70 lines
2.3 KiB
Python
70 lines
2.3 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib.util
|
|
import sys
|
|
from pathlib import Path
|
|
from types import ModuleType, SimpleNamespace
|
|
from typing import Any, cast
|
|
|
|
import pytest
|
|
|
|
|
|
def _load_openai_example_module() -> ModuleType:
|
|
path = Path("examples/openai_responses_vm_run.py")
|
|
spec = importlib.util.spec_from_file_location("openai_responses_vm_run", path)
|
|
if spec is None or spec.loader is None:
|
|
raise AssertionError("failed to load OpenAI example module")
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
return module
|
|
|
|
|
|
def test_openai_example_tool_targets_vm_run() -> None:
|
|
module = _load_openai_example_module()
|
|
assert module.OPENAI_VM_RUN_TOOL["name"] == "vm_run"
|
|
assert module.OPENAI_VM_RUN_TOOL["type"] == "function"
|
|
assert module.OPENAI_VM_RUN_TOOL["strict"] is True
|
|
|
|
|
|
def test_openai_example_runs_function_call_loop(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
module = _load_openai_example_module()
|
|
tool_call = SimpleNamespace(
|
|
type="function_call",
|
|
name="vm_run",
|
|
call_id="call_123",
|
|
arguments=(
|
|
'{"environment":"debian:12","command":"git --version",'
|
|
'"vcpu_count":1,"mem_mib":1024}'
|
|
),
|
|
)
|
|
responses = [
|
|
SimpleNamespace(output=[tool_call], output_text=""),
|
|
SimpleNamespace(output=[], output_text="git version 2.40.1"),
|
|
]
|
|
calls: list[dict[str, Any]] = []
|
|
|
|
class FakeResponses:
|
|
def create(self, **kwargs: Any) -> Any:
|
|
calls.append(kwargs)
|
|
return responses.pop(0)
|
|
|
|
class FakeOpenAI:
|
|
def __init__(self) -> None:
|
|
self.responses = FakeResponses()
|
|
|
|
fake_openai_module = ModuleType("openai")
|
|
cast(Any, fake_openai_module).OpenAI = FakeOpenAI
|
|
monkeypatch.setitem(sys.modules, "openai", fake_openai_module)
|
|
monkeypatch.setattr(
|
|
module,
|
|
"call_vm_run",
|
|
lambda arguments: {"exit_code": 0, "stdout": f"ran {arguments['command']}"},
|
|
)
|
|
|
|
result = module.run_openai_vm_run_example(prompt="run git --version")
|
|
|
|
assert result == "git version 2.40.1"
|
|
assert calls[0]["tools"][0]["name"] == "vm_run"
|
|
second_input = calls[1]["input"]
|
|
assert second_input[-1]["type"] == "function_call_output"
|
|
assert second_input[-1]["call_id"] == "call_123"
|