Bundle firecracker runtime and switch ollama demo to live logs
This commit is contained in:
parent
ef0ddeaa11
commit
65f7c0d262
26 changed files with 1896 additions and 408 deletions
|
|
@ -1,32 +1,97 @@
|
|||
"""Ollama chat-completions demo that triggers `hello_static` tool usage."""
|
||||
"""Ollama demo that drives VM lifecycle tools to run an ephemeral command."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Final, cast
|
||||
|
||||
from pyro_mcp.demo import run_demo
|
||||
from pyro_mcp.vm_manager import VmManager
|
||||
|
||||
__all__ = ["VmManager", "run_ollama_tool_demo"]
|
||||
|
||||
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
|
||||
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
|
||||
TOOL_NAME: Final[str] = "hello_static"
|
||||
MAX_TOOL_ROUNDS: Final[int] = 12
|
||||
|
||||
TOOL_SPEC: Final[dict[str, Any]] = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": TOOL_NAME,
|
||||
"description": "Returns a deterministic static payload from pyro_mcp.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
TOOL_SPECS: Final[list[dict[str, Any]]] = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "vm_list_profiles",
|
||||
"description": "List standard VM environment profiles.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "vm_create",
|
||||
"description": "Create an ephemeral VM with explicit vCPU and memory sizing.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile": {"type": "string"},
|
||||
"vcpu_count": {"type": "integer"},
|
||||
"mem_mib": {"type": "integer"},
|
||||
"ttl_seconds": {"type": "integer"},
|
||||
},
|
||||
"required": ["profile", "vcpu_count", "mem_mib"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "vm_start",
|
||||
"description": "Start a VM before command execution.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"vm_id": {"type": "string"}},
|
||||
"required": ["vm_id"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "vm_exec",
|
||||
"description": "Run one non-interactive command inside the VM and auto-clean it.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"vm_id": {"type": "string"},
|
||||
"command": {"type": "string"},
|
||||
"timeout_seconds": {"type": "integer"},
|
||||
},
|
||||
"required": ["vm_id", "command"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "vm_status",
|
||||
"description": "Read current VM status and metadata.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"vm_id": {"type": "string"}},
|
||||
"required": ["vm_id"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
|
|
@ -39,13 +104,12 @@ def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, A
|
|||
method="POST",
|
||||
)
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=60) as response:
|
||||
with urllib.request.urlopen(request, timeout=90) as response:
|
||||
response_text = response.read().decode("utf-8")
|
||||
except urllib.error.URLError as exc:
|
||||
raise RuntimeError(
|
||||
"failed to call Ollama. Ensure `ollama serve` is running and the model is available."
|
||||
) from exc
|
||||
|
||||
parsed = json.loads(response_text)
|
||||
if not isinstance(parsed, dict):
|
||||
raise TypeError("unexpected Ollama response shape")
|
||||
|
|
@ -80,89 +144,205 @@ def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]:
|
|||
raise TypeError("tool arguments must be a dictionary or JSON object string")
|
||||
|
||||
|
||||
def _require_str(arguments: dict[str, Any], key: str) -> str:
|
||||
value = arguments.get(key)
|
||||
if not isinstance(value, str) or value == "":
|
||||
raise ValueError(f"{key} must be a non-empty string")
|
||||
return value
|
||||
|
||||
|
||||
def _require_int(arguments: dict[str, Any], key: str) -> int:
|
||||
value = arguments.get(key)
|
||||
if not isinstance(value, int):
|
||||
raise ValueError(f"{key} must be an integer")
|
||||
return value
|
||||
|
||||
|
||||
def _dispatch_tool_call(
|
||||
manager: VmManager, tool_name: str, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if tool_name == "vm_list_profiles":
|
||||
return {"profiles": manager.list_profiles()}
|
||||
if tool_name == "vm_create":
|
||||
return manager.create_vm(
|
||||
profile=_require_str(arguments, "profile"),
|
||||
vcpu_count=_require_int(arguments, "vcpu_count"),
|
||||
mem_mib=_require_int(arguments, "mem_mib"),
|
||||
ttl_seconds=arguments.get("ttl_seconds", 600)
|
||||
if isinstance(arguments.get("ttl_seconds"), int)
|
||||
else 600,
|
||||
)
|
||||
if tool_name == "vm_start":
|
||||
return manager.start_vm(_require_str(arguments, "vm_id"))
|
||||
if tool_name == "vm_exec":
|
||||
return manager.exec_vm(
|
||||
_require_str(arguments, "vm_id"),
|
||||
command=_require_str(arguments, "command"),
|
||||
timeout_seconds=arguments.get("timeout_seconds", 30)
|
||||
if isinstance(arguments.get("timeout_seconds"), int)
|
||||
else 30,
|
||||
)
|
||||
if tool_name == "vm_status":
|
||||
return manager.status_vm(_require_str(arguments, "vm_id"))
|
||||
raise RuntimeError(f"unexpected tool requested by model: {tool_name!r}")
|
||||
|
||||
|
||||
def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]:
|
||||
return {
|
||||
"ok": False,
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments,
|
||||
"error_type": exc.__class__.__name__,
|
||||
"error": str(exc),
|
||||
}
|
||||
|
||||
|
||||
def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
|
||||
created = manager.create_vm(profile="debian-git", vcpu_count=1, mem_mib=512, ttl_seconds=600)
|
||||
vm_id = str(created["vm_id"])
|
||||
manager.start_vm(vm_id)
|
||||
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
|
||||
|
||||
|
||||
def run_ollama_tool_demo(
|
||||
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
||||
model: str = DEFAULT_OLLAMA_MODEL,
|
||||
*,
|
||||
strict: bool = True,
|
||||
log: Callable[[str], None] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Ask Ollama to call the static tool, execute it, and return final model output."""
|
||||
"""Ask Ollama to run git version check in an ephemeral VM through lifecycle tools."""
|
||||
emit = log or (lambda _: None)
|
||||
emit(f"[ollama] starting tool demo with model={model}")
|
||||
manager = VmManager()
|
||||
messages: list[dict[str, Any]] = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
"Use the hello_static tool and then summarize its payload in one short sentence."
|
||||
"Use the lifecycle tools to run `git --version` in an ephemeral VM.\n"
|
||||
"Required order: vm_list_profiles -> vm_create -> vm_start -> vm_exec.\n"
|
||||
"Use profile `debian-git`, choose adequate vCPU/memory, and pass the `vm_id` "
|
||||
"returned by vm_create into vm_start/vm_exec.\n"
|
||||
"If a tool returns an error, fix arguments and retry."
|
||||
),
|
||||
}
|
||||
]
|
||||
first_payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": [TOOL_SPEC],
|
||||
"tool_choice": "auto",
|
||||
"temperature": 0,
|
||||
}
|
||||
first_response = _post_chat_completion(base_url, first_payload)
|
||||
assistant_message = _extract_message(first_response)
|
||||
tool_events: list[dict[str, Any]] = []
|
||||
final_response = ""
|
||||
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
raise RuntimeError("model did not trigger any tool call")
|
||||
for round_index in range(1, MAX_TOOL_ROUNDS + 1):
|
||||
emit(f"[ollama] round {round_index}: requesting completion")
|
||||
response = _post_chat_completion(
|
||||
base_url,
|
||||
{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"tools": TOOL_SPECS,
|
||||
"tool_choice": "auto",
|
||||
"temperature": 0,
|
||||
},
|
||||
)
|
||||
assistant_message = _extract_message(response)
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
final_response = str(assistant_message.get("content") or "")
|
||||
emit("[ollama] no tool calls returned; stopping loop")
|
||||
break
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": str(assistant_message.get("content") or ""),
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
tool_payload: dict[str, str] | None = None
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
raise RuntimeError("invalid tool call entry returned by model")
|
||||
function = tool_call.get("function")
|
||||
if not isinstance(function, dict):
|
||||
raise RuntimeError("tool call did not include function metadata")
|
||||
name = function.get("name")
|
||||
if name != TOOL_NAME:
|
||||
raise RuntimeError(f"unexpected tool requested by model: {name!r}")
|
||||
arguments = _parse_tool_arguments(function.get("arguments"))
|
||||
if arguments:
|
||||
raise RuntimeError("hello_static does not accept arguments")
|
||||
call_id = tool_call.get("id")
|
||||
if not isinstance(call_id, str) or call_id == "":
|
||||
raise RuntimeError("tool call did not provide a valid call id")
|
||||
|
||||
tool_payload = asyncio.run(run_demo())
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": TOOL_NAME,
|
||||
"content": json.dumps(tool_payload, sort_keys=True),
|
||||
"role": "assistant",
|
||||
"content": str(assistant_message.get("content") or ""),
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
)
|
||||
|
||||
if tool_payload is None:
|
||||
raise RuntimeError("tool payload was not generated")
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
raise RuntimeError("invalid tool call entry returned by model")
|
||||
call_id = tool_call.get("id")
|
||||
if not isinstance(call_id, str) or call_id == "":
|
||||
raise RuntimeError("tool call did not provide a valid call id")
|
||||
function = tool_call.get("function")
|
||||
if not isinstance(function, dict):
|
||||
raise RuntimeError("tool call did not include function metadata")
|
||||
tool_name = function.get("name")
|
||||
if not isinstance(tool_name, str):
|
||||
raise RuntimeError("tool call function name is invalid")
|
||||
arguments = _parse_tool_arguments(function.get("arguments"))
|
||||
emit(f"[tool] calling {tool_name} with args={arguments}")
|
||||
try:
|
||||
result = _dispatch_tool_call(manager, tool_name, arguments)
|
||||
success = True
|
||||
emit(f"[tool] {tool_name} succeeded")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
result = _format_tool_error(tool_name, arguments, exc)
|
||||
success = False
|
||||
emit(f"[tool] {tool_name} failed: {exc}")
|
||||
tool_events.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
"arguments": arguments,
|
||||
"result": result,
|
||||
"success": success,
|
||||
}
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
"name": tool_name,
|
||||
"content": json.dumps(result, sort_keys=True),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("tool-calling loop exceeded maximum rounds")
|
||||
|
||||
second_payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": 0,
|
||||
}
|
||||
second_response = _post_chat_completion(base_url, second_payload)
|
||||
final_message = _extract_message(second_response)
|
||||
exec_event = next(
|
||||
(
|
||||
event
|
||||
for event in reversed(tool_events)
|
||||
if event.get("tool_name") == "vm_exec" and bool(event.get("success"))
|
||||
),
|
||||
None,
|
||||
)
|
||||
fallback_used = False
|
||||
if exec_event is None:
|
||||
if strict:
|
||||
raise RuntimeError("demo did not execute a successful vm_exec")
|
||||
emit("[fallback] model did not complete vm_exec; running direct lifecycle command")
|
||||
exec_result = _run_direct_lifecycle_fallback(manager)
|
||||
fallback_used = True
|
||||
tool_events.append(
|
||||
{
|
||||
"tool_name": "vm_exec_fallback",
|
||||
"arguments": {"command": "git --version"},
|
||||
"result": exec_result,
|
||||
"success": True,
|
||||
}
|
||||
)
|
||||
else:
|
||||
exec_result = exec_event["result"]
|
||||
if not isinstance(exec_result, dict):
|
||||
raise RuntimeError("vm_exec result shape is invalid")
|
||||
if int(exec_result.get("exit_code", -1)) != 0:
|
||||
raise RuntimeError("vm_exec failed; expected exit_code=0")
|
||||
if "git version" not in str(exec_result.get("stdout", "")):
|
||||
raise RuntimeError("vm_exec output did not contain `git version`")
|
||||
emit("[done] command execution succeeded")
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"tool_name": TOOL_NAME,
|
||||
"tool_payload": tool_payload,
|
||||
"final_response": str(final_message.get("content") or ""),
|
||||
"command": "git --version",
|
||||
"exec_result": exec_result,
|
||||
"tool_events": tool_events,
|
||||
"final_response": final_response,
|
||||
"fallback_used": fallback_used,
|
||||
}
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for pyro_mcp.")
|
||||
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for ephemeral VMs.")
|
||||
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
|
||||
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
|
||||
return parser
|
||||
|
|
@ -171,5 +351,18 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
def main() -> None:
|
||||
"""CLI entrypoint for Ollama tool-calling demo."""
|
||||
args = _build_parser().parse_args()
|
||||
result = run_ollama_tool_demo(base_url=args.base_url, model=args.model)
|
||||
print(json.dumps(result, indent=2, sort_keys=True))
|
||||
result = run_ollama_tool_demo(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
strict=False,
|
||||
log=lambda message: print(message, flush=True),
|
||||
)
|
||||
exec_result = result["exec_result"]
|
||||
if not isinstance(exec_result, dict):
|
||||
raise RuntimeError("demo produced invalid execution result")
|
||||
print(
|
||||
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
|
||||
f"fallback_used={bool(result.get('fallback_used'))}",
|
||||
flush=True,
|
||||
)
|
||||
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue