feat: add IRCv3 cap negotiation, channel management, state persistence
Implement CAP LS 302 flow with configurable ircv3_caps list, replacing the minimal SASL-only registration. Parse IRCv3 message tags (@key=value) with proper value unescaping. Add channel management plugin (kick, ban, unban, topic, mode) and bot API methods. Add SQLite key-value StateStore for plugin state persistence with !state inspection command. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
107
src/derp/bot.py
107
src/derp/bot.py
@@ -13,6 +13,7 @@ from pathlib import Path
|
||||
from derp import __version__
|
||||
from derp.irc import IRCConnection, Message, format_msg, parse
|
||||
from derp.plugin import Handler, PluginRegistry
|
||||
from derp.state import StateStore
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -63,6 +64,8 @@ class Bot:
|
||||
self._tasks: set[asyncio.Task] = set()
|
||||
self._admins: list[str] = config.get("bot", {}).get("admins", [])
|
||||
self._opers: set[str] = set() # hostmasks of known IRC operators
|
||||
self._caps: set[str] = set() # negotiated IRCv3 caps
|
||||
self.state = StateStore()
|
||||
# Rate limiter: default 2 msg/sec, burst of 5
|
||||
rate_cfg = config.get("bot", {})
|
||||
self._bucket = _TokenBucket(
|
||||
@@ -92,40 +95,86 @@ class Bot:
|
||||
await self.conn.close()
|
||||
|
||||
async def _register(self) -> None:
|
||||
"""Send NICK/USER registration, with optional SASL PLAIN."""
|
||||
"""IRCv3 CAP negotiation followed by NICK/USER registration."""
|
||||
srv = self.config["server"]
|
||||
sasl_user = srv.get("sasl_user", "")
|
||||
sasl_pass = srv.get("sasl_pass", "")
|
||||
|
||||
if sasl_user and sasl_pass:
|
||||
await self._sasl_auth(sasl_user, sasl_pass)
|
||||
# 1. Request server capabilities
|
||||
await self.conn.send("CAP LS 302")
|
||||
available = await self._cap_ls()
|
||||
|
||||
# 2. Determine desired caps
|
||||
wanted = set(srv.get("ircv3_caps", [
|
||||
"multi-prefix", "away-notify", "server-time",
|
||||
"cap-notify", "account-notify",
|
||||
]))
|
||||
if srv.get("sasl_user") and srv.get("sasl_pass"):
|
||||
wanted.add("sasl")
|
||||
|
||||
to_request = wanted & available
|
||||
if to_request:
|
||||
await self.conn.send(f"CAP REQ :{' '.join(sorted(to_request))}")
|
||||
acked = await self._cap_ack()
|
||||
self._caps = acked
|
||||
log.info("negotiated caps: %s", " ".join(sorted(acked)))
|
||||
else:
|
||||
self._caps = set()
|
||||
|
||||
# 3. SASL auth if negotiated
|
||||
if "sasl" in self._caps:
|
||||
await self._sasl_auth(srv["sasl_user"], srv["sasl_pass"])
|
||||
|
||||
# 4. End capability negotiation
|
||||
await self.conn.send("CAP END")
|
||||
|
||||
# 5. Standard registration
|
||||
if srv.get("password"):
|
||||
await self.conn.send(format_msg("PASS", srv["password"]))
|
||||
await self.conn.send(format_msg("NICK", self.nick))
|
||||
await self.conn.send(format_msg("USER", srv["user"], "0", "*", srv["realname"]))
|
||||
|
||||
async def _sasl_auth(self, user: str, password: str) -> None:
|
||||
"""Perform SASL PLAIN authentication during registration."""
|
||||
await self.conn.send("CAP REQ :sasl")
|
||||
|
||||
# Wait for CAP ACK or NAK
|
||||
async def _cap_ls(self) -> set[str]:
|
||||
"""Read CAP LS response(s). Returns set of capability names."""
|
||||
caps: set[str] = set()
|
||||
while True:
|
||||
line = await self.conn.readline()
|
||||
if line is None:
|
||||
log.error("connection closed during SASL negotiation")
|
||||
return
|
||||
log.error("connection closed during CAP LS")
|
||||
return caps
|
||||
msg = parse(line)
|
||||
if msg.command == "CAP" and len(msg.params) >= 3:
|
||||
sub = msg.params[1].upper()
|
||||
if sub == "ACK" and "sasl" in msg.params[-1].lower():
|
||||
break
|
||||
if sub == "NAK":
|
||||
log.warning("server rejected SASL capability")
|
||||
await self.conn.send("CAP END")
|
||||
return
|
||||
if sub == "LS":
|
||||
# Multi-line: CAP * LS * :caps...
|
||||
# Final: CAP * LS :caps...
|
||||
cap_str = msg.params[-1]
|
||||
for token in cap_str.split():
|
||||
name = token.split("=", 1)[0]
|
||||
caps.add(name)
|
||||
# Check for continuation marker
|
||||
if len(msg.params) >= 4 and msg.params[2] == "*":
|
||||
continue
|
||||
return caps
|
||||
return caps # pragma: no cover
|
||||
|
||||
# Send AUTHENTICATE PLAIN
|
||||
async def _cap_ack(self) -> set[str]:
|
||||
"""Read CAP ACK/NAK response. Returns set of acknowledged caps."""
|
||||
while True:
|
||||
line = await self.conn.readline()
|
||||
if line is None:
|
||||
log.error("connection closed during CAP REQ")
|
||||
return set()
|
||||
msg = parse(line)
|
||||
if msg.command == "CAP" and len(msg.params) >= 3:
|
||||
sub = msg.params[1].upper()
|
||||
if sub == "ACK":
|
||||
return set(msg.params[-1].split())
|
||||
if sub == "NAK":
|
||||
log.warning("server rejected caps: %s", msg.params[-1])
|
||||
return set()
|
||||
return set() # pragma: no cover
|
||||
|
||||
async def _sasl_auth(self, user: str, password: str) -> None:
|
||||
"""Perform SASL PLAIN authentication (within CAP negotiation)."""
|
||||
await self.conn.send("AUTHENTICATE PLAIN")
|
||||
|
||||
# Wait for AUTHENTICATE +
|
||||
@@ -152,11 +201,10 @@ class Bot:
|
||||
log.info("SASL authentication successful")
|
||||
break
|
||||
if msg.command in ("904", "905", "906"):
|
||||
log.error("SASL authentication failed: %s", msg.params[-1] if msg.params else "")
|
||||
log.error("SASL authentication failed: %s",
|
||||
msg.params[-1] if msg.params else "")
|
||||
break
|
||||
|
||||
await self.conn.send("CAP END")
|
||||
|
||||
async def _loop(self) -> None:
|
||||
"""Read and dispatch messages until disconnect."""
|
||||
while self._running:
|
||||
@@ -343,6 +391,21 @@ class Bot:
|
||||
await self.conn.send(format_msg("QUIT", reason))
|
||||
await self.conn.close()
|
||||
|
||||
async def kick(self, channel: str, nick: str, reason: str = "") -> None:
|
||||
"""Kick a user from a channel."""
|
||||
if reason:
|
||||
await self.conn.send(format_msg("KICK", channel, nick, reason))
|
||||
else:
|
||||
await self.conn.send(format_msg("KICK", channel, nick))
|
||||
|
||||
async def mode(self, target: str, mode_str: str, *args: str) -> None:
|
||||
"""Set a mode on a target (channel or nick)."""
|
||||
await self.conn.send(format_msg("MODE", target, mode_str, *args))
|
||||
|
||||
async def set_topic(self, channel: str, topic: str) -> None:
|
||||
"""Set the channel topic."""
|
||||
await self.conn.send(format_msg("TOPIC", channel, topic))
|
||||
|
||||
def load_plugins(self, plugins_dir: str | Path | None = None) -> None:
|
||||
"""Load plugins from the configured directory."""
|
||||
if plugins_dir is None:
|
||||
|
||||
@@ -16,6 +16,10 @@ DEFAULTS: dict = {
|
||||
"password": "",
|
||||
"sasl_user": "",
|
||||
"sasl_pass": "",
|
||||
"ircv3_caps": [
|
||||
"multi-prefix", "away-notify", "server-time",
|
||||
"cap-notify", "account-notify",
|
||||
],
|
||||
},
|
||||
"bot": {
|
||||
"prefix": "!",
|
||||
|
||||
@@ -10,15 +10,56 @@ from dataclasses import dataclass
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _unescape_tag_value(value: str) -> str:
|
||||
"""Unescape an IRCv3 message tag value per the spec."""
|
||||
out: list[str] = []
|
||||
i = 0
|
||||
while i < len(value):
|
||||
if value[i] == "\\" and i + 1 < len(value):
|
||||
nxt = value[i + 1]
|
||||
if nxt == ":":
|
||||
out.append(";")
|
||||
elif nxt == "s":
|
||||
out.append(" ")
|
||||
elif nxt == "\\":
|
||||
out.append("\\")
|
||||
elif nxt == "r":
|
||||
out.append("\r")
|
||||
elif nxt == "n":
|
||||
out.append("\n")
|
||||
else:
|
||||
out.append(nxt)
|
||||
i += 2
|
||||
else:
|
||||
out.append(value[i])
|
||||
i += 1
|
||||
return "".join(out)
|
||||
|
||||
|
||||
def _parse_tags(raw_tags: str) -> dict[str, str]:
|
||||
"""Parse an IRCv3 tags string into a dict."""
|
||||
tags: dict[str, str] = {}
|
||||
for part in raw_tags.split(";"):
|
||||
if not part:
|
||||
continue
|
||||
if "=" in part:
|
||||
key, value = part.split("=", 1)
|
||||
tags[key] = _unescape_tag_value(value)
|
||||
else:
|
||||
tags[part] = ""
|
||||
return tags
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class Message:
|
||||
"""Parsed IRC message (RFC 1459)."""
|
||||
"""Parsed IRC message (RFC 1459 + IRCv3 tags)."""
|
||||
|
||||
raw: str
|
||||
prefix: str | None
|
||||
nick: str | None
|
||||
command: str
|
||||
params: list[str]
|
||||
tags: dict[str, str]
|
||||
|
||||
@property
|
||||
def target(self) -> str | None:
|
||||
@@ -39,11 +80,17 @@ class Message:
|
||||
def parse(line: str) -> Message:
|
||||
"""Parse a raw IRC line into a Message.
|
||||
|
||||
Format: [:prefix] command [params...] [:trailing]
|
||||
Format: [@tags] [:prefix] command [params...] [:trailing]
|
||||
"""
|
||||
raw = line
|
||||
prefix = None
|
||||
nick = None
|
||||
tags: dict[str, str] = {}
|
||||
|
||||
# IRCv3 message tags
|
||||
if line.startswith("@"):
|
||||
tag_str, line = line[1:].split(" ", 1)
|
||||
tags = _parse_tags(tag_str)
|
||||
|
||||
if line.startswith(":"):
|
||||
prefix, line = line[1:].split(" ", 1)
|
||||
@@ -63,7 +110,10 @@ def parse(line: str) -> Message:
|
||||
if trailing is not None:
|
||||
params.append(trailing)
|
||||
|
||||
return Message(raw=raw, prefix=prefix, nick=nick, command=command, params=params)
|
||||
return Message(
|
||||
raw=raw, prefix=prefix, nick=nick, command=command,
|
||||
params=params, tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def format_msg(command: str, *params: str) -> str:
|
||||
|
||||
89
src/derp/state.py
Normal file
89
src/derp/state.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""SQLite key-value store for plugin state persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_SCHEMA = """\
|
||||
CREATE TABLE IF NOT EXISTS state (
|
||||
plugin TEXT NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
PRIMARY KEY (plugin, key)
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
class StateStore:
|
||||
"""Persistent key-value store backed by SQLite.
|
||||
|
||||
Each plugin gets its own namespace (plugin, key) so keys never collide
|
||||
across plugins. The database is created lazily in ``data/state.db``.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str | Path = "data/state.db") -> None:
|
||||
self._path = Path(db_path)
|
||||
self._conn: sqlite3.Connection | None = None
|
||||
|
||||
def _db(self) -> sqlite3.Connection:
|
||||
"""Return (and lazily create) the database connection."""
|
||||
if self._conn is None:
|
||||
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._conn = sqlite3.connect(str(self._path))
|
||||
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||
self._conn.executescript(_SCHEMA)
|
||||
log.debug("state store opened: %s", self._path)
|
||||
return self._conn
|
||||
|
||||
def get(self, plugin: str, key: str, default: str | None = None) -> str | None:
|
||||
"""Get a value by plugin and key."""
|
||||
row = self._db().execute(
|
||||
"SELECT value FROM state WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
).fetchone()
|
||||
return row[0] if row else default
|
||||
|
||||
def set(self, plugin: str, key: str, value: str) -> None:
|
||||
"""Set a value, creating or updating as needed."""
|
||||
db = self._db()
|
||||
db.execute(
|
||||
"INSERT INTO state (plugin, key, value) VALUES (?, ?, ?)"
|
||||
" ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, value),
|
||||
)
|
||||
db.commit()
|
||||
|
||||
def delete(self, plugin: str, key: str) -> bool:
|
||||
"""Delete a key. Returns True if a row was removed."""
|
||||
db = self._db()
|
||||
cur = db.execute(
|
||||
"DELETE FROM state WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
)
|
||||
db.commit()
|
||||
return cur.rowcount > 0
|
||||
|
||||
def keys(self, plugin: str) -> list[str]:
|
||||
"""List all keys for a plugin."""
|
||||
rows = self._db().execute(
|
||||
"SELECT key FROM state WHERE plugin = ? ORDER BY key",
|
||||
(plugin,),
|
||||
).fetchall()
|
||||
return [r[0] for r in rows]
|
||||
|
||||
def clear(self, plugin: str) -> int:
|
||||
"""Delete all state for a plugin. Returns number of rows removed."""
|
||||
db = self._db()
|
||||
cur = db.execute("DELETE FROM state WHERE plugin = ?", (plugin,))
|
||||
db.commit()
|
||||
return cur.rowcount
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the database connection."""
|
||||
if self._conn is not None:
|
||||
self._conn.close()
|
||||
self._conn = None
|
||||
Reference in New Issue
Block a user