103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
from __future__ import annotations
|
|
|
|
import socket
|
|
|
|
import pytest
|
|
|
|
from pyro_mcp.vm_guest import VsockExecClient
|
|
|
|
|
|
class StubSocket:
|
|
def __init__(self, responses: list[bytes] | bytes) -> None:
|
|
if isinstance(responses, bytes):
|
|
self.responses = [responses]
|
|
else:
|
|
self.responses = responses
|
|
self._buffer = b""
|
|
self.connected: object | None = None
|
|
self.sent = b""
|
|
self.timeout: int | None = None
|
|
self.closed = False
|
|
|
|
def settimeout(self, timeout: int) -> None:
|
|
self.timeout = timeout
|
|
|
|
def connect(self, address: tuple[int, int]) -> None:
|
|
self.connected = address
|
|
|
|
def sendall(self, data: bytes) -> None:
|
|
self.sent += data
|
|
|
|
def recv(self, size: int) -> bytes:
|
|
if not self._buffer and self.responses:
|
|
self._buffer = self.responses.pop(0)
|
|
if not self._buffer:
|
|
return b""
|
|
data = self._buffer[:size]
|
|
self._buffer = self._buffer[size:]
|
|
return data
|
|
|
|
def close(self) -> None:
|
|
self.closed = True
|
|
|
|
|
|
def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
|
|
stub = StubSocket(
|
|
b'{"stdout":"ok\\n","stderr":"","exit_code":0,"duration_ms":7}'
|
|
)
|
|
|
|
def socket_factory(family: int, sock_type: int) -> StubSocket:
|
|
assert family == socket.AF_VSOCK
|
|
assert sock_type == socket.SOCK_STREAM
|
|
return stub
|
|
|
|
client = VsockExecClient(socket_factory=socket_factory)
|
|
response = client.exec(1234, 5005, "echo ok", 30)
|
|
|
|
assert response.exit_code == 0
|
|
assert response.stdout == "ok\n"
|
|
assert stub.connected == (1234, 5005)
|
|
assert b'"command": "echo ok"' in stub.sent
|
|
assert stub.closed is True
|
|
|
|
|
|
def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
|
|
stub = StubSocket(b"[]")
|
|
client = VsockExecClient(socket_factory=lambda family, sock_type: stub)
|
|
with pytest.raises(RuntimeError, match="JSON object"):
|
|
client.exec(1234, 5005, "echo ok", 30)
|
|
|
|
|
|
def test_vsock_exec_client_uses_unix_bridge_when_vsock_is_unavailable(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.delattr(socket, "AF_VSOCK", raising=False)
|
|
stub = StubSocket(
|
|
[
|
|
b"OK 1073746829\n",
|
|
b'{"stdout":"ready\\n","stderr":"","exit_code":0,"duration_ms":5}',
|
|
]
|
|
)
|
|
|
|
def socket_factory(family: int, sock_type: int) -> StubSocket:
|
|
assert family == socket.AF_UNIX
|
|
assert sock_type == socket.SOCK_STREAM
|
|
return stub
|
|
|
|
client = VsockExecClient(socket_factory=socket_factory)
|
|
response = client.exec(1234, 5005, "echo ready", 30, uds_path="/tmp/vsock.sock")
|
|
|
|
assert response.stdout == "ready\n"
|
|
assert stub.connected == "/tmp/vsock.sock"
|
|
assert stub.sent.startswith(b"CONNECT 5005\n")
|
|
|
|
|
|
def test_vsock_exec_client_requires_transport_when_vsock_is_unavailable(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
monkeypatch.delattr(socket, "AF_VSOCK", raising=False)
|
|
client = VsockExecClient(socket_factory=lambda family, sock_type: StubSocket(b""))
|
|
with pytest.raises(RuntimeError, match="not supported"):
|
|
client.exec(1234, 5005, "echo ok", 30)
|