From 3f8293ad241badc37e3c30754f7a27b8585d529d Mon Sep 17 00:00:00 2001 From: Thales Maciel Date: Thu, 12 Mar 2026 02:31:57 -0300 Subject: [PATCH] Add persistent workspace shell sessions Let agents inhabit a workspace across separate calls instead of only submitting one-shot execs. Add workspace shell open/read/write/signal/close across the CLI, Python SDK, and MCP server, with persisted shell records, a local PTY-backed mock implementation, and guest-agent support for real Firecracker workspaces. Mark the 2.5.0 roadmap milestone done, refresh docs/examples and the release metadata, and verify with uv lock, UV_CACHE_DIR=.uv-cache make check, and UV_CACHE_DIR=.uv-cache make dist-check. --- CHANGELOG.md | 10 + README.md | 17 +- docs/first-run.md | 18 +- docs/install.md | 12 +- docs/integrations.md | 4 + docs/public-contract.md | 25 + docs/roadmap/task-workspace-ga.md | 7 +- .../2.5.0-pty-shell-sessions.md | 2 + examples/python_shell.py | 34 + pyproject.toml | 2 +- .../linux-x86_64/guest/pyro_guest_agent.py | 369 +++++++++- src/pyro_mcp/api.py | 119 +++ src/pyro_mcp/cli.py | 312 ++++++++ src/pyro_mcp/contract.py | 26 +- .../linux-x86_64/guest/pyro_guest_agent.py | 369 +++++++++- .../runtime_bundle/linux-x86_64/manifest.json | 2 +- src/pyro_mcp/vm_environments.py | 2 +- src/pyro_mcp/vm_guest.py | 230 +++++- src/pyro_mcp/vm_manager.py | 684 +++++++++++++++++- src/pyro_mcp/workspace_shells.py | 291 ++++++++ tests/test_api.py | 20 + tests/test_cli.py | 256 +++++++ tests/test_public_contract.py | 37 + tests/test_server.py | 85 ++- tests/test_vm_guest.py | 124 ++++ tests/test_vm_manager.py | 67 ++ tests/test_workspace_shells.py | 220 ++++++ uv.lock | 2 +- 28 files changed, 3265 insertions(+), 81 deletions(-) create mode 100644 examples/python_shell.py create mode 100644 src/pyro_mcp/workspace_shells.py create mode 100644 tests/test_workspace_shells.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 5311308..c0269c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,16 @@ All notable user-visible changes to `pyro-mcp` are documented here. +## 2.5.0 + +- Added persistent PTY shell sessions across the CLI, Python SDK, and MCP server with + `pyro workspace shell *`, `Pyro.open_shell()` / `read_shell()` / `write_shell()` / + `signal_shell()` / `close_shell()`, and `shell_*` MCP tools. +- Kept interactive shells separate from `workspace exec`, with cursor-based merged output reads + and explicit close/signal operations for long-lived workspace sessions. +- Updated the bundled guest agent and mock backend so shell sessions persist across separate + calls and are cleaned up automatically by `workspace delete`. + ## 2.4.0 - Replaced the public persistent-workspace surface from `task_*` to `workspace_*` across the CLI, diff --git a/README.md b/README.md index 86b5a91..11fac75 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ It exposes the same runtime in three public forms: - First run transcript: [docs/first-run.md](docs/first-run.md) - Terminal walkthrough GIF: [docs/assets/first-run.gif](docs/assets/first-run.gif) - PyPI package: [pypi.org/project/pyro-mcp](https://pypi.org/project/pyro-mcp/) -- What's new in 2.4.0: [CHANGELOG.md#240](CHANGELOG.md#240) +- What's new in 2.5.0: [CHANGELOG.md#250](CHANGELOG.md#250) - Host requirements: [docs/host-requirements.md](docs/host-requirements.md) - Integration targets: [docs/integrations.md](docs/integrations.md) - Public contract: [docs/public-contract.md](docs/public-contract.md) @@ -57,7 +57,7 @@ What success looks like: ```bash Platform: linux-x86_64 Runtime: PASS -Catalog version: 2.4.0 +Catalog version: 2.5.0 ... [pull] phase=install environment=debian:12 [pull] phase=ready environment=debian:12 @@ -78,6 +78,7 @@ After the quickstart works: - prove the full one-shot lifecycle with `uvx --from pyro-mcp pyro demo` - create a persistent workspace with `uvx --from pyro-mcp pyro workspace create debian:12 --seed-path ./repo` - update a live workspace from the host with `uvx --from pyro-mcp pyro workspace sync push WORKSPACE_ID ./changes` +- open a persistent interactive shell with `uvx --from pyro-mcp pyro workspace shell open WORKSPACE_ID` - move to Python or MCP via [docs/integrations.md](docs/integrations.md) ## Supported Hosts @@ -131,7 +132,7 @@ uvx --from pyro-mcp pyro env list Expected output: ```bash -Catalog version: 2.4.0 +Catalog version: 2.5.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -209,6 +210,10 @@ longer-term interaction model. pyro workspace create debian:12 --seed-path ./repo pyro workspace sync push WORKSPACE_ID ./changes --dest src pyro workspace exec WORKSPACE_ID -- cat src/note.txt +pyro workspace shell open WORKSPACE_ID +pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd' +pyro workspace shell read WORKSPACE_ID SHELL_ID +pyro workspace shell close WORKSPACE_ID SHELL_ID pyro workspace logs WORKSPACE_ID pyro workspace delete WORKSPACE_ID ``` @@ -217,8 +222,10 @@ Persistent workspaces start in `/workspace` and keep command history until you d machine consumption, add `--json` and read the returned `workspace_id`. Use `--seed-path` when you want the workspace to start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive instead of an empty workspace. Use `pyro workspace sync push` when you want to import -later host-side changes into a started workspace. Sync is non-atomic in `2.4.0`; if it fails -partway through, delete and recreate the workspace from its seed. +later host-side changes into a started workspace. Sync is non-atomic in `2.5.0`; if it fails +partway through, delete and recreate the workspace from its seed. Use `pyro workspace exec` for +one-shot non-interactive commands inside a live workspace, and `pyro workspace shell *` when you +need a persistent PTY session that keeps interactive shell state between calls. ## Public Interfaces diff --git a/docs/first-run.md b/docs/first-run.md index ce3adcd..b0d0e73 100644 --- a/docs/first-run.md +++ b/docs/first-run.md @@ -22,7 +22,7 @@ Networking: tun=yes ip_forward=yes ```bash $ uvx --from pyro-mcp pyro env list -Catalog version: 2.4.0 +Catalog version: 2.5.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -72,6 +72,7 @@ deterministic structured result. $ uvx --from pyro-mcp pyro demo $ uvx --from pyro-mcp pyro workspace create debian:12 --seed-path ./repo $ uvx --from pyro-mcp pyro workspace sync push WORKSPACE_ID ./changes +$ uvx --from pyro-mcp pyro workspace shell open WORKSPACE_ID $ uvx --from pyro-mcp pyro mcp serve ``` @@ -96,13 +97,24 @@ $ uvx --from pyro-mcp pyro workspace sync push WORKSPACE_ID ./changes --dest src $ uvx --from pyro-mcp pyro workspace exec WORKSPACE_ID -- cat src/note.txt hello from synced workspace [workspace-exec] workspace_id=... sequence=1 cwd=/workspace execution_mode=guest_vsock exit_code=0 duration_ms=... + +$ uvx --from pyro-mcp pyro workspace shell open WORKSPACE_ID +[workspace-shell-open] workspace_id=... shell_id=... state=running cwd=/workspace cols=120 rows=30 execution_mode=guest_vsock + +$ uvx --from pyro-mcp pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd' +[workspace-shell-write] workspace_id=... shell_id=... state=running cwd=/workspace cols=120 rows=30 execution_mode=guest_vsock + +$ uvx --from pyro-mcp pyro workspace shell read WORKSPACE_ID SHELL_ID +/workspace +[workspace-shell-read] workspace_id=... shell_id=... state=running cursor=0 next_cursor=... truncated=False execution_mode=guest_vsock ``` Use `--seed-path` when the workspace should start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive instead of an empty `/workspace`. Use `pyro workspace sync push` when you need to import later host-side changes into a started -workspace. Sync is non-atomic in `2.4.0`; if it fails partway through, delete and recreate the -workspace. +workspace. Sync is non-atomic in `2.5.0`; if it fails partway through, delete and recreate the +workspace. Use `pyro workspace exec` for one-shot commands and `pyro workspace shell *` when you +need a persistent interactive PTY session in that same workspace. Example output: diff --git a/docs/install.md b/docs/install.md index 69f18ba..4e7300b 100644 --- a/docs/install.md +++ b/docs/install.md @@ -83,7 +83,7 @@ uvx --from pyro-mcp pyro env list Expected output: ```bash -Catalog version: 2.4.0 +Catalog version: 2.5.0 debian:12 [installed|not installed] Debian 12 environment with Git preinstalled for common agent workflows. debian:12-base [installed|not installed] Minimal Debian 12 environment for shell and core Unix tooling. debian:12-build [installed|not installed] Debian 12 environment with Git and common build tools preinstalled. @@ -176,6 +176,7 @@ After the CLI path works, you can move on to: - persistent workspaces: `pyro workspace create debian:12 --seed-path ./repo` - live workspace updates: `pyro workspace sync push WORKSPACE_ID ./changes` +- interactive shells: `pyro workspace shell open WORKSPACE_ID` - MCP: `pyro mcp serve` - Python SDK: `from pyro_mcp import Pyro` - Demos: `pyro demo` or `pyro demo --network` @@ -188,6 +189,10 @@ Use `pyro workspace ...` when you need repeated commands in one sandbox instead pyro workspace create debian:12 --seed-path ./repo pyro workspace sync push WORKSPACE_ID ./changes --dest src pyro workspace exec WORKSPACE_ID -- cat src/note.txt +pyro workspace shell open WORKSPACE_ID +pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd' +pyro workspace shell read WORKSPACE_ID SHELL_ID +pyro workspace shell close WORKSPACE_ID SHELL_ID pyro workspace logs WORKSPACE_ID pyro workspace delete WORKSPACE_ID ``` @@ -196,8 +201,9 @@ Workspace commands default to the persistent `/workspace` directory inside the g the identifier programmatically, use `--json` and read the `workspace_id` field. Use `--seed-path` when the workspace should start from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive. Use `pyro workspace sync push` for later host-side changes to a started workspace. Sync -is non-atomic in `2.4.0`; if it fails partway through, delete and recreate the workspace from its -seed. +is non-atomic in `2.5.0`; if it fails partway through, delete and recreate the workspace from its +seed. Use `pyro workspace exec` for one-shot commands and `pyro workspace shell *` when you need +an interactive PTY that survives across separate calls. ## Contributor Clone diff --git a/docs/integrations.md b/docs/integrations.md index 0061ebb..9be184d 100644 --- a/docs/integrations.md +++ b/docs/integrations.md @@ -31,6 +31,7 @@ Recommended surface: - `vm_run` - `workspace_create(seed_path=...)` + `workspace_sync_push` + `workspace_exec` when the agent needs persistent workspace state +- `open_shell` / `read_shell` / `write_shell` when the agent needs an interactive PTY inside that workspace Canonical example: @@ -66,6 +67,7 @@ Recommended default: - `Pyro.run_in_vm(...)` - `Pyro.create_workspace(seed_path=...)` + `Pyro.push_workspace_sync(...)` + `Pyro.exec_workspace(...)` when repeated workspace commands are required +- `Pyro.open_shell(...)` + `Pyro.write_shell(...)` + `Pyro.read_shell(...)` when the agent needs an interactive PTY inside the workspace Lifecycle note: @@ -76,12 +78,14 @@ Lifecycle note: `/workspace` that starts from host content - use `push_workspace_sync(...)` when later host-side changes need to be imported into that running workspace without recreating it +- use `open_shell(...)` when the agent needs interactive shell state instead of one-shot execs Examples: - [examples/python_run.py](../examples/python_run.py) - [examples/python_lifecycle.py](../examples/python_lifecycle.py) - [examples/python_workspace.py](../examples/python_workspace.py) +- [examples/python_shell.py](../examples/python_shell.py) ## Agent Framework Wrappers diff --git a/docs/public-contract.md b/docs/public-contract.md index 49d1746..65e5b74 100644 --- a/docs/public-contract.md +++ b/docs/public-contract.md @@ -22,6 +22,11 @@ Top-level commands: - `pyro workspace create` - `pyro workspace sync push` - `pyro workspace exec` +- `pyro workspace shell open` +- `pyro workspace shell read` +- `pyro workspace shell write` +- `pyro workspace shell signal` +- `pyro workspace shell close` - `pyro workspace status` - `pyro workspace logs` - `pyro workspace delete` @@ -50,6 +55,7 @@ Behavioral guarantees: - `pyro workspace create --seed-path PATH` seeds `/workspace` from a host directory or a local `.tar` / `.tar.gz` / `.tgz` archive before the workspace is returned. - `pyro workspace sync push WORKSPACE_ID SOURCE_PATH [--dest WORKSPACE_PATH]` imports later host-side directory or archive content into a started workspace. - `pyro workspace exec` runs in the persistent `/workspace` for that workspace and does not auto-clean. +- `pyro workspace shell *` manages persistent PTY sessions inside a started workspace. - `pyro workspace logs` returns persisted command history for that workspace until `pyro workspace delete`. - Workspace create/status results expose `workspace_seed` metadata describing how `/workspace` was initialized. @@ -70,6 +76,11 @@ Supported public entrypoints: - `Pyro.create_vm(...)` - `Pyro.create_workspace(...)` - `Pyro.push_workspace_sync(workspace_id, source_path, *, dest="/workspace")` +- `Pyro.open_shell(workspace_id, *, cwd="/workspace", cols=120, rows=30)` +- `Pyro.read_shell(workspace_id, shell_id, *, cursor=0, max_chars=65536)` +- `Pyro.write_shell(workspace_id, shell_id, *, input, append_newline=True)` +- `Pyro.signal_shell(workspace_id, shell_id, *, signal_name="INT")` +- `Pyro.close_shell(workspace_id, shell_id)` - `Pyro.start_vm(vm_id)` - `Pyro.exec_vm(vm_id, *, command, timeout_seconds=30)` - `Pyro.exec_workspace(workspace_id, *, command, timeout_seconds=30)` @@ -93,6 +104,11 @@ Stable public method names: - `create_vm(...)` - `create_workspace(...)` - `push_workspace_sync(workspace_id, source_path, *, dest="/workspace")` +- `open_shell(workspace_id, *, cwd="/workspace", cols=120, rows=30)` +- `read_shell(workspace_id, shell_id, *, cursor=0, max_chars=65536)` +- `write_shell(workspace_id, shell_id, *, input, append_newline=True)` +- `signal_shell(workspace_id, shell_id, *, signal_name="INT")` +- `close_shell(workspace_id, shell_id)` - `start_vm(vm_id)` - `exec_vm(vm_id, *, command, timeout_seconds=30)` - `exec_workspace(workspace_id, *, command, timeout_seconds=30)` @@ -116,6 +132,9 @@ Behavioral defaults: - `Pyro.push_workspace_sync(...)` imports later host-side directory or archive content into a started workspace. - `Pyro.exec_vm(...)` runs one command and auto-cleans that VM after the exec completes. - `Pyro.exec_workspace(...)` runs one command in the persistent workspace and leaves it alive. +- `Pyro.open_shell(...)` opens a persistent PTY shell attached to one started workspace. +- `Pyro.read_shell(...)` reads merged text output from that shell by cursor. +- `Pyro.write_shell(...)`, `Pyro.signal_shell(...)`, and `Pyro.close_shell(...)` operate on that persistent shell session. ## MCP Contract @@ -140,6 +159,11 @@ Persistent workspace tools: - `workspace_create` - `workspace_sync_push` - `workspace_exec` +- `shell_open` +- `shell_read` +- `shell_write` +- `shell_signal` +- `shell_close` - `workspace_status` - `workspace_logs` - `workspace_delete` @@ -154,6 +178,7 @@ Behavioral defaults: - `workspace_sync_push` imports later host-side directory or archive content into a started workspace, with an optional `dest` under `/workspace`. - `vm_exec` runs one command and auto-cleans that VM after the exec completes. - `workspace_exec` runs one command in a persistent `/workspace` and leaves the workspace alive. +- `shell_open`, `shell_read`, `shell_write`, `shell_signal`, and `shell_close` manage persistent PTY shells inside a started workspace. ## Versioning Rule diff --git a/docs/roadmap/task-workspace-ga.md b/docs/roadmap/task-workspace-ga.md index b39ed02..e878c10 100644 --- a/docs/roadmap/task-workspace-ga.md +++ b/docs/roadmap/task-workspace-ga.md @@ -2,11 +2,12 @@ This roadmap turns the agent-workspace vision into release-sized milestones. -Current baseline is `2.4.0`: +Current baseline is `2.5.0`: - workspace persistence exists and the public surface is now workspace-first - host crossing currently covers create-time seeding and later sync push -- no shell, export, diff, service, snapshot, reset, or secrets contract exists yet +- persistent PTY shell sessions exist alongside one-shot `workspace exec` +- no export, diff, service, snapshot, reset, or secrets contract exists yet Locked roadmap decisions: @@ -26,7 +27,7 @@ also expected to update: ## Milestones 1. [`2.4.0` Workspace Contract Pivot](task-workspace-ga/2.4.0-workspace-contract-pivot.md) - Done -2. [`2.5.0` PTY Shell Sessions](task-workspace-ga/2.5.0-pty-shell-sessions.md) +2. [`2.5.0` PTY Shell Sessions](task-workspace-ga/2.5.0-pty-shell-sessions.md) - Done 3. [`2.6.0` Structured Export And Baseline Diff](task-workspace-ga/2.6.0-structured-export-and-baseline-diff.md) 4. [`2.7.0` Service Lifecycle And Typed Readiness](task-workspace-ga/2.7.0-service-lifecycle-and-typed-readiness.md) 5. [`2.8.0` Named Snapshots And Reset](task-workspace-ga/2.8.0-named-snapshots-and-reset.md) diff --git a/docs/roadmap/task-workspace-ga/2.5.0-pty-shell-sessions.md b/docs/roadmap/task-workspace-ga/2.5.0-pty-shell-sessions.md index ccd3a30..91bcf2d 100644 --- a/docs/roadmap/task-workspace-ga/2.5.0-pty-shell-sessions.md +++ b/docs/roadmap/task-workspace-ga/2.5.0-pty-shell-sessions.md @@ -1,5 +1,7 @@ # `2.5.0` PTY Shell Sessions +Status: Done + ## Goal Add persistent interactive shells so an agent can inhabit a workspace instead diff --git a/examples/python_shell.py b/examples/python_shell.py new file mode 100644 index 0000000..85887f7 --- /dev/null +++ b/examples/python_shell.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +import tempfile +import time +from pathlib import Path + +from pyro_mcp import Pyro + + +def main() -> None: + pyro = Pyro() + with tempfile.TemporaryDirectory(prefix="pyro-workspace-seed-") as seed_dir: + Path(seed_dir, "note.txt").write_text("hello from shell\n", encoding="utf-8") + created = pyro.create_workspace(environment="debian:12", seed_path=seed_dir) + workspace_id = str(created["workspace_id"]) + try: + opened = pyro.open_shell(workspace_id) + shell_id = str(opened["shell_id"]) + pyro.write_shell(workspace_id, shell_id, input="pwd") + deadline = time.time() + 5 + while True: + read = pyro.read_shell(workspace_id, shell_id, cursor=0) + output = str(read["output"]) + if "/workspace" in output or time.time() >= deadline: + print(output, end="") + break + time.sleep(0.1) + pyro.close_shell(workspace_id, shell_id) + finally: + pyro.delete_workspace(workspace_id) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 05d4be3..2ad5fb0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pyro-mcp" -version = "2.4.0" +version = "2.5.0" description = "Ephemeral Firecracker sandboxes with curated environments, persistent workspaces, and MCP tools." readme = "README.md" license = { file = "LICENSE" } diff --git a/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py b/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py index d914fc1..03ff492 100644 --- a/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py +++ b/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py @@ -1,14 +1,21 @@ #!/usr/bin/env python3 -"""Minimal guest-side exec and workspace import agent for pyro runtime bundles.""" +"""Guest-side exec, workspace import, and interactive shell agent.""" from __future__ import annotations +import codecs +import fcntl import io import json import os +import pty +import signal import socket +import struct import subprocess import tarfile +import termios +import threading import time from pathlib import Path, PurePosixPath from typing import Any @@ -16,6 +23,17 @@ from typing import Any PORT = 5005 BUFFER_SIZE = 65536 WORKSPACE_ROOT = PurePosixPath("/workspace") +SHELL_ROOT = Path("/run/pyro-shells") +SHELL_SIGNAL_MAP = { + "HUP": signal.SIGHUP, + "INT": signal.SIGINT, + "TERM": signal.SIGTERM, + "KILL": signal.SIGKILL, +} +SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP) + +_SHELLS: dict[str, "GuestShellSession"] = {} +_SHELLS_LOCK = threading.Lock() def _read_request(conn: socket.socket) -> dict[str, Any]: @@ -77,10 +95,15 @@ def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]: suffix = normalized.relative_to(WORKSPACE_ROOT) host_path = Path("/workspace") if str(suffix) not in {"", "."}: - host_path = host_path / str(suffix) + host_path = host_path.joinpath(*suffix.parts) return normalized, host_path +def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]: + normalized, host_path = _normalize_destination(cwd) + return str(normalized), host_path + + def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None: target = link_target.strip() if target == "": @@ -106,18 +129,18 @@ def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: - _, destination_root = _normalize_destination(destination) + normalized_destination, destination_root = _normalize_destination(destination) destination_root.mkdir(parents=True, exist_ok=True) bytes_written = 0 entry_count = 0 with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive: for member in archive.getmembers(): member_name = _normalize_member_name(member.name) - target_path = destination_root / str(member_name) + target_path = destination_root.joinpath(*member_name.parts) entry_count += 1 _ensure_no_symlink_parents(destination_root, target_path, member.name) if member.isdir(): - if target_path.exists() and not target_path.is_dir(): + if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()): raise RuntimeError(f"directory conflicts with existing path: {member.name}") target_path.mkdir(parents=True, exist_ok=True) continue @@ -151,7 +174,7 @@ def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: ) raise RuntimeError(f"unsupported archive member type: {member.name}") return { - "destination": destination, + "destination": str(normalized_destination), "entry_count": entry_count, "bytes_written": bytes_written, } @@ -182,7 +205,323 @@ def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]: } +def _set_pty_size(fd: int, rows: int, cols: int) -> None: + winsize = struct.pack("HHHH", rows, cols, 0, 0) + fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize) + + +class GuestShellSession: + """In-guest PTY-backed interactive shell session.""" + + def __init__( + self, + *, + shell_id: str, + cwd: Path, + cwd_text: str, + cols: int, + rows: int, + ) -> None: + self.shell_id = shell_id + self.cwd = cwd_text + self.cols = cols + self.rows = rows + self.started_at = time.time() + self.ended_at: float | None = None + self.exit_code: int | None = None + self.state = "running" + self._lock = threading.RLock() + self._output = "" + self._decoder = codecs.getincrementaldecoder("utf-8")("replace") + self._metadata_path = SHELL_ROOT / f"{shell_id}.json" + self._log_path = SHELL_ROOT / f"{shell_id}.log" + self._master_fd: int | None = None + + master_fd, slave_fd = pty.openpty() + try: + _set_pty_size(slave_fd, rows, cols) + env = os.environ.copy() + env.update( + { + "TERM": env.get("TERM", "xterm-256color"), + "PS1": "pyro$ ", + "PROMPT_COMMAND": "", + } + ) + process = subprocess.Popen( # noqa: S603 + ["/bin/bash", "--noprofile", "--norc", "-i"], + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + cwd=str(cwd), + env=env, + text=False, + close_fds=True, + preexec_fn=os.setsid, + ) + except Exception: + os.close(master_fd) + raise + finally: + os.close(slave_fd) + + self._process = process + self._master_fd = master_fd + self._write_metadata() + self._reader = threading.Thread(target=self._reader_loop, daemon=True) + self._waiter = threading.Thread(target=self._waiter_loop, daemon=True) + self._reader.start() + self._waiter.start() + + def summary(self) -> dict[str, Any]: + with self._lock: + return { + "shell_id": self.shell_id, + "cwd": self.cwd, + "cols": self.cols, + "rows": self.rows, + "state": self.state, + "started_at": self.started_at, + "ended_at": self.ended_at, + "exit_code": self.exit_code, + } + + def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]: + with self._lock: + clamped_cursor = min(max(cursor, 0), len(self._output)) + output = self._output[clamped_cursor : clamped_cursor + max_chars] + next_cursor = clamped_cursor + len(output) + payload = self.summary() + payload.update( + { + "cursor": clamped_cursor, + "next_cursor": next_cursor, + "output": output, + "truncated": next_cursor < len(self._output), + } + ) + return payload + + def write(self, text: str, *, append_newline: bool) -> dict[str, Any]: + if self._process.poll() is not None: + self._refresh_process_state() + with self._lock: + if self.state != "running": + raise RuntimeError(f"shell {self.shell_id} is not running") + master_fd = self._master_fd + if master_fd is None: + raise RuntimeError(f"shell {self.shell_id} transport is unavailable") + payload = text + ("\n" if append_newline else "") + try: + os.write(master_fd, payload.encode("utf-8")) + except OSError as exc: + self._refresh_process_state() + raise RuntimeError(f"failed to write shell input: {exc}") from exc + response = self.summary() + response.update({"input_length": len(text), "append_newline": append_newline}) + return response + + def send_signal(self, signal_name: str) -> dict[str, Any]: + signum = SHELL_SIGNAL_MAP.get(signal_name) + if signum is None: + raise ValueError(f"unsupported shell signal: {signal_name}") + if self._process.poll() is not None: + self._refresh_process_state() + with self._lock: + if self.state != "running": + raise RuntimeError(f"shell {self.shell_id} is not running") + pid = self._process.pid + try: + os.killpg(pid, signum) + except ProcessLookupError as exc: + self._refresh_process_state() + raise RuntimeError(f"shell {self.shell_id} is not running") from exc + response = self.summary() + response["signal"] = signal_name + return response + + def close(self) -> dict[str, Any]: + if self._process.poll() is None: + try: + os.killpg(self._process.pid, signal.SIGHUP) + except ProcessLookupError: + pass + try: + self._process.wait(timeout=5) + except subprocess.TimeoutExpired: + try: + os.killpg(self._process.pid, signal.SIGKILL) + except ProcessLookupError: + pass + self._process.wait(timeout=5) + else: + self._refresh_process_state() + self._close_master_fd() + if self._reader is not None: + self._reader.join(timeout=1) + if self._waiter is not None: + self._waiter.join(timeout=1) + response = self.summary() + response["closed"] = True + return response + + def _reader_loop(self) -> None: + master_fd = self._master_fd + if master_fd is None: + return + while True: + try: + chunk = os.read(master_fd, BUFFER_SIZE) + except OSError: + break + if chunk == b"": + break + decoded = self._decoder.decode(chunk) + if decoded == "": + continue + with self._lock: + self._output += decoded + with self._log_path.open("a", encoding="utf-8") as handle: + handle.write(decoded) + decoded = self._decoder.decode(b"", final=True) + if decoded != "": + with self._lock: + self._output += decoded + with self._log_path.open("a", encoding="utf-8") as handle: + handle.write(decoded) + + def _waiter_loop(self) -> None: + exit_code = self._process.wait() + with self._lock: + self.state = "stopped" + self.exit_code = exit_code + self.ended_at = time.time() + self._write_metadata() + + def _refresh_process_state(self) -> None: + exit_code = self._process.poll() + if exit_code is None: + return + with self._lock: + if self.state == "running": + self.state = "stopped" + self.exit_code = exit_code + self.ended_at = time.time() + self._write_metadata() + + def _write_metadata(self) -> None: + self._metadata_path.parent.mkdir(parents=True, exist_ok=True) + self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8") + + def _close_master_fd(self) -> None: + with self._lock: + master_fd = self._master_fd + self._master_fd = None + if master_fd is None: + return + try: + os.close(master_fd) + except OSError: + pass + + +def _create_shell( + *, + shell_id: str, + cwd_text: str, + cols: int, + rows: int, +) -> GuestShellSession: + _, cwd_path = _normalize_shell_cwd(cwd_text) + with _SHELLS_LOCK: + if shell_id in _SHELLS: + raise RuntimeError(f"shell {shell_id!r} already exists") + session = GuestShellSession( + shell_id=shell_id, + cwd=cwd_path, + cwd_text=cwd_text, + cols=cols, + rows=rows, + ) + _SHELLS[shell_id] = session + return session + + +def _get_shell(shell_id: str) -> GuestShellSession: + with _SHELLS_LOCK: + try: + return _SHELLS[shell_id] + except KeyError as exc: + raise RuntimeError(f"shell {shell_id!r} does not exist") from exc + + +def _remove_shell(shell_id: str) -> GuestShellSession: + with _SHELLS_LOCK: + try: + return _SHELLS.pop(shell_id) + except KeyError as exc: + raise RuntimeError(f"shell {shell_id!r} does not exist") from exc + + +def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]: + action = str(request.get("action", "exec")) + if action == "extract_archive": + archive_size = int(request.get("archive_size", 0)) + if archive_size < 0: + raise RuntimeError("archive_size must not be negative") + destination = str(request.get("destination", "/workspace")) + payload = _read_exact(conn, archive_size) + return _extract_archive(payload, destination) + if action == "open_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace"))) + session = _create_shell( + shell_id=shell_id, + cwd_text=cwd_text, + cols=int(request.get("cols", 120)), + rows=int(request.get("rows", 30)), + ) + return session.summary() + if action == "read_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + return _get_shell(shell_id).read( + cursor=int(request.get("cursor", 0)), + max_chars=int(request.get("max_chars", 65536)), + ) + if action == "write_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + return _get_shell(shell_id).write( + str(request.get("input", "")), + append_newline=bool(request.get("append_newline", True)), + ) + if action == "signal_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + signal_name = str(request.get("signal", "INT")).upper() + if signal_name not in SHELL_SIGNAL_NAMES: + raise RuntimeError( + f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}" + ) + return _get_shell(shell_id).send_signal(signal_name) + if action == "close_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + return _remove_shell(shell_id).close() + command = str(request.get("command", "")) + timeout_seconds = int(request.get("timeout_seconds", 30)) + return _run_command(command, timeout_seconds) + + def main() -> None: + SHELL_ROOT.mkdir(parents=True, exist_ok=True) family = getattr(socket, "AF_VSOCK", None) if family is None: raise SystemExit("AF_VSOCK is unavailable") @@ -192,19 +531,11 @@ def main() -> None: while True: conn, _ = server.accept() with conn: - request = _read_request(conn) - action = str(request.get("action", "exec")) - if action == "extract_archive": - archive_size = int(request.get("archive_size", 0)) - if archive_size < 0: - raise RuntimeError("archive_size must not be negative") - destination = str(request.get("destination", "/workspace")) - payload = _read_exact(conn, archive_size) - response = _extract_archive(payload, destination) - else: - command = str(request.get("command", "")) - timeout_seconds = int(request.get("timeout_seconds", 30)) - response = _run_command(command, timeout_seconds) + try: + request = _read_request(conn) + response = _dispatch(request, conn) + except Exception as exc: # noqa: BLE001 + response = {"error": str(exc)} conn.sendall((json.dumps(response) + "\n").encode("utf-8")) diff --git a/src/pyro_mcp/api.py b/src/pyro_mcp/api.py index ac8c0bf..568d089 100644 --- a/src/pyro_mcp/api.py +++ b/src/pyro_mcp/api.py @@ -130,6 +130,67 @@ class Pyro: def logs_workspace(self, workspace_id: str) -> dict[str, Any]: return self._manager.logs_workspace(workspace_id) + def open_shell( + self, + workspace_id: str, + *, + cwd: str = "/workspace", + cols: int = 120, + rows: int = 30, + ) -> dict[str, Any]: + return self._manager.open_shell( + workspace_id, + cwd=cwd, + cols=cols, + rows=rows, + ) + + def read_shell( + self, + workspace_id: str, + shell_id: str, + *, + cursor: int = 0, + max_chars: int = 65536, + ) -> dict[str, Any]: + return self._manager.read_shell( + workspace_id, + shell_id, + cursor=cursor, + max_chars=max_chars, + ) + + def write_shell( + self, + workspace_id: str, + shell_id: str, + *, + input: str, + append_newline: bool = True, + ) -> dict[str, Any]: + return self._manager.write_shell( + workspace_id, + shell_id, + input_text=input, + append_newline=append_newline, + ) + + def signal_shell( + self, + workspace_id: str, + shell_id: str, + *, + signal_name: str = "INT", + ) -> dict[str, Any]: + return self._manager.signal_shell( + workspace_id, + shell_id, + signal_name=signal_name, + ) + + def close_shell(self, workspace_id: str, shell_id: str) -> dict[str, Any]: + return self._manager.close_shell(workspace_id, shell_id) + def delete_workspace(self, workspace_id: str) -> dict[str, Any]: return self._manager.delete_workspace(workspace_id) @@ -309,6 +370,64 @@ class Pyro: """Return persisted command history for one workspace.""" return self.logs_workspace(workspace_id) + @server.tool() + async def shell_open( + workspace_id: str, + cwd: str = "/workspace", + cols: int = 120, + rows: int = 30, + ) -> dict[str, Any]: + """Open a persistent interactive shell inside one workspace.""" + return self.open_shell(workspace_id, cwd=cwd, cols=cols, rows=rows) + + @server.tool() + async def shell_read( + workspace_id: str, + shell_id: str, + cursor: int = 0, + max_chars: int = 65536, + ) -> dict[str, Any]: + """Read merged PTY output from a workspace shell.""" + return self.read_shell( + workspace_id, + shell_id, + cursor=cursor, + max_chars=max_chars, + ) + + @server.tool() + async def shell_write( + workspace_id: str, + shell_id: str, + input: str, + append_newline: bool = True, + ) -> dict[str, Any]: + """Write text input to a persistent workspace shell.""" + return self.write_shell( + workspace_id, + shell_id, + input=input, + append_newline=append_newline, + ) + + @server.tool() + async def shell_signal( + workspace_id: str, + shell_id: str, + signal_name: str = "INT", + ) -> dict[str, Any]: + """Send a signal to the shell process group.""" + return self.signal_shell( + workspace_id, + shell_id, + signal_name=signal_name, + ) + + @server.tool() + async def shell_close(workspace_id: str, shell_id: str) -> dict[str, Any]: + """Close a persistent workspace shell.""" + return self.close_shell(workspace_id, shell_id) + @server.tool() async def workspace_delete(workspace_id: str) -> dict[str, Any]: """Delete a persistent workspace and its backing sandbox.""" diff --git a/src/pyro_mcp/cli.py b/src/pyro_mcp/cli.py index 3fd4eaf..38a592b 100644 --- a/src/pyro_mcp/cli.py +++ b/src/pyro_mcp/cli.py @@ -19,6 +19,7 @@ from pyro_mcp.vm_manager import ( DEFAULT_MEM_MIB, DEFAULT_VCPU_COUNT, WORKSPACE_GUEST_PATH, + WORKSPACE_SHELL_SIGNAL_NAMES, ) @@ -237,6 +238,37 @@ def _print_workspace_logs_human(payload: dict[str, Any]) -> None: print(stderr, end="" if stderr.endswith("\n") else "\n", file=sys.stderr) +def _print_workspace_shell_summary_human(payload: dict[str, Any], *, prefix: str) -> None: + print( + f"[{prefix}] " + f"workspace_id={str(payload.get('workspace_id', 'unknown'))} " + f"shell_id={str(payload.get('shell_id', 'unknown'))} " + f"state={str(payload.get('state', 'unknown'))} " + f"cwd={str(payload.get('cwd', WORKSPACE_GUEST_PATH))} " + f"cols={int(payload.get('cols', 0))} " + f"rows={int(payload.get('rows', 0))} " + f"execution_mode={str(payload.get('execution_mode', 'unknown'))}", + file=sys.stderr, + flush=True, + ) + + +def _print_workspace_shell_read_human(payload: dict[str, Any]) -> None: + _write_stream(str(payload.get("output", "")), stream=sys.stdout) + print( + "[workspace-shell-read] " + f"workspace_id={str(payload.get('workspace_id', 'unknown'))} " + f"shell_id={str(payload.get('shell_id', 'unknown'))} " + f"state={str(payload.get('state', 'unknown'))} " + f"cursor={int(payload.get('cursor', 0))} " + f"next_cursor={int(payload.get('next_cursor', 0))} " + f"truncated={bool(payload.get('truncated', False))} " + f"execution_mode={str(payload.get('execution_mode', 'unknown'))}", + file=sys.stderr, + flush=True, + ) + + class _HelpFormatter( argparse.RawDescriptionHelpFormatter, argparse.ArgumentDefaultsHelpFormatter, @@ -269,6 +301,7 @@ def _build_parser() -> argparse.ArgumentParser: Need repeated commands in one workspace after that? pyro workspace create debian:12 --seed-path ./repo pyro workspace sync push WORKSPACE_ID ./changes + pyro workspace shell open WORKSPACE_ID Use `pyro mcp serve` only after the CLI validation path works. """ @@ -476,6 +509,7 @@ def _build_parser() -> argparse.ArgumentParser: pyro workspace create debian:12 --seed-path ./repo pyro workspace sync push WORKSPACE_ID ./repo --dest src pyro workspace exec WORKSPACE_ID -- sh -lc 'printf "hello\\n" > note.txt' + pyro workspace shell open WORKSPACE_ID pyro workspace logs WORKSPACE_ID """ ), @@ -633,6 +667,191 @@ def _build_parser() -> argparse.ArgumentParser: action="store_true", help="Print structured JSON instead of human-readable output.", ) + workspace_shell_parser = workspace_subparsers.add_parser( + "shell", + help="Open and manage persistent interactive shells.", + description=( + "Open one or more persistent interactive PTY shell sessions inside a started " + "workspace." + ), + epilog=dedent( + """ + Examples: + pyro workspace shell open WORKSPACE_ID + pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd' + pyro workspace shell read WORKSPACE_ID SHELL_ID + pyro workspace shell signal WORKSPACE_ID SHELL_ID --signal INT + pyro workspace shell close WORKSPACE_ID SHELL_ID + + Use `workspace exec` for one-shot commands. Use `workspace shell` when you need + an interactive process that keeps its state between calls. + """ + ), + formatter_class=_HelpFormatter, + ) + workspace_shell_subparsers = workspace_shell_parser.add_subparsers( + dest="workspace_shell_command", + required=True, + metavar="SHELL", + ) + workspace_shell_open_parser = workspace_shell_subparsers.add_parser( + "open", + help="Open a persistent interactive shell.", + description="Open a new PTY shell inside a started workspace.", + epilog="Example:\n pyro workspace shell open WORKSPACE_ID --cwd src", + formatter_class=_HelpFormatter, + ) + workspace_shell_open_parser.add_argument( + "workspace_id", + metavar="WORKSPACE_ID", + help="Persistent workspace identifier.", + ) + workspace_shell_open_parser.add_argument( + "--cwd", + default=WORKSPACE_GUEST_PATH, + help="Shell working directory. Relative values resolve inside `/workspace`.", + ) + workspace_shell_open_parser.add_argument( + "--cols", + type=int, + default=120, + help="Shell terminal width in columns.", + ) + workspace_shell_open_parser.add_argument( + "--rows", + type=int, + default=30, + help="Shell terminal height in rows.", + ) + workspace_shell_open_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + workspace_shell_read_parser = workspace_shell_subparsers.add_parser( + "read", + help="Read merged PTY output from a shell.", + description="Read merged text output from a persistent workspace shell.", + epilog=dedent( + """ + Example: + pyro workspace shell read WORKSPACE_ID SHELL_ID --cursor 0 + + Shell output is written to stdout. The read summary is written to stderr. + Use --json for a deterministic structured response. + """ + ), + formatter_class=_HelpFormatter, + ) + workspace_shell_read_parser.add_argument( + "workspace_id", + metavar="WORKSPACE_ID", + help="Persistent workspace identifier.", + ) + workspace_shell_read_parser.add_argument( + "shell_id", + metavar="SHELL_ID", + help="Persistent shell identifier returned by `workspace shell open`.", + ) + workspace_shell_read_parser.add_argument( + "--cursor", + type=int, + default=0, + help="Character offset into the merged shell output buffer.", + ) + workspace_shell_read_parser.add_argument( + "--max-chars", + type=int, + default=65536, + help="Maximum number of characters to return from the current cursor position.", + ) + workspace_shell_read_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + workspace_shell_write_parser = workspace_shell_subparsers.add_parser( + "write", + help="Write text input into a shell.", + description="Write text input into a persistent workspace shell.", + epilog="Example:\n pyro workspace shell write WORKSPACE_ID SHELL_ID --input 'pwd'", + formatter_class=_HelpFormatter, + ) + workspace_shell_write_parser.add_argument( + "workspace_id", + metavar="WORKSPACE_ID", + help="Persistent workspace identifier.", + ) + workspace_shell_write_parser.add_argument( + "shell_id", + metavar="SHELL_ID", + help="Persistent shell identifier returned by `workspace shell open`.", + ) + workspace_shell_write_parser.add_argument( + "--input", + required=True, + help="Text to send to the shell.", + ) + workspace_shell_write_parser.add_argument( + "--no-newline", + action="store_true", + help="Do not append a trailing newline after the provided input.", + ) + workspace_shell_write_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + workspace_shell_signal_parser = workspace_shell_subparsers.add_parser( + "signal", + help="Send a signal to a shell process group.", + description="Send a control signal to a persistent workspace shell.", + epilog="Example:\n pyro workspace shell signal WORKSPACE_ID SHELL_ID --signal INT", + formatter_class=_HelpFormatter, + ) + workspace_shell_signal_parser.add_argument( + "workspace_id", + metavar="WORKSPACE_ID", + help="Persistent workspace identifier.", + ) + workspace_shell_signal_parser.add_argument( + "shell_id", + metavar="SHELL_ID", + help="Persistent shell identifier returned by `workspace shell open`.", + ) + workspace_shell_signal_parser.add_argument( + "--signal", + default="INT", + choices=WORKSPACE_SHELL_SIGNAL_NAMES, + help="Signal name to send to the shell process group.", + ) + workspace_shell_signal_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) + workspace_shell_close_parser = workspace_shell_subparsers.add_parser( + "close", + help="Close a persistent shell.", + description="Close a persistent workspace shell and release its PTY state.", + epilog="Example:\n pyro workspace shell close WORKSPACE_ID SHELL_ID", + formatter_class=_HelpFormatter, + ) + workspace_shell_close_parser.add_argument( + "workspace_id", + metavar="WORKSPACE_ID", + help="Persistent workspace identifier.", + ) + workspace_shell_close_parser.add_argument( + "shell_id", + metavar="SHELL_ID", + help="Persistent shell identifier returned by `workspace shell open`.", + ) + workspace_shell_close_parser.add_argument( + "--json", + action="store_true", + help="Print structured JSON instead of human-readable output.", + ) workspace_status_parser = workspace_subparsers.add_parser( "status", help="Inspect one workspace.", @@ -929,6 +1148,99 @@ def main() -> None: raise SystemExit(1) from exc _print_workspace_sync_human(payload) return + if args.workspace_command == "shell": + if args.workspace_shell_command == "open": + try: + payload = pyro.open_shell( + args.workspace_id, + cwd=args.cwd, + cols=args.cols, + rows=args.rows, + ) + except Exception as exc: # noqa: BLE001 + if bool(args.json): + _print_json({"ok": False, "error": str(exc)}) + else: + print(f"[error] {exc}", file=sys.stderr, flush=True) + raise SystemExit(1) from exc + if bool(args.json): + _print_json(payload) + else: + _print_workspace_shell_summary_human(payload, prefix="workspace-shell-open") + return + if args.workspace_shell_command == "read": + try: + payload = pyro.read_shell( + args.workspace_id, + args.shell_id, + cursor=args.cursor, + max_chars=args.max_chars, + ) + except Exception as exc: # noqa: BLE001 + if bool(args.json): + _print_json({"ok": False, "error": str(exc)}) + else: + print(f"[error] {exc}", file=sys.stderr, flush=True) + raise SystemExit(1) from exc + if bool(args.json): + _print_json(payload) + else: + _print_workspace_shell_read_human(payload) + return + if args.workspace_shell_command == "write": + try: + payload = pyro.write_shell( + args.workspace_id, + args.shell_id, + input=args.input, + append_newline=not bool(args.no_newline), + ) + except Exception as exc: # noqa: BLE001 + if bool(args.json): + _print_json({"ok": False, "error": str(exc)}) + else: + print(f"[error] {exc}", file=sys.stderr, flush=True) + raise SystemExit(1) from exc + if bool(args.json): + _print_json(payload) + else: + _print_workspace_shell_summary_human(payload, prefix="workspace-shell-write") + return + if args.workspace_shell_command == "signal": + try: + payload = pyro.signal_shell( + args.workspace_id, + args.shell_id, + signal_name=args.signal, + ) + except Exception as exc: # noqa: BLE001 + if bool(args.json): + _print_json({"ok": False, "error": str(exc)}) + else: + print(f"[error] {exc}", file=sys.stderr, flush=True) + raise SystemExit(1) from exc + if bool(args.json): + _print_json(payload) + else: + _print_workspace_shell_summary_human( + payload, + prefix="workspace-shell-signal", + ) + return + if args.workspace_shell_command == "close": + try: + payload = pyro.close_shell(args.workspace_id, args.shell_id) + except Exception as exc: # noqa: BLE001 + if bool(args.json): + _print_json({"ok": False, "error": str(exc)}) + else: + print(f"[error] {exc}", file=sys.stderr, flush=True) + raise SystemExit(1) from exc + if bool(args.json): + _print_json(payload) + else: + _print_workspace_shell_summary_human(payload, prefix="workspace-shell-close") + return if args.workspace_command == "status": payload = pyro.status_workspace(args.workspace_id) if bool(args.json): diff --git a/src/pyro_mcp/contract.py b/src/pyro_mcp/contract.py index 61335f0..ae2a05c 100644 --- a/src/pyro_mcp/contract.py +++ b/src/pyro_mcp/contract.py @@ -5,7 +5,16 @@ from __future__ import annotations PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "workspace") PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",) PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune") -PUBLIC_CLI_WORKSPACE_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status", "sync") +PUBLIC_CLI_WORKSPACE_SUBCOMMANDS = ( + "create", + "delete", + "exec", + "logs", + "shell", + "status", + "sync", +) +PUBLIC_CLI_WORKSPACE_SHELL_SUBCOMMANDS = ("close", "open", "read", "signal", "write") PUBLIC_CLI_WORKSPACE_SYNC_SUBCOMMANDS = ("push",) PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = ( "--vcpu-count", @@ -16,6 +25,11 @@ PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = ( "--seed-path", "--json", ) +PUBLIC_CLI_WORKSPACE_SHELL_OPEN_FLAGS = ("--cwd", "--cols", "--rows", "--json") +PUBLIC_CLI_WORKSPACE_SHELL_READ_FLAGS = ("--cursor", "--max-chars", "--json") +PUBLIC_CLI_WORKSPACE_SHELL_WRITE_FLAGS = ("--input", "--no-newline", "--json") +PUBLIC_CLI_WORKSPACE_SHELL_SIGNAL_FLAGS = ("--signal", "--json") +PUBLIC_CLI_WORKSPACE_SHELL_CLOSE_FLAGS = ("--json",) PUBLIC_CLI_WORKSPACE_SYNC_PUSH_FLAGS = ("--dest", "--json") PUBLIC_CLI_RUN_FLAGS = ( "--vcpu-count", @@ -28,6 +42,7 @@ PUBLIC_CLI_RUN_FLAGS = ( ) PUBLIC_SDK_METHODS = ( + "close_shell", "create_server", "create_vm", "create_workspace", @@ -39,18 +54,27 @@ PUBLIC_SDK_METHODS = ( "list_environments", "logs_workspace", "network_info_vm", + "open_shell", "prune_environments", "pull_environment", "push_workspace_sync", + "read_shell", "reap_expired", "run_in_vm", + "signal_shell", "start_vm", "status_vm", "status_workspace", "stop_vm", + "write_shell", ) PUBLIC_MCP_TOOLS = ( + "shell_close", + "shell_open", + "shell_read", + "shell_signal", + "shell_write", "vm_create", "vm_delete", "vm_exec", diff --git a/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py b/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py index d914fc1..03ff492 100755 --- a/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py +++ b/src/pyro_mcp/runtime_bundle/linux-x86_64/guest/pyro_guest_agent.py @@ -1,14 +1,21 @@ #!/usr/bin/env python3 -"""Minimal guest-side exec and workspace import agent for pyro runtime bundles.""" +"""Guest-side exec, workspace import, and interactive shell agent.""" from __future__ import annotations +import codecs +import fcntl import io import json import os +import pty +import signal import socket +import struct import subprocess import tarfile +import termios +import threading import time from pathlib import Path, PurePosixPath from typing import Any @@ -16,6 +23,17 @@ from typing import Any PORT = 5005 BUFFER_SIZE = 65536 WORKSPACE_ROOT = PurePosixPath("/workspace") +SHELL_ROOT = Path("/run/pyro-shells") +SHELL_SIGNAL_MAP = { + "HUP": signal.SIGHUP, + "INT": signal.SIGINT, + "TERM": signal.SIGTERM, + "KILL": signal.SIGKILL, +} +SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP) + +_SHELLS: dict[str, "GuestShellSession"] = {} +_SHELLS_LOCK = threading.Lock() def _read_request(conn: socket.socket) -> dict[str, Any]: @@ -77,10 +95,15 @@ def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]: suffix = normalized.relative_to(WORKSPACE_ROOT) host_path = Path("/workspace") if str(suffix) not in {"", "."}: - host_path = host_path / str(suffix) + host_path = host_path.joinpath(*suffix.parts) return normalized, host_path +def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]: + normalized, host_path = _normalize_destination(cwd) + return str(normalized), host_path + + def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None: target = link_target.strip() if target == "": @@ -106,18 +129,18 @@ def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: - _, destination_root = _normalize_destination(destination) + normalized_destination, destination_root = _normalize_destination(destination) destination_root.mkdir(parents=True, exist_ok=True) bytes_written = 0 entry_count = 0 with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive: for member in archive.getmembers(): member_name = _normalize_member_name(member.name) - target_path = destination_root / str(member_name) + target_path = destination_root.joinpath(*member_name.parts) entry_count += 1 _ensure_no_symlink_parents(destination_root, target_path, member.name) if member.isdir(): - if target_path.exists() and not target_path.is_dir(): + if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()): raise RuntimeError(f"directory conflicts with existing path: {member.name}") target_path.mkdir(parents=True, exist_ok=True) continue @@ -151,7 +174,7 @@ def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]: ) raise RuntimeError(f"unsupported archive member type: {member.name}") return { - "destination": destination, + "destination": str(normalized_destination), "entry_count": entry_count, "bytes_written": bytes_written, } @@ -182,7 +205,323 @@ def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]: } +def _set_pty_size(fd: int, rows: int, cols: int) -> None: + winsize = struct.pack("HHHH", rows, cols, 0, 0) + fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize) + + +class GuestShellSession: + """In-guest PTY-backed interactive shell session.""" + + def __init__( + self, + *, + shell_id: str, + cwd: Path, + cwd_text: str, + cols: int, + rows: int, + ) -> None: + self.shell_id = shell_id + self.cwd = cwd_text + self.cols = cols + self.rows = rows + self.started_at = time.time() + self.ended_at: float | None = None + self.exit_code: int | None = None + self.state = "running" + self._lock = threading.RLock() + self._output = "" + self._decoder = codecs.getincrementaldecoder("utf-8")("replace") + self._metadata_path = SHELL_ROOT / f"{shell_id}.json" + self._log_path = SHELL_ROOT / f"{shell_id}.log" + self._master_fd: int | None = None + + master_fd, slave_fd = pty.openpty() + try: + _set_pty_size(slave_fd, rows, cols) + env = os.environ.copy() + env.update( + { + "TERM": env.get("TERM", "xterm-256color"), + "PS1": "pyro$ ", + "PROMPT_COMMAND": "", + } + ) + process = subprocess.Popen( # noqa: S603 + ["/bin/bash", "--noprofile", "--norc", "-i"], + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + cwd=str(cwd), + env=env, + text=False, + close_fds=True, + preexec_fn=os.setsid, + ) + except Exception: + os.close(master_fd) + raise + finally: + os.close(slave_fd) + + self._process = process + self._master_fd = master_fd + self._write_metadata() + self._reader = threading.Thread(target=self._reader_loop, daemon=True) + self._waiter = threading.Thread(target=self._waiter_loop, daemon=True) + self._reader.start() + self._waiter.start() + + def summary(self) -> dict[str, Any]: + with self._lock: + return { + "shell_id": self.shell_id, + "cwd": self.cwd, + "cols": self.cols, + "rows": self.rows, + "state": self.state, + "started_at": self.started_at, + "ended_at": self.ended_at, + "exit_code": self.exit_code, + } + + def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]: + with self._lock: + clamped_cursor = min(max(cursor, 0), len(self._output)) + output = self._output[clamped_cursor : clamped_cursor + max_chars] + next_cursor = clamped_cursor + len(output) + payload = self.summary() + payload.update( + { + "cursor": clamped_cursor, + "next_cursor": next_cursor, + "output": output, + "truncated": next_cursor < len(self._output), + } + ) + return payload + + def write(self, text: str, *, append_newline: bool) -> dict[str, Any]: + if self._process.poll() is not None: + self._refresh_process_state() + with self._lock: + if self.state != "running": + raise RuntimeError(f"shell {self.shell_id} is not running") + master_fd = self._master_fd + if master_fd is None: + raise RuntimeError(f"shell {self.shell_id} transport is unavailable") + payload = text + ("\n" if append_newline else "") + try: + os.write(master_fd, payload.encode("utf-8")) + except OSError as exc: + self._refresh_process_state() + raise RuntimeError(f"failed to write shell input: {exc}") from exc + response = self.summary() + response.update({"input_length": len(text), "append_newline": append_newline}) + return response + + def send_signal(self, signal_name: str) -> dict[str, Any]: + signum = SHELL_SIGNAL_MAP.get(signal_name) + if signum is None: + raise ValueError(f"unsupported shell signal: {signal_name}") + if self._process.poll() is not None: + self._refresh_process_state() + with self._lock: + if self.state != "running": + raise RuntimeError(f"shell {self.shell_id} is not running") + pid = self._process.pid + try: + os.killpg(pid, signum) + except ProcessLookupError as exc: + self._refresh_process_state() + raise RuntimeError(f"shell {self.shell_id} is not running") from exc + response = self.summary() + response["signal"] = signal_name + return response + + def close(self) -> dict[str, Any]: + if self._process.poll() is None: + try: + os.killpg(self._process.pid, signal.SIGHUP) + except ProcessLookupError: + pass + try: + self._process.wait(timeout=5) + except subprocess.TimeoutExpired: + try: + os.killpg(self._process.pid, signal.SIGKILL) + except ProcessLookupError: + pass + self._process.wait(timeout=5) + else: + self._refresh_process_state() + self._close_master_fd() + if self._reader is not None: + self._reader.join(timeout=1) + if self._waiter is not None: + self._waiter.join(timeout=1) + response = self.summary() + response["closed"] = True + return response + + def _reader_loop(self) -> None: + master_fd = self._master_fd + if master_fd is None: + return + while True: + try: + chunk = os.read(master_fd, BUFFER_SIZE) + except OSError: + break + if chunk == b"": + break + decoded = self._decoder.decode(chunk) + if decoded == "": + continue + with self._lock: + self._output += decoded + with self._log_path.open("a", encoding="utf-8") as handle: + handle.write(decoded) + decoded = self._decoder.decode(b"", final=True) + if decoded != "": + with self._lock: + self._output += decoded + with self._log_path.open("a", encoding="utf-8") as handle: + handle.write(decoded) + + def _waiter_loop(self) -> None: + exit_code = self._process.wait() + with self._lock: + self.state = "stopped" + self.exit_code = exit_code + self.ended_at = time.time() + self._write_metadata() + + def _refresh_process_state(self) -> None: + exit_code = self._process.poll() + if exit_code is None: + return + with self._lock: + if self.state == "running": + self.state = "stopped" + self.exit_code = exit_code + self.ended_at = time.time() + self._write_metadata() + + def _write_metadata(self) -> None: + self._metadata_path.parent.mkdir(parents=True, exist_ok=True) + self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8") + + def _close_master_fd(self) -> None: + with self._lock: + master_fd = self._master_fd + self._master_fd = None + if master_fd is None: + return + try: + os.close(master_fd) + except OSError: + pass + + +def _create_shell( + *, + shell_id: str, + cwd_text: str, + cols: int, + rows: int, +) -> GuestShellSession: + _, cwd_path = _normalize_shell_cwd(cwd_text) + with _SHELLS_LOCK: + if shell_id in _SHELLS: + raise RuntimeError(f"shell {shell_id!r} already exists") + session = GuestShellSession( + shell_id=shell_id, + cwd=cwd_path, + cwd_text=cwd_text, + cols=cols, + rows=rows, + ) + _SHELLS[shell_id] = session + return session + + +def _get_shell(shell_id: str) -> GuestShellSession: + with _SHELLS_LOCK: + try: + return _SHELLS[shell_id] + except KeyError as exc: + raise RuntimeError(f"shell {shell_id!r} does not exist") from exc + + +def _remove_shell(shell_id: str) -> GuestShellSession: + with _SHELLS_LOCK: + try: + return _SHELLS.pop(shell_id) + except KeyError as exc: + raise RuntimeError(f"shell {shell_id!r} does not exist") from exc + + +def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]: + action = str(request.get("action", "exec")) + if action == "extract_archive": + archive_size = int(request.get("archive_size", 0)) + if archive_size < 0: + raise RuntimeError("archive_size must not be negative") + destination = str(request.get("destination", "/workspace")) + payload = _read_exact(conn, archive_size) + return _extract_archive(payload, destination) + if action == "open_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace"))) + session = _create_shell( + shell_id=shell_id, + cwd_text=cwd_text, + cols=int(request.get("cols", 120)), + rows=int(request.get("rows", 30)), + ) + return session.summary() + if action == "read_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + return _get_shell(shell_id).read( + cursor=int(request.get("cursor", 0)), + max_chars=int(request.get("max_chars", 65536)), + ) + if action == "write_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + return _get_shell(shell_id).write( + str(request.get("input", "")), + append_newline=bool(request.get("append_newline", True)), + ) + if action == "signal_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + signal_name = str(request.get("signal", "INT")).upper() + if signal_name not in SHELL_SIGNAL_NAMES: + raise RuntimeError( + f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}" + ) + return _get_shell(shell_id).send_signal(signal_name) + if action == "close_shell": + shell_id = str(request.get("shell_id", "")).strip() + if shell_id == "": + raise RuntimeError("shell_id is required") + return _remove_shell(shell_id).close() + command = str(request.get("command", "")) + timeout_seconds = int(request.get("timeout_seconds", 30)) + return _run_command(command, timeout_seconds) + + def main() -> None: + SHELL_ROOT.mkdir(parents=True, exist_ok=True) family = getattr(socket, "AF_VSOCK", None) if family is None: raise SystemExit("AF_VSOCK is unavailable") @@ -192,19 +531,11 @@ def main() -> None: while True: conn, _ = server.accept() with conn: - request = _read_request(conn) - action = str(request.get("action", "exec")) - if action == "extract_archive": - archive_size = int(request.get("archive_size", 0)) - if archive_size < 0: - raise RuntimeError("archive_size must not be negative") - destination = str(request.get("destination", "/workspace")) - payload = _read_exact(conn, archive_size) - response = _extract_archive(payload, destination) - else: - command = str(request.get("command", "")) - timeout_seconds = int(request.get("timeout_seconds", 30)) - response = _run_command(command, timeout_seconds) + try: + request = _read_request(conn) + response = _dispatch(request, conn) + except Exception as exc: # noqa: BLE001 + response = {"error": str(exc)} conn.sendall((json.dumps(response) + "\n").encode("utf-8")) diff --git a/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json b/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json index 948953a..52d1a5d 100644 --- a/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json +++ b/src/pyro_mcp/runtime_bundle/linux-x86_64/manifest.json @@ -25,7 +25,7 @@ "guest": { "agent": { "path": "guest/pyro_guest_agent.py", - "sha256": "3b684b1b07745fc7788e560b0bdd0c0535c31c395ff87474ae9e114f4489726d" + "sha256": "07adf6269551447dbea8c236f91499ea1479212a3f084c5402a656f5f5cc5892" } }, "platform": "linux-x86_64", diff --git a/src/pyro_mcp/vm_environments.py b/src/pyro_mcp/vm_environments.py index 5a0bf99..3e7eddf 100644 --- a/src/pyro_mcp/vm_environments.py +++ b/src/pyro_mcp/vm_environments.py @@ -19,7 +19,7 @@ from typing import Any from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths DEFAULT_ENVIRONMENT_VERSION = "1.0.0" -DEFAULT_CATALOG_VERSION = "2.4.0" +DEFAULT_CATALOG_VERSION = "2.5.0" OCI_MANIFEST_ACCEPT = ", ".join( ( "application/vnd.oci.image.index.v1+json", diff --git a/src/pyro_mcp/vm_guest.py b/src/pyro_mcp/vm_guest.py index d6acd3a..1988269 100644 --- a/src/pyro_mcp/vm_guest.py +++ b/src/pyro_mcp/vm_guest.py @@ -39,6 +39,26 @@ class GuestArchiveResponse: bytes_written: int +@dataclass(frozen=True) +class GuestShellSummary: + shell_id: str + cwd: str + cols: int + rows: int + state: str + started_at: float + ended_at: float | None + exit_code: int | None + + +@dataclass(frozen=True) +class GuestShellReadResponse(GuestShellSummary): + cursor: int + next_cursor: int + output: str + truncated: bool + + class VsockExecClient: """Minimal JSON-over-stream client for a guest exec agent.""" @@ -54,19 +74,17 @@ class VsockExecClient: *, uds_path: str | None = None, ) -> GuestExecResponse: - request = { - "command": command, - "timeout_seconds": timeout_seconds, - } - sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path) - try: - sock.sendall((json.dumps(request) + "\n").encode("utf-8")) - payload = self._recv_json_payload(sock) - finally: - sock.close() - - if not isinstance(payload, dict): - raise RuntimeError("guest exec response must be a JSON object") + payload = self._request_json( + guest_cid, + port, + { + "command": command, + "timeout_seconds": timeout_seconds, + }, + timeout_seconds=timeout_seconds, + uds_path=uds_path, + error_message="guest exec response must be a JSON object", + ) return GuestExecResponse( stdout=str(payload.get("stdout", "")), stderr=str(payload.get("stderr", "")), @@ -101,12 +119,198 @@ class VsockExecClient: if not isinstance(payload, dict): raise RuntimeError("guest archive response must be a JSON object") + error = payload.get("error") + if error is not None: + raise RuntimeError(str(error)) return GuestArchiveResponse( destination=str(payload.get("destination", destination)), entry_count=int(payload.get("entry_count", 0)), bytes_written=int(payload.get("bytes_written", 0)), ) + def open_shell( + self, + guest_cid: int, + port: int, + *, + shell_id: str, + cwd: str, + cols: int, + rows: int, + timeout_seconds: int = 30, + uds_path: str | None = None, + ) -> GuestShellSummary: + payload = self._request_json( + guest_cid, + port, + { + "action": "open_shell", + "shell_id": shell_id, + "cwd": cwd, + "cols": cols, + "rows": rows, + }, + timeout_seconds=timeout_seconds, + uds_path=uds_path, + error_message="guest shell open response must be a JSON object", + ) + return self._shell_summary_from_payload(payload) + + def read_shell( + self, + guest_cid: int, + port: int, + *, + shell_id: str, + cursor: int, + max_chars: int, + timeout_seconds: int = 30, + uds_path: str | None = None, + ) -> GuestShellReadResponse: + payload = self._request_json( + guest_cid, + port, + { + "action": "read_shell", + "shell_id": shell_id, + "cursor": cursor, + "max_chars": max_chars, + }, + timeout_seconds=timeout_seconds, + uds_path=uds_path, + error_message="guest shell read response must be a JSON object", + ) + summary = self._shell_summary_from_payload(payload) + return GuestShellReadResponse( + shell_id=summary.shell_id, + cwd=summary.cwd, + cols=summary.cols, + rows=summary.rows, + state=summary.state, + started_at=summary.started_at, + ended_at=summary.ended_at, + exit_code=summary.exit_code, + cursor=int(payload.get("cursor", cursor)), + next_cursor=int(payload.get("next_cursor", cursor)), + output=str(payload.get("output", "")), + truncated=bool(payload.get("truncated", False)), + ) + + def write_shell( + self, + guest_cid: int, + port: int, + *, + shell_id: str, + input_text: str, + append_newline: bool, + timeout_seconds: int = 30, + uds_path: str | None = None, + ) -> dict[str, Any]: + payload = self._request_json( + guest_cid, + port, + { + "action": "write_shell", + "shell_id": shell_id, + "input": input_text, + "append_newline": append_newline, + }, + timeout_seconds=timeout_seconds, + uds_path=uds_path, + error_message="guest shell write response must be a JSON object", + ) + self._shell_summary_from_payload(payload) + return payload + + def signal_shell( + self, + guest_cid: int, + port: int, + *, + shell_id: str, + signal_name: str, + timeout_seconds: int = 30, + uds_path: str | None = None, + ) -> dict[str, Any]: + payload = self._request_json( + guest_cid, + port, + { + "action": "signal_shell", + "shell_id": shell_id, + "signal": signal_name, + }, + timeout_seconds=timeout_seconds, + uds_path=uds_path, + error_message="guest shell signal response must be a JSON object", + ) + self._shell_summary_from_payload(payload) + return payload + + def close_shell( + self, + guest_cid: int, + port: int, + *, + shell_id: str, + timeout_seconds: int = 30, + uds_path: str | None = None, + ) -> dict[str, Any]: + payload = self._request_json( + guest_cid, + port, + { + "action": "close_shell", + "shell_id": shell_id, + }, + timeout_seconds=timeout_seconds, + uds_path=uds_path, + error_message="guest shell close response must be a JSON object", + ) + self._shell_summary_from_payload(payload) + return payload + + def _request_json( + self, + guest_cid: int, + port: int, + request: dict[str, Any], + *, + timeout_seconds: int, + uds_path: str | None, + error_message: str, + ) -> dict[str, Any]: + sock = self._connect(guest_cid, port, timeout_seconds, uds_path=uds_path) + try: + sock.sendall((json.dumps(request) + "\n").encode("utf-8")) + payload = self._recv_json_payload(sock) + finally: + sock.close() + if not isinstance(payload, dict): + raise RuntimeError(error_message) + error = payload.get("error") + if error is not None: + raise RuntimeError(str(error)) + return payload + + @staticmethod + def _shell_summary_from_payload(payload: dict[str, Any]) -> GuestShellSummary: + return GuestShellSummary( + shell_id=str(payload.get("shell_id", "")), + cwd=str(payload.get("cwd", "/workspace")), + cols=int(payload.get("cols", 0)), + rows=int(payload.get("rows", 0)), + state=str(payload.get("state", "stopped")), + started_at=float(payload.get("started_at", 0.0)), + ended_at=( + None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0)) + ), + exit_code=( + None if payload.get("exit_code") is None else int(payload.get("exit_code", 0)) + ), + ) + def _connect( self, guest_cid: int, diff --git a/src/pyro_mcp/vm_manager.py b/src/pyro_mcp/vm_manager.py index 34de2e7..ad221dc 100644 --- a/src/pyro_mcp/vm_manager.py +++ b/src/pyro_mcp/vm_manager.py @@ -27,8 +27,15 @@ from pyro_mcp.vm_environments import EnvironmentStore, default_cache_dir, get_en from pyro_mcp.vm_firecracker import build_launch_plan from pyro_mcp.vm_guest import VsockExecClient from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager +from pyro_mcp.workspace_shells import ( + create_local_shell, + get_local_shell, + remove_local_shell, + shell_signal_names, +) VmState = Literal["created", "started", "stopped"] +WorkspaceShellState = Literal["running", "stopped"] DEFAULT_VCPU_COUNT = 1 DEFAULT_MEM_MIB = 1024 @@ -36,13 +43,18 @@ DEFAULT_TIMEOUT_SECONDS = 30 DEFAULT_TTL_SECONDS = 600 DEFAULT_ALLOW_HOST_COMPAT = False -WORKSPACE_LAYOUT_VERSION = 2 +WORKSPACE_LAYOUT_VERSION = 3 WORKSPACE_DIRNAME = "workspace" WORKSPACE_COMMANDS_DIRNAME = "commands" +WORKSPACE_SHELLS_DIRNAME = "shells" WORKSPACE_RUNTIME_DIRNAME = "runtime" WORKSPACE_GUEST_PATH = "/workspace" WORKSPACE_GUEST_AGENT_PATH = "/opt/pyro/bin/pyro_guest_agent.py" WORKSPACE_ARCHIVE_UPLOAD_TIMEOUT_SECONDS = 60 +DEFAULT_SHELL_COLS = 120 +DEFAULT_SHELL_ROWS = 30 +DEFAULT_SHELL_MAX_CHARS = 65536 +WORKSPACE_SHELL_SIGNAL_NAMES = shell_signal_names() WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"] @@ -183,6 +195,58 @@ class WorkspaceRecord: ) +@dataclass +class WorkspaceShellRecord: + """Persistent shell metadata stored on disk per workspace.""" + + workspace_id: str + shell_id: str + cwd: str + cols: int + rows: int + state: WorkspaceShellState + started_at: float + ended_at: float | None = None + exit_code: int | None = None + execution_mode: str = "pending" + metadata: dict[str, str] = field(default_factory=dict) + + def to_payload(self) -> dict[str, Any]: + return { + "layout_version": WORKSPACE_LAYOUT_VERSION, + "workspace_id": self.workspace_id, + "shell_id": self.shell_id, + "cwd": self.cwd, + "cols": self.cols, + "rows": self.rows, + "state": self.state, + "started_at": self.started_at, + "ended_at": self.ended_at, + "exit_code": self.exit_code, + "execution_mode": self.execution_mode, + "metadata": dict(self.metadata), + } + + @classmethod + def from_payload(cls, payload: dict[str, Any]) -> WorkspaceShellRecord: + return cls( + workspace_id=str(payload["workspace_id"]), + shell_id=str(payload["shell_id"]), + cwd=str(payload.get("cwd", WORKSPACE_GUEST_PATH)), + cols=int(payload.get("cols", DEFAULT_SHELL_COLS)), + rows=int(payload.get("rows", DEFAULT_SHELL_ROWS)), + state=cast(WorkspaceShellState, str(payload.get("state", "stopped"))), + started_at=float(payload.get("started_at", 0.0)), + ended_at=( + None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0)) + ), + exit_code=( + None if payload.get("exit_code") is None else int(payload.get("exit_code", 0)) + ), + execution_mode=str(payload.get("execution_mode", "pending")), + metadata=_string_dict(payload.get("metadata")), + ) + @dataclass(frozen=True) class PreparedWorkspaceSeed: """Prepared host-side seed archive plus metadata.""" @@ -610,6 +674,59 @@ class VmBackend: ) -> dict[str, Any]: raise NotImplementedError + def open_shell( # pragma: no cover + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + cwd: str, + cols: int, + rows: int, + ) -> dict[str, Any]: + raise NotImplementedError + + def read_shell( # pragma: no cover + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + cursor: int, + max_chars: int, + ) -> dict[str, Any]: + raise NotImplementedError + + def write_shell( # pragma: no cover + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + input_text: str, + append_newline: bool, + ) -> dict[str, Any]: + raise NotImplementedError + + def signal_shell( # pragma: no cover + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + signal_name: str, + ) -> dict[str, Any]: + raise NotImplementedError + + def close_shell( # pragma: no cover + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + ) -> dict[str, Any]: + raise NotImplementedError + class MockBackend(VmBackend): """Host-process backend used for development and testability.""" @@ -651,6 +768,87 @@ class MockBackend(VmBackend): destination=destination, ) + def open_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + cwd: str, + cols: int, + rows: int, + ) -> dict[str, Any]: + session = create_local_shell( + workspace_id=workspace_id, + shell_id=shell_id, + cwd=_workspace_host_destination(_instance_workspace_host_dir(instance), cwd), + display_cwd=cwd, + cols=cols, + rows=rows, + ) + summary = session.summary() + summary["execution_mode"] = "host_compat" + return summary + + def read_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + cursor: int, + max_chars: int, + ) -> dict[str, Any]: + del instance + session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id) + payload = session.read(cursor=cursor, max_chars=max_chars) + payload["execution_mode"] = "host_compat" + return payload + + def write_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + input_text: str, + append_newline: bool, + ) -> dict[str, Any]: + del instance + session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id) + payload = session.write(input_text, append_newline=append_newline) + payload["execution_mode"] = "host_compat" + return payload + + def signal_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + signal_name: str, + ) -> dict[str, Any]: + del instance + session = get_local_shell(workspace_id=workspace_id, shell_id=shell_id) + payload = session.send_signal(signal_name) + payload["execution_mode"] = "host_compat" + return payload + + def close_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + ) -> dict[str, Any]: + del instance + session = remove_local_shell(workspace_id=workspace_id, shell_id=shell_id) + if session is None: + raise ValueError(f"shell {shell_id!r} does not exist in workspace {workspace_id!r}") + payload = session.close() + payload["execution_mode"] = "host_compat" + return payload + class FirecrackerBackend(VmBackend): # pragma: no cover """Host-gated backend that validates Firecracker prerequisites.""" @@ -888,6 +1086,144 @@ class FirecrackerBackend(VmBackend): # pragma: no cover destination=destination, ) + def open_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + cwd: str, + cols: int, + rows: int, + ) -> dict[str, Any]: + del workspace_id + guest_cid = int(instance.metadata["guest_cid"]) + port = int(instance.metadata["guest_exec_port"]) + uds_path = instance.metadata.get("guest_exec_uds_path") + response = self._guest_exec_client.open_shell( + guest_cid, + port, + shell_id=shell_id, + cwd=cwd, + cols=cols, + rows=rows, + uds_path=uds_path, + ) + return { + "shell_id": response.shell_id or shell_id, + "cwd": response.cwd, + "cols": response.cols, + "rows": response.rows, + "state": response.state, + "started_at": response.started_at, + "ended_at": response.ended_at, + "exit_code": response.exit_code, + "execution_mode": instance.metadata.get("execution_mode", "pending"), + } + + def read_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + cursor: int, + max_chars: int, + ) -> dict[str, Any]: + del workspace_id + guest_cid = int(instance.metadata["guest_cid"]) + port = int(instance.metadata["guest_exec_port"]) + uds_path = instance.metadata.get("guest_exec_uds_path") + response = self._guest_exec_client.read_shell( + guest_cid, + port, + shell_id=shell_id, + cursor=cursor, + max_chars=max_chars, + uds_path=uds_path, + ) + return { + "shell_id": response.shell_id, + "cwd": response.cwd, + "cols": response.cols, + "rows": response.rows, + "state": response.state, + "started_at": response.started_at, + "ended_at": response.ended_at, + "exit_code": response.exit_code, + "cursor": response.cursor, + "next_cursor": response.next_cursor, + "output": response.output, + "truncated": response.truncated, + "execution_mode": instance.metadata.get("execution_mode", "pending"), + } + + def write_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + input_text: str, + append_newline: bool, + ) -> dict[str, Any]: + del workspace_id + guest_cid = int(instance.metadata["guest_cid"]) + port = int(instance.metadata["guest_exec_port"]) + uds_path = instance.metadata.get("guest_exec_uds_path") + payload = self._guest_exec_client.write_shell( + guest_cid, + port, + shell_id=shell_id, + input_text=input_text, + append_newline=append_newline, + uds_path=uds_path, + ) + payload["execution_mode"] = instance.metadata.get("execution_mode", "pending") + return payload + + def signal_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + signal_name: str, + ) -> dict[str, Any]: + del workspace_id + guest_cid = int(instance.metadata["guest_cid"]) + port = int(instance.metadata["guest_exec_port"]) + uds_path = instance.metadata.get("guest_exec_uds_path") + payload = self._guest_exec_client.signal_shell( + guest_cid, + port, + shell_id=shell_id, + signal_name=signal_name, + uds_path=uds_path, + ) + payload["execution_mode"] = instance.metadata.get("execution_mode", "pending") + return payload + + def close_shell( + self, + instance: VmInstance, + *, + workspace_id: str, + shell_id: str, + ) -> dict[str, Any]: + del workspace_id + guest_cid = int(instance.metadata["guest_cid"]) + port = int(instance.metadata["guest_exec_port"]) + uds_path = instance.metadata.get("guest_exec_uds_path") + payload = self._guest_exec_client.close_shell( + guest_cid, + port, + shell_id=shell_id, + uds_path=uds_path, + ) + payload["execution_mode"] = instance.metadata.get("execution_mode", "pending") + return payload + class VmManager: """In-process lifecycle manager for ephemeral VM environments and workspaces.""" @@ -1151,9 +1487,11 @@ class VmManager: runtime_dir = self._workspace_runtime_dir(workspace_id) host_workspace_dir = self._workspace_host_dir(workspace_id) commands_dir = self._workspace_commands_dir(workspace_id) + shells_dir = self._workspace_shells_dir(workspace_id) workspace_dir.mkdir(parents=True, exist_ok=False) host_workspace_dir.mkdir(parents=True, exist_ok=True) commands_dir.mkdir(parents=True, exist_ok=True) + shells_dir.mkdir(parents=True, exist_ok=True) instance = VmInstance( vm_id=workspace_id, environment=environment, @@ -1179,11 +1517,8 @@ class VmManager: f"max active VMs reached ({self._max_active_vms}); delete old VMs first" ) self._backend.create(instance) - if ( - prepared_seed.archive_path is not None - and self._runtime_capabilities.supports_guest_exec - ): - self._ensure_workspace_guest_seed_support(instance) + if self._runtime_capabilities.supports_guest_exec: + self._ensure_workspace_guest_agent_support(instance) with self._lock: self._start_instance_locked(instance) self._require_guest_exec_or_opt_in(instance) @@ -1332,6 +1667,208 @@ class VmManager: "cwd": WORKSPACE_GUEST_PATH, } + def open_shell( + self, + workspace_id: str, + *, + cwd: str = WORKSPACE_GUEST_PATH, + cols: int = DEFAULT_SHELL_COLS, + rows: int = DEFAULT_SHELL_ROWS, + ) -> dict[str, Any]: + if cols <= 0: + raise ValueError("cols must be positive") + if rows <= 0: + raise ValueError("rows must be positive") + normalized_cwd, _ = _normalize_workspace_destination(cwd) + shell_id = uuid.uuid4().hex[:12] + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + instance = self._workspace_instance_for_live_shell_locked(workspace) + payload = self._backend.open_shell( + instance, + workspace_id=workspace_id, + shell_id=shell_id, + cwd=normalized_cwd, + cols=cols, + rows=rows, + ) + shell = self._workspace_shell_record_from_payload( + workspace_id=workspace_id, + shell_id=shell_id, + payload=payload, + ) + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + workspace.state = instance.state + workspace.firecracker_pid = instance.firecracker_pid + workspace.last_error = instance.last_error + workspace.metadata = dict(instance.metadata) + self._save_workspace_locked(workspace) + self._save_workspace_shell_locked(shell) + return self._serialize_workspace_shell(shell) + + def read_shell( + self, + workspace_id: str, + shell_id: str, + *, + cursor: int = 0, + max_chars: int = DEFAULT_SHELL_MAX_CHARS, + ) -> dict[str, Any]: + if cursor < 0: + raise ValueError("cursor must not be negative") + if max_chars <= 0: + raise ValueError("max_chars must be positive") + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + instance = self._workspace_instance_for_live_shell_locked(workspace) + shell = self._load_workspace_shell_locked(workspace_id, shell_id) + payload = self._backend.read_shell( + instance, + workspace_id=workspace_id, + shell_id=shell_id, + cursor=cursor, + max_chars=max_chars, + ) + updated_shell = self._workspace_shell_record_from_payload( + workspace_id=workspace_id, + shell_id=shell_id, + payload=payload, + metadata=shell.metadata, + ) + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + workspace.state = instance.state + workspace.firecracker_pid = instance.firecracker_pid + workspace.last_error = instance.last_error + workspace.metadata = dict(instance.metadata) + self._save_workspace_locked(workspace) + self._save_workspace_shell_locked(updated_shell) + response = self._serialize_workspace_shell(updated_shell) + response.update( + { + "cursor": int(payload.get("cursor", cursor)), + "next_cursor": int(payload.get("next_cursor", cursor)), + "output": str(payload.get("output", "")), + "truncated": bool(payload.get("truncated", False)), + } + ) + return response + + def write_shell( + self, + workspace_id: str, + shell_id: str, + *, + input_text: str, + append_newline: bool = True, + ) -> dict[str, Any]: + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + instance = self._workspace_instance_for_live_shell_locked(workspace) + shell = self._load_workspace_shell_locked(workspace_id, shell_id) + payload = self._backend.write_shell( + instance, + workspace_id=workspace_id, + shell_id=shell_id, + input_text=input_text, + append_newline=append_newline, + ) + updated_shell = self._workspace_shell_record_from_payload( + workspace_id=workspace_id, + shell_id=shell_id, + payload=payload, + metadata=shell.metadata, + ) + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + workspace.state = instance.state + workspace.firecracker_pid = instance.firecracker_pid + workspace.last_error = instance.last_error + workspace.metadata = dict(instance.metadata) + self._save_workspace_locked(workspace) + self._save_workspace_shell_locked(updated_shell) + response = self._serialize_workspace_shell(updated_shell) + response.update( + { + "input_length": int(payload.get("input_length", len(input_text))), + "append_newline": bool(payload.get("append_newline", append_newline)), + } + ) + return response + + def signal_shell( + self, + workspace_id: str, + shell_id: str, + *, + signal_name: str = "INT", + ) -> dict[str, Any]: + normalized_signal = signal_name.upper() + if normalized_signal not in WORKSPACE_SHELL_SIGNAL_NAMES: + raise ValueError( + f"signal_name must be one of: {', '.join(WORKSPACE_SHELL_SIGNAL_NAMES)}" + ) + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + instance = self._workspace_instance_for_live_shell_locked(workspace) + shell = self._load_workspace_shell_locked(workspace_id, shell_id) + payload = self._backend.signal_shell( + instance, + workspace_id=workspace_id, + shell_id=shell_id, + signal_name=normalized_signal, + ) + updated_shell = self._workspace_shell_record_from_payload( + workspace_id=workspace_id, + shell_id=shell_id, + payload=payload, + metadata=shell.metadata, + ) + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + workspace.state = instance.state + workspace.firecracker_pid = instance.firecracker_pid + workspace.last_error = instance.last_error + workspace.metadata = dict(instance.metadata) + self._save_workspace_locked(workspace) + self._save_workspace_shell_locked(updated_shell) + response = self._serialize_workspace_shell(updated_shell) + response["signal"] = str(payload.get("signal", normalized_signal)) + return response + + def close_shell( + self, + workspace_id: str, + shell_id: str, + ) -> dict[str, Any]: + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + instance = self._workspace_instance_for_live_shell_locked(workspace) + shell = self._load_workspace_shell_locked(workspace_id, shell_id) + payload = self._backend.close_shell( + instance, + workspace_id=workspace_id, + shell_id=shell_id, + ) + closed_shell = self._workspace_shell_record_from_payload( + workspace_id=workspace_id, + shell_id=shell_id, + payload=payload, + metadata=shell.metadata, + ) + with self._lock: + workspace = self._load_workspace_locked(workspace_id) + workspace.state = instance.state + workspace.firecracker_pid = instance.firecracker_pid + workspace.last_error = instance.last_error + workspace.metadata = dict(instance.metadata) + self._save_workspace_locked(workspace) + self._delete_workspace_shell_locked(workspace_id, shell_id) + response = self._serialize_workspace_shell(closed_shell) + response["closed"] = bool(payload.get("closed", True)) + return response + def status_workspace(self, workspace_id: str) -> dict[str, Any]: with self._lock: workspace = self._load_workspace_locked(workspace_id) @@ -1364,6 +1901,7 @@ class VmManager: instance = workspace.to_instance( workdir=self._workspace_runtime_dir(workspace.workspace_id) ) + self._close_workspace_shells_locked(workspace, instance) if workspace.state == "started": self._backend.stop(instance) workspace.state = "stopped" @@ -1423,6 +1961,20 @@ class VmManager: "metadata": workspace.metadata, } + def _serialize_workspace_shell(self, shell: WorkspaceShellRecord) -> dict[str, Any]: + return { + "workspace_id": shell.workspace_id, + "shell_id": shell.shell_id, + "cwd": shell.cwd, + "cols": shell.cols, + "rows": shell.rows, + "state": shell.state, + "started_at": shell.started_at, + "ended_at": shell.ended_at, + "exit_code": shell.exit_code, + "execution_mode": shell.execution_mode, + } + def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None: if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat: return @@ -1445,6 +1997,19 @@ class VmManager: "host execution." ) + def _require_workspace_shell_support(self, instance: VmInstance) -> None: + if self._backend_name == "mock": + return + if self._runtime_capabilities.supports_guest_exec: + return + reason = self._runtime_capabilities.reason or ( + "runtime does not support guest interactive shell sessions" + ) + raise RuntimeError( + "interactive shells require guest execution and are unavailable for this " + f"workspace: {reason}" + ) + def _get_instance_locked(self, vm_id: str) -> VmInstance: try: return self._instances[vm_id] @@ -1552,14 +2117,14 @@ class VmManager: bytes_written=bytes_written, ) - def _ensure_workspace_guest_seed_support(self, instance: VmInstance) -> None: + def _ensure_workspace_guest_agent_support(self, instance: VmInstance) -> None: if self._runtime_paths is None or self._runtime_paths.guest_agent_path is None: raise RuntimeError( - "runtime bundle does not provide a guest agent for workspace seeding" + "runtime bundle does not provide a guest agent for workspace operations" ) rootfs_image = instance.metadata.get("rootfs_image") if rootfs_image is None or rootfs_image == "": - raise RuntimeError("workspace rootfs image is unavailable for guest seeding") + raise RuntimeError("workspace rootfs image is unavailable for guest operations") _patch_rootfs_guest_agent(Path(rootfs_image), self._runtime_paths.guest_agent_path) def _workspace_dir(self, workspace_id: str) -> Path: @@ -1574,9 +2139,15 @@ class VmManager: def _workspace_commands_dir(self, workspace_id: str) -> Path: return self._workspace_dir(workspace_id) / WORKSPACE_COMMANDS_DIRNAME + def _workspace_shells_dir(self, workspace_id: str) -> Path: + return self._workspace_dir(workspace_id) / WORKSPACE_SHELLS_DIRNAME + def _workspace_metadata_path(self, workspace_id: str) -> Path: return self._workspace_dir(workspace_id) / "workspace.json" + def _workspace_shell_record_path(self, workspace_id: str, shell_id: str) -> Path: + return self._workspace_shells_dir(workspace_id) / f"{shell_id}.json" + def _count_workspaces_locked(self) -> int: return sum(1 for _ in self._workspaces_dir.glob("*/workspace.json")) @@ -1609,6 +2180,7 @@ class VmManager: instance = workspace.to_instance( workdir=self._workspace_runtime_dir(workspace.workspace_id) ) + self._close_workspace_shells_locked(workspace, instance) if workspace.state == "started": self._backend.stop(instance) workspace.state = "stopped" @@ -1704,3 +2276,97 @@ class VmManager: entry["stderr"] = stderr entries.append(entry) return entries + + def _workspace_instance_for_live_shell_locked(self, workspace: WorkspaceRecord) -> VmInstance: + self._ensure_workspace_not_expired_locked(workspace, time.time()) + self._refresh_workspace_liveness_locked(workspace) + if workspace.state != "started": + raise RuntimeError( + "workspace " + f"{workspace.workspace_id} must be in 'started' state before shell operations" + ) + instance = workspace.to_instance( + workdir=self._workspace_runtime_dir(workspace.workspace_id) + ) + self._require_workspace_shell_support(instance) + return instance + + def _workspace_shell_record_from_payload( + self, + *, + workspace_id: str, + shell_id: str, + payload: dict[str, Any], + metadata: dict[str, str] | None = None, + ) -> WorkspaceShellRecord: + return WorkspaceShellRecord( + workspace_id=workspace_id, + shell_id=str(payload.get("shell_id", shell_id)), + cwd=str(payload.get("cwd", WORKSPACE_GUEST_PATH)), + cols=int(payload.get("cols", DEFAULT_SHELL_COLS)), + rows=int(payload.get("rows", DEFAULT_SHELL_ROWS)), + state=cast(WorkspaceShellState, str(payload.get("state", "stopped"))), + started_at=float(payload.get("started_at", time.time())), + ended_at=( + None if payload.get("ended_at") is None else float(payload.get("ended_at", 0.0)) + ), + exit_code=( + None if payload.get("exit_code") is None else int(payload.get("exit_code", 0)) + ), + execution_mode=str(payload.get("execution_mode", "pending")), + metadata=dict(metadata or {}), + ) + + def _load_workspace_shell_locked( + self, + workspace_id: str, + shell_id: str, + ) -> WorkspaceShellRecord: + record_path = self._workspace_shell_record_path(workspace_id, shell_id) + if not record_path.exists(): + raise ValueError(f"shell {shell_id!r} does not exist in workspace {workspace_id!r}") + payload = json.loads(record_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + raise RuntimeError(f"shell record at {record_path} is invalid") + return WorkspaceShellRecord.from_payload(payload) + + def _save_workspace_shell_locked(self, shell: WorkspaceShellRecord) -> None: + record_path = self._workspace_shell_record_path(shell.workspace_id, shell.shell_id) + record_path.parent.mkdir(parents=True, exist_ok=True) + record_path.write_text( + json.dumps(shell.to_payload(), indent=2, sort_keys=True), + encoding="utf-8", + ) + + def _delete_workspace_shell_locked(self, workspace_id: str, shell_id: str) -> None: + record_path = self._workspace_shell_record_path(workspace_id, shell_id) + if record_path.exists(): + record_path.unlink() + + def _list_workspace_shells_locked(self, workspace_id: str) -> list[WorkspaceShellRecord]: + shells_dir = self._workspace_shells_dir(workspace_id) + if not shells_dir.exists(): + return [] + shells: list[WorkspaceShellRecord] = [] + for record_path in sorted(shells_dir.glob("*.json")): + payload = json.loads(record_path.read_text(encoding="utf-8")) + if not isinstance(payload, dict): + continue + shells.append(WorkspaceShellRecord.from_payload(payload)) + return shells + + def _close_workspace_shells_locked( + self, + workspace: WorkspaceRecord, + instance: VmInstance, + ) -> None: + for shell in self._list_workspace_shells_locked(workspace.workspace_id): + try: + self._backend.close_shell( + instance, + workspace_id=workspace.workspace_id, + shell_id=shell.shell_id, + ) + except Exception: + pass + self._delete_workspace_shell_locked(workspace.workspace_id, shell.shell_id) diff --git a/src/pyro_mcp/workspace_shells.py b/src/pyro_mcp/workspace_shells.py new file mode 100644 index 0000000..0a30d06 --- /dev/null +++ b/src/pyro_mcp/workspace_shells.py @@ -0,0 +1,291 @@ +"""Local PTY-backed shell sessions for the mock workspace backend.""" + +from __future__ import annotations + +import codecs +import fcntl +import os +import pty +import shlex +import signal +import struct +import subprocess +import termios +import threading +import time +from pathlib import Path +from typing import Literal + +ShellState = Literal["running", "stopped"] + +SHELL_SIGNAL_NAMES = ("HUP", "INT", "TERM", "KILL") +_SHELL_SIGNAL_MAP = { + "HUP": signal.SIGHUP, + "INT": signal.SIGINT, + "TERM": signal.SIGTERM, + "KILL": signal.SIGKILL, +} + +_LOCAL_SHELLS: dict[str, "LocalShellSession"] = {} +_LOCAL_SHELLS_LOCK = threading.Lock() + + +def _set_pty_size(fd: int, rows: int, cols: int) -> None: + winsize = struct.pack("HHHH", rows, cols, 0, 0) + fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize) + + +class LocalShellSession: + """Host-local interactive shell used by the mock backend.""" + + def __init__( + self, + *, + shell_id: str, + cwd: Path, + display_cwd: str, + cols: int, + rows: int, + ) -> None: + self.shell_id = shell_id + self.cwd = display_cwd + self.cols = cols + self.rows = rows + self.started_at = time.time() + self.ended_at: float | None = None + self.exit_code: int | None = None + self.state: ShellState = "running" + self.pid: int | None = None + self._lock = threading.RLock() + self._output = "" + self._master_fd: int | None = None + self._reader: threading.Thread | None = None + self._waiter: threading.Thread | None = None + self._decoder = codecs.getincrementaldecoder("utf-8")("replace") + + master_fd, slave_fd = pty.openpty() + try: + _set_pty_size(slave_fd, rows, cols) + env = os.environ.copy() + env.update( + { + "TERM": env.get("TERM", "xterm-256color"), + "PS1": "pyro$ ", + "PROMPT_COMMAND": "", + } + ) + process = subprocess.Popen( # noqa: S603 + ["/bin/bash", "--noprofile", "--norc", "-i"], + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + cwd=str(cwd), + env=env, + text=False, + close_fds=True, + preexec_fn=os.setsid, + ) + except Exception: + os.close(master_fd) + raise + finally: + os.close(slave_fd) + + self._process = process + self.pid = process.pid + self._master_fd = master_fd + self._reader = threading.Thread(target=self._reader_loop, daemon=True) + self._waiter = threading.Thread(target=self._waiter_loop, daemon=True) + self._reader.start() + self._waiter.start() + + def summary(self) -> dict[str, object]: + with self._lock: + return { + "shell_id": self.shell_id, + "cwd": self.cwd, + "cols": self.cols, + "rows": self.rows, + "state": self.state, + "started_at": self.started_at, + "ended_at": self.ended_at, + "exit_code": self.exit_code, + "pid": self.pid, + } + + def read(self, *, cursor: int, max_chars: int) -> dict[str, object]: + with self._lock: + clamped_cursor = min(max(cursor, 0), len(self._output)) + output = self._output[clamped_cursor : clamped_cursor + max_chars] + next_cursor = clamped_cursor + len(output) + payload = self.summary() + payload.update( + { + "cursor": clamped_cursor, + "next_cursor": next_cursor, + "output": output, + "truncated": next_cursor < len(self._output), + } + ) + return payload + + def write(self, text: str, *, append_newline: bool) -> dict[str, object]: + if self._process.poll() is not None: + self._refresh_process_state() + with self._lock: + if self.state != "running": + raise RuntimeError(f"shell {self.shell_id} is not running") + master_fd = self._master_fd + if master_fd is None: + raise RuntimeError(f"shell {self.shell_id} transport is unavailable") + payload = text + ("\n" if append_newline else "") + try: + os.write(master_fd, payload.encode("utf-8")) + except OSError as exc: + self._refresh_process_state() + raise RuntimeError(f"failed to write to shell {self.shell_id}: {exc}") from exc + result = self.summary() + result.update({"input_length": len(text), "append_newline": append_newline}) + return result + + def send_signal(self, signal_name: str) -> dict[str, object]: + signal_name = signal_name.upper() + signum = _SHELL_SIGNAL_MAP.get(signal_name) + if signum is None: + raise ValueError(f"unsupported shell signal: {signal_name}") + if self._process.poll() is not None: + self._refresh_process_state() + with self._lock: + if self.state != "running" or self.pid is None: + raise RuntimeError(f"shell {self.shell_id} is not running") + pid = self.pid + try: + os.killpg(pid, signum) + except ProcessLookupError as exc: + self._refresh_process_state() + raise RuntimeError(f"shell {self.shell_id} is not running") from exc + result = self.summary() + result["signal"] = signal_name + return result + + def close(self) -> dict[str, object]: + if self._process.poll() is None and self.pid is not None: + try: + os.killpg(self.pid, signal.SIGHUP) + except ProcessLookupError: + pass + try: + self._process.wait(timeout=5) + except subprocess.TimeoutExpired: + try: + os.killpg(self.pid, signal.SIGKILL) + except ProcessLookupError: + pass + self._process.wait(timeout=5) + else: + self._refresh_process_state() + self._close_master_fd() + if self._reader is not None: + self._reader.join(timeout=1) + if self._waiter is not None: + self._waiter.join(timeout=1) + result = self.summary() + result["closed"] = True + return result + + def _reader_loop(self) -> None: + master_fd = self._master_fd + if master_fd is None: + return + while True: + try: + chunk = os.read(master_fd, 65536) + except OSError: + break + if chunk == b"": + break + decoded = self._decoder.decode(chunk) + if decoded: + with self._lock: + self._output += decoded + decoded = self._decoder.decode(b"", final=True) + if decoded: + with self._lock: + self._output += decoded + + def _waiter_loop(self) -> None: + exit_code = self._process.wait() + with self._lock: + self.state = "stopped" + self.exit_code = exit_code + self.ended_at = time.time() + + def _refresh_process_state(self) -> None: + exit_code = self._process.poll() + if exit_code is None: + return + with self._lock: + if self.state == "running": + self.state = "stopped" + self.exit_code = exit_code + self.ended_at = time.time() + + def _close_master_fd(self) -> None: + with self._lock: + master_fd = self._master_fd + self._master_fd = None + if master_fd is None: + return + try: + os.close(master_fd) + except OSError: + pass + + +def create_local_shell( + *, + workspace_id: str, + shell_id: str, + cwd: Path, + display_cwd: str, + cols: int, + rows: int, +) -> LocalShellSession: + session_key = f"{workspace_id}:{shell_id}" + with _LOCAL_SHELLS_LOCK: + if session_key in _LOCAL_SHELLS: + raise RuntimeError(f"shell {shell_id} already exists in workspace {workspace_id}") + session = LocalShellSession( + shell_id=shell_id, + cwd=cwd, + display_cwd=display_cwd, + cols=cols, + rows=rows, + ) + _LOCAL_SHELLS[session_key] = session + return session + + +def get_local_shell(*, workspace_id: str, shell_id: str) -> LocalShellSession: + session_key = f"{workspace_id}:{shell_id}" + with _LOCAL_SHELLS_LOCK: + try: + return _LOCAL_SHELLS[session_key] + except KeyError as exc: + raise ValueError( + f"shell {shell_id!r} does not exist in workspace {workspace_id!r}" + ) from exc + + +def remove_local_shell(*, workspace_id: str, shell_id: str) -> LocalShellSession | None: + session_key = f"{workspace_id}:{shell_id}" + with _LOCAL_SHELLS_LOCK: + return _LOCAL_SHELLS.pop(session_key, None) + + +def shell_signal_names() -> tuple[str, ...]: + return SHELL_SIGNAL_NAMES + + +def shell_signal_arg_help() -> str: + return ", ".join(shlex.quote(name) for name in SHELL_SIGNAL_NAMES) diff --git a/tests/test_api.py b/tests/test_api.py index 836f64c..cba5d1f 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import time from pathlib import Path from typing import Any, cast @@ -50,6 +51,11 @@ def test_pyro_create_server_registers_vm_run(tmp_path: Path) -> None: assert "vm_create" in tool_names assert "workspace_create" in tool_names assert "workspace_sync_push" in tool_names + assert "shell_open" in tool_names + assert "shell_read" in tool_names + assert "shell_write" in tool_names + assert "shell_signal" in tool_names + assert "shell_close" in tool_names def test_pyro_vm_run_tool_executes(tmp_path: Path) -> None: @@ -130,6 +136,16 @@ def test_pyro_workspace_methods_delegate_to_manager(tmp_path: Path) -> None: (updated_dir / "more.txt").write_text("more\n", encoding="utf-8") synced = pyro.push_workspace_sync(workspace_id, updated_dir, dest="subdir") executed = pyro.exec_workspace(workspace_id, command="cat note.txt") + opened = pyro.open_shell(workspace_id) + shell_id = str(opened["shell_id"]) + written = pyro.write_shell(workspace_id, shell_id, input="pwd") + read = pyro.read_shell(workspace_id, shell_id) + deadline = time.time() + 5 + while "/workspace" not in str(read["output"]) and time.time() < deadline: + read = pyro.read_shell(workspace_id, shell_id, cursor=0) + time.sleep(0.05) + signaled = pyro.signal_shell(workspace_id, shell_id) + closed = pyro.close_shell(workspace_id, shell_id) status = pyro.status_workspace(workspace_id) logs = pyro.logs_workspace(workspace_id) deleted = pyro.delete_workspace(workspace_id) @@ -137,6 +153,10 @@ def test_pyro_workspace_methods_delegate_to_manager(tmp_path: Path) -> None: assert executed["stdout"] == "ok\n" assert created["workspace_seed"]["mode"] == "directory" assert synced["workspace_sync"]["destination"] == "/workspace/subdir" + assert written["input_length"] == 3 + assert "/workspace" in read["output"] + assert signaled["signal"] == "INT" + assert closed["closed"] is True assert status["command_count"] == 1 assert logs["count"] == 1 assert deleted["deleted"] is True diff --git a/tests/test_cli.py b/tests/test_cli.py index b3914a6..da03832 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -64,6 +64,7 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None: assert "pyro workspace create debian:12 --seed-path ./repo" in workspace_help assert "pyro workspace sync push WORKSPACE_ID ./repo --dest src" in workspace_help assert "pyro workspace exec WORKSPACE_ID" in workspace_help + assert "pyro workspace shell open WORKSPACE_ID" in workspace_help workspace_create_help = _subparser_choice( _subparser_choice(parser, "workspace"), @@ -92,6 +93,41 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None: assert "--dest" in workspace_sync_push_help assert "Import host content into `/workspace`" in workspace_sync_push_help + workspace_shell_help = _subparser_choice( + _subparser_choice(parser, "workspace"), + "shell", + ).format_help() + assert "pyro workspace shell open WORKSPACE_ID" in workspace_shell_help + assert "Use `workspace exec` for one-shot commands." in workspace_shell_help + + workspace_shell_open_help = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "open" + ).format_help() + assert "--cwd" in workspace_shell_open_help + assert "--cols" in workspace_shell_open_help + assert "--rows" in workspace_shell_open_help + + workspace_shell_read_help = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "read" + ).format_help() + assert "Shell output is written to stdout." in workspace_shell_read_help + + workspace_shell_write_help = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "write" + ).format_help() + assert "--input" in workspace_shell_write_help + assert "--no-newline" in workspace_shell_write_help + + workspace_shell_signal_help = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "signal" + ).format_help() + assert "--signal" in workspace_shell_signal_help + + workspace_shell_close_help = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "close" + ).format_help() + assert "Close a persistent workspace shell" in workspace_shell_close_help + def test_cli_run_prints_json( monkeypatch: pytest.MonkeyPatch, @@ -681,6 +717,226 @@ def test_cli_workspace_status_and_delete_print_json( assert deleted["deleted"] is True +def test_cli_workspace_shell_open_and_read_human( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + class StubPyro: + def open_shell( + self, + workspace_id: str, + *, + cwd: str, + cols: int, + rows: int, + ) -> dict[str, Any]: + assert workspace_id == "workspace-123" + assert cwd == "/workspace" + assert cols == 120 + assert rows == 30 + return { + "workspace_id": workspace_id, + "shell_id": "shell-123", + "state": "running", + "cwd": cwd, + "cols": cols, + "rows": rows, + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + "execution_mode": "guest_vsock", + } + + def read_shell( + self, + workspace_id: str, + shell_id: str, + *, + cursor: int, + max_chars: int, + ) -> dict[str, Any]: + assert workspace_id == "workspace-123" + assert shell_id == "shell-123" + assert cursor == 0 + assert max_chars == 1024 + return { + "workspace_id": workspace_id, + "shell_id": shell_id, + "state": "running", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + "execution_mode": "guest_vsock", + "cursor": 0, + "next_cursor": 14, + "output": "pyro$ pwd\n", + "truncated": False, + } + + class OpenParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="workspace", + workspace_command="shell", + workspace_shell_command="open", + workspace_id="workspace-123", + cwd="/workspace", + cols=120, + rows=30, + json=False, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: OpenParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + cli.main() + + class ReadParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="workspace", + workspace_command="shell", + workspace_shell_command="read", + workspace_id="workspace-123", + shell_id="shell-123", + cursor=0, + max_chars=1024, + json=False, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: ReadParser()) + cli.main() + + captured = capsys.readouterr() + assert "pyro$ pwd\n" in captured.out + assert "[workspace-shell-open] workspace_id=workspace-123 shell_id=shell-123" in captured.err + assert "[workspace-shell-read] workspace_id=workspace-123 shell_id=shell-123" in captured.err + + +def test_cli_workspace_shell_write_signal_close_json( + monkeypatch: pytest.MonkeyPatch, + capsys: pytest.CaptureFixture[str], +) -> None: + class StubPyro: + def write_shell( + self, + workspace_id: str, + shell_id: str, + *, + input: str, + append_newline: bool, + ) -> dict[str, Any]: + assert workspace_id == "workspace-123" + assert shell_id == "shell-123" + assert input == "pwd" + assert append_newline is False + return { + "workspace_id": workspace_id, + "shell_id": shell_id, + "state": "running", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + "execution_mode": "guest_vsock", + "input_length": 3, + "append_newline": False, + } + + def signal_shell( + self, + workspace_id: str, + shell_id: str, + *, + signal_name: str, + ) -> dict[str, Any]: + assert signal_name == "INT" + return { + "workspace_id": workspace_id, + "shell_id": shell_id, + "state": "running", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + "execution_mode": "guest_vsock", + "signal": signal_name, + } + + def close_shell(self, workspace_id: str, shell_id: str) -> dict[str, Any]: + return { + "workspace_id": workspace_id, + "shell_id": shell_id, + "state": "stopped", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "started_at": 1.0, + "ended_at": 2.0, + "exit_code": 0, + "execution_mode": "guest_vsock", + "closed": True, + } + + class WriteParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="workspace", + workspace_command="shell", + workspace_shell_command="write", + workspace_id="workspace-123", + shell_id="shell-123", + input="pwd", + no_newline=True, + json=True, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: WriteParser()) + monkeypatch.setattr(cli, "Pyro", StubPyro) + cli.main() + written = json.loads(capsys.readouterr().out) + assert written["append_newline"] is False + + class SignalParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="workspace", + workspace_command="shell", + workspace_shell_command="signal", + workspace_id="workspace-123", + shell_id="shell-123", + signal="INT", + json=True, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: SignalParser()) + cli.main() + signaled = json.loads(capsys.readouterr().out) + assert signaled["signal"] == "INT" + + class CloseParser: + def parse_args(self) -> argparse.Namespace: + return argparse.Namespace( + command="workspace", + workspace_command="shell", + workspace_shell_command="close", + workspace_id="workspace-123", + shell_id="shell-123", + json=True, + ) + + monkeypatch.setattr(cli, "_build_parser", lambda: CloseParser()) + cli.main() + closed = json.loads(capsys.readouterr().out) + assert closed["closed"] is True + + def test_cli_workspace_exec_json_error_exits_nonzero( monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str] ) -> None: diff --git a/tests/test_public_contract.py b/tests/test_public_contract.py index eecb8ea..5db18e7 100644 --- a/tests/test_public_contract.py +++ b/tests/test_public_contract.py @@ -18,6 +18,12 @@ from pyro_mcp.contract import ( PUBLIC_CLI_ENV_SUBCOMMANDS, PUBLIC_CLI_RUN_FLAGS, PUBLIC_CLI_WORKSPACE_CREATE_FLAGS, + PUBLIC_CLI_WORKSPACE_SHELL_CLOSE_FLAGS, + PUBLIC_CLI_WORKSPACE_SHELL_OPEN_FLAGS, + PUBLIC_CLI_WORKSPACE_SHELL_READ_FLAGS, + PUBLIC_CLI_WORKSPACE_SHELL_SIGNAL_FLAGS, + PUBLIC_CLI_WORKSPACE_SHELL_SUBCOMMANDS, + PUBLIC_CLI_WORKSPACE_SHELL_WRITE_FLAGS, PUBLIC_CLI_WORKSPACE_SUBCOMMANDS, PUBLIC_CLI_WORKSPACE_SYNC_PUSH_FLAGS, PUBLIC_CLI_WORKSPACE_SYNC_SUBCOMMANDS, @@ -86,6 +92,37 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None: ).format_help() for flag in PUBLIC_CLI_WORKSPACE_SYNC_PUSH_FLAGS: assert flag in workspace_sync_push_help_text + workspace_shell_help_text = _subparser_choice( + _subparser_choice(parser, "workspace"), + "shell", + ).format_help() + for subcommand_name in PUBLIC_CLI_WORKSPACE_SHELL_SUBCOMMANDS: + assert subcommand_name in workspace_shell_help_text + workspace_shell_open_help_text = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "open" + ).format_help() + for flag in PUBLIC_CLI_WORKSPACE_SHELL_OPEN_FLAGS: + assert flag in workspace_shell_open_help_text + workspace_shell_read_help_text = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "read" + ).format_help() + for flag in PUBLIC_CLI_WORKSPACE_SHELL_READ_FLAGS: + assert flag in workspace_shell_read_help_text + workspace_shell_write_help_text = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "write" + ).format_help() + for flag in PUBLIC_CLI_WORKSPACE_SHELL_WRITE_FLAGS: + assert flag in workspace_shell_write_help_text + workspace_shell_signal_help_text = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "signal" + ).format_help() + for flag in PUBLIC_CLI_WORKSPACE_SHELL_SIGNAL_FLAGS: + assert flag in workspace_shell_signal_help_text + workspace_shell_close_help_text = _subparser_choice( + _subparser_choice(_subparser_choice(parser, "workspace"), "shell"), "close" + ).format_help() + for flag in PUBLIC_CLI_WORKSPACE_SHELL_CLOSE_FLAGS: + assert flag in workspace_shell_close_help_text demo_help_text = _subparser_choice(parser, "demo").format_help() for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS: diff --git a/tests/test_server.py b/tests/test_server.py index ce34b4d..5525e8c 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import time from pathlib import Path from typing import Any, cast @@ -34,6 +35,11 @@ def test_create_server_registers_vm_tools(tmp_path: Path) -> None: assert "workspace_create" in tool_names assert "workspace_logs" in tool_names assert "workspace_sync_push" in tool_names + assert "shell_open" in tool_names + assert "shell_read" in tool_names + assert "shell_write" in tool_names + assert "shell_signal" in tool_names + assert "shell_close" in tool_names def test_vm_run_round_trip(tmp_path: Path) -> None: @@ -190,6 +196,11 @@ def test_workspace_tools_round_trip(tmp_path: Path) -> None: dict[str, Any], dict[str, Any], dict[str, Any], + dict[str, Any], + dict[str, Any], + dict[str, Any], + dict[str, Any], + dict[str, Any], ]: server = create_server(manager=manager) created = _extract_structured( @@ -225,18 +236,88 @@ def test_workspace_tools_round_trip(tmp_path: Path) -> None: }, ) ) + opened = _extract_structured( + await server.call_tool("shell_open", {"workspace_id": workspace_id}) + ) + shell_id = str(opened["shell_id"]) + written = _extract_structured( + await server.call_tool( + "shell_write", + { + "workspace_id": workspace_id, + "shell_id": shell_id, + "input": "pwd", + }, + ) + ) + read = _extract_structured( + await server.call_tool( + "shell_read", + { + "workspace_id": workspace_id, + "shell_id": shell_id, + }, + ) + ) + deadline = time.time() + 5 + while "/workspace" not in str(read["output"]) and time.time() < deadline: + read = _extract_structured( + await server.call_tool( + "shell_read", + { + "workspace_id": workspace_id, + "shell_id": shell_id, + "cursor": 0, + }, + ) + ) + await asyncio.sleep(0.05) + signaled = _extract_structured( + await server.call_tool( + "shell_signal", + { + "workspace_id": workspace_id, + "shell_id": shell_id, + }, + ) + ) + closed = _extract_structured( + await server.call_tool( + "shell_close", + { + "workspace_id": workspace_id, + "shell_id": shell_id, + }, + ) + ) logs = _extract_structured( await server.call_tool("workspace_logs", {"workspace_id": workspace_id}) ) deleted = _extract_structured( await server.call_tool("workspace_delete", {"workspace_id": workspace_id}) ) - return created, synced, executed, logs, deleted + return created, synced, executed, opened, written, read, signaled, closed, logs, deleted - created, synced, executed, logs, deleted = asyncio.run(_run()) + ( + created, + synced, + executed, + opened, + written, + read, + signaled, + closed, + logs, + deleted, + ) = asyncio.run(_run()) assert created["state"] == "started" assert created["workspace_seed"]["mode"] == "directory" assert synced["workspace_sync"]["destination"] == "/workspace/subdir" assert executed["stdout"] == "more\n" + assert opened["state"] == "running" + assert written["input_length"] == 3 + assert "/workspace" in read["output"] + assert signaled["signal"] == "INT" + assert closed["closed"] is True assert logs["count"] == 1 assert deleted["deleted"] is True diff --git a/tests/test_vm_guest.py b/tests/test_vm_guest.py index ee4773e..cc35209 100644 --- a/tests/test_vm_guest.py +++ b/tests/test_vm_guest.py @@ -105,6 +105,130 @@ def test_vsock_exec_client_upload_archive_round_trip( assert stub.closed is True +def test_vsock_exec_client_shell_round_trip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) + responses = [ + json.dumps( + { + "shell_id": "shell-1", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "state": "running", + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + } + ).encode("utf-8"), + json.dumps( + { + "shell_id": "shell-1", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "state": "running", + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + "cursor": 0, + "next_cursor": 12, + "output": "pyro$ pwd\n", + "truncated": False, + } + ).encode("utf-8"), + json.dumps( + { + "shell_id": "shell-1", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "state": "running", + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + "input_length": 3, + "append_newline": True, + } + ).encode("utf-8"), + json.dumps( + { + "shell_id": "shell-1", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "state": "running", + "started_at": 1.0, + "ended_at": None, + "exit_code": None, + "signal": "INT", + } + ).encode("utf-8"), + json.dumps( + { + "shell_id": "shell-1", + "cwd": "/workspace", + "cols": 120, + "rows": 30, + "state": "stopped", + "started_at": 1.0, + "ended_at": 2.0, + "exit_code": 0, + "closed": True, + } + ).encode("utf-8"), + ] + stubs = [StubSocket(response) for response in responses] + remaining = list(stubs) + + def socket_factory(family: int, sock_type: int) -> StubSocket: + assert family == socket.AF_VSOCK + assert sock_type == socket.SOCK_STREAM + return remaining.pop(0) + + client = VsockExecClient(socket_factory=socket_factory) + opened = client.open_shell( + 1234, + 5005, + shell_id="shell-1", + cwd="/workspace", + cols=120, + rows=30, + ) + assert opened.shell_id == "shell-1" + read = client.read_shell(1234, 5005, shell_id="shell-1", cursor=0, max_chars=1024) + assert read.output == "pyro$ pwd\n" + write = client.write_shell( + 1234, + 5005, + shell_id="shell-1", + input_text="pwd", + append_newline=True, + ) + assert write["input_length"] == 3 + signaled = client.signal_shell(1234, 5005, shell_id="shell-1", signal_name="INT") + assert signaled["signal"] == "INT" + closed = client.close_shell(1234, 5005, shell_id="shell-1") + assert closed["closed"] is True + open_request = json.loads(stubs[0].sent.decode("utf-8").strip()) + assert open_request["action"] == "open_shell" + assert open_request["shell_id"] == "shell-1" + + +def test_vsock_exec_client_raises_agent_error(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) + stub = StubSocket(b'{"error":"shell is unavailable"}') + client = VsockExecClient(socket_factory=lambda family, sock_type: stub) + with pytest.raises(RuntimeError, match="shell is unavailable"): + client.open_shell( + 1234, + 5005, + shell_id="shell-1", + cwd="/workspace", + cols=120, + rows=30, + ) + + def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False) stub = StubSocket(b"[]") diff --git a/tests/test_vm_manager.py b/tests/test_vm_manager.py index 6bd059a..ca65ad2 100644 --- a/tests/test_vm_manager.py +++ b/tests/test_vm_manager.py @@ -454,6 +454,73 @@ def test_workspace_sync_push_rejects_destination_outside_workspace(tmp_path: Pat manager.push_workspace_sync(workspace_id, source_path=source_dir, dest="../escape") +def test_workspace_shell_lifecycle_and_rehydration(tmp_path: Path) -> None: + manager = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + created = manager.create_workspace( + environment="debian:12-base", + allow_host_compat=True, + ) + workspace_id = str(created["workspace_id"]) + + opened = manager.open_shell(workspace_id) + shell_id = str(opened["shell_id"]) + assert opened["state"] == "running" + + manager.write_shell(workspace_id, shell_id, input_text="pwd") + + output = "" + deadline = time.time() + 5 + while time.time() < deadline: + read = manager.read_shell(workspace_id, shell_id, cursor=0, max_chars=65536) + output = str(read["output"]) + if "/workspace" in output: + break + time.sleep(0.05) + assert "/workspace" in output + + manager_rehydrated = VmManager( + backend_name="mock", + base_dir=tmp_path / "vms", + network_manager=TapNetworkManager(enabled=False), + ) + second_opened = manager_rehydrated.open_shell(workspace_id) + second_shell_id = str(second_opened["shell_id"]) + assert second_shell_id != shell_id + + manager_rehydrated.write_shell(workspace_id, second_shell_id, input_text="printf 'ok\\n'") + second_output = "" + deadline = time.time() + 5 + while time.time() < deadline: + read = manager_rehydrated.read_shell( + workspace_id, + second_shell_id, + cursor=0, + max_chars=65536, + ) + second_output = str(read["output"]) + if "ok" in second_output: + break + time.sleep(0.05) + assert "ok" in second_output + + logs = manager.logs_workspace(workspace_id) + assert logs["count"] == 0 + + closed = manager.close_shell(workspace_id, shell_id) + assert closed["closed"] is True + with pytest.raises(ValueError, match="does not exist"): + manager.read_shell(workspace_id, shell_id) + + deleted = manager.delete_workspace(workspace_id) + assert deleted["deleted"] is True + with pytest.raises(ValueError, match="does not exist"): + manager_rehydrated.read_shell(workspace_id, second_shell_id) + + def test_workspace_create_rejects_unsafe_seed_archive(tmp_path: Path) -> None: archive_path = tmp_path / "bad.tgz" with tarfile.open(archive_path, "w:gz") as archive: diff --git a/tests/test_workspace_shells.py b/tests/test_workspace_shells.py new file mode 100644 index 0000000..7b56cf4 --- /dev/null +++ b/tests/test_workspace_shells.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import os +import subprocess +import time +from pathlib import Path +from typing import cast + +import pytest + +from pyro_mcp.workspace_shells import ( + create_local_shell, + get_local_shell, + remove_local_shell, + shell_signal_arg_help, + shell_signal_names, +) + + +def _read_until( + workspace_id: str, + shell_id: str, + text: str, + *, + timeout_seconds: float = 5.0, +) -> dict[str, object]: + deadline = time.time() + timeout_seconds + payload = get_local_shell(workspace_id=workspace_id, shell_id=shell_id).read( + cursor=0, + max_chars=65536, + ) + while text not in str(payload["output"]) and time.time() < deadline: + time.sleep(0.05) + payload = get_local_shell(workspace_id=workspace_id, shell_id=shell_id).read( + cursor=0, + max_chars=65536, + ) + return payload + + +def test_workspace_shells_round_trip(tmp_path: Path) -> None: + session = create_local_shell( + workspace_id="workspace-1", + shell_id="shell-1", + cwd=tmp_path, + display_cwd="/workspace", + cols=120, + rows=30, + ) + try: + assert session.summary()["state"] == "running" + write = session.write("printf 'hello\\n'", append_newline=True) + assert write["input_length"] == 16 + payload = _read_until("workspace-1", "shell-1", "hello") + assert "hello" in str(payload["output"]) + assert cast(int, payload["next_cursor"]) >= cast(int, payload["cursor"]) + assert isinstance(payload["truncated"], bool) + session.write("sleep 60", append_newline=True) + signaled = session.send_signal("INT") + assert signaled["signal"] == "INT" + finally: + closed = session.close() + assert closed["closed"] is True + + +def test_workspace_shell_registry_helpers(tmp_path: Path) -> None: + session = create_local_shell( + workspace_id="workspace-2", + shell_id="shell-2", + cwd=tmp_path, + display_cwd="/workspace/subdir", + cols=80, + rows=24, + ) + assert get_local_shell(workspace_id="workspace-2", shell_id="shell-2") is session + assert shell_signal_names() == ("HUP", "INT", "TERM", "KILL") + assert "HUP" in shell_signal_arg_help() + with pytest.raises(RuntimeError, match="already exists"): + create_local_shell( + workspace_id="workspace-2", + shell_id="shell-2", + cwd=tmp_path, + display_cwd="/workspace/subdir", + cols=80, + rows=24, + ) + removed = remove_local_shell(workspace_id="workspace-2", shell_id="shell-2") + assert removed is session + assert remove_local_shell(workspace_id="workspace-2", shell_id="shell-2") is None + with pytest.raises(ValueError, match="does not exist"): + get_local_shell(workspace_id="workspace-2", shell_id="shell-2") + closed = session.close() + assert closed["closed"] is True + + +def test_workspace_shells_error_after_exit(tmp_path: Path) -> None: + session = create_local_shell( + workspace_id="workspace-3", + shell_id="shell-3", + cwd=tmp_path, + display_cwd="/workspace", + cols=120, + rows=30, + ) + session.write("exit", append_newline=True) + deadline = time.time() + 5 + while session.summary()["state"] != "stopped" and time.time() < deadline: + time.sleep(0.05) + assert session.summary()["state"] == "stopped" + with pytest.raises(RuntimeError, match="not running"): + session.write("pwd", append_newline=True) + with pytest.raises(RuntimeError, match="not running"): + session.send_signal("INT") + closed = session.close() + assert closed["closed"] is True + + +def test_workspace_shells_reject_invalid_signal(tmp_path: Path) -> None: + session = create_local_shell( + workspace_id="workspace-4", + shell_id="shell-4", + cwd=tmp_path, + display_cwd="/workspace", + cols=120, + rows=30, + ) + try: + with pytest.raises(ValueError, match="unsupported shell signal"): + session.send_signal("BOGUS") + finally: + session.close() + + +def test_workspace_shells_init_failure_closes_ptys( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + def _boom(*args: object, **kwargs: object) -> object: + raise RuntimeError("boom") + + monkeypatch.setattr(subprocess, "Popen", _boom) + with pytest.raises(RuntimeError, match="boom"): + create_local_shell( + workspace_id="workspace-5", + shell_id="shell-5", + cwd=tmp_path, + display_cwd="/workspace", + cols=120, + rows=30, + ) + + +def test_workspace_shells_write_and_signal_runtime_errors( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + session = create_local_shell( + workspace_id="workspace-6", + shell_id="shell-6", + cwd=tmp_path, + display_cwd="/workspace", + cols=120, + rows=30, + ) + try: + with session._lock: # noqa: SLF001 + session._master_fd = None # noqa: SLF001 + with pytest.raises(RuntimeError, match="transport is unavailable"): + session.write("pwd", append_newline=True) + + with session._lock: # noqa: SLF001 + master_fd, slave_fd = os.pipe() + os.close(slave_fd) + session._master_fd = master_fd # noqa: SLF001 + + def _raise_write(fd: int, data: bytes) -> int: + del fd, data + raise OSError("broken") + + monkeypatch.setattr("pyro_mcp.workspace_shells.os.write", _raise_write) + with pytest.raises(RuntimeError, match="failed to write"): + session.write("pwd", append_newline=True) + + def _raise_killpg(pid: int, signum: int) -> None: + del pid, signum + raise ProcessLookupError() + + monkeypatch.setattr("pyro_mcp.workspace_shells.os.killpg", _raise_killpg) + with pytest.raises(RuntimeError, match="not running"): + session.send_signal("INT") + finally: + try: + session.close() + except Exception: + pass + + +def test_workspace_shells_refresh_process_state_updates_exit_code(tmp_path: Path) -> None: + session = create_local_shell( + workspace_id="workspace-7", + shell_id="shell-7", + cwd=tmp_path, + display_cwd="/workspace", + cols=120, + rows=30, + ) + try: + class StubProcess: + def poll(self) -> int: + return 7 + + session._process = StubProcess() # type: ignore[assignment] # noqa: SLF001 + session._refresh_process_state() # noqa: SLF001 + assert session.summary()["state"] == "stopped" + assert session.summary()["exit_code"] == 7 + finally: + try: + session.close() + except Exception: + pass diff --git a/uv.lock b/uv.lock index 4fd3492..0297f19 100644 --- a/uv.lock +++ b/uv.lock @@ -706,7 +706,7 @@ crypto = [ [[package]] name = "pyro-mcp" -version = "2.4.0" +version = "2.5.0" source = { editable = "." } dependencies = [ { name = "mcp" },