from __future__ import annotations import argparse import json import urllib.error import urllib.request from typing import Any import pytest import pyro_mcp.ollama_demo as ollama_demo from pyro_mcp.server import HELLO_STATIC_PAYLOAD def test_run_ollama_tool_demo_triggers_tool_and_returns_final_response( monkeypatch: pytest.MonkeyPatch, ) -> None: requests: list[dict[str, Any]] = [] def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: assert base_url == "http://localhost:11434/v1" requests.append(payload) if len(requests) == 1: return { "choices": [ { "message": { "role": "assistant", "content": "", "tool_calls": [ { "id": "call_1", "type": "function", "function": {"name": "hello_static", "arguments": "{}"}, } ], } } ] } return { "choices": [ { "message": { "role": "assistant", "content": "Tool says hello from pyro_mcp.", } } ] } async def fake_run_demo() -> dict[str, str]: return HELLO_STATIC_PAYLOAD monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) monkeypatch.setattr(ollama_demo, "run_demo", fake_run_demo) result = ollama_demo.run_ollama_tool_demo() assert result["tool_payload"] == HELLO_STATIC_PAYLOAD assert result["final_response"] == "Tool says hello from pyro_mcp." assert len(requests) == 2 assert requests[0]["tools"][0]["function"]["name"] == "hello_static" tool_message = requests[1]["messages"][-1] assert tool_message["role"] == "tool" assert tool_message["tool_call_id"] == "call_1" def test_run_ollama_tool_demo_raises_when_model_does_not_call_tool( monkeypatch: pytest.MonkeyPatch, ) -> None: def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: del base_url, payload return {"choices": [{"message": {"role": "assistant", "content": "No tool call."}}]} monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) with pytest.raises(RuntimeError, match="model did not trigger any tool call"): ollama_demo.run_ollama_tool_demo() def test_run_ollama_tool_demo_raises_on_unexpected_tool(monkeypatch: pytest.MonkeyPatch) -> None: def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: del base_url, payload return { "choices": [ { "message": { "role": "assistant", "content": "", "tool_calls": [ { "id": "call_1", "type": "function", "function": {"name": "unexpected_tool", "arguments": "{}"}, } ], } } ] } monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) with pytest.raises(RuntimeError, match="unexpected tool requested by model"): ollama_demo.run_ollama_tool_demo() def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None: class StubResponse: def __enter__(self) -> StubResponse: return self def __exit__(self, exc_type: object, exc: object, tb: object) -> None: del exc_type, exc, tb def read(self) -> bytes: return b'{"ok": true}' def fake_urlopen(request: Any, timeout: int) -> StubResponse: assert timeout == 60 assert request.full_url == "http://localhost:11434/v1/chat/completions" return StubResponse() monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen) result = ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1}) assert result == {"ok": True} def test_post_chat_completion_raises_for_ollama_connection_error( monkeypatch: pytest.MonkeyPatch, ) -> None: def fake_urlopen(request: Any, timeout: int) -> Any: del request, timeout raise urllib.error.URLError("boom") monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen) with pytest.raises(RuntimeError, match="failed to call Ollama"): ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1}) def test_post_chat_completion_raises_for_non_object_response( monkeypatch: pytest.MonkeyPatch, ) -> None: class StubResponse: def __enter__(self) -> StubResponse: return self def __exit__(self, exc_type: object, exc: object, tb: object) -> None: del exc_type, exc, tb def read(self) -> bytes: return b'["not-an-object"]' def fake_urlopen(request: Any, timeout: int) -> StubResponse: del request, timeout return StubResponse() monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen) with pytest.raises(TypeError, match="unexpected Ollama response shape"): ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1}) @pytest.mark.parametrize( ("response", "expected_error"), [ ({}, "did not contain completion choices"), ({"choices": [1]}, "unexpected completion choice format"), ({"choices": [{"message": "bad"}]}, "did not contain a message"), ], ) def test_extract_message_validation_errors( response: dict[str, Any], expected_error: str, ) -> None: with pytest.raises(RuntimeError, match=expected_error): ollama_demo._extract_message(response) def test_parse_tool_arguments_variants() -> None: assert ollama_demo._parse_tool_arguments(None) == {} assert ollama_demo._parse_tool_arguments({}) == {} assert ollama_demo._parse_tool_arguments("") == {} assert ollama_demo._parse_tool_arguments('{"a": 1}') == {"a": 1} def test_parse_tool_arguments_rejects_invalid_types() -> None: with pytest.raises(TypeError, match="must decode to an object"): ollama_demo._parse_tool_arguments("[]") with pytest.raises(TypeError, match="must be a dictionary or JSON object string"): ollama_demo._parse_tool_arguments(123) @pytest.mark.parametrize( ("tool_call", "expected_error"), [ (1, "invalid tool call entry"), ({"id": "c1"}, "did not include function metadata"), ( {"id": "c1", "function": {"name": "hello_static", "arguments": '{"x": 1}'}}, "does not accept arguments", ), ( {"id": "", "function": {"name": "hello_static", "arguments": "{}"}}, "did not provide a valid call id", ), ], ) def test_run_ollama_tool_demo_validation_branches( monkeypatch: pytest.MonkeyPatch, tool_call: Any, expected_error: str, ) -> None: def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: del base_url, payload return { "choices": [ { "message": { "role": "assistant", "content": "", "tool_calls": [tool_call], } } ] } monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) with pytest.raises(RuntimeError, match=expected_error): ollama_demo.run_ollama_tool_demo() def test_main_uses_parser_and_prints_json( monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str], ) -> None: class StubParser: def parse_args(self) -> argparse.Namespace: return argparse.Namespace(base_url="http://x", model="m") monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser()) monkeypatch.setattr( ollama_demo, "run_ollama_tool_demo", lambda base_url, model: {"base_url": base_url, "model": model}, ) ollama_demo.main() output = json.loads(capsys.readouterr().out) assert output == {"base_url": "http://x", "model": "m"}