Ship trust-first CLI and runtime defaults

This commit is contained in:
Thales Maciel 2026-03-09 20:52:49 -03:00
parent fb718af154
commit 5d63e4c16e
26 changed files with 894 additions and 134 deletions

View file

@ -30,6 +30,7 @@ This repository ships `pyro-mcp`, an MCP-compatible package for ephemeral VM lif
- Use `make doctor` to inspect bundled runtime integrity and host prerequisites. - Use `make doctor` to inspect bundled runtime integrity and host prerequisites.
- Network-enabled flows require host privilege for TAP/NAT setup; the current implementation uses `sudo -n` for `ip`, `nft`, and `iptables` when available. - Network-enabled flows require host privilege for TAP/NAT setup; the current implementation uses `sudo -n` for `ip`, `nft`, and `iptables` when available.
- If you need full log payloads from the Ollama demo, use `make ollama-demo OLLAMA_DEMO_FLAGS=-v`. - If you need full log payloads from the Ollama demo, use `make ollama-demo OLLAMA_DEMO_FLAGS=-v`.
- `pyro run` now defaults to `1 vCPU / 1024 MiB`, human-readable output, and fail-closed guest execution unless `--allow-host-compat` is passed.
- After heavy runtime work, reclaim local space with `rm -rf build` and `git lfs prune`. - After heavy runtime work, reclaim local space with `rm -rf build` and `git lfs prune`.
- The pre-migration `pre-lfs-*` tag is local backup material only; do not push it or it will keep the old giant blobs reachable. - The pre-migration `pre-lfs-*` tag is local backup material only; do not push it or it will keep the old giant blobs reachable.
- Public contract documentation lives in `docs/public-contract.md`. - Public contract documentation lives in `docs/public-contract.md`.

View file

@ -45,6 +45,8 @@ The package ships the embedded Firecracker runtime and a package-controlled envi
Official environments are pulled as OCI artifacts from public Docker Hub repositories into a local Official environments are pulled as OCI artifacts from public Docker Hub repositories into a local
cache on first use or through `pyro env pull`. cache on first use or through `pyro env pull`.
End users do not need registry credentials to pull or run official environments. End users do not need registry credentials to pull or run official environments.
The default cache location is `~/.cache/pyro-mcp/environments`; override it with
`PYRO_ENVIRONMENT_CACHE_DIR`.
## CLI ## CLI
@ -63,13 +65,13 @@ pyro env pull debian:12
Run one command in an ephemeral VM: Run one command in an ephemeral VM:
```bash ```bash
pyro run debian:12 --vcpu-count 1 --mem-mib 1024 -- git --version pyro run debian:12 -- git --version
``` ```
Run with outbound internet enabled: Run with outbound internet enabled:
```bash ```bash
pyro run debian:12 --vcpu-count 1 --mem-mib 1024 --network -- \ pyro run debian:12 --network -- \
"git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world && git -C hello-world rev-parse --is-inside-work-tree" "git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world && git -C hello-world rev-parse --is-inside-work-tree"
``` ```
@ -77,8 +79,13 @@ Show runtime and host diagnostics:
```bash ```bash
pyro doctor pyro doctor
pyro doctor --json
``` ```
`pyro run` defaults to `1 vCPU / 1024 MiB`.
It fails closed when guest boot or guest exec is unavailable.
Use `--allow-host-compat` only if you explicitly want host execution.
Run the deterministic demo: Run the deterministic demo:
```bash ```bash
@ -103,8 +110,6 @@ pyro = Pyro()
result = pyro.run_in_vm( result = pyro.run_in_vm(
environment="debian:12", environment="debian:12",
command="git --version", command="git --version",
vcpu_count=1,
mem_mib=1024,
timeout_seconds=30, timeout_seconds=30,
network=False, network=False,
) )
@ -119,8 +124,6 @@ from pyro_mcp import Pyro
pyro = Pyro() pyro = Pyro()
created = pyro.create_vm( created = pyro.create_vm(
environment="debian:12", environment="debian:12",
vcpu_count=1,
mem_mib=1024,
ttl_seconds=600, ttl_seconds=600,
network=True, network=True,
) )
@ -144,12 +147,12 @@ print(pyro.inspect_environment("debian:12"))
Primary agent-facing tool: Primary agent-facing tool:
- `vm_run(environment, command, vcpu_count, mem_mib, timeout_seconds=30, ttl_seconds=600, network=false)` - `vm_run(environment, command, vcpu_count=1, mem_mib=1024, timeout_seconds=30, ttl_seconds=600, network=false, allow_host_compat=false)`
Advanced lifecycle tools: Advanced lifecycle tools:
- `vm_list_environments()` - `vm_list_environments()`
- `vm_create(environment, vcpu_count, mem_mib, ttl_seconds=600, network=false)` - `vm_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false)`
- `vm_start(vm_id)` - `vm_start(vm_id)`
- `vm_exec(vm_id, command, timeout_seconds=30)` - `vm_exec(vm_id, command, timeout_seconds=30)`
- `vm_stop(vm_id)` - `vm_stop(vm_id)`
@ -180,6 +183,7 @@ The package ships an embedded Linux x86_64 runtime payload with:
No system Firecracker installation is required. No system Firecracker installation is required.
`pyro` installs curated environments into a local cache and reports their status through `pyro env inspect` and `pyro doctor`. `pyro` installs curated environments into a local cache and reports their status through `pyro env inspect` and `pyro doctor`.
The public CLI is human-readable by default; add `--json` for structured output.
## Contributor Workflow ## Contributor Workflow

View file

@ -30,7 +30,7 @@ uvx --from pyro-mcp pyro env pull debian:12
Run one command in a curated environment: Run one command in a curated environment:
```bash ```bash
uvx --from pyro-mcp pyro run debian:12 --vcpu-count 1 --mem-mib 1024 -- git --version uvx --from pyro-mcp pyro run debian:12 -- git --version
``` ```
Inspect the official environment catalog: Inspect the official environment catalog:
@ -48,8 +48,13 @@ pyro env list
pyro env pull debian:12 pyro env pull debian:12
pyro env inspect debian:12 pyro env inspect debian:12
pyro doctor pyro doctor
pyro run debian:12 -- git --version
``` ```
`pyro run` defaults to `1 vCPU / 1024 MiB`.
If guest execution is unavailable, the command fails unless you explicitly pass
`--allow-host-compat`.
## Contributor Clone ## Contributor Clone
```bash ```bash

View file

@ -1,6 +1,6 @@
# Public Contract # Public Contract
This document defines the supported public interface for `pyro-mcp` `1.x`. This document defines the supported public interface for `pyro-mcp` `2.x`.
## Package Identity ## Package Identity
@ -31,12 +31,14 @@ Stable `pyro run` interface:
- `--timeout-seconds` - `--timeout-seconds`
- `--ttl-seconds` - `--ttl-seconds`
- `--network` - `--network`
- `--allow-host-compat`
- `--json`
Behavioral guarantees: Behavioral guarantees:
- `pyro run <environment> --vcpu-count <n> --mem-mib <mib> -- <command>` returns structured JSON. - `pyro run <environment> -- <command>` defaults to `1 vCPU / 1024 MiB`.
- `pyro env list`, `pyro env pull`, `pyro env inspect`, and `pyro env prune` return structured JSON. - `pyro run` fails if guest boot or guest exec is unavailable unless `--allow-host-compat` is set.
- `pyro doctor` returns structured JSON diagnostics. - `pyro run`, `pyro env list`, `pyro env pull`, `pyro env inspect`, `pyro env prune`, and `pyro doctor` are human-readable by default and return structured JSON with `--json`.
- `pyro demo ollama` prints log lines plus a final summary line. - `pyro demo ollama` prints log lines plus a final summary line.
## Python SDK Contract ## Python SDK Contract
@ -80,6 +82,11 @@ Stable public method names:
- `reap_expired()` - `reap_expired()`
- `run_in_vm(...)` - `run_in_vm(...)`
Behavioral defaults:
- `Pyro.create_vm(...)` and `Pyro.run_in_vm(...)` default to `vcpu_count=1` and `mem_mib=1024`.
- `allow_host_compat` defaults to `False` on `create_vm(...)` and `run_in_vm(...)`.
## MCP Contract ## MCP Contract
Primary tool: Primary tool:
@ -98,6 +105,11 @@ Advanced lifecycle tools:
- `vm_network_info` - `vm_network_info`
- `vm_reap_expired` - `vm_reap_expired`
Behavioral defaults:
- `vm_run` and `vm_create` default to `vcpu_count=1` and `mem_mib=1024`.
- `vm_run` and `vm_create` expose `allow_host_compat`, which defaults to `false`.
## Versioning Rule ## Versioning Rule
- `pyro-mcp` uses SemVer. - `pyro-mcp` uses SemVer.

View file

@ -20,6 +20,26 @@ pyro env pull debian:12
If you are validating a freshly published official environment, also verify that the corresponding If you are validating a freshly published official environment, also verify that the corresponding
Docker Hub repository is public. Docker Hub repository is public.
## `pyro run` fails closed before the command executes
Cause:
- the bundled runtime cannot boot a guest
- guest boot works but guest exec is unavailable
- you are using a mock or shim runtime path that only supports host compatibility mode
Fix:
```bash
pyro doctor
```
If you intentionally want host execution for a one-off compatibility run, rerun with:
```bash
pyro run --allow-host-compat debian:12 -- git --version
```
## `pyro run --network` fails before the guest starts ## `pyro run --network` fails before the guest starts
Cause: Cause:

View file

@ -6,6 +6,13 @@ import json
from typing import Any from typing import Any
from pyro_mcp import Pyro from pyro_mcp import Pyro
from pyro_mcp.vm_manager import (
DEFAULT_ALLOW_HOST_COMPAT,
DEFAULT_MEM_MIB,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TTL_SECONDS,
DEFAULT_VCPU_COUNT,
)
VM_RUN_TOOL: dict[str, Any] = { VM_RUN_TOOL: dict[str, Any] = {
"name": "vm_run", "name": "vm_run",
@ -20,8 +27,9 @@ VM_RUN_TOOL: dict[str, Any] = {
"timeout_seconds": {"type": "integer", "default": 30}, "timeout_seconds": {"type": "integer", "default": 30},
"ttl_seconds": {"type": "integer", "default": 600}, "ttl_seconds": {"type": "integer", "default": 600},
"network": {"type": "boolean", "default": False}, "network": {"type": "boolean", "default": False},
"allow_host_compat": {"type": "boolean", "default": False},
}, },
"required": ["environment", "command", "vcpu_count", "mem_mib"], "required": ["environment", "command"],
}, },
} }
@ -31,11 +39,12 @@ def call_vm_run(arguments: dict[str, Any]) -> dict[str, Any]:
return pyro.run_in_vm( return pyro.run_in_vm(
environment=str(arguments["environment"]), environment=str(arguments["environment"]),
command=str(arguments["command"]), command=str(arguments["command"]),
vcpu_count=int(arguments["vcpu_count"]), vcpu_count=int(arguments.get("vcpu_count", DEFAULT_VCPU_COUNT)),
mem_mib=int(arguments["mem_mib"]), mem_mib=int(arguments.get("mem_mib", DEFAULT_MEM_MIB)),
timeout_seconds=int(arguments.get("timeout_seconds", 30)), timeout_seconds=int(arguments.get("timeout_seconds", DEFAULT_TIMEOUT_SECONDS)),
ttl_seconds=int(arguments.get("ttl_seconds", 600)), ttl_seconds=int(arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)),
network=bool(arguments.get("network", False)), network=bool(arguments.get("network", False)),
allow_host_compat=bool(arguments.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)),
) )
@ -43,8 +52,6 @@ def main() -> None:
tool_arguments: dict[str, Any] = { tool_arguments: dict[str, Any] = {
"environment": "debian:12", "environment": "debian:12",
"command": "git --version", "command": "git --version",
"vcpu_count": 1,
"mem_mib": 1024,
"timeout_seconds": 30, "timeout_seconds": 30,
"network": False, "network": False,
} }

View file

@ -13,6 +13,13 @@ import json
from typing import Any, Callable, TypeVar, cast from typing import Any, Callable, TypeVar, cast
from pyro_mcp import Pyro from pyro_mcp import Pyro
from pyro_mcp.vm_manager import (
DEFAULT_ALLOW_HOST_COMPAT,
DEFAULT_MEM_MIB,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TTL_SECONDS,
DEFAULT_VCPU_COUNT,
)
F = TypeVar("F", bound=Callable[..., Any]) F = TypeVar("F", bound=Callable[..., Any])
@ -21,11 +28,12 @@ def run_vm_run_tool(
*, *,
environment: str, environment: str,
command: str, command: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
timeout_seconds: int = 30, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
ttl_seconds: int = 600, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> str: ) -> str:
pyro = Pyro() pyro = Pyro()
result = pyro.run_in_vm( result = pyro.run_in_vm(
@ -36,6 +44,7 @@ def run_vm_run_tool(
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
ttl_seconds=ttl_seconds, ttl_seconds=ttl_seconds,
network=network, network=network,
allow_host_compat=allow_host_compat,
) )
return json.dumps(result, sort_keys=True) return json.dumps(result, sort_keys=True)
@ -55,11 +64,12 @@ def build_langchain_vm_run_tool() -> Any:
def vm_run( def vm_run(
environment: str, environment: str,
command: str, command: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
timeout_seconds: int = 30, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
ttl_seconds: int = 600, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> str: ) -> str:
"""Run one command in an ephemeral Firecracker VM and clean it up.""" """Run one command in an ephemeral Firecracker VM and clean it up."""
return run_vm_run_tool( return run_vm_run_tool(
@ -70,6 +80,7 @@ def build_langchain_vm_run_tool() -> Any:
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
ttl_seconds=ttl_seconds, ttl_seconds=ttl_seconds,
network=network, network=network,
allow_host_compat=allow_host_compat,
) )
return vm_run return vm_run

View file

@ -15,6 +15,13 @@ import os
from typing import Any from typing import Any
from pyro_mcp import Pyro from pyro_mcp import Pyro
from pyro_mcp.vm_manager import (
DEFAULT_ALLOW_HOST_COMPAT,
DEFAULT_MEM_MIB,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TTL_SECONDS,
DEFAULT_VCPU_COUNT,
)
DEFAULT_MODEL = "gpt-5" DEFAULT_MODEL = "gpt-5"
@ -33,8 +40,9 @@ OPENAI_VM_RUN_TOOL: dict[str, Any] = {
"timeout_seconds": {"type": "integer"}, "timeout_seconds": {"type": "integer"},
"ttl_seconds": {"type": "integer"}, "ttl_seconds": {"type": "integer"},
"network": {"type": "boolean"}, "network": {"type": "boolean"},
"allow_host_compat": {"type": "boolean"},
}, },
"required": ["environment", "command", "vcpu_count", "mem_mib"], "required": ["environment", "command"],
"additionalProperties": False, "additionalProperties": False,
}, },
} }
@ -45,11 +53,12 @@ def call_vm_run(arguments: dict[str, Any]) -> dict[str, Any]:
return pyro.run_in_vm( return pyro.run_in_vm(
environment=str(arguments["environment"]), environment=str(arguments["environment"]),
command=str(arguments["command"]), command=str(arguments["command"]),
vcpu_count=int(arguments["vcpu_count"]), vcpu_count=int(arguments.get("vcpu_count", DEFAULT_VCPU_COUNT)),
mem_mib=int(arguments["mem_mib"]), mem_mib=int(arguments.get("mem_mib", DEFAULT_MEM_MIB)),
timeout_seconds=int(arguments.get("timeout_seconds", 30)), timeout_seconds=int(arguments.get("timeout_seconds", DEFAULT_TIMEOUT_SECONDS)),
ttl_seconds=int(arguments.get("ttl_seconds", 600)), ttl_seconds=int(arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)),
network=bool(arguments.get("network", False)), network=bool(arguments.get("network", False)),
allow_host_compat=bool(arguments.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)),
) )
@ -88,7 +97,7 @@ def main() -> None:
model = os.environ.get("OPENAI_MODEL", DEFAULT_MODEL) model = os.environ.get("OPENAI_MODEL", DEFAULT_MODEL)
prompt = ( prompt = (
"Use the vm_run tool to run `git --version` in an ephemeral VM. " "Use the vm_run tool to run `git --version` in an ephemeral VM. "
"Use the `debian:12` environment with 1 vCPU and 1024 MiB of memory. " "Use the `debian:12` environment. "
"Do not use networking for this request." "Do not use networking for this request."
) )
print(run_openai_vm_run_example(prompt=prompt, model=model)) print(run_openai_vm_run_example(prompt=prompt, model=model))

View file

@ -11,8 +11,6 @@ def main() -> None:
pyro = Pyro() pyro = Pyro()
created = pyro.create_vm( created = pyro.create_vm(
environment="debian:12", environment="debian:12",
vcpu_count=1,
mem_mib=1024,
ttl_seconds=600, ttl_seconds=600,
network=False, network=False,
) )

View file

@ -12,8 +12,6 @@ def main() -> None:
result = pyro.run_in_vm( result = pyro.run_in_vm(
environment="debian:12", environment="debian:12",
command="git --version", command="git --version",
vcpu_count=1,
mem_mib=1024,
timeout_seconds=30, timeout_seconds=30,
network=False, network=False,
) )

View file

@ -1,6 +1,6 @@
[project] [project]
name = "pyro-mcp" name = "pyro-mcp"
version = "1.0.0" version = "2.0.0"
description = "Curated Linux environments for ephemeral Firecracker-backed VM execution." description = "Curated Linux environments for ephemeral Firecracker-backed VM execution."
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }

View file

@ -7,7 +7,14 @@ from typing import Any
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
from pyro_mcp.vm_manager import VmManager from pyro_mcp.vm_manager import (
DEFAULT_ALLOW_HOST_COMPAT,
DEFAULT_MEM_MIB,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TTL_SECONDS,
DEFAULT_VCPU_COUNT,
VmManager,
)
class Pyro: class Pyro:
@ -49,10 +56,11 @@ class Pyro:
self, self,
*, *,
environment: str, environment: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = 600, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]: ) -> dict[str, Any]:
return self._manager.create_vm( return self._manager.create_vm(
environment=environment, environment=environment,
@ -60,6 +68,7 @@ class Pyro:
mem_mib=mem_mib, mem_mib=mem_mib,
ttl_seconds=ttl_seconds, ttl_seconds=ttl_seconds,
network=network, network=network,
allow_host_compat=allow_host_compat,
) )
def start_vm(self, vm_id: str) -> dict[str, Any]: def start_vm(self, vm_id: str) -> dict[str, Any]:
@ -88,11 +97,12 @@ class Pyro:
*, *,
environment: str, environment: str,
command: str, command: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
timeout_seconds: int = 30, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
ttl_seconds: int = 600, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]: ) -> dict[str, Any]:
return self._manager.run_vm( return self._manager.run_vm(
environment=environment, environment=environment,
@ -102,6 +112,7 @@ class Pyro:
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
ttl_seconds=ttl_seconds, ttl_seconds=ttl_seconds,
network=network, network=network,
allow_host_compat=allow_host_compat,
) )
def create_server(self) -> FastMCP: def create_server(self) -> FastMCP:
@ -111,11 +122,12 @@ class Pyro:
async def vm_run( async def vm_run(
environment: str, environment: str,
command: str, command: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
timeout_seconds: int = 30, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
ttl_seconds: int = 600, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create, start, execute, and clean up an ephemeral VM.""" """Create, start, execute, and clean up an ephemeral VM."""
return self.run_in_vm( return self.run_in_vm(
@ -126,6 +138,7 @@ class Pyro:
timeout_seconds=timeout_seconds, timeout_seconds=timeout_seconds,
ttl_seconds=ttl_seconds, ttl_seconds=ttl_seconds,
network=network, network=network,
allow_host_compat=allow_host_compat,
) )
@server.tool() @server.tool()
@ -136,10 +149,11 @@ class Pyro:
@server.tool() @server.tool()
async def vm_create( async def vm_create(
environment: str, environment: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = 600, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create an ephemeral VM record with environment and resource sizing.""" """Create an ephemeral VM record with environment and resource sizing."""
return self.create_vm( return self.create_vm(
@ -148,6 +162,7 @@ class Pyro:
mem_mib=mem_mib, mem_mib=mem_mib,
ttl_seconds=ttl_seconds, ttl_seconds=ttl_seconds,
network=network, network=network,
allow_host_compat=allow_host_compat,
) )
@server.tool() @server.tool()

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import argparse import argparse
import json import json
import sys
from typing import Any from typing import Any
from pyro_mcp import __version__ from pyro_mcp import __version__
@ -12,12 +13,135 @@ from pyro_mcp.demo import run_demo
from pyro_mcp.ollama_demo import DEFAULT_OLLAMA_BASE_URL, DEFAULT_OLLAMA_MODEL, run_ollama_tool_demo from pyro_mcp.ollama_demo import DEFAULT_OLLAMA_BASE_URL, DEFAULT_OLLAMA_MODEL, run_ollama_tool_demo
from pyro_mcp.runtime import DEFAULT_PLATFORM, doctor_report from pyro_mcp.runtime import DEFAULT_PLATFORM, doctor_report
from pyro_mcp.vm_environments import DEFAULT_CATALOG_VERSION from pyro_mcp.vm_environments import DEFAULT_CATALOG_VERSION
from pyro_mcp.vm_manager import (
DEFAULT_MEM_MIB,
DEFAULT_VCPU_COUNT,
)
def _print_json(payload: dict[str, Any]) -> None: def _print_json(payload: dict[str, Any]) -> None:
print(json.dumps(payload, indent=2, sort_keys=True)) print(json.dumps(payload, indent=2, sort_keys=True))
def _write_stream(text: str, *, stream: Any) -> None:
if text == "":
return
stream.write(text)
stream.flush()
def _print_run_human(payload: dict[str, Any]) -> None:
stdout = str(payload.get("stdout", ""))
stderr = str(payload.get("stderr", ""))
_write_stream(stdout, stream=sys.stdout)
_write_stream(stderr, stream=sys.stderr)
print(
"[run] "
f"environment={str(payload.get('environment', 'unknown'))} "
f"execution_mode={str(payload.get('execution_mode', 'unknown'))} "
f"exit_code={int(payload.get('exit_code', 1))} "
f"duration_ms={int(payload.get('duration_ms', 0))}",
file=sys.stderr,
flush=True,
)
def _print_env_list_human(payload: dict[str, Any]) -> None:
print(f"Catalog version: {payload.get('catalog_version', 'unknown')}")
environments = payload.get("environments")
if not isinstance(environments, list) or not environments:
print("No environments found.")
return
for entry in environments:
if not isinstance(entry, dict):
continue
status = "installed" if bool(entry.get("installed")) else "not installed"
print(
f"{str(entry.get('name', 'unknown'))} [{status}] "
f"{str(entry.get('description', '')).strip()}".rstrip()
)
def _print_env_detail_human(payload: dict[str, Any], *, action: str) -> None:
print(f"{action}: {str(payload.get('name', 'unknown'))}")
print(f"Version: {str(payload.get('version', 'unknown'))}")
print(
f"Distribution: {str(payload.get('distribution', 'unknown'))} "
f"{str(payload.get('distribution_version', 'unknown'))}"
)
print(f"Installed: {'yes' if bool(payload.get('installed')) else 'no'}")
print(f"Cache dir: {str(payload.get('cache_dir', 'unknown'))}")
packages = payload.get("default_packages")
if isinstance(packages, list) and packages:
print("Default packages: " + ", ".join(str(item) for item in packages))
description = str(payload.get("description", "")).strip()
if description != "":
print(f"Description: {description}")
if payload.get("installed"):
print(f"Install dir: {str(payload.get('install_dir', 'unknown'))}")
install_manifest = payload.get("install_manifest")
if install_manifest is not None:
print(f"Install manifest: {str(install_manifest)}")
kernel_image = payload.get("kernel_image")
if kernel_image is not None:
print(f"Kernel image: {str(kernel_image)}")
rootfs_image = payload.get("rootfs_image")
if rootfs_image is not None:
print(f"Rootfs image: {str(rootfs_image)}")
registry = payload.get("oci_registry")
repository = payload.get("oci_repository")
reference = payload.get("oci_reference")
if isinstance(registry, str) and isinstance(repository, str) and isinstance(reference, str):
print(f"OCI source: {registry}/{repository}:{reference}")
def _print_prune_human(payload: dict[str, Any]) -> None:
count = int(payload.get("count", 0))
print(f"Deleted {count} cached environment entr{'y' if count == 1 else 'ies'}.")
deleted = payload.get("deleted_environment_dirs")
if isinstance(deleted, list):
for entry in deleted:
print(f"- {entry}")
def _print_doctor_human(payload: dict[str, Any]) -> None:
issues = payload.get("issues")
runtime_ok = bool(payload.get("runtime_ok"))
print(f"Platform: {str(payload.get('platform', 'unknown'))}")
print(f"Runtime: {'PASS' if runtime_ok else 'FAIL'}")
kvm = payload.get("kvm")
if isinstance(kvm, dict):
print(
"KVM: "
f"exists={'yes' if bool(kvm.get('exists')) else 'no'} "
f"readable={'yes' if bool(kvm.get('readable')) else 'no'} "
f"writable={'yes' if bool(kvm.get('writable')) else 'no'}"
)
runtime = payload.get("runtime")
if isinstance(runtime, dict):
print(f"Environment cache: {str(runtime.get('cache_dir', 'unknown'))}")
capabilities = runtime.get("capabilities")
if isinstance(capabilities, dict):
print(
"Capabilities: "
f"vm_boot={'yes' if bool(capabilities.get('supports_vm_boot')) else 'no'} "
f"guest_exec={'yes' if bool(capabilities.get('supports_guest_exec')) else 'no'} "
"guest_network="
f"{'yes' if bool(capabilities.get('supports_guest_network')) else 'no'}"
)
networking = payload.get("networking")
if isinstance(networking, dict):
print(
"Networking: "
f"tun={'yes' if bool(networking.get('tun_available')) else 'no'} "
f"ip_forward={'yes' if bool(networking.get('ip_forward_enabled')) else 'no'}"
)
if isinstance(issues, list) and issues:
print("Issues:")
for issue in issues:
print(f"- {issue}")
def _build_parser() -> argparse.ArgumentParser: def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="pyro CLI for curated ephemeral Linux environments." description="pyro CLI for curated ephemeral Linux environments."
@ -27,15 +151,19 @@ def _build_parser() -> argparse.ArgumentParser:
env_parser = subparsers.add_parser("env", help="Inspect and manage curated environments.") env_parser = subparsers.add_parser("env", help="Inspect and manage curated environments.")
env_subparsers = env_parser.add_subparsers(dest="env_command", required=True) env_subparsers = env_parser.add_subparsers(dest="env_command", required=True)
env_subparsers.add_parser("list", help="List official environments.") list_parser = env_subparsers.add_parser("list", help="List official environments.")
list_parser.add_argument("--json", action="store_true")
pull_parser = env_subparsers.add_parser( pull_parser = env_subparsers.add_parser(
"pull", "pull",
help="Install an environment into the local cache.", help="Install an environment into the local cache.",
) )
pull_parser.add_argument("environment") pull_parser.add_argument("environment")
pull_parser.add_argument("--json", action="store_true")
inspect_parser = env_subparsers.add_parser("inspect", help="Inspect one environment.") inspect_parser = env_subparsers.add_parser("inspect", help="Inspect one environment.")
inspect_parser.add_argument("environment") inspect_parser.add_argument("environment")
env_subparsers.add_parser("prune", help="Delete stale cached environments.") inspect_parser.add_argument("--json", action="store_true")
prune_parser = env_subparsers.add_parser("prune", help="Delete stale cached environments.")
prune_parser.add_argument("--json", action="store_true")
mcp_parser = subparsers.add_parser("mcp", help="Run the MCP server.") mcp_parser = subparsers.add_parser("mcp", help="Run the MCP server.")
mcp_subparsers = mcp_parser.add_subparsers(dest="mcp_command", required=True) mcp_subparsers = mcp_parser.add_subparsers(dest="mcp_command", required=True)
@ -43,15 +171,18 @@ def _build_parser() -> argparse.ArgumentParser:
run_parser = subparsers.add_parser("run", help="Run one command inside an ephemeral VM.") run_parser = subparsers.add_parser("run", help="Run one command inside an ephemeral VM.")
run_parser.add_argument("environment") run_parser.add_argument("environment")
run_parser.add_argument("--vcpu-count", type=int, required=True) run_parser.add_argument("--vcpu-count", type=int, default=DEFAULT_VCPU_COUNT)
run_parser.add_argument("--mem-mib", type=int, required=True) run_parser.add_argument("--mem-mib", type=int, default=DEFAULT_MEM_MIB)
run_parser.add_argument("--timeout-seconds", type=int, default=30) run_parser.add_argument("--timeout-seconds", type=int, default=30)
run_parser.add_argument("--ttl-seconds", type=int, default=600) run_parser.add_argument("--ttl-seconds", type=int, default=600)
run_parser.add_argument("--network", action="store_true") run_parser.add_argument("--network", action="store_true")
run_parser.add_argument("--allow-host-compat", action="store_true")
run_parser.add_argument("--json", action="store_true")
run_parser.add_argument("command_args", nargs="*") run_parser.add_argument("command_args", nargs="*")
doctor_parser = subparsers.add_parser("doctor", help="Inspect runtime and host diagnostics.") doctor_parser = subparsers.add_parser("doctor", help="Inspect runtime and host diagnostics.")
doctor_parser.add_argument("--platform", default=DEFAULT_PLATFORM) doctor_parser.add_argument("--platform", default=DEFAULT_PLATFORM)
doctor_parser.add_argument("--json", action="store_true")
demo_parser = subparsers.add_parser("demo", help="Run built-in demos.") demo_parser = subparsers.add_parser("demo", help="Run built-in demos.")
demo_subparsers = demo_parser.add_subparsers(dest="demo_command") demo_subparsers = demo_parser.add_subparsers(dest="demo_command")
@ -77,27 +208,42 @@ def main() -> None:
pyro = Pyro() pyro = Pyro()
if args.command == "env": if args.command == "env":
if args.env_command == "list": if args.env_command == "list":
_print_json( list_payload: dict[str, Any] = {
{
"catalog_version": DEFAULT_CATALOG_VERSION, "catalog_version": DEFAULT_CATALOG_VERSION,
"environments": pyro.list_environments(), "environments": pyro.list_environments(),
} }
) if bool(args.json):
_print_json(list_payload)
else:
_print_env_list_human(list_payload)
return return
if args.env_command == "pull": if args.env_command == "pull":
_print_json(dict(pyro.pull_environment(args.environment))) pull_payload = pyro.pull_environment(args.environment)
if bool(args.json):
_print_json(pull_payload)
else:
_print_env_detail_human(pull_payload, action="Pulled")
return return
if args.env_command == "inspect": if args.env_command == "inspect":
_print_json(dict(pyro.inspect_environment(args.environment))) inspect_payload = pyro.inspect_environment(args.environment)
if bool(args.json):
_print_json(inspect_payload)
else:
_print_env_detail_human(inspect_payload, action="Environment")
return return
if args.env_command == "prune": if args.env_command == "prune":
_print_json(dict(pyro.prune_environments())) prune_payload = pyro.prune_environments()
if bool(args.json):
_print_json(prune_payload)
else:
_print_prune_human(prune_payload)
return return
if args.command == "mcp": if args.command == "mcp":
pyro.create_server().run(transport="stdio") pyro.create_server().run(transport="stdio")
return return
if args.command == "run": if args.command == "run":
command = _require_command(args.command_args) command = _require_command(args.command_args)
try:
result = pyro.run_in_vm( result = pyro.run_in_vm(
environment=args.environment, environment=args.environment,
command=command, command=command,
@ -106,11 +252,28 @@ def main() -> None:
timeout_seconds=args.timeout_seconds, timeout_seconds=args.timeout_seconds,
ttl_seconds=args.ttl_seconds, ttl_seconds=args.ttl_seconds,
network=args.network, network=args.network,
allow_host_compat=args.allow_host_compat,
) )
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(result) _print_json(result)
else:
_print_run_human(result)
exit_code = int(result.get("exit_code", 1))
if exit_code != 0:
raise SystemExit(exit_code)
return return
if args.command == "doctor": if args.command == "doctor":
_print_json(doctor_report(platform=args.platform)) payload = doctor_report(platform=args.platform)
if bool(args.json):
_print_json(payload)
else:
_print_doctor_human(payload)
return return
if args.command == "demo" and args.demo_command == "ollama": if args.command == "demo" and args.demo_command == "ollama":
try: try:

View file

@ -11,6 +11,8 @@ PUBLIC_CLI_RUN_FLAGS = (
"--timeout-seconds", "--timeout-seconds",
"--ttl-seconds", "--ttl-seconds",
"--network", "--network",
"--allow-host-compat",
"--json",
) )
PUBLIC_SDK_METHODS = ( PUBLIC_SDK_METHODS = (

View file

@ -6,6 +6,7 @@ import json
from typing import Any from typing import Any
from pyro_mcp.api import Pyro from pyro_mcp.api import Pyro
from pyro_mcp.vm_manager import DEFAULT_MEM_MIB, DEFAULT_TTL_SECONDS, DEFAULT_VCPU_COUNT
INTERNET_PROBE_COMMAND = ( INTERNET_PROBE_COMMAND = (
'python3 -c "import urllib.request; ' 'python3 -c "import urllib.request; '
@ -30,10 +31,10 @@ def run_demo(*, network: bool = False) -> dict[str, Any]:
return pyro.run_in_vm( return pyro.run_in_vm(
environment="debian:12", environment="debian:12",
command=_demo_command(status), command=_demo_command(status),
vcpu_count=1, vcpu_count=DEFAULT_VCPU_COUNT,
mem_mib=512, mem_mib=DEFAULT_MEM_MIB,
timeout_seconds=30, timeout_seconds=30,
ttl_seconds=600, ttl_seconds=DEFAULT_TTL_SECONDS,
network=network, network=network,
) )

View file

@ -10,6 +10,13 @@ from collections.abc import Callable
from typing import Any, Final, cast from typing import Any, Final, cast
from pyro_mcp.api import Pyro from pyro_mcp.api import Pyro
from pyro_mcp.vm_manager import (
DEFAULT_ALLOW_HOST_COMPAT,
DEFAULT_MEM_MIB,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TTL_SECONDS,
DEFAULT_VCPU_COUNT,
)
__all__ = ["Pyro", "run_ollama_tool_demo"] __all__ = ["Pyro", "run_ollama_tool_demo"]
@ -39,8 +46,9 @@ TOOL_SPECS: Final[list[dict[str, Any]]] = [
"timeout_seconds": {"type": "integer"}, "timeout_seconds": {"type": "integer"},
"ttl_seconds": {"type": "integer"}, "ttl_seconds": {"type": "integer"},
"network": {"type": "boolean"}, "network": {"type": "boolean"},
"allow_host_compat": {"type": "boolean"},
}, },
"required": ["environment", "command", "vcpu_count", "mem_mib"], "required": ["environment", "command"],
"additionalProperties": False, "additionalProperties": False,
}, },
}, },
@ -61,7 +69,7 @@ TOOL_SPECS: Final[list[dict[str, Any]]] = [
"type": "function", "type": "function",
"function": { "function": {
"name": "vm_create", "name": "vm_create",
"description": "Create an ephemeral VM with explicit vCPU and memory sizing.", "description": "Create an ephemeral VM with optional resource sizing.",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
@ -70,8 +78,9 @@ TOOL_SPECS: Final[list[dict[str, Any]]] = [
"mem_mib": {"type": "integer"}, "mem_mib": {"type": "integer"},
"ttl_seconds": {"type": "integer"}, "ttl_seconds": {"type": "integer"},
"network": {"type": "boolean"}, "network": {"type": "boolean"},
"allow_host_compat": {"type": "boolean"},
}, },
"required": ["environment", "vcpu_count", "mem_mib"], "required": ["environment"],
"additionalProperties": False, "additionalProperties": False,
}, },
}, },
@ -192,6 +201,12 @@ def _require_int(arguments: dict[str, Any], key: str) -> int:
raise ValueError(f"{key} must be an integer") raise ValueError(f"{key} must be an integer")
def _optional_int(arguments: dict[str, Any], key: str, *, default: int) -> int:
if key not in arguments:
return default
return _require_int(arguments, key)
def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool: def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool:
value = arguments.get(key, default) value = arguments.get(key, default)
if isinstance(value, bool): if isinstance(value, bool):
@ -211,27 +226,37 @@ def _dispatch_tool_call(
pyro: Pyro, tool_name: str, arguments: dict[str, Any] pyro: Pyro, tool_name: str, arguments: dict[str, Any]
) -> dict[str, Any]: ) -> dict[str, Any]:
if tool_name == "vm_run": if tool_name == "vm_run":
ttl_seconds = arguments.get("ttl_seconds", 600) ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)
timeout_seconds = arguments.get("timeout_seconds", 30) timeout_seconds = arguments.get("timeout_seconds", DEFAULT_TIMEOUT_SECONDS)
return pyro.run_in_vm( return pyro.run_in_vm(
environment=_require_str(arguments, "environment"), environment=_require_str(arguments, "environment"),
command=_require_str(arguments, "command"), command=_require_str(arguments, "command"),
vcpu_count=_require_int(arguments, "vcpu_count"), vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT),
mem_mib=_require_int(arguments, "mem_mib"), mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB),
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"), timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"), ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
network=_require_bool(arguments, "network", default=False), network=_require_bool(arguments, "network", default=False),
allow_host_compat=_require_bool(
arguments,
"allow_host_compat",
default=DEFAULT_ALLOW_HOST_COMPAT,
),
) )
if tool_name == "vm_list_environments": if tool_name == "vm_list_environments":
return {"environments": pyro.list_environments()} return {"environments": pyro.list_environments()}
if tool_name == "vm_create": if tool_name == "vm_create":
ttl_seconds = arguments.get("ttl_seconds", 600) ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)
return pyro.create_vm( return pyro.create_vm(
environment=_require_str(arguments, "environment"), environment=_require_str(arguments, "environment"),
vcpu_count=_require_int(arguments, "vcpu_count"), vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT),
mem_mib=_require_int(arguments, "mem_mib"), mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB),
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"), ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
network=_require_bool(arguments, "network", default=False), network=_require_bool(arguments, "network", default=False),
allow_host_compat=_require_bool(
arguments,
"allow_host_compat",
default=DEFAULT_ALLOW_HOST_COMPAT,
),
) )
if tool_name == "vm_start": if tool_name == "vm_start":
return pyro.start_vm(_require_str(arguments, "vm_id")) return pyro.start_vm(_require_str(arguments, "vm_id"))
@ -275,10 +300,10 @@ def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]:
return pyro.run_in_vm( return pyro.run_in_vm(
environment="debian:12", environment="debian:12",
command=NETWORK_PROOF_COMMAND, command=NETWORK_PROOF_COMMAND,
vcpu_count=1, vcpu_count=DEFAULT_VCPU_COUNT,
mem_mib=512, mem_mib=DEFAULT_MEM_MIB,
timeout_seconds=60, timeout_seconds=60,
ttl_seconds=600, ttl_seconds=DEFAULT_TTL_SECONDS,
network=True, network=True,
) )

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0" DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "1.0.0" DEFAULT_CATALOG_VERSION = "2.0.0"
OCI_MANIFEST_ACCEPT = ", ".join( OCI_MANIFEST_ACCEPT = ", ".join(
( (
"application/vnd.oci.image.index.v1+json", "application/vnd.oci.image.index.v1+json",
@ -48,7 +48,7 @@ class VmEnvironment:
oci_repository: str | None = None oci_repository: str | None = None
oci_reference: str | None = None oci_reference: str | None = None
source_digest: str | None = None source_digest: str | None = None
compatibility: str = ">=1.0.0,<2.0.0" compatibility: str = ">=2.0.0,<3.0.0"
@dataclass(frozen=True) @dataclass(frozen=True)
@ -114,6 +114,11 @@ def _default_cache_dir() -> Path:
) )
def default_cache_dir() -> Path:
"""Return the canonical default environment cache directory."""
return _default_cache_dir()
def _manifest_profile_digest(runtime_paths: RuntimePaths, profile_name: str) -> str | None: def _manifest_profile_digest(runtime_paths: RuntimePaths, profile_name: str) -> str | None:
profiles = runtime_paths.manifest.get("profiles") profiles = runtime_paths.manifest.get("profiles")
if not isinstance(profiles, dict): if not isinstance(profiles, dict):

View file

@ -19,13 +19,19 @@ from pyro_mcp.runtime import (
resolve_runtime_paths, resolve_runtime_paths,
runtime_capabilities, runtime_capabilities,
) )
from pyro_mcp.vm_environments import EnvironmentStore, get_environment from pyro_mcp.vm_environments import EnvironmentStore, default_cache_dir, get_environment
from pyro_mcp.vm_firecracker import build_launch_plan from pyro_mcp.vm_firecracker import build_launch_plan
from pyro_mcp.vm_guest import VsockExecClient from pyro_mcp.vm_guest import VsockExecClient
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
VmState = Literal["created", "started", "stopped"] VmState = Literal["created", "started", "stopped"]
DEFAULT_VCPU_COUNT = 1
DEFAULT_MEM_MIB = 1024
DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False
@dataclass @dataclass
class VmInstance: class VmInstance:
@ -41,6 +47,7 @@ class VmInstance:
workdir: Path workdir: Path
state: VmState = "created" state: VmState = "created"
network_requested: bool = False network_requested: bool = False
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT
firecracker_pid: int | None = None firecracker_pid: int | None = None
last_error: str | None = None last_error: str | None = None
metadata: dict[str, str] = field(default_factory=dict) metadata: dict[str, str] = field(default_factory=dict)
@ -262,7 +269,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
) )
instance.firecracker_pid = process.pid instance.firecracker_pid = process.pid
instance.metadata["execution_mode"] = ( instance.metadata["execution_mode"] = (
"guest_vsock" if self._runtime_capabilities.supports_guest_exec else "host_compat" "guest_vsock" if self._runtime_capabilities.supports_guest_exec else "guest_boot_only"
) )
instance.metadata["boot_mode"] = "native" instance.metadata["boot_mode"] = "native"
@ -342,6 +349,11 @@ class VmManager:
MAX_MEM_MIB = 32768 MAX_MEM_MIB = 32768
MIN_TTL_SECONDS = 60 MIN_TTL_SECONDS = 60
MAX_TTL_SECONDS = 3600 MAX_TTL_SECONDS = 3600
DEFAULT_VCPU_COUNT = DEFAULT_VCPU_COUNT
DEFAULT_MEM_MIB = DEFAULT_MEM_MIB
DEFAULT_TIMEOUT_SECONDS = DEFAULT_TIMEOUT_SECONDS
DEFAULT_TTL_SECONDS = DEFAULT_TTL_SECONDS
DEFAULT_ALLOW_HOST_COMPAT = DEFAULT_ALLOW_HOST_COMPAT
def __init__( def __init__(
self, self,
@ -355,7 +367,7 @@ class VmManager:
) -> None: ) -> None:
self._backend_name = backend_name or "firecracker" self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp") self._base_dir = base_dir or Path("/tmp/pyro-mcp")
resolved_cache_dir = cache_dir or self._base_dir / ".environment-cache" resolved_cache_dir = cache_dir or default_cache_dir()
self._runtime_paths = runtime_paths self._runtime_paths = runtime_paths
if self._backend_name == "firecracker": if self._backend_name == "firecracker":
self._runtime_paths = self._runtime_paths or resolve_runtime_paths() self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
@ -420,10 +432,11 @@ class VmManager:
self, self,
*, *,
environment: str, environment: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]: ) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds) self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_environment(environment, runtime_paths=self._runtime_paths) get_environment(environment, runtime_paths=self._runtime_paths)
@ -446,7 +459,9 @@ class VmManager:
expires_at=now + ttl_seconds, expires_at=now + ttl_seconds,
workdir=self._base_dir / vm_id, workdir=self._base_dir / vm_id,
network_requested=network, network_requested=network,
allow_host_compat=allow_host_compat,
) )
instance.metadata["allow_host_compat"] = str(allow_host_compat).lower()
self._backend.create(instance) self._backend.create(instance)
self._instances[vm_id] = instance self._instances[vm_id] = instance
return self._serialize(instance) return self._serialize(instance)
@ -456,11 +471,12 @@ class VmManager:
*, *,
environment: str, environment: str,
command: str, command: str,
vcpu_count: int, vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int, mem_mib: int = DEFAULT_MEM_MIB,
timeout_seconds: int = 30, timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
ttl_seconds: int = 600, ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False, network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
) -> dict[str, Any]: ) -> dict[str, Any]:
created = self.create_vm( created = self.create_vm(
environment=environment, environment=environment,
@ -468,6 +484,7 @@ class VmManager:
mem_mib=mem_mib, mem_mib=mem_mib,
ttl_seconds=ttl_seconds, ttl_seconds=ttl_seconds,
network=network, network=network,
allow_host_compat=allow_host_compat,
) )
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
try: try:
@ -486,6 +503,12 @@ class VmManager:
self._ensure_not_expired_locked(instance, time.time()) self._ensure_not_expired_locked(instance, time.time())
if instance.state not in {"created", "stopped"}: if instance.state not in {"created", "stopped"}:
raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}") raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}")
self._require_guest_boot_or_opt_in(instance)
if not self._runtime_capabilities.supports_vm_boot:
instance.metadata["execution_mode"] = "host_compat"
instance.metadata["boot_mode"] = "compat"
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
self._backend.start(instance) self._backend.start(instance)
instance.state = "started" instance.state = "started"
return self._serialize(instance) return self._serialize(instance)
@ -498,8 +521,11 @@ class VmManager:
self._ensure_not_expired_locked(instance, time.time()) self._ensure_not_expired_locked(instance, time.time())
if instance.state != "started": if instance.state != "started":
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec") raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
self._require_guest_exec_or_opt_in(instance)
if not self._runtime_capabilities.supports_guest_exec:
instance.metadata["execution_mode"] = "host_compat"
exec_result = self._backend.exec(instance, command, timeout_seconds) exec_result = self._backend.exec(instance, command, timeout_seconds)
execution_mode = instance.metadata.get("execution_mode", "host_compat") execution_mode = instance.metadata.get("execution_mode", "unknown")
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup") cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return { return {
"vm_id": vm_id, "vm_id": vm_id,
@ -587,12 +613,35 @@ class VmManager:
"expires_at": instance.expires_at, "expires_at": instance.expires_at,
"state": instance.state, "state": instance.state,
"network_enabled": instance.network is not None, "network_enabled": instance.network is not None,
"allow_host_compat": instance.allow_host_compat,
"guest_ip": instance.network.guest_ip if instance.network is not None else None, "guest_ip": instance.network.guest_ip if instance.network is not None else None,
"tap_name": instance.network.tap_name if instance.network is not None else None, "tap_name": instance.network.tap_name if instance.network is not None else None,
"execution_mode": instance.metadata.get("execution_mode", "host_compat"), "execution_mode": instance.metadata.get("execution_mode", "pending"),
"metadata": instance.metadata, "metadata": instance.metadata,
} }
def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None:
if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat:
return
reason = self._runtime_capabilities.reason or "runtime does not support real VM boot"
raise RuntimeError(
"guest boot is unavailable and host compatibility mode is disabled: "
f"{reason}. Set allow_host_compat=True (CLI: --allow-host-compat) to opt into "
"host execution."
)
def _require_guest_exec_or_opt_in(self, instance: VmInstance) -> None:
if self._runtime_capabilities.supports_guest_exec or instance.allow_host_compat:
return
reason = self._runtime_capabilities.reason or (
"runtime does not support guest command execution"
)
raise RuntimeError(
"guest command execution is unavailable and host compatibility mode is disabled: "
f"{reason}. Set allow_host_compat=True (CLI: --allow-host-compat) to opt into "
"host execution."
)
def _get_instance_locked(self, vm_id: str) -> VmInstance: def _get_instance_locked(self, vm_id: str) -> VmInstance:
try: try:
return self._instances[vm_id] return self._instances[vm_id]

View file

@ -25,6 +25,7 @@ def test_pyro_run_in_vm_delegates_to_manager(tmp_path: Path) -> None:
timeout_seconds=30, timeout_seconds=30,
ttl_seconds=600, ttl_seconds=600,
network=False, network=False,
allow_host_compat=True,
) )
assert int(result["exit_code"]) == 0 assert int(result["exit_code"]) == 0
assert str(result["stdout"]) == "ok\n" assert str(result["stdout"]) == "ok\n"
@ -74,12 +75,30 @@ def test_pyro_vm_run_tool_executes(tmp_path: Path) -> None:
{ {
"environment": "debian:12-base", "environment": "debian:12-base",
"command": "printf 'ok\\n'", "command": "printf 'ok\\n'",
"vcpu_count": 1,
"mem_mib": 512,
"network": False, "network": False,
"allow_host_compat": True,
}, },
) )
) )
result = asyncio.run(_run()) result = asyncio.run(_run())
assert int(result["exit_code"]) == 0 assert int(result["exit_code"]) == 0
def test_pyro_create_vm_defaults_sizing_and_host_compat(tmp_path: Path) -> None:
pyro = Pyro(
manager=VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
)
created = pyro.create_vm(
environment="debian:12-base",
allow_host_compat=True,
)
assert created["vcpu_count"] == 1
assert created["mem_mib"] == 1024
assert created["allow_host_compat"] is True

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import argparse import argparse
import json import json
import sys
from typing import Any from typing import Any
import pytest import pytest
@ -29,6 +30,8 @@ def test_cli_run_prints_json(
timeout_seconds=30, timeout_seconds=30,
ttl_seconds=600, ttl_seconds=600,
network=True, network=True,
allow_host_compat=False,
json=True,
command_args=["--", "echo", "hi"], command_args=["--", "echo", "hi"],
) )
@ -44,7 +47,7 @@ def test_cli_doctor_prints_json(
) -> None: ) -> None:
class StubParser: class StubParser:
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(command="doctor", platform="linux-x86_64") return argparse.Namespace(command="doctor", platform="linux-x86_64", json=True)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser()) monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr( monkeypatch.setattr(
@ -93,7 +96,7 @@ def test_cli_env_list_prints_json(
class StubParser: class StubParser:
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(command="env", env_command="list") return argparse.Namespace(command="env", env_command="list", json=True)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser()) monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro) monkeypatch.setattr(cli, "Pyro", StubPyro)
@ -102,6 +105,372 @@ def test_cli_env_list_prints_json(
assert output["environments"][0]["name"] == "debian:12" assert output["environments"][0]["name"] == "debian:12"
def test_cli_run_prints_human_output(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubPyro:
def run_in_vm(self, **kwargs: Any) -> dict[str, Any]:
assert kwargs["vcpu_count"] == 1
assert kwargs["mem_mib"] == 1024
return {
"environment": kwargs["environment"],
"execution_mode": "guest_vsock",
"exit_code": 0,
"duration_ms": 12,
"stdout": "hi\n",
"stderr": "",
}
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="run",
environment="debian:12",
vcpu_count=1,
mem_mib=1024,
timeout_seconds=30,
ttl_seconds=600,
network=False,
allow_host_compat=False,
json=False,
command_args=["--", "echo", "hi"],
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
captured = capsys.readouterr()
assert captured.out == "hi\n"
assert "[run] environment=debian:12 execution_mode=guest_vsock exit_code=0" in captured.err
def test_cli_run_exits_with_command_status(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubPyro:
def run_in_vm(self, **kwargs: Any) -> dict[str, Any]:
del kwargs
return {
"environment": "debian:12",
"execution_mode": "guest_vsock",
"exit_code": 7,
"duration_ms": 5,
"stdout": "",
"stderr": "bad\n",
}
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="run",
environment="debian:12",
vcpu_count=1,
mem_mib=1024,
timeout_seconds=30,
ttl_seconds=600,
network=False,
allow_host_compat=False,
json=False,
command_args=["--", "false"],
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
with pytest.raises(SystemExit, match="7"):
cli.main()
captured = capsys.readouterr()
assert "bad\n" in captured.err
def test_cli_requires_run_command() -> None: def test_cli_requires_run_command() -> None:
with pytest.raises(ValueError, match="command is required"): with pytest.raises(ValueError, match="command is required"):
cli._require_command([]) cli._require_command([])
def test_print_env_helpers_render_human_output(capsys: pytest.CaptureFixture[str]) -> None:
cli._print_env_list_human(
{
"catalog_version": "2.0.0",
"environments": [
{"name": "debian:12", "installed": True, "description": "Git environment"},
"ignored",
],
}
)
cli._print_env_detail_human(
{
"name": "debian:12",
"version": "1.0.0",
"distribution": "debian",
"distribution_version": "12",
"installed": True,
"cache_dir": "/cache",
"default_packages": ["bash", "git"],
"description": "Git environment",
"install_dir": "/cache/linux-x86_64/debian_12-1.0.0",
"install_manifest": "/cache/linux-x86_64/debian_12-1.0.0/environment.json",
"kernel_image": "/cache/vmlinux",
"rootfs_image": "/cache/rootfs.ext4",
"oci_registry": "registry-1.docker.io",
"oci_repository": "thalesmaciel/pyro-environment-debian-12",
"oci_reference": "1.0.0",
},
action="Environment",
)
cli._print_prune_human({"count": 2, "deleted_environment_dirs": ["a", "b"]})
cli._print_doctor_human(
{
"platform": "linux-x86_64",
"runtime_ok": False,
"issues": ["broken"],
"kvm": {"exists": True, "readable": True, "writable": False},
"runtime": {
"cache_dir": "/cache",
"capabilities": {
"supports_vm_boot": True,
"supports_guest_exec": False,
"supports_guest_network": True,
},
},
"networking": {"tun_available": True, "ip_forward_enabled": False},
}
)
captured = capsys.readouterr().out
assert "Catalog version: 2.0.0" in captured
assert "debian:12 [installed] Git environment" in captured
assert "Install manifest: /cache/linux-x86_64/debian_12-1.0.0/environment.json" in captured
assert "Deleted 2 cached environment entries." in captured
assert "Runtime: FAIL" in captured
assert "Issues:" in captured
def test_print_env_list_human_handles_empty(capsys: pytest.CaptureFixture[str]) -> None:
cli._print_env_list_human({"catalog_version": "2.0.0", "environments": []})
output = capsys.readouterr().out
assert "No environments found." in output
def test_write_stream_skips_empty(capsys: pytest.CaptureFixture[str]) -> None:
cli._write_stream("", stream=sys.stdout)
cli._write_stream("x", stream=sys.stdout)
captured = capsys.readouterr()
assert captured.out == "x"
def test_cli_env_pull_prints_human(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def pull_environment(self, environment: str) -> dict[str, object]:
assert environment == "debian:12"
return {
"name": "debian:12",
"version": "1.0.0",
"distribution": "debian",
"distribution_version": "12",
"installed": True,
"cache_dir": "/cache",
}
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="env",
env_command="pull",
environment="debian:12",
json=False,
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
output = capsys.readouterr().out
assert "Pulled: debian:12" in output
def test_cli_env_inspect_and_prune_print_human(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def inspect_environment(self, environment: str) -> dict[str, object]:
assert environment == "debian:12"
return {
"name": "debian:12",
"version": "1.0.0",
"distribution": "debian",
"distribution_version": "12",
"installed": False,
"cache_dir": "/cache",
}
def prune_environments(self) -> dict[str, object]:
return {"count": 1, "deleted_environment_dirs": ["stale"]}
class InspectParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="env",
env_command="inspect",
environment="debian:12",
json=False,
)
monkeypatch.setattr(cli, "_build_parser", lambda: InspectParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
class PruneParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(command="env", env_command="prune", json=False)
monkeypatch.setattr(cli, "_build_parser", lambda: PruneParser())
cli.main()
output = capsys.readouterr().out
assert "Environment: debian:12" in output
assert "Deleted 1 cached environment entry." in output
def test_cli_doctor_prints_human(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(command="doctor", platform="linux-x86_64", json=False)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(
cli,
"doctor_report",
lambda platform: {
"platform": platform,
"runtime_ok": True,
"issues": [],
"kvm": {"exists": True, "readable": True, "writable": True},
},
)
cli.main()
output = capsys.readouterr().out
assert "Runtime: PASS" in output
def test_cli_run_json_error_exits_nonzero(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def run_in_vm(self, **kwargs: Any) -> dict[str, Any]:
del kwargs
raise RuntimeError("guest boot is unavailable")
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="run",
environment="debian:12",
vcpu_count=1,
mem_mib=1024,
timeout_seconds=30,
ttl_seconds=600,
network=False,
allow_host_compat=False,
json=True,
command_args=["--", "echo", "hi"],
)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
with pytest.raises(SystemExit, match="1"):
cli.main()
payload = json.loads(capsys.readouterr().out)
assert payload["ok"] is False
def test_cli_mcp_runs_stdio_transport(monkeypatch: pytest.MonkeyPatch) -> None:
observed: dict[str, str] = {}
class StubPyro:
def create_server(self) -> Any:
return type(
"StubServer",
(),
{"run": staticmethod(lambda transport: observed.update({"transport": transport}))},
)()
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(command="mcp", mcp_command="serve")
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "Pyro", StubPyro)
cli.main()
assert observed == {"transport": "stdio"}
def test_cli_demo_default_prints_json(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(command="demo", demo_command=None, network=False)
monkeypatch.setattr(cli, "_build_parser", lambda: StubParser())
monkeypatch.setattr(cli, "run_demo", lambda network: {"exit_code": 0, "network": network})
cli.main()
output = json.loads(capsys.readouterr().out)
assert output["exit_code"] == 0
def test_cli_demo_ollama_verbose_and_error_paths(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class VerboseParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="demo",
demo_command="ollama",
base_url="http://localhost:11434/v1",
model="llama3.2:3b",
verbose=True,
)
monkeypatch.setattr(cli, "_build_parser", lambda: VerboseParser())
monkeypatch.setattr(
cli,
"run_ollama_tool_demo",
lambda **kwargs: {
"exec_result": {"exit_code": 0, "execution_mode": "guest_vsock", "stdout": "true\n"},
"fallback_used": False,
},
)
cli.main()
output = capsys.readouterr().out
assert "[summary] stdout=true" in output
class ErrorParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(
command="demo",
demo_command="ollama",
base_url="http://localhost:11434/v1",
model="llama3.2:3b",
verbose=False,
)
monkeypatch.setattr(cli, "_build_parser", lambda: ErrorParser())
monkeypatch.setattr(
cli,
"run_ollama_tool_demo",
lambda **kwargs: (_ for _ in ()).throw(RuntimeError("tool loop failed")),
)
with pytest.raises(SystemExit, match="1"):
cli.main()
assert "[error] tool loop failed" in capsys.readouterr().out

View file

@ -53,7 +53,7 @@ def test_run_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> None:
"environment": "debian:12", "environment": "debian:12",
"command": "git --version", "command": "git --version",
"vcpu_count": 1, "vcpu_count": 1,
"mem_mib": 512, "mem_mib": 1024,
"timeout_seconds": 30, "timeout_seconds": 30,
"ttl_seconds": 600, "ttl_seconds": 600,
"network": False, "network": False,
@ -95,3 +95,4 @@ def test_run_demo_network_uses_probe(monkeypatch: pytest.MonkeyPatch) -> None:
demo_module.run_demo(network=True) demo_module.run_demo(network=True)
assert "https://example.com" in str(captured["command"]) assert "https://example.com" in str(captured["command"])
assert captured["network"] is True assert captured["network"] is True
assert captured["mem_mib"] == 1024

View file

@ -52,9 +52,8 @@ def _stepwise_model_response(payload: dict[str, Any], step: int) -> dict[str, An
{ {
"environment": "debian:12", "environment": "debian:12",
"command": "printf 'true\\n'", "command": "printf 'true\\n'",
"vcpu_count": 1,
"mem_mib": 512,
"network": True, "network": True,
"allow_host_compat": True,
} }
), ),
}, },
@ -119,9 +118,8 @@ def test_run_ollama_tool_demo_accepts_legacy_profile_and_string_network(
{ {
"profile": "debian:12", "profile": "debian:12",
"command": "printf 'true\\n'", "command": "printf 'true\\n'",
"vcpu_count": 1,
"mem_mib": 512,
"network": "true", "network": "true",
"allow_host_compat": True,
} }
), ),
}, },
@ -224,8 +222,7 @@ def test_run_ollama_tool_demo_resolves_vm_id_placeholder(
"arguments": json.dumps( "arguments": json.dumps(
{ {
"environment": "debian:12", "environment": "debian:12",
"vcpu_count": "2", "allow_host_compat": True,
"mem_mib": "2048",
} }
), ),
}, },
@ -280,6 +277,7 @@ def test_dispatch_tool_call_vm_exec_autostarts_created_vm(tmp_path: Path) -> Non
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=60, ttl_seconds=60,
allow_host_compat=True,
) )
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
@ -458,6 +456,7 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
"mem_mib": "512", "mem_mib": "512",
"ttl_seconds": "60", "ttl_seconds": "60",
"network": False, "network": False,
"allow_host_compat": True,
}, },
) )
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
@ -477,10 +476,9 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
{ {
"environment": "debian:12-base", "environment": "debian:12-base",
"command": "printf 'true\\n'", "command": "printf 'true\\n'",
"vcpu_count": "1",
"mem_mib": "512",
"timeout_seconds": "30", "timeout_seconds": "30",
"network": False, "network": False,
"allow_host_compat": True,
}, },
) )
assert int(executed_run["exit_code"]) == 0 assert int(executed_run["exit_code"]) == 0

View file

@ -49,11 +49,11 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None:
assert command_name in help_text assert command_name in help_text
run_parser = _build_parser() run_parser = _build_parser()
run_help = run_parser.parse_args( run_help = run_parser.parse_args(["run", "debian:12-base", "--", "true"])
["run", "debian:12-base", "--vcpu-count", "1", "--mem-mib", "512", "--", "true"]
)
assert run_help.command == "run" assert run_help.command == "run"
assert run_help.environment == "debian:12-base" assert run_help.environment == "debian:12-base"
assert run_help.vcpu_count == 1
assert run_help.mem_mib == 1024
run_help_text = _subparser_choice(parser, "run").format_help() run_help_text = _subparser_choice(parser, "run").format_help()
for flag in PUBLIC_CLI_RUN_FLAGS: for flag in PUBLIC_CLI_RUN_FLAGS:

View file

@ -56,10 +56,9 @@ def test_vm_run_round_trip(tmp_path: Path) -> None:
{ {
"environment": "debian:12", "environment": "debian:12",
"command": "printf 'git version 2.0\\n'", "command": "printf 'git version 2.0\\n'",
"vcpu_count": 1,
"mem_mib": 512,
"ttl_seconds": 600, "ttl_seconds": 600,
"network": False, "network": False,
"allow_host_compat": True,
}, },
) )
) )
@ -109,9 +108,8 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None:
"vm_create", "vm_create",
{ {
"environment": "debian:12-base", "environment": "debian:12-base",
"vcpu_count": 1,
"mem_mib": 512,
"ttl_seconds": 600, "ttl_seconds": 600,
"allow_host_compat": True,
}, },
) )
) )
@ -127,9 +125,8 @@ def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None:
"vm_create", "vm_create",
{ {
"environment": "debian:12-base", "environment": "debian:12-base",
"vcpu_count": 1,
"mem_mib": 512,
"ttl_seconds": 1, "ttl_seconds": 1,
"allow_host_compat": True,
}, },
) )
) )

View file

@ -22,6 +22,7 @@ def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=600, ttl_seconds=600,
allow_host_compat=True,
) )
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
started = manager.start_vm(vm_id) started = manager.start_vm(vm_id)
@ -47,6 +48,7 @@ def test_vm_manager_exec_timeout(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=600, ttl_seconds=600,
allow_host_compat=True,
)["vm_id"] )["vm_id"]
) )
manager.start_vm(vm_id) manager.start_vm(vm_id)
@ -67,6 +69,7 @@ def test_vm_manager_stop_and_delete(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=600, ttl_seconds=600,
allow_host_compat=True,
)["vm_id"] )["vm_id"]
) )
manager.start_vm(vm_id) manager.start_vm(vm_id)
@ -89,6 +92,7 @@ def test_vm_manager_reaps_expired(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=1, ttl_seconds=1,
allow_host_compat=True,
)["vm_id"] )["vm_id"]
) )
instance = manager._instances[vm_id] # noqa: SLF001 instance = manager._instances[vm_id] # noqa: SLF001
@ -112,6 +116,7 @@ def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=1, ttl_seconds=1,
allow_host_compat=True,
)["vm_id"] )["vm_id"]
) )
manager.start_vm(vm_id) manager.start_vm(vm_id)
@ -145,9 +150,21 @@ def test_vm_manager_max_active_limit(tmp_path: Path) -> None:
max_active_vms=1, max_active_vms=1,
network_manager=TapNetworkManager(enabled=False), network_manager=TapNetworkManager(enabled=False),
) )
manager.create_vm(environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600) manager.create_vm(
environment="debian:12-base",
vcpu_count=1,
mem_mib=512,
ttl_seconds=600,
allow_host_compat=True,
)
with pytest.raises(RuntimeError, match="max active VMs reached"): with pytest.raises(RuntimeError, match="max active VMs reached"):
manager.create_vm(environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600) manager.create_vm(
environment="debian:12-base",
vcpu_count=1,
mem_mib=512,
ttl_seconds=600,
allow_host_compat=True,
)
def test_vm_manager_state_validation(tmp_path: Path) -> None: def test_vm_manager_state_validation(tmp_path: Path) -> None:
@ -162,6 +179,7 @@ def test_vm_manager_state_validation(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=600, ttl_seconds=600,
allow_host_compat=True,
)["vm_id"] )["vm_id"]
) )
with pytest.raises(RuntimeError, match="must be in 'started' state"): with pytest.raises(RuntimeError, match="must be in 'started' state"):
@ -186,6 +204,7 @@ def test_vm_manager_status_expired_raises(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=1, ttl_seconds=1,
allow_host_compat=True,
)["vm_id"] )["vm_id"]
) )
manager._instances[vm_id].expires_at = 0.0 # noqa: SLF001 manager._instances[vm_id].expires_at = 0.0 # noqa: SLF001
@ -213,6 +232,7 @@ def test_vm_manager_network_info(tmp_path: Path) -> None:
vcpu_count=1, vcpu_count=1,
mem_mib=512, mem_mib=512,
ttl_seconds=600, ttl_seconds=600,
allow_host_compat=True,
) )
vm_id = str(created["vm_id"]) vm_id = str(created["vm_id"])
status = manager.status_vm(vm_id) status = manager.status_vm(vm_id)
@ -236,6 +256,7 @@ def test_vm_manager_run_vm(tmp_path: Path) -> None:
timeout_seconds=30, timeout_seconds=30,
ttl_seconds=600, ttl_seconds=600,
network=False, network=False,
allow_host_compat=True,
) )
assert int(result["exit_code"]) == 0 assert int(result["exit_code"]) == 0
assert str(result["stdout"]) == "ok\n" assert str(result["stdout"]) == "ok\n"
@ -283,3 +304,33 @@ def test_vm_manager_firecracker_backend_path(
network_manager=TapNetworkManager(enabled=False), network_manager=TapNetworkManager(enabled=False),
) )
assert manager._backend_name == "firecracker" # noqa: SLF001 assert manager._backend_name == "firecracker" # noqa: SLF001
def test_vm_manager_fails_closed_without_host_compat_opt_in(tmp_path: Path) -> None:
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
vm_id = str(
manager.create_vm(
environment="debian:12-base",
ttl_seconds=600,
)["vm_id"]
)
with pytest.raises(RuntimeError, match="guest boot is unavailable"):
manager.start_vm(vm_id)
def test_vm_manager_uses_canonical_default_cache_dir(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.setenv("PYRO_ENVIRONMENT_CACHE_DIR", str(tmp_path / "cache"))
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
assert manager._environment_store.cache_dir == tmp_path / "cache" # noqa: SLF001

2
uv.lock generated
View file

@ -706,7 +706,7 @@ crypto = [
[[package]] [[package]]
name = "pyro-mcp" name = "pyro-mcp"
version = "1.0.0" version = "2.0.0"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "mcp" }, { name = "mcp" },