diff --git a/plugins/rss.py b/plugins/rss.py new file mode 100644 index 0000000..f2776bc --- /dev/null +++ b/plugins/rss.py @@ -0,0 +1,504 @@ +"""Plugin: per-channel RSS/Atom feed subscriptions with periodic polling.""" + +from __future__ import annotations + +import asyncio +import json +import re +import ssl +import urllib.request +import xml.etree.ElementTree as ET +from datetime import datetime, timezone +from urllib.parse import urlparse + +from derp.plugin import command, event + +# -- Constants --------------------------------------------------------------- + +_NAME_RE = re.compile(r"^[a-z0-9][a-z0-9-]{0,19}$") +_MAX_SEEN = 200 +_MAX_ANNOUNCE = 5 +_DEFAULT_INTERVAL = 600 +_MAX_INTERVAL = 3600 +_FETCH_TIMEOUT = 15 +_USER_AGENT = "derp/1.0" +_MAX_TITLE_LEN = 80 +_MAX_FEEDS = 20 +_ATOM_NS = "{http://www.w3.org/2005/Atom}" +_DC_NS = "{http://purl.org/dc/elements/1.1/}" + +# -- Module-level tracking --------------------------------------------------- + +_pollers: dict[str, asyncio.Task] = {} +_feeds: dict[str, dict] = {} +_errors: dict[str, int] = {} + + +# -- Pure helpers ------------------------------------------------------------ + +def _state_key(channel: str, name: str) -> str: + """Build composite state key.""" + return f"{channel}:{name}" + + +def _validate_name(name: str) -> bool: + """Check name against allowed pattern.""" + return bool(_NAME_RE.match(name)) + + +def _derive_name(url: str) -> str: + """Derive a short feed name from URL hostname.""" + try: + hostname = urlparse(url).hostname or "" + except Exception: + hostname = "" + # Strip www. prefix, take first label, lowercase + hostname = hostname.lower().removeprefix("www.") + name = hostname.split(".")[0] if hostname else "feed" + # Sanitize to allowed chars + name = re.sub(r"[^a-z0-9-]", "", name) + if not name or not name[0].isalnum(): + name = "feed" + return name[:20] + + +def _truncate(text: str, max_len: int = _MAX_TITLE_LEN) -> str: + """Truncate text with ellipsis if needed.""" + if len(text) <= max_len: + return text + return text[: max_len - 3].rstrip() + "..." + + +# -- State helpers ----------------------------------------------------------- + +def _save(bot, key: str, data: dict) -> None: + """Persist feed data to bot.state.""" + bot.state.set("rss", key, json.dumps(data)) + + +def _load(bot, key: str) -> dict | None: + """Load feed data from bot.state.""" + raw = bot.state.get("rss", key) + if raw is None: + return None + try: + return json.loads(raw) + except json.JSONDecodeError: + return None + + +def _delete(bot, key: str) -> None: + """Remove feed data from bot.state.""" + bot.state.delete("rss", key) + + +# -- Feed fetching (blocking, for executor) ---------------------------------- + +def _fetch_feed(url: str, etag: str = "", last_modified: str = "") -> dict: + """Blocking HTTP GET for feed content. Run via executor.""" + result: dict = { + "status": 0, + "body": b"", + "etag": "", + "last_modified": "", + "error": "", + } + + req = urllib.request.Request(url, method="GET") + req.add_header("User-Agent", _USER_AGENT) + if etag: + req.add_header("If-None-Match", etag) + if last_modified: + req.add_header("If-Modified-Since", last_modified) + + ctx = ssl.create_default_context() + + try: + resp = urllib.request.urlopen(req, timeout=_FETCH_TIMEOUT, context=ctx) + result["status"] = resp.status + result["body"] = resp.read() + result["etag"] = resp.headers.get("ETag", "") + result["last_modified"] = resp.headers.get("Last-Modified", "") + resp.close() + except urllib.error.HTTPError as exc: + result["status"] = exc.code + if exc.code == 304: + result["etag"] = etag + result["last_modified"] = last_modified + else: + result["error"] = f"HTTP {exc.code}" + except urllib.error.URLError as exc: + result["error"] = str(exc.reason) + except Exception as exc: + result["error"] = str(exc) + + return result + + +# -- Feed parsing ------------------------------------------------------------ + +def _parse_rss(root: ET.Element) -> tuple[str, list[dict]]: + """Parse RSS 2.0 feed.""" + channel = root.find("channel") + if channel is None: + return ("", []) + title = (channel.findtext("title") or "").strip() + items = [] + for item in channel.findall("item"): + item_id = item.findtext("guid") or item.findtext("link") or "" + item_title = (item.findtext("title") or "").strip() + item_link = (item.findtext("link") or "").strip() + if item_id: + items.append({"id": item_id, "title": item_title, "link": item_link}) + return (title, items) + + +def _parse_atom(root: ET.Element) -> tuple[str, list[dict]]: + """Parse Atom feed.""" + title = (root.findtext(f"{_ATOM_NS}title") or "").strip() + items = [] + for entry in root.findall(f"{_ATOM_NS}entry"): + entry_id = (entry.findtext(f"{_ATOM_NS}id") or "").strip() + link_el = entry.find(f"{_ATOM_NS}link") + entry_link = (link_el.get("href", "") if link_el is not None else "").strip() + if not entry_id: + entry_id = entry_link + entry_title = (entry.findtext(f"{_ATOM_NS}title") or "").strip() + if entry_id: + items.append({"id": entry_id, "title": entry_title, "link": entry_link}) + return (title, items) + + +def _parse_feed(body: bytes) -> tuple[str, list[dict]]: + """Auto-detect RSS/Atom and parse. Returns (feed_title, items).""" + root = ET.fromstring(body) + tag = root.tag + local = tag.split("}")[-1].lower() if "}" in tag else tag.lower() + if local == "rss": + return _parse_rss(root) + if local == "feed": + return _parse_atom(root) + raise ValueError(f"Unknown feed format: {root.tag}") + + +# -- Polling ----------------------------------------------------------------- + +async def _poll_once(bot, key: str, announce: bool = True) -> None: + """Single poll cycle for one feed.""" + data = _feeds.get(key) + if data is None: + data = _load(bot, key) + if data is None: + return + _feeds[key] = data + + url = data["url"] + etag = data.get("etag", "") + last_modified = data.get("last_modified", "") + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, _fetch_feed, url, etag, last_modified, + ) + + now = datetime.now(timezone.utc).isoformat() + data["last_poll"] = now + + if result["error"]: + data["last_error"] = result["error"] + _errors[key] = _errors.get(key, 0) + 1 + _feeds[key] = data + _save(bot, key, data) + return + + # HTTP 304 -- not modified + if result["status"] == 304: + data["last_error"] = "" + _errors[key] = 0 + _feeds[key] = data + _save(bot, key, data) + return + + # Update conditional headers + data["etag"] = result["etag"] + data["last_modified"] = result["last_modified"] + data["last_error"] = "" + _errors[key] = 0 + + try: + feed_title, items = _parse_feed(result["body"]) + except Exception as exc: + data["last_error"] = f"Parse error: {exc}" + _errors[key] = _errors.get(key, 0) + 1 + _feeds[key] = data + _save(bot, key, data) + return + + if feed_title and not data.get("title"): + data["title"] = feed_title + + seen = set(data.get("seen", [])) + seen_list = list(data.get("seen", [])) + new_items = [item for item in items if item["id"] not in seen] + + if announce and new_items: + channel = data["channel"] + name = data["name"] + shown = new_items[:_MAX_ANNOUNCE] + for item in shown: + title = _truncate(item["title"]) if item["title"] else "(no title)" + link = item["link"] + line = f"[{name}] {title}" + if link: + line += f" -- {link}" + await bot.send(channel, line) + remaining = len(new_items) - len(shown) + if remaining > 0: + await bot.send(channel, f"[{name}] ... and {remaining} more") + + # Update seen list + for item in new_items: + seen_list.append(item["id"]) + if len(seen_list) > _MAX_SEEN: + seen_list = seen_list[-_MAX_SEEN:] + data["seen"] = seen_list + + _feeds[key] = data + _save(bot, key, data) + + +async def _poll_loop(bot, key: str) -> None: + """Infinite poll loop for one feed.""" + try: + while True: + data = _feeds.get(key) or _load(bot, key) + if data is None: + return + interval = data.get("interval", _DEFAULT_INTERVAL) + # Back off on consecutive errors + errs = _errors.get(key, 0) + if errs >= 5: + interval = min(interval * 2, _MAX_INTERVAL) + await asyncio.sleep(interval) + await _poll_once(bot, key, announce=True) + except asyncio.CancelledError: + pass + + +def _start_poller(bot, key: str) -> None: + """Create and track a poller task.""" + existing = _pollers.get(key) + if existing and not existing.done(): + return + task = asyncio.create_task(_poll_loop(bot, key)) + _pollers[key] = task + + +def _stop_poller(key: str) -> None: + """Cancel and remove a poller task.""" + task = _pollers.pop(key, None) + if task and not task.done(): + task.cancel() + _feeds.pop(key, None) + _errors.pop(key, 0) + + +# -- Restore on connect ----------------------------------------------------- + +def _restore(bot) -> None: + """Rebuild pollers from persisted state.""" + for key in bot.state.keys("rss"): + existing = _pollers.get(key) + if existing and not existing.done(): + continue + data = _load(bot, key) + if data is None: + continue + _feeds[key] = data + _start_poller(bot, key) + + +@event("001") +async def on_connect(bot, message): + """Restore RSS feed pollers on connect.""" + _restore(bot) + + +# -- Command handler --------------------------------------------------------- + +@command("rss", help="RSS: !rss add|del|list|check") +async def cmd_rss(bot, message): + """Per-channel RSS/Atom feed subscriptions. + + Usage: + !rss add [name] Subscribe a feed (admin) + !rss del Unsubscribe a feed (admin) + !rss list List feeds in this channel + !rss check Force-poll a feed now + """ + parts = message.text.split(None, 3) + if len(parts) < 2: + await bot.reply(message, "Usage: !rss [args]") + return + + sub = parts[1].lower() + + # -- list (any user, any context) ---------------------------------------- + if sub == "list": + if not message.is_channel: + await bot.reply(message, "Use this command in a channel") + return + channel = message.target + prefix = f"{channel}:" + feeds = [] + for key in bot.state.keys("rss"): + if key.startswith(prefix): + data = _load(bot, key) + if data: + name = data["name"] + err = data.get("last_error", "") + if err: + feeds.append(f"{name} (error)") + else: + feeds.append(name) + if not feeds: + await bot.reply(message, "No feeds in this channel") + return + await bot.reply(message, f"Feeds: {', '.join(feeds)}") + return + + # -- check (any user, channel only) -------------------------------------- + if sub == "check": + if not message.is_channel: + await bot.reply(message, "Use this command in a channel") + return + if len(parts) < 3: + await bot.reply(message, "Usage: !rss check ") + return + name = parts[2].lower() + channel = message.target + key = _state_key(channel, name) + data = _load(bot, key) + if data is None: + await bot.reply(message, f"No feed '{name}' in this channel") + return + _feeds[key] = data + await _poll_once(bot, key, announce=True) + data = _feeds.get(key, data) + if data.get("last_error"): + await bot.reply(message, f"{name}: error -- {data['last_error']}") + else: + await bot.reply(message, f"{name}: checked") + return + + # -- add (admin, channel only) ------------------------------------------- + if sub == "add": + if not bot._is_admin(message): + await bot.reply(message, "Permission denied: add requires admin") + return + if not message.is_channel: + await bot.reply(message, "Use this command in a channel") + return + if len(parts) < 3: + await bot.reply(message, "Usage: !rss add [name]") + return + + url = parts[2] + if not url.startswith(("http://", "https://")): + url = f"https://{url}" + + name = parts[3].lower() if len(parts) > 3 else _derive_name(url) + if not _validate_name(name): + await bot.reply( + message, + "Invalid name (lowercase alphanumeric + hyphens, 1-20 chars)", + ) + return + + channel = message.target + key = _state_key(channel, name) + + # Check for duplicate + if _load(bot, key) is not None: + await bot.reply(message, f"Feed '{name}' already exists in this channel") + return + + # Check per-channel limit + prefix = f"{channel}:" + count = sum(1 for k in bot.state.keys("rss") if k.startswith(prefix)) + if count >= _MAX_FEEDS: + await bot.reply(message, f"Channel feed limit reached ({_MAX_FEEDS})") + return + + # Test-fetch to validate URL and seed seen list + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, _fetch_feed, url, "", "") + + if result["error"]: + await bot.reply(message, f"Fetch failed: {result['error']}") + return + + feed_title = "" + seen = [] + try: + feed_title, items = _parse_feed(result["body"]) + seen = [item["id"] for item in items] + if len(seen) > _MAX_SEEN: + seen = seen[-_MAX_SEEN:] + except Exception as exc: + await bot.reply(message, f"Parse failed: {exc}") + return + + now = datetime.now(timezone.utc).isoformat() + data = { + "url": url, + "name": name, + "channel": channel, + "interval": _DEFAULT_INTERVAL, + "added_by": message.nick, + "added_at": now, + "seen": seen, + "last_poll": now, + "last_error": "", + "etag": result["etag"], + "last_modified": result["last_modified"], + "title": feed_title, + } + _save(bot, key, data) + _feeds[key] = data + _start_poller(bot, key) + + display = feed_title or name + item_count = len(seen) + await bot.reply( + message, + f"Subscribed '{name}' ({display}, {item_count} existing items)", + ) + return + + # -- del (admin, channel only) ------------------------------------------- + if sub == "del": + if not bot._is_admin(message): + await bot.reply(message, "Permission denied: del requires admin") + return + if not message.is_channel: + await bot.reply(message, "Use this command in a channel") + return + if len(parts) < 3: + await bot.reply(message, "Usage: !rss del ") + return + + name = parts[2].lower() + channel = message.target + key = _state_key(channel, name) + + if _load(bot, key) is None: + await bot.reply(message, f"No feed '{name}' in this channel") + return + + _stop_poller(key) + _delete(bot, key) + await bot.reply(message, f"Unsubscribed '{name}'") + return + + await bot.reply(message, "Usage: !rss [args]") diff --git a/tests/test_rss.py b/tests/test_rss.py new file mode 100644 index 0000000..614a021 --- /dev/null +++ b/tests/test_rss.py @@ -0,0 +1,1075 @@ +"""Tests for the RSS feed plugin.""" + +import asyncio +import importlib.util +import sys +from pathlib import Path +from unittest.mock import patch + +from derp.irc import Message + +# plugins/ is not a Python package -- load the module from file path +_spec = importlib.util.spec_from_file_location( + "plugins.rss", Path(__file__).resolve().parent.parent / "plugins" / "rss.py", +) +_mod = importlib.util.module_from_spec(_spec) +sys.modules[_spec.name] = _mod +_spec.loader.exec_module(_mod) + +from plugins.rss import ( # noqa: E402 + _MAX_ANNOUNCE, + _MAX_SEEN, + _delete, + _derive_name, + _errors, + _feeds, + _load, + _parse_atom, + _parse_feed, + _parse_rss, + _poll_once, + _pollers, + _restore, + _save, + _start_poller, + _state_key, + _stop_poller, + _truncate, + _validate_name, + cmd_rss, + on_connect, +) + +# -- Fixtures ---------------------------------------------------------------- + +RSS_FEED = b"""\ + + + + Test RSS Feed + https://example.com + + item-1 + First Post + https://example.com/1 + + + item-2 + Second Post + https://example.com/2 + + + item-3 + Third Post + https://example.com/3 + + + +""" + +RSS_NO_GUID = b"""\ + + + + No GUID Feed + + Linkonly + https://example.com/linkonly + + + +""" + +ATOM_FEED = b"""\ + + + Test Atom Feed + + atom-1 + Atom First + + + + atom-2 + Atom Second + + + +""" + +ATOM_NO_ID = b"""\ + + + Atom No ID + + No ID Entry + + + +""" + +INVALID_XML = b"Not a feed" + +EMPTY_RSS = b"""\ + + + + Empty Feed + + +""" + + +# -- Helpers ----------------------------------------------------------------- + +class _FakeState: + """In-memory stand-in for bot.state.""" + + def __init__(self): + self._store: dict[str, dict[str, str]] = {} + + def get(self, plugin: str, key: str, default: str | None = None) -> str | None: + return self._store.get(plugin, {}).get(key, default) + + def set(self, plugin: str, key: str, value: str) -> None: + self._store.setdefault(plugin, {})[key] = value + + def delete(self, plugin: str, key: str) -> bool: + try: + del self._store[plugin][key] + return True + except KeyError: + return False + + def keys(self, plugin: str) -> list[str]: + return sorted(self._store.get(plugin, {}).keys()) + + +class _FakeBot: + """Minimal bot stand-in that captures sent/replied messages.""" + + def __init__(self, *, admin: bool = False): + self.sent: list[tuple[str, str]] = [] + self.replied: list[str] = [] + self.state = _FakeState() + self._admin = admin + + async def send(self, target: str, text: str) -> None: + self.sent.append((target, text)) + + async def reply(self, message, text: str) -> None: + self.replied.append(text) + + def _is_admin(self, message) -> bool: + return self._admin + + +def _msg(text: str, nick: str = "alice", target: str = "#test") -> Message: + """Create a channel PRIVMSG.""" + return Message( + raw="", prefix=f"{nick}!~{nick}@host", nick=nick, + command="PRIVMSG", params=[target, text], tags={}, + ) + + +def _pm(text: str, nick: str = "alice") -> Message: + """Create a private PRIVMSG.""" + return Message( + raw="", prefix=f"{nick}!~{nick}@host", nick=nick, + command="PRIVMSG", params=["botname", text], tags={}, + ) + + +def _clear() -> None: + """Reset module-level state between tests.""" + for task in _pollers.values(): + if task and not task.done(): + task.cancel() + _pollers.clear() + _feeds.clear() + _errors.clear() + + +def _fake_fetch_ok(url, etag="", last_modified=""): + """Fake fetch that returns RSS_FEED.""" + return { + "status": 200, + "body": RSS_FEED, + "etag": '"abc"', + "last_modified": "Sat, 15 Feb 2026 12:00:00 GMT", + "error": "", + } + + +def _fake_fetch_error(url, etag="", last_modified=""): + """Fake fetch that returns an error.""" + return { + "status": 0, + "body": b"", + "etag": "", + "last_modified": "", + "error": "Connection refused", + } + + +def _fake_fetch_304(url, etag="", last_modified=""): + """Fake fetch that returns 304 Not Modified.""" + return { + "status": 304, + "body": b"", + "etag": etag, + "last_modified": last_modified, + "error": "", + } + + +# --------------------------------------------------------------------------- +# TestValidateName +# --------------------------------------------------------------------------- + +class TestValidateName: + def test_valid_simple(self): + assert _validate_name("hn") is True + + def test_valid_with_hyphens(self): + assert _validate_name("my-feed") is True + + def test_valid_with_numbers(self): + assert _validate_name("feed123") is True + + def test_valid_single_char(self): + assert _validate_name("a") is True + + def test_valid_max_length(self): + assert _validate_name("a" * 20) is True + + def test_invalid_too_long(self): + assert _validate_name("a" * 21) is False + + def test_invalid_uppercase(self): + assert _validate_name("Feed") is False + + def test_invalid_starts_with_hyphen(self): + assert _validate_name("-feed") is False + + def test_invalid_special_chars(self): + assert _validate_name("feed!") is False + + def test_invalid_spaces(self): + assert _validate_name("my feed") is False + + def test_invalid_empty(self): + assert _validate_name("") is False + + +# --------------------------------------------------------------------------- +# TestDeriveName +# --------------------------------------------------------------------------- + +class TestDeriveName: + def test_simple_domain(self): + assert _derive_name("https://hnrss.org/newest") == "hnrss" + + def test_www_stripped(self): + assert _derive_name("https://www.example.com/feed") == "example" + + def test_subdomain(self): + assert _derive_name("https://blog.example.com/rss") == "blog" + + def test_invalid_url(self): + result = _derive_name("not a url") + assert _validate_name(result) + + def test_empty_url(self): + result = _derive_name("") + assert _validate_name(result) + + def test_long_hostname_truncated(self): + result = _derive_name("https://abcdefghijklmnopqrstuvwxyz.com/feed") + assert len(result) <= 20 + + +# --------------------------------------------------------------------------- +# TestTruncate +# --------------------------------------------------------------------------- + +class TestTruncate: + def test_short_text_unchanged(self): + assert _truncate("hello", 80) == "hello" + + def test_exact_length_unchanged(self): + text = "a" * 80 + assert _truncate(text, 80) == text + + def test_long_text_truncated(self): + text = "a" * 100 + result = _truncate(text, 80) + assert len(result) == 80 + assert result.endswith("...") + + def test_default_max_length(self): + text = "a" * 100 + result = _truncate(text) + assert len(result) == 80 + + def test_trailing_space_stripped(self): + text = "word " * 20 + result = _truncate(text, 20) + assert not result.endswith(" ...") + + +# --------------------------------------------------------------------------- +# TestParseRSS +# --------------------------------------------------------------------------- + +class TestParseRSS: + def test_parses_items(self): + import xml.etree.ElementTree as ET + root = ET.fromstring(RSS_FEED) + title, items = _parse_rss(root) + assert title == "Test RSS Feed" + assert len(items) == 3 + assert items[0]["id"] == "item-1" + assert items[0]["title"] == "First Post" + assert items[0]["link"] == "https://example.com/1" + + def test_fallback_to_link_as_id(self): + import xml.etree.ElementTree as ET + root = ET.fromstring(RSS_NO_GUID) + title, items = _parse_rss(root) + assert title == "No GUID Feed" + assert len(items) == 1 + assert items[0]["id"] == "https://example.com/linkonly" + + def test_empty_channel(self): + import xml.etree.ElementTree as ET + root = ET.fromstring(EMPTY_RSS) + title, items = _parse_rss(root) + assert title == "Empty Feed" + assert items == [] + + +# --------------------------------------------------------------------------- +# TestParseAtom +# --------------------------------------------------------------------------- + +class TestParseAtom: + def test_parses_entries(self): + import xml.etree.ElementTree as ET + root = ET.fromstring(ATOM_FEED) + title, items = _parse_atom(root) + assert title == "Test Atom Feed" + assert len(items) == 2 + assert items[0]["id"] == "atom-1" + assert items[0]["title"] == "Atom First" + assert items[0]["link"] == "https://example.com/a1" + + def test_fallback_to_link_as_id(self): + import xml.etree.ElementTree as ET + root = ET.fromstring(ATOM_NO_ID) + title, items = _parse_atom(root) + assert len(items) == 1 + assert items[0]["id"] == "https://example.com/noid" + + +# --------------------------------------------------------------------------- +# TestParseFeedDetect +# --------------------------------------------------------------------------- + +class TestParseFeedDetect: + def test_detects_rss(self): + title, items = _parse_feed(RSS_FEED) + assert title == "Test RSS Feed" + assert len(items) == 3 + + def test_detects_atom(self): + title, items = _parse_feed(ATOM_FEED) + assert title == "Test Atom Feed" + assert len(items) == 2 + + def test_rejects_html(self): + import pytest + with pytest.raises(ValueError, match="Unknown feed format"): + _parse_feed(INVALID_XML) + + def test_rejects_malformed_xml(self): + import pytest + with pytest.raises(Exception): + _parse_feed(b"<<>>") + + +# --------------------------------------------------------------------------- +# TestDeduplication +# --------------------------------------------------------------------------- + +class TestDeduplication: + def test_new_items_detected(self): + seen = {"item-1"} + _, items = _parse_feed(RSS_FEED) + new = [i for i in items if i["id"] not in seen] + assert len(new) == 2 + assert new[0]["id"] == "item-2" + assert new[1]["id"] == "item-3" + + def test_all_seen_no_new(self): + seen = {"item-1", "item-2", "item-3"} + _, items = _parse_feed(RSS_FEED) + new = [i for i in items if i["id"] not in seen] + assert len(new) == 0 + + def test_fifo_cap(self): + seen_list = list(range(250)) + seen_list = seen_list[-_MAX_SEEN:] + assert len(seen_list) == _MAX_SEEN + assert seen_list[0] == 50 + + +# --------------------------------------------------------------------------- +# TestStateHelpers +# --------------------------------------------------------------------------- + +class TestStateHelpers: + def test_save_and_load(self): + bot = _FakeBot() + data = {"url": "https://example.com/feed", "name": "test"} + _save(bot, "#ch:test", data) + loaded = _load(bot, "#ch:test") + assert loaded == data + + def test_load_missing(self): + bot = _FakeBot() + assert _load(bot, "nonexistent") is None + + def test_delete(self): + bot = _FakeBot() + _save(bot, "#ch:test", {"name": "test"}) + _delete(bot, "#ch:test") + assert _load(bot, "#ch:test") is None + + def test_state_key(self): + assert _state_key("#ops", "hn") == "#ops:hn" + + def test_load_invalid_json(self): + bot = _FakeBot() + bot.state.set("rss", "bad", "not json{{{") + assert _load(bot, "bad") is None + + +# --------------------------------------------------------------------------- +# TestCmdRssAdd +# --------------------------------------------------------------------------- + +class TestCmdRssAdd: + def test_add_success(self): + _clear() + bot = _FakeBot(admin=True) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss add https://example.com/feed testfeed")) + await asyncio.sleep(0) + assert len(bot.replied) == 1 + assert "Subscribed 'testfeed'" in bot.replied[0] + assert "3 existing items" in bot.replied[0] + data = _load(bot, "#test:testfeed") + assert data is not None + assert data["name"] == "testfeed" + assert data["channel"] == "#test" + assert len(data["seen"]) == 3 + assert "#test:testfeed" in _pollers + _stop_poller("#test:testfeed") + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_add_derives_name(self): + _clear() + bot = _FakeBot(admin=True) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss add https://hnrss.org/newest")) + await asyncio.sleep(0) + assert "Subscribed 'hnrss'" in bot.replied[0] + _stop_poller("#test:hnrss") + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_add_requires_admin(self): + _clear() + bot = _FakeBot(admin=False) + asyncio.run(cmd_rss(bot, _msg("!rss add https://example.com/feed"))) + assert "Permission denied" in bot.replied[0] + + def test_add_requires_channel(self): + _clear() + bot = _FakeBot(admin=True) + asyncio.run(cmd_rss(bot, _pm("!rss add https://example.com/feed"))) + assert "Use this command in a channel" in bot.replied[0] + + def test_add_invalid_name(self): + _clear() + bot = _FakeBot(admin=True) + asyncio.run(cmd_rss(bot, _msg("!rss add https://example.com/feed BAD!"))) + assert "Invalid name" in bot.replied[0] + + def test_add_duplicate(self): + _clear() + bot = _FakeBot(admin=True) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss add https://example.com/feed myfeed")) + await asyncio.sleep(0) + bot.replied.clear() + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss add https://other.com/feed myfeed")) + assert "already exists" in bot.replied[0] + _stop_poller("#test:myfeed") + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_add_fetch_error(self): + _clear() + bot = _FakeBot(admin=True) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_error): + await cmd_rss(bot, _msg("!rss add https://example.com/feed")) + assert "Fetch failed" in bot.replied[0] + + asyncio.run(inner()) + + def test_add_no_url(self): + _clear() + bot = _FakeBot(admin=True) + asyncio.run(cmd_rss(bot, _msg("!rss add"))) + assert "Usage:" in bot.replied[0] + + def test_add_feed_limit(self): + _clear() + bot = _FakeBot(admin=True) + # Pre-fill state with max feeds + for i in range(20): + _save(bot, f"#test:feed{i}", {"name": f"feed{i}", "channel": "#test"}) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss add https://example.com/feed overflow")) + assert "limit reached" in bot.replied[0] + + asyncio.run(inner()) + + def test_add_prepends_https(self): + _clear() + bot = _FakeBot(admin=True) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss add example.com/feed test")) + await asyncio.sleep(0) + data = _load(bot, "#test:test") + assert data["url"] == "https://example.com/feed" + _stop_poller("#test:test") + await asyncio.sleep(0) + + asyncio.run(inner()) + + +# --------------------------------------------------------------------------- +# TestCmdRssDel +# --------------------------------------------------------------------------- + +class TestCmdRssDel: + def test_del_success(self): + _clear() + bot = _FakeBot(admin=True) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss add https://example.com/feed delfeed")) + await asyncio.sleep(0) + bot.replied.clear() + await cmd_rss(bot, _msg("!rss del delfeed")) + assert "Unsubscribed 'delfeed'" in bot.replied[0] + assert _load(bot, "#test:delfeed") is None + assert "#test:delfeed" not in _pollers + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_del_requires_admin(self): + _clear() + bot = _FakeBot(admin=False) + asyncio.run(cmd_rss(bot, _msg("!rss del somefeed"))) + assert "Permission denied" in bot.replied[0] + + def test_del_requires_channel(self): + _clear() + bot = _FakeBot(admin=True) + asyncio.run(cmd_rss(bot, _pm("!rss del somefeed"))) + assert "Use this command in a channel" in bot.replied[0] + + def test_del_nonexistent(self): + _clear() + bot = _FakeBot(admin=True) + asyncio.run(cmd_rss(bot, _msg("!rss del nosuchfeed"))) + assert "No feed" in bot.replied[0] + + def test_del_no_name(self): + _clear() + bot = _FakeBot(admin=True) + asyncio.run(cmd_rss(bot, _msg("!rss del"))) + assert "Usage:" in bot.replied[0] + + +# --------------------------------------------------------------------------- +# TestCmdRssList +# --------------------------------------------------------------------------- + +class TestCmdRssList: + def test_list_empty(self): + _clear() + bot = _FakeBot() + asyncio.run(cmd_rss(bot, _msg("!rss list"))) + assert "No feeds" in bot.replied[0] + + def test_list_populated(self): + _clear() + bot = _FakeBot() + _save(bot, "#test:hn", { + "name": "hn", "channel": "#test", "url": "https://hn.example.com", + "last_error": "", + }) + _save(bot, "#test:reddit", { + "name": "reddit", "channel": "#test", "url": "https://reddit.example.com", + "last_error": "", + }) + asyncio.run(cmd_rss(bot, _msg("!rss list"))) + assert "Feeds:" in bot.replied[0] + assert "hn" in bot.replied[0] + assert "reddit" in bot.replied[0] + + def test_list_shows_error(self): + _clear() + bot = _FakeBot() + _save(bot, "#test:broken", { + "name": "broken", "channel": "#test", "url": "https://broken.example.com", + "last_error": "Connection refused", + }) + asyncio.run(cmd_rss(bot, _msg("!rss list"))) + assert "broken (error)" in bot.replied[0] + + def test_list_requires_channel(self): + _clear() + bot = _FakeBot() + asyncio.run(cmd_rss(bot, _pm("!rss list"))) + assert "Use this command in a channel" in bot.replied[0] + + def test_list_only_this_channel(self): + _clear() + bot = _FakeBot() + _save(bot, "#test:mine", { + "name": "mine", "channel": "#test", "url": "https://mine.example.com", + "last_error": "", + }) + _save(bot, "#other:theirs", { + "name": "theirs", "channel": "#other", "url": "https://theirs.example.com", + "last_error": "", + }) + asyncio.run(cmd_rss(bot, _msg("!rss list"))) + assert "mine" in bot.replied[0] + assert "theirs" not in bot.replied[0] + + +# --------------------------------------------------------------------------- +# TestCmdRssCheck +# --------------------------------------------------------------------------- + +class TestCmdRssCheck: + def test_check_success(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "chk", "channel": "#test", + "interval": 600, "seen": ["item-1", "item-2", "item-3"], + "last_poll": "", "last_error": "", "etag": "", "last_modified": "", + "title": "Test", + } + _save(bot, "#test:chk", data) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss check chk")) + assert "chk: checked" in bot.replied[0] + + asyncio.run(inner()) + + def test_check_nonexistent(self): + _clear() + bot = _FakeBot() + asyncio.run(cmd_rss(bot, _msg("!rss check nope"))) + assert "No feed" in bot.replied[0] + + def test_check_requires_channel(self): + _clear() + bot = _FakeBot() + asyncio.run(cmd_rss(bot, _pm("!rss check something"))) + assert "Use this command in a channel" in bot.replied[0] + + def test_check_shows_error(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "errfeed", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + _save(bot, "#test:errfeed", data) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_error): + await cmd_rss(bot, _msg("!rss check errfeed")) + assert "error" in bot.replied[0].lower() + + asyncio.run(inner()) + + def test_check_announces_new_items(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "news", "channel": "#test", + "interval": 600, "seen": ["item-1"], + "last_poll": "", "last_error": "", "etag": "", "last_modified": "", + "title": "Test", + } + _save(bot, "#test:news", data) + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await cmd_rss(bot, _msg("!rss check news")) + # Should have sent announcements for item-2 and item-3 + announcements = [s for t, s in bot.sent if t == "#test"] + assert len(announcements) == 2 + assert "[news]" in announcements[0] + assert "Second Post" in announcements[0] + + asyncio.run(inner()) + + def test_check_no_name(self): + _clear() + bot = _FakeBot() + asyncio.run(cmd_rss(bot, _msg("!rss check"))) + assert "Usage:" in bot.replied[0] + + +# --------------------------------------------------------------------------- +# TestPollOnce +# --------------------------------------------------------------------------- + +class TestPollOnce: + def test_poll_304_clears_error(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "f304", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "old err", + "etag": '"xyz"', "last_modified": "", "title": "", + } + key = "#test:f304" + _save(bot, key, data) + _feeds[key] = data + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_304): + await _poll_once(bot, key) + updated = _load(bot, key) + assert updated["last_error"] == "" + + asyncio.run(inner()) + + def test_poll_error_increments(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "ferr", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + key = "#test:ferr" + _save(bot, key, data) + _feeds[key] = data + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_error): + await _poll_once(bot, key) + await _poll_once(bot, key) + assert _errors[key] == 2 + updated = _load(bot, key) + assert updated["last_error"] == "Connection refused" + + asyncio.run(inner()) + + def test_poll_max_announce(self): + """Only MAX_ANNOUNCE items are individually announced.""" + _clear() + bot = _FakeBot() + # Build a feed with 8 items + items_xml = "" + for i in range(8): + items_xml += f""" + + big-{i} + Item {i} + https://example.com/{i} + """ + big_feed = f"""\ + + + Big Feed{items_xml} +""".encode() + + def fake_big(url, etag="", lm=""): + return {"status": 200, "body": big_feed, "etag": "", "last_modified": "", "error": ""} + + data = { + "url": "https://example.com/big", "name": "big", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + key = "#test:big" + _save(bot, key, data) + _feeds[key] = data + + async def inner(): + with patch.object(_mod, "_fetch_feed", fake_big): + await _poll_once(bot, key, announce=True) + messages = [s for t, s in bot.sent if t == "#test"] + # 5 individual + 1 "... and N more" + assert len(messages) == _MAX_ANNOUNCE + 1 + assert "... and 3 more" in messages[-1] + + asyncio.run(inner()) + + def test_poll_no_announce_flag(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "quiet", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + key = "#test:quiet" + _save(bot, key, data) + _feeds[key] = data + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await _poll_once(bot, key, announce=False) + # No channel messages even though items are new + assert len(bot.sent) == 0 + # But seen should be updated + updated = _load(bot, key) + assert len(updated["seen"]) == 3 + + asyncio.run(inner()) + + def test_poll_updates_etag(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "etag", "channel": "#test", + "interval": 600, "seen": ["item-1", "item-2", "item-3"], + "last_poll": "", "last_error": "", "etag": "", "last_modified": "", + "title": "", + } + key = "#test:etag" + _save(bot, key, data) + _feeds[key] = data + + async def inner(): + with patch.object(_mod, "_fetch_feed", _fake_fetch_ok): + await _poll_once(bot, key) + updated = _load(bot, key) + assert updated["etag"] == '"abc"' + + asyncio.run(inner()) + + +# --------------------------------------------------------------------------- +# TestRestore +# --------------------------------------------------------------------------- + +class TestRestore: + def test_restore_spawns_pollers(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "restored", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + _save(bot, "#test:restored", data) + + async def inner(): + _restore(bot) + assert "#test:restored" in _pollers + task = _pollers["#test:restored"] + assert not task.done() + _stop_poller("#test:restored") + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_restore_skips_active(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "active", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + _save(bot, "#test:active", data) + + async def inner(): + # Pre-place an active task + dummy = asyncio.create_task(asyncio.sleep(9999)) + _pollers["#test:active"] = dummy + _restore(bot) + # Should not have replaced it + assert _pollers["#test:active"] is dummy + dummy.cancel() + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_restore_replaces_done_task(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "done", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + _save(bot, "#test:done", data) + + async def inner(): + # Place a completed task + done_task = asyncio.create_task(asyncio.sleep(0)) + await done_task + _pollers["#test:done"] = done_task + _restore(bot) + # Should have been replaced + new_task = _pollers["#test:done"] + assert new_task is not done_task + assert not new_task.done() + _stop_poller("#test:done") + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_restore_skips_bad_json(self): + _clear() + bot = _FakeBot() + bot.state.set("rss", "#test:bad", "not json{{{") + + async def inner(): + _restore(bot) + assert "#test:bad" not in _pollers + + asyncio.run(inner()) + + def test_on_connect_calls_restore(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "conn", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + _save(bot, "#test:conn", data) + + async def inner(): + msg = _msg("", target="botname") + await on_connect(bot, msg) + assert "#test:conn" in _pollers + _stop_poller("#test:conn") + await asyncio.sleep(0) + + asyncio.run(inner()) + + +# --------------------------------------------------------------------------- +# TestPollerManagement +# --------------------------------------------------------------------------- + +class TestPollerManagement: + def test_start_and_stop(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "mgmt", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + key = "#test:mgmt" + _save(bot, key, data) + _feeds[key] = data + + async def inner(): + _start_poller(bot, key) + assert key in _pollers + assert not _pollers[key].done() + _stop_poller(key) + await asyncio.sleep(0) + assert key not in _pollers + assert key not in _feeds + + asyncio.run(inner()) + + def test_start_idempotent(self): + _clear() + bot = _FakeBot() + data = { + "url": "https://example.com/feed", "name": "idem", "channel": "#test", + "interval": 600, "seen": [], "last_poll": "", "last_error": "", + "etag": "", "last_modified": "", "title": "", + } + key = "#test:idem" + _save(bot, key, data) + _feeds[key] = data + + async def inner(): + _start_poller(bot, key) + first = _pollers[key] + _start_poller(bot, key) + assert _pollers[key] is first + _stop_poller(key) + await asyncio.sleep(0) + + asyncio.run(inner()) + + def test_stop_nonexistent(self): + _clear() + # Should not raise + _stop_poller("#test:nonexistent") + + +# --------------------------------------------------------------------------- +# TestCmdRssUsage +# --------------------------------------------------------------------------- + +class TestCmdRssUsage: + def test_no_args(self): + _clear() + bot = _FakeBot() + asyncio.run(cmd_rss(bot, _msg("!rss"))) + assert "Usage:" in bot.replied[0] + + def test_unknown_subcommand(self): + _clear() + bot = _FakeBot() + asyncio.run(cmd_rss(bot, _msg("!rss foobar"))) + assert "Usage:" in bot.replied[0]