feat: add Mumble bot adapter with minimal protobuf codec
TCP/TLS connection over SOCKS5 proxy to Mumble servers for text chat. Minimal varint/field protobuf encoder/decoder (no external dep) handles Version, Authenticate, Ping, ServerSync, ChannelState, UserState, and TextMessage message types. MumbleBot exposes the same duck-typed plugin API as Bot/TeamsBot/TelegramBot. 93 new tests (1470 total).
This commit is contained in:
@@ -148,6 +148,13 @@ def main(argv: list[str] | None = None) -> int:
|
||||
tg_bot = TelegramBot("telegram", config, registry)
|
||||
bots.append(tg_bot)
|
||||
|
||||
# Mumble adapter (optional)
|
||||
if config.get("mumble", {}).get("enabled"):
|
||||
from derp.mumble import MumbleBot
|
||||
|
||||
mumble_bot = MumbleBot("mumble", config, registry)
|
||||
bots.append(mumble_bot)
|
||||
|
||||
names = ", ".join(b.name for b in bots)
|
||||
log.info("servers: %s", names)
|
||||
|
||||
|
||||
@@ -58,6 +58,17 @@ DEFAULTS: dict = {
|
||||
"operators": [],
|
||||
"trusted": [],
|
||||
},
|
||||
"mumble": {
|
||||
"enabled": False,
|
||||
"host": "127.0.0.1",
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"admins": [],
|
||||
"operators": [],
|
||||
"trusted": [],
|
||||
},
|
||||
"logging": {
|
||||
"level": "info",
|
||||
"format": "text",
|
||||
|
||||
728
src/derp/mumble.py
Normal file
728
src/derp/mumble.py
Normal file
@@ -0,0 +1,728 @@
|
||||
"""Mumble adapter: TLS/TCP over SOCKS5, protobuf control channel (text only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import html
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import ssl
|
||||
import struct
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from derp import http
|
||||
from derp.bot import _TokenBucket
|
||||
from derp.plugin import TIERS, PluginRegistry
|
||||
from derp.state import StateStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_AMBIGUOUS = object() # sentinel for ambiguous prefix matches
|
||||
|
||||
# -- Mumble message types ----------------------------------------------------
|
||||
|
||||
MSG_VERSION = 0
|
||||
MSG_AUTHENTICATE = 2
|
||||
MSG_PING = 3
|
||||
MSG_SERVER_SYNC = 5
|
||||
MSG_CHANNEL_REMOVE = 6
|
||||
MSG_CHANNEL_STATE = 7
|
||||
MSG_USER_REMOVE = 8
|
||||
MSG_USER_STATE = 9
|
||||
MSG_TEXT_MESSAGE = 11
|
||||
|
||||
# -- Protobuf wire helpers (minimal, no external dep) ------------------------
|
||||
|
||||
_WIRE_VARINT = 0
|
||||
_WIRE_LEN = 2
|
||||
|
||||
|
||||
def _encode_varint(value: int) -> bytes:
|
||||
"""Encode an unsigned integer as a protobuf varint."""
|
||||
buf = bytearray()
|
||||
while value > 0x7F:
|
||||
buf.append((value & 0x7F) | 0x80)
|
||||
value >>= 7
|
||||
buf.append(value & 0x7F)
|
||||
return bytes(buf)
|
||||
|
||||
|
||||
def _decode_varint(data: bytes, offset: int) -> tuple[int, int]:
|
||||
"""Decode a protobuf varint, returning (value, new_offset)."""
|
||||
result = 0
|
||||
shift = 0
|
||||
while offset < len(data):
|
||||
byte = data[offset]
|
||||
offset += 1
|
||||
result |= (byte & 0x7F) << shift
|
||||
if not (byte & 0x80):
|
||||
return result, offset
|
||||
shift += 7
|
||||
return result, offset
|
||||
|
||||
|
||||
def _encode_field(field_num: int, wire_type: int, value: int | bytes) -> bytes:
|
||||
"""Encode a single protobuf field."""
|
||||
tag = _encode_varint((field_num << 3) | wire_type)
|
||||
if wire_type == _WIRE_VARINT:
|
||||
return tag + _encode_varint(value)
|
||||
# wire_type == _WIRE_LEN
|
||||
if isinstance(value, str):
|
||||
value = value.encode("utf-8")
|
||||
return tag + _encode_varint(len(value)) + value
|
||||
|
||||
|
||||
def _decode_fields(data: bytes) -> dict[int, list]:
|
||||
"""Decode protobuf fields, returning field_num -> list of values."""
|
||||
fields: dict[int, list] = {}
|
||||
offset = 0
|
||||
while offset < len(data):
|
||||
tag, offset = _decode_varint(data, offset)
|
||||
field_num = tag >> 3
|
||||
wire_type = tag & 0x07
|
||||
if wire_type == _WIRE_VARINT:
|
||||
value, offset = _decode_varint(data, offset)
|
||||
elif wire_type == _WIRE_LEN:
|
||||
length, offset = _decode_varint(data, offset)
|
||||
value = data[offset:offset + length]
|
||||
offset += length
|
||||
elif wire_type == 5: # 32-bit fixed
|
||||
value = data[offset:offset + 4]
|
||||
offset += 4
|
||||
elif wire_type == 1: # 64-bit fixed
|
||||
value = data[offset:offset + 8]
|
||||
offset += 8
|
||||
else:
|
||||
break # unknown wire type, stop parsing
|
||||
fields.setdefault(field_num, []).append(value)
|
||||
return fields
|
||||
|
||||
|
||||
def _pack_msg(msg_type: int, payload: bytes = b"") -> bytes:
|
||||
"""Wrap a protobuf payload in a Mumble 6-byte header."""
|
||||
return struct.pack("!HI", msg_type, len(payload)) + payload
|
||||
|
||||
|
||||
def _unpack_header(data: bytes) -> tuple[int, int]:
|
||||
"""Unpack a 6-byte Mumble header into (msg_type, payload_length)."""
|
||||
return struct.unpack("!HI", data)
|
||||
|
||||
|
||||
def _field_str(fields: dict[int, list], num: int) -> str | None:
|
||||
"""Extract a string field."""
|
||||
vals = fields.get(num)
|
||||
if vals and isinstance(vals[0], bytes):
|
||||
return vals[0].decode("utf-8", errors="replace")
|
||||
return None
|
||||
|
||||
|
||||
def _field_int(fields: dict[int, list], num: int) -> int | None:
|
||||
"""Extract an integer field."""
|
||||
vals = fields.get(num)
|
||||
if vals and isinstance(vals[0], int):
|
||||
return vals[0]
|
||||
return None
|
||||
|
||||
|
||||
def _field_ints(fields: dict[int, list], num: int) -> list[int]:
|
||||
"""Extract repeated integer fields."""
|
||||
vals = fields.get(num, [])
|
||||
return [v for v in vals if isinstance(v, int)]
|
||||
|
||||
|
||||
# -- HTML helpers ------------------------------------------------------------
|
||||
|
||||
_TAG_RE = re.compile(r"<[^>]+>")
|
||||
|
||||
|
||||
def _strip_html(text: str) -> str:
|
||||
"""Strip HTML tags and unescape entities."""
|
||||
return html.unescape(_TAG_RE.sub("", text))
|
||||
|
||||
|
||||
def _escape_html(text: str) -> str:
|
||||
"""Escape text for Mumble HTML messages."""
|
||||
return html.escape(text, quote=False)
|
||||
|
||||
|
||||
# -- MumbleMessage -----------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class MumbleMessage:
|
||||
"""Parsed Mumble TextMessage, duck-typed with IRC Message.
|
||||
|
||||
Plugins that use only ``msg.nick``, ``msg.text``, ``msg.target``,
|
||||
``msg.is_channel``, ``msg.prefix``, ``msg.command``, ``msg.params``,
|
||||
and ``msg.tags`` work without modification.
|
||||
"""
|
||||
|
||||
raw: dict # decoded protobuf fields
|
||||
nick: str | None # sender username (from session lookup)
|
||||
prefix: str | None # sender username (for ACL)
|
||||
text: str | None # message text (HTML stripped)
|
||||
target: str | None # channel_id as string (or "dm" for DMs)
|
||||
is_channel: bool = True # True for channel msgs, False for DMs
|
||||
command: str = "PRIVMSG" # compat shim
|
||||
params: list[str] = field(default_factory=list)
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
# -- Message builders --------------------------------------------------------
|
||||
|
||||
|
||||
def _build_version_payload() -> bytes:
|
||||
"""Build a Version message payload."""
|
||||
payload = b""
|
||||
# field 1: version_v1 (uint32) -- legacy: 1.5.0
|
||||
payload += _encode_field(1, _WIRE_VARINT, (1 << 16) | (5 << 8))
|
||||
# field 2: release (string)
|
||||
payload += _encode_field(2, _WIRE_LEN, "derp 1.5.0")
|
||||
# field 3: os (string)
|
||||
payload += _encode_field(3, _WIRE_LEN, "Linux")
|
||||
# field 4: os_version (string)
|
||||
payload += _encode_field(4, _WIRE_LEN, "")
|
||||
# field 5: version_v2 (uint64) -- semantic: 1.5.0
|
||||
payload += _encode_field(5, _WIRE_VARINT, (1 << 48) | (5 << 32))
|
||||
return payload
|
||||
|
||||
|
||||
def _build_authenticate_payload(username: str, password: str) -> bytes:
|
||||
"""Build an Authenticate message payload."""
|
||||
payload = b""
|
||||
# field 1: username (string)
|
||||
payload += _encode_field(1, _WIRE_LEN, username)
|
||||
# field 2: password (string)
|
||||
if password:
|
||||
payload += _encode_field(2, _WIRE_LEN, password)
|
||||
# field 5: opus (bool/varint) -- True
|
||||
payload += _encode_field(5, _WIRE_VARINT, 1)
|
||||
return payload
|
||||
|
||||
|
||||
def _build_ping_payload(timestamp: int) -> bytes:
|
||||
"""Build a Ping message payload."""
|
||||
return _encode_field(1, _WIRE_VARINT, timestamp)
|
||||
|
||||
|
||||
def _build_text_message_payload(
|
||||
channel_ids: list[int] | None = None,
|
||||
session_ids: list[int] | None = None,
|
||||
tree_ids: list[int] | None = None,
|
||||
message: str = "",
|
||||
) -> bytes:
|
||||
"""Build a TextMessage payload."""
|
||||
payload = b""
|
||||
for sid in (session_ids or []):
|
||||
payload += _encode_field(2, _WIRE_VARINT, sid)
|
||||
for cid in (channel_ids or []):
|
||||
payload += _encode_field(3, _WIRE_VARINT, cid)
|
||||
for tid in (tree_ids or []):
|
||||
payload += _encode_field(4, _WIRE_VARINT, tid)
|
||||
# field 5: message (string)
|
||||
payload += _encode_field(5, _WIRE_LEN, message)
|
||||
return payload
|
||||
|
||||
|
||||
def _build_mumble_message(
|
||||
fields: dict[int, list],
|
||||
users: dict[int, str],
|
||||
our_session: int,
|
||||
) -> MumbleMessage | None:
|
||||
"""Build a MumbleMessage from decoded TextMessage fields."""
|
||||
# field 5: message text
|
||||
raw_text = _field_str(fields, 5)
|
||||
if raw_text is None:
|
||||
return None
|
||||
|
||||
text = _strip_html(raw_text)
|
||||
|
||||
# field 1: actor (sender session)
|
||||
actor = _field_int(fields, 1)
|
||||
nick = users.get(actor) if actor is not None else None
|
||||
prefix = nick # use username for ACL
|
||||
|
||||
# Determine target: channel_id (field 3) or DM (field 2)
|
||||
channel_ids = _field_ints(fields, 3)
|
||||
session_ids = _field_ints(fields, 2)
|
||||
|
||||
if channel_ids:
|
||||
target = str(channel_ids[0])
|
||||
is_channel = True
|
||||
elif session_ids:
|
||||
target = "dm"
|
||||
is_channel = False
|
||||
else:
|
||||
target = None
|
||||
is_channel = True
|
||||
|
||||
return MumbleMessage(
|
||||
raw=dict(fields),
|
||||
nick=nick,
|
||||
prefix=prefix,
|
||||
text=text,
|
||||
target=target,
|
||||
is_channel=is_channel,
|
||||
params=[target or "", text],
|
||||
)
|
||||
|
||||
|
||||
# -- MumbleBot --------------------------------------------------------------
|
||||
|
||||
|
||||
class MumbleBot:
|
||||
"""Mumble bot adapter via TCP/TLS protobuf control channel (text only).
|
||||
|
||||
Exposes the same public API as :class:`derp.bot.Bot` so that
|
||||
protocol-agnostic plugins work without modification.
|
||||
All TCP goes through ``derp.http.create_connection`` (SOCKS5 proxy).
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, config: dict, registry: PluginRegistry) -> None:
|
||||
self.name = name
|
||||
self.config = config
|
||||
self.registry = registry
|
||||
self._pstate: dict = {}
|
||||
|
||||
mu_cfg = config.get("mumble", {})
|
||||
self._host: str = mu_cfg.get("host", "127.0.0.1")
|
||||
self._port: int = mu_cfg.get("port", 64738)
|
||||
self._username: str = mu_cfg.get("username", "derp")
|
||||
self._password: str = mu_cfg.get("password", "")
|
||||
self._tls_verify: bool = mu_cfg.get("tls_verify", False)
|
||||
self.nick: str = self._username
|
||||
self.prefix: str = (
|
||||
mu_cfg.get("prefix")
|
||||
or config.get("bot", {}).get("prefix", "!")
|
||||
)
|
||||
self._running = False
|
||||
self._started: float = time.monotonic()
|
||||
self._tasks: set[asyncio.Task] = set()
|
||||
self._reconnect_delay: float = 5.0
|
||||
self._admins: list[str] = [str(x) for x in mu_cfg.get("admins", [])]
|
||||
self._operators: list[str] = [str(x) for x in mu_cfg.get("operators", [])]
|
||||
self._trusted: list[str] = [str(x) for x in mu_cfg.get("trusted", [])]
|
||||
self.state = StateStore(f"data/state-{name}.db")
|
||||
|
||||
# Protocol state
|
||||
self._session: int = 0 # our session ID (from ServerSync)
|
||||
self._channels: dict[int, str] = {} # channel_id -> name
|
||||
self._users: dict[int, str] = {} # session_id -> username
|
||||
self._user_channels: dict[int, int] = {} # session_id -> channel_id
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
|
||||
rate_cfg = config.get("bot", {})
|
||||
self._bucket = _TokenBucket(
|
||||
rate=rate_cfg.get("rate_limit", 2.0),
|
||||
burst=rate_cfg.get("rate_burst", 5),
|
||||
)
|
||||
|
||||
# -- Connection ----------------------------------------------------------
|
||||
|
||||
def _create_ssl_context(self) -> ssl.SSLContext:
|
||||
"""Build an SSL context for Mumble TLS."""
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
if not self._tls_verify:
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
return ctx
|
||||
|
||||
async def _connect(self) -> None:
|
||||
"""Establish TLS connection over SOCKS5 proxy."""
|
||||
loop = asyncio.get_running_loop()
|
||||
sock = await loop.run_in_executor(
|
||||
None, http.create_connection,
|
||||
(self._host, self._port),
|
||||
)
|
||||
ssl_ctx = self._create_ssl_context()
|
||||
hostname = self._host if self._tls_verify else None
|
||||
self._reader, self._writer = await asyncio.open_connection(
|
||||
sock=sock, ssl=ssl_ctx, server_hostname=hostname,
|
||||
)
|
||||
|
||||
async def _send_msg(self, msg_type: int, payload: bytes = b"") -> None:
|
||||
"""Send a framed Mumble message."""
|
||||
if self._writer is None:
|
||||
return
|
||||
data = _pack_msg(msg_type, payload)
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
|
||||
async def _read_msg(self) -> tuple[int, bytes] | None:
|
||||
"""Read one framed Mumble message. Returns (msg_type, payload)."""
|
||||
if self._reader is None:
|
||||
return None
|
||||
header = await self._reader.readexactly(6)
|
||||
msg_type, length = _unpack_header(header)
|
||||
payload = await self._reader.readexactly(length) if length else b""
|
||||
return msg_type, payload
|
||||
|
||||
async def _close(self) -> None:
|
||||
"""Close the connection."""
|
||||
if self._writer:
|
||||
try:
|
||||
self._writer.close()
|
||||
await self._writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
self._reader = None
|
||||
self._writer = None
|
||||
|
||||
# -- Lifecycle -----------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect and enter the message loop with reconnect backoff."""
|
||||
self._running = True
|
||||
while self._running:
|
||||
try:
|
||||
await self._connect_and_run()
|
||||
self._reconnect_delay = 5.0
|
||||
except (OSError, ConnectionError, asyncio.IncompleteReadError) as exc:
|
||||
log.error("mumble: connection lost: %s", exc)
|
||||
if self._running:
|
||||
jitter = self._reconnect_delay * 0.25 * (2 * random.random() - 1)
|
||||
delay = self._reconnect_delay + jitter
|
||||
log.info("mumble: reconnecting in %.0fs...", delay)
|
||||
await asyncio.sleep(delay)
|
||||
self._reconnect_delay = min(self._reconnect_delay * 2, 300.0)
|
||||
|
||||
async def _connect_and_run(self) -> None:
|
||||
"""Single connection lifecycle."""
|
||||
await self._connect()
|
||||
try:
|
||||
await self._handshake()
|
||||
await self._loop()
|
||||
finally:
|
||||
await self._close()
|
||||
|
||||
async def _handshake(self) -> None:
|
||||
"""Send Version and Authenticate messages."""
|
||||
# Version
|
||||
await self._send_msg(MSG_VERSION, _build_version_payload())
|
||||
# Authenticate
|
||||
await self._send_msg(
|
||||
MSG_AUTHENTICATE,
|
||||
_build_authenticate_payload(self._username, self._password),
|
||||
)
|
||||
log.info("mumble: authenticating as %s on %s:%d",
|
||||
self._username, self._host, self._port)
|
||||
|
||||
async def _loop(self) -> None:
|
||||
"""Read and dispatch messages until disconnect."""
|
||||
while self._running:
|
||||
result = await self._read_msg()
|
||||
if result is None:
|
||||
log.warning("mumble: server closed connection")
|
||||
return
|
||||
msg_type, payload = result
|
||||
await self._handle(msg_type, payload)
|
||||
|
||||
async def _handle(self, msg_type: int, payload: bytes) -> None:
|
||||
"""Dispatch a received Mumble message by type."""
|
||||
if msg_type == MSG_PING:
|
||||
fields = _decode_fields(payload)
|
||||
ts = _field_int(fields, 1) or 0
|
||||
await self._send_msg(MSG_PING, _build_ping_payload(ts))
|
||||
return
|
||||
|
||||
if msg_type == MSG_SERVER_SYNC:
|
||||
fields = _decode_fields(payload)
|
||||
session = _field_int(fields, 1)
|
||||
if session is not None:
|
||||
self._session = session
|
||||
welcome = _field_str(fields, 3) or ""
|
||||
log.info("mumble: connected, session=%d, welcome=%s",
|
||||
self._session, _strip_html(welcome)[:80])
|
||||
return
|
||||
|
||||
if msg_type == MSG_CHANNEL_STATE:
|
||||
fields = _decode_fields(payload)
|
||||
cid = _field_int(fields, 1)
|
||||
name = _field_str(fields, 3)
|
||||
if cid is not None and name is not None:
|
||||
self._channels[cid] = name
|
||||
return
|
||||
|
||||
if msg_type == MSG_CHANNEL_REMOVE:
|
||||
fields = _decode_fields(payload)
|
||||
cid = _field_int(fields, 1)
|
||||
if cid is not None:
|
||||
self._channels.pop(cid, None)
|
||||
return
|
||||
|
||||
if msg_type == MSG_USER_STATE:
|
||||
fields = _decode_fields(payload)
|
||||
session = _field_int(fields, 1)
|
||||
name = _field_str(fields, 3)
|
||||
channel_id = _field_int(fields, 5)
|
||||
if session is not None:
|
||||
if name is not None:
|
||||
self._users[session] = name
|
||||
if channel_id is not None:
|
||||
self._user_channels[session] = channel_id
|
||||
return
|
||||
|
||||
if msg_type == MSG_USER_REMOVE:
|
||||
fields = _decode_fields(payload)
|
||||
session = _field_int(fields, 1)
|
||||
if session is not None:
|
||||
self._users.pop(session, None)
|
||||
self._user_channels.pop(session, None)
|
||||
return
|
||||
|
||||
if msg_type == MSG_TEXT_MESSAGE:
|
||||
fields = _decode_fields(payload)
|
||||
msg = _build_mumble_message(fields, self._users, self._session)
|
||||
if msg is not None:
|
||||
await self._dispatch_command(msg)
|
||||
return
|
||||
|
||||
# -- Command dispatch ----------------------------------------------------
|
||||
|
||||
async def _dispatch_command(self, msg: MumbleMessage) -> None:
|
||||
"""Parse and dispatch a command from a Mumble text message."""
|
||||
text = msg.text
|
||||
if not text or not text.startswith(self.prefix):
|
||||
return
|
||||
|
||||
parts = text[len(self.prefix):].split(None, 1)
|
||||
cmd_name = parts[0].lower() if parts else ""
|
||||
handler = self._resolve_command(cmd_name)
|
||||
if handler is None:
|
||||
return
|
||||
if handler is _AMBIGUOUS:
|
||||
matches = [k for k in self.registry.commands
|
||||
if k.startswith(cmd_name)]
|
||||
names = ", ".join(self.prefix + m for m in sorted(matches))
|
||||
await self.reply(
|
||||
msg,
|
||||
f"Ambiguous command '{self.prefix}{cmd_name}': {names}",
|
||||
)
|
||||
return
|
||||
|
||||
if not self._plugin_allowed(handler.plugin, msg.target):
|
||||
return
|
||||
|
||||
required = handler.tier
|
||||
if required != "user":
|
||||
sender = self._get_tier(msg)
|
||||
if TIERS.index(sender) < TIERS.index(required):
|
||||
await self.reply(
|
||||
msg,
|
||||
f"Permission denied: {self.prefix}{cmd_name} "
|
||||
f"requires {required}",
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
await handler.callback(self, msg)
|
||||
except Exception:
|
||||
log.exception("mumble: error in command handler '%s'", cmd_name)
|
||||
|
||||
def _resolve_command(self, name: str):
|
||||
"""Resolve command name with unambiguous prefix matching.
|
||||
|
||||
Returns the Handler on exact or unique prefix match, the sentinel
|
||||
``_AMBIGUOUS`` if multiple commands match, or None if nothing matches.
|
||||
"""
|
||||
handler = self.registry.commands.get(name)
|
||||
if handler is not None:
|
||||
return handler
|
||||
matches = [v for k, v in self.registry.commands.items()
|
||||
if k.startswith(name)]
|
||||
if len(matches) == 1:
|
||||
return matches[0]
|
||||
if len(matches) > 1:
|
||||
return _AMBIGUOUS
|
||||
return None
|
||||
|
||||
def _plugin_allowed(self, plugin_name: str, channel: str | None) -> bool:
|
||||
"""Channel filtering is IRC-only; all plugins are allowed on Mumble."""
|
||||
return True
|
||||
|
||||
# -- Permission tiers ----------------------------------------------------
|
||||
|
||||
def _get_tier(self, msg: MumbleMessage) -> str:
|
||||
"""Determine permission tier from username.
|
||||
|
||||
Matches exact string comparison of username against config lists.
|
||||
"""
|
||||
if not msg.prefix:
|
||||
return "user"
|
||||
for name in self._admins:
|
||||
if msg.prefix == name:
|
||||
return "admin"
|
||||
for name in self._operators:
|
||||
if msg.prefix == name:
|
||||
return "oper"
|
||||
for name in self._trusted:
|
||||
if msg.prefix == name:
|
||||
return "trusted"
|
||||
return "user"
|
||||
|
||||
def _is_admin(self, msg: MumbleMessage) -> bool:
|
||||
"""Check if the message sender is a bot admin."""
|
||||
return self._get_tier(msg) == "admin"
|
||||
|
||||
# -- Public API for plugins ----------------------------------------------
|
||||
|
||||
async def _send_html(self, target: str, html_text: str) -> None:
|
||||
"""Send a TextMessage with pre-formatted HTML (no escaping)."""
|
||||
await self._bucket.acquire()
|
||||
try:
|
||||
channel_id = int(target)
|
||||
except (ValueError, TypeError):
|
||||
channel_id = 0 # root channel fallback
|
||||
payload = _build_text_message_payload(
|
||||
channel_ids=[channel_id], message=html_text,
|
||||
)
|
||||
await self._send_msg(MSG_TEXT_MESSAGE, payload)
|
||||
|
||||
async def send(self, target: str, text: str) -> None:
|
||||
"""Send a TextMessage to a channel (HTML-escaped, rate-limited).
|
||||
|
||||
``target`` is a channel_id as string. For DMs, sends to the
|
||||
current channel instead (Mumble DMs require session IDs which
|
||||
are not stable).
|
||||
"""
|
||||
await self._send_html(target, _escape_html(text))
|
||||
|
||||
async def reply(self, msg, text: str) -> None:
|
||||
"""Reply to the source channel."""
|
||||
if msg.target and msg.target != "dm":
|
||||
await self.send(msg.target, text)
|
||||
elif msg.target == "dm":
|
||||
# Best-effort: send to root channel
|
||||
await self.send("0", text)
|
||||
|
||||
async def long_reply(
|
||||
self, msg, lines: list[str], *,
|
||||
label: str = "",
|
||||
) -> None:
|
||||
"""Reply with a list of lines; paste overflow to FlaskPaste.
|
||||
|
||||
Same overflow logic as :meth:`derp.bot.Bot.long_reply`.
|
||||
"""
|
||||
threshold = self.config.get("bot", {}).get("paste_threshold", 4)
|
||||
if not lines or not msg.target:
|
||||
return
|
||||
|
||||
if len(lines) <= threshold:
|
||||
for line in lines:
|
||||
await self.send(msg.target, line)
|
||||
return
|
||||
|
||||
# Attempt paste overflow
|
||||
fp = self.registry._modules.get("flaskpaste")
|
||||
paste_url = None
|
||||
if fp:
|
||||
full_text = "\n".join(lines)
|
||||
loop = asyncio.get_running_loop()
|
||||
paste_url = await loop.run_in_executor(
|
||||
None, fp.create_paste, self, full_text,
|
||||
)
|
||||
|
||||
if paste_url:
|
||||
preview_count = min(2, threshold - 1)
|
||||
for line in lines[:preview_count]:
|
||||
await self.send(msg.target, line)
|
||||
remaining = len(lines) - preview_count
|
||||
suffix = f" ({label})" if label else ""
|
||||
await self.send(
|
||||
msg.target,
|
||||
f"... {remaining} more lines{suffix}: {paste_url}",
|
||||
)
|
||||
else:
|
||||
for line in lines:
|
||||
await self.send(msg.target, line)
|
||||
|
||||
async def action(self, target: str, text: str) -> None:
|
||||
"""Send an action as italic HTML text."""
|
||||
await self._send_html(target, f"<i>{_escape_html(text)}</i>")
|
||||
|
||||
async def shorten_url(self, url: str) -> str:
|
||||
"""Shorten a URL via FlaskPaste. Returns original on failure."""
|
||||
fp = self.registry._modules.get("flaskpaste")
|
||||
if not fp:
|
||||
return url
|
||||
loop = asyncio.get_running_loop()
|
||||
try:
|
||||
return await loop.run_in_executor(None, fp.shorten_url, self, url)
|
||||
except Exception:
|
||||
return url
|
||||
|
||||
# -- IRC no-ops ----------------------------------------------------------
|
||||
|
||||
async def join(self, channel: str) -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: join() is a no-op")
|
||||
|
||||
async def part(self, channel: str, reason: str = "") -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: part() is a no-op")
|
||||
|
||||
async def quit(self, reason: str = "bye") -> None:
|
||||
"""Stop the Mumble adapter."""
|
||||
self._running = False
|
||||
|
||||
async def kick(self, channel: str, nick: str, reason: str = "") -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: kick() is a no-op")
|
||||
|
||||
async def mode(self, target: str, mode_str: str, *args: str) -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: mode() is a no-op")
|
||||
|
||||
async def set_topic(self, channel: str, topic: str) -> None:
|
||||
"""No-op: IRC-only concept."""
|
||||
log.debug("mumble: set_topic() is a no-op")
|
||||
|
||||
# -- Plugin management (delegated to registry) ---------------------------
|
||||
|
||||
def load_plugins(self, plugins_dir: str | Path | None = None) -> None:
|
||||
"""Load plugins from the configured directory."""
|
||||
if plugins_dir is None:
|
||||
plugins_dir = self.config.get("bot", {}).get(
|
||||
"plugins_dir", "plugins")
|
||||
path = Path(plugins_dir)
|
||||
self.registry.load_directory(path)
|
||||
|
||||
@property
|
||||
def plugins_dir(self) -> Path:
|
||||
"""Resolved path to the plugins directory."""
|
||||
return Path(self.config.get("bot", {}).get("plugins_dir", "plugins"))
|
||||
|
||||
def load_plugin(self, name: str) -> tuple[bool, str]:
|
||||
"""Hot-load a new plugin by name from the plugins directory."""
|
||||
if name in self.registry._modules:
|
||||
return False, f"plugin already loaded: {name}"
|
||||
path = self.plugins_dir / f"{name}.py"
|
||||
if not path.is_file():
|
||||
return False, f"{name}.py not found"
|
||||
count = self.registry.load_plugin(path)
|
||||
if count < 0:
|
||||
return False, f"failed to load {name}"
|
||||
return True, f"{count} handlers"
|
||||
|
||||
def reload_plugin(self, name: str) -> tuple[bool, str]:
|
||||
"""Reload a plugin, picking up any file changes."""
|
||||
return self.registry.reload_plugin(name)
|
||||
|
||||
def unload_plugin(self, name: str) -> tuple[bool, str]:
|
||||
"""Unload a plugin, removing all its handlers."""
|
||||
if self.registry.unload_plugin(name):
|
||||
return True, ""
|
||||
if name == "core":
|
||||
return False, "cannot unload core"
|
||||
return False, f"plugin not loaded: {name}"
|
||||
|
||||
def _spawn(self, coro, *, name: str | None = None) -> asyncio.Task:
|
||||
"""Spawn a background task and track it for cleanup."""
|
||||
task = asyncio.create_task(coro, name=name)
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._tasks.discard)
|
||||
return task
|
||||
891
tests/test_mumble.py
Normal file
891
tests/test_mumble.py
Normal file
@@ -0,0 +1,891 @@
|
||||
"""Tests for the Mumble adapter."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from derp.mumble import (
|
||||
_WIRE_LEN,
|
||||
_WIRE_VARINT,
|
||||
MSG_CHANNEL_REMOVE,
|
||||
MSG_CHANNEL_STATE,
|
||||
MSG_PING,
|
||||
MSG_SERVER_SYNC,
|
||||
MSG_TEXT_MESSAGE,
|
||||
MSG_USER_REMOVE,
|
||||
MSG_USER_STATE,
|
||||
MumbleBot,
|
||||
MumbleMessage,
|
||||
_build_authenticate_payload,
|
||||
_build_mumble_message,
|
||||
_build_ping_payload,
|
||||
_build_text_message_payload,
|
||||
_build_version_payload,
|
||||
_decode_fields,
|
||||
_decode_varint,
|
||||
_encode_field,
|
||||
_encode_varint,
|
||||
_escape_html,
|
||||
_field_int,
|
||||
_field_ints,
|
||||
_field_str,
|
||||
_pack_msg,
|
||||
_strip_html,
|
||||
_unpack_header,
|
||||
)
|
||||
from derp.plugin import PluginRegistry
|
||||
|
||||
# -- Helpers -----------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bot(admins=None, operators=None, trusted=None, prefix=None):
|
||||
"""Create a MumbleBot with test config."""
|
||||
config = {
|
||||
"mumble": {
|
||||
"enabled": True,
|
||||
"host": "127.0.0.1",
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"admins": admins or [],
|
||||
"operators": operators or [],
|
||||
"trusted": trusted or [],
|
||||
},
|
||||
"bot": {
|
||||
"prefix": prefix or "!",
|
||||
"paste_threshold": 4,
|
||||
"plugins_dir": "plugins",
|
||||
"rate_limit": 2.0,
|
||||
"rate_burst": 5,
|
||||
},
|
||||
}
|
||||
registry = PluginRegistry()
|
||||
bot = MumbleBot("mu-test", config, registry)
|
||||
return bot
|
||||
|
||||
|
||||
def _mu_msg(text="!ping", nick="Alice", prefix="Alice",
|
||||
target="0", is_channel=True):
|
||||
"""Create a MumbleMessage for command testing."""
|
||||
return MumbleMessage(
|
||||
raw={}, nick=nick, prefix=prefix, text=text, target=target,
|
||||
is_channel=is_channel,
|
||||
params=[target, text],
|
||||
)
|
||||
|
||||
|
||||
# -- Test helpers for registering commands -----------------------------------
|
||||
|
||||
|
||||
async def _echo_handler(bot, msg):
|
||||
"""Simple command handler that echoes text."""
|
||||
args = msg.text.split(None, 1)
|
||||
reply = args[1] if len(args) > 1 else "no args"
|
||||
await bot.reply(msg, reply)
|
||||
|
||||
|
||||
async def _admin_handler(bot, msg):
|
||||
"""Admin-only command handler."""
|
||||
await bot.reply(msg, "admin action done")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestProtobufCodec
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProtobufCodec:
|
||||
def test_encode_varint_zero(self):
|
||||
assert _encode_varint(0) == b"\x00"
|
||||
|
||||
def test_encode_varint_small(self):
|
||||
assert _encode_varint(1) == b"\x01"
|
||||
assert _encode_varint(127) == b"\x7f"
|
||||
|
||||
def test_encode_varint_two_byte(self):
|
||||
assert _encode_varint(128) == b"\x80\x01"
|
||||
assert _encode_varint(300) == b"\xac\x02"
|
||||
|
||||
def test_encode_varint_large(self):
|
||||
# 16384 = 0x4000
|
||||
encoded = _encode_varint(16384)
|
||||
val, _ = _decode_varint(encoded, 0)
|
||||
assert val == 16384
|
||||
|
||||
def test_decode_varint_zero(self):
|
||||
val, off = _decode_varint(b"\x00", 0)
|
||||
assert val == 0
|
||||
assert off == 1
|
||||
|
||||
def test_decode_varint_small(self):
|
||||
val, off = _decode_varint(b"\x01", 0)
|
||||
assert val == 1
|
||||
assert off == 1
|
||||
|
||||
def test_decode_varint_multi_byte(self):
|
||||
val, off = _decode_varint(b"\xac\x02", 0)
|
||||
assert val == 300
|
||||
assert off == 2
|
||||
|
||||
def test_varint_roundtrip(self):
|
||||
for n in [0, 1, 127, 128, 300, 16384, 2**21, 2**28]:
|
||||
encoded = _encode_varint(n)
|
||||
decoded, _ = _decode_varint(encoded, 0)
|
||||
assert decoded == n, f"roundtrip failed for {n}"
|
||||
|
||||
def test_encode_field_varint(self):
|
||||
# field 1, wire type 0, value 42
|
||||
data = _encode_field(1, _WIRE_VARINT, 42)
|
||||
fields = _decode_fields(data)
|
||||
assert fields[1] == [42]
|
||||
|
||||
def test_encode_field_string(self):
|
||||
data = _encode_field(5, _WIRE_LEN, "hello")
|
||||
fields = _decode_fields(data)
|
||||
assert fields[5] == [b"hello"]
|
||||
|
||||
def test_encode_field_bytes(self):
|
||||
data = _encode_field(3, _WIRE_LEN, b"\x00\x01\x02")
|
||||
fields = _decode_fields(data)
|
||||
assert fields[3] == [b"\x00\x01\x02"]
|
||||
|
||||
def test_decode_multiple_fields(self):
|
||||
data = (
|
||||
_encode_field(1, _WIRE_VARINT, 10)
|
||||
+ _encode_field(2, _WIRE_LEN, "test")
|
||||
+ _encode_field(3, _WIRE_VARINT, 99)
|
||||
)
|
||||
fields = _decode_fields(data)
|
||||
assert fields[1] == [10]
|
||||
assert fields[2] == [b"test"]
|
||||
assert fields[3] == [99]
|
||||
|
||||
def test_decode_repeated_fields(self):
|
||||
data = (
|
||||
_encode_field(3, _WIRE_VARINT, 1)
|
||||
+ _encode_field(3, _WIRE_VARINT, 2)
|
||||
+ _encode_field(3, _WIRE_VARINT, 3)
|
||||
)
|
||||
fields = _decode_fields(data)
|
||||
assert fields[3] == [1, 2, 3]
|
||||
|
||||
def test_pack_unpack_header(self):
|
||||
packed = _pack_msg(11, b"hello")
|
||||
msg_type, length = _unpack_header(packed[:6])
|
||||
assert msg_type == 11
|
||||
assert length == 5
|
||||
assert packed[6:] == b"hello"
|
||||
|
||||
def test_pack_empty_payload(self):
|
||||
packed = _pack_msg(3)
|
||||
assert len(packed) == 6
|
||||
msg_type, length = _unpack_header(packed)
|
||||
assert msg_type == 3
|
||||
assert length == 0
|
||||
|
||||
def test_field_str(self):
|
||||
fields = {5: [b"hello"]}
|
||||
assert _field_str(fields, 5) == "hello"
|
||||
assert _field_str(fields, 1) is None
|
||||
|
||||
def test_field_int(self):
|
||||
fields = {1: [42]}
|
||||
assert _field_int(fields, 1) == 42
|
||||
assert _field_int(fields, 2) is None
|
||||
|
||||
def test_field_ints(self):
|
||||
fields = {3: [1, 2, 3]}
|
||||
assert _field_ints(fields, 3) == [1, 2, 3]
|
||||
assert _field_ints(fields, 9) == []
|
||||
|
||||
def test_decode_empty(self):
|
||||
fields = _decode_fields(b"")
|
||||
assert fields == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleMessage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleMessage:
|
||||
def test_defaults(self):
|
||||
msg = MumbleMessage(raw={}, nick=None, prefix=None, text=None,
|
||||
target=None)
|
||||
assert msg.is_channel is True
|
||||
assert msg.command == "PRIVMSG"
|
||||
assert msg.params == []
|
||||
assert msg.tags == {}
|
||||
|
||||
def test_custom_values(self):
|
||||
msg = MumbleMessage(
|
||||
raw={"field": 1}, nick="Alice", prefix="Alice",
|
||||
text="hello", target="0", is_channel=True,
|
||||
command="PRIVMSG", params=["0", "hello"],
|
||||
tags={"key": "val"},
|
||||
)
|
||||
assert msg.nick == "Alice"
|
||||
assert msg.prefix == "Alice"
|
||||
assert msg.text == "hello"
|
||||
assert msg.target == "0"
|
||||
assert msg.tags == {"key": "val"}
|
||||
|
||||
def test_duck_type_compat(self):
|
||||
"""MumbleMessage has the same attribute names as IRC Message."""
|
||||
msg = _mu_msg()
|
||||
attrs = ["raw", "nick", "prefix", "text", "target",
|
||||
"is_channel", "command", "params", "tags"]
|
||||
for attr in attrs:
|
||||
assert hasattr(msg, attr), f"missing attribute: {attr}"
|
||||
|
||||
def test_dm_message(self):
|
||||
msg = _mu_msg(target="dm", is_channel=False)
|
||||
assert msg.is_channel is False
|
||||
assert msg.target == "dm"
|
||||
|
||||
def test_prefix_is_username(self):
|
||||
msg = _mu_msg(prefix="admin_user")
|
||||
assert msg.prefix == "admin_user"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestBuildMumbleMessage
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildMumbleMessage:
|
||||
def test_channel_message(self):
|
||||
fields = {
|
||||
1: [42], # actor session
|
||||
3: [5], # channel_id
|
||||
5: [b"<b>Hello</b>"], # message HTML
|
||||
}
|
||||
users = {42: "Alice"}
|
||||
msg = _build_mumble_message(fields, users, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.nick == "Alice"
|
||||
assert msg.prefix == "Alice"
|
||||
assert msg.text == "Hello"
|
||||
assert msg.target == "5"
|
||||
assert msg.is_channel is True
|
||||
|
||||
def test_dm_message(self):
|
||||
fields = {
|
||||
1: [42], # actor session
|
||||
2: [1], # target session (DM)
|
||||
5: [b"secret"], # message
|
||||
}
|
||||
users = {42: "Bob"}
|
||||
msg = _build_mumble_message(fields, users, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.target == "dm"
|
||||
assert msg.is_channel is False
|
||||
|
||||
def test_missing_sender(self):
|
||||
fields = {
|
||||
3: [0], # channel_id
|
||||
5: [b"anonymous"], # message
|
||||
}
|
||||
msg = _build_mumble_message(fields, {}, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.nick is None
|
||||
assert msg.prefix is None
|
||||
|
||||
def test_missing_message(self):
|
||||
fields = {
|
||||
1: [42],
|
||||
3: [0],
|
||||
}
|
||||
msg = _build_mumble_message(fields, {42: "Alice"}, our_session=1)
|
||||
assert msg is None
|
||||
|
||||
def test_html_stripped(self):
|
||||
fields = {
|
||||
1: [1],
|
||||
3: [0],
|
||||
5: [b"<a href='link'>click & go</a>"],
|
||||
}
|
||||
msg = _build_mumble_message(fields, {1: "User"}, our_session=0)
|
||||
assert msg is not None
|
||||
assert msg.text == "click & go"
|
||||
|
||||
def test_no_target(self):
|
||||
fields = {
|
||||
1: [42],
|
||||
5: [b"orphan message"],
|
||||
}
|
||||
msg = _build_mumble_message(fields, {42: "Alice"}, our_session=1)
|
||||
assert msg is not None
|
||||
assert msg.target is None
|
||||
assert msg.is_channel is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestHtmlHelpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHtmlHelpers:
|
||||
def test_strip_html_simple(self):
|
||||
assert _strip_html("<b>bold</b>") == "bold"
|
||||
|
||||
def test_strip_html_entities(self):
|
||||
assert _strip_html("& < > "") == '& < > "'
|
||||
|
||||
def test_strip_html_nested(self):
|
||||
assert _strip_html("<p><b>hello</b> <i>world</i></p>") == "hello world"
|
||||
|
||||
def test_strip_html_plain(self):
|
||||
assert _strip_html("no tags here") == "no tags here"
|
||||
|
||||
def test_escape_html(self):
|
||||
assert _escape_html("<script>alert('xss')") == "<script>alert('xss')"
|
||||
|
||||
def test_escape_html_ampersand(self):
|
||||
assert _escape_html("a & b") == "a & b"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleBotReply
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleBotReply:
|
||||
def test_send_builds_text_message(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[int, bytes]] = []
|
||||
|
||||
async def _fake_send_msg(msg_type, payload=b""):
|
||||
sent.append((msg_type, payload))
|
||||
|
||||
with patch.object(bot, "_send_msg", side_effect=_fake_send_msg):
|
||||
asyncio.run(bot.send("5", "hello"))
|
||||
assert len(sent) == 1
|
||||
assert sent[0][0] == MSG_TEXT_MESSAGE
|
||||
# Verify payload contains the message
|
||||
fields = _decode_fields(sent[0][1])
|
||||
text = _field_str(fields, 5)
|
||||
assert text == "hello"
|
||||
assert _field_ints(fields, 3) == [5]
|
||||
|
||||
def test_send_escapes_html(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[int, bytes]] = []
|
||||
|
||||
async def _fake_send_msg(msg_type, payload=b""):
|
||||
sent.append((msg_type, payload))
|
||||
|
||||
with patch.object(bot, "_send_msg", side_effect=_fake_send_msg):
|
||||
asyncio.run(bot.send("0", "<script>alert('xss')"))
|
||||
fields = _decode_fields(sent[0][1])
|
||||
text = _field_str(fields, 5)
|
||||
assert "<script>" not in text
|
||||
assert "<script>" in text
|
||||
|
||||
def test_reply_sends_to_target(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg(target="5")
|
||||
sent: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append((target, text))
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot.reply(msg, "pong"))
|
||||
assert sent == [("5", "pong")]
|
||||
|
||||
def test_reply_dm_fallback(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg(target="dm", is_channel=False)
|
||||
sent: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append((target, text))
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot.reply(msg, "dm reply"))
|
||||
# DM falls back to root channel "0"
|
||||
assert sent == [("0", "dm reply")]
|
||||
|
||||
def test_long_reply_under_threshold(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg()
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot.long_reply(msg, ["a", "b", "c"]))
|
||||
assert sent == ["a", "b", "c"]
|
||||
|
||||
def test_long_reply_over_threshold_no_paste(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg()
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot.long_reply(msg, ["a", "b", "c", "d", "e"]))
|
||||
assert sent == ["a", "b", "c", "d", "e"]
|
||||
|
||||
def test_long_reply_empty(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg()
|
||||
with patch.object(bot, "send") as mock_send:
|
||||
asyncio.run(bot.long_reply(msg, []))
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_action_format(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_send_html(target, html_text):
|
||||
sent.append((target, html_text))
|
||||
|
||||
with patch.object(bot, "_send_html", side_effect=_fake_send_html):
|
||||
asyncio.run(bot.action("0", "does a thing"))
|
||||
assert sent == [("0", "<i>does a thing</i>")]
|
||||
|
||||
def test_action_escapes_content(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_send_html(target, html_text):
|
||||
sent.append((target, html_text))
|
||||
|
||||
with patch.object(bot, "_send_html", side_effect=_fake_send_html):
|
||||
asyncio.run(bot.action("0", "<script>"))
|
||||
assert sent == [("0", "<i><script></i>")]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleBotDispatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleBotDispatch:
|
||||
def test_dispatch_known_command(self):
|
||||
bot = _make_bot()
|
||||
bot.registry.register_command(
|
||||
"echo", _echo_handler, help="echo", plugin="test")
|
||||
msg = _mu_msg(text="!echo world")
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
assert sent == ["world"]
|
||||
|
||||
def test_dispatch_unknown_command(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg(text="!nonexistent")
|
||||
with patch.object(bot, "send") as mock_send:
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_dispatch_no_prefix(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg(text="just a message")
|
||||
with patch.object(bot, "send") as mock_send:
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_dispatch_empty_text(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg(text="")
|
||||
with patch.object(bot, "send") as mock_send:
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_dispatch_none_text(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg()
|
||||
msg.text = None
|
||||
with patch.object(bot, "send") as mock_send:
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
mock_send.assert_not_called()
|
||||
|
||||
def test_dispatch_ambiguous(self):
|
||||
bot = _make_bot()
|
||||
bot.registry.register_command("ping", _echo_handler, plugin="test")
|
||||
bot.registry.register_command("plugins", _echo_handler, plugin="test")
|
||||
msg = _mu_msg(text="!p")
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
assert len(sent) == 1
|
||||
assert "Ambiguous" in sent[0]
|
||||
|
||||
def test_dispatch_tier_denied(self):
|
||||
bot = _make_bot()
|
||||
bot.registry.register_command(
|
||||
"secret", _admin_handler, plugin="test", tier="admin")
|
||||
msg = _mu_msg(text="!secret", prefix="nobody")
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
assert len(sent) == 1
|
||||
assert "Permission denied" in sent[0]
|
||||
|
||||
def test_dispatch_tier_allowed(self):
|
||||
bot = _make_bot(admins=["Alice"])
|
||||
bot.registry.register_command(
|
||||
"secret", _admin_handler, plugin="test", tier="admin")
|
||||
msg = _mu_msg(text="!secret", prefix="Alice")
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
assert sent == ["admin action done"]
|
||||
|
||||
def test_dispatch_prefix_match(self):
|
||||
bot = _make_bot()
|
||||
bot.registry.register_command("echo", _echo_handler, plugin="test")
|
||||
msg = _mu_msg(text="!ec hello")
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot._dispatch_command(msg))
|
||||
assert sent == ["hello"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleBotTier
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleBotTier:
|
||||
def test_admin_tier(self):
|
||||
bot = _make_bot(admins=["AdminUser"])
|
||||
msg = _mu_msg(prefix="AdminUser")
|
||||
assert bot._get_tier(msg) == "admin"
|
||||
|
||||
def test_oper_tier(self):
|
||||
bot = _make_bot(operators=["OperUser"])
|
||||
msg = _mu_msg(prefix="OperUser")
|
||||
assert bot._get_tier(msg) == "oper"
|
||||
|
||||
def test_trusted_tier(self):
|
||||
bot = _make_bot(trusted=["TrustedUser"])
|
||||
msg = _mu_msg(prefix="TrustedUser")
|
||||
assert bot._get_tier(msg) == "trusted"
|
||||
|
||||
def test_user_tier_default(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg(prefix="RandomUser")
|
||||
assert bot._get_tier(msg) == "user"
|
||||
|
||||
def test_no_prefix(self):
|
||||
bot = _make_bot(admins=["Admin"])
|
||||
msg = _mu_msg()
|
||||
msg.prefix = None
|
||||
assert bot._get_tier(msg) == "user"
|
||||
|
||||
def test_is_admin_true(self):
|
||||
bot = _make_bot(admins=["Admin"])
|
||||
msg = _mu_msg(prefix="Admin")
|
||||
assert bot._is_admin(msg) is True
|
||||
|
||||
def test_is_admin_false(self):
|
||||
bot = _make_bot()
|
||||
msg = _mu_msg(prefix="Nobody")
|
||||
assert bot._is_admin(msg) is False
|
||||
|
||||
def test_priority_order(self):
|
||||
"""Admin takes priority over oper and trusted."""
|
||||
bot = _make_bot(admins=["User"], operators=["User"], trusted=["User"])
|
||||
msg = _mu_msg(prefix="User")
|
||||
assert bot._get_tier(msg) == "admin"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleBotNoOps
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleBotNoOps:
|
||||
def test_join_noop(self):
|
||||
bot = _make_bot()
|
||||
asyncio.run(bot.join("#channel"))
|
||||
|
||||
def test_part_noop(self):
|
||||
bot = _make_bot()
|
||||
asyncio.run(bot.part("#channel", "reason"))
|
||||
|
||||
def test_kick_noop(self):
|
||||
bot = _make_bot()
|
||||
asyncio.run(bot.kick("#channel", "nick", "reason"))
|
||||
|
||||
def test_mode_noop(self):
|
||||
bot = _make_bot()
|
||||
asyncio.run(bot.mode("#channel", "+o", "nick"))
|
||||
|
||||
def test_set_topic_noop(self):
|
||||
bot = _make_bot()
|
||||
asyncio.run(bot.set_topic("#channel", "new topic"))
|
||||
|
||||
def test_quit_stops(self):
|
||||
bot = _make_bot()
|
||||
bot._running = True
|
||||
asyncio.run(bot.quit())
|
||||
assert bot._running is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleBotProtocol
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleBotProtocol:
|
||||
def test_version_payload_structure(self):
|
||||
payload = _build_version_payload()
|
||||
fields = _decode_fields(payload)
|
||||
# field 1: version_v1 (uint32)
|
||||
assert _field_int(fields, 1) == (1 << 16) | (5 << 8)
|
||||
# field 2: release (string)
|
||||
assert _field_str(fields, 2) == "derp 1.5.0"
|
||||
# field 3: os (string)
|
||||
assert _field_str(fields, 3) == "Linux"
|
||||
|
||||
def test_authenticate_payload_structure(self):
|
||||
payload = _build_authenticate_payload("testuser", "testpass")
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_str(fields, 1) == "testuser"
|
||||
assert _field_str(fields, 2) == "testpass"
|
||||
# field 5: opus (bool=1)
|
||||
assert _field_int(fields, 5) == 1
|
||||
|
||||
def test_authenticate_no_password(self):
|
||||
payload = _build_authenticate_payload("testuser", "")
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_str(fields, 1) == "testuser"
|
||||
assert 2 not in fields # no password field
|
||||
|
||||
def test_ping_payload_roundtrip(self):
|
||||
payload = _build_ping_payload(123456789)
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_int(fields, 1) == 123456789
|
||||
|
||||
def test_text_message_payload(self):
|
||||
payload = _build_text_message_payload(
|
||||
channel_ids=[5], message="<b>hello</b>",
|
||||
)
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_str(fields, 5) == "<b>hello</b>"
|
||||
assert _field_ints(fields, 3) == [5]
|
||||
|
||||
def test_text_message_multiple_channels(self):
|
||||
payload = _build_text_message_payload(
|
||||
channel_ids=[1, 2, 3], message="broadcast",
|
||||
)
|
||||
fields = _decode_fields(payload)
|
||||
assert _field_ints(fields, 3) == [1, 2, 3]
|
||||
|
||||
def test_handle_ping_echoes(self):
|
||||
bot = _make_bot()
|
||||
sent: list[tuple[int, bytes]] = []
|
||||
|
||||
async def _fake_send_msg(msg_type, payload=b""):
|
||||
sent.append((msg_type, payload))
|
||||
|
||||
ping_payload = _build_ping_payload(42)
|
||||
with patch.object(bot, "_send_msg", side_effect=_fake_send_msg):
|
||||
asyncio.run(bot._handle(MSG_PING, ping_payload))
|
||||
assert len(sent) == 1
|
||||
assert sent[0][0] == MSG_PING
|
||||
fields = _decode_fields(sent[0][1])
|
||||
assert _field_int(fields, 1) == 42
|
||||
|
||||
def test_handle_server_sync(self):
|
||||
bot = _make_bot()
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 99)
|
||||
+ _encode_field(3, _WIRE_LEN, "Welcome!")
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_SERVER_SYNC, payload))
|
||||
assert bot._session == 99
|
||||
|
||||
def test_handle_channel_state(self):
|
||||
bot = _make_bot()
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 5)
|
||||
+ _encode_field(3, _WIRE_LEN, "General")
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_CHANNEL_STATE, payload))
|
||||
assert bot._channels[5] == "General"
|
||||
|
||||
def test_handle_channel_remove(self):
|
||||
bot = _make_bot()
|
||||
bot._channels[5] = "General"
|
||||
payload = _encode_field(1, _WIRE_VARINT, 5)
|
||||
asyncio.run(bot._handle(MSG_CHANNEL_REMOVE, payload))
|
||||
assert 5 not in bot._channels
|
||||
|
||||
def test_handle_user_state(self):
|
||||
bot = _make_bot()
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 10)
|
||||
+ _encode_field(3, _WIRE_LEN, "Alice")
|
||||
+ _encode_field(5, _WIRE_VARINT, 0)
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_USER_STATE, payload))
|
||||
assert bot._users[10] == "Alice"
|
||||
assert bot._user_channels[10] == 0
|
||||
|
||||
def test_handle_user_state_channel_change(self):
|
||||
bot = _make_bot()
|
||||
bot._users[10] = "Alice"
|
||||
bot._user_channels[10] = 0
|
||||
# User moves to channel 5
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 10)
|
||||
+ _encode_field(5, _WIRE_VARINT, 5)
|
||||
)
|
||||
asyncio.run(bot._handle(MSG_USER_STATE, payload))
|
||||
assert bot._users[10] == "Alice" # name unchanged
|
||||
assert bot._user_channels[10] == 5
|
||||
|
||||
def test_handle_user_remove(self):
|
||||
bot = _make_bot()
|
||||
bot._users[10] = "Alice"
|
||||
bot._user_channels[10] = 0
|
||||
payload = _encode_field(1, _WIRE_VARINT, 10)
|
||||
asyncio.run(bot._handle(MSG_USER_REMOVE, payload))
|
||||
assert 10 not in bot._users
|
||||
assert 10 not in bot._user_channels
|
||||
|
||||
def test_handle_text_message_dispatch(self):
|
||||
bot = _make_bot()
|
||||
bot._users[42] = "Alice"
|
||||
bot.registry.register_command(
|
||||
"echo", _echo_handler, help="echo", plugin="test")
|
||||
payload = (
|
||||
_encode_field(1, _WIRE_VARINT, 42) # actor
|
||||
+ _encode_field(3, _WIRE_VARINT, 0) # channel_id
|
||||
+ _encode_field(5, _WIRE_LEN, "!echo test") # message
|
||||
)
|
||||
sent: list[str] = []
|
||||
|
||||
async def _fake_send(target, text):
|
||||
sent.append(text)
|
||||
|
||||
with patch.object(bot, "send", side_effect=_fake_send):
|
||||
asyncio.run(bot._handle(MSG_TEXT_MESSAGE, payload))
|
||||
assert sent == ["test"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestPluginManagement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPluginManagement:
|
||||
def test_load_plugin_not_found(self):
|
||||
bot = _make_bot()
|
||||
ok, msg = bot.load_plugin("nonexistent_xyz")
|
||||
assert ok is False
|
||||
assert "not found" in msg
|
||||
|
||||
def test_load_plugin_already_loaded(self):
|
||||
bot = _make_bot()
|
||||
bot.registry._modules["test"] = object()
|
||||
ok, msg = bot.load_plugin("test")
|
||||
assert ok is False
|
||||
assert "already loaded" in msg
|
||||
|
||||
def test_unload_core_refused(self):
|
||||
bot = _make_bot()
|
||||
ok, msg = bot.unload_plugin("core")
|
||||
assert ok is False
|
||||
assert "cannot unload core" in msg
|
||||
|
||||
def test_unload_not_loaded(self):
|
||||
bot = _make_bot()
|
||||
ok, msg = bot.unload_plugin("nonexistent")
|
||||
assert ok is False
|
||||
assert "not loaded" in msg
|
||||
|
||||
def test_reload_delegates(self):
|
||||
bot = _make_bot()
|
||||
ok, msg = bot.reload_plugin("nonexistent")
|
||||
assert ok is False
|
||||
assert "not loaded" in msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TestMumbleBotConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMumbleBotConfig:
|
||||
def test_prefix_from_mumble_section(self):
|
||||
config = {
|
||||
"mumble": {
|
||||
"enabled": True,
|
||||
"host": "127.0.0.1",
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"prefix": "/",
|
||||
"admins": [],
|
||||
"operators": [],
|
||||
"trusted": [],
|
||||
},
|
||||
"bot": {"prefix": "!", "rate_limit": 2.0, "rate_burst": 5},
|
||||
}
|
||||
bot = MumbleBot("test", config, PluginRegistry())
|
||||
assert bot.prefix == "/"
|
||||
|
||||
def test_prefix_falls_back_to_bot(self):
|
||||
config = {
|
||||
"mumble": {
|
||||
"enabled": True,
|
||||
"host": "127.0.0.1",
|
||||
"port": 64738,
|
||||
"username": "derp",
|
||||
"password": "",
|
||||
"tls_verify": False,
|
||||
"admins": [],
|
||||
"operators": [],
|
||||
"trusted": [],
|
||||
},
|
||||
"bot": {"prefix": "!", "rate_limit": 2.0, "rate_burst": 5},
|
||||
}
|
||||
bot = MumbleBot("test", config, PluginRegistry())
|
||||
assert bot.prefix == "!"
|
||||
|
||||
def test_admins_coerced_to_str(self):
|
||||
bot = _make_bot(admins=[111, 222])
|
||||
assert bot._admins == ["111", "222"]
|
||||
|
||||
def test_default_port(self):
|
||||
bot = _make_bot()
|
||||
assert bot._port == 64738
|
||||
|
||||
def test_tls_verify_default(self):
|
||||
bot = _make_bot()
|
||||
assert bot._tls_verify is False
|
||||
|
||||
def test_nick_from_username(self):
|
||||
bot = _make_bot()
|
||||
assert bot.nick == "derp"
|
||||
Reference in New Issue
Block a user