Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (f *Async) Done() bool {
type asyncCallState struct {
md *ordering.Metadata
data QuorumCallData
replyChan <-chan response
replyChan <-chan NodeResponse[proto.Message]
expectedReplies int
}

Expand All @@ -47,7 +47,7 @@ type asyncCallState struct {
func (c RawConfiguration) AsyncCall(ctx context.Context, d QuorumCallData) *Async {
expectedReplies := len(c)
md := ordering.NewGorumsMetadata(ctx, c.getMsgID(), d.Method)
replyChan := make(chan response, expectedReplies)
replyChan := make(chan NodeResponse[proto.Message], expectedReplies)

for _, n := range c {
msg := d.Message
Expand All @@ -58,7 +58,7 @@ func (c RawConfiguration) AsyncCall(ctx context.Context, d QuorumCallData) *Asyn
continue // don't send if no msg
}
}
n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg)}, replyChan)
n.channel.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), responseChan: replyChan})
}

fut := &Async{c: make(chan struct{}, 1)}
Expand Down Expand Up @@ -86,11 +86,11 @@ func (RawConfiguration) handleAsyncCall(ctx context.Context, fut *Async, state a
for {
select {
case r := <-state.replyChan:
if r.err != nil {
errs = append(errs, nodeError{nodeID: r.nid, cause: r.err})
if r.Err != nil {
errs = append(errs, nodeError{nodeID: r.NodeID, cause: r.Err})
break
}
replies[r.nid] = r.msg
replies[r.NodeID] = r.Value
if resp, quorum = state.data.QuorumFunction(state.data.Message, replies); quorum {
fut.reply, fut.err = resp, nil
return
Expand All @@ -100,7 +100,7 @@ func (RawConfiguration) handleAsyncCall(ctx context.Context, fut *Async, state a
return
}
if len(errs)+len(replies) == state.expectedReplies {
fut.reply, fut.err = resp, QuorumCallError{cause: Incomplete, errors: errs, replies: len(replies)}
fut.reply, fut.err = resp, QuorumCallError{cause: ErrIncomplete, errors: errs, replies: len(replies)}
return
}
}
Expand Down
77 changes: 37 additions & 40 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,21 @@ import (
"google.golang.org/protobuf/proto"
)

var streamDownErr = status.Error(codes.Unavailable, "stream is down")

type request struct {
ctx context.Context
msg *Message
opts callOptions
streaming bool
// NodeResponse wraps a response value from node ID, and an error if any.
type NodeResponse[T any] struct {
NodeID uint32
Value T
Err error
}

type response struct {
nid uint32
msg proto.Message
err error
}
var streamDownErr = status.Error(codes.Unavailable, "stream is down")

type responseRouter struct {
c chan<- response
streaming bool
type request struct {
ctx context.Context
msg *Message
opts callOptions
streaming bool
responseChan chan<- NodeResponse[proto.Message]
}

type channel struct {
Expand All @@ -50,8 +47,10 @@ type channel struct {
streamCtx context.Context
cancelStream context.CancelFunc

// Response routing
responseRouters map[uint64]responseRouter
// Response routing; the map holds pending requests waiting for responses.
// The request contains the responseChan on which to send the response
// to the caller.
responseRouters map[uint64]request
responseMut sync.Mutex

// Lifecycle management: node close() cancels this context
Expand All @@ -70,7 +69,7 @@ func newChannel(n *RawNode) *channel {
sendQ: make(chan request, n.mgr.opts.sendBuffer),
node: n,
latency: -1 * time.Second,
responseRouters: make(map[uint64]responseRouter),
responseRouters: make(map[uint64]request),
}
// parentCtx controls the channel and is used to shut it down
c.parentCtx = n.newContext()
Expand Down Expand Up @@ -119,52 +118,50 @@ func (c *channel) isConnected() bool {
return c.node.conn.GetState() == connectivity.Ready && c.getStream() != nil
}

// enqueue adds the request to the send queue and sets up a response router if needed.
// enqueue adds the request to the send queue and sets up response routing if needed.
// If the node is closed, it responds with an error instead.
func (c *channel) enqueue(req request, responseChan chan<- response) {
func (c *channel) enqueue(req request) {
msgID := req.msg.GetMessageID()
if responseChan != nil {
// allocate before critical section
router := responseRouter{responseChan, req.streaming}
if req.responseChan != nil {
c.responseMut.Lock()
c.responseRouters[msgID] = router
c.responseRouters[msgID] = req
c.responseMut.Unlock()
}
// either enqueue the request on the sendQ or respond
// with error if the node is closed
select {
case <-c.parentCtx.Done():
// the node's close() method was called: respond with error instead of enqueueing
c.routeResponse(msgID, response{nid: c.node.ID(), err: fmt.Errorf("node closed")})
c.routeResponse(msgID, NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: fmt.Errorf("node closed")})
return
case c.sendQ <- req:
// enqueued successfully
}
}

// routeResponse routes the response to the appropriate response router based on msgID.
// If no router is found, the response is discarded.
func (c *channel) routeResponse(msgID uint64, resp response) {
// routeResponse routes the response to the appropriate response channel based on msgID.
// If no matching request is found, the response is discarded.
func (c *channel) routeResponse(msgID uint64, resp NodeResponse[proto.Message]) {
c.responseMut.Lock()
defer c.responseMut.Unlock()
if router, ok := c.responseRouters[msgID]; ok {
router.c <- resp
if req, ok := c.responseRouters[msgID]; ok {
req.responseChan <- resp
// delete the router if we are only expecting a single reply message
if !router.streaming {
if !req.streaming {
delete(c.responseRouters, msgID)
}
}
}

// cancelPendingMsgs cancels all pending messages by sending an error response to each router.
// This is typically called when the stream goes down to notify all waiting calls.
// cancelPendingMsgs cancels all pending messages by sending an error response to each
// associated request. This is called when the stream goes down to notify all waiting calls.
func (c *channel) cancelPendingMsgs() {
c.responseMut.Lock()
defer c.responseMut.Unlock()
for msgID, router := range c.responseRouters {
router.c <- response{nid: c.node.ID(), err: streamDownErr}
for msgID, req := range c.responseRouters {
req.responseChan <- NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: streamDownErr}
// delete the router if we are only expecting a single reply message
if !router.streaming {
if !req.streaming {
delete(c.responseRouters, msgID)
}
}
Expand All @@ -191,11 +188,11 @@ func (c *channel) sender() {
// take next request from sendQ
}
if err := c.ensureStream(); err != nil {
c.routeResponse(req.msg.GetMessageID(), response{nid: c.node.ID(), err: err})
c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: err})
continue
}
if err := c.sendMsg(req); err != nil {
c.routeResponse(req.msg.GetMessageID(), response{nid: c.node.ID(), err: err})
c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.node.ID(), Err: err})
}
}
}
Expand All @@ -220,7 +217,7 @@ func (c *channel) receiver() {
c.clearStream()
} else {
err := resp.GetStatus().Err()
c.routeResponse(resp.GetMessageID(), response{nid: c.node.ID(), msg: resp.GetProtoMessage(), err: err})
c.routeResponse(resp.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.node.ID(), Value: resp.GetProtoMessage(), Err: err})
}

select {
Expand Down Expand Up @@ -250,7 +247,7 @@ func (c *channel) sendMsg(req request) (err error) {
// wait for actual server responses, so mustWaitSendDone() returns false for them.
if req.opts.mustWaitSendDone() && err == nil {
// Send succeeded: unblock the caller and clean up the responseRouter
c.routeResponse(req.msg.GetMessageID(), response{})
c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{})
}
}()

Expand Down
53 changes: 27 additions & 26 deletions channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,27 +55,28 @@ func newNodeWithStoppableServer(t testing.TB, delay time.Duration) (*RawNode, fu
return NewNode(t, addrs[0]), teardown
}

func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) response {
func sendRequest(t testing.TB, node *RawNode, req request, msgID uint64) NodeResponse[proto.Message] {
t.Helper()
if req.ctx == nil {
req.ctx = t.Context()
}
req.msg = NewRequestMessage(ordering.NewGorumsMetadata(req.ctx, msgID, mock.TestMethod), nil)
replyChan := make(chan response, 1)
node.channel.enqueue(req, replyChan)
replyChan := make(chan NodeResponse[proto.Message], 1)
req.responseChan = replyChan
node.channel.enqueue(req)

select {
case resp := <-replyChan:
return resp
case <-time.After(defaultTestTimeout):
t.Fatalf("timeout waiting for response to message %d", msgID)
return response{}
return NodeResponse[proto.Message]{}
}
}

type msgResponse struct {
msgID uint64
resp response
resp NodeResponse[proto.Message]
}

func send(t testing.TB, results chan<- msgResponse, node *RawNode, goroutineID, msgsToSend int, req request) {
Expand Down Expand Up @@ -104,7 +105,7 @@ func TestChannelCreation(t *testing.T) {

// send message when server is down
resp := sendRequest(t, node, request{opts: waitSendDone}, 1)
if resp.err == nil {
if resp.Err == nil {
t.Error("response err: got <nil>, want error")
}
}
Expand All @@ -122,8 +123,8 @@ func TestChannelShutdown(t *testing.T) {
for i := range numMessages {
wg.Go(func() {
resp := sendRequest(t, node, request{}, uint64(i))
if resp.err != nil {
t.Errorf("unexpected error for message %d, got error: %v", i, resp.err)
if resp.Err != nil {
t.Errorf("unexpected error for message %d, got error: %v", i, resp.Err)
}
})
}
Expand All @@ -136,10 +137,10 @@ func TestChannelShutdown(t *testing.T) {

// try to send a message after node closure
resp := sendRequest(t, node, request{}, 999)
if resp.err == nil {
if resp.Err == nil {
t.Error("expected error when sending to closed channel")
} else if resp.err.Error() != "node closed" {
t.Errorf("expected 'node closed' error, got: %v", resp.err)
} else if resp.Err.Error() != "node closed" {
t.Errorf("expected 'node closed' error, got: %v", resp.Err)
}

if node.channel.isConnected() {
Expand All @@ -163,8 +164,8 @@ func TestChannelSendCompletionWaiting(t *testing.T) {
start := time.Now()
resp := sendRequest(t, node, request{opts: tt.opts}, uint64(i))
elapsed := time.Since(start)
if resp.err != nil {
t.Errorf("unexpected error: %v", resp.err)
if resp.Err != nil {
t.Errorf("unexpected error: %v", resp.Err)
}
t.Logf("response received in %v", elapsed)
})
Expand Down Expand Up @@ -214,8 +215,8 @@ func TestChannelErrors(t *testing.T) {
setup: func(t *testing.T) *RawNode {
node, stopServer := newNodeWithStoppableServer(t, 0)
resp := sendRequest(t, node, request{opts: waitSendDone}, 1)
if resp.err != nil {
t.Errorf("first message should succeed, got error: %v", resp.err)
if resp.Err != nil {
t.Errorf("first message should succeed, got error: %v", resp.Err)
}
stopServer()
return node
Expand All @@ -230,10 +231,10 @@ func TestChannelErrors(t *testing.T) {

// Send message and verify error
resp := sendRequest(t, node, request{opts: waitSendDone}, uint64(i))
if resp.err == nil {
if resp.Err == nil {
t.Errorf("expected error '%s' but got nil", tt.wantErr)
} else if !strings.Contains(resp.err.Error(), tt.wantErr) {
t.Errorf("expected error '%s', got: %v", tt.wantErr, resp.err)
} else if !strings.Contains(resp.Err.Error(), tt.wantErr) {
t.Errorf("expected error '%s', got: %v", tt.wantErr, resp.Err)
}
})
}
Expand Down Expand Up @@ -401,8 +402,8 @@ func TestChannelConcurrentSends(t *testing.T) {
var errs []error
for range numMessages {
res := <-results
if res.resp.err != nil {
errs = append(errs, res.resp.err)
if res.resp.Err != nil {
errs = append(errs, res.resp.Err)
}
}

Expand Down Expand Up @@ -477,8 +478,8 @@ func TestChannelContext(t *testing.T) {

node := newNodeWithServer(t, tt.serverDelay)
resp := sendRequest(t, node, request{ctx: ctx, opts: tt.callOpts}, uint64(i))
if !errors.Is(resp.err, tt.wantErr) {
t.Errorf("expected %v, got: %v", tt.wantErr, resp.err)
if !errors.Is(resp.Err, tt.wantErr) {
t.Errorf("expected %v, got: %v", tt.wantErr, resp.Err)
}
})
}
Expand Down Expand Up @@ -578,8 +579,8 @@ func TestChannelRouterLifecycle(t *testing.T) {
t.Run(name, func(t *testing.T) {
msgID := uint64(i)
resp := sendRequest(t, node, request{opts: tt.opts, streaming: tt.streaming}, msgID)
if resp.err != nil {
t.Errorf("unexpected error: %v", resp.err)
if resp.Err != nil {
t.Errorf("unexpected error: %v", resp.Err)
}
if exists := routerExists(node, msgID); exists != tt.wantRouter {
t.Errorf("router exists = %v, want %v", exists, tt.wantRouter)
Expand All @@ -605,8 +606,8 @@ func TestChannelResponseRouting(t *testing.T) {
received := make(map[uint64]bool)
for range numMessages {
result := <-results
if result.resp.err != nil {
t.Errorf("message %d got error: %v", result.msgID, result.resp.err)
if result.resp.Err != nil {
t.Errorf("message %d got error: %v", result.msgID, result.resp.Err)
}
if received[result.msgID] {
t.Errorf("message %d received twice", result.msgID)
Expand Down
Loading