Harden Ollama demo tool-call handling and logging
This commit is contained in:
parent
65f7c0d262
commit
fb8b985049
2 changed files with 219 additions and 22 deletions
|
|
@ -153,9 +153,15 @@ def _require_str(arguments: dict[str, Any], key: str) -> str:
|
||||||
|
|
||||||
def _require_int(arguments: dict[str, Any], key: str) -> int:
|
def _require_int(arguments: dict[str, Any], key: str) -> int:
|
||||||
value = arguments.get(key)
|
value = arguments.get(key)
|
||||||
if not isinstance(value, int):
|
if isinstance(value, bool):
|
||||||
raise ValueError(f"{key} must be an integer")
|
raise ValueError(f"{key} must be an integer")
|
||||||
|
if isinstance(value, int):
|
||||||
return value
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
normalized = value.strip()
|
||||||
|
if normalized.isdigit():
|
||||||
|
return int(normalized)
|
||||||
|
raise ValueError(f"{key} must be an integer")
|
||||||
|
|
||||||
|
|
||||||
def _dispatch_tool_call(
|
def _dispatch_tool_call(
|
||||||
|
|
@ -164,23 +170,26 @@ def _dispatch_tool_call(
|
||||||
if tool_name == "vm_list_profiles":
|
if tool_name == "vm_list_profiles":
|
||||||
return {"profiles": manager.list_profiles()}
|
return {"profiles": manager.list_profiles()}
|
||||||
if tool_name == "vm_create":
|
if tool_name == "vm_create":
|
||||||
|
ttl_seconds = arguments.get("ttl_seconds", 600)
|
||||||
return manager.create_vm(
|
return manager.create_vm(
|
||||||
profile=_require_str(arguments, "profile"),
|
profile=_require_str(arguments, "profile"),
|
||||||
vcpu_count=_require_int(arguments, "vcpu_count"),
|
vcpu_count=_require_int(arguments, "vcpu_count"),
|
||||||
mem_mib=_require_int(arguments, "mem_mib"),
|
mem_mib=_require_int(arguments, "mem_mib"),
|
||||||
ttl_seconds=arguments.get("ttl_seconds", 600)
|
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
||||||
if isinstance(arguments.get("ttl_seconds"), int)
|
|
||||||
else 600,
|
|
||||||
)
|
)
|
||||||
if tool_name == "vm_start":
|
if tool_name == "vm_start":
|
||||||
return manager.start_vm(_require_str(arguments, "vm_id"))
|
return manager.start_vm(_require_str(arguments, "vm_id"))
|
||||||
if tool_name == "vm_exec":
|
if tool_name == "vm_exec":
|
||||||
|
timeout_seconds = arguments.get("timeout_seconds", 30)
|
||||||
|
vm_id = _require_str(arguments, "vm_id")
|
||||||
|
status = manager.status_vm(vm_id)
|
||||||
|
state = status.get("state")
|
||||||
|
if state in {"created", "stopped"}:
|
||||||
|
manager.start_vm(vm_id)
|
||||||
return manager.exec_vm(
|
return manager.exec_vm(
|
||||||
_require_str(arguments, "vm_id"),
|
vm_id,
|
||||||
command=_require_str(arguments, "command"),
|
command=_require_str(arguments, "command"),
|
||||||
timeout_seconds=arguments.get("timeout_seconds", 30)
|
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
||||||
if isinstance(arguments.get("timeout_seconds"), int)
|
|
||||||
else 30,
|
|
||||||
)
|
)
|
||||||
if tool_name == "vm_status":
|
if tool_name == "vm_status":
|
||||||
return manager.status_vm(_require_str(arguments, "vm_id"))
|
return manager.status_vm(_require_str(arguments, "vm_id"))
|
||||||
|
|
@ -204,6 +213,45 @@ def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
|
||||||
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
|
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_vm_id_placeholder(value: str) -> bool:
|
||||||
|
normalized = value.strip().lower()
|
||||||
|
if normalized in {
|
||||||
|
"vm_id_returned_by_vm_create",
|
||||||
|
"<vm_id_returned_by_vm_create>",
|
||||||
|
"{vm_id_returned_by_vm_create}",
|
||||||
|
}:
|
||||||
|
return True
|
||||||
|
return normalized.startswith("<") and normalized.endswith(">") and "vm_id" in normalized
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_tool_arguments(
|
||||||
|
tool_name: str,
|
||||||
|
arguments: dict[str, Any],
|
||||||
|
*,
|
||||||
|
last_created_vm_id: str | None,
|
||||||
|
) -> tuple[dict[str, Any], str | None]:
|
||||||
|
if tool_name not in {"vm_start", "vm_exec", "vm_status"} or last_created_vm_id is None:
|
||||||
|
return arguments, None
|
||||||
|
vm_id = arguments.get("vm_id")
|
||||||
|
if not isinstance(vm_id, str) or not _is_vm_id_placeholder(vm_id):
|
||||||
|
return arguments, None
|
||||||
|
normalized_arguments = dict(arguments)
|
||||||
|
normalized_arguments["vm_id"] = last_created_vm_id
|
||||||
|
return normalized_arguments, last_created_vm_id
|
||||||
|
|
||||||
|
|
||||||
|
def _summarize_message_for_log(message: dict[str, Any]) -> str:
|
||||||
|
role = str(message.get("role", "unknown"))
|
||||||
|
content = str(message.get("content") or "").strip()
|
||||||
|
if content == "":
|
||||||
|
return f"{role}: <empty>"
|
||||||
|
return f"{role}: {content}"
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_log_value(value: Any) -> str:
|
||||||
|
return json.dumps(value, sort_keys=True, separators=(",", ":"))
|
||||||
|
|
||||||
|
|
||||||
def run_ollama_tool_demo(
|
def run_ollama_tool_demo(
|
||||||
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
||||||
model: str = DEFAULT_OLLAMA_MODEL,
|
model: str = DEFAULT_OLLAMA_MODEL,
|
||||||
|
|
@ -229,9 +277,10 @@ def run_ollama_tool_demo(
|
||||||
]
|
]
|
||||||
tool_events: list[dict[str, Any]] = []
|
tool_events: list[dict[str, Any]] = []
|
||||||
final_response = ""
|
final_response = ""
|
||||||
|
last_created_vm_id: str | None = None
|
||||||
|
|
||||||
for round_index in range(1, MAX_TOOL_ROUNDS + 1):
|
for _round_index in range(1, MAX_TOOL_ROUNDS + 1):
|
||||||
emit(f"[ollama] round {round_index}: requesting completion")
|
emit(f"[model] input {_summarize_message_for_log(messages[-1])}")
|
||||||
response = _post_chat_completion(
|
response = _post_chat_completion(
|
||||||
base_url,
|
base_url,
|
||||||
{
|
{
|
||||||
|
|
@ -243,6 +292,7 @@ def run_ollama_tool_demo(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
assistant_message = _extract_message(response)
|
assistant_message = _extract_message(response)
|
||||||
|
emit(f"[model] output {_summarize_message_for_log(assistant_message)}")
|
||||||
tool_calls = assistant_message.get("tool_calls")
|
tool_calls = assistant_message.get("tool_calls")
|
||||||
if not isinstance(tool_calls, list) or not tool_calls:
|
if not isinstance(tool_calls, list) or not tool_calls:
|
||||||
final_response = str(assistant_message.get("content") or "")
|
final_response = str(assistant_message.get("content") or "")
|
||||||
|
|
@ -270,15 +320,28 @@ def run_ollama_tool_demo(
|
||||||
if not isinstance(tool_name, str):
|
if not isinstance(tool_name, str):
|
||||||
raise RuntimeError("tool call function name is invalid")
|
raise RuntimeError("tool call function name is invalid")
|
||||||
arguments = _parse_tool_arguments(function.get("arguments"))
|
arguments = _parse_tool_arguments(function.get("arguments"))
|
||||||
|
emit(f"[model] tool_call {tool_name} args={arguments}")
|
||||||
|
arguments, normalized_vm_id = _normalize_tool_arguments(
|
||||||
|
tool_name,
|
||||||
|
arguments,
|
||||||
|
last_created_vm_id=last_created_vm_id,
|
||||||
|
)
|
||||||
|
if normalized_vm_id is not None:
|
||||||
|
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
|
||||||
emit(f"[tool] calling {tool_name} with args={arguments}")
|
emit(f"[tool] calling {tool_name} with args={arguments}")
|
||||||
try:
|
try:
|
||||||
result = _dispatch_tool_call(manager, tool_name, arguments)
|
result = _dispatch_tool_call(manager, tool_name, arguments)
|
||||||
success = True
|
success = True
|
||||||
emit(f"[tool] {tool_name} succeeded")
|
emit(f"[tool] {tool_name} succeeded")
|
||||||
|
if tool_name == "vm_create":
|
||||||
|
created_vm_id = result.get("vm_id")
|
||||||
|
if isinstance(created_vm_id, str) and created_vm_id != "":
|
||||||
|
last_created_vm_id = created_vm_id
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
result = _format_tool_error(tool_name, arguments, exc)
|
result = _format_tool_error(tool_name, arguments, exc)
|
||||||
success = False
|
success = False
|
||||||
emit(f"[tool] {tool_name} failed: {exc}")
|
emit(f"[tool] {tool_name} failed: {exc}")
|
||||||
|
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
|
||||||
tool_events.append(
|
tool_events.append(
|
||||||
{
|
{
|
||||||
"tool_name": tool_name,
|
"tool_name": tool_name,
|
||||||
|
|
@ -351,12 +414,15 @@ def _build_parser() -> argparse.ArgumentParser:
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
"""CLI entrypoint for Ollama tool-calling demo."""
|
"""CLI entrypoint for Ollama tool-calling demo."""
|
||||||
args = _build_parser().parse_args()
|
args = _build_parser().parse_args()
|
||||||
|
try:
|
||||||
result = run_ollama_tool_demo(
|
result = run_ollama_tool_demo(
|
||||||
base_url=args.base_url,
|
base_url=args.base_url,
|
||||||
model=args.model,
|
model=args.model,
|
||||||
strict=False,
|
|
||||||
log=lambda message: print(message, flush=True),
|
log=lambda message: print(message, flush=True),
|
||||||
)
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
print(f"[error] {exc}", flush=True)
|
||||||
|
raise SystemExit(1) from exc
|
||||||
exec_result = result["exec_result"]
|
exec_result = result["exec_result"]
|
||||||
if not isinstance(exec_result, dict):
|
if not isinstance(exec_result, dict):
|
||||||
raise RuntimeError("demo produced invalid execution result")
|
raise RuntimeError("demo produced invalid execution result")
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,11 @@ def test_run_ollama_tool_demo_happy_path(monkeypatch: pytest.MonkeyPatch) -> Non
|
||||||
assert "git version" in str(result["exec_result"]["stdout"])
|
assert "git version" in str(result["exec_result"]["stdout"])
|
||||||
assert result["final_response"] == "Executed git command in ephemeral VM."
|
assert result["final_response"] == "Executed git command in ephemeral VM."
|
||||||
assert len(result["tool_events"]) == 4
|
assert len(result["tool_events"]) == 4
|
||||||
|
assert any("[model] input user:" in line for line in logs)
|
||||||
|
assert any("[model] output assistant:" in line for line in logs)
|
||||||
|
assert any("[model] tool_call vm_exec" in line for line in logs)
|
||||||
assert any("[tool] calling vm_exec" in line for line in logs)
|
assert any("[tool] calling vm_exec" in line for line in logs)
|
||||||
|
assert any("[tool] result vm_exec " in line for line in logs)
|
||||||
|
|
||||||
|
|
||||||
def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
|
def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
|
||||||
|
|
@ -176,6 +180,95 @@ def test_run_ollama_tool_demo_recovers_from_bad_vm_id(
|
||||||
assert int(result["exec_result"]["exit_code"]) == 0
|
assert int(result["exec_result"]["exit_code"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_ollama_tool_demo_resolves_vm_id_placeholder(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
requests: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
assert base_url == "http://localhost:11434/v1"
|
||||||
|
requests.append(payload)
|
||||||
|
step = len(requests)
|
||||||
|
if step == 1:
|
||||||
|
return {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "1", "function": {"name": "vm_list_profiles"}},
|
||||||
|
{"id": "2", "function": {"name": "vm_list_profiles"}},
|
||||||
|
{
|
||||||
|
"id": "3",
|
||||||
|
"function": {
|
||||||
|
"name": "vm_create",
|
||||||
|
"arguments": json.dumps(
|
||||||
|
{
|
||||||
|
"profile": "debian-git",
|
||||||
|
"vcpu_count": "2",
|
||||||
|
"mem_mib": "2048",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "4",
|
||||||
|
"function": {
|
||||||
|
"name": "vm_exec",
|
||||||
|
"arguments": json.dumps(
|
||||||
|
{
|
||||||
|
"vm_id": "<vm_id_returned_by_vm_create>",
|
||||||
|
"command": "printf 'git version 2.44.0\\n'",
|
||||||
|
"timeout_seconds": "300",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Executed git command in ephemeral VM.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
monkeypatch.setattr(ollama_demo, "_post_chat_completion", fake_post_chat_completion)
|
||||||
|
|
||||||
|
logs: list[str] = []
|
||||||
|
result = ollama_demo.run_ollama_tool_demo(log=logs.append)
|
||||||
|
|
||||||
|
assert result["fallback_used"] is False
|
||||||
|
assert int(result["exec_result"]["exit_code"]) == 0
|
||||||
|
assert any("resolved vm_id placeholder" in line for line in logs)
|
||||||
|
exec_event = next(
|
||||||
|
event for event in result["tool_events"] if event["tool_name"] == "vm_exec"
|
||||||
|
)
|
||||||
|
assert exec_event["success"] is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_dispatch_tool_call_vm_exec_autostarts_created_vm(tmp_path: Path) -> None:
|
||||||
|
manager = RealVmManager(backend_name="mock", base_dir=tmp_path / "vms")
|
||||||
|
created = manager.create_vm(profile="debian-base", vcpu_count=1, mem_mib=512, ttl_seconds=60)
|
||||||
|
vm_id = str(created["vm_id"])
|
||||||
|
|
||||||
|
executed = ollama_demo._dispatch_tool_call(
|
||||||
|
manager,
|
||||||
|
"vm_exec",
|
||||||
|
{"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert int(executed["exit_code"]) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_run_ollama_tool_demo_raises_without_vm_exec(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_run_ollama_tool_demo_raises_without_vm_exec(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
def fake_post_chat_completion(base_url: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||||
del base_url, payload
|
del base_url, payload
|
||||||
|
|
@ -203,6 +296,7 @@ def test_run_ollama_tool_demo_uses_fallback_when_not_strict(
|
||||||
result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append)
|
result = ollama_demo.run_ollama_tool_demo(strict=False, log=logs.append)
|
||||||
assert result["fallback_used"] is True
|
assert result["fallback_used"] is True
|
||||||
assert int(result["exec_result"]["exit_code"]) == 0
|
assert int(result["exec_result"]["exit_code"]) == 0
|
||||||
|
assert any("[model] output assistant: No tools" in line for line in logs)
|
||||||
assert any("[fallback]" in line for line in logs)
|
assert any("[fallback]" in line for line in logs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -300,7 +394,7 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
|
||||||
created = ollama_demo._dispatch_tool_call(
|
created = ollama_demo._dispatch_tool_call(
|
||||||
manager,
|
manager,
|
||||||
"vm_create",
|
"vm_create",
|
||||||
{"profile": "debian-base", "vcpu_count": 1, "mem_mib": 512},
|
{"profile": "debian-base", "vcpu_count": "1", "mem_mib": "512", "ttl_seconds": "60"},
|
||||||
)
|
)
|
||||||
vm_id = str(created["vm_id"])
|
vm_id = str(created["vm_id"])
|
||||||
started = ollama_demo._dispatch_tool_call(manager, "vm_start", {"vm_id": vm_id})
|
started = ollama_demo._dispatch_tool_call(manager, "vm_start", {"vm_id": vm_id})
|
||||||
|
|
@ -308,7 +402,9 @@ def test_dispatch_tool_call_coverage(tmp_path: Path) -> None:
|
||||||
status = ollama_demo._dispatch_tool_call(manager, "vm_status", {"vm_id": vm_id})
|
status = ollama_demo._dispatch_tool_call(manager, "vm_status", {"vm_id": vm_id})
|
||||||
assert status["vm_id"] == vm_id
|
assert status["vm_id"] == vm_id
|
||||||
executed = ollama_demo._dispatch_tool_call(
|
executed = ollama_demo._dispatch_tool_call(
|
||||||
manager, "vm_exec", {"vm_id": vm_id, "command": "printf 'git version\\n'"}
|
manager,
|
||||||
|
"vm_exec",
|
||||||
|
{"vm_id": vm_id, "command": "printf 'git version\\n'", "timeout_seconds": "30"},
|
||||||
)
|
)
|
||||||
assert int(executed["exit_code"]) == 0
|
assert int(executed["exit_code"]) == 0
|
||||||
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
|
with pytest.raises(RuntimeError, match="unexpected tool requested by model"):
|
||||||
|
|
@ -334,10 +430,22 @@ def test_require_str(arguments: dict[str, Any], error: str) -> None:
|
||||||
ollama_demo._require_str(arguments, "k")
|
ollama_demo._require_str(arguments, "k")
|
||||||
|
|
||||||
|
|
||||||
def test_require_int_validation() -> None:
|
@pytest.mark.parametrize(
|
||||||
|
("arguments", "expected"),
|
||||||
|
[
|
||||||
|
({"k": 1}, 1),
|
||||||
|
({"k": "1"}, 1),
|
||||||
|
({"k": " 42 "}, 42),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_require_int_accepts_numeric_strings(arguments: dict[str, Any], expected: int) -> None:
|
||||||
|
assert ollama_demo._require_int(arguments, "k") == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", ["", "abc", "2.5", True])
|
||||||
|
def test_require_int_validation(value: Any) -> None:
|
||||||
with pytest.raises(ValueError, match="must be an integer"):
|
with pytest.raises(ValueError, match="must be an integer"):
|
||||||
ollama_demo._require_int({"k": "1"}, "k")
|
ollama_demo._require_int({"k": value}, "k")
|
||||||
assert ollama_demo._require_int({"k": 1}, "k") == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
def test_post_chat_completion_success(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||||
|
|
@ -444,3 +552,26 @@ def test_main_uses_parser_and_prints_logs(
|
||||||
output = capsys.readouterr().out
|
output = capsys.readouterr().out
|
||||||
assert "[summary] exit_code=0 fallback_used=False" in output
|
assert "[summary] exit_code=0 fallback_used=False" in output
|
||||||
assert "[summary] stdout=git version 2.44.0" in output
|
assert "[summary] stdout=git version 2.44.0" in output
|
||||||
|
|
||||||
|
|
||||||
|
def test_main_logs_error_and_exits_nonzero(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
capsys: pytest.CaptureFixture[str],
|
||||||
|
) -> None:
|
||||||
|
class StubParser:
|
||||||
|
def parse_args(self) -> argparse.Namespace:
|
||||||
|
return argparse.Namespace(base_url="http://x", model="m")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ollama_demo, "_build_parser", lambda: StubParser())
|
||||||
|
|
||||||
|
def fake_run(base_url: str, model: str, strict: bool = True, log: Any = None) -> dict[str, Any]:
|
||||||
|
del base_url, model, strict, log
|
||||||
|
raise RuntimeError("demo did not execute a successful vm_exec")
|
||||||
|
|
||||||
|
monkeypatch.setattr(ollama_demo, "run_ollama_tool_demo", fake_run)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit, match="1"):
|
||||||
|
ollama_demo.main()
|
||||||
|
|
||||||
|
output = capsys.readouterr().out
|
||||||
|
assert "[error] demo did not execute a successful vm_exec" in output
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue