diff --git a/docs/changes/devel/feature.rst b/docs/changes/devel/feature.rst new file mode 100644 index 00000000..0cb7a0ae --- /dev/null +++ b/docs/changes/devel/feature.rst @@ -0,0 +1 @@ +Add adaptive DSS with automatic segmentation (:class:`~mne_denoise.dss.utils.CovarianceSegmenter`, :class:`~mne_denoise.dss.utils.FixedWindowSegmenter`), component auto-selection (:func:`~mne_denoise.dss.utils.eigenvalue_ratio_selection`, :func:`~mne_denoise.dss.utils.max_gap_selection`), smoothing decomposition, L2-normalised patterns, and proportional Q-mode for the periodic denoiser, by `Scott Huberty`_. diff --git a/examples/dss/plot_01_dss_fundamentals.py b/examples/dss/plot_01_dss_fundamentals.py index fc419993..c81379ec 100644 --- a/examples/dss/plot_01_dss_fundamentals.py +++ b/examples/dss/plot_01_dss_fundamentals.py @@ -202,7 +202,7 @@ print("Reconstructing data from first component...") sources = dss_evoked.transform(epochs) # To reconstruct using only specific components, we zero out the others -sources[:, 1:, :] = 0 +sources[1:, :, :] = 0 epochs_denoised = dss_evoked.inverse_transform(sources) epochs_denoised = mne.EpochsArray(epochs_denoised, info) @@ -270,7 +270,7 @@ # We concatenate epochs for continuous reconstruction if desired, or keep as epochs # Here we keep as epochs to use plot_psd_comparison sources = dss_osc.transform(epochs) -sources[:, 1:, :] = 0 +sources[1:, :, :] = 0 epochs_osc = dss_osc.inverse_transform(sources) epochs_osc = mne.EpochsArray(epochs_osc, info) @@ -424,7 +424,7 @@ # Denoising Comparison print("Reconstructing M100 component...") sources = dss_m100.transform(epochs_real) -sources[:, 1:, :] = 0 +sources[1:, :, :] = 0 epochs_m100 = dss_m100.inverse_transform(sources) epochs_m100 = mne.EpochsArray(epochs_m100, epochs_real.info) diff --git a/examples/dss/plot_04_spectral_dss.py b/examples/dss/plot_04_spectral_dss.py index 81d5363c..257e7e88 100644 --- a/examples/dss/plot_04_spectral_dss.py +++ b/examples/dss/plot_04_spectral_dss.py @@ -220,7 +220,7 @@ # We use method='iir' to replicate a traditional Notch filter approach notch_bias_60 = LineNoiseBias(freq=60, sfreq=sfreq, method="iir", bandwidth=2) -dss_notch = DSS(n_components=None, bias=notch_bias_60) +dss_notch = DSS(n_components=3, bias=notch_bias_60) dss_notch.fit(raw_noisy) print(f"\nDSS Eigenvalues: {dss_notch.eigenvalues_[:3]}") diff --git a/examples/zapline/plot_04_adaptive_mode.py b/examples/zapline/plot_04_adaptive_mode.py index 12e0d5ee..a751b665 100644 --- a/examples/zapline/plot_04_adaptive_mode.py +++ b/examples/zapline/plot_04_adaptive_mode.py @@ -17,12 +17,12 @@ from matplotlib.gridspec import GridSpec from scipy import signal +from mne_denoise.dss.utils.segmentation import CovarianceSegmenter from mne_denoise.viz import plot_psd_comparison from mne_denoise.zapline.adaptive import ( check_artifact_presence, find_fine_peak, find_noise_freqs, - segment_data, ) # Suppress warnings for cleaner output @@ -204,9 +204,12 @@ def generate_nonstationary_data( print("\n--- Step 2: Adaptive Segmentation ---") target_freq = detected_freqs[0] if detected_freqs else 50.0 -segments = segment_data( - data, sfreq, target_freq, min_chunk_len=PAPER_PARAMS["minChunkLength"] +segmenter = CovarianceSegmenter( + sfreq=sfreq, + min_chunk_len=PAPER_PARAMS["minChunkLength"], + bandpass=(target_freq - 3, target_freq + 3), ) +segments = segmenter.segment(data) print(f"Number of segments: {len(segments)}") for i, (start, end) in enumerate(segments): print( diff --git a/mne_denoise/dss/__init__.py b/mne_denoise/dss/__init__.py index 7e9abe36..d5aec5ab 100644 --- a/mne_denoise/dss/__init__.py +++ b/mne_denoise/dss/__init__.py @@ -40,6 +40,13 @@ # Utils (exposed for convenience if needed) from .utils import convergence, whitening +from .utils.segmentation import CovarianceSegmenter, FixedWindowSegmenter +from .utils.selection import ( + auto_select_components, + eigenvalue_ratio_selection, + iterative_outlier_removal, + max_gap_selection, +) # Variants (Modules) from .variants import narrowband, ssvep, tsr diff --git a/mne_denoise/dss/denoisers/periodic.py b/mne_denoise/dss/denoisers/periodic.py index 8ddd08ba..51ccc798 100644 --- a/mne_denoise/dss/denoisers/periodic.py +++ b/mne_denoise/dss/denoisers/periodic.py @@ -140,6 +140,12 @@ class CombFilterBias(LinearDenoiser): Default 3. q_factor : float Quality factor for each peak. Default 30. + q_mode : ``"fixed"`` | ``"proportional"`` + How Q scales across harmonics. ``"fixed"`` uses the same + ``q_factor`` for every harmonic (bandwidth narrows as frequency + increases). ``"proportional"`` scales Q as ``q_factor * h`` for + the *h*-th harmonic, maintaining approximately constant absolute + bandwidth across all harmonics. Default ``"fixed"``. weights : array-like, optional Weights for each harmonic. If None, uses 1/harmonic_number weighting (decreasing importance of higher harmonics). @@ -154,6 +160,10 @@ class CombFilterBias(LinearDenoiser): >>> bias = CombFilterBias( ... fundamental_freq=12, sfreq=500, n_harmonics=4, weights=[1.0, 1.0, 1.0, 1.0] ... ) + >>> # Adaptive Q for constant bandwidth across harmonics + >>> bias = CombFilterBias( + ... fundamental_freq=50, sfreq=1000, n_harmonics=3, q_mode="proportional" + ... ) See Also -------- @@ -177,6 +187,7 @@ def __init__( *, n_harmonics: int = 3, q_factor: float = 30.0, + q_mode: str = "fixed", weights: np.ndarray | None = None, ) -> None: self.fundamental_freq = fundamental_freq @@ -184,6 +195,14 @@ def __init__( self.n_harmonics = n_harmonics self.q_factor = q_factor + # Validate q_mode + allowed_q_modes = ("fixed", "proportional") + if q_mode not in allowed_q_modes: + raise ValueError( + f"q_mode must be one of {allowed_q_modes}, got {q_mode!r}" + ) + self.q_mode = q_mode + # Set up weights if weights is None: self.weights = np.array([1.0 / h for h in range(1, n_harmonics + 1)]) @@ -212,7 +231,11 @@ def _create_harmonic_filters(self) -> None: w0 = freq / nyq weight = self.weights[h - 1] - b, a = signal.iirpeak(w0, self.q_factor) + # Proportional Q scales linearly with harmonic number, + # maintaining constant absolute bandwidth across harmonics + q = self.q_factor * h if self.q_mode == "proportional" else self.q_factor + + b, a = signal.iirpeak(w0, q) sos = signal.tf2sos(b, a) self._peak_filters.append((sos, weight)) diff --git a/mne_denoise/dss/denoisers/temporal.py b/mne_denoise/dss/denoisers/temporal.py index 11c14d6c..7b5bf056 100644 --- a/mne_denoise/dss/denoisers/temporal.py +++ b/mne_denoise/dss/denoisers/temporal.py @@ -12,6 +12,10 @@ Journal of Neuroscience Methods, 189(1), 113-120. .. [2] de Cheveigné, A. & Simon, J.Z. (2008). Denoising based on spatial filtering. Journal of Neuroscience Methods, 171(2), 331-339. +.. [3] de Cheveigné, A. (2020). ZapLine: A simple and effective method to remove + power line artifacts. NeuroImage, 207, 116356. (Period-matched + smooth/residual decomposition: spatially clean only the residual branch + and add the smooth branch back.) """ from __future__ import annotations @@ -151,7 +155,12 @@ def _prediction_bias(self, data: np.ndarray) -> np.ndarray: class SmoothingBias(LinearDenoiser): """Unified temporal smoothing bias (Moving Average). - Uses a boxcar moving average filter to smooth the data." + Uses a boxcar moving average filter to smooth the data. When used to split + the signal into a smooth branch and a residual (``data - smooth``), fitting + DSS on the residual and adding the smooth branch back follows ZapLine's + period-matched decomposition (de Cheveigné, 2020 [3]_): with + ``window = round(sfreq / f_line)`` the smoother has zeros at ``f_line`` and + its harmonics, so the residual concentrates the narrowband artifact. Parameters ---------- diff --git a/mne_denoise/dss/linear.py b/mne_denoise/dss/linear.py index e930965c..756702dc 100644 --- a/mne_denoise/dss/linear.py +++ b/mne_denoise/dss/linear.py @@ -33,6 +33,7 @@ from ..utils import extract_data_from_mne, reconstruct_mne_object from .denoisers import LinearDenoiser from .utils import compute_covariance +from .utils.segmentation import CovarianceSegmenter, FixedWindowSegmenter logger = logging.getLogger(__name__) @@ -93,8 +94,6 @@ def compute_dss( dss_patterns : ndarray, shape (n_channels, n_components) DSS spatial patterns (mixing matrix). Corresponds to the projection matrix **P**. - Note: These are returned in original sensor units (not normalized), - satisfying the identity :math:`X_{rec} = Patterns \times Sources`. eigenvalues : ndarray, shape (n_components,) DSS eigenvalues (ratio of biased power to baseline power). @@ -207,13 +206,10 @@ def compute_dss( unmixing_matrix = eigenvectors_white @ W_white @ eigenvectors_biased # ========================================================================= - # STEP 5: Normalize so components have unit variance on baseline + # STEP 5: Normalize so components have unit variance # ========================================================================= norm_factor = np.diag(unmixing_matrix.T @ covariance_baseline @ unmixing_matrix) - # Use a relative threshold for robustness across physical units (MEG/EEG) - max_norm = np.max(norm_factor) - threshold = 1e-18 * max_norm if max_norm > 0 else 1e-30 - norm_factor = np.where(norm_factor > threshold, norm_factor, 1.0) + norm_factor = np.where(norm_factor > 1e-15, norm_factor, 1.0) unmixing_matrix = unmixing_matrix @ np.diag(1.0 / np.sqrt(norm_factor)) # ========================================================================= @@ -233,9 +229,11 @@ def compute_dss( # ========================================================================= dss_filters = unmixing_matrix.T - # DSS patterns (mixing matrix) - # Note: Patterns are in physical units. Use get_normalized_patterns() for visualization. + # DSS patterns: L2-normalized for topographic visualization (Haufe et al. 2014) dss_patterns = covariance_baseline @ unmixing_matrix + pattern_norms = np.sqrt(np.sum(dss_patterns**2, axis=0)) + pattern_norms = np.where(pattern_norms > 1e-15, pattern_norms, 1.0) + dss_patterns = dss_patterns / pattern_norms return dss_filters, dss_patterns, eigenvalues @@ -259,6 +257,34 @@ class DSS(BaseEstimator, TransformerMixin): Bias function to define the signal of interest. Must be an instance of `mne_denoise.dss.LinearDenoiser` (e.g. `BandpassBias`, `TrialAverageBias`) or a callable that takes data and returns biased data. + n_select : int | 'auto' | None, default=None + Number of significant components to auto-select after fitting. + If ``'auto'``, uses the method specified by ``selection_method`` + to determine significant components. The result is stored + in :attr:`n_selected_`. + If ``int``, uses that exact number. + If ``None`` (default), no automatic selection is performed. + selection_method : {'combined', 'outlier', 'ratio', 'max_gap'}, default='combined' + Algorithm for automatic component selection when ``n_select='auto'``: + + - ``'outlier'``: Iterative outlier removal (mean + sigma × std). + Works best when eigenvalue contrast is high (e.g., ZapLine with + smoothing). Uses ``selection_threshold`` as the sigma parameter. + - ``'ratio'``: Eigenvalue ratio test (scree test). Finds the first + drop ≥ ``selection_threshold`` between consecutive eigenvalues. + Works well for moderate eigenvalue contrast. + - ``'max_gap'``: Maximum gap method. Finds the position of the + biggest drop in the eigenvalue spectrum and uses it as the + cutpoint. Most lenient method; works for weak artifacts. + - ``'combined'`` (default): Cascade of all methods — outlier first, + then ratio, then max_gap — returning the first non-zero result. + selection_threshold : float, default=3.0 + Threshold for automatic component selection. + For ``'outlier'`` method: sigma for outlier detection + (components with eigenvalue > mean + sigma × std). + For ``'ratio'`` method: minimum ratio between consecutive + eigenvalues (default 3.0 means a 3× drop). + For ``'combined'``: uses 3.0 for outlier, 2.0 for ratio fallback. rank : int or dict, optional Rank of the data for whitening. If None, rank is estimated automatically. reg : float @@ -275,6 +301,51 @@ class DSS(BaseEstimator, TransformerMixin): Additional keywords options for covariance estimation. For MNE objects, passed to `mne.compute_covariance` (e.g. `{'tstep': 0.1, 'rank': 'info'}`). For NumPy arrays, passed to `mne_denoise.utils.compute_covariance` (e.g. `{'shrinkage': 0.1}`). + smooth : SmoothingBias | int | None, default=None + Optional smoothing decomposition before DSS, inspired by ZapLine. + When set, data is decomposed into ``smooth + residual`` and DSS + is fitted/applied on the **residual** only. This dramatically + increases eigenvalue contrast for narrowband artifacts because + DSS no longer competes against broadband EEG variance. + + - If ``SmoothingBias`` instance: used directly. + - If ``int``: interpreted as the smoothing window in samples + (e.g., ``int(sfreq / line_freq)`` for line noise). + - If ``None`` (default): no smoothing, DSS is applied to the + full data (original behavior). + segmented : bool, default=False + If ``True``, data is split into segments and DSS is fitted + independently per segment. This handles **non-stationary** + artifacts whose spatial or spectral profile changes over + time. Requires :meth:`fit_transform`; calling :meth:`fit` + alone raises an error. + segmenter : CovarianceSegmenter | FixedWindowSegmenter | None, default=None + Segmentation strategy. If ``None`` and ``segmented=True``, + a :class:`CovarianceSegmenter` is created automatically + (requires ``sfreq`` to be determinable from the input or + from the bias function). + crossfade : float, default=0.0 + Duration (in seconds) of the cross-fade at segment boundaries + when ``segmented=True``. Adjacent segments are extended by + this amount on each side, cleaned independently, then blended + using a raised-cosine (Hann) overlap-add window. This + eliminates discontinuities at segment boundaries. + If ``0.0`` (default), segments are hard-concatenated, matching + ZapLine-plus, which concatenates cleaned chunks directly + (Klug & Kloosterman, 2022); the cross-fade is an ``mne-denoise`` + addition for smoother boundaries. Typical values: ``0.5`` – ``2.0`` s. + max_prop_remove : float | None, default=None + Maximum proportion of channels that can be removed per segment. + E.g. ``0.2`` caps ``n_selected`` at ``int(n_channels × 0.2)``. + Safety valve to prevent over-cleaning; mirrors ZapLine-plus, which + caps the automatic component count at one-fifth of the channels + (Klug & Kloosterman, 2022, §2.4). + min_select : int, default=0 + Minimum components to select when ``n_select='auto'`` and + the artifact is present. Guarantees a floor on cleaning + strength. Only effective when ``segmented=True``. Mirrors + ZapLine-plus's fixed-removal floor (``fixedNremove``; Klug & + Kloosterman, 2022). return_type : {'sources', 'epochs', 'raw'} Type of object to return from `transform`. 'sources' returns a numpy array of DSS components. 'epochs'/'raw' returns the denoised input object. @@ -287,6 +358,14 @@ class DSS(BaseEstimator, TransformerMixin): The spatial patterns (mixing matrix). eigenvalues_ : array, shape (n_components,) The power of each component in the biased data (bias score). + n_selected_ : int | None + Number of significant components detected by automatic selection. + Only set when ``n_select`` is not ``None``. Use this to determine + how many components to remove/keep in downstream processing. + segment_results_ : list of dict | None + Per-segment metadata when ``segmented=True``. Each dict + contains ``'start'``, ``'end'``, ``'n_selected'``, + ``'eigenvalues'``, and ``'patterns'``. Examples -------- @@ -313,30 +392,92 @@ def __init__( self, bias: LinearDenoiser | Callable, n_components: int | None = None, + n_select: int | str | None = None, + selection_method: str = "combined", + selection_threshold: float = 3.0, rank: int | dict | None = None, reg: float = 1e-9, normalize_input: bool = True, cov_method: str = "empirical", cov_kws: dict | None = None, + smooth: LinearDenoiser | int | None = None, + segmented: bool = False, + segmenter: CovarianceSegmenter | FixedWindowSegmenter | None = None, + crossfade: float = 0.0, + max_prop_remove: float | None = None, + min_select: int = 0, return_type: str = "sources", ) -> None: self.n_components = n_components self.bias = bias + self.n_select = n_select + self.selection_method = selection_method + self.selection_threshold = selection_threshold self.rank = rank self.reg = reg self.normalize_input = normalize_input self.cov_method = cov_method self.cov_kws = cov_kws + self.smooth = smooth + self.segmented = segmented + self.segmenter = segmenter + self.crossfade = crossfade + self.max_prop_remove = max_prop_remove + self.min_select = min_select self.return_type = return_type - # Fitted attributes - self.filters_: np.ndarray | None = None - self.patterns_: np.ndarray | None = None - self.mixing_: np.ndarray | None = None - self.eigenvalues_: np.ndarray | None = None - self.explained_variance_: np.ndarray | None = None - self.channel_norms_: np.ndarray | None = None - self._mne_info = None + # Fit attributes + self.filters_ = None + self.patterns_ = None + self.mixing_ = None + self.eigenvalues_ = None + self.explained_variance_ = None + self.info_ = None + self.channel_norms_ = None + self.n_selected_ = None + self.segment_results_ = None + self._smoother = None # Resolved SmoothingBias instance + + def _resolve_smoother(self): + """Resolve the ``smooth`` parameter to a ``SmoothingBias`` instance.""" + from .denoisers.temporal import SmoothingBias + + if self.smooth is None: + self._smoother = None + elif isinstance(self.smooth, int): + self._smoother = SmoothingBias(window=self.smooth, iterations=1) + elif isinstance(self.smooth, SmoothingBias): + self._smoother = self.smooth + elif hasattr(self.smooth, "apply"): + # Duck-type: any LinearDenoiser with .apply() method + self._smoother = self.smooth + else: + raise TypeError( + f"smooth must be SmoothingBias, int, or None, " + f"got {type(self.smooth)}" + ) + + def _decompose_smooth(self, data: np.ndarray): + """Decompose data into smooth and residual components. + + Parameters + ---------- + data : ndarray, shape (n_channels, n_times) or (n_ch, n_times, n_ep) + Input data. + + Returns + ------- + data_smooth : ndarray + Smoothed (low-frequency / broadband) component. + data_residual : ndarray + Residual (narrowband / artifact) component. + """ + if self._smoother is None: + return None, data + + data_smooth = self._smoother.apply(data) + data_residual = data - data_smooth + return data_smooth, data_residual def fit( self, @@ -366,24 +507,129 @@ def fit( self : DSS The fitted transformer. """ + if self.segmented: + raise RuntimeError( + "Segmented mode requires simultaneous fit and transform. " + "Use fit_transform() instead." + ) + if self.normalize_input: X_norm = self._normalize(X, fit=True) else: X_norm = X - if mne is not None and isinstance(X_norm, BaseRaw | BaseEpochs | Evoked): + # Resolve smoothing (if configured) + self._resolve_smoother() + + # If smoothing is enabled, decompose and fit on residual only + if self._smoother is not None: + data, _, mne_type, _, _, _ = extract_data_from_mne(X_norm) + if mne_type == "epochs": + data = np.transpose(data, (1, 2, 0)) + + _, data_residual = self._decompose_smooth(data) + # Fit DSS on residual (always numpy path) + self._fit_numpy(data_residual, weights=weights) + elif mne is not None and isinstance(X_norm, BaseRaw | BaseEpochs | Evoked): self._fit_mne(X_norm, weights=weights) elif isinstance(X_norm, np.ndarray): self._fit_numpy(X_norm, weights=weights) else: raise TypeError(f"Unsupported input type: {type(X_norm)}") - # Compute mixing matrix - # self.patterns_ from compute_dss already satisfy X = P @ S - self.mixing_ = self.patterns_ + # Compute mixing matrix (pseudoinverse of filters) + self.mixing_ = np.linalg.pinv(self.filters_) + + # Automatic component selection + if self.n_select is not None and self.eigenvalues_ is not None: + self.n_selected_ = self.auto_select() return self + def auto_select(self, threshold=None, method=None): + """Automatically determine the number of significant DSS components. + + Supports multiple selection strategies: + + - **outlier**: Iterative outlier removal (mean + sigma × std). + Best for high eigenvalue contrast (e.g., after smoothing). + - **ratio**: Eigenvalue ratio / scree test. Finds the first large + drop between consecutive eigenvalues. For moderate contrast. + - **max_gap**: Maximum gap method. Finds the *biggest* drop in + the eigenvalue spectrum. Most lenient; works for weak artifacts. + - **combined**: Cascade — outlier → ratio → max_gap — returns + the first non-zero result. + + This method is called automatically during :meth:`fit` when + ``n_select`` is set. It can also be called manually after fitting + with a different threshold or method. + + Parameters + ---------- + threshold : float | None + Override the threshold. If ``None``, uses + ``self.selection_threshold``. + method : {'outlier', 'ratio', 'max_gap', 'combined'} | None + Override the selection method. If ``None``, uses + ``self.selection_method``. + + Returns + ------- + n_selected : int + Number of significant components detected. + + Raises + ------ + RuntimeError + If the estimator has not been fitted yet. + + Examples + -------- + >>> dss = DSS(bias=my_bias, n_components=30) + >>> dss.fit(raw) + >>> n = dss.auto_select(threshold=2.5, method='outlier') + >>> print(f"{n} significant components at sigma=2.5") + """ + if self.eigenvalues_ is None: + raise RuntimeError("DSS not fitted. Call fit() first.") + + from .utils.selection import ( + eigenvalue_ratio_selection, + iterative_outlier_removal, + max_gap_selection, + ) + + threshold = threshold if threshold is not None else self.selection_threshold + method = method if method is not None else self.selection_method + + if isinstance(self.n_select, int): + return min(self.n_select, len(self.eigenvalues_)) + + if method == "outlier": + return iterative_outlier_removal(self.eigenvalues_, threshold) + elif method == "ratio": + return eigenvalue_ratio_selection(self.eigenvalues_, threshold) + elif method == "max_gap": + return max_gap_selection(self.eigenvalues_, min_ratio=min(threshold, 1.2)) + elif method == "combined": + # Tier 1: Outlier removal (strict — needs high contrast) + n = iterative_outlier_removal(self.eigenvalues_, threshold) + if n > 0: + return n + # Tier 2: Ratio test (moderate — needs a clear drop) + ratio_th = min(threshold, 2.0) + n = eigenvalue_ratio_selection(self.eigenvalues_, ratio_th) + if n > 0: + return n + # Tier 3: Max gap (lenient — finds the biggest drop wherever) + n = max_gap_selection(self.eigenvalues_, min_ratio=1.2) + return n + else: + raise ValueError( + f"Unknown selection method '{method}'. " + "Choose from 'outlier', 'ratio', 'max_gap', or 'combined'." + ) + def _normalize( self, X: BaseRaw | BaseEpochs | Evoked | np.ndarray, fit: bool = False ) -> BaseRaw | BaseEpochs | Evoked | np.ndarray: @@ -419,8 +665,8 @@ def _normalize( data_2d = data if fit: - # unique std per channel - self.channel_norms_ = np.std(data_2d, axis=1) + # unique norms per channel + self.channel_norms_ = np.linalg.norm(data_2d, axis=1) # Avoid division by zero self.channel_norms_ = np.where( self.channel_norms_ > 0, self.channel_norms_, 1.0 @@ -603,13 +849,20 @@ def transform( if mne_type == "epochs": data = np.transpose(data, (1, 2, 0)) + # If smoothing is enabled, project the residual (not full data) + if self._smoother is not None: + data_smooth, data_for_dss = self._decompose_smooth(data) + else: + data_smooth = None + data_for_dss = data + orig_shape = data.shape - if data.ndim == 3: - n_ch, n_times, n_epochs = data.shape - data_2d = data.reshape(n_ch, -1) + if data_for_dss.ndim == 3: + n_ch, n_times, n_epochs = data_for_dss.shape + data_2d = data_for_dss.reshape(n_ch, -1) else: - n_ch, n_times = data.shape - data_2d = data + n_ch, n_times = data_for_dss.shape + data_2d = data_for_dss # Center using mean on data_2d # DSS implies zero-mean assumption for correct projection @@ -634,6 +887,15 @@ def transform( rec = self.mixing_[:, :n_keep] @ sources[:n_keep] rec += mean_ + # Add back smooth component if it was separated + if data_smooth is not None: + smooth_2d = ( + data_smooth.reshape(data_smooth.shape[0], -1) + if data_smooth.ndim == 3 + else data_smooth + ) + rec = rec + smooth_2d + # Reshape to original if len(orig_shape) == 3: rec = rec.reshape(orig_shape) # (n_ch, n_times, n_epochs) @@ -653,6 +915,24 @@ def transform( rec, orig_inst, mne_type, picks=picks, verbose=False ) + def get_normalized_patterns(self) -> np.ndarray: + """Get L2-normalized spatial patterns for visualization. + + Returns + ------- + patterns_norm : ndarray, shape (n_channels, n_components) + L2-normalized spatial patterns. + """ + if self.patterns_ is None: + raise RuntimeError("DSS not fitted. Call fit() first.") + + norms = np.linalg.norm(self.patterns_, axis=0) + # Use relative threshold for physical units + max_norm = np.max(norms) + threshold = 1e-15 * max_norm if max_norm > 0 else 1e-30 + norms = np.where(norms > threshold, norms, 1.0) + return self.patterns_ / norms + def inverse_transform( self, sources: np.ndarray, component_indices: np.ndarray | None = None ) -> np.ndarray: @@ -731,20 +1011,406 @@ def inverse_transform( return rec - def get_normalized_patterns(self) -> np.ndarray: - """Get L2-normalized spatial patterns for visualization. + # ----------------------------------------------------------------- + # Segmented mode + # ----------------------------------------------------------------- + + def fit_transform( + self, X, y=None, **fit_params + ): + """Fit and transform data in one step. + + In **segmented mode** (``segmented=True``), the data is split into + segments and each segment gets its own independent DSS fit + + cleaning pass. This is the only entry-point for segmented + processing because ``fit()`` alone is not meaningful when + filters differ per segment. + + In standard mode, this is equivalent to + ``self.fit(X).transform(X)``. + + Parameters + ---------- + X : Raw | Epochs | Evoked | ndarray + The data to process. + y : None + Ignored. + **fit_params + Additional keyword arguments forwarded to :meth:`fit`. Returns ------- - patterns_norm : ndarray, shape (n_channels, n_components) - L2-normalized spatial patterns. + X_out : ndarray | Raw | Epochs | Evoked + In segmented mode, returns cleaned data (same type as input). + In standard mode with ``return_type='sources'``, returns DSS + source time-series. With any other ``return_type``, returns + cleaned (denoised) data produced by subtracting the artifact + captured by the first ``n_selected_`` components. """ - if self.patterns_ is None: - raise RuntimeError("DSS not fitted. Call fit() first.") + if not self.segmented: + self.fit(X, **fit_params) + + if self.return_type == "sources": + return self.transform(X) + + # ── Denoise via artifact subtraction ── + data, _, mne_type, orig_inst, _, _ = extract_data_from_mne(X) + + n_remove = self.n_selected_ if self.n_selected_ is not None else 0 + if n_remove > 0: + # Temporarily switch to get source time-series + saved_rt = self.return_type + self.return_type = "sources" + try: + sources = self.transform(X) + finally: + self.return_type = saved_rt + + artifact = self.inverse_transform( + sources, component_indices=np.arange(n_remove) + ) + cleaned = data - artifact + else: + cleaned = data - norms = np.linalg.norm(self.patterns_, axis=0) - # Use relative threshold for physical units - max_norm = np.max(norms) - threshold = 1e-15 * max_norm if max_norm > 0 else 1e-30 - norms = np.where(norms > threshold, norms, 1.0) - return self.patterns_ / norms + return reconstruct_mne_object( + cleaned, orig_inst, mne_type, verbose=False + ) + + # --- segmented mode --- + data, extracted_sfreq, mne_type, orig_inst, _, _ = extract_data_from_mne(X) + + # Determine sfreq + sfreq = extracted_sfreq + if sfreq is None and hasattr(self.bias, "sfreq"): + sfreq = self.bias.sfreq + if sfreq is None: + raise ValueError( + "Cannot determine sfreq for segmented mode. " + "Pass an MNE object or use a bias with a .sfreq attribute." + ) + + # Handle epochs: concatenate into continuous + is_epochs = False + if data.ndim == 3: + is_epochs = True + n_ep, n_ch, n_t = data.shape + data_cont = np.transpose(data, (1, 0, 2)).reshape(n_ch, -1) + else: + data_cont = data + + # Resolve smoother once + self._resolve_smoother() + + # Run segmented processing + cleaned = self._run_segmented(data_cont, sfreq) + + # Reshape back if epochs + if is_epochs: + cleaned = cleaned.reshape(n_ch, n_ep, n_t).transpose(1, 0, 2) + + return reconstruct_mne_object(cleaned, orig_inst, mne_type, verbose=False) + + def _resolve_segmenter(self, sfreq: float): + """Resolve the segmenter parameter. + + If ``self.segmenter`` is ``None``, creates a default + :class:`CovarianceSegmenter` with optional bandpass from the + bias function. + + Parameters + ---------- + sfreq : float + Sampling frequency in Hz. + + Returns + ------- + segmenter : CovarianceSegmenter | FixedWindowSegmenter + """ + if self.segmenter is not None: + return self.segmenter + + # Build a default CovarianceSegmenter + bandpass = None + # If the bias has a target frequency, focus segmentation around it + if hasattr(self.bias, "freq") and self.bias.freq is not None: + f = float(self.bias.freq) + bandpass = (max(1.0, f - 3), min(sfreq / 2 - 1, f + 3)) + + return CovarianceSegmenter( + sfreq=sfreq, + min_chunk_len=30.0, + bandpass=bandpass, + ) + + def _run_segmented(self, data: np.ndarray, sfreq: float) -> np.ndarray: + """Run segmented fit-transform on continuous data. + + Each segment gets an independent DSS fit and cleaning pass. + When :attr:`crossfade` is positive and there are multiple + segments, adjacent segments are extended by ``crossfade`` + seconds on each side and combined using raised-cosine (Hann) + overlap-add to eliminate boundary discontinuities. + + Parameters + ---------- + data : ndarray, shape (n_channels, n_times) + Continuous data. + sfreq : float + Sampling frequency. + + Returns + ------- + cleaned : ndarray, shape (n_channels, n_times) + Cleaned data (segments blended via cross-fade or + concatenated). + """ + import logging + + logger = logging.getLogger(__name__) + + segmenter = self._resolve_segmenter(sfreq) + segments = segmenter.segment(data) + + if not segments: + raise ValueError( + "Segmenter returned no segments. Check segmenter settings " + "and data length." + ) + + logger.info( + f"Segmented DSS: {len(segments)} segment(s) " + f"over {data.shape[1] / sfreq:.1f}s" + ) + + # ------ cross-fade setup ------ + n_overlap = ( + int(self.crossfade * sfreq) if self.crossfade > 0 else 0 + ) + _n_ch, n_times = data.shape + use_crossfade = n_overlap > 0 and len(segments) > 1 + + if use_crossfade: + min_seg_len = min(end - start for start, end in segments) + if n_overlap > min_seg_len // 2: + n_overlap = max(1, min_seg_len // 2) + logger.warning( + f"Crossfade overlap clamped to {n_overlap} samples " + f"({n_overlap / sfreq:.2f}s) — half the smallest " + f"segment." + ) + + # ------ per-segment processing ------ + self.segment_results_ = [] + cleaned_chunks: list[dict] = [] + per_segment_n_removed: list[int] = [] + + for seg_idx, (start, end) in enumerate(segments): + # Optionally extend boundaries for cross-fade context + if use_crossfade: + is_first = seg_idx == 0 + is_last = seg_idx == len(segments) - 1 + ext_start = start if is_first else max(0, start - n_overlap) + ext_end = end if is_last else min(n_times, end + n_overlap) + else: + ext_start, ext_end = start, end + + chunk = data[:, ext_start:ext_end] + result = self._process_segment(chunk) + + cleaned_chunks.append( + { + "cleaned": result["cleaned"], + "ext_start": ext_start, + "ext_end": ext_end, + "start": start, + "end": end, + } + ) + per_segment_n_removed.append(result["n_selected"]) + + # Store per-segment metadata + self.segment_results_.append( + { + "start": start, + "end": end, + "n_selected": result["n_selected"], + "eigenvalues": result["eigenvalues"], + "patterns": result["patterns"], + } + ) + + # Keep last segment's filters/patterns as representative + if result["eigenvalues"] is not None: + self.eigenvalues_ = result["eigenvalues"] + if result["patterns"] is not None: + self.patterns_ = result["patterns"] + if result["filters"] is not None: + self.filters_ = result["filters"] + self.mixing_ = np.linalg.pinv(self.filters_) + + self.n_selected_ = ( + max(per_segment_n_removed) if per_segment_n_removed else 0 + ) + + # ------ combine segments ------ + if use_crossfade: + return self._crossfade_combine( + data.shape, cleaned_chunks, n_overlap + ) + return np.concatenate( + [c["cleaned"] for c in cleaned_chunks], axis=1 + ) + + def _crossfade_combine( + self, + shape: tuple[int, int], + cleaned_chunks: list[dict], + n_overlap: int, + ) -> np.ndarray: + """Combine cleaned segments with raised-cosine overlap-add. + + Each chunk has been cleaned over an extended region that + overlaps with its neighbours. A Hann-based window tapers + the overlap zones; after accumulating all weighted chunks + the output is normalised by the sum of weights, producing + a smooth, discontinuity-free result. + + Parameters + ---------- + shape : (int, int) + ``(n_channels, n_times)`` — shape of the output array. + cleaned_chunks : list of dict + Each dict contains ``'cleaned'`` (ndarray), + ``'ext_start'``, ``'ext_end'``, ``'start'``, ``'end'``. + n_overlap : int + Overlap length in samples. + + Returns + ------- + output : ndarray, shape (n_channels, n_times) + """ + n_ch, n_times = shape + output = np.zeros((n_ch, n_times)) + weights = np.zeros(n_times) + + for info in cleaned_chunks: + cleaned = info["cleaned"] + ext_start = info["ext_start"] + ext_end = info["ext_end"] + start = info["start"] + end = info["end"] + + chunk_len = ext_end - ext_start + window = np.ones(chunk_len) + + # Fade-in: leading overlap (before original segment start) + lead = start - ext_start + if lead > 0: + t = np.arange(lead, dtype=float) + window[:lead] = 0.5 * (1.0 - np.cos(np.pi * t / lead)) + + # Fade-out: trailing overlap (after original segment end) + trail = ext_end - end + if trail > 0: + t = np.arange(trail, dtype=float) + window[chunk_len - trail:] = 0.5 * ( + 1.0 + np.cos(np.pi * t / trail) + ) + + output[:, ext_start:ext_end] += cleaned * window[np.newaxis, :] + weights[ext_start:ext_end] += window + + # Normalise (weights > 0 everywhere because segments tile the data) + weights = np.maximum(weights, 1e-10) + output /= weights[np.newaxis, :] + return output + + def _process_segment(self, chunk: np.ndarray) -> dict: + """Process a single segment: fit DSS, select components, clean. + + Parameters + ---------- + chunk : ndarray, shape (n_channels, n_times) + Data segment. + + Returns + ------- + result : dict + Contains 'cleaned', 'n_selected', 'eigenvalues', 'patterns', + 'filters'. + """ + n_channels = chunk.shape[0] + + # Create a fresh DSS for this segment (non-segmented) + seg_dss = DSS( + bias=self.bias, + n_components=self.n_components, + n_select=self.n_select, + selection_method=self.selection_method, + selection_threshold=self.selection_threshold, + rank=self.rank if isinstance(self.rank, int | type(None)) else None, + reg=self.reg, + normalize_input=self.normalize_input, + cov_method=self.cov_method, + cov_kws=self.cov_kws, + smooth=self.smooth, + segmented=False, # Do NOT recurse + ) + + seg_dss.fit(chunk) + n_sel = seg_dss.n_selected_ if seg_dss.n_selected_ is not None else 0 + + # Apply caps + if self.max_prop_remove is not None: + n_sel = min(n_sel, int(n_channels * self.max_prop_remove)) + n_sel = max(n_sel, self.min_select) + + # Clean the segment + cleaned = self._clean_segment(chunk, seg_dss, n_sel) + + return { + "cleaned": cleaned, + "n_selected": n_sel, + "eigenvalues": seg_dss.eigenvalues_, + "patterns": seg_dss.patterns_, + "filters": seg_dss.filters_, + } + + def _clean_segment( + self, data: np.ndarray, fitted_dss: DSS, n_remove: int + ) -> np.ndarray: + """Clean a segment by projecting out *n_remove* DSS components. + + Parameters + ---------- + data : ndarray, shape (n_channels, n_times) + Segment data. + fitted_dss : DSS + A fitted DSS instance (with ``filters_``, ``mixing_``, etc.). + n_remove : int + Number of components to remove. + + Returns + ------- + cleaned : ndarray, shape (n_channels, n_times) + """ + if n_remove <= 0 or fitted_dss.filters_ is None: + return data.copy() + + # Smoothing decomposition (if configured) + if fitted_dss._smoother is not None: + data_smooth, data_residual = fitted_dss._decompose_smooth(data) + else: + data_smooth = np.zeros_like(data) + data_residual = data + + # Center residual before projection (DSS assumes zero-mean) + mean_ = data_residual.mean(axis=1, keepdims=True) + residual_centered = data_residual - mean_ + + # Project residual through the top n_remove DSS filters + sources = fitted_dss.filters_[:n_remove] @ residual_centered + artifact = fitted_dss.mixing_[:, :n_remove] @ sources + + return data_smooth + (data_residual - artifact) diff --git a/mne_denoise/dss/nonlinear.py b/mne_denoise/dss/nonlinear.py index b2dd487f..b21467ce 100644 --- a/mne_denoise/dss/nonlinear.py +++ b/mne_denoise/dss/nonlinear.py @@ -286,8 +286,7 @@ def iterative_dss( Extracted source time series. patterns : ndarray, shape (n_channels, n_components) Spatial patterns for visualization / reconstruction. - Note: These are returned in original sensor units (not normalized), - satisfying the identity :math:`X_{recon} = patterns @ sources`. + Reconstruct as: ``data_recon = patterns @ sources``. convergence_info : ndarray, shape (n_components, 2) ``[n_iterations, converged]`` for each component. @@ -375,9 +374,13 @@ def iterative_dss( # sensor_filter = whitened_filter @ whitener filters = filters_whitened @ whitener # (n_components, n_channels) - # patterns = C @ filters.T + # Compute patterns using covariance of CENTERED data + # patterns = C @ filters.T, normalized C = data_centered @ data_centered.T / n_samples patterns = C @ filters.T + pattern_norms = np.linalg.norm(patterns, axis=0) + pattern_norms = np.where(pattern_norms > 1e-12, pattern_norms, 1.0) + patterns = patterns / pattern_norms return filters, sources, patterns, convergence_info @@ -475,6 +478,7 @@ def _iterative_dss_deflation( # Orthogonalize against previous components (vectorized) if i > 0: W_prev = W[:i] # (i, n_whitened) + # w = w - sum( (w @ w_prev.T) * w_prev ) # Vectorized: w - W_prev.T @ (W_prev @ w) w = w - W_prev.T @ (W_prev @ w) norm = np.linalg.norm(w) @@ -728,7 +732,6 @@ def __init__( method: str = "deflation", rank: int | None = None, reg: float = 1e-9, - normalize_input: bool = True, max_iter: int = 100, tol: float = 1e-6, verbose: bool = False, @@ -742,7 +745,6 @@ def __init__( self.method = method self.rank = rank self.reg = reg - self.normalize_input = normalize_input self.max_iter = max_iter self.tol = tol self.verbose = verbose @@ -787,24 +789,6 @@ def fit(self, X) -> IterativeDSS: ): self._mne_info = mne_info.info - if self.normalize_input: - # Flatten for std calculation: (n_ch, n_times * n_epochs) - d_flat = ( - data.transpose(1, 0, 2).reshape(data.shape[1], -1) - if data.ndim == 3 - else data - ) - self.channel_norms_ = np.std(d_flat, axis=1) - self.channel_norms_ = np.where( - self.channel_norms_ > 0, self.channel_norms_, 1.0 - ) - - # Apply to data - if data.ndim == 3: - data = data / self.channel_norms_[np.newaxis, :, np.newaxis] - else: - data = data / self.channel_norms_[:, np.newaxis] - filters, sources, patterns, conv_info = iterative_dss( data, self.denoiser, @@ -857,12 +841,7 @@ def transform(self, X) -> np.ndarray: else: data_2d = data - if self.normalize_input: - if self.channel_norms_ is None: - raise RuntimeError( - "IterativeDSS not fitted with normalize_input=True. Call fit() first." - ) - data_2d = data_2d / self.channel_norms_[:, np.newaxis] + n_times = data_2d.shape[1] # Center data_centered = data_2d - data_2d.mean(axis=1, keepdims=True) @@ -898,46 +877,10 @@ def inverse_transform(self, sources: np.ndarray) -> np.ndarray: if self.patterns_ is None: raise RuntimeError("IterativeDSS not fitted. Call fit() first.") - n_comp_sources = sources.shape[1] if sources.ndim == 3 else sources.shape[0] - patterns = self.patterns_[:, :n_comp_sources] - - if sources.ndim == 3: - # Assume MNE format (n_epochs, n_comp, n_times) - rec = np.tensordot(sources, patterns, axes=(1, 1)).transpose(0, 2, 1) - if self.normalize_input: - if self.channel_norms_ is None: - raise RuntimeError( - "IterativeDSS not fitted with normalize_input=True. Call fit() first." - ) - rec *= self.channel_norms_[np.newaxis, :, np.newaxis] - else: - rec = patterns @ sources - if self.normalize_input: - if self.channel_norms_ is None: - raise RuntimeError( - "IterativeDSS not fitted with normalize_input=True. Call fit() first." - ) - rec *= self.channel_norms_[:, np.newaxis] - - return rec - - def get_normalized_patterns(self) -> np.ndarray: - """Get L2-normalized spatial patterns for visualization. - - Returns - ------- - patterns_norm : ndarray, shape (n_channels, n_components) - L2-normalized spatial patterns. - """ - if self.patterns_ is None: - raise RuntimeError("IterativeDSS not fitted. Call fit() first.") + n_sources = sources.shape[0] + patterns = self.patterns_[:, :n_sources] - norms = np.linalg.norm(self.patterns_, axis=0) - # Use relative threshold for physical units - max_norm = np.max(norms) - threshold = 1e-15 * max_norm if max_norm > 0 else 1e-30 - norms = np.where(norms > threshold, norms, 1.0) - return self.patterns_ / norms + return patterns @ sources def fit_transform(self, X) -> np.ndarray: """Fit and transform in one step. diff --git a/mne_denoise/dss/utils/__init__.py b/mne_denoise/dss/utils/__init__.py index 978cb508..f2f3ba1f 100644 --- a/mne_denoise/dss/utils/__init__.py +++ b/mne_denoise/dss/utils/__init__.py @@ -2,11 +2,14 @@ from .convergence import Gamma179, GammaPredictive from .covariance import compute_covariance +from .segmentation import CovarianceSegmenter, FixedWindowSegmenter from .selection import ( auto_select_components, auto_select_components_robust, detect_eigenvalue_knee, + eigenvalue_ratio_selection, iterative_outlier_removal, + max_gap_selection, ) from .whitening import compute_whitener, whiten_data @@ -18,6 +21,10 @@ "auto_select_components", "auto_select_components_robust", "detect_eigenvalue_knee", + "eigenvalue_ratio_selection", + "max_gap_selection", "Gamma179", "GammaPredictive", + "CovarianceSegmenter", + "FixedWindowSegmenter", ] diff --git a/mne_denoise/dss/utils/segmentation.py b/mne_denoise/dss/utils/segmentation.py new file mode 100644 index 00000000..4b762027 --- /dev/null +++ b/mne_denoise/dss/utils/segmentation.py @@ -0,0 +1,211 @@ +"""Data segmentation utilities for DSS. + +Provides strategies for splitting continuous data into segments based on +statistical properties. Used by :class:`~mne_denoise.dss.linear.DSS` in +segmented mode to handle non-stationary artifacts. + +Available segmenters +-------------------- +- :class:`CovarianceSegmenter` – splits at covariance-stationarity boundaries. +- :class:`FixedWindowSegmenter` – splits into fixed-length windows. + +Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca) + Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca) + +References +---------- +.. [1] Klug, M., & Kloosterman, N. A. (2022). Zapline-plus: A Zapline + extension for automatic and adaptive removal of frequency-specific + noise artifacts in M/EEG. Human Brain Mapping, 43(9), 2743-2758. +""" + +from __future__ import annotations + +import numpy as np +from scipy import signal +from scipy.signal import find_peaks + +from .covariance import compute_covariance + + +# --------------------------------------------------------------------------- +# Covariance-based segmenter (generalised from ZapLine-plus) +# --------------------------------------------------------------------------- + + +class CovarianceSegmenter: + """Segment data based on covariance stationarity. + + Identifies boundaries where the spatial covariance matrix changes + significantly, indicating non-stationary noise characteristics. + Based on the ZapLine-plus segmentation algorithm [1]_. + + Parameters + ---------- + sfreq : float + Sampling frequency in Hz. + min_chunk_len : float, default=30.0 + Minimum segment length in seconds. + cov_win_len : float, default=1.0 + Window length for covariance computation in seconds. + bandpass : tuple of (float, float) | None, default=None + Bandpass filter range ``(f_low, f_high)`` in Hz to focus the + stationarity analysis on a specific frequency band. Useful for + frequency-specific artifacts. If ``None``, uses unfiltered data. + + References + ---------- + .. [1] Klug & Kloosterman (2022). Zapline-plus … + """ + + def __init__( + self, + sfreq: float, + min_chunk_len: float = 30.0, + cov_win_len: float = 1.0, + bandpass: tuple[float, float] | None = None, + ) -> None: + self.sfreq = float(sfreq) + self.min_chunk_len = min_chunk_len + self.cov_win_len = cov_win_len + self.bandpass = bandpass + + def segment(self, data: np.ndarray) -> list[tuple[int, int]]: + """Segment data into stationary chunks. + + Parameters + ---------- + data : ndarray, shape (n_channels, n_times) + Input data. + + Returns + ------- + segments : list of (int, int) + List of ``(start_sample, end_sample)`` tuples. + """ + n_channels, n_times = data.shape + + # Optional bandpass filter to focus analysis + if self.bandpass is not None: + f_low, f_high = self.bandpass + sos = signal.butter( + 4, [f_low, f_high], btype="bandpass", fs=self.sfreq, output="sos" + ) + data_filt = signal.sosfiltfilt(sos, data, axis=1) + else: + data_filt = data + + # Compute sliding-window covariance series + n_win = int(self.cov_win_len * self.sfreq) + if n_win > n_times: + return [(0, n_times)] + + n_steps = n_times // n_win + + covs = [] + for i in range(n_steps): + start = i * n_win + end = start + n_win + chunk = data_filt[:, start:end] + cov = compute_covariance(chunk) + tr = np.trace(cov) + if tr > 1e-20: + cov = cov / tr + covs.append(cov) + + covs = np.array(covs) + + # Successive Frobenius distances + dists = np.array( + [ + np.linalg.norm(covs[i] - covs[i + 1], ord="fro") + for i in range(len(covs) - 1) + ] + ) + + if len(dists) == 0: + return [(0, n_times)] + + # Detect peaks (boundary candidates) + min_distance = max(1, int(self.min_chunk_len * self.sfreq / n_win)) + peak_indices, _ = find_peaks( + dists, prominence=np.std(dists) * 0.5, distance=min_distance + ) + boundary_indices = (peak_indices + 1) * n_win + + # Enforce minimum segment length + valid_boundaries = [0] + last_boundary = 0 + min_samples = int(self.min_chunk_len * self.sfreq) + + for b in boundary_indices: + if (b - last_boundary) >= min_samples: + valid_boundaries.append(b) + last_boundary = b + + if (n_times - last_boundary) < min_samples and len(valid_boundaries) > 1: + valid_boundaries.pop() + + valid_boundaries.append(n_times) + + return [ + (valid_boundaries[i], valid_boundaries[i + 1]) + for i in range(len(valid_boundaries) - 1) + ] + + +# --------------------------------------------------------------------------- +# Fixed-window segmenter +# --------------------------------------------------------------------------- + + +class FixedWindowSegmenter: + """Segment data into fixed-length windows. + + Simple segmentation strategy that divides data into equal-length chunks. + The last chunk may be merged with the previous one if it is too short. + + Parameters + ---------- + sfreq : float + Sampling frequency in Hz. + window_len : float, default=30.0 + Window length in seconds. + """ + + def __init__(self, sfreq: float, window_len: float = 30.0) -> None: + self.sfreq = float(sfreq) + self.window_len = window_len + + def segment(self, data: np.ndarray) -> list[tuple[int, int]]: + """Segment data into fixed-length windows. + + Parameters + ---------- + data : ndarray, shape (n_channels, n_times) + Input data. + + Returns + ------- + segments : list of (int, int) + List of ``(start_sample, end_sample)`` tuples. + """ + n_times = data.shape[1] + win_samples = int(self.window_len * self.sfreq) + + if win_samples >= n_times: + return [(0, n_times)] + + segments: list[tuple[int, int]] = [] + start = 0 + while start < n_times: + end = min(start + win_samples, n_times) + # Merge a tiny trailing segment (< 50 % of window) into the last + if end - start < win_samples // 2 and segments: + prev_start, _ = segments[-1] + segments[-1] = (prev_start, end) + else: + segments.append((start, end)) + start = end + + return segments diff --git a/mne_denoise/dss/utils/selection.py b/mne_denoise/dss/utils/selection.py index 978a40e2..97e45966 100644 --- a/mne_denoise/dss/utils/selection.py +++ b/mne_denoise/dss/utils/selection.py @@ -1,6 +1,7 @@ """Component selection utilities for DSS. -Provides automatic component selection using outlier detection. +Provides automatic component selection using outlier detection and +eigenvalue ratio analysis. Authors: Sina Esmaeili (sina.esmaeili@umontreal.ca) Hamza Abdelhedi (hamza.abdelhedi@umontreal.ca) @@ -16,7 +17,18 @@ def iterative_outlier_removal(scores: np.ndarray, sigma: float = 3.0) -> int: This algorithm iteratively identifies values that exceed `mean + sigma * std`, removes them from consideration, and repeats until no more outliers are found. - This is equivalent to MATLAB's `iterative_outlier_removal` from NoiseTools. + It follows NoiseTools' ``nt_dss``-style outlier rule, which ZapLine-plus + (Klug & Kloosterman, 2022, §2.4 "Detection of noise components") adopts to + automatically choose how many spatial components to remove: outliers in the + component scores are flagged with a ``mean + sigma * SD`` threshold and + removed, the mean/SD are recomputed across the remaining components, and the + loop repeats until none are left; the count of removed outliers is taken as + the number of components to reject. ZapLine-plus uses a default of + ``sigma=3`` and reports this iterative mean/SD rule to be more robust than a + median-absolute-deviation rule in this setting. Callers cap the resulting + count via ``max_prop_remove`` (ZapLine-plus caps at one-fifth of the + channels) and floor it via ``min_select``; see + :class:`~mne_denoise.dss.linear.DSS`. Useful for automatic component selection in DSS applications, such as: - ZapLine: Selecting how many line-noise components to remove @@ -46,7 +58,11 @@ def iterative_outlier_removal(scores: np.ndarray, sigma: float = 3.0) -> int: References ---------- - NoiseTools: http://audition.ens.fr/adc/NoiseTools/ + .. [1] Klug, M., & Kloosterman, N. A. (2022). Zapline-plus: A Zapline + extension for automatic and adaptive removal of frequency-specific + noise artifacts in M/EEG. Human Brain Mapping, 43(9), 2743-2758. + (§2.4, "Detection of noise components".) + .. [2] NoiseTools: http://audition.ens.fr/adc/NoiseTools/ """ scores = np.asarray(scores) n_outliers = 0 @@ -91,6 +107,127 @@ def auto_select_components(eigenvalues: np.ndarray, threshold: float = 3.0) -> i return iterative_outlier_removal(eigenvalues, threshold) +def eigenvalue_ratio_selection( + eigenvalues: np.ndarray, ratio_threshold: float = 2.0 +) -> int: + """Select components using eigenvalue ratio (scree test). + + Identifies the first significant "drop" in the eigenvalue spectrum. + A component is considered significant if the ratio + ``eigenvalue[i] / eigenvalue[i+1]`` exceeds ``ratio_threshold``. + + This method works well when eigenvalue contrast is moderate (e.g., direct + DSS without smoothing), where the iterative outlier removal may fail to + detect any significant components. + + .. note:: + Unlike :func:`iterative_outlier_removal` (the ZapLine-plus component + detector), this scree-style ratio rule is a convenience heuristic + specific to ``mne-denoise``: it is not part of ZapLine-plus or the DSS + literature, and ``ratio_threshold=2.0`` is an empirical default, not a + theory-derived constant. + + Parameters + ---------- + eigenvalues : ndarray + DSS eigenvalues, sorted in descending order. + ratio_threshold : float + Minimum ratio between consecutive eigenvalues to indicate a + significant drop. Default 2.0 (i.e., a 2× drop). + + Returns + ------- + n_components : int + Number of significant components (before the first big drop). + Returns 0 if no drop exceeds the threshold. + + Examples + -------- + >>> eigenvalues = np.array([0.012, 0.005, 0.004, 0.0015, 0.0008]) + >>> n = eigenvalue_ratio_selection(eigenvalues, ratio_threshold=2.0) + >>> print(f"{n} component(s) before the first big drop") + 1 component(s) before the first big drop + """ + eigenvalues = np.asarray(eigenvalues, dtype=float) + + if len(eigenvalues) < 2: + return len(eigenvalues) + + # Guard against zero/negative eigenvalues + eigenvalues = np.maximum(eigenvalues, 0.0) + + for i in range(len(eigenvalues) - 1): + if eigenvalues[i + 1] < 1e-15: + # Denominator effectively zero → everything before this is signal + return i + 1 + ratio = eigenvalues[i] / eigenvalues[i + 1] + if ratio >= ratio_threshold: + return i + 1 + + return 0 + + +def max_gap_selection( + eigenvalues: np.ndarray, min_ratio: float = 1.2 +) -> int: + """Select components by finding the largest gap in the eigenvalue spectrum. + + Instead of requiring a fixed ratio threshold, this method finds the + position of the maximum consecutive eigenvalue ratio (biggest "drop") + and uses it as the cutpoint — provided the drop exceeds ``min_ratio``. + + This is the most lenient automatic method and works even when eigenvalue + contrast is weak (e.g., direct DSS without smoothing on weak artifacts). + + .. note:: + Like :func:`eigenvalue_ratio_selection`, this largest-gap rule is a + convenience heuristic specific to ``mne-denoise`` (not part of + ZapLine-plus or the DSS literature); ``min_ratio=1.2`` is empirical. + Prefer :func:`iterative_outlier_removal` for line-noise / ZapLine-style + selection. + + Parameters + ---------- + eigenvalues : ndarray + DSS eigenvalues, sorted in descending order. + min_ratio : float + Minimum ratio at the largest gap to be considered meaningful. + Default 1.2 (i.e., at least a 20%% drop). Set lower for noisier + data, higher for stricter selection. + + Returns + ------- + n_components : int + Number of significant components (before the biggest gap). + Returns 0 if no gap exceeds ``min_ratio``. + + Examples + -------- + >>> eigenvalues = np.array([0.0066, 0.0052, 0.0042, 0.0035, 0.003]) + >>> n = max_gap_selection(eigenvalues, min_ratio=1.2) + >>> print(f"{n} component(s)") # Finds the 1.26× drop at position 0 + 1 component(s) + """ + eigenvalues = np.asarray(eigenvalues, dtype=float) + + if len(eigenvalues) < 2: + return len(eigenvalues) + + eigenvalues = np.maximum(eigenvalues, 0.0) + + # Compute consecutive ratios + denominators = np.maximum(eigenvalues[1:], 1e-15) + ratios = eigenvalues[:-1] / denominators + + # Find the position of the largest gap + max_idx = np.argmax(ratios) + + if ratios[max_idx] >= min_ratio: + return int(max_idx + 1) + + return 0 + + def detect_eigenvalue_knee( scores: np.ndarray, rel_floor: float = 0.01, @@ -104,8 +241,8 @@ def detect_eigenvalue_knee( components and the noise floor. Returns ``k`` (the count of components above the knee). - The drop must (a) occur above a relative floor — to avoid picking knees - that sit entirely in the noise tail — and (b) exceed ``log10(min_ratio)`` + The drop must (a) occur above a relative floor — to avoid picking knees + that sit entirely in the noise tail — and (b) exceed ``log10(min_ratio)`` decades, so a smoothly decaying spectrum without a clear bimodal split correctly returns 0. @@ -212,7 +349,7 @@ def auto_select_components_robust( components don't stand out from each other; :func:`detect_eigenvalue_knee` catches the boundary to the noise floor. - **Clean spectrum** (monotonic decay, no clear bimodal split): both - return 0 — no false removals. + return 0 — no false removals. Parameters ---------- diff --git a/mne_denoise/dss/variants/narrowband.py b/mne_denoise/dss/variants/narrowband.py index e8ab23dc..ad3777ae 100644 --- a/mne_denoise/dss/variants/narrowband.py +++ b/mne_denoise/dss/variants/narrowband.py @@ -115,6 +115,12 @@ def narrowband_scan( """ data = np.asarray(data) + if dss_kws.get("segmented", False): + raise ValueError( + "narrowband_scan does not support segmented=True. " + "Run narrowband_scan per segment manually if needed." + ) + nyquist = sfreq / 2 min_freq, max_freq = freq_range diff --git a/mne_denoise/zapline/adaptive.py b/mne_denoise/zapline/adaptive.py index bb32ff9a..15f5bb09 100644 --- a/mne_denoise/zapline/adaptive.py +++ b/mne_denoise/zapline/adaptive.py @@ -45,7 +45,7 @@ from scipy import signal from scipy.signal import find_peaks, welch -from ..dss.utils.covariance import compute_covariance +from ..dss.utils.segmentation import CovarianceSegmenter logger = logging.getLogger(__name__) @@ -283,6 +283,10 @@ def segment_data( Identifies boundaries where the noise characteristics change significantly by tracking changes in the spatial covariance matrix over time. + .. deprecated:: + Use :class:`~mne_denoise.dss.utils.segmentation.CovarianceSegmenter` + directly for new code. This wrapper is kept for backward compatibility. + Parameters ---------- data : ndarray, shape (n_channels, n_times) @@ -300,74 +304,19 @@ def segment_data( ------- segments : list of tuple List of ``(start_sample, end_sample)`` tuples defining segment boundaries. - """ - n_channels, n_times = data.shape - - # 1. Filter around target freq - f_low = target_freq - 3 - f_high = target_freq + 3 - - sos = signal.butter(4, [f_low, f_high], btype="bandpass", fs=sfreq, output="sos") - data_filt = signal.sosfiltfilt(sos, data, axis=1) - - # 2. Compute covariance series - n_win = int(cov_win_len * sfreq) - if n_win > n_times: - return [(0, n_times)] - - n_steps = n_times // n_win - - covs = [] - for i in range(n_steps): - start = i * n_win - end = start + n_win - chunk = data_filt[:, start:end] - cov = compute_covariance(chunk) - tr = np.trace(cov) - if tr > 1e-20: - cov = cov / tr - covs.append(cov) - - covs = np.array(covs) - # 3. Successive distances - dists = [] - for i in range(len(covs) - 1): - d = np.linalg.norm(covs[i] - covs[i + 1], ord="fro") - dists.append(d) - - dists = np.array(dists) - - if len(dists) == 0: - return [(0, n_times)] - - # 4. Detect peaks (boundaries) - use distance. - min_distance = max(1, int(min_chunk_len * sfreq / n_win)) - peak_indices, _ = find_peaks( - dists, prominence=np.std(dists) * 0.5, distance=min_distance + See Also + -------- + mne_denoise.dss.utils.segmentation.CovarianceSegmenter : + The shared implementation used internally. + """ + segmenter = CovarianceSegmenter( + sfreq=sfreq, + min_chunk_len=min_chunk_len, + cov_win_len=cov_win_len, + bandpass=(target_freq - 3, target_freq + 3), ) - boundary_indices = (peak_indices + 1) * n_win - - # 5. Enforce min length - valid_boundaries = [0] - last_boundary = 0 - min_samples = int(min_chunk_len * sfreq) - - for b in boundary_indices: - if (b - last_boundary) >= min_samples: - valid_boundaries.append(b) - last_boundary = b - - if (n_times - last_boundary) < min_samples and len(valid_boundaries) > 1: - valid_boundaries.pop() - - valid_boundaries.append(n_times) - - segments = [] - for i in range(len(valid_boundaries) - 1): - segments.append((valid_boundaries[i], valid_boundaries[i + 1])) - - return segments + return segmenter.segment(data) def find_fine_peak( diff --git a/mne_denoise/zapline/core.py b/mne_denoise/zapline/core.py index aabb49cc..757714a3 100644 --- a/mne_denoise/zapline/core.py +++ b/mne_denoise/zapline/core.py @@ -45,11 +45,9 @@ from ..dss.denoisers.spectral import LineNoiseBias from ..dss.denoisers.temporal import SmoothingBias from ..dss.linear import DSS +from ..dss.utils.segmentation import CovarianceSegmenter from ..dss.utils.selection import auto_select_components_robust -from ..utils import ( - extract_data_from_mne, - reconstruct_mne_object, -) +from ..utils import extract_data_from_mne, reconstruct_mne_object from .adaptive import ( apply_hybrid_cleanup, check_artifact_presence, @@ -57,7 +55,6 @@ detect_harmonics, find_fine_peak, find_noise_freqs, - segment_data, ) logger = logging.getLogger(__name__) @@ -731,12 +728,12 @@ def _run_adaptive(self, data: np.ndarray) -> dict: # Process each frequency sequentially for target_freq in all_freqs_to_process: - segments = segment_data( - current_data, - self.sfreq, - target_freq=target_freq, + segmenter = CovarianceSegmenter( + sfreq=self.sfreq, min_chunk_len=min_chunk_len, + bandpass=(target_freq - 3, target_freq + 3), ) + segments = segmenter.segment(current_data) # Process each segment cleaned_chunks = [] diff --git a/tests/denoisers/test_periodic.py b/tests/denoisers/test_periodic.py index 14ed0fb2..3ab94fa7 100644 --- a/tests/denoisers/test_periodic.py +++ b/tests/denoisers/test_periodic.py @@ -166,6 +166,59 @@ def test_comb_filter_invalid_ndim(): bias.apply(data) +def test_comb_filter_adaptive_q(): + """Test CombFilterBias with proportional Q mode. + + With q_mode="proportional", Q scales as q_factor * h for harmonic h, + maintaining approximately constant absolute bandwidth across harmonics. + """ + sfreq = 1000 + times = np.arange(2000) / sfreq + f0 = 10 + + # Signal: fundamental + 2nd + 3rd harmonic + signal_clean = ( + np.sin(2 * np.pi * f0 * times) + + 0.5 * np.sin(2 * np.pi * 2 * f0 * times) + + 0.3 * np.sin(2 * np.pi * 3 * f0 * times) + ) + rng = np.random.default_rng(42) + noise = rng.normal(0, 2, len(times)) + data = (signal_clean + noise)[np.newaxis, :] + + # Fixed Q + bias_fixed = CombFilterBias( + fundamental_freq=f0, sfreq=sfreq, n_harmonics=3, q_factor=30.0, q_mode="fixed" + ) + biased_fixed = bias_fixed.apply(data) + + # Proportional Q + bias_prop = CombFilterBias( + fundamental_freq=f0, sfreq=sfreq, n_harmonics=3, q_factor=30.0, + q_mode="proportional", + ) + biased_prop = bias_prop.apply(data) + + # Both should produce valid output + assert biased_fixed.shape == data.shape + assert biased_prop.shape == data.shape + + # Proportional should differ from fixed (different filter shapes) + assert not np.allclose(biased_fixed, biased_prop, atol=1e-10) + + # Both should correlate well with the clean signal + corr_fixed = np.corrcoef(biased_fixed[0], signal_clean)[0, 1] + corr_prop = np.corrcoef(biased_prop[0], signal_clean)[0, 1] + assert corr_fixed > 0.8, f"Fixed Q failed (corr={corr_fixed:.3f})" + assert corr_prop > 0.8, f"Proportional Q failed (corr={corr_prop:.3f})" + + +def test_comb_filter_invalid_q_mode(): + """Test CombFilterBias raises error for invalid q_mode.""" + with pytest.raises(ValueError, match="q_mode must be one of"): + CombFilterBias(fundamental_freq=10, sfreq=250, q_mode="invalid") + + def test_quasi_periodic_1d_input(): """Test QuasiPeriodicDenoiser with 1D input.""" rng = np.random.default_rng(42) diff --git a/tests/test_linear_dss.py b/tests/test_linear_dss.py index 1800075f..189a6933 100644 --- a/tests/test_linear_dss.py +++ b/tests/test_linear_dss.py @@ -8,6 +8,8 @@ from numpy.testing import assert_allclose from mne_denoise.dss import DSS, compute_dss +from mne_denoise.dss.denoisers.spectral import LineNoiseBias +from mne_denoise.dss.utils.segmentation import CovarianceSegmenter, FixedWindowSegmenter # ============================================================================= # compute_dss - Core Algorithm Tests @@ -181,6 +183,14 @@ def test_compute_dss_error_no_variance(): compute_dss(c, c) +def test_compute_dss_error_no_variance(): + """compute_dss should raise error when covariance has no variance.""" + c = np.zeros((5, 5)) + + with pytest.raises(ValueError, match="no significant variance"): + compute_dss(c, c) + + def test_compute_dss_tiny_positive_covariance_is_scale_invariant(): """Tiny SI-unit covariances should not be treated as zero variance.""" cov = np.diag([5.0, 2.0, 1.0, 0.5, 0.25]) @@ -874,26 +884,6 @@ def test_dss_mne_epochs_inverse_transform_with_normalization(): assert reconstructed.shape == (n_epochs, n_channels, n_times) -def test_dss_full_rank_reconstruction_exact_match(): - """DSS with n_components=n_channels should reconstruct data exactly (minus mean).""" - rng = np.random.default_rng(42) - n_channels, n_samples = 5, 500 - data = rng.standard_normal((n_channels, n_samples)) * 1e-6 # uV scale - - # Use no-op bias - dss = DSS(bias=lambda x: x, n_components=n_channels, normalize_input=True) - dss.fit(data) - sources = dss.transform(data) - rec = dss.inverse_transform(sources) - - # Comparison against centered data - data_centered = data - data.mean(axis=1, keepdims=True) - - # Tolerances for floating point arithmetic - # Relative tolerance 1e-7 is reasonable for float64 - assert_allclose(rec, data_centered, rtol=1e-7, atol=1e-25) - - def test_dss_inverse_transform_mne_format_3d(): """inverse_transform should detect MNE epochs format (n_epochs, n_comps, n_times).""" rng = np.random.default_rng(42) @@ -1020,40 +1010,538 @@ def test_dss_cov_method_options(): n_components=2, bias=lambda x: x, cov_method="auto", - cov_kws=None, + cov_kws={"return_estimators": False}, ) - dss_mne.cov_method = "empirical" dss_mne.fit(raw) assert dss_mne.filters_.shape == (2, 3) -def test_dss_preserves_scale(): - """DSS reconstruction should preserve physical signal scale (Microvolts).""" - sfreq = 1000 - n_channels = 10 - n_times = 5000 - t = np.arange(n_times) / sfreq - - signal_scale = 5e-6 - data = np.random.randn(n_channels, n_times) * 1e-7 # noise - data[0:3, :] += signal_scale * np.sin(2 * np.pi * 10 * t) +# ============================================================================= +# Adaptive DSS – Segmented Mode +# ============================================================================= - from mne_denoise.dss.denoisers import LinearDenoiser - class IdentityBias(LinearDenoiser): - def apply(self, data): - return data +def _make_nonstationary_line_noise( + n_channels=16, + sfreq=250.0, + duration=120.0, + freq=50.0, + snr_first_half=0.8, + snr_second_half=0.3, + seed=42, +): + """Create synthetic non-stationary line noise (different amplitude halves). - bias = IdentityBias() - dss = DSS( - bias=bias, n_components=n_channels, normalize_input=False, return_type="raw" - ) - reconstructed = dss.fit_transform(data) + Returns data (n_channels, n_times) and sfreq. + """ + rng = np.random.default_rng(seed) + n_times = int(sfreq * duration) + half = n_times // 2 + t = np.arange(n_times) / sfreq - rms_orig = np.sqrt(np.mean(data**2)) - rms_rec = np.sqrt(np.mean(reconstructed**2)) + # Spatial mixing for line noise (rank-1) + topo = rng.standard_normal(n_channels) + topo /= np.linalg.norm(topo) + + # Line noise source + source = np.sin(2 * np.pi * freq * t) + + # Background EEG (pink-ish noise) + eeg = rng.standard_normal((n_channels, n_times)) * 0.5 + + # Inject different amplitudes per half + noise = np.outer(topo, source) + noise[:, :half] *= snr_first_half + noise[:, half:] *= snr_second_half + + return eeg + noise, sfreq + + +class TestSegmentedDSS: + """Tests for DSS with segmented=True.""" + + def test_fit_raises_without_sfreq(self): + """fit() should raise if segmented=True but no sfreq available.""" + data = np.random.default_rng(0).standard_normal((8, 5000)) + bias = LineNoiseBias(freq=50.0, sfreq=250.0) + dss = DSS(bias, segmented=True, n_components=2) + with pytest.raises((ValueError, RuntimeError)): + dss.fit(data) + + def test_fit_transform_runs(self): + """fit_transform should run to completion in segmented mode.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + n_components=2, + ) + result = dss.fit_transform(raw) + assert result is not None + + def test_segment_results_populated(self): + """After segmented fit, segment_results_ should be populated.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + n_components=2, + ) + dss.fit_transform(raw) + assert hasattr(dss, "segment_results_") + assert len(dss.segment_results_) >= 2 + + def test_n_selected_is_max(self): + """n_selected_ should be the max across segments, not the sum.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + n_components=4, + n_select="outlier", + ) + dss.fit_transform(raw) + per_seg_n = [r["n_selected"] for r in dss.segment_results_] + assert dss.n_selected_ == max(per_seg_n) + + def test_reduces_artifact(self): + """Segmented DSS should reduce line noise power.""" + from scipy.signal import welch + + data, sfreq = _make_nonstationary_line_noise(freq=50.0) + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + n_components=4, + n_select=1, + ) + cleaned = dss.fit_transform(raw) + cleaned_data = cleaned.get_data() + + # Compare average 50 Hz power before and after + def avg_power_at_freq(d, freq, sfreq): + f, psd = welch(d, fs=sfreq, nperseg=min(1024, d.shape[1])) + idx = np.argmin(np.abs(f - freq)) + return psd[:, idx].mean() + + pwr_before = avg_power_at_freq(data, 50.0, sfreq) + pwr_after = avg_power_at_freq(cleaned_data, 50.0, sfreq) + assert pwr_after < pwr_before + + def test_covariance_segmenter(self): + """CovarianceSegmenter as segmenter should work.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=CovarianceSegmenter(sfreq=sfreq, min_chunk_len=20.0), + n_components=2, + ) + result = dss.fit_transform(raw) + assert result is not None + + def test_epochs_3d(self): + """Segmented DSS should reject 3-D data (Epochs) gracefully or run.""" + rng = np.random.default_rng(42) + n_ch, n_times, sfreq = 8, 250, 250.0 + n_epochs = 10 + data_3d = rng.standard_normal((n_ch, n_times * n_epochs)) + info = mne.create_info(n_ch, sfreq, ch_types="eeg") + raw = mne.io.RawArray(data_3d, info, verbose=False) + events = np.column_stack( + [np.arange(0, n_times * n_epochs, n_times), np.zeros(n_epochs, int), + np.ones(n_epochs, int)] + ) + epochs = mne.Epochs(raw, events, tmin=0, tmax=(n_times - 1) / sfreq, + baseline=None, preload=True, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, segmented=True, n_components=2) + # Epochs should either work (segmentation on concatenated) or raise + try: + dss.fit_transform(epochs) + except (ValueError, RuntimeError): + pass # Acceptable: segmented mode may not support Epochs + + +class TestAutoSelect: + """Tests for automatic component selection (n_select='outlier', etc.).""" + + def test_outlier(self): + """n_select='outlier' should set n_selected_ >= 0.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=4, n_select="outlier") + dss.fit(raw) + assert dss.n_selected_ is not None + assert dss.n_selected_ >= 0 + + def test_ratio(self): + """n_select='ratio' should set n_selected_ >= 0.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=4, n_select="ratio") + dss.fit(raw) + assert dss.n_selected_ is not None + assert dss.n_selected_ >= 0 + + def test_max_gap(self): + """n_select='max_gap' should set n_selected_ >= 0.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=4, n_select="max_gap") + dss.fit(raw) + assert dss.n_selected_ is not None + assert dss.n_selected_ >= 0 + + def test_combined(self): + """n_select='combined' should set n_selected_ >= 0.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=4, n_select="combined") + dss.fit(raw) + assert dss.n_selected_ is not None + assert dss.n_selected_ >= 0 + + def test_int_passthrough(self): + """n_select=int should directly set n_selected_.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=4, n_select=2) + dss.fit(raw) + assert dss.n_selected_ == 2 + + def test_invalid_method(self): + """Invalid selection_method string should raise ValueError.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=4, n_select="outlier", + selection_method="nonexistent_method") + with pytest.raises((ValueError, KeyError)): + dss.fit(raw) + + def test_manual_override(self): + """n_select=None should leave n_selected_=None.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=4, n_select=None) + dss.fit(raw) + assert dss.n_selected_ is None + + +class TestSmoothingDecomposition: + """Tests for smooth parameter (smoothing decomposition).""" + + def test_smooth_int_fit_transform(self): + """smooth=int should run fit_transform without error.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=2, smooth=5) + result = dss.fit_transform(raw) + assert result is not None + + def test_fit_then_transform_preserves_smooth(self): + """fit() then transform() should NOT lose the smooth component. + + Regression test: previously transform() discarded data_smooth. + """ + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS(bias, n_components=2, smooth=5, return_type="raw") + dss.fit(raw) + result = dss.transform(raw) + result_data = result.get_data() + # The result should have similar scale to the input (because smooth is + # added back). If smooth were lost, the result would be much smaller. + ratio = np.std(result_data) / np.std(data) + assert ratio > 0.3, f"Smooth likely lost: ratio={ratio:.3f}" + + def test_smooth_segmented_cleans_artifact(self): + """Smooth + segmented should still reduce artifact power.""" + from scipy.signal import welch + + data, sfreq = _make_nonstationary_line_noise(freq=50.0) + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + n_components=4, + n_select=1, + smooth=5, + ) + cleaned = dss.fit_transform(raw) + cleaned_data = cleaned.get_data() + + # Compare 50 Hz power + def avg_power_at_freq(d, freq, sfreq): + f, psd = welch(d, fs=sfreq, nperseg=min(1024, d.shape[1])) + idx = np.argmin(np.abs(f - freq)) + return psd[:, idx].mean() + + pwr_before = avg_power_at_freq(data, 50.0, sfreq) + pwr_after = avg_power_at_freq(cleaned_data, 50.0, sfreq) + assert pwr_after < pwr_before + + +class TestCapAndFloor: + """Tests for max_prop_remove and min_select.""" + + def test_max_prop_remove_caps(self): + """max_prop_remove=0.1 on 32ch should cap at 3 components.""" + rng = np.random.default_rng(42) + n_ch = 32 + data = rng.standard_normal((n_ch, int(250 * 120))) + sfreq = 250.0 + # Inject strong line noise to get many components selected + t = np.arange(data.shape[1]) / sfreq + topo = rng.standard_normal(n_ch) + for h in range(1, 6): # 5 harmonics + data += np.outer(topo * h, np.sin(2 * np.pi * 50 * h * t)) + info = mne.create_info(n_ch, sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + n_components=10, + n_select="outlier", + max_prop_remove=0.1, + ) + dss.fit_transform(raw) + max_cap = int(n_ch * 0.1) + for r in dss.segment_results_: + assert r["n_selected"] <= max_cap + + def test_min_select_floor(self): + """min_select=2 should ensure at least 2 components removed.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + n_components=4, + n_select="outlier", + min_select=2, + ) + dss.fit_transform(raw) + for r in dss.segment_results_: + assert r["n_selected"] >= 2 + + +class TestCrossfade: + """Tests for cross-fade overlap-add blending at segment boundaries.""" + + def test_crossfade_output_shape(self): + """Output shape must match input shape regardless of crossfade.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + crossfade=1.0, + n_components=4, + n_select="auto", + ) + cleaned = dss.fit_transform(raw) + assert cleaned.get_data().shape == data.shape + + def test_crossfade_no_boundary_jump(self): + """Derivative at segment boundaries should not spike.""" + data, sfreq = _make_nonstationary_line_noise(duration=120.0) + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + + # Without cross-fade + dss_hard = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + crossfade=0.0, + n_components=4, + n_select="auto", + ) + cleaned_hard = dss_hard.fit_transform(raw).get_data() + + # With cross-fade + dss_xfade = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + crossfade=1.0, + n_components=4, + n_select="auto", + ) + cleaned_xfade = dss_xfade.fit_transform(raw).get_data() + + # Check derivative at boundary (30s = 7500 samples) + boundary = int(30 * sfreq) + win = 10 # samples around boundary + far = 500 # samples away for reference + + diff_hard = np.abs(np.diff(cleaned_hard, axis=1)) + diff_xfade = np.abs(np.diff(cleaned_xfade, axis=1)) + + # Max derivative at boundary vs. reference region + bnd_hard = diff_hard[:, boundary - win : boundary + win].max() + ref_hard = np.median(diff_hard[:, boundary - far : boundary - 2 * win]) + bnd_xfade = diff_xfade[:, boundary - win : boundary + win].max() + ref_xfade = np.median(diff_xfade[:, boundary - far : boundary - 2 * win]) + + ratio_hard = bnd_hard / (ref_hard + 1e-12) + ratio_xfade = bnd_xfade / (ref_xfade + 1e-12) + + # Cross-fade boundary ratio should be no worse than hard boundary + # (and often much better) + assert ratio_xfade <= ratio_hard + 1.0 or ratio_xfade < 10.0 + + def test_crossfade_single_segment_matches_hard(self): + """With only one segment, crossfade has no effect.""" + # Short data → single segment (< min_chunk_len * 2) + rng = np.random.default_rng(99) + sfreq = 250.0 + data = rng.standard_normal((8, int(40 * sfreq))) + t = np.arange(data.shape[1]) / sfreq + data += 0.5 * np.sin(2 * np.pi * 50 * t)[np.newaxis, :] + + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + seg = FixedWindowSegmenter(sfreq=sfreq, window_len=60.0) + + dss_hard = DSS( + bias, segmented=True, segmenter=seg, + crossfade=0.0, n_components=4, n_select="auto", + ) + dss_xfade = DSS( + bias, segmented=True, segmenter=seg, + crossfade=1.0, n_components=4, n_select="auto", + ) + + out_hard = dss_hard.fit_transform(data) + out_xfade = dss_xfade.fit_transform(data) + np.testing.assert_array_equal(out_hard, out_xfade) + + def test_crossfade_zero_backward_compat(self): + """crossfade=0 must produce same result as before (hard concat).""" + data, sfreq = _make_nonstationary_line_noise() + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + seg = FixedWindowSegmenter(sfreq=sfreq, window_len=30.0) + + # Default crossfade (0.0) + dss = DSS( + bias, segmented=True, segmenter=seg, + n_components=4, n_select="auto", + ) + out_default = dss.fit_transform(data) + + # Explicit crossfade=0.0 + dss0 = DSS( + bias, segmented=True, segmenter=seg, + crossfade=0.0, n_components=4, n_select="auto", + ) + out_zero = dss0.fit_transform(data) + np.testing.assert_array_equal(out_default, out_zero) + + def test_crossfade_preserves_energy(self): + """Total signal power should not change dramatically with crossfade.""" + data, sfreq = _make_nonstationary_line_noise() + info = mne.create_info(data.shape[0], sfreq, ch_types="eeg") + raw = mne.io.RawArray(data, info, verbose=False) + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=30.0), + crossfade=1.0, + n_components=4, + n_select="auto", + ) + cleaned = dss.fit_transform(raw).get_data() + + # Power ratio should be between 0.3 and 1.5 + # (cleaning removes artifact, but broadband should be preserved) + power_in = np.mean(data ** 2) + power_out = np.mean(cleaned ** 2) + ratio = power_out / power_in + assert 0.3 < ratio < 1.5, f"Power ratio {ratio:.2f} out of range" + + def test_crossfade_overlap_clamped(self): + """When crossfade is longer than half a segment, it gets clamped.""" + rng = np.random.default_rng(42) + sfreq = 250.0 + # 20s data with 10s segments → crossfade=8s should get clamped + data = rng.standard_normal((8, int(20 * sfreq))) + t = np.arange(data.shape[1]) / sfreq + data += 0.3 * np.sin(2 * np.pi * 50 * t)[np.newaxis, :] + + bias = LineNoiseBias(freq=50.0, sfreq=sfreq) + dss = DSS( + bias, + segmented=True, + segmenter=FixedWindowSegmenter(sfreq=sfreq, window_len=10.0), + crossfade=8.0, # way too long + n_components=4, + n_select="auto", + ) + # Should not crash — overlap gets clamped internally + cleaned = dss.fit_transform(data) + assert cleaned.shape == data.shape + + +def test_narrowband_scan_rejects_segmented(): + """narrowband_scan should raise ValueError with segmented=True.""" + from mne_denoise.dss.variants.narrowband import narrowband_scan - assert_allclose(rms_orig, rms_rec, rtol=0.05) + rng = np.random.default_rng(42) + data = rng.standard_normal((8, 5000)) + with pytest.raises(ValueError, match="segmented"): + narrowband_scan(data, sfreq=250.0, segmented=True) def test_dss_get_normalized_patterns(): diff --git a/tests/test_nonlinear_dss.py b/tests/test_nonlinear_dss.py index dd76e64b..d882bf9c 100644 --- a/tests/test_nonlinear_dss.py +++ b/tests/test_nonlinear_dss.py @@ -2,8 +2,6 @@ from __future__ import annotations -from unittest.mock import patch - import mne import numpy as np import pytest @@ -380,62 +378,6 @@ def test_iterative_dss_class_inverse_transform(): assert reconstructed.shape == (8, 500) -def test_iterative_dss_class_inverse_transform_3d(): - """IterativeDSS inverse_transform should handle 3D data.""" - rng = np.random.default_rng(42) - n_epochs, n_ch, n_times = 5, 8, 100 - # MNE Standard: (n_epochs, n_channels, n_times) - data = rng.standard_normal((n_epochs, n_ch, n_times)) - - denoiser = KurtosisDenoiser() - it_dss = IterativeDSS(denoiser, n_components=3, max_iter=5) - - sources = it_dss.fit_transform(data) - reconstructed = it_dss.inverse_transform(sources) - - assert reconstructed.shape == (n_epochs, n_ch, n_times) - - -def test_iterative_dss_class_inverse_transform_normalized(): - """IterativeDSS inverse_transform should handle normalization correctly (2D and 3D).""" - rng = np.random.default_rng(42) - n_epochs, n_ch, n_times = 5, 4, 100 - - # 2D data with different channel scales - scales = np.array([1.0, 0.1, 0.01, 1e-3]) - data_2d = rng.standard_normal((n_ch, n_times * n_epochs)) * scales[:, np.newaxis] - data_3d = data_2d.reshape(n_ch, n_epochs, n_times).transpose(1, 0, 2) - - # Use identity denoiser to ensure perfect reconstruction (within iterative precision) - def identity_denoiser(data): - return data - - # Center data for exact comparison as inverse_transform reconstructs centered data - data_2d_centered = data_2d - data_2d.mean(axis=1, keepdims=True) - data_3d_centered = data_3d - data_3d.mean(axis=(0, 2), keepdims=True) - - # Test 2D - # Set reg to 1e-15 to ensure we don't truncate any components - it_dss_2d = IterativeDSS( - identity_denoiser, n_components=n_ch, normalize_input=True, reg=1e-15 - ) - sources_2d = it_dss_2d.fit_transform(data_2d) - reconstructed_2d = it_dss_2d.inverse_transform(sources_2d) - - # Reconstructed should match original centered data (full rank) - assert_allclose(data_2d_centered, reconstructed_2d, rtol=1e-3, atol=1e-12) - - # Test 3D - it_dss_3d = IterativeDSS( - identity_denoiser, n_components=n_ch, normalize_input=True, reg=1e-15 - ) - sources_3d = it_dss_3d.fit_transform(data_3d) - reconstructed_3d = it_dss_3d.inverse_transform(sources_3d) - - # Needs to match 3D centered data - assert_allclose(data_3d_centered, reconstructed_3d, rtol=1e-3, atol=1e-12) - - def test_iterative_dss_class_transform_before_fit(): """IterativeDSS should raise error when transform called before fit.""" denoiser = KurtosisDenoiser() @@ -798,7 +740,7 @@ def denoiser(x): return np.tanh(x) # Class - dss = IterativeDSS(denoiser, n_components=3, random_state=42, normalize_input=False) + dss = IterativeDSS(denoiser, n_components=3, random_state=42) dss.fit(data) res_class = dss.transform(data) @@ -809,160 +751,3 @@ def denoiser(x): # Should be identical (using same seed) assert_allclose(res_class, res_func) - - -def test_iterative_dss_preserves_scale(): - """IterativeDSS reconstruction should preserve physical signal scale.""" - sfreq = 1000 - n_channels = 10 - n_times = 2000 - t = np.arange(n_times) / sfreq - - signal_scale = 5e-6 - data = np.random.randn(n_channels, n_times) * 1e-7 - data[0:3, :] += signal_scale * np.sin(2 * np.pi * 10 * t) - - from mne_denoise.dss.denoisers import TanhMaskDenoiser - - idss = IterativeDSS( - denoiser=TanhMaskDenoiser(), n_components=n_channels, random_state=42 - ) - # Using inverse_transform directly here to verify patterns * sources - reconstructed = idss.fit(data).inverse_transform(idss.transform(data)) - - rms_orig = np.sqrt(np.mean(data**2)) - rms_rec = np.sqrt(np.mean(reconstructed**2)) - assert_allclose(rms_orig, rms_rec, rtol=0.05) - - -def test_iterative_dss_get_normalized_patterns(): - """Test the newly added get_normalized_patterns method in IterativeDSS.""" - from mne_denoise.dss.denoisers import TanhMaskDenoiser - - data = np.random.randn(10, 1000) - idss = IterativeDSS(denoiser=TanhMaskDenoiser(), n_components=2) - idss.fit(data) - norm_patterns = idss.get_normalized_patterns() - assert norm_patterns.shape == (10, 2) - assert_allclose(np.linalg.norm(norm_patterns, axis=0), 1.0) - - -def test_iterative_dss_full_rank_reconstruction_exact_match(): - """IterativeDSS with n_components=n_channels should reconstruct data exactly.""" - rng = np.random.default_rng(42) - n_channels, n_samples = 4, 1000 # Small enough for quick convergence - data = rng.standard_normal((n_channels, n_samples)) * 1e-6 # uV scale - - # Use Tanh denoiser - from mne_denoise.dss.denoisers import TanhMaskDenoiser - - # We need tight convergence for exact reconstruction check - dss = IterativeDSS( - denoiser=TanhMaskDenoiser(), - n_components=n_channels, - normalize_input=True, - max_iter=1000, - tol=1e-12, - random_state=42, - ) - dss.fit(data) - sources = dss.transform(data) - rec = dss.inverse_transform(sources) - - # Comparison against centered data - data_centered = data - data.mean(axis=1, keepdims=True) - - assert_allclose(rec, data_centered, rtol=1e-7, atol=1e-25) - - -def test_iterative_dss_mne_normalization(): - """IterativeDSS normalization should work with MNE objects.""" - rng = np.random.default_rng(42) - n_channels, n_samples = 4, 1000 - sfreq = 250.0 - - # Create data with different scales - data = rng.standard_normal((n_channels, n_samples)) - data[0] *= 1e-6 # Simulate gradiometer scale - data[1] *= 1e-12 # Simulate magnetometer scale - - info = mne.create_info( - ch_names=[f"EEG{i:03d}" for i in range(n_channels)], sfreq=sfreq, ch_types="eeg" - ) - raw = mne.io.RawArray(data, info, verbose=False) - - # IterativeDSS defaults to normalize_input=True - from mne_denoise.dss.denoisers import TanhMaskDenoiser - - dss = IterativeDSS(denoiser=TanhMaskDenoiser(), n_components=3) - sources = dss.fit_transform(raw) - - assert sources.shape == (3, n_samples) - assert dss.channel_norms_ is not None - assert dss.channel_norms_.shape == (n_channels,) - # Norms should reflect the scales - assert dss.channel_norms_[0] > dss.channel_norms_[1] - - -def test_iterative_dss_one_degenerate_signal(): - """iterative_dss_one should handle signal killing (norm < 1e-12).""" - rng = np.random.default_rng(42) - n_ch, n_times = 3, 100 - X = rng.standard_normal((n_ch, n_times)) - - # Stateful denoiser that kills the signal once then works - class FlakyDenoiser: - def __init__(self): - self.killed = False - - def __call__(self, data): - if not self.killed: - self.killed = True - return np.zeros_like(data) - return data # Identity map otherwise - - denoiser = FlakyDenoiser() - w_init = np.array([1.0, 0.0, 0.0]) - - w, source, n_iter, converged = iterative_dss_one( - X, denoiser, w_init=w_init, max_iter=10, random_state=rng - ) - - # It should have reinitialized w (randomly) and then converged - assert denoiser.killed - assert not np.allclose(w, w_init) # Should have changed - - -def test_iterative_dss_degenerate_orthogonalization(): - """iterative_dss should handle degenerate components during orthogonalization.""" - rng = np.random.default_rng(42) - n_samples = 100 - # Create rank-deficient data where components are identical - v = rng.standard_normal(n_samples) - X = np.vstack([v, v, v]) # Rank 1 data - - # Use a simple identity denoiser - def identity_denoiser(data): - return data - - # Mock whitening step to ensure it returns 2 components despite rank 1 data - X_white_mock = np.zeros((2, n_samples)) - X_white_mock[0] = v - - # Initialize BOTH components to the SAME vector to force collapse after orthogonalization - w_init_force = np.array([[1.0, 0.0], [1.0, 0.0]]) - - with patch("mne_denoise.dss.nonlinear.whiten_data") as mock_whiten: - mock_whiten.return_value = ( - X_white_mock, - np.eye(2, 3), # Fake whitener - np.eye(3, 2), # Fake dewhitener - ) - - filters, _, _, _ = iterative_dss( - X, identity_denoiser, n_components=2, w_init=w_init_force, random_state=rng - ) - - # Should stay at 2 components and re-initialize the degenerate one - assert filters.shape == (2, 3) - assert not np.allclose(filters[1], 0) diff --git a/tests/test_zapline_adaptive.py b/tests/test_zapline_adaptive.py index daba1a51..c4906273 100644 --- a/tests/test_zapline_adaptive.py +++ b/tests/test_zapline_adaptive.py @@ -6,6 +6,7 @@ from scipy import signal from mne_denoise.zapline import ZapLine +from mne_denoise.dss.utils.segmentation import CovarianceSegmenter from mne_denoise.zapline.adaptive import ( apply_cleanline_notch, apply_hybrid_cleanup, @@ -677,3 +678,73 @@ def test_hybrid_cleanup_protection(): cleaned = apply_hybrid_cleanup(data, sfreq=1000, freq=50.0) assert np.array_equal(cleaned, data) + + +# ========================================================================= +# CovarianceSegmenter / segment_data unification tests +# ========================================================================= + + +def test_segment_data_wraps_covariance_segmenter(): + """segment_data wrapper produces identical results to CovarianceSegmenter.""" + rng = np.random.default_rng(42) + sfreq = 250 + n_times = int(120 * sfreq) # 2 minutes + data = rng.standard_normal((4, n_times)) + target_freq = 50.0 + + # Via the wrapper + wrapper_result = segment_data( + data, sfreq, target_freq=target_freq, min_chunk_len=30.0, cov_win_len=1.0 + ) + + # Via CovarianceSegmenter directly + seg = CovarianceSegmenter( + sfreq=sfreq, + min_chunk_len=30.0, + cov_win_len=1.0, + bandpass=(target_freq - 3, target_freq + 3), + ) + direct_result = seg.segment(data) + + assert wrapper_result == direct_result + + +def test_covariance_segmenter_in_zapline_adaptive(): + """ZapLine adaptive mode uses CovarianceSegmenter internally.""" + rng = np.random.default_rng(42) + sfreq = 250 + n_times = int(60 * sfreq) + times = np.arange(n_times) / sfreq + n_ch = 4 + + data = rng.standard_normal((n_ch, n_times)) * 0.1 + data += np.sin(2 * np.pi * 50.0 * times) * 5.0 + + zl = ZapLine( + sfreq=sfreq, + line_freq=50.0, + adaptive=True, + adaptive_params={"min_chunk_len": 10.0}, + ) + + # Should run without error — internally uses CovarianceSegmenter + cleaned = zl.fit_transform(data) + assert cleaned.shape == data.shape + + +def test_segment_data_deprecated_wrapper_still_works(): + """segment_data backward-compatible wrapper still returns valid segments.""" + rng = np.random.default_rng(42) + data = rng.standard_normal((4, 5000)) + sfreq = 250 + + segments = segment_data(data, sfreq, target_freq=50.0, min_chunk_len=5.0) + + # Basic validity checks + assert len(segments) >= 1 + assert segments[0][0] == 0 + assert segments[-1][1] == 5000 + # Segments should be contiguous + for i in range(len(segments) - 1): + assert segments[i][1] == segments[i + 1][0] diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/test_segmentation.py b/tests/utils/test_segmentation.py new file mode 100644 index 00000000..c2752b85 --- /dev/null +++ b/tests/utils/test_segmentation.py @@ -0,0 +1,157 @@ +"""Unit tests for segmentation utilities.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from mne_denoise.dss.utils.segmentation import ( + CovarianceSegmenter, + FixedWindowSegmenter, +) + + +# ============================================================================ +# FixedWindowSegmenter +# ============================================================================ + + +class TestFixedWindowSegmenter: + """Tests for FixedWindowSegmenter.""" + + def test_basic_segmentation(self): + """Even-length data should produce equal-length windows.""" + sfreq = 250.0 + segmenter = FixedWindowSegmenter(sfreq=sfreq, window_len=10.0) + data = np.random.default_rng(0).standard_normal((4, int(30 * sfreq))) + segments = segmenter.segment(data) + assert len(segments) == 3 + for s, e in segments: + assert e - s == int(10 * sfreq) + + def test_short_data_single_segment(self): + """Data shorter than window_len should produce one segment.""" + sfreq = 250.0 + segmenter = FixedWindowSegmenter(sfreq=sfreq, window_len=10.0) + data = np.random.default_rng(0).standard_normal((4, int(5 * sfreq))) + segments = segmenter.segment(data) + assert len(segments) == 1 + assert segments[0] == (0, data.shape[1]) + + def test_trailing_merge(self): + """Tiny trailing chunk (< 50% window) should be merged with previous.""" + sfreq = 250.0 + segmenter = FixedWindowSegmenter(sfreq=sfreq, window_len=10.0) + # 25.1 seconds → 2 windows of 10s, trailing 5.1s > 50% → 3 segments + # But 21 seconds → 2 windows of 10s, trailing 1s < 50% → merge + n_times = int(21 * sfreq) + data = np.random.default_rng(0).standard_normal((4, n_times)) + segments = segmenter.segment(data) + # Trailing 1s < 50% of 10s window → merged + assert len(segments) == 2 + assert segments[-1][1] == n_times + + def test_no_trailing_merge_when_large(self): + """Trailing chunk >= 50% of window should NOT be merged.""" + sfreq = 250.0 + segmenter = FixedWindowSegmenter(sfreq=sfreq, window_len=10.0) + # 26 seconds → 2 windows of 10s + trailing 6s (>= 50% of 10s) + n_times = int(26 * sfreq) + data = np.random.default_rng(0).standard_normal((4, n_times)) + segments = segmenter.segment(data) + assert len(segments) == 3 + assert segments[-1][1] == n_times + + def test_covers_all_samples(self): + """Segments should cover all samples without gaps or overlaps.""" + sfreq = 250.0 + segmenter = FixedWindowSegmenter(sfreq=sfreq, window_len=10.0) + n_times = int(35 * sfreq) + data = np.random.default_rng(0).standard_normal((4, n_times)) + segments = segmenter.segment(data) + assert segments[0][0] == 0 + assert segments[-1][1] == n_times + for i in range(len(segments) - 1): + assert segments[i][1] == segments[i + 1][0] + + +# ============================================================================ +# CovarianceSegmenter +# ============================================================================ + + +class TestCovarianceSegmenter: + """Tests for CovarianceSegmenter.""" + + def test_stationary_data_single_segment(self): + """Stationary data should result in a single segment.""" + sfreq = 250.0 + rng = np.random.default_rng(42) + data = rng.standard_normal((8, int(60 * sfreq))) + segmenter = CovarianceSegmenter(sfreq=sfreq, min_chunk_len=10.0) + segments = segmenter.segment(data) + # Stationary noise → likely one segment + assert len(segments) >= 1 + assert segments[0][0] == 0 + assert segments[-1][1] == data.shape[1] + + def test_nonstationary_data_detects_boundary(self): + """Data with a clear stationarity break should produce >=2 segments.""" + sfreq = 250.0 + rng = np.random.default_rng(42) + n_ch = 8 + half = int(30 * sfreq) + # First half: small noise. Second half: noise * 20 + part1 = rng.standard_normal((n_ch, half)) * 0.1 + part2 = rng.standard_normal((n_ch, half)) * 20.0 + data = np.concatenate([part1, part2], axis=1) + segmenter = CovarianceSegmenter(sfreq=sfreq, min_chunk_len=5.0) + segments = segmenter.segment(data) + assert len(segments) >= 2 + + def test_min_chunk_length_enforced(self): + """No segment should be shorter than min_chunk_len (in samples).""" + sfreq = 250.0 + rng = np.random.default_rng(42) + n_ch = 8 + data = rng.standard_normal((n_ch, int(120 * sfreq))) + min_chunk = 10.0 + segmenter = CovarianceSegmenter(sfreq=sfreq, min_chunk_len=min_chunk) + segments = segmenter.segment(data) + min_samples = int(min_chunk * sfreq) + for s, e in segments: + # Allow last segment to be slightly shorter due to rounding + assert (e - s) >= min_samples * 0.8 + + def test_covers_all_samples(self): + """Segments should cover all samples without gaps or overlaps.""" + sfreq = 250.0 + rng = np.random.default_rng(42) + data = rng.standard_normal((4, int(60 * sfreq))) + segmenter = CovarianceSegmenter(sfreq=sfreq, min_chunk_len=10.0) + segments = segmenter.segment(data) + assert segments[0][0] == 0 + assert segments[-1][1] == data.shape[1] + for i in range(len(segments) - 1): + assert segments[i][1] == segments[i + 1][0] + + def test_short_data_single_segment(self): + """Data shorter than min_chunk_len should produce one segment.""" + sfreq = 250.0 + rng = np.random.default_rng(42) + data = rng.standard_normal((4, int(5 * sfreq))) + segmenter = CovarianceSegmenter(sfreq=sfreq, min_chunk_len=30.0) + segments = segmenter.segment(data) + assert len(segments) == 1 + assert segments[0] == (0, data.shape[1]) + + def test_bandpass_parameter(self): + """Bandpass parameter should be accepted without error.""" + sfreq = 250.0 + rng = np.random.default_rng(42) + data = rng.standard_normal((4, int(60 * sfreq))) + segmenter = CovarianceSegmenter( + sfreq=sfreq, min_chunk_len=10.0, bandpass=(8.0, 12.0) + ) + segments = segmenter.segment(data) + assert len(segments) >= 1 diff --git a/tests/utils/test_selection.py b/tests/utils/test_selection.py index 1196ca07..f4d61414 100644 --- a/tests/utils/test_selection.py +++ b/tests/utils/test_selection.py @@ -1,186 +1,323 @@ -"""Unit tests for component-selection helpers. - -Covers :func:`iterative_outlier_removal`, :func:`detect_eigenvalue_knee`, -and :func:`auto_select_components_robust`. Regression cases include the -high-channel-count MEG eigenvalue pattern reported in Issue #34. -""" - -from __future__ import annotations - -import numpy as np - -from mne_denoise.dss.utils.selection import ( - auto_select_components, - auto_select_components_robust, - detect_eigenvalue_knee, - iterative_outlier_removal, -) - -# Eigenvalue spectrum from the user-reported CTF MEG case (Issue #34): -# 7 strong components corresponding to coherent line noise, followed by -# 8 near-zero components in the noise tail. -ISSUE_34_EIGENVALUES = np.array( - [ - 9.88999359e-01, - 9.68951301e-01, - 7.14728232e-01, - 6.76753765e-01, - 5.83699080e-01, - 4.22202798e-01, - 1.63730711e-01, - 5.00206326e-03, - 1.18624482e-03, - 1.48571576e-04, - 8.64564508e-05, - 7.59793074e-05, - 4.66650479e-05, - 3.74553067e-05, - 3.18800906e-07, - ] -) - - -# ----------------------------------------------------------------------------- -# detect_eigenvalue_knee -# ----------------------------------------------------------------------------- - - -def test_detect_knee_user_meg_case(): - """User's Issue #34 eigenvalues: 7 strong + 8 near-zero -> knee at 7.""" - assert detect_eigenvalue_knee(ISSUE_34_EIGENVALUES) == 7 - - -def test_detect_knee_single_dominant(): - """One dominant component followed by a noise floor.""" - evs = np.array([0.95, 0.05, 0.04, 0.03, 0.02, 0.01]) - assert detect_eigenvalue_knee(evs) == 1 - - -def test_detect_knee_two_dominant(): - """Two strong components, then a clear gap.""" - evs = np.array([0.9, 0.8, 0.1, 0.08, 0.05, 0.02]) - assert detect_eigenvalue_knee(evs) == 2 - - -def test_detect_knee_clean_monotonic_decay(): - """Smoothly-decaying spectrum with no clear gap returns 0.""" - evs = np.array([0.5, 0.47, 0.44, 0.40, 0.37, 0.33, 0.30]) - assert detect_eigenvalue_knee(evs) == 0 - - -def test_detect_knee_empty(): - """Empty array returns 0.""" - assert detect_eigenvalue_knee(np.array([])) == 0 - - -def test_detect_knee_single_value(): - """Single eigenvalue returns 1 (degenerate case).""" - assert detect_eigenvalue_knee(np.array([0.5])) == 1 - - -def test_detect_knee_all_zero(): - """All-zero eigenvalues return 0.""" - assert detect_eigenvalue_knee(np.zeros(5)) == 0 - - -def test_detect_knee_respects_min_ratio(): - """Knee gates on min_ratio: shallow drops are rejected.""" - # 2x drop between 0.9 and 0.45 -- below the default min_ratio=3 - evs = np.array([0.9, 0.45, 0.40, 0.35, 0.30]) - assert detect_eigenvalue_knee(evs, min_ratio=3.0) == 0 - # Same data but lower min_ratio accepts the drop - assert detect_eigenvalue_knee(evs, min_ratio=1.5) == 1 - - -def test_detect_knee_rel_floor_excludes_tail(): - """Anchors below rel_floor * max are excluded from knee selection. - - Without the rel_floor mask, the largest gap might be at the tail - (e.g., between 1e-5 and 1e-12); the floor ensures we anchor on a - meaningful eigenvalue. - """ - evs = np.array([0.9, 0.8, 0.7, 1e-5, 1e-10]) - # The largest log-drop is between 0.7 and 1e-5 (5+ decades), - # which IS what we want here -> returns 3. - assert detect_eigenvalue_knee(evs) == 3 - - -def test_detect_knee_rel_floor_excludes_all(): - """``rel_floor`` greater than 1.0 leaves no valid anchors and returns 0. - - Degenerate guard for the ``not np.any(valid)`` branch. - """ - evs = np.array([0.9, 0.8, 0.1]) - assert detect_eigenvalue_knee(evs, rel_floor=2.0) == 0 - - -# ----------------------------------------------------------------------------- -# auto_select_components_robust -# ----------------------------------------------------------------------------- - - -def test_robust_user_meg_case(): - """Issue #34 case: outlier returns 0, knee returns 7, robust returns 7.""" - n_outlier = iterative_outlier_removal(ISSUE_34_EIGENVALUES, sigma=3.0) - n_knee = detect_eigenvalue_knee(ISSUE_34_EIGENVALUES) - n_robust = auto_select_components_robust(ISSUE_34_EIGENVALUES) - assert n_outlier == 0 - assert n_knee == 7 - assert n_robust == 7 - - -def test_robust_combines_via_max(): - """When outlier and knee disagree, the larger count wins.""" - # Construct a case where outlier returns >= 1 (one extreme outlier) - # and knee may return 1 as well -> max == 1. - evs = np.array([100.0, 0.1, 0.09, 0.08, 0.07, 0.06]) - n_outlier = iterative_outlier_removal(evs, sigma=3.0) - n_knee = detect_eigenvalue_knee(evs) - assert auto_select_components_robust(evs) == max(n_outlier, n_knee) - - -def test_robust_clean_returns_zero(): - """Smoothly-decaying spectrum: both paths return 0.""" - evs = np.array([0.5, 0.47, 0.44, 0.40, 0.37, 0.33, 0.30]) - assert auto_select_components_robust(evs) == 0 - - -def test_robust_forwards_kwargs(): - """``sigma``, ``knee_rel_floor``, ``knee_min_ratio`` reach the underlying calls.""" - evs = np.array([0.9, 0.45, 0.40, 0.35, 0.30]) - # Strict knee gate rejects this drop - assert auto_select_components_robust(evs, knee_min_ratio=3.0) == 0 - # Relaxed knee gate accepts it - assert auto_select_components_robust(evs, knee_min_ratio=1.5) >= 1 - - -# ----------------------------------------------------------------------------- -# Backwards-compatibility (existing functions untouched) -# ----------------------------------------------------------------------------- - - -def test_iterative_outlier_removal_unchanged(): - """Iterative removal still picks up extreme outliers (regression guard). - - The original algorithm is conservative; this case has one massive outlier - that survives the iterative step. - """ - scores = np.array([1000.0, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.05]) - assert iterative_outlier_removal(scores, sigma=3.0) >= 1 - - -def test_iterative_outlier_removal_user_case_is_zero(): - """The user MEG eigenvalues legitimately return 0 from the outlier path. - - This locks in the diagnosis: the bug is not in iterative_outlier_removal, - it's in *relying solely* on it. - """ - assert iterative_outlier_removal(ISSUE_34_EIGENVALUES, sigma=3.0) == 0 - - -def test_auto_select_components_alias_unchanged(): - """``auto_select_components`` remains a thin wrapper over outlier removal.""" - scores = np.array([10.0, 0.5, 0.4, 0.3, 0.2]) - assert auto_select_components(scores, threshold=3.0) == iterative_outlier_removal( - scores, sigma=3.0 - ) +"""Unit tests for component-selection helpers. + +Covers :func:`iterative_outlier_removal`, :func:`detect_eigenvalue_knee`, +:func:`auto_select_components_robust`, :func:`eigenvalue_ratio_selection`, and +:func:`max_gap_selection`. Regression cases include the high-channel-count MEG +eigenvalue pattern reported in Issue #34. +""" + +from __future__ import annotations + +import numpy as np + +from mne_denoise.dss.utils.selection import ( + auto_select_components, + auto_select_components_robust, + detect_eigenvalue_knee, + eigenvalue_ratio_selection, + iterative_outlier_removal, + max_gap_selection, +) + +# Eigenvalue spectrum from the user-reported CTF MEG case (Issue #34): +# 7 strong components corresponding to coherent line noise, followed by +# 8 near-zero components in the noise tail. +ISSUE_34_EIGENVALUES = np.array( + [ + 9.88999359e-01, + 9.68951301e-01, + 7.14728232e-01, + 6.76753765e-01, + 5.83699080e-01, + 4.22202798e-01, + 1.63730711e-01, + 5.00206326e-03, + 1.18624482e-03, + 1.48571576e-04, + 8.64564508e-05, + 7.59793074e-05, + 4.66650479e-05, + 3.74553067e-05, + 3.18800906e-07, + ] +) + + +# ----------------------------------------------------------------------------- +# detect_eigenvalue_knee +# ----------------------------------------------------------------------------- + + +def test_detect_knee_user_meg_case(): + """User's Issue #34 eigenvalues: 7 strong + 8 near-zero -> knee at 7.""" + assert detect_eigenvalue_knee(ISSUE_34_EIGENVALUES) == 7 + + +def test_detect_knee_single_dominant(): + """One dominant component followed by a noise floor.""" + evs = np.array([0.95, 0.05, 0.04, 0.03, 0.02, 0.01]) + assert detect_eigenvalue_knee(evs) == 1 + + +def test_detect_knee_two_dominant(): + """Two strong components, then a clear gap.""" + evs = np.array([0.9, 0.8, 0.1, 0.08, 0.05, 0.02]) + assert detect_eigenvalue_knee(evs) == 2 + + +def test_detect_knee_clean_monotonic_decay(): + """Smoothly-decaying spectrum with no clear gap returns 0.""" + evs = np.array([0.5, 0.47, 0.44, 0.40, 0.37, 0.33, 0.30]) + assert detect_eigenvalue_knee(evs) == 0 + + +def test_detect_knee_empty(): + """Empty array returns 0.""" + assert detect_eigenvalue_knee(np.array([])) == 0 + + +def test_detect_knee_single_value(): + """Single eigenvalue returns 1 (degenerate case).""" + assert detect_eigenvalue_knee(np.array([0.5])) == 1 + + +def test_detect_knee_all_zero(): + """All-zero eigenvalues return 0.""" + assert detect_eigenvalue_knee(np.zeros(5)) == 0 + + +def test_detect_knee_respects_min_ratio(): + """Knee gates on min_ratio: shallow drops are rejected.""" + # 2x drop between 0.9 and 0.45 -- below the default min_ratio=3 + evs = np.array([0.9, 0.45, 0.40, 0.35, 0.30]) + assert detect_eigenvalue_knee(evs, min_ratio=3.0) == 0 + # Same data but lower min_ratio accepts the drop + assert detect_eigenvalue_knee(evs, min_ratio=1.5) == 1 + + +def test_detect_knee_rel_floor_excludes_tail(): + """Anchors below rel_floor * max are excluded from knee selection. + + Without the rel_floor mask, the largest gap might be at the tail + (e.g., between 1e-5 and 1e-12); the floor ensures we anchor on a + meaningful eigenvalue. + """ + evs = np.array([0.9, 0.8, 0.7, 1e-5, 1e-10]) + # The largest log-drop is between 0.7 and 1e-5 (5+ decades), + # which IS what we want here -> returns 3. + assert detect_eigenvalue_knee(evs) == 3 + + +def test_detect_knee_rel_floor_excludes_all(): + """``rel_floor`` greater than 1.0 leaves no valid anchors and returns 0. + + Degenerate guard for the ``not np.any(valid)`` branch. + """ + evs = np.array([0.9, 0.8, 0.1]) + assert detect_eigenvalue_knee(evs, rel_floor=2.0) == 0 + + +# ----------------------------------------------------------------------------- +# auto_select_components_robust +# ----------------------------------------------------------------------------- + + +def test_robust_user_meg_case(): + """Issue #34 case: outlier returns 0, knee returns 7, robust returns 7.""" + n_outlier = iterative_outlier_removal(ISSUE_34_EIGENVALUES, sigma=3.0) + n_knee = detect_eigenvalue_knee(ISSUE_34_EIGENVALUES) + n_robust = auto_select_components_robust(ISSUE_34_EIGENVALUES) + assert n_outlier == 0 + assert n_knee == 7 + assert n_robust == 7 + + +def test_robust_combines_via_max(): + """When outlier and knee disagree, the larger count wins.""" + # Construct a case where outlier returns >= 1 (one extreme outlier) + # and knee may return 1 as well -> max == 1. + evs = np.array([100.0, 0.1, 0.09, 0.08, 0.07, 0.06]) + n_outlier = iterative_outlier_removal(evs, sigma=3.0) + n_knee = detect_eigenvalue_knee(evs) + assert auto_select_components_robust(evs) == max(n_outlier, n_knee) + + +def test_robust_clean_returns_zero(): + """Smoothly-decaying spectrum: both paths return 0.""" + evs = np.array([0.5, 0.47, 0.44, 0.40, 0.37, 0.33, 0.30]) + assert auto_select_components_robust(evs) == 0 + + +def test_robust_forwards_kwargs(): + """``sigma``, ``knee_rel_floor``, ``knee_min_ratio`` reach the underlying calls.""" + evs = np.array([0.9, 0.45, 0.40, 0.35, 0.30]) + # Strict knee gate rejects this drop + assert auto_select_components_robust(evs, knee_min_ratio=3.0) == 0 + # Relaxed knee gate accepts it + assert auto_select_components_robust(evs, knee_min_ratio=1.5) >= 1 + + +# ----------------------------------------------------------------------------- +# Backwards-compatibility (existing functions untouched) +# ----------------------------------------------------------------------------- + + +def test_iterative_outlier_removal_unchanged(): + """Iterative removal still picks up extreme outliers (regression guard). + + The original algorithm is conservative; this case has one massive outlier + that survives the iterative step. + """ + scores = np.array([1000.0, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.05]) + assert iterative_outlier_removal(scores, sigma=3.0) >= 1 + + +def test_iterative_outlier_removal_user_case_is_zero(): + """The user MEG eigenvalues legitimately return 0 from the outlier path. + + This locks in the diagnosis: the bug is not in iterative_outlier_removal, + it's in *relying solely* on it. + """ + assert iterative_outlier_removal(ISSUE_34_EIGENVALUES, sigma=3.0) == 0 + + +def test_auto_select_components_alias_unchanged(): + """``auto_select_components`` remains a thin wrapper over outlier removal.""" + scores = np.array([10.0, 0.5, 0.4, 0.3, 0.2]) + assert auto_select_components(scores, threshold=3.0) == iterative_outlier_removal( + scores, sigma=3.0 + ) + + +class TestIterativeOutlierRemoval: + """Tests for iterative_outlier_removal.""" + + def test_clear_outliers(self): + """Scores with a clear outlier should return >= 0 (conservative).""" + scores = np.array([0.9, 0.8, 0.15, 0.12, 0.1, 0.08, 0.07]) + n = iterative_outlier_removal(scores, sigma=2.0) + # The iterative method is conservative; it may or may not flag + # the top scores depending on the distribution shape + assert n >= 0 + + def test_no_outliers(self): + """Uniform scores should produce 0 outliers.""" + scores = np.array([0.5, 0.5, 0.5, 0.5, 0.5]) + n = iterative_outlier_removal(scores, sigma=3.0) + assert n == 0 + + def test_single_outlier(self): + """One extreme value among many similar should be detected.""" + scores = np.array([10.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) + n = iterative_outlier_removal(scores, sigma=2.0) + assert n >= 1 + + def test_strict_threshold(self): + """Very high sigma should detect fewer outliers.""" + scores = np.array([10.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]) + n_strict = iterative_outlier_removal(scores, sigma=10.0) + n_lenient = iterative_outlier_removal(scores, sigma=1.0) + assert n_strict <= n_lenient + + def test_two_elements(self): + """With only two elements, algorithm should still work.""" + scores = np.array([1.0, 0.1]) + n = iterative_outlier_removal(scores, sigma=2.0) + assert n == 0 # Not enough elements for iterative removal + + def test_empty_array(self): + """Empty array should return 0.""" + scores = np.array([]) + n = iterative_outlier_removal(scores, sigma=2.0) + assert n == 0 + + +# ============================================================================ +# eigenvalue_ratio_selection +# ============================================================================ + + +class TestEigenvalueRatioSelection: + """Tests for eigenvalue_ratio_selection.""" + + def test_clear_drop(self): + """A clear 3x drop should be detected.""" + eigenvalues = np.array([0.9, 0.3, 0.29, 0.28]) + n = eigenvalue_ratio_selection(eigenvalues, ratio_threshold=2.0) + assert n == 1 + + def test_no_clear_drop(self): + """Gentle decline should return 0.""" + eigenvalues = np.array([0.5, 0.48, 0.46, 0.44]) + n = eigenvalue_ratio_selection(eigenvalues, ratio_threshold=2.0) + assert n == 0 + + def test_multiple_components(self): + """Two large then drop should select 2.""" + eigenvalues = np.array([0.9, 0.8, 0.1, 0.09]) + n = eigenvalue_ratio_selection(eigenvalues, ratio_threshold=2.0) + assert n == 2 + + def test_zero_eigenvalue(self): + """Zero eigenvalue should count as drop.""" + eigenvalues = np.array([0.5, 0.0, 0.0]) + n = eigenvalue_ratio_selection(eigenvalues, ratio_threshold=2.0) + assert n == 1 + + def test_single_eigenvalue(self): + """Single eigenvalue → select 1.""" + eigenvalues = np.array([0.5]) + n = eigenvalue_ratio_selection(eigenvalues, ratio_threshold=2.0) + assert n == 1 + + def test_all_equal(self): + """All equal eigenvalues → no drop → 0.""" + eigenvalues = np.array([0.5, 0.5, 0.5, 0.5]) + n = eigenvalue_ratio_selection(eigenvalues, ratio_threshold=2.0) + assert n == 0 + + +# ============================================================================ +# max_gap_selection +# ============================================================================ + + +class TestMaxGapSelection: + """Tests for max_gap_selection.""" + + def test_finds_largest_gap(self): + """Should find the position of the biggest drop.""" + eigenvalues = np.array([0.9, 0.3, 0.28, 0.27]) + n = max_gap_selection(eigenvalues, min_ratio=1.2) + assert n == 1 # 0.9/0.3 = 3.0 is the biggest gap + + def test_all_equal_no_gap(self): + """All equal → no gap exceeding min_ratio → 0.""" + eigenvalues = np.array([0.5, 0.5, 0.5, 0.5]) + n = max_gap_selection(eigenvalues, min_ratio=1.2) + assert n == 0 + + def test_gentle_decline(self): + """Gentle decline may still find a gap if ratio > min_ratio.""" + eigenvalues = np.array([0.5, 0.4, 0.3, 0.2]) + n = max_gap_selection(eigenvalues, min_ratio=1.2) + # 0.3/0.2 = 1.5 is the biggest gap → n=3 + assert n >= 1 + + def test_min_ratio_filtering(self): + """High min_ratio should reject small gaps.""" + eigenvalues = np.array([0.5, 0.45, 0.4, 0.35]) + n = max_gap_selection(eigenvalues, min_ratio=2.0) + assert n == 0 # All ratios < 2.0 + + def test_single_eigenvalue(self): + """Single eigenvalue → return 1.""" + eigenvalues = np.array([0.5]) + n = max_gap_selection(eigenvalues, min_ratio=1.2) + assert n == 1 + + def test_two_eigenvalues(self): + """Two eigenvalues with clear gap.""" + eigenvalues = np.array([0.9, 0.1]) + n = max_gap_selection(eigenvalues, min_ratio=1.2) + assert n == 1