"""Server-side authorization, per-key policy and rate limiting for ``browser-cli serve``. This bundles the three serve-time security concerns that travel together through the connection-handling chain: - ``policy`` the server-wide default ``CommandPolicy`` (from ``--allow-*``) - ``key_policies`` optional per-pubkey overrides parsed from the ``allow:`` token in the ``authorized_keys`` file - ``key_names`` pubkey -> friendly name (from authorized_keys), for audit logs - ``rate_limiter`` optional per-identity token-bucket throttle """ from __future__ import annotations import threading import time from dataclasses import dataclass, field from pathlib import Path from browser_cli.command_security import CommandPolicy # ── per-key authorization ─────────────────────────────────────────────────────── _CATEGORY_FLAGS = { "read-page": "allow_read_page", "control": "allow_control", "dangerous": "allow_dangerous", "keys": "allow_keys", } def policy_from_categories(categories) -> CommandPolicy: """Build a CommandPolicy from category strings (``all``/``safe``/``read-page``/``control``/``dangerous``).""" cats = [str(c).strip().lower() for c in categories] if "all" in cats: return CommandPolicy.unrestricted() kwargs: dict[str, bool] = {} for cat in cats: if cat in ("", "safe"): continue flag = _CATEGORY_FLAGS.get(cat) if flag is None: raise ValueError( f"unknown command category {cat!r}; expected one of: all, safe, read-page, control, dangerous" ) kwargs[flag] = True return CommandPolicy(**kwargs) def key_policies_from_authorized_keys(path: Path | str | None) -> dict[str, CommandPolicy]: """Build ``{pubkey: CommandPolicy}`` from the ``allow:`` tokens in authorized_keys. Only keys that carry an explicit ``allow:`` token get an entry; keys without one fall back to the server-wide default policy. Pubkeys are normalised to lowercase hex. Raises ``ValueError`` on an unknown category so the server fails loudly at startup rather than silently mis-gating. """ if path is None: return {} from browser_cli.auth import load_authorized_keys_with_policies out: dict[str, CommandPolicy] = {} for pubkey, _name, categories in load_authorized_keys_with_policies(Path(path)): if categories is not None: out[pubkey.strip().lower()] = policy_from_categories(categories) return out # ── per-identity rate limiting ─────────────────────────────────────────────────── class RateLimiter: """Token bucket keyed by identity (pubkey, or client address when unauthenticated). ``rate`` is the sustained refill in tokens/second; ``burst`` is the bucket capacity (defaults to ``rate``). ``rate <= 0`` disables limiting entirely. Thread-safe so it can be shared across all connections of one serve process. """ def __init__(self, rate: float, burst: float | None = None) -> None: self.rate = float(rate) self.capacity = float(burst) if burst is not None else max(float(rate), 1.0) self._buckets: dict[str, tuple[float, float]] = {} self._lock = threading.Lock() def allow(self, key: str) -> bool: if self.rate <= 0: return True now = time.monotonic() with self._lock: tokens, last = self._buckets.get(key, (self.capacity, now)) tokens = min(self.capacity, tokens + (now - last) * self.rate) if tokens < 1.0: self._buckets[key] = (tokens, now) return False self._buckets[key] = (tokens - 1.0, now) return True # ── bundled server security context ────────────────────────────────────────────── @dataclass(frozen=True) class ServeSecurity: policy: CommandPolicy = field(default_factory=CommandPolicy.unrestricted) key_policies: dict[str, CommandPolicy] = field(default_factory=dict) key_names: dict[str, str] = field(default_factory=dict) rate_limiter: RateLimiter | None = None def effective_policy(self, pubkey: str | None) -> CommandPolicy: """Per-key override if one exists for this pubkey, else the server default.""" if pubkey and pubkey in self.key_policies: return self.key_policies[pubkey] return self.policy def label_for(self, pubkey: str | None) -> str | None: """Audit label for log lines: `` …`` or just the short pubkey.""" if not pubkey: return None short = f"{pubkey[:8]}…" name = self.key_names.get(pubkey, "") return f"{name} {short}".strip() if name else short