Unify public UX around pyro CLI and Pyro facade
This commit is contained in:
parent
d16aadd03f
commit
23a2dfb330
19 changed files with 936 additions and 407 deletions
|
|
@ -9,9 +9,9 @@ import urllib.request
|
|||
from collections.abc import Callable
|
||||
from typing import Any, Final, cast
|
||||
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
from pyro_mcp.api import Pyro
|
||||
|
||||
__all__ = ["VmManager", "run_ollama_tool_demo"]
|
||||
__all__ = ["Pyro", "run_ollama_tool_demo"]
|
||||
|
||||
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
|
||||
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
|
||||
|
|
@ -24,6 +24,27 @@ NETWORK_PROOF_COMMAND: Final[str] = (
|
|||
)
|
||||
|
||||
TOOL_SPECS: Final[list[dict[str, Any]]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "vm_run",
|
||||
"description": "Create, start, execute, and clean up an ephemeral VM in one call.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile": {"type": "string"},
|
||||
"command": {"type": "string"},
|
||||
"vcpu_count": {"type": "integer"},
|
||||
"mem_mib": {"type": "integer"},
|
||||
"timeout_seconds": {"type": "integer"},
|
||||
"ttl_seconds": {"type": "integer"},
|
||||
"network": {"type": "boolean"},
|
||||
},
|
||||
"required": ["profile", "command", "vcpu_count", "mem_mib"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
|
|
@ -48,6 +69,7 @@ TOOL_SPECS: Final[list[dict[str, Any]]] = [
|
|||
"vcpu_count": {"type": "integer"},
|
||||
"mem_mib": {"type": "integer"},
|
||||
"ttl_seconds": {"type": "integer"},
|
||||
"network": {"type": "boolean"},
|
||||
},
|
||||
"required": ["profile", "vcpu_count", "mem_mib"],
|
||||
"additionalProperties": False,
|
||||
|
|
@ -170,35 +192,55 @@ def _require_int(arguments: dict[str, Any], key: str) -> int:
|
|||
raise ValueError(f"{key} must be an integer")
|
||||
|
||||
|
||||
def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool:
|
||||
value = arguments.get(key, default)
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
raise ValueError(f"{key} must be a boolean")
|
||||
|
||||
|
||||
def _dispatch_tool_call(
|
||||
manager: VmManager, tool_name: str, arguments: dict[str, Any]
|
||||
pyro: Pyro, tool_name: str, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if tool_name == "vm_run":
|
||||
ttl_seconds = arguments.get("ttl_seconds", 600)
|
||||
timeout_seconds = arguments.get("timeout_seconds", 30)
|
||||
return pyro.run_in_vm(
|
||||
profile=_require_str(arguments, "profile"),
|
||||
command=_require_str(arguments, "command"),
|
||||
vcpu_count=_require_int(arguments, "vcpu_count"),
|
||||
mem_mib=_require_int(arguments, "mem_mib"),
|
||||
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
||||
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
||||
network=_require_bool(arguments, "network", default=False),
|
||||
)
|
||||
if tool_name == "vm_list_profiles":
|
||||
return {"profiles": manager.list_profiles()}
|
||||
return {"profiles": pyro.list_profiles()}
|
||||
if tool_name == "vm_create":
|
||||
ttl_seconds = arguments.get("ttl_seconds", 600)
|
||||
return manager.create_vm(
|
||||
return pyro.create_vm(
|
||||
profile=_require_str(arguments, "profile"),
|
||||
vcpu_count=_require_int(arguments, "vcpu_count"),
|
||||
mem_mib=_require_int(arguments, "mem_mib"),
|
||||
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
||||
network=_require_bool(arguments, "network", default=False),
|
||||
)
|
||||
if tool_name == "vm_start":
|
||||
return manager.start_vm(_require_str(arguments, "vm_id"))
|
||||
return pyro.start_vm(_require_str(arguments, "vm_id"))
|
||||
if tool_name == "vm_exec":
|
||||
timeout_seconds = arguments.get("timeout_seconds", 30)
|
||||
vm_id = _require_str(arguments, "vm_id")
|
||||
status = manager.status_vm(vm_id)
|
||||
status = pyro.status_vm(vm_id)
|
||||
state = status.get("state")
|
||||
if state in {"created", "stopped"}:
|
||||
manager.start_vm(vm_id)
|
||||
return manager.exec_vm(
|
||||
pyro.start_vm(vm_id)
|
||||
return pyro.exec_vm(
|
||||
vm_id,
|
||||
command=_require_str(arguments, "command"),
|
||||
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
||||
)
|
||||
if tool_name == "vm_status":
|
||||
return manager.status_vm(_require_str(arguments, "vm_id"))
|
||||
return pyro.status_vm(_require_str(arguments, "vm_id"))
|
||||
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
|
||||
|
||||
|
||||
|
|
@ -212,11 +254,16 @@ def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception
|
|||
}
|
||||
|
||||
|
||||
def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
|
||||
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
|
||||
vm_id = str(created["vm_id"])
|
||||
manager.start_vm(vm_id)
|
||||
return manager.exec_vm(vm_id, command=NETWORK_PROOF_COMMAND, timeout_seconds=60)
|
||||
def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]:
|
||||
return pyro.run_in_vm(
|
||||
profile="debian-git",
|
||||
command=NETWORK_PROOF_COMMAND,
|
||||
vcpu_count=1,
|
||||
mem_mib=512,
|
||||
timeout_seconds=60,
|
||||
ttl_seconds=600,
|
||||
network=True,
|
||||
)
|
||||
|
||||
|
||||
def _is_vm_id_placeholder(value: str) -> bool:
|
||||
|
|
@ -264,6 +311,7 @@ def run_ollama_tool_demo(
|
|||
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
||||
model: str = DEFAULT_OLLAMA_MODEL,
|
||||
*,
|
||||
pyro: Pyro | None = None,
|
||||
strict: bool = True,
|
||||
verbose: bool = False,
|
||||
log: Callable[[str], None] | None = None,
|
||||
|
|
@ -271,15 +319,15 @@ def run_ollama_tool_demo(
|
|||
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
|
||||
emit = log or (lambda _: None)
|
||||
emit(f"[ollama] starting tool demo with model={model}")
|
||||
manager = VmManager()
|
||||
pyro_client = pyro or Pyro()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Use the lifecycle tools to prove outbound internet access in an ephemeral VM.\n"
|
||||
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
|
||||
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
|
||||
"returned by vm_create into vm_start/vm_exec.\n"
|
||||
"Use the VM tools to prove outbound internet access in an ephemeral VM.\n"
|
||||
"Prefer `vm_run` unless a lower-level lifecycle step is strictly necessary.\n"
|
||||
"Use profile `debian-git`, choose adequate vCPU/memory, "
|
||||
"and set `network` to true.\n"
|
||||
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
|
||||
f"Success means the clone completes and the command prints `true`.\n"
|
||||
"If a tool returns an error, fix arguments and retry."
|
||||
|
|
@ -350,7 +398,7 @@ def run_ollama_tool_demo(
|
|||
else:
|
||||
emit(f"[tool] calling {tool_name}")
|
||||
try:
|
||||
result = _dispatch_tool_call(manager, tool_name, arguments)
|
||||
result = _dispatch_tool_call(pyro_client, tool_name, arguments)
|
||||
success = True
|
||||
emit(f"[tool] {tool_name} succeeded")
|
||||
if tool_name == "vm_create":
|
||||
|
|
@ -382,26 +430,26 @@ def run_ollama_tool_demo(
|
|||
}
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("tool-calling loop exceeded maximum rounds")
|
||||
raise RuntimeError("tool-calling loop exceeded maximum rounds")
|
||||
|
||||
exec_event = next(
|
||||
(
|
||||
event
|
||||
for event in reversed(tool_events)
|
||||
if event.get("tool_name") == "vm_exec" and bool(event.get("success"))
|
||||
if event.get("tool_name") in {"vm_exec", "vm_run"} and bool(event.get("success"))
|
||||
),
|
||||
None,
|
||||
)
|
||||
fallback_used = False
|
||||
if exec_event is None:
|
||||
if strict:
|
||||
raise RuntimeError("demo did not execute a successful vm_exec")
|
||||
emit("[fallback] model did not complete vm_exec; running direct lifecycle command")
|
||||
exec_result = _run_direct_lifecycle_fallback(manager)
|
||||
raise RuntimeError("demo did not execute a successful vm_run or vm_exec")
|
||||
emit("[fallback] model did not complete vm_run; running direct lifecycle command")
|
||||
exec_result = _run_direct_lifecycle_fallback(pyro_client)
|
||||
fallback_used = True
|
||||
tool_events.append(
|
||||
{
|
||||
"tool_name": "vm_exec_fallback",
|
||||
"tool_name": "vm_run_fallback",
|
||||
"arguments": {"command": NETWORK_PROOF_COMMAND},
|
||||
"result": exec_result,
|
||||
"success": True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue