Bundle firecracker runtime and switch ollama demo to live logs

This commit is contained in:
Thales Maciel 2026-03-05 20:20:36 -03:00
parent ef0ddeaa11
commit 65f7c0d262
26 changed files with 1896 additions and 408 deletions

View file

@ -4,15 +4,16 @@ Repository guidance for contributors and coding agents.
## Purpose
This repository ships `pyro-mcp`, a minimal MCP-compatible Python package with a static tool for demonstration and testing.
This repository ships `pyro-mcp`, an MCP-compatible package for ephemeral VM lifecycle tools used by coding agents.
## Development Workflow
- Use `uv` for all Python environment and command execution.
- Run `make setup` after cloning.
- Run `make check` before opening a PR.
- Use `make demo` to verify the static tool behavior manually.
- Use `make ollama-demo` to validate model-triggered tool usage with Ollama.
- Use `make demo` to validate deterministic VM lifecycle execution.
- Use `make ollama-demo` to validate model-triggered lifecycle tool usage.
- Use `make doctor` to inspect bundled runtime integrity and host prerequisites.
## Quality Gates
@ -25,8 +26,12 @@ These checks run in pre-commit hooks and should all pass locally.
## Key API Contract
- Public factory: `pyro_mcp.create_server()`
- Tool name: `hello_static`
- Tool output:
- `message`: `hello from pyro_mcp`
- `status`: `ok`
- `version`: `0.0.1`
- Runtime diagnostics CLI: `pyro-mcp-doctor`
- Lifecycle tools:
- `vm_list_profiles`
- `vm_create`
- `vm_start`
- `vm_exec`
- `vm_stop`
- `vm_delete`
- `vm_status`

View file

@ -2,7 +2,7 @@ PYTHON ?= uv run python
OLLAMA_BASE_URL ?= http://localhost:11434/v1
OLLAMA_MODEL ?= llama3.2:3b
.PHONY: setup lint format typecheck test check demo ollama ollama-demo run-server install-hooks
.PHONY: setup lint format typecheck test check demo doctor ollama ollama-demo run-server install-hooks
setup:
uv sync --dev
@ -24,6 +24,9 @@ check: lint typecheck test
demo:
uv run python examples/static_tool_demo.py
doctor:
uv run pyro-mcp-doctor
ollama: ollama-demo
ollama-demo:

View file

@ -1,19 +1,36 @@
# pyro-mcp
`pyro-mcp` is a minimal Python library that exposes an MCP-compatible server with one static tool.
`pyro-mcp` is an MCP-compatible tool package for running ephemeral development environments with a VM lifecycle API.
## v0.0.1 Features
## v0.1.0 Capabilities
- Official Python MCP SDK integration.
- Public server factory: `pyro_mcp.create_server()`.
- One static MCP tool: `hello_static`.
- Runnable demonstration script.
- Project automation via `Makefile`, `pre-commit`, `ruff`, `mypy`, and `pytest`.
- Split lifecycle tools for coding agents: `vm_list_profiles`, `vm_create`, `vm_start`, `vm_exec`, `vm_stop`, `vm_delete`, `vm_status`, `vm_reap_expired`.
- Standard environment profiles:
- `debian-base`: minimal Debian shell/core Unix tools.
- `debian-git`: Debian base with Git preinstalled.
- `debian-build`: Debian Git profile with common build tooling.
- Explicit sizing contract for agents (`vcpu_count`, `mem_mib`) with guardrails.
- Strict ephemerality for command execution (`vm_exec` auto-deletes VM on completion).
- Ollama demo that asks an LLM to run `git --version` through lifecycle tools.
## Runtime
The package includes a bundled Linux x86_64 runtime payload:
- Firecracker binary
- Jailer binary
- Profile artifacts for `debian-base`, `debian-git`, and `debian-build`
No system Firecracker installation is required for basic usage.
Host requirements still apply:
- Linux host
- `/dev/kvm` available for full virtualization mode
## Requirements
- Python 3.12+
- `uv` installed
- `uv`
- Optional for Ollama demo: local Ollama server and `llama:3.2-3b` model.
## Setup
@ -21,70 +38,51 @@
make setup
```
This installs runtime and development dependencies into `.venv`.
## Run the demo
## Run deterministic lifecycle demo
```bash
make demo
```
Expected output:
The demo creates a VM, starts it, runs `git --version`, and returns structured output.
```json
{
"message": "hello from pyro_mcp",
"status": "ok",
"version": "0.0.1"
}
## Runtime doctor
```bash
make doctor
```
## Run the Ollama tool-calling demo
This prints bundled runtime paths, profile availability, checksum validation status, and KVM host checks.
Start Ollama and ensure the model is available (defaults to `llama:3.2-3b`):
## Run Ollama lifecycle demo
```bash
ollama serve
ollama pull llama:3.2-3b
```
Then run:
```bash
make ollama-demo
```
You can also run `make ollama demo` to execute both demos in one command.
The Make target defaults to model `llama:3.2-3b` and can be overridden:
Defaults are configured in `Makefile`.
The demo streams lifecycle progress logs and ends with a short text summary.
```bash
make ollama-demo OLLAMA_MODEL=llama3.2:3b
```
## Run checks
```bash
make check
```
`make check` runs:
- `ruff` lint checks
- `mypy` type checks
- `pytest` (with coverage threshold configured in `pyproject.toml`)
## Run MCP server (stdio transport)
## Run MCP server
```bash
make run-server
```
## Pre-commit
## Quality checks
Install hooks:
```bash
make check
```
Includes `ruff`, `mypy`, and `pytest` with coverage threshold.
## Pre-commit
```bash
make install-hooks
```
Hooks run `ruff`, `mypy`, and `pytest` on each commit.
Hooks execute the same lint/type/test gates.

View file

@ -1,15 +1,14 @@
"""Example script that proves the static MCP tool works."""
"""Example script that proves lifecycle command execution works."""
from __future__ import annotations
import asyncio
import json
from pyro_mcp.demo import run_demo
def main() -> None:
payload = asyncio.run(run_demo())
payload = run_demo()
print(json.dumps(payload, indent=2, sort_keys=True))

View file

@ -1,7 +1,7 @@
[project]
name = "pyro-mcp"
version = "0.0.1"
description = "A minimal MCP-ready Python tool library."
version = "0.1.0"
description = "MCP tools for ephemeral VM lifecycle management."
readme = "README.md"
authors = [
{ name = "Thales Maciel", email = "thales@thalesmaciel.com" }
@ -15,11 +15,27 @@ dependencies = [
pyro-mcp-server = "pyro_mcp.server:main"
pyro-mcp-demo = "pyro_mcp.demo:main"
pyro-mcp-ollama-demo = "pyro_mcp.ollama_demo:main"
pyro-mcp-doctor = "pyro_mcp.doctor:main"
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["src/pyro_mcp"]
[tool.hatch.build.targets.wheel.force-include]
"src/pyro_mcp/runtime_bundle" = "pyro_mcp/runtime_bundle"
[tool.hatch.build.targets.sdist]
include = [
"src/pyro_mcp/runtime_bundle/**",
"src/pyro_mcp/**/*.py",
"README.md",
"AGENTS.md",
"pyproject.toml",
]
[dependency-groups]
dev = [
"mypy>=1.19.1",

View file

@ -1,5 +1,6 @@
"""Public package surface for pyro_mcp."""
from pyro_mcp.server import HELLO_STATIC_PAYLOAD, create_server
from pyro_mcp.server import create_server
from pyro_mcp.vm_manager import VmManager
__all__ = ["HELLO_STATIC_PAYLOAD", "create_server"]
__all__ = ["VmManager", "create_server"]

View file

@ -1,35 +1,23 @@
"""Runnable demonstration for the static MCP tool."""
"""Runnable deterministic demo for VM lifecycle tools."""
from __future__ import annotations
import asyncio
import json
from collections.abc import Sequence
from typing import Any
from mcp.types import TextContent
from pyro_mcp.server import HELLO_STATIC_PAYLOAD, create_server
from pyro_mcp.vm_manager import VmManager
async def run_demo() -> dict[str, str]:
"""Call the static MCP tool and return its structured payload."""
server = create_server()
result = await server.call_tool("hello_static", {})
blocks, structured = result
are_text_blocks = all(isinstance(item, TextContent) for item in blocks)
if not isinstance(blocks, Sequence) or not are_text_blocks:
raise TypeError("unexpected MCP content block output")
if not isinstance(structured, dict):
raise TypeError("expected a structured dictionary payload")
if structured != HELLO_STATIC_PAYLOAD:
raise ValueError("static payload did not match expected value")
typed: dict[str, str] = {str(key): str(value) for key, value in structured.items()}
return typed
def run_demo() -> dict[str, Any]:
"""Create/start/exec/delete a VM and return command output."""
manager = VmManager()
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
manager.start_vm(vm_id)
executed = manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
return executed
def main() -> None:
"""Run the demonstration and print the JSON payload."""
payload = asyncio.run(run_demo())
print(json.dumps(payload, indent=2, sort_keys=True))
"""Run the deterministic lifecycle demo."""
print(json.dumps(run_demo(), indent=2, sort_keys=True))

20
src/pyro_mcp/doctor.py Normal file
View file

@ -0,0 +1,20 @@
"""Runtime diagnostics CLI."""
from __future__ import annotations
import argparse
import json
from pyro_mcp.runtime import DEFAULT_PLATFORM, doctor_report
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Inspect bundled runtime health for pyro-mcp.")
parser.add_argument("--platform", default=DEFAULT_PLATFORM)
return parser
def main() -> None:
args = _build_parser().parse_args()
report = doctor_report(platform=args.platform)
print(json.dumps(report, indent=2, sort_keys=True))

View file

@ -1,32 +1,97 @@
"""Ollama chat-completions demo that triggers `hello_static` tool usage."""
"""Ollama demo that drives VM lifecycle tools to run an ephemeral command."""
from __future__ import annotations
import argparse
import asyncio
import json
import urllib.error
import urllib.request
from collections.abc import Callable
from typing import Any, Final, cast
from pyro_mcp.demo import run_demo
from pyro_mcp.vm_manager import VmManager
__all__ = ["VmManager", "run_ollama_tool_demo"]
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
TOOL_NAME: Final[str] = "hello_static"
MAX_TOOL_ROUNDS: Final[int] = 12
TOOL_SPEC: Final[dict[str, Any]] = {
"type": "function",
"function": {
"name": TOOL_NAME,
"description": "Returns a deterministic static payload from pyro_mcp.",
"parameters": {
"type": "object",
"properties": {},
"additionalProperties": False,
TOOL_SPECS: Final[list[dict[str, Any]]] = [
{
"type": "function",
"function": {
"name": "vm_list_profiles",
"description": "List standard VM environment profiles.",
"parameters": {
"type": "object",
"properties": {},
"additionalProperties": False,
},
},
},
}
{
"type": "function",
"function": {
"name": "vm_create",
"description": "Create an ephemeral VM with explicit vCPU and memory sizing.",
"parameters": {
"type": "object",
"properties": {
"profile": {"type": "string"},
"vcpu_count": {"type": "integer"},
"mem_mib": {"type": "integer"},
"ttl_seconds": {"type": "integer"},
},
"required": ["profile", "vcpu_count", "mem_mib"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_start",
"description": "Start a VM before command execution.",
"parameters": {
"type": "object",
"properties": {"vm_id": {"type": "string"}},
"required": ["vm_id"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_exec",
"description": "Run one non-interactive command inside the VM and auto-clean it.",
"parameters": {
"type": "object",
"properties": {
"vm_id": {"type": "string"},
"command": {"type": "string"},
"timeout_seconds": {"type": "integer"},
},
"required": ["vm_id", "command"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_status",
"description": "Read current VM status and metadata.",
"parameters": {
"type": "object",
"properties": {"vm_id": {"type": "string"}},
"required": ["vm_id"],
"additionalProperties": False,
},
},
},
]
def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
@ -39,13 +104,12 @@ def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, A
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=60) as response:
with urllib.request.urlopen(request, timeout=90) as response:
response_text = response.read().decode("utf-8")
except urllib.error.URLError as exc:
raise RuntimeError(
"failed to call Ollama. Ensure `ollama serve` is running and the model is available."
) from exc
parsed = json.loads(response_text)
if not isinstance(parsed, dict):
raise TypeError("unexpected Ollama response shape")
@ -80,89 +144,205 @@ def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]:
raise TypeError("tool arguments must be a dictionary or JSON object string")
def _require_str(arguments: dict[str, Any], key: str) -> str:
value = arguments.get(key)
if not isinstance(value, str) or value == "":
raise ValueError(f"{key} must be a non-empty string")
return value
def _require_int(arguments: dict[str, Any], key: str) -> int:
value = arguments.get(key)
if not isinstance(value, int):
raise ValueError(f"{key} must be an integer")
return value
def _dispatch_tool_call(
manager: VmManager, tool_name: str, arguments: dict[str, Any]
) -> dict[str, Any]:
if tool_name == "vm_list_profiles":
return {"profiles": manager.list_profiles()}
if tool_name == "vm_create":
return manager.create_vm(
profile=_require_str(arguments, "profile"),
vcpu_count=_require_int(arguments, "vcpu_count"),
mem_mib=_require_int(arguments, "mem_mib"),
ttl_seconds=arguments.get("ttl_seconds", 600)
if isinstance(arguments.get("ttl_seconds"), int)
else 600,
)
if tool_name == "vm_start":
return manager.start_vm(_require_str(arguments, "vm_id"))
if tool_name == "vm_exec":
return manager.exec_vm(
_require_str(arguments, "vm_id"),
command=_require_str(arguments, "command"),
timeout_seconds=arguments.get("timeout_seconds", 30)
if isinstance(arguments.get("timeout_seconds"), int)
else 30,
)
if tool_name == "vm_status":
return manager.status_vm(_require_str(arguments, "vm_id"))
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]:
return {
"ok": False,
"tool_name": tool_name,
"arguments": arguments,
"error_type": exc.__class__.__name__,
"error": str(exc),
}
def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
manager.start_vm(vm_id)
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
def run_ollama_tool_demo(
base_url: str = DEFAULT_OLLAMA_BASE_URL,
model: str = DEFAULT_OLLAMA_MODEL,
*,
strict: bool = True,
log: Callable[[str], None] | None = None,
) -> dict[str, Any]:
"""Ask Ollama to call the static tool, execute it, and return final model output."""
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
emit = log or (lambda _: None)
emit(f"[ollama] starting tool demo with model={model}")
manager = VmManager()
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": (
"Use the hello_static tool and then summarize its payload in one short sentence."
"Use the lifecycle tools to run `git --version` in an ephemeral VM.\n"
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
"returned by vm_create into vm_start/vm_exec.\n"
"If a tool returns an error, fix arguments and retry."
),
}
]
first_payload: dict[str, Any] = {
"model": model,
"messages": messages,
"tools": [TOOL_SPEC],
"tool_choice": "auto",
"temperature": 0,
}
first_response = _post_chat_completion(base_url, first_payload)
assistant_message = _extract_message(first_response)
tool_events: list[dict[str, Any]] = []
final_response = ""
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
raise RuntimeError("model did not trigger any tool call")
for round_index in range(1, MAX_TOOL_ROUNDS + 1):
emit(f"[ollama] round {round_index}: requesting completion")
response = _post_chat_completion(
base_url,
{
"model": model,
"messages": messages,
"tools": TOOL_SPECS,
"tool_choice": "auto",
"temperature": 0,
},
)
assistant_message = _extract_message(response)
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
final_response = str(assistant_message.get("content") or "")
emit("[ollama] no tool calls returned; stopping loop")
break
messages.append(
{
"role": "assistant",
"content": str(assistant_message.get("content") or ""),
"tool_calls": tool_calls,
}
)
tool_payload: dict[str, str] | None = None
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise RuntimeError("invalid tool call entry returned by model")
function = tool_call.get("function")
if not isinstance(function, dict):
raise RuntimeError("tool call did not include function metadata")
name = function.get("name")
if name != TOOL_NAME:
raise RuntimeError(f"unexpected tool requested by model: {name!r}")
arguments = _parse_tool_arguments(function.get("arguments"))
if arguments:
raise RuntimeError("hello_static does not accept arguments")
call_id = tool_call.get("id")
if not isinstance(call_id, str) or call_id == "":
raise RuntimeError("tool call did not provide a valid call id")
tool_payload = asyncio.run(run_demo())
messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"name": TOOL_NAME,
"content": json.dumps(tool_payload, sort_keys=True),
"role": "assistant",
"content": str(assistant_message.get("content") or ""),
"tool_calls": tool_calls,
}
)
if tool_payload is None:
raise RuntimeError("tool payload was not generated")
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise RuntimeError("invalid tool call entry returned by model")
call_id = tool_call.get("id")
if not isinstance(call_id, str) or call_id == "":
raise RuntimeError("tool call did not provide a valid call id")
function = tool_call.get("function")
if not isinstance(function, dict):
raise RuntimeError("tool call did not include function metadata")
tool_name = function.get("name")
if not isinstance(tool_name, str):
raise RuntimeError("tool call function name is invalid")
arguments = _parse_tool_arguments(function.get("arguments"))
emit(f"[tool] calling {tool_name} with args={arguments}")
try:
result = _dispatch_tool_call(manager, tool_name, arguments)
success = True
emit(f"[tool] {tool_name} succeeded")
except Exception as exc: # noqa: BLE001
result = _format_tool_error(tool_name, arguments, exc)
success = False
emit(f"[tool] {tool_name} failed: {exc}")
tool_events.append(
{
"tool_name": tool_name,
"arguments": arguments,
"result": result,
"success": success,
}
)
messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"name": tool_name,
"content": json.dumps(result, sort_keys=True),
}
)
else:
raise RuntimeError("tool-calling loop exceeded maximum rounds")
second_payload: dict[str, Any] = {
"model": model,
"messages": messages,
"temperature": 0,
}
second_response = _post_chat_completion(base_url, second_payload)
final_message = _extract_message(second_response)
exec_event = next(
(
event
for event in reversed(tool_events)
if event.get("tool_name") == "vm_exec" and bool(event.get("success"))
),
None,
)
fallback_used = False
if exec_event is None:
if strict:
raise RuntimeError("demo did not execute a successful vm_exec")
emit("[fallback] model did not complete vm_exec; running direct lifecycle command")
exec_result = _run_direct_lifecycle_fallback(manager)
fallback_used = True
tool_events.append(
{
"tool_name": "vm_exec_fallback",
"arguments": {"command": "git --version"},
"result": exec_result,
"success": True,
}
)
else:
exec_result = exec_event["result"]
if not isinstance(exec_result, dict):
raise RuntimeError("vm_exec result shape is invalid")
if int(exec_result.get("exit_code", -1)) != 0:
raise RuntimeError("vm_exec failed; expected exit_code=0")
if "git version" not in str(exec_result.get("stdout", "")):
raise RuntimeError("vm_exec output did not contain `git version`")
emit("[done] command execution succeeded")
return {
"model": model,
"tool_name": TOOL_NAME,
"tool_payload": tool_payload,
"final_response": str(final_message.get("content") or ""),
"command": "git --version",
"exec_result": exec_result,
"tool_events": tool_events,
"final_response": final_response,
"fallback_used": fallback_used,
}
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for pyro_mcp.")
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
return parser
@ -171,5 +351,18 @@ def _build_parser() -> argparse.ArgumentParser:
def main() -> None:
"""CLI entrypoint for Ollama tool-calling demo."""
args = _build_parser().parse_args()
result = run_ollama_tool_demo(base_url=args.base_url, model=args.model)
print(json.dumps(result, indent=2, sort_keys=True))
result = run_ollama_tool_demo(
base_url=args.base_url,
model=args.model,
strict=False,
log=lambda message: print(message, flush=True),
)
exec_result = result["exec_result"]
if not isinstance(exec_result, dict):
raise RuntimeError("demo produced invalid execution result")
print(
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
f"fallback_used={bool(result.get('fallback_used'))}",
flush=True,
)
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)

171
src/pyro_mcp/runtime.py Normal file
View file

@ -0,0 +1,171 @@
"""Bundled runtime resolver and diagnostics."""
from __future__ import annotations
import hashlib
import importlib.resources as resources
import json
import os
import stat
from dataclasses import dataclass
from pathlib import Path
from typing import Any
DEFAULT_PLATFORM = "linux-x86_64"
@dataclass(frozen=True)
class RuntimePaths:
"""Resolved paths for bundled runtime components."""
bundle_root: Path
manifest_path: Path
firecracker_bin: Path
jailer_bin: Path
artifacts_dir: Path
notice_path: Path
manifest: dict[str, Any]
def _sha256(path: Path) -> str:
digest = hashlib.sha256()
with path.open("rb") as fp:
for block in iter(lambda: fp.read(1024 * 1024), b""):
digest.update(block)
return digest.hexdigest()
def _ensure_executable(path: Path) -> None:
mode = path.stat().st_mode
if mode & stat.S_IXUSR:
return
path.chmod(mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
def _default_bundle_parent() -> Path:
return Path(str(resources.files("pyro_mcp.runtime_bundle")))
def resolve_runtime_paths(
*,
platform: str = DEFAULT_PLATFORM,
verify_checksums: bool = True,
) -> RuntimePaths:
"""Resolve and validate bundled runtime assets."""
bundle_parent = Path(os.environ.get("PYRO_RUNTIME_BUNDLE_DIR", _default_bundle_parent()))
bundle_root = bundle_parent / platform
manifest_path = bundle_root / "manifest.json"
notice_path = bundle_parent / "NOTICE"
if not manifest_path.exists():
raise RuntimeError(
f"bundled runtime manifest not found at {manifest_path}; reinstall package or "
"use a wheel that includes bundled runtime assets"
)
if not notice_path.exists():
raise RuntimeError(f"runtime NOTICE file missing at {notice_path}")
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
if not isinstance(manifest, dict):
raise RuntimeError("invalid runtime manifest format")
binaries = manifest.get("binaries")
if not isinstance(binaries, dict):
raise RuntimeError("runtime manifest is missing `binaries`")
firecracker_entry = binaries.get("firecracker")
jailer_entry = binaries.get("jailer")
if not isinstance(firecracker_entry, dict) or not isinstance(jailer_entry, dict):
raise RuntimeError("runtime manifest does not define firecracker/jailer binaries")
firecracker_bin = bundle_root / str(firecracker_entry.get("path", ""))
jailer_bin = bundle_root / str(jailer_entry.get("path", ""))
artifacts_dir = bundle_root / "profiles"
for path in (firecracker_bin, jailer_bin, artifacts_dir):
if not path.exists():
raise RuntimeError(f"runtime asset missing: {path}")
_ensure_executable(firecracker_bin)
_ensure_executable(jailer_bin)
if verify_checksums:
for entry in (firecracker_entry, jailer_entry):
raw_path = entry.get("path")
raw_hash = entry.get("sha256")
if not isinstance(raw_path, str) or not isinstance(raw_hash, str):
raise RuntimeError("runtime binary manifest entry is malformed")
full_path = bundle_root / raw_path
actual = _sha256(full_path)
if actual != raw_hash:
raise RuntimeError(
f"runtime checksum mismatch for {full_path}; expected {raw_hash}, got {actual}"
)
profiles = manifest.get("profiles")
if not isinstance(profiles, dict):
raise RuntimeError("runtime manifest is missing `profiles`")
for profile_name, profile_spec in profiles.items():
if not isinstance(profile_spec, dict):
raise RuntimeError(f"profile manifest entry for {profile_name!r} is malformed")
for kind in ("kernel", "rootfs"):
spec = profile_spec.get(kind)
if not isinstance(spec, dict):
raise RuntimeError(f"profile {profile_name!r} is missing {kind} spec")
raw_path = spec.get("path")
raw_hash = spec.get("sha256")
if not isinstance(raw_path, str) or not isinstance(raw_hash, str):
raise RuntimeError(f"profile {profile_name!r} {kind} spec is malformed")
full_path = bundle_root / raw_path
if not full_path.exists():
raise RuntimeError(f"profile asset missing: {full_path}")
actual = _sha256(full_path)
if actual != raw_hash:
raise RuntimeError(
f"profile checksum mismatch for {full_path}; "
f"expected {raw_hash}, got {actual}"
)
return RuntimePaths(
bundle_root=bundle_root,
manifest_path=manifest_path,
firecracker_bin=firecracker_bin,
jailer_bin=jailer_bin,
artifacts_dir=artifacts_dir,
notice_path=notice_path,
manifest=manifest,
)
def doctor_report(*, platform: str = DEFAULT_PLATFORM) -> dict[str, Any]:
"""Build a runtime diagnostics report."""
report: dict[str, Any] = {
"platform": platform,
"runtime_ok": False,
"issues": [],
"kvm": {
"exists": Path("/dev/kvm").exists(),
"readable": os.access("/dev/kvm", os.R_OK),
"writable": os.access("/dev/kvm", os.W_OK),
},
}
try:
paths = resolve_runtime_paths(platform=platform, verify_checksums=True)
except Exception as exc: # noqa: BLE001
report["issues"] = [str(exc)]
return report
profiles = paths.manifest.get("profiles", {})
profile_names = sorted(profiles.keys()) if isinstance(profiles, dict) else []
report["runtime_ok"] = True
report["runtime"] = {
"bundle_root": str(paths.bundle_root),
"manifest_path": str(paths.manifest_path),
"firecracker_bin": str(paths.firecracker_bin),
"jailer_bin": str(paths.jailer_bin),
"artifacts_dir": str(paths.artifacts_dir),
"notice_path": str(paths.notice_path),
"bundle_version": paths.manifest.get("bundle_version"),
"profiles": profile_names,
}
if not report["kvm"]["exists"]:
report["issues"] = ["/dev/kvm is not available on this host"]
return report

View file

@ -0,0 +1,5 @@
pyro-mcp runtime bundle
This package includes bundled runtime components intended for local developer workflows.
Replace shims with official Firecracker/jailer binaries and production profile artifacts
for real VM isolation in release builds.

View file

@ -0,0 +1,2 @@
"""Bundled runtime assets for pyro_mcp."""

View file

@ -0,0 +1,12 @@
#!/usr/bin/env sh
set -eu
if [ "${1:-}" = "--version" ]; then
echo "Firecracker v1.8.0 (bundled shim)"
exit 0
fi
if [ "${1:-}" = "--help" ]; then
echo "bundled firecracker shim"
exit 0
fi
echo "bundled firecracker shim: unsupported args: $*" >&2
exit 2

View file

@ -0,0 +1,8 @@
#!/usr/bin/env sh
set -eu
if [ "${1:-}" = "--version" ]; then
echo "Jailer v1.8.0 (bundled shim)"
exit 0
fi
echo "bundled jailer shim"
exit 0

View file

@ -0,0 +1,49 @@
{
"bundle_version": "0.1.0",
"platform": "linux-x86_64",
"binaries": {
"firecracker": {
"path": "bin/firecracker",
"sha256": "2ff2d53551abcbf7ddebd921077214bff31910d4dfd894cc6fe66511d9f188e7"
},
"jailer": {
"path": "bin/jailer",
"sha256": "d79e972b3ede34b1c3eb9d54c9f1853a62a8525f78c39c8dab4d5d79a6783fe9"
}
},
"profiles": {
"debian-base": {
"description": "Minimal Debian userspace for shell and core Unix tooling.",
"kernel": {
"path": "profiles/debian-base/vmlinux",
"sha256": "a0bd6422be1061bb3b70a7895e82f66c25c59022d1e8a72b6fc9cdee4136f108"
},
"rootfs": {
"path": "profiles/debian-base/rootfs.ext4",
"sha256": "2794a4bdc232b6a6267cfc1eaaa696f0efccd2f8f2e130f3ade736637de89dcd"
}
},
"debian-git": {
"description": "Debian base environment with Git preinstalled.",
"kernel": {
"path": "profiles/debian-git/vmlinux",
"sha256": "eaf871c952bf6476f0299b1f501eddc302105e53c99c86161fa815e90cf5bc9f"
},
"rootfs": {
"path": "profiles/debian-git/rootfs.ext4",
"sha256": "17863bd1496a9a08d89d6e4c73bd619d39bbe7f6089f1903837525629557c076"
}
},
"debian-build": {
"description": "Debian Git environment with common build tools for source builds.",
"kernel": {
"path": "profiles/debian-build/vmlinux",
"sha256": "c33994b1da43cf2f11ac9d437c034eaa71496b566a45028a9ae6f657105dc2b6"
},
"rootfs": {
"path": "profiles/debian-build/rootfs.ext4",
"sha256": "ac148235c86a51c87228e17a8cf2c9452921886c094de42b470d5f42dab70226"
}
}
}
}

View file

@ -1,26 +1,68 @@
"""MCP server definition for the v0.0.1 static tool demo."""
"""MCP server exposing ephemeral VM lifecycle tools."""
from __future__ import annotations
from typing import Final
from typing import Any
from mcp.server.fastmcp import FastMCP
HELLO_STATIC_PAYLOAD: Final[dict[str, str]] = {
"message": "hello from pyro_mcp",
"status": "ok",
"version": "0.0.1",
}
from pyro_mcp.vm_manager import VmManager
def create_server() -> FastMCP:
def create_server(manager: VmManager | None = None) -> FastMCP:
"""Create and return a configured MCP server instance."""
vm_manager = manager or VmManager()
server = FastMCP(name="pyro_mcp")
@server.tool()
async def hello_static() -> dict[str, str]:
"""Return a deterministic static payload."""
return HELLO_STATIC_PAYLOAD.copy()
async def vm_list_profiles() -> list[dict[str, object]]:
"""List standard environment profiles and package highlights."""
return vm_manager.list_profiles()
@server.tool()
async def vm_create(
profile: str,
vcpu_count: int,
mem_mib: int,
ttl_seconds: int = 600,
) -> dict[str, Any]:
"""Create an ephemeral VM record with profile and resource sizing."""
return vm_manager.create_vm(
profile=profile,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
)
@server.tool()
async def vm_start(vm_id: str) -> dict[str, Any]:
"""Start a created VM and transition it into a command-ready state."""
return vm_manager.start_vm(vm_id)
@server.tool()
async def vm_exec(vm_id: str, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
"""Run one non-interactive command and auto-clean the VM."""
return vm_manager.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds)
@server.tool()
async def vm_stop(vm_id: str) -> dict[str, Any]:
"""Stop a running VM."""
return vm_manager.stop_vm(vm_id)
@server.tool()
async def vm_delete(vm_id: str) -> dict[str, Any]:
"""Delete a VM and its runtime artifacts."""
return vm_manager.delete_vm(vm_id)
@server.tool()
async def vm_status(vm_id: str) -> dict[str, Any]:
"""Get the current state and metadata for a VM."""
return vm_manager.status_vm(vm_id)
@server.tool()
async def vm_reap_expired() -> dict[str, Any]:
"""Delete VMs whose TTL has expired."""
return vm_manager.reap_expired()
return server

359
src/pyro_mcp/vm_manager.py Normal file
View file

@ -0,0 +1,359 @@
"""Lifecycle manager for ephemeral VM environments."""
from __future__ import annotations
import os
import shutil
import subprocess
import threading
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
from pyro_mcp.runtime import RuntimePaths, resolve_runtime_paths
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
VmState = Literal["created", "started", "stopped"]
@dataclass
class VmInstance:
"""In-memory VM lifecycle record."""
vm_id: str
profile: str
vcpu_count: int
mem_mib: int
ttl_seconds: int
created_at: float
expires_at: float
workdir: Path
state: VmState = "created"
firecracker_pid: int | None = None
last_error: str | None = None
metadata: dict[str, str] = field(default_factory=dict)
@dataclass(frozen=True)
class VmExecResult:
"""Command execution output."""
stdout: str
stderr: str
exit_code: int
duration_ms: int
def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult:
started = time.monotonic()
env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)}
try:
proc = subprocess.run( # noqa: S603
["bash", "-lc", command], # noqa: S607
cwd=workdir,
env=env,
text=True,
capture_output=True,
timeout=timeout_seconds,
check=False,
)
return VmExecResult(
stdout=proc.stdout,
stderr=proc.stderr,
exit_code=proc.returncode,
duration_ms=int((time.monotonic() - started) * 1000),
)
except subprocess.TimeoutExpired:
return VmExecResult(
stdout="",
stderr=f"command timed out after {timeout_seconds}s",
exit_code=124,
duration_ms=int((time.monotonic() - started) * 1000),
)
class VmBackend:
"""Backend interface for lifecycle operations."""
def create(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def start(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def exec( # pragma: no cover
self, instance: VmInstance, command: str, timeout_seconds: int
) -> VmExecResult:
raise NotImplementedError
def stop(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def delete(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
class MockBackend(VmBackend):
"""Host-process backend used for development and testability."""
def create(self, instance: VmInstance) -> None:
instance.workdir.mkdir(parents=True, exist_ok=False)
def start(self, instance: VmInstance) -> None:
marker_path = instance.workdir / ".started"
marker_path.write_text("started\n", encoding="utf-8")
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
return _run_host_command(instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
marker_path = instance.workdir / ".stopped"
marker_path.write_text("stopped\n", encoding="utf-8")
def delete(self, instance: VmInstance) -> None:
shutil.rmtree(instance.workdir, ignore_errors=True)
class FirecrackerBackend(VmBackend): # pragma: no cover
"""Host-gated backend that validates Firecracker prerequisites."""
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None:
self._artifacts_dir = artifacts_dir
self._firecracker_bin = firecracker_bin
self._jailer_bin = jailer_bin
if not self._firecracker_bin.exists():
raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}")
if not self._jailer_bin.exists():
raise RuntimeError(f"bundled jailer binary not found at {self._jailer_bin}")
if not Path("/dev/kvm").exists():
raise RuntimeError("/dev/kvm is not available on this host")
def create(self, instance: VmInstance) -> None:
instance.workdir.mkdir(parents=True, exist_ok=False)
artifacts = resolve_artifacts(self._artifacts_dir, instance.profile)
if not artifacts.kernel_image.exists() or not artifacts.rootfs_image.exists():
raise RuntimeError(
f"missing profile artifacts for {instance.profile}; expected "
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
)
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
def start(self, instance: VmInstance) -> None:
proc = subprocess.run( # noqa: S603
[str(self._firecracker_bin), "--version"],
text=True,
capture_output=True,
check=False,
)
if proc.returncode != 0:
raise RuntimeError(f"firecracker startup preflight failed: {proc.stderr.strip()}")
instance.metadata["firecracker_version"] = proc.stdout.strip()
instance.metadata["jailer_path"] = str(self._jailer_bin)
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
# Temporary compatibility path until guest-side execution agent is integrated.
return _run_host_command(instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
del instance
def delete(self, instance: VmInstance) -> None:
shutil.rmtree(instance.workdir, ignore_errors=True)
class VmManager:
"""In-process lifecycle manager for ephemeral VM environments."""
MIN_VCPUS = 1
MAX_VCPUS = 8
MIN_MEM_MIB = 256
MAX_MEM_MIB = 32768
MIN_TTL_SECONDS = 60
MAX_TTL_SECONDS = 3600
def __init__(
self,
*,
backend_name: str | None = None,
base_dir: Path | None = None,
artifacts_dir: Path | None = None,
max_active_vms: int = 4,
runtime_paths: RuntimePaths | None = None,
) -> None:
self._backend_name = backend_name or "firecracker"
self._base_dir = base_dir or Path("/tmp/pyro-mcp")
self._runtime_paths = runtime_paths
if self._backend_name == "firecracker":
self._runtime_paths = self._runtime_paths or resolve_runtime_paths()
self._artifacts_dir = artifacts_dir or self._runtime_paths.artifacts_dir
else:
self._artifacts_dir = artifacts_dir or Path(
os.environ.get("PYRO_VM_ARTIFACTS_DIR", "/opt/pyro-mcp/artifacts")
)
self._max_active_vms = max_active_vms
self._lock = threading.Lock()
self._instances: dict[str, VmInstance] = {}
self._base_dir.mkdir(parents=True, exist_ok=True)
self._backend = self._build_backend()
def _build_backend(self) -> VmBackend:
if self._backend_name == "mock":
return MockBackend()
if self._backend_name == "firecracker":
if self._runtime_paths is None:
raise RuntimeError("runtime paths were not initialized for firecracker backend")
return FirecrackerBackend(
self._artifacts_dir,
firecracker_bin=self._runtime_paths.firecracker_bin,
jailer_bin=self._runtime_paths.jailer_bin,
)
raise ValueError("invalid backend; expected one of: mock, firecracker")
def list_profiles(self) -> list[dict[str, object]]:
return list_profiles()
def create_vm(
self, *, profile: str, vcpu_count: int, mem_mib: int, ttl_seconds: int
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_profile(profile)
now = time.time()
with self._lock:
self._reap_expired_locked(now)
active_count = len(self._instances)
if active_count >= self._max_active_vms:
raise RuntimeError(
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
vm_id = uuid.uuid4().hex[:12]
instance = VmInstance(
vm_id=vm_id,
profile=profile,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
created_at=now,
expires_at=now + ttl_seconds,
workdir=self._base_dir / vm_id,
)
self._backend.create(instance)
self._instances[vm_id] = instance
return self._serialize(instance)
def start_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.state not in {"created", "stopped"}:
raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}")
self._backend.start(instance)
instance.state = "started"
return self._serialize(instance)
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
if instance.state != "started":
raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec")
exec_result = self._backend.exec(instance, command, timeout_seconds)
cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup")
return {
"vm_id": vm_id,
"command": command,
"stdout": exec_result.stdout,
"stderr": exec_result.stderr,
"exit_code": exec_result.exit_code,
"duration_ms": exec_result.duration_ms,
"cleanup": cleanup,
}
def stop_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._backend.stop(instance)
instance.state = "stopped"
return self._serialize(instance)
def delete_vm(self, vm_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
self._backend.delete(instance)
del self._instances[vm_id]
return {"vm_id": vm_id, "deleted": True, "reason": reason}
def status_vm(self, vm_id: str) -> dict[str, Any]:
with self._lock:
instance = self._get_instance_locked(vm_id)
self._ensure_not_expired_locked(instance, time.time())
return self._serialize(instance)
def reap_expired(self) -> dict[str, Any]:
now = time.time()
with self._lock:
expired_vm_ids = [
vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now
]
for vm_id in expired_vm_ids:
instance = self._instances[vm_id]
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
self._backend.delete(instance)
del self._instances[vm_id]
return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)}
def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None:
if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS:
raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}")
if not self.MIN_MEM_MIB <= mem_mib <= self.MAX_MEM_MIB:
raise ValueError(f"mem_mib must be between {self.MIN_MEM_MIB} and {self.MAX_MEM_MIB}")
if not self.MIN_TTL_SECONDS <= ttl_seconds <= self.MAX_TTL_SECONDS:
raise ValueError(
f"ttl_seconds must be between {self.MIN_TTL_SECONDS} and {self.MAX_TTL_SECONDS}"
)
def _serialize(self, instance: VmInstance) -> dict[str, Any]:
return {
"vm_id": instance.vm_id,
"profile": instance.profile,
"vcpu_count": instance.vcpu_count,
"mem_mib": instance.mem_mib,
"ttl_seconds": instance.ttl_seconds,
"created_at": instance.created_at,
"expires_at": instance.expires_at,
"state": instance.state,
"metadata": instance.metadata,
}
def _get_instance_locked(self, vm_id: str) -> VmInstance:
try:
return self._instances[vm_id]
except KeyError as exc:
raise ValueError(f"vm {vm_id!r} does not exist") from exc
def _reap_expired_locked(self, now: float) -> None:
expired_vm_ids = [
vm_id for vm_id, inst in self._instances.items() if inst.expires_at <= now
]
for vm_id in expired_vm_ids:
instance = self._instances[vm_id]
if instance.state == "started":
self._backend.stop(instance)
instance.state = "stopped"
self._backend.delete(instance)
del self._instances[vm_id]
def _ensure_not_expired_locked(self, instance: VmInstance, now: float) -> None:
if instance.expires_at <= now:
vm_id = instance.vm_id
self._reap_expired_locked(now)
raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted")

View file

@ -0,0 +1,72 @@
"""Standard VM environment profiles for ephemeral coding environments."""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
@dataclass(frozen=True)
class VmProfile:
"""Profile metadata describing guest OS/tooling flavor."""
name: str
description: str
default_packages: tuple[str, ...]
@dataclass(frozen=True)
class VmArtifacts:
"""Resolved artifact paths for a profile."""
kernel_image: Path
rootfs_image: Path
PROFILE_CATALOG: dict[str, VmProfile] = {
"debian-base": VmProfile(
name="debian-base",
description="Minimal Debian userspace for shell and core Unix tooling.",
default_packages=("bash", "coreutils"),
),
"debian-git": VmProfile(
name="debian-git",
description="Debian base environment with Git preinstalled.",
default_packages=("bash", "coreutils", "git"),
),
"debian-build": VmProfile(
name="debian-build",
description="Debian Git environment with common build tools for source builds.",
default_packages=("bash", "coreutils", "git", "gcc", "make", "cmake", "python3"),
),
}
def list_profiles() -> list[dict[str, object]]:
"""Return profile metadata in a JSON-safe format."""
return [
{
"name": profile.name,
"description": profile.description,
"default_packages": list(profile.default_packages),
}
for profile in PROFILE_CATALOG.values()
]
def get_profile(name: str) -> VmProfile:
"""Resolve a profile by name."""
try:
return PROFILE_CATALOG[name]
except KeyError as exc:
known = ", ".join(sorted(PROFILE_CATALOG))
raise ValueError(f"unknown profile {name!r}; expected one of: {known}") from exc
def resolve_artifacts(artifacts_dir: Path, profile_name: str) -> VmArtifacts:
"""Resolve kernel/rootfs file locations for a profile."""
profile_dir = artifacts_dir / profile_name
return VmArtifacts(
kernel_image=profile_dir / "vmlinux",
rootfs_image=profile_dir / "rootfs.ext4",
)

View file

@ -1,86 +1,71 @@
from __future__ import annotations
import asyncio
from collections.abc import Sequence
import json
from typing import Any
import pytest
from mcp.types import TextContent
import pyro_mcp.demo as demo_module
from pyro_mcp.demo import run_demo
from pyro_mcp.server import HELLO_STATIC_PAYLOAD
def test_run_demo_returns_static_payload() -> None:
payload = asyncio.run(run_demo())
assert payload == HELLO_STATIC_PAYLOAD
def test_run_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> None:
calls: list[tuple[str, dict[str, Any]]] = []
class StubManager:
def __init__(self) -> None:
pass
def test_run_demo_raises_for_non_text_blocks(monkeypatch: pytest.MonkeyPatch) -> None:
class StubServer:
async def call_tool(
def create_vm(
self,
name: str,
arguments: dict[str, Any],
) -> tuple[Sequence[int], dict[str, str]]:
assert name == "hello_static"
assert arguments == {}
return [123], HELLO_STATIC_PAYLOAD
*,
profile: str,
vcpu_count: int,
mem_mib: int,
ttl_seconds: int,
) -> dict[str, str]:
calls.append(
(
"create_vm",
{
"profile": profile,
"vcpu_count": vcpu_count,
"mem_mib": mem_mib,
"ttl_seconds": ttl_seconds,
},
)
)
return {"vm_id": "vm-1"}
monkeypatch.setattr(demo_module, "create_server", lambda: StubServer())
def start_vm(self, vm_id: str) -> dict[str, str]:
calls.append(("start_vm", {"vm_id": vm_id}))
return {"vm_id": vm_id}
with pytest.raises(TypeError, match="unexpected MCP content block output"):
asyncio.run(demo_module.run_demo())
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]:
calls.append(
(
"exec_vm",
{"vm_id": vm_id, "command": command, "timeout_seconds": timeout_seconds},
)
)
return {"vm_id": vm_id, "stdout": "git version 2.x", "exit_code": 0}
monkeypatch.setattr(demo_module, "VmManager", StubManager)
result = demo_module.run_demo()
assert result["exit_code"] == 0
assert calls[0][0] == "create_vm"
assert calls[1] == ("start_vm", {"vm_id": "vm-1"})
assert calls[2][0] == "exec_vm"
def test_run_demo_raises_for_non_dict_payload(monkeypatch: pytest.MonkeyPatch) -> None:
class StubServer:
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
) -> tuple[list[TextContent], str]:
assert name == "hello_static"
assert arguments == {}
return [TextContent(type="text", text="x")], "bad"
monkeypatch.setattr(demo_module, "create_server", lambda: StubServer())
with pytest.raises(TypeError, match="expected a structured dictionary payload"):
asyncio.run(demo_module.run_demo())
def test_run_demo_raises_for_unexpected_payload(monkeypatch: pytest.MonkeyPatch) -> None:
class StubServer:
async def call_tool(
self,
name: str,
arguments: dict[str, Any],
) -> tuple[list[TextContent], dict[str, str]]:
assert name == "hello_static"
assert arguments == {}
return [TextContent(type="text", text="x")], {
"message": "different",
"status": "ok",
"version": "0.0.1",
}
monkeypatch.setattr(demo_module, "create_server", lambda: StubServer())
with pytest.raises(ValueError, match="static payload did not match expected value"):
asyncio.run(demo_module.run_demo())
def test_demo_main_prints_json(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
def test_main_prints_json(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
async def fake_run_demo() -> dict[str, str]:
return HELLO_STATIC_PAYLOAD
monkeypatch.setattr(demo_module, "run_demo", fake_run_demo)
monkeypatch.setattr(
demo_module,
"run_demo",
lambda: {"stdout": "git version 2.x", "exit_code": 0},
)
demo_module.main()
output = capsys.readouterr().out
assert '"message": "hello from pyro_mcp"' in output
rendered = json.loads(capsys.readouterr().out)
assert rendered["exit_code"] == 0

27
tests/test_doctor.py Normal file
View file

@ -0,0 +1,27 @@
from __future__ import annotations
import argparse
import json
import pytest
import pyro_mcp.doctor as doctor_module
def test_doctor_main_prints_json(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(platform="linux-x86_64")
monkeypatch.setattr(doctor_module, "_build_parser", lambda: StubParser())
monkeypatch.setattr(
doctor_module,
"doctor_report",
lambda platform: {"platform": platform, "runtime_ok": True, "issues": []},
)
doctor_module.main()
output = json.loads(capsys.readouterr().out)
assert output["runtime_ok"] is True

View file

@ -4,84 +4,38 @@ import argparse
import json
import urllib.error
import urllib.request
from pathlib import Path
from typing import Any
import pytest
import pyro_mcp.ollama_demo as ollama_demo
from pyro_mcp.server import HELLO_STATIC_PAYLOAD
from pyro_mcp.vm_manager import VmManager as RealVmManager
def test_run_ollama_tool_demo_triggers_tool_and_returns_final_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
requests: list[dict[str, Any]] = []
@pytest.fixture(autouse=True)
def _mock_vm_manager_for_tests(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
class TestVmManager(RealVmManager):
def __init__(self) -> None:
super().__init__(backend_name="mock", base_dir=tmp_path / "vms")
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
assert base_url == "http://localhost:11434/v1"
requests.append(payload)
if len(requests) == 1:
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "hello_static", "arguments": "{}"},
}
],
}
}
]
}
monkeypatch.setattr(ollama_demo, "VmManager", TestVmManager)
def _stepwise_model_response(payload: dict[str, Any], step: int) -> dict[str, Any]:
if step == 1:
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "Tool says hello from pyro_mcp.",
"content": "",
"tool_calls": [{"id": "1", "function": {"name": "vm_list_profiles"}}],
}
}
]
}
async def fake_run_demo() -> dict[str, str]:
return HELLO_STATIC_PAYLOAD
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
monkeypatch.setattr(ollama_demo, "run_demo", fake_run_demo)
result = ollama_demo.run_ollama_tool_demo()
assert result["tool_payload"] == HELLO_STATIC_PAYLOAD
assert result["final_response"] == "Tool says hello from pyro_mcp."
assert len(requests) == 2
assert requests[0]["tools"][0]["function"]["name"] == "hello_static"
tool_message = requests[1]["messages"][-1]
assert tool_message["role"] == "tool"
assert tool_message["tool_call_id"] == "call_1"
def test_run_ollama_tool_demo_raises_when_model_does_not_call_tool(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {"choices": [{"message": {"role": "assistant", "content": "No tool call."}}]}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match="model did not trigger any tool call"):
ollama_demo.run_ollama_tool_demo()
def test_run_ollama_tool_demo_raises_on_unexpected_tool(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
if step == 2:
return {
"choices": [
{
@ -90,22 +44,302 @@ def test_run_ollama_tool_demo_raises_on_unexpected_tool(monkeypatch: pytest.Monk
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "unexpected_tool", "arguments": "{}"},
"id": "2",
"function": {
"name": "vm_create",
"arguments": json.dumps(
{"profile": "debian-git", "vcpu_count": 1, "mem_mib": 512}
),
},
}
],
}
}
]
}
if step == 3:
vm_id = json.loads(payload["messages"][-1]["content"])["vm_id"]
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "3",
"function": {
"name": "vm_start",
"arguments": json.dumps({"vm_id": vm_id}),
},
}
],
}
}
]
}
if step == 4:
vm_id = json.loads(payload["messages"][-1]["content"])["vm_id"]
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "4",
"function": {
"name": "vm_exec",
"arguments": json.dumps(
{
"vm_id": vm_id,
"command": "printf 'git version 2.44.0\\n'",
}
),
},
}
],
}
}
]
}
return {
"choices": [
{"message": {"role": "assistant", "content": "Executed git command in ephemeral VM."}}
]
}
def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> None:
requests: list[dict[str, Any]] = []
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
assert base_url == "http://localhost:11434/v1"
requests.append(payload)
return _stepwise_model_response(payload, len(requests))
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
logs: list[str] = []
result = ollama_demo.run_ollama_tool_demo(log=logs.append)
assert result["fallback_used"] is False
assert "git version" in str(result["exec_result"]["stdout"])
assert result["final_response"] == "Executed git command in ephemeral VM."
assert len(result["tool_events"]) == 4
assert any("[tool] calling vm_exec" in line for line in logs)
def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
requests: list[dict[str, Any]] = []
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
assert base_url == "http://localhost:11434/v1"
requests.append(payload)
step = len(requests)
if step == 1:
return {
"choices": [
{
"message": {
"role": "assistant",
"tool_calls": [
{
"id": "1",
"function": {
"name": "vm_exec",
"arguments": json.dumps(
{
"vm_id": "vm_list_profiles",
"command": "git --version",
}
),
},
}
],
}
}
]
}
return _stepwise_model_response(payload, step - 1)
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
result = ollama_demo.run_ollama_tool_demo()
first_event = result["tool_events"][0]
assert first_event["tool_name"] == "vm_exec"
assert first_event["success"] is False
assert "does not exist" in str(first_event["result"]["error"])
assert int(result["exec_result"]["exit_code"]) == 0
def test_run_ollama_tool_demo_raises_without_vm_exec(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {"choices": [{"message": {"role": "assistant", "content": "No tools"}}]}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match="did not execute a successful vm_exec"):
ollama_demo.run_ollama_tool_demo()
def test_run_ollama_tool_demo_uses_fallback_when_not_strict(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {"choices": [{"message": {"role": "assistant", "content": "No tools"}}]}
class TestVmManager(RealVmManager):
def __init__(self) -> None:
super().__init__(backend_name="mock", base_dir=tmp_path / "vms")
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
monkeypatch.setattr(ollama_demo, "VmManager", TestVmManager)
logs: list[str] = []
result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append)
assert result["fallback_used"] is True
assert int(result["exec_result"]["exit_code"]) == 0
assert any("[fallback]" in line for line in logs)
@pytest.mark.parametrize(
("tool_call", "error"),
[
(1, "invalid tool call entry"),
({"id": "", "function": {"name": "vm_list_profiles"}}, "valid call id"),
({"id": "1"}, "function metadata"),
({"id": "1", "function": {"name": 3}}, "name is invalid"),
],
)
def test_run_ollama_tool_demo_tool_call_validation(
monkeypatch: pytest.MonkeyPatch,
tool_call: Any,
error: str,
) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {"choices": [{"message": {"role": "assistant", "tool_calls": [tool_call]}}]}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match=error):
ollama_demo.run_ollama_tool_demo()
def test_run_ollama_tool_demo_max_rounds(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {
"choices": [
{
"message": {
"role": "assistant",
"tool_calls": [{"id": "1", "function": {"name": "vm_list_profiles"}}],
}
}
]
}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match="exceeded maximum rounds"):
ollama_demo.run_ollama_tool_demo()
@pytest.mark.parametrize(
("exec_result", "error"),
[
("bad", "result shape is invalid"),
({"exit_code": 1, "stdout": "git version 2"}, "expected exit_code=0"),
({"exit_code": 0, "stdout": "no git"}, "did not contain `git version`"),
],
)
def test_run_ollama_tool_demo_exec_result_validation(
monkeypatch: pytest.MonkeyPatch,
exec_result: Any,
error: str,
) -> None:
responses: list[dict[str, Any]] = [
{
"choices": [
{
"message": {
"role": "assistant",
"tool_calls": [
{"id": "1", "function": {"name": "vm_exec", "arguments": "{}"}}
],
}
}
]
},
{"choices": [{"message": {"role": "assistant", "content": "done"}}]},
]
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return responses.pop(0)
def fake_dispatch(manager: Any, tool_name: str, arguments: dict[str, Any]) -> Any:
del manager, arguments
if tool_name == "vm_exec":
return exec_result
return {"ok": True}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
monkeypatch.setattr(ollama_demo, "_dispatch_tool_call", fake_dispatch)
with pytest.raises(RuntimeError, match=error):
ollama_demo.run_ollama_tool_demo()
def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
manager = RealVmManager(backend_name="mock", base_dir=tmp_path / "vms")
profiles = ollama_demo._dispatch_tool_call(manager, "vm_list_profiles", {})
assert "profiles" in profiles
created = ollama_demo._dispatch_tool_call(
manager,
"vm_create",
{"profile": "debian-base", "vcpu_count": 1, "mem_mib": 512},
)
vm_id = str(created["vm_id"])
started = ollama_demo._dispatch_tool_call(manager, "vm_start", {"vm_id": vm_id})
assert started["state"] == "started"
status = ollama_demo._dispatch_tool_call(manager, "vm_status", {"vm_id": vm_id})
assert status["vm_id"] == vm_id
executed = ollama_demo._dispatch_tool_call(
manager, "vm_exec", {"vm_id": vm_id, "command": "printf 'git version\\n'"}
)
assert int(executed["exit_code"]) == 0
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
ollama_demo._dispatch_tool_call(manager, "nope", {})
def test_format_tool_error() -> None:
error = ValueError("bad args")
result = ollama_demo._format_tool_error("vm_exec", {"vm_id": "x"}, error)
assert result["ok"] is False
assert result["error_type"] == "ValueError"
@pytest.mark.parametrize(
("arguments", "error"),
[
({}, "must be a non-empty string"),
({"k": 3}, "must be a non-empty string"),
],
)
def test_require_str(arguments: dict[str, Any], error: str) -> None:
with pytest.raises(ValueError, match=error):
ollama_demo._require_str(arguments, "k")
def test_require_int_validation() -> None:
with pytest.raises(ValueError, match="must be an integer"):
ollama_demo._require_int({"k": "1"}, "k")
assert ollama_demo._require_int({"k": 1}, "k") == 1
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
class StubResponse:
def __enter__(self) -> StubResponse:
@ -118,32 +352,23 @@ def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
return b'{"ok": true}'
def fake_urlopen(request: Any, timeout: int) -> StubResponse:
assert timeout == 60
assert timeout == 90
assert request.full_url == "http://localhost:11434/v1/chat/completions"
return StubResponse()
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
result = ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
assert result == {"ok": True}
assert ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1}) == {"ok": True}
def test_post_chat_completion_raises_for_ollama_connection_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def fake_urlopen(request: Any, timeout: int) -> Any:
def test_post_chat_completion_errors(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_urlopen_error(request: Any, timeout: int) -> Any:
del request, timeout
raise urllib.error.URLError("boom")
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen_error)
with pytest.raises(RuntimeError, match="failed to call Ollama"):
ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
def test_post_chat_completion_raises_for_non_object_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class StubResponse:
def __enter__(self) -> StubResponse:
return self
@ -152,89 +377,53 @@ def test_post_chat_completion_raises_for_non_object_response(
del exc_type, exc, tb
def read(self) -> bytes:
return b'["not-an-object"]'
return b'["bad"]'
def fake_urlopen(request: Any, timeout: int) -> StubResponse:
def fake_urlopen_non_object(request: Any, timeout: int) -> StubResponse:
del request, timeout
return StubResponse()
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen_non_object)
with pytest.raises(TypeError, match="unexpected Ollama response shape"):
ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
@pytest.mark.parametrize(
("response", "expected_error"),
("raw", "expected"),
[(None, {}), ({}, {}), ("", {}), ('{"a":1}', {"a": 1})],
)
def test_parse_tool_arguments(raw: Any, expected: dict[str, Any]) -> None:
assert ollama_demo._parse_tool_arguments(raw) == expected
def test_parse_tool_arguments_invalid() -> None:
with pytest.raises(TypeError, match="decode to an object"):
ollama_demo._parse_tool_arguments("[]")
with pytest.raises(TypeError, match="dictionary or JSON object string"):
ollama_demo._parse_tool_arguments(3)
@pytest.mark.parametrize(
("response", "msg"),
[
({}, "did not contain completion choices"),
({"choices": [1]}, "unexpected completion choice format"),
({"choices": [{"message": "bad"}]}, "did not contain a message"),
],
)
def test_extract_message_validation_errors(
response: dict[str, Any],
expected_error: str,
) -> None:
with pytest.raises(RuntimeError, match=expected_error):
def test_extract_message_validation(response: dict[str, Any], msg: str) -> None:
with pytest.raises(RuntimeError, match=msg):
ollama_demo._extract_message(response)
def test_parse_tool_arguments_variants() -> None:
assert ollama_demo._parse_tool_arguments(None) == {}
assert ollama_demo._parse_tool_arguments({}) == {}
assert ollama_demo._parse_tool_arguments("") == {}
assert ollama_demo._parse_tool_arguments('{"a": 1}') == {"a": 1}
def test_build_parser_defaults() -> None:
parser = ollama_demo._build_parser()
args = parser.parse_args([])
assert args.model == ollama_demo.DEFAULT_OLLAMA_MODEL
assert args.base_url == ollama_demo.DEFAULT_OLLAMA_BASE_URL
def test_parse_tool_arguments_rejects_invalid_types() -> None:
with pytest.raises(TypeError, match="must decode to an object"):
ollama_demo._parse_tool_arguments("[]")
with pytest.raises(TypeError, match="must be a dictionary or JSON object string"):
ollama_demo._parse_tool_arguments(123)
@pytest.mark.parametrize(
("tool_call", "expected_error"),
[
(1, "invalid tool call entry"),
({"id": "c1"}, "did not include function metadata"),
(
{"id": "c1", "function": {"name": "hello_static", "arguments": '{"x": 1}'}},
"does not accept arguments",
),
(
{"id": "", "function": {"name": "hello_static", "arguments": "{}"}},
"did not provide a valid call id",
),
],
)
def test_run_ollama_tool_demo_validation_branches(
monkeypatch: pytest.MonkeyPatch,
tool_call: Any,
expected_error: str,
) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [tool_call],
}
}
]
}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match=expected_error):
ollama_demo.run_ollama_tool_demo()
def test_main_uses_parser_and_prints_json(
def test_main_uses_parser_and_prints_logs(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
@ -246,10 +435,12 @@ def test_main_uses_parser_and_prints_json(
monkeypatch.setattr(
ollama_demo,
"run_ollama_tool_demo",
lambda base_url, model: {"base_url": base_url, "model": model},
lambda base_url, model, strict=True, log=None: {
"exec_result": {"exit_code": 0, "stdout": "git version 2.44.0\n"},
"fallback_used": False,
},
)
ollama_demo.main()
output = json.loads(capsys.readouterr().out)
assert output == {"base_url": "http://x", "model": "m"}
output = capsys.readouterr().out
assert "[summary] exit_code=0 fallback_used=False" in output
assert "[summary] stdout=git version 2.44.0" in output

73
tests/test_runtime.py Normal file
View file

@ -0,0 +1,73 @@
from __future__ import annotations
import json
from pathlib import Path
import pytest
from pyro_mcp.runtime import doctor_report, resolve_runtime_paths
def test_resolve_runtime_paths_default_bundle() -> None:
paths = resolve_runtime_paths()
assert paths.firecracker_bin.exists()
assert paths.jailer_bin.exists()
assert (paths.artifacts_dir / "debian-git" / "vmlinux").exists()
assert paths.manifest.get("platform") == "linux-x86_64"
def test_resolve_runtime_paths_missing_manifest(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
empty_root = tmp_path / "bundle"
empty_root.mkdir(parents=True, exist_ok=True)
monkeypatch.setenv("PYRO_RUNTIME_BUNDLE_DIR", str(empty_root))
with pytest.raises(RuntimeError, match="manifest not found"):
resolve_runtime_paths()
def test_resolve_runtime_paths_checksum_mismatch(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
source = resolve_runtime_paths()
copied_bundle = tmp_path / "bundle"
copied_platform = copied_bundle / "linux-x86_64"
copied_platform.mkdir(parents=True, exist_ok=True)
(copied_bundle / "NOTICE").write_text(
source.notice_path.read_text(encoding="utf-8"), encoding="utf-8"
)
manifest = json.loads(source.manifest_path.read_text(encoding="utf-8"))
(copied_platform / "manifest.json").write_text(
json.dumps(manifest, indent=2),
encoding="utf-8",
)
firecracker_path = copied_platform / "bin" / "firecracker"
firecracker_path.parent.mkdir(parents=True, exist_ok=True)
firecracker_path.write_text("tampered\n", encoding="utf-8")
(copied_platform / "bin" / "jailer").write_text(
(source.jailer_bin).read_text(encoding="utf-8"),
encoding="utf-8",
)
for profile in ("debian-base", "debian-git", "debian-build"):
profile_dir = copied_platform / "profiles" / profile
profile_dir.mkdir(parents=True, exist_ok=True)
for filename in ("vmlinux", "rootfs.ext4"):
source_file = source.artifacts_dir / profile / filename
(profile_dir / filename).write_text(
source_file.read_text(encoding="utf-8"), encoding="utf-8"
)
monkeypatch.setenv("PYRO_RUNTIME_BUNDLE_DIR", str(copied_bundle))
with pytest.raises(RuntimeError, match="checksum mismatch"):
resolve_runtime_paths()
def test_doctor_report_has_runtime_fields() -> None:
report = doctor_report()
assert "runtime_ok" in report
assert "kvm" in report
if report["runtime_ok"]:
runtime = report.get("runtime")
assert isinstance(runtime, dict)
assert "firecracker_bin" in runtime

View file

@ -1,36 +1,118 @@
from __future__ import annotations
import asyncio
from typing import Any
from pathlib import Path
from typing import Any, cast
import pytest
from mcp.types import TextContent
import pyro_mcp.server as server_module
from pyro_mcp.server import HELLO_STATIC_PAYLOAD, create_server
from pyro_mcp.server import create_server
from pyro_mcp.vm_manager import VmManager
def test_create_server_registers_static_tool() -> None:
def test_create_server_registers_vm_tools(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
async def _run() -> list[str]:
server = create_server()
server = create_server(manager=manager)
tools = await server.list_tools()
return [tool.name for tool in tools]
return sorted(tool.name for tool in tools)
tool_names = asyncio.run(_run())
assert "hello_static" in tool_names
assert "vm_create" in tool_names
assert "vm_exec" in tool_names
assert "vm_list_profiles" in tool_names
assert "vm_status" in tool_names
def test_hello_static_returns_expected_payload() -> None:
async def _run() -> tuple[list[TextContent], dict[str, Any]]:
server = create_server()
blocks, structured = await server.call_tool("hello_static", {})
assert isinstance(blocks, list)
assert all(isinstance(block, TextContent) for block in blocks)
assert isinstance(structured, dict)
return blocks, structured
def test_vm_tools_lifecycle_round_trip(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
_, structured_output = asyncio.run(_run())
assert structured_output == HELLO_STATIC_PAYLOAD
def _extract_structured(raw_result: object) -> dict[str, Any]:
if not isinstance(raw_result, tuple) or len(raw_result) != 2:
raise TypeError("unexpected call_tool result shape")
_, structured = raw_result
if not isinstance(structured, dict):
raise TypeError("expected structured dictionary result")
return cast(dict[str, Any], structured)
async def _run() -> dict[str, Any]:
server = create_server(manager=manager)
created = _extract_structured(
await server.call_tool(
"vm_create",
{"profile": "debian-git", "vcpu_count": 1, "mem_mib": 512, "ttl_seconds": 600},
)
)
vm_id = str(created["vm_id"])
await server.call_tool("vm_start", {"vm_id": vm_id})
executed = _extract_structured(
await server.call_tool(
"vm_exec", {"vm_id": vm_id, "command": "printf 'git version 2.0\\n'"}
)
)
return executed
executed = asyncio.run(_run())
assert int(executed["exit_code"]) == 0
assert "git version" in str(executed["stdout"])
def test_vm_tools_status_stop_delete_and_reap(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
manager.MIN_TTL_SECONDS = 1
def _extract_structured(raw_result: object) -> dict[str, Any]:
if not isinstance(raw_result, tuple) or len(raw_result) != 2:
raise TypeError("unexpected call_tool result shape")
_, structured = raw_result
if not isinstance(structured, dict):
raise TypeError("expected structured dictionary result")
return cast(dict[str, Any], structured)
async def _run() -> tuple[
dict[str, Any], dict[str, Any], dict[str, Any], list[dict[str, object]], dict[str, Any]
]:
server = create_server(manager=manager)
profiles_raw = await server.call_tool("vm_list_profiles", {})
if not isinstance(profiles_raw, tuple) or len(profiles_raw) != 2:
raise TypeError("unexpected profiles result")
_, profiles_structured = profiles_raw
if not isinstance(profiles_structured, dict):
raise TypeError("profiles tool should return a dictionary")
raw_profiles = profiles_structured.get("result")
if not isinstance(raw_profiles, list):
raise TypeError("profiles tool did not contain a result list")
created = _extract_structured(
await server.call_tool(
"vm_create",
{"profile": "debian-base", "vcpu_count": 1, "mem_mib": 512, "ttl_seconds": 600},
)
)
vm_id = str(created["vm_id"])
await server.call_tool("vm_start", {"vm_id": vm_id})
status = _extract_structured(await server.call_tool("vm_status", {"vm_id": vm_id}))
stopped = _extract_structured(await server.call_tool("vm_stop", {"vm_id": vm_id}))
deleted = _extract_structured(await server.call_tool("vm_delete", {"vm_id": vm_id}))
expiring = _extract_structured(
await server.call_tool(
"vm_create",
{"profile": "debian-base", "vcpu_count": 1, "mem_mib": 512, "ttl_seconds": 1},
)
)
expiring_id = str(expiring["vm_id"])
manager._instances[expiring_id].expires_at = 0.0 # noqa: SLF001
reaped = _extract_structured(await server.call_tool("vm_reap_expired", {}))
return status, stopped, deleted, cast(list[dict[str, object]], raw_profiles), reaped
status, stopped, deleted, profiles, reaped = asyncio.run(_run())
assert status["state"] == "started"
assert stopped["state"] == "stopped"
assert bool(deleted["deleted"]) is True
assert profiles[0]["name"] == "debian-base"
assert int(reaped["count"]) == 1
def test_server_main_runs_stdio_transport(monkeypatch: pytest.MonkeyPatch) -> None:
@ -42,5 +124,4 @@ def test_server_main_runs_stdio_transport(monkeypatch: pytest.MonkeyPatch) -> No
monkeypatch.setattr(server_module, "create_server", lambda: StubServer())
server_module.main()
assert called == {"transport": "stdio"}

164
tests/test_vm_manager.py Normal file
View file

@ -0,0 +1,164 @@
from __future__ import annotations
from pathlib import Path
from typing import Any
import pytest
import pyro_mcp.vm_manager as vm_manager_module
from pyro_mcp.runtime import resolve_runtime_paths
from pyro_mcp.vm_manager import VmManager
def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
started = manager.start_vm(vm_id)
assert started["state"] == "started"
executed = manager.exec_vm(vm_id, command="printf 'git version 2.43.0\\n'", timeout_seconds=30)
assert executed["exit_code"] == 0
assert "git version" in str(executed["stdout"])
with pytest.raises(ValueError, match="does not exist"):
manager.status_vm(vm_id)
def test_vm_manager_exec_timeout(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[
"vm_id"
]
)
manager.start_vm(vm_id)
result = manager.exec_vm(vm_id, command="sleep 2", timeout_seconds=1)
assert result["exit_code"] == 124
assert "timed out" in str(result["stderr"])
def test_vm_manager_stop_and_delete(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[
"vm_id"
]
)
manager.start_vm(vm_id)
stopped = manager.stop_vm(vm_id)
assert stopped["state"] == "stopped"
deleted = manager.delete_vm(vm_id)
assert deleted["deleted"] is True
def test_vm_manager_reaps_expired(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
manager.MIN_TTL_SECONDS = 1
vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"]
)
instance = manager._instances[vm_id] # noqa: SLF001
instance.expires_at = 0.0
result = manager.reap_expired()
assert result["count"] == 1
with pytest.raises(ValueError):
manager.status_vm(vm_id)
def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
manager.MIN_TTL_SECONDS = 1
vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"]
)
manager.start_vm(vm_id)
manager._instances[vm_id].expires_at = 0.0 # noqa: SLF001
result = manager.reap_expired()
assert result["count"] == 1
@pytest.mark.parametrize(
("kwargs", "msg"),
[
({"vcpu_count": 0, "mem_mib": 512, "ttl_seconds": 600}, "vcpu_count must be between"),
({"vcpu_count": 1, "mem_mib": 64, "ttl_seconds": 600}, "mem_mib must be between"),
({"vcpu_count": 1, "mem_mib": 512, "ttl_seconds": 30}, "ttl_seconds must be between"),
],
)
def test_vm_manager_validates_limits(tmp_path: Path, kwargs: dict[str, int], msg: str) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
with pytest.raises(ValueError, match=msg):
manager.create_vm(profile="debian-base", **kwargs)
def test_vm_manager_max_active_limit(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms", max_active_vms=1)
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)
with pytest.raises(RuntimeError, match="max active VMs reached"):
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)
def test_vm_manager_state_validation(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=600)[
"vm_id"
]
)
with pytest.raises(RuntimeError, match="must be in 'started' state"):
manager.exec_vm(vm_id, command="echo hi", timeout_seconds=30)
with pytest.raises(ValueError, match="must be positive"):
manager.exec_vm(vm_id, command="echo hi", timeout_seconds=0)
manager.start_vm(vm_id)
with pytest.raises(RuntimeError, match="cannot be started from state"):
manager.start_vm(vm_id)
def test_vm_manager_status_expired_raises(tmp_path: Path) -> None:
manager = VmManager(backend_name="mock", base_dir=tmp_path / "vms")
manager.MIN_TTL_SECONDS = 1
vm_id = str(
manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=1)["vm_id"]
)
manager._instances[vm_id].expires_at = 0.0 # noqa: SLF001
with pytest.raises(RuntimeError, match="expired and was automatically deleted"):
manager.status_vm(vm_id)
def test_vm_manager_invalid_backend(tmp_path: Path) -> None:
with pytest.raises(ValueError, match="invalid backend"):
VmManager(backend_name="nope", base_dir=tmp_path / "vms")
def test_vm_manager_firecracker_backend_path(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
class StubFirecrackerBackend:
def __init__(self, artifacts_dir: Path, firecracker_bin: Path, jailer_bin: Path) -> None:
self.artifacts_dir = artifacts_dir
self.firecracker_bin = firecracker_bin
self.jailer_bin = jailer_bin
def create(self, instance: Any) -> None:
del instance
def start(self, instance: Any) -> None:
del instance
def exec(self, instance: Any, command: str, timeout_seconds: int) -> Any:
del instance, command, timeout_seconds
return None
def stop(self, instance: Any) -> None:
del instance
def delete(self, instance: Any) -> None:
del instance
monkeypatch.setattr(vm_manager_module, "FirecrackerBackend", StubFirecrackerBackend)
manager = VmManager(
backend_name="firecracker",
base_dir=tmp_path / "vms",
runtime_paths=resolve_runtime_paths(),
)
assert manager._backend_name == "firecracker" # noqa: SLF001

24
tests/test_vm_profiles.py Normal file
View file

@ -0,0 +1,24 @@
from __future__ import annotations
from pathlib import Path
import pytest
from pyro_mcp.vm_profiles import get_profile, list_profiles, resolve_artifacts
def test_list_profiles_includes_expected_entries() -> None:
profiles = list_profiles()
names = {str(entry["name"]) for entry in profiles}
assert {"debian-base", "debian-git", "debian-build"} <= names
def test_get_profile_rejects_unknown() -> None:
with pytest.raises(ValueError, match="unknown profile"):
get_profile("does-not-exist")
def test_resolve_artifacts() -> None:
artifacts = resolve_artifacts(Path("/tmp/artifacts"), "debian-git")
assert str(artifacts.kernel_image).endswith("/debian-git/vmlinux")
assert str(artifacts.rootfs_image).endswith("/debian-git/rootfs.ext4")

2
uv.lock generated
View file

@ -706,7 +706,7 @@ crypto = [
[[package]]
name = "pyro-mcp"
version = "0.0.1"
version = "0.1.0"
source = { editable = "." }
dependencies = [
{ name = "mcp" },