Skip to content

Commit

Permalink
benchmarks core: fix a few mypy type errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
tfogal committed Jan 28, 2025
1 parent 6a4f050 commit 1c1f1e6
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions benchmarks/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def with_executor(executor: str, fwd_fn: Callable, **kwargs) -> Callable:
)
if executor == "thunder-torchcompile":
return thunder.jit(fwd_fn, executors=["torchcompile"], **kwargs)
raise ValueError(f"Unknown executor '{executor}'")


def compute_total_iobytes(
Expand Down Expand Up @@ -164,6 +165,7 @@ def torchprofile_timer(self) -> float:

def fusionprofile_timer(self) -> float:
if not self.execution_start:
assert self.fd is not None
profile = self.fd.profile()
elapsed_host_time = profile.host_time_ms / 1e3
self._increment_global_time(elapsed_host_time)
Expand Down Expand Up @@ -210,7 +212,7 @@ def set_metrics(
self,
inputs: Union[torch.Tensor, List],
outputs: Union[torch.Tensor, List],
iobytes: int = None,
iobytes: int | None = None,
) -> None:
"""
Utility function to compute metrics for the target function.
Expand Down Expand Up @@ -256,9 +258,9 @@ def run_benchmark(
benchmark: pytest_benchmark.fixture.BenchmarkFixture,
benchmark_fn: Callable | None,
inputs: Union[torch.Tensor, List],
iobytes: int = None,
iobytes: int | None = None,
device: str = "cuda",
fusion_fn: Callable = None,
fusion_fn: Callable | None = None,
) -> Union[torch.Tensor, List]:
"""
Benchmarks the target function using torchprofiler and stores metrics as extra information.
Expand Down

0 comments on commit 1c1f1e6

Please sign in to comment.