"""Lifecycle manager for ephemeral VM environments and persistent tasks.""" from __future__ import annotations import json import os import shlex import shutil import signal import subprocess import threading import time import uuid from dataclasses import dataclass, field from pathlib import Path from typing import Any, Literal, cast from pyro_mcp.runtime import ( RuntimeCapabilities, RuntimePaths, resolve_runtime_paths, runtime_capabilities, ) from pyro_mcp.vm_environments import EnvironmentStore, default_cache_dir, get_environment from pyro_mcp.vm_firecracker import build_launch_plan from pyro_mcp.vm_guest import VsockExecClient from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager VmState = Literal["created", "started", "stopped"] DEFAULT_VCPU_COUNT = 1 DEFAULT_MEM_MIB = 1024 DEFAULT_TIMEOUT_SECONDS = 30 DEFAULT_TTL_SECONDS = 600 DEFAULT_ALLOW_HOST_COMPAT = False TASK_LAYOUT_VERSION = 1 TASK_WORKSPACE_DIRNAME = "workspace" TASK_COMMANDS_DIRNAME = "commands" TASK_RUNTIME_DIRNAME = "runtime" TASK_WORKSPACE_GUEST_PATH = "/workspace" @dataclass class VmInstance: """In-memory VM lifecycle record.""" vm_id: str environment: str vcpu_count: int mem_mib: int ttl_seconds: int created_at: float expires_at: float workdir: Path state: VmState = "created" network_requested: bool = False allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT firecracker_pid: int | None = None last_error: str | None = None metadata: dict[str, str] = field(default_factory=dict) network: NetworkConfig | None = None @dataclass class TaskRecord: """Persistent task metadata stored on disk.""" task_id: str environment: str vcpu_count: int mem_mib: int ttl_seconds: int created_at: float expires_at: float state: VmState network_requested: bool allow_host_compat: bool firecracker_pid: int | None = None last_error: str | None = None metadata: dict[str, str] = field(default_factory=dict) network: NetworkConfig | None = None command_count: int = 0 last_command: dict[str, Any] | None = None @classmethod def from_instance( cls, instance: VmInstance, *, command_count: int = 0, last_command: dict[str, Any] | None = None, ) -> TaskRecord: return cls( task_id=instance.vm_id, environment=instance.environment, vcpu_count=instance.vcpu_count, mem_mib=instance.mem_mib, ttl_seconds=instance.ttl_seconds, created_at=instance.created_at, expires_at=instance.expires_at, state=instance.state, network_requested=instance.network_requested, allow_host_compat=instance.allow_host_compat, firecracker_pid=instance.firecracker_pid, last_error=instance.last_error, metadata=dict(instance.metadata), network=instance.network, command_count=command_count, last_command=last_command, ) def to_instance(self, *, workdir: Path) -> VmInstance: return VmInstance( vm_id=self.task_id, environment=self.environment, vcpu_count=self.vcpu_count, mem_mib=self.mem_mib, ttl_seconds=self.ttl_seconds, created_at=self.created_at, expires_at=self.expires_at, workdir=workdir, state=self.state, network_requested=self.network_requested, allow_host_compat=self.allow_host_compat, firecracker_pid=self.firecracker_pid, last_error=self.last_error, metadata=dict(self.metadata), network=self.network, ) def to_payload(self) -> dict[str, Any]: return { "layout_version": TASK_LAYOUT_VERSION, "task_id": self.task_id, "environment": self.environment, "vcpu_count": self.vcpu_count, "mem_mib": self.mem_mib, "ttl_seconds": self.ttl_seconds, "created_at": self.created_at, "expires_at": self.expires_at, "state": self.state, "network_requested": self.network_requested, "allow_host_compat": self.allow_host_compat, "firecracker_pid": self.firecracker_pid, "last_error": self.last_error, "metadata": self.metadata, "network": _serialize_network(self.network), "command_count": self.command_count, "last_command": self.last_command, } @classmethod def from_payload(cls, payload: dict[str, Any]) -> TaskRecord: return cls( task_id=str(payload["task_id"]), environment=str(payload["environment"]), vcpu_count=int(payload["vcpu_count"]), mem_mib=int(payload["mem_mib"]), ttl_seconds=int(payload["ttl_seconds"]), created_at=float(payload["created_at"]), expires_at=float(payload["expires_at"]), state=cast(VmState, str(payload.get("state", "stopped"))), network_requested=bool(payload.get("network_requested", False)), allow_host_compat=bool(payload.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)), firecracker_pid=_optional_int(payload.get("firecracker_pid")), last_error=_optional_str(payload.get("last_error")), metadata=_string_dict(payload.get("metadata")), network=_deserialize_network(payload.get("network")), command_count=int(payload.get("command_count", 0)), last_command=_optional_dict(payload.get("last_command")), ) @dataclass(frozen=True) class VmExecResult: """Command execution output.""" stdout: str stderr: str exit_code: int duration_ms: int def _optional_int(value: object) -> int | None: if value is None: return None if isinstance(value, bool): return int(value) if isinstance(value, int): return value if isinstance(value, float): return int(value) if isinstance(value, str): return int(value) raise TypeError("expected integer-compatible payload") def _optional_str(value: object) -> str | None: if value is None: return None return str(value) def _optional_dict(value: object) -> dict[str, Any] | None: if value is None: return None if not isinstance(value, dict): raise TypeError("expected dictionary payload") return dict(value) def _string_dict(value: object) -> dict[str, str]: if not isinstance(value, dict): return {} return {str(key): str(item) for key, item in value.items()} def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None: if network is None: return None return { "vm_id": network.vm_id, "tap_name": network.tap_name, "guest_ip": network.guest_ip, "gateway_ip": network.gateway_ip, "subnet_cidr": network.subnet_cidr, "mac_address": network.mac_address, "dns_servers": list(network.dns_servers), } def _deserialize_network(payload: object) -> NetworkConfig | None: if payload is None: return None if not isinstance(payload, dict): raise TypeError("expected dictionary payload") dns_servers = payload.get("dns_servers", []) dns_values = tuple(str(item) for item in dns_servers) if isinstance(dns_servers, list) else () return NetworkConfig( vm_id=str(payload["vm_id"]), tap_name=str(payload["tap_name"]), guest_ip=str(payload["guest_ip"]), gateway_ip=str(payload["gateway_ip"]), subnet_cidr=str(payload["subnet_cidr"]), mac_address=str(payload["mac_address"]), dns_servers=dns_values, ) def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult: started = time.monotonic() env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)} try: proc = subprocess.run( # noqa: S603 ["bash", "-lc", command], # noqa: S607 cwd=workdir, env=env, text=True, capture_output=True, timeout=timeout_seconds, check=False, ) return VmExecResult( stdout=proc.stdout, stderr=proc.stderr, exit_code=proc.returncode, duration_ms=int((time.monotonic() - started) * 1000), ) except subprocess.TimeoutExpired: return VmExecResult( stdout="", stderr=f"command timed out after {timeout_seconds}s", exit_code=124, duration_ms=int((time.monotonic() - started) * 1000), ) def _copy_rootfs(source: Path, dest: Path) -> str: dest.parent.mkdir(parents=True, exist_ok=True) try: proc = subprocess.run( # noqa: S603 ["cp", "--reflink=auto", str(source), str(dest)], text=True, capture_output=True, check=False, ) if proc.returncode == 0: return "reflink_or_copy" except OSError: pass shutil.copy2(source, dest) return "copy2" def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str: if cwd is None: return command quoted_cwd = shlex.quote(cwd) return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}" def _pid_is_running(pid: int | None) -> bool: if pid is None: return False try: os.kill(pid, 0) except ProcessLookupError: return False except PermissionError: return True return True class VmBackend: """Backend interface for lifecycle operations.""" def create(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError def start(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError def exec( # pragma: no cover self, instance: VmInstance, command: str, timeout_seconds: int, *, workdir: Path | None = None, ) -> VmExecResult: raise NotImplementedError def stop(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError def delete(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError class MockBackend(VmBackend): """Host-process backend used for development and testability.""" def create(self, instance: VmInstance) -> None: instance.workdir.mkdir(parents=True, exist_ok=False) def start(self, instance: VmInstance) -> None: marker_path = instance.workdir / ".started" marker_path.write_text("started\n", encoding="utf-8") def exec( self, instance: VmInstance, command: str, timeout_seconds: int, *, workdir: Path | None = None, ) -> VmExecResult: return _run_host_command(workdir or instance.workdir, command, timeout_seconds) def stop(self, instance: VmInstance) -> None: marker_path = instance.workdir / ".stopped" marker_path.write_text("stopped\n", encoding="utf-8") def delete(self, instance: VmInstance) -> None: shutil.rmtree(instance.workdir, ignore_errors=True) class FirecrackerBackend(VmBackend): # pragma: no cover """Host-gated backend that validates Firecracker prerequisites.""" def __init__( self, environment_store: EnvironmentStore, firecracker_bin: Path, jailer_bin: Path, runtime_capabilities: RuntimeCapabilities, network_manager: TapNetworkManager | None = None, guest_exec_client: VsockExecClient | None = None, ) -> None: self._environment_store = environment_store self._firecracker_bin = firecracker_bin self._jailer_bin = jailer_bin self._runtime_capabilities = runtime_capabilities self._network_manager = network_manager or TapNetworkManager() self._guest_exec_client = guest_exec_client or VsockExecClient() self._processes: dict[str, subprocess.Popen[str]] = {} if not self._firecracker_bin.exists(): raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}") if not self._jailer_bin.exists(): raise RuntimeError(f"bundled jailer binary not found at {self._jailer_bin}") if not Path("/dev/kvm").exists(): raise RuntimeError("/dev/kvm is not available on this host") def create(self, instance: VmInstance) -> None: instance.workdir.mkdir(parents=True, exist_ok=False) try: installed_environment = self._environment_store.ensure_installed(instance.environment) if ( not installed_environment.kernel_image.exists() or not installed_environment.rootfs_image.exists() ): raise RuntimeError( f"missing environment artifacts for {instance.environment}; expected " f"{installed_environment.kernel_image} and {installed_environment.rootfs_image}" ) instance.metadata["environment_version"] = installed_environment.version instance.metadata["environment_source"] = installed_environment.source if installed_environment.source_digest is not None: instance.metadata["environment_digest"] = installed_environment.source_digest instance.metadata["environment_install_dir"] = str(installed_environment.install_dir) instance.metadata["kernel_image"] = str(installed_environment.kernel_image) rootfs_copy = instance.workdir / "rootfs.ext4" instance.metadata["rootfs_clone_mode"] = _copy_rootfs( installed_environment.rootfs_image, rootfs_copy, ) instance.metadata["rootfs_image"] = str(rootfs_copy) if instance.network_requested: network = self._network_manager.allocate(instance.vm_id) instance.network = network instance.metadata.update(self._network_manager.to_metadata(network)) else: instance.network = None instance.metadata["network_enabled"] = "false" except Exception: shutil.rmtree(instance.workdir, ignore_errors=True) raise def start(self, instance: VmInstance) -> None: launch_plan = build_launch_plan(instance) instance.metadata["firecracker_config_path"] = str(launch_plan.config_path) instance.metadata["guest_network_path"] = str(launch_plan.guest_network_path) instance.metadata["guest_exec_path"] = str(launch_plan.guest_exec_path) instance.metadata["guest_cid"] = str(launch_plan.guest_cid) instance.metadata["guest_exec_port"] = str(launch_plan.vsock_port) instance.metadata["guest_exec_uds_path"] = str(instance.workdir / "vsock.sock") serial_log_path = instance.workdir / "serial.log" firecracker_log_path = instance.workdir / "firecracker.log" firecracker_log_path.touch() instance.metadata["serial_log_path"] = str(serial_log_path) instance.metadata["firecracker_log_path"] = str(firecracker_log_path) proc = subprocess.run( # noqa: S603 [str(self._firecracker_bin), "--version"], text=True, capture_output=True, check=False, ) if proc.returncode != 0: raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}") instance.metadata["firecracker_version"] = proc.stdout.strip() instance.metadata["jailer_path"] = str(self._jailer_bin) if not self._runtime_capabilities.supports_vm_boot: instance.metadata["execution_mode"] = "host_compat" instance.metadata["boot_mode"] = "shim" if self._runtime_capabilities.reason is not None: instance.metadata["runtime_reason"] = self._runtime_capabilities.reason return with serial_log_path.open("w", encoding="utf-8") as serial_fp: process = subprocess.Popen( # noqa: S603 [ str(self._firecracker_bin), "--no-api", "--config-file", str(launch_plan.config_path), "--log-path", str(firecracker_log_path), "--level", "Info", ], stdout=serial_fp, stderr=subprocess.STDOUT, text=True, start_new_session=True, ) self._processes[instance.vm_id] = process time.sleep(2) if process.poll() is not None: serial_log = serial_log_path.read_text(encoding="utf-8", errors="ignore") firecracker_log = firecracker_log_path.read_text(encoding="utf-8", errors="ignore") self._processes.pop(instance.vm_id, None) raise RuntimeError( "firecracker microVM exited during startup: " f"{(serial_log or firecracker_log).strip()}" ) instance.firecracker_pid = process.pid instance.metadata["execution_mode"] = ( "guest_vsock" if self._runtime_capabilities.supports_guest_exec else "guest_boot_only" ) instance.metadata["boot_mode"] = "native" def exec( self, instance: VmInstance, command: str, timeout_seconds: int, *, workdir: Path | None = None, ) -> VmExecResult: if self._runtime_capabilities.supports_guest_exec: guest_cid = int(instance.metadata["guest_cid"]) port = int(instance.metadata["guest_exec_port"]) uds_path = instance.metadata.get("guest_exec_uds_path") deadline = time.monotonic() + min(timeout_seconds, 10) while True: try: response = self._guest_exec_client.exec( guest_cid, port, command, timeout_seconds, uds_path=uds_path, ) break except (OSError, RuntimeError) as exc: if time.monotonic() >= deadline: raise RuntimeError( f"guest exec transport did not become ready: {exc}" ) from exc time.sleep(0.2) return VmExecResult( stdout=response.stdout, stderr=response.stderr, exit_code=response.exit_code, duration_ms=response.duration_ms, ) instance.metadata["execution_mode"] = "host_compat" return _run_host_command(workdir or instance.workdir, command, timeout_seconds) def stop(self, instance: VmInstance) -> None: process = self._processes.pop(instance.vm_id, None) if process is not None: process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() process.wait(timeout=5) instance.firecracker_pid = None return if instance.firecracker_pid is None: return try: os.kill(instance.firecracker_pid, signal.SIGTERM) except ProcessLookupError: instance.firecracker_pid = None return deadline = time.monotonic() + 5 while time.monotonic() < deadline: try: os.kill(instance.firecracker_pid, 0) except ProcessLookupError: instance.firecracker_pid = None return time.sleep(0.1) os.kill(instance.firecracker_pid, signal.SIGKILL) instance.firecracker_pid = None def delete(self, instance: VmInstance) -> None: self._processes.pop(instance.vm_id, None) if instance.network is not None: self._network_manager.cleanup(instance.network) shutil.rmtree(instance.workdir, ignore_errors=True) class VmManager: """In-process lifecycle manager for ephemeral VM environments and tasks.""" MIN_VCPUS = 1 MAX_VCPUS = 8 MIN_MEM_MIB = 256 MAX_MEM_MIB = 32768 MIN_TTL_SECONDS = 60 MAX_TTL_SECONDS = 3600 DEFAULT_VCPU_COUNT = DEFAULT_VCPU_COUNT DEFAULT_MEM_MIB = DEFAULT_MEM_MIB DEFAULT_TIMEOUT_SECONDS = DEFAULT_TIMEOUT_SECONDS DEFAULT_TTL_SECONDS = DEFAULT_TTL_SECONDS DEFAULT_ALLOW_HOST_COMPAT = DEFAULT_ALLOW_HOST_COMPAT def __init__( self, *, backend_name: str | None = None, base_dir: Path | None = None, cache_dir: Path | None = None, max_active_vms: int = 4, runtime_paths: RuntimePaths | None = None, network_manager: TapNetworkManager | None = None, ) -> None: self._backend_name = backend_name or "firecracker" self._base_dir = base_dir or Path("/tmp/pyro-mcp") self._tasks_dir = self._base_dir / "tasks" resolved_cache_dir = cache_dir or default_cache_dir() self._runtime_paths = runtime_paths if self._backend_name == "firecracker": self._runtime_paths = self._runtime_paths or resolve_runtime_paths() self._runtime_capabilities = runtime_capabilities(self._runtime_paths) self._environment_store = EnvironmentStore( runtime_paths=self._runtime_paths, cache_dir=resolved_cache_dir, ) else: self._runtime_capabilities = RuntimeCapabilities( supports_vm_boot=False, supports_guest_exec=False, supports_guest_network=False, reason="mock backend does not boot a guest", ) if self._runtime_paths is None: self._runtime_paths = resolve_runtime_paths(verify_checksums=False) self._environment_store = EnvironmentStore( runtime_paths=self._runtime_paths, cache_dir=resolved_cache_dir, ) self._max_active_vms = max_active_vms if network_manager is not None: self._network_manager = network_manager elif self._backend_name == "firecracker": self._network_manager = TapNetworkManager(enabled=True) else: self._network_manager = TapNetworkManager(enabled=False) self._lock = threading.Lock() self._instances: dict[str, VmInstance] = {} self._base_dir.mkdir(parents=True, exist_ok=True) self._tasks_dir.mkdir(parents=True, exist_ok=True) self._backend = self._build_backend() def _build_backend(self) -> VmBackend: if self._backend_name == "mock": return MockBackend() if self._backend_name == "firecracker": if self._runtime_paths is None: raise RuntimeError("runtime paths were not initialized for firecracker backend") return FirecrackerBackend( self._environment_store, firecracker_bin=self._runtime_paths.firecracker_bin, jailer_bin=self._runtime_paths.jailer_bin, runtime_capabilities=self._runtime_capabilities, network_manager=self._network_manager, ) raise ValueError("invalid backend; expected one of: mock, firecracker") def list_environments(self) -> list[dict[str, object]]: return self._environment_store.list_environments() def pull_environment(self, environment: str) -> dict[str, object]: return self._environment_store.pull_environment(environment) def inspect_environment(self, environment: str) -> dict[str, object]: return self._environment_store.inspect_environment(environment) def prune_environments(self) -> dict[str, object]: return self._environment_store.prune_environments() def create_vm( self, *, environment: str, vcpu_count: int = DEFAULT_VCPU_COUNT, mem_mib: int = DEFAULT_MEM_MIB, ttl_seconds: int = DEFAULT_TTL_SECONDS, network: bool = False, allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, ) -> dict[str, Any]: self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds) get_environment(environment, runtime_paths=self._runtime_paths) now = time.time() with self._lock: self._reap_expired_locked(now) self._reap_expired_tasks_locked(now) active_count = len(self._instances) + self._count_tasks_locked() if active_count >= self._max_active_vms: raise RuntimeError( f"max active VMs reached ({self._max_active_vms}); delete old VMs first" ) vm_id = uuid.uuid4().hex[:12] instance = VmInstance( vm_id=vm_id, environment=environment, vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds, created_at=now, expires_at=now + ttl_seconds, workdir=self._base_dir / vm_id, network_requested=network, allow_host_compat=allow_host_compat, ) instance.metadata["allow_host_compat"] = str(allow_host_compat).lower() self._backend.create(instance) self._instances[vm_id] = instance return self._serialize(instance) def run_vm( self, *, environment: str, command: str, vcpu_count: int = DEFAULT_VCPU_COUNT, mem_mib: int = DEFAULT_MEM_MIB, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, ttl_seconds: int = DEFAULT_TTL_SECONDS, network: bool = False, allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, ) -> dict[str, Any]: created = self.create_vm( environment=environment, vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds, network=network, allow_host_compat=allow_host_compat, ) vm_id = str(created["vm_id"]) try: self.start_vm(vm_id) return self.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds) except Exception: try: self.delete_vm(vm_id, reason="run_vm_error_cleanup") except ValueError: pass raise def start_vm(self, vm_id: str) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) self._start_instance_locked(instance) return self._serialize(instance) def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) exec_instance = instance exec_result, execution_mode = self._exec_instance( exec_instance, command=command, timeout_seconds=timeout_seconds, ) cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup") return { "vm_id": vm_id, "environment": exec_instance.environment, "environment_version": exec_instance.metadata.get("environment_version"), "command": command, "stdout": exec_result.stdout, "stderr": exec_result.stderr, "exit_code": exec_result.exit_code, "duration_ms": exec_result.duration_ms, "execution_mode": execution_mode, "cleanup": cleanup, } def stop_vm(self, vm_id: str) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._backend.stop(instance) instance.state = "stopped" return self._serialize(instance) def delete_vm(self, vm_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) if instance.state == "started": self._backend.stop(instance) instance.state = "stopped" self._backend.delete(instance) del self._instances[vm_id] return {"vm_id": vm_id, "deleted": True, "reason": reason} def status_vm(self, vm_id: str) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) return self._serialize(instance) def network_info_vm(self, vm_id: str) -> dict[str, Any]: with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) if instance.network is None: return { "vm_id": vm_id, "network_enabled": False, "outbound_connectivity_expected": False, "reason": "network configuration is unavailable for this VM", } return {"vm_id": vm_id, **self._network_manager.network_info(instance.network)} def reap_expired(self) -> dict[str, Any]: now = time.time() with self._lock: expired_vm_ids = [ vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now ] for vm_id in expired_vm_ids: instance = self._instances[vm_id] if instance.state == "started": self._backend.stop(instance) instance.state = "stopped" self._backend.delete(instance) del self._instances[vm_id] return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)} def create_task( self, *, environment: str, vcpu_count: int = DEFAULT_VCPU_COUNT, mem_mib: int = DEFAULT_MEM_MIB, ttl_seconds: int = DEFAULT_TTL_SECONDS, network: bool = False, allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, ) -> dict[str, Any]: self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds) get_environment(environment, runtime_paths=self._runtime_paths) now = time.time() task_id = uuid.uuid4().hex[:12] task_dir = self._task_dir(task_id) runtime_dir = self._task_runtime_dir(task_id) workspace_dir = self._task_workspace_dir(task_id) commands_dir = self._task_commands_dir(task_id) task_dir.mkdir(parents=True, exist_ok=False) workspace_dir.mkdir(parents=True, exist_ok=True) commands_dir.mkdir(parents=True, exist_ok=True) instance = VmInstance( vm_id=task_id, environment=environment, vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds, created_at=now, expires_at=now + ttl_seconds, workdir=runtime_dir, network_requested=network, allow_host_compat=allow_host_compat, ) instance.metadata["allow_host_compat"] = str(allow_host_compat).lower() instance.metadata["workspace_path"] = TASK_WORKSPACE_GUEST_PATH instance.metadata["workspace_host_dir"] = str(workspace_dir) try: with self._lock: self._reap_expired_locked(now) self._reap_expired_tasks_locked(now) active_count = len(self._instances) + self._count_tasks_locked() if active_count >= self._max_active_vms: raise RuntimeError( f"max active VMs reached ({self._max_active_vms}); delete old VMs first" ) self._backend.create(instance) with self._lock: self._start_instance_locked(instance) self._require_guest_exec_or_opt_in(instance) if self._runtime_capabilities.supports_guest_exec: self._backend.exec( instance, f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}", 10, ) else: instance.metadata["execution_mode"] = "host_compat" task = TaskRecord.from_instance(instance) self._save_task_locked(task) return self._serialize_task(task) except Exception: if runtime_dir.exists(): try: if instance.state == "started": self._backend.stop(instance) instance.state = "stopped" except Exception: pass try: self._backend.delete(instance) except Exception: pass shutil.rmtree(task_dir, ignore_errors=True) raise def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]: if timeout_seconds <= 0: raise ValueError("timeout_seconds must be positive") with self._lock: task = self._load_task_locked(task_id) self._ensure_task_not_expired_locked(task, time.time()) self._refresh_task_liveness_locked(task) if task.state != "started": raise RuntimeError(f"task {task_id} must be in 'started' state before task_exec") instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id)) exec_result, execution_mode = self._exec_instance( instance, command=command, timeout_seconds=timeout_seconds, host_workdir=self._task_workspace_dir(task.task_id), guest_cwd=TASK_WORKSPACE_GUEST_PATH, ) with self._lock: task = self._load_task_locked(task_id) task.state = instance.state task.firecracker_pid = instance.firecracker_pid task.last_error = instance.last_error task.metadata = dict(instance.metadata) entry = self._record_task_command_locked( task, command=command, exec_result=exec_result, execution_mode=execution_mode, cwd=TASK_WORKSPACE_GUEST_PATH, ) self._save_task_locked(task) return { "task_id": task_id, "environment": task.environment, "environment_version": task.metadata.get("environment_version"), "command": command, "stdout": exec_result.stdout, "stderr": exec_result.stderr, "exit_code": exec_result.exit_code, "duration_ms": exec_result.duration_ms, "execution_mode": execution_mode, "sequence": entry["sequence"], "cwd": TASK_WORKSPACE_GUEST_PATH, } def status_task(self, task_id: str) -> dict[str, Any]: with self._lock: task = self._load_task_locked(task_id) self._ensure_task_not_expired_locked(task, time.time()) self._refresh_task_liveness_locked(task) self._save_task_locked(task) return self._serialize_task(task) def logs_task(self, task_id: str) -> dict[str, Any]: with self._lock: task = self._load_task_locked(task_id) self._ensure_task_not_expired_locked(task, time.time()) self._refresh_task_liveness_locked(task) self._save_task_locked(task) entries = self._read_task_logs_locked(task.task_id) return {"task_id": task.task_id, "count": len(entries), "entries": entries} def delete_task(self, task_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]: with self._lock: task = self._load_task_locked(task_id) instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id)) if task.state == "started": self._backend.stop(instance) task.state = "stopped" self._backend.delete(instance) shutil.rmtree(self._task_dir(task_id), ignore_errors=True) return {"task_id": task_id, "deleted": True, "reason": reason} def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None: if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS: raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}") if not self.MIN_MEM_MIB <= mem_mib <= self.MAX_MEM_MIB: raise ValueError(f"mem_mib must be between {self.MIN_MEM_MIB} and {self.MAX_MEM_MIB}") if not self.MIN_TTL_SECONDS <= ttl_seconds <= self.MAX_TTL_SECONDS: raise ValueError( f"ttl_seconds must be between {self.MIN_TTL_SECONDS} and {self.MAX_TTL_SECONDS}" ) def _serialize(self, instance: VmInstance) -> dict[str, Any]: return { "vm_id": instance.vm_id, "environment": instance.environment, "environment_version": instance.metadata.get("environment_version"), "vcpu_count": instance.vcpu_count, "mem_mib": instance.mem_mib, "ttl_seconds": instance.ttl_seconds, "created_at": instance.created_at, "expires_at": instance.expires_at, "state": instance.state, "network_enabled": instance.network is not None, "allow_host_compat": instance.allow_host_compat, "guest_ip": instance.network.guest_ip if instance.network is not None else None, "tap_name": instance.network.tap_name if instance.network is not None else None, "execution_mode": instance.metadata.get("execution_mode", "pending"), "metadata": instance.metadata, } def _serialize_task(self, task: TaskRecord) -> dict[str, Any]: return { "task_id": task.task_id, "environment": task.environment, "environment_version": task.metadata.get("environment_version"), "vcpu_count": task.vcpu_count, "mem_mib": task.mem_mib, "ttl_seconds": task.ttl_seconds, "created_at": task.created_at, "expires_at": task.expires_at, "state": task.state, "network_enabled": task.network is not None, "allow_host_compat": task.allow_host_compat, "guest_ip": task.network.guest_ip if task.network is not None else None, "tap_name": task.network.tap_name if task.network is not None else None, "execution_mode": task.metadata.get("execution_mode", "pending"), "workspace_path": TASK_WORKSPACE_GUEST_PATH, "command_count": task.command_count, "last_command": task.last_command, "metadata": task.metadata, } def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None: if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat: return reason = self._runtime_capabilities.reason or "runtime does not support real VM boot" raise RuntimeError( "guest boot is unavailable and host compatibility mode is disabled: " f"{reason}. Set allow_host_compat=True (CLI: --allow-host-compat) to opt into " "host execution." ) def _require_guest_exec_or_opt_in(self, instance: VmInstance) -> None: if self._runtime_capabilities.supports_guest_exec or instance.allow_host_compat: return reason = self._runtime_capabilities.reason or ( "runtime does not support guest command execution" ) raise RuntimeError( "guest command execution is unavailable and host compatibility mode is disabled: " f"{reason}. Set allow_host_compat=True (CLI: --allow-host-compat) to opt into " "host execution." ) def _get_instance_locked(self, vm_id: str) -> VmInstance: try: return self._instances[vm_id] except KeyError as exc: raise ValueError(f"vm {vm_id!r} does not exist") from exc def _reap_expired_locked(self, now: float) -> None: expired_vm_ids = [ vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now ] for vm_id in expired_vm_ids: instance = self._instances[vm_id] if instance.state == "started": self._backend.stop(instance) instance.state = "stopped" self._backend.delete(instance) del self._instances[vm_id] def _ensure_not_expired_locked(self, instance: VmInstance, now: float) -> None: if instance.expires_at <= now: vm_id = instance.vm_id self._reap_expired_locked(now) raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted") def _start_instance_locked(self, instance: VmInstance) -> None: if instance.state not in {"created", "stopped"}: raise RuntimeError( f"vm {instance.vm_id} cannot be started from state {instance.state!r}" ) self._require_guest_boot_or_opt_in(instance) if not self._runtime_capabilities.supports_vm_boot: instance.metadata["execution_mode"] = "host_compat" instance.metadata["boot_mode"] = "compat" if self._runtime_capabilities.reason is not None: instance.metadata["runtime_reason"] = self._runtime_capabilities.reason self._backend.start(instance) instance.state = "started" def _exec_instance( self, instance: VmInstance, *, command: str, timeout_seconds: int, host_workdir: Path | None = None, guest_cwd: str | None = None, ) -> tuple[VmExecResult, str]: if timeout_seconds <= 0: raise ValueError("timeout_seconds must be positive") if instance.state != "started": raise RuntimeError(f"vm {instance.vm_id} must be in 'started' state before execution") self._require_guest_exec_or_opt_in(instance) prepared_command = command if self._runtime_capabilities.supports_guest_exec: prepared_command = _wrap_guest_command(command, cwd=guest_cwd) workdir = None else: instance.metadata["execution_mode"] = "host_compat" workdir = host_workdir exec_result = self._backend.exec( instance, prepared_command, timeout_seconds, workdir=workdir, ) execution_mode = instance.metadata.get("execution_mode", "unknown") return exec_result, execution_mode def _task_dir(self, task_id: str) -> Path: return self._tasks_dir / task_id def _task_runtime_dir(self, task_id: str) -> Path: return self._task_dir(task_id) / TASK_RUNTIME_DIRNAME def _task_workspace_dir(self, task_id: str) -> Path: return self._task_dir(task_id) / TASK_WORKSPACE_DIRNAME def _task_commands_dir(self, task_id: str) -> Path: return self._task_dir(task_id) / TASK_COMMANDS_DIRNAME def _task_metadata_path(self, task_id: str) -> Path: return self._task_dir(task_id) / "task.json" def _count_tasks_locked(self) -> int: return sum(1 for _ in self._tasks_dir.glob("*/task.json")) def _load_task_locked(self, task_id: str) -> TaskRecord: metadata_path = self._task_metadata_path(task_id) if not metadata_path.exists(): raise ValueError(f"task {task_id!r} does not exist") payload = json.loads(metadata_path.read_text(encoding="utf-8")) if not isinstance(payload, dict): raise RuntimeError(f"task record at {metadata_path} is invalid") return TaskRecord.from_payload(payload) def _save_task_locked(self, task: TaskRecord) -> None: metadata_path = self._task_metadata_path(task.task_id) metadata_path.parent.mkdir(parents=True, exist_ok=True) metadata_path.write_text( json.dumps(task.to_payload(), indent=2, sort_keys=True), encoding="utf-8", ) def _reap_expired_tasks_locked(self, now: float) -> None: for metadata_path in list(self._tasks_dir.glob("*/task.json")): payload = json.loads(metadata_path.read_text(encoding="utf-8")) if not isinstance(payload, dict): shutil.rmtree(metadata_path.parent, ignore_errors=True) continue task = TaskRecord.from_payload(payload) if task.expires_at > now: continue instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id)) if task.state == "started": self._backend.stop(instance) task.state = "stopped" self._backend.delete(instance) shutil.rmtree(self._task_dir(task.task_id), ignore_errors=True) def _ensure_task_not_expired_locked(self, task: TaskRecord, now: float) -> None: if task.expires_at <= now: task_id = task.task_id self._reap_expired_tasks_locked(now) raise RuntimeError(f"task {task_id!r} expired and was automatically deleted") def _refresh_task_liveness_locked(self, task: TaskRecord) -> None: if task.state != "started": return execution_mode = task.metadata.get("execution_mode") if execution_mode == "host_compat": return if _pid_is_running(task.firecracker_pid): return task.state = "stopped" task.firecracker_pid = None task.last_error = "backing guest process is no longer running" def _record_task_command_locked( self, task: TaskRecord, *, command: str, exec_result: VmExecResult, execution_mode: str, cwd: str, ) -> dict[str, Any]: sequence = task.command_count + 1 commands_dir = self._task_commands_dir(task.task_id) commands_dir.mkdir(parents=True, exist_ok=True) base_name = f"{sequence:06d}" stdout_path = commands_dir / f"{base_name}.stdout" stderr_path = commands_dir / f"{base_name}.stderr" record_path = commands_dir / f"{base_name}.json" stdout_path.write_text(exec_result.stdout, encoding="utf-8") stderr_path.write_text(exec_result.stderr, encoding="utf-8") entry: dict[str, Any] = { "sequence": sequence, "command": command, "cwd": cwd, "exit_code": exec_result.exit_code, "duration_ms": exec_result.duration_ms, "execution_mode": execution_mode, "stdout_file": stdout_path.name, "stderr_file": stderr_path.name, "recorded_at": time.time(), } record_path.write_text(json.dumps(entry, indent=2, sort_keys=True), encoding="utf-8") task.command_count = sequence task.last_command = { "sequence": sequence, "command": command, "cwd": cwd, "exit_code": exec_result.exit_code, "duration_ms": exec_result.duration_ms, "execution_mode": execution_mode, } return entry def _read_task_logs_locked(self, task_id: str) -> list[dict[str, Any]]: entries: list[dict[str, Any]] = [] commands_dir = self._task_commands_dir(task_id) if not commands_dir.exists(): return entries for record_path in sorted(commands_dir.glob("*.json")): payload = json.loads(record_path.read_text(encoding="utf-8")) if not isinstance(payload, dict): continue stdout_name = str(payload.get("stdout_file", "")) stderr_name = str(payload.get("stderr_file", "")) stdout = "" stderr = "" if stdout_name != "": stdout_path = commands_dir / stdout_name if stdout_path.exists(): stdout = stdout_path.read_text(encoding="utf-8") if stderr_name != "": stderr_path = commands_dir / stderr_name if stderr_path.exists(): stderr = stderr_path.read_text(encoding="utf-8") entry = dict(payload) entry["stdout"] = stdout entry["stderr"] = stderr entries.append(entry) return entries