diff --git a/comm.go b/comm.go index af0ba207..47e43912 100644 --- a/comm.go +++ b/comm.go @@ -165,7 +165,7 @@ func (p *PubSub) handlePeerDead(s network.Stream) { } func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue) { - writeRpc := func(rpc *RPC) error { + writeRpc := func(rpc *pb.RPC) error { size := uint64(rpc.Size()) buf := pool.Get(varint.UvarintSize(size) + int(size)) @@ -193,8 +193,11 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou p.logger.Debug("error popping message from the queue to send to peer", "peer", s.Conn().RemotePeer(), "err", err) return } + if rpc.Size() == 0 { + continue + } - err = writeRpc(rpc) + err = writeRpc(&rpc.RPC) if err != nil { s.Reset() p.logger.Debug("error writing message to peer", "peer", s.Conn().RemotePeer(), "err", err) @@ -202,45 +205,3 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou } } } - -func rpcWithSubs(subs ...*pb.RPC_SubOpts) *RPC { - return &RPC{ - RPC: pb.RPC{ - Subscriptions: subs, - }, - } -} - -func rpcWithMessages(msgs ...*pb.Message) *RPC { - return &RPC{RPC: pb.RPC{Publish: msgs}} -} - -func rpcWithControl(msgs []*pb.Message, - ihave []*pb.ControlIHave, - iwant []*pb.ControlIWant, - graft []*pb.ControlGraft, - prune []*pb.ControlPrune, - idontwant []*pb.ControlIDontWant) *RPC { - return &RPC{ - RPC: pb.RPC{ - Publish: msgs, - Control: &pb.ControlMessage{ - Ihave: ihave, - Iwant: iwant, - Graft: graft, - Prune: prune, - Idontwant: idontwant, - }, - }, - } -} - -func copyRPC(rpc *RPC) *RPC { - res := new(RPC) - *res = *rpc - if rpc.Control != nil { - res.Control = new(pb.ControlMessage) - *res.Control = *rpc.Control - } - return res -} diff --git a/gossipsub.go b/gossipsub.go index c492ded9..fdfa94d6 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -1163,6 +1163,9 @@ func (gs *GossipSubRouter) handleIDontWant(p peer.ID, ctl *pb.ControlMessage) { gs.peerdontwant[p]++ totalUnwantedIds := 0 + // Collect message IDs for cancellation + var msgIDsToCancel []string + // Remember all the unwanted message ids mainIDWLoop: for _, idontwant := range ctl.GetIdontwant() { @@ -1175,8 +1178,14 @@ mainIDWLoop: totalUnwantedIds++ gs.unwanted[p][computeChecksum(mid)] = gs.params.IDontWantMessageTTL + msgIDsToCancel = append(msgIDsToCancel, mid) } } + + // Cancel these messages in the RPC queue if it exists + if queue, ok := gs.p.peers[p]; ok && len(msgIDsToCancel) > 0 { + queue.CancelMessages(msgIDsToCancel) + } } func (gs *GossipSubRouter) addBackoff(p peer.ID, topic string, isUnsubscribe bool) { @@ -1316,6 +1325,7 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { return } + msgID := gs.p.idGen.ID(msg) if gs.floodPublish && from == gs.p.host.ID() { for p := range tmap { _, direct := gs.direct[p] @@ -1359,7 +1369,7 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { gs.lastpub[topic] = time.Now().UnixNano() } - csum := computeChecksum(gs.p.idGen.ID(msg)) + csum := computeChecksum(msgID) for p := range gmap { // Check if it has already received an IDONTWANT for the message. // If so, don't send it to the peer @@ -1370,7 +1380,7 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] { } } - out := rpcWithMessages(msg.Message) + out := rpcWithMessageAndMsgID(msg.Message, msgID) for pid := range tosend { if pid == from || pid == peer.ID(msg.GetFrom()) { continue diff --git a/gossipsub_spam_test.go b/gossipsub_spam_test.go index e42b8859..b6356d5d 100644 --- a/gossipsub_spam_test.go +++ b/gossipsub_spam_test.go @@ -930,20 +930,25 @@ func TestGossipsubHandleIDontwantSpam(t *testing.T) { rPid := hosts[1].ID() ctrlMessage := &pb.ControlMessage{Idontwant: []*pb.ControlIDontWant{{MessageIDs: idwIds}}} grt := psubs[0].rt.(*GossipSubRouter) - grt.handleIDontWant(rPid, ctrlMessage) + completed := make(chan struct{}) + psubs[0].eval <- func() { + grt.handleIDontWant(rPid, ctrlMessage) - if grt.peerdontwant[rPid] != 1 { - t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid]) - } - mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1) - if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok { - t.Errorf("Desired message id was not stored in the unwanted map: %s", mid) - } + if grt.peerdontwant[rPid] != 1 { + t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid]) + } + mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1) + if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok { + t.Errorf("Desired message id was not stored in the unwanted map: %s", mid) + } - mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength) - if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok { - t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid) + mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength) + if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok { + t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid) + } + close(completed) } + <-completed } type mockGSOnRead func(writeMsg func(*pb.RPC), irpc *pb.RPC) diff --git a/pubsub.go b/pubsub.go index 03ada8ef..d0a59971 100644 --- a/pubsub.go +++ b/pubsub.go @@ -257,6 +257,8 @@ func (m *Message) GetFrom() peer.ID { type RPC struct { pb.RPC + // MessageIDs are the ids of the messages in the rpc. MsgID[i] = id(rpc.Publish[i]) + MessageIDs []string // unexported on purpose, not sending this over the wire from peer.ID @@ -274,6 +276,12 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] { messagesInNextRPC := 0 messageSlice := rpc.Publish + // This slice may be smaller than messageSlice. + // Only use len(messageIDSlice) elements. + // + // For most cases it'll either be len(RPC.Messages) or empty in which + // case the split RPCs will have it equal to len(RPC.Messages) or empty + messageIDSlice := rpc.MessageIDs // Merge/Append publish messages. This pattern is optimized compared the // the patterns for other fields because this is the common cause for @@ -286,6 +294,14 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] { // into this RPC, yield it, then make a new one nextRPC.Publish = messageSlice[:messagesInNextRPC] messageSlice = messageSlice[messagesInNextRPC:] + if len(messageIDSlice) >= messagesInNextRPC { + nextRPC.MessageIDs = messageIDSlice[:messagesInNextRPC] + messageIDSlice = messageIDSlice[messagesInNextRPC:] + } else { + nextRPC.MessageIDs = messageIDSlice + messageIDSlice = messageIDSlice[len(messageIDSlice):] + } + if !yield(nextRPC) { return } @@ -303,6 +319,11 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] { // packing this RPC, but we avoid successively calling .Size() // on the messages for the next parts. nextRPC.Publish = messageSlice[:messagesInNextRPC] + if len(messageIDSlice) >= messagesInNextRPC { + nextRPC.MessageIDs = messageIDSlice[:messagesInNextRPC] + } else { + nextRPC.MessageIDs = messageIDSlice + } if !yield(nextRPC) { return } @@ -438,6 +459,52 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] { } } +func rpcWithSubs(subs ...*pb.RPC_SubOpts) *RPC { + return &RPC{ + RPC: pb.RPC{ + Subscriptions: subs, + }, + } +} + +func rpcWithMessages(msgs ...*pb.Message) *RPC { + return &RPC{RPC: pb.RPC{Publish: msgs}} +} + +func rpcWithMessageAndMsgID(msg *pb.Message, msgID string) *RPC { + return &RPC{RPC: pb.RPC{Publish: []*pb.Message{msg}}, MessageIDs: []string{msgID}} +} + +func rpcWithControl(msgs []*pb.Message, + ihave []*pb.ControlIHave, + iwant []*pb.ControlIWant, + graft []*pb.ControlGraft, + prune []*pb.ControlPrune, + idontwant []*pb.ControlIDontWant) *RPC { + return &RPC{ + RPC: pb.RPC{ + Publish: msgs, + Control: &pb.ControlMessage{ + Ihave: ihave, + Iwant: iwant, + Graft: graft, + Prune: prune, + Idontwant: idontwant, + }, + }, + } +} + +func copyRPC(rpc *RPC) *RPC { + res := new(RPC) + *res = *rpc + if rpc.Control != nil { + res.Control = new(pb.ControlMessage) + *res.Control = *rpc.Control + } + return res +} + // pbFieldNumberLT15Size is the number of bytes required to encode a protobuf // field number less than or equal to 15 along with its wire type. This is 1 // byte because the protobuf encoding of field numbers is a varint encoding of: diff --git a/rpc_queue.go b/rpc_queue.go index e5c22935..e434abb5 100644 --- a/rpc_queue.go +++ b/rpc_queue.go @@ -3,6 +3,7 @@ package pubsub import ( "context" "errors" + "slices" "sync" ) @@ -50,20 +51,41 @@ type rpcQueue struct { dataAvailable sync.Cond spaceAvailable sync.Cond // Mutex used to access queue - queueMu sync.Mutex - queue priorityQueue + queueMu sync.Mutex + queue priorityQueue + queuedMessageIDs map[string]struct{} // messageids in queue + cancelledMessageIDs map[string]struct{} // messageids that'll be dropped before sending closed bool maxSize int } func newRpcQueue(maxSize int) *rpcQueue { - q := &rpcQueue{maxSize: maxSize} + q := &rpcQueue{ + maxSize: maxSize, + queuedMessageIDs: make(map[string]struct{}), + cancelledMessageIDs: make(map[string]struct{}), + } q.dataAvailable.L = &q.queueMu q.spaceAvailable.L = &q.queueMu return q } +// CancelMessages marks the given message IDs for cancellation only if they are already in queue. +func (q *rpcQueue) CancelMessages(msgIDs []string) { + q.queueMu.Lock() + defer q.queueMu.Unlock() + + for _, id := range msgIDs { + if id != "" { + // Only cancel messages that are actually in the queue + if _, ok := q.queuedMessageIDs[id]; ok { + q.cancelledMessageIDs[id] = struct{}{} + } + } + } +} + func (q *rpcQueue) Push(rpc *RPC, block bool) error { return q.push(rpc, false, block) } @@ -91,11 +113,17 @@ func (q *rpcQueue) push(rpc *RPC, urgent bool, block bool) error { return ErrQueueFull } } + if urgent { q.queue.PriorityPush(rpc) } else { q.queue.NormalPush(rpc) } + for _, id := range rpc.MessageIDs { + if id != "" { + q.queuedMessageIDs[id] = struct{}{} + } + } q.dataAvailable.Signal() return nil @@ -133,10 +161,53 @@ func (q *rpcQueue) Pop(ctx context.Context) (*RPC, error) { } } rpc := q.queue.Pop() + rpc = q.handleCancellations(rpc) q.spaceAvailable.Signal() return rpc, nil } +func (q *rpcQueue) handleCancellations(rpc *RPC) *RPC { + hasCancellations := false + for _, msgID := range rpc.MessageIDs { + delete(q.queuedMessageIDs, msgID) + if _, ok := q.cancelledMessageIDs[msgID]; ok { + hasCancellations = true + } + } + if hasCancellations { + // clone the RPC parts that we'll modify. It may be shared with other queues. + newRPC := *rpc + newRPC.RPC.Publish = slices.Clone(rpc.RPC.Publish) + newRPC.MessageIDs = slices.Clone(rpc.MessageIDs) + rpc = &newRPC + // Ensure looping over MessageIDs. They may not be present. In that case, we wouldn't + // be in this branch but don't risk it. + for i, msgID := range newRPC.MessageIDs { + if msgID == "" { + continue + } + _, ok := q.cancelledMessageIDs[msgID] + if !ok { + continue + } + delete(q.cancelledMessageIDs, msgID) + rpc.RPC.Publish[i] = nil + rpc.MessageIDs[i] = "" + } + nextEmpty := 0 + for i := 0; i < len(rpc.MessageIDs); i++ { + if rpc.RPC.Publish[i] != nil { + rpc.RPC.Publish[nextEmpty] = rpc.RPC.Publish[i] + rpc.MessageIDs[nextEmpty] = rpc.MessageIDs[i] + nextEmpty++ + } + } + rpc.RPC.Publish = rpc.RPC.Publish[:nextEmpty] + rpc.MessageIDs = rpc.MessageIDs[:nextEmpty] + } + return rpc +} + func (q *rpcQueue) Close() { q.queueMu.Lock() defer q.queueMu.Unlock() diff --git a/rpc_queue_test.go b/rpc_queue_test.go index 6e92ee56..eae0a30c 100644 --- a/rpc_queue_test.go +++ b/rpc_queue_test.go @@ -2,8 +2,12 @@ package pubsub import ( "context" + "fmt" + "slices" "testing" "time" + + pb "github.com/libp2p/go-libp2p-pubsub/pb" ) func TestNewRpcQueue(t *testing.T) { @@ -227,3 +231,112 @@ func TestRpcQueueCancelPop(t *testing.T) { t.Fatalf("rpc queue Pop returns wrong error when it's cancelled") } } + +func TestRPCQueueCancellations(t *testing.T) { + maxSize := 32 + q := newRpcQueue(maxSize) + + getMesssages := func(n int) ([]*pb.Message, []string) { + msgs := make([]*pb.Message, n) + msgIDs := make([]string, n) + for i := range msgs { + msgs[i] = &pb.Message{Data: []byte(fmt.Sprintf("message%d", i+1))} + msgIDs[i] = fmt.Sprintf("msg%d", i+1) + } + return msgs, msgIDs + } + + t.Run("cancel all", func(t *testing.T) { + msgs, msgIDs := getMesssages(10) + rpc := &RPC{ + RPC: pb.RPC{Publish: slices.Clone(msgs)}, + MessageIDs: slices.Clone(msgIDs), + } + q.Push(rpc, true) + q.CancelMessages(msgIDs) + popped, err := q.Pop(context.Background()) + if err != nil { + t.Fatalf("failed to pop RPC: %v", err) + } + if len(popped.Publish) != 0 { + t.Fatalf("expected popped.Publish to be empty, got %v", popped.Publish) + } + if len(popped.MessageIDs) != 0 { + t.Fatalf("expected popped.MsgIDs to be empty, got %v", popped.MessageIDs) + } + if len(q.queuedMessageIDs) != 0 { + t.Fatalf("expected q.queuedMsgIDs to be empty, got %v", q.queuedMessageIDs) + } + if len(q.cancelledMessageIDs) != 0 { + t.Fatalf("expected q.cancelledIDs to be empty, got %v", q.cancelledMessageIDs) + } + }) + + t.Run("cancel some", func(t *testing.T) { + msgs, msgIDs := getMesssages(10) + rpc := &RPC{ + RPC: pb.RPC{Publish: slices.Clone(msgs)}, + MessageIDs: slices.Clone(msgIDs), + } + q.Push(rpc, true) + q.CancelMessages(msgIDs[:3]) + popped, err := q.Pop(context.Background()) + if err != nil { + t.Fatalf("failed to pop RPC: %v", err) + } + if !slices.Equal(msgs[3:], popped.Publish) { + t.Fatalf("expected popped.Publish to be %v, got %v", msgs[3:], popped.Publish) + } + if !slices.Equal(msgIDs[3:], popped.MessageIDs) { + t.Fatalf("expected popped.MsgIDs to be %v, got %v", msgIDs[3:], popped.MessageIDs) + } + if len(q.queuedMessageIDs) != 0 { + t.Fatalf("expected q.queuedMsgIDs to be empty, got %v", q.queuedMessageIDs) + } + if len(q.cancelledMessageIDs) != 0 { + t.Fatalf("expected q.cancelledIDs to be empty, got %v", q.cancelledMessageIDs) + } + }) + + t.Run("only one rpc cancelled", func(t *testing.T) { + msgs, msgIDs := getMesssages(10) + rpc := &RPC{ + RPC: pb.RPC{Publish: slices.Clone(msgs)}, + MessageIDs: slices.Clone(msgIDs), + } + q.Push(rpc, true) + rpc2 := &RPC{ + RPC: pb.RPC{Publish: slices.Clone(msgs)}, + MessageIDs: slices.Clone(msgIDs), + } + q.Push(rpc2, true) + q.CancelMessages(msgIDs[:3]) + popped, err := q.Pop(context.Background()) + if err != nil { + t.Fatalf("failed to pop RPC: %v", err) + } + if !slices.Equal(msgs[3:], popped.Publish) { + t.Fatalf("expected popped.Publish to be %v, got %v", msgs[3:], popped.Publish) + } + if !slices.Equal(msgIDs[3:], popped.MessageIDs) { + t.Fatalf("expected popped.MsgIDs to be %v, got %v", msgIDs[3:], popped.MessageIDs) + } + + popped, err = q.Pop(context.Background()) + if err != nil { + t.Fatalf("failed to pop RPC: %v", err) + } + if !slices.Equal(msgs, popped.Publish) { + t.Fatalf("expected popped.Publish to be %v, got %v", msgs, popped.Publish) + } + if !slices.Equal(msgIDs, popped.MessageIDs) { + t.Fatalf("expected popped.MsgIDs to be %v, got %v", msgIDs, popped.MessageIDs) + } + if len(q.queuedMessageIDs) != 0 { + t.Fatalf("expected q.queuedMsgIDs to be empty, got %v", q.queuedMessageIDs) + } + if len(q.cancelledMessageIDs) != 0 { + t.Fatalf("expected q.cancelledIDs to be empty, got %v", q.cancelledMessageIDs) + } + }) +}