#!/usr/bin/env python3 """Minimal guest-side exec and workspace import agent for pyro runtime bundles.""" from __future__ import annotations import io import json import os import socket import subprocess import tarfile import time from pathlib import Path, PurePosixPath from typing import Any PORT = 5005 BUFFER_SIZE = 65536 WORKSPACE_ROOT = PurePosixPath("/workspace") def _read_request(conn: socket.socket) -> dict[str, Any]: chunks: list[bytes] = [] while True: data = conn.recv(1) if data == b"": break chunks.append(data) if data == b"\n": break payload = json.loads(b"".join(chunks).decode("utf-8").strip()) if not isinstance(payload, dict): raise RuntimeError("request must be a JSON object") return payload def _read_exact(conn: socket.socket, size: int) -> bytes: remaining = size chunks: list[bytes] = [] while remaining > 0: data = conn.recv(min(BUFFER_SIZE, remaining)) if data == b"": raise RuntimeError("unexpected EOF while reading archive payload") chunks.append(data) remaining -= len(data) return b"".join(chunks) def _normalize_member_name(name: str) -> PurePosixPath: candidate = name.strip() if candidate == "": raise RuntimeError("archive member path is empty") member_path = PurePosixPath(candidate) if member_path.is_absolute(): raise RuntimeError(f"absolute archive member paths are not allowed: {name}") parts = [part for part in member_path.parts if part not in {"", "."}] if any(part == ".." for part in parts): raise RuntimeError(f"unsafe archive member path: {name}") normalized = PurePosixPath(*parts) if str(normalized) in {"", "."}: raise RuntimeError(f"unsafe archive member path: {name}") return normalized def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]: candidate = destination.strip() if candidate == "": raise RuntimeError("destination must not be empty") destination_path = PurePosixPath(candidate) if not destination_path.is_absolute(): destination_path = WORKSPACE_ROOT / destination_path parts = [part for part in destination_path.parts if part not in {"", "."}] normalized = PurePosixPath("/") / PurePosixPath(*parts) if normalized == PurePosixPath("/"): raise RuntimeError("destination must stay inside /workspace") if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts: raise RuntimeError("destination must stay inside /workspace") suffix = normalized.relative_to(WORKSPACE_ROOT) host_path = Path("/workspace") if str(suffix) not in {"", "."}: host_path = host_path / str(suffix) return normalized, host_path def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None: target = link_target.strip() if target == "": raise RuntimeError(f"symlink {member_path} has an empty target") target_path = PurePosixPath(target) if target_path.is_absolute(): raise RuntimeError(f"symlink {member_path} escapes the workspace") combined = member_path.parent.joinpath(target_path) parts = [part for part in combined.parts if part not in {"", "."}] if any(part == ".." for part in parts): raise RuntimeError(f"symlink {member_path} escapes the workspace") def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None: relative_path = target_path.relative_to(root) current = root for part in relative_path.parts[:-1]: current = current / part if current.is_symlink(): raise RuntimeError( f"archive member would traverse through a symlinked path: {member_name}" ) def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: _, destination_root = _normalize_destination(destination) destination_root.mkdir(parents=True, exist_ok=True) bytes_written = 0 entry_count = 0 with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive: for member in archive.getmembers(): member_name = _normalize_member_name(member.name) target_path = destination_root / str(member_name) entry_count += 1 _ensure_no_symlink_parents(destination_root, target_path, member.name) if member.isdir(): if target_path.exists() and not target_path.is_dir(): raise RuntimeError(f"directory conflicts with existing path: {member.name}") target_path.mkdir(parents=True, exist_ok=True) continue if member.isfile(): target_path.parent.mkdir(parents=True, exist_ok=True) if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()): raise RuntimeError(f"file conflicts with existing path: {member.name}") source = archive.extractfile(member) if source is None: raise RuntimeError(f"failed to read archive member: {member.name}") with target_path.open("wb") as handle: while True: chunk = source.read(BUFFER_SIZE) if chunk == b"": break handle.write(chunk) bytes_written += member.size continue if member.issym(): _validate_symlink_target(member_name, member.linkname) target_path.parent.mkdir(parents=True, exist_ok=True) if target_path.exists() and not target_path.is_symlink(): raise RuntimeError(f"symlink conflicts with existing path: {member.name}") if target_path.is_symlink(): target_path.unlink() os.symlink(member.linkname, target_path) continue if member.islnk(): raise RuntimeError( f"hard links are not allowed in workspace archives: {member.name}" ) raise RuntimeError(f"unsupported archive member type: {member.name}") return { "destination": destination, "entry_count": entry_count, "bytes_written": bytes_written, } def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]: started = time.monotonic() try: proc = subprocess.run( ["/bin/sh", "-lc", command], text=True, capture_output=True, timeout=timeout_seconds, check=False, ) return { "stdout": proc.stdout, "stderr": proc.stderr, "exit_code": proc.returncode, "duration_ms": int((time.monotonic() - started) * 1000), } except subprocess.TimeoutExpired: return { "stdout": "", "stderr": f"command timed out after {timeout_seconds}s", "exit_code": 124, "duration_ms": int((time.monotonic() - started) * 1000), } def main() -> None: family = getattr(socket, "AF_VSOCK", None) if family is None: raise SystemExit("AF_VSOCK is unavailable") with socket.socket(family, socket.SOCK_STREAM) as server: server.bind((socket.VMADDR_CID_ANY, PORT)) server.listen(1) while True: conn, _ = server.accept() with conn: request = _read_request(conn) action = str(request.get("action", "exec")) if action == "extract_archive": archive_size = int(request.get("archive_size", 0)) if archive_size < 0: raise RuntimeError("archive_size must not be negative") destination = str(request.get("destination", "/workspace")) payload = _read_exact(conn, archive_size) response = _extract_archive(payload, destination) else: command = str(request.get("command", "")) timeout_seconds = int(request.get("timeout_seconds", 30)) response = _run_command(command, timeout_seconds) conn.sendall((json.dumps(response) + "\n").encode("utf-8")) if __name__ == "__main__": main()