#!/usr/bin/env python3
"""OTA firmware update for ESP32 CSI devices."""

import argparse
import http.server
import os
import socket
import sys
import threading
import time

from esp_ctl.auth import sign_command

DEFAULT_CMD_PORT = 5501
DEFAULT_HTTP_PORT = 8070
DEFAULT_FW = os.path.expanduser(
    "~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin"
)
TIMEOUT = 2.0
REBOOT_WAIT = 25.0
STATUS_RETRIES = 10
STATUS_INTERVAL = 3.0


def resolve(host: str) -> str:
    """Resolve hostname to IP address."""
    try:
        result = socket.getaddrinfo(host, DEFAULT_CMD_PORT, socket.AF_INET, socket.SOCK_DGRAM)
        return result[0][4][0]
    except socket.gaierror as e:
        print(f"ERR: cannot resolve {host}: {e}", file=sys.stderr)
        sys.exit(1)


def udp_cmd(ip: str, cmd: str, timeout: float = TIMEOUT) -> str:
    """Send UDP command and return reply."""
    cmd = sign_command(cmd)
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(timeout)
    try:
        sock.sendto(cmd.encode(), (ip, DEFAULT_CMD_PORT))
        data, _ = sock.recvfrom(512)
        return data.decode().strip()
    finally:
        sock.close()


def get_local_ip(target_ip: str) -> str:
    """Determine which local IP the target can reach us on."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    try:
        sock.connect((target_ip, 80))
        return sock.getsockname()[0]
    finally:
        sock.close()


def serve_firmware(directory: str, port: int) -> http.server.HTTPServer:
    """Start HTTP server serving firmware directory in a background thread."""
    handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler(
        *a, directory=directory, **k
    )
    server = http.server.HTTPServer(("0.0.0.0", port), handler)
    thread = threading.Thread(target=server.serve_forever, daemon=True)
    thread.start()
    return server


def main():
    parser = argparse.ArgumentParser(description="OTA firmware update for ESP32 CSI devices")
    parser.add_argument("host", help="Device hostname or IP (e.g., amber-maple.local)")
    parser.add_argument("-f", "--firmware", default=DEFAULT_FW, help="Path to firmware .bin")
    parser.add_argument("-p", "--port", type=int, default=DEFAULT_HTTP_PORT, help="HTTP server port")
    parser.add_argument("--no-wait", action="store_true", help="Don't wait for reboot verification")
    args = parser.parse_args()

    fw_path = os.path.abspath(args.firmware)
    if not os.path.isfile(fw_path):
        print(f"ERR: firmware not found: {fw_path}", file=sys.stderr)
        sys.exit(1)

    fw_size = os.path.getsize(fw_path)
    fw_dir = os.path.dirname(fw_path)
    fw_name = os.path.basename(fw_path)

    print(f"Firmware: {fw_path} ({fw_size // 1024} KB)")

    # Resolve device
    ip = resolve(args.host)
    print(f"Device:   {args.host} ({ip})")

    # Check device is alive
    try:
        reply = udp_cmd(ip, "STATUS")
        print(f"Status:   {reply}")
    except (socket.timeout, OSError) as e:
        print(f"ERR: device not responding: {e}", file=sys.stderr)
        sys.exit(1)

    # Determine local IP
    local_ip = get_local_ip(ip)
    ota_url = f"http://{local_ip}:{args.port}/{fw_name}"
    print(f"OTA URL:  {ota_url}")

    # Start HTTP server
    server = serve_firmware(fw_dir, args.port)
    print(f"HTTP server on port {args.port}")

    # Send OTA command
    try:
        reply = udp_cmd(ip, f"OTA {ota_url}")
        print(f"OTA cmd:  {reply}")
    except (socket.timeout, OSError) as e:
        server.shutdown()
        print(f"ERR: OTA command failed: {e}", file=sys.stderr)
        sys.exit(1)

    if not reply.startswith("OK"):
        server.shutdown()
        print(f"ERR: device rejected OTA: {reply}", file=sys.stderr)
        sys.exit(1)

    if args.no_wait:
        print("OTA started (--no-wait, not verifying)")
        # Keep server alive briefly for download
        time.sleep(30)
        server.shutdown()
        return

    # Wait for device to download, flash, and reboot
    print(f"Waiting for reboot...")
    time.sleep(REBOOT_WAIT)

    # Verify device comes back
    for attempt in range(1, STATUS_RETRIES + 1):
        try:
            ip = resolve(args.host)
            reply = udp_cmd(ip, "STATUS", timeout=3.0)
            print(f"Verified: {reply}")
            server.shutdown()
            return
        except (socket.timeout, OSError):
            print(f"  retry {attempt}/{STATUS_RETRIES}...")
            time.sleep(STATUS_INTERVAL)

    server.shutdown()
    print("ERR: device did not come back after OTA", file=sys.stderr)
    sys.exit(1)


if __name__ == "__main__":
    main()
