Add guest-only workspace secrets

Add explicit workspace secrets across the CLI, SDK, and MCP, with create-time secret definitions and per-call secret-to-env mapping for exec, shell open, and service start. Persist only safe secret metadata in workspace records, materialize secret files under /run/pyro-secrets, and redact secret values from exec output, shell reads, service logs, and surfaced errors.

Fix the remaining real-guest shell gap by shipping bundled guest init alongside the guest agent and patching both into guest-backed workspace rootfs images before boot. The new init mounts devpts so PTY shells work on Firecracker guests, while reset continues to recreate the sandbox and re-materialize secrets from stored task-local secret material.

Validation: uv lock; UV_CACHE_DIR=.uv-cache make check; UV_CACHE_DIR=.uv-cache make dist-check; and a real guest-backed Firecracker smoke covering workspace create with secrets, secret-backed exec, shell, service, reset, and delete.
This commit is contained in:
Thales Maciel 2026-03-12 15:43:34 -03:00
parent 18b8fd2a7d
commit fc72fcd3a1
32 changed files with 1980 additions and 181 deletions

View file

@ -10,6 +10,7 @@ import json
import os
import re
import shlex
import shutil
import signal
import socket
import struct
@ -29,6 +30,7 @@ BUFFER_SIZE = 65536
WORKSPACE_ROOT = PurePosixPath("/workspace")
SHELL_ROOT = Path("/run/pyro-shells")
SERVICE_ROOT = Path("/run/pyro-services")
SECRET_ROOT = Path("/run/pyro-secrets")
SERVICE_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
SHELL_SIGNAL_MAP = {
"HUP": signal.SIGHUP,
@ -42,6 +44,17 @@ _SHELLS: dict[str, "GuestShellSession"] = {}
_SHELLS_LOCK = threading.Lock()
def _redact_text(text: str, redact_values: list[str]) -> str:
redacted = text
for secret_value in sorted(
{item for item in redact_values if item != ""},
key=len,
reverse=True,
):
redacted = redacted.replace(secret_value, "[REDACTED]")
return redacted
def _read_request(conn: socket.socket) -> dict[str, Any]:
chunks: list[bytes] = []
while True:
@ -139,6 +152,15 @@ def _service_metadata_path(service_name: str) -> Path:
return SERVICE_ROOT / f"{service_name}.json"
def _normalize_secret_name(secret_name: str) -> str:
normalized = secret_name.strip()
if normalized == "":
raise RuntimeError("secret name is required")
if re.fullmatch(r"^[A-Za-z_][A-Za-z0-9_]{0,63}$", normalized) is None:
raise RuntimeError("secret name is invalid")
return normalized
def _validate_symlink_target(member_path: PurePosixPath, link_target: str) -> None:
target = link_target.strip()
if target == "":
@ -215,6 +237,49 @@ def _extract_archive(payload: bytes, destination: str) -> dict[str, Any]:
}
def _install_secrets_archive(payload: bytes) -> dict[str, Any]:
SECRET_ROOT.mkdir(parents=True, exist_ok=True)
for existing in SECRET_ROOT.iterdir():
if existing.is_dir() and not existing.is_symlink():
shutil.rmtree(existing, ignore_errors=True)
else:
existing.unlink(missing_ok=True)
bytes_written = 0
entry_count = 0
with tarfile.open(fileobj=io.BytesIO(payload), mode="r:*") as archive:
for member in archive.getmembers():
member_name = _normalize_member_name(member.name)
target_path = SECRET_ROOT.joinpath(*member_name.parts)
entry_count += 1
if member.isdir():
target_path.mkdir(parents=True, exist_ok=True)
target_path.chmod(0o700)
continue
if member.isfile():
target_path.parent.mkdir(parents=True, exist_ok=True)
target_path.parent.chmod(0o700)
source = archive.extractfile(member)
if source is None:
raise RuntimeError(f"failed to read secret archive member: {member.name}")
with target_path.open("wb") as handle:
while True:
chunk = source.read(BUFFER_SIZE)
if chunk == b"":
break
handle.write(chunk)
target_path.chmod(0o600)
bytes_written += member.size
continue
if member.issym() or member.islnk():
raise RuntimeError(f"secret archive may not contain links: {member.name}")
raise RuntimeError(f"unsupported secret archive member type: {member.name}")
return {
"destination": str(SECRET_ROOT),
"entry_count": entry_count,
"bytes_written": bytes_written,
}
def _inspect_archive(archive_path: Path) -> tuple[int, int]:
entry_count = 0
bytes_written = 0
@ -263,13 +328,22 @@ def _prepare_export_archive(path: str) -> dict[str, Any]:
raise
def _run_command(command: str, timeout_seconds: int) -> dict[str, Any]:
def _run_command(
command: str,
timeout_seconds: int,
*,
env: dict[str, str] | None = None,
) -> dict[str, Any]:
started = time.monotonic()
command_env = os.environ.copy()
if env is not None:
command_env.update(env)
try:
proc = subprocess.run(
["/bin/sh", "-lc", command],
text=True,
capture_output=True,
env=command_env,
timeout=timeout_seconds,
check=False,
)
@ -293,6 +367,16 @@ def _set_pty_size(fd: int, rows: int, cols: int) -> None:
fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)
def _shell_argv(*, interactive: bool) -> list[str]:
shell_program = shutil.which("bash") or "/bin/sh"
argv = [shell_program]
if shell_program.endswith("bash"):
argv.extend(["--noprofile", "--norc"])
if interactive:
argv.append("-i")
return argv
class GuestShellSession:
"""In-guest PTY-backed interactive shell session."""
@ -304,6 +388,8 @@ class GuestShellSession:
cwd_text: str,
cols: int,
rows: int,
env_overrides: dict[str, str] | None = None,
redact_values: list[str] | None = None,
) -> None:
self.shell_id = shell_id
self.cwd = cwd_text
@ -316,6 +402,7 @@ class GuestShellSession:
self._lock = threading.RLock()
self._output = ""
self._decoder = codecs.getincrementaldecoder("utf-8")("replace")
self._redact_values = list(redact_values or [])
self._metadata_path = SHELL_ROOT / f"{shell_id}.json"
self._log_path = SHELL_ROOT / f"{shell_id}.log"
self._master_fd: int | None = None
@ -331,8 +418,10 @@ class GuestShellSession:
"PROMPT_COMMAND": "",
}
)
if env_overrides is not None:
env.update(env_overrides)
process = subprocess.Popen( # noqa: S603
["/bin/bash", "--noprofile", "--norc", "-i"],
_shell_argv(interactive=True),
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
@ -371,8 +460,9 @@ class GuestShellSession:
def read(self, *, cursor: int, max_chars: int) -> dict[str, Any]:
with self._lock:
clamped_cursor = min(max(cursor, 0), len(self._output))
output = self._output[clamped_cursor : clamped_cursor + max_chars]
redacted_output = _redact_text(self._output, self._redact_values)
clamped_cursor = min(max(cursor, 0), len(redacted_output))
output = redacted_output[clamped_cursor : clamped_cursor + max_chars]
next_cursor = clamped_cursor + len(output)
payload = self.summary()
payload.update(
@ -380,7 +470,7 @@ class GuestShellSession:
"cursor": clamped_cursor,
"next_cursor": next_cursor,
"output": output,
"truncated": next_cursor < len(self._output),
"truncated": next_cursor < len(redacted_output),
}
)
return payload
@ -514,6 +604,8 @@ def _create_shell(
cwd_text: str,
cols: int,
rows: int,
env_overrides: dict[str, str] | None = None,
redact_values: list[str] | None = None,
) -> GuestShellSession:
_, cwd_path = _normalize_shell_cwd(cwd_text)
with _SHELLS_LOCK:
@ -525,6 +617,8 @@ def _create_shell(
cwd_text=cwd_text,
cols=cols,
rows=rows,
env_overrides=env_overrides,
redact_values=redact_values,
)
_SHELLS[shell_id] = session
return session
@ -634,7 +728,12 @@ def _refresh_service_payload(service_name: str, payload: dict[str, Any]) -> dict
return refreshed
def _run_readiness_probe(readiness: dict[str, Any] | None, *, cwd: Path) -> bool:
def _run_readiness_probe(
readiness: dict[str, Any] | None,
*,
cwd: Path,
env: dict[str, str] | None = None,
) -> bool:
if readiness is None:
return True
readiness_type = str(readiness["type"])
@ -658,11 +757,15 @@ def _run_readiness_probe(readiness: dict[str, Any] | None, *, cwd: Path) -> bool
except (urllib.error.URLError, TimeoutError, ValueError):
return False
if readiness_type == "command":
command_env = os.environ.copy()
if env is not None:
command_env.update(env)
proc = subprocess.run( # noqa: S603
["/bin/sh", "-lc", str(readiness["command"])],
cwd=str(cwd),
text=True,
capture_output=True,
env=command_env,
timeout=10,
check=False,
)
@ -678,6 +781,7 @@ def _start_service(
readiness: dict[str, Any] | None,
ready_timeout_seconds: int,
ready_interval_ms: int,
env: dict[str, str] | None = None,
) -> dict[str, Any]:
normalized_service_name = _normalize_service_name(service_name)
normalized_cwd, cwd_path = _normalize_shell_cwd(cwd_text)
@ -718,9 +822,13 @@ def _start_service(
encoding="utf-8",
)
runner_path.chmod(0o700)
service_env = os.environ.copy()
if env is not None:
service_env.update(env)
process = subprocess.Popen( # noqa: S603
[str(runner_path)],
cwd=str(cwd_path),
env=service_env,
text=True,
start_new_session=True,
)
@ -747,7 +855,7 @@ def _start_service(
payload["ended_at"] = payload.get("ended_at") or time.time()
_write_service_metadata(normalized_service_name, payload)
return payload
if _run_readiness_probe(readiness, cwd=cwd_path):
if _run_readiness_probe(readiness, cwd=cwd_path, env=env):
payload["ready_at"] = time.time()
_write_service_metadata(normalized_service_name, payload)
return payload
@ -817,16 +925,38 @@ def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
destination = str(request.get("destination", "/workspace"))
payload = _read_exact(conn, archive_size)
return _extract_archive(payload, destination)
if action == "install_secrets":
archive_size = int(request.get("archive_size", 0))
if archive_size < 0:
raise RuntimeError("archive_size must not be negative")
payload = _read_exact(conn, archive_size)
return _install_secrets_archive(payload)
if action == "open_shell":
shell_id = str(request.get("shell_id", "")).strip()
if shell_id == "":
raise RuntimeError("shell_id is required")
cwd_text, _ = _normalize_shell_cwd(str(request.get("cwd", "/workspace")))
env_payload = request.get("env")
env_overrides = None
if env_payload is not None:
if not isinstance(env_payload, dict):
raise RuntimeError("shell env must be a JSON object")
env_overrides = {
_normalize_secret_name(str(key)): str(value) for key, value in env_payload.items()
}
redact_values_payload = request.get("redact_values")
redact_values: list[str] | None = None
if redact_values_payload is not None:
if not isinstance(redact_values_payload, list):
raise RuntimeError("redact_values must be a list")
redact_values = [str(item) for item in redact_values_payload]
session = _create_shell(
shell_id=shell_id,
cwd_text=cwd_text,
cols=int(request.get("cols", 120)),
rows=int(request.get("rows", 30)),
env_overrides=env_overrides,
redact_values=redact_values,
)
return session.summary()
if action == "read_shell":
@ -866,6 +996,15 @@ def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
cwd_text = str(request.get("cwd", "/workspace"))
readiness = request.get("readiness")
readiness_payload = dict(readiness) if isinstance(readiness, dict) else None
env_payload = request.get("env")
env = None
if env_payload is not None:
if not isinstance(env_payload, dict):
raise RuntimeError("service env must be a JSON object")
env = {
_normalize_secret_name(str(key)): str(value)
for key, value in env_payload.items()
}
return _start_service(
service_name=service_name,
command=command,
@ -873,6 +1012,7 @@ def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
readiness=readiness_payload,
ready_timeout_seconds=int(request.get("ready_timeout_seconds", 30)),
ready_interval_ms=int(request.get("ready_interval_ms", 500)),
env=env,
)
if action == "status_service":
service_name = str(request.get("service_name", "")).strip()
@ -887,12 +1027,19 @@ def _dispatch(request: dict[str, Any], conn: socket.socket) -> dict[str, Any]:
return _stop_service(service_name)
command = str(request.get("command", ""))
timeout_seconds = int(request.get("timeout_seconds", 30))
return _run_command(command, timeout_seconds)
env_payload = request.get("env")
env = None
if env_payload is not None:
if not isinstance(env_payload, dict):
raise RuntimeError("exec env must be a JSON object")
env = {_normalize_secret_name(str(key)): str(value) for key, value in env_payload.items()}
return _run_command(command, timeout_seconds, env=env)
def main() -> None:
SHELL_ROOT.mkdir(parents=True, exist_ok=True)
SERVICE_ROOT.mkdir(parents=True, exist_ok=True)
SECRET_ROOT.mkdir(parents=True, exist_ok=True)
family = getattr(socket, "AF_VSOCK", None)
if family is None:
raise SystemExit("AF_VSOCK is unavailable")