Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces. Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries. Validation: - uv lock - UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py - UV_CACHE_DIR=.uv-cache make check - UV_CACHE_DIR=.uv-cache make dist-check - real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
212 lines
8.2 KiB
Python
212 lines
8.2 KiB
Python
#!/usr/bin/env python3
|
|
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import io
|
|
import json
|
|
import os
|
|
import socket
|
|
import subprocess
|
|
import tarfile
|
|
import time
|
|
from pathlib import Path, PurePosixPath
|
|
from typing import Any
|
|
|
|
PORT = 5005
|
|
BUFFER_SIZE = 65536
|
|
WORKSPACE_ROOT = PurePosixPath("/workspace")
|
|
|
|
|
|
def _read_request(conn: socket.socket) -> dict[str, Any]:
|
|
chunks: list[bytes] = []
|
|
while True:
|
|
data = conn.recv(1)
|
|
if data == b"":
|
|
break
|
|
chunks.append(data)
|
|
if data == b"\n":
|
|
break
|
|
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
|
|
if not isinstance(payload, dict):
|
|
raise RuntimeError("request must be a JSON object")
|
|
return payload
|
|
|
|
|
|
def _read_exact(conn: socket.socket, size: int) -> bytes:
|
|
remaining = size
|
|
chunks: list[bytes] = []
|
|
while remaining > 0:
|
|
data = conn.recv(min(BUFFER_SIZE, remaining))
|
|
if data == b"":
|
|
raise RuntimeError("unexpected EOF while reading archive payload")
|
|
chunks.append(data)
|
|
remaining -= len(data)
|
|
return b"".join(chunks)
|
|
|
|
|
|
def _normalize_member_name(name: str) -> PurePosixPath:
|
|
candidate = name.strip()
|
|
if candidate == "":
|
|
raise RuntimeError("archive member path is empty")
|
|
member_path = PurePosixPath(candidate)
|
|
if member_path.is_absolute():
|
|
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
|
|
parts = [part for part in member_path.parts if part not in {"", "."}]
|
|
if any(part == ".." for part in parts):
|
|
raise RuntimeError(f"unsafe archive member path: {name}")
|
|
normalized = PurePosixPath(*parts)
|
|
if str(normalized) in {"", "."}:
|
|
raise RuntimeError(f"unsafe archive member path: {name}")
|
|
return normalized
|
|
|
|
|
|
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
|
|
candidate = destination.strip()
|
|
if candidate == "":
|
|
raise RuntimeError("destination must not be empty")
|
|
destination_path = PurePosixPath(candidate)
|
|
if not destination_path.is_absolute():
|
|
destination_path = WORKSPACE_ROOT / destination_path
|
|
parts = [part for part in destination_path.parts if part not in {"", "."}]
|
|
normalized = PurePosixPath("/") / PurePosixPath(*parts)
|
|
if normalized == PurePosixPath("/"):
|
|
raise RuntimeError("destination must stay inside /workspace")
|
|
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
|
|
raise RuntimeError("destination must stay inside /workspace")
|
|
suffix = normalized.relative_to(WORKSPACE_ROOT)
|
|
host_path = Path("/workspace")
|
|
if str(suffix) not in {"", "."}:
|
|
host_path = host_path / str(suffix)
|
|
return normalized, host_path
|
|
|
|
|
|
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
|
|
target = link_target.strip()
|
|
if target == "":
|
|
raise RuntimeError(f"symlink {member_path} has an empty target")
|
|
target_path = PurePosixPath(target)
|
|
if target_path.is_absolute():
|
|
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
|
combined = member_path.parent.joinpath(target_path)
|
|
parts = [part for part in combined.parts if part not in {"", "."}]
|
|
if any(part == ".." for part in parts):
|
|
raise RuntimeError(f"symlink {member_path} escapes the workspace")
|
|
|
|
|
|
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
|
|
relative_path = target_path.relative_to(root)
|
|
current = root
|
|
for part in relative_path.parts[:-1]:
|
|
current = current / part
|
|
if current.is_symlink():
|
|
raise RuntimeError(
|
|
f"archive member would traverse through a symlinked path: {member_name}"
|
|
)
|
|
|
|
|
|
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
|
_, destination_root = _normalize_destination(destination)
|
|
destination_root.mkdir(parents=True, exist_ok=True)
|
|
bytes_written = 0
|
|
entry_count = 0
|
|
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
|
|
for member in archive.getmembers():
|
|
member_name = _normalize_member_name(member.name)
|
|
target_path = destination_root / str(member_name)
|
|
entry_count += 1
|
|
_ensure_no_symlink_parents(destination_root, target_path, member.name)
|
|
if member.isdir():
|
|
if target_path.exists() and not target_path.is_dir():
|
|
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
|
target_path.mkdir(parents=True, exist_ok=True)
|
|
continue
|
|
if member.isfile():
|
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
|
|
raise RuntimeError(f"file conflicts with existing path: {member.name}")
|
|
source = archive.extractfile(member)
|
|
if source is None:
|
|
raise RuntimeError(f"failed to read archive member: {member.name}")
|
|
with target_path.open("wb") as handle:
|
|
while True:
|
|
chunk = source.read(BUFFER_SIZE)
|
|
if chunk == b"":
|
|
break
|
|
handle.write(chunk)
|
|
bytes_written += member.size
|
|
continue
|
|
if member.issym():
|
|
_validate_symlink_target(member_name, member.linkname)
|
|
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
if target_path.exists() and not target_path.is_symlink():
|
|
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
|
|
if target_path.is_symlink():
|
|
target_path.unlink()
|
|
os.symlink(member.linkname, target_path)
|
|
continue
|
|
if member.islnk():
|
|
raise RuntimeError(
|
|
f"hard links are not allowed in workspace archives: {member.name}"
|
|
)
|
|
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
|
return {
|
|
"destination": destination,
|
|
"entry_count": entry_count,
|
|
"bytes_written": bytes_written,
|
|
}
|
|
|
|
|
|
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
|
|
started = time.monotonic()
|
|
try:
|
|
proc = subprocess.run(
|
|
["/bin/sh", "-lc", command],
|
|
text=True,
|
|
capture_output=True,
|
|
timeout=timeout_seconds,
|
|
check=False,
|
|
)
|
|
return {
|
|
"stdout": proc.stdout,
|
|
"stderr": proc.stderr,
|
|
"exit_code": proc.returncode,
|
|
"duration_ms": int((time.monotonic() - started) * 1000),
|
|
}
|
|
except subprocess.TimeoutExpired:
|
|
return {
|
|
"stdout": "",
|
|
"stderr": f"command timed out after {timeout_seconds}s",
|
|
"exit_code": 124,
|
|
"duration_ms": int((time.monotonic() - started) * 1000),
|
|
}
|
|
|
|
|
|
def main() -> None:
|
|
family = getattr(socket, "AF_VSOCK", None)
|
|
if family is None:
|
|
raise SystemExit("AF_VSOCK is unavailable")
|
|
with socket.socket(family, socket.SOCK_STREAM) as server:
|
|
server.bind((socket.VMADDR_CID_ANY, PORT))
|
|
server.listen(1)
|
|
while True:
|
|
conn, _ = server.accept()
|
|
with conn:
|
|
request = _read_request(conn)
|
|
action = str(request.get("action", "exec"))
|
|
if action == "extract_archive":
|
|
archive_size = int(request.get("archive_size", 0))
|
|
if archive_size < 0:
|
|
raise RuntimeError("archive_size must not be negative")
|
|
destination = str(request.get("destination", "/workspace"))
|
|
payload = _read_exact(conn, archive_size)
|
|
response = _extract_archive(payload, destination)
|
|
else:
|
|
command = str(request.get("command", ""))
|
|
timeout_seconds = int(request.get("timeout_seconds", 30))
|
|
response = _run_command(command, timeout_seconds)
|
|
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|