Normalize Ollama demo tool arguments

This commit is contained in:
Thales Maciel 2026-03-08 17:16:28 -03:00
parent 75082467f9
commit f3e7d4aa3e
2 changed files with 106 additions and 12 deletions

View file

@ -196,6 +196,14 @@ def _require_bool(arguments: dict[str, Any], key: str, *, default: bool = False)
value = arguments.get(key, default)
if isinstance(value, bool):
return value
if isinstance(value, int) and value in {0, 1}:
return bool(value)
if isinstance(value, str):
normalized = value.strip().lower()
if normalized in {"true", "1", "yes", "on"}:
return True
if normalized in {"false", "0", "no", "off"}:
return False
raise ValueError(f"{key} must be a boolean")
@ -245,13 +253,22 @@ def _dispatch_tool_call(
def _format_tool_error(tool_name: str, arguments: dict[str, Any], exc: Exception) -> dict[str, Any]:
return {
payload = {
"ok": False,
"tool_name": tool_name,
"arguments": arguments,
"error_type": exc.__class__.__name__,
"error": str(exc),
}
error_text = str(exc)
if "must be a boolean" in error_text:
payload["hint"] = "Use JSON booleans true or false, not quoted strings."
if (
"environment must be a non-empty string" in error_text
and isinstance(arguments.get("profile"), str)
):
payload["hint"] = "Use `environment` instead of `profile`."
return payload
def _run_direct_lifecycle_fallback(pyro: Pyro) -> dict[str, Any]:
@ -283,14 +300,19 @@ def _normalize_tool_arguments(
*,
last_created_vm_id: str | None,
) -> tuple[dict[str, Any], str | None]:
if tool_name not in {"vm_start", "vm_exec", "vm_status"} or last_created_vm_id is None:
return arguments, None
vm_id = arguments.get("vm_id")
if not isinstance(vm_id, str) or not _is_vm_id_placeholder(vm_id):
return arguments, None
normalized_arguments = dict(arguments)
normalized_arguments["vm_id"] = last_created_vm_id
return normalized_arguments, last_created_vm_id
normalized_vm_id: str | None = None
if tool_name in {"vm_run", "vm_create"}:
legacy_profile = normalized_arguments.get("profile")
if "environment" not in normalized_arguments and isinstance(legacy_profile, str):
normalized_arguments["environment"] = legacy_profile
normalized_arguments.pop("profile", None)
if tool_name in {"vm_start", "vm_exec", "vm_status"} and last_created_vm_id is not None:
vm_id = normalized_arguments.get("vm_id")
if isinstance(vm_id, str) and _is_vm_id_placeholder(vm_id):
normalized_arguments["vm_id"] = last_created_vm_id
normalized_vm_id = last_created_vm_id
return normalized_arguments, normalized_vm_id
def _summarize_message_for_log(message: dict[str, Any], *, verbose: bool) -> str:
@ -408,7 +430,10 @@ def run_ollama_tool_demo(
except Exception as exc: # noqa: BLE001
result = _format_tool_error(tool_name, arguments, exc)
success = False
emit(f"[tool] {tool_name} failed: {exc}")
emit(
f"[tool] {tool_name} failed: {exc} "
f"args={_serialize_log_value(arguments)}"
)
if verbose:
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
else:

View file

@ -95,6 +95,64 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non
assert any(line == "[tool] result vm_run" for line in logs)
def test_run_ollama_tool_demo_accepts_legacy_profile_and_string_network(
monkeypatch: pytest.MonkeyPatch,
) -> None:
requests: list[dict[str, Any]] = []
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
assert base_url == "http://localhost:11434/v1"
requests.append(payload)
if len(requests) == 1:
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "1",
"function": {
"name": "vm_run",
"arguments": json.dumps(
{
"profile": "debian:12",
"command": "printf 'true\\n'",
"vcpu_count": 1,
"mem_mib": 512,
"network": "true",
}
),
},
}
],
}
}
]
}
return {
"choices": [
{
"message": {
"role": "assistant",
"content": "Executed git command in ephemeral VM.",
}
}
]
}
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
result = ollama_demo.run_ollama_tool_demo()
assert result["fallback_used"] is False
assert int(result["exec_result"]["exit_code"]) == 0
assert result["tool_events"][0]["success"] is True
assert result["tool_events"][0]["arguments"]["environment"] == "debian:12"
assert "profile" not in result["tool_events"][0]["arguments"]
def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
monkeypatch: pytest.MonkeyPatch,
) -> None:
@ -467,14 +525,25 @@ def test_require_int_validation(value: Any) -> None:
ollama_demo._require_int({"k": value}, "k")
@pytest.mark.parametrize(("arguments", "expected"), [({}, False), ({"k": True}, True)])
@pytest.mark.parametrize(
("arguments", "expected"),
[
({}, False),
({"k": True}, True),
({"k": "true"}, True),
({"k": " false "}, False),
({"k": 1}, True),
({"k": 0}, False),
],
)
def test_require_bool(arguments: dict[str, Any], expected: bool) -> None:
assert ollama_demo._require_bool(arguments, "k", default=False) is expected
def test_require_bool_validation() -> None:
@pytest.mark.parametrize("value", ["", "maybe", 2])
def test_require_bool_validation(value: Any) -> None:
with pytest.raises(ValueError, match="must be a boolean"):
ollama_demo._require_bool({"k": "true"}, "k")
ollama_demo._require_bool({"k": value}, "k")
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None: