diff --git a/plugins/chanmgmt.py b/plugins/chanmgmt.py index c6f25ba..3155411 100644 --- a/plugins/chanmgmt.py +++ b/plugins/chanmgmt.py @@ -1,6 +1,10 @@ -"""Channel management: kick, ban, unban, topic, mode.""" +"""Channel management: kick, ban, unban, topic, mode, invite-join.""" -from derp.plugin import command +import logging + +from derp.plugin import command, event + +log = logging.getLogger(__name__) def _require_channel(message): @@ -79,3 +83,16 @@ async def cmd_mode(bot, message): mode_str = parts[1] args = parts[2:] await bot.mode(message.target, mode_str, *args) + + +@event("INVITE") +async def on_invite(bot, message): + """Join a channel when invited by an admin or IRC operator.""" + if not bot._is_admin(message): + log.info("ignoring invite from non-admin %s", message.nick) + return + channel = message.params[1] if len(message.params) > 1 else None + if not channel: + return + log.info("accepting invite to %s from %s", channel, message.nick) + await bot.join(channel) diff --git a/tests/test_chanmgmt.py b/tests/test_chanmgmt.py new file mode 100644 index 0000000..f39dd57 --- /dev/null +++ b/tests/test_chanmgmt.py @@ -0,0 +1,81 @@ +"""Tests for the chanmgmt plugin (invite-join).""" + +import asyncio +import importlib.util +import sys +from pathlib import Path + +from derp.irc import Message + +# plugins/ is not a Python package -- load the module from file path +_spec = importlib.util.spec_from_file_location( + "plugins.chanmgmt", + Path(__file__).resolve().parent.parent / "plugins" / "chanmgmt.py", +) +_mod = importlib.util.module_from_spec(_spec) +sys.modules[_spec.name] = _mod +_spec.loader.exec_module(_mod) + +from plugins.chanmgmt import on_invite # noqa: E402 # isort: skip + + +# -- Helpers ----------------------------------------------------------------- + +class _FakeConn: + """Minimal connection stand-in.""" + + def __init__(self): + self.sent: list[str] = [] + + async def send(self, raw: str) -> None: + self.sent.append(raw) + + +class _FakeBot: + """Minimal bot stand-in.""" + + def __init__(self, *, admin: bool = False): + self.joined: list[str] = [] + self._admin = admin + self.conn = _FakeConn() + + def _is_admin(self, message) -> bool: + return self._admin + + async def join(self, channel: str) -> None: + self.joined.append(channel) + + +def _invite(nick: str, channel: str) -> Message: + """Create an INVITE message: :nick!user@host INVITE botname #channel.""" + return Message( + raw="", + prefix=f"{nick}!~{nick}@host", + nick=nick, + command="INVITE", + params=["derp", channel], + tags={}, + ) + + +# -- Tests ------------------------------------------------------------------- + +class TestOnInvite: + def test_admin_joins(self): + bot = _FakeBot(admin=True) + asyncio.run(on_invite(bot, _invite("oper", "#secret"))) + assert "#secret" in bot.joined + + def test_non_admin_ignored(self): + bot = _FakeBot(admin=False) + asyncio.run(on_invite(bot, _invite("rando", "#trap"))) + assert bot.joined == [] + + def test_missing_channel_ignored(self): + bot = _FakeBot(admin=True) + msg = Message( + raw="", prefix="oper!~oper@host", nick="oper", + command="INVITE", params=["derp"], tags={}, + ) + asyncio.run(on_invite(bot, msg)) + assert bot.joined == []