#!/usr/bin/env python3
"""
Find (and optionally delete) duplicate `recordings` Firestore docs for a
single user, identified by email.

Path layout in Firestore:
    recordings/{uid}/recordings/{recordingId}

A "duplicate" here means: two or more docs whose `timestamp` field is
identical down to the millisecond. The local Room mirror dedups by
timestamp on download (`existingTimestamps[tsKey]` in
`SettingsViewModel.performFeedHistoryDownload`), so cloud-side dupes
just bloat the dataset and slow the paginated sync without ever
producing duplicate local rows.

Usage
-----
    python3 find_duplicate_recordings.py \\
        --email ericmigi@gmail.com \\
        --service-account ~/Downloads/coreapp-ce061-firebase-adminsdk-fbsvc-0159a91677.json

By default the script is read-only — it lists groups of duplicate docs
and prints a summary. Add `--delete` to actually remove the dupes (the
*newest* doc in each group is kept, older copies removed). Use `--keep
oldest` to invert that. There's a confirmation prompt.

You can also pass `--uid <UID>` directly if you already know it (skips
the Firebase Auth lookup).
"""

from __future__ import annotations

import argparse
import sys
from collections import defaultdict
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable

import firebase_admin
from firebase_admin import auth as fb_auth
from firebase_admin import credentials, firestore


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("--email", default="ericmigi@gmail.com",
                   help="User email; resolved to a UID via Firebase Auth.")
    p.add_argument("--uid", default=None,
                   help="Skip the email→UID lookup and use this UID directly.")
    p.add_argument("--service-account", default=None,
                   help="Path to a Firebase Admin SDK service-account JSON key. "
                        "Defaults to the newest coreapp-ce061-firebase-adminsdk-*.json in ~/Downloads.")
    p.add_argument("--delete", action="store_true",
                   help="Delete duplicate docs (read-only by default).")
    p.add_argument("--keep", choices=("newest", "oldest"), default="newest",
                   help="When --delete, which doc to keep per dup group (default: newest by `updated`).")
    p.add_argument("--page-size", type=int, default=500,
                   help="Firestore page size for the recordings collection scan.")
    p.add_argument("--show-all-groups", action="store_true",
                   help="Print every dup group (default: print the first 30).")
    p.add_argument("--by", choices=("timestamp", "filename", "transcription", "transferIndex"),
                   default="timestamp",
                   help="Field to group on for dedup. `filename` = entries[0].fileName "
                        "(strongest, unique-per-real-upload). `transcription` = exact content match. "
                        "`transferIndex` = entries[0].ringTransferInfo.advertisementReceived. "
                        "`timestamp` = recording start time (default, may collide on legitimate close-spaced recordings).")
    p.add_argument("--analyze", action="store_true",
                   help="Skip dedup; print field-distribution stats so you can see what's actually in cloud.")
    return p.parse_args()


def find_default_service_account() -> Path | None:
    downloads = Path.home() / "Downloads"
    candidates = sorted(
        downloads.glob("coreapp-ce061-firebase-adminsdk-*.json"),
        key=lambda p: p.stat().st_mtime,
        reverse=True,
    )
    return candidates[0] if candidates else None


def init_firebase(service_account_path: Path) -> firestore.firestore.Client:
    cred = credentials.Certificate(str(service_account_path))
    firebase_admin.initialize_app(cred)
    return firestore.client()


def resolve_uid(email: str) -> str:
    user = fb_auth.get_user_by_email(email)
    return user.uid


def fetch_all_recordings(db: firestore.firestore.Client, uid: str, page_size: int) -> list[firestore.firestore.DocumentSnapshot]:
    """Paginated full scan of `recordings/{uid}/recordings`. Streams the
    whole collection — for ~3000-5000 docs this is fine; for 100k+ you'd
    want to switch to `stream()` and avoid materialising the list."""
    coll = db.collection("recordings").document(uid).collection("recordings")
    docs: list[firestore.firestore.DocumentSnapshot] = []
    last = None
    while True:
        q = coll.order_by("timestamp").limit(page_size)
        if last is not None:
            q = q.start_after(last)
        page = list(q.stream())
        if not page:
            break
        docs.extend(page)
        last = page[-1]
        print(f"  fetched {len(docs)} so far...", file=sys.stderr)
        if len(page) < page_size:
            break
    return docs


def fmt_ts(ms: int | None) -> str:
    if ms is None:
        return "(no timestamp)"
    try:
        return datetime.fromtimestamp(ms / 1000, tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] + "Z"
    except (TypeError, ValueError, OSError):
        return f"(invalid: {ms})"


def get_ts_ms(doc) -> int | None:
    """Recordings store `timestamp` as a kotlinx.datetime.Instant
    serialised via `InstantComponentSerializer`, which on the wire is
    a struct `{epochSeconds, nanosecondsOfSecond}`. Older records or
    other code paths may use a flat long (millis) or a Firestore
    Timestamp / datetime — handle all three."""
    raw = doc.get("timestamp")
    if raw is None:
        return None
    # Struct form (kotlinx Instant default)
    if isinstance(raw, dict):
        sec = raw.get("epochSeconds")
        if sec is None:
            return None
        nanos = raw.get("nanosecondsOfSecond") or 0
        try:
            return int(sec) * 1000 + int(nanos) // 1_000_000
        except (TypeError, ValueError):
            return None
    # Firestore Timestamp / datetime
    if hasattr(raw, "timestamp"):
        return int(raw.timestamp() * 1000)
    # Flat long (epoch ms)
    try:
        return int(raw)
    except (TypeError, ValueError):
        return None


def get_updated_ms(doc) -> int:
    """Used for newest/oldest tie-break. Falls back to 0 if missing."""
    raw = doc.get("updated")
    if raw is None:
        return 0
    if hasattr(raw, "timestamp"):
        return int(raw.timestamp() * 1000)
    try:
        return int(raw)
    except (TypeError, ValueError):
        return 0


def first_entry(snap) -> dict | None:
    d = snap.to_dict() or {}
    entries = d.get("entries") or []
    return entries[0] if entries else None


def get_filename(snap) -> str | None:
    e = first_entry(snap)
    return (e or {}).get("fileName")


def get_transcription(snap) -> str | None:
    e = first_entry(snap)
    t = (e or {}).get("transcription")
    return t if t else None


def get_transfer_advert(snap) -> int | None:
    e = first_entry(snap)
    info = (e or {}).get("ringTransferInfo")
    if not info:
        return None
    raw = info.get("advertisementReceived")
    try:
        return int(raw) if raw is not None else None
    except (TypeError, ValueError):
        return None


def analyze(docs) -> None:
    """Print the field-distribution so the user can pick the right dedup key."""
    total = len(docs)
    have_ts = have_entries = have_transcription = have_session = encrypted = 0
    have_filename = have_transfer = have_status_completed = 0
    statuses: dict[str, int] = defaultdict(int)
    for d in docs:
        if get_ts_ms(d) is not None:
            have_ts += 1
        data = d.to_dict() or {}
        entries = data.get("entries") or []
        if entries:
            have_entries += 1
            e0 = entries[0]
            if e0.get("fileName"):
                have_filename += 1
            if e0.get("transcription"):
                have_transcription += 1
            if e0.get("ringTransferInfo"):
                have_transfer += 1
            status = e0.get("status")
            if status:
                statuses[str(status)] += 1
                if str(status) == "completed":
                    have_status_completed += 1
        if data.get("assistant_session"):
            have_session += 1
        if data.get("encrypted"):
            encrypted += 1

    def pct(n: int) -> str:
        return f"{n} ({n * 100 / total:.1f}%)" if total else "0"

    print(f"\n--- Field distribution across {total} recordings ---")
    print(f"  has timestamp        : {pct(have_ts)}")
    print(f"  has entries (>=1)    : {pct(have_entries)}")
    print(f"  has entries[0].fileName     : {pct(have_filename)}")
    print(f"  has entries[0].transcription: {pct(have_transcription)}")
    print(f"  entries[0].status==completed: {pct(have_status_completed)}")
    print(f"  has ringTransferInfo : {pct(have_transfer)}")
    print(f"  has assistant_session: {pct(have_session)}")
    print(f"  has encrypted        : {pct(encrypted)}")
    if statuses:
        print(f"  status breakdown     : " + ", ".join(f"{k}={v}" for k, v in sorted(statuses.items(), key=lambda kv: -kv[1])))


def main() -> int:
    args = parse_args()

    # Service account
    sa = Path(args.service_account) if args.service_account else find_default_service_account()
    if sa is None or not sa.exists():
        print("ERROR: No --service-account provided and none found in ~/Downloads.", file=sys.stderr)
        return 2
    print(f"Using service account: {sa}", file=sys.stderr)

    db = init_firebase(sa)

    # UID
    uid = args.uid or resolve_uid(args.email)
    print(f"User: {args.email}  uid={uid}", file=sys.stderr)

    # Scan
    print(f"Scanning recordings/{uid}/recordings ...", file=sys.stderr)
    docs = fetch_all_recordings(db, uid, args.page_size)
    print(f"Total recordings in cloud: {len(docs)}", file=sys.stderr)

    if args.analyze:
        analyze(docs)
        return 0

    # Group by chosen key
    key_fn_map = {
        "timestamp": get_ts_ms,
        "filename": get_filename,
        "transcription": get_transcription,
        "transferIndex": get_transfer_advert,
    }
    key_fn = key_fn_map[args.by]
    by_key: dict[object, list] = defaultdict(list)
    no_key: list = []
    for d in docs:
        k = key_fn(d)
        if k is None or (isinstance(k, str) and not k.strip()):
            no_key.append(d)
        else:
            by_key[k].append(d.reference)
            by_key[k][-1]._snapshot = d  # type: ignore[attr-defined]

    dup_groups = [(k, refs) for k, refs in by_key.items() if len(refs) > 1]
    # Sort: numbers numerically; strings lexically. Mixed-type tolerant via str().
    dup_groups.sort(key=lambda kv: str(kv[0]))

    total_extra = sum(len(refs) - 1 for _, refs in dup_groups)
    print(
        f"\nDedup key: {args.by}  "
        f"Duplicate groups: {len(dup_groups)}  "
        f"Extra docs (would be removed): {total_extra}  "
        f"Docs missing the key: {len(no_key)}",
        file=sys.stderr,
    )

    # Print groups
    if dup_groups:
        print("\n--- Duplicate groups ---")
        groups_to_show = dup_groups if args.show_all_groups else dup_groups[:30]
        for k, refs in groups_to_show:
            label = fmt_ts(k) if args.by == "timestamp" and isinstance(k, int) else (
                str(k)[:80] if isinstance(k, str) else str(k)
            )
            print(f"  {label}  ({len(refs)} copies):")
            for r in refs:
                snap = r._snapshot  # type: ignore[attr-defined]
                upd = get_updated_ms(snap)
                # Field is `assistant_session` per @SerialName on RecordingDocument.
                # Use snapshot.to_dict() to avoid raising on missing keys.
                d = snap.to_dict() or {}
                title = (d.get("assistant_session") or {}).get("title")
                print(f"    {r.id}  updated={fmt_ts(upd)}  title={(title or '')[:60]}")
        if not args.show_all_groups and len(dup_groups) > 30:
            print(f"  ... ({len(dup_groups) - 30} more groups suppressed; use --show-all-groups)")

    # Delete?
    if args.delete and dup_groups:
        print(f"\n--- About to delete {total_extra} docs ---", file=sys.stderr)
        print(f"Strategy: keep {args.keep} per group (by `updated` field)", file=sys.stderr)
        confirm = input("Type 'DELETE' to proceed: ")
        if confirm != "DELETE":
            print("Aborted.", file=sys.stderr)
            return 1

        # Firestore allows up to 500 ops per batch.
        deleted = 0
        batch = db.batch()
        ops = 0
        for ts, refs in dup_groups:
            snaps = [(r, getattr(r, "_snapshot")) for r in refs]
            snaps.sort(key=lambda rs: get_updated_ms(rs[1]),
                       reverse=(args.keep == "newest"))
            # Keep snaps[0], delete the rest
            for r, _snap in snaps[1:]:
                batch.delete(r)
                ops += 1
                deleted += 1
                if ops >= 450:
                    batch.commit()
                    batch = db.batch()
                    ops = 0
        if ops:
            batch.commit()
        print(f"Deleted {deleted} duplicate docs.", file=sys.stderr)

    return 0


if __name__ == "__main__":
    sys.exit(main())
