Add persistent workspace shell sessions

Let agents inhabit a workspace across separate calls instead of only submitting one-shot execs.

Add workspace shell open/read/write/signal/close across the CLI, Python SDK, and MCP server, with persisted shell records, a local PTY-backed mock implementation, and guest-agent support for real Firecracker workspaces.

Mark the 2.5.0 roadmap milestone done, refresh docs/examples and the release metadata, and verify with uv lock, UV_CACHE_DIR=.uv-cache make check, and UV_CACHE_DIR=.uv-cache make dist-check.
This commit is contained in:
Thales Maciel 2026-03-12 02:31:57 -03:00
parent 2de31306b6
commit 3f8293ad24
28 changed files with 3265 additions and 81 deletions

View file

@ -1,14 +1,21 @@
#!/usr/bin/env python3
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
"""Guest-side exec, workspace import, and interactive shell agent."""
from __future__ import annotations
import codecs
import fcntl
import io
import json
import os
import pty
import signal
import socket
import struct
import subprocess
import tarfile
import termios
import threading
import time
from pathlib import Path, PurePosixPath
from typing import Any
@ -16,6 +23,17 @@ from typing import Any
PORT = 5005
BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
SHELL_ROOT = Path("/run/pyro-shells")
SHELL_SIGNAL_MAP = {
"HUP": signal.SIGHUP,
"INT": signal.SIGINT,
"TERM": signal.SIGTERM,
"KILL": signal.SIGKILL,
}
SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP)
_SHELLS: dict[str, "GuestShellSession"] = {}
_SHELLS_LOCK = threading.Lock()
def _read_request(conn: socket.socket) -> dict[str, Any]:
@ -77,10 +95,15 @@ def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
suffix = normalized.relative_to(WORKSPACE_ROOT)
host_path = Path("/workspace")
if str(suffix) not in {"", "."}:
host_path = host_path / str(suffix)
host_path = host_path.joinpath(*suffix.parts)
return normalized, host_path
def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]:
normalized, host_path = _normalize_destination(cwd)
return str(normalized), host_path
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
@ -106,18 +129,18 @@ def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str)
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
_, destination_root = _normalize_destination(destination)
normalized_destination, destination_root = _normalize_destination(destination)
destination_root.mkdir(parents=True, exist_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = destination_root / str(member_name)
target_path = destination_root.joinpath(*member_name.parts)
entry_count += 1
_ensure_no_symlink_parents(destination_root, target_path, member.name)
if member.isdir():
if target_path.exists() and not target_path.is_dir():
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
@ -151,7 +174,7 @@ def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": destination,
"destination": str(normalized_destination),
"entry_count": entry_count,
"bytes_written": bytes_written,
}
@ -182,7 +205,323 @@ def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
}
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
winsize = struct.pack("HHHH", rows, cols, 0, 0)
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
class GuestShellSession:
"""In-guest PTY-backed interactive shell session."""
def __init__(
self,
*,
shell_id: str,
cwd: Path,
cwd_text: str,
cols: int,
rows: int,
) -> None:
self.shell_id = shell_id
self.cwd = cwd_text
self.cols = cols
self.rows = rows
self.started_at = time.time()
self.ended_at: float | None = None
self.exit_code: int | None = None
self.state = "running"
self._lock = threading.RLock()
self._output = ""
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
self._metadata_path = SHELL_ROOT / f"{shell_id}.json"
self._log_path = SHELL_ROOT / f"{shell_id}.log"
self._master_fd: int | None = None
master_fd, slave_fd = pty.openpty()
try:
_set_pty_size(slave_fd, rows, cols)
env = os.environ.copy()
env.update(
{
"TERM": env.get("TERM", "xterm-256color"),
"PS1": "pyro$ ",
"PROMPT_COMMAND": "",
}
)
process = subprocess.Popen( # noqa: S603
["/bin/bash", "--noprofile", "--norc", "-i"],
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd=str(cwd),
env=env,
text=False,
close_fds=True,
preexec_fn=os.setsid,
)
except Exception:
os.close(master_fd)
raise
finally:
os.close(slave_fd)
self._process = process
self._master_fd = master_fd
self._write_metadata()
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
self._reader.start()
self._waiter.start()
def summary(self) -> dict[str, Any]:
with self._lock:
return {
"shell_id": self.shell_id,
"cwd": self.cwd,
"cols": self.cols,
"rows": self.rows,
"state": self.state,
"started_at": self.started_at,
"ended_at": self.ended_at,
"exit_code": self.exit_code,
}
def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]:
with self._lock:
clamped_cursor = min(max(cursor, 0), len(self._output))
output = self._output[clamped_cursor : clamped_cursor + max_chars]
next_cursor = clamped_cursor + len(output)
payload = self.summary()
payload.update(
{
"cursor": clamped_cursor,
"next_cursor": next_cursor,
"output": output,
"truncated": next_cursor < len(self._output),
}
)
return payload
def write(self, text: str, *, append_newline: bool) -> dict[str, Any]:
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
master_fd = self._master_fd
if master_fd is None:
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
payload = text + ("\n" if append_newline else "")
try:
os.write(master_fd, payload.encode("utf-8"))
except OSError as exc:
self._refresh_process_state()
raise RuntimeError(f"failed to write shell input: {exc}") from exc
response = self.summary()
response.update({"input_length": len(text), "append_newline": append_newline})
return response
def send_signal(self, signal_name: str) -> dict[str, Any]:
signum = SHELL_SIGNAL_MAP.get(signal_name)
if signum is None:
raise ValueError(f"unsupported shell signal: {signal_name}")
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
pid = self._process.pid
try:
os.killpg(pid, signum)
except ProcessLookupError as exc:
self._refresh_process_state()
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
response = self.summary()
response["signal"] = signal_name
return response
def close(self) -> dict[str, Any]:
if self._process.poll() is None:
try:
os.killpg(self._process.pid, signal.SIGHUP)
except ProcessLookupError:
pass
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
try:
os.killpg(self._process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
self._process.wait(timeout=5)
else:
self._refresh_process_state()
self._close_master_fd()
if self._reader is not None:
self._reader.join(timeout=1)
if self._waiter is not None:
self._waiter.join(timeout=1)
response = self.summary()
response["closed"] = True
return response
def _reader_loop(self) -> None:
master_fd = self._master_fd
if master_fd is None:
return
while True:
try:
chunk = os.read(master_fd, BUFFER_SIZE)
except OSError:
break
if chunk == b"":
break
decoded = self._decoder.decode(chunk)
if decoded == "":
continue
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
decoded = self._decoder.decode(b"", final=True)
if decoded != "":
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
def _waiter_loop(self) -> None:
exit_code = self._process.wait()
with self._lock:
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _refresh_process_state(self) -> None:
exit_code = self._process.poll()
if exit_code is None:
return
with self._lock:
if self.state == "running":
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _write_metadata(self) -> None:
self._metadata_path.parent.mkdir(parents=True, exist_ok=True)
self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8")
def _close_master_fd(self) -> None:
with self._lock:
master_fd = self._master_fd
self._master_fd = None
if master_fd is None:
return
try:
os.close(master_fd)
except OSError:
pass
def _create_shell(
*,
shell_id: str,
cwd_text: str,
cols: int,
rows: int,
) -> GuestShellSession:
_, cwd_path = _normalize_shell_cwd(cwd_text)
with _SHELLS_LOCK:
if shell_id in _SHELLS:
raise RuntimeError(f"shell {shell_id!r} already exists")
session = GuestShellSession(
shell_id=shell_id,
cwd=cwd_path,
cwd_text=cwd_text,
cols=cols,
rows=rows,
)
_SHELLS[shell_id] = session
return session
def _get_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS[shell_id]
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _remove_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS.pop(shell_id)
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
return _extract_archive(payload, destination)
if action == "open_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace")))
session = _create_shell(
shell_id=shell_id,
cwd_text=cwd_text,
cols=int(request.get("cols", 120)),
rows=int(request.get("rows", 30)),
)
return session.summary()
if action == "read_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).read(
cursor=int(request.get("cursor", 0)),
max_chars=int(request.get("max_chars", 65536)),
)
if action == "write_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).write(
str(request.get("input", "")),
append_newline=bool(request.get("append_newline", True)),
)
if action == "signal_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
signal_name = str(request.get("signal", "INT")).upper()
if signal_name not in SHELL_SIGNAL_NAMES:
raise RuntimeError(
f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}"
)
return _get_shell(shell_id).send_signal(signal_name)
if action == "close_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _remove_shell(shell_id).close()
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
return _run_command(command, timeout_seconds)
def main() -> None:
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise SystemExit("AF_VSOCK is unavailable")
@ -192,19 +531,11 @@ def main() -> None:
while True:
conn, _ = server.accept()
with conn:
request = _read_request(conn)
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
response = _extract_archive(payload, destination)
else:
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
response = _run_command(command, timeout_seconds)
try:
request = _read_request(conn)
response = _dispatch(request, conn)
except Exception as exc: # noqa: BLE001
response = {"error": str(exc)}
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))