Skip to content

Commit 378d669

Browse files
authored
Merge pull request #94 from ezmsg-org/dev
Add trapezoid to list of aggregation functions and expose function choice in BandPower
2 parents 4f26871 + 031d325 commit 378d669

File tree

4 files changed

+55
-7
lines changed

4 files changed

+55
-7
lines changed

src/ezmsg/sigproc/aggregate.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class AggregationFunction(OptionsEnum):
3636
NANSUM = "nansum"
3737
ARGMIN = "argmin"
3838
ARGMAX = "argmax"
39+
TRAPEZOID = "trapezoid"
3940

4041

4142
AGGREGATORS = {
@@ -54,6 +55,9 @@ class AggregationFunction(OptionsEnum):
5455
AggregationFunction.NANSUM: np.nansum,
5556
AggregationFunction.ARGMIN: np.argmin,
5657
AggregationFunction.ARGMAX: np.argmax,
58+
# Note: Some methods require x-coordinates and
59+
# are handled specially in `_process`.
60+
AggregationFunction.TRAPEZOID: np.trapezoid,
5761
}
5862

5963

@@ -144,10 +148,23 @@ def _process(self, message: AxisArray) -> AxisArray:
144148
ax_idx = message.get_axis_idx(axis)
145149
agg_func = AGGREGATORS[self.settings.operation]
146150

147-
out_data = [
148-
agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx)
149-
for sl in self._state.slices
150-
]
151+
if self.settings.operation in [
152+
AggregationFunction.TRAPEZOID,
153+
]:
154+
# Special handling for methods that require x-coordinates.
155+
out_data = [
156+
agg_func(
157+
slice_along_axis(message.data, sl, axis=ax_idx),
158+
x=self._state.ax_vec[sl],
159+
axis=ax_idx,
160+
)
161+
for sl in self._state.slices
162+
]
163+
else:
164+
out_data = [
165+
agg_func(slice_along_axis(message.data, sl, axis=ax_idx), axis=ax_idx)
166+
for sl in self._state.slices
167+
]
151168

152169
msg_out = replace(
153170
message,

src/ezmsg/sigproc/bandpower.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class BandPowerSettings(ez.Settings):
3636
(min, max) tuples of band limits in Hz.
3737
"""
3838

39+
aggregation: AggregationFunction = AggregationFunction.MEAN
40+
""":obj:`AggregationFunction` to apply to each band."""
41+
3942

4043
class BandPowerTransformer(CompositeProcessor[BandPowerSettings, AxisArray, AxisArray]):
4144
@staticmethod
@@ -50,7 +53,7 @@ def _initialize_processors(
5053
settings=RangedAggregateSettings(
5154
axis="freq",
5255
bands=settings.bands,
53-
operation=AggregationFunction.MEAN,
56+
operation=settings.aggregation,
5457
)
5558
),
5659
}
@@ -68,6 +71,7 @@ def bandpower(
6871
(17, 30),
6972
(70, 170),
7073
],
74+
aggregation: AggregationFunction = AggregationFunction.MEAN,
7175
) -> BandPowerTransformer:
7276
"""
7377
Calculate the average spectral power in each band.
@@ -77,6 +81,8 @@ def bandpower(
7781
"""
7882
return BandPowerTransformer(
7983
settings=BandPowerSettings(
80-
spectrogram_settings=spectrogram_settings, bands=bands
84+
spectrogram_settings=spectrogram_settings,
85+
bands=bands,
86+
aggregation=aggregation,
8187
)
8288
)

tests/unit/test_aggregate.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,30 @@ def test_arg_aggregate(agg_func: AggregationFunction):
110110
assert np.array_equal(out_dat, expected_dat)
111111

112112

113+
def test_trapezoid():
114+
bands = [(5.0, 20.0), (30.0, 50.0)]
115+
in_msgs = [_ for _ in get_msg_gen()]
116+
gen = ranged_aggregate(
117+
axis="freq", bands=bands, operation=AggregationFunction.TRAPEZOID
118+
)
119+
out_msgs = [gen.send(_) for _ in in_msgs]
120+
121+
out_dat = AxisArray.concatenate(*out_msgs, dim="time").data
122+
123+
# Calculate expected data using trapezoidal integration
124+
in_data = AxisArray.concatenate(*in_msgs, dim="time").data
125+
targ_ax = in_msgs[0].axes["freq"]
126+
targ_ax_vec = targ_ax.value(np.arange(in_data.shape[-1]))
127+
expected = []
128+
for start, stop in bands:
129+
inds = np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)
130+
expected.append(np.trapezoid(in_data[..., inds], x=targ_ax_vec[inds], axis=-1))
131+
expected = np.stack(expected, axis=-1)
132+
133+
assert out_dat.shape == expected.shape
134+
assert np.allclose(out_dat, expected)
135+
136+
113137
@pytest.mark.parametrize("change_ax", ["ch", "freq"])
114138
def test_aggregate_handle_change(change_ax: str):
115139
"""

tests/unit/test_bandpower.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44
from ezmsg.util.messages.axisarray import AxisArray
5-
from ezmsg.sigproc.bandpower import bandpower, SpectrogramSettings
5+
from ezmsg.sigproc.bandpower import bandpower, SpectrogramSettings, AggregationFunction
66

77
from tests.helpers.util import (
88
create_messages_with_periodic_signal,
@@ -45,6 +45,7 @@ def test_bandpower():
4545
window_shift=0.1,
4646
),
4747
bands=bands,
48+
aggregation=AggregationFunction.MEAN,
4849
)
4950
results = [gen.send(_) for _ in messages]
5051

0 commit comments

Comments
 (0)