Add task sync push milestone

Tasks could start from host content in 2.2.0, but there was still no post-create path to update a live workspace from the host. This change adds the next host-to-task step so repeated fix or review loops do not require recreating the task for every local change.

Add task sync push across the CLI, Python SDK, and MCP server, reusing the existing safe archive import path from seeded task creation instead of introducing a second transfer stack. The implementation keeps sync separate from workspace_seed metadata, validates destinations under /workspace, and documents the current non-atomic recovery path as delete-and-recreate.

Validation:
- uv lock
- UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_cli.py tests/test_vm_manager.py tests/test_api.py tests/test_server.py tests/test_public_contract.py
- UV_CACHE_DIR=.uv-cache make check
- UV_CACHE_DIR=.uv-cache make dist-check
- real guest-backed smoke: task create --source-path, task sync push, task exec to verify both files, task delete
This commit is contained in:
Thales Maciel 2026-03-11 22:20:55 -03:00
parent aa886b346e
commit 9e11dcf9ab
19 changed files with 461 additions and 41 deletions

View file

@ -110,6 +110,15 @@ class Pyro:
def status_task(self, task_id: str) -> dict[str, Any]:
return self._manager.status_task(task_id)
def push_task_sync(
self,
task_id: str,
source_path: str | Path,
*,
dest: str = "/workspace",
) -> dict[str, Any]:
return self._manager.push_task_sync(task_id, source_path=source_path, dest=dest)
def logs_task(self, task_id: str) -> dict[str, Any]:
return self._manager.logs_task(task_id)
@ -269,6 +278,15 @@ class Pyro:
"""Run one command inside an existing task workspace."""
return self.exec_task(task_id, command=command, timeout_seconds=timeout_seconds)
@server.tool()
async def task_sync_push(
task_id: str,
source_path: str,
dest: str = "/workspace",
) -> dict[str, Any]:
"""Push host content into the persistent workspace of a started task."""
return self.push_task_sync(task_id, source_path=source_path, dest=dest)
@server.tool()
async def task_status(task_id: str) -> dict[str, Any]:
"""Inspect task state and latest command metadata."""

View file

@ -197,6 +197,23 @@ def _print_task_exec_human(payload: dict[str, Any]) -> None:
)
def _print_task_sync_human(payload: dict[str, Any]) -> None:
workspace_sync = payload.get("workspace_sync")
if not isinstance(workspace_sync, dict):
print(f"Synced task: {str(payload.get('task_id', 'unknown'))}")
return
print(
"[task-sync] "
f"task_id={str(payload.get('task_id', 'unknown'))} "
f"mode={str(workspace_sync.get('mode', 'unknown'))} "
f"source={str(workspace_sync.get('source_path', 'unknown'))} "
f"destination={str(workspace_sync.get('destination', TASK_WORKSPACE_GUEST_PATH))} "
f"entry_count={int(workspace_sync.get('entry_count', 0))} "
f"bytes_written={int(workspace_sync.get('bytes_written', 0))} "
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}"
)
def _print_task_logs_human(payload: dict[str, Any]) -> None:
entries = payload.get("entries")
if not isinstance(entries, list) or not entries:
@ -250,7 +267,8 @@ def _build_parser() -> argparse.ArgumentParser:
pyro run debian:12 -- git --version
Need repeated commands in one workspace after that?
pyro task create debian:12
pyro task create debian:12 --source-path ./repo
pyro task sync push TASK_ID ./changes
Use `pyro mcp serve` only after the CLI validation path works.
"""
@ -456,6 +474,7 @@ def _build_parser() -> argparse.ArgumentParser:
"""
Examples:
pyro task create debian:12 --source-path ./repo
pyro task sync push TASK_ID ./repo --dest src
pyro task exec TASK_ID -- sh -lc 'printf "hello\\n" > note.txt'
pyro task logs TASK_ID
"""
@ -472,6 +491,7 @@ def _build_parser() -> argparse.ArgumentParser:
Examples:
pyro task create debian:12
pyro task create debian:12 --source-path ./repo
pyro task sync push TASK_ID ./changes
"""
),
formatter_class=_HelpFormatter,
@ -552,6 +572,56 @@ def _build_parser() -> argparse.ArgumentParser:
"for example `pyro task exec TASK_ID -- cat note.txt`."
),
)
task_sync_parser = task_subparsers.add_parser(
"sync",
help="Push host content into a started task workspace.",
description=(
"Push host directory or archive content into `/workspace` for an existing "
"started task."
),
epilog=dedent(
"""
Examples:
pyro task sync push TASK_ID ./repo
pyro task sync push TASK_ID ./patches --dest src
Sync is non-atomic. If a sync fails partway through, delete and recreate the task.
"""
),
formatter_class=_HelpFormatter,
)
task_sync_subparsers = task_sync_parser.add_subparsers(
dest="task_sync_command",
required=True,
metavar="SYNC",
)
task_sync_push_parser = task_sync_subparsers.add_parser(
"push",
help="Push one host directory or archive into a started task.",
description="Import host content into `/workspace` or a subdirectory of it.",
epilog="Example:\n pyro task sync push TASK_ID ./repo --dest src",
formatter_class=_HelpFormatter,
)
task_sync_push_parser.add_argument(
"task_id",
metavar="TASK_ID",
help="Persistent task identifier.",
)
task_sync_push_parser.add_argument(
"source_path",
metavar="SOURCE_PATH",
help="Host directory or .tar/.tar.gz/.tgz archive to push into the task workspace.",
)
task_sync_push_parser.add_argument(
"--dest",
default=TASK_WORKSPACE_GUEST_PATH,
help="Workspace destination path. Relative values resolve inside `/workspace`.",
)
task_sync_push_parser.add_argument(
"--json",
action="store_true",
help="Print structured JSON instead of human-readable output.",
)
task_status_parser = task_subparsers.add_parser(
"status",
help="Inspect one task workspace.",
@ -821,6 +891,30 @@ def main() -> None:
if exit_code != 0:
raise SystemExit(exit_code)
return
if args.task_command == "sync" and args.task_sync_command == "push":
if bool(args.json):
try:
payload = pyro.push_task_sync(
args.task_id,
args.source_path,
dest=args.dest,
)
except Exception as exc: # noqa: BLE001
_print_json({"ok": False, "error": str(exc)})
raise SystemExit(1) from exc
_print_json(payload)
else:
try:
payload = pyro.push_task_sync(
args.task_id,
args.source_path,
dest=args.dest,
)
except Exception as exc: # noqa: BLE001
print(f"[error] {exc}", file=sys.stderr, flush=True)
raise SystemExit(1) from exc
_print_task_sync_human(payload)
return
if args.task_command == "status":
payload = pyro.status_task(args.task_id)
if bool(args.json):

View file

@ -5,7 +5,8 @@ from __future__ import annotations
PUBLIC_CLI_COMMANDS = ("demo", "doctor", "env", "mcp", "run", "task")
PUBLIC_CLI_DEMO_SUBCOMMANDS = ("ollama",)
PUBLIC_CLI_ENV_SUBCOMMANDS = ("inspect", "list", "pull", "prune")
PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status")
PUBLIC_CLI_TASK_SUBCOMMANDS = ("create", "delete", "exec", "logs", "status", "sync")
PUBLIC_CLI_TASK_SYNC_SUBCOMMANDS = ("push",)
PUBLIC_CLI_TASK_CREATE_FLAGS = (
"--vcpu-count",
"--mem-mib",
@ -15,6 +16,7 @@ PUBLIC_CLI_TASK_CREATE_FLAGS = (
"--source-path",
"--json",
)
PUBLIC_CLI_TASK_SYNC_PUSH_FLAGS = ("--dest", "--json")
PUBLIC_CLI_RUN_FLAGS = (
"--vcpu-count",
"--mem-mib",
@ -39,6 +41,7 @@ PUBLIC_SDK_METHODS = (
"network_info_vm",
"prune_environments",
"pull_environment",
"push_task_sync",
"reap_expired",
"run_in_vm",
"start_vm",
@ -63,4 +66,5 @@ PUBLIC_MCP_TOOLS = (
"task_exec",
"task_logs",
"task_status",
"task_sync_push",
)

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "2.2.0"
DEFAULT_CATALOG_VERSION = "2.3.0"
OCI_MANIFEST_ACCEPT = ", ".join(
(
"application/vnd.oci.image.index.v1+json",

View file

@ -194,11 +194,11 @@ class PreparedWorkspaceSeed:
bytes_written: int = 0
cleanup_dir: Path | None = None
def to_payload(self) -> dict[str, Any]:
def to_payload(self, *, destination: str = TASK_WORKSPACE_GUEST_PATH) -> dict[str, Any]:
return {
"mode": self.mode,
"source_path": self.source_path,
"destination": TASK_WORKSPACE_GUEST_PATH,
"destination": destination,
"entry_count": self.entry_count,
"bytes_written": self.bytes_written,
}
@ -372,6 +372,8 @@ def _normalize_workspace_destination(destination: str) -> tuple[str, PurePosixPa
if candidate == "":
raise ValueError("workspace destination must not be empty")
destination_path = PurePosixPath(candidate)
if any(part == ".." for part in destination_path.parts):
raise ValueError("workspace destination must stay inside /workspace")
workspace_root = PurePosixPath(TASK_WORKSPACE_GUEST_PATH)
if not destination_path.is_absolute():
destination_path = workspace_root / destination_path
@ -1218,6 +1220,52 @@ class VmManager:
finally:
prepared_seed.cleanup()
def push_task_sync(
self,
task_id: str,
*,
source_path: str | Path,
dest: str = TASK_WORKSPACE_GUEST_PATH,
) -> dict[str, Any]:
prepared_seed = self._prepare_workspace_seed(source_path)
if prepared_seed.archive_path is None:
prepared_seed.cleanup()
raise ValueError("source_path is required")
normalized_destination, _ = _normalize_workspace_destination(dest)
with self._lock:
task = self._load_task_locked(task_id)
self._ensure_task_not_expired_locked(task, time.time())
self._refresh_task_liveness_locked(task)
if task.state != "started":
raise RuntimeError(
f"task {task_id} must be in 'started' state before task_sync_push"
)
instance = task.to_instance(workdir=self._task_runtime_dir(task.task_id))
try:
import_summary = self._backend.import_archive(
instance,
archive_path=prepared_seed.archive_path,
destination=normalized_destination,
)
finally:
prepared_seed.cleanup()
workspace_sync = prepared_seed.to_payload(destination=normalized_destination)
workspace_sync["entry_count"] = int(import_summary["entry_count"])
workspace_sync["bytes_written"] = int(import_summary["bytes_written"])
workspace_sync["destination"] = str(import_summary["destination"])
with self._lock:
task = self._load_task_locked(task_id)
task.state = instance.state
task.firecracker_pid = instance.firecracker_pid
task.last_error = instance.last_error
task.metadata = dict(instance.metadata)
self._save_task_locked(task)
return {
"task_id": task_id,
"execution_mode": instance.metadata.get("execution_mode", "pending"),
"workspace_sync": workspace_sync,
}
def exec_task(self, task_id: str, *, command: str, timeout_seconds: int = 30) -> dict[str, Any]:
if timeout_seconds <= 0:
raise ValueError("timeout_seconds must be positive")