Files
derp/tests/test_dns.py
user 26063a0e8f feat: add TCP DNS plugin with SOCKS5 proxy support
Extract shared DNS wire-format helpers into src/derp/dns.py so both
the UDP plugin (dns.py) and the new TCP plugin (tdns.py) share the
same encode/decode/build/parse logic.

The !tdns command routes queries through the SOCKS5 proxy via
derp.http.open_connection, using TCP framing (2-byte length prefix).
Default server: 1.1.1.1.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-15 16:09:35 +01:00

144 lines
4.5 KiB
Python

"""Tests for the shared DNS wire-format helpers."""
import struct
from unittest.mock import mock_open, patch
from derp.dns import (
QTYPE_NAMES,
QTYPES,
build_query,
decode_name,
encode_name,
get_resolver,
parse_response,
reverse_name,
)
class TestEncodeDecode:
def test_encode_simple(self):
result = encode_name("example.com")
assert result == b"\x07example\x03com\x00"
def test_encode_trailing_dot(self):
assert encode_name("example.com.") == encode_name("example.com")
def test_decode_simple(self):
data = b"\x07example\x03com\x00"
name, offset = decode_name(data, 0)
assert name == "example.com"
assert offset == len(data)
def test_roundtrip(self):
for domain in ("a.b.c", "example.com", "sub.domain.example.org"):
encoded = encode_name(domain)
decoded, _ = decode_name(encoded, 0)
assert decoded == domain
def test_decode_pointer(self):
# Name at offset 0, then a pointer back to it at offset 13
data = b"\x07example\x03com\x00" + b"\xc0\x00"
name, offset = decode_name(data, 13)
assert name == "example.com"
assert offset == 15
class TestBuildQuery:
def test_packet_structure(self):
pkt = build_query("example.com", QTYPES["A"])
# 2 TID + 2 flags + 8 counts + encoded name + 4 qtype/qclass
encoded = encode_name("example.com")
assert len(pkt) == 12 + len(encoded) + 4
def test_flags_rd_set(self):
pkt = build_query("example.com", QTYPES["A"])
flags = struct.unpack_from("!H", pkt, 2)[0]
assert flags & 0x0100 # RD bit set
def test_qdcount_one(self):
pkt = build_query("example.com", QTYPES["A"])
qdcount = struct.unpack_from("!H", pkt, 4)[0]
assert qdcount == 1
def test_qtype_embedded(self):
pkt = build_query("example.com", QTYPES["AAAA"])
encoded = encode_name("example.com")
qtype = struct.unpack_from("!H", pkt, 12 + len(encoded))[0]
assert qtype == 28
class TestParseResponse:
def _make_response(self, rcode=0, answers=None):
"""Build a minimal DNS response packet."""
answers = answers or []
tid = b"\x00\x01"
flags = struct.pack("!H", 0x8180 | rcode)
counts = struct.pack("!HHHH", 1, len(answers), 0, 0)
# Question section
qname = encode_name("example.com")
question = qname + struct.pack("!HH", 1, 1)
# Answer section
ans_bytes = b""
for rtype, rdata in answers:
ans_bytes += qname + struct.pack("!HHIH", rtype, 1, 300, len(rdata)) + rdata
return tid + flags + counts + question + ans_bytes
def test_a_record(self):
rdata = bytes([1, 2, 3, 4])
pkt = self._make_response(answers=[(1, rdata)])
rcode, results = parse_response(pkt)
assert rcode == 0
assert results == ["1.2.3.4"]
def test_nxdomain(self):
pkt = self._make_response(rcode=3)
rcode, results = parse_response(pkt)
assert rcode == 3
assert results == []
def test_short_packet(self):
rcode, results = parse_response(b"\x00" * 5)
assert rcode == 2
assert results == []
class TestReverseName:
def test_ipv4(self):
assert reverse_name("1.2.3.4") == "4.3.2.1.in-addr.arpa"
def test_ipv6(self):
result = reverse_name("::1")
assert result.endswith(".ip6.arpa")
# Full expansion: 32 nibble chars separated by dots
parts = result.replace(".ip6.arpa", "").split(".")
assert len(parts) == 32
def test_invalid_raises(self):
try:
reverse_name("not-an-ip")
assert False, "should have raised"
except ValueError:
pass
class TestGetResolver:
def test_reads_nameserver(self):
content = "# comment\nnameserver 192.168.1.1\nnameserver 8.8.8.8\n"
with patch("builtins.open", mock_open(read_data=content)):
assert get_resolver() == "192.168.1.1"
def test_skips_ipv6(self):
content = "nameserver ::1\nnameserver 9.9.9.9\n"
with patch("builtins.open", mock_open(read_data=content)):
assert get_resolver() == "9.9.9.9"
def test_fallback(self):
with patch("builtins.open", side_effect=OSError):
assert get_resolver() == "8.8.8.8"
class TestConstants:
def test_qtype_names_reverse(self):
for name, num in QTYPES.items():
assert QTYPE_NAMES[num] == name