diff --git a/fpaste b/fpaste index 7684818..61a16cf 100755 --- a/fpaste +++ b/fpaste @@ -1,6 +1,8 @@ #!/usr/bin/env python3 """FlaskPaste command-line client.""" +from __future__ import annotations + import argparse import base64 import hashlib @@ -12,6 +14,10 @@ import urllib.error import urllib.request from datetime import UTC, datetime, timedelta from pathlib import Path +from typing import TYPE_CHECKING, Any, NoReturn + +if TYPE_CHECKING: + from collections.abc import Mapping # Optional cryptography support (for encryption and cert generation) try: @@ -25,74 +31,103 @@ try: except ImportError: HAS_CRYPTO = False +# Constants +CONFIG_DIR = Path.home() / ".config" / "fpaste" +CONFIG_FILE = CONFIG_DIR / "config" +CONFIG_KEYS = frozenset({"server", "cert_sha1", "client_cert", "client_key", "ca_cert"}) -def get_config(): - """Load configuration from environment or config file.""" - config = { - "server": os.environ.get("FLASKPASTE_SERVER", "http://localhost:5000"), - "cert_sha1": os.environ.get("FLASKPASTE_CERT_SHA1", ""), - "client_cert": os.environ.get("FLASKPASTE_CLIENT_CERT", ""), - "client_key": os.environ.get("FLASKPASTE_CLIENT_KEY", ""), - "ca_cert": os.environ.get("FLASKPASTE_CA_CERT", ""), +MIME_EXTENSIONS: dict[str, str] = { + "text/plain": ".txt", + "text/html": ".html", + "text/css": ".css", + "text/javascript": ".js", + "text/markdown": ".md", + "text/x-python": ".py", + "application/json": ".json", + "application/xml": ".xml", + "application/javascript": ".js", + "application/octet-stream": ".bin", + "image/png": ".png", + "image/jpeg": ".jpg", + "image/gif": ".gif", + "image/webp": ".webp", + "image/svg+xml": ".svg", + "application/pdf": ".pdf", + "application/zip": ".zip", + "application/gzip": ".gz", + "application/x-tar": ".tar", +} + +FILE_EXTENSIONS = frozenset( + { + "txt", + "md", + "py", + "js", + "json", + "yaml", + "yml", + "xml", + "html", + "css", + "sh", + "bash", + "c", + "cpp", + "h", + "go", + "rs", + "java", + "rb", + "php", + "sql", + "log", + "conf", + "cfg", + "ini", + "png", + "jpg", + "jpeg", + "gif", + "pdf", + "zip", + "tar", + "gz", } +) - # Try config file - config_file = Path.home() / ".config" / "fpaste" / "config" - if config_file.exists(): - for line in config_file.read_text().splitlines(): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - key, value = line.split("=", 1) - key = key.strip().lower() - value = value.strip().strip('"').strip("'") - if key == "server": - config["server"] = value - elif key == "cert_sha1": - config["cert_sha1"] = value - elif key == "client_cert": - config["client_cert"] = value - elif key == "client_key": - config["client_key"] = value - elif key == "ca_cert": - config["ca_cert"] = value - - return config +DATE_FORMATS = ( + "%Y-%m-%d", + "%Y-%m-%d %H:%M", + "%Y-%m-%d %H:%M:%S", + "%Y-%m-%dT%H:%M:%S", + "%Y-%m-%dT%H:%M:%SZ", +) -def create_ssl_context(config): - """Create SSL context for mTLS if certificates are configured.""" - client_cert = config.get("client_cert", "") - client_key = config.get("client_key", "") - ca_cert = config.get("ca_cert", "") - - if not client_cert: - return None - - ctx = ssl.create_default_context() - - # Load CA certificate if specified - if ca_cert: - ctx.load_verify_locations(ca_cert) - - # Load client certificate and key - try: - ctx.load_cert_chain(certfile=client_cert, keyfile=client_key or None) - except ssl.SSLError as e: - die(f"failed to load client certificate: {e}") - except FileNotFoundError as e: - die(f"certificate file not found: {e}") - - return ctx +# ----------------------------------------------------------------------------- +# Core utilities +# ----------------------------------------------------------------------------- -def request(url, method="GET", data=None, headers=None, ssl_context=None): - """Make HTTP request and return response.""" +def die(msg: str, code: int = 1) -> NoReturn: + """Print error and exit.""" + print(f"error: {msg}", file=sys.stderr) + sys.exit(code) + + +def request( + url: str, + method: str = "GET", + data: bytes | None = None, + headers: dict[str, str] | None = None, + ssl_context: ssl.SSLContext | None = None, +) -> tuple[int, bytes, dict[str, str]]: + """Make HTTP request and return (status, body, headers).""" headers = headers or {} - # User-configured server URL, audit is expected req = urllib.request.Request(url, data=data, headers=headers, method=method) # noqa: S310 try: - # User-configured server URL, audit is expected with urllib.request.urlopen(req, timeout=30, context=ssl_context) as resp: # noqa: S310 return resp.status, resp.read(), dict(resp.headers) except urllib.error.HTTPError as e: @@ -101,24 +136,112 @@ def request(url, method="GET", data=None, headers=None, ssl_context=None): die(f"Connection failed: {e.reason}") -def die(msg, code=1): - """Print error and exit.""" - print(f"error: {msg}", file=sys.stderr) - sys.exit(code) +def parse_error(body: bytes, default: str = "request failed") -> str: + """Parse error message from JSON response body.""" + try: + return json.loads(body).get("error", default) + except (json.JSONDecodeError, UnicodeDecodeError): + return default -def encrypt_content(plaintext): +# ----------------------------------------------------------------------------- +# Configuration +# ----------------------------------------------------------------------------- + + +def read_config_file(path: Path | None = None) -> dict[str, str]: + """Read config file and return key-value pairs.""" + path = path or CONFIG_FILE + result: dict[str, str] = {} + + if not path.exists(): + return result + + for line in path.read_text().splitlines(): + line = line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + key, value = line.split("=", 1) + key = key.strip().lower() + if key in CONFIG_KEYS: + result[key] = value.strip().strip('"').strip("'") + + return result + + +def write_config_file( + updates: dict[str, str], + path: Path | None = None, +) -> Path: + """Update config file with new values, preserving existing entries.""" + path = path or CONFIG_FILE + path.parent.mkdir(parents=True, exist_ok=True) + + existing = read_config_file(path) + existing.update(updates) + + lines = [f"{k} = {v}" for k, v in sorted(existing.items())] + path.write_text("\n".join(lines) + "\n") + return path + + +def get_config() -> dict[str, Any]: + """Load configuration from environment and config file.""" + config: dict[str, Any] = { + "server": os.environ.get("FLASKPASTE_SERVER", "http://localhost:5000"), + "cert_sha1": os.environ.get("FLASKPASTE_CERT_SHA1", ""), + "client_cert": os.environ.get("FLASKPASTE_CLIENT_CERT", ""), + "client_key": os.environ.get("FLASKPASTE_CLIENT_KEY", ""), + "ca_cert": os.environ.get("FLASKPASTE_CA_CERT", ""), + } + + # Config file values (lower priority than environment) + file_config = read_config_file() + for key in CONFIG_KEYS: + if not config.get(key) and key in file_config: + config[key] = file_config[key] + + return config + + +def create_ssl_context(config: Mapping[str, Any]) -> ssl.SSLContext | None: + """Create SSL context for mTLS if certificates are configured.""" + client_cert = config.get("client_cert", "") + if not client_cert: + return None + + ctx = ssl.create_default_context() + + if ca_cert := config.get("ca_cert", ""): + ctx.load_verify_locations(ca_cert) + + try: + ctx.load_cert_chain(certfile=client_cert, keyfile=config.get("client_key") or None) + except ssl.SSLError as e: + die(f"failed to load client certificate: {e}") + except FileNotFoundError as e: + die(f"certificate file not found: {e}") + + return ctx + + +# ----------------------------------------------------------------------------- +# Encryption +# ----------------------------------------------------------------------------- + + +def encrypt_content(plaintext: bytes) -> tuple[bytes, bytes]: """Encrypt content with AES-256-GCM. Returns (ciphertext, key).""" if not HAS_CRYPTO: die("encryption requires 'cryptography' package: pip install cryptography") key = os.urandom(32) - nonce = os.urandom(12) # 96-bit nonce for GCM + nonce = os.urandom(12) aesgcm = AESGCM(key) ciphertext = aesgcm.encrypt(nonce, plaintext, None) return nonce + ciphertext, key -def decrypt_content(blob, key): +def decrypt_content(blob: bytes, key: bytes) -> bytes: """Decrypt AES-256-GCM encrypted content.""" if not HAS_CRYPTO: die("decryption requires 'cryptography' package: pip install cryptography") @@ -132,14 +255,13 @@ def decrypt_content(blob, key): die("decryption failed (wrong key or corrupted data)") -def encode_key(key): +def encode_key(key: bytes) -> str: """Encode key as URL-safe base64.""" return base64.urlsafe_b64encode(key).decode().rstrip("=") -def decode_key(encoded): +def decode_key(encoded: str) -> bytes: """Decode URL-safe base64 key.""" - # Add padding if needed padding = 4 - (len(encoded) % 4) if padding != 4: encoded += "=" * padding @@ -149,19 +271,20 @@ def decode_key(encoded): die("invalid encryption key in URL") -def solve_pow(nonce, difficulty): - """Solve proof-of-work challenge. +# ----------------------------------------------------------------------------- +# Proof-of-work +# ----------------------------------------------------------------------------- - Find a number N such that SHA256(nonce:N) has `difficulty` leading zero bits. - """ + +def solve_pow(nonce: str, difficulty: int) -> int: + """Solve proof-of-work: find N where SHA256(nonce:N) has `difficulty` leading zero bits.""" n = 0 - target_bytes = (difficulty + 7) // 8 # Bytes to check + target_bytes = (difficulty + 7) // 8 while True: work = f"{nonce}:{n}".encode() hash_bytes = hashlib.sha256(work).digest() - # Count leading zero bits zero_bits = 0 for byte in hash_bytes[: target_bytes + 1]: if byte == 0: @@ -174,14 +297,14 @@ def solve_pow(nonce, difficulty): return n n += 1 - # Progress indicator for high difficulty if n % 100000 == 0: print(f"\rsolving pow: {n} attempts...", end="", file=sys.stderr) - return n - -def get_challenge(config, endpoint="/challenge"): +def get_challenge( + config: Mapping[str, Any], + endpoint: str = "/challenge", +) -> dict[str, Any] | None: """Fetch PoW challenge from server.""" url = config["server"].rstrip("/") + endpoint status, body, _ = request(url, ssl_context=config.get("ssl_context")) @@ -190,89 +313,209 @@ def get_challenge(config, endpoint="/challenge"): return None data = json.loads(body) - if not data.get("enabled"): - return None - - return data + return data if data.get("enabled") else None -def get_register_challenge(config): - """Fetch registration PoW challenge from server.""" - return get_challenge(config, endpoint="/register/challenge") +# ----------------------------------------------------------------------------- +# Formatting utilities +# ----------------------------------------------------------------------------- -def cmd_create(args, config): +def format_size(size: int) -> str: + """Format byte size as human-readable string.""" + if size < 1024: + return f"{size}B" + if size < 1024 * 1024: + return f"{size / 1024:.1f}K" + return f"{size / (1024 * 1024):.1f}M" + + +def format_timestamp(ts: int | float) -> str: + """Format Unix timestamp as human-readable date.""" + dt = datetime.fromtimestamp(ts, tz=UTC) + return dt.strftime("%Y-%m-%d %H:%M") + + +def parse_date(date_str: str) -> int: + """Parse date string to Unix timestamp.""" + if not date_str: + return 0 + + for fmt in DATE_FORMATS: + try: + dt = datetime.strptime(date_str, fmt).replace(tzinfo=UTC) + return int(dt.timestamp()) + except ValueError: + continue + + try: + return int(date_str) + except ValueError: + die(f"invalid date format: {date_str}") + + +def get_extension_for_mime(mime_type: str) -> str: + """Get file extension for MIME type.""" + return MIME_EXTENSIONS.get(mime_type, ".bin") + + +def format_paste_row(paste: dict[str, Any]) -> str: + """Format a paste as a table row.""" + paste_id = paste["id"] + mime_type = paste.get("mime_type", "unknown")[:16] + size = format_size(paste.get("size", 0)) + created = format_timestamp(paste.get("created_at", 0)) + + flags = [] + if paste.get("burn_after_read"): + flags.append("burn") + if paste.get("password_protected"): + flags.append("pass") + if paste.get("expires_at"): + flags.append("exp") + + return f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {' '.join(flags)}" + + +def print_paste_list( + pastes: list[dict[str, Any]], + summary: str, + as_json: bool = False, + data: dict[str, Any] | None = None, +) -> None: + """Print a list of pastes in table or JSON format.""" + if as_json: + print(json.dumps(data or {"pastes": pastes}, indent=2)) + return + + if not pastes: + print("no pastes found") + return + + print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS") + for paste in pastes: + print(format_paste_row(paste)) + print(f"\n{summary}") + + +# ----------------------------------------------------------------------------- +# Content helpers +# ----------------------------------------------------------------------------- + + +def read_content(file_arg: str | None) -> bytes: + """Read content from file or stdin.""" + if file_arg: + if file_arg == "-": + return sys.stdin.buffer.read() + path = Path(file_arg) + if not path.exists(): + die(f"file not found: {file_arg}") + return path.read_bytes() + + if sys.stdin.isatty(): + die("no input provided (pipe data or specify file)") + return sys.stdin.buffer.read() + + +def prepare_content( + content: bytes, + encrypt: bool, + quiet: bool = False, +) -> tuple[bytes, bytes | None]: + """Optionally encrypt content. Returns (content, encryption_key or None).""" + if not encrypt: + return content, None + + if not HAS_CRYPTO: + die("encryption requires 'cryptography' package (use -E to disable)") + + if not quiet: + print("encrypting...", end="", file=sys.stderr) + encrypted, key = encrypt_content(content) + if not quiet: + print(" done", file=sys.stderr) + return encrypted, key + + +def extract_paste_id(url_or_id: str) -> tuple[str, bytes | None]: + """Extract paste ID and optional encryption key from URL or ID.""" + encryption_key = None + + if "#" in url_or_id: + url_or_id, key_encoded = url_or_id.rsplit("#", 1) + if key_encoded: + encryption_key = decode_key(key_encoded) + + paste_id = url_or_id.split("/")[-1] + return paste_id, encryption_key + + +def auth_headers(config: Mapping[str, Any]) -> dict[str, str]: + """Build authentication headers.""" + if cert_sha1 := config.get("cert_sha1"): + return {"X-SSL-Client-SHA1": cert_sha1} + return {} + + +def require_auth(config: Mapping[str, Any]) -> None: + """Ensure authentication is configured.""" + if not config.get("cert_sha1"): + die("authentication required (set FLASKPASTE_CERT_SHA1)") + + +# ----------------------------------------------------------------------------- +# Commands +# ----------------------------------------------------------------------------- + + +def cmd_create(args: argparse.Namespace, config: dict[str, Any]) -> None: """Create a new paste.""" - # Read content from file or stdin - if args.file: - if args.file == "-": - content = sys.stdin.buffer.read() - else: - path = Path(args.file) - if not path.exists(): - die(f"file not found: {args.file}") - content = path.read_bytes() - else: - # No file specified, read from stdin - if sys.stdin.isatty(): - die("no input provided (pipe data or specify file)") - content = sys.stdin.buffer.read() - + content = read_content(args.file) if not content: die("empty content") - # Encrypt by default (unless --no-encrypt) - encryption_key = None - if not getattr(args, "no_encrypt", False): - if not HAS_CRYPTO: - die("encryption requires 'cryptography' package (use -E to disable)") - if not args.quiet: - print("encrypting...", end="", file=sys.stderr) - content, encryption_key = encrypt_content(content) - if not args.quiet: - print(" done", file=sys.stderr) + content, encryption_key = prepare_content( + content, + encrypt=not getattr(args, "no_encrypt", False), + quiet=args.quiet, + ) - # Build base headers (without PoW) - base_headers = {} - if config["cert_sha1"]: - base_headers["X-SSL-Client-SHA1"] = config["cert_sha1"] - - # Add burn-after-read header + # Build headers + base_headers = auth_headers(config) if args.burn: base_headers["X-Burn-After-Read"] = "true" - - # Add custom expiry header if args.expiry: base_headers["X-Expiry"] = str(args.expiry) - - # Add password header if args.password: base_headers["X-Paste-Password"] = args.password url = config["server"].rstrip("/") + "/" max_retries = 5 - last_error = None + last_error = "" for attempt in range(max_retries): headers = dict(base_headers) - # Get and solve PoW challenge if required - challenge = get_challenge(config) - if challenge: + if challenge := get_challenge(config): if attempt > 0 and not args.quiet: print(f"retry {attempt}/{max_retries - 1}...", file=sys.stderr) + if not args.quiet: diff = challenge["difficulty"] base_diff = challenge.get("base_difficulty", diff) elevated = challenge.get("elevated", False) - if elevated: - msg = f"solving pow ({diff} bits, elevated from {base_diff})..." - else: - msg = f"solving pow ({diff} bits)..." + msg = ( + f"solving pow ({diff} bits, elevated from {base_diff})..." + if elevated + else f"solving pow ({diff} bits)..." + ) print(msg, end="", file=sys.stderr) + solution = solve_pow(challenge["nonce"], challenge["difficulty"]) if not args.quiet: print(" done", file=sys.stderr) + headers["X-PoW-Token"] = challenge["token"] headers["X-PoW-Solution"] = str(solution) @@ -282,63 +525,43 @@ def cmd_create(args, config): if status == 201: data = json.loads(body) - # Append encryption key to URL fragment if encrypted - key_fragment = "" - if encryption_key: - key_fragment = "#" + encode_key(encryption_key) + key_fragment = f"#{encode_key(encryption_key)}" if encryption_key else "" + base_url = config["server"].rstrip("/") if args.raw: - print(config["server"].rstrip("/") + data["raw"] + key_fragment) + print(base_url + data["raw"] + key_fragment) elif args.quiet: print(data["id"] + key_fragment) else: - print(config["server"].rstrip("/") + data["url"] + key_fragment) + print(base_url + data["url"] + key_fragment) return - # Parse error - try: - last_error = json.loads(body).get("error", body.decode()) - except (json.JSONDecodeError, UnicodeDecodeError): - last_error = body.decode(errors="replace") - - # Check if PoW-related error (worth retrying) + last_error = parse_error(body, body.decode(errors="replace")) err_lower = last_error.lower() is_pow_error = status == 400 and ("pow" in err_lower or "proof-of-work" in err_lower) + if not is_pow_error: - # Non-PoW error, don't retry die(f"create failed ({status}): {last_error}") - # PoW error, will retry if attempts remain if not args.quiet: print(f"pow rejected: {last_error}", file=sys.stderr) - # All retries exhausted die(f"create failed after {max_retries} attempts: {last_error}") -def cmd_get(args, config): +def cmd_get(args: argparse.Namespace, config: dict[str, Any]) -> None: """Retrieve a paste.""" - # Parse URL for paste ID and optional encryption key fragment - url_input = args.id - encryption_key = None - - # Extract key from URL fragment (#...) - if "#" in url_input: - url_input, key_encoded = url_input.rsplit("#", 1) - if key_encoded: - encryption_key = decode_key(key_encoded) - - paste_id = url_input.split("/")[-1] # Handle full URLs + paste_id, encryption_key = extract_paste_id(args.id) base = config["server"].rstrip("/") - # Build headers for password-protected pastes - headers = {} + headers: dict[str, str] = {} if args.password: headers["X-Paste-Password"] = args.password if args.meta: url = f"{base}/{paste_id}" status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + if status == 200: data = json.loads(body) print(f"id: {data['id']}") @@ -357,10 +580,9 @@ def cmd_get(args, config): die(f"not found: {paste_id}") else: url = f"{base}/{paste_id}/raw" - ssl_ctx = config.get("ssl_context") - status, body, _ = request(url, headers=headers, ssl_context=ssl_ctx) + status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + if status == 200: - # Decrypt if encryption key was provided if encryption_key: body = decrypt_content(body, encryption_key) @@ -368,9 +590,7 @@ def cmd_get(args, config): Path(args.output).write_bytes(body) print(f"saved: {args.output}", file=sys.stderr) else: - # Write binary to stdout sys.stdout.buffer.write(body) - # Add newline if content doesn't end with one and stdout is tty if sys.stdout.isatty() and body and not body.endswith(b"\n"): sys.stdout.buffer.write(b"\n") elif status == 401: @@ -381,18 +601,15 @@ def cmd_get(args, config): die(f"not found: {paste_id}") -def cmd_delete(args, config): +def cmd_delete(args: argparse.Namespace, config: dict[str, Any]) -> None: """Delete a paste.""" - if not config["cert_sha1"]: - die("authentication required (set FLASKPASTE_CERT_SHA1)") + require_auth(config) paste_id = args.id.split("/")[-1] - base = config["server"].rstrip("/") - url = f"{base}/{paste_id}" + url = f"{config['server'].rstrip('/')}/{paste_id}" - headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} status, _, _ = request( - url, method="DELETE", headers=headers, ssl_context=config.get("ssl_context") + url, method="DELETE", headers=auth_headers(config), ssl_context=config.get("ssl_context") ) if status == 200: @@ -407,249 +624,113 @@ def cmd_delete(args, config): die(f"delete failed ({status})") -def cmd_info(args, config): +def cmd_info(args: argparse.Namespace, config: dict[str, Any]) -> None: """Show server info.""" url = config["server"].rstrip("/") + "/" status, body, _ = request(url, ssl_context=config.get("ssl_context")) - if status == 200: - data = json.loads(body) - print(f"server: {config['server']}") - print(f"name: {data.get('name', 'unknown')}") - print(f"version: {data.get('version', 'unknown')}") - - # Fetch PoW info - challenge = get_challenge(config) - if challenge: - difficulty = challenge.get("difficulty", 0) - base_diff = challenge.get("base_difficulty", difficulty) - elevated = challenge.get("elevated", False) - if elevated: - print(f"pow: {difficulty} bits (elevated from {base_diff})") - else: - print(f"pow: {difficulty} bits") - else: - print("pow: disabled") - else: + if status != 200: die("failed to connect to server") + data = json.loads(body) + print(f"server: {config['server']}") + print(f"name: {data.get('name', 'unknown')}") + print(f"version: {data.get('version', 'unknown')}") -def format_size(size): - """Format byte size as human-readable string.""" - if size < 1024: - return f"{size}B" - elif size < 1024 * 1024: - return f"{size / 1024:.1f}K" + if challenge := get_challenge(config): + difficulty = challenge.get("difficulty", 0) + base_diff = challenge.get("base_difficulty", difficulty) + if challenge.get("elevated"): + print(f"pow: {difficulty} bits (elevated from {base_diff})") + else: + print(f"pow: {difficulty} bits") else: - return f"{size / (1024 * 1024):.1f}M" + print("pow: disabled") -def format_timestamp(ts): - """Format Unix timestamp as human-readable date.""" - from datetime import datetime - - dt = datetime.fromtimestamp(ts, tz=UTC) - return dt.strftime("%Y-%m-%d %H:%M") - - -def cmd_list(args, config): +def cmd_list(args: argparse.Namespace, config: dict[str, Any]) -> None: """List user's pastes.""" - if not config["cert_sha1"]: - die("authentication required (set FLASKPASTE_CERT_SHA1)") + require_auth(config) - base = config["server"].rstrip("/") params = [] if args.limit: params.append(f"limit={args.limit}") if args.offset: params.append(f"offset={args.offset}") - url = f"{base}/pastes" + url = f"{config['server'].rstrip('/')}/pastes" if params: url += "?" + "&".join(params) - headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} - status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + status, body, _ = request( + url, headers=auth_headers(config), ssl_context=config.get("ssl_context") + ) if status == 401: die("authentication failed") - elif status != 200: + if status != 200: die(f"failed to list pastes ({status})") data = json.loads(body) pastes = data.get("pastes", []) - - if args.json: - print(json.dumps(data, indent=2)) - return - - if not pastes: - print("no pastes found") - return - - # Print header - print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS") - - for p in pastes: - paste_id = p["id"] - mime_type = p.get("mime_type", "unknown")[:16] - size = format_size(p.get("size", 0)) - created = format_timestamp(p.get("created_at", 0)) - - flags = [] - if p.get("burn_after_read"): - flags.append("burn") - if p.get("password_protected"): - flags.append("pass") - if p.get("expires_at"): - flags.append("exp") - - flag_str = " ".join(flags) - print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}") - - # Print summary - print(f"\n{data.get('count', 0)} of {data.get('total', 0)} pastes shown") + summary = f"{data.get('count', 0)} of {data.get('total', 0)} pastes shown" + print_paste_list(pastes, summary, as_json=args.json, data=data) -def parse_date(date_str): - """Parse date string to Unix timestamp.""" - from datetime import datetime - - if not date_str: - return 0 - - # Try various formats - formats = [ - "%Y-%m-%d", - "%Y-%m-%d %H:%M", - "%Y-%m-%d %H:%M:%S", - "%Y-%m-%dT%H:%M:%S", - "%Y-%m-%dT%H:%M:%SZ", - ] - for fmt in formats: - try: - dt = datetime.strptime(date_str, fmt) - dt = dt.replace(tzinfo=UTC) - return int(dt.timestamp()) - except ValueError: - continue - - # Try as Unix timestamp - try: - return int(date_str) - except ValueError: - pass - - die(f"invalid date format: {date_str}") - - -def cmd_search(args, config): +def cmd_search(args: argparse.Namespace, config: dict[str, Any]) -> None: """Search user's pastes.""" - if not config["cert_sha1"]: - die("authentication required (set FLASKPASTE_CERT_SHA1)") + require_auth(config) - base = config["server"].rstrip("/") params = [] - if args.type: params.append(f"type={args.type}") if args.after: - ts = parse_date(args.after) - params.append(f"after={ts}") + params.append(f"after={parse_date(args.after)}") if args.before: - ts = parse_date(args.before) - params.append(f"before={ts}") + params.append(f"before={parse_date(args.before)}") if args.limit: params.append(f"limit={args.limit}") - url = f"{base}/pastes" + url = f"{config['server'].rstrip('/')}/pastes" if params: url += "?" + "&".join(params) - headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} - status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + status, body, _ = request( + url, headers=auth_headers(config), ssl_context=config.get("ssl_context") + ) if status == 401: die("authentication failed") - elif status != 200: + if status != 200: die(f"failed to search pastes ({status})") data = json.loads(body) pastes = data.get("pastes", []) - - if args.json: - print(json.dumps(data, indent=2)) - return - - if not pastes: - print("no matching pastes found") - return - - # Print header - print(f"{'ID':<12} {'TYPE':<16} {'SIZE':>6} {'CREATED':<16} FLAGS") - - for p in pastes: - paste_id = p["id"] - mime_type = p.get("mime_type", "unknown")[:16] - size = format_size(p.get("size", 0)) - created = format_timestamp(p.get("created_at", 0)) - - flags = [] - if p.get("burn_after_read"): - flags.append("burn") - if p.get("password_protected"): - flags.append("pass") - if p.get("expires_at"): - flags.append("exp") - - flag_str = " ".join(flags) - print(f"{paste_id:<12} {mime_type:<16} {size:>6} {created:<16} {flag_str}") - - # Print summary - print(f"\n{data.get('count', 0)} matching pastes found") + summary = f"{data.get('count', 0)} matching pastes found" + print_paste_list(pastes, summary, as_json=args.json, data=data) -def cmd_update(args, config): +def cmd_update(args: argparse.Namespace, config: dict[str, Any]) -> None: """Update an existing paste.""" - if not config["cert_sha1"]: - die("authentication required (set FLASKPASTE_CERT_SHA1)") + require_auth(config) - paste_id = args.id.split("/")[-1] # Handle full URLs - if "#" in paste_id: - paste_id = paste_id.split("#")[0] # Remove key fragment + paste_id, _ = extract_paste_id(args.id) + url = f"{config['server'].rstrip('/')}/{paste_id}" - base = config["server"].rstrip("/") - url = f"{base}/{paste_id}" + headers = auth_headers(config) + content: bytes | None = None + encryption_key: bytes | None = None - headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} - content = None - - # Read content from file if provided if args.file: - if args.file == "-": - content = sys.stdin.buffer.read() - else: - path = Path(args.file) - if not path.exists(): - die(f"file not found: {args.file}") - content = path.read_bytes() - - if not content: + raw_content = read_content(args.file) + if not raw_content: die("empty content") + content, encryption_key = prepare_content( + raw_content, + encrypt=not getattr(args, "no_encrypt", False), + quiet=args.quiet, + ) - # Encrypt if requested (default is to encrypt) - if not getattr(args, "no_encrypt", False): - if not HAS_CRYPTO: - die("encryption requires 'cryptography' package (use -E to disable)") - if not args.quiet: - print("encrypting...", end="", file=sys.stderr) - content, encryption_key = encrypt_content(content) - if not args.quiet: - print(" done", file=sys.stderr) - else: - encryption_key = None - - # Set metadata update headers if args.password: headers["X-Paste-Password"] = args.password if args.remove_password: @@ -657,7 +738,6 @@ def cmd_update(args, config): if args.expiry: headers["X-Extend-Expiry"] = str(args.expiry) - # Make request status, body, _ = request( url, method="PUT", data=content, headers=headers, ssl_context=config.get("ssl_context") ) @@ -675,16 +755,11 @@ def cmd_update(args, config): if data.get("password_protected"): print(" password: protected") - # Show new encryption key if content was updated and encrypted - if content and "encryption_key" in dir() and encryption_key: - key_fragment = "#" + encode_key(encryption_key) - print(f" key: {base}/{paste_id}{key_fragment}") + if content and encryption_key: + base = config["server"].rstrip("/") + print(f" key: {base}/{paste_id}#{encode_key(encryption_key)}") elif status == 400: - try: - err = json.loads(body).get("error", "bad request") - except (json.JSONDecodeError, UnicodeDecodeError): - err = "bad request" - die(err) + die(parse_error(body, "bad request")) elif status == 401: die("authentication failed") elif status == 403: @@ -695,54 +770,43 @@ def cmd_update(args, config): die(f"update failed ({status})") -def cmd_export(args, config): +def cmd_export(args: argparse.Namespace, config: dict[str, Any]) -> None: """Export user's pastes to a directory.""" - if not config["cert_sha1"]: - die("authentication required (set FLASKPASTE_CERT_SHA1)") + require_auth(config) - base = config["server"].rstrip("/") out_dir = Path(args.output) if args.output else Path("fpaste-export") - - # Create output directory out_dir.mkdir(parents=True, exist_ok=True) - # Load key file if provided - keys = {} + # Load key file + keys: dict[str, str] = {} if args.keyfile: keyfile_path = Path(args.keyfile) if not keyfile_path.exists(): die(f"key file not found: {args.keyfile}") for line in keyfile_path.read_text().splitlines(): line = line.strip() - if not line or line.startswith("#"): - continue - if "=" not in line: - continue - paste_id, key_encoded = line.split("=", 1) - keys[paste_id.strip()] = key_encoded.strip() + if line and not line.startswith("#") and "=" in line: + paste_id, key_encoded = line.split("=", 1) + keys[paste_id.strip()] = key_encoded.strip() # Fetch paste list - headers = {"X-SSL-Client-SHA1": config["cert_sha1"]} - url = f"{base}/pastes?limit=1000" # Fetch all pastes - status, body, _ = request(url, headers=headers, ssl_context=config.get("ssl_context")) + url = f"{config['server'].rstrip('/')}/pastes?limit=1000" + status, body, _ = request( + url, headers=auth_headers(config), ssl_context=config.get("ssl_context") + ) if status == 401: die("authentication failed") - elif status != 200: + if status != 200: die(f"failed to list pastes ({status})") - data = json.loads(body) - pastes = data.get("pastes", []) - + pastes = json.loads(body).get("pastes", []) if not pastes: print("no pastes to export") return - # Export each paste - exported = 0 - skipped = 0 - errors = 0 - manifest = [] + exported, skipped, errors = 0, 0, 0 + manifest: list[dict[str, Any]] = [] for p in pastes: paste_id = p["id"] @@ -751,24 +815,22 @@ def cmd_export(args, config): if not args.quiet: print(f"exporting {paste_id}...", end=" ", file=sys.stderr) - # Skip burn-after-read pastes if p.get("burn_after_read"): if not args.quiet: print("skipped (burn-after-read)", file=sys.stderr) skipped += 1 continue - # Fetch raw content - raw_url = f"{base}/{paste_id}/raw" - req_headers = dict(headers) if p.get("password_protected"): if not args.quiet: print("skipped (password-protected)", file=sys.stderr) skipped += 1 continue - ssl_ctx = config.get("ssl_context") - status, content, _ = request(raw_url, headers=req_headers, ssl_context=ssl_ctx) + raw_url = f"{config['server'].rstrip('/')}/{paste_id}/raw" + status, content, _ = request( + raw_url, headers=auth_headers(config), ssl_context=config.get("ssl_context") + ) if status != 200: if not args.quiet: @@ -776,7 +838,6 @@ def cmd_export(args, config): errors += 1 continue - # Decrypt if key available decrypted = False if paste_id in keys: try: @@ -784,19 +845,12 @@ def cmd_export(args, config): content = decrypt_content(content, key) decrypted = True except SystemExit: - # Decryption failed, keep encrypted content if not args.quiet: print("decryption failed, keeping encrypted", file=sys.stderr, end=" ") - # Determine file extension from MIME type - ext = get_extension_for_mime(mime_type) - filename = f"{paste_id}{ext}" - filepath = out_dir / filename + filename = f"{paste_id}{get_extension_for_mime(mime_type)}" + (out_dir / filename).write_bytes(content) - # Write content - filepath.write_bytes(content) - - # Add to manifest manifest.append( { "id": paste_id, @@ -815,56 +869,27 @@ def cmd_export(args, config): exported += 1 - # Write manifest if args.manifest: manifest_path = out_dir / "manifest.json" manifest_path.write_text(json.dumps(manifest, indent=2)) if not args.quiet: print(f"manifest: {manifest_path}", file=sys.stderr) - # Summary print(f"\nexported: {exported}, skipped: {skipped}, errors: {errors}") print(f"output: {out_dir}") -def get_extension_for_mime(mime_type): - """Get file extension for MIME type.""" - mime_map = { - "text/plain": ".txt", - "text/html": ".html", - "text/css": ".css", - "text/javascript": ".js", - "text/markdown": ".md", - "text/x-python": ".py", - "application/json": ".json", - "application/xml": ".xml", - "application/javascript": ".js", - "application/octet-stream": ".bin", - "image/png": ".png", - "image/jpeg": ".jpg", - "image/gif": ".gif", - "image/webp": ".webp", - "image/svg+xml": ".svg", - "application/pdf": ".pdf", - "application/zip": ".zip", - "application/gzip": ".gz", - "application/x-tar": ".tar", - } - return mime_map.get(mime_type, ".bin") - - -def cmd_pki_status(args, config): +def cmd_pki_status(args: argparse.Namespace, config: dict[str, Any]) -> None: """Show PKI status and CA information.""" url = config["server"].rstrip("/") + "/pki" status, body, _ = request(url, ssl_context=config.get("ssl_context")) if status == 404: die("PKI not enabled on this server") - elif status != 200: + if status != 200: die(f"failed to get PKI status ({status})") data = json.loads(body) - print(f"pki enabled: {data.get('enabled', False)}") print(f"ca exists: {data.get('ca_exists', False)}") @@ -876,59 +901,42 @@ def cmd_pki_status(args, config): if data.get("expires_at"): print(f"expires: {data.get('expires_at')}") print(f"download: {config['server'].rstrip('/')}{data.get('download', '/pki/ca.crt')}") - elif data.get("hint"): - print(f"hint: {data.get('hint')}") + elif hint := data.get("hint"): + print(f"hint: {hint}") -def cmd_pki_issue(args, config): +def cmd_pki_issue(args: argparse.Namespace, config: dict[str, Any]) -> None: """Request a new client certificate from the server CA.""" url = config["server"].rstrip("/") + "/pki/issue" - headers = {"Content-Type": "application/json"} - if config["cert_sha1"]: - headers["X-SSL-Client-SHA1"] = config["cert_sha1"] - - payload = {"common_name": args.name} - data = json.dumps(payload).encode() + headers = {"Content-Type": "application/json", **auth_headers(config)} + payload = json.dumps({"common_name": args.name}).encode() status, body, _ = request( - url, method="POST", data=data, headers=headers, ssl_context=config.get("ssl_context") + url, method="POST", data=payload, headers=headers, ssl_context=config.get("ssl_context") ) if status == 404: - # Could be PKI disabled or no CA - try: - err = json.loads(body).get("error", "PKI not available") - except (json.JSONDecodeError, UnicodeDecodeError): - err = "PKI not available" - die(err) - elif status == 400: - try: - err = json.loads(body).get("error", "bad request") - except (json.JSONDecodeError, UnicodeDecodeError): - err = "bad request" - die(err) - elif status != 201: + die(parse_error(body, "PKI not available")) + if status == 400: + die(parse_error(body, "bad request")) + if status != 201: die(f"certificate issuance failed ({status})") result = json.loads(body) - # Determine output directory - out_dir = Path(args.output) if args.output else Path.home() / ".config" / "fpaste" + out_dir = Path(args.output) if args.output else CONFIG_DIR out_dir.mkdir(parents=True, exist_ok=True) - # File paths key_file = out_dir / "client.key" cert_file = out_dir / "client.crt" - # Check for existing files if not args.force: if key_file.exists(): die(f"key file exists: {key_file} (use --force)") if cert_file.exists(): die(f"cert file exists: {cert_file} (use --force)") - # Write files key_file.write_text(result["private_key_pem"]) key_file.chmod(0o600) cert_file.write_text(result["certificate_pem"]) @@ -941,144 +949,90 @@ def cmd_pki_issue(args, config): print(f"serial: {result.get('serial', 'unknown')}", file=sys.stderr) print(f"common name: {result.get('common_name', args.name)}", file=sys.stderr) - # Update config file if requested if args.configure: - config_file = Path.home() / ".config" / "fpaste" / "config" - config_file.parent.mkdir(parents=True, exist_ok=True) + cfg_path = write_config_file( + { + "client_cert": str(cert_file), + "client_key": str(key_file), + "cert_sha1": fingerprint, + } + ) + print(f"config: {cfg_path} (updated)", file=sys.stderr) - # Read existing config - existing = {} - if config_file.exists(): - for line in config_file.read_text().splitlines(): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - k, v = line.split("=", 1) - existing[k.strip().lower()] = v.strip() - - # Update values - existing["client_cert"] = str(cert_file) - existing["client_key"] = str(key_file) - existing["cert_sha1"] = fingerprint - - # Write config - lines = [f"{k} = {v}" for k, v in sorted(existing.items())] - config_file.write_text("\n".join(lines) + "\n") - print(f"config: {config_file} (updated)", file=sys.stderr) - - # Output fingerprint to stdout for easy capture print(fingerprint) -def cmd_pki_download(args, config): +def cmd_pki_download(args: argparse.Namespace, config: dict[str, Any]) -> None: """Download the CA certificate from the server.""" url = config["server"].rstrip("/") + "/pki/ca.crt" status, body, _ = request(url, ssl_context=config.get("ssl_context")) if status == 404: die("CA certificate not available (PKI disabled or CA not generated)") - elif status != 200: + if status != 200: die(f"failed to download CA certificate ({status})") - # Determine output if args.output: out_path = Path(args.output) out_path.write_bytes(body) print(f"saved: {out_path}", file=sys.stderr) - # Calculate and show fingerprint if cryptography available if HAS_CRYPTO: cert = x509.load_pem_x509_certificate(body) - # SHA1 is standard for X.509 fingerprints fp = hashlib.sha1(cert.public_bytes(serialization.Encoding.DER)).hexdigest() # noqa: S324 print(f"fingerprint: {fp}", file=sys.stderr) - # Update config if requested if args.configure: - config_file = Path.home() / ".config" / "fpaste" / "config" - config_file.parent.mkdir(parents=True, exist_ok=True) - - existing = {} - if config_file.exists(): - for line in config_file.read_text().splitlines(): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - k, v = line.split("=", 1) - existing[k.strip().lower()] = v.strip() - - existing["ca_cert"] = str(out_path) - - lines = [f"{k} = {v}" for k, v in sorted(existing.items())] - config_file.write_text("\n".join(lines) + "\n") - print(f"config: {config_file} (updated)", file=sys.stderr) + cfg_path = write_config_file({"ca_cert": str(out_path)}) + print(f"config: {cfg_path} (updated)", file=sys.stderr) else: - # Output to stdout sys.stdout.buffer.write(body) -def cmd_register(args, config): +def cmd_register(args: argparse.Namespace, config: dict[str, Any]) -> None: """Register and obtain a client certificate from the server.""" if not HAS_CRYPTO: die("register requires 'cryptography' package: pip install cryptography") - # Import pkcs12 for parsing the response from cryptography.hazmat.primitives.serialization import pkcs12 url = config["server"].rstrip("/") + "/register" - - # Build headers headers = {"Content-Type": "application/json"} - # Prepare payload - payload = {} + payload: dict[str, str] = {} if args.name: payload["common_name"] = args.name - # Get and solve PoW challenge if required - challenge = get_register_challenge(config) - if challenge: + if challenge := get_challenge(config, endpoint="/register/challenge"): if not args.quiet: - diff = challenge["difficulty"] - print(f"solving pow ({diff} bits)...", end="", file=sys.stderr) + print(f"solving pow ({challenge['difficulty']} bits)...", end="", file=sys.stderr) solution = solve_pow(challenge["nonce"], challenge["difficulty"]) if not args.quiet: print(" done", file=sys.stderr) headers["X-PoW-Token"] = challenge["token"] headers["X-PoW-Solution"] = str(solution) - # Make request data = json.dumps(payload).encode() if payload else b"{}" status, body, resp_headers = request( url, method="POST", data=data, headers=headers, ssl_context=config.get("ssl_context") ) if status == 400: - try: - err = json.loads(body).get("error", "bad request") - except (json.JSONDecodeError, UnicodeDecodeError): - err = "bad request" - die(err) - elif status == 500: - try: - err = json.loads(body).get("error", "server error") - except (json.JSONDecodeError, UnicodeDecodeError): - err = "server error" - die(err) - elif status != 200: + die(parse_error(body, "bad request")) + if status == 500: + die(parse_error(body, "server error")) + if status != 200: die(f"registration failed ({status})") - # Get fingerprint from response header fingerprint = resp_headers.get("X-Fingerprint-SHA1", "unknown") - # Determine output directory - out_dir = Path(args.output) if args.output else Path.home() / ".config" / "fpaste" + out_dir = Path(args.output) if args.output else CONFIG_DIR out_dir.mkdir(parents=True, exist_ok=True) - # File paths p12_file = out_dir / "client.p12" key_file = out_dir / "client.key" cert_file = out_dir / "client.crt" - # Check for existing files if not args.force: if p12_file.exists(): die(f"p12 file exists: {p12_file} (use --force)") @@ -1088,30 +1042,23 @@ def cmd_register(args, config): if cert_file.exists(): die(f"cert file exists: {cert_file} (use --force)") - # Save PKCS#12 bundle p12_file.write_bytes(body) p12_file.chmod(0o600) print(f"pkcs12: {p12_file}", file=sys.stderr) - # Extract certificate and key unless --p12-only if not args.p12_only: - # Parse PKCS#12 bundle (no password) - private_key, certificate, _additional_certs = pkcs12.load_key_and_certificates(body, None) + private_key, certificate, _ = pkcs12.load_key_and_certificates(body, None) if private_key is None or certificate is None: die("failed to parse PKCS#12 bundle") - # Serialize private key key_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) - - # Serialize certificate cert_pem = certificate.public_bytes(serialization.Encoding.PEM) - # Write files key_file.write_bytes(key_pem) key_file.chmod(0o600) cert_file.write_bytes(cert_pem) @@ -1119,56 +1066,36 @@ def cmd_register(args, config): print(f"key: {key_file}", file=sys.stderr) print(f"certificate: {cert_file}", file=sys.stderr) - print(f"fingerprint: {fingerprint}", file=sys.stderr) - - # Get common name from certificate - if not args.p12_only: cn = certificate.subject.get_attributes_for_oid(NameOID.COMMON_NAME) if cn: print(f"common name: {cn[0].value}", file=sys.stderr) - # Update config file if requested + print(f"fingerprint: {fingerprint}", file=sys.stderr) + if args.configure and not args.p12_only: - config_file = Path.home() / ".config" / "fpaste" / "config" - config_file.parent.mkdir(parents=True, exist_ok=True) + cfg_path = write_config_file( + { + "client_cert": str(cert_file), + "client_key": str(key_file), + "cert_sha1": fingerprint, + } + ) + print(f"config: {cfg_path} (updated)", file=sys.stderr) - # Read existing config - existing = {} - if config_file.exists(): - for line in config_file.read_text().splitlines(): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - k, v = line.split("=", 1) - existing[k.strip().lower()] = v.strip() - - # Update values - existing["client_cert"] = str(cert_file) - existing["client_key"] = str(key_file) - existing["cert_sha1"] = fingerprint - - # Write config - lines = [f"{k} = {v}" for k, v in sorted(existing.items())] - config_file.write_text("\n".join(lines) + "\n") - print(f"config: {config_file} (updated)", file=sys.stderr) - - # Output fingerprint to stdout for easy capture print(fingerprint) -def cmd_cert(args, config): +def cmd_cert(args: argparse.Namespace, config: dict[str, Any]) -> None: """Generate a self-signed client certificate for mTLS authentication.""" if not HAS_CRYPTO: die("certificate generation requires 'cryptography' package: pip install cryptography") - # Determine output directory - out_dir = Path(args.output) if args.output else Path.home() / ".config" / "fpaste" + out_dir = Path(args.output) if args.output else CONFIG_DIR out_dir.mkdir(parents=True, exist_ok=True) - # File paths key_file = out_dir / "client.key" cert_file = out_dir / "client.crt" - # Check for existing files if not args.force: if key_file.exists(): die(f"key file exists: {key_file} (use --force)") @@ -1179,10 +1106,7 @@ def cmd_cert(args, config): if args.algorithm == "rsa": key_size = args.bits or 4096 print(f"generating {key_size}-bit RSA key...", file=sys.stderr) - private_key = rsa.generate_private_key( - public_exponent=65537, - key_size=key_size, - ) + private_key = rsa.generate_private_key(public_exponent=65537, key_size=key_size) elif args.algorithm == "ec": curve_name = args.curve or "secp384r1" curves = { @@ -1197,7 +1121,6 @@ def cmd_cert(args, config): else: die(f"unsupported algorithm: {args.algorithm}") - # Certificate subject cn = args.name or os.environ.get("USER", "fpaste-client") subject = issuer = x509.Name( [ @@ -1206,11 +1129,9 @@ def cmd_cert(args, config): ] ) - # Validity period days = args.days or 365 now = datetime.now(UTC) - # Build certificate cert_builder = ( x509.CertificateBuilder() .subject_name(subject) @@ -1219,10 +1140,7 @@ def cmd_cert(args, config): .serial_number(x509.random_serial_number()) .not_valid_before(now) .not_valid_after(now + timedelta(days=days)) - .add_extension( - x509.BasicConstraints(ca=False, path_length=None), - critical=True, - ) + .add_extension(x509.BasicConstraints(ca=False, path_length=None), critical=True) .add_extension( x509.KeyUsage( digital_signature=True, @@ -1243,32 +1161,27 @@ def cmd_cert(args, config): ) ) - # Sign certificate print("signing certificate...", file=sys.stderr) certificate = cert_builder.sign(private_key, hashes.SHA256()) - # Calculate SHA1 fingerprint (standard for X.509) cert_der = certificate.public_bytes(serialization.Encoding.DER) fingerprint = hashlib.sha1(cert_der).hexdigest() # noqa: S324 - # Serialize private key - if args.password_key: - key_encryption = serialization.BestAvailableEncryption(args.password_key.encode("utf-8")) - else: - key_encryption = serialization.NoEncryption() + key_encryption = ( + serialization.BestAvailableEncryption(args.password_key.encode()) + if args.password_key + else serialization.NoEncryption() + ) key_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=key_encryption, ) - - # Serialize certificate cert_pem = certificate.public_bytes(serialization.Encoding.PEM) - # Write files key_file.write_bytes(key_pem) - key_file.chmod(0o600) # Restrict permissions + key_file.chmod(0o600) cert_file.write_bytes(cert_pem) print(f"key: {key_file}", file=sys.stderr) @@ -1277,156 +1190,46 @@ def cmd_cert(args, config): print(f"valid for: {days} days", file=sys.stderr) print(f"common name: {cn}", file=sys.stderr) - # Update config file if requested if args.configure: - config_file = Path.home() / ".config" / "fpaste" / "config" - config_file.parent.mkdir(parents=True, exist_ok=True) + cfg_path = write_config_file( + { + "client_cert": str(cert_file), + "client_key": str(key_file), + "cert_sha1": fingerprint, + } + ) + print(f"config: {cfg_path} (updated)", file=sys.stderr) - # Read existing config - existing = {} - if config_file.exists(): - for line in config_file.read_text().splitlines(): - line = line.strip() - if line and not line.startswith("#") and "=" in line: - k, v = line.split("=", 1) - existing[k.strip().lower()] = v.strip() - - # Update values - existing["client_cert"] = str(cert_file) - existing["client_key"] = str(key_file) - existing["cert_sha1"] = fingerprint - - # Write config - lines = [f"{k} = {v}" for k, v in sorted(existing.items())] - config_file.write_text("\n".join(lines) + "\n") - print(f"config: {config_file} (updated)", file=sys.stderr) - - # Output fingerprint to stdout for easy capture print(fingerprint) -def is_file_path(arg): +# ----------------------------------------------------------------------------- +# Argument parsing +# ----------------------------------------------------------------------------- + + +def is_file_path(arg: str) -> bool: """Check if argument looks like a file path.""" if not arg or arg.startswith("-"): return False - # Check if it's an existing file if Path(arg).exists(): return True - # Check if it looks like a path (contains / or \ or common extensions) if "/" in arg or "\\" in arg: return True - # Check for common file extensions if "." in arg and not arg.startswith("."): ext = arg.rsplit(".", 1)[-1].lower() - if ext in ( - "txt", - "md", - "py", - "js", - "json", - "yaml", - "yml", - "xml", - "html", - "css", - "sh", - "bash", - "c", - "cpp", - "h", - "go", - "rs", - "java", - "rb", - "php", - "sql", - "log", - "conf", - "cfg", - "ini", - "png", - "jpg", - "jpeg", - "gif", - "pdf", - "zip", - "tar", - "gz", - ): - return True + return ext in FILE_EXTENSIONS return False -def main(): - # Pre-process arguments: if first positional looks like a file, insert "create" - args_to_parse = sys.argv[1:] - commands = { - "create", - "c", - "new", - "get", - "g", - "delete", - "d", - "rm", - "info", - "i", - "list", - "ls", - "search", - "s", - "find", - "update", - "u", - "export", - "register", - "cert", - "pki", - } - - # Find insertion point for "create" command - insert_pos = 0 - has_command = False - file_pos = -1 - - i = 0 - while i < len(args_to_parse): - arg = args_to_parse[i] - if arg in ("-s", "--server"): - insert_pos = i + 2 # After -s value - i += 2 - continue - if arg in ("-h", "--help"): - i += 1 - insert_pos = i - continue - if arg.startswith("-"): - # Unknown option - might be for create subcommand - i += 1 - continue - # Found positional argument - if arg in commands: - has_command = True - break - elif is_file_path(arg): - file_pos = i - break - i += 1 - - # Insert "create" if no command found and we have input (file path or piped stdin) - if not has_command and (file_pos >= 0 or not sys.stdin.isatty()): - args_to_parse.insert(insert_pos, "create") - +def build_parser() -> argparse.ArgumentParser: + """Build and return the argument parser.""" parser = argparse.ArgumentParser( prog="fpaste", description="FlaskPaste command-line client", epilog="Shortcut: fpaste is equivalent to fpaste create ", ) - parser.add_argument( - "-s", - "--server", - help="server URL (env: FLASKPASTE_SERVER)", - ) + parser.add_argument("-s", "--server", help="server URL (env: FLASKPASTE_SERVER)") subparsers = parser.add_subparsers(dest="command", metavar="command") # create @@ -1487,15 +1290,9 @@ def main(): # register p_register = subparsers.add_parser("register", help="register and get client certificate") p_register.add_argument("-n", "--name", metavar="CN", help="common name (optional)") - p_register.add_argument( - "-o", "--output", metavar="DIR", help="output directory (default: ~/.config/fpaste)" - ) - p_register.add_argument( - "--configure", action="store_true", help="update config file with cert paths" - ) - p_register.add_argument( - "--p12-only", action="store_true", help="save only PKCS#12, don't extract cert/key" - ) + p_register.add_argument("-o", "--output", metavar="DIR", help="output directory") + p_register.add_argument("--configure", action="store_true", help="update config file") + p_register.add_argument("--p12-only", action="store_true", help="save only PKCS#12") p_register.add_argument("-f", "--force", action="store_true", help="overwrite existing files") p_register.add_argument("-q", "--quiet", action="store_true", help="minimal output") @@ -1503,66 +1300,122 @@ def main(): p_cert = subparsers.add_parser("cert", help="generate client certificate") p_cert.add_argument("-o", "--output", metavar="DIR", help="output directory") p_cert.add_argument( - "-a", "--algorithm", choices=["rsa", "ec"], default="ec", help="key algorithm (default: ec)" + "-a", "--algorithm", choices=["rsa", "ec"], default="ec", help="key algorithm" ) p_cert.add_argument("-b", "--bits", type=int, metavar="N", help="RSA key size (default: 4096)") p_cert.add_argument( - "-c", "--curve", metavar="CURVE", help="EC curve: secp256r1, secp384r1, secp521r1" + "-c", "--curve", metavar="CURVE", help="EC curve (secp256r1/secp384r1/secp521r1)" ) p_cert.add_argument("-d", "--days", type=int, metavar="N", help="validity period in days") p_cert.add_argument("-n", "--name", metavar="CN", help="common name (default: $USER)") - p_cert.add_argument("--password-key", metavar="PASS", help="encrypt private key with password") - p_cert.add_argument( - "--configure", action="store_true", help="update config file with generated cert paths" - ) + p_cert.add_argument("--password-key", metavar="PASS", help="encrypt private key") + p_cert.add_argument("--configure", action="store_true", help="update config file") p_cert.add_argument("-f", "--force", action="store_true", help="overwrite existing files") - # pki (with subcommands) + # pki p_pki = subparsers.add_parser("pki", help="PKI operations (server-issued certificates)") pki_sub = p_pki.add_subparsers(dest="pki_command", metavar="subcommand") - # pki status pki_sub.add_parser("status", help="show PKI status and CA info") - # pki issue p_pki_issue = pki_sub.add_parser("issue", help="request certificate from server CA") p_pki_issue.add_argument( - "-n", "--name", required=True, metavar="CN", help="common name for certificate (required)" - ) - p_pki_issue.add_argument( - "-o", "--output", metavar="DIR", help="output directory (default: ~/.config/fpaste)" - ) - p_pki_issue.add_argument( - "--configure", action="store_true", help="update config file with issued cert paths" + "-n", "--name", required=True, metavar="CN", help="common name (required)" ) + p_pki_issue.add_argument("-o", "--output", metavar="DIR", help="output directory") + p_pki_issue.add_argument("--configure", action="store_true", help="update config file") p_pki_issue.add_argument("-f", "--force", action="store_true", help="overwrite existing files") - # pki download p_pki_download = pki_sub.add_parser("download", aliases=["dl"], help="download CA certificate") + p_pki_download.add_argument("-o", "--output", metavar="FILE", help="save to file") p_pki_download.add_argument( - "-o", "--output", metavar="FILE", help="save to file (default: stdout)" - ) - p_pki_download.add_argument( - "--configure", - action="store_true", - help="update config file with CA cert path (requires -o)", + "--configure", action="store_true", help="update config file (requires -o)" ) + return parser + + +# Command dispatch table +COMMANDS: dict[str, Any] = { + "create": cmd_create, + "c": cmd_create, + "new": cmd_create, + "get": cmd_get, + "g": cmd_get, + "delete": cmd_delete, + "d": cmd_delete, + "rm": cmd_delete, + "info": cmd_info, + "i": cmd_info, + "list": cmd_list, + "ls": cmd_list, + "search": cmd_search, + "s": cmd_search, + "find": cmd_search, + "update": cmd_update, + "u": cmd_update, + "export": cmd_export, + "register": cmd_register, + "cert": cmd_cert, +} + +PKI_COMMANDS: dict[str, Any] = { + "status": cmd_pki_status, + "issue": cmd_pki_issue, + "download": cmd_pki_download, + "dl": cmd_pki_download, +} + + +def main() -> None: + """Main entry point.""" + args_to_parse = sys.argv[1:] + command_names = set(COMMANDS.keys()) | {"pki"} + + # Auto-insert "create" if first positional looks like a file + insert_pos = 0 + has_command = False + file_pos = -1 + + i = 0 + while i < len(args_to_parse): + arg = args_to_parse[i] + if arg in ("-s", "--server"): + insert_pos = i + 2 + i += 2 + continue + if arg in ("-h", "--help"): + i += 1 + insert_pos = i + continue + if arg.startswith("-"): + i += 1 + continue + if arg in command_names: + has_command = True + break + if is_file_path(arg): + file_pos = i + break + i += 1 + + if not has_command and (file_pos >= 0 or not sys.stdin.isatty()): + args_to_parse.insert(insert_pos, "create") + + parser = build_parser() args = parser.parse_args(args_to_parse) config = get_config() if args.server: config["server"] = args.server - # Create SSL context for mTLS if configured config["ssl_context"] = create_ssl_context(config) if not args.command: - # Default: create from stdin if data is piped if not sys.stdin.isatty(): args.command = "create" args.file = None - args.no_encrypt = False # Encrypt by default + args.no_encrypt = False args.burn = False args.expiry = None args.password = None @@ -1572,36 +1425,13 @@ def main(): parser.print_help() sys.exit(0) - if args.command in ("create", "c", "new"): - cmd_create(args, config) - elif args.command in ("get", "g"): - cmd_get(args, config) - elif args.command in ("delete", "d", "rm"): - cmd_delete(args, config) - elif args.command in ("info", "i"): - cmd_info(args, config) - elif args.command in ("list", "ls"): - cmd_list(args, config) - elif args.command in ("search", "s", "find"): - cmd_search(args, config) - elif args.command in ("update", "u"): - cmd_update(args, config) - elif args.command == "export": - cmd_export(args, config) - elif args.command == "register": - cmd_register(args, config) - elif args.command == "cert": - cmd_cert(args, config) - elif args.command == "pki": - if args.pki_command == "status": - cmd_pki_status(args, config) - elif args.pki_command == "issue": - cmd_pki_issue(args, config) - elif args.pki_command in ("download", "dl"): - cmd_pki_download(args, config) + if args.command == "pki": + if args.pki_command in PKI_COMMANDS: + PKI_COMMANDS[args.pki_command](args, config) else: - # Show pki help if no subcommand parser.parse_args(["pki", "--help"]) + elif args.command in COMMANDS: + COMMANDS[args.command](args, config) if __name__ == "__main__":