"""Tests for the bulk DNS resolve plugin.""" import asyncio import importlib.util import struct import sys from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch from derp.dns import encode_name from derp.irc import Message # plugins/ is not a Python package -- load the module from file path _spec = importlib.util.spec_from_file_location( "plugins.resolve", Path(__file__).resolve().parent.parent / "plugins" / "resolve.py", ) _mod = importlib.util.module_from_spec(_spec) sys.modules[_spec.name] = _mod _spec.loader.exec_module(_mod) from plugins.resolve import _query_tcp, _resolve_one, cmd_resolve # noqa: E402 # -- Helpers ----------------------------------------------------------------- def _make_a_response(ip_bytes: bytes = b"\x01\x02\x03\x04") -> bytes: """Build a minimal A-record DNS response.""" tid = b"\x00\x01" flags = struct.pack("!H", 0x8180) counts = struct.pack("!HHHH", 1, 1, 0, 0) qname = encode_name("example.com") question = qname + struct.pack("!HH", 1, 1) answer = qname + struct.pack("!HHIH", 1, 1, 300, len(ip_bytes)) + ip_bytes return tid + flags + counts + question + answer def _make_nxdomain_response() -> bytes: """Build a minimal NXDOMAIN DNS response.""" tid = b"\x00\x01" flags = struct.pack("!H", 0x8183) # rcode=3 counts = struct.pack("!HHHH", 1, 0, 0, 0) qname = encode_name("nope.invalid") question = qname + struct.pack("!HH", 1, 1) return tid + flags + counts + question class _FakeBot: def __init__(self): self.replied: list[str] = [] async def reply(self, message, text: str) -> None: self.replied.append(text) def _msg(text: str) -> Message: return Message( raw="", prefix="alice!~alice@host", nick="alice", command="PRIVMSG", params=["#test", text], tags={}, ) # -- _query_tcp -------------------------------------------------------------- class TestQueryTcp: def test_a_record(self): response = _make_a_response() framed = struct.pack("!H", len(response)) + response reader = AsyncMock() reader.readexactly = AsyncMock(side_effect=[framed[:2], framed[2:]]) writer = MagicMock() writer.drain = AsyncMock() writer.wait_closed = AsyncMock() mock_open = AsyncMock(return_value=(reader, writer)) with patch.object(_mod, "_open_connection", mock_open): rcode, results = asyncio.run(_query_tcp("example.com", 1, "1.1.1.1")) assert rcode == 0 assert results == ["1.2.3.4"] def test_nxdomain(self): response = _make_nxdomain_response() framed = struct.pack("!H", len(response)) + response reader = AsyncMock() reader.readexactly = AsyncMock(side_effect=[framed[:2], framed[2:]]) writer = MagicMock() writer.drain = AsyncMock() writer.wait_closed = AsyncMock() mock_open = AsyncMock(return_value=(reader, writer)) with patch.object(_mod, "_open_connection", mock_open): rcode, results = asyncio.run(_query_tcp("nope.invalid", 1, "1.1.1.1")) assert rcode == 3 assert results == [] # -- _resolve_one ------------------------------------------------------------ class TestResolveOne: def test_success(self): mock_tcp = AsyncMock(return_value=(0, ["1.2.3.4"])) with patch.object(_mod, "_query_tcp", mock_tcp): result = asyncio.run(_resolve_one("example.com", "A", "1.1.1.1")) assert "example.com -> 1.2.3.4" == result def test_nxdomain(self): mock_tcp = AsyncMock(return_value=(3, [])) with patch.object(_mod, "_query_tcp", mock_tcp): result = asyncio.run(_resolve_one("bad.invalid", "A", "1.1.1.1")) assert "NXDOMAIN" in result def test_timeout(self): mock_tcp = AsyncMock(side_effect=asyncio.TimeoutError()) with patch.object(_mod, "_query_tcp", mock_tcp): result = asyncio.run(_resolve_one("slow.example.com", "A", "1.1.1.1")) assert "timeout" in result def test_error(self): mock_tcp = AsyncMock(side_effect=OSError("connection refused")) with patch.object(_mod, "_query_tcp", mock_tcp): result = asyncio.run(_resolve_one("down.example.com", "A", "1.1.1.1")) assert "error" in result def test_ptr_auto(self): mock_tcp = AsyncMock(return_value=(0, ["dns.google"])) with patch.object(_mod, "_query_tcp", mock_tcp): result = asyncio.run(_resolve_one("8.8.8.8", "PTR", "1.1.1.1")) assert "dns.google" in result def test_ptr_invalid_ip(self): result = asyncio.run(_resolve_one("not-an-ip", "PTR", "1.1.1.1")) assert "invalid IP" in result def test_no_records(self): mock_tcp = AsyncMock(return_value=(0, [])) with patch.object(_mod, "_query_tcp", mock_tcp): result = asyncio.run(_resolve_one("empty.example.com", "A", "1.1.1.1")) assert "no records" in result def test_multiple_results(self): mock_tcp = AsyncMock(return_value=(0, ["1.1.1.1", "1.0.0.1"])) with patch.object(_mod, "_query_tcp", mock_tcp): result = asyncio.run(_resolve_one("multi.example.com", "A", "1.1.1.1")) assert "1.1.1.1, 1.0.0.1" in result # -- Command handler --------------------------------------------------------- class TestCmdResolve: def test_no_args(self): bot = _FakeBot() asyncio.run(cmd_resolve(bot, _msg("!resolve"))) assert "Usage" in bot.replied[0] def test_single_host(self): bot = _FakeBot() mock_tcp = AsyncMock(return_value=(0, ["93.184.216.34"])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_resolve(bot, _msg("!resolve example.com"))) assert len(bot.replied) == 1 assert "example.com -> 93.184.216.34" in bot.replied[0] def test_multiple_hosts(self): bot = _FakeBot() async def fake_tcp(name, qtype, server, timeout=5.0): if "example" in name: return 0, ["93.184.216.34"] return 0, ["140.82.121.3"] with patch.object(_mod, "_query_tcp", fake_tcp): asyncio.run(cmd_resolve(bot, _msg("!resolve example.com github.com"))) assert len(bot.replied) == 2 assert "93.184.216.34" in bot.replied[0] assert "140.82.121.3" in bot.replied[1] def test_explicit_type(self): bot = _FakeBot() mock_tcp = AsyncMock(return_value=(0, ["2606:2800:220:1:248:1893:25c8:1946"])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_resolve(bot, _msg("!resolve example.com AAAA"))) assert "2606:" in bot.replied[0] # Verify AAAA qtype (28) was used call_args = mock_tcp.call_args[0] assert call_args[1] == 28 def test_ip_auto_ptr(self): bot = _FakeBot() mock_tcp = AsyncMock(return_value=(0, ["dns.google"])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_resolve(bot, _msg("!resolve 8.8.8.8"))) assert "dns.google" in bot.replied[0] def test_type_only_no_hosts(self): bot = _FakeBot() asyncio.run(cmd_resolve(bot, _msg("!resolve AAAA"))) assert "Usage" in bot.replied[0] def test_nxdomain(self): bot = _FakeBot() mock_tcp = AsyncMock(return_value=(3, [])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_resolve(bot, _msg("!resolve bad.invalid"))) assert "NXDOMAIN" in bot.replied[0] def test_max_hosts(self): """Hosts beyond MAX_HOSTS are truncated.""" bot = _FakeBot() hosts = " ".join(f"h{i}.example.com" for i in range(15)) mock_tcp = AsyncMock(return_value=(0, ["1.2.3.4"])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_resolve(bot, _msg(f"!resolve {hosts}"))) # 10 results + 1 truncation note assert len(bot.replied) == 11 assert "showing first 10" in bot.replied[-1]