Add persistent task workspace alpha

Start the first workspace milestone toward the task-oriented product without changing the existing one-shot vm_run/pyro run contract.

Add a disk-backed task registry in the manager, auto-started task workspaces rooted at /workspace, repeated non-cleaning exec, and persisted command journals exposed through task create/exec/status/logs/delete across the CLI, Python SDK, and MCP server.

Update the public contract, docs, examples, and version/catalog metadata for 2.1.0, and cover the new surface with manager, CLI, SDK, and MCP tests. Validation: UV_CACHE_DIR=.uv-cache make check and UV_CACHE_DIR=.uv-cache make dist-check.
This commit is contained in:
Thales Maciel 2026-03-11 20:10:10 -03:00
parent 6e16e74fd5
commit 58df176148
19 changed files with 1730 additions and 48 deletions

View file

@ -77,6 +77,43 @@ class Pyro:
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
return self._manager.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds)
def create_task(
self,
*,
environment: str,
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]:
return self._manager.create_task(
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
network=network,
allow_host_compat=allow_host_compat,
)
def exec_task(
self,
task_id: str,
*,
command: str,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> dict[str, Any]:
return self._manager.exec_task(task_id, command=command, timeout_seconds=timeout_seconds)
def status_task(self, task_id: str) -> dict[str, Any]:
return self._manager.status_task(task_id)
def logs_task(self, task_id: str) -> dict[str, Any]:
return self._manager.logs_task(task_id)
def delete_task(self, task_id: str) -> dict[str, Any]:
return self._manager.delete_task(task_id)
def stop_vm(self, vm_id: str) -> dict[str, Any]:
return self._manager.stop_vm(vm_id)
@ -200,4 +237,47 @@ class Pyro:
"""Delete VMs whose TTL has expired."""
return self.reap_expired()
@server.tool()
async def task_create(
environment: str,
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]:
"""Create and start a persistent task workspace."""
return self.create_task(
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
network=network,
allow_host_compat=allow_host_compat,
)
@server.tool()
async def task_exec(
task_id: str,
command: str,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> dict[str, Any]:
"""Run one command inside an existing task workspace."""
return self.exec_task(task_id, command=command, timeout_seconds=timeout_seconds)
@server.tool()
async def task_status(task_id: str) -> dict[str, Any]:
"""Inspect task state and latest command metadata."""
return self.status_task(task_id)
@server.tool()
async def task_logs(task_id: str) -> dict[str, Any]:
"""Return persisted command history for one task."""
return self.logs_task(task_id)
@server.tool()
async def task_delete(task_id: str) -> dict[str, Any]:
"""Delete a task workspace and its backing sandbox."""
return self.delete_task(task_id)
return server

View file

@ -17,6 +17,7 @@ from pyro_mcp.vm_environments import DEFAULT_CATALOG_VERSION
from pyro_mcp.vm_manager import (
DEFAULT_MEM_MIB,
DEFAULT_VCPU_COUNT,
TASK_WORKSPACE_GUEST_PATH,
)
@ -149,6 +150,67 @@ def _print_doctor_human(payload: dict[str, Any]) -> None:
print(f"- {issue}")
def _print_task_summary_human(payload: dict[str, Any], *, action: str) -> None:
print(f"{action}: {str(payload.get('task_id', 'unknown'))}")
print(f"Environment: {str(payload.get('environment', 'unknown'))}")
print(f"State: {str(payload.get('state', 'unknown'))}")
print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}")
print(f"Execution mode: {str(payload.get('execution_mode', 'pending'))}")
print(
f"Resources: {int(payload.get('vcpu_count', 0))} vCPU / "
f"{int(payload.get('mem_mib', 0))} MiB"
)
print(f"Command count: {int(payload.get('command_count', 0))}")
last_command = payload.get("last_command")
if isinstance(last_command, dict):
print(
"Last command: "
f"{str(last_command.get('command', 'unknown'))} "
f"(exit_code={int(last_command.get('exit_code', -1))})"
)
def _print_task_exec_human(payload: dict[str, Any]) -> None:
stdout = str(payload.get("stdout", ""))
stderr = str(payload.get("stderr", ""))
_write_stream(stdout, stream=sys.stdout)
_write_stream(stderr, stream=sys.stderr)
print(
"[task-exec] "
f"task_id={str(payload.get('task_id', 'unknown'))} "
f"sequence={int(payload.get('sequence', 0))} "
f"cwd={str(payload.get('cwd', TASK_WORKSPACE_GUEST_PATH))} "
f"execution_mode={str(payload.get('execution_mode', 'unknown'))} "
f"exit_code={int(payload.get('exit_code', 1))} "
f"duration_ms={int(payload.get('duration_ms', 0))}",
file=sys.stderr,
flush=True,
)
def _print_task_logs_human(payload: dict[str, Any]) -> None:
entries = payload.get("entries")
if not isinstance(entries, list) or not entries:
print("No task logs found.")
return
for entry in entries:
if not isinstance(entry, dict):
continue
print(
f"#{int(entry.get('sequence', 0))} "
f"exit_code={int(entry.get('exit_code', -1))} "
f"duration_ms={int(entry.get('duration_ms', 0))} "
f"cwd={str(entry.get('cwd', TASK_WORKSPACE_GUEST_PATH))}"
)
print(f"$ {str(entry.get('command', ''))}")
stdout = str(entry.get("stdout", ""))
stderr = str(entry.get("stderr", ""))
if stdout != "":
print(stdout, end="" if stdout.endswith("\n") else "\n")
if stderr != "":
print(stderr, end="" if stderr.endswith("\n") else "\n", file=sys.stderr)
class _HelpFormatter(
argparse.RawDescriptionHelpFormatter,
argparse.ArgumentDefaultsHelpFormatter,
@ -178,6 +240,9 @@ def _build_parser() -> argparse.ArgumentParser:
pyro env pull debian:12
pyro run debian:12 -- git --version
Need repeated commands in one workspace after that?
pyro task create debian:12
Use `pyro mcp serve` only after the CLI validation path works.
"""
),
@ -371,6 +436,152 @@ def _build_parser() -> argparse.ArgumentParser:
),
)
task_parser = subparsers.add_parser(
"task",
help="Manage persistent task workspaces.",
description=(
"Create a persistent workspace when you need repeated commands in one "
"sandbox instead of one-shot `pyro run`."
),
epilog=dedent(
"""
Examples:
pyro task create debian:12
pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt'
pyro task logs TASK_ID
"""
),
formatter_class=_HelpFormatter,
)
task_subparsers = task_parser.add_subparsers(dest="task_command", required=True, metavar="TASK")
task_create_parser = task_subparsers.add_parser(
"create",
help="Create and start a persistent task workspace.",
description="Create a task workspace that stays alive across repeated exec calls.",
epilog="Example:\n pyro task create debian:12",
formatter_class=_HelpFormatter,
)
task_create_parser.add_argument(
"environment",
metavar="ENVIRONMENT",
help="Curated environment to boot, for example `debian:12`.",
)
task_create_parser.add_argument(
"--vcpu-count",
type=int,
default=DEFAULT_VCPU_COUNT,
help="Number of virtual CPUs to allocate to the task guest.",
)
task_create_parser.add_argument(
"--mem-mib",
type=int,
default=DEFAULT_MEM_MIB,
help="Guest memory allocation in MiB.",
)
task_create_parser.add_argument(
"--ttl-seconds",
type=int,
default=600,
help="Time-to-live for the task before automatic cleanup.",
)
task_create_parser.add_argument(
"--network",
action="store_true",
help="Enable outbound guest networking for the task guest.",
)
task_create_parser.add_argument(
"--allow-host-compat",
action="store_true",
help=(
"Opt into host-side compatibility execution if guest boot or guest exec "
"is unavailable."
),
)
task_create_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_exec_parser = task_subparsers.add_parser(
"exec",
help="Run one command inside an existing task workspace.",
description="Run one non-interactive command in the persistent `/workspace` for a task.",
epilog="Example:\n pyro task exec TASK_ID -- cat note.txt",
formatter_class=_HelpFormatter,
)
task_exec_parser.add_argument("task_id", metavar="TASK_ID", help="Persistent task identifier.")
task_exec_parser.add_argument(
"--timeout-seconds",
type=int,
default=30,
help="Maximum time allowed for the task command.",
)
task_exec_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_exec_parser.add_argument(
"command_args",
nargs="*",
metavar="ARG",
help=(
"Command and arguments to run inside the task workspace. Prefix them with `--`, "
"for example `pyro task exec TASK_ID -- cat note.txt`."
),
)
task_status_parser = task_subparsers.add_parser(
"status",
help="Inspect one task workspace.",
description="Show task state, sizing, workspace path, and latest command metadata.",
epilog="Example:\n pyro task status TASK_ID",
formatter_class=_HelpFormatter,
)
task_status_parser.add_argument(
"task_id",
metavar="TASK_ID",
help="Persistent task identifier.",
)
task_status_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_logs_parser = task_subparsers.add_parser(
"logs",
help="Show command history for one task.",
description="Show persisted command history, including stdout and stderr, for one task.",
epilog="Example:\n pyro task logs TASK_ID",
formatter_class=_HelpFormatter,
)
task_logs_parser.add_argument(
"task_id",
metavar="TASK_ID",
help="Persistent task identifier.",
)
task_logs_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_delete_parser = task_subparsers.add_parser(
"delete",
help="Delete one task workspace.",
description="Stop the backing sandbox if needed and remove the task workspace.",
epilog="Example:\n pyro task delete TASK_ID",
formatter_class=_HelpFormatter,
)
task_delete_parser.add_argument(
"task_id",
metavar="TASK_ID",
help="Persistent task identifier.",
)
task_delete_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
doctor_parser = subparsers.add_parser(
"doctor",
help="Inspect runtime and host diagnostics.",
@ -451,7 +662,7 @@ def _require_command(command_args: list[str]) -> str:
if command_args and command_args[0] == "--":
command_args = command_args[1:]
if not command_args:
raise ValueError("command is required after `pyro run --`")
raise ValueError("command is required after `--`")
return " ".join(command_args)
@ -544,6 +755,70 @@ def main() -> None:
if exit_code != 0:
raise SystemExit(exit_code)
return
if args.command == "task":
if args.task_command == "create":
payload = pyro.create_task(
environment=args.environment,
vcpu_count=args.vcpu_count,
mem_mib=args.mem_mib,
ttl_seconds=args.ttl_seconds,
network=args.network,
allow_host_compat=args.allow_host_compat,
)
if bool(args.json):
_print_json(payload)
else:
_print_task_summary_human(payload, action="Task")
return
if args.task_command == "exec":
command = _require_command(args.command_args)
if bool(args.json):
try:
payload = pyro.exec_task(
args.task_id,
command=command,
timeout_seconds=args.timeout_seconds,
)
except Exception as exc: # noqa: BLE001
_print_json({"ok": False, "error": str(exc)})
raise SystemExit(1) from exc
_print_json(payload)
else:
try:
payload = pyro.exec_task(
args.task_id,
command=command,
timeout_seconds=args.timeout_seconds,
)
except Exception as exc: # noqa: BLE001
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
_print_task_exec_human(payload)
exit_code = int(payload.get("exit_code", 1))
if exit_code != 0:
raise SystemExit(exit_code)
return
if args.task_command == "status":
payload = pyro.status_task(args.task_id)
if bool(args.json):
_print_json(payload)
else:
_print_task_summary_human(payload, action="Task")
return
if args.task_command == "logs":
payload = pyro.logs_task(args.task_id)
if bool(args.json):
_print_json(payload)
else:
_print_task_logs_human(payload)
return
if args.task_command == "delete":
payload = pyro.delete_task(args.task_id)
if bool(args.json):
_print_json(payload)
else:
print(f"Deleted task: {str(payload.get('task_id', 'unknown'))}")
return
if args.command == "doctor":
payload = doctor_report(platform=args.platform)
if bool(args.json):

View file

@ -2,9 +2,10 @@
from __future__ import annotations
PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run")
PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task")
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status")
PUBLIC_CLI_RUN_FLAGS = (
"--vcpu-count",
"--mem-mib",
@ -17,17 +18,22 @@ PUBLIC_CLI_RUN_FLAGS = (
PUBLIC_SDK_METHODS = (
"create_server",
"create_task",
"create_vm",
"delete_task",
"delete_vm",
"exec_task",
"exec_vm",
"inspect_environment",
"list_environments",
"logs_task",
"network_info_vm",
"prune_environments",
"pull_environment",
"reap_expired",
"run_in_vm",
"start_vm",
"status_task",
"status_vm",
"stop_vm",
)
@ -43,4 +49,9 @@ PUBLIC_MCP_TOOLS = (
"vm_start",
"vm_status",
"vm_stop",
"task_create",
"task_delete",
"task_exec",
"task_logs",
"task_status",
)

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "2.0.0"
DEFAULT_CATALOG_VERSION = "2.1.0"
OCI_MANIFEST_ACCEPT = ", ".join(
(
"application/vnd.oci.image.index.v1+json",

View file

@ -1,8 +1,10 @@
"""Lifecycle manager for ephemeral VM environments."""
"""Lifecycle manager for ephemeral VM environments and persistent tasks."""
from __future__ import annotations
import json
import os
import shlex
import shutil
import signal
import subprocess
@ -11,7 +13,7 @@ import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, cast
from pyro_mcp.runtime import (
RuntimeCapabilities,
@ -32,6 +34,12 @@ DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False
TASK_LAYOUT_VERSION = 1
TASK_WORKSPACE_DIRNAME = "workspace"
TASK_COMMANDS_DIRNAME = "commands"
TASK_RUNTIME_DIRNAME = "runtime"
TASK_WORKSPACE_GUEST_PATH = "/workspace"
@dataclass
class VmInstance:
@ -54,6 +62,116 @@ class VmInstance:
network: NetworkConfig | None = None
@dataclass
class TaskRecord:
"""Persistent task metadata stored on disk."""
task_id: str
environment: str
vcpu_count: int
mem_mib: int
ttl_seconds: int
created_at: float
expires_at: float
state: VmState
network_requested: bool
allow_host_compat: bool
firecracker_pid: int | None = None
last_error: str | None = None
metadata: dict[str, str] = field(default_factory=dict)
network: NetworkConfig | None = None
command_count: int = 0
last_command: dict[str, Any] | None = None
@classmethod
def from_instance(
cls,
instance: VmInstance,
*,
command_count: int = 0,
last_command: dict[str, Any] | None = None,
) -> TaskRecord:
return cls(
task_id=instance.vm_id,
environment=instance.environment,
vcpu_count=instance.vcpu_count,
mem_mib=instance.mem_mib,
ttl_seconds=instance.ttl_seconds,
created_at=instance.created_at,
expires_at=instance.expires_at,
state=instance.state,
network_requested=instance.network_requested,
allow_host_compat=instance.allow_host_compat,
firecracker_pid=instance.firecracker_pid,
last_error=instance.last_error,
metadata=dict(instance.metadata),
network=instance.network,
command_count=command_count,
last_command=last_command,
)
def to_instance(self, *, workdir: Path) -> VmInstance:
return VmInstance(
vm_id=self.task_id,
environment=self.environment,
vcpu_count=self.vcpu_count,
mem_mib=self.mem_mib,
ttl_seconds=self.ttl_seconds,
created_at=self.created_at,
expires_at=self.expires_at,
workdir=workdir,
state=self.state,
network_requested=self.network_requested,
allow_host_compat=self.allow_host_compat,
firecracker_pid=self.firecracker_pid,
last_error=self.last_error,
metadata=dict(self.metadata),
network=self.network,
)
def to_payload(self) -> dict[str, Any]:
return {
"layout_version": TASK_LAYOUT_VERSION,
"task_id": self.task_id,
"environment": self.environment,
"vcpu_count": self.vcpu_count,
"mem_mib": self.mem_mib,
"ttl_seconds": self.ttl_seconds,
"created_at": self.created_at,
"expires_at": self.expires_at,
"state": self.state,
"network_requested": self.network_requested,
"allow_host_compat": self.allow_host_compat,
"firecracker_pid": self.firecracker_pid,
"last_error": self.last_error,
"metadata": self.metadata,
"network": _serialize_network(self.network),
"command_count": self.command_count,
"last_command": self.last_command,
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> TaskRecord:
return cls(
task_id=str(payload["task_id"]),
environment=str(payload["environment"]),
vcpu_count=int(payload["vcpu_count"]),
mem_mib=int(payload["mem_mib"]),
ttl_seconds=int(payload["ttl_seconds"]),
created_at=float(payload["created_at"]),
expires_at=float(payload["expires_at"]),
state=cast(VmState, str(payload.get("state", "stopped"))),
network_requested=bool(payload.get("network_requested", False)),
allow_host_compat=bool(payload.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)),
firecracker_pid=_optional_int(payload.get("firecracker_pid")),
last_error=_optional_str(payload.get("last_error")),
metadata=_string_dict(payload.get("metadata")),
network=_deserialize_network(payload.get("network")),
command_count=int(payload.get("command_count", 0)),
last_command=_optional_dict(payload.get("last_command")),
)
@dataclass(frozen=True)
class VmExecResult:
"""Command execution output."""
@ -64,6 +182,72 @@ class VmExecResult:
duration_ms: int
def _optional_int(value: object) -> int | None:
if value is None:
return None
if isinstance(value, bool):
return int(value)
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value)
if isinstance(value, str):
return int(value)
raise TypeError("expected integer-compatible payload")
def _optional_str(value: object) -> str | None:
if value is None:
return None
return str(value)
def _optional_dict(value: object) -> dict[str, Any] | None:
if value is None:
return None
if not isinstance(value, dict):
raise TypeError("expected dictionary payload")
return dict(value)
def _string_dict(value: object) -> dict[str, str]:
if not isinstance(value, dict):
return {}
return {str(key): str(item) for key, item in value.items()}
def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None:
if network is None:
return None
return {
"vm_id": network.vm_id,
"tap_name": network.tap_name,
"guest_ip": network.guest_ip,
"gateway_ip": network.gateway_ip,
"subnet_cidr": network.subnet_cidr,
"mac_address": network.mac_address,
"dns_servers": list(network.dns_servers),
}
def _deserialize_network(payload: object) -> NetworkConfig | None:
if payload is None:
return None
if not isinstance(payload, dict):
raise TypeError("expected dictionary payload")
dns_servers = payload.get("dns_servers", [])
dns_values = tuple(str(item) for item in dns_servers) if isinstance(dns_servers, list) else ()
return NetworkConfig(
vm_id=str(payload["vm_id"]),
tap_name=str(payload["tap_name"]),
guest_ip=str(payload["guest_ip"]),
gateway_ip=str(payload["gateway_ip"]),
subnet_cidr=str(payload["subnet_cidr"]),
mac_address=str(payload["mac_address"]),
dns_servers=dns_values,
)
def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult:
started = time.monotonic()
env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)}
@ -109,6 +293,25 @@ def _copy_rootfs(source: Path, dest: Path) -> str:
return "copy2"
def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str:
if cwd is None:
return command
quoted_cwd = shlex.quote(cwd)
return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}"
def _pid_is_running(pid: int | None) -> bool:
if pid is None:
return False
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
return True
class VmBackend:
"""Backend interface for lifecycle operations."""
@ -119,7 +322,12 @@ class VmBackend:
raise NotImplementedError
def exec( # pragma: no cover
self, instance: VmInstance, command: str, timeout_seconds: int
self,
instance: VmInstance,
command: str,
timeout_seconds: int,
*,
workdir: Path | None = None,
) -> VmExecResult:
raise NotImplementedError
@ -140,8 +348,15 @@ class MockBackend(VmBackend):
marker_path = instance.workdir / ".started"
marker_path.write_text("started\n", encoding="utf-8")
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
return _run_host_command(instance.workdir, command, timeout_seconds)
def exec(
self,
instance: VmInstance,
command: str,
timeout_seconds: int,
*,
workdir: Path | None = None,
) -> VmExecResult:
return _run_host_command(workdir or instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
marker_path = instance.workdir / ".stopped"
@ -256,6 +471,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
stdout=serial_fp,
stderr=subprocess.STDOUT,
text=True,
start_new_session=True,
)
self._processes[instance.vm_id] = process
time.sleep(2)
@ -273,7 +489,14 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
)
instance.metadata["boot_mode"] = "native"
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
def exec(
self,
instance: VmInstance,
command: str,
timeout_seconds: int,
*,
workdir: Path | None = None,
) -> VmExecResult:
if self._runtime_capabilities.supports_guest_exec:
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
@ -302,7 +525,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
duration_ms=response.duration_ms,
)
instance.metadata["execution_mode"] = "host_compat"
return _run_host_command(instance.workdir, command, timeout_seconds)
return _run_host_command(workdir or instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
process = self._processes.pop(instance.vm_id, None)
@ -341,7 +564,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
class VmManager:
"""In-process lifecycle manager for ephemeral VM environments."""
"""In-process lifecycle manager for ephemeral VM environments and tasks."""
MIN_VCPUS = 1
MAX_VCPUS = 8
@ -367,6 +590,7 @@ class VmManager:
) -> None:
self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp")
self._tasks_dir = self._base_dir / "tasks"
resolved_cache_dir = cache_dir or default_cache_dir()
self._runtime_paths = runtime_paths
if self._backend_name == "firecracker":
@ -399,6 +623,7 @@ class VmManager:
self._lock = threading.Lock()
self._instances: dict[str, VmInstance] = {}
self._base_dir.mkdir(parents=True, exist_ok=True)
self._tasks_dir.mkdir(parents=True, exist_ok=True)
self._backend = self._build_backend()
def _build_backend(self) -> VmBackend:
@ -443,7 +668,8 @@ class VmManager:
now = time.time()
with self._lock:
self._reap_expired_locked(now)
active_count = len(self._instances)
self._reap_expired_tasks_locked(now)
active_count = len(self._instances) + self._count_tasks_locked()
if active_count >= self._max_active_vms:
raise RuntimeError(
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
@ -501,36 +727,24 @@ class VmManager:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.state not in {"created", "stopped"}:
raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}")
self._require_guest_boot_or_opt_in(instance)
if not self._runtime_capabilities.supports_vm_boot:
instance.metadata["execution_mode"] = "host_compat"
instance.metadata["boot_mode"] = "compat"
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
self._backend.start(instance)
instance.state = "started"
self._start_instance_locked(instance)
return self._serialize(instance)
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.state != "started":
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
self._require_guest_exec_or_opt_in(instance)
if not self._runtime_capabilities.supports_guest_exec:
instance.metadata["execution_mode"] = "host_compat"
exec_result = self._backend.exec(instance, command, timeout_seconds)
execution_mode = instance.metadata.get("execution_mode", "unknown")
exec_instance = instance
exec_result, execution_mode = self._exec_instance(
exec_instance,
command=command,
timeout_seconds=timeout_seconds,
)
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return {
"vm_id": vm_id,
"environment": instance.environment,
"environment_version": instance.metadata.get("environment_version"),
"environment": exec_instance.environment,
"environment_version": exec_instance.metadata.get("environment_version"),
"command": command,
"stdout": exec_result.stdout,
"stderr": exec_result.stderr,
@ -591,6 +805,154 @@ class VmManager:
del self._instances[vm_id]
return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)}
def create_task(
self,
*,
environment: str,
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_environment(environment, runtime_paths=self._runtime_paths)
now = time.time()
task_id = uuid.uuid4().hex[:12]
task_dir = self._task_dir(task_id)
runtime_dir = self._task_runtime_dir(task_id)
workspace_dir = self._task_workspace_dir(task_id)
commands_dir = self._task_commands_dir(task_id)
task_dir.mkdir(parents=True, exist_ok=False)
workspace_dir.mkdir(parents=True, exist_ok=True)
commands_dir.mkdir(parents=True, exist_ok=True)
instance = VmInstance(
vm_id=task_id,
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
created_at=now,
expires_at=now + ttl_seconds,
workdir=runtime_dir,
network_requested=network,
allow_host_compat=allow_host_compat,
)
instance.metadata["allow_host_compat"] = str(allow_host_compat).lower()
instance.metadata["workspace_path"] = TASK_WORKSPACE_GUEST_PATH
instance.metadata["workspace_host_dir"] = str(workspace_dir)
try:
with self._lock:
self._reap_expired_locked(now)
self._reap_expired_tasks_locked(now)
active_count = len(self._instances) + self._count_tasks_locked()
if active_count >= self._max_active_vms:
raise RuntimeError(
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
self._backend.create(instance)
with self._lock:
self._start_instance_locked(instance)
self._require_guest_exec_or_opt_in(instance)
if self._runtime_capabilities.supports_guest_exec:
self._backend.exec(
instance,
f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}",
10,
)
else:
instance.metadata["execution_mode"] = "host_compat"
task = TaskRecord.from_instance(instance)
self._save_task_locked(task)
return self._serialize_task(task)
except Exception:
if runtime_dir.exists():
try:
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
except Exception:
pass
try:
self._backend.delete(instance)
except Exception:
pass
shutil.rmtree(task_dir, ignore_errors=True)
raise
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
with self._lock:
task = self._load_task_locked(task_id)
self._ensure_task_not_expired_locked(task, time.time())
self._refresh_task_liveness_locked(task)
if task.state != "started":
raise RuntimeError(f"task {task_id} must be in 'started' state before task_exec")
instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id))
exec_result, execution_mode = self._exec_instance(
instance,
command=command,
timeout_seconds=timeout_seconds,
host_workdir=self._task_workspace_dir(task.task_id),
guest_cwd=TASK_WORKSPACE_GUEST_PATH,
)
with self._lock:
task = self._load_task_locked(task_id)
task.state = instance.state
task.firecracker_pid = instance.firecracker_pid
task.last_error = instance.last_error
task.metadata = dict(instance.metadata)
entry = self._record_task_command_locked(
task,
command=command,
exec_result=exec_result,
execution_mode=execution_mode,
cwd=TASK_WORKSPACE_GUEST_PATH,
)
self._save_task_locked(task)
return {
"task_id": task_id,
"environment": task.environment,
"environment_version": task.metadata.get("environment_version"),
"command": command,
"stdout": exec_result.stdout,
"stderr": exec_result.stderr,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
"sequence": entry["sequence"],
"cwd": TASK_WORKSPACE_GUEST_PATH,
}
def status_task(self, task_id: str) -> dict[str, Any]:
with self._lock:
task = self._load_task_locked(task_id)
self._ensure_task_not_expired_locked(task, time.time())
self._refresh_task_liveness_locked(task)
self._save_task_locked(task)
return self._serialize_task(task)
def logs_task(self, task_id: str) -> dict[str, Any]:
with self._lock:
task = self._load_task_locked(task_id)
self._ensure_task_not_expired_locked(task, time.time())
self._refresh_task_liveness_locked(task)
self._save_task_locked(task)
entries = self._read_task_logs_locked(task.task_id)
return {"task_id": task.task_id, "count": len(entries), "entries": entries}
def delete_task(self, task_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]:
with self._lock:
task = self._load_task_locked(task_id)
instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id))
if task.state == "started":
self._backend.stop(instance)
task.state = "stopped"
self._backend.delete(instance)
shutil.rmtree(self._task_dir(task_id), ignore_errors=True)
return {"task_id": task_id, "deleted": True, "reason": reason}
def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None:
if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS:
raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}")
@ -620,6 +982,28 @@ class VmManager:
"metadata": instance.metadata,
}
def _serialize_task(self, task: TaskRecord) -> dict[str, Any]:
return {
"task_id": task.task_id,
"environment": task.environment,
"environment_version": task.metadata.get("environment_version"),
"vcpu_count": task.vcpu_count,
"mem_mib": task.mem_mib,
"ttl_seconds": task.ttl_seconds,
"created_at": task.created_at,
"expires_at": task.expires_at,
"state": task.state,
"network_enabled": task.network is not None,
"allow_host_compat": task.allow_host_compat,
"guest_ip": task.network.guest_ip if task.network is not None else None,
"tap_name": task.network.tap_name if task.network is not None else None,
"execution_mode": task.metadata.get("execution_mode", "pending"),
"workspace_path": TASK_WORKSPACE_GUEST_PATH,
"command_count": task.command_count,
"last_command": task.last_command,
"metadata": task.metadata,
}
def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None:
if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat:
return
@ -665,3 +1049,184 @@ class VmManager:
vm_id = instance.vm_id
self._reap_expired_locked(now)
raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted")
def _start_instance_locked(self, instance: VmInstance) -> None:
if instance.state not in {"created", "stopped"}:
raise RuntimeError(
f"vm {instance.vm_id} cannot be started from state {instance.state!r}"
)
self._require_guest_boot_or_opt_in(instance)
if not self._runtime_capabilities.supports_vm_boot:
instance.metadata["execution_mode"] = "host_compat"
instance.metadata["boot_mode"] = "compat"
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
self._backend.start(instance)
instance.state = "started"
def _exec_instance(
self,
instance: VmInstance,
*,
command: str,
timeout_seconds: int,
host_workdir: Path | None = None,
guest_cwd: str | None = None,
) -> tuple[VmExecResult, str]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
if instance.state != "started":
raise RuntimeError(f"vm {instance.vm_id} must be in 'started' state before execution")
self._require_guest_exec_or_opt_in(instance)
prepared_command = command
if self._runtime_capabilities.supports_guest_exec:
prepared_command = _wrap_guest_command(command, cwd=guest_cwd)
workdir = None
else:
instance.metadata["execution_mode"] = "host_compat"
workdir = host_workdir
exec_result = self._backend.exec(
instance,
prepared_command,
timeout_seconds,
workdir=workdir,
)
execution_mode = instance.metadata.get("execution_mode", "unknown")
return exec_result, execution_mode
def _task_dir(self, task_id: str) -> Path:
return self._tasks_dir / task_id
def _task_runtime_dir(self, task_id: str) -> Path:
return self._task_dir(task_id) / TASK_RUNTIME_DIRNAME
def _task_workspace_dir(self, task_id: str) -> Path:
return self._task_dir(task_id) / TASK_WORKSPACE_DIRNAME
def _task_commands_dir(self, task_id: str) -> Path:
return self._task_dir(task_id) / TASK_COMMANDS_DIRNAME
def _task_metadata_path(self, task_id: str) -> Path:
return self._task_dir(task_id) / "task.json"
def _count_tasks_locked(self) -> int:
return sum(1 for _ in self._tasks_dir.glob("*/task.json"))
def _load_task_locked(self, task_id: str) -> TaskRecord:
metadata_path = self._task_metadata_path(task_id)
if not metadata_path.exists():
raise ValueError(f"task {task_id!r} does not exist")
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError(f"task record at {metadata_path} is invalid")
return TaskRecord.from_payload(payload)
def _save_task_locked(self, task: TaskRecord) -> None:
metadata_path = self._task_metadata_path(task.task_id)
metadata_path.parent.mkdir(parents=True, exist_ok=True)
metadata_path.write_text(
json.dumps(task.to_payload(), indent=2, sort_keys=True),
encoding="utf-8",
)
def _reap_expired_tasks_locked(self, now: float) -> None:
for metadata_path in list(self._tasks_dir.glob("*/task.json")):
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
shutil.rmtree(metadata_path.parent, ignore_errors=True)
continue
task = TaskRecord.from_payload(payload)
if task.expires_at > now:
continue
instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id))
if task.state == "started":
self._backend.stop(instance)
task.state = "stopped"
self._backend.delete(instance)
shutil.rmtree(self._task_dir(task.task_id), ignore_errors=True)
def _ensure_task_not_expired_locked(self, task: TaskRecord, now: float) -> None:
if task.expires_at <= now:
task_id = task.task_id
self._reap_expired_tasks_locked(now)
raise RuntimeError(f"task {task_id!r} expired and was automatically deleted")
def _refresh_task_liveness_locked(self, task: TaskRecord) -> None:
if task.state != "started":
return
execution_mode = task.metadata.get("execution_mode")
if execution_mode == "host_compat":
return
if _pid_is_running(task.firecracker_pid):
return
task.state = "stopped"
task.firecracker_pid = None
task.last_error = "backing guest process is no longer running"
def _record_task_command_locked(
self,
task: TaskRecord,
*,
command: str,
exec_result: VmExecResult,
execution_mode: str,
cwd: str,
) -> dict[str, Any]:
sequence = task.command_count + 1
commands_dir = self._task_commands_dir(task.task_id)
commands_dir.mkdir(parents=True, exist_ok=True)
base_name = f"{sequence:06d}"
stdout_path = commands_dir / f"{base_name}.stdout"
stderr_path = commands_dir / f"{base_name}.stderr"
record_path = commands_dir / f"{base_name}.json"
stdout_path.write_text(exec_result.stdout, encoding="utf-8")
stderr_path.write_text(exec_result.stderr, encoding="utf-8")
entry: dict[str, Any] = {
"sequence": sequence,
"command": command,
"cwd": cwd,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
"stdout_file": stdout_path.name,
"stderr_file": stderr_path.name,
"recorded_at": time.time(),
}
record_path.write_text(json.dumps(entry, indent=2, sort_keys=True), encoding="utf-8")
task.command_count = sequence
task.last_command = {
"sequence": sequence,
"command": command,
"cwd": cwd,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
}
return entry
def _read_task_logs_locked(self, task_id: str) -> list[dict[str, Any]]:
entries: list[dict[str, Any]] = []
commands_dir = self._task_commands_dir(task_id)
if not commands_dir.exists():
return entries
for record_path in sorted(commands_dir.glob("*.json")):
payload = json.loads(record_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
continue
stdout_name = str(payload.get("stdout_file", ""))
stderr_name = str(payload.get("stderr_file", ""))
stdout = ""
stderr = ""
if stdout_name != "":
stdout_path = commands_dir / stdout_name
if stdout_path.exists():
stdout = stdout_path.read_text(encoding="utf-8")
if stderr_name != "":
stderr_path = commands_dir / stderr_name
if stderr_path.exists():
stderr = stderr_path.read_text(encoding="utf-8")
entry = dict(payload)
entry["stdout"] = stdout
entry["stderr"] = stderr
entries.append(entry)
return entries