Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 65 additions & 8 deletions store/gaskv/store.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gaskv

import (
"fmt"
"io"

"cosmossdk.io/store/types"
Expand Down Expand Up @@ -37,9 +38,21 @@ func (gs *Store) Get(key []byte) (value []byte) {
gs.gasMeter.ConsumeGas(gs.gasConfig.ReadCostFlat, types.GasReadCostFlatDesc)
value = gs.parent.Get(key)

// TODO overflow-safe math?
gs.gasMeter.ConsumeGas(gs.gasConfig.ReadCostPerByte*types.Gas(len(key)), types.GasReadPerByteDesc)
gs.gasMeter.ConsumeGas(gs.gasConfig.ReadCostPerByte*types.Gas(len(value)), types.GasReadPerByteDesc)
// Safe gas calculation for key length
if gasCost, err := SafeMul(gs.gasConfig.ReadCostPerByte, len(key)); err == nil {
gs.gasMeter.ConsumeGas(gasCost, types.GasReadPerByteDesc)
} else {
// If overflow occurs, consume maximum gas as a safety measure
gs.gasMeter.ConsumeGas(types.Gas(^uint64(0)), types.GasReadPerByteDesc)
}

// Safe gas calculation for value length
if gasCost, err := SafeMul(gs.gasConfig.ReadCostPerByte, len(value)); err == nil {
gs.gasMeter.ConsumeGas(gasCost, types.GasReadPerByteDesc)
} else {
// If overflow occurs, consume maximum gas as a safety measure
gs.gasMeter.ConsumeGas(types.Gas(^uint64(0)), types.GasReadPerByteDesc)
}

return value
}
Expand All @@ -49,9 +62,23 @@ func (gs *Store) Set(key, value []byte) {
types.AssertValidKey(key)
types.AssertValidValue(value)
gs.gasMeter.ConsumeGas(gs.gasConfig.WriteCostFlat, types.GasWriteCostFlatDesc)
// TODO overflow-safe math?
gs.gasMeter.ConsumeGas(gs.gasConfig.WriteCostPerByte*types.Gas(len(key)), types.GasWritePerByteDesc)
gs.gasMeter.ConsumeGas(gs.gasConfig.WriteCostPerByte*types.Gas(len(value)), types.GasWritePerByteDesc)

// Safe gas calculation for key length
if gasCost, err := SafeMul(gs.gasConfig.WriteCostPerByte, len(key)); err == nil {
gs.gasMeter.ConsumeGas(gasCost, types.GasWritePerByteDesc)
} else {
// If overflow occurs, consume maximum gas as a safety measure
gs.gasMeter.ConsumeGas(types.Gas(^uint64(0)), types.GasWritePerByteDesc)
}

// Safe gas calculation for value length
if gasCost, err := SafeMul(gs.gasConfig.WriteCostPerByte, len(value)); err == nil {
gs.gasMeter.ConsumeGas(gasCost, types.GasWritePerByteDesc)
} else {
// If overflow occurs, consume maximum gas as a safety measure
gs.gasMeter.ConsumeGas(types.Gas(^uint64(0)), types.GasWritePerByteDesc)
}

gs.parent.Set(key, value)
}

Expand Down Expand Up @@ -170,8 +197,38 @@ func (gi *gasIterator) consumeSeekGas() {
key := gi.Key()
value := gi.Value()

gi.gasMeter.ConsumeGas(gi.gasConfig.ReadCostPerByte*types.Gas(len(key)), types.GasValuePerByteDesc)
gi.gasMeter.ConsumeGas(gi.gasConfig.ReadCostPerByte*types.Gas(len(value)), types.GasValuePerByteDesc)
// Safe gas calculation for key length
if gasCost, err := SafeMul(gi.gasConfig.ReadCostPerByte, len(key)); err == nil {
gi.gasMeter.ConsumeGas(gasCost, types.GasValuePerByteDesc)
} else {
// If overflow occurs, consume maximum gas as a safety measure
gi.gasMeter.ConsumeGas(types.Gas(^uint64(0)), types.GasValuePerByteDesc)
}

// Safe gas calculation for value length
if gasCost, err := SafeMul(gi.gasConfig.ReadCostPerByte, len(value)); err == nil {
gi.gasMeter.ConsumeGas(gasCost, types.GasValuePerByteDesc)
} else {
// If overflow occurs, consume maximum gas as a safety measure
gi.gasMeter.ConsumeGas(types.Gas(^uint64(0)), types.GasValuePerByteDesc)
}
}
gi.gasMeter.ConsumeGas(gi.gasConfig.IterNextCostFlat, types.GasIterNextCostFlatDesc)
}

// SafeMul performs safe multiplication of gas cost and length to prevent overflow
func SafeMul(cost types.Gas, length int) (types.Gas, error) {
if length < 0 {
return 0, fmt.Errorf("negative length: %d", length)
}
if cost == 0 {
return 0, nil
}

// Check for overflow: if cost * uint64(length) would overflow uint64
if uint64(length) > 0 && cost > types.Gas(^uint64(0))/types.Gas(length) {
return 0, fmt.Errorf("gas calculation overflow: cost=%d, length=%d", cost, length)
}

return cost * types.Gas(length), nil
}
142 changes: 142 additions & 0 deletions store/gaskv/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,148 @@ func bz(s string) []byte { return []byte(s) }
func keyFmt(i int) []byte { return bz(fmt.Sprintf("key%0.8d", i)) }
func valFmt(i int) []byte { return bz(fmt.Sprintf("value%0.8d", i)) }

// TestSafeMul tests the safeMul function for various scenarios
func TestSafeMul(t *testing.T) {
// Test normal cases
t.Run("normal cases", func(t *testing.T) {
// Test basic multiplication
result, err := gaskv.SafeMul(10, 5)
require.NoError(t, err)
require.Equal(t, types.Gas(50), result)

// Test with zero cost
result, err = gaskv.SafeMul(0, 1000)
require.NoError(t, err)
require.Equal(t, types.Gas(0), result)

// Test with zero length
result, err = gaskv.SafeMul(1000, 0)
require.NoError(t, err)
require.Equal(t, types.Gas(0), result)

// Test with both zero
result, err = gaskv.SafeMul(0, 0)
require.NoError(t, err)
require.Equal(t, types.Gas(0), result)

// Test large but safe values
result, err = gaskv.SafeMul(1000000, 1000000)
require.NoError(t, err)
require.Equal(t, types.Gas(1000000000000), result)
})

// Test edge cases
t.Run("edge cases", func(t *testing.T) {
// Test maximum uint64 values that don't overflow
maxUint64 := types.Gas(^uint64(0))
result, err := gaskv.SafeMul(maxUint64, 1)
require.NoError(t, err)
require.Equal(t, maxUint64, result)

// Test with 1 and a large but safe value
// Use a value that's safe to convert to int
safeLargeValue := 1000000000 // 1 billion, safe for int
result, err = gaskv.SafeMul(1, safeLargeValue)
require.NoError(t, err)
require.Equal(t, types.Gas(safeLargeValue), result)
})

// Test overflow cases
t.Run("overflow cases", func(t *testing.T) {
maxUint64 := types.Gas(^uint64(0))

// Test overflow: maxUint64 * 2 should overflow
result, err := gaskv.SafeMul(maxUint64, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "gas calculation overflow")
require.Equal(t, types.Gas(0), result)

// Test overflow: large values that multiply to overflow
result, err = gaskv.SafeMul(maxUint64/2+1, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "gas calculation overflow")
require.Equal(t, types.Gas(0), result)

// Test overflow: very large length
result, err = gaskv.SafeMul(1000, int(maxUint64/1000+1))
require.Error(t, err)
require.Contains(t, err.Error(), "gas calculation overflow")
require.Equal(t, types.Gas(0), result)
})

// Test negative length
t.Run("negative length", func(t *testing.T) {
result, err := gaskv.SafeMul(100, -1)
require.Error(t, err)
require.Contains(t, err.Error(), "negative length")
require.Equal(t, types.Gas(0), result)

result, err = gaskv.SafeMul(0, -100)
require.Error(t, err)
require.Contains(t, err.Error(), "negative length")
require.Equal(t, types.Gas(0), result)
})

// Test boundary cases
t.Run("boundary cases", func(t *testing.T) {
maxUint64 := types.Gas(^uint64(0))

// Test exactly at the boundary (should not overflow)
// Find a value that when multiplied by 2 equals maxUint64
boundaryValue := maxUint64 / 2
result, err := gaskv.SafeMul(boundaryValue, 2)
require.NoError(t, err)
// The issue is that maxUint64 is odd, so dividing by 2 loses 1
// We need to handle this case properly
if maxUint64%2 == 1 {
// If maxUint64 is odd, boundaryValue * 2 will be maxUint64 - 1
require.Equal(t, maxUint64-1, result)
} else {
require.Equal(t, maxUint64, result)
}

// Test just over the boundary (should overflow)
// Use a value that's guaranteed to overflow when multiplied by 2
overflowValue := maxUint64/2 + 1
result, err = gaskv.SafeMul(overflowValue, 2)
require.Error(t, err)
require.Contains(t, err.Error(), "gas calculation overflow")
require.Equal(t, types.Gas(0), result)
})
}

// TestSafeMulIntegration tests that safeMul works correctly in actual gas calculations
func TestSafeMulIntegration(t *testing.T) {
mem := dbadapter.Store{DB: dbm.NewMemDB()}
meter := types.NewGasMeter(1000000)
st := gaskv.NewStore(mem, meter, types.KVGasConfig())

// Test with normal sized data
normalKey := []byte("normal_key")
normalValue := []byte("normal_value")
st.Set(normalKey, normalValue)
value := st.Get(normalKey)
require.Equal(t, normalValue, value)

// Test with large data (but not too large to avoid key size limits)
largeKey := make([]byte, 10000) // 10KB key
largeValue := make([]byte, 10000) // 10KB value
for i := range largeKey {
largeKey[i] = byte(i % 256)
largeValue[i] = byte(i % 256)
}

// This should work without overflow
st.Set(largeKey, largeValue)
retrievedValue := st.Get(largeKey)
require.Equal(t, largeValue, retrievedValue)

// Verify gas was consumed (should be a large amount but not overflow)
gasConsumed := meter.GasConsumed()
require.Greater(t, gasConsumed, types.Gas(0))
require.Less(t, gasConsumed, types.Gas(^uint64(0))) // Should not be max uint64
}

func TestGasKVStoreBasic(t *testing.T) {
mem := dbadapter.Store{DB: dbm.NewMemDB()}
meter := types.NewGasMeter(10000)
Expand Down