diff --git a/src/derp/cli.py b/src/derp/cli.py index c1f0d4c..f9ad621 100644 --- a/src/derp/cli.py +++ b/src/derp/cli.py @@ -148,6 +148,13 @@ def main(argv: list[str] | None = None) -> int: tg_bot = TelegramBot("telegram", config, registry) bots.append(tg_bot) + # Mumble adapter (optional) + if config.get("mumble", {}).get("enabled"): + from derp.mumble import MumbleBot + + mumble_bot = MumbleBot("mumble", config, registry) + bots.append(mumble_bot) + names = ", ".join(b.name for b in bots) log.info("servers: %s", names) diff --git a/src/derp/config.py b/src/derp/config.py index c4d40a0..55aa019 100644 --- a/src/derp/config.py +++ b/src/derp/config.py @@ -58,6 +58,17 @@ DEFAULTS: dict = { "operators": [], "trusted": [], }, + "mumble": { + "enabled": False, + "host": "127.0.0.1", + "port": 64738, + "username": "derp", + "password": "", + "tls_verify": False, + "admins": [], + "operators": [], + "trusted": [], + }, "logging": { "level": "info", "format": "text", diff --git a/src/derp/mumble.py b/src/derp/mumble.py new file mode 100644 index 0000000..123d978 --- /dev/null +++ b/src/derp/mumble.py @@ -0,0 +1,728 @@ +"""Mumble adapter: TLS/TCP over SOCKS5, protobuf control channel (text only).""" + +from __future__ import annotations + +import asyncio +import html +import logging +import random +import re +import ssl +import struct +import time +from dataclasses import dataclass, field +from pathlib import Path + +from derp import http +from derp.bot import _TokenBucket +from derp.plugin import TIERS, PluginRegistry +from derp.state import StateStore + +log = logging.getLogger(__name__) + +_AMBIGUOUS = object() # sentinel for ambiguous prefix matches + +# -- Mumble message types ---------------------------------------------------- + +MSG_VERSION = 0 +MSG_AUTHENTICATE = 2 +MSG_PING = 3 +MSG_SERVER_SYNC = 5 +MSG_CHANNEL_REMOVE = 6 +MSG_CHANNEL_STATE = 7 +MSG_USER_REMOVE = 8 +MSG_USER_STATE = 9 +MSG_TEXT_MESSAGE = 11 + +# -- Protobuf wire helpers (minimal, no external dep) ------------------------ + +_WIRE_VARINT = 0 +_WIRE_LEN = 2 + + +def _encode_varint(value: int) -> bytes: + """Encode an unsigned integer as a protobuf varint.""" + buf = bytearray() + while value > 0x7F: + buf.append((value & 0x7F) | 0x80) + value >>= 7 + buf.append(value & 0x7F) + return bytes(buf) + + +def _decode_varint(data: bytes, offset: int) -> tuple[int, int]: + """Decode a protobuf varint, returning (value, new_offset).""" + result = 0 + shift = 0 + while offset < len(data): + byte = data[offset] + offset += 1 + result |= (byte & 0x7F) << shift + if not (byte & 0x80): + return result, offset + shift += 7 + return result, offset + + +def _encode_field(field_num: int, wire_type: int, value: int | bytes) -> bytes: + """Encode a single protobuf field.""" + tag = _encode_varint((field_num << 3) | wire_type) + if wire_type == _WIRE_VARINT: + return tag + _encode_varint(value) + # wire_type == _WIRE_LEN + if isinstance(value, str): + value = value.encode("utf-8") + return tag + _encode_varint(len(value)) + value + + +def _decode_fields(data: bytes) -> dict[int, list]: + """Decode protobuf fields, returning field_num -> list of values.""" + fields: dict[int, list] = {} + offset = 0 + while offset < len(data): + tag, offset = _decode_varint(data, offset) + field_num = tag >> 3 + wire_type = tag & 0x07 + if wire_type == _WIRE_VARINT: + value, offset = _decode_varint(data, offset) + elif wire_type == _WIRE_LEN: + length, offset = _decode_varint(data, offset) + value = data[offset:offset + length] + offset += length + elif wire_type == 5: # 32-bit fixed + value = data[offset:offset + 4] + offset += 4 + elif wire_type == 1: # 64-bit fixed + value = data[offset:offset + 8] + offset += 8 + else: + break # unknown wire type, stop parsing + fields.setdefault(field_num, []).append(value) + return fields + + +def _pack_msg(msg_type: int, payload: bytes = b"") -> bytes: + """Wrap a protobuf payload in a Mumble 6-byte header.""" + return struct.pack("!HI", msg_type, len(payload)) + payload + + +def _unpack_header(data: bytes) -> tuple[int, int]: + """Unpack a 6-byte Mumble header into (msg_type, payload_length).""" + return struct.unpack("!HI", data) + + +def _field_str(fields: dict[int, list], num: int) -> str | None: + """Extract a string field.""" + vals = fields.get(num) + if vals and isinstance(vals[0], bytes): + return vals[0].decode("utf-8", errors="replace") + return None + + +def _field_int(fields: dict[int, list], num: int) -> int | None: + """Extract an integer field.""" + vals = fields.get(num) + if vals and isinstance(vals[0], int): + return vals[0] + return None + + +def _field_ints(fields: dict[int, list], num: int) -> list[int]: + """Extract repeated integer fields.""" + vals = fields.get(num, []) + return [v for v in vals if isinstance(v, int)] + + +# -- HTML helpers ------------------------------------------------------------ + +_TAG_RE = re.compile(r"<[^>]+>") + + +def _strip_html(text: str) -> str: + """Strip HTML tags and unescape entities.""" + return html.unescape(_TAG_RE.sub("", text)) + + +def _escape_html(text: str) -> str: + """Escape text for Mumble HTML messages.""" + return html.escape(text, quote=False) + + +# -- MumbleMessage ----------------------------------------------------------- + + +@dataclass(slots=True) +class MumbleMessage: + """Parsed Mumble TextMessage, duck-typed with IRC Message. + + Plugins that use only ``msg.nick``, ``msg.text``, ``msg.target``, + ``msg.is_channel``, ``msg.prefix``, ``msg.command``, ``msg.params``, + and ``msg.tags`` work without modification. + """ + + raw: dict # decoded protobuf fields + nick: str | None # sender username (from session lookup) + prefix: str | None # sender username (for ACL) + text: str | None # message text (HTML stripped) + target: str | None # channel_id as string (or "dm" for DMs) + is_channel: bool = True # True for channel msgs, False for DMs + command: str = "PRIVMSG" # compat shim + params: list[str] = field(default_factory=list) + tags: dict[str, str] = field(default_factory=dict) + + +# -- Message builders -------------------------------------------------------- + + +def _build_version_payload() -> bytes: + """Build a Version message payload.""" + payload = b"" + # field 1: version_v1 (uint32) -- legacy: 1.5.0 + payload += _encode_field(1, _WIRE_VARINT, (1 << 16) | (5 << 8)) + # field 2: release (string) + payload += _encode_field(2, _WIRE_LEN, "derp 1.5.0") + # field 3: os (string) + payload += _encode_field(3, _WIRE_LEN, "Linux") + # field 4: os_version (string) + payload += _encode_field(4, _WIRE_LEN, "") + # field 5: version_v2 (uint64) -- semantic: 1.5.0 + payload += _encode_field(5, _WIRE_VARINT, (1 << 48) | (5 << 32)) + return payload + + +def _build_authenticate_payload(username: str, password: str) -> bytes: + """Build an Authenticate message payload.""" + payload = b"" + # field 1: username (string) + payload += _encode_field(1, _WIRE_LEN, username) + # field 2: password (string) + if password: + payload += _encode_field(2, _WIRE_LEN, password) + # field 5: opus (bool/varint) -- True + payload += _encode_field(5, _WIRE_VARINT, 1) + return payload + + +def _build_ping_payload(timestamp: int) -> bytes: + """Build a Ping message payload.""" + return _encode_field(1, _WIRE_VARINT, timestamp) + + +def _build_text_message_payload( + channel_ids: list[int] | None = None, + session_ids: list[int] | None = None, + tree_ids: list[int] | None = None, + message: str = "", +) -> bytes: + """Build a TextMessage payload.""" + payload = b"" + for sid in (session_ids or []): + payload += _encode_field(2, _WIRE_VARINT, sid) + for cid in (channel_ids or []): + payload += _encode_field(3, _WIRE_VARINT, cid) + for tid in (tree_ids or []): + payload += _encode_field(4, _WIRE_VARINT, tid) + # field 5: message (string) + payload += _encode_field(5, _WIRE_LEN, message) + return payload + + +def _build_mumble_message( + fields: dict[int, list], + users: dict[int, str], + our_session: int, +) -> MumbleMessage | None: + """Build a MumbleMessage from decoded TextMessage fields.""" + # field 5: message text + raw_text = _field_str(fields, 5) + if raw_text is None: + return None + + text = _strip_html(raw_text) + + # field 1: actor (sender session) + actor = _field_int(fields, 1) + nick = users.get(actor) if actor is not None else None + prefix = nick # use username for ACL + + # Determine target: channel_id (field 3) or DM (field 2) + channel_ids = _field_ints(fields, 3) + session_ids = _field_ints(fields, 2) + + if channel_ids: + target = str(channel_ids[0]) + is_channel = True + elif session_ids: + target = "dm" + is_channel = False + else: + target = None + is_channel = True + + return MumbleMessage( + raw=dict(fields), + nick=nick, + prefix=prefix, + text=text, + target=target, + is_channel=is_channel, + params=[target or "", text], + ) + + +# -- MumbleBot -------------------------------------------------------------- + + +class MumbleBot: + """Mumble bot adapter via TCP/TLS protobuf control channel (text only). + + Exposes the same public API as :class:`derp.bot.Bot` so that + protocol-agnostic plugins work without modification. + All TCP goes through ``derp.http.create_connection`` (SOCKS5 proxy). + """ + + def __init__(self, name: str, config: dict, registry: PluginRegistry) -> None: + self.name = name + self.config = config + self.registry = registry + self._pstate: dict = {} + + mu_cfg = config.get("mumble", {}) + self._host: str = mu_cfg.get("host", "127.0.0.1") + self._port: int = mu_cfg.get("port", 64738) + self._username: str = mu_cfg.get("username", "derp") + self._password: str = mu_cfg.get("password", "") + self._tls_verify: bool = mu_cfg.get("tls_verify", False) + self.nick: str = self._username + self.prefix: str = ( + mu_cfg.get("prefix") + or config.get("bot", {}).get("prefix", "!") + ) + self._running = False + self._started: float = time.monotonic() + self._tasks: set[asyncio.Task] = set() + self._reconnect_delay: float = 5.0 + self._admins: list[str] = [str(x) for x in mu_cfg.get("admins", [])] + self._operators: list[str] = [str(x) for x in mu_cfg.get("operators", [])] + self._trusted: list[str] = [str(x) for x in mu_cfg.get("trusted", [])] + self.state = StateStore(f"data/state-{name}.db") + + # Protocol state + self._session: int = 0 # our session ID (from ServerSync) + self._channels: dict[int, str] = {} # channel_id -> name + self._users: dict[int, str] = {} # session_id -> username + self._user_channels: dict[int, int] = {} # session_id -> channel_id + self._reader: asyncio.StreamReader | None = None + self._writer: asyncio.StreamWriter | None = None + + rate_cfg = config.get("bot", {}) + self._bucket = _TokenBucket( + rate=rate_cfg.get("rate_limit", 2.0), + burst=rate_cfg.get("rate_burst", 5), + ) + + # -- Connection ---------------------------------------------------------- + + def _create_ssl_context(self) -> ssl.SSLContext: + """Build an SSL context for Mumble TLS.""" + ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + if not self._tls_verify: + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + return ctx + + async def _connect(self) -> None: + """Establish TLS connection over SOCKS5 proxy.""" + loop = asyncio.get_running_loop() + sock = await loop.run_in_executor( + None, http.create_connection, + (self._host, self._port), + ) + ssl_ctx = self._create_ssl_context() + hostname = self._host if self._tls_verify else None + self._reader, self._writer = await asyncio.open_connection( + sock=sock, ssl=ssl_ctx, server_hostname=hostname, + ) + + async def _send_msg(self, msg_type: int, payload: bytes = b"") -> None: + """Send a framed Mumble message.""" + if self._writer is None: + return + data = _pack_msg(msg_type, payload) + self._writer.write(data) + await self._writer.drain() + + async def _read_msg(self) -> tuple[int, bytes] | None: + """Read one framed Mumble message. Returns (msg_type, payload).""" + if self._reader is None: + return None + header = await self._reader.readexactly(6) + msg_type, length = _unpack_header(header) + payload = await self._reader.readexactly(length) if length else b"" + return msg_type, payload + + async def _close(self) -> None: + """Close the connection.""" + if self._writer: + try: + self._writer.close() + await self._writer.wait_closed() + except Exception: + pass + self._reader = None + self._writer = None + + # -- Lifecycle ----------------------------------------------------------- + + async def start(self) -> None: + """Connect and enter the message loop with reconnect backoff.""" + self._running = True + while self._running: + try: + await self._connect_and_run() + self._reconnect_delay = 5.0 + except (OSError, ConnectionError, asyncio.IncompleteReadError) as exc: + log.error("mumble: connection lost: %s", exc) + if self._running: + jitter = self._reconnect_delay * 0.25 * (2 * random.random() - 1) + delay = self._reconnect_delay + jitter + log.info("mumble: reconnecting in %.0fs...", delay) + await asyncio.sleep(delay) + self._reconnect_delay = min(self._reconnect_delay * 2, 300.0) + + async def _connect_and_run(self) -> None: + """Single connection lifecycle.""" + await self._connect() + try: + await self._handshake() + await self._loop() + finally: + await self._close() + + async def _handshake(self) -> None: + """Send Version and Authenticate messages.""" + # Version + await self._send_msg(MSG_VERSION, _build_version_payload()) + # Authenticate + await self._send_msg( + MSG_AUTHENTICATE, + _build_authenticate_payload(self._username, self._password), + ) + log.info("mumble: authenticating as %s on %s:%d", + self._username, self._host, self._port) + + async def _loop(self) -> None: + """Read and dispatch messages until disconnect.""" + while self._running: + result = await self._read_msg() + if result is None: + log.warning("mumble: server closed connection") + return + msg_type, payload = result + await self._handle(msg_type, payload) + + async def _handle(self, msg_type: int, payload: bytes) -> None: + """Dispatch a received Mumble message by type.""" + if msg_type == MSG_PING: + fields = _decode_fields(payload) + ts = _field_int(fields, 1) or 0 + await self._send_msg(MSG_PING, _build_ping_payload(ts)) + return + + if msg_type == MSG_SERVER_SYNC: + fields = _decode_fields(payload) + session = _field_int(fields, 1) + if session is not None: + self._session = session + welcome = _field_str(fields, 3) or "" + log.info("mumble: connected, session=%d, welcome=%s", + self._session, _strip_html(welcome)[:80]) + return + + if msg_type == MSG_CHANNEL_STATE: + fields = _decode_fields(payload) + cid = _field_int(fields, 1) + name = _field_str(fields, 3) + if cid is not None and name is not None: + self._channels[cid] = name + return + + if msg_type == MSG_CHANNEL_REMOVE: + fields = _decode_fields(payload) + cid = _field_int(fields, 1) + if cid is not None: + self._channels.pop(cid, None) + return + + if msg_type == MSG_USER_STATE: + fields = _decode_fields(payload) + session = _field_int(fields, 1) + name = _field_str(fields, 3) + channel_id = _field_int(fields, 5) + if session is not None: + if name is not None: + self._users[session] = name + if channel_id is not None: + self._user_channels[session] = channel_id + return + + if msg_type == MSG_USER_REMOVE: + fields = _decode_fields(payload) + session = _field_int(fields, 1) + if session is not None: + self._users.pop(session, None) + self._user_channels.pop(session, None) + return + + if msg_type == MSG_TEXT_MESSAGE: + fields = _decode_fields(payload) + msg = _build_mumble_message(fields, self._users, self._session) + if msg is not None: + await self._dispatch_command(msg) + return + + # -- Command dispatch ---------------------------------------------------- + + async def _dispatch_command(self, msg: MumbleMessage) -> None: + """Parse and dispatch a command from a Mumble text message.""" + text = msg.text + if not text or not text.startswith(self.prefix): + return + + parts = text[len(self.prefix):].split(None, 1) + cmd_name = parts[0].lower() if parts else "" + handler = self._resolve_command(cmd_name) + if handler is None: + return + if handler is _AMBIGUOUS: + matches = [k for k in self.registry.commands + if k.startswith(cmd_name)] + names = ", ".join(self.prefix + m for m in sorted(matches)) + await self.reply( + msg, + f"Ambiguous command '{self.prefix}{cmd_name}': {names}", + ) + return + + if not self._plugin_allowed(handler.plugin, msg.target): + return + + required = handler.tier + if required != "user": + sender = self._get_tier(msg) + if TIERS.index(sender) < TIERS.index(required): + await self.reply( + msg, + f"Permission denied: {self.prefix}{cmd_name} " + f"requires {required}", + ) + return + + try: + await handler.callback(self, msg) + except Exception: + log.exception("mumble: error in command handler '%s'", cmd_name) + + def _resolve_command(self, name: str): + """Resolve command name with unambiguous prefix matching. + + Returns the Handler on exact or unique prefix match, the sentinel + ``_AMBIGUOUS`` if multiple commands match, or None if nothing matches. + """ + handler = self.registry.commands.get(name) + if handler is not None: + return handler + matches = [v for k, v in self.registry.commands.items() + if k.startswith(name)] + if len(matches) == 1: + return matches[0] + if len(matches) > 1: + return _AMBIGUOUS + return None + + def _plugin_allowed(self, plugin_name: str, channel: str | None) -> bool: + """Channel filtering is IRC-only; all plugins are allowed on Mumble.""" + return True + + # -- Permission tiers ---------------------------------------------------- + + def _get_tier(self, msg: MumbleMessage) -> str: + """Determine permission tier from username. + + Matches exact string comparison of username against config lists. + """ + if not msg.prefix: + return "user" + for name in self._admins: + if msg.prefix == name: + return "admin" + for name in self._operators: + if msg.prefix == name: + return "oper" + for name in self._trusted: + if msg.prefix == name: + return "trusted" + return "user" + + def _is_admin(self, msg: MumbleMessage) -> bool: + """Check if the message sender is a bot admin.""" + return self._get_tier(msg) == "admin" + + # -- Public API for plugins ---------------------------------------------- + + async def _send_html(self, target: str, html_text: str) -> None: + """Send a TextMessage with pre-formatted HTML (no escaping).""" + await self._bucket.acquire() + try: + channel_id = int(target) + except (ValueError, TypeError): + channel_id = 0 # root channel fallback + payload = _build_text_message_payload( + channel_ids=[channel_id], message=html_text, + ) + await self._send_msg(MSG_TEXT_MESSAGE, payload) + + async def send(self, target: str, text: str) -> None: + """Send a TextMessage to a channel (HTML-escaped, rate-limited). + + ``target`` is a channel_id as string. For DMs, sends to the + current channel instead (Mumble DMs require session IDs which + are not stable). + """ + await self._send_html(target, _escape_html(text)) + + async def reply(self, msg, text: str) -> None: + """Reply to the source channel.""" + if msg.target and msg.target != "dm": + await self.send(msg.target, text) + elif msg.target == "dm": + # Best-effort: send to root channel + await self.send("0", text) + + async def long_reply( + self, msg, lines: list[str], *, + label: str = "", + ) -> None: + """Reply with a list of lines; paste overflow to FlaskPaste. + + Same overflow logic as :meth:`derp.bot.Bot.long_reply`. + """ + threshold = self.config.get("bot", {}).get("paste_threshold", 4) + if not lines or not msg.target: + return + + if len(lines) <= threshold: + for line in lines: + await self.send(msg.target, line) + return + + # Attempt paste overflow + fp = self.registry._modules.get("flaskpaste") + paste_url = None + if fp: + full_text = "\n".join(lines) + loop = asyncio.get_running_loop() + paste_url = await loop.run_in_executor( + None, fp.create_paste, self, full_text, + ) + + if paste_url: + preview_count = min(2, threshold - 1) + for line in lines[:preview_count]: + await self.send(msg.target, line) + remaining = len(lines) - preview_count + suffix = f" ({label})" if label else "" + await self.send( + msg.target, + f"... {remaining} more lines{suffix}: {paste_url}", + ) + else: + for line in lines: + await self.send(msg.target, line) + + async def action(self, target: str, text: str) -> None: + """Send an action as italic HTML text.""" + await self._send_html(target, f"{_escape_html(text)}") + + async def shorten_url(self, url: str) -> str: + """Shorten a URL via FlaskPaste. Returns original on failure.""" + fp = self.registry._modules.get("flaskpaste") + if not fp: + return url + loop = asyncio.get_running_loop() + try: + return await loop.run_in_executor(None, fp.shorten_url, self, url) + except Exception: + return url + + # -- IRC no-ops ---------------------------------------------------------- + + async def join(self, channel: str) -> None: + """No-op: IRC-only concept.""" + log.debug("mumble: join() is a no-op") + + async def part(self, channel: str, reason: str = "") -> None: + """No-op: IRC-only concept.""" + log.debug("mumble: part() is a no-op") + + async def quit(self, reason: str = "bye") -> None: + """Stop the Mumble adapter.""" + self._running = False + + async def kick(self, channel: str, nick: str, reason: str = "") -> None: + """No-op: IRC-only concept.""" + log.debug("mumble: kick() is a no-op") + + async def mode(self, target: str, mode_str: str, *args: str) -> None: + """No-op: IRC-only concept.""" + log.debug("mumble: mode() is a no-op") + + async def set_topic(self, channel: str, topic: str) -> None: + """No-op: IRC-only concept.""" + log.debug("mumble: set_topic() is a no-op") + + # -- Plugin management (delegated to registry) --------------------------- + + def load_plugins(self, plugins_dir: str | Path | None = None) -> None: + """Load plugins from the configured directory.""" + if plugins_dir is None: + plugins_dir = self.config.get("bot", {}).get( + "plugins_dir", "plugins") + path = Path(plugins_dir) + self.registry.load_directory(path) + + @property + def plugins_dir(self) -> Path: + """Resolved path to the plugins directory.""" + return Path(self.config.get("bot", {}).get("plugins_dir", "plugins")) + + def load_plugin(self, name: str) -> tuple[bool, str]: + """Hot-load a new plugin by name from the plugins directory.""" + if name in self.registry._modules: + return False, f"plugin already loaded: {name}" + path = self.plugins_dir / f"{name}.py" + if not path.is_file(): + return False, f"{name}.py not found" + count = self.registry.load_plugin(path) + if count < 0: + return False, f"failed to load {name}" + return True, f"{count} handlers" + + def reload_plugin(self, name: str) -> tuple[bool, str]: + """Reload a plugin, picking up any file changes.""" + return self.registry.reload_plugin(name) + + def unload_plugin(self, name: str) -> tuple[bool, str]: + """Unload a plugin, removing all its handlers.""" + if self.registry.unload_plugin(name): + return True, "" + if name == "core": + return False, "cannot unload core" + return False, f"plugin not loaded: {name}" + + def _spawn(self, coro, *, name: str | None = None) -> asyncio.Task: + """Spawn a background task and track it for cleanup.""" + task = asyncio.create_task(coro, name=name) + self._tasks.add(task) + task.add_done_callback(self._tasks.discard) + return task diff --git a/tests/test_mumble.py b/tests/test_mumble.py new file mode 100644 index 0000000..4c81d31 --- /dev/null +++ b/tests/test_mumble.py @@ -0,0 +1,891 @@ +"""Tests for the Mumble adapter.""" + +import asyncio +from unittest.mock import patch + +from derp.mumble import ( + _WIRE_LEN, + _WIRE_VARINT, + MSG_CHANNEL_REMOVE, + MSG_CHANNEL_STATE, + MSG_PING, + MSG_SERVER_SYNC, + MSG_TEXT_MESSAGE, + MSG_USER_REMOVE, + MSG_USER_STATE, + MumbleBot, + MumbleMessage, + _build_authenticate_payload, + _build_mumble_message, + _build_ping_payload, + _build_text_message_payload, + _build_version_payload, + _decode_fields, + _decode_varint, + _encode_field, + _encode_varint, + _escape_html, + _field_int, + _field_ints, + _field_str, + _pack_msg, + _strip_html, + _unpack_header, +) +from derp.plugin import PluginRegistry + +# -- Helpers ----------------------------------------------------------------- + + +def _make_bot(admins=None, operators=None, trusted=None, prefix=None): + """Create a MumbleBot with test config.""" + config = { + "mumble": { + "enabled": True, + "host": "127.0.0.1", + "port": 64738, + "username": "derp", + "password": "", + "tls_verify": False, + "admins": admins or [], + "operators": operators or [], + "trusted": trusted or [], + }, + "bot": { + "prefix": prefix or "!", + "paste_threshold": 4, + "plugins_dir": "plugins", + "rate_limit": 2.0, + "rate_burst": 5, + }, + } + registry = PluginRegistry() + bot = MumbleBot("mu-test", config, registry) + return bot + + +def _mu_msg(text="!ping", nick="Alice", prefix="Alice", + target="0", is_channel=True): + """Create a MumbleMessage for command testing.""" + return MumbleMessage( + raw={}, nick=nick, prefix=prefix, text=text, target=target, + is_channel=is_channel, + params=[target, text], + ) + + +# -- Test helpers for registering commands ----------------------------------- + + +async def _echo_handler(bot, msg): + """Simple command handler that echoes text.""" + args = msg.text.split(None, 1) + reply = args[1] if len(args) > 1 else "no args" + await bot.reply(msg, reply) + + +async def _admin_handler(bot, msg): + """Admin-only command handler.""" + await bot.reply(msg, "admin action done") + + +# --------------------------------------------------------------------------- +# TestProtobufCodec +# --------------------------------------------------------------------------- + + +class TestProtobufCodec: + def test_encode_varint_zero(self): + assert _encode_varint(0) == b"\x00" + + def test_encode_varint_small(self): + assert _encode_varint(1) == b"\x01" + assert _encode_varint(127) == b"\x7f" + + def test_encode_varint_two_byte(self): + assert _encode_varint(128) == b"\x80\x01" + assert _encode_varint(300) == b"\xac\x02" + + def test_encode_varint_large(self): + # 16384 = 0x4000 + encoded = _encode_varint(16384) + val, _ = _decode_varint(encoded, 0) + assert val == 16384 + + def test_decode_varint_zero(self): + val, off = _decode_varint(b"\x00", 0) + assert val == 0 + assert off == 1 + + def test_decode_varint_small(self): + val, off = _decode_varint(b"\x01", 0) + assert val == 1 + assert off == 1 + + def test_decode_varint_multi_byte(self): + val, off = _decode_varint(b"\xac\x02", 0) + assert val == 300 + assert off == 2 + + def test_varint_roundtrip(self): + for n in [0, 1, 127, 128, 300, 16384, 2**21, 2**28]: + encoded = _encode_varint(n) + decoded, _ = _decode_varint(encoded, 0) + assert decoded == n, f"roundtrip failed for {n}" + + def test_encode_field_varint(self): + # field 1, wire type 0, value 42 + data = _encode_field(1, _WIRE_VARINT, 42) + fields = _decode_fields(data) + assert fields[1] == [42] + + def test_encode_field_string(self): + data = _encode_field(5, _WIRE_LEN, "hello") + fields = _decode_fields(data) + assert fields[5] == [b"hello"] + + def test_encode_field_bytes(self): + data = _encode_field(3, _WIRE_LEN, b"\x00\x01\x02") + fields = _decode_fields(data) + assert fields[3] == [b"\x00\x01\x02"] + + def test_decode_multiple_fields(self): + data = ( + _encode_field(1, _WIRE_VARINT, 10) + + _encode_field(2, _WIRE_LEN, "test") + + _encode_field(3, _WIRE_VARINT, 99) + ) + fields = _decode_fields(data) + assert fields[1] == [10] + assert fields[2] == [b"test"] + assert fields[3] == [99] + + def test_decode_repeated_fields(self): + data = ( + _encode_field(3, _WIRE_VARINT, 1) + + _encode_field(3, _WIRE_VARINT, 2) + + _encode_field(3, _WIRE_VARINT, 3) + ) + fields = _decode_fields(data) + assert fields[3] == [1, 2, 3] + + def test_pack_unpack_header(self): + packed = _pack_msg(11, b"hello") + msg_type, length = _unpack_header(packed[:6]) + assert msg_type == 11 + assert length == 5 + assert packed[6:] == b"hello" + + def test_pack_empty_payload(self): + packed = _pack_msg(3) + assert len(packed) == 6 + msg_type, length = _unpack_header(packed) + assert msg_type == 3 + assert length == 0 + + def test_field_str(self): + fields = {5: [b"hello"]} + assert _field_str(fields, 5) == "hello" + assert _field_str(fields, 1) is None + + def test_field_int(self): + fields = {1: [42]} + assert _field_int(fields, 1) == 42 + assert _field_int(fields, 2) is None + + def test_field_ints(self): + fields = {3: [1, 2, 3]} + assert _field_ints(fields, 3) == [1, 2, 3] + assert _field_ints(fields, 9) == [] + + def test_decode_empty(self): + fields = _decode_fields(b"") + assert fields == {} + + +# --------------------------------------------------------------------------- +# TestMumbleMessage +# --------------------------------------------------------------------------- + + +class TestMumbleMessage: + def test_defaults(self): + msg = MumbleMessage(raw={}, nick=None, prefix=None, text=None, + target=None) + assert msg.is_channel is True + assert msg.command == "PRIVMSG" + assert msg.params == [] + assert msg.tags == {} + + def test_custom_values(self): + msg = MumbleMessage( + raw={"field": 1}, nick="Alice", prefix="Alice", + text="hello", target="0", is_channel=True, + command="PRIVMSG", params=["0", "hello"], + tags={"key": "val"}, + ) + assert msg.nick == "Alice" + assert msg.prefix == "Alice" + assert msg.text == "hello" + assert msg.target == "0" + assert msg.tags == {"key": "val"} + + def test_duck_type_compat(self): + """MumbleMessage has the same attribute names as IRC Message.""" + msg = _mu_msg() + attrs = ["raw", "nick", "prefix", "text", "target", + "is_channel", "command", "params", "tags"] + for attr in attrs: + assert hasattr(msg, attr), f"missing attribute: {attr}" + + def test_dm_message(self): + msg = _mu_msg(target="dm", is_channel=False) + assert msg.is_channel is False + assert msg.target == "dm" + + def test_prefix_is_username(self): + msg = _mu_msg(prefix="admin_user") + assert msg.prefix == "admin_user" + + +# --------------------------------------------------------------------------- +# TestBuildMumbleMessage +# --------------------------------------------------------------------------- + + +class TestBuildMumbleMessage: + def test_channel_message(self): + fields = { + 1: [42], # actor session + 3: [5], # channel_id + 5: [b"Hello"], # message HTML + } + users = {42: "Alice"} + msg = _build_mumble_message(fields, users, our_session=1) + assert msg is not None + assert msg.nick == "Alice" + assert msg.prefix == "Alice" + assert msg.text == "Hello" + assert msg.target == "5" + assert msg.is_channel is True + + def test_dm_message(self): + fields = { + 1: [42], # actor session + 2: [1], # target session (DM) + 5: [b"secret"], # message + } + users = {42: "Bob"} + msg = _build_mumble_message(fields, users, our_session=1) + assert msg is not None + assert msg.target == "dm" + assert msg.is_channel is False + + def test_missing_sender(self): + fields = { + 3: [0], # channel_id + 5: [b"anonymous"], # message + } + msg = _build_mumble_message(fields, {}, our_session=1) + assert msg is not None + assert msg.nick is None + assert msg.prefix is None + + def test_missing_message(self): + fields = { + 1: [42], + 3: [0], + } + msg = _build_mumble_message(fields, {42: "Alice"}, our_session=1) + assert msg is None + + def test_html_stripped(self): + fields = { + 1: [1], + 3: [0], + 5: [b"click & go"], + } + msg = _build_mumble_message(fields, {1: "User"}, our_session=0) + assert msg is not None + assert msg.text == "click & go" + + def test_no_target(self): + fields = { + 1: [42], + 5: [b"orphan message"], + } + msg = _build_mumble_message(fields, {42: "Alice"}, our_session=1) + assert msg is not None + assert msg.target is None + assert msg.is_channel is True + + +# --------------------------------------------------------------------------- +# TestHtmlHelpers +# --------------------------------------------------------------------------- + + +class TestHtmlHelpers: + def test_strip_html_simple(self): + assert _strip_html("bold") == "bold" + + def test_strip_html_entities(self): + assert _strip_html("& < > "") == '& < > "' + + def test_strip_html_nested(self): + assert _strip_html("
hello world
") == "hello world" + + def test_strip_html_plain(self): + assert _strip_html("no tags here") == "no tags here" + + def test_escape_html(self): + assert _escape_html("