package vsockagent import ( "bytes" "context" "encoding/json" "errors" "net" "net/http" "path/filepath" "reflect" "strconv" "strings" "testing" "time" ) func TestNewHandlerHealthz(t *testing.T) { t.Parallel() req, err := http.NewRequest(http.MethodGet, HealthPath, nil) if err != nil { t.Fatalf("NewRequest: %v", err) } rr := newTestResponseRecorder() NewHandler().ServeHTTP(rr, req) if rr.status != http.StatusOK { t.Fatalf("status = %d, want %d", rr.status, http.StatusOK) } if got := rr.headers.Get("Content-Type"); got != "application/json" { t.Fatalf("content-type = %q", got) } var payload HealthResponse if err := json.Unmarshal(rr.body.Bytes(), &payload); err != nil { t.Fatalf("Unmarshal: %v", err) } if payload.Status != HealthyStatus { t.Fatalf("status = %q, want %q", payload.Status, HealthyStatus) } } func TestNewHandlerPorts(t *testing.T) { origCollector := portCollector t.Cleanup(func() { portCollector = origCollector }) portCollector = func(context.Context) ([]PortListener, error) { return []PortListener{{ Proto: "tcp", BindAddress: "0.0.0.0", Port: 8080, PID: 42, Process: "python3", Command: "python3 -m http.server 8080", }}, nil } req, err := http.NewRequest(http.MethodGet, PortsPath, nil) if err != nil { t.Fatalf("NewRequest: %v", err) } rr := newTestResponseRecorder() NewHandler().ServeHTTP(rr, req) if rr.status != http.StatusOK { t.Fatalf("status = %d, want %d", rr.status, http.StatusOK) } var payload PortsResponse if err := json.Unmarshal(rr.body.Bytes(), &payload); err != nil { t.Fatalf("Unmarshal: %v", err) } if len(payload.Listeners) != 1 { t.Fatalf("listeners = %d, want 1", len(payload.Listeners)) } if got := payload.Listeners[0]; got.Command != "python3 -m http.server 8080" { t.Fatalf("listener = %+v, want command", got) } } func TestHealth(t *testing.T) { t.Parallel() dir := t.TempDir() socketPath := filepath.Join(dir, "fc.vsock") listener, err := net.Listen("unix", socketPath) if err != nil { t.Fatalf("Listen: %v", err) } defer listener.Close() done := make(chan error, 1) go func() { conn, err := listener.Accept() if err != nil { done <- err return } defer conn.Close() buf := make([]byte, 0, 256) tmp := make([]byte, 256) for { n, err := conn.Read(tmp) if err != nil { done <- err return } buf = append(buf, tmp[:n]...) if strings.Contains(string(buf), "\n") { break } } if got := string(buf); got != "CONNECT 42070\n" { done <- unexpectedStringError(got) return } if _, err := conn.Write([]byte("OK 55\n")); err != nil { done <- err return } buf = buf[:0] for { n, err := conn.Read(tmp) if err != nil { done <- err return } buf = append(buf, tmp[:n]...) if strings.Contains(string(buf), "\r\n\r\n") { break } } req := string(buf) if !strings.Contains(req, "GET /healthz HTTP/1.1\r\n") { done <- unexpectedStringError(req) return } _, err = conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: 15\r\n\r\n{\"status\":\"ok\"}")) done <- err }() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() if err := Health(ctx, nil, socketPath); err != nil { t.Fatalf("Health: %v", err) } if err := <-done; err != nil { t.Fatalf("server: %v", err) } } func TestPorts(t *testing.T) { t.Parallel() dir := t.TempDir() socketPath := filepath.Join(dir, "fc.vsock") listener, err := net.Listen("unix", socketPath) if err != nil { t.Fatalf("Listen: %v", err) } defer listener.Close() done := make(chan error, 1) go func() { conn, err := listener.Accept() if err != nil { done <- err return } defer conn.Close() buf := make([]byte, 0, 256) tmp := make([]byte, 256) for { n, err := conn.Read(tmp) if err != nil { done <- err return } buf = append(buf, tmp[:n]...) if strings.Contains(string(buf), "\n") { break } } if got := string(buf); got != "CONNECT 42070\n" { done <- unexpectedStringError(got) return } if _, err := conn.Write([]byte("OK 55\n")); err != nil { done <- err return } buf = buf[:0] for { n, err := conn.Read(tmp) if err != nil { done <- err return } buf = append(buf, tmp[:n]...) if strings.Contains(string(buf), "\r\n\r\n") { break } } req := string(buf) if !strings.Contains(req, "GET /ports HTTP/1.1\r\n") { done <- unexpectedStringError(req) return } body := `{"listeners":[{"proto":"tcp","bind_address":"0.0.0.0","port":8080,"pid":42,"process":"python3","command":"python3 -m http.server 8080"}]}` resp := "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nContent-Length: " + strconv.Itoa(len(body)) + "\r\n\r\n" + body _, err = conn.Write([]byte(resp)) done <- err }() ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() listeners, err := Ports(ctx, nil, socketPath) if err != nil { t.Fatalf("Ports: %v", err) } if len(listeners) != 1 || listeners[0].Port != 8080 || listeners[0].Command != "python3 -m http.server 8080" { t.Fatalf("listeners = %+v, want parsed port listener", listeners) } if err := <-done; err != nil { t.Fatalf("server: %v", err) } } func TestParsePortListenersFiltersLoopbackAndDedupesWildcards(t *testing.T) { t.Parallel() raw := strings.Join([]string{ `tcp LISTEN 0 4096 0.0.0.0:22 0.0.0.0:* users:(("sshd",pid=12,fd=3))`, `tcp LISTEN 0 4096 [::]:22 [::]:* users:(("sshd",pid=12,fd=4))`, `udp UNCONN 0 0 127.0.0.53%lo:53 0.0.0.0:* users:(("stubby",pid=99,fd=3))`, `tcp LISTEN 0 4096 172.16.0.2:8080 0.0.0.0:* users:(("python3",pid=44,fd=6))`, }, "\n") readCmdline := func(pid int) string { switch pid { case 12: return "/usr/sbin/sshd -D" case 44: return "python3 -m http.server 8080" default: return "" } } listeners, err := parsePortListeners([]byte(raw), readCmdline) if err != nil { t.Fatalf("parsePortListeners: %v", err) } want := []PortListener{ { Proto: "tcp", BindAddress: "0.0.0.0", Port: 22, PID: 12, Process: "sshd", Command: "/usr/sbin/sshd -D", }, { Proto: "tcp", BindAddress: "172.16.0.2", Port: 8080, PID: 44, Process: "python3", Command: "python3 -m http.server 8080", }, } if !reflect.DeepEqual(listeners, want) { t.Fatalf("listeners = %#v, want %#v", listeners, want) } } func TestParsePortListenersFallsBackToProcessName(t *testing.T) { t.Parallel() raw := `tcp LISTEN 0 128 0.0.0.0:5432 0.0.0.0:* users:(("postgres",pid=77,fd=5))` listeners, err := parsePortListeners([]byte(raw), func(int) string { return "" }) if err != nil { t.Fatalf("parsePortListeners: %v", err) } if len(listeners) != 1 { t.Fatalf("listeners = %d, want 1", len(listeners)) } if listeners[0].Command != "postgres" { t.Fatalf("command = %q, want process fallback", listeners[0].Command) } } func TestNewHandlerPortsReturnsServerErrorOnCollectorFailure(t *testing.T) { origCollector := portCollector t.Cleanup(func() { portCollector = origCollector }) portCollector = func(context.Context) ([]PortListener, error) { return nil, errors.New("ss missing") } req, err := http.NewRequest(http.MethodGet, PortsPath, nil) if err != nil { t.Fatalf("NewRequest: %v", err) } rr := newTestResponseRecorder() NewHandler().ServeHTTP(rr, req) if rr.status != http.StatusInternalServerError { t.Fatalf("status = %d, want %d", rr.status, http.StatusInternalServerError) } } type testResponseRecorder struct { headers http.Header body bytes.Buffer status int } func newTestResponseRecorder() *testResponseRecorder { return &testResponseRecorder{headers: make(http.Header), status: http.StatusOK} } func (r *testResponseRecorder) Header() http.Header { return r.headers } func (r *testResponseRecorder) Write(data []byte) (int, error) { return r.body.Write(data) } func (r *testResponseRecorder) WriteHeader(status int) { r.status = status } type unexpectedStringError string func (e unexpectedStringError) Error() string { return "unexpected string: " + string(e) }