#!/usr/bin/env python3
"""Query all ESP32 CSI sensors in parallel."""

import concurrent.futures
import http.server
import os
import socket
import subprocess
import sys
import threading
import time

from esp_ctl.auth import sign_command

DEFAULT_PORT = 5501
DEFAULT_HTTP_PORT = 8070
DEFAULT_FW = os.path.expanduser(
    "~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin"
)
TIMEOUT = 2.0
REBOOT_WAIT = 25.0
STATUS_RETRIES = 10
STATUS_INTERVAL = 3.0

SENSORS = [
    ("muddy-storm",  "muddy-storm.local"),
    ("amber-maple",  "amber-maple.local"),
    ("hollow-acorn", "hollow-acorn.local"),
]

ESP_OTA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "esp-ota")

USAGE = """\
Usage: esp-fleet <command> [args...]

Sends a command to all known sensors in parallel and prints results.

Commands:
  status              Query all devices
  identify            Blink LEDs on all devices
  rate <10-100>       Set ping rate on all devices
  power <2-20>        Set TX power on all devices
  reboot              Reboot all devices
  ota [firmware.bin]           OTA update all devices (sequentially)
  ota --parallel [firmware]    OTA update all devices in parallel

Examples:
  esp-fleet status
  esp-fleet identify
  esp-fleet rate 50
  esp-fleet ota
  esp-fleet ota --parallel
  esp-fleet ota --parallel /path/to/firmware.bin"""


def query(name, host, cmd):
    """Send command to one sensor, return (name, reply_or_error)."""
    cmd = sign_command(cmd)
    try:
        info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
        ip = info[0][4][0]
    except socket.gaierror:
        return (name, f"ERR: cannot resolve {host}")

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(TIMEOUT)
    try:
        sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
        data, _ = sock.recvfrom(512)
        return (name, data.decode().strip())
    except socket.timeout:
        return (name, f"ERR: timeout ({ip})")
    except OSError as e:
        return (name, f"ERR: {e}")
    finally:
        sock.close()


def _resolve(host):
    """Resolve hostname to IP, return None on failure."""
    try:
        info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
        return info[0][4][0]
    except socket.gaierror:
        return None


def _udp_cmd(ip, cmd, timeout=TIMEOUT):
    """Send signed UDP command and return reply string."""
    cmd = sign_command(cmd)
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(timeout)
    try:
        sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
        data, _ = sock.recvfrom(512)
        return data.decode().strip()
    finally:
        sock.close()


def _get_local_ip(target_ip):
    """Determine which local IP the target can reach us on."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        sock.connect((target_ip, 80))
        return sock.getsockname()[0]
    finally:
        sock.close()


def _serve_firmware(directory, port):
    """Start HTTP server in background thread, return server object."""
    handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler(
        *a, directory=directory, **k
    )
    server = http.server.HTTPServer(("0.0.0.0", port), handler)
    thread = threading.Thread(target=server.serve_forever, daemon=True)
    thread.start()
    return server


def _verify_device(name, host, retries=STATUS_RETRIES, interval=STATUS_INTERVAL):
    """Wait for device to come back after OTA. Returns (name, status_or_error)."""
    for attempt in range(1, retries + 1):
        ip = _resolve(host)
        if ip:
            try:
                reply = _udp_cmd(ip, "STATUS", timeout=3.0)
                return (name, reply)
            except (socket.timeout, OSError):
                pass
        time.sleep(interval)
    return (name, "ERR: did not come back after OTA")


def run_ota_parallel(firmware=None):
    """Run OTA on all sensors in parallel: one HTTP server, concurrent OTA commands."""
    fw_path = os.path.abspath(firmware or DEFAULT_FW)
    if not os.path.isfile(fw_path):
        print(f"ERR: firmware not found: {fw_path}", file=sys.stderr)
        sys.exit(1)

    fw_size = os.path.getsize(fw_path)
    fw_dir = os.path.dirname(fw_path)
    fw_name = os.path.basename(fw_path)
    print(f"Firmware: {fw_path} ({fw_size // 1024} KB)")

    # Resolve all devices and check they're alive
    devices = {}  # name -> ip
    for name, host in SENSORS:
        ip = _resolve(host)
        if not ip:
            print(f"  {name}: ERR cannot resolve {host}", file=sys.stderr)
            continue
        try:
            reply = _udp_cmd(ip, "STATUS")
            # Extract version from status
            version = "?"
            for part in reply.split():
                if part.startswith("version="):
                    version = part[8:]
                    break
            print(f"  {name}: alive ({ip}) version={version}")
            devices[name] = (host, ip)
        except (socket.timeout, OSError):
            print(f"  {name}: ERR not responding ({ip})", file=sys.stderr)

    if not devices:
        print("ERR: no devices reachable", file=sys.stderr)
        sys.exit(1)

    # Start one HTTP server
    first_ip = next(iter(devices.values()))[1]
    local_ip = _get_local_ip(first_ip)
    ota_url = f"http://{local_ip}:{DEFAULT_HTTP_PORT}/{fw_name}"
    server = _serve_firmware(fw_dir, DEFAULT_HTTP_PORT)
    print(f"HTTP server on :{DEFAULT_HTTP_PORT}, OTA URL: {ota_url}")

    # Send OTA command to all devices in parallel
    failed = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool:
        def _send_ota(name, ip):
            try:
                reply = _udp_cmd(ip, f"OTA {ota_url}")
                return (name, reply)
            except (socket.timeout, OSError) as e:
                return (name, f"ERR: {e}")

        futures = {pool.submit(_send_ota, n, ip): n for n, (host, ip) in devices.items()}
        for f in concurrent.futures.as_completed(futures):
            name, reply = f.result()
            ok = reply.startswith("OK")
            print(f"  {name}: {reply}")
            if not ok:
                failed.append(name)

    if failed:
        print(f"ERR: OTA rejected by: {', '.join(failed)}", file=sys.stderr)

    # Wait for all devices to download, flash, and reboot
    active = {n: h for n, (h, ip) in devices.items() if n not in failed}
    if not active:
        server.shutdown()
        sys.exit(1)

    print(f"Waiting for {len(active)} devices to reboot...")
    time.sleep(REBOOT_WAIT)

    # Verify all in parallel
    with concurrent.futures.ThreadPoolExecutor(max_workers=len(active)) as pool:
        futures = {pool.submit(_verify_device, n, h): n for n, h in active.items()}
        results = {}
        for f in concurrent.futures.as_completed(futures):
            name, reply = f.result()
            results[name] = reply

    server.shutdown()

    # Print results in sensor order
    max_name = max(len(n) for n, _ in SENSORS)
    all_ok = True
    for name, _ in SENSORS:
        if name in results:
            ok = not results[name].startswith("ERR")
            if not ok:
                all_ok = False
            print(f"  {name:<{max_name}}  {results[name]}")

    if all_ok:
        print(f"\nAll {len(active)} devices updated.")
    else:
        print(f"\nSome devices failed.", file=sys.stderr)
        sys.exit(1)


def run_ota(firmware=None):
    """Run OTA on each sensor sequentially."""
    for name, host in SENSORS:
        print(f"\n{'='*40}")
        print(f"OTA: {name} ({host})")
        print(f"{'='*40}")
        cmd = [ESP_OTA, host]
        if firmware:
            cmd += ["-f", firmware]
        result = subprocess.run(cmd)
        if result.returncode != 0:
            print(f"ERR: OTA failed for {name}, stopping fleet OTA", file=sys.stderr)
            sys.exit(1)
    print(f"\nAll {len(SENSORS)} devices updated.")


def main():
    if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
        print(USAGE)
        sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2)

    # Handle OTA subcommand
    if sys.argv[1].lower() == "ota":
        args = sys.argv[2:]
        parallel = "--parallel" in args
        if parallel:
            args.remove("--parallel")
        firmware = args[0] if args else None
        if parallel:
            run_ota_parallel(firmware)
        else:
            run_ota(firmware)
        return

    cmd = " ".join(sys.argv[1:]).strip().upper()

    with concurrent.futures.ThreadPoolExecutor(max_workers=len(SENSORS)) as pool:
        futures = {pool.submit(query, name, host, cmd): name for name, host in SENSORS}
        results = {}
        for f in concurrent.futures.as_completed(futures):
            name, reply = f.result()
            results[name] = reply

    # Print in sensor order
    max_name = max(len(n) for n, _ in SENSORS)
    for name, _ in SENSORS:
        print(f"{name:<{max_name}}  {results[name]}")


if __name__ == "__main__":
    main()
