Add seeded task workspace creation
Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces. Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries. Validation: - uv lock - UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py - UV_CACHE_DIR=.uv-cache make check - UV_CACHE_DIR=.uv-cache make dist-check - real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
This commit is contained in:
parent
58df176148
commit
aa886b346e
25 changed files with 1076 additions and 75 deletions
|
|
@ -2,6 +2,15 @@
|
|||
|
||||
All notable user-visible changes to `pyro-mcp` are documented here.
|
||||
|
||||
## 2.2.0
|
||||
|
||||
- Added seeded task creation across the CLI, Python SDK, and MCP server with an optional
|
||||
`source_path` for host directories and `.tar` / `.tar.gz` / `.tgz` archives.
|
||||
- Seeded task workspaces now persist `workspace_seed` metadata so later status calls report how
|
||||
`/workspace` was initialized.
|
||||
- Reused the task workspace model from `2.1.0` while adding the first explicit host-to-task
|
||||
content import path for repeated command workflows.
|
||||
|
||||
## 2.1.0
|
||||
|
||||
- Added the first persistent task workspace alpha across the CLI, Python SDK, and MCP server.
|
||||
|
|
|
|||
18
README.md
18
README.md
|
|
@ -18,7 +18,7 @@ It exposes the same runtime in three public forms:
|
|||
- First run transcript: [docs/first-run.md](docs/first-run.md)
|
||||
- Terminal walkthrough GIF: [docs/assets/first-run.gif](docs/assets/first-run.gif)
|
||||
- PyPI package: [pypi.org/project/pyro-mcp](https://pypi.org/project/pyro-mcp/)
|
||||
- What's new in 2.1.0: [CHANGELOG.md#210](CHANGELOG.md#210)
|
||||
- What's new in 2.2.0: [CHANGELOG.md#220](CHANGELOG.md#220)
|
||||
- Host requirements: [docs/host-requirements.md](docs/host-requirements.md)
|
||||
- Integration targets: [docs/integrations.md](docs/integrations.md)
|
||||
- Public contract: [docs/public-contract.md](docs/public-contract.md)
|
||||
|
|
@ -55,7 +55,7 @@ What success looks like:
|
|||
```bash
|
||||
Platform: linux-x86_64
|
||||
Runtime: PASS
|
||||
Catalog version: 2.1.0
|
||||
Catalog version: 2.2.0
|
||||
...
|
||||
[pull] phase=install environment=debian:12
|
||||
[pull] phase=ready environment=debian:12
|
||||
|
|
@ -74,7 +74,7 @@ access to `registry-1.docker.io`, and needs local cache space for the guest imag
|
|||
After the quickstart works:
|
||||
|
||||
- prove the full one-shot lifecycle with `uvx --from pyro-mcp pyro demo`
|
||||
- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12`
|
||||
- create a persistent workspace with `uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo`
|
||||
- move to Python or MCP via [docs/integrations.md](docs/integrations.md)
|
||||
|
||||
## Supported Hosts
|
||||
|
|
@ -128,7 +128,7 @@ uvx --from pyro-mcp pyro env list
|
|||
Expected output:
|
||||
|
||||
```bash
|
||||
Catalog version: 2.1.0
|
||||
Catalog version: 2.2.0
|
||||
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
|
||||
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
|
||||
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
|
||||
|
|
@ -198,7 +198,7 @@ Use `pyro run` for one-shot commands. Use `pyro task ...` when you need repeated
|
|||
workspace without recreating the sandbox every time.
|
||||
|
||||
```bash
|
||||
pyro task create debian:12
|
||||
pyro task create debian:12 --source-path ./repo
|
||||
pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt'
|
||||
pyro task exec TASK_ID -- cat note.txt
|
||||
pyro task logs TASK_ID
|
||||
|
|
@ -206,7 +206,9 @@ pyro task delete TASK_ID
|
|||
```
|
||||
|
||||
Task workspaces start in `/workspace` and keep command history until you delete them. For machine
|
||||
consumption, add `--json` and read the returned `task_id`.
|
||||
consumption, add `--json` and read the returned `task_id`. Use `--source-path` when you want the
|
||||
task to start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive instead of an
|
||||
empty workspace.
|
||||
|
||||
## Public Interfaces
|
||||
|
||||
|
|
@ -348,7 +350,7 @@ For repeated commands in one workspace:
|
|||
from pyro_mcp import Pyro
|
||||
|
||||
pyro = Pyro()
|
||||
task = pyro.create_task(environment="debian:12")
|
||||
task = pyro.create_task(environment="debian:12", source_path="./repo")
|
||||
task_id = task["task_id"]
|
||||
try:
|
||||
pyro.exec_task(task_id, command="printf 'hello from task\\n' > note.txt")
|
||||
|
|
@ -378,7 +380,7 @@ Advanced lifecycle tools:
|
|||
|
||||
Persistent workspace tools:
|
||||
|
||||
- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false)`
|
||||
- `task_create(environment, vcpu_count=1, mem_mib=1024, ttl_seconds=600, network=false, allow_host_compat=false, source_path=null)`
|
||||
- `task_exec(task_id, command, timeout_seconds=30)`
|
||||
- `task_status(task_id)`
|
||||
- `task_logs(task_id)`
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ Networking: tun=yes ip_forward=yes
|
|||
|
||||
```bash
|
||||
$ uvx --from pyro-mcp pyro env list
|
||||
Catalog version: 2.1.0
|
||||
Catalog version: 2.2.0
|
||||
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
|
||||
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
|
||||
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
|
||||
|
|
@ -70,7 +70,7 @@ deterministic structured result.
|
|||
|
||||
```bash
|
||||
$ uvx --from pyro-mcp pyro demo
|
||||
$ uvx --from pyro-mcp pyro task create debian:12
|
||||
$ uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo
|
||||
$ uvx --from pyro-mcp pyro mcp serve
|
||||
```
|
||||
|
||||
|
|
@ -79,11 +79,12 @@ $ uvx --from pyro-mcp pyro mcp serve
|
|||
When you need repeated commands in one sandbox, switch to `pyro task ...`:
|
||||
|
||||
```bash
|
||||
$ uvx --from pyro-mcp pyro task create debian:12
|
||||
$ uvx --from pyro-mcp pyro task create debian:12 --source-path ./repo
|
||||
Task: ...
|
||||
Environment: debian:12
|
||||
State: started
|
||||
Workspace: /workspace
|
||||
Workspace seed: directory from ...
|
||||
Execution mode: guest_vsock
|
||||
Resources: 1 vCPU / 1024 MiB
|
||||
Command count: 0
|
||||
|
|
@ -96,6 +97,9 @@ hello from task
|
|||
[task-exec] task_id=... sequence=2 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=...
|
||||
```
|
||||
|
||||
Use `--source-path` when the task should start from a host directory or a local
|
||||
`.tar` / `.tar.gz` / `.tgz` archive instead of an empty `/workspace`.
|
||||
|
||||
Example output:
|
||||
|
||||
```json
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ uvx --from pyro-mcp pyro env list
|
|||
Expected output:
|
||||
|
||||
```bash
|
||||
Catalog version: 2.1.0
|
||||
Catalog version: 2.2.0
|
||||
debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows.
|
||||
debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling.
|
||||
debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled.
|
||||
|
|
@ -174,7 +174,7 @@ pyro run debian:12 -- git --version
|
|||
|
||||
After the CLI path works, you can move on to:
|
||||
|
||||
- persistent workspaces: `pyro task create debian:12`
|
||||
- persistent workspaces: `pyro task create debian:12 --source-path ./repo`
|
||||
- MCP: `pyro mcp serve`
|
||||
- Python SDK: `from pyro_mcp import Pyro`
|
||||
- Demos: `pyro demo` or `pyro demo --network`
|
||||
|
|
@ -184,7 +184,7 @@ After the CLI path works, you can move on to:
|
|||
Use `pyro task ...` when you need repeated commands in one sandbox instead of one-shot `pyro run`.
|
||||
|
||||
```bash
|
||||
pyro task create debian:12
|
||||
pyro task create debian:12 --source-path ./repo
|
||||
pyro task exec TASK_ID -- sh -lc 'printf "hello from task\n" > note.txt'
|
||||
pyro task exec TASK_ID -- cat note.txt
|
||||
pyro task logs TASK_ID
|
||||
|
|
@ -192,7 +192,8 @@ pyro task delete TASK_ID
|
|||
```
|
||||
|
||||
Task commands default to the persistent `/workspace` directory inside the guest. If you need the
|
||||
task identifier programmatically, use `--json` and read the `task_id` field.
|
||||
task identifier programmatically, use `--json` and read the `task_id` field. Use `--source-path`
|
||||
when the task should start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive.
|
||||
|
||||
## Contributor Clone
|
||||
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ Best when:
|
|||
Recommended surface:
|
||||
|
||||
- `vm_run`
|
||||
- `task_create` + `task_exec` when the agent needs persistent workspace state
|
||||
- `task_create(source_path=...)` + `task_exec` when the agent needs persistent workspace state
|
||||
|
||||
Canonical example:
|
||||
|
||||
|
|
@ -65,14 +65,15 @@ Best when:
|
|||
Recommended default:
|
||||
|
||||
- `Pyro.run_in_vm(...)`
|
||||
- `Pyro.create_task(...)` + `Pyro.exec_task(...)` when repeated workspace commands are required
|
||||
- `Pyro.create_task(source_path=...)` + `Pyro.exec_task(...)` when repeated workspace commands are required
|
||||
|
||||
Lifecycle note:
|
||||
|
||||
- `Pyro.exec_vm(...)` runs one command and auto-cleans the VM afterward
|
||||
- use `create_vm(...)` + `start_vm(...)` only when you need pre-exec inspection or status before
|
||||
that final exec
|
||||
- use `create_task(...)` when the agent needs repeated commands in one persistent `/workspace`
|
||||
- use `create_task(source_path=...)` when the agent needs repeated commands in one persistent
|
||||
`/workspace` that starts from host content
|
||||
|
||||
Examples:
|
||||
|
||||
|
|
|
|||
|
|
@ -46,8 +46,12 @@ Behavioral guarantees:
|
|||
- `pyro run`, `pyro env list`, `pyro env pull`, `pyro env inspect`, `pyro env prune`, and `pyro doctor` are human-readable by default and return structured JSON with `--json`.
|
||||
- `pyro demo ollama` prints log lines plus a final summary line.
|
||||
- `pyro task create` auto-starts a persistent workspace.
|
||||
- `pyro task create --source-path PATH` seeds `/workspace` from a host directory or a local
|
||||
`.tar` / `.tar.gz` / `.tgz` archive before the task is returned.
|
||||
- `pyro task exec` runs in the persistent `/workspace` for that task and does not auto-clean.
|
||||
- `pyro task logs` returns persisted command history for that task until `pyro task delete`.
|
||||
- Task create/status results expose `workspace_seed` metadata describing how `/workspace` was
|
||||
initialized.
|
||||
|
||||
## Python SDK Contract
|
||||
|
||||
|
|
@ -106,6 +110,8 @@ Behavioral defaults:
|
|||
- `Pyro.create_task(...)` defaults to `vcpu_count=1` and `mem_mib=1024`.
|
||||
- `allow_host_compat` defaults to `False` on `create_vm(...)` and `run_in_vm(...)`.
|
||||
- `allow_host_compat` defaults to `False` on `create_task(...)`.
|
||||
- `Pyro.create_task(..., source_path=...)` seeds `/workspace` from a host directory or a local
|
||||
`.tar` / `.tar.gz` / `.tgz` archive before the task is returned.
|
||||
- `Pyro.exec_vm(...)` runs one command and auto-cleans that VM after the exec completes.
|
||||
- `Pyro.exec_task(...)` runs one command in the persistent task workspace and leaves the task alive.
|
||||
|
||||
|
|
@ -141,6 +147,8 @@ Behavioral defaults:
|
|||
- `task_create` defaults to `vcpu_count=1` and `mem_mib=1024`.
|
||||
- `vm_run` and `vm_create` expose `allow_host_compat`, which defaults to `false`.
|
||||
- `task_create` exposes `allow_host_compat`, which defaults to `false`.
|
||||
- `task_create` accepts optional `source_path` and seeds `/workspace` from a host directory or a
|
||||
local `.tar` / `.tar.gz` / `.tgz` archive before the task is returned.
|
||||
- `vm_exec` runs one command and auto-cleans that VM after the exec completes.
|
||||
- `task_exec` runs one command in a persistent `/workspace` and leaves the task alive.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "pyro-mcp"
|
||||
version = "2.1.0"
|
||||
version = "2.2.0"
|
||||
description = "Curated Linux environments for ephemeral Firecracker-backed VM execution."
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ Kernel build note:
|
|||
Current status:
|
||||
1. Firecracker and Jailer are materialized from pinned official release artifacts.
|
||||
2. The kernel and rootfs images are built from pinned inputs into `build/runtime_sources/`.
|
||||
3. The guest agent is installed into each rootfs and used for vsock exec.
|
||||
3. The guest agent is installed into each rootfs and used for vsock exec plus workspace archive imports.
|
||||
4. `runtime.lock.json` now advertises real guest capabilities.
|
||||
|
||||
Safety rule:
|
||||
|
|
|
|||
|
|
@ -1,26 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Minimal guest-side exec agent for pyro runtime bundles."""
|
||||
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
PORT = 5005
|
||||
BUFFER_SIZE = 65536
|
||||
WORKSPACE_ROOT = PurePosixPath("/workspace")
|
||||
|
||||
|
||||
def _read_request(conn: socket.socket) -> dict[str, Any]:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = conn.recv(BUFFER_SIZE)
|
||||
data = conn.recv(1)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
if b"\n" in data:
|
||||
if data == b"\n":
|
||||
break
|
||||
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
|
||||
if not isinstance(payload, dict):
|
||||
|
|
@ -28,6 +33,130 @@ def _read_request(conn: socket.socket) -> dict[str, Any]:
|
|||
return payload
|
||||
|
||||
|
||||
def _read_exact(conn: socket.socket, size: int) -> bytes:
|
||||
remaining = size
|
||||
chunks: list[bytes] = []
|
||||
while remaining > 0:
|
||||
data = conn.recv(min(BUFFER_SIZE, remaining))
|
||||
if data == b"":
|
||||
raise RuntimeError("unexpected EOF while reading archive payload")
|
||||
chunks.append(data)
|
||||
remaining -= len(data)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _normalize_member_name(name: str) -> PurePosixPath:
|
||||
candidate = name.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("archive member path is empty")
|
||||
member_path = PurePosixPath(candidate)
|
||||
if member_path.is_absolute():
|
||||
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
|
||||
parts = [part for part in member_path.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
normalized = PurePosixPath(*parts)
|
||||
if str(normalized) in {"", "."}:
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
|
||||
candidate = destination.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("destination must not be empty")
|
||||
destination_path = PurePosixPath(candidate)
|
||||
if not destination_path.is_absolute():
|
||||
destination_path = WORKSPACE_ROOT / destination_path
|
||||
parts = [part for part in destination_path.parts if part not in {"", "."}]
|
||||
normalized = PurePosixPath("/") / PurePosixPath(*parts)
|
||||
if normalized == PurePosixPath("/"):
|
||||
raise RuntimeError("destination must stay inside /workspace")
|
||||
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
|
||||
raise RuntimeError("destination must stay inside /workspace")
|
||||
suffix = normalized.relative_to(WORKSPACE_ROOT)
|
||||
host_path = Path("/workspace")
|
||||
if str(suffix) not in {"", "."}:
|
||||
host_path = host_path / str(suffix)
|
||||
return normalized, host_path
|
||||
|
||||
|
||||
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
raise RuntimeError(f"symlink {member_path} has an empty target")
|
||||
target_path = PurePosixPath(target)
|
||||
if target_path.is_absolute():
|
||||
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
||||
combined = member_path.parent.joinpath(target_path)
|
||||
parts = [part for part in combined.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
||||
|
||||
|
||||
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
|
||||
relative_path = target_path.relative_to(root)
|
||||
current = root
|
||||
for part in relative_path.parts[:-1]:
|
||||
current = current / part
|
||||
if current.is_symlink():
|
||||
raise RuntimeError(
|
||||
f"archive member would traverse through a symlinked path: {member_name}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
||||
_, destination_root = _normalize_destination(destination)
|
||||
destination_root.mkdir(parents=True, exist_ok=True)
|
||||
bytes_written = 0
|
||||
entry_count = 0
|
||||
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_member_name(member.name)
|
||||
target_path = destination_root / str(member_name)
|
||||
entry_count += 1
|
||||
_ensure_no_symlink_parents(destination_root, target_path, member.name)
|
||||
if member.isdir():
|
||||
if target_path.exists() and not target_path.is_dir():
|
||||
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
if member.isfile():
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
|
||||
raise RuntimeError(f"file conflicts with existing path: {member.name}")
|
||||
source = archive.extractfile(member)
|
||||
if source is None:
|
||||
raise RuntimeError(f"failed to read archive member: {member.name}")
|
||||
with target_path.open("wb") as handle:
|
||||
while True:
|
||||
chunk = source.read(BUFFER_SIZE)
|
||||
if chunk == b"":
|
||||
break
|
||||
handle.write(chunk)
|
||||
bytes_written += member.size
|
||||
continue
|
||||
if member.issym():
|
||||
_validate_symlink_target(member_name, member.linkname)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and not target_path.is_symlink():
|
||||
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
|
||||
if target_path.is_symlink():
|
||||
target_path.unlink()
|
||||
os.symlink(member.linkname, target_path)
|
||||
continue
|
||||
if member.islnk():
|
||||
raise RuntimeError(
|
||||
f"hard links are not allowed in workspace archives: {member.name}"
|
||||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return {
|
||||
"destination": destination,
|
||||
"entry_count": entry_count,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
|
||||
|
||||
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
|
||||
started = time.monotonic()
|
||||
try:
|
||||
|
|
@ -64,9 +193,18 @@ def main() -> None:
|
|||
conn, _ = server.accept()
|
||||
with conn:
|
||||
request = _read_request(conn)
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
archive_size = int(request.get("archive_size", 0))
|
||||
if archive_size < 0:
|
||||
raise RuntimeError("archive_size must not be negative")
|
||||
destination = str(request.get("destination", "/workspace"))
|
||||
payload = _read_exact(conn, archive_size)
|
||||
response = _extract_archive(payload, destination)
|
||||
else:
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
"firecracker": "1.12.1",
|
||||
"jailer": "1.12.1",
|
||||
"kernel": "5.10.210",
|
||||
"guest_agent": "0.1.0-dev",
|
||||
"guest_agent": "0.2.0-dev",
|
||||
"base_distro": "debian-bookworm-20250210"
|
||||
},
|
||||
"capabilities": {
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class Pyro:
|
|||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
source_path: str | Path | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.create_task(
|
||||
environment=environment,
|
||||
|
|
@ -94,6 +95,7 @@ class Pyro:
|
|||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
allow_host_compat=allow_host_compat,
|
||||
source_path=source_path,
|
||||
)
|
||||
|
||||
def exec_task(
|
||||
|
|
@ -245,6 +247,7 @@ class Pyro:
|
|||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
source_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create and start a persistent task workspace."""
|
||||
return self.create_task(
|
||||
|
|
@ -254,6 +257,7 @@ class Pyro:
|
|||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
allow_host_compat=allow_host_compat,
|
||||
source_path=source_path,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import argparse
|
||||
import json
|
||||
import shlex
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
|
@ -155,6 +156,14 @@ def _print_task_summary_human(payload: dict[str, Any], *, action: str) -> None:
|
|||
print(f"Environment: {str(payload.get('environment', 'unknown'))}")
|
||||
print(f"State: {str(payload.get('state', 'unknown'))}")
|
||||
print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}")
|
||||
workspace_seed = payload.get("workspace_seed")
|
||||
if isinstance(workspace_seed, dict):
|
||||
mode = str(workspace_seed.get("mode", "empty"))
|
||||
source_path = workspace_seed.get("source_path")
|
||||
if isinstance(source_path, str) and source_path != "":
|
||||
print(f"Workspace seed: {mode} from {source_path}")
|
||||
else:
|
||||
print(f"Workspace seed: {mode}")
|
||||
print(f"Execution mode: {str(payload.get('execution_mode', 'pending'))}")
|
||||
print(
|
||||
f"Resources: {int(payload.get('vcpu_count', 0))} vCPU / "
|
||||
|
|
@ -446,7 +455,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
epilog=dedent(
|
||||
"""
|
||||
Examples:
|
||||
pyro task create debian:12
|
||||
pyro task create debian:12 --source-path ./repo
|
||||
pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt'
|
||||
pyro task logs TASK_ID
|
||||
"""
|
||||
|
|
@ -458,7 +467,13 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
"create",
|
||||
help="Create and start a persistent task workspace.",
|
||||
description="Create a task workspace that stays alive across repeated exec calls.",
|
||||
epilog="Example:\n pyro task create debian:12",
|
||||
epilog=dedent(
|
||||
"""
|
||||
Examples:
|
||||
pyro task create debian:12
|
||||
pyro task create debian:12 --source-path ./repo
|
||||
"""
|
||||
),
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
task_create_parser.add_argument(
|
||||
|
|
@ -497,6 +512,13 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
"is unavailable."
|
||||
),
|
||||
)
|
||||
task_create_parser.add_argument(
|
||||
"--source-path",
|
||||
help=(
|
||||
"Optional host directory or .tar/.tar.gz/.tgz archive to seed into `/workspace` "
|
||||
"before the task is returned."
|
||||
),
|
||||
)
|
||||
task_create_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
|
|
@ -663,7 +685,7 @@ def _require_command(command_args: list[str]) -> str:
|
|||
command_args = command_args[1:]
|
||||
if not command_args:
|
||||
raise ValueError("command is required after `--`")
|
||||
return " ".join(command_args)
|
||||
return shlex.join(command_args)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
@ -764,6 +786,7 @@ def main() -> None:
|
|||
ttl_seconds=args.ttl_seconds,
|
||||
network=args.network,
|
||||
allow_host_compat=args.allow_host_compat,
|
||||
source_path=args.source_path,
|
||||
)
|
||||
if bool(args.json):
|
||||
_print_json(payload)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,15 @@ PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task")
|
|||
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
|
||||
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
|
||||
PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status")
|
||||
PUBLIC_CLI_TASK_CREATE_FLAGS = (
|
||||
"--vcpu-count",
|
||||
"--mem-mib",
|
||||
"--ttl-seconds",
|
||||
"--network",
|
||||
"--allow-host-compat",
|
||||
"--source-path",
|
||||
"--json",
|
||||
)
|
||||
PUBLIC_CLI_RUN_FLAGS = (
|
||||
"--vcpu-count",
|
||||
"--mem-mib",
|
||||
|
|
|
|||
|
|
@ -1,26 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Minimal guest-side exec agent for pyro runtime bundles."""
|
||||
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
PORT = 5005
|
||||
BUFFER_SIZE = 65536
|
||||
WORKSPACE_ROOT = PurePosixPath("/workspace")
|
||||
|
||||
|
||||
def _read_request(conn: socket.socket) -> dict[str, Any]:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = conn.recv(BUFFER_SIZE)
|
||||
data = conn.recv(1)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
if b"\n" in data:
|
||||
if data == b"\n":
|
||||
break
|
||||
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
|
||||
if not isinstance(payload, dict):
|
||||
|
|
@ -28,6 +33,130 @@ def _read_request(conn: socket.socket) -> dict[str, Any]:
|
|||
return payload
|
||||
|
||||
|
||||
def _read_exact(conn: socket.socket, size: int) -> bytes:
|
||||
remaining = size
|
||||
chunks: list[bytes] = []
|
||||
while remaining > 0:
|
||||
data = conn.recv(min(BUFFER_SIZE, remaining))
|
||||
if data == b"":
|
||||
raise RuntimeError("unexpected EOF while reading archive payload")
|
||||
chunks.append(data)
|
||||
remaining -= len(data)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _normalize_member_name(name: str) -> PurePosixPath:
|
||||
candidate = name.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("archive member path is empty")
|
||||
member_path = PurePosixPath(candidate)
|
||||
if member_path.is_absolute():
|
||||
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
|
||||
parts = [part for part in member_path.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
normalized = PurePosixPath(*parts)
|
||||
if str(normalized) in {"", "."}:
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
|
||||
candidate = destination.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("destination must not be empty")
|
||||
destination_path = PurePosixPath(candidate)
|
||||
if not destination_path.is_absolute():
|
||||
destination_path = WORKSPACE_ROOT / destination_path
|
||||
parts = [part for part in destination_path.parts if part not in {"", "."}]
|
||||
normalized = PurePosixPath("/") / PurePosixPath(*parts)
|
||||
if normalized == PurePosixPath("/"):
|
||||
raise RuntimeError("destination must stay inside /workspace")
|
||||
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
|
||||
raise RuntimeError("destination must stay inside /workspace")
|
||||
suffix = normalized.relative_to(WORKSPACE_ROOT)
|
||||
host_path = Path("/workspace")
|
||||
if str(suffix) not in {"", "."}:
|
||||
host_path = host_path / str(suffix)
|
||||
return normalized, host_path
|
||||
|
||||
|
||||
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
raise RuntimeError(f"symlink {member_path} has an empty target")
|
||||
target_path = PurePosixPath(target)
|
||||
if target_path.is_absolute():
|
||||
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
||||
combined = member_path.parent.joinpath(target_path)
|
||||
parts = [part for part in combined.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
||||
|
||||
|
||||
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
|
||||
relative_path = target_path.relative_to(root)
|
||||
current = root
|
||||
for part in relative_path.parts[:-1]:
|
||||
current = current / part
|
||||
if current.is_symlink():
|
||||
raise RuntimeError(
|
||||
f"archive member would traverse through a symlinked path: {member_name}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
||||
_, destination_root = _normalize_destination(destination)
|
||||
destination_root.mkdir(parents=True, exist_ok=True)
|
||||
bytes_written = 0
|
||||
entry_count = 0
|
||||
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_member_name(member.name)
|
||||
target_path = destination_root / str(member_name)
|
||||
entry_count += 1
|
||||
_ensure_no_symlink_parents(destination_root, target_path, member.name)
|
||||
if member.isdir():
|
||||
if target_path.exists() and not target_path.is_dir():
|
||||
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
if member.isfile():
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
|
||||
raise RuntimeError(f"file conflicts with existing path: {member.name}")
|
||||
source = archive.extractfile(member)
|
||||
if source is None:
|
||||
raise RuntimeError(f"failed to read archive member: {member.name}")
|
||||
with target_path.open("wb") as handle:
|
||||
while True:
|
||||
chunk = source.read(BUFFER_SIZE)
|
||||
if chunk == b"":
|
||||
break
|
||||
handle.write(chunk)
|
||||
bytes_written += member.size
|
||||
continue
|
||||
if member.issym():
|
||||
_validate_symlink_target(member_name, member.linkname)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and not target_path.is_symlink():
|
||||
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
|
||||
if target_path.is_symlink():
|
||||
target_path.unlink()
|
||||
os.symlink(member.linkname, target_path)
|
||||
continue
|
||||
if member.islnk():
|
||||
raise RuntimeError(
|
||||
f"hard links are not allowed in workspace archives: {member.name}"
|
||||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return {
|
||||
"destination": destination,
|
||||
"entry_count": entry_count,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
|
||||
|
||||
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
|
||||
started = time.monotonic()
|
||||
try:
|
||||
|
|
@ -64,9 +193,18 @@ def main() -> None:
|
|||
conn, _ = server.accept()
|
||||
with conn:
|
||||
request = _read_request(conn)
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
archive_size = int(request.get("archive_size", 0))
|
||||
if archive_size < 0:
|
||||
raise RuntimeError("archive_size must not be negative")
|
||||
destination = str(request.get("destination", "/workspace"))
|
||||
payload = _read_exact(conn, archive_size)
|
||||
response = _extract_archive(payload, destination)
|
||||
else:
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,14 +18,14 @@
|
|||
"component_versions": {
|
||||
"base_distro": "debian-bookworm-20250210",
|
||||
"firecracker": "1.12.1",
|
||||
"guest_agent": "0.1.0-dev",
|
||||
"guest_agent": "0.2.0-dev",
|
||||
"jailer": "1.12.1",
|
||||
"kernel": "5.10.210"
|
||||
},
|
||||
"guest": {
|
||||
"agent": {
|
||||
"path": "guest/pyro_guest_agent.py",
|
||||
"sha256": "65bf8a9a57ffd7321463537e598c4b30f0a13046cbd4538f1b65bc351da5d3c0"
|
||||
"sha256": "3b684b1b07745fc7788e560b0bdd0c0535c31c395ff87474ae9e114f4489726d"
|
||||
}
|
||||
},
|
||||
"platform": "linux-x86_64",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from typing import Any
|
|||
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
|
||||
|
||||
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.1.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.2.0"
|
||||
OCI_MANIFEST_ACCEPT = ", ".join(
|
||||
(
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
|
||||
|
|
@ -31,6 +32,13 @@ class GuestExecResponse:
|
|||
duration_ms: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuestArchiveResponse:
|
||||
destination: str
|
||||
entry_count: int
|
||||
bytes_written: int
|
||||
|
||||
|
||||
class VsockExecClient:
|
||||
"""Minimal JSON-over-stream client for a guest exec agent."""
|
||||
|
||||
|
|
@ -50,6 +58,63 @@ class VsockExecClient:
|
|||
"command": command,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
}
|
||||
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
|
||||
try:
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
payload = self._recv_json_payload(sock)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest exec response must be a JSON object")
|
||||
return GuestExecResponse(
|
||||
stdout=str(payload.get("stdout", "")),
|
||||
stderr=str(payload.get("stderr", "")),
|
||||
exit_code=int(payload.get("exit_code", -1)),
|
||||
duration_ms=int(payload.get("duration_ms", 0)),
|
||||
)
|
||||
|
||||
def upload_archive(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
archive_path: Path,
|
||||
*,
|
||||
destination: str,
|
||||
timeout_seconds: int = 60,
|
||||
uds_path: str | None = None,
|
||||
) -> GuestArchiveResponse:
|
||||
request = {
|
||||
"action": "extract_archive",
|
||||
"destination": destination,
|
||||
"archive_size": archive_path.stat().st_size,
|
||||
}
|
||||
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
|
||||
try:
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
with archive_path.open("rb") as handle:
|
||||
for chunk in iter(lambda: handle.read(65536), b""):
|
||||
sock.sendall(chunk)
|
||||
payload = self._recv_json_payload(sock)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest archive response must be a JSON object")
|
||||
return GuestArchiveResponse(
|
||||
destination=str(payload.get("destination", destination)),
|
||||
entry_count=int(payload.get("entry_count", 0)),
|
||||
bytes_written=int(payload.get("bytes_written", 0)),
|
||||
)
|
||||
|
||||
def _connect(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
timeout_seconds: int,
|
||||
*,
|
||||
uds_path: str | None,
|
||||
) -> SocketLike:
|
||||
family = getattr(socket, "AF_VSOCK", None)
|
||||
if family is not None:
|
||||
sock = self._socket_factory(family, socket.SOCK_STREAM)
|
||||
|
|
@ -59,33 +124,15 @@ class VsockExecClient:
|
|||
connect_address = uds_path
|
||||
else:
|
||||
raise RuntimeError("vsock sockets are not supported on this host Python runtime")
|
||||
try:
|
||||
sock.settimeout(timeout_seconds)
|
||||
sock.connect(connect_address)
|
||||
if family is None:
|
||||
sock.sendall(f"CONNECT {port}\n".encode("utf-8"))
|
||||
status = self._recv_line(sock)
|
||||
if not status.startswith("OK "):
|
||||
raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}")
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = sock.recv(65536)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
payload = json.loads(b"".join(chunks).decode("utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest exec response must be a JSON object")
|
||||
return GuestExecResponse(
|
||||
stdout=str(payload.get("stdout", "")),
|
||||
stderr=str(payload.get("stderr", "")),
|
||||
exit_code=int(payload.get("exit_code", -1)),
|
||||
duration_ms=int(payload.get("duration_ms", 0)),
|
||||
)
|
||||
sock.settimeout(timeout_seconds)
|
||||
sock.connect(connect_address)
|
||||
if family is None:
|
||||
sock.sendall(f"CONNECT {port}\n".encode("utf-8"))
|
||||
status = self._recv_line(sock)
|
||||
if not status.startswith("OK "):
|
||||
sock.close()
|
||||
raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}")
|
||||
return sock
|
||||
|
||||
@staticmethod
|
||||
def _recv_line(sock: SocketLike) -> str:
|
||||
|
|
@ -98,3 +145,13 @@ class VsockExecClient:
|
|||
if data == b"\n":
|
||||
break
|
||||
return b"".join(chunks).decode("utf-8", errors="replace")
|
||||
|
||||
@staticmethod
|
||||
def _recv_json_payload(sock: SocketLike) -> Any:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = sock.recv(65536)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
return json.loads(b"".join(chunks).decode("utf-8"))
|
||||
|
|
|
|||
|
|
@ -8,11 +8,13 @@ import shlex
|
|||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from pyro_mcp.runtime import (
|
||||
|
|
@ -34,11 +36,15 @@ DEFAULT_TIMEOUT_SECONDS = 30
|
|||
DEFAULT_TTL_SECONDS = 600
|
||||
DEFAULT_ALLOW_HOST_COMPAT = False
|
||||
|
||||
TASK_LAYOUT_VERSION = 1
|
||||
TASK_LAYOUT_VERSION = 2
|
||||
TASK_WORKSPACE_DIRNAME = "workspace"
|
||||
TASK_COMMANDS_DIRNAME = "commands"
|
||||
TASK_RUNTIME_DIRNAME = "runtime"
|
||||
TASK_WORKSPACE_GUEST_PATH = "/workspace"
|
||||
TASK_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py"
|
||||
TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60
|
||||
|
||||
WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -82,6 +88,7 @@ class TaskRecord:
|
|||
network: NetworkConfig | None = None
|
||||
command_count: int = 0
|
||||
last_command: dict[str, Any] | None = None
|
||||
workspace_seed: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_instance(
|
||||
|
|
@ -90,6 +97,7 @@ class TaskRecord:
|
|||
*,
|
||||
command_count: int = 0,
|
||||
last_command: dict[str, Any] | None = None,
|
||||
workspace_seed: dict[str, Any] | None = None,
|
||||
) -> TaskRecord:
|
||||
return cls(
|
||||
task_id=instance.vm_id,
|
||||
|
|
@ -108,6 +116,7 @@ class TaskRecord:
|
|||
network=instance.network,
|
||||
command_count=command_count,
|
||||
last_command=last_command,
|
||||
workspace_seed=dict(workspace_seed or _empty_workspace_seed_payload()),
|
||||
)
|
||||
|
||||
def to_instance(self, *, workdir: Path) -> VmInstance:
|
||||
|
|
@ -148,6 +157,7 @@ class TaskRecord:
|
|||
"network": _serialize_network(self.network),
|
||||
"command_count": self.command_count,
|
||||
"last_command": self.last_command,
|
||||
"workspace_seed": self.workspace_seed,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -169,9 +179,35 @@ class TaskRecord:
|
|||
network=_deserialize_network(payload.get("network")),
|
||||
command_count=int(payload.get("command_count", 0)),
|
||||
last_command=_optional_dict(payload.get("last_command")),
|
||||
workspace_seed=_task_workspace_seed_dict(payload.get("workspace_seed")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PreparedWorkspaceSeed:
|
||||
"""Prepared host-side seed archive plus metadata."""
|
||||
|
||||
mode: WorkspaceSeedMode
|
||||
source_path: str | None
|
||||
archive_path: Path | None = None
|
||||
entry_count: int = 0
|
||||
bytes_written: int = 0
|
||||
cleanup_dir: Path | None = None
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"source_path": self.source_path,
|
||||
"destination": TASK_WORKSPACE_GUEST_PATH,
|
||||
"entry_count": self.entry_count,
|
||||
"bytes_written": self.bytes_written,
|
||||
}
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self.cleanup_dir is not None:
|
||||
shutil.rmtree(self.cleanup_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VmExecResult:
|
||||
"""Command execution output."""
|
||||
|
|
@ -216,6 +252,32 @@ def _string_dict(value: object) -> dict[str, str]:
|
|||
return {str(key): str(item) for key, item in value.items()}
|
||||
|
||||
|
||||
def _empty_workspace_seed_payload() -> dict[str, Any]:
|
||||
return {
|
||||
"mode": "empty",
|
||||
"source_path": None,
|
||||
"destination": TASK_WORKSPACE_GUEST_PATH,
|
||||
"entry_count": 0,
|
||||
"bytes_written": 0,
|
||||
}
|
||||
|
||||
|
||||
def _task_workspace_seed_dict(value: object) -> dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
return _empty_workspace_seed_payload()
|
||||
payload = _empty_workspace_seed_payload()
|
||||
payload.update(
|
||||
{
|
||||
"mode": str(value.get("mode", payload["mode"])),
|
||||
"source_path": _optional_str(value.get("source_path")),
|
||||
"destination": str(value.get("destination", payload["destination"])),
|
||||
"entry_count": int(value.get("entry_count", payload["entry_count"])),
|
||||
"bytes_written": int(value.get("bytes_written", payload["bytes_written"])),
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None:
|
||||
if network is None:
|
||||
return None
|
||||
|
|
@ -300,6 +362,201 @@ def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str:
|
|||
return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}"
|
||||
|
||||
|
||||
def _is_supported_seed_archive(path: Path) -> bool:
|
||||
name = path.name.lower()
|
||||
return name.endswith(".tar") or name.endswith(".tar.gz") or name.endswith(".tgz")
|
||||
|
||||
|
||||
def _normalize_workspace_destination(destination: str) -> tuple[str, PurePosixPath]:
|
||||
candidate = destination.strip()
|
||||
if candidate == "":
|
||||
raise ValueError("workspace destination must not be empty")
|
||||
destination_path = PurePosixPath(candidate)
|
||||
workspace_root = PurePosixPath(TASK_WORKSPACE_GUEST_PATH)
|
||||
if not destination_path.is_absolute():
|
||||
destination_path = workspace_root / destination_path
|
||||
parts = [part for part in destination_path.parts if part not in {"", "."}]
|
||||
normalized = PurePosixPath("/") / PurePosixPath(*parts)
|
||||
if normalized == PurePosixPath("/"):
|
||||
raise ValueError("workspace destination must stay inside /workspace")
|
||||
if normalized.parts[: len(workspace_root.parts)] != workspace_root.parts:
|
||||
raise ValueError("workspace destination must stay inside /workspace")
|
||||
suffix = normalized.relative_to(workspace_root)
|
||||
return str(normalized), suffix
|
||||
|
||||
|
||||
def _workspace_host_destination(workspace_dir: Path, destination: str) -> Path:
|
||||
_, suffix = _normalize_workspace_destination(destination)
|
||||
if str(suffix) in {"", "."}:
|
||||
return workspace_dir
|
||||
return workspace_dir.joinpath(*suffix.parts)
|
||||
|
||||
|
||||
def _normalize_archive_member_name(name: str) -> PurePosixPath:
|
||||
candidate = name.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("archive member path is empty")
|
||||
member_path = PurePosixPath(candidate)
|
||||
if member_path.is_absolute():
|
||||
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
|
||||
parts = [part for part in member_path.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
normalized = PurePosixPath(*parts)
|
||||
if str(normalized) in {"", "."}:
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _validate_archive_symlink_target(member_name: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
raise RuntimeError(f"symlink {member_name} has an empty target")
|
||||
link_path = PurePosixPath(target)
|
||||
if link_path.is_absolute():
|
||||
raise RuntimeError(f"symlink {member_name} escapes the workspace")
|
||||
combined = member_name.parent.joinpath(link_path)
|
||||
parts = [part for part in combined.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"symlink {member_name} escapes the workspace")
|
||||
|
||||
|
||||
def _inspect_seed_archive(archive_path: Path) -> tuple[int, int]:
|
||||
entry_count = 0
|
||||
bytes_written = 0
|
||||
with tarfile.open(archive_path, "r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_archive_member_name(member.name)
|
||||
entry_count += 1
|
||||
if member.isdir():
|
||||
continue
|
||||
if member.isfile():
|
||||
bytes_written += member.size
|
||||
continue
|
||||
if member.issym():
|
||||
_validate_archive_symlink_target(member_name, member.linkname)
|
||||
continue
|
||||
if member.islnk():
|
||||
raise RuntimeError(
|
||||
f"hard links are not allowed in workspace archives: {member.name}"
|
||||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return entry_count, bytes_written
|
||||
|
||||
|
||||
def _write_directory_seed_archive(source_dir: Path, archive_path: Path) -> None:
|
||||
archive_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with tarfile.open(archive_path, "w") as archive:
|
||||
for child in sorted(source_dir.iterdir(), key=lambda item: item.name):
|
||||
archive.add(child, arcname=child.name, recursive=True)
|
||||
|
||||
|
||||
def _extract_seed_archive_to_host_workspace(
|
||||
archive_path: Path,
|
||||
*,
|
||||
workspace_dir: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
normalized_destination, _ = _normalize_workspace_destination(destination)
|
||||
destination_root = _workspace_host_destination(workspace_dir, normalized_destination)
|
||||
destination_root.mkdir(parents=True, exist_ok=True)
|
||||
entry_count = 0
|
||||
bytes_written = 0
|
||||
with tarfile.open(archive_path, "r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_archive_member_name(member.name)
|
||||
target_path = destination_root.joinpath(*member_name.parts)
|
||||
entry_count += 1
|
||||
_ensure_no_symlink_parents(workspace_dir, target_path, member.name)
|
||||
if member.isdir():
|
||||
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
|
||||
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
if member.isfile():
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.is_symlink() or target_path.is_dir():
|
||||
raise RuntimeError(f"file conflicts with existing path: {member.name}")
|
||||
source = archive.extractfile(member)
|
||||
if source is None:
|
||||
raise RuntimeError(f"failed to read archive member: {member.name}")
|
||||
with target_path.open("wb") as handle:
|
||||
shutil.copyfileobj(source, handle)
|
||||
bytes_written += member.size
|
||||
continue
|
||||
if member.issym():
|
||||
_validate_archive_symlink_target(member_name, member.linkname)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and not target_path.is_symlink():
|
||||
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
|
||||
if target_path.is_symlink():
|
||||
target_path.unlink()
|
||||
os.symlink(member.linkname, target_path)
|
||||
continue
|
||||
if member.islnk():
|
||||
raise RuntimeError(
|
||||
f"hard links are not allowed in workspace archives: {member.name}"
|
||||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return {
|
||||
"destination": normalized_destination,
|
||||
"entry_count": entry_count,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
|
||||
|
||||
def _instance_workspace_host_dir(instance: VmInstance) -> Path:
|
||||
raw_value = instance.metadata.get("workspace_host_dir")
|
||||
if raw_value is None or raw_value == "":
|
||||
raise RuntimeError("task workspace host directory is unavailable")
|
||||
return Path(raw_value)
|
||||
|
||||
|
||||
def _patch_rootfs_guest_agent(rootfs_image: Path, guest_agent_path: Path) -> None:
|
||||
debugfs_path = shutil.which("debugfs")
|
||||
if debugfs_path is None:
|
||||
raise RuntimeError(
|
||||
"debugfs is required to seed task workspaces on guest-backed runtimes"
|
||||
)
|
||||
with tempfile.TemporaryDirectory(prefix="pyro-guest-agent-") as temp_dir:
|
||||
staged_agent_path = Path(temp_dir) / "pyro_guest_agent.py"
|
||||
shutil.copy2(guest_agent_path, staged_agent_path)
|
||||
subprocess.run( # noqa: S603
|
||||
[debugfs_path, "-w", "-R", f"rm {TASK_GUEST_AGENT_PATH}", str(rootfs_image)],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
proc = subprocess.run( # noqa: S603
|
||||
[
|
||||
debugfs_path,
|
||||
"-w",
|
||||
"-R",
|
||||
f"write {staged_agent_path} {TASK_GUEST_AGENT_PATH}",
|
||||
str(rootfs_image),
|
||||
],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"failed to patch guest agent into task rootfs: "
|
||||
f"{proc.stderr.strip() or proc.stdout.strip()}"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
|
||||
relative_path = target_path.relative_to(root)
|
||||
current = root
|
||||
for part in relative_path.parts[:-1]:
|
||||
current = current / part
|
||||
if current.is_symlink():
|
||||
raise RuntimeError(
|
||||
f"archive member would traverse through a symlinked path: {member_name}"
|
||||
)
|
||||
|
||||
|
||||
def _pid_is_running(pid: int | None) -> bool:
|
||||
if pid is None:
|
||||
return False
|
||||
|
|
@ -337,6 +594,15 @@ class VmBackend:
|
|||
def delete(self, instance: VmInstance) -> None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def import_archive( # pragma: no cover
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
archive_path: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockBackend(VmBackend):
|
||||
"""Host-process backend used for development and testability."""
|
||||
|
|
@ -365,6 +631,19 @@ class MockBackend(VmBackend):
|
|||
def delete(self, instance: VmInstance) -> None:
|
||||
shutil.rmtree(instance.workdir, ignore_errors=True)
|
||||
|
||||
def import_archive(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
archive_path: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
return _extract_seed_archive_to_host_workspace(
|
||||
archive_path,
|
||||
workspace_dir=_instance_workspace_host_dir(instance),
|
||||
destination=destination,
|
||||
)
|
||||
|
||||
|
||||
class FirecrackerBackend(VmBackend): # pragma: no cover
|
||||
"""Host-gated backend that validates Firecracker prerequisites."""
|
||||
|
|
@ -562,6 +841,46 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
|
|||
self._network_manager.cleanup(instance.network)
|
||||
shutil.rmtree(instance.workdir, ignore_errors=True)
|
||||
|
||||
def import_archive(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
archive_path: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
uds_path = instance.metadata.get("guest_exec_uds_path")
|
||||
deadline = time.monotonic() + 10
|
||||
while True:
|
||||
try:
|
||||
response = self._guest_exec_client.upload_archive(
|
||||
guest_cid,
|
||||
port,
|
||||
archive_path,
|
||||
destination=destination,
|
||||
timeout_seconds=TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS,
|
||||
uds_path=uds_path,
|
||||
)
|
||||
return {
|
||||
"destination": response.destination,
|
||||
"entry_count": response.entry_count,
|
||||
"bytes_written": response.bytes_written,
|
||||
}
|
||||
except (OSError, RuntimeError) as exc:
|
||||
if time.monotonic() >= deadline:
|
||||
raise RuntimeError(
|
||||
f"guest archive transport did not become ready: {exc}"
|
||||
) from exc
|
||||
time.sleep(0.2)
|
||||
instance.metadata["execution_mode"] = "host_compat"
|
||||
return _extract_seed_archive_to_host_workspace(
|
||||
archive_path,
|
||||
workspace_dir=_instance_workspace_host_dir(instance),
|
||||
destination=destination,
|
||||
)
|
||||
|
||||
|
||||
class VmManager:
|
||||
"""In-process lifecycle manager for ephemeral VM environments and tasks."""
|
||||
|
|
@ -814,9 +1133,11 @@ class VmManager:
|
|||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
source_path: str | Path | None = None,
|
||||
) -> dict[str, Any]:
|
||||
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
|
||||
get_environment(environment, runtime_paths=self._runtime_paths)
|
||||
prepared_seed = self._prepare_workspace_seed(source_path)
|
||||
now = time.time()
|
||||
task_id = uuid.uuid4().hex[:12]
|
||||
task_dir = self._task_dir(task_id)
|
||||
|
|
@ -851,10 +1172,25 @@ class VmManager:
|
|||
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
|
||||
)
|
||||
self._backend.create(instance)
|
||||
if (
|
||||
prepared_seed.archive_path is not None
|
||||
and self._runtime_capabilities.supports_guest_exec
|
||||
):
|
||||
self._ensure_task_guest_seed_support(instance)
|
||||
with self._lock:
|
||||
self._start_instance_locked(instance)
|
||||
self._require_guest_exec_or_opt_in(instance)
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
workspace_seed = prepared_seed.to_payload()
|
||||
if prepared_seed.archive_path is not None:
|
||||
import_summary = self._backend.import_archive(
|
||||
instance,
|
||||
archive_path=prepared_seed.archive_path,
|
||||
destination=TASK_WORKSPACE_GUEST_PATH,
|
||||
)
|
||||
workspace_seed["entry_count"] = int(import_summary["entry_count"])
|
||||
workspace_seed["bytes_written"] = int(import_summary["bytes_written"])
|
||||
workspace_seed["destination"] = str(import_summary["destination"])
|
||||
elif self._runtime_capabilities.supports_guest_exec:
|
||||
self._backend.exec(
|
||||
instance,
|
||||
f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}",
|
||||
|
|
@ -862,7 +1198,7 @@ class VmManager:
|
|||
)
|
||||
else:
|
||||
instance.metadata["execution_mode"] = "host_compat"
|
||||
task = TaskRecord.from_instance(instance)
|
||||
task = TaskRecord.from_instance(instance, workspace_seed=workspace_seed)
|
||||
self._save_task_locked(task)
|
||||
return self._serialize_task(task)
|
||||
except Exception:
|
||||
|
|
@ -879,6 +1215,8 @@ class VmManager:
|
|||
pass
|
||||
shutil.rmtree(task_dir, ignore_errors=True)
|
||||
raise
|
||||
finally:
|
||||
prepared_seed.cleanup()
|
||||
|
||||
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
|
||||
if timeout_seconds <= 0:
|
||||
|
|
@ -999,6 +1337,7 @@ class VmManager:
|
|||
"tap_name": task.network.tap_name if task.network is not None else None,
|
||||
"execution_mode": task.metadata.get("execution_mode", "pending"),
|
||||
"workspace_path": TASK_WORKSPACE_GUEST_PATH,
|
||||
"workspace_seed": _task_workspace_seed_dict(task.workspace_seed),
|
||||
"command_count": task.command_count,
|
||||
"last_command": task.last_command,
|
||||
"metadata": task.metadata,
|
||||
|
|
@ -1094,6 +1433,53 @@ class VmManager:
|
|||
execution_mode = instance.metadata.get("execution_mode", "unknown")
|
||||
return exec_result, execution_mode
|
||||
|
||||
def _prepare_workspace_seed(self, source_path: str | Path | None) -> PreparedWorkspaceSeed:
|
||||
if source_path is None:
|
||||
return PreparedWorkspaceSeed(mode="empty", source_path=None)
|
||||
resolved_source_path = Path(source_path).expanduser().resolve()
|
||||
if not resolved_source_path.exists():
|
||||
raise ValueError(f"source_path {resolved_source_path} does not exist")
|
||||
if resolved_source_path.is_dir():
|
||||
cleanup_dir = Path(tempfile.mkdtemp(prefix="pyro-task-seed-"))
|
||||
archive_path = cleanup_dir / "workspace-seed.tar"
|
||||
try:
|
||||
_write_directory_seed_archive(resolved_source_path, archive_path)
|
||||
entry_count, bytes_written = _inspect_seed_archive(archive_path)
|
||||
except Exception:
|
||||
shutil.rmtree(cleanup_dir, ignore_errors=True)
|
||||
raise
|
||||
return PreparedWorkspaceSeed(
|
||||
mode="directory",
|
||||
source_path=str(resolved_source_path),
|
||||
archive_path=archive_path,
|
||||
entry_count=entry_count,
|
||||
bytes_written=bytes_written,
|
||||
cleanup_dir=cleanup_dir,
|
||||
)
|
||||
if (
|
||||
not resolved_source_path.is_file()
|
||||
or not _is_supported_seed_archive(resolved_source_path)
|
||||
):
|
||||
raise ValueError(
|
||||
"source_path must be a directory or a .tar/.tar.gz/.tgz archive"
|
||||
)
|
||||
entry_count, bytes_written = _inspect_seed_archive(resolved_source_path)
|
||||
return PreparedWorkspaceSeed(
|
||||
mode="tar_archive",
|
||||
source_path=str(resolved_source_path),
|
||||
archive_path=resolved_source_path,
|
||||
entry_count=entry_count,
|
||||
bytes_written=bytes_written,
|
||||
)
|
||||
|
||||
def _ensure_task_guest_seed_support(self, instance: VmInstance) -> None:
|
||||
if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None:
|
||||
raise RuntimeError("runtime bundle does not provide a guest agent for task seeding")
|
||||
rootfs_image = instance.metadata.get("rootfs_image")
|
||||
if rootfs_image is None or rootfs_image == "":
|
||||
raise RuntimeError("task rootfs image is unavailable for guest workspace seeding")
|
||||
_patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path)
|
||||
|
||||
def _task_dir(self, task_id: str) -> Path:
|
||||
return self._tasks_dir / task_id
|
||||
|
||||
|
|
|
|||
|
|
@ -114,14 +114,23 @@ def test_pyro_task_methods_delegate_to_manager(tmp_path: Path) -> None:
|
|||
)
|
||||
)
|
||||
|
||||
created = pyro.create_task(environment="debian:12-base", allow_host_compat=True)
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
|
||||
|
||||
created = pyro.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=source_dir,
|
||||
)
|
||||
task_id = str(created["task_id"])
|
||||
executed = pyro.exec_task(task_id, command="printf 'ok\\n'")
|
||||
executed = pyro.exec_task(task_id, command="cat note.txt")
|
||||
status = pyro.status_task(task_id)
|
||||
logs = pyro.logs_task(task_id)
|
||||
deleted = pyro.delete_task(task_id)
|
||||
|
||||
assert executed["stdout"] == "ok\n"
|
||||
assert created["workspace_seed"]["mode"] == "directory"
|
||||
assert status["command_count"] == 1
|
||||
assert logs["count"] == 1
|
||||
assert deleted["deleted"] is True
|
||||
|
|
|
|||
|
|
@ -60,9 +60,13 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None:
|
|||
assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help
|
||||
|
||||
task_help = _subparser_choice(parser, "task").format_help()
|
||||
assert "pyro task create debian:12" in task_help
|
||||
assert "pyro task create debian:12 --source-path ./repo" in task_help
|
||||
assert "pyro task exec TASK_ID" in task_help
|
||||
|
||||
task_create_help = _subparser_choice(_subparser_choice(parser, "task"), "create").format_help()
|
||||
assert "--source-path" in task_create_help
|
||||
assert "seed into `/workspace`" in task_create_help
|
||||
|
||||
task_exec_help = _subparser_choice(_subparser_choice(parser, "task"), "exec").format_help()
|
||||
assert "persistent `/workspace`" in task_exec_help
|
||||
assert "pyro task exec TASK_ID -- cat note.txt" in task_exec_help
|
||||
|
|
@ -326,12 +330,20 @@ def test_cli_requires_run_command() -> None:
|
|||
cli._require_command([])
|
||||
|
||||
|
||||
def test_cli_requires_command_preserves_shell_argument_boundaries() -> None:
|
||||
command = cli._require_command(
|
||||
["--", "sh", "-lc", 'printf "hello from task\\n" > note.txt']
|
||||
)
|
||||
assert command == 'sh -lc \'printf "hello from task\\n" > note.txt\''
|
||||
|
||||
|
||||
def test_cli_task_create_prints_json(
|
||||
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
||||
) -> None:
|
||||
class StubPyro:
|
||||
def create_task(self, **kwargs: Any) -> dict[str, Any]:
|
||||
assert kwargs["environment"] == "debian:12"
|
||||
assert kwargs["source_path"] == "./repo"
|
||||
return {"task_id": "task-123", "state": "started"}
|
||||
|
||||
class StubParser:
|
||||
|
|
@ -345,6 +357,7 @@ def test_cli_task_create_prints_json(
|
|||
ttl_seconds=600,
|
||||
network=False,
|
||||
allow_host_compat=False,
|
||||
source_path="./repo",
|
||||
json=True,
|
||||
)
|
||||
|
||||
|
|
@ -366,6 +379,13 @@ def test_cli_task_create_prints_human(
|
|||
"environment": "debian:12",
|
||||
"state": "started",
|
||||
"workspace_path": "/workspace",
|
||||
"workspace_seed": {
|
||||
"mode": "directory",
|
||||
"source_path": "/tmp/repo",
|
||||
"destination": "/workspace",
|
||||
"entry_count": 1,
|
||||
"bytes_written": 6,
|
||||
},
|
||||
"execution_mode": "guest_vsock",
|
||||
"vcpu_count": 1,
|
||||
"mem_mib": 1024,
|
||||
|
|
@ -384,6 +404,7 @@ def test_cli_task_create_prints_human(
|
|||
ttl_seconds=600,
|
||||
network=False,
|
||||
allow_host_compat=False,
|
||||
source_path="/tmp/repo",
|
||||
json=False,
|
||||
)
|
||||
|
||||
|
|
@ -393,6 +414,7 @@ def test_cli_task_create_prints_human(
|
|||
output = capsys.readouterr().out
|
||||
assert "Task: task-123" in output
|
||||
assert "Workspace: /workspace" in output
|
||||
assert "Workspace seed: directory from /tmp/repo" in output
|
||||
|
||||
|
||||
def test_cli_task_exec_prints_human_output(
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from pyro_mcp.contract import (
|
|||
PUBLIC_CLI_DEMO_SUBCOMMANDS,
|
||||
PUBLIC_CLI_ENV_SUBCOMMANDS,
|
||||
PUBLIC_CLI_RUN_FLAGS,
|
||||
PUBLIC_CLI_TASK_CREATE_FLAGS,
|
||||
PUBLIC_CLI_TASK_SUBCOMMANDS,
|
||||
PUBLIC_MCP_TOOLS,
|
||||
PUBLIC_SDK_METHODS,
|
||||
|
|
@ -67,6 +68,11 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None:
|
|||
task_help_text = _subparser_choice(parser, "task").format_help()
|
||||
for subcommand_name in PUBLIC_CLI_TASK_SUBCOMMANDS:
|
||||
assert subcommand_name in task_help_text
|
||||
task_create_help_text = _subparser_choice(
|
||||
_subparser_choice(parser, "task"), "create"
|
||||
).format_help()
|
||||
for flag in PUBLIC_CLI_TASK_CREATE_FLAGS:
|
||||
assert flag in task_create_help_text
|
||||
|
||||
demo_help_text = _subparser_choice(parser, "demo").format_help()
|
||||
for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS:
|
||||
|
|
|
|||
|
|
@ -171,6 +171,9 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
|
||||
|
||||
def _extract_structured(raw_result: object) -> dict[str, Any]:
|
||||
if not isinstance(raw_result, tuple) or len(raw_result) != 2:
|
||||
|
|
@ -188,6 +191,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
{
|
||||
"environment": "debian:12-base",
|
||||
"allow_host_compat": True,
|
||||
"source_path": str(source_dir),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -197,7 +201,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
"task_exec",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"command": "printf 'ok\\n'",
|
||||
"command": "cat note.txt",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -207,6 +211,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
|
||||
created, executed, logs, deleted = asyncio.run(_run())
|
||||
assert created["state"] == "started"
|
||||
assert created["workspace_seed"]["mode"] == "directory"
|
||||
assert executed["stdout"] == "ok\n"
|
||||
assert logs["count"] == 1
|
||||
assert deleted["deleted"] is True
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import socket
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -62,6 +66,45 @@ def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert stub.closed is True
|
||||
|
||||
|
||||
def test_vsock_exec_client_upload_archive_round_trip(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
|
||||
archive_path = tmp_path / "seed.tgz"
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
payload = b"hello\n"
|
||||
info = tarfile.TarInfo(name="note.txt")
|
||||
info.size = len(payload)
|
||||
archive.addfile(info, io.BytesIO(payload))
|
||||
stub = StubSocket(
|
||||
b'{"destination":"/workspace","entry_count":1,"bytes_written":6}'
|
||||
)
|
||||
|
||||
def socket_factory(family: int, sock_type: int) -> StubSocket:
|
||||
assert family == socket.AF_VSOCK
|
||||
assert sock_type == socket.SOCK_STREAM
|
||||
return stub
|
||||
|
||||
client = VsockExecClient(socket_factory=socket_factory)
|
||||
response = client.upload_archive(
|
||||
1234,
|
||||
5005,
|
||||
archive_path,
|
||||
destination="/workspace",
|
||||
timeout_seconds=60,
|
||||
)
|
||||
|
||||
request_payload, archive_payload = stub.sent.split(b"\n", 1)
|
||||
request = json.loads(request_payload.decode("utf-8"))
|
||||
assert request["action"] == "extract_archive"
|
||||
assert request["destination"] == "/workspace"
|
||||
assert int(request["archive_size"]) == archive_path.stat().st_size
|
||||
assert archive_payload == archive_path.read_bytes()
|
||||
assert response.entry_count == 1
|
||||
assert response.bytes_written == 6
|
||||
assert stub.closed is True
|
||||
|
||||
|
||||
def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
|
||||
stub = StubSocket(b"[]")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import subprocess
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -306,6 +308,140 @@ def test_task_lifecycle_and_logs(tmp_path: Path) -> None:
|
|||
manager.status_task(task_id)
|
||||
|
||||
|
||||
def test_task_create_seeds_directory_source_into_workspace(tmp_path: Path) -> None:
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
created = manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=source_dir,
|
||||
)
|
||||
task_id = str(created["task_id"])
|
||||
|
||||
workspace_seed = created["workspace_seed"]
|
||||
assert workspace_seed["mode"] == "directory"
|
||||
assert workspace_seed["source_path"] == str(source_dir.resolve())
|
||||
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
|
||||
assert executed["stdout"] == "hello\n"
|
||||
|
||||
status = manager.status_task(task_id)
|
||||
assert status["workspace_seed"]["mode"] == "directory"
|
||||
assert status["workspace_seed"]["source_path"] == str(source_dir.resolve())
|
||||
|
||||
|
||||
def test_task_create_seeds_tar_archive_into_workspace(tmp_path: Path) -> None:
|
||||
archive_path = tmp_path / "seed.tgz"
|
||||
nested_dir = tmp_path / "src"
|
||||
nested_dir.mkdir()
|
||||
(nested_dir / "note.txt").write_text("archive\n", encoding="utf-8")
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
archive.add(nested_dir / "note.txt", arcname="note.txt")
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
created = manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=archive_path,
|
||||
)
|
||||
task_id = str(created["task_id"])
|
||||
|
||||
assert created["workspace_seed"]["mode"] == "tar_archive"
|
||||
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
|
||||
assert executed["stdout"] == "archive\n"
|
||||
|
||||
|
||||
def test_task_create_rejects_unsafe_seed_archive(tmp_path: Path) -> None:
|
||||
archive_path = tmp_path / "bad.tgz"
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
payload = b"bad\n"
|
||||
info = tarfile.TarInfo(name="../escape.txt")
|
||||
info.size = len(payload)
|
||||
archive.addfile(info, io.BytesIO(payload))
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="unsafe archive member path"):
|
||||
manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=archive_path,
|
||||
)
|
||||
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
|
||||
|
||||
|
||||
def test_task_create_rejects_archive_that_writes_through_symlink(tmp_path: Path) -> None:
|
||||
archive_path = tmp_path / "bad-symlink.tgz"
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
symlink_info = tarfile.TarInfo(name="linked")
|
||||
symlink_info.type = tarfile.SYMTYPE
|
||||
symlink_info.linkname = "outside"
|
||||
archive.addfile(symlink_info)
|
||||
|
||||
payload = b"bad\n"
|
||||
file_info = tarfile.TarInfo(name="linked/note.txt")
|
||||
file_info.size = len(payload)
|
||||
archive.addfile(file_info, io.BytesIO(payload))
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="traverse through a symlinked path"):
|
||||
manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=archive_path,
|
||||
)
|
||||
|
||||
|
||||
def test_task_create_cleans_up_on_seed_failure(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
def _boom(*args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
del args, kwargs
|
||||
raise RuntimeError("seed import failed")
|
||||
|
||||
monkeypatch.setattr(manager._backend, "import_archive", _boom) # noqa: SLF001
|
||||
|
||||
with pytest.raises(RuntimeError, match="seed import failed"):
|
||||
manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=source_dir,
|
||||
)
|
||||
|
||||
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
|
||||
|
||||
|
||||
def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None:
|
||||
base_dir = tmp_path / "vms"
|
||||
manager = VmManager(
|
||||
|
|
|
|||
2
uv.lock
generated
2
uv.lock
generated
|
|
@ -706,7 +706,7 @@ crypto = [
|
|||
|
||||
[[package]]
|
||||
name = "pyro-mcp"
|
||||
version = "2.1.0"
|
||||
version = "2.2.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "mcp" },
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue