from __future__ import annotations import io import json import socket import tarfile from pathlib import Path import pytest from pyro_mcp.vm_guest import VsockExecClient class StubSocket: def __init__(self, responses: list[bytes] | bytes) -> None: if isinstance(responses, bytes): self.responses = [responses] else: self.responses = responses self._buffer = b"" self.connected: object | None = None self.sent = b"" self.timeout: int | None = None self.closed = False def settimeout(self, timeout: int) -> None: self.timeout = timeout def connect(self, address: tuple[int, int]) -> None: self.connected = address def sendall(self, data: bytes) -> None: self.sent += data def recv(self, size: int) -> bytes: if not self._buffer and self.responses: self._buffer = self.responses.pop(0) if not self._buffer: return b"" data = self._buffer[:size] self._buffer = self._buffer[size:] return data def close(self) -> None: self.closed = True def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) stub = StubSocket( b'{"stdout":"ok\\n","stderr":"","exit_code":0,"duration_ms":7}' ) def socket_factory(family: int, sock_type: int) -> StubSocket: assert family == socket.AF_VSOCK assert sock_type == socket.SOCK_STREAM return stub client = VsockExecClient(socket_factory=socket_factory) response = client.exec(1234, 5005, "echo ok", 30) assert response.exit_code == 0 assert response.stdout == "ok\n" assert stub.connected == (1234, 5005) assert b'"command": "echo ok"' in stub.sent assert stub.closed is True def test_vsock_exec_client_upload_archive_round_trip( monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) archive_path = tmp_path / "seed.tgz" with tarfile.open(archive_path, "w:gz") as archive: payload = b"hello\n" info = tarfile.TarInfo(name="note.txt") info.size = len(payload) archive.addfile(info, io.BytesIO(payload)) stub = StubSocket( b'{"destination":"/workspace","entry_count":1,"bytes_written":6}' ) def socket_factory(family: int, sock_type: int) -> StubSocket: assert family == socket.AF_VSOCK assert sock_type == socket.SOCK_STREAM return stub client = VsockExecClient(socket_factory=socket_factory) response = client.upload_archive( 1234, 5005, archive_path, destination="/workspace", timeout_seconds=60, ) request_payload, archive_payload = stub.sent.split(b"\n", 1) request = json.loads(request_payload.decode("utf-8")) assert request["action"] == "extract_archive" assert request["destination"] == "/workspace" assert int(request["archive_size"]) == archive_path.stat().st_size assert archive_payload == archive_path.read_bytes() assert response.entry_count == 1 assert response.bytes_written == 6 assert stub.closed is True def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) stub = StubSocket(b"[]") client = VsockExecClient(socket_factory=lambda family, sock_type: stub) with pytest.raises(RuntimeError, match="JSON object"): client.exec(1234, 5005, "echo ok", 30) def test_vsock_exec_client_uses_unix_bridge_when_vsock_is_unavailable( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.delattr(socket, "AF_VSOCK", raising=False) stub = StubSocket( [ b"OK 1073746829\n", b'{"stdout":"ready\\n","stderr":"","exit_code":0,"duration_ms":5}', ] ) def socket_factory(family: int, sock_type: int) -> StubSocket: assert family == socket.AF_UNIX assert sock_type == socket.SOCK_STREAM return stub client = VsockExecClient(socket_factory=socket_factory) response = client.exec(1234, 5005, "echo ready", 30, uds_path="/tmp/vsock.sock") assert response.stdout == "ready\n" assert stub.connected == "/tmp/vsock.sock" assert stub.sent.startswith(b"CONNECT 5005\n") def test_vsock_exec_client_requires_transport_when_vsock_is_unavailable( monkeypatch: pytest.MonkeyPatch, ) -> None: monkeypatch.delattr(socket, "AF_VSOCK", raising=False) client = VsockExecClient(socket_factory=lambda family, sock_type: StubSocket(b"")) with pytest.raises(RuntimeError, match="not supported"): client.exec(1234, 5005, "echo ok", 30)