"""Ollama chat-completions demo that triggers `hello_static` tool usage.""" from __future__ import annotations import argparse import asyncio import json import urllib.error import urllib.request from typing import Any, Final, cast from pyro_mcp.demo import run_demo DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1" DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b" TOOL_NAME: Final[str] = "hello_static" TOOL_SPEC: Final[dict[str, Any]] = { "type": "function", "function": { "name": TOOL_NAME, "description": "Returns a deterministic static payload from pyro_mcp.", "parameters": { "type": "object", "properties": {}, "additionalProperties": False, }, }, } def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: endpoint = f"{base_url.rstrip('/')}/chat/completions" body = json.dumps(payload).encode("utf-8") request = urllib.request.Request( endpoint, data=body, headers={"Content-Type": "application/json"}, method="POST", ) try: with urllib.request.urlopen(request, timeout=60) as response: response_text = response.read().decode("utf-8") except urllib.error.URLError as exc: raise RuntimeError( "failed to call Ollama. Ensure `ollama serve` is running and the model is available." ) from exc parsed = json.loads(response_text) if not isinstance(parsed, dict): raise TypeError("unexpected Ollama response shape") return cast(dict[str, Any], parsed) def _extract_message(response: dict[str, Any]) -> dict[str, Any]: choices = response.get("choices") if not isinstance(choices, list) or not choices: raise RuntimeError("Ollama response did not contain completion choices") first = choices[0] if not isinstance(first, dict): raise RuntimeError("unexpected completion choice format") message = first.get("message") if not isinstance(message, dict): raise RuntimeError("completion choice did not contain a message") return cast(dict[str, Any], message) def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]: if raw_arguments is None: return {} if isinstance(raw_arguments, dict): return cast(dict[str, Any], raw_arguments) if isinstance(raw_arguments, str): if raw_arguments.strip() == "": return {} parsed = json.loads(raw_arguments) if not isinstance(parsed, dict): raise TypeError("tool arguments must decode to an object") return cast(dict[str, Any], parsed) raise TypeError("tool arguments must be a dictionary or JSON object string") def run_ollama_tool_demo( base_url: str = DEFAULT_OLLAMA_BASE_URL, model: str = DEFAULT_OLLAMA_MODEL, ) -> dict[str, Any]: """Ask Ollama to call the static tool, execute it, and return final model output.""" messages: list[dict[str, Any]] = [ { "role": "user", "content": ( "Use the hello_static tool and then summarize its payload in one short sentence." ), } ] first_payload: dict[str, Any] = { "model": model, "messages": messages, "tools": [TOOL_SPEC], "tool_choice": "auto", "temperature": 0, } first_response = _post_chat_completion(base_url, first_payload) assistant_message = _extract_message(first_response) tool_calls = assistant_message.get("tool_calls") if not isinstance(tool_calls, list) or not tool_calls: raise RuntimeError("model did not trigger any tool call") messages.append( { "role": "assistant", "content": str(assistant_message.get("content") or ""), "tool_calls": tool_calls, } ) tool_payload: dict[str, str] | None = None for tool_call in tool_calls: if not isinstance(tool_call, dict): raise RuntimeError("invalid tool call entry returned by model") function = tool_call.get("function") if not isinstance(function, dict): raise RuntimeError("tool call did not include function metadata") name = function.get("name") if name != TOOL_NAME: raise RuntimeError(f"unexpected tool requested by model: {name!r}") arguments = _parse_tool_arguments(function.get("arguments")) if arguments: raise RuntimeError("hello_static does not accept arguments") call_id = tool_call.get("id") if not isinstance(call_id, str) or call_id == "": raise RuntimeError("tool call did not provide a valid call id") tool_payload = asyncio.run(run_demo()) messages.append( { "role": "tool", "tool_call_id": call_id, "name": TOOL_NAME, "content": json.dumps(tool_payload, sort_keys=True), } ) if tool_payload is None: raise RuntimeError("tool payload was not generated") second_payload: dict[str, Any] = { "model": model, "messages": messages, "temperature": 0, } second_response = _post_chat_completion(base_url, second_payload) final_message = _extract_message(second_response) return { "model": model, "tool_name": TOOL_NAME, "tool_payload": tool_payload, "final_response": str(final_message.get("content") or ""), } def _build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for pyro_mcp.") parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL) parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL) return parser def main() -> None: """CLI entrypoint for Ollama tool-calling demo.""" args = _build_parser().parse_args() result = run_ollama_tool_demo(base_url=args.base_url, model=args.model) print(json.dumps(result, indent=2, sort_keys=True))