From 4dd817ea75c99a847432a5e84ccc778d82a845a0 Mon Sep 17 00:00:00 2001 From: user Date: Sat, 21 Feb 2026 17:11:58 +0100 Subject: [PATCH] test: add 94 tests for network connection manager Cover state machine, markov nick generation, SASL negotiation, NickServ/Q-bot auth flows, probation timer, reconnection backoff, read loop buffering, and IRC message handling. Co-Authored-By: Claude Opus 4.6 --- tests/test_network.py | 1203 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1203 insertions(+) create mode 100644 tests/test_network.py diff --git a/tests/test_network.py b/tests/test_network.py new file mode 100644 index 0000000..1057ca1 --- /dev/null +++ b/tests/test_network.py @@ -0,0 +1,1203 @@ +"""Tests for network connection manager.""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import random +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from bouncer.config import BouncerConfig, NetworkConfig, ProxyConfig +from bouncer.irc import IRCMessage, parse +from bouncer.network import ( + Network, + State, + _BIGRAMS, + _GENERIC_IDENTS, + _GENERIC_REALNAMES, + _STARTERS, + _VOWELS, + _markov_word, + _nick_for_host, + _password_for_host, + _random_nick, + _rng_for_key, + _seeded_markov, +) + + +# -- helpers ----------------------------------------------------------------- + +def _cfg(name: str = "testnet", host: str = "irc.test.net", port: int = 6697, + tls: bool = True, nick: str = "", channels: list[str] | None = None, + password: str | None = None, auth_service: str = "nickserv") -> NetworkConfig: + return NetworkConfig( + name=name, host=host, port=port, tls=tls, nick=nick, + channels=channels or [], password=password, auth_service=auth_service, + ) + + +def _proxy() -> ProxyConfig: + return ProxyConfig(host="127.0.0.1", port=1080) + + +def _bouncer(**overrides: object) -> BouncerConfig: + defaults: dict[str, object] = { + "probation_seconds": 1, + "nick_timeout": 1, + "rejoin_delay": 0, + "backoff_steps": [0], + } + defaults.update(overrides) + return BouncerConfig(**defaults) # type: ignore[arg-type] + + +def _net(cfg: NetworkConfig | None = None, backlog: AsyncMock | None = None, + bouncer_cfg: BouncerConfig | None = None, + on_message: MagicMock | None = None) -> Network: + return Network( + cfg=cfg or _cfg(), + proxy_cfg=_proxy(), + backlog=backlog, + on_message=on_message, + bouncer_cfg=bouncer_cfg or _bouncer(), + ) + + +def _mock_backlog(**kw: object) -> AsyncMock: + bl = AsyncMock() + bl.get_nickserv_creds_by_network.return_value = kw.get("creds_by_network") + bl.get_nickserv_creds_by_host.return_value = kw.get("creds_by_host") + bl.get_pending_registration.return_value = kw.get("pending") + bl.save_nickserv_creds = AsyncMock() + bl.mark_nickserv_verified = AsyncMock() + return bl + + +def _msg(raw: str) -> IRCMessage: + """Parse a raw IRC line into an IRCMessage.""" + return parse(raw.encode()) + + +# -- markov / nick generation ----------------------------------------------- + +class TestMarkovWord: + def test_length_bounds(self) -> None: + for _ in range(100): + word = _markov_word(4, 6) + assert 4 <= len(word) <= 6 + + def test_all_alpha(self) -> None: + for _ in range(50): + assert _markov_word(5, 8).isalpha() + + def test_no_triple_consonants(self) -> None: + for _ in range(200): + word = _markov_word(5, 10) + consonant_run = 0 + for ch in word: + if ch not in _VOWELS: + consonant_run += 1 + else: + consonant_run = 0 + assert consonant_run <= 2, f"triple consonant in {word!r}" + + def test_starts_with_starter(self) -> None: + for _ in range(50): + word = _markov_word(3, 5) + assert word[0] in _STARTERS + + +class TestSeededMarkov: + def test_deterministic(self) -> None: + rng1 = random.Random(42) + rng2 = random.Random(42) + assert _seeded_markov(rng1, 5, 8) == _seeded_markov(rng2, 5, 8) + + def test_different_seeds_differ(self) -> None: + results = {_seeded_markov(random.Random(i), 5, 8) for i in range(20)} + assert len(results) > 1 + + +class TestRandomNick: + def test_length(self) -> None: + for _ in range(50): + nick = _random_nick() + assert 5 <= len(nick) <= 10 # 5-8 base + optional 0-2 digit suffix + + def test_starts_alpha(self) -> None: + for _ in range(50): + assert _random_nick()[0].isalpha() + + +class TestNickForHost: + def test_deterministic(self) -> None: + assert _nick_for_host("example.com") == _nick_for_host("example.com") + + def test_different_hosts_differ(self) -> None: + nicks = {_nick_for_host(f"host{i}.example.com") for i in range(10)} + assert len(nicks) > 1 + + +class TestPasswordForHost: + def test_deterministic(self) -> None: + assert _password_for_host("h.com") == _password_for_host("h.com") + + def test_length_16(self) -> None: + assert len(_password_for_host("any.host")) == 16 + + def test_hex_chars(self) -> None: + pw = _password_for_host("foo.bar") + assert all(c in "0123456789abcdef" for c in pw) + + def test_different_from_nick(self) -> None: + """Password and nick use different hash domains.""" + host = "same.host" + pw = _password_for_host(host) + nick = _nick_for_host(host) + assert pw != nick + + +class TestRngForKey: + def test_deterministic(self) -> None: + a = _rng_for_key("test").random() + b = _rng_for_key("test").random() + assert a == b + + def test_different_keys(self) -> None: + a = _rng_for_key("key1").random() + b = _rng_for_key("key2").random() + assert a != b + + +class TestEmailForHost: + @patch("bouncer.network._email_domains", return_value=["mail.tm", "mail.gw"]) + def test_returns_email(self, _mock: MagicMock) -> None: + from bouncer.network import _email_for_host + email = _email_for_host("test.host") + assert email is not None + assert "@" in email + + @patch("bouncer.network._email_domains", return_value=["mail.tm"]) + def test_excludes_domain(self, _mock: MagicMock) -> None: + from bouncer.network import _email_for_host + result = _email_for_host("host", excluded={"mail.tm"}) + assert result is None + + @patch("bouncer.network._email_domains", return_value=[]) + def test_no_domains_returns_none(self, _mock: MagicMock) -> None: + from bouncer.network import _email_for_host + assert _email_for_host("host") is None + + +# -- state properties -------------------------------------------------------- + +class TestStateProperties: + def test_initial_state(self) -> None: + net = _net() + assert net.state == State.DISCONNECTED + assert not net.connected + assert not net.registered + assert not net.ready + + def test_connecting_not_connected(self) -> None: + net = _net() + net.state = State.CONNECTING + assert not net.connected + + def test_registering_is_connected(self) -> None: + net = _net() + net.state = State.REGISTERING + assert net.connected + assert not net.registered + + def test_probation_is_registered(self) -> None: + net = _net() + net.state = State.PROBATION + assert net.connected + assert net.registered + assert not net.ready + + def test_ready(self) -> None: + net = _net() + net.state = State.READY + assert net.connected + assert net.registered + assert net.ready + + +class TestInit: + def test_defaults(self) -> None: + net = _net() + assert net.nick == "*" + assert net.channels == set() + assert net.topics == {} + assert net.names == {} + assert net.visible_host is None + + def test_nick_from_config(self) -> None: + net = _net(cfg=_cfg(nick="mynick")) + assert net.nick == "mynick" + + +# -- send helpers ------------------------------------------------------------ + +class TestSend: + @pytest.mark.asyncio + async def test_send_writes_to_writer(self) -> None: + net = _net() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + msg = IRCMessage(command="PRIVMSG", params=["#test", "hello world"]) + await net.send(msg) + writer.write.assert_called_once_with(b"PRIVMSG #test :hello world\r\n") + writer.drain.assert_awaited_once() + + @pytest.mark.asyncio + async def test_send_skips_closing_writer(self) -> None: + net = _net() + writer = MagicMock() + writer.is_closing.return_value = True + net._writer = writer + + await net.send(IRCMessage(command="PING")) + writer.write.assert_not_called() + + @pytest.mark.asyncio + async def test_send_skips_none_writer(self) -> None: + net = _net() + net._writer = None + # Should not raise + await net.send(IRCMessage(command="PING")) + + @pytest.mark.asyncio + async def test_send_raw(self) -> None: + net = _net() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net.send_raw("JOIN", "#channel") + writer.write.assert_called_once_with(b"JOIN #channel\r\n") + + +# -- _handle: PING/PONG ----------------------------------------------------- + +class TestHandlePing: + @pytest.mark.asyncio + async def test_ping_pong(self) -> None: + net = _net() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg("PING :server.example.com")) + writer.write.assert_called_once_with(b"PONG server.example.com\r\n") + + +# -- _handle: ERROR --------------------------------------------------------- + +class TestHandleError: + @pytest.mark.asyncio + async def test_error_sets_status(self) -> None: + status = MagicMock() + net = _net() + net.on_status = status + + await net._handle(_msg(":server ERROR :K-lined")) + status.assert_called_once_with("testnet", "ERROR: K-lined") + + +# -- _handle: 001 (RPL_WELCOME) --------------------------------------------- + +class TestHandleWelcome: + @pytest.mark.asyncio + async def test_sets_nick_from_params(self) -> None: + net = _net() + net.state = State.REGISTERING + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 001 coolguy :Welcome to the network coolguy!user@host.example.com")) + assert net.nick == "coolguy" + assert net.state == State.PROBATION + + @pytest.mark.asyncio + async def test_extracts_visible_host(self) -> None: + net = _net() + net.state = State.REGISTERING + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 001 nick :Welcome to the IRC Network nick!user@visible.host.com")) + assert net.visible_host == "visible.host.com" + + +# -- _handle: 396 (RPL_VISIBLEHOST) ----------------------------------------- + +class TestHandleVisibleHost: + @pytest.mark.asyncio + async def test_updates_visible_host(self) -> None: + net = _net() + await net._handle(_msg(":server 396 nick new.host.com :is now your displayed host")) + assert net.visible_host == "new.host.com" + + +# -- _handle: NICK ---------------------------------------------------------- + +class TestHandleNick: + @pytest.mark.asyncio + async def test_own_nick_change(self) -> None: + net = _net() + net.nick = "oldnick" + await net._handle(_msg(":oldnick!user@host NICK newnick")) + assert net.nick == "newnick" + assert net._nick_confirmed.is_set() + + @pytest.mark.asyncio + async def test_other_nick_change_ignored(self) -> None: + net = _net() + net.nick = "mynick" + await net._handle(_msg(":othernick!user@host NICK somethingelse")) + assert net.nick == "mynick" + assert not net._nick_confirmed.is_set() + + +# -- _handle: JOIN/PART ----------------------------------------------------- + +class TestHandleJoinPart: + @pytest.mark.asyncio + async def test_own_join(self) -> None: + net = _net() + net.nick = "me" + await net._handle(_msg(":me!user@host JOIN #test")) + assert "#test" in net.channels + assert "#test" in net.names + + @pytest.mark.asyncio + async def test_other_join_ignored(self) -> None: + net = _net() + net.nick = "me" + await net._handle(_msg(":other!user@host JOIN #test")) + assert "#test" not in net.channels + + @pytest.mark.asyncio + async def test_own_part(self) -> None: + net = _net() + net.nick = "me" + net.channels = {"#test"} + net.names["#test"] = {"me", "other"} + net.topics["#test"] = "Topic" + + await net._handle(_msg(":me!user@host PART #test")) + assert "#test" not in net.channels + assert "#test" not in net.names + assert "#test" not in net.topics + + +# -- _handle: 332 (RPL_TOPIC), 353 (RPL_NAMREPLY) -------------------------- + +class TestHandleTopicNames: + @pytest.mark.asyncio + async def test_topic(self) -> None: + net = _net() + await net._handle(_msg(":server 332 me #test :Welcome to the channel")) + assert net.topics["#test"] == "Welcome to the channel" + + @pytest.mark.asyncio + async def test_namreply(self) -> None: + net = _net() + await net._handle(_msg(":server 353 me = #test :@op +voice regular")) + assert net.names["#test"] == {"@op", "+voice", "regular"} + + @pytest.mark.asyncio + async def test_namreply_accumulates(self) -> None: + net = _net() + await net._handle(_msg(":server 353 me = #test :nick1 nick2")) + await net._handle(_msg(":server 353 me = #test :nick3")) + assert net.names["#test"] == {"nick1", "nick2", "nick3"} + + +# -- _handle: 433 (ERR_NICKNAMEINUSE) --------------------------------------- + +class TestHandleNickInUse: + @pytest.mark.asyncio + async def test_during_registration_picks_random(self) -> None: + net = _net() + net.state = State.REGISTERING + net.nick = "taken" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 433 * taken :Nickname is already in use")) + # Should have changed to a new random nick + assert net.nick != "taken" + assert writer.write.called + + @pytest.mark.asyncio + async def test_when_ready_appends_underscore(self) -> None: + net = _net() + net.state = State.READY + net.nick = "desired" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 433 * desired :Nickname is already in use")) + assert net.nick == "desired_" + writer.write.assert_called_once_with(b"NICK desired_\r\n") + + +# -- _handle: KICK ---------------------------------------------------------- + +class TestHandleKick: + @pytest.mark.asyncio + async def test_kicked_removes_channel(self) -> None: + net = _net(cfg=_cfg(channels=["#test"])) + net.state = State.READY + net._running = True + net.nick = "me" + net.channels = {"#test"} + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":op!user@host KICK #test me :reason")) + assert "#test" not in net.channels + # Should rejoin (rejoin_delay=0) + writer.write.assert_called_with(b"JOIN #test\r\n") + + @pytest.mark.asyncio + async def test_kick_other_user_ignored(self) -> None: + net = _net() + net.nick = "me" + net.channels = {"#test"} + await net._handle(_msg(":op!user@host KICK #test other :reason")) + assert "#test" in net.channels + + +# -- _handle: NOTICE (hostname extraction) ----------------------------------- + +class TestHandleNotice: + @pytest.mark.asyncio + async def test_found_hostname_notice(self) -> None: + net = _net() + await net._handle( + _msg(":server NOTICE * :*** Found your hostname: some.isp.example.com") + ) + assert net.visible_host == "some.isp.example.com" + + +# -- _handle: CAP (SASL negotiation) ---------------------------------------- + +class TestHandleCap: + @pytest.mark.asyncio + async def test_cap_ack_sasl_starts_authenticate(self) -> None: + net = _net() + net._sasl_mechanism = "PLAIN" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server CAP * ACK :sasl")) + writer.write.assert_called_with(b"AUTHENTICATE PLAIN\r\n") + + @pytest.mark.asyncio + async def test_cap_nak_sasl_ends_cap(self) -> None: + net = _net() + net._sasl_mechanism = "PLAIN" + net._sasl_nick = "nick" + net._sasl_pass = "pass" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server CAP * NAK :sasl")) + writer.write.assert_called_with(b"CAP END\r\n") + assert net._sasl_mechanism == "" + assert net._sasl_nick == "" + + +# -- _handle: AUTHENTICATE -------------------------------------------------- + +class TestHandleAuthenticate: + @pytest.mark.asyncio + async def test_plain_credentials(self) -> None: + net = _net() + net._sasl_mechanism = "PLAIN" + net._sasl_nick = "mynick" + net._sasl_pass = "mypass" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg("AUTHENTICATE +")) + expected = base64.b64encode(b"mynick\0mynick\0mypass").decode() + writer.write.assert_called_with(f"AUTHENTICATE {expected}\r\n".encode()) + + @pytest.mark.asyncio + async def test_external_sends_nick(self) -> None: + net = _net() + net._sasl_mechanism = "EXTERNAL" + net._sasl_nick = "extnick" + net._sasl_pass = "pass" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg("AUTHENTICATE +")) + expected = base64.b64encode(b"extnick").decode() + writer.write.assert_called_with(f"AUTHENTICATE {expected}\r\n".encode()) + + @pytest.mark.asyncio + async def test_no_creds_aborts(self) -> None: + net = _net() + net._sasl_mechanism = "PLAIN" + net._sasl_nick = "" + net._sasl_pass = "" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg("AUTHENTICATE +")) + writer.write.assert_called_with(b"AUTHENTICATE *\r\n") + + +# -- _handle: 903 (RPL_SASLSUCCESS) ---------------------------------------- + +class TestHandleSaslSuccess: + @pytest.mark.asyncio + async def test_sasl_success_sets_event(self) -> None: + net = _net() + net._sasl_mechanism = "PLAIN" + net._sasl_nick = "nick" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 903 nick :SASL authentication successful")) + assert net._sasl_complete.is_set() + writer.write.assert_called_with(b"CAP END\r\n") + + +# -- _handle: 904/905 (SASL failure) ---------------------------------------- + +class TestHandleSaslFailure: + @pytest.mark.asyncio + async def test_external_falls_back_to_plain(self) -> None: + net = _net() + net._sasl_mechanism = "EXTERNAL" + net._sasl_nick = "nick" + net._sasl_pass = "pass" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 904 nick :SASL authentication failed")) + assert net._sasl_mechanism == "PLAIN" + writer.write.assert_called_with(b"AUTHENTICATE PLAIN\r\n") + + @pytest.mark.asyncio + async def test_plain_failure_ends_cap(self) -> None: + net = _net() + net._sasl_mechanism = "PLAIN" + net._sasl_nick = "nick" + net._sasl_pass = "pass" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 904 nick :SASL authentication failed")) + assert net._sasl_mechanism == "" + assert net._sasl_nick == "" + writer.write.assert_called_with(b"CAP END\r\n") + + +# -- _handle: 906/908 (SASL aborted/mechs) ---------------------------------- + +class TestHandleSaslAborted: + @pytest.mark.asyncio + async def test_906_ends_cap(self) -> None: + net = _net() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg(":server 906 nick :SASL authentication aborted")) + writer.write.assert_called_with(b"CAP END\r\n") + + +# -- message forwarding to router ------------------------------------------- + +class TestOnMessage: + @pytest.mark.asyncio + async def test_on_message_called_for_non_ping(self) -> None: + cb = MagicMock() + net = _net(on_message=cb) + await net._handle(_msg(":nick!user@host PRIVMSG #test :hello")) + cb.assert_called_once() + assert cb.call_args[0][0] == "testnet" + assert cb.call_args[0][1].command == "PRIVMSG" + + @pytest.mark.asyncio + async def test_on_message_not_called_for_ping(self) -> None: + cb = MagicMock() + net = _net(on_message=cb) + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle(_msg("PING :server")) + cb.assert_not_called() + + +# -- probation timer --------------------------------------------------------- + +class TestProbation: + @pytest.mark.asyncio + async def test_enter_probation(self) -> None: + net = _net(bouncer_cfg=_bouncer(probation_seconds=0)) + net.state = State.REGISTERING + net._running = True + # Mock _go_ready to prevent full NickServ flow + net._go_ready = AsyncMock() + + await net._enter_probation() + assert net.state == State.PROBATION + assert net._probation_task is not None + + # Wait for the probation timer to fire + await net._probation_task + net._go_ready.assert_awaited_once() + + @pytest.mark.asyncio + async def test_probation_cancelled(self) -> None: + net = _net(bouncer_cfg=_bouncer(probation_seconds=60)) + net.state = State.REGISTERING + net._running = True + net._go_ready = AsyncMock() + + await net._enter_probation() + net._probation_task.cancel() + try: + await net._probation_task + except asyncio.CancelledError: + pass + net._go_ready.assert_not_awaited() + + @pytest.mark.asyncio + async def test_probation_resets_reconnect_counter(self) -> None: + net = _net(bouncer_cfg=_bouncer(probation_seconds=0)) + net.state = State.REGISTERING + net._running = True + net._reconnect_attempt = 5 + net._go_ready = AsyncMock() + + await net._enter_probation() + await net._probation_task + assert net._reconnect_attempt == 0 + + +# -- reconnection backoff --------------------------------------------------- + +class TestReconnect: + @pytest.mark.asyncio + async def test_backoff_increments(self) -> None: + net = _net(bouncer_cfg=_bouncer(backoff_steps=[0, 0, 0])) + net._running = True + + with patch.object(net, "_connect", new_callable=AsyncMock) as mock_connect: + assert net._reconnect_attempt == 0 + net._schedule_reconnect() + await net._reconnect_task + assert net._reconnect_attempt == 1 + mock_connect.assert_awaited_once() + + @pytest.mark.asyncio + async def test_backoff_clamps_to_last_step(self) -> None: + steps = [1, 2, 5] + net = _net(bouncer_cfg=_bouncer(backoff_steps=steps)) + net._running = True + net._reconnect_attempt = 100 + + with patch.object(net, "_connect", new_callable=AsyncMock): + net._schedule_reconnect() + # The delay should use the last step (5), but since we can't + # easily measure the sleep duration, verify the attempt incremented + await net._reconnect_task + assert net._reconnect_attempt == 101 + + @pytest.mark.asyncio + async def test_reconnect_cancelled(self) -> None: + net = _net(bouncer_cfg=_bouncer(backoff_steps=[60])) + net._running = True + + with patch.object(net, "_connect", new_callable=AsyncMock) as mock_connect: + net._schedule_reconnect() + net._reconnect_task.cancel() + try: + await net._reconnect_task + except asyncio.CancelledError: + pass + mock_connect.assert_not_awaited() + + @pytest.mark.asyncio + async def test_no_reconnect_when_stopped(self) -> None: + net = _net(bouncer_cfg=_bouncer(backoff_steps=[0])) + net._running = False + + with patch.object(net, "_connect", new_callable=AsyncMock) as mock_connect: + net._schedule_reconnect() + await net._reconnect_task + mock_connect.assert_not_awaited() + + +# -- disconnect -------------------------------------------------------------- + +class TestDisconnect: + @pytest.mark.asyncio + async def test_disconnect_closes_writer(self) -> None: + net = _net() + writer = MagicMock() + writer.is_closing.return_value = False + writer.wait_closed = AsyncMock() + net._writer = writer + net._reader = MagicMock() + + await net._disconnect() + assert net.state == State.DISCONNECTED + writer.close.assert_called_once() + assert net._writer is None + assert net._reader is None + + @pytest.mark.asyncio + async def test_disconnect_cancels_probation(self) -> None: + net = _net() + task = MagicMock() + task.done.return_value = False + net._probation_task = task + + await net._disconnect() + task.cancel.assert_called_once() + assert net._probation_task is None + + +# -- stop -------------------------------------------------------------------- + +class TestStop: + @pytest.mark.asyncio + async def test_stop_cancels_tasks(self) -> None: + net = _net() + net._running = True + read_task = MagicMock() + read_task.done.return_value = False + reconnect_task = MagicMock() + reconnect_task.done.return_value = False + net._read_task = read_task + net._reconnect_task = reconnect_task + + await net.stop() + assert not net._running + read_task.cancel.assert_called_once() + reconnect_task.cancel.assert_called_once() + assert net.state == State.DISCONNECTED + + +# -- _connect ---------------------------------------------------------------- + +class TestConnect: + @pytest.mark.asyncio + async def test_connect_no_sasl(self) -> None: + """Without stored creds, uses random nick and no SASL.""" + net = _net() + reader = MagicMock() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + + with patch("bouncer.proxy.connect", new_callable=AsyncMock, + return_value=(reader, writer)): + # Prevent the read loop from actually running + with patch.object(net, "_read_loop", new_callable=AsyncMock): + await net._connect() + + assert net.state == State.REGISTERING + # Should have sent NICK and USER + calls = [c.args[0] for c in writer.write.call_args_list] + nick_sent = any(b"NICK " in c for c in calls) + user_sent = any(b"USER " in c for c in calls) + assert nick_sent + assert user_sent + # Should NOT have sent CAP REQ sasl + cap_sent = any(b"CAP REQ" in c for c in calls) + assert not cap_sent + + @pytest.mark.asyncio + async def test_connect_with_sasl_plain(self) -> None: + """With stored creds and no cert, uses SASL PLAIN.""" + bl = _mock_backlog(creds_by_network=("registered_nick", "secret")) + net = _net(backlog=bl) + reader = MagicMock() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + + with patch("bouncer.proxy.connect", new_callable=AsyncMock, + return_value=(reader, writer)): + with patch.object(net, "_read_loop", new_callable=AsyncMock): + await net._connect() + + assert net._sasl_mechanism == "PLAIN" + assert net._sasl_nick == "registered_nick" + assert net._connect_nick == "registered_nick" + calls = [c.args[0] for c in writer.write.call_args_list] + cap_sent = any(b"CAP REQ sasl" in c for c in calls) + assert cap_sent + + @pytest.mark.asyncio + async def test_connect_failure_schedules_reconnect(self) -> None: + """Connection failure should schedule a reconnect.""" + net = _net() + net._running = True + + with patch("bouncer.proxy.connect", new_callable=AsyncMock, + side_effect=ConnectionRefusedError("refused")): + with patch.object(net, "_schedule_reconnect") as mock_sched: + await net._connect() + + assert net.state == State.DISCONNECTED + mock_sched.assert_called_once() + + @pytest.mark.asyncio + async def test_connect_sends_server_password(self) -> None: + """If network has a password, sends PASS before NICK.""" + net = _net(cfg=_cfg(password="serverpass")) + reader = MagicMock() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + + with patch("bouncer.proxy.connect", new_callable=AsyncMock, + return_value=(reader, writer)): + with patch.object(net, "_read_loop", new_callable=AsyncMock): + await net._connect() + + calls = [c.args[0] for c in writer.write.call_args_list] + pass_sent = any(b"PASS serverpass" in c for c in calls) + assert pass_sent + + +# -- _go_ready --------------------------------------------------------------- + +class TestGoReady: + @pytest.mark.asyncio + async def test_sasl_skips_nickserv(self) -> None: + """When SASL already authenticated, skip NickServ flow.""" + bl = _mock_backlog() + net = _net(backlog=bl, cfg=_cfg(auth_service="nickserv")) + net.state = State.PROBATION + net.nick = "autheduser" + net._sasl_complete.set() + + await net._go_ready() + assert net.state == State.READY + + @pytest.mark.asyncio + async def test_auth_service_none_skips_auth(self) -> None: + """When auth_service=none, no authentication attempted.""" + status = MagicMock() + net = _net(cfg=_cfg(auth_service="none")) + net.on_status = status + net.state = State.PROBATION + net.nick = "anon" + + await net._go_ready() + assert net.state == State.READY + status.assert_called_with("testnet", "ready as anon (no auth service)") + + +# -- NickServ response parsing ----------------------------------------------- + +class TestNickServMatchers: + def test_registration_confirmed(self) -> None: + net = _net() + assert net._registration_confirmed("a passcode has been sent to your email") + assert net._registration_confirmed("activation instructions have been sent") + assert not net._registration_confirmed("you are now identified") + + def test_registration_immediate(self) -> None: + net = _net() + assert net._registration_immediate("nickname registered under your account") + assert not net._registration_immediate("nickname registered, email verification required") + + def test_verification_succeeded(self) -> None: + net = _net() + assert net._verification_succeeded("your nick has been verified") + assert net._verification_succeeded("has now been activated for use") + assert net._verification_succeeded("you are now identified for this nick") + assert not net._verification_succeeded("verification pending") + + +# -- _handle_nickserv -------------------------------------------------------- + +class TestHandleNickserv: + @pytest.mark.asyncio + async def test_identify_success(self) -> None: + bl = _mock_backlog() + net = _net(backlog=bl) + net.nick = "mynick" + net.visible_host = "user/mynick" + net._nickserv_pending = "identify" + net._nickserv_password = "secret" + net._nickserv_done = asyncio.Event() + + await net._handle_nickserv("You are now identified for mynick") + assert net._nickserv_pending == "" + assert net._nickserv_done.is_set() + bl.save_nickserv_creds.assert_awaited_once() + + @pytest.mark.asyncio + async def test_identify_not_registered_triggers_register(self) -> None: + net = _net() + net._nickserv_pending = "identify" + net._nickserv_done = asyncio.Event() + + with patch.object(net, "_nickserv_register", new_callable=AsyncMock) as mock_reg: + await net._handle_nickserv("mynick is not a registered nickname") + mock_reg.assert_awaited_once() + assert net._nickserv_pending == "" + + @pytest.mark.asyncio + async def test_identify_wrong_password(self) -> None: + net = _net() + net._nickserv_pending = "identify" + net._nickserv_done = asyncio.Event() + + await net._handle_nickserv("Invalid password for mynick") + assert net._nickserv_pending == "" + assert net._nickserv_done.is_set() + + @pytest.mark.asyncio + async def test_register_email_sent(self) -> None: + bl = _mock_backlog() + net = _net(backlog=bl) + net.nick = "newnick" + net._nickserv_pending = "register" + net._nickserv_password = "pass" + net._nickserv_email = "test@mail.tm" + net._nickserv_done = asyncio.Event() + + await net._handle_nickserv("A passcode has been sent to test@mail.tm") + assert net._nickserv_pending == "verify" + bl.save_nickserv_creds.assert_awaited_once() + + @pytest.mark.asyncio + async def test_register_domain_rejected(self) -> None: + net = _net() + net._nickserv_pending = "register" + net._nickserv_email = "user@bad.domain" + net._nickserv_done = asyncio.Event() + + with patch.object(net, "_nickserv_register", new_callable=AsyncMock): + await net._handle_nickserv("bad.domain do not accept email from that address") + assert "bad.domain" in net._rejected_email_domains + + @pytest.mark.asyncio + async def test_register_nick_taken(self) -> None: + net = _net() + net._nickserv_pending = "register" + net._nickserv_done = asyncio.Event() + + await net._handle_nickserv("mynick is already registered by someone else") + assert net._nickserv_pending == "" + assert net._nickserv_done.is_set() + + @pytest.mark.asyncio + async def test_register_rate_limited(self) -> None: + net = _net() + net._nickserv_pending = "register" + net._nickserv_done = asyncio.Event() + + await net._handle_nickserv("You have sent too many registration requests") + assert net._nickserv_pending == "" + assert net._nickserv_done.is_set() + + @pytest.mark.asyncio + async def test_verify_success(self) -> None: + bl = _mock_backlog() + net = _net(backlog=bl) + net.nick = "verified_nick" + net._nickserv_pending = "verify" + net._nickserv_password = "pass" + + await net._handle_nickserv("verified_nick has been verified") + assert net._nickserv_pending == "" + bl.mark_nickserv_verified.assert_awaited_once() + + @pytest.mark.asyncio + async def test_verify_failure(self) -> None: + net = _net() + net._nickserv_pending = "verify" + status = MagicMock() + net.on_status = status + + await net._handle_nickserv("Invalid verification code") + assert net._nickserv_pending == "" + + @pytest.mark.asyncio + async def test_late_registration_confirmation(self) -> None: + """After timeout clears pending state, late confirmation still works.""" + bl = _mock_backlog() + net = _net(backlog=bl) + net.nick = "latenick" + net._nickserv_pending = "" + net._nickserv_password = "pass" + net._nickserv_email = "late@mail.tm" + net._nickserv_done = asyncio.Event() + + await net._handle_nickserv("A passcode has been sent to your email") + bl.save_nickserv_creds.assert_awaited_once() + + +# -- _handle_qbot ----------------------------------------------------------- + +class TestHandleQbot: + @pytest.mark.asyncio + async def test_qbot_auth_success(self) -> None: + net = _net(cfg=_cfg(nick="quser", auth_service="qbot")) + net.nick = "quser" + net._nickserv_pending = "qbot_auth" + net._nickserv_done = asyncio.Event() + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + await net._handle_qbot("You are now logged in as quser.") + assert net._nickserv_pending == "" + assert net._nickserv_done.is_set() + + @pytest.mark.asyncio + async def test_qbot_auth_failure(self) -> None: + net = _net(cfg=_cfg(auth_service="qbot")) + net._nickserv_pending = "qbot_auth" + net._nickserv_done = asyncio.Event() + + await net._handle_qbot("Incorrect password for quser") + assert net._nickserv_pending == "" + assert net._nickserv_done.is_set() + + @pytest.mark.asyncio + async def test_qbot_ignores_when_not_pending(self) -> None: + net = _net() + net._nickserv_pending = "" + net._nickserv_done = asyncio.Event() + + await net._handle_qbot("You are now logged in as quser.") + # Should not have set the event (wasn't waiting for qbot) + assert not net._nickserv_done.is_set() + + +# -- read loop --------------------------------------------------------------- + +class TestReadLoop: + @pytest.mark.asyncio + async def test_processes_complete_lines(self) -> None: + """Read loop should parse complete CRLF-terminated lines.""" + net = _net() + net.state = State.REGISTERING + net._running = True + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + reader = AsyncMock() + reader.read = AsyncMock(side_effect=[ + b"PING :test\r\n", + b"", # EOF + ]) + net._reader = reader + + with patch.object(net, "_disconnect", new_callable=AsyncMock): + with patch.object(net, "_schedule_reconnect"): + await net._read_loop() + + writer.write.assert_called_with(b"PONG test\r\n") + + @pytest.mark.asyncio + async def test_handles_partial_reads(self) -> None: + """Read loop should buffer partial lines across reads.""" + net = _net() + net.state = State.REGISTERING + net._running = True + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + reader = AsyncMock() + reader.read = AsyncMock(side_effect=[ + b"PING :tes", # partial + b"t\r\n", # completion + b"", # EOF + ]) + net._reader = reader + + with patch.object(net, "_disconnect", new_callable=AsyncMock): + with patch.object(net, "_schedule_reconnect"): + await net._read_loop() + + writer.write.assert_called_with(b"PONG test\r\n") + + @pytest.mark.asyncio + async def test_multiple_lines_in_one_read(self) -> None: + """Read loop should handle multiple lines in a single read.""" + cb = MagicMock() + net = _net(on_message=cb) + net.state = State.READY + net._running = True + net.nick = "me" + writer = MagicMock() + writer.is_closing.return_value = False + writer.drain = AsyncMock() + net._writer = writer + + reader = AsyncMock() + reader.read = AsyncMock(side_effect=[ + b":a!u@h PRIVMSG #ch :msg1\r\n:b!u@h PRIVMSG #ch :msg2\r\n", + b"", + ]) + net._reader = reader + + with patch.object(net, "_disconnect", new_callable=AsyncMock): + with patch.object(net, "_schedule_reconnect"): + await net._read_loop() + + assert cb.call_count == 2