"""Guest command transport over vsock-compatible JSON protocol.""" from __future__ import annotations import json import socket from dataclasses import dataclass from typing import Any, Callable, Protocol class SocketLike(Protocol): def settimeout(self, timeout: int) -> None: ... def connect(self, address: Any) -> None: ... def sendall(self, data: bytes) -> None: ... def recv(self, size: int) -> bytes: ... def close(self) -> None: ... SocketFactory = Callable[[int, int], SocketLike] @dataclass(frozen=True) class GuestExecResponse: stdout: str stderr: str exit_code: int duration_ms: int class VsockExecClient: """Minimal JSON-over-stream client for a guest exec agent.""" def __init__(self, socket_factory: SocketFactory | None = None) -> None: self._socket_factory = socket_factory or socket.socket def exec( self, guest_cid: int, port: int, command: str, timeout_seconds: int, *, uds_path: str | None = None, ) -> GuestExecResponse: request = { "command": command, "timeout_seconds": timeout_seconds, } family = getattr(socket, "AF_VSOCK", None) if family is not None: sock = self._socket_factory(family, socket.SOCK_STREAM) connect_address: Any = (guest_cid, port) elif uds_path is not None: sock = self._socket_factory(socket.AF_UNIX, socket.SOCK_STREAM) connect_address = uds_path else: raise RuntimeError("vsock sockets are not supported on this host Python runtime") try: sock.settimeout(timeout_seconds) sock.connect(connect_address) if family is None: sock.sendall(f"CONNECT {port}\n".encode("utf-8")) status = self._recv_line(sock) if not status.startswith("OK "): raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}") sock.sendall((json.dumps(request) + "\n").encode("utf-8")) chunks: list[bytes] = [] while True: data = sock.recv(65536) if data == b"": break chunks.append(data) finally: sock.close() payload = json.loads(b"".join(chunks).decode("utf-8")) if not isinstance(payload, dict): raise RuntimeError("guest exec response must be a JSON object") return GuestExecResponse( stdout=str(payload.get("stdout", "")), stderr=str(payload.get("stderr", "")), exit_code=int(payload.get("exit_code", -1)), duration_ms=int(payload.get("duration_ms", 0)), ) @staticmethod def _recv_line(sock: SocketLike) -> str: chunks: list[bytes] = [] while True: data = sock.recv(1) if data == b"": break chunks.append(data) if data == b"\n": break return b"".join(chunks).decode("utf-8", errors="replace")