#!/usr/bin/env python3 """Guest-side exec, workspace import, and interactive shell agent.""" from __future__ import annotations import codecs import fcntl import io import json import os import pty import signal import socket import struct import subprocess import tarfile import termios import threading import time from pathlib import Path, PurePosixPath from typing import Any PORT = 5005 BUFFER_SIZE = 65536 WORKSPACE_ROOT = PurePosixPath("/workspace") SHELL_ROOT = Path("/run/pyro-shells") SHELL_SIGNAL_MAP = { "HUP": signal.SIGHUP, "INT": signal.SIGINT, "TERM": signal.SIGTERM, "KILL": signal.SIGKILL, } SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP) _SHELLS: dict[str, "GuestShellSession"] = {} _SHELLS_LOCK = threading.Lock() def _read_request(conn: socket.socket) -> dict[str, Any]: chunks: list[bytes] = [] while True: data = conn.recv(1) if data == b"": break chunks.append(data) if data == b"\n": break payload = json.loads(b"".join(chunks).decode("utf-8").strip()) if not isinstance(payload, dict): raise RuntimeError("request must be a JSON object") return payload def _read_exact(conn: socket.socket, size: int) -> bytes: remaining = size chunks: list[bytes] = [] while remaining > 0: data = conn.recv(min(BUFFER_SIZE, remaining)) if data == b"": raise RuntimeError("unexpected EOF while reading archive payload") chunks.append(data) remaining -= len(data) return b"".join(chunks) def _normalize_member_name(name: str) -> PurePosixPath: candidate = name.strip() if candidate == "": raise RuntimeError("archive member path is empty") member_path = PurePosixPath(candidate) if member_path.is_absolute(): raise RuntimeError(f"absolute archive member paths are not allowed: {name}") parts = [part for part in member_path.parts if part not in {"", "."}] if any(part == ".." for part in parts): raise RuntimeError(f"unsafe archive member path: {name}") normalized = PurePosixPath(*parts) if str(normalized) in {"", "."}: raise RuntimeError(f"unsafe archive member path: {name}") return normalized def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]: candidate = destination.strip() if candidate == "": raise RuntimeError("destination must not be empty") destination_path = PurePosixPath(candidate) if not destination_path.is_absolute(): destination_path = WORKSPACE_ROOT / destination_path parts = [part for part in destination_path.parts if part not in {"", "."}] normalized = PurePosixPath("/") / PurePosixPath(*parts) if normalized == PurePosixPath("/"): raise RuntimeError("destination must stay inside /workspace") if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts: raise RuntimeError("destination must stay inside /workspace") suffix = normalized.relative_to(WORKSPACE_ROOT) host_path = Path("/workspace") if str(suffix) not in {"", "."}: host_path = host_path.joinpath(*suffix.parts) return normalized, host_path def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]: normalized, host_path = _normalize_destination(cwd) return str(normalized), host_path def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None: target = link_target.strip() if target == "": raise RuntimeError(f"symlink {member_path} has an empty target") target_path = PurePosixPath(target) if target_path.is_absolute(): raise RuntimeError(f"symlink {member_path} escapes the workspace") combined = member_path.parent.joinpath(target_path) parts = [part for part in combined.parts if part not in {"", "."}] if any(part == ".." for part in parts): raise RuntimeError(f"symlink {member_path} escapes the workspace") def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None: relative_path = target_path.relative_to(root) current = root for part in relative_path.parts[:-1]: current = current / part if current.is_symlink(): raise RuntimeError( f"archive member would traverse through a symlinked path: {member_name}" ) def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: normalized_destination, destination_root = _normalize_destination(destination) destination_root.mkdir(parents=True, exist_ok=True) bytes_written = 0 entry_count = 0 with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive: for member in archive.getmembers(): member_name = _normalize_member_name(member.name) target_path = destination_root.joinpath(*member_name.parts) entry_count += 1 _ensure_no_symlink_parents(destination_root, target_path, member.name) if member.isdir(): if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()): raise RuntimeError(f"directory conflicts with existing path: {member.name}") target_path.mkdir(parents=True, exist_ok=True) continue if member.isfile(): target_path.parent.mkdir(parents=True, exist_ok=True) if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()): raise RuntimeError(f"file conflicts with existing path: {member.name}") source = archive.extractfile(member) if source is None: raise RuntimeError(f"failed to read archive member: {member.name}") with target_path.open("wb") as handle: while True: chunk = source.read(BUFFER_SIZE) if chunk == b"": break handle.write(chunk) bytes_written += member.size continue if member.issym(): _validate_symlink_target(member_name, member.linkname) target_path.parent.mkdir(parents=True, exist_ok=True) if target_path.exists() and not target_path.is_symlink(): raise RuntimeError(f"symlink conflicts with existing path: {member.name}") if target_path.is_symlink(): target_path.unlink() os.symlink(member.linkname, target_path) continue if member.islnk(): raise RuntimeError( f"hard links are not allowed in workspace archives: {member.name}" ) raise RuntimeError(f"unsupported archive member type: {member.name}") return { "destination": str(normalized_destination), "entry_count": entry_count, "bytes_written": bytes_written, } def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]: started = time.monotonic() try: proc = subprocess.run( ["/bin/sh", "-lc", command], text=True, capture_output=True, timeout=timeout_seconds, check=False, ) return { "stdout": proc.stdout, "stderr": proc.stderr, "exit_code": proc.returncode, "duration_ms": int((time.monotonic() - started) * 1000), } except subprocess.TimeoutExpired: return { "stdout": "", "stderr": f"command timed out after {timeout_seconds}s", "exit_code": 124, "duration_ms": int((time.monotonic() - started) * 1000), } def _set_pty_size(fd: int, rows: int, cols: int) -> None: winsize = struct.pack("HHHH", rows, cols, 0, 0) fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize) class GuestShellSession: """In-guest PTY-backed interactive shell session.""" def __init__( self, *, shell_id: str, cwd: Path, cwd_text: str, cols: int, rows: int, ) -> None: self.shell_id = shell_id self.cwd = cwd_text self.cols = cols self.rows = rows self.started_at = time.time() self.ended_at: float | None = None self.exit_code: int | None = None self.state = "running" self._lock = threading.RLock() self._output = "" self._decoder = codecs.getincrementaldecoder("utf-8")("replace") self._metadata_path = SHELL_ROOT / f"{shell_id}.json" self._log_path = SHELL_ROOT / f"{shell_id}.log" self._master_fd: int | None = None master_fd, slave_fd = pty.openpty() try: _set_pty_size(slave_fd, rows, cols) env = os.environ.copy() env.update( { "TERM": env.get("TERM", "xterm-256color"), "PS1": "pyro$ ", "PROMPT_COMMAND": "", } ) process = subprocess.Popen( # noqa: S603 ["/bin/bash", "--noprofile", "--norc", "-i"], stdin=slave_fd, stdout=slave_fd, stderr=slave_fd, cwd=str(cwd), env=env, text=False, close_fds=True, preexec_fn=os.setsid, ) except Exception: os.close(master_fd) raise finally: os.close(slave_fd) self._process = process self._master_fd = master_fd self._write_metadata() self._reader = threading.Thread(target=self._reader_loop, daemon=True) self._waiter = threading.Thread(target=self._waiter_loop, daemon=True) self._reader.start() self._waiter.start() def summary(self) -> dict[str, Any]: with self._lock: return { "shell_id": self.shell_id, "cwd": self.cwd, "cols": self.cols, "rows": self.rows, "state": self.state, "started_at": self.started_at, "ended_at": self.ended_at, "exit_code": self.exit_code, } def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]: with self._lock: clamped_cursor = min(max(cursor, 0), len(self._output)) output = self._output[clamped_cursor : clamped_cursor + max_chars] next_cursor = clamped_cursor + len(output) payload = self.summary() payload.update( { "cursor": clamped_cursor, "next_cursor": next_cursor, "output": output, "truncated": next_cursor < len(self._output), } ) return payload def write(self, text: str, *, append_newline: bool) -> dict[str, Any]: if self._process.poll() is not None: self._refresh_process_state() with self._lock: if self.state != "running": raise RuntimeError(f"shell {self.shell_id} is not running") master_fd = self._master_fd if master_fd is None: raise RuntimeError(f"shell {self.shell_id} transport is unavailable") payload = text + ("\n" if append_newline else "") try: os.write(master_fd, payload.encode("utf-8")) except OSError as exc: self._refresh_process_state() raise RuntimeError(f"failed to write shell input: {exc}") from exc response = self.summary() response.update({"input_length": len(text), "append_newline": append_newline}) return response def send_signal(self, signal_name: str) -> dict[str, Any]: signum = SHELL_SIGNAL_MAP.get(signal_name) if signum is None: raise ValueError(f"unsupported shell signal: {signal_name}") if self._process.poll() is not None: self._refresh_process_state() with self._lock: if self.state != "running": raise RuntimeError(f"shell {self.shell_id} is not running") pid = self._process.pid try: os.killpg(pid, signum) except ProcessLookupError as exc: self._refresh_process_state() raise RuntimeError(f"shell {self.shell_id} is not running") from exc response = self.summary() response["signal"] = signal_name return response def close(self) -> dict[str, Any]: if self._process.poll() is None: try: os.killpg(self._process.pid, signal.SIGHUP) except ProcessLookupError: pass try: self._process.wait(timeout=5) except subprocess.TimeoutExpired: try: os.killpg(self._process.pid, signal.SIGKILL) except ProcessLookupError: pass self._process.wait(timeout=5) else: self._refresh_process_state() self._close_master_fd() if self._reader is not None: self._reader.join(timeout=1) if self._waiter is not None: self._waiter.join(timeout=1) response = self.summary() response["closed"] = True return response def _reader_loop(self) -> None: master_fd = self._master_fd if master_fd is None: return while True: try: chunk = os.read(master_fd, BUFFER_SIZE) except OSError: break if chunk == b"": break decoded = self._decoder.decode(chunk) if decoded == "": continue with self._lock: self._output += decoded with self._log_path.open("a", encoding="utf-8") as handle: handle.write(decoded) decoded = self._decoder.decode(b"", final=True) if decoded != "": with self._lock: self._output += decoded with self._log_path.open("a", encoding="utf-8") as handle: handle.write(decoded) def _waiter_loop(self) -> None: exit_code = self._process.wait() with self._lock: self.state = "stopped" self.exit_code = exit_code self.ended_at = time.time() self._write_metadata() def _refresh_process_state(self) -> None: exit_code = self._process.poll() if exit_code is None: return with self._lock: if self.state == "running": self.state = "stopped" self.exit_code = exit_code self.ended_at = time.time() self._write_metadata() def _write_metadata(self) -> None: self._metadata_path.parent.mkdir(parents=True, exist_ok=True) self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8") def _close_master_fd(self) -> None: with self._lock: master_fd = self._master_fd self._master_fd = None if master_fd is None: return try: os.close(master_fd) except OSError: pass def _create_shell( *, shell_id: str, cwd_text: str, cols: int, rows: int, ) -> GuestShellSession: _, cwd_path = _normalize_shell_cwd(cwd_text) with _SHELLS_LOCK: if shell_id in _SHELLS: raise RuntimeError(f"shell {shell_id!r} already exists") session = GuestShellSession( shell_id=shell_id, cwd=cwd_path, cwd_text=cwd_text, cols=cols, rows=rows, ) _SHELLS[shell_id] = session return session def _get_shell(shell_id: str) -> GuestShellSession: with _SHELLS_LOCK: try: return _SHELLS[shell_id] except KeyError as exc: raise RuntimeError(f"shell {shell_id!r} does not exist") from exc def _remove_shell(shell_id: str) -> GuestShellSession: with _SHELLS_LOCK: try: return _SHELLS.pop(shell_id) except KeyError as exc: raise RuntimeError(f"shell {shell_id!r} does not exist") from exc def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]: action = str(request.get("action", "exec")) if action == "extract_archive": archive_size = int(request.get("archive_size", 0)) if archive_size < 0: raise RuntimeError("archive_size must not be negative") destination = str(request.get("destination", "/workspace")) payload = _read_exact(conn, archive_size) return _extract_archive(payload, destination) if action == "open_shell": shell_id = str(request.get("shell_id", "")).strip() if shell_id == "": raise RuntimeError("shell_id is required") cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace"))) session = _create_shell( shell_id=shell_id, cwd_text=cwd_text, cols=int(request.get("cols", 120)), rows=int(request.get("rows", 30)), ) return session.summary() if action == "read_shell": shell_id = str(request.get("shell_id", "")).strip() if shell_id == "": raise RuntimeError("shell_id is required") return _get_shell(shell_id).read( cursor=int(request.get("cursor", 0)), max_chars=int(request.get("max_chars", 65536)), ) if action == "write_shell": shell_id = str(request.get("shell_id", "")).strip() if shell_id == "": raise RuntimeError("shell_id is required") return _get_shell(shell_id).write( str(request.get("input", "")), append_newline=bool(request.get("append_newline", True)), ) if action == "signal_shell": shell_id = str(request.get("shell_id", "")).strip() if shell_id == "": raise RuntimeError("shell_id is required") signal_name = str(request.get("signal", "INT")).upper() if signal_name not in SHELL_SIGNAL_NAMES: raise RuntimeError( f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}" ) return _get_shell(shell_id).send_signal(signal_name) if action == "close_shell": shell_id = str(request.get("shell_id", "")).strip() if shell_id == "": raise RuntimeError("shell_id is required") return _remove_shell(shell_id).close() command = str(request.get("command", "")) timeout_seconds = int(request.get("timeout_seconds", 30)) return _run_command(command, timeout_seconds) def main() -> None: SHELL_ROOT.mkdir(parents=True, exist_ok=True) family = getattr(socket, "AF_VSOCK", None) if family is None: raise SystemExit("AF_VSOCK is unavailable") with socket.socket(family, socket.SOCK_STREAM) as server: server.bind((socket.VMADDR_CID_ANY, PORT)) server.listen(1) while True: conn, _ = server.accept() with conn: try: request = _read_request(conn) response = _dispatch(request, conn) except Exception as exc: # noqa: BLE001 response = {"error": str(exc)} conn.sendall((json.dumps(response) + "\n").encode("utf-8")) if __name__ == "__main__": main()