diff --git a/consensus/consortium/v2/snapshot.go b/consensus/consortium/v2/snapshot.go index 664e69bba..610233d2d 100644 --- a/consensus/consortium/v2/snapshot.go +++ b/consensus/consortium/v2/snapshot.go @@ -16,6 +16,7 @@ import ( blsCommon "github.com/ethereum/go-ethereum/crypto/bls/common" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/internal/ethapi" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/params" "github.com/hashicorp/golang-lru/arc/v2" ) @@ -243,7 +244,7 @@ func (s *Snapshot) apply(headers []*types.Header, chain consensus.ChainHeaderRea // Change the validator set base on the size of the validators set if number > 0 && number%s.config.EpochV2 == uint64(len(snap.validators())/2) { // Get the most recent checkpoint header - checkpointHeader := FindAncientHeader(header, uint64(len(snap.validators())/2), chain, parents) + checkpointHeader := findAncestorHeader(header, number-uint64(len(snap.validators())/2), chain, parents) if checkpointHeader == nil { return nil, consensus.ErrUnknownAncestor } @@ -420,35 +421,58 @@ func (s *Snapshot) IsRecentlySigned(validator common.Address) bool { return false } -// FindAncientHeader finds the most recent checkpoint header -// Travel through the candidateParents to find the ancient header. -// If all headers in candidateParents have the number is larger than the header number, -// the search function will return the index, but it is not valid if we check with the -// header since the number and hash is not equals. The candidateParents is -// only available when it downloads blocks from the network. -// Otherwise, the candidateParents is nil, and it will be found by header hash and number. -func FindAncientHeader(header *types.Header, ite uint64, chain consensus.ChainHeaderReader, candidateParents []*types.Header) *types.Header { - ancient := header - for i := uint64(1); i <= ite; i++ { - parentHash := ancient.ParentHash - parentHeight := ancient.Number.Uint64() - 1 - found := false - if len(candidateParents) > 0 { - index := sort.Search(len(candidateParents), func(i int) bool { - return candidateParents[i].Number.Uint64() >= parentHeight - }) - if index < len(candidateParents) && candidateParents[index].Number.Uint64() == parentHeight { - ancient = candidateParents[index] - found = true - } - } - if !found { - ancient = chain.GetHeader(parentHash, parentHeight) - found = true +// findAncestorHeader traverses back to look for the requested ancestor header +// in parents list or in chaindata +// +// parents are guaranteed to be ordered and linked by the check when InsertChain +// +// There are 2 possible cases: +// Case 1: ancestor header is in parents list +// <- parents -> +// [ ancestorHeader ] +// +// Case 2: ancestor header's height is lower than parents list +// <- parents -> +// ancestorHeader ... [ ] + +func findAncestorHeader( + currentHeader *types.Header, + ancestorBlockNumber uint64, + chain consensus.ChainHeaderReader, + parents []*types.Header, +) *types.Header { + // Find the first header in parents list that is higher or equal to checkpoint block + index := sort.Search(len(parents), func(i int) bool { + return parents[i].Number.Uint64() >= ancestorBlockNumber + }) + + // This must not happen, checkpoint header's height cannot be higher the parents list + if len(parents) != 0 && index >= len(parents) { + log.Warn( + "Checkpoint header's height is higher than parents list", + "checkpointNumber", ancestorBlockNumber, + "last parent", parents[len(parents)-1].Number, + ) + return nil + } + + if len(parents) != 0 && parents[index].Number.Uint64() == ancestorBlockNumber { + // Case 1: checkpoint header is in parents list + return parents[index] + } else { + // Case 2: checkpoint header's height is lower than parents list + var headerIterator *types.Header + if len(parents) != 0 { + headerIterator = parents[0] + } else { + headerIterator = currentHeader } - if ancient == nil || !found { - return nil + for headerIterator.Number.Uint64() != ancestorBlockNumber { + headerIterator = chain.GetHeader(headerIterator.ParentHash, headerIterator.Number.Uint64()-1) + if headerIterator == nil { + return nil + } } + return headerIterator } - return ancient } diff --git a/consensus/consortium/v2/snapshot_test.go b/consensus/consortium/v2/snapshot_test.go index 2de1ab932..0732fa3e7 100644 --- a/consensus/consortium/v2/snapshot_test.go +++ b/consensus/consortium/v2/snapshot_test.go @@ -29,13 +29,14 @@ func (chainReader *mockChainReader) OpEvents() []*vm.PublishEvent func TestFindCheckpointHeader(t *testing.T) { // Case 1: checkpoint header is at block 5 (in parent list) + // parent list ranges from [0, 10) parents := make([]*types.Header, 10) for i := range parents { parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} } currentHeader := &types.Header{Number: big.NewInt(10)} - checkpointHeader := FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, nil, parents) + checkpointHeader := findAncestorHeader(currentHeader, 5, nil, parents) if checkpointHeader.Number.Cmp(big.NewInt(5)) != 0 && checkpointHeader.Coinbase != common.BigToAddress(big.NewInt(5)) { t.Fatalf("Expect checkpoint header number: %d, got: %d", 5, checkpointHeader.Number.Int64()) } @@ -66,7 +67,7 @@ func TestFindCheckpointHeader(t *testing.T) { currentHeader = &types.Header{ParentHash: common.BigToHash(big.NewInt(19)), Number: big.NewInt(20)} // Must traverse and get the correct header in chain 2 - checkpointHeader = FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, &mockChain, parents) + checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, parents) if checkpointHeader == nil { t.Fatal("Failed to find checkpoint header") } @@ -79,7 +80,7 @@ func TestFindCheckpointHeader(t *testing.T) { // Case 3: find checkpoint header with nil parent list currentHeader = &types.Header{Number: big.NewInt(10), ParentHash: common.BigToHash(big.NewInt(109))} - checkpointHeader = FindAncientHeader(currentHeader, currentHeader.Number.Uint64()-5, &mockChain, nil) + checkpointHeader = findAncestorHeader(currentHeader, 5, &mockChain, nil) // Must traverse and get the correct header in chain 1 if checkpointHeader == nil { t.Fatal("Failed to find checkpoint header") @@ -90,4 +91,16 @@ func TestFindCheckpointHeader(t *testing.T) { checkpointHeader.Number.Int64(), checkpointHeader.ParentHash, ) } + + // Case 4: checkpoint header is higher than parent list, this must not happen + // but the function must not crash in this case + // parent list ranges from [0, 10) + parents = make([]*types.Header, 10) + for i := range parents { + parents[i] = &types.Header{Number: big.NewInt(int64(i)), Coinbase: common.BigToAddress(big.NewInt(int64(i)))} + } + checkpointHeader = findAncestorHeader(nil, 10, nil, parents) + if checkpointHeader != nil { + t.Fatalf("Expect %v checkpoint header, got %v", nil, checkpointHeader) + } }