#!/usr/bin/env python3
"""
Full backup of one user's recordings: every Firestore document in
`recordings/{uid}/recordings/*` plus every audio blob those documents
reference in Firebase Storage.

Layout
------
    <output-dir>/
        firestore/<recordingId>.json     # full Firestore doc, formatted
        audio/<fileName>.<ext>           # downloaded audio blob
        manifest.json                    # counters + uid + timestamp

Restartability
--------------
The script is purely filesystem-stateful. It re-runs safely:
- A Firestore doc that already exists locally with non-zero size is skipped.
- An audio blob that already exists locally with non-zero size is skipped.
- A partial download from a previous run will look like a non-zero file and
  be kept; if you want to force a re-download, just delete the file.

If you killed the script mid-page, simply rerun — it will refetch the same
Firestore page (Firestore reads are cached; cost is minimal) and skip any
docs already on disk.

Parallelism
-----------
Firestore is paginated single-threaded (it's already fast — ~50 docs per
~30 ms). Audio downloads run on a thread pool (default 16 workers); the
network is the bottleneck and Storage tolerates many concurrent ranges.

Usage
-----
    python3 backup_recordings.py \\
        --email ericmigi@gmail.com \\
        --output-dir "$HOME/Downloads/2026-04-29 Eric Firestore recording db"
"""

from __future__ import annotations

import argparse
import json
import sys
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from typing import Any

import firebase_admin
from firebase_admin import auth as fb_auth
from firebase_admin import credentials, firestore
from google.cloud import storage as gcs


# Firebase Storage bucket — pulled out of google-services.json. Override
# with --bucket if you need to point at a different project.
DEFAULT_BUCKET = "coreapp-ce061.firebasestorage.app"


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("--email", default="ericmigi@gmail.com",
                   help="User email; resolved to a UID via Firebase Auth.")
    p.add_argument("--uid", default=None, help="Skip the email→UID lookup and use this UID directly.")
    p.add_argument("--service-account", default=None,
                   help="Service-account JSON key. Defaults to newest "
                        "coreapp-ce061-firebase-adminsdk-*.json in ~/Downloads.")
    p.add_argument("--output-dir", required=True,
                   help="Where to write the backup. Created if missing.")
    p.add_argument("--bucket", default=DEFAULT_BUCKET,
                   help=f"GCS bucket name (default: {DEFAULT_BUCKET}).")
    p.add_argument("--workers", type=int, default=16,
                   help="Parallel audio download workers (default 16).")
    p.add_argument("--page-size", type=int, default=500,
                   help="Firestore page size for the recordings collection scan.")
    p.add_argument("--skip-audio", action="store_true",
                   help="Only download Firestore docs, skip audio blobs.")
    p.add_argument("--skip-firestore", action="store_true",
                   help="Skip Firestore (use existing local JSON), only download audio.")
    return p.parse_args()


def find_default_service_account() -> Path | None:
    downloads = Path.home() / "Downloads"
    candidates = sorted(
        downloads.glob("coreapp-ce061-firebase-adminsdk-*.json"),
        key=lambda p: p.stat().st_mtime,
        reverse=True,
    )
    return candidates[0] if candidates else None


def init_firebase(service_account_path: Path):
    cred = credentials.Certificate(str(service_account_path))
    firebase_admin.initialize_app(cred)
    return firestore.client()


def init_gcs(service_account_path: Path, bucket_name: str):
    """Use the same service account file for the GCS client."""
    client = gcs.Client.from_service_account_json(str(service_account_path))
    return client.bucket(bucket_name)


def resolve_uid(email: str) -> str:
    return fb_auth.get_user_by_email(email).uid


# ---------- Firestore export ----------

def _json_default(o):
    """Make Firestore special types JSON-serialisable."""
    # Firestore Timestamp / datetime
    if hasattr(o, "isoformat"):
        return o.isoformat()
    # GeoPoint, DocumentReference, Bytes — just stringify
    return str(o)


def export_firestore(db, uid: str, output_dir: Path, page_size: int) -> tuple[int, int, set[str]]:
    """Stream every recording doc to <output>/firestore/<id>.json.

    Returns (total_seen, written_now, all_filenames) — filenames are
    every fileName across all entries[] (used by the audio pass)."""
    out = output_dir / "firestore"
    out.mkdir(parents=True, exist_ok=True)

    coll = db.collection("recordings").document(uid).collection("recordings")
    total = 0
    written = 0
    skipped = 0
    filenames: set[str] = set()
    last = None
    while True:
        q = coll.order_by("timestamp").limit(page_size)
        if last is not None:
            q = q.start_after(last)
        page = list(q.stream())
        if not page:
            break
        for snap in page:
            total += 1
            data = snap.to_dict() or {}
            for entry in (data.get("entries") or []):
                fn = entry.get("fileName")
                if fn:
                    filenames.add(fn)
            target = out / f"{snap.id}.json"
            if target.exists() and target.stat().st_size > 0:
                skipped += 1
                continue
            tmp = target.with_suffix(".json.tmp")
            with tmp.open("w", encoding="utf-8") as f:
                json.dump(data, f, default=_json_default, indent=2, sort_keys=True)
            tmp.rename(target)
            written += 1
        last = page[-1]
        print(f"  firestore: total={total} written={written} skipped={skipped} unique_files={len(filenames)}", file=sys.stderr)
        if len(page) < page_size:
            break
    return total, written, filenames


# ---------- Audio download ----------

def _ext_for_content_type(ct: str | None) -> str:
    if not ct:
        return ""
    ct = ct.lower().split(";")[0].strip()
    return {
        "audio/mp4": ".m4a",
        "audio/m4a": ".m4a",
        "audio/x-m4a": ".m4a",
        "audio/wav": ".wav",
        "audio/x-wav": ".wav",
        "audio/wave": ".wav",
        "audio/raw": ".pcm",
        "audio/pcm": ".pcm",
        "audio/aac": ".aac",
        "audio/mpeg": ".mp3",
        "audio/ogg": ".ogg",
    }.get(ct, "")


def download_one_audio(bucket, uid: str, file_name: str, out_dir: Path) -> tuple[str, str]:
    """Returns (file_name, status). Status is one of:
    'skipped' (already present), 'downloaded', 'missing', 'error: <msg>'."""
    blob_path = f"recordings/{uid}/{file_name}"
    blob = bucket.blob(blob_path)
    try:
        if not blob.exists():
            return file_name, "missing"
        # Reload to get content_type without a full GET
        blob.reload()
        ext = _ext_for_content_type(blob.content_type) or ".bin"
        target = out_dir / f"{file_name}{ext}"
        if target.exists() and target.stat().st_size > 0:
            return file_name, "skipped"
        tmp = target.with_suffix(target.suffix + ".part")
        blob.download_to_filename(str(tmp))
        tmp.rename(target)
        return file_name, "downloaded"
    except Exception as e:
        return file_name, f"error: {type(e).__name__}: {e}"


def download_all_audio(bucket, uid: str, file_names: set[str], output_dir: Path, workers: int) -> dict[str, int]:
    out = output_dir / "audio"
    out.mkdir(parents=True, exist_ok=True)

    counters = {"downloaded": 0, "skipped": 0, "missing": 0, "error": 0}
    lock = threading.Lock()
    todo = list(file_names)
    print(f"\nAudio: {len(todo)} unique fileNames to check, {workers} workers", file=sys.stderr)

    progress_every = max(1, len(todo) // 50)
    done = 0
    with ThreadPoolExecutor(max_workers=workers) as pool:
        futs = [pool.submit(download_one_audio, bucket, uid, fn, out) for fn in todo]
        for fut in as_completed(futs):
            fn, status = fut.result()
            with lock:
                done += 1
                if status == "downloaded":
                    counters["downloaded"] += 1
                elif status == "skipped":
                    counters["skipped"] += 1
                elif status == "missing":
                    counters["missing"] += 1
                    print(f"    missing: {fn}", file=sys.stderr)
                else:
                    counters["error"] += 1
                    print(f"    {status}: {fn}", file=sys.stderr)
                if done % progress_every == 0 or done == len(todo):
                    print(
                        f"  audio: done={done}/{len(todo)} "
                        f"downloaded={counters['downloaded']} skipped={counters['skipped']} "
                        f"missing={counters['missing']} error={counters['error']}",
                        file=sys.stderr,
                    )
    return counters


# ---------- Main ----------

def main() -> int:
    args = parse_args()
    sa = Path(args.service_account) if args.service_account else find_default_service_account()
    if sa is None or not sa.exists():
        print("ERROR: No --service-account provided and none found in ~/Downloads.", file=sys.stderr)
        return 2
    print(f"Using service account: {sa}", file=sys.stderr)

    db = init_firebase(sa)
    bucket = init_gcs(sa, args.bucket)

    uid = args.uid or resolve_uid(args.email)
    print(f"User: {args.email}  uid={uid}  bucket={args.bucket}", file=sys.stderr)

    output_dir = Path(args.output_dir).expanduser()
    output_dir.mkdir(parents=True, exist_ok=True)

    # ---- Firestore ----
    if args.skip_firestore:
        # Reconstruct fileName set from already-on-disk JSON
        fs_dir = output_dir / "firestore"
        filenames: set[str] = set()
        total = 0
        for fp in fs_dir.glob("*.json"):
            try:
                d = json.loads(fp.read_text())
            except Exception:
                continue
            total += 1
            for e in (d.get("entries") or []):
                fn = e.get("fileName")
                if fn:
                    filenames.add(fn)
        written = 0
        print(f"Firestore: skipped (loaded {total} docs from disk, {len(filenames)} unique fileNames)", file=sys.stderr)
    else:
        print(f"\nExporting Firestore → {output_dir / 'firestore'}", file=sys.stderr)
        total, written, filenames = export_firestore(db, uid, output_dir, args.page_size)
        print(f"Firestore: total={total} written={written} unique_files={len(filenames)}", file=sys.stderr)

    # ---- Audio ----
    audio_counters = {"downloaded": 0, "skipped": 0, "missing": 0, "error": 0}
    if not args.skip_audio:
        audio_counters = download_all_audio(bucket, uid, filenames, output_dir, args.workers)
    else:
        print("\nAudio download: skipped (--skip-audio)", file=sys.stderr)

    # ---- Manifest ----
    manifest = {
        "email": args.email,
        "uid": uid,
        "bucket": args.bucket,
        "totals": {
            "firestore_docs": total,
            "firestore_written_this_run": written,
            "unique_audio_filenames": len(filenames),
            "audio": audio_counters,
        },
    }
    (output_dir / "manifest.json").write_text(json.dumps(manifest, indent=2))
    print(f"\nDone. Manifest at {output_dir / 'manifest.json'}", file=sys.stderr)
    return 0


if __name__ == "__main__":
    sys.exit(main())
