Enable real guest networking and make demos network-first

This commit is contained in:
Thales Maciel 2026-03-06 22:47:16 -03:00
parent c43c718c83
commit b01efa6452
14 changed files with 618 additions and 72 deletions

View file

@ -9,6 +9,7 @@ import shutil
import subprocess
import tarfile
import urllib.request
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any
@ -307,6 +308,7 @@ def materialize_rootfs(
packages_path = paths.source_platform_root / raw_packages_path
output_path = paths.materialized_platform_root / profile["rootfs"]
output_path.parent.mkdir(parents=True, exist_ok=True)
profile_workdir = workdir / f"{profile_name}-{uuid.uuid4().hex[:8]}"
_run(
[
str(script_path),
@ -325,7 +327,7 @@ def materialize_rootfs(
"--agent-service",
str(service_path),
"--workdir",
str(workdir / profile_name),
str(profile_workdir),
"--output",
str(output_path),
]

View file

@ -0,0 +1,96 @@
"""Direct guest-network validation for a bundled runtime profile."""
from __future__ import annotations
import argparse
from dataclasses import dataclass
from pathlib import Path
from pyro_mcp.vm_manager import VmManager
from pyro_mcp.vm_network import TapNetworkManager
NETWORK_CHECK_COMMAND = (
"rm -rf hello-world "
"&& git clone --depth 1 https://github.com/octocat/Hello-World.git hello-world >/dev/null "
"&& git -C hello-world rev-parse --is-inside-work-tree"
)
@dataclass(frozen=True)
class NetworkCheckResult:
vm_id: str
execution_mode: str
network_enabled: bool
exit_code: int
stdout: str
stderr: str
cleanup: dict[str, object]
def run_network_check(
*,
profile: str = "debian-git",
vcpu_count: int = 1,
mem_mib: int = 1024,
ttl_seconds: int = 600,
timeout_seconds: int = 120,
base_dir: Path | None = None,
) -> NetworkCheckResult: # pragma: no cover - integration helper
manager = VmManager(
base_dir=base_dir,
network_manager=TapNetworkManager(enabled=True),
)
created = manager.create_vm(
profile=profile,
vcpu_count=vcpu_count,
mem_mib=mem_mib,
ttl_seconds=ttl_seconds,
)
vm_id = str(created["vm_id"])
manager.start_vm(vm_id)
status = manager.status_vm(vm_id)
executed = manager.exec_vm(
vm_id,
command=NETWORK_CHECK_COMMAND,
timeout_seconds=timeout_seconds,
)
return NetworkCheckResult(
vm_id=vm_id,
execution_mode=str(executed["execution_mode"]),
network_enabled=bool(status["network_enabled"]),
exit_code=int(executed["exit_code"]),
stdout=str(executed["stdout"]),
stderr=str(executed["stderr"]),
cleanup=dict(executed["cleanup"]),
)
def main() -> None: # pragma: no cover - CLI wiring
parser = argparse.ArgumentParser(description="Run a guest networking check.")
parser.add_argument("--profile", default="debian-git")
parser.add_argument("--vcpu-count", type=int, default=1)
parser.add_argument("--mem-mib", type=int, default=1024)
parser.add_argument("--ttl-seconds", type=int, default=600)
parser.add_argument("--timeout-seconds", type=int, default=120)
args = parser.parse_args()
result = run_network_check(
profile=args.profile,
vcpu_count=args.vcpu_count,
mem_mib=args.mem_mib,
ttl_seconds=args.ttl_seconds,
timeout_seconds=args.timeout_seconds,
)
print(f"[network] vm_id={result.vm_id}")
print(f"[network] execution_mode={result.execution_mode}")
print(f"[network] network_enabled={result.network_enabled}")
print(f"[network] exit_code={result.exit_code}")
if result.exit_code == 0 and result.stdout.strip() == "true":
print("[network] result=success")
return
print("[network] result=failure")
if result.stdout.strip():
print(f"[network] stdout={result.stdout.strip()}")
if result.stderr.strip():
print(f"[network] stderr={result.stderr.strip()}")
raise SystemExit(1)

View file

@ -5,13 +5,13 @@ from __future__ import annotations
import json
import socket
from dataclasses import dataclass
from typing import Callable, Protocol
from typing import Any, Callable, Protocol
class SocketLike(Protocol):
def settimeout(self, timeout: int) -> None: ...
def connect(self, address: tuple[int, int]) -> None: ...
def connect(self, address: Any) -> None: ...
def sendall(self, data: bytes) -> None: ...
@ -38,19 +38,35 @@ class VsockExecClient:
self._socket_factory = socket_factory or socket.socket
def exec(
self, guest_cid: int, port: int, command: str, timeout_seconds: int
self,
guest_cid: int,
port: int,
command: str,
timeout_seconds: int,
*,
uds_path: str | None = None,
) -> GuestExecResponse:
request = {
"command": command,
"timeout_seconds": timeout_seconds,
}
family = getattr(socket, "AF_VSOCK", None)
if family is None:
if family is not None:
sock = self._socket_factory(family, socket.SOCK_STREAM)
connect_address: Any = (guest_cid, port)
elif uds_path is not None:
sock = self._socket_factory(socket.AF_UNIX, socket.SOCK_STREAM)
connect_address = uds_path
else:
raise RuntimeError("vsock sockets are not supported on this host Python runtime")
sock = self._socket_factory(family, socket.SOCK_STREAM)
try:
sock.settimeout(timeout_seconds)
sock.connect((guest_cid, port))
sock.connect(connect_address)
if family is None:
sock.sendall(f"CONNECT {port}\n".encode("utf-8"))
status = self._recv_line(sock)
if not status.startswith("OK "):
raise RuntimeError(f"vsock unix bridge rejected port {port}: {status.strip()}")
sock.sendall((json.dumps(request) + "\n").encode("utf-8"))
chunks: list[bytes] = []
while True:
@ -70,3 +86,15 @@ class VsockExecClient:
exit_code=int(payload.get("exit_code", -1)),
duration_ms=int(payload.get("duration_ms", 0)),
)
@staticmethod
def _recv_line(sock: SocketLike) -> str:
chunks: list[bytes] = []
while True:
data = sock.recv(1)
if data == b"":
break
chunks.append(data)
if data == b"\n":
break
return b"".join(chunks).decode("utf-8", errors="replace")

View file

@ -4,6 +4,7 @@ from __future__ import annotations
import os
import shutil
import signal
import subprocess
import threading
import time
@ -143,6 +144,7 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
self._runtime_capabilities = runtime_capabilities
self._network_manager = network_manager or TapNetworkManager()
self._guest_exec_client = guest_exec_client or VsockExecClient()
self._processes: dict[str, subprocess.Popen[str]] = {}
if not self._firecracker_bin.exists():
raise RuntimeError(f"bundled firecracker binary not found at {self._firecracker_bin}")
if not self._jailer_bin.exists():
@ -160,10 +162,16 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
f"{artifacts.kernel_image} and {artifacts.rootfs_image}"
)
instance.metadata["kernel_image"] = str(artifacts.kernel_image)
instance.metadata["rootfs_image"] = str(artifacts.rootfs_image)
network = self._network_manager.allocate(instance.vm_id)
instance.network = network
instance.metadata.update(self._network_manager.to_metadata(network))
rootfs_copy = instance.workdir / "rootfs.ext4"
shutil.copy2(artifacts.rootfs_image, rootfs_copy)
instance.metadata["rootfs_image"] = str(rootfs_copy)
if self._network_manager.enabled:
network = self._network_manager.allocate(instance.vm_id)
instance.network = network
instance.metadata.update(self._network_manager.to_metadata(network))
else:
instance.network = None
instance.metadata["network_enabled"] = "false"
except Exception:
shutil.rmtree(instance.workdir, ignore_errors=True)
raise
@ -175,6 +183,12 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
instance.metadata["guest_exec_path"] = str(launch_plan.guest_exec_path)
instance.metadata["guest_cid"] = str(launch_plan.guest_cid)
instance.metadata["guest_exec_port"] = str(launch_plan.vsock_port)
instance.metadata["guest_exec_uds_path"] = str(instance.workdir / "vsock.sock")
serial_log_path = instance.workdir / "serial.log"
firecracker_log_path = instance.workdir / "firecracker.log"
firecracker_log_path.touch()
instance.metadata["serial_log_path"] = str(serial_log_path)
instance.metadata["firecracker_log_path"] = str(firecracker_log_path)
proc = subprocess.run( # noqa: S603
[str(self._firecracker_bin), "--version"],
text=True,
@ -191,14 +205,60 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
if self._runtime_capabilities.reason is not None:
instance.metadata["runtime_reason"] = self._runtime_capabilities.reason
return
instance.metadata["execution_mode"] = "guest_vsock"
with serial_log_path.open("w", encoding="utf-8") as serial_fp:
process = subprocess.Popen( # noqa: S603
[
str(self._firecracker_bin),
"--no-api",
"--config-file",
str(launch_plan.config_path),
"--log-path",
str(firecracker_log_path),
"--level",
"Info",
],
stdout=serial_fp,
stderr=subprocess.STDOUT,
text=True,
)
self._processes[instance.vm_id] = process
time.sleep(2)
if process.poll() is not None:
serial_log = serial_log_path.read_text(encoding="utf-8", errors="ignore")
firecracker_log = firecracker_log_path.read_text(encoding="utf-8", errors="ignore")
self._processes.pop(instance.vm_id, None)
raise RuntimeError(
"firecracker microVM exited during startup: "
f"{(serial_log or firecracker_log).strip()}"
)
instance.firecracker_pid = process.pid
instance.metadata["execution_mode"] = (
"guest_vsock" if self._runtime_capabilities.supports_guest_exec else "host_compat"
)
instance.metadata["boot_mode"] = "native"
def exec(self, instance: VmInstance, command: str, timeout_seconds: int) -> VmExecResult:
if self._runtime_capabilities.supports_guest_exec:
guest_cid = int(instance.metadata["guest_cid"])
port = int(instance.metadata["guest_exec_port"])
response = self._guest_exec_client.exec(guest_cid, port, command, timeout_seconds)
uds_path = instance.metadata.get("guest_exec_uds_path")
deadline = time.monotonic() + min(timeout_seconds, 10)
while True:
try:
response = self._guest_exec_client.exec(
guest_cid,
port,
command,
timeout_seconds,
uds_path=uds_path,
)
break
except (OSError, RuntimeError) as exc:
if time.monotonic() >= deadline:
raise RuntimeError(
f"guest exec transport did not become ready: {exc}"
) from exc
time.sleep(0.2)
return VmExecResult(
stdout=response.stdout,
stderr=response.stderr,
@ -209,9 +269,36 @@ class FirecrackerBackend(VmBackend): # pragma: no cover
return _run_host_command(instance.workdir, command, timeout_seconds)
def stop(self, instance: VmInstance) -> None:
del instance
process = self._processes.pop(instance.vm_id, None)
if process is not None:
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
process.wait(timeout=5)
instance.firecracker_pid = None
return
if instance.firecracker_pid is None:
return
try:
os.kill(instance.firecracker_pid, signal.SIGTERM)
except ProcessLookupError:
instance.firecracker_pid = None
return
deadline = time.monotonic() + 5
while time.monotonic() < deadline:
try:
os.kill(instance.firecracker_pid, 0)
except ProcessLookupError:
instance.firecracker_pid = None
return
time.sleep(0.1)
os.kill(instance.firecracker_pid, signal.SIGKILL)
instance.firecracker_pid = None
def delete(self, instance: VmInstance) -> None:
self._processes.pop(instance.vm_id, None)
if instance.network is not None:
self._network_manager.cleanup(instance.network)
shutil.rmtree(instance.workdir, ignore_errors=True)

View file

@ -34,6 +34,7 @@ class NetworkDiagnostics:
nft_binary: str | None
iptables_binary: str | None
ip_forward_enabled: bool
sudo_non_interactive: bool
class TapNetworkManager:
@ -44,11 +45,13 @@ class TapNetworkManager:
*,
enabled: bool | None = None,
runner: CommandRunner | None = None,
use_sudo: bool | None = None,
) -> None:
if enabled is None:
self._enabled = os.environ.get("PYRO_VM_ENABLE_NETWORK") == "1"
else:
self._enabled = enabled
self._use_sudo = self._detect_sudo() if use_sudo is None else use_sudo
self._runner = runner or self._run
@staticmethod
@ -67,6 +70,7 @@ class TapNetworkManager:
nft_binary=shutil.which("nft"),
iptables_binary=shutil.which("iptables"),
ip_forward_enabled=ip_forward,
sudo_non_interactive=TapNetworkManager._detect_sudo(),
)
@property
@ -90,7 +94,44 @@ class TapNetworkManager:
mac_address=mac_address,
)
if self._enabled:
self._ensure_host_network(config)
try:
self._ensure_host_network(config)
except Exception:
table_name = self._nft_table_name(config.vm_id)
self._run_ignore(["nft", "delete", "table", "ip", table_name])
self._run_ignore(
[
"iptables",
"-t",
"nat",
"-D",
"POSTROUTING",
"-s",
config.subnet_cidr,
"-j",
"MASQUERADE",
]
)
self._run_ignore(
["iptables", "-D", "FORWARD", "-i", config.tap_name, "-j", "ACCEPT"]
)
self._run_ignore(
[
"iptables",
"-D",
"FORWARD",
"-o",
config.tap_name,
"-m",
"conntrack",
"--ctstate",
"RELATED,ESTABLISHED",
"-j",
"ACCEPT",
]
)
self._run_ignore(["ip", "link", "del", config.tap_name])
raise
return config
def cleanup(self, config: NetworkConfig) -> None:
@ -98,6 +139,35 @@ class TapNetworkManager:
return
table_name = self._nft_table_name(config.vm_id)
self._run_ignore(["nft", "delete", "table", "ip", table_name])
self._run_ignore(
[
"iptables",
"-t",
"nat",
"-D",
"POSTROUTING",
"-s",
config.subnet_cidr,
"-j",
"MASQUERADE",
]
)
self._run_ignore(["iptables", "-D", "FORWARD", "-i", config.tap_name, "-j", "ACCEPT"])
self._run_ignore(
[
"iptables",
"-D",
"FORWARD",
"-o",
config.tap_name,
"-m",
"conntrack",
"--ctstate",
"RELATED,ESTABLISHED",
"-j",
"ACCEPT",
]
)
self._run_ignore(["ip", "link", "del", config.tap_name])
def to_metadata(self, config: NetworkConfig) -> dict[str, str]:
@ -129,15 +199,33 @@ class TapNetworkManager:
raise RuntimeError("/dev/net/tun is not available on this host")
if diagnostics.ip_binary is None:
raise RuntimeError("`ip` binary is required for TAP setup")
if diagnostics.nft_binary is None:
raise RuntimeError("`nft` binary is required for outbound NAT setup")
if diagnostics.nft_binary is None and diagnostics.iptables_binary is None:
raise RuntimeError("`nft` or `iptables` is required for outbound NAT setup")
if not diagnostics.ip_forward_enabled:
raise RuntimeError("IPv4 forwarding is disabled on this host")
self._runner(["ip", "tuntap", "add", "dev", config.tap_name, "mode", "tap"])
self._runner(
[
"ip",
"tuntap",
"add",
"dev",
config.tap_name,
"mode",
"tap",
"user",
str(os.getuid()),
]
)
self._runner(["ip", "addr", "add", f"{config.gateway_ip}/24", "dev", config.tap_name])
self._runner(["ip", "link", "set", config.tap_name, "up"])
if diagnostics.nft_binary is not None:
self._ensure_nft_network(config)
return
self._ensure_iptables_network(config)
def _ensure_nft_network(self, config: NetworkConfig) -> None:
table_name = self._nft_table_name(config.vm_id)
self._runner(["nft", "add", "table", "ip", table_name])
self._runner(
@ -159,18 +247,102 @@ class TapNetworkManager:
"}",
]
)
self._runner([
"nft",
"add",
"rule",
"ip",
table_name,
"postrouting",
"ip",
"saddr",
config.subnet_cidr,
"masquerade",
])
self._runner(
[
"nft",
"add",
"chain",
"ip",
table_name,
"forward",
"{",
"type",
"filter",
"hook",
"forward",
"priority",
"filter",
";",
"policy",
"accept",
";",
"}",
]
)
self._runner(
[
"nft",
"add",
"rule",
"ip",
table_name,
"postrouting",
"ip",
"saddr",
config.subnet_cidr,
"masquerade",
]
)
self._runner(
[
"nft",
"add",
"rule",
"ip",
table_name,
"forward",
"iifname",
config.tap_name,
"accept",
]
)
self._runner(
[
"nft",
"add",
"rule",
"ip",
table_name,
"forward",
"oifname",
config.tap_name,
"ct",
"state",
"related,established",
"accept",
]
)
def _ensure_iptables_network(self, config: NetworkConfig) -> None:
self._runner(
[
"iptables",
"-t",
"nat",
"-A",
"POSTROUTING",
"-s",
config.subnet_cidr,
"-j",
"MASQUERADE",
]
)
self._runner(["iptables", "-A", "FORWARD", "-i", config.tap_name, "-j", "ACCEPT"])
self._runner(
[
"iptables",
"-A",
"FORWARD",
"-o",
config.tap_name,
"-m",
"conntrack",
"--ctstate",
"RELATED,ESTABLISHED",
"-j",
"ACCEPT",
]
)
def _run_ignore(self, command: list[str]) -> None:
try:
@ -183,9 +355,30 @@ class TapNetworkManager:
return f"pyro_{vm_id[:12]}"
@staticmethod
def _run(command: list[str]) -> subprocess.CompletedProcess[str]:
completed = subprocess.run(command, text=True, capture_output=True, check=False)
def _detect_sudo() -> bool:
if os.geteuid() == 0:
return False
if shutil.which("sudo") is None:
return False
completed = subprocess.run(
["sudo", "-n", "true"],
text=True,
capture_output=True,
check=False,
)
return completed.returncode == 0
def _run(self, command: list[str]) -> subprocess.CompletedProcess[str]:
effective_command = command
if self._use_sudo:
effective_command = ["sudo", "-n", *command]
completed = subprocess.run(
effective_command,
text=True,
capture_output=True,
check=False,
)
if completed.returncode != 0:
stderr = completed.stderr.strip() or completed.stdout.strip()
raise RuntimeError(f"command {' '.join(command)!r} failed: {stderr}")
raise RuntimeError(f"command {' '.join(effective_command)!r} failed: {stderr}")
return completed