Add runtime capability scaffolding and align docs

This commit is contained in:
Thales Maciel 2026-03-05 22:57:09 -03:00
parent fb8b985049
commit cbf212bb7b
19 changed files with 1048 additions and 71 deletions

View file

@ -7,6 +7,18 @@ from typing import Any
from pyro_mcp.vm_manager import VmManager
INTERNET_PROBE_COMMAND = (
'python3 -c "import urllib.request; '
"print(urllib.request.urlopen('https://example.com', timeout=10).status)"
'"'
)
def _demo_command(status: dict[str, Any]) -> str:
if bool(status.get("network_enabled")) and status.get("execution_mode") == "guest_vsock":
return INTERNET_PROBE_COMMAND
return "git --version"
def run_demo() -> dict[str, Any]:
"""Create/start/exec/delete a VM and return command output."""
@ -14,7 +26,8 @@ def run_demo() -> dict[str, Any]:
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
manager.start_vm(vm_id)
executed = manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
status = manager.status_vm(vm_id)
executed = manager.exec_vm(vm_id, command=_demo_command(status), timeout_seconds=30)
return executed

View file

@ -16,6 +16,12 @@ __all__ = ["VmManager", "run_ollama_tool_demo"]
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
MAX_TOOL_ROUNDS: Final[int] = 12
CLONE_TARGET_DIR: Final[str] = "hello-world"
NETWORK_PROOF_COMMAND: Final[str] = (
"rm -rf hello-world "
"&& git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world >/dev/null "
"&& git -C hello-world rev-parse --is-inside-work-tree"
)
TOOL_SPECS: Final[list[dict[str, Any]]] = [
{
@ -210,7 +216,7 @@ def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
manager.start_vm(vm_id)
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
return manager.exec_vm(vm_id, command=NETWORK_PROOF_COMMAND, timeout_seconds=60)
def _is_vm_id_placeholder(value: str) -> bool:
@ -240,8 +246,10 @@ def _normalize_tool_arguments(
return normalized_arguments, last_created_vm_id
def _summarize_message_for_log(message: dict[str, Any]) -> str:
def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str:
role = str(message.get("role", "unknown"))
if not verbose:
return role
content = str(message.get("content") or "").strip()
if content == "":
return f"{role}: <empty>"
@ -257,6 +265,7 @@ def run_ollama_tool_demo(
model: str = DEFAULT_OLLAMA_MODEL,
*,
strict: bool = True,
verbose: bool = False,
log: Callable[[str], None] | None = None,
) -> dict[str, Any]:
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
@ -267,10 +276,12 @@ def run_ollama_tool_demo(
{
"role": "user",
"content": (
"Use the lifecycle tools to run `git --version` in an ephemeral VM.\n"
"Use the lifecycle tools to prove outbound internet access in an ephemeral VM.\n"
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
"returned by vm_create into vm_start/vm_exec.\n"
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
f"Success means the clone completes and the command prints `true`.\n"
"If a tool returns an error, fix arguments and retry."
),
}
@ -280,7 +291,7 @@ def run_ollama_tool_demo(
last_created_vm_id: str | None = None
for _round_index in range(1, MAX_TOOL_ROUNDS + 1):
emit(f"[model] input {_summarize_message_for_log(messages[-1])}")
emit(f"[model] input {_summarize_message_for_log(messages[-1], verbose=verbose)}")
response = _post_chat_completion(
base_url,
{
@ -292,7 +303,7 @@ def run_ollama_tool_demo(
},
)
assistant_message = _extract_message(response)
emit(f"[model] output {_summarize_message_for_log(assistant_message)}")
emit(f"[model] output {_summarize_message_for_log(assistant_message, verbose=verbose)}")
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
final_response = str(assistant_message.get("content") or "")
@ -320,15 +331,24 @@ def run_ollama_tool_demo(
if not isinstance(tool_name, str):
raise RuntimeError("tool call function name is invalid")
arguments = _parse_tool_arguments(function.get("arguments"))
emit(f"[model] tool_call {tool_name} args={arguments}")
if verbose:
emit(f"[model] tool_call {tool_name} args={arguments}")
else:
emit(f"[model] tool_call {tool_name}")
arguments, normalized_vm_id = _normalize_tool_arguments(
tool_name,
arguments,
last_created_vm_id=last_created_vm_id,
)
if normalized_vm_id is not None:
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
emit(f"[tool] calling {tool_name} with args={arguments}")
if verbose:
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
else:
emit("[tool] resolved vm_id placeholder")
if verbose:
emit(f"[tool] calling {tool_name} with args={arguments}")
else:
emit(f"[tool] calling {tool_name}")
try:
result = _dispatch_tool_call(manager, tool_name, arguments)
success = True
@ -341,7 +361,10 @@ def run_ollama_tool_demo(
result = _format_tool_error(tool_name, arguments, exc)
success = False
emit(f"[tool] {tool_name} failed: {exc}")
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
if verbose:
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
else:
emit(f"[tool] result {tool_name}")
tool_events.append(
{
"tool_name": tool_name,
@ -379,7 +402,7 @@ def run_ollama_tool_demo(
tool_events.append(
{
"tool_name": "vm_exec_fallback",
"arguments": {"command": "git --version"},
"arguments": {"command": NETWORK_PROOF_COMMAND},
"result": exec_result,
"success": True,
}
@ -390,13 +413,13 @@ def run_ollama_tool_demo(
raise RuntimeError("vm_exec result shape is invalid")
if int(exec_result.get("exit_code", -1)) != 0:
raise RuntimeError("vm_exec failed; expected exit_code=0")
if "git version" not in str(exec_result.get("stdout", "")):
raise RuntimeError("vm_exec output did not contain `git version`")
if str(exec_result.get("stdout", "")).strip() != "true":
raise RuntimeError("vm_exec output did not confirm repository clone success")
emit("[done] command execution succeeded")
return {
"model": model,
"command": "git --version",
"command": NETWORK_PROOF_COMMAND,
"exec_result": exec_result,
"tool_events": tool_events,
"final_response": final_response,
@ -408,6 +431,7 @@ def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
parser.add_argument("-v", "--verbose", action="store_true")
return parser
@ -418,6 +442,7 @@ def main() -> None:
result = run_ollama_tool_demo(
base_url=args.base_url,
model=args.model,
verbose=args.verbose,
log=lambda message: print(message, flush=True),
)
except Exception as exc: # noqa: BLE001
@ -428,7 +453,9 @@ def main() -> None:
raise RuntimeError("demo produced invalid execution result")
print(
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
f"fallback_used={bool(result.get('fallback_used'))}",
f"fallback_used={bool(result.get('fallback_used'))} "
f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}",
flush=True,
)
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)
if args.verbose:
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)

View file

@ -11,6 +11,8 @@ from dataclasses import dataclass
from pathlib import Path
from typing import Any
from pyro_mcp.vm_network import TapNetworkManager
DEFAULT_PLATFORM = "linux-x86_64"
@ -27,6 +29,16 @@ class RuntimePaths:
manifest: dict[str, Any]
@dataclass(frozen=True)
class RuntimeCapabilities:
"""Feature flags inferred from the bundled runtime."""
supports_vm_boot: bool
supports_guest_exec: bool
supports_guest_network: bool
reason: str | None = None
def _sha256(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as fp:
@ -135,6 +147,40 @@ def resolve_runtime_paths(
)
def runtime_capabilities(paths: RuntimePaths) -> RuntimeCapabilities:
"""Infer what the current bundled runtime can actually do."""
binary_text = paths.firecracker_bin.read_text(encoding="utf-8", errors="ignore")
if "bundled firecracker shim" in binary_text:
return RuntimeCapabilities(
supports_vm_boot=False,
supports_guest_exec=False,
supports_guest_network=False,
reason="bundled runtime uses shim firecracker/jailer binaries",
)
capabilities = paths.manifest.get("capabilities")
if not isinstance(capabilities, dict):
return RuntimeCapabilities(
supports_vm_boot=False,
supports_guest_exec=False,
supports_guest_network=False,
reason="runtime manifest does not declare guest boot/exec/network capabilities",
)
supports_vm_boot = bool(capabilities.get("vm_boot"))
supports_guest_exec = bool(capabilities.get("guest_exec"))
supports_guest_network = bool(capabilities.get("guest_network"))
reason = None
if not supports_vm_boot:
reason = "runtime manifest does not advertise real VM boot support"
return RuntimeCapabilities(
supports_vm_boot=supports_vm_boot,
supports_guest_exec=supports_guest_exec,
supports_guest_network=supports_guest_network,
reason=reason,
)
def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
"""Build a runtime diagnostics report."""
report: dict[str, Any] = {
@ -146,13 +192,28 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
"readable": os.access("/dev/kvm", os.R_OK),
"writable": os.access("/dev/kvm", os.W_OK),
},
"networking": {
"enabled_by_default": TapNetworkManager().enabled,
},
}
network = TapNetworkManager.diagnostics()
report["networking"].update(
{
"tun_available": network.tun_available,
"ip_binary": network.ip_binary,
"nft_binary": network.nft_binary,
"iptables_binary": network.iptables_binary,
"ip_forward_enabled": network.ip_forward_enabled,
}
)
try:
paths = resolve_runtime_paths(platform=platform, verify_checksums=True)
except Exception as exc: # noqa: BLE001
report["issues"] = [str(exc)]
return report
capabilities = runtime_capabilities(paths)
profiles = paths.manifest.get("profiles", {})
profile_names = sorted(profiles.keys()) if isinstance(profiles, dict) else []
report["runtime_ok"] = True
@ -165,6 +226,12 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
"notice_path": str(paths.notice_path),
"bundle_version": paths.manifest.get("bundle_version"),
"profiles": profile_names,
"capabilities": {
"supports_vm_boot": capabilities.supports_vm_boot,
"supports_guest_exec": capabilities.supports_guest_exec,
"supports_guest_network": capabilities.supports_guest_network,
"reason": capabilities.reason,
},
}
if not report["kvm"]["exists"]:
report["issues"] = ["/dev/kvm is not available on this host"]

View file

@ -59,6 +59,11 @@ def create_server(manager: VmManager | None = None) -> FastMCP:
"""Get the current state and metadata for a VM."""
return vm_manager.status_vm(vm_id)
@server.tool()
async def vm_network_info(vm_id: str) -> dict[str, Any]:
"""Get the current network configuration assigned to a VM."""
return vm_manager.network_info_vm(vm_id)
@server.tool()
async def vm_reap_expired() -> dict[str, Any]:
"""Delete VMs whose TTL has expired."""

View file

@ -0,0 +1,109 @@
"""Firecracker launch-plan generation for pyro VMs."""
from __future__ import annotations
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Protocol
DEFAULT_GUEST_CID_OFFSET = 10_000
DEFAULT_VSOCK_PORT = 5005
@dataclass(frozen=True)
class FirecrackerLaunchPlan:
"""On-disk config artifacts needed to boot a guest."""
api_socket_path: Path
config_path: Path
guest_network_path: Path
guest_exec_path: Path
guest_cid: int
vsock_port: int
config: dict[str, Any]
class VmInstanceLike(Protocol):
vm_id: str
vcpu_count: int
mem_mib: int
workdir: Path
metadata: dict[str, str]
network: Any
def _guest_cid(vm_id: str) -> int:
return DEFAULT_GUEST_CID_OFFSET + int(vm_id[:4], 16)
def build_launch_plan(instance: VmInstanceLike) -> FirecrackerLaunchPlan:
guest_cid = _guest_cid(instance.vm_id)
vsock_port = DEFAULT_VSOCK_PORT
api_socket_path = instance.workdir / "firecracker.sock"
config_path = instance.workdir / "firecracker-config.json"
guest_network_path = instance.workdir / "guest-network.json"
guest_exec_path = instance.workdir / "guest-exec.json"
config: dict[str, Any] = {
"boot-source": {
"kernel_image_path": instance.metadata["kernel_image"],
"boot_args": "console=ttyS0 reboot=k panic=1 pci=off",
},
"drives": [
{
"drive_id": "rootfs",
"path_on_host": instance.metadata["rootfs_image"],
"is_root_device": True,
"is_read_only": False,
}
],
"machine-config": {
"vcpu_count": instance.vcpu_count,
"mem_size_mib": instance.mem_mib,
"smt": False,
},
"network-interfaces": [],
"vsock": {
"guest_cid": guest_cid,
"uds_path": str(instance.workdir / "vsock.sock"),
},
}
if instance.network is not None:
config["network-interfaces"] = [
{
"iface_id": "eth0",
"guest_mac": instance.network.mac_address,
"host_dev_name": instance.network.tap_name,
}
]
guest_network = {
"guest_ip": instance.network.guest_ip if instance.network is not None else None,
"gateway_ip": instance.network.gateway_ip if instance.network is not None else None,
"subnet_cidr": instance.network.subnet_cidr if instance.network is not None else None,
"dns_servers": list(instance.network.dns_servers) if instance.network is not None else [],
"tap_name": instance.network.tap_name if instance.network is not None else None,
}
guest_exec = {
"transport": "vsock",
"guest_cid": guest_cid,
"port": vsock_port,
}
config_path.write_text(json.dumps(config, indent=2, sort_keys=True), encoding="utf-8")
guest_network_path.write_text(
json.dumps(guest_network, indent=2, sort_keys=True), encoding="utf-8"
)
guest_exec_path.write_text(json.dumps(guest_exec, indent=2, sort_keys=True), encoding="utf-8")
return FirecrackerLaunchPlan(
api_socket_path=api_socket_path,
config_path=config_path,
guest_network_path=guest_network_path,
guest_exec_path=guest_exec_path,
guest_cid=guest_cid,
vsock_port=vsock_port,
config=config,
)

72
src/pyro_mcp/vm_guest.py Normal file
View file

@ -0,0 +1,72 @@
"""Guest command transport over vsock-compatible JSON protocol."""
from __future__ import annotations
import json
import socket
from dataclasses import dataclass
from typing import Callable, Protocol
class SocketLike(Protocol):
def settimeout(self, timeout: int) -> None: ...
def connect(self, address: tuple[int, int]) -> None: ...
def sendall(self, data: bytes) -> None: ...
def recv(self, size: int) -> bytes: ...
def close(self) -> None: ...
SocketFactory = Callable[[int, int], SocketLike]
@dataclass(frozen=True)
class GuestExecResponse:
stdout: str
stderr: str
exit_code: int
duration_ms: int
class VsockExecClient:
"""Minimal JSON-over-stream client for a guest exec agent."""
def __init__(self, socket_factory: SocketFactory | None = None) -> None:
self._socket_factory = socket_factory or socket.socket
def exec(
self, guest_cid: int, port: int, command: str, timeout_seconds: int
) -> GuestExecResponse:
request = {
"command": command,
"timeout_seconds": timeout_seconds,
}
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise RuntimeError("vsock sockets are not supported on this host Python runtime")
sock = self._socket_factory(family, socket.SOCK_STREAM)
try:
sock.settimeout(timeout_seconds)
sock.connect((guest_cid, port))
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
chunks: list[bytes] = []
while True:
data = sock.recv(65536)
if data == b"":
break
chunks.append(data)
finally:
sock.close()
payload = json.loads(b"".join(chunks).decode("utf-8"))
if not isinstance(payload, dict):
raise RuntimeError("guest exec response must be a JSON object")
return GuestExecResponse(
stdout=str(payload.get("stdout", "")),
stderr=str(payload.get("stderr", "")),
exit_code=int(payload.get("exit_code", -1)),
duration_ms=int(payload.get("duration_ms", 0)),
)

View file

@ -12,7 +12,15 @@ from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths
from pyro_mcp.runtime import (
RuntimeCapabilities,
RuntimePaths,
resolve_runtime_paths,
runtime_capabilities,
)
from pyro_mcp.vm_firecracker import build_launch_plan
from pyro_mcp.vm_guest import VsockExecClient
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
VmState = Literal["created", "started", "stopped"]
@ -34,6 +42,7 @@ class VmInstance:
firecracker_pid: int | None = None
last_error: str | None = None
metadata: dict[str, str] = field(default_factory=dict)
network: NetworkConfig | None = None
@dataclass(frozen=True)
@ -119,10 +128,21 @@ class MockBackend(VmBackend):
class FirecrackerBackend(VmBackend): # pragma: no cover
"""Host-gated backend that validates Firecracker prerequisites."""
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None:
def __init__(
self,
artifacts_dir: Path,
firecracker_bin: Path,
jailer_bin: Path,
runtime_capabilities: RuntimeCapabilities,
network_manager: TapNetworkManager | None = None,
guest_exec_client: VsockExecClient | None = None,
) -> None:
self._artifacts_dir = artifacts_dir
self._firecracker_bin = firecracker_bin
self._jailer_bin = jailer_bin
self._runtime_capabilities = runtime_capabilities
self._network_manager = network_manager or TapNetworkManager()
self._guest_exec_client = guest_exec_client or VsockExecClient()
if not self._firecracker_bin.exists():
raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}")
if not self._jailer_bin.exists():
@ -132,16 +152,29 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
def create(self, instance: VmInstance) -> None:
instance.workdir.mkdir(parents=True, exist_ok=False)
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
raise RuntimeError(
f"missing profile artifacts for {instance.profile}; expected "
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
)
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
try:
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
raise RuntimeError(
f"missing profile artifacts for {instance.profile}; expected "
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
)
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
network = self._network_manager.allocate(instance.vm_id)
instance.network = network
instance.metadata.update(self._network_manager.to_metadata(network))
except Exception:
shutil.rmtree(instance.workdir, ignore_errors=True)
raise
def start(self, instance: VmInstance) -> None:
launch_plan = build_launch_plan(instance)
instance.metadata["firecracker_config_path"] = str(launch_plan.config_path)
instance.metadata["guest_network_path"] = str(launch_plan.guest_network_path)
instance.metadata["guest_exec_path"] = str(launch_plan.guest_exec_path)
instance.metadata["guest_cid"] = str(launch_plan.guest_cid)
instance.metadata["guest_exec_port"] = str(launch_plan.vsock_port)
proc = subprocess.run( # noqa: S603
[str(self._firecracker_bin), "--version"],
text=True,
@ -152,15 +185,35 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}")
instance.metadata["firecracker_version"] = proc.stdout.strip()
instance.metadata["jailer_path"] = str(self._jailer_bin)
if not self._runtime_capabilities.supports_vm_boot:
instance.metadata["execution_mode"] = "host_compat"
instance.metadata["boot_mode"] = "shim"
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
return
instance.metadata["execution_mode"] = "guest_vsock"
instance.metadata["boot_mode"] = "native"
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
# Temporary compatibility path until guest-side execution agent is integrated.
if self._runtime_capabilities.supports_guest_exec:
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
response = self._guest_exec_client.exec(guest_cid, port, command, timeout_seconds)
return VmExecResult(
stdout=response.stdout,
stderr=response.stderr,
exit_code=response.exit_code,
duration_ms=response.duration_ms,
)
instance.metadata["execution_mode"] = "host_compat"
return _run_host_command(instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
del instance
def delete(self, instance: VmInstance) -> None:
if instance.network is not None:
self._network_manager.cleanup(instance.network)
shutil.rmtree(instance.workdir, ignore_errors=True)
@ -182,6 +235,7 @@ class VmManager:
artifacts_dir: Path | None = None,
max_active_vms: int = 4,
runtime_paths: RuntimePaths | None = None,
network_manager: TapNetworkManager | None = None,
) -> None:
self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp")
@ -189,11 +243,19 @@ class VmManager:
if self._backend_name == "firecracker":
self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir
self._runtime_capabilities = runtime_capabilities(self._runtime_paths)
else:
self._artifacts_dir = artifacts_dir or Path(
os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts")
)
self._runtime_capabilities = RuntimeCapabilities(
supports_vm_boot=False,
supports_guest_exec=False,
supports_guest_network=False,
reason="mock backend does not boot a guest",
)
self._max_active_vms = max_active_vms
self._network_manager = network_manager or TapNetworkManager()
self._lock = threading.Lock()
self._instances: dict[str, VmInstance] = {}
self._base_dir.mkdir(parents=True, exist_ok=True)
@ -209,6 +271,8 @@ class VmManager:
self._artifacts_dir,
firecracker_bin=self._runtime_paths.firecracker_bin,
jailer_bin=self._runtime_paths.jailer_bin,
runtime_capabilities=self._runtime_capabilities,
network_manager=self._network_manager,
)
raise ValueError("invalid backend; expected one of: mock, firecracker")
@ -262,6 +326,7 @@ class VmManager:
if instance.state != "started":
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
exec_result = self._backend.exec(instance, command, timeout_seconds)
execution_mode = instance.metadata.get("execution_mode", "host_compat")
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return {
"vm_id": vm_id,
@ -270,6 +335,7 @@ class VmManager:
"stderr": exec_result.stderr,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
"cleanup": cleanup,
}
@ -296,6 +362,19 @@ class VmManager:
self._ensure_not_expired_locked(instance, time.time())
return self._serialize(instance)
def network_info_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.network is None:
return {
"vm_id": vm_id,
"network_enabled": False,
"outbound_connectivity_expected": False,
"reason": "network configuration is unavailable for this VM",
}
return {"vm_id": vm_id, **self._network_manager.network_info(instance.network)}
def reap_expired(self) -> dict[str, Any]:
now = time.time()
with self._lock:
@ -331,6 +410,10 @@ class VmManager:
"created_at": instance.created_at,
"expires_at": instance.expires_at,
"state": instance.state,
"network_enabled": instance.network is not None,
"guest_ip": instance.network.guest_ip if instance.network is not None else None,
"tap_name": instance.network.tap_name if instance.network is not None else None,
"execution_mode": instance.metadata.get("execution_mode", "host_compat"),
"metadata": instance.metadata,
}

191
src/pyro_mcp/vm_network.py Normal file
View file

@ -0,0 +1,191 @@
"""Host-side network allocation and diagnostics for Firecracker VMs."""
from __future__ import annotations
import os
import shutil
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Callable
CommandRunner = Callable[[list[str]], subprocess.CompletedProcess[str]]
DEFAULT_NETWORK_PREFIX = "172.29."
DEFAULT_GATEWAY_SUFFIX = 1
DEFAULT_GUEST_SUFFIX = 2
@dataclass(frozen=True)
class NetworkConfig:
vm_id: str
tap_name: str
guest_ip: str
gateway_ip: str
subnet_cidr: str
mac_address: str
dns_servers: tuple[str, ...] = ("1.1.1.1", "8.8.8.8")
@dataclass(frozen=True)
class NetworkDiagnostics:
tun_available: bool
ip_binary: str | None
nft_binary: str | None
iptables_binary: str | None
ip_forward_enabled: bool
class TapNetworkManager:
"""Allocate per-VM TAP/NAT configuration on the host."""
def __init__(
self,
*,
enabled: bool | None = None,
runner: CommandRunner | None = None,
) -> None:
if enabled is None:
self._enabled = os.environ.get("PYRO_VM_ENABLE_NETWORK") == "1"
else:
self._enabled = enabled
self._runner = runner or self._run
@staticmethod
def diagnostics() -> NetworkDiagnostics:
ip_forward = False
try:
ip_forward = (
Path("/proc/sys/net/ipv4/ip_forward").read_text(encoding="utf-8").strip()
== "1"
)
except OSError:
ip_forward = False
return NetworkDiagnostics(
tun_available=Path("/dev/net/tun").exists(),
ip_binary=shutil.which("ip"),
nft_binary=shutil.which("nft"),
iptables_binary=shutil.which("iptables"),
ip_forward_enabled=ip_forward,
)
@property
def enabled(self) -> bool:
return self._enabled
def allocate(self, vm_id: str) -> NetworkConfig:
octet = int(vm_id[:2], 16)
third_octet = 1 + (octet % 200)
subnet_cidr = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.0/24"
gateway_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GATEWAY_SUFFIX}"
guest_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GUEST_SUFFIX}"
tap_name = f"pyro{vm_id[:8]}"
mac_address = f"06:00:{vm_id[0:2]}:{vm_id[2:4]}:{vm_id[4:6]}:{vm_id[6:8]}"
config = NetworkConfig(
vm_id=vm_id,
tap_name=tap_name,
guest_ip=guest_ip,
gateway_ip=gateway_ip,
subnet_cidr=subnet_cidr,
mac_address=mac_address,
)
if self._enabled:
self._ensure_host_network(config)
return config
def cleanup(self, config: NetworkConfig) -> None:
if not self._enabled:
return
table_name = self._nft_table_name(config.vm_id)
self._run_ignore(["nft", "delete", "table", "ip", table_name])
self._run_ignore(["ip", "link", "del", config.tap_name])
def to_metadata(self, config: NetworkConfig) -> dict[str, str]:
return {
"network_enabled": "true" if self._enabled else "false",
"tap_name": config.tap_name,
"guest_ip": config.guest_ip,
"gateway_ip": config.gateway_ip,
"subnet_cidr": config.subnet_cidr,
"mac_address": config.mac_address,
"dns_servers": ",".join(config.dns_servers),
}
def network_info(self, config: NetworkConfig) -> dict[str, object]:
return {
"network_enabled": self._enabled,
"tap_name": config.tap_name,
"guest_ip": config.guest_ip,
"gateway_ip": config.gateway_ip,
"subnet_cidr": config.subnet_cidr,
"mac_address": config.mac_address,
"dns_servers": list(config.dns_servers),
"outbound_connectivity_expected": self._enabled,
}
def _ensure_host_network(self, config: NetworkConfig) -> None:
diagnostics = self.diagnostics()
if not diagnostics.tun_available:
raise RuntimeError("/dev/net/tun is not available on this host")
if diagnostics.ip_binary is None:
raise RuntimeError("`ip` binary is required for TAP setup")
if diagnostics.nft_binary is None:
raise RuntimeError("`nft` binary is required for outbound NAT setup")
if not diagnostics.ip_forward_enabled:
raise RuntimeError("IPv4 forwarding is disabled on this host")
self._runner(["ip", "tuntap", "add", "dev", config.tap_name, "mode", "tap"])
self._runner(["ip", "addr", "add", f"{config.gateway_ip}/24", "dev", config.tap_name])
self._runner(["ip", "link", "set", config.tap_name, "up"])
table_name = self._nft_table_name(config.vm_id)
self._runner(["nft", "add", "table", "ip", table_name])
self._runner(
[
"nft",
"add",
"chain",
"ip",
table_name,
"postrouting",
"{",
"type",
"nat",
"hook",
"postrouting",
"priority",
"srcnat",
";",
"}",
]
)
self._runner([
"nft",
"add",
"rule",
"ip",
table_name,
"postrouting",
"ip",
"saddr",
config.subnet_cidr,
"masquerade",
])
def _run_ignore(self, command: list[str]) -> None:
try:
self._runner(command)
except RuntimeError:
return
@staticmethod
def _nft_table_name(vm_id: str) -> str:
return f"pyro_{vm_id[:12]}"
@staticmethod
def _run(command: list[str]) -> subprocess.CompletedProcess[str]:
completed = subprocess.run(command, text=True, capture_output=True, check=False)
if completed.returncode != 0:
stderr = completed.stderr.strip() or completed.stdout.strip()
raise RuntimeError(f"command {' '.join(command)!r} failed: {stderr}")
return completed