103 lines
4.1 KiB
Python
103 lines
4.1 KiB
Python
"""Public CLI for pyro-mcp."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
from typing import Any
|
|
|
|
from pyro_mcp.api import Pyro
|
|
from pyro_mcp.demo import run_demo
|
|
from pyro_mcp.ollama_demo import DEFAULT_OLLAMA_BASE_URL, DEFAULT_OLLAMA_MODEL, run_ollama_tool_demo
|
|
from pyro_mcp.runtime import DEFAULT_PLATFORM, doctor_report
|
|
|
|
|
|
def _print_json(payload: dict[str, Any]) -> None:
|
|
print(json.dumps(payload, indent=2, sort_keys=True))
|
|
|
|
|
|
def _build_parser() -> argparse.ArgumentParser:
|
|
parser = argparse.ArgumentParser(description="pyro CLI for ephemeral Firecracker VMs.")
|
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
|
|
mcp_parser = subparsers.add_parser("mcp", help="Run the MCP server.")
|
|
mcp_subparsers = mcp_parser.add_subparsers(dest="mcp_command", required=True)
|
|
mcp_subparsers.add_parser("serve", help="Run the MCP server over stdio.")
|
|
|
|
run_parser = subparsers.add_parser("run", help="Run one command inside an ephemeral VM.")
|
|
run_parser.add_argument("--profile", required=True)
|
|
run_parser.add_argument("--vcpu-count", type=int, required=True)
|
|
run_parser.add_argument("--mem-mib", type=int, required=True)
|
|
run_parser.add_argument("--timeout-seconds", type=int, default=30)
|
|
run_parser.add_argument("--ttl-seconds", type=int, default=600)
|
|
run_parser.add_argument("--network", action="store_true")
|
|
run_parser.add_argument("command_args", nargs=argparse.REMAINDER)
|
|
|
|
doctor_parser = subparsers.add_parser("doctor", help="Inspect runtime and host diagnostics.")
|
|
doctor_parser.add_argument("--platform", default=DEFAULT_PLATFORM)
|
|
|
|
demo_parser = subparsers.add_parser("demo", help="Run built-in demos.")
|
|
demo_subparsers = demo_parser.add_subparsers(dest="demo_command")
|
|
demo_parser.add_argument("--network", action="store_true")
|
|
ollama_parser = demo_subparsers.add_parser("ollama", help="Run the Ollama MCP demo.")
|
|
ollama_parser.add_argument("--base-url", default=DEFAULT_OLLAMA_BASE_URL)
|
|
ollama_parser.add_argument("--model", default=DEFAULT_OLLAMA_MODEL)
|
|
ollama_parser.add_argument("-v", "--verbose", action="store_true")
|
|
|
|
return parser
|
|
|
|
|
|
def _require_command(command_args: list[str]) -> str:
|
|
if command_args and command_args[0] == "--":
|
|
command_args = command_args[1:]
|
|
if not command_args:
|
|
raise ValueError("command is required after `pyro run --`")
|
|
return " ".join(command_args)
|
|
|
|
|
|
def main() -> None:
|
|
args = _build_parser().parse_args()
|
|
if args.command == "mcp":
|
|
Pyro().create_server().run(transport="stdio")
|
|
return
|
|
if args.command == "run":
|
|
command = _require_command(args.command_args)
|
|
result = Pyro().run_in_vm(
|
|
profile=args.profile,
|
|
command=command,
|
|
vcpu_count=args.vcpu_count,
|
|
mem_mib=args.mem_mib,
|
|
timeout_seconds=args.timeout_seconds,
|
|
ttl_seconds=args.ttl_seconds,
|
|
network=args.network,
|
|
)
|
|
_print_json(result)
|
|
return
|
|
if args.command == "doctor":
|
|
_print_json(doctor_report(platform=args.platform))
|
|
return
|
|
if args.command == "demo" and args.demo_command == "ollama":
|
|
try:
|
|
result = run_ollama_tool_demo(
|
|
base_url=args.base_url,
|
|
model=args.model,
|
|
verbose=args.verbose,
|
|
log=lambda message: print(message, flush=True),
|
|
)
|
|
except Exception as exc: # noqa: BLE001
|
|
print(f"[error] {exc}", flush=True)
|
|
raise SystemExit(1) from exc
|
|
exec_result = result["exec_result"]
|
|
if not isinstance(exec_result, dict):
|
|
raise RuntimeError("demo produced invalid execution result")
|
|
print(
|
|
f"[summary] exit_code={int(exec_result.get('exit_code', -1))} "
|
|
f"fallback_used={bool(result.get('fallback_used'))} "
|
|
f"execution_mode={str(exec_result.get('execution_mode', 'unknown'))}",
|
|
flush=True,
|
|
)
|
|
if args.verbose:
|
|
print(f"[summary] stdout={str(exec_result.get('stdout', '')).strip()}", flush=True)
|
|
return
|
|
result = run_demo(network=bool(args.network))
|
|
_print_json(result)
|