From ef16cecc4be7d1324772ba9416a39752be3c8587 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bartosz=20R=C3=B3=C5=BCa=C5=84ski?= Date: Mon, 29 Jul 2024 15:09:19 +0200 Subject: [PATCH] previous atx malfeasance check for ATX V2 --- activation/handler_v2.go | 127 +++++++++++++++++++++++----------- activation/handler_v2_test.go | 77 +++++++++++++++++++++ sql/atxs/atxs.go | 11 +-- sql/atxs/atxs_test.go | 53 ++++++++++++++ 4 files changed, 222 insertions(+), 46 deletions(-) diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 13c6ed084a3..488de71ae0c 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -111,10 +111,11 @@ func (h *HandlerV2) processATX( return nil, fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err) } - parts, proof, err := h.syntacticallyValidateDeps(ctx, watx) + atxData, proof, err := h.syntacticallyValidateDeps(ctx, watx) if err != nil { return nil, fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err) } + atxData.marriages = marrying if proof != nil { return proof, err @@ -125,9 +126,9 @@ func (h *HandlerV2) processATX( MarriageATX: watx.MarriageATX, Coinbase: watx.Coinbase, BaseTickHeight: baseTickHeight, - NumUnits: parts.effectiveUnits, - TickCount: parts.ticks, - Weight: parts.weight, + NumUnits: atxData.effectiveUnits, + TickCount: atxData.ticks, + Weight: atxData.weight, VRFNonce: types.VRFPostIndex(watx.VRFNonce), SmesherID: watx.SmesherID, } @@ -145,7 +146,7 @@ func (h *HandlerV2) processATX( atx.SetID(watx.ID()) atx.SetReceived(received) - proof, err = h.storeAtx(ctx, atx, watx, marrying, parts.units) + proof, err = h.storeAtx(ctx, atx, atxData) if err != nil { return nil, fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err) } @@ -439,11 +440,18 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX) } -type atxParts struct { +type idData struct { + previous types.ATXID + units uint32 +} + +type activationTx struct { + *wire.ActivationTxV2 ticks uint64 weight uint64 effectiveUnits uint32 - units map[types.NodeID]uint32 + ids map[types.NodeID]idData + marriages []marriage } type nipostSize struct { @@ -501,9 +509,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error func (h *HandlerV2) syntacticallyValidateDeps( ctx context.Context, atx *wire.ActivationTxV2, -) (*atxParts, *mwire.MalfeasanceProof, error) { - parts := atxParts{ - units: make(map[types.NodeID]uint32), +) (*activationTx, *mwire.MalfeasanceProof, error) { + result := activationTx{ + ActivationTxV2: atx, + ids: make(map[types.NodeID]idData), } if atx.Initial != nil { if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil { @@ -591,7 +600,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( nipostSizes[i].ticks = leaves / h.tickSize } - parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp() + result.effectiveUnits, result.weight, err = nipostSizes.sumUp() if err != nil { return nil, nil, err } @@ -599,9 +608,10 @@ func (h *HandlerV2) syntacticallyValidateDeps( // validate all niposts var smesherCommitment *types.ATXID for _, niposts := range atx.NiPosts { - for _, post := range niposts.Posts { - id := equivocationSet[post.MarriageIndex] + for _, p := range niposts.Posts { + id := equivocationSet[p.MarriageIndex] var commitment types.ATXID + var previous types.ATXID if atx.Initial != nil { commitment = atx.Initial.CommitmentATX } else { @@ -613,15 +623,16 @@ func (h *HandlerV2) syntacticallyValidateDeps( if id == atx.SmesherID { smesherCommitment = &commitment } + previous = previousAtxs[p.PrevATXIndex].ID() } err := h.nipostValidator.PostV2( ctx, id, commitment, - wire.PostFromWireV1(&post.Post), + wire.PostFromWireV1(&p.Post), niposts.Challenge[:], - post.NumUnits, + p.NumUnits, PostSubset([]byte(h.local)), ) var invalidIdx *verifying.ErrInvalidIndex @@ -636,7 +647,10 @@ func (h *HandlerV2) syntacticallyValidateDeps( if err != nil { return nil, nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err) } - parts.units[id] = post.NumUnits + result.ids[id] = idData{ + previous: previous, + units: p.NumUnits, + } } } @@ -650,17 +664,13 @@ func (h *HandlerV2) syntacticallyValidateDeps( } } - parts.ticks = nipostSizes.minTicks() + result.ticks = nipostSizes.minTicks() - return &parts, nil, nil + return &result, nil, nil } -func (h *HandlerV2) checkMalicious( - tx *sql.Tx, - watx *wire.ActivationTxV2, - marrying []marriage, -) (bool, *mwire.MalfeasanceProof, error) { - malicious, err := identities.IsMalicious(tx, watx.SmesherID) +func (h *HandlerV2) checkMalicious(tx *sql.Tx, atx *activationTx) (bool, *mwire.MalfeasanceProof, error) { + malicious, err := identities.IsMalicious(tx, atx.SmesherID) if err != nil { return false, nil, fmt.Errorf("checking if node is malicious: %w", err) } @@ -668,7 +678,7 @@ func (h *HandlerV2) checkMalicious( return true, nil, nil } - proof, err := h.checkDoubleMarry(tx, marrying) + proof, err := h.checkDoubleMarry(tx, atx.marriages) if err != nil { return false, nil, fmt.Errorf("checking double marry: %w", err) } @@ -676,7 +686,7 @@ func (h *HandlerV2) checkMalicious( return true, proof, nil } - proof, err = h.checkDoubleMerge(tx, watx) + proof, err = h.checkDoubleMerge(tx, atx) if err != nil { return false, nil, fmt.Errorf("checking double merge: %w", err) } @@ -684,6 +694,14 @@ func (h *HandlerV2) checkMalicious( return true, proof, nil } + proof, err = h.checkPrevAtx(tx, atx) + if err != nil { + return false, nil, fmt.Errorf("checking previous ATX: %w", err) + } + if proof != nil { + return true, proof, nil + } + // TODO: contextual validation: // 1. check double-publish // 2. check previous ATX @@ -713,11 +731,11 @@ func (h *HandlerV2) checkDoubleMarry(tx *sql.Tx, marrying []marriage) (*mwire.Ma return nil, nil } -func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, watx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) { - if watx.MarriageATX == nil { +func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, atx *activationTx) (*mwire.MalfeasanceProof, error) { + if atx.MarriageATX == nil { return nil, nil } - id, err := atxs.AtxWithMarriage(tx, *watx.MarriageATX, watx.PublishEpoch) + id, err := atxs.AtxWithMarriage(tx, *atx.MarriageATX, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): return nil, nil @@ -726,10 +744,10 @@ func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, watx *wire.ActivationTxV2) (*mw } h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof", - zap.Stringer("marriage_atx", *watx.MarriageATX), - zap.Stringer("atx", watx.ID()), + zap.Stringer("marriage_atx", *atx.MarriageATX), + zap.Stringer("atx", atx.ID()), zap.Stringer("other atx", id), - zap.Stringer("smesher_id", watx.SmesherID), + zap.Stringer("smesher_id", atx.SmesherID), ) // FIXME: implement the proof @@ -742,14 +760,41 @@ func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, watx *wire.ActivationTxV2) (*mw return proof, nil } +func (h *HandlerV2) checkPrevAtx(tx *sql.Tx, atx *activationTx) (*mwire.MalfeasanceProof, error) { + for id, data := range atx.ids { + prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) + if err != nil && !errors.Is(err, sql.ErrNotFound) { + return nil, fmt.Errorf("get last atx by node id: %w", err) + } + if prevID == data.previous { + continue + } + + h.logger.Debug("atx references a wrong previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("actual", data.previous), + log.ZShortStringer("expected", prevID), + ) + + // FIXME: implement the proof + proof := &mwire.MalfeasanceProof{ + Proof: mwire.Proof{ + Type: mwire.DoubleMarry, + Data: &mwire.DoubleMarryProof{}, + }, + } + return proof, nil + + } + return nil, nil +} + // Store an ATX in the DB. // TODO: detect malfeasance and create proofs. func (h *HandlerV2) storeAtx( ctx context.Context, atx *types.ActivationTx, - watx *wire.ActivationTxV2, - marrying []marriage, - units map[types.NodeID]uint32, + watx *activationTx, ) (*mwire.MalfeasanceProof, error) { var ( malicious bool @@ -757,17 +802,17 @@ func (h *HandlerV2) storeAtx( ) if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { var err error - malicious, proof, err = h.checkMalicious(tx, watx, marrying) + malicious, proof, err = h.checkMalicious(tx, watx) if err != nil { return fmt.Errorf("check malicious: %w", err) } - if len(marrying) != 0 { + if len(watx.marriages) != 0 { marriageData := identities.MarriageData{ ATX: atx.ID(), Target: atx.SmesherID, } - for i, m := range marrying { + for i, m := range watx.marriages { marriageData.Signature = m.signature marriageData.Index = i if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil { @@ -787,8 +832,8 @@ func (h *HandlerV2) storeAtx( if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) } - for id, units := range units { - err = atxs.SetUnits(tx, atx.ID(), id, units) + for id, p := range watx.ids { + err = atxs.SetUnits(tx, atx.ID(), id, p.units) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("setting atx units for ID %s: %w", id, err) } @@ -812,7 +857,7 @@ func (h *HandlerV2) storeAtx( for _, id := range set { allMalicious[id] = struct{}{} } - for _, m := range marrying { + for _, m := range watx.marriages { allMalicious[m.id] = struct{}{} } } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index f67d414bf46..9bb5ba17a03 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -647,6 +647,36 @@ func marryIDs( return mATX, other } +func marryIDsV2( + t testing.TB, + atxHandler *v2TestHandler, + signers []*signing.EdSigner, + golden types.ATXID, +) (marriage *wire.ActivationTxV2, other []*wire.ActivationTxV2) { + sig := signers[0] + mATX := newInitialATXv2(t, golden) + mATX.Marriages = []wire.MarriageCertificate{{ + Signature: sig.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }} + + for _, signer := range signers[1:] { + atx := atxHandler.createAndProcessInitial(t, signer) + other = append(other, atx) + mATX.Marriages = append(mATX.Marriages, wire.MarriageCertificate{ + ReferenceAtx: atx.ID(), + Signature: signer.Sign(signing.MARRIAGE, sig.NodeID().Bytes()), + }) + } + + mATX.Sign(sig) + atxHandler.expectInitialAtxV2(mATX) + p, err := atxHandler.processATX(context.Background(), "", mATX, time.Now()) + require.NoError(t, err) + require.Nil(t, p) + + return mATX, other +} + func TestHandlerV2_ProcessMergedATX(t *testing.T) { t.Parallel() var ( @@ -1865,6 +1895,53 @@ func Test_CalculatingUnits(t *testing.T) { }) } +func TestContextual_PreviousATX(t *testing.T) { + golden := types.RandomATXID() + atxHndlr := newV2TestHandler(t, golden) + var ( + signers []*signing.EdSigner + eqSet []types.NodeID + ) + for range 3 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + eqSet = append(eqSet, sig.NodeID()) + } + + mATX, otherAtxs := marryIDsV2(t, atxHndlr, signers, golden) + + // signer 1 creates a solo ATX + soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) + soloAtx.Sign(signers[1]) + atxHndlr.expectAtxV2(soloAtx) + _, err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + // Pass a wrong previous ATX for signer 1. It's already been used for soloATX + // (which should be used for the previous ATX for signer 1). + merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) + matxID := mATX.ID() + merged.MarriageATX = &matxID + merged.Sign(signers[0]) + + atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) + for _, sig := range signers { + atxHndlr.mtortoise.EXPECT().OnMalfeasance(sig.NodeID()) + } + p, err := atxHndlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) + require.NotNil(t, p) +} + func Test_CalculatingWeight(t *testing.T) { t.Parallel() t.Run("total weight must not overflow uint64", func(t *testing.T) { diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 6aaeb86cad6..3667b405050 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -248,7 +248,8 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er } // PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch. -// It returns the newest ATX ID that was published before the given public epoch. +// It returns the newest ATX ID containing PoST of the given node ID +// that was published before the given public epoch. func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) @@ -260,10 +261,10 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID } if rows, err := db.Exec(` - select id from atxs - where pubkey = ?1 and epoch < ?2 - order by epoch desc - limit 1;`, enc, dec); err != nil { + SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id + WHERE posts.pubkey = ?1 AND atxs.epoch < ?2 + ORDER BY atxs.epoch DESC + LIMIT 1;`, enc, dec); err != nil { return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err) } else if rows == 0 { return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound) diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index d3ac0dd3de1..95da9d0fc26 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1201,3 +1201,56 @@ func Test_AtxWithPrevious(t *testing.T) { require.Equal(t, atx2.ID(), id) }) } + +func TestPrevIDByNodeID(t *testing.T) { + t.Run("no previous ATXs", func(t *testing.T) { + db := sql.InMemory() + _, err := atxs.PrevIDByNodeID(db, types.RandomNodeID(), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("filters by epoch", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 4)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetUnits(db, atx2.ID(), sig.NodeID(), 4)) + + _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1) + require.ErrorIs(t, err, sql.ErrNotFound) + + prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3) + require.NoError(t, err) + require.Equal(t, atx2.ID(), prevID) + }) + t.Run("the previous is merged and ID is not the signer", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + id := types.RandomNodeID() + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 4)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), id, 8)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), types.RandomNodeID(), 12)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetUnits(db, atx2.ID(), sig.NodeID(), 4)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), types.RandomNodeID(), 12)) + + prevID, err := atxs.PrevIDByNodeID(db, id, 3) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + }) +}