#!/usr/bin/env python3
"""Send management commands to ESP32 CSI devices over UDP."""

import socket
import sys
import time

from esp_ctl.auth import get_secret, sign_command

DEFAULT_PORT = 5501
TIMEOUT = 2.0

USAGE = """\
Usage: esp-cmd <host> <command> [args...]

Host can be an IP address or mDNS hostname (e.g., amber-maple.local).

Commands:
  STATUS              Query device state (uptime, heap, RSSI, tx_power, rate)
  REBOOT              Restart the ESP32
  IDENTIFY            Blink LED solid for 5 seconds
  RATE <10-100>       Set ping frequency in Hz (saved to NVS)
  POWER <2-20>        Set TX power in dBm (saved to NVS)

Examples:
  esp-cmd amber-maple.local STATUS
  esp-cmd 192.168.129.30 RATE 50
  esp-cmd amber-maple.local IDENTIFY"""


def resolve(host):
    """Resolve hostname to IP address (supports mDNS .local)."""
    try:
        result = socket.getaddrinfo(host, DEFAULT_PORT, socket.AF_INET, socket.SOCK_DGRAM)
        return result[0][4][0]
    except socket.gaierror as e:
        print(f"ERR: cannot resolve {host}: {e}", file=sys.stderr)
        sys.exit(1)


def get_uptime(ip):
    """Query device uptime_s for HMAC timestamp (unauthenticated)."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(TIMEOUT)
    try:
        sock.sendto(b"STATUS", (ip, DEFAULT_PORT))
        data, _ = sock.recvfrom(1500)
        for part in data.decode().split():
            if part.startswith("uptime_s="):
                return int(part.split("=", 1)[1])
    except (socket.timeout, OSError, ValueError):
        pass
    finally:
        sock.close()
    return 0


def main():
    if len(sys.argv) < 3 or sys.argv[1] in ("-h", "--help"):
        print(USAGE)
        sys.exit(0 if sys.argv[1:] and sys.argv[1] in ("-h", "--help") else 2)

    host = sys.argv[1]
    ip = resolve(host)
    secret = get_secret()
    uptime = get_uptime(ip) if secret else 0
    if secret and uptime:
        time.sleep(0.06)  # avoid firmware rate limiter (50ms)
    cmd = sign_command(" ".join(sys.argv[2:]).strip(), uptime, secret)

    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    sock.settimeout(TIMEOUT)

    try:
        sock.sendto(cmd.encode(), (ip, DEFAULT_PORT))
        data, _ = sock.recvfrom(512)
        print(data.decode().strip())
    except socket.timeout:
        print(f"ERR: no reply from {host} ({ip}:{DEFAULT_PORT}), timeout {TIMEOUT}s", file=sys.stderr)
        sys.exit(1)
    except OSError as e:
        print(f"ERR: {e}", file=sys.stderr)
        sys.exit(1)
    finally:
        sock.close()


if __name__ == "__main__":
    main()
