#!/usr/bin/env python3 """OTA firmware update for ESP32 CSI devices.""" import argparse import http.server import os import socket import sys import threading import time from esp_ctl.auth import get_secret, sign_command DEFAULT_CMD_PORT = 5501 DEFAULT_HTTP_PORT = 8070 DEFAULT_FW = os.path.expanduser( "~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin" ) TIMEOUT = 2.0 REBOOT_WAIT = 25.0 STATUS_RETRIES = 10 STATUS_INTERVAL = 3.0 def resolve(host: str) -> str: """Resolve hostname to IP address.""" try: result = socket.getaddrinfo(host, DEFAULT_CMD_PORT, socket.AF_INET, socket.SOCK_DGRAM) return result[0][4][0] except socket.gaierror as e: print(f"ERR: cannot resolve {host}: {e}", file=sys.stderr) sys.exit(1) _uptime_cache = {"ip": None, "value": 0, "time": 0} def get_uptime(ip: str, timeout: float = TIMEOUT) -> int: """Query device uptime_s for HMAC timestamp (unauthenticated). Caches result for 3s to avoid hitting the firmware rate limiter when multiple udp_cmd calls happen in quick succession. """ now = time.monotonic() if _uptime_cache["ip"] == ip and (now - _uptime_cache["time"]) < 3: elapsed = int(now - _uptime_cache["time"]) return _uptime_cache["value"] + elapsed sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) try: sock.sendto(b"STATUS", (ip, DEFAULT_CMD_PORT)) data, _ = sock.recvfrom(1500) for part in data.decode().split(): if part.startswith("uptime_s="): val = int(part.split("=", 1)[1]) _uptime_cache.update(ip=ip, value=val, time=now) return val except (socket.timeout, OSError, ValueError): pass finally: sock.close() return 0 def udp_cmd(ip: str, cmd: str, timeout: float = TIMEOUT) -> str: """Send UDP command and return reply.""" secret = get_secret() uptime = get_uptime(ip, timeout) if secret else 0 if secret and uptime: time.sleep(0.06) # avoid firmware rate limiter (50ms) cmd = sign_command(cmd, uptime, secret) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) try: sock.sendto(cmd.encode(), (ip, DEFAULT_CMD_PORT)) data, _ = sock.recvfrom(512) return data.decode().strip() finally: sock.close() def get_local_ip(target_ip: str) -> str: """Determine which local IP the target can reach us on.""" sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) try: sock.connect((target_ip, 80)) return sock.getsockname()[0] finally: sock.close() def serve_firmware(directory: str, port: int) -> http.server.HTTPServer: """Start HTTP server serving firmware directory in a background thread.""" handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler( *a, directory=directory, **k ) server = http.server.HTTPServer(("0.0.0.0", port), handler) thread = threading.Thread(target=server.serve_forever, daemon=True) thread.start() return server def main(): parser = argparse.ArgumentParser(description="OTA firmware update for ESP32 CSI devices") parser.add_argument("host", help="Device hostname or IP (e.g., amber-maple.local)") parser.add_argument("-f", "--firmware", default=DEFAULT_FW, help="Path to firmware .bin") parser.add_argument("-p", "--port", type=int, default=DEFAULT_HTTP_PORT, help="HTTP server port") parser.add_argument("--no-wait", action="store_true", help="Don't wait for reboot verification") args = parser.parse_args() fw_path = os.path.abspath(args.firmware) if not os.path.isfile(fw_path): print(f"ERR: firmware not found: {fw_path}", file=sys.stderr) sys.exit(1) fw_size = os.path.getsize(fw_path) fw_dir = os.path.dirname(fw_path) fw_name = os.path.basename(fw_path) print(f"Firmware: {fw_path} ({fw_size // 1024} KB)") # Resolve device ip = resolve(args.host) print(f"Device: {args.host} ({ip})") # Check device is alive try: reply = udp_cmd(ip, "STATUS") print(f"Status: {reply}") except (socket.timeout, OSError) as e: print(f"ERR: device not responding: {e}", file=sys.stderr) sys.exit(1) # Determine local IP local_ip = get_local_ip(ip) ota_url = f"http://{local_ip}:{args.port}/{fw_name}" print(f"OTA URL: {ota_url}") # Start HTTP server server = serve_firmware(fw_dir, args.port) print(f"HTTP server on port {args.port}") # Send OTA command try: reply = udp_cmd(ip, f"OTA {ota_url}") print(f"OTA cmd: {reply}") except (socket.timeout, OSError) as e: server.shutdown() print(f"ERR: OTA command failed: {e}", file=sys.stderr) sys.exit(1) if not reply.startswith("OK"): server.shutdown() print(f"ERR: device rejected OTA: {reply}", file=sys.stderr) sys.exit(1) if args.no_wait: print("OTA started (--no-wait, not verifying)") # Keep server alive briefly for download time.sleep(30) server.shutdown() return # Wait for device to download, flash, and reboot print(f"Waiting for reboot...") time.sleep(REBOOT_WAIT) # Verify device comes back for attempt in range(1, STATUS_RETRIES + 1): try: ip = resolve(args.host) reply = udp_cmd(ip, "STATUS", timeout=3.0) print(f"Verified: {reply}") server.shutdown() return except (socket.timeout, OSError): print(f" retry {attempt}/{STATUS_RETRIES}...") time.sleep(STATUS_INTERVAL) server.shutdown() print("ERR: device did not come back after OTA", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()