feat: Add parallel OTA to esp-fleet (--parallel flag)

Start one HTTP server, send OTA commands to all devices simultaneously,
wait for reboot, then verify all in parallel. Cuts fleet OTA from ~90s
to ~30s. Sequential mode remains the default.

Usage: esp-fleet ota --parallel [firmware.bin]
This commit is contained in:
user
2026-02-04 21:18:17 +01:00
parent 7511814976
commit 6066832271

View File

@@ -2,15 +2,25 @@
"""Query all ESP32 CSI sensors in parallel.""" """Query all ESP32 CSI sensors in parallel."""
import concurrent.futures import concurrent.futures
import http.server
import os import os
import socket import socket
import subprocess import subprocess
import sys import sys
import threading
import time
from esp_ctl.auth import sign_command from esp_ctl.auth import sign_command
DEFAULT_PORT = 5501 DEFAULT_PORT = 5501
DEFAULT_HTTP_PORT = 8070
DEFAULT_FW = os.path.expanduser(
"~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin"
)
TIMEOUT = 2.0 TIMEOUT = 2.0
REBOOT_WAIT = 25.0
STATUS_RETRIES = 10
STATUS_INTERVAL = 3.0
SENSORS = [ SENSORS = [
("muddy-storm", "muddy-storm.local"), ("muddy-storm", "muddy-storm.local"),
@@ -31,14 +41,16 @@ Commands:
rate <10-100> Set ping rate on all devices rate <10-100> Set ping rate on all devices
power <2-20> Set TX power on all devices power <2-20> Set TX power on all devices
reboot Reboot all devices reboot Reboot all devices
ota [firmware.bin] OTA update all devices (sequentially) ota [firmware.bin] OTA update all devices (sequentially)
ota --parallel [firmware] OTA update all devices in parallel
Examples: Examples:
esp-fleet status esp-fleet status
esp-fleet identify esp-fleet identify
esp-fleet rate 50 esp-fleet rate 50
esp-fleet ota esp-fleet ota
esp-fleet ota /path/to/firmware.bin""" esp-fleet ota --parallel
esp-fleet ota --parallel /path/to/firmware.bin"""
def query(name, host, cmd): def query(name, host, cmd):
@@ -64,6 +76,163 @@ def query(name, host, cmd):
sock.close() sock.close()
def _resolve(host):
"""Resolve hostname to IP, return None on failure."""
try:
info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
return info[0][4][0]
except socket.gaierror:
return None
def _udp_cmd(ip, cmd, timeout=TIMEOUT):
"""Send signed UDP command and return reply string."""
cmd = sign_command(cmd)
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.settimeout(timeout)
try:
sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
data, _ = sock.recvfrom(512)
return data.decode().strip()
finally:
sock.close()
def _get_local_ip(target_ip):
"""Determine which local IP the target can reach us on."""
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
sock.connect((target_ip, 80))
return sock.getsockname()[0]
finally:
sock.close()
def _serve_firmware(directory, port):
"""Start HTTP server in background thread, return server object."""
handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler(
*a, directory=directory, **k
)
server = http.server.HTTPServer(("0.0.0.0", port), handler)
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
return server
def _verify_device(name, host, retries=STATUS_RETRIES, interval=STATUS_INTERVAL):
"""Wait for device to come back after OTA. Returns (name, status_or_error)."""
for attempt in range(1, retries + 1):
ip = _resolve(host)
if ip:
try:
reply = _udp_cmd(ip, "STATUS", timeout=3.0)
return (name, reply)
except (socket.timeout, OSError):
pass
time.sleep(interval)
return (name, "ERR: did not come back after OTA")
def run_ota_parallel(firmware=None):
"""Run OTA on all sensors in parallel: one HTTP server, concurrent OTA commands."""
fw_path = os.path.abspath(firmware or DEFAULT_FW)
if not os.path.isfile(fw_path):
print(f"ERR: firmware not found: {fw_path}", file=sys.stderr)
sys.exit(1)
fw_size = os.path.getsize(fw_path)
fw_dir = os.path.dirname(fw_path)
fw_name = os.path.basename(fw_path)
print(f"Firmware: {fw_path} ({fw_size // 1024} KB)")
# Resolve all devices and check they're alive
devices = {} # name -> ip
for name, host in SENSORS:
ip = _resolve(host)
if not ip:
print(f" {name}: ERR cannot resolve {host}", file=sys.stderr)
continue
try:
reply = _udp_cmd(ip, "STATUS")
# Extract version from status
version = "?"
for part in reply.split():
if part.startswith("version="):
version = part[8:]
break
print(f" {name}: alive ({ip}) version={version}")
devices[name] = (host, ip)
except (socket.timeout, OSError):
print(f" {name}: ERR not responding ({ip})", file=sys.stderr)
if not devices:
print("ERR: no devices reachable", file=sys.stderr)
sys.exit(1)
# Start one HTTP server
first_ip = next(iter(devices.values()))[1]
local_ip = _get_local_ip(first_ip)
ota_url = f"http://{local_ip}:{DEFAULT_HTTP_PORT}/{fw_name}"
server = _serve_firmware(fw_dir, DEFAULT_HTTP_PORT)
print(f"HTTP server on :{DEFAULT_HTTP_PORT}, OTA URL: {ota_url}")
# Send OTA command to all devices in parallel
failed = []
with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool:
def _send_ota(name, ip):
try:
reply = _udp_cmd(ip, f"OTA {ota_url}")
return (name, reply)
except (socket.timeout, OSError) as e:
return (name, f"ERR: {e}")
futures = {pool.submit(_send_ota, n, ip): n for n, (host, ip) in devices.items()}
for f in concurrent.futures.as_completed(futures):
name, reply = f.result()
ok = reply.startswith("OK")
print(f" {name}: {reply}")
if not ok:
failed.append(name)
if failed:
print(f"ERR: OTA rejected by: {', '.join(failed)}", file=sys.stderr)
# Wait for all devices to download, flash, and reboot
active = {n: h for n, (h, ip) in devices.items() if n not in failed}
if not active:
server.shutdown()
sys.exit(1)
print(f"Waiting for {len(active)} devices to reboot...")
time.sleep(REBOOT_WAIT)
# Verify all in parallel
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active)) as pool:
futures = {pool.submit(_verify_device, n, h): n for n, h in active.items()}
results = {}
for f in concurrent.futures.as_completed(futures):
name, reply = f.result()
results[name] = reply
server.shutdown()
# Print results in sensor order
max_name = max(len(n) for n, _ in SENSORS)
all_ok = True
for name, _ in SENSORS:
if name in results:
ok = not results[name].startswith("ERR")
if not ok:
all_ok = False
print(f" {name:<{max_name}} {results[name]}")
if all_ok:
print(f"\nAll {len(active)} devices updated.")
else:
print(f"\nSome devices failed.", file=sys.stderr)
sys.exit(1)
def run_ota(firmware=None): def run_ota(firmware=None):
"""Run OTA on each sensor sequentially.""" """Run OTA on each sensor sequentially."""
for name, host in SENSORS: for name, host in SENSORS:
@@ -85,10 +254,17 @@ def main():
print(USAGE) print(USAGE)
sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2) sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2)
# Handle OTA subcommand separately (sequential, not parallel) # Handle OTA subcommand
if sys.argv[1].lower() == "ota": if sys.argv[1].lower() == "ota":
firmware = sys.argv[2] if len(sys.argv) > 2 else None args = sys.argv[2:]
run_ota(firmware) parallel = "--parallel" in args
if parallel:
args.remove("--parallel")
firmware = args[0] if args else None
if parallel:
run_ota_parallel(firmware)
else:
run_ota(firmware)
return return
cmd = " ".join(sys.argv[1:]).strip().upper() cmd = " ".join(sys.argv[1:]).strip().upper()