Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 5 additions & 44 deletions comm.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ func (p *PubSub) handlePeerDead(s network.Stream) {
}

func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, outgoing *rpcQueue) {
writeRpc := func(rpc *RPC) error {
writeRpc := func(rpc *pb.RPC) error {
size := uint64(rpc.Size())

buf := pool.Get(varint.UvarintSize(size) + int(size))
Expand Down Expand Up @@ -193,54 +193,15 @@ func (p *PubSub) handleSendingMessages(ctx context.Context, s network.Stream, ou
p.logger.Debug("error popping message from the queue to send to peer", "peer", s.Conn().RemotePeer(), "err", err)
return
}
if rpc.Size() == 0 {
continue
}

err = writeRpc(rpc)
err = writeRpc(&rpc.RPC)
if err != nil {
s.Reset()
p.logger.Debug("error writing message to peer", "peer", s.Conn().RemotePeer(), "err", err)
return
}
}
}

func rpcWithSubs(subs ...*pb.RPC_SubOpts) *RPC {
return &RPC{
RPC: pb.RPC{
Subscriptions: subs,
},
}
}

func rpcWithMessages(msgs ...*pb.Message) *RPC {
return &RPC{RPC: pb.RPC{Publish: msgs}}
}

func rpcWithControl(msgs []*pb.Message,
ihave []*pb.ControlIHave,
iwant []*pb.ControlIWant,
graft []*pb.ControlGraft,
prune []*pb.ControlPrune,
idontwant []*pb.ControlIDontWant) *RPC {
return &RPC{
RPC: pb.RPC{
Publish: msgs,
Control: &pb.ControlMessage{
Ihave: ihave,
Iwant: iwant,
Graft: graft,
Prune: prune,
Idontwant: idontwant,
},
},
}
}

func copyRPC(rpc *RPC) *RPC {
res := new(RPC)
*res = *rpc
if rpc.Control != nil {
res.Control = new(pb.ControlMessage)
*res.Control = *rpc.Control
}
return res
}
14 changes: 12 additions & 2 deletions gossipsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -1163,6 +1163,9 @@ func (gs *GossipSubRouter) handleIDontWant(p peer.ID, ctl *pb.ControlMessage) {
gs.peerdontwant[p]++

totalUnwantedIds := 0
// Collect message IDs for cancellation
var msgIDsToCancel []string

// Remember all the unwanted message ids
mainIDWLoop:
for _, idontwant := range ctl.GetIdontwant() {
Expand All @@ -1175,8 +1178,14 @@ mainIDWLoop:

totalUnwantedIds++
gs.unwanted[p][computeChecksum(mid)] = gs.params.IDontWantMessageTTL
msgIDsToCancel = append(msgIDsToCancel, mid)
}
}

// Cancel these messages in the RPC queue if it exists
if queue, ok := gs.p.peers[p]; ok && len(msgIDsToCancel) > 0 {
queue.CancelMessages(msgIDsToCancel)
}
}

func (gs *GossipSubRouter) addBackoff(p peer.ID, topic string, isUnsubscribe bool) {
Expand Down Expand Up @@ -1316,6 +1325,7 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
return
}

msgID := gs.p.idGen.ID(msg)
if gs.floodPublish && from == gs.p.host.ID() {
for p := range tmap {
_, direct := gs.direct[p]
Expand Down Expand Up @@ -1359,7 +1369,7 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
gs.lastpub[topic] = time.Now().UnixNano()
}

csum := computeChecksum(gs.p.idGen.ID(msg))
csum := computeChecksum(msgID)
for p := range gmap {
// Check if it has already received an IDONTWANT for the message.
// If so, don't send it to the peer
Expand All @@ -1370,7 +1380,7 @@ func (gs *GossipSubRouter) rpcs(msg *Message) iter.Seq2[peer.ID, *RPC] {
}
}

out := rpcWithMessages(msg.Message)
out := rpcWithMessageAndMsgID(msg.Message, msgID)
for pid := range tosend {
if pid == from || pid == peer.ID(msg.GetFrom()) {
continue
Expand Down
27 changes: 16 additions & 11 deletions gossipsub_spam_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -930,20 +930,25 @@ func TestGossipsubHandleIDontwantSpam(t *testing.T) {
rPid := hosts[1].ID()
ctrlMessage := &pb.ControlMessage{Idontwant: []*pb.ControlIDontWant{{MessageIDs: idwIds}}}
grt := psubs[0].rt.(*GossipSubRouter)
grt.handleIDontWant(rPid, ctrlMessage)
completed := make(chan struct{})
psubs[0].eval <- func() {
grt.handleIDontWant(rPid, ctrlMessage)

if grt.peerdontwant[rPid] != 1 {
t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid])
}
mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1)
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok {
t.Errorf("Desired message id was not stored in the unwanted map: %s", mid)
}
if grt.peerdontwant[rPid] != 1 {
t.Errorf("Wanted message count of %d but received %d", 1, grt.peerdontwant[rPid])
}
mid := fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength-1)
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; !ok {
t.Errorf("Desired message id was not stored in the unwanted map: %s", mid)
}

mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength)
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok {
t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid)
mid = fmt.Sprintf("idontwant-%d", GossipSubMaxIDontWantLength)
if _, ok := grt.unwanted[rPid][computeChecksum(mid)]; ok {
t.Errorf("Unwanted message id was stored in the unwanted map: %s", mid)
}
close(completed)
}
<-completed
}

type mockGSOnRead func(writeMsg func(*pb.RPC), irpc *pb.RPC)
Expand Down
67 changes: 67 additions & 0 deletions pubsub.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ func (m *Message) GetFrom() peer.ID {

type RPC struct {
pb.RPC
// MessageIDs are the ids of the messages in the rpc. MsgID[i] = id(rpc.Publish[i])
MessageIDs []string

// unexported on purpose, not sending this over the wire
from peer.ID
Expand All @@ -274,6 +276,12 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] {

messagesInNextRPC := 0
messageSlice := rpc.Publish
// This slice may be smaller than messageSlice.
// Only use len(messageIDSlice) elements.
//
// For most cases it'll either be len(RPC.Messages) or empty in which
// case the split RPCs will have it equal to len(RPC.Messages) or empty
messageIDSlice := rpc.MessageIDs

// Merge/Append publish messages. This pattern is optimized compared the
// the patterns for other fields because this is the common cause for
Expand All @@ -286,6 +294,14 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] {
// into this RPC, yield it, then make a new one
nextRPC.Publish = messageSlice[:messagesInNextRPC]
messageSlice = messageSlice[messagesInNextRPC:]
if len(messageIDSlice) >= messagesInNextRPC {
nextRPC.MessageIDs = messageIDSlice[:messagesInNextRPC]
messageIDSlice = messageIDSlice[messagesInNextRPC:]
} else {
nextRPC.MessageIDs = messageIDSlice
messageIDSlice = messageIDSlice[len(messageIDSlice):]
}

if !yield(nextRPC) {
return
}
Expand All @@ -303,6 +319,11 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] {
// packing this RPC, but we avoid successively calling .Size()
// on the messages for the next parts.
nextRPC.Publish = messageSlice[:messagesInNextRPC]
if len(messageIDSlice) >= messagesInNextRPC {
nextRPC.MessageIDs = messageIDSlice[:messagesInNextRPC]
} else {
nextRPC.MessageIDs = messageIDSlice
}
if !yield(nextRPC) {
return
}
Expand Down Expand Up @@ -438,6 +459,52 @@ func (rpc *RPC) split(limit int) iter.Seq[RPC] {
}
}

func rpcWithSubs(subs ...*pb.RPC_SubOpts) *RPC {
return &RPC{
RPC: pb.RPC{
Subscriptions: subs,
},
}
}

func rpcWithMessages(msgs ...*pb.Message) *RPC {
return &RPC{RPC: pb.RPC{Publish: msgs}}
}

func rpcWithMessageAndMsgID(msg *pb.Message, msgID string) *RPC {
return &RPC{RPC: pb.RPC{Publish: []*pb.Message{msg}}, MessageIDs: []string{msgID}}
}

func rpcWithControl(msgs []*pb.Message,
ihave []*pb.ControlIHave,
iwant []*pb.ControlIWant,
graft []*pb.ControlGraft,
prune []*pb.ControlPrune,
idontwant []*pb.ControlIDontWant) *RPC {
return &RPC{
RPC: pb.RPC{
Publish: msgs,
Control: &pb.ControlMessage{
Ihave: ihave,
Iwant: iwant,
Graft: graft,
Prune: prune,
Idontwant: idontwant,
},
},
}
}

func copyRPC(rpc *RPC) *RPC {
res := new(RPC)
*res = *rpc
if rpc.Control != nil {
res.Control = new(pb.ControlMessage)
*res.Control = *rpc.Control
}
return res
}

// pbFieldNumberLT15Size is the number of bytes required to encode a protobuf
// field number less than or equal to 15 along with its wire type. This is 1
// byte because the protobuf encoding of field numbers is a varint encoding of:
Expand Down
77 changes: 74 additions & 3 deletions rpc_queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package pubsub
import (
"context"
"errors"
"slices"
"sync"
)

Expand Down Expand Up @@ -50,20 +51,41 @@ type rpcQueue struct {
dataAvailable sync.Cond
spaceAvailable sync.Cond
// Mutex used to access queue
queueMu sync.Mutex
queue priorityQueue
queueMu sync.Mutex
queue priorityQueue
queuedMessageIDs map[string]struct{} // messageids in queue
cancelledMessageIDs map[string]struct{} // messageids that'll be dropped before sending

closed bool
maxSize int
}

func newRpcQueue(maxSize int) *rpcQueue {
q := &rpcQueue{maxSize: maxSize}
q := &rpcQueue{
maxSize: maxSize,
queuedMessageIDs: make(map[string]struct{}),
cancelledMessageIDs: make(map[string]struct{}),
}
q.dataAvailable.L = &q.queueMu
q.spaceAvailable.L = &q.queueMu
return q
}

// CancelMessages marks the given message IDs for cancellation only if they are already in queue.
func (q *rpcQueue) CancelMessages(msgIDs []string) {
q.queueMu.Lock()
defer q.queueMu.Unlock()

for _, id := range msgIDs {
if id != "" {
// Only cancel messages that are actually in the queue
if _, ok := q.queuedMessageIDs[id]; ok {
q.cancelledMessageIDs[id] = struct{}{}
}
}
}
}

func (q *rpcQueue) Push(rpc *RPC, block bool) error {
return q.push(rpc, false, block)
}
Expand Down Expand Up @@ -91,11 +113,17 @@ func (q *rpcQueue) push(rpc *RPC, urgent bool, block bool) error {
return ErrQueueFull
}
}

if urgent {
q.queue.PriorityPush(rpc)
} else {
q.queue.NormalPush(rpc)
}
for _, id := range rpc.MessageIDs {
if id != "" {
q.queuedMessageIDs[id] = struct{}{}
}
}

q.dataAvailable.Signal()
return nil
Expand Down Expand Up @@ -133,10 +161,53 @@ func (q *rpcQueue) Pop(ctx context.Context) (*RPC, error) {
}
}
rpc := q.queue.Pop()
rpc = q.handleCancellations(rpc)
q.spaceAvailable.Signal()
return rpc, nil
}

func (q *rpcQueue) handleCancellations(rpc *RPC) *RPC {
hasCancellations := false
for _, msgID := range rpc.MessageIDs {
delete(q.queuedMessageIDs, msgID)
if _, ok := q.cancelledMessageIDs[msgID]; ok {
hasCancellations = true
}
}
if hasCancellations {
// clone the RPC parts that we'll modify. It may be shared with other queues.
newRPC := *rpc
newRPC.RPC.Publish = slices.Clone(rpc.RPC.Publish)
newRPC.MessageIDs = slices.Clone(rpc.MessageIDs)
rpc = &newRPC
// Ensure looping over MessageIDs. They may not be present. In that case, we wouldn't
// be in this branch but don't risk it.
for i, msgID := range newRPC.MessageIDs {
if msgID == "" {
continue
}
_, ok := q.cancelledMessageIDs[msgID]
if !ok {
continue
}
delete(q.cancelledMessageIDs, msgID)
rpc.RPC.Publish[i] = nil
rpc.MessageIDs[i] = ""
}
nextEmpty := 0
for i := 0; i < len(rpc.MessageIDs); i++ {
if rpc.RPC.Publish[i] != nil {
rpc.RPC.Publish[nextEmpty] = rpc.RPC.Publish[i]
rpc.MessageIDs[nextEmpty] = rpc.MessageIDs[i]
nextEmpty++
}
}
rpc.RPC.Publish = rpc.RPC.Publish[:nextEmpty]
rpc.MessageIDs = rpc.MessageIDs[:nextEmpty]
}
return rpc
}

func (q *rpcQueue) Close() {
q.queueMu.Lock()
defer q.queueMu.Unlock()
Expand Down
Loading