Skip to content

Commit

Permalink
Implement kevm.forgetBranch cheatcode (#899)
Browse files Browse the repository at this point in the history
* draft FOUNDRYSemantics

* forgetBranch

* add mlEqualsTrue

* minor corrections

* add simplification step

* formatting

* add back not equal

* rename FOUNDRYSemantics to KontrolSemantics

* checking for negation as well

* correcting indentation

* expanding functionality

* heuristic simplifications

* further refinement

* refactoring _exec_forget_custom_step

* add show test

* fix test

* update expected output

---------

Co-authored-by: Petar Maksimovic <[email protected]>
Co-authored-by: Palina <[email protected]>
Co-authored-by: Petar Maksimović <[email protected]>
  • Loading branch information
4 people authored Jan 23, 2025
1 parent e35b4a5 commit 7d496d3
Show file tree
Hide file tree
Showing 8 changed files with 1,555 additions and 36 deletions.
143 changes: 138 additions & 5 deletions src/kontrol/foundry.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
collect,
extract_lhs,
flatten_label,
free_vars,
minimize_term,
set_cell,
top_down,
Expand All @@ -34,9 +35,10 @@
from pyk.kcfg import KCFG
from pyk.kcfg.kcfg import Step
from pyk.kcfg.minimize import KCFGMinimizer
from pyk.kdist import kdist
from pyk.prelude.bytes import bytesToken
from pyk.prelude.collections import map_empty
from pyk.prelude.k import DOTS
from pyk.prelude.k import DOTS, GENERATED_TOP_CELL
from pyk.prelude.kbool import notBool
from pyk.prelude.kint import INT, intToken
from pyk.prelude.ml import mlEqualsFalse, mlEqualsTrue
Expand Down Expand Up @@ -103,7 +105,7 @@ def cut_point_rules(
break_on_basic_blocks: bool,
break_on_load_program: bool,
) -> list[str]:
return ['FOUNDRY-CHEAT-CODES.rename'] + KEVMSemantics.cut_point_rules(
return ['FOUNDRY-CHEAT-CODES.rename', 'FOUNDRY-ACCOUNTS.forget'] + KEVMSemantics.cut_point_rules(
break_on_jumpi,
break_on_jump,
break_on_calls,
Expand Down Expand Up @@ -153,14 +155,145 @@ def _exec_rename_custom_step(self, cterm: CTerm) -> KCFGExtendResult | None:
_LOGGER.info(f'Renaming {target_var.name} to {name}')
return Step(CTerm(new_cterm.config, constraints), 1, (), ['foundry_rename'], cut=True)

def custom_step(self, cterm: CTerm, _cterm_symbolic: CTermSymbolic) -> KCFGExtendResult | None:
def _check_forget_pattern(self, cterm: CTerm) -> bool:
"""Given a CTerm, check if the rule 'FOUNDRY-ACCOUNTS.forget' is at the top of the K_CELL.
This method checks if the 'FOUNDRY-ACCOUNTS.forget' rule is at the top of the `K_CELL` in the given `cterm`.
If the rule matches, the resulting substitution is cached in `_cached_subst` for later use in `custom_step`
:param cterm: The CTerm representing the current state of the proof node.
:return: `True` if the pattern matches and a custom step can be made; `False` otherwise.
"""
abstract_pattern = KSequence(
[
KApply('cheatcode_forget', [KVariable('###TERM1'), KVariable('###OPERATOR'), KVariable('###TERM2')]),
KVariable('###CONTINUATION'),
]
)
self._cached_subst = abstract_pattern.match(cterm.cell('K_CELL'))
return self._cached_subst is not None

def _exec_forget_custom_step(self, cterm: CTerm, cterm_symbolic: CTermSymbolic) -> KCFGExtendResult | None:
"""Remove the constraint at the top of K_CELL of a given CTerm from its path constraints,
as part of the 'FOUNDRY-ACCOUNTS.forget' cut-rule.
:param cterm: CTerm representing a proof node
:param cterm_symbolic: CTermSymbolic instance
:return: A Step of depth 1 carrying a new configuration in which the constraint is consumed from the top
of the K cell and is removed from the initial path constraints if it existed, together with
information that the `cheatcode_forget` rule has been applied.
"""

def _find_constraints_to_keep(cterm: CTerm, constraint_vars: frozenset[str]) -> set[KInner]:
range_patterns: list[KInner] = [
mlEqualsTrue(KApply('_<Int_', KVariable('###VARL', INT), KVariable('###VARR', INT))),
mlEqualsTrue(KApply('_<=Int_', KVariable('###VARL', INT), KVariable('###VARR', INT))),
mlEqualsTrue(notBool(KApply('_==Int_', KVariable('###VARL', INT), KVariable('###VARR', INT)))),
]
constraints_to_keep: set[KInner] = set()
for constraint in cterm.constraints:
for pattern in range_patterns:
subst_rcp = pattern.match(constraint)
if subst_rcp is not None and (
(
type(subst_rcp['###VARL']) is KVariable
and subst_rcp['###VARL'].name in constraint_vars
and type(subst_rcp['###VARR']) is KToken
)
or (
type(subst_rcp['###VARR']) is KVariable
and subst_rcp['###VARR'].name in constraint_vars
and type(subst_rcp['###VARL']) is KToken
)
):
constraints_to_keep.add(constraint)
break
return constraints_to_keep

def _filter_constraints_by_simplification(
cterm_symbolic: CTermSymbolic,
initial_cterm: CTerm,
constraints_to_remove: list[KInner],
constraints_to_keep: set[KInner],
constraints: set[KInner],
empty_config: CTerm,
) -> set[KInner]:
for constraint_variant in constraints_to_remove:
simplification_cterm = initial_cterm.add_constraint(constraint_variant)
result_cterm, _ = cterm_symbolic.simplify(simplification_cterm)
# Extract constraints that appear after simplification but are not in the 'to keep' set
result_constraints = set(result_cterm.constraints).difference(constraints_to_keep)

if len(result_constraints) == 1:
target_constraint = single(result_constraints)
if target_constraint in constraints:
_LOGGER.info(f'forgetBranch: removing constraint: {target_constraint}')
constraints.remove(target_constraint)
break
else:
_LOGGER.info(f'forgetBranch: constraint: {target_constraint} not found in current constraints')
else:
# If no constraints or multiple constraints appear, log this scenario.
if len(result_constraints) == 0:
_LOGGER.info(f'forgetBranch: constraint {constraint_variant} entailed by remaining constraints')
result_cterm, _ = cterm_symbolic.simplify(CTerm(empty_config.config, [constraint_variant]))
if len(result_cterm.constraints) == 1:
to_remove = single(result_cterm.constraints)
if to_remove in constraints:
_LOGGER.info(f'forgetBranch: removing constraint: {to_remove}')
constraints.remove(to_remove)
else:
_LOGGER.info(
f'forgetBranch: more than one constraint found after simplification and removal:\n{result_constraints}'
)
return constraints

_operators = ['_==Int_', '_=/=Int_', '_<=Int_', '_<Int_', '_>=Int_', '_>Int_']
subst = self._cached_subst
assert subst is not None
# Extract the terms and operator from the substitution
fst_term = subst['###TERM1']
snd_term = subst['###TERM2']
operator = subst['###OPERATOR']
assert isinstance(operator, KToken)
# Construct the positive and negative constraints
pos_constraint = mlEqualsTrue(KApply(_operators[int(operator.token)], fst_term, snd_term))
neg_constraint = mlEqualsTrue(notBool(KApply(_operators[int(operator.token)], fst_term, snd_term)))
# To be able to better simplify, we maintain range constraints on the variables present in the constraint
constraint_vars: frozenset[str] = free_vars(fst_term).union(free_vars(snd_term))
constraints_to_keep: set[KInner] = _find_constraints_to_keep(cterm, constraint_vars)

# Set up initial configuration for constraint simplification, and simplify it to get all
# of the kept constraints in the form in which they will appear after constraint simplification
kevm = KEVM(kdist.get('kontrol.foundry'))
empty_config: CTerm = CTerm.from_kast(kevm.definition.empty_config(GENERATED_TOP_CELL))
initial_cterm, _ = cterm_symbolic.simplify(CTerm(empty_config.config, constraints_to_keep))
constraints_to_keep = set(initial_cterm.constraints)

# Simplify in the presence of constraints to keep, then remove the constraints to keep to
# reveal simplified constraint, then remove if present in original constraints
new_constraints: set[KInner] = _filter_constraints_by_simplification(
cterm_symbolic=cterm_symbolic,
initial_cterm=initial_cterm,
constraints_to_remove=[pos_constraint, neg_constraint],
constraints_to_keep=constraints_to_keep,
constraints=set(cterm.constraints),
empty_config=empty_config,
)

# Update the K_CELL with the continuation
new_cterm = CTerm.from_kast(set_cell(cterm.kast, 'K_CELL', KSequence(subst['###CONTINUATION'])))
return Step(CTerm(new_cterm.config, new_constraints), 1, (), ['cheatcode_forget'], cut=True)

def custom_step(self, cterm: CTerm, cterm_symbolic: CTermSymbolic) -> KCFGExtendResult | None:
if self._check_rename_pattern(cterm):
return self._exec_rename_custom_step(cterm)
elif self._check_forget_pattern(cterm):
return self._exec_forget_custom_step(cterm, cterm_symbolic)
else:
return super().custom_step(cterm, _cterm_symbolic)
return super().custom_step(cterm, cterm_symbolic)

def can_make_custom_step(self, cterm: CTerm) -> bool:
return self._check_rename_pattern(cterm) or super().can_make_custom_step(cterm)
return any(
[self._check_rename_pattern(cterm), self._check_forget_pattern(cterm), super().can_make_custom_step(cterm)]
)


class FoundryKEVM(KEVM):
Expand Down
21 changes: 21 additions & 0 deletions src/kontrol/kdist/cheatcodes.md
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,26 @@ Mock functions
[priority(30)]
```


Abstraction functions
---------------------

#### `forgetBranch` - removes a given path constraint.

```
function forgetBranch(uint256 op1, ComparisonOperator op, uint256 op2) external;
```

```k
rule [cheatcode.call.abstract]:
<k> #cheatcode_call SELECTOR ARGS
=> #forget ( #asWord(#range(ARGS,0,32)), #asWord(#range(ARGS,32,32)), #asWord(#range(ARGS,64,32)))
...
</k>
requires SELECTOR ==Int selector ( "forgetBranch(uint256,uint8,uint256)" )
```

Utils
-----

Expand Down Expand Up @@ -1751,6 +1771,7 @@ Selectors for **implemented** cheat code functions.
rule ( selector ( "mockCall(address,bytes,bytes)" ) => 3110212580 )
rule ( selector ( "mockFunction(address,address,bytes)" ) => 2918731041 )
rule ( selector ( "copyStorage(address,address)" ) => 540912653 )
rule ( selector ( "forgetBranch(uint256,uint8,uint256)" ) => 1720990067 )
```

Selectors for **unimplemented** cheat code functions.
Expand Down
4 changes: 4 additions & 0 deletions src/kontrol/kdist/foundry.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ Then, we define helpers in K which can:
<storage> STORAGE => STORAGE [ #loc(FoundryCheat . Failed) <- 1 ] </storage>
...
</account>
syntax KItem ::= #forget ( Int , Int , Int ) [symbol(cheatcode_forget)]
// -----------------------------------------------------------------------
rule [forget]: <k> #forget(_C1,_OP,_C2) => .K ... </k>
```

#### Structure of execution
Expand Down
29 changes: 15 additions & 14 deletions src/tests/integration/test-data/end-to-end-prove-all
Original file line number Diff line number Diff line change
@@ -1,14 +1,25 @@
CounterTest.test_Increment()
ForgetBranchTest.test_forgetBranch(uint256)
RandomVarTest.test_custom_names()
RandomVarTest.test_randomBool()
RandomVarTest.test_randomAddress()
RandomVarTest.test_randomUint()
RandomVarTest.test_randomUint_192()
RandomVarTest.test_randomUint_Range(uint256,uint256)
RandomVarTest.test_randomBool()
RandomVarTest.test_randomBytes_length(uint256)
RandomVarTest.test_randomBytes4_length()
RandomVarTest.test_randomBytes8_length()
RandomVarTest.test_randomUint_192()
RandomVarTest.test_randomUint_Range(uint256,uint256)
RandomVarTest.test_randomUint()
TransientStorageTest.testTransientStoreLoad(uint256,uint256)
UnitTest.test_assert_eq_address_darray(address[])
UnitTest.test_assert_eq_bool_darray(bool[])
UnitTest.test_assert_eq_bytes32_darray(bytes32[])
UnitTest.test_assert_eq_int256_darray(int256[])
UnitTest.test_assert_eq_uint256_darray(uint256[])
UnitTest.test_assertApproxEqAbs_int_same_sign(uint256,uint256,uint256)
UnitTest.test_assertApproxEqAbs_uint(uint256,uint256,uint256)
UnitTest.test_assertApproxEqRel_int_same_sign_unit()
UnitTest.test_assertApproxEqRel_int_zero_cases_unit()
UnitTest.test_assertApproxEqRel_uint_unit()
UnitTest.test_assertEq_address_err()
UnitTest.test_assertEq_bool_err()
UnitTest.test_assertEq_bytes32_err()
Expand Down Expand Up @@ -39,13 +50,3 @@ UnitTest.test_assertNotEq(bytes32,bytes32)
UnitTest.test_assertNotEq(int256,int256)
UnitTest.test_assertTrue_err()
UnitTest.test_assertTrue(bool)
UnitTest.test_assertApproxEqAbs_uint(uint256,uint256,uint256)
UnitTest.test_assertApproxEqAbs_int_same_sign(uint256,uint256,uint256)
UnitTest.test_assertApproxEqRel_uint_unit()
UnitTest.test_assertApproxEqRel_int_same_sign_unit()
UnitTest.test_assertApproxEqRel_int_zero_cases_unit()
UnitTest.test_assert_eq_bytes32_darray(bytes32[])
UnitTest.test_assert_eq_bool_darray(bool[])
UnitTest.test_assert_eq_int256_darray(int256[])
UnitTest.test_assert_eq_uint256_darray(uint256[])
UnitTest.test_assert_eq_address_darray(address[])
1 change: 1 addition & 0 deletions src/tests/integration/test-data/end-to-end-prove-show
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
ForgetBranchTest.test_forgetBranch(uint256)
RandomVarTest.test_custom_names()
34 changes: 17 additions & 17 deletions src/tests/integration/test-data/end-to-end-prove-skip
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
UnitTest.test_assert_eq_address_darray_err()
UnitTest.test_assert_eq_bool_darray_err()
UnitTest.test_assert_eq_bytes32_darray_err()
UnitTest.test_assert_eq_int256_darray_err()
UnitTest.test_assert_eq_uint256_darray_err()
UnitTest.test_assertApproxEqAbs_int_opp_sign_err()
UnitTest.test_assertApproxEqAbs_int_opp_sign(uint256,uint256,uint256)
UnitTest.test_assertApproxEqAbs_int_same_sign_err()
UnitTest.test_assertApproxEqAbs_int_zero_cases_err()
UnitTest.test_assertApproxEqAbs_int_zero_cases(uint256,uint256)
UnitTest.test_assertApproxEqAbs_uint_err()
UnitTest.test_assertApproxEqRel_int_opp_sign_err()
UnitTest.test_assertApproxEqRel_int_opp_sign_unit()
UnitTest.test_assertApproxEqRel_int_same_sign_err()
UnitTest.test_assertApproxEqRel_int_zero_cases_err()
UnitTest.test_assertApproxEqRel_uint_err()
UnitTest.test_assertEq_address_err()
UnitTest.test_assertEq_bool_err()
UnitTest.test_assertEq_bytes32_err()
Expand All @@ -13,20 +29,4 @@ UnitTest.test_assertNotEq_bool_err()
UnitTest.test_assertNotEq_bytes32_err()
UnitTest.test_assertNotEq_err()
UnitTest.test_assertNotEq_int256_err()
UnitTest.test_assertTrue_err()
UnitTest.test_assertApproxEqAbs_uint_err()
UnitTest.test_assertApproxEqAbs_int_same_sign_err()
UnitTest.test_assertApproxEqAbs_int_opp_sign(uint256,uint256,uint256)
UnitTest.test_assertApproxEqAbs_int_opp_sign_err()
UnitTest.test_assertApproxEqAbs_int_zero_cases(uint256,uint256)
UnitTest.test_assertApproxEqAbs_int_zero_cases_err()
UnitTest.test_assertApproxEqRel_uint_err()
UnitTest.test_assertApproxEqRel_int_same_sign_err()
UnitTest.test_assertApproxEqRel_int_opp_sign_unit()
UnitTest.test_assertApproxEqRel_int_opp_sign_err()
UnitTest.test_assertApproxEqRel_int_zero_cases_err()
UnitTest.test_assert_eq_bytes32_darray_err()
UnitTest.test_assert_eq_bool_darray_err()
UnitTest.test_assert_eq_int256_darray_err()
UnitTest.test_assert_eq_address_darray_err()
UnitTest.test_assert_eq_uint256_darray_err()
UnitTest.test_assertTrue_err()
Loading

0 comments on commit 7d496d3

Please sign in to comment.