Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Merged by Bors] - detect invalid previous ATX for V2 ATXs #6189

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 53 additions & 8 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,14 @@
return nil
}

malicious, err = h.checkPrevAtx(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking previous ATX: %w", err)

Check warning on line 704 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L704

Added line #L704 was not covered by tests
}
if malicious {
return nil
}

// TODO(mafa): contextual validation:
// 1. check double-publish = ID contributed post to two ATXs in the same epoch
// 2. check previous ATX
Expand Down Expand Up @@ -762,29 +770,66 @@
return false, nil
}

func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) {
if watx.MarriageATX == nil {
func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
if atx.MarriageATX == nil {
return false, nil
}
ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch)
ids, err := atxs.MergeConflict(tx, *atx.MarriageATX, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
return false, nil
case err != nil:
return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err)
}
otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() })
otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != atx.ID() })
other := ids[otherIndex]

h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof",
zap.Stringer("marriage_atx", *watx.MarriageATX),
zap.Stringer("atx", watx.ID()),
zap.Stringer("marriage_atx", *atx.MarriageATX),
zap.Stringer("atx", atx.ID()),
zap.Stringer("other_atx", other),
zap.Stringer("smesher_id", watx.SmesherID),
zap.Stringer("smesher_id", atx.SmesherID),
)

var proof wire.Proof
return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof)
return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof)
}

func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for id, data := range atx.ids {
expectedPrevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return false, fmt.Errorf("get last atx by node id: %w", err)

Check warning on line 802 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L802

Added line #L802 was not covered by tests
}
if expectedPrevID == data.previous {
continue
}

h.logger.Debug("atx references a wrong previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("actual", data.previous),
log.ZShortStringer("expected", expectedPrevID),
)

atx1, atx2, err := atxs.PrevATXCollision(tx, data.previous, id)
switch {
case errors.Is(err, sql.ErrNotFound):
continue
case err != nil:
return false, fmt.Errorf("checking for previous ATX collision: %w", err)

Check warning on line 819 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L816-L819

Added lines #L816 - L819 were not covered by tests
}

h.logger.Debug("creating a malfeasance proof for invalid previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("atx1", atx1),
log.ZShortStringer("atx2", atx2),
)

// TODO(mafa): finish proof
var proof wire.Proof
return true, h.malPublisher.Publish(ctx, id, proof)
}
return false, nil
}

// Store an ATX in the DB.
Expand Down
44 changes: 44 additions & 0 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1894,6 +1894,50 @@ func Test_CalculatingUnits(t *testing.T) {
})
}

func TestContextual_PreviousATX(t *testing.T) {
golden := types.RandomATXID()
atxHndlr := newV2TestHandler(t, golden)
var (
signers []*signing.EdSigner
eqSet []types.NodeID
)
for range 3 {
sig, err := signing.NewEdSigner()
require.NoError(t, err)
signers = append(signers, sig)
eqSet = append(eqSet, sig.NodeID())
}

mATX, otherAtxs := marryIDs(t, atxHndlr, signers, golden)

// signer 1 creates a solo ATX
soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID())
soloAtx.Sign(signers[1])
atxHndlr.expectAtxV2(soloAtx)
err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now())
require.NoError(t, err)

// create a MergedATX for all IDs
merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID())
post := wire.SubPostV2{
MarriageIndex: 1,
PrevATXIndex: 1,
NumUnits: soloAtx.TotalNumUnits(),
}
merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post)
// Pass a wrong previous ATX for signer 1. It's already been used for soloATX
// (which should be used for the previous ATX for signer 1).
merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID())
matxID := mATX.ID()
merged.MarriageATX = &matxID
merged.Sign(signers[0])

atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100})
atxHndlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Any())
err = atxHndlr.processATX(context.Background(), "", merged, time.Now())
require.NoError(t, err)
}

func Test_CalculatingWeight(t *testing.T) {
t.Parallel()
t.Run("total weight must not overflow uint64", func(t *testing.T) {
Expand Down
9 changes: 5 additions & 4 deletions activation/wire/malfeasance.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ const (
LegacyInvalidPost ProofType = 0x01
LegacyInvalidPrevATX ProofType = 0x02

DoublePublish ProofType = 0x10
DoubleMarry ProofType = 0x11
DoubleMerge ProofType = 0x12
InvalidPost ProofType = 0x13
DoublePublish ProofType = 0x10
DoubleMarry ProofType = 0x11
DoubleMerge ProofType = 0x12
InvalidPost ProofType = 0x13
InvalidPrevious ProofType = 0x14
)

// ProofVersion is an identifier for the version of the proof that is encoded in the ATXProof.
Expand Down
65 changes: 23 additions & 42 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@
}

// PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch.
// It returns the newest ATX ID that was published before the given public epoch.
// It returns the newest ATX ID containing PoST of the given node ID
// that was published before the given public epoch.
func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) {
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, nodeID.Bytes())
Expand All @@ -257,10 +258,10 @@
}

if rows, err := db.Exec(`
select id from atxs
where pubkey = ?1 and epoch < ?2
order by epoch desc
limit 1;`, enc, dec); err != nil {
SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id
WHERE posts.pubkey = ?1 AND atxs.epoch < ?2
ORDER BY atxs.epoch DESC
LIMIT 1;`, enc, dec); err != nil {
return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err)
} else if rows == 0 {
return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound)
Expand Down Expand Up @@ -873,46 +874,26 @@
return err
}

type PrevATXCollision struct {
NodeID1 types.NodeID
ATX1 types.ATXID

NodeID2 types.NodeID
ATX2 types.ATXID
}

func PrevATXCollisions(db sql.Executor) ([]PrevATXCollision, error) {
var result []PrevATXCollision

func PrevATXCollision(db sql.Executor, prev types.ATXID, id types.NodeID) (types.ATXID, types.ATXID, error) {
var atxs []types.ATXID
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, prev[:])
stmt.BindBytes(2, id[:])
}
dec := func(stmt *sql.Statement) bool {
var nodeID1, nodeID2 types.NodeID
stmt.ColumnBytes(0, nodeID1[:])
stmt.ColumnBytes(1, nodeID2[:])

var id1, id2 types.ATXID
stmt.ColumnBytes(2, id1[:])
stmt.ColumnBytes(3, id2[:])

result = append(result, PrevATXCollision{
NodeID1: nodeID1,
ATX1: id1,

NodeID2: nodeID2,
ATX2: id2,
})
return true
var id types.ATXID
stmt.ColumnBytes(0, id[:])
atxs = append(atxs, id)
return len(atxs) < 2
}
// we are joining the table with itself to find ATXs with the same prevATX
// the WHERE clause ensures that we only get the pairs once
if _, err := db.Exec(`
SELECT p1.pubkey, p2.pubkey, p1.atxid, p2.atxid
FROM posts p1
INNER JOIN posts p2 ON p1.prev_atxid = p2.prev_atxid
WHERE p1.atxid < p2.atxid;`, nil, dec); err != nil {
return nil, fmt.Errorf("error getting ATXs with same prevATX: %w", err)
_, err := db.Exec("SELECT atxid FROM posts WHERE prev_atxid = ?1 AND pubkey = ?2;", enc, dec)
fasmat marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return types.EmptyATXID, types.EmptyATXID, fmt.Errorf("error getting ATXs with same prevATX: %w", err)

Check warning on line 891 in sql/atxs/atxs.go

View check run for this annotation

Codecov / codecov/patch

sql/atxs/atxs.go#L891

Added line #L891 was not covered by tests
}

return result, nil
if len(atxs) != 2 {
return types.EmptyATXID, types.EmptyATXID, sql.ErrNotFound
}
return atxs[0], atxs[1], nil
}

func Units(db sql.Executor, atxID types.ATXID, nodeID types.NodeID) (uint32, error) {
Expand Down
83 changes: 68 additions & 15 deletions sql/atxs/atxs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,7 @@ func TestLatest(t *testing.T) {
}
}

func Test_PrevATXCollisions(t *testing.T) {
func Test_PrevATXCollision(t *testing.T) {
db := sql.InMemory()
sig, err := signing.NewEdSigner()
require.NoError(t, err)
Expand All @@ -1048,29 +1048,29 @@ func Test_PrevATXCollisions(t *testing.T) {
require.NoError(t, err)
require.Equal(t, atx2, got2)

// add 10 valid ATXs by 10 other smeshers
// add 10 valid ATXs by 10 other smeshers, using the same previous but no collision
var otherIds []types.NodeID
for i := 2; i < 6; i++ {
otherSig, err := signing.NewEdSigner()
require.NoError(t, err)
otherIds = append(otherIds, otherSig.NodeID())

atx, blob := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i)))
require.NoError(t, atxs.Add(db, atx, blob))

atx2, blob2 := newAtx(t, otherSig,
withPublishEpoch(types.EpochID(i+1)),
)
atx2, blob2 := newAtx(t, otherSig, withPublishEpoch(types.EpochID(i+1)))
require.NoError(t, atxs.Add(db, atx2, blob2))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx.ID(), 0, sig.NodeID(), 10))
require.NoError(t, atxs.SetPost(db, atx2.ID(), prevATXID, 0, atx2.SmesherID, 10))
}

// get the collisions
got, err := atxs.PrevATXCollisions(db)
collision1, collision2, err := atxs.PrevATXCollision(db, prevATXID, sig.NodeID())
require.NoError(t, err)
require.Len(t, got, 1)
require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{collision1, collision2})

require.Equal(t, sig.NodeID(), got[0].NodeID1)
require.Equal(t, sig.NodeID(), got[0].NodeID2)
require.ElementsMatch(t, []types.ATXID{atx1.ID(), atx2.ID()}, []types.ATXID{got[0].ATX1, got[0].ATX2})
_, _, err = atxs.PrevATXCollision(db, types.RandomATXID(), sig.NodeID())
require.ErrorIs(t, err, sql.ErrNotFound)

for _, id := range append(otherIds, types.RandomNodeID()) {
_, _, err := atxs.PrevATXCollision(db, prevATXID, id)
require.ErrorIs(t, err, sql.ErrNotFound)
}
}

func TestCoinbase(t *testing.T) {
Expand Down Expand Up @@ -1362,3 +1362,56 @@ func Test_Previous(t *testing.T) {
require.Equal(t, previousAtxs, got)
})
}

func TestPrevIDByNodeID(t *testing.T) {
t.Run("no previous ATXs", func(t *testing.T) {
db := sql.InMemory()
_, err := atxs.PrevIDByNodeID(db, types.RandomNodeID(), 0)
require.ErrorIs(t, err, sql.ErrNotFound)
})
t.Run("filters by epoch", func(t *testing.T) {
db := sql.InMemory()
sig, err := signing.NewEdSigner()
require.NoError(t, err)

atx1, blob1 := newAtx(t, sig, withPublishEpoch(1))
require.NoError(t, atxs.Add(db, atx1, blob1))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, sig.NodeID(), 4))

atx2, blob2 := newAtx(t, sig, withPublishEpoch(2))
require.NoError(t, atxs.Add(db, atx2, blob2))
require.NoError(t, atxs.SetPost(db, atx2.ID(), types.EmptyATXID, 0, sig.NodeID(), 4))

_, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1)
require.ErrorIs(t, err, sql.ErrNotFound)

prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2)
require.NoError(t, err)
require.Equal(t, atx1.ID(), prevID)

prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3)
require.NoError(t, err)
require.Equal(t, atx2.ID(), prevID)
})
t.Run("the previous is merged and ID is not the signer", func(t *testing.T) {
db := sql.InMemory()
sig, err := signing.NewEdSigner()
require.NoError(t, err)
id := types.RandomNodeID()

atx1, blob1 := newAtx(t, sig, withPublishEpoch(1))
require.NoError(t, atxs.Add(db, atx1, blob1))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, sig.NodeID(), 4))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, id, 8))
require.NoError(t, atxs.SetPost(db, atx1.ID(), types.EmptyATXID, 0, types.RandomNodeID(), 12))

atx2, blob2 := newAtx(t, sig, withPublishEpoch(2))
require.NoError(t, atxs.Add(db, atx2, blob2))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx1.ID(), 0, sig.NodeID(), 4))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx1.ID(), 0, types.RandomNodeID(), 12))

prevID, err := atxs.PrevIDByNodeID(db, id, 3)
require.NoError(t, err)
require.Equal(t, atx1.ID(), prevID)
})
}
Loading