diff --git a/bencher/bencher.py b/bencher/bencher.py index cf4b0a8f..e17c5bf4 100644 --- a/bencher/bencher.py +++ b/bencher/bencher.py @@ -5,7 +5,6 @@ from param import Parameter from typing import Callable, List, Optional, Tuple, Any from copy import deepcopy -import numpy as np import param import xarray as xr from diskcache import Cache @@ -15,32 +14,23 @@ from bencher.worker_job import WorkerJob -from bencher.bench_cfg import BenchCfg, BenchRunCfg, DimsCfg +from bencher.bench_cfg import BenchCfg, BenchRunCfg from bencher.bench_plot_server import BenchPlotServer from bencher.bench_report import BenchReport from bencher.variables.inputs import IntSweep -from bencher.variables.time import TimeSnapshot, TimeEvent -from bencher.variables.results import ( - XARRAY_MULTIDIM_RESULT_TYPES, - ResultVar, - ResultBool, - ResultVec, - ResultHmap, - ResultPath, - ResultVideo, - ResultImage, - ResultString, - ResultContainer, - ResultReference, - ResultDataSet, -) +from bencher.variables.results import ResultHmap from bencher.results.bench_result import BenchResult from bencher.variables.parametrised_sweep import ParametrizedSweep from bencher.job import Job, FutureCache, JobFuture, Executors from bencher.utils import params_to_str from bencher.sample_order import SampleOrder +# Import helper classes +from bencher.worker_manager import WorkerManager +from bencher.result_collector import ResultCollector +from bencher.sweep_executor import SweepExecutor, worker_kwargs_wrapper + # Default cache size for benchmark results (100 GB) DEFAULT_CACHE_SIZE_BYTES = int(100e9) @@ -53,83 +43,6 @@ handler.setFormatter(formatter) -def set_xarray_multidim( - data_array: xr.DataArray, index_tuple: Tuple[int, ...], value: Any -) -> xr.DataArray: - """Set a value in a multi-dimensional xarray at the specified index position. - - This function sets a value in an N-dimensional xarray using dynamic indexing - that works for any number of dimensions. - - Args: - data_array (xr.DataArray): The data array to modify - index_tuple (Tuple[int, ...]): The index coordinates as a tuple - value (Any): The value to set at the specified position - - Returns: - xr.DataArray: The modified data array - """ - data_array[index_tuple] = value - return data_array - - -def kwargs_to_input_cfg(worker_input_cfg: ParametrizedSweep, **kwargs) -> ParametrizedSweep: - """Create a configured instance of a ParametrizedSweep with the provided keyword arguments. - - Args: - worker_input_cfg (ParametrizedSweep): The ParametrizedSweep class to instantiate - **kwargs: Keyword arguments to update the configuration with - - Returns: - ParametrizedSweep: A configured instance of the worker_input_cfg class - """ - input_cfg = worker_input_cfg() - input_cfg.param.update(kwargs) - return input_cfg - - -def worker_cfg_wrapper(worker: Callable, worker_input_cfg: ParametrizedSweep, **kwargs) -> dict: - """Wrap a worker function to accept keyword arguments instead of a config object. - - This wrapper creates an instance of the worker_input_cfg class, updates it with the - provided keyword arguments, and passes it to the worker function. - - Args: - worker (Callable): The worker function that expects a config object - worker_input_cfg (ParametrizedSweep): The class defining the configuration - **kwargs: Keyword arguments to update the configuration with - - Returns: - dict: The result of calling the worker function with the configured input - """ - input_cfg = kwargs_to_input_cfg(worker_input_cfg, **kwargs) - return worker(input_cfg) - - -def worker_kwargs_wrapper(worker: Callable, bench_cfg: BenchCfg, **kwargs) -> dict: - """Prepare keyword arguments and pass them to a worker function. - - This wrapper helps filter out metadata parameters that should not be passed - to the worker function (like 'repeat', 'over_time', and 'time_event'). - - Args: - worker (Callable): The worker function to call - bench_cfg (BenchCfg): Benchmark configuration with parameters like pass_repeat - **kwargs: The keyword arguments to filter and pass to the worker - - Returns: - dict: The result from the worker function - """ - function_input_deep = deepcopy(kwargs) - if not bench_cfg.pass_repeat: - function_input_deep.pop("repeat") - if "over_time" in function_input_deep: - function_input_deep.pop("over_time") - if "time_event" in function_input_deep: - function_input_deep.pop("time_event") - return worker(**function_input_deep) - - class Bench(BenchPlotServer): def __init__( self, @@ -166,10 +79,16 @@ def __init__( if not isinstance(bench_name, str): raise TypeError(f"bench_name must be a string, got {type(bench_name).__name__}") self.bench_name = bench_name - self.worker = None - self.worker_class_instance = None - self.worker_input_cfg = None + + # Initialize helper classes + self.cache_size = DEFAULT_CACHE_SIZE_BYTES + self._worker_mgr = WorkerManager() + self._executor = SweepExecutor(cache_size=self.cache_size) + self._collector = ResultCollector(cache_size=self.cache_size) + + # Set worker using the manager self.set_worker(worker, worker_input_cfg) + self.run_cfg = run_cfg if report is None: self.report = BenchReport(self.bench_name) @@ -181,12 +100,6 @@ def __init__( self.bench_cfg_hashes = [] # a list of hashes that point to benchmark results self.last_run_cfg = None # cached run_cfg used to pass to the plotting function - self.sample_cache = None # store the results of each benchmark function call in a cache - self.ds_dynamic = {} # A dictionary to store unstructured vector datasets - - self.cache_size = DEFAULT_CACHE_SIZE_BYTES - - # self.bench_cfg = BenchCfg() # Maybe put this in SweepCfg self.input_vars = None @@ -195,6 +108,26 @@ def __init__( self.plot_callbacks = [] self.plot = True + @property + def sample_cache(self): + """Access the sample cache from the executor (for backward compatibility).""" + return self._executor.sample_cache + + @sample_cache.setter + def sample_cache(self, value): + """Set the sample cache on the executor (for backward compatibility).""" + self._executor.sample_cache = value + + @property + def ds_dynamic(self): + """Access the dynamic dataset from the collector (for backward compatibility).""" + return self._collector.ds_dynamic + + @ds_dynamic.setter + def ds_dynamic(self, value): + """Set the dynamic dataset on the collector (for backward compatibility).""" + self._collector.ds_dynamic = value + def add_plot_callback(self, callback: Callable[[BenchResult], pn.panel], **kwargs) -> None: """Add a plotting callback to be called on benchmark results. @@ -232,20 +165,11 @@ def set_worker( Raises: RuntimeError: If worker is a class type instead of an instance. """ - if isinstance(worker, ParametrizedSweep): - self.worker_class_instance = worker - # self.worker_class_type = type(worker) - self.worker = self.worker_class_instance.__call__ - logging.info("setting worker from bench class.__call__") - else: - if isinstance(worker, type): - raise RuntimeError("This should be a class instance, not a class") - if worker_input_cfg is None: - self.worker = worker - else: - self.worker = partial(worker_cfg_wrapper, worker, worker_input_cfg) - logging.info(f"setting worker {worker}") - self.worker_input_cfg = worker_input_cfg + self._worker_mgr.set_worker(worker, worker_input_cfg) + # Expose worker attributes for backward compatibility + self.worker = self._worker_mgr.worker + self.worker_class_instance = self._worker_mgr.worker_class_instance + self.worker_input_cfg = self._worker_mgr.worker_input_cfg def sweep_sequential( self, @@ -598,78 +522,21 @@ def run_sweep( self.results.append(bench_res) return bench_res + # TODO: Remove thin wrapper methods in major version bump - callers can use helpers directly def convert_vars_to_params( self, variable: param.Parameter | str | dict | tuple, var_type: str, run_cfg: Optional[BenchRunCfg], ) -> param.Parameter: - """Convert various input formats to param.Parameter objects. - - This method handles different ways of specifying variables in benchmark sweeps, - including direct param.Parameter objects, string names of parameters, or dictionaries - with parameter configuration details. It ensures all inputs are properly converted - to param.Parameter objects with the correct configuration. - - Args: - variable (param.Parameter | str | dict | tuple): The variable to convert, can be: - - param.Parameter: Already a parameter object - - str: Name of a parameter in the worker_class_instance - - dict: Configuration with 'name' and optional 'values', 'samples', 'max_level' - - tuple: Tuple that can be converted to a parameter - var_type (str): Type of variable ('input', 'result', or 'const') for error messages - run_cfg (Optional[BenchRunCfg]): Run configuration for level settings - - Returns: - param.Parameter: The converted parameter object - - Raises: - TypeError: If the variable cannot be converted to a param.Parameter - """ - if isinstance(variable, str): - variable = self.worker_class_instance.param.objects(instance=False)[variable] - if isinstance(variable, dict): - param_var = self.worker_class_instance.param.objects(instance=False)[variable["name"]] - if variable.get("values"): - param_var = param_var.with_sample_values(variable["values"]) - - if variable.get("samples"): - param_var = param_var.with_samples(variable["samples"]) - if variable.get("max_level"): - if run_cfg is not None: - param_var = param_var.with_level(run_cfg.level, variable["max_level"]) - variable = param_var - if not isinstance(variable, param.Parameter): - raise TypeError( - f"You need to use {var_type}_vars =[{self.worker_input_cfg}.param.your_variable], instead of {var_type}_vars =[{self.worker_input_cfg}.your_variable]" - ) - return variable + """Convert various input formats (str, dict, tuple) to param.Parameter objects.""" + return self._executor.convert_vars_to_params( + variable, var_type, run_cfg, self.worker_class_instance, self.worker_input_cfg + ) def cache_results(self, bench_res: BenchResult, bench_cfg_hash: str) -> None: - """Cache benchmark results for future retrieval. - - This method stores benchmark results in the disk cache using the benchmark - configuration hash as the key. It temporarily removes non-pickleable objects - from the benchmark result before caching. - - Args: - bench_res (BenchResult): The benchmark result to cache - bench_cfg_hash (str): The hash value to use as the cache key - """ - with Cache("cachedir/benchmark_inputs", size_limit=self.cache_size) as c: - logging.info(f"saving results with key: {bench_cfg_hash}") - self.bench_cfg_hashes.append(bench_cfg_hash) - # object index may not be pickleable so remove before caching - obj_index_tmp = bench_res.object_index - bench_res.object_index = [] - - c[bench_cfg_hash] = bench_res - - # restore object index - bench_res.object_index = obj_index_tmp - - logging.info(f"saving benchmark: {self.bench_name}") - c[self.bench_name] = self.bench_cfg_hashes + """Cache benchmark results to disk using the config hash as key.""" + self._collector.cache_results(bench_res, bench_cfg_hash, self.bench_cfg_hashes) # def show(self, run_cfg: BenchRunCfg = None, pane: pn.panel = None) -> None: # """Launch a web server with plots of the benchmark results. @@ -698,159 +565,24 @@ def cache_results(self, bench_res: BenchResult, bench_cfg_hash: str) -> None: def load_history_cache( self, dataset: xr.Dataset, bench_cfg_hash: str, clear_history: bool ) -> xr.Dataset: - """Load historical data from a cache if over_time is enabled. - - This method is used to retrieve and concatenate historical benchmark data from the cache - when tracking performance over time. If clear_history is True, it will clear any existing - historical data instead of loading it. - - Args: - dataset (xr.Dataset): Freshly calculated benchmark data for the current run - bench_cfg_hash (str): Hash of the input variables used to identify cached data - clear_history (bool): If True, clears historical data instead of loading it - - Returns: - xr.Dataset: Combined dataset with both historical and current benchmark data, - or just the current data if no history exists or history is cleared - """ - with Cache("cachedir/history", size_limit=self.cache_size) as c: - if clear_history: - logging.info("clearing history") - else: - logging.info(f"checking historical key: {bench_cfg_hash}") - if bench_cfg_hash in c: - logging.info("loading historical data from cache") - ds_old = c[bench_cfg_hash] - dataset = xr.concat([ds_old, dataset], "over_time") - else: - logging.info("did not detect any historical data") - - logging.info("saving data to history cache") - c[bench_cfg_hash] = dataset - return dataset + """Load and concatenate historical benchmark data from cache.""" + return self._collector.load_history_cache(dataset, bench_cfg_hash, clear_history) def setup_dataset( self, bench_cfg: BenchCfg, time_src: datetime | str ) -> tuple[BenchResult, List[tuple], List[str]]: - """Initialize an n-dimensional xarray dataset from benchmark configuration parameters. - - This function creates the data structures needed to store benchmark results based on - the provided configuration. It sets up the xarray dimensions, coordinates, and variables - based on input variables and result variables. - - Args: - bench_cfg (BenchCfg): Configuration defining the benchmark parameters, inputs, and results - time_src (datetime | str): Timestamp or event name for the benchmark run - - Returns: - tuple[BenchResult, List[tuple], List[str]]: - - A BenchResult object with the initialized dataset - - A list of function input tuples (index, value pairs) - - A list of dimension names for the dataset - """ - if time_src is None: - time_src = datetime.now() - bench_cfg.meta_vars = self.define_extra_vars(bench_cfg, bench_cfg.repeats, time_src) - - bench_cfg.all_vars = bench_cfg.input_vars + bench_cfg.meta_vars - # bench_cfg.all_vars = bench_cfg.iv_time + bench_cfg.input_vars +[ bench_cfg.iv_repeat] - # bench_cfg.all_vars = [ bench_cfg.iv_repeat] +bench_cfg.input_vars + bench_cfg.iv_time - - for i in bench_cfg.all_vars: - logging.info(i.sampling_str()) - - dims_cfg = DimsCfg(bench_cfg) - function_inputs = list( - zip(product(*dims_cfg.dim_ranges_index), product(*dims_cfg.dim_ranges)) - ) - # xarray stores K N-dimensional arrays of data. Each array is named and in this case we have an ND array for each result variable - data_vars = {} - dataset_list = [] - - for rv in bench_cfg.result_vars: - if isinstance(rv, (ResultVar, ResultBool)): - result_data = np.full(dims_cfg.dims_size, np.nan, dtype=float) - data_vars[rv.name] = (dims_cfg.dims_name, result_data) - if isinstance(rv, (ResultReference, ResultDataSet)): - result_data = np.full(dims_cfg.dims_size, -1, dtype=int) - data_vars[rv.name] = (dims_cfg.dims_name, result_data) - if isinstance( - rv, (ResultPath, ResultVideo, ResultImage, ResultString, ResultContainer) - ): - result_data = np.full(dims_cfg.dims_size, "NAN", dtype=object) - data_vars[rv.name] = (dims_cfg.dims_name, result_data) - - elif type(rv) is ResultVec: - for i in range(rv.size): - result_data = np.full(dims_cfg.dims_size, np.nan) - data_vars[rv.index_name(i)] = (dims_cfg.dims_name, result_data) - - bench_res = BenchResult(bench_cfg) - bench_res.ds = xr.Dataset(data_vars=data_vars, coords=dims_cfg.coords) - bench_res.ds_dynamic = self.ds_dynamic - bench_res.dataset_list = dataset_list - bench_res.setup_object_index() - - return bench_res, function_inputs, dims_cfg.dims_name + """Initialize n-dimensional xarray dataset for storing benchmark results.""" + return self._collector.setup_dataset(bench_cfg, time_src) def define_const_inputs(self, const_vars: List[Tuple[param.Parameter, Any]]) -> Optional[dict]: - """Convert constant variable tuples into a dictionary of name-value pairs. - - Args: - const_vars (List[Tuple[param.Parameter, Any]]): List of (parameter, value) tuples - representing constant parameters and their values - - Returns: - Optional[dict]: Dictionary mapping parameter names to their constant values, - or None if const_vars is None - """ - constant_inputs = None - if const_vars is not None: - const_vars, constant_values = [ - [i for i, j in const_vars], - [j for i, j in const_vars], - ] - - constant_names = [i.name for i in const_vars] - constant_inputs = dict(zip(constant_names, constant_values)) - return constant_inputs + """Convert constant variable tuples into a name-value dictionary.""" + return self._executor.define_const_inputs(const_vars) def define_extra_vars( self, bench_cfg: BenchCfg, repeats: int, time_src: datetime | str ) -> List[IntSweep]: - """Define extra meta variables for tracking benchmark execution details. - - This function creates variables that aren't passed to the worker function but are stored - in the n-dimensional array to provide context about the benchmark, such as the number of - repeat measurements and timestamps. - - Args: - bench_cfg (BenchCfg): The benchmark configuration to add variables to - repeats (int): The number of times each sample point should be measured - time_src (datetime | str): Either a timestamp or a string event name for temporal tracking - - Returns: - List[IntSweep]: A list of additional parameter variables to include in the benchmark - """ - bench_cfg.iv_repeat = IntSweep( - default=repeats, - bounds=[1, repeats], - samples=repeats, - units="repeats", - doc="The number of times a sample was measured", - ) - bench_cfg.iv_repeat.name = "repeat" - extra_vars = [bench_cfg.iv_repeat] - - if bench_cfg.over_time: - if isinstance(time_src, str): - iv_over_time = TimeEvent(time_src) - else: - iv_over_time = TimeSnapshot(time_src) - iv_over_time.name = "over_time" - extra_vars.append(iv_over_time) - bench_cfg.iv_time = [iv_over_time] - return extra_vars + """Define meta variables (repeat count, timestamps) for benchmark tracking.""" + return self._collector.define_extra_vars(bench_cfg, repeats, time_src) def calculate_benchmark_results( self, @@ -953,150 +685,26 @@ def store_results( worker_job: WorkerJob, bench_run_cfg: BenchRunCfg, ) -> None: - """Store the results from a benchmark worker job into the benchmark result dataset. - - This method handles unpacking the results from worker jobs and placing them - in the correct locations in the n-dimensional result dataset. It supports different - types of result variables including scalars, vectors, references, and media. - - Args: - job_result (JobFuture): The future containing the worker function result - bench_res (BenchResult): The benchmark result object to store results in - worker_job (WorkerJob): The job metadata needed to index the result - bench_run_cfg (BenchRunCfg): Configuration for how results should be handled - - Raises: - RuntimeError: If an unsupported result variable type is encountered - """ - result = job_result.result() - if result is not None: - logging.info(f"{job_result.job.job_id}:") - if bench_res.bench_cfg.print_bench_inputs: - for k, v in worker_job.function_input.items(): - logging.info(f"\t {k}:{v}") - - result_dict = result if isinstance(result, dict) else result.param.values() - - for rv in bench_res.bench_cfg.result_vars: - result_value = result_dict[rv.name] - if bench_run_cfg.print_bench_results: - logging.info(f"{rv.name}: {result_value}") - - if isinstance(rv, XARRAY_MULTIDIM_RESULT_TYPES): - set_xarray_multidim(bench_res.ds[rv.name], worker_job.index_tuple, result_value) - elif isinstance(rv, ResultDataSet): - bench_res.dataset_list.append(result_value) - set_xarray_multidim( - bench_res.ds[rv.name], - worker_job.index_tuple, - len(bench_res.dataset_list) - 1, - ) - elif isinstance(rv, ResultReference): - bench_res.object_index.append(result_value) - set_xarray_multidim( - bench_res.ds[rv.name], - worker_job.index_tuple, - len(bench_res.object_index) - 1, - ) - - elif isinstance(rv, ResultVec): - if isinstance(result_value, (list, np.ndarray)): - if len(result_value) == rv.size: - for i in range(rv.size): - set_xarray_multidim( - bench_res.ds[rv.index_name(i)], - worker_job.index_tuple, - result_value[i], - ) - - else: - raise RuntimeError("Unsupported result type") - for rv in bench_res.result_hmaps: - bench_res.hmaps[rv.name][worker_job.canonical_input] = result_dict[rv.name] - - # bench_cfg.hmap = bench_cfg.hmaps[bench_cfg.result_hmaps[0].name] + """Store worker job results into the n-dimensional result dataset.""" + self._collector.store_results(job_result, bench_res, worker_job, bench_run_cfg) def init_sample_cache(self, run_cfg: BenchRunCfg) -> FutureCache: - """Initialize the sample cache for storing benchmark function results. - - This method creates a FutureCache for storing and retrieving benchmark results - based on the run configuration settings. - - Args: - run_cfg (BenchRunCfg): Configuration with cache settings such as overwrite policy, - executor type, and whether to cache results - - Returns: - FutureCache: A configured cache for storing benchmark results - """ - return FutureCache( - overwrite=run_cfg.overwrite_sample_cache, - executor=run_cfg.executor, - cache_name="sample_cache", - tag_index=True, - size_limit=self.cache_size, - cache_results=run_cfg.cache_samples, - ) + """Initialize the FutureCache for storing benchmark function results.""" + return self._executor.init_sample_cache(run_cfg) def clear_tag_from_sample_cache(self, tag: str, run_cfg: BenchRunCfg) -> None: - """Clear all samples from the cache that match a specific tag. - - This method is useful when you want to rerun a benchmark with the same tag - but want fresh results instead of using cached data. - - Args: - tag (str): The tag identifying samples to clear from the cache - run_cfg (BenchRunCfg): Run configuration used to initialize the sample cache if needed - """ - if self.sample_cache is None: - self.sample_cache = self.init_sample_cache(run_cfg) - self.sample_cache.clear_tag(tag) + """Clear all cached samples matching a specific tag.""" + self._executor.clear_tag_from_sample_cache(tag, run_cfg) def add_metadata_to_dataset(self, bench_res: BenchResult, input_var: ParametrizedSweep) -> None: - """Add variable metadata to the xarray dataset for improved visualization. - - This method adds metadata like units, long names, and descriptions to the xarray dataset - attributes, which helps visualization tools properly label axes and tooltips. - - Args: - bench_res (BenchResult): The benchmark result object containing the dataset to display - input_var (ParametrizedSweep): The variable to extract metadata from - """ - for rv in bench_res.bench_cfg.result_vars: - if type(rv) is ResultVar: - bench_res.ds[rv.name].attrs["units"] = rv.units - bench_res.ds[rv.name].attrs["long_name"] = rv.name - elif type(rv) is ResultVec: - for i in range(rv.size): - bench_res.ds[rv.index_name(i)].attrs["units"] = rv.units - bench_res.ds[rv.index_name(i)].attrs["long_name"] = rv.name - else: - pass # todo - - dsvar = bench_res.ds[input_var.name] - dsvar.attrs["long_name"] = input_var.name - if input_var.units is not None: - dsvar.attrs["units"] = input_var.units - if input_var.__doc__ is not None: - dsvar.attrs["description"] = input_var.__doc__ + """Add units, long names, and descriptions to xarray dataset attributes.""" + self._collector.add_metadata_to_dataset(bench_res, input_var) def report_results( self, bench_res: BenchResult, print_xarray: bool, print_pandas: bool ) -> None: - """Display the calculated benchmark data in various formats. - - This method provides options to display the benchmark results as xarray data structures - or pandas DataFrames for debugging and inspection. - - Args: - bench_res (BenchResult): The benchmark result containing the dataset to display - print_xarray (bool): If True, log the raw xarray Dataset structure - print_pandas (bool): If True, log the dataset converted to a pandas DataFrame - """ - if print_xarray: - logging.info(bench_res.ds) - if print_pandas: - logging.info(bench_res.ds.to_dataframe()) + """Log benchmark results as xarray or pandas DataFrame.""" + self._collector.report_results(bench_res, print_xarray, print_pandas) def clear_call_counts(self) -> None: """Clear the worker and cache call counts, to help debug and assert caching is happening properly""" diff --git a/bencher/result_collector.py b/bencher/result_collector.py new file mode 100644 index 00000000..2786b0d6 --- /dev/null +++ b/bencher/result_collector.py @@ -0,0 +1,359 @@ +"""Result collection and storage for benchmarking. + +This module provides the ResultCollector class for managing benchmark results, +including xarray dataset operations, caching, and metadata management. +""" + +import logging +from datetime import datetime +from itertools import product +from typing import Any, List, Tuple + +import numpy as np +import xarray as xr +from diskcache import Cache + +from bencher.bench_cfg import BenchCfg, BenchRunCfg, DimsCfg +from bencher.results.bench_result import BenchResult +from bencher.variables.inputs import IntSweep +from bencher.variables.time import TimeSnapshot, TimeEvent +from bencher.variables.results import ( + XARRAY_MULTIDIM_RESULT_TYPES, + ResultVar, + ResultBool, + ResultVec, + ResultPath, + ResultVideo, + ResultImage, + ResultString, + ResultContainer, + ResultReference, + ResultDataSet, +) +from bencher.worker_job import WorkerJob +from bencher.job import JobFuture + +# Default cache size for benchmark results (100 GB) +DEFAULT_CACHE_SIZE_BYTES = int(100e9) + +logger = logging.getLogger(__name__) + + +def set_xarray_multidim( + data_array: xr.DataArray, index_tuple: Tuple[int, ...], value: Any +) -> xr.DataArray: + """Set a value in a multi-dimensional xarray at the specified index position. + + This function sets a value in an N-dimensional xarray using dynamic indexing + that works for any number of dimensions. + + Args: + data_array (xr.DataArray): The data array to modify + index_tuple (Tuple[int, ...]): The index coordinates as a tuple + value (Any): The value to set at the specified position + + Returns: + xr.DataArray: The modified data array + """ + data_array[index_tuple] = value + return data_array + + +class ResultCollector: + """Manages benchmark result collection, storage, and caching. + + This class handles the initialization of xarray datasets for storing benchmark + results, storing results from worker jobs, managing caches, and adding metadata. + + Attributes: + cache_size (int): Maximum size of the cache in bytes + ds_dynamic (dict): Dictionary for storing unstructured vector datasets + """ + + def __init__(self, cache_size: int = DEFAULT_CACHE_SIZE_BYTES) -> None: + """Initialize a new ResultCollector. + + Args: + cache_size (int): Maximum cache size in bytes. Defaults to 100 GB. + """ + self.cache_size = cache_size + self.ds_dynamic: dict = {} + + def setup_dataset( + self, bench_cfg: BenchCfg, time_src: datetime | str + ) -> Tuple[BenchResult, List[Tuple], List[str]]: + """Initialize an n-dimensional xarray dataset from benchmark configuration parameters. + + This function creates the data structures needed to store benchmark results based on + the provided configuration. It sets up the xarray dimensions, coordinates, and variables + based on input variables and result variables. + + Args: + bench_cfg (BenchCfg): Configuration defining the benchmark parameters, inputs, and + results + time_src (datetime | str): Timestamp or event name for the benchmark run + + Returns: + Tuple[BenchResult, List[Tuple], List[str]]: + - A BenchResult object with the initialized dataset + - A list of function input tuples (index, value pairs) + - A list of dimension names for the dataset + """ + if time_src is None: + time_src = datetime.now() + bench_cfg.meta_vars = self.define_extra_vars(bench_cfg, bench_cfg.repeats, time_src) + + bench_cfg.all_vars = bench_cfg.input_vars + bench_cfg.meta_vars + + for i in bench_cfg.all_vars: + logger.info(i.sampling_str()) + + dims_cfg = DimsCfg(bench_cfg) + function_inputs = list( + zip(product(*dims_cfg.dim_ranges_index), product(*dims_cfg.dim_ranges)) + ) + # xarray stores K N-dimensional arrays of data. + # Each array is named and in this case we have an ND array for each result variable + data_vars = {} + dataset_list = [] + + for rv in bench_cfg.result_vars: + if isinstance(rv, (ResultVar, ResultBool)): + result_data = np.full(dims_cfg.dims_size, np.nan, dtype=float) + data_vars[rv.name] = (dims_cfg.dims_name, result_data) + if isinstance(rv, (ResultReference, ResultDataSet)): + result_data = np.full(dims_cfg.dims_size, -1, dtype=int) + data_vars[rv.name] = (dims_cfg.dims_name, result_data) + if isinstance( + rv, (ResultPath, ResultVideo, ResultImage, ResultString, ResultContainer) + ): + result_data = np.full(dims_cfg.dims_size, "NAN", dtype=object) + data_vars[rv.name] = (dims_cfg.dims_name, result_data) + + elif type(rv) is ResultVec: + for i in range(rv.size): + result_data = np.full(dims_cfg.dims_size, np.nan) + data_vars[rv.index_name(i)] = (dims_cfg.dims_name, result_data) + + bench_res = BenchResult(bench_cfg) + bench_res.ds = xr.Dataset(data_vars=data_vars, coords=dims_cfg.coords) + bench_res.ds_dynamic = self.ds_dynamic + bench_res.dataset_list = dataset_list + bench_res.setup_object_index() + + return bench_res, function_inputs, dims_cfg.dims_name + + def define_extra_vars( + self, bench_cfg: BenchCfg, repeats: int, time_src: datetime | str + ) -> List[IntSweep]: + """Define extra meta variables for tracking benchmark execution details. + + This function creates variables that aren't passed to the worker function but are stored + in the n-dimensional array to provide context about the benchmark, such as the number of + repeat measurements and timestamps. + + Args: + bench_cfg (BenchCfg): The benchmark configuration to add variables to + repeats (int): The number of times each sample point should be measured + time_src (datetime | str): Either a timestamp or a string event name for temporal + tracking + + Returns: + List[IntSweep]: A list of additional parameter variables to include in the benchmark + """ + bench_cfg.iv_repeat = IntSweep( + default=repeats, + bounds=[1, repeats], + samples=repeats, + units="repeats", + doc="The number of times a sample was measured", + ) + bench_cfg.iv_repeat.name = "repeat" + extra_vars = [bench_cfg.iv_repeat] + + if bench_cfg.over_time: + if isinstance(time_src, str): + iv_over_time = TimeEvent(time_src) + else: + iv_over_time = TimeSnapshot(time_src) + iv_over_time.name = "over_time" + extra_vars.append(iv_over_time) + bench_cfg.iv_time = [iv_over_time] + return extra_vars + + def store_results( + self, + job_result: JobFuture, + bench_res: BenchResult, + worker_job: WorkerJob, + bench_run_cfg: BenchRunCfg, + ) -> None: + """Store the results from a benchmark worker job into the benchmark result dataset. + + This method handles unpacking the results from worker jobs and placing them + in the correct locations in the n-dimensional result dataset. It supports different + types of result variables including scalars, vectors, references, and media. + + Args: + job_result (JobFuture): The future containing the worker function result + bench_res (BenchResult): The benchmark result object to store results in + worker_job (WorkerJob): The job metadata needed to index the result + bench_run_cfg (BenchRunCfg): Configuration for how results should be handled + + Raises: + RuntimeError: If an unsupported result variable type is encountered + """ + result = job_result.result() + if result is not None: + logger.info(f"{job_result.job.job_id}:") + if bench_res.bench_cfg.print_bench_inputs: + for k, v in worker_job.function_input.items(): + logger.info(f"\t {k}:{v}") + + result_dict = result if isinstance(result, dict) else result.param.values() + + for rv in bench_res.bench_cfg.result_vars: + result_value = result_dict[rv.name] + if bench_run_cfg.print_bench_results: + logger.info(f"{rv.name}: {result_value}") + + if isinstance(rv, XARRAY_MULTIDIM_RESULT_TYPES): + set_xarray_multidim(bench_res.ds[rv.name], worker_job.index_tuple, result_value) + elif isinstance(rv, ResultDataSet): + bench_res.dataset_list.append(result_value) + set_xarray_multidim( + bench_res.ds[rv.name], + worker_job.index_tuple, + len(bench_res.dataset_list) - 1, + ) + elif isinstance(rv, ResultReference): + bench_res.object_index.append(result_value) + set_xarray_multidim( + bench_res.ds[rv.name], + worker_job.index_tuple, + len(bench_res.object_index) - 1, + ) + + elif isinstance(rv, ResultVec): + if isinstance(result_value, (list, np.ndarray)): + if len(result_value) == rv.size: + for i in range(rv.size): + set_xarray_multidim( + bench_res.ds[rv.index_name(i)], + worker_job.index_tuple, + result_value[i], + ) + + else: + raise RuntimeError("Unsupported result type") + for rv in bench_res.result_hmaps: + bench_res.hmaps[rv.name][worker_job.canonical_input] = result_dict[rv.name] + + def cache_results( + self, bench_res: BenchResult, bench_cfg_hash: str, bench_cfg_hashes: List[str] + ) -> None: + """Cache benchmark results for future retrieval. + + This method stores benchmark results in the disk cache using the benchmark + configuration hash as the key. It temporarily removes non-pickleable objects + from the benchmark result before caching. + + Args: + bench_res (BenchResult): The benchmark result to cache + bench_cfg_hash (str): The hash value to use as the cache key + bench_cfg_hashes (List[str]): List to append the hash to (modified in place) + """ + with Cache("cachedir/benchmark_inputs", size_limit=self.cache_size) as c: + logger.info(f"saving results with key: {bench_cfg_hash}") + bench_cfg_hashes.append(bench_cfg_hash) + # object index may not be pickleable so remove before caching + obj_index_tmp = bench_res.object_index + bench_res.object_index = [] + + c[bench_cfg_hash] = bench_res + + # restore object index + bench_res.object_index = obj_index_tmp + + logger.info(f"saving benchmark: {bench_res.bench_cfg.bench_name}") + c[bench_res.bench_cfg.bench_name] = bench_cfg_hashes + + def load_history_cache( + self, dataset: xr.Dataset, bench_cfg_hash: str, clear_history: bool + ) -> xr.Dataset: + """Load historical data from a cache if over_time is enabled. + + This method is used to retrieve and concatenate historical benchmark data from the cache + when tracking performance over time. If clear_history is True, it will clear any existing + historical data instead of loading it. + + Args: + dataset (xr.Dataset): Freshly calculated benchmark data for the current run + bench_cfg_hash (str): Hash of the input variables used to identify cached data + clear_history (bool): If True, clears historical data instead of loading it + + Returns: + xr.Dataset: Combined dataset with both historical and current benchmark data, + or just the current data if no history exists or history is cleared + """ + with Cache("cachedir/history", size_limit=self.cache_size) as c: + if clear_history: + logger.info("clearing history") + else: + logger.info(f"checking historical key: {bench_cfg_hash}") + if bench_cfg_hash in c: + logger.info("loading historical data from cache") + ds_old = c[bench_cfg_hash] + dataset = xr.concat([ds_old, dataset], "over_time") + else: + logger.info("did not detect any historical data") + + logger.info("saving data to history cache") + c[bench_cfg_hash] = dataset + return dataset + + def add_metadata_to_dataset(self, bench_res: BenchResult, input_var: Any) -> None: + """Add variable metadata to the xarray dataset for improved visualization. + + This method adds metadata like units, long names, and descriptions to the xarray dataset + attributes, which helps visualization tools properly label axes and tooltips. + + Args: + bench_res (BenchResult): The benchmark result object containing the dataset to display + input_var: The variable to extract metadata from + """ + for rv in bench_res.bench_cfg.result_vars: + if type(rv) is ResultVar: + bench_res.ds[rv.name].attrs["units"] = rv.units + bench_res.ds[rv.name].attrs["long_name"] = rv.name + elif type(rv) is ResultVec: + for i in range(rv.size): + bench_res.ds[rv.index_name(i)].attrs["units"] = rv.units + bench_res.ds[rv.index_name(i)].attrs["long_name"] = rv.name + else: + pass # todo + + dsvar = bench_res.ds[input_var.name] + dsvar.attrs["long_name"] = input_var.name + if input_var.units is not None: + dsvar.attrs["units"] = input_var.units + if input_var.__doc__ is not None: + dsvar.attrs["description"] = input_var.__doc__ + + def report_results( + self, bench_res: BenchResult, print_xarray: bool, print_pandas: bool + ) -> None: + """Display the calculated benchmark data in various formats. + + This method provides options to display the benchmark results as xarray data structures + or pandas DataFrames for debugging and inspection. + + Args: + bench_res (BenchResult): The benchmark result containing the dataset to display + print_xarray (bool): If True, log the raw xarray Dataset structure + print_pandas (bool): If True, log the dataset converted to a pandas DataFrame + """ + if print_xarray: + logger.info(bench_res.ds) + if print_pandas: + logger.info(bench_res.ds.to_dataframe()) diff --git a/bencher/sweep_executor.py b/bencher/sweep_executor.py new file mode 100644 index 00000000..ad273e1d --- /dev/null +++ b/bencher/sweep_executor.py @@ -0,0 +1,206 @@ +"""Sweep execution for benchmarking. + +This module provides the SweepExecutor class for managing parameter sweep execution, +job creation, and cache management in benchmark runs. +""" + +import logging +from copy import deepcopy +from typing import Any, Callable, List, Optional, Tuple + +import param + +from bencher.bench_cfg import BenchCfg, BenchRunCfg +from bencher.job import FutureCache +from bencher.variables.parametrised_sweep import ParametrizedSweep + +# Default cache size for benchmark results (100 GB) +DEFAULT_CACHE_SIZE_BYTES = int(100e9) + +logger = logging.getLogger(__name__) + + +def worker_kwargs_wrapper(worker: Callable, bench_cfg: BenchCfg, **kwargs) -> dict: + """Prepare keyword arguments and pass them to a worker function. + + This wrapper helps filter out metadata parameters that should not be passed + to the worker function (like 'repeat', 'over_time', and 'time_event'). + + Args: + worker (Callable): The worker function to call + bench_cfg (BenchCfg): Benchmark configuration with parameters like pass_repeat + **kwargs: The keyword arguments to filter and pass to the worker + + Returns: + dict: The result from the worker function + """ + function_input_deep = deepcopy(kwargs) + if not bench_cfg.pass_repeat: + function_input_deep.pop("repeat") + if "over_time" in function_input_deep: + function_input_deep.pop("over_time") + if "time_event" in function_input_deep: + function_input_deep.pop("time_event") + return worker(**function_input_deep) + + +class SweepExecutor: + """Manages parameter sweep execution, job creation, and caching. + + This class handles the conversion of variables to parameters, initialization + of sample caches, and management of cache entries. + + Attributes: + cache_size (int): Maximum size of the cache in bytes + sample_cache (FutureCache): Cache for storing sample results + """ + + def __init__(self, cache_size: int = DEFAULT_CACHE_SIZE_BYTES) -> None: + """Initialize a new SweepExecutor. + + Args: + cache_size (int): Maximum cache size in bytes. Defaults to 100 GB. + """ + self.cache_size = cache_size + self.sample_cache: Optional[FutureCache] = None + + def convert_vars_to_params( + self, + variable: param.Parameter | str | dict | tuple, + var_type: str, + run_cfg: Optional[BenchRunCfg], + worker_class_instance: Optional[ParametrizedSweep] = None, + worker_input_cfg: Optional[ParametrizedSweep] = None, + ) -> param.Parameter: + """Convert various input formats to param.Parameter objects. + + This method handles different ways of specifying variables in benchmark sweeps, + including direct param.Parameter objects, string names of parameters, or dictionaries + with parameter configuration details. It ensures all inputs are properly converted + to param.Parameter objects with the correct configuration. + + Args: + variable (param.Parameter | str | dict | tuple): The variable to convert, can be: + - param.Parameter: Already a parameter object + - str: Name of a parameter in the worker_class_instance + - dict: Configuration with 'name' and optional 'values', 'samples', 'max_level' + - tuple: Tuple that can be converted to a parameter + var_type (str): Type of variable ('input', 'result', or 'const') for error messages + run_cfg (Optional[BenchRunCfg]): Run configuration for level settings + worker_class_instance (Optional[ParametrizedSweep]): The worker class instance for + looking up parameters by name + worker_input_cfg (Optional[ParametrizedSweep]): The worker input configuration class + + Returns: + param.Parameter: The converted parameter object + + Raises: + TypeError: If the variable cannot be converted to a param.Parameter + """ + if isinstance(variable, (str, dict)): + if worker_class_instance is None: + raise TypeError( + f"Cannot convert {var_type}_vars from string/dict without a worker class instance. " + f"Use param.Parameter objects directly or provide a ParametrizedSweep worker." + ) + if isinstance(variable, str): + variable = worker_class_instance.param.objects(instance=False)[variable] + if isinstance(variable, dict): + param_var = worker_class_instance.param.objects(instance=False)[variable["name"]] + if variable.get("values"): + param_var = param_var.with_sample_values(variable["values"]) + + if variable.get("samples"): + param_var = param_var.with_samples(variable["samples"]) + if variable.get("max_level"): + if run_cfg is not None: + param_var = param_var.with_level(run_cfg.level, variable["max_level"]) + variable = param_var + if not isinstance(variable, param.Parameter): + raise TypeError( + f"You need to use {var_type}_vars =[{worker_input_cfg}.param.your_variable], " + f"instead of {var_type}_vars =[{worker_input_cfg}.your_variable]" + ) + return variable + + def define_const_inputs(self, const_vars: List[Tuple[param.Parameter, Any]]) -> Optional[dict]: + """Convert constant variable tuples into a dictionary of name-value pairs. + + Args: + const_vars (List[Tuple[param.Parameter, Any]]): List of (parameter, value) tuples + representing constant parameters and their values + + Returns: + Optional[dict]: Dictionary mapping parameter names to their constant values, + or None if const_vars is None + """ + constant_inputs = None + if const_vars is not None: + const_vars_list, constant_values = [ + [i for i, j in const_vars], + [j for i, j in const_vars], + ] + + constant_names = [i.name for i in const_vars_list] + constant_inputs = dict(zip(constant_names, constant_values)) + return constant_inputs + + def init_sample_cache(self, run_cfg: BenchRunCfg) -> FutureCache: + """Initialize the sample cache for storing benchmark function results. + + This method creates a FutureCache for storing and retrieving benchmark results + based on the run configuration settings. + + Args: + run_cfg (BenchRunCfg): Configuration with cache settings such as overwrite policy, + executor type, and whether to cache results + + Returns: + FutureCache: A configured cache for storing benchmark results + """ + self.sample_cache = FutureCache( + overwrite=run_cfg.overwrite_sample_cache, + executor=run_cfg.executor, + cache_name="sample_cache", + tag_index=True, + size_limit=self.cache_size, + cache_results=run_cfg.cache_samples, + ) + return self.sample_cache + + def clear_tag_from_sample_cache(self, tag: str, run_cfg: BenchRunCfg) -> None: + """Clear all samples from the cache that match a specific tag. + + This method is useful when you want to rerun a benchmark with the same tag + but want fresh results instead of using cached data. + + Args: + tag (str): The tag identifying samples to clear from the cache + run_cfg (BenchRunCfg): Run configuration used to initialize the sample cache if needed + """ + if self.sample_cache is None: + self.sample_cache = self.init_sample_cache(run_cfg) + self.sample_cache.clear_tag(tag) + + def clear_call_counts(self) -> None: + """Clear the worker and cache call counts. + + This helps debug and assert caching is happening properly. + """ + if self.sample_cache is not None: + self.sample_cache.clear_call_counts() + + def close_cache(self) -> None: + """Close the sample cache if it exists.""" + if self.sample_cache is not None: + self.sample_cache.close() + + def get_cache_stats(self) -> str: + """Get statistics about cache usage. + + Returns: + str: A string with cache statistics + """ + if self.sample_cache is not None: + return self.sample_cache.stats() + return "" diff --git a/bencher/worker_manager.py b/bencher/worker_manager.py new file mode 100644 index 00000000..3345e85a --- /dev/null +++ b/bencher/worker_manager.py @@ -0,0 +1,148 @@ +"""Worker management for benchmarking. + +This module provides the WorkerManager class for handling worker function +configuration and validation in benchmark runs. +""" + +import logging +from functools import partial +from typing import Callable, List, Optional + +from bencher.variables.parametrised_sweep import ParametrizedSweep + +logger = logging.getLogger(__name__) + + +def kwargs_to_input_cfg(worker_input_cfg: ParametrizedSweep, **kwargs) -> ParametrizedSweep: + """Create a configured instance of a ParametrizedSweep with the provided keyword arguments. + + Args: + worker_input_cfg (ParametrizedSweep): The ParametrizedSweep class to instantiate + **kwargs: Keyword arguments to update the configuration with + + Returns: + ParametrizedSweep: A configured instance of the worker_input_cfg class + """ + input_cfg = worker_input_cfg() + input_cfg.param.update(kwargs) + return input_cfg + + +def worker_cfg_wrapper(worker: Callable, worker_input_cfg: ParametrizedSweep, **kwargs) -> dict: + """Wrap a worker function to accept keyword arguments instead of a config object. + + This wrapper creates an instance of the worker_input_cfg class, updates it with the + provided keyword arguments, and passes it to the worker function. + + Args: + worker (Callable): The worker function that expects a config object + worker_input_cfg (ParametrizedSweep): The class defining the configuration + **kwargs: Keyword arguments to update the configuration with + + Returns: + dict: The result of calling the worker function with the configured input + """ + input_cfg = kwargs_to_input_cfg(worker_input_cfg, **kwargs) + return worker(input_cfg) + + +class WorkerManager: + """Manages worker function configuration and validation for benchmarks. + + This class handles the setup and management of worker functions used in benchmarking, + including support for both callable functions and ParametrizedSweep instances. + + Attributes: + worker (Callable): The configured worker function + worker_class_instance (ParametrizedSweep): The worker class instance if provided + worker_input_cfg (ParametrizedSweep): The input configuration class + """ + + def __init__(self) -> None: + """Initialize a new WorkerManager.""" + self.worker: Optional[Callable] = None + self.worker_class_instance: Optional[ParametrizedSweep] = None + self.worker_input_cfg: Optional[ParametrizedSweep] = None + + def set_worker( + self, + worker: Callable | ParametrizedSweep, + worker_input_cfg: Optional[ParametrizedSweep] = None, + ) -> None: + """Set the benchmark worker function and its input configuration. + + This method sets up the worker function to be benchmarked. The worker can be either a + callable function that takes a ParametrizedSweep instance or a ParametrizedSweep + instance with a __call__ method. In the latter case, worker_input_cfg is not needed. + + Args: + worker (Callable | ParametrizedSweep): Either a function that will be benchmarked or a + ParametrizedSweep instance with a __call__ method. When a ParametrizedSweep is + provided, its __call__ method becomes the worker function. + worker_input_cfg (ParametrizedSweep, optional): The class defining the input parameters + for the worker function. Only needed if worker is a function rather than a + ParametrizedSweep instance. Defaults to None. + + Raises: + RuntimeError: If worker is a class type instead of an instance. + """ + if isinstance(worker, ParametrizedSweep): + self.worker_class_instance = worker + self.worker = self.worker_class_instance.__call__ + logger.info("setting worker from bench class.__call__") + else: + if isinstance(worker, type): + raise RuntimeError("This should be a class instance, not a class") + if worker_input_cfg is None: + self.worker = worker + else: + self.worker = partial(worker_cfg_wrapper, worker, worker_input_cfg) + logger.info(f"setting worker {worker}") + self.worker_input_cfg = worker_input_cfg + + def get_result_vars(self, as_str: bool = True) -> List[str | ParametrizedSweep]: + """Retrieve the result variables from the worker class instance. + + Args: + as_str (bool): If True, the result variables are returned as strings. + If False, they are returned in their original form. + Default is True. + + Returns: + List[str | ParametrizedSweep]: A list of result variables, either as strings + or in their original form. + + Raises: + RuntimeError: If the worker class instance is not set. + """ + if self.worker_class_instance is not None: + if as_str: + return [i.name for i in self.worker_class_instance.get_results_only()] + return self.worker_class_instance.get_results_only() + raise RuntimeError("Worker class instance not set") + + def get_inputs_only(self) -> List[ParametrizedSweep]: + """Retrieve the input variables from the worker class instance. + + Returns: + List[ParametrizedSweep]: A list of input variables. + + Raises: + RuntimeError: If the worker class instance is not set. + """ + if self.worker_class_instance is not None: + return self.worker_class_instance.get_inputs_only() + raise RuntimeError("Worker class instance not set") + + def get_input_defaults(self) -> List: + """Retrieve the default input values from the worker class instance. + + Returns: + List: A list of default input values as (parameter, value) tuples. + + Raises: + RuntimeError: If the worker class instance is not set. + """ + if self.worker_class_instance is not None: + return self.worker_class_instance.get_input_defaults() + raise RuntimeError("Worker class instance not set") diff --git a/test/test_result_collector.py b/test/test_result_collector.py new file mode 100644 index 00000000..34a1fea9 --- /dev/null +++ b/test/test_result_collector.py @@ -0,0 +1,319 @@ +"""Tests for ResultCollector extracted from Bench.""" + +import shutil +import tempfile +import unittest +import uuid +from datetime import datetime +from unittest import mock + +import numpy as np +import xarray as xr +from hypothesis import given, settings, strategies as st + +from bencher.example.benchmark_data import ExampleBenchCfg +from bencher.result_collector import ResultCollector, set_xarray_multidim +from bencher.bench_cfg import BenchCfg + + +class TestResultCollector(unittest.TestCase): + """Tests for ResultCollector extracted from Bench.""" + + def setUp(self): + self.collector = ResultCollector() + + def test_init_default_cache_size(self): + """Test default cache size is set.""" + collector = ResultCollector() + self.assertEqual(collector.cache_size, int(100e9)) + + def test_init_custom_cache_size(self): + """Test custom cache size is set.""" + collector = ResultCollector(cache_size=int(50e9)) + self.assertEqual(collector.cache_size, int(50e9)) + + def test_setup_dataset_creates_bench_result(self): + """Test xarray dataset has correct structure.""" + instance = ExampleBenchCfg() + bench_cfg = BenchCfg( + input_vars=[instance.param.theta], + result_vars=[instance.param.out_sin], + const_vars=[], + bench_name="test", + title="test", + repeats=2, + ) + + bench_res, _, dims_name = self.collector.setup_dataset(bench_cfg, datetime(2024, 1, 1)) + + self.assertIsNotNone(bench_res) + self.assertIsNotNone(bench_res.ds) + self.assertIn("theta", dims_name) + self.assertIn("repeat", dims_name) + + def test_setup_dataset_result_vars_scalar(self): + """Test ResultVar creates float data_vars.""" + instance = ExampleBenchCfg() + instance.param.theta.samples = 3 + + bench_cfg = BenchCfg( + input_vars=[instance.param.theta], + result_vars=[instance.param.out_sin], + const_vars=[], + bench_name="test", + title="test", + repeats=1, + ) + + bench_res, _, _ = self.collector.setup_dataset(bench_cfg, datetime(2024, 1, 1)) + + self.assertIn("out_sin", bench_res.ds.data_vars) + self.assertEqual(bench_res.ds["out_sin"].dtype, np.float64) + + def test_define_extra_vars_repeat(self): + """Test repeat meta variable creation.""" + bench_cfg = BenchCfg( + input_vars=[], + result_vars=[], + const_vars=[], + bench_name="test", + title="test", + repeats=5, + over_time=False, + ) + + extra_vars = self.collector.define_extra_vars(bench_cfg, 5, datetime(2024, 1, 1)) + + self.assertEqual(len(extra_vars), 1) + self.assertEqual(extra_vars[0].name, "repeat") + self.assertEqual(len(extra_vars[0].values()), 5) + + def test_define_extra_vars_time_snapshot(self): + """Test TimeSnapshot creation for datetime.""" + bench_cfg = BenchCfg( + input_vars=[], + result_vars=[], + const_vars=[], + bench_name="test", + title="test", + repeats=1, + over_time=True, + ) + + extra_vars = self.collector.define_extra_vars(bench_cfg, 1, datetime(2024, 1, 1)) + + self.assertEqual(len(extra_vars), 2) + self.assertEqual(extra_vars[1].name, "over_time") + + def test_define_extra_vars_time_event(self): + """Test TimeEvent creation for string.""" + bench_cfg = BenchCfg( + input_vars=[], + result_vars=[], + const_vars=[], + bench_name="test", + title="test", + repeats=1, + over_time=True, + ) + + extra_vars = self.collector.define_extra_vars(bench_cfg, 1, "event_123") + + self.assertEqual(len(extra_vars), 2) + self.assertEqual(extra_vars[1].name, "over_time") + + def test_report_results_no_print(self): + """Test report_results with printing disabled.""" + instance = ExampleBenchCfg() + bench_cfg = BenchCfg( + input_vars=[instance.param.theta], + result_vars=[instance.param.out_sin], + const_vars=[], + bench_name="test", + title="test", + repeats=1, + ) + + bench_res, _, _ = self.collector.setup_dataset(bench_cfg, datetime(2024, 1, 1)) + + # Should not raise any errors + self.collector.report_results(bench_res, print_xarray=False, print_pandas=False) + + # Hypothesis property-based tests + @settings(deadline=10000) + @given( + repeats=st.integers(min_value=1, max_value=5), + over_time=st.booleans(), + ) + def test_define_extra_vars_combinations(self, repeats, over_time): + """Property: extra vars created correctly for all combinations.""" + bench_cfg = BenchCfg( + input_vars=[], + result_vars=[], + const_vars=[], + bench_name="test", + title="test", + repeats=repeats, + over_time=over_time, + ) + + extra_vars = self.collector.define_extra_vars(bench_cfg, repeats, datetime(2024, 1, 1)) + + # Always has repeat + self.assertGreaterEqual(len(extra_vars), 1) + self.assertEqual(extra_vars[0].name, "repeat") + + # Has time if over_time + if over_time: + self.assertEqual(len(extra_vars), 2) + self.assertEqual(extra_vars[1].name, "over_time") + else: + self.assertEqual(len(extra_vars), 1) + + +class TestCacheOperations(unittest.TestCase): + """Tests for cache_results and load_history_cache.""" + + def setUp(self): + self.collector = ResultCollector() + # Use temp directories for cache to avoid polluting real cache + self.temp_dir = tempfile.mkdtemp() + self.original_cache_dir = "cachedir/benchmark_inputs" + self.original_history_dir = "cachedir/history" + + def tearDown(self): + # Clean up temp directory + shutil.rmtree(self.temp_dir, ignore_errors=True) + + def test_cache_results_appends_hash_to_list(self): + """cache_results should append hash to the provided list.""" + instance = ExampleBenchCfg() + bench_cfg = BenchCfg( + input_vars=[instance.param.theta], + result_vars=[instance.param.out_sin], + const_vars=[], + bench_name="test_cache_append", + title="test", + repeats=1, + ) + + bench_res, _, _ = self.collector.setup_dataset(bench_cfg, datetime(2024, 1, 1)) + + # Start with empty list + bench_cfg_hashes = [] + + # Cache first result + self.collector.cache_results(bench_res, "hash-1", bench_cfg_hashes) + self.assertEqual(bench_cfg_hashes, ["hash-1"]) + + # Cache second result - should append, not replace + self.collector.cache_results(bench_res, "hash-2", bench_cfg_hashes) + self.assertEqual(bench_cfg_hashes, ["hash-1", "hash-2"]) + + def test_cache_results_preserves_object_index_in_memory(self): + """cache_results should restore object_index after caching.""" + instance = ExampleBenchCfg() + bench_cfg = BenchCfg( + input_vars=[instance.param.theta], + result_vars=[instance.param.out_sin], + const_vars=[], + bench_name="test_cache_object_index", + title="test", + repeats=1, + ) + + bench_res, _, _ = self.collector.setup_dataset(bench_cfg, datetime(2024, 1, 1)) + + # Set object_index to something non-empty + bench_res.object_index = ["obj-1", "obj-2"] + + bench_cfg_hashes = [] + self.collector.cache_results(bench_res, "hash-1", bench_cfg_hashes) + + # object_index should be restored in memory + self.assertEqual(bench_res.object_index, ["obj-1", "obj-2"]) + + def test_load_history_cache_no_existing_history(self): + """load_history_cache should return dataset unchanged when no history exists.""" + # Create a simple dataset + dataset = xr.Dataset({"var": (["x"], [1, 2, 3])}) + + # Use a truly unique hash that won't exist in any cache + unique_hash = f"nonexistent-hash-{uuid.uuid4()}" + result = self.collector.load_history_cache(dataset, unique_hash, False) + + # Should return the same dataset (no concat) + self.assertTrue(result.equals(dataset)) + + def test_load_history_cache_clear_history_flag(self): + """load_history_cache with clear_history=True should not concat.""" + dataset = xr.Dataset({"var": (["x"], [1, 2, 3])}) + + # Even with existing history, clear_history=True should skip concat + with mock.patch.object(xr, "concat") as mock_concat: + self.collector.load_history_cache(dataset, "some-hash", clear_history=True) + mock_concat.assert_not_called() + + def test_add_metadata_to_dataset_scalar_result(self): + """add_metadata_to_dataset should set attrs for scalar ResultVar.""" + instance = ExampleBenchCfg() + bench_cfg = BenchCfg( + input_vars=[instance.param.theta], + result_vars=[instance.param.out_sin], + const_vars=[], + bench_name="test_metadata", + title="test", + repeats=1, + ) + + bench_res, _, _ = self.collector.setup_dataset(bench_cfg, datetime(2024, 1, 1)) + + # Add metadata for theta input + self.collector.add_metadata_to_dataset(bench_res, instance.param.theta) + + # Check result var has units and long_name + self.assertEqual(bench_res.ds["out_sin"].attrs.get("units"), "v") + self.assertEqual(bench_res.ds["out_sin"].attrs.get("long_name"), "out_sin") + + # Check input var coordinate has metadata + self.assertEqual(bench_res.ds["theta"].attrs.get("long_name"), "theta") + self.assertEqual(bench_res.ds["theta"].attrs.get("units"), "rad") + + +class TestSetXarrayMultidim(unittest.TestCase): + """Tests for set_xarray_multidim utility function.""" + + def test_set_value_2d(self): + """Test setting value in 2D array.""" + data = xr.DataArray(np.zeros((3, 3)), dims=["x", "y"]) + set_xarray_multidim(data, (1, 2), 5.0) + self.assertEqual(data[1, 2].item(), 5.0) + + def test_set_value_3d(self): + """Test setting value in 3D array.""" + data = xr.DataArray(np.zeros((2, 2, 2)), dims=["x", "y", "z"]) + set_xarray_multidim(data, (1, 0, 1), 7.5) + self.assertEqual(data[1, 0, 1].item(), 7.5) + + def test_set_value_preserves_other_values(self): + """Test setting one value doesn't affect others.""" + data = xr.DataArray(np.ones((3, 3)), dims=["x", "y"]) + set_xarray_multidim(data, (0, 0), 99.0) + + self.assertEqual(data[0, 0].item(), 99.0) + self.assertEqual(data[1, 1].item(), 1.0) + self.assertEqual(data[2, 2].item(), 1.0) + + @settings(deadline=10000) + @given( + value=st.floats(allow_nan=False, allow_infinity=False, min_value=-1e10, max_value=1e10), + ) + def test_set_xarray_multidim_values(self, value): + """Property: set_xarray_multidim works for various values.""" + data = xr.DataArray(np.zeros((3, 3)), dims=["x", "y"]) + set_xarray_multidim(data, (1, 1), value) + self.assertEqual(data[1, 1].item(), value) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_set_xarray_multidim.py b/test/test_set_xarray_multidim.py index cec211ce..13a06ddc 100644 --- a/test/test_set_xarray_multidim.py +++ b/test/test_set_xarray_multidim.py @@ -1,6 +1,6 @@ import numpy as np import xarray as xr -from bencher.bencher import set_xarray_multidim +from bencher.result_collector import set_xarray_multidim class TestSetXarrayMultidim: diff --git a/test/test_sweep_executor.py b/test/test_sweep_executor.py new file mode 100644 index 00000000..c22ca7e2 --- /dev/null +++ b/test/test_sweep_executor.py @@ -0,0 +1,287 @@ +"""Tests for SweepExecutor extracted from Bench.""" + +import unittest + +from hypothesis import given, settings, strategies as st + +from bencher.example.benchmark_data import ExampleBenchCfg +from bencher.sweep_executor import SweepExecutor, worker_kwargs_wrapper +from bencher.bench_cfg import BenchCfg, BenchRunCfg +from bencher.job import Executors + + +class TestSweepExecutor(unittest.TestCase): + """Tests for SweepExecutor extracted from Bench.""" + + def setUp(self): + self.executor = SweepExecutor() + self.worker_instance = ExampleBenchCfg() + + def test_init_default_cache_size(self): + """Test default cache size is set.""" + executor = SweepExecutor() + self.assertEqual(executor.cache_size, int(100e9)) + + def test_init_custom_cache_size(self): + """Test custom cache size is set.""" + executor = SweepExecutor(cache_size=int(50e9)) + self.assertEqual(executor.cache_size, int(50e9)) + + def test_convert_vars_to_params_from_string(self): + """Test converting string variable names to params.""" + result = self.executor.convert_vars_to_params( + "theta", + "input", + None, + worker_class_instance=self.worker_instance, + worker_input_cfg=ExampleBenchCfg, + ) + self.assertEqual(result.name, "theta") + + def test_convert_vars_to_params_from_dict(self): + """Test converting dict config to params.""" + result = self.executor.convert_vars_to_params( + {"name": "theta", "samples": 5}, + "input", + None, + worker_class_instance=self.worker_instance, + worker_input_cfg=ExampleBenchCfg, + ) + self.assertEqual(result.name, "theta") + + def test_convert_vars_to_params_from_param(self): + """Test passing param.Parameter directly.""" + result = self.executor.convert_vars_to_params( + self.worker_instance.param.theta, + "input", + None, + worker_class_instance=self.worker_instance, + worker_input_cfg=ExampleBenchCfg, + ) + self.assertEqual(result.name, "theta") + + def test_convert_vars_to_params_type_error(self): + """Test proper error for invalid variable types.""" + with self.assertRaises(TypeError): + self.executor.convert_vars_to_params( + 12345, # Invalid type + "input", + None, + worker_class_instance=self.worker_instance, + worker_input_cfg=ExampleBenchCfg, + ) + + def test_define_const_inputs(self): + """Test converting const tuples to dict.""" + const_vars = [ + (self.worker_instance.param.theta, 1.5), + (self.worker_instance.param.offset, 0.1), + ] + result = self.executor.define_const_inputs(const_vars) + + self.assertEqual(result["theta"], 1.5) + self.assertEqual(result["offset"], 0.1) + + def test_define_const_inputs_none(self): + """Test None input returns None.""" + result = self.executor.define_const_inputs(None) + self.assertIsNone(result) + + def test_init_sample_cache(self): + """Test FutureCache initialization with config.""" + run_cfg = BenchRunCfg() + run_cfg.cache_samples = True + run_cfg.executor = Executors.SERIAL + + cache = self.executor.init_sample_cache(run_cfg) + + self.assertIsNotNone(cache) + self.assertEqual(self.executor.sample_cache, cache) + + def test_init_sample_cache_with_caching_disabled(self): + """Test FutureCache when cache_samples=False.""" + run_cfg = BenchRunCfg() + run_cfg.cache_samples = False + run_cfg.executor = Executors.SERIAL + + cache = self.executor.init_sample_cache(run_cfg) + + self.assertIsNotNone(cache) + # When cache_samples=False, cache.cache should be None + self.assertIsNone(cache.cache) + + def test_clear_call_counts(self): + """Test clearing call counts.""" + run_cfg = BenchRunCfg() + run_cfg.cache_samples = True + run_cfg.executor = Executors.SERIAL + + self.executor.init_sample_cache(run_cfg) + self.executor.sample_cache.worker_wrapper_call_count = 5 + + self.executor.clear_call_counts() + + self.assertEqual(self.executor.sample_cache.worker_wrapper_call_count, 0) + + def test_clear_call_counts_no_cache(self): + """Test clearing call counts when no cache exists.""" + # Should not raise + self.executor.clear_call_counts() + + def test_close_cache(self): + """Test closing the cache.""" + run_cfg = BenchRunCfg() + run_cfg.cache_samples = True + run_cfg.executor = Executors.SERIAL + + self.executor.init_sample_cache(run_cfg) + self.executor.close_cache() + + def test_close_cache_no_cache(self): + """Test closing cache when none exists.""" + # Should not raise + self.executor.close_cache() + + def test_get_cache_stats_no_cache(self): + """Test getting stats when no cache exists.""" + result = self.executor.get_cache_stats() + self.assertEqual(result, "") + + def test_get_cache_stats_with_cache(self): + """Test getting stats when cache is present.""" + run_cfg = BenchRunCfg() + run_cfg.cache_samples = True + run_cfg.executor = Executors.SERIAL + + self.executor.init_sample_cache(run_cfg) + result = self.executor.get_cache_stats() + + # Should return non-empty stats string + self.assertIsInstance(result, str) + + def test_convert_vars_to_params_with_max_level(self): + """Test max_level handling when run_cfg.level is set.""" + run_cfg = BenchRunCfg() + run_cfg.level = 2 + + result = self.executor.convert_vars_to_params( + {"name": "theta", "max_level": 3}, + "input", + run_cfg, + worker_class_instance=self.worker_instance, + worker_input_cfg=ExampleBenchCfg, + ) + + self.assertEqual(result.name, "theta") + # The parameter should have been processed with level adjustment + + def test_clear_tag_from_sample_cache_lazy_init(self): + """Test clear_tag_from_sample_cache initializes cache if None.""" + # sample_cache should be None initially + self.assertIsNone(self.executor.sample_cache) + + run_cfg = BenchRunCfg() + run_cfg.cache_samples = True + run_cfg.executor = Executors.SERIAL + + # This should initialize the cache lazily + self.executor.clear_tag_from_sample_cache("test_tag", run_cfg) + + # Cache should now be initialized + self.assertIsNotNone(self.executor.sample_cache) + + # Hypothesis property-based tests + @settings(deadline=10000) + @given( + cache_samples=st.booleans(), + ) + def test_init_sample_cache_configs(self, cache_samples): + """Property: cache initializes correctly with various configs.""" + run_cfg = BenchRunCfg() + run_cfg.cache_samples = cache_samples + run_cfg.executor = Executors.SERIAL + + cache = self.executor.init_sample_cache(run_cfg) + + self.assertIsNotNone(cache) + if cache_samples: + self.assertIsNotNone(cache.cache) + else: + self.assertIsNone(cache.cache) + + +class TestWorkerKwargsWrapper(unittest.TestCase): + """Tests for worker_kwargs_wrapper function.""" + + def test_filters_repeat_when_pass_repeat_false(self): + """Test repeat is filtered when pass_repeat=False.""" + call_log = [] + + def my_worker(**kwargs): + call_log.append(kwargs) + return {"result": 1} + + bench_cfg = BenchCfg( + input_vars=[], + result_vars=[], + const_vars=[], + bench_name="test", + title="test", + pass_repeat=False, + ) + + worker_kwargs_wrapper(my_worker, bench_cfg, theta=1.0, repeat=1) + + self.assertNotIn("repeat", call_log[0]) + self.assertIn("theta", call_log[0]) + + def test_passes_repeat_when_pass_repeat_true(self): + """Test repeat is passed when pass_repeat=True.""" + call_log = [] + + def my_worker(**kwargs): + call_log.append(kwargs) + return {"result": 1} + + bench_cfg = BenchCfg( + input_vars=[], + result_vars=[], + const_vars=[], + bench_name="test", + title="test", + pass_repeat=True, + ) + + worker_kwargs_wrapper(my_worker, bench_cfg, theta=1.0, repeat=1) + + self.assertIn("repeat", call_log[0]) + self.assertIn("theta", call_log[0]) + + def test_filters_meta_vars(self): + """Test over_time and time_event are always filtered.""" + call_log = [] + + def my_worker(**kwargs): + call_log.append(kwargs) + return {"result": 1} + + bench_cfg = BenchCfg( + input_vars=[], + result_vars=[], + const_vars=[], + bench_name="test", + title="test", + pass_repeat=True, + ) + + worker_kwargs_wrapper( + my_worker, bench_cfg, theta=1.0, repeat=1, over_time="2024-01-01", time_event="ev1" + ) + + self.assertNotIn("over_time", call_log[0]) + self.assertNotIn("time_event", call_log[0]) + self.assertIn("theta", call_log[0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_worker_manager.py b/test/test_worker_manager.py new file mode 100644 index 00000000..f7e070f4 --- /dev/null +++ b/test/test_worker_manager.py @@ -0,0 +1,137 @@ +"""Tests for WorkerManager extracted from Bench.""" + +import unittest +from hypothesis import given, settings, strategies as st + +from bencher.example.benchmark_data import ExampleBenchCfg +from bencher.worker_manager import WorkerManager, worker_cfg_wrapper, kwargs_to_input_cfg + + +class TestWorkerManager(unittest.TestCase): + """Tests for WorkerManager extracted from Bench.""" + + def setUp(self): + self.manager = WorkerManager() + + def test_set_worker_from_parametrized_sweep(self): + """Test setting worker from ParametrizedSweep instance.""" + instance = ExampleBenchCfg() + self.manager.set_worker(instance) + self.assertEqual(self.manager.worker, instance.__call__) + self.assertEqual(self.manager.worker_class_instance, instance) + + def test_set_worker_from_callable(self): + """Test setting worker from function.""" + + def my_worker(**_kwargs): + return {"result": 1} + + self.manager.set_worker(my_worker) + self.assertEqual(self.manager.worker, my_worker) + self.assertIsNone(self.manager.worker_class_instance) + + def test_set_worker_with_input_cfg(self): + """Test setting worker with separate config.""" + + def my_worker(cfg): + return {"result": cfg.theta} + + self.manager.set_worker(my_worker, ExampleBenchCfg) + # Worker should be wrapped with config - it's now a partial + self.assertIsNotNone(self.manager.worker) + self.assertEqual(self.manager.worker_input_cfg, ExampleBenchCfg) + + def test_set_worker_class_type_error(self): + """Test error when class type passed instead of instance.""" + with self.assertRaises(RuntimeError): + self.manager.set_worker(ExampleBenchCfg) # Class, not instance + + def test_get_result_vars_as_str(self): + """Test getting result var names as strings.""" + self.manager.set_worker(ExampleBenchCfg()) + result_vars = self.manager.get_result_vars(as_str=True) + self.assertIsInstance(result_vars[0], str) + self.assertIn("out_sin", result_vars) + + def test_get_result_vars_as_params(self): + """Test getting result vars as Parameter objects.""" + self.manager.set_worker(ExampleBenchCfg()) + result_vars = self.manager.get_result_vars(as_str=False) + self.assertTrue(hasattr(result_vars[0], "name")) + + def test_get_result_vars_no_instance_error(self): + """Test error when worker instance not set.""" + with self.assertRaises(RuntimeError): + self.manager.get_result_vars() + + def test_get_inputs_only(self): + """Test getting input variables.""" + self.manager.set_worker(ExampleBenchCfg()) + inputs = self.manager.get_inputs_only() + self.assertIsInstance(inputs, list) + self.assertGreater(len(inputs), 0) + + def test_get_inputs_only_no_instance_error(self): + """Test error when worker instance not set for get_inputs_only.""" + with self.assertRaises(RuntimeError): + self.manager.get_inputs_only() + + def test_get_input_defaults(self): + """Test getting default input values.""" + self.manager.set_worker(ExampleBenchCfg()) + defaults = self.manager.get_input_defaults() + self.assertIsInstance(defaults, list) + + def test_get_input_defaults_no_instance_error(self): + """Test error when worker instance not set for get_input_defaults.""" + with self.assertRaises(RuntimeError): + self.manager.get_input_defaults() + + # Hypothesis property-based tests + @settings(deadline=10000) + @given(as_str=st.booleans()) + def test_get_result_vars_return_type(self, as_str): + """Property: return type matches as_str parameter.""" + self.manager.set_worker(ExampleBenchCfg()) + result_vars = self.manager.get_result_vars(as_str=as_str) + if as_str: + self.assertTrue(all(isinstance(v, str) for v in result_vars)) + else: + self.assertTrue(all(hasattr(v, "name") for v in result_vars)) + + +class TestKwargsToInputCfg(unittest.TestCase): + """Tests for kwargs_to_input_cfg function.""" + + def test_creates_instance(self): + """Test that it creates an instance of the config class.""" + cfg = kwargs_to_input_cfg(ExampleBenchCfg) + self.assertIsInstance(cfg, ExampleBenchCfg) + + def test_updates_with_kwargs(self): + """Test that kwargs are applied to the config.""" + cfg = kwargs_to_input_cfg(ExampleBenchCfg, theta=1.5) + self.assertEqual(cfg.theta, 1.5) + + +class TestWorkerCfgWrapper(unittest.TestCase): + """Tests for worker_cfg_wrapper function.""" + + def test_wrapper_calls_worker_with_config(self): + """Test wrapper creates config instance correctly.""" + call_log = [] + + def my_worker(cfg): + call_log.append(cfg) + return {"result": cfg.theta} + + result = worker_cfg_wrapper(my_worker, ExampleBenchCfg, theta=2.0) + + self.assertEqual(len(call_log), 1) + self.assertIsInstance(call_log[0], ExampleBenchCfg) + self.assertEqual(call_log[0].theta, 2.0) + self.assertEqual(result, {"result": 2.0}) + + +if __name__ == "__main__": + unittest.main()