"""Ollama demo that drives VM lifecycle tools to run an ephemeral command.""" from __future__ import annotations import argparse import json import urllib.error import urllib.request from collections.abc import Callable from typing import Any, Final, cast from pyro_mcp.api import Pyro from pyro_mcp.vm_manager import ( DEFAULT_ALLOW_HOST_COMPAT, DEFAULT_MEM_MIB, DEFAULT_TIMEOUT_SECONDS, DEFAULT_TTL_SECONDS, DEFAULT_VCPU_COUNT, ) __all__ = ["Pyro", "run_ollama_tool_demo"] DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1" DEFAULT_OLLAMA_MODEL: Final[str] = "llama3.2:3b" MAX_TOOL_ROUNDS: Final[int] = 12 NETWORK_PROOF_COMMAND: Final[str] = ( 'python3 -c "import urllib.request as u; ' "print(u.urlopen('https://example.com').status)" '"' ) TOOL_SPECS: Final[list[dict[str, Any]]] = [ { "type": "function", "function": { "name": "vm_run", "description": "Create, start, execute, and clean up an ephemeral VM in one call.", "parameters": { "type": "object", "properties": { "environment": {"type": "string"}, "command": {"type": "string"}, "vcpu_count": {"type": "integer"}, "mem_mib": {"type": "integer"}, "timeout_seconds": {"type": "integer"}, "ttl_seconds": {"type": "integer"}, "network": {"type": "boolean"}, "allow_host_compat": {"type": "boolean"}, }, "required": ["environment", "command"], "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "vm_list_environments", "description": "List curated Linux environments and installation status.", "parameters": { "type": "object", "properties": {}, "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "vm_create", "description": "Create an ephemeral VM with optional resource sizing.", "parameters": { "type": "object", "properties": { "environment": {"type": "string"}, "vcpu_count": {"type": "integer"}, "mem_mib": {"type": "integer"}, "ttl_seconds": {"type": "integer"}, "network": {"type": "boolean"}, "allow_host_compat": {"type": "boolean"}, }, "required": ["environment"], "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "vm_start", "description": "Start a VM before command execution.", "parameters": { "type": "object", "properties": {"vm_id": {"type": "string"}}, "required": ["vm_id"], "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "vm_exec", "description": "Run one non-interactive command inside the VM and auto-clean it.", "parameters": { "type": "object", "properties": { "vm_id": {"type": "string"}, "command": {"type": "string"}, "timeout_seconds": {"type": "integer"}, }, "required": ["vm_id", "command"], "additionalProperties": False, }, }, }, { "type": "function", "function": { "name": "vm_status", "description": "Read current VM status and metadata.", "parameters": { "type": "object", "properties": {"vm_id": {"type": "string"}}, "required": ["vm_id"], "additionalProperties": False, }, }, }, ] def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: endpoint = f"{base_url.rstrip('/')}/chat/completions" body = json.dumps(payload).encode("utf-8") request = urllib.request.Request( endpoint, data=body, headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(request, timeout=90) as response: response_text = response.read().decode("utf-8") except urllib.error.URLError as exc: raise RuntimeError( "failed to call Ollama. Ensure `ollama serve` is running and the model is available." ) from exc parsed = json.loads(response_text) if not isinstance(parsed, dict): raise TypeError("unexpected Ollama response shape") return cast(dict[str, Any], parsed) def _extract_message(response: dict[str, Any]) -> dict[str, Any]: choices = response.get("choices") if not isinstance(choices, list) or not choices: raise RuntimeError("Ollama response did not contain completion choices") first = choices[0] if not isinstance(first, dict): raise RuntimeError("unexpected completion choice format") message = first.get("message") if not isinstance(message, dict): raise RuntimeError("completion choice did not contain a message") return cast(dict[str, Any], message) def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]: if raw_arguments is None: return {} if isinstance(raw_arguments, dict): return cast(dict[str, Any], raw_arguments) if isinstance(raw_arguments, str): if raw_arguments.strip() == "": return {} parsed = json.loads(raw_arguments) if not isinstance(parsed, dict): raise TypeError("tool arguments must decode to an object") return cast(dict[str, Any], parsed) raise TypeError("tool arguments must be a dictionary or JSON object string") def _require_str(arguments: dict[str, Any], key: str) -> str: value = arguments.get(key) if not isinstance(value, str) or value == "": raise ValueError(f"{key} must be a non-empty string") return value def _require_int(arguments: dict[str, Any], key: str) -> int: value = arguments.get(key) if isinstance(value, bool): raise ValueError(f"{key} must be an integer") if isinstance(value, int): return value if isinstance(value, str): normalized = value.strip() if normalized.isdigit(): return int(normalized) raise ValueError(f"{key} must be an integer") def _optional_int(arguments: dict[str, Any], key: str, *, default: int) -> int: if key not in arguments: return default return _require_int(arguments, key) def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) -> bool: value = arguments.get(key, default) if isinstance(value, bool): return value if isinstance(value, int) and value in {0, 1}: return bool(value) if isinstance(value, str): normalized = value.strip().lower() if normalized in {"true", "1", "yes", "on"}: return True if normalized in {"false", "0", "no", "off"}: return False raise ValueError(f"{key} must be a boolean") def _dispatch_tool_call( pyro: Pyro, tool_name: str, arguments: dict[str, Any] ) -> dict[str, Any]: if tool_name == "vm_run": ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS) timeout_seconds = arguments.get("timeout_seconds", DEFAULT_TIMEOUT_SECONDS) return pyro.run_in_vm( environment=_require_str(arguments, "environment"), command=_require_str(arguments, "command"), vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT), mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB), timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"), ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"), network=_require_bool(arguments, "network", default=False), allow_host_compat=_require_bool( arguments, "allow_host_compat", default=DEFAULT_ALLOW_HOST_COMPAT, ), ) if tool_name == "vm_list_environments": return {"environments": pyro.list_environments()} if tool_name == "vm_create": ttl_seconds = arguments.get("ttl_seconds", DEFAULT_TTL_SECONDS) return pyro.create_vm( environment=_require_str(arguments, "environment"), vcpu_count=_optional_int(arguments, "vcpu_count", default=DEFAULT_VCPU_COUNT), mem_mib=_optional_int(arguments, "mem_mib", default=DEFAULT_MEM_MIB), ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"), network=_require_bool(arguments, "network", default=False), allow_host_compat=_require_bool( arguments, "allow_host_compat", default=DEFAULT_ALLOW_HOST_COMPAT, ), ) if tool_name == "vm_start": return pyro.start_vm(_require_str(arguments, "vm_id")) if tool_name == "vm_exec": timeout_seconds = arguments.get("timeout_seconds", 30) vm_id = _require_str(arguments, "vm_id") status = pyro.status_vm(vm_id) state = status.get("state") if state in {"created", "stopped"}: pyro.start_vm(vm_id) return pyro.exec_vm( vm_id, command=_require_str(arguments, "command"), timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"), ) if tool_name == "vm_status": return pyro.status_vm(_require_str(arguments, "vm_id")) raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}") def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]: payload = { "ok": False, "tool_name": tool_name, "arguments": arguments, "error_type": exc.__class__.__name__, "error": str(exc), } error_text = str(exc) if "must be a boolean" in error_text: payload["hint"] = "Use JSON booleans true or false, not quoted strings." if ( "environment must be a non-empty string" in error_text and isinstance(arguments.get("profile"), str) ): payload["hint"] = "Use `environment` instead of `profile`." return payload def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]: return pyro.run_in_vm( environment="debian:12", command=NETWORK_PROOF_COMMAND, vcpu_count=DEFAULT_VCPU_COUNT, mem_mib=DEFAULT_MEM_MIB, timeout_seconds=60, ttl_seconds=DEFAULT_TTL_SECONDS, network=True, ) def _is_vm_id_placeholder(value: str) -> bool: normalized = value.strip().lower() if normalized in { "vm_id_returned_by_vm_create", "", "{vm_id_returned_by_vm_create}", }: return True return normalized.startswith("<") and normalized.endswith(">") and "vm_id" in normalized def _normalize_tool_arguments( tool_name: str, arguments: dict[str, Any], *, last_created_vm_id: str | None, ) -> tuple[dict[str, Any], str | None]: normalized_arguments = dict(arguments) normalized_vm_id: str | None = None if tool_name in {"vm_run", "vm_create"}: legacy_profile = normalized_arguments.get("profile") if "environment" not in normalized_arguments and isinstance(legacy_profile, str): normalized_arguments["environment"] = legacy_profile normalized_arguments.pop("profile", None) if tool_name in {"vm_start", "vm_exec", "vm_status"} and last_created_vm_id is not None: vm_id = normalized_arguments.get("vm_id") if isinstance(vm_id, str) and _is_vm_id_placeholder(vm_id): normalized_arguments["vm_id"] = last_created_vm_id normalized_vm_id = last_created_vm_id return normalized_arguments, normalized_vm_id def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str: role = str(message.get("role", "unknown")) if not verbose: return role content = str(message.get("content") or "").strip() if content == "": return f"{role}: " return f"{role}: {content}" def _serialize_log_value(value: Any) -> str: return json.dumps(value, sort_keys=True, separators=(",", ":")) def run_ollama_tool_demo( base_url: str = DEFAULT_OLLAMA_BASE_URL, model: str = DEFAULT_OLLAMA_MODEL, *, pyro: Pyro | None = None, strict: bool = True, verbose: bool = False, log: Callable[[str], None] | None = None, ) -> dict[str, Any]: """Ask Ollama to run git version check in an ephemeral VM through lifecycle tools.""" emit = log or (lambda _: None) emit(f"[ollama] starting tool demo with model={model}") pyro_client = pyro or Pyro() messages: list[dict[str, Any]] = [ { "role": "user", "content": ( "Use the VM tools to prove outbound internet access in an ephemeral VM.\n" "Prefer `vm_run` unless a lower-level lifecycle step is strictly necessary.\n" "Use environment `debian:12`, choose adequate vCPU/memory, " "and set `network` to true.\n" f"Run this exact command: `{NETWORK_PROOF_COMMAND}`.\n" f"Success means the clone completes and the command prints `true`.\n" "If a tool returns an error, fix arguments and retry." ), } ] tool_events: list[dict[str, Any]] = [] final_response = "" last_created_vm_id: str | None = None for _round_index in range(1, MAX_TOOL_ROUNDS + 1): emit(f"[model] input {_summarize_message_for_log(messages[-1], verbose=verbose)}") response = _post_chat_completion( base_url, { "model": model, "messages": messages, "tools": TOOL_SPECS, "tool_choice": "auto", "temperature": 0, }, ) assistant_message = _extract_message(response) emit(f"[model] output {_summarize_message_for_log(assistant_message, verbose=verbose)}") tool_calls = assistant_message.get("tool_calls") if not isinstance(tool_calls, list) or not tool_calls: final_response = str(assistant_message.get("content") or "") emit("[ollama] no tool calls returned; stopping loop") break messages.append( { "role": "assistant", "content": str(assistant_message.get("content") or ""), "tool_calls": tool_calls, } ) for tool_call in tool_calls: if not isinstance(tool_call, dict): raise RuntimeError("invalid tool call entry returned by model") call_id = tool_call.get("id") if not isinstance(call_id, str) or call_id == "": raise RuntimeError("tool call did not provide a valid call id") function = tool_call.get("function") if not isinstance(function, dict): raise RuntimeError("tool call did not include function metadata") tool_name = function.get("name") if not isinstance(tool_name, str): raise RuntimeError("tool call function name is invalid") arguments = _parse_tool_arguments(function.get("arguments")) if verbose: emit(f"[model] tool_call {tool_name} args={arguments}") else: emit(f"[model] tool_call {tool_name}") arguments, normalized_vm_id = _normalize_tool_arguments( tool_name, arguments, last_created_vm_id=last_created_vm_id, ) if normalized_vm_id is not None: if verbose: emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}") else: emit("[tool] resolved vm_id placeholder") if verbose: emit(f"[tool] calling {tool_name} with args={arguments}") else: emit(f"[tool] calling {tool_name}") try: result = _dispatch_tool_call(pyro_client, tool_name, arguments) success = True emit(f"[tool] {tool_name} succeeded") if tool_name == "vm_create": created_vm_id = result.get("vm_id") if isinstance(created_vm_id, str) and created_vm_id != "": last_created_vm_id = created_vm_id except Exception as exc: # noqa: BLE001 result = _format_tool_error(tool_name, arguments, exc) success = False emit( f"[tool] {tool_name} failed: {exc} " f"args={_serialize_log_value(arguments)}" ) if verbose: emit(f"[tool] result {tool_name} {_serialize_log_value(result)}") else: emit(f"[tool] result {tool_name}") tool_events.append( { "tool_name": tool_name, "arguments": arguments, "result": result, "success": success, } ) messages.append( { "role": "tool", "tool_call_id": call_id, "name": tool_name, "content": json.dumps(result, sort_keys=True), } ) else: raise RuntimeError("tool-calling loop exceeded maximum rounds") exec_event = next( ( event for event in reversed(tool_events) if event.get("tool_name") in {"vm_exec", "vm_run"} and bool(event.get("success")) ), None, ) fallback_used = False if exec_event is None: if strict: raise RuntimeError("demo did not execute a successful vm_run or vm_exec") emit("[fallback] model did not complete vm_run; running direct lifecycle command") exec_result = _run_direct_lifecycle_fallback(pyro_client) fallback_used = True tool_events.append( { "tool_name": "vm_run_fallback", "arguments": {"command": NETWORK_PROOF_COMMAND}, "result": exec_result, "success": True, } ) else: exec_result = exec_event["result"] if not isinstance(exec_result, dict): raise RuntimeError("vm_exec result shape is invalid") if int(exec_result.get("exit_code", -1)) != 0: raise RuntimeError("vm_exec failed; expected exit_code=0") if str(exec_result.get("stdout", "")).strip() != "true": raise RuntimeError("vm_exec output did not confirm repository clone success") emit("[done] command execution succeeded") return { "model": model, "command": NETWORK_PROOF_COMMAND, "exec_result": exec_result, "tool_events": tool_events, "final_response": final_response, "fallback_used": fallback_used, } def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.") parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL) parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL) parser.add_argument("-v", "--verbose", action="store_true") return parser def main() -> None: """CLI entrypoint for Ollama tool-calling demo.""" args = _build_parser().parse_args() try: result = run_ollama_tool_demo( base_url=args.base_url, model=args.model, verbose=args.verbose, log=lambda message: print(message, flush=True), ) except Exception as exc: # noqa: BLE001 print(f"[error] {exc}", flush=True) raise SystemExit(1) from exc exec_result = result["exec_result"] if not isinstance(exec_result, dict): raise RuntimeError("demo produced invalid execution result") print( f"[summary] exit_code={int(exec_result.get('exit_code', -1))} " f"fallback_used={bool(result.get('fallback_used'))} " f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}", flush=True, ) if args.verbose: print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)