#!/usr/bin/env python3 """Query all ESP32 CSI sensors in parallel.""" import concurrent.futures import http.server import os import socket import subprocess import sys import threading import time from esp_ctl.auth import sign_command DEFAULT_PORT = 5501 DEFAULT_HTTP_PORT = 8070 DEFAULT_FW = os.path.expanduser( "~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin" ) TIMEOUT = 2.0 REBOOT_WAIT = 25.0 STATUS_RETRIES = 10 STATUS_INTERVAL = 3.0 SENSORS = [ ("muddy-storm", "muddy-storm.local"), ("amber-maple", "amber-maple.local"), ("hollow-acorn", "hollow-acorn.local"), ] ESP_OTA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "esp-ota") USAGE = """\ Usage: esp-fleet [args...] Sends a command to all known sensors in parallel and prints results. Commands: status Query all devices identify Blink LEDs on all devices rate <10-100> Set ping rate on all devices power <2-20> Set TX power on all devices reboot Reboot all devices ota [firmware.bin] OTA update all devices (sequentially) ota --parallel [firmware] OTA update all devices in parallel Examples: esp-fleet status esp-fleet identify esp-fleet rate 50 esp-fleet ota esp-fleet ota --parallel esp-fleet ota --parallel /path/to/firmware.bin""" def query(name, host, cmd): """Send command to one sensor, return (name, reply_or_error).""" cmd = sign_command(cmd) try: info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM) ip = info[0][4][0] except socket.gaierror: return (name, f"ERR: cannot resolve {host}") sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(TIMEOUT) try: sock.sendto(cmd.encode(), (ip, DEFAULT_PORT)) data, _ = sock.recvfrom(512) return (name, data.decode().strip()) except socket.timeout: return (name, f"ERR: timeout ({ip})") except OSError as e: return (name, f"ERR: {e}") finally: sock.close() def _resolve(host): """Resolve hostname to IP, return None on failure.""" try: info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM) return info[0][4][0] except socket.gaierror: return None def _udp_cmd(ip, cmd, timeout=TIMEOUT): """Send signed UDP command and return reply string.""" cmd = sign_command(cmd) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) try: sock.sendto(cmd.encode(), (ip, DEFAULT_PORT)) data, _ = sock.recvfrom(512) return data.decode().strip() finally: sock.close() def _get_local_ip(target_ip): """Determine which local IP the target can reach us on.""" sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: sock.connect((target_ip, 80)) return sock.getsockname()[0] finally: sock.close() def _serve_firmware(directory, port): """Start HTTP server in background thread, return server object.""" handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler( *a, directory=directory, **k ) server = http.server.HTTPServer(("0.0.0.0", port), handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() return server def _verify_device(name, host, retries=STATUS_RETRIES, interval=STATUS_INTERVAL): """Wait for device to come back after OTA. Returns (name, status_or_error).""" for attempt in range(1, retries + 1): ip = _resolve(host) if ip: try: reply = _udp_cmd(ip, "STATUS", timeout=3.0) return (name, reply) except (socket.timeout, OSError): pass time.sleep(interval) return (name, "ERR: did not come back after OTA") def run_ota_parallel(firmware=None): """Run OTA on all sensors in parallel: one HTTP server, concurrent OTA commands.""" fw_path = os.path.abspath(firmware or DEFAULT_FW) if not os.path.isfile(fw_path): print(f"ERR: firmware not found: {fw_path}", file=sys.stderr) sys.exit(1) fw_size = os.path.getsize(fw_path) fw_dir = os.path.dirname(fw_path) fw_name = os.path.basename(fw_path) print(f"Firmware: {fw_path} ({fw_size // 1024} KB)") # Resolve all devices and check they're alive devices = {} # name -> ip for name, host in SENSORS: ip = _resolve(host) if not ip: print(f" {name}: ERR cannot resolve {host}", file=sys.stderr) continue try: reply = _udp_cmd(ip, "STATUS") # Extract version from status version = "?" for part in reply.split(): if part.startswith("version="): version = part[8:] break print(f" {name}: alive ({ip}) version={version}") devices[name] = (host, ip) except (socket.timeout, OSError): print(f" {name}: ERR not responding ({ip})", file=sys.stderr) if not devices: print("ERR: no devices reachable", file=sys.stderr) sys.exit(1) # Start one HTTP server first_ip = next(iter(devices.values()))[1] local_ip = _get_local_ip(first_ip) ota_url = f"http://{local_ip}:{DEFAULT_HTTP_PORT}/{fw_name}" server = _serve_firmware(fw_dir, DEFAULT_HTTP_PORT) print(f"HTTP server on :{DEFAULT_HTTP_PORT}, OTA URL: {ota_url}") # Send OTA command to all devices in parallel failed = [] with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool: def _send_ota(name, ip): try: reply = _udp_cmd(ip, f"OTA {ota_url}") return (name, reply) except (socket.timeout, OSError) as e: return (name, f"ERR: {e}") futures = {pool.submit(_send_ota, n, ip): n for n, (host, ip) in devices.items()} for f in concurrent.futures.as_completed(futures): name, reply = f.result() ok = reply.startswith("OK") print(f" {name}: {reply}") if not ok: failed.append(name) if failed: print(f"ERR: OTA rejected by: {', '.join(failed)}", file=sys.stderr) # Wait for all devices to download, flash, and reboot active = {n: h for n, (h, ip) in devices.items() if n not in failed} if not active: server.shutdown() sys.exit(1) print(f"Waiting for {len(active)} devices to reboot...") time.sleep(REBOOT_WAIT) # Verify all in parallel with concurrent.futures.ThreadPoolExecutor(max_workers=len(active)) as pool: futures = {pool.submit(_verify_device, n, h): n for n, h in active.items()} results = {} for f in concurrent.futures.as_completed(futures): name, reply = f.result() results[name] = reply server.shutdown() # Print results in sensor order max_name = max(len(n) for n, _ in SENSORS) all_ok = True for name, _ in SENSORS: if name in results: ok = not results[name].startswith("ERR") if not ok: all_ok = False print(f" {name:<{max_name}} {results[name]}") if all_ok: print(f"\nAll {len(active)} devices updated.") else: print(f"\nSome devices failed.", file=sys.stderr) sys.exit(1) def run_ota(firmware=None): """Run OTA on each sensor sequentially.""" for name, host in SENSORS: print(f"\n{'='*40}") print(f"OTA: {name} ({host})") print(f"{'='*40}") cmd = [ESP_OTA, host] if firmware: cmd += ["-f", firmware] result = subprocess.run(cmd) if result.returncode != 0: print(f"ERR: OTA failed for {name}, stopping fleet OTA", file=sys.stderr) sys.exit(1) print(f"\nAll {len(SENSORS)} devices updated.") def main(): if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"): print(USAGE) sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2) # Handle OTA subcommand if sys.argv[1].lower() == "ota": args = sys.argv[2:] parallel = "--parallel" in args if parallel: args.remove("--parallel") firmware = args[0] if args else None if parallel: run_ota_parallel(firmware) else: run_ota(firmware) return cmd = " ".join(sys.argv[1:]).strip().upper() with concurrent.futures.ThreadPoolExecutor(max_workers=len(SENSORS)) as pool: futures = {pool.submit(query, name, host, cmd): name for name, host in SENSORS} results = {} for f in concurrent.futures.as_completed(futures): name, reply = f.result() results[name] = reply # Print in sensor order max_name = max(len(n) for n, _ in SENSORS) for name, _ in SENSORS: print(f"{name:<{max_name}} {results[name]}") if __name__ == "__main__": main()