Add persistent task workspace alpha

Start the first workspace milestone toward the task-oriented product without changing the existing one-shot vm_run/pyro run contract.

Add a disk-backed task registry in the manager, auto-started task workspaces rooted at /workspace, repeated non-cleaning exec, and persisted command journals exposed through task create/exec/status/logs/delete across the CLI, Python SDK, and MCP server.

Update the public contract, docs, examples, and version/catalog metadata for 2.1.0, and cover the new surface with manager, CLI, SDK, and MCP tests. Validation: UV_CACHE_DIR=.uv-cache make check and UV_CACHE_DIR=.uv-cache make dist-check.
This commit is contained in:
Thales Maciel 2026-03-11 20:10:10 -03:00
parent 6e16e74fd5
commit 58df176148
19 changed files with 1730 additions and 48 deletions

View file

@ -2,6 +2,16 @@
All notable user-visible changes to `pyro-mcp` are documented here. All notable user-visible changes to `pyro-mcp` are documented here.
## 2.1.0
- Added the first persistent task workspace alpha across the CLI, Python SDK, and MCP server.
- Shipped `task create`, `task exec`, `task status`, `task logs`, and `task delete` as an additive
surface alongside the existing one-shot VM contract.
- Made task workspaces persistent across separate CLI/SDK/MCP processes by storing task records on
disk under the runtime base directory.
- Added per-task command journaling so repeated workspace commands can be inspected through
`pyro task logs` or the matching SDK/MCP methods.
## 2.0.1 ## 2.0.1
- Fixed the default `pyro env pull` path so empty local profile directories no longer produce - Fixed the default `pyro env pull` path so empty local profile directories no longer produce

View file

@ -1,6 +1,6 @@
# pyro-mcp # pyro-mcp
`pyro-mcp` runs commands inside ephemeral Firecracker microVMs using curated Linux environments such as `debian:12`. `pyro-mcp` runs one-shot commands and repeated task workspaces inside ephemeral Firecracker microVMs using curated Linux environments such as `debian:12`.
[![PyPI version](https://img.shields.io/pypi/v/pyro-mcp.svg)](https://pypi.org/project/pyro-mcp/) [![PyPI version](https://img.shields.io/pypi/v/pyro-mcp.svg)](https://pypi.org/project/pyro-mcp/)
@ -18,7 +18,7 @@ It exposes the same runtime in three public forms:
- First run transcript: [docs/first-run.md](docs/first-run.md) - First run transcript: [docs/first-run.md](docs/first-run.md)
- Terminal walkthrough GIF: [docs/assets/first-run.gif](docs/assets/first-run.gif) - Terminal walkthrough GIF: [docs/assets/first-run.gif](docs/assets/first-run.gif)
- PyPI package: [pypi.org/project/pyro-mcp](https://pypi.org/project/pyro-mcp/) - PyPI package: [pypi.org/project/pyro-mcp](https://pypi.org/project/pyro-mcp/)
- What's new in 2.0.1: [CHANGELOG.md#201](CHANGELOG.md#201) - What's new in 2.1.0: [CHANGELOG.md#210](CHANGELOG.md#210)
- Host requirements: [docs/host-requirements.md](docs/host-requirements.md) - Host requirements: [docs/host-requirements.md](docs/host-requirements.md)
- Integration targets: [docs/integrations.md](docs/integrations.md) - Integration targets: [docs/integrations.md](docs/integrations.md)
- Public contract: [docs/public-contract.md](docs/public-contract.md) - Public contract: [docs/public-contract.md](docs/public-contract.md)
@ -55,7 +55,7 @@ What success looks like:
```bash ```bash
Platform: linux-x86_64 Platform: linux-x86_64
Runtime: PASS Runtime: PASS
Catalog version: 2.0.0 Catalog version: 2.1.0
... ...
[pull] phase=install environment=debian:12 [pull] phase=install environment=debian:12
[pull] phase=ready environment=debian:12 [pull] phase=ready environment=debian:12
@ -74,6 +74,7 @@ access to `registry-1.docker.io`, and needs local cache space for the guest imag
After the quickstart works: After the quickstart works:
- prove the full one-shot lifecycle with `uvx --from pyro-mcp pyro demo` - prove the full one-shot lifecycle with `uvx --from pyro-mcp pyro demo`
- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12`
- move to Python or MCP via [docs/integrations.md](docs/integrations.md) - move to Python or MCP via [docs/integrations.md](docs/integrations.md)
## Supported Hosts ## Supported Hosts
@ -127,7 +128,7 @@ uvx --from pyro-mcp pyro env list
Expected output: Expected output:
```bash ```bash
Catalog version: 2.0.0 Catalog version: 2.1.0
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
@ -191,11 +192,27 @@ When you are done evaluating and want to remove stale cached environments, run `
If you prefer a fuller copy-pasteable transcript, see [docs/first-run.md](docs/first-run.md). If you prefer a fuller copy-pasteable transcript, see [docs/first-run.md](docs/first-run.md).
The walkthrough GIF above was rendered from [docs/assets/first-run.tape](docs/assets/first-run.tape) using [scripts/render_tape.sh](scripts/render_tape.sh). The walkthrough GIF above was rendered from [docs/assets/first-run.tape](docs/assets/first-run.tape) using [scripts/render_tape.sh](scripts/render_tape.sh).
## Persistent Tasks
Use `pyro run` for one-shot commands. Use `pyro task ...` when you need repeated commands in one
workspace without recreating the sandbox every time.
```bash
pyro task create debian:12
pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt'
pyro task exec TASK_ID -- cat note.txt
pyro task logs TASK_ID
pyro task delete TASK_ID
```
Task workspaces start in `/workspace` and keep command history until you delete them. For machine
consumption, add `--json` and read the returned `task_id`.
## Public Interfaces ## Public Interfaces
The public user-facing interface is `pyro` and `Pyro`. After the CLI validation path works, you can choose one of three surfaces: The public user-facing interface is `pyro` and `Pyro`. After the CLI validation path works, you can choose one of three surfaces:
- `pyro` for direct CLI usage - `pyro` for direct CLI usage, including one-shot `run` and persistent `task` workflows
- `from pyro_mcp import Pyro` for Python orchestration - `from pyro_mcp import Pyro` for Python orchestration
- `pyro mcp serve` for MCP clients - `pyro mcp serve` for MCP clients
@ -325,6 +342,22 @@ print(pyro.list_environments())
print(pyro.inspect_environment("debian:12")) print(pyro.inspect_environment("debian:12"))
``` ```
For repeated commands in one workspace:
```python
from pyro_mcp import Pyro
pyro = Pyro()
task = pyro.create_task(environment="debian:12")
task_id = task["task_id"]
try:
pyro.exec_task(task_id, command="printf 'hello from task\\n' > note.txt")
result = pyro.exec_task(task_id, command="cat note.txt")
print(result["stdout"], end="")
finally:
pyro.delete_task(task_id)
```
## MCP Tools ## MCP Tools
Primary agent-facing tool: Primary agent-facing tool:
@ -343,10 +376,19 @@ Advanced lifecycle tools:
- `vm_network_info(vm_id)` - `vm_network_info(vm_id)`
- `vm_reap_expired()` - `vm_reap_expired()`
Persistent workspace tools:
- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false)`
- `task_exec(task_id, command, timeout_seconds=30)`
- `task_status(task_id)`
- `task_logs(task_id)`
- `task_delete(task_id)`
## Integration Examples ## Integration Examples
- Python one-shot SDK example: [examples/python_run.py](examples/python_run.py) - Python one-shot SDK example: [examples/python_run.py](examples/python_run.py)
- Python lifecycle example: [examples/python_lifecycle.py](examples/python_lifecycle.py) - Python lifecycle example: [examples/python_lifecycle.py](examples/python_lifecycle.py)
- Python task workspace example: [examples/python_task.py](examples/python_task.py)
- MCP client config example: [examples/mcp_client_config.md](examples/mcp_client_config.md) - MCP client config example: [examples/mcp_client_config.md](examples/mcp_client_config.md)
- Claude Desktop MCP config: [examples/claude_desktop_mcp_config.json](examples/claude_desktop_mcp_config.json) - Claude Desktop MCP config: [examples/claude_desktop_mcp_config.json](examples/claude_desktop_mcp_config.json)
- Cursor MCP config: [examples/cursor_mcp_config.json](examples/cursor_mcp_config.json) - Cursor MCP config: [examples/cursor_mcp_config.json](examples/cursor_mcp_config.json)

View file

@ -22,7 +22,7 @@ Networking: tun=yes ip_forward=yes
```bash ```bash
$ uvx --from pyro-mcp pyro env list $ uvx --from pyro-mcp pyro env list
Catalog version: 2.0.0 Catalog version: 2.1.0
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
@ -70,11 +70,32 @@ deterministic structured result.
```bash ```bash
$ uvx --from pyro-mcp pyro demo $ uvx --from pyro-mcp pyro demo
$ uvx --from pyro-mcp pyro task create debian:12
$ uvx --from pyro-mcp pyro mcp serve $ uvx --from pyro-mcp pyro mcp serve
``` ```
`pyro demo` proves the one-shot create/start/exec/delete VM lifecycle works end to end. `pyro demo` proves the one-shot create/start/exec/delete VM lifecycle works end to end.
When you need repeated commands in one sandbox, switch to `pyro task ...`:
```bash
$ uvx --from pyro-mcp pyro task create debian:12
Task: ...
Environment: debian:12
State: started
Workspace: /workspace
Execution mode: guest_vsock
Resources: 1 vCPU / 1024 MiB
Command count: 0
$ uvx --from pyro-mcp pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt'
[task-exec] task_id=... sequence=1 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=...
$ uvx --from pyro-mcp pyro task exec TASK_ID -- cat note.txt
hello from task
[task-exec] task_id=... sequence=2 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=...
```
Example output: Example output:
```json ```json

View file

@ -83,7 +83,7 @@ uvx --from pyro-mcp pyro env list
Expected output: Expected output:
```bash ```bash
Catalog version: 2.0.0 Catalog version: 2.1.0
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
@ -174,10 +174,26 @@ pyro run debian:12 -- git --version
After the CLI path works, you can move on to: After the CLI path works, you can move on to:
- persistent workspaces: `pyro task create debian:12`
- MCP: `pyro mcp serve` - MCP: `pyro mcp serve`
- Python SDK: `from pyro_mcp import Pyro` - Python SDK: `from pyro_mcp import Pyro`
- Demos: `pyro demo` or `pyro demo --network` - Demos: `pyro demo` or `pyro demo --network`
## Persistent Task Workspace
Use `pyro task ...` when you need repeated commands in one sandbox instead of one-shot `pyro run`.
```bash
pyro task create debian:12
pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt'
pyro task exec TASK_ID -- cat note.txt
pyro task logs TASK_ID
pyro task delete TASK_ID
```
Task commands default to the persistent `/workspace` directory inside the guest. If you need the
task identifier programmatically, use `--json` and read the `task_id` field.
## Contributor Clone ## Contributor Clone
```bash ```bash

View file

@ -7,7 +7,7 @@ CLI path in [install.md](install.md) or [first-run.md](first-run.md).
## Recommended Default ## Recommended Default
Use `vm_run` first. Use `vm_run` first for one-shot commands.
That keeps the model-facing contract small: That keeps the model-facing contract small:
@ -16,7 +16,8 @@ That keeps the model-facing contract small:
- one ephemeral VM - one ephemeral VM
- automatic cleanup - automatic cleanup
Only move to lifecycle tools when the agent truly needs VM state across multiple calls. Move to `task_*` only when the agent truly needs repeated commands in one workspace across
multiple calls.
## OpenAI Responses API ## OpenAI Responses API
@ -29,6 +30,7 @@ Best when:
Recommended surface: Recommended surface:
- `vm_run` - `vm_run`
- `task_create` + `task_exec` when the agent needs persistent workspace state
Canonical example: Canonical example:
@ -63,17 +65,20 @@ Best when:
Recommended default: Recommended default:
- `Pyro.run_in_vm(...)` - `Pyro.run_in_vm(...)`
- `Pyro.create_task(...)` + `Pyro.exec_task(...)` when repeated workspace commands are required
Lifecycle note: Lifecycle note:
- `Pyro.exec_vm(...)` runs one command and auto-cleans the VM afterward - `Pyro.exec_vm(...)` runs one command and auto-cleans the VM afterward
- use `create_vm(...)` + `start_vm(...)` only when you need pre-exec inspection or status before - use `create_vm(...)` + `start_vm(...)` only when you need pre-exec inspection or status before
that final exec that final exec
- use `create_task(...)` when the agent needs repeated commands in one persistent `/workspace`
Examples: Examples:
- [examples/python_run.py](../examples/python_run.py) - [examples/python_run.py](../examples/python_run.py)
- [examples/python_lifecycle.py](../examples/python_lifecycle.py) - [examples/python_lifecycle.py](../examples/python_lifecycle.py)
- [examples/python_task.py](../examples/python_task.py)
## Agent Framework Wrappers ## Agent Framework Wrappers
@ -91,8 +96,8 @@ Best when:
Recommended pattern: Recommended pattern:
- keep the framework wrapper thin - keep the framework wrapper thin
- map framework tool input directly onto `vm_run` - map one-shot framework tool input directly onto `vm_run`
- avoid exposing lifecycle tools unless the framework truly needs them - expose `task_*` only when the framework truly needs repeated commands in one workspace
Concrete example: Concrete example:

View file

@ -19,6 +19,11 @@ Top-level commands:
- `pyro env prune` - `pyro env prune`
- `pyro mcp serve` - `pyro mcp serve`
- `pyro run` - `pyro run`
- `pyro task create`
- `pyro task exec`
- `pyro task status`
- `pyro task logs`
- `pyro task delete`
- `pyro doctor` - `pyro doctor`
- `pyro demo` - `pyro demo`
- `pyro demo ollama` - `pyro demo ollama`
@ -40,6 +45,9 @@ Behavioral guarantees:
- `pyro run` fails if guest boot or guest exec is unavailable unless `--allow-host-compat` is set. - `pyro run` fails if guest boot or guest exec is unavailable unless `--allow-host-compat` is set.
- `pyro run`, `pyro env list`, `pyro env pull`, `pyro env inspect`, `pyro env prune`, and `pyro doctor` are human-readable by default and return structured JSON with `--json`. - `pyro run`, `pyro env list`, `pyro env pull`, `pyro env inspect`, `pyro env prune`, and `pyro doctor` are human-readable by default and return structured JSON with `--json`.
- `pyro demo ollama` prints log lines plus a final summary line. - `pyro demo ollama` prints log lines plus a final summary line.
- `pyro task create` auto-starts a persistent workspace.
- `pyro task exec` runs in the persistent `/workspace` for that task and does not auto-clean.
- `pyro task logs` returns persisted command history for that task until `pyro task delete`.
## Python SDK Contract ## Python SDK Contract
@ -56,11 +64,16 @@ Supported public entrypoints:
- `Pyro.inspect_environment(environment)` - `Pyro.inspect_environment(environment)`
- `Pyro.prune_environments()` - `Pyro.prune_environments()`
- `Pyro.create_vm(...)` - `Pyro.create_vm(...)`
- `Pyro.create_task(...)`
- `Pyro.start_vm(vm_id)` - `Pyro.start_vm(vm_id)`
- `Pyro.exec_vm(vm_id, *, command, timeout_seconds=30)` - `Pyro.exec_vm(vm_id, *, command, timeout_seconds=30)`
- `Pyro.exec_task(task_id, *, command, timeout_seconds=30)`
- `Pyro.stop_vm(vm_id)` - `Pyro.stop_vm(vm_id)`
- `Pyro.delete_vm(vm_id)` - `Pyro.delete_vm(vm_id)`
- `Pyro.delete_task(task_id)`
- `Pyro.status_vm(vm_id)` - `Pyro.status_vm(vm_id)`
- `Pyro.status_task(task_id)`
- `Pyro.logs_task(task_id)`
- `Pyro.network_info_vm(vm_id)` - `Pyro.network_info_vm(vm_id)`
- `Pyro.reap_expired()` - `Pyro.reap_expired()`
- `Pyro.run_in_vm(...)` - `Pyro.run_in_vm(...)`
@ -73,11 +86,16 @@ Stable public method names:
- `inspect_environment(environment)` - `inspect_environment(environment)`
- `prune_environments()` - `prune_environments()`
- `create_vm(...)` - `create_vm(...)`
- `create_task(...)`
- `start_vm(vm_id)` - `start_vm(vm_id)`
- `exec_vm(vm_id, *, command, timeout_seconds=30)` - `exec_vm(vm_id, *, command, timeout_seconds=30)`
- `exec_task(task_id, *, command, timeout_seconds=30)`
- `stop_vm(vm_id)` - `stop_vm(vm_id)`
- `delete_vm(vm_id)` - `delete_vm(vm_id)`
- `delete_task(task_id)`
- `status_vm(vm_id)` - `status_vm(vm_id)`
- `status_task(task_id)`
- `logs_task(task_id)`
- `network_info_vm(vm_id)` - `network_info_vm(vm_id)`
- `reap_expired()` - `reap_expired()`
- `run_in_vm(...)` - `run_in_vm(...)`
@ -85,8 +103,11 @@ Stable public method names:
Behavioral defaults: Behavioral defaults:
- `Pyro.create_vm(...)` and `Pyro.run_in_vm(...)` default to `vcpu_count=1` and `mem_mib=1024`. - `Pyro.create_vm(...)` and `Pyro.run_in_vm(...)` default to `vcpu_count=1` and `mem_mib=1024`.
- `Pyro.create_task(...)` defaults to `vcpu_count=1` and `mem_mib=1024`.
- `allow_host_compat` defaults to `False` on `create_vm(...)` and `run_in_vm(...)`. - `allow_host_compat` defaults to `False` on `create_vm(...)` and `run_in_vm(...)`.
- `allow_host_compat` defaults to `False` on `create_task(...)`.
- `Pyro.exec_vm(...)` runs one command and auto-cleans that VM after the exec completes. - `Pyro.exec_vm(...)` runs one command and auto-cleans that VM after the exec completes.
- `Pyro.exec_task(...)` runs one command in the persistent task workspace and leaves the task alive.
## MCP Contract ## MCP Contract
@ -106,11 +127,22 @@ Advanced lifecycle tools:
- `vm_network_info` - `vm_network_info`
- `vm_reap_expired` - `vm_reap_expired`
Task workspace tools:
- `task_create`
- `task_exec`
- `task_status`
- `task_logs`
- `task_delete`
Behavioral defaults: Behavioral defaults:
- `vm_run` and `vm_create` default to `vcpu_count=1` and `mem_mib=1024`. - `vm_run` and `vm_create` default to `vcpu_count=1` and `mem_mib=1024`.
- `task_create` defaults to `vcpu_count=1` and `mem_mib=1024`.
- `vm_run` and `vm_create` expose `allow_host_compat`, which defaults to `false`. - `vm_run` and `vm_create` expose `allow_host_compat`, which defaults to `false`.
- `task_create` exposes `allow_host_compat`, which defaults to `false`.
- `vm_exec` runs one command and auto-cleans that VM after the exec completes. - `vm_exec` runs one command and auto-cleans that VM after the exec completes.
- `task_exec` runs one command in a persistent `/workspace` and leaves the task alive.
## Versioning Rule ## Versioning Rule

21
examples/python_task.py Normal file
View file

@ -0,0 +1,21 @@
from __future__ import annotations
from pyro_mcp import Pyro
def main() -> None:
pyro = Pyro()
created = pyro.create_task(environment="debian:12")
task_id = str(created["task_id"])
try:
pyro.exec_task(task_id, command="printf 'hello from task\\n' > note.txt")
result = pyro.exec_task(task_id, command="cat note.txt")
print(result["stdout"], end="")
logs = pyro.logs_task(task_id)
print(f"task_id={task_id} command_count={logs['count']}")
finally:
pyro.delete_task(task_id)
if __name__ == "__main__":
main()

View file

@ -1,6 +1,6 @@
[project] [project]
name = "pyro-mcp" name = "pyro-mcp"
version = "2.0.1" version = "2.1.0"
description = "Curated Linux environments for ephemeral Firecracker-backed VM execution." description = "Curated Linux environments for ephemeral Firecracker-backed VM execution."
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }

View file

@ -77,6 +77,43 @@ class Pyro:
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]: def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
return self._manager.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds) return self._manager.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds)
def create_task(
self,
*,
environment: str,
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]:
return self._manager.create_task(
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
network=network,
allow_host_compat=allow_host_compat,
)
def exec_task(
self,
task_id: str,
*,
command: str,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> dict[str, Any]:
return self._manager.exec_task(task_id, command=command, timeout_seconds=timeout_seconds)
def status_task(self, task_id: str) -> dict[str, Any]:
return self._manager.status_task(task_id)
def logs_task(self, task_id: str) -> dict[str, Any]:
return self._manager.logs_task(task_id)
def delete_task(self, task_id: str) -> dict[str, Any]:
return self._manager.delete_task(task_id)
def stop_vm(self, vm_id: str) -> dict[str, Any]: def stop_vm(self, vm_id: str) -> dict[str, Any]:
return self._manager.stop_vm(vm_id) return self._manager.stop_vm(vm_id)
@ -200,4 +237,47 @@ class Pyro:
"""Delete VMs whose TTL has expired.""" """Delete VMs whose TTL has expired."""
return self.reap_expired() return self.reap_expired()
@server.tool()
async def task_create(
environment: str,
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]:
"""Create and start a persistent task workspace."""
return self.create_task(
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
network=network,
allow_host_compat=allow_host_compat,
)
@server.tool()
async def task_exec(
task_id: str,
command: str,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
) -> dict[str, Any]:
"""Run one command inside an existing task workspace."""
return self.exec_task(task_id, command=command, timeout_seconds=timeout_seconds)
@server.tool()
async def task_status(task_id: str) -> dict[str, Any]:
"""Inspect task state and latest command metadata."""
return self.status_task(task_id)
@server.tool()
async def task_logs(task_id: str) -> dict[str, Any]:
"""Return persisted command history for one task."""
return self.logs_task(task_id)
@server.tool()
async def task_delete(task_id: str) -> dict[str, Any]:
"""Delete a task workspace and its backing sandbox."""
return self.delete_task(task_id)
return server return server

View file

@ -17,6 +17,7 @@ from pyro_mcp.vm_environments import DEFAULT_CATALOG_VERSION
from pyro_mcp.vm_manager import ( from pyro_mcp.vm_manager import (
DEFAULT_MEM_MIB, DEFAULT_MEM_MIB,
DEFAULT_VCPU_COUNT, DEFAULT_VCPU_COUNT,
TASK_WORKSPACE_GUEST_PATH,
) )
@ -149,6 +150,67 @@ def _print_doctor_human(payload: dict[str, Any]) -> None:
print(f"- {issue}") print(f"- {issue}")
def _print_task_summary_human(payload: dict[str, Any], *, action: str) -> None:
print(f"{action}: {str(payload.get('task_id', 'unknown'))}")
print(f"Environment: {str(payload.get('environment', 'unknown'))}")
print(f"State: {str(payload.get('state', 'unknown'))}")
print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}")
print(f"Execution mode: {str(payload.get('execution_mode', 'pending'))}")
print(
f"Resources: {int(payload.get('vcpu_count', 0))} vCPU / "
f"{int(payload.get('mem_mib', 0))} MiB"
)
print(f"Command count: {int(payload.get('command_count', 0))}")
last_command = payload.get("last_command")
if isinstance(last_command, dict):
print(
"Last command: "
f"{str(last_command.get('command', 'unknown'))} "
f"(exit_code={int(last_command.get('exit_code', -1))})"
)
def _print_task_exec_human(payload: dict[str, Any]) -> None:
stdout = str(payload.get("stdout", ""))
stderr = str(payload.get("stderr", ""))
_write_stream(stdout, stream=sys.stdout)
_write_stream(stderr, stream=sys.stderr)
print(
"[task-exec] "
f"task_id={str(payload.get('task_id', 'unknown'))} "
f"sequence={int(payload.get('sequence', 0))} "
f"cwd={str(payload.get('cwd', TASK_WORKSPACE_GUEST_PATH))} "
f"execution_mode={str(payload.get('execution_mode', 'unknown'))} "
f"exit_code={int(payload.get('exit_code', 1))} "
f"duration_ms={int(payload.get('duration_ms', 0))}",
file=sys.stderr,
flush=True,
)
def _print_task_logs_human(payload: dict[str, Any]) -> None:
entries = payload.get("entries")
if not isinstance(entries, list) or not entries:
print("No task logs found.")
return
for entry in entries:
if not isinstance(entry, dict):
continue
print(
f"#{int(entry.get('sequence', 0))} "
f"exit_code={int(entry.get('exit_code', -1))} "
f"duration_ms={int(entry.get('duration_ms', 0))} "
f"cwd={str(entry.get('cwd', TASK_WORKSPACE_GUEST_PATH))}"
)
print(f"$ {str(entry.get('command', ''))}")
stdout = str(entry.get("stdout", ""))
stderr = str(entry.get("stderr", ""))
if stdout != "":
print(stdout, end="" if stdout.endswith("\n") else "\n")
if stderr != "":
print(stderr, end="" if stderr.endswith("\n") else "\n", file=sys.stderr)
class _HelpFormatter( class _HelpFormatter(
argparse.RawDescriptionHelpFormatter, argparse.RawDescriptionHelpFormatter,
argparse.ArgumentDefaultsHelpFormatter, argparse.ArgumentDefaultsHelpFormatter,
@ -178,6 +240,9 @@ def _build_parser() -> argparse.ArgumentParser:
pyro env pull debian:12 pyro env pull debian:12
pyro run debian:12 -- git --version pyro run debian:12 -- git --version
Need repeated commands in one workspace after that?
pyro task create debian:12
Use `pyro mcp serve` only after the CLI validation path works. Use `pyro mcp serve` only after the CLI validation path works.
""" """
), ),
@ -371,6 +436,152 @@ def _build_parser() -> argparse.ArgumentParser:
), ),
) )
task_parser = subparsers.add_parser(
"task",
help="Manage persistent task workspaces.",
description=(
"Create a persistent workspace when you need repeated commands in one "
"sandbox instead of one-shot `pyro run`."
),
epilog=dedent(
"""
Examples:
pyro task create debian:12
pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt'
pyro task logs TASK_ID
"""
),
formatter_class=_HelpFormatter,
)
task_subparsers = task_parser.add_subparsers(dest="task_command", required=True, metavar="TASK")
task_create_parser = task_subparsers.add_parser(
"create",
help="Create and start a persistent task workspace.",
description="Create a task workspace that stays alive across repeated exec calls.",
epilog="Example:\n pyro task create debian:12",
formatter_class=_HelpFormatter,
)
task_create_parser.add_argument(
"environment",
metavar="ENVIRONMENT",
help="Curated environment to boot, for example `debian:12`.",
)
task_create_parser.add_argument(
"--vcpu-count",
type=int,
default=DEFAULT_VCPU_COUNT,
help="Number of virtual CPUs to allocate to the task guest.",
)
task_create_parser.add_argument(
"--mem-mib",
type=int,
default=DEFAULT_MEM_MIB,
help="Guest memory allocation in MiB.",
)
task_create_parser.add_argument(
"--ttl-seconds",
type=int,
default=600,
help="Time-to-live for the task before automatic cleanup.",
)
task_create_parser.add_argument(
"--network",
action="store_true",
help="Enable outbound guest networking for the task guest.",
)
task_create_parser.add_argument(
"--allow-host-compat",
action="store_true",
help=(
"Opt into host-side compatibility execution if guest boot or guest exec "
"is unavailable."
),
)
task_create_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_exec_parser = task_subparsers.add_parser(
"exec",
help="Run one command inside an existing task workspace.",
description="Run one non-interactive command in the persistent `/workspace` for a task.",
epilog="Example:\n pyro task exec TASK_ID -- cat note.txt",
formatter_class=_HelpFormatter,
)
task_exec_parser.add_argument("task_id", metavar="TASK_ID", help="Persistent task identifier.")
task_exec_parser.add_argument(
"--timeout-seconds",
type=int,
default=30,
help="Maximum time allowed for the task command.",
)
task_exec_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_exec_parser.add_argument(
"command_args",
nargs="*",
metavar="ARG",
help=(
"Command and arguments to run inside the task workspace. Prefix them with `--`, "
"for example `pyro task exec TASK_ID -- cat note.txt`."
),
)
task_status_parser = task_subparsers.add_parser(
"status",
help="Inspect one task workspace.",
description="Show task state, sizing, workspace path, and latest command metadata.",
epilog="Example:\n pyro task status TASK_ID",
formatter_class=_HelpFormatter,
)
task_status_parser.add_argument(
"task_id",
metavar="TASK_ID",
help="Persistent task identifier.",
)
task_status_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_logs_parser = task_subparsers.add_parser(
"logs",
help="Show command history for one task.",
description="Show persisted command history, including stdout and stderr, for one task.",
epilog="Example:\n pyro task logs TASK_ID",
formatter_class=_HelpFormatter,
)
task_logs_parser.add_argument(
"task_id",
metavar="TASK_ID",
help="Persistent task identifier.",
)
task_logs_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_delete_parser = task_subparsers.add_parser(
"delete",
help="Delete one task workspace.",
description="Stop the backing sandbox if needed and remove the task workspace.",
epilog="Example:\n pyro task delete TASK_ID",
formatter_class=_HelpFormatter,
)
task_delete_parser.add_argument(
"task_id",
metavar="TASK_ID",
help="Persistent task identifier.",
)
task_delete_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
doctor_parser = subparsers.add_parser( doctor_parser = subparsers.add_parser(
"doctor", "doctor",
help="Inspect runtime and host diagnostics.", help="Inspect runtime and host diagnostics.",
@ -451,7 +662,7 @@ def _require_command(command_args: list[str]) -> str:
if command_args and command_args[0] == "--": if command_args and command_args[0] == "--":
command_args = command_args[1:] command_args = command_args[1:]
if not command_args: if not command_args:
raise ValueError("command is required after `pyro run --`") raise ValueError("command is required after `--`")
return " ".join(command_args) return " ".join(command_args)
@ -544,6 +755,70 @@ def main() -> None:
if exit_code != 0: if exit_code != 0:
raise SystemExit(exit_code) raise SystemExit(exit_code)
return return
if args.command == "task":
if args.task_command == "create":
payload = pyro.create_task(
environment=args.environment,
vcpu_count=args.vcpu_count,
mem_mib=args.mem_mib,
ttl_seconds=args.ttl_seconds,
network=args.network,
allow_host_compat=args.allow_host_compat,
)
if bool(args.json):
_print_json(payload)
else:
_print_task_summary_human(payload, action="Task")
return
if args.task_command == "exec":
command = _require_command(args.command_args)
if bool(args.json):
try:
payload = pyro.exec_task(
args.task_id,
command=command,
timeout_seconds=args.timeout_seconds,
)
except Exception as exc: # noqa: BLE001
_print_json({"ok": False, "error": str(exc)})
raise SystemExit(1) from exc
_print_json(payload)
else:
try:
payload = pyro.exec_task(
args.task_id,
command=command,
timeout_seconds=args.timeout_seconds,
)
except Exception as exc: # noqa: BLE001
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
_print_task_exec_human(payload)
exit_code = int(payload.get("exit_code", 1))
if exit_code != 0:
raise SystemExit(exit_code)
return
if args.task_command == "status":
payload = pyro.status_task(args.task_id)
if bool(args.json):
_print_json(payload)
else:
_print_task_summary_human(payload, action="Task")
return
if args.task_command == "logs":
payload = pyro.logs_task(args.task_id)
if bool(args.json):
_print_json(payload)
else:
_print_task_logs_human(payload)
return
if args.task_command == "delete":
payload = pyro.delete_task(args.task_id)
if bool(args.json):
_print_json(payload)
else:
print(f"Deleted task: {str(payload.get('task_id', 'unknown'))}")
return
if args.command == "doctor": if args.command == "doctor":
payload = doctor_report(platform=args.platform) payload = doctor_report(platform=args.platform)
if bool(args.json): if bool(args.json):

View file

@ -2,9 +2,10 @@
from __future__ import annotations from __future__ import annotations
PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run") PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task")
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",) PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune") PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status")
PUBLIC_CLI_RUN_FLAGS = ( PUBLIC_CLI_RUN_FLAGS = (
"--vcpu-count", "--vcpu-count",
"--mem-mib", "--mem-mib",
@ -17,17 +18,22 @@ PUBLIC_CLI_RUN_FLAGS = (
PUBLIC_SDK_METHODS = ( PUBLIC_SDK_METHODS = (
"create_server", "create_server",
"create_task",
"create_vm", "create_vm",
"delete_task",
"delete_vm", "delete_vm",
"exec_task",
"exec_vm", "exec_vm",
"inspect_environment", "inspect_environment",
"list_environments", "list_environments",
"logs_task",
"network_info_vm", "network_info_vm",
"prune_environments", "prune_environments",
"pull_environment", "pull_environment",
"reap_expired", "reap_expired",
"run_in_vm", "run_in_vm",
"start_vm", "start_vm",
"status_task",
"status_vm", "status_vm",
"stop_vm", "stop_vm",
) )
@ -43,4 +49,9 @@ PUBLIC_MCP_TOOLS = (
"vm_start", "vm_start",
"vm_status", "vm_status",
"vm_stop", "vm_stop",
"task_create",
"task_delete",
"task_exec",
"task_logs",
"task_status",
) )

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0" DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "2.0.0" DEFAULT_CATALOG_VERSION = "2.1.0"
OCI_MANIFEST_ACCEPT = ", ".join( OCI_MANIFEST_ACCEPT = ", ".join(
( (
"application/vnd.oci.image.index.v1+json", "application/vnd.oci.image.index.v1+json",

View file

@ -1,8 +1,10 @@
"""Lifecycle manager for ephemeral VM environments.""" """Lifecycle manager for ephemeral VM environments and persistent tasks."""
from __future__ import annotations from __future__ import annotations
import json
import os import os
import shlex
import shutil import shutil
import signal import signal
import subprocess import subprocess
@ -11,7 +13,7 @@ import time
import uuid import uuid
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Literal from typing import Any, Literal, cast
from pyro_mcp.runtime import ( from pyro_mcp.runtime import (
RuntimeCapabilities, RuntimeCapabilities,
@ -32,6 +34,12 @@ DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600 DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False DEFAULT_ALLOW_HOST_COMPAT = False
TASK_LAYOUT_VERSION = 1
TASK_WORKSPACE_DIRNAME = "workspace"
TASK_COMMANDS_DIRNAME = "commands"
TASK_RUNTIME_DIRNAME = "runtime"
TASK_WORKSPACE_GUEST_PATH = "/workspace"
@dataclass @dataclass
class VmInstance: class VmInstance:
@ -54,6 +62,116 @@ class VmInstance:
network: NetworkConfig | None = None network: NetworkConfig | None = None
@dataclass
class TaskRecord:
"""Persistent task metadata stored on disk."""
task_id: str
environment: str
vcpu_count: int
mem_mib: int
ttl_seconds: int
created_at: float
expires_at: float
state: VmState
network_requested: bool
allow_host_compat: bool
firecracker_pid: int | None = None
last_error: str | None = None
metadata: dict[str, str] = field(default_factory=dict)
network: NetworkConfig | None = None
command_count: int = 0
last_command: dict[str, Any] | None = None
@classmethod
def from_instance(
cls,
instance: VmInstance,
*,
command_count: int = 0,
last_command: dict[str, Any] | None = None,
) -> TaskRecord:
return cls(
task_id=instance.vm_id,
environment=instance.environment,
vcpu_count=instance.vcpu_count,
mem_mib=instance.mem_mib,
ttl_seconds=instance.ttl_seconds,
created_at=instance.created_at,
expires_at=instance.expires_at,
state=instance.state,
network_requested=instance.network_requested,
allow_host_compat=instance.allow_host_compat,
firecracker_pid=instance.firecracker_pid,
last_error=instance.last_error,
metadata=dict(instance.metadata),
network=instance.network,
command_count=command_count,
last_command=last_command,
)
def to_instance(self, *, workdir: Path) -> VmInstance:
return VmInstance(
vm_id=self.task_id,
environment=self.environment,
vcpu_count=self.vcpu_count,
mem_mib=self.mem_mib,
ttl_seconds=self.ttl_seconds,
created_at=self.created_at,
expires_at=self.expires_at,
workdir=workdir,
state=self.state,
network_requested=self.network_requested,
allow_host_compat=self.allow_host_compat,
firecracker_pid=self.firecracker_pid,
last_error=self.last_error,
metadata=dict(self.metadata),
network=self.network,
)
def to_payload(self) -> dict[str, Any]:
return {
"layout_version": TASK_LAYOUT_VERSION,
"task_id": self.task_id,
"environment": self.environment,
"vcpu_count": self.vcpu_count,
"mem_mib": self.mem_mib,
"ttl_seconds": self.ttl_seconds,
"created_at": self.created_at,
"expires_at": self.expires_at,
"state": self.state,
"network_requested": self.network_requested,
"allow_host_compat": self.allow_host_compat,
"firecracker_pid": self.firecracker_pid,
"last_error": self.last_error,
"metadata": self.metadata,
"network": _serialize_network(self.network),
"command_count": self.command_count,
"last_command": self.last_command,
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> TaskRecord:
return cls(
task_id=str(payload["task_id"]),
environment=str(payload["environment"]),
vcpu_count=int(payload["vcpu_count"]),
mem_mib=int(payload["mem_mib"]),
ttl_seconds=int(payload["ttl_seconds"]),
created_at=float(payload["created_at"]),
expires_at=float(payload["expires_at"]),
state=cast(VmState, str(payload.get("state", "stopped"))),
network_requested=bool(payload.get("network_requested", False)),
allow_host_compat=bool(payload.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)),
firecracker_pid=_optional_int(payload.get("firecracker_pid")),
last_error=_optional_str(payload.get("last_error")),
metadata=_string_dict(payload.get("metadata")),
network=_deserialize_network(payload.get("network")),
command_count=int(payload.get("command_count", 0)),
last_command=_optional_dict(payload.get("last_command")),
)
@dataclass(frozen=True) @dataclass(frozen=True)
class VmExecResult: class VmExecResult:
"""Command execution output.""" """Command execution output."""
@ -64,6 +182,72 @@ class VmExecResult:
duration_ms: int duration_ms: int
def _optional_int(value: object) -> int | None:
if value is None:
return None
if isinstance(value, bool):
return int(value)
if isinstance(value, int):
return value
if isinstance(value, float):
return int(value)
if isinstance(value, str):
return int(value)
raise TypeError("expected integer-compatible payload")
def _optional_str(value: object) -> str | None:
if value is None:
return None
return str(value)
def _optional_dict(value: object) -> dict[str, Any] | None:
if value is None:
return None
if not isinstance(value, dict):
raise TypeError("expected dictionary payload")
return dict(value)
def _string_dict(value: object) -> dict[str, str]:
if not isinstance(value, dict):
return {}
return {str(key): str(item) for key, item in value.items()}
def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None:
if network is None:
return None
return {
"vm_id": network.vm_id,
"tap_name": network.tap_name,
"guest_ip": network.guest_ip,
"gateway_ip": network.gateway_ip,
"subnet_cidr": network.subnet_cidr,
"mac_address": network.mac_address,
"dns_servers": list(network.dns_servers),
}
def _deserialize_network(payload: object) -> NetworkConfig | None:
if payload is None:
return None
if not isinstance(payload, dict):
raise TypeError("expected dictionary payload")
dns_servers = payload.get("dns_servers", [])
dns_values = tuple(str(item) for item in dns_servers) if isinstance(dns_servers, list) else ()
return NetworkConfig(
vm_id=str(payload["vm_id"]),
tap_name=str(payload["tap_name"]),
guest_ip=str(payload["guest_ip"]),
gateway_ip=str(payload["gateway_ip"]),
subnet_cidr=str(payload["subnet_cidr"]),
mac_address=str(payload["mac_address"]),
dns_servers=dns_values,
)
def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult: def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult:
started = time.monotonic() started = time.monotonic()
env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)} env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)}
@ -109,6 +293,25 @@ def _copy_rootfs(source: Path, dest: Path) -> str:
return "copy2" return "copy2"
def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str:
if cwd is None:
return command
quoted_cwd = shlex.quote(cwd)
return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}"
def _pid_is_running(pid: int | None) -> bool:
if pid is None:
return False
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
return True
class VmBackend: class VmBackend:
"""Backend interface for lifecycle operations.""" """Backend interface for lifecycle operations."""
@ -119,7 +322,12 @@ class VmBackend:
raise NotImplementedError raise NotImplementedError
def exec( # pragma: no cover def exec( # pragma: no cover
self, instance: VmInstance, command: str, timeout_seconds: int self,
instance: VmInstance,
command: str,
timeout_seconds: int,
*,
workdir: Path | None = None,
) -> VmExecResult: ) -> VmExecResult:
raise NotImplementedError raise NotImplementedError
@ -140,8 +348,15 @@ class MockBackend(VmBackend):
marker_path = instance.workdir / ".started" marker_path = instance.workdir / ".started"
marker_path.write_text("started\n", encoding="utf-8") marker_path.write_text("started\n", encoding="utf-8")
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: def exec(
return _run_host_command(instance.workdir, command, timeout_seconds) self,
instance: VmInstance,
command: str,
timeout_seconds: int,
*,
workdir: Path | None = None,
) -> VmExecResult:
return _run_host_command(workdir or instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None: def stop(self, instance: VmInstance) -> None:
marker_path = instance.workdir / ".stopped" marker_path = instance.workdir / ".stopped"
@ -256,6 +471,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
stdout=serial_fp, stdout=serial_fp,
stderr=subprocess.STDOUT, stderr=subprocess.STDOUT,
text=True, text=True,
start_new_session=True,
) )
self._processes[instance.vm_id] = process self._processes[instance.vm_id] = process
time.sleep(2) time.sleep(2)
@ -273,7 +489,14 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
) )
instance.metadata["boot_mode"] = "native" instance.metadata["boot_mode"] = "native"
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: def exec(
self,
instance: VmInstance,
command: str,
timeout_seconds: int,
*,
workdir: Path | None = None,
) -> VmExecResult:
if self._runtime_capabilities.supports_guest_exec: if self._runtime_capabilities.supports_guest_exec:
guest_cid = int(instance.metadata["guest_cid"]) guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"]) port = int(instance.metadata["guest_exec_port"])
@ -302,7 +525,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
duration_ms=response.duration_ms, duration_ms=response.duration_ms,
) )
instance.metadata["execution_mode"] = "host_compat" instance.metadata["execution_mode"] = "host_compat"
return _run_host_command(instance.workdir, command, timeout_seconds) return _run_host_command(workdir or instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None: def stop(self, instance: VmInstance) -> None:
process = self._processes.pop(instance.vm_id, None) process = self._processes.pop(instance.vm_id, None)
@ -341,7 +564,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
class VmManager: class VmManager:
"""In-process lifecycle manager for ephemeral VM environments.""" """In-process lifecycle manager for ephemeral VM environments and tasks."""
MIN_VCPUS = 1 MIN_VCPUS = 1
MAX_VCPUS = 8 MAX_VCPUS = 8
@ -367,6 +590,7 @@ class VmManager:
) -> None: ) -> None:
self._backend_name = backend_name or "firecracker" self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp") self._base_dir = base_dir or Path("/tmp/pyro-mcp")
self._tasks_dir = self._base_dir / "tasks"
resolved_cache_dir = cache_dir or default_cache_dir() resolved_cache_dir = cache_dir or default_cache_dir()
self._runtime_paths = runtime_paths self._runtime_paths = runtime_paths
if self._backend_name == "firecracker": if self._backend_name == "firecracker":
@ -399,6 +623,7 @@ class VmManager:
self._lock = threading.Lock() self._lock = threading.Lock()
self._instances: dict[str, VmInstance] = {} self._instances: dict[str, VmInstance] = {}
self._base_dir.mkdir(parents=True, exist_ok=True) self._base_dir.mkdir(parents=True, exist_ok=True)
self._tasks_dir.mkdir(parents=True, exist_ok=True)
self._backend = self._build_backend() self._backend = self._build_backend()
def _build_backend(self) -> VmBackend: def _build_backend(self) -> VmBackend:
@ -443,7 +668,8 @@ class VmManager:
now = time.time() now = time.time()
with self._lock: with self._lock:
self._reap_expired_locked(now) self._reap_expired_locked(now)
active_count = len(self._instances) self._reap_expired_tasks_locked(now)
active_count = len(self._instances) + self._count_tasks_locked()
if active_count >= self._max_active_vms: if active_count >= self._max_active_vms:
raise RuntimeError( raise RuntimeError(
f"max active VMs reached ({self._max_active_vms}); delete old VMs first" f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
@ -501,36 +727,24 @@ class VmManager:
with self._lock: with self._lock:
instance = self._get_instance_locked(vm_id) instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time()) self._ensure_not_expired_locked(instance, time.time())
if instance.state not in {"created", "stopped"}: self._start_instance_locked(instance)
raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}")
self._require_guest_boot_or_opt_in(instance)
if not self._runtime_capabilities.supports_vm_boot:
instance.metadata["execution_mode"] = "host_compat"
instance.metadata["boot_mode"] = "compat"
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
self._backend.start(instance)
instance.state = "started"
return self._serialize(instance) return self._serialize(instance)
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
with self._lock: with self._lock:
instance = self._get_instance_locked(vm_id) instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time()) self._ensure_not_expired_locked(instance, time.time())
if instance.state != "started": exec_instance = instance
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec") exec_result, execution_mode = self._exec_instance(
self._require_guest_exec_or_opt_in(instance) exec_instance,
if not self._runtime_capabilities.supports_guest_exec: command=command,
instance.metadata["execution_mode"] = "host_compat" timeout_seconds=timeout_seconds,
exec_result = self._backend.exec(instance, command, timeout_seconds) )
execution_mode = instance.metadata.get("execution_mode", "unknown")
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup") cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return { return {
"vm_id": vm_id, "vm_id": vm_id,
"environment": instance.environment, "environment": exec_instance.environment,
"environment_version": instance.metadata.get("environment_version"), "environment_version": exec_instance.metadata.get("environment_version"),
"command": command, "command": command,
"stdout": exec_result.stdout, "stdout": exec_result.stdout,
"stderr": exec_result.stderr, "stderr": exec_result.stderr,
@ -591,6 +805,154 @@ class VmManager:
del self._instances[vm_id] del self._instances[vm_id]
return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)} return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)}
def create_task(
self,
*,
environment: str,
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_environment(environment, runtime_paths=self._runtime_paths)
now = time.time()
task_id = uuid.uuid4().hex[:12]
task_dir = self._task_dir(task_id)
runtime_dir = self._task_runtime_dir(task_id)
workspace_dir = self._task_workspace_dir(task_id)
commands_dir = self._task_commands_dir(task_id)
task_dir.mkdir(parents=True, exist_ok=False)
workspace_dir.mkdir(parents=True, exist_ok=True)
commands_dir.mkdir(parents=True, exist_ok=True)
instance = VmInstance(
vm_id=task_id,
environment=environment,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
created_at=now,
expires_at=now + ttl_seconds,
workdir=runtime_dir,
network_requested=network,
allow_host_compat=allow_host_compat,
)
instance.metadata["allow_host_compat"] = str(allow_host_compat).lower()
instance.metadata["workspace_path"] = TASK_WORKSPACE_GUEST_PATH
instance.metadata["workspace_host_dir"] = str(workspace_dir)
try:
with self._lock:
self._reap_expired_locked(now)
self._reap_expired_tasks_locked(now)
active_count = len(self._instances) + self._count_tasks_locked()
if active_count >= self._max_active_vms:
raise RuntimeError(
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
self._backend.create(instance)
with self._lock:
self._start_instance_locked(instance)
self._require_guest_exec_or_opt_in(instance)
if self._runtime_capabilities.supports_guest_exec:
self._backend.exec(
instance,
f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}",
10,
)
else:
instance.metadata["execution_mode"] = "host_compat"
task = TaskRecord.from_instance(instance)
self._save_task_locked(task)
return self._serialize_task(task)
except Exception:
if runtime_dir.exists():
try:
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
except Exception:
pass
try:
self._backend.delete(instance)
except Exception:
pass
shutil.rmtree(task_dir, ignore_errors=True)
raise
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
with self._lock:
task = self._load_task_locked(task_id)
self._ensure_task_not_expired_locked(task, time.time())
self._refresh_task_liveness_locked(task)
if task.state != "started":
raise RuntimeError(f"task {task_id} must be in 'started' state before task_exec")
instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id))
exec_result, execution_mode = self._exec_instance(
instance,
command=command,
timeout_seconds=timeout_seconds,
host_workdir=self._task_workspace_dir(task.task_id),
guest_cwd=TASK_WORKSPACE_GUEST_PATH,
)
with self._lock:
task = self._load_task_locked(task_id)
task.state = instance.state
task.firecracker_pid = instance.firecracker_pid
task.last_error = instance.last_error
task.metadata = dict(instance.metadata)
entry = self._record_task_command_locked(
task,
command=command,
exec_result=exec_result,
execution_mode=execution_mode,
cwd=TASK_WORKSPACE_GUEST_PATH,
)
self._save_task_locked(task)
return {
"task_id": task_id,
"environment": task.environment,
"environment_version": task.metadata.get("environment_version"),
"command": command,
"stdout": exec_result.stdout,
"stderr": exec_result.stderr,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
"sequence": entry["sequence"],
"cwd": TASK_WORKSPACE_GUEST_PATH,
}
def status_task(self, task_id: str) -> dict[str, Any]:
with self._lock:
task = self._load_task_locked(task_id)
self._ensure_task_not_expired_locked(task, time.time())
self._refresh_task_liveness_locked(task)
self._save_task_locked(task)
return self._serialize_task(task)
def logs_task(self, task_id: str) -> dict[str, Any]:
with self._lock:
task = self._load_task_locked(task_id)
self._ensure_task_not_expired_locked(task, time.time())
self._refresh_task_liveness_locked(task)
self._save_task_locked(task)
entries = self._read_task_logs_locked(task.task_id)
return {"task_id": task.task_id, "count": len(entries), "entries": entries}
def delete_task(self, task_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]:
with self._lock:
task = self._load_task_locked(task_id)
instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id))
if task.state == "started":
self._backend.stop(instance)
task.state = "stopped"
self._backend.delete(instance)
shutil.rmtree(self._task_dir(task_id), ignore_errors=True)
return {"task_id": task_id, "deleted": True, "reason": reason}
def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None: def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None:
if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS: if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS:
raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}") raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}")
@ -620,6 +982,28 @@ class VmManager:
"metadata": instance.metadata, "metadata": instance.metadata,
} }
def _serialize_task(self, task: TaskRecord) -> dict[str, Any]:
return {
"task_id": task.task_id,
"environment": task.environment,
"environment_version": task.metadata.get("environment_version"),
"vcpu_count": task.vcpu_count,
"mem_mib": task.mem_mib,
"ttl_seconds": task.ttl_seconds,
"created_at": task.created_at,
"expires_at": task.expires_at,
"state": task.state,
"network_enabled": task.network is not None,
"allow_host_compat": task.allow_host_compat,
"guest_ip": task.network.guest_ip if task.network is not None else None,
"tap_name": task.network.tap_name if task.network is not None else None,
"execution_mode": task.metadata.get("execution_mode", "pending"),
"workspace_path": TASK_WORKSPACE_GUEST_PATH,
"command_count": task.command_count,
"last_command": task.last_command,
"metadata": task.metadata,
}
def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None: def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None:
if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat: if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat:
return return
@ -665,3 +1049,184 @@ class VmManager:
vm_id = instance.vm_id vm_id = instance.vm_id
self._reap_expired_locked(now) self._reap_expired_locked(now)
raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted") raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted")
def _start_instance_locked(self, instance: VmInstance) -> None:
if instance.state not in {"created", "stopped"}:
raise RuntimeError(
f"vm {instance.vm_id} cannot be started from state {instance.state!r}"
)
self._require_guest_boot_or_opt_in(instance)
if not self._runtime_capabilities.supports_vm_boot:
instance.metadata["execution_mode"] = "host_compat"
instance.metadata["boot_mode"] = "compat"
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
self._backend.start(instance)
instance.state = "started"
def _exec_instance(
self,
instance: VmInstance,
*,
command: str,
timeout_seconds: int,
host_workdir: Path | None = None,
guest_cwd: str | None = None,
) -> tuple[VmExecResult, str]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
if instance.state != "started":
raise RuntimeError(f"vm {instance.vm_id} must be in 'started' state before execution")
self._require_guest_exec_or_opt_in(instance)
prepared_command = command
if self._runtime_capabilities.supports_guest_exec:
prepared_command = _wrap_guest_command(command, cwd=guest_cwd)
workdir = None
else:
instance.metadata["execution_mode"] = "host_compat"
workdir = host_workdir
exec_result = self._backend.exec(
instance,
prepared_command,
timeout_seconds,
workdir=workdir,
)
execution_mode = instance.metadata.get("execution_mode", "unknown")
return exec_result, execution_mode
def _task_dir(self, task_id: str) -> Path:
return self._tasks_dir / task_id
def _task_runtime_dir(self, task_id: str) -> Path:
return self._task_dir(task_id) / TASK_RUNTIME_DIRNAME
def _task_workspace_dir(self, task_id: str) -> Path:
return self._task_dir(task_id) / TASK_WORKSPACE_DIRNAME
def _task_commands_dir(self, task_id: str) -> Path:
return self._task_dir(task_id) / TASK_COMMANDS_DIRNAME
def _task_metadata_path(self, task_id: str) -> Path:
return self._task_dir(task_id) / "task.json"
def _count_tasks_locked(self) -> int:
return sum(1 for _ in self._tasks_dir.glob("*/task.json"))
def _load_task_locked(self, task_id: str) -> TaskRecord:
metadata_path = self._task_metadata_path(task_id)
if not metadata_path.exists():
raise ValueError(f"task {task_id!r} does not exist")
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError(f"task record at {metadata_path} is invalid")
return TaskRecord.from_payload(payload)
def _save_task_locked(self, task: TaskRecord) -> None:
metadata_path = self._task_metadata_path(task.task_id)
metadata_path.parent.mkdir(parents=True, exist_ok=True)
metadata_path.write_text(
json.dumps(task.to_payload(), indent=2, sort_keys=True),
encoding="utf-8",
)
def _reap_expired_tasks_locked(self, now: float) -> None:
for metadata_path in list(self._tasks_dir.glob("*/task.json")):
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
shutil.rmtree(metadata_path.parent, ignore_errors=True)
continue
task = TaskRecord.from_payload(payload)
if task.expires_at > now:
continue
instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id))
if task.state == "started":
self._backend.stop(instance)
task.state = "stopped"
self._backend.delete(instance)
shutil.rmtree(self._task_dir(task.task_id), ignore_errors=True)
def _ensure_task_not_expired_locked(self, task: TaskRecord, now: float) -> None:
if task.expires_at <= now:
task_id = task.task_id
self._reap_expired_tasks_locked(now)
raise RuntimeError(f"task {task_id!r} expired and was automatically deleted")
def _refresh_task_liveness_locked(self, task: TaskRecord) -> None:
if task.state != "started":
return
execution_mode = task.metadata.get("execution_mode")
if execution_mode == "host_compat":
return
if _pid_is_running(task.firecracker_pid):
return
task.state = "stopped"
task.firecracker_pid = None
task.last_error = "backing guest process is no longer running"
def _record_task_command_locked(
self,
task: TaskRecord,
*,
command: str,
exec_result: VmExecResult,
execution_mode: str,
cwd: str,
) -> dict[str, Any]:
sequence = task.command_count + 1
commands_dir = self._task_commands_dir(task.task_id)
commands_dir.mkdir(parents=True, exist_ok=True)
base_name = f"{sequence:06d}"
stdout_path = commands_dir / f"{base_name}.stdout"
stderr_path = commands_dir / f"{base_name}.stderr"
record_path = commands_dir / f"{base_name}.json"
stdout_path.write_text(exec_result.stdout, encoding="utf-8")
stderr_path.write_text(exec_result.stderr, encoding="utf-8")
entry: dict[str, Any] = {
"sequence": sequence,
"command": command,
"cwd": cwd,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
"stdout_file": stdout_path.name,
"stderr_file": stderr_path.name,
"recorded_at": time.time(),
}
record_path.write_text(json.dumps(entry, indent=2, sort_keys=True), encoding="utf-8")
task.command_count = sequence
task.last_command = {
"sequence": sequence,
"command": command,
"cwd": cwd,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"execution_mode": execution_mode,
}
return entry
def _read_task_logs_locked(self, task_id: str) -> list[dict[str, Any]]:
entries: list[dict[str, Any]] = []
commands_dir = self._task_commands_dir(task_id)
if not commands_dir.exists():
return entries
for record_path in sorted(commands_dir.glob("*.json")):
payload = json.loads(record_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
continue
stdout_name = str(payload.get("stdout_file", ""))
stderr_name = str(payload.get("stderr_file", ""))
stdout = ""
stderr = ""
if stdout_name != "":
stdout_path = commands_dir / stdout_name
if stdout_path.exists():
stdout = stdout_path.read_text(encoding="utf-8")
if stderr_name != "":
stderr_path = commands_dir / stderr_name
if stderr_path.exists():
stderr = stderr_path.read_text(encoding="utf-8")
entry = dict(payload)
entry["stdout"] = stdout
entry["stderr"] = stderr
entries.append(entry)
return entries

View file

@ -48,6 +48,7 @@ def test_pyro_create_server_registers_vm_run(tmp_path: Path) -> None:
tool_names = asyncio.run(_run()) tool_names = asyncio.run(_run())
assert "vm_run" in tool_names assert "vm_run" in tool_names
assert "vm_create" in tool_names assert "vm_create" in tool_names
assert "task_create" in tool_names
def test_pyro_vm_run_tool_executes(tmp_path: Path) -> None: def test_pyro_vm_run_tool_executes(tmp_path: Path) -> None:
@ -102,3 +103,25 @@ def test_pyro_create_vm_defaults_sizing_and_host_compat(tmp_path: Path) -> None:
assert created["vcpu_count"] == 1 assert created["vcpu_count"] == 1
assert created["mem_mib"] == 1024 assert created["mem_mib"] == 1024
assert created["allow_host_compat"] is True assert created["allow_host_compat"] is True
def test_pyro_task_methods_delegate_to_manager(tmp_path: Path) -> None:
pyro = Pyro(
manager=VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
)
created = pyro.create_task(environment="debian:12-base", allow_host_compat=True)
task_id = str(created["task_id"])
executed = pyro.exec_task(task_id, command="printf 'ok\\n'")
status = pyro.status_task(task_id)
logs = pyro.logs_task(task_id)
deleted = pyro.delete_task(task_id)
assert executed["stdout"] == "ok\n"
assert status["command_count"] == 1
assert logs["count"] == 1
assert deleted["deleted"] is True

View file

@ -59,6 +59,14 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None:
assert "Expose pyro tools over stdio for an MCP client." in mcp_help assert "Expose pyro tools over stdio for an MCP client." in mcp_help
assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help
task_help = _subparser_choice(parser, "task").format_help()
assert "pyro task create debian:12" in task_help
assert "pyro task exec TASK_ID" in task_help
task_exec_help = _subparser_choice(_subparser_choice(parser, "task"), "exec").format_help()
assert "persistent `/workspace`" in task_exec_help
assert "pyro task exec TASK_ID -- cat note.txt" in task_exec_help
def test_cli_run_prints_json( def test_cli_run_prints_json(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
@ -318,6 +326,243 @@ def test_cli_requires_run_command() -> None:
cli._require_command([]) cli._require_command([])
def test_cli_task_create_prints_json(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def create_task(self, **kwargs: Any) -> dict[str, Any]:
assert kwargs["environment"] == "debian:12"
return {"task_id": "task-123", "state": "started"}
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="create",
environment="debian:12",
vcpu_count=1,
mem_mib=1024,
ttl_seconds=600,
network=False,
allow_host_compat=False,
json=True,
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
output = json.loads(capsys.readouterr().out)
assert output["task_id"] == "task-123"
def test_cli_task_create_prints_human(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def create_task(self, **kwargs: Any) -> dict[str, Any]:
del kwargs
return {
"task_id": "task-123",
"environment": "debian:12",
"state": "started",
"workspace_path": "/workspace",
"execution_mode": "guest_vsock",
"vcpu_count": 1,
"mem_mib": 1024,
"command_count": 0,
"last_command": None,
}
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="create",
environment="debian:12",
vcpu_count=1,
mem_mib=1024,
ttl_seconds=600,
network=False,
allow_host_compat=False,
json=False,
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
output = capsys.readouterr().out
assert "Task: task-123" in output
assert "Workspace: /workspace" in output
def test_cli_task_exec_prints_human_output(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubPyro:
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
assert task_id == "task-123"
assert command == "cat note.txt"
assert timeout_seconds == 30
return {
"task_id": task_id,
"sequence": 2,
"cwd": "/workspace",
"execution_mode": "guest_vsock",
"exit_code": 0,
"duration_ms": 4,
"stdout": "hello\n",
"stderr": "",
}
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="exec",
task_id="task-123",
timeout_seconds=30,
json=False,
command_args=["--", "cat", "note.txt"],
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
captured = capsys.readouterr()
assert captured.out == "hello\n"
assert "[task-exec] task_id=task-123 sequence=2 cwd=/workspace" in captured.err
def test_cli_task_logs_and_delete_print_human(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubPyro:
def logs_task(self, task_id: str) -> dict[str, Any]:
assert task_id == "task-123"
return {
"task_id": task_id,
"count": 1,
"entries": [
{
"sequence": 1,
"exit_code": 0,
"duration_ms": 2,
"cwd": "/workspace",
"command": "printf 'ok\\n'",
"stdout": "ok\n",
"stderr": "",
}
],
}
def delete_task(self, task_id: str) -> dict[str, Any]:
assert task_id == "task-123"
return {"task_id": task_id, "deleted": True}
class LogsParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="logs",
task_id="task-123",
json=False,
)
monkeypatch.setattr(cli, "_build_parser", lambda: LogsParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
class DeleteParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="delete",
task_id="task-123",
json=False,
)
monkeypatch.setattr(cli, "_build_parser", lambda: DeleteParser())
cli.main()
output = capsys.readouterr().out
assert "#1 exit_code=0 duration_ms=2 cwd=/workspace" in output
assert "Deleted task: task-123" in output
def test_cli_task_status_and_delete_print_json(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubPyro:
def status_task(self, task_id: str) -> dict[str, Any]:
assert task_id == "task-123"
return {"task_id": task_id, "state": "started"}
def delete_task(self, task_id: str) -> dict[str, Any]:
assert task_id == "task-123"
return {"task_id": task_id, "deleted": True}
class StatusParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="status",
task_id="task-123",
json=True,
)
monkeypatch.setattr(cli, "_build_parser", lambda: StatusParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
status = json.loads(capsys.readouterr().out)
assert status["state"] == "started"
class DeleteParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="delete",
task_id="task-123",
json=True,
)
monkeypatch.setattr(cli, "_build_parser", lambda: DeleteParser())
cli.main()
deleted = json.loads(capsys.readouterr().out)
assert deleted["deleted"] is True
def test_cli_task_exec_json_error_exits_nonzero(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
del task_id, command, timeout_seconds
raise RuntimeError("task is unavailable")
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="task",
task_command="exec",
task_id="task-123",
timeout_seconds=30,
json=True,
command_args=["--", "true"],
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
with pytest.raises(SystemExit, match="1"):
cli.main()
payload = json.loads(capsys.readouterr().out)
assert payload["ok"] is False
def test_print_env_helpers_render_human_output(capsys: pytest.CaptureFixture[str]) -> None: def test_print_env_helpers_render_human_output(capsys: pytest.CaptureFixture[str]) -> None:
cli._print_env_list_human( cli._print_env_list_human(
{ {

View file

@ -17,6 +17,7 @@ from pyro_mcp.contract import (
PUBLIC_CLI_DEMO_SUBCOMMANDS, PUBLIC_CLI_DEMO_SUBCOMMANDS,
PUBLIC_CLI_ENV_SUBCOMMANDS, PUBLIC_CLI_ENV_SUBCOMMANDS,
PUBLIC_CLI_RUN_FLAGS, PUBLIC_CLI_RUN_FLAGS,
PUBLIC_CLI_TASK_SUBCOMMANDS,
PUBLIC_MCP_TOOLS, PUBLIC_MCP_TOOLS,
PUBLIC_SDK_METHODS, PUBLIC_SDK_METHODS,
) )
@ -63,6 +64,10 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None:
for subcommand_name in PUBLIC_CLI_ENV_SUBCOMMANDS: for subcommand_name in PUBLIC_CLI_ENV_SUBCOMMANDS:
assert subcommand_name in env_help_text assert subcommand_name in env_help_text
task_help_text = _subparser_choice(parser, "task").format_help()
for subcommand_name in PUBLIC_CLI_TASK_SUBCOMMANDS:
assert subcommand_name in task_help_text
demo_help_text = _subparser_choice(parser, "demo").format_help() demo_help_text = _subparser_choice(parser, "demo").format_help()
for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS: for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS:
assert subcommand_name in demo_help_text assert subcommand_name in demo_help_text

View file

@ -31,6 +31,8 @@ def test_create_server_registers_vm_tools(tmp_path: Path) -> None:
assert "vm_network_info" in tool_names assert "vm_network_info" in tool_names
assert "vm_run" in tool_names assert "vm_run" in tool_names
assert "vm_status" in tool_names assert "vm_status" in tool_names
assert "task_create" in tool_names
assert "task_logs" in tool_names
def test_vm_run_round_trip(tmp_path: Path) -> None: def test_vm_run_round_trip(tmp_path: Path) -> None:
@ -161,3 +163,50 @@ def test_server_main_runs_stdio_transport(monkeypatch: pytest.MonkeyPatch) -> No
monkeypatch.setattr(server_module, "create_server", lambda: StubServer()) monkeypatch.setattr(server_module, "create_server", lambda: StubServer())
server_module.main() server_module.main()
assert called == {"transport": "stdio"} assert called == {"transport": "stdio"}
def test_task_tools_round_trip(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
def _extract_structured(raw_result: object) -> dict[str, Any]:
if not isinstance(raw_result, tuple) or len(raw_result) != 2:
raise TypeError("unexpected call_tool result shape")
_, structured = raw_result
if not isinstance(structured, dict):
raise TypeError("expected structured dictionary result")
return cast(dict[str, Any], structured)
async def _run() -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]:
server = create_server(manager=manager)
created = _extract_structured(
await server.call_tool(
"task_create",
{
"environment": "debian:12-base",
"allow_host_compat": True,
},
)
)
task_id = str(created["task_id"])
executed = _extract_structured(
await server.call_tool(
"task_exec",
{
"task_id": task_id,
"command": "printf 'ok\\n'",
},
)
)
logs = _extract_structured(await server.call_tool("task_logs", {"task_id": task_id}))
deleted = _extract_structured(await server.call_tool("task_delete", {"task_id": task_id}))
return created, executed, logs, deleted
created, executed, logs, deleted = asyncio.run(_run())
assert created["state"] == "started"
assert executed["stdout"] == "ok\n"
assert logs["count"] == 1
assert deleted["deleted"] is True

View file

@ -1,14 +1,17 @@
from __future__ import annotations from __future__ import annotations
import json
import subprocess
import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import pytest import pytest
import pyro_mcp.vm_manager as vm_manager_module import pyro_mcp.vm_manager as vm_manager_module
from pyro_mcp.runtime import resolve_runtime_paths from pyro_mcp.runtime import RuntimeCapabilities, resolve_runtime_paths
from pyro_mcp.vm_manager import VmManager from pyro_mcp.vm_manager import VmManager
from pyro_mcp.vm_network import TapNetworkManager from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None: def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None:
@ -262,6 +265,95 @@ def test_vm_manager_run_vm(tmp_path: Path) -> None:
assert str(result["stdout"]) == "ok\n" assert str(result["stdout"]) == "ok\n"
def test_task_lifecycle_and_logs(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
created = manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
)
task_id = str(created["task_id"])
assert created["state"] == "started"
assert created["workspace_path"] == "/workspace"
first = manager.exec_task(
task_id,
command="printf 'hello\\n' > note.txt",
timeout_seconds=30,
)
second = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
assert first["exit_code"] == 0
assert second["stdout"] == "hello\n"
status = manager.status_task(task_id)
assert status["command_count"] == 2
assert status["last_command"] is not None
logs = manager.logs_task(task_id)
assert logs["count"] == 2
entries = logs["entries"]
assert isinstance(entries, list)
assert entries[1]["stdout"] == "hello\n"
deleted = manager.delete_task(task_id)
assert deleted["deleted"] is True
with pytest.raises(ValueError, match="does not exist"):
manager.status_task(task_id)
def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None:
base_dir = tmp_path / "vms"
manager = VmManager(
backend_name="mock",
base_dir=base_dir,
network_manager=TapNetworkManager(enabled=False),
)
task_id = str(
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
)["task_id"]
)
other = VmManager(
backend_name="mock",
base_dir=base_dir,
network_manager=TapNetworkManager(enabled=False),
)
executed = other.exec_task(task_id, command="printf 'ok\\n'", timeout_seconds=30)
assert executed["exit_code"] == 0
assert executed["stdout"] == "ok\n"
logs = other.logs_task(task_id)
assert logs["count"] == 1
def test_task_requires_started_state(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
task_id = str(
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
)["task_id"]
)
task_dir = tmp_path / "vms" / "tasks" / task_id / "task.json"
payload = json.loads(task_dir.read_text(encoding="utf-8"))
payload["state"] = "stopped"
task_dir.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
with pytest.raises(RuntimeError, match="must be in 'started' state"):
manager.exec_task(task_id, command="true", timeout_seconds=30)
def test_vm_manager_firecracker_backend_path( def test_vm_manager_firecracker_backend_path(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None: ) -> None:
@ -334,3 +426,193 @@ def test_vm_manager_uses_canonical_default_cache_dir(
) )
assert manager._environment_store.cache_dir == tmp_path / "cache" # noqa: SLF001 assert manager._environment_store.cache_dir == tmp_path / "cache" # noqa: SLF001
def test_vm_manager_helper_round_trips() -> None:
network = NetworkConfig(
vm_id="abc123",
tap_name="tap0",
guest_ip="172.29.1.2",
gateway_ip="172.29.1.1",
subnet_cidr="172.29.1.0/24",
mac_address="06:00:aa:bb:cc:dd",
dns_servers=("1.1.1.1", "8.8.8.8"),
)
assert vm_manager_module._optional_int(None) is None # noqa: SLF001
assert vm_manager_module._optional_int(True) == 1 # noqa: SLF001
assert vm_manager_module._optional_int(7) == 7 # noqa: SLF001
assert vm_manager_module._optional_int(7.2) == 7 # noqa: SLF001
assert vm_manager_module._optional_int("9") == 9 # noqa: SLF001
with pytest.raises(TypeError, match="integer-compatible"):
vm_manager_module._optional_int(object()) # noqa: SLF001
assert vm_manager_module._optional_str(None) is None # noqa: SLF001
assert vm_manager_module._optional_str(1) == "1" # noqa: SLF001
assert vm_manager_module._optional_dict(None) is None # noqa: SLF001
assert vm_manager_module._optional_dict({"x": 1}) == {"x": 1} # noqa: SLF001
with pytest.raises(TypeError, match="dictionary payload"):
vm_manager_module._optional_dict("bad") # noqa: SLF001
assert vm_manager_module._string_dict({"x": 1}) == {"x": "1"} # noqa: SLF001
assert vm_manager_module._string_dict("bad") == {} # noqa: SLF001
serialized = vm_manager_module._serialize_network(network) # noqa: SLF001
assert serialized is not None
restored = vm_manager_module._deserialize_network(serialized) # noqa: SLF001
assert restored == network
assert vm_manager_module._deserialize_network(None) is None # noqa: SLF001
with pytest.raises(TypeError, match="dictionary payload"):
vm_manager_module._deserialize_network("bad") # noqa: SLF001
assert vm_manager_module._wrap_guest_command("echo hi") == "echo hi" # noqa: SLF001
wrapped = vm_manager_module._wrap_guest_command("echo hi", cwd="/workspace") # noqa: SLF001
assert "cd /workspace" in wrapped
assert vm_manager_module._pid_is_running(None) is False # noqa: SLF001
def test_copy_rootfs_falls_back_to_copy2(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
source = tmp_path / "rootfs.ext4"
source.write_text("payload", encoding="utf-8")
dest = tmp_path / "dest" / "rootfs.ext4"
def _raise_oserror(*args: Any, **kwargs: Any) -> Any:
del args, kwargs
raise OSError("no cp")
monkeypatch.setattr(subprocess, "run", _raise_oserror)
clone_mode = vm_manager_module._copy_rootfs(source, dest) # noqa: SLF001
assert clone_mode == "copy2"
assert dest.read_text(encoding="utf-8") == "payload"
def test_task_create_cleans_up_on_start_failure(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
def _boom(instance: Any) -> None:
del instance
raise RuntimeError("boom")
monkeypatch.setattr(manager._backend, "start", _boom) # noqa: SLF001
with pytest.raises(RuntimeError, match="boom"):
manager.create_task(environment="debian:12-base", allow_host_compat=True)
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
def test_exec_instance_wraps_guest_workspace_command(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001
supports_vm_boot=True,
supports_guest_exec=True,
supports_guest_network=False,
reason=None,
)
captured: dict[str, Any] = {}
class StubBackend:
def exec(
self,
instance: Any,
command: str,
timeout_seconds: int,
*,
workdir: Path | None = None,
) -> vm_manager_module.VmExecResult:
del instance, timeout_seconds
captured["command"] = command
captured["workdir"] = workdir
return vm_manager_module.VmExecResult(
stdout="",
stderr="",
exit_code=0,
duration_ms=1,
)
manager._backend = StubBackend() # type: ignore[assignment] # noqa: SLF001
instance = vm_manager_module.VmInstance( # noqa: SLF001
vm_id="vm-123",
environment="debian:12-base",
vcpu_count=1,
mem_mib=512,
ttl_seconds=600,
created_at=time.time(),
expires_at=time.time() + 600,
workdir=tmp_path / "runtime",
state="started",
)
result, execution_mode = manager._exec_instance( # noqa: SLF001
instance,
command="echo hi",
timeout_seconds=30,
guest_cwd="/workspace",
)
assert result.exit_code == 0
assert execution_mode == "unknown"
assert "cd /workspace" in str(captured["command"])
assert captured["workdir"] is None
def test_status_task_marks_dead_backing_process_stopped(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
task_id = str(
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
)["task_id"]
)
task_path = tmp_path / "vms" / "tasks" / task_id / "task.json"
payload = json.loads(task_path.read_text(encoding="utf-8"))
payload["metadata"]["execution_mode"] = "guest_vsock"
payload["firecracker_pid"] = 999999
task_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
status = manager.status_task(task_id)
assert status["state"] == "stopped"
updated_payload = json.loads(task_path.read_text(encoding="utf-8"))
assert "backing guest process" in str(updated_payload.get("last_error", ""))
def test_reap_expired_tasks_removes_invalid_and_expired_records(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
invalid_dir = tmp_path / "vms" / "tasks" / "invalid"
invalid_dir.mkdir(parents=True)
(invalid_dir / "task.json").write_text("[]", encoding="utf-8")
task_id = str(
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
)["task_id"]
)
task_path = tmp_path / "vms" / "tasks" / task_id / "task.json"
payload = json.loads(task_path.read_text(encoding="utf-8"))
payload["expires_at"] = 0.0
task_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
with manager._lock: # noqa: SLF001
manager._reap_expired_tasks_locked(time.time()) # noqa: SLF001
assert not invalid_dir.exists()
assert not (tmp_path / "vms" / "tasks" / task_id).exists()

2
uv.lock generated
View file

@ -706,7 +706,7 @@ crypto = [
[[package]] [[package]]
name = "pyro-mcp" name = "pyro-mcp"
version = "2.0.1" version = "2.1.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "mcp" }, { name = "mcp" },