diff --git a/activation/handler_v2.go b/activation/handler_v2.go index 155b5db6d1..51b5bfc805 100644 --- a/activation/handler_v2.go +++ b/activation/handler_v2.go @@ -110,19 +110,20 @@ func (h *HandlerV2) processATX( return fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err) } - parts, err := h.syntacticallyValidateDeps(ctx, watx) + atxData, err := h.syntacticallyValidateDeps(ctx, watx) if err != nil { return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err) } + atxData.marriages = marrying atx := &types.ActivationTx{ PublishEpoch: watx.PublishEpoch, MarriageATX: watx.MarriageATX, Coinbase: watx.Coinbase, BaseTickHeight: baseTickHeight, - NumUnits: parts.effectiveUnits, - TickCount: parts.ticks, - Weight: parts.weight, + NumUnits: atxData.effectiveUnits, + TickCount: atxData.ticks, + Weight: atxData.weight, VRFNonce: types.VRFPostIndex(watx.VRFNonce), SmesherID: watx.SmesherID, } @@ -140,7 +141,7 @@ func (h *HandlerV2) processATX( atx.SetID(watx.ID()) atx.SetReceived(received) - if err := h.storeAtx(ctx, atx, watx, marrying, parts.units); err != nil { + if err := h.storeAtx(ctx, atx, atxData); err != nil { return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err) } @@ -433,11 +434,18 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX) } -type atxParts struct { +type idData struct { + previous types.ATXID + units uint32 +} + +type activationTx struct { + *wire.ActivationTxV2 ticks uint64 weight uint64 effectiveUnits uint32 - units map[types.NodeID]uint32 + ids map[types.NodeID]idData + marriages []marriage } type nipostSize struct { @@ -495,9 +503,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error func (h *HandlerV2) syntacticallyValidateDeps( ctx context.Context, atx *wire.ActivationTxV2, -) (*atxParts, error) { - parts := atxParts{ - units: make(map[types.NodeID]uint32), +) (*activationTx, error) { + result := activationTx{ + ActivationTxV2: atx, + ids: make(map[types.NodeID]idData), } if atx.Initial != nil { if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil { @@ -585,7 +594,7 @@ func (h *HandlerV2) syntacticallyValidateDeps( nipostSizes[i].ticks = leaves / h.tickSize } - parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp() + result.effectiveUnits, result.weight, err = nipostSizes.sumUp() if err != nil { return nil, err } @@ -593,9 +602,10 @@ func (h *HandlerV2) syntacticallyValidateDeps( // validate all niposts var smesherCommitment *types.ATXID for _, niposts := range atx.NiPosts { - for _, post := range niposts.Posts { - id := equivocationSet[post.MarriageIndex] + for _, p := range niposts.Posts { + id := equivocationSet[p.MarriageIndex] var commitment types.ATXID + var previous types.ATXID if atx.Initial != nil { commitment = atx.Initial.CommitmentATX } else { @@ -607,15 +617,16 @@ func (h *HandlerV2) syntacticallyValidateDeps( if id == atx.SmesherID { smesherCommitment = &commitment } + previous = previousAtxs[p.PrevATXIndex].ID() } err := h.nipostValidator.PostV2( ctx, id, commitment, - wire.PostFromWireV1(&post.Post), + wire.PostFromWireV1(&p.Post), niposts.Challenge[:], - post.NumUnits, + p.NumUnits, PostSubset([]byte(h.local)), ) invalidIdx := &verifying.ErrInvalidIndex{} @@ -636,7 +647,10 @@ func (h *HandlerV2) syntacticallyValidateDeps( if err != nil { return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err) } - parts.units[id] = post.NumUnits + result.ids[id] = idData{ + previous: previous, + units: p.NumUnits, + } } } @@ -650,18 +664,12 @@ func (h *HandlerV2) syntacticallyValidateDeps( } } - parts.ticks = nipostSizes.minTicks() - return &parts, nil + result.ticks = nipostSizes.minTicks() + return &result, nil } -func (h *HandlerV2) checkMalicious( - ctx context.Context, - tx *sql.Tx, - watx *wire.ActivationTxV2, - marrying []marriage, - ids []types.NodeID, -) error { - malicious, err := identities.IsMalicious(tx, watx.SmesherID) +func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error { + malicious, err := identities.IsMalicious(tx, atx.SmesherID) if err != nil { return fmt.Errorf("checking if node is malicious: %w", err) } @@ -669,7 +677,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoubleMarry(ctx, tx, watx.ID(), marrying) + malicious, err = h.checkDoubleMarry(ctx, tx, atx.ID(), atx.marriages) if err != nil { return fmt.Errorf("checking double marry: %w", err) } @@ -677,7 +685,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoublePost(ctx, tx, watx, ids) + malicious, err = h.checkDoublePost(ctx, tx, atx) if err != nil { return fmt.Errorf("checking double post: %w", err) } @@ -685,7 +693,7 @@ func (h *HandlerV2) checkMalicious( return nil } - malicious, err = h.checkDoubleMerge(ctx, tx, watx) + malicious, err = h.checkDoubleMerge(ctx, tx, atx) if err != nil { return fmt.Errorf("checking double merge: %w", err) } @@ -693,6 +701,14 @@ func (h *HandlerV2) checkMalicious( return nil } + malicious, err = h.checkPrevAtx(ctx, tx, atx) + if err != nil { + return fmt.Errorf("checking previous ATX: %w", err) + } + if malicious { + return nil + } + // TODO(mafa): contextual validation: // 1. check double-publish = ID contributed post to two ATXs in the same epoch // 2. check previous ATX @@ -726,10 +742,9 @@ func (h *HandlerV2) checkDoubleMarry( func (h *HandlerV2) checkDoublePost( ctx context.Context, tx *sql.Tx, - atx *wire.ActivationTxV2, - ids []types.NodeID, + atx *activationTx, ) (bool, error) { - for _, id := range ids { + for id := range atx.ids { atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): @@ -755,49 +770,72 @@ func (h *HandlerV2) checkDoublePost( return false, nil } -func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) { - if watx.MarriageATX == nil { +func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + if atx.MarriageATX == nil { return false, nil } - ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch) + ids, err := atxs.MergeConflict(tx, *atx.MarriageATX, atx.PublishEpoch) switch { case errors.Is(err, sql.ErrNotFound): return false, nil case err != nil: return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err) } - otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() }) + otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != atx.ID() }) other := ids[otherIndex] h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof", - zap.Stringer("marriage_atx", *watx.MarriageATX), - zap.Stringer("atx", watx.ID()), + zap.Stringer("marriage_atx", *atx.MarriageATX), + zap.Stringer("atx", atx.ID()), zap.Stringer("other_atx", other), - zap.Stringer("smesher_id", watx.SmesherID), + zap.Stringer("smesher_id", atx.SmesherID), ) // TODO(mafa): finish proof proof := &wire.ATXProof{ ProofType: wire.DoubleMerge, } - return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof) + return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof) +} + +func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) { + for id, data := range atx.ids { + prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch) + if err != nil && !errors.Is(err, sql.ErrNotFound) { + return false, fmt.Errorf("get last atx by node id: %w", err) + } + if prevID == data.previous { + continue + } + + h.logger.Debug("atx references a wrong previous ATX", + log.ZShortStringer("smesherID", id), + log.ZShortStringer("actual", data.previous), + log.ZShortStringer("expected", prevID), + ) + + // TODO(mafa): finish proof + proof := &wire.ATXProof{ + ProofType: wire.InvalidPrevious, + } + return true, h.malPublisher.Publish(ctx, id, proof) + } + return false, nil } // Store an ATX in the DB. func (h *HandlerV2) storeAtx( ctx context.Context, atx *types.ActivationTx, - watx *wire.ActivationTxV2, - marrying []marriage, - units map[types.NodeID]uint32, + watx *activationTx, ) error { if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error { - if len(marrying) != 0 { + if len(watx.marriages) != 0 { marriageData := identities.MarriageData{ ATX: atx.ID(), Target: atx.SmesherID, } - for i, m := range marrying { + for i, m := range watx.marriages { marriageData.Signature = m.signature marriageData.Index = i if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil { @@ -810,8 +848,8 @@ func (h *HandlerV2) storeAtx( if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("add atx to db: %w", err) } - for id, units := range units { - err = atxs.SetUnits(tx, atx.ID(), id, units) + for id, p := range watx.ids { + err = atxs.SetUnits(tx, atx.ID(), id, p.units) if err != nil && !errors.Is(err, sql.ErrObjectExists) { return fmt.Errorf("setting atx units for ID %s: %w", id, err) } @@ -830,7 +868,7 @@ func (h *HandlerV2) storeAtx( // TODO(mafa): don't store own ATX if it would mark the node as malicious // this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in // the gossip handler (not sync!) - err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units)) + err := h.checkMalicious(ctx, tx, watx) if err != nil { return fmt.Errorf("check malicious: %w", err) } diff --git a/activation/handler_v2_test.go b/activation/handler_v2_test.go index 06f52f9b69..4f0bafaf3b 100644 --- a/activation/handler_v2_test.go +++ b/activation/handler_v2_test.go @@ -1898,6 +1898,52 @@ func Test_CalculatingUnits(t *testing.T) { }) } +func TestContextual_PreviousATX(t *testing.T) { + golden := types.RandomATXID() + atxHndlr := newV2TestHandler(t, golden) + var ( + signers []*signing.EdSigner + eqSet []types.NodeID + ) + for range 3 { + sig, err := signing.NewEdSigner() + require.NoError(t, err) + signers = append(signers, sig) + eqSet = append(eqSet, sig.NodeID()) + } + + mATX, otherAtxs := marryIDs(t, atxHndlr, signers, golden) + + // signer 1 creates a solo ATX + soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID()) + soloAtx.Sign(signers[1]) + atxHndlr.expectAtxV2(soloAtx) + err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now()) + require.NoError(t, err) + + // create a MergedATX for all IDs + merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID()) + post := wire.SubPostV2{ + MarriageIndex: 1, + PrevATXIndex: 1, + NumUnits: soloAtx.TotalNumUnits(), + } + merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post) + // Pass a wrong previous ATX for signer 1. It's already been used for soloATX + // (which should be used for the previous ATX for signer 1). + merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID()) + matxID := mATX.ID() + merged.MarriageATX = &matxID + merged.Sign(signers[0]) + + atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100}) + atxHndlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Cond(func(data any) bool { + return data.(*wire.ATXProof).ProofType == wire.InvalidPrevious + })) + err = atxHndlr.processATX(context.Background(), "", merged, time.Now()) + require.NoError(t, err) +} + func Test_CalculatingWeight(t *testing.T) { t.Parallel() t.Run("total weight must not overflow uint64", func(t *testing.T) { diff --git a/activation/wire/malfeasance.go b/activation/wire/malfeasance.go index d8e60a4127..29460ec485 100644 --- a/activation/wire/malfeasance.go +++ b/activation/wire/malfeasance.go @@ -13,6 +13,7 @@ const ( DoublePublish ProofType = iota + 1 DoubleMarry DoubleMerge + InvalidPrevious InvalidPost ) diff --git a/sql/atxs/atxs.go b/sql/atxs/atxs.go index 5e14cddde1..e2f9fe6729 100644 --- a/sql/atxs/atxs.go +++ b/sql/atxs/atxs.go @@ -248,7 +248,8 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er } // PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch. -// It returns the newest ATX ID that was published before the given public epoch. +// It returns the newest ATX ID containing PoST of the given node ID +// that was published before the given public epoch. func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) { enc := func(stmt *sql.Statement) { stmt.BindBytes(1, nodeID.Bytes()) @@ -260,10 +261,10 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID } if rows, err := db.Exec(` - select id from atxs - where pubkey = ?1 and epoch < ?2 - order by epoch desc - limit 1;`, enc, dec); err != nil { + SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id + WHERE posts.pubkey = ?1 AND atxs.epoch < ?2 + ORDER BY atxs.epoch DESC + LIMIT 1;`, enc, dec); err != nil { return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err) } else if rows == 0 { return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound) diff --git a/sql/atxs/atxs_test.go b/sql/atxs/atxs_test.go index 1dd1914968..ecdd704cd5 100644 --- a/sql/atxs/atxs_test.go +++ b/sql/atxs/atxs_test.go @@ -1330,3 +1330,56 @@ func Test_MergeConflict(t *testing.T) { require.Len(t, ids, 2) }) } + +func TestPrevIDByNodeID(t *testing.T) { + t.Run("no previous ATXs", func(t *testing.T) { + db := sql.InMemory() + _, err := atxs.PrevIDByNodeID(db, types.RandomNodeID(), 0) + require.ErrorIs(t, err, sql.ErrNotFound) + }) + t.Run("filters by epoch", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 4)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetUnits(db, atx2.ID(), sig.NodeID(), 4)) + + _, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 1) + require.ErrorIs(t, err, sql.ErrNotFound) + + prevID, err := atxs.PrevIDByNodeID(db, sig.NodeID(), 2) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + + prevID, err = atxs.PrevIDByNodeID(db, sig.NodeID(), 3) + require.NoError(t, err) + require.Equal(t, atx2.ID(), prevID) + }) + t.Run("the previous is merged and ID is not the signer", func(t *testing.T) { + db := sql.InMemory() + sig, err := signing.NewEdSigner() + require.NoError(t, err) + id := types.RandomNodeID() + + atx1, blob1 := newAtx(t, sig, withPublishEpoch(1)) + require.NoError(t, atxs.Add(db, atx1, blob1)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), sig.NodeID(), 4)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), id, 8)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), types.RandomNodeID(), 12)) + + atx2, blob2 := newAtx(t, sig, withPublishEpoch(2)) + require.NoError(t, atxs.Add(db, atx2, blob2)) + require.NoError(t, atxs.SetUnits(db, atx2.ID(), sig.NodeID(), 4)) + require.NoError(t, atxs.SetUnits(db, atx1.ID(), types.RandomNodeID(), 12)) + + prevID, err := atxs.PrevIDByNodeID(db, id, 3) + require.NoError(t, err) + require.Equal(t, atx1.ID(), prevID) + }) +}