Add seeded task workspace creation
Current persistent tasks started with an empty workspace, which blocked the first useful host-to-task workflow in the task roadmap. This change lets task creation start from a host directory or tar archive without changing the one-shot VM surfaces. Expose source_path on task create across the CLI, SDK, and MCP, add safe archive upload and extraction support for guest and host-compat backends, persist workspace_seed metadata, and patch the per-task rootfs with the bundled guest agent before boot so seeded guest tasks work without republishing environments. Also switch post--- command reconstruction to shlex.join() so documented sh -lc task examples preserve argument boundaries. Validation: - uv lock - UV_CACHE_DIR=.uv-cache uv run pytest --no-cov tests/test_vm_guest.py tests/test_vm_manager.py tests/test_cli.py tests/test_api.py tests/test_server.py tests/test_public_contract.py - UV_CACHE_DIR=.uv-cache make check - UV_CACHE_DIR=.uv-cache make dist-check - real guest-backed smoke: task create --source-path, task exec -- cat note.txt, task delete
This commit is contained in:
parent
58df176148
commit
aa886b346e
25 changed files with 1076 additions and 75 deletions
|
|
@ -114,14 +114,23 @@ def test_pyro_task_methods_delegate_to_manager(tmp_path: Path) -> None:
|
|||
)
|
||||
)
|
||||
|
||||
created = pyro.create_task(environment="debian:12-base", allow_host_compat=True)
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
|
||||
|
||||
created = pyro.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=source_dir,
|
||||
)
|
||||
task_id = str(created["task_id"])
|
||||
executed = pyro.exec_task(task_id, command="printf 'ok\\n'")
|
||||
executed = pyro.exec_task(task_id, command="cat note.txt")
|
||||
status = pyro.status_task(task_id)
|
||||
logs = pyro.logs_task(task_id)
|
||||
deleted = pyro.delete_task(task_id)
|
||||
|
||||
assert executed["stdout"] == "ok\n"
|
||||
assert created["workspace_seed"]["mode"] == "directory"
|
||||
assert status["command_count"] == 1
|
||||
assert logs["count"] == 1
|
||||
assert deleted["deleted"] is True
|
||||
|
|
|
|||
|
|
@ -60,9 +60,13 @@ def test_cli_subcommand_help_includes_examples_and_guidance() -> None:
|
|||
assert "Use this from an MCP client config after the CLI evaluation path works." in mcp_help
|
||||
|
||||
task_help = _subparser_choice(parser, "task").format_help()
|
||||
assert "pyro task create debian:12" in task_help
|
||||
assert "pyro task create debian:12 --source-path ./repo" in task_help
|
||||
assert "pyro task exec TASK_ID" in task_help
|
||||
|
||||
task_create_help = _subparser_choice(_subparser_choice(parser, "task"), "create").format_help()
|
||||
assert "--source-path" in task_create_help
|
||||
assert "seed into `/workspace`" in task_create_help
|
||||
|
||||
task_exec_help = _subparser_choice(_subparser_choice(parser, "task"), "exec").format_help()
|
||||
assert "persistent `/workspace`" in task_exec_help
|
||||
assert "pyro task exec TASK_ID -- cat note.txt" in task_exec_help
|
||||
|
|
@ -326,12 +330,20 @@ def test_cli_requires_run_command() -> None:
|
|||
cli._require_command([])
|
||||
|
||||
|
||||
def test_cli_requires_command_preserves_shell_argument_boundaries() -> None:
|
||||
command = cli._require_command(
|
||||
["--", "sh", "-lc", 'printf "hello from task\\n" > note.txt']
|
||||
)
|
||||
assert command == 'sh -lc \'printf "hello from task\\n" > note.txt\''
|
||||
|
||||
|
||||
def test_cli_task_create_prints_json(
|
||||
monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]
|
||||
) -> None:
|
||||
class StubPyro:
|
||||
def create_task(self, **kwargs: Any) -> dict[str, Any]:
|
||||
assert kwargs["environment"] == "debian:12"
|
||||
assert kwargs["source_path"] == "./repo"
|
||||
return {"task_id": "task-123", "state": "started"}
|
||||
|
||||
class StubParser:
|
||||
|
|
@ -345,6 +357,7 @@ def test_cli_task_create_prints_json(
|
|||
ttl_seconds=600,
|
||||
network=False,
|
||||
allow_host_compat=False,
|
||||
source_path="./repo",
|
||||
json=True,
|
||||
)
|
||||
|
||||
|
|
@ -366,6 +379,13 @@ def test_cli_task_create_prints_human(
|
|||
"environment": "debian:12",
|
||||
"state": "started",
|
||||
"workspace_path": "/workspace",
|
||||
"workspace_seed": {
|
||||
"mode": "directory",
|
||||
"source_path": "/tmp/repo",
|
||||
"destination": "/workspace",
|
||||
"entry_count": 1,
|
||||
"bytes_written": 6,
|
||||
},
|
||||
"execution_mode": "guest_vsock",
|
||||
"vcpu_count": 1,
|
||||
"mem_mib": 1024,
|
||||
|
|
@ -384,6 +404,7 @@ def test_cli_task_create_prints_human(
|
|||
ttl_seconds=600,
|
||||
network=False,
|
||||
allow_host_compat=False,
|
||||
source_path="/tmp/repo",
|
||||
json=False,
|
||||
)
|
||||
|
||||
|
|
@ -393,6 +414,7 @@ def test_cli_task_create_prints_human(
|
|||
output = capsys.readouterr().out
|
||||
assert "Task: task-123" in output
|
||||
assert "Workspace: /workspace" in output
|
||||
assert "Workspace seed: directory from /tmp/repo" in output
|
||||
|
||||
|
||||
def test_cli_task_exec_prints_human_output(
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from pyro_mcp.contract import (
|
|||
PUBLIC_CLI_DEMO_SUBCOMMANDS,
|
||||
PUBLIC_CLI_ENV_SUBCOMMANDS,
|
||||
PUBLIC_CLI_RUN_FLAGS,
|
||||
PUBLIC_CLI_TASK_CREATE_FLAGS,
|
||||
PUBLIC_CLI_TASK_SUBCOMMANDS,
|
||||
PUBLIC_MCP_TOOLS,
|
||||
PUBLIC_SDK_METHODS,
|
||||
|
|
@ -67,6 +68,11 @@ def test_public_cli_help_lists_commands_and_run_flags() -> None:
|
|||
task_help_text = _subparser_choice(parser, "task").format_help()
|
||||
for subcommand_name in PUBLIC_CLI_TASK_SUBCOMMANDS:
|
||||
assert subcommand_name in task_help_text
|
||||
task_create_help_text = _subparser_choice(
|
||||
_subparser_choice(parser, "task"), "create"
|
||||
).format_help()
|
||||
for flag in PUBLIC_CLI_TASK_CREATE_FLAGS:
|
||||
assert flag in task_create_help_text
|
||||
|
||||
demo_help_text = _subparser_choice(parser, "demo").format_help()
|
||||
for subcommand_name in PUBLIC_CLI_DEMO_SUBCOMMANDS:
|
||||
|
|
|
|||
|
|
@ -171,6 +171,9 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("ok\n", encoding="utf-8")
|
||||
|
||||
def _extract_structured(raw_result: object) -> dict[str, Any]:
|
||||
if not isinstance(raw_result, tuple) or len(raw_result) != 2:
|
||||
|
|
@ -188,6 +191,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
{
|
||||
"environment": "debian:12-base",
|
||||
"allow_host_compat": True,
|
||||
"source_path": str(source_dir),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -197,7 +201,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
"task_exec",
|
||||
{
|
||||
"task_id": task_id,
|
||||
"command": "printf 'ok\\n'",
|
||||
"command": "cat note.txt",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
|
@ -207,6 +211,7 @@ def test_task_tools_round_trip(tmp_path: Path) -> None:
|
|||
|
||||
created, executed, logs, deleted = asyncio.run(_run())
|
||||
assert created["state"] == "started"
|
||||
assert created["workspace_seed"]["mode"] == "directory"
|
||||
assert executed["stdout"] == "ok\n"
|
||||
assert logs["count"] == 1
|
||||
assert deleted["deleted"] is True
|
||||
|
|
|
|||
|
|
@ -1,6 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import socket
|
||||
import tarfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -62,6 +66,45 @@ def test_vsock_exec_client_round_trip(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||
assert stub.closed is True
|
||||
|
||||
|
||||
def test_vsock_exec_client_upload_archive_round_trip(
|
||||
monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
|
||||
archive_path = tmp_path / "seed.tgz"
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
payload = b"hello\n"
|
||||
info = tarfile.TarInfo(name="note.txt")
|
||||
info.size = len(payload)
|
||||
archive.addfile(info, io.BytesIO(payload))
|
||||
stub = StubSocket(
|
||||
b'{"destination":"/workspace","entry_count":1,"bytes_written":6}'
|
||||
)
|
||||
|
||||
def socket_factory(family: int, sock_type: int) -> StubSocket:
|
||||
assert family == socket.AF_VSOCK
|
||||
assert sock_type == socket.SOCK_STREAM
|
||||
return stub
|
||||
|
||||
client = VsockExecClient(socket_factory=socket_factory)
|
||||
response = client.upload_archive(
|
||||
1234,
|
||||
5005,
|
||||
archive_path,
|
||||
destination="/workspace",
|
||||
timeout_seconds=60,
|
||||
)
|
||||
|
||||
request_payload, archive_payload = stub.sent.split(b"\n", 1)
|
||||
request = json.loads(request_payload.decode("utf-8"))
|
||||
assert request["action"] == "extract_archive"
|
||||
assert request["destination"] == "/workspace"
|
||||
assert int(request["archive_size"]) == archive_path.stat().st_size
|
||||
assert archive_payload == archive_path.read_bytes()
|
||||
assert response.entry_count == 1
|
||||
assert response.bytes_written == 6
|
||||
assert stub.closed is True
|
||||
|
||||
|
||||
def test_vsock_exec_client_rejects_bad_json(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(socket, "AF_VSOCK", 40, raising=False)
|
||||
stub = StubSocket(b"[]")
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import subprocess
|
||||
import tarfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
|
@ -306,6 +308,140 @@ def test_task_lifecycle_and_logs(tmp_path: Path) -> None:
|
|||
manager.status_task(task_id)
|
||||
|
||||
|
||||
def test_task_create_seeds_directory_source_into_workspace(tmp_path: Path) -> None:
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
created = manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=source_dir,
|
||||
)
|
||||
task_id = str(created["task_id"])
|
||||
|
||||
workspace_seed = created["workspace_seed"]
|
||||
assert workspace_seed["mode"] == "directory"
|
||||
assert workspace_seed["source_path"] == str(source_dir.resolve())
|
||||
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
|
||||
assert executed["stdout"] == "hello\n"
|
||||
|
||||
status = manager.status_task(task_id)
|
||||
assert status["workspace_seed"]["mode"] == "directory"
|
||||
assert status["workspace_seed"]["source_path"] == str(source_dir.resolve())
|
||||
|
||||
|
||||
def test_task_create_seeds_tar_archive_into_workspace(tmp_path: Path) -> None:
|
||||
archive_path = tmp_path / "seed.tgz"
|
||||
nested_dir = tmp_path / "src"
|
||||
nested_dir.mkdir()
|
||||
(nested_dir / "note.txt").write_text("archive\n", encoding="utf-8")
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
archive.add(nested_dir / "note.txt", arcname="note.txt")
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
created = manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=archive_path,
|
||||
)
|
||||
task_id = str(created["task_id"])
|
||||
|
||||
assert created["workspace_seed"]["mode"] == "tar_archive"
|
||||
executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30)
|
||||
assert executed["stdout"] == "archive\n"
|
||||
|
||||
|
||||
def test_task_create_rejects_unsafe_seed_archive(tmp_path: Path) -> None:
|
||||
archive_path = tmp_path / "bad.tgz"
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
payload = b"bad\n"
|
||||
info = tarfile.TarInfo(name="../escape.txt")
|
||||
info.size = len(payload)
|
||||
archive.addfile(info, io.BytesIO(payload))
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="unsafe archive member path"):
|
||||
manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=archive_path,
|
||||
)
|
||||
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
|
||||
|
||||
|
||||
def test_task_create_rejects_archive_that_writes_through_symlink(tmp_path: Path) -> None:
|
||||
archive_path = tmp_path / "bad-symlink.tgz"
|
||||
with tarfile.open(archive_path, "w:gz") as archive:
|
||||
symlink_info = tarfile.TarInfo(name="linked")
|
||||
symlink_info.type = tarfile.SYMTYPE
|
||||
symlink_info.linkname = "outside"
|
||||
archive.addfile(symlink_info)
|
||||
|
||||
payload = b"bad\n"
|
||||
file_info = tarfile.TarInfo(name="linked/note.txt")
|
||||
file_info.size = len(payload)
|
||||
archive.addfile(file_info, io.BytesIO(payload))
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="traverse through a symlinked path"):
|
||||
manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=archive_path,
|
||||
)
|
||||
|
||||
|
||||
def test_task_create_cleans_up_on_seed_failure(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
source_dir = tmp_path / "seed"
|
||||
source_dir.mkdir()
|
||||
(source_dir / "note.txt").write_text("hello\n", encoding="utf-8")
|
||||
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
def _boom(*args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
del args, kwargs
|
||||
raise RuntimeError("seed import failed")
|
||||
|
||||
monkeypatch.setattr(manager._backend, "import_archive", _boom) # noqa: SLF001
|
||||
|
||||
with pytest.raises(RuntimeError, match="seed import failed"):
|
||||
manager.create_task(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
source_path=source_dir,
|
||||
)
|
||||
|
||||
assert list((tmp_path / "vms" / "tasks").iterdir()) == []
|
||||
|
||||
|
||||
def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None:
|
||||
base_dir = tmp_path / "vms"
|
||||
manager = VmManager(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue