"""Lifecycle manager for ephemeral VM environments.""" from __future__ import annotations import os import shutil import subprocess import threading import time import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts VmState = Literal["created", "started", "stopped"] @dataclass class VmInstance: """In-memory VM lifecycle record.""" vm_id: str profile: str vcpu_count: int mem_mib: int ttl_seconds: int created_at: float expires_at: float workdir: Path state: VmState = "created" firecracker_pid: int | None = None last_error: str | None = None metadata: dict[str, str] = field(default_factory=dict) @dataclass(frozen=True) class VmExecResult: """Command execution output.""" stdout: str stderr: str exit_code: int duration_ms: int def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult: started = time.monotonic() env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)} try: proc = subprocess.run( # noqa: S603 ["bash", "-lc", command], # noqa: S607 cwd=workdir, env=env, text=True, capture_output=True, timeout=timeout_seconds, check=False, ) return VmExecResult( stdout=proc.stdout, stderr=proc.stderr, exit_code=proc.returncode, duration_ms=int((time.monotonic() - started) * 1000), ) except subprocess.TimeoutExpired: return VmExecResult( stdout="", stderr=f"command timed out after {timeout_seconds}s", exit_code=124, duration_ms=int((time.monotonic() - started) * 1000), ) class VmBackend: """Backend interface for lifecycle operations.""" def create(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError def start(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError def exec( # pragma: no cover self, instance: VmInstance, command: str, timeout_seconds: int ) -> VmExecResult: raise NotImplementedError def stop(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError def delete(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError class MockBackend(VmBackend): """Host-process backend used for development and testability.""" def create(self, instance: VmInstance) -> None: instance.workdir.mkdir(parents=True, exist_ok=False) def start(self, instance: VmInstance) -> None: marker_path = instance.workdir / ".started" marker_path.write_text("started\n", encoding="utf-8") def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: return _run_host_command(instance.workdir, command, timeout_seconds) def stop(self, instance: VmInstance) -> None: marker_path = instance.workdir / ".stopped" marker_path.write_text("stopped\n", encoding="utf-8") def delete(self, instance: VmInstance) -> None: shutil.rmtree(instance.workdir, ignore_errors=True) class FirecrackerBackend(VmBackend): # pragma: no cover """Host-gated backend that validates Firecracker prerequisites.""" def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None: self._artifacts_dir = artifacts_dir self._firecracker_bin = firecracker_bin self._jailer_bin = jailer_bin if not self._firecracker_bin.exists(): raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}") if not self._jailer_bin.exists(): raise RuntimeError(f"bundled jailer binary not found at {self._jailer_bin}") if not Path("/dev/kvm").exists(): raise RuntimeError("/dev/kvm is not available on this host") def create(self, instance: VmInstance) -> None: instance.workdir.mkdir(parents=True, exist_ok=False) artifacts = resolve_artifacts(self._artifacts_dir, instance.profile) if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists(): raise RuntimeError( f"missing profile artifacts for {instance.profile}; expected " f"{artifacts.kernel_image} and {artifacts.rootfs_image}" ) instance.metadata["kernel_image"] = str(artifacts.kernel_image) instance.metadata["rootfs_image"] = str(artifacts.rootfs_image) def start(self, instance: VmInstance) -> None: proc = subprocess.run( # noqa: S603 [str(self._firecracker_bin), "--version"], text=True, capture_output=True, check=False, ) if proc.returncode != 0: raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}") instance.metadata["firecracker_version"] = proc.stdout.strip() instance.metadata["jailer_path"] = str(self._jailer_bin) def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: # Temporary compatibility path until guest-side execution agent is integrated. return _run_host_command(instance.workdir, command, timeout_seconds) def stop(self, instance: VmInstance) -> None: del instance def delete(self, instance: VmInstance) -> None: shutil.rmtree(instance.workdir, ignore_errors=True) class VmManager: """In-process lifecycle manager for ephemeral VM environments.""" MIN_VCPUS = 1 MAX_VCPUS = 8 MIN_MEM_MIB = 256 MAX_MEM_MIB = 32768 MIN_TTL_SECONDS = 60 MAX_TTL_SECONDS = 3600 def __init__( self, *, backend_name: str | None = None, base_dir: Path | None = None, artifacts_dir: Path | None = None, max_active_vms: int = 4, runtime_paths: RuntimePaths | None = None, ) -> None: self._backend_name = backend_name or "firecracker" self._base_dir = base_dir or Path("/tmp/pyro-mcp") self._runtime_paths = runtime_paths if self._backend_name == "firecracker": self._runtime_paths = self._runtime_paths or resolve_runtime_paths() self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir else: self._artifacts_dir = artifacts_dir or Path( os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts") ) self._max_active_vms = max_active_vms self._lock = threading.Lock() self._instances: dict[str, VmInstance] = {} self._base_dir.mkdir(parents=True, exist_ok=True) self._backend = self._build_backend() def _build_backend(self) -> VmBackend: if self._backend_name == "mock": return MockBackend() if self._backend_name == "firecracker": if self._runtime_paths is None: raise RuntimeError("runtime paths were not initialized for firecracker backend") return FirecrackerBackend( self._artifacts_dir, firecracker_bin=self._runtime_paths.firecracker_bin, jailer_bin=self._runtime_paths.jailer_bin, ) raise ValueError("invalid backend; expected one of: mock, firecracker") def list_profiles(self) -> list[dict[str, object]]: return list_profiles() def create_vm( self, *, profile: str, vcpu_count: int, mem_mib: int, ttl_seconds: int ) -> dict[str, Any]: self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds) get_profile(profile) now = time.time() with self._lock: self._reap_expired_locked(now) active_count = len(self._instances) if active_count >= self._max_active_vms: raise RuntimeError( f"max active VMs reached ({self._max_active_vms}); delete old VMs first" ) vm_id = uuid.uuid4().hex[:12] instance = VmInstance( vm_id=vm_id, profile=profile, vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds, created_at=now, expires_at=now + ttl_seconds, workdir=self._base_dir / vm_id, ) self._backend.create(instance) self._instances[vm_id] = instance return self._serialize(instance) def start_vm(self, vm_id: str) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) if instance.state not in {"created", "stopped"}: raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}") self._backend.start(instance) instance.state = "started" return self._serialize(instance) def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: if timeout_seconds <= 0: raise ValueError("timeout_seconds must be positive") with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) if instance.state != "started": raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec") exec_result = self._backend.exec(instance, command, timeout_seconds) cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup") return { "vm_id": vm_id, "command": command, "stdout": exec_result.stdout, "stderr": exec_result.stderr, "exit_code": exec_result.exit_code, "duration_ms": exec_result.duration_ms, "cleanup": cleanup, } def stop_vm(self, vm_id: str) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._backend.stop(instance) instance.state = "stopped" return self._serialize(instance) def delete_vm(self, vm_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) if instance.state == "started": self._backend.stop(instance) instance.state = "stopped" self._backend.delete(instance) del self._instances[vm_id] return {"vm_id": vm_id, "deleted": True, "reason": reason} def status_vm(self, vm_id: str) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) return self._serialize(instance) def reap_expired(self) -> dict[str, Any]: now = time.time() with self._lock: expired_vm_ids = [ vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now ] for vm_id in expired_vm_ids: instance = self._instances[vm_id] if instance.state == "started": self._backend.stop(instance) instance.state = "stopped" self._backend.delete(instance) del self._instances[vm_id] return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)} def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None: if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS: raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}") if not self.MIN_MEM_MIB <= mem_mib <= self.MAX_MEM_MIB: raise ValueError(f"mem_mib must be between {self.MIN_MEM_MIB} and {self.MAX_MEM_MIB}") if not self.MIN_TTL_SECONDS <= ttl_seconds <= self.MAX_TTL_SECONDS: raise ValueError( f"ttl_seconds must be between {self.MIN_TTL_SECONDS} and {self.MAX_TTL_SECONDS}" ) def _serialize(self, instance: VmInstance) -> dict[str, Any]: return { "vm_id": instance.vm_id, "profile": instance.profile, "vcpu_count": instance.vcpu_count, "mem_mib": instance.mem_mib, "ttl_seconds": instance.ttl_seconds, "created_at": instance.created_at, "expires_at": instance.expires_at, "state": instance.state, "metadata": instance.metadata, } def _get_instance_locked(self, vm_id: str) -> VmInstance: try: return self._instances[vm_id] except KeyError as exc: raise ValueError(f"vm {vm_id!r} does not exist") from exc def _reap_expired_locked(self, now: float) -> None: expired_vm_ids = [ vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now ] for vm_id in expired_vm_ids: instance = self._instances[vm_id] if instance.state == "started": self._backend.stop(instance) instance.state = "stopped" self._backend.delete(instance) del self._instances[vm_id] def _ensure_not_expired_locked(self, instance: VmInstance, now: float) -> None: if instance.expires_at <= now: vm_id = instance.vm_id self._reap_expired_locked(now) raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted")