#!/usr/bin/env python3
"""
One-shot cleanup of items whose `sourceRecordingId` is broken.

Two failure modes are repaired by nulling out the `sourceRecordingId`:

  1. The item still references a per-device `local:N` Room PK. Those keys
     never resolve cross-device and stop resolving even on the source device
     after a delete-and-resync renumber.

  2. The item references a real firestoreId but the corresponding recording
     doesn't exist OR has a corrupted timestamp (`epoch == 0`, i.e. parsed
     to Dec 31 1969). These items got "repaired" by the now-deleted
     `repairLocalSourceIds` to point at whatever ghost recording happened
     to land at a given Pixel-local PK during a delete-and-resync.

Setting `sourceRecordingId = None` is reversible-ish: the UI shows
"no source" instead of a wrong link, and the item itself is intact.
The data is GONE if you ever knew the original source recording id, so
it's safer to leave items orphaned than to keep pointing them at the
wrong recording.

Usage
-----
    python3 repair_item_source_links.py \
        --email ericmigi@gmail.com \
        --service-account ~/Downloads/coreapp-ce061-firebase-adminsdk-fbsvc-0159a91677.json

Read-only by default. Add `--apply` to perform writes (with a confirm prompt).
"""

from __future__ import annotations

import argparse
import sys
from datetime import datetime, timezone
from pathlib import Path

import firebase_admin
from firebase_admin import auth as fb_auth
from firebase_admin import credentials, firestore


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    p.add_argument("--email", default="ericmigi@gmail.com")
    p.add_argument("--uid", default=None)
    p.add_argument("--service-account", default=None)
    p.add_argument(
        "--apply",
        action="store_true",
        help="Actually write nulls to Firestore (read-only by default).",
    )
    return p.parse_args()


def resolve_service_account(arg: str | None) -> Path:
    if arg:
        return Path(arg).expanduser()
    downloads = Path("~/Downloads").expanduser()
    candidates = sorted(
        downloads.glob("coreapp-ce061-firebase-adminsdk-*.json"),
        key=lambda p: p.stat().st_mtime,
        reverse=True,
    )
    if not candidates:
        sys.exit(
            "No --service-account provided and no coreapp-ce061-*.json found in ~/Downloads. "
            "Pass --service-account explicitly."
        )
    return candidates[0]


def resolve_uid(email: str, uid_arg: str | None) -> str:
    if uid_arg:
        return uid_arg
    return fb_auth.get_user_by_email(email).uid


def is_epoch_zero_timestamp(value) -> bool:
    """Detect a recording timestamp that effectively round-trips to epoch 0
    on the Kotlin client (i.e. one of the corrupted ghost recordings)."""
    if value is None:
        return True
    if isinstance(value, datetime):
        return value.timestamp() == 0
    if isinstance(value, dict):
        secs = value.get("epochSeconds") or value.get("seconds") or 0
        return secs == 0
    return False


def epoch_seconds(value) -> int | None:
    """Parse either a Firestore native Timestamp or a kotlinx-serialization
    `{epochSeconds, nanosecondsOfSecond}` map into epoch seconds."""
    if value is None:
        return None
    if isinstance(value, datetime):
        return int(value.timestamp())
    if isinstance(value, dict):
        secs = value.get("epochSeconds") or value.get("seconds")
        if secs is not None:
            return int(secs)
    return None


def looks_misrouted(item_data: dict, rec_data: dict) -> bool:
    """Item probably links to the wrong recording.

    Heuristic: if the item's `createdAt` and the recording's `timestamp`
    differ by more than 24 hours, the link almost certainly came from the
    bad `repairLocalSourceIds` pass (a Pixel-local PK collision pointed
    one item at an unrelated recording from years earlier).

    This is conservative — items legitimately created from a recording
    have createdAt within seconds of the recording's timestamp."""
    item_secs = epoch_seconds(item_data.get("createdAt"))
    rec_secs = epoch_seconds(rec_data.get("timestamp"))
    if item_secs is None or rec_secs is None:
        return False
    return abs(item_secs - rec_secs) > 86_400


def main() -> int:
    args = parse_args()

    cred_path = resolve_service_account(args.service_account)
    print(f"Using service account: {cred_path}")
    firebase_admin.initialize_app(credentials.Certificate(str(cred_path)))

    uid = resolve_uid(args.email, args.uid)
    print(f"User: {args.email}  uid={uid}")

    db = firestore.client()
    items_ref = db.collection(f"items/{uid}/items")
    recordings_ref = db.collection(f"recordings/{uid}/recordings")

    print("Loading items...")
    items = list(items_ref.stream())
    print(f"  {len(items)} items in Firestore")

    # Pull the recordings referenced by these items in one pass so we can
    # resolve them locally instead of N round-trips.
    referenced_ids = {
        d.to_dict().get("sourceRecordingId")
        for d in items
        if d.to_dict().get("sourceRecordingId")
        and not str(d.to_dict().get("sourceRecordingId", "")).startswith("local:")
    }
    print(f"  {len(referenced_ids)} distinct recording references to resolve")
    rec_cache: dict[str, dict | None] = {}
    for rid in referenced_ids:
        snap = recordings_ref.document(rid).get()
        rec_cache[rid] = snap.to_dict() if snap.exists else None

    to_clear: list[tuple[str, str, str]] = []  # (item_id, reason, current_src)
    for d in items:
        data = d.to_dict()
        src = data.get("sourceRecordingId")
        if src is None or src == "":
            continue
        if str(src).startswith("local:"):
            to_clear.append((d.id, "local:N (per-device key)", src))
            continue
        rec = rec_cache.get(src)
        if rec is None:
            to_clear.append((d.id, "recording missing in Firestore", src))
            continue
        if is_epoch_zero_timestamp(rec.get("timestamp")):
            to_clear.append((d.id, "recording has epoch-0 timestamp", src))
            continue
        if looks_misrouted(data, rec):
            to_clear.append((d.id, "createdAt vs recording.timestamp diverges (>24h)", src))
            continue

    print()
    print(f"Items to repair: {len(to_clear)}")
    if not to_clear:
        print("Nothing to do.")
        return 0

    # Group by reason for readability.
    by_reason: dict[str, list] = {}
    for item_id, reason, src in to_clear:
        by_reason.setdefault(reason, []).append((item_id, src))
    for reason, rows in by_reason.items():
        print(f"  {reason}: {len(rows)}")
        for item_id, src in rows[:5]:
            print(f"    {item_id}  src={src}")
        if len(rows) > 5:
            print(f"    ... and {len(rows) - 5} more")
    print()

    if not args.apply:
        print("Read-only mode. Re-run with --apply to clear the broken sourceRecordingIds.")
        return 0

    confirm = input(f"Type 'CLEAR {len(to_clear)}' to proceed: ")
    if confirm != f"CLEAR {len(to_clear)}":
        print("Aborted.")
        return 1

    # Bump updatedAt at the same time so each device's pull listener
    # actually applies the change to Room. Without the bump, remote and
    # local updatedAt stay equal and the pull listener's
    # `remote.updatedAt > local.updatedAt` check skips the row — so
    # Room keeps the stale `sourceRecordingId` and on the next push
    # observer cycle would round-trip it right back to Firestore,
    # undoing this cleanup.
    now = datetime.now(timezone.utc)
    new_updated_at = {
        "epochSeconds": int(now.timestamp()),
        "nanosecondsOfSecond": int((now.timestamp() % 1) * 1_000_000_000),
    }

    cleared = 0
    failed = 0
    for item_id, _, _ in to_clear:
        try:
            items_ref.document(item_id).update({
                "sourceRecordingId": firestore.DELETE_FIELD,
                "updatedAt": new_updated_at,
            })
            cleared += 1
        except Exception as e:
            print(f"  failed: {item_id}: {e}")
            failed += 1
    print(f"Cleared {cleared}; failed {failed}.")
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
