Add workspace network policy and published ports

Replace the workspace-level boolean network toggle with explicit network policies and attach localhost TCP publication to workspace services.

Persist network_policy in workspace records, validate --publish requests, and run host-side proxy helpers that follow the service lifecycle so published ports are cleaned up on failure, stop, reset, and delete.

Update the CLI, SDK, MCP contract, docs, roadmap, and examples for the new policy model, add coverage for the proxy and manager edge cases, and validate with uv lock, UV_CACHE_DIR=.uv-cache make check, UV_CACHE_DIR=.uv-cache make dist-check, and a real guest-backed published-port probe smoke.
This commit is contained in:
Thales Maciel 2026-03-12 18:12:57 -03:00
parent fc72fcd3a1
commit c82f4629b2
21 changed files with 1944 additions and 49 deletions

View file

@ -13,6 +13,7 @@ from pyro_mcp.vm_manager import (
DEFAULT_TIMEOUT_SECONDS,
DEFAULT_TTL_SECONDS,
DEFAULT_VCPU_COUNT,
DEFAULT_WORKSPACE_NETWORK_POLICY,
VmManager,
)
@ -84,7 +85,7 @@ class Pyro:
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
network_policy: str = DEFAULT_WORKSPACE_NETWORK_POLICY,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
seed_path: str | Path | None = None,
secrets: list[dict[str, str]] | None = None,
@ -94,7 +95,7 @@ class Pyro:
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
network=network,
network_policy=network_policy,
allow_host_compat=allow_host_compat,
seed_path=seed_path,
secrets=secrets,
@ -241,6 +242,7 @@ class Pyro:
ready_timeout_seconds: int = 30,
ready_interval_ms: int = 500,
secret_env: dict[str, str] | None = None,
published_ports: list[dict[str, int | None]] | None = None,
) -> dict[str, Any]:
return self._manager.start_service(
workspace_id,
@ -251,6 +253,7 @@ class Pyro:
ready_timeout_seconds=ready_timeout_seconds,
ready_interval_ms=ready_interval_ms,
secret_env=secret_env,
published_ports=published_ports,
)
def list_services(self, workspace_id: str) -> dict[str, Any]:
@ -408,7 +411,7 @@ class Pyro:
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
network_policy: str = DEFAULT_WORKSPACE_NETWORK_POLICY,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
seed_path: str | None = None,
secrets: list[dict[str, str]] | None = None,
@ -419,7 +422,7 @@ class Pyro:
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
network=network,
network_policy=network_policy,
allow_host_compat=allow_host_compat,
seed_path=seed_path,
secrets=secrets,
@ -574,6 +577,7 @@ class Pyro:
ready_timeout_seconds: int = 30,
ready_interval_ms: int = 500,
secret_env: dict[str, str] | None = None,
published_ports: list[dict[str, int | None]] | None = None,
) -> dict[str, Any]:
"""Start a named long-running service inside a workspace."""
readiness: dict[str, Any] | None = None
@ -594,6 +598,7 @@ class Pyro:
ready_timeout_seconds=ready_timeout_seconds,
ready_interval_ms=ready_interval_ms,
secret_env=secret_env,
published_ports=published_ports,
)
@server.tool()

View file

@ -160,6 +160,7 @@ def _print_workspace_summary_human(payload: dict[str, Any], *, action: str) -> N
print(f"Environment: {str(payload.get('environment', 'unknown'))}")
print(f"State: {str(payload.get('state', 'unknown'))}")
print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}")
print(f"Network policy: {str(payload.get('network_policy', 'off'))}")
workspace_seed = payload.get("workspace_seed")
if isinstance(workspace_seed, dict):
mode = str(workspace_seed.get("mode", "empty"))
@ -378,13 +379,27 @@ def _print_workspace_shell_read_human(payload: dict[str, Any]) -> None:
def _print_workspace_service_summary_human(payload: dict[str, Any], *, prefix: str) -> None:
published_ports = payload.get("published_ports")
published_text = ""
if isinstance(published_ports, list) and published_ports:
parts = []
for item in published_ports:
if not isinstance(item, dict):
continue
parts.append(
f"{str(item.get('host', '127.0.0.1'))}:{int(item.get('host_port', 0))}"
f"->{int(item.get('guest_port', 0))}/{str(item.get('protocol', 'tcp'))}"
)
if parts:
published_text = " published_ports=" + ",".join(parts)
print(
f"[{prefix}] "
f"workspace_id={str(payload.get('workspace_id', 'unknown'))} "
f"service_name={str(payload.get('service_name', 'unknown'))} "
f"state={str(payload.get('state', 'unknown'))} "
f"cwd={str(payload.get('cwd', WORKSPACE_GUEST_PATH))} "
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}",
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}"
f"{published_text}",
file=sys.stderr,
flush=True,
)
@ -402,6 +417,18 @@ def _print_workspace_service_list_human(payload: dict[str, Any]) -> None:
f"{str(service.get('service_name', 'unknown'))} "
f"[{str(service.get('state', 'unknown'))}] "
f"cwd={str(service.get('cwd', WORKSPACE_GUEST_PATH))}"
+ (
" published="
+ ",".join(
f"{str(item.get('host', '127.0.0.1'))}:{int(item.get('host_port', 0))}"
f"->{int(item.get('guest_port', 0))}/{str(item.get('protocol', 'tcp'))}"
for item in service.get("published_ports", [])
if isinstance(item, dict)
)
if isinstance(service.get("published_ports"), list)
and service.get("published_ports")
else ""
)
)
@ -683,6 +710,7 @@ def _build_parser() -> argparse.ArgumentParser:
Examples:
pyro workspace create debian:12
pyro workspace create debian:12 --seed-path ./repo
pyro workspace create debian:12 --network-policy egress
pyro workspace create debian:12 --secret API_TOKEN=expected
pyro workspace sync push WORKSPACE_ID ./changes
pyro workspace snapshot create WORKSPACE_ID checkpoint
@ -718,9 +746,10 @@ def _build_parser() -> argparse.ArgumentParser:
help="Time-to-live for the workspace before automatic cleanup.",
)
workspace_create_parser.add_argument(
"--network",
action="store_true",
help="Enable outbound guest networking for the workspace guest.",
"--network-policy",
choices=("off", "egress", "egress+published-ports"),
default="off",
help="Workspace network policy.",
)
workspace_create_parser.add_argument(
"--allow-host-compat",
@ -1204,6 +1233,8 @@ def _build_parser() -> argparse.ArgumentParser:
Examples:
pyro workspace service start WORKSPACE_ID app --ready-file .ready -- \
sh -lc 'touch .ready && while true; do sleep 60; done'
pyro workspace service start WORKSPACE_ID app --ready-file .ready --publish 8080 -- \
sh -lc 'touch .ready && python3 -m http.server 8080'
pyro workspace service list WORKSPACE_ID
pyro workspace service status WORKSPACE_ID app
pyro workspace service logs WORKSPACE_ID app --tail-lines 50
@ -1229,6 +1260,9 @@ def _build_parser() -> argparse.ArgumentParser:
Examples:
pyro workspace service start WORKSPACE_ID app --ready-file .ready -- \
sh -lc 'touch .ready && while true; do sleep 60; done'
pyro workspace service start WORKSPACE_ID app \
--ready-file .ready --publish 18080:8080 -- \
sh -lc 'touch .ready && python3 -m http.server 8080'
pyro workspace service start WORKSPACE_ID app --secret-env API_TOKEN -- \
sh -lc 'test \"$API_TOKEN\" = \"expected\"; touch .ready; \
while true; do sleep 60; done'
@ -1280,6 +1314,16 @@ while true; do sleep 60; done'
metavar="SECRET[=ENV_VAR]",
help="Expose one persisted workspace secret as an environment variable for this service.",
)
workspace_service_start_parser.add_argument(
"--publish",
action="append",
default=[],
metavar="GUEST_PORT|HOST_PORT:GUEST_PORT",
help=(
"Publish one guest TCP port on 127.0.0.1. Requires workspace network policy "
"`egress+published-ports`."
),
)
workspace_service_start_parser.add_argument(
"--json",
action="store_true",
@ -1528,6 +1572,33 @@ def _parse_workspace_secret_env_options(values: list[str]) -> dict[str, str]:
return parsed
def _parse_workspace_publish_options(values: list[str]) -> list[dict[str, int | None]]:
parsed: list[dict[str, int | None]] = []
for raw_value in values:
candidate = raw_value.strip()
if candidate == "":
raise ValueError("published ports must not be empty")
if ":" in candidate:
raw_host_port, raw_guest_port = candidate.split(":", 1)
try:
host_port = int(raw_host_port)
guest_port = int(raw_guest_port)
except ValueError as exc:
raise ValueError(
"published ports must use GUEST_PORT or HOST_PORT:GUEST_PORT"
) from exc
parsed.append({"host_port": host_port, "guest_port": guest_port})
else:
try:
guest_port = int(candidate)
except ValueError as exc:
raise ValueError(
"published ports must use GUEST_PORT or HOST_PORT:GUEST_PORT"
) from exc
parsed.append({"host_port": None, "guest_port": guest_port})
return parsed
def main() -> None:
args = _build_parser().parse_args()
pyro = Pyro()
@ -1634,7 +1705,7 @@ def main() -> None:
vcpu_count=args.vcpu_count,
mem_mib=args.mem_mib,
ttl_seconds=args.ttl_seconds,
network=args.network,
network_policy=getattr(args, "network_policy", "off"),
allow_host_compat=args.allow_host_compat,
seed_path=args.seed_path,
secrets=secrets or None,
@ -1932,6 +2003,7 @@ def main() -> None:
readiness = {"type": "command", "command": args.ready_command}
command = _require_command(args.command_args)
secret_env = _parse_workspace_secret_env_options(getattr(args, "secret_env", []))
published_ports = _parse_workspace_publish_options(getattr(args, "publish", []))
try:
payload = pyro.start_service(
args.workspace_id,
@ -1942,6 +2014,7 @@ def main() -> None:
ready_timeout_seconds=args.ready_timeout_seconds,
ready_interval_ms=args.ready_interval_ms,
secret_env=secret_env or None,
published_ports=published_ports or None,
)
except Exception as exc: # noqa: BLE001
if bool(args.json):

View file

@ -27,7 +27,7 @@ PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
"--vcpu-count",
"--mem-mib",
"--ttl-seconds",
"--network",
"--network-policy",
"--allow-host-compat",
"--seed-path",
"--secret",
@ -49,6 +49,7 @@ PUBLIC_CLI_WORKSPACE_SERVICE_START_FLAGS = (
"--ready-timeout-seconds",
"--ready-interval-ms",
"--secret-env",
"--publish",
"--json",
)
PUBLIC_CLI_WORKSPACE_SERVICE_STATUS_FLAGS = ("--json",)

View file

@ -19,7 +19,7 @@ from typing import Any
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
DEFAULT_CATALOG_VERSION = "2.9.0"
DEFAULT_CATALOG_VERSION = "2.10.0"
OCI_MANIFEST_ACCEPT = ", ".join(
(
"application/vnd.oci.image.index.v1+json",

View file

@ -12,6 +12,7 @@ import shutil
import signal
import socket
import subprocess
import sys
import tarfile
import tempfile
import threading
@ -33,6 +34,7 @@ from pyro_mcp.vm_environments import EnvironmentStore, default_cache_dir, get_en
from pyro_mcp.vm_firecracker import build_launch_plan
from pyro_mcp.vm_guest import VsockExecClient
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
from pyro_mcp.workspace_ports import DEFAULT_PUBLISHED_PORT_HOST
from pyro_mcp.workspace_shells import (
create_local_shell,
get_local_shell,
@ -43,6 +45,7 @@ from pyro_mcp.workspace_shells import (
VmState = Literal["created", "started", "stopped"]
WorkspaceShellState = Literal["running", "stopped"]
WorkspaceServiceState = Literal["running", "exited", "stopped", "failed"]
WorkspaceNetworkPolicy = Literal["off", "egress", "egress+published-ports"]
DEFAULT_VCPU_COUNT = 1
DEFAULT_MEM_MIB = 1024
@ -50,7 +53,7 @@ DEFAULT_TIMEOUT_SECONDS = 30
DEFAULT_TTL_SECONDS = 600
DEFAULT_ALLOW_HOST_COMPAT = False
WORKSPACE_LAYOUT_VERSION = 6
WORKSPACE_LAYOUT_VERSION = 7
WORKSPACE_BASELINE_DIRNAME = "baseline"
WORKSPACE_BASELINE_ARCHIVE_NAME = "workspace.tar"
WORKSPACE_SNAPSHOTS_DIRNAME = "snapshots"
@ -72,6 +75,7 @@ DEFAULT_SHELL_MAX_CHARS = 65536
DEFAULT_SERVICE_READY_TIMEOUT_SECONDS = 30
DEFAULT_SERVICE_READY_INTERVAL_MS = 500
DEFAULT_SERVICE_LOG_TAIL_LINES = 200
DEFAULT_WORKSPACE_NETWORK_POLICY: WorkspaceNetworkPolicy = "off"
WORKSPACE_SHELL_SIGNAL_NAMES = shell_signal_names()
WORKSPACE_SERVICE_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
WORKSPACE_SNAPSHOT_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
@ -117,7 +121,7 @@ class WorkspaceRecord:
created_at: float
expires_at: float
state: VmState
network_requested: bool
network_policy: WorkspaceNetworkPolicy
allow_host_compat: bool
firecracker_pid: int | None = None
last_error: str | None = None
@ -135,6 +139,7 @@ class WorkspaceRecord:
cls,
instance: VmInstance,
*,
network_policy: WorkspaceNetworkPolicy = DEFAULT_WORKSPACE_NETWORK_POLICY,
command_count: int = 0,
last_command: dict[str, Any] | None = None,
workspace_seed: dict[str, Any] | None = None,
@ -149,7 +154,7 @@ class WorkspaceRecord:
created_at=instance.created_at,
expires_at=instance.expires_at,
state=instance.state,
network_requested=instance.network_requested,
network_policy=network_policy,
allow_host_compat=instance.allow_host_compat,
firecracker_pid=instance.firecracker_pid,
last_error=instance.last_error,
@ -174,7 +179,7 @@ class WorkspaceRecord:
expires_at=self.expires_at,
workdir=workdir,
state=self.state,
network_requested=self.network_requested,
network_requested=self.network_policy != "off",
allow_host_compat=self.allow_host_compat,
firecracker_pid=self.firecracker_pid,
last_error=self.last_error,
@ -193,7 +198,7 @@ class WorkspaceRecord:
"created_at": self.created_at,
"expires_at": self.expires_at,
"state": self.state,
"network_requested": self.network_requested,
"network_policy": self.network_policy,
"allow_host_compat": self.allow_host_compat,
"firecracker_pid": self.firecracker_pid,
"last_error": self.last_error,
@ -218,7 +223,7 @@ class WorkspaceRecord:
created_at=float(payload["created_at"]),
expires_at=float(payload["expires_at"]),
state=cast(VmState, str(payload.get("state", "stopped"))),
network_requested=bool(payload.get("network_requested", False)),
network_policy=_workspace_network_policy_from_payload(payload),
allow_host_compat=bool(payload.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)),
firecracker_pid=_optional_int(payload.get("firecracker_pid")),
last_error=_optional_str(payload.get("last_error")),
@ -365,6 +370,7 @@ class WorkspaceServiceRecord:
pid: int | None = None
execution_mode: str = "pending"
stop_reason: str | None = None
published_ports: list[WorkspacePublishedPortRecord] = field(default_factory=list)
metadata: dict[str, str] = field(default_factory=dict)
def to_payload(self) -> dict[str, Any]:
@ -383,6 +389,9 @@ class WorkspaceServiceRecord:
"pid": self.pid,
"execution_mode": self.execution_mode,
"stop_reason": self.stop_reason,
"published_ports": [
published_port.to_payload() for published_port in self.published_ports
],
"metadata": dict(self.metadata),
}
@ -412,10 +421,53 @@ class WorkspaceServiceRecord:
pid=None if payload.get("pid") is None else int(payload.get("pid", 0)),
execution_mode=str(payload.get("execution_mode", "pending")),
stop_reason=_optional_str(payload.get("stop_reason")),
published_ports=_workspace_published_port_records(payload.get("published_ports")),
metadata=_string_dict(payload.get("metadata")),
)
@dataclass(frozen=True)
class WorkspacePublishedPortRecord:
"""Persisted localhost published-port metadata for one service."""
guest_port: int
host_port: int
host: str = DEFAULT_PUBLISHED_PORT_HOST
protocol: str = "tcp"
proxy_pid: int | None = None
def to_payload(self) -> dict[str, Any]:
return {
"guest_port": self.guest_port,
"host_port": self.host_port,
"host": self.host,
"protocol": self.protocol,
"proxy_pid": self.proxy_pid,
}
@classmethod
def from_payload(cls, payload: dict[str, Any]) -> WorkspacePublishedPortRecord:
return cls(
guest_port=int(payload["guest_port"]),
host_port=int(payload["host_port"]),
host=str(payload.get("host", DEFAULT_PUBLISHED_PORT_HOST)),
protocol=str(payload.get("protocol", "tcp")),
proxy_pid=(
None
if payload.get("proxy_pid") is None
else int(payload.get("proxy_pid", 0))
),
)
@dataclass(frozen=True)
class WorkspacePublishedPortSpec:
"""Requested published-port configuration for one service."""
guest_port: int
host_port: int | None = None
@dataclass(frozen=True)
class PreparedWorkspaceSeed:
"""Prepared host-side seed archive plus metadata."""
@ -534,6 +586,49 @@ def _workspace_seed_dict(value: object) -> dict[str, Any]:
return payload
def _normalize_workspace_network_policy(policy: str) -> WorkspaceNetworkPolicy:
normalized = policy.strip().lower()
if normalized not in {"off", "egress", "egress+published-ports"}:
raise ValueError("network_policy must be one of: off, egress, egress+published-ports")
return cast(WorkspaceNetworkPolicy, normalized)
def _workspace_network_policy_from_payload(payload: dict[str, Any]) -> WorkspaceNetworkPolicy:
raw_policy = payload.get("network_policy")
if raw_policy is not None:
return _normalize_workspace_network_policy(str(raw_policy))
raw_network_requested = payload.get("network_requested", False)
if isinstance(raw_network_requested, str):
network_requested = raw_network_requested.strip().lower() in {"1", "true", "yes", "on"}
else:
network_requested = bool(raw_network_requested)
if network_requested:
return "egress"
return DEFAULT_WORKSPACE_NETWORK_POLICY
def _serialize_workspace_published_port_public(
published_port: WorkspacePublishedPortRecord,
) -> dict[str, Any]:
return {
"host": published_port.host,
"host_port": published_port.host_port,
"guest_port": published_port.guest_port,
"protocol": published_port.protocol,
}
def _workspace_published_port_records(value: object) -> list[WorkspacePublishedPortRecord]:
if not isinstance(value, list):
return []
records: list[WorkspacePublishedPortRecord] = []
for item in value:
if not isinstance(item, dict):
continue
records.append(WorkspacePublishedPortRecord.from_payload(item))
return records
def _workspace_secret_records(value: object) -> list[WorkspaceSecretRecord]:
if not isinstance(value, list):
return []
@ -1159,6 +1254,59 @@ def _normalize_workspace_secret_env_mapping(
return normalized
def _normalize_workspace_published_port(
*,
guest_port: object,
host_port: object | None = None,
) -> WorkspacePublishedPortSpec:
if isinstance(guest_port, bool) or not isinstance(guest_port, int | str):
raise ValueError("published guest_port must be an integer")
try:
normalized_guest_port = int(guest_port)
except (TypeError, ValueError) as exc:
raise ValueError("published guest_port must be an integer") from exc
if normalized_guest_port <= 0 or normalized_guest_port > 65535:
raise ValueError("published guest_port must be between 1 and 65535")
normalized_host_port: int | None = None
if host_port is not None:
if isinstance(host_port, bool) or not isinstance(host_port, int | str):
raise ValueError("published host_port must be an integer")
try:
normalized_host_port = int(host_port)
except (TypeError, ValueError) as exc:
raise ValueError("published host_port must be an integer") from exc
if normalized_host_port <= 1024 or normalized_host_port > 65535:
raise ValueError("published host_port must be between 1025 and 65535")
return WorkspacePublishedPortSpec(
guest_port=normalized_guest_port,
host_port=normalized_host_port,
)
def _normalize_workspace_published_port_specs(
published_ports: list[dict[str, Any]] | None,
) -> list[WorkspacePublishedPortSpec]:
if not published_ports:
return []
normalized: list[WorkspacePublishedPortSpec] = []
seen_guest_ports: set[tuple[int | None, int]] = set()
for index, item in enumerate(published_ports, start=1):
if not isinstance(item, dict):
raise ValueError(f"published port #{index} must be a dictionary")
spec = _normalize_workspace_published_port(
guest_port=item.get("guest_port"),
host_port=item.get("host_port"),
)
dedupe_key = (spec.host_port, spec.guest_port)
if dedupe_key in seen_guest_ports:
raise ValueError(
"published ports must not repeat the same host/guest port mapping"
)
seen_guest_ports.add(dedupe_key)
normalized.append(spec)
return normalized
def _normalize_workspace_service_readiness(
readiness: dict[str, Any] | None,
) -> dict[str, Any] | None:
@ -1215,6 +1363,15 @@ def _workspace_service_runner_path(services_dir: Path, service_name: str) -> Pat
return services_dir / f"{service_name}.runner.sh"
def _workspace_service_port_ready_path(
services_dir: Path,
service_name: str,
host_port: int,
guest_port: int,
) -> Path:
return services_dir / f"{service_name}.port-{host_port}-to-{guest_port}.ready.json"
def _read_service_exit_code(status_path: Path) -> int | None:
if not status_path.exists():
return None
@ -1348,6 +1505,7 @@ def _refresh_local_service_record(
pid=service.pid,
execution_mode=service.execution_mode,
stop_reason=service.stop_reason,
published_ports=list(service.published_ports),
metadata=dict(service.metadata),
)
return refreshed
@ -1466,6 +1624,95 @@ def _stop_local_service(
return refreshed
def _start_workspace_published_port_proxy(
*,
services_dir: Path,
service_name: str,
workspace_id: str,
guest_ip: str,
spec: WorkspacePublishedPortSpec,
) -> WorkspacePublishedPortRecord:
ready_path = _workspace_service_port_ready_path(
services_dir,
service_name,
spec.host_port or 0,
spec.guest_port,
)
ready_path.unlink(missing_ok=True)
command = [
sys.executable,
"-m",
"pyro_mcp.workspace_ports",
"--listen-host",
DEFAULT_PUBLISHED_PORT_HOST,
"--listen-port",
str(spec.host_port or 0),
"--target-host",
guest_ip,
"--target-port",
str(spec.guest_port),
"--ready-file",
str(ready_path),
]
process = subprocess.Popen( # noqa: S603
command,
text=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
start_new_session=True,
)
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
if ready_path.exists():
payload = json.loads(ready_path.read_text(encoding="utf-8"))
if not isinstance(payload, dict):
raise RuntimeError("published port proxy ready payload is invalid")
ready_path.unlink(missing_ok=True)
return WorkspacePublishedPortRecord(
guest_port=int(payload.get("target_port", spec.guest_port)),
host_port=int(payload["host_port"]),
host=str(payload.get("host", DEFAULT_PUBLISHED_PORT_HOST)),
protocol=str(payload.get("protocol", "tcp")),
proxy_pid=process.pid,
)
if process.poll() is not None:
raise RuntimeError(
"failed to start published port proxy for "
f"service {service_name!r} in workspace {workspace_id!r}"
)
time.sleep(0.05)
_stop_workspace_published_port_proxy(
WorkspacePublishedPortRecord(
guest_port=spec.guest_port,
host_port=spec.host_port or 0,
proxy_pid=process.pid,
)
)
ready_path.unlink(missing_ok=True)
raise RuntimeError(
"timed out waiting for published port proxy readiness for "
f"service {service_name!r} in workspace {workspace_id!r}"
)
def _stop_workspace_published_port_proxy(published_port: WorkspacePublishedPortRecord) -> None:
if published_port.proxy_pid is None:
return
try:
os.killpg(published_port.proxy_pid, signal.SIGTERM)
except ProcessLookupError:
return
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
if not _pid_is_running(published_port.proxy_pid):
return
time.sleep(0.05)
try:
os.killpg(published_port.proxy_pid, signal.SIGKILL)
except ProcessLookupError:
return
def _instance_workspace_host_dir(instance: VmInstance) -> Path:
raw_value = instance.metadata.get("workspace_host_dir")
if raw_value is None or raw_value == "":
@ -3057,13 +3304,14 @@ class VmManager:
vcpu_count: int = DEFAULT_VCPU_COUNT,
mem_mib: int = DEFAULT_MEM_MIB,
ttl_seconds: int = DEFAULT_TTL_SECONDS,
network: bool = False,
network_policy: WorkspaceNetworkPolicy | str = DEFAULT_WORKSPACE_NETWORK_POLICY,
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
seed_path: str | Path | None = None,
secrets: list[dict[str, str]] | None = None,
) -> dict[str, Any]:
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
get_environment(environment, runtime_paths=self._runtime_paths)
normalized_network_policy = _normalize_workspace_network_policy(str(network_policy))
prepared_seed = self._prepare_workspace_seed(seed_path)
now = time.time()
workspace_id = uuid.uuid4().hex[:12]
@ -3097,12 +3345,13 @@ class VmManager:
created_at=now,
expires_at=now + ttl_seconds,
workdir=runtime_dir,
network_requested=network,
network_requested=normalized_network_policy != "off",
allow_host_compat=allow_host_compat,
)
instance.metadata["allow_host_compat"] = str(allow_host_compat).lower()
instance.metadata["workspace_path"] = WORKSPACE_GUEST_PATH
instance.metadata["workspace_host_dir"] = str(host_workspace_dir)
instance.metadata["network_policy"] = normalized_network_policy
try:
with self._lock:
self._reap_expired_locked(now)
@ -3112,6 +3361,9 @@ class VmManager:
raise RuntimeError(
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
)
self._require_workspace_network_policy_support(
network_policy=normalized_network_policy
)
self._backend.create(instance)
if self._runtime_capabilities.supports_guest_exec:
self._ensure_workspace_guest_bootstrap_support(instance)
@ -3119,6 +3371,7 @@ class VmManager:
self._start_instance_locked(instance)
workspace = WorkspaceRecord.from_instance(
instance,
network_policy=normalized_network_policy,
workspace_seed=prepared_seed.to_payload(),
secrets=secret_records,
)
@ -3435,6 +3688,9 @@ class VmManager:
recreated = workspace.to_instance(
workdir=self._workspace_runtime_dir(workspace.workspace_id)
)
self._require_workspace_network_policy_support(
network_policy=workspace.network_policy
)
self._backend.create(recreated)
if self._runtime_capabilities.supports_guest_exec:
self._ensure_workspace_guest_bootstrap_support(recreated)
@ -3798,11 +4054,13 @@ class VmManager:
ready_timeout_seconds: int = DEFAULT_SERVICE_READY_TIMEOUT_SECONDS,
ready_interval_ms: int = DEFAULT_SERVICE_READY_INTERVAL_MS,
secret_env: dict[str, str] | None = None,
published_ports: list[dict[str, Any]] | None = None,
) -> dict[str, Any]:
normalized_service_name = _normalize_workspace_service_name(service_name)
normalized_cwd, _ = _normalize_workspace_destination(cwd)
normalized_readiness = _normalize_workspace_service_readiness(readiness)
normalized_secret_env = _normalize_workspace_secret_env_mapping(secret_env)
normalized_published_ports = _normalize_workspace_published_port_specs(published_ports)
if ready_timeout_seconds <= 0:
raise ValueError("ready_timeout_seconds must be positive")
if ready_interval_ms <= 0:
@ -3810,6 +4068,16 @@ class VmManager:
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
instance = self._workspace_instance_for_live_service_locked(workspace)
if normalized_published_ports:
if workspace.network_policy != "egress+published-ports":
raise RuntimeError(
"published ports require workspace network_policy "
"'egress+published-ports'"
)
if instance.network is None:
raise RuntimeError(
"published ports require an active guest network configuration"
)
redact_values = self._workspace_secret_redact_values_locked(workspace)
env_values = self._workspace_secret_env_values_locked(workspace, normalized_secret_env)
if workspace.secrets and normalized_secret_env:
@ -3852,6 +4120,36 @@ class VmManager:
service_name=normalized_service_name,
payload=payload,
)
if normalized_published_ports:
assert instance.network is not None # guarded above
try:
service.published_ports = self._start_workspace_service_published_ports(
workspace=workspace,
service=service,
guest_ip=instance.network.guest_ip,
published_ports=normalized_published_ports,
)
except Exception:
try:
failed_payload = self._backend.stop_service(
instance,
workspace_id=workspace_id,
service_name=normalized_service_name,
)
service = self._workspace_service_record_from_payload(
workspace_id=workspace_id,
service_name=normalized_service_name,
payload=failed_payload,
published_ports=[],
)
except Exception:
service.state = "failed"
service.stop_reason = "published_port_failed"
service.ended_at = service.ended_at or time.time()
else:
service.state = "failed"
service.stop_reason = "published_port_failed"
service.ended_at = service.ended_at or time.time()
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = instance.state
@ -3916,7 +4214,21 @@ class VmManager:
service_name=normalized_service_name,
payload=payload,
metadata=service.metadata,
published_ports=service.published_ports,
)
if service.published_ports:
for published_port in service.published_ports:
_stop_workspace_published_port_proxy(published_port)
service.published_ports = [
WorkspacePublishedPortRecord(
guest_port=published_port.guest_port,
host_port=published_port.host_port,
host=published_port.host,
protocol=published_port.protocol,
proxy_pid=None,
)
for published_port in service.published_ports
]
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
workspace.state = instance.state
@ -3956,6 +4268,7 @@ class VmManager:
service_name=normalized_service_name,
payload=payload,
metadata=service.metadata,
published_ports=service.published_ports,
)
with self._lock:
workspace = self._load_workspace_locked(workspace_id)
@ -4059,6 +4372,7 @@ class VmManager:
"created_at": workspace.created_at,
"expires_at": workspace.expires_at,
"state": workspace.state,
"network_policy": workspace.network_policy,
"network_enabled": workspace.network is not None,
"allow_host_compat": workspace.allow_host_compat,
"guest_ip": workspace.network.guest_ip if workspace.network is not None else None,
@ -4107,6 +4421,10 @@ class VmManager:
"readiness": dict(service.readiness) if service.readiness is not None else None,
"ready_at": service.ready_at,
"stop_reason": service.stop_reason,
"published_ports": [
_serialize_workspace_published_port_public(published_port)
for published_port in service.published_ports
],
}
def _serialize_workspace_snapshot(self, snapshot: WorkspaceSnapshotRecord) -> dict[str, Any]:
@ -4182,6 +4500,23 @@ class VmManager:
f"workspace: {reason}"
)
def _require_workspace_network_policy_support(
self,
*,
network_policy: WorkspaceNetworkPolicy,
) -> None:
if network_policy == "off":
return
if self._runtime_capabilities.supports_guest_network:
return
reason = self._runtime_capabilities.reason or (
"runtime does not support guest-backed workspace networking"
)
raise RuntimeError(
"workspace network_policy requires guest networking and is unavailable for this "
f"workspace: {reason}"
)
def _workspace_secret_values_locked(self, workspace: WorkspaceRecord) -> dict[str, str]:
return _load_workspace_secret_values(
workspace_dir=self._workspace_dir(workspace.workspace_id),
@ -4609,9 +4944,15 @@ class VmManager:
service_name: str,
payload: dict[str, Any],
metadata: dict[str, str] | None = None,
published_ports: list[WorkspacePublishedPortRecord] | None = None,
) -> WorkspaceServiceRecord:
readiness_payload = payload.get("readiness")
readiness = dict(readiness_payload) if isinstance(readiness_payload, dict) else None
normalized_published_ports = _workspace_published_port_records(
payload.get("published_ports")
)
if not normalized_published_ports and published_ports is not None:
normalized_published_ports = list(published_ports)
return WorkspaceServiceRecord(
workspace_id=workspace_id,
service_name=str(payload.get("service_name", service_name)),
@ -4632,6 +4973,7 @@ class VmManager:
pid=None if payload.get("pid") is None else int(payload.get("pid", 0)),
execution_mode=str(payload.get("execution_mode", "pending")),
stop_reason=_optional_str(payload.get("stop_reason")),
published_ports=normalized_published_ports,
metadata=dict(metadata or {}),
)
@ -4652,6 +4994,33 @@ class VmManager:
services = self._list_workspace_services_locked(workspace_id)
return len(services), sum(1 for service in services if service.state == "running")
def _start_workspace_service_published_ports(
self,
*,
workspace: WorkspaceRecord,
service: WorkspaceServiceRecord,
guest_ip: str,
published_ports: list[WorkspacePublishedPortSpec],
) -> list[WorkspacePublishedPortRecord]:
services_dir = self._workspace_services_dir(workspace.workspace_id)
started: list[WorkspacePublishedPortRecord] = []
try:
for spec in published_ports:
started.append(
_start_workspace_published_port_proxy(
services_dir=services_dir,
service_name=service.service_name,
workspace_id=workspace.workspace_id,
guest_ip=guest_ip,
spec=spec,
)
)
except Exception:
for published_port in started:
_stop_workspace_published_port_proxy(published_port)
raise
return started
def _workspace_baseline_snapshot_locked(
self,
workspace: WorkspaceRecord,
@ -4770,12 +5139,18 @@ class VmManager:
workspace_id: str,
service_name: str,
) -> None:
existing = self._load_workspace_service_locked_optional(workspace_id, service_name)
if existing is not None:
for published_port in existing.published_ports:
_stop_workspace_published_port_proxy(published_port)
self._workspace_service_record_path(workspace_id, service_name).unlink(missing_ok=True)
services_dir = self._workspace_services_dir(workspace_id)
_workspace_service_stdout_path(services_dir, service_name).unlink(missing_ok=True)
_workspace_service_stderr_path(services_dir, service_name).unlink(missing_ok=True)
_workspace_service_status_path(services_dir, service_name).unlink(missing_ok=True)
_workspace_service_runner_path(services_dir, service_name).unlink(missing_ok=True)
for ready_path in services_dir.glob(f"{service_name}.port-*.ready.json"):
ready_path.unlink(missing_ok=True)
def _delete_workspace_snapshot_locked(self, workspace_id: str, snapshot_name: str) -> None:
self._workspace_snapshot_metadata_path(workspace_id, snapshot_name).unlink(missing_ok=True)
@ -4881,7 +5256,21 @@ class VmManager:
service_name=service.service_name,
payload=payload,
metadata=service.metadata,
published_ports=service.published_ports,
)
if refreshed.state != "running" and refreshed.published_ports:
refreshed.published_ports = [
WorkspacePublishedPortRecord(
guest_port=published_port.guest_port,
host_port=published_port.host_port,
host=published_port.host,
protocol=published_port.protocol,
proxy_pid=None,
)
for published_port in refreshed.published_ports
]
for published_port in service.published_ports:
_stop_workspace_published_port_proxy(published_port)
self._save_workspace_service_locked(refreshed)
return refreshed
@ -4904,6 +5293,8 @@ class VmManager:
changed = False
for service in services:
if service.state == "running":
for published_port in service.published_ports:
_stop_workspace_published_port_proxy(published_port)
service.state = "stopped"
service.stop_reason = "workspace_stopped"
service.ended_at = service.ended_at or time.time()
@ -4936,6 +5327,7 @@ class VmManager:
service_name=service.service_name,
payload=payload,
metadata=service.metadata,
published_ports=service.published_ports,
)
self._save_workspace_service_locked(stopped)
except Exception:

View file

@ -0,0 +1,116 @@
"""Localhost-only TCP port proxy for published workspace services."""
from __future__ import annotations
import argparse
import json
import selectors
import signal
import socket
import socketserver
import sys
import threading
from pathlib import Path
DEFAULT_PUBLISHED_PORT_HOST = "127.0.0.1"
class _ProxyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
allow_reuse_address = False
daemon_threads = True
def __init__(self, server_address: tuple[str, int], target_address: tuple[str, int]) -> None:
super().__init__(server_address, _ProxyHandler)
self.target_address = target_address
class _ProxyHandler(socketserver.BaseRequestHandler):
def handle(self) -> None:
server = self.server
if not isinstance(server, _ProxyServer):
raise RuntimeError("proxy server is invalid")
try:
upstream = socket.create_connection(server.target_address, timeout=5)
except OSError:
return
with upstream:
self.request.setblocking(False)
upstream.setblocking(False)
selector = selectors.DefaultSelector()
try:
selector.register(self.request, selectors.EVENT_READ, upstream)
selector.register(upstream, selectors.EVENT_READ, self.request)
while True:
events = selector.select()
if not events:
continue
for key, _ in events:
source = key.fileobj
target = key.data
if not isinstance(source, socket.socket) or not isinstance(
target, socket.socket
):
continue
try:
chunk = source.recv(65536)
except OSError:
return
if not chunk:
return
try:
target.sendall(chunk)
except OSError:
return
finally:
selector.close()
def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run a localhost-only TCP port proxy.")
parser.add_argument("--listen-host", required=True)
parser.add_argument("--listen-port", type=int, required=True)
parser.add_argument("--target-host", required=True)
parser.add_argument("--target-port", type=int, required=True)
parser.add_argument("--ready-file", required=True)
return parser
def main(argv: list[str] | None = None) -> int:
args = _build_parser().parse_args(argv)
ready_file = Path(args.ready_file)
ready_file.parent.mkdir(parents=True, exist_ok=True)
server = _ProxyServer(
(str(args.listen_host), int(args.listen_port)),
(str(args.target_host), int(args.target_port)),
)
actual_host = str(server.server_address[0])
actual_port = int(server.server_address[1])
ready_file.write_text(
json.dumps(
{
"host": actual_host,
"host_port": actual_port,
"target_host": args.target_host,
"target_port": int(args.target_port),
"protocol": "tcp",
},
indent=2,
sort_keys=True,
),
encoding="utf-8",
)
def _shutdown(_: int, __: object) -> None:
threading.Thread(target=server.shutdown, daemon=True).start()
signal.signal(signal.SIGTERM, _shutdown)
signal.signal(signal.SIGINT, _shutdown)
try:
server.serve_forever(poll_interval=0.2)
finally:
server.server_close()
return 0
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))