Skip to content

Commit 15f8bb0

Browse files
committed
refactor: Use generics
1 parent a8e643c commit 15f8bb0

File tree

6 files changed

+68
-52
lines changed

6 files changed

+68
-52
lines changed

packages/testing/src/execution_testing/client_clis/cli_types.py

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
from dataclasses import dataclass
55
from pathlib import Path
6-
from typing import Annotated, Any, Dict, List, Self
6+
from typing import Annotated, Any, Dict, Generic, List, Self, TypeVar
77

88
from pydantic import Field, PlainSerializer, PlainValidator
99

@@ -287,43 +287,53 @@ class Result(CamelModel):
287287
opcode_count: OpcodeCount | None = None
288288

289289

290+
JSONDict = Dict[str, Any]
291+
292+
TRaw = TypeVar("TRaw", str, JSONDict)
290293

291294

292295
@dataclass(kw_only=True)
293-
class LazyAlloc:
296+
class LazyAlloc(Generic[TRaw]):
294297
"""
295298
Allocation that is lazily loaded from a JSON file.
296299
"""
297300

298-
json_contents: Dict | None = None
299-
str_contents: str | None = None
301+
raw: TRaw
300302
alloc: Alloc | None = None
301303

302-
def str_contents_from_cache(self) -> str:
303-
"""
304-
Return the cached, unparsed, alloc from the last t8n execution.
305-
"""
306-
assert self.str_contents is not None, "No string cache found"
307-
return self.str_contents
308-
309-
def json_contents_from_cache(self) -> Dict:
310-
"""
311-
Return the cached, unparsed, json dict from the last t8n execution.
312-
"""
313-
assert self.json_contents is not None, "No json cache found"
314-
return self.json_contents
304+
def validate(self) -> Alloc:
305+
"""Validate the alloc."""
306+
raise NotImplementedError("validate method not implemented.")
315307

316-
def get_alloc(self) -> Alloc:
308+
def get(self) -> Alloc:
317309
"""Model validate the allocation and return it."""
318-
if self.alloc is not None:
319-
return self.alloc
320-
elif self.json_contents is not None:
321-
self.alloc = Alloc.model_validate(self.json_contents)
322-
return self.alloc
323-
elif self.str_contents is not None:
324-
self.alloc = Alloc.model_validate_json(self.str_contents)
325-
return self.alloc
326-
raise ValueError("No alloc found")
310+
if self.alloc is None:
311+
self.alloc = self.validate()
312+
return self.alloc
313+
314+
315+
class LazyAllocJson(LazyAlloc[JSONDict]):
316+
"""
317+
Lazy allocation backed by a JSON dict cache.
318+
319+
Uses Alloc.model_validate on the dict.
320+
"""
321+
322+
def validate(self) -> Alloc:
323+
"""Validate the alloc."""
324+
return Alloc.model_validate(self.raw)
325+
326+
327+
class LazyAllocStr(LazyAlloc[str]):
328+
"""
329+
Lazy allocation backed by a JSON dict cache.
330+
331+
Uses Alloc.model_validate on the dict.
332+
"""
333+
334+
def validate(self) -> Alloc:
335+
"""Validate the alloc."""
336+
return Alloc.model_validate_json(self.raw)
327337

328338

329339
@dataclass
@@ -342,10 +352,12 @@ def to_files(
342352
Prepare the input in a directory path in the file system for
343353
consumption by the t8n tool.
344354
"""
345-
if isinstance(self.alloc, LazyAlloc):
346-
alloc_contents = self.alloc.str_contents_from_cache()
347-
else:
355+
if isinstance(self.alloc, Alloc):
348356
alloc_contents = self.alloc.model_dump_json(**model_dump_config)
357+
elif isinstance(self.alloc, LazyAllocStr):
358+
alloc_contents = self.alloc.raw
359+
else:
360+
raise Exception(f"Invalid alloc type: {type(self.alloc)}")
349361

350362
env_contents = self.env.model_dump_json(**model_dump_config)
351363
txs_contents = (
@@ -375,10 +387,12 @@ def to_files(
375387

376388
def model_dump_json(self, **model_dump_config: Any) -> str:
377389
"""Dump the model in string JSON format."""
378-
if isinstance(self.alloc, LazyAlloc):
379-
alloc_contents = self.alloc.str_contents_from_cache()
380-
else:
390+
if isinstance(self.alloc, Alloc):
381391
alloc_contents = self.alloc.model_dump_json(**model_dump_config)
392+
elif isinstance(self.alloc, LazyAllocStr):
393+
alloc_contents = self.alloc.raw
394+
else:
395+
raise Exception(f"Invalid alloc type: {type(self.alloc)}")
382396

383397
env_contents = self.env.model_dump_json(**model_dump_config)
384398
txs_contents = (
@@ -405,12 +419,14 @@ def model_dump_json(self, **model_dump_config: Any) -> str:
405419
def model_dump(self, mode: str, **model_dump_config: Any) -> Any:
406420
"""Return the validated model."""
407421
assert mode == "json", f"Mode {mode} not supported."
408-
if isinstance(self.alloc, LazyAlloc):
409-
alloc_contents = self.alloc.json_contents_from_cache()
410-
else:
422+
if isinstance(self.alloc, Alloc):
411423
alloc_contents = self.alloc.model_dump(
412424
mode=mode, **model_dump_config
413425
)
426+
elif isinstance(self.alloc, LazyAllocJson):
427+
alloc_contents = self.alloc.raw
428+
else:
429+
raise Exception(f"Invalid alloc type: {type(self.alloc)}")
414430

415431
env_contents = self.env.model_dump(mode=mode, **model_dump_config)
416432
txs_contents = [
@@ -450,7 +466,7 @@ def model_validate_files(
450466
result = Result.model_validate_json(
451467
json_data=result_data, context=context
452468
)
453-
alloc = LazyAlloc(str_contents=alloc_data)
469+
alloc = LazyAllocStr(raw=alloc_data)
454470
output = cls(result=result, alloc=alloc)
455471
return output
456472

@@ -465,7 +481,7 @@ def model_validate(
465481
result = Result.model_validate(
466482
obj=response_json["result"], context=context
467483
)
468-
alloc = LazyAlloc(json_contents=response_json["alloc"])
484+
alloc = LazyAllocJson(raw=response_json["alloc"])
469485
output = cls(result=result, alloc=alloc)
470486
return output
471487

@@ -482,7 +498,7 @@ def model_validate_json(
482498
result = Result.model_validate(
483499
obj=parsed_json["result"], context=context
484500
)
485-
alloc = LazyAlloc(str_contents=json.dumps(parsed_json["alloc"]))
501+
alloc = LazyAllocStr(raw=json.dumps(parsed_json["alloc"]))
486502
output = cls(result=result, alloc=alloc)
487503
return output
488504

packages/testing/src/execution_testing/client_clis/clis/besu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def evaluate(
191191
dump_files_to_directory(
192192
debug_output_path,
193193
{
194-
"output/alloc.json": output.alloc.json_contents_from_cache(),
194+
"output/alloc.json": output.alloc.raw,
195195
"output/result.json": output.result.model_dump(
196196
mode="json", **model_dump_config
197197
),

packages/testing/src/execution_testing/client_clis/tests/test_execution_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_evm_t8n(
182182
blob_schedule=Berlin.blob_schedule(),
183183
),
184184
)
185-
assert to_json(t8n_output.alloc.get_alloc()) == expected.get("alloc")
185+
assert to_json(t8n_output.alloc.get()) == expected.get("alloc")
186186
if isinstance(default_t8n, ExecutionSpecsTransitionTool):
187187
# The expected output was generated with geth, instead of deleting
188188
# any info from this expected output, the fields not returned by

packages/testing/src/execution_testing/client_clis/transition_tool.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ def _evaluate_server(
526526
dump_files_to_directory(
527527
debug_output_path,
528528
{
529-
"input/alloc.json": request_data.input.alloc.json_contents_from_cache()
529+
"input/alloc.json": request_data.input.alloc.raw
530530
if isinstance(request_data.input.alloc, LazyAlloc)
531531
else request_data.input.alloc.model_dump(
532532
mode="json", **model_dump_config
@@ -571,7 +571,7 @@ def _evaluate_server(
571571
dump_files_to_directory(
572572
debug_output_path,
573573
{
574-
"output/alloc.json": output.alloc.json_contents_from_cache(),
574+
"output/alloc.json": output.alloc.raw,
575575
"output/result.json": output.result,
576576
"output/txs.rlp": str(output.body),
577577
"response_info.txt": response_info,
@@ -626,7 +626,7 @@ def _evaluate_stream(
626626
dump_files_to_directory(
627627
debug_output_path,
628628
{
629-
"output/alloc.json": output.alloc.str_contents_from_cache(),
629+
"output/alloc.json": output.alloc.raw,
630630
"output/result.json": output.result,
631631
"output/txs.rlp": str(output.body),
632632
},

packages/testing/src/execution_testing/specs/blockchain.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def generate_block_data(
741741
print_traces(t8n.get_traces())
742742
pprint(transition_tool_output.result)
743743
pprint(previous_alloc)
744-
pprint(transition_tool_output.alloc.get_alloc())
744+
pprint(transition_tool_output.alloc.get())
745745
raise e
746746

747747
if len(rejected_txs) > 0 and block.exception is None:
@@ -814,13 +814,13 @@ def make_fixture(
814814
if block.expected_post_state:
815815
self.verify_post_state(
816816
t8n,
817-
t8n_state=alloc.get_alloc()
817+
t8n_state=alloc.get()
818818
if isinstance(alloc, LazyAlloc)
819819
else alloc,
820820
expected_state=block.expected_post_state,
821821
)
822822
self.check_exception_test(exception=invalid_blocks > 0)
823-
alloc = alloc.get_alloc() if isinstance(alloc, LazyAlloc) else alloc
823+
alloc = alloc.get() if isinstance(alloc, LazyAlloc) else alloc
824824
self.verify_post_state(t8n, t8n_state=alloc)
825825
info = {}
826826
if self._opcode_count is not None:
@@ -891,7 +891,7 @@ def make_hive_fixture(
891891
if block.expected_post_state:
892892
self.verify_post_state(
893893
t8n,
894-
t8n_state=alloc.get_alloc()
894+
t8n_state=alloc.get()
895895
if isinstance(alloc, LazyAlloc)
896896
else alloc,
897897
expected_state=block.expected_post_state,
@@ -906,7 +906,7 @@ def make_hive_fixture(
906906
" The framework should never try to execute this test case."
907907
)
908908

909-
alloc = alloc.get_alloc() if isinstance(alloc, LazyAlloc) else alloc
909+
alloc = alloc.get() if isinstance(alloc, LazyAlloc) else alloc
910910
self.verify_post_state(t8n, t8n_state=alloc)
911911

912912
# Create base fixture data, common to all fixture formats

packages/testing/src/execution_testing/specs/state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def verify_modified_gas_limit(
167167
f"Traces are not equivalent (gas_limit={current_gas_limit})"
168168
)
169169
return False
170-
modified_tool_alloc = modified_tool_output.alloc.get_alloc()
170+
modified_tool_alloc = modified_tool_output.alloc.get()
171171
try:
172172
self.post.verify_post_alloc(modified_tool_alloc)
173173
except Exception as e:
@@ -363,7 +363,7 @@ def make_state_test_fixture(
363363
debug_output_path=self.get_next_transition_tool_output_path(),
364364
slow_request=self.is_tx_gas_heavy_test(),
365365
)
366-
output_alloc = transition_tool_output.alloc.get_alloc()
366+
output_alloc = transition_tool_output.alloc.get()
367367

368368
try:
369369
self.post.verify_post_alloc(output_alloc)
@@ -391,7 +391,7 @@ def make_state_test_fixture(
391391
self._operation_mode == OpMode.OPTIMIZE_GAS_POST_PROCESSING
392392
)
393393
base_tool_output = transition_tool_output
394-
base_tool_alloc = base_tool_output.alloc.get_alloc()
394+
base_tool_alloc = base_tool_output.alloc.get()
395395
base_tool_result = base_tool_output.result
396396

397397
assert base_tool_result.traces is not None, "Traces not found."

0 commit comments

Comments
 (0)