#!/usr/bin/env python3
"""Query all ESP32 CSI sensors in parallel."""

import concurrent.futures
import socket
import sys

DEFAULT_PORT = 5501
TIMEOUT = 2.0

SENSORS = [
    ("muddy-storm",  "muddy-storm.local"),
    ("amber-maple",  "amber-maple.local"),
    ("hollow-acorn", "hollow-acorn.local"),
]

USAGE = """\
Usage: esp-fleet <command> [args...]

Sends a command to all known sensors in parallel and prints results.

Commands:
  status              Query all devices
  identify            Blink LEDs on all devices
  rate <10-100>       Set ping rate on all devices
  power <2-20>        Set TX power on all devices
  reboot              Reboot all devices

Examples:
  esp-fleet status
  esp-fleet identify
  esp-fleet rate 50"""


def query(name, host, cmd):
    """Send command to one sensor, return (name, reply_or_error)."""
    try:
        info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
        ip = info[0][4][0]
    except socket.gaierror:
        return (name, f"ERR: cannot resolve {host}")

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(TIMEOUT)
    try:
        sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
        data, _ = sock.recvfrom(512)
        return (name, data.decode().strip())
    except socket.timeout:
        return (name, f"ERR: timeout ({ip})")
    except OSError as e:
        return (name, f"ERR: {e}")
    finally:
        sock.close()


def main():
    if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
        print(USAGE)
        sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2)

    cmd = " ".join(sys.argv[1:]).strip().upper()

    with concurrent.futures.ThreadPoolExecutor(max_workers=len(SENSORS)) as pool:
        futures = {pool.submit(query, name, host, cmd): name for name, host in SENSORS}
        results = {}
        for f in concurrent.futures.as_completed(futures):
            name, reply = f.result()
            results[name] = reply

    # Print in sensor order
    max_name = max(len(n) for n, _ in SENSORS)
    for name, _ in SENSORS:
        print(f"{name:<{max_name}}  {results[name]}")


if __name__ == "__main__":
    main()
