Add workspace network policy and published ports
Replace the workspace-level boolean network toggle with explicit network policies and attach localhost TCP publication to workspace services. Persist network_policy in workspace records, validate --publish requests, and run host-side proxy helpers that follow the service lifecycle so published ports are cleaned up on failure, stop, reset, and delete. Update the CLI, SDK, MCP contract, docs, roadmap, and examples for the new policy model, add coverage for the proxy and manager edge cases, and validate with uv lock, UV_CACHE_DIR=.uv-cache make check, UV_CACHE_DIR=.uv-cache make dist-check, and a real guest-backed published-port probe smoke.
This commit is contained in:
parent
fc72fcd3a1
commit
c82f4629b2
21 changed files with 1944 additions and 49 deletions
|
|
@ -13,6 +13,7 @@ from pyro_mcp.vm_manager import (
|
|||
DEFAULT_TIMEOUT_SECONDS,
|
||||
DEFAULT_TTL_SECONDS,
|
||||
DEFAULT_VCPU_COUNT,
|
||||
DEFAULT_WORKSPACE_NETWORK_POLICY,
|
||||
VmManager,
|
||||
)
|
||||
|
||||
|
|
@ -84,7 +85,7 @@ class Pyro:
|
|||
vcpu_count: int = DEFAULT_VCPU_COUNT,
|
||||
mem_mib: int = DEFAULT_MEM_MIB,
|
||||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
network_policy: str = DEFAULT_WORKSPACE_NETWORK_POLICY,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
seed_path: str | Path | None = None,
|
||||
secrets: list[dict[str, str]] | None = None,
|
||||
|
|
@ -94,7 +95,7 @@ class Pyro:
|
|||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
network_policy=network_policy,
|
||||
allow_host_compat=allow_host_compat,
|
||||
seed_path=seed_path,
|
||||
secrets=secrets,
|
||||
|
|
@ -241,6 +242,7 @@ class Pyro:
|
|||
ready_timeout_seconds: int = 30,
|
||||
ready_interval_ms: int = 500,
|
||||
secret_env: dict[str, str] | None = None,
|
||||
published_ports: list[dict[str, int | None]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
return self._manager.start_service(
|
||||
workspace_id,
|
||||
|
|
@ -251,6 +253,7 @@ class Pyro:
|
|||
ready_timeout_seconds=ready_timeout_seconds,
|
||||
ready_interval_ms=ready_interval_ms,
|
||||
secret_env=secret_env,
|
||||
published_ports=published_ports,
|
||||
)
|
||||
|
||||
def list_services(self, workspace_id: str) -> dict[str, Any]:
|
||||
|
|
@ -408,7 +411,7 @@ class Pyro:
|
|||
vcpu_count: int = DEFAULT_VCPU_COUNT,
|
||||
mem_mib: int = DEFAULT_MEM_MIB,
|
||||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
network_policy: str = DEFAULT_WORKSPACE_NETWORK_POLICY,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
seed_path: str | None = None,
|
||||
secrets: list[dict[str, str]] | None = None,
|
||||
|
|
@ -419,7 +422,7 @@ class Pyro:
|
|||
vcpu_count=vcpu_count,
|
||||
mem_mib=mem_mib,
|
||||
ttl_seconds=ttl_seconds,
|
||||
network=network,
|
||||
network_policy=network_policy,
|
||||
allow_host_compat=allow_host_compat,
|
||||
seed_path=seed_path,
|
||||
secrets=secrets,
|
||||
|
|
@ -574,6 +577,7 @@ class Pyro:
|
|||
ready_timeout_seconds: int = 30,
|
||||
ready_interval_ms: int = 500,
|
||||
secret_env: dict[str, str] | None = None,
|
||||
published_ports: list[dict[str, int | None]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Start a named long-running service inside a workspace."""
|
||||
readiness: dict[str, Any] | None = None
|
||||
|
|
@ -594,6 +598,7 @@ class Pyro:
|
|||
ready_timeout_seconds=ready_timeout_seconds,
|
||||
ready_interval_ms=ready_interval_ms,
|
||||
secret_env=secret_env,
|
||||
published_ports=published_ports,
|
||||
)
|
||||
|
||||
@server.tool()
|
||||
|
|
|
|||
|
|
@ -160,6 +160,7 @@ def _print_workspace_summary_human(payload: dict[str, Any], *, action: str) -> N
|
|||
print(f"Environment: {str(payload.get('environment', 'unknown'))}")
|
||||
print(f"State: {str(payload.get('state', 'unknown'))}")
|
||||
print(f"Workspace: {str(payload.get('workspace_path', '/workspace'))}")
|
||||
print(f"Network policy: {str(payload.get('network_policy', 'off'))}")
|
||||
workspace_seed = payload.get("workspace_seed")
|
||||
if isinstance(workspace_seed, dict):
|
||||
mode = str(workspace_seed.get("mode", "empty"))
|
||||
|
|
@ -378,13 +379,27 @@ def _print_workspace_shell_read_human(payload: dict[str, Any]) -> None:
|
|||
|
||||
|
||||
def _print_workspace_service_summary_human(payload: dict[str, Any], *, prefix: str) -> None:
|
||||
published_ports = payload.get("published_ports")
|
||||
published_text = ""
|
||||
if isinstance(published_ports, list) and published_ports:
|
||||
parts = []
|
||||
for item in published_ports:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
parts.append(
|
||||
f"{str(item.get('host', '127.0.0.1'))}:{int(item.get('host_port', 0))}"
|
||||
f"->{int(item.get('guest_port', 0))}/{str(item.get('protocol', 'tcp'))}"
|
||||
)
|
||||
if parts:
|
||||
published_text = " published_ports=" + ",".join(parts)
|
||||
print(
|
||||
f"[{prefix}] "
|
||||
f"workspace_id={str(payload.get('workspace_id', 'unknown'))} "
|
||||
f"service_name={str(payload.get('service_name', 'unknown'))} "
|
||||
f"state={str(payload.get('state', 'unknown'))} "
|
||||
f"cwd={str(payload.get('cwd', WORKSPACE_GUEST_PATH))} "
|
||||
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}",
|
||||
f"execution_mode={str(payload.get('execution_mode', 'unknown'))}"
|
||||
f"{published_text}",
|
||||
file=sys.stderr,
|
||||
flush=True,
|
||||
)
|
||||
|
|
@ -402,6 +417,18 @@ def _print_workspace_service_list_human(payload: dict[str, Any]) -> None:
|
|||
f"{str(service.get('service_name', 'unknown'))} "
|
||||
f"[{str(service.get('state', 'unknown'))}] "
|
||||
f"cwd={str(service.get('cwd', WORKSPACE_GUEST_PATH))}"
|
||||
+ (
|
||||
" published="
|
||||
+ ",".join(
|
||||
f"{str(item.get('host', '127.0.0.1'))}:{int(item.get('host_port', 0))}"
|
||||
f"->{int(item.get('guest_port', 0))}/{str(item.get('protocol', 'tcp'))}"
|
||||
for item in service.get("published_ports", [])
|
||||
if isinstance(item, dict)
|
||||
)
|
||||
if isinstance(service.get("published_ports"), list)
|
||||
and service.get("published_ports")
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -683,6 +710,7 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
Examples:
|
||||
pyro workspace create debian:12
|
||||
pyro workspace create debian:12 --seed-path ./repo
|
||||
pyro workspace create debian:12 --network-policy egress
|
||||
pyro workspace create debian:12 --secret API_TOKEN=expected
|
||||
pyro workspace sync push WORKSPACE_ID ./changes
|
||||
pyro workspace snapshot create WORKSPACE_ID checkpoint
|
||||
|
|
@ -718,9 +746,10 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
help="Time-to-live for the workspace before automatic cleanup.",
|
||||
)
|
||||
workspace_create_parser.add_argument(
|
||||
"--network",
|
||||
action="store_true",
|
||||
help="Enable outbound guest networking for the workspace guest.",
|
||||
"--network-policy",
|
||||
choices=("off", "egress", "egress+published-ports"),
|
||||
default="off",
|
||||
help="Workspace network policy.",
|
||||
)
|
||||
workspace_create_parser.add_argument(
|
||||
"--allow-host-compat",
|
||||
|
|
@ -1204,6 +1233,8 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
Examples:
|
||||
pyro workspace service start WORKSPACE_ID app --ready-file .ready -- \
|
||||
sh -lc 'touch .ready && while true; do sleep 60; done'
|
||||
pyro workspace service start WORKSPACE_ID app --ready-file .ready --publish 8080 -- \
|
||||
sh -lc 'touch .ready && python3 -m http.server 8080'
|
||||
pyro workspace service list WORKSPACE_ID
|
||||
pyro workspace service status WORKSPACE_ID app
|
||||
pyro workspace service logs WORKSPACE_ID app --tail-lines 50
|
||||
|
|
@ -1229,6 +1260,9 @@ def _build_parser() -> argparse.ArgumentParser:
|
|||
Examples:
|
||||
pyro workspace service start WORKSPACE_ID app --ready-file .ready -- \
|
||||
sh -lc 'touch .ready && while true; do sleep 60; done'
|
||||
pyro workspace service start WORKSPACE_ID app \
|
||||
--ready-file .ready --publish 18080:8080 -- \
|
||||
sh -lc 'touch .ready && python3 -m http.server 8080'
|
||||
pyro workspace service start WORKSPACE_ID app --secret-env API_TOKEN -- \
|
||||
sh -lc 'test \"$API_TOKEN\" = \"expected\"; touch .ready; \
|
||||
while true; do sleep 60; done'
|
||||
|
|
@ -1280,6 +1314,16 @@ while true; do sleep 60; done'
|
|||
metavar="SECRET[=ENV_VAR]",
|
||||
help="Expose one persisted workspace secret as an environment variable for this service.",
|
||||
)
|
||||
workspace_service_start_parser.add_argument(
|
||||
"--publish",
|
||||
action="append",
|
||||
default=[],
|
||||
metavar="GUEST_PORT|HOST_PORT:GUEST_PORT",
|
||||
help=(
|
||||
"Publish one guest TCP port on 127.0.0.1. Requires workspace network policy "
|
||||
"`egress+published-ports`."
|
||||
),
|
||||
)
|
||||
workspace_service_start_parser.add_argument(
|
||||
"--json",
|
||||
action="store_true",
|
||||
|
|
@ -1528,6 +1572,33 @@ def _parse_workspace_secret_env_options(values: list[str]) -> dict[str, str]:
|
|||
return parsed
|
||||
|
||||
|
||||
def _parse_workspace_publish_options(values: list[str]) -> list[dict[str, int | None]]:
|
||||
parsed: list[dict[str, int | None]] = []
|
||||
for raw_value in values:
|
||||
candidate = raw_value.strip()
|
||||
if candidate == "":
|
||||
raise ValueError("published ports must not be empty")
|
||||
if ":" in candidate:
|
||||
raw_host_port, raw_guest_port = candidate.split(":", 1)
|
||||
try:
|
||||
host_port = int(raw_host_port)
|
||||
guest_port = int(raw_guest_port)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"published ports must use GUEST_PORT or HOST_PORT:GUEST_PORT"
|
||||
) from exc
|
||||
parsed.append({"host_port": host_port, "guest_port": guest_port})
|
||||
else:
|
||||
try:
|
||||
guest_port = int(candidate)
|
||||
except ValueError as exc:
|
||||
raise ValueError(
|
||||
"published ports must use GUEST_PORT or HOST_PORT:GUEST_PORT"
|
||||
) from exc
|
||||
parsed.append({"host_port": None, "guest_port": guest_port})
|
||||
return parsed
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _build_parser().parse_args()
|
||||
pyro = Pyro()
|
||||
|
|
@ -1634,7 +1705,7 @@ def main() -> None:
|
|||
vcpu_count=args.vcpu_count,
|
||||
mem_mib=args.mem_mib,
|
||||
ttl_seconds=args.ttl_seconds,
|
||||
network=args.network,
|
||||
network_policy=getattr(args, "network_policy", "off"),
|
||||
allow_host_compat=args.allow_host_compat,
|
||||
seed_path=args.seed_path,
|
||||
secrets=secrets or None,
|
||||
|
|
@ -1932,6 +2003,7 @@ def main() -> None:
|
|||
readiness = {"type": "command", "command": args.ready_command}
|
||||
command = _require_command(args.command_args)
|
||||
secret_env = _parse_workspace_secret_env_options(getattr(args, "secret_env", []))
|
||||
published_ports = _parse_workspace_publish_options(getattr(args, "publish", []))
|
||||
try:
|
||||
payload = pyro.start_service(
|
||||
args.workspace_id,
|
||||
|
|
@ -1942,6 +2014,7 @@ def main() -> None:
|
|||
ready_timeout_seconds=args.ready_timeout_seconds,
|
||||
ready_interval_ms=args.ready_interval_ms,
|
||||
secret_env=secret_env or None,
|
||||
published_ports=published_ports or None,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
if bool(args.json):
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ PUBLIC_CLI_WORKSPACE_CREATE_FLAGS = (
|
|||
"--vcpu-count",
|
||||
"--mem-mib",
|
||||
"--ttl-seconds",
|
||||
"--network",
|
||||
"--network-policy",
|
||||
"--allow-host-compat",
|
||||
"--seed-path",
|
||||
"--secret",
|
||||
|
|
@ -49,6 +49,7 @@ PUBLIC_CLI_WORKSPACE_SERVICE_START_FLAGS = (
|
|||
"--ready-timeout-seconds",
|
||||
"--ready-interval-ms",
|
||||
"--secret-env",
|
||||
"--publish",
|
||||
"--json",
|
||||
)
|
||||
PUBLIC_CLI_WORKSPACE_SERVICE_STATUS_FLAGS = ("--json",)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from typing import Any
|
|||
from pyro_mcp.runtime import DEFAULT_PLATFORM, RuntimePaths
|
||||
|
||||
DEFAULT_ENVIRONMENT_VERSION = "1.0.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.9.0"
|
||||
DEFAULT_CATALOG_VERSION = "2.10.0"
|
||||
OCI_MANIFEST_ACCEPT = ", ".join(
|
||||
(
|
||||
"application/vnd.oci.image.index.v1+json",
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import shutil
|
|||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tarfile
|
||||
import tempfile
|
||||
import threading
|
||||
|
|
@ -33,6 +34,7 @@ from pyro_mcp.vm_environments import EnvironmentStore, default_cache_dir, get_en
|
|||
from pyro_mcp.vm_firecracker import build_launch_plan
|
||||
from pyro_mcp.vm_guest import VsockExecClient
|
||||
from pyro_mcp.vm_network import NetworkConfig, TapNetworkManager
|
||||
from pyro_mcp.workspace_ports import DEFAULT_PUBLISHED_PORT_HOST
|
||||
from pyro_mcp.workspace_shells import (
|
||||
create_local_shell,
|
||||
get_local_shell,
|
||||
|
|
@ -43,6 +45,7 @@ from pyro_mcp.workspace_shells import (
|
|||
VmState = Literal["created", "started", "stopped"]
|
||||
WorkspaceShellState = Literal["running", "stopped"]
|
||||
WorkspaceServiceState = Literal["running", "exited", "stopped", "failed"]
|
||||
WorkspaceNetworkPolicy = Literal["off", "egress", "egress+published-ports"]
|
||||
|
||||
DEFAULT_VCPU_COUNT = 1
|
||||
DEFAULT_MEM_MIB = 1024
|
||||
|
|
@ -50,7 +53,7 @@ DEFAULT_TIMEOUT_SECONDS = 30
|
|||
DEFAULT_TTL_SECONDS = 600
|
||||
DEFAULT_ALLOW_HOST_COMPAT = False
|
||||
|
||||
WORKSPACE_LAYOUT_VERSION = 6
|
||||
WORKSPACE_LAYOUT_VERSION = 7
|
||||
WORKSPACE_BASELINE_DIRNAME = "baseline"
|
||||
WORKSPACE_BASELINE_ARCHIVE_NAME = "workspace.tar"
|
||||
WORKSPACE_SNAPSHOTS_DIRNAME = "snapshots"
|
||||
|
|
@ -72,6 +75,7 @@ DEFAULT_SHELL_MAX_CHARS = 65536
|
|||
DEFAULT_SERVICE_READY_TIMEOUT_SECONDS = 30
|
||||
DEFAULT_SERVICE_READY_INTERVAL_MS = 500
|
||||
DEFAULT_SERVICE_LOG_TAIL_LINES = 200
|
||||
DEFAULT_WORKSPACE_NETWORK_POLICY: WorkspaceNetworkPolicy = "off"
|
||||
WORKSPACE_SHELL_SIGNAL_NAMES = shell_signal_names()
|
||||
WORKSPACE_SERVICE_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
|
||||
WORKSPACE_SNAPSHOT_NAME_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$")
|
||||
|
|
@ -117,7 +121,7 @@ class WorkspaceRecord:
|
|||
created_at: float
|
||||
expires_at: float
|
||||
state: VmState
|
||||
network_requested: bool
|
||||
network_policy: WorkspaceNetworkPolicy
|
||||
allow_host_compat: bool
|
||||
firecracker_pid: int | None = None
|
||||
last_error: str | None = None
|
||||
|
|
@ -135,6 +139,7 @@ class WorkspaceRecord:
|
|||
cls,
|
||||
instance: VmInstance,
|
||||
*,
|
||||
network_policy: WorkspaceNetworkPolicy = DEFAULT_WORKSPACE_NETWORK_POLICY,
|
||||
command_count: int = 0,
|
||||
last_command: dict[str, Any] | None = None,
|
||||
workspace_seed: dict[str, Any] | None = None,
|
||||
|
|
@ -149,7 +154,7 @@ class WorkspaceRecord:
|
|||
created_at=instance.created_at,
|
||||
expires_at=instance.expires_at,
|
||||
state=instance.state,
|
||||
network_requested=instance.network_requested,
|
||||
network_policy=network_policy,
|
||||
allow_host_compat=instance.allow_host_compat,
|
||||
firecracker_pid=instance.firecracker_pid,
|
||||
last_error=instance.last_error,
|
||||
|
|
@ -174,7 +179,7 @@ class WorkspaceRecord:
|
|||
expires_at=self.expires_at,
|
||||
workdir=workdir,
|
||||
state=self.state,
|
||||
network_requested=self.network_requested,
|
||||
network_requested=self.network_policy != "off",
|
||||
allow_host_compat=self.allow_host_compat,
|
||||
firecracker_pid=self.firecracker_pid,
|
||||
last_error=self.last_error,
|
||||
|
|
@ -193,7 +198,7 @@ class WorkspaceRecord:
|
|||
"created_at": self.created_at,
|
||||
"expires_at": self.expires_at,
|
||||
"state": self.state,
|
||||
"network_requested": self.network_requested,
|
||||
"network_policy": self.network_policy,
|
||||
"allow_host_compat": self.allow_host_compat,
|
||||
"firecracker_pid": self.firecracker_pid,
|
||||
"last_error": self.last_error,
|
||||
|
|
@ -218,7 +223,7 @@ class WorkspaceRecord:
|
|||
created_at=float(payload["created_at"]),
|
||||
expires_at=float(payload["expires_at"]),
|
||||
state=cast(VmState, str(payload.get("state", "stopped"))),
|
||||
network_requested=bool(payload.get("network_requested", False)),
|
||||
network_policy=_workspace_network_policy_from_payload(payload),
|
||||
allow_host_compat=bool(payload.get("allow_host_compat", DEFAULT_ALLOW_HOST_COMPAT)),
|
||||
firecracker_pid=_optional_int(payload.get("firecracker_pid")),
|
||||
last_error=_optional_str(payload.get("last_error")),
|
||||
|
|
@ -365,6 +370,7 @@ class WorkspaceServiceRecord:
|
|||
pid: int | None = None
|
||||
execution_mode: str = "pending"
|
||||
stop_reason: str | None = None
|
||||
published_ports: list[WorkspacePublishedPortRecord] = field(default_factory=list)
|
||||
metadata: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
|
|
@ -383,6 +389,9 @@ class WorkspaceServiceRecord:
|
|||
"pid": self.pid,
|
||||
"execution_mode": self.execution_mode,
|
||||
"stop_reason": self.stop_reason,
|
||||
"published_ports": [
|
||||
published_port.to_payload() for published_port in self.published_ports
|
||||
],
|
||||
"metadata": dict(self.metadata),
|
||||
}
|
||||
|
||||
|
|
@ -412,10 +421,53 @@ class WorkspaceServiceRecord:
|
|||
pid=None if payload.get("pid") is None else int(payload.get("pid", 0)),
|
||||
execution_mode=str(payload.get("execution_mode", "pending")),
|
||||
stop_reason=_optional_str(payload.get("stop_reason")),
|
||||
published_ports=_workspace_published_port_records(payload.get("published_ports")),
|
||||
metadata=_string_dict(payload.get("metadata")),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkspacePublishedPortRecord:
|
||||
"""Persisted localhost published-port metadata for one service."""
|
||||
|
||||
guest_port: int
|
||||
host_port: int
|
||||
host: str = DEFAULT_PUBLISHED_PORT_HOST
|
||||
protocol: str = "tcp"
|
||||
proxy_pid: int | None = None
|
||||
|
||||
def to_payload(self) -> dict[str, Any]:
|
||||
return {
|
||||
"guest_port": self.guest_port,
|
||||
"host_port": self.host_port,
|
||||
"host": self.host,
|
||||
"protocol": self.protocol,
|
||||
"proxy_pid": self.proxy_pid,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_payload(cls, payload: dict[str, Any]) -> WorkspacePublishedPortRecord:
|
||||
return cls(
|
||||
guest_port=int(payload["guest_port"]),
|
||||
host_port=int(payload["host_port"]),
|
||||
host=str(payload.get("host", DEFAULT_PUBLISHED_PORT_HOST)),
|
||||
protocol=str(payload.get("protocol", "tcp")),
|
||||
proxy_pid=(
|
||||
None
|
||||
if payload.get("proxy_pid") is None
|
||||
else int(payload.get("proxy_pid", 0))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkspacePublishedPortSpec:
|
||||
"""Requested published-port configuration for one service."""
|
||||
|
||||
guest_port: int
|
||||
host_port: int | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PreparedWorkspaceSeed:
|
||||
"""Prepared host-side seed archive plus metadata."""
|
||||
|
|
@ -534,6 +586,49 @@ def _workspace_seed_dict(value: object) -> dict[str, Any]:
|
|||
return payload
|
||||
|
||||
|
||||
def _normalize_workspace_network_policy(policy: str) -> WorkspaceNetworkPolicy:
|
||||
normalized = policy.strip().lower()
|
||||
if normalized not in {"off", "egress", "egress+published-ports"}:
|
||||
raise ValueError("network_policy must be one of: off, egress, egress+published-ports")
|
||||
return cast(WorkspaceNetworkPolicy, normalized)
|
||||
|
||||
|
||||
def _workspace_network_policy_from_payload(payload: dict[str, Any]) -> WorkspaceNetworkPolicy:
|
||||
raw_policy = payload.get("network_policy")
|
||||
if raw_policy is not None:
|
||||
return _normalize_workspace_network_policy(str(raw_policy))
|
||||
raw_network_requested = payload.get("network_requested", False)
|
||||
if isinstance(raw_network_requested, str):
|
||||
network_requested = raw_network_requested.strip().lower() in {"1", "true", "yes", "on"}
|
||||
else:
|
||||
network_requested = bool(raw_network_requested)
|
||||
if network_requested:
|
||||
return "egress"
|
||||
return DEFAULT_WORKSPACE_NETWORK_POLICY
|
||||
|
||||
|
||||
def _serialize_workspace_published_port_public(
|
||||
published_port: WorkspacePublishedPortRecord,
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"host": published_port.host,
|
||||
"host_port": published_port.host_port,
|
||||
"guest_port": published_port.guest_port,
|
||||
"protocol": published_port.protocol,
|
||||
}
|
||||
|
||||
|
||||
def _workspace_published_port_records(value: object) -> list[WorkspacePublishedPortRecord]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
records: list[WorkspacePublishedPortRecord] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
records.append(WorkspacePublishedPortRecord.from_payload(item))
|
||||
return records
|
||||
|
||||
|
||||
def _workspace_secret_records(value: object) -> list[WorkspaceSecretRecord]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
|
@ -1159,6 +1254,59 @@ def _normalize_workspace_secret_env_mapping(
|
|||
return normalized
|
||||
|
||||
|
||||
def _normalize_workspace_published_port(
|
||||
*,
|
||||
guest_port: object,
|
||||
host_port: object | None = None,
|
||||
) -> WorkspacePublishedPortSpec:
|
||||
if isinstance(guest_port, bool) or not isinstance(guest_port, int | str):
|
||||
raise ValueError("published guest_port must be an integer")
|
||||
try:
|
||||
normalized_guest_port = int(guest_port)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError("published guest_port must be an integer") from exc
|
||||
if normalized_guest_port <= 0 or normalized_guest_port > 65535:
|
||||
raise ValueError("published guest_port must be between 1 and 65535")
|
||||
normalized_host_port: int | None = None
|
||||
if host_port is not None:
|
||||
if isinstance(host_port, bool) or not isinstance(host_port, int | str):
|
||||
raise ValueError("published host_port must be an integer")
|
||||
try:
|
||||
normalized_host_port = int(host_port)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError("published host_port must be an integer") from exc
|
||||
if normalized_host_port <= 1024 or normalized_host_port > 65535:
|
||||
raise ValueError("published host_port must be between 1025 and 65535")
|
||||
return WorkspacePublishedPortSpec(
|
||||
guest_port=normalized_guest_port,
|
||||
host_port=normalized_host_port,
|
||||
)
|
||||
|
||||
|
||||
def _normalize_workspace_published_port_specs(
|
||||
published_ports: list[dict[str, Any]] | None,
|
||||
) -> list[WorkspacePublishedPortSpec]:
|
||||
if not published_ports:
|
||||
return []
|
||||
normalized: list[WorkspacePublishedPortSpec] = []
|
||||
seen_guest_ports: set[tuple[int | None, int]] = set()
|
||||
for index, item in enumerate(published_ports, start=1):
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(f"published port #{index} must be a dictionary")
|
||||
spec = _normalize_workspace_published_port(
|
||||
guest_port=item.get("guest_port"),
|
||||
host_port=item.get("host_port"),
|
||||
)
|
||||
dedupe_key = (spec.host_port, spec.guest_port)
|
||||
if dedupe_key in seen_guest_ports:
|
||||
raise ValueError(
|
||||
"published ports must not repeat the same host/guest port mapping"
|
||||
)
|
||||
seen_guest_ports.add(dedupe_key)
|
||||
normalized.append(spec)
|
||||
return normalized
|
||||
|
||||
|
||||
def _normalize_workspace_service_readiness(
|
||||
readiness: dict[str, Any] | None,
|
||||
) -> dict[str, Any] | None:
|
||||
|
|
@ -1215,6 +1363,15 @@ def _workspace_service_runner_path(services_dir: Path, service_name: str) -> Pat
|
|||
return services_dir / f"{service_name}.runner.sh"
|
||||
|
||||
|
||||
def _workspace_service_port_ready_path(
|
||||
services_dir: Path,
|
||||
service_name: str,
|
||||
host_port: int,
|
||||
guest_port: int,
|
||||
) -> Path:
|
||||
return services_dir / f"{service_name}.port-{host_port}-to-{guest_port}.ready.json"
|
||||
|
||||
|
||||
def _read_service_exit_code(status_path: Path) -> int | None:
|
||||
if not status_path.exists():
|
||||
return None
|
||||
|
|
@ -1348,6 +1505,7 @@ def _refresh_local_service_record(
|
|||
pid=service.pid,
|
||||
execution_mode=service.execution_mode,
|
||||
stop_reason=service.stop_reason,
|
||||
published_ports=list(service.published_ports),
|
||||
metadata=dict(service.metadata),
|
||||
)
|
||||
return refreshed
|
||||
|
|
@ -1466,6 +1624,95 @@ def _stop_local_service(
|
|||
return refreshed
|
||||
|
||||
|
||||
def _start_workspace_published_port_proxy(
|
||||
*,
|
||||
services_dir: Path,
|
||||
service_name: str,
|
||||
workspace_id: str,
|
||||
guest_ip: str,
|
||||
spec: WorkspacePublishedPortSpec,
|
||||
) -> WorkspacePublishedPortRecord:
|
||||
ready_path = _workspace_service_port_ready_path(
|
||||
services_dir,
|
||||
service_name,
|
||||
spec.host_port or 0,
|
||||
spec.guest_port,
|
||||
)
|
||||
ready_path.unlink(missing_ok=True)
|
||||
command = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pyro_mcp.workspace_ports",
|
||||
"--listen-host",
|
||||
DEFAULT_PUBLISHED_PORT_HOST,
|
||||
"--listen-port",
|
||||
str(spec.host_port or 0),
|
||||
"--target-host",
|
||||
guest_ip,
|
||||
"--target-port",
|
||||
str(spec.guest_port),
|
||||
"--ready-file",
|
||||
str(ready_path),
|
||||
]
|
||||
process = subprocess.Popen( # noqa: S603
|
||||
command,
|
||||
text=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
start_new_session=True,
|
||||
)
|
||||
deadline = time.monotonic() + 5
|
||||
while time.monotonic() < deadline:
|
||||
if ready_path.exists():
|
||||
payload = json.loads(ready_path.read_text(encoding="utf-8"))
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("published port proxy ready payload is invalid")
|
||||
ready_path.unlink(missing_ok=True)
|
||||
return WorkspacePublishedPortRecord(
|
||||
guest_port=int(payload.get("target_port", spec.guest_port)),
|
||||
host_port=int(payload["host_port"]),
|
||||
host=str(payload.get("host", DEFAULT_PUBLISHED_PORT_HOST)),
|
||||
protocol=str(payload.get("protocol", "tcp")),
|
||||
proxy_pid=process.pid,
|
||||
)
|
||||
if process.poll() is not None:
|
||||
raise RuntimeError(
|
||||
"failed to start published port proxy for "
|
||||
f"service {service_name!r} in workspace {workspace_id!r}"
|
||||
)
|
||||
time.sleep(0.05)
|
||||
_stop_workspace_published_port_proxy(
|
||||
WorkspacePublishedPortRecord(
|
||||
guest_port=spec.guest_port,
|
||||
host_port=spec.host_port or 0,
|
||||
proxy_pid=process.pid,
|
||||
)
|
||||
)
|
||||
ready_path.unlink(missing_ok=True)
|
||||
raise RuntimeError(
|
||||
"timed out waiting for published port proxy readiness for "
|
||||
f"service {service_name!r} in workspace {workspace_id!r}"
|
||||
)
|
||||
|
||||
|
||||
def _stop_workspace_published_port_proxy(published_port: WorkspacePublishedPortRecord) -> None:
|
||||
if published_port.proxy_pid is None:
|
||||
return
|
||||
try:
|
||||
os.killpg(published_port.proxy_pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
deadline = time.monotonic() + 5
|
||||
while time.monotonic() < deadline:
|
||||
if not _pid_is_running(published_port.proxy_pid):
|
||||
return
|
||||
time.sleep(0.05)
|
||||
try:
|
||||
os.killpg(published_port.proxy_pid, signal.SIGKILL)
|
||||
except ProcessLookupError:
|
||||
return
|
||||
|
||||
|
||||
def _instance_workspace_host_dir(instance: VmInstance) -> Path:
|
||||
raw_value = instance.metadata.get("workspace_host_dir")
|
||||
if raw_value is None or raw_value == "":
|
||||
|
|
@ -3057,13 +3304,14 @@ class VmManager:
|
|||
vcpu_count: int = DEFAULT_VCPU_COUNT,
|
||||
mem_mib: int = DEFAULT_MEM_MIB,
|
||||
ttl_seconds: int = DEFAULT_TTL_SECONDS,
|
||||
network: bool = False,
|
||||
network_policy: WorkspaceNetworkPolicy | str = DEFAULT_WORKSPACE_NETWORK_POLICY,
|
||||
allow_host_compat: bool = DEFAULT_ALLOW_HOST_COMPAT,
|
||||
seed_path: str | Path | None = None,
|
||||
secrets: list[dict[str, str]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
self._validate_limits(vcpu_count=vcpu_count, mem_mib=mem_mib, ttl_seconds=ttl_seconds)
|
||||
get_environment(environment, runtime_paths=self._runtime_paths)
|
||||
normalized_network_policy = _normalize_workspace_network_policy(str(network_policy))
|
||||
prepared_seed = self._prepare_workspace_seed(seed_path)
|
||||
now = time.time()
|
||||
workspace_id = uuid.uuid4().hex[:12]
|
||||
|
|
@ -3097,12 +3345,13 @@ class VmManager:
|
|||
created_at=now,
|
||||
expires_at=now + ttl_seconds,
|
||||
workdir=runtime_dir,
|
||||
network_requested=network,
|
||||
network_requested=normalized_network_policy != "off",
|
||||
allow_host_compat=allow_host_compat,
|
||||
)
|
||||
instance.metadata["allow_host_compat"] = str(allow_host_compat).lower()
|
||||
instance.metadata["workspace_path"] = WORKSPACE_GUEST_PATH
|
||||
instance.metadata["workspace_host_dir"] = str(host_workspace_dir)
|
||||
instance.metadata["network_policy"] = normalized_network_policy
|
||||
try:
|
||||
with self._lock:
|
||||
self._reap_expired_locked(now)
|
||||
|
|
@ -3112,6 +3361,9 @@ class VmManager:
|
|||
raise RuntimeError(
|
||||
f"max active VMs reached ({self._max_active_vms}); delete old VMs first"
|
||||
)
|
||||
self._require_workspace_network_policy_support(
|
||||
network_policy=normalized_network_policy
|
||||
)
|
||||
self._backend.create(instance)
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
self._ensure_workspace_guest_bootstrap_support(instance)
|
||||
|
|
@ -3119,6 +3371,7 @@ class VmManager:
|
|||
self._start_instance_locked(instance)
|
||||
workspace = WorkspaceRecord.from_instance(
|
||||
instance,
|
||||
network_policy=normalized_network_policy,
|
||||
workspace_seed=prepared_seed.to_payload(),
|
||||
secrets=secret_records,
|
||||
)
|
||||
|
|
@ -3435,6 +3688,9 @@ class VmManager:
|
|||
recreated = workspace.to_instance(
|
||||
workdir=self._workspace_runtime_dir(workspace.workspace_id)
|
||||
)
|
||||
self._require_workspace_network_policy_support(
|
||||
network_policy=workspace.network_policy
|
||||
)
|
||||
self._backend.create(recreated)
|
||||
if self._runtime_capabilities.supports_guest_exec:
|
||||
self._ensure_workspace_guest_bootstrap_support(recreated)
|
||||
|
|
@ -3798,11 +4054,13 @@ class VmManager:
|
|||
ready_timeout_seconds: int = DEFAULT_SERVICE_READY_TIMEOUT_SECONDS,
|
||||
ready_interval_ms: int = DEFAULT_SERVICE_READY_INTERVAL_MS,
|
||||
secret_env: dict[str, str] | None = None,
|
||||
published_ports: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
normalized_service_name = _normalize_workspace_service_name(service_name)
|
||||
normalized_cwd, _ = _normalize_workspace_destination(cwd)
|
||||
normalized_readiness = _normalize_workspace_service_readiness(readiness)
|
||||
normalized_secret_env = _normalize_workspace_secret_env_mapping(secret_env)
|
||||
normalized_published_ports = _normalize_workspace_published_port_specs(published_ports)
|
||||
if ready_timeout_seconds <= 0:
|
||||
raise ValueError("ready_timeout_seconds must be positive")
|
||||
if ready_interval_ms <= 0:
|
||||
|
|
@ -3810,6 +4068,16 @@ class VmManager:
|
|||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
instance = self._workspace_instance_for_live_service_locked(workspace)
|
||||
if normalized_published_ports:
|
||||
if workspace.network_policy != "egress+published-ports":
|
||||
raise RuntimeError(
|
||||
"published ports require workspace network_policy "
|
||||
"'egress+published-ports'"
|
||||
)
|
||||
if instance.network is None:
|
||||
raise RuntimeError(
|
||||
"published ports require an active guest network configuration"
|
||||
)
|
||||
redact_values = self._workspace_secret_redact_values_locked(workspace)
|
||||
env_values = self._workspace_secret_env_values_locked(workspace, normalized_secret_env)
|
||||
if workspace.secrets and normalized_secret_env:
|
||||
|
|
@ -3852,6 +4120,36 @@ class VmManager:
|
|||
service_name=normalized_service_name,
|
||||
payload=payload,
|
||||
)
|
||||
if normalized_published_ports:
|
||||
assert instance.network is not None # guarded above
|
||||
try:
|
||||
service.published_ports = self._start_workspace_service_published_ports(
|
||||
workspace=workspace,
|
||||
service=service,
|
||||
guest_ip=instance.network.guest_ip,
|
||||
published_ports=normalized_published_ports,
|
||||
)
|
||||
except Exception:
|
||||
try:
|
||||
failed_payload = self._backend.stop_service(
|
||||
instance,
|
||||
workspace_id=workspace_id,
|
||||
service_name=normalized_service_name,
|
||||
)
|
||||
service = self._workspace_service_record_from_payload(
|
||||
workspace_id=workspace_id,
|
||||
service_name=normalized_service_name,
|
||||
payload=failed_payload,
|
||||
published_ports=[],
|
||||
)
|
||||
except Exception:
|
||||
service.state = "failed"
|
||||
service.stop_reason = "published_port_failed"
|
||||
service.ended_at = service.ended_at or time.time()
|
||||
else:
|
||||
service.state = "failed"
|
||||
service.stop_reason = "published_port_failed"
|
||||
service.ended_at = service.ended_at or time.time()
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
workspace.state = instance.state
|
||||
|
|
@ -3916,7 +4214,21 @@ class VmManager:
|
|||
service_name=normalized_service_name,
|
||||
payload=payload,
|
||||
metadata=service.metadata,
|
||||
published_ports=service.published_ports,
|
||||
)
|
||||
if service.published_ports:
|
||||
for published_port in service.published_ports:
|
||||
_stop_workspace_published_port_proxy(published_port)
|
||||
service.published_ports = [
|
||||
WorkspacePublishedPortRecord(
|
||||
guest_port=published_port.guest_port,
|
||||
host_port=published_port.host_port,
|
||||
host=published_port.host,
|
||||
protocol=published_port.protocol,
|
||||
proxy_pid=None,
|
||||
)
|
||||
for published_port in service.published_ports
|
||||
]
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
workspace.state = instance.state
|
||||
|
|
@ -3956,6 +4268,7 @@ class VmManager:
|
|||
service_name=normalized_service_name,
|
||||
payload=payload,
|
||||
metadata=service.metadata,
|
||||
published_ports=service.published_ports,
|
||||
)
|
||||
with self._lock:
|
||||
workspace = self._load_workspace_locked(workspace_id)
|
||||
|
|
@ -4059,6 +4372,7 @@ class VmManager:
|
|||
"created_at": workspace.created_at,
|
||||
"expires_at": workspace.expires_at,
|
||||
"state": workspace.state,
|
||||
"network_policy": workspace.network_policy,
|
||||
"network_enabled": workspace.network is not None,
|
||||
"allow_host_compat": workspace.allow_host_compat,
|
||||
"guest_ip": workspace.network.guest_ip if workspace.network is not None else None,
|
||||
|
|
@ -4107,6 +4421,10 @@ class VmManager:
|
|||
"readiness": dict(service.readiness) if service.readiness is not None else None,
|
||||
"ready_at": service.ready_at,
|
||||
"stop_reason": service.stop_reason,
|
||||
"published_ports": [
|
||||
_serialize_workspace_published_port_public(published_port)
|
||||
for published_port in service.published_ports
|
||||
],
|
||||
}
|
||||
|
||||
def _serialize_workspace_snapshot(self, snapshot: WorkspaceSnapshotRecord) -> dict[str, Any]:
|
||||
|
|
@ -4182,6 +4500,23 @@ class VmManager:
|
|||
f"workspace: {reason}"
|
||||
)
|
||||
|
||||
def _require_workspace_network_policy_support(
|
||||
self,
|
||||
*,
|
||||
network_policy: WorkspaceNetworkPolicy,
|
||||
) -> None:
|
||||
if network_policy == "off":
|
||||
return
|
||||
if self._runtime_capabilities.supports_guest_network:
|
||||
return
|
||||
reason = self._runtime_capabilities.reason or (
|
||||
"runtime does not support guest-backed workspace networking"
|
||||
)
|
||||
raise RuntimeError(
|
||||
"workspace network_policy requires guest networking and is unavailable for this "
|
||||
f"workspace: {reason}"
|
||||
)
|
||||
|
||||
def _workspace_secret_values_locked(self, workspace: WorkspaceRecord) -> dict[str, str]:
|
||||
return _load_workspace_secret_values(
|
||||
workspace_dir=self._workspace_dir(workspace.workspace_id),
|
||||
|
|
@ -4609,9 +4944,15 @@ class VmManager:
|
|||
service_name: str,
|
||||
payload: dict[str, Any],
|
||||
metadata: dict[str, str] | None = None,
|
||||
published_ports: list[WorkspacePublishedPortRecord] | None = None,
|
||||
) -> WorkspaceServiceRecord:
|
||||
readiness_payload = payload.get("readiness")
|
||||
readiness = dict(readiness_payload) if isinstance(readiness_payload, dict) else None
|
||||
normalized_published_ports = _workspace_published_port_records(
|
||||
payload.get("published_ports")
|
||||
)
|
||||
if not normalized_published_ports and published_ports is not None:
|
||||
normalized_published_ports = list(published_ports)
|
||||
return WorkspaceServiceRecord(
|
||||
workspace_id=workspace_id,
|
||||
service_name=str(payload.get("service_name", service_name)),
|
||||
|
|
@ -4632,6 +4973,7 @@ class VmManager:
|
|||
pid=None if payload.get("pid") is None else int(payload.get("pid", 0)),
|
||||
execution_mode=str(payload.get("execution_mode", "pending")),
|
||||
stop_reason=_optional_str(payload.get("stop_reason")),
|
||||
published_ports=normalized_published_ports,
|
||||
metadata=dict(metadata or {}),
|
||||
)
|
||||
|
||||
|
|
@ -4652,6 +4994,33 @@ class VmManager:
|
|||
services = self._list_workspace_services_locked(workspace_id)
|
||||
return len(services), sum(1 for service in services if service.state == "running")
|
||||
|
||||
def _start_workspace_service_published_ports(
|
||||
self,
|
||||
*,
|
||||
workspace: WorkspaceRecord,
|
||||
service: WorkspaceServiceRecord,
|
||||
guest_ip: str,
|
||||
published_ports: list[WorkspacePublishedPortSpec],
|
||||
) -> list[WorkspacePublishedPortRecord]:
|
||||
services_dir = self._workspace_services_dir(workspace.workspace_id)
|
||||
started: list[WorkspacePublishedPortRecord] = []
|
||||
try:
|
||||
for spec in published_ports:
|
||||
started.append(
|
||||
_start_workspace_published_port_proxy(
|
||||
services_dir=services_dir,
|
||||
service_name=service.service_name,
|
||||
workspace_id=workspace.workspace_id,
|
||||
guest_ip=guest_ip,
|
||||
spec=spec,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
for published_port in started:
|
||||
_stop_workspace_published_port_proxy(published_port)
|
||||
raise
|
||||
return started
|
||||
|
||||
def _workspace_baseline_snapshot_locked(
|
||||
self,
|
||||
workspace: WorkspaceRecord,
|
||||
|
|
@ -4770,12 +5139,18 @@ class VmManager:
|
|||
workspace_id: str,
|
||||
service_name: str,
|
||||
) -> None:
|
||||
existing = self._load_workspace_service_locked_optional(workspace_id, service_name)
|
||||
if existing is not None:
|
||||
for published_port in existing.published_ports:
|
||||
_stop_workspace_published_port_proxy(published_port)
|
||||
self._workspace_service_record_path(workspace_id, service_name).unlink(missing_ok=True)
|
||||
services_dir = self._workspace_services_dir(workspace_id)
|
||||
_workspace_service_stdout_path(services_dir, service_name).unlink(missing_ok=True)
|
||||
_workspace_service_stderr_path(services_dir, service_name).unlink(missing_ok=True)
|
||||
_workspace_service_status_path(services_dir, service_name).unlink(missing_ok=True)
|
||||
_workspace_service_runner_path(services_dir, service_name).unlink(missing_ok=True)
|
||||
for ready_path in services_dir.glob(f"{service_name}.port-*.ready.json"):
|
||||
ready_path.unlink(missing_ok=True)
|
||||
|
||||
def _delete_workspace_snapshot_locked(self, workspace_id: str, snapshot_name: str) -> None:
|
||||
self._workspace_snapshot_metadata_path(workspace_id, snapshot_name).unlink(missing_ok=True)
|
||||
|
|
@ -4881,7 +5256,21 @@ class VmManager:
|
|||
service_name=service.service_name,
|
||||
payload=payload,
|
||||
metadata=service.metadata,
|
||||
published_ports=service.published_ports,
|
||||
)
|
||||
if refreshed.state != "running" and refreshed.published_ports:
|
||||
refreshed.published_ports = [
|
||||
WorkspacePublishedPortRecord(
|
||||
guest_port=published_port.guest_port,
|
||||
host_port=published_port.host_port,
|
||||
host=published_port.host,
|
||||
protocol=published_port.protocol,
|
||||
proxy_pid=None,
|
||||
)
|
||||
for published_port in refreshed.published_ports
|
||||
]
|
||||
for published_port in service.published_ports:
|
||||
_stop_workspace_published_port_proxy(published_port)
|
||||
self._save_workspace_service_locked(refreshed)
|
||||
return refreshed
|
||||
|
||||
|
|
@ -4904,6 +5293,8 @@ class VmManager:
|
|||
changed = False
|
||||
for service in services:
|
||||
if service.state == "running":
|
||||
for published_port in service.published_ports:
|
||||
_stop_workspace_published_port_proxy(published_port)
|
||||
service.state = "stopped"
|
||||
service.stop_reason = "workspace_stopped"
|
||||
service.ended_at = service.ended_at or time.time()
|
||||
|
|
@ -4936,6 +5327,7 @@ class VmManager:
|
|||
service_name=service.service_name,
|
||||
payload=payload,
|
||||
metadata=service.metadata,
|
||||
published_ports=service.published_ports,
|
||||
)
|
||||
self._save_workspace_service_locked(stopped)
|
||||
except Exception:
|
||||
|
|
|
|||
116
src/pyro_mcp/workspace_ports.py
Normal file
116
src/pyro_mcp/workspace_ports.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""Localhost-only TCP port proxy for published workspace services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import selectors
|
||||
import signal
|
||||
import socket
|
||||
import socketserver
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
DEFAULT_PUBLISHED_PORT_HOST = "127.0.0.1"
|
||||
|
||||
|
||||
class _ProxyServer(socketserver.ThreadingMixIn, socketserver.TCPServer):
|
||||
allow_reuse_address = False
|
||||
daemon_threads = True
|
||||
|
||||
def __init__(self, server_address: tuple[str, int], target_address: tuple[str, int]) -> None:
|
||||
super().__init__(server_address, _ProxyHandler)
|
||||
self.target_address = target_address
|
||||
|
||||
|
||||
class _ProxyHandler(socketserver.BaseRequestHandler):
|
||||
def handle(self) -> None:
|
||||
server = self.server
|
||||
if not isinstance(server, _ProxyServer):
|
||||
raise RuntimeError("proxy server is invalid")
|
||||
try:
|
||||
upstream = socket.create_connection(server.target_address, timeout=5)
|
||||
except OSError:
|
||||
return
|
||||
with upstream:
|
||||
self.request.setblocking(False)
|
||||
upstream.setblocking(False)
|
||||
selector = selectors.DefaultSelector()
|
||||
try:
|
||||
selector.register(self.request, selectors.EVENT_READ, upstream)
|
||||
selector.register(upstream, selectors.EVENT_READ, self.request)
|
||||
while True:
|
||||
events = selector.select()
|
||||
if not events:
|
||||
continue
|
||||
for key, _ in events:
|
||||
source = key.fileobj
|
||||
target = key.data
|
||||
if not isinstance(source, socket.socket) or not isinstance(
|
||||
target, socket.socket
|
||||
):
|
||||
continue
|
||||
try:
|
||||
chunk = source.recv(65536)
|
||||
except OSError:
|
||||
return
|
||||
if not chunk:
|
||||
return
|
||||
try:
|
||||
target.sendall(chunk)
|
||||
except OSError:
|
||||
return
|
||||
finally:
|
||||
selector.close()
|
||||
|
||||
|
||||
def _build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(description="Run a localhost-only TCP port proxy.")
|
||||
parser.add_argument("--listen-host", required=True)
|
||||
parser.add_argument("--listen-port", type=int, required=True)
|
||||
parser.add_argument("--target-host", required=True)
|
||||
parser.add_argument("--target-port", type=int, required=True)
|
||||
parser.add_argument("--ready-file", required=True)
|
||||
return parser
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None) -> int:
|
||||
args = _build_parser().parse_args(argv)
|
||||
ready_file = Path(args.ready_file)
|
||||
ready_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
server = _ProxyServer(
|
||||
(str(args.listen_host), int(args.listen_port)),
|
||||
(str(args.target_host), int(args.target_port)),
|
||||
)
|
||||
actual_host = str(server.server_address[0])
|
||||
actual_port = int(server.server_address[1])
|
||||
ready_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"host": actual_host,
|
||||
"host_port": actual_port,
|
||||
"target_host": args.target_host,
|
||||
"target_port": int(args.target_port),
|
||||
"protocol": "tcp",
|
||||
},
|
||||
indent=2,
|
||||
sort_keys=True,
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
def _shutdown(_: int, __: object) -> None:
|
||||
threading.Thread(target=server.shutdown, daemon=True).start()
|
||||
|
||||
signal.signal(signal.SIGTERM, _shutdown)
|
||||
signal.signal(signal.SIGINT, _shutdown)
|
||||
try:
|
||||
server.serve_forever(poll_interval=0.2)
|
||||
finally:
|
||||
server.server_close()
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main(sys.argv[1:]))
|
||||
Loading…
Add table
Add a link
Reference in a new issue