#!/usr/bin/env python3
"""Audio diagnostics — verify devices, codec, and loopback.

Usage:
  audio-diag                   Run all checks
  audio-diag --devices         List audio devices only
  audio-diag --codec           Test Opus encode/decode roundtrip
  audio-diag --loopback [SEC]  Capture mic -> encode -> decode -> playback
  audio-diag --help            Show this help
"""

from __future__ import annotations

import argparse
import struct
import sys
import time


RST = "\033[0m"
DIM = "\033[2m"
GRN = "\033[38;5;108m"
RED = "\033[38;5;131m"
YEL = "\033[38;5;179m"
CYN = "\033[38;5;109m"

RATE = 48000
CHANNELS = 1
FRAME = 960
DTYPE = "int16"


def ok(msg: str) -> None:
    print(f"  {GRN}\u2713{RST} {msg}")


def fail(msg: str) -> None:
    print(f"  {RED}\u2717{RST} {msg}")


def warn(msg: str) -> None:
    print(f"  {YEL}\u26a0{RST} {msg}")


def info(msg: str) -> None:
    print(f"  {DIM}{msg}{RST}")


def heading(msg: str) -> None:
    print(f"\n{CYN}{msg}{RST}")


# -- checks ------------------------------------------------------------------


def check_portaudio() -> bool:
    heading("PortAudio")
    try:
        import sounddevice as sd  # noqa: F401

        ok("library loaded")
        return True
    except OSError as e:
        fail(f"not found: {e}")
        info("install: apt install libportaudio2")
        return False


def check_devices() -> bool:
    import sounddevice as sd

    heading("Audio Devices")
    devices = sd.query_devices()
    if not devices:
        fail("no devices found")
        return False

    default_in = sd.default.device[0]
    default_out = sd.default.device[1]

    has_input = False
    has_output = False
    for i, d in enumerate(devices):
        markers = []
        if i == default_in:
            markers.append("default-in")
        if i == default_out:
            markers.append("default-out")
        tag = f" {YEL}({', '.join(markers)}){RST}" if markers else ""
        ch_in = d["max_input_channels"]
        ch_out = d["max_output_channels"]
        print(f"  {DIM}{i:>2}{RST} {d['name']}"
              f"  {DIM}in={ch_in} out={ch_out}{RST}{tag}")
        if ch_in > 0:
            has_input = True
        if ch_out > 0:
            has_output = True

    if not has_input:
        warn("no input devices")
    if not has_output:
        warn("no output devices")
    if has_input and has_output:
        ok(f"{len(devices)} devices, input+output available")
    return has_input and has_output


def check_opus() -> bool:
    heading("Opus Codec")
    try:
        import opuslib
    except ImportError:
        fail("opuslib not installed")
        info("install: pip install opuslib")
        return False

    ok("opuslib loaded")

    # roundtrip: sine wave -> encode -> decode -> verify
    encoder = opuslib.Encoder(RATE, CHANNELS, opuslib.APPLICATION_VOIP)
    decoder = opuslib.Decoder(RATE, CHANNELS)

    # generate 20ms of 440Hz sine (int16)
    import math

    pcm = b""
    for i in range(FRAME):
        sample = int(16000 * math.sin(2 * math.pi * 440 * i / RATE))
        pcm += struct.pack("<h", sample)

    encoded = encoder.encode(pcm, FRAME)
    decoded = decoder.decode(encoded, FRAME)

    ok(f"encode: {len(pcm)} bytes -> {len(encoded)} bytes")
    ok(f"decode: {len(encoded)} bytes -> {len(decoded)} bytes")

    # verify decoded is close to original (not silent)
    samples_in = struct.unpack(f"<{FRAME}h", pcm)
    samples_out = struct.unpack(f"<{FRAME}h", decoded)
    rms_in = (sum(s * s for s in samples_in) / FRAME) ** 0.5
    rms_out = (sum(s * s for s in samples_out) / FRAME) ** 0.5

    if rms_out < rms_in * 0.5:
        warn(f"decoded RMS ({rms_out:.0f}) much lower than input ({rms_in:.0f})")
    else:
        ok(f"roundtrip RMS: {rms_in:.0f} -> {rms_out:.0f}")

    return True


def check_streams() -> bool:
    import sounddevice as sd

    heading("Stream Open/Close")
    errors = []

    try:
        s = sd.RawInputStream(
            samplerate=RATE, channels=CHANNELS,
            dtype=DTYPE, blocksize=FRAME,
        )
        s.start()
        time.sleep(0.05)
        s.stop()
        s.close()
        ok("input stream")
    except Exception as e:
        fail(f"input stream: {e}")
        errors.append("input")

    try:
        s = sd.RawOutputStream(
            samplerate=RATE, channels=CHANNELS,
            dtype=DTYPE, blocksize=FRAME,
        )
        s.start()
        time.sleep(0.05)
        s.stop()
        s.close()
        ok("output stream")
    except Exception as e:
        fail(f"output stream: {e}")
        errors.append("output")

    return len(errors) == 0


def run_loopback(seconds: float) -> None:
    """Capture -> Opus encode -> decode -> playback in real time."""
    import opuslib
    import sounddevice as sd

    heading(f"Loopback Test ({seconds:.1f}s)")
    info("mic -> opus encode -> opus decode -> speakers")

    encoder = opuslib.Encoder(RATE, CHANNELS, opuslib.APPLICATION_VOIP)
    decoder = opuslib.Decoder(RATE, CHANNELS)

    import queue

    buf: queue.Queue[bytes] = queue.Queue(maxsize=100)
    stats = {"captured": 0, "played": 0, "dropped": 0}

    def capture_cb(indata, frames, time_info, status):
        try:
            encoded = encoder.encode(bytes(indata), FRAME)
            pcm = decoder.decode(encoded, FRAME)
            buf.put_nowait(pcm)
            stats["captured"] += 1
        except queue.Full:
            stats["dropped"] += 1

    def playback_cb(outdata, frames, time_info, status):
        try:
            pcm = buf.get_nowait()
            n = min(len(pcm), len(outdata))
            outdata[:n] = pcm[:n]
            if n < len(outdata):
                outdata[n:] = b"\x00" * (len(outdata) - n)
            stats["played"] += 1
        except queue.Empty:
            outdata[:] = b"\x00" * len(outdata)

    inp = sd.RawInputStream(
        samplerate=RATE, channels=CHANNELS,
        dtype=DTYPE, blocksize=FRAME, callback=capture_cb,
    )
    out = sd.RawOutputStream(
        samplerate=RATE, channels=CHANNELS,
        dtype=DTYPE, blocksize=FRAME, callback=playback_cb,
    )

    out.start()
    inp.start()
    info("speak into mic now...")

    try:
        time.sleep(seconds)
    except KeyboardInterrupt:
        pass

    inp.stop()
    out.stop()
    inp.close()
    out.close()

    expected = int(seconds * RATE / FRAME)
    ok(f"captured={stats['captured']} played={stats['played']} "
       f"dropped={stats['dropped']} (expected ~{expected})")


# -- main --------------------------------------------------------------------


def main() -> int:
    parser = argparse.ArgumentParser(
        description="Audio pipeline diagnostics",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        "--devices", action="store_true", help="list audio devices",
    )
    parser.add_argument(
        "--codec", action="store_true", help="test opus roundtrip",
    )
    parser.add_argument(
        "--loopback", nargs="?", const=3.0, type=float, metavar="SEC",
        help="mic->encode->decode->speakers (default: 3s)",
    )
    parser.add_argument(
        "--version", action="version", version="audio-diag 0.1.0",
    )
    args = parser.parse_args()

    # specific mode
    if args.devices:
        if not check_portaudio():
            return 1
        check_devices()
        return 0

    if args.codec:
        return 0 if check_opus() else 1

    if args.loopback is not None:
        if not check_portaudio():
            return 1
        run_loopback(args.loopback)
        return 0

    # full diagnostic
    print(f"{CYN}tuimble audio diagnostics{RST}")
    passed = 0
    total = 0

    total += 1
    if check_portaudio():
        passed += 1
    else:
        print(f"\n{RED}cannot continue without PortAudio{RST}")
        return 1

    total += 1
    if check_devices():
        passed += 1

    total += 1
    if check_opus():
        passed += 1

    total += 1
    if check_streams():
        passed += 1

    heading("Summary")
    color = GRN if passed == total else YEL
    print(f"  {color}{passed}/{total} checks passed{RST}")
    return 0 if passed == total else 1


if __name__ == "__main__":
    sys.exit(main())
