"""Ed25519 keypair management, ML-KEM key exchange, and auth helpers.""" import hashlib import json import os import secrets import socket import struct from dataclasses import dataclass from pathlib import Path from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey from cryptography.hazmat.primitives.serialization import ( Encoding, NoEncryption, PrivateFormat, PublicFormat, load_pem_private_key, ) _CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", str(Path.home() / ".config"))) / "browser-cli" DEFAULT_KEY_PATH = _CONFIG_DIR / "client.key.pem" DEFAULT_AUTHORIZED_KEYS_PATH = _CONFIG_DIR / "authorized_keys" # ── SSH agent protocol constants ─────────────────────────────────────────────── _SSH_AGENTC_REQUEST_IDENTITIES = 11 _SSH_AGENT_IDENTITIES_ANSWER = 12 _SSH_AGENTC_SIGN_REQUEST = 13 _SSH_AGENT_SIGN_RESPONSE = 14 def _pack_str(s: bytes) -> bytes: return struct.pack(">I", len(s)) + s def _unpack_str(data: bytes, off: int) -> tuple[bytes, int]: n = struct.unpack_from(">I", data, off)[0] return data[off + 4 : off + 4 + n], off + 4 + n def _agent_roundtrip(msg: bytes) -> bytes: sock_path = os.environ.get("SSH_AUTH_SOCK") if not sock_path: raise RuntimeError("SSH_AUTH_SOCK not set — is gpg-agent / ssh-agent running?") with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as sock: sock.settimeout(10) sock.connect(sock_path) sock.sendall(struct.pack(">I", len(msg)) + msg) raw_len = b"" while len(raw_len) < 4: chunk = sock.recv(4 - len(raw_len)) if not chunk: raise RuntimeError("SSH agent closed connection") raw_len += chunk n = struct.unpack(">I", raw_len)[0] resp = b"" while len(resp) < n: chunk = sock.recv(n - len(resp)) if not chunk: raise RuntimeError("SSH agent closed connection mid-response") resp += chunk return resp # ── AgentKey ─────────────────────────────────────────────────────────────────── @dataclass class AgentKey: """Ed25519 key backed by an SSH agent (YubiKey, TPM, ssh-agent, gpg-agent …).""" blob: bytes comment: str @property def pubkey_bytes(self) -> bytes: _algo, off = _unpack_str(self.blob, 0) key_bytes, _ = _unpack_str(self.blob, off) return key_bytes # ── Agent helpers ────────────────────────────────────────────────────────────── def agent_list_keys() -> list[AgentKey]: """Return all Ed25519 keys currently held by the SSH agent.""" resp = _agent_roundtrip(bytes([_SSH_AGENTC_REQUEST_IDENTITIES])) if resp[0] != _SSH_AGENT_IDENTITIES_ANSWER: raise RuntimeError(f"Unexpected agent response: {resp[0]}") n_keys = struct.unpack_from(">I", resp, 1)[0] keys: list[AgentKey] = [] off = 5 for _ in range(n_keys): blob, off = _unpack_str(resp, off) comment, off = _unpack_str(resp, off) algo, _ = _unpack_str(blob, 0) if algo == b"ssh-ed25519": keys.append(AgentKey(blob=blob, comment=comment.decode("utf-8", errors="replace"))) return keys def agent_find_key(selector: str | None = None) -> AgentKey | None: """Return the first agent Ed25519 key whose comment contains selector (or any if None).""" try: keys = agent_list_keys() except Exception: return None for key in keys: if key.comment == "(none)": continue if selector is None or selector in key.comment: return key return None def agent_sign_raw(key: AgentKey, data: bytes) -> bytes: """Ask the SSH agent to sign data and return the raw 64-byte Ed25519 signature.""" msg = ( bytes([_SSH_AGENTC_SIGN_REQUEST]) + _pack_str(key.blob) + _pack_str(data) + struct.pack(">I", 0) ) resp = _agent_roundtrip(msg) if resp[0] != _SSH_AGENT_SIGN_RESPONSE: raise RuntimeError(f"SSH agent refused to sign (response code {resp[0]})") sig_blob, _ = _unpack_str(resp, 1) _algo, soff = _unpack_str(sig_blob, 0) raw_sig, _ = _unpack_str(sig_blob, soff) if len(raw_sig) != 64: raise RuntimeError(f"Unexpected signature length {len(raw_sig)}") return raw_sig # ── File-based key helpers ───────────────────────────────────────────────────── def generate_keypair() -> tuple[bytes, str]: """Return (private_key_pem_bytes, public_key_hex).""" priv = Ed25519PrivateKey.generate() pem = priv.private_bytes(Encoding.PEM, PrivateFormat.PKCS8, NoEncryption()) pub_hex = priv.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw).hex() return pem, pub_hex def load_private_key(path: Path) -> Ed25519PrivateKey: return load_pem_private_key(path.read_bytes(), password=None) def public_key_hex(key: Ed25519PrivateKey | AgentKey) -> str: if isinstance(key, AgentKey): return key.pubkey_bytes.hex() return key.public_key().public_bytes(Encoding.Raw, PublicFormat.Raw).hex() # ── Canonical payload + sign/verify ─────────────────────────────────────────── def canonical_payload(msg: dict) -> bytes: """Deterministic JSON encoding of msg without auth protocol fields.""" return json.dumps( {k: v for k, v in msg.items() if k not in {"pubkey", "sig", "pq_kex"}}, sort_keys=True, separators=(",", ":"), ).encode("utf-8") def _auth_message(nonce: bytes, msg: dict, pq_shared_secret: bytes | None = None) -> bytes: """Bytes signed for auth; optionally binds a post-quantum KEX secret.""" data = nonce + hashlib.sha256(canonical_payload(msg)).digest() if pq_shared_secret is not None: data += hashlib.sha256(b"browser-cli ml-kem-768 v1" + pq_shared_secret).digest() return data def sign(key: Ed25519PrivateKey | AgentKey, nonce: bytes, msg: dict, pq_shared_secret: bytes | None = None) -> bytes: """Sign nonce + payload hash, optionally bound to an ML-KEM shared secret.""" data = _auth_message(nonce, msg, pq_shared_secret) if isinstance(key, AgentKey): return agent_sign_raw(key, data) return key.sign(data) def verify(pub_hex: str, nonce: bytes, msg: dict, sig_hex: str, pq_shared_secret: bytes | None = None) -> bool: """Return True if sig_hex is a valid signature over the canonical payload/auth secret.""" try: pub_bytes = bytes.fromhex(pub_hex) pub_key = Ed25519PublicKey.from_public_bytes(pub_bytes) pub_key.verify(bytes.fromhex(sig_hex), _auth_message(nonce, msg, pq_shared_secret)) return True except (InvalidSignature, ValueError): return False # ── Post-quantum key exchange (ML-KEM / Kyber) ──────────────────────────────── PQ_KEX_ALG = "ML-KEM-768" def pq_kex_server_keypair(): """Return an ephemeral ML-KEM-768 private key and raw public key bytes. Returns ``None`` when the installed cryptography/OpenSSL backend does not support ML-KEM yet. The serve/client protocol treats this as graceful downgrade instead of breaking local installs on older OpenSSL builds. """ try: from cryptography.hazmat.primitives.asymmetric import mlkem priv = mlkem.MLKEM768PrivateKey.generate() pub = priv.public_key().public_bytes_raw() return priv, pub except Exception: return None def pq_kex_client_encapsulate(public_key_hex: str) -> tuple[str, bytes]: """Encapsulate to a server ML-KEM public key. Returns (ciphertext_hex, secret).""" from cryptography.hazmat.primitives.asymmetric import mlkem pub = mlkem.MLKEM768PublicKey.from_public_bytes(bytes.fromhex(public_key_hex)) ciphertext, shared_secret = pub.encapsulate() return ciphertext.hex(), shared_secret def pq_kex_server_decapsulate(private_key, ciphertext_hex: str) -> bytes: """Decapsulate a client ML-KEM ciphertext and return the shared secret.""" return private_key.decapsulate(bytes.fromhex(ciphertext_hex)) def new_nonce() -> str: return secrets.token_hex(32) def load_authorized_keys_with_names(path: Path) -> list[tuple[str, str]]: """Return list of (pubkey_hex, name) pairs. Name is empty string if not set.""" if not path.exists(): return [] result = [] for line in path.read_text(encoding="utf-8").splitlines(): line = line.strip() if not line or line.startswith("#"): continue parts = line.split(None, 1) pubkey = parts[0] name = parts[1].strip() if len(parts) > 1 else "" result.append((pubkey, name)) return result def load_authorized_keys(path: Path) -> list[str]: return [pk for pk, _ in load_authorized_keys_with_names(path)] def add_authorized_key(path: Path, pub_hex: str, name: str = "") -> bool: """Append pub_hex to authorized_keys. Returns False if already present.""" path.parent.mkdir(parents=True, exist_ok=True) existing = {pk for pk, _ in load_authorized_keys_with_names(path)} if pub_hex in existing: return False line = (f"{pub_hex} {name}".rstrip()) + "\n" with open(path, "a", encoding="utf-8") as f: f.write(line) return True