#!/usr/bin/env python3
"""Provision auth secret to ESP32 CSI device NVS via USB serial."""

import argparse
import os
import secrets
import subprocess
import sys
import tempfile

NVS_GEN = os.path.expanduser(
    "~/esp/esp-idf/components/nvs_flash/nvs_partition_generator/nvs_partition_gen.py"
)
NVS_NAMESPACE = "csi_config"
NVS_OFFSET = "0x9000"
NVS_SIZE = "0x4000"
DEFAULT_PORT = "/dev/ttyUSB0"
DEFAULT_BAUD = "460800"


def generate_secret() -> str:
    """Generate a 32-char hex secret."""
    return secrets.token_hex(16)


def create_nvs_csv(secret: str, path: str) -> None:
    """Write NVS CSV with auth_secret entry."""
    with open(path, "w") as f:
        f.write("key,type,encoding,value\n")
        f.write(f"{NVS_NAMESPACE},namespace,,\n")
        f.write(f"auth_secret,data,string,{secret}\n")


def main():
    parser = argparse.ArgumentParser(
        description="Provision auth secret to ESP32 NVS partition"
    )
    parser.add_argument(
        "secret", nargs="?", default=None,
        help="Auth secret (8-64 chars). Auto-generated if omitted."
    )
    parser.add_argument(
        "-p", "--port", default=DEFAULT_PORT,
        help=f"USB serial port (default: {DEFAULT_PORT})"
    )
    parser.add_argument(
        "-b", "--baud", default=DEFAULT_BAUD,
        help=f"Flash baud rate (default: {DEFAULT_BAUD})"
    )
    parser.add_argument(
        "--serial", action="store_true",
        help="Set secret via serial console instead of NVS flash"
    )
    parser.add_argument(
        "--generate-only", action="store_true",
        help="Generate and print a secret without flashing"
    )
    args = parser.parse_args()

    # Generate or validate secret
    secret = args.secret or generate_secret()
    if len(secret) < 8 or len(secret) > 64:
        print("ERR: secret must be 8-64 characters", file=sys.stderr)
        sys.exit(1)

    if args.generate_only:
        print(secret)
        sys.exit(0)

    if args.serial:
        # Set secret via serial console (device must be running)
        try:
            import serial
        except ImportError:
            print("ERR: pyserial required (pip install pyserial)", file=sys.stderr)
            sys.exit(1)

        import time
        s = serial.Serial(args.port, 921600, timeout=2)
        s.reset_input_buffer()
        s.write(f"AUTH {secret}\n".encode())
        time.sleep(0.5)
        out = s.read(s.in_waiting or 256).decode("utf-8", errors="replace")
        s.close()

        ok = False
        for line in out.splitlines():
            if "OK AUTH on" in line:
                ok = True
                break

        if ok:
            print(f"Secret set: {secret}")
        else:
            print(f"ERR: unexpected response: {out.strip()}", file=sys.stderr)
            sys.exit(1)
        sys.exit(0)

    # NVS flash method
    if not os.path.isfile(NVS_GEN):
        print(f"ERR: NVS generator not found: {NVS_GEN}", file=sys.stderr)
        print("  Ensure ESP-IDF is installed at ~/esp/esp-idf/", file=sys.stderr)
        sys.exit(1)

    with tempfile.TemporaryDirectory() as tmpdir:
        csv_path = os.path.join(tmpdir, "nvs.csv")
        bin_path = os.path.join(tmpdir, "nvs.bin")

        # Generate NVS partition binary
        create_nvs_csv(secret, csv_path)
        result = subprocess.run(
            [sys.executable, NVS_GEN, "generate", csv_path, bin_path, NVS_SIZE],
            capture_output=True, text=True,
        )
        if result.returncode != 0:
            print(f"ERR: NVS generation failed:\n{result.stderr}", file=sys.stderr)
            sys.exit(1)

        # Flash NVS partition
        print(f"Flashing auth secret to NVS at {NVS_OFFSET}...")
        result = subprocess.run(
            [
                sys.executable, "-m", "esptool",
                "--port", args.port, "--baud", args.baud,
                "write_flash", NVS_OFFSET, bin_path,
            ],
            capture_output=True, text=True,
        )
        if result.returncode != 0:
            print(f"ERR: flash failed:\n{result.stderr}", file=sys.stderr)
            sys.exit(1)

    print(f"Secret provisioned: {secret}")
    print(f"  Add to environment: export ESP_CMD_SECRET=\"{secret}\"")


if __name__ == "__main__":
    main()
