Unify public UX around pyro CLI and Pyro facade
This commit is contained in:
parent
d16aadd03f
commit
23a2dfb330
19 changed files with 936 additions and 407 deletions
|
|
@ -1,6 +1,7 @@
|
|||
"""Public package surface for pyro_mcp."""
|
||||
|
||||
from pyro_mcp.api import Pyro
|
||||
from pyro_mcp.server import create_server
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
|
||||
__all__ = ["VmManager", "create_server"]
|
||||
__all__ = ["Pyro", "VmManager", "create_server"]
|
||||
|
|
|
|||
179
src/pyro_mcp/api.py
Normal file
179
src/pyro_mcp/api.py
Normal file
|
|
@ -0,0 +1,179 @@
|
|||
"""Public facade shared by the Python SDK and MCP server."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
|
||||
|
||||
class Pyro:
|
||||
"""High-level facade over the ephemeral VM runtime."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
manager: VmManager | None = None,
|
||||
*,
|
||||
backend_name: str | None = None,
|
||||
base_dir: Path | None = None,
|
||||
artifacts_dir: Path | None = None,
|
||||
max_active_vms: int = 4,
|
||||
) -> None:
|
||||
self._manager = manager or VmManager(
|
||||
backend_name=backend_name,
|
||||
base_dir=base_dir,
|
||||
artifacts_dir=artifacts_dir,
|
||||
max_active_vms=max_active_vms,
|
||||
)
|
||||
|
||||
@property
|
||||
def manager(self) -> VmManager:
|
||||
return self._manager
|
||||
|
||||
def list_profiles(self) -> list[dict[str, object]]:
|
||||
return self._manager.list_profiles()
|
||||
|
||||
def create_vm(
|
||||
self,
|
||||
*,
|
||||
profile: str,
|
||||
vcpu_count: int,
|
||||
mem_mib: int,
|
||||
ttl_seconds: int = 600,
|
||||
network: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.create_vm(
|
||||
profile=profile,
|
||||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
)
|
||||
|
||||
def start_vm(self, vm_id: str) -> dict[str, Any]:
|
||||
return self._manager.start_vm(vm_id)
|
||||
|
||||
def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
|
||||
return self._manager.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds)
|
||||
|
||||
def stop_vm(self, vm_id: str) -> dict[str, Any]:
|
||||
return self._manager.stop_vm(vm_id)
|
||||
|
||||
def delete_vm(self, vm_id: str) -> dict[str, Any]:
|
||||
return self._manager.delete_vm(vm_id)
|
||||
|
||||
def status_vm(self, vm_id: str) -> dict[str, Any]:
|
||||
return self._manager.status_vm(vm_id)
|
||||
|
||||
def network_info_vm(self, vm_id: str) -> dict[str, Any]:
|
||||
return self._manager.network_info_vm(vm_id)
|
||||
|
||||
def reap_expired(self) -> dict[str, Any]:
|
||||
return self._manager.reap_expired()
|
||||
|
||||
def run_in_vm(
|
||||
self,
|
||||
*,
|
||||
profile: str,
|
||||
command: str,
|
||||
vcpu_count: int,
|
||||
mem_mib: int,
|
||||
timeout_seconds: int = 30,
|
||||
ttl_seconds: int = 600,
|
||||
network: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.run_vm(
|
||||
profile=profile,
|
||||
command=command,
|
||||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
timeout_seconds=timeout_seconds,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
)
|
||||
|
||||
def create_server(self) -> FastMCP:
|
||||
server = FastMCP(name="pyro_mcp")
|
||||
|
||||
@server.tool()
|
||||
async def vm_run(
|
||||
profile: str,
|
||||
command: str,
|
||||
vcpu_count: int,
|
||||
mem_mib: int,
|
||||
timeout_seconds: int = 30,
|
||||
ttl_seconds: int = 600,
|
||||
network: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Create, start, execute, and clean up an ephemeral VM."""
|
||||
return self.run_in_vm(
|
||||
profile=profile,
|
||||
command=command,
|
||||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
timeout_seconds=timeout_seconds,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def vm_list_profiles() -> list[dict[str, object]]:
|
||||
"""List standard environment profiles and package highlights."""
|
||||
return self.list_profiles()
|
||||
|
||||
@server.tool()
|
||||
async def vm_create(
|
||||
profile: str,
|
||||
vcpu_count: int,
|
||||
mem_mib: int,
|
||||
ttl_seconds: int = 600,
|
||||
network: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Create an ephemeral VM record with profile and resource sizing."""
|
||||
return self.create_vm(
|
||||
profile=profile,
|
||||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def vm_start(vm_id: str) -> dict[str, Any]:
|
||||
"""Start a created VM and transition it into a command-ready state."""
|
||||
return self.start_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_exec(vm_id: str, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
|
||||
"""Run one non-interactive command and auto-clean the VM."""
|
||||
return self.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds)
|
||||
|
||||
@server.tool()
|
||||
async def vm_stop(vm_id: str) -> dict[str, Any]:
|
||||
"""Stop a running VM."""
|
||||
return self.stop_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_delete(vm_id: str) -> dict[str, Any]:
|
||||
"""Delete a VM and its runtime artifacts."""
|
||||
return self.delete_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_status(vm_id: str) -> dict[str, Any]:
|
||||
"""Get the current state and metadata for a VM."""
|
||||
return self.status_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_network_info(vm_id: str) -> dict[str, Any]:
|
||||
"""Get the current network configuration assigned to a VM."""
|
||||
return self.network_info_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_reap_expired() -> dict[str, Any]:
|
||||
"""Delete VMs whose TTL has expired."""
|
||||
return self.reap_expired()
|
||||
|
||||
return server
|
||||
103
src/pyro_mcp/cli.py
Normal file
103
src/pyro_mcp/cli.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Public CLI for pyro-mcp."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from pyro_mcp.api import Pyro
|
||||
from pyro_mcp.demo import run_demo
|
||||
from pyro_mcp.ollama_demo import DEFAULT_OLLAMA_BASE_URL, DEFAULT_OLLAMA_MODEL, run_ollama_tool_demo
|
||||
from pyro_mcp.runtime import DEFAULT_PLATFORM, doctor_report
|
||||
|
||||
|
||||
def _print_json(payload: dict[str, Any]) -> None:
|
||||
print(json.dumps(payload, indent=2, sort_keys=True))
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="pyro CLI for ephemeral Firecracker VMs.")
|
||||
subparsers = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
mcp_parser = subparsers.add_parser("mcp", help="Run the MCP server.")
|
||||
mcp_subparsers = mcp_parser.add_subparsers(dest="mcp_command", required=True)
|
||||
mcp_subparsers.add_parser("serve", help="Run the MCP server over stdio.")
|
||||
|
||||
run_parser = subparsers.add_parser("run", help="Run one command inside an ephemeral VM.")
|
||||
run_parser.add_argument("--profile", required=True)
|
||||
run_parser.add_argument("--vcpu-count", type=int, required=True)
|
||||
run_parser.add_argument("--mem-mib", type=int, required=True)
|
||||
run_parser.add_argument("--timeout-seconds", type=int, default=30)
|
||||
run_parser.add_argument("--ttl-seconds", type=int, default=600)
|
||||
run_parser.add_argument("--network", action="store_true")
|
||||
run_parser.add_argument("command_args", nargs=argparse.REMAINDER)
|
||||
|
||||
doctor_parser = subparsers.add_parser("doctor", help="Inspect runtime and host diagnostics.")
|
||||
doctor_parser.add_argument("--platform", default=DEFAULT_PLATFORM)
|
||||
|
||||
demo_parser = subparsers.add_parser("demo", help="Run built-in demos.")
|
||||
demo_subparsers = demo_parser.add_subparsers(dest="demo_command")
|
||||
demo_parser.add_argument("--network", action="store_true")
|
||||
ollama_parser = demo_subparsers.add_parser("ollama", help="Run the Ollama MCP demo.")
|
||||
ollama_parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
|
||||
ollama_parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
|
||||
ollama_parser.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _require_command(command_args: list[str]) -> str:
|
||||
if command_args and command_args[0] == "--":
|
||||
command_args = command_args[1:]
|
||||
if not command_args:
|
||||
raise ValueError("command is required after `pyro run --`")
|
||||
return " ".join(command_args)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _build_parser().parse_args()
|
||||
if args.command == "mcp":
|
||||
Pyro().create_server().run(transport="stdio")
|
||||
return
|
||||
if args.command == "run":
|
||||
command = _require_command(args.command_args)
|
||||
result = Pyro().run_in_vm(
|
||||
profile=args.profile,
|
||||
command=command,
|
||||
vcpu_count=args.vcpu_count,
|
||||
mem_mib=args.mem_mib,
|
||||
timeout_seconds=args.timeout_seconds,
|
||||
ttl_seconds=args.ttl_seconds,
|
||||
network=args.network,
|
||||
)
|
||||
_print_json(result)
|
||||
return
|
||||
if args.command == "doctor":
|
||||
_print_json(doctor_report(platform=args.platform))
|
||||
return
|
||||
if args.command == "demo" and args.demo_command == "ollama":
|
||||
try:
|
||||
result = run_ollama_tool_demo(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
verbose=args.verbose,
|
||||
log=lambda message: print(message, flush=True),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[error] {exc}", flush=True)
|
||||
raise SystemExit(1) from exc
|
||||
exec_result = result["exec_result"]
|
||||
if not isinstance(exec_result, dict):
|
||||
raise RuntimeError("demo produced invalid execution result")
|
||||
print(
|
||||
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
|
||||
f"fallback_used={bool(result.get('fallback_used'))} "
|
||||
f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}",
|
||||
flush=True,
|
||||
)
|
||||
if args.verbose:
|
||||
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)
|
||||
return
|
||||
result = run_demo(network=bool(args.network))
|
||||
_print_json(result)
|
||||
|
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||
import json
|
||||
from typing import Any
|
||||
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
from pyro_mcp.api import Pyro
|
||||
|
||||
INTERNET_PROBE_COMMAND = (
|
||||
'python3 -c "import urllib.request; '
|
||||
|
|
@ -20,15 +20,22 @@ def _demo_command(status: dict[str, Any]) -> str:
|
|||
return "git --version"
|
||||
|
||||
|
||||
def run_demo() -> dict[str, Any]:
|
||||
def run_demo(*, network: bool = False) -> dict[str, Any]:
|
||||
"""Create/start/exec/delete a VM and return command output."""
|
||||
manager = VmManager()
|
||||
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
|
||||
vm_id = str(created["vm_id"])
|
||||
manager.start_vm(vm_id)
|
||||
status = manager.status_vm(vm_id)
|
||||
executed = manager.exec_vm(vm_id, command=_demo_command(status), timeout_seconds=30)
|
||||
return executed
|
||||
pyro = Pyro()
|
||||
status = {
|
||||
"network_enabled": network,
|
||||
"execution_mode": "guest_vsock" if network else "host_compat",
|
||||
}
|
||||
return pyro.run_in_vm(
|
||||
profile="debian-git",
|
||||
command=_demo_command(status),
|
||||
vcpu_count=1,
|
||||
mem_mib=512,
|
||||
timeout_seconds=30,
|
||||
ttl_seconds=600,
|
||||
network=network,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
|||
|
|
@ -9,9 +9,9 @@ import urllib.request
|
|||
from collections.abc import Callable
|
||||
from typing import Any, Final, cast
|
||||
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
from pyro_mcp.api import Pyro
|
||||
|
||||
__all__ = ["VmManager", "run_ollama_tool_demo"]
|
||||
__all__ = ["Pyro", "run_ollama_tool_demo"]
|
||||
|
||||
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
|
||||
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
|
||||
|
|
@ -24,6 +24,27 @@ NETWORK_PROOF_COMMAND: Final[str] = (
|
|||
)
|
||||
|
||||
TOOL_SPECS: Final[list[dict[str, Any]]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "vm_run",
|
||||
"description": "Create, start, execute, and clean up an ephemeral VM in one call.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile": {"type": "string"},
|
||||
"command": {"type": "string"},
|
||||
"vcpu_count": {"type": "integer"},
|
||||
"mem_mib": {"type": "integer"},
|
||||
"timeout_seconds": {"type": "integer"},
|
||||
"ttl_seconds": {"type": "integer"},
|
||||
"network": {"type": "boolean"},
|
||||
},
|
||||
"required": ["profile", "command", "vcpu_count", "mem_mib"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -48,6 +69,7 @@ TOOL_SPECS: Final[list[dict[str, Any]]] = [
|
|||
"vcpu_count": {"type": "integer"},
|
||||
"mem_mib": {"type": "integer"},
|
||||
"ttl_seconds": {"type": "integer"},
|
||||
"network": {"type": "boolean"},
|
||||
},
|
||||
"required": ["profile", "vcpu_count", "mem_mib"],
|
||||
"additionalProperties": False,
|
||||
|
|
@ -170,35 +192,55 @@ def _require_int(arguments: dict[str, Any], key: str) -> int:
|
|||
raise ValueError(f"{key} must be an integer")
|
||||
|
||||
|
||||
def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool:
|
||||
value = arguments.get(key, default)
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
raise ValueError(f"{key} must be a boolean")
|
||||
|
||||
|
||||
def _dispatch_tool_call(
|
||||
manager: VmManager, tool_name: str, arguments: dict[str, Any]
|
||||
pyro: Pyro, tool_name: str, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if tool_name == "vm_run":
|
||||
ttl_seconds = arguments.get("ttl_seconds", 600)
|
||||
timeout_seconds = arguments.get("timeout_seconds", 30)
|
||||
return pyro.run_in_vm(
|
||||
profile=_require_str(arguments, "profile"),
|
||||
command=_require_str(arguments, "command"),
|
||||
vcpu_count=_require_int(arguments, "vcpu_count"),
|
||||
mem_mib=_require_int(arguments, "mem_mib"),
|
||||
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
||||
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
||||
network=_require_bool(arguments, "network", default=False),
|
||||
)
|
||||
if tool_name == "vm_list_profiles":
|
||||
return {"profiles": manager.list_profiles()}
|
||||
return {"profiles": pyro.list_profiles()}
|
||||
if tool_name == "vm_create":
|
||||
ttl_seconds = arguments.get("ttl_seconds", 600)
|
||||
return manager.create_vm(
|
||||
return pyro.create_vm(
|
||||
profile=_require_str(arguments, "profile"),
|
||||
vcpu_count=_require_int(arguments, "vcpu_count"),
|
||||
mem_mib=_require_int(arguments, "mem_mib"),
|
||||
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
||||
network=_require_bool(arguments, "network", default=False),
|
||||
)
|
||||
if tool_name == "vm_start":
|
||||
return manager.start_vm(_require_str(arguments, "vm_id"))
|
||||
return pyro.start_vm(_require_str(arguments, "vm_id"))
|
||||
if tool_name == "vm_exec":
|
||||
timeout_seconds = arguments.get("timeout_seconds", 30)
|
||||
vm_id = _require_str(arguments, "vm_id")
|
||||
status = manager.status_vm(vm_id)
|
||||
status = pyro.status_vm(vm_id)
|
||||
state = status.get("state")
|
||||
if state in {"created", "stopped"}:
|
||||
manager.start_vm(vm_id)
|
||||
return manager.exec_vm(
|
||||
pyro.start_vm(vm_id)
|
||||
return pyro.exec_vm(
|
||||
vm_id,
|
||||
command=_require_str(arguments, "command"),
|
||||
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
||||
)
|
||||
if tool_name == "vm_status":
|
||||
return manager.status_vm(_require_str(arguments, "vm_id"))
|
||||
return pyro.status_vm(_require_str(arguments, "vm_id"))
|
||||
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
|
||||
|
||||
|
||||
|
|
@ -212,11 +254,16 @@ def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception
|
|||
}
|
||||
|
||||
|
||||
def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
|
||||
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
|
||||
vm_id = str(created["vm_id"])
|
||||
manager.start_vm(vm_id)
|
||||
return manager.exec_vm(vm_id, command=NETWORK_PROOF_COMMAND, timeout_seconds=60)
|
||||
def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]:
|
||||
return pyro.run_in_vm(
|
||||
profile="debian-git",
|
||||
command=NETWORK_PROOF_COMMAND,
|
||||
vcpu_count=1,
|
||||
mem_mib=512,
|
||||
timeout_seconds=60,
|
||||
ttl_seconds=600,
|
||||
network=True,
|
||||
)
|
||||
|
||||
|
||||
def _is_vm_id_placeholder(value: str) -> bool:
|
||||
|
|
@ -264,6 +311,7 @@ def run_ollama_tool_demo(
|
|||
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
||||
model: str = DEFAULT_OLLAMA_MODEL,
|
||||
*,
|
||||
pyro: Pyro | None = None,
|
||||
strict: bool = True,
|
||||
verbose: bool = False,
|
||||
log: Callable[[str], None] | None = None,
|
||||
|
|
@ -271,15 +319,15 @@ def run_ollama_tool_demo(
|
|||
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
|
||||
emit = log or (lambda _: None)
|
||||
emit(f"[ollama] starting tool demo with model={model}")
|
||||
manager = VmManager()
|
||||
pyro_client = pyro or Pyro()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Use the lifecycle tools to prove outbound internet access in an ephemeral VM.\n"
|
||||
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
|
||||
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
|
||||
"returned by vm_create into vm_start/vm_exec.\n"
|
||||
"Use the VM tools to prove outbound internet access in an ephemeral VM.\n"
|
||||
"Prefer `vm_run` unless a lower-level lifecycle step is strictly necessary.\n"
|
||||
"Use profile `debian-git`, choose adequate vCPU/memory, "
|
||||
"and set `network` to true.\n"
|
||||
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
|
||||
f"Success means the clone completes and the command prints `true`.\n"
|
||||
"If a tool returns an error, fix arguments and retry."
|
||||
|
|
@ -350,7 +398,7 @@ def run_ollama_tool_demo(
|
|||
else:
|
||||
emit(f"[tool] calling {tool_name}")
|
||||
try:
|
||||
result = _dispatch_tool_call(manager, tool_name, arguments)
|
||||
result = _dispatch_tool_call(pyro_client, tool_name, arguments)
|
||||
success = True
|
||||
emit(f"[tool] {tool_name} succeeded")
|
||||
if tool_name == "vm_create":
|
||||
|
|
@ -382,26 +430,26 @@ def run_ollama_tool_demo(
|
|||
}
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("tool-calling loop exceeded maximum rounds")
|
||||
raise RuntimeError("tool-calling loop exceeded maximum rounds")
|
||||
|
||||
exec_event = next(
|
||||
(
|
||||
event
|
||||
for event in reversed(tool_events)
|
||||
if event.get("tool_name") == "vm_exec" and bool(event.get("success"))
|
||||
if event.get("tool_name") in {"vm_exec", "vm_run"} and bool(event.get("success"))
|
||||
),
|
||||
None,
|
||||
)
|
||||
fallback_used = False
|
||||
if exec_event is None:
|
||||
if strict:
|
||||
raise RuntimeError("demo did not execute a successful vm_exec")
|
||||
emit("[fallback] model did not complete vm_exec; running direct lifecycle command")
|
||||
exec_result = _run_direct_lifecycle_fallback(manager)
|
||||
raise RuntimeError("demo did not execute a successful vm_run or vm_exec")
|
||||
emit("[fallback] model did not complete vm_run; running direct lifecycle command")
|
||||
exec_result = _run_direct_lifecycle_fallback(pyro_client)
|
||||
fallback_used = True
|
||||
tool_events.append(
|
||||
{
|
||||
"tool_name": "vm_exec_fallback",
|
||||
"tool_name": "vm_run_fallback",
|
||||
"arguments": {"command": NETWORK_PROOF_COMMAND},
|
||||
"result": exec_result,
|
||||
"success": True,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,7 @@ import argparse
|
|||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
from pyro_mcp.vm_network import TapNetworkManager
|
||||
from pyro_mcp.api import Pyro
|
||||
|
||||
NETWORK_CHECK_COMMAND = (
|
||||
"rm -rf hello-world "
|
||||
|
|
@ -36,32 +35,24 @@ def run_network_check(
|
|||
timeout_seconds: int = 120,
|
||||
base_dir: Path | None = None,
|
||||
) -> NetworkCheckResult: # pragma: no cover - integration helper
|
||||
manager = VmManager(
|
||||
base_dir=base_dir,
|
||||
network_manager=TapNetworkManager(enabled=True),
|
||||
)
|
||||
created = manager.create_vm(
|
||||
pyro = Pyro(base_dir=base_dir)
|
||||
result = pyro.run_in_vm(
|
||||
profile=profile,
|
||||
command=NETWORK_CHECK_COMMAND,
|
||||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
vm_id = str(created["vm_id"])
|
||||
manager.start_vm(vm_id)
|
||||
status = manager.status_vm(vm_id)
|
||||
executed = manager.exec_vm(
|
||||
vm_id,
|
||||
command=NETWORK_CHECK_COMMAND,
|
||||
timeout_seconds=timeout_seconds,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=True,
|
||||
)
|
||||
return NetworkCheckResult(
|
||||
vm_id=vm_id,
|
||||
execution_mode=str(executed["execution_mode"]),
|
||||
network_enabled=bool(status["network_enabled"]),
|
||||
exit_code=int(executed["exit_code"]),
|
||||
stdout=str(executed["stdout"]),
|
||||
stderr=str(executed["stderr"]),
|
||||
cleanup=dict(executed["cleanup"]),
|
||||
vm_id=str(result["vm_id"]),
|
||||
execution_mode=str(result["execution_mode"]),
|
||||
network_enabled=True,
|
||||
exit_code=int(result["exit_code"]),
|
||||
stdout=str(result["stdout"]),
|
||||
stderr=str(result["stderr"]),
|
||||
cleanup=dict(result["cleanup"]),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,74 +2,15 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from pyro_mcp.api import Pyro
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
|
||||
|
||||
def create_server(manager: VmManager | None = None) -> FastMCP:
|
||||
"""Create and return a configured MCP server instance."""
|
||||
vm_manager = manager or VmManager()
|
||||
server = FastMCP(name="pyro_mcp")
|
||||
|
||||
@server.tool()
|
||||
async def vm_list_profiles() -> list[dict[str, object]]:
|
||||
"""List standard environment profiles and package highlights."""
|
||||
return vm_manager.list_profiles()
|
||||
|
||||
@server.tool()
|
||||
async def vm_create(
|
||||
profile: str,
|
||||
vcpu_count: int,
|
||||
mem_mib: int,
|
||||
ttl_seconds: int = 600,
|
||||
) -> dict[str, Any]:
|
||||
"""Create an ephemeral VM record with profile and resource sizing."""
|
||||
return vm_manager.create_vm(
|
||||
profile=profile,
|
||||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def vm_start(vm_id: str) -> dict[str, Any]:
|
||||
"""Start a created VM and transition it into a command-ready state."""
|
||||
return vm_manager.start_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_exec(vm_id: str, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
|
||||
"""Run one non-interactive command and auto-clean the VM."""
|
||||
return vm_manager.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds)
|
||||
|
||||
@server.tool()
|
||||
async def vm_stop(vm_id: str) -> dict[str, Any]:
|
||||
"""Stop a running VM."""
|
||||
return vm_manager.stop_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_delete(vm_id: str) -> dict[str, Any]:
|
||||
"""Delete a VM and its runtime artifacts."""
|
||||
return vm_manager.delete_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_status(vm_id: str) -> dict[str, Any]:
|
||||
"""Get the current state and metadata for a VM."""
|
||||
return vm_manager.status_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_network_info(vm_id: str) -> dict[str, Any]:
|
||||
"""Get the current network configuration assigned to a VM."""
|
||||
return vm_manager.network_info_vm(vm_id)
|
||||
|
||||
@server.tool()
|
||||
async def vm_reap_expired() -> dict[str, Any]:
|
||||
"""Delete VMs whose TTL has expired."""
|
||||
return vm_manager.reap_expired()
|
||||
|
||||
return server
|
||||
return Pyro(manager=manager).create_server()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class VmInstance:
|
|||
expires_at: float
|
||||
workdir: Path
|
||||
state: VmState = "created"
|
||||
network_requested: bool = False
|
||||
firecracker_pid: int | None = None
|
||||
last_error: str | None = None
|
||||
metadata: dict[str, str] = field(default_factory=dict)
|
||||
|
|
@ -165,7 +166,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
|
|||
rootfs_copy = instance.workdir / "rootfs.ext4"
|
||||
shutil.copy2(artifacts.rootfs_image, rootfs_copy)
|
||||
instance.metadata["rootfs_image"] = str(rootfs_copy)
|
||||
if self._network_manager.enabled:
|
||||
if instance.network_requested:
|
||||
network = self._network_manager.allocate(instance.vm_id)
|
||||
instance.network = network
|
||||
instance.metadata.update(self._network_manager.to_metadata(network))
|
||||
|
|
@ -342,7 +343,12 @@ class VmManager:
|
|||
reason="mock backend does not boot a guest",
|
||||
)
|
||||
self._max_active_vms = max_active_vms
|
||||
self._network_manager = network_manager or TapNetworkManager()
|
||||
if network_manager is not None:
|
||||
self._network_manager = network_manager
|
||||
elif self._backend_name == "firecracker":
|
||||
self._network_manager = TapNetworkManager(enabled=True)
|
||||
else:
|
||||
self._network_manager = TapNetworkManager(enabled=False)
|
||||
self._lock = threading.Lock()
|
||||
self._instances: dict[str, VmInstance] = {}
|
||||
self._base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -367,7 +373,13 @@ class VmManager:
|
|||
return list_profiles()
|
||||
|
||||
def create_vm(
|
||||
self, *, profile: str, vcpu_count: int, mem_mib: int, ttl_seconds: int
|
||||
self,
|
||||
*,
|
||||
profile: str,
|
||||
vcpu_count: int,
|
||||
mem_mib: int,
|
||||
ttl_seconds: int,
|
||||
network: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
|
||||
get_profile(profile)
|
||||
|
|
@ -389,11 +401,41 @@ class VmManager:
|
|||
created_at=now,
|
||||
expires_at=now + ttl_seconds,
|
||||
workdir=self._base_dir / vm_id,
|
||||
network_requested=network,
|
||||
)
|
||||
self._backend.create(instance)
|
||||
self._instances[vm_id] = instance
|
||||
return self._serialize(instance)
|
||||
|
||||
def run_vm(
|
||||
self,
|
||||
*,
|
||||
profile: str,
|
||||
command: str,
|
||||
vcpu_count: int,
|
||||
mem_mib: int,
|
||||
timeout_seconds: int = 30,
|
||||
ttl_seconds: int = 600,
|
||||
network: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
created = self.create_vm(
|
||||
profile=profile,
|
||||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
)
|
||||
vm_id = str(created["vm_id"])
|
||||
try:
|
||||
self.start_vm(vm_id)
|
||||
return self.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds)
|
||||
except Exception:
|
||||
try:
|
||||
self.delete_vm(vm_id, reason="run_vm_error_cleanup")
|
||||
except ValueError:
|
||||
pass
|
||||
raise
|
||||
|
||||
def start_vm(self, vm_id: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
instance = self._get_instance_locked(vm_id)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue