"""Post-quantum ML-KEM key exchange and app-layer transport encryption.""" from __future__ import annotations import secrets from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 from cryptography.hazmat.primitives.kdf.hkdf import HKDF from browser_cli.constants import PQ_TRANSPORT_ALG def pq_kex_server_keypair(): """Return an ephemeral ML-KEM-768 private key and raw public key bytes. Returns ``None`` when the installed cryptography/OpenSSL backend does not support ML-KEM yet. The serve/client protocol treats this as graceful downgrade instead of breaking local installs on older OpenSSL builds. """ try: from cryptography.hazmat.primitives.asymmetric import mlkem private_key = mlkem.MLKEM768PrivateKey.generate() public_key = private_key.public_key().public_bytes_raw() return private_key, public_key except Exception: return None def pq_kex_client_encapsulate(public_key_hex: str) -> tuple[str, bytes]: """Encapsulate to a server ML-KEM public key. Returns (ciphertext_hex, secret).""" from cryptography.hazmat.primitives.asymmetric import mlkem public_key = mlkem.MLKEM768PublicKey.from_public_bytes(bytes.fromhex(public_key_hex)) shared_secret, ciphertext = public_key.encapsulate() return ciphertext.hex(), shared_secret def pq_kex_server_decapsulate(private_key, ciphertext_hex: str) -> bytes: """Decapsulate a client ML-KEM ciphertext and return the shared secret.""" return private_key.decapsulate(bytes.fromhex(ciphertext_hex)) def pq_transport_key(shared_secret: bytes, direction: str) -> bytes: return HKDF( algorithm=hashes.SHA256(), length=32, salt=None, info=f"browser-cli pq transport v1 {direction}".encode("ascii"), ).derive(shared_secret) def pq_encrypt(shared_secret: bytes, direction: str, plaintext: bytes) -> dict: """Encrypt an app-layer frame with a key derived from the ML-KEM secret.""" nonce = secrets.token_bytes(12) key = pq_transport_key(shared_secret, direction) ciphertext = ChaCha20Poly1305(key).encrypt(nonce, plaintext, None) return {"alg": PQ_TRANSPORT_ALG, "nonce": nonce.hex(), "ciphertext": ciphertext.hex()} def pq_decrypt(shared_secret: bytes, direction: str, envelope: dict) -> bytes: """Decrypt an app-layer frame produced by pq_encrypt().""" if not isinstance(envelope, dict) or envelope.get("alg") != PQ_TRANSPORT_ALG: raise ValueError("unsupported encrypted transport envelope") key = pq_transport_key(shared_secret, direction) return ChaCha20Poly1305(key).decrypt( bytes.fromhex(str(envelope["nonce"])), bytes.fromhex(str(envelope["ciphertext"])), None, ) def new_nonce() -> str: return secrets.token_hex(32)