Unify public UX around pyro CLI and Pyro facade

This commit is contained in:
Thales Maciel 2026-03-07 16:28:28 -03:00
parent d16aadd03f
commit 23a2dfb330
19 changed files with 936 additions and 407 deletions

View file

@ -9,9 +9,9 @@ import urllib.request
from collections.abc import Callable
from typing import Any, Final, cast
from pyro_mcp.vm_manager import VmManager
from pyro_mcp.api import Pyro
__all__ = ["VmManager", "run_ollama_tool_demo"]
__all__ = ["Pyro", "run_ollama_tool_demo"]
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
@ -24,6 +24,27 @@ NETWORK_PROOF_COMMAND: Final[str] = (
)
TOOL_SPECS: Final[list[dict[str, Any]]] = [
{
"type": "function",
"function": {
"name": "vm_run",
"description": "Create, start, execute, and clean up an ephemeral VM in one call.",
"parameters": {
"type": "object",
"properties": {
"profile": {"type": "string"},
"command": {"type": "string"},
"vcpu_count": {"type": "integer"},
"mem_mib": {"type": "integer"},
"timeout_seconds": {"type": "integer"},
"ttl_seconds": {"type": "integer"},
"network": {"type": "boolean"},
},
"required": ["profile", "command", "vcpu_count", "mem_mib"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
@ -48,6 +69,7 @@ TOOL_SPECS: Final[list[dict[str, Any]]] = [
"vcpu_count": {"type": "integer"},
"mem_mib": {"type": "integer"},
"ttl_seconds": {"type": "integer"},
"network": {"type": "boolean"},
},
"required": ["profile", "vcpu_count", "mem_mib"],
"additionalProperties": False,
@ -170,35 +192,55 @@ def _require_int(arguments: dict[str, Any], key: str) -> int:
raise ValueError(f"{key} must be an integer")
def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool:
value = arguments.get(key, default)
if isinstance(value, bool):
return value
raise ValueError(f"{key} must be a boolean")
def _dispatch_tool_call(
manager: VmManager, tool_name: str, arguments: dict[str, Any]
pyro: Pyro, tool_name: str, arguments: dict[str, Any]
) -> dict[str, Any]:
if tool_name == "vm_run":
ttl_seconds = arguments.get("ttl_seconds", 600)
timeout_seconds = arguments.get("timeout_seconds", 30)
return pyro.run_in_vm(
profile=_require_str(arguments, "profile"),
command=_require_str(arguments, "command"),
vcpu_count=_require_int(arguments, "vcpu_count"),
mem_mib=_require_int(arguments, "mem_mib"),
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
network=_require_bool(arguments, "network", default=False),
)
if tool_name == "vm_list_profiles":
return {"profiles": manager.list_profiles()}
return {"profiles": pyro.list_profiles()}
if tool_name == "vm_create":
ttl_seconds = arguments.get("ttl_seconds", 600)
return manager.create_vm(
return pyro.create_vm(
profile=_require_str(arguments, "profile"),
vcpu_count=_require_int(arguments, "vcpu_count"),
mem_mib=_require_int(arguments, "mem_mib"),
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
network=_require_bool(arguments, "network", default=False),
)
if tool_name == "vm_start":
return manager.start_vm(_require_str(arguments, "vm_id"))
return pyro.start_vm(_require_str(arguments, "vm_id"))
if tool_name == "vm_exec":
timeout_seconds = arguments.get("timeout_seconds", 30)
vm_id = _require_str(arguments, "vm_id")
status = manager.status_vm(vm_id)
status = pyro.status_vm(vm_id)
state = status.get("state")
if state in {"created", "stopped"}:
manager.start_vm(vm_id)
return manager.exec_vm(
pyro.start_vm(vm_id)
return pyro.exec_vm(
vm_id,
command=_require_str(arguments, "command"),
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
)
if tool_name == "vm_status":
return manager.status_vm(_require_str(arguments, "vm_id"))
return pyro.status_vm(_require_str(arguments, "vm_id"))
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
@ -212,11 +254,16 @@ def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception
}
def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
manager.start_vm(vm_id)
return manager.exec_vm(vm_id, command=NETWORK_PROOF_COMMAND, timeout_seconds=60)
def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]:
return pyro.run_in_vm(
profile="debian-git",
command=NETWORK_PROOF_COMMAND,
vcpu_count=1,
mem_mib=512,
timeout_seconds=60,
ttl_seconds=600,
network=True,
)
def _is_vm_id_placeholder(value: str) -> bool:
@ -264,6 +311,7 @@ def run_ollama_tool_demo(
base_url: str = DEFAULT_OLLAMA_BASE_URL,
model: str = DEFAULT_OLLAMA_MODEL,
*,
pyro: Pyro | None = None,
strict: bool = True,
verbose: bool = False,
log: Callable[[str], None] | None = None,
@ -271,15 +319,15 @@ def run_ollama_tool_demo(
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
emit = log or (lambda _: None)
emit(f"[ollama] starting tool demo with model={model}")
manager = VmManager()
pyro_client = pyro or Pyro()
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": (
"Use the lifecycle tools to prove outbound internet access in an ephemeral VM.\n"
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
"returned by vm_create into vm_start/vm_exec.\n"
"Use the VM tools to prove outbound internet access in an ephemeral VM.\n"
"Prefer `vm_run` unless a lower-level lifecycle step is strictly necessary.\n"
"Use profile `debian-git`, choose adequate vCPU/memory, "
"and set `network` to true.\n"
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
f"Success means the clone completes and the command prints `true`.\n"
"If a tool returns an error, fix arguments and retry."
@ -350,7 +398,7 @@ def run_ollama_tool_demo(
else:
emit(f"[tool] calling {tool_name}")
try:
result = _dispatch_tool_call(manager, tool_name, arguments)
result = _dispatch_tool_call(pyro_client, tool_name, arguments)
success = True
emit(f"[tool] {tool_name} succeeded")
if tool_name == "vm_create":
@ -382,26 +430,26 @@ def run_ollama_tool_demo(
}
)
else:
raise RuntimeError("tool-calling loop exceeded maximum rounds")
raise RuntimeError("tool-calling loop exceeded maximum rounds")
exec_event = next(
(
event
for event in reversed(tool_events)
if event.get("tool_name") == "vm_exec" and bool(event.get("success"))
if event.get("tool_name") in {"vm_exec", "vm_run"} and bool(event.get("success"))
),
None,
)
fallback_used = False
if exec_event is None:
if strict:
raise RuntimeError("demo did not execute a successful vm_exec")
emit("[fallback] model did not complete vm_exec; running direct lifecycle command")
exec_result = _run_direct_lifecycle_fallback(manager)
raise RuntimeError("demo did not execute a successful vm_run or vm_exec")
emit("[fallback] model did not complete vm_run; running direct lifecycle command")
exec_result = _run_direct_lifecycle_fallback(pyro_client)
fallback_used = True
tool_events.append(
{
"tool_name": "vm_exec_fallback",
"tool_name": "vm_run_fallback",
"arguments": {"command": NETWORK_PROOF_COMMAND},
"result": exec_result,
"success": True,