pyro-mcp/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py
Thales Maciel aa886b346e Add seeded task workspace creation
Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces.

Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries.

Validation:
- uv lock
- UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py
- UV_CACHE_DIR=.uv-cache make check
- UV_CACHE_DIR=.uv-cache make dist-check
- real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
2026-03-11 21:45:38 -03:00

212 lines
8.2 KiB
Python

#!/usr/bin/env python3
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
from __future__ import annotations
import io
import json
import os
import socket
import subprocess
import tarfile
import time
from pathlib import Path, PurePosixPath
from typing import Any
PORT = 5005
BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
def _read_request(conn: socket.socket) -> dict[str, Any]:
chunks: list[bytes] = []
while True:
data = conn.recv(1)
if data == b"":
break
chunks.append(data)
if data == b"\n":
break
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
if not isinstance(payload, dict):
raise RuntimeError("request must be a JSON object")
return payload
def _read_exact(conn: socket.socket, size: int) -> bytes:
remaining = size
chunks: list[bytes] = []
while remaining > 0:
data = conn.recv(min(BUFFER_SIZE, remaining))
if data == b"":
raise RuntimeError("unexpected EOF while reading archive payload")
chunks.append(data)
remaining -= len(data)
return b"".join(chunks)
def _normalize_member_name(name: str) -> PurePosixPath:
candidate = name.strip()
if candidate == "":
raise RuntimeError("archive member path is empty")
member_path = PurePosixPath(candidate)
if member_path.is_absolute():
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
parts = [part for part in member_path.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"unsafe archive member path: {name}")
normalized = PurePosixPath(*parts)
if str(normalized) in {"", "."}:
raise RuntimeError(f"unsafe archive member path: {name}")
return normalized
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
candidate = destination.strip()
if candidate == "":
raise RuntimeError("destination must not be empty")
destination_path = PurePosixPath(candidate)
if not destination_path.is_absolute():
destination_path = WORKSPACE_ROOT / destination_path
parts = [part for part in destination_path.parts if part not in {"", "."}]
normalized = PurePosixPath("/") / PurePosixPath(*parts)
if normalized == PurePosixPath("/"):
raise RuntimeError("destination must stay inside /workspace")
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
raise RuntimeError("destination must stay inside /workspace")
suffix = normalized.relative_to(WORKSPACE_ROOT)
host_path = Path("/workspace")
if str(suffix) not in {"", "."}:
host_path = host_path / str(suffix)
return normalized, host_path
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
raise RuntimeError(f"symlink {member_path} has an empty target")
target_path = PurePosixPath(target)
if target_path.is_absolute():
raise RuntimeError(f"symlink {member_path} escapes the workspace")
combined = member_path.parent.joinpath(target_path)
parts = [part for part in combined.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"symlink {member_path} escapes the workspace")
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"archive member would traverse through a symlinked path: {member_name}"
)
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
_, destination_root = _normalize_destination(destination)
destination_root.mkdir(parents=True, exist_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = destination_root / str(member_name)
entry_count += 1
_ensure_no_symlink_parents(destination_root, target_path, member.name)
if member.isdir():
if target_path.exists() and not target_path.is_dir():
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
raise RuntimeError(f"file conflicts with existing path: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read archive member: {member.name}")
with target_path.open("wb") as handle:
while True:
chunk = source.read(BUFFER_SIZE)
if chunk == b"":
break
handle.write(chunk)
bytes_written += member.size
continue
if member.issym():
_validate_symlink_target(member_name, member.linkname)
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not target_path.is_symlink():
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
if target_path.is_symlink():
target_path.unlink()
os.symlink(member.linkname, target_path)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": destination,
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
started = time.monotonic()
try:
proc = subprocess.run(
["/bin/sh", "-lc", command],
text=True,
capture_output=True,
timeout=timeout_seconds,
check=False,
)
return {
"stdout": proc.stdout,
"stderr": proc.stderr,
"exit_code": proc.returncode,
"duration_ms": int((time.monotonic() - started) * 1000),
}
except subprocess.TimeoutExpired:
return {
"stdout": "",
"stderr": f"command timed out after {timeout_seconds}s",
"exit_code": 124,
"duration_ms": int((time.monotonic() - started) * 1000),
}
def main() -> None:
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise SystemExit("AF_VSOCK is unavailable")
with socket.socket(family, socket.SOCK_STREAM) as server:
server.bind((socket.VMADDR_CID_ANY, PORT))
server.listen(1)
while True:
conn, _ = server.accept()
with conn:
request = _read_request(conn)
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
response = _extract_archive(payload, destination)
else:
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
response = _run_command(command, timeout_seconds)
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
if __name__ == "__main__":
main()