"""Tests for TCP server with optional TLS.""" from __future__ import annotations import asyncio import ssl from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from bouncer.cert import generate_listener_cert from bouncer.config import BouncerConfig from bouncer.server import start def _bouncer_cfg(**overrides) -> BouncerConfig: defaults = {"bind": "127.0.0.1", "port": 0} # port 0 = OS-assigned defaults.update(overrides) return BouncerConfig(**defaults) def _mock_router() -> MagicMock: return MagicMock() def _make_ssl_ctx(data_dir: Path) -> ssl.SSLContext: """Build a server SSL context from an auto-generated listener cert.""" pem = generate_listener_cert(data_dir) ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.minimum_version = ssl.TLSVersion.TLSv1_2 ctx.load_cert_chain(certfile=str(pem)) return ctx def _make_client_ssl_ctx() -> ssl.SSLContext: """Build a client SSL context that trusts any self-signed cert.""" ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ctx.check_hostname = False ctx.verify_mode = ssl.CERT_NONE return ctx @pytest.fixture def data_dir(tmp_path: Path) -> Path: return tmp_path class TestStartPlaintext: async def test_accepts_connection(self) -> None: """Plaintext listener starts and accepts a TCP connection.""" cfg = _bouncer_cfg() router = _mock_router() with patch("bouncer.server.Client") as mock_client_cls: mock_client_cls.return_value.handle = AsyncMock() server = await start(cfg, router) addr = server.sockets[0].getsockname() reader, writer = await asyncio.open_connection(addr[0], addr[1]) await asyncio.sleep(0.05) assert mock_client_cls.called writer.close() await writer.wait_closed() server.close() class TestStartWithTLS: async def test_accepts_tls_connection(self, data_dir: Path) -> None: """TLS listener starts and accepts a TLS connection.""" cfg = _bouncer_cfg() router = _mock_router() ssl_ctx = _make_ssl_ctx(data_dir) with patch("bouncer.server.Client") as mock_client_cls: mock_client_cls.return_value.handle = AsyncMock() server = await start(cfg, router, ssl_ctx=ssl_ctx) addr = server.sockets[0].getsockname() client_ctx = _make_client_ssl_ctx() reader, writer = await asyncio.open_connection( addr[0], addr[1], ssl=client_ctx, ) await asyncio.sleep(0.05) assert mock_client_cls.called writer.close() await writer.wait_closed() server.close() async def test_tls_handshake_and_auth(self, data_dir: Path) -> None: """TLS handshake succeeds and IRC data flows encrypted.""" cfg = _bouncer_cfg() router = _mock_router() ssl_ctx = _make_ssl_ctx(data_dir) received_lines: list[bytes] = [] async def _fake_handle(obj: MagicMock) -> None: """Minimal handler: read one line, echo a 001.""" data = await obj._reader.readline() received_lines.append(data) obj._writer.write(b":bouncer 001 test :Welcome\r\n") await obj._writer.drain() def _make_client(reader, writer, router_, password_): obj = MagicMock() obj._reader = reader obj._writer = writer obj.handle = lambda: _fake_handle(obj) return obj with patch("bouncer.server.Client", side_effect=_make_client): server = await start(cfg, router, ssl_ctx=ssl_ctx) addr = server.sockets[0].getsockname() client_ctx = _make_client_ssl_ctx() reader, writer = await asyncio.open_connection( addr[0], addr[1], ssl=client_ctx, ) writer.write(b"PASS testpass\r\n") await writer.drain() response = await asyncio.wait_for(reader.readline(), timeout=2.0) assert b"001" in response writer.close() await writer.wait_closed() server.close() assert len(received_lines) == 1 assert b"PASS testpass" in received_lines[0] async def test_plaintext_rejected_on_tls(self, data_dir: Path) -> None: """Non-TLS bytes on a TLS listener get dropped.""" cfg = _bouncer_cfg() router = _mock_router() ssl_ctx = _make_ssl_ctx(data_dir) with patch("bouncer.server.Client") as mock_client_cls: mock_client_cls.return_value.handle = AsyncMock() server = await start(cfg, router, ssl_ctx=ssl_ctx) addr = server.sockets[0].getsockname() # Connect without TLS to a TLS listener reader, writer = await asyncio.open_connection(addr[0], addr[1]) writer.write(b"PASS hello\r\n") await writer.drain() # Server should close the connection (EOF) data = await asyncio.wait_for(reader.read(1024), timeout=2.0) assert data == b"" writer.close() await writer.wait_closed() server.close()