"""Plugin: decode and inspect JSON Web Tokens.""" from __future__ import annotations import base64 import json import logging import time from derp.plugin import command log = logging.getLogger(__name__) _DANGEROUS_ALGS = {"none", ""} def _b64url_decode(s: str) -> bytes: """Base64url decode with padding correction.""" s = s.replace("-", "+").replace("_", "/") pad = 4 - len(s) % 4 if pad != 4: s += "=" * pad return base64.b64decode(s) def _decode_jwt(token: str) -> tuple[dict, dict, bytes]: """Decode JWT into (header, payload, signature_bytes). Raises ValueError on malformed tokens. """ parts = token.split(".") if len(parts) != 3: raise ValueError(f"expected 3 parts, got {len(parts)}") try: header = json.loads(_b64url_decode(parts[0])) except (json.JSONDecodeError, Exception) as exc: raise ValueError(f"invalid header: {exc}") from exc try: payload = json.loads(_b64url_decode(parts[1])) except (json.JSONDecodeError, Exception) as exc: raise ValueError(f"invalid payload: {exc}") from exc try: sig = _b64url_decode(parts[2]) if parts[2] else b"" except Exception: sig = b"" return header, payload, sig def _check_issues(header: dict, payload: dict) -> list[str]: """Return list of warning strings for common JWT issues.""" issues = [] now = time.time() alg = str(header.get("alg", "")).lower() if alg in _DANGEROUS_ALGS: issues.append(f'alg="{header.get("alg", "")}" (unsigned)') exp = payload.get("exp") if isinstance(exp, (int, float)): from datetime import datetime, timezone exp_dt = datetime.fromtimestamp(exp, tz=timezone.utc) if exp < now: issues.append(f"expired ({exp_dt:%Y-%m-%d %H:%M} UTC)") nbf = payload.get("nbf") if isinstance(nbf, (int, float)): from datetime import datetime, timezone nbf_dt = datetime.fromtimestamp(nbf, tz=timezone.utc) if nbf > now: issues.append(f"not yet valid (nbf={nbf_dt:%Y-%m-%d %H:%M} UTC)") return issues def _format_claims(payload: dict) -> str: """Format payload claims as compact key=value pairs.""" parts = [] for key, val in payload.items(): if key in ("exp", "nbf", "iat") and isinstance(val, (int, float)): from datetime import datetime, timezone dt = datetime.fromtimestamp(val, tz=timezone.utc) parts.append(f"{key}={dt:%Y-%m-%d %H:%M} UTC") elif isinstance(val, str): parts.append(f"{key}={val}") else: parts.append(f"{key}={json.dumps(val, separators=(',', ':'))}") return " | ".join(parts) @command("jwt", help="Decode JWT: !jwt ") async def cmd_jwt(bot, message): """Decode a JSON Web Token and display header, claims, and issues.""" parts = message.text.split(None, 2) if len(parts) < 2: await bot.reply(message, "Usage: !jwt ") return token = parts[1].strip() try: header, payload, sig = _decode_jwt(token) except ValueError as exc: await bot.reply(message, f"Invalid JWT: {exc}") return # Line 1: header alg = header.get("alg", "?") typ = header.get("typ", "?") sig_len = len(sig) hdr_line = f"Header: alg={alg} typ={typ} | sig={sig_len} bytes" # Line 2: claims if payload: claims_line = _format_claims(payload) else: claims_line = "(empty payload)" await bot.reply(message, hdr_line) await bot.reply(message, claims_line) # Line 3: warnings issues = _check_issues(header, payload) if issues: await bot.reply(message, "WARN: " + " | ".join(issues))