diff --git a/plugins/remind.py b/plugins/remind.py new file mode 100644 index 0000000..7178fb5 --- /dev/null +++ b/plugins/remind.py @@ -0,0 +1,186 @@ +"""Plugin: one-shot and repeating reminders with short ID tracking.""" + +from __future__ import annotations + +import asyncio +import hashlib +import re +import time +from datetime import datetime, timezone + +from derp.plugin import command + +_DURATION_RE = re.compile(r"(?:(\d+)d)?(?:(\d+)h)?(?:(\d+)m)?(?:(\d+)s)?$") + + +def _make_id(nick: str, label: str) -> str: + """Generate a short hex ID from nick + label + timestamp.""" + raw = f"{nick}:{label}:{time.monotonic()}".encode() + return hashlib.sha256(raw).hexdigest()[:6] + + +def _parse_duration(spec: str) -> int | None: + """Parse a duration like '5m', '1h30m', '2d', '90s', or raw seconds.""" + try: + secs = int(spec) + return secs if secs > 0 else None + except ValueError: + pass + m = _DURATION_RE.match(spec.lower()) + if not m or not any(m.groups()): + return None + days = int(m.group(1) or 0) + hours = int(m.group(2) or 0) + mins = int(m.group(3) or 0) + secs = int(m.group(4) or 0) + total = days * 86400 + hours * 3600 + mins * 60 + secs + return total if total > 0 else None + + +def _format_duration(secs: int) -> str: + """Format seconds into compact duration.""" + parts = [] + if secs >= 86400: + parts.append(f"{secs // 86400}d") + secs %= 86400 + if secs >= 3600: + parts.append(f"{secs // 3600}h") + secs %= 3600 + if secs >= 60: + parts.append(f"{secs // 60}m") + secs %= 60 + if secs or not parts: + parts.append(f"{secs}s") + return "".join(parts) + + +# In-memory tracking: {rid: (task, target, nick, label, created, repeating)} +_reminders: dict[str, tuple[asyncio.Task, str, str, str, str, bool]] = {} +# Reverse lookup: (target, nick) -> [rid, ...] +_by_user: dict[tuple[str, str], list[str]] = {} + + +def _cleanup(rid: str, target: str, nick: str) -> None: + """Remove a reminder from tracking structures.""" + _reminders.pop(rid, None) + ukey = (target, nick) + if ukey in _by_user: + _by_user[ukey] = [r for r in _by_user[ukey] if r != rid] + if not _by_user[ukey]: + del _by_user[ukey] + + +async def _remind_once(bot, rid: str, target: str, nick: str, label: str, + duration: int, created: str) -> None: + """One-shot reminder: sleep, fire, clean up.""" + try: + await asyncio.sleep(duration) + await bot.send(target, f"{nick}: reminder #{rid} (set {created})") + if label: + await bot.send(target, label) + except asyncio.CancelledError: + pass + finally: + _cleanup(rid, target, nick) + + +async def _remind_repeat(bot, rid: str, target: str, nick: str, label: str, + interval: int, created: str) -> None: + """Repeating reminder: fire every interval until cancelled.""" + try: + while True: + await asyncio.sleep(interval) + await bot.send(target, f"{nick}: reminder #{rid} (every {_format_duration(interval)})") + if label: + await bot.send(target, label) + except asyncio.CancelledError: + pass + finally: + _cleanup(rid, target, nick) + + +@command("remind", help="Reminder: !remind [every] | list | cancel ") +async def cmd_remind(bot, message): + """Set a one-shot or repeating reminder. + + Usage: + !remind 5m check the oven + !remind every 1h drink water + !remind 2d12h renew cert + !remind list + !remind cancel + """ + parts = message.text.split(None, 2) + if len(parts) < 2: + await bot.reply(message, "Usage: !remind [every] | list | cancel ") + return + + target = message.target if message.is_channel else message.nick + nick = message.nick + sub = parts[1].lower() + ukey = (target, nick) + + # List active reminders + if sub == "list": + rids = _by_user.get(ukey, []) + active = [] + for rid in rids: + entry = _reminders.get(rid) + if entry and not entry[0].done(): + tag = f"#{rid} (repeat)" if entry[5] else f"#{rid}" + active.append(tag) + if not active: + await bot.reply(message, "No active reminders") + return + await bot.reply(message, f"Reminders: {', '.join(active)}") + return + + # Cancel by ID + if sub == "cancel": + rid = parts[2].lstrip("#") if len(parts) > 2 else "" + if not rid: + await bot.reply(message, "Usage: !remind cancel ") + return + entry = _reminders.get(rid) + if entry and not entry[0].done() and entry[2] == nick: + entry[0].cancel() + await bot.reply(message, f"Cancelled #{rid}") + else: + await bot.reply(message, f"No active reminder #{rid}") + return + + # Detect repeating flag + repeating = False + if sub == "every": + repeating = True + rest = parts[2] if len(parts) > 2 else "" + parts = ["", "", *rest.split(None, 1)] # re-split: [_, _, duration, text] + + duration = _parse_duration(parts[1] if not repeating else parts[2]) + if duration is None: + await bot.reply(message, "Invalid duration (use: 5m, 1h30m, 2d, 90s)") + return + + label = "" + if repeating: + label = parts[3] if len(parts) > 3 else "" + else: + label = parts[2] if len(parts) > 2 else "" + + rid = _make_id(nick, label) + created = datetime.now(timezone.utc).strftime("%H:%M:%S UTC") + + if repeating: + task = asyncio.create_task( + _remind_repeat(bot, rid, target, nick, label, duration, created), + ) + else: + task = asyncio.create_task( + _remind_once(bot, rid, target, nick, label, duration, created), + ) + + _reminders[rid] = (task, target, nick, label, created, repeating) + _by_user.setdefault(ukey, []).append(rid) + + kind = f"every {_format_duration(duration)}" if repeating else _format_duration(duration) + await bot.reply(message, f"Reminder #{rid} set ({kind})") diff --git a/tests/test_remind.py b/tests/test_remind.py new file mode 100644 index 0000000..0836882 --- /dev/null +++ b/tests/test_remind.py @@ -0,0 +1,104 @@ +"""Tests for the remind plugin.""" + +import importlib.util +import sys +from pathlib import Path + +# plugins/ is not a Python package -- load the module from file path +_spec = importlib.util.spec_from_file_location( + "plugins.remind", Path(__file__).resolve().parent.parent / "plugins" / "remind.py", +) +_mod = importlib.util.module_from_spec(_spec) +sys.modules[_spec.name] = _mod +_spec.loader.exec_module(_mod) + +from plugins.remind import ( # noqa: E402 + _format_duration, + _make_id, + _parse_duration, +) + + +class TestParseDurationRawSeconds: + def test_positive_integer(self): + assert _parse_duration("60") == 60 + + def test_zero_returns_none(self): + assert _parse_duration("0") is None + + def test_negative_returns_none(self): + assert _parse_duration("-5") is None + + def test_large_value(self): + assert _parse_duration("86400") == 86400 + + +class TestParseDurationSpecs: + def test_minutes(self): + assert _parse_duration("5m") == 300 + + def test_hours_and_minutes(self): + assert _parse_duration("1h30m") == 5400 + + def test_days(self): + assert _parse_duration("2d") == 172800 + + def test_seconds_suffix(self): + assert _parse_duration("90s") == 90 + + def test_full_combo(self): + assert _parse_duration("1d12h30m15s") == 131415 + + +class TestParseDurationInvalid: + def test_empty_string(self): + assert _parse_duration("") is None + + def test_letters_only(self): + assert _parse_duration("abc") is None + + def test_all_zeros(self): + assert _parse_duration("0m0s") is None + + +class TestParseDurationEdgeCases: + def test_uppercase_works(self): + assert _parse_duration("5M") == 300 + + def test_hours_with_zero_minutes(self): + assert _parse_duration("1h0m") == 3600 + + def test_all_zeros_except_one_second(self): + assert _parse_duration("0d0h0m1s") == 1 + + +class TestFormatDuration: + def test_minutes_and_seconds(self): + assert _format_duration(90) == "1m30s" + + def test_exact_hour(self): + assert _format_duration(3600) == "1h" + + def test_exact_day(self): + assert _format_duration(86400) == "1d" + + def test_zero(self): + assert _format_duration(0) == "0s" + + def test_full_combo(self): + assert _format_duration(90061) == "1d1h1m1s" + + def test_seconds_only(self): + assert _format_duration(45) == "45s" + + +class TestMakeId: + def test_returns_six_char_hex(self): + rid = _make_id("user", "check oven") + assert len(rid) == 6 + assert all(c in "0123456789abcdef" for c in rid) + + def test_different_inputs_differ(self): + rid1 = _make_id("alice", "task one") + rid2 = _make_id("bob", "task two") + assert rid1 != rid2