Skip to content

Commit

Permalink
Support multiple previous ATXs
Browse files Browse the repository at this point in the history
  • Loading branch information
poszu committed Jul 18, 2024
1 parent f9f8a53 commit 8211add
Show file tree
Hide file tree
Showing 24 changed files with 308 additions and 144 deletions.
2 changes: 1 addition & 1 deletion activation/activation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func publishAtxV1(
func(_ context.Context, _ string, got []byte) error {
return codec.Decode(got, &watx)
})
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob()))
require.NoError(tb, atxs.Add(tab.db, toAtx(tb, &watx), watx.Blob(), watx.PrevATXID))
tab.atxsdata.AddFromAtx(toAtx(tb, &watx), false)
return &watx
}
Expand Down
1 change: 0 additions & 1 deletion activation/e2e/atx_merge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,5 @@ func Test_MarryAndMerge(t *testing.T) {
require.Equal(t, units[i], atxFromDb.NumUnits)
require.Equal(t, signer.NodeID(), atxFromDb.SmesherID)
require.Equal(t, publish, atxFromDb.PublishEpoch)
require.Equal(t, mergedATX2.ID(), atxFromDb.PrevATXID)
}
}
28 changes: 18 additions & 10 deletions activation/handler_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,24 @@ func (h *HandlerV1) checkWrongPrevAtx(
return nil, fmt.Errorf("get prev atx id by node id: %w", err)
}

atx2, err := atxs.Get(tx, id)
if err != nil {
return nil, fmt.Errorf("get prev atx: %w", err)
}
if atx.ID() != atx2.ID() && atx.PrevATXID == atx2.PrevATXID {
// found an ATX that points to the same previous ATX
atx2ID = id
break
if atx.ID() != id {
prev, err := atxs.Previous(tx, id)
if err != nil {
return nil, fmt.Errorf("get prev atx: %w", err)

Check warning on line 465 in activation/handler_v1.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v1.go#L465

Added line #L465 was not covered by tests
}
if (atx.PrevATXID == types.EmptyATXID && len(prev) == 0) || atx.PrevATXID == prev[0] {
// found an ATX that points to the same previous ATX
atx2ID = id
break
}
atx2, err := atxs.Get(tx, id)
if err != nil {
return nil, fmt.Errorf("get atx: %w", err)

Check warning on line 474 in activation/handler_v1.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v1.go#L474

Added line #L474 was not covered by tests
}
pubEpoch = atx2.PublishEpoch
} else {
pubEpoch = atx.PublishEpoch

Check warning on line 478 in activation/handler_v1.go

View check run for this annotation

Codecov / codecov/patch

activation/handler_v1.go#L477-L478

Added lines #L477 - L478 were not covered by tests
}
pubEpoch = atx2.PublishEpoch
}
}

Expand Down Expand Up @@ -551,7 +559,7 @@ func (h *HandlerV1) storeAtx(
return fmt.Errorf("check malicious: %w", err)
}

err = atxs.Add(tx, atx, watx.Blob())
err = atxs.Add(tx, atx, watx.Blob(), watx.PrevATXID)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
Expand Down
7 changes: 2 additions & 5 deletions activation/handler_v2.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,7 @@ func (h *HandlerV2) processATX(
SmesherID: watx.SmesherID,
}

if watx.Initial == nil {
// FIXME: update to keep many previous ATXs to support merged ATXs
atx.PrevATXID = watx.PreviousATXs[0]
} else {
if watx.Initial != nil {
atx.CommitmentATX = &watx.Initial.CommitmentATX
}

Expand Down Expand Up @@ -745,7 +742,7 @@ func (h *HandlerV2) storeAtx(
}
}

err = atxs.Add(tx, atx, watx.Blob())
err = atxs.Add(tx, atx, watx.Blob(), watx.PreviousATXs...)
if err != nil && !errors.Is(err, sql.ErrObjectExists) {
return fmt.Errorf("add atx to db: %w", err)
}
Expand Down
78 changes: 40 additions & 38 deletions activation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,9 @@ func (v *Validator) VerifyChain(ctx context.Context, id, goldenATXID types.ATXID
}

type atxDeps struct {
nipost types.NIPost
niposts []types.NIPost
positioning types.ATXID
previous types.ATXID
previous []types.ATXID
commitment types.ATXID
}

Expand Down Expand Up @@ -455,14 +455,16 @@ func (v *Validator) getAtxDeps(ctx context.Context, id types.ATXID) (*atxDeps, e
}

deps := &atxDeps{
nipost: *wire.NiPostFromWireV1(atx.NIPost),
niposts: []types.NIPost{*wire.NiPostFromWireV1(atx.NIPost)},
positioning: atx.PositioningATXID,
previous: atx.PrevATXID,
commitment: commitment,
}
if atx.PrevATXID != types.EmptyATXID {
deps.previous = []types.ATXID{atx.PrevATXID}
}

return deps, nil
case types.AtxV2:
// TODO: support merged ATXs
var atx wire.ActivationTxV2
if err := codec.Decode(blob.Bytes, &atx); err != nil {
return nil, fmt.Errorf("decoding ATX blob: %w", err)
Expand All @@ -478,23 +480,23 @@ func (v *Validator) getAtxDeps(ctx context.Context, id types.ATXID) (*atxDeps, e
}
commitment = catx
}
var previous types.ATXID
if len(atx.PreviousATXs) != 0 {
previous = atx.PreviousATXs[0]
}

deps := &atxDeps{
nipost: types.NIPost{
Post: wire.PostFromWireV1(&atx.NiPosts[0].Posts[0].Post),
PostMetadata: &types.PostMetadata{
Challenge: atx.NiPosts[0].Challenge[:],
LabelsPerUnit: v.cfg.LabelsPerUnit,
},
},
positioning: atx.PositioningATX,
previous: previous,
previous: atx.PreviousATXs,
commitment: commitment,
}
for _, nipost := range atx.NiPosts {
for _, post := range nipost.Posts {
deps.niposts = append(deps.niposts, types.NIPost{
Post: wire.PostFromWireV1(&post.Post),
PostMetadata: &types.PostMetadata{
Challenge: nipost.Challenge[:],
LabelsPerUnit: v.cfg.LabelsPerUnit,
},
})
}
}
return deps, nil
}

Expand All @@ -511,12 +513,11 @@ func (v *Validator) verifyChainWithOpts(
if err != nil {
return fmt.Errorf("get atx: %w", err)
}
if atx.Golden() {
log.Debug("not verifying ATX chain", zap.Stringer("atx_id", id), zap.String("reason", "golden"))
return nil
}

switch {
case atx.Golden():
log.Debug("not verifying ATX chain", zap.Stringer("atx_id", id), zap.String("reason", "golden"))
return nil
case atx.Validity() == types.Valid:
log.Debug("not verifying ATX chain", zap.Stringer("atx_id", id), zap.String("reason", "already verified"))
return nil
Expand All @@ -542,20 +543,21 @@ func (v *Validator) verifyChainWithOpts(
if err != nil {
return fmt.Errorf("getting ATX dependencies: %w", err)
}

if err := v.Post(
ctx,
atx.SmesherID,
deps.commitment,
deps.nipost.Post,
deps.nipost.PostMetadata,
atx.NumUnits,
[]validatorOption{PrioritizeCall()}...,
); err != nil {
if err := atxs.SetValidity(v.db, id, types.Invalid); err != nil {
log.Warn("failed to persist atx validity", zap.Error(err), zap.Stringer("atx_id", id))
for _, nipost := range deps.niposts {
if err := v.Post(
ctx,
atx.SmesherID,
deps.commitment,
nipost.Post,
nipost.PostMetadata,
atx.NumUnits,
[]validatorOption{PrioritizeCall()}...,
); err != nil {
if err := atxs.SetValidity(v.db, id, types.Invalid); err != nil {
log.Warn("failed to persist atx validity", zap.Error(err), zap.Stringer("atx_id", id))

Check warning on line 557 in activation/validation.go

View check run for this annotation

Codecov / codecov/patch

activation/validation.go#L557

Added line #L557 was not covered by tests
}
return &InvalidChainError{ID: id, src: err}
}
return &InvalidChainError{ID: id, src: err}
}

err = v.verifyChainDeps(ctx, deps, goldenATXID, opts)
Expand All @@ -579,9 +581,9 @@ func (v *Validator) verifyChainDeps(
goldenATXID types.ATXID,
opts verifyChainOpts,
) error {
if deps.previous != types.EmptyATXID {
if err := v.verifyChainWithOpts(ctx, deps.previous, goldenATXID, opts); err != nil {
return fmt.Errorf("validating previous ATX %s chain: %w", deps.previous.ShortString(), err)
for _, prev := range deps.previous {
if err := v.verifyChainWithOpts(ctx, prev, goldenATXID, opts); err != nil {
return fmt.Errorf("validating previous ATX %s chain: %w", prev.ShortString(), err)
}
}
if deps.positioning != goldenATXID {
Expand All @@ -591,7 +593,7 @@ func (v *Validator) verifyChainDeps(
}
// verify commitment only if arrived at the first ATX in the chain
// to avoid verifying the same commitment ATX multiple times.
if deps.previous == types.EmptyATXID && deps.commitment != goldenATXID {
if len(deps.previous) == 0 && deps.commitment != goldenATXID {
if err := v.verifyChainWithOpts(ctx, deps.commitment, goldenATXID, opts); err != nil {
return fmt.Errorf("validating commitment ATX %s chain: %w", deps.commitment.ShortString(), err)
}
Expand Down
47 changes: 45 additions & 2 deletions activation/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ func TestVerifyChainDeps(t *testing.T) {
atx := newChainedActivationTxV1(t, invalidAtx, goldenATXID)
atx.Sign(signer)
vAtx := toAtx(t, atx)
require.NoError(t, atxs.Add(db, vAtx, atx.Blob()))
require.NoError(t, atxs.Add(db, vAtx, atx.Blob(), invalidAtx.ID()))

ctrl := gomock.NewController(t)
v := NewMockPostVerifier(ctrl)
Expand Down Expand Up @@ -606,12 +606,55 @@ func TestVerifyChainDeps(t *testing.T) {
SmesherID: watx.SmesherID,
}
atx.SetID(watx.ID())
require.NoError(t, atxs.Add(db, atx, watx.Blob()))
require.NoError(t, atxs.Add(db, atx, watx.Blob(), initialAtx.ID()))

v := NewMockPostVerifier(gomock.NewController(t))
expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post))
v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx.NIPost.Post), gomock.Any(), gomock.Any())
v.EXPECT().Verify(ctx, expectedPost, gomock.Any(), gomock.Any())
validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v)
err = validator.VerifyChain(ctx, watx.ID(), goldenATXID)
require.NoError(t, err)
})
t.Run("merged ATX", func(t *testing.T) {
initialAtx := newInitialATXv1(t, goldenATXID)
initialAtx.Sign(signer)
require.NoError(t, atxs.Add(db, toAtx(t, initialAtx), initialAtx.Blob()))

// second ID for the merged ATX
otherSig, err := signing.NewEdSigner()
require.NoError(t, err)
initialAtx2 := newInitialATXv1(t, goldenATXID)
initialAtx2.Sign(otherSig)
require.NoError(t, atxs.Add(db, toAtx(t, initialAtx2), initialAtx2.Blob()))

watx := newSoloATXv2(t, initialAtx.PublishEpoch+1, initialAtx.ID(), initialAtx.ID())
watx.NiPosts[0].Posts = append(watx.NiPosts[0].Posts, wire.SubPostV2{
MarriageIndex: 1,
PrevATXIndex: 1,
Post: wire.PostV1{
Nonce: 99,
Pow: 55,
Indices: types.RandomBytes(33),
},
NumUnits: 77,
})
watx.PreviousATXs = append(watx.PreviousATXs, initialAtx2.ID())
watx.Sign(signer)
atx := &types.ActivationTx{
PublishEpoch: watx.PublishEpoch,
SmesherID: watx.SmesherID,
}
atx.SetID(watx.ID())
require.NoError(t, atxs.Add(db, atx, watx.Blob(), initialAtx.ID()))

v := NewMockPostVerifier(gomock.NewController(t))
expectedPost := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[0].Post))
expectedPost2 := (*shared.Proof)(wire.PostFromWireV1(&watx.NiPosts[0].Posts[1].Post))
v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx.NIPost.Post), gomock.Any(), gomock.Any())
v.EXPECT().Verify(ctx, (*shared.Proof)(initialAtx2.NIPost.Post), gomock.Any(), gomock.Any())
v.EXPECT().Verify(ctx, expectedPost, gomock.Any(), gomock.Any())
v.EXPECT().Verify(ctx, expectedPost2, gomock.Any(), gomock.Any())
validator := NewValidator(db, nil, DefaultPostConfig(), config.ScryptParams{}, v)
err = validator.VerifyChain(ctx, watx.ID(), goldenATXID)
require.NoError(t, err)
Expand Down
6 changes: 3 additions & 3 deletions activation/verify_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ func Test_CheckPrevATXs(t *testing.T) {
})
atx1.Sign(sig)
vAtx1 := toAtx(t, atx1)
require.NoError(t, atxs.Add(db, vAtx1, atx1.Blob()))
require.NoError(t, atxs.Add(db, vAtx1, atx1.Blob(), atx1.PrevATXID))

atx2 := newInitialATXv1(t, goldenATXID, func(atx *wire.ActivationTxV1) {
atx.PrevATXID = prevATXID
atx.PublishEpoch = 3
})
atx2.Sign(sig)
vAtx2 := toAtx(t, atx2)
require.NoError(t, atxs.Add(db, vAtx2, atx2.Blob()))
require.NoError(t, atxs.Add(db, vAtx2, atx2.Blob(), atx2.PrevATXID))

// create 100 random ATXs that are not malicious
for i := 0; i < 100; i++ {
Expand All @@ -55,7 +55,7 @@ func Test_CheckPrevATXs(t *testing.T) {
})
atx.Sign(otherSig)
vAtx := toAtx(t, atx)
require.NoError(t, atxs.Add(db, vAtx, atx.Blob()))
require.NoError(t, atxs.Add(db, vAtx, atx.Blob(), atx.PrevATXID))
}

// Act
Expand Down
1 change: 0 additions & 1 deletion activation/wire/wire_v1.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ func ActivationTxFromWireV1(atx *ActivationTxV1) *types.ActivationTx {
result := &types.ActivationTx{
PublishEpoch: atx.PublishEpoch,
Sequence: atx.Sequence,
PrevATXID: atx.PrevATXID,
CommitmentATX: atx.CommitmentATXID,
Coinbase: atx.Coinbase,
NumUnits: atx.NumUnits,
Expand Down
22 changes: 20 additions & 2 deletions api/grpcserver/activation_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,15 @@ func (s *activationService) Get(ctx context.Context, request *pb.GetRequest) (*p
)
return nil, status.Error(codes.NotFound, "id was not found")
}
prev, err := s.atxProvider.Previous(atxId)
if err != nil {
ctxzap.Error(ctx, "failed to get previous ATX",
zap.Stringer("id", atxId),
zap.Error(err),
)
return nil, status.Error(codes.Internal, "couldn't get previous ATXs")
}

proof, err := s.atxProvider.GetMalfeasanceProof(atx.SmesherID)
if err != nil && !errors.Is(err, sql.ErrNotFound) {
ctxzap.Error(ctx, "failed to get malfeasance proof",
Expand All @@ -74,7 +83,7 @@ func (s *activationService) Get(ctx context.Context, request *pb.GetRequest) (*p
return nil, status.Error(codes.NotFound, "id was not found")
}
resp := &pb.GetResponse{
Atx: convertActivation(atx),
Atx: convertActivation(atx, prev),
}
if proof != nil {
resp.MalfeasanceProof = events.ToMalfeasancePB(atx.SmesherID, proof, false)
Expand All @@ -95,7 +104,16 @@ func (s *activationService) Highest(ctx context.Context, req *emptypb.Empty) (*p
if err != nil || atx == nil {
return nil, status.Error(codes.NotFound, fmt.Sprintf("atx id %v not found: %v", highest, err.Error()))
}
prev, err := s.atxProvider.Previous(highest)
if err != nil {
ctxzap.Error(ctx, "failed to get previous ATX",
zap.Stringer("id", highest),
zap.Error(err),
)
return nil, status.Error(codes.Internal, "couldn't get previous ATXs")

Check warning on line 113 in api/grpcserver/activation_service.go

View check run for this annotation

Codecov / codecov/patch

api/grpcserver/activation_service.go#L109-L113

Added lines #L109 - L113 were not covered by tests
}

return &pb.HighestResponse{
Atx: convertActivation(atx),
Atx: convertActivation(atx, prev),
}, nil
}
Loading

0 comments on commit 8211add

Please sign in to comment.