Add seeded task workspace creation

Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces.

Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries.

Validation:
- uv lock
- UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py
- UV_CACHE_DIR=.uv-cache make check
- UV_CACHE_DIR=.uv-cache make dist-check
- real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
This commit is contained in:
Thales Maciel 2026-03-11 21:45:38 -03:00
parent 58df176148
commit aa886b346e
25 changed files with 1076 additions and 75 deletions

View file

@ -8,11 +8,13 @@ import shlex
import shutil
import signal
import subprocess
import tarfile
import tempfile
import threading
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from pathlib import Path, PurePosixPath
from typing import Any, Literal, cast
from pyro_mcp.runtime import (
@ -34,11 +36,15 @@ DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False
TASK_LAYOUT_VERSION = 1
TASK_LAYOUT_VERSION = 2
TASK_WORKSPACE_DIRNAME = "workspace"
TASK_COMMANDS_DIRNAME = "commands"
TASK_RUNTIME_DIRNAME = "runtime"
TASK_WORKSPACE_GUEST_PATH = "/workspace"
TASK_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py"
TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60
WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"]
@dataclass
@ -82,6 +88,7 @@ class TaskRecord:
network: NetworkConfig | None = None
command_count: int = 0
last_command: dict[str, Any] | None = None
workspace_seed: dict[str, Any] = field(default_factory=dict)
@classmethod
def from_instance(
@ -90,6 +97,7 @@ class TaskRecord:
*,
command_count: int = 0,
last_command: dict[str, Any] | None = None,
workspace_seed: dict[str, Any] | None = None,
) -> TaskRecord:
return cls(
task_id=instance.vm_id,
@ -108,6 +116,7 @@ class TaskRecord:
network=instance.network,
command_count=command_count,
last_command=last_command,
workspace_seed=dict(workspace_seed or _empty_workspace_seed_payload()),
)
def to_instance(self, *, workdir: Path) -> VmInstance:
@ -148,6 +157,7 @@ class TaskRecord:
"network": _serialize_network(self.network),
"command_count": self.command_count,
"last_command": self.last_command,
"workspace_seed": self.workspace_seed,
}
@classmethod
@ -169,9 +179,35 @@ class TaskRecord:
network=_deserialize_network(payload.get("network")),
command_count=int(payload.get("command_count", 0)),
last_command=_optional_dict(payload.get("last_command")),
workspace_seed=_task_workspace_seed_dict(payload.get("workspace_seed")),
)
@dataclass(frozen=True)
class PreparedWorkspaceSeed:
"""Prepared host-side seed archive plus metadata."""
mode: WorkspaceSeedMode
source_path: str | None
archive_path: Path | None = None
entry_count: int = 0
bytes_written: int = 0
cleanup_dir: Path | None = None
def to_payload(self) -> dict[str, Any]:
return {
"mode": self.mode,
"source_path": self.source_path,
"destination": TASK_WORKSPACE_GUEST_PATH,
"entry_count": self.entry_count,
"bytes_written": self.bytes_written,
}
def cleanup(self) -> None:
if self.cleanup_dir is not None:
shutil.rmtree(self.cleanup_dir, ignore_errors=True)
@dataclass(frozen=True)
class VmExecResult:
"""Command execution output."""
@ -216,6 +252,32 @@ def _string_dict(value: object) -> dict[str, str]:
return {str(key): str(item) for key, item in value.items()}
def _empty_workspace_seed_payload() -> dict[str, Any]:
return {
"mode": "empty",
"source_path": None,
"destination": TASK_WORKSPACE_GUEST_PATH,
"entry_count": 0,
"bytes_written": 0,
}
def _task_workspace_seed_dict(value: object) -> dict[str, Any]:
if not isinstance(value, dict):
return _empty_workspace_seed_payload()
payload = _empty_workspace_seed_payload()
payload.update(
{
"mode": str(value.get("mode", payload["mode"])),
"source_path": _optional_str(value.get("source_path")),
"destination": str(value.get("destination", payload["destination"])),
"entry_count": int(value.get("entry_count", payload["entry_count"])),
"bytes_written": int(value.get("bytes_written", payload["bytes_written"])),
}
)
return payload
def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None:
if network is None:
return None
@ -300,6 +362,201 @@ def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str:
return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}"
def _is_supported_seed_archive(path: Path) -> bool:
name = path.name.lower()
return name.endswith(".tar") or name.endswith(".tar.gz") or name.endswith(".tgz")
def _normalize_workspace_destination(destination: str) -> tuple[str, PurePosixPath]:
candidate = destination.strip()
if candidate == "":
raise ValueError("workspace destination must not be empty")
destination_path = PurePosixPath(candidate)
workspace_root = PurePosixPath(TASK_WORKSPACE_GUEST_PATH)
if not destination_path.is_absolute():
destination_path = workspace_root / destination_path
parts = [part for part in destination_path.parts if part not in {"", "."}]
normalized = PurePosixPath("/") / PurePosixPath(*parts)
if normalized == PurePosixPath("/"):
raise ValueError("workspace destination must stay inside /workspace")
if normalized.parts[: len(workspace_root.parts)] != workspace_root.parts:
raise ValueError("workspace destination must stay inside /workspace")
suffix = normalized.relative_to(workspace_root)
return str(normalized), suffix
def _workspace_host_destination(workspace_dir: Path, destination: str) -> Path:
_, suffix = _normalize_workspace_destination(destination)
if str(suffix) in {"", "."}:
return workspace_dir
return workspace_dir.joinpath(*suffix.parts)
def _normalize_archive_member_name(name: str) -> PurePosixPath:
candidate = name.strip()
if candidate == "":
raise RuntimeError("archive member path is empty")
member_path = PurePosixPath(candidate)
if member_path.is_absolute():
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
parts = [part for part in member_path.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"unsafe archive member path: {name}")
normalized = PurePosixPath(*parts)
if str(normalized) in {"", "."}:
raise RuntimeError(f"unsafe archive member path: {name}")
return normalized
def _validate_archive_symlink_target(member_name: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
raise RuntimeError(f"symlink {member_name} has an empty target")
link_path = PurePosixPath(target)
if link_path.is_absolute():
raise RuntimeError(f"symlink {member_name} escapes the workspace")
combined = member_name.parent.joinpath(link_path)
parts = [part for part in combined.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"symlink {member_name} escapes the workspace")
def _inspect_seed_archive(archive_path: Path) -> tuple[int, int]:
entry_count = 0
bytes_written = 0
with tarfile.open(archive_path, "r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_archive_member_name(member.name)
entry_count += 1
if member.isdir():
continue
if member.isfile():
bytes_written += member.size
continue
if member.issym():
_validate_archive_symlink_target(member_name, member.linkname)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return entry_count, bytes_written
def _write_directory_seed_archive(source_dir: Path, archive_path: Path) -> None:
archive_path.parent.mkdir(parents=True, exist_ok=True)
with tarfile.open(archive_path, "w") as archive:
for child in sorted(source_dir.iterdir(), key=lambda item: item.name):
archive.add(child, arcname=child.name, recursive=True)
def _extract_seed_archive_to_host_workspace(
archive_path: Path,
*,
workspace_dir: Path,
destination: str,
) -> dict[str, Any]:
normalized_destination, _ = _normalize_workspace_destination(destination)
destination_root = _workspace_host_destination(workspace_dir, normalized_destination)
destination_root.mkdir(parents=True, exist_ok=True)
entry_count = 0
bytes_written = 0
with tarfile.open(archive_path, "r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_archive_member_name(member.name)
target_path = destination_root.joinpath(*member_name.parts)
entry_count += 1
_ensure_no_symlink_parents(workspace_dir, target_path, member.name)
if member.isdir():
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.is_symlink() or target_path.is_dir():
raise RuntimeError(f"file conflicts with existing path: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read archive member: {member.name}")
with target_path.open("wb") as handle:
shutil.copyfileobj(source, handle)
bytes_written += member.size
continue
if member.issym():
_validate_archive_symlink_target(member_name, member.linkname)
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not target_path.is_symlink():
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
if target_path.is_symlink():
target_path.unlink()
os.symlink(member.linkname, target_path)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": normalized_destination,
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _instance_workspace_host_dir(instance: VmInstance) -> Path:
raw_value = instance.metadata.get("workspace_host_dir")
if raw_value is None or raw_value == "":
raise RuntimeError("task workspace host directory is unavailable")
return Path(raw_value)
def _patch_rootfs_guest_agent(rootfs_image: Path, guest_agent_path: Path) -> None:
debugfs_path = shutil.which("debugfs")
if debugfs_path is None:
raise RuntimeError(
"debugfs is required to seed task workspaces on guest-backed runtimes"
)
with tempfile.TemporaryDirectory(prefix="pyro-guest-agent-") as temp_dir:
staged_agent_path = Path(temp_dir) / "pyro_guest_agent.py"
shutil.copy2(guest_agent_path, staged_agent_path)
subprocess.run( # noqa: S603
[debugfs_path, "-w", "-R", f"rm {TASK_GUEST_AGENT_PATH}", str(rootfs_image)],
text=True,
capture_output=True,
check=False,
)
proc = subprocess.run( # noqa: S603
[
debugfs_path,
"-w",
"-R",
f"write {staged_agent_path} {TASK_GUEST_AGENT_PATH}",
str(rootfs_image),
],
text=True,
capture_output=True,
check=False,
)
if proc.returncode != 0:
raise RuntimeError(
"failed to patch guest agent into task rootfs: "
f"{proc.stderr.strip() or proc.stdout.strip()}"
)
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"archive member would traverse through a symlinked path: {member_name}"
)
def _pid_is_running(pid: int | None) -> bool:
if pid is None:
return False
@ -337,6 +594,15 @@ class VmBackend:
def delete(self, instance: VmInstance) -> None: # pragma: no cover
raise NotImplementedError
def import_archive( # pragma: no cover
self,
instance: VmInstance,
*,
archive_path: Path,
destination: str,
) -> dict[str, Any]:
raise NotImplementedError
class MockBackend(VmBackend):
"""Host-process backend used for development and testability."""
@ -365,6 +631,19 @@ class MockBackend(VmBackend):
def delete(self, instance: VmInstance) -> None:
shutil.rmtree(instance.workdir, ignore_errors=True)
def import_archive(
self,
instance: VmInstance,
*,
archive_path: Path,
destination: str,
) -> dict[str, Any]:
return _extract_seed_archive_to_host_workspace(
archive_path,
workspace_dir=_instance_workspace_host_dir(instance),
destination=destination,
)
class FirecrackerBackend(VmBackend): # pragma: no cover
"""Host-gated backend that validates Firecracker prerequisites."""
@ -562,6 +841,46 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
self._network_manager.cleanup(instance.network)
shutil.rmtree(instance.workdir, ignore_errors=True)
def import_archive(
self,
instance: VmInstance,
*,
archive_path: Path,
destination: str,
) -> dict[str, Any]:
if self._runtime_capabilities.supports_guest_exec:
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
uds_path = instance.metadata.get("guest_exec_uds_path")
deadline = time.monotonic() + 10
while True:
try:
response = self._guest_exec_client.upload_archive(
guest_cid,
port,
archive_path,
destination=destination,
timeout_seconds=TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS,
uds_path=uds_path,
)
return {
"destination": response.destination,
"entry_count": response.entry_count,
"bytes_written": response.bytes_written,
}
except (OSError, RuntimeError) as exc:
if time.monotonic() >= deadline:
raise RuntimeError(
f"guest archive transport did not become ready: {exc}"
) from exc
time.sleep(0.2)
instance.metadata["execution_mode"] = "host_compat"
return _extract_seed_archive_to_host_workspace(
archive_path,
workspace_dir=_instance_workspace_host_dir(instance),
destination=destination,
)
class VmManager:
"""In-process lifecycle manager for ephemeral VM environments and tasks."""
@ -814,9 +1133,11 @@ class VmManager:
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
source_path: str | Path | None = None,
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_environment(environment, runtime_paths=self._runtime_paths)
prepared_seed = self._prepare_workspace_seed(source_path)
now = time.time()
task_id = uuid.uuid4().hex[:12]
task_dir = self._task_dir(task_id)
@ -851,10 +1172,25 @@ class VmManager:
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
self._backend.create(instance)
if (
prepared_seed.archive_path is not None
and self._runtime_capabilities.supports_guest_exec
):
self._ensure_task_guest_seed_support(instance)
with self._lock:
self._start_instance_locked(instance)
self._require_guest_exec_or_opt_in(instance)
if self._runtime_capabilities.supports_guest_exec:
workspace_seed = prepared_seed.to_payload()
if prepared_seed.archive_path is not None:
import_summary = self._backend.import_archive(
instance,
archive_path=prepared_seed.archive_path,
destination=TASK_WORKSPACE_GUEST_PATH,
)
workspace_seed["entry_count"] = int(import_summary["entry_count"])
workspace_seed["bytes_written"] = int(import_summary["bytes_written"])
workspace_seed["destination"] = str(import_summary["destination"])
elif self._runtime_capabilities.supports_guest_exec:
self._backend.exec(
instance,
f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}",
@ -862,7 +1198,7 @@ class VmManager:
)
else:
instance.metadata["execution_mode"] = "host_compat"
task = TaskRecord.from_instance(instance)
task = TaskRecord.from_instance(instance, workspace_seed=workspace_seed)
self._save_task_locked(task)
return self._serialize_task(task)
except Exception:
@ -879,6 +1215,8 @@ class VmManager:
pass
shutil.rmtree(task_dir, ignore_errors=True)
raise
finally:
prepared_seed.cleanup()
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
if timeout_seconds <= 0:
@ -999,6 +1337,7 @@ class VmManager:
"tap_name": task.network.tap_name if task.network is not None else None,
"execution_mode": task.metadata.get("execution_mode", "pending"),
"workspace_path": TASK_WORKSPACE_GUEST_PATH,
"workspace_seed": _task_workspace_seed_dict(task.workspace_seed),
"command_count": task.command_count,
"last_command": task.last_command,
"metadata": task.metadata,
@ -1094,6 +1433,53 @@ class VmManager:
execution_mode = instance.metadata.get("execution_mode", "unknown")
return exec_result, execution_mode
def _prepare_workspace_seed(self, source_path: str | Path | None) -> PreparedWorkspaceSeed:
if source_path is None:
return PreparedWorkspaceSeed(mode="empty", source_path=None)
resolved_source_path = Path(source_path).expanduser().resolve()
if not resolved_source_path.exists():
raise ValueError(f"source_path {resolved_source_path} does not exist")
if resolved_source_path.is_dir():
cleanup_dir = Path(tempfile.mkdtemp(prefix="pyro-task-seed-"))
archive_path = cleanup_dir / "workspace-seed.tar"
try:
_write_directory_seed_archive(resolved_source_path, archive_path)
entry_count, bytes_written = _inspect_seed_archive(archive_path)
except Exception:
shutil.rmtree(cleanup_dir, ignore_errors=True)
raise
return PreparedWorkspaceSeed(
mode="directory",
source_path=str(resolved_source_path),
archive_path=archive_path,
entry_count=entry_count,
bytes_written=bytes_written,
cleanup_dir=cleanup_dir,
)
if (
not resolved_source_path.is_file()
or not _is_supported_seed_archive(resolved_source_path)
):
raise ValueError(
"source_path must be a directory or a .tar/.tar.gz/.tgz archive"
)
entry_count, bytes_written = _inspect_seed_archive(resolved_source_path)
return PreparedWorkspaceSeed(
mode="tar_archive",
source_path=str(resolved_source_path),
archive_path=resolved_source_path,
entry_count=entry_count,
bytes_written=bytes_written,
)
def _ensure_task_guest_seed_support(self, instance: VmInstance) -> None:
if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None:
raise RuntimeError("runtime bundle does not provide a guest agent for task seeding")
rootfs_image = instance.metadata.get("rootfs_image")
if rootfs_image is None or rootfs_image == "":
raise RuntimeError("task rootfs image is unavailable for guest workspace seeding")
_patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path)
def _task_dir(self, task_id: str) -> Path:
return self._tasks_dir / task_id