"""Runtime implementation for ``browser-cli serve``. The Click command lives in ``browser_cli.commands.serve``. This module owns the connection lifecycle; auth, control commands and browser proxying live in small mixins so each piece can be tested/refactored independently. """ from __future__ import annotations import asyncio import json import socket from dataclasses import dataclass, field from pathlib import Path from browser_cli import transport from browser_cli.command_security import assert_command_allowed from browser_cli.compat import adapt_auth from browser_cli.constants import REMOTE_SESSION_IDLE_TIMEOUT from browser_cli.framing import async_recv_frame, async_send_frame from browser_cli.serve.auth import ServeAuthMixin from browser_cli.serve.challenge import build_challenge as _build_challenge, load_auth_keys as _load_auth_keys from browser_cli.serve.control import ServeControlMixin from browser_cli.serve.logging import console, log_request from browser_cli.serve.proxy import ServeProxyMixin from browser_cli.serve.security import ServeSecurity async def _async_framed_send(writer: asyncio.StreamWriter, data: bytes) -> None: await async_send_frame(writer, data) async def _async_recv_all(reader: asyncio.StreamReader) -> bytes: return await async_recv_frame(reader) or b"" @dataclass class ServeRequest(ServeAuthMixin, ServeControlMixin, ServeProxyMixin): reader: asyncio.StreamReader writer: asyncio.StreamWriter addr: tuple profile: str | None auth_keys: list[str] | None auth_keys_path: Path | None nonce: str pq_private_key: object | None = None compress: bool = True security: ServeSecurity = field(default_factory=ServeSecurity) response_secret: bytes | None = None accept_encoding: dict | None = None client_ver: str = "0" msg_id: object = None command: str = "?" auth_pubkey: str | None = None auth_label: str | None = None async def send_payload(self, data: bytes) -> None: if self.response_secret is not None: from browser_cli.auth import pq_encrypt data = json.dumps({"encrypted": pq_encrypt(self.response_secret, "response", data)}).encode() await _async_framed_send(self.writer, data) async def send_error(self, msg: str, msg_id=None) -> None: err = json.dumps({"id": self.msg_id if msg_id is None else msg_id, "success": False, "error": msg}).encode() try: await self.send_payload(err) except OSError: pass async def send_ok(self, payload, command: str | None = None) -> None: obj = {"id": self.msg_id, "success": True, "data": payload} try: await self.send_payload(transport.encode_response(obj, self.accept_encoding if self.compress else None, command)) except OSError: pass async def read_message(self) -> dict | None: try: payload = await _async_recv_all(self.reader) except (ConnectionError, OSError) as exc: if "too large" in str(exc): await self.send_error(str(exc), msg_id=None) return None try: msg = json.loads(payload) except (json.JSONDecodeError, ValueError): await self.send_error("invalid JSON", msg_id=None) log_request(self.addr, "?", None, "ERROR", "invalid JSON") return None return msg if isinstance(msg, dict) else None async def run(self) -> None: msg = await self.read_message() if msg is None or not await self.validate_client(msg): return msg = adapt_auth(msg, self.client_ver) self.command = msg.get("command", "?") msg = await self.authenticate(msg) if msg is None: return self._apply_identity(msg) await self._dispatch(msg) # Once an encrypted session is established, keep serving further commands on # the same connection — the client may reuse it without re-authenticating. # Safe because every frame carries a fresh AEAD nonce (see pq_encrypt). while self.response_secret is not None: nxt = await self._read_session_message() if nxt is None: return await self._dispatch(nxt) def _apply_identity(self, msg: dict) -> None: """Record the authenticated pubkey (if any) for per-key policy and audit logs.""" pub = (msg.get("pubkey") or "").strip().lower() self.auth_pubkey = pub or None self.auth_label = self.security.label_for(self.auth_pubkey) async def _enforce_rate_limit(self) -> bool: limiter = self.security.rate_limiter if limiter is None or limiter.allow(self.auth_pubkey or str(self.addr[0])): return True await self.send_error("rate limit exceeded; slow down and retry") log_request(self.addr, self.command, None, "DENIED", "rate limit exceeded", identity=self.auth_label) return False async def _dispatch(self, msg: dict) -> None: self.accept_encoding = msg.get("accept_encoding") if not await self._enforce_rate_limit(): return # Gate every command — including server control commands like the key-management # ones — so the policy is enforced before handle_control_command acts on it. try: assert_command_allowed(self.command, self.security.effective_policy(self.auth_pubkey)) except PermissionError as exc: await self.send_error(str(exc)) log_request(self.addr, self.command, None, "DENIED", "blocked by command policy", identity=self.auth_label) return if await self.handle_control_command(msg): return await self.forward_to_browser(msg) async def _read_session_message(self) -> dict | None: """Read the next command on an established encrypted session, or None to close.""" try: payload = await asyncio.wait_for(_async_recv_all(self.reader), timeout=REMOTE_SESSION_IDLE_TIMEOUT) except (asyncio.TimeoutError, ConnectionError, OSError): return None if not payload: return None try: outer = json.loads(payload) except (json.JSONDecodeError, ValueError): return None if not isinstance(outer, dict) or "encrypted" not in outer: return None # an authenticated session only accepts encrypted frames from browser_cli.auth import pq_decrypt try: inner = json.loads(pq_decrypt(self.response_secret, "request", outer["encrypted"])) except Exception: return None if not isinstance(inner, dict): return None inner = adapt_auth(inner, self.client_ver) self.msg_id = inner.get("id") self.command = inner.get("command", "?") return inner async def _async_proxy_request( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, addr: tuple, profile: str | None, auth_keys: list[str] | None, auth_keys_path: Path | None, nonce: str, pq_private_key=None, compress: bool = True, security: ServeSecurity | None = None, ) -> None: await ServeRequest( reader, writer, addr, profile, auth_keys, auth_keys_path, nonce, pq_private_key, compress, security if security is not None else ServeSecurity(), ).run() async def _async_handle_client( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, addr: tuple, profile: str | None, auth_keys_path: Path | None, compress: bool = True, conn_limit: asyncio.Semaphore | None = None, security: ServeSecurity | None = None, ) -> None: if conn_limit is None: conn_limit = asyncio.Semaphore(64) if conn_limit.locked(): writer.close() await writer.wait_closed() return await conn_limit.acquire() try: auth_keys = await _load_auth_keys(auth_keys_path) nonce, pq_private_key, challenge_msg = await _build_challenge(auth_keys_path) try: await _async_framed_send(writer, json.dumps(challenge_msg).encode()) except OSError: return await _async_proxy_request(reader, writer, addr, profile, auth_keys, auth_keys_path, nonce, pq_private_key, compress, security) finally: conn_limit.release() writer.close() try: await writer.wait_closed() except Exception: pass def _handle_client( client_sock: socket.socket, addr: tuple, profile: str | None, auth_keys_path: Path | None, compress: bool = True, security: ServeSecurity | None = None, ) -> None: """Run one accepted socket through the async serve pipeline.""" async def _run() -> None: reader, writer = await asyncio.open_connection(sock=client_sock) await _async_handle_client(reader, writer, addr, profile, auth_keys_path, compress, None, security) try: asyncio.run(_run()) except OSError: try: client_sock.close() except OSError: pass async def _serve_async( host: str, port: int, profile: str | None, auth_keys_path: Path | None, compress: bool, security: ServeSecurity | None = None, ) -> None: conn_limit = asyncio.Semaphore(64) async def _client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: peer = writer.get_extra_info("peername") or ("?", 0) await _async_handle_client(reader, writer, peer, profile, auth_keys_path, compress, conn_limit, security) server = await asyncio.start_server(_client_connected, host, port, backlog=16) async with server: await server.serve_forever()