Skip to content

Commit

Permalink
Allow using MPT
Browse files Browse the repository at this point in the history
  • Loading branch information
omerfirmak committed Jan 24, 2025
1 parent 447fed2 commit 4d0e694
Show file tree
Hide file tree
Showing 38 changed files with 931 additions and 83 deletions.
1 change: 1 addition & 0 deletions cmd/geth/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ var (
utils.ScrollAlphaFlag,
utils.ScrollSepoliaFlag,
utils.ScrollFlag,
utils.ScrollMPTFlag,
utils.VMEnableDebugFlag,
utils.NetworkIdFlag,
utils.EthStatsURLFlag,
Expand Down
1 change: 1 addition & 0 deletions cmd/geth/usage.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ var AppHelpFlagGroups = []flags.FlagGroup{
utils.ScrollAlphaFlag,
utils.ScrollSepoliaFlag,
utils.ScrollFlag,
utils.ScrollMPTFlag,
utils.SyncModeFlag,
utils.ExitWhenSyncedFlag,
utils.GCModeFlag,
Expand Down
30 changes: 20 additions & 10 deletions cmd/utils/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ var (
Name: "scroll",
Usage: "Scroll mainnet",
}
ScrollMPTFlag = cli.BoolFlag{
Name: "scroll-mpt",
Usage: "Use MPT trie for state storage",
}
DeveloperFlag = cli.BoolFlag{
Name: "dev",
Usage: "Ephemeral proof-of-authority network with a pre-funded developer account, mining enabled",
Expand Down Expand Up @@ -1879,12 +1883,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
stack.Config().L1Confirmations = rpc.FinalizedBlockNumber
log.Info("Setting flag", "--l1.sync.startblock", "4038000")
stack.Config().L1DeploymentBlock = 4038000
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name)
if cfg.Genesis.Config.Scroll.UseZktrie {
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
}
log.Info("Pruning disabled")
cfg.NoPruning = true
}
log.Info("Pruning disabled")
cfg.NoPruning = true
case ctx.GlobalBool(ScrollFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 534352
Expand All @@ -1895,12 +1902,15 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *ethconfig.Config) {
stack.Config().L1Confirmations = rpc.FinalizedBlockNumber
log.Info("Setting flag", "--l1.sync.startblock", "18306000")
stack.Config().L1DeploymentBlock = 18306000
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
cfg.Genesis.Config.Scroll.UseZktrie = !ctx.GlobalBool(ScrollMPTFlag.Name)
if cfg.Genesis.Config.Scroll.UseZktrie {
// disable pruning
if ctx.GlobalString(GCModeFlag.Name) != GCModeArchive {
log.Crit("Must use --gcmode=archive")
}
log.Info("Pruning disabled")
cfg.NoPruning = true
}
log.Info("Pruning disabled")
cfg.NoPruning = true
case ctx.GlobalBool(DeveloperFlag.Name):
if !ctx.GlobalIsSet(NetworkIdFlag.Name) {
cfg.NetworkId = 1337
Expand Down
3 changes: 2 additions & 1 deletion core/block_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ func (v *BlockValidator) ValidateState(block *types.Block, statedb *state.StateD
}
// Validate the state root against the received state root and throw
// an error if they don't match.
if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); header.Root != root {
shouldValidateStateRoot := v.config.Scroll.UseZktrie != v.config.IsEuclid(header.Time)
if root := statedb.IntermediateRoot(v.config.IsEIP158(header.Number)); shouldValidateStateRoot && header.Root != root {
return fmt.Errorf("invalid merkle root (remote: %x local: %x)", header.Root, root)
}
return nil
Expand Down
7 changes: 5 additions & 2 deletions core/blockchain.go
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,9 @@ func (bc *BlockChain) writeBlockWithState(block *types.Block, receipts []*types.
return NonStatTy, err
}
triedb := bc.stateCache.TrieDB()
if block.Root() != root {
rawdb.WriteDiskStateRoot(bc.db, block.Root(), root)
}

// If we're running an archive node, always flush
if bc.cacheConfig.TrieDirtyDisabled {
Expand Down Expand Up @@ -1677,7 +1680,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
}

// Enable prefetching to pull in trie node paths while processing transactions
statedb.StartPrefetcher("chain")
statedb.StartPrefetcher("chain", nil)
activeState = statedb

// If we have a followup block, run that against the current state to pre-cache
Expand Down Expand Up @@ -1814,7 +1817,7 @@ func (bc *BlockChain) BuildAndWriteBlock(parentBlock *types.Block, header *types
return NonStatTy, err
}

statedb.StartPrefetcher("l1sync")
statedb.StartPrefetcher("l1sync", nil)
defer statedb.StopPrefetcher()

header.ParentHash = parentBlock.Hash()
Expand Down
20 changes: 11 additions & 9 deletions core/blockchain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3032,15 +3032,16 @@ func TestPoseidonCodeHash(t *testing.T) {
var callCreate2Code = common.Hex2Bytes("f4754f660000000000000000000000000000000000000000000000000000000000000020000000000000000000000000000000000000000000000000000000000000005c6080604052348015600f57600080fd5b50603f80601d6000396000f3fe6080604052600080fdfea2646970667358221220707985753fcb6578098bb16f3709cf6d012993cba6dd3712661cf8f57bbc0d4d64736f6c6343000807003300000000")

var (
key1, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
addr1 = crypto.PubkeyToAddress(key1.PublicKey)
db = rawdb.NewMemoryDatabase()
gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000000)}}}
genesis = gspec.MustCommit(db)
signer = types.LatestSigner(gspec.Config)
engine = ethash.NewFaker()
blockchain, _ = NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil)
key1, _ = crypto.HexToECDSA("b71c71a67e1177ad4e901695e1b4b9ee17ae16c6668d313eac2f96dbcda3f291")
addr1 = crypto.PubkeyToAddress(key1.PublicKey)
db = rawdb.NewMemoryDatabase()
gspec = &Genesis{Config: params.TestChainConfig, Alloc: GenesisAlloc{addr1: {Balance: big.NewInt(10000000000000000)}}}
signer = types.LatestSigner(gspec.Config)
engine = ethash.NewFaker()
)
gspec.Config.Scroll.UseZktrie = true
genesis := gspec.MustCommit(db)
blockchain, _ := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{}, nil, nil)

defer blockchain.Stop()

Expand Down Expand Up @@ -3724,6 +3725,7 @@ func TestCurieTransition(t *testing.T) {
config.CurieBlock = big.NewInt(2)
config.DarwinTime = nil
config.DarwinV2Time = nil
config.Scroll.UseZktrie = true

var (
db = rawdb.NewMemoryDatabase()
Expand All @@ -3748,7 +3750,7 @@ func TestCurieTransition(t *testing.T) {
number := block.Number().Uint64()
baseFee := block.BaseFee()

statedb, _ := state.New(block.Root(), state.NewDatabase(db), nil)
statedb, _ := state.New(block.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Zktrie: gspec.Config.Scroll.UseZktrie}), nil)

code := statedb.GetCode(rcfg.L1GasPriceOracleAddress)
codeSize := statedb.GetCodeSize(rcfg.L1GasPriceOracleAddress)
Expand Down
3 changes: 2 additions & 1 deletion core/chain_makers.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/scroll-tech/go-ethereum/ethdb"
"github.com/scroll-tech/go-ethereum/params"
"github.com/scroll-tech/go-ethereum/rollup/fees"
"github.com/scroll-tech/go-ethereum/trie"
)

// BlockGen creates blocks for testing.
Expand Down Expand Up @@ -264,7 +265,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
return nil, nil
}
for i := 0; i < n; i++ {
statedb, err := state.New(parent.Root(), state.NewDatabase(db), nil)
statedb, err := state.New(parent.Root(), state.NewDatabaseWithConfig(db, &trie.Config{Zktrie: config.Scroll.ZktrieEnabled()}), nil)
if err != nil {
panic(err)
}
Expand Down
5 changes: 4 additions & 1 deletion core/genesis.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,10 @@ func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
}
statedb.Commit(false)
statedb.Database().TrieDB().Commit(root, true, nil)

if g.Config != nil && g.Config.Scroll.GenesisStateRoot != nil {
head.Root = *g.Config.Scroll.GenesisStateRoot
rawdb.WriteDiskStateRoot(db, head.Root, root)
}
return types.NewBlock(head, nil, nil, nil, trie.NewStackTrie(nil))
}

Expand Down
14 changes: 14 additions & 0 deletions core/rawdb/accessors_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,17 @@ func DeleteTrieNode(db ethdb.KeyValueWriter, hash common.Hash) {
log.Crit("Failed to delete trie node", "err", err)
}
}

func WriteDiskStateRoot(db ethdb.KeyValueWriter, headerRoot, diskRoot common.Hash) {
if err := db.Put(diskStateRootKey(headerRoot), diskRoot.Bytes()); err != nil {
log.Crit("Failed to store disk state root", "err", err)
}
}

func ReadDiskStateRoot(db ethdb.KeyValueReader, headerRoot common.Hash) (common.Hash, error) {
data, err := db.Get(diskStateRootKey(headerRoot))
if err != nil {
return common.Hash{}, err
}
return common.BytesToHash(data), nil
}
6 changes: 6 additions & 0 deletions core/rawdb/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ var (

// Scroll da syncer store
daSyncedL1BlockNumberKey = []byte("LastDASyncedL1BlockNumber")

diskStateRootPrefix = []byte("disk-state-root")
)

// Use the updated "L1" prefix on all new networks
Expand Down Expand Up @@ -312,3 +314,7 @@ func batchMetaKey(batchIndex uint64) []byte {
func committedBatchMetaKey(batchIndex uint64) []byte {
return append(committedBatchMetaPrefix, encodeBigEndian(batchIndex)...)
}

func diskStateRootKey(headerRoot common.Hash) []byte {
return append(diskStateRootPrefix, headerRoot.Bytes()...)
}
6 changes: 6 additions & 0 deletions core/state/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ type Trie interface {
// nodes of the longest existing prefix of the key (at least the root), ending
// with the node that proves the absence of the key.
Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) error

// Witness returns a set containing all trie nodes that have been accessed.
Witness() map[string]struct{}
}

// NewDatabase creates a backing store for state. The returned database is safe for
Expand Down Expand Up @@ -136,6 +139,9 @@ type cachingDB struct {

// OpenTrie opens the main account trie at a specific root hash.
func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
if diskRoot, err := rawdb.ReadDiskStateRoot(db.db.DiskDB(), root); err == nil {
root = diskRoot
}
if db.zktrie {
tr, err := trie.NewZkTrie(root, trie.NewZktrieDatabaseFromTriedb(db.db))
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions core/state/snapshot/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,8 +618,8 @@ func (dl *diskLayer) generate(stats *generatorStats) {
Balance *big.Int
Root common.Hash
KeccakCodeHash []byte
PoseidonCodeHash []byte
CodeSize uint64
PoseidonCodeHash []byte `rlp:"-"`
CodeSize uint64 `rlp:"-"`
}
if err := rlp.DecodeBytes(val, &acc); err != nil {
log.Crit("Invalid account encountered during snapshot creation", "err", err)
Expand Down
17 changes: 15 additions & 2 deletions core/state/state_object.go
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,18 @@ func (s *stateObject) Code(db Database) []byte {
// CodeSize returns the size of the contract code associated with this object,
// or zero if none. This method is an almost mirror of Code, but uses a cache
// inside the database to avoid loading codes seen recently.
func (s *stateObject) CodeSize() uint64 {
return s.data.CodeSize
func (s *stateObject) CodeSize(db Database) uint64 {
if s.code != nil {
return uint64(len(s.code))
}
if bytes.Equal(s.KeccakCodeHash(), emptyKeccakCodeHash) {
return 0
}
size, err := db.ContractCodeSize(s.addrHash, common.BytesToHash(s.KeccakCodeHash()))
if err != nil {
s.setError(fmt.Errorf("can't load code size %x: %v", s.KeccakCodeHash(), err))
}
return uint64(size)
}

func (s *stateObject) SetCode(code []byte) {
Expand Down Expand Up @@ -534,6 +544,9 @@ func (s *stateObject) setNonce(nonce uint64) {
}

func (s *stateObject) PoseidonCodeHash() []byte {
if !s.db.IsZktrie() {
panic("PoseidonCodeHash is only available in zktrie mode")
}
return s.data.PoseidonCodeHash
}

Expand Down
11 changes: 6 additions & 5 deletions core/state/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ func TestSnapshotEmpty(t *testing.T) {
}

func TestSnapshot2(t *testing.T) {
state, _ := New(common.Hash{}, NewDatabase(rawdb.NewMemoryDatabase()), nil)
stateDb := NewDatabase(rawdb.NewMemoryDatabase())
state, _ := New(common.Hash{}, stateDb, nil)

stateobjaddr0 := common.BytesToAddress([]byte("so0"))
stateobjaddr1 := common.BytesToAddress([]byte("so1"))
Expand Down Expand Up @@ -201,7 +202,7 @@ func TestSnapshot2(t *testing.T) {
so0Restored.GetState(state.db, storageaddr)
so0Restored.Code(state.db)
// non-deleted is equal (restored)
compareStateObjects(so0Restored, so0, t)
compareStateObjects(so0Restored, so0, stateDb, t)

// deleted should be nil, both before and after restore of state copy
so1Restored := state.getStateObject(stateobjaddr1)
Expand All @@ -210,7 +211,7 @@ func TestSnapshot2(t *testing.T) {
}
}

func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
func compareStateObjects(so0, so1 *stateObject, db Database, t *testing.T) {
if so0.Address() != so1.Address() {
t.Fatalf("Address mismatch: have %v, want %v", so0.address, so1.address)
}
Expand All @@ -229,8 +230,8 @@ func compareStateObjects(so0, so1 *stateObject, t *testing.T) {
if !bytes.Equal(so0.PoseidonCodeHash(), so1.PoseidonCodeHash()) {
t.Fatalf("PoseidonCodeHash mismatch: have %v, want %v", so0.PoseidonCodeHash(), so1.PoseidonCodeHash())
}
if so0.CodeSize() != so1.CodeSize() {
t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(), so1.CodeSize())
if so0.CodeSize(db) != so1.CodeSize(db) {
t.Fatalf("CodeSize mismatch: have %v, want %v", so0.CodeSize(db), so1.CodeSize(db))
}
if !bytes.Equal(so0.code, so1.code) {
t.Fatalf("Code mismatch: have %v, want %v", so0.code, so1.code)
Expand Down
Loading

0 comments on commit 4d0e694

Please sign in to comment.