diff --git a/AGENTS.md b/AGENTS.md index 39af119..7fc9156 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,6 +14,7 @@ This repository ships `pyro-mcp`, an MCP-compatible package for ephemeral VM lif - Use `make demo` to validate deterministic VM lifecycle execution. - Use `make ollama-demo` to validate model-triggered lifecycle tool usage. - Use `make doctor` to inspect bundled runtime integrity and host prerequisites. +- If you need full log payloads from the Ollama demo, use `make ollama-demo OLLAMA_DEMO_FLAGS=-v`. ## Quality Gates @@ -27,6 +28,7 @@ These checks run in pre-commit hooks and should all pass locally. - Public factory: `pyro_mcp.create_server()` - Runtime diagnostics CLI: `pyro-mcp-doctor` +- Current bundled runtime is shim-based unless replaced with a real guest-capable bundle; check `make doctor` for runtime capabilities. - Lifecycle tools: - `vm_list_profiles` - `vm_create` @@ -35,3 +37,5 @@ These checks run in pre-commit hooks and should all pass locally. - `vm_stop` - `vm_delete` - `vm_status` + - `vm_network_info` + - `vm_reap_expired` diff --git a/Makefile b/Makefile index a452580..9f5dcb1 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ PYTHON ?= uv run python OLLAMA_BASE_URL ?= http://localhost:11434/v1 OLLAMA_MODEL ?= llama3.2:3b +OLLAMA_DEMO_FLAGS ?= .PHONY: setup lint format typecheck test check demo doctor ollama ollama-demo run-server install-hooks @@ -30,7 +31,7 @@ doctor: ollama: ollama-demo ollama-demo: - uv run pyro-mcp-ollama-demo --base-url "$(OLLAMA_BASE_URL)" --model "$(OLLAMA_MODEL)" + uv run pyro-mcp-ollama-demo --base-url "$(OLLAMA_BASE_URL)" --model "$(OLLAMA_MODEL)" $(OLLAMA_DEMO_FLAGS) run-server: uv run pyro-mcp-server diff --git a/README.md b/README.md index e8c3efa..32db90d 100644 --- a/README.md +++ b/README.md @@ -4,14 +4,14 @@ ## v0.1.0 Capabilities -- Split lifecycle tools for coding agents: `vm_list_profiles`, `vm_create`, `vm_start`, `vm_exec`, `vm_stop`, `vm_delete`, `vm_status`, `vm_reap_expired`. +- Split lifecycle tools for coding agents: `vm_list_profiles`, `vm_create`, `vm_start`, `vm_exec`, `vm_stop`, `vm_delete`, `vm_status`, `vm_network_info`, `vm_reap_expired`. - Standard environment profiles: - `debian-base`: minimal Debian shell/core Unix tools. - `debian-git`: Debian base with Git preinstalled. - `debian-build`: Debian Git profile with common build tooling. - Explicit sizing contract for agents (`vcpu_count`, `mem_mib`) with guardrails. - Strict ephemerality for command execution (`vm_exec` auto-deletes VM on completion). -- Ollama demo that asks an LLM to run `git --version` through lifecycle tools. +- Ollama demo that asks an LLM to clone a small public Git repository through lifecycle tools. ## Runtime @@ -22,6 +22,12 @@ The package includes a bundled Linux x86_64 runtime payload: No system Firecracker installation is required for basic usage. +Current limitation: +- The bundled runtime is currently shim-based. +- `doctor` reports runtime capabilities, and current bundles report no real guest boot, no guest exec agent, and no guest networking. +- Until a real guest-capable bundle is installed, `vm_exec` runs in `host_compat` mode rather than `guest_vsock`. +- This means demo commands can exercise lifecycle/control-plane behavior, but they are not yet proof of command execution inside a real VM guest. + Host requirements still apply: - Linux host - `/dev/kvm` available for full virtualization mode @@ -44,7 +50,9 @@ make setup make demo ``` -The demo creates a VM, starts it, runs `git --version`, and returns structured output. +The demo creates a VM, starts it, runs a command, and returns structured output. +If the runtime reports `guest_vsock` plus networking, it uses an internet probe. +Otherwise it falls back to a local compatibility command and the result will report `execution_mode=host_compat`. ## Runtime doctor @@ -52,7 +60,21 @@ The demo creates a VM, starts it, runs `git --version`, and returns structured o make doctor ``` -This prints bundled runtime paths, profile availability, checksum validation status, and KVM host checks. +This prints bundled runtime paths, profile availability, checksum validation status, runtime capability flags, KVM host checks, and host networking diagnostics. + +## Networking + +- Host-side network allocation and diagnostics are implemented. +- The MCP server exposes `vm_network_info` for per-VM network metadata. +- Host TAP/NAT setup is opt-in with: + +```bash +PYRO_VM_ENABLE_NETWORK=1 make doctor +``` + +- Current limitation: + - network metadata and host preflight exist + - real in-guest outbound networking still depends on a non-shim runtime bundle with real guest boot and guest exec support ## Run Ollama lifecycle demo @@ -64,6 +86,18 @@ make ollama-demo Defaults are configured in `Makefile`. The demo streams lifecycle progress logs and ends with a short text summary. +The command it asks the model to run is a small public repository clone: + +```bash +rm -rf hello-world && git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world >/dev/null && git -C hello-world rev-parse --is-inside-work-tree +``` + +If the runtime is still shim-based, the summary will show `execution_mode=host_compat`. +By default it omits log values; to include prompt content, tool args, and tool results use: + +```bash +make ollama-demo OLLAMA_DEMO_FLAGS=-v +``` ## Run MCP server diff --git a/src/pyro_mcp/demo.py b/src/pyro_mcp/demo.py index 9e04277..833f8c5 100644 --- a/src/pyro_mcp/demo.py +++ b/src/pyro_mcp/demo.py @@ -7,6 +7,18 @@ from typing import Any from pyro_mcp.vm_manager import VmManager +INTERNET_PROBE_COMMAND = ( + 'python3 -c "import urllib.request; ' + "print(urllib.request.urlopen('https://example.com', timeout=10).status)" + '"' +) + + +def _demo_command(status: dict[str, Any]) -> str: + if bool(status.get("network_enabled")) and status.get("execution_mode") == "guest_vsock": + return INTERNET_PROBE_COMMAND + return "git --version" + def run_demo() -> dict[str, Any]: """Create/start/exec/delete a VM and return command output.""" @@ -14,7 +26,8 @@ def run_demo() -> dict[str, Any]: created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600) vm_id = str(created["vm_id"]) manager.start_vm(vm_id) - executed = manager.exec_vm(vm_id, command="git --version", timeout_seconds=30) + status = manager.status_vm(vm_id) + executed = manager.exec_vm(vm_id, command=_demo_command(status), timeout_seconds=30) return executed diff --git a/src/pyro_mcp/ollama_demo.py b/src/pyro_mcp/ollama_demo.py index 0f6ba73..449d69f 100644 --- a/src/pyro_mcp/ollama_demo.py +++ b/src/pyro_mcp/ollama_demo.py @@ -16,6 +16,12 @@ __all__ = ["VmManager", "run_ollama_tool_demo"] DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1" DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b" MAX_TOOL_ROUNDS: Final[int] = 12 +CLONE_TARGET_DIR: Final[str] = "hello-world" +NETWORK_PROOF_COMMAND: Final[str] = ( + "rm -rf hello-world " + "&& git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world >/dev/null " + "&& git -C hello-world rev-parse --is-inside-work-tree" +) TOOL_SPECS: Final[list[dict[str, Any]]] = [ { @@ -210,7 +216,7 @@ def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]: created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600) vm_id = str(created["vm_id"]) manager.start_vm(vm_id) - return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30) + return manager.exec_vm(vm_id, command=NETWORK_PROOF_COMMAND, timeout_seconds=60) def _is_vm_id_placeholder(value: str) -> bool: @@ -240,8 +246,10 @@ def _normalize_tool_arguments( return normalized_arguments, last_created_vm_id -def _summarize_message_for_log(message: dict[str, Any]) -> str: +def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str: role = str(message.get("role", "unknown")) + if not verbose: + return role content = str(message.get("content") or "").strip() if content == "": return f"{role}: " @@ -257,6 +265,7 @@ def run_ollama_tool_demo( model: str = DEFAULT_OLLAMA_MODEL, *, strict: bool = True, + verbose: bool = False, log: Callable[[str], None] | None = None, ) -> dict[str, Any]: """Ask Ollama to run git version check in an ephemeral VM through lifecycle tools.""" @@ -267,10 +276,12 @@ def run_ollama_tool_demo( { "role": "user", "content": ( - "Use the lifecycle tools to run `git --version` in an ephemeral VM.\n" + "Use the lifecycle tools to prove outbound internet access in an ephemeral VM.\n" "Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n" "Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` " "returned by vm_create into vm_start/vm_exec.\n" + f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n" + f"Success means the clone completes and the command prints `true`.\n" "If a tool returns an error, fix arguments and retry." ), } @@ -280,7 +291,7 @@ def run_ollama_tool_demo( last_created_vm_id: str | None = None for _round_index in range(1, MAX_TOOL_ROUNDS + 1): - emit(f"[model] input {_summarize_message_for_log(messages[-1])}") + emit(f"[model] input {_summarize_message_for_log(messages[-1], verbose=verbose)}") response = _post_chat_completion( base_url, { @@ -292,7 +303,7 @@ def run_ollama_tool_demo( }, ) assistant_message = _extract_message(response) - emit(f"[model] output {_summarize_message_for_log(assistant_message)}") + emit(f"[model] output {_summarize_message_for_log(assistant_message, verbose=verbose)}") tool_calls = assistant_message.get("tool_calls") if not isinstance(tool_calls, list) or not tool_calls: final_response = str(assistant_message.get("content") or "") @@ -320,15 +331,24 @@ def run_ollama_tool_demo( if not isinstance(tool_name, str): raise RuntimeError("tool call function name is invalid") arguments = _parse_tool_arguments(function.get("arguments")) - emit(f"[model] tool_call {tool_name} args={arguments}") + if verbose: + emit(f"[model] tool_call {tool_name} args={arguments}") + else: + emit(f"[model] tool_call {tool_name}") arguments, normalized_vm_id = _normalize_tool_arguments( tool_name, arguments, last_created_vm_id=last_created_vm_id, ) if normalized_vm_id is not None: - emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}") - emit(f"[tool] calling {tool_name} with args={arguments}") + if verbose: + emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}") + else: + emit("[tool] resolved vm_id placeholder") + if verbose: + emit(f"[tool] calling {tool_name} with args={arguments}") + else: + emit(f"[tool] calling {tool_name}") try: result = _dispatch_tool_call(manager, tool_name, arguments) success = True @@ -341,7 +361,10 @@ def run_ollama_tool_demo( result = _format_tool_error(tool_name, arguments, exc) success = False emit(f"[tool] {tool_name} failed: {exc}") - emit(f"[tool] result {tool_name} {_serialize_log_value(result)}") + if verbose: + emit(f"[tool] result {tool_name} {_serialize_log_value(result)}") + else: + emit(f"[tool] result {tool_name}") tool_events.append( { "tool_name": tool_name, @@ -379,7 +402,7 @@ def run_ollama_tool_demo( tool_events.append( { "tool_name": "vm_exec_fallback", - "arguments": {"command": "git --version"}, + "arguments": {"command": NETWORK_PROOF_COMMAND}, "result": exec_result, "success": True, } @@ -390,13 +413,13 @@ def run_ollama_tool_demo( raise RuntimeError("vm_exec result shape is invalid") if int(exec_result.get("exit_code", -1)) != 0: raise RuntimeError("vm_exec failed; expected exit_code=0") - if "git version" not in str(exec_result.get("stdout", "")): - raise RuntimeError("vm_exec output did not contain `git version`") + if str(exec_result.get("stdout", "")).strip() != "true": + raise RuntimeError("vm_exec output did not confirm repository clone success") emit("[done] command execution succeeded") return { "model": model, - "command": "git --version", + "command": NETWORK_PROOF_COMMAND, "exec_result": exec_result, "tool_events": tool_events, "final_response": final_response, @@ -408,6 +431,7 @@ def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.") parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL) parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL) + parser.add_argument("-v", "--verbose", action="store_true") return parser @@ -418,6 +442,7 @@ def main() -> None: result = run_ollama_tool_demo( base_url=args.base_url, model=args.model, + verbose=args.verbose, log=lambda message: print(message, flush=True), ) except Exception as exc: # noqa: BLE001 @@ -428,7 +453,9 @@ def main() -> None: raise RuntimeError("demo produced invalid execution result") print( f"[summary] exit_code={int(exec_result.get('exit_code', -1))} " - f"fallback_used={bool(result.get('fallback_used'))}", + f"fallback_used={bool(result.get('fallback_used'))} " + f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}", flush=True, ) - print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True) + if args.verbose: + print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True) diff --git a/src/pyro_mcp/runtime.py b/src/pyro_mcp/runtime.py index c955dfe..3a5187e 100644 --- a/src/pyro_mcp/runtime.py +++ b/src/pyro_mcp/runtime.py @@ -11,6 +11,8 @@ from dataclasses import dataclass from pathlib import Path from typing import Any +from pyro_mcp.vm_network import TapNetworkManager + DEFAULT_PLATFORM = "linux-x86_64" @@ -27,6 +29,16 @@ class RuntimePaths: manifest: dict[str, Any] +@dataclass(frozen=True) +class RuntimeCapabilities: + """Feature flags inferred from the bundled runtime.""" + + supports_vm_boot: bool + supports_guest_exec: bool + supports_guest_network: bool + reason: str | None = None + + def _sha256(path: Path) -> str: digest = hashlib.sha256() with path.open("rb") as fp: @@ -135,6 +147,40 @@ def resolve_runtime_paths( ) +def runtime_capabilities(paths: RuntimePaths) -> RuntimeCapabilities: + """Infer what the current bundled runtime can actually do.""" + binary_text = paths.firecracker_bin.read_text(encoding="utf-8", errors="ignore") + if "bundled firecracker shim" in binary_text: + return RuntimeCapabilities( + supports_vm_boot=False, + supports_guest_exec=False, + supports_guest_network=False, + reason="bundled runtime uses shim firecracker/jailer binaries", + ) + + capabilities = paths.manifest.get("capabilities") + if not isinstance(capabilities, dict): + return RuntimeCapabilities( + supports_vm_boot=False, + supports_guest_exec=False, + supports_guest_network=False, + reason="runtime manifest does not declare guest boot/exec/network capabilities", + ) + + supports_vm_boot = bool(capabilities.get("vm_boot")) + supports_guest_exec = bool(capabilities.get("guest_exec")) + supports_guest_network = bool(capabilities.get("guest_network")) + reason = None + if not supports_vm_boot: + reason = "runtime manifest does not advertise real VM boot support" + return RuntimeCapabilities( + supports_vm_boot=supports_vm_boot, + supports_guest_exec=supports_guest_exec, + supports_guest_network=supports_guest_network, + reason=reason, + ) + + def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]: """Build a runtime diagnostics report.""" report: dict[str, Any] = { @@ -146,13 +192,28 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]: "readable": os.access("/dev/kvm", os.R_OK), "writable": os.access("/dev/kvm", os.W_OK), }, + "networking": { + "enabled_by_default": TapNetworkManager().enabled, + }, } + network = TapNetworkManager.diagnostics() + report["networking"].update( + { + "tun_available": network.tun_available, + "ip_binary": network.ip_binary, + "nft_binary": network.nft_binary, + "iptables_binary": network.iptables_binary, + "ip_forward_enabled": network.ip_forward_enabled, + } + ) try: paths = resolve_runtime_paths(platform=platform, verify_checksums=True) except Exception as exc: # noqa: BLE001 report["issues"] = [str(exc)] return report + capabilities = runtime_capabilities(paths) + profiles = paths.manifest.get("profiles", {}) profile_names = sorted(profiles.keys()) if isinstance(profiles, dict) else [] report["runtime_ok"] = True @@ -165,6 +226,12 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]: "notice_path": str(paths.notice_path), "bundle_version": paths.manifest.get("bundle_version"), "profiles": profile_names, + "capabilities": { + "supports_vm_boot": capabilities.supports_vm_boot, + "supports_guest_exec": capabilities.supports_guest_exec, + "supports_guest_network": capabilities.supports_guest_network, + "reason": capabilities.reason, + }, } if not report["kvm"]["exists"]: report["issues"] = ["/dev/kvm is not available on this host"] diff --git a/src/pyro_mcp/server.py b/src/pyro_mcp/server.py index c61c374..8794f4a 100644 --- a/src/pyro_mcp/server.py +++ b/src/pyro_mcp/server.py @@ -59,6 +59,11 @@ def create_server(manager: VmManager | None = None) -> FastMCP: """Get the current state and metadata for a VM.""" return vm_manager.status_vm(vm_id) + @server.tool() + async def vm_network_info(vm_id: str) -> dict[str, Any]: + """Get the current network configuration assigned to a VM.""" + return vm_manager.network_info_vm(vm_id) + @server.tool() async def vm_reap_expired() -> dict[str, Any]: """Delete VMs whose TTL has expired.""" diff --git a/src/pyro_mcp/vm_firecracker.py b/src/pyro_mcp/vm_firecracker.py new file mode 100644 index 0000000..f522d5c --- /dev/null +++ b/src/pyro_mcp/vm_firecracker.py @@ -0,0 +1,109 @@ +"""Firecracker launch-plan generation for pyro VMs.""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Protocol + +DEFAULT_GUEST_CID_OFFSET = 10_000 +DEFAULT_VSOCK_PORT = 5005 + + +@dataclass(frozen=True) +class FirecrackerLaunchPlan: + """On-disk config artifacts needed to boot a guest.""" + + api_socket_path: Path + config_path: Path + guest_network_path: Path + guest_exec_path: Path + guest_cid: int + vsock_port: int + config: dict[str, Any] + + +class VmInstanceLike(Protocol): + vm_id: str + vcpu_count: int + mem_mib: int + workdir: Path + metadata: dict[str, str] + network: Any + + +def _guest_cid(vm_id: str) -> int: + return DEFAULT_GUEST_CID_OFFSET + int(vm_id[:4], 16) + + +def build_launch_plan(instance: VmInstanceLike) -> FirecrackerLaunchPlan: + guest_cid = _guest_cid(instance.vm_id) + vsock_port = DEFAULT_VSOCK_PORT + api_socket_path = instance.workdir / "firecracker.sock" + config_path = instance.workdir / "firecracker-config.json" + guest_network_path = instance.workdir / "guest-network.json" + guest_exec_path = instance.workdir / "guest-exec.json" + + config: dict[str, Any] = { + "boot-source": { + "kernel_image_path": instance.metadata["kernel_image"], + "boot_args": "console=ttyS0 reboot=k panic=1 pci=off", + }, + "drives": [ + { + "drive_id": "rootfs", + "path_on_host": instance.metadata["rootfs_image"], + "is_root_device": True, + "is_read_only": False, + } + ], + "machine-config": { + "vcpu_count": instance.vcpu_count, + "mem_size_mib": instance.mem_mib, + "smt": False, + }, + "network-interfaces": [], + "vsock": { + "guest_cid": guest_cid, + "uds_path": str(instance.workdir / "vsock.sock"), + }, + } + + if instance.network is not None: + config["network-interfaces"] = [ + { + "iface_id": "eth0", + "guest_mac": instance.network.mac_address, + "host_dev_name": instance.network.tap_name, + } + ] + + guest_network = { + "guest_ip": instance.network.guest_ip if instance.network is not None else None, + "gateway_ip": instance.network.gateway_ip if instance.network is not None else None, + "subnet_cidr": instance.network.subnet_cidr if instance.network is not None else None, + "dns_servers": list(instance.network.dns_servers) if instance.network is not None else [], + "tap_name": instance.network.tap_name if instance.network is not None else None, + } + guest_exec = { + "transport": "vsock", + "guest_cid": guest_cid, + "port": vsock_port, + } + + config_path.write_text(json.dumps(config, indent=2, sort_keys=True), encoding="utf-8") + guest_network_path.write_text( + json.dumps(guest_network, indent=2, sort_keys=True), encoding="utf-8" + ) + guest_exec_path.write_text(json.dumps(guest_exec, indent=2, sort_keys=True), encoding="utf-8") + + return FirecrackerLaunchPlan( + api_socket_path=api_socket_path, + config_path=config_path, + guest_network_path=guest_network_path, + guest_exec_path=guest_exec_path, + guest_cid=guest_cid, + vsock_port=vsock_port, + config=config, + ) diff --git a/src/pyro_mcp/vm_guest.py b/src/pyro_mcp/vm_guest.py new file mode 100644 index 0000000..7d9a13b --- /dev/null +++ b/src/pyro_mcp/vm_guest.py @@ -0,0 +1,72 @@ +"""Guest command transport over vsock-compatible JSON protocol.""" + +from __future__ import annotations + +import json +import socket +from dataclasses import dataclass +from typing import Callable, Protocol + + +class SocketLike(Protocol): + def settimeout(self, timeout: int) -> None: ... + + def connect(self, address: tuple[int, int]) -> None: ... + + def sendall(self, data: bytes) -> None: ... + + def recv(self, size: int) -> bytes: ... + + def close(self) -> None: ... + + +SocketFactory = Callable[[int, int], SocketLike] + + +@dataclass(frozen=True) +class GuestExecResponse: + stdout: str + stderr: str + exit_code: int + duration_ms: int + + +class VsockExecClient: + """Minimal JSON-over-stream client for a guest exec agent.""" + + def __init__(self, socket_factory: SocketFactory | None = None) -> None: + self._socket_factory = socket_factory or socket.socket + + def exec( + self, guest_cid: int, port: int, command: str, timeout_seconds: int + ) -> GuestExecResponse: + request = { + "command": command, + "timeout_seconds": timeout_seconds, + } + family = getattr(socket, "AF_VSOCK", None) + if family is None: + raise RuntimeError("vsock sockets are not supported on this host Python runtime") + sock = self._socket_factory(family, socket.SOCK_STREAM) + try: + sock.settimeout(timeout_seconds) + sock.connect((guest_cid, port)) + sock.sendall((json.dumps(request) + "\n").encode("utf-8")) + chunks: list[bytes] = [] + while True: + data = sock.recv(65536) + if data == b"": + break + chunks.append(data) + finally: + sock.close() + + payload = json.loads(b"".join(chunks).decode("utf-8")) + if not isinstance(payload, dict): + raise RuntimeError("guest exec response must be a JSON object") + return GuestExecResponse( + stdout=str(payload.get("stdout", "")), + stderr=str(payload.get("stderr", "")), + exit_code=int(payload.get("exit_code", -1)), + duration_ms=int(payload.get("duration_ms", 0)), + ) diff --git a/src/pyro_mcp/vm_manager.py b/src/pyro_mcp/vm_manager.py index 53c20cf..4e38464 100644 --- a/src/pyro_mcp/vm_manager.py +++ b/src/pyro_mcp/vm_manager.py @@ -12,7 +12,15 @@ from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal -from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths +from pyro_mcp.runtime import ( + RuntimeCapabilities, + RuntimePaths, + resolve_runtime_paths, + runtime_capabilities, +) +from pyro_mcp.vm_firecracker import build_launch_plan +from pyro_mcp.vm_guest import VsockExecClient +from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts VmState = Literal["created", "started", "stopped"] @@ -34,6 +42,7 @@ class VmInstance: firecracker_pid: int | None = None last_error: str | None = None metadata: dict[str, str] = field(default_factory=dict) + network: NetworkConfig | None = None @dataclass(frozen=True) @@ -119,10 +128,21 @@ class MockBackend(VmBackend): class FirecrackerBackend(VmBackend): # pragma: no cover """Host-gated backend that validates Firecracker prerequisites.""" - def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None: + def __init__( + self, + artifacts_dir: Path, + firecracker_bin: Path, + jailer_bin: Path, + runtime_capabilities: RuntimeCapabilities, + network_manager: TapNetworkManager | None = None, + guest_exec_client: VsockExecClient | None = None, + ) -> None: self._artifacts_dir = artifacts_dir self._firecracker_bin = firecracker_bin self._jailer_bin = jailer_bin + self._runtime_capabilities = runtime_capabilities + self._network_manager = network_manager or TapNetworkManager() + self._guest_exec_client = guest_exec_client or VsockExecClient() if not self._firecracker_bin.exists(): raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}") if not self._jailer_bin.exists(): @@ -132,16 +152,29 @@ class FirecrackerBackend(VmBackend): # pragma: no cover def create(self, instance: VmInstance) -> None: instance.workdir.mkdir(parents=True, exist_ok=False) - artifacts = resolve_artifacts(self._artifacts_dir, instance.profile) - if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists(): - raise RuntimeError( - f"missing profile artifacts for {instance.profile}; expected " - f"{artifacts.kernel_image} and {artifacts.rootfs_image}" - ) - instance.metadata["kernel_image"] = str(artifacts.kernel_image) - instance.metadata["rootfs_image"] = str(artifacts.rootfs_image) + try: + artifacts = resolve_artifacts(self._artifacts_dir, instance.profile) + if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists(): + raise RuntimeError( + f"missing profile artifacts for {instance.profile}; expected " + f"{artifacts.kernel_image} and {artifacts.rootfs_image}" + ) + instance.metadata["kernel_image"] = str(artifacts.kernel_image) + instance.metadata["rootfs_image"] = str(artifacts.rootfs_image) + network = self._network_manager.allocate(instance.vm_id) + instance.network = network + instance.metadata.update(self._network_manager.to_metadata(network)) + except Exception: + shutil.rmtree(instance.workdir, ignore_errors=True) + raise def start(self, instance: VmInstance) -> None: + launch_plan = build_launch_plan(instance) + instance.metadata["firecracker_config_path"] = str(launch_plan.config_path) + instance.metadata["guest_network_path"] = str(launch_plan.guest_network_path) + instance.metadata["guest_exec_path"] = str(launch_plan.guest_exec_path) + instance.metadata["guest_cid"] = str(launch_plan.guest_cid) + instance.metadata["guest_exec_port"] = str(launch_plan.vsock_port) proc = subprocess.run( # noqa: S603 [str(self._firecracker_bin), "--version"], text=True, @@ -152,15 +185,35 @@ class FirecrackerBackend(VmBackend): # pragma: no cover raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}") instance.metadata["firecracker_version"] = proc.stdout.strip() instance.metadata["jailer_path"] = str(self._jailer_bin) + if not self._runtime_capabilities.supports_vm_boot: + instance.metadata["execution_mode"] = "host_compat" + instance.metadata["boot_mode"] = "shim" + if self._runtime_capabilities.reason is not None: + instance.metadata["runtime_reason"] = self._runtime_capabilities.reason + return + instance.metadata["execution_mode"] = "guest_vsock" + instance.metadata["boot_mode"] = "native" def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: - # Temporary compatibility path until guest-side execution agent is integrated. + if self._runtime_capabilities.supports_guest_exec: + guest_cid = int(instance.metadata["guest_cid"]) + port = int(instance.metadata["guest_exec_port"]) + response = self._guest_exec_client.exec(guest_cid, port, command, timeout_seconds) + return VmExecResult( + stdout=response.stdout, + stderr=response.stderr, + exit_code=response.exit_code, + duration_ms=response.duration_ms, + ) + instance.metadata["execution_mode"] = "host_compat" return _run_host_command(instance.workdir, command, timeout_seconds) def stop(self, instance: VmInstance) -> None: del instance def delete(self, instance: VmInstance) -> None: + if instance.network is not None: + self._network_manager.cleanup(instance.network) shutil.rmtree(instance.workdir, ignore_errors=True) @@ -182,6 +235,7 @@ class VmManager: artifacts_dir: Path | None = None, max_active_vms: int = 4, runtime_paths: RuntimePaths | None = None, + network_manager: TapNetworkManager | None = None, ) -> None: self._backend_name = backend_name or "firecracker" self._base_dir = base_dir or Path("/tmp/pyro-mcp") @@ -189,11 +243,19 @@ class VmManager: if self._backend_name == "firecracker": self._runtime_paths = self._runtime_paths or resolve_runtime_paths() self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir + self._runtime_capabilities = runtime_capabilities(self._runtime_paths) else: self._artifacts_dir = artifacts_dir or Path( os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts") ) + self._runtime_capabilities = RuntimeCapabilities( + supports_vm_boot=False, + supports_guest_exec=False, + supports_guest_network=False, + reason="mock backend does not boot a guest", + ) self._max_active_vms = max_active_vms + self._network_manager = network_manager or TapNetworkManager() self._lock = threading.Lock() self._instances: dict[str, VmInstance] = {} self._base_dir.mkdir(parents=True, exist_ok=True) @@ -209,6 +271,8 @@ class VmManager: self._artifacts_dir, firecracker_bin=self._runtime_paths.firecracker_bin, jailer_bin=self._runtime_paths.jailer_bin, + runtime_capabilities=self._runtime_capabilities, + network_manager=self._network_manager, ) raise ValueError("invalid backend; expected one of: mock, firecracker") @@ -262,6 +326,7 @@ class VmManager: if instance.state != "started": raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec") exec_result = self._backend.exec(instance, command, timeout_seconds) + execution_mode = instance.metadata.get("execution_mode", "host_compat") cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup") return { "vm_id": vm_id, @@ -270,6 +335,7 @@ class VmManager: "stderr": exec_result.stderr, "exit_code": exec_result.exit_code, "duration_ms": exec_result.duration_ms, + "execution_mode": execution_mode, "cleanup": cleanup, } @@ -296,6 +362,19 @@ class VmManager: self._ensure_not_expired_locked(instance, time.time()) return self._serialize(instance) + def network_info_vm(self, vm_id: str) -> dict[str, Any]: + with self._lock: + instance = self._get_instance_locked(vm_id) + self._ensure_not_expired_locked(instance, time.time()) + if instance.network is None: + return { + "vm_id": vm_id, + "network_enabled": False, + "outbound_connectivity_expected": False, + "reason": "network configuration is unavailable for this VM", + } + return {"vm_id": vm_id, **self._network_manager.network_info(instance.network)} + def reap_expired(self) -> dict[str, Any]: now = time.time() with self._lock: @@ -331,6 +410,10 @@ class VmManager: "created_at": instance.created_at, "expires_at": instance.expires_at, "state": instance.state, + "network_enabled": instance.network is not None, + "guest_ip": instance.network.guest_ip if instance.network is not None else None, + "tap_name": instance.network.tap_name if instance.network is not None else None, + "execution_mode": instance.metadata.get("execution_mode", "host_compat"), "metadata": instance.metadata, } diff --git a/src/pyro_mcp/vm_network.py b/src/pyro_mcp/vm_network.py new file mode 100644 index 0000000..f24b49f --- /dev/null +++ b/src/pyro_mcp/vm_network.py @@ -0,0 +1,191 @@ +"""Host-side network allocation and diagnostics for Firecracker VMs.""" + +from __future__ import annotations + +import os +import shutil +import subprocess +from dataclasses import dataclass +from pathlib import Path +from typing import Callable + +CommandRunner = Callable[[list[str]], subprocess.CompletedProcess[str]] + +DEFAULT_NETWORK_PREFIX = "172.29." +DEFAULT_GATEWAY_SUFFIX = 1 +DEFAULT_GUEST_SUFFIX = 2 + + +@dataclass(frozen=True) +class NetworkConfig: + vm_id: str + tap_name: str + guest_ip: str + gateway_ip: str + subnet_cidr: str + mac_address: str + dns_servers: tuple[str, ...] = ("1.1.1.1", "8.8.8.8") + + +@dataclass(frozen=True) +class NetworkDiagnostics: + tun_available: bool + ip_binary: str | None + nft_binary: str | None + iptables_binary: str | None + ip_forward_enabled: bool + + +class TapNetworkManager: + """Allocate per-VM TAP/NAT configuration on the host.""" + + def __init__( + self, + *, + enabled: bool | None = None, + runner: CommandRunner | None = None, + ) -> None: + if enabled is None: + self._enabled = os.environ.get("PYRO_VM_ENABLE_NETWORK") == "1" + else: + self._enabled = enabled + self._runner = runner or self._run + + @staticmethod + def diagnostics() -> NetworkDiagnostics: + ip_forward = False + try: + ip_forward = ( + Path("/proc/sys/net/ipv4/ip_forward").read_text(encoding="utf-8").strip() + == "1" + ) + except OSError: + ip_forward = False + return NetworkDiagnostics( + tun_available=Path("/dev/net/tun").exists(), + ip_binary=shutil.which("ip"), + nft_binary=shutil.which("nft"), + iptables_binary=shutil.which("iptables"), + ip_forward_enabled=ip_forward, + ) + + @property + def enabled(self) -> bool: + return self._enabled + + def allocate(self, vm_id: str) -> NetworkConfig: + octet = int(vm_id[:2], 16) + third_octet = 1 + (octet % 200) + subnet_cidr = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.0/24" + gateway_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GATEWAY_SUFFIX}" + guest_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GUEST_SUFFIX}" + tap_name = f"pyro{vm_id[:8]}" + mac_address = f"06:00:{vm_id[0:2]}:{vm_id[2:4]}:{vm_id[4:6]}:{vm_id[6:8]}" + config = NetworkConfig( + vm_id=vm_id, + tap_name=tap_name, + guest_ip=guest_ip, + gateway_ip=gateway_ip, + subnet_cidr=subnet_cidr, + mac_address=mac_address, + ) + if self._enabled: + self._ensure_host_network(config) + return config + + def cleanup(self, config: NetworkConfig) -> None: + if not self._enabled: + return + table_name = self._nft_table_name(config.vm_id) + self._run_ignore(["nft", "delete", "table", "ip", table_name]) + self._run_ignore(["ip", "link", "del", config.tap_name]) + + def to_metadata(self, config: NetworkConfig) -> dict[str, str]: + return { + "network_enabled": "true" if self._enabled else "false", + "tap_name": config.tap_name, + "guest_ip": config.guest_ip, + "gateway_ip": config.gateway_ip, + "subnet_cidr": config.subnet_cidr, + "mac_address": config.mac_address, + "dns_servers": ",".join(config.dns_servers), + } + + def network_info(self, config: NetworkConfig) -> dict[str, object]: + return { + "network_enabled": self._enabled, + "tap_name": config.tap_name, + "guest_ip": config.guest_ip, + "gateway_ip": config.gateway_ip, + "subnet_cidr": config.subnet_cidr, + "mac_address": config.mac_address, + "dns_servers": list(config.dns_servers), + "outbound_connectivity_expected": self._enabled, + } + + def _ensure_host_network(self, config: NetworkConfig) -> None: + diagnostics = self.diagnostics() + if not diagnostics.tun_available: + raise RuntimeError("/dev/net/tun is not available on this host") + if diagnostics.ip_binary is None: + raise RuntimeError("`ip` binary is required for TAP setup") + if diagnostics.nft_binary is None: + raise RuntimeError("`nft` binary is required for outbound NAT setup") + if not diagnostics.ip_forward_enabled: + raise RuntimeError("IPv4 forwarding is disabled on this host") + + self._runner(["ip", "tuntap", "add", "dev", config.tap_name, "mode", "tap"]) + self._runner(["ip", "addr", "add", f"{config.gateway_ip}/24", "dev", config.tap_name]) + self._runner(["ip", "link", "set", config.tap_name, "up"]) + + table_name = self._nft_table_name(config.vm_id) + self._runner(["nft", "add", "table", "ip", table_name]) + self._runner( + [ + "nft", + "add", + "chain", + "ip", + table_name, + "postrouting", + "{", + "type", + "nat", + "hook", + "postrouting", + "priority", + "srcnat", + ";", + "}", + ] + ) + self._runner([ + "nft", + "add", + "rule", + "ip", + table_name, + "postrouting", + "ip", + "saddr", + config.subnet_cidr, + "masquerade", + ]) + + def _run_ignore(self, command: list[str]) -> None: + try: + self._runner(command) + except RuntimeError: + return + + @staticmethod + def _nft_table_name(vm_id: str) -> str: + return f"pyro_{vm_id[:12]}" + + @staticmethod + def _run(command: list[str]) -> subprocess.CompletedProcess[str]: + completed = subprocess.run(command, text=True, capture_output=True, check=False) + if completed.returncode != 0: + stderr = completed.stderr.strip() or completed.stdout.strip() + raise RuntimeError(f"command {' '.join(command)!r} failed: {stderr}") + return completed diff --git a/tests/test_demo.py b/tests/test_demo.py index d9b247a..6947042 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -40,6 +40,10 @@ def test_run_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> None: calls.append(("start_vm", {"vm_id": vm_id})) return {"vm_id": vm_id} + def status_vm(self, vm_id: str) -> dict[str, Any]: + calls.append(("status_vm", {"vm_id": vm_id})) + return {"vm_id": vm_id, "network_enabled": False, "execution_mode": "host_compat"} + def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: calls.append( ( @@ -55,7 +59,13 @@ def test_run_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> None: assert result["exit_code"] == 0 assert calls[0][0] == "create_vm" assert calls[1] == ("start_vm", {"vm_id": "vm-1"}) - assert calls[2][0] == "exec_vm" + assert calls[2] == ("status_vm", {"vm_id": "vm-1"}) + assert calls[3][0] == "exec_vm" + + +def test_demo_command_prefers_network_probe_for_guest_vsock() -> None: + status = {"network_enabled": True, "execution_mode": "guest_vsock"} + assert "https://example.com" in demo_module._demo_command(status) def test_main_prints_json( diff --git a/tests/test_ollama_demo.py b/tests/test_ollama_demo.py index 1e61040..82b1235 100644 --- a/tests/test_ollama_demo.py +++ b/tests/test_ollama_demo.py @@ -94,7 +94,7 @@ def _stepwise_model_response(payload: dict[str, Any], step: int) -> dict[str, An "arguments": json.dumps( { "vm_id": vm_id, - "command": "printf 'git version 2.44.0\\n'", + "command": "printf 'true\\n'", } ), }, @@ -125,14 +125,14 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non result = ollama_demo.run_ollama_tool_demo(log=logs.append) assert result["fallback_used"] is False - assert "git version" in str(result["exec_result"]["stdout"]) + assert str(result["exec_result"]["stdout"]).strip() == "true" assert result["final_response"] == "Executed git command in ephemeral VM." assert len(result["tool_events"]) == 4 - assert any("[model] input user:" in line for line in logs) - assert any("[model] output assistant:" in line for line in logs) + assert any(line == "[model] input user" for line in logs) + assert any(line == "[model] output assistant" for line in logs) assert any("[model] tool_call vm_exec" in line for line in logs) - assert any("[tool] calling vm_exec" in line for line in logs) - assert any("[tool] result vm_exec " in line for line in logs) + assert any(line == "[tool] calling vm_exec" for line in logs) + assert any(line == "[tool] result vm_exec" for line in logs) def test_run_ollama_tool_demo_recovers_from_bad_vm_id( @@ -158,7 +158,7 @@ def test_run_ollama_tool_demo_recovers_from_bad_vm_id( "arguments": json.dumps( { "vm_id": "vm_list_profiles", - "command": "git --version", + "command": ollama_demo.NETWORK_PROOF_COMMAND, } ), }, @@ -219,7 +219,7 @@ def test_run_ollama_tool_demo_resolves_vm_id_placeholder( "arguments": json.dumps( { "vm_id": "", - "command": "printf 'git version 2.44.0\\n'", + "command": "printf 'true\\n'", "timeout_seconds": "300", } ), @@ -292,14 +292,49 @@ def test_run_ollama_tool_demo_uses_fallback_when_not_strict( monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) monkeypatch.setattr(ollama_demo, "VmManager", TestVmManager) + monkeypatch.setattr( + ollama_demo, + "_run_direct_lifecycle_fallback", + lambda manager: { + "vm_id": "vm-1", + "command": ollama_demo.NETWORK_PROOF_COMMAND, + "stdout": "true\n", + "stderr": "", + "exit_code": 0, + "duration_ms": 5, + "execution_mode": "host_compat", + "cleanup": {"deleted": True, "reason": "post_exec_cleanup", "vm_id": "vm-1"}, + }, + ) logs: list[str] = [] result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append) assert result["fallback_used"] is True assert int(result["exec_result"]["exit_code"]) == 0 - assert any("[model] output assistant: No tools" in line for line in logs) + assert any(line == "[model] output assistant" for line in logs) assert any("[fallback]" in line for line in logs) +def test_run_ollama_tool_demo_verbose_logs_values(monkeypatch: pytest.MonkeyPatch) -> None: + requests = 0 + + def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: + del base_url + nonlocal requests + requests += 1 + return _stepwise_model_response(payload, requests) + + monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) + + logs: list[str] = [] + result = ollama_demo.run_ollama_tool_demo(verbose=True, log=logs.append) + + assert result["fallback_used"] is False + assert str(result["exec_result"]["stdout"]).strip() == "true" + assert any("[model] input user:" in line for line in logs) + assert any("[model] tool_call vm_list_profiles args={}" in line for line in logs) + assert any("[tool] result vm_exec " in line for line in logs) + + @pytest.mark.parametrize( ("tool_call", "error"), [ @@ -346,8 +381,8 @@ def test_run_ollama_tool_demo_max_rounds(monkeypatch: pytest.MonkeyPatch) -> Non ("exec_result", "error"), [ ("bad", "result shape is invalid"), - ({"exit_code": 1, "stdout": "git version 2"}, "expected exit_code=0"), - ({"exit_code": 0, "stdout": "no git"}, "did not contain `git version`"), + ({"exit_code": 1, "stdout": "true"}, "expected exit_code=0"), + ({"exit_code": 0, "stdout": "false"}, "did not confirm repository clone success"), ], ) def test_run_ollama_tool_demo_exec_result_validation( @@ -404,7 +439,7 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None: executed = ollama_demo._dispatch_tool_call( manager, "vm_exec", - {"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"}, + {"vm_id": vm_id, "command": "printf 'true\\n'", "timeout_seconds": "30"}, ) assert int(executed["exit_code"]) == 0 with pytest.raises(RuntimeError, match="unexpected tool requested by model"): @@ -529,6 +564,13 @@ def test_build_parser_defaults() -> None: args = parser.parse_args([]) assert args.model == ollama_demo.DEFAULT_OLLAMA_MODEL assert args.base_url == ollama_demo.DEFAULT_OLLAMA_BASE_URL + assert args.verbose is False + + +def test_build_parser_verbose_flag() -> None: + parser = ollama_demo._build_parser() + args = parser.parse_args(["-v"]) + assert args.verbose is True def test_main_uses_parser_and_prints_logs( @@ -537,21 +579,51 @@ def test_main_uses_parser_and_prints_logs( ) -> None: class StubParser: def parse_args(self) -> argparse.Namespace: - return argparse.Namespace(base_url="http://x", model="m") + return argparse.Namespace(base_url="http://x", model="m", verbose=False) monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser()) monkeypatch.setattr( ollama_demo, "run_ollama_tool_demo", - lambda base_url, model, strict=True, log=None: { - "exec_result": {"exit_code": 0, "stdout": "git version 2.44.0\n"}, + lambda base_url, model, strict=True, verbose=False, log=None: { + "exec_result": { + "exit_code": 0, + "stdout": "true\n", + "execution_mode": "host_compat", + }, "fallback_used": False, }, ) ollama_demo.main() output = capsys.readouterr().out - assert "[summary] exit_code=0 fallback_used=False" in output - assert "[summary] stdout=git version 2.44.0" in output + assert "[summary] exit_code=0 fallback_used=False execution_mode=host_compat" in output + assert "[summary] stdout=" not in output + + +def test_main_verbose_prints_stdout( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + class StubParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace(base_url="http://x", model="m", verbose=True) + + monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser()) + monkeypatch.setattr( + ollama_demo, + "run_ollama_tool_demo", + lambda base_url, model, strict=True, verbose=False, log=None: { + "exec_result": { + "exit_code": 0, + "stdout": "true\n", + "execution_mode": "host_compat", + }, + "fallback_used": False, + }, + ) + ollama_demo.main() + output = capsys.readouterr().out + assert "[summary] stdout=true" in output def test_main_logs_error_and_exits_nonzero( @@ -560,12 +632,18 @@ def test_main_logs_error_and_exits_nonzero( ) -> None: class StubParser: def parse_args(self) -> argparse.Namespace: - return argparse.Namespace(base_url="http://x", model="m") + return argparse.Namespace(base_url="http://x", model="m", verbose=False) monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser()) - def fake_run(base_url: str, model: str, strict: bool = True, log: Any = None) -> dict[str, Any]: - del base_url, model, strict, log + def fake_run( + base_url: str, + model: str, + strict: bool = True, + verbose: bool = False, + log: Any = None, + ) -> dict[str, Any]: + del base_url, model, strict, verbose, log raise RuntimeError("demo did not execute a successful vm_exec") monkeypatch.setattr(ollama_demo, "run_ollama_tool_demo", fake_run) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 096ed25..4ba5972 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -5,7 +5,7 @@ from pathlib import Path import pytest -from pyro_mcp.runtime import doctor_report, resolve_runtime_paths +from pyro_mcp.runtime import doctor_report, resolve_runtime_paths, runtime_capabilities def test_resolve_runtime_paths_default_bundle() -> None: @@ -67,7 +67,19 @@ def test_doctor_report_has_runtime_fields() -> None: report = doctor_report() assert "runtime_ok" in report assert "kvm" in report + assert "networking" in report if report["runtime_ok"]: runtime = report.get("runtime") assert isinstance(runtime, dict) assert "firecracker_bin" in runtime + networking = report["networking"] + assert isinstance(networking, dict) + assert "tun_available" in networking + + +def test_runtime_capabilities_reports_shim_bundle() -> None: + paths = resolve_runtime_paths() + capabilities = runtime_capabilities(paths) + assert capabilities.supports_vm_boot is False + assert capabilities.supports_guest_exec is False + assert capabilities.supports_guest_network is False diff --git a/tests/test_server.py b/tests/test_server.py index dd0ecd0..ba38313 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -9,10 +9,15 @@ import pytest import pyro_mcp.server as server_module from pyro_mcp.server import create_server from pyro_mcp.vm_manager import VmManager +from pyro_mcp.vm_network import TapNetworkManager def test_create_server_registers_vm_tools(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) async def _run() -> list[str]: server = create_server(manager=manager) @@ -23,11 +28,16 @@ def test_create_server_registers_vm_tools(tmp_path: Path) -> None: assert "vm_create" in tool_names assert "vm_exec" in tool_names assert "vm_list_profiles" in tool_names + assert "vm_network_info" in tool_names assert "vm_status" in tool_names def test_vm_tools_lifecycle_round_trip(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) def _extract_structured(raw_result: object) -> dict[str, Any]: if not isinstance(raw_result, tuple) or len(raw_result) != 2: @@ -60,7 +70,11 @@ def test_vm_tools_lifecycle_round_trip(tmp_path: Path) -> None: def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) manager.MIN_TTL_SECONDS = 1 def _extract_structured(raw_result: object) -> dict[str, Any]: @@ -72,7 +86,12 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None: return cast(dict[str, Any], structured) async def _run() -> tuple[ - dict[str, Any], dict[str, Any], dict[str, Any], list[dict[str, object]], dict[str, Any] + dict[str, Any], + dict[str, Any], + dict[str, Any], + dict[str, Any], + list[dict[str, object]], + dict[str, Any], ]: server = create_server(manager=manager) profiles_raw = await server.call_tool("vm_list_profiles", {}) @@ -93,6 +112,7 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None: vm_id = str(created["vm_id"]) await server.call_tool("vm_start", {"vm_id": vm_id}) status = _extract_structured(await server.call_tool("vm_status", {"vm_id": vm_id})) + network = _extract_structured(await server.call_tool("vm_network_info", {"vm_id": vm_id})) stopped = _extract_structured(await server.call_tool("vm_stop", {"vm_id": vm_id})) deleted = _extract_structured(await server.call_tool("vm_delete", {"vm_id": vm_id})) @@ -105,10 +125,18 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None: expiring_id = str(expiring["vm_id"]) manager._instances[expiring_id].expires_at = 0.0 # noqa: SLF001 reaped = _extract_structured(await server.call_tool("vm_reap_expired", {})) - return status, stopped, deleted, cast(list[dict[str, object]], raw_profiles), reaped + return ( + status, + network, + stopped, + deleted, + cast(list[dict[str, object]], raw_profiles), + reaped, + ) - status, stopped, deleted, profiles, reaped = asyncio.run(_run()) + status, network, stopped, deleted, profiles, reaped = asyncio.run(_run()) assert status["state"] == "started" + assert network["network_enabled"] is False assert stopped["state"] == "stopped" assert bool(deleted["deleted"]) is True assert profiles[0]["name"] == "debian-base" diff --git a/tests/test_vm_firecracker.py b/tests/test_vm_firecracker.py new file mode 100644 index 0000000..7cbc22b --- /dev/null +++ b/tests/test_vm_firecracker.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +from pyro_mcp.vm_firecracker import build_launch_plan +from pyro_mcp.vm_network import NetworkConfig + + +@dataclass +class StubInstance: + vm_id: str + vcpu_count: int + mem_mib: int + workdir: Path + metadata: dict[str, str] = field(default_factory=dict) + network: Any = None + + +def test_build_launch_plan_writes_expected_files(tmp_path: Path) -> None: + instance = StubInstance( + vm_id="abcdef123456", + vcpu_count=2, + mem_mib=2048, + workdir=tmp_path, + metadata={ + "kernel_image": "/bundle/profiles/debian-git/vmlinux", + "rootfs_image": "/bundle/profiles/debian-git/rootfs.ext4", + }, + network=NetworkConfig( + vm_id="abcdef123456", + tap_name="pyroabcdef12", + guest_ip="172.29.100.2", + gateway_ip="172.29.100.1", + subnet_cidr="172.29.100.0/24", + mac_address="06:00:ab:cd:ef:12", + ), + ) + + plan = build_launch_plan(instance) + + assert plan.config_path.exists() + assert plan.guest_network_path.exists() + assert plan.guest_exec_path.exists() + rendered = json.loads(plan.config_path.read_text(encoding="utf-8")) + assert rendered["machine-config"]["vcpu_count"] == 2 + assert rendered["network-interfaces"][0]["host_dev_name"] == "pyroabcdef12" + guest_exec = json.loads(plan.guest_exec_path.read_text(encoding="utf-8")) + assert guest_exec["transport"] == "vsock" diff --git a/tests/test_vm_guest.py b/tests/test_vm_guest.py new file mode 100644 index 0000000..32e629b --- /dev/null +++ b/tests/test_vm_guest.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import socket + +import pytest + +from pyro_mcp.vm_guest import VsockExecClient + + +class StubSocket: + def __init__(self, response: bytes) -> None: + self.response = response + self.connected: tuple[int, int] | None = None + self.sent = b"" + self.timeout: int | None = None + self.closed = False + + def settimeout(self, timeout: int) -> None: + self.timeout = timeout + + def connect(self, address: tuple[int, int]) -> None: + self.connected = address + + def sendall(self, data: bytes) -> None: + self.sent += data + + def recv(self, size: int) -> bytes: + del size + if self.response == b"": + return b"" + data, self.response = self.response, b"" + return data + + def close(self) -> None: + self.closed = True + + +def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) + stub = StubSocket( + b'{"stdout":"ok\\n","stderr":"","exit_code":0,"duration_ms":7}' + ) + + def socket_factory(family: int, sock_type: int) -> StubSocket: + assert family == socket.AF_VSOCK + assert sock_type == socket.SOCK_STREAM + return stub + + client = VsockExecClient(socket_factory=socket_factory) + response = client.exec(1234, 5005, "echo ok", 30) + + assert response.exit_code == 0 + assert response.stdout == "ok\n" + assert stub.connected == (1234, 5005) + assert b'"command": "echo ok"' in stub.sent + assert stub.closed is True + + +def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) + stub = StubSocket(b"[]") + client = VsockExecClient(socket_factory=lambda family, sock_type: stub) + with pytest.raises(RuntimeError, match="JSON object"): + client.exec(1234, 5005, "echo ok", 30) diff --git a/tests/test_vm_manager.py b/tests/test_vm_manager.py index b8f6c04..e32c057 100644 --- a/tests/test_vm_manager.py +++ b/tests/test_vm_manager.py @@ -8,10 +8,15 @@ import pytest import pyro_mcp.vm_manager as vm_manager_module from pyro_mcp.runtime import resolve_runtime_paths from pyro_mcp.vm_manager import VmManager +from pyro_mcp.vm_network import TapNetworkManager def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600) vm_id = str(created["vm_id"]) started = manager.start_vm(vm_id) @@ -19,13 +24,18 @@ def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None: executed = manager.exec_vm(vm_id, command="printf 'git version 2.43.0\\n'", timeout_seconds=30) assert executed["exit_code"] == 0 + assert executed["execution_mode"] == "host_compat" assert "git version" in str(executed["stdout"]) with pytest.raises(ValueError, match="does not exist"): manager.status_vm(vm_id) def test_vm_manager_exec_timeout(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) vm_id = str( manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[ "vm_id" @@ -38,7 +48,11 @@ def test_vm_manager_exec_timeout(tmp_path: Path) -> None: def test_vm_manager_stop_and_delete(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) vm_id = str( manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[ "vm_id" @@ -52,7 +66,11 @@ def test_vm_manager_stop_and_delete(tmp_path: Path) -> None: def test_vm_manager_reaps_expired(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) manager.MIN_TTL_SECONDS = 1 vm_id = str( manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"] @@ -66,7 +84,11 @@ def test_vm_manager_reaps_expired(tmp_path: Path) -> None: def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) manager.MIN_TTL_SECONDS = 1 vm_id = str( manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"] @@ -86,20 +108,33 @@ def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None: ], ) def test_vm_manager_validates_limits(tmp_path: Path, kwargs: dict[str, int], msg: str) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) with pytest.raises(ValueError, match=msg): manager.create_vm(profile="debian-base", **kwargs) def test_vm_manager_max_active_limit(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms", max_active_vms=1) + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + max_active_vms=1, + network_manager=TapNetworkManager(enabled=False), + ) manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600) with pytest.raises(RuntimeError, match="max active VMs reached"): manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600) def test_vm_manager_state_validation(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) vm_id = str( manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[ "vm_id" @@ -115,7 +150,11 @@ def test_vm_manager_state_validation(tmp_path: Path) -> None: def test_vm_manager_status_expired_raises(tmp_path: Path) -> None: - manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) manager.MIN_TTL_SECONDS = 1 vm_id = str( manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"] @@ -127,17 +166,45 @@ def test_vm_manager_status_expired_raises(tmp_path: Path) -> None: def test_vm_manager_invalid_backend(tmp_path: Path) -> None: with pytest.raises(ValueError, match="invalid backend"): - VmManager(backend_name="nope", base_dir=tmp_path / "vms") + VmManager( + backend_name="nope", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + +def test_vm_manager_network_info(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + created = manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600) + vm_id = str(created["vm_id"]) + status = manager.status_vm(vm_id) + info = manager.network_info_vm(vm_id) + assert status["network_enabled"] is False + assert status["guest_ip"] is None + assert info["network_enabled"] is False def test_vm_manager_firecracker_backend_path( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: class StubFirecrackerBackend: - def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None: + def __init__( + self, + artifacts_dir: Path, + firecracker_bin: Path, + jailer_bin: Path, + runtime_capabilities: Any, + network_manager: TapNetworkManager, + ) -> None: self.artifacts_dir = artifacts_dir self.firecracker_bin = firecracker_bin self.jailer_bin = jailer_bin + self.runtime_capabilities = runtime_capabilities + self.network_manager = network_manager def create(self, instance: Any) -> None: del instance @@ -160,5 +227,6 @@ def test_vm_manager_firecracker_backend_path( backend_name="firecracker", base_dir=tmp_path / "vms", runtime_paths=resolve_runtime_paths(), + network_manager=TapNetworkManager(enabled=False), ) assert manager._backend_name == "firecracker" # noqa: SLF001 diff --git a/tests/test_vm_network.py b/tests/test_vm_network.py new file mode 100644 index 0000000..2a71bec --- /dev/null +++ b/tests/test_vm_network.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import subprocess +from pathlib import Path + +import pytest + +from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager + + +def test_tap_network_manager_allocation_disabled_by_default( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.delenv("PYRO_VM_ENABLE_NETWORK", raising=False) + manager = TapNetworkManager() + config = manager.allocate("abcdef123456") + assert manager.enabled is False + assert config.tap_name == "pyroabcdef12" + assert config.guest_ip.startswith("172.29.") + metadata = manager.to_metadata(config) + assert metadata["network_enabled"] == "false" + + +def test_tap_network_manager_network_info() -> None: + manager = TapNetworkManager(enabled=False) + config = NetworkConfig( + vm_id="abcdef123456", + tap_name="pyroabcdef12", + guest_ip="172.29.100.2", + gateway_ip="172.29.100.1", + subnet_cidr="172.29.100.0/24", + mac_address="06:00:ab:cd:ef:12", + ) + info = manager.network_info(config) + assert info["tap_name"] == "pyroabcdef12" + assert info["outbound_connectivity_expected"] is False + + +def test_tap_network_manager_enabled_runs_host_commands() -> None: + commands: list[list[str]] = [] + + def runner(command: list[str]) -> subprocess.CompletedProcess[str]: + commands.append(command) + return subprocess.CompletedProcess(command, 0, "", "") + + manager = TapNetworkManager(enabled=True, runner=runner) + config = manager.allocate("abcdef123456") + manager.cleanup(config) + assert commands[0][:4] == ["ip", "tuntap", "add", "dev"] + assert commands[-1][:3] in (["ip", "link", "del"], ["nft", "delete", "table"]) + + +def test_tap_network_manager_missing_host_support(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(Path, "exists", lambda self: False if str(self) == "/dev/net/tun" else True) + manager = TapNetworkManager( + enabled=True, + runner=lambda command: subprocess.CompletedProcess(command, 0, "", ""), + ) + with pytest.raises(RuntimeError, match="/dev/net/tun"): + manager.allocate("abcdef123456")