diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 61a28bf..f7b1918 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -25,16 +25,17 @@ import ( ) type Daemon struct { - layout paths.Layout - config model.DaemonConfig - store *store.Store - runner system.CommandRunner - logger *slog.Logger - mu sync.Mutex - closing chan struct{} - once sync.Once - pid int - listener net.Listener + layout paths.Layout + config model.DaemonConfig + store *store.Store + runner system.CommandRunner + logger *slog.Logger + mu sync.Mutex + closing chan struct{} + once sync.Once + pid int + listener net.Listener + requestHandler func(context.Context, rpc.Request) rpc.Response } func Open(ctx context.Context) (*Daemon, error) { @@ -143,23 +144,73 @@ func (d *Daemon) Serve(ctx context.Context) error { func (d *Daemon) handleConn(conn net.Conn) { defer conn.Close() + reader := bufio.NewReader(conn) var req rpc.Request - if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil { + if err := json.NewDecoder(reader).Decode(&req); err != nil { if d.logger != nil { d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error()) } _ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error())) return } - resp := d.dispatch(req) - _ = json.NewEncoder(conn).Encode(resp) + reqCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel) + defer stopWatch() + resp := d.dispatch(reqCtx, req) + if reqCtx.Err() != nil { + return + } + if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil { + d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error()) + } } -func (d *Daemon) dispatch(req rpc.Request) rpc.Response { +func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() { + if conn == nil || reader == nil { + return func() {} + } + done := make(chan struct{}) + var once sync.Once + go func() { + go func() { + <-done + if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok { + _ = deadlineSetter.SetReadDeadline(time.Now()) + } + }() + var buf [1]byte + for { + _, err := reader.Read(buf[:]) + if err == nil { + continue + } + select { + case <-done: + return + default: + } + if d.logger != nil { + d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error()) + } + cancel() + return + } + }() + return func() { + once.Do(func() { + close(done) + }) + } +} + +func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { if req.Version != rpc.Version { return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version)) } - ctx := context.Background() + if d.requestHandler != nil { + return d.requestHandler(ctx, req) + } switch req.Method { case "ping": result, _ := rpc.NewResult(api.PingResult{Status: "ok", PID: d.pid}) diff --git a/internal/daemon/daemon_test.go b/internal/daemon/daemon_test.go index eba22b2..80211c0 100644 --- a/internal/daemon/daemon_test.go +++ b/internal/daemon/daemon_test.go @@ -1,13 +1,19 @@ package daemon import ( + "bufio" "context" + "encoding/json" + "net" "os" "path/filepath" + "strings" "testing" "time" + "banger/internal/api" "banger/internal/model" + "banger/internal/rpc" "banger/internal/store" ) @@ -330,3 +336,107 @@ func TestSetDNSUsesMapDNSDefaultsWhenDataFileUnset(t *testing.T) { } runner.assertExhausted() } + +func TestDispatchUsesPassedContext(t *testing.T) { + t.Parallel() + + db := openDefaultImageStore(t, t.TempDir()) + d := &Daemon{store: db} + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + resp := d.dispatch(ctx, rpc.Request{ + Version: rpc.Version, + Method: "vm.list", + Params: mustJSON(t, api.Empty{}), + }) + + if resp.OK { + t.Fatal("dispatch() succeeded with canceled context") + } + if resp.Error == nil || !strings.Contains(resp.Error.Message, context.Canceled.Error()) { + t.Fatalf("dispatch() error = %+v, want context canceled", resp.Error) + } +} + +func TestHandleConnCancelsRequestWhenClientDisconnects(t *testing.T) { + t.Parallel() + + server, client := net.Pipe() + defer client.Close() + + requestCanceled := make(chan struct{}) + done := make(chan struct{}) + d := &Daemon{ + closing: make(chan struct{}), + requestHandler: func(ctx context.Context, req rpc.Request) rpc.Response { + if req.Method != "block" { + t.Errorf("request method = %q, want block", req.Method) + } + <-ctx.Done() + close(requestCanceled) + return rpc.NewError("operation_failed", ctx.Err().Error()) + }, + } + + go func() { + d.handleConn(server) + close(done) + }() + + if err := json.NewEncoder(client).Encode(rpc.Request{Version: rpc.Version, Method: "block"}); err != nil { + t.Fatalf("encode request: %v", err) + } + if err := client.Close(); err != nil { + t.Fatalf("close client: %v", err) + } + + select { + case <-requestCanceled: + case <-time.After(2 * time.Second): + t.Fatal("request context was not canceled after client disconnect") + } + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("handleConn did not return after client disconnect") + } +} + +func TestWatchRequestDisconnectCancelsContextOnEOF(t *testing.T) { + t.Parallel() + + server, client := net.Pipe() + defer server.Close() + + reader := bufio.NewReader(server) + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + d := &Daemon{closing: make(chan struct{})} + stop := d.watchRequestDisconnect(server, reader, "block", cancel) + defer stop() + + if err := client.Close(); err != nil { + t.Fatalf("close client: %v", err) + } + + select { + case <-ctx.Done(): + if !strings.Contains(ctx.Err().Error(), context.Canceled.Error()) { + t.Fatalf("ctx.Err() = %v, want canceled", ctx.Err()) + } + case <-time.After(2 * time.Second): + t.Fatal("watchRequestDisconnect did not cancel context") + } +} + +func mustJSON(t *testing.T, v any) []byte { + t.Helper() + data, err := json.Marshal(v) + if err != nil { + t.Fatalf("json.Marshal(%T): %v", v, err) + } + return data +} diff --git a/internal/rpc/rpc.go b/internal/rpc/rpc.go index f9f77d8..3abfb59 100644 --- a/internal/rpc/rpc.go +++ b/internal/rpc/rpc.go @@ -57,11 +57,22 @@ func DecodeParams[T any](req Request) (T, error) { func Call[T any](ctx context.Context, socketPath, method string, params any) (T, error) { var zero T - conn, err := net.DialTimeout("unix", socketPath, 2*time.Second) + dialer := &net.Dialer{Timeout: 2 * time.Second} + conn, err := dialer.DialContext(ctx, "unix", socketPath) if err != nil { return zero, err } defer conn.Close() + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + _ = conn.SetDeadline(time.Now()) + _ = conn.Close() + case <-done: + } + }() if deadline, ok := ctx.Deadline(); ok { _ = conn.SetDeadline(deadline) @@ -77,11 +88,17 @@ func Call[T any](ctx context.Context, socketPath, method string, params any) (T, } if err := json.NewEncoder(conn).Encode(request); err != nil { + if ctx.Err() != nil { + return zero, ctx.Err() + } return zero, err } var response Response if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&response); err != nil { + if ctx.Err() != nil { + return zero, ctx.Err() + } return zero, err } if !response.OK { diff --git a/internal/rpc/rpc_test.go b/internal/rpc/rpc_test.go index b469fcd..c59a8e9 100644 --- a/internal/rpc/rpc_test.go +++ b/internal/rpc/rpc_test.go @@ -128,6 +128,32 @@ func TestCallHonorsContextDeadline(t *testing.T) { } } +func TestCallHonorsContextCancellationWithoutDeadline(t *testing.T) { + t.Parallel() + + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + defer conn.Close() + var req Request + if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + var buf [1]byte + _, _ = conn.Read(buf[:]) + }) + defer cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + _, err := Call[map[string]string](ctx, socketPath, "ping", nil) + if !errors.Is(err, context.Canceled) { + t.Fatalf("Call() error = %v, want context canceled", err) + } +} + func TestWaitForSocket(t *testing.T) { t.Parallel()