From 819d2eeac4442c15bfd23136c4b473a0063cf307 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 12:28:16 -0800 Subject: [PATCH 01/16] refactor: response handling to use Result[T] type for type safety The Result[T] type replaces the non-generic, unexported response type. The reason to export the type is to make it easier to implement interceptors that need to use this type. Note that Result[T] is currently only used as Result[proto.Message] instances (in the current channel and as called from the existing calltypes (like quorumcall, multicast etc). In a future commit Result[T] will also be instansiated as Result[Resp], or Result[any], which can be used at the interceptor level. --- async.go | 12 ++++---- channel.go | 73 ++++++++++++++++++++++++------------------------- channel_test.go | 53 +++++++++++++++++------------------ correctable.go | 16 +++++------ multicast.go | 7 +++-- quorumcall.go | 10 +++---- rpc.go | 6 ++-- unicast.go | 8 +++--- 8 files changed, 92 insertions(+), 93 deletions(-) diff --git a/async.go b/async.go index 0d16c982..0f454dcc 100644 --- a/async.go +++ b/async.go @@ -37,7 +37,7 @@ func (f *Async) Done() bool { type asyncCallState struct { md *ordering.Metadata data QuorumCallData - replyChan <-chan response + replyChan <-chan Result[proto.Message] expectedReplies int } @@ -47,7 +47,7 @@ type asyncCallState struct { func (c RawConfiguration) AsyncCall(ctx context.Context, d QuorumCallData) *Async { expectedReplies := len(c) md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) - replyChan := make(chan response, expectedReplies) + replyChan := make(chan Result[proto.Message], expectedReplies) for _, n := range c { msg := d.Message @@ -58,7 +58,7 @@ func (c RawConfiguration) AsyncCall(ctx context.Context, d QuorumCallData) *Asyn continue // don't send if no msg } } - n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg)}, replyChan) + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), responseChan: replyChan}) } fut := &Async{c: make(chan struct{}, 1)} @@ -86,11 +86,11 @@ func (RawConfiguration) handleAsyncCall(ctx context.Context, fut *Async, state a for { select { case r := <-state.replyChan: - if r.err != nil { - errs = append(errs, nodeError{nodeID: r.nid, cause: r.err}) + if r.Err != nil { + errs = append(errs, nodeError{nodeID: r.NodeID, cause: r.Err}) break } - replies[r.nid] = r.msg + replies[r.NodeID] = r.Value if resp, quorum = state.data.QuorumFunction(state.data.Message, replies); quorum { fut.reply, fut.err = resp, nil return diff --git a/channel.go b/channel.go index 6c6b54a3..9f63a497 100644 --- a/channel.go +++ b/channel.go @@ -17,21 +17,18 @@ import ( var streamDownErr = status.Error(codes.Unavailable, "stream is down") type request struct { - ctx context.Context - msg *Message - opts callOptions - streaming bool + ctx context.Context + msg *Message + opts callOptions + streaming bool + responseChan chan<- Result[proto.Message] } -type response struct { - nid uint32 - msg proto.Message - err error -} - -type responseRouter struct { - c chan<- response - streaming bool +// Result wraps a response value with its associated error and node ID. +type Result[T any] struct { + NodeID uint32 + Value T + Err error } type channel struct { @@ -50,8 +47,10 @@ type channel struct { streamCtx context.Context cancelStream context.CancelFunc - // Response routing - responseRouters map[uint64]responseRouter + // Response routing; the map holds pending requests waiting for responses. + // The request contains the responseChan on which to send the response + // to the caller. + responseRouters map[uint64]request responseMut sync.Mutex // Lifecycle management: node close() cancels this context @@ -70,7 +69,7 @@ func newChannel(n *RawNode) *channel { sendQ: make(chan request, n.mgr.opts.sendBuffer), node: n, latency: -1 * time.Second, - responseRouters: make(map[uint64]responseRouter), + responseRouters: make(map[uint64]request), } // parentCtx controls the channel and is used to shut it down c.parentCtx = n.newContext() @@ -119,15 +118,13 @@ func (c *channel) isConnected() bool { return c.node.conn.GetState() == connectivity.Ready && c.getStream() != nil } -// enqueue adds the request to the send queue and sets up a response router if needed. +// enqueue adds the request to the send queue and sets up response routing if needed. // If the node is closed, it responds with an error instead. -func (c *channel) enqueue(req request, responseChan chan<- response) { +func (c *channel) enqueue(req request) { msgID := req.msg.GetMessageID() - if responseChan != nil { - // allocate before critical section - router := responseRouter{responseChan, req.streaming} + if req.responseChan != nil { c.responseMut.Lock() - c.responseRouters[msgID] = router + c.responseRouters[msgID] = req c.responseMut.Unlock() } // either enqueue the request on the sendQ or respond @@ -135,36 +132,36 @@ func (c *channel) enqueue(req request, responseChan chan<- response) { select { case <-c.parentCtx.Done(): // the node's close() method was called: respond with error instead of enqueueing - c.routeResponse(msgID, response{nid: c.node.ID(), err: fmt.Errorf("node closed")}) + c.routeResponse(msgID, Result[proto.Message]{NodeID: c.node.ID(), Err: fmt.Errorf("node closed")}) return case c.sendQ <- req: // enqueued successfully } } -// routeResponse routes the response to the appropriate response router based on msgID. -// If no router is found, the response is discarded. -func (c *channel) routeResponse(msgID uint64, resp response) { +// routeResponse routes the response to the appropriate response channel based on msgID. +// If no matching request is found, the response is discarded. +func (c *channel) routeResponse(msgID uint64, resp Result[proto.Message]) { c.responseMut.Lock() defer c.responseMut.Unlock() - if router, ok := c.responseRouters[msgID]; ok { - router.c <- resp + if req, ok := c.responseRouters[msgID]; ok { + req.responseChan <- resp // delete the router if we are only expecting a single reply message - if !router.streaming { + if !req.streaming { delete(c.responseRouters, msgID) } } } -// cancelPendingMsgs cancels all pending messages by sending an error response to each router. -// This is typically called when the stream goes down to notify all waiting calls. +// cancelPendingMsgs cancels all pending messages by sending an error response to each +// associated request. This is called when the stream goes down to notify all waiting calls. func (c *channel) cancelPendingMsgs() { c.responseMut.Lock() defer c.responseMut.Unlock() - for msgID, router := range c.responseRouters { - router.c <- response{nid: c.node.ID(), err: streamDownErr} + for msgID, req := range c.responseRouters { + req.responseChan <- Result[proto.Message]{NodeID: c.node.ID(), Err: streamDownErr} // delete the router if we are only expecting a single reply message - if !router.streaming { + if !req.streaming { delete(c.responseRouters, msgID) } } @@ -191,11 +188,11 @@ func (c *channel) sender() { // take next request from sendQ } if err := c.ensureStream(); err != nil { - c.routeResponse(req.msg.GetMessageID(), response{nid: c.node.ID(), err: err}) + c.routeResponse(req.msg.GetMessageID(), Result[proto.Message]{NodeID: c.node.ID(), Err: err}) continue } if err := c.sendMsg(req); err != nil { - c.routeResponse(req.msg.GetMessageID(), response{nid: c.node.ID(), err: err}) + c.routeResponse(req.msg.GetMessageID(), Result[proto.Message]{NodeID: c.node.ID(), Err: err}) } } } @@ -220,7 +217,7 @@ func (c *channel) receiver() { c.clearStream() } else { err := resp.GetStatus().Err() - c.routeResponse(resp.GetMessageID(), response{nid: c.node.ID(), msg: resp.GetProtoMessage(), err: err}) + c.routeResponse(resp.GetMessageID(), Result[proto.Message]{NodeID: c.node.ID(), Value: resp.GetProtoMessage(), Err: err}) } select { @@ -250,7 +247,7 @@ func (c *channel) sendMsg(req request) (err error) { // wait for actual server responses, so mustWaitSendDone() returns false for them. if req.opts.mustWaitSendDone() && err == nil { // Send succeeded: unblock the caller and clean up the responseRouter - c.routeResponse(req.msg.GetMessageID(), response{}) + c.routeResponse(req.msg.GetMessageID(), Result[proto.Message]{}) } }() diff --git a/channel_test.go b/channel_test.go index 6ecf2a62..fb1ec2fe 100644 --- a/channel_test.go +++ b/channel_test.go @@ -55,27 +55,28 @@ func newNodeWithStoppableServer(t testing.TB, delay time.Duration) (*RawNode, fu return NewNode(t, addrs[0]), teardown } -func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) response { +func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) Result[proto.Message] { t.Helper() if req.ctx == nil { req.ctx = t.Context() } req.msg = NewRequestMessage(ordering.NewGorumsMetadata(req.ctx, msgID, mock.TestMethod), nil) - replyChan := make(chan response, 1) - node.channel.enqueue(req, replyChan) + replyChan := make(chan Result[proto.Message], 1) + req.responseChan = replyChan + node.channel.enqueue(req) select { case resp := <-replyChan: return resp case <-time.After(defaultTestTimeout): t.Fatalf("timeout waiting for response to message %d", msgID) - return response{} + return Result[proto.Message]{} } } type msgResponse struct { msgID uint64 - resp response + resp Result[proto.Message] } func send(t testing.TB, results chan<- msgResponse, node *RawNode, goroutineID, msgsToSend int, req request) { @@ -104,7 +105,7 @@ func TestChannelCreation(t *testing.T) { // send message when server is down resp := sendRequest(t, node, request{opts: waitSendDone}, 1) - if resp.err == nil { + if resp.Err == nil { t.Error("response err: got , want error") } } @@ -122,8 +123,8 @@ func TestChannelShutdown(t *testing.T) { for i := range numMessages { wg.Go(func() { resp := sendRequest(t, node, request{}, uint64(i)) - if resp.err != nil { - t.Errorf("unexpected error for message %d, got error: %v", i, resp.err) + if resp.Err != nil { + t.Errorf("unexpected error for message %d, got error: %v", i, resp.Err) } }) } @@ -136,10 +137,10 @@ func TestChannelShutdown(t *testing.T) { // try to send a message after node closure resp := sendRequest(t, node, request{}, 999) - if resp.err == nil { + if resp.Err == nil { t.Error("expected error when sending to closed channel") - } else if resp.err.Error() != "node closed" { - t.Errorf("expected 'node closed' error, got: %v", resp.err) + } else if resp.Err.Error() != "node closed" { + t.Errorf("expected 'node closed' error, got: %v", resp.Err) } if node.channel.isConnected() { @@ -163,8 +164,8 @@ func TestChannelSendCompletionWaiting(t *testing.T) { start := time.Now() resp := sendRequest(t, node, request{opts: tt.opts}, uint64(i)) elapsed := time.Since(start) - if resp.err != nil { - t.Errorf("unexpected error: %v", resp.err) + if resp.Err != nil { + t.Errorf("unexpected error: %v", resp.Err) } t.Logf("response received in %v", elapsed) }) @@ -214,8 +215,8 @@ func TestChannelErrors(t *testing.T) { setup: func(t *testing.T) *RawNode { node, stopServer := newNodeWithStoppableServer(t, 0) resp := sendRequest(t, node, request{opts: waitSendDone}, 1) - if resp.err != nil { - t.Errorf("first message should succeed, got error: %v", resp.err) + if resp.Err != nil { + t.Errorf("first message should succeed, got error: %v", resp.Err) } stopServer() return node @@ -230,10 +231,10 @@ func TestChannelErrors(t *testing.T) { // Send message and verify error resp := sendRequest(t, node, request{opts: waitSendDone}, uint64(i)) - if resp.err == nil { + if resp.Err == nil { t.Errorf("expected error '%s' but got nil", tt.wantErr) - } else if !strings.Contains(resp.err.Error(), tt.wantErr) { - t.Errorf("expected error '%s', got: %v", tt.wantErr, resp.err) + } else if !strings.Contains(resp.Err.Error(), tt.wantErr) { + t.Errorf("expected error '%s', got: %v", tt.wantErr, resp.Err) } }) } @@ -401,8 +402,8 @@ func TestChannelConcurrentSends(t *testing.T) { var errs []error for range numMessages { res := <-results - if res.resp.err != nil { - errs = append(errs, res.resp.err) + if res.resp.Err != nil { + errs = append(errs, res.resp.Err) } } @@ -477,8 +478,8 @@ func TestChannelContext(t *testing.T) { node := newNodeWithServer(t, tt.serverDelay) resp := sendRequest(t, node, request{ctx: ctx, opts: tt.callOpts}, uint64(i)) - if !errors.Is(resp.err, tt.wantErr) { - t.Errorf("expected %v, got: %v", tt.wantErr, resp.err) + if !errors.Is(resp.Err, tt.wantErr) { + t.Errorf("expected %v, got: %v", tt.wantErr, resp.Err) } }) } @@ -578,8 +579,8 @@ func TestChannelRouterLifecycle(t *testing.T) { t.Run(name, func(t *testing.T) { msgID := uint64(i) resp := sendRequest(t, node, request{opts: tt.opts, streaming: tt.streaming}, msgID) - if resp.err != nil { - t.Errorf("unexpected error: %v", resp.err) + if resp.Err != nil { + t.Errorf("unexpected error: %v", resp.Err) } if exists := routerExists(node, msgID); exists != tt.wantRouter { t.Errorf("router exists = %v, want %v", exists, tt.wantRouter) @@ -605,8 +606,8 @@ func TestChannelResponseRouting(t *testing.T) { received := make(map[uint64]bool) for range numMessages { result := <-results - if result.resp.err != nil { - t.Errorf("message %d got error: %v", result.msgID, result.resp.err) + if result.resp.Err != nil { + t.Errorf("message %d got error: %v", result.msgID, result.resp.Err) } if received[result.msgID] { t.Errorf("message %d received twice", result.msgID) diff --git a/correctable.go b/correctable.go index edbecc11..9ecdb2f2 100644 --- a/correctable.go +++ b/correctable.go @@ -93,7 +93,7 @@ type CorrectableCallData struct { type correctableCallState struct { md *ordering.Metadata data CorrectableCallData - replyChan <-chan response + replyChan <-chan Result[proto.Message] expectedReplies int } @@ -104,7 +104,7 @@ func (c RawConfiguration) CorrectableCall(ctx context.Context, d CorrectableCall expectedReplies := len(c) md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) - replyChan := make(chan response, expectedReplies) + replyChan := make(chan Result[proto.Message], expectedReplies) for _, n := range c { msg := d.Message if d.PerNodeArgFn != nil { @@ -114,7 +114,7 @@ func (c RawConfiguration) CorrectableCall(ctx context.Context, d CorrectableCall continue // don't send if no msg } } - n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), streaming: d.ServerStream}, replyChan) + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), streaming: d.ServerStream, responseChan: replyChan}) } corr := &Correctable{donech: make(chan struct{}, 1)} @@ -148,19 +148,19 @@ func (c RawConfiguration) handleCorrectableCall(ctx context.Context, corr *Corre for { select { case r := <-state.replyChan: - if r.err != nil { - errs = append(errs, nodeError{nodeID: r.nid, cause: r.err}) + if r.Err != nil { + errs = append(errs, nodeError{nodeID: r.NodeID, cause: r.Err}) break } - replies[r.nid] = r.msg + replies[r.NodeID] = r.Value if resp, rlevel, quorum = state.data.QuorumFunction(state.data.Message, replies); quorum { if quorum { - corr.set(r.msg, rlevel, nil, true) + corr.set(r.Value, rlevel, nil, true) return } if rlevel > clevel { clevel = rlevel - corr.set(r.msg, rlevel, nil, false) + corr.set(r.Value, rlevel, nil, false) } } case <-ctx.Done(): diff --git a/multicast.go b/multicast.go index d563f532..c2856aa8 100644 --- a/multicast.go +++ b/multicast.go @@ -4,6 +4,7 @@ import ( "context" "github.com/relab/gorums/ordering" + "google.golang.org/protobuf/proto" ) // Multicast is a one-way call; no replies are returned to the client. @@ -21,9 +22,9 @@ func (c RawConfiguration) Multicast(ctx context.Context, d QuorumCallData, opts md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) sentMsgs := 0 - var replyChan chan response + var replyChan chan Result[proto.Message] if o.waitSendDone { - replyChan = make(chan response, len(c)) + replyChan = make(chan Result[proto.Message], len(c)) } for _, n := range c { msg := d.Message @@ -33,7 +34,7 @@ func (c RawConfiguration) Multicast(ctx context.Context, d QuorumCallData, opts continue // don't send if no msg } } - n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), opts: o}, replyChan) + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), opts: o, responseChan: replyChan}) sentMsgs++ } diff --git a/quorumcall.go b/quorumcall.go index 58a815fd..08e63d23 100644 --- a/quorumcall.go +++ b/quorumcall.go @@ -26,7 +26,7 @@ func (c RawConfiguration) QuorumCall(ctx context.Context, d QuorumCallData) (res expectedReplies := len(c) md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) - replyChan := make(chan response, expectedReplies) + replyChan := make(chan Result[proto.Message], expectedReplies) for _, n := range c { msg := d.Message if d.PerNodeArgFn != nil { @@ -36,7 +36,7 @@ func (c RawConfiguration) QuorumCall(ctx context.Context, d QuorumCallData) (res continue // don't send if no msg } } - n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg)}, replyChan) + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), responseChan: replyChan}) } var ( @@ -48,11 +48,11 @@ func (c RawConfiguration) QuorumCall(ctx context.Context, d QuorumCallData) (res for { select { case r := <-replyChan: - if r.err != nil { - errs = append(errs, nodeError{nodeID: r.nid, cause: r.err}) + if r.Err != nil { + errs = append(errs, nodeError{nodeID: r.NodeID, cause: r.Err}) break } - replies[r.nid] = r.msg + replies[r.NodeID] = r.Value if resp, quorum = d.QuorumFunction(d.Message, replies); quorum { return resp, nil } diff --git a/rpc.go b/rpc.go index 446370fa..d5031cea 100644 --- a/rpc.go +++ b/rpc.go @@ -20,12 +20,12 @@ type CallData struct { // This method should be used by generated code only. func (n *RawNode) RPCCall(ctx context.Context, d CallData) (proto.Message, error) { md := ordering.NewGorumsMetadata(ctx, n.mgr.getMsgID(), d.Method) - replyChan := make(chan response, 1) - n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, d.Message)}, replyChan) + replyChan := make(chan Result[proto.Message], 1) + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, d.Message), responseChan: replyChan}) select { case r := <-replyChan: - return r.msg, r.err + return r.Value, r.Err case <-ctx.Done(): return nil, ctx.Err() } diff --git a/unicast.go b/unicast.go index b84aa6d2..ad7ffb59 100644 --- a/unicast.go +++ b/unicast.go @@ -4,6 +4,7 @@ import ( "context" "github.com/relab/gorums/ordering" + "google.golang.org/protobuf/proto" ) // Unicast is a one-way call; no replies are returned to the client. @@ -20,15 +21,14 @@ func (n *RawNode) Unicast(ctx context.Context, d CallData, opts ...CallOption) { o := getCallOptions(E_Unicast, opts) md := ordering.NewGorumsMetadata(ctx, n.mgr.getMsgID(), d.Method) - req := request{ctx: ctx, msg: NewRequestMessage(md, d.Message), opts: o} if !o.waitSendDone { - n.channel.enqueue(req, nil) + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, d.Message), opts: o}) return // fire-and-forget: don't wait for send completion } // Default: block until send completes - replyChan := make(chan response, 1) - n.channel.enqueue(req, replyChan) + replyChan := make(chan Result[proto.Message], 1) + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, d.Message), opts: o, responseChan: replyChan}) <-replyChan } From 2322fffd52abba689f1686cdb5c959b3750a1d4a Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 12:38:07 -0800 Subject: [PATCH 02/16] feat: add echoSrv and update testSrv to support GetValue This updates the testSrv implmentation to support the GetValue method. It also adds the echoSrv implmentation. These will be used for testing client-side interceptors. --- testing_gorums.go | 38 +++++++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/testing_gorums.go b/testing_gorums.go index a28251fe..0743a12a 100644 --- a/testing_gorums.go +++ b/testing_gorums.go @@ -91,7 +91,7 @@ func TestSetup(t testing.TB, numServers int, srvFn func(i int) ServerIface) ([]s if srvFn != nil { srv = srvFn(i) } else { - srv = initServer() + srv = initServer(i) } // listen on any available port lis, err := net.Listen("tcp", "127.0.0.1:0") @@ -122,17 +122,45 @@ func TestSetup(t testing.TB, numServers int, srvFn func(i int) ServerIface) ([]s return addrs, stopFn } -func initServer() *Server { +func initServer(i int) *Server { srv := NewServer() + ts := testSrv{val: int32((i + 1) * 10)} srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { - resp, err := (&testSrv{}).Test(ctx, in.GetProtoMessage()) + resp, err := ts.Test(ctx, in.GetProtoMessage()) + return NewResponseMessage(in.GetMetadata(), resp), err + }) + srv.RegisterHandler(mock.GetValueMethod, func(ctx ServerCtx, in *Message) (*Message, error) { + resp, err := ts.GetValue(ctx, in.GetProtoMessage()) return NewResponseMessage(in.GetMetadata(), resp), err }) return srv } -type testSrv struct{} +type testSrv struct { + val int32 +} -func (testSrv) Test(_ ServerCtx, _ proto.Message) (proto.Message, error) { +func (_ testSrv) Test(_ ServerCtx, _ proto.Message) (proto.Message, error) { return pb.String(""), nil } + +func (ts testSrv) GetValue(_ ServerCtx, _ proto.Message) (proto.Message, error) { + return pb.Int32(ts.val), nil +} + +func echoServerFn(_ int) ServerIface { + srv := NewServer() + srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { + resp, err := echoSrv{}.Test(ctx, in.GetProtoMessage()) + return NewResponseMessage(in.GetMetadata(), resp), err + }) + + return srv +} + +// echoSrv implements a simple echo server handler for testing +type echoSrv struct{} + +func (echoSrv) Test(_ ServerCtx, req proto.Message) (proto.Message, error) { + return pb.String("echo: " + mock.GetVal(req)), nil +} From c615dba79acc756830fe2dafbc2c67cedf42e099 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 12:43:57 -0800 Subject: [PATCH 03/16] fix: rename Incomplete error to ErrIncomplete for consistency Linters have been complaining about this style issue for a while. --- async.go | 2 +- correctable.go | 2 +- errors.go | 7 +++++-- errors_test.go | 12 ++++++------ quorumcall.go | 2 +- 5 files changed, 14 insertions(+), 11 deletions(-) diff --git a/async.go b/async.go index 0f454dcc..68cf6569 100644 --- a/async.go +++ b/async.go @@ -100,7 +100,7 @@ func (RawConfiguration) handleAsyncCall(ctx context.Context, fut *Async, state a return } if len(errs)+len(replies) == state.expectedReplies { - fut.reply, fut.err = resp, QuorumCallError{cause: Incomplete, errors: errs, replies: len(replies)} + fut.reply, fut.err = resp, QuorumCallError{cause: ErrIncomplete, errors: errs, replies: len(replies)} return } } diff --git a/correctable.go b/correctable.go index 9ecdb2f2..5d2584f7 100644 --- a/correctable.go +++ b/correctable.go @@ -169,7 +169,7 @@ func (c RawConfiguration) handleCorrectableCall(ctx context.Context, corr *Corre } if (state.data.ServerStream && len(errs) == state.expectedReplies) || (!state.data.ServerStream && len(errs)+len(replies) == state.expectedReplies) { - corr.set(resp, clevel, QuorumCallError{cause: Incomplete, errors: errs, replies: len(replies)}, true) + corr.set(resp, clevel, QuorumCallError{cause: ErrIncomplete, errors: errs, replies: len(replies)}, true) return } } diff --git a/errors.go b/errors.go index 2ac77f9b..6173b105 100644 --- a/errors.go +++ b/errors.go @@ -6,9 +6,12 @@ import ( "strings" ) -// Incomplete is the error returned by a quorum call when the call cannot completed +// ErrIncomplete is the error returned by a quorum call when the call cannot completed // due insufficient non-error replies to form a quorum according to the quorum function. -var Incomplete = errors.New("incomplete call") +var ErrIncomplete = errors.New("incomplete call") + +// ErrTypeMismatch is returned when a response cannot be cast to the expected type. +var ErrTypeMismatch = errors.New("response type mismatch") // QuorumCallError reports on a failed quorum call. type QuorumCallError struct { diff --git a/errors_test.go b/errors_test.go index 8dceb6dd..ef8aa375 100644 --- a/errors_test.go +++ b/errors_test.go @@ -15,25 +15,25 @@ func TestQuorumCallErrorIs(t *testing.T) { }{ { name: "SameCauseError", - err: QuorumCallError{cause: Incomplete}, - target: Incomplete, + err: QuorumCallError{cause: ErrIncomplete}, + target: ErrIncomplete, want: true, }, { name: "SameCauseQCError", - err: QuorumCallError{cause: Incomplete}, - target: QuorumCallError{cause: Incomplete}, + err: QuorumCallError{cause: ErrIncomplete}, + target: QuorumCallError{cause: ErrIncomplete}, want: true, }, { name: "DifferentError", - err: QuorumCallError{cause: Incomplete}, + err: QuorumCallError{cause: ErrIncomplete}, target: errors.New("incomplete call"), want: false, }, { name: "DifferentQCError", - err: QuorumCallError{cause: Incomplete}, + err: QuorumCallError{cause: ErrIncomplete}, target: QuorumCallError{cause: errors.New("incomplete call")}, want: false, }, diff --git a/quorumcall.go b/quorumcall.go index 08e63d23..99d91c3c 100644 --- a/quorumcall.go +++ b/quorumcall.go @@ -60,7 +60,7 @@ func (c RawConfiguration) QuorumCall(ctx context.Context, d QuorumCallData) (res return resp, QuorumCallError{cause: ctx.Err(), errors: errs, replies: len(replies)} } if len(errs)+len(replies) == expectedReplies { - return resp, QuorumCallError{cause: Incomplete, errors: errs, replies: len(replies)} + return resp, QuorumCallError{cause: ErrIncomplete, errors: errs, replies: len(replies)} } } } From 219a9348d465a5e1d0b4d000f55dcda4ab156592 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 12:48:09 -0800 Subject: [PATCH 04/16] feat: add client-side interceptor-based quorum call architecture MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement a flexible, composable interceptor architecture for quorum calls that provides better type safety, modularity, and extensibility compared to the legacy QuorumCall approach. Core Architecture: - QuorumInterceptor: Generic type for interceptor functions that wrap QuorumFunc - QuorumFunc: Function signature for processing quorum calls and returning results - ClientCtx: Context object providing access to request, config, and response iterator - Chain: Utility function to compose interceptors around a base handler Key Features: 1. Full type safety with generics (Req, Resp, Out type parameters) 2. Support for custom return types via Out parameter 3. Lazy message sending - transforms applied before dispatch 4. Iterator-based response handling with early termination support 5. Composable middleware pattern for building complex quorum logic Base Quorum Functions (Aggregators): - MajorityQuorum: Returns first response after ⌈(n+1)/2⌉ successful replies - FirstResponse: Returns first successful response (read-any pattern) - AllResponses: Waits for all nodes to respond (write-all pattern) - ThresholdQuorum: Generic threshold-based quorum with configurable count - CollectAllResponses: Returns map of all successful responses by node ID Interceptors (Middleware): - PerNodeTransform: Applies per-node request transformations with skip support - QuorumSpecInterceptor: Adapter for legacy QuorumSpec-style functions Iterator Helpers: - IgnoreErrors: Filters iterator to yield only successful responses - CollectN: Collects up to n responses into a map - CollectAll: Collects all responses into a map Implementation Details: - Lazy sending via sync.OnceFunc ensures transforms registered before dispatch - RegisterTransformFunc allows chaining multiple request transformations - applyTransforms applies registered transforms in order, skips invalid results - Responses() iterator yields node responses as they arrive - Type-safe conversion from Result[proto.Message] to Result[Resp] Backward Compatibility: - Legacy QuorumCall remains unchanged - QuorumSpecInterceptor bridges old and new approaches - No breaking changes to existing code Testing: - 17 comprehensive test functions covering unit and integration scenarios - Tests for iterator utilities, interceptor chaining, custom aggregation - Tests for per-node transformation with node skipping - Tests for all base quorum functions with various error conditions - Integration tests with real gRPC servers - Helper functions for consistent test patterns (testContext, checkError, etc.) This architecture enables gradual migration from the legacy approach and provides a foundation for future code generation template updates. Files changed: - client_interceptor.go (new, 446 lines) - client_interceptor_test.go (new, 794 lines) --- client_interceptor.go | 445 +++++++++++++++++++++ client_interceptor_test.go | 794 +++++++++++++++++++++++++++++++++++++ 2 files changed, 1239 insertions(+) create mode 100644 client_interceptor.go create mode 100644 client_interceptor_test.go diff --git a/client_interceptor.go b/client_interceptor.go new file mode 100644 index 00000000..8a2dcc93 --- /dev/null +++ b/client_interceptor.go @@ -0,0 +1,445 @@ +package gorums + +import ( + "context" + "iter" + "maps" + "sync" + + "github.com/relab/gorums/ordering" + "google.golang.org/protobuf/proto" +) + +// msg is a type alias for proto.Message intended to be used as a type parameter. +type msg = proto.Message + +// QuorumInterceptor intercepts and processes quorum calls, allowing modification of +// requests, responses, and aggregation logic. Interceptors can be chained together, +// with each interceptor wrapping a QuorumFunc. +// +// Type parameters: +// - Req: The request message type sent to nodes +// - Resp: The response message type from individual nodes +// - Out: The final output type returned by the interceptor chain +// +// The Out type parameter allows interceptors to transform responses to a type different +// from the original response type specified in the service method's definition. +type QuorumInterceptor[Req, Resp msg, Out any] func(QuorumFunc[Req, Resp, Out]) QuorumFunc[Req, Resp, Out] + +// QuorumFunc processes a quorum call and returns the aggregated result. +// This is the function type that interceptors call to continue the chain. +// +// In a chain of interceptors, the final QuorumFunc is the base quorum function +// (e.g. MajorityQuorum) that actually collects and aggregates responses. +type QuorumFunc[Req, Resp msg, Out any] func(*ClientCtx[Req, Resp]) (Out, error) + +// ClientCtx provides context and access to the quorum call state for interceptors. +// It exposes the request, configuration, and an iterator over node responses. +type ClientCtx[Req, Resp any] struct { + context.Context + config RawConfiguration + request Req + method string + replyChan <-chan Result[msg] + reqTransforms []func(Req, *RawNode) Req + + // sendOnce is called lazily on the first call to Responses(). + sendOnce func() +} + +// newClientCtx creates a new ClientCtx for a quorum call. +func newClientCtx[Req, Resp any]( + ctx context.Context, + config RawConfiguration, + req Req, + method string, + replyChan <-chan Result[msg], +) *ClientCtx[Req, Resp] { + return &ClientCtx[Req, Resp]{ + Context: ctx, + config: config, + request: req, + method: method, + replyChan: replyChan, + reqTransforms: nil, + } +} + +// ------------------------------------------------------------------------- +// ClientCtx Methods +// ------------------------------------------------------------------------- + +// Request returns the original request message for this quorum call. +func (c ClientCtx[Req, Resp]) Request() Req { + return c.request +} + +// Config returns the configuration (set of nodes) for this quorum call. +func (c ClientCtx[Req, Resp]) Config() RawConfiguration { + return c.config +} + +// Method returns the name of the RPC method being called. +func (c ClientCtx[Req, Resp]) Method() string { + return c.method +} + +// Nodes returns the slice of nodes in this configuration. +func (c ClientCtx[Req, Resp]) Nodes() []*RawNode { + return c.config.Nodes() +} + +// Size returns the number of nodes in this configuration. +func (c ClientCtx[Req, Resp]) Size() int { + return c.config.Size() +} + +// RegisterTransformFunc registers a transformation function to modify the request message +// for a specific node. Transformation functions are applied in the order they are registered. +// +// This method is intended to be used by interceptors to modify requests before they are sent. +// It must be called before the first call to Responses(), which triggers the actual sending. +func (c *ClientCtx[Req, Resp]) RegisterTransformFunc(fn func(Req, *RawNode) Req) { + c.reqTransforms = append(c.reqTransforms, fn) +} + +// applyTransforms returns the transformed request as a proto.Message, or nil if the result is +// invalid or the node should be skipped. It applies the registered transformation functions to +// the given request for the specified node. Transformation functions are applied in the order +// they were registered. +func (c *ClientCtx[Req, Resp]) applyTransforms(req Req, node *RawNode) proto.Message { + result := req + for _, transform := range c.reqTransforms { + result = transform(result, node) + } + if protoMsg, ok := any(result).(proto.Message); ok { + if protoMsg.ProtoReflect().IsValid() { + return protoMsg + } + } + return nil +} + +// Responses returns an iterator that yields node responses as they arrive. +// It returns a single-use iterator. +// +// Messages are not sent to nodes before ctx.Responses() is called, applying any +// registered request transformations. This lazy sending is necessary to allow +// interceptors to register transformations prior to dispatch. +// +// The iterator will: +// - Yield responses as they arrive from nodes +// - Continue until the context is canceled or all expected responses are received +// - Allow early termination by breaking from the range loop +// +// Example usage: +// +// for nodeID, result := range ctx.Responses() { +// if result.Err != nil { +// // Handle node error +// continue +// } +// // Process result.Value +// if haveQuorum { +// break // Early termination +// } +// } +func (c *ClientCtx[Req, Resp]) Responses() iter.Seq2[uint32, Result[Resp]] { + // Trigger lazy sending + if c.sendOnce != nil { + c.sendOnce() + } + return func(yield func(uint32, Result[Resp]) bool) { + // Wait for at most c.Size() responses + for range c.Size() { + select { + case r := <-c.replyChan: + // We get a Result[proto.Message] from the channel layer's + // response router; however, we convert it to Result[Resp] + // here to match the calltype's expected response type. + res := Result[Resp]{ + NodeID: r.NodeID, + Err: r.Err, + } + if r.Err == nil { + if val, ok := r.Value.(Resp); ok { + res.Value = val + } else { + res.Err = ErrTypeMismatch + } + } + if !yield(res.NodeID, res) { + return // Consumer stopped iteration + } + case <-c.Done(): + return // Context canceled + } + } + } +} + +// ------------------------------------------------------------------------- +// Iterator Helpers +// ------------------------------------------------------------------------- + +// IgnoreErrors filters an iterator to only yield successful responses, discarding errors. +// This is useful when you want to process only valid responses from nodes. +// +// Example: +// +// for nodeID, resp := range IgnoreErrors(ctx.Responses()) { +// // resp is guaranteed to be a successful response +// process(resp) +// } +func IgnoreErrors[Resp any](seq iter.Seq2[uint32, Result[Resp]]) iter.Seq2[uint32, Resp] { + return func(yield func(uint32, Resp) bool) { + for nodeID, result := range seq { + if result.Err == nil { + if !yield(nodeID, result.Value) { + return + } + } + } + } +} + +// CollectN collects up to n successful responses from the iterator into a map. +// It returns early if n responses are collected or the iterator is exhausted. +func CollectN[Resp any](seq iter.Seq2[uint32, Resp], n int) map[uint32]Resp { + replies := make(map[uint32]Resp, n) + for nodeID, resp := range seq { + replies[nodeID] = resp + if len(replies) >= n { + break + } + } + return replies +} + +// CollectAll collects all responses from the iterator into a map. +// This is a convenience wrapper around maps.Collect. +func CollectAll[Resp any](seq iter.Seq2[uint32, Resp]) map[uint32]Resp { + return maps.Collect(seq) +} + +// ------------------------------------------------------------------------- +// Interceptors (Middleware) +// ------------------------------------------------------------------------- + +// PerNodeTransform returns an interceptor that applies per-node request transformations. +// +// The transform function receives the original request and a node, and returns the transformed +// request to send to that node. If the function returns an invalid message or nil, the request to +// that node is skipped. +// +// Multiple PerNodeTransform interceptors can be chained together, with transforms applied in order. +// +// Example: +// +// interceptor := PerNodeTransform(func(req *Request, node *gorums.RawNode) *Request { +// // Send different shard to each node +// return &Request{Shard: int(node.ID())} +// }) +func PerNodeTransform[Req, Resp msg, Out any](transform func(Req, *RawNode) Req) QuorumInterceptor[Req, Resp, Out] { + return func(next QuorumFunc[Req, Resp, Out]) QuorumFunc[Req, Resp, Out] { + return func(ctx *ClientCtx[Req, Resp]) (Out, error) { + ctx.RegisterTransformFunc(transform) + return next(ctx) + } + } +} + +// QuorumSpecInterceptor returns an interceptor that wraps a legacy QuorumSpec-style quorum function. +// This adapter allows gradual migration from the legacy QuorumCall approach to interceptors. +// +// The quorum function receives the original request and a map of replies, and returns +// the aggregated result and a boolean indicating whether quorum was reached. +// +// Example: +// +// // Legacy QuorumSpec function +// qf := func(req *Request, replies map[uint32]*Response) (*Result, bool) { +// if len(replies) > len(config)/2 { +// return replies[0], true +// } +// return nil, false +// } +// +// // Convert to interceptor +// interceptor := QuorumSpecInterceptor(qf) +func QuorumSpecInterceptor[Req, Resp msg, Out any]( + qf func(Req, map[uint32]Resp) (Out, bool), +) QuorumInterceptor[Req, Resp, Out] { + return func(next QuorumFunc[Req, Resp, Out]) QuorumFunc[Req, Resp, Out] { + return func(ctx *ClientCtx[Req, Resp]) (Out, error) { + replies := CollectAll(IgnoreErrors(ctx.Responses())) + resp, ok := qf(ctx.Request(), replies) + if !ok { + var zero Out + return zero, QuorumCallError{cause: ErrIncomplete, replies: len(replies)} + } + return resp, nil + } + } +} + +// ------------------------------------------------------------------------- +// Base Quorum Functions (Aggregators) +// ------------------------------------------------------------------------- + +// ThresholdQuorum returns a QuorumFunc that waits for a threshold number of +// successful responses. It returns the first response once the threshold is reached. +// +// This is a base quorum function that terminates the interceptor chain. +// +// Example: +// +// // Create a quorum that needs 2 out of 3 responses +// qf := ThresholdQuorum[*Request, *Response](2) +func ThresholdQuorum[Req, Resp msg](threshold int) QuorumFunc[Req, Resp, Resp] { + return func(ctx *ClientCtx[Req, Resp]) (Resp, error) { + var ( + firstResp Resp + found bool + count int + errs []nodeError + ) + + for nodeID, result := range ctx.Responses() { + if result.Err != nil { + errs = append(errs, nodeError{nodeID: nodeID, cause: result.Err}) + continue + } + + count++ + if !found { + firstResp = result.Value + found = true + } + + // Check if we have reached the threshold + if count >= threshold { + return firstResp, nil + } + } + + var zero Resp + return zero, QuorumCallError{cause: ErrIncomplete, errors: errs, replies: count} + } +} + +// MajorityQuorum returns the first response once a simple majority (⌈(n+1)/2⌉) +// of successful responses are received. +// +// This is a base quorum function that terminates the interceptor chain. +func MajorityQuorum[Req, Resp msg](ctx *ClientCtx[Req, Resp]) (Resp, error) { + quorumSize := (ctx.Size() + 1) / 2 + return ThresholdQuorum[Req, Resp](quorumSize)(ctx) +} + +// FirstResponse returns the first successful response received from any node. +// This is useful for read-any patterns where any single response is sufficient. +// +// This is a base quorum function that terminates the interceptor chain. +func FirstResponse[Req, Resp msg](ctx *ClientCtx[Req, Resp]) (Resp, error) { + return ThresholdQuorum[Req, Resp](1)(ctx) +} + +// AllResponses returns the first response once all nodes have responded successfully. +// If any node fails, it returns an error. This is useful for write-all patterns. +// +// This is a base quorum function that terminates the interceptor chain. +func AllResponses[Req, Resp msg](ctx *ClientCtx[Req, Resp]) (Resp, error) { + return ThresholdQuorum[Req, Resp](ctx.Size())(ctx) +} + +// CollectAllResponses returns a map of all successful responses indexed by node ID. +// +// This is a base quorum function that terminates the interceptor chain. +func CollectAllResponses[Req, Resp msg](ctx *ClientCtx[Req, Resp]) (map[uint32]Resp, error) { + return maps.Collect(IgnoreErrors(ctx.Responses())), nil +} + +// ------------------------------------------------------------------------- +// Chain +// ------------------------------------------------------------------------- + +// Chain returns a QuorumFunc that composes the provided interceptors around the base function. +// The interceptors are executed in the order provided, wrapping the base QuorumFunc. +// +// The base QuorumFunc is the terminal handler that actually processes the responses. +// Interceptors can wrap this handler to add behavior before or after the base handler. +// +// Execution order: +// 1. interceptors[0] (outermost wrapper) +// 2. interceptors[1] +// ... +// 3. base (innermost handler, e.g. aggregation) +func Chain[Req, Resp msg, Out any]( + base QuorumFunc[Req, Resp, Out], + interceptors ...QuorumInterceptor[Req, Resp, Out], +) QuorumFunc[Req, Resp, Out] { + handler := base + for i := len(interceptors) - 1; i >= 0; i-- { + handler = interceptors[i](handler) + } + return handler +} + +// ------------------------------------------------------------------------- +// QuorumCallWithInterceptor +// ------------------------------------------------------------------------- + +// QuorumCallWithInterceptor performs a quorum call using an interceptor-based approach. +// +// Type parameters: +// - Req: The request message type +// - Resp: The response message type from individual nodes +// - Out: The final output type returned by the interceptor chain +// +// The base parameter is the terminal handler that processes responses (e.g., MajorityQuorum). +// The interceptors parameter accepts one or more interceptors that wrap the base handler. +// +// Execution order: +// 1. interceptors[0] (outermost wrapper) +// 2. interceptors[1] +// ... +// 3. base (innermost handler, e.g. aggregation) +// +// Note: Messages are not sent to nodes before ctx.Responses() is called, applying any +// registered request transformations. This lazy sending is necessary to allow interceptors +// to register transformations prior to dispatch. +// +// This function should be used by generated code only. +func QuorumCallWithInterceptor[Req, Resp msg, Out any]( + ctx context.Context, + config RawConfiguration, + req Req, + method string, + base QuorumFunc[Req, Resp, Out], + interceptors ...QuorumInterceptor[Req, Resp, Out], +) (Out, error) { + md := ordering.NewGorumsMetadata(ctx, config.getMsgID(), method) + replyChan := make(chan Result[msg], len(config)) + + // Create ClientCtx first so sendOnce can access it + clientCtx := newClientCtx[Req, Resp](ctx, config, req, method, replyChan) + + // Create sendOnce function that will be called lazily on first Responses() call + sendOnce := func() { + for _, n := range config { + // Apply registered request transformations (if any) + msg := clientCtx.applyTransforms(req, n) + if msg == nil { + continue // Skip node if transformation function returns nil + } + n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), responseChan: replyChan}) + } + } + + // Wrap sendOnce with sync.OnceFunc to ensure it's only called once + clientCtx.sendOnce = sync.OnceFunc(sendOnce) + + handler := Chain(base, interceptors...) + return handler(clientCtx) +} diff --git a/client_interceptor_test.go b/client_interceptor_test.go new file mode 100644 index 00000000..3c83f652 --- /dev/null +++ b/client_interceptor_test.go @@ -0,0 +1,794 @@ +package gorums + +import ( + "context" + "errors" + "slices" + "strconv" + "sync" + "testing" + "time" + + "github.com/relab/gorums/internal/testutils/mock" + "google.golang.org/protobuf/proto" + pb "google.golang.org/protobuf/types/known/wrapperspb" +) + +// Test helper types and functions + +// testContext creates a context with timeout for testing. +// It uses t.Context() as the parent and automatically cancels on cleanup. +func testContext(t *testing.T, timeout time.Duration) context.Context { + t.Helper() + ctx, cancel := context.WithTimeout(t.Context(), timeout) + t.Cleanup(cancel) + return ctx +} + +// checkError is a helper to validate error expectations in tests. +func checkError(t *testing.T, wantErr bool, err, wantErrType error) bool { + t.Helper() + if wantErr { + if err == nil { + t.Error("Expected error, got nil") + return false + } + if wantErrType != nil && !errors.Is(err, wantErrType) { + t.Errorf("Expected error type %v, got %v", wantErrType, err) + return false + } + return true + } + if err != nil { + t.Errorf("Expected no error, got %v", err) + return false + } + return true +} + +// executionTracker tracks interceptor execution order for testing. +type executionTracker struct { + mu sync.Mutex + log []string +} + +func (et *executionTracker) append(entry string) { + et.mu.Lock() + defer et.mu.Unlock() + et.log = append(et.log, entry) +} + +func (et *executionTracker) get() []string { + et.mu.Lock() + defer et.mu.Unlock() + return append([]string(nil), et.log...) +} + +func (et *executionTracker) check(t *testing.T, want []string) { + t.Helper() + got := et.get() + if len(got) != len(want) { + t.Errorf("Expected %d log entries, got %d: %v", len(want), len(got), got) + return + } + for i, wantEntry := range want { + if i >= len(got) || got[i] != wantEntry { + t.Errorf("log[%d] = %v, want %s", i, got, wantEntry) + } + } +} + +// loggingInterceptor creates a reusable logging interceptor for testing. +// It logs "before" and "after" entries to the provided tracker. +func loggingInterceptor[Req, Resp proto.Message](tracker *executionTracker) QuorumInterceptor[Req, Resp, Resp] { + return func(next QuorumFunc[Req, Resp, Resp]) QuorumFunc[Req, Resp, Resp] { + return func(ctx *ClientCtx[Req, Resp]) (Resp, error) { + tracker.append("logging-before") + result, err := next(ctx) + tracker.append("logging-after") + return result, err + } + } +} + +// makeClientCtx is a helper to create a ClientCtx with mock responses for unit tests. +// It creates a channel with the provided responses and returns a ClientCtx with a short timeout. +func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, responses []Result[proto.Message]) *ClientCtx[Req, Resp] { + t.Helper() + + resultChan := make(chan Result[proto.Message], len(responses)) + for _, r := range responses { + resultChan <- r + } + close(resultChan) + + config := make(RawConfiguration, numNodes) + for i := range numNodes { + config[i] = &RawNode{id: uint32(i + 1)} + } + + return &ClientCtx[Req, Resp]{ + Context: testContext(t, 100*time.Millisecond), + config: config, + replyChan: resultChan, + } +} + +// Iterator Utility Tests + +// TestIteratorUtilities tests the iterator helper functions +func TestIteratorUtilities(t *testing.T) { + tests := []struct { + name string + responses []Result[proto.Message] + operation string // "ignoreErrors", "collectN", "collectAll" + collectN int + wantCount int + wantFilteredIDs []uint32 + }{ + { + name: "IgnoreErrors", + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response1"), Err: nil}, + {NodeID: 2, Value: nil, Err: errors.New("node error")}, + {NodeID: 3, Value: pb.String("response3"), Err: nil}, + {NodeID: 4, Value: nil, Err: errors.New("another error")}, + {NodeID: 5, Value: pb.String("response5"), Err: nil}, + }, + operation: "ignoreErrors", + wantCount: 3, + wantFilteredIDs: []uint32{1, 3, 5}, + }, + { + name: "CollectN", + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response"), Err: nil}, + {NodeID: 2, Value: pb.String("response"), Err: nil}, + {NodeID: 3, Value: pb.String("response"), Err: nil}, + {NodeID: 4, Value: pb.String("response"), Err: nil}, + {NodeID: 5, Value: pb.String("response"), Err: nil}, + }, + operation: "collectN", + collectN: 3, + wantCount: 3, + }, + { + name: "CollectAll", + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response"), Err: nil}, + {NodeID: 2, Value: pb.String("response"), Err: nil}, + {NodeID: 3, Value: pb.String("response"), Err: nil}, + {NodeID: 4, Value: pb.String("response"), Err: nil}, + {NodeID: 5, Value: pb.String("response"), Err: nil}, + }, + operation: "collectAll", + wantCount: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, len(tt.responses), tt.responses) + + switch tt.operation { + case "ignoreErrors": + count := 0 + for nodeID, resp := range IgnoreErrors(clientCtx.Responses()) { + t.Logf("Node %d: %v", nodeID, resp.GetValue()) + count++ + if !slices.Contains(tt.wantFilteredIDs, nodeID) { + t.Errorf("Node %d should have been filtered out", nodeID) + } + } + if count != tt.wantCount { + t.Errorf("Expected %d successful responses, got %d", tt.wantCount, count) + } + + case "collectN": + replies := CollectN(IgnoreErrors(clientCtx.Responses()), tt.collectN) + if len(replies) != tt.wantCount { + t.Errorf("Expected %d responses, got %d", tt.wantCount, len(replies)) + } + + case "collectAll": + replies := CollectAll(IgnoreErrors(clientCtx.Responses())) + if len(replies) != tt.wantCount { + t.Errorf("Expected %d responses, got %d", tt.wantCount, len(replies)) + } + } + }) + } +} + +// Interceptor Unit Tests + +// TestInterceptorChaining tests composing multiple interceptors +func TestInterceptorChaining(t *testing.T) { + t.Run("ChainLoggingAndQuorum", func(t *testing.T) { + // Track interceptor execution order + tracker := &executionTracker{} + + // Create mock responses + responses := []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response1"), Err: nil}, + {NodeID: 2, Value: pb.String("response2"), Err: nil}, + {NodeID: 3, Value: pb.String("response3"), Err: nil}, + } + clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses) + + handler := Chain( + MajorityQuorum[*pb.StringValue, *pb.StringValue], + loggingInterceptor[*pb.StringValue, *pb.StringValue](tracker), + ) + result, err := handler(clientCtx) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if result.GetValue() != "response1" { + t.Errorf("Expected 'response1', got '%s'", result.GetValue()) + } + + // Check execution order + tracker.check(t, []string{"logging-before", "logging-after"}) + }) +} + +// sumInterceptor returns the sum of the responses from all nodes. +// It demonstrates a custom aggregation interceptor; it is used in several tests. +var sumInterceptor = func(ctx *ClientCtx[*pb.Int32Value, *pb.Int32Value]) (*pb.Int32Value, error) { + var sum int32 + for _, result := range IgnoreErrors(ctx.Responses()) { + sum += result.GetValue() + } + return pb.Int32(sum), nil +} + +// TestInterceptorCustomAggregation demonstrates custom interceptor for aggregation +func TestInterceptorCustomAggregation(t *testing.T) { + t.Run("SumAggregation", func(t *testing.T) { + responses := []Result[proto.Message]{ + {NodeID: 1, Value: pb.Int32(10), Err: nil}, + {NodeID: 2, Value: pb.Int32(20), Err: nil}, + {NodeID: 3, Value: pb.Int32(30), Err: nil}, + } + clientCtx := makeClientCtx[*pb.Int32Value, *pb.Int32Value](t, 3, responses) + + // Custom aggregation interceptors typically don't call next. + result, err := sumInterceptor(clientCtx) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if got, want := result.GetValue(), int32(60); got != want { + t.Errorf("Sum = %d, want %d", got, want) + } + }) +} + +// TestInterceptorCustomReturnType demonstrates using interceptors to return a different type +func TestInterceptorCustomReturnType(t *testing.T) { + type CustomResult struct { + Total int + Count int + } + + t.Run("ConvertToCustomType", func(t *testing.T) { + customReturnInterceptor := func(ctx *ClientCtx[*pb.Int32Value, *pb.Int32Value]) (*CustomResult, error) { + var total int32 + var count int + for _, result := range IgnoreErrors(ctx.Responses()) { + total += result.GetValue() + count++ + } + return &CustomResult{Total: int(total), Count: count}, nil + } + + responses := []Result[proto.Message]{ + {NodeID: 1, Value: pb.Int32(10), Err: nil}, + {NodeID: 2, Value: pb.Int32(20), Err: nil}, + {NodeID: 3, Value: pb.Int32(30), Err: nil}, + } + clientCtx := makeClientCtx[*pb.Int32Value, *pb.Int32Value](t, 3, responses) + + result, err := customReturnInterceptor(clientCtx) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if got, want := result.Total, 60; got != want { + t.Errorf("Total = %d, want %d", got, want) + } + + if got, want := result.Count, 3; got != want { + t.Errorf("Count = %d, want %d", got, want) + } + }) +} + +// Interceptor Integration Tests with Real Servers + +// TestInterceptorIntegration_MajorityQuorum tests the complete flow with real servers +func TestInterceptorIntegration_MajorityQuorum(t *testing.T) { + addrs, closeServers := TestSetup(t, 3, echoServerFn) + t.Cleanup(closeServers) + + result, err := QuorumCallWithInterceptor( + testContext(t, 2*time.Second), + NewConfig(t, addrs), + pb.String("test"), + mock.TestMethod, + MajorityQuorum[*pb.StringValue, *pb.StringValue], + ) + if err != nil { + t.Fatalf("QuorumCall failed: %v", err) + } + if got, want := result.GetValue(), "echo: test"; got != want { + t.Errorf("Response = %q, want %q", got, want) + } +} + +// TestInterceptorIntegration_CustomAggregation tests custom aggregation with real servers +func TestInterceptorIntegration_CustomAggregation(t *testing.T) { + addrs, closeServers := TestSetup(t, 3, nil) + t.Cleanup(closeServers) + + result, err := QuorumCallWithInterceptor( + testContext(t, 2*time.Second), + NewConfig(t, addrs), + pb.Int32(0), + mock.GetValueMethod, + sumInterceptor, + ) + if err != nil { + t.Fatalf("QuorumCall failed: %v", err) + } + + // Expected: 10 + 20 + 30 = 60 + if result.GetValue() != 60 { + t.Errorf("Expected sum of 60, got %d", result.GetValue()) + } +} + +// TestInterceptorIntegration_Chaining tests chained interceptors with real servers +func TestInterceptorIntegration_Chaining(t *testing.T) { + addrs, closeServers := TestSetup(t, 3, echoServerFn) + t.Cleanup(closeServers) + + // Track interceptor execution + tracker := &executionTracker{} + + // Chain interceptors + result, err := QuorumCallWithInterceptor( + testContext(t, 2*time.Second), + NewConfig(t, addrs), + pb.String("test"), + mock.TestMethod, + MajorityQuorum[*pb.StringValue, *pb.StringValue], // Base + loggingInterceptor[*pb.StringValue, *pb.StringValue](tracker), // Interceptor + ) + if err != nil { + t.Fatalf("QuorumCall failed: %v", err) + } + + if result.GetValue() != "echo: test" { + t.Errorf("Expected 'echo: test', got '%s'", result.GetValue()) + } + + // Verify interceptor execution order + tracker.check(t, []string{"logging-before", "logging-after"}) +} + +// TestInterceptorCollectAllResponses tests the CollectAllResponses interceptor +func TestInterceptorCollectAllResponses(t *testing.T) { + responses := []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response1"), Err: nil}, + {NodeID: 2, Value: pb.String("response2"), Err: nil}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + } + clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses) + + result, err := CollectAllResponses(clientCtx) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + if len(result) != 2 { + t.Errorf("Expected 2 responses, got %d", len(result)) + } + + if _, ok := result[1]; !ok { + t.Error("Expected response from node 1") + } + if _, ok := result[2]; !ok { + t.Error("Expected response from node 2") + } + if _, ok := result[3]; ok { + t.Error("Did not expect response from node 3 (had error)") + } +} + +// TestInterceptorIntegration_CollectAll tests CollectAllResponses with real servers +func TestInterceptorIntegration_CollectAll(t *testing.T) { + addrs, closeServers := TestSetup(t, 3, func(i int) ServerIface { + srv := NewServer() + srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { + req := AsProto[*pb.StringValue](in) + resp := pb.String(req.GetValue() + "-node-" + strconv.Itoa(i)) + return NewResponseMessage(in.GetMetadata(), resp), nil + }) + return srv + }) + t.Cleanup(closeServers) + + config := NewConfig(t, addrs) + result, err := QuorumCallWithInterceptor( + testContext(t, 2*time.Second), + config, + pb.String("test"), + mock.TestMethod, + CollectAllResponses[*pb.StringValue, *pb.StringValue], + ) + if err != nil { + t.Fatalf("QuorumCall failed: %v", err) + } + + if len(result) != 3 { + t.Errorf("Expected 3 responses, got %d", len(result)) + } + + // Verify we got responses from all nodes in the configuration + nodes := config.Nodes() + for _, node := range nodes { + if _, ok := result[node.ID()]; !ok { + t.Errorf("Missing response from node %d", node.ID()) + } + } +} + +// TestInterceptorIntegration_PerNodeTransform tests per-node transformation with real servers +func TestInterceptorIntegration_PerNodeTransform(t *testing.T) { + addrs, closeServers := TestSetup(t, 3, echoServerFn) + t.Cleanup(closeServers) + + // Create a transform that sends different values to each node + transformInterceptor := PerNodeTransform[*pb.StringValue, *pb.StringValue, map[uint32]*pb.StringValue]( + func(req *pb.StringValue, node *RawNode) *pb.StringValue { + return pb.String(req.GetValue() + "-node-" + strconv.Itoa(int(node.ID()))) + }, + ) + + result, err := QuorumCallWithInterceptor( + testContext(t, 2*time.Second), + NewConfig(t, addrs), + pb.String("test"), + mock.TestMethod, + CollectAllResponses[*pb.StringValue, *pb.StringValue], // Base + transformInterceptor, // Interceptor + ) + if err != nil { + t.Fatalf("QuorumCall failed: %v", err) + } + + if len(result) != 3 { + t.Errorf("Expected 3 responses, got %d", len(result)) + } + + // Verify each node received the transformed request + for nodeID, resp := range result { + expected := "echo: test-node-" + strconv.Itoa(int(nodeID)) + if resp.GetValue() != expected { + t.Errorf("Node %d: expected %q, got %q", nodeID, expected, resp.GetValue()) + } + } +} + +// TestInterceptorIntegration_PerNodeTransformSkip tests skipping nodes in per-node transformation +func TestInterceptorIntegration_PerNodeTransformSkip(t *testing.T) { + addrs, closeServers := TestSetup(t, 3, echoServerFn) + t.Cleanup(closeServers) + + config := NewConfig(t, addrs) + nodes := config.Nodes() + + // Skip the second node (index 1) + skipNodeID := nodes[1].ID() + + // Create a transform that skips one node by returning an invalid message + transformInterceptor := PerNodeTransform[*pb.StringValue, *pb.StringValue, map[uint32]*pb.StringValue]( + func(req *pb.StringValue, node *RawNode) *pb.StringValue { + if node.ID() == skipNodeID { + return nil // Skip this node + } + return pb.String(req.GetValue() + "-node-" + strconv.Itoa(int(node.ID()))) + }, + ) + + result, err := QuorumCallWithInterceptor( + testContext(t, 2*time.Second), + config, + pb.String("test"), + mock.TestMethod, + CollectAllResponses[*pb.StringValue, *pb.StringValue], // Base + transformInterceptor, // Interceptor + ) + if err != nil { + t.Fatalf("QuorumCall failed: %v", err) + } + + if len(result) != 2 { + t.Errorf("Expected 2 responses (one node skipped), got %d", len(result)) + } + + // Verify we got responses from nodes 0 and 2, but not 1 + if _, ok := result[nodes[0].ID()]; !ok { + t.Errorf("Expected response from node %d", nodes[0].ID()) + } + if _, ok := result[skipNodeID]; ok { + t.Errorf("Did not expect response from skipped node %d", skipNodeID) + } + if _, ok := result[nodes[2].ID()]; !ok { + t.Errorf("Expected response from node %d", nodes[2].ID()) + } +} + +// TestBaseQuorumFunctions tests all base quorum functions (FirstResponse, AllResponses, ThresholdQuorum, MajorityQuorum) +func TestBaseQuorumFunctions(t *testing.T) { + tests := []struct { + name string + quorumFunc QuorumFunc[*pb.StringValue, *pb.StringValue, *pb.StringValue] + numNodes int + responses []Result[proto.Message] + wantErr bool + wantErrType error + wantValue string + }{ + // FirstResponse tests + { + name: "FirstResponse_Success", + quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: pb.String("third"), Err: nil}, + }, + wantErr: false, + wantValue: "first", + }, + { + name: "FirstResponse_AfterErrors", + quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], + responses: []Result[proto.Message]{ + {NodeID: 1, Value: nil, Err: errors.New("error1")}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: pb.String("third"), Err: nil}, + }, + wantErr: false, + wantValue: "second", + }, + { + name: "FirstResponse_AllErrors", + quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], + responses: []Result[proto.Message]{ + {NodeID: 1, Value: nil, Err: errors.New("error1")}, + {NodeID: 2, Value: nil, Err: errors.New("error2")}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + }, + wantErr: true, + wantErrType: ErrIncomplete, + }, + { + name: "FirstResponse_NoResponses", + quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], + responses: []Result[proto.Message]{}, + wantErr: true, + wantErrType: ErrIncomplete, + }, + + // AllResponses tests + { + name: "AllResponses_AllSuccess", + quorumFunc: AllResponses[*pb.StringValue, *pb.StringValue], + numNodes: 3, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: pb.String("third"), Err: nil}, + }, + wantErr: false, + wantValue: "first", + }, + { + name: "AllResponses_OneError", + quorumFunc: AllResponses[*pb.StringValue, *pb.StringValue], + numNodes: 3, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: nil, Err: errors.New("error2")}, + {NodeID: 3, Value: pb.String("third"), Err: nil}, + }, + wantErr: true, + wantErrType: ErrIncomplete, + }, + + // MajorityQuorum tests + { + name: "MajorityQuorum_Success", + quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], + numNodes: 5, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response1"), Err: nil}, + {NodeID: 2, Value: pb.String("response2"), Err: nil}, + {NodeID: 3, Value: pb.String("response3"), Err: nil}, + {NodeID: 4, Value: nil, Err: errors.New("error4")}, + {NodeID: 5, Value: nil, Err: errors.New("error5")}, + }, + wantErr: false, + wantValue: "response1", + }, + { + name: "MajorityQuorum_Insufficient", + quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], + numNodes: 5, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response1"), Err: nil}, + {NodeID: 2, Value: pb.String("response2"), Err: nil}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + {NodeID: 4, Value: nil, Err: errors.New("error4")}, + {NodeID: 5, Value: nil, Err: errors.New("error5")}, + }, + wantErr: true, + wantErrType: ErrIncomplete, + }, + { + name: "MajorityQuorum_Exact", + quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], + numNodes: 3, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: nil, Err: errors.New("error")}, + }, + wantErr: false, + wantValue: "first", + }, + + // ThresholdQuorum tests + { + name: "ThresholdQuorum_Met", + quorumFunc: ThresholdQuorum[*pb.StringValue, *pb.StringValue](2), + numNodes: 3, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + }, + wantErr: false, + wantValue: "first", + }, + { + name: "ThresholdQuorum_NotMet", + quorumFunc: ThresholdQuorum[*pb.StringValue, *pb.StringValue](3), + numNodes: 3, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: nil, Err: errors.New("error2")}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + }, + wantErr: true, + wantErrType: ErrIncomplete, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + numNodes := tt.numNodes + if numNodes == 0 { + numNodes = len(tt.responses) + } + clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, numNodes, tt.responses) + + result, err := tt.quorumFunc(clientCtx) + + if !checkError(t, tt.wantErr, err, tt.wantErrType) { + return + } + + if !tt.wantErr && result.GetValue() != tt.wantValue { + t.Errorf("Expected '%s', got '%s'", tt.wantValue, result.GetValue()) + } + }) + } +} + +// TestInterceptorQuorumSpecAdapter tests the QuorumSpecInterceptor adapter +func TestInterceptorQuorumSpecAdapter(t *testing.T) { + t.Run("AdapterWithMajorityQuorum", func(t *testing.T) { + // Define a legacy-style quorum function + qf := func(req *pb.StringValue, replies map[uint32]*pb.StringValue) (*pb.StringValue, bool) { + quorumSize := 2 // Majority of 3 + if len(replies) >= quorumSize { + // Return first reply + for _, v := range replies { + return v, true + } + } + return nil, false + } + + responses := []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + } + clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses) + + // Convert to interceptor + interceptor := QuorumSpecInterceptor(qf) + result, err := interceptor(MajorityQuorum[*pb.StringValue, *pb.StringValue])(clientCtx) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + + // Should get one of the successful responses + if result.GetValue() != "first" && result.GetValue() != "second" { + t.Errorf("Expected 'first' or 'second', got '%s'", result.GetValue()) + } + }) + + t.Run("AdapterQuorumNotReached", func(t *testing.T) { + // Define a quorum function that needs all 3 responses + qf := func(req *pb.StringValue, replies map[uint32]*pb.StringValue) (*pb.StringValue, bool) { + if len(replies) == 3 { + return pb.String("success"), true + } + return nil, false + } + + responses := []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: nil, Err: errors.New("error2")}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + } + clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses) + + interceptor := QuorumSpecInterceptor(qf) + + _, err := interceptor(MajorityQuorum[*pb.StringValue, *pb.StringValue])(clientCtx) + if err == nil { + t.Error("Expected error when quorum not reached") + } + }) +} + +// TestInterceptorTerminalHandlerError tests that ErrNoTerminalHandler +// is returned when no interceptor in the chain completes the call. +// TestInterceptorUsage demonstrates correct usage of interceptors +func TestInterceptorUsage(t *testing.T) { + t.Run("CorrectUsage", func(t *testing.T) { + // Demonstrate correct usage: transform followed by aggregator + transform := func(req *pb.StringValue, node *RawNode) *pb.StringValue { + return pb.String(req.GetValue() + "-transformed") + } + + responses := []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("response1"), Err: nil}, + {NodeID: 2, Value: pb.String("response2"), Err: nil}, + {NodeID: 3, Value: pb.String("response3"), Err: nil}, + } + clientCtx := makeClientCtx[*pb.StringValue, *pb.StringValue](t, 3, responses) + + // Correct chain: transform -> aggregator (MajorityQuorum completes the call) + handler := Chain( + MajorityQuorum[*pb.StringValue, *pb.StringValue], + PerNodeTransform[*pb.StringValue, *pb.StringValue, *pb.StringValue](transform), + ) + + result, err := handler(clientCtx) + if err != nil { + t.Errorf("Expected no error, got %v", err) + } + if result == nil || result.GetValue() != "response1" { + t.Errorf("Expected 'response1', got %v", result) + } + }) +} From d2e0f504ff61160c2e2cbd30201fc05316c94d83 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 13:11:12 -0800 Subject: [PATCH 05/16] fix: underscore unused parameters and remove unused receiver parameter This fixes lint issues raised by deepsource and golangci-lint. --- client_interceptor.go | 2 +- client_interceptor_test.go | 8 ++++---- testing_gorums.go | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/client_interceptor.go b/client_interceptor.go index 8a2dcc93..280a393d 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -270,7 +270,7 @@ func PerNodeTransform[Req, Resp msg, Out any](transform func(Req, *RawNode) Req) func QuorumSpecInterceptor[Req, Resp msg, Out any]( qf func(Req, map[uint32]Resp) (Out, bool), ) QuorumInterceptor[Req, Resp, Out] { - return func(next QuorumFunc[Req, Resp, Out]) QuorumFunc[Req, Resp, Out] { + return func(_ QuorumFunc[Req, Resp, Out]) QuorumFunc[Req, Resp, Out] { return func(ctx *ClientCtx[Req, Resp]) (Out, error) { replies := CollectAll(IgnoreErrors(ctx.Responses())) resp, ok := qf(ctx.Request(), replies) diff --git a/client_interceptor_test.go b/client_interceptor_test.go index 3c83f652..4d135fd1 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -411,7 +411,7 @@ func TestInterceptorCollectAllResponses(t *testing.T) { func TestInterceptorIntegration_CollectAll(t *testing.T) { addrs, closeServers := TestSetup(t, 3, func(i int) ServerIface { srv := NewServer() - srv.RegisterHandler(mock.TestMethod, func(ctx ServerCtx, in *Message) (*Message, error) { + srv.RegisterHandler(mock.TestMethod, func(_ ServerCtx, in *Message) (*Message, error) { req := AsProto[*pb.StringValue](in) resp := pb.String(req.GetValue() + "-node-" + strconv.Itoa(i)) return NewResponseMessage(in.GetMetadata(), resp), nil @@ -704,7 +704,7 @@ func TestBaseQuorumFunctions(t *testing.T) { func TestInterceptorQuorumSpecAdapter(t *testing.T) { t.Run("AdapterWithMajorityQuorum", func(t *testing.T) { // Define a legacy-style quorum function - qf := func(req *pb.StringValue, replies map[uint32]*pb.StringValue) (*pb.StringValue, bool) { + qf := func(_ *pb.StringValue, replies map[uint32]*pb.StringValue) (*pb.StringValue, bool) { quorumSize := 2 // Majority of 3 if len(replies) >= quorumSize { // Return first reply @@ -737,7 +737,7 @@ func TestInterceptorQuorumSpecAdapter(t *testing.T) { t.Run("AdapterQuorumNotReached", func(t *testing.T) { // Define a quorum function that needs all 3 responses - qf := func(req *pb.StringValue, replies map[uint32]*pb.StringValue) (*pb.StringValue, bool) { + qf := func(_ *pb.StringValue, replies map[uint32]*pb.StringValue) (*pb.StringValue, bool) { if len(replies) == 3 { return pb.String("success"), true } @@ -766,7 +766,7 @@ func TestInterceptorQuorumSpecAdapter(t *testing.T) { func TestInterceptorUsage(t *testing.T) { t.Run("CorrectUsage", func(t *testing.T) { // Demonstrate correct usage: transform followed by aggregator - transform := func(req *pb.StringValue, node *RawNode) *pb.StringValue { + transform := func(req *pb.StringValue, _ *RawNode) *pb.StringValue { return pb.String(req.GetValue() + "-transformed") } diff --git a/testing_gorums.go b/testing_gorums.go index 0743a12a..88068013 100644 --- a/testing_gorums.go +++ b/testing_gorums.go @@ -140,7 +140,7 @@ type testSrv struct { val int32 } -func (_ testSrv) Test(_ ServerCtx, _ proto.Message) (proto.Message, error) { +func (testSrv) Test(_ ServerCtx, _ proto.Message) (proto.Message, error) { return pb.String(""), nil } From c161818f0f9587b14ae41fd2fa566198eef00352 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 13:30:09 -0800 Subject: [PATCH 06/16] fix: quorum size calculation and add tests for even node scenarios --- client_interceptor.go | 2 +- client_interceptor_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/client_interceptor.go b/client_interceptor.go index 280a393d..054c4ac3 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -333,7 +333,7 @@ func ThresholdQuorum[Req, Resp msg](threshold int) QuorumFunc[Req, Resp, Resp] { // // This is a base quorum function that terminates the interceptor chain. func MajorityQuorum[Req, Resp msg](ctx *ClientCtx[Req, Resp]) (Resp, error) { - quorumSize := (ctx.Size() + 1) / 2 + quorumSize := ctx.Size()/2 + 1 return ThresholdQuorum[Req, Resp](quorumSize)(ctx) } diff --git a/client_interceptor_test.go b/client_interceptor_test.go index 4d135fd1..dbfb79c5 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -651,6 +651,32 @@ func TestBaseQuorumFunctions(t *testing.T) { wantErr: false, wantValue: "first", }, + { + name: "MajorityQuorum_Even_Success", + quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], + numNodes: 4, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: pb.String("third"), Err: nil}, + {NodeID: 4, Value: nil, Err: errors.New("error4")}, + }, + wantErr: false, + wantValue: "first", + }, + { + name: "MajorityQuorum_Even_Insufficient", + quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], + numNodes: 4, + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("first"), Err: nil}, + {NodeID: 2, Value: pb.String("second"), Err: nil}, + {NodeID: 3, Value: nil, Err: errors.New("error3")}, + {NodeID: 4, Value: nil, Err: errors.New("error4")}, + }, + wantErr: true, + wantErrType: ErrIncomplete, + }, // ThresholdQuorum tests { From 376f0bd2e00bf0c66f77ac3bf9bb50fb8b926cb3 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 13:34:31 -0800 Subject: [PATCH 07/16] fix: QuorumSpecInterceptor doc comment --- client_interceptor.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/client_interceptor.go b/client_interceptor.go index 054c4ac3..89ac679f 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -255,6 +255,9 @@ func PerNodeTransform[Req, Resp msg, Out any](transform func(Req, *RawNode) Req) // The quorum function receives the original request and a map of replies, and returns // the aggregated result and a boolean indicating whether quorum was reached. // +// Note: This is a terminal handler that collects all responses itself. Any base quorum function +// passed when using this interceptor will be ignored." +// // Example: // // // Legacy QuorumSpec function From e5b480d1610558e5c63db867e5136e2b1b33265e Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 13:34:50 -0800 Subject: [PATCH 08/16] fix: typos in ErrIncomplete doc comment --- errors.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/errors.go b/errors.go index 6173b105..1f4a8a99 100644 --- a/errors.go +++ b/errors.go @@ -6,8 +6,8 @@ import ( "strings" ) -// ErrIncomplete is the error returned by a quorum call when the call cannot completed -// due insufficient non-error replies to form a quorum according to the quorum function. +// ErrIncomplete is the error returned by a quorum call when the call cannot be completed +// due to insufficient non-error replies to form a quorum according to the quorum function. var ErrIncomplete = errors.New("incomplete call") // ErrTypeMismatch is returned when a response cannot be cast to the expected type. From 0d3c93753cc4cac03e4844293e5a565e2cc7d549 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Wed, 26 Nov 2025 13:46:49 -0800 Subject: [PATCH 09/16] chore: remove extra " typo --- client_interceptor.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client_interceptor.go b/client_interceptor.go index 89ac679f..62a32e7b 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -256,7 +256,7 @@ func PerNodeTransform[Req, Resp msg, Out any](transform func(Req, *RawNode) Req) // the aggregated result and a boolean indicating whether quorum was reached. // // Note: This is a terminal handler that collects all responses itself. Any base quorum function -// passed when using this interceptor will be ignored." +// passed when using this interceptor will be ignored. // // Example: // From 424f760751702b8f951b6c0b5f30e8b313479b18 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Thu, 27 Nov 2025 12:07:56 -0800 Subject: [PATCH 10/16] refactor: iterator API to use single-value Results[T] type as receiver MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This changes the iterator API from iter.Seq2[uint32, Result[T]] to a cleaner iter.Seq[Result[T]] pattern, and a type alias Results[T] which can serve as receiver for methods on said iterator. This simplifies the API by consolidating node ID and result information into a single Result[T] value, making iteration more ergonomic, despite not following Go's function-based iterator patterns. Key changes: - Introduce Results[T] type alias for iter.Seq[Result[T]] - Change ClientCtx.Responses() return type from iter.Seq2 to Results[T] - Update iterator helper methods to be methods on Results[T]: * IgnoreErrors() now returns Results[T] instead of iter.Seq2[uint32, T] * Add Filter() method for generic result filtering * CollectN() and CollectAll() now methods on Results[T] - Update all iterator consumers to use single-value iteration pattern - Constrain ClientCtx type parameters to msg (proto.Message) type Benefits: - Simpler iteration: `for result := range ctx.Responses()` vs `for nodeID, result := range ctx.Responses()` - More composable: method chaining like `ctx.Responses().IgnoreErrors().CollectAll()` - Consistent: Result[T] already contains NodeID, no need to pass separately - Cleaner: Filter() operates on complete Result[T] values This borrows from Asbjørn Salhus's design in PR #230, which I now agree is better than Go's function-based iterator pattern because of its significantly better composability. That is, you avoid composing with functions that would look like: gorums.IgnoreErrors(ctx.Responses())... and even worse when there are many iterators being composed. --- client_interceptor.go | 78 ++++++++++++++++++++++++-------------- client_interceptor_test.go | 63 ++++++++++++++++++++---------- 2 files changed, 92 insertions(+), 49 deletions(-) diff --git a/client_interceptor.go b/client_interceptor.go index 62a32e7b..557124f2 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -3,7 +3,6 @@ package gorums import ( "context" "iter" - "maps" "sync" "github.com/relab/gorums/ordering" @@ -33,9 +32,12 @@ type QuorumInterceptor[Req, Resp msg, Out any] func(QuorumFunc[Req, Resp, Out]) // (e.g. MajorityQuorum) that actually collects and aggregates responses. type QuorumFunc[Req, Resp msg, Out any] func(*ClientCtx[Req, Resp]) (Out, error) +// Results is an iterator that yields Result[T] values from a quorum call. +type Results[T msg] iter.Seq[Result[T]] + // ClientCtx provides context and access to the quorum call state for interceptors. // It exposes the request, configuration, and an iterator over node responses. -type ClientCtx[Req, Resp any] struct { +type ClientCtx[Req, Resp msg] struct { context.Context config RawConfiguration request Req @@ -48,7 +50,7 @@ type ClientCtx[Req, Resp any] struct { } // newClientCtx creates a new ClientCtx for a quorum call. -func newClientCtx[Req, Resp any]( +func newClientCtx[Req, Resp msg]( ctx context.Context, config RawConfiguration, req Req, @@ -134,22 +136,19 @@ func (c *ClientCtx[Req, Resp]) applyTransforms(req Req, node *RawNode) proto.Mes // // Example usage: // -// for nodeID, result := range ctx.Responses() { +// for result := range ctx.Responses() { // if result.Err != nil { // // Handle node error // continue // } // // Process result.Value -// if haveQuorum { -// break // Early termination -// } // } -func (c *ClientCtx[Req, Resp]) Responses() iter.Seq2[uint32, Result[Resp]] { +func (c *ClientCtx[Req, Resp]) Responses() Results[Resp] { // Trigger lazy sending if c.sendOnce != nil { c.sendOnce() } - return func(yield func(uint32, Result[Resp]) bool) { + return func(yield func(Result[Resp]) bool) { // Wait for at most c.Size() responses for range c.Size() { select { @@ -168,7 +167,7 @@ func (c *ClientCtx[Req, Resp]) Responses() iter.Seq2[uint32, Result[Resp]] { res.Err = ErrTypeMismatch } } - if !yield(res.NodeID, res) { + if !yield(res) { return // Consumer stopped iteration } case <-c.Done(): @@ -182,20 +181,21 @@ func (c *ClientCtx[Req, Resp]) Responses() iter.Seq2[uint32, Result[Resp]] { // Iterator Helpers // ------------------------------------------------------------------------- -// IgnoreErrors filters an iterator to only yield successful responses, discarding errors. -// This is useful when you want to process only valid responses from nodes. +// IgnoreErrors returns an iterator that yields only successful responses, +// discarding any responses with errors. This is useful when you want to process +// only valid responses from nodes. // // Example: // -// for nodeID, resp := range IgnoreErrors(ctx.Responses()) { +// for resp := range ctx.Responses().IgnoreErrors() { // // resp is guaranteed to be a successful response // process(resp) // } -func IgnoreErrors[Resp any](seq iter.Seq2[uint32, Result[Resp]]) iter.Seq2[uint32, Resp] { - return func(yield func(uint32, Resp) bool) { - for nodeID, result := range seq { +func (seq Results[Resp]) IgnoreErrors() Results[Resp] { + return func(yield func(Result[Resp]) bool) { + for result := range seq { if result.Err == nil { - if !yield(nodeID, result.Value) { + if !yield(result) { return } } @@ -203,12 +203,28 @@ func IgnoreErrors[Resp any](seq iter.Seq2[uint32, Result[Resp]]) iter.Seq2[uint3 } } -// CollectN collects up to n successful responses from the iterator into a map. +// Filter returns an iterator that yields only the responses for which the +// provided keep function returns true. This is useful for verifying or filtering +// responses from servers before further processing. +func (seq Results[Resp]) Filter(keep func(Result[Resp]) bool) Results[Resp] { + return func(yield func(Result[Resp]) bool) { + for result := range seq { + if keep(result) { + if !yield(result) { + return + } + } + } + } +} + +// CollectN collects up to n responses from the iterator into a map by node ID. +// It includes both successful and error responses. // It returns early if n responses are collected or the iterator is exhausted. -func CollectN[Resp any](seq iter.Seq2[uint32, Resp], n int) map[uint32]Resp { +func (seq Results[Resp]) CollectN(n int) map[uint32]Resp { replies := make(map[uint32]Resp, n) - for nodeID, resp := range seq { - replies[nodeID] = resp + for result := range seq { + replies[result.NodeID] = result.Value if len(replies) >= n { break } @@ -216,10 +232,14 @@ func CollectN[Resp any](seq iter.Seq2[uint32, Resp], n int) map[uint32]Resp { return replies } -// CollectAll collects all responses from the iterator into a map. -// This is a convenience wrapper around maps.Collect. -func CollectAll[Resp any](seq iter.Seq2[uint32, Resp]) map[uint32]Resp { - return maps.Collect(seq) +// CollectAll collects all responses from the iterator into a map by node ID. +// It includes both successful and error responses. +func (seq Results[Resp]) CollectAll() map[uint32]Resp { + replies := make(map[uint32]Resp) + for result := range seq { + replies[result.NodeID] = result.Value + } + return replies } // ------------------------------------------------------------------------- @@ -275,7 +295,7 @@ func QuorumSpecInterceptor[Req, Resp msg, Out any]( ) QuorumInterceptor[Req, Resp, Out] { return func(_ QuorumFunc[Req, Resp, Out]) QuorumFunc[Req, Resp, Out] { return func(ctx *ClientCtx[Req, Resp]) (Out, error) { - replies := CollectAll(IgnoreErrors(ctx.Responses())) + replies := ctx.Responses().IgnoreErrors().CollectAll() resp, ok := qf(ctx.Request(), replies) if !ok { var zero Out @@ -308,9 +328,9 @@ func ThresholdQuorum[Req, Resp msg](threshold int) QuorumFunc[Req, Resp, Resp] { errs []nodeError ) - for nodeID, result := range ctx.Responses() { + for result := range ctx.Responses() { if result.Err != nil { - errs = append(errs, nodeError{nodeID: nodeID, cause: result.Err}) + errs = append(errs, nodeError{nodeID: result.NodeID, cause: result.Err}) continue } @@ -360,7 +380,7 @@ func AllResponses[Req, Resp msg](ctx *ClientCtx[Req, Resp]) (Resp, error) { // // This is a base quorum function that terminates the interceptor chain. func CollectAllResponses[Req, Resp msg](ctx *ClientCtx[Req, Resp]) (map[uint32]Resp, error) { - return maps.Collect(IgnoreErrors(ctx.Responses())), nil + return ctx.Responses().CollectAll(), nil } // ------------------------------------------------------------------------- diff --git a/client_interceptor_test.go b/client_interceptor_test.go index dbfb79c5..b24783ac 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -121,7 +121,7 @@ func TestIteratorUtilities(t *testing.T) { tests := []struct { name string responses []Result[proto.Message] - operation string // "ignoreErrors", "collectN", "collectAll" + operation string // "ignoreErrors", "collectN", "collectAll", "filter" collectN int wantCount int wantFilteredIDs []uint32 @@ -143,7 +143,7 @@ func TestIteratorUtilities(t *testing.T) { name: "CollectN", responses: []Result[proto.Message]{ {NodeID: 1, Value: pb.String("response"), Err: nil}, - {NodeID: 2, Value: pb.String("response"), Err: nil}, + {NodeID: 2, Value: nil, Err: errors.New("error")}, {NodeID: 3, Value: pb.String("response"), Err: nil}, {NodeID: 4, Value: pb.String("response"), Err: nil}, {NodeID: 5, Value: pb.String("response"), Err: nil}, @@ -156,13 +156,22 @@ func TestIteratorUtilities(t *testing.T) { name: "CollectAll", responses: []Result[proto.Message]{ {NodeID: 1, Value: pb.String("response"), Err: nil}, - {NodeID: 2, Value: pb.String("response"), Err: nil}, + {NodeID: 2, Value: nil, Err: errors.New("error")}, {NodeID: 3, Value: pb.String("response"), Err: nil}, - {NodeID: 4, Value: pb.String("response"), Err: nil}, - {NodeID: 5, Value: pb.String("response"), Err: nil}, }, operation: "collectAll", - wantCount: 5, + wantCount: 3, + }, + { + name: "Filter", + responses: []Result[proto.Message]{ + {NodeID: 1, Value: pb.String("keep"), Err: nil}, + {NodeID: 2, Value: pb.String("drop"), Err: nil}, + {NodeID: 3, Value: pb.String("keep"), Err: nil}, + }, + operation: "filter", + wantCount: 2, + wantFilteredIDs: []uint32{1, 3}, }, } @@ -173,11 +182,11 @@ func TestIteratorUtilities(t *testing.T) { switch tt.operation { case "ignoreErrors": count := 0 - for nodeID, resp := range IgnoreErrors(clientCtx.Responses()) { - t.Logf("Node %d: %v", nodeID, resp.GetValue()) + for resp := range clientCtx.Responses().IgnoreErrors() { + t.Logf("Node %d: %v", resp.NodeID, resp.Value.GetValue()) count++ - if !slices.Contains(tt.wantFilteredIDs, nodeID) { - t.Errorf("Node %d should have been filtered out", nodeID) + if !slices.Contains(tt.wantFilteredIDs, resp.NodeID) { + t.Errorf("Node %d should have been filtered out", resp.NodeID) } } if count != tt.wantCount { @@ -185,16 +194,30 @@ func TestIteratorUtilities(t *testing.T) { } case "collectN": - replies := CollectN(IgnoreErrors(clientCtx.Responses()), tt.collectN) + replies := clientCtx.Responses().CollectN(tt.collectN) if len(replies) != tt.wantCount { t.Errorf("Expected %d responses, got %d", tt.wantCount, len(replies)) } case "collectAll": - replies := CollectAll(IgnoreErrors(clientCtx.Responses())) + replies := clientCtx.Responses().CollectAll() if len(replies) != tt.wantCount { t.Errorf("Expected %d responses, got %d", tt.wantCount, len(replies)) } + + case "filter": + count := 0 + for resp := range clientCtx.Responses().Filter(func(r Result[*pb.StringValue]) bool { + return r.Value.GetValue() == "keep" + }) { + count++ + if !slices.Contains(tt.wantFilteredIDs, resp.NodeID) { + t.Errorf("Node %d should have been filtered out", resp.NodeID) + } + } + if count != tt.wantCount { + t.Errorf("Expected %d responses, got %d", tt.wantCount, count) + } } }) } @@ -237,8 +260,8 @@ func TestInterceptorChaining(t *testing.T) { // It demonstrates a custom aggregation interceptor; it is used in several tests. var sumInterceptor = func(ctx *ClientCtx[*pb.Int32Value, *pb.Int32Value]) (*pb.Int32Value, error) { var sum int32 - for _, result := range IgnoreErrors(ctx.Responses()) { - sum += result.GetValue() + for result := range ctx.Responses().IgnoreErrors() { + sum += result.Value.GetValue() } return pb.Int32(sum), nil } @@ -276,8 +299,8 @@ func TestInterceptorCustomReturnType(t *testing.T) { customReturnInterceptor := func(ctx *ClientCtx[*pb.Int32Value, *pb.Int32Value]) (*CustomResult, error) { var total int32 var count int - for _, result := range IgnoreErrors(ctx.Responses()) { - total += result.GetValue() + for result := range ctx.Responses().IgnoreErrors() { + total += result.Value.GetValue() count++ } return &CustomResult{Total: int(total), Count: count}, nil @@ -392,8 +415,8 @@ func TestInterceptorCollectAllResponses(t *testing.T) { t.Errorf("Expected no error, got %v", err) } - if len(result) != 2 { - t.Errorf("Expected 2 responses, got %d", len(result)) + if len(result) != 3 { + t.Errorf("Expected 3 responses, got %d", len(result)) } if _, ok := result[1]; !ok { @@ -402,8 +425,8 @@ func TestInterceptorCollectAllResponses(t *testing.T) { if _, ok := result[2]; !ok { t.Error("Expected response from node 2") } - if _, ok := result[3]; ok { - t.Error("Did not expect response from node 3 (had error)") + if _, ok := result[3]; !ok { + t.Error("Expected response from node 3 (even if error)") } } From 176976e2c5fb463edf823ebc79b2dc9fac4b7837 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Thu, 27 Nov 2025 13:00:08 -0800 Subject: [PATCH 11/16] fix: check if a quorumcall failed with deadline exceeded Tests should fail with a deadline exceeded if they block; this is an indication of a deadlock issue that needs to be investigated. --- client_interceptor_test.go | 74 +++++++++++++++++++++++--------------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/client_interceptor_test.go b/client_interceptor_test.go index b24783ac..f027759a 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -16,6 +16,10 @@ import ( // Test helper types and functions +// ctxTimeout is the timeout for test contexts. If this is exceeded, +// the test will fail, indicating a bug in the test or the code under test. +const ctxTimeout = 2 * time.Second + // testContext creates a context with timeout for testing. // It uses t.Context() as the parent and automatically cancels on cleanup. func testContext(t *testing.T, timeout time.Duration) context.Context { @@ -25,7 +29,22 @@ func testContext(t *testing.T, timeout time.Duration) context.Context { return ctx } -// checkError is a helper to validate error expectations in tests. +// checkQuorumCall returns true if the quorum call was successful. +// It returns false if an error occurred or the context timed out. +func checkQuorumCall(t *testing.T, ctx context.Context, err error) bool { + t.Helper() + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + t.Error(ctx.Err()) + return false + } + if err != nil { + t.Errorf("QuorumCall failed: %v", err) + return false + } + return true +} + +// checkError returns true if the error matches the expected error. func checkError(t *testing.T, wantErr bool, err, wantErrType error) bool { t.Helper() if wantErr { @@ -335,15 +354,16 @@ func TestInterceptorIntegration_MajorityQuorum(t *testing.T) { addrs, closeServers := TestSetup(t, 3, echoServerFn) t.Cleanup(closeServers) + ctx := testContext(t, ctxTimeout) result, err := QuorumCallWithInterceptor( - testContext(t, 2*time.Second), + ctx, NewConfig(t, addrs), pb.String("test"), mock.TestMethod, MajorityQuorum[*pb.StringValue, *pb.StringValue], ) - if err != nil { - t.Fatalf("QuorumCall failed: %v", err) + if !checkQuorumCall(t, ctx, err) { + return } if got, want := result.GetValue(), "echo: test"; got != want { t.Errorf("Response = %q, want %q", got, want) @@ -355,17 +375,17 @@ func TestInterceptorIntegration_CustomAggregation(t *testing.T) { addrs, closeServers := TestSetup(t, 3, nil) t.Cleanup(closeServers) + ctx := testContext(t, ctxTimeout) result, err := QuorumCallWithInterceptor( - testContext(t, 2*time.Second), + ctx, NewConfig(t, addrs), pb.Int32(0), mock.GetValueMethod, sumInterceptor, ) - if err != nil { - t.Fatalf("QuorumCall failed: %v", err) + if !checkQuorumCall(t, ctx, err) { + return } - // Expected: 10 + 20 + 30 = 60 if result.GetValue() != 60 { t.Errorf("Expected sum of 60, got %d", result.GetValue()) @@ -380,19 +400,18 @@ func TestInterceptorIntegration_Chaining(t *testing.T) { // Track interceptor execution tracker := &executionTracker{} - // Chain interceptors + ctx := testContext(t, ctxTimeout) result, err := QuorumCallWithInterceptor( - testContext(t, 2*time.Second), + ctx, NewConfig(t, addrs), pb.String("test"), mock.TestMethod, - MajorityQuorum[*pb.StringValue, *pb.StringValue], // Base + MajorityQuorum[*pb.StringValue, *pb.StringValue], // Base loggingInterceptor[*pb.StringValue, *pb.StringValue](tracker), // Interceptor ) - if err != nil { - t.Fatalf("QuorumCall failed: %v", err) + if !checkQuorumCall(t, ctx, err) { + return } - if result.GetValue() != "echo: test" { t.Errorf("Expected 'echo: test', got '%s'", result.GetValue()) } @@ -444,24 +463,23 @@ func TestInterceptorIntegration_CollectAll(t *testing.T) { t.Cleanup(closeServers) config := NewConfig(t, addrs) + ctx := testContext(t, ctxTimeout) result, err := QuorumCallWithInterceptor( - testContext(t, 2*time.Second), + ctx, config, pb.String("test"), mock.TestMethod, CollectAllResponses[*pb.StringValue, *pb.StringValue], ) - if err != nil { - t.Fatalf("QuorumCall failed: %v", err) + if !checkQuorumCall(t, ctx, err) { + return } - if len(result) != 3 { t.Errorf("Expected 3 responses, got %d", len(result)) } // Verify we got responses from all nodes in the configuration - nodes := config.Nodes() - for _, node := range nodes { + for _, node := range config.Nodes() { if _, ok := result[node.ID()]; !ok { t.Errorf("Missing response from node %d", node.ID()) } @@ -480,18 +498,18 @@ func TestInterceptorIntegration_PerNodeTransform(t *testing.T) { }, ) + ctx := testContext(t, ctxTimeout) result, err := QuorumCallWithInterceptor( - testContext(t, 2*time.Second), + ctx, NewConfig(t, addrs), pb.String("test"), mock.TestMethod, CollectAllResponses[*pb.StringValue, *pb.StringValue], // Base transformInterceptor, // Interceptor ) - if err != nil { - t.Fatalf("QuorumCall failed: %v", err) + if !checkQuorumCall(t, ctx, err) { + return } - if len(result) != 3 { t.Errorf("Expected 3 responses, got %d", len(result)) } @@ -526,18 +544,18 @@ func TestInterceptorIntegration_PerNodeTransformSkip(t *testing.T) { }, ) + ctx := testContext(t, ctxTimeout) result, err := QuorumCallWithInterceptor( - testContext(t, 2*time.Second), + ctx, config, pb.String("test"), mock.TestMethod, CollectAllResponses[*pb.StringValue, *pb.StringValue], // Base transformInterceptor, // Interceptor ) - if err != nil { - t.Fatalf("QuorumCall failed: %v", err) + if !checkQuorumCall(t, ctx, err) { + return } - if len(result) != 2 { t.Errorf("Expected 2 responses (one node skipped), got %d", len(result)) } From bdee346b85196df37f90586f3901b0ae4d23927b Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Thu, 27 Nov 2025 13:14:46 -0800 Subject: [PATCH 12/16] fix: iterate at most expectedReplies times in Response() --- client_interceptor.go | 24 ++++++++++++++++-------- client_interceptor_test.go | 9 +++++---- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/client_interceptor.go b/client_interceptor.go index 557124f2..ce158241 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -45,6 +45,10 @@ type ClientCtx[Req, Resp msg] struct { replyChan <-chan Result[msg] reqTransforms []func(Req, *RawNode) Req + // expectedReplies is the number of responses we expect to receive. + // It defaults to the configuration size but may be lower if nodes are skipped. + expectedReplies int + // sendOnce is called lazily on the first call to Responses(). sendOnce func() } @@ -58,12 +62,13 @@ func newClientCtx[Req, Resp msg]( replyChan <-chan Result[msg], ) *ClientCtx[Req, Resp] { return &ClientCtx[Req, Resp]{ - Context: ctx, - config: config, - request: req, - method: method, - replyChan: replyChan, - reqTransforms: nil, + Context: ctx, + config: config, + request: req, + method: method, + replyChan: replyChan, + reqTransforms: nil, + expectedReplies: config.Size(), } } @@ -149,8 +154,8 @@ func (c *ClientCtx[Req, Resp]) Responses() Results[Resp] { c.sendOnce() } return func(yield func(Result[Resp]) bool) { - // Wait for at most c.Size() responses - for range c.Size() { + // Wait for at most c.expectedReplies + for range c.expectedReplies { select { case r := <-c.replyChan: // We get a Result[proto.Message] from the channel layer's @@ -450,14 +455,17 @@ func QuorumCallWithInterceptor[Req, Resp msg, Out any]( // Create sendOnce function that will be called lazily on first Responses() call sendOnce := func() { + var expected int for _, n := range config { // Apply registered request transformations (if any) msg := clientCtx.applyTransforms(req, n) if msg == nil { continue // Skip node if transformation function returns nil } + expected++ n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), responseChan: replyChan}) } + clientCtx.expectedReplies = expected } // Wrap sendOnce with sync.OnceFunc to ensure it's only called once diff --git a/client_interceptor_test.go b/client_interceptor_test.go index f027759a..47e44c30 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -111,7 +111,7 @@ func loggingInterceptor[Req, Resp proto.Message](tracker *executionTracker) Quor } // makeClientCtx is a helper to create a ClientCtx with mock responses for unit tests. -// It creates a channel with the provided responses and returns a ClientCtx with a short timeout. +// It creates a channel with the provided responses and returns a ClientCtx. func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, responses []Result[proto.Message]) *ClientCtx[Req, Resp] { t.Helper() @@ -127,9 +127,10 @@ func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, response } return &ClientCtx[Req, Resp]{ - Context: testContext(t, 100*time.Millisecond), - config: config, - replyChan: resultChan, + Context: t.Context(), + config: config, + replyChan: resultChan, + expectedReplies: numNodes, } } From d1c45c52b524e6d5a889279ec679aae79c8bbf56 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Thu, 27 Nov 2025 13:17:32 -0800 Subject: [PATCH 13/16] doc: fixed test doc comment issue raised by copilot in code review --- client_interceptor_test.go | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/client_interceptor_test.go b/client_interceptor_test.go index 47e44c30..bb2805c7 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -828,9 +828,7 @@ func TestInterceptorQuorumSpecAdapter(t *testing.T) { }) } -// TestInterceptorTerminalHandlerError tests that ErrNoTerminalHandler -// is returned when no interceptor in the chain completes the call. -// TestInterceptorUsage demonstrates correct usage of interceptors +// TestInterceptorUsage demonstrates correct usage of interceptor chaining. func TestInterceptorUsage(t *testing.T) { t.Run("CorrectUsage", func(t *testing.T) { // Demonstrate correct usage: transform followed by aggregator From 24af70c57040b204eda6d78cc4c25684d81ba2c6 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Thu, 27 Nov 2025 13:27:01 -0800 Subject: [PATCH 14/16] fix: checkQuorumCall: avoid passing in ctx (deepsource) Deepsource wants ctx to be the first argument, even in tests helpers. --- client_interceptor_test.go | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/client_interceptor_test.go b/client_interceptor_test.go index bb2805c7..f3e5e790 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -31,10 +31,10 @@ func testContext(t *testing.T, timeout time.Duration) context.Context { // checkQuorumCall returns true if the quorum call was successful. // It returns false if an error occurred or the context timed out. -func checkQuorumCall(t *testing.T, ctx context.Context, err error) bool { +func checkQuorumCall(t *testing.T, ctxErr, err error) bool { t.Helper() - if errors.Is(ctx.Err(), context.DeadlineExceeded) { - t.Error(ctx.Err()) + if errors.Is(ctxErr, context.DeadlineExceeded) { + t.Error(ctxErr) return false } if err != nil { @@ -363,7 +363,7 @@ func TestInterceptorIntegration_MajorityQuorum(t *testing.T) { mock.TestMethod, MajorityQuorum[*pb.StringValue, *pb.StringValue], ) - if !checkQuorumCall(t, ctx, err) { + if !checkQuorumCall(t, ctx.Err(), err) { return } if got, want := result.GetValue(), "echo: test"; got != want { @@ -384,7 +384,7 @@ func TestInterceptorIntegration_CustomAggregation(t *testing.T) { mock.GetValueMethod, sumInterceptor, ) - if !checkQuorumCall(t, ctx, err) { + if !checkQuorumCall(t, ctx.Err(), err) { return } // Expected: 10 + 20 + 30 = 60 @@ -410,7 +410,7 @@ func TestInterceptorIntegration_Chaining(t *testing.T) { MajorityQuorum[*pb.StringValue, *pb.StringValue], // Base loggingInterceptor[*pb.StringValue, *pb.StringValue](tracker), // Interceptor ) - if !checkQuorumCall(t, ctx, err) { + if !checkQuorumCall(t, ctx.Err(), err) { return } if result.GetValue() != "echo: test" { @@ -472,7 +472,7 @@ func TestInterceptorIntegration_CollectAll(t *testing.T) { mock.TestMethod, CollectAllResponses[*pb.StringValue, *pb.StringValue], ) - if !checkQuorumCall(t, ctx, err) { + if !checkQuorumCall(t, ctx.Err(), err) { return } if len(result) != 3 { @@ -508,7 +508,7 @@ func TestInterceptorIntegration_PerNodeTransform(t *testing.T) { CollectAllResponses[*pb.StringValue, *pb.StringValue], // Base transformInterceptor, // Interceptor ) - if !checkQuorumCall(t, ctx, err) { + if !checkQuorumCall(t, ctx.Err(), err) { return } if len(result) != 3 { @@ -554,7 +554,7 @@ func TestInterceptorIntegration_PerNodeTransformSkip(t *testing.T) { CollectAllResponses[*pb.StringValue, *pb.StringValue], // Base transformInterceptor, // Interceptor ) - if !checkQuorumCall(t, ctx, err) { + if !checkQuorumCall(t, ctx.Err(), err) { return } if len(result) != 2 { From 407b2d73adac6f04ab554ee977335e28d98a6a36 Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Thu, 27 Nov 2025 13:50:38 -0800 Subject: [PATCH 15/16] doc: updated doc comments mentioned by copilot in code review --- channel.go | 14 +++++++------- client_interceptor.go | 10 +++++----- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/channel.go b/channel.go index 9f63a497..2bf4501b 100644 --- a/channel.go +++ b/channel.go @@ -14,6 +14,13 @@ import ( "google.golang.org/protobuf/proto" ) +// Result wraps a response value from node ID, and an error if any. +type Result[T any] struct { + NodeID uint32 + Value T + Err error +} + var streamDownErr = status.Error(codes.Unavailable, "stream is down") type request struct { @@ -24,13 +31,6 @@ type request struct { responseChan chan<- Result[proto.Message] } -// Result wraps a response value with its associated error and node ID. -type Result[T any] struct { - NodeID uint32 - Value T - Err error -} - type channel struct { sendQ chan request node *RawNode diff --git a/client_interceptor.go b/client_interceptor.go index ce158241..a8f8aa5d 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -223,9 +223,9 @@ func (seq Results[Resp]) Filter(keep func(Result[Resp]) bool) Results[Resp] { } } -// CollectN collects up to n responses from the iterator into a map by node ID. -// It includes both successful and error responses. -// It returns early if n responses are collected or the iterator is exhausted. +// CollectN collects up to n responses, including errors, from the iterator +// into a map by node ID. It returns early if n responses are collected or +// the iterator is exhausted. func (seq Results[Resp]) CollectN(n int) map[uint32]Resp { replies := make(map[uint32]Resp, n) for result := range seq { @@ -237,8 +237,8 @@ func (seq Results[Resp]) CollectN(n int) map[uint32]Resp { return replies } -// CollectAll collects all responses from the iterator into a map by node ID. -// It includes both successful and error responses. +// CollectAll collects all responses, including errors, from the iterator +// into a map by node ID. func (seq Results[Resp]) CollectAll() map[uint32]Resp { replies := make(map[uint32]Resp) for result := range seq { From 0d4f85eee2f589ca64b7563844e8c67a01811a1b Mon Sep 17 00:00:00 2001 From: Hein Meling Date: Thu, 27 Nov 2025 13:55:31 -0800 Subject: [PATCH 16/16] refactor: rename Result to NodeResponse and Results to Responses This helps to more clearly distinguish the difference between an individual NodeResponse (previously Result) from the complete set of Responses (previously Results). --- async.go | 4 +-- channel.go | 20 ++++++------- channel_test.go | 8 +++--- client_interceptor.go | 28 +++++++++--------- client_interceptor_test.go | 58 +++++++++++++++++++------------------- correctable.go | 4 +-- multicast.go | 4 +-- quorumcall.go | 2 +- rpc.go | 2 +- unicast.go | 2 +- 10 files changed, 66 insertions(+), 66 deletions(-) diff --git a/async.go b/async.go index 68cf6569..e3b7a4b3 100644 --- a/async.go +++ b/async.go @@ -37,7 +37,7 @@ func (f *Async) Done() bool { type asyncCallState struct { md *ordering.Metadata data QuorumCallData - replyChan <-chan Result[proto.Message] + replyChan <-chan NodeResponse[proto.Message] expectedReplies int } @@ -47,7 +47,7 @@ type asyncCallState struct { func (c RawConfiguration) AsyncCall(ctx context.Context, d QuorumCallData) *Async { expectedReplies := len(c) md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) - replyChan := make(chan Result[proto.Message], expectedReplies) + replyChan := make(chan NodeResponse[proto.Message], expectedReplies) for _, n := range c { msg := d.Message diff --git a/channel.go b/channel.go index 2bf4501b..885649b2 100644 --- a/channel.go +++ b/channel.go @@ -14,8 +14,8 @@ import ( "google.golang.org/protobuf/proto" ) -// Result wraps a response value from node ID, and an error if any. -type Result[T any] struct { +// NodeResponse wraps a response value from node ID, and an error if any. +type NodeResponse[T any] struct { NodeID uint32 Value T Err error @@ -28,7 +28,7 @@ type request struct { msg *Message opts callOptions streaming bool - responseChan chan<- Result[proto.Message] + responseChan chan<- NodeResponse[proto.Message] } type channel struct { @@ -132,7 +132,7 @@ func (c *channel) enqueue(req request) { select { case <-c.parentCtx.Done(): // the node's close() method was called: respond with error instead of enqueueing - c.routeResponse(msgID, Result[proto.Message]{NodeID: c.node.ID(), Err: fmt.Errorf("node closed")}) + c.routeResponse(msgID, NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: fmt.Errorf("node closed")}) return case c.sendQ <- req: // enqueued successfully @@ -141,7 +141,7 @@ func (c *channel) enqueue(req request) { // routeResponse routes the response to the appropriate response channel based on msgID. // If no matching request is found, the response is discarded. -func (c *channel) routeResponse(msgID uint64, resp Result[proto.Message]) { +func (c *channel) routeResponse(msgID uint64, resp NodeResponse[proto.Message]) { c.responseMut.Lock() defer c.responseMut.Unlock() if req, ok := c.responseRouters[msgID]; ok { @@ -159,7 +159,7 @@ func (c *channel) cancelPendingMsgs() { c.responseMut.Lock() defer c.responseMut.Unlock() for msgID, req := range c.responseRouters { - req.responseChan <- Result[proto.Message]{NodeID: c.node.ID(), Err: streamDownErr} + req.responseChan <- NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: streamDownErr} // delete the router if we are only expecting a single reply message if !req.streaming { delete(c.responseRouters, msgID) @@ -188,11 +188,11 @@ func (c *channel) sender() { // take next request from sendQ } if err := c.ensureStream(); err != nil { - c.routeResponse(req.msg.GetMessageID(), Result[proto.Message]{NodeID: c.node.ID(), Err: err}) + c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: err}) continue } if err := c.sendMsg(req); err != nil { - c.routeResponse(req.msg.GetMessageID(), Result[proto.Message]{NodeID: c.node.ID(), Err: err}) + c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: err}) } } } @@ -217,7 +217,7 @@ func (c *channel) receiver() { c.clearStream() } else { err := resp.GetStatus().Err() - c.routeResponse(resp.GetMessageID(), Result[proto.Message]{NodeID: c.node.ID(), Value: resp.GetProtoMessage(), Err: err}) + c.routeResponse(resp.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.node.ID(), Value: resp.GetProtoMessage(), Err: err}) } select { @@ -247,7 +247,7 @@ func (c *channel) sendMsg(req request) (err error) { // wait for actual server responses, so mustWaitSendDone() returns false for them. if req.opts.mustWaitSendDone() && err == nil { // Send succeeded: unblock the caller and clean up the responseRouter - c.routeResponse(req.msg.GetMessageID(), Result[proto.Message]{}) + c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{}) } }() diff --git a/channel_test.go b/channel_test.go index fb1ec2fe..6156794e 100644 --- a/channel_test.go +++ b/channel_test.go @@ -55,13 +55,13 @@ func newNodeWithStoppableServer(t testing.TB, delay time.Duration) (*RawNode, fu return NewNode(t, addrs[0]), teardown } -func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) Result[proto.Message] { +func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) NodeResponse[proto.Message] { t.Helper() if req.ctx == nil { req.ctx = t.Context() } req.msg = NewRequestMessage(ordering.NewGorumsMetadata(req.ctx, msgID, mock.TestMethod), nil) - replyChan := make(chan Result[proto.Message], 1) + replyChan := make(chan NodeResponse[proto.Message], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -70,13 +70,13 @@ func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) Result[ return resp case <-time.After(defaultTestTimeout): t.Fatalf("timeout waiting for response to message %d", msgID) - return Result[proto.Message]{} + return NodeResponse[proto.Message]{} } } type msgResponse struct { msgID uint64 - resp Result[proto.Message] + resp NodeResponse[proto.Message] } func send(t testing.TB, results chan<- msgResponse, node *RawNode, goroutineID, msgsToSend int, req request) { diff --git a/client_interceptor.go b/client_interceptor.go index a8f8aa5d..88975ff1 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -32,8 +32,8 @@ type QuorumInterceptor[Req, Resp msg, Out any] func(QuorumFunc[Req, Resp, Out]) // (e.g. MajorityQuorum) that actually collects and aggregates responses. type QuorumFunc[Req, Resp msg, Out any] func(*ClientCtx[Req, Resp]) (Out, error) -// Results is an iterator that yields Result[T] values from a quorum call. -type Results[T msg] iter.Seq[Result[T]] +// Responses is an iterator that yields Result[T] values from a quorum call. +type Responses[T msg] iter.Seq[NodeResponse[T]] // ClientCtx provides context and access to the quorum call state for interceptors. // It exposes the request, configuration, and an iterator over node responses. @@ -42,7 +42,7 @@ type ClientCtx[Req, Resp msg] struct { config RawConfiguration request Req method string - replyChan <-chan Result[msg] + replyChan <-chan NodeResponse[msg] reqTransforms []func(Req, *RawNode) Req // expectedReplies is the number of responses we expect to receive. @@ -59,7 +59,7 @@ func newClientCtx[Req, Resp msg]( config RawConfiguration, req Req, method string, - replyChan <-chan Result[msg], + replyChan <-chan NodeResponse[msg], ) *ClientCtx[Req, Resp] { return &ClientCtx[Req, Resp]{ Context: ctx, @@ -148,12 +148,12 @@ func (c *ClientCtx[Req, Resp]) applyTransforms(req Req, node *RawNode) proto.Mes // } // // Process result.Value // } -func (c *ClientCtx[Req, Resp]) Responses() Results[Resp] { +func (c *ClientCtx[Req, Resp]) Responses() Responses[Resp] { // Trigger lazy sending if c.sendOnce != nil { c.sendOnce() } - return func(yield func(Result[Resp]) bool) { + return func(yield func(NodeResponse[Resp]) bool) { // Wait for at most c.expectedReplies for range c.expectedReplies { select { @@ -161,7 +161,7 @@ func (c *ClientCtx[Req, Resp]) Responses() Results[Resp] { // We get a Result[proto.Message] from the channel layer's // response router; however, we convert it to Result[Resp] // here to match the calltype's expected response type. - res := Result[Resp]{ + res := NodeResponse[Resp]{ NodeID: r.NodeID, Err: r.Err, } @@ -196,8 +196,8 @@ func (c *ClientCtx[Req, Resp]) Responses() Results[Resp] { // // resp is guaranteed to be a successful response // process(resp) // } -func (seq Results[Resp]) IgnoreErrors() Results[Resp] { - return func(yield func(Result[Resp]) bool) { +func (seq Responses[Resp]) IgnoreErrors() Responses[Resp] { + return func(yield func(NodeResponse[Resp]) bool) { for result := range seq { if result.Err == nil { if !yield(result) { @@ -211,8 +211,8 @@ func (seq Results[Resp]) IgnoreErrors() Results[Resp] { // Filter returns an iterator that yields only the responses for which the // provided keep function returns true. This is useful for verifying or filtering // responses from servers before further processing. -func (seq Results[Resp]) Filter(keep func(Result[Resp]) bool) Results[Resp] { - return func(yield func(Result[Resp]) bool) { +func (seq Responses[Resp]) Filter(keep func(NodeResponse[Resp]) bool) Responses[Resp] { + return func(yield func(NodeResponse[Resp]) bool) { for result := range seq { if keep(result) { if !yield(result) { @@ -226,7 +226,7 @@ func (seq Results[Resp]) Filter(keep func(Result[Resp]) bool) Results[Resp] { // CollectN collects up to n responses, including errors, from the iterator // into a map by node ID. It returns early if n responses are collected or // the iterator is exhausted. -func (seq Results[Resp]) CollectN(n int) map[uint32]Resp { +func (seq Responses[Resp]) CollectN(n int) map[uint32]Resp { replies := make(map[uint32]Resp, n) for result := range seq { replies[result.NodeID] = result.Value @@ -239,7 +239,7 @@ func (seq Results[Resp]) CollectN(n int) map[uint32]Resp { // CollectAll collects all responses, including errors, from the iterator // into a map by node ID. -func (seq Results[Resp]) CollectAll() map[uint32]Resp { +func (seq Responses[Resp]) CollectAll() map[uint32]Resp { replies := make(map[uint32]Resp) for result := range seq { replies[result.NodeID] = result.Value @@ -448,7 +448,7 @@ func QuorumCallWithInterceptor[Req, Resp msg, Out any]( interceptors ...QuorumInterceptor[Req, Resp, Out], ) (Out, error) { md := ordering.NewGorumsMetadata(ctx, config.getMsgID(), method) - replyChan := make(chan Result[msg], len(config)) + replyChan := make(chan NodeResponse[msg], len(config)) // Create ClientCtx first so sendOnce can access it clientCtx := newClientCtx[Req, Resp](ctx, config, req, method, replyChan) diff --git a/client_interceptor_test.go b/client_interceptor_test.go index f3e5e790..ea60458b 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -112,10 +112,10 @@ func loggingInterceptor[Req, Resp proto.Message](tracker *executionTracker) Quor // makeClientCtx is a helper to create a ClientCtx with mock responses for unit tests. // It creates a channel with the provided responses and returns a ClientCtx. -func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, responses []Result[proto.Message]) *ClientCtx[Req, Resp] { +func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, responses []NodeResponse[proto.Message]) *ClientCtx[Req, Resp] { t.Helper() - resultChan := make(chan Result[proto.Message], len(responses)) + resultChan := make(chan NodeResponse[proto.Message], len(responses)) for _, r := range responses { resultChan <- r } @@ -140,7 +140,7 @@ func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, response func TestIteratorUtilities(t *testing.T) { tests := []struct { name string - responses []Result[proto.Message] + responses []NodeResponse[proto.Message] operation string // "ignoreErrors", "collectN", "collectAll", "filter" collectN int wantCount int @@ -148,7 +148,7 @@ func TestIteratorUtilities(t *testing.T) { }{ { name: "IgnoreErrors", - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("node error")}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -161,7 +161,7 @@ func TestIteratorUtilities(t *testing.T) { }, { name: "CollectN", - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("error")}, {NodeID: 3, Value: pb.String("response"), Err: nil}, @@ -174,7 +174,7 @@ func TestIteratorUtilities(t *testing.T) { }, { name: "CollectAll", - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("error")}, {NodeID: 3, Value: pb.String("response"), Err: nil}, @@ -184,7 +184,7 @@ func TestIteratorUtilities(t *testing.T) { }, { name: "Filter", - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("keep"), Err: nil}, {NodeID: 2, Value: pb.String("drop"), Err: nil}, {NodeID: 3, Value: pb.String("keep"), Err: nil}, @@ -227,7 +227,7 @@ func TestIteratorUtilities(t *testing.T) { case "filter": count := 0 - for resp := range clientCtx.Responses().Filter(func(r Result[*pb.StringValue]) bool { + for resp := range clientCtx.Responses().Filter(func(r NodeResponse[*pb.StringValue]) bool { return r.Value.GetValue() == "keep" }) { count++ @@ -252,7 +252,7 @@ func TestInterceptorChaining(t *testing.T) { tracker := &executionTracker{} // Create mock responses - responses := []Result[proto.Message]{ + responses := []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -289,7 +289,7 @@ var sumInterceptor = func(ctx *ClientCtx[*pb.Int32Value, *pb.Int32Value]) (*pb.I // TestInterceptorCustomAggregation demonstrates custom interceptor for aggregation func TestInterceptorCustomAggregation(t *testing.T) { t.Run("SumAggregation", func(t *testing.T) { - responses := []Result[proto.Message]{ + responses := []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.Int32(10), Err: nil}, {NodeID: 2, Value: pb.Int32(20), Err: nil}, {NodeID: 3, Value: pb.Int32(30), Err: nil}, @@ -326,7 +326,7 @@ func TestInterceptorCustomReturnType(t *testing.T) { return &CustomResult{Total: int(total), Count: count}, nil } - responses := []Result[proto.Message]{ + responses := []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.Int32(10), Err: nil}, {NodeID: 2, Value: pb.Int32(20), Err: nil}, {NodeID: 3, Value: pb.Int32(30), Err: nil}, @@ -423,7 +423,7 @@ func TestInterceptorIntegration_Chaining(t *testing.T) { // TestInterceptorCollectAllResponses tests the CollectAllResponses interceptor func TestInterceptorCollectAllResponses(t *testing.T) { - responses := []Result[proto.Message]{ + responses := []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -579,7 +579,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name string quorumFunc QuorumFunc[*pb.StringValue, *pb.StringValue, *pb.StringValue] numNodes int - responses []Result[proto.Message] + responses []NodeResponse[proto.Message] wantErr bool wantErrType error wantValue string @@ -588,7 +588,7 @@ func TestBaseQuorumFunctions(t *testing.T) { { name: "FirstResponse_Success", quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: pb.String("third"), Err: nil}, @@ -599,7 +599,7 @@ func TestBaseQuorumFunctions(t *testing.T) { { name: "FirstResponse_AfterErrors", quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: nil, Err: errors.New("error1")}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: pb.String("third"), Err: nil}, @@ -610,7 +610,7 @@ func TestBaseQuorumFunctions(t *testing.T) { { name: "FirstResponse_AllErrors", quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: nil, Err: errors.New("error1")}, {NodeID: 2, Value: nil, Err: errors.New("error2")}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -621,7 +621,7 @@ func TestBaseQuorumFunctions(t *testing.T) { { name: "FirstResponse_NoResponses", quorumFunc: FirstResponse[*pb.StringValue, *pb.StringValue], - responses: []Result[proto.Message]{}, + responses: []NodeResponse[proto.Message]{}, wantErr: true, wantErrType: ErrIncomplete, }, @@ -631,7 +631,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "AllResponses_AllSuccess", quorumFunc: AllResponses[*pb.StringValue, *pb.StringValue], numNodes: 3, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: pb.String("third"), Err: nil}, @@ -643,7 +643,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "AllResponses_OneError", quorumFunc: AllResponses[*pb.StringValue, *pb.StringValue], numNodes: 3, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("error2")}, {NodeID: 3, Value: pb.String("third"), Err: nil}, @@ -657,7 +657,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "MajorityQuorum_Success", quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], numNodes: 5, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -671,7 +671,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "MajorityQuorum_Insufficient", quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], numNodes: 5, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -685,7 +685,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "MajorityQuorum_Exact", quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], numNodes: 3, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: nil, Err: errors.New("error")}, @@ -697,7 +697,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "MajorityQuorum_Even_Success", quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], numNodes: 4, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: pb.String("third"), Err: nil}, @@ -710,7 +710,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "MajorityQuorum_Even_Insufficient", quorumFunc: MajorityQuorum[*pb.StringValue, *pb.StringValue], numNodes: 4, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -725,7 +725,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "ThresholdQuorum_Met", quorumFunc: ThresholdQuorum[*pb.StringValue, *pb.StringValue](2), numNodes: 3, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -737,7 +737,7 @@ func TestBaseQuorumFunctions(t *testing.T) { name: "ThresholdQuorum_NotMet", quorumFunc: ThresholdQuorum[*pb.StringValue, *pb.StringValue](3), numNodes: 3, - responses: []Result[proto.Message]{ + responses: []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("error2")}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -783,7 +783,7 @@ func TestInterceptorQuorumSpecAdapter(t *testing.T) { return nil, false } - responses := []Result[proto.Message]{ + responses := []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: pb.String("second"), Err: nil}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -812,7 +812,7 @@ func TestInterceptorQuorumSpecAdapter(t *testing.T) { return nil, false } - responses := []Result[proto.Message]{ + responses := []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("first"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("error2")}, {NodeID: 3, Value: nil, Err: errors.New("error3")}, @@ -836,7 +836,7 @@ func TestInterceptorUsage(t *testing.T) { return pb.String(req.GetValue() + "-transformed") } - responses := []Result[proto.Message]{ + responses := []NodeResponse[proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, diff --git a/correctable.go b/correctable.go index 5d2584f7..07212d8d 100644 --- a/correctable.go +++ b/correctable.go @@ -93,7 +93,7 @@ type CorrectableCallData struct { type correctableCallState struct { md *ordering.Metadata data CorrectableCallData - replyChan <-chan Result[proto.Message] + replyChan <-chan NodeResponse[proto.Message] expectedReplies int } @@ -104,7 +104,7 @@ func (c RawConfiguration) CorrectableCall(ctx context.Context, d CorrectableCall expectedReplies := len(c) md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) - replyChan := make(chan Result[proto.Message], expectedReplies) + replyChan := make(chan NodeResponse[proto.Message], expectedReplies) for _, n := range c { msg := d.Message if d.PerNodeArgFn != nil { diff --git a/multicast.go b/multicast.go index c2856aa8..ab4cdfbd 100644 --- a/multicast.go +++ b/multicast.go @@ -22,9 +22,9 @@ func (c RawConfiguration) Multicast(ctx context.Context, d QuorumCallData, opts md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) sentMsgs := 0 - var replyChan chan Result[proto.Message] + var replyChan chan NodeResponse[proto.Message] if o.waitSendDone { - replyChan = make(chan Result[proto.Message], len(c)) + replyChan = make(chan NodeResponse[proto.Message], len(c)) } for _, n := range c { msg := d.Message diff --git a/quorumcall.go b/quorumcall.go index 99d91c3c..e208a8f8 100644 --- a/quorumcall.go +++ b/quorumcall.go @@ -26,7 +26,7 @@ func (c RawConfiguration) QuorumCall(ctx context.Context, d QuorumCallData) (res expectedReplies := len(c) md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method) - replyChan := make(chan Result[proto.Message], expectedReplies) + replyChan := make(chan NodeResponse[proto.Message], expectedReplies) for _, n := range c { msg := d.Message if d.PerNodeArgFn != nil { diff --git a/rpc.go b/rpc.go index d5031cea..ddc0428f 100644 --- a/rpc.go +++ b/rpc.go @@ -20,7 +20,7 @@ type CallData struct { // This method should be used by generated code only. func (n *RawNode) RPCCall(ctx context.Context, d CallData) (proto.Message, error) { md := ordering.NewGorumsMetadata(ctx, n.mgr.getMsgID(), d.Method) - replyChan := make(chan Result[proto.Message], 1) + replyChan := make(chan NodeResponse[proto.Message], 1) n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, d.Message), responseChan: replyChan}) select { diff --git a/unicast.go b/unicast.go index ad7ffb59..cdf5c3c2 100644 --- a/unicast.go +++ b/unicast.go @@ -28,7 +28,7 @@ func (n *RawNode) Unicast(ctx context.Context, d CallData, opts ...CallOption) { } // Default: block until send completes - replyChan := make(chan Result[proto.Message], 1) + replyChan := make(chan NodeResponse[proto.Message], 1) n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, d.Message), opts: o, responseChan: replyChan}) <-replyChan }