All three standalone tools (esp-cmd, esp-fleet, esp-ota) now fetch device uptime before signing commands, matching what esp-ctl already does. Includes 60ms delay after uptime fetch to avoid firmware rate limiter (50ms inter-command throttle).
313 lines
9.8 KiB
Python
Executable File
313 lines
9.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Query all ESP32 CSI sensors in parallel."""
|
|
|
|
import concurrent.futures
|
|
import http.server
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
from esp_ctl.auth import get_secret, sign_command
|
|
|
|
DEFAULT_PORT = 5501
|
|
DEFAULT_HTTP_PORT = 8070
|
|
DEFAULT_FW = os.path.expanduser(
|
|
"~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin"
|
|
)
|
|
TIMEOUT = 2.0
|
|
REBOOT_WAIT = 25.0
|
|
STATUS_RETRIES = 10
|
|
STATUS_INTERVAL = 3.0
|
|
|
|
SENSORS = [
|
|
("muddy-storm", "muddy-storm.local"),
|
|
("amber-maple", "amber-maple.local"),
|
|
("hollow-acorn", "hollow-acorn.local"),
|
|
]
|
|
|
|
ESP_OTA = os.path.join(os.path.dirname(os.path.abspath(__file__)), "esp-ota")
|
|
|
|
USAGE = """\
|
|
Usage: esp-fleet <command> [args...]
|
|
|
|
Sends a command to all known sensors in parallel and prints results.
|
|
|
|
Commands:
|
|
status Query all devices
|
|
identify Blink LEDs on all devices
|
|
rate <10-100> Set ping rate on all devices
|
|
power <2-20> Set TX power on all devices
|
|
reboot Reboot all devices
|
|
ota [firmware.bin] OTA update all devices (sequentially)
|
|
ota --parallel [firmware] OTA update all devices in parallel
|
|
|
|
Examples:
|
|
esp-fleet status
|
|
esp-fleet identify
|
|
esp-fleet rate 50
|
|
esp-fleet ota
|
|
esp-fleet ota --parallel
|
|
esp-fleet ota --parallel /path/to/firmware.bin"""
|
|
|
|
|
|
def _get_uptime(ip, timeout=TIMEOUT):
|
|
"""Query device uptime_s for HMAC timestamp (unauthenticated)."""
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(timeout)
|
|
try:
|
|
sock.sendto(b"STATUS", (ip, DEFAULT_PORT))
|
|
data, _ = sock.recvfrom(1500)
|
|
for part in data.decode().split():
|
|
if part.startswith("uptime_s="):
|
|
return int(part.split("=", 1)[1])
|
|
except (socket.timeout, OSError, ValueError):
|
|
pass
|
|
finally:
|
|
sock.close()
|
|
return 0
|
|
|
|
|
|
def query(name, host, cmd):
|
|
"""Send command to one sensor, return (name, reply_or_error)."""
|
|
try:
|
|
info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
|
|
ip = info[0][4][0]
|
|
except socket.gaierror:
|
|
return (name, f"ERR: cannot resolve {host}")
|
|
|
|
secret = get_secret()
|
|
uptime = _get_uptime(ip) if secret else 0
|
|
if secret and uptime:
|
|
time.sleep(0.06) # avoid firmware rate limiter (50ms)
|
|
cmd = sign_command(cmd, uptime, secret)
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(TIMEOUT)
|
|
try:
|
|
sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
|
|
data, _ = sock.recvfrom(512)
|
|
return (name, data.decode().strip())
|
|
except socket.timeout:
|
|
return (name, f"ERR: timeout ({ip})")
|
|
except OSError as e:
|
|
return (name, f"ERR: {e}")
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def _resolve(host):
|
|
"""Resolve hostname to IP, return None on failure."""
|
|
try:
|
|
info = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
|
|
return info[0][4][0]
|
|
except socket.gaierror:
|
|
return None
|
|
|
|
|
|
def _udp_cmd(ip, cmd, timeout=TIMEOUT):
|
|
"""Send signed UDP command and return reply string."""
|
|
secret = get_secret()
|
|
uptime = _get_uptime(ip, timeout) if secret else 0
|
|
if secret and uptime:
|
|
time.sleep(0.06) # avoid firmware rate limiter (50ms)
|
|
cmd = sign_command(cmd, uptime, secret)
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(timeout)
|
|
try:
|
|
sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
|
|
data, _ = sock.recvfrom(512)
|
|
return data.decode().strip()
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def _get_local_ip(target_ip):
|
|
"""Determine which local IP the target can reach us on."""
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
try:
|
|
sock.connect((target_ip, 80))
|
|
return sock.getsockname()[0]
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def _serve_firmware(directory, port):
|
|
"""Start HTTP server in background thread, return server object."""
|
|
handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler(
|
|
*a, directory=directory, **k
|
|
)
|
|
server = http.server.HTTPServer(("0.0.0.0", port), handler)
|
|
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
|
thread.start()
|
|
return server
|
|
|
|
|
|
def _verify_device(name, host, retries=STATUS_RETRIES, interval=STATUS_INTERVAL):
|
|
"""Wait for device to come back after OTA. Returns (name, status_or_error)."""
|
|
for attempt in range(1, retries + 1):
|
|
ip = _resolve(host)
|
|
if ip:
|
|
try:
|
|
reply = _udp_cmd(ip, "STATUS", timeout=3.0)
|
|
return (name, reply)
|
|
except (socket.timeout, OSError):
|
|
pass
|
|
time.sleep(interval)
|
|
return (name, "ERR: did not come back after OTA")
|
|
|
|
|
|
def run_ota_parallel(firmware=None):
|
|
"""Run OTA on all sensors in parallel: one HTTP server, concurrent OTA commands."""
|
|
fw_path = os.path.abspath(firmware or DEFAULT_FW)
|
|
if not os.path.isfile(fw_path):
|
|
print(f"ERR: firmware not found: {fw_path}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
fw_size = os.path.getsize(fw_path)
|
|
fw_dir = os.path.dirname(fw_path)
|
|
fw_name = os.path.basename(fw_path)
|
|
print(f"Firmware: {fw_path} ({fw_size // 1024} KB)")
|
|
|
|
# Resolve all devices and check they're alive
|
|
devices = {} # name -> ip
|
|
for name, host in SENSORS:
|
|
ip = _resolve(host)
|
|
if not ip:
|
|
print(f" {name}: ERR cannot resolve {host}", file=sys.stderr)
|
|
continue
|
|
try:
|
|
reply = _udp_cmd(ip, "STATUS")
|
|
# Extract version from status
|
|
version = "?"
|
|
for part in reply.split():
|
|
if part.startswith("version="):
|
|
version = part[8:]
|
|
break
|
|
print(f" {name}: alive ({ip}) version={version}")
|
|
devices[name] = (host, ip)
|
|
except (socket.timeout, OSError):
|
|
print(f" {name}: ERR not responding ({ip})", file=sys.stderr)
|
|
|
|
if not devices:
|
|
print("ERR: no devices reachable", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Start one HTTP server
|
|
first_ip = next(iter(devices.values()))[1]
|
|
local_ip = _get_local_ip(first_ip)
|
|
ota_url = f"http://{local_ip}:{DEFAULT_HTTP_PORT}/{fw_name}"
|
|
server = _serve_firmware(fw_dir, DEFAULT_HTTP_PORT)
|
|
print(f"HTTP server on :{DEFAULT_HTTP_PORT}, OTA URL: {ota_url}")
|
|
|
|
# Send OTA command to all devices in parallel
|
|
failed = []
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(devices)) as pool:
|
|
def _send_ota(name, ip):
|
|
try:
|
|
reply = _udp_cmd(ip, f"OTA {ota_url}")
|
|
return (name, reply)
|
|
except (socket.timeout, OSError) as e:
|
|
return (name, f"ERR: {e}")
|
|
|
|
futures = {pool.submit(_send_ota, n, ip): n for n, (host, ip) in devices.items()}
|
|
for f in concurrent.futures.as_completed(futures):
|
|
name, reply = f.result()
|
|
ok = reply.startswith("OK")
|
|
print(f" {name}: {reply}")
|
|
if not ok:
|
|
failed.append(name)
|
|
|
|
if failed:
|
|
print(f"ERR: OTA rejected by: {', '.join(failed)}", file=sys.stderr)
|
|
|
|
# Wait for all devices to download, flash, and reboot
|
|
active = {n: h for n, (h, ip) in devices.items() if n not in failed}
|
|
if not active:
|
|
server.shutdown()
|
|
sys.exit(1)
|
|
|
|
print(f"Waiting for {len(active)} devices to reboot...")
|
|
time.sleep(REBOOT_WAIT)
|
|
|
|
# Verify all in parallel
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(active)) as pool:
|
|
futures = {pool.submit(_verify_device, n, h): n for n, h in active.items()}
|
|
results = {}
|
|
for f in concurrent.futures.as_completed(futures):
|
|
name, reply = f.result()
|
|
results[name] = reply
|
|
|
|
server.shutdown()
|
|
|
|
# Print results in sensor order
|
|
max_name = max(len(n) for n, _ in SENSORS)
|
|
all_ok = True
|
|
for name, _ in SENSORS:
|
|
if name in results:
|
|
ok = not results[name].startswith("ERR")
|
|
if not ok:
|
|
all_ok = False
|
|
print(f" {name:<{max_name}} {results[name]}")
|
|
|
|
if all_ok:
|
|
print(f"\nAll {len(active)} devices updated.")
|
|
else:
|
|
print(f"\nSome devices failed.", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
def run_ota(firmware=None):
|
|
"""Run OTA on each sensor sequentially."""
|
|
for name, host in SENSORS:
|
|
print(f"\n{'='*40}")
|
|
print(f"OTA: {name} ({host})")
|
|
print(f"{'='*40}")
|
|
cmd = [ESP_OTA, host]
|
|
if firmware:
|
|
cmd += ["-f", firmware]
|
|
result = subprocess.run(cmd)
|
|
if result.returncode != 0:
|
|
print(f"ERR: OTA failed for {name}, stopping fleet OTA", file=sys.stderr)
|
|
sys.exit(1)
|
|
print(f"\nAll {len(SENSORS)} devices updated.")
|
|
|
|
|
|
def main():
|
|
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
|
|
print(USAGE)
|
|
sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2)
|
|
|
|
# Handle OTA subcommand
|
|
if sys.argv[1].lower() == "ota":
|
|
args = sys.argv[2:]
|
|
parallel = "--parallel" in args
|
|
if parallel:
|
|
args.remove("--parallel")
|
|
firmware = args[0] if args else None
|
|
if parallel:
|
|
run_ota_parallel(firmware)
|
|
else:
|
|
run_ota(firmware)
|
|
return
|
|
|
|
cmd = " ".join(sys.argv[1:]).strip().upper()
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=len(SENSORS)) as pool:
|
|
futures = {pool.submit(query, name, host, cmd): name for name, host in SENSORS}
|
|
results = {}
|
|
for f in concurrent.futures.as_completed(futures):
|
|
name, reply = f.result()
|
|
results[name] = reply
|
|
|
|
# Print in sensor order
|
|
max_name = max(len(n) for n, _ in SENSORS)
|
|
for name, _ in SENSORS:
|
|
print(f"{name:<{max_name}} {results[name]}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|