#!/usr/bin/env python3
"""Query all ESP32 CSI sensors in parallel."""

import concurrent.futures
import os
import socket
import subprocess
import sys

DEFAULT_PORT = 5501
TIMEOUT = 2.0

SENSORS = [
    ("muddy-storm",  "muddy-storm.local"),
    ("amber-maple",  "amber-maple.local"),
    ("hollow-acorn", "hollow-acorn.local"),
]

ESP_OTA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "esp-ota")

USAGE = """\
Usage: esp-fleet <command> [args...]

Sends a command to all known sensors in parallel and prints results.

Commands:
  status              Query all devices
  identify            Blink LEDs on all devices
  rate <10-100>       Set ping rate on all devices
  power <2-20>        Set TX power on all devices
  reboot              Reboot all devices
  ota [firmware.bin]   OTA update all devices (sequentially)

Examples:
  esp-fleet status
  esp-fleet identify
  esp-fleet rate 50
  esp-fleet ota
  esp-fleet ota /path/to/firmware.bin"""


def query(name, host, cmd):
    """Send command to one sensor, return (name, reply_or_error)."""
    try:
        info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
        ip = info[0][4][0]
    except socket.gaierror:
        return (name, f"ERR: cannot resolve {host}")

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(TIMEOUT)
    try:
        sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
        data, _ = sock.recvfrom(512)
        return (name, data.decode().strip())
    except socket.timeout:
        return (name, f"ERR: timeout ({ip})")
    except OSError as e:
        return (name, f"ERR: {e}")
    finally:
        sock.close()


def run_ota(firmware=None):
    """Run OTA on each sensor sequentially."""
    for name, host in SENSORS:
        print(f"\n{'='*40}")
        print(f"OTA: {name} ({host})")
        print(f"{'='*40}")
        cmd = [ESP_OTA, host]
        if firmware:
            cmd += ["-f", firmware]
        result = subprocess.run(cmd)
        if result.returncode != 0:
            print(f"ERR: OTA failed for {name}, stopping fleet OTA", file=sys.stderr)
            sys.exit(1)
    print(f"\nAll {len(SENSORS)} devices updated.")


def main():
    if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
        print(USAGE)
        sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2)

    # Handle OTA subcommand separately (sequential, not parallel)
    if sys.argv[1].lower() == "ota":
        firmware = sys.argv[2] if len(sys.argv) > 2 else None
        run_ota(firmware)
        return

    cmd = " ".join(sys.argv[1:]).strip().upper()

    with concurrent.futures.ThreadPoolExecutor(max_workers=len(SENSORS)) as pool:
        futures = {pool.submit(query, name, host, cmd): name for name, host in SENSORS}
        results = {}
        for f in concurrent.futures.as_completed(futures):
            name, reply = f.result()
            results[name] = reply

    # Print in sensor order
    max_name = max(len(n) for n, _ in SENSORS)
    for name, _ in SENSORS:
        print(f"{name:<{max_name}}  {results[name]}")


if __name__ == "__main__":
    main()
