Add persistent workspace shell sessions
Let agents inhabit a workspace across separate calls instead of only submitting one-shot execs. Add workspace shell open/read/write/signal/close across the CLI, Python SDK, and MCP server, with persisted shell records, a local PTY-backed mock implementation, and guest-agent support for real Firecracker workspaces. Mark the 2.5.0 roadmap milestone done, refresh docs/examples and the release metadata, and verify with uv lock, UV_CACHE_DIR=.uv-cache make check, and UV_CACHE_DIR=.uv-cache make dist-check.
This commit is contained in:
parent
2de31306b6
commit
3f8293ad24
28 changed files with 3265 additions and 81 deletions
|
|
@ -1,14 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
|
||||
"""Guest-side exec, workspace import, and interactive shell agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import fcntl
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pty
|
||||
import signal
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import tarfile
|
||||
import termios
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any
|
||||
|
|
@ -16,6 +23,17 @@ from typing import Any
|
|||
PORT = 5005
|
||||
BUFFER_SIZE = 65536
|
||||
WORKSPACE_ROOT = PurePosixPath("/workspace")
|
||||
SHELL_ROOT = Path("/run/pyro-shells")
|
||||
SHELL_SIGNAL_MAP = {
|
||||
"HUP": signal.SIGHUP,
|
||||
"INT": signal.SIGINT,
|
||||
"TERM": signal.SIGTERM,
|
||||
"KILL": signal.SIGKILL,
|
||||
}
|
||||
SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP)
|
||||
|
||||
_SHELLS: dict[str, "GuestShellSession"] = {}
|
||||
_SHELLS_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _read_request(conn: socket.socket) -> dict[str, Any]:
|
||||
|
|
@ -77,10 +95,15 @@ def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
|
|||
suffix = normalized.relative_to(WORKSPACE_ROOT)
|
||||
host_path = Path("/workspace")
|
||||
if str(suffix) not in {"", "."}:
|
||||
host_path = host_path / str(suffix)
|
||||
host_path = host_path.joinpath(*suffix.parts)
|
||||
return normalized, host_path
|
||||
|
||||
|
||||
def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]:
|
||||
normalized, host_path = _normalize_destination(cwd)
|
||||
return str(normalized), host_path
|
||||
|
||||
|
||||
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
|
|
@ -106,18 +129,18 @@ def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str)
|
|||
|
||||
|
||||
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
||||
_, destination_root = _normalize_destination(destination)
|
||||
normalized_destination, destination_root = _normalize_destination(destination)
|
||||
destination_root.mkdir(parents=True, exist_ok=True)
|
||||
bytes_written = 0
|
||||
entry_count = 0
|
||||
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_member_name(member.name)
|
||||
target_path = destination_root / str(member_name)
|
||||
target_path = destination_root.joinpath(*member_name.parts)
|
||||
entry_count += 1
|
||||
_ensure_no_symlink_parents(destination_root, target_path, member.name)
|
||||
if member.isdir():
|
||||
if target_path.exists() and not target_path.is_dir():
|
||||
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
|
||||
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
|
|
@ -151,7 +174,7 @@ def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
|||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return {
|
||||
"destination": destination,
|
||||
"destination": str(normalized_destination),
|
||||
"entry_count": entry_count,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
|
|
@ -182,7 +205,323 @@ def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
|
||||
winsize = struct.pack("HHHH", rows, cols, 0, 0)
|
||||
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
|
||||
|
||||
|
||||
class GuestShellSession:
|
||||
"""In-guest PTY-backed interactive shell session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shell_id: str,
|
||||
cwd: Path,
|
||||
cwd_text: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> None:
|
||||
self.shell_id = shell_id
|
||||
self.cwd = cwd_text
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
self.started_at = time.time()
|
||||
self.ended_at: float | None = None
|
||||
self.exit_code: int | None = None
|
||||
self.state = "running"
|
||||
self._lock = threading.RLock()
|
||||
self._output = ""
|
||||
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
|
||||
self._metadata_path = SHELL_ROOT / f"{shell_id}.json"
|
||||
self._log_path = SHELL_ROOT / f"{shell_id}.log"
|
||||
self._master_fd: int | None = None
|
||||
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
try:
|
||||
_set_pty_size(slave_fd, rows, cols)
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"TERM": env.get("TERM", "xterm-256color"),
|
||||
"PS1": "pyro$ ",
|
||||
"PROMPT_COMMAND": "",
|
||||
}
|
||||
)
|
||||
process = subprocess.Popen( # noqa: S603
|
||||
["/bin/bash", "--noprofile", "--norc", "-i"],
|
||||
stdin=slave_fd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
cwd=str(cwd),
|
||||
env=env,
|
||||
text=False,
|
||||
close_fds=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
except Exception:
|
||||
os.close(master_fd)
|
||||
raise
|
||||
finally:
|
||||
os.close(slave_fd)
|
||||
|
||||
self._process = process
|
||||
self._master_fd = master_fd
|
||||
self._write_metadata()
|
||||
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
|
||||
self._reader.start()
|
||||
self._waiter.start()
|
||||
|
||||
def summary(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return {
|
||||
"shell_id": self.shell_id,
|
||||
"cwd": self.cwd,
|
||||
"cols": self.cols,
|
||||
"rows": self.rows,
|
||||
"state": self.state,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"exit_code": self.exit_code,
|
||||
}
|
||||
|
||||
def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
clamped_cursor = min(max(cursor, 0), len(self._output))
|
||||
output = self._output[clamped_cursor : clamped_cursor + max_chars]
|
||||
next_cursor = clamped_cursor + len(output)
|
||||
payload = self.summary()
|
||||
payload.update(
|
||||
{
|
||||
"cursor": clamped_cursor,
|
||||
"next_cursor": next_cursor,
|
||||
"output": output,
|
||||
"truncated": next_cursor < len(self._output),
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
def write(self, text: str, *, append_newline: bool) -> dict[str, Any]:
|
||||
if self._process.poll() is not None:
|
||||
self._refresh_process_state()
|
||||
with self._lock:
|
||||
if self.state != "running":
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running")
|
||||
master_fd = self._master_fd
|
||||
if master_fd is None:
|
||||
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
|
||||
payload = text + ("\n" if append_newline else "")
|
||||
try:
|
||||
os.write(master_fd, payload.encode("utf-8"))
|
||||
except OSError as exc:
|
||||
self._refresh_process_state()
|
||||
raise RuntimeError(f"failed to write shell input: {exc}") from exc
|
||||
response = self.summary()
|
||||
response.update({"input_length": len(text), "append_newline": append_newline})
|
||||
return response
|
||||
|
||||
def send_signal(self, signal_name: str) -> dict[str, Any]:
|
||||
signum = SHELL_SIGNAL_MAP.get(signal_name)
|
||||
if signum is None:
|
||||
raise ValueError(f"unsupported shell signal: {signal_name}")
|
||||
if self._process.poll() is not None:
|
||||
self._refresh_process_state()
|
||||
with self._lock:
|
||||
if self.state != "running":
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running")
|
||||
pid = self._process.pid
|
||||
try:
|
||||
os.killpg(pid, signum)
|
||||
except ProcessLookupError as exc:
|
||||
self._refresh_process_state()
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
|
||||
response = self.summary()
|
||||
response["signal"] = signal_name
|
||||
return response
|
||||
|
||||
def close(self) -> dict[str, Any]:
|
||||
if self._process.poll() is None:
|
||||
try:
|
||||
os.killpg(self._process.pid, signal.SIGHUP)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
try:
|
||||
self._process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
os.killpg(self._process.pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
self._process.wait(timeout=5)
|
||||
else:
|
||||
self._refresh_process_state()
|
||||
self._close_master_fd()
|
||||
if self._reader is not None:
|
||||
self._reader.join(timeout=1)
|
||||
if self._waiter is not None:
|
||||
self._waiter.join(timeout=1)
|
||||
response = self.summary()
|
||||
response["closed"] = True
|
||||
return response
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
master_fd = self._master_fd
|
||||
if master_fd is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
chunk = os.read(master_fd, BUFFER_SIZE)
|
||||
except OSError:
|
||||
break
|
||||
if chunk == b"":
|
||||
break
|
||||
decoded = self._decoder.decode(chunk)
|
||||
if decoded == "":
|
||||
continue
|
||||
with self._lock:
|
||||
self._output += decoded
|
||||
with self._log_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(decoded)
|
||||
decoded = self._decoder.decode(b"", final=True)
|
||||
if decoded != "":
|
||||
with self._lock:
|
||||
self._output += decoded
|
||||
with self._log_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(decoded)
|
||||
|
||||
def _waiter_loop(self) -> None:
|
||||
exit_code = self._process.wait()
|
||||
with self._lock:
|
||||
self.state = "stopped"
|
||||
self.exit_code = exit_code
|
||||
self.ended_at = time.time()
|
||||
self._write_metadata()
|
||||
|
||||
def _refresh_process_state(self) -> None:
|
||||
exit_code = self._process.poll()
|
||||
if exit_code is None:
|
||||
return
|
||||
with self._lock:
|
||||
if self.state == "running":
|
||||
self.state = "stopped"
|
||||
self.exit_code = exit_code
|
||||
self.ended_at = time.time()
|
||||
self._write_metadata()
|
||||
|
||||
def _write_metadata(self) -> None:
|
||||
self._metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8")
|
||||
|
||||
def _close_master_fd(self) -> None:
|
||||
with self._lock:
|
||||
master_fd = self._master_fd
|
||||
self._master_fd = None
|
||||
if master_fd is None:
|
||||
return
|
||||
try:
|
||||
os.close(master_fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _create_shell(
|
||||
*,
|
||||
shell_id: str,
|
||||
cwd_text: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> GuestShellSession:
|
||||
_, cwd_path = _normalize_shell_cwd(cwd_text)
|
||||
with _SHELLS_LOCK:
|
||||
if shell_id in _SHELLS:
|
||||
raise RuntimeError(f"shell {shell_id!r} already exists")
|
||||
session = GuestShellSession(
|
||||
shell_id=shell_id,
|
||||
cwd=cwd_path,
|
||||
cwd_text=cwd_text,
|
||||
cols=cols,
|
||||
rows=rows,
|
||||
)
|
||||
_SHELLS[shell_id] = session
|
||||
return session
|
||||
|
||||
|
||||
def _get_shell(shell_id: str) -> GuestShellSession:
|
||||
with _SHELLS_LOCK:
|
||||
try:
|
||||
return _SHELLS[shell_id]
|
||||
except KeyError as exc:
|
||||
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
|
||||
|
||||
|
||||
def _remove_shell(shell_id: str) -> GuestShellSession:
|
||||
with _SHELLS_LOCK:
|
||||
try:
|
||||
return _SHELLS.pop(shell_id)
|
||||
except KeyError as exc:
|
||||
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
|
||||
|
||||
|
||||
def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
archive_size = int(request.get("archive_size", 0))
|
||||
if archive_size < 0:
|
||||
raise RuntimeError("archive_size must not be negative")
|
||||
destination = str(request.get("destination", "/workspace"))
|
||||
payload = _read_exact(conn, archive_size)
|
||||
return _extract_archive(payload, destination)
|
||||
if action == "open_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace")))
|
||||
session = _create_shell(
|
||||
shell_id=shell_id,
|
||||
cwd_text=cwd_text,
|
||||
cols=int(request.get("cols", 120)),
|
||||
rows=int(request.get("rows", 30)),
|
||||
)
|
||||
return session.summary()
|
||||
if action == "read_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
return _get_shell(shell_id).read(
|
||||
cursor=int(request.get("cursor", 0)),
|
||||
max_chars=int(request.get("max_chars", 65536)),
|
||||
)
|
||||
if action == "write_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
return _get_shell(shell_id).write(
|
||||
str(request.get("input", "")),
|
||||
append_newline=bool(request.get("append_newline", True)),
|
||||
)
|
||||
if action == "signal_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
signal_name = str(request.get("signal", "INT")).upper()
|
||||
if signal_name not in SHELL_SIGNAL_NAMES:
|
||||
raise RuntimeError(
|
||||
f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}"
|
||||
)
|
||||
return _get_shell(shell_id).send_signal(signal_name)
|
||||
if action == "close_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
return _remove_shell(shell_id).close()
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
return _run_command(command, timeout_seconds)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
family = getattr(socket, "AF_VSOCK", None)
|
||||
if family is None:
|
||||
raise SystemExit("AF_VSOCK is unavailable")
|
||||
|
|
@ -192,19 +531,11 @@ def main() -> None:
|
|||
while True:
|
||||
conn, _ = server.accept()
|
||||
with conn:
|
||||
request = _read_request(conn)
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
archive_size = int(request.get("archive_size", 0))
|
||||
if archive_size < 0:
|
||||
raise RuntimeError("archive_size must not be negative")
|
||||
destination = str(request.get("destination", "/workspace"))
|
||||
payload = _read_exact(conn, archive_size)
|
||||
response = _extract_archive(payload, destination)
|
||||
else:
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
try:
|
||||
request = _read_request(conn)
|
||||
response = _dispatch(request, conn)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
response = {"error": str(exc)}
|
||||
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue