Add persistent workspace shell sessions

Let agents inhabit a workspace across separate calls instead of only submitting one-shot execs.

Add workspace shell open/read/write/signal/close across the CLI, Python SDK, and MCP server, with persisted shell records, a local PTY-backed mock implementation, and guest-agent support for real Firecracker workspaces.

Mark the 2.5.0 roadmap milestone done, refresh docs/examples and the release metadata, and verify with uv lock, UV_CACHE_DIR=.uv-cache make check, and UV_CACHE_DIR=.uv-cache make dist-check.
This commit is contained in:
Thales Maciel 2026-03-12 02:31:57 -03:00
parent 2de31306b6
commit 3f8293ad24
28 changed files with 3265 additions and 81 deletions

View file

@ -130,6 +130,67 @@ class Pyro:
def logs_workspace(self, workspace_id: str) -> dict[str, Any]:
return self._manager.logs_workspace(workspace_id)
def open_shell(
self,
workspace_id: str,
*,
cwd: str = "/workspace",
cols: int = 120,
rows: int = 30,
) -> dict[str, Any]:
return self._manager.open_shell(
workspace_id,
cwd=cwd,
cols=cols,
rows=rows,
)
def read_shell(
self,
workspace_id: str,
shell_id: str,
*,
cursor: int = 0,
max_chars: int = 65536,
) -> dict[str, Any]:
return self._manager.read_shell(
workspace_id,
shell_id,
cursor=cursor,
max_chars=max_chars,
)
def write_shell(
self,
workspace_id: str,
shell_id: str,
*,
input: str,
append_newline: bool = True,
) -> dict[str, Any]:
return self._manager.write_shell(
workspace_id,
shell_id,
input_text=input,
append_newline=append_newline,
)
def signal_shell(
self,
workspace_id: str,
shell_id: str,
*,
signal_name: str = "INT",
) -> dict[str, Any]:
return self._manager.signal_shell(
workspace_id,
shell_id,
signal_name=signal_name,
)
def close_shell(self, workspace_id: str, shell_id: str) -> dict[str, Any]:
return self._manager.close_shell(workspace_id, shell_id)
def delete_workspace(self, workspace_id: str) -> dict[str, Any]:
return self._manager.delete_workspace(workspace_id)
@ -309,6 +370,64 @@ class Pyro:
"""Return persisted command history for one workspace."""
return self.logs_workspace(workspace_id)
@server.tool()
async def shell_open(
workspace_id: str,
cwd: str = "/workspace",
cols: int = 120,
rows: int = 30,
) -> dict[str, Any]:
"""Open a persistent interactive shell inside one workspace."""
return self.open_shell(workspace_id, cwd=cwd, cols=cols, rows=rows)
@server.tool()
async def shell_read(
workspace_id: str,
shell_id: str,
cursor: int = 0,
max_chars: int = 65536,
) -> dict[str, Any]:
"""Read merged PTY output from a workspace shell."""
return self.read_shell(
workspace_id,
shell_id,
cursor=cursor,
max_chars=max_chars,
)
@server.tool()
async def shell_write(
workspace_id: str,
shell_id: str,
input: str,
append_newline: bool = True,
) -> dict[str, Any]:
"""Write text input to a persistent workspace shell."""
return self.write_shell(
workspace_id,
shell_id,
input=input,
append_newline=append_newline,
)
@server.tool()
async def shell_signal(
workspace_id: str,
shell_id: str,
signal_name: str = "INT",
) -> dict[str, Any]:
"""Send a signal to the shell process group."""
return self.signal_shell(
workspace_id,
shell_id,
signal_name=signal_name,
)
@server.tool()
async def shell_close(workspace_id: str, shell_id: str) -> dict[str, Any]:
"""Close a persistent workspace shell."""
return self.close_shell(workspace_id, shell_id)
@server.tool()
async def workspace_delete(workspace_id: str) -> dict[str, Any]:
"""Delete a persistent workspace and its backing sandbox."""

View file

@ -19,6 +19,7 @@ from pyro_mcp.vm_manager import (
DEFAULT_MEM_MIB,
DEFAULT_VCPU_COUNT,
WORKSPACE_GUEST_PATH,
WORKSPACE_SHELL_SIGNAL_NAMES,
)
@ -237,6 +238,37 @@ def _print_workspace_logs_human(payload: dict[str, Any]) -> None:
print(stderr, end="" if stderr.endswith("\n") else "\n", file=sys.stderr)
def _print_workspace_shell_summary_human(payload: dict[str, Any], *, prefix: str) -> None:
print(
f"[{prefix}] "
f"workspace_id={str(payload.get('workspace_id', 'unknown'))} "
f"shell_id={str(payload.get('shell_id', 'unknown'))} "
f"state={str(payload.get('state', 'unknown'))} "
f"cwd={str(payload.get('cwd', WORKSPACE_GUEST_PATH))} "
f"cols={int(payload.get('cols', 0))} "
f"rows={int(payload.get('rows', 0))} "
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}",
file=sys.stderr,
flush=True,
)
def _print_workspace_shell_read_human(payload: dict[str, Any]) -> None:
_write_stream(str(payload.get("output", "")), stream=sys.stdout)
print(
"[workspace-shell-read] "
f"workspace_id={str(payload.get('workspace_id', 'unknown'))} "
f"shell_id={str(payload.get('shell_id', 'unknown'))} "
f"state={str(payload.get('state', 'unknown'))} "
f"cursor={int(payload.get('cursor', 0))} "
f"next_cursor={int(payload.get('next_cursor', 0))} "
f"truncated={bool(payload.get('truncated', False))} "
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}",
file=sys.stderr,
flush=True,
)
class _HelpFormatter(
argparse.RawDescriptionHelpFormatter,
argparse.ArgumentDefaultsHelpFormatter,
@ -269,6 +301,7 @@ def _build_parser() -> argparse.ArgumentParser:
Need repeated commands in one workspace after that?
pyro workspace create debian:12 --seed-path ./repo
pyro workspace sync push WORKSPACE_ID ./changes
pyro workspace shell open WORKSPACE_ID
Use `pyro mcp serve` only after the CLI validation path works.
"""
@ -476,6 +509,7 @@ def _build_parser() -> argparse.ArgumentParser:
pyro workspace create debian:12 --seed-path ./repo
pyro workspace sync push WORKSPACE_ID ./repo --dest src
pyro workspace exec WORKSPACE_ID -- sh -lc 'printf "hello\\n" > note.txt'
pyro workspace shell open WORKSPACE_ID
pyro workspace logs WORKSPACE_ID
"""
),
@ -633,6 +667,191 @@ def _build_parser() -> argparse.ArgumentParser:
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_shell_parser = workspace_subparsers.add_parser(
"shell",
help="Open and manage persistent interactive shells.",
description=(
"Open one or more persistent interactive PTY shell sessions inside a started "
"workspace."
),
epilog=dedent(
"""
Examples:
pyro workspace shell open WORKSPACE_ID
pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd'
pyro workspace shell read WORKSPACE_ID SHELL_ID
pyro workspace shell signal WORKSPACE_ID SHELL_ID --signal INT
pyro workspace shell close WORKSPACE_ID SHELL_ID
Use `workspace exec` for one-shot commands. Use `workspace shell` when you need
an interactive process that keeps its state between calls.
"""
),
formatter_class=_HelpFormatter,
)
workspace_shell_subparsers = workspace_shell_parser.add_subparsers(
dest="workspace_shell_command",
required=True,
metavar="SHELL",
)
workspace_shell_open_parser = workspace_shell_subparsers.add_parser(
"open",
help="Open a persistent interactive shell.",
description="Open a new PTY shell inside a started workspace.",
epilog="Example:\n pyro workspace shell open WORKSPACE_ID --cwd src",
formatter_class=_HelpFormatter,
)
workspace_shell_open_parser.add_argument(
"workspace_id",
metavar="WORKSPACE_ID",
help="Persistent workspace identifier.",
)
workspace_shell_open_parser.add_argument(
"--cwd",
default=WORKSPACE_GUEST_PATH,
help="Shell working directory. Relative values resolve inside `/workspace`.",
)
workspace_shell_open_parser.add_argument(
"--cols",
type=int,
default=120,
help="Shell terminal width in columns.",
)
workspace_shell_open_parser.add_argument(
"--rows",
type=int,
default=30,
help="Shell terminal height in rows.",
)
workspace_shell_open_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_shell_read_parser = workspace_shell_subparsers.add_parser(
"read",
help="Read merged PTY output from a shell.",
description="Read merged text output from a persistent workspace shell.",
epilog=dedent(
"""
Example:
pyro workspace shell read WORKSPACE_ID SHELL_ID --cursor 0
Shell output is written to stdout. The read summary is written to stderr.
Use --json for a deterministic structured response.
"""
),
formatter_class=_HelpFormatter,
)
workspace_shell_read_parser.add_argument(
"workspace_id",
metavar="WORKSPACE_ID",
help="Persistent workspace identifier.",
)
workspace_shell_read_parser.add_argument(
"shell_id",
metavar="SHELL_ID",
help="Persistent shell identifier returned by `workspace shell open`.",
)
workspace_shell_read_parser.add_argument(
"--cursor",
type=int,
default=0,
help="Character offset into the merged shell output buffer.",
)
workspace_shell_read_parser.add_argument(
"--max-chars",
type=int,
default=65536,
help="Maximum number of characters to return from the current cursor position.",
)
workspace_shell_read_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_shell_write_parser = workspace_shell_subparsers.add_parser(
"write",
help="Write text input into a shell.",
description="Write text input into a persistent workspace shell.",
epilog="Example:\n pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd'",
formatter_class=_HelpFormatter,
)
workspace_shell_write_parser.add_argument(
"workspace_id",
metavar="WORKSPACE_ID",
help="Persistent workspace identifier.",
)
workspace_shell_write_parser.add_argument(
"shell_id",
metavar="SHELL_ID",
help="Persistent shell identifier returned by `workspace shell open`.",
)
workspace_shell_write_parser.add_argument(
"--input",
required=True,
help="Text to send to the shell.",
)
workspace_shell_write_parser.add_argument(
"--no-newline",
action="store_true",
help="Do not append a trailing newline after the provided input.",
)
workspace_shell_write_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_shell_signal_parser = workspace_shell_subparsers.add_parser(
"signal",
help="Send a signal to a shell process group.",
description="Send a control signal to a persistent workspace shell.",
epilog="Example:\n pyro workspace shell signal WORKSPACE_ID SHELL_ID --signal INT",
formatter_class=_HelpFormatter,
)
workspace_shell_signal_parser.add_argument(
"workspace_id",
metavar="WORKSPACE_ID",
help="Persistent workspace identifier.",
)
workspace_shell_signal_parser.add_argument(
"shell_id",
metavar="SHELL_ID",
help="Persistent shell identifier returned by `workspace shell open`.",
)
workspace_shell_signal_parser.add_argument(
"--signal",
default="INT",
choices=WORKSPACE_SHELL_SIGNAL_NAMES,
help="Signal name to send to the shell process group.",
)
workspace_shell_signal_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_shell_close_parser = workspace_shell_subparsers.add_parser(
"close",
help="Close a persistent shell.",
description="Close a persistent workspace shell and release its PTY state.",
epilog="Example:\n pyro workspace shell close WORKSPACE_ID SHELL_ID",
formatter_class=_HelpFormatter,
)
workspace_shell_close_parser.add_argument(
"workspace_id",
metavar="WORKSPACE_ID",
help="Persistent workspace identifier.",
)
workspace_shell_close_parser.add_argument(
"shell_id",
metavar="SHELL_ID",
help="Persistent shell identifier returned by `workspace shell open`.",
)
workspace_shell_close_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_status_parser = workspace_subparsers.add_parser(
"status",
help="Inspect one workspace.",
@ -929,6 +1148,99 @@ def main() -> None:
raise SystemExit(1) from exc
_print_workspace_sync_human(payload)
return
if args.workspace_command == "shell":
if args.workspace_shell_command == "open":
try:
payload = pyro.open_shell(
args.workspace_id,
cwd=args.cwd,
cols=args.cols,
rows=args.rows,
)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_shell_summary_human(payload, prefix="workspace-shell-open")
return
if args.workspace_shell_command == "read":
try:
payload = pyro.read_shell(
args.workspace_id,
args.shell_id,
cursor=args.cursor,
max_chars=args.max_chars,
)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_shell_read_human(payload)
return
if args.workspace_shell_command == "write":
try:
payload = pyro.write_shell(
args.workspace_id,
args.shell_id,
input=args.input,
append_newline=not bool(args.no_newline),
)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_shell_summary_human(payload, prefix="workspace-shell-write")
return
if args.workspace_shell_command == "signal":
try:
payload = pyro.signal_shell(
args.workspace_id,
args.shell_id,
signal_name=args.signal,
)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_shell_summary_human(
payload,
prefix="workspace-shell-signal",
)
return
if args.workspace_shell_command == "close":
try:
payload = pyro.close_shell(args.workspace_id, args.shell_id)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_shell_summary_human(payload, prefix="workspace-shell-close")
return
if args.workspace_command == "status":
payload = pyro.status_workspace(args.workspace_id)
if bool(args.json):

View file

@ -5,7 +5,16 @@ from __future__ import annotations
PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "workspace")
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
PUBLIC_CLI_WORKSPACE_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status", "sync")
PUBLIC_CLI_WORKSPACE_SUBCOMMANDS = (
"create",
"delete",
"exec",
"logs",
"shell",
"status",
"sync",
)
PUBLIC_CLI_WORKSPACE_SHELL_SUBCOMMANDS = ("close", "open", "read", "signal", "write")
PUBLIC_CLI_WORKSPACE_SYNC_SUBCOMMANDS = ("push",)
PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
"--vcpu-count",
@ -16,6 +25,11 @@ PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
"--seed-path",
"--json",
)
PUBLIC_CLI_WORKSPACE_SHELL_OPEN_FLAGS = ("--cwd", "--cols", "--rows", "--json")
PUBLIC_CLI_WORKSPACE_SHELL_READ_FLAGS = ("--cursor", "--max-chars", "--json")
PUBLIC_CLI_WORKSPACE_SHELL_WRITE_FLAGS = ("--input", "--no-newline", "--json")
PUBLIC_CLI_WORKSPACE_SHELL_SIGNAL_FLAGS = ("--signal", "--json")
PUBLIC_CLI_WORKSPACE_SHELL_CLOSE_FLAGS = ("--json",)
PUBLIC_CLI_WORKSPACE_SYNC_PUSH_FLAGS = ("--dest", "--json")
PUBLIC_CLI_RUN_FLAGS = (
"--vcpu-count",
@ -28,6 +42,7 @@ PUBLIC_CLI_RUN_FLAGS = (
)
PUBLIC_SDK_METHODS = (
"close_shell",
"create_server",
"create_vm",
"create_workspace",
@ -39,18 +54,27 @@ PUBLIC_SDK_METHODS = (
"list_environments",
"logs_workspace",
"network_info_vm",
"open_shell",
"prune_environments",
"pull_environment",
"push_workspace_sync",
"read_shell",
"reap_expired",
"run_in_vm",
"signal_shell",
"start_vm",
"status_vm",
"status_workspace",
"stop_vm",
"write_shell",
)
PUBLIC_MCP_TOOLS = (
"shell_close",
"shell_open",
"shell_read",
"shell_signal",
"shell_write",
"vm_create",
"vm_delete",
"vm_exec",

View file

@ -1,14 +1,21 @@
#!/usr/bin/env python3
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
"""Guest-side exec, workspace import, and interactive shell agent."""
from __future__ import annotations
import codecs
import fcntl
import io
import json
import os
import pty
import signal
import socket
import struct
import subprocess
import tarfile
import termios
import threading
import time
from pathlib import Path, PurePosixPath
from typing import Any
@ -16,6 +23,17 @@ from typing import Any
PORT = 5005
BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
SHELL_ROOT = Path("/run/pyro-shells")
SHELL_SIGNAL_MAP = {
"HUP": signal.SIGHUP,
"INT": signal.SIGINT,
"TERM": signal.SIGTERM,
"KILL": signal.SIGKILL,
}
SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP)
_SHELLS: dict[str, "GuestShellSession"] = {}
_SHELLS_LOCK = threading.Lock()
def _read_request(conn: socket.socket) -> dict[str, Any]:
@ -77,10 +95,15 @@ def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
suffix = normalized.relative_to(WORKSPACE_ROOT)
host_path = Path("/workspace")
if str(suffix) not in {"", "."}:
host_path = host_path / str(suffix)
host_path = host_path.joinpath(*suffix.parts)
return normalized, host_path
def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]:
normalized, host_path = _normalize_destination(cwd)
return str(normalized), host_path
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
@ -106,18 +129,18 @@ def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str)
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
_, destination_root = _normalize_destination(destination)
normalized_destination, destination_root = _normalize_destination(destination)
destination_root.mkdir(parents=True, exist_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = destination_root / str(member_name)
target_path = destination_root.joinpath(*member_name.parts)
entry_count += 1
_ensure_no_symlink_parents(destination_root, target_path, member.name)
if member.isdir():
if target_path.exists() and not target_path.is_dir():
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
@ -151,7 +174,7 @@ def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": destination,
"destination": str(normalized_destination),
"entry_count": entry_count,
"bytes_written": bytes_written,
}
@ -182,7 +205,323 @@ def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
}
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
winsize = struct.pack("HHHH", rows, cols, 0, 0)
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
class GuestShellSession:
"""In-guest PTY-backed interactive shell session."""
def __init__(
self,
*,
shell_id: str,
cwd: Path,
cwd_text: str,
cols: int,
rows: int,
) -> None:
self.shell_id = shell_id
self.cwd = cwd_text
self.cols = cols
self.rows = rows
self.started_at = time.time()
self.ended_at: float | None = None
self.exit_code: int | None = None
self.state = "running"
self._lock = threading.RLock()
self._output = ""
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
self._metadata_path = SHELL_ROOT / f"{shell_id}.json"
self._log_path = SHELL_ROOT / f"{shell_id}.log"
self._master_fd: int | None = None
master_fd, slave_fd = pty.openpty()
try:
_set_pty_size(slave_fd, rows, cols)
env = os.environ.copy()
env.update(
{
"TERM": env.get("TERM", "xterm-256color"),
"PS1": "pyro$ ",
"PROMPT_COMMAND": "",
}
)
process = subprocess.Popen( # noqa: S603
["/bin/bash", "--noprofile", "--norc", "-i"],
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd=str(cwd),
env=env,
text=False,
close_fds=True,
preexec_fn=os.setsid,
)
except Exception:
os.close(master_fd)
raise
finally:
os.close(slave_fd)
self._process = process
self._master_fd = master_fd
self._write_metadata()
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
self._reader.start()
self._waiter.start()
def summary(self) -> dict[str, Any]:
with self._lock:
return {
"shell_id": self.shell_id,
"cwd": self.cwd,
"cols": self.cols,
"rows": self.rows,
"state": self.state,
"started_at": self.started_at,
"ended_at": self.ended_at,
"exit_code": self.exit_code,
}
def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]:
with self._lock:
clamped_cursor = min(max(cursor, 0), len(self._output))
output = self._output[clamped_cursor : clamped_cursor + max_chars]
next_cursor = clamped_cursor + len(output)
payload = self.summary()
payload.update(
{
"cursor": clamped_cursor,
"next_cursor": next_cursor,
"output": output,
"truncated": next_cursor < len(self._output),
}
)
return payload
def write(self, text: str, *, append_newline: bool) -> dict[str, Any]:
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
master_fd = self._master_fd
if master_fd is None:
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
payload = text + ("\n" if append_newline else "")
try:
os.write(master_fd, payload.encode("utf-8"))
except OSError as exc:
self._refresh_process_state()
raise RuntimeError(f"failed to write shell input: {exc}") from exc
response = self.summary()
response.update({"input_length": len(text), "append_newline": append_newline})
return response
def send_signal(self, signal_name: str) -> dict[str, Any]:
signum = SHELL_SIGNAL_MAP.get(signal_name)
if signum is None:
raise ValueError(f"unsupported shell signal: {signal_name}")
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
pid = self._process.pid
try:
os.killpg(pid, signum)
except ProcessLookupError as exc:
self._refresh_process_state()
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
response = self.summary()
response["signal"] = signal_name
return response
def close(self) -> dict[str, Any]:
if self._process.poll() is None:
try:
os.killpg(self._process.pid, signal.SIGHUP)
except ProcessLookupError:
pass
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
try:
os.killpg(self._process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
self._process.wait(timeout=5)
else:
self._refresh_process_state()
self._close_master_fd()
if self._reader is not None:
self._reader.join(timeout=1)
if self._waiter is not None:
self._waiter.join(timeout=1)
response = self.summary()
response["closed"] = True
return response
def _reader_loop(self) -> None:
master_fd = self._master_fd
if master_fd is None:
return
while True:
try:
chunk = os.read(master_fd, BUFFER_SIZE)
except OSError:
break
if chunk == b"":
break
decoded = self._decoder.decode(chunk)
if decoded == "":
continue
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
decoded = self._decoder.decode(b"", final=True)
if decoded != "":
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
def _waiter_loop(self) -> None:
exit_code = self._process.wait()
with self._lock:
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _refresh_process_state(self) -> None:
exit_code = self._process.poll()
if exit_code is None:
return
with self._lock:
if self.state == "running":
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _write_metadata(self) -> None:
self._metadata_path.parent.mkdir(parents=True, exist_ok=True)
self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8")
def _close_master_fd(self) -> None:
with self._lock:
master_fd = self._master_fd
self._master_fd = None
if master_fd is None:
return
try:
os.close(master_fd)
except OSError:
pass
def _create_shell(
*,
shell_id: str,
cwd_text: str,
cols: int,
rows: int,
) -> GuestShellSession:
_, cwd_path = _normalize_shell_cwd(cwd_text)
with _SHELLS_LOCK:
if shell_id in _SHELLS:
raise RuntimeError(f"shell {shell_id!r} already exists")
session = GuestShellSession(
shell_id=shell_id,
cwd=cwd_path,
cwd_text=cwd_text,
cols=cols,
rows=rows,
)
_SHELLS[shell_id] = session
return session
def _get_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS[shell_id]
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _remove_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS.pop(shell_id)
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
return _extract_archive(payload, destination)
if action == "open_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace")))
session = _create_shell(
shell_id=shell_id,
cwd_text=cwd_text,
cols=int(request.get("cols", 120)),
rows=int(request.get("rows", 30)),
)
return session.summary()
if action == "read_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).read(
cursor=int(request.get("cursor", 0)),
max_chars=int(request.get("max_chars", 65536)),
)
if action == "write_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).write(
str(request.get("input", "")),
append_newline=bool(request.get("append_newline", True)),
)
if action == "signal_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
signal_name = str(request.get("signal", "INT")).upper()
if signal_name not in SHELL_SIGNAL_NAMES:
raise RuntimeError(
f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}"
)
return _get_shell(shell_id).send_signal(signal_name)
if action == "close_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _remove_shell(shell_id).close()
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
return _run_command(command, timeout_seconds)
def main() -> None:
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise SystemExit("AF_VSOCK is unavailable")
@ -192,19 +531,11 @@ def main() -> None:
while True:
conn, _ = server.accept()
with conn:
request = _read_request(conn)
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
response = _extract_archive(payload, destination)
else:
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
response = _run_command(command, timeout_seconds)
try:
request = _read_request(conn)
response = _dispatch(request, conn)
except Exception as exc: # noqa: BLE001
response = {"error": str(exc)}
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))

View file

@ -25,7 +25,7 @@
"guest": {
"agent": {
"path": "guest/pyro_guest_agent.py",
"sha256": "3b684b1b07745fc7788e560b0bdd0c0535c31c395ff87474ae9e114f4489726d"
"sha256": "07adf6269551447dbea8c236f91499ea1479212a3f084c5402a656f5f5cc5892"
}
},
"platform": "linux-x86_64",

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "2.4.0"
DEFAULT_CATALOG_VERSION = "2.5.0"
OCI_MANIFEST_ACCEPT = ", ".join(
(
"application/vnd.oci.image.index.v1+json",

View file

@ -39,6 +39,26 @@ class GuestArchiveResponse:
bytes_written: int
@dataclass(frozen=True)
class GuestShellSummary:
shell_id: str
cwd: str
cols: int
rows: int
state: str
started_at: float
ended_at: float | None
exit_code: int | None
@dataclass(frozen=True)
class GuestShellReadResponse(GuestShellSummary):
cursor: int
next_cursor: int
output: str
truncated: bool
class VsockExecClient:
"""Minimal JSON-over-stream client for a guest exec agent."""
@ -54,19 +74,17 @@ class VsockExecClient:
*,
uds_path: str | None = None,
) -> GuestExecResponse:
request = {
"command": command,
"timeout_seconds": timeout_seconds,
}
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
try:
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
payload = self._recv_json_payload(sock)
finally:
sock.close()
if not isinstance(payload, dict):
raise RuntimeError("guest exec response must be a JSON object")
payload = self._request_json(
guest_cid,
port,
{
"command": command,
"timeout_seconds": timeout_seconds,
},
timeout_seconds=timeout_seconds,
uds_path=uds_path,
error_message="guest exec response must be a JSON object",
)
return GuestExecResponse(
stdout=str(payload.get("stdout", "")),
stderr=str(payload.get("stderr", "")),
@ -101,12 +119,198 @@ class VsockExecClient:
if not isinstance(payload, dict):
raise RuntimeError("guest archive response must be a JSON object")
error = payload.get("error")
if error is not None:
raise RuntimeError(str(error))
return GuestArchiveResponse(
destination=str(payload.get("destination", destination)),
entry_count=int(payload.get("entry_count", 0)),
bytes_written=int(payload.get("bytes_written", 0)),
)
def open_shell(
self,
guest_cid: int,
port: int,
*,
shell_id: str,
cwd: str,
cols: int,
rows: int,
timeout_seconds: int = 30,
uds_path: str | None = None,
) -> GuestShellSummary:
payload = self._request_json(
guest_cid,
port,
{
"action": "open_shell",
"shell_id": shell_id,
"cwd": cwd,
"cols": cols,
"rows": rows,
},
timeout_seconds=timeout_seconds,
uds_path=uds_path,
error_message="guest shell open response must be a JSON object",
)
return self._shell_summary_from_payload(payload)
def read_shell(
self,
guest_cid: int,
port: int,
*,
shell_id: str,
cursor: int,
max_chars: int,
timeout_seconds: int = 30,
uds_path: str | None = None,
) -> GuestShellReadResponse:
payload = self._request_json(
guest_cid,
port,
{
"action": "read_shell",
"shell_id": shell_id,
"cursor": cursor,
"max_chars": max_chars,
},
timeout_seconds=timeout_seconds,
uds_path=uds_path,
error_message="guest shell read response must be a JSON object",
)
summary = self._shell_summary_from_payload(payload)
return GuestShellReadResponse(
shell_id=summary.shell_id,
cwd=summary.cwd,
cols=summary.cols,
rows=summary.rows,
state=summary.state,
started_at=summary.started_at,
ended_at=summary.ended_at,
exit_code=summary.exit_code,
cursor=int(payload.get("cursor", cursor)),
next_cursor=int(payload.get("next_cursor", cursor)),
output=str(payload.get("output", "")),
truncated=bool(payload.get("truncated", False)),
)
def write_shell(
self,
guest_cid: int,
port: int,
*,
shell_id: str,
input_text: str,
append_newline: bool,
timeout_seconds: int = 30,
uds_path: str | None = None,
) -> dict[str, Any]:
payload = self._request_json(
guest_cid,
port,
{
"action": "write_shell",
"shell_id": shell_id,
"input": input_text,
"append_newline": append_newline,
},
timeout_seconds=timeout_seconds,
uds_path=uds_path,
error_message="guest shell write response must be a JSON object",
)
self._shell_summary_from_payload(payload)
return payload
def signal_shell(
self,
guest_cid: int,
port: int,
*,
shell_id: str,
signal_name: str,
timeout_seconds: int = 30,
uds_path: str | None = None,
) -> dict[str, Any]:
payload = self._request_json(
guest_cid,
port,
{
"action": "signal_shell",
"shell_id": shell_id,
"signal": signal_name,
},
timeout_seconds=timeout_seconds,
uds_path=uds_path,
error_message="guest shell signal response must be a JSON object",
)
self._shell_summary_from_payload(payload)
return payload
def close_shell(
self,
guest_cid: int,
port: int,
*,
shell_id: str,
timeout_seconds: int = 30,
uds_path: str | None = None,
) -> dict[str, Any]:
payload = self._request_json(
guest_cid,
port,
{
"action": "close_shell",
"shell_id": shell_id,
},
timeout_seconds=timeout_seconds,
uds_path=uds_path,
error_message="guest shell close response must be a JSON object",
)
self._shell_summary_from_payload(payload)
return payload
def _request_json(
self,
guest_cid: int,
port: int,
request: dict[str, Any],
*,
timeout_seconds: int,
uds_path: str | None,
error_message: str,
) -> dict[str, Any]:
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
try:
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
payload = self._recv_json_payload(sock)
finally:
sock.close()
if not isinstance(payload, dict):
raise RuntimeError(error_message)
error = payload.get("error")
if error is not None:
raise RuntimeError(str(error))
return payload
@staticmethod
def _shell_summary_from_payload(payload: dict[str, Any]) -> GuestShellSummary:
return GuestShellSummary(
shell_id=str(payload.get("shell_id", "")),
cwd=str(payload.get("cwd", "/workspace")),
cols=int(payload.get("cols", 0)),
rows=int(payload.get("rows", 0)),
state=str(payload.get("state", "stopped")),
started_at=float(payload.get("started_at", 0.0)),
ended_at=(
None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0))
),
exit_code=(
None if payload.get("exit_code") is None else int(payload.get("exit_code", 0))
),
)
def _connect(
self,
guest_cid: int,

View file

@ -27,8 +27,15 @@ from pyro_mcp.vm_environments import EnvironmentStore, default_cache_dir, get_en
from pyro_mcp.vm_firecracker import build_launch_plan
from pyro_mcp.vm_guest import VsockExecClient
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
from pyro_mcp.workspace_shells import (
create_local_shell,
get_local_shell,
remove_local_shell,
shell_signal_names,
)
VmState = Literal["created", "started", "stopped"]
WorkspaceShellState = Literal["running", "stopped"]
DEFAULT_VCPU_COUNT = 1
DEFAULT_MEM_MIB = 1024
@ -36,13 +43,18 @@ DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False
WORKSPACE_LAYOUT_VERSION = 2
WORKSPACE_LAYOUT_VERSION = 3
WORKSPACE_DIRNAME = "workspace"
WORKSPACE_COMMANDS_DIRNAME = "commands"
WORKSPACE_SHELLS_DIRNAME = "shells"
WORKSPACE_RUNTIME_DIRNAME = "runtime"
WORKSPACE_GUEST_PATH = "/workspace"
WORKSPACE_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py"
WORKSPACE_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60
DEFAULT_SHELL_COLS = 120
DEFAULT_SHELL_ROWS = 30
DEFAULT_SHELL_MAX_CHARS = 65536
WORKSPACE_SHELL_SIGNAL_NAMES = shell_signal_names()
WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"]
@ -183,6 +195,58 @@ class WorkspaceRecord:
)
@dataclass
class WorkspaceShellRecord:
"""Persistent shell metadata stored on disk per workspace."""
workspace_id: str
shell_id: str
cwd: str
cols: int
rows: int
state: WorkspaceShellState
started_at: float
ended_at: float | None = None
exit_code: int | None = None
execution_mode: str = "pending"
metadata: dict[str, str] = field(default_factory=dict)
def to_payload(self) -> dict[str, Any]:
return {
"layout_version": WORKSPACE_LAYOUT_VERSION,
"workspace_id": self.workspace_id,
"shell_id": self.shell_id,
"cwd": self.cwd,
"cols": self.cols,
"rows": self.rows,
"state": self.state,
"started_at": self.started_at,
"ended_at": self.ended_at,
"exit_code": self.exit_code,
"execution_mode": self.execution_mode,
"metadata": dict(self.metadata),
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> WorkspaceShellRecord:
return cls(
workspace_id=str(payload["workspace_id"]),
shell_id=str(payload["shell_id"]),
cwd=str(payload.get("cwd", WORKSPACE_GUEST_PATH)),
cols=int(payload.get("cols", DEFAULT_SHELL_COLS)),
rows=int(payload.get("rows", DEFAULT_SHELL_ROWS)),
state=cast(WorkspaceShellState, str(payload.get("state", "stopped"))),
started_at=float(payload.get("started_at", 0.0)),
ended_at=(
None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0))
),
exit_code=(
None if payload.get("exit_code") is None else int(payload.get("exit_code", 0))
),
execution_mode=str(payload.get("execution_mode", "pending")),
metadata=_string_dict(payload.get("metadata")),
)
@dataclass(frozen=True)
class PreparedWorkspaceSeed:
"""Prepared host-side seed archive plus metadata."""
@ -610,6 +674,59 @@ class VmBackend:
) -> dict[str, Any]:
raise NotImplementedError
def open_shell( # pragma: no cover
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
cwd: str,
cols: int,
rows: int,
) -> dict[str, Any]:
raise NotImplementedError
def read_shell( # pragma: no cover
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
cursor: int,
max_chars: int,
) -> dict[str, Any]:
raise NotImplementedError
def write_shell( # pragma: no cover
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
input_text: str,
append_newline: bool,
) -> dict[str, Any]:
raise NotImplementedError
def signal_shell( # pragma: no cover
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
signal_name: str,
) -> dict[str, Any]:
raise NotImplementedError
def close_shell( # pragma: no cover
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
) -> dict[str, Any]:
raise NotImplementedError
class MockBackend(VmBackend):
"""Host-process backend used for development and testability."""
@ -651,6 +768,87 @@ class MockBackend(VmBackend):
destination=destination,
)
def open_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
cwd: str,
cols: int,
rows: int,
) -> dict[str, Any]:
session = create_local_shell(
workspace_id=workspace_id,
shell_id=shell_id,
cwd=_workspace_host_destination(_instance_workspace_host_dir(instance), cwd),
display_cwd=cwd,
cols=cols,
rows=rows,
)
summary = session.summary()
summary["execution_mode"] = "host_compat"
return summary
def read_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
cursor: int,
max_chars: int,
) -> dict[str, Any]:
del instance
session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id)
payload = session.read(cursor=cursor, max_chars=max_chars)
payload["execution_mode"] = "host_compat"
return payload
def write_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
input_text: str,
append_newline: bool,
) -> dict[str, Any]:
del instance
session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id)
payload = session.write(input_text, append_newline=append_newline)
payload["execution_mode"] = "host_compat"
return payload
def signal_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
signal_name: str,
) -> dict[str, Any]:
del instance
session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id)
payload = session.send_signal(signal_name)
payload["execution_mode"] = "host_compat"
return payload
def close_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
) -> dict[str, Any]:
del instance
session = remove_local_shell(workspace_id=workspace_id, shell_id=shell_id)
if session is None:
raise ValueError(f"shell {shell_id!r} does not exist in workspace {workspace_id!r}")
payload = session.close()
payload["execution_mode"] = "host_compat"
return payload
class FirecrackerBackend(VmBackend): # pragma: no cover
"""Host-gated backend that validates Firecracker prerequisites."""
@ -888,6 +1086,144 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
destination=destination,
)
def open_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
cwd: str,
cols: int,
rows: int,
) -> dict[str, Any]:
del workspace_id
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
uds_path = instance.metadata.get("guest_exec_uds_path")
response = self._guest_exec_client.open_shell(
guest_cid,
port,
shell_id=shell_id,
cwd=cwd,
cols=cols,
rows=rows,
uds_path=uds_path,
)
return {
"shell_id": response.shell_id or shell_id,
"cwd": response.cwd,
"cols": response.cols,
"rows": response.rows,
"state": response.state,
"started_at": response.started_at,
"ended_at": response.ended_at,
"exit_code": response.exit_code,
"execution_mode": instance.metadata.get("execution_mode", "pending"),
}
def read_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
cursor: int,
max_chars: int,
) -> dict[str, Any]:
del workspace_id
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
uds_path = instance.metadata.get("guest_exec_uds_path")
response = self._guest_exec_client.read_shell(
guest_cid,
port,
shell_id=shell_id,
cursor=cursor,
max_chars=max_chars,
uds_path=uds_path,
)
return {
"shell_id": response.shell_id,
"cwd": response.cwd,
"cols": response.cols,
"rows": response.rows,
"state": response.state,
"started_at": response.started_at,
"ended_at": response.ended_at,
"exit_code": response.exit_code,
"cursor": response.cursor,
"next_cursor": response.next_cursor,
"output": response.output,
"truncated": response.truncated,
"execution_mode": instance.metadata.get("execution_mode", "pending"),
}
def write_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
input_text: str,
append_newline: bool,
) -> dict[str, Any]:
del workspace_id
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
uds_path = instance.metadata.get("guest_exec_uds_path")
payload = self._guest_exec_client.write_shell(
guest_cid,
port,
shell_id=shell_id,
input_text=input_text,
append_newline=append_newline,
uds_path=uds_path,
)
payload["execution_mode"] = instance.metadata.get("execution_mode", "pending")
return payload
def signal_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
signal_name: str,
) -> dict[str, Any]:
del workspace_id
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
uds_path = instance.metadata.get("guest_exec_uds_path")
payload = self._guest_exec_client.signal_shell(
guest_cid,
port,
shell_id=shell_id,
signal_name=signal_name,
uds_path=uds_path,
)
payload["execution_mode"] = instance.metadata.get("execution_mode", "pending")
return payload
def close_shell(
self,
instance: VmInstance,
*,
workspace_id: str,
shell_id: str,
) -> dict[str, Any]:
del workspace_id
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
uds_path = instance.metadata.get("guest_exec_uds_path")
payload = self._guest_exec_client.close_shell(
guest_cid,
port,
shell_id=shell_id,
uds_path=uds_path,
)
payload["execution_mode"] = instance.metadata.get("execution_mode", "pending")
return payload
class VmManager:
"""In-process lifecycle manager for ephemeral VM environments and workspaces."""
@ -1151,9 +1487,11 @@ class VmManager:
runtime_dir = self._workspace_runtime_dir(workspace_id)
host_workspace_dir = self._workspace_host_dir(workspace_id)
commands_dir = self._workspace_commands_dir(workspace_id)
shells_dir = self._workspace_shells_dir(workspace_id)
workspace_dir.mkdir(parents=True, exist_ok=False)
host_workspace_dir.mkdir(parents=True, exist_ok=True)
commands_dir.mkdir(parents=True, exist_ok=True)
shells_dir.mkdir(parents=True, exist_ok=True)
instance = VmInstance(
vm_id=workspace_id,
environment=environment,
@ -1179,11 +1517,8 @@ class VmManager:
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
self._backend.create(instance)
if (
prepared_seed.archive_path is not None
and self._runtime_capabilities.supports_guest_exec
):
self._ensure_workspace_guest_seed_support(instance)
if self._runtime_capabilities.supports_guest_exec:
self._ensure_workspace_guest_agent_support(instance)
with self._lock:
self._start_instance_locked(instance)
self._require_guest_exec_or_opt_in(instance)
@ -1332,6 +1667,208 @@ class VmManager:
"cwd": WORKSPACE_GUEST_PATH,
}
def open_shell(
self,
workspace_id: str,
*,
cwd: str = WORKSPACE_GUEST_PATH,
cols: int = DEFAULT_SHELL_COLS,
rows: int = DEFAULT_SHELL_ROWS,
) -> dict[str, Any]:
if cols <= 0:
raise ValueError("cols must be positive")
if rows <= 0:
raise ValueError("rows must be positive")
normalized_cwd, _ = _normalize_workspace_destination(cwd)
shell_id = uuid.uuid4().hex[:12]
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
instance = self._workspace_instance_for_live_shell_locked(workspace)
payload = self._backend.open_shell(
instance,
workspace_id=workspace_id,
shell_id=shell_id,
cwd=normalized_cwd,
cols=cols,
rows=rows,
)
shell = self._workspace_shell_record_from_payload(
workspace_id=workspace_id,
shell_id=shell_id,
payload=payload,
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = instance.state
workspace.firecracker_pid = instance.firecracker_pid
workspace.last_error = instance.last_error
workspace.metadata = dict(instance.metadata)
self._save_workspace_locked(workspace)
self._save_workspace_shell_locked(shell)
return self._serialize_workspace_shell(shell)
def read_shell(
self,
workspace_id: str,
shell_id: str,
*,
cursor: int = 0,
max_chars: int = DEFAULT_SHELL_MAX_CHARS,
) -> dict[str, Any]:
if cursor < 0:
raise ValueError("cursor must not be negative")
if max_chars <= 0:
raise ValueError("max_chars must be positive")
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
instance = self._workspace_instance_for_live_shell_locked(workspace)
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
payload = self._backend.read_shell(
instance,
workspace_id=workspace_id,
shell_id=shell_id,
cursor=cursor,
max_chars=max_chars,
)
updated_shell = self._workspace_shell_record_from_payload(
workspace_id=workspace_id,
shell_id=shell_id,
payload=payload,
metadata=shell.metadata,
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = instance.state
workspace.firecracker_pid = instance.firecracker_pid
workspace.last_error = instance.last_error
workspace.metadata = dict(instance.metadata)
self._save_workspace_locked(workspace)
self._save_workspace_shell_locked(updated_shell)
response = self._serialize_workspace_shell(updated_shell)
response.update(
{
"cursor": int(payload.get("cursor", cursor)),
"next_cursor": int(payload.get("next_cursor", cursor)),
"output": str(payload.get("output", "")),
"truncated": bool(payload.get("truncated", False)),
}
)
return response
def write_shell(
self,
workspace_id: str,
shell_id: str,
*,
input_text: str,
append_newline: bool = True,
) -> dict[str, Any]:
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
instance = self._workspace_instance_for_live_shell_locked(workspace)
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
payload = self._backend.write_shell(
instance,
workspace_id=workspace_id,
shell_id=shell_id,
input_text=input_text,
append_newline=append_newline,
)
updated_shell = self._workspace_shell_record_from_payload(
workspace_id=workspace_id,
shell_id=shell_id,
payload=payload,
metadata=shell.metadata,
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = instance.state
workspace.firecracker_pid = instance.firecracker_pid
workspace.last_error = instance.last_error
workspace.metadata = dict(instance.metadata)
self._save_workspace_locked(workspace)
self._save_workspace_shell_locked(updated_shell)
response = self._serialize_workspace_shell(updated_shell)
response.update(
{
"input_length": int(payload.get("input_length", len(input_text))),
"append_newline": bool(payload.get("append_newline", append_newline)),
}
)
return response
def signal_shell(
self,
workspace_id: str,
shell_id: str,
*,
signal_name: str = "INT",
) -> dict[str, Any]:
normalized_signal = signal_name.upper()
if normalized_signal not in WORKSPACE_SHELL_SIGNAL_NAMES:
raise ValueError(
f"signal_name must be one of: {', '.join(WORKSPACE_SHELL_SIGNAL_NAMES)}"
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
instance = self._workspace_instance_for_live_shell_locked(workspace)
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
payload = self._backend.signal_shell(
instance,
workspace_id=workspace_id,
shell_id=shell_id,
signal_name=normalized_signal,
)
updated_shell = self._workspace_shell_record_from_payload(
workspace_id=workspace_id,
shell_id=shell_id,
payload=payload,
metadata=shell.metadata,
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = instance.state
workspace.firecracker_pid = instance.firecracker_pid
workspace.last_error = instance.last_error
workspace.metadata = dict(instance.metadata)
self._save_workspace_locked(workspace)
self._save_workspace_shell_locked(updated_shell)
response = self._serialize_workspace_shell(updated_shell)
response["signal"] = str(payload.get("signal", normalized_signal))
return response
def close_shell(
self,
workspace_id: str,
shell_id: str,
) -> dict[str, Any]:
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
instance = self._workspace_instance_for_live_shell_locked(workspace)
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
payload = self._backend.close_shell(
instance,
workspace_id=workspace_id,
shell_id=shell_id,
)
closed_shell = self._workspace_shell_record_from_payload(
workspace_id=workspace_id,
shell_id=shell_id,
payload=payload,
metadata=shell.metadata,
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = instance.state
workspace.firecracker_pid = instance.firecracker_pid
workspace.last_error = instance.last_error
workspace.metadata = dict(instance.metadata)
self._save_workspace_locked(workspace)
self._delete_workspace_shell_locked(workspace_id, shell_id)
response = self._serialize_workspace_shell(closed_shell)
response["closed"] = bool(payload.get("closed", True))
return response
def status_workspace(self, workspace_id: str) -> dict[str, Any]:
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
@ -1364,6 +1901,7 @@ class VmManager:
instance = workspace.to_instance(
workdir=self._workspace_runtime_dir(workspace.workspace_id)
)
self._close_workspace_shells_locked(workspace, instance)
if workspace.state == "started":
self._backend.stop(instance)
workspace.state = "stopped"
@ -1423,6 +1961,20 @@ class VmManager:
"metadata": workspace.metadata,
}
def _serialize_workspace_shell(self, shell: WorkspaceShellRecord) -> dict[str, Any]:
return {
"workspace_id": shell.workspace_id,
"shell_id": shell.shell_id,
"cwd": shell.cwd,
"cols": shell.cols,
"rows": shell.rows,
"state": shell.state,
"started_at": shell.started_at,
"ended_at": shell.ended_at,
"exit_code": shell.exit_code,
"execution_mode": shell.execution_mode,
}
def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None:
if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat:
return
@ -1445,6 +1997,19 @@ class VmManager:
"host execution."
)
def _require_workspace_shell_support(self, instance: VmInstance) -> None:
if self._backend_name == "mock":
return
if self._runtime_capabilities.supports_guest_exec:
return
reason = self._runtime_capabilities.reason or (
"runtime does not support guest interactive shell sessions"
)
raise RuntimeError(
"interactive shells require guest execution and are unavailable for this "
f"workspace: {reason}"
)
def _get_instance_locked(self, vm_id: str) -> VmInstance:
try:
return self._instances[vm_id]
@ -1552,14 +2117,14 @@ class VmManager:
bytes_written=bytes_written,
)
def _ensure_workspace_guest_seed_support(self, instance: VmInstance) -> None:
def _ensure_workspace_guest_agent_support(self, instance: VmInstance) -> None:
if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None:
raise RuntimeError(
"runtime bundle does not provide a guest agent for workspace seeding"
"runtime bundle does not provide a guest agent for workspace operations"
)
rootfs_image = instance.metadata.get("rootfs_image")
if rootfs_image is None or rootfs_image == "":
raise RuntimeError("workspace rootfs image is unavailable for guest seeding")
raise RuntimeError("workspace rootfs image is unavailable for guest operations")
_patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path)
def _workspace_dir(self, workspace_id: str) -> Path:
@ -1574,9 +2139,15 @@ class VmManager:
def _workspace_commands_dir(self, workspace_id: str) -> Path:
return self._workspace_dir(workspace_id) / WORKSPACE_COMMANDS_DIRNAME
def _workspace_shells_dir(self, workspace_id: str) -> Path:
return self._workspace_dir(workspace_id) / WORKSPACE_SHELLS_DIRNAME
def _workspace_metadata_path(self, workspace_id: str) -> Path:
return self._workspace_dir(workspace_id) / "workspace.json"
def _workspace_shell_record_path(self, workspace_id: str, shell_id: str) -> Path:
return self._workspace_shells_dir(workspace_id) / f"{shell_id}.json"
def _count_workspaces_locked(self) -> int:
return sum(1 for _ in self._workspaces_dir.glob("*/workspace.json"))
@ -1609,6 +2180,7 @@ class VmManager:
instance = workspace.to_instance(
workdir=self._workspace_runtime_dir(workspace.workspace_id)
)
self._close_workspace_shells_locked(workspace, instance)
if workspace.state == "started":
self._backend.stop(instance)
workspace.state = "stopped"
@ -1704,3 +2276,97 @@ class VmManager:
entry["stderr"] = stderr
entries.append(entry)
return entries
def _workspace_instance_for_live_shell_locked(self, workspace: WorkspaceRecord) -> VmInstance:
self._ensure_workspace_not_expired_locked(workspace, time.time())
self._refresh_workspace_liveness_locked(workspace)
if workspace.state != "started":
raise RuntimeError(
"workspace "
f"{workspace.workspace_id} must be in 'started' state before shell operations"
)
instance = workspace.to_instance(
workdir=self._workspace_runtime_dir(workspace.workspace_id)
)
self._require_workspace_shell_support(instance)
return instance
def _workspace_shell_record_from_payload(
self,
*,
workspace_id: str,
shell_id: str,
payload: dict[str, Any],
metadata: dict[str, str] | None = None,
) -> WorkspaceShellRecord:
return WorkspaceShellRecord(
workspace_id=workspace_id,
shell_id=str(payload.get("shell_id", shell_id)),
cwd=str(payload.get("cwd", WORKSPACE_GUEST_PATH)),
cols=int(payload.get("cols", DEFAULT_SHELL_COLS)),
rows=int(payload.get("rows", DEFAULT_SHELL_ROWS)),
state=cast(WorkspaceShellState, str(payload.get("state", "stopped"))),
started_at=float(payload.get("started_at", time.time())),
ended_at=(
None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0))
),
exit_code=(
None if payload.get("exit_code") is None else int(payload.get("exit_code", 0))
),
execution_mode=str(payload.get("execution_mode", "pending")),
metadata=dict(metadata or {}),
)
def _load_workspace_shell_locked(
self,
workspace_id: str,
shell_id: str,
) -> WorkspaceShellRecord:
record_path = self._workspace_shell_record_path(workspace_id, shell_id)
if not record_path.exists():
raise ValueError(f"shell {shell_id!r} does not exist in workspace {workspace_id!r}")
payload = json.loads(record_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError(f"shell record at {record_path} is invalid")
return WorkspaceShellRecord.from_payload(payload)
def _save_workspace_shell_locked(self, shell: WorkspaceShellRecord) -> None:
record_path = self._workspace_shell_record_path(shell.workspace_id, shell.shell_id)
record_path.parent.mkdir(parents=True, exist_ok=True)
record_path.write_text(
json.dumps(shell.to_payload(), indent=2, sort_keys=True),
encoding="utf-8",
)
def _delete_workspace_shell_locked(self, workspace_id: str, shell_id: str) -> None:
record_path = self._workspace_shell_record_path(workspace_id, shell_id)
if record_path.exists():
record_path.unlink()
def _list_workspace_shells_locked(self, workspace_id: str) -> list[WorkspaceShellRecord]:
shells_dir = self._workspace_shells_dir(workspace_id)
if not shells_dir.exists():
return []
shells: list[WorkspaceShellRecord] = []
for record_path in sorted(shells_dir.glob("*.json")):
payload = json.loads(record_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
continue
shells.append(WorkspaceShellRecord.from_payload(payload))
return shells
def _close_workspace_shells_locked(
self,
workspace: WorkspaceRecord,
instance: VmInstance,
) -> None:
for shell in self._list_workspace_shells_locked(workspace.workspace_id):
try:
self._backend.close_shell(
instance,
workspace_id=workspace.workspace_id,
shell_id=shell.shell_id,
)
except Exception:
pass
self._delete_workspace_shell_locked(workspace.workspace_id, shell.shell_id)

View file

@ -0,0 +1,291 @@
"""Local PTY-backed shell sessions for the mock workspace backend."""
from __future__ import annotations
import codecs
import fcntl
import os
import pty
import shlex
import signal
import struct
import subprocess
import termios
import threading
import time
from pathlib import Path
from typing import Literal
ShellState = Literal["running", "stopped"]
SHELL_SIGNAL_NAMES = ("HUP", "INT", "TERM", "KILL")
_SHELL_SIGNAL_MAP = {
"HUP": signal.SIGHUP,
"INT": signal.SIGINT,
"TERM": signal.SIGTERM,
"KILL": signal.SIGKILL,
}
_LOCAL_SHELLS: dict[str, "LocalShellSession"] = {}
_LOCAL_SHELLS_LOCK = threading.Lock()
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
winsize = struct.pack("HHHH", rows, cols, 0, 0)
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
class LocalShellSession:
"""Host-local interactive shell used by the mock backend."""
def __init__(
self,
*,
shell_id: str,
cwd: Path,
display_cwd: str,
cols: int,
rows: int,
) -> None:
self.shell_id = shell_id
self.cwd = display_cwd
self.cols = cols
self.rows = rows
self.started_at = time.time()
self.ended_at: float | None = None
self.exit_code: int | None = None
self.state: ShellState = "running"
self.pid: int | None = None
self._lock = threading.RLock()
self._output = ""
self._master_fd: int | None = None
self._reader: threading.Thread | None = None
self._waiter: threading.Thread | None = None
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
master_fd, slave_fd = pty.openpty()
try:
_set_pty_size(slave_fd, rows, cols)
env = os.environ.copy()
env.update(
{
"TERM": env.get("TERM", "xterm-256color"),
"PS1": "pyro$ ",
"PROMPT_COMMAND": "",
}
)
process = subprocess.Popen( # noqa: S603
["/bin/bash", "--noprofile", "--norc", "-i"],
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd=str(cwd),
env=env,
text=False,
close_fds=True,
preexec_fn=os.setsid,
)
except Exception:
os.close(master_fd)
raise
finally:
os.close(slave_fd)
self._process = process
self.pid = process.pid
self._master_fd = master_fd
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
self._reader.start()
self._waiter.start()
def summary(self) -> dict[str, object]:
with self._lock:
return {
"shell_id": self.shell_id,
"cwd": self.cwd,
"cols": self.cols,
"rows": self.rows,
"state": self.state,
"started_at": self.started_at,
"ended_at": self.ended_at,
"exit_code": self.exit_code,
"pid": self.pid,
}
def read(self, *, cursor: int, max_chars: int) -> dict[str, object]:
with self._lock:
clamped_cursor = min(max(cursor, 0), len(self._output))
output = self._output[clamped_cursor : clamped_cursor + max_chars]
next_cursor = clamped_cursor + len(output)
payload = self.summary()
payload.update(
{
"cursor": clamped_cursor,
"next_cursor": next_cursor,
"output": output,
"truncated": next_cursor < len(self._output),
}
)
return payload
def write(self, text: str, *, append_newline: bool) -> dict[str, object]:
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
master_fd = self._master_fd
if master_fd is None:
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
payload = text + ("\n" if append_newline else "")
try:
os.write(master_fd, payload.encode("utf-8"))
except OSError as exc:
self._refresh_process_state()
raise RuntimeError(f"failed to write to shell {self.shell_id}: {exc}") from exc
result = self.summary()
result.update({"input_length": len(text), "append_newline": append_newline})
return result
def send_signal(self, signal_name: str) -> dict[str, object]:
signal_name = signal_name.upper()
signum = _SHELL_SIGNAL_MAP.get(signal_name)
if signum is None:
raise ValueError(f"unsupported shell signal: {signal_name}")
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running" or self.pid is None:
raise RuntimeError(f"shell {self.shell_id} is not running")
pid = self.pid
try:
os.killpg(pid, signum)
except ProcessLookupError as exc:
self._refresh_process_state()
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
result = self.summary()
result["signal"] = signal_name
return result
def close(self) -> dict[str, object]:
if self._process.poll() is None and self.pid is not None:
try:
os.killpg(self.pid, signal.SIGHUP)
except ProcessLookupError:
pass
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
try:
os.killpg(self.pid, signal.SIGKILL)
except ProcessLookupError:
pass
self._process.wait(timeout=5)
else:
self._refresh_process_state()
self._close_master_fd()
if self._reader is not None:
self._reader.join(timeout=1)
if self._waiter is not None:
self._waiter.join(timeout=1)
result = self.summary()
result["closed"] = True
return result
def _reader_loop(self) -> None:
master_fd = self._master_fd
if master_fd is None:
return
while True:
try:
chunk = os.read(master_fd, 65536)
except OSError:
break
if chunk == b"":
break
decoded = self._decoder.decode(chunk)
if decoded:
with self._lock:
self._output += decoded
decoded = self._decoder.decode(b"", final=True)
if decoded:
with self._lock:
self._output += decoded
def _waiter_loop(self) -> None:
exit_code = self._process.wait()
with self._lock:
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
def _refresh_process_state(self) -> None:
exit_code = self._process.poll()
if exit_code is None:
return
with self._lock:
if self.state == "running":
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
def _close_master_fd(self) -> None:
with self._lock:
master_fd = self._master_fd
self._master_fd = None
if master_fd is None:
return
try:
os.close(master_fd)
except OSError:
pass
def create_local_shell(
*,
workspace_id: str,
shell_id: str,
cwd: Path,
display_cwd: str,
cols: int,
rows: int,
) -> LocalShellSession:
session_key = f"{workspace_id}:{shell_id}"
with _LOCAL_SHELLS_LOCK:
if session_key in _LOCAL_SHELLS:
raise RuntimeError(f"shell {shell_id} already exists in workspace {workspace_id}")
session = LocalShellSession(
shell_id=shell_id,
cwd=cwd,
display_cwd=display_cwd,
cols=cols,
rows=rows,
)
_LOCAL_SHELLS[session_key] = session
return session
def get_local_shell(*, workspace_id: str, shell_id: str) -> LocalShellSession:
session_key = f"{workspace_id}:{shell_id}"
with _LOCAL_SHELLS_LOCK:
try:
return _LOCAL_SHELLS[session_key]
except KeyError as exc:
raise ValueError(
f"shell {shell_id!r} does not exist in workspace {workspace_id!r}"
) from exc
def remove_local_shell(*, workspace_id: str, shell_id: str) -> LocalShellSession | None:
session_key = f"{workspace_id}:{shell_id}"
with _LOCAL_SHELLS_LOCK:
return _LOCAL_SHELLS.pop(session_key, None)
def shell_signal_names() -> tuple[str, ...]:
return SHELL_SIGNAL_NAMES
def shell_signal_arg_help() -> str:
return ", ".join(shlex.quote(name) for name in SHELL_SIGNAL_NAMES)