"""Tests for the shared DNS wire-format helpers.""" import struct from unittest.mock import mock_open, patch from derp.dns import ( QTYPE_NAMES, QTYPES, build_query, decode_name, encode_name, get_resolver, parse_response, reverse_name, ) class TestEncodeDecode: def test_encode_simple(self): result = encode_name("example.com") assert result == b"\x07example\x03com\x00" def test_encode_trailing_dot(self): assert encode_name("example.com.") == encode_name("example.com") def test_decode_simple(self): data = b"\x07example\x03com\x00" name, offset = decode_name(data, 0) assert name == "example.com" assert offset == len(data) def test_roundtrip(self): for domain in ("a.b.c", "example.com", "sub.domain.example.org"): encoded = encode_name(domain) decoded, _ = decode_name(encoded, 0) assert decoded == domain def test_decode_pointer(self): # Name at offset 0, then a pointer back to it at offset 13 data = b"\x07example\x03com\x00" + b"\xc0\x00" name, offset = decode_name(data, 13) assert name == "example.com" assert offset == 15 class TestBuildQuery: def test_packet_structure(self): pkt = build_query("example.com", QTYPES["A"]) # 2 TID + 2 flags + 8 counts + encoded name + 4 qtype/qclass encoded = encode_name("example.com") assert len(pkt) == 12 + len(encoded) + 4 def test_flags_rd_set(self): pkt = build_query("example.com", QTYPES["A"]) flags = struct.unpack_from("!H", pkt, 2)[0] assert flags & 0x0100 # RD bit set def test_qdcount_one(self): pkt = build_query("example.com", QTYPES["A"]) qdcount = struct.unpack_from("!H", pkt, 4)[0] assert qdcount == 1 def test_qtype_embedded(self): pkt = build_query("example.com", QTYPES["AAAA"]) encoded = encode_name("example.com") qtype = struct.unpack_from("!H", pkt, 12 + len(encoded))[0] assert qtype == 28 class TestParseResponse: def _make_response(self, rcode=0, answers=None): """Build a minimal DNS response packet.""" answers = answers or [] tid = b"\x00\x01" flags = struct.pack("!H", 0x8180 | rcode) counts = struct.pack("!HHHH", 1, len(answers), 0, 0) # Question section qname = encode_name("example.com") question = qname + struct.pack("!HH", 1, 1) # Answer section ans_bytes = b"" for rtype, rdata in answers: ans_bytes += qname + struct.pack("!HHIH", rtype, 1, 300, len(rdata)) + rdata return tid + flags + counts + question + ans_bytes def test_a_record(self): rdata = bytes([1, 2, 3, 4]) pkt = self._make_response(answers=[(1, rdata)]) rcode, results = parse_response(pkt) assert rcode == 0 assert results == ["1.2.3.4"] def test_nxdomain(self): pkt = self._make_response(rcode=3) rcode, results = parse_response(pkt) assert rcode == 3 assert results == [] def test_short_packet(self): rcode, results = parse_response(b"\x00" * 5) assert rcode == 2 assert results == [] class TestReverseName: def test_ipv4(self): assert reverse_name("1.2.3.4") == "4.3.2.1.in-addr.arpa" def test_ipv6(self): result = reverse_name("::1") assert result.endswith(".ip6.arpa") # Full expansion: 32 nibble chars separated by dots parts = result.replace(".ip6.arpa", "").split(".") assert len(parts) == 32 def test_invalid_raises(self): try: reverse_name("not-an-ip") assert False, "should have raised" except ValueError: pass class TestGetResolver: def test_reads_nameserver(self): content = "# comment\nnameserver 192.168.1.1\nnameserver 8.8.8.8\n" with patch("builtins.open", mock_open(read_data=content)): assert get_resolver() == "192.168.1.1" def test_skips_ipv6(self): content = "nameserver ::1\nnameserver 9.9.9.9\n" with patch("builtins.open", mock_open(read_data=content)): assert get_resolver() == "9.9.9.9" def test_fallback(self): with patch("builtins.open", side_effect=OSError): assert get_resolver() == "8.8.8.8" class TestConstants: def test_qtype_names_reverse(self): for name, num in QTYPES.items(): assert QTYPE_NAMES[num] == name