"""Tests for the TCP DNS plugin.""" import asyncio import importlib.util import struct import sys from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest from derp.dns import encode_name # plugins/ is not a Python package -- load the module from file path _spec = importlib.util.spec_from_file_location( "plugins.tdns", Path(__file__).resolve().parent.parent / "plugins" / "tdns.py", ) _mod = importlib.util.module_from_spec(_spec) sys.modules[_spec.name] = _mod _spec.loader.exec_module(_mod) from plugins.tdns import _query_tcp, cmd_tdns # noqa: E402 def _make_a_response(ip_bytes: bytes = b"\x01\x02\x03\x04") -> bytes: """Build a minimal A-record DNS response.""" tid = b"\x00\x01" flags = struct.pack("!H", 0x8180) counts = struct.pack("!HHHH", 1, 1, 0, 0) qname = encode_name("example.com") question = qname + struct.pack("!HH", 1, 1) answer = qname + struct.pack("!HHIH", 1, 1, 300, len(ip_bytes)) + ip_bytes return tid + flags + counts + question + answer class TestQueryTcp: def test_sends_length_prefixed_query(self): response = _make_a_response() framed = struct.pack("!H", len(response)) + response reader = AsyncMock() reader.readexactly = AsyncMock(side_effect=[ framed[:2], # length prefix framed[2:], # payload ]) writer = MagicMock() writer.drain = AsyncMock() writer.wait_closed = AsyncMock() mock_open = AsyncMock(return_value=(reader, writer)) with patch.object(_mod, "_open_connection", mock_open): rcode, results = asyncio.run(_query_tcp("example.com", 1, "1.1.1.1")) assert rcode == 0 assert results == ["1.2.3.4"] # Verify the written data has a 2-byte length prefix written = writer.write.call_args[0][0] pkt_len = struct.unpack("!H", written[:2])[0] assert pkt_len == len(written) - 2 def test_closes_writer_on_error(self): reader = AsyncMock() reader.readexactly = AsyncMock(side_effect=asyncio.IncompleteReadError(b"", 2)) writer = MagicMock() writer.drain = AsyncMock() writer.wait_closed = AsyncMock() mock_open = AsyncMock(return_value=(reader, writer)) with patch.object(_mod, "_open_connection", mock_open): with pytest.raises(asyncio.IncompleteReadError): asyncio.run(_query_tcp("example.com", 1, "1.1.1.1")) writer.close.assert_called_once() class TestCmdTdns: def test_no_args(self): bot = AsyncMock() msg = MagicMock() msg.text = "!tdns" asyncio.run(cmd_tdns(bot, msg)) bot.reply.assert_called_once() assert "Usage" in bot.reply.call_args[0][1] def test_ip_auto_ptr(self): bot = AsyncMock() msg = MagicMock() msg.text = "!tdns 1.2.3.4" mock_tcp = AsyncMock(return_value=(0, ["host.example.com"])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_tdns(bot, msg)) reply = bot.reply.call_args[0][1] assert "PTR" in reply assert "host.example.com" in reply def test_explicit_type(self): bot = AsyncMock() msg = MagicMock() msg.text = "!tdns example.com MX" mock_tcp = AsyncMock(return_value=(0, ["10 mail.example.com"])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_tdns(bot, msg)) reply = bot.reply.call_args[0][1] assert "MX" in reply def test_custom_server(self): bot = AsyncMock() msg = MagicMock() msg.text = "!tdns example.com A @8.8.8.8" mock_tcp = AsyncMock(return_value=(0, ["93.184.216.34"])) with patch.object(_mod, "_query_tcp", mock_tcp): asyncio.run(cmd_tdns(bot, msg)) assert mock_tcp.call_args[0][2] == "8.8.8.8" def test_unknown_type(self): bot = AsyncMock() msg = MagicMock() msg.text = "!tdns example.com BOGUS" asyncio.run(cmd_tdns(bot, msg)) reply = bot.reply.call_args[0][1] assert "Unknown type" in reply