Skip to content

Commit

Permalink
previous atx malfeasance check for ATX V2
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Jul 30, 2024
1 parent 627de2d commit e11433c
Show file tree
Hide file tree
Showing 4 changed files with 192 additions and 46 deletions.
127 changes: 86 additions & 41 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,11 @@ func (h *HandlerV2) processATX(
return nil, fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err)
}

parts, proof, err := h.syntacticallyValidateDeps(ctx, watx)
atxData, proof, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return nil, fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}
atxData.marriages = marrying

if proof != nil {
return proof, err
Expand All @@ -125,9 +126,9 @@ func (h *HandlerV2) processATX(
MarriageATX: watx.MarriageATX,
Coinbase: watx.Coinbase,
BaseTickHeight: baseTickHeight,
NumUnits: parts.effectiveUnits,
TickCount: parts.ticks,
Weight: parts.weight,
NumUnits: atxData.effectiveUnits,
TickCount: atxData.ticks,
Weight: atxData.weight,
VRFNonce: types.VRFPostIndex(watx.VRFNonce),
SmesherID: watx.SmesherID,
}
Expand All @@ -145,7 +146,7 @@ func (h *HandlerV2) processATX(
atx.SetID(watx.ID())
atx.SetReceived(received)

proof, err = h.storeAtx(ctx, atx, watx, marrying, parts.units)
proof, err = h.storeAtx(ctx, atx, atxData)
if err != nil {
return nil, fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}
Expand Down Expand Up @@ -439,11 +440,18 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e
return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX)
}

type atxParts struct {
type idData struct {
previous types.ATXID
units uint32
}

type activationTx struct {
*wire.ActivationTxV2
ticks uint64
weight uint64
effectiveUnits uint32
units map[types.NodeID]uint32
ids map[types.NodeID]idData
marriages []marriage
}

type nipostSize struct {
Expand Down Expand Up @@ -501,9 +509,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error
func (h *HandlerV2) syntacticallyValidateDeps(
ctx context.Context,
atx *wire.ActivationTxV2,
) (*atxParts, *mwire.MalfeasanceProof, error) {
parts := atxParts{
units: make(map[types.NodeID]uint32),
) (*activationTx, *mwire.MalfeasanceProof, error) {
result := activationTx{
ActivationTxV2: atx,
ids: make(map[types.NodeID]idData),
}
if atx.Initial != nil {
if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil {
Expand Down Expand Up @@ -591,17 +600,18 @@ func (h *HandlerV2) syntacticallyValidateDeps(
nipostSizes[i].ticks = leaves / h.tickSize
}

parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp()
result.effectiveUnits, result.weight, err = nipostSizes.sumUp()
if err != nil {
return nil, nil, err
}

// validate all niposts
var smesherCommitment *types.ATXID
for _, niposts := range atx.NiPosts {
for _, post := range niposts.Posts {
id := equivocationSet[post.MarriageIndex]
for _, p := range niposts.Posts {
id := equivocationSet[p.MarriageIndex]
var commitment types.ATXID
var previous types.ATXID
if atx.Initial != nil {
commitment = atx.Initial.CommitmentATX
} else {
Expand All @@ -613,15 +623,16 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if id == atx.SmesherID {
smesherCommitment = &commitment
}
previous = previousAtxs[p.PrevATXIndex].ID()
}

err := h.nipostValidator.PostV2(
ctx,
id,
commitment,
wire.PostFromWireV1(&post.Post),
wire.PostFromWireV1(&p.Post),
niposts.Challenge[:],
post.NumUnits,
p.NumUnits,
PostSubset([]byte(h.local)),
)
var invalidIdx *verifying.ErrInvalidIndex
Expand All @@ -636,7 +647,10 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if err != nil {
return nil, nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err)
}
parts.units[id] = post.NumUnits
result.ids[id] = idData{
previous: previous,
units: p.NumUnits,
}
}
}

Expand All @@ -650,40 +664,44 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}
}

parts.ticks = nipostSizes.minTicks()
result.ticks = nipostSizes.minTicks()

return &parts, nil, nil
return &result, nil, nil
}

func (h *HandlerV2) checkMalicious(
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []marriage,
) (bool, *mwire.MalfeasanceProof, error) {
malicious, err := identities.IsMalicious(tx, watx.SmesherID)
func (h *HandlerV2) checkMalicious(tx *sql.Tx, atx *activationTx) (bool, *mwire.MalfeasanceProof, error) {
malicious, err := identities.IsMalicious(tx, atx.SmesherID)
if err != nil {
return false, nil, fmt.Errorf("checking if node is malicious: %w", err)
}
if malicious {
return true, nil, nil
}

proof, err := h.checkDoubleMarry(tx, marrying)
proof, err := h.checkDoubleMarry(tx, atx.marriages)
if err != nil {
return false, nil, fmt.Errorf("checking double marry: %w", err)
}
if proof != nil {
return true, proof, nil
}

proof, err = h.checkDoubleMerge(tx, watx)
proof, err = h.checkDoubleMerge(tx, atx)
if err != nil {
return false, nil, fmt.Errorf("checking double merge: %w", err)
}
if proof != nil {
return true, proof, nil
}

proof, err = h.checkPrevAtx(tx, atx)
if err != nil {
return false, nil, fmt.Errorf("checking previous ATX: %w", err)

Check warning on line 699 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L699

Added line #L699 was not covered by tests
}
if proof != nil {
return true, proof, nil
}

// TODO: contextual validation:
// 1. check double-publish
// 2. check previous ATX
Expand Down Expand Up @@ -713,11 +731,11 @@ func (h *HandlerV2) checkDoubleMarry(tx *sql.Tx, marrying []marriage) (*mwire.Ma
return nil, nil
}

func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, watx *wire.ActivationTxV2) (*mwire.MalfeasanceProof, error) {
if watx.MarriageATX == nil {
func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, atx *activationTx) (*mwire.MalfeasanceProof, error) {
if atx.MarriageATX == nil {
return nil, nil
}
id, err := atxs.AtxWithMarriage(tx, *watx.MarriageATX, watx.PublishEpoch)
id, err := atxs.AtxWithMarriage(tx, *atx.MarriageATX, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
return nil, nil
Expand All @@ -726,10 +744,10 @@ func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, watx *wire.ActivationTxV2) (*mw
}

h.logger.Debug("second merged ATX for single marriage - creating malfeasance proof",
zap.Stringer("marriage_atx", *watx.MarriageATX),
zap.Stringer("atx", watx.ID()),
zap.Stringer("marriage_atx", *atx.MarriageATX),
zap.Stringer("atx", atx.ID()),
zap.Stringer("other atx", id),
zap.Stringer("smesher_id", watx.SmesherID),
zap.Stringer("smesher_id", atx.SmesherID),
)

// FIXME: implement the proof
Expand All @@ -742,32 +760,59 @@ func (h *HandlerV2) checkDoubleMerge(tx *sql.Tx, watx *wire.ActivationTxV2) (*mw
return proof, nil
}

func (h *HandlerV2) checkPrevAtx(tx *sql.Tx, atx *activationTx) (*mwire.MalfeasanceProof, error) {
for id, data := range atx.ids {
prevID, err := atxs.PrevIDByNodeID(tx, id, atx.PublishEpoch)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
return nil, fmt.Errorf("get last atx by node id: %w", err)

Check warning on line 767 in activation/handler_v2.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v2.go#L767

Added line #L767 was not covered by tests
}
if prevID == data.previous {
continue
}

h.logger.Debug("atx references a wrong previous ATX",
log.ZShortStringer("smesherID", id),
log.ZShortStringer("actual", data.previous),
log.ZShortStringer("expected", prevID),
)

// FIXME: implement the proof
proof := &mwire.MalfeasanceProof{
Proof: mwire.Proof{
Type: mwire.DoubleMarry,
Data: &mwire.DoubleMarryProof{},
},
}
return proof, nil

}
return nil, nil
}

// Store an ATX in the DB.
// TODO: detect malfeasance and create proofs.
func (h *HandlerV2) storeAtx(
ctx context.Context,
atx *types.ActivationTx,
watx *wire.ActivationTxV2,
marrying []marriage,
units map[types.NodeID]uint32,
watx *activationTx,
) (*mwire.MalfeasanceProof, error) {
var (
malicious bool
proof *mwire.MalfeasanceProof
)
if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error {
var err error
malicious, proof, err = h.checkMalicious(tx, watx, marrying)
malicious, proof, err = h.checkMalicious(tx, watx)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}

if len(marrying) != 0 {
if len(watx.marriages) != 0 {
marriageData := identities.MarriageData{
ATX: atx.ID(),
Target: atx.SmesherID,
}
for i, m := range marrying {
for i, m := range watx.marriages {
marriageData.Signature = m.signature
marriageData.Index = i
if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil {
Expand All @@ -787,8 +832,8 @@ func (h *HandlerV2) storeAtx(
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
for id, units := range units {
err = atxs.SetUnits(tx, atx.ID(), id, units)
for id, p := range watx.ids {
err = atxs.SetUnits(tx, atx.ID(), id, p.units)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("setting atx units for ID %s: %w", id, err)
}
Expand All @@ -812,7 +857,7 @@ func (h *HandlerV2) storeAtx(
for _, id := range set {
allMalicious[id] = struct{}{}
}
for _, m := range marrying {
for _, m := range watx.marriages {
allMalicious[m.id] = struct{}{}
}
}
Expand Down
47 changes: 47 additions & 0 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1865,6 +1865,53 @@ func Test_CalculatingUnits(t *testing.T) {
})
}

func TestContextual_PreviousATX(t *testing.T) {
golden := types.RandomATXID()
atxHndlr := newV2TestHandler(t, golden)
var (
signers []*signing.EdSigner
eqSet []types.NodeID
)
for range 3 {
sig, err := signing.NewEdSigner()
require.NoError(t, err)
signers = append(signers, sig)
eqSet = append(eqSet, sig.NodeID())
}

mATX, otherAtxs := marryIDs(t, atxHndlr, signers, golden)

// signer 1 creates a solo ATX
soloAtx := newSoloATXv2(t, mATX.PublishEpoch+1, otherAtxs[0].ID(), mATX.ID())
soloAtx.Sign(signers[1])
atxHndlr.expectAtxV2(soloAtx)
_, err := atxHndlr.processATX(context.Background(), "", soloAtx, time.Now())
require.NoError(t, err)

// create a MergedATX for all IDs
merged := newSoloATXv2(t, mATX.PublishEpoch+2, mATX.ID(), mATX.ID())
post := wire.SubPostV2{
MarriageIndex: 1,
PrevATXIndex: 1,
NumUnits: soloAtx.TotalNumUnits(),
}
merged.NiPosts[0].Posts = append(merged.NiPosts[0].Posts, post)
// Pass a wrong previous ATX for signer 1. It's already been used for soloATX
// (which should be used for the previous ATX for signer 1).
merged.PreviousATXs = append(merged.PreviousATXs, otherAtxs[0].ID())
matxID := mATX.ID()
merged.MarriageATX = &matxID
merged.Sign(signers[0])

atxHndlr.expectMergedAtxV2(merged, eqSet, []uint64{100})
for _, sig := range signers {
atxHndlr.mtortoise.EXPECT().OnMalfeasance(sig.NodeID())
}
p, err := atxHndlr.processATX(context.Background(), "", merged, time.Now())
require.NoError(t, err)
require.NotNil(t, p)
}

func Test_CalculatingWeight(t *testing.T) {
t.Parallel()
t.Run("total weight must not overflow uint64", func(t *testing.T) {
Expand Down
11 changes: 6 additions & 5 deletions sql/atxs/atxs.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ func GetLastIDByNodeID(db sql.Executor, nodeID types.NodeID) (id types.ATXID, er
}

// PrevIDByNodeID returns the previous ATX ID for a given node ID and public epoch.
// It returns the newest ATX ID that was published before the given public epoch.
// It returns the newest ATX ID containing PoST of the given node ID
// that was published before the given public epoch.
func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID) (id types.ATXID, err error) {
enc := func(stmt *sql.Statement) {
stmt.BindBytes(1, nodeID.Bytes())
Expand All @@ -260,10 +261,10 @@ func PrevIDByNodeID(db sql.Executor, nodeID types.NodeID, pubEpoch types.EpochID
}

if rows, err := db.Exec(`
select id from atxs
where pubkey = ?1 and epoch < ?2
order by epoch desc
limit 1;`, enc, dec); err != nil {
SELECT posts.atxid FROM posts JOIN atxs ON posts.atxid = atxs.id
WHERE posts.pubkey = ?1 AND atxs.epoch < ?2
ORDER BY atxs.epoch DESC
LIMIT 1;`, enc, dec); err != nil {
return types.EmptyATXID, fmt.Errorf("exec nodeID %v, epoch %d: %w", nodeID, pubEpoch, err)
} else if rows == 0 {
return types.EmptyATXID, fmt.Errorf("exec nodeID %s, epoch %d: %w", nodeID, pubEpoch, sql.ErrNotFound)
Expand Down
Loading

0 comments on commit e11433c

Please sign in to comment.