Harden Ollama demo tool-call handling and logging
This commit is contained in:
parent
65f7c0d262
commit
fb8b985049
2 changed files with 219 additions and 22 deletions
|
|
@ -153,9 +153,15 @@ def _require_str(arguments: dict[str, Any], key: str) -> str:
|
|||
|
||||
def _require_int(arguments: dict[str, Any], key: str) -> int:
|
||||
value = arguments.get(key)
|
||||
if not isinstance(value, int):
|
||||
if isinstance(value, bool):
|
||||
raise ValueError(f"{key} must be an integer")
|
||||
return value
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip()
|
||||
if normalized.isdigit():
|
||||
return int(normalized)
|
||||
raise ValueError(f"{key} must be an integer")
|
||||
|
||||
|
||||
def _dispatch_tool_call(
|
||||
|
|
@ -164,23 +170,26 @@ def _dispatch_tool_call(
|
|||
if tool_name == "vm_list_profiles":
|
||||
return {"profiles": manager.list_profiles()}
|
||||
if tool_name == "vm_create":
|
||||
ttl_seconds = arguments.get("ttl_seconds", 600)
|
||||
return manager.create_vm(
|
||||
profile=_require_str(arguments, "profile"),
|
||||
vcpu_count=_require_int(arguments, "vcpu_count"),
|
||||
mem_mib=_require_int(arguments, "mem_mib"),
|
||||
ttl_seconds=arguments.get("ttl_seconds", 600)
|
||||
if isinstance(arguments.get("ttl_seconds"), int)
|
||||
else 600,
|
||||
ttl_seconds=_require_int({"ttl_seconds": ttl_seconds}, "ttl_seconds"),
|
||||
)
|
||||
if tool_name == "vm_start":
|
||||
return manager.start_vm(_require_str(arguments, "vm_id"))
|
||||
if tool_name == "vm_exec":
|
||||
timeout_seconds = arguments.get("timeout_seconds", 30)
|
||||
vm_id = _require_str(arguments, "vm_id")
|
||||
status = manager.status_vm(vm_id)
|
||||
state = status.get("state")
|
||||
if state in {"created", "stopped"}:
|
||||
manager.start_vm(vm_id)
|
||||
return manager.exec_vm(
|
||||
_require_str(arguments, "vm_id"),
|
||||
vm_id,
|
||||
command=_require_str(arguments, "command"),
|
||||
timeout_seconds=arguments.get("timeout_seconds", 30)
|
||||
if isinstance(arguments.get("timeout_seconds"), int)
|
||||
else 30,
|
||||
timeout_seconds=_require_int({"timeout_seconds": timeout_seconds}, "timeout_seconds"),
|
||||
)
|
||||
if tool_name == "vm_status":
|
||||
return manager.status_vm(_require_str(arguments, "vm_id"))
|
||||
|
|
@ -204,6 +213,45 @@ def _run_direct_lifecycle_fallback(manager: VmManager) -> dict[str, Any]:
|
|||
return manager.exec_vm(vm_id, command="git --version", timeout_seconds=30)
|
||||
|
||||
|
||||
def _is_vm_id_placeholder(value: str) -> bool:
|
||||
normalized = value.strip().lower()
|
||||
if normalized in {
|
||||
"vm_id_returned_by_vm_create",
|
||||
"<vm_id_returned_by_vm_create>",
|
||||
"{vm_id_returned_by_vm_create}",
|
||||
}:
|
||||
return True
|
||||
return normalized.startswith("<") and normalized.endswith(">") and "vm_id" in normalized
|
||||
|
||||
|
||||
def _normalize_tool_arguments(
|
||||
tool_name: str,
|
||||
arguments: dict[str, Any],
|
||||
*,
|
||||
last_created_vm_id: str | None,
|
||||
) -> tuple[dict[str, Any], str | None]:
|
||||
if tool_name not in {"vm_start", "vm_exec", "vm_status"} or last_created_vm_id is None:
|
||||
return arguments, None
|
||||
vm_id = arguments.get("vm_id")
|
||||
if not isinstance(vm_id, str) or not _is_vm_id_placeholder(vm_id):
|
||||
return arguments, None
|
||||
normalized_arguments = dict(arguments)
|
||||
normalized_arguments["vm_id"] = last_created_vm_id
|
||||
return normalized_arguments, last_created_vm_id
|
||||
|
||||
|
||||
def _summarize_message_for_log(message: dict[str, Any]) -> str:
|
||||
role = str(message.get("role", "unknown"))
|
||||
content = str(message.get("content") or "").strip()
|
||||
if content == "":
|
||||
return f"{role}: <empty>"
|
||||
return f"{role}: {content}"
|
||||
|
||||
|
||||
def _serialize_log_value(value: Any) -> str:
|
||||
return json.dumps(value, sort_keys=True, separators=(",", ":"))
|
||||
|
||||
|
||||
def run_ollama_tool_demo(
|
||||
base_url: str = DEFAULT_OLLAMA_BASE_URL,
|
||||
model: str = DEFAULT_OLLAMA_MODEL,
|
||||
|
|
@ -229,9 +277,10 @@ def run_ollama_tool_demo(
|
|||
]
|
||||
tool_events: list[dict[str, Any]] = []
|
||||
final_response = ""
|
||||
last_created_vm_id: str | None = None
|
||||
|
||||
for round_index in range(1, MAX_TOOL_ROUNDS + 1):
|
||||
emit(f"[ollama] round {round_index}: requesting completion")
|
||||
for _round_index in range(1, MAX_TOOL_ROUNDS + 1):
|
||||
emit(f"[model] input {_summarize_message_for_log(messages[-1])}")
|
||||
response = _post_chat_completion(
|
||||
base_url,
|
||||
{
|
||||
|
|
@ -243,6 +292,7 @@ def run_ollama_tool_demo(
|
|||
},
|
||||
)
|
||||
assistant_message = _extract_message(response)
|
||||
emit(f"[model] output {_summarize_message_for_log(assistant_message)}")
|
||||
tool_calls = assistant_message.get("tool_calls")
|
||||
if not isinstance(tool_calls, list) or not tool_calls:
|
||||
final_response = str(assistant_message.get("content") or "")
|
||||
|
|
@ -270,15 +320,28 @@ def run_ollama_tool_demo(
|
|||
if not isinstance(tool_name, str):
|
||||
raise RuntimeError("tool call function name is invalid")
|
||||
arguments = _parse_tool_arguments(function.get("arguments"))
|
||||
emit(f"[model] tool_call {tool_name} args={arguments}")
|
||||
arguments, normalized_vm_id = _normalize_tool_arguments(
|
||||
tool_name,
|
||||
arguments,
|
||||
last_created_vm_id=last_created_vm_id,
|
||||
)
|
||||
if normalized_vm_id is not None:
|
||||
emit(f"[tool] resolved vm_id placeholder to {normalized_vm_id}")
|
||||
emit(f"[tool] calling {tool_name} with args={arguments}")
|
||||
try:
|
||||
result = _dispatch_tool_call(manager, tool_name, arguments)
|
||||
success = True
|
||||
emit(f"[tool] {tool_name} succeeded")
|
||||
if tool_name == "vm_create":
|
||||
created_vm_id = result.get("vm_id")
|
||||
if isinstance(created_vm_id, str) and created_vm_id != "":
|
||||
last_created_vm_id = created_vm_id
|
||||
except Exception as exc: # noqa: BLE001
|
||||
result = _format_tool_error(tool_name, arguments, exc)
|
||||
success = False
|
||||
emit(f"[tool] {tool_name} failed: {exc}")
|
||||
emit(f"[tool] result {tool_name} {_serialize_log_value(result)}")
|
||||
tool_events.append(
|
||||
{
|
||||
"tool_name": tool_name,
|
||||
|
|
@ -351,12 +414,15 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
def main() -> None:
|
||||
"""CLI entrypoint for Ollama tool-calling demo."""
|
||||
args = _build_parser().parse_args()
|
||||
result = run_ollama_tool_demo(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
strict=False,
|
||||
log=lambda message: print(message, flush=True),
|
||||
)
|
||||
try:
|
||||
result = run_ollama_tool_demo(
|
||||
base_url=args.base_url,
|
||||
model=args.model,
|
||||
log=lambda message: print(message, flush=True),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
print(f"[error] {exc}", flush=True)
|
||||
raise SystemExit(1) from exc
|
||||
exec_result = result["exec_result"]
|
||||
if not isinstance(exec_result, dict):
|
||||
raise RuntimeError("demo produced invalid execution result")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue