Skip to content

Commit

Permalink
implement finding previous ATX collision
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Aug 12, 2024
1 parent f009de0 commit 6fa08cd
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 55 deletions.
20 changes: 17 additions & 3 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -797,18 +797,32 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activ

func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for id, data := range atx.ids {
prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch)
expectedPrevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return false, fmt.Errorf("get last atx by node id: %w", err)

Check warning on line 802 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L802

Added line #L802 was not covered by tests
}
if prevID == data.previous {
if expectedPrevID == data.previous {
continue
}

h.logger.Debug("atx references a wrong previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("actual", data.previous),
log.ZShortStringer("expected", prevID),
log.ZShortStringer("expected", expectedPrevID),
)

atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id)
switch {
case errors.Is(err, sql.ErrNotFound):
continue
case err != nil:
return false, fmt.Errorf("checking for previous ATX collision: %w", err)

Check warning on line 819 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L816-L819

Added lines #L816 - L819 were not covered by tests
}

h.logger.Debug("creating a malfeasance proof for invalid previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("atx1", atx1),
log.ZShortStringer("atx2", atx2),
)

// TODO(mafa): finish proof
Expand Down
54 changes: 17 additions & 37 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -874,46 +874,26 @@ func IterateAtxIdsWithMalfeasance(
return err
}

type PrevATXCollision struct {
NodeID1 types.NodeID
ATX1 types.ATXID

NodeID2 types.NodeID
ATX2 types.ATXID
}

func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) {
var result []PrevATXCollision

func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) {
var atxs []types.ATXID
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, prev[:])
stmt.BindBytes(2, id[:])
}
dec := func(stmt *sql.Statement) bool {
var nodeID1, nodeID2 types.NodeID
stmt.ColumnBytes(0, nodeID1[:])
stmt.ColumnBytes(1, nodeID2[:])

var id1, id2 types.ATXID
stmt.ColumnBytes(2, id1[:])
stmt.ColumnBytes(3, id2[:])

result = append(result, PrevATXCollision{
NodeID1: nodeID1,
ATX1: id1,

NodeID2: nodeID2,
ATX2: id2,
})
return true
var id types.ATXID
stmt.ColumnBytes(0, id[:])
atxs = append(atxs, id)
return len(atxs) < 2
}
// we are joining the table with itself to find ATXs with the same prevATX
// the WHERE clause ensures that we only get the pairs once
if _, err := db.Exec(`
SELECT p1.pubkey, p2.pubkey, p1.atxid, p2.atxid
FROM posts p1
INNER JOIN posts p2 ON p1.prev_atxid = p2.prev_atxid
WHERE p1.atxid < p2.atxid;`, nil, dec); err != nil {
return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err)
_, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec)
if err != nil {
return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err)

Check warning on line 891 in sql/atxs/atxs.go

View check run for this annotation

Codecov / codecov/patch

sql/atxs/atxs.go#L891

Added line #L891 was not covered by tests
}

return result, nil
if len(atxs) != 2 {
return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound
}
return atxs[0], atxs[1], nil
}

func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) {
Expand Down
30 changes: 15 additions & 15 deletions sql/atxs/atxs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ func TestLatest(t *testing.T) {
}
}

func Test_PrevATXCollisions(t *testing.T) {
func Test_PrevATXCollision(t *testing.T) {
db := sql.InMemory()
sig, err := signing.NewEdSigner()
require.NoError(t, err)
Expand All @@ -1048,29 +1048,29 @@ func Test_PrevATXCollisions(t *testing.T) {
require.NoError(t, err)
require.Equal(t, atx2, got2)

// add 10 valid ATXs by 10 other smeshers
// add 10 valid ATXs by 10 other smeshers, using the same previous but no collision
var otherIds []types.NodeID
for i := 2; i < 6; i++ {
otherSig, err := signing.NewEdSigner()
require.NoError(t, err)
otherIds = append(otherIds, otherSig.NodeID())

atx, blob := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i)))
require.NoError(t, atxs.Add(db, atx, blob))

atx2, blob2 := newAtx(t, otherSig,
withPublishEpoch(types.EpochID(i+1)),
)
atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1)))
require.NoError(t, atxs.Add(db, atx2, blob2))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx.ID(), 0, sig.NodeID(), 10))
require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10))
}

// get the collisions
got, err := atxs.PrevATXCollisions(db)
collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID())
require.NoError(t, err)
require.Len(t, got, 1)
require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2})

_, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID())
require.ErrorIs(t, err, sql.ErrNotFound)

require.Equal(t, sig.NodeID(), got[0].NodeID1)
require.Equal(t, sig.NodeID(), got[0].NodeID2)
require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{got[0].ATX1, got[0].ATX2})
for _, id := range append(otherIds, types.RandomNodeID()) {
_, _, err := atxs.PrevATXCollision(db, prevATXID, id)
require.ErrorIs(t, err, sql.ErrNotFound)
}
}

func TestCoinbase(t *testing.T) {
Expand Down

0 comments on commit 6fa08cd

Please sign in to comment.