Refactor public API around environments

This commit is contained in:
Thales Maciel 2026-03-08 16:02:02 -03:00
parent 57dae52cc2
commit 5d5243df23
41 changed files with 1301 additions and 459 deletions

View file

@ -19,10 +19,10 @@ from pyro_mcp.runtime import (
resolve_runtime_paths,
runtime_capabilities,
)
from pyro_mcp.vm_environments import EnvironmentStore, get_environment
from pyro_mcp.vm_firecracker import build_launch_plan
from pyro_mcp.vm_guest import VsockExecClient
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
VmState = Literal["created", "started", "stopped"]
@ -32,7 +32,7 @@ class VmInstance:
"""In-memory VM lifecycle record."""
vm_id: str
profile: str
environment: str
vcpu_count: int
mem_mib: int
ttl_seconds: int
@ -85,6 +85,23 @@ def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmEx
)
def _copy_rootfs(source: Path, dest: Path) -> str:
dest.parent.mkdir(parents=True, exist_ok=True)
try:
proc = subprocess.run( # noqa: S603
["cp", "--reflink=auto", str(source), str(dest)],
text=True,
capture_output=True,
check=False,
)
if proc.returncode == 0:
return "reflink_or_copy"
except OSError:
pass
shutil.copy2(source, dest)
return "copy2"
class VmBackend:
"""Backend interface for lifecycle operations."""
@ -132,14 +149,14 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
def __init__(
self,
artifacts_dir: Path,
environment_store: EnvironmentStore,
firecracker_bin: Path,
jailer_bin: Path,
runtime_capabilities: RuntimeCapabilities,
network_manager: TapNetworkManager | None = None,
guest_exec_client: VsockExecClient | None = None,
) -> None:
self._artifacts_dir = artifacts_dir
self._environment_store = environment_store
self._firecracker_bin = firecracker_bin
self._jailer_bin = jailer_bin
self._runtime_capabilities = runtime_capabilities
@ -156,15 +173,26 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
def create(self, instance: VmInstance) -> None:
instance.workdir.mkdir(parents=True, exist_ok=False)
try:
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
installed_environment = self._environment_store.ensure_installed(instance.environment)
if (
not installed_environment.kernel_image.exists()
or not installed_environment.rootfs_image.exists()
):
raise RuntimeError(
f"missing profile artifacts for {instance.profile}; expected "
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
f"missing environment artifacts for {instance.environment}; expected "
f"{installed_environment.kernel_image} and {installed_environment.rootfs_image}"
)
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
instance.metadata["environment_version"] = installed_environment.version
instance.metadata["environment_source"] = installed_environment.source
if installed_environment.source_digest is not None:
instance.metadata["environment_digest"] = installed_environment.source_digest
instance.metadata["environment_install_dir"] = str(installed_environment.install_dir)
instance.metadata["kernel_image"] = str(installed_environment.kernel_image)
rootfs_copy = instance.workdir / "rootfs.ext4"
shutil.copy2(artifacts.rootfs_image, rootfs_copy)
instance.metadata["rootfs_clone_mode"] = _copy_rootfs(
installed_environment.rootfs_image,
rootfs_copy,
)
instance.metadata["rootfs_image"] = str(rootfs_copy)
if instance.network_requested:
network = self._network_manager.allocate(instance.vm_id)
@ -320,28 +348,35 @@ class VmManager:
*,
backend_name: str | None = None,
base_dir: Path | None = None,
artifacts_dir: Path | None = None,
cache_dir: Path | None = None,
max_active_vms: int = 4,
runtime_paths: RuntimePaths | None = None,
network_manager: TapNetworkManager | None = None,
) -> None:
self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp")
resolved_cache_dir = cache_dir or self._base_dir / ".environment-cache"
self._runtime_paths = runtime_paths
if self._backend_name == "firecracker":
self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir
self._runtime_capabilities = runtime_capabilities(self._runtime_paths)
else:
self._artifacts_dir = artifacts_dir or Path(
os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts")
self._environment_store = EnvironmentStore(
runtime_paths=self._runtime_paths,
cache_dir=resolved_cache_dir,
)
else:
self._runtime_capabilities = RuntimeCapabilities(
supports_vm_boot=False,
supports_guest_exec=False,
supports_guest_network=False,
reason="mock backend does not boot a guest",
)
if self._runtime_paths is None:
self._runtime_paths = resolve_runtime_paths(verify_checksums=False)
self._environment_store = EnvironmentStore(
runtime_paths=self._runtime_paths,
cache_dir=resolved_cache_dir,
)
self._max_active_vms = max_active_vms
if network_manager is not None:
self._network_manager = network_manager
@ -361,7 +396,7 @@ class VmManager:
if self._runtime_paths is None:
raise RuntimeError("runtime paths were not initialized for firecracker backend")
return FirecrackerBackend(
self._artifacts_dir,
self._environment_store,
firecracker_bin=self._runtime_paths.firecracker_bin,
jailer_bin=self._runtime_paths.jailer_bin,
runtime_capabilities=self._runtime_capabilities,
@ -369,20 +404,29 @@ class VmManager:
)
raise ValueError("invalid backend; expected one of: mock, firecracker")
def list_profiles(self) -> list[dict[str, object]]:
return list_profiles()
def list_environments(self) -> list[dict[str, object]]:
return self._environment_store.list_environments()
def pull_environment(self, environment: str) -> dict[str, object]:
return self._environment_store.pull_environment(environment)
def inspect_environment(self, environment: str) -> dict[str, object]:
return self._environment_store.inspect_environment(environment)
def prune_environments(self) -> dict[str, object]:
return self._environment_store.prune_environments()
def create_vm(
self,
*,
profile: str,
environment: str,
vcpu_count: int,
mem_mib: int,
ttl_seconds: int,
network: bool = False,
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_profile(profile)
get_environment(environment, runtime_paths=self._runtime_paths)
now = time.time()
with self._lock:
self._reap_expired_locked(now)
@ -394,7 +438,7 @@ class VmManager:
vm_id = uuid.uuid4().hex[:12]
instance = VmInstance(
vm_id=vm_id,
profile=profile,
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
@ -410,7 +454,7 @@ class VmManager:
def run_vm(
self,
*,
profile: str,
environment: str,
command: str,
vcpu_count: int,
mem_mib: int,
@ -419,7 +463,7 @@ class VmManager:
network: bool = False,
) -> dict[str, Any]:
created = self.create_vm(
profile=profile,
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
@ -459,6 +503,8 @@ class VmManager:
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return {
"vm_id": vm_id,
"environment": instance.environment,
"environment_version": instance.metadata.get("environment_version"),
"command": command,
"stdout": exec_result.stdout,
"stderr": exec_result.stderr,
@ -532,7 +578,8 @@ class VmManager:
def _serialize(self, instance: VmInstance) -> dict[str, Any]:
return {
"vm_id": instance.vm_id,
"profile": instance.profile,
"environment": instance.environment,
"environment_version": instance.metadata.get("environment_version"),
"vcpu_count": instance.vcpu_count,
"mem_mib": instance.mem_mib,
"ttl_seconds": instance.ttl_seconds,