Skip to content

Commit 186a748

Browse files
authored
Merge pull request #652 from blooop/feature/tabulator_result
Feature/tabulator result
2 parents 2cd02a0 + fd57cec commit 186a748

7 files changed

Lines changed: 164 additions & 73 deletions

File tree

bencher/example/inputs_0_float/example_3_cat_in_2_out.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,15 @@ def example_3_cat_in_2_out(
103103
- Note that variance in the results simulates real-world measurement fluctuations
104104
""",
105105
)
106+
107+
res = bench.get_result()
108+
109+
bench.report.append(res.to(bch.BarResult, agg_over_dims=["data_structure", "data_size"]))
110+
bench.report.append(res.to(bch.BarResult, agg_over_dims=["data_structure"]))
111+
bench.report.append(res.to(bch.BarResult, agg_over_dims=["data_size"]))
106112
return bench
107113

108114

109115
if __name__ == "__main__":
110116
br = bch.BenchRunner()
111-
br.add(example_3_cat_in_2_out).run(repeats=5, show=True)
117+
br.add(example_3_cat_in_2_out).run(repeats=1, show=True)

bencher/example/inputs_2_float/example_2_float_1_cat_in_2_out.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def example_2_float_1_cat_in_2_out(
9191
"""
9292
if run_cfg is None:
9393
run_cfg = bch.BenchRunCfg()
94-
run_cfg.repeats = 3 # Fewer repeats for a quicker benchmark
94+
run_cfg.repeats = 1 # Fewer repeats for a quicker benchmark
9595

9696
bench = Pattern1CatBenchmark().to_bench(run_cfg, report)
9797
bench.plot_sweep(
@@ -108,6 +108,20 @@ def example_2_float_1_cat_in_2_out(
108108
res = bench.get_result()
109109

110110
bench.report.append(res.to(bch.HeatmapResult, agg_over_dims=["pattern_type"]))
111+
# bench.report.append(res.to(bch.TabulatorResult, agg_over_dims=["pattern_type", "x_value"]))
112+
# bench.report.append(res.to(bch.TableResult, agg_over_dims=["pattern_type", "x_value", "y_value"]))
113+
114+
bench.report.append(
115+
res.to(bch.TabulatorResult, agg_over_dims=["pattern_type", "x_value", "y_value"])
116+
)
117+
bench.report.append(
118+
res.to(
119+
bch.TabulatorResult,
120+
agg_over_dims=["pattern_type", "x_value"],
121+
)
122+
)
123+
bench.report.append(res.to(bch.TabulatorResult, agg_over_dims=["pattern_type"]))
124+
bench.report.append(res.to(bch.TabulatorResult))
111125

112126
return bench
113127

bencher/results/holoview_results/bar_result.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,25 @@ def to_bar_ds(self, dataset: xr.Dataset, result_var: Parameter = None, **kwargs)
9090
Returns:
9191
hvplot.element.Bars: A bar chart visualization of the benchmark data.
9292
"""
93-
by = None
94-
if self.plt_cnt_cfg.cat_cnt >= 2 and self.plt_cnt_cfg.cat_vars[1].name in dataset.dims:
95-
by = self.plt_cnt_cfg.cat_vars[1].name
96-
93+
# Determine grouping ('by') dynamically based on dims that still exist
9794
da = dataset[result_var.name]
95+
96+
# Allow explicit override via kwargs
97+
by = kwargs.pop("by", None)
98+
if by is None:
99+
# Candidate categorical dims from the original config, filtered to those still present
100+
cat_dim_names = [cv.name for cv in self.plt_cnt_cfg.cat_vars]
101+
dims_present = [d for d in da.dims if d not in ("repeat", "over_time")]
102+
# Prefer categorical dims that are not the primary x-axis
103+
candidates = [d for d in dims_present if d != da.dims[0] and d in cat_dim_names]
104+
105+
if len(candidates) == 1:
106+
by = candidates[0]
107+
elif len(candidates) > 1:
108+
# Preserve multi-level grouping when multiple categorical dims remain
109+
by = candidates
110+
else:
111+
by = None
98112
title = self.title_from_ds(da, result_var, **kwargs)
99113
time_args = self.time_widget(title)
100114

bencher/results/holoview_results/tabulator_result.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33

44
from bencher.results.holoview_results.holoview_result import HoloviewResult
55

6+
from param import Parameter
7+
import hvplot.xarray # noqa pylint: disable=duplicate-code,unused-import
8+
import xarray as xr
9+
import pandas as pd
10+
611

712
class TabulatorResult(HoloviewResult):
813
def to_plot(self, **kwargs) -> pn.widgets.Tabulator: # pylint:disable=unused-argument
@@ -17,4 +22,54 @@ def to_plot(self, **kwargs) -> pn.widgets.Tabulator: # pylint:disable=unused-ar
1722
Returns:
1823
pn.widgets.Tabulator: An interactive table widget.
1924
"""
20-
return pn.widgets.Tabulator(self.to_pandas())
25+
return self.to_tabulator(**kwargs)
26+
27+
def to_tabulator(self, result_var: Parameter = None, **kwargs) -> pn.widgets.Tabulator:
28+
"""Generates a Tabulator widget from benchmark data.
29+
30+
This is a convenience method that calls to_tabulator_ds() with the same parameters.
31+
32+
Args:
33+
result_var (Parameter, optional): The result variable to include in the table. If None, uses the default.
34+
**kwargs: Additional keyword arguments passed to the Tabulator constructor.
35+
36+
Returns:
37+
pn.widgets.Tabulator: An interactive table widget.
38+
"""
39+
return self.filter(
40+
self.to_tabulator_ds,
41+
result_var=result_var,
42+
**kwargs,
43+
)
44+
45+
def to_tabulator_ds(
46+
self, dataset: xr.Dataset, result_var: Parameter, **kwargs
47+
) -> pn.widgets.Tabulator:
48+
"""Creates a Tabulator widget from the provided dataset.
49+
50+
Given a filtered dataset, this method generates an interactive table visualization.
51+
52+
Args:
53+
dataset (xr.Dataset): The filtered dataset to visualize.
54+
result_var (Parameter): The result variable to include in the table.
55+
**kwargs: Additional keyword arguments passed to the Tabulator constructor.
56+
57+
Returns:
58+
pn.widgets.Tabulator: An interactive table widget.
59+
"""
60+
61+
# Assume input is an xarray.Dataset. Keep Dataset throughout.
62+
ds: xr.Dataset = dataset if isinstance(dataset, xr.Dataset) else xr.Dataset(dataset)
63+
64+
# Step 1: If a result variable is specified, select it and keep as Dataset
65+
if result_var is not None and result_var.name in ds.data_vars:
66+
ds = ds[[result_var.name]]
67+
68+
# Step 2: Build a flat pandas DataFrame
69+
if len(ds.dims) == 0:
70+
df = pd.DataFrame({name: [da.values.item()] for name, da in ds.data_vars.items()})
71+
else:
72+
# N-D: to DataFrame and reset the index so coordinates become columns
73+
df = ds.to_dataframe().reset_index()
74+
75+
return pn.widgets.Tabulator(df, **kwargs)

0 commit comments

Comments
 (0)