from __future__ import annotations import io import json import subprocess import tarfile import time from pathlib import Path from typing import Any import pytest import pyro_mcp.vm_manager as vm_manager_module from pyro_mcp.runtime import RuntimeCapabilities, resolve_runtime_paths from pyro_mcp.vm_manager import VmManager from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager def test_vm_manager_lifecycle_and_auto_cleanup(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) created = manager.create_vm( environment="debian:12", vcpu_count=1, mem_mib=512, ttl_seconds=600, allow_host_compat=True, ) vm_id = str(created["vm_id"]) started = manager.start_vm(vm_id) assert started["state"] == "started" executed = manager.exec_vm(vm_id, command="printf 'git version 2.43.0\\n'", timeout_seconds=30) assert executed["exit_code"] == 0 assert executed["execution_mode"] == "host_compat" assert "git version" in str(executed["stdout"]) with pytest.raises(ValueError, match="does not exist"): manager.status_vm(vm_id) def test_vm_manager_exec_timeout(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) vm_id = str( manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600, allow_host_compat=True, )["vm_id"] ) manager.start_vm(vm_id) result = manager.exec_vm(vm_id, command="sleep 2", timeout_seconds=1) assert result["exit_code"] == 124 assert "timed out" in str(result["stderr"]) def test_vm_manager_stop_and_delete(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) vm_id = str( manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600, allow_host_compat=True, )["vm_id"] ) manager.start_vm(vm_id) stopped = manager.stop_vm(vm_id) assert stopped["state"] == "stopped" deleted = manager.delete_vm(vm_id) assert deleted["deleted"] is True def test_vm_manager_reaps_expired(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) manager.MIN_TTL_SECONDS = 1 vm_id = str( manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=1, allow_host_compat=True, )["vm_id"] ) instance = manager._instances[vm_id] # noqa: SLF001 instance.expires_at = 0.0 result = manager.reap_expired() assert result["count"] == 1 with pytest.raises(ValueError): manager.status_vm(vm_id) def test_vm_manager_reaps_started_vm(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) manager.MIN_TTL_SECONDS = 1 vm_id = str( manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=1, allow_host_compat=True, )["vm_id"] ) manager.start_vm(vm_id) manager._instances[vm_id].expires_at = 0.0 # noqa: SLF001 result = manager.reap_expired() assert result["count"] == 1 @pytest.mark.parametrize( ("kwargs", "msg"), [ ({"vcpu_count": 0, "mem_mib": 512, "ttl_seconds": 600}, "vcpu_count must be between"), ({"vcpu_count": 1, "mem_mib": 64, "ttl_seconds": 600}, "mem_mib must be between"), ({"vcpu_count": 1, "mem_mib": 512, "ttl_seconds": 30}, "ttl_seconds must be between"), ], ) def test_vm_manager_validates_limits(tmp_path: Path, kwargs: dict[str, Any], msg: str) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) with pytest.raises(ValueError, match=msg): manager.create_vm(environment="debian:12-base", **kwargs) def test_vm_manager_max_active_limit(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", max_active_vms=1, network_manager=TapNetworkManager(enabled=False), ) manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600, allow_host_compat=True, ) with pytest.raises(RuntimeError, match="max active VMs reached"): manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600, allow_host_compat=True, ) def test_vm_manager_state_validation(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) vm_id = str( manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600, allow_host_compat=True, )["vm_id"] ) with pytest.raises(RuntimeError, match="must be in 'started' state"): manager.exec_vm(vm_id, command="echo hi", timeout_seconds=30) with pytest.raises(ValueError, match="must be positive"): manager.exec_vm(vm_id, command="echo hi", timeout_seconds=0) manager.start_vm(vm_id) with pytest.raises(RuntimeError, match="cannot be started from state"): manager.start_vm(vm_id) def test_vm_manager_status_expired_raises(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) manager.MIN_TTL_SECONDS = 1 vm_id = str( manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=1, allow_host_compat=True, )["vm_id"] ) manager._instances[vm_id].expires_at = 0.0 # noqa: SLF001 with pytest.raises(RuntimeError, match="expired and was automatically deleted"): manager.status_vm(vm_id) def test_vm_manager_invalid_backend(tmp_path: Path) -> None: with pytest.raises(ValueError, match="invalid backend"): VmManager( backend_name="nope", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) def test_vm_manager_network_info(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) created = manager.create_vm( environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600, allow_host_compat=True, ) vm_id = str(created["vm_id"]) status = manager.status_vm(vm_id) info = manager.network_info_vm(vm_id) assert status["network_enabled"] is False assert status["guest_ip"] is None assert info["network_enabled"] is False def test_vm_manager_run_vm(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) result = manager.run_vm( environment="debian:12-base", command="printf 'ok\\n'", vcpu_count=1, mem_mib=512, timeout_seconds=30, ttl_seconds=600, network=False, allow_host_compat=True, ) assert int(result["exit_code"]) == 0 assert str(result["stdout"]) == "ok\n" def test_task_lifecycle_and_logs(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) created = manager.create_task( environment="debian:12-base", allow_host_compat=True, ) task_id = str(created["task_id"]) assert created["state"] == "started" assert created["workspace_path"] == "/workspace" first = manager.exec_task( task_id, command="printf 'hello\\n' > note.txt", timeout_seconds=30, ) second = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30) assert first["exit_code"] == 0 assert second["stdout"] == "hello\n" status = manager.status_task(task_id) assert status["command_count"] == 2 assert status["last_command"] is not None logs = manager.logs_task(task_id) assert logs["count"] == 2 entries = logs["entries"] assert isinstance(entries, list) assert entries[1]["stdout"] == "hello\n" deleted = manager.delete_task(task_id) assert deleted["deleted"] is True with pytest.raises(ValueError, match="does not exist"): manager.status_task(task_id) def test_task_create_seeds_directory_source_into_workspace(tmp_path: Path) -> None: source_dir = tmp_path / "seed" source_dir.mkdir() (source_dir / "note.txt").write_text("hello\n", encoding="utf-8") manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) created = manager.create_task( environment="debian:12-base", allow_host_compat=True, source_path=source_dir, ) task_id = str(created["task_id"]) workspace_seed = created["workspace_seed"] assert workspace_seed["mode"] == "directory" assert workspace_seed["source_path"] == str(source_dir.resolve()) executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30) assert executed["stdout"] == "hello\n" status = manager.status_task(task_id) assert status["workspace_seed"]["mode"] == "directory" assert status["workspace_seed"]["source_path"] == str(source_dir.resolve()) def test_task_create_seeds_tar_archive_into_workspace(tmp_path: Path) -> None: archive_path = tmp_path / "seed.tgz" nested_dir = tmp_path / "src" nested_dir.mkdir() (nested_dir / "note.txt").write_text("archive\n", encoding="utf-8") with tarfile.open(archive_path, "w:gz") as archive: archive.add(nested_dir / "note.txt", arcname="note.txt") manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) created = manager.create_task( environment="debian:12-base", allow_host_compat=True, source_path=archive_path, ) task_id = str(created["task_id"]) assert created["workspace_seed"]["mode"] == "tar_archive" executed = manager.exec_task(task_id, command="cat note.txt", timeout_seconds=30) assert executed["stdout"] == "archive\n" def test_task_create_rejects_unsafe_seed_archive(tmp_path: Path) -> None: archive_path = tmp_path / "bad.tgz" with tarfile.open(archive_path, "w:gz") as archive: payload = b"bad\n" info = tarfile.TarInfo(name="../escape.txt") info.size = len(payload) archive.addfile(info, io.BytesIO(payload)) manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) with pytest.raises(RuntimeError, match="unsafe archive member path"): manager.create_task( environment="debian:12-base", allow_host_compat=True, source_path=archive_path, ) assert list((tmp_path / "vms" / "tasks").iterdir()) == [] def test_task_create_rejects_archive_that_writes_through_symlink(tmp_path: Path) -> None: archive_path = tmp_path / "bad-symlink.tgz" with tarfile.open(archive_path, "w:gz") as archive: symlink_info = tarfile.TarInfo(name="linked") symlink_info.type = tarfile.SYMTYPE symlink_info.linkname = "outside" archive.addfile(symlink_info) payload = b"bad\n" file_info = tarfile.TarInfo(name="linked/note.txt") file_info.size = len(payload) archive.addfile(file_info, io.BytesIO(payload)) manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) with pytest.raises(RuntimeError, match="traverse through a symlinked path"): manager.create_task( environment="debian:12-base", allow_host_compat=True, source_path=archive_path, ) def test_task_create_cleans_up_on_seed_failure( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: source_dir = tmp_path / "seed" source_dir.mkdir() (source_dir / "note.txt").write_text("hello\n", encoding="utf-8") manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) def _boom(*args: Any, **kwargs: Any) -> dict[str, Any]: del args, kwargs raise RuntimeError("seed import failed") monkeypatch.setattr(manager._backend, "import_archive", _boom) # noqa: SLF001 with pytest.raises(RuntimeError, match="seed import failed"): manager.create_task( environment="debian:12-base", allow_host_compat=True, source_path=source_dir, ) assert list((tmp_path / "vms" / "tasks").iterdir()) == [] def test_task_rehydrates_across_manager_processes(tmp_path: Path) -> None: base_dir = tmp_path / "vms" manager = VmManager( backend_name="mock", base_dir=base_dir, network_manager=TapNetworkManager(enabled=False), ) task_id = str( manager.create_task( environment="debian:12-base", allow_host_compat=True, )["task_id"] ) other = VmManager( backend_name="mock", base_dir=base_dir, network_manager=TapNetworkManager(enabled=False), ) executed = other.exec_task(task_id, command="printf 'ok\\n'", timeout_seconds=30) assert executed["exit_code"] == 0 assert executed["stdout"] == "ok\n" logs = other.logs_task(task_id) assert logs["count"] == 1 def test_task_requires_started_state(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) task_id = str( manager.create_task( environment="debian:12-base", allow_host_compat=True, )["task_id"] ) task_dir = tmp_path / "vms" / "tasks" / task_id / "task.json" payload = json.loads(task_dir.read_text(encoding="utf-8")) payload["state"] = "stopped" task_dir.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") with pytest.raises(RuntimeError, match="must be in 'started' state"): manager.exec_task(task_id, command="true", timeout_seconds=30) def test_vm_manager_firecracker_backend_path( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: class StubFirecrackerBackend: def __init__( self, environment_store: Any, firecracker_bin: Path, jailer_bin: Path, runtime_capabilities: Any, network_manager: TapNetworkManager, ) -> None: self.environment_store = environment_store self.firecracker_bin = firecracker_bin self.jailer_bin = jailer_bin self.runtime_capabilities = runtime_capabilities self.network_manager = network_manager def create(self, instance: Any) -> None: del instance def start(self, instance: Any) -> None: del instance def exec(self, instance: Any, command: str, timeout_seconds: int) -> Any: del instance, command, timeout_seconds return None def stop(self, instance: Any) -> None: del instance def delete(self, instance: Any) -> None: del instance monkeypatch.setattr(vm_manager_module, "FirecrackerBackend", StubFirecrackerBackend) manager = VmManager( backend_name="firecracker", base_dir=tmp_path / "vms", runtime_paths=resolve_runtime_paths(), network_manager=TapNetworkManager(enabled=False), ) assert manager._backend_name == "firecracker" # noqa: SLF001 def test_vm_manager_fails_closed_without_host_compat_opt_in(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) vm_id = str( manager.create_vm( environment="debian:12-base", ttl_seconds=600, )["vm_id"] ) with pytest.raises(RuntimeError, match="guest boot is unavailable"): manager.start_vm(vm_id) def test_vm_manager_uses_canonical_default_cache_dir( monkeypatch: pytest.MonkeyPatch, tmp_path: Path ) -> None: monkeypatch.setenv("PYRO_ENVIRONMENT_CACHE_DIR", str(tmp_path / "cache")) manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) assert manager._environment_store.cache_dir == tmp_path / "cache" # noqa: SLF001 def test_vm_manager_helper_round_trips() -> None: network = NetworkConfig( vm_id="abc123", tap_name="tap0", guest_ip="172.29.1.2", gateway_ip="172.29.1.1", subnet_cidr="172.29.1.0/24", mac_address="06:00:aa:bb:cc:dd", dns_servers=("1.1.1.1", "8.8.8.8"), ) assert vm_manager_module._optional_int(None) is None # noqa: SLF001 assert vm_manager_module._optional_int(True) == 1 # noqa: SLF001 assert vm_manager_module._optional_int(7) == 7 # noqa: SLF001 assert vm_manager_module._optional_int(7.2) == 7 # noqa: SLF001 assert vm_manager_module._optional_int("9") == 9 # noqa: SLF001 with pytest.raises(TypeError, match="integer-compatible"): vm_manager_module._optional_int(object()) # noqa: SLF001 assert vm_manager_module._optional_str(None) is None # noqa: SLF001 assert vm_manager_module._optional_str(1) == "1" # noqa: SLF001 assert vm_manager_module._optional_dict(None) is None # noqa: SLF001 assert vm_manager_module._optional_dict({"x": 1}) == {"x": 1} # noqa: SLF001 with pytest.raises(TypeError, match="dictionary payload"): vm_manager_module._optional_dict("bad") # noqa: SLF001 assert vm_manager_module._string_dict({"x": 1}) == {"x": "1"} # noqa: SLF001 assert vm_manager_module._string_dict("bad") == {} # noqa: SLF001 serialized = vm_manager_module._serialize_network(network) # noqa: SLF001 assert serialized is not None restored = vm_manager_module._deserialize_network(serialized) # noqa: SLF001 assert restored == network assert vm_manager_module._deserialize_network(None) is None # noqa: SLF001 with pytest.raises(TypeError, match="dictionary payload"): vm_manager_module._deserialize_network("bad") # noqa: SLF001 assert vm_manager_module._wrap_guest_command("echo hi") == "echo hi" # noqa: SLF001 wrapped = vm_manager_module._wrap_guest_command("echo hi", cwd="/workspace") # noqa: SLF001 assert "cd /workspace" in wrapped assert vm_manager_module._pid_is_running(None) is False # noqa: SLF001 def test_copy_rootfs_falls_back_to_copy2( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: source = tmp_path / "rootfs.ext4" source.write_text("payload", encoding="utf-8") dest = tmp_path / "dest" / "rootfs.ext4" def _raise_oserror(*args: Any, **kwargs: Any) -> Any: del args, kwargs raise OSError("no cp") monkeypatch.setattr(subprocess, "run", _raise_oserror) clone_mode = vm_manager_module._copy_rootfs(source, dest) # noqa: SLF001 assert clone_mode == "copy2" assert dest.read_text(encoding="utf-8") == "payload" def test_task_create_cleans_up_on_start_failure( tmp_path: Path, monkeypatch: pytest.MonkeyPatch ) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) def _boom(instance: Any) -> None: del instance raise RuntimeError("boom") monkeypatch.setattr(manager._backend, "start", _boom) # noqa: SLF001 with pytest.raises(RuntimeError, match="boom"): manager.create_task(environment="debian:12-base", allow_host_compat=True) assert list((tmp_path / "vms" / "tasks").iterdir()) == [] def test_exec_instance_wraps_guest_workspace_command(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001 supports_vm_boot=True, supports_guest_exec=True, supports_guest_network=False, reason=None, ) captured: dict[str, Any] = {} class StubBackend: def exec( self, instance: Any, command: str, timeout_seconds: int, *, workdir: Path | None = None, ) -> vm_manager_module.VmExecResult: del instance, timeout_seconds captured["command"] = command captured["workdir"] = workdir return vm_manager_module.VmExecResult( stdout="", stderr="", exit_code=0, duration_ms=1, ) manager._backend = StubBackend() # type: ignore[assignment] # noqa: SLF001 instance = vm_manager_module.VmInstance( # noqa: SLF001 vm_id="vm-123", environment="debian:12-base", vcpu_count=1, mem_mib=512, ttl_seconds=600, created_at=time.time(), expires_at=time.time() + 600, workdir=tmp_path / "runtime", state="started", ) result, execution_mode = manager._exec_instance( # noqa: SLF001 instance, command="echo hi", timeout_seconds=30, guest_cwd="/workspace", ) assert result.exit_code == 0 assert execution_mode == "unknown" assert "cd /workspace" in str(captured["command"]) assert captured["workdir"] is None def test_status_task_marks_dead_backing_process_stopped(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) task_id = str( manager.create_task( environment="debian:12-base", allow_host_compat=True, )["task_id"] ) task_path = tmp_path / "vms" / "tasks" / task_id / "task.json" payload = json.loads(task_path.read_text(encoding="utf-8")) payload["metadata"]["execution_mode"] = "guest_vsock" payload["firecracker_pid"] = 999999 task_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") status = manager.status_task(task_id) assert status["state"] == "stopped" updated_payload = json.loads(task_path.read_text(encoding="utf-8")) assert "backing guest process" in str(updated_payload.get("last_error", "")) def test_reap_expired_tasks_removes_invalid_and_expired_records(tmp_path: Path) -> None: manager = VmManager( backend_name="mock", base_dir=tmp_path / "vms", network_manager=TapNetworkManager(enabled=False), ) invalid_dir = tmp_path / "vms" / "tasks" / "invalid" invalid_dir.mkdir(parents=True) (invalid_dir / "task.json").write_text("[]", encoding="utf-8") task_id = str( manager.create_task( environment="debian:12-base", allow_host_compat=True, )["task_id"] ) task_path = tmp_path / "vms" / "tasks" / task_id / "task.json" payload = json.loads(task_path.read_text(encoding="utf-8")) payload["expires_at"] = 0.0 task_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8") with manager._lock: # noqa: SLF001 manager._reap_expired_tasks_locked(time.time()) # noqa: SLF001 assert not invalid_dir.exists() assert not (tmp_path / "vms" / "tasks" / task_id).exists()