Add seeded task workspace creation

Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces.

Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries.

Validation:
- uv lock
- UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py
- UV_CACHE_DIR=.uv-cache make check
- UV_CACHE_DIR=.uv-cache make dist-check
- real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
This commit is contained in:
Thales Maciel 2026-03-11 21:45:38 -03:00
parent 58df176148
commit aa886b346e
25 changed files with 1076 additions and 75 deletions

View file

@ -114,14 +114,23 @@ def test_pyro_task_methods_delegate_to_manager(tmp_path: Path) -> None:
)
)
created = pyro.create_task(environment="debian:12-base", allow_host_compat=True)
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
created = pyro.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=source_dir,
)
task_id = str(created["task_id"])
executed = pyro.exec_task(task_id, command="printf 'ok\\n'")
executed = pyro.exec_task(task_id, command="cat note.txt")
status = pyro.status_task(task_id)
logs = pyro.logs_task(task_id)
deleted = pyro.delete_task(task_id)
assert executed["stdout"] == "ok\n"
assert created["workspace_seed"]["mode"] == "directory"
assert status["command_count"] == 1
assert logs["count"] == 1
assert deleted["deleted"] is True

View file

@ -60,9 +60,13 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None:
assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help
task_help = _subparser_choice(parser, "task").format_help()
assert "pyro task create debian:12" in task_help
assert "pyro task create debian:12 --source-path ./repo" in task_help
assert "pyro task exec TASK_ID" in task_help
task_create_help = _subparser_choice(_subparser_choice(parser, "task"), "create").format_help()
assert "--source-path" in task_create_help
assert "seed into `/workspace`" in task_create_help
task_exec_help = _subparser_choice(_subparser_choice(parser, "task"), "exec").format_help()
assert "persistent `/workspace`" in task_exec_help
assert "pyro task exec TASK_ID -- cat note.txt" in task_exec_help
@ -326,12 +330,20 @@ def test_cli_requires_run_command() -> None:
cli._require_command([])
def test_cli_requires_command_preserves_shell_argument_boundaries() -> None:
command = cli._require_command(
["--", "sh", "-lc", 'printf "hello from task\\n" > note.txt']
)
assert command == 'sh -lc \'printf "hello from task\\n" > note.txt\''
def test_cli_task_create_prints_json(
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
) -> None:
class StubPyro:
def create_task(self, **kwargs: Any) -> dict[str, Any]:
assert kwargs["environment"] == "debian:12"
assert kwargs["source_path"] == "./repo"
return {"task_id": "task-123", "state": "started"}
class StubParser:
@ -345,6 +357,7 @@ def test_cli_task_create_prints_json(
ttl_seconds=600,
network=False,
allow_host_compat=False,
source_path="./repo",
json=True,
)
@ -366,6 +379,13 @@ def test_cli_task_create_prints_human(
"environment": "debian:12",
"state": "started",
"workspace_path": "/workspace",
"workspace_seed": {
"mode": "directory",
"source_path": "/tmp/repo",
"destination": "/workspace",
"entry_count": 1,
"bytes_written": 6,
},
"execution_mode": "guest_vsock",
"vcpu_count": 1,
"mem_mib": 1024,
@ -384,6 +404,7 @@ def test_cli_task_create_prints_human(
ttl_seconds=600,
network=False,
allow_host_compat=False,
source_path="/tmp/repo",
json=False,
)
@ -393,6 +414,7 @@ def test_cli_task_create_prints_human(
output = capsys.readouterr().out
assert "Task: task-123" in output
assert "Workspace: /workspace" in output
assert "Workspace seed: directory from /tmp/repo" in output
def test_cli_task_exec_prints_human_output(

View file

@ -17,6 +17,7 @@ from pyro_mcp.contract import (
PUBLIC_CLI_DEMO_SUBCOMMANDS,
PUBLIC_CLI_ENV_SUBCOMMANDS,
PUBLIC_CLI_RUN_FLAGS,
PUBLIC_CLI_TASK_CREATE_FLAGS,
PUBLIC_CLI_TASK_SUBCOMMANDS,
PUBLIC_MCP_TOOLS,
PUBLIC_SDK_METHODS,
@ -67,6 +68,11 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None:
task_help_text = _subparser_choice(parser, "task").format_help()
for subcommand_name in PUBLIC_CLI_TASK_SUBCOMMANDS:
assert subcommand_name in task_help_text
task_create_help_text = _subparser_choice(
_subparser_choice(parser, "task"), "create"
).format_help()
for flag in PUBLIC_CLI_TASK_CREATE_FLAGS:
assert flag in task_create_help_text
demo_help_text = _subparser_choice(parser, "demo").format_help()
for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS:

View file

@ -171,6 +171,9 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
def _extract_structured(raw_result: object) -> dict[str, Any]:
if not isinstance(raw_result, tuple) or len(raw_result) != 2:
@ -188,6 +191,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
{
"environment": "debian:12-base",
"allow_host_compat": True,
"source_path": str(source_dir),
},
)
)
@ -197,7 +201,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
"task_exec",
{
"task_id": task_id,
"command": "printf 'ok\\n'",
"command": "cat note.txt",
},
)
)
@ -207,6 +211,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
created, executed, logs, deleted = asyncio.run(_run())
assert created["state"] == "started"
assert created["workspace_seed"]["mode"] == "directory"
assert executed["stdout"] == "ok\n"
assert logs["count"] == 1
assert deleted["deleted"] is True

View file

@ -1,6 +1,10 @@
from __future__ import annotations
import io
import json
import socket
import tarfile
from pathlib import Path
import pytest
@ -62,6 +66,45 @@ def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None:
assert stub.closed is True
def test_vsock_exec_client_upload_archive_round_trip(
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
archive_path = tmp_path / "seed.tgz"
with tarfile.open(archive_path, "w:gz") as archive:
payload = b"hello\n"
info = tarfile.TarInfo(name="note.txt")
info.size = len(payload)
archive.addfile(info, io.BytesIO(payload))
stub = StubSocket(
b'{"destination":"/workspace","entry_count":1,"bytes_written":6}'
)
def socket_factory(family: int, sock_type: int) -> StubSocket:
assert family == socket.AF_VSOCK
assert sock_type == socket.SOCK_STREAM
return stub
client = VsockExecClient(socket_factory=socket_factory)
response = client.upload_archive(
1234,
5005,
archive_path,
destination="/workspace",
timeout_seconds=60,
)
request_payload, archive_payload = stub.sent.split(b"\n", 1)
request = json.loads(request_payload.decode("utf-8"))
assert request["action"] == "extract_archive"
assert request["destination"] == "/workspace"
assert int(request["archive_size"]) == archive_path.stat().st_size
assert archive_payload == archive_path.read_bytes()
assert response.entry_count == 1
assert response.bytes_written == 6
assert stub.closed is True
def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
stub = StubSocket(b"[]")

View file

@ -1,7 +1,9 @@
from __future__ import annotations
import io
import json
import subprocess
import tarfile
import time
from pathlib import Path
from typing import Any
@ -306,6 +308,140 @@ def test_task_lifecycle_and_logs(tmp_path: Path) -> None:
manager.status_task(task_id)
def test_task_create_seeds_directory_source_into_workspace(tmp_path: Path) -> None:
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
created = manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=source_dir,
)
task_id = str(created["task_id"])
workspace_seed = created["workspace_seed"]
assert workspace_seed["mode"] == "directory"
assert workspace_seed["source_path"] == str(source_dir.resolve())
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
assert executed["stdout"] == "hello\n"
status = manager.status_task(task_id)
assert status["workspace_seed"]["mode"] == "directory"
assert status["workspace_seed"]["source_path"] == str(source_dir.resolve())
def test_task_create_seeds_tar_archive_into_workspace(tmp_path: Path) -> None:
archive_path = tmp_path / "seed.tgz"
nested_dir = tmp_path / "src"
nested_dir.mkdir()
(nested_dir / "note.txt").write_text("archive\n", encoding="utf-8")
with tarfile.open(archive_path, "w:gz") as archive:
archive.add(nested_dir / "note.txt", arcname="note.txt")
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
created = manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=archive_path,
)
task_id = str(created["task_id"])
assert created["workspace_seed"]["mode"] == "tar_archive"
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
assert executed["stdout"] == "archive\n"
def test_task_create_rejects_unsafe_seed_archive(tmp_path: Path) -> None:
archive_path = tmp_path / "bad.tgz"
with tarfile.open(archive_path, "w:gz") as archive:
payload = b"bad\n"
info = tarfile.TarInfo(name="../escape.txt")
info.size = len(payload)
archive.addfile(info, io.BytesIO(payload))
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
with pytest.raises(RuntimeError, match="unsafe archive member path"):
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=archive_path,
)
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
def test_task_create_rejects_archive_that_writes_through_symlink(tmp_path: Path) -> None:
archive_path = tmp_path / "bad-symlink.tgz"
with tarfile.open(archive_path, "w:gz") as archive:
symlink_info = tarfile.TarInfo(name="linked")
symlink_info.type = tarfile.SYMTYPE
symlink_info.linkname = "outside"
archive.addfile(symlink_info)
payload = b"bad\n"
file_info = tarfile.TarInfo(name="linked/note.txt")
file_info.size = len(payload)
archive.addfile(file_info, io.BytesIO(payload))
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
with pytest.raises(RuntimeError, match="traverse through a symlinked path"):
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=archive_path,
)
def test_task_create_cleans_up_on_seed_failure(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
) -> None:
source_dir = tmp_path / "seed"
source_dir.mkdir()
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
manager = VmManager(
backend_name="mock",
base_dir=tmp_path / "vms",
network_manager=TapNetworkManager(enabled=False),
)
def _boom(*args: Any, **kwargs: Any) -> dict[str, Any]:
del args, kwargs
raise RuntimeError("seed import failed")
monkeypatch.setattr(manager._backend, "import_archive", _boom) # noqa: SLF001
with pytest.raises(RuntimeError, match="seed import failed"):
manager.create_task(
environment="debian:12-base",
allow_host_compat=True,
source_path=source_dir,
)
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None:
base_dir = tmp_path / "vms"
manager = VmManager(