diff --git a/CHANGELOG.md b/CHANGELOG.md index 6cfcb11..217b6e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ All notable user-visible changes to `pyro-mcp` are documented here. +## 2.1.0 + +- Added the first persistent task workspace alpha across the CLI, Python SDK, and MCP server. +- Shipped `task create`, `task exec`, `task status`, `task logs`, and `task delete` as an additive + surface alongside the existing one-shot VM contract. +- Made task workspaces persistent across separate CLI/SDK/MCP processes by storing task records on + disk under the runtime base directory. +- Added per-task command journaling so repeated workspace commands can be inspected through + `pyro task logs` or the matching SDK/MCP methods. + ## 2.0.1 - Fixed the default `pyro env pull` path so empty local profile directories no longer produce diff --git a/README.md b/README.md index 11f1e58..b4ba65e 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # pyro-mcp -`pyro-mcp` runs commands inside ephemeral Firecracker microVMs using curated Linux environments such as `debian:12`. +`pyro-mcp` runs one-shot commands and repeated task workspaces inside ephemeral Firecracker microVMs using curated Linux environments such as `debian:12`. [![PyPI version](https://img.shields.io/pypi/v/pyro-mcp.svg)](https://pypi.org/project/pyro-mcp/) @@ -18,7 +18,7 @@ It exposes the same runtime in three public forms: - First run transcript: [docs/first-run.md](docs/first-run.md) - Terminal walkthrough GIF: [docs/assets/first-run.gif](docs/assets/first-run.gif) - PyPI package: [pypi.org/project/pyro-mcp](https://pypi.org/project/pyro-mcp/) -- What's new in 2.0.1: [CHANGELOG.md#201](CHANGELOG.md#201) +- What's new in 2.1.0: [CHANGELOG.md#210](CHANGELOG.md#210) - Host requirements: [docs/host-requirements.md](docs/host-requirements.md) - Integration targets: [docs/integrations.md](docs/integrations.md) - Public contract: [docs/public-contract.md](docs/public-contract.md) @@ -55,7 +55,7 @@ What success looks like: ```bash Platform: linux-x86_64 Runtime: PASS -Catalog version: 2.0.0 +Catalog version: 2.1.0 ... [pull] phase=install environment=debian:12 [pull] phase=ready environment=debian:12 @@ -74,6 +74,7 @@ access to `registry-1.docker.io`, and needs local cache space for the guest imag After the quickstart works: - prove the full one-shot lifecycle with `uvx --from pyro-mcp pyro demo` +- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12` - move to Python or MCP via [docs/integrations.md](docs/integrations.md) ## Supported Hosts @@ -127,7 +128,7 @@ uvx --from pyro-mcp pyro env list Expected output: ```bash -Catalog version: 2.0.0 +Catalog version: 2.1.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -191,11 +192,27 @@ When you are done evaluating and want to remove stale cached environments, run ` If you prefer a fuller copy-pasteable transcript, see [docs/first-run.md](docs/first-run.md). The walkthrough GIF above was rendered from [docs/assets/first-run.tape](docs/assets/first-run.tape) using [scripts/render_tape.sh](scripts/render_tape.sh). +## Persistent Tasks + +Use `pyro run` for one-shot commands. Use `pyro task ...` when you need repeated commands in one +workspace without recreating the sandbox every time. + +```bash +pyro task create debian:12 +pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt' +pyro task exec TASK_ID -- cat note.txt +pyro task logs TASK_ID +pyro task delete TASK_ID +``` + +Task workspaces start in `/workspace` and keep command history until you delete them. For machine +consumption, add `--json` and read the returned `task_id`. + ## Public Interfaces The public user-facing interface is `pyro` and `Pyro`. After the CLI validation path works, you can choose one of three surfaces: -- `pyro` for direct CLI usage +- `pyro` for direct CLI usage, including one-shot `run` and persistent `task` workflows - `from pyro_mcp import Pyro` for Python orchestration - `pyro mcp serve` for MCP clients @@ -325,6 +342,22 @@ print(pyro.list_environments()) print(pyro.inspect_environment("debian:12")) ``` +For repeated commands in one workspace: + +```python +from pyro_mcp import Pyro + +pyro = Pyro() +task = pyro.create_task(environment="debian:12") +task_id = task["task_id"] +try: + pyro.exec_task(task_id, command="printf 'hello from task\\n' > note.txt") + result = pyro.exec_task(task_id, command="cat note.txt") + print(result["stdout"], end="") +finally: + pyro.delete_task(task_id) +``` + ## MCP Tools Primary agent-facing tool: @@ -343,10 +376,19 @@ Advanced lifecycle tools: - `vm_network_info(vm_id)` - `vm_reap_expired()` +Persistent workspace tools: + +- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false)` +- `task_exec(task_id, command, timeout_seconds=30)` +- `task_status(task_id)` +- `task_logs(task_id)` +- `task_delete(task_id)` + ## Integration Examples - Python one-shot SDK example: [examples/python_run.py](examples/python_run.py) - Python lifecycle example: [examples/python_lifecycle.py](examples/python_lifecycle.py) +- Python task workspace example: [examples/python_task.py](examples/python_task.py) - MCP client config example: [examples/mcp_client_config.md](examples/mcp_client_config.md) - Claude Desktop MCP config: [examples/claude_desktop_mcp_config.json](examples/claude_desktop_mcp_config.json) - Cursor MCP config: [examples/cursor_mcp_config.json](examples/cursor_mcp_config.json) diff --git a/docs/first-run.md b/docs/first-run.md index f76167d..5b699bf 100644 --- a/docs/first-run.md +++ b/docs/first-run.md @@ -22,7 +22,7 @@ Networking: tun=yes ip_forward=yes ```bash $ uvx --from pyro-mcp pyro env list -Catalog version: 2.0.0 +Catalog version: 2.1.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -70,11 +70,32 @@ deterministic structured result. ```bash $ uvx --from pyro-mcp pyro demo +$ uvx --from pyro-mcp pyro task create debian:12 $ uvx --from pyro-mcp pyro mcp serve ``` `pyro demo` proves the one-shot create/start/exec/delete VM lifecycle works end to end. +When you need repeated commands in one sandbox, switch to `pyro task ...`: + +```bash +$ uvx --from pyro-mcp pyro task create debian:12 +Task: ... +Environment: debian:12 +State: started +Workspace: /workspace +Execution mode: guest_vsock +Resources: 1 vCPU / 1024 MiB +Command count: 0 + +$ uvx --from pyro-mcp pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt' +[task-exec] task_id=... sequence=1 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=... + +$ uvx --from pyro-mcp pyro task exec TASK_ID -- cat note.txt +hello from task +[task-exec] task_id=... sequence=2 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=... +``` + Example output: ```json diff --git a/docs/install.md b/docs/install.md index 977fc56..6332804 100644 --- a/docs/install.md +++ b/docs/install.md @@ -83,7 +83,7 @@ uvx --from pyro-mcp pyro env list Expected output: ```bash -Catalog version: 2.0.0 +Catalog version: 2.1.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -174,10 +174,26 @@ pyro run debian:12 -- git --version After the CLI path works, you can move on to: +- persistent workspaces: `pyro task create debian:12` - MCP: `pyro mcp serve` - Python SDK: `from pyro_mcp import Pyro` - Demos: `pyro demo` or `pyro demo --network` +## Persistent Task Workspace + +Use `pyro task ...` when you need repeated commands in one sandbox instead of one-shot `pyro run`. + +```bash +pyro task create debian:12 +pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt' +pyro task exec TASK_ID -- cat note.txt +pyro task logs TASK_ID +pyro task delete TASK_ID +``` + +Task commands default to the persistent `/workspace` directory inside the guest. If you need the +task identifier programmatically, use `--json` and read the `task_id` field. + ## Contributor Clone ```bash diff --git a/docs/integrations.md b/docs/integrations.md index 5501508..242907c 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -7,7 +7,7 @@ CLI path in [install.md](install.md) or [first-run.md](first-run.md). ## Recommended Default -Use `vm_run` first. +Use `vm_run` first for one-shot commands. That keeps the model-facing contract small: @@ -16,7 +16,8 @@ That keeps the model-facing contract small: - one ephemeral VM - automatic cleanup -Only move to lifecycle tools when the agent truly needs VM state across multiple calls. +Move to `task_*` only when the agent truly needs repeated commands in one workspace across +multiple calls. ## OpenAI Responses API @@ -29,6 +30,7 @@ Best when: Recommended surface: - `vm_run` +- `task_create` + `task_exec` when the agent needs persistent workspace state Canonical example: @@ -63,17 +65,20 @@ Best when: Recommended default: - `Pyro.run_in_vm(...)` +- `Pyro.create_task(...)` + `Pyro.exec_task(...)` when repeated workspace commands are required Lifecycle note: - `Pyro.exec_vm(...)` runs one command and auto-cleans the VM afterward - use `create_vm(...)` + `start_vm(...)` only when you need pre-exec inspection or status before that final exec +- use `create_task(...)` when the agent needs repeated commands in one persistent `/workspace` Examples: - [examples/python_run.py](../examples/python_run.py) - [examples/python_lifecycle.py](../examples/python_lifecycle.py) +- [examples/python_task.py](../examples/python_task.py) ## Agent Framework Wrappers @@ -91,8 +96,8 @@ Best when: Recommended pattern: - keep the framework wrapper thin -- map framework tool input directly onto `vm_run` -- avoid exposing lifecycle tools unless the framework truly needs them +- map one-shot framework tool input directly onto `vm_run` +- expose `task_*` only when the framework truly needs repeated commands in one workspace Concrete example: diff --git a/docs/public-contract.md b/docs/public-contract.md index b0150b1..3ecfe9b 100644 --- a/docs/public-contract.md +++ b/docs/public-contract.md @@ -19,6 +19,11 @@ Top-level commands: - `pyro env prune` - `pyro mcp serve` - `pyro run` +- `pyro task create` +- `pyro task exec` +- `pyro task status` +- `pyro task logs` +- `pyro task delete` - `pyro doctor` - `pyro demo` - `pyro demo ollama` @@ -40,6 +45,9 @@ Behavioral guarantees: - `pyro run` fails if guest boot or guest exec is unavailable unless `--allow-host-compat` is set. - `pyro run`, `pyro env list`, `pyro env pull`, `pyro env inspect`, `pyro env prune`, and `pyro doctor` are human-readable by default and return structured JSON with `--json`. - `pyro demo ollama` prints log lines plus a final summary line. +- `pyro task create` auto-starts a persistent workspace. +- `pyro task exec` runs in the persistent `/workspace` for that task and does not auto-clean. +- `pyro task logs` returns persisted command history for that task until `pyro task delete`. ## Python SDK Contract @@ -56,11 +64,16 @@ Supported public entrypoints: - `Pyro.inspect_environment(environment)` - `Pyro.prune_environments()` - `Pyro.create_vm(...)` +- `Pyro.create_task(...)` - `Pyro.start_vm(vm_id)` - `Pyro.exec_vm(vm_id, *, command, timeout_seconds=30)` +- `Pyro.exec_task(task_id, *, command, timeout_seconds=30)` - `Pyro.stop_vm(vm_id)` - `Pyro.delete_vm(vm_id)` +- `Pyro.delete_task(task_id)` - `Pyro.status_vm(vm_id)` +- `Pyro.status_task(task_id)` +- `Pyro.logs_task(task_id)` - `Pyro.network_info_vm(vm_id)` - `Pyro.reap_expired()` - `Pyro.run_in_vm(...)` @@ -73,11 +86,16 @@ Stable public method names: - `inspect_environment(environment)` - `prune_environments()` - `create_vm(...)` +- `create_task(...)` - `start_vm(vm_id)` - `exec_vm(vm_id, *, command, timeout_seconds=30)` +- `exec_task(task_id, *, command, timeout_seconds=30)` - `stop_vm(vm_id)` - `delete_vm(vm_id)` +- `delete_task(task_id)` - `status_vm(vm_id)` +- `status_task(task_id)` +- `logs_task(task_id)` - `network_info_vm(vm_id)` - `reap_expired()` - `run_in_vm(...)` @@ -85,8 +103,11 @@ Stable public method names: Behavioral defaults: - `Pyro.create_vm(...)` and `Pyro.run_in_vm(...)` default to `vcpu_count=1` and `mem_mib=1024`. +- `Pyro.create_task(...)` defaults to `vcpu_count=1` and `mem_mib=1024`. - `allow_host_compat` defaults to `False` on `create_vm(...)` and `run_in_vm(...)`. +- `allow_host_compat` defaults to `False` on `create_task(...)`. - `Pyro.exec_vm(...)` runs one command and auto-cleans that VM after the exec completes. +- `Pyro.exec_task(...)` runs one command in the persistent task workspace and leaves the task alive. ## MCP Contract @@ -106,11 +127,22 @@ Advanced lifecycle tools: - `vm_network_info` - `vm_reap_expired` +Task workspace tools: + +- `task_create` +- `task_exec` +- `task_status` +- `task_logs` +- `task_delete` + Behavioral defaults: - `vm_run` and `vm_create` default to `vcpu_count=1` and `mem_mib=1024`. +- `task_create` defaults to `vcpu_count=1` and `mem_mib=1024`. - `vm_run` and `vm_create` expose `allow_host_compat`, which defaults to `false`. +- `task_create` exposes `allow_host_compat`, which defaults to `false`. - `vm_exec` runs one command and auto-cleans that VM after the exec completes. +- `task_exec` runs one command in a persistent `/workspace` and leaves the task alive. ## Versioning Rule diff --git a/examples/python_task.py b/examples/python_task.py new file mode 100644 index 0000000..c1c7b98 --- /dev/null +++ b/examples/python_task.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from pyro_mcp import Pyro + + +def main() -> None: + pyro = Pyro() + created = pyro.create_task(environment="debian:12") + task_id = str(created["task_id"]) + try: + pyro.exec_task(task_id, command="printf 'hello from task\\n' > note.txt") + result = pyro.exec_task(task_id, command="cat note.txt") + print(result["stdout"], end="") + logs = pyro.logs_task(task_id) + print(f"task_id={task_id} command_count={logs['count']}") + finally: + pyro.delete_task(task_id) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index c13a4e7..6a50302 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pyro-mcp" -version = "2.0.1" +version = "2.1.0" description = "Curated Linux environments for ephemeral Firecracker-backed VM execution." readme = "README.md" license = { file = "LICENSE" } diff --git a/src/pyro_mcp/api.py b/src/pyro_mcp/api.py index e2a0b1d..82f3abf 100644 --- a/src/pyro_mcp/api.py +++ b/src/pyro_mcp/api.py @@ -77,6 +77,43 @@ class Pyro: def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]: return self._manager.exec_vm(vm_id, command=command, timeout_seconds=timeout_seconds) + def create_task( + self, + *, + environment: str, + vcpu_count: int = DEFAULT_VCPU_COUNT, + mem_mib: int = DEFAULT_MEM_MIB, + ttl_seconds: int = DEFAULT_TTL_SECONDS, + network: bool = False, + allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, + ) -> dict[str, Any]: + return self._manager.create_task( + environment=environment, + vcpu_count=vcpu_count, + mem_mib=mem_mib, + ttl_seconds=ttl_seconds, + network=network, + allow_host_compat=allow_host_compat, + ) + + def exec_task( + self, + task_id: str, + *, + command: str, + timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, + ) -> dict[str, Any]: + return self._manager.exec_task(task_id, command=command, timeout_seconds=timeout_seconds) + + def status_task(self, task_id: str) -> dict[str, Any]: + return self._manager.status_task(task_id) + + def logs_task(self, task_id: str) -> dict[str, Any]: + return self._manager.logs_task(task_id) + + def delete_task(self, task_id: str) -> dict[str, Any]: + return self._manager.delete_task(task_id) + def stop_vm(self, vm_id: str) -> dict[str, Any]: return self._manager.stop_vm(vm_id) @@ -200,4 +237,47 @@ class Pyro: """Delete VMs whose TTL has expired.""" return self.reap_expired() + @server.tool() + async def task_create( + environment: str, + vcpu_count: int = DEFAULT_VCPU_COUNT, + mem_mib: int = DEFAULT_MEM_MIB, + ttl_seconds: int = DEFAULT_TTL_SECONDS, + network: bool = False, + allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, + ) -> dict[str, Any]: + """Create and start a persistent task workspace.""" + return self.create_task( + environment=environment, + vcpu_count=vcpu_count, + mem_mib=mem_mib, + ttl_seconds=ttl_seconds, + network=network, + allow_host_compat=allow_host_compat, + ) + + @server.tool() + async def task_exec( + task_id: str, + command: str, + timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS, + ) -> dict[str, Any]: + """Run one command inside an existing task workspace.""" + return self.exec_task(task_id, command=command, timeout_seconds=timeout_seconds) + + @server.tool() + async def task_status(task_id: str) -> dict[str, Any]: + """Inspect task state and latest command metadata.""" + return self.status_task(task_id) + + @server.tool() + async def task_logs(task_id: str) -> dict[str, Any]: + """Return persisted command history for one task.""" + return self.logs_task(task_id) + + @server.tool() + async def task_delete(task_id: str) -> dict[str, Any]: + """Delete a task workspace and its backing sandbox.""" + return self.delete_task(task_id) + return server diff --git a/src/pyro_mcp/cli.py b/src/pyro_mcp/cli.py index 39c2434..16c056c 100644 --- a/src/pyro_mcp/cli.py +++ b/src/pyro_mcp/cli.py @@ -17,6 +17,7 @@ from pyro_mcp.vm_environments import DEFAULT_CATALOG_VERSION from pyro_mcp.vm_manager import ( DEFAULT_MEM_MIB, DEFAULT_VCPU_COUNT, + TASK_WORKSPACE_GUEST_PATH, ) @@ -149,6 +150,67 @@ def _print_doctor_human(payload: dict[str, Any]) -> None: print(f"- {issue}") +def _print_task_summary_human(payload: dict[str, Any], *, action: str) -> None: + print(f"{action}: {str(payload.get('task_id', 'unknown'))}") + print(f"Environment: {str(payload.get('environment', 'unknown'))}") + print(f"State: {str(payload.get('state', 'unknown'))}") + print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}") + print(f"Execution mode: {str(payload.get('execution_mode', 'pending'))}") + print( + f"Resources: {int(payload.get('vcpu_count', 0))} vCPU / " + f"{int(payload.get('mem_mib', 0))} MiB" + ) + print(f"Command count: {int(payload.get('command_count', 0))}") + last_command = payload.get("last_command") + if isinstance(last_command, dict): + print( + "Last command: " + f"{str(last_command.get('command', 'unknown'))} " + f"(exit_code={int(last_command.get('exit_code', -1))})" + ) + + +def _print_task_exec_human(payload: dict[str, Any]) -> None: + stdout = str(payload.get("stdout", "")) + stderr = str(payload.get("stderr", "")) + _write_stream(stdout, stream=sys.stdout) + _write_stream(stderr, stream=sys.stderr) + print( + "[task-exec] " + f"task_id={str(payload.get('task_id', 'unknown'))} " + f"sequence={int(payload.get('sequence', 0))} " + f"cwd={str(payload.get('cwd', TASK_WORKSPACE_GUEST_PATH))} " + f"execution_mode={str(payload.get('execution_mode', 'unknown'))} " + f"exit_code={int(payload.get('exit_code', 1))} " + f"duration_ms={int(payload.get('duration_ms', 0))}", + file=sys.stderr, + flush=True, + ) + + +def _print_task_logs_human(payload: dict[str, Any]) -> None: + entries = payload.get("entries") + if not isinstance(entries, list) or not entries: + print("No task logs found.") + return + for entry in entries: + if not isinstance(entry, dict): + continue + print( + f"#{int(entry.get('sequence', 0))} " + f"exit_code={int(entry.get('exit_code', -1))} " + f"duration_ms={int(entry.get('duration_ms', 0))} " + f"cwd={str(entry.get('cwd', TASK_WORKSPACE_GUEST_PATH))}" + ) + print(f"$ {str(entry.get('command', ''))}") + stdout = str(entry.get("stdout", "")) + stderr = str(entry.get("stderr", "")) + if stdout != "": + print(stdout, end="" if stdout.endswith("\n") else "\n") + if stderr != "": + print(stderr, end="" if stderr.endswith("\n") else "\n", file=sys.stderr) + + class _HelpFormatter( argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter, @@ -178,6 +240,9 @@ def _build_parser() -> argparse.ArgumentParser: pyro env pull debian:12 pyro run debian:12 -- git --version + Need repeated commands in one workspace after that? + pyro task create debian:12 + Use `pyro mcp serve` only after the CLI validation path works. """ ), @@ -371,6 +436,152 @@ def _build_parser() -> argparse.ArgumentParser: ), ) + task_parser = subparsers.add_parser( + "task", + help="Manage persistent task workspaces.", + description=( + "Create a persistent workspace when you need repeated commands in one " + "sandbox instead of one-shot `pyro run`." + ), + epilog=dedent( + """ + Examples: + pyro task create debian:12 + pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt' + pyro task logs TASK_ID + """ + ), + formatter_class=_HelpFormatter, + ) + task_subparsers = task_parser.add_subparsers(dest="task_command", required=True, metavar="TASK") + task_create_parser = task_subparsers.add_parser( + "create", + help="Create and start a persistent task workspace.", + description="Create a task workspace that stays alive across repeated exec calls.", + epilog="Example:\n pyro task create debian:12", + formatter_class=_HelpFormatter, + ) + task_create_parser.add_argument( + "environment", + metavar="ENVIRONMENT", + help="Curated environment to boot, for example `debian:12`.", + ) + task_create_parser.add_argument( + "--vcpu-count", + type=int, + default=DEFAULT_VCPU_COUNT, + help="Number of virtual CPUs to allocate to the task guest.", + ) + task_create_parser.add_argument( + "--mem-mib", + type=int, + default=DEFAULT_MEM_MIB, + help="Guest memory allocation in MiB.", + ) + task_create_parser.add_argument( + "--ttl-seconds", + type=int, + default=600, + help="Time-to-live for the task before automatic cleanup.", + ) + task_create_parser.add_argument( + "--network", + action="store_true", + help="Enable outbound guest networking for the task guest.", + ) + task_create_parser.add_argument( + "--allow-host-compat", + action="store_true", + help=( + "Opt into host-side compatibility execution if guest boot or guest exec " + "is unavailable." + ), + ) + task_create_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + task_exec_parser = task_subparsers.add_parser( + "exec", + help="Run one command inside an existing task workspace.", + description="Run one non-interactive command in the persistent `/workspace` for a task.", + epilog="Example:\n pyro task exec TASK_ID -- cat note.txt", + formatter_class=_HelpFormatter, + ) + task_exec_parser.add_argument("task_id", metavar="TASK_ID", help="Persistent task identifier.") + task_exec_parser.add_argument( + "--timeout-seconds", + type=int, + default=30, + help="Maximum time allowed for the task command.", + ) + task_exec_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + task_exec_parser.add_argument( + "command_args", + nargs="*", + metavar="ARG", + help=( + "Command and arguments to run inside the task workspace. Prefix them with `--`, " + "for example `pyro task exec TASK_ID -- cat note.txt`." + ), + ) + task_status_parser = task_subparsers.add_parser( + "status", + help="Inspect one task workspace.", + description="Show task state, sizing, workspace path, and latest command metadata.", + epilog="Example:\n pyro task status TASK_ID", + formatter_class=_HelpFormatter, + ) + task_status_parser.add_argument( + "task_id", + metavar="TASK_ID", + help="Persistent task identifier.", + ) + task_status_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + task_logs_parser = task_subparsers.add_parser( + "logs", + help="Show command history for one task.", + description="Show persisted command history, including stdout and stderr, for one task.", + epilog="Example:\n pyro task logs TASK_ID", + formatter_class=_HelpFormatter, + ) + task_logs_parser.add_argument( + "task_id", + metavar="TASK_ID", + help="Persistent task identifier.", + ) + task_logs_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + task_delete_parser = task_subparsers.add_parser( + "delete", + help="Delete one task workspace.", + description="Stop the backing sandbox if needed and remove the task workspace.", + epilog="Example:\n pyro task delete TASK_ID", + formatter_class=_HelpFormatter, + ) + task_delete_parser.add_argument( + "task_id", + metavar="TASK_ID", + help="Persistent task identifier.", + ) + task_delete_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + doctor_parser = subparsers.add_parser( "doctor", help="Inspect runtime and host diagnostics.", @@ -451,7 +662,7 @@ def _require_command(command_args: list[str]) -> str: if command_args and command_args[0] == "--": command_args = command_args[1:] if not command_args: - raise ValueError("command is required after `pyro run --`") + raise ValueError("command is required after `--`") return " ".join(command_args) @@ -544,6 +755,70 @@ def main() -> None: if exit_code != 0: raise SystemExit(exit_code) return + if args.command == "task": + if args.task_command == "create": + payload = pyro.create_task( + environment=args.environment, + vcpu_count=args.vcpu_count, + mem_mib=args.mem_mib, + ttl_seconds=args.ttl_seconds, + network=args.network, + allow_host_compat=args.allow_host_compat, + ) + if bool(args.json): + _print_json(payload) + else: + _print_task_summary_human(payload, action="Task") + return + if args.task_command == "exec": + command = _require_command(args.command_args) + if bool(args.json): + try: + payload = pyro.exec_task( + args.task_id, + command=command, + timeout_seconds=args.timeout_seconds, + ) + except Exception as exc: # noqa: BLE001 + _print_json({"ok": False, "error": str(exc)}) + raise SystemExit(1) from exc + _print_json(payload) + else: + try: + payload = pyro.exec_task( + args.task_id, + command=command, + timeout_seconds=args.timeout_seconds, + ) + except Exception as exc: # noqa: BLE001 + print(f"[error] {exc}", file=sys.stderr, flush=True) + raise SystemExit(1) from exc + _print_task_exec_human(payload) + exit_code = int(payload.get("exit_code", 1)) + if exit_code != 0: + raise SystemExit(exit_code) + return + if args.task_command == "status": + payload = pyro.status_task(args.task_id) + if bool(args.json): + _print_json(payload) + else: + _print_task_summary_human(payload, action="Task") + return + if args.task_command == "logs": + payload = pyro.logs_task(args.task_id) + if bool(args.json): + _print_json(payload) + else: + _print_task_logs_human(payload) + return + if args.task_command == "delete": + payload = pyro.delete_task(args.task_id) + if bool(args.json): + _print_json(payload) + else: + print(f"Deleted task: {str(payload.get('task_id', 'unknown'))}") + return if args.command == "doctor": payload = doctor_report(platform=args.platform) if bool(args.json): diff --git a/src/pyro_mcp/contract.py b/src/pyro_mcp/contract.py index 131907d..4e6866e 100644 --- a/src/pyro_mcp/contract.py +++ b/src/pyro_mcp/contract.py @@ -2,9 +2,10 @@ from __future__ import annotations -PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run") +PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task") PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",) PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune") +PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status") PUBLIC_CLI_RUN_FLAGS = ( "--vcpu-count", "--mem-mib", @@ -17,17 +18,22 @@ PUBLIC_CLI_RUN_FLAGS = ( PUBLIC_SDK_METHODS = ( "create_server", + "create_task", "create_vm", + "delete_task", "delete_vm", + "exec_task", "exec_vm", "inspect_environment", "list_environments", + "logs_task", "network_info_vm", "prune_environments", "pull_environment", "reap_expired", "run_in_vm", "start_vm", + "status_task", "status_vm", "stop_vm", ) @@ -43,4 +49,9 @@ PUBLIC_MCP_TOOLS = ( "vm_start", "vm_status", "vm_stop", + "task_create", + "task_delete", + "task_exec", + "task_logs", + "task_status", ) diff --git a/src/pyro_mcp/vm_environments.py b/src/pyro_mcp/vm_environments.py index 801789a..4067fe0 100644 --- a/src/pyro_mcp/vm_environments.py +++ b/src/pyro_mcp/vm_environments.py @@ -19,7 +19,7 @@ from typing import Any from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths DEFAULT_ENVIRONMENT_VERSION = "1.0.0" -DEFAULT_CATALOG_VERSION = "2.0.0" +DEFAULT_CATALOG_VERSION = "2.1.0" OCI_MANIFEST_ACCEPT = ", ".join( ( "application/vnd.oci.image.index.v1+json", diff --git a/src/pyro_mcp/vm_manager.py b/src/pyro_mcp/vm_manager.py index 9d005de..e93b4c8 100644 --- a/src/pyro_mcp/vm_manager.py +++ b/src/pyro_mcp/vm_manager.py @@ -1,8 +1,10 @@ -"""Lifecycle manager for ephemeral VM environments.""" +"""Lifecycle manager for ephemeral VM environments and persistent tasks.""" from __future__ import annotations +import json import os +import shlex import shutil import signal import subprocess @@ -11,7 +13,7 @@ import time import uuid from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Literal +from typing import Any, Literal, cast from pyro_mcp.runtime import ( RuntimeCapabilities, @@ -32,6 +34,12 @@ DEFAULT_TIMEOUT_SECONDS = 30 DEFAULT_TTL_SECONDS = 600 DEFAULT_ALLOW_HOST_COMPAT = False +TASK_LAYOUT_VERSION = 1 +TASK_WORKSPACE_DIRNAME = "workspace" +TASK_COMMANDS_DIRNAME = "commands" +TASK_RUNTIME_DIRNAME = "runtime" +TASK_WORKSPACE_GUEST_PATH = "/workspace" + @dataclass class VmInstance: @@ -54,6 +62,116 @@ class VmInstance: network: NetworkConfig | None = None +@dataclass +class TaskRecord: + """Persistent task metadata stored on disk.""" + + task_id: str + environment: str + vcpu_count: int + mem_mib: int + ttl_seconds: int + created_at: float + expires_at: float + state: VmState + network_requested: bool + allow_host_compat: bool + firecracker_pid: int | None = None + last_error: str | None = None + metadata: dict[str, str] = field(default_factory=dict) + network: NetworkConfig | None = None + command_count: int = 0 + last_command: dict[str, Any] | None = None + + @classmethod + def from_instance( + cls, + instance: VmInstance, + *, + command_count: int = 0, + last_command: dict[str, Any] | None = None, + ) -> TaskRecord: + return cls( + task_id=instance.vm_id, + environment=instance.environment, + vcpu_count=instance.vcpu_count, + mem_mib=instance.mem_mib, + ttl_seconds=instance.ttl_seconds, + created_at=instance.created_at, + expires_at=instance.expires_at, + state=instance.state, + network_requested=instance.network_requested, + allow_host_compat=instance.allow_host_compat, + firecracker_pid=instance.firecracker_pid, + last_error=instance.last_error, + metadata=dict(instance.metadata), + network=instance.network, + command_count=command_count, + last_command=last_command, + ) + + def to_instance(self, *, workdir: Path) -> VmInstance: + return VmInstance( + vm_id=self.task_id, + environment=self.environment, + vcpu_count=self.vcpu_count, + mem_mib=self.mem_mib, + ttl_seconds=self.ttl_seconds, + created_at=self.created_at, + expires_at=self.expires_at, + workdir=workdir, + state=self.state, + network_requested=self.network_requested, + allow_host_compat=self.allow_host_compat, + firecracker_pid=self.firecracker_pid, + last_error=self.last_error, + metadata=dict(self.metadata), + network=self.network, + ) + + def to_payload(self) -> dict[str, Any]: + return { + "layout_version": TASK_LAYOUT_VERSION, + "task_id": self.task_id, + "environment": self.environment, + "vcpu_count": self.vcpu_count, + "mem_mib": self.mem_mib, + "ttl_seconds": self.ttl_seconds, + "created_at": self.created_at, + "expires_at": self.expires_at, + "state": self.state, + "network_requested": self.network_requested, + "allow_host_compat": self.allow_host_compat, + "firecracker_pid": self.firecracker_pid, + "last_error": self.last_error, + "metadata": self.metadata, + "network": _serialize_network(self.network), + "command_count": self.command_count, + "last_command": self.last_command, + } + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> TaskRecord: + return cls( + task_id=str(payload["task_id"]), + environment=str(payload["environment"]), + vcpu_count=int(payload["vcpu_count"]), + mem_mib=int(payload["mem_mib"]), + ttl_seconds=int(payload["ttl_seconds"]), + created_at=float(payload["created_at"]), + expires_at=float(payload["expires_at"]), + state=cast(VmState, str(payload.get("state", "stopped"))), + network_requested=bool(payload.get("network_requested", False)), + allow_host_compat=bool(payload.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)), + firecracker_pid=_optional_int(payload.get("firecracker_pid")), + last_error=_optional_str(payload.get("last_error")), + metadata=_string_dict(payload.get("metadata")), + network=_deserialize_network(payload.get("network")), + command_count=int(payload.get("command_count", 0)), + last_command=_optional_dict(payload.get("last_command")), + ) + + @dataclass(frozen=True) class VmExecResult: """Command execution output.""" @@ -64,6 +182,72 @@ class VmExecResult: duration_ms: int +def _optional_int(value: object) -> int | None: + if value is None: + return None + if isinstance(value, bool): + return int(value) + if isinstance(value, int): + return value + if isinstance(value, float): + return int(value) + if isinstance(value, str): + return int(value) + raise TypeError("expected integer-compatible payload") + + +def _optional_str(value: object) -> str | None: + if value is None: + return None + return str(value) + + +def _optional_dict(value: object) -> dict[str, Any] | None: + if value is None: + return None + if not isinstance(value, dict): + raise TypeError("expected dictionary payload") + return dict(value) + + +def _string_dict(value: object) -> dict[str, str]: + if not isinstance(value, dict): + return {} + return {str(key): str(item) for key, item in value.items()} + + +def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None: + if network is None: + return None + return { + "vm_id": network.vm_id, + "tap_name": network.tap_name, + "guest_ip": network.guest_ip, + "gateway_ip": network.gateway_ip, + "subnet_cidr": network.subnet_cidr, + "mac_address": network.mac_address, + "dns_servers": list(network.dns_servers), + } + + +def _deserialize_network(payload: object) -> NetworkConfig | None: + if payload is None: + return None + if not isinstance(payload, dict): + raise TypeError("expected dictionary payload") + dns_servers = payload.get("dns_servers", []) + dns_values = tuple(str(item) for item in dns_servers) if isinstance(dns_servers, list) else () + return NetworkConfig( + vm_id=str(payload["vm_id"]), + tap_name=str(payload["tap_name"]), + guest_ip=str(payload["guest_ip"]), + gateway_ip=str(payload["gateway_ip"]), + subnet_cidr=str(payload["subnet_cidr"]), + mac_address=str(payload["mac_address"]), + dns_servers=dns_values, + ) + + def _run_host_command(workdir: Path, command: str, timeout_seconds: int) -> VmExecResult: started = time.monotonic() env = {"PATH": os.environ.get("PATH", ""), "HOME": str(workdir)} @@ -109,6 +293,25 @@ def _copy_rootfs(source: Path, dest: Path) -> str: return "copy2" +def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str: + if cwd is None: + return command + quoted_cwd = shlex.quote(cwd) + return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}" + + +def _pid_is_running(pid: int | None) -> bool: + if pid is None: + return False + try: + os.kill(pid, 0) + except ProcessLookupError: + return False + except PermissionError: + return True + return True + + class VmBackend: """Backend interface for lifecycle operations.""" @@ -119,7 +322,12 @@ class VmBackend: raise NotImplementedError def exec( # pragma: no cover - self, instance: VmInstance, command: str, timeout_seconds: int + self, + instance: VmInstance, + command: str, + timeout_seconds: int, + *, + workdir: Path | None = None, ) -> VmExecResult: raise NotImplementedError @@ -140,8 +348,15 @@ class MockBackend(VmBackend): marker_path = instance.workdir / ".started" marker_path.write_text("started\n", encoding="utf-8") - def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: - return _run_host_command(instance.workdir, command, timeout_seconds) + def exec( + self, + instance: VmInstance, + command: str, + timeout_seconds: int, + *, + workdir: Path | None = None, + ) -> VmExecResult: + return _run_host_command(workdir or instance.workdir, command, timeout_seconds) def stop(self, instance: VmInstance) -> None: marker_path = instance.workdir / ".stopped" @@ -256,6 +471,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover stdout=serial_fp, stderr=subprocess.STDOUT, text=True, + start_new_session=True, ) self._processes[instance.vm_id] = process time.sleep(2) @@ -273,7 +489,14 @@ class FirecrackerBackend(VmBackend): # pragma: no cover ) instance.metadata["boot_mode"] = "native" - def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult: + def exec( + self, + instance: VmInstance, + command: str, + timeout_seconds: int, + *, + workdir: Path | None = None, + ) -> VmExecResult: if self._runtime_capabilities.supports_guest_exec: guest_cid = int(instance.metadata["guest_cid"]) port = int(instance.metadata["guest_exec_port"]) @@ -302,7 +525,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover duration_ms=response.duration_ms, ) instance.metadata["execution_mode"] = "host_compat" - return _run_host_command(instance.workdir, command, timeout_seconds) + return _run_host_command(workdir or instance.workdir, command, timeout_seconds) def stop(self, instance: VmInstance) -> None: process = self._processes.pop(instance.vm_id, None) @@ -341,7 +564,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover class VmManager: - """In-process lifecycle manager for ephemeral VM environments.""" + """In-process lifecycle manager for ephemeral VM environments and tasks.""" MIN_VCPUS = 1 MAX_VCPUS = 8 @@ -367,6 +590,7 @@ class VmManager: ) -> None: self._backend_name = backend_name or "firecracker" self._base_dir = base_dir or Path("/tmp/pyro-mcp") + self._tasks_dir = self._base_dir / "tasks" resolved_cache_dir = cache_dir or default_cache_dir() self._runtime_paths = runtime_paths if self._backend_name == "firecracker": @@ -399,6 +623,7 @@ class VmManager: self._lock = threading.Lock() self._instances: dict[str, VmInstance] = {} self._base_dir.mkdir(parents=True, exist_ok=True) + self._tasks_dir.mkdir(parents=True, exist_ok=True) self._backend = self._build_backend() def _build_backend(self) -> VmBackend: @@ -443,7 +668,8 @@ class VmManager: now = time.time() with self._lock: self._reap_expired_locked(now) - active_count = len(self._instances) + self._reap_expired_tasks_locked(now) + active_count = len(self._instances) + self._count_tasks_locked() if active_count >= self._max_active_vms: raise RuntimeError( f"max active VMs reached ({self._max_active_vms}); delete old VMs first" @@ -501,36 +727,24 @@ class VmManager: with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) - if instance.state not in {"created", "stopped"}: - raise RuntimeError(f"vm {vm_id} cannot be started from state {instance.state!r}") - self._require_guest_boot_or_opt_in(instance) - if not self._runtime_capabilities.supports_vm_boot: - instance.metadata["execution_mode"] = "host_compat" - instance.metadata["boot_mode"] = "compat" - if self._runtime_capabilities.reason is not None: - instance.metadata["runtime_reason"] = self._runtime_capabilities.reason - self._backend.start(instance) - instance.state = "started" + self._start_instance_locked(instance) return self._serialize(instance) def exec_vm(self, vm_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: - if timeout_seconds <= 0: - raise ValueError("timeout_seconds must be positive") with self._lock: instance = self._get_instance_locked(vm_id) self._ensure_not_expired_locked(instance, time.time()) - if instance.state != "started": - raise RuntimeError(f"vm {vm_id} must be in 'started' state before vm_exec") - self._require_guest_exec_or_opt_in(instance) - if not self._runtime_capabilities.supports_guest_exec: - instance.metadata["execution_mode"] = "host_compat" - exec_result = self._backend.exec(instance, command, timeout_seconds) - execution_mode = instance.metadata.get("execution_mode", "unknown") + exec_instance = instance + exec_result, execution_mode = self._exec_instance( + exec_instance, + command=command, + timeout_seconds=timeout_seconds, + ) cleanup = self.delete_vm(vm_id, reason="post_exec_cleanup") return { "vm_id": vm_id, - "environment": instance.environment, - "environment_version": instance.metadata.get("environment_version"), + "environment": exec_instance.environment, + "environment_version": exec_instance.metadata.get("environment_version"), "command": command, "stdout": exec_result.stdout, "stderr": exec_result.stderr, @@ -591,6 +805,154 @@ class VmManager: del self._instances[vm_id] return {"deleted_vm_ids": expired_vm_ids, "count": len(expired_vm_ids)} + def create_task( + self, + *, + environment: str, + vcpu_count: int = DEFAULT_VCPU_COUNT, + mem_mib: int = DEFAULT_MEM_MIB, + ttl_seconds: int = DEFAULT_TTL_SECONDS, + network: bool = False, + allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, + ) -> dict[str, Any]: + self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds) + get_environment(environment, runtime_paths=self._runtime_paths) + now = time.time() + task_id = uuid.uuid4().hex[:12] + task_dir = self._task_dir(task_id) + runtime_dir = self._task_runtime_dir(task_id) + workspace_dir = self._task_workspace_dir(task_id) + commands_dir = self._task_commands_dir(task_id) + task_dir.mkdir(parents=True, exist_ok=False) + workspace_dir.mkdir(parents=True, exist_ok=True) + commands_dir.mkdir(parents=True, exist_ok=True) + instance = VmInstance( + vm_id=task_id, + environment=environment, + vcpu_count=vcpu_count, + mem_mib=mem_mib, + ttl_seconds=ttl_seconds, + created_at=now, + expires_at=now + ttl_seconds, + workdir=runtime_dir, + network_requested=network, + allow_host_compat=allow_host_compat, + ) + instance.metadata["allow_host_compat"] = str(allow_host_compat).lower() + instance.metadata["workspace_path"] = TASK_WORKSPACE_GUEST_PATH + instance.metadata["workspace_host_dir"] = str(workspace_dir) + try: + with self._lock: + self._reap_expired_locked(now) + self._reap_expired_tasks_locked(now) + active_count = len(self._instances) + self._count_tasks_locked() + if active_count >= self._max_active_vms: + raise RuntimeError( + f"max active VMs reached ({self._max_active_vms}); delete old VMs first" + ) + self._backend.create(instance) + with self._lock: + self._start_instance_locked(instance) + self._require_guest_exec_or_opt_in(instance) + if self._runtime_capabilities.supports_guest_exec: + self._backend.exec( + instance, + f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}", + 10, + ) + else: + instance.metadata["execution_mode"] = "host_compat" + task = TaskRecord.from_instance(instance) + self._save_task_locked(task) + return self._serialize_task(task) + except Exception: + if runtime_dir.exists(): + try: + if instance.state == "started": + self._backend.stop(instance) + instance.state = "stopped" + except Exception: + pass + try: + self._backend.delete(instance) + except Exception: + pass + shutil.rmtree(task_dir, ignore_errors=True) + raise + + def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]: + if timeout_seconds <= 0: + raise ValueError("timeout_seconds must be positive") + with self._lock: + task = self._load_task_locked(task_id) + self._ensure_task_not_expired_locked(task, time.time()) + self._refresh_task_liveness_locked(task) + if task.state != "started": + raise RuntimeError(f"task {task_id} must be in 'started' state before task_exec") + instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id)) + exec_result, execution_mode = self._exec_instance( + instance, + command=command, + timeout_seconds=timeout_seconds, + host_workdir=self._task_workspace_dir(task.task_id), + guest_cwd=TASK_WORKSPACE_GUEST_PATH, + ) + with self._lock: + task = self._load_task_locked(task_id) + task.state = instance.state + task.firecracker_pid = instance.firecracker_pid + task.last_error = instance.last_error + task.metadata = dict(instance.metadata) + entry = self._record_task_command_locked( + task, + command=command, + exec_result=exec_result, + execution_mode=execution_mode, + cwd=TASK_WORKSPACE_GUEST_PATH, + ) + self._save_task_locked(task) + return { + "task_id": task_id, + "environment": task.environment, + "environment_version": task.metadata.get("environment_version"), + "command": command, + "stdout": exec_result.stdout, + "stderr": exec_result.stderr, + "exit_code": exec_result.exit_code, + "duration_ms": exec_result.duration_ms, + "execution_mode": execution_mode, + "sequence": entry["sequence"], + "cwd": TASK_WORKSPACE_GUEST_PATH, + } + + def status_task(self, task_id: str) -> dict[str, Any]: + with self._lock: + task = self._load_task_locked(task_id) + self._ensure_task_not_expired_locked(task, time.time()) + self._refresh_task_liveness_locked(task) + self._save_task_locked(task) + return self._serialize_task(task) + + def logs_task(self, task_id: str) -> dict[str, Any]: + with self._lock: + task = self._load_task_locked(task_id) + self._ensure_task_not_expired_locked(task, time.time()) + self._refresh_task_liveness_locked(task) + self._save_task_locked(task) + entries = self._read_task_logs_locked(task.task_id) + return {"task_id": task.task_id, "count": len(entries), "entries": entries} + + def delete_task(self, task_id: str, *, reason: str = "explicit_delete") -> dict[str, Any]: + with self._lock: + task = self._load_task_locked(task_id) + instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id)) + if task.state == "started": + self._backend.stop(instance) + task.state = "stopped" + self._backend.delete(instance) + shutil.rmtree(self._task_dir(task_id), ignore_errors=True) + return {"task_id": task_id, "deleted": True, "reason": reason} + def _validate_limits(self, *, vcpu_count: int, mem_mib: int, ttl_seconds: int) -> None: if not self.MIN_VCPUS <= vcpu_count <= self.MAX_VCPUS: raise ValueError(f"vcpu_count must be between {self.MIN_VCPUS} and {self.MAX_VCPUS}") @@ -620,6 +982,28 @@ class VmManager: "metadata": instance.metadata, } + def _serialize_task(self, task: TaskRecord) -> dict[str, Any]: + return { + "task_id": task.task_id, + "environment": task.environment, + "environment_version": task.metadata.get("environment_version"), + "vcpu_count": task.vcpu_count, + "mem_mib": task.mem_mib, + "ttl_seconds": task.ttl_seconds, + "created_at": task.created_at, + "expires_at": task.expires_at, + "state": task.state, + "network_enabled": task.network is not None, + "allow_host_compat": task.allow_host_compat, + "guest_ip": task.network.guest_ip if task.network is not None else None, + "tap_name": task.network.tap_name if task.network is not None else None, + "execution_mode": task.metadata.get("execution_mode", "pending"), + "workspace_path": TASK_WORKSPACE_GUEST_PATH, + "command_count": task.command_count, + "last_command": task.last_command, + "metadata": task.metadata, + } + def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None: if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat: return @@ -665,3 +1049,184 @@ class VmManager: vm_id = instance.vm_id self._reap_expired_locked(now) raise RuntimeError(f"vm {vm_id!r} expired and was automatically deleted") + + def _start_instance_locked(self, instance: VmInstance) -> None: + if instance.state not in {"created", "stopped"}: + raise RuntimeError( + f"vm {instance.vm_id} cannot be started from state {instance.state!r}" + ) + self._require_guest_boot_or_opt_in(instance) + if not self._runtime_capabilities.supports_vm_boot: + instance.metadata["execution_mode"] = "host_compat" + instance.metadata["boot_mode"] = "compat" + if self._runtime_capabilities.reason is not None: + instance.metadata["runtime_reason"] = self._runtime_capabilities.reason + self._backend.start(instance) + instance.state = "started" + + def _exec_instance( + self, + instance: VmInstance, + *, + command: str, + timeout_seconds: int, + host_workdir: Path | None = None, + guest_cwd: str | None = None, + ) -> tuple[VmExecResult, str]: + if timeout_seconds <= 0: + raise ValueError("timeout_seconds must be positive") + if instance.state != "started": + raise RuntimeError(f"vm {instance.vm_id} must be in 'started' state before execution") + self._require_guest_exec_or_opt_in(instance) + prepared_command = command + if self._runtime_capabilities.supports_guest_exec: + prepared_command = _wrap_guest_command(command, cwd=guest_cwd) + workdir = None + else: + instance.metadata["execution_mode"] = "host_compat" + workdir = host_workdir + exec_result = self._backend.exec( + instance, + prepared_command, + timeout_seconds, + workdir=workdir, + ) + execution_mode = instance.metadata.get("execution_mode", "unknown") + return exec_result, execution_mode + + def _task_dir(self, task_id: str) -> Path: + return self._tasks_dir / task_id + + def _task_runtime_dir(self, task_id: str) -> Path: + return self._task_dir(task_id) / TASK_RUNTIME_DIRNAME + + def _task_workspace_dir(self, task_id: str) -> Path: + return self._task_dir(task_id) / TASK_WORKSPACE_DIRNAME + + def _task_commands_dir(self, task_id: str) -> Path: + return self._task_dir(task_id) / TASK_COMMANDS_DIRNAME + + def _task_metadata_path(self, task_id: str) -> Path: + return self._task_dir(task_id) / "task.json" + + def _count_tasks_locked(self) -> int: + return sum(1 for _ in self._tasks_dir.glob("*/task.json")) + + def _load_task_locked(self, task_id: str) -> TaskRecord: + metadata_path = self._task_metadata_path(task_id) + if not metadata_path.exists(): + raise ValueError(f"task {task_id!r} does not exist") + payload = json.loads(metadata_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise RuntimeError(f"task record at {metadata_path} is invalid") + return TaskRecord.from_payload(payload) + + def _save_task_locked(self, task: TaskRecord) -> None: + metadata_path = self._task_metadata_path(task.task_id) + metadata_path.parent.mkdir(parents=True, exist_ok=True) + metadata_path.write_text( + json.dumps(task.to_payload(), indent=2, sort_keys=True), + encoding="utf-8", + ) + + def _reap_expired_tasks_locked(self, now: float) -> None: + for metadata_path in list(self._tasks_dir.glob("*/task.json")): + payload = json.loads(metadata_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + shutil.rmtree(metadata_path.parent, ignore_errors=True) + continue + task = TaskRecord.from_payload(payload) + if task.expires_at > now: + continue + instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id)) + if task.state == "started": + self._backend.stop(instance) + task.state = "stopped" + self._backend.delete(instance) + shutil.rmtree(self._task_dir(task.task_id), ignore_errors=True) + + def _ensure_task_not_expired_locked(self, task: TaskRecord, now: float) -> None: + if task.expires_at <= now: + task_id = task.task_id + self._reap_expired_tasks_locked(now) + raise RuntimeError(f"task {task_id!r} expired and was automatically deleted") + + def _refresh_task_liveness_locked(self, task: TaskRecord) -> None: + if task.state != "started": + return + execution_mode = task.metadata.get("execution_mode") + if execution_mode == "host_compat": + return + if _pid_is_running(task.firecracker_pid): + return + task.state = "stopped" + task.firecracker_pid = None + task.last_error = "backing guest process is no longer running" + + def _record_task_command_locked( + self, + task: TaskRecord, + *, + command: str, + exec_result: VmExecResult, + execution_mode: str, + cwd: str, + ) -> dict[str, Any]: + sequence = task.command_count + 1 + commands_dir = self._task_commands_dir(task.task_id) + commands_dir.mkdir(parents=True, exist_ok=True) + base_name = f"{sequence:06d}" + stdout_path = commands_dir / f"{base_name}.stdout" + stderr_path = commands_dir / f"{base_name}.stderr" + record_path = commands_dir / f"{base_name}.json" + stdout_path.write_text(exec_result.stdout, encoding="utf-8") + stderr_path.write_text(exec_result.stderr, encoding="utf-8") + entry: dict[str, Any] = { + "sequence": sequence, + "command": command, + "cwd": cwd, + "exit_code": exec_result.exit_code, + "duration_ms": exec_result.duration_ms, + "execution_mode": execution_mode, + "stdout_file": stdout_path.name, + "stderr_file": stderr_path.name, + "recorded_at": time.time(), + } + record_path.write_text(json.dumps(entry, indent=2, sort_keys=True), encoding="utf-8") + task.command_count = sequence + task.last_command = { + "sequence": sequence, + "command": command, + "cwd": cwd, + "exit_code": exec_result.exit_code, + "duration_ms": exec_result.duration_ms, + "execution_mode": execution_mode, + } + return entry + + def _read_task_logs_locked(self, task_id: str) -> list[dict[str, Any]]: + entries: list[dict[str, Any]] = [] + commands_dir = self._task_commands_dir(task_id) + if not commands_dir.exists(): + return entries + for record_path in sorted(commands_dir.glob("*.json")): + payload = json.loads(record_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + continue + stdout_name = str(payload.get("stdout_file", "")) + stderr_name = str(payload.get("stderr_file", "")) + stdout = "" + stderr = "" + if stdout_name != "": + stdout_path = commands_dir / stdout_name + if stdout_path.exists(): + stdout = stdout_path.read_text(encoding="utf-8") + if stderr_name != "": + stderr_path = commands_dir / stderr_name + if stderr_path.exists(): + stderr = stderr_path.read_text(encoding="utf-8") + entry = dict(payload) + entry["stdout"] = stdout + entry["stderr"] = stderr + entries.append(entry) + return entries diff --git a/tests/test_api.py b/tests/test_api.py index b282378..7b79121 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -48,6 +48,7 @@ def test_pyro_create_server_registers_vm_run(tmp_path: Path) -> None: tool_names = asyncio.run(_run()) assert "vm_run" in tool_names assert "vm_create" in tool_names + assert "task_create" in tool_names def test_pyro_vm_run_tool_executes(tmp_path: Path) -> None: @@ -102,3 +103,25 @@ def test_pyro_create_vm_defaults_sizing_and_host_compat(tmp_path: Path) -> None: assert created["vcpu_count"] == 1 assert created["mem_mib"] == 1024 assert created["allow_host_compat"] is True + + +def test_pyro_task_methods_delegate_to_manager(tmp_path: Path) -> None: + pyro = Pyro( + manager=VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + ) + + created = pyro.create_task(environment="debian:12-base", allow_host_compat=True) + task_id = str(created["task_id"]) + executed = pyro.exec_task(task_id, command="printf 'ok\\n'") + status = pyro.status_task(task_id) + logs = pyro.logs_task(task_id) + deleted = pyro.delete_task(task_id) + + assert executed["stdout"] == "ok\n" + assert status["command_count"] == 1 + assert logs["count"] == 1 + assert deleted["deleted"] is True diff --git a/tests/test_cli.py b/tests/test_cli.py index 69643d6..67859f7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -59,6 +59,14 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None: assert "Expose pyro tools over stdio for an MCP client." in mcp_help assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help + task_help = _subparser_choice(parser, "task").format_help() + assert "pyro task create debian:12" in task_help + assert "pyro task exec TASK_ID" in task_help + + task_exec_help = _subparser_choice(_subparser_choice(parser, "task"), "exec").format_help() + assert "persistent `/workspace`" in task_exec_help + assert "pyro task exec TASK_ID -- cat note.txt" in task_exec_help + def test_cli_run_prints_json( monkeypatch: pytest.MonkeyPatch, @@ -318,6 +326,243 @@ def test_cli_requires_run_command() -> None: cli._require_command([]) +def test_cli_task_create_prints_json( + monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] +) -> None: + class StubPyro: + def create_task(self, **kwargs: Any) -> dict[str, Any]: + assert kwargs["environment"] == "debian:12" + return {"task_id": "task-123", "state": "started"} + + class StubParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="create", + environment="debian:12", + vcpu_count=1, + mem_mib=1024, + ttl_seconds=600, + network=False, + allow_host_compat=False, + json=True, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: StubParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + cli.main() + output = json.loads(capsys.readouterr().out) + assert output["task_id"] == "task-123" + + +def test_cli_task_create_prints_human( + monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] +) -> None: + class StubPyro: + def create_task(self, **kwargs: Any) -> dict[str, Any]: + del kwargs + return { + "task_id": "task-123", + "environment": "debian:12", + "state": "started", + "workspace_path": "/workspace", + "execution_mode": "guest_vsock", + "vcpu_count": 1, + "mem_mib": 1024, + "command_count": 0, + "last_command": None, + } + + class StubParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="create", + environment="debian:12", + vcpu_count=1, + mem_mib=1024, + ttl_seconds=600, + network=False, + allow_host_compat=False, + json=False, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: StubParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + cli.main() + output = capsys.readouterr().out + assert "Task: task-123" in output + assert "Workspace: /workspace" in output + + +def test_cli_task_exec_prints_human_output( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + class StubPyro: + def exec_task(self, task_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: + assert task_id == "task-123" + assert command == "cat note.txt" + assert timeout_seconds == 30 + return { + "task_id": task_id, + "sequence": 2, + "cwd": "/workspace", + "execution_mode": "guest_vsock", + "exit_code": 0, + "duration_ms": 4, + "stdout": "hello\n", + "stderr": "", + } + + class StubParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="exec", + task_id="task-123", + timeout_seconds=30, + json=False, + command_args=["--", "cat", "note.txt"], + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: StubParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + cli.main() + captured = capsys.readouterr() + assert captured.out == "hello\n" + assert "[task-exec] task_id=task-123 sequence=2 cwd=/workspace" in captured.err + + +def test_cli_task_logs_and_delete_print_human( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + class StubPyro: + def logs_task(self, task_id: str) -> dict[str, Any]: + assert task_id == "task-123" + return { + "task_id": task_id, + "count": 1, + "entries": [ + { + "sequence": 1, + "exit_code": 0, + "duration_ms": 2, + "cwd": "/workspace", + "command": "printf 'ok\\n'", + "stdout": "ok\n", + "stderr": "", + } + ], + } + + def delete_task(self, task_id: str) -> dict[str, Any]: + assert task_id == "task-123" + return {"task_id": task_id, "deleted": True} + + class LogsParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="logs", + task_id="task-123", + json=False, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: LogsParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + cli.main() + + class DeleteParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="delete", + task_id="task-123", + json=False, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: DeleteParser()) + cli.main() + + output = capsys.readouterr().out + assert "#1 exit_code=0 duration_ms=2 cwd=/workspace" in output + assert "Deleted task: task-123" in output + + +def test_cli_task_status_and_delete_print_json( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + class StubPyro: + def status_task(self, task_id: str) -> dict[str, Any]: + assert task_id == "task-123" + return {"task_id": task_id, "state": "started"} + + def delete_task(self, task_id: str) -> dict[str, Any]: + assert task_id == "task-123" + return {"task_id": task_id, "deleted": True} + + class StatusParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="status", + task_id="task-123", + json=True, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: StatusParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + cli.main() + status = json.loads(capsys.readouterr().out) + assert status["state"] == "started" + + class DeleteParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="delete", + task_id="task-123", + json=True, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: DeleteParser()) + cli.main() + deleted = json.loads(capsys.readouterr().out) + assert deleted["deleted"] is True + + +def test_cli_task_exec_json_error_exits_nonzero( + monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] +) -> None: + class StubPyro: + def exec_task(self, task_id: str, *, command: str, timeout_seconds: int) -> dict[str, Any]: + del task_id, command, timeout_seconds + raise RuntimeError("task is unavailable") + + class StubParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="task", + task_command="exec", + task_id="task-123", + timeout_seconds=30, + json=True, + command_args=["--", "true"], + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: StubParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + + with pytest.raises(SystemExit, match="1"): + cli.main() + + payload = json.loads(capsys.readouterr().out) + assert payload["ok"] is False + + def test_print_env_helpers_render_human_output(capsys: pytest.CaptureFixture[str]) -> None: cli._print_env_list_human( { diff --git a/tests/test_public_contract.py b/tests/test_public_contract.py index 97aaa16..bf07ad4 100644 --- a/tests/test_public_contract.py +++ b/tests/test_public_contract.py @@ -17,6 +17,7 @@ from pyro_mcp.contract import ( PUBLIC_CLI_DEMO_SUBCOMMANDS, PUBLIC_CLI_ENV_SUBCOMMANDS, PUBLIC_CLI_RUN_FLAGS, + PUBLIC_CLI_TASK_SUBCOMMANDS, PUBLIC_MCP_TOOLS, PUBLIC_SDK_METHODS, ) @@ -63,6 +64,10 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None: for subcommand_name in PUBLIC_CLI_ENV_SUBCOMMANDS: assert subcommand_name in env_help_text + task_help_text = _subparser_choice(parser, "task").format_help() + for subcommand_name in PUBLIC_CLI_TASK_SUBCOMMANDS: + assert subcommand_name in task_help_text + demo_help_text = _subparser_choice(parser, "demo").format_help() for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS: assert subcommand_name in demo_help_text diff --git a/tests/test_server.py b/tests/test_server.py index 17a358c..be8e1db 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -31,6 +31,8 @@ def test_create_server_registers_vm_tools(tmp_path: Path) -> None: assert "vm_network_info" in tool_names assert "vm_run" in tool_names assert "vm_status" in tool_names + assert "task_create" in tool_names + assert "task_logs" in tool_names def test_vm_run_round_trip(tmp_path: Path) -> None: @@ -161,3 +163,50 @@ def test_server_main_runs_stdio_transport(monkeypatch: pytest.MonkeyPatch) -> No monkeypatch.setattr(server_module, "create_server", lambda: StubServer()) server_module.main() assert called == {"transport": "stdio"} + + +def test_task_tools_round_trip(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + def _extract_structured(raw_result: object) -> dict[str, Any]: + if not isinstance(raw_result, tuple) or len(raw_result) != 2: + raise TypeError("unexpected call_tool result shape") + _, structured = raw_result + if not isinstance(structured, dict): + raise TypeError("expected structured dictionary result") + return cast(dict[str, Any], structured) + + async def _run() -> tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]: + server = create_server(manager=manager) + created = _extract_structured( + await server.call_tool( + "task_create", + { + "environment": "debian:12-base", + "allow_host_compat": True, + }, + ) + ) + task_id = str(created["task_id"]) + executed = _extract_structured( + await server.call_tool( + "task_exec", + { + "task_id": task_id, + "command": "printf 'ok\\n'", + }, + ) + ) + logs = _extract_structured(await server.call_tool("task_logs", {"task_id": task_id})) + deleted = _extract_structured(await server.call_tool("task_delete", {"task_id": task_id})) + return created, executed, logs, deleted + + created, executed, logs, deleted = asyncio.run(_run()) + assert created["state"] == "started" + assert executed["stdout"] == "ok\n" + assert logs["count"] == 1 + assert deleted["deleted"] is True diff --git a/tests/test_vm_manager.py b/tests/test_vm_manager.py index e307688..560ba95 100644 --- a/tests/test_vm_manager.py +++ b/tests/test_vm_manager.py @@ -1,14 +1,17 @@ from __future__ import annotations +import json +import subprocess +import time from pathlib import Path from typing import Any import pytest import pyro_mcp.vm_manager as vm_manager_module -from pyro_mcp.runtime import resolve_runtime_paths +from pyro_mcp.runtime import RuntimeCapabilities, resolve_runtime_paths from pyro_mcp.vm_manager import VmManager -from pyro_mcp.vm_network import TapNetworkManager +from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None: @@ -262,6 +265,95 @@ def test_vm_manager_run_vm(tmp_path: Path) -> None: assert str(result["stdout"]) == "ok\n" +def test_task_lifecycle_and_logs(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + created = manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + ) + task_id = str(created["task_id"]) + assert created["state"] == "started" + assert created["workspace_path"] == "/workspace" + + first = manager.exec_task( + task_id, + command="printf 'hello\\n' > note.txt", + timeout_seconds=30, + ) + second = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30) + + assert first["exit_code"] == 0 + assert second["stdout"] == "hello\n" + + status = manager.status_task(task_id) + assert status["command_count"] == 2 + assert status["last_command"] is not None + + logs = manager.logs_task(task_id) + assert logs["count"] == 2 + entries = logs["entries"] + assert isinstance(entries, list) + assert entries[1]["stdout"] == "hello\n" + + deleted = manager.delete_task(task_id) + assert deleted["deleted"] is True + with pytest.raises(ValueError, match="does not exist"): + manager.status_task(task_id) + + +def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None: + base_dir = tmp_path / "vms" + manager = VmManager( + backend_name="mock", + base_dir=base_dir, + network_manager=TapNetworkManager(enabled=False), + ) + task_id = str( + manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + )["task_id"] + ) + + other = VmManager( + backend_name="mock", + base_dir=base_dir, + network_manager=TapNetworkManager(enabled=False), + ) + executed = other.exec_task(task_id, command="printf 'ok\\n'", timeout_seconds=30) + assert executed["exit_code"] == 0 + assert executed["stdout"] == "ok\n" + + logs = other.logs_task(task_id) + assert logs["count"] == 1 + + +def test_task_requires_started_state(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + task_id = str( + manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + )["task_id"] + ) + task_dir = tmp_path / "vms" / "tasks" / task_id / "task.json" + payload = json.loads(task_dir.read_text(encoding="utf-8")) + payload["state"] = "stopped" + task_dir.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") + + with pytest.raises(RuntimeError, match="must be in 'started' state"): + manager.exec_task(task_id, command="true", timeout_seconds=30) + + def test_vm_manager_firecracker_backend_path( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: @@ -334,3 +426,193 @@ def test_vm_manager_uses_canonical_default_cache_dir( ) assert manager._environment_store.cache_dir == tmp_path / "cache" # noqa: SLF001 + + +def test_vm_manager_helper_round_trips() -> None: + network = NetworkConfig( + vm_id="abc123", + tap_name="tap0", + guest_ip="172.29.1.2", + gateway_ip="172.29.1.1", + subnet_cidr="172.29.1.0/24", + mac_address="06:00:aa:bb:cc:dd", + dns_servers=("1.1.1.1", "8.8.8.8"), + ) + + assert vm_manager_module._optional_int(None) is None # noqa: SLF001 + assert vm_manager_module._optional_int(True) == 1 # noqa: SLF001 + assert vm_manager_module._optional_int(7) == 7 # noqa: SLF001 + assert vm_manager_module._optional_int(7.2) == 7 # noqa: SLF001 + assert vm_manager_module._optional_int("9") == 9 # noqa: SLF001 + with pytest.raises(TypeError, match="integer-compatible"): + vm_manager_module._optional_int(object()) # noqa: SLF001 + + assert vm_manager_module._optional_str(None) is None # noqa: SLF001 + assert vm_manager_module._optional_str(1) == "1" # noqa: SLF001 + assert vm_manager_module._optional_dict(None) is None # noqa: SLF001 + assert vm_manager_module._optional_dict({"x": 1}) == {"x": 1} # noqa: SLF001 + with pytest.raises(TypeError, match="dictionary payload"): + vm_manager_module._optional_dict("bad") # noqa: SLF001 + assert vm_manager_module._string_dict({"x": 1}) == {"x": "1"} # noqa: SLF001 + assert vm_manager_module._string_dict("bad") == {} # noqa: SLF001 + + serialized = vm_manager_module._serialize_network(network) # noqa: SLF001 + assert serialized is not None + restored = vm_manager_module._deserialize_network(serialized) # noqa: SLF001 + assert restored == network + assert vm_manager_module._deserialize_network(None) is None # noqa: SLF001 + with pytest.raises(TypeError, match="dictionary payload"): + vm_manager_module._deserialize_network("bad") # noqa: SLF001 + + assert vm_manager_module._wrap_guest_command("echo hi") == "echo hi" # noqa: SLF001 + wrapped = vm_manager_module._wrap_guest_command("echo hi", cwd="/workspace") # noqa: SLF001 + assert "cd /workspace" in wrapped + assert vm_manager_module._pid_is_running(None) is False # noqa: SLF001 + + +def test_copy_rootfs_falls_back_to_copy2( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + source = tmp_path / "rootfs.ext4" + source.write_text("payload", encoding="utf-8") + dest = tmp_path / "dest" / "rootfs.ext4" + + def _raise_oserror(*args: Any, **kwargs: Any) -> Any: + del args, kwargs + raise OSError("no cp") + + monkeypatch.setattr(subprocess, "run", _raise_oserror) + + clone_mode = vm_manager_module._copy_rootfs(source, dest) # noqa: SLF001 + assert clone_mode == "copy2" + assert dest.read_text(encoding="utf-8") == "payload" + + +def test_task_create_cleans_up_on_start_failure( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + def _boom(instance: Any) -> None: + del instance + raise RuntimeError("boom") + + monkeypatch.setattr(manager._backend, "start", _boom) # noqa: SLF001 + + with pytest.raises(RuntimeError, match="boom"): + manager.create_task(environment="debian:12-base", allow_host_compat=True) + + assert list((tmp_path / "vms" / "tasks").iterdir()) == [] + + +def test_exec_instance_wraps_guest_workspace_command(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001 + supports_vm_boot=True, + supports_guest_exec=True, + supports_guest_network=False, + reason=None, + ) + captured: dict[str, Any] = {} + + class StubBackend: + def exec( + self, + instance: Any, + command: str, + timeout_seconds: int, + *, + workdir: Path | None = None, + ) -> vm_manager_module.VmExecResult: + del instance, timeout_seconds + captured["command"] = command + captured["workdir"] = workdir + return vm_manager_module.VmExecResult( + stdout="", + stderr="", + exit_code=0, + duration_ms=1, + ) + + manager._backend = StubBackend() # type: ignore[assignment] # noqa: SLF001 + instance = vm_manager_module.VmInstance( # noqa: SLF001 + vm_id="vm-123", + environment="debian:12-base", + vcpu_count=1, + mem_mib=512, + ttl_seconds=600, + created_at=time.time(), + expires_at=time.time() + 600, + workdir=tmp_path / "runtime", + state="started", + ) + result, execution_mode = manager._exec_instance( # noqa: SLF001 + instance, + command="echo hi", + timeout_seconds=30, + guest_cwd="/workspace", + ) + assert result.exit_code == 0 + assert execution_mode == "unknown" + assert "cd /workspace" in str(captured["command"]) + assert captured["workdir"] is None + + +def test_status_task_marks_dead_backing_process_stopped(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + task_id = str( + manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + )["task_id"] + ) + task_path = tmp_path / "vms" / "tasks" / task_id / "task.json" + payload = json.loads(task_path.read_text(encoding="utf-8")) + payload["metadata"]["execution_mode"] = "guest_vsock" + payload["firecracker_pid"] = 999999 + task_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") + + status = manager.status_task(task_id) + assert status["state"] == "stopped" + updated_payload = json.loads(task_path.read_text(encoding="utf-8")) + assert "backing guest process" in str(updated_payload.get("last_error", "")) + + +def test_reap_expired_tasks_removes_invalid_and_expired_records(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + invalid_dir = tmp_path / "vms" / "tasks" / "invalid" + invalid_dir.mkdir(parents=True) + (invalid_dir / "task.json").write_text("[]", encoding="utf-8") + + task_id = str( + manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + )["task_id"] + ) + task_path = tmp_path / "vms" / "tasks" / task_id / "task.json" + payload = json.loads(task_path.read_text(encoding="utf-8")) + payload["expires_at"] = 0.0 + task_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") + + with manager._lock: # noqa: SLF001 + manager._reap_expired_tasks_locked(time.time()) # noqa: SLF001 + + assert not invalid_dir.exists() + assert not (tmp_path / "vms" / "tasks" / task_id).exists() diff --git a/uv.lock b/uv.lock index 754b94e..f97ee73 100644 --- a/uv.lock +++ b/uv.lock @@ -706,7 +706,7 @@ crypto = [ [[package]] name = "pyro-mcp" -version = "2.0.1" +version = "2.1.0" source = { editable = "." } dependencies = [ { name = "mcp" },