Fix E501 line-too-long in backlog.py, network.py, test_network.py. Fix F541 f-string-without-placeholders in network.py. Fix I001 unsorted imports in network.py. Remove unused datetime import in test_cert.py (F401). Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1585 lines
53 KiB
Python
1585 lines
53 KiB
Python
"""Tests for network connection manager."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import random
|
|
import time
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from bouncer.config import BouncerConfig, NetworkConfig, ProxyConfig
|
|
from bouncer.irc import IRCMessage, parse
|
|
from bouncer.network import (
|
|
_STARTERS,
|
|
_VOWELS,
|
|
Network,
|
|
State,
|
|
_markov_word,
|
|
_nick_for_host,
|
|
_password_for_host,
|
|
_random_nick,
|
|
_rng_for_key,
|
|
_seeded_markov,
|
|
)
|
|
|
|
# -- helpers -----------------------------------------------------------------
|
|
|
|
def _cfg(name: str = "testnet", host: str = "irc.test.net", port: int = 6697,
|
|
tls: bool = True, nick: str = "", channels: list[str] | None = None,
|
|
password: str | None = None, auth_service: str = "nickserv") -> NetworkConfig:
|
|
return NetworkConfig(
|
|
name=name, host=host, port=port, tls=tls, nick=nick,
|
|
channels=channels or [], password=password, auth_service=auth_service,
|
|
)
|
|
|
|
|
|
def _proxy() -> ProxyConfig:
|
|
return ProxyConfig(host="127.0.0.1", port=1080)
|
|
|
|
|
|
def _bouncer(**overrides: object) -> BouncerConfig:
|
|
defaults: dict[str, object] = {
|
|
"probation_seconds": 1,
|
|
"nick_timeout": 1,
|
|
"rejoin_delay": 0,
|
|
"backoff_steps": [0],
|
|
}
|
|
defaults.update(overrides)
|
|
return BouncerConfig(**defaults) # type: ignore[arg-type]
|
|
|
|
|
|
def _net(cfg: NetworkConfig | None = None, backlog: AsyncMock | None = None,
|
|
bouncer_cfg: BouncerConfig | None = None,
|
|
on_message: MagicMock | None = None) -> Network:
|
|
return Network(
|
|
cfg=cfg or _cfg(),
|
|
proxy_cfg=_proxy(),
|
|
backlog=backlog,
|
|
on_message=on_message,
|
|
bouncer_cfg=bouncer_cfg or _bouncer(),
|
|
)
|
|
|
|
|
|
def _mock_backlog(**kw: object) -> AsyncMock:
|
|
bl = AsyncMock()
|
|
bl.get_nickserv_creds_by_network.return_value = kw.get("creds_by_network")
|
|
bl.get_nickserv_creds_by_host.return_value = kw.get("creds_by_host")
|
|
bl.get_pending_registration.return_value = kw.get("pending")
|
|
bl.save_nickserv_creds = AsyncMock()
|
|
bl.mark_nickserv_verified = AsyncMock()
|
|
return bl
|
|
|
|
|
|
def _msg(raw: str) -> IRCMessage:
|
|
"""Parse a raw IRC line into an IRCMessage."""
|
|
return parse(raw.encode())
|
|
|
|
|
|
# -- markov / nick generation -----------------------------------------------
|
|
|
|
class TestMarkovWord:
|
|
def test_length_bounds(self) -> None:
|
|
for _ in range(100):
|
|
word = _markov_word(4, 6)
|
|
assert 4 <= len(word) <= 6
|
|
|
|
def test_all_alpha(self) -> None:
|
|
for _ in range(50):
|
|
assert _markov_word(5, 8).isalpha()
|
|
|
|
def test_no_triple_consonants(self) -> None:
|
|
for _ in range(200):
|
|
word = _markov_word(5, 10)
|
|
consonant_run = 0
|
|
for ch in word:
|
|
if ch not in _VOWELS:
|
|
consonant_run += 1
|
|
else:
|
|
consonant_run = 0
|
|
assert consonant_run <= 2, f"triple consonant in {word!r}"
|
|
|
|
def test_starts_with_starter(self) -> None:
|
|
for _ in range(50):
|
|
word = _markov_word(3, 5)
|
|
assert word[0] in _STARTERS
|
|
|
|
|
|
class TestSeededMarkov:
|
|
def test_deterministic(self) -> None:
|
|
rng1 = random.Random(42)
|
|
rng2 = random.Random(42)
|
|
assert _seeded_markov(rng1, 5, 8) == _seeded_markov(rng2, 5, 8)
|
|
|
|
def test_different_seeds_differ(self) -> None:
|
|
results = {_seeded_markov(random.Random(i), 5, 8) for i in range(20)}
|
|
assert len(results) > 1
|
|
|
|
|
|
class TestRandomNick:
|
|
def test_length(self) -> None:
|
|
for _ in range(50):
|
|
nick = _random_nick()
|
|
assert 5 <= len(nick) <= 10 # 5-8 base + optional 0-2 digit suffix
|
|
|
|
def test_starts_alpha(self) -> None:
|
|
for _ in range(50):
|
|
assert _random_nick()[0].isalpha()
|
|
|
|
|
|
class TestNickForHost:
|
|
def test_deterministic(self) -> None:
|
|
assert _nick_for_host("example.com") == _nick_for_host("example.com")
|
|
|
|
def test_different_hosts_differ(self) -> None:
|
|
nicks = {_nick_for_host(f"host{i}.example.com") for i in range(10)}
|
|
assert len(nicks) > 1
|
|
|
|
|
|
class TestPasswordForHost:
|
|
def test_deterministic(self) -> None:
|
|
assert _password_for_host("h.com") == _password_for_host("h.com")
|
|
|
|
def test_length_16(self) -> None:
|
|
assert len(_password_for_host("any.host")) == 16
|
|
|
|
def test_hex_chars(self) -> None:
|
|
pw = _password_for_host("foo.bar")
|
|
assert all(c in "0123456789abcdef" for c in pw)
|
|
|
|
def test_different_from_nick(self) -> None:
|
|
"""Password and nick use different hash domains."""
|
|
host = "same.host"
|
|
pw = _password_for_host(host)
|
|
nick = _nick_for_host(host)
|
|
assert pw != nick
|
|
|
|
|
|
class TestRngForKey:
|
|
def test_deterministic(self) -> None:
|
|
a = _rng_for_key("test").random()
|
|
b = _rng_for_key("test").random()
|
|
assert a == b
|
|
|
|
def test_different_keys(self) -> None:
|
|
a = _rng_for_key("key1").random()
|
|
b = _rng_for_key("key2").random()
|
|
assert a != b
|
|
|
|
|
|
class TestEmailForHost:
|
|
@patch("bouncer.network._email_domains", return_value=["mail.tm", "mail.gw"])
|
|
def test_returns_email(self, _mock: MagicMock) -> None:
|
|
from bouncer.network import _email_for_host
|
|
email = _email_for_host("test.host")
|
|
assert email is not None
|
|
assert "@" in email
|
|
|
|
@patch("bouncer.network._email_domains", return_value=["mail.tm"])
|
|
def test_excludes_domain(self, _mock: MagicMock) -> None:
|
|
from bouncer.network import _email_for_host
|
|
result = _email_for_host("host", excluded={"mail.tm"})
|
|
assert result is None
|
|
|
|
@patch("bouncer.network._email_domains", return_value=[])
|
|
def test_no_domains_returns_none(self, _mock: MagicMock) -> None:
|
|
from bouncer.network import _email_for_host
|
|
assert _email_for_host("host") is None
|
|
|
|
|
|
# -- state properties --------------------------------------------------------
|
|
|
|
class TestStateProperties:
|
|
def test_initial_state(self) -> None:
|
|
net = _net()
|
|
assert net.state == State.DISCONNECTED
|
|
assert not net.connected
|
|
assert not net.registered
|
|
assert not net.ready
|
|
|
|
def test_connecting_not_connected(self) -> None:
|
|
net = _net()
|
|
net.state = State.CONNECTING
|
|
assert not net.connected
|
|
|
|
def test_registering_is_connected(self) -> None:
|
|
net = _net()
|
|
net.state = State.REGISTERING
|
|
assert net.connected
|
|
assert not net.registered
|
|
|
|
def test_probation_is_registered(self) -> None:
|
|
net = _net()
|
|
net.state = State.PROBATION
|
|
assert net.connected
|
|
assert net.registered
|
|
assert not net.ready
|
|
|
|
def test_ready(self) -> None:
|
|
net = _net()
|
|
net.state = State.READY
|
|
assert net.connected
|
|
assert net.registered
|
|
assert net.ready
|
|
|
|
|
|
class TestInit:
|
|
def test_defaults(self) -> None:
|
|
net = _net()
|
|
assert net.nick == "*"
|
|
assert net.channels == set()
|
|
assert net.topics == {}
|
|
assert net.names == {}
|
|
assert net.visible_host is None
|
|
|
|
def test_nick_from_config(self) -> None:
|
|
net = _net(cfg=_cfg(nick="mynick"))
|
|
assert net.nick == "mynick"
|
|
|
|
|
|
# -- send helpers ------------------------------------------------------------
|
|
|
|
class TestSend:
|
|
@pytest.mark.asyncio
|
|
async def test_send_writes_to_writer(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
msg = IRCMessage(command="PRIVMSG", params=["#test", "hello world"])
|
|
await net.send(msg)
|
|
writer.write.assert_called_once_with(b"PRIVMSG #test :hello world\r\n")
|
|
writer.drain.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_skips_closing_writer(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = True
|
|
net._writer = writer
|
|
|
|
await net.send(IRCMessage(command="PING"))
|
|
writer.write.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_skips_none_writer(self) -> None:
|
|
net = _net()
|
|
net._writer = None
|
|
# Should not raise
|
|
await net.send(IRCMessage(command="PING"))
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_raw(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net.send_raw("JOIN", "#channel")
|
|
writer.write.assert_called_once_with(b"JOIN #channel\r\n")
|
|
|
|
|
|
# -- _handle: PING/PONG -----------------------------------------------------
|
|
|
|
class TestHandlePing:
|
|
@pytest.mark.asyncio
|
|
async def test_ping_pong(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg("PING :server.example.com"))
|
|
writer.write.assert_called_once_with(b"PONG server.example.com\r\n")
|
|
|
|
|
|
# -- _handle: ERROR ---------------------------------------------------------
|
|
|
|
class TestHandleError:
|
|
@pytest.mark.asyncio
|
|
async def test_error_sets_status(self) -> None:
|
|
status = MagicMock()
|
|
net = _net()
|
|
net.on_status = status
|
|
|
|
await net._handle(_msg(":server ERROR :K-lined"))
|
|
status.assert_called_once_with("testnet", "ERROR: K-lined")
|
|
|
|
|
|
# -- _handle: 001 (RPL_WELCOME) ---------------------------------------------
|
|
|
|
class TestHandleWelcome:
|
|
@pytest.mark.asyncio
|
|
async def test_sets_nick_from_params(self) -> None:
|
|
net = _net()
|
|
net.state = State.REGISTERING
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
welcome = ":server 001 coolguy :Welcome to the network coolguy!user@host.example.com"
|
|
await net._handle(_msg(welcome))
|
|
assert net.nick == "coolguy"
|
|
assert net.state == State.PROBATION
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extracts_visible_host(self) -> None:
|
|
net = _net()
|
|
net.state = State.REGISTERING
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
welcome = ":server 001 nick :Welcome to the IRC Network nick!user@visible.host.com"
|
|
await net._handle(_msg(welcome))
|
|
assert net.visible_host == "visible.host.com"
|
|
|
|
|
|
# -- _handle: 396 (RPL_VISIBLEHOST) -----------------------------------------
|
|
|
|
class TestHandleVisibleHost:
|
|
@pytest.mark.asyncio
|
|
async def test_updates_visible_host(self) -> None:
|
|
net = _net()
|
|
await net._handle(_msg(":server 396 nick new.host.com :is now your displayed host"))
|
|
assert net.visible_host == "new.host.com"
|
|
|
|
|
|
# -- _handle: NICK ----------------------------------------------------------
|
|
|
|
class TestHandleNick:
|
|
@pytest.mark.asyncio
|
|
async def test_own_nick_change(self) -> None:
|
|
net = _net()
|
|
net.nick = "oldnick"
|
|
await net._handle(_msg(":oldnick!user@host NICK newnick"))
|
|
assert net.nick == "newnick"
|
|
assert net._nick_confirmed.is_set()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_other_nick_change_ignored(self) -> None:
|
|
net = _net()
|
|
net.nick = "mynick"
|
|
await net._handle(_msg(":othernick!user@host NICK somethingelse"))
|
|
assert net.nick == "mynick"
|
|
assert not net._nick_confirmed.is_set()
|
|
|
|
|
|
# -- _handle: JOIN/PART -----------------------------------------------------
|
|
|
|
class TestHandleJoinPart:
|
|
@pytest.mark.asyncio
|
|
async def test_own_join(self) -> None:
|
|
net = _net()
|
|
net.nick = "me"
|
|
await net._handle(_msg(":me!user@host JOIN #test"))
|
|
assert "#test" in net.channels
|
|
assert "#test" in net.names
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_other_join_ignored(self) -> None:
|
|
net = _net()
|
|
net.nick = "me"
|
|
await net._handle(_msg(":other!user@host JOIN #test"))
|
|
assert "#test" not in net.channels
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_own_part(self) -> None:
|
|
net = _net()
|
|
net.nick = "me"
|
|
net.channels = {"#test"}
|
|
net.names["#test"] = {"me", "other"}
|
|
net.topics["#test"] = "Topic"
|
|
|
|
await net._handle(_msg(":me!user@host PART #test"))
|
|
assert "#test" not in net.channels
|
|
assert "#test" not in net.names
|
|
assert "#test" not in net.topics
|
|
|
|
|
|
# -- _handle: 332 (RPL_TOPIC), 353 (RPL_NAMREPLY) --------------------------
|
|
|
|
class TestHandleTopicNames:
|
|
@pytest.mark.asyncio
|
|
async def test_topic(self) -> None:
|
|
net = _net()
|
|
await net._handle(_msg(":server 332 me #test :Welcome to the channel"))
|
|
assert net.topics["#test"] == "Welcome to the channel"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_namreply(self) -> None:
|
|
net = _net()
|
|
await net._handle(_msg(":server 353 me = #test :@op +voice regular"))
|
|
assert net.names["#test"] == {"@op", "+voice", "regular"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_namreply_accumulates(self) -> None:
|
|
net = _net()
|
|
await net._handle(_msg(":server 353 me = #test :nick1 nick2"))
|
|
await net._handle(_msg(":server 353 me = #test :nick3"))
|
|
assert net.names["#test"] == {"nick1", "nick2", "nick3"}
|
|
|
|
|
|
# -- _handle: 433 (ERR_NICKNAMEINUSE) ---------------------------------------
|
|
|
|
class TestHandleNickInUse:
|
|
@pytest.mark.asyncio
|
|
async def test_during_registration_picks_random(self) -> None:
|
|
net = _net()
|
|
net.state = State.REGISTERING
|
|
net.nick = "taken"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server 433 * taken :Nickname is already in use"))
|
|
# Should have changed to a new random nick
|
|
assert net.nick != "taken"
|
|
assert writer.write.called
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_when_ready_appends_underscore(self) -> None:
|
|
net = _net()
|
|
net.state = State.READY
|
|
net.nick = "desired"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server 433 * desired :Nickname is already in use"))
|
|
assert net.nick == "desired_"
|
|
writer.write.assert_called_once_with(b"NICK desired_\r\n")
|
|
|
|
|
|
# -- _handle: KICK ----------------------------------------------------------
|
|
|
|
class TestHandleKick:
|
|
@pytest.mark.asyncio
|
|
async def test_kicked_removes_channel(self) -> None:
|
|
net = _net(cfg=_cfg(channels=["#test"]))
|
|
net.state = State.READY
|
|
net._running = True
|
|
net.nick = "me"
|
|
net.channels = {"#test"}
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":op!user@host KICK #test me :reason"))
|
|
assert "#test" not in net.channels
|
|
# Should rejoin (rejoin_delay=0)
|
|
writer.write.assert_called_with(b"JOIN #test\r\n")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_kick_rejoin_with_key(self) -> None:
|
|
cfg = _cfg(channels=["#secret"])
|
|
cfg.channel_keys = {"#secret": "hunter2"}
|
|
net = _net(cfg=cfg)
|
|
net.state = State.READY
|
|
net._running = True
|
|
net.nick = "me"
|
|
net.channels = {"#secret"}
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":op!user@host KICK #secret me :reason"))
|
|
assert "#secret" not in net.channels
|
|
# Should rejoin with key (rejoin_delay=0)
|
|
writer.write.assert_called_with(b"JOIN #secret hunter2\r\n")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_kick_other_user_ignored(self) -> None:
|
|
net = _net()
|
|
net.nick = "me"
|
|
net.channels = {"#test"}
|
|
await net._handle(_msg(":op!user@host KICK #test other :reason"))
|
|
assert "#test" in net.channels
|
|
|
|
|
|
# -- _handle: NOTICE (hostname extraction) -----------------------------------
|
|
|
|
class TestHandleNotice:
|
|
@pytest.mark.asyncio
|
|
async def test_found_hostname_notice(self) -> None:
|
|
net = _net()
|
|
await net._handle(
|
|
_msg(":server NOTICE * :*** Found your hostname: some.isp.example.com")
|
|
)
|
|
assert net.visible_host == "some.isp.example.com"
|
|
|
|
|
|
# -- _handle: CAP (SASL negotiation) ----------------------------------------
|
|
|
|
class TestHandleCap:
|
|
@pytest.mark.asyncio
|
|
async def test_cap_ack_sasl_starts_authenticate(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "PLAIN"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server CAP * ACK :sasl"))
|
|
writer.write.assert_called_with(b"AUTHENTICATE PLAIN\r\n")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cap_nak_sasl_ends_cap(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "PLAIN"
|
|
net._sasl_nick = "nick"
|
|
net._sasl_pass = "pass"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server CAP * NAK :sasl"))
|
|
writer.write.assert_called_with(b"CAP END\r\n")
|
|
assert net._sasl_mechanism == ""
|
|
assert net._sasl_nick == ""
|
|
|
|
|
|
# -- _handle: AUTHENTICATE --------------------------------------------------
|
|
|
|
class TestHandleAuthenticate:
|
|
@pytest.mark.asyncio
|
|
async def test_plain_credentials(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "PLAIN"
|
|
net._sasl_nick = "mynick"
|
|
net._sasl_pass = "mypass"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg("AUTHENTICATE +"))
|
|
expected = base64.b64encode(b"mynick\0mynick\0mypass").decode()
|
|
writer.write.assert_called_with(f"AUTHENTICATE {expected}\r\n".encode())
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_external_sends_nick(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "EXTERNAL"
|
|
net._sasl_nick = "extnick"
|
|
net._sasl_pass = "pass"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg("AUTHENTICATE +"))
|
|
expected = base64.b64encode(b"extnick").decode()
|
|
writer.write.assert_called_with(f"AUTHENTICATE {expected}\r\n".encode())
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_creds_aborts(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "PLAIN"
|
|
net._sasl_nick = ""
|
|
net._sasl_pass = ""
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg("AUTHENTICATE +"))
|
|
writer.write.assert_called_with(b"AUTHENTICATE *\r\n")
|
|
|
|
|
|
# -- _handle: 903 (RPL_SASLSUCCESS) ----------------------------------------
|
|
|
|
class TestHandleSaslSuccess:
|
|
@pytest.mark.asyncio
|
|
async def test_sasl_success_sets_event(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "PLAIN"
|
|
net._sasl_nick = "nick"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server 903 nick :SASL authentication successful"))
|
|
assert net._sasl_complete.is_set()
|
|
writer.write.assert_called_with(b"CAP END\r\n")
|
|
|
|
|
|
# -- _handle: 904/905 (SASL failure) ----------------------------------------
|
|
|
|
class TestHandleSaslFailure:
|
|
@pytest.mark.asyncio
|
|
async def test_external_falls_back_to_plain(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "EXTERNAL"
|
|
net._sasl_nick = "nick"
|
|
net._sasl_pass = "pass"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server 904 nick :SASL authentication failed"))
|
|
assert net._sasl_mechanism == "PLAIN"
|
|
writer.write.assert_called_with(b"AUTHENTICATE PLAIN\r\n")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_plain_failure_ends_cap(self) -> None:
|
|
net = _net()
|
|
net._sasl_mechanism = "PLAIN"
|
|
net._sasl_nick = "nick"
|
|
net._sasl_pass = "pass"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server 904 nick :SASL authentication failed"))
|
|
assert net._sasl_mechanism == ""
|
|
assert net._sasl_nick == ""
|
|
writer.write.assert_called_with(b"CAP END\r\n")
|
|
|
|
|
|
# -- _handle: 906/908 (SASL aborted/mechs) ----------------------------------
|
|
|
|
class TestHandleSaslAborted:
|
|
@pytest.mark.asyncio
|
|
async def test_906_ends_cap(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg(":server 906 nick :SASL authentication aborted"))
|
|
writer.write.assert_called_with(b"CAP END\r\n")
|
|
|
|
|
|
# -- message forwarding to router -------------------------------------------
|
|
|
|
class TestOnMessage:
|
|
@pytest.mark.asyncio
|
|
async def test_on_message_called_for_non_ping(self) -> None:
|
|
cb = MagicMock()
|
|
net = _net(on_message=cb)
|
|
await net._handle(_msg(":nick!user@host PRIVMSG #test :hello"))
|
|
cb.assert_called_once()
|
|
assert cb.call_args[0][0] == "testnet"
|
|
assert cb.call_args[0][1].command == "PRIVMSG"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_on_message_not_called_for_ping(self) -> None:
|
|
cb = MagicMock()
|
|
net = _net(on_message=cb)
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle(_msg("PING :server"))
|
|
cb.assert_not_called()
|
|
|
|
|
|
# -- probation timer ---------------------------------------------------------
|
|
|
|
class TestProbation:
|
|
@pytest.mark.asyncio
|
|
async def test_enter_probation(self) -> None:
|
|
net = _net(bouncer_cfg=_bouncer(probation_seconds=0))
|
|
net.state = State.REGISTERING
|
|
net._running = True
|
|
# Mock _go_ready to prevent full NickServ flow
|
|
net._go_ready = AsyncMock()
|
|
|
|
await net._enter_probation()
|
|
assert net.state == State.PROBATION
|
|
assert net._probation_task is not None
|
|
|
|
# Wait for the probation timer to fire
|
|
await net._probation_task
|
|
net._go_ready.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_probation_cancelled(self) -> None:
|
|
net = _net(bouncer_cfg=_bouncer(probation_seconds=60))
|
|
net.state = State.REGISTERING
|
|
net._running = True
|
|
net._go_ready = AsyncMock()
|
|
|
|
await net._enter_probation()
|
|
net._probation_task.cancel()
|
|
try:
|
|
await net._probation_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
net._go_ready.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_probation_resets_reconnect_counter(self) -> None:
|
|
net = _net(bouncer_cfg=_bouncer(probation_seconds=0))
|
|
net.state = State.REGISTERING
|
|
net._running = True
|
|
net._reconnect_attempt = 5
|
|
net._go_ready = AsyncMock()
|
|
|
|
await net._enter_probation()
|
|
await net._probation_task
|
|
assert net._reconnect_attempt == 0
|
|
|
|
|
|
# -- reconnection backoff ---------------------------------------------------
|
|
|
|
class TestReconnect:
|
|
@pytest.mark.asyncio
|
|
async def test_backoff_increments(self) -> None:
|
|
net = _net(bouncer_cfg=_bouncer(backoff_steps=[0, 0, 0]))
|
|
net._running = True
|
|
|
|
with patch.object(net, "_connect", new_callable=AsyncMock) as mock_connect:
|
|
assert net._reconnect_attempt == 0
|
|
net._schedule_reconnect()
|
|
await net._reconnect_task
|
|
assert net._reconnect_attempt == 1
|
|
mock_connect.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_backoff_clamps_to_last_step(self) -> None:
|
|
steps = [1, 2, 5]
|
|
net = _net(bouncer_cfg=_bouncer(backoff_steps=steps))
|
|
net._running = True
|
|
net._reconnect_attempt = 100
|
|
|
|
with patch.object(net, "_connect", new_callable=AsyncMock):
|
|
net._schedule_reconnect()
|
|
# The delay should use the last step (5), but since we can't
|
|
# easily measure the sleep duration, verify the attempt incremented
|
|
await net._reconnect_task
|
|
assert net._reconnect_attempt == 101
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_reconnect_cancelled(self) -> None:
|
|
net = _net(bouncer_cfg=_bouncer(backoff_steps=[60]))
|
|
net._running = True
|
|
|
|
with patch.object(net, "_connect", new_callable=AsyncMock) as mock_connect:
|
|
net._schedule_reconnect()
|
|
net._reconnect_task.cancel()
|
|
try:
|
|
await net._reconnect_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
mock_connect.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_reconnect_when_stopped(self) -> None:
|
|
net = _net(bouncer_cfg=_bouncer(backoff_steps=[0]))
|
|
net._running = False
|
|
|
|
with patch.object(net, "_connect", new_callable=AsyncMock) as mock_connect:
|
|
net._schedule_reconnect()
|
|
await net._reconnect_task
|
|
mock_connect.assert_not_awaited()
|
|
|
|
|
|
# -- disconnect --------------------------------------------------------------
|
|
|
|
class TestDisconnect:
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_closes_writer(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.wait_closed = AsyncMock()
|
|
net._writer = writer
|
|
net._reader = MagicMock()
|
|
|
|
await net._disconnect()
|
|
assert net.state == State.DISCONNECTED
|
|
writer.close.assert_called_once()
|
|
assert net._writer is None
|
|
assert net._reader is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_cancels_probation(self) -> None:
|
|
net = _net()
|
|
task = MagicMock()
|
|
task.done.return_value = False
|
|
net._probation_task = task
|
|
|
|
await net._disconnect()
|
|
task.cancel.assert_called_once()
|
|
assert net._probation_task is None
|
|
|
|
|
|
# -- stop --------------------------------------------------------------------
|
|
|
|
class TestStop:
|
|
@pytest.mark.asyncio
|
|
async def test_stop_cancels_tasks(self) -> None:
|
|
net = _net()
|
|
net._running = True
|
|
read_task = MagicMock()
|
|
read_task.done.return_value = False
|
|
reconnect_task = MagicMock()
|
|
reconnect_task.done.return_value = False
|
|
net._read_task = read_task
|
|
net._reconnect_task = reconnect_task
|
|
|
|
await net.stop()
|
|
assert not net._running
|
|
read_task.cancel.assert_called_once()
|
|
reconnect_task.cancel.assert_called_once()
|
|
assert net.state == State.DISCONNECTED
|
|
|
|
|
|
# -- _connect ----------------------------------------------------------------
|
|
|
|
class TestConnect:
|
|
@pytest.mark.asyncio
|
|
async def test_connect_no_sasl(self) -> None:
|
|
"""Without stored creds, uses random nick and no SASL."""
|
|
net = _net()
|
|
reader = MagicMock()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
|
|
with patch("bouncer.proxy.connect", new_callable=AsyncMock,
|
|
return_value=(reader, writer)):
|
|
# Prevent the read loop from actually running
|
|
with patch.object(net, "_read_loop", new_callable=AsyncMock):
|
|
await net._connect()
|
|
|
|
assert net.state == State.REGISTERING
|
|
# Should have sent NICK and USER
|
|
calls = [c.args[0] for c in writer.write.call_args_list]
|
|
nick_sent = any(b"NICK " in c for c in calls)
|
|
user_sent = any(b"USER " in c for c in calls)
|
|
assert nick_sent
|
|
assert user_sent
|
|
# Should have sent CAP REQ server-time but NOT CAP REQ sasl
|
|
cap_server_time = any(b"CAP REQ server-time" in c for c in calls)
|
|
cap_sasl = any(b"CAP REQ sasl" in c for c in calls)
|
|
assert cap_server_time
|
|
assert not cap_sasl
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_with_sasl_plain(self) -> None:
|
|
"""With stored creds and no cert, uses SASL PLAIN."""
|
|
bl = _mock_backlog(creds_by_network=("registered_nick", "secret"))
|
|
net = _net(backlog=bl)
|
|
reader = MagicMock()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
|
|
with patch("bouncer.proxy.connect", new_callable=AsyncMock,
|
|
return_value=(reader, writer)):
|
|
with patch.object(net, "_read_loop", new_callable=AsyncMock):
|
|
await net._connect()
|
|
|
|
assert net._sasl_mechanism == "PLAIN"
|
|
assert net._sasl_nick == "registered_nick"
|
|
assert net._connect_nick == "registered_nick"
|
|
calls = [c.args[0] for c in writer.write.call_args_list]
|
|
cap_sent = any(b"CAP REQ sasl" in c for c in calls)
|
|
assert cap_sent
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_failure_schedules_reconnect(self) -> None:
|
|
"""Connection failure should schedule a reconnect."""
|
|
net = _net()
|
|
net._running = True
|
|
|
|
with patch("bouncer.proxy.connect", new_callable=AsyncMock,
|
|
side_effect=ConnectionRefusedError("refused")):
|
|
with patch.object(net, "_schedule_reconnect") as mock_sched:
|
|
await net._connect()
|
|
|
|
assert net.state == State.DISCONNECTED
|
|
mock_sched.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_sends_server_password(self) -> None:
|
|
"""If network has a password, sends PASS before NICK."""
|
|
net = _net(cfg=_cfg(password="serverpass"))
|
|
reader = MagicMock()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
|
|
with patch("bouncer.proxy.connect", new_callable=AsyncMock,
|
|
return_value=(reader, writer)):
|
|
with patch.object(net, "_read_loop", new_callable=AsyncMock):
|
|
await net._connect()
|
|
|
|
calls = [c.args[0] for c in writer.write.call_args_list]
|
|
pass_sent = any(b"PASS serverpass" in c for c in calls)
|
|
assert pass_sent
|
|
|
|
|
|
# -- _go_ready ---------------------------------------------------------------
|
|
|
|
class TestGoReady:
|
|
@pytest.mark.asyncio
|
|
async def test_sasl_skips_nickserv(self) -> None:
|
|
"""When SASL already authenticated, skip NickServ flow."""
|
|
bl = _mock_backlog()
|
|
net = _net(backlog=bl, cfg=_cfg(auth_service="nickserv"))
|
|
net.state = State.PROBATION
|
|
net.nick = "autheduser"
|
|
net._sasl_complete.set()
|
|
|
|
await net._go_ready()
|
|
assert net.state == State.READY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_auth_service_none_skips_auth(self) -> None:
|
|
"""When auth_service=none, no authentication attempted."""
|
|
status = MagicMock()
|
|
net = _net(cfg=_cfg(auth_service="none"))
|
|
net.on_status = status
|
|
net.state = State.PROBATION
|
|
net.nick = "anon"
|
|
|
|
await net._go_ready()
|
|
assert net.state == State.READY
|
|
status.assert_called_with("testnet", "ready as anon (no auth service)")
|
|
|
|
|
|
# -- NickServ response parsing -----------------------------------------------
|
|
|
|
class TestNickServMatchers:
|
|
def test_registration_confirmed(self) -> None:
|
|
net = _net()
|
|
assert net._registration_confirmed("a passcode has been sent to your email")
|
|
assert net._registration_confirmed("activation instructions have been sent")
|
|
assert not net._registration_confirmed("you are now identified")
|
|
|
|
def test_registration_immediate(self) -> None:
|
|
net = _net()
|
|
assert net._registration_immediate("nickname registered under your account")
|
|
assert not net._registration_immediate("nickname registered, email verification required")
|
|
|
|
def test_verification_succeeded(self) -> None:
|
|
net = _net()
|
|
assert net._verification_succeeded("your nick has been verified")
|
|
assert net._verification_succeeded("has now been activated for use")
|
|
assert net._verification_succeeded("you are now identified for this nick")
|
|
assert not net._verification_succeeded("verification pending")
|
|
|
|
|
|
# -- _handle_nickserv --------------------------------------------------------
|
|
|
|
class TestHandleNickserv:
|
|
@pytest.mark.asyncio
|
|
async def test_identify_success(self) -> None:
|
|
bl = _mock_backlog()
|
|
net = _net(backlog=bl)
|
|
net.nick = "mynick"
|
|
net.visible_host = "user/mynick"
|
|
net._nickserv_pending = "identify"
|
|
net._nickserv_password = "secret"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("You are now identified for mynick")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
bl.save_nickserv_creds.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_identify_not_registered_triggers_register(self) -> None:
|
|
net = _net()
|
|
net._nickserv_pending = "identify"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
with patch.object(net, "_nickserv_register", new_callable=AsyncMock) as mock_reg:
|
|
await net._handle_nickserv("mynick is not a registered nickname")
|
|
mock_reg.assert_awaited_once()
|
|
assert net._nickserv_pending == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_identify_wrong_password(self) -> None:
|
|
net = _net()
|
|
net._nickserv_pending = "identify"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("Invalid password for mynick")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_email_sent(self) -> None:
|
|
bl = _mock_backlog()
|
|
net = _net(backlog=bl)
|
|
net.nick = "newnick"
|
|
net._nickserv_pending = "register"
|
|
net._nickserv_password = "pass"
|
|
net._nickserv_email = "test@mail.tm"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("A passcode has been sent to test@mail.tm")
|
|
assert net._nickserv_pending == "verify"
|
|
bl.save_nickserv_creds.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_domain_rejected(self) -> None:
|
|
net = _net()
|
|
net._nickserv_pending = "register"
|
|
net._nickserv_email = "user@bad.domain"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
with patch.object(net, "_nickserv_register", new_callable=AsyncMock):
|
|
await net._handle_nickserv("bad.domain do not accept email from that address")
|
|
assert "bad.domain" in net._rejected_email_domains
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_nick_taken(self) -> None:
|
|
net = _net()
|
|
net._nickserv_pending = "register"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("mynick is already registered by someone else")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_register_rate_limited(self) -> None:
|
|
net = _net()
|
|
net._nickserv_pending = "register"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("You have sent too many registration requests")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_success(self) -> None:
|
|
bl = _mock_backlog()
|
|
net = _net(backlog=bl)
|
|
net.nick = "verified_nick"
|
|
net._nickserv_pending = "verify"
|
|
net._nickserv_password = "pass"
|
|
|
|
await net._handle_nickserv("verified_nick has been verified")
|
|
assert net._nickserv_pending == ""
|
|
bl.mark_nickserv_verified.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_success_signals_completion(self) -> None:
|
|
"""_on_verify_success must signal _nickserv_done."""
|
|
bl = _mock_backlog()
|
|
net = _net(backlog=bl)
|
|
net.nick = "verified_nick"
|
|
net._nickserv_pending = "verify"
|
|
net._nickserv_password = "pass"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("verified_nick has been verified")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_registration_immediate_signals_completion(self) -> None:
|
|
"""Immediate registration (no email) must signal _nickserv_done."""
|
|
bl = _mock_backlog()
|
|
net = _net(backlog=bl)
|
|
net.nick = "fastnick"
|
|
net._nickserv_pending = "register"
|
|
net._nickserv_password = "pass"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("Nickname registered under your account")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
bl.mark_nickserv_verified.assert_awaited_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_verify_failure(self) -> None:
|
|
net = _net()
|
|
net._nickserv_pending = "verify"
|
|
status = MagicMock()
|
|
net.on_status = status
|
|
|
|
await net._handle_nickserv("Invalid verification code")
|
|
assert net._nickserv_pending == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resume_pending_restores_verify_url(self) -> None:
|
|
"""Cross-session resume must restore the verify_url from DB."""
|
|
bl = _mock_backlog(
|
|
pending=("oldnick", "pass", "e@mail.tm", "host", "https://oftc/verify/abc"),
|
|
)
|
|
net = _net(backlog=bl)
|
|
net.state = State.READY
|
|
net._running = True
|
|
net.nick = "oldnick"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
result = await net._resume_pending_verification()
|
|
assert result is True
|
|
assert net._verify_url == "https://oftc/verify/abc"
|
|
assert net._nickserv_email == "e@mail.tm"
|
|
assert net._nickserv_password == "pass"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_resume_pending_no_url(self) -> None:
|
|
"""Resume works when verify_url is empty."""
|
|
bl = _mock_backlog(
|
|
pending=("nick", "pass", "e@mail.tm", "host", ""),
|
|
)
|
|
net = _net(backlog=bl)
|
|
net.state = State.READY
|
|
net._running = True
|
|
net.nick = "nick"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
result = await net._resume_pending_verification()
|
|
assert result is True
|
|
assert net._verify_url == ""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_late_registration_confirmation(self) -> None:
|
|
"""After timeout clears pending state, late confirmation still works."""
|
|
bl = _mock_backlog()
|
|
net = _net(backlog=bl)
|
|
net.nick = "latenick"
|
|
net._nickserv_pending = ""
|
|
net._nickserv_password = "pass"
|
|
net._nickserv_email = "late@mail.tm"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_nickserv("A passcode has been sent to your email")
|
|
bl.save_nickserv_creds.assert_awaited_once()
|
|
|
|
|
|
# -- _handle_qbot -----------------------------------------------------------
|
|
|
|
class TestHandleQbot:
|
|
@pytest.mark.asyncio
|
|
async def test_qbot_auth_success(self) -> None:
|
|
net = _net(cfg=_cfg(nick="quser", auth_service="qbot"))
|
|
net.nick = "quser"
|
|
net._nickserv_pending = "qbot_auth"
|
|
net._nickserv_done = asyncio.Event()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
await net._handle_qbot("You are now logged in as quser.")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_qbot_auth_failure(self) -> None:
|
|
net = _net(cfg=_cfg(auth_service="qbot"))
|
|
net._nickserv_pending = "qbot_auth"
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_qbot("Incorrect password for quser")
|
|
assert net._nickserv_pending == ""
|
|
assert net._nickserv_done.is_set()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_qbot_ignores_when_not_pending(self) -> None:
|
|
net = _net()
|
|
net._nickserv_pending = ""
|
|
net._nickserv_done = asyncio.Event()
|
|
|
|
await net._handle_qbot("You are now logged in as quser.")
|
|
# Should not have set the event (wasn't waiting for qbot)
|
|
assert not net._nickserv_done.is_set()
|
|
|
|
|
|
# -- read loop ---------------------------------------------------------------
|
|
|
|
class TestReadLoop:
|
|
@pytest.mark.asyncio
|
|
async def test_processes_complete_lines(self) -> None:
|
|
"""Read loop should parse complete CRLF-terminated lines."""
|
|
net = _net()
|
|
net.state = State.REGISTERING
|
|
net._running = True
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
reader = AsyncMock()
|
|
reader.read = AsyncMock(side_effect=[
|
|
b"PING :test\r\n",
|
|
b"", # EOF
|
|
])
|
|
net._reader = reader
|
|
|
|
with patch.object(net, "_disconnect", new_callable=AsyncMock):
|
|
with patch.object(net, "_schedule_reconnect"):
|
|
await net._read_loop()
|
|
|
|
writer.write.assert_called_with(b"PONG test\r\n")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_handles_partial_reads(self) -> None:
|
|
"""Read loop should buffer partial lines across reads."""
|
|
net = _net()
|
|
net.state = State.REGISTERING
|
|
net._running = True
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
reader = AsyncMock()
|
|
reader.read = AsyncMock(side_effect=[
|
|
b"PING :tes", # partial
|
|
b"t\r\n", # completion
|
|
b"", # EOF
|
|
])
|
|
net._reader = reader
|
|
|
|
with patch.object(net, "_disconnect", new_callable=AsyncMock):
|
|
with patch.object(net, "_schedule_reconnect"):
|
|
await net._read_loop()
|
|
|
|
writer.write.assert_called_with(b"PONG test\r\n")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_lines_in_one_read(self) -> None:
|
|
"""Read loop should handle multiple lines in a single read."""
|
|
cb = MagicMock()
|
|
net = _net(on_message=cb)
|
|
net.state = State.READY
|
|
net._running = True
|
|
net.nick = "me"
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
reader = AsyncMock()
|
|
reader.read = AsyncMock(side_effect=[
|
|
b":a!u@h PRIVMSG #ch :msg1\r\n:b!u@h PRIVMSG #ch :msg2\r\n",
|
|
b"",
|
|
])
|
|
net._reader = reader
|
|
|
|
with patch.object(net, "_disconnect", new_callable=AsyncMock):
|
|
with patch.object(net, "_schedule_reconnect"):
|
|
await net._read_loop()
|
|
|
|
assert cb.call_count == 2
|
|
|
|
|
|
# -- PING watchdog -----------------------------------------------------------
|
|
|
|
class TestPingWatchdog:
|
|
@pytest.mark.asyncio
|
|
async def test_timeout_triggers_disconnect(self) -> None:
|
|
"""Stale connection triggers disconnect + reconnect."""
|
|
net = _net(bouncer_cfg=_bouncer(ping_interval=0, ping_timeout=0))
|
|
net.state = State.READY
|
|
net._running = True
|
|
net._last_recv = time.monotonic() - 1000 # stale
|
|
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
with patch.object(net, "_disconnect", new_callable=AsyncMock) as mock_disc:
|
|
with patch.object(net, "_schedule_reconnect") as mock_recon:
|
|
await net._ping_watchdog()
|
|
|
|
mock_disc.assert_awaited_once()
|
|
mock_recon.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_healthy_connection_stays_alive(self) -> None:
|
|
"""Fresh data during timeout window prevents disconnect."""
|
|
net = _net(bouncer_cfg=_bouncer(ping_interval=10, ping_timeout=5))
|
|
net.state = State.READY
|
|
net._running = True
|
|
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
# Control time progression: stale at first, then fresh after PING
|
|
clock = [100.0] # start at T=100
|
|
|
|
def fake_monotonic() -> float:
|
|
return clock[0]
|
|
|
|
original_sleep = asyncio.sleep
|
|
call_count = 0
|
|
|
|
async def fake_sleep(delay: float) -> None:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
# After interval sleep: time advanced, data is stale -> PING sent
|
|
clock[0] = 120.0
|
|
net._last_recv = 100.0 # stale (20s > interval=10)
|
|
elif call_count == 2:
|
|
# During timeout wait: simulate PONG received (fresh data)
|
|
clock[0] = 122.0
|
|
net._last_recv = 122.0 # fresh
|
|
elif call_count == 3:
|
|
# Next interval sleep: exit loop
|
|
net.state = State.DISCONNECTED
|
|
await original_sleep(0)
|
|
|
|
with patch("time.monotonic", side_effect=fake_monotonic):
|
|
with patch("asyncio.sleep", side_effect=fake_sleep):
|
|
with patch.object(net, "_disconnect", new_callable=AsyncMock) as mock_disc:
|
|
await net._ping_watchdog()
|
|
|
|
# PING was sent, but fresh data arrived -- no disconnect
|
|
ping_sent = any(b"PING" in c.args[0] for c in writer.write.call_args_list)
|
|
assert ping_sent
|
|
mock_disc.assert_not_awaited()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_watchdog_cancelled_on_disconnect(self) -> None:
|
|
net = _net()
|
|
net.state = State.READY
|
|
net._running = True
|
|
net._last_recv = time.monotonic()
|
|
|
|
task = MagicMock()
|
|
task.done.return_value = False
|
|
net._ping_task = task
|
|
|
|
await net._disconnect()
|
|
task.cancel.assert_called_once()
|
|
assert net._ping_task is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ping_task_started_in_go_ready(self) -> None:
|
|
net = _net(bouncer_cfg=_bouncer(probation_seconds=0, ping_interval=999))
|
|
net.state = State.PROBATION
|
|
net._running = True
|
|
net._sasl_complete.set()
|
|
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
bl = _mock_backlog()
|
|
net.backlog = bl
|
|
|
|
await net._go_ready()
|
|
assert net._ping_task is not None
|
|
# Cancel it so test doesn't leak
|
|
net._ping_task.cancel()
|
|
try:
|
|
await net._ping_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ping_task_in_stop(self) -> None:
|
|
"""stop() cancels the ping task."""
|
|
net = _net()
|
|
net._running = True
|
|
ping_task = MagicMock()
|
|
ping_task.done.return_value = False
|
|
net._ping_task = ping_task
|
|
|
|
await net.stop()
|
|
# cancel() is called in both stop() and _disconnect()
|
|
assert ping_task.cancel.call_count >= 1
|
|
|
|
|
|
# -- IRCv3 CAP negotiation (server-time) ------------------------------------
|
|
|
|
class TestCapServerTime:
|
|
@pytest.mark.asyncio
|
|
async def test_server_time_ack_sets_flag(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
net._caps_pending = 1
|
|
|
|
await net._handle(_msg(":server CAP * ACK :server-time"))
|
|
assert net._server_time is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_time_nak_handled(self) -> None:
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
net._caps_pending = 1
|
|
|
|
await net._handle(_msg(":server CAP * NAK :server-time"))
|
|
assert net._server_time is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_combined_sasl_and_server_time(self) -> None:
|
|
"""Both caps requested: ACK server-time + ACK sasl."""
|
|
net = _net()
|
|
net._sasl_mechanism = "PLAIN"
|
|
net._sasl_nick = "nick"
|
|
net._sasl_pass = "pass"
|
|
net._caps_pending = 2
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
|
|
# ACK server-time first
|
|
await net._handle(_msg(":server CAP * ACK :server-time"))
|
|
assert net._server_time is True
|
|
# Should NOT have sent CAP END yet (SASL still pending)
|
|
cap_end_calls = [c for c in writer.write.call_args_list
|
|
if b"CAP END" in c.args[0]]
|
|
assert len(cap_end_calls) == 0
|
|
|
|
# ACK sasl starts AUTHENTICATE flow
|
|
await net._handle(_msg(":server CAP * ACK :sasl"))
|
|
writer.write.assert_called_with(b"AUTHENTICATE PLAIN\r\n")
|
|
|
|
# SASL success resolves the last cap
|
|
await net._handle(_msg(":server 903 nick :SASL authentication successful"))
|
|
cap_end_calls = [c for c in writer.write.call_args_list
|
|
if b"CAP END" in c.args[0]]
|
|
assert len(cap_end_calls) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cap_end_after_all_resolved(self) -> None:
|
|
"""CAP END sent only after all caps are resolved."""
|
|
net = _net()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
net._writer = writer
|
|
net._caps_pending = 1 # only server-time, no SASL
|
|
|
|
await net._handle(_msg(":server CAP * ACK :server-time"))
|
|
cap_end_calls = [c for c in writer.write.call_args_list
|
|
if b"CAP END" in c.args[0]]
|
|
assert len(cap_end_calls) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_server_time_property(self) -> None:
|
|
net = _net()
|
|
assert net.server_time is False
|
|
net._server_time = True
|
|
assert net.server_time is True
|
|
|
|
|
|
# -- cred_network + ephemeral -----------------------------------------------
|
|
|
|
class TestCredNetwork:
|
|
def test_defaults_to_cfg_name(self) -> None:
|
|
"""cred_network defaults to cfg.name when not overridden."""
|
|
net = _net()
|
|
assert net.cred_network == "testnet"
|
|
|
|
def test_override(self) -> None:
|
|
"""cred_network uses explicit value when provided."""
|
|
net = Network(
|
|
cfg=_cfg(name="_farm_libera"),
|
|
proxy_cfg=_proxy(),
|
|
bouncer_cfg=_bouncer(),
|
|
cred_network="libera",
|
|
)
|
|
assert net.cred_network == "libera"
|
|
|
|
|
|
class TestEphemeral:
|
|
def test_status_suppressed(self) -> None:
|
|
"""Ephemeral _status() logs but doesn't call on_status."""
|
|
status_cb = MagicMock()
|
|
net = Network(
|
|
cfg=_cfg(),
|
|
proxy_cfg=_proxy(),
|
|
bouncer_cfg=_bouncer(),
|
|
on_status=status_cb,
|
|
ephemeral=True,
|
|
)
|
|
net._status("test message")
|
|
status_cb.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_go_ready_registers_directly(self) -> None:
|
|
"""Ephemeral _go_ready() calls _nickserv_register() directly."""
|
|
net = Network(
|
|
cfg=_cfg(),
|
|
proxy_cfg=_proxy(),
|
|
bouncer_cfg=_bouncer(),
|
|
ephemeral=True,
|
|
)
|
|
net.state = State.PROBATION
|
|
net._running = True
|
|
net.nick = "ephemeral_nick"
|
|
|
|
register_called = False
|
|
|
|
async def mock_register() -> None:
|
|
nonlocal register_called
|
|
register_called = True
|
|
# Signal done so _go_ready doesn't block forever
|
|
net._nickserv_done.set()
|
|
|
|
with patch.object(net, "_nickserv_register", side_effect=mock_register):
|
|
await net._go_ready()
|
|
|
|
assert register_called
|
|
assert net.state == State.READY
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_sasl(self) -> None:
|
|
"""Ephemeral _connect() uses random nick, no SASL."""
|
|
bl = _mock_backlog(creds_by_network=("stored_nick", "secret"))
|
|
net = Network(
|
|
cfg=_cfg(),
|
|
proxy_cfg=_proxy(),
|
|
backlog=bl,
|
|
bouncer_cfg=_bouncer(),
|
|
ephemeral=True,
|
|
)
|
|
reader = MagicMock()
|
|
writer = MagicMock()
|
|
writer.is_closing.return_value = False
|
|
writer.drain = AsyncMock()
|
|
|
|
with patch("bouncer.proxy.connect", new_callable=AsyncMock,
|
|
return_value=(reader, writer)):
|
|
with patch.object(net, "_read_loop", new_callable=AsyncMock):
|
|
await net._connect()
|
|
|
|
# Should NOT have looked up creds (ephemeral skips SASL)
|
|
bl.get_nickserv_creds_by_network.assert_not_called()
|
|
assert net._sasl_mechanism == ""
|
|
# Should have sent NICK with random nick, not stored
|
|
calls = [c.args[0] for c in writer.write.call_args_list]
|
|
cap_sasl = any(b"CAP REQ sasl" in c for c in calls)
|
|
assert not cap_sasl
|