Skip to content

Commit 60e83fc

Browse files
authored
Merge pull request #97 from ezmsg-org/BlackrockNeurotech-main
Gaussian Smoothing Filter (redo of #95)
2 parents c04335f + a771de6 commit 60e83fc

File tree

2 files changed

+348
-0
lines changed

2 files changed

+348
-0
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Callable
2+
import warnings
3+
4+
import numpy as np
5+
6+
from .filter import (
7+
FilterBaseSettings,
8+
BACoeffs,
9+
FilterByDesignTransformer,
10+
BaseFilterByDesignTransformerUnit,
11+
)
12+
13+
14+
class GaussianSmoothingSettings(FilterBaseSettings):
15+
sigma: float | None = 1.0
16+
"""
17+
sigma : float
18+
Standard deviation of the Gaussian kernel.
19+
"""
20+
21+
width: int | None = 4
22+
"""
23+
width : int
24+
Number of standard deviations covered by the kernel window if kernel_size is not provided.
25+
"""
26+
27+
kernel_size: int | None = None
28+
"""
29+
kernel_size : int | None
30+
Length of the kernel in samples. If provided, overrides automatic calculation.
31+
"""
32+
33+
34+
def gaussian_smoothing_filter_design(
35+
sigma: float = 1.0,
36+
width: int = 4,
37+
kernel_size: int | None = None,
38+
) -> BACoeffs | None:
39+
# Parameter checks
40+
if sigma <= 0:
41+
raise ValueError(f"sigma must be positive. Received: {sigma}")
42+
43+
if width <= 0:
44+
raise ValueError(f"width must be positive. Received: {width}")
45+
46+
if kernel_size is not None:
47+
if kernel_size < 1:
48+
raise ValueError(f"kernel_size must be >= 1. Received: {kernel_size}")
49+
else:
50+
kernel_size = int(2 * width * sigma + 1)
51+
52+
# Warn if kernel_size is smaller than recommended but don't fail
53+
expected_kernel_size = int(2 * width * sigma + 1)
54+
if kernel_size < expected_kernel_size:
55+
## TODO: Either add a warning or determine appropriate kernel size and raise an error
56+
warnings.warn(
57+
f"Provided kernel_size {kernel_size} is smaller than recommended "
58+
f"size {expected_kernel_size} for sigma={sigma} and width={width}. "
59+
"The kernel may be truncated."
60+
)
61+
62+
from scipy.signal.windows import gaussian
63+
64+
b = gaussian(kernel_size, std=sigma)
65+
b /= np.sum(b) # Ensure normalization
66+
a = np.array([1.0])
67+
68+
return b, a
69+
70+
71+
class GaussianSmoothingFilterTransformer(
72+
FilterByDesignTransformer[GaussianSmoothingSettings, BACoeffs]
73+
):
74+
def get_design_function(
75+
self,
76+
) -> Callable[[float], BACoeffs]:
77+
# Create a wrapper function that ignores fs parameter since gaussian smoothing doesn't need it
78+
def design_wrapper(fs: float) -> BACoeffs:
79+
return gaussian_smoothing_filter_design(
80+
sigma=self.settings.sigma,
81+
width=self.settings.width,
82+
kernel_size=self.settings.kernel_size,
83+
)
84+
85+
return design_wrapper
86+
87+
88+
class GaussianSmoothingFilter(
89+
BaseFilterByDesignTransformerUnit[
90+
GaussianSmoothingSettings, GaussianSmoothingFilterTransformer
91+
]
92+
):
93+
SETTINGS = GaussianSmoothingSettings
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import numpy as np
2+
import pytest
3+
from ezmsg.util.messages.axisarray import AxisArray
4+
5+
from ezmsg.sigproc.gaussiansmoothing import (
6+
gaussian_smoothing_filter_design,
7+
GaussianSmoothingSettings,
8+
GaussianSmoothingFilterTransformer,
9+
)
10+
11+
12+
@pytest.mark.parametrize(
13+
"axis,sigma,width,kernel_size",
14+
[
15+
("time", 1.5, 5, None),
16+
("time", 2.0, 4, 21),
17+
],
18+
)
19+
def test_gaussian_smoothing_filter_function(axis, sigma, width, kernel_size):
20+
"""Test the gaussian_smoothing_filter convenience function."""
21+
transformer = GaussianSmoothingFilterTransformer(
22+
axis=axis,
23+
sigma=sigma,
24+
width=width,
25+
kernel_size=kernel_size,
26+
)
27+
28+
assert isinstance(transformer, GaussianSmoothingFilterTransformer)
29+
assert transformer.settings.axis == axis
30+
assert transformer.settings.sigma == sigma
31+
assert transformer.settings.width == width
32+
assert transformer.settings.kernel_size == kernel_size
33+
34+
35+
def test_gaussian_smoothing_settings_defaults():
36+
"""Test the GaussianSmoothingSettings class with default values."""
37+
settings = GaussianSmoothingSettings()
38+
assert settings.sigma == 1.0
39+
assert settings.width == 4
40+
assert settings.kernel_size is None
41+
42+
43+
def test_gaussian_smoothing_settings_custom():
44+
"""Test the GaussianSmoothingSettings class with custom values."""
45+
settings = GaussianSmoothingSettings(
46+
sigma=2.5,
47+
width=6,
48+
kernel_size=21,
49+
)
50+
assert settings.sigma == 2.5
51+
assert settings.width == 6
52+
assert settings.kernel_size == 21
53+
54+
55+
@pytest.mark.parametrize(
56+
"sigma,width,kernel_size",
57+
[
58+
(1.0, 4, None),
59+
(2.0, 6, None),
60+
(1.5, 5, 11),
61+
(1.5, 5, 17), # Fixed kernel_size to be >= expected
62+
],
63+
)
64+
def test_gaussian_smoothing_filter_design_parameters(sigma, width, kernel_size):
65+
"""Test gaussian smoothing filter design across multiple parameter configurations."""
66+
coefs = gaussian_smoothing_filter_design(
67+
sigma=sigma,
68+
width=width,
69+
kernel_size=kernel_size,
70+
)
71+
assert coefs is not None
72+
assert isinstance(coefs, tuple)
73+
assert len(coefs) == 2 # b and a coefficients
74+
75+
b, a = coefs
76+
assert b is not None and a is not None
77+
assert isinstance(b, np.ndarray) and isinstance(a, np.ndarray)
78+
assert np.all(b >= 0) # positive
79+
assert np.allclose(b, b[::-1]) # symmetric
80+
assert np.isclose(np.sum(b), 1.0) # normalized
81+
assert b[len(b) // 2] == np.max(b) # center of kernel is peak
82+
assert len(a) == 1 and a[0] == 1.0 # default for gaussian window
83+
84+
expected_kernel_size = (
85+
int(2 * width * sigma + 1) if kernel_size is None else kernel_size
86+
)
87+
assert len(b) == expected_kernel_size
88+
89+
90+
def test_gaussian_smoothing_kernel_properties():
91+
"""Test that larger sigma creates wider kernel"""
92+
coefs_small = gaussian_smoothing_filter_design(sigma=1.0)
93+
b_small, _ = coefs_small
94+
coefs_large = gaussian_smoothing_filter_design(sigma=3.0)
95+
b_large, _ = coefs_large
96+
97+
assert len(b_large) >= len(b_small) # wider kernel
98+
assert b_large[len(b_large) // 2] < b_small[len(b_small) // 2] # lower peak
99+
100+
101+
@pytest.mark.parametrize("sigma", [0.0, -1.0])
102+
@pytest.mark.parametrize("width", [0.0, -1.0])
103+
@pytest.mark.parametrize("kernel_size", [0, -1])
104+
def test_gaussian_smoothing_filter_design_invalid_inputs(sigma, width, kernel_size):
105+
"""Test the gaussian smoothing filter design function with invalid inputs."""
106+
with pytest.raises(ValueError):
107+
gaussian_smoothing_filter_design(sigma=sigma)
108+
with pytest.raises(ValueError):
109+
gaussian_smoothing_filter_design(width=width)
110+
with pytest.raises(ValueError):
111+
gaussian_smoothing_filter_design(kernel_size=kernel_size)
112+
113+
114+
@pytest.mark.parametrize("data_shape", [(100,), (1, 100), (100, 2), (100, 2, 3)])
115+
def test_gaussian_smoothing_filter_process(data_shape):
116+
"""Test gaussian smoothing filter with different data shapes."""
117+
# Create test data
118+
data = np.arange(np.prod(data_shape)).reshape(data_shape)
119+
120+
# Create appropriate dims and axes based on shape
121+
if len(data_shape) == 1:
122+
dims = ["time"]
123+
axes = {"time": AxisArray.TimeAxis(fs=100.0, offset=0)}
124+
elif len(data_shape) == 2:
125+
dims = ["time", "ch"]
126+
axes = {
127+
"time": AxisArray.TimeAxis(fs=100.0, offset=0),
128+
"ch": AxisArray.CoordinateAxis(
129+
data=np.arange(data_shape[1]).astype(str), dims=["ch"]
130+
),
131+
}
132+
else:
133+
dims = ["freq", "time", "ch"]
134+
axes = {
135+
"freq": AxisArray.LinearAxis(unit="Hz", offset=0.0, gain=1.0),
136+
"time": AxisArray.TimeAxis(fs=100.0, offset=0),
137+
"ch": AxisArray.CoordinateAxis(
138+
data=np.arange(data_shape[2]).astype(str), dims=["ch"]
139+
),
140+
}
141+
142+
msg = AxisArray(data=data, dims=dims, axes=axes, key="test_gaussian_smoothing")
143+
144+
# Instantiate transformer (not unit)
145+
transformer = GaussianSmoothingFilterTransformer(
146+
settings=GaussianSmoothingSettings(axis="time", sigma=2.0, width=4)
147+
)
148+
149+
# Process message using __call__ method
150+
output_msg = transformer(msg)
151+
152+
# Assertions
153+
assert isinstance(output_msg, AxisArray)
154+
assert output_msg.data.shape == data.shape
155+
assert np.isfinite(output_msg.data).all()
156+
157+
158+
def test_gaussian_smoothing_edge_cases():
159+
"""Test edge cases for gaussian smoothing filter."""
160+
# Test with very small sigma
161+
coefs_small = gaussian_smoothing_filter_design(sigma=0.01)
162+
b_small, a_small = coefs_small
163+
assert len(b_small) > 0
164+
assert np.isclose(np.sum(b_small), 1.0)
165+
166+
# Test with very large sigma
167+
coefs_large = gaussian_smoothing_filter_design(sigma=100.0)
168+
b_large, a_large = coefs_large
169+
assert len(b_large) > 0
170+
assert np.isclose(np.sum(b_large), 1.0)
171+
172+
# Test with very small width
173+
coefs_narrow = gaussian_smoothing_filter_design(width=1)
174+
b_narrow, a_narrow = coefs_narrow
175+
assert len(b_narrow) > 0
176+
177+
# Test with very large width
178+
coefs_wide = gaussian_smoothing_filter_design(width=100)
179+
b_wide, a_wide = coefs_wide
180+
assert len(b_wide) > len(b_narrow) # Wider kernel
181+
182+
183+
def test_gaussian_smoothing_update_settings():
184+
"""Test the update_settings functionality of the Gaussian smoothing filter."""
185+
# Setup parameters
186+
fs = 200.0
187+
dur = 1.0
188+
n_times = int(dur * fs)
189+
n_chans = 2
190+
191+
# Create input data with high frequency noise
192+
t = np.arange(n_times) / fs
193+
# Create a signal with both low and high frequency components
194+
signal = np.sin(2 * np.pi * 5 * t) + 0.5 * np.sin(2 * np.pi * 50 * t)
195+
in_dat = np.vstack([signal, signal + np.random.randn(n_times) * 0.1]).T
196+
197+
# Create message
198+
msg_in = AxisArray(
199+
data=in_dat,
200+
dims=["time", "ch"],
201+
axes={
202+
"time": AxisArray.TimeAxis(fs=fs, offset=0),
203+
"ch": AxisArray.CoordinateAxis(
204+
data=np.arange(n_chans).astype(str), dims=["ch"]
205+
),
206+
},
207+
key="test_gaussian_smoothing_update_settings",
208+
)
209+
210+
def _calc_smoothing_effect(msg):
211+
"""Calculate the smoothing effect by comparing variance."""
212+
return np.var(msg.data, axis=0)
213+
214+
original_variance = _calc_smoothing_effect(msg_in)
215+
216+
# Initialize filter with small sigma (minimal smoothing)
217+
proc = GaussianSmoothingFilterTransformer(
218+
axis="time",
219+
sigma=0.5,
220+
width=4,
221+
coef_type="ba",
222+
)
223+
224+
# Process first message
225+
result1 = proc(msg_in)
226+
variance1 = _calc_smoothing_effect(result1)
227+
228+
# Small sigma should have minimal effect
229+
assert np.allclose(variance1, original_variance, rtol=0.1)
230+
231+
# Update settings - change to larger sigma (more smoothing)
232+
proc.update_settings(sigma=3.0)
233+
234+
# Process the same message with new settings
235+
result2 = proc(msg_in)
236+
variance2 = _calc_smoothing_effect(result2)
237+
238+
# Larger sigma should reduce variance (more smoothing)
239+
assert np.all(variance2 < variance1)
240+
241+
# Test update_settings with complete new settings object
242+
new_settings = GaussianSmoothingSettings(
243+
axis="time",
244+
sigma=5.0, # Even larger sigma
245+
width=6,
246+
kernel_size=None,
247+
coef_type="ba",
248+
)
249+
250+
proc.update_settings(new_settings=new_settings)
251+
result3 = proc(msg_in)
252+
variance3 = _calc_smoothing_effect(result3)
253+
254+
# Even larger sigma should reduce variance further
255+
assert np.all(variance3 < variance2)

0 commit comments

Comments
 (0)