Add workspace network policy and published ports
Replace the workspace-level boolean network toggle with explicit network policies and attach localhost TCP publication to workspace services. Persist network_policy in workspace records, validate --publish requests, and run host-side proxy helpers that follow the service lifecycle so published ports are cleaned up on failure, stop, reset, and delete. Update the CLI, SDK, MCP contract, docs, roadmap, and examples for the new policy model, add coverage for the proxy and manager edge cases, and validate with uv lock, UV_CACHE_DIR=.uv-cache make check, UV_CACHE_DIR=.uv-cache make dist-check, and a real guest-backed published-port probe smoke.
This commit is contained in:
parent
fc72fcd3a1
commit
c82f4629b2
21 changed files with 1944 additions and 49 deletions
|
|
@ -123,6 +123,74 @@ def test_pyro_create_vm_defaults_sizing_and_host_compat(tmp_path: Path) -> None:
|
|||
assert created["allow_host_compat"] is True
|
||||
|
||||
|
||||
def test_pyro_workspace_network_policy_and_published_ports_delegate() -> None:
|
||||
calls: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
class StubManager:
|
||||
def create_workspace(self, **kwargs: Any) -> dict[str, Any]:
|
||||
calls.append(("create_workspace", kwargs))
|
||||
return {"workspace_id": "workspace-123"}
|
||||
|
||||
def start_service(
|
||||
self,
|
||||
workspace_id: str,
|
||||
service_name: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
calls.append(
|
||||
(
|
||||
"start_service",
|
||||
{
|
||||
"workspace_id": workspace_id,
|
||||
"service_name": service_name,
|
||||
**kwargs,
|
||||
},
|
||||
)
|
||||
)
|
||||
return {"workspace_id": workspace_id, "service_name": service_name, "state": "running"}
|
||||
|
||||
pyro = Pyro(manager=cast(Any, StubManager()))
|
||||
|
||||
pyro.create_workspace(
|
||||
environment="debian:12",
|
||||
network_policy="egress+published-ports",
|
||||
)
|
||||
pyro.start_service(
|
||||
"workspace-123",
|
||||
"web",
|
||||
command="python3 -m http.server 8080",
|
||||
published_ports=[{"guest_port": 8080, "host_port": 18080}],
|
||||
)
|
||||
|
||||
assert calls[0] == (
|
||||
"create_workspace",
|
||||
{
|
||||
"environment": "debian:12",
|
||||
"vcpu_count": 1,
|
||||
"mem_mib": 1024,
|
||||
"ttl_seconds": 600,
|
||||
"network_policy": "egress+published-ports",
|
||||
"allow_host_compat": False,
|
||||
"seed_path": None,
|
||||
"secrets": None,
|
||||
},
|
||||
)
|
||||
assert calls[1] == (
|
||||
"start_service",
|
||||
{
|
||||
"workspace_id": "workspace-123",
|
||||
"service_name": "web",
|
||||
"command": "python3 -m http.server 8080",
|
||||
"cwd": "/workspace",
|
||||
"readiness": None,
|
||||
"ready_timeout_seconds": 30,
|
||||
"ready_interval_ms": 500,
|
||||
"secret_env": None,
|
||||
"published_ports": [{"guest_port": 8080, "host_port": 18080}],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_pyro_workspace_methods_delegate_to_manager(tmp_path: Path) -> None:
|
||||
pyro = Pyro(
|
||||
manager=VmManager(
|
||||
|
|
|
|||
|
|
@ -472,6 +472,7 @@ def test_cli_workspace_create_prints_json(
|
|||
def create_workspace(self, **kwargs: Any) -> dict[str, Any]:
|
||||
assert kwargs["environment"] == "debian:12"
|
||||
assert kwargs["seed_path"] == "./repo"
|
||||
assert kwargs["network_policy"] == "egress"
|
||||
return {"workspace_id": "workspace-123", "state": "started"}
|
||||
|
||||
class StubParser:
|
||||
|
|
@ -483,7 +484,7 @@ def test_cli_workspace_create_prints_json(
|
|||
vcpu_count=1,
|
||||
mem_mib=1024,
|
||||
ttl_seconds=600,
|
||||
network=False,
|
||||
network_policy="egress",
|
||||
allow_host_compat=False,
|
||||
seed_path="./repo",
|
||||
json=True,
|
||||
|
|
@ -506,6 +507,7 @@ def test_cli_workspace_create_prints_human(
|
|||
"workspace_id": "workspace-123",
|
||||
"environment": "debian:12",
|
||||
"state": "started",
|
||||
"network_policy": "off",
|
||||
"workspace_path": "/workspace",
|
||||
"workspace_seed": {
|
||||
"mode": "directory",
|
||||
|
|
@ -530,7 +532,7 @@ def test_cli_workspace_create_prints_human(
|
|||
vcpu_count=1,
|
||||
mem_mib=1024,
|
||||
ttl_seconds=600,
|
||||
network=False,
|
||||
network_policy="off",
|
||||
allow_host_compat=False,
|
||||
seed_path="/tmp/repo",
|
||||
json=False,
|
||||
|
|
@ -2047,12 +2049,21 @@ def test_cli_workspace_service_start_prints_json(
|
|||
assert service_name == "app"
|
||||
assert kwargs["command"] == "sh -lc 'touch .ready && while true; do sleep 60; done'"
|
||||
assert kwargs["readiness"] == {"type": "file", "path": ".ready"}
|
||||
assert kwargs["published_ports"] == [{"host_port": 18080, "guest_port": 8080}]
|
||||
return {
|
||||
"workspace_id": workspace_id,
|
||||
"service_name": service_name,
|
||||
"state": "running",
|
||||
"cwd": "/workspace",
|
||||
"execution_mode": "guest_vsock",
|
||||
"published_ports": [
|
||||
{
|
||||
"host": "127.0.0.1",
|
||||
"host_port": 18080,
|
||||
"guest_port": 8080,
|
||||
"protocol": "tcp",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
class StartParser:
|
||||
|
|
@ -2070,6 +2081,7 @@ def test_cli_workspace_service_start_prints_json(
|
|||
ready_command=None,
|
||||
ready_timeout_seconds=30,
|
||||
ready_interval_ms=500,
|
||||
publish=["18080:8080"],
|
||||
json=True,
|
||||
command_args=["--", "sh", "-lc", "touch .ready && while true; do sleep 60; done"],
|
||||
)
|
||||
|
|
@ -2149,6 +2161,14 @@ def test_cli_workspace_service_list_prints_human(
|
|||
"state": "running",
|
||||
"cwd": "/workspace",
|
||||
"execution_mode": "guest_vsock",
|
||||
"published_ports": [
|
||||
{
|
||||
"host": "127.0.0.1",
|
||||
"host_port": 18080,
|
||||
"guest_port": 8080,
|
||||
"protocol": "tcp",
|
||||
}
|
||||
],
|
||||
"readiness": {"type": "file", "path": "/workspace/.ready"},
|
||||
},
|
||||
{
|
||||
|
|
@ -2176,7 +2196,7 @@ def test_cli_workspace_service_list_prints_human(
|
|||
monkeypatch.setattr(cli, "Pyro", StubPyro)
|
||||
cli.main()
|
||||
captured = capsys.readouterr()
|
||||
assert "app [running] cwd=/workspace" in captured.out
|
||||
assert "app [running] cwd=/workspace published=127.0.0.1:18080->8080/tcp" in captured.out
|
||||
assert "worker [stopped] cwd=/workspace" in captured.out
|
||||
|
||||
|
||||
|
|
@ -3006,6 +3026,110 @@ def test_cli_workspace_secret_parsers_validate_syntax(tmp_path: Path) -> None:
|
|||
cli._parse_workspace_secret_env_options(["TOKEN", "TOKEN=API_TOKEN"]) # noqa: SLF001
|
||||
|
||||
|
||||
def test_cli_workspace_publish_parser_validates_syntax() -> None:
|
||||
assert cli._parse_workspace_publish_options(["8080"]) == [ # noqa: SLF001
|
||||
{"host_port": None, "guest_port": 8080}
|
||||
]
|
||||
assert cli._parse_workspace_publish_options(["18080:8080"]) == [ # noqa: SLF001
|
||||
{"host_port": 18080, "guest_port": 8080}
|
||||
]
|
||||
|
||||
with pytest.raises(ValueError, match="must not be empty"):
|
||||
cli._parse_workspace_publish_options([" "]) # noqa: SLF001
|
||||
with pytest.raises(ValueError, match="must use GUEST_PORT or HOST_PORT:GUEST_PORT"):
|
||||
cli._parse_workspace_publish_options(["bad"]) # noqa: SLF001
|
||||
with pytest.raises(ValueError, match="must use GUEST_PORT or HOST_PORT:GUEST_PORT"):
|
||||
cli._parse_workspace_publish_options(["bad:8080"]) # noqa: SLF001
|
||||
|
||||
|
||||
def test_cli_workspace_service_start_rejects_multiple_readiness_flags_json(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
class StubPyro:
|
||||
def start_service(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
|
||||
raise AssertionError("start_service should not be called")
|
||||
|
||||
class StartParser:
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return argparse.Namespace(
|
||||
command="workspace",
|
||||
workspace_command="service",
|
||||
workspace_service_command="start",
|
||||
workspace_id="workspace-123",
|
||||
service_name="app",
|
||||
cwd="/workspace",
|
||||
ready_file=".ready",
|
||||
ready_tcp=None,
|
||||
ready_http="http://127.0.0.1:8080/",
|
||||
ready_command=None,
|
||||
ready_timeout_seconds=30,
|
||||
ready_interval_ms=500,
|
||||
publish=[],
|
||||
json=True,
|
||||
command_args=["--", "sh", "-lc", "touch .ready && while true; do sleep 60; done"],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(cli, "_build_parser", lambda: StartParser())
|
||||
monkeypatch.setattr(cli, "Pyro", StubPyro)
|
||||
with pytest.raises(SystemExit, match="1"):
|
||||
cli.main()
|
||||
payload = json.loads(capsys.readouterr().out)
|
||||
assert "choose at most one" in payload["error"]
|
||||
|
||||
|
||||
def test_cli_workspace_service_start_prints_human_with_ready_http(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
class StubPyro:
|
||||
def start_service(
|
||||
self,
|
||||
workspace_id: str,
|
||||
service_name: str,
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
assert workspace_id == "workspace-123"
|
||||
assert service_name == "app"
|
||||
assert kwargs["readiness"] == {"type": "http", "url": "http://127.0.0.1:8080/ready"}
|
||||
return {
|
||||
"workspace_id": workspace_id,
|
||||
"service_name": service_name,
|
||||
"state": "running",
|
||||
"cwd": "/workspace",
|
||||
"execution_mode": "guest_vsock",
|
||||
"readiness": kwargs["readiness"],
|
||||
}
|
||||
|
||||
class StartParser:
|
||||
def parse_args(self) -> argparse.Namespace:
|
||||
return argparse.Namespace(
|
||||
command="workspace",
|
||||
workspace_command="service",
|
||||
workspace_service_command="start",
|
||||
workspace_id="workspace-123",
|
||||
service_name="app",
|
||||
cwd="/workspace",
|
||||
ready_file=None,
|
||||
ready_tcp=None,
|
||||
ready_http="http://127.0.0.1:8080/ready",
|
||||
ready_command=None,
|
||||
ready_timeout_seconds=30,
|
||||
ready_interval_ms=500,
|
||||
publish=[],
|
||||
secret_env=[],
|
||||
json=False,
|
||||
command_args=["--", "sh", "-lc", "while true; do sleep 60; done"],
|
||||
)
|
||||
|
||||
monkeypatch.setattr(cli, "_build_parser", lambda: StartParser())
|
||||
monkeypatch.setattr(cli, "Pyro", StubPyro)
|
||||
cli.main()
|
||||
captured = capsys.readouterr()
|
||||
assert "workspace-service-start" in captured.err
|
||||
assert "service_name=app" in captured.err
|
||||
|
||||
|
||||
def test_print_workspace_summary_human_includes_secret_metadata(
|
||||
capsys: pytest.CaptureFixture[str],
|
||||
) -> None:
|
||||
|
|
|
|||
|
|
@ -1393,6 +1393,775 @@ def test_workspace_service_lifecycle_and_status_counts(tmp_path: Path) -> None:
|
|||
assert deleted["deleted"] is True
|
||||
|
||||
|
||||
def test_workspace_create_serializes_network_policy(tmp_path: Path) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001
|
||||
supports_vm_boot=True,
|
||||
supports_guest_exec=True,
|
||||
supports_guest_network=True,
|
||||
)
|
||||
manager._ensure_workspace_guest_bootstrap_support = lambda instance: None # type: ignore[method-assign] # noqa: SLF001
|
||||
|
||||
created = manager.create_workspace(
|
||||
environment="debian:12-base",
|
||||
network_policy="egress",
|
||||
)
|
||||
|
||||
assert created["network_policy"] == "egress"
|
||||
workspace_id = str(created["workspace_id"])
|
||||
workspace_path = tmp_path / "vms" / "workspaces" / workspace_id / "workspace.json"
|
||||
payload = json.loads(workspace_path.read_text(encoding="utf-8"))
|
||||
assert payload["network_policy"] == "egress"
|
||||
|
||||
|
||||
def test_workspace_service_start_serializes_published_ports(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001
|
||||
supports_vm_boot=True,
|
||||
supports_guest_exec=True,
|
||||
supports_guest_network=True,
|
||||
)
|
||||
manager._ensure_workspace_guest_bootstrap_support = lambda instance: None # type: ignore[method-assign] # noqa: SLF001
|
||||
created = manager.create_workspace(
|
||||
environment="debian:12-base",
|
||||
network_policy="egress+published-ports",
|
||||
allow_host_compat=True,
|
||||
)
|
||||
workspace_id = str(created["workspace_id"])
|
||||
|
||||
workspace_path = tmp_path / "vms" / "workspaces" / workspace_id / "workspace.json"
|
||||
payload = json.loads(workspace_path.read_text(encoding="utf-8"))
|
||||
payload["network"] = {
|
||||
"vm_id": workspace_id,
|
||||
"tap_name": "tap-test0",
|
||||
"guest_ip": "172.29.1.2",
|
||||
"gateway_ip": "172.29.1.1",
|
||||
"subnet_cidr": "172.29.1.0/30",
|
||||
"mac_address": "06:00:ac:1d:01:02",
|
||||
"dns_servers": ["1.1.1.1"],
|
||||
}
|
||||
workspace_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"_start_workspace_service_published_ports",
|
||||
lambda **kwargs: [
|
||||
vm_manager_module.WorkspacePublishedPortRecord(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
host="127.0.0.1",
|
||||
protocol="tcp",
|
||||
proxy_pid=9999,
|
||||
)
|
||||
],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"_refresh_workspace_liveness_locked",
|
||||
lambda workspace: None,
|
||||
)
|
||||
|
||||
started = manager.start_service(
|
||||
workspace_id,
|
||||
"web",
|
||||
command="sh -lc 'touch .ready && while true; do sleep 60; done'",
|
||||
readiness={"type": "file", "path": ".ready"},
|
||||
published_ports=[{"guest_port": 8080, "host_port": 18080}],
|
||||
)
|
||||
|
||||
assert started["published_ports"] == [
|
||||
{
|
||||
"host": "127.0.0.1",
|
||||
"host_port": 18080,
|
||||
"guest_port": 8080,
|
||||
"protocol": "tcp",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_workspace_service_start_rejects_published_ports_without_network_policy(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
workspace_id = str(
|
||||
manager.create_workspace(
|
||||
environment="debian:12-base",
|
||||
allow_host_compat=True,
|
||||
)["workspace_id"]
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
RuntimeError,
|
||||
match="published ports require workspace network_policy 'egress\\+published-ports'",
|
||||
):
|
||||
manager.start_service(
|
||||
workspace_id,
|
||||
"web",
|
||||
command="sh -lc 'touch .ready && while true; do sleep 60; done'",
|
||||
readiness={"type": "file", "path": ".ready"},
|
||||
published_ports=[{"guest_port": 8080}],
|
||||
)
|
||||
|
||||
|
||||
def test_workspace_service_start_rejects_published_ports_without_active_network(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001
|
||||
supports_vm_boot=True,
|
||||
supports_guest_exec=True,
|
||||
supports_guest_network=True,
|
||||
)
|
||||
manager._ensure_workspace_guest_bootstrap_support = lambda instance: None # type: ignore[method-assign] # noqa: SLF001
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"_refresh_workspace_liveness_locked",
|
||||
lambda workspace: None,
|
||||
)
|
||||
workspace_id = str(
|
||||
manager.create_workspace(
|
||||
environment="debian:12-base",
|
||||
network_policy="egress+published-ports",
|
||||
allow_host_compat=True,
|
||||
)["workspace_id"]
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="published ports require an active guest network"):
|
||||
manager.start_service(
|
||||
workspace_id,
|
||||
"web",
|
||||
command="sh -lc 'touch .ready && while true; do sleep 60; done'",
|
||||
readiness={"type": "file", "path": ".ready"},
|
||||
published_ports=[{"guest_port": 8080}],
|
||||
)
|
||||
|
||||
|
||||
def test_workspace_service_start_published_port_failure_marks_service_failed(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001
|
||||
supports_vm_boot=True,
|
||||
supports_guest_exec=True,
|
||||
supports_guest_network=True,
|
||||
)
|
||||
manager._ensure_workspace_guest_bootstrap_support = lambda instance: None # type: ignore[method-assign] # noqa: SLF001
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"_refresh_workspace_liveness_locked",
|
||||
lambda workspace: None,
|
||||
)
|
||||
created = manager.create_workspace(
|
||||
environment="debian:12-base",
|
||||
network_policy="egress+published-ports",
|
||||
allow_host_compat=True,
|
||||
)
|
||||
workspace_id = str(created["workspace_id"])
|
||||
|
||||
workspace_path = tmp_path / "vms" / "workspaces" / workspace_id / "workspace.json"
|
||||
payload = json.loads(workspace_path.read_text(encoding="utf-8"))
|
||||
payload["network"] = {
|
||||
"vm_id": workspace_id,
|
||||
"tap_name": "tap-test0",
|
||||
"guest_ip": "172.29.1.2",
|
||||
"gateway_ip": "172.29.1.1",
|
||||
"subnet_cidr": "172.29.1.0/30",
|
||||
"mac_address": "06:00:ac:1d:01:02",
|
||||
"dns_servers": ["1.1.1.1"],
|
||||
}
|
||||
workspace_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
||||
|
||||
def _raise_proxy_failure(
|
||||
**kwargs: object,
|
||||
) -> list[vm_manager_module.WorkspacePublishedPortRecord]:
|
||||
del kwargs
|
||||
raise RuntimeError("proxy boom")
|
||||
|
||||
monkeypatch.setattr(manager, "_start_workspace_service_published_ports", _raise_proxy_failure)
|
||||
|
||||
started = manager.start_service(
|
||||
workspace_id,
|
||||
"web",
|
||||
command="sh -lc 'touch .ready && while true; do sleep 60; done'",
|
||||
readiness={"type": "file", "path": ".ready"},
|
||||
published_ports=[{"guest_port": 8080, "host_port": 18080}],
|
||||
)
|
||||
|
||||
assert started["state"] == "failed"
|
||||
assert started["stop_reason"] == "published_port_failed"
|
||||
assert started["published_ports"] == []
|
||||
|
||||
|
||||
def test_workspace_service_cleanup_stops_published_port_proxies(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
workspace_id = "workspace-cleanup"
|
||||
service = vm_manager_module.WorkspaceServiceRecord(
|
||||
workspace_id=workspace_id,
|
||||
service_name="web",
|
||||
command="sleep 60",
|
||||
cwd="/workspace",
|
||||
state="running",
|
||||
pid=1234,
|
||||
started_at=time.time(),
|
||||
ended_at=None,
|
||||
exit_code=None,
|
||||
execution_mode="host_compat",
|
||||
readiness=None,
|
||||
ready_at=None,
|
||||
stop_reason=None,
|
||||
published_ports=[
|
||||
vm_manager_module.WorkspacePublishedPortRecord(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
proxy_pid=9999,
|
||||
)
|
||||
],
|
||||
)
|
||||
manager._save_workspace_service_locked(service) # noqa: SLF001
|
||||
stopped: list[int | None] = []
|
||||
monkeypatch.setattr(
|
||||
vm_manager_module,
|
||||
"_stop_workspace_published_port_proxy",
|
||||
lambda published_port: stopped.append(published_port.proxy_pid),
|
||||
)
|
||||
|
||||
manager._delete_workspace_service_artifacts_locked(workspace_id, "web") # noqa: SLF001
|
||||
|
||||
assert stopped == [9999]
|
||||
assert not manager._workspace_service_record_path(workspace_id, "web").exists() # noqa: SLF001
|
||||
|
||||
|
||||
def test_workspace_refresh_workspace_service_counts_stops_published_ports_when_stopped(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
workspace = vm_manager_module.WorkspaceRecord(
|
||||
workspace_id="workspace-counts",
|
||||
environment="debian:12-base",
|
||||
vcpu_count=1,
|
||||
mem_mib=1024,
|
||||
ttl_seconds=600,
|
||||
created_at=time.time(),
|
||||
expires_at=time.time() + 600,
|
||||
state="stopped",
|
||||
firecracker_pid=None,
|
||||
last_error=None,
|
||||
allow_host_compat=True,
|
||||
network_policy="off",
|
||||
metadata={},
|
||||
command_count=0,
|
||||
last_command=None,
|
||||
workspace_seed={
|
||||
"mode": "empty",
|
||||
"seed_path": None,
|
||||
"destination": "/workspace",
|
||||
"entry_count": 0,
|
||||
"bytes_written": 0,
|
||||
},
|
||||
secrets=[],
|
||||
reset_count=0,
|
||||
last_reset_at=None,
|
||||
)
|
||||
service = vm_manager_module.WorkspaceServiceRecord(
|
||||
workspace_id=workspace.workspace_id,
|
||||
service_name="web",
|
||||
command="sleep 60",
|
||||
cwd="/workspace",
|
||||
state="running",
|
||||
pid=1234,
|
||||
started_at=time.time(),
|
||||
ended_at=None,
|
||||
exit_code=None,
|
||||
execution_mode="host_compat",
|
||||
readiness=None,
|
||||
ready_at=None,
|
||||
stop_reason=None,
|
||||
published_ports=[
|
||||
vm_manager_module.WorkspacePublishedPortRecord(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
proxy_pid=9999,
|
||||
)
|
||||
],
|
||||
)
|
||||
manager._save_workspace_service_locked(service) # noqa: SLF001
|
||||
stopped: list[int | None] = []
|
||||
monkeypatch.setattr(
|
||||
vm_manager_module,
|
||||
"_stop_workspace_published_port_proxy",
|
||||
lambda published_port: stopped.append(published_port.proxy_pid),
|
||||
)
|
||||
|
||||
manager._refresh_workspace_service_counts_locked(workspace) # noqa: SLF001
|
||||
|
||||
assert stopped == [9999]
|
||||
refreshed = manager._load_workspace_service_locked(workspace.workspace_id, "web") # noqa: SLF001
|
||||
assert refreshed.state == "stopped"
|
||||
assert refreshed.stop_reason == "workspace_stopped"
|
||||
|
||||
|
||||
def test_workspace_published_port_proxy_helpers(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
services_dir = tmp_path / "services"
|
||||
services_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
class StubProcess:
|
||||
def __init__(self, pid: int, *, exited: bool = False) -> None:
|
||||
self.pid = pid
|
||||
self._exited = exited
|
||||
|
||||
def poll(self) -> int | None:
|
||||
return 1 if self._exited else None
|
||||
|
||||
def _fake_popen(command: list[str], **kwargs: object) -> StubProcess:
|
||||
del kwargs
|
||||
ready_file = Path(command[command.index("--ready-file") + 1])
|
||||
ready_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"host": "127.0.0.1",
|
||||
"host_port": 18080,
|
||||
"target_host": "172.29.1.2",
|
||||
"target_port": 8080,
|
||||
"protocol": "tcp",
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return StubProcess(4242)
|
||||
|
||||
monkeypatch.setattr(subprocess, "Popen", _fake_popen)
|
||||
|
||||
record = vm_manager_module._start_workspace_published_port_proxy( # noqa: SLF001
|
||||
services_dir=services_dir,
|
||||
service_name="web",
|
||||
workspace_id="workspace-proxy",
|
||||
guest_ip="172.29.1.2",
|
||||
spec=vm_manager_module.WorkspacePublishedPortSpec(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
),
|
||||
)
|
||||
|
||||
assert record.guest_port == 8080
|
||||
assert record.host_port == 18080
|
||||
assert record.proxy_pid == 4242
|
||||
|
||||
|
||||
def test_workspace_published_port_proxy_timeout_and_stop(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
services_dir = tmp_path / "services"
|
||||
services_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
class StubProcess:
|
||||
pid = 4242
|
||||
|
||||
def poll(self) -> int | None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(subprocess, "Popen", lambda *args, **kwargs: StubProcess())
|
||||
monotonic_values = iter([0.0, 0.0, 5.1])
|
||||
monkeypatch.setattr(time, "monotonic", lambda: next(monotonic_values))
|
||||
monkeypatch.setattr(time, "sleep", lambda _: None)
|
||||
stopped: list[int | None] = []
|
||||
monkeypatch.setattr(
|
||||
vm_manager_module,
|
||||
"_stop_workspace_published_port_proxy",
|
||||
lambda published_port: stopped.append(published_port.proxy_pid),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="timed out waiting for published port proxy readiness"):
|
||||
vm_manager_module._start_workspace_published_port_proxy( # noqa: SLF001
|
||||
services_dir=services_dir,
|
||||
service_name="web",
|
||||
workspace_id="workspace-proxy",
|
||||
guest_ip="172.29.1.2",
|
||||
spec=vm_manager_module.WorkspacePublishedPortSpec(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
),
|
||||
)
|
||||
|
||||
assert stopped == [4242]
|
||||
|
||||
|
||||
def test_workspace_published_port_validation_and_stop_helper(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
spec = vm_manager_module._normalize_workspace_published_port( # noqa: SLF001
|
||||
guest_port="8080",
|
||||
host_port="18080",
|
||||
)
|
||||
assert spec.guest_port == 8080
|
||||
assert spec.host_port == 18080
|
||||
with pytest.raises(ValueError, match="published guest_port must be an integer"):
|
||||
vm_manager_module._normalize_workspace_published_port(guest_port=object()) # noqa: SLF001
|
||||
with pytest.raises(ValueError, match="published host_port must be between 1025 and 65535"):
|
||||
vm_manager_module._normalize_workspace_published_port( # noqa: SLF001
|
||||
guest_port=8080,
|
||||
host_port=80,
|
||||
)
|
||||
|
||||
signals: list[int] = []
|
||||
monkeypatch.setattr(os, "killpg", lambda pid, sig: signals.append(sig))
|
||||
running = iter([True, False])
|
||||
monkeypatch.setattr(vm_manager_module, "_pid_is_running", lambda pid: next(running))
|
||||
monkeypatch.setattr(time, "sleep", lambda _: None)
|
||||
|
||||
vm_manager_module._stop_workspace_published_port_proxy( # noqa: SLF001
|
||||
vm_manager_module.WorkspacePublishedPortRecord(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
proxy_pid=9999,
|
||||
)
|
||||
)
|
||||
|
||||
assert signals == [signal.SIGTERM]
|
||||
|
||||
|
||||
def test_workspace_network_policy_requires_guest_network_support(tmp_path: Path) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="firecracker",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
manager._runtime_capabilities = RuntimeCapabilities( # noqa: SLF001
|
||||
supports_vm_boot=False,
|
||||
supports_guest_exec=False,
|
||||
supports_guest_network=False,
|
||||
reason="no guest network",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="workspace network_policy requires guest networking"):
|
||||
manager._require_workspace_network_policy_support( # noqa: SLF001
|
||||
network_policy="egress"
|
||||
)
|
||||
|
||||
|
||||
def test_prepare_workspace_seed_rejects_missing_and_invalid_paths(tmp_path: Path) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
manager._prepare_workspace_seed(tmp_path / "missing") # noqa: SLF001
|
||||
|
||||
invalid_source = tmp_path / "seed.txt"
|
||||
invalid_source.write_text("seed", encoding="utf-8")
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="seed_path must be a directory or a .tar/.tar.gz/.tgz archive",
|
||||
):
|
||||
manager._prepare_workspace_seed(invalid_source) # noqa: SLF001
|
||||
|
||||
|
||||
def test_workspace_baseline_snapshot_requires_archive(tmp_path: Path) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
created = manager.create_workspace(
|
||||
environment="debian:12",
|
||||
vcpu_count=1,
|
||||
mem_mib=512,
|
||||
ttl_seconds=600,
|
||||
allow_host_compat=True,
|
||||
)
|
||||
workspace_id = str(created["workspace_id"])
|
||||
baseline_path = tmp_path / "vms" / "workspaces" / workspace_id / "baseline" / "workspace.tar"
|
||||
baseline_path.unlink()
|
||||
workspace = manager._load_workspace_locked(workspace_id) # noqa: SLF001
|
||||
|
||||
with pytest.raises(RuntimeError, match="baseline snapshot"):
|
||||
manager._workspace_baseline_snapshot_locked(workspace) # noqa: SLF001
|
||||
|
||||
|
||||
def test_workspace_snapshot_and_service_loaders_handle_invalid_payloads(tmp_path: Path) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
workspace_id = "workspace-invalid"
|
||||
services_dir = tmp_path / "vms" / "workspaces" / workspace_id / "services"
|
||||
snapshots_dir = tmp_path / "vms" / "workspaces" / workspace_id / "snapshots"
|
||||
services_dir.mkdir(parents=True, exist_ok=True)
|
||||
snapshots_dir.mkdir(parents=True, exist_ok=True)
|
||||
(services_dir / "svc.json").write_text("[]", encoding="utf-8")
|
||||
(snapshots_dir / "snap.json").write_text("[]", encoding="utf-8")
|
||||
|
||||
with pytest.raises(RuntimeError, match="service record"):
|
||||
manager._load_workspace_service_locked(workspace_id, "svc") # noqa: SLF001
|
||||
with pytest.raises(RuntimeError, match="snapshot record"):
|
||||
manager._load_workspace_snapshot_locked(workspace_id, "snap") # noqa: SLF001
|
||||
with pytest.raises(RuntimeError, match="snapshot record"):
|
||||
manager._load_workspace_snapshot_locked_optional(workspace_id, "snap") # noqa: SLF001
|
||||
assert manager._load_workspace_snapshot_locked_optional(workspace_id, "missing") is None # noqa: SLF001
|
||||
|
||||
|
||||
def test_workspace_shell_helpers_handle_missing_invalid_and_close_errors(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
created = manager.create_workspace(
|
||||
environment="debian:12",
|
||||
vcpu_count=1,
|
||||
mem_mib=512,
|
||||
ttl_seconds=600,
|
||||
allow_host_compat=True,
|
||||
)
|
||||
workspace_id = str(created["workspace_id"])
|
||||
|
||||
assert manager._list_workspace_shells_locked(workspace_id) == [] # noqa: SLF001
|
||||
|
||||
shells_dir = tmp_path / "vms" / "workspaces" / workspace_id / "shells"
|
||||
shells_dir.mkdir(parents=True, exist_ok=True)
|
||||
(shells_dir / "invalid.json").write_text("[]", encoding="utf-8")
|
||||
assert manager._list_workspace_shells_locked(workspace_id) == [] # noqa: SLF001
|
||||
|
||||
shell = vm_manager_module.WorkspaceShellRecord(
|
||||
workspace_id=workspace_id,
|
||||
shell_id="shell-1",
|
||||
cwd="/workspace",
|
||||
cols=120,
|
||||
rows=30,
|
||||
state="running",
|
||||
started_at=time.time(),
|
||||
)
|
||||
manager._save_workspace_shell_locked(shell) # noqa: SLF001
|
||||
workspace = manager._load_workspace_locked(workspace_id) # noqa: SLF001
|
||||
instance = workspace.to_instance(
|
||||
workdir=tmp_path / "vms" / "workspaces" / workspace_id / "runtime"
|
||||
)
|
||||
|
||||
def _raise_close(**kwargs: object) -> dict[str, object]:
|
||||
del kwargs
|
||||
raise RuntimeError("shell close boom")
|
||||
|
||||
monkeypatch.setattr(manager._backend, "close_shell", _raise_close)
|
||||
manager._close_workspace_shells_locked(workspace, instance) # noqa: SLF001
|
||||
assert manager._list_workspace_shells_locked(workspace_id) == [] # noqa: SLF001
|
||||
|
||||
|
||||
def test_workspace_refresh_service_helpers_cover_exit_and_started_refresh(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
created = manager.create_workspace(
|
||||
environment="debian:12",
|
||||
vcpu_count=1,
|
||||
mem_mib=512,
|
||||
ttl_seconds=600,
|
||||
allow_host_compat=True,
|
||||
)
|
||||
workspace_id = str(created["workspace_id"])
|
||||
workspace_path = tmp_path / "vms" / "workspaces" / workspace_id / "workspace.json"
|
||||
payload = json.loads(workspace_path.read_text(encoding="utf-8"))
|
||||
payload["state"] = "started"
|
||||
payload["network"] = {
|
||||
"vm_id": workspace_id,
|
||||
"tap_name": "tap-test0",
|
||||
"guest_ip": "172.29.1.2",
|
||||
"gateway_ip": "172.29.1.1",
|
||||
"subnet_cidr": "172.29.1.0/30",
|
||||
"mac_address": "06:00:ac:1d:01:02",
|
||||
"dns_servers": ["1.1.1.1"],
|
||||
}
|
||||
workspace_path.write_text(json.dumps(payload, indent=2, sort_keys=True), encoding="utf-8")
|
||||
workspace = manager._load_workspace_locked(workspace_id) # noqa: SLF001
|
||||
instance = workspace.to_instance(
|
||||
workdir=tmp_path / "vms" / "workspaces" / workspace_id / "runtime"
|
||||
)
|
||||
|
||||
service = vm_manager_module.WorkspaceServiceRecord(
|
||||
workspace_id=workspace_id,
|
||||
service_name="web",
|
||||
command="sleep 60",
|
||||
cwd="/workspace",
|
||||
state="running",
|
||||
started_at=time.time(),
|
||||
execution_mode="guest_vsock",
|
||||
published_ports=[
|
||||
vm_manager_module.WorkspacePublishedPortRecord(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
proxy_pid=9999,
|
||||
)
|
||||
],
|
||||
)
|
||||
manager._save_workspace_service_locked(service) # noqa: SLF001
|
||||
stopped: list[int | None] = []
|
||||
monkeypatch.setattr(
|
||||
vm_manager_module,
|
||||
"_stop_workspace_published_port_proxy",
|
||||
lambda published_port: stopped.append(published_port.proxy_pid),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
manager._backend,
|
||||
"status_service",
|
||||
lambda *args, **kwargs: {
|
||||
"service_name": "web",
|
||||
"command": "sleep 60",
|
||||
"cwd": "/workspace",
|
||||
"state": "exited",
|
||||
"started_at": service.started_at,
|
||||
"ended_at": service.started_at + 1,
|
||||
"exit_code": 0,
|
||||
"execution_mode": "guest_vsock",
|
||||
},
|
||||
)
|
||||
|
||||
refreshed = manager._refresh_workspace_service_locked( # noqa: SLF001
|
||||
workspace,
|
||||
instance,
|
||||
service,
|
||||
)
|
||||
assert refreshed.state == "exited"
|
||||
assert refreshed.published_ports == [
|
||||
vm_manager_module.WorkspacePublishedPortRecord(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
proxy_pid=None,
|
||||
)
|
||||
]
|
||||
assert stopped == [9999]
|
||||
|
||||
manager._save_workspace_service_locked(service) # noqa: SLF001
|
||||
refreshed_calls: list[str] = []
|
||||
monkeypatch.setattr(manager, "_require_workspace_service_support", lambda instance: None)
|
||||
|
||||
def _refresh_services(
|
||||
workspace: vm_manager_module.WorkspaceRecord,
|
||||
instance: vm_manager_module.VmInstance,
|
||||
) -> list[vm_manager_module.WorkspaceServiceRecord]:
|
||||
del instance
|
||||
refreshed_calls.append(workspace.workspace_id)
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
manager,
|
||||
"_refresh_workspace_services_locked",
|
||||
_refresh_services,
|
||||
)
|
||||
manager._refresh_workspace_service_counts_locked(workspace) # noqa: SLF001
|
||||
assert refreshed_calls == [workspace_id]
|
||||
|
||||
|
||||
def test_workspace_start_published_ports_cleans_up_partial_failure(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
base_dir=tmp_path / "vms",
|
||||
network_manager=TapNetworkManager(enabled=False),
|
||||
)
|
||||
created = manager.create_workspace(
|
||||
environment="debian:12",
|
||||
vcpu_count=1,
|
||||
mem_mib=512,
|
||||
ttl_seconds=600,
|
||||
allow_host_compat=True,
|
||||
)
|
||||
workspace = manager._load_workspace_locked(str(created["workspace_id"])) # noqa: SLF001
|
||||
service = vm_manager_module.WorkspaceServiceRecord(
|
||||
workspace_id=workspace.workspace_id,
|
||||
service_name="web",
|
||||
command="sleep 60",
|
||||
cwd="/workspace",
|
||||
state="running",
|
||||
started_at=time.time(),
|
||||
execution_mode="guest_vsock",
|
||||
)
|
||||
started_record = vm_manager_module.WorkspacePublishedPortRecord(
|
||||
guest_port=8080,
|
||||
host_port=18080,
|
||||
proxy_pid=9999,
|
||||
)
|
||||
calls: list[int] = []
|
||||
|
||||
def _start_proxy(**kwargs: object) -> vm_manager_module.WorkspacePublishedPortRecord:
|
||||
spec = cast(vm_manager_module.WorkspacePublishedPortSpec, kwargs["spec"])
|
||||
if spec.guest_port == 8080:
|
||||
return started_record
|
||||
raise RuntimeError("proxy boom")
|
||||
|
||||
monkeypatch.setattr(vm_manager_module, "_start_workspace_published_port_proxy", _start_proxy)
|
||||
monkeypatch.setattr(
|
||||
vm_manager_module,
|
||||
"_stop_workspace_published_port_proxy",
|
||||
lambda published_port: calls.append(published_port.proxy_pid or -1),
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError, match="proxy boom"):
|
||||
manager._start_workspace_service_published_ports( # noqa: SLF001
|
||||
workspace=workspace,
|
||||
service=service,
|
||||
guest_ip="172.29.1.2",
|
||||
published_ports=[
|
||||
vm_manager_module.WorkspacePublishedPortSpec(guest_port=8080),
|
||||
vm_manager_module.WorkspacePublishedPortSpec(guest_port=9090),
|
||||
],
|
||||
)
|
||||
|
||||
assert calls == [9999]
|
||||
|
||||
|
||||
def test_workspace_service_start_replaces_non_running_record(tmp_path: Path) -> None:
|
||||
manager = VmManager(
|
||||
backend_name="mock",
|
||||
|
|
|
|||
289
tests/test_workspace_ports.py
Normal file
289
tests/test_workspace_ports.py
Normal file
|
|
@ -0,0 +1,289 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import selectors
|
||||
import signal
|
||||
import socket
|
||||
import socketserver
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
|
||||
import pytest
|
||||
|
||||
from pyro_mcp import workspace_ports
|
||||
|
||||
|
||||
class _EchoHandler(socketserver.BaseRequestHandler):
|
||||
def handle(self) -> None:
|
||||
data = self.request.recv(65536)
|
||||
if data:
|
||||
self.request.sendall(data)
|
||||
|
||||
|
||||
def test_workspace_port_proxy_handler_rejects_invalid_server() -> None:
|
||||
handler = workspace_ports._ProxyHandler.__new__(workspace_ports._ProxyHandler) # noqa: SLF001
|
||||
handler.server = cast(Any, object())
|
||||
handler.request = object()
|
||||
|
||||
with pytest.raises(RuntimeError, match="proxy server is invalid"):
|
||||
handler.handle()
|
||||
|
||||
|
||||
def test_workspace_port_proxy_handler_ignores_upstream_connect_failure(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
handler = workspace_ports._ProxyHandler.__new__(workspace_ports._ProxyHandler) # noqa: SLF001
|
||||
server = workspace_ports._ProxyServer.__new__(workspace_ports._ProxyServer) # noqa: SLF001
|
||||
server.target_address = ("127.0.0.1", 12345)
|
||||
handler.server = server
|
||||
handler.request = object()
|
||||
|
||||
def _raise_connect(*args: Any, **kwargs: Any) -> socket.socket:
|
||||
del args, kwargs
|
||||
raise OSError("boom")
|
||||
|
||||
monkeypatch.setattr(socket, "create_connection", _raise_connect)
|
||||
|
||||
handler.handle()
|
||||
|
||||
|
||||
def test_workspace_port_proxy_forwards_tcp_traffic() -> None:
|
||||
upstream = socketserver.ThreadingTCPServer(
|
||||
(workspace_ports.DEFAULT_PUBLISHED_PORT_HOST, 0),
|
||||
_EchoHandler,
|
||||
)
|
||||
upstream_thread = threading.Thread(target=upstream.serve_forever, daemon=True)
|
||||
upstream_thread.start()
|
||||
upstream_host = str(upstream.server_address[0])
|
||||
upstream_port = int(upstream.server_address[1])
|
||||
proxy = workspace_ports._ProxyServer( # noqa: SLF001
|
||||
(workspace_ports.DEFAULT_PUBLISHED_PORT_HOST, 0),
|
||||
(upstream_host, upstream_port),
|
||||
)
|
||||
proxy_thread = threading.Thread(target=proxy.serve_forever, daemon=True)
|
||||
proxy_thread.start()
|
||||
try:
|
||||
proxy_host = str(proxy.server_address[0])
|
||||
proxy_port = int(proxy.server_address[1])
|
||||
with socket.create_connection((proxy_host, proxy_port), timeout=5) as client:
|
||||
client.sendall(b"hello")
|
||||
received = client.recv(65536)
|
||||
assert received == b"hello"
|
||||
finally:
|
||||
proxy.shutdown()
|
||||
proxy.server_close()
|
||||
upstream.shutdown()
|
||||
upstream.server_close()
|
||||
|
||||
|
||||
def test_workspace_ports_main_writes_ready_file(
|
||||
tmp_path: Path,
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
ready_file = tmp_path / "proxy.ready.json"
|
||||
signals: list[int] = []
|
||||
|
||||
class StubProxyServer:
|
||||
def __init__(
|
||||
self,
|
||||
server_address: tuple[str, int],
|
||||
target_address: tuple[str, int],
|
||||
) -> None:
|
||||
self.server_address = (server_address[0], 18080)
|
||||
self.target_address = target_address
|
||||
|
||||
def serve_forever(self, poll_interval: float = 0.2) -> None:
|
||||
assert poll_interval == 0.2
|
||||
|
||||
def shutdown(self) -> None:
|
||||
return None
|
||||
|
||||
def server_close(self) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(workspace_ports, "_ProxyServer", StubProxyServer)
|
||||
monkeypatch.setattr(
|
||||
signal,
|
||||
"signal",
|
||||
lambda signum, handler: signals.append(signum),
|
||||
)
|
||||
|
||||
result = workspace_ports.main(
|
||||
[
|
||||
"--listen-host",
|
||||
"127.0.0.1",
|
||||
"--listen-port",
|
||||
"0",
|
||||
"--target-host",
|
||||
"172.29.1.2",
|
||||
"--target-port",
|
||||
"8080",
|
||||
"--ready-file",
|
||||
str(ready_file),
|
||||
]
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
payload = json.loads(ready_file.read_text(encoding="utf-8"))
|
||||
assert payload == {
|
||||
"host": "127.0.0.1",
|
||||
"host_port": 18080,
|
||||
"protocol": "tcp",
|
||||
"target_host": "172.29.1.2",
|
||||
"target_port": 8080,
|
||||
}
|
||||
assert signals == [signal.SIGTERM, signal.SIGINT]
|
||||
|
||||
|
||||
def test_workspace_ports_main_shutdown_handler_stops_server(
|
||||
tmp_path: Path,
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
ready_file = tmp_path / "proxy.ready.json"
|
||||
shutdown_called: list[bool] = []
|
||||
handlers: dict[int, Any] = {}
|
||||
|
||||
class StubProxyServer:
|
||||
def __init__(
|
||||
self,
|
||||
server_address: tuple[str, int],
|
||||
target_address: tuple[str, int],
|
||||
) -> None:
|
||||
self.server_address = server_address
|
||||
self.target_address = target_address
|
||||
|
||||
def serve_forever(self, poll_interval: float = 0.2) -> None:
|
||||
handlers[signal.SIGTERM](signal.SIGTERM, None)
|
||||
assert poll_interval == 0.2
|
||||
|
||||
def shutdown(self) -> None:
|
||||
shutdown_called.append(True)
|
||||
|
||||
def server_close(self) -> None:
|
||||
return None
|
||||
|
||||
class ImmediateThread:
|
||||
def __init__(self, *, target: Any, daemon: bool) -> None:
|
||||
self._target = target
|
||||
assert daemon is True
|
||||
|
||||
def start(self) -> None:
|
||||
self._target()
|
||||
|
||||
monkeypatch.setattr(workspace_ports, "_ProxyServer", StubProxyServer)
|
||||
monkeypatch.setattr(
|
||||
signal,
|
||||
"signal",
|
||||
lambda signum, handler: handlers.__setitem__(signum, handler),
|
||||
)
|
||||
monkeypatch.setattr(threading, "Thread", ImmediateThread)
|
||||
|
||||
result = workspace_ports.main(
|
||||
[
|
||||
"--listen-host",
|
||||
"127.0.0.1",
|
||||
"--listen-port",
|
||||
"18080",
|
||||
"--target-host",
|
||||
"172.29.1.2",
|
||||
"--target-port",
|
||||
"8080",
|
||||
"--ready-file",
|
||||
str(ready_file),
|
||||
]
|
||||
)
|
||||
|
||||
assert result == 0
|
||||
assert shutdown_called == [True]
|
||||
|
||||
|
||||
def test_workspace_port_proxy_handler_handles_empty_and_invalid_selector_events(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
source, source_peer = socket.socketpair()
|
||||
upstream, upstream_peer = socket.socketpair()
|
||||
source_peer.close()
|
||||
|
||||
class FakeSelector:
|
||||
def __init__(self) -> None:
|
||||
self._events = iter(
|
||||
[
|
||||
[],
|
||||
[(SimpleNamespace(fileobj=object(), data=object()), None)],
|
||||
[(SimpleNamespace(fileobj=source, data=upstream), None)],
|
||||
]
|
||||
)
|
||||
|
||||
def register(self, *_args: Any, **_kwargs: Any) -> None:
|
||||
return None
|
||||
|
||||
def select(self) -> list[tuple[SimpleNamespace, None]]:
|
||||
return next(self._events)
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
handler = workspace_ports._ProxyHandler.__new__(workspace_ports._ProxyHandler) # noqa: SLF001
|
||||
server = workspace_ports._ProxyServer.__new__(workspace_ports._ProxyServer) # noqa: SLF001
|
||||
server.target_address = ("127.0.0.1", 12345)
|
||||
handler.server = server
|
||||
handler.request = source
|
||||
|
||||
monkeypatch.setattr(socket, "create_connection", lambda *args, **kwargs: upstream)
|
||||
monkeypatch.setattr(selectors, "DefaultSelector", FakeSelector)
|
||||
|
||||
try:
|
||||
handler.handle()
|
||||
finally:
|
||||
source.close()
|
||||
upstream.close()
|
||||
upstream_peer.close()
|
||||
|
||||
|
||||
def test_workspace_port_proxy_handler_handles_recv_and_send_errors(
|
||||
monkeypatch: Any,
|
||||
) -> None:
|
||||
def _run_once(*, close_source: bool) -> None:
|
||||
source, source_peer = socket.socketpair()
|
||||
upstream, upstream_peer = socket.socketpair()
|
||||
if not close_source:
|
||||
source_peer.sendall(b"hello")
|
||||
|
||||
class FakeSelector:
|
||||
def register(self, *_args: Any, **_kwargs: Any) -> None:
|
||||
return None
|
||||
|
||||
def select(self) -> list[tuple[SimpleNamespace, None]]:
|
||||
if close_source:
|
||||
source.close()
|
||||
else:
|
||||
upstream.close()
|
||||
return [(SimpleNamespace(fileobj=source, data=upstream), None)]
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
handler = workspace_ports._ProxyHandler.__new__(workspace_ports._ProxyHandler) # noqa: SLF001
|
||||
server = workspace_ports._ProxyServer.__new__(workspace_ports._ProxyServer) # noqa: SLF001
|
||||
server.target_address = ("127.0.0.1", 12345)
|
||||
handler.server = server
|
||||
handler.request = source
|
||||
|
||||
monkeypatch.setattr(socket, "create_connection", lambda *args, **kwargs: upstream)
|
||||
monkeypatch.setattr(selectors, "DefaultSelector", FakeSelector)
|
||||
|
||||
try:
|
||||
handler.handle()
|
||||
finally:
|
||||
source_peer.close()
|
||||
if close_source:
|
||||
upstream.close()
|
||||
upstream_peer.close()
|
||||
else:
|
||||
source.close()
|
||||
upstream_peer.close()
|
||||
|
||||
_run_once(close_source=True)
|
||||
_run_once(close_source=False)
|
||||
Loading…
Add table
Add a link
Reference in a new issue