255 lines
8.4 KiB
Python
255 lines
8.4 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import urllib.error
|
|
import urllib.request
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
import pyro_mcp.ollama_demo as ollama_demo
|
|
from pyro_mcp.server import HELLO_STATIC_PAYLOAD
|
|
|
|
|
|
def test_run_ollama_tool_demo_triggers_tool_and_returns_final_response(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
requests: list[dict[str, Any]] = []
|
|
|
|
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
assert base_url == "http://localhost:11434/v1"
|
|
requests.append(payload)
|
|
if len(requests) == 1:
|
|
return {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {"name": "hello_static", "arguments": "{}"},
|
|
}
|
|
],
|
|
}
|
|
}
|
|
]
|
|
}
|
|
return {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "Tool says hello from pyro_mcp.",
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
async def fake_run_demo() -> dict[str, str]:
|
|
return HELLO_STATIC_PAYLOAD
|
|
|
|
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
|
|
monkeypatch.setattr(ollama_demo, "run_demo", fake_run_demo)
|
|
|
|
result = ollama_demo.run_ollama_tool_demo()
|
|
|
|
assert result["tool_payload"] == HELLO_STATIC_PAYLOAD
|
|
assert result["final_response"] == "Tool says hello from pyro_mcp."
|
|
assert len(requests) == 2
|
|
assert requests[0]["tools"][0]["function"]["name"] == "hello_static"
|
|
tool_message = requests[1]["messages"][-1]
|
|
assert tool_message["role"] == "tool"
|
|
assert tool_message["tool_call_id"] == "call_1"
|
|
|
|
|
|
def test_run_ollama_tool_demo_raises_when_model_does_not_call_tool(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
del base_url, payload
|
|
return {"choices": [{"message": {"role": "assistant", "content": "No tool call."}}]}
|
|
|
|
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
|
|
|
|
with pytest.raises(RuntimeError, match="model did not trigger any tool call"):
|
|
ollama_demo.run_ollama_tool_demo()
|
|
|
|
|
|
def test_run_ollama_tool_demo_raises_on_unexpected_tool(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
del base_url, payload
|
|
return {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [
|
|
{
|
|
"id": "call_1",
|
|
"type": "function",
|
|
"function": {"name": "unexpected_tool", "arguments": "{}"},
|
|
}
|
|
],
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
|
|
|
|
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
|
|
ollama_demo.run_ollama_tool_demo()
|
|
|
|
|
|
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
class StubResponse:
|
|
def __enter__(self) -> StubResponse:
|
|
return self
|
|
|
|
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
|
del exc_type, exc, tb
|
|
|
|
def read(self) -> bytes:
|
|
return b'{"ok": true}'
|
|
|
|
def fake_urlopen(request: Any, timeout: int) -> StubResponse:
|
|
assert timeout == 60
|
|
assert request.full_url == "http://localhost:11434/v1/chat/completions"
|
|
return StubResponse()
|
|
|
|
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
|
|
|
|
result = ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
|
|
assert result == {"ok": True}
|
|
|
|
|
|
def test_post_chat_completion_raises_for_ollama_connection_error(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
def fake_urlopen(request: Any, timeout: int) -> Any:
|
|
del request, timeout
|
|
raise urllib.error.URLError("boom")
|
|
|
|
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
|
|
|
|
with pytest.raises(RuntimeError, match="failed to call Ollama"):
|
|
ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
|
|
|
|
|
|
def test_post_chat_completion_raises_for_non_object_response(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
) -> None:
|
|
class StubResponse:
|
|
def __enter__(self) -> StubResponse:
|
|
return self
|
|
|
|
def __exit__(self, exc_type: object, exc: object, tb: object) -> None:
|
|
del exc_type, exc, tb
|
|
|
|
def read(self) -> bytes:
|
|
return b'["not-an-object"]'
|
|
|
|
def fake_urlopen(request: Any, timeout: int) -> StubResponse:
|
|
del request, timeout
|
|
return StubResponse()
|
|
|
|
monkeypatch.setattr(urllib.request, "urlopen", fake_urlopen)
|
|
|
|
with pytest.raises(TypeError, match="unexpected Ollama response shape"):
|
|
ollama_demo._post_chat_completion("http://localhost:11434/v1", {"x": 1})
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("response", "expected_error"),
|
|
[
|
|
({}, "did not contain completion choices"),
|
|
({"choices": [1]}, "unexpected completion choice format"),
|
|
({"choices": [{"message": "bad"}]}, "did not contain a message"),
|
|
],
|
|
)
|
|
def test_extract_message_validation_errors(
|
|
response: dict[str, Any],
|
|
expected_error: str,
|
|
) -> None:
|
|
with pytest.raises(RuntimeError, match=expected_error):
|
|
ollama_demo._extract_message(response)
|
|
|
|
|
|
def test_parse_tool_arguments_variants() -> None:
|
|
assert ollama_demo._parse_tool_arguments(None) == {}
|
|
assert ollama_demo._parse_tool_arguments({}) == {}
|
|
assert ollama_demo._parse_tool_arguments("") == {}
|
|
assert ollama_demo._parse_tool_arguments('{"a": 1}') == {"a": 1}
|
|
|
|
|
|
def test_parse_tool_arguments_rejects_invalid_types() -> None:
|
|
with pytest.raises(TypeError, match="must decode to an object"):
|
|
ollama_demo._parse_tool_arguments("[]")
|
|
with pytest.raises(TypeError, match="must be a dictionary or JSON object string"):
|
|
ollama_demo._parse_tool_arguments(123)
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
("tool_call", "expected_error"),
|
|
[
|
|
(1, "invalid tool call entry"),
|
|
({"id": "c1"}, "did not include function metadata"),
|
|
(
|
|
{"id": "c1", "function": {"name": "hello_static", "arguments": '{"x": 1}'}},
|
|
"does not accept arguments",
|
|
),
|
|
(
|
|
{"id": "", "function": {"name": "hello_static", "arguments": "{}"}},
|
|
"did not provide a valid call id",
|
|
),
|
|
],
|
|
)
|
|
def test_run_ollama_tool_demo_validation_branches(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
tool_call: Any,
|
|
expected_error: str,
|
|
) -> None:
|
|
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
|
del base_url, payload
|
|
return {
|
|
"choices": [
|
|
{
|
|
"message": {
|
|
"role": "assistant",
|
|
"content": "",
|
|
"tool_calls": [tool_call],
|
|
}
|
|
}
|
|
]
|
|
}
|
|
|
|
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
|
|
|
|
with pytest.raises(RuntimeError, match=expected_error):
|
|
ollama_demo.run_ollama_tool_demo()
|
|
|
|
|
|
def test_main_uses_parser_and_prints_json(
|
|
monkeypatch: pytest.MonkeyPatch,
|
|
capsys: pytest.CaptureFixture[str],
|
|
) -> None:
|
|
class StubParser:
|
|
def parse_args(self) -> argparse.Namespace:
|
|
return argparse.Namespace(base_url="http://x", model="m")
|
|
|
|
monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser())
|
|
monkeypatch.setattr(
|
|
ollama_demo,
|
|
"run_ollama_tool_demo",
|
|
lambda base_url, model: {"base_url": base_url, "model": model},
|
|
)
|
|
|
|
ollama_demo.main()
|
|
|
|
output = json.loads(capsys.readouterr().out)
|
|
assert output == {"base_url": "http://x", "model": "m"}
|