From aa886b346e374c428124eacde09ad8c4d504ba84 Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Wed, 11 Mar 2026 21:45:38 -0300 Subject: [PATCH] Add seeded task workspace creation Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces. Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries. Validation: - uv lock - UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py - UV_CACHE_DIR=.uv-cache make check - UV_CACHE_DIR=.uv-cache make dist-check - real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete --- CHANGELOG.md | 9 + README.md | 18 +- docs/first-run.md | 10 +- docs/install.md | 9 +- docs/integrations.md | 7 +- docs/public-contract.md | 8 + pyproject.toml | 2 +- runtime_sources/README.md | 2 +- .../linux-x86_64/guest/pyro_guest_agent.py | 150 ++++++- .../linux-x86_64/runtime.lock.json | 2 +- src/pyro_mcp/api.py | 4 + src/pyro_mcp/cli.py | 29 +- src/pyro_mcp/contract.py | 9 + .../linux-x86_64/guest/pyro_guest_agent.py | 150 ++++++- .../runtime_bundle/linux-x86_64/manifest.json | 4 +- src/pyro_mcp/vm_environments.py | 2 +- src/pyro_mcp/vm_guest.py | 111 +++-- src/pyro_mcp/vm_manager.py | 394 +++++++++++++++++- tests/test_api.py | 13 +- tests/test_cli.py | 24 +- tests/test_public_contract.py | 6 + tests/test_server.py | 7 +- tests/test_vm_guest.py | 43 ++ tests/test_vm_manager.py | 136 ++++++ uv.lock | 2 +- 25 files changed, 1076 insertions(+), 75 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 217b6e6..a20099d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,15 @@ All notable user-visible changes to `pyro-mcp` are documented here. +## 2.2.0 + +- Added seeded task creation across the CLI, Python SDK, and MCP server with an optional + `source_path` for host directories and `.tar` / `.tar.gz` / `.tgz` archives. +- Seeded task workspaces now persist `workspace_seed` metadata so later status calls report how + `/workspace` was initialized. +- Reused the task workspace model from `2.1.0` while adding the first explicit host-to-task + content import path for repeated command workflows. + ## 2.1.0 - Added the first persistent task workspace alpha across the CLI, Python SDK, and MCP server. diff --git a/README.md b/README.md index b4ba65e..a085622 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ It exposes the same runtime in three public forms: - First run transcript: [docs/first-run.md](docs/first-run.md) - Terminal walkthrough GIF: [docs/assets/first-run.gif](docs/assets/first-run.gif) - PyPI package: [pypi.org/project/pyro-mcp](https://pypi.org/project/pyro-mcp/) -- What's new in 2.1.0: [CHANGELOG.md#210](CHANGELOG.md#210) +- What's new in 2.2.0: [CHANGELOG.md#220](CHANGELOG.md#220) - Host requirements: [docs/host-requirements.md](docs/host-requirements.md) - Integration targets: [docs/integrations.md](docs/integrations.md) - Public contract: [docs/public-contract.md](docs/public-contract.md) @@ -55,7 +55,7 @@ What success looks like: ```bash Platform: linux-x86_64 Runtime: PASS -Catalog version: 2.1.0 +Catalog version: 2.2.0 ... [pull] phase=install environment=debian:12 [pull] phase=ready environment=debian:12 @@ -74,7 +74,7 @@ access to `registry-1.docker.io`, and needs local cache space for the guest imag After the quickstart works: - prove the full one-shot lifecycle with `uvx --from pyro-mcp pyro demo` -- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12` +- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo` - move to Python or MCP via [docs/integrations.md](docs/integrations.md) ## Supported Hosts @@ -128,7 +128,7 @@ uvx --from pyro-mcp pyro env list Expected output: ```bash -Catalog version: 2.1.0 +Catalog version: 2.2.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -198,7 +198,7 @@ Use `pyro run` for one-shot commands. Use `pyro task ...` when you need repeated workspace without recreating the sandbox every time. ```bash -pyro task create debian:12 +pyro task create debian:12 --source-path ./repo pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt' pyro task exec TASK_ID -- cat note.txt pyro task logs TASK_ID @@ -206,7 +206,9 @@ pyro task delete TASK_ID ``` Task workspaces start in `/workspace` and keep command history until you delete them. For machine -consumption, add `--json` and read the returned `task_id`. +consumption, add `--json` and read the returned `task_id`. Use `--source-path` when you want the +task to start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive instead of an +empty workspace. ## Public Interfaces @@ -348,7 +350,7 @@ For repeated commands in one workspace: from pyro_mcp import Pyro pyro = Pyro() -task = pyro.create_task(environment="debian:12") +task = pyro.create_task(environment="debian:12", source_path="./repo") task_id = task["task_id"] try: pyro.exec_task(task_id, command="printf 'hello from task\\n' > note.txt") @@ -378,7 +380,7 @@ Advanced lifecycle tools: Persistent workspace tools: -- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false)` +- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false, source_path=null)` - `task_exec(task_id, command, timeout_seconds=30)` - `task_status(task_id)` - `task_logs(task_id)` diff --git a/docs/first-run.md b/docs/first-run.md index 5b699bf..8a09b35 100644 --- a/docs/first-run.md +++ b/docs/first-run.md @@ -22,7 +22,7 @@ Networking: tun=yes ip_forward=yes ```bash $ uvx --from pyro-mcp pyro env list -Catalog version: 2.1.0 +Catalog version: 2.2.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -70,7 +70,7 @@ deterministic structured result. ```bash $ uvx --from pyro-mcp pyro demo -$ uvx --from pyro-mcp pyro task create debian:12 +$ uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo $ uvx --from pyro-mcp pyro mcp serve ``` @@ -79,11 +79,12 @@ $ uvx --from pyro-mcp pyro mcp serve When you need repeated commands in one sandbox, switch to `pyro task ...`: ```bash -$ uvx --from pyro-mcp pyro task create debian:12 +$ uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo Task: ... Environment: debian:12 State: started Workspace: /workspace +Workspace seed: directory from ... Execution mode: guest_vsock Resources: 1 vCPU / 1024 MiB Command count: 0 @@ -96,6 +97,9 @@ hello from task [task-exec] task_id=... sequence=2 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=... ``` +Use `--source-path` when the task should start from a host directory or a local +`.tar` / `.tar.gz` / `.tgz` archive instead of an empty `/workspace`. + Example output: ```json diff --git a/docs/install.md b/docs/install.md index 6332804..9e234be 100644 --- a/docs/install.md +++ b/docs/install.md @@ -83,7 +83,7 @@ uvx --from pyro-mcp pyro env list Expected output: ```bash -Catalog version: 2.1.0 +Catalog version: 2.2.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -174,7 +174,7 @@ pyro run debian:12 -- git --version After the CLI path works, you can move on to: -- persistent workspaces: `pyro task create debian:12` +- persistent workspaces: `pyro task create debian:12 --source-path ./repo` - MCP: `pyro mcp serve` - Python SDK: `from pyro_mcp import Pyro` - Demos: `pyro demo` or `pyro demo --network` @@ -184,7 +184,7 @@ After the CLI path works, you can move on to: Use `pyro task ...` when you need repeated commands in one sandbox instead of one-shot `pyro run`. ```bash -pyro task create debian:12 +pyro task create debian:12 --source-path ./repo pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt' pyro task exec TASK_ID -- cat note.txt pyro task logs TASK_ID @@ -192,7 +192,8 @@ pyro task delete TASK_ID ``` Task commands default to the persistent `/workspace` directory inside the guest. If you need the -task identifier programmatically, use `--json` and read the `task_id` field. +task identifier programmatically, use `--json` and read the `task_id` field. Use `--source-path` +when the task should start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive. ## Contributor Clone diff --git a/docs/integrations.md b/docs/integrations.md index 242907c..9401071 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -30,7 +30,7 @@ Best when: Recommended surface: - `vm_run` -- `task_create` + `task_exec` when the agent needs persistent workspace state +- `task_create(source_path=...)` + `task_exec` when the agent needs persistent workspace state Canonical example: @@ -65,14 +65,15 @@ Best when: Recommended default: - `Pyro.run_in_vm(...)` -- `Pyro.create_task(...)` + `Pyro.exec_task(...)` when repeated workspace commands are required +- `Pyro.create_task(source_path=...)` + `Pyro.exec_task(...)` when repeated workspace commands are required Lifecycle note: - `Pyro.exec_vm(...)` runs one command and auto-cleans the VM afterward - use `create_vm(...)` + `start_vm(...)` only when you need pre-exec inspection or status before that final exec -- use `create_task(...)` when the agent needs repeated commands in one persistent `/workspace` +- use `create_task(source_path=...)` when the agent needs repeated commands in one persistent + `/workspace` that starts from host content Examples: diff --git a/docs/public-contract.md b/docs/public-contract.md index 3ecfe9b..2f505f6 100644 --- a/docs/public-contract.md +++ b/docs/public-contract.md @@ -46,8 +46,12 @@ Behavioral guarantees: - `pyro run`, `pyro env list`, `pyro env pull`, `pyro env inspect`, `pyro env prune`, and `pyro doctor` are human-readable by default and return structured JSON with `--json`. - `pyro demo ollama` prints log lines plus a final summary line. - `pyro task create` auto-starts a persistent workspace. +- `pyro task create --source-path PATH` seeds `/workspace` from a host directory or a local + `.tar` / `.tar.gz` / `.tgz` archive before the task is returned. - `pyro task exec` runs in the persistent `/workspace` for that task and does not auto-clean. - `pyro task logs` returns persisted command history for that task until `pyro task delete`. +- Task create/status results expose `workspace_seed` metadata describing how `/workspace` was + initialized. ## Python SDK Contract @@ -106,6 +110,8 @@ Behavioral defaults: - `Pyro.create_task(...)` defaults to `vcpu_count=1` and `mem_mib=1024`. - `allow_host_compat` defaults to `False` on `create_vm(...)` and `run_in_vm(...)`. - `allow_host_compat` defaults to `False` on `create_task(...)`. +- `Pyro.create_task(..., source_path=...)` seeds `/workspace` from a host directory or a local + `.tar` / `.tar.gz` / `.tgz` archive before the task is returned. - `Pyro.exec_vm(...)` runs one command and auto-cleans that VM after the exec completes. - `Pyro.exec_task(...)` runs one command in the persistent task workspace and leaves the task alive. @@ -141,6 +147,8 @@ Behavioral defaults: - `task_create` defaults to `vcpu_count=1` and `mem_mib=1024`. - `vm_run` and `vm_create` expose `allow_host_compat`, which defaults to `false`. - `task_create` exposes `allow_host_compat`, which defaults to `false`. +- `task_create` accepts optional `source_path` and seeds `/workspace` from a host directory or a + local `.tar` / `.tar.gz` / `.tgz` archive before the task is returned. - `vm_exec` runs one command and auto-cleans that VM after the exec completes. - `task_exec` runs one command in a persistent `/workspace` and leaves the task alive. diff --git a/pyproject.toml b/pyproject.toml index 6a50302..f3118c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pyro-mcp" -version = "2.1.0" +version = "2.2.0" description = "Curated Linux environments for ephemeral Firecracker-backed VM execution." readme = "README.md" license = { file = "LICENSE" } diff --git a/runtime_sources/README.md b/runtime_sources/README.md index e304096..5184c0e 100644 --- a/runtime_sources/README.md +++ b/runtime_sources/README.md @@ -34,7 +34,7 @@ Kernel build note: Current status: 1. Firecracker and Jailer are materialized from pinned official release artifacts. 2. The kernel and rootfs images are built from pinned inputs into `build/runtime_sources/`. -3. The guest agent is installed into each rootfs and used for vsock exec. +3. The guest agent is installed into each rootfs and used for vsock exec plus workspace archive imports. 4. `runtime.lock.json` now advertises real guest capabilities. Safety rule: diff --git a/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py b/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py index ea9c2cf..d914fc1 100644 --- a/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py +++ b/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py @@ -1,26 +1,31 @@ #!/usr/bin/env python3 -"""Minimal guest-side exec agent for pyro runtime bundles.""" +"""Minimal guest-side exec and workspace import agent for pyro runtime bundles.""" from __future__ import annotations +import io import json +import os import socket import subprocess +import tarfile import time +from pathlib import Path, PurePosixPath from typing import Any PORT = 5005 BUFFER_SIZE = 65536 +WORKSPACE_ROOT = PurePosixPath("/workspace") def _read_request(conn: socket.socket) -> dict[str, Any]: chunks: list[bytes] = [] while True: - data = conn.recv(BUFFER_SIZE) + data = conn.recv(1) if data == b"": break chunks.append(data) - if b"\n" in data: + if data == b"\n": break payload = json.loads(b"".join(chunks).decode("utf-8").strip()) if not isinstance(payload, dict): @@ -28,6 +33,130 @@ def _read_request(conn: socket.socket) -> dict[str, Any]: return payload +def _read_exact(conn: socket.socket, size: int) -> bytes: + remaining = size + chunks: list[bytes] = [] + while remaining > 0: + data = conn.recv(min(BUFFER_SIZE, remaining)) + if data == b"": + raise RuntimeError("unexpected EOF while reading archive payload") + chunks.append(data) + remaining -= len(data) + return b"".join(chunks) + + +def _normalize_member_name(name: str) -> PurePosixPath: + candidate = name.strip() + if candidate == "": + raise RuntimeError("archive member path is empty") + member_path = PurePosixPath(candidate) + if member_path.is_absolute(): + raise RuntimeError(f"absolute archive member paths are not allowed: {name}") + parts = [part for part in member_path.parts if part not in {"", "."}] + if any(part == ".." for part in parts): + raise RuntimeError(f"unsafe archive member path: {name}") + normalized = PurePosixPath(*parts) + if str(normalized) in {"", "."}: + raise RuntimeError(f"unsafe archive member path: {name}") + return normalized + + +def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]: + candidate = destination.strip() + if candidate == "": + raise RuntimeError("destination must not be empty") + destination_path = PurePosixPath(candidate) + if not destination_path.is_absolute(): + destination_path = WORKSPACE_ROOT / destination_path + parts = [part for part in destination_path.parts if part not in {"", "."}] + normalized = PurePosixPath("/") / PurePosixPath(*parts) + if normalized == PurePosixPath("/"): + raise RuntimeError("destination must stay inside /workspace") + if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts: + raise RuntimeError("destination must stay inside /workspace") + suffix = normalized.relative_to(WORKSPACE_ROOT) + host_path = Path("/workspace") + if str(suffix) not in {"", "."}: + host_path = host_path / str(suffix) + return normalized, host_path + + +def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None: + target = link_target.strip() + if target == "": + raise RuntimeError(f"symlink {member_path} has an empty target") + target_path = PurePosixPath(target) + if target_path.is_absolute(): + raise RuntimeError(f"symlink {member_path} escapes the workspace") + combined = member_path.parent.joinpath(target_path) + parts = [part for part in combined.parts if part not in {"", "."}] + if any(part == ".." for part in parts): + raise RuntimeError(f"symlink {member_path} escapes the workspace") + + +def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None: + relative_path = target_path.relative_to(root) + current = root + for part in relative_path.parts[:-1]: + current = current / part + if current.is_symlink(): + raise RuntimeError( + f"archive member would traverse through a symlinked path: {member_name}" + ) + + +def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: + _, destination_root = _normalize_destination(destination) + destination_root.mkdir(parents=True, exist_ok=True) + bytes_written = 0 + entry_count = 0 + with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive: + for member in archive.getmembers(): + member_name = _normalize_member_name(member.name) + target_path = destination_root / str(member_name) + entry_count += 1 + _ensure_no_symlink_parents(destination_root, target_path, member.name) + if member.isdir(): + if target_path.exists() and not target_path.is_dir(): + raise RuntimeError(f"directory conflicts with existing path: {member.name}") + target_path.mkdir(parents=True, exist_ok=True) + continue + if member.isfile(): + target_path.parent.mkdir(parents=True, exist_ok=True) + if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()): + raise RuntimeError(f"file conflicts with existing path: {member.name}") + source = archive.extractfile(member) + if source is None: + raise RuntimeError(f"failed to read archive member: {member.name}") + with target_path.open("wb") as handle: + while True: + chunk = source.read(BUFFER_SIZE) + if chunk == b"": + break + handle.write(chunk) + bytes_written += member.size + continue + if member.issym(): + _validate_symlink_target(member_name, member.linkname) + target_path.parent.mkdir(parents=True, exist_ok=True) + if target_path.exists() and not target_path.is_symlink(): + raise RuntimeError(f"symlink conflicts with existing path: {member.name}") + if target_path.is_symlink(): + target_path.unlink() + os.symlink(member.linkname, target_path) + continue + if member.islnk(): + raise RuntimeError( + f"hard links are not allowed in workspace archives: {member.name}" + ) + raise RuntimeError(f"unsupported archive member type: {member.name}") + return { + "destination": destination, + "entry_count": entry_count, + "bytes_written": bytes_written, + } + + def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]: started = time.monotonic() try: @@ -64,9 +193,18 @@ def main() -> None: conn, _ = server.accept() with conn: request = _read_request(conn) - command = str(request.get("command", "")) - timeout_seconds = int(request.get("timeout_seconds", 30)) - response = _run_command(command, timeout_seconds) + action = str(request.get("action", "exec")) + if action == "extract_archive": + archive_size = int(request.get("archive_size", 0)) + if archive_size < 0: + raise RuntimeError("archive_size must not be negative") + destination = str(request.get("destination", "/workspace")) + payload = _read_exact(conn, archive_size) + response = _extract_archive(payload, destination) + else: + command = str(request.get("command", "")) + timeout_seconds = int(request.get("timeout_seconds", 30)) + response = _run_command(command, timeout_seconds) conn.sendall((json.dumps(response) + "\n").encode("utf-8")) diff --git a/runtime_sources/linux-x86_64/runtime.lock.json b/runtime_sources/linux-x86_64/runtime.lock.json index b5eca96..72a4e2c 100644 --- a/runtime_sources/linux-x86_64/runtime.lock.json +++ b/runtime_sources/linux-x86_64/runtime.lock.json @@ -5,7 +5,7 @@ "firecracker": "1.12.1", "jailer": "1.12.1", "kernel": "5.10.210", - "guest_agent": "0.1.0-dev", + "guest_agent": "0.2.0-dev", "base_distro": "debian-bookworm-20250210" }, "capabilities": { diff --git a/src/pyro_mcp/api.py b/src/pyro_mcp/api.py index 82f3abf..196e814 100644 --- a/src/pyro_mcp/api.py +++ b/src/pyro_mcp/api.py @@ -86,6 +86,7 @@ class Pyro: ttl_seconds: int = DEFAULT_TTL_SECONDS, network: bool = False, allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, + source_path: str | Path | None = None, ) -> dict[str, Any]: return self._manager.create_task( environment=environment, @@ -94,6 +95,7 @@ class Pyro: ttl_seconds=ttl_seconds, network=network, allow_host_compat=allow_host_compat, + source_path=source_path, ) def exec_task( @@ -245,6 +247,7 @@ class Pyro: ttl_seconds: int = DEFAULT_TTL_SECONDS, network: bool = False, allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, + source_path: str | None = None, ) -> dict[str, Any]: """Create and start a persistent task workspace.""" return self.create_task( @@ -254,6 +257,7 @@ class Pyro: ttl_seconds=ttl_seconds, network=network, allow_host_compat=allow_host_compat, + source_path=source_path, ) @server.tool() diff --git a/src/pyro_mcp/cli.py b/src/pyro_mcp/cli.py index 16c056c..90aad34 100644 --- a/src/pyro_mcp/cli.py +++ b/src/pyro_mcp/cli.py @@ -4,6 +4,7 @@ from __future__ import annotations import argparse import json +import shlex import sys from textwrap import dedent from typing import Any @@ -155,6 +156,14 @@ def _print_task_summary_human(payload: dict[str, Any], *, action: str) -> None: print(f"Environment: {str(payload.get('environment', 'unknown'))}") print(f"State: {str(payload.get('state', 'unknown'))}") print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}") + workspace_seed = payload.get("workspace_seed") + if isinstance(workspace_seed, dict): + mode = str(workspace_seed.get("mode", "empty")) + source_path = workspace_seed.get("source_path") + if isinstance(source_path, str) and source_path != "": + print(f"Workspace seed: {mode} from {source_path}") + else: + print(f"Workspace seed: {mode}") print(f"Execution mode: {str(payload.get('execution_mode', 'pending'))}") print( f"Resources: {int(payload.get('vcpu_count', 0))} vCPU / " @@ -446,7 +455,7 @@ def _build_parser() -> argparse.ArgumentParser: epilog=dedent( """ Examples: - pyro task create debian:12 + pyro task create debian:12 --source-path ./repo pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt' pyro task logs TASK_ID """ @@ -458,7 +467,13 @@ def _build_parser() -> argparse.ArgumentParser: "create", help="Create and start a persistent task workspace.", description="Create a task workspace that stays alive across repeated exec calls.", - epilog="Example:\n pyro task create debian:12", + epilog=dedent( + """ + Examples: + pyro task create debian:12 + pyro task create debian:12 --source-path ./repo + """ + ), formatter_class=_HelpFormatter, ) task_create_parser.add_argument( @@ -497,6 +512,13 @@ def _build_parser() -> argparse.ArgumentParser: "is unavailable." ), ) + task_create_parser.add_argument( + "--source-path", + help=( + "Optional host directory or .tar/.tar.gz/.tgz archive to seed into `/workspace` " + "before the task is returned." + ), + ) task_create_parser.add_argument( "--json", action="store_true", @@ -663,7 +685,7 @@ def _require_command(command_args: list[str]) -> str: command_args = command_args[1:] if not command_args: raise ValueError("command is required after `--`") - return " ".join(command_args) + return shlex.join(command_args) def main() -> None: @@ -764,6 +786,7 @@ def main() -> None: ttl_seconds=args.ttl_seconds, network=args.network, allow_host_compat=args.allow_host_compat, + source_path=args.source_path, ) if bool(args.json): _print_json(payload) diff --git a/src/pyro_mcp/contract.py b/src/pyro_mcp/contract.py index 4e6866e..f582609 100644 --- a/src/pyro_mcp/contract.py +++ b/src/pyro_mcp/contract.py @@ -6,6 +6,15 @@ PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task") PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",) PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune") PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status") +PUBLIC_CLI_TASK_CREATE_FLAGS = ( + "--vcpu-count", + "--mem-mib", + "--ttl-seconds", + "--network", + "--allow-host-compat", + "--source-path", + "--json", +) PUBLIC_CLI_RUN_FLAGS = ( "--vcpu-count", "--mem-mib", diff --git a/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py b/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py index ea9c2cf..d914fc1 100755 --- a/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py +++ b/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py @@ -1,26 +1,31 @@ #!/usr/bin/env python3 -"""Minimal guest-side exec agent for pyro runtime bundles.""" +"""Minimal guest-side exec and workspace import agent for pyro runtime bundles.""" from __future__ import annotations +import io import json +import os import socket import subprocess +import tarfile import time +from pathlib import Path, PurePosixPath from typing import Any PORT = 5005 BUFFER_SIZE = 65536 +WORKSPACE_ROOT = PurePosixPath("/workspace") def _read_request(conn: socket.socket) -> dict[str, Any]: chunks: list[bytes] = [] while True: - data = conn.recv(BUFFER_SIZE) + data = conn.recv(1) if data == b"": break chunks.append(data) - if b"\n" in data: + if data == b"\n": break payload = json.loads(b"".join(chunks).decode("utf-8").strip()) if not isinstance(payload, dict): @@ -28,6 +33,130 @@ def _read_request(conn: socket.socket) -> dict[str, Any]: return payload +def _read_exact(conn: socket.socket, size: int) -> bytes: + remaining = size + chunks: list[bytes] = [] + while remaining > 0: + data = conn.recv(min(BUFFER_SIZE, remaining)) + if data == b"": + raise RuntimeError("unexpected EOF while reading archive payload") + chunks.append(data) + remaining -= len(data) + return b"".join(chunks) + + +def _normalize_member_name(name: str) -> PurePosixPath: + candidate = name.strip() + if candidate == "": + raise RuntimeError("archive member path is empty") + member_path = PurePosixPath(candidate) + if member_path.is_absolute(): + raise RuntimeError(f"absolute archive member paths are not allowed: {name}") + parts = [part for part in member_path.parts if part not in {"", "."}] + if any(part == ".." for part in parts): + raise RuntimeError(f"unsafe archive member path: {name}") + normalized = PurePosixPath(*parts) + if str(normalized) in {"", "."}: + raise RuntimeError(f"unsafe archive member path: {name}") + return normalized + + +def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]: + candidate = destination.strip() + if candidate == "": + raise RuntimeError("destination must not be empty") + destination_path = PurePosixPath(candidate) + if not destination_path.is_absolute(): + destination_path = WORKSPACE_ROOT / destination_path + parts = [part for part in destination_path.parts if part not in {"", "."}] + normalized = PurePosixPath("/") / PurePosixPath(*parts) + if normalized == PurePosixPath("/"): + raise RuntimeError("destination must stay inside /workspace") + if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts: + raise RuntimeError("destination must stay inside /workspace") + suffix = normalized.relative_to(WORKSPACE_ROOT) + host_path = Path("/workspace") + if str(suffix) not in {"", "."}: + host_path = host_path / str(suffix) + return normalized, host_path + + +def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None: + target = link_target.strip() + if target == "": + raise RuntimeError(f"symlink {member_path} has an empty target") + target_path = PurePosixPath(target) + if target_path.is_absolute(): + raise RuntimeError(f"symlink {member_path} escapes the workspace") + combined = member_path.parent.joinpath(target_path) + parts = [part for part in combined.parts if part not in {"", "."}] + if any(part == ".." for part in parts): + raise RuntimeError(f"symlink {member_path} escapes the workspace") + + +def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None: + relative_path = target_path.relative_to(root) + current = root + for part in relative_path.parts[:-1]: + current = current / part + if current.is_symlink(): + raise RuntimeError( + f"archive member would traverse through a symlinked path: {member_name}" + ) + + +def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: + _, destination_root = _normalize_destination(destination) + destination_root.mkdir(parents=True, exist_ok=True) + bytes_written = 0 + entry_count = 0 + with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive: + for member in archive.getmembers(): + member_name = _normalize_member_name(member.name) + target_path = destination_root / str(member_name) + entry_count += 1 + _ensure_no_symlink_parents(destination_root, target_path, member.name) + if member.isdir(): + if target_path.exists() and not target_path.is_dir(): + raise RuntimeError(f"directory conflicts with existing path: {member.name}") + target_path.mkdir(parents=True, exist_ok=True) + continue + if member.isfile(): + target_path.parent.mkdir(parents=True, exist_ok=True) + if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()): + raise RuntimeError(f"file conflicts with existing path: {member.name}") + source = archive.extractfile(member) + if source is None: + raise RuntimeError(f"failed to read archive member: {member.name}") + with target_path.open("wb") as handle: + while True: + chunk = source.read(BUFFER_SIZE) + if chunk == b"": + break + handle.write(chunk) + bytes_written += member.size + continue + if member.issym(): + _validate_symlink_target(member_name, member.linkname) + target_path.parent.mkdir(parents=True, exist_ok=True) + if target_path.exists() and not target_path.is_symlink(): + raise RuntimeError(f"symlink conflicts with existing path: {member.name}") + if target_path.is_symlink(): + target_path.unlink() + os.symlink(member.linkname, target_path) + continue + if member.islnk(): + raise RuntimeError( + f"hard links are not allowed in workspace archives: {member.name}" + ) + raise RuntimeError(f"unsupported archive member type: {member.name}") + return { + "destination": destination, + "entry_count": entry_count, + "bytes_written": bytes_written, + } + + def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]: started = time.monotonic() try: @@ -64,9 +193,18 @@ def main() -> None: conn, _ = server.accept() with conn: request = _read_request(conn) - command = str(request.get("command", "")) - timeout_seconds = int(request.get("timeout_seconds", 30)) - response = _run_command(command, timeout_seconds) + action = str(request.get("action", "exec")) + if action == "extract_archive": + archive_size = int(request.get("archive_size", 0)) + if archive_size < 0: + raise RuntimeError("archive_size must not be negative") + destination = str(request.get("destination", "/workspace")) + payload = _read_exact(conn, archive_size) + response = _extract_archive(payload, destination) + else: + command = str(request.get("command", "")) + timeout_seconds = int(request.get("timeout_seconds", 30)) + response = _run_command(command, timeout_seconds) conn.sendall((json.dumps(response) + "\n").encode("utf-8")) diff --git a/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json b/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json index 1e521e1..948953a 100644 --- a/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json +++ b/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json @@ -18,14 +18,14 @@ "component_versions": { "base_distro": "debian-bookworm-20250210", "firecracker": "1.12.1", - "guest_agent": "0.1.0-dev", + "guest_agent": "0.2.0-dev", "jailer": "1.12.1", "kernel": "5.10.210" }, "guest": { "agent": { "path": "guest/pyro_guest_agent.py", - "sha256": "65bf8a9a57ffd7321463537e598c4b30f0a13046cbd4538f1b65bc351da5d3c0" + "sha256": "3b684b1b07745fc7788e560b0bdd0c0535c31c395ff87474ae9e114f4489726d" } }, "platform": "linux-x86_64", diff --git a/src/pyro_mcp/vm_environments.py b/src/pyro_mcp/vm_environments.py index 4067fe0..8f7cc21 100644 --- a/src/pyro_mcp/vm_environments.py +++ b/src/pyro_mcp/vm_environments.py @@ -19,7 +19,7 @@ from typing import Any from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths DEFAULT_ENVIRONMENT_VERSION = "1.0.0" -DEFAULT_CATALOG_VERSION = "2.1.0" +DEFAULT_CATALOG_VERSION = "2.2.0" OCI_MANIFEST_ACCEPT = ", ".join( ( "application/vnd.oci.image.index.v1+json", diff --git a/src/pyro_mcp/vm_guest.py b/src/pyro_mcp/vm_guest.py index 772f998..d6acd3a 100644 --- a/src/pyro_mcp/vm_guest.py +++ b/src/pyro_mcp/vm_guest.py @@ -5,6 +5,7 @@ from __future__ import annotations import json import socket from dataclasses import dataclass +from pathlib import Path from typing import Any, Callable, Protocol @@ -31,6 +32,13 @@ class GuestExecResponse: duration_ms: int +@dataclass(frozen=True) +class GuestArchiveResponse: + destination: str + entry_count: int + bytes_written: int + + class VsockExecClient: """Minimal JSON-over-stream client for a guest exec agent.""" @@ -50,6 +58,63 @@ class VsockExecClient: "command": command, "timeout_seconds": timeout_seconds, } + sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path) + try: + sock.sendall((json.dumps(request) + "\n").encode("utf-8")) + payload = self._recv_json_payload(sock) + finally: + sock.close() + + if not isinstance(payload, dict): + raise RuntimeError("guest exec response must be a JSON object") + return GuestExecResponse( + stdout=str(payload.get("stdout", "")), + stderr=str(payload.get("stderr", "")), + exit_code=int(payload.get("exit_code", -1)), + duration_ms=int(payload.get("duration_ms", 0)), + ) + + def upload_archive( + self, + guest_cid: int, + port: int, + archive_path: Path, + *, + destination: str, + timeout_seconds: int = 60, + uds_path: str | None = None, + ) -> GuestArchiveResponse: + request = { + "action": "extract_archive", + "destination": destination, + "archive_size": archive_path.stat().st_size, + } + sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path) + try: + sock.sendall((json.dumps(request) + "\n").encode("utf-8")) + with archive_path.open("rb") as handle: + for chunk in iter(lambda: handle.read(65536), b""): + sock.sendall(chunk) + payload = self._recv_json_payload(sock) + finally: + sock.close() + + if not isinstance(payload, dict): + raise RuntimeError("guest archive response must be a JSON object") + return GuestArchiveResponse( + destination=str(payload.get("destination", destination)), + entry_count=int(payload.get("entry_count", 0)), + bytes_written=int(payload.get("bytes_written", 0)), + ) + + def _connect( + self, + guest_cid: int, + port: int, + timeout_seconds: int, + *, + uds_path: str | None, + ) -> SocketLike: family = getattr(socket, "AF_VSOCK", None) if family is not None: sock = self._socket_factory(family, socket.SOCK_STREAM) @@ -59,33 +124,15 @@ class VsockExecClient: connect_address = uds_path else: raise RuntimeError("vsock sockets are not supported on this host Python runtime") - try: - sock.settimeout(timeout_seconds) - sock.connect(connect_address) - if family is None: - sock.sendall(f"CONNECT {port}\n".encode("utf-8")) - status = self._recv_line(sock) - if not status.startswith("OK "): - raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}") - sock.sendall((json.dumps(request) + "\n").encode("utf-8")) - chunks: list[bytes] = [] - while True: - data = sock.recv(65536) - if data == b"": - break - chunks.append(data) - finally: - sock.close() - - payload = json.loads(b"".join(chunks).decode("utf-8")) - if not isinstance(payload, dict): - raise RuntimeError("guest exec response must be a JSON object") - return GuestExecResponse( - stdout=str(payload.get("stdout", "")), - stderr=str(payload.get("stderr", "")), - exit_code=int(payload.get("exit_code", -1)), - duration_ms=int(payload.get("duration_ms", 0)), - ) + sock.settimeout(timeout_seconds) + sock.connect(connect_address) + if family is None: + sock.sendall(f"CONNECT {port}\n".encode("utf-8")) + status = self._recv_line(sock) + if not status.startswith("OK "): + sock.close() + raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}") + return sock @staticmethod def _recv_line(sock: SocketLike) -> str: @@ -98,3 +145,13 @@ class VsockExecClient: if data == b"\n": break return b"".join(chunks).decode("utf-8", errors="replace") + + @staticmethod + def _recv_json_payload(sock: SocketLike) -> Any: + chunks: list[bytes] = [] + while True: + data = sock.recv(65536) + if data == b"": + break + chunks.append(data) + return json.loads(b"".join(chunks).decode("utf-8")) diff --git a/src/pyro_mcp/vm_manager.py b/src/pyro_mcp/vm_manager.py index e93b4c8..398951a 100644 --- a/src/pyro_mcp/vm_manager.py +++ b/src/pyro_mcp/vm_manager.py @@ -8,11 +8,13 @@ import shlex import shutil import signal import subprocess +import tarfile +import tempfile import threading import time import uuid from dataclasses import dataclass, field -from pathlib import Path +from pathlib import Path, PurePosixPath from typing import Any, Literal, cast from pyro_mcp.runtime import ( @@ -34,11 +36,15 @@ DEFAULT_TIMEOUT_SECONDS = 30 DEFAULT_TTL_SECONDS = 600 DEFAULT_ALLOW_HOST_COMPAT = False -TASK_LAYOUT_VERSION = 1 +TASK_LAYOUT_VERSION = 2 TASK_WORKSPACE_DIRNAME = "workspace" TASK_COMMANDS_DIRNAME = "commands" TASK_RUNTIME_DIRNAME = "runtime" TASK_WORKSPACE_GUEST_PATH = "/workspace" +TASK_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py" +TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60 + +WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"] @dataclass @@ -82,6 +88,7 @@ class TaskRecord: network: NetworkConfig | None = None command_count: int = 0 last_command: dict[str, Any] | None = None + workspace_seed: dict[str, Any] = field(default_factory=dict) @classmethod def from_instance( @@ -90,6 +97,7 @@ class TaskRecord: *, command_count: int = 0, last_command: dict[str, Any] | None = None, + workspace_seed: dict[str, Any] | None = None, ) -> TaskRecord: return cls( task_id=instance.vm_id, @@ -108,6 +116,7 @@ class TaskRecord: network=instance.network, command_count=command_count, last_command=last_command, + workspace_seed=dict(workspace_seed or _empty_workspace_seed_payload()), ) def to_instance(self, *, workdir: Path) -> VmInstance: @@ -148,6 +157,7 @@ class TaskRecord: "network": _serialize_network(self.network), "command_count": self.command_count, "last_command": self.last_command, + "workspace_seed": self.workspace_seed, } @classmethod @@ -169,9 +179,35 @@ class TaskRecord: network=_deserialize_network(payload.get("network")), command_count=int(payload.get("command_count", 0)), last_command=_optional_dict(payload.get("last_command")), + workspace_seed=_task_workspace_seed_dict(payload.get("workspace_seed")), ) +@dataclass(frozen=True) +class PreparedWorkspaceSeed: + """Prepared host-side seed archive plus metadata.""" + + mode: WorkspaceSeedMode + source_path: str | None + archive_path: Path | None = None + entry_count: int = 0 + bytes_written: int = 0 + cleanup_dir: Path | None = None + + def to_payload(self) -> dict[str, Any]: + return { + "mode": self.mode, + "source_path": self.source_path, + "destination": TASK_WORKSPACE_GUEST_PATH, + "entry_count": self.entry_count, + "bytes_written": self.bytes_written, + } + + def cleanup(self) -> None: + if self.cleanup_dir is not None: + shutil.rmtree(self.cleanup_dir, ignore_errors=True) + + @dataclass(frozen=True) class VmExecResult: """Command execution output.""" @@ -216,6 +252,32 @@ def _string_dict(value: object) -> dict[str, str]: return {str(key): str(item) for key, item in value.items()} +def _empty_workspace_seed_payload() -> dict[str, Any]: + return { + "mode": "empty", + "source_path": None, + "destination": TASK_WORKSPACE_GUEST_PATH, + "entry_count": 0, + "bytes_written": 0, + } + + +def _task_workspace_seed_dict(value: object) -> dict[str, Any]: + if not isinstance(value, dict): + return _empty_workspace_seed_payload() + payload = _empty_workspace_seed_payload() + payload.update( + { + "mode": str(value.get("mode", payload["mode"])), + "source_path": _optional_str(value.get("source_path")), + "destination": str(value.get("destination", payload["destination"])), + "entry_count": int(value.get("entry_count", payload["entry_count"])), + "bytes_written": int(value.get("bytes_written", payload["bytes_written"])), + } + ) + return payload + + def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None: if network is None: return None @@ -300,6 +362,201 @@ def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str: return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}" +def _is_supported_seed_archive(path: Path) -> bool: + name = path.name.lower() + return name.endswith(".tar") or name.endswith(".tar.gz") or name.endswith(".tgz") + + +def _normalize_workspace_destination(destination: str) -> tuple[str, PurePosixPath]: + candidate = destination.strip() + if candidate == "": + raise ValueError("workspace destination must not be empty") + destination_path = PurePosixPath(candidate) + workspace_root = PurePosixPath(TASK_WORKSPACE_GUEST_PATH) + if not destination_path.is_absolute(): + destination_path = workspace_root / destination_path + parts = [part for part in destination_path.parts if part not in {"", "."}] + normalized = PurePosixPath("/") / PurePosixPath(*parts) + if normalized == PurePosixPath("/"): + raise ValueError("workspace destination must stay inside /workspace") + if normalized.parts[: len(workspace_root.parts)] != workspace_root.parts: + raise ValueError("workspace destination must stay inside /workspace") + suffix = normalized.relative_to(workspace_root) + return str(normalized), suffix + + +def _workspace_host_destination(workspace_dir: Path, destination: str) -> Path: + _, suffix = _normalize_workspace_destination(destination) + if str(suffix) in {"", "."}: + return workspace_dir + return workspace_dir.joinpath(*suffix.parts) + + +def _normalize_archive_member_name(name: str) -> PurePosixPath: + candidate = name.strip() + if candidate == "": + raise RuntimeError("archive member path is empty") + member_path = PurePosixPath(candidate) + if member_path.is_absolute(): + raise RuntimeError(f"absolute archive member paths are not allowed: {name}") + parts = [part for part in member_path.parts if part not in {"", "."}] + if any(part == ".." for part in parts): + raise RuntimeError(f"unsafe archive member path: {name}") + normalized = PurePosixPath(*parts) + if str(normalized) in {"", "."}: + raise RuntimeError(f"unsafe archive member path: {name}") + return normalized + + +def _validate_archive_symlink_target(member_name: PurePosixPath, link_target: str) -> None: + target = link_target.strip() + if target == "": + raise RuntimeError(f"symlink {member_name} has an empty target") + link_path = PurePosixPath(target) + if link_path.is_absolute(): + raise RuntimeError(f"symlink {member_name} escapes the workspace") + combined = member_name.parent.joinpath(link_path) + parts = [part for part in combined.parts if part not in {"", "."}] + if any(part == ".." for part in parts): + raise RuntimeError(f"symlink {member_name} escapes the workspace") + + +def _inspect_seed_archive(archive_path: Path) -> tuple[int, int]: + entry_count = 0 + bytes_written = 0 + with tarfile.open(archive_path, "r:*") as archive: + for member in archive.getmembers(): + member_name = _normalize_archive_member_name(member.name) + entry_count += 1 + if member.isdir(): + continue + if member.isfile(): + bytes_written += member.size + continue + if member.issym(): + _validate_archive_symlink_target(member_name, member.linkname) + continue + if member.islnk(): + raise RuntimeError( + f"hard links are not allowed in workspace archives: {member.name}" + ) + raise RuntimeError(f"unsupported archive member type: {member.name}") + return entry_count, bytes_written + + +def _write_directory_seed_archive(source_dir: Path, archive_path: Path) -> None: + archive_path.parent.mkdir(parents=True, exist_ok=True) + with tarfile.open(archive_path, "w") as archive: + for child in sorted(source_dir.iterdir(), key=lambda item: item.name): + archive.add(child, arcname=child.name, recursive=True) + + +def _extract_seed_archive_to_host_workspace( + archive_path: Path, + *, + workspace_dir: Path, + destination: str, +) -> dict[str, Any]: + normalized_destination, _ = _normalize_workspace_destination(destination) + destination_root = _workspace_host_destination(workspace_dir, normalized_destination) + destination_root.mkdir(parents=True, exist_ok=True) + entry_count = 0 + bytes_written = 0 + with tarfile.open(archive_path, "r:*") as archive: + for member in archive.getmembers(): + member_name = _normalize_archive_member_name(member.name) + target_path = destination_root.joinpath(*member_name.parts) + entry_count += 1 + _ensure_no_symlink_parents(workspace_dir, target_path, member.name) + if member.isdir(): + if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()): + raise RuntimeError(f"directory conflicts with existing path: {member.name}") + target_path.mkdir(parents=True, exist_ok=True) + continue + if member.isfile(): + target_path.parent.mkdir(parents=True, exist_ok=True) + if target_path.is_symlink() or target_path.is_dir(): + raise RuntimeError(f"file conflicts with existing path: {member.name}") + source = archive.extractfile(member) + if source is None: + raise RuntimeError(f"failed to read archive member: {member.name}") + with target_path.open("wb") as handle: + shutil.copyfileobj(source, handle) + bytes_written += member.size + continue + if member.issym(): + _validate_archive_symlink_target(member_name, member.linkname) + target_path.parent.mkdir(parents=True, exist_ok=True) + if target_path.exists() and not target_path.is_symlink(): + raise RuntimeError(f"symlink conflicts with existing path: {member.name}") + if target_path.is_symlink(): + target_path.unlink() + os.symlink(member.linkname, target_path) + continue + if member.islnk(): + raise RuntimeError( + f"hard links are not allowed in workspace archives: {member.name}" + ) + raise RuntimeError(f"unsupported archive member type: {member.name}") + return { + "destination": normalized_destination, + "entry_count": entry_count, + "bytes_written": bytes_written, + } + + +def _instance_workspace_host_dir(instance: VmInstance) -> Path: + raw_value = instance.metadata.get("workspace_host_dir") + if raw_value is None or raw_value == "": + raise RuntimeError("task workspace host directory is unavailable") + return Path(raw_value) + + +def _patch_rootfs_guest_agent(rootfs_image: Path, guest_agent_path: Path) -> None: + debugfs_path = shutil.which("debugfs") + if debugfs_path is None: + raise RuntimeError( + "debugfs is required to seed task workspaces on guest-backed runtimes" + ) + with tempfile.TemporaryDirectory(prefix="pyro-guest-agent-") as temp_dir: + staged_agent_path = Path(temp_dir) / "pyro_guest_agent.py" + shutil.copy2(guest_agent_path, staged_agent_path) + subprocess.run( # noqa: S603 + [debugfs_path, "-w", "-R", f"rm {TASK_GUEST_AGENT_PATH}", str(rootfs_image)], + text=True, + capture_output=True, + check=False, + ) + proc = subprocess.run( # noqa: S603 + [ + debugfs_path, + "-w", + "-R", + f"write {staged_agent_path} {TASK_GUEST_AGENT_PATH}", + str(rootfs_image), + ], + text=True, + capture_output=True, + check=False, + ) + if proc.returncode != 0: + raise RuntimeError( + "failed to patch guest agent into task rootfs: " + f"{proc.stderr.strip() or proc.stdout.strip()}" + ) + + +def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None: + relative_path = target_path.relative_to(root) + current = root + for part in relative_path.parts[:-1]: + current = current / part + if current.is_symlink(): + raise RuntimeError( + f"archive member would traverse through a symlinked path: {member_name}" + ) + + def _pid_is_running(pid: int | None) -> bool: if pid is None: return False @@ -337,6 +594,15 @@ class VmBackend: def delete(self, instance: VmInstance) -> None: # pragma: no cover raise NotImplementedError + def import_archive( # pragma: no cover + self, + instance: VmInstance, + *, + archive_path: Path, + destination: str, + ) -> dict[str, Any]: + raise NotImplementedError + class MockBackend(VmBackend): """Host-process backend used for development and testability.""" @@ -365,6 +631,19 @@ class MockBackend(VmBackend): def delete(self, instance: VmInstance) -> None: shutil.rmtree(instance.workdir, ignore_errors=True) + def import_archive( + self, + instance: VmInstance, + *, + archive_path: Path, + destination: str, + ) -> dict[str, Any]: + return _extract_seed_archive_to_host_workspace( + archive_path, + workspace_dir=_instance_workspace_host_dir(instance), + destination=destination, + ) + class FirecrackerBackend(VmBackend): # pragma: no cover """Host-gated backend that validates Firecracker prerequisites.""" @@ -562,6 +841,46 @@ class FirecrackerBackend(VmBackend): # pragma: no cover self._network_manager.cleanup(instance.network) shutil.rmtree(instance.workdir, ignore_errors=True) + def import_archive( + self, + instance: VmInstance, + *, + archive_path: Path, + destination: str, + ) -> dict[str, Any]: + if self._runtime_capabilities.supports_guest_exec: + guest_cid = int(instance.metadata["guest_cid"]) + port = int(instance.metadata["guest_exec_port"]) + uds_path = instance.metadata.get("guest_exec_uds_path") + deadline = time.monotonic() + 10 + while True: + try: + response = self._guest_exec_client.upload_archive( + guest_cid, + port, + archive_path, + destination=destination, + timeout_seconds=TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS, + uds_path=uds_path, + ) + return { + "destination": response.destination, + "entry_count": response.entry_count, + "bytes_written": response.bytes_written, + } + except (OSError, RuntimeError) as exc: + if time.monotonic() >= deadline: + raise RuntimeError( + f"guest archive transport did not become ready: {exc}" + ) from exc + time.sleep(0.2) + instance.metadata["execution_mode"] = "host_compat" + return _extract_seed_archive_to_host_workspace( + archive_path, + workspace_dir=_instance_workspace_host_dir(instance), + destination=destination, + ) + class VmManager: """In-process lifecycle manager for ephemeral VM environments and tasks.""" @@ -814,9 +1133,11 @@ class VmManager: ttl_seconds: int = DEFAULT_TTL_SECONDS, network: bool = False, allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT, + source_path: str | Path | None = None, ) -> dict[str, Any]: self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds) get_environment(environment, runtime_paths=self._runtime_paths) + prepared_seed = self._prepare_workspace_seed(source_path) now = time.time() task_id = uuid.uuid4().hex[:12] task_dir = self._task_dir(task_id) @@ -851,10 +1172,25 @@ class VmManager: f"max active VMs reached ({self._max_active_vms}); delete old VMs first" ) self._backend.create(instance) + if ( + prepared_seed.archive_path is not None + and self._runtime_capabilities.supports_guest_exec + ): + self._ensure_task_guest_seed_support(instance) with self._lock: self._start_instance_locked(instance) self._require_guest_exec_or_opt_in(instance) - if self._runtime_capabilities.supports_guest_exec: + workspace_seed = prepared_seed.to_payload() + if prepared_seed.archive_path is not None: + import_summary = self._backend.import_archive( + instance, + archive_path=prepared_seed.archive_path, + destination=TASK_WORKSPACE_GUEST_PATH, + ) + workspace_seed["entry_count"] = int(import_summary["entry_count"]) + workspace_seed["bytes_written"] = int(import_summary["bytes_written"]) + workspace_seed["destination"] = str(import_summary["destination"]) + elif self._runtime_capabilities.supports_guest_exec: self._backend.exec( instance, f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}", @@ -862,7 +1198,7 @@ class VmManager: ) else: instance.metadata["execution_mode"] = "host_compat" - task = TaskRecord.from_instance(instance) + task = TaskRecord.from_instance(instance, workspace_seed=workspace_seed) self._save_task_locked(task) return self._serialize_task(task) except Exception: @@ -879,6 +1215,8 @@ class VmManager: pass shutil.rmtree(task_dir, ignore_errors=True) raise + finally: + prepared_seed.cleanup() def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]: if timeout_seconds <= 0: @@ -999,6 +1337,7 @@ class VmManager: "tap_name": task.network.tap_name if task.network is not None else None, "execution_mode": task.metadata.get("execution_mode", "pending"), "workspace_path": TASK_WORKSPACE_GUEST_PATH, + "workspace_seed": _task_workspace_seed_dict(task.workspace_seed), "command_count": task.command_count, "last_command": task.last_command, "metadata": task.metadata, @@ -1094,6 +1433,53 @@ class VmManager: execution_mode = instance.metadata.get("execution_mode", "unknown") return exec_result, execution_mode + def _prepare_workspace_seed(self, source_path: str | Path | None) -> PreparedWorkspaceSeed: + if source_path is None: + return PreparedWorkspaceSeed(mode="empty", source_path=None) + resolved_source_path = Path(source_path).expanduser().resolve() + if not resolved_source_path.exists(): + raise ValueError(f"source_path {resolved_source_path} does not exist") + if resolved_source_path.is_dir(): + cleanup_dir = Path(tempfile.mkdtemp(prefix="pyro-task-seed-")) + archive_path = cleanup_dir / "workspace-seed.tar" + try: + _write_directory_seed_archive(resolved_source_path, archive_path) + entry_count, bytes_written = _inspect_seed_archive(archive_path) + except Exception: + shutil.rmtree(cleanup_dir, ignore_errors=True) + raise + return PreparedWorkspaceSeed( + mode="directory", + source_path=str(resolved_source_path), + archive_path=archive_path, + entry_count=entry_count, + bytes_written=bytes_written, + cleanup_dir=cleanup_dir, + ) + if ( + not resolved_source_path.is_file() + or not _is_supported_seed_archive(resolved_source_path) + ): + raise ValueError( + "source_path must be a directory or a .tar/.tar.gz/.tgz archive" + ) + entry_count, bytes_written = _inspect_seed_archive(resolved_source_path) + return PreparedWorkspaceSeed( + mode="tar_archive", + source_path=str(resolved_source_path), + archive_path=resolved_source_path, + entry_count=entry_count, + bytes_written=bytes_written, + ) + + def _ensure_task_guest_seed_support(self, instance: VmInstance) -> None: + if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None: + raise RuntimeError("runtime bundle does not provide a guest agent for task seeding") + rootfs_image = instance.metadata.get("rootfs_image") + if rootfs_image is None or rootfs_image == "": + raise RuntimeError("task rootfs image is unavailable for guest workspace seeding") + _patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path) + def _task_dir(self, task_id: str) -> Path: return self._tasks_dir / task_id diff --git a/tests/test_api.py b/tests/test_api.py index 7b79121..ca0b2b1 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -114,14 +114,23 @@ def test_pyro_task_methods_delegate_to_manager(tmp_path: Path) -> None: ) ) - created = pyro.create_task(environment="debian:12-base", allow_host_compat=True) + source_dir = tmp_path / "seed" + source_dir.mkdir() + (source_dir / "note.txt").write_text("ok\n", encoding="utf-8") + + created = pyro.create_task( + environment="debian:12-base", + allow_host_compat=True, + source_path=source_dir, + ) task_id = str(created["task_id"]) - executed = pyro.exec_task(task_id, command="printf 'ok\\n'") + executed = pyro.exec_task(task_id, command="cat note.txt") status = pyro.status_task(task_id) logs = pyro.logs_task(task_id) deleted = pyro.delete_task(task_id) assert executed["stdout"] == "ok\n" + assert created["workspace_seed"]["mode"] == "directory" assert status["command_count"] == 1 assert logs["count"] == 1 assert deleted["deleted"] is True diff --git a/tests/test_cli.py b/tests/test_cli.py index 67859f7..180fbb4 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -60,9 +60,13 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None: assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help task_help = _subparser_choice(parser, "task").format_help() - assert "pyro task create debian:12" in task_help + assert "pyro task create debian:12 --source-path ./repo" in task_help assert "pyro task exec TASK_ID" in task_help + task_create_help = _subparser_choice(_subparser_choice(parser, "task"), "create").format_help() + assert "--source-path" in task_create_help + assert "seed into `/workspace`" in task_create_help + task_exec_help = _subparser_choice(_subparser_choice(parser, "task"), "exec").format_help() assert "persistent `/workspace`" in task_exec_help assert "pyro task exec TASK_ID -- cat note.txt" in task_exec_help @@ -326,12 +330,20 @@ def test_cli_requires_run_command() -> None: cli._require_command([]) +def test_cli_requires_command_preserves_shell_argument_boundaries() -> None: + command = cli._require_command( + ["--", "sh", "-lc", 'printf "hello from task\\n" > note.txt'] + ) + assert command == 'sh -lc \'printf "hello from task\\n" > note.txt\'' + + def test_cli_task_create_prints_json( monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] ) -> None: class StubPyro: def create_task(self, **kwargs: Any) -> dict[str, Any]: assert kwargs["environment"] == "debian:12" + assert kwargs["source_path"] == "./repo" return {"task_id": "task-123", "state": "started"} class StubParser: @@ -345,6 +357,7 @@ def test_cli_task_create_prints_json( ttl_seconds=600, network=False, allow_host_compat=False, + source_path="./repo", json=True, ) @@ -366,6 +379,13 @@ def test_cli_task_create_prints_human( "environment": "debian:12", "state": "started", "workspace_path": "/workspace", + "workspace_seed": { + "mode": "directory", + "source_path": "/tmp/repo", + "destination": "/workspace", + "entry_count": 1, + "bytes_written": 6, + }, "execution_mode": "guest_vsock", "vcpu_count": 1, "mem_mib": 1024, @@ -384,6 +404,7 @@ def test_cli_task_create_prints_human( ttl_seconds=600, network=False, allow_host_compat=False, + source_path="/tmp/repo", json=False, ) @@ -393,6 +414,7 @@ def test_cli_task_create_prints_human( output = capsys.readouterr().out assert "Task: task-123" in output assert "Workspace: /workspace" in output + assert "Workspace seed: directory from /tmp/repo" in output def test_cli_task_exec_prints_human_output( diff --git a/tests/test_public_contract.py b/tests/test_public_contract.py index bf07ad4..5f052f0 100644 --- a/tests/test_public_contract.py +++ b/tests/test_public_contract.py @@ -17,6 +17,7 @@ from pyro_mcp.contract import ( PUBLIC_CLI_DEMO_SUBCOMMANDS, PUBLIC_CLI_ENV_SUBCOMMANDS, PUBLIC_CLI_RUN_FLAGS, + PUBLIC_CLI_TASK_CREATE_FLAGS, PUBLIC_CLI_TASK_SUBCOMMANDS, PUBLIC_MCP_TOOLS, PUBLIC_SDK_METHODS, @@ -67,6 +68,11 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None: task_help_text = _subparser_choice(parser, "task").format_help() for subcommand_name in PUBLIC_CLI_TASK_SUBCOMMANDS: assert subcommand_name in task_help_text + task_create_help_text = _subparser_choice( + _subparser_choice(parser, "task"), "create" + ).format_help() + for flag in PUBLIC_CLI_TASK_CREATE_FLAGS: + assert flag in task_create_help_text demo_help_text = _subparser_choice(parser, "demo").format_help() for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS: diff --git a/tests/test_server.py b/tests/test_server.py index be8e1db..7ddfb60 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -171,6 +171,9 @@ def test_task_tools_round_trip(tmp_path: Path) -> None: base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) + source_dir = tmp_path / "seed" + source_dir.mkdir() + (source_dir / "note.txt").write_text("ok\n", encoding="utf-8") def _extract_structured(raw_result: object) -> dict[str, Any]: if not isinstance(raw_result, tuple) or len(raw_result) != 2: @@ -188,6 +191,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None: { "environment": "debian:12-base", "allow_host_compat": True, + "source_path": str(source_dir), }, ) ) @@ -197,7 +201,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None: "task_exec", { "task_id": task_id, - "command": "printf 'ok\\n'", + "command": "cat note.txt", }, ) ) @@ -207,6 +211,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None: created, executed, logs, deleted = asyncio.run(_run()) assert created["state"] == "started" + assert created["workspace_seed"]["mode"] == "directory" assert executed["stdout"] == "ok\n" assert logs["count"] == 1 assert deleted["deleted"] is True diff --git a/tests/test_vm_guest.py b/tests/test_vm_guest.py index fe51894..ee4773e 100644 --- a/tests/test_vm_guest.py +++ b/tests/test_vm_guest.py @@ -1,6 +1,10 @@ from __future__ import annotations +import io +import json import socket +import tarfile +from pathlib import Path import pytest @@ -62,6 +66,45 @@ def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None: assert stub.closed is True +def test_vsock_exec_client_upload_archive_round_trip( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) + archive_path = tmp_path / "seed.tgz" + with tarfile.open(archive_path, "w:gz") as archive: + payload = b"hello\n" + info = tarfile.TarInfo(name="note.txt") + info.size = len(payload) + archive.addfile(info, io.BytesIO(payload)) + stub = StubSocket( + b'{"destination":"/workspace","entry_count":1,"bytes_written":6}' + ) + + def socket_factory(family: int, sock_type: int) -> StubSocket: + assert family == socket.AF_VSOCK + assert sock_type == socket.SOCK_STREAM + return stub + + client = VsockExecClient(socket_factory=socket_factory) + response = client.upload_archive( + 1234, + 5005, + archive_path, + destination="/workspace", + timeout_seconds=60, + ) + + request_payload, archive_payload = stub.sent.split(b"\n", 1) + request = json.loads(request_payload.decode("utf-8")) + assert request["action"] == "extract_archive" + assert request["destination"] == "/workspace" + assert int(request["archive_size"]) == archive_path.stat().st_size + assert archive_payload == archive_path.read_bytes() + assert response.entry_count == 1 + assert response.bytes_written == 6 + assert stub.closed is True + + def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) stub = StubSocket(b"[]") diff --git a/tests/test_vm_manager.py b/tests/test_vm_manager.py index 560ba95..9810631 100644 --- a/tests/test_vm_manager.py +++ b/tests/test_vm_manager.py @@ -1,7 +1,9 @@ from __future__ import annotations +import io import json import subprocess +import tarfile import time from pathlib import Path from typing import Any @@ -306,6 +308,140 @@ def test_task_lifecycle_and_logs(tmp_path: Path) -> None: manager.status_task(task_id) +def test_task_create_seeds_directory_source_into_workspace(tmp_path: Path) -> None: + source_dir = tmp_path / "seed" + source_dir.mkdir() + (source_dir / "note.txt").write_text("hello\n", encoding="utf-8") + + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + created = manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + source_path=source_dir, + ) + task_id = str(created["task_id"]) + + workspace_seed = created["workspace_seed"] + assert workspace_seed["mode"] == "directory" + assert workspace_seed["source_path"] == str(source_dir.resolve()) + executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30) + assert executed["stdout"] == "hello\n" + + status = manager.status_task(task_id) + assert status["workspace_seed"]["mode"] == "directory" + assert status["workspace_seed"]["source_path"] == str(source_dir.resolve()) + + +def test_task_create_seeds_tar_archive_into_workspace(tmp_path: Path) -> None: + archive_path = tmp_path / "seed.tgz" + nested_dir = tmp_path / "src" + nested_dir.mkdir() + (nested_dir / "note.txt").write_text("archive\n", encoding="utf-8") + with tarfile.open(archive_path, "w:gz") as archive: + archive.add(nested_dir / "note.txt", arcname="note.txt") + + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + created = manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + source_path=archive_path, + ) + task_id = str(created["task_id"]) + + assert created["workspace_seed"]["mode"] == "tar_archive" + executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30) + assert executed["stdout"] == "archive\n" + + +def test_task_create_rejects_unsafe_seed_archive(tmp_path: Path) -> None: + archive_path = tmp_path / "bad.tgz" + with tarfile.open(archive_path, "w:gz") as archive: + payload = b"bad\n" + info = tarfile.TarInfo(name="../escape.txt") + info.size = len(payload) + archive.addfile(info, io.BytesIO(payload)) + + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + with pytest.raises(RuntimeError, match="unsafe archive member path"): + manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + source_path=archive_path, + ) + assert list((tmp_path / "vms" / "tasks").iterdir()) == [] + + +def test_task_create_rejects_archive_that_writes_through_symlink(tmp_path: Path) -> None: + archive_path = tmp_path / "bad-symlink.tgz" + with tarfile.open(archive_path, "w:gz") as archive: + symlink_info = tarfile.TarInfo(name="linked") + symlink_info.type = tarfile.SYMTYPE + symlink_info.linkname = "outside" + archive.addfile(symlink_info) + + payload = b"bad\n" + file_info = tarfile.TarInfo(name="linked/note.txt") + file_info.size = len(payload) + archive.addfile(file_info, io.BytesIO(payload)) + + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + with pytest.raises(RuntimeError, match="traverse through a symlinked path"): + manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + source_path=archive_path, + ) + + +def test_task_create_cleans_up_on_seed_failure( + tmp_path: Path, monkeypatch: pytest.MonkeyPatch +) -> None: + source_dir = tmp_path / "seed" + source_dir.mkdir() + (source_dir / "note.txt").write_text("hello\n", encoding="utf-8") + + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + + def _boom(*args: Any, **kwargs: Any) -> dict[str, Any]: + del args, kwargs + raise RuntimeError("seed import failed") + + monkeypatch.setattr(manager._backend, "import_archive", _boom) # noqa: SLF001 + + with pytest.raises(RuntimeError, match="seed import failed"): + manager.create_task( + environment="debian:12-base", + allow_host_compat=True, + source_path=source_dir, + ) + + assert list((tmp_path / "vms" / "tasks").iterdir()) == [] + + def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None: base_dir = tmp_path / "vms" manager = VmManager( diff --git a/uv.lock b/uv.lock index f97ee73..cda5b9c 100644 --- a/uv.lock +++ b/uv.lock @@ -706,7 +706,7 @@ crypto = [ [[package]] name = "pyro-mcp" -version = "2.1.0" +version = "2.2.0" source = { editable = "." } dependencies = [ { name = "mcp" },