"""Plugin: DNS record lookup (raw UDP, pure stdlib).""" from __future__ import annotations import asyncio import ipaddress import os import socket import struct from derp.plugin import command _QTYPES = { "A": 1, "NS": 2, "CNAME": 5, "SOA": 6, "PTR": 12, "MX": 15, "TXT": 16, "AAAA": 28, } _QTYPE_NAMES = {v: k for k, v in _QTYPES.items()} _RCODES = { 0: "", 1: "FORMERR", 2: "SERVFAIL", 3: "NXDOMAIN", 4: "NOTIMP", 5: "REFUSED", } # -- wire format helpers -- def _get_resolver() -> str: """Read first IPv4 nameserver from /etc/resolv.conf.""" try: with open("/etc/resolv.conf") as f: for line in f: line = line.strip() if line.startswith("nameserver"): addr = line.split()[1] try: ipaddress.IPv4Address(addr) return addr except ValueError: continue except (OSError, IndexError): pass return "8.8.8.8" def _encode_name(name: str) -> bytes: """Encode a domain name into DNS wire format.""" out = b"" for label in name.rstrip(".").split("."): out += bytes([len(label)]) + label.encode("ascii") return out + b"\x00" def _decode_name(data: bytes, offset: int) -> tuple[str, int]: """Decode a DNS name with pointer compression.""" labels: list[str] = [] jumped = False ret_offset = offset jumps = 0 while offset < len(data): length = data[offset] if length == 0: if not jumped: ret_offset = offset + 1 break if (length & 0xC0) == 0xC0: if not jumped: ret_offset = offset + 2 ptr = struct.unpack_from("!H", data, offset)[0] & 0x3FFF offset = ptr jumped = True jumps += 1 if jumps > 20: break continue offset += 1 labels.append(data[offset:offset + length].decode("ascii", errors="replace")) offset += length if not jumped: ret_offset = offset return ".".join(labels), ret_offset def _build_query(name: str, qtype: int) -> bytes: """Build a DNS query packet.""" tid = os.urandom(2) flags = struct.pack("!H", 0x0100) counts = struct.pack("!HHHH", 1, 0, 0, 0) return tid + flags + counts + _encode_name(name) + struct.pack("!HH", qtype, 1) def _parse_rdata(rtype: int, data: bytes, offset: int, rdlength: int) -> str: """Parse an RR's rdata into a human-readable string.""" rdata = data[offset:offset + rdlength] if rtype == 1 and rdlength == 4: return socket.inet_ntoa(rdata) if rtype == 28 and rdlength == 16: return socket.inet_ntop(socket.AF_INET6, rdata) if rtype in (2, 5, 12): # NS, CNAME, PTR name, _ = _decode_name(data, offset) return name if rtype == 15: # MX pref = struct.unpack_from("!H", rdata, 0)[0] mx, _ = _decode_name(data, offset + 2) return f"{pref} {mx}" if rtype == 16: # TXT parts: list[str] = [] pos = 0 while pos < rdlength: tlen = rdata[pos] pos += 1 parts.append(rdata[pos:pos + tlen].decode("utf-8", errors="replace")) pos += tlen return "".join(parts) if rtype == 6: # SOA mname, off = _decode_name(data, offset) rname, off = _decode_name(data, off) serial = struct.unpack_from("!I", data, off)[0] return f"{mname} {rname} {serial}" return rdata.hex() def _parse_response(data: bytes) -> tuple[int, list[str]]: """Parse a DNS response, returning (rcode, [values]).""" if len(data) < 12: return 2, [] _, flags, qdcount, ancount = struct.unpack_from("!HHHH", data, 0) rcode = flags & 0x0F offset = 12 for _ in range(qdcount): _, offset = _decode_name(data, offset) offset += 4 results: list[str] = [] for _ in range(ancount): if offset + 10 > len(data): break _, offset = _decode_name(data, offset) rtype, _, _, rdlength = struct.unpack_from("!HHIH", data, offset) offset += 10 if offset + rdlength > len(data): break results.append(_parse_rdata(rtype, data, offset, rdlength)) offset += rdlength return rcode, results async def _query(name: str, qtype: int, server: str, timeout: float = 5.0) -> tuple[int, list[str]]: """Send a DNS query and return (rcode, [values]).""" query = _build_query(name, qtype) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock.settimeout(timeout) loop = asyncio.get_running_loop() try: await loop.run_in_executor(None, sock.sendto, query, (server, 53)) data = await asyncio.wait_for( loop.run_in_executor(None, sock.recv, 4096), timeout=timeout, ) return _parse_response(data) except (TimeoutError, socket.timeout): return -1, [] except OSError: return -2, [] finally: sock.close() def _reverse_name(addr: str) -> str: """Convert an IP address to its reverse DNS name.""" ip = ipaddress.ip_address(addr) if isinstance(ip, ipaddress.IPv4Address): return ".".join(reversed(addr.split("."))) + ".in-addr.arpa" expanded = ip.exploded.replace(":", "") return ".".join(reversed(expanded)) + ".ip6.arpa" @command("dns", help="DNS lookup: !dns [A|AAAA|MX|NS|TXT|CNAME|PTR|SOA]") async def cmd_dns(bot, message): """Query DNS records for a domain or reverse-lookup an IP.""" parts = message.text.split(None, 3) if len(parts) < 2: await bot.reply(message, "Usage: !dns [type]") return target = parts[1] qtype_str = parts[2].upper() if len(parts) > 2 else None # Auto-detect: IP -> PTR, domain -> A if qtype_str is None: try: ipaddress.ip_address(target) qtype_str = "PTR" except ValueError: qtype_str = "A" qtype = _QTYPES.get(qtype_str) if qtype is None: valid = ", ".join(sorted(_QTYPES)) await bot.reply(message, f"Unknown type: {qtype_str} (valid: {valid})") return lookup = target if qtype_str == "PTR": try: lookup = _reverse_name(target) except ValueError: await bot.reply(message, f"Invalid IP for PTR: {target}") return server = _get_resolver() rcode, results = await _query(lookup, qtype, server) if rcode == -1: await bot.reply(message, f"{target} {qtype_str}: timeout") elif rcode == -2: await bot.reply(message, f"{target} {qtype_str}: network error") elif rcode != 0: err = _RCODES.get(rcode, f"error {rcode}") await bot.reply(message, f"{target} {qtype_str}: {err}") elif not results: await bot.reply(message, f"{target} {qtype_str}: no records") else: await bot.reply(message, f"{target} {qtype_str}: {', '.join(results)}")