Harden Ollama demo tool-call handling and logging
This commit is contained in:
parent
65f7c0d262
commit
fb8b985049
2 changed files with 219 additions and 22 deletions
|
|
@ -128,7 +128,11 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non
|
|||
assert "git version" in str(result["exec_result"]["stdout"])
|
||||
assert result["final_response"] == "Executed git command in ephemeral VM."
|
||||
assert len(result["tool_events"]) == 4
|
||||
assert any("[model] input user:" in line for line in logs)
|
||||
assert any("[model] output assistant:" in line for line in logs)
|
||||
assert any("[model] tool_call vm_exec" in line for line in logs)
|
||||
assert any("[tool] calling vm_exec" in line for line in logs)
|
||||
assert any("[tool] result vm_exec " in line for line in logs)
|
||||
|
||||
|
||||
def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
|
||||
|
|
@ -176,6 +180,95 @@ def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
|
|||
assert int(result["exec_result"]["exit_code"]) == 0
|
||||
|
||||
|
||||
def test_run_ollama_tool_demo_resolves_vm_id_placeholder(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
requests: list[dict[str, Any]] = []
|
||||
|
||||
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
assert base_url == "http://localhost:11434/v1"
|
||||
requests.append(payload)
|
||||
step = len(requests)
|
||||
if step == 1:
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{"id": "1", "function": {"name": "vm_list_profiles"}},
|
||||
{"id": "2", "function": {"name": "vm_list_profiles"}},
|
||||
{
|
||||
"id": "3",
|
||||
"function": {
|
||||
"name": "vm_create",
|
||||
"arguments": json.dumps(
|
||||
{
|
||||
"profile": "debian-git",
|
||||
"vcpu_count": "2",
|
||||
"mem_mib": "2048",
|
||||
}
|
||||
),
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"function": {
|
||||
"name": "vm_exec",
|
||||
"arguments": json.dumps(
|
||||
{
|
||||
"vm_id": "<vm_id_returned_by_vm_create>",
|
||||
"command": "printf 'git version 2.44.0\\n'",
|
||||
"timeout_seconds": "300",
|
||||
}
|
||||
),
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
return {
|
||||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Executed git command in ephemeral VM.",
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
|
||||
|
||||
logs: list[str] = []
|
||||
result = ollama_demo.run_ollama_tool_demo(log=logs.append)
|
||||
|
||||
assert result["fallback_used"] is False
|
||||
assert int(result["exec_result"]["exit_code"]) == 0
|
||||
assert any("resolved vm_id placeholder" in line for line in logs)
|
||||
exec_event = next(
|
||||
event for event in result["tool_events"] if event["tool_name"] == "vm_exec"
|
||||
)
|
||||
assert exec_event["success"] is True
|
||||
|
||||
|
||||
def test_dispatch_tool_call_vm_exec_autostarts_created_vm(tmp_path: Path) -> None:
|
||||
manager = RealVmManager(backend_name="mock", base_dir=tmp_path / "vms")
|
||||
created = manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=60)
|
||||
vm_id = str(created["vm_id"])
|
||||
|
||||
executed = ollama_demo._dispatch_tool_call(
|
||||
manager,
|
||||
"vm_exec",
|
||||
{"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"},
|
||||
)
|
||||
|
||||
assert int(executed["exit_code"]) == 0
|
||||
|
||||
|
||||
def test_run_ollama_tool_demo_raises_without_vm_exec(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
del base_url, payload
|
||||
|
|
@ -203,6 +296,7 @@ def test_run_ollama_tool_demo_uses_fallback_when_not_strict(
|
|||
result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append)
|
||||
assert result["fallback_used"] is True
|
||||
assert int(result["exec_result"]["exit_code"]) == 0
|
||||
assert any("[model] output assistant: No tools" in line for line in logs)
|
||||
assert any("[fallback]" in line for line in logs)
|
||||
|
||||
|
||||
|
|
@ -300,7 +394,7 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
|
|||
created = ollama_demo._dispatch_tool_call(
|
||||
manager,
|
||||
"vm_create",
|
||||
{"profile": "debian-base", "vcpu_count": 1, "mem_mib": 512},
|
||||
{"profile": "debian-base", "vcpu_count": "1", "mem_mib": "512", "ttl_seconds": "60"},
|
||||
)
|
||||
vm_id = str(created["vm_id"])
|
||||
started = ollama_demo._dispatch_tool_call(manager, "vm_start", {"vm_id": vm_id})
|
||||
|
|
@ -308,7 +402,9 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
|
|||
status = ollama_demo._dispatch_tool_call(manager, "vm_status", {"vm_id": vm_id})
|
||||
assert status["vm_id"] == vm_id
|
||||
executed = ollama_demo._dispatch_tool_call(
|
||||
manager, "vm_exec", {"vm_id": vm_id, "command": "printf 'git version\\n'"}
|
||||
manager,
|
||||
"vm_exec",
|
||||
{"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"},
|
||||
)
|
||||
assert int(executed["exit_code"]) == 0
|
||||
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
|
||||
|
|
@ -334,10 +430,22 @@ def test_require_str(arguments: dict[str, Any], error: str) -> None:
|
|||
ollama_demo._require_str(arguments, "k")
|
||||
|
||||
|
||||
def test_require_int_validation() -> None:
|
||||
@pytest.mark.parametrize(
|
||||
("arguments", "expected"),
|
||||
[
|
||||
({"k": 1}, 1),
|
||||
({"k": "1"}, 1),
|
||||
({"k": " 42 "}, 42),
|
||||
],
|
||||
)
|
||||
def test_require_int_accepts_numeric_strings(arguments: dict[str, Any], expected: int) -> None:
|
||||
assert ollama_demo._require_int(arguments, "k") == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value", ["", "abc", "2.5", True])
|
||||
def test_require_int_validation(value: Any) -> None:
|
||||
with pytest.raises(ValueError, match="must be an integer"):
|
||||
ollama_demo._require_int({"k": "1"}, "k")
|
||||
assert ollama_demo._require_int({"k": 1}, "k") == 1
|
||||
ollama_demo._require_int({"k": value}, "k")
|
||||
|
||||
|
||||
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
|
|
@ -444,3 +552,26 @@ def test_main_uses_parser_and_prints_logs(
|
|||
output = capsys.readouterr().out
|
||||
assert "[summary] exit_code=0 fallback_used=False" in output
|
||||
assert "[summary] stdout=git version 2.44.0" in output
|
||||
|
||||
|
||||
def test_main_logs_error_and_exits_nonzero(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
class StubParser:
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return argparse.Namespace(base_url="http://x", model="m")
|
||||
|
||||
monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser())
|
||||
|
||||
def fake_run(base_url: str, model: str, strict: bool = True, log: Any = None) -> dict[str, Any]:
|
||||
del base_url, model, strict, log
|
||||
raise RuntimeError("demo did not execute a successful vm_exec")
|
||||
|
||||
monkeypatch.setattr(ollama_demo, "run_ollama_tool_demo", fake_run)
|
||||
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
ollama_demo.main()
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "[error] demo did not execute a successful vm_exec" in output
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue