feat: concurrent command dispatch and profiling test client
Replace sequential await in command/event dispatch with asyncio.create_task() so slow commands (whois, httpcheck, tlscheck) no longer block the read loop. Add _spawn() for task lifecycle tracking. Enable cProfile in docker-compose for profiling. Add scripts/test_client.py for end-to-end plugin testing. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -8,4 +8,5 @@ services:
|
||||
volumes:
|
||||
- ./config/derp.toml:/app/config/derp.toml:ro,Z
|
||||
- ./plugins:/app/plugins:ro,Z
|
||||
command: ["--verbose"]
|
||||
- ./profile:/app/profile:Z
|
||||
command: ["--verbose", "--cprofile", "/app/profile/derp.prof"]
|
||||
|
||||
211
scripts/test_client.py
Normal file
211
scripts/test_client.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Gentle IRC test client for profiling derp bot commands."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import ssl
|
||||
import sys
|
||||
import time
|
||||
|
||||
HOST = "mymx.me"
|
||||
PORT = 6697
|
||||
PASSWORD = "irc$1234="
|
||||
NICK = "tester"
|
||||
CHANNEL = "#derp"
|
||||
DELAY = 2.5 # seconds between commands -- be nice
|
||||
|
||||
# Commands to test, grouped by plugin
|
||||
TESTS: list[tuple[str, str]] = [
|
||||
# -- core --
|
||||
("core", "!ping"),
|
||||
("core", "!help"),
|
||||
("core", "!help ping"),
|
||||
("core", "!help dns"),
|
||||
("core", "!version"),
|
||||
("core", "!uptime"),
|
||||
("core", "!plugins"),
|
||||
|
||||
# -- example --
|
||||
("example", "!echo profiling test run"),
|
||||
|
||||
# -- dns --
|
||||
("dns", "!dns example.com"),
|
||||
("dns", "!dns example.com MX"),
|
||||
("dns", "!dns 1.1.1.1"),
|
||||
|
||||
# -- encode --
|
||||
("encode", "!encode b64 hello world"),
|
||||
("encode", "!decode b64 aGVsbG8gd29ybGQ="),
|
||||
("encode", "!encode hex derp"),
|
||||
("encode", "!encode rot13 hello"),
|
||||
|
||||
# -- hash --
|
||||
("hash", "!hash hello"),
|
||||
("hash", "!hash sha512 hello"),
|
||||
("hash", "!hashid 5d41402abc4b2a76b9719d911017c592"),
|
||||
|
||||
# -- defang --
|
||||
("defang", "!defang https://evil.com/path?q=1"),
|
||||
("defang", "!refang hxxps[://]evil[.]com/path"),
|
||||
|
||||
# -- revshell --
|
||||
("revshell", "!revshell list"),
|
||||
("revshell", "!revshell bash 10.0.0.1 4444"),
|
||||
|
||||
# -- cidr --
|
||||
("cidr", "!cidr 192.168.1.0/24"),
|
||||
("cidr", "!cidr contains 10.0.0.0/8 10.1.2.3"),
|
||||
|
||||
# -- whois --
|
||||
("whois", "!whois example.com"),
|
||||
|
||||
# -- portcheck --
|
||||
("portcheck", "!portcheck example.com 80,443"),
|
||||
|
||||
# -- httpcheck --
|
||||
("httpcheck", "!httpcheck https://example.com"),
|
||||
|
||||
# -- tlscheck --
|
||||
("tlscheck", "!tlscheck example.com"),
|
||||
|
||||
# -- blacklist (use a known-safe public DNS) --
|
||||
("blacklist", "!blacklist 8.8.8.8"),
|
||||
|
||||
# -- rand --
|
||||
("rand", "!rand password"),
|
||||
("rand", "!rand hex 16"),
|
||||
("rand", "!rand uuid"),
|
||||
("rand", "!rand int 100"),
|
||||
("rand", "!rand coin"),
|
||||
("rand", "!rand dice 2d6"),
|
||||
|
||||
# -- timer --
|
||||
("timer", "!timer 5s profile-test"),
|
||||
("timer", "!timer list"),
|
||||
|
||||
# -- crtsh (external API, single domain) --
|
||||
("crtsh", "!cert example.com"),
|
||||
|
||||
# -- shorthand --
|
||||
("shorthand", "!pi"),
|
||||
("shorthand", "!ver"),
|
||||
]
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Connect to IRC and run through all bot commands."""
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
print(f"[*] connecting to {HOST}:{PORT} (TLS)...")
|
||||
reader, writer = await asyncio.open_connection(HOST, PORT, ssl=ctx)
|
||||
|
||||
async def send(line: str) -> None:
|
||||
writer.write(f"{line}\r\n".encode())
|
||||
await writer.drain()
|
||||
print(f">>> {line}")
|
||||
|
||||
async def read_until(match: str, timeout: float = 15.0) -> list[str]:
|
||||
lines: list[str] = []
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.readline(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if not data:
|
||||
break
|
||||
line = data.decode("utf-8", errors="replace").strip()
|
||||
print(f"<<< {line}")
|
||||
lines.append(line)
|
||||
if line.startswith("PING"):
|
||||
pong = line.replace("PING", "PONG", 1)
|
||||
await send(pong)
|
||||
if match in line:
|
||||
break
|
||||
return lines
|
||||
|
||||
# -- Register --
|
||||
await send(f"PASS {PASSWORD}")
|
||||
await send(f"NICK {NICK}")
|
||||
await send(f"USER {NICK} 0 * :derp test client")
|
||||
await read_until("376") # End of MOTD
|
||||
|
||||
await asyncio.sleep(1)
|
||||
await send(f"JOIN {CHANNEL}")
|
||||
await read_until("366") # End of NAMES
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# -- Run tests --
|
||||
total = len(TESTS)
|
||||
passed = 0
|
||||
results: list[tuple[str, str, bool]] = []
|
||||
|
||||
for i, (plugin, cmd) in enumerate(TESTS, 1):
|
||||
print(f"\n[{i}/{total}] ({plugin}) {cmd}")
|
||||
await send(f"PRIVMSG {CHANNEL} :{cmd}")
|
||||
|
||||
# Wait for bot response (look for PRIVMSG from derp)
|
||||
got_reply = False
|
||||
deadline = time.monotonic() + 15.0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.readline(), timeout=3.0)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if not data:
|
||||
break
|
||||
line = data.decode("utf-8", errors="replace").strip()
|
||||
print(f"<<< {line}")
|
||||
if line.startswith("PING"):
|
||||
pong = line.replace("PING", "PONG", 1)
|
||||
await send(pong)
|
||||
if "PRIVMSG" in line and ":derp!" in line.lower():
|
||||
got_reply = True
|
||||
break
|
||||
|
||||
status = "OK" if got_reply else "NO REPLY"
|
||||
results.append((plugin, cmd, got_reply))
|
||||
if got_reply:
|
||||
passed += 1
|
||||
print(f" -> {status}")
|
||||
|
||||
await asyncio.sleep(DELAY)
|
||||
|
||||
# -- Wait for timer callback --
|
||||
print("\n[*] waiting for timer notification (5s)...")
|
||||
await asyncio.sleep(6)
|
||||
deadline = time.monotonic() + 5.0
|
||||
while time.monotonic() < deadline:
|
||||
try:
|
||||
data = await asyncio.wait_for(reader.readline(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
if not data:
|
||||
break
|
||||
line = data.decode("utf-8", errors="replace").strip()
|
||||
print(f"<<< {line}")
|
||||
|
||||
# -- Summary --
|
||||
print(f"\n{'=' * 50}")
|
||||
print(f"Results: {passed}/{total} commands got replies")
|
||||
print(f"{'=' * 50}")
|
||||
for plugin, cmd, ok in results:
|
||||
mark = "+" if ok else "-"
|
||||
print(f" [{mark}] {plugin:12s} {cmd}")
|
||||
|
||||
# -- Disconnect --
|
||||
await send(f"PART {CHANNEL} :test complete")
|
||||
await send("QUIT :profiling done")
|
||||
writer.close()
|
||||
try:
|
||||
await writer.wait_closed()
|
||||
except ssl.SSLError:
|
||||
pass # server may close before SSL shutdown completes
|
||||
|
||||
sys.exit(0 if passed == total else 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -32,6 +32,7 @@ class Bot:
|
||||
self.prefix: str = config["bot"]["prefix"]
|
||||
self._running = False
|
||||
self._started: float = time.monotonic()
|
||||
self._tasks: set[asyncio.Task] = set()
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Connect, register, join channels, and enter the main loop."""
|
||||
@@ -91,20 +92,17 @@ class Bot:
|
||||
await self.conn.send(format_msg("NICK", self.nick))
|
||||
return
|
||||
|
||||
# Dispatch to event handlers
|
||||
# Dispatch to event handlers (fire-and-forget)
|
||||
event_type = msg.command
|
||||
for handler in self.registry.events.get(event_type, []):
|
||||
try:
|
||||
await handler.callback(self, msg)
|
||||
except Exception:
|
||||
log.exception("error in event handler %s", handler.name)
|
||||
self._spawn(handler.callback(self, msg), name=f"event:{handler.name}")
|
||||
|
||||
# Dispatch to command handlers (PRIVMSG only)
|
||||
if msg.command == "PRIVMSG" and msg.text:
|
||||
await self._dispatch_command(msg)
|
||||
self._dispatch_command(msg)
|
||||
|
||||
async def _dispatch_command(self, msg: Message) -> None:
|
||||
"""Check if a PRIVMSG is a bot command and dispatch it."""
|
||||
def _dispatch_command(self, msg: Message) -> None:
|
||||
"""Check if a PRIVMSG is a bot command and spawn it."""
|
||||
text = msg.text
|
||||
if not text or not text.startswith(self.prefix):
|
||||
return
|
||||
@@ -117,14 +115,26 @@ class Bot:
|
||||
if handler is _AMBIGUOUS:
|
||||
matches = [k for k in self.registry.commands if k.startswith(cmd_name)]
|
||||
names = ", ".join(self.prefix + m for m in sorted(matches))
|
||||
await self.reply(msg, f"Ambiguous command '{self.prefix}{cmd_name}': {names}")
|
||||
self._spawn(self.reply(msg, f"Ambiguous command '{self.prefix}{cmd_name}': {names}"),
|
||||
name=f"cmd:{cmd_name}:ambiguous")
|
||||
return
|
||||
|
||||
self._spawn(self._run_command(handler, cmd_name, msg), name=f"cmd:{cmd_name}")
|
||||
|
||||
async def _run_command(self, handler: Handler, cmd_name: str, msg: Message) -> None:
|
||||
"""Execute a command handler with error logging."""
|
||||
try:
|
||||
await handler.callback(self, msg)
|
||||
except Exception:
|
||||
log.exception("error in command handler '%s'", cmd_name)
|
||||
|
||||
def _spawn(self, coro, *, name: str | None = None) -> asyncio.Task:
|
||||
"""Spawn a background task and track it for cleanup."""
|
||||
task = asyncio.create_task(coro, name=name)
|
||||
self._tasks.add(task)
|
||||
task.add_done_callback(self._tasks.discard)
|
||||
return task
|
||||
|
||||
def _resolve_command(self, name: str) -> Handler | None:
|
||||
"""Resolve a command name, supporting unambiguous prefix matching.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user