Add workspace snapshots and full reset

Implement the 2.8.0 workspace milestone with named snapshots and full-sandbox reset across the CLI, Python SDK, and MCP server.

Persist the immutable baseline plus named snapshot archives under each workspace, add workspace reset metadata, and make reset recreate the sandbox while clearing command history, shells, and services without changing the workspace identity or diff baseline.

Refresh the 2.8.0 docs, roadmap, and Python example around reset-over-repair, then validate with uv lock, UV_CACHE_DIR=.uv-cache make check, UV_CACHE_DIR=.uv-cache make dist-check, and a real guest-backed create/snapshot/reset/diff smoke test outside the sandbox.
This commit is contained in:
Thales Maciel 2026-03-12 12:41:11 -03:00
parent f504f0a331
commit 18b8fd2a7d
20 changed files with 1429 additions and 29 deletions

View file

@ -146,6 +146,23 @@ class Pyro:
def diff_workspace(self, workspace_id: str) -> dict[str, Any]:
return self._manager.diff_workspace(workspace_id)
def create_snapshot(self, workspace_id: str, snapshot_name: str) -> dict[str, Any]:
return self._manager.create_snapshot(workspace_id, snapshot_name)
def list_snapshots(self, workspace_id: str) -> dict[str, Any]:
return self._manager.list_snapshots(workspace_id)
def delete_snapshot(self, workspace_id: str, snapshot_name: str) -> dict[str, Any]:
return self._manager.delete_snapshot(workspace_id, snapshot_name)
def reset_workspace(
self,
workspace_id: str,
*,
snapshot: str = "baseline",
) -> dict[str, Any]:
return self._manager.reset_workspace(workspace_id, snapshot=snapshot)
def open_shell(
self,
workspace_id: str,
@ -444,6 +461,29 @@ class Pyro:
"""Compare `/workspace` to the immutable create-time baseline."""
return self.diff_workspace(workspace_id)
@server.tool()
async def snapshot_create(workspace_id: str, snapshot_name: str) -> dict[str, Any]:
"""Create one named workspace snapshot from the current `/workspace` tree."""
return self.create_snapshot(workspace_id, snapshot_name)
@server.tool()
async def snapshot_list(workspace_id: str) -> dict[str, Any]:
"""List the baseline plus named snapshots for one workspace."""
return self.list_snapshots(workspace_id)
@server.tool()
async def snapshot_delete(workspace_id: str, snapshot_name: str) -> dict[str, Any]:
"""Delete one named snapshot from a workspace."""
return self.delete_snapshot(workspace_id, snapshot_name)
@server.tool()
async def workspace_reset(
workspace_id: str,
snapshot: str = "baseline",
) -> dict[str, Any]:
"""Recreate a workspace and restore `/workspace` from baseline or one named snapshot."""
return self.reset_workspace(workspace_id, snapshot=snapshot)
@server.tool()
async def shell_open(
workspace_id: str,

View file

@ -174,6 +174,10 @@ def _print_workspace_summary_human(payload: dict[str, Any], *, action: str) -> N
f"{int(payload.get('mem_mib', 0))} MiB"
)
print(f"Command count: {int(payload.get('command_count', 0))}")
print(f"Reset count: {int(payload.get('reset_count', 0))}")
last_reset_at = payload.get("last_reset_at")
if last_reset_at is not None:
print(f"Last reset at: {last_reset_at}")
print(
"Services: "
f"{int(payload.get('running_service_count', 0))}/"
@ -281,6 +285,55 @@ def _print_workspace_logs_human(payload: dict[str, Any]) -> None:
print(stderr, end="" if stderr.endswith("\n") else "\n", file=sys.stderr)
def _print_workspace_snapshot_human(payload: dict[str, Any], *, prefix: str) -> None:
snapshot = payload.get("snapshot")
if not isinstance(snapshot, dict):
print(f"[{prefix}] workspace_id={str(payload.get('workspace_id', 'unknown'))}")
return
print(
f"[{prefix}] "
f"workspace_id={str(payload.get('workspace_id', 'unknown'))} "
f"snapshot_name={str(snapshot.get('snapshot_name', 'unknown'))} "
f"kind={str(snapshot.get('kind', 'unknown'))} "
f"entry_count={int(snapshot.get('entry_count', 0))} "
f"bytes_written={int(snapshot.get('bytes_written', 0))}"
)
def _print_workspace_snapshot_list_human(payload: dict[str, Any]) -> None:
snapshots = payload.get("snapshots")
if not isinstance(snapshots, list) or not snapshots:
print("No workspace snapshots found.")
return
for snapshot in snapshots:
if not isinstance(snapshot, dict):
continue
print(
f"{str(snapshot.get('snapshot_name', 'unknown'))} "
f"[{str(snapshot.get('kind', 'unknown'))}] "
f"entry_count={int(snapshot.get('entry_count', 0))} "
f"bytes_written={int(snapshot.get('bytes_written', 0))} "
f"deletable={'yes' if bool(snapshot.get('deletable', False)) else 'no'}"
)
def _print_workspace_reset_human(payload: dict[str, Any]) -> None:
_print_workspace_summary_human(payload, action="Reset workspace")
workspace_reset = payload.get("workspace_reset")
if isinstance(workspace_reset, dict):
print(
"Reset source: "
f"{str(workspace_reset.get('snapshot_name', 'unknown'))} "
f"({str(workspace_reset.get('kind', 'unknown'))})"
)
print(
"Reset restore: "
f"destination={str(workspace_reset.get('destination', WORKSPACE_GUEST_PATH))} "
f"entry_count={int(workspace_reset.get('entry_count', 0))} "
f"bytes_written={int(workspace_reset.get('bytes_written', 0))}"
)
def _print_workspace_shell_summary_human(payload: dict[str, Any], *, prefix: str) -> None:
print(
f"[{prefix}] "
@ -592,6 +645,8 @@ def _build_parser() -> argparse.ArgumentParser:
pyro workspace create debian:12 --seed-path ./repo
pyro workspace sync push WORKSPACE_ID ./repo --dest src
pyro workspace exec WORKSPACE_ID -- sh -lc 'printf "hello\\n" > note.txt'
pyro workspace snapshot create WORKSPACE_ID checkpoint
pyro workspace reset WORKSPACE_ID --snapshot checkpoint
pyro workspace diff WORKSPACE_ID
pyro workspace export WORKSPACE_ID src/note.txt --output ./note.txt
pyro workspace shell open WORKSPACE_ID
@ -617,6 +672,8 @@ def _build_parser() -> argparse.ArgumentParser:
pyro workspace create debian:12
pyro workspace create debian:12 --seed-path ./repo
pyro workspace sync push WORKSPACE_ID ./changes
pyro workspace snapshot create WORKSPACE_ID checkpoint
pyro workspace reset WORKSPACE_ID --snapshot checkpoint
pyro workspace diff WORKSPACE_ID
pyro workspace service start WORKSPACE_ID app --ready-file .ready -- \
sh -lc 'touch .ready && while true; do sleep 60; done'
@ -720,7 +777,8 @@ def _build_parser() -> argparse.ArgumentParser:
pyro workspace sync push WORKSPACE_ID ./repo
pyro workspace sync push WORKSPACE_ID ./patches --dest src
Sync is non-atomic. If a sync fails partway through, delete and recreate the workspace.
Sync is non-atomic. If a sync fails partway through, prefer reset over repair with
`pyro workspace reset WORKSPACE_ID`.
"""
),
formatter_class=_HelpFormatter,
@ -808,6 +866,100 @@ def _build_parser() -> argparse.ArgumentParser:
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_snapshot_parser = workspace_subparsers.add_parser(
"snapshot",
help="Create, list, and delete workspace snapshots.",
description=(
"Manage explicit named snapshots in addition to the implicit create-time baseline."
),
epilog=dedent(
"""
Examples:
pyro workspace snapshot create WORKSPACE_ID checkpoint
pyro workspace snapshot list WORKSPACE_ID
pyro workspace snapshot delete WORKSPACE_ID checkpoint
Use `workspace reset` to restore `/workspace` from `baseline` or one named snapshot.
"""
),
formatter_class=_HelpFormatter,
)
workspace_snapshot_subparsers = workspace_snapshot_parser.add_subparsers(
dest="workspace_snapshot_command",
required=True,
metavar="SNAPSHOT",
)
workspace_snapshot_create_parser = workspace_snapshot_subparsers.add_parser(
"create",
help="Create one named snapshot from the current workspace.",
description="Capture the current `/workspace` tree as one named snapshot.",
epilog="Example:\n pyro workspace snapshot create WORKSPACE_ID checkpoint",
formatter_class=_HelpFormatter,
)
workspace_snapshot_create_parser.add_argument("workspace_id", metavar="WORKSPACE_ID")
workspace_snapshot_create_parser.add_argument("snapshot_name", metavar="SNAPSHOT_NAME")
workspace_snapshot_create_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_snapshot_list_parser = workspace_snapshot_subparsers.add_parser(
"list",
help="List the baseline plus named snapshots.",
description="List the implicit baseline snapshot plus any named snapshots for a workspace.",
epilog="Example:\n pyro workspace snapshot list WORKSPACE_ID",
formatter_class=_HelpFormatter,
)
workspace_snapshot_list_parser.add_argument("workspace_id", metavar="WORKSPACE_ID")
workspace_snapshot_list_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_snapshot_delete_parser = workspace_snapshot_subparsers.add_parser(
"delete",
help="Delete one named snapshot.",
description="Delete one named snapshot while leaving the implicit baseline intact.",
epilog="Example:\n pyro workspace snapshot delete WORKSPACE_ID checkpoint",
formatter_class=_HelpFormatter,
)
workspace_snapshot_delete_parser.add_argument("workspace_id", metavar="WORKSPACE_ID")
workspace_snapshot_delete_parser.add_argument("snapshot_name", metavar="SNAPSHOT_NAME")
workspace_snapshot_delete_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_reset_parser = workspace_subparsers.add_parser(
"reset",
help="Recreate a workspace from baseline or one named snapshot.",
description=(
"Recreate the full sandbox and restore `/workspace` from the baseline "
"or one named snapshot."
),
epilog=dedent(
"""
Examples:
pyro workspace reset WORKSPACE_ID
pyro workspace reset WORKSPACE_ID --snapshot checkpoint
Prefer reset over repair: reset clears command history, shells, and services so the
workspace comes back clean from `baseline` or one named snapshot.
"""
),
formatter_class=_HelpFormatter,
)
workspace_reset_parser.add_argument("workspace_id", metavar="WORKSPACE_ID")
workspace_reset_parser.add_argument(
"--snapshot",
default="baseline",
help="Snapshot name to restore. Defaults to the implicit `baseline` snapshot.",
)
workspace_reset_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
workspace_shell_parser = workspace_subparsers.add_parser(
"shell",
help="Open and manage persistent interactive shells.",
@ -1483,6 +1635,72 @@ def main() -> None:
raise SystemExit(1) from exc
_print_workspace_diff_human(payload)
return
if args.workspace_command == "snapshot":
if args.workspace_snapshot_command == "create":
try:
payload = pyro.create_snapshot(args.workspace_id, args.snapshot_name)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_snapshot_human(
payload,
prefix="workspace-snapshot-create",
)
return
if args.workspace_snapshot_command == "list":
try:
payload = pyro.list_snapshots(args.workspace_id)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_snapshot_list_human(payload)
return
if args.workspace_snapshot_command == "delete":
try:
payload = pyro.delete_snapshot(args.workspace_id, args.snapshot_name)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
print(
"Deleted workspace snapshot: "
f"{str(payload.get('snapshot_name', 'unknown'))}"
)
return
if args.workspace_command == "reset":
try:
payload = pyro.reset_workspace(
args.workspace_id,
snapshot=args.snapshot,
)
except Exception as exc: # noqa: BLE001
if bool(args.json):
_print_json({"ok": False, "error": str(exc)})
else:
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
if bool(args.json):
_print_json(payload)
else:
_print_workspace_reset_human(payload)
return
if args.workspace_command == "shell":
if args.workspace_shell_command == "open":
try:

View file

@ -12,13 +12,16 @@ PUBLIC_CLI_WORKSPACE_SUBCOMMANDS = (
"exec",
"export",
"logs",
"reset",
"service",
"shell",
"snapshot",
"status",
"sync",
)
PUBLIC_CLI_WORKSPACE_SERVICE_SUBCOMMANDS = ("list", "logs", "start", "status", "stop")
PUBLIC_CLI_WORKSPACE_SHELL_SUBCOMMANDS = ("close", "open", "read", "signal", "write")
PUBLIC_CLI_WORKSPACE_SNAPSHOT_SUBCOMMANDS = ("create", "delete", "list")
PUBLIC_CLI_WORKSPACE_SYNC_SUBCOMMANDS = ("push",)
PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
"--vcpu-count",
@ -31,6 +34,7 @@ PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
)
PUBLIC_CLI_WORKSPACE_DIFF_FLAGS = ("--json",)
PUBLIC_CLI_WORKSPACE_EXPORT_FLAGS = ("--output", "--json")
PUBLIC_CLI_WORKSPACE_RESET_FLAGS = ("--snapshot", "--json")
PUBLIC_CLI_WORKSPACE_SERVICE_LIST_FLAGS = ("--json",)
PUBLIC_CLI_WORKSPACE_SERVICE_LOGS_FLAGS = ("--tail-lines", "--all", "--json")
PUBLIC_CLI_WORKSPACE_SERVICE_START_FLAGS = (
@ -50,6 +54,9 @@ PUBLIC_CLI_WORKSPACE_SHELL_READ_FLAGS = ("--cursor", "--max-chars", "--json")
PUBLIC_CLI_WORKSPACE_SHELL_WRITE_FLAGS = ("--input", "--no-newline", "--json")
PUBLIC_CLI_WORKSPACE_SHELL_SIGNAL_FLAGS = ("--signal", "--json")
PUBLIC_CLI_WORKSPACE_SHELL_CLOSE_FLAGS = ("--json",)
PUBLIC_CLI_WORKSPACE_SNAPSHOT_CREATE_FLAGS = ("--json",)
PUBLIC_CLI_WORKSPACE_SNAPSHOT_DELETE_FLAGS = ("--json",)
PUBLIC_CLI_WORKSPACE_SNAPSHOT_LIST_FLAGS = ("--json",)
PUBLIC_CLI_WORKSPACE_SYNC_PUSH_FLAGS = ("--dest", "--json")
PUBLIC_CLI_RUN_FLAGS = (
"--vcpu-count",
@ -64,8 +71,10 @@ PUBLIC_CLI_RUN_FLAGS = (
PUBLIC_SDK_METHODS = (
"close_shell",
"create_server",
"create_snapshot",
"create_vm",
"create_workspace",
"delete_snapshot",
"delete_vm",
"delete_workspace",
"diff_workspace",
@ -75,6 +84,7 @@ PUBLIC_SDK_METHODS = (
"inspect_environment",
"list_environments",
"list_services",
"list_snapshots",
"logs_service",
"logs_workspace",
"network_info_vm",
@ -84,6 +94,7 @@ PUBLIC_SDK_METHODS = (
"push_workspace_sync",
"read_shell",
"reap_expired",
"reset_workspace",
"run_in_vm",
"signal_shell",
"start_service",
@ -107,6 +118,9 @@ PUBLIC_MCP_TOOLS = (
"shell_read",
"shell_signal",
"shell_write",
"snapshot_create",
"snapshot_delete",
"snapshot_list",
"vm_create",
"vm_delete",
"vm_exec",
@ -123,6 +137,7 @@ PUBLIC_MCP_TOOLS = (
"workspace_exec",
"workspace_export",
"workspace_logs",
"workspace_reset",
"workspace_status",
"workspace_sync_push",
)

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "2.7.0"
DEFAULT_CATALOG_VERSION = "2.8.0"
OCI_MANIFEST_ACCEPT = ", ".join(
(
"application/vnd.oci.image.index.v1+json",

View file

@ -49,9 +49,10 @@ DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False
WORKSPACE_LAYOUT_VERSION = 4
WORKSPACE_LAYOUT_VERSION = 5
WORKSPACE_BASELINE_DIRNAME = "baseline"
WORKSPACE_BASELINE_ARCHIVE_NAME = "workspace.tar"
WORKSPACE_SNAPSHOTS_DIRNAME = "snapshots"
WORKSPACE_DIRNAME = "workspace"
WORKSPACE_COMMANDS_DIRNAME = "commands"
WORKSPACE_SHELLS_DIRNAME = "shells"
@ -68,10 +69,12 @@ DEFAULT_SERVICE_READY_INTERVAL_MS = 500
DEFAULT_SERVICE_LOG_TAIL_LINES = 200
WORKSPACE_SHELL_SIGNAL_NAMES = shell_signal_names()
WORKSPACE_SERVICE_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
WORKSPACE_SNAPSHOT_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
WorkspaceSeedMode = Literal["empty", "directory", "tar_archive"]
WorkspaceArtifactType = Literal["file", "directory", "symlink"]
WorkspaceServiceReadinessType = Literal["file", "tcp", "http", "command"]
WorkspaceSnapshotKind = Literal["baseline", "named"]
@dataclass
@ -116,6 +119,8 @@ class WorkspaceRecord:
command_count: int = 0
last_command: dict[str, Any] | None = None
workspace_seed: dict[str, Any] = field(default_factory=dict)
reset_count: int = 0
last_reset_at: float | None = None
@classmethod
def from_instance(
@ -144,6 +149,8 @@ class WorkspaceRecord:
command_count=command_count,
last_command=last_command,
workspace_seed=dict(workspace_seed or _empty_workspace_seed_payload()),
reset_count=0,
last_reset_at=None,
)
def to_instance(self, *, workdir: Path) -> VmInstance:
@ -185,6 +192,8 @@ class WorkspaceRecord:
"command_count": self.command_count,
"last_command": self.last_command,
"workspace_seed": self.workspace_seed,
"reset_count": self.reset_count,
"last_reset_at": self.last_reset_at,
}
@classmethod
@ -207,6 +216,46 @@ class WorkspaceRecord:
command_count=int(payload.get("command_count", 0)),
last_command=_optional_dict(payload.get("last_command")),
workspace_seed=_workspace_seed_dict(payload.get("workspace_seed")),
reset_count=int(payload.get("reset_count", 0)),
last_reset_at=(
None
if payload.get("last_reset_at") is None
else float(payload.get("last_reset_at", 0.0))
),
)
@dataclass
class WorkspaceSnapshotRecord:
"""Persistent snapshot metadata stored on disk per workspace."""
workspace_id: str
snapshot_name: str
kind: WorkspaceSnapshotKind
created_at: float
entry_count: int
bytes_written: int
def to_payload(self) -> dict[str, Any]:
return {
"layout_version": WORKSPACE_LAYOUT_VERSION,
"workspace_id": self.workspace_id,
"snapshot_name": self.snapshot_name,
"kind": self.kind,
"created_at": self.created_at,
"entry_count": self.entry_count,
"bytes_written": self.bytes_written,
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> WorkspaceSnapshotRecord:
return cls(
workspace_id=str(payload["workspace_id"]),
snapshot_name=str(payload["snapshot_name"]),
kind=cast(WorkspaceSnapshotKind, str(payload.get("kind", "named"))),
created_at=float(payload["created_at"]),
entry_count=int(payload.get("entry_count", 0)),
bytes_written=int(payload.get("bytes_written", 0)),
)
@ -864,6 +913,24 @@ def _normalize_workspace_service_name(service_name: str) -> str:
return normalized
def _normalize_workspace_snapshot_name(
snapshot_name: str,
*,
allow_baseline: bool = False,
) -> str:
normalized = snapshot_name.strip()
if normalized == "":
raise ValueError("snapshot_name must not be empty")
if normalized == "baseline" and not allow_baseline:
raise ValueError("snapshot_name 'baseline' is reserved")
if WORKSPACE_SNAPSHOT_NAME_RE.fullmatch(normalized) is None:
raise ValueError(
"snapshot_name must match "
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$"
)
return normalized
def _normalize_workspace_service_readiness(
readiness: dict[str, Any] | None,
) -> dict[str, Any] | None:
@ -2646,12 +2713,14 @@ class VmManager:
commands_dir = self._workspace_commands_dir(workspace_id)
shells_dir = self._workspace_shells_dir(workspace_id)
services_dir = self._workspace_services_dir(workspace_id)
snapshots_dir = self._workspace_snapshots_dir(workspace_id)
baseline_archive_path = self._workspace_baseline_archive_path(workspace_id)
workspace_dir.mkdir(parents=True, exist_ok=False)
host_workspace_dir.mkdir(parents=True, exist_ok=True)
commands_dir.mkdir(parents=True, exist_ok=True)
shells_dir.mkdir(parents=True, exist_ok=True)
services_dir.mkdir(parents=True, exist_ok=True)
snapshots_dir.mkdir(parents=True, exist_ok=True)
_persist_workspace_baseline(
prepared_seed,
baseline_archive_path=baseline_archive_path,
@ -2859,6 +2928,192 @@ class VmManager:
diff_payload["workspace_id"] = workspace_id
return diff_payload
def create_snapshot(
self,
workspace_id: str,
snapshot_name: str,
) -> dict[str, Any]:
normalized_snapshot_name = _normalize_workspace_snapshot_name(snapshot_name)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
self._ensure_workspace_not_expired_locked(workspace, time.time())
self._workspace_baseline_snapshot_locked(workspace)
if (
self._load_workspace_snapshot_locked_optional(
workspace_id,
normalized_snapshot_name,
)
is not None
):
raise ValueError(
f"snapshot {normalized_snapshot_name!r} already exists in workspace "
f"{workspace_id!r}"
)
instance = self._workspace_instance_for_live_operation_locked(
workspace,
operation_name="workspace_snapshot_create",
)
with tempfile.TemporaryDirectory(prefix="pyro-workspace-snapshot-") as temp_dir:
temp_archive_path = Path(temp_dir) / f"{normalized_snapshot_name}.tar"
exported = self._backend.export_archive(
instance,
workspace_path=WORKSPACE_GUEST_PATH,
archive_path=temp_archive_path,
)
snapshot = WorkspaceSnapshotRecord(
workspace_id=workspace_id,
snapshot_name=normalized_snapshot_name,
kind="named",
created_at=time.time(),
entry_count=int(exported["entry_count"]),
bytes_written=int(exported["bytes_written"]),
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
self._ensure_workspace_not_expired_locked(workspace, time.time())
if (
self._load_workspace_snapshot_locked_optional(
workspace_id,
normalized_snapshot_name,
)
is not None
):
raise ValueError(
f"snapshot {normalized_snapshot_name!r} already exists in workspace "
f"{workspace_id!r}"
)
archive_path = self._workspace_snapshot_archive_path(
workspace_id,
normalized_snapshot_name,
)
archive_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(temp_archive_path, archive_path)
workspace.state = instance.state
workspace.firecracker_pid = instance.firecracker_pid
workspace.last_error = instance.last_error
workspace.metadata = dict(instance.metadata)
self._save_workspace_locked(workspace)
self._save_workspace_snapshot_locked(snapshot)
return {
"workspace_id": workspace_id,
"snapshot": self._serialize_workspace_snapshot(snapshot),
"execution_mode": instance.metadata.get("execution_mode", "pending"),
}
def list_snapshots(self, workspace_id: str) -> dict[str, Any]:
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
self._ensure_workspace_not_expired_locked(workspace, time.time())
snapshots = self._list_workspace_snapshots_locked(workspace)
return {
"workspace_id": workspace_id,
"count": len(snapshots),
"snapshots": [
self._serialize_workspace_snapshot(snapshot) for snapshot in snapshots
],
}
def delete_snapshot(self, workspace_id: str, snapshot_name: str) -> dict[str, Any]:
normalized_snapshot_name = _normalize_workspace_snapshot_name(
snapshot_name,
allow_baseline=True,
)
if normalized_snapshot_name == "baseline":
raise ValueError("cannot delete the baseline snapshot")
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
self._ensure_workspace_not_expired_locked(workspace, time.time())
self._workspace_baseline_snapshot_locked(workspace)
self._load_workspace_snapshot_locked(workspace_id, normalized_snapshot_name)
self._delete_workspace_snapshot_locked(workspace_id, normalized_snapshot_name)
return {
"workspace_id": workspace_id,
"snapshot_name": normalized_snapshot_name,
"deleted": True,
}
def reset_workspace(
self,
workspace_id: str,
*,
snapshot: str = "baseline",
) -> dict[str, Any]:
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
self._ensure_workspace_not_expired_locked(workspace, time.time())
self._refresh_workspace_liveness_locked(workspace)
selected_snapshot, archive_path = self._resolve_workspace_snapshot_locked(
workspace,
snapshot,
)
instance = workspace.to_instance(
workdir=self._workspace_runtime_dir(workspace.workspace_id)
)
self._stop_workspace_services_locked(workspace, instance)
self._close_workspace_shells_locked(workspace, instance)
if workspace.state == "started":
self._backend.stop(instance)
workspace.state = "stopped"
self._backend.delete(instance)
workspace.state = "stopped"
workspace.firecracker_pid = None
workspace.last_error = None
self._reset_workspace_runtime_dirs(workspace_id)
self._save_workspace_locked(workspace)
recreated: VmInstance | None = None
try:
recreated = workspace.to_instance(
workdir=self._workspace_runtime_dir(workspace.workspace_id)
)
self._backend.create(recreated)
if self._runtime_capabilities.supports_guest_exec:
self._ensure_workspace_guest_agent_support(recreated)
with self._lock:
self._start_instance_locked(recreated)
self._require_guest_exec_or_opt_in(recreated)
reset_summary = self._backend.import_archive(
recreated,
archive_path=archive_path,
destination=WORKSPACE_GUEST_PATH,
)
workspace = self._load_workspace_locked(workspace_id)
workspace.state = recreated.state
workspace.firecracker_pid = recreated.firecracker_pid
workspace.last_error = recreated.last_error
workspace.metadata = dict(recreated.metadata)
workspace.command_count = 0
workspace.last_command = None
workspace.reset_count += 1
workspace.last_reset_at = time.time()
self._save_workspace_locked(workspace)
payload = self._serialize_workspace(workspace)
except Exception:
try:
if recreated is not None and recreated.state == "started":
self._backend.stop(recreated)
except Exception:
pass
try:
if recreated is not None:
self._backend.delete(recreated)
except Exception:
pass
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = "stopped"
workspace.firecracker_pid = None
workspace.last_error = None
self._save_workspace_locked(workspace)
raise
payload["workspace_reset"] = {
"snapshot_name": selected_snapshot.snapshot_name,
"kind": selected_snapshot.kind,
"destination": str(reset_summary["destination"]),
"entry_count": int(reset_summary["entry_count"]),
"bytes_written": int(reset_summary["bytes_written"]),
}
return payload
def exec_workspace(
self,
workspace_id: str,
@ -3372,6 +3627,8 @@ class VmManager:
"workspace_seed": _workspace_seed_dict(workspace.workspace_seed),
"command_count": workspace.command_count,
"last_command": workspace.last_command,
"reset_count": workspace.reset_count,
"last_reset_at": workspace.last_reset_at,
"service_count": service_count,
"running_service_count": running_service_count,
"metadata": workspace.metadata,
@ -3408,6 +3665,17 @@ class VmManager:
"stop_reason": service.stop_reason,
}
def _serialize_workspace_snapshot(self, snapshot: WorkspaceSnapshotRecord) -> dict[str, Any]:
return {
"workspace_id": snapshot.workspace_id,
"snapshot_name": snapshot.snapshot_name,
"kind": snapshot.kind,
"created_at": snapshot.created_at,
"entry_count": snapshot.entry_count,
"bytes_written": snapshot.bytes_written,
"deletable": snapshot.kind != "baseline",
}
def _require_guest_boot_or_opt_in(self, instance: VmInstance) -> None:
if self._runtime_capabilities.supports_vm_boot or instance.allow_host_compat:
return
@ -3589,6 +3857,15 @@ class VmManager:
def _workspace_baseline_archive_path(self, workspace_id: str) -> Path:
return self._workspace_baseline_dir(workspace_id) / WORKSPACE_BASELINE_ARCHIVE_NAME
def _workspace_snapshots_dir(self, workspace_id: str) -> Path:
return self._workspace_dir(workspace_id) / WORKSPACE_SNAPSHOTS_DIRNAME
def _workspace_snapshot_archive_path(self, workspace_id: str, snapshot_name: str) -> Path:
return self._workspace_snapshots_dir(workspace_id) / f"{snapshot_name}.tar"
def _workspace_snapshot_metadata_path(self, workspace_id: str, snapshot_name: str) -> Path:
return self._workspace_snapshots_dir(workspace_id) / f"{snapshot_name}.json"
def _workspace_commands_dir(self, workspace_id: str) -> Path:
return self._workspace_dir(workspace_id) / WORKSPACE_COMMANDS_DIRNAME
@ -3846,6 +4123,41 @@ class VmManager:
services = self._list_workspace_services_locked(workspace_id)
return len(services), sum(1 for service in services if service.state == "running")
def _workspace_baseline_snapshot_locked(
self,
workspace: WorkspaceRecord,
) -> WorkspaceSnapshotRecord:
baseline_archive_path = self._workspace_baseline_archive_path(workspace.workspace_id)
if not baseline_archive_path.exists():
raise RuntimeError(
"workspace snapshots and reset require a baseline snapshot. "
"Recreate the workspace to use snapshot/reset features."
)
entry_count, bytes_written = _inspect_seed_archive(baseline_archive_path)
return WorkspaceSnapshotRecord(
workspace_id=workspace.workspace_id,
snapshot_name="baseline",
kind="baseline",
created_at=workspace.created_at,
entry_count=entry_count,
bytes_written=bytes_written,
)
def _resolve_workspace_snapshot_locked(
self,
workspace: WorkspaceRecord,
snapshot_name: str,
) -> tuple[WorkspaceSnapshotRecord, Path]:
normalized_name = _normalize_workspace_snapshot_name(snapshot_name, allow_baseline=True)
if normalized_name == "baseline":
baseline = self._workspace_baseline_snapshot_locked(workspace)
return baseline, self._workspace_baseline_archive_path(workspace.workspace_id)
snapshot = self._load_workspace_snapshot_locked(workspace.workspace_id, normalized_name)
return (
snapshot,
self._workspace_snapshot_archive_path(workspace.workspace_id, normalized_name),
)
def _load_workspace_service_locked(
self,
workspace_id: str,
@ -3861,6 +4173,34 @@ class VmManager:
raise RuntimeError(f"service record at {record_path} is invalid")
return WorkspaceServiceRecord.from_payload(payload)
def _load_workspace_snapshot_locked(
self,
workspace_id: str,
snapshot_name: str,
) -> WorkspaceSnapshotRecord:
record_path = self._workspace_snapshot_metadata_path(workspace_id, snapshot_name)
if not record_path.exists():
raise ValueError(
f"snapshot {snapshot_name!r} does not exist in workspace {workspace_id!r}"
)
payload = json.loads(record_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError(f"snapshot record at {record_path} is invalid")
return WorkspaceSnapshotRecord.from_payload(payload)
def _load_workspace_snapshot_locked_optional(
self,
workspace_id: str,
snapshot_name: str,
) -> WorkspaceSnapshotRecord | None:
record_path = self._workspace_snapshot_metadata_path(workspace_id, snapshot_name)
if not record_path.exists():
return None
payload = json.loads(record_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError(f"snapshot record at {record_path} is invalid")
return WorkspaceSnapshotRecord.from_payload(payload)
def _load_workspace_service_locked_optional(
self,
workspace_id: str,
@ -3885,6 +4225,17 @@ class VmManager:
encoding="utf-8",
)
def _save_workspace_snapshot_locked(self, snapshot: WorkspaceSnapshotRecord) -> None:
record_path = self._workspace_snapshot_metadata_path(
snapshot.workspace_id,
snapshot.snapshot_name,
)
record_path.parent.mkdir(parents=True, exist_ok=True)
record_path.write_text(
json.dumps(snapshot.to_payload(), indent=2, sort_keys=True),
encoding="utf-8",
)
def _delete_workspace_service_artifacts_locked(
self,
workspace_id: str,
@ -3897,6 +4248,10 @@ class VmManager:
_workspace_service_status_path(services_dir, service_name).unlink(missing_ok=True)
_workspace_service_runner_path(services_dir, service_name).unlink(missing_ok=True)
def _delete_workspace_snapshot_locked(self, workspace_id: str, snapshot_name: str) -> None:
self._workspace_snapshot_metadata_path(workspace_id, snapshot_name).unlink(missing_ok=True)
self._workspace_snapshot_archive_path(workspace_id, snapshot_name).unlink(missing_ok=True)
def _list_workspace_services_locked(self, workspace_id: str) -> list[WorkspaceServiceRecord]:
services_dir = self._workspace_services_dir(workspace_id)
if not services_dir.exists():
@ -3909,6 +4264,26 @@ class VmManager:
services.append(WorkspaceServiceRecord.from_payload(payload))
return services
def _list_workspace_snapshots_locked(
self,
workspace: WorkspaceRecord,
) -> list[WorkspaceSnapshotRecord]:
snapshots_dir = self._workspace_snapshots_dir(workspace.workspace_id)
snapshots: list[WorkspaceSnapshotRecord] = [
self._workspace_baseline_snapshot_locked(workspace)
]
if not snapshots_dir.exists():
return snapshots
named_snapshots: list[WorkspaceSnapshotRecord] = []
for record_path in snapshots_dir.glob("*.json"):
payload = json.loads(record_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
continue
named_snapshots.append(WorkspaceSnapshotRecord.from_payload(payload))
named_snapshots.sort(key=lambda item: (-item.created_at, item.snapshot_name))
snapshots.extend(named_snapshots)
return snapshots
def _save_workspace_shell_locked(self, shell: WorkspaceShellRecord) -> None:
record_path = self._workspace_shell_record_path(shell.workspace_id, shell.shell_id)
record_path.parent.mkdir(parents=True, exist_ok=True)
@ -3950,6 +4325,17 @@ class VmManager:
pass
self._delete_workspace_shell_locked(workspace.workspace_id, shell.shell_id)
def _reset_workspace_runtime_dirs(self, workspace_id: str) -> None:
shutil.rmtree(self._workspace_runtime_dir(workspace_id), ignore_errors=True)
shutil.rmtree(self._workspace_host_dir(workspace_id), ignore_errors=True)
shutil.rmtree(self._workspace_commands_dir(workspace_id), ignore_errors=True)
shutil.rmtree(self._workspace_shells_dir(workspace_id), ignore_errors=True)
shutil.rmtree(self._workspace_services_dir(workspace_id), ignore_errors=True)
self._workspace_host_dir(workspace_id).mkdir(parents=True, exist_ok=True)
self._workspace_commands_dir(workspace_id).mkdir(parents=True, exist_ok=True)
self._workspace_shells_dir(workspace_id).mkdir(parents=True, exist_ok=True)
self._workspace_services_dir(workspace_id).mkdir(parents=True, exist_ok=True)
def _refresh_workspace_service_locked(
self,
workspace: WorkspaceRecord,