Skip to content

Commit

Permalink
Add support for NatSpec comments in proofs (#320)
Browse files Browse the repository at this point in the history
* Add support for NatSpec comments and devdocs

* Set Version: 0.1.134

* move extra-output to toml file; update expected output

* review suggestions

* Set Version: 0.1.135

* add review suggestions

* Set Version: 0.1.136

* Set Version: 0.1.136

---------

Co-authored-by: devops <[email protected]>
  • Loading branch information
anvacaru and devops authored Feb 2, 2024
1 parent 8fab824 commit d91bc90
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 36 deletions.
2 changes: 1 addition & 1 deletion package/test-project/foundry.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[profile.default]
src = 'src'
out = 'out'
extra_output = ['storageLayout', 'abi', 'evm.methodIdentifiers', 'evm.deployedBytecode.object']
extra_output = ['storageLayout', 'abi', 'evm.methodIdentifiers', 'evm.deployedBytecode.object', 'devdoc']
2 changes: 1 addition & 1 deletion package/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.135
0.1.136
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "kontrol"
version = "0.1.135"
version = "0.1.136"
description = "Foundry integration for KEVM"
authors = [
"Runtime Verification, Inc. <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion src/kontrol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
if TYPE_CHECKING:
from typing import Final

VERSION: Final = '0.1.135'
VERSION: Final = '0.1.136'
149 changes: 124 additions & 25 deletions src/kontrol/solc_to_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,35 @@ class Input:
type: str
components: tuple[Input, ...] = ()
idx: int = 0
array_lengths: tuple[int, ...] | None = None
dynamic_type_length: int | None = None

@staticmethod
def from_dict(input: dict, idx: int = 0) -> Input:
name = input['name']
type = input['type']
if input.get('components') is not None and input['type'] != 'tuple[]':
return Input(name, type, tuple(Input._unwrap_components(input['components'], idx)), idx)
def from_dict(input: dict, idx: int = 0, natspec_lengths: dict | None = None) -> Input:
"""
Creates an Input instance from a dictionary.
If the optional devdocs is provided, it is used for calculating array and dynamic type lengths.
For tuples, the function handles nested 'components' recursively.
"""
name = input.get('name')
type = input.get('type')
if name is None or type is None:
raise ValueError("ABI dictionary must contain 'name' and 'type' keys.", input)
array_lengths, dynamic_type_length = (
process_length_equals(input, natspec_lengths) if natspec_lengths is not None else (None, None)
)
if input.get('components') is not None:
return Input(
name,
type,
tuple(Input._unwrap_components(input['components'], idx, natspec_lengths)),
idx,
array_lengths,
dynamic_type_length,
)
else:
return Input(name, type, idx=idx)
return Input(name, type, idx=idx, array_lengths=array_lengths, dynamic_type_length=dynamic_type_length)

@staticmethod
def arg_name(input: Input) -> str:
Expand Down Expand Up @@ -108,21 +128,25 @@ def _make_complex_type(components: Iterable[Input]) -> KApply:
return KEVM.abi_tuple(abi_types)

@staticmethod
def _unwrap_components(components: dict, i: int = 0) -> list[Input]:
def _unwrap_components(components: list[dict], idx: int = 0, natspec_lengths: dict | None = None) -> list[Input]:
"""
recursively unwrap components in arguments of complex types
Recursively unwrap components of a complex type to create a list of Input instances.
:param components:: A list of dictionaries representing component structures
:param idx: Starting index for components, defaults to 0
:param natspec_lengths: Optional dictionary for calculating array and dynamic type lengths
:return: A list of Input instances for each component, including nested components
"""
comps = []
for comp in components:
_name = comp['name']
_type = comp['type']
if comp.get('components') is not None and type != 'tuple[]':
new_comps = Input._unwrap_components(comp['components'], i)
else:
new_comps = []
comps.append(Input(_name, _type, tuple(new_comps), i))
i += 1
return comps
return [
Input(
component['name'],
component['type'],
tuple(Input._unwrap_components(component.get('components', []), idx, natspec_lengths)),
idx,
*process_length_equals(component, natspec_lengths) if natspec_lengths else (None, None),
)
for idx, component in enumerate(components, start=idx)
]

def to_abi(self) -> KApply:
if self.type == 'tuple':
Expand All @@ -138,16 +162,76 @@ def flattened(self) -> list[Input]:
return [self]


def inputs_from_abi(abi_inputs: Iterable[dict]) -> list[Input]:
def inputs_from_abi(abi_inputs: Iterable[dict], natspec_lengths: dict | None) -> list[Input]:
inputs = []
i = 0
index = 0
for input in abi_inputs:
cur_input = Input.from_dict(input, i)
cur_input = Input.from_dict(input, index, natspec_lengths)
inputs.append(cur_input)
i += len(cur_input.flattened())
index += len(cur_input.flattened())
return inputs


def process_length_equals(input_dict: dict, lengths: dict) -> tuple[tuple[int, ...] | None, int | None]:
"""
Read from NatSpec comments the maximum length bound of an array, dynamic type, array of dynamic type, or nested arrays.
In case of arrays and nested arrays, the bound values are stored in an immutable list.
In case of dynamic types such as `string` and `bytes` the length bound is stored in its own variable.
As a convention, the length of a nested array or of a dynamic type array is accessed by appending `[]` to the name of the variable.
i.e. for `bytes[][] _b`, the lengths are registered as:
_b: length of the upper most array
_b[]: length of the inner array
_b[][]: length of the `bytes` elements in the inner array.
If an array length is missing, the default value will be `2` to avoid generating symbolic variables.
The dynamic type length is optional, ommiting it may cause branchings in symbolic execution.
"""
_name: str = input_dict['name']
_type: str = input_dict['type']
dynamic_type_length: int | None
input_array_lengths: tuple[int, ...] | None
array_lengths: list[int] = []
while _type.endswith('[]'):
array_lengths.append(lengths.get(_name, 2))
_type = _type[:-2]
_name += '[]'
input_array_lengths = tuple(array_lengths) if array_lengths else None
dynamic_type_length = lengths.get(_name) if _type in ['bytes', 'string'] else None
return (input_array_lengths, dynamic_type_length)


def parse_devdoc(tag: str, devdoc: dict | None) -> dict:
"""
Parse developer documentation (devdoc) to extract specific information based on a given tag.
Example:
If devdoc contains { 'custom:kontrol-length-equals': '_withdrawalProof 10,_withdrawalProof[] 600,_l2OutputIndex 4,'},
and the function is called with tag='custom:kontrol-length-equals', it would return:
{ '_withdrawalProof': 10, '_withdrawalProof[]': 600, '_l2OutputIndex': 4 }
"""

if devdoc is None or tag not in devdoc:
return {}

natspecs = {}
natspec_values = devdoc[tag]

for part in natspec_values.split(','):
# Trim whitespace and skip if empty
part = part.strip()
if not part:
continue

# Split each part into variable and length
try:
key, value_str = part.split(':')
key = key.strip()
natspecs[key] = int(value_str.strip())
except ValueError:
_LOGGER.warning(f'Skipping invalid fomat {part} in {tag}')
return natspecs


@dataclass
class Contract:
@dataclass
Expand Down Expand Up @@ -234,6 +318,7 @@ class Method:
payable: bool
signature: str
ast: dict | None
natspec_values: dict | None

def __init__(
self,
Expand All @@ -245,18 +330,20 @@ def __init__(
contract_digest: str,
contract_storage_digest: str,
sort: KSort,
devdoc: dict | None,
) -> None:
self.signature = msig
self.name = abi['name']
self.id = id
self.inputs = tuple(inputs_from_abi(abi['inputs']))
self.contract_name = contract_name
self.contract_digest = contract_digest
self.contract_storage_digest = contract_storage_digest
self.sort = sort
# TODO: Check that we're handling all state mutability cases
self.payable = abi['stateMutability'] == 'payable'
self.ast = ast
self.natspec_values = parse_devdoc('custom:kontrol-length-equals', devdoc)
self.inputs = tuple(inputs_from_abi(abi['inputs'], self.natspec_values))

@property
def klabel(self) -> KLabel:
Expand Down Expand Up @@ -447,14 +534,26 @@ def __init__(self, contract_name: str, contract_json: dict, foundry: bool = Fals
}

_methods = []
metadata = self.contract_json.get('metadata', {})
devdoc = metadata.get('output', {}).get('devdoc', {}).get('methods', {})

for method in contract_json['abi']:
if method['type'] == 'function':
msig = method_sig_from_abi(method)
method_selector: str = str(evm['methodIdentifiers'][msig])
mid = int(method_selector, 16)
method_ast = function_asts[method_selector] if method_selector in function_asts else None
method_devdoc = devdoc.get(msig)
_m = Contract.Method(
msig, mid, method, method_ast, self._name, self.digest, self.storage_digest, self.sort_method
msig,
mid,
method,
method_ast,
self._name,
self.digest,
self.storage_digest,
self.sort_method,
method_devdoc,
)
_methods.append(_m)
if method['type'] == 'constructor':
Expand Down
12 changes: 6 additions & 6 deletions src/tests/integration/test-data/contracts.k.expected
Original file line number Diff line number Diff line change
Expand Up @@ -6673,13 +6673,13 @@ module S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3-CONTRACT

syntax Bytes ::= S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Contract "." S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method [function(), symbol(), klabel(method_lib%forge-std%src%interfaces%IMulticall3)]

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2Kaggregate" "(" K ":" "tuple[]" ")" [symbol(), klabel(method_IMulticall3_S2Kaggregate_tuple[])]
syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2Kaggregate" "(" Int ":" "address" "," Bytes ":" "bytes" ")" [symbol(), klabel(method_IMulticall3_S2Kaggregate_address_bytes)]

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2Kaggregate3" "(" K ":" "tuple[]" ")" [symbol(), klabel(method_IMulticall3_S2Kaggregate3_tuple[])]
syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2Kaggregate3" "(" Int ":" "address" "," Int ":" "bool" "," Bytes ":" "bytes" ")" [symbol(), klabel(method_IMulticall3_S2Kaggregate3_address_bool_bytes)]

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2Kaggregate3Value" "(" K ":" "tuple[]" ")" [symbol(), klabel(method_IMulticall3_S2Kaggregate3Value_tuple[])]
syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2Kaggregate3Value" "(" Int ":" "address" "," Int ":" "bool" "," Int ":" "uint256" "," Bytes ":" "bytes" ")" [symbol(), klabel(method_IMulticall3_S2Kaggregate3Value_address_bool_uint256_bytes)]

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KblockAndAggregate" "(" K ":" "tuple[]" ")" [symbol(), klabel(method_IMulticall3_S2KblockAndAggregate_tuple[])]
syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KblockAndAggregate" "(" Int ":" "address" "," Bytes ":" "bytes" ")" [symbol(), klabel(method_IMulticall3_S2KblockAndAggregate_address_bytes)]

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KgetBasefee" "(" ")" [symbol(), klabel(method_IMulticall3_S2KgetBasefee_)]

Expand All @@ -6701,9 +6701,9 @@ module S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3-CONTRACT

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KgetLastBlockHash" "(" ")" [symbol(), klabel(method_IMulticall3_S2KgetLastBlockHash_)]

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KtryAggregate" "(" Int ":" "bool" "," K ":" "tuple[]" ")" [symbol(), klabel(method_IMulticall3_S2KtryAggregate_bool_tuple[])]
syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KtryAggregate" "(" Int ":" "bool" "," Int ":" "address" "," Bytes ":" "bytes" ")" [symbol(), klabel(method_IMulticall3_S2KtryAggregate_bool_address_bytes)]

syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KtryBlockAndAggregate" "(" Int ":" "bool" "," K ":" "tuple[]" ")" [symbol(), klabel(method_IMulticall3_S2KtryBlockAndAggregate_bool_tuple[])]
syntax S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3Method ::= "S2KtryBlockAndAggregate" "(" Int ":" "bool" "," Int ":" "address" "," Bytes ":" "bytes" ")" [symbol(), klabel(method_IMulticall3_S2KtryBlockAndAggregate_bool_address_bytes)]

rule ( S2KlibZModforgeZSubstdZModsrcZModinterfacesZModIMulticall3 . S2KgetBasefee ( ) => #abiCallData ( "getBasefee" , .TypedArgs ) )

Expand Down
31 changes: 30 additions & 1 deletion src/tests/unit/test_solc_to_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kevm_pyk.kevm import KEVM
from pyk.kast.inner import KApply, KToken, KVariable

from kontrol.solc_to_k import Contract, Input, _range_predicates
from kontrol.solc_to_k import Contract, Input, _range_predicates, process_length_equals

from .utils import TEST_DATA_DIR

Expand Down Expand Up @@ -172,6 +172,17 @@ def test_escaping(test_id: str, prefix: str, input: str, output: str) -> None:
),
]

DEVDOCS_DATA: list[tuple[str, dict, dict, tuple[int, ...] | None, int | None]] = [
(
'test_1',
{'_withdrawalProof': 10, '_withdrawalProof[]': 600, 'data': 600},
{'name': '_withdrawalProof', 'type': 'bytes[]'},
(10,),
600,
),
('test_2', {}, {'name': '_a', 'type': 'bytes'}, None, None),
]


@pytest.mark.parametrize('test_id,input,expected', INPUT_DATA, ids=[test_id for test_id, *_ in INPUT_DATA])
def test_input_to_abi(test_id: str, input: Input, expected: KApply) -> None:
Expand All @@ -180,3 +191,21 @@ def test_input_to_abi(test_id: str, input: Input, expected: KApply) -> None:

# Then
assert abi == expected


@pytest.mark.parametrize(
'test_id,devdocs,input_dict,expected_array_length, expected_dynamic_type_length',
DEVDOCS_DATA,
ids=[test_id for test_id, *_ in DEVDOCS_DATA],
)
def test_process_length_equals(
test_id: str,
devdocs: dict,
input_dict: dict,
expected_array_length: list[int] | None,
expected_dynamic_type_length: int | None,
) -> None:
# When
array_lengths, dyn_len = process_length_equals(input_dict, devdocs)
assert array_lengths == expected_array_length
assert dyn_len == expected_dynamic_type_length

0 comments on commit d91bc90

Please sign in to comment.