Add seeded task workspace creation

Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces.

Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries.

Validation:
- uv lock
- UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py
- UV_CACHE_DIR=.uv-cache make check
- UV_CACHE_DIR=.uv-cache make dist-check
- real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
This commit is contained in:
Thales Maciel 2026-03-11 21:45:38 -03:00
parent 58df176148
commit aa886b346e
25 changed files with 1076 additions and 75 deletions

View file

@ -2,6 +2,15 @@
All notable user-visible changes to `pyro-mcp` are documented here.
## 2.2.0
- Added seeded task creation across the CLI, Python SDK, and MCP server with an optional
`source_path` for host directories and `.tar` / `.tar.gz` / `.tgz` archives.
- Seeded task workspaces now persist `workspace_seed` metadata so later status calls report how
`/workspace` was initialized.
- Reused the task workspace model from `2.1.0` while adding the first explicit host-to-task
content import path for repeated command workflows.
## 2.1.0
- Added the first persistent task workspace alpha across the CLI, Python SDK, and MCP server.

View file

@ -18,7 +18,7 @@ It exposes the same runtime in three public forms:
- First run transcript: [docs/first-run.md](docs/first-run.md)
- Terminal walkthrough GIF: [docs/assets/first-run.gif](docs/assets/first-run.gif)
- PyPI package: [pypi.org/project/pyro-mcp](https://pypi.org/project/pyro-mcp/)
- What's new in 2.1.0: [CHANGELOG.md#210](CHANGELOG.md#210)
- What's new in 2.2.0: [CHANGELOG.md#220](CHANGELOG.md#220)
- Host requirements: [docs/host-requirements.md](docs/host-requirements.md)
- Integration targets: [docs/integrations.md](docs/integrations.md)
- Public contract: [docs/public-contract.md](docs/public-contract.md)
@ -55,7 +55,7 @@ What success looks like:
```bash
Platform: linux-x86_64
Runtime: PASS
Catalog version: 2.1.0
Catalog version: 2.2.0
...
[pull] phase=install environment=debian:12
[pull] phase=ready environment=debian:12
@ -74,7 +74,7 @@ access to `registry-1.docker.io`, and needs local cache space for the guest imag
After the quickstart works:
- prove the full one-shot lifecycle with `uvx --from pyro-mcp pyro demo`
- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12`
- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo`
- move to Python or MCP via [docs/integrations.md](docs/integrations.md)
## Supported Hosts
@ -128,7 +128,7 @@ uvx --from pyro-mcp pyro env list
Expected output:
```bash
Catalog version: 2.1.0
Catalog version: 2.2.0
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
@ -198,7 +198,7 @@ Use `pyro run` for one-shot commands. Use `pyro task ...` when you need repeated
workspace without recreating the sandbox every time.
```bash
pyro task create debian:12
pyro task create debian:12 --source-path ./repo
pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt'
pyro task exec TASK_ID -- cat note.txt
pyro task logs TASK_ID
@ -206,7 +206,9 @@ pyro task delete TASK_ID
```
Task workspaces start in `/workspace` and keep command history until you delete them. For machine
consumption, add `--json` and read the returned `task_id`.
consumption, add `--json` and read the returned `task_id`. Use `--source-path` when you want the
task to start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive instead of an
empty workspace.
## Public Interfaces
@ -348,7 +350,7 @@ For repeated commands in one workspace:
from pyro_mcp import Pyro
pyro = Pyro()
task = pyro.create_task(environment="debian:12")
task = pyro.create_task(environment="debian:12", source_path="./repo")
task_id = task["task_id"]
try:
pyro.exec_task(task_id, command="printf 'hello from task\\n' > note.txt")
@ -378,7 +380,7 @@ Advanced lifecycle tools:
Persistent workspace tools:
- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false)`
- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false, source_path=null)`
- `task_exec(task_id, command, timeout_seconds=30)`
- `task_status(task_id)`
- `task_logs(task_id)`

View file

@ -22,7 +22,7 @@ Networking: tun=yes ip_forward=yes
```bash
$ uvx --from pyro-mcp pyro env list
Catalog version: 2.1.0
Catalog version: 2.2.0
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
@ -70,7 +70,7 @@ deterministic structured result.
```bash
$ uvx --from pyro-mcp pyro demo
$ uvx --from pyro-mcp pyro task create debian:12
$ uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo
$ uvx --from pyro-mcp pyro mcp serve
```
@ -79,11 +79,12 @@ $ uvx --from pyro-mcp pyro mcp serve
When you need repeated commands in one sandbox, switch to `pyro task ...`:
```bash
$ uvx --from pyro-mcp pyro task create debian:12
$ uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo
Task: ...
Environment: debian:12
State: started
Workspace: /workspace
Workspace seed: directory from ...
Execution mode: guest_vsock
Resources: 1 vCPU / 1024 MiB
Command count: 0
@ -96,6 +97,9 @@ hello from task
[task-exec] task_id=... sequence=2 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=...
```
Use `--source-path` when the task should start from a host directory or a local
`.tar` / `.tar.gz` / `.tgz` archive instead of an empty `/workspace`.
Example output:
```json

View file

@ -83,7 +83,7 @@ uvx --from pyro-mcp pyro env list
Expected output:
```bash
Catalog version: 2.1.0
Catalog version: 2.2.0
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
@ -174,7 +174,7 @@ pyro run debian:12 -- git --version
After the CLI path works, you can move on to:
- persistent workspaces: `pyro task create debian:12`
- persistent workspaces: `pyro task create debian:12 --source-path ./repo`
- MCP: `pyro mcp serve`
- Python SDK: `from pyro_mcp import Pyro`
- Demos: `pyro demo` or `pyro demo --network`
@ -184,7 +184,7 @@ After the CLI path works, you can move on to:
Use `pyro task ...` when you need repeated commands in one sandbox instead of one-shot `pyro run`.
```bash
pyro task create debian:12
pyro task create debian:12 --source-path ./repo
pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt'
pyro task exec TASK_ID -- cat note.txt
pyro task logs TASK_ID
@ -192,7 +192,8 @@ pyro task delete TASK_ID
```
Task commands default to the persistent `/workspace` directory inside the guest. If you need the
task identifier programmatically, use `--json` and read the `task_id` field.
task identifier programmatically, use `--json` and read the `task_id` field. Use `--source-path`
when the task should start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive.
## Contributor Clone

View file

@ -30,7 +30,7 @@ Best when:
Recommended surface:
- `vm_run`
- `task_create` + `task_exec` when the agent needs persistent workspace state
- `task_create(source_path=...)` + `task_exec` when the agent needs persistent workspace state
Canonical example:
@ -65,14 +65,15 @@ Best when:
Recommended default:
- `Pyro.run_in_vm(...)`
- `Pyro.create_task(...)` + `Pyro.exec_task(...)` when repeated workspace commands are required
- `Pyro.create_task(source_path=...)` + `Pyro.exec_task(...)` when repeated workspace commands are required
Lifecycle note:
- `Pyro.exec_vm(...)` runs one command and auto-cleans the VM afterward
- use `create_vm(...)` + `start_vm(...)` only when you need pre-exec inspection or status before
that final exec
- use `create_task(...)` when the agent needs repeated commands in one persistent `/workspace`
- use `create_task(source_path=...)` when the agent needs repeated commands in one persistent
`/workspace` that starts from host content
Examples:

View file

@ -46,8 +46,12 @@ Behavioral guarantees:
- `pyro run`, `pyro env list`, `pyro env pull`, `pyro env inspect`, `pyro env prune`, and `pyro doctor` are human-readable by default and return structured JSON with `--json`.
- `pyro demo ollama` prints log lines plus a final summary line.
- `pyro task create` auto-starts a persistent workspace.
- `pyro task create --source-path PATH` seeds `/workspace` from a host directory or a local
`.tar` / `.tar.gz` / `.tgz` archive before the task is returned.
- `pyro task exec` runs in the persistent `/workspace` for that task and does not auto-clean.
- `pyro task logs` returns persisted command history for that task until `pyro task delete`.
- Task create/status results expose `workspace_seed` metadata describing how `/workspace` was
initialized.
## Python SDK Contract
@ -106,6 +110,8 @@ Behavioral defaults:
- `Pyro.create_task(...)` defaults to `vcpu_count=1` and `mem_mib=1024`.
- `allow_host_compat` defaults to `False` on `create_vm(...)` and `run_in_vm(...)`.
- `allow_host_compat` defaults to `False` on `create_task(...)`.
- `Pyro.create_task(..., source_path=...)` seeds `/workspace` from a host directory or a local
`.tar` / `.tar.gz` / `.tgz` archive before the task is returned.
- `Pyro.exec_vm(...)` runs one command and auto-cleans that VM after the exec completes.
- `Pyro.exec_task(...)` runs one command in the persistent task workspace and leaves the task alive.
@ -141,6 +147,8 @@ Behavioral defaults:
- `task_create` defaults to `vcpu_count=1` and `mem_mib=1024`.
- `vm_run` and `vm_create` expose `allow_host_compat`, which defaults to `false`.
- `task_create` exposes `allow_host_compat`, which defaults to `false`.
- `task_create` accepts optional `source_path` and seeds `/workspace` from a host directory or a
local `.tar` / `.tar.gz` / `.tgz` archive before the task is returned.
- `vm_exec` runs one command and auto-cleans that VM after the exec completes.
- `task_exec` runs one command in a persistent `/workspace` and leaves the task alive.

View file

@ -1,6 +1,6 @@
[project]
name = "pyro-mcp"
version = "2.1.0"
version = "2.2.0"
description = "Curated Linux environments for ephemeral Firecracker-backed VM execution."
readme = "README.md"
license = { file = "LICENSE" }

View file

@ -34,7 +34,7 @@ Kernel build note:
Current status:
1. Firecracker and Jailer are materialized from pinned official release artifacts.
2. The kernel and rootfs images are built from pinned inputs into `build/runtime_sources/`.
3. The guest agent is installed into each rootfs and used for vsock exec.
3. The guest agent is installed into each rootfs and used for vsock exec plus workspace archive imports.
4. `runtime.lock.json` now advertises real guest capabilities.
Safety rule:

View file

@ -1,26 +1,31 @@
#!/usr/bin/env python3
"""Minimal guest-side exec agent for pyro runtime bundles."""
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
from __future__ import annotations
import io
import json
import os
import socket
import subprocess
import tarfile
import time
from pathlib import Path, PurePosixPath
from typing import Any
PORT = 5005
BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
def _read_request(conn: socket.socket) -> dict[str, Any]:
chunks: list[bytes] = []
while True:
data = conn.recv(BUFFER_SIZE)
data = conn.recv(1)
if data == b"":
break
chunks.append(data)
if b"\n" in data:
if data == b"\n":
break
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
if not isinstance(payload, dict):
@ -28,6 +33,130 @@ def _read_request(conn: socket.socket) -> dict[str, Any]:
return payload
def _read_exact(conn: socket.socket, size: int) -> bytes:
remaining = size
chunks: list[bytes] = []
while remaining > 0:
data = conn.recv(min(BUFFER_SIZE, remaining))
if data == b"":
raise RuntimeError("unexpected EOF while reading archive payload")
chunks.append(data)
remaining -= len(data)
return b"".join(chunks)
def _normalize_member_name(name: str) -> PurePosixPath:
candidate = name.strip()
if candidate == "":
raise RuntimeError("archive member path is empty")
member_path = PurePosixPath(candidate)
if member_path.is_absolute():
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
parts = [part for part in member_path.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"unsafe archive member path: {name}")
normalized = PurePosixPath(*parts)
if str(normalized) in {"", "."}:
raise RuntimeError(f"unsafe archive member path: {name}")
return normalized
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
candidate = destination.strip()
if candidate == "":
raise RuntimeError("destination must not be empty")
destination_path = PurePosixPath(candidate)
if not destination_path.is_absolute():
destination_path = WORKSPACE_ROOT / destination_path
parts = [part for part in destination_path.parts if part not in {"", "."}]
normalized = PurePosixPath("/") / PurePosixPath(*parts)
if normalized == PurePosixPath("/"):
raise RuntimeError("destination must stay inside /workspace")
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
raise RuntimeError("destination must stay inside /workspace")
suffix = normalized.relative_to(WORKSPACE_ROOT)
host_path = Path("/workspace")
if str(suffix) not in {"", "."}:
host_path = host_path / str(suffix)
return normalized, host_path
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
raise RuntimeError(f"symlink {member_path} has an empty target")
target_path = PurePosixPath(target)
if target_path.is_absolute():
raise RuntimeError(f"symlink {member_path} escapes the workspace")
combined = member_path.parent.joinpath(target_path)
parts = [part for part in combined.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"symlink {member_path} escapes the workspace")
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"archive member would traverse through a symlinked path: {member_name}"
)
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
_, destination_root = _normalize_destination(destination)
destination_root.mkdir(parents=True, exist_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = destination_root / str(member_name)
entry_count += 1
_ensure_no_symlink_parents(destination_root, target_path, member.name)
if member.isdir():
if target_path.exists() and not target_path.is_dir():
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
raise RuntimeError(f"file conflicts with existing path: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read archive member: {member.name}")
with target_path.open("wb") as handle:
while True:
chunk = source.read(BUFFER_SIZE)
if chunk == b"":
break
handle.write(chunk)
bytes_written += member.size
continue
if member.issym():
_validate_symlink_target(member_name, member.linkname)
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not target_path.is_symlink():
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
if target_path.is_symlink():
target_path.unlink()
os.symlink(member.linkname, target_path)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": destination,
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
started = time.monotonic()
try:
@ -64,9 +193,18 @@ def main() -> None:
conn, _ = server.accept()
with conn:
request = _read_request(conn)
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
response = _run_command(command, timeout_seconds)
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
response = _extract_archive(payload, destination)
else:
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
response = _run_command(command, timeout_seconds)
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))

View file

@ -5,7 +5,7 @@
"firecracker": "1.12.1",
"jailer": "1.12.1",
"kernel": "5.10.210",
"guest_agent": "0.1.0-dev",
"guest_agent": "0.2.0-dev",
"base_distro": "debian-bookworm-20250210"
},
"capabilities": {

View file

@ -86,6 +86,7 @@ class Pyro:
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
source_path: str | Path | None = None,
) -> dict[str, Any]:
return self._manager.create_task(
environment=environment,
@ -94,6 +95,7 @@ class Pyro:
ttl_seconds=ttl_seconds,
network=network,
allow_host_compat=allow_host_compat,
source_path=source_path,
)
def exec_task(
@ -245,6 +247,7 @@ class Pyro:
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
source_path: str | None = None,
) -> dict[str, Any]:
"""Create and start a persistent task workspace."""
return self.create_task(
@ -254,6 +257,7 @@ class Pyro:
ttl_seconds=ttl_seconds,
network=network,
allow_host_compat=allow_host_compat,
source_path=source_path,
)
@server.tool()

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import argparse
import json
import shlex
import sys
from textwrap import dedent
from typing import Any
@ -155,6 +156,14 @@ def _print_task_summary_human(payload: dict[str, Any], *, action: str) -> None:
print(f"Environment: {str(payload.get('environment', 'unknown'))}")
print(f"State: {str(payload.get('state', 'unknown'))}")
print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}")
workspace_seed = payload.get("workspace_seed")
if isinstance(workspace_seed, dict):
mode = str(workspace_seed.get("mode", "empty"))
source_path = workspace_seed.get("source_path")
if isinstance(source_path, str) and source_path != "":
print(f"Workspace seed: {mode} from {source_path}")
else:
print(f"Workspace seed: {mode}")
print(f"Execution mode: {str(payload.get('execution_mode', 'pending'))}")
print(
f"Resources: {int(payload.get('vcpu_count', 0))} vCPU / "
@ -446,7 +455,7 @@ def _build_parser() -> argparse.ArgumentParser:
epilog=dedent(
"""
Examples:
pyro task create debian:12
pyro task create debian:12 --source-path ./repo
pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt'
pyro task logs TASK_ID
"""
@ -458,7 +467,13 @@ def _build_parser() -> argparse.ArgumentParser:
"create",
help="Create and start a persistent task workspace.",
description="Create a task workspace that stays alive across repeated exec calls.",
epilog="Example:\n pyro task create debian:12",
epilog=dedent(
"""
Examples:
pyro task create debian:12
pyro task create debian:12 --source-path ./repo
"""
),
formatter_class=_HelpFormatter,
)
task_create_parser.add_argument(
@ -497,6 +512,13 @@ def _build_parser() -> argparse.ArgumentParser:
"is unavailable."
),
)
task_create_parser.add_argument(
"--source-path",
help=(
"Optional host directory or .tar/.tar.gz/.tgz archive to seed into `/workspace` "
"before the task is returned."
),
)
task_create_parser.add_argument(
"--json",
action="store_true",
@ -663,7 +685,7 @@ def _require_command(command_args: list[str]) -> str:
command_args = command_args[1:]
if not command_args:
raise ValueError("command is required after `--`")
return " ".join(command_args)
return shlex.join(command_args)
def main() -> None:
@ -764,6 +786,7 @@ def main() -> None:
ttl_seconds=args.ttl_seconds,
network=args.network,
allow_host_compat=args.allow_host_compat,
source_path=args.source_path,
)
if bool(args.json):
_print_json(payload)

View file

@ -6,6 +6,15 @@ PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task")
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status")
PUBLIC_CLI_TASK_CREATE_FLAGS = (
"--vcpu-count",
"--mem-mib",
"--ttl-seconds",
"--network",
"--allow-host-compat",
"--source-path",
"--json",
)
PUBLIC_CLI_RUN_FLAGS = (
"--vcpu-count",
"--mem-mib",

View file

@ -1,26 +1,31 @@
#!/usr/bin/env python3
"""Minimal guest-side exec agent for pyro runtime bundles."""
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
from __future__ import annotations
import io
import json
import os
import socket
import subprocess
import tarfile
import time
from pathlib import Path, PurePosixPath
from typing import Any
PORT = 5005
BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
def _read_request(conn: socket.socket) -> dict[str, Any]:
chunks: list[bytes] = []
while True:
data = conn.recv(BUFFER_SIZE)
data = conn.recv(1)
if data == b"":
break
chunks.append(data)
if b"\n" in data:
if data == b"\n":
break
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
if not isinstance(payload, dict):
@ -28,6 +33,130 @@ def _read_request(conn: socket.socket) -> dict[str, Any]:
return payload
def _read_exact(conn: socket.socket, size: int) -> bytes:
remaining = size
chunks: list[bytes] = []
while remaining > 0:
data = conn.recv(min(BUFFER_SIZE, remaining))
if data == b"":
raise RuntimeError("unexpected EOF while reading archive payload")
chunks.append(data)
remaining -= len(data)
return b"".join(chunks)
def _normalize_member_name(name: str) -> PurePosixPath:
candidate = name.strip()
if candidate == "":
raise RuntimeError("archive member path is empty")
member_path = PurePosixPath(candidate)
if member_path.is_absolute():
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
parts = [part for part in member_path.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"unsafe archive member path: {name}")
normalized = PurePosixPath(*parts)
if str(normalized) in {"", "."}:
raise RuntimeError(f"unsafe archive member path: {name}")
return normalized
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
candidate = destination.strip()
if candidate == "":
raise RuntimeError("destination must not be empty")
destination_path = PurePosixPath(candidate)
if not destination_path.is_absolute():
destination_path = WORKSPACE_ROOT / destination_path
parts = [part for part in destination_path.parts if part not in {"", "."}]
normalized = PurePosixPath("/") / PurePosixPath(*parts)
if normalized == PurePosixPath("/"):
raise RuntimeError("destination must stay inside /workspace")
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
raise RuntimeError("destination must stay inside /workspace")
suffix = normalized.relative_to(WORKSPACE_ROOT)
host_path = Path("/workspace")
if str(suffix) not in {"", "."}:
host_path = host_path / str(suffix)
return normalized, host_path
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
raise RuntimeError(f"symlink {member_path} has an empty target")
target_path = PurePosixPath(target)
if target_path.is_absolute():
raise RuntimeError(f"symlink {member_path} escapes the workspace")
combined = member_path.parent.joinpath(target_path)
parts = [part for part in combined.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"symlink {member_path} escapes the workspace")
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"archive member would traverse through a symlinked path: {member_name}"
)
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
_, destination_root = _normalize_destination(destination)
destination_root.mkdir(parents=True, exist_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = destination_root / str(member_name)
entry_count += 1
_ensure_no_symlink_parents(destination_root, target_path, member.name)
if member.isdir():
if target_path.exists() and not target_path.is_dir():
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
raise RuntimeError(f"file conflicts with existing path: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read archive member: {member.name}")
with target_path.open("wb") as handle:
while True:
chunk = source.read(BUFFER_SIZE)
if chunk == b"":
break
handle.write(chunk)
bytes_written += member.size
continue
if member.issym():
_validate_symlink_target(member_name, member.linkname)
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not target_path.is_symlink():
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
if target_path.is_symlink():
target_path.unlink()
os.symlink(member.linkname, target_path)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": destination,
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
started = time.monotonic()
try:
@ -64,9 +193,18 @@ def main() -> None:
conn, _ = server.accept()
with conn:
request = _read_request(conn)
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
response = _run_command(command, timeout_seconds)
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
response = _extract_archive(payload, destination)
else:
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
response = _run_command(command, timeout_seconds)
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))

View file

@ -18,14 +18,14 @@
"component_versions": {
"base_distro": "debian-bookworm-20250210",
"firecracker": "1.12.1",
"guest_agent": "0.1.0-dev",
"guest_agent": "0.2.0-dev",
"jailer": "1.12.1",
"kernel": "5.10.210"
},
"guest": {
"agent": {
"path": "guest/pyro_guest_agent.py",
"sha256": "65bf8a9a57ffd7321463537e598c4b30f0a13046cbd4538f1b65bc351da5d3c0"
"sha256": "3b684b1b07745fc7788e560b0bdd0c0535c31c395ff87474ae9e114f4489726d"
}
},
"platform": "linux-x86_64",

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "2.1.0"
DEFAULT_CATALOG_VERSION = "2.2.0"
OCI_MANIFEST_ACCEPT = ", ".join(
(
"application/vnd.oci.image.index.v1+json",

View file

@ -5,6 +5,7 @@ from __future__ import annotations
import json
import socket
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, Protocol
@ -31,6 +32,13 @@ class GuestExecResponse:
duration_ms: int
@dataclass(frozen=True)
class GuestArchiveResponse:
destination: str
entry_count: int
bytes_written: int
class VsockExecClient:
"""Minimal JSON-over-stream client for a guest exec agent."""
@ -50,6 +58,63 @@ class VsockExecClient:
"command": command,
"timeout_seconds": timeout_seconds,
}
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
try:
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
payload = self._recv_json_payload(sock)
finally:
sock.close()
if not isinstance(payload, dict):
raise RuntimeError("guest exec response must be a JSON object")
return GuestExecResponse(
stdout=str(payload.get("stdout", "")),
stderr=str(payload.get("stderr", "")),
exit_code=int(payload.get("exit_code", -1)),
duration_ms=int(payload.get("duration_ms", 0)),
)
def upload_archive(
self,
guest_cid: int,
port: int,
archive_path: Path,
*,
destination: str,
timeout_seconds: int = 60,
uds_path: str | None = None,
) -> GuestArchiveResponse:
request = {
"action": "extract_archive",
"destination": destination,
"archive_size": archive_path.stat().st_size,
}
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
try:
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
with archive_path.open("rb") as handle:
for chunk in iter(lambda: handle.read(65536), b""):
sock.sendall(chunk)
payload = self._recv_json_payload(sock)
finally:
sock.close()
if not isinstance(payload, dict):
raise RuntimeError("guest archive response must be a JSON object")
return GuestArchiveResponse(
destination=str(payload.get("destination", destination)),
entry_count=int(payload.get("entry_count", 0)),
bytes_written=int(payload.get("bytes_written", 0)),
)
def _connect(
self,
guest_cid: int,
port: int,
timeout_seconds: int,
*,
uds_path: str | None,
) -> SocketLike:
family = getattr(socket, "AF_VSOCK", None)
if family is not None:
sock = self._socket_factory(family, socket.SOCK_STREAM)
@ -59,33 +124,15 @@ class VsockExecClient:
connect_address = uds_path
else:
raise RuntimeError("vsock sockets are not supported on this host Python runtime")
try:
sock.settimeout(timeout_seconds)
sock.connect(connect_address)
if family is None:
sock.sendall(f"CONNECT {port}\n".encode("utf-8"))
status = self._recv_line(sock)
if not status.startswith("OK "):
raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}")
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
chunks: list[bytes] = []
while True:
data = sock.recv(65536)
if data == b"":
break
chunks.append(data)
finally:
sock.close()
payload = json.loads(b"".join(chunks).decode("utf-8"))
if not isinstance(payload, dict):
raise RuntimeError("guest exec response must be a JSON object")
return GuestExecResponse(
stdout=str(payload.get("stdout", "")),
stderr=str(payload.get("stderr", "")),
exit_code=int(payload.get("exit_code", -1)),
duration_ms=int(payload.get("duration_ms", 0)),
)
sock.settimeout(timeout_seconds)
sock.connect(connect_address)
if family is None:
sock.sendall(f"CONNECT {port}\n".encode("utf-8"))
status = self._recv_line(sock)
if not status.startswith("OK "):
sock.close()
raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}")
return sock
@staticmethod
def _recv_line(sock: SocketLike) -> str:
@ -98,3 +145,13 @@ class VsockExecClient:
if data == b"\n":
break
return b"".join(chunks).decode("utf-8", errors="replace")
@staticmethod
def _recv_json_payload(sock: SocketLike) -> Any:
chunks: list[bytes] = []
while True:
data = sock.recv(65536)
if data == b"":
break
chunks.append(data)
return json.loads(b"".join(chunks).decode("utf-8"))

View file

@ -8,11 +8,13 @@ import shlex
import shutil
import signal
import subprocess
import tarfile
import tempfile
import threading
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from pathlib import Path, PurePosixPath
from typing import Any, Literal, cast
from pyro_mcp.runtime import (
@ -34,11 +36,15 @@ DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False
TASK_LAYOUT_VERSION = 1
TASK_LAYOUT_VERSION = 2
TASK_WORKSPACE_DIRNAME = "workspace"
TASK_COMMANDS_DIRNAME = "commands"
TASK_RUNTIME_DIRNAME = "runtime"
TASK_WORKSPACE_GUEST_PATH = "/workspace"
TASK_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py"
TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60
WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"]
@dataclass
@ -82,6 +88,7 @@ class TaskRecord:
network: NetworkConfig | None = None
command_count: int = 0
last_command: dict[str, Any] | None = None
workspace_seed: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_instance(
@ -90,6 +97,7 @@ class TaskRecord:
*,
command_count: int = 0,
last_command: dict[str, Any] | None = None,
workspace_seed: dict[str, Any] | None = None,
) -> TaskRecord:
return cls(
task_id=instance.vm_id,
@ -108,6 +116,7 @@ class TaskRecord:
network=instance.network,
command_count=command_count,
last_command=last_command,
workspace_seed=dict(workspace_seed or _empty_workspace_seed_payload()),
)
def to_instance(self, *, workdir: Path) -> VmInstance:
@ -148,6 +157,7 @@ class TaskRecord:
"network": _serialize_network(self.network),
"command_count": self.command_count,
"last_command": self.last_command,
"workspace_seed": self.workspace_seed,
}
@classmethod
@ -169,9 +179,35 @@ class TaskRecord:
network=_deserialize_network(payload.get("network")),
command_count=int(payload.get("command_count", 0)),
last_command=_optional_dict(payload.get("last_command")),
workspace_seed=_task_workspace_seed_dict(payload.get("workspace_seed")),
)
@dataclass(frozen=True)
class PreparedWorkspaceSeed:
"""Prepared host-side seed archive plus metadata."""
mode: WorkspaceSeedMode
source_path: str | None
archive_path: Path | None = None
entry_count: int = 0
bytes_written: int = 0
cleanup_dir: Path | None = None
def to_payload(self) -> dict[str, Any]:
return {
"mode": self.mode,
"source_path": self.source_path,
"destination": TASK_WORKSPACE_GUEST_PATH,
"entry_count": self.entry_count,
"bytes_written": self.bytes_written,
}
def cleanup(self) -> None:
if self.cleanup_dir is not None:
shutil.rmtree(self.cleanup_dir, ignore_errors=True)
@dataclass(frozen=True)
class VmExecResult:
"""Command execution output."""
@ -216,6 +252,32 @@ def _string_dict(value: object) -> dict[str, str]:
return {str(key): str(item) for key, item in value.items()}
def _empty_workspace_seed_payload() -> dict[str, Any]:
return {
"mode": "empty",
"source_path": None,
"destination": TASK_WORKSPACE_GUEST_PATH,
"entry_count": 0,
"bytes_written": 0,
}
def _task_workspace_seed_dict(value: object) -> dict[str, Any]:
if not isinstance(value, dict):
return _empty_workspace_seed_payload()
payload = _empty_workspace_seed_payload()
payload.update(
{
"mode": str(value.get("mode", payload["mode"])),
"source_path": _optional_str(value.get("source_path")),
"destination": str(value.get("destination", payload["destination"])),
"entry_count": int(value.get("entry_count", payload["entry_count"])),
"bytes_written": int(value.get("bytes_written", payload["bytes_written"])),
}
)
return payload
def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None:
if network is None:
return None
@ -300,6 +362,201 @@ def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str:
return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}"
def _is_supported_seed_archive(path: Path) -> bool:
name = path.name.lower()
return name.endswith(".tar") or name.endswith(".tar.gz") or name.endswith(".tgz")
def _normalize_workspace_destination(destination: str) -> tuple[str, PurePosixPath]:
candidate = destination.strip()
if candidate == "":
raise ValueError("workspace destination must not be empty")
destination_path = PurePosixPath(candidate)
workspace_root = PurePosixPath(TASK_WORKSPACE_GUEST_PATH)
if not destination_path.is_absolute():
destination_path = workspace_root / destination_path
parts = [part for part in destination_path.parts if part not in {"", "."}]
normalized = PurePosixPath("/") / PurePosixPath(*parts)
if normalized == PurePosixPath("/"):
raise ValueError("workspace destination must stay inside /workspace")
if normalized.parts[: len(workspace_root.parts)] != workspace_root.parts:
raise ValueError("workspace destination must stay inside /workspace")
suffix = normalized.relative_to(workspace_root)
return str(normalized), suffix
def _workspace_host_destination(workspace_dir: Path, destination: str) -> Path:
_, suffix = _normalize_workspace_destination(destination)
if str(suffix) in {"", "."}:
return workspace_dir
return workspace_dir.joinpath(*suffix.parts)
def _normalize_archive_member_name(name: str) -> PurePosixPath:
candidate = name.strip()
if candidate == "":
raise RuntimeError("archive member path is empty")
member_path = PurePosixPath(candidate)
if member_path.is_absolute():
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
parts = [part for part in member_path.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"unsafe archive member path: {name}")
normalized = PurePosixPath(*parts)
if str(normalized) in {"", "."}:
raise RuntimeError(f"unsafe archive member path: {name}")
return normalized
def _validate_archive_symlink_target(member_name: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
raise RuntimeError(f"symlink {member_name} has an empty target")
link_path = PurePosixPath(target)
if link_path.is_absolute():
raise RuntimeError(f"symlink {member_name} escapes the workspace")
combined = member_name.parent.joinpath(link_path)
parts = [part for part in combined.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"symlink {member_name} escapes the workspace")
def _inspect_seed_archive(archive_path: Path) -> tuple[int, int]:
entry_count = 0
bytes_written = 0
with tarfile.open(archive_path, "r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_archive_member_name(member.name)
entry_count += 1
if member.isdir():
continue
if member.isfile():
bytes_written += member.size
continue
if member.issym():
_validate_archive_symlink_target(member_name, member.linkname)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return entry_count, bytes_written
def _write_directory_seed_archive(source_dir: Path, archive_path: Path) -> None:
archive_path.parent.mkdir(parents=True, exist_ok=True)
with tarfile.open(archive_path, "w") as archive:
for child in sorted(source_dir.iterdir(), key=lambda item: item.name):
archive.add(child, arcname=child.name, recursive=True)
def _extract_seed_archive_to_host_workspace(
archive_path: Path,
*,
workspace_dir: Path,
destination: str,
) -> dict[str, Any]:
normalized_destination, _ = _normalize_workspace_destination(destination)
destination_root = _workspace_host_destination(workspace_dir, normalized_destination)
destination_root.mkdir(parents=True, exist_ok=True)
entry_count = 0
bytes_written = 0
with tarfile.open(archive_path, "r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_archive_member_name(member.name)
target_path = destination_root.joinpath(*member_name.parts)
entry_count += 1
_ensure_no_symlink_parents(workspace_dir, target_path, member.name)
if member.isdir():
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.is_symlink() or target_path.is_dir():
raise RuntimeError(f"file conflicts with existing path: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read archive member: {member.name}")
with target_path.open("wb") as handle:
shutil.copyfileobj(source, handle)
bytes_written += member.size
continue
if member.issym():
_validate_archive_symlink_target(member_name, member.linkname)
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not target_path.is_symlink():
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
if target_path.is_symlink():
target_path.unlink()
os.symlink(member.linkname, target_path)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": normalized_destination,
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _instance_workspace_host_dir(instance: VmInstance) -> Path:
raw_value = instance.metadata.get("workspace_host_dir")
if raw_value is None or raw_value == "":
raise RuntimeError("task workspace host directory is unavailable")
return Path(raw_value)
def _patch_rootfs_guest_agent(rootfs_image: Path, guest_agent_path: Path) -> None:
debugfs_path = shutil.which("debugfs")
if debugfs_path is None:
raise RuntimeError(
"debugfs is required to seed task workspaces on guest-backed runtimes"
)
with tempfile.TemporaryDirectory(prefix="pyro-guest-agent-") as temp_dir:
staged_agent_path = Path(temp_dir) / "pyro_guest_agent.py"
shutil.copy2(guest_agent_path, staged_agent_path)
subprocess.run( # noqa: S603
[debugfs_path, "-w", "-R", f"rm {TASK_GUEST_AGENT_PATH}", str(rootfs_image)],
text=True,
capture_output=True,
check=False,
)
proc = subprocess.run( # noqa: S603
[
debugfs_path,
"-w",
"-R",
f"write {staged_agent_path} {TASK_GUEST_AGENT_PATH}",
str(rootfs_image),
],
text=True,
capture_output=True,
check=False,
)
if proc.returncode != 0:
raise RuntimeError(
"failed to patch guest agent into task rootfs: "
f"{proc.stderr.strip() or proc.stdout.strip()}"
)
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"archive member would traverse through a symlinked path: {member_name}"
)
def _pid_is_running(pid: int | None) -> bool:
if pid is None:
return False
@ -337,6 +594,15 @@ class VmBackend:
def delete(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def import_archive( # pragma: no cover
self,
instance: VmInstance,
*,
archive_path: Path,
destination: str,
) -> dict[str, Any]:
raise NotImplementedError
class MockBackend(VmBackend):
"""Host-process backend used for development and testability."""
@ -365,6 +631,19 @@ class MockBackend(VmBackend):
def delete(self, instance: VmInstance) -> None:
shutil.rmtree(instance.workdir, ignore_errors=True)
def import_archive(
self,
instance: VmInstance,
*,
archive_path: Path,
destination: str,
) -> dict[str, Any]:
return _extract_seed_archive_to_host_workspace(
archive_path,
workspace_dir=_instance_workspace_host_dir(instance),
destination=destination,
)
class FirecrackerBackend(VmBackend): # pragma: no cover
"""Host-gated backend that validates Firecracker prerequisites."""
@ -562,6 +841,46 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
self._network_manager.cleanup(instance.network)
shutil.rmtree(instance.workdir, ignore_errors=True)
def import_archive(
self,
instance: VmInstance,
*,
archive_path: Path,
destination: str,
) -> dict[str, Any]:
if self._runtime_capabilities.supports_guest_exec:
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
uds_path = instance.metadata.get("guest_exec_uds_path")
deadline = time.monotonic() + 10
while True:
try:
response = self._guest_exec_client.upload_archive(
guest_cid,
port,
archive_path,
destination=destination,
timeout_seconds=TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS,
uds_path=uds_path,
)
return {
"destination": response.destination,
"entry_count": response.entry_count,
"bytes_written": response.bytes_written,
}
except (OSError, RuntimeError) as exc:
if time.monotonic() >= deadline:
raise RuntimeError(
f"guest archive transport did not become ready: {exc}"
) from exc
time.sleep(0.2)
instance.metadata["execution_mode"] = "host_compat"
return _extract_seed_archive_to_host_workspace(
archive_path,
workspace_dir=_instance_workspace_host_dir(instance),
destination=destination,
)
class VmManager:
"""In-process lifecycle manager for ephemeral VM environments and tasks."""
@ -814,9 +1133,11 @@ class VmManager:
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
source_path: str | Path | None = None,
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_environment(environment, runtime_paths=self._runtime_paths)
prepared_seed = self._prepare_workspace_seed(source_path)
now = time.time()
task_id = uuid.uuid4().hex[:12]
task_dir = self._task_dir(task_id)
@ -851,10 +1172,25 @@ class VmManager:
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
self._backend.create(instance)
if (
prepared_seed.archive_path is not None
and self._runtime_capabilities.supports_guest_exec
):
self._ensure_task_guest_seed_support(instance)
with self._lock:
self._start_instance_locked(instance)
self._require_guest_exec_or_opt_in(instance)
if self._runtime_capabilities.supports_guest_exec:
workspace_seed = prepared_seed.to_payload()
if prepared_seed.archive_path is not None:
import_summary = self._backend.import_archive(
instance,
archive_path=prepared_seed.archive_path,
destination=TASK_WORKSPACE_GUEST_PATH,
)
workspace_seed["entry_count"] = int(import_summary["entry_count"])
workspace_seed["bytes_written"] = int(import_summary["bytes_written"])
workspace_seed["destination"] = str(import_summary["destination"])
elif self._runtime_capabilities.supports_guest_exec:
self._backend.exec(
instance,
f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}",
@ -862,7 +1198,7 @@ class VmManager:
)
else:
instance.metadata["execution_mode"] = "host_compat"
task = TaskRecord.from_instance(instance)
task = TaskRecord.from_instance(instance, workspace_seed=workspace_seed)
self._save_task_locked(task)
return self._serialize_task(task)
except Exception:
@ -879,6 +1215,8 @@ class VmManager:
pass
shutil.rmtree(task_dir, ignore_errors=True)
raise
finally:
prepared_seed.cleanup()
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
if timeout_seconds <= 0:
@ -999,6 +1337,7 @@ class VmManager:
"tap_name": task.network.tap_name if task.network is not None else None,
"execution_mode": task.metadata.get("execution_mode", "pending"),
"workspace_path": TASK_WORKSPACE_GUEST_PATH,
"workspace_seed": _task_workspace_seed_dict(task.workspace_seed),
"command_count": task.command_count,
"last_command": task.last_command,
"metadata": task.metadata,
@ -1094,6 +1433,53 @@ class VmManager:
execution_mode = instance.metadata.get("execution_mode", "unknown")
return exec_result, execution_mode
def _prepare_workspace_seed(self, source_path: str | Path | None) -> PreparedWorkspaceSeed:
if source_path is None:
return PreparedWorkspaceSeed(mode="empty", source_path=None)
resolved_source_path = Path(source_path).expanduser().resolve()
if not resolved_source_path.exists():
raise ValueError(f"source_path {resolved_source_path} does not exist")
if resolved_source_path.is_dir():
cleanup_dir = Path(tempfile.mkdtemp(prefix="pyro-task-seed-"))
archive_path = cleanup_dir / "workspace-seed.tar"
try:
_write_directory_seed_archive(resolved_source_path, archive_path)
entry_count, bytes_written = _inspect_seed_archive(archive_path)
except Exception:
shutil.rmtree(cleanup_dir, ignore_errors=True)
raise
return PreparedWorkspaceSeed(
mode="directory",
source_path=str(resolved_source_path),
archive_path=archive_path,
entry_count=entry_count,
bytes_written=bytes_written,
cleanup_dir=cleanup_dir,
)
if (
not resolved_source_path.is_file()
or not _is_supported_seed_archive(resolved_source_path)
):
raise ValueError(
"source_path must be a directory or a .tar/.tar.gz/.tgz archive"
)
entry_count, bytes_written = _inspect_seed_archive(resolved_source_path)
return PreparedWorkspaceSeed(
mode="tar_archive",
source_path=str(resolved_source_path),
archive_path=resolved_source_path,
entry_count=entry_count,
bytes_written=bytes_written,
)
def _ensure_task_guest_seed_support(self, instance: VmInstance) -> None:
if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None:
raise RuntimeError("runtime bundle does not provide a guest agent for task seeding")
rootfs_image = instance.metadata.get("rootfs_image")
if rootfs_image is None or rootfs_image == "":
raise RuntimeError("task rootfs image is unavailable for guest workspace seeding")
_patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path)
def _task_dir(self, task_id: str) -> Path:
return self._tasks_dir / task_id

View file

@ -114,14 +114,23 @@ def test_pyro_task_methods_delegate_to_manager(tmp_path: Path) -> None:
)
)
created = pyro.create_task(environment="debian:12-base", allow_host_compat=True)
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
created = pyro.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=source_dir,
)
task_id = str(created["task_id"])
executed = pyro.exec_task(task_id, command="printf 'ok\\n'")
executed = pyro.exec_task(task_id, command="cat note.txt")
status = pyro.status_task(task_id)
logs = pyro.logs_task(task_id)
deleted = pyro.delete_task(task_id)
assert executed["stdout"] == "ok\n"
assert created["workspace_seed"]["mode"] == "directory"
assert status["command_count"] == 1
assert logs["count"] == 1
assert deleted["deleted"] is True

View file

@ -60,9 +60,13 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None:
assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help
task_help = _subparser_choice(parser, "task").format_help()
assert "pyro task create debian:12" in task_help
assert "pyro task create debian:12 --source-path ./repo" in task_help
assert "pyro task exec TASK_ID" in task_help
task_create_help = _subparser_choice(_subparser_choice(parser, "task"), "create").format_help()
assert "--source-path" in task_create_help
assert "seed into `/workspace`" in task_create_help
task_exec_help = _subparser_choice(_subparser_choice(parser, "task"), "exec").format_help()
assert "persistent `/workspace`" in task_exec_help
assert "pyro task exec TASK_ID -- cat note.txt" in task_exec_help
@ -326,12 +330,20 @@ def test_cli_requires_run_command() -> None:
cli._require_command([])
def test_cli_requires_command_preserves_shell_argument_boundaries() -> None:
command = cli._require_command(
["--", "sh", "-lc", 'printf "hello from task\\n" > note.txt']
)
assert command == 'sh -lc \'printf "hello from task\\n" > note.txt\''
def test_cli_task_create_prints_json(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def create_task(self, **kwargs: Any) -> dict[str, Any]:
assert kwargs["environment"] == "debian:12"
assert kwargs["source_path"] == "./repo"
return {"task_id": "task-123", "state": "started"}
class StubParser:
@ -345,6 +357,7 @@ def test_cli_task_create_prints_json(
ttl_seconds=600,
network=False,
allow_host_compat=False,
source_path="./repo",
json=True,
)
@ -366,6 +379,13 @@ def test_cli_task_create_prints_human(
"environment": "debian:12",
"state": "started",
"workspace_path": "/workspace",
"workspace_seed": {
"mode": "directory",
"source_path": "/tmp/repo",
"destination": "/workspace",
"entry_count": 1,
"bytes_written": 6,
},
"execution_mode": "guest_vsock",
"vcpu_count": 1,
"mem_mib": 1024,
@ -384,6 +404,7 @@ def test_cli_task_create_prints_human(
ttl_seconds=600,
network=False,
allow_host_compat=False,
source_path="/tmp/repo",
json=False,
)
@ -393,6 +414,7 @@ def test_cli_task_create_prints_human(
output = capsys.readouterr().out
assert "Task: task-123" in output
assert "Workspace: /workspace" in output
assert "Workspace seed: directory from /tmp/repo" in output
def test_cli_task_exec_prints_human_output(

View file

@ -17,6 +17,7 @@ from pyro_mcp.contract import (
PUBLIC_CLI_DEMO_SUBCOMMANDS,
PUBLIC_CLI_ENV_SUBCOMMANDS,
PUBLIC_CLI_RUN_FLAGS,
PUBLIC_CLI_TASK_CREATE_FLAGS,
PUBLIC_CLI_TASK_SUBCOMMANDS,
PUBLIC_MCP_TOOLS,
PUBLIC_SDK_METHODS,
@ -67,6 +68,11 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None:
task_help_text = _subparser_choice(parser, "task").format_help()
for subcommand_name in PUBLIC_CLI_TASK_SUBCOMMANDS:
assert subcommand_name in task_help_text
task_create_help_text = _subparser_choice(
_subparser_choice(parser, "task"), "create"
).format_help()
for flag in PUBLIC_CLI_TASK_CREATE_FLAGS:
assert flag in task_create_help_text
demo_help_text = _subparser_choice(parser, "demo").format_help()
for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS:

View file

@ -171,6 +171,9 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
def _extract_structured(raw_result: object) -> dict[str, Any]:
if not isinstance(raw_result, tuple) or len(raw_result) != 2:
@ -188,6 +191,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
{
"environment": "debian:12-base",
"allow_host_compat": True,
"source_path": str(source_dir),
},
)
)
@ -197,7 +201,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
"task_exec",
{
"task_id": task_id,
"command": "printf 'ok\\n'",
"command": "cat note.txt",
},
)
)
@ -207,6 +211,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
created, executed, logs, deleted = asyncio.run(_run())
assert created["state"] == "started"
assert created["workspace_seed"]["mode"] == "directory"
assert executed["stdout"] == "ok\n"
assert logs["count"] == 1
assert deleted["deleted"] is True

View file

@ -1,6 +1,10 @@
from __future__ import annotations
import io
import json
import socket
import tarfile
from pathlib import Path
import pytest
@ -62,6 +66,45 @@ def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None:
assert stub.closed is True
def test_vsock_exec_client_upload_archive_round_trip(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
archive_path = tmp_path / "seed.tgz"
with tarfile.open(archive_path, "w:gz") as archive:
payload = b"hello\n"
info = tarfile.TarInfo(name="note.txt")
info.size = len(payload)
archive.addfile(info, io.BytesIO(payload))
stub = StubSocket(
b'{"destination":"/workspace","entry_count":1,"bytes_written":6}'
)
def socket_factory(family: int, sock_type: int) -> StubSocket:
assert family == socket.AF_VSOCK
assert sock_type == socket.SOCK_STREAM
return stub
client = VsockExecClient(socket_factory=socket_factory)
response = client.upload_archive(
1234,
5005,
archive_path,
destination="/workspace",
timeout_seconds=60,
)
request_payload, archive_payload = stub.sent.split(b"\n", 1)
request = json.loads(request_payload.decode("utf-8"))
assert request["action"] == "extract_archive"
assert request["destination"] == "/workspace"
assert int(request["archive_size"]) == archive_path.stat().st_size
assert archive_payload == archive_path.read_bytes()
assert response.entry_count == 1
assert response.bytes_written == 6
assert stub.closed is True
def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
stub = StubSocket(b"[]")

View file

@ -1,7 +1,9 @@
from __future__ import annotations
import io
import json
import subprocess
import tarfile
import time
from pathlib import Path
from typing import Any
@ -306,6 +308,140 @@ def test_task_lifecycle_and_logs(tmp_path: Path) -> None:
manager.status_task(task_id)
def test_task_create_seeds_directory_source_into_workspace(tmp_path: Path) -> None:
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
created = manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=source_dir,
)
task_id = str(created["task_id"])
workspace_seed = created["workspace_seed"]
assert workspace_seed["mode"] == "directory"
assert workspace_seed["source_path"] == str(source_dir.resolve())
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
assert executed["stdout"] == "hello\n"
status = manager.status_task(task_id)
assert status["workspace_seed"]["mode"] == "directory"
assert status["workspace_seed"]["source_path"] == str(source_dir.resolve())
def test_task_create_seeds_tar_archive_into_workspace(tmp_path: Path) -> None:
archive_path = tmp_path / "seed.tgz"
nested_dir = tmp_path / "src"
nested_dir.mkdir()
(nested_dir / "note.txt").write_text("archive\n", encoding="utf-8")
with tarfile.open(archive_path, "w:gz") as archive:
archive.add(nested_dir / "note.txt", arcname="note.txt")
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
created = manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=archive_path,
)
task_id = str(created["task_id"])
assert created["workspace_seed"]["mode"] == "tar_archive"
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
assert executed["stdout"] == "archive\n"
def test_task_create_rejects_unsafe_seed_archive(tmp_path: Path) -> None:
archive_path = tmp_path / "bad.tgz"
with tarfile.open(archive_path, "w:gz") as archive:
payload = b"bad\n"
info = tarfile.TarInfo(name="../escape.txt")
info.size = len(payload)
archive.addfile(info, io.BytesIO(payload))
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
with pytest.raises(RuntimeError, match="unsafe archive member path"):
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=archive_path,
)
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
def test_task_create_rejects_archive_that_writes_through_symlink(tmp_path: Path) -> None:
archive_path = tmp_path / "bad-symlink.tgz"
with tarfile.open(archive_path, "w:gz") as archive:
symlink_info = tarfile.TarInfo(name="linked")
symlink_info.type = tarfile.SYMTYPE
symlink_info.linkname = "outside"
archive.addfile(symlink_info)
payload = b"bad\n"
file_info = tarfile.TarInfo(name="linked/note.txt")
file_info.size = len(payload)
archive.addfile(file_info, io.BytesIO(payload))
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
with pytest.raises(RuntimeError, match="traverse through a symlinked path"):
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=archive_path,
)
def test_task_create_cleans_up_on_seed_failure(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
def _boom(*args: Any, **kwargs: Any) -> dict[str, Any]:
del args, kwargs
raise RuntimeError("seed import failed")
monkeypatch.setattr(manager._backend, "import_archive", _boom) # noqa: SLF001
with pytest.raises(RuntimeError, match="seed import failed"):
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=source_dir,
)
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None:
base_dir = tmp_path / "vms"
manager = VmManager(

2
uv.lock generated
View file

@ -706,7 +706,7 @@ crypto = [
[[package]]
name = "pyro-mcp"
version = "2.1.0"
version = "2.2.0"
source = { editable = "." }
dependencies = [
{ name = "mcp" },