banger/internal/rpc/rpc.go
Thales Maciel ccba07ec68
Propagate RPC cancellation to daemon requests
Stop long-running daemon operations from running under context.Background by\nthreading a request-scoped context from handleConn into dispatch. The daemon\nnow cancels in-flight handlers when the client socket goes away, and the RPC\nclient closes its Unix connection when the caller context is canceled so that\ninterrupts actually reach the daemon boundary.\n\nAdd regression coverage for both sides of the path: canceled dispatch calls,\nclient disconnects during handleConn, watcher EOF cancellation, and context\ncancellation without an RPC deadline.\n\nValidated with GOCACHE=/tmp/banger-gocache go test ./... and\nGOCACHE=/tmp/banger-gocache make build.
2026-03-16 18:28:33 -03:00

145 lines
3.1 KiB
Go

package rpc
import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
"os"
"time"
)
const Version = 1
type Request struct {
Version int `json:"version"`
Method string `json:"method"`
Params json.RawMessage `json:"params,omitempty"`
}
type Response struct {
OK bool `json:"ok"`
Result json.RawMessage `json:"result,omitempty"`
Error *ErrorResponse `json:"error,omitempty"`
}
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
func NewResult(v any) (Response, error) {
data, err := json.Marshal(v)
if err != nil {
return Response{}, err
}
return Response{OK: true, Result: data}, nil
}
func NewError(code, message string) Response {
return Response{OK: false, Error: &ErrorResponse{Code: code, Message: message}}
}
func DecodeParams[T any](req Request) (T, error) {
var zero T
if len(req.Params) == 0 {
return zero, nil
}
var out T
if err := json.Unmarshal(req.Params, &out); err != nil {
return zero, err
}
return out, nil
}
func Call[T any](ctx context.Context, socketPath, method string, params any) (T, error) {
var zero T
dialer := &net.Dialer{Timeout: 2 * time.Second}
conn, err := dialer.DialContext(ctx, "unix", socketPath)
if err != nil {
return zero, err
}
defer conn.Close()
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-ctx.Done():
_ = conn.SetDeadline(time.Now())
_ = conn.Close()
case <-done:
}
}()
if deadline, ok := ctx.Deadline(); ok {
_ = conn.SetDeadline(deadline)
}
request := Request{Version: Version, Method: method}
if params != nil {
raw, err := json.Marshal(params)
if err != nil {
return zero, err
}
request.Params = raw
}
if err := json.NewEncoder(conn).Encode(request); err != nil {
if ctx.Err() != nil {
return zero, ctx.Err()
}
return zero, err
}
var response Response
if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&response); err != nil {
if ctx.Err() != nil {
return zero, ctx.Err()
}
return zero, err
}
if !response.OK {
if response.Error == nil {
return zero, errors.New("rpc error")
}
return zero, fmt.Errorf("%s: %s", response.Error.Code, response.Error.Message)
}
if len(response.Result) == 0 {
return zero, nil
}
var result T
if err := json.Unmarshal(response.Result, &result); err != nil {
return zero, err
}
return result, nil
}
func WaitForSocket(path string, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for {
if _, err := os.Stat(path); err == nil {
conn, err := net.DialTimeout("unix", path, 500*time.Millisecond)
if err == nil {
_ = conn.Close()
return nil
}
}
if time.Now().After(deadline) {
return fmt.Errorf("socket %s not ready", path)
}
time.Sleep(100 * time.Millisecond)
}
}
func NewUnixHTTPClient(socketPath string) *http.Client {
return &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return (&net.Dialer{}).DialContext(ctx, "unix", socketPath)
},
},
}
}