diff --git a/README.md b/README.md index 57a1494..f00c2f0 100644 --- a/README.md +++ b/README.md @@ -15,7 +15,8 @@ make run ### Container ```bash -make up # Build + start with podman-compose +make build # Build image (only after dependency changes) +make up # Start with podman-compose make logs # Follow logs make down # Stop ``` @@ -37,6 +38,7 @@ make down # Stop |--------|----------|-------------| | core | ping, help, version, uptime, whoami, admins, load, reload, unload, plugins, state | Bot management | | dns | dns | Raw UDP DNS resolver (A/AAAA/MX/NS/TXT/CNAME/PTR/SOA) | +| tdns | tdns | TCP DNS resolver via SOCKS5 proxy (same record types) | | encode | encode, decode | Base64, hex, URL, ROT13 | | hash | hash, hashid | Hash generation + type identification | | defang | defang, refang | IOC defanging for safe sharing | @@ -63,6 +65,11 @@ make down # Stop | payload | payload | SQLi/XSS/SSTI/LFI/CMDi/XXE templates | | dork | dork | Google dork query builder | | wayback | wayback | Wayback Machine snapshot lookup | +| username | username | Username enumeration across ~25 services | +| remind | remind | One-shot, repeating, and calendar reminders | +| rss | rss | RSS/Atom feed subscriptions with polling | +| youtube | yt | YouTube channel follow with new-video alerts | +| twitch | twitch | Twitch livestream notifications (public GQL) | | chanmgmt | kick, ban, unban, topic, mode | Channel management (admin) | | example | echo | Demo plugin | @@ -90,7 +97,7 @@ async def on_join(bot, message): | `make lint` | Lint with ruff | | `make run` | Start the bot (bare metal) | | `make link` | Symlink to `~/.local/bin/` | -| `make build` | Build container image | +| `make build` | Build container image (only for dependency changes) | | `make up` | Start with podman-compose | | `make down` | Stop with podman-compose | | `make logs` | Follow compose logs | diff --git a/docs/CHEATSHEET.md b/docs/CHEATSHEET.md index ee8d9a7..cff100d 100644 --- a/docs/CHEATSHEET.md +++ b/docs/CHEATSHEET.md @@ -152,9 +152,11 @@ files, login. !username list # List services by category !username john # Full scan (~25 services) !username john github # Check single service -!dns example.com # A record lookup +!dns example.com # A record lookup (UDP, local resolver) !dns 1.2.3.4 # Reverse PTR lookup !dns example.com MX # Specific type (A/AAAA/MX/NS/TXT/CNAME/PTR/SOA) +!tdns example.com # A record lookup (TCP via SOCKS5 proxy) +!tdns example.com MX @8.8.8.8 # Explicit type + custom server !cert example.com # CT log lookup (max 5 domains) !whois example.com # WHOIS domain lookup !whois 8.8.8.8 # WHOIS IP lookup diff --git a/docs/USAGE.md b/docs/USAGE.md index be9204b..deecbbd 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -82,6 +82,7 @@ format = "text" # Log format: "text" (default) or "json" | `!topic [text]` | Set or query channel topic (admin) | | `!mode [args]` | Set channel mode (admin) | | `!dns [type]` | DNS lookup (A, AAAA, MX, NS, TXT, CNAME, PTR, SOA) | +| `!tdns [type] [@server]` | TCP DNS lookup via SOCKS5 proxy | | `!encode ` | Encode text (b64, hex, url, rot13) | | `!decode ` | Decode text (b64, hex, url, rot13) | | `!hash [algo] ` | Generate hash digests (md5, sha1, sha256, sha512) | diff --git a/plugins/dns.py b/plugins/dns.py index 5b029ef..36436c5 100644 --- a/plugins/dns.py +++ b/plugins/dns.py @@ -4,148 +4,23 @@ from __future__ import annotations import asyncio import ipaddress -import os import socket -import struct +from derp.dns import ( + QTYPES, + RCODES, + build_query, + get_resolver, + parse_response, + reverse_name, +) from derp.plugin import command -_QTYPES = { - "A": 1, "NS": 2, "CNAME": 5, "SOA": 6, - "PTR": 12, "MX": 15, "TXT": 16, "AAAA": 28, -} -_QTYPE_NAMES = {v: k for k, v in _QTYPES.items()} -_RCODES = { - 0: "", 1: "FORMERR", 2: "SERVFAIL", 3: "NXDOMAIN", - 4: "NOTIMP", 5: "REFUSED", -} - - -# -- wire format helpers -- - -def _get_resolver() -> str: - """Read first IPv4 nameserver from /etc/resolv.conf.""" - try: - with open("/etc/resolv.conf") as f: - for line in f: - line = line.strip() - if line.startswith("nameserver"): - addr = line.split()[1] - try: - ipaddress.IPv4Address(addr) - return addr - except ValueError: - continue - except (OSError, IndexError): - pass - return "8.8.8.8" - - -def _encode_name(name: str) -> bytes: - """Encode a domain name into DNS wire format.""" - out = b"" - for label in name.rstrip(".").split("."): - out += bytes([len(label)]) + label.encode("ascii") - return out + b"\x00" - - -def _decode_name(data: bytes, offset: int) -> tuple[str, int]: - """Decode a DNS name with pointer compression.""" - labels: list[str] = [] - jumped = False - ret_offset = offset - jumps = 0 - while offset < len(data): - length = data[offset] - if length == 0: - if not jumped: - ret_offset = offset + 1 - break - if (length & 0xC0) == 0xC0: - if not jumped: - ret_offset = offset + 2 - ptr = struct.unpack_from("!H", data, offset)[0] & 0x3FFF - offset = ptr - jumped = True - jumps += 1 - if jumps > 20: - break - continue - offset += 1 - labels.append(data[offset:offset + length].decode("ascii", errors="replace")) - offset += length - if not jumped: - ret_offset = offset - return ".".join(labels), ret_offset - - -def _build_query(name: str, qtype: int) -> bytes: - """Build a DNS query packet.""" - tid = os.urandom(2) - flags = struct.pack("!H", 0x0100) - counts = struct.pack("!HHHH", 1, 0, 0, 0) - return tid + flags + counts + _encode_name(name) + struct.pack("!HH", qtype, 1) - - -def _parse_rdata(rtype: int, data: bytes, offset: int, rdlength: int) -> str: - """Parse an RR's rdata into a human-readable string.""" - rdata = data[offset:offset + rdlength] - if rtype == 1 and rdlength == 4: - return socket.inet_ntoa(rdata) - if rtype == 28 and rdlength == 16: - return socket.inet_ntop(socket.AF_INET6, rdata) - if rtype in (2, 5, 12): # NS, CNAME, PTR - name, _ = _decode_name(data, offset) - return name - if rtype == 15: # MX - pref = struct.unpack_from("!H", rdata, 0)[0] - mx, _ = _decode_name(data, offset + 2) - return f"{pref} {mx}" - if rtype == 16: # TXT - parts: list[str] = [] - pos = 0 - while pos < rdlength: - tlen = rdata[pos] - pos += 1 - parts.append(rdata[pos:pos + tlen].decode("utf-8", errors="replace")) - pos += tlen - return "".join(parts) - if rtype == 6: # SOA - mname, off = _decode_name(data, offset) - rname, off = _decode_name(data, off) - serial = struct.unpack_from("!I", data, off)[0] - return f"{mname} {rname} {serial}" - return rdata.hex() - - -def _parse_response(data: bytes) -> tuple[int, list[str]]: - """Parse a DNS response, returning (rcode, [values]).""" - if len(data) < 12: - return 2, [] - _, flags, qdcount, ancount = struct.unpack_from("!HHHH", data, 0) - rcode = flags & 0x0F - offset = 12 - for _ in range(qdcount): - _, offset = _decode_name(data, offset) - offset += 4 - results: list[str] = [] - for _ in range(ancount): - if offset + 10 > len(data): - break - _, offset = _decode_name(data, offset) - rtype, _, _, rdlength = struct.unpack_from("!HHIH", data, offset) - offset += 10 - if offset + rdlength > len(data): - break - results.append(_parse_rdata(rtype, data, offset, rdlength)) - offset += rdlength - return rcode, results - async def _query(name: str, qtype: int, server: str, timeout: float = 5.0) -> tuple[int, list[str]]: - """Send a DNS query and return (rcode, [values]).""" - query = _build_query(name, qtype) + """Send a DNS query over UDP and return (rcode, [values]).""" + query = build_query(name, qtype) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) loop = asyncio.get_running_loop() @@ -155,7 +30,7 @@ async def _query(name: str, qtype: int, server: str, loop.run_in_executor(None, sock.recv, 4096), timeout=timeout, ) - return _parse_response(data) + return parse_response(data) except (TimeoutError, socket.timeout): return -1, [] except OSError: @@ -164,15 +39,6 @@ async def _query(name: str, qtype: int, server: str, sock.close() -def _reverse_name(addr: str) -> str: - """Convert an IP address to its reverse DNS name.""" - ip = ipaddress.ip_address(addr) - if isinstance(ip, ipaddress.IPv4Address): - return ".".join(reversed(addr.split("."))) + ".in-addr.arpa" - expanded = ip.exploded.replace(":", "") - return ".".join(reversed(expanded)) + ".ip6.arpa" - - @command("dns", help="DNS lookup: !dns [A|AAAA|MX|NS|TXT|CNAME|PTR|SOA]") async def cmd_dns(bot, message): """Query DNS records for a domain or reverse-lookup an IP.""" @@ -192,21 +58,21 @@ async def cmd_dns(bot, message): except ValueError: qtype_str = "A" - qtype = _QTYPES.get(qtype_str) + qtype = QTYPES.get(qtype_str) if qtype is None: - valid = ", ".join(sorted(_QTYPES)) + valid = ", ".join(sorted(QTYPES)) await bot.reply(message, f"Unknown type: {qtype_str} (valid: {valid})") return lookup = target if qtype_str == "PTR": try: - lookup = _reverse_name(target) + lookup = reverse_name(target) except ValueError: await bot.reply(message, f"Invalid IP for PTR: {target}") return - server = _get_resolver() + server = get_resolver() rcode, results = await _query(lookup, qtype, server) if rcode == -1: @@ -214,7 +80,7 @@ async def cmd_dns(bot, message): elif rcode == -2: await bot.reply(message, f"{target} {qtype_str}: network error") elif rcode != 0: - err = _RCODES.get(rcode, f"error {rcode}") + err = RCODES.get(rcode, f"error {rcode}") await bot.reply(message, f"{target} {qtype_str}: {err}") elif not results: await bot.reply(message, f"{target} {qtype_str}: no records") diff --git a/plugins/tdns.py b/plugins/tdns.py new file mode 100644 index 0000000..5c49242 --- /dev/null +++ b/plugins/tdns.py @@ -0,0 +1,96 @@ +"""Plugin: DNS record lookup over TCP (SOCKS5-proxied).""" + +from __future__ import annotations + +import asyncio +import ipaddress +import struct + +from derp.dns import ( + QTYPES, + RCODES, + build_query, + parse_response, + reverse_name, +) +from derp.http import open_connection as _open_connection +from derp.plugin import command + +_DEFAULT_SERVER = "1.1.1.1" +_TIMEOUT = 5.0 + + +async def _query_tcp(name: str, qtype: int, server: str, + timeout: float = _TIMEOUT) -> tuple[int, list[str]]: + """Send a DNS query over TCP and return (rcode, [values]).""" + reader, writer = await asyncio.wait_for( + _open_connection(server, 53, timeout=timeout), timeout=timeout, + ) + try: + pkt = build_query(name, qtype) + writer.write(struct.pack("!H", len(pkt)) + pkt) + await writer.drain() + length = struct.unpack("!H", await reader.readexactly(2))[0] + data = await reader.readexactly(length) + return parse_response(data) + finally: + writer.close() + await writer.wait_closed() + + +@command("tdns", help="TCP DNS lookup: !tdns [type] [@server]") +async def cmd_tdns(bot, message): + """Query DNS records over TCP (routed through SOCKS5 proxy).""" + parts = message.text.split() + if len(parts) < 2: + await bot.reply(message, "Usage: !tdns [type] [@server]") + return + + target = parts[1] + qtype_str = None + server = _DEFAULT_SERVER + + for arg in parts[2:]: + if arg.startswith("@"): + server = arg[1:] + elif qtype_str is None: + qtype_str = arg.upper() + + # Auto-detect: IP -> PTR, domain -> A + if qtype_str is None: + try: + ipaddress.ip_address(target) + qtype_str = "PTR" + except ValueError: + qtype_str = "A" + + qtype = QTYPES.get(qtype_str) + if qtype is None: + valid = ", ".join(sorted(QTYPES)) + await bot.reply(message, f"Unknown type: {qtype_str} (valid: {valid})") + return + + lookup = target + if qtype_str == "PTR": + try: + lookup = reverse_name(target) + except ValueError: + await bot.reply(message, f"Invalid IP for PTR: {target}") + return + + try: + rcode, results = await _query_tcp(lookup, qtype, server) + except (TimeoutError, asyncio.TimeoutError): + await bot.reply(message, f"{target} {qtype_str}: timeout") + return + except OSError as exc: + await bot.reply(message, f"{target} {qtype_str}: connection error: {exc}") + return + + if rcode != 0: + err = RCODES.get(rcode, f"error {rcode}") + await bot.reply(message, f"{target} {qtype_str}: {err}") + elif not results: + await bot.reply(message, f"{target} {qtype_str}: no records") + else: + await bot.reply(message, f"{target} {qtype_str}: {', '.join(results)}") diff --git a/src/derp/dns.py b/src/derp/dns.py new file mode 100644 index 0000000..e3a9085 --- /dev/null +++ b/src/derp/dns.py @@ -0,0 +1,146 @@ +"""Shared DNS wire-format helpers (encode, decode, build, parse).""" + +from __future__ import annotations + +import ipaddress +import os +import socket +import struct + +QTYPES: dict[str, int] = { + "A": 1, "NS": 2, "CNAME": 5, "SOA": 6, + "PTR": 12, "MX": 15, "TXT": 16, "AAAA": 28, +} +QTYPE_NAMES: dict[int, str] = {v: k for k, v in QTYPES.items()} +RCODES: dict[int, str] = { + 0: "", 1: "FORMERR", 2: "SERVFAIL", 3: "NXDOMAIN", + 4: "NOTIMP", 5: "REFUSED", +} + + +def get_resolver() -> str: + """Read first IPv4 nameserver from /etc/resolv.conf.""" + try: + with open("/etc/resolv.conf") as f: + for line in f: + line = line.strip() + if line.startswith("nameserver"): + addr = line.split()[1] + try: + ipaddress.IPv4Address(addr) + return addr + except ValueError: + continue + except (OSError, IndexError): + pass + return "8.8.8.8" + + +def encode_name(name: str) -> bytes: + """Encode a domain name into DNS wire format.""" + out = b"" + for label in name.rstrip(".").split("."): + out += bytes([len(label)]) + label.encode("ascii") + return out + b"\x00" + + +def decode_name(data: bytes, offset: int) -> tuple[str, int]: + """Decode a DNS name with pointer compression.""" + labels: list[str] = [] + jumped = False + ret_offset = offset + jumps = 0 + while offset < len(data): + length = data[offset] + if length == 0: + if not jumped: + ret_offset = offset + 1 + break + if (length & 0xC0) == 0xC0: + if not jumped: + ret_offset = offset + 2 + ptr = struct.unpack_from("!H", data, offset)[0] & 0x3FFF + offset = ptr + jumped = True + jumps += 1 + if jumps > 20: + break + continue + offset += 1 + labels.append(data[offset:offset + length].decode("ascii", errors="replace")) + offset += length + if not jumped: + ret_offset = offset + return ".".join(labels), ret_offset + + +def build_query(name: str, qtype: int) -> bytes: + """Build a DNS query packet.""" + tid = os.urandom(2) + flags = struct.pack("!H", 0x0100) + counts = struct.pack("!HHHH", 1, 0, 0, 0) + return tid + flags + counts + encode_name(name) + struct.pack("!HH", qtype, 1) + + +def parse_rdata(rtype: int, data: bytes, offset: int, rdlength: int) -> str: + """Parse an RR's rdata into a human-readable string.""" + rdata = data[offset:offset + rdlength] + if rtype == 1 and rdlength == 4: + return socket.inet_ntoa(rdata) + if rtype == 28 and rdlength == 16: + return socket.inet_ntop(socket.AF_INET6, rdata) + if rtype in (2, 5, 12): # NS, CNAME, PTR + name, _ = decode_name(data, offset) + return name + if rtype == 15: # MX + pref = struct.unpack_from("!H", rdata, 0)[0] + mx, _ = decode_name(data, offset + 2) + return f"{pref} {mx}" + if rtype == 16: # TXT + parts: list[str] = [] + pos = 0 + while pos < rdlength: + tlen = rdata[pos] + pos += 1 + parts.append(rdata[pos:pos + tlen].decode("utf-8", errors="replace")) + pos += tlen + return "".join(parts) + if rtype == 6: # SOA + mname, off = decode_name(data, offset) + rname, off = decode_name(data, off) + serial = struct.unpack_from("!I", data, off)[0] + return f"{mname} {rname} {serial}" + return rdata.hex() + + +def parse_response(data: bytes) -> tuple[int, list[str]]: + """Parse a DNS response, returning (rcode, [values]).""" + if len(data) < 12: + return 2, [] + _, flags, qdcount, ancount = struct.unpack_from("!HHHH", data, 0) + rcode = flags & 0x0F + offset = 12 + for _ in range(qdcount): + _, offset = decode_name(data, offset) + offset += 4 + results: list[str] = [] + for _ in range(ancount): + if offset + 10 > len(data): + break + _, offset = decode_name(data, offset) + rtype, _, _, rdlength = struct.unpack_from("!HHIH", data, offset) + offset += 10 + if offset + rdlength > len(data): + break + results.append(parse_rdata(rtype, data, offset, rdlength)) + offset += rdlength + return rcode, results + + +def reverse_name(addr: str) -> str: + """Convert an IP address to its reverse DNS name.""" + ip = ipaddress.ip_address(addr) + if isinstance(ip, ipaddress.IPv4Address): + return ".".join(reversed(addr.split("."))) + ".in-addr.arpa" + expanded = ip.exploded.replace(":", "") + return ".".join(reversed(expanded)) + ".ip6.arpa" diff --git a/tests/test_dns.py b/tests/test_dns.py new file mode 100644 index 0000000..70ab251 --- /dev/null +++ b/tests/test_dns.py @@ -0,0 +1,143 @@ +"""Tests for the shared DNS wire-format helpers.""" + +import struct +from unittest.mock import mock_open, patch + +from derp.dns import ( + QTYPE_NAMES, + QTYPES, + build_query, + decode_name, + encode_name, + get_resolver, + parse_response, + reverse_name, +) + + +class TestEncodeDecode: + def test_encode_simple(self): + result = encode_name("example.com") + assert result == b"\x07example\x03com\x00" + + def test_encode_trailing_dot(self): + assert encode_name("example.com.") == encode_name("example.com") + + def test_decode_simple(self): + data = b"\x07example\x03com\x00" + name, offset = decode_name(data, 0) + assert name == "example.com" + assert offset == len(data) + + def test_roundtrip(self): + for domain in ("a.b.c", "example.com", "sub.domain.example.org"): + encoded = encode_name(domain) + decoded, _ = decode_name(encoded, 0) + assert decoded == domain + + def test_decode_pointer(self): + # Name at offset 0, then a pointer back to it at offset 13 + data = b"\x07example\x03com\x00" + b"\xc0\x00" + name, offset = decode_name(data, 13) + assert name == "example.com" + assert offset == 15 + + +class TestBuildQuery: + def test_packet_structure(self): + pkt = build_query("example.com", QTYPES["A"]) + # 2 TID + 2 flags + 8 counts + encoded name + 4 qtype/qclass + encoded = encode_name("example.com") + assert len(pkt) == 12 + len(encoded) + 4 + + def test_flags_rd_set(self): + pkt = build_query("example.com", QTYPES["A"]) + flags = struct.unpack_from("!H", pkt, 2)[0] + assert flags & 0x0100 # RD bit set + + def test_qdcount_one(self): + pkt = build_query("example.com", QTYPES["A"]) + qdcount = struct.unpack_from("!H", pkt, 4)[0] + assert qdcount == 1 + + def test_qtype_embedded(self): + pkt = build_query("example.com", QTYPES["AAAA"]) + encoded = encode_name("example.com") + qtype = struct.unpack_from("!H", pkt, 12 + len(encoded))[0] + assert qtype == 28 + + +class TestParseResponse: + def _make_response(self, rcode=0, answers=None): + """Build a minimal DNS response packet.""" + answers = answers or [] + tid = b"\x00\x01" + flags = struct.pack("!H", 0x8180 | rcode) + counts = struct.pack("!HHHH", 1, len(answers), 0, 0) + # Question section + qname = encode_name("example.com") + question = qname + struct.pack("!HH", 1, 1) + # Answer section + ans_bytes = b"" + for rtype, rdata in answers: + ans_bytes += qname + struct.pack("!HHIH", rtype, 1, 300, len(rdata)) + rdata + return tid + flags + counts + question + ans_bytes + + def test_a_record(self): + rdata = bytes([1, 2, 3, 4]) + pkt = self._make_response(answers=[(1, rdata)]) + rcode, results = parse_response(pkt) + assert rcode == 0 + assert results == ["1.2.3.4"] + + def test_nxdomain(self): + pkt = self._make_response(rcode=3) + rcode, results = parse_response(pkt) + assert rcode == 3 + assert results == [] + + def test_short_packet(self): + rcode, results = parse_response(b"\x00" * 5) + assert rcode == 2 + assert results == [] + + +class TestReverseName: + def test_ipv4(self): + assert reverse_name("1.2.3.4") == "4.3.2.1.in-addr.arpa" + + def test_ipv6(self): + result = reverse_name("::1") + assert result.endswith(".ip6.arpa") + # Full expansion: 32 nibble chars separated by dots + parts = result.replace(".ip6.arpa", "").split(".") + assert len(parts) == 32 + + def test_invalid_raises(self): + try: + reverse_name("not-an-ip") + assert False, "should have raised" + except ValueError: + pass + + +class TestGetResolver: + def test_reads_nameserver(self): + content = "# comment\nnameserver 192.168.1.1\nnameserver 8.8.8.8\n" + with patch("builtins.open", mock_open(read_data=content)): + assert get_resolver() == "192.168.1.1" + + def test_skips_ipv6(self): + content = "nameserver ::1\nnameserver 9.9.9.9\n" + with patch("builtins.open", mock_open(read_data=content)): + assert get_resolver() == "9.9.9.9" + + def test_fallback(self): + with patch("builtins.open", side_effect=OSError): + assert get_resolver() == "8.8.8.8" + + +class TestConstants: + def test_qtype_names_reverse(self): + for name, num in QTYPES.items(): + assert QTYPE_NAMES[num] == name diff --git a/tests/test_tdns.py b/tests/test_tdns.py new file mode 100644 index 0000000..a064c5d --- /dev/null +++ b/tests/test_tdns.py @@ -0,0 +1,128 @@ +"""Tests for the TCP DNS plugin.""" + +import asyncio +import importlib.util +import struct +import sys +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from derp.dns import encode_name + +# plugins/ is not a Python package -- load the module from file path +_spec = importlib.util.spec_from_file_location( + "plugins.tdns", Path(__file__).resolve().parent.parent / "plugins" / "tdns.py", +) +_mod = importlib.util.module_from_spec(_spec) +sys.modules[_spec.name] = _mod +_spec.loader.exec_module(_mod) + +from plugins.tdns import _query_tcp, cmd_tdns # noqa: E402 + + +def _make_a_response(ip_bytes: bytes = b"\x01\x02\x03\x04") -> bytes: + """Build a minimal A-record DNS response.""" + tid = b"\x00\x01" + flags = struct.pack("!H", 0x8180) + counts = struct.pack("!HHHH", 1, 1, 0, 0) + qname = encode_name("example.com") + question = qname + struct.pack("!HH", 1, 1) + answer = qname + struct.pack("!HHIH", 1, 1, 300, len(ip_bytes)) + ip_bytes + return tid + flags + counts + question + answer + + +class TestQueryTcp: + def test_sends_length_prefixed_query(self): + response = _make_a_response() + framed = struct.pack("!H", len(response)) + response + + reader = AsyncMock() + reader.readexactly = AsyncMock(side_effect=[ + framed[:2], # length prefix + framed[2:], # payload + ]) + writer = MagicMock() + writer.drain = AsyncMock() + writer.wait_closed = AsyncMock() + + mock_open = AsyncMock(return_value=(reader, writer)) + with patch.object(_mod, "_open_connection", mock_open): + rcode, results = asyncio.run(_query_tcp("example.com", 1, "1.1.1.1")) + + assert rcode == 0 + assert results == ["1.2.3.4"] + + # Verify the written data has a 2-byte length prefix + written = writer.write.call_args[0][0] + pkt_len = struct.unpack("!H", written[:2])[0] + assert pkt_len == len(written) - 2 + + def test_closes_writer_on_error(self): + reader = AsyncMock() + reader.readexactly = AsyncMock(side_effect=asyncio.IncompleteReadError(b"", 2)) + writer = MagicMock() + writer.drain = AsyncMock() + writer.wait_closed = AsyncMock() + + mock_open = AsyncMock(return_value=(reader, writer)) + with patch.object(_mod, "_open_connection", mock_open): + with pytest.raises(asyncio.IncompleteReadError): + asyncio.run(_query_tcp("example.com", 1, "1.1.1.1")) + + writer.close.assert_called_once() + + +class TestCmdTdns: + def test_no_args(self): + bot = AsyncMock() + msg = MagicMock() + msg.text = "!tdns" + asyncio.run(cmd_tdns(bot, msg)) + bot.reply.assert_called_once() + assert "Usage" in bot.reply.call_args[0][1] + + def test_ip_auto_ptr(self): + bot = AsyncMock() + msg = MagicMock() + msg.text = "!tdns 1.2.3.4" + + mock_tcp = AsyncMock(return_value=(0, ["host.example.com"])) + with patch.object(_mod, "_query_tcp", mock_tcp): + asyncio.run(cmd_tdns(bot, msg)) + + reply = bot.reply.call_args[0][1] + assert "PTR" in reply + assert "host.example.com" in reply + + def test_explicit_type(self): + bot = AsyncMock() + msg = MagicMock() + msg.text = "!tdns example.com MX" + + mock_tcp = AsyncMock(return_value=(0, ["10 mail.example.com"])) + with patch.object(_mod, "_query_tcp", mock_tcp): + asyncio.run(cmd_tdns(bot, msg)) + + reply = bot.reply.call_args[0][1] + assert "MX" in reply + + def test_custom_server(self): + bot = AsyncMock() + msg = MagicMock() + msg.text = "!tdns example.com A @8.8.8.8" + + mock_tcp = AsyncMock(return_value=(0, ["93.184.216.34"])) + with patch.object(_mod, "_query_tcp", mock_tcp): + asyncio.run(cmd_tdns(bot, msg)) + + assert mock_tcp.call_args[0][2] == "8.8.8.8" + + def test_unknown_type(self): + bot = AsyncMock() + msg = MagicMock() + msg.text = "!tdns example.com BOGUS" + asyncio.run(cmd_tdns(bot, msg)) + reply = bot.reply.call_args[0][1] + assert "Unknown type" in reply