Propagate RPC cancellation to daemon requests
Stop long-running daemon operations from running under context.Background by\nthreading a request-scoped context from handleConn into dispatch. The daemon\nnow cancels in-flight handlers when the client socket goes away, and the RPC\nclient closes its Unix connection when the caller context is canceled so that\ninterrupts actually reach the daemon boundary.\n\nAdd regression coverage for both sides of the path: canceled dispatch calls,\nclient disconnects during handleConn, watcher EOF cancellation, and context\ncancellation without an RPC deadline.\n\nValidated with GOCACHE=/tmp/banger-gocache go test ./... and\nGOCACHE=/tmp/banger-gocache make build.
This commit is contained in:
parent
ebb68c3126
commit
ccba07ec68
4 changed files with 220 additions and 16 deletions
|
|
@ -25,16 +25,17 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type Daemon struct {
|
type Daemon struct {
|
||||||
layout paths.Layout
|
layout paths.Layout
|
||||||
config model.DaemonConfig
|
config model.DaemonConfig
|
||||||
store *store.Store
|
store *store.Store
|
||||||
runner system.CommandRunner
|
runner system.CommandRunner
|
||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
closing chan struct{}
|
closing chan struct{}
|
||||||
once sync.Once
|
once sync.Once
|
||||||
pid int
|
pid int
|
||||||
listener net.Listener
|
listener net.Listener
|
||||||
|
requestHandler func(context.Context, rpc.Request) rpc.Response
|
||||||
}
|
}
|
||||||
|
|
||||||
func Open(ctx context.Context) (*Daemon, error) {
|
func Open(ctx context.Context) (*Daemon, error) {
|
||||||
|
|
@ -143,23 +144,73 @@ func (d *Daemon) Serve(ctx context.Context) error {
|
||||||
|
|
||||||
func (d *Daemon) handleConn(conn net.Conn) {
|
func (d *Daemon) handleConn(conn net.Conn) {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
var req rpc.Request
|
var req rpc.Request
|
||||||
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil {
|
if err := json.NewDecoder(reader).Decode(&req); err != nil {
|
||||||
if d.logger != nil {
|
if d.logger != nil {
|
||||||
d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error())
|
d.logger.Warn("daemon request decode failed", "remote", conn.RemoteAddr().String(), "error", err.Error())
|
||||||
}
|
}
|
||||||
_ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error()))
|
_ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp := d.dispatch(req)
|
reqCtx, cancel := context.WithCancel(context.Background())
|
||||||
_ = json.NewEncoder(conn).Encode(resp)
|
defer cancel()
|
||||||
|
stopWatch := d.watchRequestDisconnect(conn, reader, req.Method, cancel)
|
||||||
|
defer stopWatch()
|
||||||
|
resp := d.dispatch(reqCtx, req)
|
||||||
|
if reqCtx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := json.NewEncoder(conn).Encode(resp); err != nil && d.logger != nil {
|
||||||
|
d.logger.Warn("daemon response encode failed", "method", req.Method, "remote", conn.RemoteAddr().String(), "error", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Daemon) dispatch(req rpc.Request) rpc.Response {
|
func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, method string, cancel context.CancelFunc) func() {
|
||||||
|
if conn == nil || reader == nil {
|
||||||
|
return func() {}
|
||||||
|
}
|
||||||
|
done := make(chan struct{})
|
||||||
|
var once sync.Once
|
||||||
|
go func() {
|
||||||
|
go func() {
|
||||||
|
<-done
|
||||||
|
if deadlineSetter, ok := conn.(interface{ SetReadDeadline(time.Time) error }); ok {
|
||||||
|
_ = deadlineSetter.SetReadDeadline(time.Now())
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
var buf [1]byte
|
||||||
|
for {
|
||||||
|
_, err := reader.Read(buf[:])
|
||||||
|
if err == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if d.logger != nil {
|
||||||
|
d.logger.Info("daemon request canceled", "method", method, "remote", conn.RemoteAddr().String(), "error", err.Error())
|
||||||
|
}
|
||||||
|
cancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return func() {
|
||||||
|
once.Do(func() {
|
||||||
|
close(done)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response {
|
||||||
if req.Version != rpc.Version {
|
if req.Version != rpc.Version {
|
||||||
return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version))
|
return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version))
|
||||||
}
|
}
|
||||||
ctx := context.Background()
|
if d.requestHandler != nil {
|
||||||
|
return d.requestHandler(ctx, req)
|
||||||
|
}
|
||||||
switch req.Method {
|
switch req.Method {
|
||||||
case "ping":
|
case "ping":
|
||||||
result, _ := rpc.NewResult(api.PingResult{Status: "ok", PID: d.pid})
|
result, _ := rpc.NewResult(api.PingResult{Status: "ok", PID: d.pid})
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
package daemon
|
package daemon
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"banger/internal/api"
|
||||||
"banger/internal/model"
|
"banger/internal/model"
|
||||||
|
"banger/internal/rpc"
|
||||||
"banger/internal/store"
|
"banger/internal/store"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -330,3 +336,107 @@ func TestSetDNSUsesMapDNSDefaultsWhenDataFileUnset(t *testing.T) {
|
||||||
}
|
}
|
||||||
runner.assertExhausted()
|
runner.assertExhausted()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDispatchUsesPassedContext(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db := openDefaultImageStore(t, t.TempDir())
|
||||||
|
d := &Daemon{store: db}
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
resp := d.dispatch(ctx, rpc.Request{
|
||||||
|
Version: rpc.Version,
|
||||||
|
Method: "vm.list",
|
||||||
|
Params: mustJSON(t, api.Empty{}),
|
||||||
|
})
|
||||||
|
|
||||||
|
if resp.OK {
|
||||||
|
t.Fatal("dispatch() succeeded with canceled context")
|
||||||
|
}
|
||||||
|
if resp.Error == nil || !strings.Contains(resp.Error.Message, context.Canceled.Error()) {
|
||||||
|
t.Fatalf("dispatch() error = %+v, want context canceled", resp.Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleConnCancelsRequestWhenClientDisconnects(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server, client := net.Pipe()
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
requestCanceled := make(chan struct{})
|
||||||
|
done := make(chan struct{})
|
||||||
|
d := &Daemon{
|
||||||
|
closing: make(chan struct{}),
|
||||||
|
requestHandler: func(ctx context.Context, req rpc.Request) rpc.Response {
|
||||||
|
if req.Method != "block" {
|
||||||
|
t.Errorf("request method = %q, want block", req.Method)
|
||||||
|
}
|
||||||
|
<-ctx.Done()
|
||||||
|
close(requestCanceled)
|
||||||
|
return rpc.NewError("operation_failed", ctx.Err().Error())
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
d.handleConn(server)
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := json.NewEncoder(client).Encode(rpc.Request{Version: rpc.Version, Method: "block"}); err != nil {
|
||||||
|
t.Fatalf("encode request: %v", err)
|
||||||
|
}
|
||||||
|
if err := client.Close(); err != nil {
|
||||||
|
t.Fatalf("close client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-requestCanceled:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("request context was not canceled after client disconnect")
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("handleConn did not return after client disconnect")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWatchRequestDisconnectCancelsContextOnEOF(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
server, client := net.Pipe()
|
||||||
|
defer server.Close()
|
||||||
|
|
||||||
|
reader := bufio.NewReader(server)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
d := &Daemon{closing: make(chan struct{})}
|
||||||
|
stop := d.watchRequestDisconnect(server, reader, "block", cancel)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
if err := client.Close(); err != nil {
|
||||||
|
t.Fatalf("close client: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
if !strings.Contains(ctx.Err().Error(), context.Canceled.Error()) {
|
||||||
|
t.Fatalf("ctx.Err() = %v, want canceled", ctx.Err())
|
||||||
|
}
|
||||||
|
case <-time.After(2 * time.Second):
|
||||||
|
t.Fatal("watchRequestDisconnect did not cancel context")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustJSON(t *testing.T, v any) []byte {
|
||||||
|
t.Helper()
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("json.Marshal(%T): %v", v, err)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -57,11 +57,22 @@ func DecodeParams[T any](req Request) (T, error) {
|
||||||
|
|
||||||
func Call[T any](ctx context.Context, socketPath, method string, params any) (T, error) {
|
func Call[T any](ctx context.Context, socketPath, method string, params any) (T, error) {
|
||||||
var zero T
|
var zero T
|
||||||
conn, err := net.DialTimeout("unix", socketPath, 2*time.Second)
|
dialer := &net.Dialer{Timeout: 2 * time.Second}
|
||||||
|
conn, err := dialer.DialContext(ctx, "unix", socketPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
done := make(chan struct{})
|
||||||
|
defer close(done)
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
_ = conn.SetDeadline(time.Now())
|
||||||
|
_ = conn.Close()
|
||||||
|
case <-done:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
_ = conn.SetDeadline(deadline)
|
_ = conn.SetDeadline(deadline)
|
||||||
|
|
@ -77,11 +88,17 @@ func Call[T any](ctx context.Context, socketPath, method string, params any) (T,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(conn).Encode(request); err != nil {
|
if err := json.NewEncoder(conn).Encode(request); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return zero, ctx.Err()
|
||||||
|
}
|
||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var response Response
|
var response Response
|
||||||
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&response); err != nil {
|
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&response); err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return zero, ctx.Err()
|
||||||
|
}
|
||||||
return zero, err
|
return zero, err
|
||||||
}
|
}
|
||||||
if !response.OK {
|
if !response.OK {
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,32 @@ func TestCallHonorsContextDeadline(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCallHonorsContextCancellationWithoutDeadline(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) {
|
||||||
|
defer conn.Close()
|
||||||
|
var req Request
|
||||||
|
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil {
|
||||||
|
t.Fatalf("decode request: %v", err)
|
||||||
|
}
|
||||||
|
var buf [1]byte
|
||||||
|
_, _ = conn.Read(buf[:])
|
||||||
|
})
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
go func() {
|
||||||
|
time.Sleep(20 * time.Millisecond)
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
_, err := Call[map[string]string](ctx, socketPath, "ping", nil)
|
||||||
|
if !errors.Is(err, context.Canceled) {
|
||||||
|
t.Fatalf("Call() error = %v, want context canceled", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWaitForSocket(t *testing.T) {
|
func TestWaitForSocket(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue