Unify public UX around pyro CLI and Pyro facade

This commit is contained in:
Thales Maciel 2026-03-07 16:28:28 -03:00
parent d16aadd03f
commit 23a2dfb330
19 changed files with 936 additions and 407 deletions

View file

@ -10,16 +10,17 @@ from typing import Any
import pytest
import pyro_mcp.ollama_demo as ollama_demo
from pyro_mcp.api import Pyro as RealPyro
from pyro_mcp.vm_manager import VmManager as RealVmManager
@pytest.fixture(autouse=True)
def _mock_vm_manager_for_tests(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
class TestVmManager(RealVmManager):
class TestPyro(RealPyro):
def __init__(self) -> None:
super().__init__(backend_name="mock", base_dir=tmp_path / "vms")
super().__init__(manager=RealVmManager(backend_name="mock", base_dir=tmp_path / "vms"))
monkeypatch.setattr(ollama_demo, "VmManager", TestVmManager)
monkeypatch.setattr(ollama_demo, "Pyro", TestPyro)
def _stepwise_model_response(payload: dict[str, Any], step: int) -> dict[str, Any]:
@ -46,55 +47,14 @@ def _stepwise_model_response(payload: dict[str, Any], step: int) -> dict[str, An
{
"id": "2",
"function": {
"name": "vm_create",
"arguments": json.dumps(
{"profile": "debian-git", "vcpu_count": 1, "mem_mib": 512}
),
},
}
],
}
}
]
}
if step == 3:
vm_id = json.loads(payload["messages"][-1]["content"])["vm_id"]
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "3",
"function": {
"name": "vm_start",
"arguments": json.dumps({"vm_id": vm_id}),
},
}
],
}
}
]
}
if step == 4:
vm_id = json.loads(payload["messages"][-1]["content"])["vm_id"]
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "4",
"function": {
"name": "vm_exec",
"name": "vm_run",
"arguments": json.dumps(
{
"vm_id": vm_id,
"profile": "debian-git",
"command": "printf 'true\\n'",
"vcpu_count": 1,
"mem_mib": 512,
"network": True,
}
),
},
@ -127,12 +87,12 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non
assert result["fallback_used"] is False
assert str(result["exec_result"]["stdout"]).strip() == "true"
assert result["final_response"] == "Executed git command in ephemeral VM."
assert len(result["tool_events"]) == 4
assert len(result["tool_events"]) == 2
assert any(line == "[model] input user" for line in logs)
assert any(line == "[model] output assistant" for line in logs)
assert any("[model] tool_call vm_exec" in line for line in logs)
assert any(line == "[tool] calling vm_exec" for line in logs)
assert any(line == "[tool] result vm_exec" for line in logs)
assert any("[model] tool_call vm_run" in line for line in logs)
assert any(line == "[tool] calling vm_run" for line in logs)
assert any(line == "[tool] result vm_run" for line in logs)
def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
@ -256,12 +216,12 @@ def test_run_ollama_tool_demo_resolves_vm_id_placeholder(
def test_dispatch_tool_call_vm_exec_autostarts_created_vm(tmp_path: Path) -> None:
manager = RealVmManager(backend_name="mock", base_dir=tmp_path / "vms")
created = manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=60)
pyro = RealPyro(manager=RealVmManager(backend_name="mock", base_dir=tmp_path / "vms"))
created = pyro.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=60)
vm_id = str(created["vm_id"])
executed = ollama_demo._dispatch_tool_call(
manager,
pyro,
"vm_exec",
{"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"},
)
@ -275,7 +235,7 @@ def test_run_ollama_tool_demo_raises_without_vm_exec(monkeypatch: pytest.MonkeyP
return {"choices": [{"message": {"role": "assistant", "content": "No tools"}}]}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
with pytest.raises(RuntimeError, match="did not execute a successful vm_exec"):
with pytest.raises(RuntimeError, match="did not execute a successful vm_run or vm_exec"):
ollama_demo.run_ollama_tool_demo()
@ -286,16 +246,16 @@ def test_run_ollama_tool_demo_uses_fallback_when_not_strict(
del base_url, payload
return {"choices": [{"message": {"role": "assistant", "content": "No tools"}}]}
class TestVmManager(RealVmManager):
class TestPyro(RealPyro):
def __init__(self) -> None:
super().__init__(backend_name="mock", base_dir=tmp_path / "vms")
super().__init__(manager=RealVmManager(backend_name="mock", base_dir=tmp_path / "vms"))
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
monkeypatch.setattr(ollama_demo, "VmManager", TestVmManager)
monkeypatch.setattr(ollama_demo, "Pyro", TestPyro)
monkeypatch.setattr(
ollama_demo,
"_run_direct_lifecycle_fallback",
lambda manager: {
lambda pyro: {
"vm_id": "vm-1",
"command": ollama_demo.NETWORK_PROOF_COMMAND,
"stdout": "true\n",
@ -332,7 +292,7 @@ def test_run_ollama_tool_demo_verbose_logs_values(monkeypatch: pytest.MonkeyPatc
assert str(result["exec_result"]["stdout"]).strip() == "true"
assert any("[model] input user:" in line for line in logs)
assert any("[model] tool_call vm_list_profiles args={}" in line for line in logs)
assert any("[tool] result vm_exec " in line for line in logs)
assert any("[tool] result vm_run " in line for line in logs)
@pytest.mark.parametrize(
@ -397,7 +357,7 @@ def test_run_ollama_tool_demo_exec_result_validation(
"message": {
"role": "assistant",
"tool_calls": [
{"id": "1", "function": {"name": "vm_exec", "arguments": "{}"}}
{"id": "1", "function": {"name": "vm_run", "arguments": "{}"}}
],
}
}
@ -410,9 +370,9 @@ def test_run_ollama_tool_demo_exec_result_validation(
del base_url, payload
return responses.pop(0)
def fake_dispatch(manager: Any, tool_name: str, arguments: dict[str, Any]) -> Any:
del manager, arguments
if tool_name == "vm_exec":
def fake_dispatch(pyro: Any, tool_name: str, arguments: dict[str, Any]) -> Any:
del pyro, arguments
if tool_name == "vm_run":
return exec_result
return {"ok": True}
@ -423,27 +383,46 @@ def test_run_ollama_tool_demo_exec_result_validation(
def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
manager = RealVmManager(backend_name="mock", base_dir=tmp_path / "vms")
profiles = ollama_demo._dispatch_tool_call(manager, "vm_list_profiles", {})
pyro = RealPyro(manager=RealVmManager(backend_name="mock", base_dir=tmp_path / "vms"))
profiles = ollama_demo._dispatch_tool_call(pyro, "vm_list_profiles", {})
assert "profiles" in profiles
created = ollama_demo._dispatch_tool_call(
manager,
pyro,
"vm_create",
{"profile": "debian-base", "vcpu_count": "1", "mem_mib": "512", "ttl_seconds": "60"},
{
"profile": "debian-base",
"vcpu_count": "1",
"mem_mib": "512",
"ttl_seconds": "60",
"network": False,
},
)
vm_id = str(created["vm_id"])
started = ollama_demo._dispatch_tool_call(manager, "vm_start", {"vm_id": vm_id})
started = ollama_demo._dispatch_tool_call(pyro, "vm_start", {"vm_id": vm_id})
assert started["state"] == "started"
status = ollama_demo._dispatch_tool_call(manager, "vm_status", {"vm_id": vm_id})
status = ollama_demo._dispatch_tool_call(pyro, "vm_status", {"vm_id": vm_id})
assert status["vm_id"] == vm_id
executed = ollama_demo._dispatch_tool_call(
manager,
pyro,
"vm_exec",
{"vm_id": vm_id, "command": "printf 'true\\n'", "timeout_seconds": "30"},
)
assert int(executed["exit_code"]) == 0
executed_run = ollama_demo._dispatch_tool_call(
pyro,
"vm_run",
{
"profile": "debian-base",
"command": "printf 'true\\n'",
"vcpu_count": "1",
"mem_mib": "512",
"timeout_seconds": "30",
"network": False,
},
)
assert int(executed_run["exit_code"]) == 0
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
ollama_demo._dispatch_tool_call(manager, "nope", {})
ollama_demo._dispatch_tool_call(pyro, "nope", {})
def test_format_tool_error() -> None:
@ -483,6 +462,16 @@ def test_require_int_validation(value: Any) -> None:
ollama_demo._require_int({"k": value}, "k")
@pytest.mark.parametrize(("arguments", "expected"), [({}, False), ({"k": True}, True)])
def test_require_bool(arguments: dict[str, Any], expected: bool) -> None:
assert ollama_demo._require_bool(arguments, "k", default=False) is expected
def test_require_bool_validation() -> None:
with pytest.raises(ValueError, match="must be a boolean"):
ollama_demo._require_bool({"k": "true"}, "k")
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
class StubResponse:
def __enter__(self) -> StubResponse: