Skip to content

Commit

Permalink
previous atx malfeasance check for ATX V2
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Aug 6, 2024
1 parent dc5fcaf commit 5c75d66
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 53 deletions.
134 changes: 86 additions & 48 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,20 @@ func (h *HandlerV2) processATX(
return fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err)
}

parts, err := h.syntacticallyValidateDeps(ctx, watx)
atxData, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}
atxData.marriages = marrying

atx := &types.ActivationTx{
PublishEpoch: watx.PublishEpoch,
MarriageATX: watx.MarriageATX,
Coinbase: watx.Coinbase,
BaseTickHeight: baseTickHeight,
NumUnits: parts.effectiveUnits,
TickCount: parts.ticks,
Weight: parts.weight,
NumUnits: atxData.effectiveUnits,
TickCount: atxData.ticks,
Weight: atxData.weight,
VRFNonce: types.VRFPostIndex(watx.VRFNonce),
SmesherID: watx.SmesherID,
}
Expand All @@ -140,7 +141,7 @@ func (h *HandlerV2) processATX(
atx.SetID(watx.ID())
atx.SetReceived(received)

if err := h.storeAtx(ctx, atx, watx, marrying, parts.units); err != nil {
if err := h.storeAtx(ctx, atx, atxData); err != nil {
return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}

Expand Down Expand Up @@ -433,11 +434,18 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e
return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX)
}

type atxParts struct {
type idData struct {
previous types.ATXID
units uint32
}

type activationTx struct {
*wire.ActivationTxV2
ticks uint64
weight uint64
effectiveUnits uint32
units map[types.NodeID]uint32
ids map[types.NodeID]idData
marriages []marriage
}

type nipostSize struct {
Expand Down Expand Up @@ -495,9 +503,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error
func (h *HandlerV2) syntacticallyValidateDeps(
ctx context.Context,
atx *wire.ActivationTxV2,
) (*atxParts, error) {
parts := atxParts{
units: make(map[types.NodeID]uint32),
) (*activationTx, error) {
result := activationTx{
ActivationTxV2: atx,
ids: make(map[types.NodeID]idData),
}
if atx.Initial != nil {
if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil {
Expand Down Expand Up @@ -585,17 +594,18 @@ func (h *HandlerV2) syntacticallyValidateDeps(
nipostSizes[i].ticks = leaves / h.tickSize
}

parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp()
result.effectiveUnits, result.weight, err = nipostSizes.sumUp()
if err != nil {
return nil, err
}

// validate all niposts
var smesherCommitment *types.ATXID
for _, niposts := range atx.NiPosts {
for _, post := range niposts.Posts {
id := equivocationSet[post.MarriageIndex]
for _, p := range niposts.Posts {
id := equivocationSet[p.MarriageIndex]
var commitment types.ATXID
var previous types.ATXID
if atx.Initial != nil {
commitment = atx.Initial.CommitmentATX
} else {
Expand All @@ -607,15 +617,16 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if id == atx.SmesherID {
smesherCommitment = &commitment
}
previous = previousAtxs[p.PrevATXIndex].ID()
}

err := h.nipostValidator.PostV2(
ctx,
id,
commitment,
wire.PostFromWireV1(&post.Post),
wire.PostFromWireV1(&p.Post),
niposts.Challenge[:],
post.NumUnits,
p.NumUnits,
PostSubset([]byte(h.local)),
)
invalidIdx := &verifying.ErrInvalidIndex{}
Expand All @@ -636,7 +647,10 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if err != nil {
return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err)
}
parts.units[id] = post.NumUnits
result.ids[id] = idData{
previous: previous,
units: p.NumUnits,
}
}
}

Expand All @@ -650,49 +664,51 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}
}

parts.ticks = nipostSizes.minTicks()
return &parts, nil
result.ticks = nipostSizes.minTicks()
return &result, nil
}

func (h *HandlerV2) checkMalicious(
ctx context.Context,
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []marriage,
ids []types.NodeID,
) error {
malicious, err := identities.IsMalicious(tx, watx.SmesherID)
func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error {
malicious, err := identities.IsMalicious(tx, atx.SmesherID)
if err != nil {
return fmt.Errorf("checking if node is malicious: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMarry(ctx, tx, watx.ID(), marrying)
malicious, err = h.checkDoubleMarry(ctx, tx, atx.ID(), atx.marriages)
if err != nil {
return fmt.Errorf("checking double marry: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoublePost(ctx, tx, watx, ids)
malicious, err = h.checkDoublePost(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double post: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMerge(ctx, tx, watx)
malicious, err = h.checkDoubleMerge(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double merge: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkPrevAtx(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking previous ATX: %w", err)

Check warning on line 706 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L706

Added line #L706 was not covered by tests
}
if malicious {
return nil
}

// TODO(mafa): contextual validation:
// 1. check double-publish = ID contributed post to two ATXs in the same epoch
// 2. check previous ATX
Expand Down Expand Up @@ -726,10 +742,9 @@ func (h *HandlerV2) checkDoubleMarry(
func (h *HandlerV2) checkDoublePost(
ctx context.Context,
tx *sql.Tx,
atx *wire.ActivationTxV2,
ids []types.NodeID,
atx *activationTx,
) (bool, error) {
for _, id := range ids {
for id := range atx.ids {
atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
Expand All @@ -755,49 +770,72 @@ func (h *HandlerV2) checkDoublePost(
return false, nil
}

func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) {
if watx.MarriageATX == nil {
func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
if atx.MarriageATX == nil {
return false, nil
}
ids, err := atxs.MergeConflict(tx, *watx.MarriageATX, watx.PublishEpoch)
ids, err := atxs.MergeConflict(tx, *atx.MarriageATX, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
return false, nil
case err != nil:
return false, fmt.Errorf("searching for ATXs with the same marriage ATX: %w", err)
}
otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != watx.ID() })
otherIndex := slices.IndexFunc(ids, func(id types.ATXID) bool { return id != atx.ID() })
other := ids[otherIndex]

h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof",
zap.Stringer("marriage_atx", *watx.MarriageATX),
zap.Stringer("atx", watx.ID()),
zap.Stringer("marriage_atx", *atx.MarriageATX),
zap.Stringer("atx", atx.ID()),
zap.Stringer("other_atx", other),
zap.Stringer("smesher_id", watx.SmesherID),
zap.Stringer("smesher_id", atx.SmesherID),
)

// TODO(mafa): finish proof
proof := &wire.ATXProof{
ProofType: wire.DoubleMerge,
}
return true, h.malPublisher.Publish(ctx, watx.SmesherID, proof)
return true, h.malPublisher.Publish(ctx, atx.SmesherID, proof)
}

func (h *HandlerV2) checkPrevAtx(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for id, data := range atx.ids {
prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return false, fmt.Errorf("get last atx by node id: %w", err)

Check warning on line 805 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L805

Added line #L805 was not covered by tests
}
if prevID == data.previous {
continue
}

h.logger.Debug("atx references a wrong previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("actual", data.previous),
log.ZShortStringer("expected", prevID),
)

// TODO(mafa): finish proof
proof := &wire.ATXProof{
ProofType: wire.InvalidPrevious,
}
return true, h.malPublisher.Publish(ctx, id, proof)
}
return false, nil
}

// Store an ATX in the DB.
func (h *HandlerV2) storeAtx(
ctx context.Context,
atx *types.ActivationTx,
watx *wire.ActivationTxV2,
marrying []marriage,
units map[types.NodeID]uint32,
watx *activationTx,
) error {
if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error {
if len(marrying) != 0 {
if len(watx.marriages) != 0 {
marriageData := identities.MarriageData{
ATX: atx.ID(),
Target: atx.SmesherID,
}
for i, m := range marrying {
for i, m := range watx.marriages {
marriageData.Signature = m.signature
marriageData.Index = i
if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil {
Expand All @@ -810,8 +848,8 @@ func (h *HandlerV2) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
for id, units := range units {
err = atxs.SetUnits(tx, atx.ID(), id, units)
for id, p := range watx.ids {
err = atxs.SetUnits(tx, atx.ID(), id, p.units)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("setting atx units for ID %s: %w", id, err)
}
Expand All @@ -830,7 +868,7 @@ func (h *HandlerV2) storeAtx(
// TODO(mafa): don't store own ATX if it would mark the node as malicious
// this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in
// the gossip handler (not sync!)
err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units))
err := h.checkMalicious(ctx, tx, watx)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}
Expand Down
46 changes: 46 additions & 0 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1898,6 +1898,52 @@ func Test_CalculatingUnits(t *testing.T) {
})
}

func TestContextual_PreviousATX(t *testing.T) {
golden := types.RandomATXID()
atxHndlr := newV2TestHandler(t, golden)
var (
signers []*signing.EdSigner
eqSet []types.NodeID
)
for range 3 {
sig, err := signing.NewEdSigner()
require.NoError(t, err)
signers = append(signers, sig)
eqSet = append(eqSet, sig.NodeID())
}

mATX, otherAtxs := marryIDs(t, atxHndlr, signers, golden)

// signer 1 creates a solo ATX
soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID())
soloAtx.Sign(signers[1])
atxHndlr.expectAtxV2(soloAtx)
err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now())
require.NoError(t, err)

// create a MergedATX for all IDs
merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID())
post := wire.SubPostV2{
MarriageIndex: 1,
PrevATXIndex: 1,
NumUnits: soloAtx.TotalNumUnits(),
}
merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post)
// Pass a wrong previous ATX for signer 1. It's already been used for soloATX
// (which should be used for the previous ATX for signer 1).
merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID())
matxID := mATX.ID()
merged.MarriageATX = &matxID
merged.Sign(signers[0])

atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100})
atxHndlr.mMalPublish.EXPECT().Publish(gomock.Any(), signers[1].NodeID(), gomock.Cond(func(data any) bool {
return data.(*wire.ATXProof).ProofType == wire.InvalidPrevious
}))
err = atxHndlr.processATX(context.Background(), "", merged, time.Now())
require.NoError(t, err)
}

func Test_CalculatingWeight(t *testing.T) {
t.Parallel()
t.Run("total weight must not overflow uint64", func(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions activation/wire/malfeasance.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ const (
DoublePublish ProofType = iota + 1
DoubleMarry
DoubleMerge
InvalidPrevious
InvalidPost
)

Expand Down
11 changes: 6 additions & 5 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er
}

// PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch.
// It returns the newest ATX ID that was published before the given public epoch.
// It returns the newest ATX ID containing PoST of the given node ID
// that was published before the given public epoch.
func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) {
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, nodeID.Bytes())
Expand All @@ -260,10 +261,10 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID
}

if rows, err := db.Exec(`
select id from atxs
where pubkey = ?1 and epoch < ?2
order by epoch desc
limit 1;`, enc, dec); err != nil {
SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id
WHERE posts.pubkey = ?1 AND atxs.epoch < ?2
ORDER BY atxs.epoch DESC
LIMIT 1;`, enc, dec); err != nil {
return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err)
} else if rows == 0 {
return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound)
Expand Down
Loading

0 comments on commit 5c75d66

Please sign in to comment.