pyro-mcp/runtime_sources/linux-x86_64/guest/pyro_guest_agent.py
Thales Maciel ab02ae46c7 Add model-native workspace file operations
Remove shell-escaped file mutation from the stable workspace flow by adding explicit file and patch tools across the CLI, SDK, and MCP surfaces.

This adds workspace file list/read/write plus unified text patch application, backed by new guest and manager file primitives that stay scoped to started workspaces and /workspace only. Patch application is preflighted on the host, file writes stay text-only and bounded, and the existing diff/export/reset semantics remain intact.

The milestone also updates the 3.2.0 roadmap, public contract, docs, examples, and versioning, and includes focused coverage for the new helper module and dispatch paths.

Validation:
- uv lock
- UV_CACHE_DIR=.uv-cache make check
- UV_CACHE_DIR=.uv-cache make dist-check
- real guest-backed smoke for workspace file read, patch apply, exec, export, and delete
2026-03-12 22:03:25 -03:00

1247 lines
46 KiB
Python

#!/usr/bin/env python3
"""Guest-side exec, workspace import, and interactive shell agent."""
from __future__ import annotations
import base64
import codecs
import fcntl
import io
import json
import os
import re
import shlex
import shutil
import signal
import socket
import struct
import subprocess
import tarfile
import tempfile
import termios
import threading
import time
import urllib.error
import urllib.request
from pathlib import Path, PurePosixPath
from typing import Any
PORT = 5005
BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
SHELL_ROOT = Path("/run/pyro-shells")
SERVICE_ROOT = Path("/run/pyro-services")
SECRET_ROOT = Path("/run/pyro-secrets")
WORKSPACE_FILE_MAX_BYTES = 1024 * 1024
SERVICE_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
SHELL_SIGNAL_MAP = {
"HUP": signal.SIGHUP,
"INT": signal.SIGINT,
"TERM": signal.SIGTERM,
"KILL": signal.SIGKILL,
}
SHELL_SIGNAL_NAMES = tuple(SHELL_SIGNAL_MAP)
_SHELLS: dict[str, "GuestShellSession"] = {}
_SHELLS_LOCK = threading.Lock()
def _redact_text(text: str, redact_values: list[str]) -> str:
redacted = text
for secret_value in sorted(
{item for item in redact_values if item != ""},
key=len,
reverse=True,
):
redacted = redacted.replace(secret_value, "[REDACTED]")
return redacted
def _read_request(conn: socket.socket) -> dict[str, Any]:
chunks: list[bytes] = []
while True:
data = conn.recv(1)
if data == b"":
break
chunks.append(data)
if data == b"\n":
break
payload = json.loads(b"".join(chunks).decode("utf-8").strip())
if not isinstance(payload, dict):
raise RuntimeError("request must be a JSON object")
return payload
def _read_exact(conn: socket.socket, size: int) -> bytes:
remaining = size
chunks: list[bytes] = []
while remaining > 0:
data = conn.recv(min(BUFFER_SIZE, remaining))
if data == b"":
raise RuntimeError("unexpected EOF while reading archive payload")
chunks.append(data)
remaining -= len(data)
return b"".join(chunks)
def _normalize_member_name(name: str) -> PurePosixPath:
candidate = name.strip()
if candidate == "":
raise RuntimeError("archive member path is empty")
member_path = PurePosixPath(candidate)
if member_path.is_absolute():
raise RuntimeError(f"absolute archive member paths are not allowed: {name}")
parts = [part for part in member_path.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"unsafe archive member path: {name}")
normalized = PurePosixPath(*parts)
if str(normalized) in {"", "."}:
raise RuntimeError(f"unsafe archive member path: {name}")
return normalized
def _normalize_destination(destination: str) -> tuple[PurePosixPath, Path]:
candidate = destination.strip()
if candidate == "":
raise RuntimeError("destination must not be empty")
destination_path = PurePosixPath(candidate)
if not destination_path.is_absolute():
destination_path = WORKSPACE_ROOT / destination_path
parts = [part for part in destination_path.parts if part not in {"", "."}]
normalized = PurePosixPath("/") / PurePosixPath(*parts)
if normalized == PurePosixPath("/"):
raise RuntimeError("destination must stay inside /workspace")
if normalized.parts[: len(WORKSPACE_ROOT.parts)] != WORKSPACE_ROOT.parts:
raise RuntimeError("destination must stay inside /workspace")
suffix = normalized.relative_to(WORKSPACE_ROOT)
host_path = Path("/workspace")
if str(suffix) not in {"", "."}:
host_path = host_path.joinpath(*suffix.parts)
return normalized, host_path
def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]:
normalized, host_path = _normalize_destination(cwd)
return str(normalized), host_path
def _normalize_service_name(service_name: str) -> str:
normalized = service_name.strip()
if normalized == "":
raise RuntimeError("service_name is required")
if SERVICE_NAME_RE.fullmatch(normalized) is None:
raise RuntimeError("service_name is invalid")
return normalized
def _service_stdout_path(service_name: str) -> Path:
return SERVICE_ROOT / f"{service_name}.stdout"
def _service_stderr_path(service_name: str) -> Path:
return SERVICE_ROOT / f"{service_name}.stderr"
def _service_status_path(service_name: str) -> Path:
return SERVICE_ROOT / f"{service_name}.status"
def _service_runner_path(service_name: str) -> Path:
return SERVICE_ROOT / f"{service_name}.runner.sh"
def _service_metadata_path(service_name: str) -> Path:
return SERVICE_ROOT / f"{service_name}.json"
def _normalize_secret_name(secret_name: str) -> str:
normalized = secret_name.strip()
if normalized == "":
raise RuntimeError("secret name is required")
if re.fullmatch(r"^[A-Za-z_][A-Za-z0-9_]{0,63}$", normalized) is None:
raise RuntimeError("secret name is invalid")
return normalized
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
raise RuntimeError(f"symlink {member_path} has an empty target")
target_path = PurePosixPath(target)
if target_path.is_absolute():
raise RuntimeError(f"symlink {member_path} escapes the workspace")
combined = member_path.parent.joinpath(target_path)
parts = [part for part in combined.parts if part not in {"", "."}]
if any(part == ".." for part in parts):
raise RuntimeError(f"symlink {member_path} escapes the workspace")
def _ensure_no_symlink_parents(root: Path, target_path: Path, member_name: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"archive member would traverse through a symlinked path: {member_name}"
)
def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
normalized_destination, destination_root = _normalize_destination(destination)
destination_root.mkdir(parents=True, exist_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = destination_root.joinpath(*member_name.parts)
entry_count += 1
_ensure_no_symlink_parents(destination_root, target_path, member.name)
if member.isdir():
if target_path.is_symlink() or (target_path.exists() and not target_path.is_dir()):
raise RuntimeError(f"directory conflicts with existing path: {member.name}")
target_path.mkdir(parents=True, exist_ok=True)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and (target_path.is_dir() or target_path.is_symlink()):
raise RuntimeError(f"file conflicts with existing path: {member.name}")
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read archive member: {member.name}")
with target_path.open("wb") as handle:
while True:
chunk = source.read(BUFFER_SIZE)
if chunk == b"":
break
handle.write(chunk)
bytes_written += member.size
continue
if member.issym():
_validate_symlink_target(member_name, member.linkname)
target_path.parent.mkdir(parents=True, exist_ok=True)
if target_path.exists() and not target_path.is_symlink():
raise RuntimeError(f"symlink conflicts with existing path: {member.name}")
if target_path.is_symlink():
target_path.unlink()
os.symlink(member.linkname, target_path)
continue
if member.islnk():
raise RuntimeError(
f"hard links are not allowed in workspace archives: {member.name}"
)
raise RuntimeError(f"unsupported archive member type: {member.name}")
return {
"destination": str(normalized_destination),
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _install_secrets_archive(payload: bytes) -> dict[str, Any]:
SECRET_ROOT.mkdir(parents=True, exist_ok=True)
for existing in SECRET_ROOT.iterdir():
if existing.is_dir() and not existing.is_symlink():
shutil.rmtree(existing, ignore_errors=True)
else:
existing.unlink(missing_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = SECRET_ROOT.joinpath(*member_name.parts)
entry_count += 1
if member.isdir():
target_path.mkdir(parents=True, exist_ok=True)
target_path.chmod(0o700)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
target_path.parent.chmod(0o700)
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read secret archive member: {member.name}")
with target_path.open("wb") as handle:
while True:
chunk = source.read(BUFFER_SIZE)
if chunk == b"":
break
handle.write(chunk)
target_path.chmod(0o600)
bytes_written += member.size
continue
if member.issym() or member.islnk():
raise RuntimeError(f"secret archive may not contain links: {member.name}")
raise RuntimeError(f"unsupported secret archive member type: {member.name}")
return {
"destination": str(SECRET_ROOT),
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _inspect_archive(archive_path: Path) -> tuple[int, int]:
entry_count = 0
bytes_written = 0
with tarfile.open(archive_path, "r:*") as archive:
for member in archive.getmembers():
entry_count += 1
if member.isfile():
bytes_written += member.size
return entry_count, bytes_written
def _prepare_export_archive(path: str) -> dict[str, Any]:
normalized_path, source_path = _normalize_destination(path)
if not source_path.exists() and not source_path.is_symlink():
raise RuntimeError(f"workspace path does not exist: {normalized_path}")
if source_path.is_symlink():
artifact_type = "symlink"
elif source_path.is_file():
artifact_type = "file"
elif source_path.is_dir():
artifact_type = "directory"
else:
raise RuntimeError(f"unsupported workspace path type: {normalized_path}")
with tempfile.NamedTemporaryFile(prefix="pyro-export-", suffix=".tar", delete=False) as handle:
archive_path = Path(handle.name)
try:
with tarfile.open(archive_path, "w") as archive:
archive.dereference = False
if artifact_type == "directory":
for child in sorted(source_path.iterdir(), key=lambda item: item.name):
archive.add(child, arcname=child.name, recursive=True)
else:
archive.add(source_path, arcname=source_path.name, recursive=False)
entry_count, bytes_written = _inspect_archive(archive_path)
return {
"workspace_path": str(normalized_path),
"artifact_type": artifact_type,
"archive_path": archive_path,
"archive_size": archive_path.stat().st_size,
"entry_count": entry_count,
"bytes_written": bytes_written,
}
except Exception:
archive_path.unlink(missing_ok=True)
raise
def _workspace_entry(path_text: str, host_path: Path) -> dict[str, Any]:
try:
stat_result = os.lstat(host_path)
except FileNotFoundError as exc:
raise RuntimeError(f"workspace path does not exist: {path_text}") from exc
if host_path.is_symlink():
return {
"path": path_text,
"artifact_type": "symlink",
"size_bytes": stat_result.st_size,
"link_target": os.readlink(host_path),
}
if host_path.is_dir():
return {
"path": path_text,
"artifact_type": "directory",
"size_bytes": 0,
"link_target": None,
}
if host_path.is_file():
return {
"path": path_text,
"artifact_type": "file",
"size_bytes": stat_result.st_size,
"link_target": None,
}
raise RuntimeError(f"unsupported workspace path type: {path_text}")
def _join_workspace_path(base: str, child_name: str) -> str:
base_path = PurePosixPath(base)
return str(base_path / child_name) if str(base_path) != "/" else f"/{child_name}"
def _list_workspace(path: str, *, recursive: bool) -> dict[str, Any]:
normalized_path, host_path = _normalize_destination(path)
entry = _workspace_entry(str(normalized_path), host_path)
if entry["artifact_type"] != "directory":
return {
"path": str(normalized_path),
"artifact_type": entry["artifact_type"],
"entries": [entry],
}
entries: list[dict[str, Any]] = []
def walk(current_path: str, current_host_path: Path) -> None:
children: list[tuple[dict[str, Any], Path]] = []
with os.scandir(current_host_path) as iterator:
for child in iterator:
child_host_path = Path(child.path)
children.append(
(
_workspace_entry(
_join_workspace_path(current_path, child.name),
child_host_path,
),
child_host_path,
)
)
children.sort(key=lambda item: str(item[0]["path"]))
for child_entry, child_host_path in children:
entries.append(child_entry)
if recursive and child_entry["artifact_type"] == "directory":
walk(str(child_entry["path"]), child_host_path)
walk(str(normalized_path), host_path)
return {
"path": str(normalized_path),
"artifact_type": "directory",
"entries": entries,
}
def _read_workspace_file(path: str, *, max_bytes: int) -> dict[str, Any]:
if max_bytes <= 0:
raise RuntimeError("max_bytes must be positive")
if max_bytes > WORKSPACE_FILE_MAX_BYTES:
raise RuntimeError(
f"max_bytes must be at most {WORKSPACE_FILE_MAX_BYTES} bytes"
)
normalized_path, host_path = _normalize_destination(path)
entry = _workspace_entry(str(normalized_path), host_path)
if entry["artifact_type"] != "file":
raise RuntimeError("workspace file read only supports regular files")
raw_bytes = host_path.read_bytes()
if len(raw_bytes) > max_bytes:
raise RuntimeError(
f"workspace file exceeds the maximum supported size of {max_bytes} bytes"
)
return {
"path": str(normalized_path),
"size_bytes": len(raw_bytes),
"content_b64": base64.b64encode(raw_bytes).decode("ascii"),
}
def _ensure_no_symlink_parents_for_write(root: Path, target_path: Path, path_text: str) -> None:
relative_path = target_path.relative_to(root)
current = root
for part in relative_path.parts[:-1]:
current = current / part
if current.is_symlink():
raise RuntimeError(
f"workspace path would traverse through a symlinked parent: {path_text}"
)
def _write_workspace_file(path: str, *, text: str) -> dict[str, Any]:
raw_bytes = text.encode("utf-8")
if len(raw_bytes) > WORKSPACE_FILE_MAX_BYTES:
raise RuntimeError(
f"text must be at most {WORKSPACE_FILE_MAX_BYTES} bytes when encoded as UTF-8"
)
normalized_path, host_path = _normalize_destination(path)
_ensure_no_symlink_parents_for_write(Path("/workspace"), host_path, str(normalized_path))
if host_path.exists() or host_path.is_symlink():
entry = _workspace_entry(str(normalized_path), host_path)
if entry["artifact_type"] != "file":
raise RuntimeError("workspace file write only supports regular file targets")
host_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(
prefix=".pyro-workspace-write-",
dir=host_path.parent,
delete=False,
) as handle:
temp_path = Path(handle.name)
handle.write(raw_bytes)
os.replace(temp_path, host_path)
return {
"path": str(normalized_path),
"size_bytes": len(raw_bytes),
"bytes_written": len(raw_bytes),
}
def _delete_workspace_path(path: str) -> dict[str, Any]:
normalized_path, host_path = _normalize_destination(path)
entry = _workspace_entry(str(normalized_path), host_path)
if entry["artifact_type"] == "directory":
raise RuntimeError("workspace file delete does not support directories")
host_path.unlink(missing_ok=False)
return {
"path": str(normalized_path),
"deleted": True,
}
def _run_command(
command: str,
timeout_seconds: int,
*,
env: dict[str, str] | None = None,
) -> dict[str, Any]:
started = time.monotonic()
command_env = os.environ.copy()
if env is not None:
command_env.update(env)
try:
proc = subprocess.run(
["/bin/sh", "-lc", command],
text=True,
capture_output=True,
env=command_env,
timeout=timeout_seconds,
check=False,
)
return {
"stdout": proc.stdout,
"stderr": proc.stderr,
"exit_code": proc.returncode,
"duration_ms": int((time.monotonic() - started) * 1000),
}
except subprocess.TimeoutExpired:
return {
"stdout": "",
"stderr": f"command timed out after {timeout_seconds}s",
"exit_code": 124,
"duration_ms": int((time.monotonic() - started) * 1000),
}
def _set_pty_size(fd: int, rows: int, cols: int) -> None:
winsize = struct.pack("HHHH", rows, cols, 0, 0)
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
def _shell_argv(*, interactive: bool) -> list[str]:
shell_program = shutil.which("bash") or "/bin/sh"
argv = [shell_program]
if shell_program.endswith("bash"):
argv.extend(["--noprofile", "--norc"])
if interactive:
argv.append("-i")
return argv
class GuestShellSession:
"""In-guest PTY-backed interactive shell session."""
def __init__(
self,
*,
shell_id: str,
cwd: Path,
cwd_text: str,
cols: int,
rows: int,
env_overrides: dict[str, str] | None = None,
redact_values: list[str] | None = None,
) -> None:
self.shell_id = shell_id
self.cwd = cwd_text
self.cols = cols
self.rows = rows
self.started_at = time.time()
self.ended_at: float | None = None
self.exit_code: int | None = None
self.state = "running"
self._lock = threading.RLock()
self._output = ""
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
self._redact_values = list(redact_values or [])
self._metadata_path = SHELL_ROOT / f"{shell_id}.json"
self._log_path = SHELL_ROOT / f"{shell_id}.log"
self._master_fd: int | None = None
master_fd, slave_fd = os.openpty()
try:
_set_pty_size(slave_fd, rows, cols)
env = os.environ.copy()
env.update(
{
"TERM": env.get("TERM", "xterm-256color"),
"PS1": "pyro$ ",
"PROMPT_COMMAND": "",
}
)
if env_overrides is not None:
env.update(env_overrides)
process = subprocess.Popen( # noqa: S603
_shell_argv(interactive=True),
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd=str(cwd),
env=env,
text=False,
close_fds=True,
preexec_fn=os.setsid,
)
except Exception:
os.close(master_fd)
raise
finally:
os.close(slave_fd)
self._process = process
self._master_fd = master_fd
self._write_metadata()
self._reader = threading.Thread(target=self._reader_loop, daemon=True)
self._waiter = threading.Thread(target=self._waiter_loop, daemon=True)
self._reader.start()
self._waiter.start()
def summary(self) -> dict[str, Any]:
with self._lock:
return {
"shell_id": self.shell_id,
"cwd": self.cwd,
"cols": self.cols,
"rows": self.rows,
"state": self.state,
"started_at": self.started_at,
"ended_at": self.ended_at,
"exit_code": self.exit_code,
}
def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]:
with self._lock:
redacted_output = _redact_text(self._output, self._redact_values)
clamped_cursor = min(max(cursor, 0), len(redacted_output))
output = redacted_output[clamped_cursor : clamped_cursor + max_chars]
next_cursor = clamped_cursor + len(output)
payload = self.summary()
payload.update(
{
"cursor": clamped_cursor,
"next_cursor": next_cursor,
"output": output,
"truncated": next_cursor < len(redacted_output),
}
)
return payload
def write(self, text: str, *, append_newline: bool) -> dict[str, Any]:
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
master_fd = self._master_fd
if master_fd is None:
raise RuntimeError(f"shell {self.shell_id} transport is unavailable")
payload = text + ("\n" if append_newline else "")
try:
os.write(master_fd, payload.encode("utf-8"))
except OSError as exc:
self._refresh_process_state()
raise RuntimeError(f"failed to write shell input: {exc}") from exc
response = self.summary()
response.update({"input_length": len(text), "append_newline": append_newline})
return response
def send_signal(self, signal_name: str) -> dict[str, Any]:
signum = SHELL_SIGNAL_MAP.get(signal_name)
if signum is None:
raise ValueError(f"unsupported shell signal: {signal_name}")
if self._process.poll() is not None:
self._refresh_process_state()
with self._lock:
if self.state != "running":
raise RuntimeError(f"shell {self.shell_id} is not running")
pid = self._process.pid
try:
os.killpg(pid, signum)
except ProcessLookupError as exc:
self._refresh_process_state()
raise RuntimeError(f"shell {self.shell_id} is not running") from exc
response = self.summary()
response["signal"] = signal_name
return response
def close(self) -> dict[str, Any]:
if self._process.poll() is None:
try:
os.killpg(self._process.pid, signal.SIGHUP)
except ProcessLookupError:
pass
try:
self._process.wait(timeout=5)
except subprocess.TimeoutExpired:
try:
os.killpg(self._process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
self._process.wait(timeout=5)
else:
self._refresh_process_state()
self._close_master_fd()
if self._reader is not None:
self._reader.join(timeout=1)
if self._waiter is not None:
self._waiter.join(timeout=1)
response = self.summary()
response["closed"] = True
return response
def _reader_loop(self) -> None:
master_fd = self._master_fd
if master_fd is None:
return
while True:
try:
chunk = os.read(master_fd, BUFFER_SIZE)
except OSError:
break
if chunk == b"":
break
decoded = self._decoder.decode(chunk)
if decoded == "":
continue
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
decoded = self._decoder.decode(b"", final=True)
if decoded != "":
with self._lock:
self._output += decoded
with self._log_path.open("a", encoding="utf-8") as handle:
handle.write(decoded)
def _waiter_loop(self) -> None:
exit_code = self._process.wait()
with self._lock:
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _refresh_process_state(self) -> None:
exit_code = self._process.poll()
if exit_code is None:
return
with self._lock:
if self.state == "running":
self.state = "stopped"
self.exit_code = exit_code
self.ended_at = time.time()
self._write_metadata()
def _write_metadata(self) -> None:
self._metadata_path.parent.mkdir(parents=True, exist_ok=True)
self._metadata_path.write_text(json.dumps(self.summary(), indent=2), encoding="utf-8")
def _close_master_fd(self) -> None:
with self._lock:
master_fd = self._master_fd
self._master_fd = None
if master_fd is None:
return
try:
os.close(master_fd)
except OSError:
pass
def _create_shell(
*,
shell_id: str,
cwd_text: str,
cols: int,
rows: int,
env_overrides: dict[str, str] | None = None,
redact_values: list[str] | None = None,
) -> GuestShellSession:
_, cwd_path = _normalize_shell_cwd(cwd_text)
with _SHELLS_LOCK:
if shell_id in _SHELLS:
raise RuntimeError(f"shell {shell_id!r} already exists")
session = GuestShellSession(
shell_id=shell_id,
cwd=cwd_path,
cwd_text=cwd_text,
cols=cols,
rows=rows,
env_overrides=env_overrides,
redact_values=redact_values,
)
_SHELLS[shell_id] = session
return session
def _get_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS[shell_id]
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _remove_shell(shell_id: str) -> GuestShellSession:
with _SHELLS_LOCK:
try:
return _SHELLS.pop(shell_id)
except KeyError as exc:
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
def _read_service_metadata(service_name: str) -> dict[str, Any]:
metadata_path = _service_metadata_path(service_name)
if not metadata_path.exists():
raise RuntimeError(f"service {service_name!r} does not exist")
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError(f"service record for {service_name!r} is invalid")
return payload
def _write_service_metadata(service_name: str, payload: dict[str, Any]) -> None:
_service_metadata_path(service_name).write_text(
json.dumps(payload, indent=2, sort_keys=True),
encoding="utf-8",
)
def _service_exit_code(service_name: str) -> int | None:
status_path = _service_status_path(service_name)
if not status_path.exists():
return None
raw_value = status_path.read_text(encoding="utf-8", errors="ignore").strip()
if raw_value == "":
return None
return int(raw_value)
def _service_pid_running(pid: int | None) -> bool:
if pid is None:
return False
try:
os.kill(pid, 0)
except ProcessLookupError:
return False
except PermissionError:
return True
return True
def _tail_service_text(path: Path, *, tail_lines: int | None) -> tuple[str, bool]:
if not path.exists():
return "", False
text = path.read_text(encoding="utf-8", errors="replace")
if tail_lines is None:
return text, False
lines = text.splitlines(keepends=True)
if len(lines) <= tail_lines:
return text, False
return "".join(lines[-tail_lines:]), True
def _stop_service_process(pid: int) -> tuple[bool, bool]:
try:
os.killpg(pid, signal.SIGTERM)
except ProcessLookupError:
return False, False
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
if not _service_pid_running(pid):
return True, False
time.sleep(0.1)
try:
os.killpg(pid, signal.SIGKILL)
except ProcessLookupError:
return True, False
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
if not _service_pid_running(pid):
return True, True
time.sleep(0.1)
return True, True
def _refresh_service_payload(service_name: str, payload: dict[str, Any]) -> dict[str, Any]:
if str(payload.get("state", "stopped")) != "running":
return payload
pid = payload.get("pid")
normalized_pid = None if pid is None else int(pid)
if _service_pid_running(normalized_pid):
return payload
refreshed = dict(payload)
refreshed["state"] = "exited"
refreshed["ended_at"] = refreshed.get("ended_at") or time.time()
refreshed["exit_code"] = _service_exit_code(service_name)
_write_service_metadata(service_name, refreshed)
return refreshed
def _run_readiness_probe(
readiness: dict[str, Any] | None,
*,
cwd: Path,
env: dict[str, str] | None = None,
) -> bool:
if readiness is None:
return True
readiness_type = str(readiness["type"])
if readiness_type == "file":
_, ready_path = _normalize_destination(str(readiness["path"]))
return ready_path.exists()
if readiness_type == "tcp":
host, raw_port = str(readiness["address"]).rsplit(":", 1)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.settimeout(1)
try:
sock.connect((host, int(raw_port)))
except OSError:
return False
return True
if readiness_type == "http":
request = urllib.request.Request(str(readiness["url"]), method="GET")
try:
with urllib.request.urlopen(request, timeout=2) as response: # noqa: S310
return 200 <= int(response.status) < 400
except (urllib.error.URLError, TimeoutError, ValueError):
return False
if readiness_type == "command":
command_env = os.environ.copy()
if env is not None:
command_env.update(env)
proc = subprocess.run( # noqa: S603
["/bin/sh", "-lc", str(readiness["command"])],
cwd=str(cwd),
text=True,
capture_output=True,
env=command_env,
timeout=10,
check=False,
)
return proc.returncode == 0
raise RuntimeError(f"unsupported readiness type: {readiness_type}")
def _start_service(
*,
service_name: str,
command: str,
cwd_text: str,
readiness: dict[str, Any] | None,
ready_timeout_seconds: int,
ready_interval_ms: int,
env: dict[str, str] | None = None,
) -> dict[str, Any]:
normalized_service_name = _normalize_service_name(service_name)
normalized_cwd, cwd_path = _normalize_shell_cwd(cwd_text)
existing = None
metadata_path = _service_metadata_path(normalized_service_name)
if metadata_path.exists():
existing = _refresh_service_payload(
normalized_service_name,
_read_service_metadata(normalized_service_name),
)
if existing is not None and str(existing.get("state", "stopped")) == "running":
raise RuntimeError(f"service {normalized_service_name!r} is already running")
SERVICE_ROOT.mkdir(parents=True, exist_ok=True)
stdout_path = _service_stdout_path(normalized_service_name)
stderr_path = _service_stderr_path(normalized_service_name)
status_path = _service_status_path(normalized_service_name)
runner_path = _service_runner_path(normalized_service_name)
stdout_path.write_text("", encoding="utf-8")
stderr_path.write_text("", encoding="utf-8")
status_path.unlink(missing_ok=True)
runner_path.write_text(
"\n".join(
[
"#!/bin/sh",
"set +e",
f"cd {shlex.quote(str(cwd_path))}",
(
f"/bin/sh -lc {shlex.quote(command)}"
f" >> {shlex.quote(str(stdout_path))}"
f" 2>> {shlex.quote(str(stderr_path))}"
),
"status=$?",
f"printf '%s' \"$status\" > {shlex.quote(str(status_path))}",
"exit \"$status\"",
]
)
+ "\n",
encoding="utf-8",
)
runner_path.chmod(0o700)
service_env = os.environ.copy()
if env is not None:
service_env.update(env)
process = subprocess.Popen( # noqa: S603
[str(runner_path)],
cwd=str(cwd_path),
env=service_env,
text=True,
start_new_session=True,
)
payload: dict[str, Any] = {
"service_name": normalized_service_name,
"command": command,
"cwd": normalized_cwd,
"state": "running",
"started_at": time.time(),
"readiness": readiness,
"ready_at": None,
"ended_at": None,
"exit_code": None,
"pid": process.pid,
"stop_reason": None,
}
_write_service_metadata(normalized_service_name, payload)
deadline = time.monotonic() + ready_timeout_seconds
while True:
payload = _refresh_service_payload(normalized_service_name, payload)
if str(payload.get("state", "stopped")) != "running":
payload["state"] = "failed"
payload["stop_reason"] = "process_exited_before_ready"
payload["ended_at"] = payload.get("ended_at") or time.time()
_write_service_metadata(normalized_service_name, payload)
return payload
if _run_readiness_probe(readiness, cwd=cwd_path, env=env):
payload["ready_at"] = time.time()
_write_service_metadata(normalized_service_name, payload)
return payload
if time.monotonic() >= deadline:
_stop_service_process(process.pid)
payload = _refresh_service_payload(normalized_service_name, payload)
payload["state"] = "failed"
payload["stop_reason"] = "readiness_timeout"
payload["ended_at"] = payload.get("ended_at") or time.time()
_write_service_metadata(normalized_service_name, payload)
return payload
time.sleep(max(ready_interval_ms, 1) / 1000)
def _status_service(service_name: str) -> dict[str, Any]:
normalized_service_name = _normalize_service_name(service_name)
return _refresh_service_payload(
normalized_service_name,
_read_service_metadata(normalized_service_name),
)
def _logs_service(service_name: str, *, tail_lines: int | None) -> dict[str, Any]:
normalized_service_name = _normalize_service_name(service_name)
payload = _status_service(normalized_service_name)
stdout, stdout_truncated = _tail_service_text(
_service_stdout_path(normalized_service_name),
tail_lines=tail_lines,
)
stderr, stderr_truncated = _tail_service_text(
_service_stderr_path(normalized_service_name),
tail_lines=tail_lines,
)
payload.update(
{
"stdout": stdout,
"stderr": stderr,
"tail_lines": tail_lines,
"truncated": stdout_truncated or stderr_truncated,
}
)
return payload
def _stop_service(service_name: str) -> dict[str, Any]:
normalized_service_name = _normalize_service_name(service_name)
payload = _status_service(normalized_service_name)
pid = payload.get("pid")
if pid is None:
return payload
if str(payload.get("state", "stopped")) == "running":
_, killed = _stop_service_process(int(pid))
payload = _status_service(normalized_service_name)
payload["state"] = "stopped"
payload["stop_reason"] = "sigkill" if killed else "sigterm"
payload["ended_at"] = payload.get("ended_at") or time.time()
_write_service_metadata(normalized_service_name, payload)
return payload
def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
action = str(request.get("action", "exec"))
if action == "extract_archive":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
return _extract_archive(payload, destination)
if action == "install_secrets":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
payload = _read_exact(conn, archive_size)
return _install_secrets_archive(payload)
if action == "list_workspace":
return _list_workspace(
str(request.get("path", "/workspace")),
recursive=bool(request.get("recursive", False)),
)
if action == "read_workspace_file":
return _read_workspace_file(
str(request.get("path", "/workspace")),
max_bytes=int(request.get("max_bytes", WORKSPACE_FILE_MAX_BYTES)),
)
if action == "write_workspace_file":
return _write_workspace_file(
str(request.get("path", "/workspace")),
text=str(request.get("text", "")),
)
if action == "delete_workspace_path":
return _delete_workspace_path(str(request.get("path", "/workspace")))
if action == "open_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace")))
env_payload = request.get("env")
env_overrides = None
if env_payload is not None:
if not isinstance(env_payload, dict):
raise RuntimeError("shell env must be a JSON object")
env_overrides = {
_normalize_secret_name(str(key)): str(value) for key, value in env_payload.items()
}
redact_values_payload = request.get("redact_values")
redact_values: list[str] | None = None
if redact_values_payload is not None:
if not isinstance(redact_values_payload, list):
raise RuntimeError("redact_values must be a list")
redact_values = [str(item) for item in redact_values_payload]
session = _create_shell(
shell_id=shell_id,
cwd_text=cwd_text,
cols=int(request.get("cols", 120)),
rows=int(request.get("rows", 30)),
env_overrides=env_overrides,
redact_values=redact_values,
)
return session.summary()
if action == "read_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).read(
cursor=int(request.get("cursor", 0)),
max_chars=int(request.get("max_chars", 65536)),
)
if action == "write_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _get_shell(shell_id).write(
str(request.get("input", "")),
append_newline=bool(request.get("append_newline", True)),
)
if action == "signal_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
signal_name = str(request.get("signal", "INT")).upper()
if signal_name not in SHELL_SIGNAL_NAMES:
raise RuntimeError(
f"signal must be one of: {', '.join(SHELL_SIGNAL_NAMES)}"
)
return _get_shell(shell_id).send_signal(signal_name)
if action == "close_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
return _remove_shell(shell_id).close()
if action == "start_service":
service_name = str(request.get("service_name", "")).strip()
command = str(request.get("command", ""))
cwd_text = str(request.get("cwd", "/workspace"))
readiness = request.get("readiness")
readiness_payload = dict(readiness) if isinstance(readiness, dict) else None
env_payload = request.get("env")
env = None
if env_payload is not None:
if not isinstance(env_payload, dict):
raise RuntimeError("service env must be a JSON object")
env = {
_normalize_secret_name(str(key)): str(value)
for key, value in env_payload.items()
}
return _start_service(
service_name=service_name,
command=command,
cwd_text=cwd_text,
readiness=readiness_payload,
ready_timeout_seconds=int(request.get("ready_timeout_seconds", 30)),
ready_interval_ms=int(request.get("ready_interval_ms", 500)),
env=env,
)
if action == "status_service":
service_name = str(request.get("service_name", "")).strip()
return _status_service(service_name)
if action == "logs_service":
service_name = str(request.get("service_name", "")).strip()
tail_lines = request.get("tail_lines")
normalized_tail_lines = None if tail_lines is None else int(tail_lines)
return _logs_service(service_name, tail_lines=normalized_tail_lines)
if action == "stop_service":
service_name = str(request.get("service_name", "")).strip()
return _stop_service(service_name)
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
env_payload = request.get("env")
env = None
if env_payload is not None:
if not isinstance(env_payload, dict):
raise RuntimeError("exec env must be a JSON object")
env = {_normalize_secret_name(str(key)): str(value) for key, value in env_payload.items()}
return _run_command(command, timeout_seconds, env=env)
def main() -> None:
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
SERVICE_ROOT.mkdir(parents=True, exist_ok=True)
SECRET_ROOT.mkdir(parents=True, exist_ok=True)
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise SystemExit("AF_VSOCK is unavailable")
with socket.socket(family, socket.SOCK_STREAM) as server:
server.bind((socket.VMADDR_CID_ANY, PORT))
server.listen(1)
while True:
conn, _ = server.accept()
with conn:
try:
request = _read_request(conn)
if str(request.get("action", "")) == "export_archive":
export = _prepare_export_archive(str(request.get("path", "/workspace")))
try:
header = {
"workspace_path": export["workspace_path"],
"artifact_type": export["artifact_type"],
"archive_size": export["archive_size"],
"entry_count": export["entry_count"],
"bytes_written": export["bytes_written"],
}
conn.sendall((json.dumps(header) + "\n").encode("utf-8"))
with Path(str(export["archive_path"])).open("rb") as handle:
while True:
chunk = handle.read(BUFFER_SIZE)
if chunk == b"":
break
conn.sendall(chunk)
finally:
Path(str(export["archive_path"])).unlink(missing_ok=True)
continue
response = _dispatch(request, conn)
except Exception as exc: # noqa: BLE001
response = {"error": str(exc)}
conn.sendall((json.dumps(response) + "\n").encode("utf-8"))
if __name__ == "__main__":
main()