Add persistent workspace shell sessions
Let agents inhabit a workspace across separate calls instead of only submitting one-shot execs. Add workspace shell open/read/write/signal/close across the CLI, Python SDK, and MCP server, with persisted shell records, a local PTY-backed mock implementation, and guest-agent support for real Firecracker workspaces. Mark the 2.5.0 roadmap milestone done, refresh docs/examples and the release metadata, and verify with uv lock, UV_CACHE_DIR=.uv-cache make check, and UV_CACHE_DIR=.uv-cache make dist-check.
This commit is contained in:
parent
2de31306b6
commit
3f8293ad24
28 changed files with 3265 additions and 81 deletions
|
|
@ -130,6 +130,67 @@ class Pyro:
|
|||
def logs_workspace(self, workspace_id: str) -> dict[str, Any]:
|
||||
return self._manager.logs_workspace(workspace_id)
|
||||
|
||||
def open_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
*,
|
||||
cwd: str = "/workspace",
|
||||
cols: int = 120,
|
||||
rows: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.open_shell(
|
||||
workspace_id,
|
||||
cwd=cwd,
|
||||
cols=cols,
|
||||
rows=rows,
|
||||
)
|
||||
|
||||
def read_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
*,
|
||||
cursor: int = 0,
|
||||
max_chars: int = 65536,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.read_shell(
|
||||
workspace_id,
|
||||
shell_id,
|
||||
cursor=cursor,
|
||||
max_chars=max_chars,
|
||||
)
|
||||
|
||||
def write_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
*,
|
||||
input: str,
|
||||
append_newline: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.write_shell(
|
||||
workspace_id,
|
||||
shell_id,
|
||||
input_text=input,
|
||||
append_newline=append_newline,
|
||||
)
|
||||
|
||||
def signal_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
*,
|
||||
signal_name: str = "INT",
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.signal_shell(
|
||||
workspace_id,
|
||||
shell_id,
|
||||
signal_name=signal_name,
|
||||
)
|
||||
|
||||
def close_shell(self, workspace_id: str, shell_id: str) -> dict[str, Any]:
|
||||
return self._manager.close_shell(workspace_id, shell_id)
|
||||
|
||||
def delete_workspace(self, workspace_id: str) -> dict[str, Any]:
|
||||
return self._manager.delete_workspace(workspace_id)
|
||||
|
||||
|
|
@ -309,6 +370,64 @@ class Pyro:
|
|||
"""Return persisted command history for one workspace."""
|
||||
return self.logs_workspace(workspace_id)
|
||||
|
||||
@server.tool()
|
||||
async def shell_open(
|
||||
workspace_id: str,
|
||||
cwd: str = "/workspace",
|
||||
cols: int = 120,
|
||||
rows: int = 30,
|
||||
) -> dict[str, Any]:
|
||||
"""Open a persistent interactive shell inside one workspace."""
|
||||
return self.open_shell(workspace_id, cwd=cwd, cols=cols, rows=rows)
|
||||
|
||||
@server.tool()
|
||||
async def shell_read(
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cursor: int = 0,
|
||||
max_chars: int = 65536,
|
||||
) -> dict[str, Any]:
|
||||
"""Read merged PTY output from a workspace shell."""
|
||||
return self.read_shell(
|
||||
workspace_id,
|
||||
shell_id,
|
||||
cursor=cursor,
|
||||
max_chars=max_chars,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def shell_write(
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
input: str,
|
||||
append_newline: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Write text input to a persistent workspace shell."""
|
||||
return self.write_shell(
|
||||
workspace_id,
|
||||
shell_id,
|
||||
input=input,
|
||||
append_newline=append_newline,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def shell_signal(
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
signal_name: str = "INT",
|
||||
) -> dict[str, Any]:
|
||||
"""Send a signal to the shell process group."""
|
||||
return self.signal_shell(
|
||||
workspace_id,
|
||||
shell_id,
|
||||
signal_name=signal_name,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
async def shell_close(workspace_id: str, shell_id: str) -> dict[str, Any]:
|
||||
"""Close a persistent workspace shell."""
|
||||
return self.close_shell(workspace_id, shell_id)
|
||||
|
||||
@server.tool()
|
||||
async def workspace_delete(workspace_id: str) -> dict[str, Any]:
|
||||
"""Delete a persistent workspace and its backing sandbox."""
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from pyro_mcp.vm_manager import (
|
|||
DEFAULT_MEM_MIB,
|
||||
DEFAULT_VCPU_COUNT,
|
||||
WORKSPACE_GUEST_PATH,
|
||||
WORKSPACE_SHELL_SIGNAL_NAMES,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -237,6 +238,37 @@ def _print_workspace_logs_human(payload: dict[str, Any]) -> None:
|
|||
print(stderr, end="" if stderr.endswith("\n") else "\n", file=sys.stderr)
|
||||
|
||||
|
||||
def _print_workspace_shell_summary_human(payload: dict[str, Any], *, prefix: str) -> None:
|
||||
print(
|
||||
f"[{prefix}] "
|
||||
f"workspace_id={str(payload.get('workspace_id', 'unknown'))} "
|
||||
f"shell_id={str(payload.get('shell_id', 'unknown'))} "
|
||||
f"state={str(payload.get('state', 'unknown'))} "
|
||||
f"cwd={str(payload.get('cwd', WORKSPACE_GUEST_PATH))} "
|
||||
f"cols={int(payload.get('cols', 0))} "
|
||||
f"rows={int(payload.get('rows', 0))} "
|
||||
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
def _print_workspace_shell_read_human(payload: dict[str, Any]) -> None:
|
||||
_write_stream(str(payload.get("output", "")), stream=sys.stdout)
|
||||
print(
|
||||
"[workspace-shell-read] "
|
||||
f"workspace_id={str(payload.get('workspace_id', 'unknown'))} "
|
||||
f"shell_id={str(payload.get('shell_id', 'unknown'))} "
|
||||
f"state={str(payload.get('state', 'unknown'))} "
|
||||
f"cursor={int(payload.get('cursor', 0))} "
|
||||
f"next_cursor={int(payload.get('next_cursor', 0))} "
|
||||
f"truncated={bool(payload.get('truncated', False))} "
|
||||
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
||||
|
||||
class _HelpFormatter(
|
||||
argparse.RawDescriptionHelpFormatter,
|
||||
argparse.ArgumentDefaultsHelpFormatter,
|
||||
|
|
@ -269,6 +301,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
Need repeated commands in one workspace after that?
|
||||
pyro workspace create debian:12 --seed-path ./repo
|
||||
pyro workspace sync push WORKSPACE_ID ./changes
|
||||
pyro workspace shell open WORKSPACE_ID
|
||||
|
||||
Use `pyro mcp serve` only after the CLI validation path works.
|
||||
"""
|
||||
|
|
@ -476,6 +509,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
pyro workspace create debian:12 --seed-path ./repo
|
||||
pyro workspace sync push WORKSPACE_ID ./repo --dest src
|
||||
pyro workspace exec WORKSPACE_ID -- sh -lc 'printf "hello\\n" > note.txt'
|
||||
pyro workspace shell open WORKSPACE_ID
|
||||
pyro workspace logs WORKSPACE_ID
|
||||
"""
|
||||
),
|
||||
|
|
@ -633,6 +667,191 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
action="store_true",
|
||||
help="Print structured JSON instead of human-readable output.",
|
||||
)
|
||||
workspace_shell_parser = workspace_subparsers.add_parser(
|
||||
"shell",
|
||||
help="Open and manage persistent interactive shells.",
|
||||
description=(
|
||||
"Open one or more persistent interactive PTY shell sessions inside a started "
|
||||
"workspace."
|
||||
),
|
||||
epilog=dedent(
|
||||
"""
|
||||
Examples:
|
||||
pyro workspace shell open WORKSPACE_ID
|
||||
pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd'
|
||||
pyro workspace shell read WORKSPACE_ID SHELL_ID
|
||||
pyro workspace shell signal WORKSPACE_ID SHELL_ID --signal INT
|
||||
pyro workspace shell close WORKSPACE_ID SHELL_ID
|
||||
|
||||
Use `workspace exec` for one-shot commands. Use `workspace shell` when you need
|
||||
an interactive process that keeps its state between calls.
|
||||
"""
|
||||
),
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
workspace_shell_subparsers = workspace_shell_parser.add_subparsers(
|
||||
dest="workspace_shell_command",
|
||||
required=True,
|
||||
metavar="SHELL",
|
||||
)
|
||||
workspace_shell_open_parser = workspace_shell_subparsers.add_parser(
|
||||
"open",
|
||||
help="Open a persistent interactive shell.",
|
||||
description="Open a new PTY shell inside a started workspace.",
|
||||
epilog="Example:\n pyro workspace shell open WORKSPACE_ID --cwd src",
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
workspace_shell_open_parser.add_argument(
|
||||
"workspace_id",
|
||||
metavar="WORKSPACE_ID",
|
||||
help="Persistent workspace identifier.",
|
||||
)
|
||||
workspace_shell_open_parser.add_argument(
|
||||
"--cwd",
|
||||
default=WORKSPACE_GUEST_PATH,
|
||||
help="Shell working directory. Relative values resolve inside `/workspace`.",
|
||||
)
|
||||
workspace_shell_open_parser.add_argument(
|
||||
"--cols",
|
||||
type=int,
|
||||
default=120,
|
||||
help="Shell terminal width in columns.",
|
||||
)
|
||||
workspace_shell_open_parser.add_argument(
|
||||
"--rows",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Shell terminal height in rows.",
|
||||
)
|
||||
workspace_shell_open_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Print structured JSON instead of human-readable output.",
|
||||
)
|
||||
workspace_shell_read_parser = workspace_shell_subparsers.add_parser(
|
||||
"read",
|
||||
help="Read merged PTY output from a shell.",
|
||||
description="Read merged text output from a persistent workspace shell.",
|
||||
epilog=dedent(
|
||||
"""
|
||||
Example:
|
||||
pyro workspace shell read WORKSPACE_ID SHELL_ID --cursor 0
|
||||
|
||||
Shell output is written to stdout. The read summary is written to stderr.
|
||||
Use --json for a deterministic structured response.
|
||||
"""
|
||||
),
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
workspace_shell_read_parser.add_argument(
|
||||
"workspace_id",
|
||||
metavar="WORKSPACE_ID",
|
||||
help="Persistent workspace identifier.",
|
||||
)
|
||||
workspace_shell_read_parser.add_argument(
|
||||
"shell_id",
|
||||
metavar="SHELL_ID",
|
||||
help="Persistent shell identifier returned by `workspace shell open`.",
|
||||
)
|
||||
workspace_shell_read_parser.add_argument(
|
||||
"--cursor",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Character offset into the merged shell output buffer.",
|
||||
)
|
||||
workspace_shell_read_parser.add_argument(
|
||||
"--max-chars",
|
||||
type=int,
|
||||
default=65536,
|
||||
help="Maximum number of characters to return from the current cursor position.",
|
||||
)
|
||||
workspace_shell_read_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Print structured JSON instead of human-readable output.",
|
||||
)
|
||||
workspace_shell_write_parser = workspace_shell_subparsers.add_parser(
|
||||
"write",
|
||||
help="Write text input into a shell.",
|
||||
description="Write text input into a persistent workspace shell.",
|
||||
epilog="Example:\n pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd'",
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
workspace_shell_write_parser.add_argument(
|
||||
"workspace_id",
|
||||
metavar="WORKSPACE_ID",
|
||||
help="Persistent workspace identifier.",
|
||||
)
|
||||
workspace_shell_write_parser.add_argument(
|
||||
"shell_id",
|
||||
metavar="SHELL_ID",
|
||||
help="Persistent shell identifier returned by `workspace shell open`.",
|
||||
)
|
||||
workspace_shell_write_parser.add_argument(
|
||||
"--input",
|
||||
required=True,
|
||||
help="Text to send to the shell.",
|
||||
)
|
||||
workspace_shell_write_parser.add_argument(
|
||||
"--no-newline",
|
||||
action="store_true",
|
||||
help="Do not append a trailing newline after the provided input.",
|
||||
)
|
||||
workspace_shell_write_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Print structured JSON instead of human-readable output.",
|
||||
)
|
||||
workspace_shell_signal_parser = workspace_shell_subparsers.add_parser(
|
||||
"signal",
|
||||
help="Send a signal to a shell process group.",
|
||||
description="Send a control signal to a persistent workspace shell.",
|
||||
epilog="Example:\n pyro workspace shell signal WORKSPACE_ID SHELL_ID --signal INT",
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
workspace_shell_signal_parser.add_argument(
|
||||
"workspace_id",
|
||||
metavar="WORKSPACE_ID",
|
||||
help="Persistent workspace identifier.",
|
||||
)
|
||||
workspace_shell_signal_parser.add_argument(
|
||||
"shell_id",
|
||||
metavar="SHELL_ID",
|
||||
help="Persistent shell identifier returned by `workspace shell open`.",
|
||||
)
|
||||
workspace_shell_signal_parser.add_argument(
|
||||
"--signal",
|
||||
default="INT",
|
||||
choices=WORKSPACE_SHELL_SIGNAL_NAMES,
|
||||
help="Signal name to send to the shell process group.",
|
||||
)
|
||||
workspace_shell_signal_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Print structured JSON instead of human-readable output.",
|
||||
)
|
||||
workspace_shell_close_parser = workspace_shell_subparsers.add_parser(
|
||||
"close",
|
||||
help="Close a persistent shell.",
|
||||
description="Close a persistent workspace shell and release its PTY state.",
|
||||
epilog="Example:\n pyro workspace shell close WORKSPACE_ID SHELL_ID",
|
||||
formatter_class=_HelpFormatter,
|
||||
)
|
||||
workspace_shell_close_parser.add_argument(
|
||||
"workspace_id",
|
||||
metavar="WORKSPACE_ID",
|
||||
help="Persistent workspace identifier.",
|
||||
)
|
||||
workspace_shell_close_parser.add_argument(
|
||||
"shell_id",
|
||||
metavar="SHELL_ID",
|
||||
help="Persistent shell identifier returned by `workspace shell open`.",
|
||||
)
|
||||
workspace_shell_close_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
help="Print structured JSON instead of human-readable output.",
|
||||
)
|
||||
workspace_status_parser = workspace_subparsers.add_parser(
|
||||
"status",
|
||||
help="Inspect one workspace.",
|
||||
|
|
@ -929,6 +1148,99 @@ def main() -> None:
|
|||
raise SystemExit(1) from exc
|
||||
_print_workspace_sync_human(payload)
|
||||
return
|
||||
if args.workspace_command == "shell":
|
||||
if args.workspace_shell_command == "open":
|
||||
try:
|
||||
payload = pyro.open_shell(
|
||||
args.workspace_id,
|
||||
cwd=args.cwd,
|
||||
cols=args.cols,
|
||||
rows=args.rows,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if bool(args.json):
|
||||
_print_json({"ok": False, "error": str(exc)})
|
||||
else:
|
||||
print(f"[error] {exc}", file=sys.stderr, flush=True)
|
||||
raise SystemExit(1) from exc
|
||||
if bool(args.json):
|
||||
_print_json(payload)
|
||||
else:
|
||||
_print_workspace_shell_summary_human(payload, prefix="workspace-shell-open")
|
||||
return
|
||||
if args.workspace_shell_command == "read":
|
||||
try:
|
||||
payload = pyro.read_shell(
|
||||
args.workspace_id,
|
||||
args.shell_id,
|
||||
cursor=args.cursor,
|
||||
max_chars=args.max_chars,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if bool(args.json):
|
||||
_print_json({"ok": False, "error": str(exc)})
|
||||
else:
|
||||
print(f"[error] {exc}", file=sys.stderr, flush=True)
|
||||
raise SystemExit(1) from exc
|
||||
if bool(args.json):
|
||||
_print_json(payload)
|
||||
else:
|
||||
_print_workspace_shell_read_human(payload)
|
||||
return
|
||||
if args.workspace_shell_command == "write":
|
||||
try:
|
||||
payload = pyro.write_shell(
|
||||
args.workspace_id,
|
||||
args.shell_id,
|
||||
input=args.input,
|
||||
append_newline=not bool(args.no_newline),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if bool(args.json):
|
||||
_print_json({"ok": False, "error": str(exc)})
|
||||
else:
|
||||
print(f"[error] {exc}", file=sys.stderr, flush=True)
|
||||
raise SystemExit(1) from exc
|
||||
if bool(args.json):
|
||||
_print_json(payload)
|
||||
else:
|
||||
_print_workspace_shell_summary_human(payload, prefix="workspace-shell-write")
|
||||
return
|
||||
if args.workspace_shell_command == "signal":
|
||||
try:
|
||||
payload = pyro.signal_shell(
|
||||
args.workspace_id,
|
||||
args.shell_id,
|
||||
signal_name=args.signal,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if bool(args.json):
|
||||
_print_json({"ok": False, "error": str(exc)})
|
||||
else:
|
||||
print(f"[error] {exc}", file=sys.stderr, flush=True)
|
||||
raise SystemExit(1) from exc
|
||||
if bool(args.json):
|
||||
_print_json(payload)
|
||||
else:
|
||||
_print_workspace_shell_summary_human(
|
||||
payload,
|
||||
prefix="workspace-shell-signal",
|
||||
)
|
||||
return
|
||||
if args.workspace_shell_command == "close":
|
||||
try:
|
||||
payload = pyro.close_shell(args.workspace_id, args.shell_id)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if bool(args.json):
|
||||
_print_json({"ok": False, "error": str(exc)})
|
||||
else:
|
||||
print(f"[error] {exc}", file=sys.stderr, flush=True)
|
||||
raise SystemExit(1) from exc
|
||||
if bool(args.json):
|
||||
_print_json(payload)
|
||||
else:
|
||||
_print_workspace_shell_summary_human(payload, prefix="workspace-shell-close")
|
||||
return
|
||||
if args.workspace_command == "status":
|
||||
payload = pyro.status_workspace(args.workspace_id)
|
||||
if bool(args.json):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,16 @@ from __future__ import annotations
|
|||
PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "workspace")
|
||||
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
|
||||
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
|
||||
PUBLIC_CLI_WORKSPACE_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status", "sync")
|
||||
PUBLIC_CLI_WORKSPACE_SUBCOMMANDS = (
|
||||
"create",
|
||||
"delete",
|
||||
"exec",
|
||||
"logs",
|
||||
"shell",
|
||||
"status",
|
||||
"sync",
|
||||
)
|
||||
PUBLIC_CLI_WORKSPACE_SHELL_SUBCOMMANDS = ("close", "open", "read", "signal", "write")
|
||||
PUBLIC_CLI_WORKSPACE_SYNC_SUBCOMMANDS = ("push",)
|
||||
PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
|
||||
"--vcpu-count",
|
||||
|
|
@ -16,6 +25,11 @@ PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
|
|||
"--seed-path",
|
||||
"--json",
|
||||
)
|
||||
PUBLIC_CLI_WORKSPACE_SHELL_OPEN_FLAGS = ("--cwd", "--cols", "--rows", "--json")
|
||||
PUBLIC_CLI_WORKSPACE_SHELL_READ_FLAGS = ("--cursor", "--max-chars", "--json")
|
||||
PUBLIC_CLI_WORKSPACE_SHELL_WRITE_FLAGS = ("--input", "--no-newline", "--json")
|
||||
PUBLIC_CLI_WORKSPACE_SHELL_SIGNAL_FLAGS = ("--signal", "--json")
|
||||
PUBLIC_CLI_WORKSPACE_SHELL_CLOSE_FLAGS = ("--json",)
|
||||
PUBLIC_CLI_WORKSPACE_SYNC_PUSH_FLAGS = ("--dest", "--json")
|
||||
PUBLIC_CLI_RUN_FLAGS = (
|
||||
"--vcpu-count",
|
||||
|
|
@ -28,6 +42,7 @@ PUBLIC_CLI_RUN_FLAGS = (
|
|||
)
|
||||
|
||||
PUBLIC_SDK_METHODS = (
|
||||
"close_shell",
|
||||
"create_server",
|
||||
"create_vm",
|
||||
"create_workspace",
|
||||
|
|
@ -39,18 +54,27 @@ PUBLIC_SDK_METHODS = (
|
|||
"list_environments",
|
||||
"logs_workspace",
|
||||
"network_info_vm",
|
||||
"open_shell",
|
||||
"prune_environments",
|
||||
"pull_environment",
|
||||
"push_workspace_sync",
|
||||
"read_shell",
|
||||
"reap_expired",
|
||||
"run_in_vm",
|
||||
"signal_shell",
|
||||
"start_vm",
|
||||
"status_vm",
|
||||
"status_workspace",
|
||||
"stop_vm",
|
||||
"write_shell",
|
||||
)
|
||||
|
||||
PUBLIC_MCP_TOOLS = (
|
||||
"shell_close",
|
||||
"shell_open",
|
||||
"shell_read",
|
||||
"shell_signal",
|
||||
"shell_write",
|
||||
"vm_create",
|
||||
"vm_delete",
|
||||
"vm_exec",
|
||||
|
|
|
|||
|
|
@ -1,14 +1,21 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Minimal guest-side exec and workspace import agent for pyro runtime bundles."""
|
||||
"""Guest-side exec, workspace import, and interactive shell agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import fcntl
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import pty
|
||||
import signal
|
||||
import socket
|
||||
import struct
|
||||
import subprocess
|
||||
import tarfile
|
||||
import termios
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any
|
||||
|
|
@ -16,6 +23,17 @@ from typing import Any
|
|||
PORT = 5005
|
||||
BUFFER_SIZE = 65536
|
||||
WORKSPACE_ROOT = PurePosixPath("/workspace")
|
||||
SHELL_ROOT = Path("/run/pyro-shells")
|
||||
SHELL_SIGNAL_MAP = {
|
||||
"HUP": signal.SIGHUP,
|
||||
"INT": signal.SIGINT,
|
||||
"TERM": signal.SIGTERM,
|
||||
"KILL": signal.SIGKILL,
|
||||
}
|
||||
SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP)
|
||||
|
||||
_SHELLS: dict[str, "GuestShellSession"] = {}
|
||||
_SHELLS_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _read_request(conn: socket.socket) -> dict[str, Any]:
|
||||
|
|
@ -77,10 +95,15 @@ def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
|
|||
suffix = normalized.relative_to(WORKSPACE_ROOT)
|
||||
host_path = Path("/workspace")
|
||||
if str(suffix) not in {"", "."}:
|
||||
host_path = host_path / str(suffix)
|
||||
host_path = host_path.joinpath(*suffix.parts)
|
||||
return normalized, host_path
|
||||
|
||||
|
||||
def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]:
|
||||
normalized, host_path = _normalize_destination(cwd)
|
||||
return str(normalized), host_path
|
||||
|
||||
|
||||
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
|
|
@ -106,18 +129,18 @@ def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str)
|
|||
|
||||
|
||||
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
||||
_, destination_root = _normalize_destination(destination)
|
||||
normalized_destination, destination_root = _normalize_destination(destination)
|
||||
destination_root.mkdir(parents=True, exist_ok=True)
|
||||
bytes_written = 0
|
||||
entry_count = 0
|
||||
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
|
||||
for member in archive.getmembers():
|
||||
member_name = _normalize_member_name(member.name)
|
||||
target_path = destination_root / str(member_name)
|
||||
target_path = destination_root.joinpath(*member_name.parts)
|
||||
entry_count += 1
|
||||
_ensure_no_symlink_parents(destination_root, target_path, member.name)
|
||||
if member.isdir():
|
||||
if target_path.exists() and not target_path.is_dir():
|
||||
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
|
||||
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
|
|
@ -151,7 +174,7 @@ def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
|
|||
)
|
||||
raise RuntimeError(f"unsupported archive member type: {member.name}")
|
||||
return {
|
||||
"destination": destination,
|
||||
"destination": str(normalized_destination),
|
||||
"entry_count": entry_count,
|
||||
"bytes_written": bytes_written,
|
||||
}
|
||||
|
|
@ -182,7 +205,323 @@ def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
|
|||
}
|
||||
|
||||
|
||||
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
|
||||
winsize = struct.pack("HHHH", rows, cols, 0, 0)
|
||||
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
|
||||
|
||||
|
||||
class GuestShellSession:
|
||||
"""In-guest PTY-backed interactive shell session."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shell_id: str,
|
||||
cwd: Path,
|
||||
cwd_text: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> None:
|
||||
self.shell_id = shell_id
|
||||
self.cwd = cwd_text
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
self.started_at = time.time()
|
||||
self.ended_at: float | None = None
|
||||
self.exit_code: int | None = None
|
||||
self.state = "running"
|
||||
self._lock = threading.RLock()
|
||||
self._output = ""
|
||||
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
|
||||
self._metadata_path = SHELL_ROOT / f"{shell_id}.json"
|
||||
self._log_path = SHELL_ROOT / f"{shell_id}.log"
|
||||
self._master_fd: int | None = None
|
||||
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
try:
|
||||
_set_pty_size(slave_fd, rows, cols)
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"TERM": env.get("TERM", "xterm-256color"),
|
||||
"PS1": "pyro$ ",
|
||||
"PROMPT_COMMAND": "",
|
||||
}
|
||||
)
|
||||
process = subprocess.Popen( # noqa: S603
|
||||
["/bin/bash", "--noprofile", "--norc", "-i"],
|
||||
stdin=slave_fd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
cwd=str(cwd),
|
||||
env=env,
|
||||
text=False,
|
||||
close_fds=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
except Exception:
|
||||
os.close(master_fd)
|
||||
raise
|
||||
finally:
|
||||
os.close(slave_fd)
|
||||
|
||||
self._process = process
|
||||
self._master_fd = master_fd
|
||||
self._write_metadata()
|
||||
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
|
||||
self._reader.start()
|
||||
self._waiter.start()
|
||||
|
||||
def summary(self) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
return {
|
||||
"shell_id": self.shell_id,
|
||||
"cwd": self.cwd,
|
||||
"cols": self.cols,
|
||||
"rows": self.rows,
|
||||
"state": self.state,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"exit_code": self.exit_code,
|
||||
}
|
||||
|
||||
def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
clamped_cursor = min(max(cursor, 0), len(self._output))
|
||||
output = self._output[clamped_cursor : clamped_cursor + max_chars]
|
||||
next_cursor = clamped_cursor + len(output)
|
||||
payload = self.summary()
|
||||
payload.update(
|
||||
{
|
||||
"cursor": clamped_cursor,
|
||||
"next_cursor": next_cursor,
|
||||
"output": output,
|
||||
"truncated": next_cursor < len(self._output),
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
def write(self, text: str, *, append_newline: bool) -> dict[str, Any]:
|
||||
if self._process.poll() is not None:
|
||||
self._refresh_process_state()
|
||||
with self._lock:
|
||||
if self.state != "running":
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running")
|
||||
master_fd = self._master_fd
|
||||
if master_fd is None:
|
||||
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
|
||||
payload = text + ("\n" if append_newline else "")
|
||||
try:
|
||||
os.write(master_fd, payload.encode("utf-8"))
|
||||
except OSError as exc:
|
||||
self._refresh_process_state()
|
||||
raise RuntimeError(f"failed to write shell input: {exc}") from exc
|
||||
response = self.summary()
|
||||
response.update({"input_length": len(text), "append_newline": append_newline})
|
||||
return response
|
||||
|
||||
def send_signal(self, signal_name: str) -> dict[str, Any]:
|
||||
signum = SHELL_SIGNAL_MAP.get(signal_name)
|
||||
if signum is None:
|
||||
raise ValueError(f"unsupported shell signal: {signal_name}")
|
||||
if self._process.poll() is not None:
|
||||
self._refresh_process_state()
|
||||
with self._lock:
|
||||
if self.state != "running":
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running")
|
||||
pid = self._process.pid
|
||||
try:
|
||||
os.killpg(pid, signum)
|
||||
except ProcessLookupError as exc:
|
||||
self._refresh_process_state()
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
|
||||
response = self.summary()
|
||||
response["signal"] = signal_name
|
||||
return response
|
||||
|
||||
def close(self) -> dict[str, Any]:
|
||||
if self._process.poll() is None:
|
||||
try:
|
||||
os.killpg(self._process.pid, signal.SIGHUP)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
try:
|
||||
self._process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
os.killpg(self._process.pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
self._process.wait(timeout=5)
|
||||
else:
|
||||
self._refresh_process_state()
|
||||
self._close_master_fd()
|
||||
if self._reader is not None:
|
||||
self._reader.join(timeout=1)
|
||||
if self._waiter is not None:
|
||||
self._waiter.join(timeout=1)
|
||||
response = self.summary()
|
||||
response["closed"] = True
|
||||
return response
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
master_fd = self._master_fd
|
||||
if master_fd is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
chunk = os.read(master_fd, BUFFER_SIZE)
|
||||
except OSError:
|
||||
break
|
||||
if chunk == b"":
|
||||
break
|
||||
decoded = self._decoder.decode(chunk)
|
||||
if decoded == "":
|
||||
continue
|
||||
with self._lock:
|
||||
self._output += decoded
|
||||
with self._log_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(decoded)
|
||||
decoded = self._decoder.decode(b"", final=True)
|
||||
if decoded != "":
|
||||
with self._lock:
|
||||
self._output += decoded
|
||||
with self._log_path.open("a", encoding="utf-8") as handle:
|
||||
handle.write(decoded)
|
||||
|
||||
def _waiter_loop(self) -> None:
|
||||
exit_code = self._process.wait()
|
||||
with self._lock:
|
||||
self.state = "stopped"
|
||||
self.exit_code = exit_code
|
||||
self.ended_at = time.time()
|
||||
self._write_metadata()
|
||||
|
||||
def _refresh_process_state(self) -> None:
|
||||
exit_code = self._process.poll()
|
||||
if exit_code is None:
|
||||
return
|
||||
with self._lock:
|
||||
if self.state == "running":
|
||||
self.state = "stopped"
|
||||
self.exit_code = exit_code
|
||||
self.ended_at = time.time()
|
||||
self._write_metadata()
|
||||
|
||||
def _write_metadata(self) -> None:
|
||||
self._metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8")
|
||||
|
||||
def _close_master_fd(self) -> None:
|
||||
with self._lock:
|
||||
master_fd = self._master_fd
|
||||
self._master_fd = None
|
||||
if master_fd is None:
|
||||
return
|
||||
try:
|
||||
os.close(master_fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def _create_shell(
|
||||
*,
|
||||
shell_id: str,
|
||||
cwd_text: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> GuestShellSession:
|
||||
_, cwd_path = _normalize_shell_cwd(cwd_text)
|
||||
with _SHELLS_LOCK:
|
||||
if shell_id in _SHELLS:
|
||||
raise RuntimeError(f"shell {shell_id!r} already exists")
|
||||
session = GuestShellSession(
|
||||
shell_id=shell_id,
|
||||
cwd=cwd_path,
|
||||
cwd_text=cwd_text,
|
||||
cols=cols,
|
||||
rows=rows,
|
||||
)
|
||||
_SHELLS[shell_id] = session
|
||||
return session
|
||||
|
||||
|
||||
def _get_shell(shell_id: str) -> GuestShellSession:
|
||||
with _SHELLS_LOCK:
|
||||
try:
|
||||
return _SHELLS[shell_id]
|
||||
except KeyError as exc:
|
||||
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
|
||||
|
||||
|
||||
def _remove_shell(shell_id: str) -> GuestShellSession:
|
||||
with _SHELLS_LOCK:
|
||||
try:
|
||||
return _SHELLS.pop(shell_id)
|
||||
except KeyError as exc:
|
||||
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
|
||||
|
||||
|
||||
def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
archive_size = int(request.get("archive_size", 0))
|
||||
if archive_size < 0:
|
||||
raise RuntimeError("archive_size must not be negative")
|
||||
destination = str(request.get("destination", "/workspace"))
|
||||
payload = _read_exact(conn, archive_size)
|
||||
return _extract_archive(payload, destination)
|
||||
if action == "open_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace")))
|
||||
session = _create_shell(
|
||||
shell_id=shell_id,
|
||||
cwd_text=cwd_text,
|
||||
cols=int(request.get("cols", 120)),
|
||||
rows=int(request.get("rows", 30)),
|
||||
)
|
||||
return session.summary()
|
||||
if action == "read_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
return _get_shell(shell_id).read(
|
||||
cursor=int(request.get("cursor", 0)),
|
||||
max_chars=int(request.get("max_chars", 65536)),
|
||||
)
|
||||
if action == "write_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
return _get_shell(shell_id).write(
|
||||
str(request.get("input", "")),
|
||||
append_newline=bool(request.get("append_newline", True)),
|
||||
)
|
||||
if action == "signal_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
signal_name = str(request.get("signal", "INT")).upper()
|
||||
if signal_name not in SHELL_SIGNAL_NAMES:
|
||||
raise RuntimeError(
|
||||
f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}"
|
||||
)
|
||||
return _get_shell(shell_id).send_signal(signal_name)
|
||||
if action == "close_shell":
|
||||
shell_id = str(request.get("shell_id", "")).strip()
|
||||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
return _remove_shell(shell_id).close()
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
return _run_command(command, timeout_seconds)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
family = getattr(socket, "AF_VSOCK", None)
|
||||
if family is None:
|
||||
raise SystemExit("AF_VSOCK is unavailable")
|
||||
|
|
@ -192,19 +531,11 @@ def main() -> None:
|
|||
while True:
|
||||
conn, _ = server.accept()
|
||||
with conn:
|
||||
request = _read_request(conn)
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
archive_size = int(request.get("archive_size", 0))
|
||||
if archive_size < 0:
|
||||
raise RuntimeError("archive_size must not be negative")
|
||||
destination = str(request.get("destination", "/workspace"))
|
||||
payload = _read_exact(conn, archive_size)
|
||||
response = _extract_archive(payload, destination)
|
||||
else:
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
response = _run_command(command, timeout_seconds)
|
||||
try:
|
||||
request = _read_request(conn)
|
||||
response = _dispatch(request, conn)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
response = {"error": str(exc)}
|
||||
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
"guest": {
|
||||
"agent": {
|
||||
"path": "guest/pyro_guest_agent.py",
|
||||
"sha256": "3b684b1b07745fc7788e560b0bdd0c0535c31c395ff87474ae9e114f4489726d"
|
||||
"sha256": "07adf6269551447dbea8c236f91499ea1479212a3f084c5402a656f5f5cc5892"
|
||||
}
|
||||
},
|
||||
"platform": "linux-x86_64",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from typing import Any
|
|||
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
|
||||
|
||||
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.4.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.5.0"
|
||||
OCI_MANIFEST_ACCEPT = ", ".join(
|
||||
(
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
|
|
|
|||
|
|
@ -39,6 +39,26 @@ class GuestArchiveResponse:
|
|||
bytes_written: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuestShellSummary:
|
||||
shell_id: str
|
||||
cwd: str
|
||||
cols: int
|
||||
rows: int
|
||||
state: str
|
||||
started_at: float
|
||||
ended_at: float | None
|
||||
exit_code: int | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GuestShellReadResponse(GuestShellSummary):
|
||||
cursor: int
|
||||
next_cursor: int
|
||||
output: str
|
||||
truncated: bool
|
||||
|
||||
|
||||
class VsockExecClient:
|
||||
"""Minimal JSON-over-stream client for a guest exec agent."""
|
||||
|
||||
|
|
@ -54,19 +74,17 @@ class VsockExecClient:
|
|||
*,
|
||||
uds_path: str | None = None,
|
||||
) -> GuestExecResponse:
|
||||
request = {
|
||||
"command": command,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
}
|
||||
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
|
||||
try:
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
payload = self._recv_json_payload(sock)
|
||||
finally:
|
||||
sock.close()
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest exec response must be a JSON object")
|
||||
payload = self._request_json(
|
||||
guest_cid,
|
||||
port,
|
||||
{
|
||||
"command": command,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
},
|
||||
timeout_seconds=timeout_seconds,
|
||||
uds_path=uds_path,
|
||||
error_message="guest exec response must be a JSON object",
|
||||
)
|
||||
return GuestExecResponse(
|
||||
stdout=str(payload.get("stdout", "")),
|
||||
stderr=str(payload.get("stderr", "")),
|
||||
|
|
@ -101,12 +119,198 @@ class VsockExecClient:
|
|||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("guest archive response must be a JSON object")
|
||||
error = payload.get("error")
|
||||
if error is not None:
|
||||
raise RuntimeError(str(error))
|
||||
return GuestArchiveResponse(
|
||||
destination=str(payload.get("destination", destination)),
|
||||
entry_count=int(payload.get("entry_count", 0)),
|
||||
bytes_written=int(payload.get("bytes_written", 0)),
|
||||
)
|
||||
|
||||
def open_shell(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
*,
|
||||
shell_id: str,
|
||||
cwd: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
timeout_seconds: int = 30,
|
||||
uds_path: str | None = None,
|
||||
) -> GuestShellSummary:
|
||||
payload = self._request_json(
|
||||
guest_cid,
|
||||
port,
|
||||
{
|
||||
"action": "open_shell",
|
||||
"shell_id": shell_id,
|
||||
"cwd": cwd,
|
||||
"cols": cols,
|
||||
"rows": rows,
|
||||
},
|
||||
timeout_seconds=timeout_seconds,
|
||||
uds_path=uds_path,
|
||||
error_message="guest shell open response must be a JSON object",
|
||||
)
|
||||
return self._shell_summary_from_payload(payload)
|
||||
|
||||
def read_shell(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
*,
|
||||
shell_id: str,
|
||||
cursor: int,
|
||||
max_chars: int,
|
||||
timeout_seconds: int = 30,
|
||||
uds_path: str | None = None,
|
||||
) -> GuestShellReadResponse:
|
||||
payload = self._request_json(
|
||||
guest_cid,
|
||||
port,
|
||||
{
|
||||
"action": "read_shell",
|
||||
"shell_id": shell_id,
|
||||
"cursor": cursor,
|
||||
"max_chars": max_chars,
|
||||
},
|
||||
timeout_seconds=timeout_seconds,
|
||||
uds_path=uds_path,
|
||||
error_message="guest shell read response must be a JSON object",
|
||||
)
|
||||
summary = self._shell_summary_from_payload(payload)
|
||||
return GuestShellReadResponse(
|
||||
shell_id=summary.shell_id,
|
||||
cwd=summary.cwd,
|
||||
cols=summary.cols,
|
||||
rows=summary.rows,
|
||||
state=summary.state,
|
||||
started_at=summary.started_at,
|
||||
ended_at=summary.ended_at,
|
||||
exit_code=summary.exit_code,
|
||||
cursor=int(payload.get("cursor", cursor)),
|
||||
next_cursor=int(payload.get("next_cursor", cursor)),
|
||||
output=str(payload.get("output", "")),
|
||||
truncated=bool(payload.get("truncated", False)),
|
||||
)
|
||||
|
||||
def write_shell(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
*,
|
||||
shell_id: str,
|
||||
input_text: str,
|
||||
append_newline: bool,
|
||||
timeout_seconds: int = 30,
|
||||
uds_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload = self._request_json(
|
||||
guest_cid,
|
||||
port,
|
||||
{
|
||||
"action": "write_shell",
|
||||
"shell_id": shell_id,
|
||||
"input": input_text,
|
||||
"append_newline": append_newline,
|
||||
},
|
||||
timeout_seconds=timeout_seconds,
|
||||
uds_path=uds_path,
|
||||
error_message="guest shell write response must be a JSON object",
|
||||
)
|
||||
self._shell_summary_from_payload(payload)
|
||||
return payload
|
||||
|
||||
def signal_shell(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
*,
|
||||
shell_id: str,
|
||||
signal_name: str,
|
||||
timeout_seconds: int = 30,
|
||||
uds_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload = self._request_json(
|
||||
guest_cid,
|
||||
port,
|
||||
{
|
||||
"action": "signal_shell",
|
||||
"shell_id": shell_id,
|
||||
"signal": signal_name,
|
||||
},
|
||||
timeout_seconds=timeout_seconds,
|
||||
uds_path=uds_path,
|
||||
error_message="guest shell signal response must be a JSON object",
|
||||
)
|
||||
self._shell_summary_from_payload(payload)
|
||||
return payload
|
||||
|
||||
def close_shell(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
*,
|
||||
shell_id: str,
|
||||
timeout_seconds: int = 30,
|
||||
uds_path: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
payload = self._request_json(
|
||||
guest_cid,
|
||||
port,
|
||||
{
|
||||
"action": "close_shell",
|
||||
"shell_id": shell_id,
|
||||
},
|
||||
timeout_seconds=timeout_seconds,
|
||||
uds_path=uds_path,
|
||||
error_message="guest shell close response must be a JSON object",
|
||||
)
|
||||
self._shell_summary_from_payload(payload)
|
||||
return payload
|
||||
|
||||
def _request_json(
|
||||
self,
|
||||
guest_cid: int,
|
||||
port: int,
|
||||
request: dict[str, Any],
|
||||
*,
|
||||
timeout_seconds: int,
|
||||
uds_path: str | None,
|
||||
error_message: str,
|
||||
) -> dict[str, Any]:
|
||||
sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path)
|
||||
try:
|
||||
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
|
||||
payload = self._recv_json_payload(sock)
|
||||
finally:
|
||||
sock.close()
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(error_message)
|
||||
error = payload.get("error")
|
||||
if error is not None:
|
||||
raise RuntimeError(str(error))
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _shell_summary_from_payload(payload: dict[str, Any]) -> GuestShellSummary:
|
||||
return GuestShellSummary(
|
||||
shell_id=str(payload.get("shell_id", "")),
|
||||
cwd=str(payload.get("cwd", "/workspace")),
|
||||
cols=int(payload.get("cols", 0)),
|
||||
rows=int(payload.get("rows", 0)),
|
||||
state=str(payload.get("state", "stopped")),
|
||||
started_at=float(payload.get("started_at", 0.0)),
|
||||
ended_at=(
|
||||
None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0))
|
||||
),
|
||||
exit_code=(
|
||||
None if payload.get("exit_code") is None else int(payload.get("exit_code", 0))
|
||||
),
|
||||
)
|
||||
|
||||
def _connect(
|
||||
self,
|
||||
guest_cid: int,
|
||||
|
|
|
|||
|
|
@ -27,8 +27,15 @@ from pyro_mcp.vm_environments import EnvironmentStore, default_cache_dir, get_en
|
|||
from pyro_mcp.vm_firecracker import build_launch_plan
|
||||
from pyro_mcp.vm_guest import VsockExecClient
|
||||
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
|
||||
from pyro_mcp.workspace_shells import (
|
||||
create_local_shell,
|
||||
get_local_shell,
|
||||
remove_local_shell,
|
||||
shell_signal_names,
|
||||
)
|
||||
|
||||
VmState = Literal["created", "started", "stopped"]
|
||||
WorkspaceShellState = Literal["running", "stopped"]
|
||||
|
||||
DEFAULT_VCPU_COUNT = 1
|
||||
DEFAULT_MEM_MIB = 1024
|
||||
|
|
@ -36,13 +43,18 @@ DEFAULT_TIMEOUT_SECONDS = 30
|
|||
DEFAULT_TTL_SECONDS = 600
|
||||
DEFAULT_ALLOW_HOST_COMPAT = False
|
||||
|
||||
WORKSPACE_LAYOUT_VERSION = 2
|
||||
WORKSPACE_LAYOUT_VERSION = 3
|
||||
WORKSPACE_DIRNAME = "workspace"
|
||||
WORKSPACE_COMMANDS_DIRNAME = "commands"
|
||||
WORKSPACE_SHELLS_DIRNAME = "shells"
|
||||
WORKSPACE_RUNTIME_DIRNAME = "runtime"
|
||||
WORKSPACE_GUEST_PATH = "/workspace"
|
||||
WORKSPACE_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py"
|
||||
WORKSPACE_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60
|
||||
DEFAULT_SHELL_COLS = 120
|
||||
DEFAULT_SHELL_ROWS = 30
|
||||
DEFAULT_SHELL_MAX_CHARS = 65536
|
||||
WORKSPACE_SHELL_SIGNAL_NAMES = shell_signal_names()
|
||||
|
||||
WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"]
|
||||
|
||||
|
|
@ -183,6 +195,58 @@ class WorkspaceRecord:
|
|||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkspaceShellRecord:
|
||||
"""Persistent shell metadata stored on disk per workspace."""
|
||||
|
||||
workspace_id: str
|
||||
shell_id: str
|
||||
cwd: str
|
||||
cols: int
|
||||
rows: int
|
||||
state: WorkspaceShellState
|
||||
started_at: float
|
||||
ended_at: float | None = None
|
||||
exit_code: int | None = None
|
||||
execution_mode: str = "pending"
|
||||
metadata: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"layout_version": WORKSPACE_LAYOUT_VERSION,
|
||||
"workspace_id": self.workspace_id,
|
||||
"shell_id": self.shell_id,
|
||||
"cwd": self.cwd,
|
||||
"cols": self.cols,
|
||||
"rows": self.rows,
|
||||
"state": self.state,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"exit_code": self.exit_code,
|
||||
"execution_mode": self.execution_mode,
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> WorkspaceShellRecord:
|
||||
return cls(
|
||||
workspace_id=str(payload["workspace_id"]),
|
||||
shell_id=str(payload["shell_id"]),
|
||||
cwd=str(payload.get("cwd", WORKSPACE_GUEST_PATH)),
|
||||
cols=int(payload.get("cols", DEFAULT_SHELL_COLS)),
|
||||
rows=int(payload.get("rows", DEFAULT_SHELL_ROWS)),
|
||||
state=cast(WorkspaceShellState, str(payload.get("state", "stopped"))),
|
||||
started_at=float(payload.get("started_at", 0.0)),
|
||||
ended_at=(
|
||||
None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0))
|
||||
),
|
||||
exit_code=(
|
||||
None if payload.get("exit_code") is None else int(payload.get("exit_code", 0))
|
||||
),
|
||||
execution_mode=str(payload.get("execution_mode", "pending")),
|
||||
metadata=_string_dict(payload.get("metadata")),
|
||||
)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PreparedWorkspaceSeed:
|
||||
"""Prepared host-side seed archive plus metadata."""
|
||||
|
|
@ -610,6 +674,59 @@ class VmBackend:
|
|||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def open_shell( # pragma: no cover
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cwd: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def read_shell( # pragma: no cover
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cursor: int,
|
||||
max_chars: int,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def write_shell( # pragma: no cover
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
input_text: str,
|
||||
append_newline: bool,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def signal_shell( # pragma: no cover
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
signal_name: str,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
def close_shell( # pragma: no cover
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MockBackend(VmBackend):
|
||||
"""Host-process backend used for development and testability."""
|
||||
|
|
@ -651,6 +768,87 @@ class MockBackend(VmBackend):
|
|||
destination=destination,
|
||||
)
|
||||
|
||||
def open_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cwd: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> dict[str, Any]:
|
||||
session = create_local_shell(
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
cwd=_workspace_host_destination(_instance_workspace_host_dir(instance), cwd),
|
||||
display_cwd=cwd,
|
||||
cols=cols,
|
||||
rows=rows,
|
||||
)
|
||||
summary = session.summary()
|
||||
summary["execution_mode"] = "host_compat"
|
||||
return summary
|
||||
|
||||
def read_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cursor: int,
|
||||
max_chars: int,
|
||||
) -> dict[str, Any]:
|
||||
del instance
|
||||
session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id)
|
||||
payload = session.read(cursor=cursor, max_chars=max_chars)
|
||||
payload["execution_mode"] = "host_compat"
|
||||
return payload
|
||||
|
||||
def write_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
input_text: str,
|
||||
append_newline: bool,
|
||||
) -> dict[str, Any]:
|
||||
del instance
|
||||
session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id)
|
||||
payload = session.write(input_text, append_newline=append_newline)
|
||||
payload["execution_mode"] = "host_compat"
|
||||
return payload
|
||||
|
||||
def signal_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
signal_name: str,
|
||||
) -> dict[str, Any]:
|
||||
del instance
|
||||
session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id)
|
||||
payload = session.send_signal(signal_name)
|
||||
payload["execution_mode"] = "host_compat"
|
||||
return payload
|
||||
|
||||
def close_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
) -> dict[str, Any]:
|
||||
del instance
|
||||
session = remove_local_shell(workspace_id=workspace_id, shell_id=shell_id)
|
||||
if session is None:
|
||||
raise ValueError(f"shell {shell_id!r} does not exist in workspace {workspace_id!r}")
|
||||
payload = session.close()
|
||||
payload["execution_mode"] = "host_compat"
|
||||
return payload
|
||||
|
||||
|
||||
class FirecrackerBackend(VmBackend): # pragma: no cover
|
||||
"""Host-gated backend that validates Firecracker prerequisites."""
|
||||
|
|
@ -888,6 +1086,144 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
|
|||
destination=destination,
|
||||
)
|
||||
|
||||
def open_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cwd: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> dict[str, Any]:
|
||||
del workspace_id
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
uds_path = instance.metadata.get("guest_exec_uds_path")
|
||||
response = self._guest_exec_client.open_shell(
|
||||
guest_cid,
|
||||
port,
|
||||
shell_id=shell_id,
|
||||
cwd=cwd,
|
||||
cols=cols,
|
||||
rows=rows,
|
||||
uds_path=uds_path,
|
||||
)
|
||||
return {
|
||||
"shell_id": response.shell_id or shell_id,
|
||||
"cwd": response.cwd,
|
||||
"cols": response.cols,
|
||||
"rows": response.rows,
|
||||
"state": response.state,
|
||||
"started_at": response.started_at,
|
||||
"ended_at": response.ended_at,
|
||||
"exit_code": response.exit_code,
|
||||
"execution_mode": instance.metadata.get("execution_mode", "pending"),
|
||||
}
|
||||
|
||||
def read_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cursor: int,
|
||||
max_chars: int,
|
||||
) -> dict[str, Any]:
|
||||
del workspace_id
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
uds_path = instance.metadata.get("guest_exec_uds_path")
|
||||
response = self._guest_exec_client.read_shell(
|
||||
guest_cid,
|
||||
port,
|
||||
shell_id=shell_id,
|
||||
cursor=cursor,
|
||||
max_chars=max_chars,
|
||||
uds_path=uds_path,
|
||||
)
|
||||
return {
|
||||
"shell_id": response.shell_id,
|
||||
"cwd": response.cwd,
|
||||
"cols": response.cols,
|
||||
"rows": response.rows,
|
||||
"state": response.state,
|
||||
"started_at": response.started_at,
|
||||
"ended_at": response.ended_at,
|
||||
"exit_code": response.exit_code,
|
||||
"cursor": response.cursor,
|
||||
"next_cursor": response.next_cursor,
|
||||
"output": response.output,
|
||||
"truncated": response.truncated,
|
||||
"execution_mode": instance.metadata.get("execution_mode", "pending"),
|
||||
}
|
||||
|
||||
def write_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
input_text: str,
|
||||
append_newline: bool,
|
||||
) -> dict[str, Any]:
|
||||
del workspace_id
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
uds_path = instance.metadata.get("guest_exec_uds_path")
|
||||
payload = self._guest_exec_client.write_shell(
|
||||
guest_cid,
|
||||
port,
|
||||
shell_id=shell_id,
|
||||
input_text=input_text,
|
||||
append_newline=append_newline,
|
||||
uds_path=uds_path,
|
||||
)
|
||||
payload["execution_mode"] = instance.metadata.get("execution_mode", "pending")
|
||||
return payload
|
||||
|
||||
def signal_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
signal_name: str,
|
||||
) -> dict[str, Any]:
|
||||
del workspace_id
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
uds_path = instance.metadata.get("guest_exec_uds_path")
|
||||
payload = self._guest_exec_client.signal_shell(
|
||||
guest_cid,
|
||||
port,
|
||||
shell_id=shell_id,
|
||||
signal_name=signal_name,
|
||||
uds_path=uds_path,
|
||||
)
|
||||
payload["execution_mode"] = instance.metadata.get("execution_mode", "pending")
|
||||
return payload
|
||||
|
||||
def close_shell(
|
||||
self,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
) -> dict[str, Any]:
|
||||
del workspace_id
|
||||
guest_cid = int(instance.metadata["guest_cid"])
|
||||
port = int(instance.metadata["guest_exec_port"])
|
||||
uds_path = instance.metadata.get("guest_exec_uds_path")
|
||||
payload = self._guest_exec_client.close_shell(
|
||||
guest_cid,
|
||||
port,
|
||||
shell_id=shell_id,
|
||||
uds_path=uds_path,
|
||||
)
|
||||
payload["execution_mode"] = instance.metadata.get("execution_mode", "pending")
|
||||
return payload
|
||||
|
||||
|
||||
class VmManager:
|
||||
"""In-process lifecycle manager for ephemeral VM environments and workspaces."""
|
||||
|
|
@ -1151,9 +1487,11 @@ class VmManager:
|
|||
runtime_dir = self._workspace_runtime_dir(workspace_id)
|
||||
host_workspace_dir = self._workspace_host_dir(workspace_id)
|
||||
commands_dir = self._workspace_commands_dir(workspace_id)
|
||||
shells_dir = self._workspace_shells_dir(workspace_id)
|
||||
workspace_dir.mkdir(parents=True, exist_ok=False)
|
||||
host_workspace_dir.mkdir(parents=True, exist_ok=True)
|
||||
commands_dir.mkdir(parents=True, exist_ok=True)
|
||||
shells_dir.mkdir(parents=True, exist_ok=True)
|
||||
instance = VmInstance(
|
||||
vm_id=workspace_id,
|
||||
environment=environment,
|
||||
|
|
@ -1179,11 +1517,8 @@ class VmManager:
|
|||
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
|
||||
)
|
||||
self._backend.create(instance)
|
||||
if (
|
||||
prepared_seed.archive_path is not None
|
||||
and self._runtime_capabilities.supports_guest_exec
|
||||
):
|
||||
self._ensure_workspace_guest_seed_support(instance)
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
self._ensure_workspace_guest_agent_support(instance)
|
||||
with self._lock:
|
||||
self._start_instance_locked(instance)
|
||||
self._require_guest_exec_or_opt_in(instance)
|
||||
|
|
@ -1332,6 +1667,208 @@ class VmManager:
|
|||
"cwd": WORKSPACE_GUEST_PATH,
|
||||
}
|
||||
|
||||
def open_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
*,
|
||||
cwd: str = WORKSPACE_GUEST_PATH,
|
||||
cols: int = DEFAULT_SHELL_COLS,
|
||||
rows: int = DEFAULT_SHELL_ROWS,
|
||||
) -> dict[str, Any]:
|
||||
if cols <= 0:
|
||||
raise ValueError("cols must be positive")
|
||||
if rows <= 0:
|
||||
raise ValueError("rows must be positive")
|
||||
normalized_cwd, _ = _normalize_workspace_destination(cwd)
|
||||
shell_id = uuid.uuid4().hex[:12]
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
instance = self._workspace_instance_for_live_shell_locked(workspace)
|
||||
payload = self._backend.open_shell(
|
||||
instance,
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
cwd=normalized_cwd,
|
||||
cols=cols,
|
||||
rows=rows,
|
||||
)
|
||||
shell = self._workspace_shell_record_from_payload(
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
payload=payload,
|
||||
)
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
workspace.state = instance.state
|
||||
workspace.firecracker_pid = instance.firecracker_pid
|
||||
workspace.last_error = instance.last_error
|
||||
workspace.metadata = dict(instance.metadata)
|
||||
self._save_workspace_locked(workspace)
|
||||
self._save_workspace_shell_locked(shell)
|
||||
return self._serialize_workspace_shell(shell)
|
||||
|
||||
def read_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
*,
|
||||
cursor: int = 0,
|
||||
max_chars: int = DEFAULT_SHELL_MAX_CHARS,
|
||||
) -> dict[str, Any]:
|
||||
if cursor < 0:
|
||||
raise ValueError("cursor must not be negative")
|
||||
if max_chars <= 0:
|
||||
raise ValueError("max_chars must be positive")
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
instance = self._workspace_instance_for_live_shell_locked(workspace)
|
||||
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
|
||||
payload = self._backend.read_shell(
|
||||
instance,
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
cursor=cursor,
|
||||
max_chars=max_chars,
|
||||
)
|
||||
updated_shell = self._workspace_shell_record_from_payload(
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
payload=payload,
|
||||
metadata=shell.metadata,
|
||||
)
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
workspace.state = instance.state
|
||||
workspace.firecracker_pid = instance.firecracker_pid
|
||||
workspace.last_error = instance.last_error
|
||||
workspace.metadata = dict(instance.metadata)
|
||||
self._save_workspace_locked(workspace)
|
||||
self._save_workspace_shell_locked(updated_shell)
|
||||
response = self._serialize_workspace_shell(updated_shell)
|
||||
response.update(
|
||||
{
|
||||
"cursor": int(payload.get("cursor", cursor)),
|
||||
"next_cursor": int(payload.get("next_cursor", cursor)),
|
||||
"output": str(payload.get("output", "")),
|
||||
"truncated": bool(payload.get("truncated", False)),
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
def write_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
*,
|
||||
input_text: str,
|
||||
append_newline: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
instance = self._workspace_instance_for_live_shell_locked(workspace)
|
||||
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
|
||||
payload = self._backend.write_shell(
|
||||
instance,
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
input_text=input_text,
|
||||
append_newline=append_newline,
|
||||
)
|
||||
updated_shell = self._workspace_shell_record_from_payload(
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
payload=payload,
|
||||
metadata=shell.metadata,
|
||||
)
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
workspace.state = instance.state
|
||||
workspace.firecracker_pid = instance.firecracker_pid
|
||||
workspace.last_error = instance.last_error
|
||||
workspace.metadata = dict(instance.metadata)
|
||||
self._save_workspace_locked(workspace)
|
||||
self._save_workspace_shell_locked(updated_shell)
|
||||
response = self._serialize_workspace_shell(updated_shell)
|
||||
response.update(
|
||||
{
|
||||
"input_length": int(payload.get("input_length", len(input_text))),
|
||||
"append_newline": bool(payload.get("append_newline", append_newline)),
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
def signal_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
*,
|
||||
signal_name: str = "INT",
|
||||
) -> dict[str, Any]:
|
||||
normalized_signal = signal_name.upper()
|
||||
if normalized_signal not in WORKSPACE_SHELL_SIGNAL_NAMES:
|
||||
raise ValueError(
|
||||
f"signal_name must be one of: {', '.join(WORKSPACE_SHELL_SIGNAL_NAMES)}"
|
||||
)
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
instance = self._workspace_instance_for_live_shell_locked(workspace)
|
||||
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
|
||||
payload = self._backend.signal_shell(
|
||||
instance,
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
signal_name=normalized_signal,
|
||||
)
|
||||
updated_shell = self._workspace_shell_record_from_payload(
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
payload=payload,
|
||||
metadata=shell.metadata,
|
||||
)
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
workspace.state = instance.state
|
||||
workspace.firecracker_pid = instance.firecracker_pid
|
||||
workspace.last_error = instance.last_error
|
||||
workspace.metadata = dict(instance.metadata)
|
||||
self._save_workspace_locked(workspace)
|
||||
self._save_workspace_shell_locked(updated_shell)
|
||||
response = self._serialize_workspace_shell(updated_shell)
|
||||
response["signal"] = str(payload.get("signal", normalized_signal))
|
||||
return response
|
||||
|
||||
def close_shell(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
instance = self._workspace_instance_for_live_shell_locked(workspace)
|
||||
shell = self._load_workspace_shell_locked(workspace_id, shell_id)
|
||||
payload = self._backend.close_shell(
|
||||
instance,
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
)
|
||||
closed_shell = self._workspace_shell_record_from_payload(
|
||||
workspace_id=workspace_id,
|
||||
shell_id=shell_id,
|
||||
payload=payload,
|
||||
metadata=shell.metadata,
|
||||
)
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
workspace.state = instance.state
|
||||
workspace.firecracker_pid = instance.firecracker_pid
|
||||
workspace.last_error = instance.last_error
|
||||
workspace.metadata = dict(instance.metadata)
|
||||
self._save_workspace_locked(workspace)
|
||||
self._delete_workspace_shell_locked(workspace_id, shell_id)
|
||||
response = self._serialize_workspace_shell(closed_shell)
|
||||
response["closed"] = bool(payload.get("closed", True))
|
||||
return response
|
||||
|
||||
def status_workspace(self, workspace_id: str) -> dict[str, Any]:
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
|
|
@ -1364,6 +1901,7 @@ class VmManager:
|
|||
instance = workspace.to_instance(
|
||||
workdir=self._workspace_runtime_dir(workspace.workspace_id)
|
||||
)
|
||||
self._close_workspace_shells_locked(workspace, instance)
|
||||
if workspace.state == "started":
|
||||
self._backend.stop(instance)
|
||||
workspace.state = "stopped"
|
||||
|
|
@ -1423,6 +1961,20 @@ class VmManager:
|
|||
"metadata": workspace.metadata,
|
||||
}
|
||||
|
||||
def _serialize_workspace_shell(self, shell: WorkspaceShellRecord) -> dict[str, Any]:
|
||||
return {
|
||||
"workspace_id": shell.workspace_id,
|
||||
"shell_id": shell.shell_id,
|
||||
"cwd": shell.cwd,
|
||||
"cols": shell.cols,
|
||||
"rows": shell.rows,
|
||||
"state": shell.state,
|
||||
"started_at": shell.started_at,
|
||||
"ended_at": shell.ended_at,
|
||||
"exit_code": shell.exit_code,
|
||||
"execution_mode": shell.execution_mode,
|
||||
}
|
||||
|
||||
def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None:
|
||||
if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat:
|
||||
return
|
||||
|
|
@ -1445,6 +1997,19 @@ class VmManager:
|
|||
"host execution."
|
||||
)
|
||||
|
||||
def _require_workspace_shell_support(self, instance: VmInstance) -> None:
|
||||
if self._backend_name == "mock":
|
||||
return
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
return
|
||||
reason = self._runtime_capabilities.reason or (
|
||||
"runtime does not support guest interactive shell sessions"
|
||||
)
|
||||
raise RuntimeError(
|
||||
"interactive shells require guest execution and are unavailable for this "
|
||||
f"workspace: {reason}"
|
||||
)
|
||||
|
||||
def _get_instance_locked(self, vm_id: str) -> VmInstance:
|
||||
try:
|
||||
return self._instances[vm_id]
|
||||
|
|
@ -1552,14 +2117,14 @@ class VmManager:
|
|||
bytes_written=bytes_written,
|
||||
)
|
||||
|
||||
def _ensure_workspace_guest_seed_support(self, instance: VmInstance) -> None:
|
||||
def _ensure_workspace_guest_agent_support(self, instance: VmInstance) -> None:
|
||||
if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None:
|
||||
raise RuntimeError(
|
||||
"runtime bundle does not provide a guest agent for workspace seeding"
|
||||
"runtime bundle does not provide a guest agent for workspace operations"
|
||||
)
|
||||
rootfs_image = instance.metadata.get("rootfs_image")
|
||||
if rootfs_image is None or rootfs_image == "":
|
||||
raise RuntimeError("workspace rootfs image is unavailable for guest seeding")
|
||||
raise RuntimeError("workspace rootfs image is unavailable for guest operations")
|
||||
_patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path)
|
||||
|
||||
def _workspace_dir(self, workspace_id: str) -> Path:
|
||||
|
|
@ -1574,9 +2139,15 @@ class VmManager:
|
|||
def _workspace_commands_dir(self, workspace_id: str) -> Path:
|
||||
return self._workspace_dir(workspace_id) / WORKSPACE_COMMANDS_DIRNAME
|
||||
|
||||
def _workspace_shells_dir(self, workspace_id: str) -> Path:
|
||||
return self._workspace_dir(workspace_id) / WORKSPACE_SHELLS_DIRNAME
|
||||
|
||||
def _workspace_metadata_path(self, workspace_id: str) -> Path:
|
||||
return self._workspace_dir(workspace_id) / "workspace.json"
|
||||
|
||||
def _workspace_shell_record_path(self, workspace_id: str, shell_id: str) -> Path:
|
||||
return self._workspace_shells_dir(workspace_id) / f"{shell_id}.json"
|
||||
|
||||
def _count_workspaces_locked(self) -> int:
|
||||
return sum(1 for _ in self._workspaces_dir.glob("*/workspace.json"))
|
||||
|
||||
|
|
@ -1609,6 +2180,7 @@ class VmManager:
|
|||
instance = workspace.to_instance(
|
||||
workdir=self._workspace_runtime_dir(workspace.workspace_id)
|
||||
)
|
||||
self._close_workspace_shells_locked(workspace, instance)
|
||||
if workspace.state == "started":
|
||||
self._backend.stop(instance)
|
||||
workspace.state = "stopped"
|
||||
|
|
@ -1704,3 +2276,97 @@ class VmManager:
|
|||
entry["stderr"] = stderr
|
||||
entries.append(entry)
|
||||
return entries
|
||||
|
||||
def _workspace_instance_for_live_shell_locked(self, workspace: WorkspaceRecord) -> VmInstance:
|
||||
self._ensure_workspace_not_expired_locked(workspace, time.time())
|
||||
self._refresh_workspace_liveness_locked(workspace)
|
||||
if workspace.state != "started":
|
||||
raise RuntimeError(
|
||||
"workspace "
|
||||
f"{workspace.workspace_id} must be in 'started' state before shell operations"
|
||||
)
|
||||
instance = workspace.to_instance(
|
||||
workdir=self._workspace_runtime_dir(workspace.workspace_id)
|
||||
)
|
||||
self._require_workspace_shell_support(instance)
|
||||
return instance
|
||||
|
||||
def _workspace_shell_record_from_payload(
|
||||
self,
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
payload: dict[str, Any],
|
||||
metadata: dict[str, str] | None = None,
|
||||
) -> WorkspaceShellRecord:
|
||||
return WorkspaceShellRecord(
|
||||
workspace_id=workspace_id,
|
||||
shell_id=str(payload.get("shell_id", shell_id)),
|
||||
cwd=str(payload.get("cwd", WORKSPACE_GUEST_PATH)),
|
||||
cols=int(payload.get("cols", DEFAULT_SHELL_COLS)),
|
||||
rows=int(payload.get("rows", DEFAULT_SHELL_ROWS)),
|
||||
state=cast(WorkspaceShellState, str(payload.get("state", "stopped"))),
|
||||
started_at=float(payload.get("started_at", time.time())),
|
||||
ended_at=(
|
||||
None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0))
|
||||
),
|
||||
exit_code=(
|
||||
None if payload.get("exit_code") is None else int(payload.get("exit_code", 0))
|
||||
),
|
||||
execution_mode=str(payload.get("execution_mode", "pending")),
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
def _load_workspace_shell_locked(
|
||||
self,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
) -> WorkspaceShellRecord:
|
||||
record_path = self._workspace_shell_record_path(workspace_id, shell_id)
|
||||
if not record_path.exists():
|
||||
raise ValueError(f"shell {shell_id!r} does not exist in workspace {workspace_id!r}")
|
||||
payload = json.loads(record_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"shell record at {record_path} is invalid")
|
||||
return WorkspaceShellRecord.from_payload(payload)
|
||||
|
||||
def _save_workspace_shell_locked(self, shell: WorkspaceShellRecord) -> None:
|
||||
record_path = self._workspace_shell_record_path(shell.workspace_id, shell.shell_id)
|
||||
record_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
record_path.write_text(
|
||||
json.dumps(shell.to_payload(), indent=2, sort_keys=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def _delete_workspace_shell_locked(self, workspace_id: str, shell_id: str) -> None:
|
||||
record_path = self._workspace_shell_record_path(workspace_id, shell_id)
|
||||
if record_path.exists():
|
||||
record_path.unlink()
|
||||
|
||||
def _list_workspace_shells_locked(self, workspace_id: str) -> list[WorkspaceShellRecord]:
|
||||
shells_dir = self._workspace_shells_dir(workspace_id)
|
||||
if not shells_dir.exists():
|
||||
return []
|
||||
shells: list[WorkspaceShellRecord] = []
|
||||
for record_path in sorted(shells_dir.glob("*.json")):
|
||||
payload = json.loads(record_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
continue
|
||||
shells.append(WorkspaceShellRecord.from_payload(payload))
|
||||
return shells
|
||||
|
||||
def _close_workspace_shells_locked(
|
||||
self,
|
||||
workspace: WorkspaceRecord,
|
||||
instance: VmInstance,
|
||||
) -> None:
|
||||
for shell in self._list_workspace_shells_locked(workspace.workspace_id):
|
||||
try:
|
||||
self._backend.close_shell(
|
||||
instance,
|
||||
workspace_id=workspace.workspace_id,
|
||||
shell_id=shell.shell_id,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
self._delete_workspace_shell_locked(workspace.workspace_id, shell.shell_id)
|
||||
|
|
|
|||
291
src/pyro_mcp/workspace_shells.py
Normal file
291
src/pyro_mcp/workspace_shells.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
"""Local PTY-backed shell sessions for the mock workspace backend."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import codecs
|
||||
import fcntl
|
||||
import os
|
||||
import pty
|
||||
import shlex
|
||||
import signal
|
||||
import struct
|
||||
import subprocess
|
||||
import termios
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
ShellState = Literal["running", "stopped"]
|
||||
|
||||
SHELL_SIGNAL_NAMES = ("HUP", "INT", "TERM", "KILL")
|
||||
_SHELL_SIGNAL_MAP = {
|
||||
"HUP": signal.SIGHUP,
|
||||
"INT": signal.SIGINT,
|
||||
"TERM": signal.SIGTERM,
|
||||
"KILL": signal.SIGKILL,
|
||||
}
|
||||
|
||||
_LOCAL_SHELLS: dict[str, "LocalShellSession"] = {}
|
||||
_LOCAL_SHELLS_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
|
||||
winsize = struct.pack("HHHH", rows, cols, 0, 0)
|
||||
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
|
||||
|
||||
|
||||
class LocalShellSession:
|
||||
"""Host-local interactive shell used by the mock backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
shell_id: str,
|
||||
cwd: Path,
|
||||
display_cwd: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> None:
|
||||
self.shell_id = shell_id
|
||||
self.cwd = display_cwd
|
||||
self.cols = cols
|
||||
self.rows = rows
|
||||
self.started_at = time.time()
|
||||
self.ended_at: float | None = None
|
||||
self.exit_code: int | None = None
|
||||
self.state: ShellState = "running"
|
||||
self.pid: int | None = None
|
||||
self._lock = threading.RLock()
|
||||
self._output = ""
|
||||
self._master_fd: int | None = None
|
||||
self._reader: threading.Thread | None = None
|
||||
self._waiter: threading.Thread | None = None
|
||||
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
|
||||
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
try:
|
||||
_set_pty_size(slave_fd, rows, cols)
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"TERM": env.get("TERM", "xterm-256color"),
|
||||
"PS1": "pyro$ ",
|
||||
"PROMPT_COMMAND": "",
|
||||
}
|
||||
)
|
||||
process = subprocess.Popen( # noqa: S603
|
||||
["/bin/bash", "--noprofile", "--norc", "-i"],
|
||||
stdin=slave_fd,
|
||||
stdout=slave_fd,
|
||||
stderr=slave_fd,
|
||||
cwd=str(cwd),
|
||||
env=env,
|
||||
text=False,
|
||||
close_fds=True,
|
||||
preexec_fn=os.setsid,
|
||||
)
|
||||
except Exception:
|
||||
os.close(master_fd)
|
||||
raise
|
||||
finally:
|
||||
os.close(slave_fd)
|
||||
|
||||
self._process = process
|
||||
self.pid = process.pid
|
||||
self._master_fd = master_fd
|
||||
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
|
||||
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
|
||||
self._reader.start()
|
||||
self._waiter.start()
|
||||
|
||||
def summary(self) -> dict[str, object]:
|
||||
with self._lock:
|
||||
return {
|
||||
"shell_id": self.shell_id,
|
||||
"cwd": self.cwd,
|
||||
"cols": self.cols,
|
||||
"rows": self.rows,
|
||||
"state": self.state,
|
||||
"started_at": self.started_at,
|
||||
"ended_at": self.ended_at,
|
||||
"exit_code": self.exit_code,
|
||||
"pid": self.pid,
|
||||
}
|
||||
|
||||
def read(self, *, cursor: int, max_chars: int) -> dict[str, object]:
|
||||
with self._lock:
|
||||
clamped_cursor = min(max(cursor, 0), len(self._output))
|
||||
output = self._output[clamped_cursor : clamped_cursor + max_chars]
|
||||
next_cursor = clamped_cursor + len(output)
|
||||
payload = self.summary()
|
||||
payload.update(
|
||||
{
|
||||
"cursor": clamped_cursor,
|
||||
"next_cursor": next_cursor,
|
||||
"output": output,
|
||||
"truncated": next_cursor < len(self._output),
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
def write(self, text: str, *, append_newline: bool) -> dict[str, object]:
|
||||
if self._process.poll() is not None:
|
||||
self._refresh_process_state()
|
||||
with self._lock:
|
||||
if self.state != "running":
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running")
|
||||
master_fd = self._master_fd
|
||||
if master_fd is None:
|
||||
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
|
||||
payload = text + ("\n" if append_newline else "")
|
||||
try:
|
||||
os.write(master_fd, payload.encode("utf-8"))
|
||||
except OSError as exc:
|
||||
self._refresh_process_state()
|
||||
raise RuntimeError(f"failed to write to shell {self.shell_id}: {exc}") from exc
|
||||
result = self.summary()
|
||||
result.update({"input_length": len(text), "append_newline": append_newline})
|
||||
return result
|
||||
|
||||
def send_signal(self, signal_name: str) -> dict[str, object]:
|
||||
signal_name = signal_name.upper()
|
||||
signum = _SHELL_SIGNAL_MAP.get(signal_name)
|
||||
if signum is None:
|
||||
raise ValueError(f"unsupported shell signal: {signal_name}")
|
||||
if self._process.poll() is not None:
|
||||
self._refresh_process_state()
|
||||
with self._lock:
|
||||
if self.state != "running" or self.pid is None:
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running")
|
||||
pid = self.pid
|
||||
try:
|
||||
os.killpg(pid, signum)
|
||||
except ProcessLookupError as exc:
|
||||
self._refresh_process_state()
|
||||
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
|
||||
result = self.summary()
|
||||
result["signal"] = signal_name
|
||||
return result
|
||||
|
||||
def close(self) -> dict[str, object]:
|
||||
if self._process.poll() is None and self.pid is not None:
|
||||
try:
|
||||
os.killpg(self.pid, signal.SIGHUP)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
try:
|
||||
self._process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
try:
|
||||
os.killpg(self.pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
self._process.wait(timeout=5)
|
||||
else:
|
||||
self._refresh_process_state()
|
||||
self._close_master_fd()
|
||||
if self._reader is not None:
|
||||
self._reader.join(timeout=1)
|
||||
if self._waiter is not None:
|
||||
self._waiter.join(timeout=1)
|
||||
result = self.summary()
|
||||
result["closed"] = True
|
||||
return result
|
||||
|
||||
def _reader_loop(self) -> None:
|
||||
master_fd = self._master_fd
|
||||
if master_fd is None:
|
||||
return
|
||||
while True:
|
||||
try:
|
||||
chunk = os.read(master_fd, 65536)
|
||||
except OSError:
|
||||
break
|
||||
if chunk == b"":
|
||||
break
|
||||
decoded = self._decoder.decode(chunk)
|
||||
if decoded:
|
||||
with self._lock:
|
||||
self._output += decoded
|
||||
decoded = self._decoder.decode(b"", final=True)
|
||||
if decoded:
|
||||
with self._lock:
|
||||
self._output += decoded
|
||||
|
||||
def _waiter_loop(self) -> None:
|
||||
exit_code = self._process.wait()
|
||||
with self._lock:
|
||||
self.state = "stopped"
|
||||
self.exit_code = exit_code
|
||||
self.ended_at = time.time()
|
||||
|
||||
def _refresh_process_state(self) -> None:
|
||||
exit_code = self._process.poll()
|
||||
if exit_code is None:
|
||||
return
|
||||
with self._lock:
|
||||
if self.state == "running":
|
||||
self.state = "stopped"
|
||||
self.exit_code = exit_code
|
||||
self.ended_at = time.time()
|
||||
|
||||
def _close_master_fd(self) -> None:
|
||||
with self._lock:
|
||||
master_fd = self._master_fd
|
||||
self._master_fd = None
|
||||
if master_fd is None:
|
||||
return
|
||||
try:
|
||||
os.close(master_fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def create_local_shell(
|
||||
*,
|
||||
workspace_id: str,
|
||||
shell_id: str,
|
||||
cwd: Path,
|
||||
display_cwd: str,
|
||||
cols: int,
|
||||
rows: int,
|
||||
) -> LocalShellSession:
|
||||
session_key = f"{workspace_id}:{shell_id}"
|
||||
with _LOCAL_SHELLS_LOCK:
|
||||
if session_key in _LOCAL_SHELLS:
|
||||
raise RuntimeError(f"shell {shell_id} already exists in workspace {workspace_id}")
|
||||
session = LocalShellSession(
|
||||
shell_id=shell_id,
|
||||
cwd=cwd,
|
||||
display_cwd=display_cwd,
|
||||
cols=cols,
|
||||
rows=rows,
|
||||
)
|
||||
_LOCAL_SHELLS[session_key] = session
|
||||
return session
|
||||
|
||||
|
||||
def get_local_shell(*, workspace_id: str, shell_id: str) -> LocalShellSession:
|
||||
session_key = f"{workspace_id}:{shell_id}"
|
||||
with _LOCAL_SHELLS_LOCK:
|
||||
try:
|
||||
return _LOCAL_SHELLS[session_key]
|
||||
except KeyError as exc:
|
||||
raise ValueError(
|
||||
f"shell {shell_id!r} does not exist in workspace {workspace_id!r}"
|
||||
) from exc
|
||||
|
||||
|
||||
def remove_local_shell(*, workspace_id: str, shell_id: str) -> LocalShellSession | None:
|
||||
session_key = f"{workspace_id}:{shell_id}"
|
||||
with _LOCAL_SHELLS_LOCK:
|
||||
return _LOCAL_SHELLS.pop(session_key, None)
|
||||
|
||||
|
||||
def shell_signal_names() -> tuple[str, ...]:
|
||||
return SHELL_SIGNAL_NAMES
|
||||
|
||||
|
||||
def shell_signal_arg_help() -> str:
|
||||
return ", ".join(shlex.quote(name) for name in SHELL_SIGNAL_NAMES)
|
||||
Loading…
Add table
Add a link
Reference in a new issue