diff --git a/tools/esp-fleet b/tools/esp-fleet index c7a45a9..ba681e5 100755 --- a/tools/esp-fleet +++ b/tools/esp-fleet @@ -2,15 +2,25 @@ """Query all ESP32 CSI sensors in parallel.""" import concurrent.futures +import http.server import os import socket import subprocess import sys +import threading +import time from esp_ctl.auth import sign_command DEFAULT_PORT = 5501 +DEFAULT_HTTP_PORT = 8070 +DEFAULT_FW = os.path.expanduser( + "~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin" +) TIMEOUT = 2.0 +REBOOT_WAIT = 25.0 +STATUS_RETRIES = 10 +STATUS_INTERVAL = 3.0 SENSORS = [ ("muddy-storm", "muddy-storm.local"), @@ -31,14 +41,16 @@ Commands: rate <10-100> Set ping rate on all devices power <2-20> Set TX power on all devices reboot Reboot all devices - ota [firmware.bin] OTA update all devices (sequentially) + ota [firmware.bin] OTA update all devices (sequentially) + ota --parallel [firmware] OTA update all devices in parallel Examples: esp-fleet status esp-fleet identify esp-fleet rate 50 esp-fleet ota - esp-fleet ota /path/to/firmware.bin""" + esp-fleet ota --parallel + esp-fleet ota --parallel /path/to/firmware.bin""" def query(name, host, cmd): @@ -64,6 +76,163 @@ def query(name, host, cmd): sock.close() +def _resolve(host): + """Resolve hostname to IP, return None on failure.""" + try: + info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM) + return info[0][4][0] + except socket.gaierror: + return None + + +def _udp_cmd(ip, cmd, timeout=TIMEOUT): + """Send signed UDP command and return reply string.""" + cmd = sign_command(cmd) + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + sock.settimeout(timeout) + try: + sock.sendto(cmd.encode(), (ip, DEFAULT_PORT)) + data, _ = sock.recvfrom(512) + return data.decode().strip() + finally: + sock.close() + + +def _get_local_ip(target_ip): + """Determine which local IP the target can reach us on.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + sock.connect((target_ip, 80)) + return sock.getsockname()[0] + finally: + sock.close() + + +def _serve_firmware(directory, port): + """Start HTTP server in background thread, return server object.""" + handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler( + *a, directory=directory, **k + ) + server = http.server.HTTPServer(("0.0.0.0", port), handler) + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + return server + + +def _verify_device(name, host, retries=STATUS_RETRIES, interval=STATUS_INTERVAL): + """Wait for device to come back after OTA. Returns (name, status_or_error).""" + for attempt in range(1, retries + 1): + ip = _resolve(host) + if ip: + try: + reply = _udp_cmd(ip, "STATUS", timeout=3.0) + return (name, reply) + except (socket.timeout, OSError): + pass + time.sleep(interval) + return (name, "ERR: did not come back after OTA") + + +def run_ota_parallel(firmware=None): + """Run OTA on all sensors in parallel: one HTTP server, concurrent OTA commands.""" + fw_path = os.path.abspath(firmware or DEFAULT_FW) + if not os.path.isfile(fw_path): + print(f"ERR: firmware not found: {fw_path}", file=sys.stderr) + sys.exit(1) + + fw_size = os.path.getsize(fw_path) + fw_dir = os.path.dirname(fw_path) + fw_name = os.path.basename(fw_path) + print(f"Firmware: {fw_path} ({fw_size // 1024} KB)") + + # Resolve all devices and check they're alive + devices = {} # name -> ip + for name, host in SENSORS: + ip = _resolve(host) + if not ip: + print(f" {name}: ERR cannot resolve {host}", file=sys.stderr) + continue + try: + reply = _udp_cmd(ip, "STATUS") + # Extract version from status + version = "?" + for part in reply.split(): + if part.startswith("version="): + version = part[8:] + break + print(f" {name}: alive ({ip}) version={version}") + devices[name] = (host, ip) + except (socket.timeout, OSError): + print(f" {name}: ERR not responding ({ip})", file=sys.stderr) + + if not devices: + print("ERR: no devices reachable", file=sys.stderr) + sys.exit(1) + + # Start one HTTP server + first_ip = next(iter(devices.values()))[1] + local_ip = _get_local_ip(first_ip) + ota_url = f"http://{local_ip}:{DEFAULT_HTTP_PORT}/{fw_name}" + server = _serve_firmware(fw_dir, DEFAULT_HTTP_PORT) + print(f"HTTP server on :{DEFAULT_HTTP_PORT}, OTA URL: {ota_url}") + + # Send OTA command to all devices in parallel + failed = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool: + def _send_ota(name, ip): + try: + reply = _udp_cmd(ip, f"OTA {ota_url}") + return (name, reply) + except (socket.timeout, OSError) as e: + return (name, f"ERR: {e}") + + futures = {pool.submit(_send_ota, n, ip): n for n, (host, ip) in devices.items()} + for f in concurrent.futures.as_completed(futures): + name, reply = f.result() + ok = reply.startswith("OK") + print(f" {name}: {reply}") + if not ok: + failed.append(name) + + if failed: + print(f"ERR: OTA rejected by: {', '.join(failed)}", file=sys.stderr) + + # Wait for all devices to download, flash, and reboot + active = {n: h for n, (h, ip) in devices.items() if n not in failed} + if not active: + server.shutdown() + sys.exit(1) + + print(f"Waiting for {len(active)} devices to reboot...") + time.sleep(REBOOT_WAIT) + + # Verify all in parallel + with concurrent.futures.ThreadPoolExecutor(max_workers=len(active)) as pool: + futures = {pool.submit(_verify_device, n, h): n for n, h in active.items()} + results = {} + for f in concurrent.futures.as_completed(futures): + name, reply = f.result() + results[name] = reply + + server.shutdown() + + # Print results in sensor order + max_name = max(len(n) for n, _ in SENSORS) + all_ok = True + for name, _ in SENSORS: + if name in results: + ok = not results[name].startswith("ERR") + if not ok: + all_ok = False + print(f" {name:<{max_name}} {results[name]}") + + if all_ok: + print(f"\nAll {len(active)} devices updated.") + else: + print(f"\nSome devices failed.", file=sys.stderr) + sys.exit(1) + + def run_ota(firmware=None): """Run OTA on each sensor sequentially.""" for name, host in SENSORS: @@ -85,10 +254,17 @@ def main(): print(USAGE) sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2) - # Handle OTA subcommand separately (sequential, not parallel) + # Handle OTA subcommand if sys.argv[1].lower() == "ota": - firmware = sys.argv[2] if len(sys.argv) > 2 else None - run_ota(firmware) + args = sys.argv[2:] + parallel = "--parallel" in args + if parallel: + args.remove("--parallel") + firmware = args[0] if args else None + if parallel: + run_ota_parallel(firmware) + else: + run_ota(firmware) return cmd = " ".join(sys.argv[1:]).strip().upper()