Skip to content

Commit fa232de

Browse files
authored
Merge pull request #687 from mfinean/refactor/decompose-bench-class
Decompose Bench class into focused helper classes
2 parents 4d70f85 + 685cb99 commit fa232de

8 files changed

Lines changed: 1523 additions & 459 deletions

bencher/bencher.py

Lines changed: 66 additions & 458 deletions
Large diffs are not rendered by default.

bencher/result_collector.py

Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
"""Result collection and storage for benchmarking.
2+
3+
This module provides the ResultCollector class for managing benchmark results,
4+
including xarray dataset operations, caching, and metadata management.
5+
"""
6+
7+
import logging
8+
from datetime import datetime
9+
from itertools import product
10+
from typing import Any, List, Tuple
11+
12+
import numpy as np
13+
import xarray as xr
14+
from diskcache import Cache
15+
16+
from bencher.bench_cfg import BenchCfg, BenchRunCfg, DimsCfg
17+
from bencher.results.bench_result import BenchResult
18+
from bencher.variables.inputs import IntSweep
19+
from bencher.variables.time import TimeSnapshot, TimeEvent
20+
from bencher.variables.results import (
21+
XARRAY_MULTIDIM_RESULT_TYPES,
22+
ResultVar,
23+
ResultBool,
24+
ResultVec,
25+
ResultPath,
26+
ResultVideo,
27+
ResultImage,
28+
ResultString,
29+
ResultContainer,
30+
ResultReference,
31+
ResultDataSet,
32+
)
33+
from bencher.worker_job import WorkerJob
34+
from bencher.job import JobFuture
35+
36+
# Default cache size for benchmark results (100 GB)
37+
DEFAULT_CACHE_SIZE_BYTES = int(100e9)
38+
39+
logger = logging.getLogger(__name__)
40+
41+
42+
def set_xarray_multidim(
43+
data_array: xr.DataArray, index_tuple: Tuple[int, ...], value: Any
44+
) -> xr.DataArray:
45+
"""Set a value in a multi-dimensional xarray at the specified index position.
46+
47+
This function sets a value in an N-dimensional xarray using dynamic indexing
48+
that works for any number of dimensions.
49+
50+
Args:
51+
data_array (xr.DataArray): The data array to modify
52+
index_tuple (Tuple[int, ...]): The index coordinates as a tuple
53+
value (Any): The value to set at the specified position
54+
55+
Returns:
56+
xr.DataArray: The modified data array
57+
"""
58+
data_array[index_tuple] = value
59+
return data_array
60+
61+
62+
class ResultCollector:
63+
"""Manages benchmark result collection, storage, and caching.
64+
65+
This class handles the initialization of xarray datasets for storing benchmark
66+
results, storing results from worker jobs, managing caches, and adding metadata.
67+
68+
Attributes:
69+
cache_size (int): Maximum size of the cache in bytes
70+
ds_dynamic (dict): Dictionary for storing unstructured vector datasets
71+
"""
72+
73+
def __init__(self, cache_size: int = DEFAULT_CACHE_SIZE_BYTES) -> None:
74+
"""Initialize a new ResultCollector.
75+
76+
Args:
77+
cache_size (int): Maximum cache size in bytes. Defaults to 100 GB.
78+
"""
79+
self.cache_size = cache_size
80+
self.ds_dynamic: dict = {}
81+
82+
def setup_dataset(
83+
self, bench_cfg: BenchCfg, time_src: datetime | str
84+
) -> Tuple[BenchResult, List[Tuple], List[str]]:
85+
"""Initialize an n-dimensional xarray dataset from benchmark configuration parameters.
86+
87+
This function creates the data structures needed to store benchmark results based on
88+
the provided configuration. It sets up the xarray dimensions, coordinates, and variables
89+
based on input variables and result variables.
90+
91+
Args:
92+
bench_cfg (BenchCfg): Configuration defining the benchmark parameters, inputs, and
93+
results
94+
time_src (datetime | str): Timestamp or event name for the benchmark run
95+
96+
Returns:
97+
Tuple[BenchResult, List[Tuple], List[str]]:
98+
- A BenchResult object with the initialized dataset
99+
- A list of function input tuples (index, value pairs)
100+
- A list of dimension names for the dataset
101+
"""
102+
if time_src is None:
103+
time_src = datetime.now()
104+
bench_cfg.meta_vars = self.define_extra_vars(bench_cfg, bench_cfg.repeats, time_src)
105+
106+
bench_cfg.all_vars = bench_cfg.input_vars + bench_cfg.meta_vars
107+
108+
for i in bench_cfg.all_vars:
109+
logger.info(i.sampling_str())
110+
111+
dims_cfg = DimsCfg(bench_cfg)
112+
function_inputs = list(
113+
zip(product(*dims_cfg.dim_ranges_index), product(*dims_cfg.dim_ranges))
114+
)
115+
# xarray stores K N-dimensional arrays of data.
116+
# Each array is named and in this case we have an ND array for each result variable
117+
data_vars = {}
118+
dataset_list = []
119+
120+
for rv in bench_cfg.result_vars:
121+
if isinstance(rv, (ResultVar, ResultBool)):
122+
result_data = np.full(dims_cfg.dims_size, np.nan, dtype=float)
123+
data_vars[rv.name] = (dims_cfg.dims_name, result_data)
124+
if isinstance(rv, (ResultReference, ResultDataSet)):
125+
result_data = np.full(dims_cfg.dims_size, -1, dtype=int)
126+
data_vars[rv.name] = (dims_cfg.dims_name, result_data)
127+
if isinstance(
128+
rv, (ResultPath, ResultVideo, ResultImage, ResultString, ResultContainer)
129+
):
130+
result_data = np.full(dims_cfg.dims_size, "NAN", dtype=object)
131+
data_vars[rv.name] = (dims_cfg.dims_name, result_data)
132+
133+
elif type(rv) is ResultVec:
134+
for i in range(rv.size):
135+
result_data = np.full(dims_cfg.dims_size, np.nan)
136+
data_vars[rv.index_name(i)] = (dims_cfg.dims_name, result_data)
137+
138+
bench_res = BenchResult(bench_cfg)
139+
bench_res.ds = xr.Dataset(data_vars=data_vars, coords=dims_cfg.coords)
140+
bench_res.ds_dynamic = self.ds_dynamic
141+
bench_res.dataset_list = dataset_list
142+
bench_res.setup_object_index()
143+
144+
return bench_res, function_inputs, dims_cfg.dims_name
145+
146+
def define_extra_vars(
147+
self, bench_cfg: BenchCfg, repeats: int, time_src: datetime | str
148+
) -> List[IntSweep]:
149+
"""Define extra meta variables for tracking benchmark execution details.
150+
151+
This function creates variables that aren't passed to the worker function but are stored
152+
in the n-dimensional array to provide context about the benchmark, such as the number of
153+
repeat measurements and timestamps.
154+
155+
Args:
156+
bench_cfg (BenchCfg): The benchmark configuration to add variables to
157+
repeats (int): The number of times each sample point should be measured
158+
time_src (datetime | str): Either a timestamp or a string event name for temporal
159+
tracking
160+
161+
Returns:
162+
List[IntSweep]: A list of additional parameter variables to include in the benchmark
163+
"""
164+
bench_cfg.iv_repeat = IntSweep(
165+
default=repeats,
166+
bounds=[1, repeats],
167+
samples=repeats,
168+
units="repeats",
169+
doc="The number of times a sample was measured",
170+
)
171+
bench_cfg.iv_repeat.name = "repeat"
172+
extra_vars = [bench_cfg.iv_repeat]
173+
174+
if bench_cfg.over_time:
175+
if isinstance(time_src, str):
176+
iv_over_time = TimeEvent(time_src)
177+
else:
178+
iv_over_time = TimeSnapshot(time_src)
179+
iv_over_time.name = "over_time"
180+
extra_vars.append(iv_over_time)
181+
bench_cfg.iv_time = [iv_over_time]
182+
return extra_vars
183+
184+
def store_results(
185+
self,
186+
job_result: JobFuture,
187+
bench_res: BenchResult,
188+
worker_job: WorkerJob,
189+
bench_run_cfg: BenchRunCfg,
190+
) -> None:
191+
"""Store the results from a benchmark worker job into the benchmark result dataset.
192+
193+
This method handles unpacking the results from worker jobs and placing them
194+
in the correct locations in the n-dimensional result dataset. It supports different
195+
types of result variables including scalars, vectors, references, and media.
196+
197+
Args:
198+
job_result (JobFuture): The future containing the worker function result
199+
bench_res (BenchResult): The benchmark result object to store results in
200+
worker_job (WorkerJob): The job metadata needed to index the result
201+
bench_run_cfg (BenchRunCfg): Configuration for how results should be handled
202+
203+
Raises:
204+
RuntimeError: If an unsupported result variable type is encountered
205+
"""
206+
result = job_result.result()
207+
if result is not None:
208+
logger.info(f"{job_result.job.job_id}:")
209+
if bench_res.bench_cfg.print_bench_inputs:
210+
for k, v in worker_job.function_input.items():
211+
logger.info(f"\t {k}:{v}")
212+
213+
result_dict = result if isinstance(result, dict) else result.param.values()
214+
215+
for rv in bench_res.bench_cfg.result_vars:
216+
result_value = result_dict[rv.name]
217+
if bench_run_cfg.print_bench_results:
218+
logger.info(f"{rv.name}: {result_value}")
219+
220+
if isinstance(rv, XARRAY_MULTIDIM_RESULT_TYPES):
221+
set_xarray_multidim(bench_res.ds[rv.name], worker_job.index_tuple, result_value)
222+
elif isinstance(rv, ResultDataSet):
223+
bench_res.dataset_list.append(result_value)
224+
set_xarray_multidim(
225+
bench_res.ds[rv.name],
226+
worker_job.index_tuple,
227+
len(bench_res.dataset_list) - 1,
228+
)
229+
elif isinstance(rv, ResultReference):
230+
bench_res.object_index.append(result_value)
231+
set_xarray_multidim(
232+
bench_res.ds[rv.name],
233+
worker_job.index_tuple,
234+
len(bench_res.object_index) - 1,
235+
)
236+
237+
elif isinstance(rv, ResultVec):
238+
if isinstance(result_value, (list, np.ndarray)):
239+
if len(result_value) == rv.size:
240+
for i in range(rv.size):
241+
set_xarray_multidim(
242+
bench_res.ds[rv.index_name(i)],
243+
worker_job.index_tuple,
244+
result_value[i],
245+
)
246+
247+
else:
248+
raise RuntimeError("Unsupported result type")
249+
for rv in bench_res.result_hmaps:
250+
bench_res.hmaps[rv.name][worker_job.canonical_input] = result_dict[rv.name]
251+
252+
def cache_results(
253+
self, bench_res: BenchResult, bench_cfg_hash: str, bench_cfg_hashes: List[str]
254+
) -> None:
255+
"""Cache benchmark results for future retrieval.
256+
257+
This method stores benchmark results in the disk cache using the benchmark
258+
configuration hash as the key. It temporarily removes non-pickleable objects
259+
from the benchmark result before caching.
260+
261+
Args:
262+
bench_res (BenchResult): The benchmark result to cache
263+
bench_cfg_hash (str): The hash value to use as the cache key
264+
bench_cfg_hashes (List[str]): List to append the hash to (modified in place)
265+
"""
266+
with Cache("cachedir/benchmark_inputs", size_limit=self.cache_size) as c:
267+
logger.info(f"saving results with key: {bench_cfg_hash}")
268+
bench_cfg_hashes.append(bench_cfg_hash)
269+
# object index may not be pickleable so remove before caching
270+
obj_index_tmp = bench_res.object_index
271+
bench_res.object_index = []
272+
273+
c[bench_cfg_hash] = bench_res
274+
275+
# restore object index
276+
bench_res.object_index = obj_index_tmp
277+
278+
logger.info(f"saving benchmark: {bench_res.bench_cfg.bench_name}")
279+
c[bench_res.bench_cfg.bench_name] = bench_cfg_hashes
280+
281+
def load_history_cache(
282+
self, dataset: xr.Dataset, bench_cfg_hash: str, clear_history: bool
283+
) -> xr.Dataset:
284+
"""Load historical data from a cache if over_time is enabled.
285+
286+
This method is used to retrieve and concatenate historical benchmark data from the cache
287+
when tracking performance over time. If clear_history is True, it will clear any existing
288+
historical data instead of loading it.
289+
290+
Args:
291+
dataset (xr.Dataset): Freshly calculated benchmark data for the current run
292+
bench_cfg_hash (str): Hash of the input variables used to identify cached data
293+
clear_history (bool): If True, clears historical data instead of loading it
294+
295+
Returns:
296+
xr.Dataset: Combined dataset with both historical and current benchmark data,
297+
or just the current data if no history exists or history is cleared
298+
"""
299+
with Cache("cachedir/history", size_limit=self.cache_size) as c:
300+
if clear_history:
301+
logger.info("clearing history")
302+
else:
303+
logger.info(f"checking historical key: {bench_cfg_hash}")
304+
if bench_cfg_hash in c:
305+
logger.info("loading historical data from cache")
306+
ds_old = c[bench_cfg_hash]
307+
dataset = xr.concat([ds_old, dataset], "over_time")
308+
else:
309+
logger.info("did not detect any historical data")
310+
311+
logger.info("saving data to history cache")
312+
c[bench_cfg_hash] = dataset
313+
return dataset
314+
315+
def add_metadata_to_dataset(self, bench_res: BenchResult, input_var: Any) -> None:
316+
"""Add variable metadata to the xarray dataset for improved visualization.
317+
318+
This method adds metadata like units, long names, and descriptions to the xarray dataset
319+
attributes, which helps visualization tools properly label axes and tooltips.
320+
321+
Args:
322+
bench_res (BenchResult): The benchmark result object containing the dataset to display
323+
input_var: The variable to extract metadata from
324+
"""
325+
for rv in bench_res.bench_cfg.result_vars:
326+
if type(rv) is ResultVar:
327+
bench_res.ds[rv.name].attrs["units"] = rv.units
328+
bench_res.ds[rv.name].attrs["long_name"] = rv.name
329+
elif type(rv) is ResultVec:
330+
for i in range(rv.size):
331+
bench_res.ds[rv.index_name(i)].attrs["units"] = rv.units
332+
bench_res.ds[rv.index_name(i)].attrs["long_name"] = rv.name
333+
else:
334+
pass # todo
335+
336+
dsvar = bench_res.ds[input_var.name]
337+
dsvar.attrs["long_name"] = input_var.name
338+
if input_var.units is not None:
339+
dsvar.attrs["units"] = input_var.units
340+
if input_var.__doc__ is not None:
341+
dsvar.attrs["description"] = input_var.__doc__
342+
343+
def report_results(
344+
self, bench_res: BenchResult, print_xarray: bool, print_pandas: bool
345+
) -> None:
346+
"""Display the calculated benchmark data in various formats.
347+
348+
This method provides options to display the benchmark results as xarray data structures
349+
or pandas DataFrames for debugging and inspection.
350+
351+
Args:
352+
bench_res (BenchResult): The benchmark result containing the dataset to display
353+
print_xarray (bool): If True, log the raw xarray Dataset structure
354+
print_pandas (bool): If True, log the dataset converted to a pandas DataFrame
355+
"""
356+
if print_xarray:
357+
logger.info(bench_res.ds)
358+
if print_pandas:
359+
logger.info(bench_res.ds.to_dataframe())

0 commit comments

Comments
 (0)