diff --git a/store/CHANGELOG.md b/store/CHANGELOG.md index 0724a1e8ad1c..d176a31cf34d 100644 --- a/store/CHANGELOG.md +++ b/store/CHANGELOG.md @@ -25,6 +25,10 @@ Ref: https://keepachangelog.com/en/1.0.0/ ## [Unreleased] +### Improvements + +* [#25195](https://github.com/cosmos/cosmos-sdk/pull/25195) Improve overflow handling in gaskv store by using deterministic gas consumption instead of maximum uint64 values. + ### Bug Fixes * [#20425](https://github.com/cosmos/cosmos-sdk/pull/20425) Fix nil pointer panic when querying historical state where a new store does not exist. diff --git a/store/gaskv/store.go b/store/gaskv/store.go index 75d379a03e2f..f6fd852cf9c1 100644 --- a/store/gaskv/store.go +++ b/store/gaskv/store.go @@ -1,6 +1,7 @@ package gaskv import ( + "fmt" "io" "cosmossdk.io/store/types" @@ -37,9 +38,23 @@ func (gs *Store) Get(key []byte) (value []byte) { gs.gasMeter.ConsumeGas(gs.gasConfig.ReadCostFlat, types.GasReadCostFlatDesc) value = gs.parent.Get(key) - // TODO overflow-safe math? - gs.gasMeter.ConsumeGas(gs.gasConfig.ReadCostPerByte*types.Gas(len(key)), types.GasReadPerByteDesc) - gs.gasMeter.ConsumeGas(gs.gasConfig.ReadCostPerByte*types.Gas(len(value)), types.GasReadPerByteDesc) + // Safe gas calculation for key length + if gasCost, err := SafeMul(gs.gasConfig.ReadCostPerByte, len(key)); err == nil { + gs.gasMeter.ConsumeGas(gasCost, types.GasReadPerByteDesc) + } else { + // If overflow occurs, trigger out-of-gas panic deterministically + remaining := gs.gasMeter.Limit() - gs.gasMeter.GasConsumed() + gs.gasMeter.ConsumeGas(remaining+1, types.GasReadPerByteDesc) + } + + // Safe gas calculation for value length + if gasCost, err := SafeMul(gs.gasConfig.ReadCostPerByte, len(value)); err == nil { + gs.gasMeter.ConsumeGas(gasCost, types.GasReadPerByteDesc) + } else { + // If overflow occurs, trigger out-of-gas panic deterministically + remaining := gs.gasMeter.Limit() - gs.gasMeter.GasConsumed() + gs.gasMeter.ConsumeGas(remaining+1, types.GasReadPerByteDesc) + } return value } @@ -49,9 +64,25 @@ func (gs *Store) Set(key, value []byte) { types.AssertValidKey(key) types.AssertValidValue(value) gs.gasMeter.ConsumeGas(gs.gasConfig.WriteCostFlat, types.GasWriteCostFlatDesc) - // TODO overflow-safe math? - gs.gasMeter.ConsumeGas(gs.gasConfig.WriteCostPerByte*types.Gas(len(key)), types.GasWritePerByteDesc) - gs.gasMeter.ConsumeGas(gs.gasConfig.WriteCostPerByte*types.Gas(len(value)), types.GasWritePerByteDesc) + + // Safe gas calculation for key length + if gasCost, err := SafeMul(gs.gasConfig.WriteCostPerByte, len(key)); err == nil { + gs.gasMeter.ConsumeGas(gasCost, types.GasWritePerByteDesc) + } else { + // If overflow occurs, trigger out-of-gas panic deterministically + remaining := gs.gasMeter.Limit() - gs.gasMeter.GasConsumed() + gs.gasMeter.ConsumeGas(remaining+1, types.GasWritePerByteDesc) + } + + // Safe gas calculation for value length + if gasCost, err := SafeMul(gs.gasConfig.WriteCostPerByte, len(value)); err == nil { + gs.gasMeter.ConsumeGas(gasCost, types.GasWritePerByteDesc) + } else { + // If overflow occurs, trigger out-of-gas panic deterministically + remaining := gs.gasMeter.Limit() - gs.gasMeter.GasConsumed() + gs.gasMeter.ConsumeGas(remaining+1, types.GasWritePerByteDesc) + } + gs.parent.Set(key, value) } @@ -170,8 +201,40 @@ func (gi *gasIterator) consumeSeekGas() { key := gi.Key() value := gi.Value() - gi.gasMeter.ConsumeGas(gi.gasConfig.ReadCostPerByte*types.Gas(len(key)), types.GasValuePerByteDesc) - gi.gasMeter.ConsumeGas(gi.gasConfig.ReadCostPerByte*types.Gas(len(value)), types.GasValuePerByteDesc) + // Safe gas calculation for key length + if gasCost, err := SafeMul(gi.gasConfig.ReadCostPerByte, len(key)); err == nil { + gi.gasMeter.ConsumeGas(gasCost, types.GasValuePerByteDesc) + } else { + // If overflow occurs, trigger out-of-gas panic deterministically + remaining := gi.gasMeter.Limit() - gi.gasMeter.GasConsumed() + gi.gasMeter.ConsumeGas(remaining+1, types.GasValuePerByteDesc) + } + + // Safe gas calculation for value length + if gasCost, err := SafeMul(gi.gasConfig.ReadCostPerByte, len(value)); err == nil { + gi.gasMeter.ConsumeGas(gasCost, types.GasValuePerByteDesc) + } else { + // If overflow occurs, trigger out-of-gas panic deterministically + remaining := gi.gasMeter.Limit() - gi.gasMeter.GasConsumed() + gi.gasMeter.ConsumeGas(remaining+1, types.GasValuePerByteDesc) + } } gi.gasMeter.ConsumeGas(gi.gasConfig.IterNextCostFlat, types.GasIterNextCostFlatDesc) } + +// SafeMul performs safe multiplication of gas cost and length to prevent overflow +func SafeMul(cost types.Gas, length int) (types.Gas, error) { + if length < 0 { + return 0, fmt.Errorf("negative length: %d", length) + } + if cost == 0 { + return 0, nil + } + + // Check for overflow: if cost * uint64(length) would overflow uint64 + if uint64(length) > 0 && cost > ^types.Gas(0)/types.Gas(length) { + return 0, fmt.Errorf("gas calculation overflow: cost=%d, length=%d", cost, length) + } + + return cost * types.Gas(length), nil +} diff --git a/store/gaskv/store_test.go b/store/gaskv/store_test.go index 354832d17c40..6289e9914829 100644 --- a/store/gaskv/store_test.go +++ b/store/gaskv/store_test.go @@ -17,6 +17,152 @@ func bz(s string) []byte { return []byte(s) } func keyFmt(i int) []byte { return bz(fmt.Sprintf("key%0.8d", i)) } func valFmt(i int) []byte { return bz(fmt.Sprintf("value%0.8d", i)) } +// TestSafeMul tests the safeMul function for various scenarios +func TestSafeMul(t *testing.T) { + // Test normal cases + t.Run("normal cases", func(t *testing.T) { + // Test basic multiplication + result, err := gaskv.SafeMul(10, 5) + require.NoError(t, err) + require.Equal(t, types.Gas(50), result) + + // Test with zero cost + result, err = gaskv.SafeMul(0, 1000) + require.NoError(t, err) + require.Equal(t, types.Gas(0), result) + + // Test with zero length + result, err = gaskv.SafeMul(1000, 0) + require.NoError(t, err) + require.Equal(t, types.Gas(0), result) + + // Test with both zero + result, err = gaskv.SafeMul(0, 0) + require.NoError(t, err) + require.Equal(t, types.Gas(0), result) + + // Test large but safe values + result, err = gaskv.SafeMul(1000000, 1000000) + require.NoError(t, err) + require.Equal(t, types.Gas(1000000000000), result) + }) + + // Test edge cases + t.Run("edge cases", func(t *testing.T) { + // Test maximum uint64 values that don't overflow + maxUint64 := ^types.Gas(0) + result, err := gaskv.SafeMul(maxUint64, 1) + require.NoError(t, err) + require.Equal(t, maxUint64, result) + + // Test with 1 and a large but safe value + // Use a value that's safe to convert to int + safeLargeValue := 1000000000 // 1 billion, safe for int + result, err = gaskv.SafeMul(1, safeLargeValue) + require.NoError(t, err) + require.Equal(t, types.Gas(safeLargeValue), result) + }) + + // Test overflow cases + t.Run("overflow cases", func(t *testing.T) { + maxUint64 := ^types.Gas(0) + + // Test overflow: maxUint64 * 2 should overflow + result, err := gaskv.SafeMul(maxUint64, 2) + require.Error(t, err) + require.Contains(t, err.Error(), "gas calculation overflow") + require.Equal(t, types.Gas(0), result) + + // Test overflow: large values that multiply to overflow + result, err = gaskv.SafeMul(maxUint64/2+1, 2) + require.Error(t, err) + require.Contains(t, err.Error(), "gas calculation overflow") + require.Equal(t, types.Gas(0), result) + + // Test overflow: choose length that is safely representable as int on 32/64-bit, + // and a cost that guarantees overflow when multiplied by length. + // length = 1<<30 is safe for 32-bit; cost = floor(MaxUint64/length) + 1 ensures overflow. + length := 1 << 30 + overflowCost := (^types.Gas(0))/types.Gas(length) + 1 + result, err = gaskv.SafeMul(overflowCost, length) + require.Error(t, err) + require.Contains(t, err.Error(), "gas calculation overflow") + require.Equal(t, types.Gas(0), result) + }) + + // Test negative length + t.Run("negative length", func(t *testing.T) { + result, err := gaskv.SafeMul(100, -1) + require.Error(t, err) + require.Contains(t, err.Error(), "negative length") + require.Equal(t, types.Gas(0), result) + + result, err = gaskv.SafeMul(0, -100) + require.Error(t, err) + require.Contains(t, err.Error(), "negative length") + require.Equal(t, types.Gas(0), result) + }) + + // Test boundary cases + t.Run("boundary cases", func(t *testing.T) { + maxUint64 := ^types.Gas(0) + + // Test exactly at the boundary (should not overflow) + // Find a value that when multiplied by 2 equals maxUint64 + boundaryValue := maxUint64 / 2 + result, err := gaskv.SafeMul(boundaryValue, 2) + require.NoError(t, err) + // The issue is that maxUint64 is odd, so dividing by 2 loses 1 + // We need to handle this case properly + if maxUint64%2 == 1 { + // If maxUint64 is odd, boundaryValue * 2 will be maxUint64 - 1 + require.Equal(t, maxUint64-1, result) + } else { + require.Equal(t, maxUint64, result) + } + + // Test just over the boundary (should overflow) + // Use a value that's guaranteed to overflow when multiplied by 2 + overflowValue := maxUint64/2 + 1 + result, err = gaskv.SafeMul(overflowValue, 2) + require.Error(t, err) + require.Contains(t, err.Error(), "gas calculation overflow") + require.Equal(t, types.Gas(0), result) + }) +} + +// TestSafeMulIntegration tests that safeMul works correctly in actual gas calculations +func TestSafeMulIntegration(t *testing.T) { + mem := dbadapter.Store{DB: dbm.NewMemDB()} + meter := types.NewGasMeter(1000000) + st := gaskv.NewStore(mem, meter, types.KVGasConfig()) + + // Test with normal sized data + normalKey := []byte("normal_key") + normalValue := []byte("normal_value") + st.Set(normalKey, normalValue) + value := st.Get(normalKey) + require.Equal(t, normalValue, value) + + // Test with large data (but not too large to avoid key size limits) + largeKey := make([]byte, 10000) // 10KB key + largeValue := make([]byte, 10000) // 10KB value + for i := range largeKey { + largeKey[i] = byte(i % 256) + largeValue[i] = byte(i % 256) + } + + // This should work without overflow + st.Set(largeKey, largeValue) + retrievedValue := st.Get(largeKey) + require.Equal(t, largeValue, retrievedValue) + + // Verify gas was consumed (should be a large amount but not overflow) + gasConsumed := meter.GasConsumed() + require.Greater(t, gasConsumed, types.Gas(0)) + require.Less(t, gasConsumed, ^types.Gas(0)) // Should not be max uint64 +} + func TestGasKVStoreBasic(t *testing.T) { mem := dbadapter.Store{DB: dbm.NewMemDB()} meter := types.NewGasMeter(10000)