359 lines
14 KiB
Python
359 lines
14 KiB
Python
"""Lifecycle manager for ephemeral VM environments."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import shutil
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths
|
|
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
|
|
|
|
VmState = Literal["created", "started", "stopped"]
|
|
|
|
|
|
@dataclass
|
|
class VmInstance:
|
|
"""In-memory VM lifecycle record."""
|
|
|
|
vm_id: str
|
|
profile: str
|
|
vcpu_count: int
|
|
mem_mib: int
|
|
ttl_seconds: int
|
|
created_at: float
|
|
expires_at: float
|
|
workdir: Path
|
|
state: VmState = "created"
|
|
firecracker_pid: int | None = None
|
|
last_error: str | None = None
|
|
metadata: dict[str, str] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class VmExecResult:
|
|
"""Command execution output."""
|
|
|
|
stdout: str
|
|
stderr: str
|
|
exit_code: int
|
|
duration_ms: int
|
|
|
|
|
|
def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult:
|
|
started = time.monotonic()
|
|
env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)}
|
|
try:
|
|
proc = subprocess.run( # noqa: S603
|
|
["bash", "-lc", command], # noqa: S607
|
|
cwd=workdir,
|
|
env=env,
|
|
text=True,
|
|
capture_output=True,
|
|
timeout=timeout_seconds,
|
|
check=False,
|
|
)
|
|
return VmExecResult(
|
|
stdout=proc.stdout,
|
|
stderr=proc.stderr,
|
|
exit_code=proc.returncode,
|
|
duration_ms=int((time.monotonic() - started) * 1000),
|
|
)
|
|
except subprocess.TimeoutExpired:
|
|
return VmExecResult(
|
|
stdout="",
|
|
stderr=f"command timed out after {timeout_seconds}s",
|
|
exit_code=124,
|
|
duration_ms=int((time.monotonic() - started) * 1000),
|
|
)
|
|
|
|
|
|
class VmBackend:
|
|
"""Backend interface for lifecycle operations."""
|
|
|
|
def create(self, instance: VmInstance) -> None: # pragma: no cover
|
|
raise NotImplementedError
|
|
|
|
def start(self, instance: VmInstance) -> None: # pragma: no cover
|
|
raise NotImplementedError
|
|
|
|
def exec( # pragma: no cover
|
|
self, instance: VmInstance, command: str, timeout_seconds: int
|
|
) -> VmExecResult:
|
|
raise NotImplementedError
|
|
|
|
def stop(self, instance: VmInstance) -> None: # pragma: no cover
|
|
raise NotImplementedError
|
|
|
|
def delete(self, instance: VmInstance) -> None: # pragma: no cover
|
|
raise NotImplementedError
|
|
|
|
|
|
class MockBackend(VmBackend):
|
|
"""Host-process backend used for development and testability."""
|
|
|
|
def create(self, instance: VmInstance) -> None:
|
|
instance.workdir.mkdir(parents=True, exist_ok=False)
|
|
|
|
def start(self, instance: VmInstance) -> None:
|
|
marker_path = instance.workdir / ".started"
|
|
marker_path.write_text("started\n", encoding="utf-8")
|
|
|
|
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
|
|
return _run_host_command(instance.workdir, command, timeout_seconds)
|
|
|
|
def stop(self, instance: VmInstance) -> None:
|
|
marker_path = instance.workdir / ".stopped"
|
|
marker_path.write_text("stopped\n", encoding="utf-8")
|
|
|
|
def delete(self, instance: VmInstance) -> None:
|
|
shutil.rmtree(instance.workdir, ignore_errors=True)
|
|
|
|
|
|
class FirecrackerBackend(VmBackend): # pragma: no cover
|
|
"""Host-gated backend that validates Firecracker prerequisites."""
|
|
|
|
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None:
|
|
self._artifacts_dir = artifacts_dir
|
|
self._firecracker_bin = firecracker_bin
|
|
self._jailer_bin = jailer_bin
|
|
if not self._firecracker_bin.exists():
|
|
raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}")
|
|
if not self._jailer_bin.exists():
|
|
raise RuntimeError(f"bundled jailer binary not found at {self._jailer_bin}")
|
|
if not Path("/dev/kvm").exists():
|
|
raise RuntimeError("/dev/kvm is not available on this host")
|
|
|
|
def create(self, instance: VmInstance) -> None:
|
|
instance.workdir.mkdir(parents=True, exist_ok=False)
|
|
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
|
|
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
|
|
raise RuntimeError(
|
|
f"missing profile artifacts for {instance.profile}; expected "
|
|
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
|
|
)
|
|
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
|
|
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
|
|
|
|
def start(self, instance: VmInstance) -> None:
|
|
proc = subprocess.run( # noqa: S603
|
|
[str(self._firecracker_bin), "--version"],
|
|
text=True,
|
|
capture_output=True,
|
|
check=False,
|
|
)
|
|
if proc.returncode != 0:
|
|
raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}")
|
|
instance.metadata["firecracker_version"] = proc.stdout.strip()
|
|
instance.metadata["jailer_path"] = str(self._jailer_bin)
|
|
|
|
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
|
|
# Temporary compatibility path until guest-side execution agent is integrated.
|
|
return _run_host_command(instance.workdir, command, timeout_seconds)
|
|
|
|
def stop(self, instance: VmInstance) -> None:
|
|
del instance
|
|
|
|
def delete(self, instance: VmInstance) -> None:
|
|
shutil.rmtree(instance.workdir, ignore_errors=True)
|
|
|
|
|
|
class VmManager:
|
|
"""In-process lifecycle manager for ephemeral VM environments."""
|
|
|
|
MIN_VCPUS = 1
|
|
MAX_VCPUS = 8
|
|
MIN_MEM_MIB = 256
|
|
MAX_MEM_MIB = 32768
|
|
MIN_TTL_SECONDS = 60
|
|
MAX_TTL_SECONDS = 3600
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
backend_name: str | None = None,
|
|
base_dir: Path | None = None,
|
|
artifacts_dir: Path | None = None,
|
|
max_active_vms: int = 4,
|
|
runtime_paths: RuntimePaths | None = None,
|
|
) -> None:
|
|
self._backend_name = backend_name or "firecracker"
|
|
self._base_dir = base_dir or Path("/tmp/pyro-mcp")
|
|
self._runtime_paths = runtime_paths
|
|
if self._backend_name == "firecracker":
|
|
self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
|
|
self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir
|
|
else:
|
|
self._artifacts_dir = artifacts_dir or Path(
|
|
os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts")
|
|
)
|
|
self._max_active_vms = max_active_vms
|
|
self._lock = threading.Lock()
|
|
self._instances: dict[str, VmInstance] = {}
|
|
self._base_dir.mkdir(parents=True, exist_ok=True)
|
|
self._backend = self._build_backend()
|
|
|
|
def _build_backend(self) -> VmBackend:
|
|
if self._backend_name == "mock":
|
|
return MockBackend()
|
|
if self._backend_name == "firecracker":
|
|
if self._runtime_paths is None:
|
|
raise RuntimeError("runtime paths were not initialized for firecracker backend")
|
|
return FirecrackerBackend(
|
|
self._artifacts_dir,
|
|
firecracker_bin=self._runtime_paths.firecracker_bin,
|
|
jailer_bin=self._runtime_paths.jailer_bin,
|
|
)
|
|
raise ValueError("invalid backend; expected one of: mock, firecracker")
|
|
|
|
def list_profiles(self) -> list[dict[str, object]]:
|
|
return list_profiles()
|
|
|
|
def create_vm(
|
|
self, *, profile: str, vcpu_count: int, mem_mib: int, ttl_seconds: int
|
|
) -> dict[str, Any]:
|
|
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
|
|
get_profile(profile)
|
|
now = time.time()
|
|
with self._lock:
|
|
self._reap_expired_locked(now)
|
|
active_count = len(self._instances)
|
|
if active_count >= self._max_active_vms:
|
|
raise RuntimeError(
|
|
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
|
|
)
|
|
vm_id = uuid.uuid4().hex[:12]
|
|
instance = VmInstance(
|
|
vm_id=vm_id,
|
|
profile=profile,
|
|
vcpu_count=vcpu_count,
|
|
mem_mib=mem_mib,
|
|
ttl_seconds=ttl_seconds,
|
|
created_at=now,
|
|
expires_at=now + ttl_seconds,
|
|
workdir=self._base_dir / vm_id,
|
|
)
|
|
self._backend.create(instance)
|
|
self._instances[vm_id] = instance
|
|
return self._serialize(instance)
|
|
|
|
def start_vm(self, vm_id: str) -> dict[str, Any]:
|
|
with self._lock:
|
|
instance = self._get_instance_locked(vm_id)
|
|
self._ensure_not_expired_locked(instance, time.time())
|
|
if instance.state not in {"created", "stopped"}:
|
|
raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}")
|
|
self._backend.start(instance)
|
|
instance.state = "started"
|
|
return self._serialize(instance)
|
|
|
|
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
|
|
if timeout_seconds <= 0:
|
|
raise ValueError("timeout_seconds must be positive")
|
|
with self._lock:
|
|
instance = self._get_instance_locked(vm_id)
|
|
self._ensure_not_expired_locked(instance, time.time())
|
|
if instance.state != "started":
|
|
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
|
|
exec_result = self._backend.exec(instance, command, timeout_seconds)
|
|
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
|
|
return {
|
|
"vm_id": vm_id,
|
|
"command": command,
|
|
"stdout": exec_result.stdout,
|
|
"stderr": exec_result.stderr,
|
|
"exit_code": exec_result.exit_code,
|
|
"duration_ms": exec_result.duration_ms,
|
|
"cleanup": cleanup,
|
|
}
|
|
|
|
def stop_vm(self, vm_id: str) -> dict[str, Any]:
|
|
with self._lock:
|
|
instance = self._get_instance_locked(vm_id)
|
|
self._backend.stop(instance)
|
|
instance.state = "stopped"
|
|
return self._serialize(instance)
|
|
|
|
def delete_vm(self, vm_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]:
|
|
with self._lock:
|
|
instance = self._get_instance_locked(vm_id)
|
|
if instance.state == "started":
|
|
self._backend.stop(instance)
|
|
instance.state = "stopped"
|
|
self._backend.delete(instance)
|
|
del self._instances[vm_id]
|
|
return {"vm_id": vm_id, "deleted": True, "reason": reason}
|
|
|
|
def status_vm(self, vm_id: str) -> dict[str, Any]:
|
|
with self._lock:
|
|
instance = self._get_instance_locked(vm_id)
|
|
self._ensure_not_expired_locked(instance, time.time())
|
|
return self._serialize(instance)
|
|
|
|
def reap_expired(self) -> dict[str, Any]:
|
|
now = time.time()
|
|
with self._lock:
|
|
expired_vm_ids = [
|
|
vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now
|
|
]
|
|
for vm_id in expired_vm_ids:
|
|
instance = self._instances[vm_id]
|
|
if instance.state == "started":
|
|
self._backend.stop(instance)
|
|
instance.state = "stopped"
|
|
self._backend.delete(instance)
|
|
del self._instances[vm_id]
|
|
return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)}
|
|
|
|
def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None:
|
|
if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS:
|
|
raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}")
|
|
if not self.MIN_MEM_MIB <= mem_mib <= self.MAX_MEM_MIB:
|
|
raise ValueError(f"mem_mib must be between {self.MIN_MEM_MIB} and {self.MAX_MEM_MIB}")
|
|
if not self.MIN_TTL_SECONDS <= ttl_seconds <= self.MAX_TTL_SECONDS:
|
|
raise ValueError(
|
|
f"ttl_seconds must be between {self.MIN_TTL_SECONDS} and {self.MAX_TTL_SECONDS}"
|
|
)
|
|
|
|
def _serialize(self, instance: VmInstance) -> dict[str, Any]:
|
|
return {
|
|
"vm_id": instance.vm_id,
|
|
"profile": instance.profile,
|
|
"vcpu_count": instance.vcpu_count,
|
|
"mem_mib": instance.mem_mib,
|
|
"ttl_seconds": instance.ttl_seconds,
|
|
"created_at": instance.created_at,
|
|
"expires_at": instance.expires_at,
|
|
"state": instance.state,
|
|
"metadata": instance.metadata,
|
|
}
|
|
|
|
def _get_instance_locked(self, vm_id: str) -> VmInstance:
|
|
try:
|
|
return self._instances[vm_id]
|
|
except KeyError as exc:
|
|
raise ValueError(f"vm {vm_id!r} does not exist") from exc
|
|
|
|
def _reap_expired_locked(self, now: float) -> None:
|
|
expired_vm_ids = [
|
|
vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now
|
|
]
|
|
for vm_id in expired_vm_ids:
|
|
instance = self._instances[vm_id]
|
|
if instance.state == "started":
|
|
self._backend.stop(instance)
|
|
instance.state = "stopped"
|
|
self._backend.delete(instance)
|
|
del self._instances[vm_id]
|
|
|
|
def _ensure_not_expired_locked(self, instance: VmInstance, now: float) -> None:
|
|
if instance.expires_at <= now:
|
|
vm_id = instance.vm_id
|
|
self._reap_expired_locked(now)
|
|
raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted")
|