pyro-mcp/src/pyro_mcp/vm_manager.py

359 lines
14 KiB
Python

"""Lifecycle manager for ephemeral VM environments."""
from __future__ import annotations
import os
import shutil
import subprocess
import threading
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
VmState = Literal["created", "started", "stopped"]
@dataclass
class VmInstance:
"""In-memory VM lifecycle record."""
vm_id: str
profile: str
vcpu_count: int
mem_mib: int
ttl_seconds: int
created_at: float
expires_at: float
workdir: Path
state: VmState = "created"
firecracker_pid: int | None = None
last_error: str | None = None
metadata: dict[str, str] = field(default_factory=dict)
@dataclass(frozen=True)
class VmExecResult:
"""Command execution output."""
stdout: str
stderr: str
exit_code: int
duration_ms: int
def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult:
started = time.monotonic()
env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)}
try:
proc = subprocess.run( # noqa: S603
["bash", "-lc", command], # noqa: S607
cwd=workdir,
env=env,
text=True,
capture_output=True,
timeout=timeout_seconds,
check=False,
)
return VmExecResult(
stdout=proc.stdout,
stderr=proc.stderr,
exit_code=proc.returncode,
duration_ms=int((time.monotonic() - started) * 1000),
)
except subprocess.TimeoutExpired:
return VmExecResult(
stdout="",
stderr=f"command timed out after {timeout_seconds}s",
exit_code=124,
duration_ms=int((time.monotonic() - started) * 1000),
)
class VmBackend:
"""Backend interface for lifecycle operations."""
def create(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def start(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def exec( # pragma: no cover
self, instance: VmInstance, command: str, timeout_seconds: int
) -> VmExecResult:
raise NotImplementedError
def stop(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def delete(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
class MockBackend(VmBackend):
"""Host-process backend used for development and testability."""
def create(self, instance: VmInstance) -> None:
instance.workdir.mkdir(parents=True, exist_ok=False)
def start(self, instance: VmInstance) -> None:
marker_path = instance.workdir / ".started"
marker_path.write_text("started\n", encoding="utf-8")
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
return _run_host_command(instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
marker_path = instance.workdir / ".stopped"
marker_path.write_text("stopped\n", encoding="utf-8")
def delete(self, instance: VmInstance) -> None:
shutil.rmtree(instance.workdir, ignore_errors=True)
class FirecrackerBackend(VmBackend): # pragma: no cover
"""Host-gated backend that validates Firecracker prerequisites."""
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None:
self._artifacts_dir = artifacts_dir
self._firecracker_bin = firecracker_bin
self._jailer_bin = jailer_bin
if not self._firecracker_bin.exists():
raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}")
if not self._jailer_bin.exists():
raise RuntimeError(f"bundled jailer binary not found at {self._jailer_bin}")
if not Path("/dev/kvm").exists():
raise RuntimeError("/dev/kvm is not available on this host")
def create(self, instance: VmInstance) -> None:
instance.workdir.mkdir(parents=True, exist_ok=False)
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
raise RuntimeError(
f"missing profile artifacts for {instance.profile}; expected "
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
)
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
def start(self, instance: VmInstance) -> None:
proc = subprocess.run( # noqa: S603
[str(self._firecracker_bin), "--version"],
text=True,
capture_output=True,
check=False,
)
if proc.returncode != 0:
raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}")
instance.metadata["firecracker_version"] = proc.stdout.strip()
instance.metadata["jailer_path"] = str(self._jailer_bin)
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
# Temporary compatibility path until guest-side execution agent is integrated.
return _run_host_command(instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
del instance
def delete(self, instance: VmInstance) -> None:
shutil.rmtree(instance.workdir, ignore_errors=True)
class VmManager:
"""In-process lifecycle manager for ephemeral VM environments."""
MIN_VCPUS = 1
MAX_VCPUS = 8
MIN_MEM_MIB = 256
MAX_MEM_MIB = 32768
MIN_TTL_SECONDS = 60
MAX_TTL_SECONDS = 3600
def __init__(
self,
*,
backend_name: str | None = None,
base_dir: Path | None = None,
artifacts_dir: Path | None = None,
max_active_vms: int = 4,
runtime_paths: RuntimePaths | None = None,
) -> None:
self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp")
self._runtime_paths = runtime_paths
if self._backend_name == "firecracker":
self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir
else:
self._artifacts_dir = artifacts_dir or Path(
os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts")
)
self._max_active_vms = max_active_vms
self._lock = threading.Lock()
self._instances: dict[str, VmInstance] = {}
self._base_dir.mkdir(parents=True, exist_ok=True)
self._backend = self._build_backend()
def _build_backend(self) -> VmBackend:
if self._backend_name == "mock":
return MockBackend()
if self._backend_name == "firecracker":
if self._runtime_paths is None:
raise RuntimeError("runtime paths were not initialized for firecracker backend")
return FirecrackerBackend(
self._artifacts_dir,
firecracker_bin=self._runtime_paths.firecracker_bin,
jailer_bin=self._runtime_paths.jailer_bin,
)
raise ValueError("invalid backend; expected one of: mock, firecracker")
def list_profiles(self) -> list[dict[str, object]]:
return list_profiles()
def create_vm(
self, *, profile: str, vcpu_count: int, mem_mib: int, ttl_seconds: int
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_profile(profile)
now = time.time()
with self._lock:
self._reap_expired_locked(now)
active_count = len(self._instances)
if active_count >= self._max_active_vms:
raise RuntimeError(
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
vm_id = uuid.uuid4().hex[:12]
instance = VmInstance(
vm_id=vm_id,
profile=profile,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
created_at=now,
expires_at=now + ttl_seconds,
workdir=self._base_dir / vm_id,
)
self._backend.create(instance)
self._instances[vm_id] = instance
return self._serialize(instance)
def start_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.state not in {"created", "stopped"}:
raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}")
self._backend.start(instance)
instance.state = "started"
return self._serialize(instance)
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.state != "started":
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
exec_result = self._backend.exec(instance, command, timeout_seconds)
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return {
"vm_id": vm_id,
"command": command,
"stdout": exec_result.stdout,
"stderr": exec_result.stderr,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"cleanup": cleanup,
}
def stop_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._backend.stop(instance)
instance.state = "stopped"
return self._serialize(instance)
def delete_vm(self, vm_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
self._backend.delete(instance)
del self._instances[vm_id]
return {"vm_id": vm_id, "deleted": True, "reason": reason}
def status_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
return self._serialize(instance)
def reap_expired(self) -> dict[str, Any]:
now = time.time()
with self._lock:
expired_vm_ids = [
vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now
]
for vm_id in expired_vm_ids:
instance = self._instances[vm_id]
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
self._backend.delete(instance)
del self._instances[vm_id]
return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)}
def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None:
if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS:
raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}")
if not self.MIN_MEM_MIB <= mem_mib <= self.MAX_MEM_MIB:
raise ValueError(f"mem_mib must be between {self.MIN_MEM_MIB} and {self.MAX_MEM_MIB}")
if not self.MIN_TTL_SECONDS <= ttl_seconds <= self.MAX_TTL_SECONDS:
raise ValueError(
f"ttl_seconds must be between {self.MIN_TTL_SECONDS} and {self.MAX_TTL_SECONDS}"
)
def _serialize(self, instance: VmInstance) -> dict[str, Any]:
return {
"vm_id": instance.vm_id,
"profile": instance.profile,
"vcpu_count": instance.vcpu_count,
"mem_mib": instance.mem_mib,
"ttl_seconds": instance.ttl_seconds,
"created_at": instance.created_at,
"expires_at": instance.expires_at,
"state": instance.state,
"metadata": instance.metadata,
}
def _get_instance_locked(self, vm_id: str) -> VmInstance:
try:
return self._instances[vm_id]
except KeyError as exc:
raise ValueError(f"vm {vm_id!r} does not exist") from exc
def _reap_expired_locked(self, now: float) -> None:
expired_vm_ids = [
vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now
]
for vm_id in expired_vm_ids:
instance = self._instances[vm_id]
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
self._backend.delete(instance)
del self._instances[vm_id]
def _ensure_not_expired_locked(self, instance: VmInstance, now: float) -> None:
if instance.expires_at <= now:
vm_id = instance.vm_id
self._reap_expired_locked(now)
raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted")