Add runtime capability scaffolding and align docs
This commit is contained in:
parent
fb8b985049
commit
cbf212bb7b
19 changed files with 1048 additions and 71 deletions
|
|
@ -7,6 +7,18 @@ from typing import Any
|
|||
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
|
||||
INTERNET_PROBE_COMMAND = (
|
||||
'python3 -c "import urllib.request; '
|
||||
"print(urllib.request.urlopen('https://example.com', timeout=10).status)"
|
||||
'"'
|
||||
)
|
||||
|
||||
|
||||
def _demo_command(status: dict[str, Any]) -> str:
|
||||
if bool(status.get("network_enabled")) and status.get("execution_mode") == "guest_vsock":
|
||||
return INTERNET_PROBE_COMMAND
|
||||
return "git --version"
|
||||
|
||||
|
||||
def run_demo() -> dict[str, Any]:
|
||||
"""Create/start/exec/delete a VM and return command output."""
|
||||
|
|
@ -14,7 +26,8 @@ def run_demo() -> dict[str, Any]:
|
|||
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
|
||||
vm_id = str(created["vm_id"])
|
||||
manager.start_vm(vm_id)
|
||||
executed = manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
|
||||
status = manager.status_vm(vm_id)
|
||||
executed = manager.exec_vm(vm_id, command=_demo_command(status), timeout_seconds=30)
|
||||
return executed
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,6 +16,12 @@ __all__ = ["VmManager", "run_ollama_tool_demo"]
|
|||
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
|
||||
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
|
||||
MAX_TOOL_ROUNDS: Final[int] = 12
|
||||
CLONE_TARGET_DIR: Final[str] = "hello-world"
|
||||
NETWORK_PROOF_COMMAND: Final[str] = (
|
||||
"rm -rf hello-world "
|
||||
"&& git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world >/dev/null "
|
||||
"&& git -C hello-world rev-parse --is-inside-work-tree"
|
||||
)
|
||||
|
||||
TOOL_SPECS: Final[list[dict[str, Any]]] = [
|
||||
{
|
||||
|
|
@ -210,7 +216,7 @@ def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
|
|||
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
|
||||
vm_id = str(created["vm_id"])
|
||||
manager.start_vm(vm_id)
|
||||
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
|
||||
return manager.exec_vm(vm_id, command=NETWORK_PROOF_COMMAND, timeout_seconds=60)
|
||||
|
||||
|
||||
def _is_vm_id_placeholder(value: str) -> bool:
|
||||
|
|
@ -240,8 +246,10 @@ def _normalize_tool_arguments(
|
|||
return normalized_arguments, last_created_vm_id
|
||||
|
||||
|
||||
def _summarize_message_for_log(message: dict[str, Any]) -> str:
|
||||
def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str:
|
||||
role = str(message.get("role", "unknown"))
|
||||
if not verbose:
|
||||
return role
|
||||
content = str(message.get("content") or "").strip()
|
||||
if content == "":
|
||||
return f"{role}: <empty>"
|
||||
|
|
@ -257,6 +265,7 @@ def run_ollama_tool_demo(
|
|||
model: str = DEFAULT_OLLAMA_MODEL,
|
||||
*,
|
||||
strict: bool = True,
|
||||
verbose: bool = False,
|
||||
log: Callable[[str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
|
||||
|
|
@ -267,10 +276,12 @@ def run_ollama_tool_demo(
|
|||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Use the lifecycle tools to run `git --version` in an ephemeral VM.\n"
|
||||
"Use the lifecycle tools to prove outbound internet access in an ephemeral VM.\n"
|
||||
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
|
||||
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
|
||||
"returned by vm_create into vm_start/vm_exec.\n"
|
||||
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
|
||||
f"Success means the clone completes and the command prints `true`.\n"
|
||||
"If a tool returns an error, fix arguments and retry."
|
||||
),
|
||||
}
|
||||
|
|
@ -280,7 +291,7 @@ def run_ollama_tool_demo(
|
|||
last_created_vm_id: str | None = None
|
||||
|
||||
for _round_index in range(1, MAX_TOOL_ROUNDS + 1):
|
||||
emit(f"[model] input {_summarize_message_for_log(messages[-1])}")
|
||||
emit(f"[model] input {_summarize_message_for_log(messages[-1], verbose=verbose)}")
|
||||
response = _post_chat_completion(
|
||||
base_url,
|
||||
{
|
||||
|
|
@ -292,7 +303,7 @@ def run_ollama_tool_demo(
|
|||
},
|
||||
)
|
||||
assistant_message = _extract_message(response)
|
||||
emit(f"[model] output {_summarize_message_for_log(assistant_message)}")
|
||||
emit(f"[model] output {_summarize_message_for_log(assistant_message, verbose=verbose)}")
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
final_response = str(assistant_message.get("content") or "")
|
||||
|
|
@ -320,15 +331,24 @@ def run_ollama_tool_demo(
|
|||
if not isinstance(tool_name, str):
|
||||
raise RuntimeError("tool call function name is invalid")
|
||||
arguments = _parse_tool_arguments(function.get("arguments"))
|
||||
emit(f"[model] tool_call {tool_name} args={arguments}")
|
||||
if verbose:
|
||||
emit(f"[model] tool_call {tool_name} args={arguments}")
|
||||
else:
|
||||
emit(f"[model] tool_call {tool_name}")
|
||||
arguments, normalized_vm_id = _normalize_tool_arguments(
|
||||
tool_name,
|
||||
arguments,
|
||||
last_created_vm_id=last_created_vm_id,
|
||||
)
|
||||
if normalized_vm_id is not None:
|
||||
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
|
||||
emit(f"[tool] calling {tool_name} with args={arguments}")
|
||||
if verbose:
|
||||
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
|
||||
else:
|
||||
emit("[tool] resolved vm_id placeholder")
|
||||
if verbose:
|
||||
emit(f"[tool] calling {tool_name} with args={arguments}")
|
||||
else:
|
||||
emit(f"[tool] calling {tool_name}")
|
||||
try:
|
||||
result = _dispatch_tool_call(manager, tool_name, arguments)
|
||||
success = True
|
||||
|
|
@ -341,7 +361,10 @@ def run_ollama_tool_demo(
|
|||
result = _format_tool_error(tool_name, arguments, exc)
|
||||
success = False
|
||||
emit(f"[tool] {tool_name} failed: {exc}")
|
||||
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
|
||||
if verbose:
|
||||
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
|
||||
else:
|
||||
emit(f"[tool] result {tool_name}")
|
||||
tool_events.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
|
|
@ -379,7 +402,7 @@ def run_ollama_tool_demo(
|
|||
tool_events.append(
|
||||
{
|
||||
"tool_name": "vm_exec_fallback",
|
||||
"arguments": {"command": "git --version"},
|
||||
"arguments": {"command": NETWORK_PROOF_COMMAND},
|
||||
"result": exec_result,
|
||||
"success": True,
|
||||
}
|
||||
|
|
@ -390,13 +413,13 @@ def run_ollama_tool_demo(
|
|||
raise RuntimeError("vm_exec result shape is invalid")
|
||||
if int(exec_result.get("exit_code", -1)) != 0:
|
||||
raise RuntimeError("vm_exec failed; expected exit_code=0")
|
||||
if "git version" not in str(exec_result.get("stdout", "")):
|
||||
raise RuntimeError("vm_exec output did not contain `git version`")
|
||||
if str(exec_result.get("stdout", "")).strip() != "true":
|
||||
raise RuntimeError("vm_exec output did not confirm repository clone success")
|
||||
emit("[done] command execution succeeded")
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"command": "git --version",
|
||||
"command": NETWORK_PROOF_COMMAND,
|
||||
"exec_result": exec_result,
|
||||
"tool_events": tool_events,
|
||||
"final_response": final_response,
|
||||
|
|
@ -408,6 +431,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
|
||||
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
|
||||
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
|
||||
parser.add_argument("-v", "--verbose", action="store_true")
|
||||
return parser
|
||||
|
||||
|
||||
|
|
@ -418,6 +442,7 @@ def main() -> None:
|
|||
result = run_ollama_tool_demo(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
verbose=args.verbose,
|
||||
log=lambda message: print(message, flush=True),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
|
|
@ -428,7 +453,9 @@ def main() -> None:
|
|||
raise RuntimeError("demo produced invalid execution result")
|
||||
print(
|
||||
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
|
||||
f"fallback_used={bool(result.get('fallback_used'))}",
|
||||
f"fallback_used={bool(result.get('fallback_used'))} "
|
||||
f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)
|
||||
if args.verbose:
|
||||
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from dataclasses import dataclass
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pyro_mcp.vm_network import TapNetworkManager
|
||||
|
||||
DEFAULT_PLATFORM = "linux-x86_64"
|
||||
|
||||
|
||||
|
|
@ -27,6 +29,16 @@ class RuntimePaths:
|
|||
manifest: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuntimeCapabilities:
|
||||
"""Feature flags inferred from the bundled runtime."""
|
||||
|
||||
supports_vm_boot: bool
|
||||
supports_guest_exec: bool
|
||||
supports_guest_network: bool
|
||||
reason: str | None = None
|
||||
|
||||
|
||||
def _sha256(path: Path) -> str:
|
||||
digest = hashlib.sha256()
|
||||
with path.open("rb") as fp:
|
||||
|
|
@ -135,6 +147,40 @@ def resolve_runtime_paths(
|
|||
)
|
||||
|
||||
|
||||
def runtime_capabilities(paths: RuntimePaths) -> RuntimeCapabilities:
|
||||
"""Infer what the current bundled runtime can actually do."""
|
||||
binary_text = paths.firecracker_bin.read_text(encoding="utf-8", errors="ignore")
|
||||
if "bundled firecracker shim" in binary_text:
|
||||
return RuntimeCapabilities(
|
||||
supports_vm_boot=False,
|
||||
supports_guest_exec=False,
|
||||
supports_guest_network=False,
|
||||
reason="bundled runtime uses shim firecracker/jailer binaries",
|
||||
)
|
||||
|
||||
capabilities = paths.manifest.get("capabilities")
|
||||
if not isinstance(capabilities, dict):
|
||||
return RuntimeCapabilities(
|
||||
supports_vm_boot=False,
|
||||
supports_guest_exec=False,
|
||||
supports_guest_network=False,
|
||||
reason="runtime manifest does not declare guest boot/exec/network capabilities",
|
||||
)
|
||||
|
||||
supports_vm_boot = bool(capabilities.get("vm_boot"))
|
||||
supports_guest_exec = bool(capabilities.get("guest_exec"))
|
||||
supports_guest_network = bool(capabilities.get("guest_network"))
|
||||
reason = None
|
||||
if not supports_vm_boot:
|
||||
reason = "runtime manifest does not advertise real VM boot support"
|
||||
return RuntimeCapabilities(
|
||||
supports_vm_boot=supports_vm_boot,
|
||||
supports_guest_exec=supports_guest_exec,
|
||||
supports_guest_network=supports_guest_network,
|
||||
reason=reason,
|
||||
)
|
||||
|
||||
|
||||
def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
|
||||
"""Build a runtime diagnostics report."""
|
||||
report: dict[str, Any] = {
|
||||
|
|
@ -146,13 +192,28 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
|
|||
"readable": os.access("/dev/kvm", os.R_OK),
|
||||
"writable": os.access("/dev/kvm", os.W_OK),
|
||||
},
|
||||
"networking": {
|
||||
"enabled_by_default": TapNetworkManager().enabled,
|
||||
},
|
||||
}
|
||||
network = TapNetworkManager.diagnostics()
|
||||
report["networking"].update(
|
||||
{
|
||||
"tun_available": network.tun_available,
|
||||
"ip_binary": network.ip_binary,
|
||||
"nft_binary": network.nft_binary,
|
||||
"iptables_binary": network.iptables_binary,
|
||||
"ip_forward_enabled": network.ip_forward_enabled,
|
||||
}
|
||||
)
|
||||
try:
|
||||
paths = resolve_runtime_paths(platform=platform, verify_checksums=True)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
report["issues"] = [str(exc)]
|
||||
return report
|
||||
|
||||
capabilities = runtime_capabilities(paths)
|
||||
|
||||
profiles = paths.manifest.get("profiles", {})
|
||||
profile_names = sorted(profiles.keys()) if isinstance(profiles, dict) else []
|
||||
report["runtime_ok"] = True
|
||||
|
|
@ -165,6 +226,12 @@ def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
|
|||
"notice_path": str(paths.notice_path),
|
||||
"bundle_version": paths.manifest.get("bundle_version"),
|
||||
"profiles": profile_names,
|
||||
"capabilities": {
|
||||
"supports_vm_boot": capabilities.supports_vm_boot,
|
||||
"supports_guest_exec": capabilities.supports_guest_exec,
|
||||
"supports_guest_network": capabilities.supports_guest_network,
|
||||
"reason": capabilities.reason,
|
||||
},
|
||||
}
|
||||
if not report["kvm"]["exists"]:
|
||||
report["issues"] = ["/dev/kvm is not available on this host"]
|
||||
|
|
|
|||
|
|
@ -59,6 +59,11 @@ def create_server(manager: VmManager | None = None) -> FastMCP:
|
|||
"""Get the current state and metadata for a VM."""
|
||||
return vm_manager.status_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_network_info(vm_id: str) -> dict[str, Any]:
|
||||
"""Get the current network configuration assigned to a VM."""
|
||||
return vm_manager.network_info_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_reap_expired() -> dict[str, Any]:
|
||||
"""Delete VMs whose TTL has expired."""
|
||||
|
|
|
|||
109
src/pyro_mcp/vm_firecracker.py
Normal file
109
src/pyro_mcp/vm_firecracker.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Firecracker launch-plan generation for pyro VMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Protocol
|
||||
|
||||
DEFAULT_GUEST_CID_OFFSET = 10_000
|
||||
DEFAULT_VSOCK_PORT = 5005
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FirecrackerLaunchPlan:
|
||||
"""On-disk config artifacts needed to boot a guest."""
|
||||
|
||||
api_socket_path: Path
|
||||
config_path: Path
|
||||
guest_network_path: Path
|
||||
guest_exec_path: Path
|
||||
guest_cid: int
|
||||
vsock_port: int
|
||||
config: dict[str, Any]
|
||||
|
||||
|
||||
class VmInstanceLike(Protocol):
|
||||
vm_id: str
|
||||
vcpu_count: int
|
||||
mem_mib: int
|
||||
workdir: Path
|
||||
metadata: dict[str, str]
|
||||
network: Any
|
||||
|
||||
|
||||
def _guest_cid(vm_id: str) -> int:
|
||||
return DEFAULT_GUEST_CID_OFFSET + int(vm_id[:4], 16)
|
||||
|
||||
|
||||
def build_launch_plan(instance: VmInstanceLike) -> FirecrackerLaunchPlan:
|
||||
guest_cid = _guest_cid(instance.vm_id)
|
||||
vsock_port = DEFAULT_VSOCK_PORT
|
||||
api_socket_path = instance.workdir / "firecracker.sock"
|
||||
config_path = instance.workdir / "firecracker-config.json"
|
||||
guest_network_path = instance.workdir / "guest-network.json"
|
||||
guest_exec_path = instance.workdir / "guest-exec.json"
|
||||
|
||||
config: dict[str, Any] = {
|
||||
"boot-source": {
|
||||
"kernel_image_path": instance.metadata["kernel_image"],
|
||||
"boot_args": "console=ttyS0 reboot=k panic=1 pci=off",
|
||||
},
|
||||
"drives": [
|
||||
{
|
||||
"drive_id": "rootfs",
|
||||
"path_on_host": instance.metadata["rootfs_image"],
|
||||
"is_root_device": True,
|
||||
"is_read_only": False,
|
||||
}
|
||||
],
|
||||
"machine-config": {
|
||||
"vcpu_count": instance.vcpu_count,
|
||||
"mem_size_mib": instance.mem_mib,
|
||||
"smt": False,
|
||||
},
|
||||
"network-interfaces": [],
|
||||
"vsock": {
|
||||
"guest_cid": guest_cid,
|
||||
"uds_path": str(instance.workdir / "vsock.sock"),
|
||||
},
|
||||
}
|
||||
|
||||
if instance.network is not None:
|
||||
config["network-interfaces"] = [
|
||||
{
|
||||
"iface_id": "eth0",
|
||||
"guest_mac": instance.network.mac_address,
|
||||
"host_dev_name": instance.network.tap_name,
|
||||
}
|
||||
]
|
||||
|
||||
guest_network = {
|
||||
"guest_ip": instance.network.guest_ip if instance.network is not None else None,
|
||||
"gateway_ip": instance.network.gateway_ip if instance.network is not None else None,
|
||||
"subnet_cidr": instance.network.subnet_cidr if instance.network is not None else None,
|
||||
"dns_servers": list(instance.network.dns_servers) if instance.network is not None else [],
|
||||
"tap_name": instance.network.tap_name if instance.network is not None else None,
|
||||
}
|
||||
guest_exec = {
|
||||
"transport": "vsock",
|
||||
"guest_cid": guest_cid,
|
||||
"port": vsock_port,
|
||||
}
|
||||
|
||||
config_path.write_text(json.dumps(config, indent=2, sort_keys=True), encoding="utf-8")
|
||||
guest_network_path.write_text(
|
||||
json.dumps(guest_network, indent=2, sort_keys=True), encoding="utf-8"
|
||||
)
|
||||
guest_exec_path.write_text(json.dumps(guest_exec, indent=2, sort_keys=True), encoding="utf-8")
|
||||
|
||||
return FirecrackerLaunchPlan(
|
||||
api_socket_path=api_socket_path,
|
||||
config_path=config_path,
|
||||
guest_network_path=guest_network_path,
|
||||
guest_exec_path=guest_exec_path,
|
||||
guest_cid=guest_cid,
|
||||
vsock_port=vsock_port,
|
||||
config=config,
|
||||
)
|
||||
72
src/pyro_mcp/vm_guest.py
Normal file
72
src/pyro_mcp/vm_guest.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
"""Guest command transport over vsock-compatible JSON protocol."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Protocol
|
||||
|
||||
|
||||
class SocketLike(Protocol):
|
||||
def settimeout(self, timeout: int) -> None: ...
|
||||
|
||||
def connect(self, address: tuple[int, int]) -> None: ...
|
||||
|
||||
def sendall(self, data: bytes) -> None: ...
|
||||
|
||||
def recv(self, size: int) -> bytes: ...
|
||||
|
||||
def close(self) -> None: ...
|
||||
|
||||
|
||||
SocketFactory = Callable[[int, int], SocketLike]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuestExecResponse:
|
||||
stdout: str
|
||||
stderr: str
|
||||
exit_code: int
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class VsockExecClient:
|
||||
"""Minimal JSON-over-stream client for a guest exec agent."""
|
||||
|
||||
def __init__(self, socket_factory: SocketFactory | None = None) -> None:
|
||||
self._socket_factory = socket_factory or socket.socket
|
||||
|
||||
def exec(
|
||||
self, guest_cid: int, port: int, command: str, timeout_seconds: int
|
||||
) -> GuestExecResponse:
|
||||
request = {
|
||||
"command": command,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
}
|
||||
family = getattr(socket, "AF_VSOCK", None)
|
||||
if family is None:
|
||||
raise RuntimeError("vsock sockets are not supported on this host Python runtime")
|
||||
sock = self._socket_factory(family, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.settimeout(timeout_seconds)
|
||||
sock.connect((guest_cid, port))
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = sock.recv(65536)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
payload = json.loads(b"".join(chunks).decode("utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest exec response must be a JSON object")
|
||||
return GuestExecResponse(
|
||||
stdout=str(payload.get("stdout", "")),
|
||||
stderr=str(payload.get("stderr", "")),
|
||||
exit_code=int(payload.get("exit_code", -1)),
|
||||
duration_ms=int(payload.get("duration_ms", 0)),
|
||||
)
|
||||
|
|
@ -12,7 +12,15 @@ from dataclasses import dataclass, field
|
|||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths
|
||||
from pyro_mcp.runtime import (
|
||||
RuntimeCapabilities,
|
||||
RuntimePaths,
|
||||
resolve_runtime_paths,
|
||||
runtime_capabilities,
|
||||
)
|
||||
from pyro_mcp.vm_firecracker import build_launch_plan
|
||||
from pyro_mcp.vm_guest import VsockExecClient
|
||||
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
|
||||
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
|
||||
|
||||
VmState = Literal["created", "started", "stopped"]
|
||||
|
|
@ -34,6 +42,7 @@ class VmInstance:
|
|||
firecracker_pid: int | None = None
|
||||
last_error: str | None = None
|
||||
metadata: dict[str, str] = field(default_factory=dict)
|
||||
network: NetworkConfig | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
|
@ -119,10 +128,21 @@ class MockBackend(VmBackend):
|
|||
class FirecrackerBackend(VmBackend): # pragma: no cover
|
||||
"""Host-gated backend that validates Firecracker prerequisites."""
|
||||
|
||||
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
artifacts_dir: Path,
|
||||
firecracker_bin: Path,
|
||||
jailer_bin: Path,
|
||||
runtime_capabilities: RuntimeCapabilities,
|
||||
network_manager: TapNetworkManager | None = None,
|
||||
guest_exec_client: VsockExecClient | None = None,
|
||||
) -> None:
|
||||
self._artifacts_dir = artifacts_dir
|
||||
self._firecracker_bin = firecracker_bin
|
||||
self._jailer_bin = jailer_bin
|
||||
self._runtime_capabilities = runtime_capabilities
|
||||
self._network_manager = network_manager or TapNetworkManager()
|
||||
self._guest_exec_client = guest_exec_client or VsockExecClient()
|
||||
if not self._firecracker_bin.exists():
|
||||
raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}")
|
||||
if not self._jailer_bin.exists():
|
||||
|
|
@ -132,16 +152,29 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
|
|||
|
||||
def create(self, instance: VmInstance) -> None:
|
||||
instance.workdir.mkdir(parents=True, exist_ok=False)
|
||||
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
|
||||
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
|
||||
raise RuntimeError(
|
||||
f"missing profile artifacts for {instance.profile}; expected "
|
||||
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
|
||||
)
|
||||
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
|
||||
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
|
||||
try:
|
||||
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
|
||||
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
|
||||
raise RuntimeError(
|
||||
f"missing profile artifacts for {instance.profile}; expected "
|
||||
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
|
||||
)
|
||||
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
|
||||
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
|
||||
network = self._network_manager.allocate(instance.vm_id)
|
||||
instance.network = network
|
||||
instance.metadata.update(self._network_manager.to_metadata(network))
|
||||
except Exception:
|
||||
shutil.rmtree(instance.workdir, ignore_errors=True)
|
||||
raise
|
||||
|
||||
def start(self, instance: VmInstance) -> None:
|
||||
launch_plan = build_launch_plan(instance)
|
||||
instance.metadata["firecracker_config_path"] = str(launch_plan.config_path)
|
||||
instance.metadata["guest_network_path"] = str(launch_plan.guest_network_path)
|
||||
instance.metadata["guest_exec_path"] = str(launch_plan.guest_exec_path)
|
||||
instance.metadata["guest_cid"] = str(launch_plan.guest_cid)
|
||||
instance.metadata["guest_exec_port"] = str(launch_plan.vsock_port)
|
||||
proc = subprocess.run( # noqa: S603
|
||||
[str(self._firecracker_bin), "--version"],
|
||||
text=True,
|
||||
|
|
@ -152,15 +185,35 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
|
|||
raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}")
|
||||
instance.metadata["firecracker_version"] = proc.stdout.strip()
|
||||
instance.metadata["jailer_path"] = str(self._jailer_bin)
|
||||
if not self._runtime_capabilities.supports_vm_boot:
|
||||
instance.metadata["execution_mode"] = "host_compat"
|
||||
instance.metadata["boot_mode"] = "shim"
|
||||
if self._runtime_capabilities.reason is not None:
|
||||
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
|
||||
return
|
||||
instance.metadata["execution_mode"] = "guest_vsock"
|
||||
instance.metadata["boot_mode"] = "native"
|
||||
|
||||
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
|
||||
# Temporary compatibility path until guest-side execution agent is integrated.
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
response = self._guest_exec_client.exec(guest_cid, port, command, timeout_seconds)
|
||||
return VmExecResult(
|
||||
stdout=response.stdout,
|
||||
stderr=response.stderr,
|
||||
exit_code=response.exit_code,
|
||||
duration_ms=response.duration_ms,
|
||||
)
|
||||
instance.metadata["execution_mode"] = "host_compat"
|
||||
return _run_host_command(instance.workdir, command, timeout_seconds)
|
||||
|
||||
def stop(self, instance: VmInstance) -> None:
|
||||
del instance
|
||||
|
||||
def delete(self, instance: VmInstance) -> None:
|
||||
if instance.network is not None:
|
||||
self._network_manager.cleanup(instance.network)
|
||||
shutil.rmtree(instance.workdir, ignore_errors=True)
|
||||
|
||||
|
||||
|
|
@ -182,6 +235,7 @@ class VmManager:
|
|||
artifacts_dir: Path | None = None,
|
||||
max_active_vms: int = 4,
|
||||
runtime_paths: RuntimePaths | None = None,
|
||||
network_manager: TapNetworkManager | None = None,
|
||||
) -> None:
|
||||
self._backend_name = backend_name or "firecracker"
|
||||
self._base_dir = base_dir or Path("/tmp/pyro-mcp")
|
||||
|
|
@ -189,11 +243,19 @@ class VmManager:
|
|||
if self._backend_name == "firecracker":
|
||||
self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
|
||||
self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir
|
||||
self._runtime_capabilities = runtime_capabilities(self._runtime_paths)
|
||||
else:
|
||||
self._artifacts_dir = artifacts_dir or Path(
|
||||
os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts")
|
||||
)
|
||||
self._runtime_capabilities = RuntimeCapabilities(
|
||||
supports_vm_boot=False,
|
||||
supports_guest_exec=False,
|
||||
supports_guest_network=False,
|
||||
reason="mock backend does not boot a guest",
|
||||
)
|
||||
self._max_active_vms = max_active_vms
|
||||
self._network_manager = network_manager or TapNetworkManager()
|
||||
self._lock = threading.Lock()
|
||||
self._instances: dict[str, VmInstance] = {}
|
||||
self._base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -209,6 +271,8 @@ class VmManager:
|
|||
self._artifacts_dir,
|
||||
firecracker_bin=self._runtime_paths.firecracker_bin,
|
||||
jailer_bin=self._runtime_paths.jailer_bin,
|
||||
runtime_capabilities=self._runtime_capabilities,
|
||||
network_manager=self._network_manager,
|
||||
)
|
||||
raise ValueError("invalid backend; expected one of: mock, firecracker")
|
||||
|
||||
|
|
@ -262,6 +326,7 @@ class VmManager:
|
|||
if instance.state != "started":
|
||||
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
|
||||
exec_result = self._backend.exec(instance, command, timeout_seconds)
|
||||
execution_mode = instance.metadata.get("execution_mode", "host_compat")
|
||||
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
|
||||
return {
|
||||
"vm_id": vm_id,
|
||||
|
|
@ -270,6 +335,7 @@ class VmManager:
|
|||
"stderr": exec_result.stderr,
|
||||
"exit_code": exec_result.exit_code,
|
||||
"duration_ms": exec_result.duration_ms,
|
||||
"execution_mode": execution_mode,
|
||||
"cleanup": cleanup,
|
||||
}
|
||||
|
||||
|
|
@ -296,6 +362,19 @@ class VmManager:
|
|||
self._ensure_not_expired_locked(instance, time.time())
|
||||
return self._serialize(instance)
|
||||
|
||||
def network_info_vm(self, vm_id: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
instance = self._get_instance_locked(vm_id)
|
||||
self._ensure_not_expired_locked(instance, time.time())
|
||||
if instance.network is None:
|
||||
return {
|
||||
"vm_id": vm_id,
|
||||
"network_enabled": False,
|
||||
"outbound_connectivity_expected": False,
|
||||
"reason": "network configuration is unavailable for this VM",
|
||||
}
|
||||
return {"vm_id": vm_id, **self._network_manager.network_info(instance.network)}
|
||||
|
||||
def reap_expired(self) -> dict[str, Any]:
|
||||
now = time.time()
|
||||
with self._lock:
|
||||
|
|
@ -331,6 +410,10 @@ class VmManager:
|
|||
"created_at": instance.created_at,
|
||||
"expires_at": instance.expires_at,
|
||||
"state": instance.state,
|
||||
"network_enabled": instance.network is not None,
|
||||
"guest_ip": instance.network.guest_ip if instance.network is not None else None,
|
||||
"tap_name": instance.network.tap_name if instance.network is not None else None,
|
||||
"execution_mode": instance.metadata.get("execution_mode", "host_compat"),
|
||||
"metadata": instance.metadata,
|
||||
}
|
||||
|
||||
|
|
|
|||
191
src/pyro_mcp/vm_network.py
Normal file
191
src/pyro_mcp/vm_network.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
"""Host-side network allocation and diagnostics for Firecracker VMs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
CommandRunner = Callable[[list[str]], subprocess.CompletedProcess[str]]
|
||||
|
||||
DEFAULT_NETWORK_PREFIX = "172.29."
|
||||
DEFAULT_GATEWAY_SUFFIX = 1
|
||||
DEFAULT_GUEST_SUFFIX = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NetworkConfig:
|
||||
vm_id: str
|
||||
tap_name: str
|
||||
guest_ip: str
|
||||
gateway_ip: str
|
||||
subnet_cidr: str
|
||||
mac_address: str
|
||||
dns_servers: tuple[str, ...] = ("1.1.1.1", "8.8.8.8")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NetworkDiagnostics:
|
||||
tun_available: bool
|
||||
ip_binary: str | None
|
||||
nft_binary: str | None
|
||||
iptables_binary: str | None
|
||||
ip_forward_enabled: bool
|
||||
|
||||
|
||||
class TapNetworkManager:
|
||||
"""Allocate per-VM TAP/NAT configuration on the host."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
enabled: bool | None = None,
|
||||
runner: CommandRunner | None = None,
|
||||
) -> None:
|
||||
if enabled is None:
|
||||
self._enabled = os.environ.get("PYRO_VM_ENABLE_NETWORK") == "1"
|
||||
else:
|
||||
self._enabled = enabled
|
||||
self._runner = runner or self._run
|
||||
|
||||
@staticmethod
|
||||
def diagnostics() -> NetworkDiagnostics:
|
||||
ip_forward = False
|
||||
try:
|
||||
ip_forward = (
|
||||
Path("/proc/sys/net/ipv4/ip_forward").read_text(encoding="utf-8").strip()
|
||||
== "1"
|
||||
)
|
||||
except OSError:
|
||||
ip_forward = False
|
||||
return NetworkDiagnostics(
|
||||
tun_available=Path("/dev/net/tun").exists(),
|
||||
ip_binary=shutil.which("ip"),
|
||||
nft_binary=shutil.which("nft"),
|
||||
iptables_binary=shutil.which("iptables"),
|
||||
ip_forward_enabled=ip_forward,
|
||||
)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self._enabled
|
||||
|
||||
def allocate(self, vm_id: str) -> NetworkConfig:
|
||||
octet = int(vm_id[:2], 16)
|
||||
third_octet = 1 + (octet % 200)
|
||||
subnet_cidr = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.0/24"
|
||||
gateway_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GATEWAY_SUFFIX}"
|
||||
guest_ip = f"{DEFAULT_NETWORK_PREFIX}{third_octet}.{DEFAULT_GUEST_SUFFIX}"
|
||||
tap_name = f"pyro{vm_id[:8]}"
|
||||
mac_address = f"06:00:{vm_id[0:2]}:{vm_id[2:4]}:{vm_id[4:6]}:{vm_id[6:8]}"
|
||||
config = NetworkConfig(
|
||||
vm_id=vm_id,
|
||||
tap_name=tap_name,
|
||||
guest_ip=guest_ip,
|
||||
gateway_ip=gateway_ip,
|
||||
subnet_cidr=subnet_cidr,
|
||||
mac_address=mac_address,
|
||||
)
|
||||
if self._enabled:
|
||||
self._ensure_host_network(config)
|
||||
return config
|
||||
|
||||
def cleanup(self, config: NetworkConfig) -> None:
|
||||
if not self._enabled:
|
||||
return
|
||||
table_name = self._nft_table_name(config.vm_id)
|
||||
self._run_ignore(["nft", "delete", "table", "ip", table_name])
|
||||
self._run_ignore(["ip", "link", "del", config.tap_name])
|
||||
|
||||
def to_metadata(self, config: NetworkConfig) -> dict[str, str]:
|
||||
return {
|
||||
"network_enabled": "true" if self._enabled else "false",
|
||||
"tap_name": config.tap_name,
|
||||
"guest_ip": config.guest_ip,
|
||||
"gateway_ip": config.gateway_ip,
|
||||
"subnet_cidr": config.subnet_cidr,
|
||||
"mac_address": config.mac_address,
|
||||
"dns_servers": ",".join(config.dns_servers),
|
||||
}
|
||||
|
||||
def network_info(self, config: NetworkConfig) -> dict[str, object]:
|
||||
return {
|
||||
"network_enabled": self._enabled,
|
||||
"tap_name": config.tap_name,
|
||||
"guest_ip": config.guest_ip,
|
||||
"gateway_ip": config.gateway_ip,
|
||||
"subnet_cidr": config.subnet_cidr,
|
||||
"mac_address": config.mac_address,
|
||||
"dns_servers": list(config.dns_servers),
|
||||
"outbound_connectivity_expected": self._enabled,
|
||||
}
|
||||
|
||||
def _ensure_host_network(self, config: NetworkConfig) -> None:
|
||||
diagnostics = self.diagnostics()
|
||||
if not diagnostics.tun_available:
|
||||
raise RuntimeError("/dev/net/tun is not available on this host")
|
||||
if diagnostics.ip_binary is None:
|
||||
raise RuntimeError("`ip` binary is required for TAP setup")
|
||||
if diagnostics.nft_binary is None:
|
||||
raise RuntimeError("`nft` binary is required for outbound NAT setup")
|
||||
if not diagnostics.ip_forward_enabled:
|
||||
raise RuntimeError("IPv4 forwarding is disabled on this host")
|
||||
|
||||
self._runner(["ip", "tuntap", "add", "dev", config.tap_name, "mode", "tap"])
|
||||
self._runner(["ip", "addr", "add", f"{config.gateway_ip}/24", "dev", config.tap_name])
|
||||
self._runner(["ip", "link", "set", config.tap_name, "up"])
|
||||
|
||||
table_name = self._nft_table_name(config.vm_id)
|
||||
self._runner(["nft", "add", "table", "ip", table_name])
|
||||
self._runner(
|
||||
[
|
||||
"nft",
|
||||
"add",
|
||||
"chain",
|
||||
"ip",
|
||||
table_name,
|
||||
"postrouting",
|
||||
"{",
|
||||
"type",
|
||||
"nat",
|
||||
"hook",
|
||||
"postrouting",
|
||||
"priority",
|
||||
"srcnat",
|
||||
";",
|
||||
"}",
|
||||
]
|
||||
)
|
||||
self._runner([
|
||||
"nft",
|
||||
"add",
|
||||
"rule",
|
||||
"ip",
|
||||
table_name,
|
||||
"postrouting",
|
||||
"ip",
|
||||
"saddr",
|
||||
config.subnet_cidr,
|
||||
"masquerade",
|
||||
])
|
||||
|
||||
def _run_ignore(self, command: list[str]) -> None:
|
||||
try:
|
||||
self._runner(command)
|
||||
except RuntimeError:
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def _nft_table_name(vm_id: str) -> str:
|
||||
return f"pyro_{vm_id[:12]}"
|
||||
|
||||
@staticmethod
|
||||
def _run(command: list[str]) -> subprocess.CompletedProcess[str]:
|
||||
completed = subprocess.run(command, text=True, capture_output=True, check=False)
|
||||
if completed.returncode != 0:
|
||||
stderr = completed.stderr.strip() or completed.stdout.strip()
|
||||
raise RuntimeError(f"command {' '.join(command)!r} failed: {stderr}")
|
||||
return completed
|
||||
Loading…
Add table
Add a link
Reference in a new issue