diff --git a/internal/daemon/vm.go b/internal/daemon/vm.go index 251f039..3585db2 100644 --- a/internal/daemon/vm.go +++ b/internal/daemon/vm.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "log/slog" "net" "os" "path/filepath" @@ -25,6 +26,8 @@ import ( var ( errWaitForExitTimeout = errors.New("timed out waiting for VM to exit") gracefulShutdownWait = 10 * time.Second + vsockReadyWait = 30 * time.Second + vsockReadyPoll = 200 * time.Millisecond ) func (d *Daemon) CreateVM(ctx context.Context, params api.VMCreateParams) (vm model.VMRecord, err error) { @@ -314,10 +317,13 @@ func (d *Daemon) startVMLocked(ctx context.Context, vm model.VMRecord, image mod return cleanupOnErr(err) } op.stage("vsock_access", "vsock_path", vm.Runtime.VSockPath, "vsock_cid", vm.Runtime.VSockCID) - vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent") if err := d.ensureSocketAccess(ctx, vm.Runtime.VSockPath, "firecracker vsock socket"); err != nil { return cleanupOnErr(err) } + vmCreateStage(ctx, "wait_vsock_agent", "waiting for guest vsock agent") + if err := waitForGuestVSockAgent(ctx, d.logger, vm.Runtime.VSockPath, vsockReadyWait); err != nil { + return cleanupOnErr(err) + } op.stage("post_start_features") vmCreateStage(ctx, "wait_guest_ready", "waiting for guest services") if err := d.postStartCapabilities(ctx, vm, image); err != nil { @@ -1150,6 +1156,38 @@ func waitForPath(ctx context.Context, path string, timeout time.Duration, label } } +func waitForGuestVSockAgent(ctx context.Context, logger *slog.Logger, socketPath string, timeout time.Duration) error { + if strings.TrimSpace(socketPath) == "" { + return errors.New("vsock path is required") + } + + waitCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + ticker := time.NewTicker(vsockReadyPoll) + defer ticker.Stop() + + var lastErr error + for { + pingCtx, pingCancel := context.WithTimeout(waitCtx, 3*time.Second) + err := vsockagent.Health(pingCtx, logger, socketPath) + pingCancel() + if err == nil { + return nil + } + lastErr = err + + select { + case <-waitCtx.Done(): + if lastErr != nil { + return fmt.Errorf("guest vsock agent not ready: %w", lastErr) + } + return errors.New("guest vsock agent not ready before timeout") + case <-ticker.C: + } + } +} + func (d *Daemon) setDNS(ctx context.Context, vmName, guestIP string) error { if d.vmDNS == nil { return nil diff --git a/internal/daemon/vm_test.go b/internal/daemon/vm_test.go index 74c1881..0dfb223 100644 --- a/internal/daemon/vm_test.go +++ b/internal/daemon/vm_test.go @@ -415,6 +415,86 @@ func TestPingVMAliasReturnsAliveForHealthyVM(t *testing.T) { } } +func TestWaitForGuestVSockAgentRetriesUntilHealthy(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "fc.vsock") + listener, err := net.Listen("unix", socketPath) + if err != nil { + skipIfSocketRestricted(t, err) + t.Fatalf("listen vsock: %v", err) + } + t.Cleanup(func() { + _ = listener.Close() + _ = os.Remove(socketPath) + }) + + serverDone := make(chan error, 1) + go func() { + for attempt := 0; attempt < 2; attempt++ { + conn, err := listener.Accept() + if err != nil { + serverDone <- err + return + } + + buf := make([]byte, 512) + n, err := conn.Read(buf) + if err != nil { + _ = conn.Close() + serverDone <- err + return + } + if got := string(buf[:n]); got != "CONNECT 42070\n" { + _ = conn.Close() + serverDone <- fmt.Errorf("unexpected connect message %q", got) + return + } + if _, err := conn.Write([]byte("OK 1\n")); err != nil { + _ = conn.Close() + serverDone <- err + return + } + + if attempt == 0 { + _ = conn.Close() + continue + } + + reqBuf := make([]byte, 0, 512) + for { + n, err = conn.Read(buf) + if err != nil { + _ = conn.Close() + serverDone <- err + return + } + reqBuf = append(reqBuf, buf[:n]...) + if strings.Contains(string(reqBuf), "\r\n\r\n") { + break + } + } + if got := string(reqBuf); !strings.Contains(got, "GET /healthz HTTP/1.1\r\n") { + _ = conn.Close() + serverDone <- fmt.Errorf("unexpected health payload %q", got) + return + } + _, err = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 15\r\n\r\n{\"status\":\"ok\"}")) + _ = conn.Close() + serverDone <- err + return + } + serverDone <- errors.New("health probe did not retry") + }() + + if err := waitForGuestVSockAgent(context.Background(), nil, socketPath, time.Second); err != nil { + t.Fatalf("waitForGuestVSockAgent: %v", err) + } + if err := <-serverDone; err != nil { + t.Fatalf("server: %v", err) + } +} + func TestHealthVMReturnsFalseForStoppedVM(t *testing.T) { t.Parallel() diff --git a/internal/opencode/opencode.go b/internal/opencode/opencode.go index 01d989f..7a2af47 100644 --- a/internal/opencode/opencode.go +++ b/internal/opencode/opencode.go @@ -87,7 +87,7 @@ func waitReady(ctx context.Context, logger *slog.Logger, socketPath string, time lastErr = fmt.Errorf("guest port %d is not listening yet", Port) } else { if report != nil { - report("wait_vsock_agent", "waiting for guest vsock agent") + report("wait_guest_ready", "waiting for guest services") } lastErr = err } diff --git a/internal/opencode/opencode_test.go b/internal/opencode/opencode_test.go index e0ecdb9..8855960 100644 --- a/internal/opencode/opencode_test.go +++ b/internal/opencode/opencode_test.go @@ -57,6 +57,7 @@ func TestWaitReadyReturnsWhenPortIsListening(t *testing.T) { socketPath := filepath.Join(t.TempDir(), "opencode.vsock") listener, err := net.Listen("unix", socketPath) if err != nil { + skipIfSocketRestricted(t, err) t.Fatalf("listen: %v", err) } t.Cleanup(func() { @@ -114,3 +115,37 @@ func TestWaitReadyReturnsWhenPortIsListening(t *testing.T) { t.Fatalf("server: %v", err) } } + +func TestWaitReadyReportsGuestServicesWhenPortsUnavailable(t *testing.T) { + t.Parallel() + + var reports []string + err := waitReady( + context.Background(), + nil, + filepath.Join(t.TempDir(), "missing.vsock"), + 50*time.Millisecond, + func(stage, detail string) { + reports = append(reports, stage+":"+detail) + }, + ) + if err == nil { + t.Fatal("waitReady() error = nil, want timeout") + } + if len(reports) == 0 { + t.Fatal("waitReady() did not report progress") + } + if got := reports[0]; got != "wait_guest_ready:waiting for guest services" { + t.Fatalf("first report = %q, want guest services wait", got) + } +} + +func skipIfSocketRestricted(t *testing.T, err error) { + t.Helper() + if err == nil { + return + } + if strings.Contains(strings.ToLower(err.Error()), "operation not permitted") { + t.Skipf("socket creation is restricted in this environment: %v", err) + } +} diff --git a/scripts/make-rootfs-alpine.sh b/scripts/make-rootfs-alpine.sh index 48fa803..a09d907 100755 --- a/scripts/make-rootfs-alpine.sh +++ b/scripts/make-rootfs-alpine.sh @@ -225,7 +225,7 @@ description="Banger guest network bootstrap" depend() { need localmount - before sshd docker banger-vsock-agent banger-opencode + before sshd docker banger-opencode provide net } @@ -265,7 +265,7 @@ command="/usr/local/bin/banger-vsock-agent" depend() { need localmount - after banger-network + before banger-network sshd docker banger-opencode } start_pre() {