Skip to content

Commit

Permalink
Merge pull request #318 from lidofinance/fix/balance-check-fixture
Browse files Browse the repository at this point in the history
fix: tune-up tests after the vote #180
  • Loading branch information
iamnp authored Nov 15, 2024
2 parents de875a7 + 1d111db commit ae58f35
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 28 deletions.
38 changes: 29 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from utils.config import *
from utils.txs.deploy import deploy_from_prepared_tx
from utils.test.helpers import ETH
from utils.balance import set_balance
from utils.balance import set_balance, set_balance_in_wei
from functools import wraps

ENV_OMNIBUS_BYPASS_EVENTS_DECODING = "OMNIBUS_BYPASS_EVENTS_DECODING"
Expand All @@ -27,10 +27,12 @@
def shared_setup(fn_isolation):
pass


@pytest.fixture(scope="session", autouse=True)
def network_gas_price():
network.gas_price("2 gwei")


@pytest.fixture(scope="function")
def deployer():
return accounts[0]
Expand Down Expand Up @@ -63,10 +65,12 @@ def delegate1():
def delegate2():
return set_balance("0x100b896F2Dd8c4Ca619db86BCDDb7E085143C1C5", 100000)


@pytest.fixture(scope="module")
def trp_recipient(accounts):
return set_balance("0x228cCaFeA1fa21B74257Af975A9D84d87188c61B", 100000)


@pytest.fixture(scope="module")
def eth_whale(accounts):
if network_name() in ("goerli", "goerli-fork"):
Expand Down Expand Up @@ -253,24 +257,40 @@ def parse_events_from_local_abi():
# Added contract will resolve from address during state._find_contract without a request to Etherscan
state._add_contract(contract)


@pytest.fixture(scope="session", autouse=True)
def add_balance_check_middleware():
web3.middleware_onion.add(balance_check_middleware, name='balance_check')
web3.middleware_onion.add(balance_check_middleware, name="balance_check")


# TODO: Such implicit manipulation of the balances may lead to hard-debugging errors in the future.
# Better to return back balance after request is done.
def ensure_balance(address):
if web3.eth.get_balance(address) < ETH(999):
set_balance(address, 1000000)
def ensure_balance(address) -> int:
old_balance = web3.eth.get_balance(address)
if old_balance < ETH(999):
set_balance_in_wei(address, ETH(1000000))
return web3.eth.get_balance(address) - old_balance


def balance_check_middleware(make_request, web3):
@wraps(make_request)
def middleware(method, params):
from_address = None
result = None
balance_diff = 0

if method in ["eth_sendTransaction", "eth_sendRawTransaction"]:
transaction = params[0]
from_address = transaction.get('from')
from_address = transaction.get("from")
if from_address:
ensure_balance(from_address)
balance_diff = ensure_balance(from_address)

try:
result = make_request(method, params)
finally:
if balance_diff > 0:
new_balance = max(0, web3.eth.get_balance(from_address) - balance_diff)
set_balance_in_wei(from_address, new_balance)

return result

return make_request(method, params)
return middleware
50 changes: 41 additions & 9 deletions tests/regression/test_neg_rebase_sanity_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
INACTIVITY_PENALTIES_AMOUNT_PWEI = 101
ONE_PWEI = ETH(0.001)


@pytest.fixture(scope="module")
def oracle_report_sanity_checker() -> Contract:
return contracts.oracle_report_sanity_checker
Expand All @@ -28,31 +29,39 @@ def test_negative_rebase_correct_exited_validators_count_pos_rebase(oracle_repor
reported_validators = exited_validators_count()

reported_validators_values = [value + 2 for value in reported_validators.values()]
oracle_report(cl_diff=ETH(300), stakingModuleIdsWithNewlyExitedValidators=list(reported_validators.keys()),
numExitedValidatorsByStakingModule=reported_validators_values)
oracle_report(
cl_diff=ETH(300),
stakingModuleIdsWithNewlyExitedValidators=list(reported_validators.keys()),
numExitedValidatorsByStakingModule=reported_validators_values,
)

count = oracle_report_sanity_checker.getReportDataCount()
assert count > 0
(_, stored_exited_validators, _) = oracle_report_sanity_checker.reportData(count - 1)

assert stored_exited_validators == sum(reported_validators_values)


def test_negative_rebase_correct_exited_validators_count_neg_rebase(oracle_report_sanity_checker):
locator = contracts.lido_locator
assert oracle_report_sanity_checker.address == locator.oracleReportSanityChecker()

reported_validators = exited_validators_count()

reported_validators_values = [value + 3 for value in reported_validators.values()]
oracle_report(cl_diff=-ETH(40000), stakingModuleIdsWithNewlyExitedValidators=list(reported_validators.keys()),
numExitedValidatorsByStakingModule=reported_validators_values)
oracle_report(
cl_diff=-ETH(40000),
stakingModuleIdsWithNewlyExitedValidators=list(reported_validators.keys()),
numExitedValidatorsByStakingModule=reported_validators_values,
)

count = oracle_report_sanity_checker.getReportDataCount()
assert count > 0
(_, stored_exited_validators, _) = oracle_report_sanity_checker.reportData(count - 1)

assert stored_exited_validators == sum(reported_validators_values)


def test_negative_rebase_correct_balance_neg_rebase(oracle_report_sanity_checker):
locator = contracts.lido_locator
assert oracle_report_sanity_checker.address == locator.oracleReportSanityChecker()
Expand All @@ -78,12 +87,32 @@ def test_blocked_huge_negative_rebase(oracle_report_sanity_checker):
locator = contracts.lido_locator
assert oracle_report_sanity_checker.address == locator.oracleReportSanityChecker()

# Advance the chain 60 days more without accounting oracle reports
# The idea is to simplify the calculation of the exited validators for 18 and 54 days ago
chain.sleep(60 * 24 * 60 * 60)
chain.mine(1)

(_, cl_validators, cl_balance) = contracts.lido.getBeaconStat()
count = oracle_report_sanity_checker.getReportDataCount()
assert count > 0
(_, stored_exited_validators, _) = oracle_report_sanity_checker.reportData(count - 1)

max_cl_balance = (INITIAL_SLASHING_AMOUNT_PWEI + INACTIVITY_PENALTIES_AMOUNT_PWEI) * ONE_PWEI * cl_validators
error_cl_decrease = cl_balance // 10 # 10% of current balance will lead to error
max_cl_balance = (
(INITIAL_SLASHING_AMOUNT_PWEI + INACTIVITY_PENALTIES_AMOUNT_PWEI)
* ONE_PWEI
* (cl_validators - stored_exited_validators)
)
error_cl_decrease = cl_balance // 10 # 10% of current balance will lead to error

print(encode_error("IncorrectCLBalanceDecrease(uint256, uint256)", [error_cl_decrease, max_cl_balance]))
with reverts(encode_error("IncorrectCLBalanceDecrease(uint256, uint256)", [error_cl_decrease, max_cl_balance])):
oracle_report(cl_diff=-error_cl_decrease, exclude_vaults_balances=True, silent=True)
oracle_report(
cl_diff=-error_cl_decrease,
exclude_vaults_balances=True,
simulation_block_identifier=chain.height,
silent=True,
)


def test_negative_rebase_more_than_54_reports(oracle_report_sanity_checker):
locator = contracts.lido_locator
Expand All @@ -92,8 +121,11 @@ def test_negative_rebase_more_than_54_reports(oracle_report_sanity_checker):
reported_validators_values = exited_validators_count().values()
for _ in range(58):
reported_validators_values = [value + 3 for value in reported_validators_values]
oracle_report(cl_diff=-ETH(400), stakingModuleIdsWithNewlyExitedValidators=exited_validators_count().keys(),
numExitedValidatorsByStakingModule=reported_validators_values)
oracle_report(
cl_diff=-ETH(400),
stakingModuleIdsWithNewlyExitedValidators=exited_validators_count().keys(),
numExitedValidatorsByStakingModule=reported_validators_values,
)

count = oracle_report_sanity_checker.getReportDataCount()
assert count > 0
Expand Down
4 changes: 4 additions & 0 deletions tests/regression/test_staking_limits.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
Tests for lido staking limits
"""

import pytest
import eth_abi

from brownie import web3, convert, reverts, ZERO_ADDRESS, chain
from utils.config import contracts
from utils.test.helpers import ONE_ETH
from utils.balance import set_balance


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -88,6 +90,8 @@ def test_staking_limit_initial_not_zero():
[(10**6, 10**4), (10**12, 10**10), (10**18, 10**16)],
)
def test_staking_limit_updates_per_block_correctly(voting, stranger, limit_max, limit_per_block):
set_balance(stranger.address, 1000000)

# Should update staking limits after submit
contracts.lido.setStakingLimit(limit_max, limit_per_block, {"from": voting})
staking_limit_before = contracts.lido.getCurrentStakeLimit()
Expand Down
4 changes: 3 additions & 1 deletion utils/balance.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from brownie import accounts, web3
from utils.test.helpers import ETH


def set_balance_in_wei(address, balance):
account = accounts.at(address, force=True)
providers = ["evm_setAccountBalance", "hardhat_setBalance", "anvil_setBalance"]
Expand All @@ -15,9 +16,10 @@ def set_balance_in_wei(address, balance):
if e.args[0].get("message") != f"Method {provider} is not supported":
raise e

assert account.balance() == balance, f"Failed to set balance for account: {address}"
assert account.balance() == balance, f"Failed to set balance {balance} for account: {address}"
return account


def set_balance(address, balanceInEth):
balance = ETH(balanceInEth)

Expand Down
59 changes: 50 additions & 9 deletions utils/test/oracle_report_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from eth_abi.abi import encode
from hexbytes import HexBytes

from utils.config import (contracts, AO_CONSENSUS_VERSION)
from utils.config import contracts, AO_CONSENSUS_VERSION
from utils.test.exit_bus_data import encode_data
from utils.test.helpers import ETH, GWEI, eth_balance
from utils.test.merkle_tree import Tree

ZERO_HASH = bytes([0] * 32)
ZERO_BYTES32 = HexBytes(ZERO_HASH)
ONE_DAY = 1 * 24 * 60 * 60
SHARE_RATE_PRECISION = 10 ** 27
SHARE_RATE_PRECISION = 10**27
EXTRA_DATA_FORMAT_EMPTY = 0
EXTRA_DATA_FORMAT_LIST = 1

Expand Down Expand Up @@ -113,7 +113,7 @@ def prepare_csm_report(node_operators_rewards: dict, ref_slot):
shares = node_operators_rewards.copy()
if len(shares) < 2:
# put a stone
shares[2 ** 64 - 1] = 0
shares[2**64 - 1] = 0

tree = Tree.new(tuple((no_id, amount) for (no_id, amount) in shares.items()))
# semi-random values
Expand Down Expand Up @@ -143,8 +143,20 @@ def encode_data_from_abi(data, abi, func_name):
def get_finalization_batches(
share_rate: int, limited_withdrawal_vault_balance, limited_el_rewards_vault_balance
) -> list[int]:
(_, _, _, _, _, _, _, requestTimestampMargin, _, _, _,
_) = contracts.oracle_report_sanity_checker.getOracleReportLimits()
(
_,
_,
_,
_,
_,
_,
_,
requestTimestampMargin,
_,
_,
_,
_,
) = contracts.oracle_report_sanity_checker.getOracleReportLimits()
buffered_ether = contracts.lido.getBufferedEther()
unfinalized_steth = contracts.withdrawal_queue.unfinalizedStETH()
reserved_buffer = min(buffered_ether, unfinalized_steth)
Expand Down Expand Up @@ -219,7 +231,7 @@ def push_oracle_report(
extraDataItemsCount=extraDataItemsCount,
)
submitter = reach_consensus(refSlot, hash, consensusVersion, contracts.hash_consensus_for_accounting_oracle, silent)
accounts[0].transfer(submitter, 10 ** 19)
accounts[0].transfer(submitter, 10**19)
# print(contracts.oracle_report_sanity_checker.getOracleReportLimits())
report_tx = contracts.accounting_oracle.submitReportData(items, oracleVersion, {"from": submitter})
if not silent:
Expand All @@ -230,8 +242,9 @@ def push_oracle_report(
if not silent:
print("Submitted empty extra data report")
else:
extra_report_tx_list = [contracts.accounting_oracle.submitReportExtraDataList(data, {"from": submitter}) for
data in extraDataList]
extra_report_tx_list = [
contracts.accounting_oracle.submitReportExtraDataList(data, {"from": submitter}) for data in extraDataList
]
if not silent:
print("Submitted NOT empty extra data report")

Expand Down Expand Up @@ -263,6 +276,22 @@ def simulate_report(
):
(_, SECONDS_PER_SLOT, GENESIS_TIME) = contracts.hash_consensus_for_accounting_oracle.getChainConfig()
reportTime = GENESIS_TIME + refSlot * SECONDS_PER_SLOT

override_slot = web3.keccak(text="lido.BaseOracle.lastProcessingRefSlot").hex()
state_override = {
contracts.accounting_oracle.address: {
# Fix: Sanity checker uses `lastProcessingRefSlot` from AccountingOracle to
# properly process negative rebase sanity checks. Since current simulation skips call to AO,
# setting up `lastProcessingRefSlot` directly.
#
# The code is taken from the current production `lido-oracle` implementation
# source: https://github.com/lidofinance/lido-oracle/blob/da393bf06250344a4d06dce6d1ac6a3ddcb9c7a3/src/providers/execution/contracts/lido.py#L93-L95
"stateDiff": {
override_slot: refSlot,
},
},
}

try:
return contracts.lido.handleOracleReport.call(
reportTime,
Expand All @@ -276,9 +305,19 @@ def simulate_report(
0,
{"from": contracts.accounting_oracle.address},
block_identifier=block_identifier,
override=state_override,
)
except VirtualMachineError:
# workaround for empty revert message from ganache on eth_call

# override storage value of the processing reference slot to make the simulation sound
# Since it's not possible to pass an override as a part of the state-changing transaction
web3.provider.make_request(
# can assume ganache only here
"evm_setAccountStorageAt",
[contracts.accounting_oracle.address, override_slot, refSlot],
)

contracts.lido.handleOracleReport(
reportTime,
ONE_DAY,
Expand All @@ -301,7 +340,9 @@ def wait_to_next_available_report_time(consensus_contract):
except VirtualMachineError as e:
if "InitialEpochIsYetToArrive" in str(e):
frame_config = consensus_contract.getFrameConfig()
chain.sleep(GENESIS_TIME + 1 + (frame_config["initialEpoch"] * SLOTS_PER_EPOCH * SECONDS_PER_SLOT) - chain.time())
chain.sleep(
GENESIS_TIME + 1 + (frame_config["initialEpoch"] * SLOTS_PER_EPOCH * SECONDS_PER_SLOT) - chain.time()
)
chain.mine(1)
(refSlot, _) = consensus_contract.getCurrentFrame()
else:
Expand Down

0 comments on commit ae58f35

Please sign in to comment.