diff --git a/beacon-chain/db/kv/BUILD.bazel b/beacon-chain/db/kv/BUILD.bazel index acbaa50fa2a5..759a6dc00462 100644 --- a/beacon-chain/db/kv/BUILD.bazel +++ b/beacon-chain/db/kv/BUILD.bazel @@ -27,6 +27,9 @@ go_library( "p2p.go", "schema.go", "state.go", + "state_diff.go", + "state_diff_cache.go", + "state_diff_helpers.go", "state_summary.go", "state_summary_cache.go", "utils.go", @@ -41,10 +44,12 @@ go_library( "//beacon-chain/db/iface:go_default_library", "//beacon-chain/state:go_default_library", "//beacon-chain/state/state-native:go_default_library", + "//cmd/beacon-chain/flags:go_default_library", "//config/features:go_default_library", "//config/fieldparams:go_default_library", "//config/params:go_default_library", "//consensus-types/blocks:go_default_library", + "//consensus-types/hdiff:go_default_library", "//consensus-types/interfaces:go_default_library", "//consensus-types/light-client:go_default_library", "//consensus-types/primitives:go_default_library", @@ -53,6 +58,7 @@ go_library( "//encoding/ssz/detect:go_default_library", "//genesis:go_default_library", "//io/file:go_default_library", + "//math:go_default_library", "//monitoring/progress:go_default_library", "//monitoring/tracing:go_default_library", "//monitoring/tracing/trace:go_default_library", @@ -98,6 +104,7 @@ go_test( "migration_block_slot_index_test.go", "migration_state_validators_test.go", "p2p_test.go", + "state_diff_test.go", "state_summary_test.go", "state_test.go", "utils_test.go", @@ -111,6 +118,7 @@ go_test( "//beacon-chain/db/iface:go_default_library", "//beacon-chain/state:go_default_library", "//beacon-chain/state/state-native:go_default_library", + "//cmd/beacon-chain/flags:go_default_library", "//config/features:go_default_library", "//config/fieldparams:go_default_library", "//config/params:go_default_library", @@ -120,6 +128,7 @@ go_test( "//consensus-types/primitives:go_default_library", "//encoding/bytesutil:go_default_library", "//genesis:go_default_library", + "//math:go_default_library", "//proto/dbval:go_default_library", "//proto/engine/v1:go_default_library", "//proto/prysm/v1alpha1:go_default_library", diff --git a/beacon-chain/db/kv/kv.go b/beacon-chain/db/kv/kv.go index 193cadf48b25..e84b9d41a4ee 100644 --- a/beacon-chain/db/kv/kv.go +++ b/beacon-chain/db/kv/kv.go @@ -91,6 +91,7 @@ type Store struct { blockCache *ristretto.Cache[string, interfaces.ReadOnlySignedBeaconBlock] validatorEntryCache *ristretto.Cache[[]byte, *ethpb.Validator] stateSummaryCache *stateSummaryCache + stateDiffCache *stateDiffCache ctx context.Context } @@ -112,6 +113,7 @@ var Buckets = [][]byte{ lightClientUpdatesBucket, lightClientBootstrapBucket, lightClientSyncCommitteeBucket, + stateDiffBucket, // Indices buckets. blockSlotIndicesBucket, stateSlotIndicesBucket, @@ -201,6 +203,14 @@ func NewKVStore(ctx context.Context, dirPath string, opts ...KVStoreOption) (*St return nil, err } + if features.Get().EnableStateDiff { + sdCache, err := newStateDiffCache(kv) + if err != nil { + return nil, err + } + kv.stateDiffCache = sdCache + } + return kv, nil } diff --git a/beacon-chain/db/kv/schema.go b/beacon-chain/db/kv/schema.go index 111c75006034..80d67ba1ebae 100644 --- a/beacon-chain/db/kv/schema.go +++ b/beacon-chain/db/kv/schema.go @@ -16,6 +16,7 @@ var ( stateValidatorsBucket = []byte("state-validators") feeRecipientBucket = []byte("fee-recipient") registrationBucket = []byte("registration") + stateDiffBucket = []byte("state-diff") // Light Client Updates Bucket lightClientUpdatesBucket = []byte("light-client-updates") diff --git a/beacon-chain/db/kv/state_diff.go b/beacon-chain/db/kv/state_diff.go new file mode 100644 index 000000000000..e1e4d24f8741 --- /dev/null +++ b/beacon-chain/db/kv/state_diff.go @@ -0,0 +1,232 @@ +package kv + +import ( + "context" + + "github.com/OffchainLabs/prysm/v6/beacon-chain/state" + "github.com/OffchainLabs/prysm/v6/cmd/beacon-chain/flags" + "github.com/OffchainLabs/prysm/v6/consensus-types/hdiff" + "github.com/OffchainLabs/prysm/v6/consensus-types/primitives" + "github.com/OffchainLabs/prysm/v6/monitoring/tracing/trace" + "github.com/pkg/errors" + bolt "go.etcd.io/bbolt" +) + +const ( + stateSuffix = "_s" + validatorSuffix = "_v" + balancesSuffix = "_b" +) + +/* + We use a level-based approach to save state diffs. The levels are 0-6, where each level corresponds to an exponent of 2 (exponents[lvl]). + The data at level 0 is saved every 2**exponent[0] slots and always contains a full state snapshot that is used as a base for the delta saved at other levels. +*/ + +// saveStateByDiff takes a state and decides between saving a full state snapshot or a diff. +func (s *Store) saveStateByDiff(ctx context.Context, st state.ReadOnlyBeaconState) error { + _, span := trace.StartSpan(ctx, "BeaconDB.saveStateByDiff") + defer span.End() + + if st == nil { + return errors.New("state is nil") + } + + slot := st.Slot() + offset := s.getOffset() + if uint64(slot) < offset { + return ErrSlotBeforeOffset + } + + // Find the level to save the state. + lvl := computeLevel(offset, slot) + if lvl == -1 { + return nil + } + + // Save full state if level is 0. + if lvl == 0 { + return s.saveFullSnapshot(st) + } + + // Get anchor state to compute the diff from. + anchorState, err := s.getAnchorState(offset, lvl, slot) + if err != nil { + return err + } + + err = s.saveHdiff(lvl, anchorState, st) + if err != nil { + return err + } + + return nil +} + +// stateByDiff retrieves the full state for a given slot. +func (s *Store) stateByDiff(ctx context.Context, slot primitives.Slot) (state.BeaconState, error) { + offset := s.getOffset() + if uint64(slot) < offset { + return nil, ErrSlotBeforeOffset + } + + snapshot, diffChain, err := s.getBaseAndDiffChain(offset, slot) + if err != nil { + return nil, err + } + + for _, diff := range diffChain { + snapshot, err = hdiff.ApplyDiff(ctx, snapshot, diff) + if err != nil { + return nil, err + } + } + + return snapshot, nil +} + +// SaveHdiff computes the diff between the anchor state and the current state and saves it to the database. +func (s *Store) saveHdiff(lvl int, anchor, st state.ReadOnlyBeaconState) error { + slot := uint64(st.Slot()) + key := makeKey(lvl, slot) + + diff, err := hdiff.Diff(anchor, st) + if err != nil { + return err + } + + err = s.db.Update(func(tx *bolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bolt.ErrBucketNotFound + } + buf := append(key, stateSuffix...) + if err := bucket.Put(buf, diff.StateDiff); err != nil { + return err + } + buf = append(key, validatorSuffix...) + if err := bucket.Put(buf, diff.ValidatorDiffs); err != nil { + return err + } + buf = append(key, balancesSuffix...) + if err := bucket.Put(buf, diff.BalancesDiff); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + // Save the full state to the cache (if not the last level). + if lvl != len(flags.Get().StateDiffExponents)-1 { + err = s.stateDiffCache.setAnchor(lvl, st) + if err != nil { + return err + } + } + + return nil +} + +// SaveFullSnapshot saves the full level 0 state snapshot to the database. +func (s *Store) saveFullSnapshot(st state.ReadOnlyBeaconState) error { + slot := uint64(st.Slot()) + key := makeKey(0, slot) + stateBytes, err := st.MarshalSSZ() + if err != nil { + return err + } + // add version key to value + enc, err := addKey(st.Version(), stateBytes) + if err != nil { + return err + } + + err = s.db.Update(func(tx *bolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bolt.ErrBucketNotFound + } + + if err := bucket.Put(key, enc); err != nil { + return err + } + + return nil + }) + if err != nil { + return err + } + // Save the full state to the cache, and invalidate other levels. + s.stateDiffCache.clearAnchors() + err = s.stateDiffCache.setAnchor(0, st) + if err != nil { + return err + } + + return nil +} + +func (s *Store) getDiff(lvl int, slot uint64) (hdiff.HdiffBytes, error) { + key := makeKey(lvl, slot) + var stateDiff []byte + var validatorDiff []byte + var balancesDiff []byte + + err := s.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bolt.ErrBucketNotFound + } + buf := append(key, stateSuffix...) + stateDiff = bucket.Get(buf) + if stateDiff == nil { + return errors.New("state diff not found") + } + buf = append(key, validatorSuffix...) + validatorDiff = bucket.Get(buf) + if validatorDiff == nil { + return errors.New("validator diff not found") + } + buf = append(key, balancesSuffix...) + balancesDiff = bucket.Get(buf) + if balancesDiff == nil { + return errors.New("balances diff not found") + } + return nil + }) + + if err != nil { + return hdiff.HdiffBytes{}, err + } + + return hdiff.HdiffBytes{ + StateDiff: stateDiff, + ValidatorDiffs: validatorDiff, + BalancesDiff: balancesDiff, + }, nil +} + +func (s *Store) getFullSnapshot(slot uint64) (state.BeaconState, error) { + key := makeKey(0, slot) + var enc []byte + + err := s.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bolt.ErrBucketNotFound + } + enc = bucket.Get(key) + if enc == nil { + return errors.New("state not found") + } + return nil + }) + + if err != nil { + return nil, err + } + + return s.decodeStateSnapshot(enc) +} diff --git a/beacon-chain/db/kv/state_diff_cache.go b/beacon-chain/db/kv/state_diff_cache.go new file mode 100644 index 000000000000..76e07178de42 --- /dev/null +++ b/beacon-chain/db/kv/state_diff_cache.go @@ -0,0 +1,77 @@ +package kv + +import ( + "encoding/binary" + "errors" + "sync" + + "github.com/OffchainLabs/prysm/v6/beacon-chain/state" + "github.com/OffchainLabs/prysm/v6/cmd/beacon-chain/flags" + "go.etcd.io/bbolt" +) + +type stateDiffCache struct { + sync.RWMutex + anchors []state.ReadOnlyBeaconState + offset uint64 +} + +func newStateDiffCache(s *Store) (*stateDiffCache, error) { + var offset uint64 + + err := s.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + + offsetBytes := bucket.Get([]byte("offset")) + if offsetBytes == nil { + return errors.New("state diff cache: offset not found") + } + offset = binary.LittleEndian.Uint64(offsetBytes) + return nil + }) + if err != nil { + return nil, err + } + + return &stateDiffCache{ + anchors: make([]state.ReadOnlyBeaconState, len(flags.Get().StateDiffExponents)), + offset: offset, + }, nil +} + +func (c *stateDiffCache) getAnchor(level int) state.ReadOnlyBeaconState { + c.RLock() + defer c.RUnlock() + return c.anchors[level] +} + +func (c *stateDiffCache) setAnchor(level int, anchor state.ReadOnlyBeaconState) error { + c.Lock() + defer c.Unlock() + if level >= len(c.anchors) || level < 0 { + return errors.New("state diff cache: anchor level out of range") + } + c.anchors[level] = anchor + return nil +} + +func (c *stateDiffCache) getOffset() uint64 { + c.RLock() + defer c.RUnlock() + return c.offset +} + +func (c *stateDiffCache) setOffset(offset uint64) { + c.Lock() + defer c.Unlock() + c.offset = offset +} + +func (c *stateDiffCache) clearAnchors() { + c.Lock() + defer c.Unlock() + c.anchors = make([]state.ReadOnlyBeaconState, len(flags.Get().StateDiffExponents)) +} diff --git a/beacon-chain/db/kv/state_diff_helpers.go b/beacon-chain/db/kv/state_diff_helpers.go new file mode 100644 index 000000000000..db79005f25d7 --- /dev/null +++ b/beacon-chain/db/kv/state_diff_helpers.go @@ -0,0 +1,234 @@ +package kv + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + + "github.com/OffchainLabs/prysm/v6/beacon-chain/state" + statenative "github.com/OffchainLabs/prysm/v6/beacon-chain/state/state-native" + "github.com/OffchainLabs/prysm/v6/cmd/beacon-chain/flags" + "github.com/OffchainLabs/prysm/v6/consensus-types/hdiff" + "github.com/OffchainLabs/prysm/v6/consensus-types/primitives" + "github.com/OffchainLabs/prysm/v6/math" + ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1" + "github.com/OffchainLabs/prysm/v6/runtime/version" + "go.etcd.io/bbolt" +) + +var ( + offsetKey = []byte("offset") + ErrSlotBeforeOffset = errors.New("slot is before root offset") +) + +func makeKey(level int, slot uint64) []byte { + buf := make([]byte, 16) + buf[0] = byte(level) + binary.LittleEndian.PutUint64(buf[1:], slot) + return buf +} + +func (s *Store) getAnchorState(offset uint64, lvl int, slot primitives.Slot) (anchor state.ReadOnlyBeaconState, err error) { + if lvl <= 0 || lvl >= len(flags.Get().StateDiffExponents) { + return nil, errors.New("invalid value for level") + } + + relSlot := uint64(slot) - offset + prevExp := flags.Get().StateDiffExponents[lvl-1] + span := math.PowerOf2(uint64(prevExp)) + anchorSlot := primitives.Slot((relSlot / span * span) + offset) + + // anchorLvl can be [0, lvl-1] + anchorLvl := computeLevel(offset, anchorSlot) + if anchorLvl == -1 { + return nil, errors.New("could not compute anchor level") + } + + // Check if we have the anchor in cache. + anchor = s.stateDiffCache.getAnchor(anchorLvl) + if anchor != nil { + return anchor, nil + } + + // If not, load it from the database. + anchor, err = s.stateByDiff(context.Background(), anchorSlot) + if err != nil { + return nil, err + } + + // Save it in the cache. + err = s.stateDiffCache.setAnchor(anchorLvl, anchor) + if err != nil { + return nil, err + } + return anchor, nil +} + +// computeLevel computes the level in the diff tree. Returns -1 in case slot should not be in tree. +func computeLevel(offset uint64, slot primitives.Slot) int { + rel := uint64(slot) - offset + for i, exp := range flags.Get().StateDiffExponents { + span := math.PowerOf2(uint64(exp)) + if rel%span == 0 { + return i + } + } + // If rel isn’t on any of the boundaries, we should ignore saving it. + return -1 +} + +func (s *Store) setOffset(slot primitives.Slot) error { + err := s.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + + offsetBytes := bucket.Get(offsetKey) + if offsetBytes != nil { + return fmt.Errorf("offset already set to %d", binary.LittleEndian.Uint64(offsetBytes)) + } + + offsetBytes = make([]byte, 8) + binary.LittleEndian.PutUint64(offsetBytes, uint64(slot)) + if err := bucket.Put(offsetKey, offsetBytes); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + // Save the offset in the cache. + s.stateDiffCache.setOffset(uint64(slot)) + return nil +} + +func (s *Store) getOffset() uint64 { + return s.stateDiffCache.getOffset() +} + +func keyForSnapshot(v int) []byte { + switch v { + case version.Fulu: + return fuluKey + case version.Electra: + return ElectraKey + case version.Deneb: + return denebKey + case version.Capella: + return capellaKey + case version.Bellatrix: + return bellatrixKey + case version.Altair: + return altairKey + default: + // Phase0 + return []byte{} + } +} + +func addKey(v int, bytes []byte) ([]byte, error) { + key := keyForSnapshot(v) + enc := make([]byte, len(key)+len(bytes)) + copy(enc, key) + copy(enc[len(key):], bytes) + return enc, nil +} + +func (s *Store) decodeStateSnapshot(enc []byte) (state.BeaconState, error) { + switch { + case hasFuluKey(enc): + var fuluState ethpb.BeaconStateFulu + if err := fuluState.UnmarshalSSZ(enc[len(ElectraKey):]); err != nil { + return nil, err + } + return statenative.InitializeFromProtoUnsafeFulu(&fuluState) + case HasElectraKey(enc): + var electraState ethpb.BeaconStateElectra + if err := electraState.UnmarshalSSZ(enc[len(ElectraKey):]); err != nil { + return nil, err + } + return statenative.InitializeFromProtoUnsafeElectra(&electraState) + case hasDenebKey(enc): + var denebState ethpb.BeaconStateDeneb + if err := denebState.UnmarshalSSZ(enc[len(denebKey):]); err != nil { + return nil, err + } + return statenative.InitializeFromProtoUnsafeDeneb(&denebState) + case hasCapellaKey(enc): + var capellaState ethpb.BeaconStateCapella + if err := capellaState.UnmarshalSSZ(enc[len(capellaKey):]); err != nil { + return nil, err + } + return statenative.InitializeFromProtoUnsafeCapella(&capellaState) + case hasBellatrixKey(enc): + var bellatrixState ethpb.BeaconStateBellatrix + if err := bellatrixState.UnmarshalSSZ(enc[len(bellatrixKey):]); err != nil { + return nil, err + } + return statenative.InitializeFromProtoUnsafeBellatrix(&bellatrixState) + case hasAltairKey(enc): + var altairState ethpb.BeaconStateAltair + if err := altairState.UnmarshalSSZ(enc[len(altairKey):]); err != nil { + return nil, err + } + return statenative.InitializeFromProtoUnsafeAltair(&altairState) + default: + var phase0State ethpb.BeaconState + if err := phase0State.UnmarshalSSZ(enc); err != nil { + return nil, err + } + return statenative.InitializeFromProtoUnsafePhase0(&phase0State) + } +} + +func (s *Store) getBaseAndDiffChain(offset uint64, slot primitives.Slot) (state.BeaconState, []hdiff.HdiffBytes, error) { + rel := uint64(slot) - offset + lvl := computeLevel(offset, slot) + if lvl == -1 { + return nil, nil, errors.New("slot not in tree") + } + + exponents := flags.Get().StateDiffExponents + + baseSpan := math.PowerOf2(uint64(exponents[0])) + baseAnchorSlot := (rel / baseSpan * baseSpan) + offset + + var diffChainIndices []uint64 + for i := 1; i <= lvl; i++ { + span := math.PowerOf2(uint64(exponents[i])) + diffSlot := rel / span * span + if diffSlot == baseAnchorSlot { + continue + } + diffChainIndices = appendUnique(diffChainIndices, diffSlot+offset) + } + + baseSnapshot, err := s.getFullSnapshot(baseAnchorSlot) + if err != nil { + return nil, nil, err + } + + diffChain := make([]hdiff.HdiffBytes, 0, len(diffChainIndices)) + for _, diffSlot := range diffChainIndices { + diff, err := s.getDiff(computeLevel(offset, primitives.Slot(diffSlot)), diffSlot) + if err != nil { + return nil, nil, err + } + diffChain = append(diffChain, diff) + } + + return baseSnapshot, diffChain, nil +} + +func appendUnique(s []uint64, v uint64) []uint64 { + for _, x := range s { + if x == v { + return s + } + } + return append(s, v) +} diff --git a/beacon-chain/db/kv/state_diff_test.go b/beacon-chain/db/kv/state_diff_test.go new file mode 100644 index 000000000000..ff8ef8798964 --- /dev/null +++ b/beacon-chain/db/kv/state_diff_test.go @@ -0,0 +1,602 @@ +package kv + +import ( + "context" + "encoding/binary" + "fmt" + "math/rand" + "testing" + + "github.com/OffchainLabs/prysm/v6/beacon-chain/state" + "github.com/OffchainLabs/prysm/v6/cmd/beacon-chain/flags" + "github.com/OffchainLabs/prysm/v6/config/params" + "github.com/OffchainLabs/prysm/v6/consensus-types/primitives" + "github.com/OffchainLabs/prysm/v6/math" + ethpb "github.com/OffchainLabs/prysm/v6/proto/prysm/v1alpha1" + "github.com/OffchainLabs/prysm/v6/runtime/version" + "github.com/OffchainLabs/prysm/v6/testing/require" + "github.com/OffchainLabs/prysm/v6/testing/util" + "go.etcd.io/bbolt" +) + +func TestStateDiff_LoadOrInitOffset(t *testing.T) { + db := setupDB(t) + err := setOffsetInDB(db, 10) + require.NoError(t, err) + offset := db.getOffset() + require.Equal(t, uint64(10), offset) + + err = db.setOffset(10) + require.ErrorContains(t, "offset already set", err) + offset = db.getOffset() + require.Equal(t, uint64(10), offset) +} + +func TestStateDiff_ComputeLevel(t *testing.T) { + db := setupDB(t) + setDefaultExponents() + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + offset := db.getOffset() + + // 2 ** 21 + lvl := computeLevel(offset, primitives.Slot(math.PowerOf2(21))) + require.Equal(t, 0, lvl) + + // 2 ** 21 * 3 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(21)*3)) + require.Equal(t, 0, lvl) + + // 2 ** 18 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(18))) + require.Equal(t, 1, lvl) + + // 2 ** 18 * 3 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(18)*3)) + require.Equal(t, 1, lvl) + + // 2 ** 16 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(16))) + require.Equal(t, 2, lvl) + + // 2 ** 16 * 3 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(16)*3)) + require.Equal(t, 2, lvl) + + // 2 ** 13 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(13))) + require.Equal(t, 3, lvl) + + // 2 ** 13 * 3 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(13)*3)) + require.Equal(t, 3, lvl) + + // 2 ** 11 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(11))) + require.Equal(t, 4, lvl) + + // 2 ** 11 * 3 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(11)*3)) + require.Equal(t, 4, lvl) + + // 2 ** 9 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(9))) + require.Equal(t, 5, lvl) + + // 2 ** 9 * 3 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(9)*3)) + require.Equal(t, 5, lvl) + + // 2 ** 5 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5))) + require.Equal(t, 6, lvl) + + // 2 ** 5 * 3 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)*3)) + require.Equal(t, 6, lvl) + + // 2 ** 7 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(7))) + require.Equal(t, 6, lvl) + + // 2 ** 5 + 1 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)+1)) + require.Equal(t, -1, lvl) + + // 2 ** 5 + 16 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)+16)) + require.Equal(t, -1, lvl) + + // 2 ** 5 + 32 + lvl = computeLevel(offset, primitives.Slot(math.PowerOf2(5)+32)) + require.Equal(t, 6, lvl) + +} + +func TestStateDiff_SaveFullSnapshot(t *testing.T) { + setDefaultExponents() + + // test for every version + for v := 0; v < 6; v++ { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + + // Create state with slot 0 + st, enc := createState(t, 0, v) + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + err = db.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + s := bucket.Get(makeKey(0, uint64(0))) + if s == nil { + return bbolt.ErrIncompatibleValue + } + require.DeepSSZEqual(t, enc, s) + return nil + }) + require.NoError(t, err) + }) + } +} + +func TestStateDiff_SaveAndReadFullSnapshot(t *testing.T) { + setDefaultExponents() + + // test for every version + for v := 0; v < 6; v++ { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + + st, _ := createState(t, 0, v) + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + readSt, err := db.stateByDiff(context.Background(), 0) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } +} + +func TestStateDiff_SaveDiff(t *testing.T) { + setDefaultExponents() + + // test for every version + for v := 0; v < 6; v++ { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + + // Create state with slot 2**21 + slot := primitives.Slot(math.PowerOf2(21)) + st, enc := createState(t, slot, v) + + err := setOffsetInDB(db, uint64(slot)) + require.NoError(t, err) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + err = db.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + s := bucket.Get(makeKey(0, uint64(slot))) + if s == nil { + return bbolt.ErrIncompatibleValue + } + require.DeepSSZEqual(t, enc, s) + return nil + }) + require.NoError(t, err) + + // create state with slot 2**18 (+2**21) + slot = primitives.Slot(math.PowerOf2(18) + math.PowerOf2(21)) + st, _ = createState(t, slot, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + key := makeKey(1, uint64(slot)) + err = db.db.View(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + buf := append(key, "_s"...) + s := bucket.Get(buf) + if s == nil { + return bbolt.ErrIncompatibleValue + } + buf = append(key, "_v"...) + v := bucket.Get(buf) + if v == nil { + return bbolt.ErrIncompatibleValue + } + buf = append(key, "_b"...) + b := bucket.Get(buf) + if b == nil { + return bbolt.ErrIncompatibleValue + } + return nil + }) + require.NoError(t, err) + }) + } +} + +func TestStateDiff_SaveAndReadDiff(t *testing.T) { + setDefaultExponents() + + // test for every version + for v := 0; v < 6; v++ { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + + st, _ := createState(t, 0, v) + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + slot := primitives.Slot(math.PowerOf2(5)) + st, _ = createState(t, slot, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + readSt, err := db.stateByDiff(context.Background(), slot) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } +} + +func TestStateDiff_SaveAndReadDiff_MultipleLevels(t *testing.T) { + setDefaultExponents() + + // test for every version + for v := 0; v < 6; v++ { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + + st, _ := createState(t, 0, v) + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + slot := primitives.Slot(math.PowerOf2(11)) + st, _ = createState(t, slot, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + readSt, err := db.stateByDiff(context.Background(), slot) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + + slot = primitives.Slot(math.PowerOf2(11) + math.PowerOf2(9)) + st, _ = createState(t, slot, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + readSt, err = db.stateByDiff(context.Background(), slot) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err = st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err = readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + + slot = primitives.Slot(math.PowerOf2(11) + math.PowerOf2(9) + math.PowerOf2(5)) + st, _ = createState(t, slot, v) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + readSt, err = db.stateByDiff(context.Background(), slot) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err = st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err = readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } +} + +func TestStateDiff_SaveAndReadDiffForkTransition(t *testing.T) { + setDefaultExponents() + + // test for every version + for v := 0; v < 5; v++ { + t.Run(version.String(v), func(t *testing.T) { + db := setupDB(t) + + st, _ := createState(t, 0, v) + + err := setOffsetInDB(db, 0) + require.NoError(t, err) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + slot := primitives.Slot(math.PowerOf2(5)) + st, _ = createState(t, slot, v+1) + + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + readSt, err := db.stateByDiff(context.Background(), slot) + require.NoError(t, err) + require.NotNil(t, readSt) + + stSSZ, err := st.MarshalSSZ() + require.NoError(t, err) + readStSSZ, err := readSt.MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, stSSZ, readStSSZ) + }) + } +} + +func TestStateDiff_OffsetCache(t *testing.T) { + setDefaultExponents() + + // test for slot numbers 0 and 1 for every version + for slotNum := 0; slotNum < 2; slotNum++ { + // test for every version + for v := 0; v < 6; v++ { + t.Run(fmt.Sprintf("slotNum=%d,%s", slotNum, version.String(v)), func(t *testing.T) { + db := setupDB(t) + + slot := primitives.Slot(slotNum) + err := setOffsetInDB(db, uint64(slot)) + require.NoError(t, err) + st, _ := createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + + offset := db.stateDiffCache.getOffset() + require.Equal(t, uint64(slotNum), offset) + + slot2 := primitives.Slot(uint64(slotNum) + math.PowerOf2(uint64(flags.Get().StateDiffExponents[0]))) + st2, _ := createState(t, slot2, v) + err = db.saveStateByDiff(context.Background(), st2) + require.NoError(t, err) + + offset = db.stateDiffCache.getOffset() + require.Equal(t, uint64(slot), offset) + }) + } + } +} + +func TestStateDiff_AnchorCache(t *testing.T) { + setDefaultExponents() + + // test for every version + for v := 0; v < 6; v++ { + t.Run(version.String(v), func(t *testing.T) { + exponents := flags.Get().StateDiffExponents + localCache := make([]state.ReadOnlyBeaconState, len(exponents)-1) + db := setupDB(t) + err := setOffsetInDB(db, 0) // lvl 0 + require.NoError(t, err) + + // at first the cache should be empty + for i := 0; i < len(flags.Get().StateDiffExponents); i++ { + anchor := db.stateDiffCache.getAnchor(i) + require.IsNil(t, anchor) + } + + // add level 0 + slot := primitives.Slot(0) // offset 0 is already set + st, _ := createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + localCache[0] = st + + // level 0 should be the same + require.DeepEqual(t, localCache[0], db.stateDiffCache.getAnchor(0)) + + // rest of the cache should be nil + for i := 1; i < len(exponents)-1; i++ { + require.IsNil(t, db.stateDiffCache.getAnchor(i)) + } + + // skip last level as it does not get cached + for i := len(exponents) - 2; i > 0; i-- { + slot = primitives.Slot(math.PowerOf2(uint64(exponents[i]))) + st, _ := createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + localCache[i] = st + + // anchor cache must match local cache + for i := 0; i < len(exponents)-1; i++ { + if localCache[i] == nil { + require.IsNil(t, db.stateDiffCache.getAnchor(i)) + continue + } + localSSZ, err := localCache[i].MarshalSSZ() + require.NoError(t, err) + anchorSSZ, err := db.stateDiffCache.getAnchor(i).MarshalSSZ() + require.NoError(t, err) + require.DeepSSZEqual(t, localSSZ, anchorSSZ) + } + } + + // moving to a new tree should invalidate the cache except for level 0 + twoTo21 := math.PowerOf2(21) + slot = primitives.Slot(twoTo21) + st, _ = createState(t, slot, v) + err = db.saveStateByDiff(context.Background(), st) + require.NoError(t, err) + localCache = make([]state.ReadOnlyBeaconState, len(exponents)-1) + localCache[0] = st + + // level 0 should be the same + require.DeepEqual(t, localCache[0], db.stateDiffCache.getAnchor(0)) + + // rest of the cache should be nil + for i := 1; i < len(exponents)-1; i++ { + require.IsNil(t, db.stateDiffCache.getAnchor(i)) + } + }) + } +} + +func createState(t *testing.T, slot primitives.Slot, v int) (state.ReadOnlyBeaconState, []byte) { + p := params.BeaconConfig() + var st state.BeaconState + var err error + switch v { + case version.Altair: + st, err = util.NewBeaconStateAltair() + require.NoError(t, err) + err = st.SetFork(ðpb.Fork{ + PreviousVersion: p.GenesisForkVersion, + CurrentVersion: p.AltairForkVersion, + Epoch: p.AltairForkEpoch, + }) + require.NoError(t, err) + case version.Bellatrix: + st, err = util.NewBeaconStateBellatrix() + require.NoError(t, err) + err = st.SetFork(ðpb.Fork{ + PreviousVersion: p.AltairForkVersion, + CurrentVersion: p.BellatrixForkVersion, + Epoch: p.BellatrixForkEpoch, + }) + require.NoError(t, err) + case version.Capella: + st, err = util.NewBeaconStateCapella() + require.NoError(t, err) + err = st.SetFork(ðpb.Fork{ + PreviousVersion: p.BellatrixForkVersion, + CurrentVersion: p.CapellaForkVersion, + Epoch: p.CapellaForkEpoch, + }) + require.NoError(t, err) + case version.Deneb: + st, err = util.NewBeaconStateDeneb() + require.NoError(t, err) + err = st.SetFork(ðpb.Fork{ + PreviousVersion: p.CapellaForkVersion, + CurrentVersion: p.DenebForkVersion, + Epoch: p.DenebForkEpoch, + }) + require.NoError(t, err) + case version.Electra: + st, err = util.NewBeaconStateElectra() + require.NoError(t, err) + err = st.SetFork(ðpb.Fork{ + PreviousVersion: p.DenebForkVersion, + CurrentVersion: p.ElectraForkVersion, + Epoch: p.ElectraForkEpoch, + }) + require.NoError(t, err) + default: + st, err = util.NewBeaconState() + require.NoError(t, err) + err = st.SetFork(ðpb.Fork{ + PreviousVersion: p.GenesisForkVersion, + CurrentVersion: p.GenesisForkVersion, + Epoch: 0, + }) + require.NoError(t, err) + } + + err = st.SetSlot(slot) + require.NoError(t, err) + slashings := make([]uint64, 8192) + slashings[0] = uint64(rand.Intn(10)) + err = st.SetSlashings(slashings) + require.NoError(t, err) + stssz, err := st.MarshalSSZ() + require.NoError(t, err) + enc, err := addKey(v, stssz) + require.NoError(t, err) + return st, enc +} + +func setOffsetInDB(s *Store, offset uint64) error { + err := s.db.Update(func(tx *bbolt.Tx) error { + bucket := tx.Bucket(stateDiffBucket) + if bucket == nil { + return bbolt.ErrBucketNotFound + } + + offsetBytes := bucket.Get(offsetKey) + if offsetBytes != nil { + return fmt.Errorf("offset already set to %d", binary.LittleEndian.Uint64(offsetBytes)) + } + + offsetBytes = make([]byte, 8) + binary.LittleEndian.PutUint64(offsetBytes, offset) + if err := bucket.Put(offsetKey, offsetBytes); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + sdCache, err := newStateDiffCache(s) + if err != nil { + return err + } + s.stateDiffCache = sdCache + return nil +} + +func setDefaultExponents() { + globalFlags := flags.GlobalFlags{ + StateDiffExponents: []int{21, 18, 16, 13, 11, 9, 5}, + } + flags.Init(&globalFlags) +} diff --git a/beacon-chain/node/node.go b/beacon-chain/node/node.go index 6b305d8d4964..fb245c751893 100644 --- a/beacon-chain/node/node.go +++ b/beacon-chain/node/node.go @@ -277,7 +277,10 @@ func configureBeacon(cliCtx *cli.Context) error { return errors.Wrap(err, "could not configure beacon chain") } - flags.ConfigureGlobalFlags(cliCtx) + err := flags.ConfigureGlobalFlags(cliCtx) + if err != nil { + return errors.Wrap(err, "could not configure global flags") + } if err := configureChainConfig(cliCtx); err != nil { return errors.Wrap(err, "could not configure chain config") diff --git a/changelog/bastin_state-diff-configs.md b/changelog/bastin_state-diff-configs.md new file mode 100644 index 000000000000..d6cde5b94de7 --- /dev/null +++ b/changelog/bastin_state-diff-configs.md @@ -0,0 +1,4 @@ +### Added + +- Add initial configs for the state-diff feature. +- Add kv functions for the state-diff feature. \ No newline at end of file diff --git a/cmd/beacon-chain/flags/BUILD.bazel b/cmd/beacon-chain/flags/BUILD.bazel index 59617558015e..eca867c5aea9 100644 --- a/cmd/beacon-chain/flags/BUILD.bazel +++ b/cmd/beacon-chain/flags/BUILD.bazel @@ -18,7 +18,9 @@ go_library( ], deps = [ "//cmd:go_default_library", + "//config/features:go_default_library", "//config/params:go_default_library", + "@com_github_pkg_errors//:go_default_library", "@com_github_sirupsen_logrus//:go_default_library", "@com_github_urfave_cli_v2//:go_default_library", ], @@ -26,7 +28,13 @@ go_library( go_test( name = "go_default_test", - srcs = ["api_module_test.go"], + srcs = [ + "api_module_test.go", + "config_test.go", + ], embed = [":go_default_library"], - deps = ["//testing/assert:go_default_library"], + deps = [ + "//testing/assert:go_default_library", + "//testing/require:go_default_library", + ], ) diff --git a/cmd/beacon-chain/flags/base.go b/cmd/beacon-chain/flags/base.go index a338572bd1b5..b8461dbdbf0b 100644 --- a/cmd/beacon-chain/flags/base.go +++ b/cmd/beacon-chain/flags/base.go @@ -344,4 +344,10 @@ var ( Usage: "Maximum number of signatures to batch verify at once for beacon attestation p2p gossip.", Value: 1000, } + // StateDiffExponents defines the state diff tree hierarchy levels. + StateDiffExponents = &cli.IntSliceFlag{ + Name: "state-diff-exponents", + Usage: "A comma-separated list of exponents (of 2) in decreasing order, defining the state diff hierarchy levels. The last exponent must be greater than or equal to 5.", + Value: cli.NewIntSlice(21, 18, 16, 13, 11, 9, 5), + } ) diff --git a/cmd/beacon-chain/flags/config.go b/cmd/beacon-chain/flags/config.go index e36655182715..e861112771f9 100644 --- a/cmd/beacon-chain/flags/config.go +++ b/cmd/beacon-chain/flags/config.go @@ -1,7 +1,11 @@ package flags import ( + "math" + "github.com/OffchainLabs/prysm/v6/cmd" + "github.com/OffchainLabs/prysm/v6/config/features" + "github.com/pkg/errors" "github.com/urfave/cli/v2" ) @@ -19,6 +23,7 @@ type GlobalFlags struct { BlobBatchLimitBurstFactor int DataColumnBatchLimit int DataColumnBatchLimitBurstFactor int + StateDiffExponents []int } var globalConfig *GlobalFlags @@ -38,7 +43,7 @@ func Init(c *GlobalFlags) { // ConfigureGlobalFlags initializes the global config. // based on the provided cli context. -func ConfigureGlobalFlags(ctx *cli.Context) { +func ConfigureGlobalFlags(ctx *cli.Context) error { cfg := &GlobalFlags{} if ctx.Bool(SubscribeToAllSubnets.Name) { @@ -51,6 +56,18 @@ func ConfigureGlobalFlags(ctx *cli.Context) { cfg.SubscribeAllDataSubnets = true } + // State-diff-exponents + cfg.StateDiffExponents = ctx.IntSlice(StateDiffExponents.Name) + if features.Get().EnableStateDiff { + if err := validateStateDiffExponents(cfg.StateDiffExponents); err != nil { + return err + } + } else { + if ctx.IsSet(StateDiffExponents.Name) { + log.Warn("--state-diff-exponents is set but --enable-state-diff is not; the value will be ignored.") + } + } + cfg.BlockBatchLimit = ctx.Int(BlockBatchLimit.Name) cfg.BlockBatchLimitBurstFactor = ctx.Int(BlockBatchLimitBurstFactor.Name) cfg.BlobBatchLimit = ctx.Int(BlobBatchLimit.Name) @@ -63,6 +80,7 @@ func ConfigureGlobalFlags(ctx *cli.Context) { configureMinimumPeers(ctx, cfg) Init(cfg) + return nil } // MaxDialIsActive checks if the user has enabled the max dial flag. @@ -78,3 +96,21 @@ func configureMinimumPeers(ctx *cli.Context, cfg *GlobalFlags) { cfg.MinimumSyncPeers = maxPeers } } + +func validateStateDiffExponents(exponents []int) error { + length := len(exponents) + if length == 0 || length > 15 { + return errors.New("state diff exponents must contain between 1 and 15 values") + } + if exponents[length-1] < 5 { + return errors.New("the last state diff exponent must be at least 5") + } + prev := math.MaxInt + for _, exp := range exponents { + if exp >= prev { + return errors.New("state diff exponents must be in strictly decreasing order") + } + prev = exp + } + return nil +} diff --git a/cmd/beacon-chain/flags/config_test.go b/cmd/beacon-chain/flags/config_test.go new file mode 100644 index 000000000000..459ad724a2dd --- /dev/null +++ b/cmd/beacon-chain/flags/config_test.go @@ -0,0 +1,38 @@ +package flags + +import ( + "strconv" + "testing" + + "github.com/OffchainLabs/prysm/v6/testing/require" +) + +func TestValidateStateDiffExponents(t *testing.T) { + tests := []struct { + idx int + exponents []int + wantErr bool + errMsg string + }{ + {idx: 1, exponents: []int{0, 1, 2}, wantErr: true, errMsg: "at least 5"}, + {idx: 2, exponents: []int{1, 2, 3}, wantErr: true, errMsg: "at least 5"}, + {idx: 3, exponents: []int{9, 8, 4}, wantErr: true, errMsg: "at least 5"}, + {idx: 4, exponents: []int{3, 4, 5}, wantErr: true, errMsg: "decreasing"}, + {idx: 5, exponents: []int{15, 14, 14, 12, 11}, wantErr: true, errMsg: "decreasing"}, + {idx: 6, exponents: []int{15, 14, 13, 12, 11}, wantErr: false}, + {idx: 7, exponents: []int{21, 18, 16, 13, 11, 9, 5}, wantErr: false}, + {idx: 8, exponents: []int{30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 18, 16, 13, 11, 9, 5}, wantErr: true, errMsg: "between 1 and 15 values"}, + {idx: 9, exponents: []int{}, wantErr: true, errMsg: "between 1 and 15 values"}, + } + + for _, tt := range tests { + t.Run(strconv.Itoa(tt.idx), func(t *testing.T) { + err := validateStateDiffExponents(tt.exponents) + if tt.wantErr { + require.ErrorContains(t, tt.errMsg, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/config/features/config.go b/config/features/config.go index a9b5a08e4232..1232a69c327f 100644 --- a/config/features/config.go +++ b/config/features/config.go @@ -51,6 +51,8 @@ type Flags struct { EnableExperimentalAttestationPool bool // EnableExperimentalAttestationPool enables an experimental attestation pool design. DisableDutiesV2 bool // DisableDutiesV2 sets validator client to use the get Duties endpoint EnableWeb bool // EnableWeb enables the webui on the validator client + EnableStateDiff bool // EnableStateDiff enables the experimental state diff feature for the beacon node. + // Logging related toggles. DisableGRPCConnectionLogs bool // Disables logging when a new grpc client has connected. EnableFullSSZDataLogging bool // Enables logging for full ssz data on rejected gossip messages @@ -280,6 +282,16 @@ func ConfigureBeaconChain(ctx *cli.Context) error { cfg.BlacklistedRoots = parseBlacklistedRoots(ctx.StringSlice(blacklistRoots.Name)) } + if ctx.IsSet(enableStateDiff.Name) { + logEnabled(enableStateDiff) + cfg.EnableStateDiff = true + + if ctx.IsSet(enableHistoricalSpaceRepresentation.Name) { + log.Warn("--enable-state-diff is enabled, ignoring --enable-historical-space-representation flag.") + cfg.EnableHistoricalSpaceRepresentation = false + } + } + cfg.AggregateIntervals = [3]time.Duration{aggregateFirstInterval.Value, aggregateSecondInterval.Value, aggregateThirdInterval.Value} Init(cfg) return nil diff --git a/config/features/flags.go b/config/features/flags.go index 880092336ba5..f6c7f272a0d9 100644 --- a/config/features/flags.go +++ b/config/features/flags.go @@ -172,6 +172,10 @@ var ( Name: "enable-experimental-attestation-pool", Usage: "Enables an experimental attestation pool design.", } + enableStateDiff = &cli.BoolFlag{ + Name: "enable-state-diff", + Usage: "Enables the experimental state diff feature.", + } // forceHeadFlag is a flag to force the head of the beacon chain to a specific block. forceHeadFlag = &cli.StringFlag{ Name: "sync-from",