Add seeded task workspace creation
Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces. Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries. Validation: - uv lock - UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py - UV_CACHE_DIR=.uv-cache make check - UV_CACHE_DIR=.uv-cache make dist-check - real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
This commit is contained in:
parent
58df176148
commit
aa886b346e
25 changed files with 1076 additions and 75 deletions
|
|
@ -86,6 +86,7 @@ class Pyro:
|
|||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
source_path: str | Path | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.create_task(
|
||||
environment=environment,
|
||||
|
|
@ -94,6 +95,7 @@ class Pyro:
|
|||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
allow_host_compat=allow_host_compat,
|
||||
source_path=source_path,
|
||||
)
|
||||
|
||||
def exec_task(
|
||||
|
|
@ -245,6 +247,7 @@ class Pyro:
|
|||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
source_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Create and start a persistent task workspace."""
|
||||
return self.create_task(
|
||||
|
|
@ -254,6 +257,7 @@ class Pyro:
|
|||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
allow_host_compat=allow_host_compat,
|
||||
source_path=source_path,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import argparse
|
||||
import json
|
||||
import shlex
|
||||
import sys
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
|
@ -155,6 +156,14 @@ def _print_task_summary_human(payload: dict[str, Any], *, action: str) -> None:
|
|||
print(f"Environment: {str(payload.get('environment', 'unknown'))}")
|
||||
print(f"State: {str(payload.get('state', 'unknown'))}")
|
||||
print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}")
|
||||
workspace_seed = payload.get("workspace_seed")
|
||||
if isinstance(workspace_seed, dict):
|
||||
mode = str(workspace_seed.get("mode", "empty"))
|
||||
source_path = workspace_seed.get("source_path")
|
||||
if isinstance(source_path, str) and source_path != "":
|
||||
print(f"Workspace seed: {mode} from {source_path}")
|
||||
else:
|
||||
print(f"Workspace seed: {mode}")
|
||||
print(f"Execution mode: {str(payload.get('execution_mode', 'pending'))}")
|
||||
print(
|
||||
f"Resources: {int(payload.get('vcpu_count', 0))} vCPU / "
|
||||
|
|
@ -446,7 +455,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
epilog=dedent(
|
||||
"""
|
||||
Examples:
|
||||
pyro task create debian:12
|
||||
pyro task create debian:12 --source-path ./repo
|
||||
pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt'
|
||||
pyro task logs TASK_ID
|
||||
"""
|
||||
|
|
@ -458,7 +467,13 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
"create",
|
||||
help="Create and start a persistent task workspace.",
|
||||
description="Create a task workspace that stays alive across repeated exec calls.",
|
||||
epilog="Example:\n pyro task create debian:12",
|
||||
epilog=dedent(
|
||||
"""
|
||||
Examples:
|
||||
pyro task create debian:12
|
||||
pyro task create debian:12 --source-path ./repo
|
||||
"""
|
||||
),
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
task_create_parser.add_argument(
|
||||
|
|
@ -497,6 +512,13 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
"is unavailable."
|
||||
),
|
||||
)
|
||||
task_create_parser.add_argument(
|
||||
"--source-path",
|
||||
help=(
|
||||
"Optional host directory or .tar/.tar.gz/.tgz archive to seed into `/workspace` "
|
||||
"before the task is returned."
|
||||
),
|
||||
)
|
||||
task_create_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
|
|
@ -663,7 +685,7 @@ def _require_command(command_args: list[str]) -> str:
|
|||
command_args = command_args[1:]
|
||||
if not command_args:
|
||||
raise ValueError("command is required after `--`")
|
||||
return " ".join(command_args)
|
||||
return shlex.join(command_args)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
|
|
@ -764,6 +786,7 @@ def main() -> None:
|
|||
ttl_seconds=args.ttl_seconds,
|
||||
network=args.network,
|
||||
allow_host_compat=args.allow_host_compat,
|
||||
source_path=args.source_path,
|
||||
)
|
||||
if bool(args.json):
|
||||
_print_json(payload)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,15 @@ PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task")
|
|||
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
|
||||
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
|
||||
PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status")
|
||||
PUBLIC_CLI_TASK_CREATE_FLAGS = (
|
||||
"--vcpu-count",
|
||||
"--mem-mib",
|
||||
"--ttl-seconds",
|
||||
"--network",
|
||||
"--allow-host-compat",
|
||||
"--source-path",
|
||||
"--json",
|
||||
)
|
||||
PUBLIC_CLI_RUN_FLAGS = (
|
||||
"--vcpu-count",
|
||||
"--mem-mib",
|
||||
|
|
|
|||
|
|
@ -1,26 +1,31 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Minimal guest-side exec agent for pyro runtime bundles."""
|
||||
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
PORT = 5005
|
||||
BUFFER_SIZE = 65536
|
||||
WORKSPACE_ROOT = PurePosixPath("/workspace")
|
||||
|
||||
|
||||
def _read_request(conn: socket.socket) -> dict[str, Any]:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = conn.recv(BUFFER_SIZE)
|
||||
data = conn.recv(1)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
if b"\n" in data:
|
||||
if data == b"\n":
|
||||
break
|
||||
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
|
||||
if not isinstance(payload, dict):
|
||||
|
|
@ -28,6 +33,130 @@ def _read_request(conn: socket.socket) -> dict[str, Any]:
|
|||
return payload
|
||||
|
||||
|
||||
def _read_exact(conn: socket.socket, size: int) -> bytes:
|
||||
remaining = size
|
||||
chunks: list[bytes] = []
|
||||
while remaining > 0:
|
||||
data = conn.recv(min(BUFFER_SIZE, remaining))
|
||||
if data == b"":
|
||||
raise RuntimeError("unexpected EOF while reading archive payload")
|
||||
chunks.append(data)
|
||||
remaining -= len(data)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _normalize_member_name(name: str) -> PurePosixPath:
|
||||
candidate = name.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("archive member path is empty")
|
||||
member_path = PurePosixPath(candidate)
|
||||
if member_path.is_absolute():
|
||||
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
|
||||
parts = [part for part in member_path.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
normalized = PurePosixPath(*parts)
|
||||
if str(normalized) in {"", "."}:
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
|
||||
candidate = destination.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("destination must not be empty")
|
||||
destination_path = PurePosixPath(candidate)
|
||||
if not destination_path.is_absolute():
|
||||
destination_path = WORKSPACE_ROOT / destination_path
|
||||
parts = [part for part in destination_path.parts if part not in {"", "."}]
|
||||
normalized = PurePosixPath("/") / PurePosixPath(*parts)
|
||||
if normalized == PurePosixPath("/"):
|
||||
raise RuntimeError("destination must stay inside /workspace")
|
||||
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
|
||||
raise RuntimeError("destination must stay inside /workspace")
|
||||
suffix = normalized.relative_to(WORKSPACE_ROOT)
|
||||
host_path = Path("/workspace")
|
||||
if str(suffix) not in {"", "."}:
|
||||
host_path = host_path / str(suffix)
|
||||
return normalized, host_path
|
||||
|
||||
|
||||
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
raise RuntimeError(f"symlink {member_path} has an empty target")
|
||||
target_path = PurePosixPath(target)
|
||||
if target_path.is_absolute():
|
||||
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
||||
combined = member_path.parent.joinpath(target_path)
|
||||
parts = [part for part in combined.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
||||
|
||||
|
||||
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
|
||||
relative_path = target_path.relative_to(root)
|
||||
current = root
|
||||
for part in relative_path.parts[:-1]:
|
||||
current = current / part
|
||||
if current.is_symlink():
|
||||
raise RuntimeError(
|
||||
f"archive member would traverse through a symlinked path: {member_name}"
|
||||
)
|
||||
|
||||
|
||||
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
||||
_, destination_root = _normalize_destination(destination)
|
||||
destination_root.mkdir(parents=True, exist_ok=True)
|
||||
bytes_written = 0
|
||||
entry_count = 0
|
||||
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_member_name(member.name)
|
||||
target_path = destination_root / str(member_name)
|
||||
entry_count += 1
|
||||
_ensure_no_symlink_parents(destination_root, target_path, member.name)
|
||||
if member.isdir():
|
||||
if target_path.exists() and not target_path.is_dir():
|
||||
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
if member.isfile():
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
|
||||
raise RuntimeError(f"file conflicts with existing path: {member.name}")
|
||||
source = archive.extractfile(member)
|
||||
if source is None:
|
||||
raise RuntimeError(f"failed to read archive member: {member.name}")
|
||||
with target_path.open("wb") as handle:
|
||||
while True:
|
||||
chunk = source.read(BUFFER_SIZE)
|
||||
if chunk == b"":
|
||||
break
|
||||
handle.write(chunk)
|
||||
bytes_written += member.size
|
||||
continue
|
||||
if member.issym():
|
||||
_validate_symlink_target(member_name, member.linkname)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and not target_path.is_symlink():
|
||||
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
|
||||
if target_path.is_symlink():
|
||||
target_path.unlink()
|
||||
os.symlink(member.linkname, target_path)
|
||||
continue
|
||||
if member.islnk():
|
||||
raise RuntimeError(
|
||||
f"hard links are not allowed in workspace archives: {member.name}"
|
||||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return {
|
||||
"destination": destination,
|
||||
"entry_count": entry_count,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
|
||||
|
||||
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
|
||||
started = time.monotonic()
|
||||
try:
|
||||
|
|
@ -64,9 +193,18 @@ def main() -> None:
|
|||
conn, _ = server.accept()
|
||||
with conn:
|
||||
request = _read_request(conn)
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
archive_size = int(request.get("archive_size", 0))
|
||||
if archive_size < 0:
|
||||
raise RuntimeError("archive_size must not be negative")
|
||||
destination = str(request.get("destination", "/workspace"))
|
||||
payload = _read_exact(conn, archive_size)
|
||||
response = _extract_archive(payload, destination)
|
||||
else:
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,14 +18,14 @@
|
|||
"component_versions": {
|
||||
"base_distro": "debian-bookworm-20250210",
|
||||
"firecracker": "1.12.1",
|
||||
"guest_agent": "0.1.0-dev",
|
||||
"guest_agent": "0.2.0-dev",
|
||||
"jailer": "1.12.1",
|
||||
"kernel": "5.10.210"
|
||||
},
|
||||
"guest": {
|
||||
"agent": {
|
||||
"path": "guest/pyro_guest_agent.py",
|
||||
"sha256": "65bf8a9a57ffd7321463537e598c4b30f0a13046cbd4538f1b65bc351da5d3c0"
|
||||
"sha256": "3b684b1b07745fc7788e560b0bdd0c0535c31c395ff87474ae9e114f4489726d"
|
||||
}
|
||||
},
|
||||
"platform": "linux-x86_64",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from typing import Any
|
|||
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
|
||||
|
||||
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.1.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.2.0"
|
||||
OCI_MANIFEST_ACCEPT = ", ".join(
|
||||
(
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from __future__ import annotations
|
|||
import json
|
||||
import socket
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
|
||||
|
|
@ -31,6 +32,13 @@ class GuestExecResponse:
|
|||
duration_ms: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuestArchiveResponse:
|
||||
destination: str
|
||||
entry_count: int
|
||||
bytes_written: int
|
||||
|
||||
|
||||
class VsockExecClient:
|
||||
"""Minimal JSON-over-stream client for a guest exec agent."""
|
||||
|
||||
|
|
@ -50,6 +58,63 @@ class VsockExecClient:
|
|||
"command": command,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
}
|
||||
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
|
||||
try:
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
payload = self._recv_json_payload(sock)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest exec response must be a JSON object")
|
||||
return GuestExecResponse(
|
||||
stdout=str(payload.get("stdout", "")),
|
||||
stderr=str(payload.get("stderr", "")),
|
||||
exit_code=int(payload.get("exit_code", -1)),
|
||||
duration_ms=int(payload.get("duration_ms", 0)),
|
||||
)
|
||||
|
||||
def upload_archive(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
archive_path: Path,
|
||||
*,
|
||||
destination: str,
|
||||
timeout_seconds: int = 60,
|
||||
uds_path: str | None = None,
|
||||
) -> GuestArchiveResponse:
|
||||
request = {
|
||||
"action": "extract_archive",
|
||||
"destination": destination,
|
||||
"archive_size": archive_path.stat().st_size,
|
||||
}
|
||||
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
|
||||
try:
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
with archive_path.open("rb") as handle:
|
||||
for chunk in iter(lambda: handle.read(65536), b""):
|
||||
sock.sendall(chunk)
|
||||
payload = self._recv_json_payload(sock)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest archive response must be a JSON object")
|
||||
return GuestArchiveResponse(
|
||||
destination=str(payload.get("destination", destination)),
|
||||
entry_count=int(payload.get("entry_count", 0)),
|
||||
bytes_written=int(payload.get("bytes_written", 0)),
|
||||
)
|
||||
|
||||
def _connect(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
timeout_seconds: int,
|
||||
*,
|
||||
uds_path: str | None,
|
||||
) -> SocketLike:
|
||||
family = getattr(socket, "AF_VSOCK", None)
|
||||
if family is not None:
|
||||
sock = self._socket_factory(family, socket.SOCK_STREAM)
|
||||
|
|
@ -59,33 +124,15 @@ class VsockExecClient:
|
|||
connect_address = uds_path
|
||||
else:
|
||||
raise RuntimeError("vsock sockets are not supported on this host Python runtime")
|
||||
try:
|
||||
sock.settimeout(timeout_seconds)
|
||||
sock.connect(connect_address)
|
||||
if family is None:
|
||||
sock.sendall(f"CONNECT {port}\n".encode("utf-8"))
|
||||
status = self._recv_line(sock)
|
||||
if not status.startswith("OK "):
|
||||
raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}")
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = sock.recv(65536)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
payload = json.loads(b"".join(chunks).decode("utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest exec response must be a JSON object")
|
||||
return GuestExecResponse(
|
||||
stdout=str(payload.get("stdout", "")),
|
||||
stderr=str(payload.get("stderr", "")),
|
||||
exit_code=int(payload.get("exit_code", -1)),
|
||||
duration_ms=int(payload.get("duration_ms", 0)),
|
||||
)
|
||||
sock.settimeout(timeout_seconds)
|
||||
sock.connect(connect_address)
|
||||
if family is None:
|
||||
sock.sendall(f"CONNECT {port}\n".encode("utf-8"))
|
||||
status = self._recv_line(sock)
|
||||
if not status.startswith("OK "):
|
||||
sock.close()
|
||||
raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}")
|
||||
return sock
|
||||
|
||||
@staticmethod
|
||||
def _recv_line(sock: SocketLike) -> str:
|
||||
|
|
@ -98,3 +145,13 @@ class VsockExecClient:
|
|||
if data == b"\n":
|
||||
break
|
||||
return b"".join(chunks).decode("utf-8", errors="replace")
|
||||
|
||||
@staticmethod
|
||||
def _recv_json_payload(sock: SocketLike) -> Any:
|
||||
chunks: list[bytes] = []
|
||||
while True:
|
||||
data = sock.recv(65536)
|
||||
if data == b"":
|
||||
break
|
||||
chunks.append(data)
|
||||
return json.loads(b"".join(chunks).decode("utf-8"))
|
||||
|
|
|
|||
|
|
@ -8,11 +8,13 @@ import shlex
|
|||
import shutil
|
||||
import signal
|
||||
import subprocess
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from pyro_mcp.runtime import (
|
||||
|
|
@ -34,11 +36,15 @@ DEFAULT_TIMEOUT_SECONDS = 30
|
|||
DEFAULT_TTL_SECONDS = 600
|
||||
DEFAULT_ALLOW_HOST_COMPAT = False
|
||||
|
||||
TASK_LAYOUT_VERSION = 1
|
||||
TASK_LAYOUT_VERSION = 2
|
||||
TASK_WORKSPACE_DIRNAME = "workspace"
|
||||
TASK_COMMANDS_DIRNAME = "commands"
|
||||
TASK_RUNTIME_DIRNAME = "runtime"
|
||||
TASK_WORKSPACE_GUEST_PATH = "/workspace"
|
||||
TASK_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py"
|
||||
TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60
|
||||
|
||||
WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -82,6 +88,7 @@ class TaskRecord:
|
|||
network: NetworkConfig | None = None
|
||||
command_count: int = 0
|
||||
last_command: dict[str, Any] | None = None
|
||||
workspace_seed: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_instance(
|
||||
|
|
@ -90,6 +97,7 @@ class TaskRecord:
|
|||
*,
|
||||
command_count: int = 0,
|
||||
last_command: dict[str, Any] | None = None,
|
||||
workspace_seed: dict[str, Any] | None = None,
|
||||
) -> TaskRecord:
|
||||
return cls(
|
||||
task_id=instance.vm_id,
|
||||
|
|
@ -108,6 +116,7 @@ class TaskRecord:
|
|||
network=instance.network,
|
||||
command_count=command_count,
|
||||
last_command=last_command,
|
||||
workspace_seed=dict(workspace_seed or _empty_workspace_seed_payload()),
|
||||
)
|
||||
|
||||
def to_instance(self, *, workdir: Path) -> VmInstance:
|
||||
|
|
@ -148,6 +157,7 @@ class TaskRecord:
|
|||
"network": _serialize_network(self.network),
|
||||
"command_count": self.command_count,
|
||||
"last_command": self.last_command,
|
||||
"workspace_seed": self.workspace_seed,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
|
@ -169,9 +179,35 @@ class TaskRecord:
|
|||
network=_deserialize_network(payload.get("network")),
|
||||
command_count=int(payload.get("command_count", 0)),
|
||||
last_command=_optional_dict(payload.get("last_command")),
|
||||
workspace_seed=_task_workspace_seed_dict(payload.get("workspace_seed")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PreparedWorkspaceSeed:
|
||||
"""Prepared host-side seed archive plus metadata."""
|
||||
|
||||
mode: WorkspaceSeedMode
|
||||
source_path: str | None
|
||||
archive_path: Path | None = None
|
||||
entry_count: int = 0
|
||||
bytes_written: int = 0
|
||||
cleanup_dir: Path | None = None
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"mode": self.mode,
|
||||
"source_path": self.source_path,
|
||||
"destination": TASK_WORKSPACE_GUEST_PATH,
|
||||
"entry_count": self.entry_count,
|
||||
"bytes_written": self.bytes_written,
|
||||
}
|
||||
|
||||
def cleanup(self) -> None:
|
||||
if self.cleanup_dir is not None:
|
||||
shutil.rmtree(self.cleanup_dir, ignore_errors=True)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class VmExecResult:
|
||||
"""Command execution output."""
|
||||
|
|
@ -216,6 +252,32 @@ def _string_dict(value: object) -> dict[str, str]:
|
|||
return {str(key): str(item) for key, item in value.items()}
|
||||
|
||||
|
||||
def _empty_workspace_seed_payload() -> dict[str, Any]:
|
||||
return {
|
||||
"mode": "empty",
|
||||
"source_path": None,
|
||||
"destination": TASK_WORKSPACE_GUEST_PATH,
|
||||
"entry_count": 0,
|
||||
"bytes_written": 0,
|
||||
}
|
||||
|
||||
|
||||
def _task_workspace_seed_dict(value: object) -> dict[str, Any]:
|
||||
if not isinstance(value, dict):
|
||||
return _empty_workspace_seed_payload()
|
||||
payload = _empty_workspace_seed_payload()
|
||||
payload.update(
|
||||
{
|
||||
"mode": str(value.get("mode", payload["mode"])),
|
||||
"source_path": _optional_str(value.get("source_path")),
|
||||
"destination": str(value.get("destination", payload["destination"])),
|
||||
"entry_count": int(value.get("entry_count", payload["entry_count"])),
|
||||
"bytes_written": int(value.get("bytes_written", payload["bytes_written"])),
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def _serialize_network(network: NetworkConfig | None) -> dict[str, Any] | None:
|
||||
if network is None:
|
||||
return None
|
||||
|
|
@ -300,6 +362,201 @@ def _wrap_guest_command(command: str, *, cwd: str | None = None) -> str:
|
|||
return f"mkdir -p {quoted_cwd} && cd {quoted_cwd} && {command}"
|
||||
|
||||
|
||||
def _is_supported_seed_archive(path: Path) -> bool:
|
||||
name = path.name.lower()
|
||||
return name.endswith(".tar") or name.endswith(".tar.gz") or name.endswith(".tgz")
|
||||
|
||||
|
||||
def _normalize_workspace_destination(destination: str) -> tuple[str, PurePosixPath]:
|
||||
candidate = destination.strip()
|
||||
if candidate == "":
|
||||
raise ValueError("workspace destination must not be empty")
|
||||
destination_path = PurePosixPath(candidate)
|
||||
workspace_root = PurePosixPath(TASK_WORKSPACE_GUEST_PATH)
|
||||
if not destination_path.is_absolute():
|
||||
destination_path = workspace_root / destination_path
|
||||
parts = [part for part in destination_path.parts if part not in {"", "."}]
|
||||
normalized = PurePosixPath("/") / PurePosixPath(*parts)
|
||||
if normalized == PurePosixPath("/"):
|
||||
raise ValueError("workspace destination must stay inside /workspace")
|
||||
if normalized.parts[: len(workspace_root.parts)] != workspace_root.parts:
|
||||
raise ValueError("workspace destination must stay inside /workspace")
|
||||
suffix = normalized.relative_to(workspace_root)
|
||||
return str(normalized), suffix
|
||||
|
||||
|
||||
def _workspace_host_destination(workspace_dir: Path, destination: str) -> Path:
|
||||
_, suffix = _normalize_workspace_destination(destination)
|
||||
if str(suffix) in {"", "."}:
|
||||
return workspace_dir
|
||||
return workspace_dir.joinpath(*suffix.parts)
|
||||
|
||||
|
||||
def _normalize_archive_member_name(name: str) -> PurePosixPath:
|
||||
candidate = name.strip()
|
||||
if candidate == "":
|
||||
raise RuntimeError("archive member path is empty")
|
||||
member_path = PurePosixPath(candidate)
|
||||
if member_path.is_absolute():
|
||||
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
|
||||
parts = [part for part in member_path.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
normalized = PurePosixPath(*parts)
|
||||
if str(normalized) in {"", "."}:
|
||||
raise RuntimeError(f"unsafe archive member path: {name}")
|
||||
return normalized
|
||||
|
||||
|
||||
def _validate_archive_symlink_target(member_name: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
raise RuntimeError(f"symlink {member_name} has an empty target")
|
||||
link_path = PurePosixPath(target)
|
||||
if link_path.is_absolute():
|
||||
raise RuntimeError(f"symlink {member_name} escapes the workspace")
|
||||
combined = member_name.parent.joinpath(link_path)
|
||||
parts = [part for part in combined.parts if part not in {"", "."}]
|
||||
if any(part == ".." for part in parts):
|
||||
raise RuntimeError(f"symlink {member_name} escapes the workspace")
|
||||
|
||||
|
||||
def _inspect_seed_archive(archive_path: Path) -> tuple[int, int]:
|
||||
entry_count = 0
|
||||
bytes_written = 0
|
||||
with tarfile.open(archive_path, "r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_archive_member_name(member.name)
|
||||
entry_count += 1
|
||||
if member.isdir():
|
||||
continue
|
||||
if member.isfile():
|
||||
bytes_written += member.size
|
||||
continue
|
||||
if member.issym():
|
||||
_validate_archive_symlink_target(member_name, member.linkname)
|
||||
continue
|
||||
if member.islnk():
|
||||
raise RuntimeError(
|
||||
f"hard links are not allowed in workspace archives: {member.name}"
|
||||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return entry_count, bytes_written
|
||||
|
||||
|
||||
def _write_directory_seed_archive(source_dir: Path, archive_path: Path) -> None:
|
||||
archive_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with tarfile.open(archive_path, "w") as archive:
|
||||
for child in sorted(source_dir.iterdir(), key=lambda item: item.name):
|
||||
archive.add(child, arcname=child.name, recursive=True)
|
||||
|
||||
|
||||
def _extract_seed_archive_to_host_workspace(
|
||||
archive_path: Path,
|
||||
*,
|
||||
workspace_dir: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
normalized_destination, _ = _normalize_workspace_destination(destination)
|
||||
destination_root = _workspace_host_destination(workspace_dir, normalized_destination)
|
||||
destination_root.mkdir(parents=True, exist_ok=True)
|
||||
entry_count = 0
|
||||
bytes_written = 0
|
||||
with tarfile.open(archive_path, "r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_archive_member_name(member.name)
|
||||
target_path = destination_root.joinpath(*member_name.parts)
|
||||
entry_count += 1
|
||||
_ensure_no_symlink_parents(workspace_dir, target_path, member.name)
|
||||
if member.isdir():
|
||||
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
|
||||
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
if member.isfile():
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.is_symlink() or target_path.is_dir():
|
||||
raise RuntimeError(f"file conflicts with existing path: {member.name}")
|
||||
source = archive.extractfile(member)
|
||||
if source is None:
|
||||
raise RuntimeError(f"failed to read archive member: {member.name}")
|
||||
with target_path.open("wb") as handle:
|
||||
shutil.copyfileobj(source, handle)
|
||||
bytes_written += member.size
|
||||
continue
|
||||
if member.issym():
|
||||
_validate_archive_symlink_target(member_name, member.linkname)
|
||||
target_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
if target_path.exists() and not target_path.is_symlink():
|
||||
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
|
||||
if target_path.is_symlink():
|
||||
target_path.unlink()
|
||||
os.symlink(member.linkname, target_path)
|
||||
continue
|
||||
if member.islnk():
|
||||
raise RuntimeError(
|
||||
f"hard links are not allowed in workspace archives: {member.name}"
|
||||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return {
|
||||
"destination": normalized_destination,
|
||||
"entry_count": entry_count,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
|
||||
|
||||
def _instance_workspace_host_dir(instance: VmInstance) -> Path:
|
||||
raw_value = instance.metadata.get("workspace_host_dir")
|
||||
if raw_value is None or raw_value == "":
|
||||
raise RuntimeError("task workspace host directory is unavailable")
|
||||
return Path(raw_value)
|
||||
|
||||
|
||||
def _patch_rootfs_guest_agent(rootfs_image: Path, guest_agent_path: Path) -> None:
|
||||
debugfs_path = shutil.which("debugfs")
|
||||
if debugfs_path is None:
|
||||
raise RuntimeError(
|
||||
"debugfs is required to seed task workspaces on guest-backed runtimes"
|
||||
)
|
||||
with tempfile.TemporaryDirectory(prefix="pyro-guest-agent-") as temp_dir:
|
||||
staged_agent_path = Path(temp_dir) / "pyro_guest_agent.py"
|
||||
shutil.copy2(guest_agent_path, staged_agent_path)
|
||||
subprocess.run( # noqa: S603
|
||||
[debugfs_path, "-w", "-R", f"rm {TASK_GUEST_AGENT_PATH}", str(rootfs_image)],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
proc = subprocess.run( # noqa: S603
|
||||
[
|
||||
debugfs_path,
|
||||
"-w",
|
||||
"-R",
|
||||
f"write {staged_agent_path} {TASK_GUEST_AGENT_PATH}",
|
||||
str(rootfs_image),
|
||||
],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=False,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"failed to patch guest agent into task rootfs: "
|
||||
f"{proc.stderr.strip() or proc.stdout.strip()}"
|
||||
)
|
||||
|
||||
|
||||
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
|
||||
relative_path = target_path.relative_to(root)
|
||||
current = root
|
||||
for part in relative_path.parts[:-1]:
|
||||
current = current / part
|
||||
if current.is_symlink():
|
||||
raise RuntimeError(
|
||||
f"archive member would traverse through a symlinked path: {member_name}"
|
||||
)
|
||||
|
||||
|
||||
def _pid_is_running(pid: int | None) -> bool:
|
||||
if pid is None:
|
||||
return False
|
||||
|
|
@ -337,6 +594,15 @@ class VmBackend:
|
|||
def delete(self, instance: VmInstance) -> None: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
|
||||
def import_archive( # pragma: no cover
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
archive_path: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockBackend(VmBackend):
|
||||
"""Host-process backend used for development and testability."""
|
||||
|
|
@ -365,6 +631,19 @@ class MockBackend(VmBackend):
|
|||
def delete(self, instance: VmInstance) -> None:
|
||||
shutil.rmtree(instance.workdir, ignore_errors=True)
|
||||
|
||||
def import_archive(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
archive_path: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
return _extract_seed_archive_to_host_workspace(
|
||||
archive_path,
|
||||
workspace_dir=_instance_workspace_host_dir(instance),
|
||||
destination=destination,
|
||||
)
|
||||
|
||||
|
||||
class FirecrackerBackend(VmBackend): # pragma: no cover
|
||||
"""Host-gated backend that validates Firecracker prerequisites."""
|
||||
|
|
@ -562,6 +841,46 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
|
|||
self._network_manager.cleanup(instance.network)
|
||||
shutil.rmtree(instance.workdir, ignore_errors=True)
|
||||
|
||||
def import_archive(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
archive_path: Path,
|
||||
destination: str,
|
||||
) -> dict[str, Any]:
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
uds_path = instance.metadata.get("guest_exec_uds_path")
|
||||
deadline = time.monotonic() + 10
|
||||
while True:
|
||||
try:
|
||||
response = self._guest_exec_client.upload_archive(
|
||||
guest_cid,
|
||||
port,
|
||||
archive_path,
|
||||
destination=destination,
|
||||
timeout_seconds=TASK_ARCHIVE_UPLOAD_TIMEOUT_SECONDS,
|
||||
uds_path=uds_path,
|
||||
)
|
||||
return {
|
||||
"destination": response.destination,
|
||||
"entry_count": response.entry_count,
|
||||
"bytes_written": response.bytes_written,
|
||||
}
|
||||
except (OSError, RuntimeError) as exc:
|
||||
if time.monotonic() >= deadline:
|
||||
raise RuntimeError(
|
||||
f"guest archive transport did not become ready: {exc}"
|
||||
) from exc
|
||||
time.sleep(0.2)
|
||||
instance.metadata["execution_mode"] = "host_compat"
|
||||
return _extract_seed_archive_to_host_workspace(
|
||||
archive_path,
|
||||
workspace_dir=_instance_workspace_host_dir(instance),
|
||||
destination=destination,
|
||||
)
|
||||
|
||||
|
||||
class VmManager:
|
||||
"""In-process lifecycle manager for ephemeral VM environments and tasks."""
|
||||
|
|
@ -814,9 +1133,11 @@ class VmManager:
|
|||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
source_path: str | Path | None = None,
|
||||
) -> dict[str, Any]:
|
||||
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
|
||||
get_environment(environment, runtime_paths=self._runtime_paths)
|
||||
prepared_seed = self._prepare_workspace_seed(source_path)
|
||||
now = time.time()
|
||||
task_id = uuid.uuid4().hex[:12]
|
||||
task_dir = self._task_dir(task_id)
|
||||
|
|
@ -851,10 +1172,25 @@ class VmManager:
|
|||
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
|
||||
)
|
||||
self._backend.create(instance)
|
||||
if (
|
||||
prepared_seed.archive_path is not None
|
||||
and self._runtime_capabilities.supports_guest_exec
|
||||
):
|
||||
self._ensure_task_guest_seed_support(instance)
|
||||
with self._lock:
|
||||
self._start_instance_locked(instance)
|
||||
self._require_guest_exec_or_opt_in(instance)
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
workspace_seed = prepared_seed.to_payload()
|
||||
if prepared_seed.archive_path is not None:
|
||||
import_summary = self._backend.import_archive(
|
||||
instance,
|
||||
archive_path=prepared_seed.archive_path,
|
||||
destination=TASK_WORKSPACE_GUEST_PATH,
|
||||
)
|
||||
workspace_seed["entry_count"] = int(import_summary["entry_count"])
|
||||
workspace_seed["bytes_written"] = int(import_summary["bytes_written"])
|
||||
workspace_seed["destination"] = str(import_summary["destination"])
|
||||
elif self._runtime_capabilities.supports_guest_exec:
|
||||
self._backend.exec(
|
||||
instance,
|
||||
f"mkdir -p {shlex.quote(TASK_WORKSPACE_GUEST_PATH)}",
|
||||
|
|
@ -862,7 +1198,7 @@ class VmManager:
|
|||
)
|
||||
else:
|
||||
instance.metadata["execution_mode"] = "host_compat"
|
||||
task = TaskRecord.from_instance(instance)
|
||||
task = TaskRecord.from_instance(instance, workspace_seed=workspace_seed)
|
||||
self._save_task_locked(task)
|
||||
return self._serialize_task(task)
|
||||
except Exception:
|
||||
|
|
@ -879,6 +1215,8 @@ class VmManager:
|
|||
pass
|
||||
shutil.rmtree(task_dir, ignore_errors=True)
|
||||
raise
|
||||
finally:
|
||||
prepared_seed.cleanup()
|
||||
|
||||
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
|
||||
if timeout_seconds <= 0:
|
||||
|
|
@ -999,6 +1337,7 @@ class VmManager:
|
|||
"tap_name": task.network.tap_name if task.network is not None else None,
|
||||
"execution_mode": task.metadata.get("execution_mode", "pending"),
|
||||
"workspace_path": TASK_WORKSPACE_GUEST_PATH,
|
||||
"workspace_seed": _task_workspace_seed_dict(task.workspace_seed),
|
||||
"command_count": task.command_count,
|
||||
"last_command": task.last_command,
|
||||
"metadata": task.metadata,
|
||||
|
|
@ -1094,6 +1433,53 @@ class VmManager:
|
|||
execution_mode = instance.metadata.get("execution_mode", "unknown")
|
||||
return exec_result, execution_mode
|
||||
|
||||
def _prepare_workspace_seed(self, source_path: str | Path | None) -> PreparedWorkspaceSeed:
|
||||
if source_path is None:
|
||||
return PreparedWorkspaceSeed(mode="empty", source_path=None)
|
||||
resolved_source_path = Path(source_path).expanduser().resolve()
|
||||
if not resolved_source_path.exists():
|
||||
raise ValueError(f"source_path {resolved_source_path} does not exist")
|
||||
if resolved_source_path.is_dir():
|
||||
cleanup_dir = Path(tempfile.mkdtemp(prefix="pyro-task-seed-"))
|
||||
archive_path = cleanup_dir / "workspace-seed.tar"
|
||||
try:
|
||||
_write_directory_seed_archive(resolved_source_path, archive_path)
|
||||
entry_count, bytes_written = _inspect_seed_archive(archive_path)
|
||||
except Exception:
|
||||
shutil.rmtree(cleanup_dir, ignore_errors=True)
|
||||
raise
|
||||
return PreparedWorkspaceSeed(
|
||||
mode="directory",
|
||||
source_path=str(resolved_source_path),
|
||||
archive_path=archive_path,
|
||||
entry_count=entry_count,
|
||||
bytes_written=bytes_written,
|
||||
cleanup_dir=cleanup_dir,
|
||||
)
|
||||
if (
|
||||
not resolved_source_path.is_file()
|
||||
or not _is_supported_seed_archive(resolved_source_path)
|
||||
):
|
||||
raise ValueError(
|
||||
"source_path must be a directory or a .tar/.tar.gz/.tgz archive"
|
||||
)
|
||||
entry_count, bytes_written = _inspect_seed_archive(resolved_source_path)
|
||||
return PreparedWorkspaceSeed(
|
||||
mode="tar_archive",
|
||||
source_path=str(resolved_source_path),
|
||||
archive_path=resolved_source_path,
|
||||
entry_count=entry_count,
|
||||
bytes_written=bytes_written,
|
||||
)
|
||||
|
||||
def _ensure_task_guest_seed_support(self, instance: VmInstance) -> None:
|
||||
if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None:
|
||||
raise RuntimeError("runtime bundle does not provide a guest agent for task seeding")
|
||||
rootfs_image = instance.metadata.get("rootfs_image")
|
||||
if rootfs_image is None or rootfs_image == "":
|
||||
raise RuntimeError("task rootfs image is unavailable for guest workspace seeding")
|
||||
_patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path)
|
||||
|
||||
def _task_dir(self, task_id: str) -> Path:
|
||||
return self._tasks_dir / task_id
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue