Skip to content

Commit

Permalink
record previous ATXs in posts table
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Aug 8, 2024
1 parent 6ff9fe5 commit 9a834fa
Show file tree
Hide file tree
Showing 13 changed files with 138 additions and 136 deletions.
3 changes: 2 additions & 1 deletion activation/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ func publishAtxV1(
func(_ context.Context, _ string, got []byte) error {
return codec.Decode(got, &watx)
})
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob(), watx.PrevATXID))
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob()))
require.NoError(tb, atxs.SetPost(tab.db, watx.ID(), watx.PrevATXID, watx.SmesherID, watx.NumUnits))
tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false)
return &watx
}
Expand Down
4 changes: 2 additions & 2 deletions activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,11 +520,11 @@ func (h *HandlerV1) storeAtx(
return fmt.Errorf("check malicious: %w", err)
}

err = atxs.Add(tx, atx, watx.Blob(), watx.PrevATXID)
err = atxs.Add(tx, atx, watx.Blob())
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
err = atxs.SetUnits(tx, atx.ID(), atx.SmesherID, watx.NumUnits)
err = atxs.SetPost(tx, atx.ID(), watx.PrevATXID, atx.SmesherID, watx.NumUnits)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("set atx units: %w", err)
}
Expand Down
96 changes: 46 additions & 50 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,19 +111,20 @@ func (h *HandlerV2) processATX(
return fmt.Errorf("%w: validating marriages: %w", pubsub.ErrValidationReject, err)
}

parts, err := h.syntacticallyValidateDeps(ctx, watx)
atxData, err := h.syntacticallyValidateDeps(ctx, watx)
if err != nil {
return fmt.Errorf("%w: validating atx %s (deps): %w", pubsub.ErrValidationReject, watx.ID(), err)
}
atxData.marriages = marrying

atx := &types.ActivationTx{
PublishEpoch: watx.PublishEpoch,
MarriageATX: watx.MarriageATX,
Coinbase: watx.Coinbase,
BaseTickHeight: baseTickHeight,
NumUnits: parts.effectiveUnits,
TickCount: parts.ticks,
Weight: parts.weight,
NumUnits: atxData.effectiveUnits,
TickCount: atxData.ticks,
Weight: atxData.weight,
VRFNonce: types.VRFPostIndex(watx.VRFNonce),
SmesherID: watx.SmesherID,
}
Expand All @@ -138,7 +139,7 @@ func (h *HandlerV2) processATX(
atx.SetID(watx.ID())
atx.SetReceived(received)

if err := h.storeAtx(ctx, atx, watx, marrying, parts.units); err != nil {
if err := h.storeAtx(ctx, atx, atxData); err != nil {
return fmt.Errorf("cannot store atx %s: %w", atx.ShortString(), err)
}

Expand Down Expand Up @@ -431,11 +432,18 @@ func (h *HandlerV2) equivocationSet(atx *wire.ActivationTxV2) ([]types.NodeID, e
return identities.EquivocationSetByMarriageATX(h.cdb, *atx.MarriageATX)
}

type atxParts struct {
type idData struct {
previous types.ATXID
units uint32
}

type activationTx struct {
*wire.ActivationTxV2
ticks uint64
weight uint64
effectiveUnits uint32
units map[types.NodeID]uint32
ids map[types.NodeID]idData
marriages []marriage
}

type nipostSize struct {
Expand Down Expand Up @@ -493,9 +501,10 @@ func (h *HandlerV2) verifyIncludedIDsUniqueness(atx *wire.ActivationTxV2) error
func (h *HandlerV2) syntacticallyValidateDeps(
ctx context.Context,
atx *wire.ActivationTxV2,
) (*atxParts, error) {
parts := atxParts{
units: make(map[types.NodeID]uint32),
) (*activationTx, error) {
result := activationTx{
ActivationTxV2: atx,
ids: make(map[types.NodeID]idData),
}
if atx.Initial != nil {
if err := h.validateCommitmentAtx(h.goldenATXID, atx.Initial.CommitmentATX, atx.PublishEpoch); err != nil {
Expand Down Expand Up @@ -583,7 +592,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
nipostSizes[i].ticks = leaves / h.tickSize
}

parts.effectiveUnits, parts.weight, err = nipostSizes.sumUp()
result.effectiveUnits, result.weight, err = nipostSizes.sumUp()
if err != nil {
return nil, err
}
Expand All @@ -594,6 +603,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
for _, post := range niposts.Posts {
id := equivocationSet[post.MarriageIndex]
var commitment types.ATXID
var previous types.ATXID
if atx.Initial != nil {
commitment = atx.Initial.CommitmentATX
} else {
Expand All @@ -605,6 +615,7 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if id == atx.SmesherID {
smesherCommitment = &commitment
}
previous = previousAtxs[post.PrevATXIndex].ID()
}

err := h.nipostValidator.PostV2(
Expand Down Expand Up @@ -632,7 +643,10 @@ func (h *HandlerV2) syntacticallyValidateDeps(
if err != nil {
return nil, fmt.Errorf("validating post for ID %s: %w", id.ShortString(), err)
}
parts.units[id] = post.NumUnits
result.ids[id] = idData{
previous: previous,
units: post.NumUnits,
}
}
}

Expand All @@ -646,42 +660,36 @@ func (h *HandlerV2) syntacticallyValidateDeps(
}
}

parts.ticks = nipostSizes.minTicks()
return &parts, nil
result.ticks = nipostSizes.minTicks()
return &result, nil
}

func (h *HandlerV2) checkMalicious(
ctx context.Context,
tx *sql.Tx,
watx *wire.ActivationTxV2,
marrying []marriage,
ids []types.NodeID,
) error {
malicious, err := identities.IsMalicious(tx, watx.SmesherID)
func (h *HandlerV2) checkMalicious(ctx context.Context, tx *sql.Tx, atx *activationTx) error {
malicious, err := identities.IsMalicious(tx, atx.SmesherID)
if err != nil {
return fmt.Errorf("checking if node is malicious: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMarry(ctx, tx, watx, marrying)
malicious, err = h.checkDoubleMarry(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double marry: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoublePost(ctx, tx, watx, ids)
malicious, err = h.checkDoublePost(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double post: %w", err)
}
if malicious {
return nil
}

malicious, err = h.checkDoubleMerge(ctx, tx, watx)
malicious, err = h.checkDoubleMerge(ctx, tx, atx)
if err != nil {
return fmt.Errorf("checking double merge: %w", err)
}
Expand All @@ -697,13 +705,8 @@ func (h *HandlerV2) checkMalicious(
return nil
}

func (h *HandlerV2) checkDoubleMarry(
ctx context.Context,
tx *sql.Tx,
atx *wire.ActivationTxV2,
marrying []marriage,
) (bool, error) {
for _, m := range marrying {
func (h *HandlerV2) checkDoubleMarry(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for _, m := range atx.marriages {
mATX, err := identities.MarriageATX(tx, m.id)
if err != nil {
return false, fmt.Errorf("checking if ID is married: %w", err)
Expand All @@ -722,7 +725,7 @@ func (h *HandlerV2) checkDoubleMarry(
var otherAtx wire.ActivationTxV2
codec.MustDecode(blob.Bytes, &otherAtx)

proof, err := wire.NewDoubleMarryProof(tx, atx, &otherAtx, m.id)
proof, err := wire.NewDoubleMarryProof(tx, atx.ActivationTxV2, &otherAtx, m.id)
if err != nil {
return true, fmt.Errorf("creating double marry proof: %w", err)
}
Expand All @@ -732,13 +735,8 @@ func (h *HandlerV2) checkDoubleMarry(
return false, nil
}

func (h *HandlerV2) checkDoublePost(
ctx context.Context,
tx *sql.Tx,
atx *wire.ActivationTxV2,
ids []types.NodeID,
) (bool, error) {
for _, id := range ids {
func (h *HandlerV2) checkDoublePost(ctx context.Context, tx *sql.Tx, atx *activationTx) (bool, error) {
for id := range atx.ids {
atxids, err := atxs.FindDoublePublish(tx, id, atx.PublishEpoch)
switch {
case errors.Is(err, sql.ErrNotFound):
Expand All @@ -762,7 +760,7 @@ func (h *HandlerV2) checkDoublePost(
return false, nil
}

func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire.ActivationTxV2) (bool, error) {
func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *activationTx) (bool, error) {
if watx.MarriageATX == nil {
return false, nil
}
Expand Down Expand Up @@ -791,17 +789,15 @@ func (h *HandlerV2) checkDoubleMerge(ctx context.Context, tx *sql.Tx, watx *wire
func (h *HandlerV2) storeAtx(
ctx context.Context,
atx *types.ActivationTx,
watx *wire.ActivationTxV2,
marrying []marriage,
units map[types.NodeID]uint32,
watx *activationTx,
) error {
if err := h.cdb.WithTx(ctx, func(tx *sql.Tx) error {
if len(marrying) != 0 {
if len(watx.marriages) != 0 {
marriageData := identities.MarriageData{
ATX: atx.ID(),
Target: atx.SmesherID,
}
for i, m := range marrying {
for i, m := range watx.marriages {
marriageData.Signature = m.signature
marriageData.Index = i
if err := identities.SetMarriage(tx, m.id, &marriageData); err != nil {
Expand All @@ -810,12 +806,12 @@ func (h *HandlerV2) storeAtx(
}
}

err := atxs.Add(tx, atx, watx.Blob(), watx.PreviousATXs...)
err := atxs.Add(tx, atx, watx.Blob())
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
for id, units := range units {
err = atxs.SetUnits(tx, atx.ID(), id, units)
for id, post := range watx.ids {
err = atxs.SetPost(tx, atx.ID(), post.previous, id, post.units)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("setting atx units for ID %s: %w", id, err)
}
Expand All @@ -834,7 +830,7 @@ func (h *HandlerV2) storeAtx(
// TODO(mafa): don't store own ATX if it would mark the node as malicious
// this probably needs to be done by validating and storing own ATXs eagerly and skipping validation in
// the gossip handler (not sync!)
err := h.checkMalicious(ctx, tx, watx, marrying, maps.Keys(units))
err := h.checkMalicious(ctx, tx, watx)
if err != nil {
return fmt.Errorf("check malicious: %w", err)
}
Expand Down
8 changes: 4 additions & 4 deletions activation/handler_v2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,7 @@ func Test_ValidatePreviousATX(t *testing.T) {
t.Parallel()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), types.RandomNodeID(), 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, types.RandomNodeID(), 13))

_, err := atxHandler.validatePreviousAtx(types.RandomNodeID(), &wire.SubPostV2{}, []*types.ActivationTx{prev})
require.Error(t, err)
Expand All @@ -1446,8 +1446,8 @@ func Test_ValidatePreviousATX(t *testing.T) {
other := types.RandomNodeID()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), id, 7))
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, id, 7))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, other, 13))

units, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev})
require.NoError(t, err)
Expand All @@ -1467,7 +1467,7 @@ func Test_ValidatePreviousATX(t *testing.T) {
other := types.RandomNodeID()
prev := &types.ActivationTx{}
prev.SetID(types.RandomATXID())
require.NoError(t, atxs.SetUnits(atxHandler.cdb, prev.ID(), other, 13))
require.NoError(t, atxs.SetPost(atxHandler.cdb, prev.ID(), types.EmptyATXID, other, 13))

_, err := atxHandler.validatePreviousAtx(id, &wire.SubPostV2{NumUnits: 100}, []*types.ActivationTx{prev})
require.Error(t, err)
Expand Down
6 changes: 3 additions & 3 deletions activation/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func TestVerifyChainDeps(t *testing.T) {
atx := newChainedActivationTxV1(t, invalidAtx, goldenATXID)
atx.Sign(signer)
vAtx := toAtx(t, atx)
require.NoError(t, atxs.Add(db, vAtx, atx.Blob(), invalidAtx.ID()))
require.NoError(t, atxs.Add(db, vAtx, atx.Blob()))

ctrl := gomock.NewController(t)
v := NewMockPostVerifier(ctrl)
Expand Down Expand Up @@ -606,7 +606,7 @@ func TestVerifyChainDeps(t *testing.T) {
SmesherID: watx.SmesherID,
}
atx.SetID(watx.ID())
require.NoError(t, atxs.Add(db, atx, watx.Blob(), initialAtx.ID()))
require.NoError(t, atxs.Add(db, atx, watx.Blob()))

v := NewMockPostVerifier(gomock.NewController(t))
expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post))
Expand Down Expand Up @@ -646,7 +646,7 @@ func TestVerifyChainDeps(t *testing.T) {
SmesherID: watx.SmesherID,
}
atx.SetID(watx.ID())
require.NoError(t, atxs.Add(db, atx, watx.Blob(), initialAtx.ID()))
require.NoError(t, atxs.Add(db, atx, watx.Blob()))

v := NewMockPostVerifier(gomock.NewController(t))
expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post))
Expand Down
9 changes: 6 additions & 3 deletions activation/verify_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,17 @@ func Test_CheckPrevATXs(t *testing.T) {
})
atx1.Sign(sig)
vAtx1 := toAtx(t, atx1)
require.NoError(t, atxs.Add(db, vAtx1, atx1.Blob(), atx1.PrevATXID))
require.NoError(t, atxs.Add(db, vAtx1, atx1.Blob()))
require.NoError(t, atxs.SetPost(db, atx1.ID(), atx1.PrevATXID, sig.NodeID(), 1))

atx2 := newInitialATXv1(t, goldenATXID, func(atx *wire.ActivationTxV1) {
atx.PrevATXID = prevATXID
atx.PublishEpoch = 3
})
atx2.Sign(sig)
vAtx2 := toAtx(t, atx2)
require.NoError(t, atxs.Add(db, vAtx2, atx2.Blob(), atx2.PrevATXID))
require.NoError(t, atxs.Add(db, vAtx2, atx2.Blob()))
require.NoError(t, atxs.SetPost(db, atx2.ID(), atx2.PrevATXID, sig.NodeID(), 1))

// create 100 random ATXs that are not malicious
for i := 0; i < 100; i++ {
Expand All @@ -55,7 +57,8 @@ func Test_CheckPrevATXs(t *testing.T) {
})
atx.Sign(otherSig)
vAtx := toAtx(t, atx)
require.NoError(t, atxs.Add(db, vAtx, atx.Blob(), atx.PrevATXID))
require.NoError(t, atxs.Add(db, vAtx, atx.Blob()))
require.NoError(t, atxs.SetPost(db, atx.ID(), atx.PrevATXID, otherSig.NodeID(), 1))
}

// Act
Expand Down
2 changes: 1 addition & 1 deletion api/grpcserver/admin_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func newAtx(tb testing.TB, db *sql.Database) {
atx.SmesherID = types.BytesToNodeID(types.RandomBytes(20))
atx.SetReceived(time.Now().Local())
require.NoError(tb, atxs.Add(db, atx, types.AtxBlob{}))
require.NoError(tb, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits))
require.NoError(tb, atxs.SetPost(db, atx.ID(), types.EmptyATXID, atx.SmesherID, atx.NumUnits))
}

func createMesh(tb testing.TB, db *sql.Database) {
Expand Down
6 changes: 3 additions & 3 deletions checkpoint/runner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ func createMesh(t testing.TB, db *sql.Database, miners []miner, accts []*types.A
t.Helper()
for _, miner := range miners {
for _, atx := range miner.atxs {
require.NoError(t, atxs.Add(db, atx.ActivationTx, types.AtxBlob{}, atx.previous))
require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits))
require.NoError(t, atxs.Add(db, atx.ActivationTx, types.AtxBlob{}))
require.NoError(t, atxs.SetPost(db, atx.ID(), atx.previous, atx.SmesherID, atx.NumUnits))
}
if proof := miner.malfeasanceProof; len(proof) > 0 {
require.NoError(t, identities.SetMalicious(db, miner.atxs[0].SmesherID, proof, time.Now()))
Expand Down Expand Up @@ -399,7 +399,7 @@ func TestRunner_Generate_PreservesMarriageATX(t *testing.T) {
}
atx.SetID(types.RandomATXID())
require.NoError(t, atxs.Add(db, atx, types.AtxBlob{}))
require.NoError(t, atxs.SetUnits(db, atx.ID(), atx.SmesherID, atx.NumUnits))
require.NoError(t, atxs.SetPost(db, atx.ID(), types.EmptyATXID, atx.SmesherID, atx.NumUnits))

fs := afero.NewMemMapFs()
dir, err := afero.TempDir(fs, "", "Generate")
Expand Down
Loading

0 comments on commit 9a834fa

Please sign in to comment.