diff --git a/config/bouncer.example.toml b/config/bouncer.example.toml index 54c4b2c..04a2f36 100644 --- a/config/bouncer.example.toml +++ b/config/bouncer.example.toml @@ -49,6 +49,7 @@ port = 6697 tls = true # nick = "mynick" # optional: override host-derived nick channels = ["#test"] +# channel_keys = { "#secret" = "hunter2" } # keys for +k channels autojoin = true # [networks.oftc] diff --git a/src/bouncer/commands.py b/src/bouncer/commands.py index 998fbc4..8da3d81 100644 --- a/src/bouncer/commands.py +++ b/src/bouncer/commands.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import TYPE_CHECKING from bouncer.network import State +from bouncer.notify import Notifier if TYPE_CHECKING: from bouncer.client import Client @@ -448,14 +449,15 @@ def _cmd_version() -> list[str]: # --- Config Management --- -async def _cmd_rehash(router: Router) -> list[str]: - """Reload config, add/remove networks (proxy/bind unchanged).""" - if not CONFIG_PATH: - return ["[REHASH] config path not set"] +async def rehash(router: Router, config_path: Path) -> list[str]: + """Reload config and apply changes. Returns status lines. + Reusable core -- called by both the REHASH command and SIGHUP handler. + """ from bouncer.config import load + try: - new_cfg = load(CONFIG_PATH) + new_cfg = load(config_path) except Exception as exc: return [f"[REHASH] config error: {exc}"] @@ -493,16 +495,48 @@ async def _cmd_rehash(router: Router) -> list[str]: else: # Update mutable config fields old_net.cfg.channels = new_net_cfg.channels + old_net.cfg.channel_keys = new_net_cfg.channel_keys old_net.cfg.nick = new_net_cfg.nick old_net.cfg.password = new_net_cfg.password lines.append(f" unchanged: {name}") + # Propagate bouncer-level settings to live objects + old_b = router.config.bouncer + new_b = new_cfg.bouncer + + # Warn about immutable fields + for field_name in ("bind", "port", "password", "client_tls"): + old_val = getattr(old_b, field_name) + new_val = getattr(new_b, field_name) + if old_val != new_val: + lines.append(f" warning: {field_name} changed (restart required)") + + # Update notifier settings + router._notifier = Notifier(new_b, new_cfg.proxy) + + # Update farm settings + farm_was_enabled = old_b.farm_enabled + router._farm._cfg = new_b + if new_b.farm_enabled and not farm_was_enabled: + await router._farm.start() + lines.append(" farm: started") + elif not new_b.farm_enabled and farm_was_enabled: + await router._farm.stop() + lines.append(" farm: stopped") + router.config = new_cfg lines.append(f" {len(new_cfg.networks)} network(s) loaded") return lines +async def _cmd_rehash(router: Router) -> list[str]: + """Reload config, add/remove networks (proxy/bind unchanged).""" + if not CONFIG_PATH: + return ["[REHASH] config path not set"] + return await rehash(router, CONFIG_PATH) + + async def _cmd_addnetwork(router: Router, arg: str) -> list[str]: """Create a network at runtime from key=value pairs.""" from bouncer.config import NetworkConfig @@ -510,7 +544,7 @@ async def _cmd_addnetwork(router: Router, arg: str) -> list[str]: parts = arg.split() if not parts: return ["Usage: ADDNETWORK host= [port=N] [tls=yes|no]", - " [nick=N] [channels=#a,#b] [password=P]"] + " [nick=N] [channels=#a,#b] [channel_keys=#c=key,...] [password=P]"] name = parts[0].lower() if "/" in name: @@ -533,6 +567,14 @@ async def _cmd_addnetwork(router: Router, arg: str) -> list[str]: port = int(kvs.get("port", str(default_port))) channels = kvs.get("channels", "").split(",") if kvs.get("channels") else [] + # Parse channel_keys: #secret=hunter2,#vip=pass + channel_keys: dict[str, str] = {} + if kvs.get("channel_keys"): + for pair in kvs["channel_keys"].split(","): + if "=" in pair: + ch, k = pair.split("=", 1) + channel_keys[ch] = k + cfg = NetworkConfig( name=name, host=kvs["host"], @@ -540,6 +582,7 @@ async def _cmd_addnetwork(router: Router, arg: str) -> list[str]: tls=tls, nick=kvs.get("nick", ""), channels=channels, + channel_keys=channel_keys, password=kvs.get("password"), auth_service=kvs.get("auth_service", "nickserv"), ) @@ -564,15 +607,15 @@ async def _cmd_delnetwork(router: Router, arg: str) -> list[str]: def _cmd_autojoin(router: Router, arg: str) -> list[str]: """Add or remove a channel from a network's autojoin list.""" - parts = arg.split(None, 1) + parts = arg.split() if len(parts) < 2: - return ["Usage: AUTOJOIN +#channel | -#channel"] + return ["Usage: AUTOJOIN +#channel [key] | -#channel"] net, err = _resolve_network(router, parts[0]) if err: return err - spec = parts[1].strip() + spec = parts[1] if not spec or spec[0] not in ("+", "-"): return ["Channel must start with + (add) or - (remove)"] @@ -581,15 +624,22 @@ def _cmd_autojoin(router: Router, arg: str) -> list[str]: if not channel: return ["Channel name required after +/-"] + key = parts[2] if len(parts) >= 3 and action == "+" else "" + lines = [f"[AUTOJOIN] {net.cfg.name}"] if action == "+": if channel not in net.cfg.channels: net.cfg.channels.append(channel) + if key: + net.cfg.channel_keys[channel] = key lines.append(f" added: {channel}") # Join immediately if network is ready if net.ready: - asyncio.create_task(net.send_raw("JOIN", channel)) + if key: + asyncio.create_task(net.send_raw("JOIN", channel, key)) + else: + asyncio.create_task(net.send_raw("JOIN", channel)) lines.append(f" joining {channel}") else: try: @@ -597,6 +647,7 @@ def _cmd_autojoin(router: Router, arg: str) -> list[str]: lines.append(f" removed: {channel}") except ValueError: lines.append(f" {channel} not in autojoin list") + net.cfg.channel_keys.pop(channel, None) lines.append(f" autojoin: {', '.join(net.cfg.channels) or '(empty)'}") return lines diff --git a/src/bouncer/config.py b/src/bouncer/config.py index 518f8e3..043b789 100644 --- a/src/bouncer/config.py +++ b/src/bouncer/config.py @@ -43,6 +43,7 @@ class NetworkConfig: user: str = "" realname: str = "" channels: list[str] = field(default_factory=list) + channel_keys: dict[str, str] = field(default_factory=dict) autojoin: bool = False password: str | None = None proxy_host: str | None = None @@ -167,6 +168,7 @@ def load(path: Path) -> Config: user=net_raw.get("user", ""), realname=net_raw.get("realname", ""), channels=net_raw.get("channels", []), + channel_keys=dict(net_raw.get("channel_keys", {})), autojoin=net_raw.get("autojoin", True), password=net_raw.get("password"), proxy_host=net_raw.get("proxy_host"), diff --git a/src/bouncer/network.py b/src/bouncer/network.py index cde5cad..4b4c7bf 100644 --- a/src/bouncer/network.py +++ b/src/bouncer/network.py @@ -1232,7 +1232,11 @@ class Network: # Rejoin after a brief delay await asyncio.sleep(self.bouncer_cfg.rejoin_delay) if channel in set(self.cfg.channels) and self._running and self.ready: - await self.send_raw("JOIN", channel) + key = self.cfg.channel_keys.get(channel, "") + if key: + await self.send_raw("JOIN", channel, key) + else: + await self.send_raw("JOIN", channel) # Forward to router if self.on_message: diff --git a/tests/test_commands.py b/tests/test_commands.py index 99e2992..c3e0772 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -14,7 +14,8 @@ from bouncer.network import State def _make_network(name: str, state: State, nick: str = "testnick", host: str | None = None, channels: set[str] | None = None, - topics: dict[str, str] | None = None) -> MagicMock: + topics: dict[str, str] | None = None, + channel_keys: dict[str, str] | None = None) -> MagicMock: """Create a mock Network.""" net = MagicMock() net.cfg.name = name @@ -22,6 +23,7 @@ def _make_network(name: str, state: State, nick: str = "testnick", net.cfg.port = 6697 net.cfg.tls = True net.cfg.channels = list(channels) if channels else [] + net.cfg.channel_keys = dict(channel_keys) if channel_keys else {} net.cfg.nick = nick net.cfg.password = None net.state = state @@ -514,6 +516,11 @@ class TestRehash: old_net = _make_network("libera", State.READY) router = _make_router(old_net) + router._notifier = MagicMock() + router._farm = MagicMock() + router._farm._cfg = BouncerConfig() + router._farm.start = AsyncMock() + router._farm.stop = AsyncMock() new_cfg = Config( bouncer=BouncerConfig(), @@ -535,6 +542,101 @@ class TestRehash: router.add_network.assert_awaited() +class TestRehashFunction: + @pytest.mark.asyncio + async def test_rehash_function_directly(self) -> None: + from bouncer.commands import rehash + from bouncer.config import BouncerConfig, Config, NetworkConfig, ProxyConfig + + old_net = _make_network("libera", State.READY) + router = _make_router(old_net) + router._notifier = MagicMock() + router._farm = MagicMock() + router._farm._cfg = BouncerConfig() + router._farm.start = AsyncMock() + router._farm.stop = AsyncMock() + + new_cfg = Config( + bouncer=BouncerConfig(), + proxy=ProxyConfig(), + networks={ + "oftc": NetworkConfig(name="oftc", host="irc.oftc.net", port=6697, tls=True), + }, + ) + + with patch("bouncer.config.load", return_value=new_cfg): + lines = await rehash(router, Path("/tmp/test.toml")) + + assert lines[0] == "[REHASH]" + assert any("removed: libera" in line for line in lines) + assert any("added: oftc" in line for line in lines) + + @pytest.mark.asyncio + async def test_rehash_updates_bouncer_config(self) -> None: + from bouncer.commands import rehash + from bouncer.config import BouncerConfig, Config, NetworkConfig, ProxyConfig + + net = _make_network("libera", State.READY) + router = _make_router(net) + router._notifier = MagicMock() + router._farm = MagicMock() + router._farm._cfg = BouncerConfig() + router._farm.start = AsyncMock() + router._farm.stop = AsyncMock() + + new_cfg = Config( + bouncer=BouncerConfig(notify_url="https://ntfy.sh/test"), + proxy=ProxyConfig(), + networks={ + "libera": NetworkConfig(name="libera", host="irc.libera.chat", + port=6697, tls=True, + channel_keys={"#secret": "key"}), + }, + ) + + with patch("bouncer.config.load", return_value=new_cfg): + result = await rehash(router, Path("/tmp/test.toml")) + + assert result[0] == "[REHASH]" + assert router.config == new_cfg + # Notifier was replaced (new instance) + assert router._notifier is not None + + @pytest.mark.asyncio + async def test_rehash_propagates_channel_keys(self) -> None: + from bouncer.commands import rehash + from bouncer.config import BouncerConfig, Config, NetworkConfig, ProxyConfig + + net = _make_network("libera", State.READY) + net.cfg.host = "irc.libera.chat" + net.cfg.port = 6697 + net.cfg.tls = True + net.cfg.proxy_host = None + net.cfg.proxy_port = None + net.cfg.channel_keys = {} + router = _make_router(net) + router._notifier = MagicMock() + router._farm = MagicMock() + router._farm._cfg = BouncerConfig() + router._farm.start = AsyncMock() + router._farm.stop = AsyncMock() + + new_cfg = Config( + bouncer=BouncerConfig(), + proxy=ProxyConfig(), + networks={ + "libera": NetworkConfig(name="libera", host="irc.libera.chat", + port=6697, tls=True, + channel_keys={"#secret": "key123"}), + }, + ) + + with patch("bouncer.config.load", return_value=new_cfg): + await rehash(router, Path("/tmp/test.toml")) + + assert net.cfg.channel_keys == {"#secret": "key123"} + + class TestAddNetwork: @pytest.mark.asyncio async def test_addnetwork_missing_args(self) -> None: @@ -648,6 +750,33 @@ class TestAutojoin: lines = await commands.dispatch("AUTOJOIN libera -#missing", router, client) assert any("not in autojoin" in line for line in lines) + @pytest.mark.asyncio + async def test_autojoin_with_key(self) -> None: + net = _make_network("libera", State.READY) + net.cfg.channels = [] + net.cfg.channel_keys = {} + router = _make_router(net) + client = _make_client() + lines = await commands.dispatch("AUTOJOIN libera +#secret hunter2", router, client) + assert "[AUTOJOIN]" in lines[0] + assert any("added: #secret" in line for line in lines) + assert "#secret" in net.cfg.channels + assert net.cfg.channel_keys["#secret"] == "hunter2" + + @pytest.mark.asyncio + async def test_autojoin_remove_clears_key(self) -> None: + net = _make_network("libera", State.READY, + channels={"#secret"}, + channel_keys={"#secret": "hunter2"}) + net.cfg.channels = ["#secret"] + net.cfg.channel_keys = {"#secret": "hunter2"} + router = _make_router(net) + client = _make_client() + lines = await commands.dispatch("AUTOJOIN libera -#secret", router, client) + assert any("removed: #secret" in line for line in lines) + assert "#secret" not in net.cfg.channels + assert "#secret" not in net.cfg.channel_keys + @pytest.mark.asyncio async def test_autojoin_invalid_spec(self) -> None: net = _make_network("libera", State.READY) diff --git a/tests/test_config.py b/tests/test_config.py index c463b8c..6d20fce 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -124,6 +124,28 @@ tls = true cfg = load(_write_config(config)) assert cfg.networks["test"].port == 6697 + def test_channel_keys_parsed(self): + config = """\ +[bouncer] +password = "x" + +[proxy] + +[networks.test] +host = "irc.example.com" +channels = ["#secret", "#public"] +channel_keys = { "#secret" = "hunter2" } +""" + cfg = load(_write_config(config)) + net = cfg.networks["test"] + assert net.channel_keys == {"#secret": "hunter2"} + assert "#secret" in net.channels + + def test_channel_keys_default_empty(self): + cfg = load(_write_config(MINIMAL_CONFIG)) + net = cfg.networks["test"] + assert net.channel_keys == {} + def test_operational_defaults(self): """Ensure all operational values have sane defaults.""" cfg = load(_write_config(MINIMAL_CONFIG)) diff --git a/tests/test_network.py b/tests/test_network.py index 3391bcb..5b5e263 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -479,6 +479,25 @@ class TestHandleKick: # Should rejoin (rejoin_delay=0) writer.write.assert_called_with(b"JOIN #test\r\n") + @pytest.mark.asyncio + async def test_kick_rejoin_with_key(self) -> None: + cfg = _cfg(channels=["#secret"]) + cfg.channel_keys = {"#secret": "hunter2"} + net = _net(cfg=cfg) + net.state = State.READY + net._running = True + net.nick = "me" + net.channels = {"#secret"} + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":op!user@host KICK #secret me :reason")) + assert "#secret" not in net.channels + # Should rejoin with key (rejoin_delay=0) + writer.write.assert_called_with(b"JOIN #secret hunter2\r\n") + @pytest.mark.asyncio async def test_kick_other_user_ignored(self) -> None: net = _net()