From 5c75d6676c742392f7620fa930626b140566a48d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Mon, 29 Jul 2024 15:09:19 +0200 Subject: [PATCH 1/2] previous atx malfeasance check for ATX V2 --- activation/handler_v2.go | 134 +++++++++++++++++++++------------ activation/handler_v2_test.go | 46 +++++++++++ activation/wire/malfeasance.go | 1 + sql/atxs/atxs.go | 11 +-- sql/atxs/atxs_test.go | 53 +++++++++++++ 5 files changed, 192 insertions(+), 53 deletions(-) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 155b5db6d1..51b5bfc805 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -110,19 +110,20 @@ func (h *HandlerV2) processATX( return fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err) } - parts, err := h.syntacticallyValidateDeps(ctx, watx) + atxData, err := h.syntacticallyValidateDeps(ctx, watx) if err != nil { return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err) } + atxData.marriages = marrying atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, MarriageATX: watx.MarriageATX, Coinbase: watx.Coinbase, BaseTickHeight: baseTickHeight, - NumUnits: parts.effectiveUnits, - TickCount: parts.ticks, - Weight: parts.weight, + NumUnits: atxData.effectiveUnits, + TickCount: atxData.ticks, + Weight: atxData.weight, VRFNonce: types.VRFPostIndex(watx.VRFNonce), SmesherID: watx.SmesherID, } @@ -140,7 +141,7 @@ func (h *HandlerV2) processATX( atx.SetID(watx.ID()) atx.SetReceived(received) - if err := h.storeAtx(ctx, atx, watx, marrying, parts.units); err != nil { + if err := h.storeAtx(ctx, atx, atxData); err != nil { return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err) } @@ -433,11 +434,18 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX) } -type atxParts struct { +type idData struct { + previous types.ATXID + units uint32 +} + +type activationTx struct { + *wire.ActivationTxV2 ticks uint64 weight uint64 effectiveUnits uint32 - units map[types.NodeID]uint32 + ids map[types.NodeID]idData + marriages []marriage } type nipostSize struct { @@ -495,9 +503,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error func (h *HandlerV2) syntacticallyValidateDeps( ctx context.Context, atx *wire.ActivationTxV2, -) (*atxParts, error) { - parts := atxParts{ - units: make(map[types.NodeID]uint32), +) (*activationTx, error) { + result := activationTx{ + ActivationTxV2: atx, + ids: make(map[types.NodeID]idData), } if atx.Initial != nil { if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil { @@ -585,7 +594,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( nipostSizes[i].ticks = leaves / h.tickSize } - parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp() + result.effectiveUnits, result.weight, err = nipostSizes.sumUp() if err != nil { return nil, err } @@ -593,9 +602,10 @@ func (h *HandlerV2) syntacticallyValidateDeps( // validate all niposts var smesherCommitment *types.ATXID for _, niposts := range atx.NiPosts { - for _, post := range niposts.Posts { - id := equivocationSet[post.MarriageIndex] + for _, p := range niposts.Posts { + id := equivocationSet[p.MarriageIndex] var commitment types.ATXID + var previous types.ATXID if atx.Initial != nil { commitment = atx.Initial.CommitmentATX } else { @@ -607,15 +617,16 @@ func (h *HandlerV2) syntacticallyValidateDeps( if id == atx.SmesherID { smesherCommitment = &commitment } + previous = previousAtxs[p.PrevATXIndex].ID() } err := h.nipostValidator.PostV2( ctx, id, commitment, - wire.PostFromWireV1(&post.Post), + wire.PostFromWireV1(&p.Post), niposts.Challenge[:], - post.NumUnits, + p.NumUnits, PostSubset([]byte(h.local)), ) invalidIdx := &verifying.ErrInvalidIndex{} @@ -636,7 +647,10 @@ func (h *HandlerV2) syntacticallyValidateDeps( if err != nil { return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err) } - parts.units[id] = post.NumUnits + result.ids[id] = idData{ + previous: previous, + units: p.NumUnits, + } } } @@ -650,18 +664,12 @@ func (h *HandlerV2) syntacticallyValidateDeps( } } - parts.ticks = nipostSizes.minTicks() - return &parts, nil + result.ticks = nipostSizes.minTicks() + return &result, nil } -func (h *HandlerV2) checkMalicious( - ctx context.Context, - tx *sql.Tx, - watx *wire.ActivationTxV2, - marrying []marriage, - ids []types.NodeID, -) error { - malicious, err := identities.IsMalicious(tx, watx.SmesherID) +func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error { + malicious, err := identities.IsMalicious(tx, atx.SmesherID) if err != nil { return fmt.Errorf("checking if node is malicious: %w", err) } @@ -669,7 +677,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoubleMarry(ctx, tx, watx.ID(), marrying) + malicious, err = h.checkDoubleMarry(ctx, tx, atx.ID(), atx.marriages) if err != nil { return fmt.Errorf("checking double marry: %w", err) } @@ -677,7 +685,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoublePost(ctx, tx, watx, ids) + malicious, err = h.checkDoublePost(ctx, tx, atx) if err != nil { return fmt.Errorf("checking double post: %w", err) } @@ -685,7 +693,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoubleMerge(ctx, tx, watx) + malicious, err = h.checkDoubleMerge(ctx, tx, atx) if err != nil { return fmt.Errorf("checking double merge: %w", err) } @@ -693,6 +701,14 @@ func (h *HandlerV2) checkMalicious( return nil } + malicious, err = h.checkPrevAtx(ctx, tx, atx) + if err != nil { + return fmt.Errorf("checking previous ATX: %w", err) + } + if malicious { + return nil + } + // TODO(mafa): contextual validation: // 1. check double-publish = ID contributed post to two ATXs in the same epoch // 2. check previous ATX @@ -726,10 +742,9 @@ func (h *HandlerV2) checkDoubleMarry( func (h *HandlerV2) checkDoublePost( ctx context.Context, tx *sql.Tx, - atx *wire.ActivationTxV2, - ids []types.NodeID, + atx *activationTx, ) (bool, error) { - for _, id := range ids { + for id := range atx.ids { atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): @@ -755,49 +770,72 @@ func (h *HandlerV2) checkDoublePost( return false, nil } -func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) { - if watx.MarriageATX == nil { +func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + if atx.MarriageATX == nil { return false, nil } - ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch) + ids, err := atxs.MergeConflict(tx, *atx.MarriageATX, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): return false, nil case err != nil: return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err) } - otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() }) + otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != atx.ID() }) other := ids[otherIndex] h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof", - zap.Stringer("marriage_atx", *watx.MarriageATX), - zap.Stringer("atx", watx.ID()), + zap.Stringer("marriage_atx", *atx.MarriageATX), + zap.Stringer("atx", atx.ID()), zap.Stringer("other_atx", other), - zap.Stringer("smesher_id", watx.SmesherID), + zap.Stringer("smesher_id", atx.SmesherID), ) // TODO(mafa): finish proof proof := &wire.ATXProof{ ProofType: wire.DoubleMerge, } - return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof) + return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof) +} + +func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + for id, data := range atx.ids { + prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) + if err != nil && !errors.Is(err, sql.ErrNotFound) { + return false, fmt.Errorf("get last atx by node id: %w", err) + } + if prevID == data.previous { + continue + } + + h.logger.Debug("atx references a wrong previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("actual", data.previous), + log.ZShortStringer("expected", prevID), + ) + + // TODO(mafa): finish proof + proof := &wire.ATXProof{ + ProofType: wire.InvalidPrevious, + } + return true, h.malPublisher.Publish(ctx, id, proof) + } + return false, nil } // Store an ATX in the DB. func (h *HandlerV2) storeAtx( ctx context.Context, atx *types.ActivationTx, - watx *wire.ActivationTxV2, - marrying []marriage, - units map[types.NodeID]uint32, + watx *activationTx, ) error { if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { - if len(marrying) != 0 { + if len(watx.marriages) != 0 { marriageData := identities.MarriageData{ ATX: atx.ID(), Target: atx.SmesherID, } - for i, m := range marrying { + for i, m := range watx.marriages { marriageData.Signature = m.signature marriageData.Index = i if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil { @@ -810,8 +848,8 @@ func (h *HandlerV2) storeAtx( if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) } - for id, units := range units { - err = atxs.SetUnits(tx, atx.ID(), id, units) + for id, p := range watx.ids { + err = atxs.SetUnits(tx, atx.ID(), id, p.units) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("setting atx units for ID %s: %w", id, err) } @@ -830,7 +868,7 @@ func (h *HandlerV2) storeAtx( // TODO(mafa): don't store own ATX if it would mark the node as malicious // this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in // the gossip handler (not sync!) - err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units)) + err := h.checkMalicious(ctx, tx, watx) if err != nil { return fmt.Errorf("check malicious: %w", err) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 06f52f9b69..4f0bafaf3b 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -1898,6 +1898,52 @@ func Test_CalculatingUnits(t *testing.T) { }) } +func TestContextual_PreviousATX(t *testing.T) { + golden := types.RandomATXID() + atxHndlr := newV2TestHandler(t, golden) + var ( + signers []*signing.EdSigner + eqSet []types.NodeID + ) + for range 3 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + eqSet = append(eqSet, sig.NodeID()) + } + + mATX, otherAtxs := marryIDs(t, atxHndlr, signers, golden) + + // signer 1 creates a solo ATX + soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) + soloAtx.Sign(signers[1]) + atxHndlr.expectAtxV2(soloAtx) + err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + // Pass a wrong previous ATX for signer 1. It's already been used for soloATX + // (which should be used for the previous ATX for signer 1). + merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) + matxID := mATX.ID() + merged.MarriageATX = &matxID + merged.Sign(signers[0]) + + atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) + atxHndlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Cond(func(data any) bool { + return data.(*wire.ATXProof).ProofType == wire.InvalidPrevious + })) + err = atxHndlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) +} + func Test_CalculatingWeight(t *testing.T) { t.Parallel() t.Run("total weight must not overflow uint64", func(t *testing.T) { diff --git a/activation/wire/malfeasance.go b/activation/wire/malfeasance.go index d8e60a4127..29460ec485 100644 --- a/activation/wire/malfeasance.go +++ b/activation/wire/malfeasance.go @@ -13,6 +13,7 @@ const ( DoublePublish ProofType = iota + 1 DoubleMarry DoubleMerge + InvalidPrevious InvalidPost ) diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 5e14cddde1..e2f9fe6729 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -248,7 +248,8 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er } // PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch. -// It returns the newest ATX ID that was published before the given public epoch. +// It returns the newest ATX ID containing PoST of the given node ID +// that was published before the given public epoch. func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) @@ -260,10 +261,10 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID } if rows, err := db.Exec(` - select id from atxs - where pubkey = ?1 and epoch < ?2 - order by epoch desc - limit 1;`, enc, dec); err != nil { + SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id + WHERE posts.pubkey = ?1 AND atxs.epoch < ?2 + ORDER BY atxs.epoch DESC + LIMIT 1;`, enc, dec); err != nil { return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err) } else if rows == 0 { return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound) diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 1dd1914968..ecdd704cd5 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1330,3 +1330,56 @@ func Test_MergeConflict(t *testing.T) { require.Len(t, ids, 2) }) } + +func TestPrevIDByNodeID(t *testing.T) { + t.Run("no previous ATXs", func(t *testing.T) { + db := sql.InMemory() + _, err := atxs.PrevIDByNodeID(db, types.RandomNodeID(), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("filters by epoch", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 4)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetUnits(db, atx2.ID(), sig.NodeID(), 4)) + + _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1) + require.ErrorIs(t, err, sql.ErrNotFound) + + prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3) + require.NoError(t, err) + require.Equal(t, atx2.ID(), prevID) + }) + t.Run("the previous is merged and ID is not the signer", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + id := types.RandomNodeID() + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 4)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), id, 8)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), types.RandomNodeID(), 12)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetUnits(db, atx2.ID(), sig.NodeID(), 4)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), types.RandomNodeID(), 12)) + + prevID, err := atxs.PrevIDByNodeID(db, id, 3) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + }) +} From 6fa08cdc13021882ec59323fef901ae981a567a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Mon, 12 Aug 2024 14:22:09 +0200 Subject: [PATCH 2/2] implement finding previous ATX collision --- activation/handler_v2.go | 20 ++++++++++++--- sql/atxs/atxs.go | 54 +++++++++++++--------------------------- sql/atxs/atxs_test.go | 30 +++++++++++----------- 3 files changed, 49 insertions(+), 55 deletions(-) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 9004e123fc..b4b58b8eb8 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -797,18 +797,32 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { for id, data := range atx.ids { - prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) + expectedPrevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) if err != nil && !errors.Is(err, sql.ErrNotFound) { return false, fmt.Errorf("get last atx by node id: %w", err) } - if prevID == data.previous { + if expectedPrevID == data.previous { continue } h.logger.Debug("atx references a wrong previous ATX", log.ZShortStringer("smesherID", id), log.ZShortStringer("actual", data.previous), - log.ZShortStringer("expected", prevID), + log.ZShortStringer("expected", expectedPrevID), + ) + + atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id) + switch { + case errors.Is(err, sql.ErrNotFound): + continue + case err != nil: + return false, fmt.Errorf("checking for previous ATX collision: %w", err) + } + + h.logger.Debug("creating a malfeasance proof for invalid previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("atx1", atx1), + log.ZShortStringer("atx2", atx2), ) // TODO(mafa): finish proof diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index fd0c8bc02a..6cfd2c51d9 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -874,46 +874,26 @@ func IterateAtxIdsWithMalfeasance( return err } -type PrevATXCollision struct { - NodeID1 types.NodeID - ATX1 types.ATXID - - NodeID2 types.NodeID - ATX2 types.ATXID -} - -func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) { - var result []PrevATXCollision - +func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) { + var atxs []types.ATXID + enc := func(stmt *sql.Statement) { + stmt.BindBytes(1, prev[:]) + stmt.BindBytes(2, id[:]) + } dec := func(stmt *sql.Statement) bool { - var nodeID1, nodeID2 types.NodeID - stmt.ColumnBytes(0, nodeID1[:]) - stmt.ColumnBytes(1, nodeID2[:]) - - var id1, id2 types.ATXID - stmt.ColumnBytes(2, id1[:]) - stmt.ColumnBytes(3, id2[:]) - - result = append(result, PrevATXCollision{ - NodeID1: nodeID1, - ATX1: id1, - - NodeID2: nodeID2, - ATX2: id2, - }) - return true + var id types.ATXID + stmt.ColumnBytes(0, id[:]) + atxs = append(atxs, id) + return len(atxs) < 2 } - // we are joining the table with itself to find ATXs with the same prevATX - // the WHERE clause ensures that we only get the pairs once - if _, err := db.Exec(` - SELECT p1.pubkey, p2.pubkey, p1.atxid, p2.atxid - FROM posts p1 - INNER JOIN posts p2 ON p1.prev_atxid = p2.prev_atxid - WHERE p1.atxid < p2.atxid;`, nil, dec); err != nil { - return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err) + _, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec) + if err != nil { + return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err) } - - return result, nil + if len(atxs) != 2 { + return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound + } + return atxs[0], atxs[1], nil } func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 77833a9f68..aaebce0755 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1023,7 +1023,7 @@ func TestLatest(t *testing.T) { } } -func Test_PrevATXCollisions(t *testing.T) { +func Test_PrevATXCollision(t *testing.T) { db := sql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1048,29 +1048,29 @@ func Test_PrevATXCollisions(t *testing.T) { require.NoError(t, err) require.Equal(t, atx2, got2) - // add 10 valid ATXs by 10 other smeshers + // add 10 valid ATXs by 10 other smeshers, using the same previous but no collision + var otherIds []types.NodeID for i := 2; i < 6; i++ { otherSig, err := signing.NewEdSigner() require.NoError(t, err) + otherIds = append(otherIds, otherSig.NodeID()) - atx, blob := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i))) - require.NoError(t, atxs.Add(db, atx, blob)) - - atx2, blob2 := newAtx(t, otherSig, - withPublishEpoch(types.EpochID(i+1)), - ) + atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1))) require.NoError(t, atxs.Add(db, atx2, blob2)) - require.NoError(t, atxs.SetPost(db, atx2.ID(), atx.ID(), 0, sig.NodeID(), 10)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10)) } - // get the collisions - got, err := atxs.PrevATXCollisions(db) + collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID()) require.NoError(t, err) - require.Len(t, got, 1) + require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2}) + + _, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID()) + require.ErrorIs(t, err, sql.ErrNotFound) - require.Equal(t, sig.NodeID(), got[0].NodeID1) - require.Equal(t, sig.NodeID(), got[0].NodeID2) - require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{got[0].ATX1, got[0].ATX2}) + for _, id := range append(otherIds, types.RandomNodeID()) { + _, _, err := atxs.PrevATXCollision(db, prevATXID, id) + require.ErrorIs(t, err, sql.ErrNotFound) + } } func TestCoinbase(t *testing.T) {