diff --git a/internal/daemon/daemon.go b/internal/daemon/daemon.go index 84325ed..4cff28b 100644 --- a/internal/daemon/daemon.go +++ b/internal/daemon/daemon.go @@ -310,17 +310,34 @@ func (d *Daemon) watchRequestDisconnect(conn net.Conn, reader *bufio.Reader, met } func (d *Daemon) dispatch(ctx context.Context, req rpc.Request) rpc.Response { + // Per-RPC correlation id is generated unconditionally — even + // errors that short-circuit before reaching a handler get one + // so the operator has a handle for every CLI failure. + // Generation can fail in theory (crypto/rand IO error) — + // degrade gracefully to a blank id rather than tearing down + // the request. + opID, _ := model.NewOpID() + if opID != "" { + ctx = WithOpID(ctx, opID) + } + stampOpID := func(resp rpc.Response) rpc.Response { + if !resp.OK && resp.Error != nil && resp.Error.OpID == "" && opID != "" { + resp.Error.OpID = opID + } + return resp + } + if req.Version != rpc.Version { - return rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version)) + return stampOpID(rpc.NewError("bad_version", fmt.Sprintf("unsupported version %d", req.Version))) } if d.requestHandler != nil { - return d.requestHandler(ctx, req) + return stampOpID(d.requestHandler(ctx, req)) } h, ok := rpcHandlers[req.Method] if !ok { - return rpc.NewError("unknown_method", req.Method) + return stampOpID(rpc.NewError("unknown_method", req.Method)) } - return h(ctx, d, req) + return stampOpID(h(ctx, d, req)) } func (d *Daemon) backgroundLoop() { @@ -346,7 +363,7 @@ func (d *Daemon) backgroundLoop() { } func (d *Daemon) reconcile(ctx context.Context) error { - op := d.beginOperation("daemon.reconcile") + op := d.beginOperation(ctx, "daemon.reconcile") vms, err := d.store.ListVMs(ctx) if err != nil { return op.fail(err) @@ -441,14 +458,12 @@ func wireServices(d *Daemon) { } if d.img == nil { d.img = newImageService(imageServiceDeps{ - runner: d.runner, - logger: d.logger, - config: d.config, - layout: d.layout, - store: d.store, - beginOperation: func(name string, attrs ...any) *operationLog { - return d.beginOperation(name, attrs...) - }, + runner: d.runner, + logger: d.logger, + config: d.config, + layout: d.layout, + store: d.store, + beginOperation: d.beginOperation, }) } if d.ws == nil { diff --git a/internal/daemon/dispatch_test.go b/internal/daemon/dispatch_test.go index 18e7bd8..73ea418 100644 --- a/internal/daemon/dispatch_test.go +++ b/internal/daemon/dispatch_test.go @@ -1,8 +1,12 @@ package daemon import ( + "context" "sort" + "strings" "testing" + + "banger/internal/rpc" ) // TestRPCHandlersMatchDocumentedMethods pins the surface of the RPC @@ -82,3 +86,55 @@ func TestRPCHandlersAllNonNil(t *testing.T) { } } } + +// TestDispatchStampsOpIDOnError pins the contract that every error +// response leaving dispatch carries an op_id, even on the +// short-circuit paths (bad_version, unknown_method) that never +// reach a handler. Operators rely on this id to correlate a CLI +// failure to a daemon log line. +func TestDispatchStampsOpIDOnError(t *testing.T) { + d := &Daemon{} + t.Run("unknown_method", func(t *testing.T) { + resp := d.dispatch(context.Background(), rpc.Request{Version: rpc.Version, Method: "no.such.method"}) + if resp.OK { + t.Fatalf("expected error response, got %+v", resp) + } + if resp.Error == nil || resp.Error.Code != "unknown_method" { + t.Fatalf("error = %+v, want unknown_method", resp.Error) + } + if !strings.HasPrefix(resp.Error.OpID, "op-") { + t.Fatalf("op_id = %q, want op-* prefix", resp.Error.OpID) + } + }) + t.Run("bad_version", func(t *testing.T) { + resp := d.dispatch(context.Background(), rpc.Request{Version: rpc.Version + 99, Method: "ping"}) + if resp.OK { + t.Fatalf("expected error response, got %+v", resp) + } + if resp.Error == nil || resp.Error.Code != "bad_version" { + t.Fatalf("error = %+v, want bad_version", resp.Error) + } + if !strings.HasPrefix(resp.Error.OpID, "op-") { + t.Fatalf("op_id = %q, want op-* prefix", resp.Error.OpID) + } + }) +} + +// TestDispatchPropagatesOpIDFromContext covers the case where a +// handler returns its own rpc.NewError with an empty op_id (most +// service errors do); the dispatch wrapper must stamp the +// dispatch-generated id on the way out. +func TestDispatchPropagatesOpIDFromContext(t *testing.T) { + d := &Daemon{ + requestHandler: func(_ context.Context, _ rpc.Request) rpc.Response { + return rpc.NewError("operation_failed", "deliberate test failure") + }, + } + resp := d.dispatch(context.Background(), rpc.Request{Version: rpc.Version, Method: "anything"}) + if resp.OK || resp.Error == nil { + t.Fatalf("expected error response, got %+v", resp) + } + if !strings.HasPrefix(resp.Error.OpID, "op-") { + t.Fatalf("dispatch did not stamp op_id: %+v", resp.Error) + } +} diff --git a/internal/daemon/image_service.go b/internal/daemon/image_service.go index ea0be21..c87893b 100644 --- a/internal/daemon/image_service.go +++ b/internal/daemon/image_service.go @@ -47,7 +47,7 @@ type ImageService struct { // beginOperation is a test seam used by a couple of image ops that // want structured operation logging. Nil → Daemon's beginOperation, // injected at construction. - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog } // imageServiceDeps names every handle ImageService needs from the @@ -59,7 +59,7 @@ type imageServiceDeps struct { config model.DaemonConfig layout paths.Layout store *store.Store - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog } func newImageService(deps imageServiceDeps) *ImageService { diff --git a/internal/daemon/images.go b/internal/daemon/images.go index b9e1332..1b100c3 100644 --- a/internal/daemon/images.go +++ b/internal/daemon/images.go @@ -98,7 +98,7 @@ func (s *ImageService) RegisterImage(ctx context.Context, params api.ImageRegist // imageOpsMu — only the find/rename/upsert commit atom holds the // lock. func (s *ImageService) PromoteImage(ctx context.Context, idOrName string) (image model.Image, err error) { - op := s.beginOperation("image.promote") + op := s.beginOperation(ctx, "image.promote") defer func() { if err != nil { op.fail(err, imageLogAttrs(image)...) diff --git a/internal/daemon/logger.go b/internal/daemon/logger.go index 8771609..cdc5fb7 100644 --- a/internal/daemon/logger.go +++ b/internal/daemon/logger.go @@ -9,6 +9,7 @@ import ( "time" "banger/internal/model" + "banger/internal/rpc" ) func newDaemonLogger(w io.Writer, rawLevel string) (*slog.Logger, string, error) { @@ -35,9 +36,37 @@ func parseLogLevel(raw string) (slog.Level, string, error) { } } -func (d *Daemon) beginOperation(name string, attrs ...any) *operationLog { +// WithOpID stores the per-RPC correlation id on ctx. Re-exported +// from rpc so daemon-side call sites don't have to import rpc just +// for context plumbing. The dispatch layer calls this on every +// incoming request; capability hooks, lifecycle steps, and the +// privileged-ops shim that crosses into the root helper all read +// the id back via OpIDFromContext so a single id stitches the +// whole chain together in journalctl. +func WithOpID(ctx context.Context, opID string) context.Context { + return rpc.WithOpID(ctx, opID) +} + +// OpIDFromContext returns the dispatch-assigned op id stored on +// ctx, or "" if none was set. +func OpIDFromContext(ctx context.Context) string { + return rpc.OpIDFromContext(ctx) +} + +// beginOperation starts a logged operation. When ctx carries a +// dispatch-assigned op id (see WithOpID) every log line emitted +// through the returned operationLog includes it as an "op_id" attr, +// so the daemon journal can be greppable by id from the user's CLI +// error all the way down through capability hooks and the root +// helper. +func (d *Daemon) beginOperation(ctx context.Context, name string, attrs ...any) *operationLog { + opID := OpIDFromContext(ctx) + allAttrs := append([]any(nil), attrs...) + if opID != "" { + allAttrs = append([]any{"op_id", opID}, allAttrs...) + } if d.logger != nil { - d.logger.Info("operation started", append([]any{"operation", name}, attrs...)...) + d.logger.Info("operation started", append([]any{"operation", name}, allAttrs...)...) } now := time.Now() return &operationLog{ @@ -45,7 +74,8 @@ func (d *Daemon) beginOperation(name string, attrs ...any) *operationLog { name: name, started: now, last: now, - attrs: append([]any(nil), attrs...), + attrs: allAttrs, + opID: opID, } } @@ -55,6 +85,16 @@ type operationLog struct { started time.Time last time.Time attrs []any + opID string +} + +// OpID exposes the correlation id this operation was started with so +// dispatch can stamp it onto an outgoing error response. +func (o *operationLog) OpID() string { + if o == nil { + return "" + } + return o.opID } func (o *operationLog) stage(stage string, attrs ...any) { diff --git a/internal/daemon/stats_service.go b/internal/daemon/stats_service.go index 6f5e25f..a15495b 100644 --- a/internal/daemon/stats_service.go +++ b/internal/daemon/stats_service.go @@ -39,7 +39,7 @@ type StatsService struct { config model.DaemonConfig store *store.Store net *HostNetwork - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog // vmAlive / vmHandles are the minimum pair needed to answer "is // this VM actually running right now?" + "what PID is it?". @@ -68,7 +68,7 @@ type statsServiceDeps struct { config model.DaemonConfig store *store.Store net *HostNetwork - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog vmAlive func(vm model.VMRecord) bool vmHandles func(vmID string) model.VMHandles withVMLockByRef func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) @@ -189,7 +189,7 @@ func (s *StatsService) stopStaleVMs(ctx context.Context) (err error) { if s.config.AutoStopStaleAfter <= 0 { return nil } - op := s.beginOperation("vm.stop_stale") + op := s.beginOperation(ctx, "vm.stop_stale") defer func() { if err != nil { op.fail(err) diff --git a/internal/daemon/vm_create.go b/internal/daemon/vm_create.go index 8946228..1fd8277 100644 --- a/internal/daemon/vm_create.go +++ b/internal/daemon/vm_create.go @@ -28,7 +28,7 @@ import ( // 3. Boot. Only the per-VM lock is held — parallel creates against // different VMs fully overlap. func (s *VMService) CreateVM(ctx context.Context, params api.VMCreateParams) (vm model.VMRecord, err error) { - op := s.beginOperation("vm.create") + op := s.beginOperation(ctx, "vm.create") defer func() { if err != nil { op.fail(err) diff --git a/internal/daemon/vm_create_ops.go b/internal/daemon/vm_create_ops.go index 53d1e98..0c8afe5 100644 --- a/internal/daemon/vm_create_ops.go +++ b/internal/daemon/vm_create_ops.go @@ -24,10 +24,21 @@ type vmCreateOperationState struct { op api.VMCreateOperation } -func newVMCreateOperationState() (*vmCreateOperationState, error) { - id, err := model.NewID() - if err != nil { - return nil, err +// newVMCreateOperationState constructs the async-progress record for +// a vm.create.begin RPC. When the caller's context already carries a +// dispatch-assigned op id (the normal path), we reuse it so the +// operator-visible status id and the daemon-log op_id are the same +// string. Otherwise we mint a fresh op id — keeps the same shape on +// internal call sites that don't go through dispatch (tests, future +// background creators). +func newVMCreateOperationState(ctx context.Context) (*vmCreateOperationState, error) { + id := OpIDFromContext(ctx) + if id == "" { + var err error + id, err = model.NewOpID() + if err != nil { + return nil, err + } } now := model.Now() return &vmCreateOperationState{ @@ -146,12 +157,16 @@ func (op *vmCreateOperationState) cancelOperation() { } } -func (s *VMService) BeginVMCreate(_ context.Context, params api.VMCreateParams) (api.VMCreateOperation, error) { - op, err := newVMCreateOperationState() +func (s *VMService) BeginVMCreate(ctx context.Context, params api.VMCreateParams) (api.VMCreateOperation, error) { + op, err := newVMCreateOperationState(ctx) if err != nil { return api.VMCreateOperation{}, err } - createCtx, cancel := context.WithCancel(context.Background()) + // Detach from the caller's deadline (the begin RPC returns + // immediately) but preserve the op id so every log line emitted + // by the goroutine carries the same identifier the client just + // got back. + createCtx, cancel := context.WithCancel(WithOpID(context.Background(), op.op.ID)) op.setCancel(cancel) s.createOps.Insert(op) go s.runVMCreateOperation(withVMCreateProgress(createCtx, op), op, params) diff --git a/internal/daemon/vm_lifecycle.go b/internal/daemon/vm_lifecycle.go index cb4f3b0..17e83e8 100644 --- a/internal/daemon/vm_lifecycle.go +++ b/internal/daemon/vm_lifecycle.go @@ -30,7 +30,7 @@ func (s *VMService) StartVM(ctx context.Context, idOrName string) (model.VMRecor } func (s *VMService) startVMLocked(ctx context.Context, vm model.VMRecord, image model.Image) (_ model.VMRecord, err error) { - op := s.beginOperation("vm.start", append(vmLogAttrs(vm), imageLogAttrs(image)...)...) + op := s.beginOperation(ctx, "vm.start", append(vmLogAttrs(vm), imageLogAttrs(image)...)...) defer func() { if err != nil { err = annotateLogPath(err, vm.Runtime.LogPath) @@ -97,7 +97,7 @@ func (s *VMService) StopVM(ctx context.Context, idOrName string) (model.VMRecord func (s *VMService) stopVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) { vm = current - op := s.beginOperation("vm.stop", "vm_ref", vm.ID) + op := s.beginOperation(ctx, "vm.stop", "vm_ref", vm.ID) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) @@ -154,7 +154,7 @@ func (s *VMService) KillVM(ctx context.Context, params api.VMKillParams) (model. func (s *VMService) killVMLocked(ctx context.Context, current model.VMRecord, signalValue string) (vm model.VMRecord, err error) { vm = current - op := s.beginOperation("vm.kill", "vm_ref", vm.ID, "signal", signalValue) + op := s.beginOperation(ctx, "vm.kill", "vm_ref", vm.ID, "signal", signalValue) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) @@ -209,7 +209,7 @@ func (s *VMService) killVMLocked(ctx context.Context, current model.VMRecord, si } func (s *VMService) RestartVM(ctx context.Context, idOrName string) (vm model.VMRecord, err error) { - op := s.beginOperation("vm.restart", "vm_ref", idOrName) + op := s.beginOperation(ctx, "vm.restart", "vm_ref", idOrName) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) @@ -244,7 +244,7 @@ func (s *VMService) DeleteVM(ctx context.Context, idOrName string) (model.VMReco func (s *VMService) deleteVMLocked(ctx context.Context, current model.VMRecord) (vm model.VMRecord, err error) { vm = current - op := s.beginOperation("vm.delete", "vm_ref", vm.ID) + op := s.beginOperation(ctx, "vm.delete", "vm_ref", vm.ID) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) diff --git a/internal/daemon/vm_service.go b/internal/daemon/vm_service.go index d8db6a4..86908a6 100644 --- a/internal/daemon/vm_service.go +++ b/internal/daemon/vm_service.go @@ -76,7 +76,7 @@ type VMService struct { // VMService never reaches back to *Daemon. capHooks capabilityHooks - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog } // capabilityHooks bundles the capability-dispatch entry points that @@ -104,7 +104,7 @@ type vmServiceDeps struct { ws *WorkspaceService priv privilegedOps capHooks capabilityHooks - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog vsockHostDevice string } diff --git a/internal/daemon/vm_set.go b/internal/daemon/vm_set.go index fdbb864..0acf4c4 100644 --- a/internal/daemon/vm_set.go +++ b/internal/daemon/vm_set.go @@ -17,7 +17,7 @@ func (s *VMService) SetVM(ctx context.Context, params api.VMSetParams) (model.VM func (s *VMService) setVMLocked(ctx context.Context, current model.VMRecord, params api.VMSetParams) (vm model.VMRecord, err error) { vm = current - op := s.beginOperation("vm.set", "vm_ref", vm.ID) + op := s.beginOperation(ctx, "vm.set", "vm_ref", vm.ID) defer func() { if err != nil { op.fail(err, vmLogAttrs(vm)...) diff --git a/internal/daemon/workspace_service.go b/internal/daemon/workspace_service.go index 386b38b..864c293 100644 --- a/internal/daemon/workspace_service.go +++ b/internal/daemon/workspace_service.go @@ -43,7 +43,7 @@ type WorkspaceService struct { imageWorkSeed func(ctx context.Context, image model.Image, fingerprint string) error withVMLockByRef func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog // repoInspector is the Inspector used by the real InspectRepo / // ImportRepoToGuest fallbacks when the test seams below aren't @@ -71,7 +71,7 @@ type workspaceServiceDeps struct { imageResolver func(ctx context.Context, idOrName string) (model.Image, error) imageWorkSeed func(ctx context.Context, image model.Image, fingerprint string) error withVMLockByRef func(ctx context.Context, idOrName string, fn func(model.VMRecord) (model.VMRecord, error)) (model.VMRecord, error) - beginOperation func(name string, attrs ...any) *operationLog + beginOperation func(ctx context.Context, name string, attrs ...any) *operationLog } func newWorkspaceService(deps workspaceServiceDeps) *WorkspaceService { diff --git a/internal/model/types.go b/internal/model/types.go index c37b71a..d3a44fc 100644 --- a/internal/model/types.go +++ b/internal/model/types.go @@ -200,6 +200,21 @@ func NewID() (string, error) { return hex.EncodeToString(buf), nil } +// NewOpID returns a short identifier for tracing a single RPC +// operation across the daemon, the root helper, and the user-visible +// CLI error string. Format: "op-" + 12 hex chars (48 bits of entropy +// — collisions inside one daemon session are vanishingly unlikely +// and don't matter beyond it). Short enough to copy-paste from a +// CLI error into a journalctl --grep, long enough to actually +// disambiguate. +func NewOpID() (string, error) { + buf := make([]byte, 6) + if _, err := rand.Read(buf); err != nil { + return "", err + } + return "op-" + hex.EncodeToString(buf), nil +} + func ParseSize(raw string) (int64, error) { if raw == "" { return 0, errors.New("size is required") diff --git a/internal/roothelper/roothelper.go b/internal/roothelper/roothelper.go index 09bf4bd..ec3626f 100644 --- a/internal/roothelper/roothelper.go +++ b/internal/roothelper/roothelper.go @@ -285,7 +285,11 @@ func Open() (*Server, error) { return &Server{ meta: meta, runner: system.NewRunner(), - logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo})), + // JSON to match bangerd. Mixed text/JSON streams in the + // merged journalctl made the daemon side painful to grep; + // this aligns the helper so a single greppable shape spans + // both units. + logger: slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelInfo})), }, nil } @@ -352,7 +356,29 @@ func (s *Server) handleConn(conn net.Conn) { _ = json.NewEncoder(conn).Encode(rpc.NewError("bad_request", err.Error())) return } - resp := s.dispatch(context.Background(), req) + // Adopt the daemon's op id so a single greppable id covers the + // whole call chain (CLI → daemon → helper). Entry log at debug + // level keeps production quiet; the completion log fires at + // info-on-success / error-on-failure with duration so an + // operator can see at a glance how long each privileged op + // took. + ctx := rpc.WithOpID(context.Background(), req.OpID) + start := time.Now() + if s.logger != nil { + s.logger.Debug("helper rpc", "method", req.Method, "op_id", req.OpID) + } + resp := s.dispatch(ctx, req) + if !resp.OK && resp.Error != nil && resp.Error.OpID == "" && req.OpID != "" { + resp.Error.OpID = req.OpID + } + if s.logger != nil { + duration := time.Since(start).Milliseconds() + if !resp.OK && resp.Error != nil { + s.logger.Error("helper rpc failed", "method", req.Method, "op_id", req.OpID, "duration_ms", duration, "code", resp.Error.Code, "message", resp.Error.Message) + } else { + s.logger.Info("helper rpc completed", "method", req.Method, "op_id", req.OpID, "duration_ms", duration) + } + } _ = json.NewEncoder(conn).Encode(resp) } diff --git a/internal/rpc/rpc.go b/internal/rpc/rpc.go index 3abfb59..00e1ec3 100644 --- a/internal/rpc/rpc.go +++ b/internal/rpc/rpc.go @@ -18,6 +18,40 @@ type Request struct { Version int `json:"version"` Method string `json:"method"` Params json.RawMessage `json:"params,omitempty"` + // OpID is the per-RPC correlation id. Optional on the wire so + // older clients (which don't set it) and older servers (which + // don't read it) keep interoperating. The daemon attaches it on + // every incoming request via dispatch; rpc.Call forwards + // whatever id is on ctx so a helper RPC carries the same id as + // the daemon RPC that triggered it. + OpID string `json:"op_id,omitempty"` +} + +// opIDKey is the context-value key for the per-RPC correlation id +// that flows from CLI → daemon → root helper. Lives in the rpc +// package because rpc.Call needs to read it without depending on +// the daemon package; daemon and roothelper both import it. +type opIDKey struct{} + +// WithOpID stores opID on ctx. Used by the daemon dispatch layer to +// inject the per-request id; rpc.Call picks it up automatically. +func WithOpID(ctx context.Context, opID string) context.Context { + if ctx == nil || opID == "" { + return ctx + } + return context.WithValue(ctx, opIDKey{}, opID) +} + +// OpIDFromContext returns the op id stored on ctx by WithOpID, or +// "" if none was set. +func OpIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + if id, _ := ctx.Value(opIDKey{}).(string); id != "" { + return id + } + return "" } type Response struct { @@ -29,6 +63,29 @@ type Response struct { type ErrorResponse struct { Code string `json:"code"` Message string `json:"message"` + // OpID is the daemon-assigned correlation id for the RPC that + // produced this error. Optional and may be empty (older daemons + // don't set it); when present the CLI surfaces it so an operator + // can grep journalctl by that id and find the full context. + OpID string `json:"op_id,omitempty"` +} + +// Error makes ErrorResponse satisfy the error interface so callers +// can errors.As it out of an rpc.Call return value and read the +// structured fields directly. The default string form is +// "code: message (op-id)" — the op id only appears when the daemon +// attached one. CLI code paths that want a translated, user-facing +// message render the typed fields themselves; this fallback is for +// log lines, fmt.Errorf %w wrappers, and any caller that hasn't +// bothered to errors.As yet. +func (e *ErrorResponse) Error() string { + if e == nil { + return "" + } + if e.OpID == "" { + return e.Code + ": " + e.Message + } + return e.Code + ": " + e.Message + " (" + e.OpID + ")" } func NewResult(v any) (Response, error) { @@ -43,6 +100,12 @@ func NewError(code, message string) Response { return Response{OK: false, Error: &ErrorResponse{Code: code, Message: message}} } +// NewErrorWithOpID is the variant for daemon dispatch sites that have +// resolved an op id by the time they encode the response. +func NewErrorWithOpID(code, message, opID string) Response { + return Response{OK: false, Error: &ErrorResponse{Code: code, Message: message, OpID: opID}} +} + func DecodeParams[T any](req Request) (T, error) { var zero T if len(req.Params) == 0 { @@ -78,7 +141,7 @@ func Call[T any](ctx context.Context, socketPath, method string, params any) (T, _ = conn.SetDeadline(deadline) } - request := Request{Version: Version, Method: method} + request := Request{Version: Version, Method: method, OpID: OpIDFromContext(ctx)} if params != nil { raw, err := json.Marshal(params) if err != nil { @@ -105,7 +168,10 @@ func Call[T any](ctx context.Context, socketPath, method string, params any) (T, if response.Error == nil { return zero, errors.New("rpc error") } - return zero, fmt.Errorf("%s: %s", response.Error.Code, response.Error.Message) + // Return the typed error directly so callers that need code + // or op_id can errors.As it out. err.Error() format is + // preserved for callers that only print the message. + return zero, response.Error } if len(response.Result) == 0 { return zero, nil diff --git a/internal/rpc/rpc_test.go b/internal/rpc/rpc_test.go index c59a8e9..10e64c2 100644 --- a/internal/rpc/rpc_test.go +++ b/internal/rpc/rpc_test.go @@ -92,6 +92,62 @@ func TestCallReturnsRemoteError(t *testing.T) { } } +func TestCallExposesTypedErrorWithOpID(t *testing.T) { + t.Parallel() + + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + defer conn.Close() + var req Request + if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&req); err != nil { + t.Fatalf("decode request: %v", err) + } + if err := json.NewEncoder(conn).Encode(NewErrorWithOpID("not_found", "vm \"foo\" not found", "op-deadbeef00ff")); err != nil { + t.Fatalf("encode error response: %v", err) + } + }) + defer cleanup() + + _, err := Call[map[string]string](context.Background(), socketPath, "vm.show", nil) + if err == nil { + t.Fatal("Call() returned nil error") + } + var rpcErr *ErrorResponse + if !errors.As(err, &rpcErr) { + t.Fatalf("Call() error %T (%v) is not *ErrorResponse — CLI cannot read the op_id", err, err) + } + if rpcErr.Code != "not_found" || rpcErr.OpID != "op-deadbeef00ff" { + t.Fatalf("typed error = %+v, want code=not_found op-deadbeef00ff", rpcErr) + } + // String form keeps the op_id in parens so callers that only + // log err.Error() still surface the id. + if got := rpcErr.Error(); !strings.Contains(got, "(op-deadbeef00ff)") { + t.Fatalf("err.Error() = %q, want op-id suffix", got) + } +} + +func TestCallForwardsOpIDFromContext(t *testing.T) { + t.Parallel() + + var seenReq Request + socketPath, cleanup := serveRPCOnce(t, func(conn net.Conn) { + defer conn.Close() + if err := json.NewDecoder(bufio.NewReader(conn)).Decode(&seenReq); err != nil { + t.Fatalf("decode request: %v", err) + } + resp, _ := NewResult(map[string]string{"status": "ok"}) + _ = json.NewEncoder(conn).Encode(resp) + }) + defer cleanup() + + ctx := WithOpID(context.Background(), "op-cafef00d1234") + if _, err := Call[map[string]string](ctx, socketPath, "ping", nil); err != nil { + t.Fatalf("Call: %v", err) + } + if seenReq.OpID != "op-cafef00d1234" { + t.Fatalf("server saw op_id = %q, want op-cafef00d1234", seenReq.OpID) + } +} + func TestCallRejectsMalformedResponse(t *testing.T) { t.Parallel()