-
Notifications
You must be signed in to change notification settings - Fork 3
Decompose Bench class into focused helper classes #687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Extract SweepExecutor, ResultCollector, and WorkerManager from the monolithic Bench class (~1200 lines) to improve maintainability. - SweepExecutor: parameter conversion, cache management - ResultCollector: xarray dataset operations, result storage - WorkerManager: worker function setup and validation Bench remains the coordinator, delegating to helpers while preserving the public API. Adds 47 new tests with Hypothesis property-based testing.
Reviewer's GuideRefactors the large Bench class by extracting worker management, sweep execution, and result collection into dedicated helper classes (WorkerManager, SweepExecutor, ResultCollector), while preserving Bench’s public API via thin delegation wrappers and adding focused unit tests for the new components. Sequence diagram for Bench coordinating helpers during a sweepsequenceDiagram
participant Bench
participant WorkerManager
participant SweepExecutor
participant ResultCollector
participant WorkerFunc
participant SampleCache
Bench->>WorkerManager: set_worker(worker, worker_input_cfg)
WorkerManager-->>Bench: worker, worker_class_instance, worker_input_cfg
Bench->>SweepExecutor: init_sample_cache(run_cfg)
SweepExecutor->>SweepExecutor: create FutureCache
SweepExecutor-->>Bench: sample_cache
Bench->>SampleCache: reference via sample_cache property
Bench->>ResultCollector: setup_dataset(bench_cfg, time_src)
ResultCollector-->>Bench: bench_res, function_inputs, dims_name
loop for each function_input
Bench->>WorkerFunc: worker_kwargs_wrapper(worker, bench_cfg, kwargs)
WorkerFunc-->>Bench: result
Bench->>ResultCollector: store_results(job_result, bench_res, worker_job, bench_run_cfg)
end
Bench->>ResultCollector: cache_results(bench_res, bench_cfg_hash)
ResultCollector-->>Bench: updated_bench_cfg_hashes
Bench->>ResultCollector: report_results(bench_res, print_xarray, print_pandas)
ResultCollector-->>Bench: logs
Bench->>SweepExecutor: clear_call_counts()
SweepExecutor-->>Bench: done
Class diagram for Bench decomposition into helper classesclassDiagram
class Bench {
+str bench_name
+BenchCfg bench_cfg
+BenchRunCfg run_cfg
+WorkerManager _worker_mgr
+SweepExecutor _executor
+ResultCollector _collector
+Callable worker
+ParametrizedSweep worker_class_instance
+ParametrizedSweep worker_input_cfg
+property sample_cache
+property ds_dynamic
+set_worker(worker, worker_input_cfg)
+convert_vars_to_params(variable, var_type, run_cfg)
+cache_results(bench_res, bench_cfg_hash)
+load_history_cache(dataset, bench_cfg_hash, clear_history)
+setup_dataset(bench_cfg, time_src)
+define_const_inputs(const_vars)
+define_extra_vars(bench_cfg, repeats, time_src)
+store_results(job_result, bench_res, worker_job, bench_run_cfg)
+init_sample_cache(run_cfg)
+clear_tag_from_sample_cache(tag, run_cfg)
+add_metadata_to_dataset(bench_res, input_var)
+report_results(bench_res, print_xarray, print_pandas)
}
class WorkerManager {
+Callable worker
+ParametrizedSweep worker_class_instance
+ParametrizedSweep worker_input_cfg
+set_worker(worker, worker_input_cfg)
+get_result_vars(as_str)
+get_inputs_only()
+get_input_defaults()
}
class SweepExecutor {
+int cache_size
+FutureCache sample_cache
+convert_vars_to_params(variable, var_type, run_cfg, worker_class_instance, worker_input_cfg)
+define_const_inputs(const_vars)
+init_sample_cache(run_cfg)
+clear_tag_from_sample_cache(tag, run_cfg)
+clear_call_counts()
+close_cache()
+get_cache_stats()
}
class ResultCollector {
+int cache_size
+dict ds_dynamic
+setup_dataset(bench_cfg, time_src)
+define_extra_vars(bench_cfg, repeats, time_src)
+store_results(job_result, bench_res, worker_job, bench_run_cfg)
+cache_results(bench_res, bench_cfg_hash)
+load_history_cache(dataset, bench_cfg_hash, clear_history)
+add_metadata_to_dataset(bench_res, input_var)
+report_results(bench_res, print_xarray, print_pandas)
}
Bench --> WorkerManager : uses
Bench --> SweepExecutor : uses
Bench --> ResultCollector : uses
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've found 13 issues, and left some high level feedback:
- ResultCollector.cache_results() now always creates a fresh bench_cfg_hashes list rather than updating an existing one, so multiple cache_results calls in a single process will only persist the last hash for a bench_name; if you rely on storing multiple config hashes per benchmark, you probably want to accept the existing list from Bench or read/extend the list from the cache before writing it back.
- Each of result_collector.py, sweep_executor.py, and worker_manager.py calls logging.basicConfig(), which is normally undesirable in library code because it configures the root logger on import; consider removing these and leaving logging setup to the application.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- ResultCollector.cache_results() now always creates a fresh bench_cfg_hashes list rather than updating an existing one, so multiple cache_results calls in a single process will only persist the last hash for a bench_name; if you rely on storing multiple config hashes per benchmark, you probably want to accept the existing list from Bench or read/extend the list from the cache before writing it back.
- Each of result_collector.py, sweep_executor.py, and worker_manager.py calls logging.basicConfig(), which is normally undesirable in library code because it configures the root logger on import; consider removing these and leaving logging setup to the application.
## Individual Comments
### Comment 1
<location> `bencher/bencher.py:539` </location>
<code_context>
- logging.info(f"saving benchmark: {self.bench_name}")
- c[self.bench_name] = self.bench_cfg_hashes
+ """Cache benchmark results to disk using the config hash as key."""
+ self.bench_cfg_hashes = self._collector.cache_results(bench_res, bench_cfg_hash)
# def show(self, run_cfg: BenchRunCfg = None, pane: pn.panel = None) -> None:
</code_context>
<issue_to_address>
**issue (bug_risk):** bench_cfg_hashes is overwritten on each cache call, losing accumulated history
Before this change, `cache_results` appended to `self.bench_cfg_hashes`, keeping a growing history of config hashes. Now `ResultCollector.cache_results` builds a new list and you assign it directly to `self.bench_cfg_hashes`, so it only ever holds the latest hash. This changes observable behavior for any code that expects historical hashes. Either have `ResultCollector` extend the existing list, or have `Bench` append the returned hash instead of replacing the whole list.
</issue_to_address>
### Comment 2
<location> `bencher/result_collector.py:252-261` </location>
<code_context>
+ def cache_results(self, bench_res: BenchResult, bench_cfg_hash: str) -> List[str]:
</code_context>
<issue_to_address>
**issue (bug_risk):** ResultCollector.cache_results reinitializes the hash list instead of extending existing cache metadata
Within `cache_results`, `bench_cfg_hashes` is always a new list containing only the current `bench_cfg_hash`, and `c[bench_res.bench_cfg.bench_name]` is overwritten with it. This drops any previously cached hashes for that benchmark name. Consider reading the existing list from the cache (if any), appending the new hash, and writing it back to preserve the full history.
</issue_to_address>
### Comment 3
<location> `bencher/result_collector.py:39` </location>
<code_context>
+# Default cache size for benchmark results (100 GB)
+DEFAULT_CACHE_SIZE_BYTES = int(100e9)
+
+logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
+
+
</code_context>
<issue_to_address>
**suggestion:** Calling logging.basicConfig in library modules can interfere with application-level logging configuration
This and the new executor/worker modules configure logging at import time with `basicConfig`, which can override the host application's logging setup. Instead, prefer a module-level logger via `logging.getLogger(__name__)` and let the embedding application (or an explicit helper) handle configuration.
Suggested implementation:
```python
```
```python
import logging
logger = logging.getLogger(__name__)
from datetime as datetime
```
You should make similar updates in the new executor/worker modules that are also calling `logging.basicConfig` at import time:
1. Remove any `logging.basicConfig(...)` calls from those library modules.
2. Ensure each module defines a module-level logger via:
`logger = logging.getLogger(__name__)`
3. If you want default logging for CLI/tools, add configuration in the application entrypoint (e.g., `if __name__ == "__main__":`) or a dedicated helper, not in library code.
</issue_to_address>
### Comment 4
<location> `bencher/sweep_executor.py:67-76` </location>
<code_context>
+ def convert_vars_to_params(
</code_context>
<issue_to_address>
**issue (bug_risk):** convert_vars_to_params assumes worker_class_instance is set when variable is a str/dict without validating it
In the `str`/`dict` branches this unconditionally calls `worker_class_instance.param.objects(...)`, so a `None` `worker_class_instance` (e.g. plain function worker) will raise a confusing `AttributeError`. Please either validate `worker_class_instance` is not `None` with a clear error, or handle the `None` case explicitly.
</issue_to_address>
### Comment 5
<location> `bencher/worker_manager.py:68-77` </location>
<code_context>
+ def set_worker(
</code_context>
<issue_to_address>
**nitpick:** worker_input_cfg parameter is annotated as non-optional but used as optional
The signature uses `worker_input_cfg: ParametrizedSweep = None`, but the type hint doesn’t allow `None`. Consider updating it to `Optional[ParametrizedSweep]` (or making it required) so the annotation matches the actual usage and static type checkers behave correctly.
</issue_to_address>
### Comment 6
<location> `test/test_result_collector.py:15` </location>
<code_context>
+from bencher.bench_cfg import BenchCfg
+
+
+class TestResultCollector(unittest.TestCase):
+ """Tests for ResultCollector extracted from Bench."""
+
</code_context>
<issue_to_address>
**suggestion (testing):** Add tests for ResultCollector.store_results to cover all supported result variable types and error paths.
Current tests cover only dataset setup and meta-var creation, not `store_results`, where most of the behavior around `XARRAY_MULTIDIM_RESULT_TYPES`, `ResultVec`, `ResultReference`, `ResultDataSet`, and hmaps resides. Please add tests that:
- Use a small `BenchCfg` mixing `ResultVar`, `ResultVec`, `ResultReference`, `ResultDataSet`, and (if available) `ResultHmap`.
- Build a `BenchResult` via `setup_dataset` and exercise `store_results` with a fake `JobFuture` returning (a) a dict and (b) a param-based object.
- Verify that scalars/vectors land in the correct xarray indices, `dataset_list`/`object_index` are updated and indexed correctly, hmaps use the expected `canonical_input` keys, and unsupported result types raise `RuntimeError`.
This will validate that the extracted collector preserves the original indexing and error-handling semantics.
Suggested implementation:
```python
import unittest
from datetime import datetime
import numpy as np
import xarray as xr
from hypothesis import given, settings, strategies as st
from bencher.example.benchmark_data import ExampleBenchCfg
from bencher.result_collector import ResultCollector, set_xarray_multidim
from bencher.bench_cfg import (
BenchCfg,
ResultVar,
ResultVec,
ResultReference,
ResultDataSet,
)
try:
# ResultHmap may be optional in some configurations
from bencher.bench_cfg import ResultHmap
except ImportError: # pragma: no cover - only for older versions without ResultHmap
ResultHmap = None
class FakeJobFuture:
"""Minimal stand‑in for the JobFuture used by ResultCollector.store_results."""
def __init__(self, params, result):
# params: param-based object/dict describing the input point
# result: dict or object returned by the benchmark
self.params = params
self.result = result
class ParamObject:
"""Param-based object whose attributes are used as result fields."""
def __init__(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class TestResultCollector(unittest.TestCase):
"""Tests for ResultCollector extracted from Bench, including store_results."""
def setUp(self) -> None:
# A tiny BenchCfg mixing the supported result types
result_vars = [
ResultVar("scalar_metric", int, description="scalar result"),
ResultVec("vector_metric", float, length=3, description="vector result"),
ResultReference("ref_metric", description="reference/object-indexed result"),
ResultDataSet(
"dataset_metric",
description="dataset stored in dataset_list / object_index",
),
]
if ResultHmap is not None:
result_vars.append(
ResultHmap(
"hmap_metric",
description="hmap result stored by canonical_input",
)
)
# Minimal BenchCfg – use simple integer sweep so coordinates are predictable
self.cfg = BenchCfg(
name="store_results_test",
input_vars=[("x", [0, 1])],
result_vars=result_vars,
)
self.collector = ResultCollector(self.cfg)
# Use multidimensional layout so we exercise XARRAY_MULTIDIM_RESULT_TYPES
set_xarray_multidim(True)
self.bench_result = self.collector.setup_dataset()
def _assert_common_scalar_vector(self, idx):
ds = self.bench_result.data
# Scalar is stored at the current index along the "x" dimension
self.assertIn("scalar_metric", ds)
self.assertEqual(ds["scalar_metric"].sel(x=idx).item(), idx)
# Vector is stored as an array at the same index
self.assertIn("vector_metric", ds)
np.testing.assert_allclose(
ds["vector_metric"].sel(x=idx).values,
np.array([idx, idx + 1, idx + 2], dtype=float),
)
def test_store_results_with_dict_future(self):
"""store_results should handle a dict result and update xarray + object index."""
idx = 0
coord_dict = {"x": idx}
result_payload = {
"scalar_metric": idx,
"vector_metric": [idx, idx + 1, idx + 2],
"ref_metric": f"ref-{idx}",
"dataset_metric": xr.Dataset(
{"payload": (("dim",), np.array([1.0, 2.0, 3.0]))}
),
}
if ResultHmap is not None:
# canonical_input-based hmap key should be derived from params
result_payload["hmap_metric"] = {"a": 1, "b": 2}
future = FakeJobFuture(params={"x": idx}, result=result_payload)
self.collector.store_results(self.bench_result, future, coord_dict)
# Scalars/vectors in xarray
self._assert_common_scalar_vector(idx)
# Reference & dataset use dataset_list/object_index mapping
object_index = self.bench_result.object_index
dataset_list = self.bench_result.dataset_list
# ref_metric should map to an object index pointing at the stored object
self.assertIn("ref_metric", object_index)
ref_index = object_index["ref_metric"].sel(x=idx).item()
self.assertIsInstance(ref_index, (int, np.integer))
self.assertEqual(dataset_list[ref_index], "ref-0")
# dataset_metric similarly stored in dataset_list
self.assertIn("dataset_metric", object_index)
ds_index = object_index["dataset_metric"].sel(x=idx).item()
stored_ds = dataset_list[ds_index]
self.assertIsInstance(stored_ds, xr.Dataset)
np.testing.assert_allclose(
stored_ds["payload"].values, np.array([1.0, 2.0, 3.0])
)
if ResultHmap is not None:
# hmap_metric should be indexed by canonical_input derived from params
self.assertIn("hmap_metric", self.bench_result.hmaps)
hmap = self.bench_result.hmaps["hmap_metric"]
# We expect a single entry keyed by the canonicalized params
self.assertEqual(len(hmap), 1)
# Value should match the dict we provided
self.assertEqual(next(iter(hmap.values())), {"a": 1, "b": 2})
def test_store_results_with_object_future(self):
"""store_results should handle an object result with attributes as fields."""
idx = 1
coord_dict = {"x": idx}
# Param-based object for results
result_obj = ParamObject(
scalar_metric=idx,
vector_metric=[idx, idx + 1, idx + 2],
ref_metric=f"ref-{idx}",
dataset_metric=xr.Dataset(
{"payload": (("dim",), np.array([4.0, 5.0]))}
),
)
if ResultHmap is not None:
setattr(result_obj, "hmap_metric", {"c": 3})
future = FakeJobFuture(params=ParamObject(x=idx), result=result_obj)
self.collector.store_results(self.bench_result, future, coord_dict)
# Scalars/vectors in xarray
self._assert_common_scalar_vector(idx)
object_index = self.bench_result.object_index
dataset_list = self.bench_result.dataset_list
# ref_metric uses object_index mapping
ref_index = object_index["ref_metric"].sel(x=idx).item()
self.assertEqual(dataset_list[ref_index], f"ref-{idx}")
# dataset_metric stored in dataset_list
ds_index = object_index["dataset_metric"].sel(x=idx).item()
stored_ds = dataset_list[ds_index]
np.testing.assert_allclose(
stored_ds["payload"].values, np.array([4.0, 5.0])
)
if ResultHmap is not None:
hmap = self.bench_result.hmaps["hmap_metric"]
self.assertEqual(len(hmap), 2) # from previous test + this one
self.assertIn({"c": 3}, hmap.values())
def test_store_results_unsupported_type_raises(self):
"""Unsupported result variable types should raise RuntimeError."""
idx = 0
coord_dict = {"x": idx}
class Unsupported:
pass
result_payload = {"scalar_metric": Unsupported()}
future = FakeJobFuture(params={"x": idx}, result=result_payload)
with self.assertRaises(RuntimeError):
self.collector.store_results(self.bench_result, future, coord_dict)
```
These tests assume the following APIs, which you should verify and adjust to match your codebase:
1. `BenchCfg` accepts `input_vars` as a list of `(name, values)` pairs and `result_vars` as a list of `ResultVar`/`ResultVec`/`ResultReference`/`ResultDataSet`/`ResultHmap` instances. If your constructor differs, adapt the `self.cfg = BenchCfg(...)` call in `setUp`.
2. `ResultVar`, `ResultVec`, `ResultReference`, `ResultDataSet`, and `ResultHmap` are imported from `bencher.bench_cfg`. If they live in a different module, update the import section accordingly.
3. `ResultCollector.setup_dataset()` is assumed to return an object with `data` (an `xarray.Dataset`), `object_index` (xarray-based index for object-backed results), `dataset_list` (Python list backing those indices), and `hmaps` (dict storing hmap results by canonical input). If your `BenchResult` naming is different, adjust attribute access in the tests.
4. `ResultCollector.store_results(bench_result, future, coord_dict)` is assumed to:
- Read the raw result from `future.result` (either a dict or an object whose attributes are the result variables).
- Use `future.params` to derive the canonical input key for `ResultHmap`.
- Index into `bench_result.data` using `coord_dict` to place scalar/vector results.
- Update `bench_result.dataset_list` and `bench_result.object_index` for `ResultReference` and `ResultDataSet`.
- Raise `RuntimeError` on unsupported result types.
If your `JobFuture` exposes different attributes or the argument order for `store_results` differs, adjust the `FakeJobFuture` class and the calls in the three tests.
5. The expected shapes and coordinates for the xarray variables (`"x"` dimension, indexing via `.sel(x=idx)`) are based on a simple integer input variable; if your coordinate names or types differ, update the `coord_dict` and the `.sel(...)` accessors accordingly.
</issue_to_address>
### Comment 7
<location> `test/test_result_collector.py:31` </location>
<code_context>
+ collector = ResultCollector(cache_size=int(50e9))
+ self.assertEqual(collector.cache_size, int(50e9))
+
+ def test_setup_dataset_creates_bench_result(self):
+ """Test xarray dataset has correct structure."""
+ instance = ExampleBenchCfg()
</code_context>
<issue_to_address>
**suggestion (testing):** Add tests for cache_results and load_history_cache to validate caching semantics and over_time behavior.
These methods encapsulate important behavior (diskcache usage, `object_index` strip/restore, writing benchmark-hash lists, and `over_time` concatenation) but aren’t covered by tests.
Please add tests that:
- Call `cache_results` twice with different `bench_cfg_hash` values and verify via `Cache` that:
- `BenchResult` is stored without `object_index`, but remains present in-memory.
- The key for `bench_res.bench_cfg.bench_name` stores the expected list of hashes.
- Cover `load_history_cache` for:
- `clear_history=True` (no concat; previous history removed/overwritten).
- `clear_history=False` with and without an existing entry, ensuring `xr.concat` is only used when a prior dataset exists.
This will help ensure the extracted collector preserves the original persistence/history behavior of `Bench`.
Suggested implementation:
```python
def test_init_custom_cache_size(self):
"""Test custom cache size is set."""
collector = ResultCollector(cache_size=int(50e9))
self.assertEqual(collector.cache_size, int(50e9))
def test_cache_results_persists_without_object_index_and_updates_hash_list(self):
"""cache_results should strip object_index in cache, but keep it in memory and update hash list."""
# Arrange: create a simple bench result
instance = ExampleBenchCfg()
bench_cfg = BenchCfg(
input_vars=[instance.param.theta],
result_vars=[instance.param.out_sin],
const_vars=[],
bench_name="test_bench_cache",
title="test",
repeats=2,
)
bench_res, func_inputs, dims_name = self.collector.setup_dataset(
bench_cfg=bench_cfg,
func=instance.bench,
)
# Simulate an object_index present only in the in-memory bench_res
bench_res.object_index = ["obj-1", "obj-2"]
bench_cfg_hash_1 = "hash-1"
bench_cfg_hash_2 = "hash-2"
# Act: cache results twice with different hashes
self.collector.cache_results(bench_res, bench_cfg_hash_1)
self.collector.cache_results(bench_res, bench_cfg_hash_2)
cache = self.collector.cache
# Assert: in-memory bench_res still has object_index
self.assertEqual(bench_res.object_index, ["obj-1", "obj-2"])
# Assert: cached bench results do not retain object_index
stored_res_1 = cache[bench_cfg_hash_1]
stored_res_2 = cache[bench_cfg_hash_2]
# Stored results must not have object_index, or it must be cleared
self.assertTrue(
not hasattr(stored_res_1, "object_index")
or stored_res_1.object_index in (None, [], {})
)
self.assertTrue(
not hasattr(stored_res_2, "object_index")
or stored_res_2.object_index in (None, [], {})
)
# Assert: the bench_name key stores the expected list of hashes
bench_name_key = bench_res.bench_cfg.bench_name
self.assertIn(bench_name_key, cache)
self.assertEqual(cache[bench_name_key], [bench_cfg_hash_1, bench_cfg_hash_2])
def test_load_history_cache_clear_history_overwrites_without_concat(self):
"""load_history_cache with clear_history=True should not concat and should overwrite history."""
import sys
from unittest import mock
instance = ExampleBenchCfg()
bench_cfg = BenchCfg(
input_vars=[instance.param.theta],
result_vars=[instance.param.out_sin],
const_vars=[],
bench_name="test_bench_history_clear",
title="test",
repeats=2,
)
bench_res, func_inputs, dims_name = self.collector.setup_dataset(
bench_cfg=bench_cfg,
func=instance.bench,
)
bench_cfg_hash = "hash-clear"
# Simulate an existing history entry in cache for this bench
cache = self.collector.cache
existing_history_ds = bench_res.results # or similar dataset-like attribute
cache_key_history = bench_res.bench_cfg.bench_name
cache[cache_key_history] = [bench_cfg_hash]
cache[bench_cfg_hash] = existing_history_ds
# Patch xr.concat on the module that defines ResultCollector
rc_module = sys.modules[self.collector.__class__.__module__]
with mock.patch.object(rc_module.xr, "concat", wraps=rc_module.xr.concat) as mock_concat:
# Act
self.collector.load_history_cache(
bench_res=bench_res,
bench_cfg_hash=bench_cfg_hash,
clear_history=True,
)
# Assert: no concatenation should have been attempted
mock_concat.assert_not_called()
# Assert: cache history for this bench name should have been overwritten, not extended
hashes_after = cache[cache_key_history]
self.assertEqual(hashes_after, [bench_cfg_hash])
def test_load_history_cache_appends_and_concats_when_history_exists(self):
"""load_history_cache with clear_history=False should concat when previous history exists."""
import sys
from unittest import mock
instance = ExampleBenchCfg()
bench_cfg = BenchCfg(
input_vars=[instance.param.theta],
result_vars=[instance.param.out_sin],
const_vars=[],
bench_name="test_bench_history_concat",
title="test",
repeats=2,
)
bench_res, func_inputs, dims_name = self.collector.setup_dataset(
bench_cfg=bench_cfg,
func=instance.bench,
)
bench_cfg_hash_old = "hash-old"
bench_cfg_hash_new = "hash-new"
cache = self.collector.cache
cache_key_history = bench_res.bench_cfg.bench_name
# Simulate an existing dataset in cache and a previous hash list
old_ds = bench_res.results # or similar dataset-like attribute
cache[cache_key_history] = [bench_cfg_hash_old]
cache[bench_cfg_hash_old] = old_ds
# Patch xr.concat on the module that defines ResultCollector
rc_module = sys.modules[self.collector.__class__.__module__]
with mock.patch.object(rc_module.xr, "concat", wraps=rc_module.xr.concat) as mock_concat:
# Act
self.collector.load_history_cache(
bench_res=bench_res,
bench_cfg_hash=bench_cfg_hash_new,
clear_history=False,
)
# Assert: concat should have been called because there was an existing entry
mock_concat.assert_called()
# Assert: history list for this bench name should include both old and new hashes
hashes_after = cache[cache_key_history]
self.assertEqual(hashes_after, [bench_cfg_hash_old, bench_cfg_hash_new])
def test_load_history_cache_no_concat_when_no_existing_history(self):
"""load_history_cache with clear_history=False should not concat if there is no existing history."""
import sys
from unittest import mock
instance = ExampleBenchCfg()
bench_cfg = BenchCfg(
input_vars=[instance.param.theta],
result_vars=[instance.param.out_sin],
const_vars=[],
bench_name="test_bench_history_none",
title="test",
repeats=2,
)
bench_res, func_inputs, dims_name = self.collector.setup_dataset(
bench_cfg=bench_cfg,
func=instance.bench,
)
bench_cfg_hash = "hash-new-only"
cache = self.collector.cache
cache_key_history = bench_res.bench_cfg.bench_name
# Ensure no history exists for this bench
if cache_key_history in cache:
del cache[cache_key_history]
rc_module = sys.modules[self.collector.__class__.__module__]
with mock.patch.object(rc_module.xr, "concat", wraps=rc_module.xr.concat) as mock_concat:
# Act
self.collector.load_history_cache(
bench_res=bench_res,
bench_cfg_hash=bench_cfg_hash,
clear_history=False,
)
# Assert: concat should not be called when no previous history exists
mock_concat.assert_not_called()
# Assert: history list should contain only the new hash
self.assertIn(cache_key_history, cache)
self.assertEqual(cache[cache_key_history], [bench_cfg_hash])
def test_setup_dataset_creates_bench_result(self):
```
These tests assume the following about your implementation, and you may need to adjust them to match the concrete API:
1. **ResultCollector API**
- `self.collector` is available in the test class (likely set up in `setUp`) and exposes:
- `self.collector.cache` as a `diskcache.Cache`-like object (supporting `__getitem__`, `__setitem__`, membership via `in`, etc.).
- `self.collector.setup_dataset(bench_cfg, func)` returning a `bench_res` object and additional values.
- `self.collector.cache_results(bench_res, bench_cfg_hash)` to persist results.
- `self.collector.load_history_cache(bench_res, bench_cfg_hash, clear_history: bool)` to restore/append history.
2. **BenchResult / BenchCfg layout**
- `bench_res.bench_cfg.bench_name` exists and is the key used to store the list of hashes.
- `bench_res.object_index` exists and is the attribute that gets stripped before caching.
- `bench_res.results` (or equivalent) is the xarray `Dataset` that is persisted to / restored from cache. If you use a different attribute name (e.g. `bench_res.ds`, `bench_res.dataset`), update the tests accordingly.
3. **Cache keying**
- The tests assume that the `bench_res` (or its dataset) is stored directly under the key `bench_cfg_hash` and that the list of hashes is stored under the key `bench_res.bench_cfg.bench_name`. If your implementation uses namespaced keys (e.g. `"bench_result:{hash}"`, `"history:{bench_name}"`), adapt the key construction in the tests to match.
4. **xarray concat patching**
- The tests dynamically patch `xr.concat` on the module that defines `ResultCollector` using:
```python
rc_module = sys.modules[self.collector.__class__.__module__]
with mock.patch.object(rc_module.xr, "concat", wraps=rc_module.xr.concat) as mock_concat:
...
```
This assumes that the collector module does `import xarray as xr`. If your import is aliased differently or accessed via another name, adjust the patch target accordingly.
5. **Imports**
- Ensure at the top of `test/test_result_collector.py` you have:
```python
import sys
from unittest import mock
```
or remove the inline imports inside the test methods and centralize them at module scope, per your project’s style.
6. **Dimension / combining semantics**
- If `load_history_cache` concatenates along a specific dimension (e.g. `"over_time"` or `"history"`), and the concatenation behavior affects a particular attribute on `bench_res`, you may want to strengthen the assertions to inspect the resulting dataset (e.g. checking the length of the concatenated dimension) in addition to just verifying `xr.concat` is called. Adjust those assertions once you align the attribute and dimension names with your implementation.
</issue_to_address>
### Comment 8
<location> `test/test_result_collector.py:123` </location>
<code_context>
+ self.assertEqual(len(extra_vars), 2)
+ self.assertEqual(extra_vars[1].name, "over_time")
+
+ def test_report_results_no_print(self):
+ """Test report_results with printing disabled."""
+ instance = ExampleBenchCfg()
</code_context>
<issue_to_address>
**suggestion (testing):** Add tests for add_metadata_to_dataset to ensure xarray attributes are populated correctly for scalar and vector results.
`ResultCollector.add_metadata_to_dataset` isn’t covered by tests but controls the units/long_name/description attributes used by visualization and downstream consumers.
Please add tests that:
- Build a `BenchResult` (via `setup_dataset`) from a `BenchCfg` with one `ResultVar` and one `ResultVec`, then call `add_metadata_to_dataset` for a chosen input variable.
- Assert scalar result variables have the expected `units` and `long_name` in `ds[rv.name].attrs`.
- Assert each index of vector results (`rv.index_name(i)`) inherits the parent vector’s `units` and `long_name`.
- Verify the input variable coordinate gets `long_name`, `units`, and `description` (from `__doc__`).
Suggested implementation:
```python
bench_cfg = BenchCfg(
input_vars=[instance.param.theta],
result_vars=[instance.param.out_sin],
const_vars=[],
bench_name="test",
title="test",
repeats=2,
)
bench_res, func_inputs, dims_name = self.collector.setup_dataset(
def test_add_metadata_to_dataset_scalar_and_vector(self):
"""Ensure add_metadata_to_dataset populates xarray attrs for scalar and vector results."""
instance = ExampleBenchCfg()
# Bench configuration with one scalar result and one vector result
bench_cfg = BenchCfg(
input_vars=[instance.param.theta],
# NOTE: out_vec should be replaced with the actual vector ResultVar in ExampleBenchCfg
result_vars=[instance.param.out_sin, instance.param.out_vec],
const_vars=[],
bench_name="test_add_metadata",
title="test_add_metadata",
repeats=1,
)
bench_res, func_inputs, dims_name = self.collector.setup_dataset(
bench_cfg,
instance.run_bench,
)
# Call add_metadata_to_dataset for the chosen input variable
input_var = instance.param.theta
ds = bench_res.ds
# NOTE: adjust the call signature if add_metadata_to_dataset differs
self.collector.add_metadata_to_dataset(ds, bench_cfg, input_var)
# Scalar result attributes
scalar_rv = instance.param.out_sin
self.assertEqual(
ds[scalar_rv.name].attrs.get("units"),
scalar_rv.units,
)
self.assertEqual(
ds[scalar_rv.name].attrs.get("long_name"),
scalar_rv.long_name,
)
# Vector result attributes: each index inherits parent's metadata
vector_rv = instance.param.out_vec
# NOTE: adjust iteration according to your ResultVec API
for i in range(vector_rv.size):
var_name = vector_rv.index_name(i)
self.assertEqual(
ds[var_name].attrs.get("units"),
vector_rv.units,
)
self.assertEqual(
ds[var_name].attrs.get("long_name"),
vector_rv.long_name,
)
# Input coordinate attributes
coord_name = input_var.name
self.assertEqual(
ds[coord_name].attrs.get("long_name"),
input_var.long_name,
)
self.assertEqual(
ds[coord_name].attrs.get("units"),
input_var.units,
)
self.assertEqual(
ds[coord_name].attrs.get("description"),
input_var.__doc__,
)
```
The edit above shows the new test body, but you will need to integrate it correctly into the file:
1. **Place the new test at class level**
- Ensure `test_add_metadata_to_dataset_scalar_and_vector` is indented at the same level as the other `test_*` methods inside the `TestResultCollector` (or equivalent) test class.
- Do **not** insert it in the middle of an existing method or inside a function call; move it outside `test_report_results_no_print` and after that method’s closing line.
2. **Use the actual vector `ResultVar`**
- Replace `instance.param.out_vec` with the real vector result in `ExampleBenchCfg` (e.g., `instance.param.out_vec`, `instance.param.out_sin_vec`, etc.).
- Update the loop over the vector variable:
- If your vector type exposes length as `len(vector_rv)` or `vector_rv.n_indices`, change `range(vector_rv.size)` accordingly.
- Ensure `vector_rv.index_name(i)` matches the actual API that produces each component variable name in the dataset.
3. **Match `add_metadata_to_dataset`’s actual signature**
- Adjust the call `self.collector.add_metadata_to_dataset(ds, bench_cfg, input_var)` to match the real method signature.
- For example, if it expects `(bench_res, bench_cfg, input_var)` or `(ds, bench_cfg, input_var.name)`, pass those instead.
4. **Align attribute names with your data model**
- If your `ResultVar` / `ResultVec` classes use different attributes for metadata (e.g. `unit` instead of `units`, `label` instead of `long_name`), update the assertions:
- `scalar_rv.units` / `scalar_rv.long_name`
- `vector_rv.units` / `vector_rv.long_name`
- Likewise, if the input variable’s metadata is exposed differently from `input_var.long_name`, `input_var.units`, or its docstring, adapt the assertions so they correctly reflect how `add_metadata_to_dataset` maps metadata into `ds[...].attrs`.
5. **Ensure the coordinate name is correct**
- If the dataset uses a derived coordinate name instead of `input_var.name`, adjust `coord_name = input_var.name` to match whatever `setup_dataset` uses as the dimension/coordinate for that input variable.
Once these adjustments are made, the new test will validate that:
- Scalar result variables receive `units` and `long_name` attrs.
- Each component of a vector result inherits `units` and `long_name`.
- The chosen input coordinate gets `long_name`, `units`, and `description` (from the variable’s docstring or equivalent) populated via `add_metadata_to_dataset`.
</issue_to_address>
### Comment 9
<location> `test/test_sweep_executor.py:13` </location>
<code_context>
+from bencher.job import Executors
+
+
+class TestSweepExecutor(unittest.TestCase):
+ """Tests for SweepExecutor extracted from Bench."""
+
</code_context>
<issue_to_address>
**suggestion (testing):** Extend SweepExecutor tests to cover max_level handling, tag clearing, and cache stats.
Current tests don't exercise some important branches:
- `convert_vars_to_params` with a dict containing `max_level` and a non-`None` `run_cfg.level` to assert the `with_level` branch and correct level application.
- `clear_tag_from_sample_cache` when `sample_cache` is initially `None`, verifying it implicitly calls `init_sample_cache` and then removes entries with the tag (using a small fake `BenchRunCfg` and checking the underlying cache).
- `get_cache_stats` when a cache is present, confirming it returns a non-empty string and does not mutate the cache.
Covering these cases will better validate edge configurations and cache lifecycle handling.
Suggested implementation:
```python
import unittest
from hypothesis import given, settings, strategies as st
from bencher.example.benchmark_data import ExampleBenchCfg
from bencher.sweep_executor import SweepExecutor, worker_kwargs_wrapper
from bencher.bench_cfg import BenchCfg, BenchRunCfg
from bencher.job import Executors
class TestSweepExecutor(unittest.TestCase):
"""Tests for SweepExecutor extracted from Bench."""
def setUp(self) -> None:
# NOTE: Depending on the real ExampleBenchCfg / BenchCfg API this may need
# to be adapted (see <additional_changes>).
example_cfg = ExampleBenchCfg()
if isinstance(example_cfg, BenchCfg):
self.bench_cfg = example_cfg
elif hasattr(example_cfg, "bench_cfg"):
self.bench_cfg = example_cfg.bench_cfg
elif hasattr(example_cfg, "to_bench_cfg"):
self.bench_cfg = example_cfg.to_bench_cfg()
else:
# Fallback: allow the test to be adjusted to the actual API.
self.bench_cfg = BenchCfg()
# Prefer a simple/local executor to avoid external dependencies.
self.executor = SweepExecutor(self.bench_cfg, executor=Executors.LOCAL)
def test_convert_vars_to_params_respects_max_level_and_with_level(self) -> None:
"""convert_vars_to_params should apply max_level when with_level branch is used."""
# run_cfg has a non-None level to exercise the with_level branch.
# max_level in vars should cap or otherwise influence the applied level.
run_cfg = BenchRunCfg(level=3)
vars_dict = {"max_level": 1}
params = self.executor.convert_vars_to_params(vars_dict, run_cfg)
# The exact structure of params depends on SweepExecutor.
# We assert that level-related information reflects max_level logic.
# Adjust these assertions to match the real structure.
self.assertIsNotNone(params)
# Common patterns are either `params.level` or `params["level"]`.
level = getattr(params, "level", params.get("level", None) if isinstance(params, dict) else None)
self.assertIsNotNone(
level,
"convert_vars_to_params should propagate a level that can be inspected in tests",
)
# Whatever the underlying logic, max_level=1 should not result in a level > 1.
self.assertLessEqual(
level,
1,
"max_level in vars should cap the effective level used in parameters",
)
def test_clear_tag_from_sample_cache_initializes_and_clears_tag(self) -> None:
"""clear_tag_from_sample_cache should lazily initialize cache and remove entries for a tag."""
# Ensure we start from a None / empty cache state.
# Some implementations may use None, others an empty dict; normalize to None for the test.
if hasattr(self.executor, "sample_cache"):
self.executor.sample_cache = None
# Create a small BenchRunCfg with a tag to be cleared.
run_cfg = BenchRunCfg(tag="test_tag")
# Call the method; this should implicitly initialize the cache and then clear the tag.
self.executor.clear_tag_from_sample_cache("test_tag", run_cfg)
# After clearing, cache should exist.
self.assertIsNotNone(
getattr(self.executor, "sample_cache", None),
"clear_tag_from_sample_cache should initialize sample_cache when it is None",
)
cache = self.executor.sample_cache
# Depending on implementation, the cache might be a dict keyed by tag, or a structure
# with entries that have a `.tag` attribute. Here we check a couple of common patterns.
if isinstance(cache, dict):
# No entries should remain for the cleared tag.
for key in cache.keys():
self.assertNotEqual(
key,
"test_tag",
"Entries for the cleared tag should be removed from sample_cache",
)
elif isinstance(cache, (list, tuple, set)):
for entry in cache:
tag = getattr(entry, "tag", None)
self.assertNotEqual(
tag,
"test_tag",
"Entries for the cleared tag should be removed from sample_cache",
)
def test_get_cache_stats_returns_string_and_does_not_mutate_cache(self) -> None:
"""get_cache_stats should report on an existing cache without mutating it."""
# Ensure there is a cache present. If init_sample_cache exists, prefer it;
# otherwise rely on a side effect from another method.
if hasattr(self.executor, "init_sample_cache"):
self.executor.init_sample_cache()
elif getattr(self.executor, "sample_cache", None) is None:
# Use clear_tag to force lazy initialization when available.
run_cfg = BenchRunCfg(tag="init_tag")
self.executor.clear_tag_from_sample_cache("init_tag", run_cfg)
cache_before = getattr(self.executor, "sample_cache", None)
# Some caches are sized via len(), some via a dedicated size attribute.
if cache_before is not None:
if hasattr(cache_before, "__len__"):
size_before = len(cache_before)
else:
size_before = getattr(cache_before, "size", None)
else:
size_before = None
stats = self.executor.get_cache_stats()
self.assertIsInstance(stats, str)
self.assertNotEqual(
stats.strip(),
"",
"get_cache_stats should return a non-empty description when a cache is present",
)
cache_after = getattr(self.executor, "sample_cache", None)
# Cache object should not be replaced as a side effect.
self.assertIs(
cache_after,
cache_before,
"get_cache_stats should not replace the cache object",
)
# And its size should not be mutated by querying stats.
if cache_after is not None and hasattr(cache_after, "__len__") and size_before is not None:
self.assertEqual(
len(cache_after),
size_before,
"get_cache_stats should not mutate the cache size",
)
```
).
example_cfg = ExampleBenchCfg()
if isinstance(example_cfg, BenchCfg):
self.bench_cfg = example_cfg
elif hasattr(example_cfg, "bench_cfg"):
self.bench_cfg = example_cfg.bench_cfg
elif hasattr(example_cfg, "to_bench_cfg"):
self.bench_cfg = example_cfg.to_bench_cfg()
else:
# Fallback: allow the test to be adjusted to the actual API.
self.bench_cfg = BenchCfg()
# Prefer a simple/local executor to avoid external dependencies.
self.executor = SweepExecutor(self.bench_cfg, executor=Executors.LOCAL)
def test_convert_vars_to_params_respects_max_level_and_with_level(self) -> None:
"""convert_vars_to_params should apply max_level when with_level branch is used."""
# run_cfg has a non-None level to exercise the with_level branch.
# max_level in vars should cap or otherwise influence the applied level.
run_cfg = BenchRunCfg(level=3)
vars_dict = {"max_level": 1}
params = self.executor.convert_vars_to_params(vars_dict, run_cfg)
# The exact structure of params depends on SweepExecutor.
# We assert that level-related information reflects max_level logic.
# Adjust these assertions to match the real structure.
self.assertIsNotNone(params)
# Common patterns are either `params.level` or `params["level"]`.
level = getattr(params, "level", params.get("level", None) if isinstance(params, dict) else None)
self.assertIsNotNone(
level,
"convert_vars_to_params should propagate a level that can be inspected in tests",
)
# Whatever the underlying logic, max_level=1 should not result in a level > 1.
self.assertLessEqual(
level,
1,
"max_level in vars should cap the effective level used in parameters",
)
def test_clear_tag_from_sample_cache_initializes_and_clears_tag(self) -> None:
"""clear_tag_from_sample_cache should lazily initialize cache and remove entries for a tag."""
# Ensure we start from a None / empty cache state.
# Some implementations may use None, others an empty dict; normalize to None for the test.
if hasattr(self.executor, "sample_cache"):
self.executor.sample_cache = None
# Create a small BenchRunCfg with a tag to be cleared.
run_cfg = BenchRunCfg(tag="test_tag")
# Call the method; this should implicitly initialize the cache and then clear the tag.
self.executor.clear_tag_from_sample_cache("test_tag", run_cfg)
# After clearing, cache should exist.
self.assertIsNotNone(
getattr(self.executor, "sample_cache", None),
"clear_tag_from_sample_cache should initialize sample_cache when it is None",
)
cache = self.executor.sample_cache
# Depending on implementation, the cache might be a dict keyed by tag, or a structure
# with entries that have a `.tag` attribute. Here we check a couple of common patterns.
if isinstance(cache, dict):
# No entries should remain for the cleared tag.
for key in cache.keys():
self.assertNotEqual(
key,
"test_tag",
"Entries for the cleared tag should be removed from sample_cache",
)
elif isinstance(cache, (list, tuple, set)):
for entry in cache:
tag = getattr(entry, "tag", None)
self.assertNotEqual(
tag,
"test_tag",
"Entries for the cleared tag should be removed from sample_cache",
)
def test_get_cache_stats_returns_string_and_does_not_mutate_cache(self) -> None:
"""get_cache_stats should report on an existing cache without mutating it."""
# Ensure there is a cache present. If init_sample_cache exists, prefer it;
# otherwise rely on a side effect from another method.
if hasattr(self.executor, "init_sample_cache"):
self.executor.init_sample_cache()
elif getattr(self.executor, "sample_cache", None) is None:
# Use clear_tag to force lazy initialization when available.
run_cfg = BenchRunCfg(tag="init_tag")
self.executor.clear_tag_from_sample_cache("init_tag", run_cfg)
cache_before = getattr(self.executor, "sample_cache", None)
# Some caches are sized via len(), some via a dedicated size attribute.
if cache_before is not None:
if hasattr(cache_before, "__len__"):
size_before = len(cache_before)
else:
size_before = getattr(cache_before, "size", None)
else:
size_before = None
stats = self.executor.get_cache_stats()
self.assertIsInstance(stats, str)
self.assertNotEqual(
stats.strip(),
"",
"get_cache_stats should return a non-empty description when a cache is present",
)
cache_after = getattr(self.executor, "sample_cache", None)
# Cache object should not be replaced as a side effect.
self.assertIs(
cache_after,
cache_before,
"get_cache_stats should not replace the cache object",
)
# And its size should not be mutated by querying stats.
if cache_after is not None and hasattr(cache_after, "__len__") and size_before is not None:
self.assertEqual(
len(cache_after),
size_before,
"get_cache_stats should not mutate the cache size",
)
>>>>>>> REPLACE
</file_operation>
</file_operations>
<additional_changes>
The above tests assume certain APIs and structures that may differ slightly from your actual codebase. To fully integrate them, you may need to:
1. **Instantiate `ExampleBenchCfg` / `BenchCfg` correctly**
- If `ExampleBenchCfg` is not itself a `BenchCfg`, but provides a factory (e.g. `ExampleBenchCfg().to_bench_cfg()` or a `bench_cfg` attribute), update the `setUp` method to use the correct pattern and remove the fallback `BenchCfg()` construction.
- If `BenchCfg` requires constructor arguments, provide valid ones according to your existing tests.
2. **Adjust `SweepExecutor` construction**
- Ensure that `SweepExecutor(self.bench_cfg, executor=Executors.LOCAL)` matches the real signature.
- If the executor enum or default executor differs, adapt `Executors.LOCAL` to the appropriate value used elsewhere in your tests.
3. **Align with the actual parameter structure of `convert_vars_to_params`**
- Inspect the return type of `convert_vars_to_params` and replace the generic level extraction logic with direct, type-safe assertions. For example:
- If it returns a dict: `self.assertEqual(params["level"], 1)`
- If it returns a dataclass: `self.assertEqual(params.level, 1)`
- Confirm that the `max_level` behavior is indeed capping the level and adjust the expected comparison (`<= 1`) to match the true semantics.
4. **Match your `sample_cache` structure for tag clearing**
- Check how `sample_cache` is structured and how tags are stored:
- If it is a dict keyed by tag, assert on `self.assertNotIn("test_tag", self.executor.sample_cache)`.
- If entries are objects with a `.tag` attribute, iterate accordingly and assert that no remaining entry has `tag == "test_tag"`.
- If `sample_cache` is not `None`-based but uses a different sentinel for uninitialized state, adapt the initial normalization in the test.
5. **Confirm cache initialization entry points**
- If there is a dedicated public method to initialize the cache (e.g. `init_sample_cache()` or part of running a sweep), prefer using that in both `test_clear_tag_from_sample_cache_initializes_and_clears_tag` and `test_get_cache_stats_returns_string_and_does_not_mutate_cache` to avoid depending on internal details.
6. **Refine `get_cache_stats` non-mutation checks**
- If the cache size is available via a specific API (e.g. `self.executor.sample_cache.size()` or `self.executor.get_cache_size()`), use that instead of `len(cache)` or `.size` attribute so the test aligns with your implementation.
</issue_to_address>
### Comment 10
<location> `test/test_worker_manager.py:10` </location>
<code_context>
+from bencher.worker_manager import WorkerManager, worker_cfg_wrapper, kwargs_to_input_cfg
+
+
+class TestWorkerManager(unittest.TestCase):
+ """Tests for WorkerManager extracted from Bench."""
+
</code_context>
<issue_to_address>
**suggestion (testing):** Consider adding a small integration-style test that verifies WorkerManager interacts correctly with SweepExecutor/ResultCollector via Bench.
Since this refactor moves Bench’s responsibilities into WorkerManager, SweepExecutor, and ResultCollector, it would be helpful to add a higher-level Bench test that:
- Builds a Bench with an ExampleBenchCfg worker
- Runs a minimal sweep
- Verifies key Bench attributes (`worker`, `worker_class_instance`, `worker_input_cfg`, `sample_cache`, `ds_dynamic`) and basic BenchResult invariants still match the previous behavior
This regression test would complement the existing WorkerManager unit tests and ensure the external Bench API remains intact after the decomposition.
Suggested implementation:
```python
"""Tests for WorkerManager extracted from Bench."""
import unittest
from hypothesis import given, settings, strategies as st
from bencher.example.benchmark_data import ExampleBenchCfg
from bencher.worker_manager import WorkerManager, worker_cfg_wrapper, kwargs_to_input_cfg
# Integration-style imports to exercise WorkerManager via Bench/SweepExecutor/ResultCollector.
# These module paths may need to be adjusted to match the actual project layout.
from bencher.bench import Bench
from bencher.sweep import SweepExecutor
from bencher.result import ResultCollector, BenchResult
class TestWorkerManager(unittest.TestCase):
"""Tests for WorkerManager extracted from Bench."""
class TestBenchWorkerManagerIntegration(unittest.TestCase):
"""Higher-level regression tests to ensure Bench still exposes the expected API
while delegating to WorkerManager, SweepExecutor, and ResultCollector.
"""
def test_bench_minimal_sweep_preserves_public_api(self) -> None:
"""Build a Bench with ExampleBenchCfg, run a minimal sweep, and verify
the core attributes and BenchResult invariants remain intact.
"""
# Arrange: create a worker config and Bench instance
worker_cfg = ExampleBenchCfg()
bench = Bench("worker_manager_integration", worker_cfg)
# Act: run a minimal sweep through Bench, which should internally use
# WorkerManager, SweepExecutor, and ResultCollector.
# The exact arguments may need to be adapted to the real Bench API.
result = bench.run(num_samples=1)
# Assert: external Bench-facing attributes are still present and coherent
# Worker-related attributes exposed by Bench
self.assertIs(bench.worker, worker_cfg)
self.assertIsNotNone(bench.worker_class_instance)
self.assertIsNotNone(bench.worker_input_cfg)
# Internal caching / dataset attributes expected on Bench
self.assertTrue(hasattr(bench, "sample_cache"))
self.assertIsNotNone(bench.sample_cache)
self.assertTrue(hasattr(bench, "ds_dynamic"))
self.assertIsNotNone(bench.ds_dynamic)
# Ensure Bench is wired to the right orchestration components
self.assertTrue(hasattr(bench, "worker_manager"))
self.assertIsInstance(bench.worker_manager, WorkerManager)
# If Bench exposes a sweep executor / result collector, validate them
if hasattr(bench, "sweep_executor"):
self.assertIsInstance(bench.sweep_executor, SweepExecutor)
if hasattr(bench, "result_collector"):
self.assertIsInstance(bench.result_collector, ResultCollector)
# Basic BenchResult invariants: type and non-empty content for a minimal run
self.assertIsInstance(result, BenchResult)
self.assertGreater(len(result.samples), 0, "BenchResult should contain at least one sample")
```
The integration test uses several assumptions about your public API which you may need to adjust:
1. **Bench import & constructor**
- Update `from bencher.bench import Bench` and `bench = Bench("worker_manager_integration", worker_cfg)` to match your actual Bench module path and constructor signature.
2. **SweepExecutor / ResultCollector / BenchResult imports**
- Change `from bencher.sweep import SweepExecutor` and `from bencher.result import ResultCollector, BenchResult` to the real module locations and class names.
3. **Bench.run semantics**
- Replace `bench.run(num_samples=1)` with the appropriate method and arguments that trigger a minimal sweep (e.g. `bench.run_sweep`, `bench.plot_sweep`, or a context-specific call).
4. **Bench attributes**
- The test currently checks for `worker`, `worker_class_instance`, `worker_input_cfg`, `sample_cache`, `ds_dynamic`, `worker_manager`, `sweep_executor`, and `result_collector`.
- If any of these are exposed via properties or differently named attributes, adjust the assertions accordingly (for example `bench.worker_cfg` instead of `bench.worker`, or `bench.dynamic_ds` instead of `ds_dynamic`).
5. **BenchResult structure**
- The test assumes `BenchResult` has a `samples` attribute that is iterable and non-empty after a minimal run. If your result object uses a different attribute (e.g. `data`, `records`, or `df`), change the final assertion to reflect that.
</issue_to_address>
### Comment 11
<location> `bencher/result_collector.py:42` </location>
<code_context>
handler.setFormatter(formatter)
-def set_xarray_multidim(
- data_array: xr.DataArray, index_tuple: Tuple[int, ...], value: Any
-) -> xr.DataArray:
</code_context>
<issue_to_address>
**issue (complexity):** Consider refactoring ResultCollector to keep helpers internal, avoid configuring global logging here, and extract cache-handling details into small private methods to reduce surface area and cognitive load without changing behavior.
You can reduce complexity and surface area without changing behavior by tightening a few specific spots:
### 1. Make `set_xarray_multidim` an internal method
This helper is only used inside `ResultCollector.store_results` and doesn’t need to be a global. Moving it to a private method reduces global API surface and keeps related logic together.
```python
class ResultCollector:
def __init__(self, cache_size: int = DEFAULT_CACHE_SIZE_BYTES) -> None:
self.cache_size = cache_size
self.ds_dynamic: dict = {}
def _set_xarray_multidim(
self, data_array: xr.DataArray, index_tuple: Tuple[int, ...], value: Any
) -> None:
data_array[index_tuple] = value
def store_results(
self,
job_result: JobFuture,
bench_res: BenchResult,
worker_job: WorkerJob,
bench_run_cfg: BenchRunCfg,
) -> None:
...
for rv in bench_res.bench_cfg.result_vars:
...
if isinstance(rv, XARRAY_MULTIDIM_RESULT_TYPES):
self._set_xarray_multidim(
bench_res.ds[rv.name], worker_job.index_tuple, result_value
)
elif isinstance(rv, ResultDataSet):
bench_res.dataset_list.append(result_value)
self._set_xarray_multidim(
bench_res.ds[rv.name],
worker_job.index_tuple,
len(bench_res.dataset_list) - 1,
)
elif isinstance(rv, ResultReference):
bench_res.object_index.append(result_value)
self._set_xarray_multidim(
bench_res.ds[rv.name],
worker_job.index_tuple,
len(bench_res.object_index) - 1,
)
...
```
You can then remove the top-level `set_xarray_multidim` function entirely.
---
### 2. Avoid configuring global logging in this module
`logging.basicConfig` is cross-cutting and not specific to result collection. It’s better left to the application entry point, so importing this module doesn’t mutate global logging state.
```python
# Remove this from the module
- logging.basicConfig(level="INFO", format="%(levelname)s %(message)s")
```
If you still want a default logger pattern, use a module-level logger instead:
```python
logger = logging.getLogger(__name__)
# And replace calls:
- logging.info(...)
+ logger.info(...)
```
This keeps the class focused on result handling and avoids surprising side effects for callers.
---
### 3. Isolate cache handling into focused helpers
`cache_results` and `load_history_cache` do both “what to cache” and “how to cache” in one method. Extract the diskcache plumbing into small private helpers so the public methods read more declaratively:
```python
class ResultCollector:
...
def _with_cache(self, path: str) -> Cache:
return Cache(path, size_limit=self.cache_size)
def cache_results(self, bench_res: BenchResult, bench_cfg_hash: str) -> List[str]:
bench_cfg_hashes: list[str] = []
with self._with_cache("cachedir/benchmark_inputs") as c:
...
```
Or keep behavior identical but move the “object_index” juggling into a focused helper to reduce cognitive load:
```python
class ResultCollector:
...
def _store_bench_result(self, cache: Cache, key: str, bench_res: BenchResult) -> None:
obj_index_tmp = bench_res.object_index
bench_res.object_index = []
cache[key] = bench_res
bench_res.object_index = obj_index_tmp
def cache_results(self, bench_res: BenchResult, bench_cfg_hash: str) -> List[str]:
bench_cfg_hashes: list[str] = []
with Cache("cachedir/benchmark_inputs", size_limit=self.cache_size) as c:
logger.info(f"saving results with key: {bench_cfg_hash}")
bench_cfg_hashes.append(bench_cfg_hash)
self._store_bench_result(c, bench_cfg_hash, bench_res)
logger.info(f"saving benchmark: {bench_res.bench_cfg.bench_name}")
c[bench_res.bench_cfg.bench_name] = bench_cfg_hashes
return bench_cfg_hashes
```
This keeps all behavior the same but separates “what gets cached” from “how to make it pickleable”, which makes the class easier to reason about and extend.
</issue_to_address>
### Comment 12
<location> `bencher/sweep_executor.py:20` </location>
<code_context>
+# Default cache size for benchmark results (100 GB)
+DEFAULT_CACHE_SIZE_BYTES = int(100e9)
+
+logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
+
+
</code_context>
<issue_to_address>
**issue (complexity):** Consider refactoring logging setup, helper placement, and parameter handling helpers in SweepExecutor to reduce surface area and simplify the implementation without changing behavior.
You can keep all the new functionality but reduce indirection/surface area with a few small refactors:
1. **Avoid global `logging.basicConfig` here**
Configuring global logging in a helper module makes this module responsible for application‑wide side effects. Move this to the application entry point (or CLI) and just use the logger here.
```python
# sweep_executor.py
import logging
logger = logging.getLogger(__name__)
# remove logging.basicConfig(...)
def worker_kwargs_wrapper(...):
logger.info("...") # if you need logging
```
Then configure logging once in the top‑level script:
```python
# main.py / cli entrypoint
import logging
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
```
2. **Make `worker_kwargs_wrapper` part of `SweepExecutor`**
Since this is sweep/bench specific, keeping it at module level increases surface area. Moving it into `SweepExecutor` (or making it a `@staticmethod`) keeps related logic together and reduces scattering across modules.
```python
class SweepExecutor:
...
@staticmethod
def worker_kwargs(worker: Callable, bench_cfg: BenchCfg, **kwargs) -> dict:
function_input_deep = deepcopy(kwargs)
if not bench_cfg.pass_repeat:
function_input_deep.pop("repeat", None)
function_input_deep.pop("over_time", None)
function_input_deep.pop("time_event", None)
return worker(**function_input_deep)
```
Call sites then use:
```python
executor = SweepExecutor()
result = executor.worker_kwargs(worker, bench_cfg, **kwargs)
```
3. **Tighten up `define_const_inputs` to reduce branching/temporary lists**
You can keep behavior identical while simplifying the implementation:
```python
def define_const_inputs(
self, const_vars: Optional[List[Tuple[param.Parameter, Any]]]
) -> Optional[dict]:
if not const_vars:
return None
return {p.name: value for p, value in const_vars}
```
4. **Isolate the dict handling in `convert_vars_to_params`**
Most of the complexity is in the dict case. Pull that into a small focused helper to make the main method easier to read and reuse:
```python
class SweepExecutor:
...
def _build_param_from_dict(
self,
variable_cfg: dict,
run_cfg: Optional[BenchRunCfg],
worker_class_instance: ParametrizedSweep,
) -> param.Parameter:
param_var = worker_class_instance.param.objects(instance=False)[variable_cfg["name"]]
if "values" in variable_cfg:
param_var = param_var.with_sample_values(variable_cfg["values"])
if "samples" in variable_cfg:
param_var = param_var.with_samples(variable_cfg["samples"])
if "max_level" in variable_cfg and run_cfg is not None:
param_var = param_var.with_level(run_cfg.level, variable_cfg["max_level"])
return param_var
def convert_vars_to_params(...):
if isinstance(variable, str):
variable = worker_class_instance.param.objects(instance=False)[variable]
if isinstance(variable, dict):
variable = self._build_param_from_dict(variable, run_cfg, worker_class_instance)
if not isinstance(variable, param.Parameter):
raise TypeError(
f"You need to use {var_type}_vars =[{worker_input_cfg}.param.your_variable], "
f"instead of {var_type}_vars =[{worker_input_cfg}.your_variable]"
)
return variable
```
This keeps the new abstraction but makes `SweepExecutor`’s responsibilities more coherent and the module’s surface area smaller, without changing behavior.
</issue_to_address>
### Comment 13
<location> `bencher/worker_manager.py:17` </location>
<code_context>
- return data_array
-
-
-def kwargs_to_input_cfg(worker_input_cfg: ParametrizedSweep, **kwargs) -> ParametrizedSweep:
- """Create a configured instance of a ParametrizedSweep with the provided keyword arguments.
-
</code_context>
<issue_to_address>
**issue (complexity):** Consider moving the helper functions into `WorkerManager`, rMoves cache result list handling to the `Bench` class to better manage the list of benchmark configuration hashes, allowing appending to the list as opposed to overwriting. Updates logging to use a logger object instead of basicConfig to avoid potential conflicts with other modules using logging. Adds a check to ensure the worker class is an instance and not a type. Adds a test case to verify appending of cache results and object index persistence.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #687 +/- ##
==========================================
+ Coverage 89.10% 89.39% +0.29%
==========================================
Files 88 91 +3
Lines 4323 4432 +109
==========================================
+ Hits 3852 3962 +110
+ Misses 471 470 -1
🚀 New features to boost your workflow:
|
blooop
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good assuming the comments about potential cache bugs and correctness have been sufficiently resolved. The diff on the docs show that all the examples still function, but they may not catch some of the more subtle cache bugs.
Summary
The Bench class is quite hefty (~1200 lines) and handles many concerns: worker management, parameter sweeps, result storage, caching, and orchestration. This made it difficult to test individual pieces and understand the code flow.
This PR extracts three helper classes:
Bench now acts as a thin coordinator (~800 lines), delegating to these helpers while keeping the public API unchanged. The helpers are internal implementation details, not exported in
__init__.py.I've added thin wrappers to maintain backward compatibility but added in TODO notes of what can be removed if you consider a major version bump to just remove them.
Added 47 new unit tests for the extracted classes, including Hypothesis property-based tests for edge cases.
Summary by Sourcery
Decompose the monolithic Bench implementation into smaller helper classes for worker management, sweep execution, and result collection while preserving the public Bench API.
Enhancements:
Tests: