refactor: switch Mumble voice to pymumble transport
asyncio's SSL memory-BIO transport silently drops voice packets even though text works fine. pymumble uses blocking ssl.SSLSocket.send() which reliably delivers voice data. - Rewrite MumbleBot to use pymumble for connection, SSL, ping, and voice encoding/sending - Bridge pymumble thread callbacks to asyncio via run_coroutine_threadsafe for text dispatch - Voice via sound_output.add_sound(pcm) -- pymumble handles Opus encoding, packetization, and timing - Remove custom protobuf codec, voice varint, and opus ctypes wrapper - Add container patches for pymumble ssl.wrap_socket (Python 3.13) and opuslib find_library (musl/Alpine) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,17 @@
|
||||
FROM python:3.13-alpine
|
||||
|
||||
RUN apk add --no-cache opus ffmpeg yt-dlp
|
||||
RUN apk add --no-cache opus ffmpeg yt-dlp && \
|
||||
ln -s /usr/lib/libopus.so.0 /usr/lib/libopus.so
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Patch pymumble for Python 3.13 (ssl.wrap_socket was removed)
|
||||
COPY patches/apply_pymumble_ssl.py /tmp/apply_pymumble_ssl.py
|
||||
RUN python3 /tmp/apply_pymumble_ssl.py && rm /tmp/apply_pymumble_ssl.py
|
||||
|
||||
ENV PYTHONPATH=/app/src
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENTRYPOINT ["python", "-m", "derp"]
|
||||
|
||||
45
patches/apply_pymumble_ssl.py
Normal file
45
patches/apply_pymumble_ssl.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Patch pymumble deps for Python 3.13+ / musl (Alpine).
|
||||
|
||||
1. pymumble: ssl.wrap_socket was removed in 3.13
|
||||
2. opuslib: ctypes.util.find_library fails on musl-based distros
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
import sysconfig
|
||||
|
||||
site = sysconfig.get_path("purelib")
|
||||
|
||||
# -- pymumble: replace ssl.wrap_socket with SSLContext --
|
||||
p = pathlib.Path(f"{site}/pymumble_py3/mumble.py")
|
||||
src = p.read_text()
|
||||
|
||||
old = """\
|
||||
try:
|
||||
self.control_socket = ssl.wrap_socket(std_sock, certfile=self.certfile, keyfile=self.keyfile, ssl_version=ssl.PROTOCOL_TLS)
|
||||
except AttributeError:
|
||||
self.control_socket = ssl.wrap_socket(std_sock, certfile=self.certfile, keyfile=self.keyfile, ssl_version=ssl.PROTOCOL_TLSv1)
|
||||
try:"""
|
||||
|
||||
new = """\
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
if self.certfile:
|
||||
ctx.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile)
|
||||
self.control_socket = ctx.wrap_socket(std_sock, server_hostname=self.host)
|
||||
try:"""
|
||||
|
||||
assert old in src, "pymumble ssl patch target not found"
|
||||
p.write_text(src.replace(old, new))
|
||||
print("pymumble ssl patch applied")
|
||||
|
||||
# -- opuslib: find_library fails on musl, use direct CDLL fallback --
|
||||
p = pathlib.Path(f"{site}/opuslib/api/__init__.py")
|
||||
src = p.read_text()
|
||||
|
||||
old_opus = "lib_location = find_library('opus')"
|
||||
new_opus = "lib_location = find_library('opus') or 'libopus.so.0'"
|
||||
|
||||
assert old_opus in src, "opuslib find_library patch target not found"
|
||||
p.write_text(src.replace(old_opus, new_opus))
|
||||
print("opuslib musl patch applied")
|
||||
@@ -252,6 +252,17 @@ async def cmd_np(bot, message):
|
||||
)
|
||||
|
||||
|
||||
@command("testtone", help="Music: !testtone -- debug sine wave")
|
||||
async def cmd_testtone(bot, message):
|
||||
"""Send a 3-second test tone for voice debugging."""
|
||||
if not _is_mumble(bot):
|
||||
await bot.reply(message, "Mumble-only feature")
|
||||
return
|
||||
await bot.reply(message, "Sending 440Hz test tone (3s)...")
|
||||
await bot.test_tone(3.0)
|
||||
await bot.reply(message, "Test tone complete")
|
||||
|
||||
|
||||
@command("volume", help="Music: !volume [0-100]")
|
||||
async def cmd_volume(bot, message):
|
||||
"""Get or set playback volume.
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
maxminddb>=2.0
|
||||
pymumble>=1.6
|
||||
PySocks>=1.7.1
|
||||
urllib3[socks]>=2.0
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""Mumble adapter: TLS/TCP over SOCKS5, protobuf control channel + voice."""
|
||||
"""Mumble adapter: pymumble transport with asyncio plugin dispatch."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -6,15 +6,19 @@ import array
|
||||
import asyncio
|
||||
import html
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import ssl
|
||||
import struct
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from derp import http
|
||||
import pymumble_py3 as pymumble
|
||||
from pymumble_py3.constants import (
|
||||
PYMUMBLE_CLBK_CONNECTED,
|
||||
PYMUMBLE_CLBK_DISCONNECTED,
|
||||
PYMUMBLE_CLBK_TEXTMESSAGERECEIVED,
|
||||
)
|
||||
|
||||
from derp.bot import _TokenBucket
|
||||
from derp.plugin import TIERS, PluginRegistry
|
||||
from derp.state import StateStore
|
||||
@@ -23,117 +27,10 @@ log = logging.getLogger(__name__)
|
||||
|
||||
_AMBIGUOUS = object() # sentinel for ambiguous prefix matches
|
||||
|
||||
# -- Mumble message types ----------------------------------------------------
|
||||
|
||||
MSG_VERSION = 0
|
||||
MSG_UDPTUNNEL = 1
|
||||
MSG_AUTHENTICATE = 2
|
||||
MSG_PING = 3
|
||||
MSG_SERVER_SYNC = 5
|
||||
MSG_CHANNEL_REMOVE = 6
|
||||
MSG_CHANNEL_STATE = 7
|
||||
MSG_USER_REMOVE = 8
|
||||
MSG_USER_STATE = 9
|
||||
MSG_TEXT_MESSAGE = 11
|
||||
|
||||
# -- Protobuf wire helpers (minimal, no external dep) ------------------------
|
||||
|
||||
_WIRE_VARINT = 0
|
||||
_WIRE_LEN = 2
|
||||
|
||||
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
"""Encode an unsigned integer as a protobuf varint."""
|
||||
buf = bytearray()
|
||||
while value > 0x7F:
|
||||
buf.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
buf.append(value & 0x7F)
|
||||
return bytes(buf)
|
||||
|
||||
|
||||
def _decode_varint(data: bytes, offset: int) -> tuple[int, int]:
|
||||
"""Decode a protobuf varint, returning (value, new_offset)."""
|
||||
result = 0
|
||||
shift = 0
|
||||
while offset < len(data):
|
||||
byte = data[offset]
|
||||
offset += 1
|
||||
result |= (byte & 0x7F) << shift
|
||||
if not (byte & 0x80):
|
||||
return result, offset
|
||||
shift += 7
|
||||
return result, offset
|
||||
|
||||
|
||||
def _encode_field(field_num: int, wire_type: int, value: int | bytes) -> bytes:
|
||||
"""Encode a single protobuf field."""
|
||||
tag = _encode_varint((field_num << 3) | wire_type)
|
||||
if wire_type == _WIRE_VARINT:
|
||||
return tag + _encode_varint(value)
|
||||
# wire_type == _WIRE_LEN
|
||||
if isinstance(value, str):
|
||||
value = value.encode("utf-8")
|
||||
return tag + _encode_varint(len(value)) + value
|
||||
|
||||
|
||||
def _decode_fields(data: bytes) -> dict[int, list]:
|
||||
"""Decode protobuf fields, returning field_num -> list of values."""
|
||||
fields: dict[int, list] = {}
|
||||
offset = 0
|
||||
while offset < len(data):
|
||||
tag, offset = _decode_varint(data, offset)
|
||||
field_num = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
if wire_type == _WIRE_VARINT:
|
||||
value, offset = _decode_varint(data, offset)
|
||||
elif wire_type == _WIRE_LEN:
|
||||
length, offset = _decode_varint(data, offset)
|
||||
value = data[offset:offset + length]
|
||||
offset += length
|
||||
elif wire_type == 5: # 32-bit fixed
|
||||
value = data[offset:offset + 4]
|
||||
offset += 4
|
||||
elif wire_type == 1: # 64-bit fixed
|
||||
value = data[offset:offset + 8]
|
||||
offset += 8
|
||||
else:
|
||||
break # unknown wire type, stop parsing
|
||||
fields.setdefault(field_num, []).append(value)
|
||||
return fields
|
||||
|
||||
|
||||
def _pack_msg(msg_type: int, payload: bytes = b"") -> bytes:
|
||||
"""Wrap a protobuf payload in a Mumble 6-byte header."""
|
||||
return struct.pack("!HI", msg_type, len(payload)) + payload
|
||||
|
||||
|
||||
def _unpack_header(data: bytes) -> tuple[int, int]:
|
||||
"""Unpack a 6-byte Mumble header into (msg_type, payload_length)."""
|
||||
return struct.unpack("!HI", data)
|
||||
|
||||
|
||||
def _field_str(fields: dict[int, list], num: int) -> str | None:
|
||||
"""Extract a string field."""
|
||||
vals = fields.get(num)
|
||||
if vals and isinstance(vals[0], bytes):
|
||||
return vals[0].decode("utf-8", errors="replace")
|
||||
return None
|
||||
|
||||
|
||||
def _field_int(fields: dict[int, list], num: int) -> int | None:
|
||||
"""Extract an integer field."""
|
||||
vals = fields.get(num)
|
||||
if vals and isinstance(vals[0], int):
|
||||
return vals[0]
|
||||
return None
|
||||
|
||||
|
||||
def _field_ints(fields: dict[int, list], num: int) -> list[int]:
|
||||
"""Extract repeated integer fields."""
|
||||
vals = fields.get(num, [])
|
||||
return [v for v in vals if isinstance(v, int)]
|
||||
|
||||
# PCM constants for audio streaming
|
||||
_SAMPLE_RATE = 48000
|
||||
_FRAME_SIZE = 960 # 20ms at 48kHz
|
||||
_FRAME_BYTES = 1920 # 960 samples * 2 bytes (s16le)
|
||||
|
||||
# -- HTML helpers ------------------------------------------------------------
|
||||
|
||||
@@ -155,66 +52,6 @@ def _shell_quote(s: str) -> str:
|
||||
return "'" + s.replace("'", "'\\''") + "'"
|
||||
|
||||
|
||||
# -- Mumble voice helpers ----------------------------------------------------
|
||||
|
||||
|
||||
def _encode_voice_varint(value: int) -> bytes:
|
||||
"""Encode an integer using Mumble's voice varint format.
|
||||
|
||||
NOT the same as protobuf varint. Mumble voice varints use a prefix
|
||||
code based on leading bits:
|
||||
0xxxxxxx -- 7-bit (0-127)
|
||||
10xxxxxx yyyyyyyy -- 14-bit
|
||||
110xxxxx yyyyyyyy yyyyyyyy -- 21-bit
|
||||
1110xxxx yyyyyyyy yyyyyyyy yyyyyyyy -- 28-bit
|
||||
11110000 + 8 bytes -- 64-bit
|
||||
"""
|
||||
if value < 0:
|
||||
raise ValueError("voice varint must be non-negative")
|
||||
if value < 0x80:
|
||||
return bytes([value])
|
||||
if value < 0x4000:
|
||||
return bytes([0x80 | (value >> 8), value & 0xFF])
|
||||
if value < 0x200000:
|
||||
return bytes([
|
||||
0xC0 | (value >> 16),
|
||||
(value >> 8) & 0xFF,
|
||||
value & 0xFF,
|
||||
])
|
||||
if value < 0x10000000:
|
||||
return bytes([
|
||||
0xE0 | (value >> 24),
|
||||
(value >> 16) & 0xFF,
|
||||
(value >> 8) & 0xFF,
|
||||
value & 0xFF,
|
||||
])
|
||||
# 64-bit fallback
|
||||
return b"\xf0" + value.to_bytes(8, "big")
|
||||
|
||||
|
||||
def _build_voice_packet(
|
||||
sequence: int,
|
||||
opus_data: bytes,
|
||||
*,
|
||||
last: bool = False,
|
||||
) -> bytes:
|
||||
"""Build a Mumble voice packet for client-to-server Opus audio.
|
||||
|
||||
Format (client-to-server, no session ID):
|
||||
1 byte : header (type=4 << 5 | target=0 -> 0x80)
|
||||
varint : sequence number (increments by 1 per frame)
|
||||
varint : opus frame length (bit 13 = terminator on last)
|
||||
N bytes : raw opus data
|
||||
"""
|
||||
header = bytes([0x80]) # type=4 (Opus), target=0
|
||||
seq = _encode_voice_varint(sequence)
|
||||
length = len(opus_data)
|
||||
if last:
|
||||
length |= 0x2000 # bit 13 = terminator flag
|
||||
size = _encode_voice_varint(length)
|
||||
return header + seq + size + opus_data
|
||||
|
||||
|
||||
def _scale_pcm(data: bytes, volume: float) -> bytes:
|
||||
"""Scale s16le PCM samples by a volume factor, clamped to [-32768, 32767]."""
|
||||
samples = array.array("h")
|
||||
@@ -241,7 +78,7 @@ class MumbleMessage:
|
||||
and ``msg.tags`` work without modification.
|
||||
"""
|
||||
|
||||
raw: dict # decoded protobuf fields
|
||||
raw: dict # original message data
|
||||
nick: str | None # sender username (from session lookup)
|
||||
prefix: str | None # sender username (for ACL)
|
||||
text: str | None # message text (HTML stripped)
|
||||
@@ -252,115 +89,16 @@ class MumbleMessage:
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# -- Message builders --------------------------------------------------------
|
||||
|
||||
|
||||
def _build_version_payload() -> bytes:
|
||||
"""Build a Version message payload."""
|
||||
payload = b""
|
||||
# field 1: version_v1 (uint32) -- legacy: 1.5.0
|
||||
payload += _encode_field(1, _WIRE_VARINT, (1 << 16) | (5 << 8))
|
||||
# field 2: release (string)
|
||||
payload += _encode_field(2, _WIRE_LEN, "derp 1.5.0")
|
||||
# field 3: os (string)
|
||||
payload += _encode_field(3, _WIRE_LEN, "Linux")
|
||||
# field 4: os_version (string)
|
||||
payload += _encode_field(4, _WIRE_LEN, "")
|
||||
# field 5: version_v2 (uint64) -- semantic: 1.5.0
|
||||
payload += _encode_field(5, _WIRE_VARINT, (1 << 48) | (5 << 32))
|
||||
return payload
|
||||
|
||||
|
||||
def _build_authenticate_payload(username: str, password: str) -> bytes:
|
||||
"""Build an Authenticate message payload."""
|
||||
payload = b""
|
||||
# field 1: username (string)
|
||||
payload += _encode_field(1, _WIRE_LEN, username)
|
||||
# field 2: password (string)
|
||||
if password:
|
||||
payload += _encode_field(2, _WIRE_LEN, password)
|
||||
# field 5: opus (bool/varint) -- True
|
||||
payload += _encode_field(5, _WIRE_VARINT, 1)
|
||||
return payload
|
||||
|
||||
|
||||
def _build_ping_payload(timestamp: int) -> bytes:
|
||||
"""Build a Ping message payload."""
|
||||
return _encode_field(1, _WIRE_VARINT, timestamp)
|
||||
|
||||
|
||||
def _build_text_message_payload(
|
||||
channel_ids: list[int] | None = None,
|
||||
session_ids: list[int] | None = None,
|
||||
tree_ids: list[int] | None = None,
|
||||
message: str = "",
|
||||
) -> bytes:
|
||||
"""Build a TextMessage payload."""
|
||||
payload = b""
|
||||
for sid in (session_ids or []):
|
||||
payload += _encode_field(2, _WIRE_VARINT, sid)
|
||||
for cid in (channel_ids or []):
|
||||
payload += _encode_field(3, _WIRE_VARINT, cid)
|
||||
for tid in (tree_ids or []):
|
||||
payload += _encode_field(4, _WIRE_VARINT, tid)
|
||||
# field 5: message (string)
|
||||
payload += _encode_field(5, _WIRE_LEN, message)
|
||||
return payload
|
||||
|
||||
|
||||
def _build_mumble_message(
|
||||
fields: dict[int, list],
|
||||
users: dict[int, str],
|
||||
our_session: int,
|
||||
) -> MumbleMessage | None:
|
||||
"""Build a MumbleMessage from decoded TextMessage fields."""
|
||||
# field 5: message text
|
||||
raw_text = _field_str(fields, 5)
|
||||
if raw_text is None:
|
||||
return None
|
||||
|
||||
text = _strip_html(raw_text)
|
||||
|
||||
# field 1: actor (sender session)
|
||||
actor = _field_int(fields, 1)
|
||||
nick = users.get(actor) if actor is not None else None
|
||||
prefix = nick # use username for ACL
|
||||
|
||||
# Determine target: channel_id (field 3) or DM (field 2)
|
||||
channel_ids = _field_ints(fields, 3)
|
||||
session_ids = _field_ints(fields, 2)
|
||||
|
||||
if channel_ids:
|
||||
target = str(channel_ids[0])
|
||||
is_channel = True
|
||||
elif session_ids:
|
||||
target = "dm"
|
||||
is_channel = False
|
||||
else:
|
||||
target = None
|
||||
is_channel = True
|
||||
|
||||
return MumbleMessage(
|
||||
raw=dict(fields),
|
||||
nick=nick,
|
||||
prefix=prefix,
|
||||
text=text,
|
||||
target=target,
|
||||
is_channel=is_channel,
|
||||
params=[target or "", text],
|
||||
)
|
||||
|
||||
|
||||
# -- MumbleBot --------------------------------------------------------------
|
||||
|
||||
|
||||
class MumbleBot:
|
||||
"""Mumble bot adapter via TCP/TLS protobuf control channel (text only).
|
||||
"""Mumble bot adapter using pymumble for connection and voice.
|
||||
|
||||
Exposes the same public API as :class:`derp.bot.Bot` so that
|
||||
protocol-agnostic plugins work without modification.
|
||||
TCP is routed through ``derp.http.create_connection`` (SOCKS5
|
||||
optional via ``mumble.proxy`` config).
|
||||
Voice uses pymumble's sound_output (blocking SSL socket, proven
|
||||
reliable for audio delivery).
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, config: dict, registry: PluginRegistry) -> None:
|
||||
@@ -370,12 +108,10 @@ class MumbleBot:
|
||||
self._pstate: dict = {}
|
||||
|
||||
mu_cfg = config.get("mumble", {})
|
||||
self._proxy: bool = mu_cfg.get("proxy", True)
|
||||
self._host: str = mu_cfg.get("host", "127.0.0.1")
|
||||
self._port: int = mu_cfg.get("port", 64738)
|
||||
self._username: str = mu_cfg.get("username", "derp")
|
||||
self._password: str = mu_cfg.get("password", "")
|
||||
self._tls_verify: bool = mu_cfg.get("tls_verify", False)
|
||||
self.nick: str = self._username
|
||||
self.prefix: str = (
|
||||
mu_cfg.get("prefix")
|
||||
@@ -384,19 +120,14 @@ class MumbleBot:
|
||||
self._running = False
|
||||
self._started: float = time.monotonic()
|
||||
self._tasks: set[asyncio.Task] = set()
|
||||
self._reconnect_delay: float = 5.0
|
||||
self._admins: list[str] = [str(x) for x in mu_cfg.get("admins", [])]
|
||||
self._operators: list[str] = [str(x) for x in mu_cfg.get("operators", [])]
|
||||
self._trusted: list[str] = [str(x) for x in mu_cfg.get("trusted", [])]
|
||||
self.state = StateStore(f"data/state-{name}.db")
|
||||
|
||||
# Protocol state
|
||||
self._session: int = 0 # our session ID (from ServerSync)
|
||||
self._channels: dict[int, str] = {} # channel_id -> name
|
||||
self._users: dict[int, str] = {} # session_id -> username
|
||||
self._user_channels: dict[int, int] = {} # session_id -> channel_id
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
# pymumble state
|
||||
self._mumble: pymumble.Mumble | None = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
rate_cfg = config.get("bot", {})
|
||||
self._bucket = _TokenBucket(
|
||||
@@ -406,163 +137,99 @@ class MumbleBot:
|
||||
|
||||
# -- Connection ----------------------------------------------------------
|
||||
|
||||
def _create_ssl_context(self) -> ssl.SSLContext:
|
||||
"""Build an SSL context for Mumble TLS."""
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
if not self._tls_verify:
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
|
||||
async def _connect(self) -> None:
|
||||
"""Establish TLS connection, optionally through SOCKS5 proxy."""
|
||||
loop = asyncio.get_running_loop()
|
||||
sock = await loop.run_in_executor(
|
||||
None, lambda: http.create_connection(
|
||||
(self._host, self._port), proxy=self._proxy,
|
||||
),
|
||||
def _connect_sync(self) -> None:
|
||||
"""Create and start pymumble connection (blocking, run in executor)."""
|
||||
self._mumble = pymumble.Mumble(
|
||||
self._host, self._username,
|
||||
port=self._port, password=self._password,
|
||||
reconnect=True,
|
||||
)
|
||||
ssl_ctx = self._create_ssl_context()
|
||||
self._reader, self._writer = await asyncio.open_connection(
|
||||
sock=sock, ssl=ssl_ctx, server_hostname=self._host,
|
||||
self._mumble.callbacks.set_callback(
|
||||
PYMUMBLE_CLBK_TEXTMESSAGERECEIVED,
|
||||
self._on_text_message,
|
||||
)
|
||||
self._mumble.callbacks.set_callback(
|
||||
PYMUMBLE_CLBK_CONNECTED,
|
||||
self._on_connected,
|
||||
)
|
||||
self._mumble.callbacks.set_callback(
|
||||
PYMUMBLE_CLBK_DISCONNECTED,
|
||||
self._on_disconnected,
|
||||
)
|
||||
self._mumble.set_receive_sound(False)
|
||||
self._mumble.start()
|
||||
self._mumble.is_ready()
|
||||
|
||||
async def _send_msg(self, msg_type: int, payload: bytes = b"") -> None:
|
||||
"""Send a framed Mumble message."""
|
||||
if self._writer is None:
|
||||
def _on_connected(self) -> None:
|
||||
"""Callback from pymumble thread: connection established."""
|
||||
session = getattr(self._mumble.users, "myself_session", "?")
|
||||
log.info("mumble: connected as %s on %s:%d (session=%s)",
|
||||
self._username, self._host, self._port, session)
|
||||
|
||||
def _on_disconnected(self) -> None:
|
||||
"""Callback from pymumble thread: connection lost."""
|
||||
log.warning("mumble: disconnected")
|
||||
|
||||
def _on_text_message(self, message) -> None:
|
||||
"""Callback from pymumble thread: text message received.
|
||||
|
||||
Bridges to the asyncio event loop for command dispatch.
|
||||
"""
|
||||
if self._loop is None:
|
||||
return
|
||||
data = _pack_msg(msg_type, payload)
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._handle_text(message), self._loop,
|
||||
)
|
||||
|
||||
async def _read_msg(self) -> tuple[int, bytes] | None:
|
||||
"""Read one framed Mumble message. Returns (msg_type, payload)."""
|
||||
if self._reader is None:
|
||||
return None
|
||||
header = await self._reader.readexactly(6)
|
||||
msg_type, length = _unpack_header(header)
|
||||
payload = await self._reader.readexactly(length) if length else b""
|
||||
return msg_type, payload
|
||||
async def _handle_text(self, pb_msg) -> None:
|
||||
"""Process a text message from pymumble (runs on asyncio loop)."""
|
||||
text = _strip_html(pb_msg.message)
|
||||
actor = pb_msg.actor
|
||||
|
||||
async def _close(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._writer:
|
||||
try:
|
||||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
# Look up sender username
|
||||
nick = None
|
||||
try:
|
||||
nick = self._mumble.users[actor]["name"]
|
||||
except (KeyError, TypeError):
|
||||
pass
|
||||
|
||||
# Determine target: channel or DM
|
||||
if pb_msg.channel_id:
|
||||
target = str(pb_msg.channel_id[0])
|
||||
is_channel = True
|
||||
elif pb_msg.session:
|
||||
target = "dm"
|
||||
is_channel = False
|
||||
else:
|
||||
target = None
|
||||
is_channel = True
|
||||
|
||||
msg = MumbleMessage(
|
||||
raw={},
|
||||
nick=nick,
|
||||
prefix=nick,
|
||||
text=text,
|
||||
target=target,
|
||||
is_channel=is_channel,
|
||||
params=[target or "", text],
|
||||
)
|
||||
await self._dispatch_command(msg)
|
||||
|
||||
# -- Lifecycle -----------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect and enter the message loop with reconnect backoff."""
|
||||
"""Connect via pymumble and run until stopped."""
|
||||
self._running = True
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect_and_run()
|
||||
self._reconnect_delay = 5.0
|
||||
except (OSError, ConnectionError, asyncio.IncompleteReadError) as exc:
|
||||
log.error("mumble: connection lost: %s", exc)
|
||||
if self._running:
|
||||
jitter = self._reconnect_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = self._reconnect_delay + jitter
|
||||
log.info("mumble: reconnecting in %.0fs...", delay)
|
||||
await asyncio.sleep(delay)
|
||||
self._reconnect_delay = min(self._reconnect_delay * 2, 300.0)
|
||||
self._loop = asyncio.get_running_loop()
|
||||
|
||||
async def _connect_and_run(self) -> None:
|
||||
"""Single connection lifecycle."""
|
||||
await self._connect()
|
||||
await self._loop.run_in_executor(None, self._connect_sync)
|
||||
try:
|
||||
await self._handshake()
|
||||
await self._loop()
|
||||
while self._running:
|
||||
await asyncio.sleep(1)
|
||||
finally:
|
||||
await self._close()
|
||||
|
||||
async def _handshake(self) -> None:
|
||||
"""Send Version and Authenticate messages."""
|
||||
# Version
|
||||
await self._send_msg(MSG_VERSION, _build_version_payload())
|
||||
# Authenticate
|
||||
await self._send_msg(
|
||||
MSG_AUTHENTICATE,
|
||||
_build_authenticate_payload(self._username, self._password),
|
||||
)
|
||||
log.info("mumble: authenticating as %s on %s:%d",
|
||||
self._username, self._host, self._port)
|
||||
|
||||
async def _loop(self) -> None:
|
||||
"""Read and dispatch messages until disconnect."""
|
||||
while self._running:
|
||||
result = await self._read_msg()
|
||||
if result is None:
|
||||
log.warning("mumble: server closed connection")
|
||||
return
|
||||
msg_type, payload = result
|
||||
await self._handle(msg_type, payload)
|
||||
|
||||
async def _handle(self, msg_type: int, payload: bytes) -> None:
|
||||
"""Dispatch a received Mumble message by type."""
|
||||
if msg_type == MSG_PING:
|
||||
fields = _decode_fields(payload)
|
||||
ts = _field_int(fields, 1) or 0
|
||||
await self._send_msg(MSG_PING, _build_ping_payload(ts))
|
||||
return
|
||||
|
||||
if msg_type == MSG_SERVER_SYNC:
|
||||
fields = _decode_fields(payload)
|
||||
session = _field_int(fields, 1)
|
||||
if session is not None:
|
||||
self._session = session
|
||||
welcome = _field_str(fields, 3) or ""
|
||||
log.info("mumble: connected, session=%d, welcome=%s",
|
||||
self._session, _strip_html(welcome)[:80])
|
||||
return
|
||||
|
||||
if msg_type == MSG_CHANNEL_STATE:
|
||||
fields = _decode_fields(payload)
|
||||
cid = _field_int(fields, 1)
|
||||
name = _field_str(fields, 3)
|
||||
if cid is not None and name is not None:
|
||||
self._channels[cid] = name
|
||||
return
|
||||
|
||||
if msg_type == MSG_CHANNEL_REMOVE:
|
||||
fields = _decode_fields(payload)
|
||||
cid = _field_int(fields, 1)
|
||||
if cid is not None:
|
||||
self._channels.pop(cid, None)
|
||||
return
|
||||
|
||||
if msg_type == MSG_USER_STATE:
|
||||
fields = _decode_fields(payload)
|
||||
session = _field_int(fields, 1)
|
||||
name = _field_str(fields, 3)
|
||||
channel_id = _field_int(fields, 5)
|
||||
if session is not None:
|
||||
if name is not None:
|
||||
self._users[session] = name
|
||||
if channel_id is not None:
|
||||
self._user_channels[session] = channel_id
|
||||
return
|
||||
|
||||
if msg_type == MSG_USER_REMOVE:
|
||||
fields = _decode_fields(payload)
|
||||
session = _field_int(fields, 1)
|
||||
if session is not None:
|
||||
self._users.pop(session, None)
|
||||
self._user_channels.pop(session, None)
|
||||
return
|
||||
|
||||
if msg_type == MSG_TEXT_MESSAGE:
|
||||
fields = _decode_fields(payload)
|
||||
msg = _build_mumble_message(fields, self._users, self._session)
|
||||
if msg is not None:
|
||||
await self._dispatch_command(msg)
|
||||
return
|
||||
if self._mumble:
|
||||
self._mumble.stop()
|
||||
self._mumble = None
|
||||
|
||||
# -- Command dispatch ----------------------------------------------------
|
||||
|
||||
@@ -607,11 +274,7 @@ class MumbleBot:
|
||||
log.exception("mumble: error in command handler '%s'", cmd_name)
|
||||
|
||||
def _resolve_command(self, name: str):
|
||||
"""Resolve command name with unambiguous prefix matching.
|
||||
|
||||
Returns the Handler on exact or unique prefix match, the sentinel
|
||||
``_AMBIGUOUS`` if multiple commands match, or None if nothing matches.
|
||||
"""
|
||||
"""Resolve command name with unambiguous prefix matching."""
|
||||
handler = self.registry.commands.get(name)
|
||||
if handler is not None:
|
||||
return handler
|
||||
@@ -630,10 +293,7 @@ class MumbleBot:
|
||||
# -- Permission tiers ----------------------------------------------------
|
||||
|
||||
def _get_tier(self, msg: MumbleMessage) -> str:
|
||||
"""Determine permission tier from username.
|
||||
|
||||
Matches exact string comparison of username against config lists.
|
||||
"""
|
||||
"""Determine permission tier from username."""
|
||||
if not msg.prefix:
|
||||
return "user"
|
||||
for name in self._admins:
|
||||
@@ -653,25 +313,31 @@ class MumbleBot:
|
||||
|
||||
# -- Public API for plugins ----------------------------------------------
|
||||
|
||||
def _send_text_sync(self, channel_id: int, html_text: str) -> None:
|
||||
"""Send a text message via pymumble (blocking, thread-safe)."""
|
||||
try:
|
||||
channel = self._mumble.channels[channel_id]
|
||||
channel.send_text_message(html_text)
|
||||
except Exception:
|
||||
log.exception("mumble: failed to send text to channel %d",
|
||||
channel_id)
|
||||
|
||||
async def _send_html(self, target: str, html_text: str) -> None:
|
||||
"""Send a TextMessage with pre-formatted HTML (no escaping)."""
|
||||
if self._mumble is None:
|
||||
return
|
||||
await self._bucket.acquire()
|
||||
try:
|
||||
channel_id = int(target)
|
||||
except (ValueError, TypeError):
|
||||
channel_id = 0 # root channel fallback
|
||||
payload = _build_text_message_payload(
|
||||
channel_ids=[channel_id], message=html_text,
|
||||
channel_id = 0
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None, self._send_text_sync, channel_id, html_text,
|
||||
)
|
||||
await self._send_msg(MSG_TEXT_MESSAGE, payload)
|
||||
|
||||
async def send(self, target: str, text: str) -> None:
|
||||
"""Send a TextMessage to a channel (HTML-escaped, rate-limited).
|
||||
|
||||
``target`` is a channel_id as string. For DMs, sends to the
|
||||
current channel instead (Mumble DMs require session IDs which
|
||||
are not stable).
|
||||
"""
|
||||
"""Send a TextMessage to a channel (HTML-escaped, rate-limited)."""
|
||||
await self._send_html(target, _escape_html(text))
|
||||
|
||||
async def reply(self, msg, text: str) -> None:
|
||||
@@ -679,17 +345,13 @@ class MumbleBot:
|
||||
if msg.target and msg.target != "dm":
|
||||
await self.send(msg.target, text)
|
||||
elif msg.target == "dm":
|
||||
# Best-effort: send to root channel
|
||||
await self.send("0", text)
|
||||
|
||||
async def long_reply(
|
||||
self, msg, lines: list[str], *,
|
||||
label: str = "",
|
||||
) -> None:
|
||||
"""Reply with a list of lines; paste overflow to FlaskPaste.
|
||||
|
||||
Same overflow logic as :meth:`derp.bot.Bot.long_reply`.
|
||||
"""
|
||||
"""Reply with a list of lines; paste overflow to FlaskPaste."""
|
||||
threshold = self.config.get("bot", {}).get("paste_threshold", 4)
|
||||
if not lines or not msg.target:
|
||||
return
|
||||
@@ -699,7 +361,6 @@ class MumbleBot:
|
||||
await self.send(msg.target, line)
|
||||
return
|
||||
|
||||
# Attempt paste overflow
|
||||
fp = self.registry._modules.get("flaskpaste")
|
||||
paste_url = None
|
||||
if fp:
|
||||
@@ -729,9 +390,32 @@ class MumbleBot:
|
||||
|
||||
# -- Voice streaming -----------------------------------------------------
|
||||
|
||||
async def _send_voice_packet(self, packet: bytes) -> None:
|
||||
"""Send a voice packet via UDPTunnel (msg type 1)."""
|
||||
await self._send_msg(MSG_UDPTUNNEL, packet)
|
||||
async def test_tone(self, duration: float = 3.0) -> None:
|
||||
"""Send a 440Hz sine test tone for debugging voice output."""
|
||||
import math
|
||||
|
||||
if self._mumble is None:
|
||||
return
|
||||
|
||||
log.info("test_tone: sending %.1fs of 440Hz sine", duration)
|
||||
total_frames = int(duration / 0.02)
|
||||
|
||||
for i in range(total_frames):
|
||||
samples = []
|
||||
for j in range(_FRAME_SIZE):
|
||||
t = (i * _FRAME_SIZE + j) / _SAMPLE_RATE
|
||||
samples.append(int(16000 * math.sin(2 * math.pi * 440 * t)))
|
||||
pcm = struct.pack(f"<{_FRAME_SIZE}h", *samples)
|
||||
self._mumble.sound_output.add_sound(pcm)
|
||||
|
||||
# Keep buffer shallow so we can cancel promptly
|
||||
while self._mumble.sound_output.get_buffer_size() > 0.5:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Wait for buffer to drain
|
||||
while self._mumble.sound_output.get_buffer_size() > 0:
|
||||
await asyncio.sleep(0.1)
|
||||
log.info("test_tone: done")
|
||||
|
||||
async def stream_audio(
|
||||
self,
|
||||
@@ -746,16 +430,14 @@ class MumbleBot:
|
||||
yt-dlp -o - -f bestaudio <url>
|
||||
| ffmpeg -i pipe:0 -f s16le -ar 48000 -ac 1 pipe:1
|
||||
|
||||
Reads 1920 bytes (20ms frames), scales volume, encodes Opus,
|
||||
wraps in voice packets, sends at 20ms intervals. Sets terminator
|
||||
flag on the last frame.
|
||||
|
||||
Args:
|
||||
url: Audio URL (YouTube, SoundCloud, etc.)
|
||||
volume: Volume scale factor (0.0 to 1.0).
|
||||
on_done: Optional asyncio.Event to set when playback ends.
|
||||
Feeds raw PCM to pymumble's sound_output which handles Opus
|
||||
encoding, packetization, and timing.
|
||||
"""
|
||||
from derp.opus import FRAME_BYTES, OpusEncoder
|
||||
if self._mumble is None:
|
||||
return
|
||||
|
||||
log.info("stream_audio: starting pipeline for %s (vol=%.0f%%)",
|
||||
url, volume * 100)
|
||||
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
"sh", "-c",
|
||||
@@ -763,44 +445,50 @@ class MumbleBot:
|
||||
f" | ffmpeg -i pipe:0 -f s16le -ar 48000 -ac 1"
|
||||
f" -loglevel error pipe:1",
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.DEVNULL,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
encoder = OpusEncoder()
|
||||
sequence = 0
|
||||
frames = 0
|
||||
try:
|
||||
while True:
|
||||
pcm = await proc.stdout.read(FRAME_BYTES)
|
||||
pcm = await proc.stdout.read(_FRAME_BYTES)
|
||||
if not pcm:
|
||||
break
|
||||
if len(pcm) < FRAME_BYTES:
|
||||
pcm += b"\x00" * (FRAME_BYTES - len(pcm))
|
||||
if len(pcm) < _FRAME_BYTES:
|
||||
pcm += b"\x00" * (_FRAME_BYTES - len(pcm))
|
||||
|
||||
if volume != 1.0:
|
||||
pcm = _scale_pcm(pcm, volume)
|
||||
|
||||
opus_data = encoder.encode(pcm)
|
||||
pkt = _build_voice_packet(sequence, opus_data)
|
||||
await self._send_voice_packet(pkt)
|
||||
sequence += 1
|
||||
self._mumble.sound_output.add_sound(pcm)
|
||||
frames += 1
|
||||
|
||||
# Pace at 20ms per frame
|
||||
await asyncio.sleep(0.02)
|
||||
if frames == 1:
|
||||
log.info("stream_audio: first frame fed to pymumble")
|
||||
|
||||
# Send terminator frame (silence)
|
||||
silence = b"\x00" * FRAME_BYTES
|
||||
opus_data = encoder.encode(silence)
|
||||
pkt = _build_voice_packet(sequence, opus_data, last=True)
|
||||
await self._send_voice_packet(pkt)
|
||||
# Keep buffer at most 1 second ahead
|
||||
while self._mumble.sound_output.get_buffer_size() > 1.0:
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# Wait for buffer to drain
|
||||
while self._mumble.sound_output.get_buffer_size() > 0:
|
||||
await asyncio.sleep(0.1)
|
||||
log.info("stream_audio: finished, %d frames", frames)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._mumble.sound_output.clear_buffer()
|
||||
log.info("stream_audio: cancelled at frame %d", frames)
|
||||
except Exception:
|
||||
log.exception("stream_audio: error at frame %d", frames)
|
||||
finally:
|
||||
encoder.close()
|
||||
try:
|
||||
proc.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
stderr_out = await proc.stderr.read()
|
||||
await proc.wait()
|
||||
if stderr_out:
|
||||
log.warning("stream_audio: subprocess stderr: %s",
|
||||
stderr_out.decode(errors="replace")[:500])
|
||||
if on_done is not None:
|
||||
on_done.set()
|
||||
|
||||
@@ -819,11 +507,9 @@ class MumbleBot:
|
||||
|
||||
async def join(self, channel: str) -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: join() is a no-op")
|
||||
|
||||
async def part(self, channel: str, reason: str = "") -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: part() is a no-op")
|
||||
|
||||
async def quit(self, reason: str = "bye") -> None:
|
||||
"""Stop the Mumble adapter."""
|
||||
@@ -831,15 +517,12 @@ class MumbleBot:
|
||||
|
||||
async def kick(self, channel: str, nick: str, reason: str = "") -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: kick() is a no-op")
|
||||
|
||||
async def mode(self, target: str, mode_str: str, *args: str) -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: mode() is a no-op")
|
||||
|
||||
async def set_topic(self, channel: str, topic: str) -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: set_topic() is a no-op")
|
||||
|
||||
# -- Plugin management (delegated to registry) ---------------------------
|
||||
|
||||
|
||||
100
src/derp/opus.py
100
src/derp/opus.py
@@ -1,100 +0,0 @@
|
||||
"""Minimal ctypes wrapper around system libopus for encoding only."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import ctypes
|
||||
import ctypes.util
|
||||
from ctypes import POINTER, c_char_p, c_int, c_int32
|
||||
|
||||
SAMPLE_RATE = 48000
|
||||
CHANNELS = 1
|
||||
FRAME_SIZE = 960 # 20ms at 48kHz mono
|
||||
FRAME_BYTES = 1920 # FRAME_SIZE * CHANNELS * 2 (s16le)
|
||||
|
||||
_APPLICATION_AUDIO = 2049
|
||||
|
||||
_OPUS_SET_BITRATE_REQUEST = 4002
|
||||
_OPUS_OK = 0
|
||||
|
||||
_lib: ctypes.CDLL | None = None
|
||||
|
||||
|
||||
def _load_lib() -> ctypes.CDLL:
|
||||
"""Find and load libopus, cached after first call."""
|
||||
global _lib
|
||||
if _lib is not None:
|
||||
return _lib
|
||||
|
||||
path = ctypes.util.find_library("opus")
|
||||
if path is None:
|
||||
path = "libopus.so.0"
|
||||
|
||||
lib = ctypes.cdll.LoadLibrary(path)
|
||||
|
||||
lib.opus_encoder_get_size.argtypes = [c_int]
|
||||
lib.opus_encoder_get_size.restype = c_int
|
||||
|
||||
lib.opus_encoder_init.argtypes = [c_char_p, c_int32, c_int, c_int]
|
||||
lib.opus_encoder_init.restype = c_int
|
||||
|
||||
lib.opus_encode.argtypes = [
|
||||
c_char_p, # encoder state
|
||||
c_char_p, # pcm input
|
||||
c_int, # frame_size (samples per channel)
|
||||
POINTER(ctypes.c_ubyte), # output buffer
|
||||
c_int32, # max output bytes
|
||||
]
|
||||
lib.opus_encode.restype = c_int
|
||||
|
||||
lib.opus_encoder_ctl.argtypes = [c_char_p, c_int]
|
||||
lib.opus_encoder_ctl.restype = c_int
|
||||
|
||||
_lib = lib
|
||||
return lib
|
||||
|
||||
|
||||
class OpusEncoder:
|
||||
"""Opus encoder for 48kHz mono s16le PCM -> Opus frames."""
|
||||
|
||||
def __init__(self, bitrate: int = 64000) -> None:
|
||||
lib = _load_lib()
|
||||
size = lib.opus_encoder_get_size(CHANNELS)
|
||||
self._state = ctypes.create_string_buffer(size)
|
||||
rc = lib.opus_encoder_init(
|
||||
self._state, SAMPLE_RATE, CHANNELS, _APPLICATION_AUDIO,
|
||||
)
|
||||
if rc != _OPUS_OK:
|
||||
raise RuntimeError(f"opus_encoder_init failed: {rc}")
|
||||
|
||||
rc = lib.opus_encoder_ctl(
|
||||
self._state, _OPUS_SET_BITRATE_REQUEST, c_int32(bitrate),
|
||||
)
|
||||
if rc != _OPUS_OK:
|
||||
raise RuntimeError(f"opus_encoder_ctl set bitrate failed: {rc}")
|
||||
|
||||
self._lib = lib
|
||||
self._out = (ctypes.c_ubyte * 4000)()
|
||||
|
||||
def encode(self, pcm: bytes) -> bytes:
|
||||
"""Encode one 20ms frame of s16le PCM to an Opus packet.
|
||||
|
||||
Args:
|
||||
pcm: Exactly 1920 bytes (960 samples, 48kHz mono s16le).
|
||||
|
||||
Returns:
|
||||
Opus-encoded frame bytes.
|
||||
"""
|
||||
if len(pcm) != FRAME_BYTES:
|
||||
raise ValueError(
|
||||
f"expected {FRAME_BYTES} bytes, got {len(pcm)}"
|
||||
)
|
||||
n = self._lib.opus_encode(
|
||||
self._state, pcm, FRAME_SIZE, self._out, len(self._out),
|
||||
)
|
||||
if n < 0:
|
||||
raise RuntimeError(f"opus_encode failed: {n}")
|
||||
return bytes(self._out[:n])
|
||||
|
||||
def close(self) -> None:
|
||||
"""Release encoder state."""
|
||||
self._state = None
|
||||
@@ -1,41 +1,16 @@
|
||||
"""Tests for the Mumble adapter."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
import struct
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from derp.mumble import (
|
||||
_WIRE_LEN,
|
||||
_WIRE_VARINT,
|
||||
MSG_CHANNEL_REMOVE,
|
||||
MSG_CHANNEL_STATE,
|
||||
MSG_PING,
|
||||
MSG_SERVER_SYNC,
|
||||
MSG_TEXT_MESSAGE,
|
||||
MSG_UDPTUNNEL,
|
||||
MSG_USER_REMOVE,
|
||||
MSG_USER_STATE,
|
||||
MumbleBot,
|
||||
MumbleMessage,
|
||||
_build_authenticate_payload,
|
||||
_build_mumble_message,
|
||||
_build_ping_payload,
|
||||
_build_text_message_payload,
|
||||
_build_version_payload,
|
||||
_build_voice_packet,
|
||||
_decode_fields,
|
||||
_decode_varint,
|
||||
_encode_field,
|
||||
_encode_varint,
|
||||
_encode_voice_varint,
|
||||
_escape_html,
|
||||
_field_int,
|
||||
_field_ints,
|
||||
_field_str,
|
||||
_pack_msg,
|
||||
_scale_pcm,
|
||||
_shell_quote,
|
||||
_strip_html,
|
||||
_unpack_header,
|
||||
)
|
||||
from derp.plugin import PluginRegistry
|
||||
|
||||
@@ -51,7 +26,6 @@ def _make_bot(admins=None, operators=None, trusted=None, prefix=None):
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"admins": admins or [],
|
||||
"operators": operators or [],
|
||||
"trusted": trusted or [],
|
||||
@@ -94,120 +68,6 @@ async def _admin_handler(bot, msg):
|
||||
await bot.reply(msg, "admin action done")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestProtobufCodec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProtobufCodec:
|
||||
def test_encode_varint_zero(self):
|
||||
assert _encode_varint(0) == b"\x00"
|
||||
|
||||
def test_encode_varint_small(self):
|
||||
assert _encode_varint(1) == b"\x01"
|
||||
assert _encode_varint(127) == b"\x7f"
|
||||
|
||||
def test_encode_varint_two_byte(self):
|
||||
assert _encode_varint(128) == b"\x80\x01"
|
||||
assert _encode_varint(300) == b"\xac\x02"
|
||||
|
||||
def test_encode_varint_large(self):
|
||||
# 16384 = 0x4000
|
||||
encoded = _encode_varint(16384)
|
||||
val, _ = _decode_varint(encoded, 0)
|
||||
assert val == 16384
|
||||
|
||||
def test_decode_varint_zero(self):
|
||||
val, off = _decode_varint(b"\x00", 0)
|
||||
assert val == 0
|
||||
assert off == 1
|
||||
|
||||
def test_decode_varint_small(self):
|
||||
val, off = _decode_varint(b"\x01", 0)
|
||||
assert val == 1
|
||||
assert off == 1
|
||||
|
||||
def test_decode_varint_multi_byte(self):
|
||||
val, off = _decode_varint(b"\xac\x02", 0)
|
||||
assert val == 300
|
||||
assert off == 2
|
||||
|
||||
def test_varint_roundtrip(self):
|
||||
for n in [0, 1, 127, 128, 300, 16384, 2**21, 2**28]:
|
||||
encoded = _encode_varint(n)
|
||||
decoded, _ = _decode_varint(encoded, 0)
|
||||
assert decoded == n, f"roundtrip failed for {n}"
|
||||
|
||||
def test_encode_field_varint(self):
|
||||
# field 1, wire type 0, value 42
|
||||
data = _encode_field(1, _WIRE_VARINT, 42)
|
||||
fields = _decode_fields(data)
|
||||
assert fields[1] == [42]
|
||||
|
||||
def test_encode_field_string(self):
|
||||
data = _encode_field(5, _WIRE_LEN, "hello")
|
||||
fields = _decode_fields(data)
|
||||
assert fields[5] == [b"hello"]
|
||||
|
||||
def test_encode_field_bytes(self):
|
||||
data = _encode_field(3, _WIRE_LEN, b"\x00\x01\x02")
|
||||
fields = _decode_fields(data)
|
||||
assert fields[3] == [b"\x00\x01\x02"]
|
||||
|
||||
def test_decode_multiple_fields(self):
|
||||
data = (
|
||||
_encode_field(1, _WIRE_VARINT, 10)
|
||||
+ _encode_field(2, _WIRE_LEN, "test")
|
||||
+ _encode_field(3, _WIRE_VARINT, 99)
|
||||
)
|
||||
fields = _decode_fields(data)
|
||||
assert fields[1] == [10]
|
||||
assert fields[2] == [b"test"]
|
||||
assert fields[3] == [99]
|
||||
|
||||
def test_decode_repeated_fields(self):
|
||||
data = (
|
||||
_encode_field(3, _WIRE_VARINT, 1)
|
||||
+ _encode_field(3, _WIRE_VARINT, 2)
|
||||
+ _encode_field(3, _WIRE_VARINT, 3)
|
||||
)
|
||||
fields = _decode_fields(data)
|
||||
assert fields[3] == [1, 2, 3]
|
||||
|
||||
def test_pack_unpack_header(self):
|
||||
packed = _pack_msg(11, b"hello")
|
||||
msg_type, length = _unpack_header(packed[:6])
|
||||
assert msg_type == 11
|
||||
assert length == 5
|
||||
assert packed[6:] == b"hello"
|
||||
|
||||
def test_pack_empty_payload(self):
|
||||
packed = _pack_msg(3)
|
||||
assert len(packed) == 6
|
||||
msg_type, length = _unpack_header(packed)
|
||||
assert msg_type == 3
|
||||
assert length == 0
|
||||
|
||||
def test_field_str(self):
|
||||
fields = {5: [b"hello"]}
|
||||
assert _field_str(fields, 5) == "hello"
|
||||
assert _field_str(fields, 1) is None
|
||||
|
||||
def test_field_int(self):
|
||||
fields = {1: [42]}
|
||||
assert _field_int(fields, 1) == 42
|
||||
assert _field_int(fields, 2) is None
|
||||
|
||||
def test_field_ints(self):
|
||||
fields = {3: [1, 2, 3]}
|
||||
assert _field_ints(fields, 3) == [1, 2, 3]
|
||||
assert _field_ints(fields, 9) == []
|
||||
|
||||
def test_decode_empty(self):
|
||||
fields = _decode_fields(b"")
|
||||
assert fields == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleMessage
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -253,78 +113,6 @@ class TestMumbleMessage:
|
||||
assert msg.prefix == "admin_user"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestBuildMumbleMessage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildMumbleMessage:
|
||||
def test_channel_message(self):
|
||||
fields = {
|
||||
1: [42], # actor session
|
||||
3: [5], # channel_id
|
||||
5: [b"<b>Hello</b>"], # message HTML
|
||||
}
|
||||
users = {42: "Alice"}
|
||||
msg = _build_mumble_message(fields, users, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.nick == "Alice"
|
||||
assert msg.prefix == "Alice"
|
||||
assert msg.text == "Hello"
|
||||
assert msg.target == "5"
|
||||
assert msg.is_channel is True
|
||||
|
||||
def test_dm_message(self):
|
||||
fields = {
|
||||
1: [42], # actor session
|
||||
2: [1], # target session (DM)
|
||||
5: [b"secret"], # message
|
||||
}
|
||||
users = {42: "Bob"}
|
||||
msg = _build_mumble_message(fields, users, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.target == "dm"
|
||||
assert msg.is_channel is False
|
||||
|
||||
def test_missing_sender(self):
|
||||
fields = {
|
||||
3: [0], # channel_id
|
||||
5: [b"anonymous"], # message
|
||||
}
|
||||
msg = _build_mumble_message(fields, {}, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.nick is None
|
||||
assert msg.prefix is None
|
||||
|
||||
def test_missing_message(self):
|
||||
fields = {
|
||||
1: [42],
|
||||
3: [0],
|
||||
}
|
||||
msg = _build_mumble_message(fields, {42: "Alice"}, our_session=1)
|
||||
assert msg is None
|
||||
|
||||
def test_html_stripped(self):
|
||||
fields = {
|
||||
1: [1],
|
||||
3: [0],
|
||||
5: [b"<a href='link'>click & go</a>"],
|
||||
}
|
||||
msg = _build_mumble_message(fields, {1: "User"}, our_session=0)
|
||||
assert msg is not None
|
||||
assert msg.text == "click & go"
|
||||
|
||||
def test_no_target(self):
|
||||
fields = {
|
||||
1: [42],
|
||||
5: [b"orphan message"],
|
||||
}
|
||||
msg = _build_mumble_message(fields, {42: "Alice"}, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.target is None
|
||||
assert msg.is_channel is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestHtmlHelpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -356,36 +144,27 @@ class TestHtmlHelpers:
|
||||
|
||||
|
||||
class TestMumbleBotReply:
|
||||
def test_send_builds_text_message(self):
|
||||
def test_send_calls_send_html(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[int, bytes]] = []
|
||||
sent: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_send_msg(msg_type, payload=b""):
|
||||
sent.append((msg_type, payload))
|
||||
async def _fake_send_html(target, html_text):
|
||||
sent.append((target, html_text))
|
||||
|
||||
with patch.object(bot, "_send_msg", side_effect=_fake_send_msg):
|
||||
with patch.object(bot, "_send_html", side_effect=_fake_send_html):
|
||||
asyncio.run(bot.send("5", "hello"))
|
||||
assert len(sent) == 1
|
||||
assert sent[0][0] == MSG_TEXT_MESSAGE
|
||||
# Verify payload contains the message
|
||||
fields = _decode_fields(sent[0][1])
|
||||
text = _field_str(fields, 5)
|
||||
assert text == "hello"
|
||||
assert _field_ints(fields, 3) == [5]
|
||||
assert sent == [("5", "hello")]
|
||||
|
||||
def test_send_escapes_html(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[int, bytes]] = []
|
||||
sent: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_send_msg(msg_type, payload=b""):
|
||||
sent.append((msg_type, payload))
|
||||
async def _fake_send_html(target, html_text):
|
||||
sent.append((target, html_text))
|
||||
|
||||
with patch.object(bot, "_send_msg", side_effect=_fake_send_msg):
|
||||
with patch.object(bot, "_send_html", side_effect=_fake_send_html):
|
||||
asyncio.run(bot.send("0", "<script>alert('xss')"))
|
||||
fields = _decode_fields(sent[0][1])
|
||||
text = _field_str(fields, 5)
|
||||
assert "<script>" not in text
|
||||
assert "<script>" in text
|
||||
assert "<script>" in sent[0][1]
|
||||
|
||||
def test_reply_sends_to_target(self):
|
||||
bot = _make_bot()
|
||||
@@ -409,7 +188,6 @@ class TestMumbleBotReply:
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot.reply(msg, "dm reply"))
|
||||
# DM falls back to root channel "0"
|
||||
assert sent == [("0", "dm reply")]
|
||||
|
||||
def test_long_reply_under_threshold(self):
|
||||
@@ -655,149 +433,6 @@ class TestMumbleBotNoOps:
|
||||
assert bot._running is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleBotProtocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleBotProtocol:
|
||||
def test_version_payload_structure(self):
|
||||
payload = _build_version_payload()
|
||||
fields = _decode_fields(payload)
|
||||
# field 1: version_v1 (uint32)
|
||||
assert _field_int(fields, 1) == (1 << 16) | (5 << 8)
|
||||
# field 2: release (string)
|
||||
assert _field_str(fields, 2) == "derp 1.5.0"
|
||||
# field 3: os (string)
|
||||
assert _field_str(fields, 3) == "Linux"
|
||||
|
||||
def test_authenticate_payload_structure(self):
|
||||
payload = _build_authenticate_payload("testuser", "testpass")
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_str(fields, 1) == "testuser"
|
||||
assert _field_str(fields, 2) == "testpass"
|
||||
# field 5: opus (bool=1)
|
||||
assert _field_int(fields, 5) == 1
|
||||
|
||||
def test_authenticate_no_password(self):
|
||||
payload = _build_authenticate_payload("testuser", "")
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_str(fields, 1) == "testuser"
|
||||
assert 2 not in fields # no password field
|
||||
|
||||
def test_ping_payload_roundtrip(self):
|
||||
payload = _build_ping_payload(123456789)
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_int(fields, 1) == 123456789
|
||||
|
||||
def test_text_message_payload(self):
|
||||
payload = _build_text_message_payload(
|
||||
channel_ids=[5], message="<b>hello</b>",
|
||||
)
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_str(fields, 5) == "<b>hello</b>"
|
||||
assert _field_ints(fields, 3) == [5]
|
||||
|
||||
def test_text_message_multiple_channels(self):
|
||||
payload = _build_text_message_payload(
|
||||
channel_ids=[1, 2, 3], message="broadcast",
|
||||
)
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_ints(fields, 3) == [1, 2, 3]
|
||||
|
||||
def test_handle_ping_echoes(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[int, bytes]] = []
|
||||
|
||||
async def _fake_send_msg(msg_type, payload=b""):
|
||||
sent.append((msg_type, payload))
|
||||
|
||||
ping_payload = _build_ping_payload(42)
|
||||
with patch.object(bot, "_send_msg", side_effect=_fake_send_msg):
|
||||
asyncio.run(bot._handle(MSG_PING, ping_payload))
|
||||
assert len(sent) == 1
|
||||
assert sent[0][0] == MSG_PING
|
||||
fields = _decode_fields(sent[0][1])
|
||||
assert _field_int(fields, 1) == 42
|
||||
|
||||
def test_handle_server_sync(self):
|
||||
bot = _make_bot()
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 99)
|
||||
+ _encode_field(3, _WIRE_LEN, "Welcome!")
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_SERVER_SYNC, payload))
|
||||
assert bot._session == 99
|
||||
|
||||
def test_handle_channel_state(self):
|
||||
bot = _make_bot()
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 5)
|
||||
+ _encode_field(3, _WIRE_LEN, "General")
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_CHANNEL_STATE, payload))
|
||||
assert bot._channels[5] == "General"
|
||||
|
||||
def test_handle_channel_remove(self):
|
||||
bot = _make_bot()
|
||||
bot._channels[5] = "General"
|
||||
payload = _encode_field(1, _WIRE_VARINT, 5)
|
||||
asyncio.run(bot._handle(MSG_CHANNEL_REMOVE, payload))
|
||||
assert 5 not in bot._channels
|
||||
|
||||
def test_handle_user_state(self):
|
||||
bot = _make_bot()
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 10)
|
||||
+ _encode_field(3, _WIRE_LEN, "Alice")
|
||||
+ _encode_field(5, _WIRE_VARINT, 0)
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_USER_STATE, payload))
|
||||
assert bot._users[10] == "Alice"
|
||||
assert bot._user_channels[10] == 0
|
||||
|
||||
def test_handle_user_state_channel_change(self):
|
||||
bot = _make_bot()
|
||||
bot._users[10] = "Alice"
|
||||
bot._user_channels[10] = 0
|
||||
# User moves to channel 5
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 10)
|
||||
+ _encode_field(5, _WIRE_VARINT, 5)
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_USER_STATE, payload))
|
||||
assert bot._users[10] == "Alice" # name unchanged
|
||||
assert bot._user_channels[10] == 5
|
||||
|
||||
def test_handle_user_remove(self):
|
||||
bot = _make_bot()
|
||||
bot._users[10] = "Alice"
|
||||
bot._user_channels[10] = 0
|
||||
payload = _encode_field(1, _WIRE_VARINT, 10)
|
||||
asyncio.run(bot._handle(MSG_USER_REMOVE, payload))
|
||||
assert 10 not in bot._users
|
||||
assert 10 not in bot._user_channels
|
||||
|
||||
def test_handle_text_message_dispatch(self):
|
||||
bot = _make_bot()
|
||||
bot._users[42] = "Alice"
|
||||
bot.registry.register_command(
|
||||
"echo", _echo_handler, help="echo", plugin="test")
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 42) # actor
|
||||
+ _encode_field(3, _WIRE_VARINT, 0) # channel_id
|
||||
+ _encode_field(5, _WIRE_LEN, "!echo test") # message
|
||||
)
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot._handle(MSG_TEXT_MESSAGE, payload))
|
||||
assert sent == ["test"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestPluginManagement
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -850,7 +485,6 @@ class TestMumbleBotConfig:
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"prefix": "/",
|
||||
"admins": [],
|
||||
"operators": [],
|
||||
@@ -869,7 +503,6 @@ class TestMumbleBotConfig:
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"admins": [],
|
||||
"operators": [],
|
||||
"trusted": [],
|
||||
@@ -887,140 +520,10 @@ class TestMumbleBotConfig:
|
||||
bot = _make_bot()
|
||||
assert bot._port == 64738
|
||||
|
||||
def test_tls_verify_default(self):
|
||||
bot = _make_bot()
|
||||
assert bot._tls_verify is False
|
||||
|
||||
def test_nick_from_username(self):
|
||||
bot = _make_bot()
|
||||
assert bot.nick == "derp"
|
||||
|
||||
def test_proxy_default_true(self):
|
||||
bot = _make_bot()
|
||||
assert bot._proxy is True
|
||||
|
||||
def test_proxy_disabled(self):
|
||||
config = {
|
||||
"mumble": {
|
||||
"enabled": True,
|
||||
"host": "127.0.0.1",
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"proxy": False,
|
||||
"admins": [],
|
||||
"operators": [],
|
||||
"trusted": [],
|
||||
},
|
||||
"bot": {"prefix": "!", "rate_limit": 2.0, "rate_burst": 5},
|
||||
}
|
||||
bot = MumbleBot("test", config, PluginRegistry())
|
||||
assert bot._proxy is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestVoiceVarint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVoiceVarint:
|
||||
def test_zero(self):
|
||||
assert _encode_voice_varint(0) == b"\x00"
|
||||
|
||||
def test_7bit_max(self):
|
||||
assert _encode_voice_varint(127) == b"\x7f"
|
||||
|
||||
def test_7bit_small(self):
|
||||
assert _encode_voice_varint(1) == b"\x01"
|
||||
assert _encode_voice_varint(42) == b"\x2a"
|
||||
|
||||
def test_14bit_min(self):
|
||||
# 128 = 0x80 -> prefix 10, value 128
|
||||
result = _encode_voice_varint(128)
|
||||
assert len(result) == 2
|
||||
assert result[0] & 0xC0 == 0x80 # top 2 bits = 10
|
||||
|
||||
def test_14bit_max(self):
|
||||
result = _encode_voice_varint(0x3FFF)
|
||||
assert len(result) == 2
|
||||
|
||||
def test_21bit(self):
|
||||
result = _encode_voice_varint(0x4000)
|
||||
assert len(result) == 3
|
||||
assert result[0] & 0xE0 == 0xC0 # top 3 bits = 110
|
||||
|
||||
def test_28bit(self):
|
||||
result = _encode_voice_varint(0x200000)
|
||||
assert len(result) == 4
|
||||
assert result[0] & 0xF0 == 0xE0 # top 4 bits = 1110
|
||||
|
||||
def test_64bit(self):
|
||||
result = _encode_voice_varint(0x10000000)
|
||||
assert len(result) == 9
|
||||
assert result[0] == 0xF0
|
||||
|
||||
def test_negative_raises(self):
|
||||
import pytest
|
||||
with pytest.raises(ValueError, match="non-negative"):
|
||||
_encode_voice_varint(-1)
|
||||
|
||||
def test_14bit_roundtrip(self):
|
||||
"""Value encoded and decoded back correctly (manual decode)."""
|
||||
val = 300
|
||||
data = _encode_voice_varint(val)
|
||||
assert len(data) == 2
|
||||
decoded = ((data[0] & 0x3F) << 8) | data[1]
|
||||
assert decoded == val
|
||||
|
||||
def test_7bit_roundtrip(self):
|
||||
for v in range(128):
|
||||
data = _encode_voice_varint(v)
|
||||
assert len(data) == 1
|
||||
assert data[0] == v
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestVoicePacket
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestVoicePacket:
|
||||
def test_header_byte(self):
|
||||
pkt = _build_voice_packet(0, b"\xaa\xbb")
|
||||
assert pkt[0] == 0x80 # type=4, target=0
|
||||
|
||||
def test_sequence_encoding(self):
|
||||
pkt = _build_voice_packet(42, b"\x00")
|
||||
# byte 0: header, byte 1: sequence=42 (7-bit)
|
||||
assert pkt[1] == 42
|
||||
|
||||
def test_opus_data_present(self):
|
||||
opus = b"\xde\xad\xbe\xef"
|
||||
pkt = _build_voice_packet(0, opus)
|
||||
assert pkt.endswith(opus)
|
||||
|
||||
def test_length_field(self):
|
||||
opus = b"\x00" * 10
|
||||
pkt = _build_voice_packet(0, opus)
|
||||
# header(1) + seq(1, val=0) + length(1, val=10) + data(10) = 13
|
||||
assert len(pkt) == 13
|
||||
assert pkt[2] == 10 # length varint
|
||||
|
||||
def test_terminator_flag(self):
|
||||
opus = b"\x00" * 5
|
||||
pkt = _build_voice_packet(0, opus, last=True)
|
||||
# length with bit 13 set: 5 | 0x2000 = 0x2005
|
||||
# 0x2005 in 14-bit varint: 10_100000 00000101
|
||||
length_bytes = _encode_voice_varint(5 | 0x2000)
|
||||
assert length_bytes in pkt
|
||||
|
||||
def test_no_terminator_by_default(self):
|
||||
opus = b"\x00" * 5
|
||||
pkt = _build_voice_packet(0, opus, last=False)
|
||||
# length=5, no bit 13
|
||||
assert pkt[2] == 5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestPcmScaling
|
||||
@@ -1029,41 +532,36 @@ class TestVoicePacket:
|
||||
|
||||
class TestPcmScaling:
|
||||
def test_unity_volume(self):
|
||||
import struct as _s
|
||||
pcm = _s.pack("<hh", 1000, -1000)
|
||||
pcm = struct.pack("<hh", 1000, -1000)
|
||||
result = _scale_pcm(pcm, 1.0)
|
||||
assert result == pcm
|
||||
|
||||
def test_half_volume(self):
|
||||
import struct as _s
|
||||
pcm = _s.pack("<h", 1000)
|
||||
pcm = struct.pack("<h", 1000)
|
||||
result = _scale_pcm(pcm, 0.5)
|
||||
samples = _s.unpack("<h", result)
|
||||
samples = struct.unpack("<h", result)
|
||||
assert samples[0] == 500
|
||||
|
||||
def test_clamp_positive(self):
|
||||
import struct as _s
|
||||
pcm = _s.pack("<h", 32767)
|
||||
pcm = struct.pack("<h", 32767)
|
||||
result = _scale_pcm(pcm, 2.0)
|
||||
samples = _s.unpack("<h", result)
|
||||
samples = struct.unpack("<h", result)
|
||||
assert samples[0] == 32767
|
||||
|
||||
def test_clamp_negative(self):
|
||||
import struct as _s
|
||||
pcm = _s.pack("<h", -32768)
|
||||
pcm = struct.pack("<h", -32768)
|
||||
result = _scale_pcm(pcm, 2.0)
|
||||
samples = _s.unpack("<h", result)
|
||||
samples = struct.unpack("<h", result)
|
||||
assert samples[0] == -32768
|
||||
|
||||
def test_zero_volume(self):
|
||||
import struct as _s
|
||||
pcm = _s.pack("<hh", 32767, -32768)
|
||||
pcm = struct.pack("<hh", 32767, -32768)
|
||||
result = _scale_pcm(pcm, 0.0)
|
||||
samples = _s.unpack("<hh", result)
|
||||
samples = struct.unpack("<hh", result)
|
||||
assert samples == (0, 0)
|
||||
|
||||
def test_preserves_length(self):
|
||||
pcm = b"\x00" * 1920 # 960 samples
|
||||
pcm = b"\x00" * 1920
|
||||
result = _scale_pcm(pcm, 0.5)
|
||||
assert len(result) == 1920
|
||||
|
||||
@@ -1085,13 +583,3 @@ class TestShellQuote:
|
||||
quoted = _shell_quote(url)
|
||||
assert quoted.startswith("'")
|
||||
assert quoted.endswith("'")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMsgUdpTunnel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMsgUdpTunnel:
|
||||
def test_constant(self):
|
||||
assert MSG_UDPTUNNEL == 1
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
"""Tests for the Opus ctypes wrapper."""
|
||||
|
||||
import math
|
||||
import struct
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from derp.opus import CHANNELS, FRAME_BYTES, FRAME_SIZE, SAMPLE_RATE
|
||||
|
||||
# -- Helpers -----------------------------------------------------------------
|
||||
|
||||
|
||||
def _silence() -> bytes:
|
||||
"""Generate one frame of silence (1920 bytes of zeros)."""
|
||||
return b"\x00" * FRAME_BYTES
|
||||
|
||||
|
||||
def _sine_frame(freq: float = 440.0) -> bytes:
|
||||
"""Generate one 20ms frame of a sine wave at the given frequency."""
|
||||
samples = []
|
||||
for i in range(FRAME_SIZE):
|
||||
t = i / SAMPLE_RATE
|
||||
val = int(16000 * math.sin(2 * math.pi * freq * t))
|
||||
samples.append(struct.pack("<h", max(-32768, min(32767, val))))
|
||||
return b"".join(samples)
|
||||
|
||||
|
||||
# -- Mock libopus for unit testing without system library --------------------
|
||||
|
||||
|
||||
def _mock_lib():
|
||||
"""Build a mock ctypes CDLL that simulates libopus calls."""
|
||||
lib = MagicMock()
|
||||
lib.opus_encoder_get_size.return_value = 256
|
||||
lib.opus_encoder_init.return_value = 0
|
||||
lib.opus_encoder_ctl.return_value = 0
|
||||
lib.opus_encode.return_value = 10 # 10 bytes output
|
||||
return lib
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_lib_cache():
|
||||
"""Reset the cached _lib before each test."""
|
||||
import derp.opus as _mod
|
||||
old = _mod._lib
|
||||
_mod._lib = None
|
||||
yield
|
||||
_mod._lib = old
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestOpusConstants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpusConstants:
|
||||
def test_sample_rate(self):
|
||||
assert SAMPLE_RATE == 48000
|
||||
|
||||
def test_channels(self):
|
||||
assert CHANNELS == 1
|
||||
|
||||
def test_frame_size(self):
|
||||
assert FRAME_SIZE == 960
|
||||
|
||||
def test_frame_bytes(self):
|
||||
assert FRAME_BYTES == FRAME_SIZE * CHANNELS * 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestOpusEncoder (mocked libopus)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpusEncoder:
|
||||
def test_encode_silence(self):
|
||||
"""Encoding silence produces bytes output."""
|
||||
lib = _mock_lib()
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
enc = OpusEncoder()
|
||||
result = enc.encode(_silence())
|
||||
assert isinstance(result, bytes)
|
||||
assert len(result) > 0
|
||||
enc.close()
|
||||
|
||||
def test_encode_sine(self):
|
||||
"""Encoding a sine wave produces bytes output."""
|
||||
lib = _mock_lib()
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
enc = OpusEncoder()
|
||||
result = enc.encode(_sine_frame())
|
||||
assert isinstance(result, bytes)
|
||||
enc.close()
|
||||
|
||||
def test_encode_wrong_size(self):
|
||||
"""Passing wrong buffer size raises ValueError."""
|
||||
lib = _mock_lib()
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
enc = OpusEncoder()
|
||||
with pytest.raises(ValueError, match="expected 1920"):
|
||||
enc.encode(b"\x00" * 100)
|
||||
enc.close()
|
||||
|
||||
def test_encode_multi_frame(self):
|
||||
"""Multiple sequential encodes work."""
|
||||
lib = _mock_lib()
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
enc = OpusEncoder()
|
||||
for _ in range(5):
|
||||
result = enc.encode(_silence())
|
||||
assert isinstance(result, bytes)
|
||||
enc.close()
|
||||
|
||||
def test_custom_bitrate(self):
|
||||
"""Custom bitrate is passed to opus_encoder_ctl."""
|
||||
lib = _mock_lib()
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
enc = OpusEncoder(bitrate=96000)
|
||||
assert lib.opus_encoder_ctl.called
|
||||
enc.close()
|
||||
|
||||
def test_init_failure(self):
|
||||
"""RuntimeError on encoder init failure."""
|
||||
lib = _mock_lib()
|
||||
lib.opus_encoder_init.return_value = -1
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
with pytest.raises(RuntimeError, match="opus_encoder_init"):
|
||||
OpusEncoder()
|
||||
|
||||
def test_encode_failure(self):
|
||||
"""RuntimeError on encode failure."""
|
||||
lib = _mock_lib()
|
||||
lib.opus_encode.return_value = -1
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
enc = OpusEncoder()
|
||||
with pytest.raises(RuntimeError, match="opus_encode"):
|
||||
enc.encode(_silence())
|
||||
|
||||
def test_close_clears_state(self):
|
||||
"""close() sets internal state to None."""
|
||||
lib = _mock_lib()
|
||||
with patch("derp.opus._load_lib", return_value=lib):
|
||||
from derp.opus import OpusEncoder
|
||||
enc = OpusEncoder()
|
||||
enc.close()
|
||||
assert enc._state is None
|
||||
Reference in New Issue
Block a user