diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 9004e123fc..b4b58b8eb8 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -797,18 +797,32 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activ func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { for id, data := range atx.ids { - prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) + expectedPrevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) if err != nil && !errors.Is(err, sql.ErrNotFound) { return false, fmt.Errorf("get last atx by node id: %w", err) } - if prevID == data.previous { + if expectedPrevID == data.previous { continue } h.logger.Debug("atx references a wrong previous ATX", log.ZShortStringer("smesherID", id), log.ZShortStringer("actual", data.previous), - log.ZShortStringer("expected", prevID), + log.ZShortStringer("expected", expectedPrevID), + ) + + atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id) + switch { + case errors.Is(err, sql.ErrNotFound): + continue + case err != nil: + return false, fmt.Errorf("checking for previous ATX collision: %w", err) + } + + h.logger.Debug("creating a malfeasance proof for invalid previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("atx1", atx1), + log.ZShortStringer("atx2", atx2), ) // TODO(mafa): finish proof diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index fd0c8bc02a..6cfd2c51d9 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -874,46 +874,26 @@ func IterateAtxIdsWithMalfeasance( return err } -type PrevATXCollision struct { - NodeID1 types.NodeID - ATX1 types.ATXID - - NodeID2 types.NodeID - ATX2 types.ATXID -} - -func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) { - var result []PrevATXCollision - +func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) { + var atxs []types.ATXID + enc := func(stmt *sql.Statement) { + stmt.BindBytes(1, prev[:]) + stmt.BindBytes(2, id[:]) + } dec := func(stmt *sql.Statement) bool { - var nodeID1, nodeID2 types.NodeID - stmt.ColumnBytes(0, nodeID1[:]) - stmt.ColumnBytes(1, nodeID2[:]) - - var id1, id2 types.ATXID - stmt.ColumnBytes(2, id1[:]) - stmt.ColumnBytes(3, id2[:]) - - result = append(result, PrevATXCollision{ - NodeID1: nodeID1, - ATX1: id1, - - NodeID2: nodeID2, - ATX2: id2, - }) - return true + var id types.ATXID + stmt.ColumnBytes(0, id[:]) + atxs = append(atxs, id) + return len(atxs) < 2 } - // we are joining the table with itself to find ATXs with the same prevATX - // the WHERE clause ensures that we only get the pairs once - if _, err := db.Exec(` - SELECT p1.pubkey, p2.pubkey, p1.atxid, p2.atxid - FROM posts p1 - INNER JOIN posts p2 ON p1.prev_atxid = p2.prev_atxid - WHERE p1.atxid < p2.atxid;`, nil, dec); err != nil { - return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err) + _, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec) + if err != nil { + return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err) } - - return result, nil + if len(atxs) != 2 { + return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound + } + return atxs[0], atxs[1], nil } func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) { diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 77833a9f68..aaebce0755 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1023,7 +1023,7 @@ func TestLatest(t *testing.T) { } } -func Test_PrevATXCollisions(t *testing.T) { +func Test_PrevATXCollision(t *testing.T) { db := sql.InMemory() sig, err := signing.NewEdSigner() require.NoError(t, err) @@ -1048,29 +1048,29 @@ func Test_PrevATXCollisions(t *testing.T) { require.NoError(t, err) require.Equal(t, atx2, got2) - // add 10 valid ATXs by 10 other smeshers + // add 10 valid ATXs by 10 other smeshers, using the same previous but no collision + var otherIds []types.NodeID for i := 2; i < 6; i++ { otherSig, err := signing.NewEdSigner() require.NoError(t, err) + otherIds = append(otherIds, otherSig.NodeID()) - atx, blob := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i))) - require.NoError(t, atxs.Add(db, atx, blob)) - - atx2, blob2 := newAtx(t, otherSig, - withPublishEpoch(types.EpochID(i+1)), - ) + atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1))) require.NoError(t, atxs.Add(db, atx2, blob2)) - require.NoError(t, atxs.SetPost(db, atx2.ID(), atx.ID(), 0, sig.NodeID(), 10)) + require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10)) } - // get the collisions - got, err := atxs.PrevATXCollisions(db) + collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID()) require.NoError(t, err) - require.Len(t, got, 1) + require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2}) + + _, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID()) + require.ErrorIs(t, err, sql.ErrNotFound) - require.Equal(t, sig.NodeID(), got[0].NodeID1) - require.Equal(t, sig.NodeID(), got[0].NodeID2) - require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{got[0].ATX1, got[0].ATX2}) + for _, id := range append(otherIds, types.RandomNodeID()) { + _, _, err := atxs.PrevATXCollision(db, prevATXID, id) + require.ErrorIs(t, err, sql.ErrNotFound) + } } func TestCoinbase(t *testing.T) {