pyro-mcp/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py
Thales Maciel 84a7e18d4d Add workspace export and baseline diff
Complete the 2.6.0 workspace milestone by adding explicit host-out export and immutable-baseline diff across the CLI, Python SDK, and MCP server.

Capture a baseline archive at workspace creation, export live /workspace paths through the guest agent, and compute structured whole-workspace diffs on the host without affecting command logs or shell state. The docs, roadmap, bundled guest agent, and workspace example now reflect the new create -> sync -> diff -> export workflow.

Validation: uv lock, UV_CACHE_DIR=.uv-cache make check, UV_CACHE_DIR=.uv-cache make dist-check, and a real guest-backed Firecracker smoke covering workspace create, sync push, diff, export, and delete.
2026-03-12 03:15:45 -03:00

612 lines
22 KiB
Python

#!/usr/bin/env python3
"""Guest-side exec, workspace import, and interactive shell agent."""
from __future__ import annotations
import codecs
import fcntl
import io
import json
import os
import pty
import signal
import socket
import struct
import subprocess
import tarfile
import tempfile
import termios
import threading
import time
from pathlib import Path, PurePosixPath
from typing import Any
PORT = 5005
BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
SHELL_ROOT = Path("/run/pyro-shells")
SHELL_SIGNAL_MAP = {
"HUP": signal.SIGHUP,
"INT": signal.SIGINT,
"TERM": signal.SIGTERM,
"KILL": signal.SIGKILL,
}
SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP)
_SHELLS: dict[str, "GuestShellSession"] = {}
_SHELLS_LOCK = threading.Lock()
def _read_request(conn: socket.socket) -> dict[str, Any]:
chunks: list[bytes] = []
while True:
data = conn.recv(1)
if data == b"":
break
chunks.append(data)
if data == b"\n":
break
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
if not isinstance(payload, dict):
raise RuntimeError("request must be a JSON object")
return payload
def _read_exact(conn: socket.socket, size: int) -> bytes:
remaining = size
chunks: list[bytes] = []
while remaining > 0:
data = conn.recv(min(BUFFER_SIZE, remaining))
if data == b"":
raise RuntimeError("unexpected EOF while reading archive payload")
chunks.append(data)
remaining -= len(data)
return b"".join(chunks)
def _normalize_member_name(name: str) -> PurePosixPath:
candidate = name.strip()
if candidate == "":
raise RuntimeError("archive member path is empty")
member_path = PurePosixPath(candidate)
if member_path.is_absolute():
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
parts = [part for part in member_path.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"unsafe archive member path: {name}")
normalized = PurePosixPath(*parts)
if str(normalized) in {"", "."}:
raise RuntimeError(f"unsafe archive member path: {name}")
return normalized
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
candidate = destination.strip()
if candidate == "":
raise RuntimeError("destination must not be empty")
destination_path = PurePosixPath(candidate)
if not destination_path.is_absolute():
destination_path = WORKSPACE_ROOT / destination_path
parts = [part for part in destination_path.parts if part not in {"", "."}]
normalized = PurePosixPath("/") / PurePosixPath(*parts)
if normalized == PurePosixPath("/"):
raise RuntimeError("destination must stay inside /workspace")
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
raise RuntimeError("destination must stay inside /workspace")
suffix = normalized.relative_to(WORKSPACE_ROOT)
host_path = Path("/workspace")
if str(suffix) not in {"", "."}:
host_path = host_path.joinpath(*suffix.parts)
return normalized, host_path
def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]:
normalized, host_path = _normalize_destination(cwd)
return str(normalized), host_path
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
raise RuntimeError(f"symlink {member_path} has an empty target")
target_path = PurePosixPath(target)
if target_path.is_absolute():
raise RuntimeError(f"symlink {member_path} escapes the workspace")
combined = member_path.parent.joinpath(target_path)
parts = [part for part in combined.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"symlink {member_path} escapes the workspace")
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"archive member would traverse through a symlinked path: {member_name}"
)
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
normalized_destination, destination_root = _normalize_destination(destination)
destination_root.mkdir(parents=True, exist_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = destination_root.joinpath(*member_name.parts)
entry_count += 1
_ensure_no_symlink_parents(destination_root, target_path, member.name)
if member.isdir():
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
raise RuntimeError(f"file conflicts with existing path: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read archive member: {member.name}")
with target_path.open("wb") as handle:
while True:
chunk = source.read(BUFFER_SIZE)
if chunk == b"":
break
handle.write(chunk)
bytes_written += member.size
continue
if member.issym():
_validate_symlink_target(member_name, member.linkname)
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not target_path.is_symlink():
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
if target_path.is_symlink():
target_path.unlink()
os.symlink(member.linkname, target_path)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": str(normalized_destination),
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _inspect_archive(archive_path: Path) -> tuple[int, int]:
entry_count = 0
bytes_written = 0
with tarfile.open(archive_path, "r:*") as archive:
for member in archive.getmembers():
entry_count += 1
if member.isfile():
bytes_written += member.size
return entry_count, bytes_written
def _prepare_export_archive(path: str) -> dict[str, Any]:
normalized_path, source_path = _normalize_destination(path)
if not source_path.exists() and not source_path.is_symlink():
raise RuntimeError(f"workspace path does not exist: {normalized_path}")
if source_path.is_symlink():
artifact_type = "symlink"
elif source_path.is_file():
artifact_type = "file"
elif source_path.is_dir():
artifact_type = "directory"
else:
raise RuntimeError(f"unsupported workspace path type: {normalized_path}")
with tempfile.NamedTemporaryFile(prefix="pyro-export-", suffix=".tar", delete=False) as handle:
archive_path = Path(handle.name)
try:
with tarfile.open(archive_path, "w") as archive:
archive.dereference = False
if artifact_type == "directory":
for child in sorted(source_path.iterdir(), key=lambda item: item.name):
archive.add(child, arcname=child.name, recursive=True)
else:
archive.add(source_path, arcname=source_path.name, recursive=False)
entry_count, bytes_written = _inspect_archive(archive_path)
return {
"workspace_path": str(normalized_path),
"artifact_type": artifact_type,
"archive_path": archive_path,
"archive_size": archive_path.stat().st_size,
"entry_count": entry_count,
"bytes_written": bytes_written,
}
except Exception:
archive_path.unlink(missing_ok=True)
raise
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
started = time.monotonic()
try:
proc = subprocess.run(
["/bin/sh", "-lc", command],
text=True,
capture_output=True,
timeout=timeout_seconds,
check=False,
)
return {
"stdout": proc.stdout,
"stderr": proc.stderr,
"exit_code": proc.returncode,
"duration_ms": int((time.monotonic() - started) * 1000),
}
except subprocess.TimeoutExpired:
return {
"stdout": "",
"stderr": f"command timed out after {timeout_seconds}s",
"exit_code": 124,
"duration_ms": int((time.monotonic() - started) * 1000),
}
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
winsize = struct.pack("HHHH", rows, cols, 0, 0)
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
class GuestShellSession:
"""In-guest PTY-backed interactive shell session."""
def __init__(
self,
*,
shell_id: str,
cwd: Path,
cwd_text: str,
cols: int,
rows: int,
) -> None:
self.shell_id = shell_id
self.cwd = cwd_text
self.cols = cols
self.rows = rows
self.started_at = time.time()
self.ended_at: float | None = None
self.exit_code: int | None = None
self.state = "running"
self._lock = threading.RLock()
self._output = ""
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
self._metadata_path = SHELL_ROOT / f"{shell_id}.json"
self._log_path = SHELL_ROOT / f"{shell_id}.log"
self._master_fd: int | None = None
master_fd, slave_fd = pty.openpty()
try:
_set_pty_size(slave_fd, rows, cols)
env = os.environ.copy()
env.update(
{
"TERM": env.get("TERM", "xterm-256color"),
"PS1": "pyro$ ",
"PROMPT_COMMAND": "",
}
)
process = subprocess.Popen( # noqa: S603
["/bin/bash", "--noprofile", "--norc", "-i"],
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd=str(cwd),
env=env,
text=False,
close_fds=True,
preexec_fn=os.setsid,
)
except Exception:
os.close(master_fd)
raise
finally:
os.close(slave_fd)
self._process = process
self._master_fd = master_fd
self._write_metadata()
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
self._reader.start()
self._waiter.start()
def summary(self) -> dict[str, Any]:
with self._lock:
return {
"shell_id": self.shell_id,
"cwd": self.cwd,
"cols": self.cols,
"rows": self.rows,
"state": self.state,
"started_at": self.started_at,
"ended_at": self.ended_at,
"exit_code": self.exit_code,
}
def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]:
with self._lock:
clamped_cursor = min(max(cursor, 0), len(self._output))
output = self._output[clamped_cursor : clamped_cursor + max_chars]
next_cursor = clamped_cursor + len(output)
payload = self.summary()
payload.update(
{
"cursor": clamped_cursor,
"next_cursor": next_cursor,
"output": output,
"truncated": next_cursor < len(self._output),
}
)
return payload
def write(self, text: str, *, append_newline: bool) -> dict[str, Any]:
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
master_fd = self._master_fd
if master_fd is None:
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
payload = text + ("\n" if append_newline else "")
try:
os.write(master_fd, payload.encode("utf-8"))
except OSError as exc:
self._refresh_process_state()
raise RuntimeError(f"failed to write shell input: {exc}") from exc
response = self.summary()
response.update({"input_length": len(text), "append_newline": append_newline})
return response
def send_signal(self, signal_name: str) -> dict[str, Any]:
signum = SHELL_SIGNAL_MAP.get(signal_name)
if signum is None:
raise ValueError(f"unsupported shell signal: {signal_name}")
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
pid = self._process.pid
try:
os.killpg(pid, signum)
except ProcessLookupError as exc:
self._refresh_process_state()
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
response = self.summary()
response["signal"] = signal_name
return response
def close(self) -> dict[str, Any]:
if self._process.poll() is None:
try:
os.killpg(self._process.pid, signal.SIGHUP)
except ProcessLookupError:
pass
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
try:
os.killpg(self._process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
self._process.wait(timeout=5)
else:
self._refresh_process_state()
self._close_master_fd()
if self._reader is not None:
self._reader.join(timeout=1)
if self._waiter is not None:
self._waiter.join(timeout=1)
response = self.summary()
response["closed"] = True
return response
def _reader_loop(self) -> None:
master_fd = self._master_fd
if master_fd is None:
return
while True:
try:
chunk = os.read(master_fd, BUFFER_SIZE)
except OSError:
break
if chunk == b"":
break
decoded = self._decoder.decode(chunk)
if decoded == "":
continue
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
decoded = self._decoder.decode(b"", final=True)
if decoded != "":
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
def _waiter_loop(self) -> None:
exit_code = self._process.wait()
with self._lock:
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _refresh_process_state(self) -> None:
exit_code = self._process.poll()
if exit_code is None:
return
with self._lock:
if self.state == "running":
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _write_metadata(self) -> None:
self._metadata_path.parent.mkdir(parents=True, exist_ok=True)
self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8")
def _close_master_fd(self) -> None:
with self._lock:
master_fd = self._master_fd
self._master_fd = None
if master_fd is None:
return
try:
os.close(master_fd)
except OSError:
pass
def _create_shell(
*,
shell_id: str,
cwd_text: str,
cols: int,
rows: int,
) -> GuestShellSession:
_, cwd_path = _normalize_shell_cwd(cwd_text)
with _SHELLS_LOCK:
if shell_id in _SHELLS:
raise RuntimeError(f"shell {shell_id!r} already exists")
session = GuestShellSession(
shell_id=shell_id,
cwd=cwd_path,
cwd_text=cwd_text,
cols=cols,
rows=rows,
)
_SHELLS[shell_id] = session
return session
def _get_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS[shell_id]
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _remove_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS.pop(shell_id)
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
return _extract_archive(payload, destination)
if action == "open_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace")))
session = _create_shell(
shell_id=shell_id,
cwd_text=cwd_text,
cols=int(request.get("cols", 120)),
rows=int(request.get("rows", 30)),
)
return session.summary()
if action == "read_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).read(
cursor=int(request.get("cursor", 0)),
max_chars=int(request.get("max_chars", 65536)),
)
if action == "write_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).write(
str(request.get("input", "")),
append_newline=bool(request.get("append_newline", True)),
)
if action == "signal_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
signal_name = str(request.get("signal", "INT")).upper()
if signal_name not in SHELL_SIGNAL_NAMES:
raise RuntimeError(
f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}"
)
return _get_shell(shell_id).send_signal(signal_name)
if action == "close_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _remove_shell(shell_id).close()
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
return _run_command(command, timeout_seconds)
def main() -> None:
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise SystemExit("AF_VSOCK is unavailable")
with socket.socket(family, socket.SOCK_STREAM) as server:
server.bind((socket.VMADDR_CID_ANY, PORT))
server.listen(1)
while True:
conn, _ = server.accept()
with conn:
try:
request = _read_request(conn)
if str(request.get("action", "")) == "export_archive":
export = _prepare_export_archive(str(request.get("path", "/workspace")))
try:
header = {
"workspace_path": export["workspace_path"],
"artifact_type": export["artifact_type"],
"archive_size": export["archive_size"],
"entry_count": export["entry_count"],
"bytes_written": export["bytes_written"],
}
conn.sendall((json.dumps(header) + "\n").encode("utf-8"))
with Path(str(export["archive_path"])).open("rb") as handle:
while True:
chunk = handle.read(BUFFER_SIZE)
if chunk == b"":
break
conn.sendall(chunk)
finally:
Path(str(export["archive_path"])).unlink(missing_ok=True)
continue
response = _dispatch(request, conn)
except Exception as exc: # noqa: BLE001
response = {"error": str(exc)}
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
if __name__ == "__main__":
main()