94 lines
3.1 KiB
Python
94 lines
3.1 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import io
|
|
import tomllib
|
|
from contextlib import redirect_stdout
|
|
from pathlib import Path
|
|
from typing import Any, cast
|
|
|
|
import pytest
|
|
|
|
from pyro_mcp import Pyro, __version__
|
|
from pyro_mcp.cli import _build_parser
|
|
from pyro_mcp.contract import (
|
|
PUBLIC_CLI_COMMANDS,
|
|
PUBLIC_CLI_DEMO_SUBCOMMANDS,
|
|
PUBLIC_CLI_RUN_FLAGS,
|
|
PUBLIC_MCP_TOOLS,
|
|
PUBLIC_SDK_METHODS,
|
|
)
|
|
from pyro_mcp.vm_manager import VmManager
|
|
from pyro_mcp.vm_network import TapNetworkManager
|
|
|
|
|
|
def _subparser_choice(parser: argparse.ArgumentParser, name: str) -> argparse.ArgumentParser:
|
|
subparsers = getattr(parser, "_subparsers", None)
|
|
if subparsers is None:
|
|
raise AssertionError("parser does not define subparsers")
|
|
group_actions = cast(list[Any], subparsers._group_actions) # noqa: SLF001
|
|
if not group_actions:
|
|
raise AssertionError("parser subparsers are empty")
|
|
choices = cast(dict[str, argparse.ArgumentParser], group_actions[0].choices)
|
|
return choices[name]
|
|
|
|
|
|
def test_public_sdk_methods_exist() -> None:
|
|
assert tuple(sorted(PUBLIC_SDK_METHODS)) == PUBLIC_SDK_METHODS
|
|
for method_name in PUBLIC_SDK_METHODS:
|
|
assert hasattr(Pyro, method_name), method_name
|
|
|
|
|
|
def test_public_cli_help_lists_commands_and_run_flags() -> None:
|
|
parser = _build_parser()
|
|
help_text = parser.format_help()
|
|
assert "--version" in help_text
|
|
for command_name in PUBLIC_CLI_COMMANDS:
|
|
assert command_name in help_text
|
|
|
|
run_parser = _build_parser()
|
|
run_help = run_parser.parse_args(
|
|
["run", "--profile", "debian-base", "--vcpu-count", "1", "--mem-mib", "512", "--", "true"]
|
|
)
|
|
assert run_help.command == "run"
|
|
|
|
run_help_text = _subparser_choice(parser, "run").format_help()
|
|
for flag in PUBLIC_CLI_RUN_FLAGS:
|
|
assert flag in run_help_text
|
|
|
|
demo_help_text = _subparser_choice(parser, "demo").format_help()
|
|
for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS:
|
|
assert subcommand_name in demo_help_text
|
|
|
|
|
|
def test_public_cli_version_matches_package_version() -> None:
|
|
parser = _build_parser()
|
|
stdout = io.StringIO()
|
|
with pytest.raises(SystemExit, match="0"), redirect_stdout(stdout):
|
|
parser.parse_args(["--version"])
|
|
assert stdout.getvalue().strip().endswith(f" {__version__}")
|
|
|
|
|
|
def test_public_mcp_tools_match_contract(tmp_path: Path) -> None:
|
|
pyro = Pyro(
|
|
manager=VmManager(
|
|
backend_name="mock",
|
|
base_dir=tmp_path / "vms",
|
|
network_manager=TapNetworkManager(enabled=False),
|
|
)
|
|
)
|
|
|
|
async def _run() -> tuple[str, ...]:
|
|
server = pyro.create_server()
|
|
tools = await server.list_tools()
|
|
return tuple(sorted(tool.name for tool in tools))
|
|
|
|
tool_names = asyncio.run(_run())
|
|
assert tool_names == tuple(sorted(PUBLIC_MCP_TOOLS))
|
|
|
|
|
|
def test_pyproject_exposes_single_public_cli_script() -> None:
|
|
pyproject = tomllib.loads(Path("pyproject.toml").read_text(encoding="utf-8"))
|
|
scripts = pyproject["project"]["scripts"]
|
|
assert scripts == {"pyro": "pyro_mcp.cli:main"}
|