|
| 1 | +import numpy as np |
| 2 | +import pytest |
| 3 | +from ezmsg.util.messages.axisarray import AxisArray |
| 4 | + |
| 5 | +from ezmsg.sigproc.gaussiansmoothing import ( |
| 6 | + gaussian_smoothing_filter_design, |
| 7 | + GaussianSmoothingSettings, |
| 8 | + GaussianSmoothingFilterTransformer, |
| 9 | +) |
| 10 | + |
| 11 | + |
| 12 | +@pytest.mark.parametrize( |
| 13 | + "axis,sigma,width,kernel_size", |
| 14 | + [ |
| 15 | + ("time", 1.5, 5, None), |
| 16 | + ("time", 2.0, 4, 21), |
| 17 | + ], |
| 18 | +) |
| 19 | +def test_gaussian_smoothing_filter_function(axis, sigma, width, kernel_size): |
| 20 | + """Test the gaussian_smoothing_filter convenience function.""" |
| 21 | + transformer = GaussianSmoothingFilterTransformer( |
| 22 | + axis=axis, |
| 23 | + sigma=sigma, |
| 24 | + width=width, |
| 25 | + kernel_size=kernel_size, |
| 26 | + ) |
| 27 | + |
| 28 | + assert isinstance(transformer, GaussianSmoothingFilterTransformer) |
| 29 | + assert transformer.settings.axis == axis |
| 30 | + assert transformer.settings.sigma == sigma |
| 31 | + assert transformer.settings.width == width |
| 32 | + assert transformer.settings.kernel_size == kernel_size |
| 33 | + |
| 34 | + |
| 35 | +def test_gaussian_smoothing_settings_defaults(): |
| 36 | + """Test the GaussianSmoothingSettings class with default values.""" |
| 37 | + settings = GaussianSmoothingSettings() |
| 38 | + assert settings.sigma == 1.0 |
| 39 | + assert settings.width == 4 |
| 40 | + assert settings.kernel_size is None |
| 41 | + |
| 42 | + |
| 43 | +def test_gaussian_smoothing_settings_custom(): |
| 44 | + """Test the GaussianSmoothingSettings class with custom values.""" |
| 45 | + settings = GaussianSmoothingSettings( |
| 46 | + sigma=2.5, |
| 47 | + width=6, |
| 48 | + kernel_size=21, |
| 49 | + ) |
| 50 | + assert settings.sigma == 2.5 |
| 51 | + assert settings.width == 6 |
| 52 | + assert settings.kernel_size == 21 |
| 53 | + |
| 54 | + |
| 55 | +@pytest.mark.parametrize( |
| 56 | + "sigma,width,kernel_size", |
| 57 | + [ |
| 58 | + (1.0, 4, None), |
| 59 | + (2.0, 6, None), |
| 60 | + (1.5, 5, 11), |
| 61 | + (1.5, 5, 17), # Fixed kernel_size to be >= expected |
| 62 | + ], |
| 63 | +) |
| 64 | +def test_gaussian_smoothing_filter_design_parameters(sigma, width, kernel_size): |
| 65 | + """Test gaussian smoothing filter design across multiple parameter configurations.""" |
| 66 | + coefs = gaussian_smoothing_filter_design( |
| 67 | + sigma=sigma, |
| 68 | + width=width, |
| 69 | + kernel_size=kernel_size, |
| 70 | + ) |
| 71 | + assert coefs is not None |
| 72 | + assert isinstance(coefs, tuple) |
| 73 | + assert len(coefs) == 2 # b and a coefficients |
| 74 | + |
| 75 | + b, a = coefs |
| 76 | + assert b is not None and a is not None |
| 77 | + assert isinstance(b, np.ndarray) and isinstance(a, np.ndarray) |
| 78 | + assert np.all(b >= 0) # positive |
| 79 | + assert np.allclose(b, b[::-1]) # symmetric |
| 80 | + assert np.isclose(np.sum(b), 1.0) # normalized |
| 81 | + assert b[len(b) // 2] == np.max(b) # center of kernel is peak |
| 82 | + assert len(a) == 1 and a[0] == 1.0 # default for gaussian window |
| 83 | + |
| 84 | + expected_kernel_size = ( |
| 85 | + int(2 * width * sigma + 1) if kernel_size is None else kernel_size |
| 86 | + ) |
| 87 | + assert len(b) == expected_kernel_size |
| 88 | + |
| 89 | + |
| 90 | +def test_gaussian_smoothing_kernel_properties(): |
| 91 | + """Test that larger sigma creates wider kernel""" |
| 92 | + coefs_small = gaussian_smoothing_filter_design(sigma=1.0) |
| 93 | + b_small, _ = coefs_small |
| 94 | + coefs_large = gaussian_smoothing_filter_design(sigma=3.0) |
| 95 | + b_large, _ = coefs_large |
| 96 | + |
| 97 | + assert len(b_large) >= len(b_small) # wider kernel |
| 98 | + assert b_large[len(b_large) // 2] < b_small[len(b_small) // 2] # lower peak |
| 99 | + |
| 100 | + |
| 101 | +@pytest.mark.parametrize("sigma", [0.0, -1.0]) |
| 102 | +@pytest.mark.parametrize("width", [0.0, -1.0]) |
| 103 | +@pytest.mark.parametrize("kernel_size", [0, -1]) |
| 104 | +def test_gaussian_smoothing_filter_design_invalid_inputs(sigma, width, kernel_size): |
| 105 | + """Test the gaussian smoothing filter design function with invalid inputs.""" |
| 106 | + with pytest.raises(ValueError): |
| 107 | + gaussian_smoothing_filter_design(sigma=sigma) |
| 108 | + with pytest.raises(ValueError): |
| 109 | + gaussian_smoothing_filter_design(width=width) |
| 110 | + with pytest.raises(ValueError): |
| 111 | + gaussian_smoothing_filter_design(kernel_size=kernel_size) |
| 112 | + |
| 113 | + |
| 114 | +@pytest.mark.parametrize("data_shape", [(100,), (1, 100), (100, 2), (100, 2, 3)]) |
| 115 | +def test_gaussian_smoothing_filter_process(data_shape): |
| 116 | + """Test gaussian smoothing filter with different data shapes.""" |
| 117 | + # Create test data |
| 118 | + data = np.arange(np.prod(data_shape)).reshape(data_shape) |
| 119 | + |
| 120 | + # Create appropriate dims and axes based on shape |
| 121 | + if len(data_shape) == 1: |
| 122 | + dims = ["time"] |
| 123 | + axes = {"time": AxisArray.TimeAxis(fs=100.0, offset=0)} |
| 124 | + elif len(data_shape) == 2: |
| 125 | + dims = ["time", "ch"] |
| 126 | + axes = { |
| 127 | + "time": AxisArray.TimeAxis(fs=100.0, offset=0), |
| 128 | + "ch": AxisArray.CoordinateAxis( |
| 129 | + data=np.arange(data_shape[1]).astype(str), dims=["ch"] |
| 130 | + ), |
| 131 | + } |
| 132 | + else: |
| 133 | + dims = ["freq", "time", "ch"] |
| 134 | + axes = { |
| 135 | + "freq": AxisArray.LinearAxis(unit="Hz", offset=0.0, gain=1.0), |
| 136 | + "time": AxisArray.TimeAxis(fs=100.0, offset=0), |
| 137 | + "ch": AxisArray.CoordinateAxis( |
| 138 | + data=np.arange(data_shape[2]).astype(str), dims=["ch"] |
| 139 | + ), |
| 140 | + } |
| 141 | + |
| 142 | + msg = AxisArray(data=data, dims=dims, axes=axes, key="test_gaussian_smoothing") |
| 143 | + |
| 144 | + # Instantiate transformer (not unit) |
| 145 | + transformer = GaussianSmoothingFilterTransformer( |
| 146 | + settings=GaussianSmoothingSettings(axis="time", sigma=2.0, width=4) |
| 147 | + ) |
| 148 | + |
| 149 | + # Process message using __call__ method |
| 150 | + output_msg = transformer(msg) |
| 151 | + |
| 152 | + # Assertions |
| 153 | + assert isinstance(output_msg, AxisArray) |
| 154 | + assert output_msg.data.shape == data.shape |
| 155 | + assert np.isfinite(output_msg.data).all() |
| 156 | + |
| 157 | + |
| 158 | +def test_gaussian_smoothing_edge_cases(): |
| 159 | + """Test edge cases for gaussian smoothing filter.""" |
| 160 | + # Test with very small sigma |
| 161 | + coefs_small = gaussian_smoothing_filter_design(sigma=0.01) |
| 162 | + b_small, a_small = coefs_small |
| 163 | + assert len(b_small) > 0 |
| 164 | + assert np.isclose(np.sum(b_small), 1.0) |
| 165 | + |
| 166 | + # Test with very large sigma |
| 167 | + coefs_large = gaussian_smoothing_filter_design(sigma=100.0) |
| 168 | + b_large, a_large = coefs_large |
| 169 | + assert len(b_large) > 0 |
| 170 | + assert np.isclose(np.sum(b_large), 1.0) |
| 171 | + |
| 172 | + # Test with very small width |
| 173 | + coefs_narrow = gaussian_smoothing_filter_design(width=1) |
| 174 | + b_narrow, a_narrow = coefs_narrow |
| 175 | + assert len(b_narrow) > 0 |
| 176 | + |
| 177 | + # Test with very large width |
| 178 | + coefs_wide = gaussian_smoothing_filter_design(width=100) |
| 179 | + b_wide, a_wide = coefs_wide |
| 180 | + assert len(b_wide) > len(b_narrow) # Wider kernel |
| 181 | + |
| 182 | + |
| 183 | +def test_gaussian_smoothing_update_settings(): |
| 184 | + """Test the update_settings functionality of the Gaussian smoothing filter.""" |
| 185 | + # Setup parameters |
| 186 | + fs = 200.0 |
| 187 | + dur = 1.0 |
| 188 | + n_times = int(dur * fs) |
| 189 | + n_chans = 2 |
| 190 | + |
| 191 | + # Create input data with high frequency noise |
| 192 | + t = np.arange(n_times) / fs |
| 193 | + # Create a signal with both low and high frequency components |
| 194 | + signal = np.sin(2 * np.pi * 5 * t) + 0.5 * np.sin(2 * np.pi * 50 * t) |
| 195 | + in_dat = np.vstack([signal, signal + np.random.randn(n_times) * 0.1]).T |
| 196 | + |
| 197 | + # Create message |
| 198 | + msg_in = AxisArray( |
| 199 | + data=in_dat, |
| 200 | + dims=["time", "ch"], |
| 201 | + axes={ |
| 202 | + "time": AxisArray.TimeAxis(fs=fs, offset=0), |
| 203 | + "ch": AxisArray.CoordinateAxis( |
| 204 | + data=np.arange(n_chans).astype(str), dims=["ch"] |
| 205 | + ), |
| 206 | + }, |
| 207 | + key="test_gaussian_smoothing_update_settings", |
| 208 | + ) |
| 209 | + |
| 210 | + def _calc_smoothing_effect(msg): |
| 211 | + """Calculate the smoothing effect by comparing variance.""" |
| 212 | + return np.var(msg.data, axis=0) |
| 213 | + |
| 214 | + original_variance = _calc_smoothing_effect(msg_in) |
| 215 | + |
| 216 | + # Initialize filter with small sigma (minimal smoothing) |
| 217 | + proc = GaussianSmoothingFilterTransformer( |
| 218 | + axis="time", |
| 219 | + sigma=0.5, |
| 220 | + width=4, |
| 221 | + coef_type="ba", |
| 222 | + ) |
| 223 | + |
| 224 | + # Process first message |
| 225 | + result1 = proc(msg_in) |
| 226 | + variance1 = _calc_smoothing_effect(result1) |
| 227 | + |
| 228 | + # Small sigma should have minimal effect |
| 229 | + assert np.allclose(variance1, original_variance, rtol=0.1) |
| 230 | + |
| 231 | + # Update settings - change to larger sigma (more smoothing) |
| 232 | + proc.update_settings(sigma=3.0) |
| 233 | + |
| 234 | + # Process the same message with new settings |
| 235 | + result2 = proc(msg_in) |
| 236 | + variance2 = _calc_smoothing_effect(result2) |
| 237 | + |
| 238 | + # Larger sigma should reduce variance (more smoothing) |
| 239 | + assert np.all(variance2 < variance1) |
| 240 | + |
| 241 | + # Test update_settings with complete new settings object |
| 242 | + new_settings = GaussianSmoothingSettings( |
| 243 | + axis="time", |
| 244 | + sigma=5.0, # Even larger sigma |
| 245 | + width=6, |
| 246 | + kernel_size=None, |
| 247 | + coef_type="ba", |
| 248 | + ) |
| 249 | + |
| 250 | + proc.update_settings(new_settings=new_settings) |
| 251 | + result3 = proc(msg_in) |
| 252 | + variance3 = _calc_smoothing_effect(result3) |
| 253 | + |
| 254 | + # Even larger sigma should reduce variance further |
| 255 | + assert np.all(variance3 < variance2) |
0 commit comments