pyro-mcp/src/pyro_mcp/ollama_demo.py

175 lines
5.9 KiB
Python

"""Ollama chat-completions demo that triggers `hello_static` tool usage."""
from __future__ import annotations
import argparse
import asyncio
import json
import urllib.error
import urllib.request
from typing import Any, Final, cast
from pyro_mcp.demo import run_demo
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
TOOL_NAME: Final[str] = "hello_static"
TOOL_SPEC: Final[dict[str, Any]] = {
"type": "function",
"function": {
"name": TOOL_NAME,
"description": "Returns a deterministic static payload from pyro_mcp.",
"parameters": {
"type": "object",
"properties": {},
"additionalProperties": False,
},
},
}
def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
endpoint = f"{base_url.rstrip('/')}/chat/completions"
body = json.dumps(payload).encode("utf-8")
request = urllib.request.Request(
endpoint,
data=body,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(request, timeout=60) as response:
response_text = response.read().decode("utf-8")
except urllib.error.URLError as exc:
raise RuntimeError(
"failed to call Ollama. Ensure `ollama serve` is running and the model is available."
) from exc
parsed = json.loads(response_text)
if not isinstance(parsed, dict):
raise TypeError("unexpected Ollama response shape")
return cast(dict[str, Any], parsed)
def _extract_message(response: dict[str, Any]) -> dict[str, Any]:
choices = response.get("choices")
if not isinstance(choices, list) or not choices:
raise RuntimeError("Ollama response did not contain completion choices")
first = choices[0]
if not isinstance(first, dict):
raise RuntimeError("unexpected completion choice format")
message = first.get("message")
if not isinstance(message, dict):
raise RuntimeError("completion choice did not contain a message")
return cast(dict[str, Any], message)
def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]:
if raw_arguments is None:
return {}
if isinstance(raw_arguments, dict):
return cast(dict[str, Any], raw_arguments)
if isinstance(raw_arguments, str):
if raw_arguments.strip() == "":
return {}
parsed = json.loads(raw_arguments)
if not isinstance(parsed, dict):
raise TypeError("tool arguments must decode to an object")
return cast(dict[str, Any], parsed)
raise TypeError("tool arguments must be a dictionary or JSON object string")
def run_ollama_tool_demo(
base_url: str = DEFAULT_OLLAMA_BASE_URL,
model: str = DEFAULT_OLLAMA_MODEL,
) -> dict[str, Any]:
"""Ask Ollama to call the static tool, execute it, and return final model output."""
messages: list[dict[str, Any]] = [
{
"role": "user",
"content": (
"Use the hello_static tool and then summarize its payload in one short sentence."
),
}
]
first_payload: dict[str, Any] = {
"model": model,
"messages": messages,
"tools": [TOOL_SPEC],
"tool_choice": "auto",
"temperature": 0,
}
first_response = _post_chat_completion(base_url, first_payload)
assistant_message = _extract_message(first_response)
tool_calls = assistant_message.get("tool_calls")
if not isinstance(tool_calls, list) or not tool_calls:
raise RuntimeError("model did not trigger any tool call")
messages.append(
{
"role": "assistant",
"content": str(assistant_message.get("content") or ""),
"tool_calls": tool_calls,
}
)
tool_payload: dict[str, str] | None = None
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise RuntimeError("invalid tool call entry returned by model")
function = tool_call.get("function")
if not isinstance(function, dict):
raise RuntimeError("tool call did not include function metadata")
name = function.get("name")
if name != TOOL_NAME:
raise RuntimeError(f"unexpected tool requested by model: {name!r}")
arguments = _parse_tool_arguments(function.get("arguments"))
if arguments:
raise RuntimeError("hello_static does not accept arguments")
call_id = tool_call.get("id")
if not isinstance(call_id, str) or call_id == "":
raise RuntimeError("tool call did not provide a valid call id")
tool_payload = asyncio.run(run_demo())
messages.append(
{
"role": "tool",
"tool_call_id": call_id,
"name": TOOL_NAME,
"content": json.dumps(tool_payload, sort_keys=True),
}
)
if tool_payload is None:
raise RuntimeError("tool payload was not generated")
second_payload: dict[str, Any] = {
"model": model,
"messages": messages,
"temperature": 0,
}
second_response = _post_chat_completion(base_url, second_payload)
final_message = _extract_message(second_response)
return {
"model": model,
"tool_name": TOOL_NAME,
"tool_payload": tool_payload,
"final_response": str(final_message.get("content") or ""),
}
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for pyro_mcp.")
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
return parser
def main() -> None:
"""CLI entrypoint for Ollama tool-calling demo."""
args = _build_parser().parse_args()
result = run_ollama_tool_demo(base_url=args.base_url, model=args.model)
print(json.dumps(result, indent=2, sort_keys=True))