#!/usr/bin/env python3
"""td — time-track markdown todos identified by ^tid block-refs.

Usage:
    td start <tid> [--file PATH] [--desc TEXT]
    td stop
    td report [FILTER]       # FILTER matches tid or file substring
    td current               # print the currently-running tid (or nothing)
    td tidgen                # print a new unique tid
    td week [--since DATE]   # task create/complete events from git log, joined with time.csv

Log location: $TD_LOG (default ./time.csv)
"""
import argparse, csv, os, re, subprocess, sys
from datetime import datetime, timedelta
from collections import defaultdict

LOG = os.environ.get("TD_LOG", "./time.csv")
FIELDS = ["started_at", "stopped_at", "tid", "file", "description"]

def now():
    return datetime.now().isoformat(timespec="seconds")

def read_rows():
    if not os.path.exists(LOG):
        return []
    with open(LOG, newline="") as f:
        return list(csv.DictReader(f))

def write_rows(rows):
    with open(LOG, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=FIELDS)
        w.writeheader()
        w.writerows(rows)

def lookup_task(tid):
    try:
        out = subprocess.check_output(
            ["rg", "--no-heading", "--with-filename", rf"\^{tid}\b", "."],
            text=True, stderr=subprocess.DEVNULL,
        )
    except (subprocess.CalledProcessError, FileNotFoundError):
        return "", ""
    first = out.splitlines()[0] if out else ""
    if not first:
        return "", ""
    path, _, line = first.partition(":")
    file = path.removeprefix("./")
    desc = re.sub(r"^\s*-\s*\[.\]\s*", "", line)
    desc = re.sub(r"\s*\^\S+\s*$", "", desc).strip()
    return file, desc

def cmd_start(args):
    rows = read_rows()
    ts = now()
    stopped = [r["tid"] for r in rows if not r["stopped_at"]]
    for r in rows:
        if not r["stopped_at"]:
            r["stopped_at"] = ts
    file, desc = args.file or "", args.desc or ""
    if not (file or desc):
        file, desc = lookup_task(args.tid)
        if not file:
            print(f"warning: ^{args.tid} not found in {os.getcwd()} — logging with empty file/desc", file=sys.stderr)
    rows.append({"started_at": ts, "stopped_at": "", "tid": args.tid, "file": file, "description": desc})
    write_rows(rows)
    if stopped:
        print(f"  stopped {', '.join(stopped)} first")
    snippet = f" ({file}: {desc[:50]})" if desc else ""
    print(f"started {args.tid}{snippet}")

def cmd_stop(args):
    rows = read_rows()
    if not rows:
        print(f"no log at {LOG}"); sys.exit(1)
    ts = now()
    stopped = []
    for r in rows:
        if not r["stopped_at"]:
            r["stopped_at"] = ts
            stopped.append(r["tid"])
    if not stopped:
        print("nothing running"); sys.exit(1)
    write_rows(rows)
    print(f"stopped {', '.join(stopped)}")

def cmd_report(args):
    rows = read_rows()
    if not rows:
        print(f"no log at {LOG}"); sys.exit(1)
    totals, files, running = defaultdict(int), {}, []
    for r in rows:
        if args.filter and args.filter not in r["tid"] and args.filter not in r["file"]:
            continue
        start = datetime.fromisoformat(r["started_at"])
        files[r["tid"]] = r["file"]
        if r["stopped_at"]:
            stop = datetime.fromisoformat(r["stopped_at"])
            totals[r["tid"]] += int((stop - start).total_seconds())
        else:
            running.append((r["tid"], start, r["file"], r["description"]))
    for tid in sorted(totals):
        t = totals[tid]
        h, m = t // 3600, (t % 3600) // 60
        print(f"{tid:<24} {h:>3}h{m:02d}m  {files.get(tid,'')}")
    for tid, start, file, desc in running:
        print(f"{tid:<24} RUNNING since {start.strftime('%H:%M')}  {file}: {desc[:40]}")

def cmd_current(args):
    for r in read_rows():
        if not r["stopped_at"]:
            print(r["tid"]); return

def cmd_tidgen(args):
    print(f"tid-{datetime.now().strftime('%Y%m%d-%H%M%S')}")

TID_RE = re.compile(r"\^(tid-[\w-]+)")
TASK_ADD_RE = re.compile(r"^\+(?!\+\+)\s*-\s*\[([ xX])\]\s*(.*)$")
TASK_REM_RE = re.compile(r"^-(?!--)\s*-\s*\[([ xX])\]\s*(.*)$")

def _tid_totals():
    totals = defaultdict(int)
    for r in read_rows():
        if r["stopped_at"]:
            start = datetime.fromisoformat(r["started_at"])
            stop = datetime.fromisoformat(r["stopped_at"])
            totals[r["tid"]] += int((stop - start).total_seconds())
    return totals

def _scan_git(since):
    SEP = "---TDWEEK-COMMIT---"
    try:
        out = subprocess.check_output(
            ["git", "log", f"--since={since}", "--reverse",
             f"--format={SEP}%n%cI", "-p", "--no-color", "--no-renames"],
            text=True, stderr=subprocess.DEVNULL,
        )
    except (subprocess.CalledProcessError, FileNotFoundError):
        return
    for chunk in out.split(SEP + "\n"):
        if not chunk.strip():
            continue
        lines = chunk.split("\n")
        ts = lines[0]
        current_file = None
        adds, removes = [], []
        for ln in lines[1:]:
            if ln.startswith("+++ b/"):
                current_file = ln[6:].split("\t")[0].rstrip(); continue
            if ln.startswith("--- ") or ln.startswith("diff --git") or ln.startswith("index ") or ln.startswith("@@"):
                continue
            if not current_file or current_file == "/dev/null":
                continue
            m = TASK_ADD_RE.match(ln)
            if m:
                status = m.group(1).lower().strip() or " "
                tm = TID_RE.search(m.group(2))
                if tm:
                    desc = TID_RE.sub("", m.group(2)).strip()
                    adds.append((current_file, status, tm.group(1), desc))
                continue
            m = TASK_REM_RE.match(ln)
            if m:
                status = m.group(1).lower().strip() or " "
                tm = TID_RE.search(m.group(2))
                if tm:
                    removes.append((current_file, status, tm.group(1)))
        rem_by_tid = {t[2]: t for t in removes}
        for file, status, tid, desc in adds:
            prior = rem_by_tid.get(tid)
            if prior:
                if prior[1] == " " and status == "x":
                    yield (ts, "done", tid, file, desc)
                elif prior[1] == "x" and status == " ":
                    yield (ts, "reopen", tid, file, desc)
            else:
                yield (ts, "done+new" if status == "x" else "new", tid, file, desc)

def _default_since():
    today = datetime.now()
    monday = today - timedelta(days=today.weekday())
    return monday.strftime("%Y-%m-%d 00:00")

def _fmt_dur(secs):
    return f"{secs // 3600}h{(secs % 3600) // 60:02d}m"

def cmd_week(args):
    since = args.since or _default_since()
    events = list(_scan_git(since))
    totals = _tid_totals()
    by_day = defaultdict(list)
    git_tids = set()
    for ts, kind, tid, file, desc in events:
        day = ts[:10]
        by_day[day].append((ts, kind, tid, file, desc))
        git_tids.add(tid)
    seen_tid_days = set()
    for r in read_rows():
        if not r["stopped_at"]:
            continue
        day = r["started_at"][:10]
        if day < since[:10]:
            continue
        tid = r["tid"]
        if tid in git_tids or (day, tid) in seen_tid_days:
            continue
        seen_tid_days.add((day, tid))
        by_day[day].append((r["started_at"], "time", tid, r["file"], r["description"]))
    if not by_day:
        print(f"no task events since {since}"); return
    marker = {"done": "[x]", "new": "[ ]", "done+new": "[+]", "reopen": "[o]", "time": "[-]"}
    n_done = n_new = total_secs = 0
    counted_tids = set()
    print(f"Week since {since}\n")
    for day in sorted(by_day):
        dt = datetime.fromisoformat(day)
        print(f"{dt.strftime('%a %Y-%m-%d')}")
        for ts, kind, tid, file, desc in by_day[day]:
            t = totals.get(tid, 0)
            dur = _fmt_dur(t) if t else "     "
            print(f"  {marker.get(kind,'?  ')} {tid:<24} {dur:>6}  {file:<36}  {desc[:60]}")
            if kind in ("done", "done+new"): n_done += 1
            if kind in ("new", "done+new"): n_new += 1
            if tid not in counted_tids:
                counted_tids.add(tid); total_secs += t
        print()
    print(f"Totals: completed {n_done}  created {n_new}  time {_fmt_dur(total_secs)}")

def _lookup_line(file, tid):
    if not file or not os.path.exists(file):
        return ""
    try:
        with open(file) as f:
            for line in f:
                if f"^{tid}" in line:
                    return line.rstrip()
    except OSError:
        pass
    return ""

def cmd_weekmd(args):
    since = args.since or _default_since()
    events = list(_scan_git(since))
    totals = _tid_totals()
    if not events:
        print(f"no task events since {since}"); return
    by_day = defaultdict(list)
    for ts, kind, tid, file, desc in events:
        by_day[ts[:10]].append((ts, kind, tid, file, desc))
    marker = {"done": "[x]", "new": "[ ]", "done+new": "[+]", "reopen": "[o]"}
    n_done = n_new = total_secs = 0
    counted_tids = set()
    rows = []
    for day in sorted(by_day):
        dt = datetime.fromisoformat(day)
        day_label = dt.strftime("%a %Y-%m-%d")
        first = True
        for ts, kind, tid, file, desc in by_day[day]:
            t = totals.get(tid, 0)
            raw = _lookup_line(file, tid) or desc
            line = re.sub(r"^\s*-\s*\[.\]\s*", "", raw)
            line = re.sub(r"\s*\^\S+.*$", "", line).strip()
            rows.append((day_label if first else "", marker.get(kind, "?"), tid, _fmt_dur(t) if t else "", file, line))
            first = False
            if kind in ("done", "done+new"): n_done += 1
            if kind in ("new", "done+new"): n_new += 1
            if tid not in counted_tids:
                counted_tids.add(tid); total_secs += t
    print(f"# Week since {since}\n")
    print("| Day | | TID | Time | File | Task |")
    print("|-----|---|-----|------|------|------|")
    for day_label, mk, tid, dur, file, line in rows:
        dl = day_label.replace("|", "\\|")
        fn = file.replace("|", "\\|")
        ln = line.replace("|", "\\|")
        print(f"| {dl} | {mk} | {tid} | {dur} | {fn} | {ln} |")
    print()
    print(f"**Totals:** completed {n_done} · created {n_new} · time {_fmt_dur(total_secs)}")

def main():
    p = argparse.ArgumentParser(prog="td", description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    sp = p.add_subparsers(dest="cmd", required=True)
    s = sp.add_parser("start"); s.add_argument("tid"); s.add_argument("--file", default=""); s.add_argument("--desc", default=""); s.set_defaults(func=cmd_start)
    sp.add_parser("stop").set_defaults(func=cmd_stop)
    s = sp.add_parser("report"); s.add_argument("filter", nargs="?"); s.set_defaults(func=cmd_report)
    sp.add_parser("current").set_defaults(func=cmd_current)
    sp.add_parser("tidgen").set_defaults(func=cmd_tidgen)
    s = sp.add_parser("week"); s.add_argument("--since", default=None); s.set_defaults(func=cmd_week)
    s = sp.add_parser("weekmd"); s.add_argument("--since", default=None); s.set_defaults(func=cmd_weekmd)
    args = p.parse_args()
    args.func(args)

if __name__ == "__main__":
    main()
