pyro-mcp/src/pyro_mcp/ollama_demo.py

558 lines
21 KiB
Python

"""Ollama demo that drives VM lifecycle tools to run an ephemeral command."""
from __future__ import annotations
import argparse
import json
import urllib.error
import urllib.request
from collections.abc import Callable
from typing import Any, Final, cast
from pyro_mcp.api import Pyro
from pyro_mcp.vm_manager import (
DEFAULT_ALLOW_HOST_COMPAT,
DEFAULT_MEM_MIB,
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TTL_SECONDS,
DEFAULT_VCPU_COUNT,
)
__all__ = ["Pyro", "run_ollama_tool_demo"]
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
DEFAULT_OLLAMA_MODEL: Final[str] = "llama3.2:3b"
MAX_TOOL_ROUNDS: Final[int] = 12
NETWORK_PROOF_COMMAND: Final[str] = (
'python3 -c "import urllib.request as u; '
"print(u.urlopen('https://example.com').status)"
'"'
)
TOOL_SPECS: Final[list[dict[str, Any]]] = [
{
"type": "function",
"function": {
"name": "vm_run",
"description": "Create, start, execute, and clean up an ephemeral VM in one call.",
"parameters": {
"type": "object",
"properties": {
"environment": {"type": "string"},
"command": {"type": "string"},
"vcpu_count": {"type": "integer"},
"mem_mib": {"type": "integer"},
"timeout_seconds": {"type": "integer"},
"ttl_seconds": {"type": "integer"},
"network": {"type": "boolean"},
"allow_host_compat": {"type": "boolean"},
},
"required": ["environment", "command"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_list_environments",
"description": "List curated Linux environments and installation status.",
"parameters": {
"type": "object",
"properties": {},
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_create",
"description": "Create an ephemeral VM with optional resource sizing.",
"parameters": {
"type": "object",
"properties": {
"environment": {"type": "string"},
"vcpu_count": {"type": "integer"},
"mem_mib": {"type": "integer"},
"ttl_seconds": {"type": "integer"},
"network": {"type": "boolean"},
"allow_host_compat": {"type": "boolean"},
},
"required": ["environment"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_start",
"description": "Start a VM before command execution.",
"parameters": {
"type": "object",
"properties": {"vm_id": {"type": "string"}},
"required": ["vm_id"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_exec",
"description": "Run one non-interactive command inside the VM and auto-clean it.",
"parameters": {
"type": "object",
"properties": {
"vm_id": {"type": "string"},
"command": {"type": "string"},
"timeout_seconds": {"type": "integer"},
},
"required": ["vm_id", "command"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_status",
"description": "Read current VM status and metadata.",
"parameters": {
"type": "object",
"properties": {"vm_id": {"type": "string"}},
"required": ["vm_id"],
"additionalProperties": False,
},
},
},
]
def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
endpoint = f"{base_url.rstrip('/')}/chat/completions"
body = json.dumps(payload).encode("utf-8")
request = urllib.request.Request(
endpoint,
data=body,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=90) as response:
response_text = response.read().decode("utf-8")
except urllib.error.URLError as exc:
raise RuntimeError(
"failed to call Ollama. Ensure `ollama serve` is running and the model is available."
) from exc
parsed = json.loads(response_text)
if not isinstance(parsed, dict):
raise TypeError("unexpected Ollama response shape")
return cast(dict[str, Any], parsed)
def _extract_message(response: dict[str, Any]) -> dict[str, Any]:
choices = response.get("choices")
if not isinstance(choices, list) or not choices:
raise RuntimeError("Ollama response did not contain completion choices")
first = choices[0]
if not isinstance(first, dict):
raise RuntimeError("unexpected completion choice format")
message = first.get("message")
if not isinstance(message, dict):
raise RuntimeError("completion choice did not contain a message")
return cast(dict[str, Any], message)
def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]:
if raw_arguments is None:
return {}
if isinstance(raw_arguments, dict):
return cast(dict[str, Any], raw_arguments)
if isinstance(raw_arguments, str):
if raw_arguments.strip() == "":
return {}
parsed = json.loads(raw_arguments)
if not isinstance(parsed, dict):
raise TypeError("tool arguments must decode to an object")
return cast(dict[str, Any], parsed)
raise TypeError("tool arguments must be a dictionary or JSON object string")
def _require_str(arguments: dict[str, Any], key: str) -> str:
value = arguments.get(key)
if not isinstance(value, str) or value == "":
raise ValueError(f"{key} must be a non-empty string")
return value
def _require_int(arguments: dict[str, Any], key: str) -> int:
value = arguments.get(key)
if isinstance(value, bool):
raise ValueError(f"{key} must be an integer")
if isinstance(value, int):
return value
if isinstance(value, str):
normalized = value.strip()
if normalized.isdigit():
return int(normalized)
raise ValueError(f"{key} must be an integer")
def _optional_int(arguments: dict[str, Any], key: str, *, default: int) -> int:
if key not in arguments:
return default
return _require_int(arguments, key)
def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool:
value = arguments.get(key, default)
if isinstance(value, bool):
return value
if isinstance(value, int) and value in {0, 1}:
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "1", "yes", "on"}:
return True
if normalized in {"false", "0", "no", "off"}:
return False
raise ValueError(f"{key} must be a boolean")
def _dispatch_tool_call(
pyro: Pyro, tool_name: str, arguments: dict[str, Any]
) -> dict[str, Any]:
if tool_name == "vm_run":
ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)
timeout_seconds = arguments.get("timeout_seconds", DEFAULT_TIMEOUT_SECONDS)
return pyro.run_in_vm(
environment=_require_str(arguments, "environment"),
command=_require_str(arguments, "command"),
vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT),
mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB),
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
network=_require_bool(arguments, "network", default=False),
allow_host_compat=_require_bool(
arguments,
"allow_host_compat",
default=DEFAULT_ALLOW_HOST_COMPAT,
),
)
if tool_name == "vm_list_environments":
return {"environments": pyro.list_environments()}
if tool_name == "vm_create":
ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)
return pyro.create_vm(
environment=_require_str(arguments, "environment"),
vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT),
mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB),
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
network=_require_bool(arguments, "network", default=False),
allow_host_compat=_require_bool(
arguments,
"allow_host_compat",
default=DEFAULT_ALLOW_HOST_COMPAT,
),
)
if tool_name == "vm_start":
return pyro.start_vm(_require_str(arguments, "vm_id"))
if tool_name == "vm_exec":
timeout_seconds = arguments.get("timeout_seconds", 30)
vm_id = _require_str(arguments, "vm_id")
status = pyro.status_vm(vm_id)
state = status.get("state")
if state in {"created", "stopped"}:
pyro.start_vm(vm_id)
return pyro.exec_vm(
vm_id,
command=_require_str(arguments, "command"),
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
)
if tool_name == "vm_status":
return pyro.status_vm(_require_str(arguments, "vm_id"))
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]:
payload = {
"ok": False,
"tool_name": tool_name,
"arguments": arguments,
"error_type": exc.__class__.__name__,
"error": str(exc),
}
error_text = str(exc)
if "must be a boolean" in error_text:
payload["hint"] = "Use JSON booleans true or false, not quoted strings."
if (
"environment must be a non-empty string" in error_text
and isinstance(arguments.get("profile"), str)
):
payload["hint"] = "Use `environment` instead of `profile`."
return payload
def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]:
return pyro.run_in_vm(
environment="debian:12",
command=NETWORK_PROOF_COMMAND,
vcpu_count=DEFAULT_VCPU_COUNT,
mem_mib=DEFAULT_MEM_MIB,
timeout_seconds=60,
ttl_seconds=DEFAULT_TTL_SECONDS,
network=True,
)
def _is_vm_id_placeholder(value: str) -> bool:
normalized = value.strip().lower()
if normalized in {
"vm_id_returned_by_vm_create",
"<vm_id_returned_by_vm_create>",
"{vm_id_returned_by_vm_create}",
}:
return True
return normalized.startswith("<") and normalized.endswith(">") and "vm_id" in normalized
def _normalize_tool_arguments(
tool_name: str,
arguments: dict[str, Any],
*,
last_created_vm_id: str | None,
) -> tuple[dict[str, Any], str | None]:
normalized_arguments = dict(arguments)
normalized_vm_id: str | None = None
if tool_name in {"vm_run", "vm_create"}:
legacy_profile = normalized_arguments.get("profile")
if "environment" not in normalized_arguments and isinstance(legacy_profile, str):
normalized_arguments["environment"] = legacy_profile
normalized_arguments.pop("profile", None)
if tool_name in {"vm_start", "vm_exec", "vm_status"} and last_created_vm_id is not None:
vm_id = normalized_arguments.get("vm_id")
if isinstance(vm_id, str) and _is_vm_id_placeholder(vm_id):
normalized_arguments["vm_id"] = last_created_vm_id
normalized_vm_id = last_created_vm_id
return normalized_arguments, normalized_vm_id
def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str:
role = str(message.get("role", "unknown"))
if not verbose:
return role
content = str(message.get("content") or "").strip()
if content == "":
return f"{role}: <empty>"
return f"{role}: {content}"
def _serialize_log_value(value: Any) -> str:
return json.dumps(value, sort_keys=True, separators=(",", ":"))
def run_ollama_tool_demo(
base_url: str = DEFAULT_OLLAMA_BASE_URL,
model: str = DEFAULT_OLLAMA_MODEL,
*,
pyro: Pyro | None = None,
strict: bool = True,
verbose: bool = False,
log: Callable[[str], None] | None = None,
) -> dict[str, Any]:
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
emit = log or (lambda _: None)
emit(f"[ollama] starting tool demo with model={model}")
pyro_client = pyro or Pyro()
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": (
"Use the VM tools to prove outbound internet access in an ephemeral VM.\n"
"Prefer `vm_run` unless a lower-level lifecycle step is strictly necessary.\n"
"Use environment `debian:12`, choose adequate vCPU/memory, "
"and set `network` to true.\n"
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
f"Success means the clone completes and the command prints `true`.\n"
"If a tool returns an error, fix arguments and retry."
),
}
]
tool_events: list[dict[str, Any]] = []
final_response = ""
last_created_vm_id: str | None = None
for _round_index in range(1, MAX_TOOL_ROUNDS + 1):
emit(f"[model] input {_summarize_message_for_log(messages[-1], verbose=verbose)}")
response = _post_chat_completion(
base_url,
{
"model": model,
"messages": messages,
"tools": TOOL_SPECS,
"tool_choice": "auto",
"temperature": 0,
},
)
assistant_message = _extract_message(response)
emit(f"[model] output {_summarize_message_for_log(assistant_message, verbose=verbose)}")
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
final_response = str(assistant_message.get("content") or "")
emit("[ollama] no tool calls returned; stopping loop")
break
messages.append(
{
"role": "assistant",
"content": str(assistant_message.get("content") or ""),
"tool_calls": tool_calls,
}
)
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise RuntimeError("invalid tool call entry returned by model")
call_id = tool_call.get("id")
if not isinstance(call_id, str) or call_id == "":
raise RuntimeError("tool call did not provide a valid call id")
function = tool_call.get("function")
if not isinstance(function, dict):
raise RuntimeError("tool call did not include function metadata")
tool_name = function.get("name")
if not isinstance(tool_name, str):
raise RuntimeError("tool call function name is invalid")
arguments = _parse_tool_arguments(function.get("arguments"))
if verbose:
emit(f"[model] tool_call {tool_name} args={arguments}")
else:
emit(f"[model] tool_call {tool_name}")
arguments, normalized_vm_id = _normalize_tool_arguments(
tool_name,
arguments,
last_created_vm_id=last_created_vm_id,
)
if normalized_vm_id is not None:
if verbose:
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
else:
emit("[tool] resolved vm_id placeholder")
if verbose:
emit(f"[tool] calling {tool_name} with args={arguments}")
else:
emit(f"[tool] calling {tool_name}")
try:
result = _dispatch_tool_call(pyro_client, tool_name, arguments)
success = True
emit(f"[tool] {tool_name} succeeded")
if tool_name == "vm_create":
created_vm_id = result.get("vm_id")
if isinstance(created_vm_id, str) and created_vm_id != "":
last_created_vm_id = created_vm_id
except Exception as exc: # noqa: BLE001
result = _format_tool_error(tool_name, arguments, exc)
success = False
emit(
f"[tool] {tool_name} failed: {exc} "
f"args={_serialize_log_value(arguments)}"
)
if verbose:
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
else:
emit(f"[tool] result {tool_name}")
tool_events.append(
{
"tool_name": tool_name,
"arguments": arguments,
"result": result,
"success": success,
}
)
messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"name": tool_name,
"content": json.dumps(result, sort_keys=True),
}
)
else:
raise RuntimeError("tool-calling loop exceeded maximum rounds")
exec_event = next(
(
event
for event in reversed(tool_events)
if event.get("tool_name") in {"vm_exec", "vm_run"} and bool(event.get("success"))
),
None,
)
fallback_used = False
if exec_event is None:
if strict:
raise RuntimeError("demo did not execute a successful vm_run or vm_exec")
emit("[fallback] model did not complete vm_run; running direct lifecycle command")
exec_result = _run_direct_lifecycle_fallback(pyro_client)
fallback_used = True
tool_events.append(
{
"tool_name": "vm_run_fallback",
"arguments": {"command": NETWORK_PROOF_COMMAND},
"result": exec_result,
"success": True,
}
)
else:
exec_result = exec_event["result"]
if not isinstance(exec_result, dict):
raise RuntimeError("vm_exec result shape is invalid")
if int(exec_result.get("exit_code", -1)) != 0:
raise RuntimeError("vm_exec failed; expected exit_code=0")
if str(exec_result.get("stdout", "")).strip() != "true":
raise RuntimeError("vm_exec output did not confirm repository clone success")
emit("[done] command execution succeeded")
return {
"model": model,
"command": NETWORK_PROOF_COMMAND,
"exec_result": exec_result,
"tool_events": tool_events,
"final_response": final_response,
"fallback_used": fallback_used,
}
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
parser.add_argument("-v", "--verbose", action="store_true")
return parser
def main() -> None:
"""CLI entrypoint for Ollama tool-calling demo."""
args = _build_parser().parse_args()
try:
result = run_ollama_tool_demo(
base_url=args.base_url,
model=args.model,
verbose=args.verbose,
log=lambda message: print(message, flush=True),
)
except Exception as exc: # noqa: BLE001
print(f"[error] {exc}", flush=True)
raise SystemExit(1) from exc
exec_result = result["exec_result"]
if not isinstance(exec_result, dict):
raise RuntimeError("demo produced invalid execution result")
print(
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
f"fallback_used={bool(result.get('fallback_used'))} "
f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}",
flush=True,
)
if args.verbose:
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)