diff --git a/src/pyro_mcp/ollama_demo.py b/src/pyro_mcp/ollama_demo.py index 938e75a..badaca6 100644 --- a/src/pyro_mcp/ollama_demo.py +++ b/src/pyro_mcp/ollama_demo.py @@ -196,6 +196,14 @@ def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False) value = arguments.get(key, default) if isinstance(value, bool): return value + if isinstance(value, int) and value in {0, 1}: + return bool(value) + if isinstance(value, str): + normalized = value.strip().lower() + if normalized in {"true", "1", "yes", "on"}: + return True + if normalized in {"false", "0", "no", "off"}: + return False raise ValueError(f"{key} must be a boolean") @@ -245,13 +253,22 @@ def _dispatch_tool_call( def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]: - return { + payload = { "ok": False, "tool_name": tool_name, "arguments": arguments, "error_type": exc.__class__.__name__, "error": str(exc), } + error_text = str(exc) + if "must be a boolean" in error_text: + payload["hint"] = "Use JSON booleans true or false, not quoted strings." + if ( + "environment must be a non-empty string" in error_text + and isinstance(arguments.get("profile"), str) + ): + payload["hint"] = "Use `environment` instead of `profile`." + return payload def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]: @@ -283,14 +300,19 @@ def _normalize_tool_arguments( *, last_created_vm_id: str | None, ) -> tuple[dict[str, Any], str | None]: - if tool_name not in {"vm_start", "vm_exec", "vm_status"} or last_created_vm_id is None: - return arguments, None - vm_id = arguments.get("vm_id") - if not isinstance(vm_id, str) or not _is_vm_id_placeholder(vm_id): - return arguments, None normalized_arguments = dict(arguments) - normalized_arguments["vm_id"] = last_created_vm_id - return normalized_arguments, last_created_vm_id + normalized_vm_id: str | None = None + if tool_name in {"vm_run", "vm_create"}: + legacy_profile = normalized_arguments.get("profile") + if "environment" not in normalized_arguments and isinstance(legacy_profile, str): + normalized_arguments["environment"] = legacy_profile + normalized_arguments.pop("profile", None) + if tool_name in {"vm_start", "vm_exec", "vm_status"} and last_created_vm_id is not None: + vm_id = normalized_arguments.get("vm_id") + if isinstance(vm_id, str) and _is_vm_id_placeholder(vm_id): + normalized_arguments["vm_id"] = last_created_vm_id + normalized_vm_id = last_created_vm_id + return normalized_arguments, normalized_vm_id def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str: @@ -408,7 +430,10 @@ def run_ollama_tool_demo( except Exception as exc: # noqa: BLE001 result = _format_tool_error(tool_name, arguments, exc) success = False - emit(f"[tool] {tool_name} failed: {exc}") + emit( + f"[tool] {tool_name} failed: {exc} " + f"args={_serialize_log_value(arguments)}" + ) if verbose: emit(f"[tool] result {tool_name} {_serialize_log_value(result)}") else: diff --git a/tests/test_ollama_demo.py b/tests/test_ollama_demo.py index 713e5f2..e8a3d3d 100644 --- a/tests/test_ollama_demo.py +++ b/tests/test_ollama_demo.py @@ -95,6 +95,64 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non assert any(line == "[tool] result vm_run" for line in logs) +def test_run_ollama_tool_demo_accepts_legacy_profile_and_string_network( + monkeypatch: pytest.MonkeyPatch, +) -> None: + requests: list[dict[str, Any]] = [] + + def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]: + assert base_url == "http://localhost:11434/v1" + requests.append(payload) + if len(requests) == 1: + return { + "choices": [ + { + "message": { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "1", + "function": { + "name": "vm_run", + "arguments": json.dumps( + { + "profile": "debian:12", + "command": "printf 'true\\n'", + "vcpu_count": 1, + "mem_mib": 512, + "network": "true", + } + ), + }, + } + ], + } + } + ] + } + return { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Executed git command in ephemeral VM.", + } + } + ] + } + + monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion) + + result = ollama_demo.run_ollama_tool_demo() + + assert result["fallback_used"] is False + assert int(result["exec_result"]["exit_code"]) == 0 + assert result["tool_events"][0]["success"] is True + assert result["tool_events"][0]["arguments"]["environment"] == "debian:12" + assert "profile" not in result["tool_events"][0]["arguments"] + + def test_run_ollama_tool_demo_recovers_from_bad_vm_id( monkeypatch: pytest.MonkeyPatch, ) -> None: @@ -467,14 +525,25 @@ def test_require_int_validation(value: Any) -> None: ollama_demo._require_int({"k": value}, "k") -@pytest.mark.parametrize(("arguments", "expected"), [({}, False), ({"k": True}, True)]) +@pytest.mark.parametrize( + ("arguments", "expected"), + [ + ({}, False), + ({"k": True}, True), + ({"k": "true"}, True), + ({"k": " false "}, False), + ({"k": 1}, True), + ({"k": 0}, False), + ], +) def test_require_bool(arguments: dict[str, Any], expected: bool) -> None: assert ollama_demo._require_bool(arguments, "k", default=False) is expected -def test_require_bool_validation() -> None: +@pytest.mark.parametrize("value", ["", "maybe", 2]) +def test_require_bool_validation(value: Any) -> None: with pytest.raises(ValueError, match="must be a boolean"): - ollama_demo._require_bool({"k": "true"}, "k") + ollama_demo._require_bool({"k": value}, "k") def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None: