From fb8b98504952c8c948ae499af281c9bc8a7e0c7e Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Thu, 5 Mar 2026 21:37:02 -0300 Subject: [PATCH] Harden Ollama demo tool-call handling and logging --- src/pyro_mcp/ollama_demo.py | 100 ++++++++++++++++++++----- tests/test_ollama_demo.py | 141 ++++++++++++++++++++++++++++++++++-- 2 files changed, 219 insertions(+), 22 deletions(-) diff --git a/src/pyro_mcp/ollama_demo.py b/src/pyro_mcp/ollama_demo.py index ea3a267..0f6ba73 100644 --- a/src/pyro_mcp/ollama_demo.py +++ b/src/pyro_mcp/ollama_demo.py @@ -153,9 +153,15 @@ def _require_str(arguments: dict[str, Any], key: str) -> str: def _require_int(arguments: dict[str, Any], key: str) -> int: value = arguments.get(key) - if not isinstance(value, int): + if isinstance(value, bool): raise ValueError(f"{key} must be an integer") - return value + if isinstance(value, int): + return value + if isinstance(value, str): + normalized = value.strip() + if normalized.isdigit(): + return int(normalized) + raise ValueError(f"{key} must be an integer") def _dispatch_tool_call( @@ -164,23 +170,26 @@ def _dispatch_tool_call( if tool_name == "vm_list_profiles": return {"profiles": manager.list_profiles()} if tool_name == "vm_create": + ttl_seconds = arguments.get("ttl_seconds", 600) return manager.create_vm( profile=_require_str(arguments, "profile"), vcpu_count=_require_int(arguments, "vcpu_count"), mem_mib=_require_int(arguments, "mem_mib"), - ttl_seconds=arguments.get("ttl_seconds", 600) - if isinstance(arguments.get("ttl_seconds"), int) - else 600, + ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"), ) if tool_name == "vm_start": return manager.start_vm(_require_str(arguments, "vm_id")) if tool_name == "vm_exec": + timeout_seconds = arguments.get("timeout_seconds", 30) + vm_id = _require_str(arguments, "vm_id") + status = manager.status_vm(vm_id) + state = status.get("state") + if state in {"created", "stopped"}: + manager.start_vm(vm_id) return manager.exec_vm( - _require_str(arguments, "vm_id"), + vm_id, command=_require_str(arguments, "command"), - timeout_seconds=arguments.get("timeout_seconds", 30) - if isinstance(arguments.get("timeout_seconds"), int) - else 30, + timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"), ) if tool_name == "vm_status": return manager.status_vm(_require_str(arguments, "vm_id")) @@ -204,6 +213,45 @@ def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]: return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30) +def _is_vm_id_placeholder(value: str) -> bool: + normalized = value.strip().lower() + if normalized in { + "vm_id_returned_by_vm_create", + "", + "{vm_id_returned_by_vm_create}", + }: + return True + return normalized.startswith("<") and normalized.endswith(">") and "vm_id" in normalized + + +def _normalize_tool_arguments( + tool_name: str, + arguments: dict[str, Any], + *, + last_created_vm_id: str | None, +) -> tuple[dict[str, Any], str | None]: + if tool_name not in {"vm_start", "vm_exec", "vm_status"} or last_created_vm_id is None: + return arguments, None + vm_id = arguments.get("vm_id") + if not isinstance(vm_id, str) or not _is_vm_id_placeholder(vm_id): + return arguments, None + normalized_arguments = dict(arguments) + normalized_arguments["vm_id"] = last_created_vm_id + return normalized_arguments, last_created_vm_id + + +def _summarize_message_for_log(message: dict[str, Any]) -> str: + role = str(message.get("role", "unknown")) + content = str(message.get("content") or "").strip() + if content == "": + return f"{role}: " + return f"{role}: {content}" + + +def _serialize_log_value(value: Any) -> str: + return json.dumps(value, sort_keys=True, separators=(",", ":")) + + def run_ollama_tool_demo( base_url: str = DEFAULT_OLLAMA_BASE_URL, model: str = DEFAULT_OLLAMA_MODEL, @@ -229,9 +277,10 @@ def run_ollama_tool_demo( ] tool_events: list[dict[str, Any]] = [] final_response = "" + last_created_vm_id: str | None = None - for round_index in range(1, MAX_TOOL_ROUNDS + 1): - emit(f"[ollama] round {round_index}: requesting completion") + for _round_index in range(1, MAX_TOOL_ROUNDS + 1): + emit(f"[model] input {_summarize_message_for_log(messages[-1])}") response = _post_chat_completion( base_url, { @@ -243,6 +292,7 @@ def run_ollama_tool_demo( }, ) assistant_message = _extract_message(response) + emit(f"[model] output {_summarize_message_for_log(assistant_message)}") tool_calls = assistant_message.get("tool_calls") if not isinstance(tool_calls, list) or not tool_calls: final_response = str(assistant_message.get("content") or "") @@ -270,15 +320,28 @@ def run_ollama_tool_demo( if not isinstance(tool_name, str): raise RuntimeError("tool call function name is invalid") arguments = _parse_tool_arguments(function.get("arguments")) + emit(f"[model] tool_call {tool_name} args={arguments}") + arguments, normalized_vm_id = _normalize_tool_arguments( + tool_name, + arguments, + last_created_vm_id=last_created_vm_id, + ) + if normalized_vm_id is not None: + emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}") emit(f"[tool] calling {tool_name} with args={arguments}") try: result = _dispatch_tool_call(manager, tool_name, arguments) success = True emit(f"[tool] {tool_name} succeeded") + if tool_name == "vm_create": + created_vm_id = result.get("vm_id") + if isinstance(created_vm_id, str) and created_vm_id != "": + last_created_vm_id = created_vm_id except Exception as exc: # noqa: BLE001 result = _format_tool_error(tool_name, arguments, exc) success = False emit(f"[tool] {tool_name} failed: {exc}") + emit(f"[tool] result {tool_name} {_serialize_log_value(result)}") tool_events.append( { "tool_name": tool_name, @@ -351,12 +414,15 @@ def _build_parser() -> argparse.ArgumentParser: def main() -> None: """CLI entrypoint for Ollama tool-calling demo.""" args = _build_parser().parse_args() - result = run_ollama_tool_demo( - base_url=args.base_url, - model=args.model, - strict=False, - log=lambda message: print(message, flush=True), - ) + try: + result = run_ollama_tool_demo( + base_url=args.base_url, + model=args.model, + log=lambda message: print(message, flush=True), + ) + except Exception as exc: # noqa: BLE001 + print(f"[error] {exc}", flush=True) + raise SystemExit(1) from exc exec_result = result["exec_result"] if not isinstance(exec_result, dict): raise RuntimeError("demo produced invalid execution result") diff --git a/tests/test_ollama_demo.py b/tests/test_ollama_demo.py index f282003..1e61040 100644 --- a/tests/test_ollama_demo.py +++ b/tests/test_ollama_demo.py @@ -128,7 +128,11 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non assert "git version" in str(result["exec_result"]["stdout"]) assert result["final_response"] == "Executed git command in ephemeral VM." assert len(result["tool_events"]) == 4 + assert any("[model] input user:" in line for line in logs) + assert any("[model] output assistant:" in line for line in logs) + assert any("[model] tool_call vm_exec" in line for line in logs) assert any("[tool] calling vm_exec" in line for line in logs) + assert any("[tool] result vm_exec " in line for line in logs) def test_run_ollama_tool_demo_recovers_from_bad_vm_id( @@ -176,6 +180,95 @@ def test_run_ollama_tool_demo_recovers_from_bad_vm_id( assert int(result["exec_result"]["exit_code"]) == 0 +def test_run_ollama_tool_demo_resolves_vm_id_placeholder( + monkeypatch: pytest.MonkeyPatch, +) -> None: + requests: list[dict[str, Any]] = [] + + def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: + assert base_url == "http://localhost:11434/v1" + requests.append(payload) + step = len(requests) + if step == 1: + return { + "choices": [ + { + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + {"id": "1", "function": {"name": "vm_list_profiles"}}, + {"id": "2", "function": {"name": "vm_list_profiles"}}, + { + "id": "3", + "function": { + "name": "vm_create", + "arguments": json.dumps( + { + "profile": "debian-git", + "vcpu_count": "2", + "mem_mib": "2048", + } + ), + }, + }, + { + "id": "4", + "function": { + "name": "vm_exec", + "arguments": json.dumps( + { + "vm_id": "", + "command": "printf 'git version 2.44.0\\n'", + "timeout_seconds": "300", + } + ), + }, + }, + ], + } + } + ] + } + return { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Executed git command in ephemeral VM.", + } + } + ] + } + + monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) + + logs: list[str] = [] + result = ollama_demo.run_ollama_tool_demo(log=logs.append) + + assert result["fallback_used"] is False + assert int(result["exec_result"]["exit_code"]) == 0 + assert any("resolved vm_id placeholder" in line for line in logs) + exec_event = next( + event for event in result["tool_events"] if event["tool_name"] == "vm_exec" + ) + assert exec_event["success"] is True + + +def test_dispatch_tool_call_vm_exec_autostarts_created_vm(tmp_path: Path) -> None: + manager = RealVmManager(backend_name="mock", base_dir=tmp_path / "vms") + created = manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=60) + vm_id = str(created["vm_id"]) + + executed = ollama_demo._dispatch_tool_call( + manager, + "vm_exec", + {"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"}, + ) + + assert int(executed["exit_code"]) == 0 + + def test_run_ollama_tool_demo_raises_without_vm_exec(monkeypatch: pytest.MonkeyPatch) -> None: def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: del base_url, payload @@ -203,6 +296,7 @@ def test_run_ollama_tool_demo_uses_fallback_when_not_strict( result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append) assert result["fallback_used"] is True assert int(result["exec_result"]["exit_code"]) == 0 + assert any("[model] output assistant: No tools" in line for line in logs) assert any("[fallback]" in line for line in logs) @@ -300,7 +394,7 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None: created = ollama_demo._dispatch_tool_call( manager, "vm_create", - {"profile": "debian-base", "vcpu_count": 1, "mem_mib": 512}, + {"profile": "debian-base", "vcpu_count": "1", "mem_mib": "512", "ttl_seconds": "60"}, ) vm_id = str(created["vm_id"]) started = ollama_demo._dispatch_tool_call(manager, "vm_start", {"vm_id": vm_id}) @@ -308,7 +402,9 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None: status = ollama_demo._dispatch_tool_call(manager, "vm_status", {"vm_id": vm_id}) assert status["vm_id"] == vm_id executed = ollama_demo._dispatch_tool_call( - manager, "vm_exec", {"vm_id": vm_id, "command": "printf 'git version\\n'"} + manager, + "vm_exec", + {"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"}, ) assert int(executed["exit_code"]) == 0 with pytest.raises(RuntimeError, match="unexpected tool requested by model"): @@ -334,10 +430,22 @@ def test_require_str(arguments: dict[str, Any], error: str) -> None: ollama_demo._require_str(arguments, "k") -def test_require_int_validation() -> None: +@pytest.mark.parametrize( + ("arguments", "expected"), + [ + ({"k": 1}, 1), + ({"k": "1"}, 1), + ({"k": " 42 "}, 42), + ], +) +def test_require_int_accepts_numeric_strings(arguments: dict[str, Any], expected: int) -> None: + assert ollama_demo._require_int(arguments, "k") == expected + + +@pytest.mark.parametrize("value", ["", "abc", "2.5", True]) +def test_require_int_validation(value: Any) -> None: with pytest.raises(ValueError, match="must be an integer"): - ollama_demo._require_int({"k": "1"}, "k") - assert ollama_demo._require_int({"k": 1}, "k") == 1 + ollama_demo._require_int({"k": value}, "k") def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None: @@ -444,3 +552,26 @@ def test_main_uses_parser_and_prints_logs( output = capsys.readouterr().out assert "[summary] exit_code=0 fallback_used=False" in output assert "[summary] stdout=git version 2.44.0" in output + + +def test_main_logs_error_and_exits_nonzero( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + class StubParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace(base_url="http://x", model="m") + + monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser()) + + def fake_run(base_url: str, model: str, strict: bool = True, log: Any = None) -> dict[str, Any]: + del base_url, model, strict, log + raise RuntimeError("demo did not execute a successful vm_exec") + + monkeypatch.setattr(ollama_demo, "run_ollama_tool_demo", fake_run) + + with pytest.raises(SystemExit, match="1"): + ollama_demo.main() + + output = capsys.readouterr().out + assert "[error] demo did not execute a successful vm_exec" in output