- Add serial_task: UART console for AUTH management with physical access AUTH shows full secret, AUTH <secret> sets, AUTH OFF clears - Add esp-provision tool: provision auth secret via serial or NVS flash Supports auto-generate, custom secrets, --serial and --generate-only - Fix esp-ota uptime cache: avoid firmware rate limiter on consecutive udp_cmd calls by caching uptime_s for 3s
188 lines
5.8 KiB
Python
Executable File
188 lines
5.8 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""OTA firmware update for ESP32 CSI devices."""
|
|
|
|
import argparse
|
|
import http.server
|
|
import os
|
|
import socket
|
|
import sys
|
|
import threading
|
|
import time
|
|
|
|
from esp_ctl.auth import get_secret, sign_command
|
|
|
|
DEFAULT_CMD_PORT = 5501
|
|
DEFAULT_HTTP_PORT = 8070
|
|
DEFAULT_FW = os.path.expanduser(
|
|
"~/git/esp32-hacking/get-started/csi_recv_router/build/csi_recv_router.bin"
|
|
)
|
|
TIMEOUT = 2.0
|
|
REBOOT_WAIT = 25.0
|
|
STATUS_RETRIES = 10
|
|
STATUS_INTERVAL = 3.0
|
|
|
|
|
|
def resolve(host: str) -> str:
|
|
"""Resolve hostname to IP address."""
|
|
try:
|
|
result = socket.getaddrinfo(host, DEFAULT_CMD_PORT, socket.AF_INET, socket.SOCK_DGRAM)
|
|
return result[0][4][0]
|
|
except socket.gaierror as e:
|
|
print(f"ERR: cannot resolve {host}: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
_uptime_cache = {"ip": None, "value": 0, "time": 0}
|
|
|
|
|
|
def get_uptime(ip: str, timeout: float = TIMEOUT) -> int:
|
|
"""Query device uptime_s for HMAC timestamp (unauthenticated).
|
|
|
|
Caches result for 3s to avoid hitting the firmware rate limiter
|
|
when multiple udp_cmd calls happen in quick succession.
|
|
"""
|
|
now = time.monotonic()
|
|
if _uptime_cache["ip"] == ip and (now - _uptime_cache["time"]) < 3:
|
|
elapsed = int(now - _uptime_cache["time"])
|
|
return _uptime_cache["value"] + elapsed
|
|
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(timeout)
|
|
try:
|
|
sock.sendto(b"STATUS", (ip, DEFAULT_CMD_PORT))
|
|
data, _ = sock.recvfrom(1500)
|
|
for part in data.decode().split():
|
|
if part.startswith("uptime_s="):
|
|
val = int(part.split("=", 1)[1])
|
|
_uptime_cache.update(ip=ip, value=val, time=now)
|
|
return val
|
|
except (socket.timeout, OSError, ValueError):
|
|
pass
|
|
finally:
|
|
sock.close()
|
|
return 0
|
|
|
|
|
|
def udp_cmd(ip: str, cmd: str, timeout: float = TIMEOUT) -> str:
|
|
"""Send UDP command and return reply."""
|
|
secret = get_secret()
|
|
uptime = get_uptime(ip, timeout) if secret else 0
|
|
if secret and uptime:
|
|
time.sleep(0.06) # avoid firmware rate limiter (50ms)
|
|
cmd = sign_command(cmd, uptime, secret)
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
sock.settimeout(timeout)
|
|
try:
|
|
sock.sendto(cmd.encode(), (ip, DEFAULT_CMD_PORT))
|
|
data, _ = sock.recvfrom(512)
|
|
return data.decode().strip()
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def get_local_ip(target_ip: str) -> str:
|
|
"""Determine which local IP the target can reach us on."""
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
try:
|
|
sock.connect((target_ip, 80))
|
|
return sock.getsockname()[0]
|
|
finally:
|
|
sock.close()
|
|
|
|
|
|
def serve_firmware(directory: str, port: int) -> http.server.HTTPServer:
|
|
"""Start HTTP server serving firmware directory in a background thread."""
|
|
handler = lambda *a, **k: http.server.SimpleHTTPRequestHandler(
|
|
*a, directory=directory, **k
|
|
)
|
|
server = http.server.HTTPServer(("0.0.0.0", port), handler)
|
|
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
|
thread.start()
|
|
return server
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="OTA firmware update for ESP32 CSI devices")
|
|
parser.add_argument("host", help="Device hostname or IP (e.g., amber-maple.local)")
|
|
parser.add_argument("-f", "--firmware", default=DEFAULT_FW, help="Path to firmware .bin")
|
|
parser.add_argument("-p", "--port", type=int, default=DEFAULT_HTTP_PORT, help="HTTP server port")
|
|
parser.add_argument("--no-wait", action="store_true", help="Don't wait for reboot verification")
|
|
args = parser.parse_args()
|
|
|
|
fw_path = os.path.abspath(args.firmware)
|
|
if not os.path.isfile(fw_path):
|
|
print(f"ERR: firmware not found: {fw_path}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
fw_size = os.path.getsize(fw_path)
|
|
fw_dir = os.path.dirname(fw_path)
|
|
fw_name = os.path.basename(fw_path)
|
|
|
|
print(f"Firmware: {fw_path} ({fw_size // 1024} KB)")
|
|
|
|
# Resolve device
|
|
ip = resolve(args.host)
|
|
print(f"Device: {args.host} ({ip})")
|
|
|
|
# Check device is alive
|
|
try:
|
|
reply = udp_cmd(ip, "STATUS")
|
|
print(f"Status: {reply}")
|
|
except (socket.timeout, OSError) as e:
|
|
print(f"ERR: device not responding: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
# Determine local IP
|
|
local_ip = get_local_ip(ip)
|
|
ota_url = f"http://{local_ip}:{args.port}/{fw_name}"
|
|
print(f"OTA URL: {ota_url}")
|
|
|
|
# Start HTTP server
|
|
server = serve_firmware(fw_dir, args.port)
|
|
print(f"HTTP server on port {args.port}")
|
|
|
|
# Send OTA command
|
|
try:
|
|
reply = udp_cmd(ip, f"OTA {ota_url}")
|
|
print(f"OTA cmd: {reply}")
|
|
except (socket.timeout, OSError) as e:
|
|
server.shutdown()
|
|
print(f"ERR: OTA command failed: {e}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if not reply.startswith("OK"):
|
|
server.shutdown()
|
|
print(f"ERR: device rejected OTA: {reply}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
if args.no_wait:
|
|
print("OTA started (--no-wait, not verifying)")
|
|
# Keep server alive briefly for download
|
|
time.sleep(30)
|
|
server.shutdown()
|
|
return
|
|
|
|
# Wait for device to download, flash, and reboot
|
|
print(f"Waiting for reboot...")
|
|
time.sleep(REBOOT_WAIT)
|
|
|
|
# Verify device comes back
|
|
for attempt in range(1, STATUS_RETRIES + 1):
|
|
try:
|
|
ip = resolve(args.host)
|
|
reply = udp_cmd(ip, "STATUS", timeout=3.0)
|
|
print(f"Verified: {reply}")
|
|
server.shutdown()
|
|
return
|
|
except (socket.timeout, OSError):
|
|
print(f" retry {attempt}/{STATUS_RETRIES}...")
|
|
time.sleep(STATUS_INTERVAL)
|
|
|
|
server.shutdown()
|
|
print("ERR: device did not come back after OTA", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|