feat: add TCP DNS plugin with SOCKS5 proxy support
Extract shared DNS wire-format helpers into src/derp/dns.py so both the UDP plugin (dns.py) and the new TCP plugin (tdns.py) share the same encode/decode/build/parse logic. The !tdns command routes queries through the SOCKS5 proxy via derp.http.open_connection, using TCP framing (2-byte length prefix). Default server: 1.1.1.1. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
143
tests/test_dns.py
Normal file
143
tests/test_dns.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for the shared DNS wire-format helpers."""
|
||||
|
||||
import struct
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from derp.dns import (
|
||||
QTYPE_NAMES,
|
||||
QTYPES,
|
||||
build_query,
|
||||
decode_name,
|
||||
encode_name,
|
||||
get_resolver,
|
||||
parse_response,
|
||||
reverse_name,
|
||||
)
|
||||
|
||||
|
||||
class TestEncodeDecode:
|
||||
def test_encode_simple(self):
|
||||
result = encode_name("example.com")
|
||||
assert result == b"\x07example\x03com\x00"
|
||||
|
||||
def test_encode_trailing_dot(self):
|
||||
assert encode_name("example.com.") == encode_name("example.com")
|
||||
|
||||
def test_decode_simple(self):
|
||||
data = b"\x07example\x03com\x00"
|
||||
name, offset = decode_name(data, 0)
|
||||
assert name == "example.com"
|
||||
assert offset == len(data)
|
||||
|
||||
def test_roundtrip(self):
|
||||
for domain in ("a.b.c", "example.com", "sub.domain.example.org"):
|
||||
encoded = encode_name(domain)
|
||||
decoded, _ = decode_name(encoded, 0)
|
||||
assert decoded == domain
|
||||
|
||||
def test_decode_pointer(self):
|
||||
# Name at offset 0, then a pointer back to it at offset 13
|
||||
data = b"\x07example\x03com\x00" + b"\xc0\x00"
|
||||
name, offset = decode_name(data, 13)
|
||||
assert name == "example.com"
|
||||
assert offset == 15
|
||||
|
||||
|
||||
class TestBuildQuery:
|
||||
def test_packet_structure(self):
|
||||
pkt = build_query("example.com", QTYPES["A"])
|
||||
# 2 TID + 2 flags + 8 counts + encoded name + 4 qtype/qclass
|
||||
encoded = encode_name("example.com")
|
||||
assert len(pkt) == 12 + len(encoded) + 4
|
||||
|
||||
def test_flags_rd_set(self):
|
||||
pkt = build_query("example.com", QTYPES["A"])
|
||||
flags = struct.unpack_from("!H", pkt, 2)[0]
|
||||
assert flags & 0x0100 # RD bit set
|
||||
|
||||
def test_qdcount_one(self):
|
||||
pkt = build_query("example.com", QTYPES["A"])
|
||||
qdcount = struct.unpack_from("!H", pkt, 4)[0]
|
||||
assert qdcount == 1
|
||||
|
||||
def test_qtype_embedded(self):
|
||||
pkt = build_query("example.com", QTYPES["AAAA"])
|
||||
encoded = encode_name("example.com")
|
||||
qtype = struct.unpack_from("!H", pkt, 12 + len(encoded))[0]
|
||||
assert qtype == 28
|
||||
|
||||
|
||||
class TestParseResponse:
|
||||
def _make_response(self, rcode=0, answers=None):
|
||||
"""Build a minimal DNS response packet."""
|
||||
answers = answers or []
|
||||
tid = b"\x00\x01"
|
||||
flags = struct.pack("!H", 0x8180 | rcode)
|
||||
counts = struct.pack("!HHHH", 1, len(answers), 0, 0)
|
||||
# Question section
|
||||
qname = encode_name("example.com")
|
||||
question = qname + struct.pack("!HH", 1, 1)
|
||||
# Answer section
|
||||
ans_bytes = b""
|
||||
for rtype, rdata in answers:
|
||||
ans_bytes += qname + struct.pack("!HHIH", rtype, 1, 300, len(rdata)) + rdata
|
||||
return tid + flags + counts + question + ans_bytes
|
||||
|
||||
def test_a_record(self):
|
||||
rdata = bytes([1, 2, 3, 4])
|
||||
pkt = self._make_response(answers=[(1, rdata)])
|
||||
rcode, results = parse_response(pkt)
|
||||
assert rcode == 0
|
||||
assert results == ["1.2.3.4"]
|
||||
|
||||
def test_nxdomain(self):
|
||||
pkt = self._make_response(rcode=3)
|
||||
rcode, results = parse_response(pkt)
|
||||
assert rcode == 3
|
||||
assert results == []
|
||||
|
||||
def test_short_packet(self):
|
||||
rcode, results = parse_response(b"\x00" * 5)
|
||||
assert rcode == 2
|
||||
assert results == []
|
||||
|
||||
|
||||
class TestReverseName:
|
||||
def test_ipv4(self):
|
||||
assert reverse_name("1.2.3.4") == "4.3.2.1.in-addr.arpa"
|
||||
|
||||
def test_ipv6(self):
|
||||
result = reverse_name("::1")
|
||||
assert result.endswith(".ip6.arpa")
|
||||
# Full expansion: 32 nibble chars separated by dots
|
||||
parts = result.replace(".ip6.arpa", "").split(".")
|
||||
assert len(parts) == 32
|
||||
|
||||
def test_invalid_raises(self):
|
||||
try:
|
||||
reverse_name("not-an-ip")
|
||||
assert False, "should have raised"
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
|
||||
class TestGetResolver:
|
||||
def test_reads_nameserver(self):
|
||||
content = "# comment\nnameserver 192.168.1.1\nnameserver 8.8.8.8\n"
|
||||
with patch("builtins.open", mock_open(read_data=content)):
|
||||
assert get_resolver() == "192.168.1.1"
|
||||
|
||||
def test_skips_ipv6(self):
|
||||
content = "nameserver ::1\nnameserver 9.9.9.9\n"
|
||||
with patch("builtins.open", mock_open(read_data=content)):
|
||||
assert get_resolver() == "9.9.9.9"
|
||||
|
||||
def test_fallback(self):
|
||||
with patch("builtins.open", side_effect=OSError):
|
||||
assert get_resolver() == "8.8.8.8"
|
||||
|
||||
|
||||
class TestConstants:
|
||||
def test_qtype_names_reverse(self):
|
||||
for name, num in QTYPES.items():
|
||||
assert QTYPE_NAMES[num] == name
|
||||
128
tests/test_tdns.py
Normal file
128
tests/test_tdns.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Tests for the TCP DNS plugin."""
|
||||
|
||||
import asyncio
|
||||
import importlib.util
|
||||
import struct
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from derp.dns import encode_name
|
||||
|
||||
# plugins/ is not a Python package -- load the module from file path
|
||||
_spec = importlib.util.spec_from_file_location(
|
||||
"plugins.tdns", Path(__file__).resolve().parent.parent / "plugins" / "tdns.py",
|
||||
)
|
||||
_mod = importlib.util.module_from_spec(_spec)
|
||||
sys.modules[_spec.name] = _mod
|
||||
_spec.loader.exec_module(_mod)
|
||||
|
||||
from plugins.tdns import _query_tcp, cmd_tdns # noqa: E402
|
||||
|
||||
|
||||
def _make_a_response(ip_bytes: bytes = b"\x01\x02\x03\x04") -> bytes:
|
||||
"""Build a minimal A-record DNS response."""
|
||||
tid = b"\x00\x01"
|
||||
flags = struct.pack("!H", 0x8180)
|
||||
counts = struct.pack("!HHHH", 1, 1, 0, 0)
|
||||
qname = encode_name("example.com")
|
||||
question = qname + struct.pack("!HH", 1, 1)
|
||||
answer = qname + struct.pack("!HHIH", 1, 1, 300, len(ip_bytes)) + ip_bytes
|
||||
return tid + flags + counts + question + answer
|
||||
|
||||
|
||||
class TestQueryTcp:
|
||||
def test_sends_length_prefixed_query(self):
|
||||
response = _make_a_response()
|
||||
framed = struct.pack("!H", len(response)) + response
|
||||
|
||||
reader = AsyncMock()
|
||||
reader.readexactly = AsyncMock(side_effect=[
|
||||
framed[:2], # length prefix
|
||||
framed[2:], # payload
|
||||
])
|
||||
writer = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
writer.wait_closed = AsyncMock()
|
||||
|
||||
mock_open = AsyncMock(return_value=(reader, writer))
|
||||
with patch.object(_mod, "_open_connection", mock_open):
|
||||
rcode, results = asyncio.run(_query_tcp("example.com", 1, "1.1.1.1"))
|
||||
|
||||
assert rcode == 0
|
||||
assert results == ["1.2.3.4"]
|
||||
|
||||
# Verify the written data has a 2-byte length prefix
|
||||
written = writer.write.call_args[0][0]
|
||||
pkt_len = struct.unpack("!H", written[:2])[0]
|
||||
assert pkt_len == len(written) - 2
|
||||
|
||||
def test_closes_writer_on_error(self):
|
||||
reader = AsyncMock()
|
||||
reader.readexactly = AsyncMock(side_effect=asyncio.IncompleteReadError(b"", 2))
|
||||
writer = MagicMock()
|
||||
writer.drain = AsyncMock()
|
||||
writer.wait_closed = AsyncMock()
|
||||
|
||||
mock_open = AsyncMock(return_value=(reader, writer))
|
||||
with patch.object(_mod, "_open_connection", mock_open):
|
||||
with pytest.raises(asyncio.IncompleteReadError):
|
||||
asyncio.run(_query_tcp("example.com", 1, "1.1.1.1"))
|
||||
|
||||
writer.close.assert_called_once()
|
||||
|
||||
|
||||
class TestCmdTdns:
|
||||
def test_no_args(self):
|
||||
bot = AsyncMock()
|
||||
msg = MagicMock()
|
||||
msg.text = "!tdns"
|
||||
asyncio.run(cmd_tdns(bot, msg))
|
||||
bot.reply.assert_called_once()
|
||||
assert "Usage" in bot.reply.call_args[0][1]
|
||||
|
||||
def test_ip_auto_ptr(self):
|
||||
bot = AsyncMock()
|
||||
msg = MagicMock()
|
||||
msg.text = "!tdns 1.2.3.4"
|
||||
|
||||
mock_tcp = AsyncMock(return_value=(0, ["host.example.com"]))
|
||||
with patch.object(_mod, "_query_tcp", mock_tcp):
|
||||
asyncio.run(cmd_tdns(bot, msg))
|
||||
|
||||
reply = bot.reply.call_args[0][1]
|
||||
assert "PTR" in reply
|
||||
assert "host.example.com" in reply
|
||||
|
||||
def test_explicit_type(self):
|
||||
bot = AsyncMock()
|
||||
msg = MagicMock()
|
||||
msg.text = "!tdns example.com MX"
|
||||
|
||||
mock_tcp = AsyncMock(return_value=(0, ["10 mail.example.com"]))
|
||||
with patch.object(_mod, "_query_tcp", mock_tcp):
|
||||
asyncio.run(cmd_tdns(bot, msg))
|
||||
|
||||
reply = bot.reply.call_args[0][1]
|
||||
assert "MX" in reply
|
||||
|
||||
def test_custom_server(self):
|
||||
bot = AsyncMock()
|
||||
msg = MagicMock()
|
||||
msg.text = "!tdns example.com A @8.8.8.8"
|
||||
|
||||
mock_tcp = AsyncMock(return_value=(0, ["93.184.216.34"]))
|
||||
with patch.object(_mod, "_query_tcp", mock_tcp):
|
||||
asyncio.run(cmd_tdns(bot, msg))
|
||||
|
||||
assert mock_tcp.call_args[0][2] == "8.8.8.8"
|
||||
|
||||
def test_unknown_type(self):
|
||||
bot = AsyncMock()
|
||||
msg = MagicMock()
|
||||
msg.text = "!tdns example.com BOGUS"
|
||||
asyncio.run(cmd_tdns(bot, msg))
|
||||
reply = bot.reply.call_args[0][1]
|
||||
assert "Unknown type" in reply
|
||||
Reference in New Issue
Block a user