pyro-mcp/tests/test_ollama_demo.py

255 lines
8.4 KiB
Python

from __future__ import annotations
import argparse
import json
import urllib.error
import urllib.request
from typing import Any
import pytest
import pyro_mcp.ollama_demo as ollama_demo
from pyro_mcp.server import HELLO_STATIC_PAYLOAD
def test_run_ollama_tool_demo_triggers_tool_and_returns_final_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
requests: list[dict[str, Any]] = []
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
assert base_url == "http://localhost:11434/v1"
requests.append(payload)
if len(requests) == 1:
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "hello_static", "arguments": "{}"},
}
],
}
}
]
}
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "Tool says hello from pyro_mcp.",
}
}
]
}
async def fake_run_demo() -> dict[str, str]:
return HELLO_STATIC_PAYLOAD
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
monkeypatch.setattr(ollama_demo, "run_demo", fake_run_demo)
result = ollama_demo.run_ollama_tool_demo()
assert result["tool_payload"] == HELLO_STATIC_PAYLOAD
assert result["final_response"] == "Tool says hello from pyro_mcp."
assert len(requests) == 2
assert requests[0]["tools"][0]["function"]["name"] == "hello_static"
tool_message = requests[1]["messages"][-1]
assert tool_message["role"] == "tool"
assert tool_message["tool_call_id"] == "call_1"
def test_run_ollama_tool_demo_raises_when_model_does_not_call_tool(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {"choices": [{"message": {"role": "assistant", "content": "No tool call."}}]}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match="model did not trigger any tool call"):
ollama_demo.run_ollama_tool_demo()
def test_run_ollama_tool_demo_raises_on_unexpected_tool(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {"name": "unexpected_tool", "arguments": "{}"},
}
],
}
}
]
}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
ollama_demo.run_ollama_tool_demo()
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
class StubResponse:
def __enter__(self) -> StubResponse:
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
def read(self) -> bytes:
return b'{"ok": true}'
def fake_urlopen(request: Any, timeout: int) -> StubResponse:
assert timeout == 60
assert request.full_url == "http://localhost:11434/v1/chat/completions"
return StubResponse()
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
result = ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
assert result == {"ok": True}
def test_post_chat_completion_raises_for_ollama_connection_error(
monkeypatch: pytest.MonkeyPatch,
) -> None:
def fake_urlopen(request: Any, timeout: int) -> Any:
del request, timeout
raise urllib.error.URLError("boom")
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
with pytest.raises(RuntimeError, match="failed to call Ollama"):
ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
def test_post_chat_completion_raises_for_non_object_response(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class StubResponse:
def __enter__(self) -> StubResponse:
return self
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
del exc_type, exc, tb
def read(self) -> bytes:
return b'["not-an-object"]'
def fake_urlopen(request: Any, timeout: int) -> StubResponse:
del request, timeout
return StubResponse()
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
with pytest.raises(TypeError, match="unexpected Ollama response shape"):
ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
@pytest.mark.parametrize(
("response", "expected_error"),
[
({}, "did not contain completion choices"),
({"choices": [1]}, "unexpected completion choice format"),
({"choices": [{"message": "bad"}]}, "did not contain a message"),
],
)
def test_extract_message_validation_errors(
response: dict[str, Any],
expected_error: str,
) -> None:
with pytest.raises(RuntimeError, match=expected_error):
ollama_demo._extract_message(response)
def test_parse_tool_arguments_variants() -> None:
assert ollama_demo._parse_tool_arguments(None) == {}
assert ollama_demo._parse_tool_arguments({}) == {}
assert ollama_demo._parse_tool_arguments("") == {}
assert ollama_demo._parse_tool_arguments('{"a": 1}') == {"a": 1}
def test_parse_tool_arguments_rejects_invalid_types() -> None:
with pytest.raises(TypeError, match="must decode to an object"):
ollama_demo._parse_tool_arguments("[]")
with pytest.raises(TypeError, match="must be a dictionary or JSON object string"):
ollama_demo._parse_tool_arguments(123)
@pytest.mark.parametrize(
("tool_call", "expected_error"),
[
(1, "invalid tool call entry"),
({"id": "c1"}, "did not include function metadata"),
(
{"id": "c1", "function": {"name": "hello_static", "arguments": '{"x": 1}'}},
"does not accept arguments",
),
(
{"id": "", "function": {"name": "hello_static", "arguments": "{}"}},
"did not provide a valid call id",
),
],
)
def test_run_ollama_tool_demo_validation_branches(
monkeypatch: pytest.MonkeyPatch,
tool_call: Any,
expected_error: str,
) -> None:
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
del base_url, payload
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [tool_call],
}
}
]
}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match=expected_error):
ollama_demo.run_ollama_tool_demo()
def test_main_uses_parser_and_prints_json(
monkeypatch: pytest.MonkeyPatch,
capsys: pytest.CaptureFixture[str],
) -> None:
class StubParser:
def parse_args(self) -> argparse.Namespace:
return argparse.Namespace(base_url="http://x", model="m")
monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser())
monkeypatch.setattr(
ollama_demo,
"run_ollama_tool_demo",
lambda base_url, model: {"base_url": base_url, "model": model},
)
ollama_demo.main()
output = json.loads(capsys.readouterr().out)
assert output == {"base_url": "http://x", "model": "m"}