"""Local IPC server loops used by the native messaging host.""" import asyncio import os import socket import threading from collections.abc import Callable from multiprocessing.connection import Listener from pathlib import Path from browser_cli import framing, local_transport from browser_cli.platform import is_windows PayloadHandler = Callable[[bytes], bytes] ErrorHandler = Callable[[Exception], bytes] async def async_socket_server( sock_path: str, handle_payload: PayloadHandler, error_response: ErrorHandler, *, bound_sock: socket.socket | None = None, ) -> None: sock = bound_sock if sock is None: path = Path(sock_path) if path.exists(): path.unlink() sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) sock.bind(sock_path) os.chmod(sock_path, 0o600) sock.listen(16) async def handle(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: await async_handle_cli_connection(reader, writer, handle_payload, error_response) server = await asyncio.start_unix_server(handle, sock=sock) async with server: await server.serve_forever() def socket_server( sock_path: str, handle_payload: PayloadHandler, error_response: ErrorHandler, *, bound_sock: socket.socket | None = None, ) -> None: if is_windows(): windows_pipe_server(sock_path, handle_payload, error_response) return asyncio.run(async_socket_server(sock_path, handle_payload, error_response, bound_sock=bound_sock)) def windows_pipe_server(sock_path: str, handle_payload: PayloadHandler, error_response: ErrorHandler) -> None: while True: listener = None try: listener = Listener(sock_path, family="AF_PIPE") conn = listener.accept() except OSError: if listener is not None: try: listener.close() except Exception: pass break threading.Thread(target=handle_cli_connection, args=(conn, handle_payload, error_response, listener), daemon=True).start() async def async_handle_cli_connection( reader: asyncio.StreamReader, writer: asyncio.StreamWriter, handle_payload: PayloadHandler, error_response: ErrorHandler, ) -> None: try: data = await local_transport.async_recv_all(reader) if not data: return response = await asyncio.to_thread(handle_payload, data) await local_transport.async_send_all(writer, response) except Exception as exc: try: await local_transport.async_send_all(writer, error_response(exc)) except Exception: pass finally: writer.close() try: await writer.wait_closed() except Exception: pass def send_cli_response(conn, response: bytes) -> None: if is_windows(): conn.send_bytes(response) else: framing.send_frame(conn, response) def handle_cli_connection(conn, handle_payload: PayloadHandler, error_response: ErrorHandler, listener=None) -> None: try: data = conn.recv_bytes() if is_windows() else framing.recv_frame(conn, allow_eof=True) if not data: return send_cli_response(conn, handle_payload(data)) except Exception as exc: try: send_cli_response(conn, error_response(exc)) except Exception: pass finally: conn.close() if listener is not None: listener.close()