Bundle firecracker runtime and switch ollama demo to live logs

This commit is contained in:
Thales Maciel 2026-03-05 20:20:36 -03:00
parent ef0ddeaa11
commit 65f7c0d262
26 changed files with 1896 additions and 408 deletions

View file

@ -1,32 +1,97 @@
"""Ollama chat-completions demo that triggers `hello_static` tool usage."""
"""Ollama demo that drives VM lifecycle tools to run an ephemeral command."""
from __future__ import annotations
import argparse
import asyncio
import json
import urllib.error
import urllib.request
from collections.abc import Callable
from typing import Any, Final, cast
from pyro_mcp.demo import run_demo
from pyro_mcp.vm_manager import VmManager
__all__ = ["VmManager", "run_ollama_tool_demo"]
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
TOOL_NAME: Final[str] = "hello_static"
MAX_TOOL_ROUNDS: Final[int] = 12
TOOL_SPEC: Final[dict[str, Any]] = {
"type": "function",
"function": {
"name": TOOL_NAME,
"description": "Returns a deterministic static payload from pyro_mcp.",
"parameters": {
"type": "object",
"properties": {},
"additionalProperties": False,
TOOL_SPECS: Final[list[dict[str, Any]]] = [
{
"type": "function",
"function": {
"name": "vm_list_profiles",
"description": "List standard VM environment profiles.",
"parameters": {
"type": "object",
"properties": {},
"additionalProperties": False,
},
},
},
}
{
"type": "function",
"function": {
"name": "vm_create",
"description": "Create an ephemeral VM with explicit vCPU and memory sizing.",
"parameters": {
"type": "object",
"properties": {
"profile": {"type": "string"},
"vcpu_count": {"type": "integer"},
"mem_mib": {"type": "integer"},
"ttl_seconds": {"type": "integer"},
},
"required": ["profile", "vcpu_count", "mem_mib"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_start",
"description": "Start a VM before command execution.",
"parameters": {
"type": "object",
"properties": {"vm_id": {"type": "string"}},
"required": ["vm_id"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_exec",
"description": "Run one non-interactive command inside the VM and auto-clean it.",
"parameters": {
"type": "object",
"properties": {
"vm_id": {"type": "string"},
"command": {"type": "string"},
"timeout_seconds": {"type": "integer"},
},
"required": ["vm_id", "command"],
"additionalProperties": False,
},
},
},
{
"type": "function",
"function": {
"name": "vm_status",
"description": "Read current VM status and metadata.",
"parameters": {
"type": "object",
"properties": {"vm_id": {"type": "string"}},
"required": ["vm_id"],
"additionalProperties": False,
},
},
},
]
def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
@ -39,13 +104,12 @@ def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, A
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=60) as response:
with urllib.request.urlopen(request, timeout=90) as response:
response_text = response.read().decode("utf-8")
except urllib.error.URLError as exc:
raise RuntimeError(
"failed to call Ollama. Ensure `ollama serve` is running and the model is available."
) from exc
parsed = json.loads(response_text)
if not isinstance(parsed, dict):
raise TypeError("unexpected Ollama response shape")
@ -80,89 +144,205 @@ def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]:
raise TypeError("tool arguments must be a dictionary or JSON object string")
def _require_str(arguments: dict[str, Any], key: str) -> str:
value = arguments.get(key)
if not isinstance(value, str) or value == "":
raise ValueError(f"{key} must be a non-empty string")
return value
def _require_int(arguments: dict[str, Any], key: str) -> int:
value = arguments.get(key)
if not isinstance(value, int):
raise ValueError(f"{key} must be an integer")
return value
def _dispatch_tool_call(
manager: VmManager, tool_name: str, arguments: dict[str, Any]
) -> dict[str, Any]:
if tool_name == "vm_list_profiles":
return {"profiles": manager.list_profiles()}
if tool_name == "vm_create":
return manager.create_vm(
profile=_require_str(arguments, "profile"),
vcpu_count=_require_int(arguments, "vcpu_count"),
mem_mib=_require_int(arguments, "mem_mib"),
ttl_seconds=arguments.get("ttl_seconds", 600)
if isinstance(arguments.get("ttl_seconds"), int)
else 600,
)
if tool_name == "vm_start":
return manager.start_vm(_require_str(arguments, "vm_id"))
if tool_name == "vm_exec":
return manager.exec_vm(
_require_str(arguments, "vm_id"),
command=_require_str(arguments, "command"),
timeout_seconds=arguments.get("timeout_seconds", 30)
if isinstance(arguments.get("timeout_seconds"), int)
else 30,
)
if tool_name == "vm_status":
return manager.status_vm(_require_str(arguments, "vm_id"))
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]:
return {
"ok": False,
"tool_name": tool_name,
"arguments": arguments,
"error_type": exc.__class__.__name__,
"error": str(exc),
}
def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
vm_id = str(created["vm_id"])
manager.start_vm(vm_id)
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
def run_ollama_tool_demo(
base_url: str = DEFAULT_OLLAMA_BASE_URL,
model: str = DEFAULT_OLLAMA_MODEL,
*,
strict: bool = True,
log: Callable[[str], None] | None = None,
) -> dict[str, Any]:
"""Ask Ollama to call the static tool, execute it, and return final model output."""
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
emit = log or (lambda _: None)
emit(f"[ollama] starting tool demo with model={model}")
manager = VmManager()
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": (
"Use the hello_static tool and then summarize its payload in one short sentence."
"Use the lifecycle tools to run `git --version` in an ephemeral VM.\n"
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
"returned by vm_create into vm_start/vm_exec.\n"
"If a tool returns an error, fix arguments and retry."
),
}
]
first_payload: dict[str, Any] = {
"model": model,
"messages": messages,
"tools": [TOOL_SPEC],
"tool_choice": "auto",
"temperature": 0,
}
first_response = _post_chat_completion(base_url, first_payload)
assistant_message = _extract_message(first_response)
tool_events: list[dict[str, Any]] = []
final_response = ""
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
raise RuntimeError("model did not trigger any tool call")
for round_index in range(1, MAX_TOOL_ROUNDS + 1):
emit(f"[ollama] round {round_index}: requesting completion")
response = _post_chat_completion(
base_url,
{
"model": model,
"messages": messages,
"tools": TOOL_SPECS,
"tool_choice": "auto",
"temperature": 0,
},
)
assistant_message = _extract_message(response)
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
final_response = str(assistant_message.get("content") or "")
emit("[ollama] no tool calls returned; stopping loop")
break
messages.append(
{
"role": "assistant",
"content": str(assistant_message.get("content") or ""),
"tool_calls": tool_calls,
}
)
tool_payload: dict[str, str] | None = None
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise RuntimeError("invalid tool call entry returned by model")
function = tool_call.get("function")
if not isinstance(function, dict):
raise RuntimeError("tool call did not include function metadata")
name = function.get("name")
if name != TOOL_NAME:
raise RuntimeError(f"unexpected tool requested by model: {name!r}")
arguments = _parse_tool_arguments(function.get("arguments"))
if arguments:
raise RuntimeError("hello_static does not accept arguments")
call_id = tool_call.get("id")
if not isinstance(call_id, str) or call_id == "":
raise RuntimeError("tool call did not provide a valid call id")
tool_payload = asyncio.run(run_demo())
messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"name": TOOL_NAME,
"content": json.dumps(tool_payload, sort_keys=True),
"role": "assistant",
"content": str(assistant_message.get("content") or ""),
"tool_calls": tool_calls,
}
)
if tool_payload is None:
raise RuntimeError("tool payload was not generated")
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise RuntimeError("invalid tool call entry returned by model")
call_id = tool_call.get("id")
if not isinstance(call_id, str) or call_id == "":
raise RuntimeError("tool call did not provide a valid call id")
function = tool_call.get("function")
if not isinstance(function, dict):
raise RuntimeError("tool call did not include function metadata")
tool_name = function.get("name")
if not isinstance(tool_name, str):
raise RuntimeError("tool call function name is invalid")
arguments = _parse_tool_arguments(function.get("arguments"))
emit(f"[tool] calling {tool_name} with args={arguments}")
try:
result = _dispatch_tool_call(manager, tool_name, arguments)
success = True
emit(f"[tool] {tool_name} succeeded")
except Exception as exc: # noqa: BLE001
result = _format_tool_error(tool_name, arguments, exc)
success = False
emit(f"[tool] {tool_name} failed: {exc}")
tool_events.append(
{
"tool_name": tool_name,
"arguments": arguments,
"result": result,
"success": success,
}
)
messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"name": tool_name,
"content": json.dumps(result, sort_keys=True),
}
)
else:
raise RuntimeError("tool-calling loop exceeded maximum rounds")
second_payload: dict[str, Any] = {
"model": model,
"messages": messages,
"temperature": 0,
}
second_response = _post_chat_completion(base_url, second_payload)
final_message = _extract_message(second_response)
exec_event = next(
(
event
for event in reversed(tool_events)
if event.get("tool_name") == "vm_exec" and bool(event.get("success"))
),
None,
)
fallback_used = False
if exec_event is None:
if strict:
raise RuntimeError("demo did not execute a successful vm_exec")
emit("[fallback] model did not complete vm_exec; running direct lifecycle command")
exec_result = _run_direct_lifecycle_fallback(manager)
fallback_used = True
tool_events.append(
{
"tool_name": "vm_exec_fallback",
"arguments": {"command": "git --version"},
"result": exec_result,
"success": True,
}
)
else:
exec_result = exec_event["result"]
if not isinstance(exec_result, dict):
raise RuntimeError("vm_exec result shape is invalid")
if int(exec_result.get("exit_code", -1)) != 0:
raise RuntimeError("vm_exec failed; expected exit_code=0")
if "git version" not in str(exec_result.get("stdout", "")):
raise RuntimeError("vm_exec output did not contain `git version`")
emit("[done] command execution succeeded")
return {
"model": model,
"tool_name": TOOL_NAME,
"tool_payload": tool_payload,
"final_response": str(final_message.get("content") or ""),
"command": "git --version",
"exec_result": exec_result,
"tool_events": tool_events,
"final_response": final_response,
"fallback_used": fallback_used,
}
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for pyro_mcp.")
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
return parser
@ -171,5 +351,18 @@ def _build_parser() -> argparse.ArgumentParser:
def main() -> None:
"""CLI entrypoint for Ollama tool-calling demo."""
args = _build_parser().parse_args()
result = run_ollama_tool_demo(base_url=args.base_url, model=args.model)
print(json.dumps(result, indent=2, sort_keys=True))
result = run_ollama_tool_demo(
base_url=args.base_url,
model=args.model,
strict=False,
log=lambda message: print(message, flush=True),
)
exec_result = result["exec_result"]
if not isinstance(exec_result, dict):
raise RuntimeError("demo produced invalid execution result")
print(
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
f"fallback_used={bool(result.get('fallback_used'))}",
flush=True,
)
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)