558 lines
21 KiB
Python
558 lines
21 KiB
Python
"""Ollama demo that drives VM lifecycle tools to run an ephemeral command."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import urllib.error
|
|
import urllib.request
|
|
from collections.abc import Callable
|
|
from typing import Any, Final, cast
|
|
|
|
from pyro_mcp.api import Pyro
|
|
from pyro_mcp.vm_manager import (
|
|
DEFAULT_ALLOW_HOST_COMPAT,
|
|
DEFAULT_MEM_MIB,
|
|
DEFAULT_TIMEOUT_SECONDS,
|
|
DEFAULT_TTL_SECONDS,
|
|
DEFAULT_VCPU_COUNT,
|
|
)
|
|
|
|
__all__ = ["Pyro", "run_ollama_tool_demo"]
|
|
|
|
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
|
|
DEFAULT_OLLAMA_MODEL: Final[str] = "llama3.2:3b"
|
|
MAX_TOOL_ROUNDS: Final[int] = 12
|
|
NETWORK_PROOF_COMMAND: Final[str] = (
|
|
'python3 -c "import urllib.request as u; '
|
|
"print(u.urlopen('https://example.com').status)"
|
|
'"'
|
|
)
|
|
|
|
TOOL_SPECS: Final[list[dict[str, Any]]] = [
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "vm_run",
|
|
"description": "Create, start, execute, and clean up an ephemeral VM in one call.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"environment": {"type": "string"},
|
|
"command": {"type": "string"},
|
|
"vcpu_count": {"type": "integer"},
|
|
"mem_mib": {"type": "integer"},
|
|
"timeout_seconds": {"type": "integer"},
|
|
"ttl_seconds": {"type": "integer"},
|
|
"network": {"type": "boolean"},
|
|
"allow_host_compat": {"type": "boolean"},
|
|
},
|
|
"required": ["environment", "command"],
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "vm_list_environments",
|
|
"description": "List curated Linux environments and installation status.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {},
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "vm_create",
|
|
"description": "Create an ephemeral VM with optional resource sizing.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"environment": {"type": "string"},
|
|
"vcpu_count": {"type": "integer"},
|
|
"mem_mib": {"type": "integer"},
|
|
"ttl_seconds": {"type": "integer"},
|
|
"network": {"type": "boolean"},
|
|
"allow_host_compat": {"type": "boolean"},
|
|
},
|
|
"required": ["environment"],
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "vm_start",
|
|
"description": "Start a VM before command execution.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"vm_id": {"type": "string"}},
|
|
"required": ["vm_id"],
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "vm_exec",
|
|
"description": "Run one non-interactive command inside the VM and auto-clean it.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"vm_id": {"type": "string"},
|
|
"command": {"type": "string"},
|
|
"timeout_seconds": {"type": "integer"},
|
|
},
|
|
"required": ["vm_id", "command"],
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "vm_status",
|
|
"description": "Read current VM status and metadata.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {"vm_id": {"type": "string"}},
|
|
"required": ["vm_id"],
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
},
|
|
]
|
|
|
|
|
|
def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
endpoint = f"{base_url.rstrip('/')}/chat/completions"
|
|
body = json.dumps(payload).encode("utf-8")
|
|
request = urllib.request.Request(
|
|
endpoint,
|
|
data=body,
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=90) as response:
|
|
response_text = response.read().decode("utf-8")
|
|
except urllib.error.URLError as exc:
|
|
raise RuntimeError(
|
|
"failed to call Ollama. Ensure `ollama serve` is running and the model is available."
|
|
) from exc
|
|
parsed = json.loads(response_text)
|
|
if not isinstance(parsed, dict):
|
|
raise TypeError("unexpected Ollama response shape")
|
|
return cast(dict[str, Any], parsed)
|
|
|
|
|
|
def _extract_message(response: dict[str, Any]) -> dict[str, Any]:
|
|
choices = response.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
raise RuntimeError("Ollama response did not contain completion choices")
|
|
first = choices[0]
|
|
if not isinstance(first, dict):
|
|
raise RuntimeError("unexpected completion choice format")
|
|
message = first.get("message")
|
|
if not isinstance(message, dict):
|
|
raise RuntimeError("completion choice did not contain a message")
|
|
return cast(dict[str, Any], message)
|
|
|
|
|
|
def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]:
|
|
if raw_arguments is None:
|
|
return {}
|
|
if isinstance(raw_arguments, dict):
|
|
return cast(dict[str, Any], raw_arguments)
|
|
if isinstance(raw_arguments, str):
|
|
if raw_arguments.strip() == "":
|
|
return {}
|
|
parsed = json.loads(raw_arguments)
|
|
if not isinstance(parsed, dict):
|
|
raise TypeError("tool arguments must decode to an object")
|
|
return cast(dict[str, Any], parsed)
|
|
raise TypeError("tool arguments must be a dictionary or JSON object string")
|
|
|
|
|
|
def _require_str(arguments: dict[str, Any], key: str) -> str:
|
|
value = arguments.get(key)
|
|
if not isinstance(value, str) or value == "":
|
|
raise ValueError(f"{key} must be a non-empty string")
|
|
return value
|
|
|
|
|
|
def _require_int(arguments: dict[str, Any], key: str) -> int:
|
|
value = arguments.get(key)
|
|
if isinstance(value, bool):
|
|
raise ValueError(f"{key} must be an integer")
|
|
if isinstance(value, int):
|
|
return value
|
|
if isinstance(value, str):
|
|
normalized = value.strip()
|
|
if normalized.isdigit():
|
|
return int(normalized)
|
|
raise ValueError(f"{key} must be an integer")
|
|
|
|
|
|
def _optional_int(arguments: dict[str, Any], key: str, *, default: int) -> int:
|
|
if key not in arguments:
|
|
return default
|
|
return _require_int(arguments, key)
|
|
|
|
|
|
def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool:
|
|
value = arguments.get(key, default)
|
|
if isinstance(value, bool):
|
|
return value
|
|
if isinstance(value, int) and value in {0, 1}:
|
|
return bool(value)
|
|
if isinstance(value, str):
|
|
normalized = value.strip().lower()
|
|
if normalized in {"true", "1", "yes", "on"}:
|
|
return True
|
|
if normalized in {"false", "0", "no", "off"}:
|
|
return False
|
|
raise ValueError(f"{key} must be a boolean")
|
|
|
|
|
|
def _dispatch_tool_call(
|
|
pyro: Pyro, tool_name: str, arguments: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
if tool_name == "vm_run":
|
|
ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)
|
|
timeout_seconds = arguments.get("timeout_seconds", DEFAULT_TIMEOUT_SECONDS)
|
|
return pyro.run_in_vm(
|
|
environment=_require_str(arguments, "environment"),
|
|
command=_require_str(arguments, "command"),
|
|
vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT),
|
|
mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB),
|
|
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
|
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
|
network=_require_bool(arguments, "network", default=False),
|
|
allow_host_compat=_require_bool(
|
|
arguments,
|
|
"allow_host_compat",
|
|
default=DEFAULT_ALLOW_HOST_COMPAT,
|
|
),
|
|
)
|
|
if tool_name == "vm_list_environments":
|
|
return {"environments": pyro.list_environments()}
|
|
if tool_name == "vm_create":
|
|
ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS)
|
|
return pyro.create_vm(
|
|
environment=_require_str(arguments, "environment"),
|
|
vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT),
|
|
mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB),
|
|
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
|
network=_require_bool(arguments, "network", default=False),
|
|
allow_host_compat=_require_bool(
|
|
arguments,
|
|
"allow_host_compat",
|
|
default=DEFAULT_ALLOW_HOST_COMPAT,
|
|
),
|
|
)
|
|
if tool_name == "vm_start":
|
|
return pyro.start_vm(_require_str(arguments, "vm_id"))
|
|
if tool_name == "vm_exec":
|
|
timeout_seconds = arguments.get("timeout_seconds", 30)
|
|
vm_id = _require_str(arguments, "vm_id")
|
|
status = pyro.status_vm(vm_id)
|
|
state = status.get("state")
|
|
if state in {"created", "stopped"}:
|
|
pyro.start_vm(vm_id)
|
|
return pyro.exec_vm(
|
|
vm_id,
|
|
command=_require_str(arguments, "command"),
|
|
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
|
)
|
|
if tool_name == "vm_status":
|
|
return pyro.status_vm(_require_str(arguments, "vm_id"))
|
|
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
|
|
|
|
|
|
def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]:
|
|
payload = {
|
|
"ok": False,
|
|
"tool_name": tool_name,
|
|
"arguments": arguments,
|
|
"error_type": exc.__class__.__name__,
|
|
"error": str(exc),
|
|
}
|
|
error_text = str(exc)
|
|
if "must be a boolean" in error_text:
|
|
payload["hint"] = "Use JSON booleans true or false, not quoted strings."
|
|
if (
|
|
"environment must be a non-empty string" in error_text
|
|
and isinstance(arguments.get("profile"), str)
|
|
):
|
|
payload["hint"] = "Use `environment` instead of `profile`."
|
|
return payload
|
|
|
|
|
|
def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]:
|
|
return pyro.run_in_vm(
|
|
environment="debian:12",
|
|
command=NETWORK_PROOF_COMMAND,
|
|
vcpu_count=DEFAULT_VCPU_COUNT,
|
|
mem_mib=DEFAULT_MEM_MIB,
|
|
timeout_seconds=60,
|
|
ttl_seconds=DEFAULT_TTL_SECONDS,
|
|
network=True,
|
|
)
|
|
|
|
|
|
def _is_vm_id_placeholder(value: str) -> bool:
|
|
normalized = value.strip().lower()
|
|
if normalized in {
|
|
"vm_id_returned_by_vm_create",
|
|
"<vm_id_returned_by_vm_create>",
|
|
"{vm_id_returned_by_vm_create}",
|
|
}:
|
|
return True
|
|
return normalized.startswith("<") and normalized.endswith(">") and "vm_id" in normalized
|
|
|
|
|
|
def _normalize_tool_arguments(
|
|
tool_name: str,
|
|
arguments: dict[str, Any],
|
|
*,
|
|
last_created_vm_id: str | None,
|
|
) -> tuple[dict[str, Any], str | None]:
|
|
normalized_arguments = dict(arguments)
|
|
normalized_vm_id: str | None = None
|
|
if tool_name in {"vm_run", "vm_create"}:
|
|
legacy_profile = normalized_arguments.get("profile")
|
|
if "environment" not in normalized_arguments and isinstance(legacy_profile, str):
|
|
normalized_arguments["environment"] = legacy_profile
|
|
normalized_arguments.pop("profile", None)
|
|
if tool_name in {"vm_start", "vm_exec", "vm_status"} and last_created_vm_id is not None:
|
|
vm_id = normalized_arguments.get("vm_id")
|
|
if isinstance(vm_id, str) and _is_vm_id_placeholder(vm_id):
|
|
normalized_arguments["vm_id"] = last_created_vm_id
|
|
normalized_vm_id = last_created_vm_id
|
|
return normalized_arguments, normalized_vm_id
|
|
|
|
|
|
def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str:
|
|
role = str(message.get("role", "unknown"))
|
|
if not verbose:
|
|
return role
|
|
content = str(message.get("content") or "").strip()
|
|
if content == "":
|
|
return f"{role}: <empty>"
|
|
return f"{role}: {content}"
|
|
|
|
|
|
def _serialize_log_value(value: Any) -> str:
|
|
return json.dumps(value, sort_keys=True, separators=(",", ":"))
|
|
|
|
|
|
def run_ollama_tool_demo(
|
|
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
|
model: str = DEFAULT_OLLAMA_MODEL,
|
|
*,
|
|
pyro: Pyro | None = None,
|
|
strict: bool = True,
|
|
verbose: bool = False,
|
|
log: Callable[[str], None] | None = None,
|
|
) -> dict[str, Any]:
|
|
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
|
|
emit = log or (lambda _: None)
|
|
emit(f"[ollama] starting tool demo with model={model}")
|
|
pyro_client = pyro or Pyro()
|
|
messages: list[dict[str, Any]] = [
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"Use the VM tools to prove outbound internet access in an ephemeral VM.\n"
|
|
"Prefer `vm_run` unless a lower-level lifecycle step is strictly necessary.\n"
|
|
"Use environment `debian:12`, choose adequate vCPU/memory, "
|
|
"and set `network` to true.\n"
|
|
f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n"
|
|
f"Success means the clone completes and the command prints `true`.\n"
|
|
"If a tool returns an error, fix arguments and retry."
|
|
),
|
|
}
|
|
]
|
|
tool_events: list[dict[str, Any]] = []
|
|
final_response = ""
|
|
last_created_vm_id: str | None = None
|
|
|
|
for _round_index in range(1, MAX_TOOL_ROUNDS + 1):
|
|
emit(f"[model] input {_summarize_message_for_log(messages[-1], verbose=verbose)}")
|
|
response = _post_chat_completion(
|
|
base_url,
|
|
{
|
|
"model": model,
|
|
"messages": messages,
|
|
"tools": TOOL_SPECS,
|
|
"tool_choice": "auto",
|
|
"temperature": 0,
|
|
},
|
|
)
|
|
assistant_message = _extract_message(response)
|
|
emit(f"[model] output {_summarize_message_for_log(assistant_message, verbose=verbose)}")
|
|
tool_calls = assistant_message.get("tool_calls")
|
|
if not isinstance(tool_calls, list) or not tool_calls:
|
|
final_response = str(assistant_message.get("content") or "")
|
|
emit("[ollama] no tool calls returned; stopping loop")
|
|
break
|
|
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": str(assistant_message.get("content") or ""),
|
|
"tool_calls": tool_calls,
|
|
}
|
|
)
|
|
|
|
for tool_call in tool_calls:
|
|
if not isinstance(tool_call, dict):
|
|
raise RuntimeError("invalid tool call entry returned by model")
|
|
call_id = tool_call.get("id")
|
|
if not isinstance(call_id, str) or call_id == "":
|
|
raise RuntimeError("tool call did not provide a valid call id")
|
|
function = tool_call.get("function")
|
|
if not isinstance(function, dict):
|
|
raise RuntimeError("tool call did not include function metadata")
|
|
tool_name = function.get("name")
|
|
if not isinstance(tool_name, str):
|
|
raise RuntimeError("tool call function name is invalid")
|
|
arguments = _parse_tool_arguments(function.get("arguments"))
|
|
if verbose:
|
|
emit(f"[model] tool_call {tool_name} args={arguments}")
|
|
else:
|
|
emit(f"[model] tool_call {tool_name}")
|
|
arguments, normalized_vm_id = _normalize_tool_arguments(
|
|
tool_name,
|
|
arguments,
|
|
last_created_vm_id=last_created_vm_id,
|
|
)
|
|
if normalized_vm_id is not None:
|
|
if verbose:
|
|
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
|
|
else:
|
|
emit("[tool] resolved vm_id placeholder")
|
|
if verbose:
|
|
emit(f"[tool] calling {tool_name} with args={arguments}")
|
|
else:
|
|
emit(f"[tool] calling {tool_name}")
|
|
try:
|
|
result = _dispatch_tool_call(pyro_client, tool_name, arguments)
|
|
success = True
|
|
emit(f"[tool] {tool_name} succeeded")
|
|
if tool_name == "vm_create":
|
|
created_vm_id = result.get("vm_id")
|
|
if isinstance(created_vm_id, str) and created_vm_id != "":
|
|
last_created_vm_id = created_vm_id
|
|
except Exception as exc: # noqa: BLE001
|
|
result = _format_tool_error(tool_name, arguments, exc)
|
|
success = False
|
|
emit(
|
|
f"[tool] {tool_name} failed: {exc} "
|
|
f"args={_serialize_log_value(arguments)}"
|
|
)
|
|
if verbose:
|
|
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
|
|
else:
|
|
emit(f"[tool] result {tool_name}")
|
|
tool_events.append(
|
|
{
|
|
"tool_name": tool_name,
|
|
"arguments": arguments,
|
|
"result": result,
|
|
"success": success,
|
|
}
|
|
)
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": call_id,
|
|
"name": tool_name,
|
|
"content": json.dumps(result, sort_keys=True),
|
|
}
|
|
)
|
|
else:
|
|
raise RuntimeError("tool-calling loop exceeded maximum rounds")
|
|
|
|
exec_event = next(
|
|
(
|
|
event
|
|
for event in reversed(tool_events)
|
|
if event.get("tool_name") in {"vm_exec", "vm_run"} and bool(event.get("success"))
|
|
),
|
|
None,
|
|
)
|
|
fallback_used = False
|
|
if exec_event is None:
|
|
if strict:
|
|
raise RuntimeError("demo did not execute a successful vm_run or vm_exec")
|
|
emit("[fallback] model did not complete vm_run; running direct lifecycle command")
|
|
exec_result = _run_direct_lifecycle_fallback(pyro_client)
|
|
fallback_used = True
|
|
tool_events.append(
|
|
{
|
|
"tool_name": "vm_run_fallback",
|
|
"arguments": {"command": NETWORK_PROOF_COMMAND},
|
|
"result": exec_result,
|
|
"success": True,
|
|
}
|
|
)
|
|
else:
|
|
exec_result = exec_event["result"]
|
|
if not isinstance(exec_result, dict):
|
|
raise RuntimeError("vm_exec result shape is invalid")
|
|
if int(exec_result.get("exit_code", -1)) != 0:
|
|
raise RuntimeError("vm_exec failed; expected exit_code=0")
|
|
if str(exec_result.get("stdout", "")).strip() != "true":
|
|
raise RuntimeError("vm_exec output did not confirm repository clone success")
|
|
emit("[done] command execution succeeded")
|
|
|
|
return {
|
|
"model": model,
|
|
"command": NETWORK_PROOF_COMMAND,
|
|
"exec_result": exec_result,
|
|
"tool_events": tool_events,
|
|
"final_response": final_response,
|
|
"fallback_used": fallback_used,
|
|
}
|
|
|
|
|
|
def _build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
|
|
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
|
|
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
|
|
parser.add_argument("-v", "--verbose", action="store_true")
|
|
return parser
|
|
|
|
|
|
def main() -> None:
|
|
"""CLI entrypoint for Ollama tool-calling demo."""
|
|
args = _build_parser().parse_args()
|
|
try:
|
|
result = run_ollama_tool_demo(
|
|
base_url=args.base_url,
|
|
model=args.model,
|
|
verbose=args.verbose,
|
|
log=lambda message: print(message, flush=True),
|
|
)
|
|
except Exception as exc: # noqa: BLE001
|
|
print(f"[error] {exc}", flush=True)
|
|
raise SystemExit(1) from exc
|
|
exec_result = result["exec_result"]
|
|
if not isinstance(exec_result, dict):
|
|
raise RuntimeError("demo produced invalid execution result")
|
|
print(
|
|
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
|
|
f"fallback_used={bool(result.get('fallback_used'))} "
|
|
f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}",
|
|
flush=True,
|
|
)
|
|
if args.verbose:
|
|
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)
|