banger/internal/rpc/rpc_test.go
Thales Maciel ccba07ec68
Propagate RPC cancellation to daemon requests
Stop long-running daemon operations from running under context.Background by\nthreading a request-scoped context from handleConn into dispatch. The daemon\nnow cancels in-flight handlers when the client socket goes away, and the RPC\nclient closes its Unix connection when the caller context is canceled so that\ninterrupts actually reach the daemon boundary.\n\nAdd regression coverage for both sides of the path: canceled dispatch calls,\nclient disconnects during handleConn, watcher EOF cancellation, and context\ncancellation without an RPC deadline.\n\nValidated with GOCACHE=/tmp/banger-gocache go test ./... and\nGOCACHE=/tmp/banger-gocache make build.
2026-03-16 18:28:33 -03:00

201 lines
4.9 KiB
Go

package rpc
import (
"bufio"
"context"
"encoding/json"
"errors"
"net"
"path/filepath"
"strings"
"testing"
"time"
)
func TestDecodeParams(t *testing.T) {
t.Parallel()
type payload struct {
Name string `json:"name"`
}
got, err := DecodeParams[payload](Request{})
if err != nil {
t.Fatalf("DecodeParams(empty): %v", err)
}
if got.Name != "" {
t.Fatalf("DecodeParams(empty) = %+v, want zero value", got)
}
_, err = DecodeParams[payload](Request{Params: json.RawMessage(`{"name":`)})
if err == nil {
t.Fatal("DecodeParams(malformed) returned nil error")
}
}
func TestCallRoundTripSuccess(t *testing.T) {
t.Parallel()
socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) {
defer conn.Close()
var req Request
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if req.Version != Version || req.Method != "ping" {
t.Fatalf("unexpected request: %+v", req)
}
var params map[string]string
if err := json.Unmarshal(req.Params, &params); err != nil {
t.Fatalf("unmarshal params: %v", err)
}
if params["name"] != "devbox" {
t.Fatalf("params = %v, want name=devbox", params)
}
resp, err := NewResult(map[string]string{"status": "ok"})
if err != nil {
t.Fatalf("NewResult: %v", err)
}
if err := json.NewEncoder(conn).Encode(resp); err != nil {
t.Fatalf("encode response: %v", err)
}
})
defer cleanup()
result, err := Call[map[string]string](context.Background(), socketPath, "ping", map[string]string{"name": "devbox"})
if err != nil {
t.Fatalf("Call: %v", err)
}
if result["status"] != "ok" {
t.Fatalf("Call() result = %v, want status=ok", result)
}
}
func TestCallReturnsRemoteError(t *testing.T) {
t.Parallel()
socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) {
defer conn.Close()
var req Request
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
if err := json.NewEncoder(conn).Encode(NewError("operation_failed", "boom")); err != nil {
t.Fatalf("encode error response: %v", err)
}
})
defer cleanup()
_, err := Call[map[string]string](context.Background(), socketPath, "ping", nil)
if err == nil || !strings.Contains(err.Error(), "operation_failed: boom") {
t.Fatalf("Call() error = %v, want remote error", err)
}
}
func TestCallRejectsMalformedResponse(t *testing.T) {
t.Parallel()
socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) {
defer conn.Close()
_, _ = conn.Write([]byte("{not-json}\n"))
})
defer cleanup()
_, err := Call[map[string]string](context.Background(), socketPath, "ping", nil)
if err == nil {
t.Fatal("Call() returned nil error for malformed response")
}
}
func TestCallHonorsContextDeadline(t *testing.T) {
t.Parallel()
socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) {
defer conn.Close()
var req Request
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
time.Sleep(100 * time.Millisecond)
})
defer cleanup()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
_, err := Call[map[string]string](ctx, socketPath, "ping", nil)
if err == nil {
t.Fatal("Call() returned nil error for deadline")
}
}
func TestCallHonorsContextCancellationWithoutDeadline(t *testing.T) {
t.Parallel()
socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) {
defer conn.Close()
var req Request
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil {
t.Fatalf("decode request: %v", err)
}
var buf [1]byte
_, _ = conn.Read(buf[:])
})
defer cleanup()
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
_, err := Call[map[string]string](ctx, socketPath, "ping", nil)
if !errors.Is(err, context.Canceled) {
t.Fatalf("Call() error = %v, want context canceled", err)
}
}
func TestWaitForSocket(t *testing.T) {
t.Parallel()
socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) {
_ = conn.Close()
})
defer cleanup()
if err := WaitForSocket(socketPath, 2*time.Second); err != nil {
t.Fatalf("WaitForSocket(success): %v", err)
}
err := WaitForSocket(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond)
if err == nil || !strings.Contains(err.Error(), "not ready") {
t.Fatalf("WaitForSocket(timeout) error = %v, want timeout", err)
}
}
func serveRPCOnce(t *testing.T, handler func(net.Conn)) (string, func()) {
t.Helper()
socketPath := filepath.Join(t.TempDir(), "rpc.sock")
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatalf("Listen: %v", err)
}
done := make(chan struct{})
go func() {
defer close(done)
conn, err := listener.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
t.Errorf("Accept: %v", err)
}
return
}
handler(conn)
}()
cleanup := func() {
_ = listener.Close()
<-done
}
return socketPath, cleanup
}