Add workspace service lifecycle with typed readiness
Make persistent workspaces capable of running long-lived background processes instead of forcing everything through one-shot exec calls. Add workspace service start/list/status/logs/stop across the CLI, Python SDK, and MCP server, with multiple named services per workspace, typed readiness probes (file, tcp, http, and command), and aggregate service counts on workspace status. Keep service state and logs outside /workspace so diff and export semantics stay workspace-scoped, and extend the guest agent plus backends to persist service records and logs across separate calls. Update the 2.7.0 docs, examples, changelog, and roadmap milestone to reflect the shipped surface. Validation: uv lock; UV_CACHE_DIR=.uv-cache make check; UV_CACHE_DIR=.uv-cache make dist-check; real guest-backed Firecracker smoke for workspace create, two service starts, list/status/logs, diff unaffected, stop, and delete.
This commit is contained in:
parent
84a7e18d4d
commit
f504f0a331
28 changed files with 4098 additions and 124 deletions
|
|
@ -8,7 +8,8 @@ import fcntl
|
|||
import io
|
||||
import json
|
||||
import os
|
||||
import pty
|
||||
import re
|
||||
import shlex
|
||||
import signal
|
||||
import socket
|
||||
import struct
|
||||
|
|
@ -18,6 +19,8 @@ import tempfile
|
|||
import termios
|
||||
import threading
|
||||
import time
|
||||
import urllib.error
|
||||
import urllib.request
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -25,6 +28,8 @@ PORT = 5005
|
|||
BUFFER_SIZE = 65536
|
||||
WORKSPACE_ROOT = PurePosixPath("/workspace")
|
||||
SHELL_ROOT = Path("/run/pyro-shells")
|
||||
SERVICE_ROOT = Path("/run/pyro-services")
|
||||
SERVICE_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
|
||||
SHELL_SIGNAL_MAP = {
|
||||
"HUP": signal.SIGHUP,
|
||||
"INT": signal.SIGINT,
|
||||
|
|
@ -105,6 +110,35 @@ def _normalize_shell_cwd(cwd: str) -> tuple[str, Path]:
|
|||
return str(normalized), host_path
|
||||
|
||||
|
||||
def _normalize_service_name(service_name: str) -> str:
|
||||
normalized = service_name.strip()
|
||||
if normalized == "":
|
||||
raise RuntimeError("service_name is required")
|
||||
if SERVICE_NAME_RE.fullmatch(normalized) is None:
|
||||
raise RuntimeError("service_name is invalid")
|
||||
return normalized
|
||||
|
||||
|
||||
def _service_stdout_path(service_name: str) -> Path:
|
||||
return SERVICE_ROOT / f"{service_name}.stdout"
|
||||
|
||||
|
||||
def _service_stderr_path(service_name: str) -> Path:
|
||||
return SERVICE_ROOT / f"{service_name}.stderr"
|
||||
|
||||
|
||||
def _service_status_path(service_name: str) -> Path:
|
||||
return SERVICE_ROOT / f"{service_name}.status"
|
||||
|
||||
|
||||
def _service_runner_path(service_name: str) -> Path:
|
||||
return SERVICE_ROOT / f"{service_name}.runner.sh"
|
||||
|
||||
|
||||
def _service_metadata_path(service_name: str) -> Path:
|
||||
return SERVICE_ROOT / f"{service_name}.json"
|
||||
|
||||
|
||||
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
|
||||
target = link_target.strip()
|
||||
if target == "":
|
||||
|
|
@ -286,7 +320,7 @@ class GuestShellSession:
|
|||
self._log_path = SHELL_ROOT / f"{shell_id}.log"
|
||||
self._master_fd: int | None = None
|
||||
|
||||
master_fd, slave_fd = pty.openpty()
|
||||
master_fd, slave_fd = os.openpty()
|
||||
try:
|
||||
_set_pty_size(slave_fd, rows, cols)
|
||||
env = os.environ.copy()
|
||||
|
|
@ -512,6 +546,268 @@ def _remove_shell(shell_id: str) -> GuestShellSession:
|
|||
raise RuntimeError(f"shell {shell_id!r} does not exist") from exc
|
||||
|
||||
|
||||
def _read_service_metadata(service_name: str) -> dict[str, Any]:
|
||||
metadata_path = _service_metadata_path(service_name)
|
||||
if not metadata_path.exists():
|
||||
raise RuntimeError(f"service {service_name!r} does not exist")
|
||||
payload = json.loads(metadata_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError(f"service record for {service_name!r} is invalid")
|
||||
return payload
|
||||
|
||||
|
||||
def _write_service_metadata(service_name: str, payload: dict[str, Any]) -> None:
|
||||
_service_metadata_path(service_name).write_text(
|
||||
json.dumps(payload, indent=2, sort_keys=True),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
|
||||
def _service_exit_code(service_name: str) -> int | None:
|
||||
status_path = _service_status_path(service_name)
|
||||
if not status_path.exists():
|
||||
return None
|
||||
raw_value = status_path.read_text(encoding="utf-8", errors="ignore").strip()
|
||||
if raw_value == "":
|
||||
return None
|
||||
return int(raw_value)
|
||||
|
||||
|
||||
def _service_pid_running(pid: int | None) -> bool:
|
||||
if pid is None:
|
||||
return False
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except ProcessLookupError:
|
||||
return False
|
||||
except PermissionError:
|
||||
return True
|
||||
return True
|
||||
|
||||
|
||||
def _tail_service_text(path: Path, *, tail_lines: int | None) -> tuple[str, bool]:
|
||||
if not path.exists():
|
||||
return "", False
|
||||
text = path.read_text(encoding="utf-8", errors="replace")
|
||||
if tail_lines is None:
|
||||
return text, False
|
||||
lines = text.splitlines(keepends=True)
|
||||
if len(lines) <= tail_lines:
|
||||
return text, False
|
||||
return "".join(lines[-tail_lines:]), True
|
||||
|
||||
|
||||
def _stop_service_process(pid: int) -> tuple[bool, bool]:
|
||||
try:
|
||||
os.killpg(pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
return False, False
|
||||
deadline = time.monotonic() + 5
|
||||
while time.monotonic() < deadline:
|
||||
if not _service_pid_running(pid):
|
||||
return True, False
|
||||
time.sleep(0.1)
|
||||
try:
|
||||
os.killpg(pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
return True, False
|
||||
deadline = time.monotonic() + 5
|
||||
while time.monotonic() < deadline:
|
||||
if not _service_pid_running(pid):
|
||||
return True, True
|
||||
time.sleep(0.1)
|
||||
return True, True
|
||||
|
||||
|
||||
def _refresh_service_payload(service_name: str, payload: dict[str, Any]) -> dict[str, Any]:
|
||||
if str(payload.get("state", "stopped")) != "running":
|
||||
return payload
|
||||
pid = payload.get("pid")
|
||||
normalized_pid = None if pid is None else int(pid)
|
||||
if _service_pid_running(normalized_pid):
|
||||
return payload
|
||||
refreshed = dict(payload)
|
||||
refreshed["state"] = "exited"
|
||||
refreshed["ended_at"] = refreshed.get("ended_at") or time.time()
|
||||
refreshed["exit_code"] = _service_exit_code(service_name)
|
||||
_write_service_metadata(service_name, refreshed)
|
||||
return refreshed
|
||||
|
||||
|
||||
def _run_readiness_probe(readiness: dict[str, Any] | None, *, cwd: Path) -> bool:
|
||||
if readiness is None:
|
||||
return True
|
||||
readiness_type = str(readiness["type"])
|
||||
if readiness_type == "file":
|
||||
_, ready_path = _normalize_destination(str(readiness["path"]))
|
||||
return ready_path.exists()
|
||||
if readiness_type == "tcp":
|
||||
host, raw_port = str(readiness["address"]).rsplit(":", 1)
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
sock.settimeout(1)
|
||||
try:
|
||||
sock.connect((host, int(raw_port)))
|
||||
except OSError:
|
||||
return False
|
||||
return True
|
||||
if readiness_type == "http":
|
||||
request = urllib.request.Request(str(readiness["url"]), method="GET")
|
||||
try:
|
||||
with urllib.request.urlopen(request, timeout=2) as response: # noqa: S310
|
||||
return 200 <= int(response.status) < 400
|
||||
except (urllib.error.URLError, TimeoutError, ValueError):
|
||||
return False
|
||||
if readiness_type == "command":
|
||||
proc = subprocess.run( # noqa: S603
|
||||
["/bin/sh", "-lc", str(readiness["command"])],
|
||||
cwd=str(cwd),
|
||||
text=True,
|
||||
capture_output=True,
|
||||
timeout=10,
|
||||
check=False,
|
||||
)
|
||||
return proc.returncode == 0
|
||||
raise RuntimeError(f"unsupported readiness type: {readiness_type}")
|
||||
|
||||
|
||||
def _start_service(
|
||||
*,
|
||||
service_name: str,
|
||||
command: str,
|
||||
cwd_text: str,
|
||||
readiness: dict[str, Any] | None,
|
||||
ready_timeout_seconds: int,
|
||||
ready_interval_ms: int,
|
||||
) -> dict[str, Any]:
|
||||
normalized_service_name = _normalize_service_name(service_name)
|
||||
normalized_cwd, cwd_path = _normalize_shell_cwd(cwd_text)
|
||||
existing = None
|
||||
metadata_path = _service_metadata_path(normalized_service_name)
|
||||
if metadata_path.exists():
|
||||
existing = _refresh_service_payload(
|
||||
normalized_service_name,
|
||||
_read_service_metadata(normalized_service_name),
|
||||
)
|
||||
if existing is not None and str(existing.get("state", "stopped")) == "running":
|
||||
raise RuntimeError(f"service {normalized_service_name!r} is already running")
|
||||
SERVICE_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
stdout_path = _service_stdout_path(normalized_service_name)
|
||||
stderr_path = _service_stderr_path(normalized_service_name)
|
||||
status_path = _service_status_path(normalized_service_name)
|
||||
runner_path = _service_runner_path(normalized_service_name)
|
||||
stdout_path.write_text("", encoding="utf-8")
|
||||
stderr_path.write_text("", encoding="utf-8")
|
||||
status_path.unlink(missing_ok=True)
|
||||
runner_path.write_text(
|
||||
"\n".join(
|
||||
[
|
||||
"#!/bin/sh",
|
||||
"set +e",
|
||||
f"cd {shlex.quote(str(cwd_path))}",
|
||||
(
|
||||
f"/bin/sh -lc {shlex.quote(command)}"
|
||||
f" >> {shlex.quote(str(stdout_path))}"
|
||||
f" 2>> {shlex.quote(str(stderr_path))}"
|
||||
),
|
||||
"status=$?",
|
||||
f"printf '%s' \"$status\" > {shlex.quote(str(status_path))}",
|
||||
"exit \"$status\"",
|
||||
]
|
||||
)
|
||||
+ "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
runner_path.chmod(0o700)
|
||||
process = subprocess.Popen( # noqa: S603
|
||||
[str(runner_path)],
|
||||
cwd=str(cwd_path),
|
||||
text=True,
|
||||
start_new_session=True,
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"service_name": normalized_service_name,
|
||||
"command": command,
|
||||
"cwd": normalized_cwd,
|
||||
"state": "running",
|
||||
"started_at": time.time(),
|
||||
"readiness": readiness,
|
||||
"ready_at": None,
|
||||
"ended_at": None,
|
||||
"exit_code": None,
|
||||
"pid": process.pid,
|
||||
"stop_reason": None,
|
||||
}
|
||||
_write_service_metadata(normalized_service_name, payload)
|
||||
deadline = time.monotonic() + ready_timeout_seconds
|
||||
while True:
|
||||
payload = _refresh_service_payload(normalized_service_name, payload)
|
||||
if str(payload.get("state", "stopped")) != "running":
|
||||
payload["state"] = "failed"
|
||||
payload["stop_reason"] = "process_exited_before_ready"
|
||||
payload["ended_at"] = payload.get("ended_at") or time.time()
|
||||
_write_service_metadata(normalized_service_name, payload)
|
||||
return payload
|
||||
if _run_readiness_probe(readiness, cwd=cwd_path):
|
||||
payload["ready_at"] = time.time()
|
||||
_write_service_metadata(normalized_service_name, payload)
|
||||
return payload
|
||||
if time.monotonic() >= deadline:
|
||||
_stop_service_process(process.pid)
|
||||
payload = _refresh_service_payload(normalized_service_name, payload)
|
||||
payload["state"] = "failed"
|
||||
payload["stop_reason"] = "readiness_timeout"
|
||||
payload["ended_at"] = payload.get("ended_at") or time.time()
|
||||
_write_service_metadata(normalized_service_name, payload)
|
||||
return payload
|
||||
time.sleep(max(ready_interval_ms, 1) / 1000)
|
||||
|
||||
|
||||
def _status_service(service_name: str) -> dict[str, Any]:
|
||||
normalized_service_name = _normalize_service_name(service_name)
|
||||
return _refresh_service_payload(
|
||||
normalized_service_name,
|
||||
_read_service_metadata(normalized_service_name),
|
||||
)
|
||||
|
||||
|
||||
def _logs_service(service_name: str, *, tail_lines: int | None) -> dict[str, Any]:
|
||||
normalized_service_name = _normalize_service_name(service_name)
|
||||
payload = _status_service(normalized_service_name)
|
||||
stdout, stdout_truncated = _tail_service_text(
|
||||
_service_stdout_path(normalized_service_name),
|
||||
tail_lines=tail_lines,
|
||||
)
|
||||
stderr, stderr_truncated = _tail_service_text(
|
||||
_service_stderr_path(normalized_service_name),
|
||||
tail_lines=tail_lines,
|
||||
)
|
||||
payload.update(
|
||||
{
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"tail_lines": tail_lines,
|
||||
"truncated": stdout_truncated or stderr_truncated,
|
||||
}
|
||||
)
|
||||
return payload
|
||||
|
||||
|
||||
def _stop_service(service_name: str) -> dict[str, Any]:
|
||||
normalized_service_name = _normalize_service_name(service_name)
|
||||
payload = _status_service(normalized_service_name)
|
||||
pid = payload.get("pid")
|
||||
if pid is None:
|
||||
return payload
|
||||
if str(payload.get("state", "stopped")) == "running":
|
||||
_, killed = _stop_service_process(int(pid))
|
||||
payload = _status_service(normalized_service_name)
|
||||
payload["state"] = "stopped"
|
||||
payload["stop_reason"] = "sigkill" if killed else "sigterm"
|
||||
payload["ended_at"] = payload.get("ended_at") or time.time()
|
||||
_write_service_metadata(normalized_service_name, payload)
|
||||
return payload
|
||||
|
||||
|
||||
def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
|
||||
action = str(request.get("action", "exec"))
|
||||
if action == "extract_archive":
|
||||
|
|
@ -564,6 +860,31 @@ def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
|
|||
if shell_id == "":
|
||||
raise RuntimeError("shell_id is required")
|
||||
return _remove_shell(shell_id).close()
|
||||
if action == "start_service":
|
||||
service_name = str(request.get("service_name", "")).strip()
|
||||
command = str(request.get("command", ""))
|
||||
cwd_text = str(request.get("cwd", "/workspace"))
|
||||
readiness = request.get("readiness")
|
||||
readiness_payload = dict(readiness) if isinstance(readiness, dict) else None
|
||||
return _start_service(
|
||||
service_name=service_name,
|
||||
command=command,
|
||||
cwd_text=cwd_text,
|
||||
readiness=readiness_payload,
|
||||
ready_timeout_seconds=int(request.get("ready_timeout_seconds", 30)),
|
||||
ready_interval_ms=int(request.get("ready_interval_ms", 500)),
|
||||
)
|
||||
if action == "status_service":
|
||||
service_name = str(request.get("service_name", "")).strip()
|
||||
return _status_service(service_name)
|
||||
if action == "logs_service":
|
||||
service_name = str(request.get("service_name", "")).strip()
|
||||
tail_lines = request.get("tail_lines")
|
||||
normalized_tail_lines = None if tail_lines is None else int(tail_lines)
|
||||
return _logs_service(service_name, tail_lines=normalized_tail_lines)
|
||||
if action == "stop_service":
|
||||
service_name = str(request.get("service_name", "")).strip()
|
||||
return _stop_service(service_name)
|
||||
command = str(request.get("command", ""))
|
||||
timeout_seconds = int(request.get("timeout_seconds", 30))
|
||||
return _run_command(command, timeout_seconds)
|
||||
|
|
@ -571,6 +892,7 @@ def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
|
|||
|
||||
def main() -> None:
|
||||
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
SERVICE_ROOT.mkdir(parents=True, exist_ok=True)
|
||||
family = getattr(socket, "AF_VSOCK", None)
|
||||
if family is None:
|
||||
raise SystemExit("AF_VSOCK is unavailable")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue