Add runtime capability scaffolding and align docs

This commit is contained in:
Thales Maciel 2026-03-05 22:57:09 -03:00
parent fb8b985049
commit cbf212bb7b
19 changed files with 1048 additions and 71 deletions

View file

@ -14,6 +14,7 @@ This repository ships `pyro-mcp`, an MCP-compatible package for ephemeral VM lif
- Use `make demo` to validate deterministic VM lifecycle execution. - Use `make demo` to validate deterministic VM lifecycle execution.
- Use `make ollama-demo` to validate model-triggered lifecycle tool usage. - Use `make ollama-demo` to validate model-triggered lifecycle tool usage.
- Use `make doctor` to inspect bundled runtime integrity and host prerequisites. - Use `make doctor` to inspect bundled runtime integrity and host prerequisites.
- If you need full log payloads from the Ollama demo, use `make ollama-demo OLLAMA_DEMO_FLAGS=-v`.
## Quality Gates ## Quality Gates
@ -27,6 +28,7 @@ These checks run in pre-commit hooks and should all pass locally.
- Public factory: `pyro_mcp.create_server()` - Public factory: `pyro_mcp.create_server()`
- Runtime diagnostics CLI: `pyro-mcp-doctor` - Runtime diagnostics CLI: `pyro-mcp-doctor`
- Current bundled runtime is shim-based unless replaced with a real guest-capable bundle; check `make doctor` for runtime capabilities.
- Lifecycle tools: - Lifecycle tools:
- `vm_list_profiles` - `vm_list_profiles`
- `vm_create` - `vm_create`
@ -35,3 +37,5 @@ These checks run in pre-commit hooks and should all pass locally.
- `vm_stop` - `vm_stop`
- `vm_delete` - `vm_delete`
- `vm_status` - `vm_status`
- `vm_network_info`
- `vm_reap_expired`

View file

@ -1,6 +1,7 @@
PYTHON ?= uv run python PYTHON ?= uv run python
OLLAMA_BASE_URL ?= http://localhost:11434/v1 OLLAMA_BASE_URL ?= http://localhost:11434/v1
OLLAMA_MODEL ?= llama3.2:3b OLLAMA_MODEL ?= llama3.2:3b
OLLAMA_DEMO_FLAGS ?=
.PHONY: setup lint format typecheck test check demo doctor ollama ollama-demo run-server install-hooks .PHONY: setup lint format typecheck test check demo doctor ollama ollama-demo run-server install-hooks
@ -30,7 +31,7 @@ doctor:
ollama: ollama-demo ollama: ollama-demo
ollama-demo: ollama-demo:
uv run pyro-mcp-ollama-demo --base-url "$(OLLAMA_BASE_URL)" --model "$(OLLAMA_MODEL)" uv run pyro-mcp-ollama-demo --base-url "$(OLLAMA_BASE_URL)" --model "$(OLLAMA_MODEL)" $(OLLAMA_DEMO_FLAGS)
run-server: run-server:
uv run pyro-mcp-server uv run pyro-mcp-server

View file

@ -4,14 +4,14 @@
## v0.1.0 Capabilities ## v0.1.0 Capabilities
- Split lifecycle tools for coding agents: `vm_list_profiles`, `vm_create`, `vm_start`, `vm_exec`, `vm_stop`, `vm_delete`, `vm_status`, `vm_reap_expired`. - Split lifecycle tools for coding agents: `vm_list_profiles`, `vm_create`, `vm_start`, `vm_exec`, `vm_stop`, `vm_delete`, `vm_status`, `vm_network_info`, `vm_reap_expired`.
- Standard environment profiles: - Standard environment profiles:
- `debian-base`: minimal Debian shell/core Unix tools. - `debian-base`: minimal Debian shell/core Unix tools.
- `debian-git`: Debian base with Git preinstalled. - `debian-git`: Debian base with Git preinstalled.
- `debian-build`: Debian Git profile with common build tooling. - `debian-build`: Debian Git profile with common build tooling.
- Explicit sizing contract for agents (`vcpu_count`, `mem_mib`) with guardrails. - Explicit sizing contract for agents (`vcpu_count`, `mem_mib`) with guardrails.
- Strict ephemerality for command execution (`vm_exec` auto-deletes VM on completion). - Strict ephemerality for command execution (`vm_exec` auto-deletes VM on completion).
- Ollama demo that asks an LLM to run `git --version` through lifecycle tools. - Ollama demo that asks an LLM to clone a small public Git repository through lifecycle tools.
## Runtime ## Runtime
@ -22,6 +22,12 @@ The package includes a bundled Linux x86_64 runtime payload:
No system Firecracker installation is required for basic usage. No system Firecracker installation is required for basic usage.
Current limitation:
- The bundled runtime is currently shim-based.
- `doctor` reports runtime capabilities, and current bundles report no real guest boot, no guest exec agent, and no guest networking.
- Until a real guest-capable bundle is installed, `vm_exec` runs in `host_compat` mode rather than `guest_vsock`.
- This means demo commands can exercise lifecycle/control-plane behavior, but they are not yet proof of command execution inside a real VM guest.
Host requirements still apply: Host requirements still apply:
- Linux host - Linux host
- `/dev/kvm` available for full virtualization mode - `/dev/kvm` available for full virtualization mode
@ -44,7 +50,9 @@ make setup
make demo make demo
``` ```
The demo creates a VM, starts it, runs `git --version`, and returns structured output. The demo creates a VM, starts it, runs a command, and returns structured output.
If the runtime reports `guest_vsock` plus networking, it uses an internet probe.
Otherwise it falls back to a local compatibility command and the result will report `execution_mode=host_compat`.
## Runtime doctor ## Runtime doctor
@ -52,7 +60,21 @@ The demo creates a VM, starts it, runs `git --version`, and returns structured o
make doctor make doctor
``` ```
This prints bundled runtime paths, profile availability, checksum validation status, and KVM host checks. This prints bundled runtime paths, profile availability, checksum validation status, runtime capability flags, KVM host checks, and host networking diagnostics.
## Networking
- Host-side network allocation and diagnostics are implemented.
- The MCP server exposes `vm_network_info` for per-VM network metadata.
- Host TAP/NAT setup is opt-in with:
```bash
PYRO_VM_ENABLE_NETWORK=1 make doctor
```
- Current limitation:
- network metadata and host preflight exist
- real in-guest outbound networking still depends on a non-shim runtime bundle with real guest boot and guest exec support
## Run Ollama lifecycle demo ## Run Ollama lifecycle demo
@ -64,6 +86,18 @@ make ollama-demo
Defaults are configured in `Makefile`. Defaults are configured in `Makefile`.
The demo streams lifecycle progress logs and ends with a short text summary. The demo streams lifecycle progress logs and ends with a short text summary.
The command it asks the model to run is a small public repository clone:
```bash
rm -rf hello-world && git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world >/dev/null && git -C hello-world rev-parse --is-inside-work-tree
```
If the runtime is still shim-based, the summary will show `execution_mode=host_compat`.
By default it omits log values; to include prompt content, tool args, and tool results use:
```bash
make ollama-demo OLLAMA_DEMO_FLAGS=-v
```
## Run MCP server ## Run MCP server

View file

@ -7,6 +7,18 @@ from typing import Any
from pyro_mcp.vm_manager import VmManager from pyro_mcp.vm_manager import VmManager
INTERNET_PROBE_COMMAND = (
'python3 -c "import urllib.request; '
"print(urllib.request.urlopen('https://example.com', timeout=10).status)"
'"'
)
def _demo_command(status: dict[str, Any]) -> str:
if bool(status.get("network_enabled")) and status.get("execution_mode") == "guest_vsock":
return INTERNET_PROBE_COMMAND
return "git --version"
def run_demo() -> dict[str, Any]: def run_demo() -> dict[str, Any]:
"""Create/start/exec/delete a VM and return command output.""" """Create/start/exec/delete a VM and return command output."""
@ -14,7 +26,8 @@ def run_demo() -> dict[str, Any]:
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600) created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
manager.start_vm(vm_id) manager.start_vm(vm_id)
executed = manager.exec_vm(vm_id, command="git --version", timeout_seconds=30) status = manager.status_vm(vm_id)
executed = manager.exec_vm(vm_id, command=_demo_command(status), timeout_seconds=30)
return executed return executed

View file

@ -16,6 +16,12 @@ __all__ = ["VmManager", "run_ollama_tool_demo"]
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1" DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b" DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
MAX_TOOL_ROUNDS: Final[int] = 12 MAX_TOOL_ROUNDS: Final[int] = 12
CLONE_TARGET_DIR: Final[str] = "hello-world"
NETWORK_PROOF_COMMAND: Final[str] = (
"rm -rf hello-world "
"&& git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world >/dev/null "
"&& git -C hello-world rev-parse --is-inside-work-tree"
)
TOOL_SPECS: Final[list[dict[str, Any]]] = [ TOOL_SPECS: Final[list[dict[str, Any]]] = [
{ {
@ -210,7 +216,7 @@ def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600) created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
manager.start_vm(vm_id) manager.start_vm(vm_id)
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30) return manager.exec_vm(vm_id, command=NETWORK_PROOF_COMMAND, timeout_seconds=60)
def _is_vm_id_placeholder(value: str) -> bool: def _is_vm_id_placeholder(value: str) -> bool:
@ -240,8 +246,10 @@ def _normalize_tool_arguments(
return normalized_arguments, last_created_vm_id return normalized_arguments, last_created_vm_id
def _summarize_message_for_log(message: dict[str, Any]) -> str: def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str:
role = str(message.get("role", "unknown")) role = str(message.get("role", "unknown"))
if not verbose:
return role
content = str(message.get("content") or "").strip() content = str(message.get("content") or "").strip()
if content == "": if content == "":
return f"{role}: <empty>" return f"{role}: <empty>"
@ -257,6 +265,7 @@ def run_ollama_tool_demo(
model: str = DEFAULT_OLLAMA_MODEL, model: str = DEFAULT_OLLAMA_MODEL,
*, *,
strict: bool = True, strict: bool = True,
verbose: bool = False,
log: Callable[[str], None] | None = None, log: Callable[[str], None] | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools.""" """Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
@ -267,10 +276,12 @@ def run_ollama_tool_demo(
{ {
"role": "user", "role": "user",
"content": ( "content": (
"Use the lifecycle tools to run `git --version` in an ephemeral VM.\n" "Use the lifecycle tools to prove outbound internet access in an ephemeral VM.\n"
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n" "Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` " "Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
"returned by vm_create into vm_start/vm_exec.\n" "returned by vm_create into vm_start/vm_exec.\n"
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
f"Success means the clone completes and the command prints `true`.\n"
"If a tool returns an error, fix arguments and retry." "If a tool returns an error, fix arguments and retry."
), ),
} }
@ -280,7 +291,7 @@ def run_ollama_tool_demo(
last_created_vm_id: str | None = None last_created_vm_id: str | None = None
for _round_index in range(1, MAX_TOOL_ROUNDS + 1): for _round_index in range(1, MAX_TOOL_ROUNDS + 1):
emit(f"[model] input {_summarize_message_for_log(messages[-1])}") emit(f"[model] input {_summarize_message_for_log(messages[-1], verbose=verbose)}")
response = _post_chat_completion( response = _post_chat_completion(
base_url, base_url,
{ {
@ -292,7 +303,7 @@ def run_ollama_tool_demo(
}, },
) )
assistant_message = _extract_message(response) assistant_message = _extract_message(response)
emit(f"[model] output {_summarize_message_for_log(assistant_message)}") emit(f"[model] output {_summarize_message_for_log(assistant_message, verbose=verbose)}")
tool_calls = assistant_message.get("tool_calls") tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls: if not isinstance(tool_calls, list) or not tool_calls:
final_response = str(assistant_message.get("content") or "") final_response = str(assistant_message.get("content") or "")
@ -320,15 +331,24 @@ def run_ollama_tool_demo(
if not isinstance(tool_name, str): if not isinstance(tool_name, str):
raise RuntimeError("tool call function name is invalid") raise RuntimeError("tool call function name is invalid")
arguments = _parse_tool_arguments(function.get("arguments")) arguments = _parse_tool_arguments(function.get("arguments"))
emit(f"[model] tool_call {tool_name} args={arguments}") if verbose:
emit(f"[model] tool_call {tool_name} args={arguments}")
else:
emit(f"[model] tool_call {tool_name}")
arguments, normalized_vm_id = _normalize_tool_arguments( arguments, normalized_vm_id = _normalize_tool_arguments(
tool_name, tool_name,
arguments, arguments,
last_created_vm_id=last_created_vm_id, last_created_vm_id=last_created_vm_id,
) )
if normalized_vm_id is not None: if normalized_vm_id is not None:
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}") if verbose:
emit(f"[tool] calling {tool_name} with args={arguments}") emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
else:
emit("[tool] resolved vm_id placeholder")
if verbose:
emit(f"[tool] calling {tool_name} with args={arguments}")
else:
emit(f"[tool] calling {tool_name}")
try: try:
result = _dispatch_tool_call(manager, tool_name, arguments) result = _dispatch_tool_call(manager, tool_name, arguments)
success = True success = True
@ -341,7 +361,10 @@ def run_ollama_tool_demo(
result = _format_tool_error(tool_name, arguments, exc) result = _format_tool_error(tool_name, arguments, exc)
success = False success = False
emit(f"[tool] {tool_name} failed: {exc}") emit(f"[tool] {tool_name} failed: {exc}")
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}") if verbose:
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
else:
emit(f"[tool] result {tool_name}")
tool_events.append( tool_events.append(
{ {
"tool_name": tool_name, "tool_name": tool_name,
@ -379,7 +402,7 @@ def run_ollama_tool_demo(
tool_events.append( tool_events.append(
{ {
"tool_name": "vm_exec_fallback", "tool_name": "vm_exec_fallback",
"arguments": {"command": "git --version"}, "arguments": {"command": NETWORK_PROOF_COMMAND},
"result": exec_result, "result": exec_result,
"success": True, "success": True,
} }
@ -390,13 +413,13 @@ def run_ollama_tool_demo(
raise RuntimeError("vm_exec result shape is invalid") raise RuntimeError("vm_exec result shape is invalid")
if int(exec_result.get("exit_code", -1)) != 0: if int(exec_result.get("exit_code", -1)) != 0:
raise RuntimeError("vm_exec failed; expected exit_code=0") raise RuntimeError("vm_exec failed; expected exit_code=0")
if "git version" not in str(exec_result.get("stdout", "")): if str(exec_result.get("stdout", "")).strip() != "true":
raise RuntimeError("vm_exec output did not contain `git version`") raise RuntimeError("vm_exec output did not confirm repository clone success")
emit("[done] command execution succeeded") emit("[done] command execution succeeded")
return { return {
"model": model, "model": model,
"command": "git --version", "command": NETWORK_PROOF_COMMAND,
"exec_result": exec_result, "exec_result": exec_result,
"tool_events": tool_events, "tool_events": tool_events,
"final_response": final_response, "final_response": final_response,
@ -408,6 +431,7 @@ def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.") parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL) parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL) parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
parser.add_argument("-v", "--verbose", action="store_true")
return parser return parser
@ -418,6 +442,7 @@ def main() -> None:
result = run_ollama_tool_demo( result = run_ollama_tool_demo(
base_url=args.base_url, base_url=args.base_url,
model=args.model, model=args.model,
verbose=args.verbose,
log=lambda message: print(message, flush=True), log=lambda message: print(message, flush=True),
) )
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
@ -428,7 +453,9 @@ def main() -> None:
raise RuntimeError("demo produced invalid execution result") raise RuntimeError("demo produced invalid execution result")
print( print(
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} " f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
f"fallback_used={bool(result.get('fallback_used'))}", f"fallback_used={bool(result.get('fallback_used'))} "
f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}",
flush=True, flush=True,
) )
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True) if args.verbose:
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)

View file

@ -11,6 +11,8 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from pyro_mcp.vm_network import TapNetworkManager
DEFAULT_PLATFORM = "linux-x86_64" DEFAULT_PLATFORM = "linux-x86_64"
@ -27,6 +29,16 @@ class RuntimePaths:
manifest: dict[str, Any] manifest: dict[str, Any]
@dataclass(frozen=True)
class RuntimeCapabilities:
"""Feature flags inferred from the bundled runtime."""
supports_vm_boot: bool
supports_guest_exec: bool
supports_guest_network: bool
reason: str | None = None
def _sha256(path: Path) -> str: def _sha256(path: Path) -> str:
digest = hashlib.sha256() digest = hashlib.sha256()
with path.open("rb") as fp: with path.open("rb") as fp:
@ -135,6 +147,40 @@ def resolve_runtime_paths(
) )
def runtime_capabilities(paths: RuntimePaths) -> RuntimeCapabilities:
"""Infer what the current bundled runtime can actually do."""
binary_text = paths.firecracker_bin.read_text(encoding="utf-8", errors="ignore")
if "bundled firecracker shim" in binary_text:
return RuntimeCapabilities(
supports_vm_boot=False,
supports_guest_exec=False,
supports_guest_network=False,
reason="bundled runtime uses shim firecracker/jailer binaries",
)
capabilities = paths.manifest.get("capabilities")
if not isinstance(capabilities, dict):
return RuntimeCapabilities(
supports_vm_boot=False,
supports_guest_exec=False,
supports_guest_network=False,
reason="runtime manifest does not declare guest boot/exec/network capabilities",
)
supports_vm_boot = bool(capabilities.get("vm_boot"))
supports_guest_exec = bool(capabilities.get("guest_exec"))
supports_guest_network = bool(capabilities.get("guest_network"))
reason = None
if not supports_vm_boot:
reason = "runtime manifest does not advertise real VM boot support"
return RuntimeCapabilities(
supports_vm_boot=supports_vm_boot,
supports_guest_exec=supports_guest_exec,
supports_guest_network=supports_guest_network,
reason=reason,
)
def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]: def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
"""Build a runtime diagnostics report.""" """Build a runtime diagnostics report."""
report: dict[str, Any] = { report: dict[str, Any] = {
@ -146,13 +192,28 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
"readable": os.access("/dev/kvm", os.R_OK), "readable": os.access("/dev/kvm", os.R_OK),
"writable": os.access("/dev/kvm", os.W_OK), "writable": os.access("/dev/kvm", os.W_OK),
}, },
"networking": {
"enabled_by_default": TapNetworkManager().enabled,
},
} }
network = TapNetworkManager.diagnostics()
report["networking"].update(
{
"tun_available": network.tun_available,
"ip_binary": network.ip_binary,
"nft_binary": network.nft_binary,
"iptables_binary": network.iptables_binary,
"ip_forward_enabled": network.ip_forward_enabled,
}
)
try: try:
paths = resolve_runtime_paths(platform=platform, verify_checksums=True) paths = resolve_runtime_paths(platform=platform, verify_checksums=True)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
report["issues"] = [str(exc)] report["issues"] = [str(exc)]
return report return report
capabilities = runtime_capabilities(paths)
profiles = paths.manifest.get("profiles", {}) profiles = paths.manifest.get("profiles", {})
profile_names = sorted(profiles.keys()) if isinstance(profiles, dict) else [] profile_names = sorted(profiles.keys()) if isinstance(profiles, dict) else []
report["runtime_ok"] = True report["runtime_ok"] = True
@ -165,6 +226,12 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
"notice_path": str(paths.notice_path), "notice_path": str(paths.notice_path),
"bundle_version": paths.manifest.get("bundle_version"), "bundle_version": paths.manifest.get("bundle_version"),
"profiles": profile_names, "profiles": profile_names,
"capabilities": {
"supports_vm_boot": capabilities.supports_vm_boot,
"supports_guest_exec": capabilities.supports_guest_exec,
"supports_guest_network": capabilities.supports_guest_network,
"reason": capabilities.reason,
},
} }
if not report["kvm"]["exists"]: if not report["kvm"]["exists"]:
report["issues"] = ["/dev/kvm is not available on this host"] report["issues"] = ["/dev/kvm is not available on this host"]

View file

@ -59,6 +59,11 @@ def create_server(manager: VmManager | None = None) -> FastMCP:
"""Get the current state and metadata for a VM.""" """Get the current state and metadata for a VM."""
return vm_manager.status_vm(vm_id) return vm_manager.status_vm(vm_id)
@server.tool()
async def vm_network_info(vm_id: str) -> dict[str, Any]:
"""Get the current network configuration assigned to a VM."""
return vm_manager.network_info_vm(vm_id)
@server.tool() @server.tool()
async def vm_reap_expired() -> dict[str, Any]: async def vm_reap_expired() -> dict[str, Any]:
"""Delete VMs whose TTL has expired.""" """Delete VMs whose TTL has expired."""

View file

@ -0,0 +1,109 @@
"""Firecracker launch-plan generation for pyro VMs."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol
DEFAULT_GUEST_CID_OFFSET = 10_000
DEFAULT_VSOCK_PORT = 5005
@dataclass(frozen=True)
class FirecrackerLaunchPlan:
"""On-disk config artifacts needed to boot a guest."""
api_socket_path: Path
config_path: Path
guest_network_path: Path
guest_exec_path: Path
guest_cid: int
vsock_port: int
config: dict[str, Any]
class VmInstanceLike(Protocol):
vm_id: str
vcpu_count: int
mem_mib: int
workdir: Path
metadata: dict[str, str]
network: Any
def _guest_cid(vm_id: str) -> int:
return DEFAULT_GUEST_CID_OFFSET + int(vm_id[:4], 16)
def build_launch_plan(instance: VmInstanceLike) -> FirecrackerLaunchPlan:
guest_cid = _guest_cid(instance.vm_id)
vsock_port = DEFAULT_VSOCK_PORT
api_socket_path = instance.workdir / "firecracker.sock"
config_path = instance.workdir / "firecracker-config.json"
guest_network_path = instance.workdir / "guest-network.json"
guest_exec_path = instance.workdir / "guest-exec.json"
config: dict[str, Any] = {
"boot-source": {
"kernel_image_path": instance.metadata["kernel_image"],
"boot_args": "console=ttyS0 reboot=k panic=1 pci=off",
},
"drives": [
{
"drive_id": "rootfs",
"path_on_host": instance.metadata["rootfs_image"],
"is_root_device": True,
"is_read_only": False,
}
],
"machine-config": {
"vcpu_count": instance.vcpu_count,
"mem_size_mib": instance.mem_mib,
"smt": False,
},
"network-interfaces": [],
"vsock": {
"guest_cid": guest_cid,
"uds_path": str(instance.workdir / "vsock.sock"),
},
}
if instance.network is not None:
config["network-interfaces"] = [
{
"iface_id": "eth0",
"guest_mac": instance.network.mac_address,
"host_dev_name": instance.network.tap_name,
}
]
guest_network = {
"guest_ip": instance.network.guest_ip if instance.network is not None else None,
"gateway_ip": instance.network.gateway_ip if instance.network is not None else None,
"subnet_cidr": instance.network.subnet_cidr if instance.network is not None else None,
"dns_servers": list(instance.network.dns_servers) if instance.network is not None else [],
"tap_name": instance.network.tap_name if instance.network is not None else None,
}
guest_exec = {
"transport": "vsock",
"guest_cid": guest_cid,
"port": vsock_port,
}
config_path.write_text(json.dumps(config, indent=2, sort_keys=True), encoding="utf-8")
guest_network_path.write_text(
json.dumps(guest_network, indent=2, sort_keys=True), encoding="utf-8"
)
guest_exec_path.write_text(json.dumps(guest_exec, indent=2, sort_keys=True), encoding="utf-8")
return FirecrackerLaunchPlan(
api_socket_path=api_socket_path,
config_path=config_path,
guest_network_path=guest_network_path,
guest_exec_path=guest_exec_path,
guest_cid=guest_cid,
vsock_port=vsock_port,
config=config,
)

72
src/pyro_mcp/vm_guest.py Normal file
View file

@ -0,0 +1,72 @@
"""Guest command transport over vsock-compatible JSON protocol."""
from __future__ import annotations
import json
import socket
from dataclasses import dataclass
from typing import Callable, Protocol
class SocketLike(Protocol):
def settimeout(self, timeout: int) -> None: ...
def connect(self, address: tuple[int, int]) -> None: ...
def sendall(self, data: bytes) -> None: ...
def recv(self, size: int) -> bytes: ...
def close(self) -> None: ...
SocketFactory = Callable[[int, int], SocketLike]
@dataclass(frozen=True)
class GuestExecResponse:
stdout: str
stderr: str
exit_code: int
duration_ms: int
class VsockExecClient:
"""Minimal JSON-over-stream client for a guest exec agent."""
def __init__(self, socket_factory: SocketFactory | None = None) -> None:
self._socket_factory = socket_factory or socket.socket
def exec(
self, guest_cid: int, port: int, command: str, timeout_seconds: int
) -> GuestExecResponse:
request = {
"command": command,
"timeout_seconds": timeout_seconds,
}
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise RuntimeError("vsock sockets are not supported on this host Python runtime")
sock = self._socket_factory(family, socket.SOCK_STREAM)
try:
sock.settimeout(timeout_seconds)
sock.connect((guest_cid, port))
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
chunks: list[bytes] = []
while True:
data = sock.recv(65536)
if data == b"":
break
chunks.append(data)
finally:
sock.close()
payload = json.loads(b"".join(chunks).decode("utf-8"))
if not isinstance(payload, dict):
raise RuntimeError("guest exec response must be a JSON object")
return GuestExecResponse(
stdout=str(payload.get("stdout", "")),
stderr=str(payload.get("stderr", "")),
exit_code=int(payload.get("exit_code", -1)),
duration_ms=int(payload.get("duration_ms", 0)),
)

View file

@ -12,7 +12,15 @@ from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal
from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths from pyro_mcp.runtime import (
RuntimeCapabilities,
RuntimePaths,
resolve_runtime_paths,
runtime_capabilities,
)
from pyro_mcp.vm_firecracker import build_launch_plan
from pyro_mcp.vm_guest import VsockExecClient
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
VmState = Literal["created", "started", "stopped"] VmState = Literal["created", "started", "stopped"]
@ -34,6 +42,7 @@ class VmInstance:
firecracker_pid: int | None = None firecracker_pid: int | None = None
last_error: str | None = None last_error: str | None = None
metadata: dict[str, str] = field(default_factory=dict) metadata: dict[str, str] = field(default_factory=dict)
network: NetworkConfig | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -119,10 +128,21 @@ class MockBackend(VmBackend):
class FirecrackerBackend(VmBackend): # pragma: no cover class FirecrackerBackend(VmBackend): # pragma: no cover
"""Host-gated backend that validates Firecracker prerequisites.""" """Host-gated backend that validates Firecracker prerequisites."""
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None: def __init__(
self,
artifacts_dir: Path,
firecracker_bin: Path,
jailer_bin: Path,
runtime_capabilities: RuntimeCapabilities,
network_manager: TapNetworkManager | None = None,
guest_exec_client: VsockExecClient | None = None,
) -> None:
self._artifacts_dir = artifacts_dir self._artifacts_dir = artifacts_dir
self._firecracker_bin = firecracker_bin self._firecracker_bin = firecracker_bin
self._jailer_bin = jailer_bin self._jailer_bin = jailer_bin
self._runtime_capabilities = runtime_capabilities
self._network_manager = network_manager or TapNetworkManager()
self._guest_exec_client = guest_exec_client or VsockExecClient()
if not self._firecracker_bin.exists(): if not self._firecracker_bin.exists():
raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}") raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}")
if not self._jailer_bin.exists(): if not self._jailer_bin.exists():
@ -132,16 +152,29 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
def create(self, instance: VmInstance) -> None: def create(self, instance: VmInstance) -> None:
instance.workdir.mkdir(parents=True, exist_ok=False) instance.workdir.mkdir(parents=True, exist_ok=False)
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile) try:
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists(): artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
raise RuntimeError( if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
f"missing profile artifacts for {instance.profile}; expected " raise RuntimeError(
f"{artifacts.kernel_image} and {artifacts.rootfs_image}" f"missing profile artifacts for {instance.profile}; expected "
) f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
instance.metadata["kernel_image"] = str(artifacts.kernel_image) )
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image) instance.metadata["kernel_image"] = str(artifacts.kernel_image)
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
network = self._network_manager.allocate(instance.vm_id)
instance.network = network
instance.metadata.update(self._network_manager.to_metadata(network))
except Exception:
shutil.rmtree(instance.workdir, ignore_errors=True)
raise
def start(self, instance: VmInstance) -> None: def start(self, instance: VmInstance) -> None:
launch_plan = build_launch_plan(instance)
instance.metadata["firecracker_config_path"] = str(launch_plan.config_path)
instance.metadata["guest_network_path"] = str(launch_plan.guest_network_path)
instance.metadata["guest_exec_path"] = str(launch_plan.guest_exec_path)
instance.metadata["guest_cid"] = str(launch_plan.guest_cid)
instance.metadata["guest_exec_port"] = str(launch_plan.vsock_port)
proc = subprocess.run( # noqa: S603 proc = subprocess.run( # noqa: S603
[str(self._firecracker_bin), "--version"], [str(self._firecracker_bin), "--version"],
text=True, text=True,
@ -152,15 +185,35 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}") raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}")
instance.metadata["firecracker_version"] = proc.stdout.strip() instance.metadata["firecracker_version"] = proc.stdout.strip()
instance.metadata["jailer_path"] = str(self._jailer_bin) instance.metadata["jailer_path"] = str(self._jailer_bin)
if not self._runtime_capabilities.supports_vm_boot:
instance.metadata["execution_mode"] = "host_compat"
instance.metadata["boot_mode"] = "shim"
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
return
instance.metadata["execution_mode"] = "guest_vsock"
instance.metadata["boot_mode"] = "native"
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
# Temporary compatibility path until guest-side execution agent is integrated. if self._runtime_capabilities.supports_guest_exec:
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
response = self._guest_exec_client.exec(guest_cid, port, command, timeout_seconds)
return VmExecResult(
stdout=response.stdout,
stderr=response.stderr,
exit_code=response.exit_code,
duration_ms=response.duration_ms,
)
instance.metadata["execution_mode"] = "host_compat"
return _run_host_command(instance.workdir, command, timeout_seconds) return _run_host_command(instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None: def stop(self, instance: VmInstance) -> None:
del instance del instance
def delete(self, instance: VmInstance) -> None: def delete(self, instance: VmInstance) -> None:
if instance.network is not None:
self._network_manager.cleanup(instance.network)
shutil.rmtree(instance.workdir, ignore_errors=True) shutil.rmtree(instance.workdir, ignore_errors=True)
@ -182,6 +235,7 @@ class VmManager:
artifacts_dir: Path | None = None, artifacts_dir: Path | None = None,
max_active_vms: int = 4, max_active_vms: int = 4,
runtime_paths: RuntimePaths | None = None, runtime_paths: RuntimePaths | None = None,
network_manager: TapNetworkManager | None = None,
) -> None: ) -> None:
self._backend_name = backend_name or "firecracker" self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp") self._base_dir = base_dir or Path("/tmp/pyro-mcp")
@ -189,11 +243,19 @@ class VmManager:
if self._backend_name == "firecracker": if self._backend_name == "firecracker":
self._runtime_paths = self._runtime_paths or resolve_runtime_paths() self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir
self._runtime_capabilities = runtime_capabilities(self._runtime_paths)
else: else:
self._artifacts_dir = artifacts_dir or Path( self._artifacts_dir = artifacts_dir or Path(
os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts") os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts")
) )
self._runtime_capabilities = RuntimeCapabilities(
supports_vm_boot=False,
supports_guest_exec=False,
supports_guest_network=False,
reason="mock backend does not boot a guest",
)
self._max_active_vms = max_active_vms self._max_active_vms = max_active_vms
self._network_manager = network_manager or TapNetworkManager()
self._lock = threading.Lock() self._lock = threading.Lock()
self._instances: dict[str, VmInstance] = {} self._instances: dict[str, VmInstance] = {}
self._base_dir.mkdir(parents=True, exist_ok=True) self._base_dir.mkdir(parents=True, exist_ok=True)
@ -209,6 +271,8 @@ class VmManager:
self._artifacts_dir, self._artifacts_dir,
firecracker_bin=self._runtime_paths.firecracker_bin, firecracker_bin=self._runtime_paths.firecracker_bin,
jailer_bin=self._runtime_paths.jailer_bin, jailer_bin=self._runtime_paths.jailer_bin,
runtime_capabilities=self._runtime_capabilities,
network_manager=self._network_manager,
) )
raise ValueError("invalid backend; expected one of: mock, firecracker") raise ValueError("invalid backend; expected one of: mock, firecracker")
@ -262,6 +326,7 @@ class VmManager:
if instance.state != "started": if instance.state != "started":
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec") raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
exec_result = self._backend.exec(instance, command, timeout_seconds) exec_result = self._backend.exec(instance, command, timeout_seconds)
execution_mode = instance.metadata.get("execution_mode", "host_compat")
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup") cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return { return {
"vm_id": vm_id, "vm_id": vm_id,
@ -270,6 +335,7 @@ class VmManager:
"stderr": exec_result.stderr, "stderr": exec_result.stderr,
"exit_code": exec_result.exit_code, "exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms, "duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
"cleanup": cleanup, "cleanup": cleanup,
} }
@ -296,6 +362,19 @@ class VmManager:
self._ensure_not_expired_locked(instance, time.time()) self._ensure_not_expired_locked(instance, time.time())
return self._serialize(instance) return self._serialize(instance)
def network_info_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.network is None:
return {
"vm_id": vm_id,
"network_enabled": False,
"outbound_connectivity_expected": False,
"reason": "network configuration is unavailable for this VM",
}
return {"vm_id": vm_id, **self._network_manager.network_info(instance.network)}
def reap_expired(self) -> dict[str, Any]: def reap_expired(self) -> dict[str, Any]:
now = time.time() now = time.time()
with self._lock: with self._lock:
@ -331,6 +410,10 @@ class VmManager:
"created_at": instance.created_at, "created_at": instance.created_at,
"expires_at": instance.expires_at, "expires_at": instance.expires_at,
"state": instance.state, "state": instance.state,
"network_enabled": instance.network is not None,
"guest_ip": instance.network.guest_ip if instance.network is not None else None,
"tap_name": instance.network.tap_name if instance.network is not None else None,
"execution_mode": instance.metadata.get("execution_mode", "host_compat"),
"metadata": instance.metadata, "metadata": instance.metadata,
} }

191
src/pyro_mcp/vm_network.py Normal file
View file

@ -0,0 +1,191 @@
"""Host-side network allocation and diagnostics for Firecracker VMs."""
from __future__ import annotations
import os
import shutil
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
CommandRunner = Callable[[list[str]], subprocess.CompletedProcess[str]]
DEFAULT_NETWORK_PREFIX = "172.29."
DEFAULT_GATEWAY_SUFFIX = 1
DEFAULT_GUEST_SUFFIX = 2
@dataclass(frozen=True)
class NetworkConfig:
vm_id: str
tap_name: str
guest_ip: str
gateway_ip: str
subnet_cidr: str
mac_address: str
dns_servers: tuple[str, ...] = ("1.1.1.1", "8.8.8.8")
@dataclass(frozen=True)
class NetworkDiagnostics:
tun_available: bool
ip_binary: str | None
nft_binary: str | None
iptables_binary: str | None
ip_forward_enabled: bool
class TapNetworkManager:
"""Allocate per-VM TAP/NAT configuration on the host."""
def __init__(
self,
*,
enabled: bool | None = None,
runner: CommandRunner | None = None,
) -> None:
if enabled is None:
self._enabled = os.environ.get("PYRO_VM_ENABLE_NETWORK") == "1"
else:
self._enabled = enabled
self._runner = runner or self._run
@staticmethod
def diagnostics() -> NetworkDiagnostics:
ip_forward = False
try:
ip_forward = (
Path("/proc/sys/net/ipv4/ip_forward").read_text(encoding="utf-8").strip()
== "1"
)
except OSError:
ip_forward = False
return NetworkDiagnostics(
tun_available=Path("/dev/net/tun").exists(),
ip_binary=shutil.which("ip"),
nft_binary=shutil.which("nft"),
iptables_binary=shutil.which("iptables"),
ip_forward_enabled=ip_forward,
)
@property
def enabled(self) -> bool:
return self._enabled
def allocate(self, vm_id: str) -> NetworkConfig:
octet = int(vm_id[:2], 16)
third_octet = 1 + (octet % 200)
subnet_cidr = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.0/24"
gateway_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GATEWAY_SUFFIX}"
guest_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GUEST_SUFFIX}"
tap_name = f"pyro{vm_id[:8]}"
mac_address = f"06:00:{vm_id[0:2]}:{vm_id[2:4]}:{vm_id[4:6]}:{vm_id[6:8]}"
config = NetworkConfig(
vm_id=vm_id,
tap_name=tap_name,
guest_ip=guest_ip,
gateway_ip=gateway_ip,
subnet_cidr=subnet_cidr,
mac_address=mac_address,
)
if self._enabled:
self._ensure_host_network(config)
return config
def cleanup(self, config: NetworkConfig) -> None:
if not self._enabled:
return
table_name = self._nft_table_name(config.vm_id)
self._run_ignore(["nft", "delete", "table", "ip", table_name])
self._run_ignore(["ip", "link", "del", config.tap_name])
def to_metadata(self, config: NetworkConfig) -> dict[str, str]:
return {
"network_enabled": "true" if self._enabled else "false",
"tap_name": config.tap_name,
"guest_ip": config.guest_ip,
"gateway_ip": config.gateway_ip,
"subnet_cidr": config.subnet_cidr,
"mac_address": config.mac_address,
"dns_servers": ",".join(config.dns_servers),
}
def network_info(self, config: NetworkConfig) -> dict[str, object]:
return {
"network_enabled": self._enabled,
"tap_name": config.tap_name,
"guest_ip": config.guest_ip,
"gateway_ip": config.gateway_ip,
"subnet_cidr": config.subnet_cidr,
"mac_address": config.mac_address,
"dns_servers": list(config.dns_servers),
"outbound_connectivity_expected": self._enabled,
}
def _ensure_host_network(self, config: NetworkConfig) -> None:
diagnostics = self.diagnostics()
if not diagnostics.tun_available:
raise RuntimeError("/dev/net/tun is not available on this host")
if diagnostics.ip_binary is None:
raise RuntimeError("`ip` binary is required for TAP setup")
if diagnostics.nft_binary is None:
raise RuntimeError("`nft` binary is required for outbound NAT setup")
if not diagnostics.ip_forward_enabled:
raise RuntimeError("IPv4 forwarding is disabled on this host")
self._runner(["ip", "tuntap", "add", "dev", config.tap_name, "mode", "tap"])
self._runner(["ip", "addr", "add", f"{config.gateway_ip}/24", "dev", config.tap_name])
self._runner(["ip", "link", "set", config.tap_name, "up"])
table_name = self._nft_table_name(config.vm_id)
self._runner(["nft", "add", "table", "ip", table_name])
self._runner(
[
"nft",
"add",
"chain",
"ip",
table_name,
"postrouting",
"{",
"type",
"nat",
"hook",
"postrouting",
"priority",
"srcnat",
";",
"}",
]
)
self._runner([
"nft",
"add",
"rule",
"ip",
table_name,
"postrouting",
"ip",
"saddr",
config.subnet_cidr,
"masquerade",
])
def _run_ignore(self, command: list[str]) -> None:
try:
self._runner(command)
except RuntimeError:
return
@staticmethod
def _nft_table_name(vm_id: str) -> str:
return f"pyro_{vm_id[:12]}"
@staticmethod
def _run(command: list[str]) -> subprocess.CompletedProcess[str]:
completed = subprocess.run(command, text=True, capture_output=True, check=False)
if completed.returncode != 0:
stderr = completed.stderr.strip() or completed.stdout.strip()
raise RuntimeError(f"command {' '.join(command)!r} failed: {stderr}")
return completed

View file

@ -40,6 +40,10 @@ def test_run_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> None:
calls.append(("start_vm", {"vm_id": vm_id})) calls.append(("start_vm", {"vm_id": vm_id}))
return {"vm_id": vm_id} return {"vm_id": vm_id}
def status_vm(self, vm_id: str) -> dict[str, Any]:
calls.append(("status_vm", {"vm_id": vm_id}))
return {"vm_id": vm_id, "network_enabled": False, "execution_mode": "host_compat"}
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
calls.append( calls.append(
( (
@ -55,7 +59,13 @@ def test_run_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> None:
assert result["exit_code"] == 0 assert result["exit_code"] == 0
assert calls[0][0] == "create_vm" assert calls[0][0] == "create_vm"
assert calls[1] == ("start_vm", {"vm_id": "vm-1"}) assert calls[1] == ("start_vm", {"vm_id": "vm-1"})
assert calls[2][0] == "exec_vm" assert calls[2] == ("status_vm", {"vm_id": "vm-1"})
assert calls[3][0] == "exec_vm"
def test_demo_command_prefers_network_probe_for_guest_vsock() -> None:
status = {"network_enabled": True, "execution_mode": "guest_vsock"}
assert "https://example.com" in demo_module._demo_command(status)
def test_main_prints_json( def test_main_prints_json(

View file

@ -94,7 +94,7 @@ def _stepwise_model_response(payload: dict[str, Any], step: int) -> dict[str, An
"arguments": json.dumps( "arguments": json.dumps(
{ {
"vm_id": vm_id, "vm_id": vm_id,
"command": "printf 'git version 2.44.0\\n'", "command": "printf 'true\\n'",
} }
), ),
}, },
@ -125,14 +125,14 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non
result = ollama_demo.run_ollama_tool_demo(log=logs.append) result = ollama_demo.run_ollama_tool_demo(log=logs.append)
assert result["fallback_used"] is False assert result["fallback_used"] is False
assert "git version" in str(result["exec_result"]["stdout"]) assert str(result["exec_result"]["stdout"]).strip() == "true"
assert result["final_response"] == "Executed git command in ephemeral VM." assert result["final_response"] == "Executed git command in ephemeral VM."
assert len(result["tool_events"]) == 4 assert len(result["tool_events"]) == 4
assert any("[model] input user:" in line for line in logs) assert any(line == "[model] input user" for line in logs)
assert any("[model] output assistant:" in line for line in logs) assert any(line == "[model] output assistant" for line in logs)
assert any("[model] tool_call vm_exec" in line for line in logs) assert any("[model] tool_call vm_exec" in line for line in logs)
assert any("[tool] calling vm_exec" in line for line in logs) assert any(line == "[tool] calling vm_exec" for line in logs)
assert any("[tool] result vm_exec " in line for line in logs) assert any(line == "[tool] result vm_exec" for line in logs)
def test_run_ollama_tool_demo_recovers_from_bad_vm_id( def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
@ -158,7 +158,7 @@ def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
"arguments": json.dumps( "arguments": json.dumps(
{ {
"vm_id": "vm_list_profiles", "vm_id": "vm_list_profiles",
"command": "git --version", "command": ollama_demo.NETWORK_PROOF_COMMAND,
} }
), ),
}, },
@ -219,7 +219,7 @@ def test_run_ollama_tool_demo_resolves_vm_id_placeholder(
"arguments": json.dumps( "arguments": json.dumps(
{ {
"vm_id": "<vm_id_returned_by_vm_create>", "vm_id": "<vm_id_returned_by_vm_create>",
"command": "printf 'git version 2.44.0\\n'", "command": "printf 'true\\n'",
"timeout_seconds": "300", "timeout_seconds": "300",
} }
), ),
@ -292,14 +292,49 @@ def test_run_ollama_tool_demo_uses_fallback_when_not_strict(
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
monkeypatch.setattr(ollama_demo, "VmManager", TestVmManager) monkeypatch.setattr(ollama_demo, "VmManager", TestVmManager)
monkeypatch.setattr(
ollama_demo,
"_run_direct_lifecycle_fallback",
lambda manager: {
"vm_id": "vm-1",
"command": ollama_demo.NETWORK_PROOF_COMMAND,
"stdout": "true\n",
"stderr": "",
"exit_code": 0,
"duration_ms": 5,
"execution_mode": "host_compat",
"cleanup": {"deleted": True, "reason": "post_exec_cleanup", "vm_id": "vm-1"},
},
)
logs: list[str] = [] logs: list[str] = []
result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append) result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append)
assert result["fallback_used"] is True assert result["fallback_used"] is True
assert int(result["exec_result"]["exit_code"]) == 0 assert int(result["exec_result"]["exit_code"]) == 0
assert any("[model] output assistant: No tools" in line for line in logs) assert any(line == "[model] output assistant" for line in logs)
assert any("[fallback]" in line for line in logs) assert any("[fallback]" in line for line in logs)
def test_run_ollama_tool_demo_verbose_logs_values(monkeypatch: pytest.MonkeyPatch) -> None:
requests = 0
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url
nonlocal requests
requests += 1
return _stepwise_model_response(payload, requests)
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
logs: list[str] = []
result = ollama_demo.run_ollama_tool_demo(verbose=True, log=logs.append)
assert result["fallback_used"] is False
assert str(result["exec_result"]["stdout"]).strip() == "true"
assert any("[model] input user:" in line for line in logs)
assert any("[model] tool_call vm_list_profiles args={}" in line for line in logs)
assert any("[tool] result vm_exec " in line for line in logs)
@pytest.mark.parametrize( @pytest.mark.parametrize(
("tool_call", "error"), ("tool_call", "error"),
[ [
@ -346,8 +381,8 @@ def test_run_ollama_tool_demo_max_rounds(monkeypatch: pytest.MonkeyPatch) -> Non
("exec_result", "error"), ("exec_result", "error"),
[ [
("bad", "result shape is invalid"), ("bad", "result shape is invalid"),
({"exit_code": 1, "stdout": "git version 2"}, "expected exit_code=0"), ({"exit_code": 1, "stdout": "true"}, "expected exit_code=0"),
({"exit_code": 0, "stdout": "no git"}, "did not contain `git version`"), ({"exit_code": 0, "stdout": "false"}, "did not confirm repository clone success"),
], ],
) )
def test_run_ollama_tool_demo_exec_result_validation( def test_run_ollama_tool_demo_exec_result_validation(
@ -404,7 +439,7 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
executed = ollama_demo._dispatch_tool_call( executed = ollama_demo._dispatch_tool_call(
manager, manager,
"vm_exec", "vm_exec",
{"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"}, {"vm_id": vm_id, "command": "printf 'true\\n'", "timeout_seconds": "30"},
) )
assert int(executed["exit_code"]) == 0 assert int(executed["exit_code"]) == 0
with pytest.raises(RuntimeError, match="unexpected tool requested by model"): with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
@ -529,6 +564,13 @@ def test_build_parser_defaults() -> None:
args = parser.parse_args([]) args = parser.parse_args([])
assert args.model == ollama_demo.DEFAULT_OLLAMA_MODEL assert args.model == ollama_demo.DEFAULT_OLLAMA_MODEL
assert args.base_url == ollama_demo.DEFAULT_OLLAMA_BASE_URL assert args.base_url == ollama_demo.DEFAULT_OLLAMA_BASE_URL
assert args.verbose is False
def test_build_parser_verbose_flag() -> None:
parser = ollama_demo._build_parser()
args = parser.parse_args(["-v"])
assert args.verbose is True
def test_main_uses_parser_and_prints_logs( def test_main_uses_parser_and_prints_logs(
@ -537,21 +579,51 @@ def test_main_uses_parser_and_prints_logs(
) -> None: ) -> None:
class StubParser: class StubParser:
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(base_url="http://x", model="m") return argparse.Namespace(base_url="http://x", model="m", verbose=False)
monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser()) monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser())
monkeypatch.setattr( monkeypatch.setattr(
ollama_demo, ollama_demo,
"run_ollama_tool_demo", "run_ollama_tool_demo",
lambda base_url, model, strict=True, log=None: { lambda base_url, model, strict=True, verbose=False, log=None: {
"exec_result": {"exit_code": 0, "stdout": "git version 2.44.0\n"}, "exec_result": {
"exit_code": 0,
"stdout": "true\n",
"execution_mode": "host_compat",
},
"fallback_used": False, "fallback_used": False,
}, },
) )
ollama_demo.main() ollama_demo.main()
output = capsys.readouterr().out output = capsys.readouterr().out
assert "[summary] exit_code=0 fallback_used=False" in output assert "[summary] exit_code=0 fallback_used=False execution_mode=host_compat" in output
assert "[summary] stdout=git version 2.44.0" in output assert "[summary] stdout=" not in output
def test_main_verbose_prints_stdout(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(base_url="http://x", model="m", verbose=True)
monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser())
monkeypatch.setattr(
ollama_demo,
"run_ollama_tool_demo",
lambda base_url, model, strict=True, verbose=False, log=None: {
"exec_result": {
"exit_code": 0,
"stdout": "true\n",
"execution_mode": "host_compat",
},
"fallback_used": False,
},
)
ollama_demo.main()
output = capsys.readouterr().out
assert "[summary] stdout=true" in output
def test_main_logs_error_and_exits_nonzero( def test_main_logs_error_and_exits_nonzero(
@ -560,12 +632,18 @@ def test_main_logs_error_and_exits_nonzero(
) -> None: ) -> None:
class StubParser: class StubParser:
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(base_url="http://x", model="m") return argparse.Namespace(base_url="http://x", model="m", verbose=False)
monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser()) monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser())
def fake_run(base_url: str, model: str, strict: bool = True, log: Any = None) -> dict[str, Any]: def fake_run(
del base_url, model, strict, log base_url: str,
model: str,
strict: bool = True,
verbose: bool = False,
log: Any = None,
) -> dict[str, Any]:
del base_url, model, strict, verbose, log
raise RuntimeError("demo did not execute a successful vm_exec") raise RuntimeError("demo did not execute a successful vm_exec")
monkeypatch.setattr(ollama_demo, "run_ollama_tool_demo", fake_run) monkeypatch.setattr(ollama_demo, "run_ollama_tool_demo", fake_run)

View file

@ -5,7 +5,7 @@ from pathlib import Path
import pytest import pytest
from pyro_mcp.runtime import doctor_report, resolve_runtime_paths from pyro_mcp.runtime import doctor_report, resolve_runtime_paths, runtime_capabilities
def test_resolve_runtime_paths_default_bundle() -> None: def test_resolve_runtime_paths_default_bundle() -> None:
@ -67,7 +67,19 @@ def test_doctor_report_has_runtime_fields() -> None:
report = doctor_report() report = doctor_report()
assert "runtime_ok" in report assert "runtime_ok" in report
assert "kvm" in report assert "kvm" in report
assert "networking" in report
if report["runtime_ok"]: if report["runtime_ok"]:
runtime = report.get("runtime") runtime = report.get("runtime")
assert isinstance(runtime, dict) assert isinstance(runtime, dict)
assert "firecracker_bin" in runtime assert "firecracker_bin" in runtime
networking = report["networking"]
assert isinstance(networking, dict)
assert "tun_available" in networking
def test_runtime_capabilities_reports_shim_bundle() -> None:
paths = resolve_runtime_paths()
capabilities = runtime_capabilities(paths)
assert capabilities.supports_vm_boot is False
assert capabilities.supports_guest_exec is False
assert capabilities.supports_guest_network is False

View file

@ -9,10 +9,15 @@ import pytest
import pyro_mcp.server as server_module import pyro_mcp.server as server_module
from pyro_mcp.server import create_server from pyro_mcp.server import create_server
from pyro_mcp.vm_manager import VmManager from pyro_mcp.vm_manager import VmManager
from pyro_mcp.vm_network import TapNetworkManager
def test_create_server_registers_vm_tools(tmp_path: Path) -> None: def test_create_server_registers_vm_tools(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
async def _run() -> list[str]: async def _run() -> list[str]:
server = create_server(manager=manager) server = create_server(manager=manager)
@ -23,11 +28,16 @@ def test_create_server_registers_vm_tools(tmp_path: Path) -> None:
assert "vm_create" in tool_names assert "vm_create" in tool_names
assert "vm_exec" in tool_names assert "vm_exec" in tool_names
assert "vm_list_profiles" in tool_names assert "vm_list_profiles" in tool_names
assert "vm_network_info" in tool_names
assert "vm_status" in tool_names assert "vm_status" in tool_names
def test_vm_tools_lifecycle_round_trip(tmp_path: Path) -> None: def test_vm_tools_lifecycle_round_trip(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
def _extract_structured(raw_result: object) -> dict[str, Any]: def _extract_structured(raw_result: object) -> dict[str, Any]:
if not isinstance(raw_result, tuple) or len(raw_result) != 2: if not isinstance(raw_result, tuple) or len(raw_result) != 2:
@ -60,7 +70,11 @@ def test_vm_tools_lifecycle_round_trip(tmp_path: Path) -> None:
def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None: def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
manager.MIN_TTL_SECONDS = 1 manager.MIN_TTL_SECONDS = 1
def _extract_structured(raw_result: object) -> dict[str, Any]: def _extract_structured(raw_result: object) -> dict[str, Any]:
@ -72,7 +86,12 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None:
return cast(dict[str, Any], structured) return cast(dict[str, Any], structured)
async def _run() -> tuple[ async def _run() -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any], list[dict[str, object]], dict[str, Any] dict[str, Any],
dict[str, Any],
dict[str, Any],
dict[str, Any],
list[dict[str, object]],
dict[str, Any],
]: ]:
server = create_server(manager=manager) server = create_server(manager=manager)
profiles_raw = await server.call_tool("vm_list_profiles", {}) profiles_raw = await server.call_tool("vm_list_profiles", {})
@ -93,6 +112,7 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None:
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
await server.call_tool("vm_start", {"vm_id": vm_id}) await server.call_tool("vm_start", {"vm_id": vm_id})
status = _extract_structured(await server.call_tool("vm_status", {"vm_id": vm_id})) status = _extract_structured(await server.call_tool("vm_status", {"vm_id": vm_id}))
network = _extract_structured(await server.call_tool("vm_network_info", {"vm_id": vm_id}))
stopped = _extract_structured(await server.call_tool("vm_stop", {"vm_id": vm_id})) stopped = _extract_structured(await server.call_tool("vm_stop", {"vm_id": vm_id}))
deleted = _extract_structured(await server.call_tool("vm_delete", {"vm_id": vm_id})) deleted = _extract_structured(await server.call_tool("vm_delete", {"vm_id": vm_id}))
@ -105,10 +125,18 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None:
expiring_id = str(expiring["vm_id"]) expiring_id = str(expiring["vm_id"])
manager._instances[expiring_id].expires_at = 0.0 # noqa: SLF001 manager._instances[expiring_id].expires_at = 0.0 # noqa: SLF001
reaped = _extract_structured(await server.call_tool("vm_reap_expired", {})) reaped = _extract_structured(await server.call_tool("vm_reap_expired", {}))
return status, stopped, deleted, cast(list[dict[str, object]], raw_profiles), reaped return (
status,
network,
stopped,
deleted,
cast(list[dict[str, object]], raw_profiles),
reaped,
)
status, stopped, deleted, profiles, reaped = asyncio.run(_run()) status, network, stopped, deleted, profiles, reaped = asyncio.run(_run())
assert status["state"] == "started" assert status["state"] == "started"
assert network["network_enabled"] is False
assert stopped["state"] == "stopped" assert stopped["state"] == "stopped"
assert bool(deleted["deleted"]) is True assert bool(deleted["deleted"]) is True
assert profiles[0]["name"] == "debian-base" assert profiles[0]["name"] == "debian-base"

View file

@ -0,0 +1,51 @@
from __future__ import annotations
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from pyro_mcp.vm_firecracker import build_launch_plan
from pyro_mcp.vm_network import NetworkConfig
@dataclass
class StubInstance:
vm_id: str
vcpu_count: int
mem_mib: int
workdir: Path
metadata: dict[str, str] = field(default_factory=dict)
network: Any = None
def test_build_launch_plan_writes_expected_files(tmp_path: Path) -> None:
instance = StubInstance(
vm_id="abcdef123456",
vcpu_count=2,
mem_mib=2048,
workdir=tmp_path,
metadata={
"kernel_image": "/bundle/profiles/debian-git/vmlinux",
"rootfs_image": "/bundle/profiles/debian-git/rootfs.ext4",
},
network=NetworkConfig(
vm_id="abcdef123456",
tap_name="pyroabcdef12",
guest_ip="172.29.100.2",
gateway_ip="172.29.100.1",
subnet_cidr="172.29.100.0/24",
mac_address="06:00:ab:cd:ef:12",
),
)
plan = build_launch_plan(instance)
assert plan.config_path.exists()
assert plan.guest_network_path.exists()
assert plan.guest_exec_path.exists()
rendered = json.loads(plan.config_path.read_text(encoding="utf-8"))
assert rendered["machine-config"]["vcpu_count"] == 2
assert rendered["network-interfaces"][0]["host_dev_name"] == "pyroabcdef12"
guest_exec = json.loads(plan.guest_exec_path.read_text(encoding="utf-8"))
assert guest_exec["transport"] == "vsock"

64
tests/test_vm_guest.py Normal file
View file

@ -0,0 +1,64 @@
from __future__ import annotations
import socket
import pytest
from pyro_mcp.vm_guest import VsockExecClient
class StubSocket:
def __init__(self, response: bytes) -> None:
self.response = response
self.connected: tuple[int, int] | None = None
self.sent = b""
self.timeout: int | None = None
self.closed = False
def settimeout(self, timeout: int) -> None:
self.timeout = timeout
def connect(self, address: tuple[int, int]) -> None:
self.connected = address
def sendall(self, data: bytes) -> None:
self.sent += data
def recv(self, size: int) -> bytes:
del size
if self.response == b"":
return b""
data, self.response = self.response, b""
return data
def close(self) -> None:
self.closed = True
def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
stub = StubSocket(
b'{"stdout":"ok\\n","stderr":"","exit_code":0,"duration_ms":7}'
)
def socket_factory(family: int, sock_type: int) -> StubSocket:
assert family == socket.AF_VSOCK
assert sock_type == socket.SOCK_STREAM
return stub
client = VsockExecClient(socket_factory=socket_factory)
response = client.exec(1234, 5005, "echo ok", 30)
assert response.exit_code == 0
assert response.stdout == "ok\n"
assert stub.connected == (1234, 5005)
assert b'"command": "echo ok"' in stub.sent
assert stub.closed is True
def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
stub = StubSocket(b"[]")
client = VsockExecClient(socket_factory=lambda family, sock_type: stub)
with pytest.raises(RuntimeError, match="JSON object"):
client.exec(1234, 5005, "echo ok", 30)

View file

@ -8,10 +8,15 @@ import pytest
import pyro_mcp.vm_manager as vm_manager_module import pyro_mcp.vm_manager as vm_manager_module
from pyro_mcp.runtime import resolve_runtime_paths from pyro_mcp.runtime import resolve_runtime_paths
from pyro_mcp.vm_manager import VmManager from pyro_mcp.vm_manager import VmManager
from pyro_mcp.vm_network import TapNetworkManager
def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None: def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600) created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
started = manager.start_vm(vm_id) started = manager.start_vm(vm_id)
@ -19,13 +24,18 @@ def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None:
executed = manager.exec_vm(vm_id, command="printf 'git version 2.43.0\\n'", timeout_seconds=30) executed = manager.exec_vm(vm_id, command="printf 'git version 2.43.0\\n'", timeout_seconds=30)
assert executed["exit_code"] == 0 assert executed["exit_code"] == 0
assert executed["execution_mode"] == "host_compat"
assert "git version" in str(executed["stdout"]) assert "git version" in str(executed["stdout"])
with pytest.raises(ValueError, match="does not exist"): with pytest.raises(ValueError, match="does not exist"):
manager.status_vm(vm_id) manager.status_vm(vm_id)
def test_vm_manager_exec_timeout(tmp_path: Path) -> None: def test_vm_manager_exec_timeout(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
vm_id = str( vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[ manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[
"vm_id" "vm_id"
@ -38,7 +48,11 @@ def test_vm_manager_exec_timeout(tmp_path: Path) -> None:
def test_vm_manager_stop_and_delete(tmp_path: Path) -> None: def test_vm_manager_stop_and_delete(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
vm_id = str( vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[ manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[
"vm_id" "vm_id"
@ -52,7 +66,11 @@ def test_vm_manager_stop_and_delete(tmp_path: Path) -> None:
def test_vm_manager_reaps_expired(tmp_path: Path) -> None: def test_vm_manager_reaps_expired(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
manager.MIN_TTL_SECONDS = 1 manager.MIN_TTL_SECONDS = 1
vm_id = str( vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"] manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"]
@ -66,7 +84,11 @@ def test_vm_manager_reaps_expired(tmp_path: Path) -> None:
def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None: def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
manager.MIN_TTL_SECONDS = 1 manager.MIN_TTL_SECONDS = 1
vm_id = str( vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"] manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"]
@ -86,20 +108,33 @@ def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None:
], ],
) )
def test_vm_manager_validates_limits(tmp_path: Path, kwargs: dict[str, int], msg: str) -> None: def test_vm_manager_validates_limits(tmp_path: Path, kwargs: dict[str, int], msg: str) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
with pytest.raises(ValueError, match=msg): with pytest.raises(ValueError, match=msg):
manager.create_vm(profile="debian-base", **kwargs) manager.create_vm(profile="debian-base", **kwargs)
def test_vm_manager_max_active_limit(tmp_path: Path) -> None: def test_vm_manager_max_active_limit(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms", max_active_vms=1) manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
max_active_vms=1,
network_manager=TapNetworkManager(enabled=False),
)
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600) manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)
with pytest.raises(RuntimeError, match="max active VMs reached"): with pytest.raises(RuntimeError, match="max active VMs reached"):
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600) manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)
def test_vm_manager_state_validation(tmp_path: Path) -> None: def test_vm_manager_state_validation(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
vm_id = str( vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[ manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[
"vm_id" "vm_id"
@ -115,7 +150,11 @@ def test_vm_manager_state_validation(tmp_path: Path) -> None:
def test_vm_manager_status_expired_raises(tmp_path: Path) -> None: def test_vm_manager_status_expired_raises(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms") manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
manager.MIN_TTL_SECONDS = 1 manager.MIN_TTL_SECONDS = 1
vm_id = str( vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"] manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"]
@ -127,17 +166,45 @@ def test_vm_manager_status_expired_raises(tmp_path: Path) -> None:
def test_vm_manager_invalid_backend(tmp_path: Path) -> None: def test_vm_manager_invalid_backend(tmp_path: Path) -> None:
with pytest.raises(ValueError, match="invalid backend"): with pytest.raises(ValueError, match="invalid backend"):
VmManager(backend_name="nope", base_dir=tmp_path / "vms") VmManager(
backend_name="nope",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
def test_vm_manager_network_info(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
created = manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
status = manager.status_vm(vm_id)
info = manager.network_info_vm(vm_id)
assert status["network_enabled"] is False
assert status["guest_ip"] is None
assert info["network_enabled"] is False
def test_vm_manager_firecracker_backend_path( def test_vm_manager_firecracker_backend_path(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None: ) -> None:
class StubFirecrackerBackend: class StubFirecrackerBackend:
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None: def __init__(
self,
artifacts_dir: Path,
firecracker_bin: Path,
jailer_bin: Path,
runtime_capabilities: Any,
network_manager: TapNetworkManager,
) -> None:
self.artifacts_dir = artifacts_dir self.artifacts_dir = artifacts_dir
self.firecracker_bin = firecracker_bin self.firecracker_bin = firecracker_bin
self.jailer_bin = jailer_bin self.jailer_bin = jailer_bin
self.runtime_capabilities = runtime_capabilities
self.network_manager = network_manager
def create(self, instance: Any) -> None: def create(self, instance: Any) -> None:
del instance del instance
@ -160,5 +227,6 @@ def test_vm_manager_firecracker_backend_path(
backend_name="firecracker", backend_name="firecracker",
base_dir=tmp_path / "vms", base_dir=tmp_path / "vms",
runtime_paths=resolve_runtime_paths(), runtime_paths=resolve_runtime_paths(),
network_manager=TapNetworkManager(enabled=False),
) )
assert manager._backend_name == "firecracker" # noqa: SLF001 assert manager._backend_name == "firecracker" # noqa: SLF001

60
tests/test_vm_network.py Normal file
View file

@ -0,0 +1,60 @@
from __future__ import annotations
import subprocess
from pathlib import Path
import pytest
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
def test_tap_network_manager_allocation_disabled_by_default(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.delenv("PYRO_VM_ENABLE_NETWORK", raising=False)
manager = TapNetworkManager()
config = manager.allocate("abcdef123456")
assert manager.enabled is False
assert config.tap_name == "pyroabcdef12"
assert config.guest_ip.startswith("172.29.")
metadata = manager.to_metadata(config)
assert metadata["network_enabled"] == "false"
def test_tap_network_manager_network_info() -> None:
manager = TapNetworkManager(enabled=False)
config = NetworkConfig(
vm_id="abcdef123456",
tap_name="pyroabcdef12",
guest_ip="172.29.100.2",
gateway_ip="172.29.100.1",
subnet_cidr="172.29.100.0/24",
mac_address="06:00:ab:cd:ef:12",
)
info = manager.network_info(config)
assert info["tap_name"] == "pyroabcdef12"
assert info["outbound_connectivity_expected"] is False
def test_tap_network_manager_enabled_runs_host_commands() -> None:
commands: list[list[str]] = []
def runner(command: list[str]) -> subprocess.CompletedProcess[str]:
commands.append(command)
return subprocess.CompletedProcess(command, 0, "", "")
manager = TapNetworkManager(enabled=True, runner=runner)
config = manager.allocate("abcdef123456")
manager.cleanup(config)
assert commands[0][:4] == ["ip", "tuntap", "add", "dev"]
assert commands[-1][:3] in (["ip", "link", "del"], ["nft", "delete", "table"])
def test_tap_network_manager_missing_host_support(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(Path, "exists", lambda self: False if str(self) == "/dev/net/tun" else True)
manager = TapNetworkManager(
enabled=True,
runner=lambda command: subprocess.CompletedProcess(command, 0, "", ""),
)
with pytest.raises(RuntimeError, match="/dev/net/tun"):
manager.allocate("abcdef123456")