175 lines
5.9 KiB
Python
175 lines
5.9 KiB
Python
"""Ollama chat-completions demo that triggers `hello_static` tool usage."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import urllib.error
|
|
import urllib.request
|
|
from typing import Any, Final, cast
|
|
|
|
from pyro_mcp.demo import run_demo
|
|
|
|
DEFAULT_OLLAMA_BASE_URL: Final[str] = "http://localhost:11434/v1"
|
|
DEFAULT_OLLAMA_MODEL: Final[str] = "llama:3.2-3b"
|
|
TOOL_NAME: Final[str] = "hello_static"
|
|
|
|
TOOL_SPEC: Final[dict[str, Any]] = {
|
|
"type": "function",
|
|
"function": {
|
|
"name": TOOL_NAME,
|
|
"description": "Returns a deterministic static payload from pyro_mcp.",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {},
|
|
"additionalProperties": False,
|
|
},
|
|
},
|
|
}
|
|
|
|
|
|
def _post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
endpoint = f"{base_url.rstrip('/')}/chat/completions"
|
|
body = json.dumps(payload).encode("utf-8")
|
|
request = urllib.request.Request(
|
|
endpoint,
|
|
data=body,
|
|
headers={"Content-Type": "application/json"},
|
|
method="POST",
|
|
)
|
|
try:
|
|
with urllib.request.urlopen(request, timeout=60) as response:
|
|
response_text = response.read().decode("utf-8")
|
|
except urllib.error.URLError as exc:
|
|
raise RuntimeError(
|
|
"failed to call Ollama. Ensure `ollama serve` is running and the model is available."
|
|
) from exc
|
|
|
|
parsed = json.loads(response_text)
|
|
if not isinstance(parsed, dict):
|
|
raise TypeError("unexpected Ollama response shape")
|
|
return cast(dict[str, Any], parsed)
|
|
|
|
|
|
def _extract_message(response: dict[str, Any]) -> dict[str, Any]:
|
|
choices = response.get("choices")
|
|
if not isinstance(choices, list) or not choices:
|
|
raise RuntimeError("Ollama response did not contain completion choices")
|
|
first = choices[0]
|
|
if not isinstance(first, dict):
|
|
raise RuntimeError("unexpected completion choice format")
|
|
message = first.get("message")
|
|
if not isinstance(message, dict):
|
|
raise RuntimeError("completion choice did not contain a message")
|
|
return cast(dict[str, Any], message)
|
|
|
|
|
|
def _parse_tool_arguments(raw_arguments: Any) -> dict[str, Any]:
|
|
if raw_arguments is None:
|
|
return {}
|
|
if isinstance(raw_arguments, dict):
|
|
return cast(dict[str, Any], raw_arguments)
|
|
if isinstance(raw_arguments, str):
|
|
if raw_arguments.strip() == "":
|
|
return {}
|
|
parsed = json.loads(raw_arguments)
|
|
if not isinstance(parsed, dict):
|
|
raise TypeError("tool arguments must decode to an object")
|
|
return cast(dict[str, Any], parsed)
|
|
raise TypeError("tool arguments must be a dictionary or JSON object string")
|
|
|
|
|
|
def run_ollama_tool_demo(
|
|
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
|
model: str = DEFAULT_OLLAMA_MODEL,
|
|
) -> dict[str, Any]:
|
|
"""Ask Ollama to call the static tool, execute it, and return final model output."""
|
|
messages: list[dict[str, Any]] = [
|
|
{
|
|
"role": "user",
|
|
"content": (
|
|
"Use the hello_static tool and then summarize its payload in one short sentence."
|
|
),
|
|
}
|
|
]
|
|
first_payload: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"tools": [TOOL_SPEC],
|
|
"tool_choice": "auto",
|
|
"temperature": 0,
|
|
}
|
|
first_response = _post_chat_completion(base_url, first_payload)
|
|
assistant_message = _extract_message(first_response)
|
|
|
|
tool_calls = assistant_message.get("tool_calls")
|
|
if not isinstance(tool_calls, list) or not tool_calls:
|
|
raise RuntimeError("model did not trigger any tool call")
|
|
|
|
messages.append(
|
|
{
|
|
"role": "assistant",
|
|
"content": str(assistant_message.get("content") or ""),
|
|
"tool_calls": tool_calls,
|
|
}
|
|
)
|
|
|
|
tool_payload: dict[str, str] | None = None
|
|
for tool_call in tool_calls:
|
|
if not isinstance(tool_call, dict):
|
|
raise RuntimeError("invalid tool call entry returned by model")
|
|
function = tool_call.get("function")
|
|
if not isinstance(function, dict):
|
|
raise RuntimeError("tool call did not include function metadata")
|
|
name = function.get("name")
|
|
if name != TOOL_NAME:
|
|
raise RuntimeError(f"unexpected tool requested by model: {name!r}")
|
|
arguments = _parse_tool_arguments(function.get("arguments"))
|
|
if arguments:
|
|
raise RuntimeError("hello_static does not accept arguments")
|
|
call_id = tool_call.get("id")
|
|
if not isinstance(call_id, str) or call_id == "":
|
|
raise RuntimeError("tool call did not provide a valid call id")
|
|
|
|
tool_payload = asyncio.run(run_demo())
|
|
messages.append(
|
|
{
|
|
"role": "tool",
|
|
"tool_call_id": call_id,
|
|
"name": TOOL_NAME,
|
|
"content": json.dumps(tool_payload, sort_keys=True),
|
|
}
|
|
)
|
|
|
|
if tool_payload is None:
|
|
raise RuntimeError("tool payload was not generated")
|
|
|
|
second_payload: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": messages,
|
|
"temperature": 0,
|
|
}
|
|
second_response = _post_chat_completion(base_url, second_payload)
|
|
final_message = _extract_message(second_response)
|
|
|
|
return {
|
|
"model": model,
|
|
"tool_name": TOOL_NAME,
|
|
"tool_payload": tool_payload,
|
|
"final_response": str(final_message.get("content") or ""),
|
|
}
|
|
|
|
|
|
def _build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="Run Ollama tool-calling demo for pyro_mcp.")
|
|
parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
|
|
parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
|
|
return parser
|
|
|
|
|
|
def main() -> None:
|
|
"""CLI entrypoint for Ollama tool-calling demo."""
|
|
args = _build_parser().parse_args()
|
|
result = run_ollama_tool_demo(base_url=args.base_url, model=args.model)
|
|
print(json.dumps(result, indent=2, sort_keys=True))
|