Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions mne_bids_pipeline/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@
i.e. the average across all channels. If a string, must be the name of a single
channel. To use multiple channels as reference, set to a list of channel names.

!!! warning
If this option is set to `average`, AND `add_online_reference_channel` is `False`
(AND the online reference channel is not already included as flat channel), you
will likely get incorrect results.
See https://github.com/mne-tools/mne-python/issues/13618

???+ example "Example"
Use the average reference:
```python
Expand All @@ -357,6 +363,28 @@
```
"""

add_online_reference_channel: bool = True
"""
Whether the online reference channel should be added to the data (as flat channel),
if it is not included in the data already.
"""

eeg_online_reference_channel: str | None = None
"""
Specify the EEG channel that was used as reference channel during the recording.

???+ example "Example"
```python
eeg_online_reference_channel = "Cz"
```
"""

drop_channel_after_rereference: bool = True
"""
Whether the reconstructed online reference channel should be dropped again after applying
an average reference.
"""

eeg_template_montage: str | DigMontageType | None = None
"""
In situations where you wish to process EEG data and no individual
Expand Down
41 changes: 41 additions & 0 deletions mne_bids_pipeline/_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""EEG reference utilities."""

from ._logging import gen_log_kwargs, logger
from types import SimpleNamespace
import mne

def set_initial_average_reference(inst, cfg: SimpleNamespace):
"""Set an average EEG reference with the option to add the online flat channel before re-referencing.

If `cfg.add_online_reference_channel` is True and the specified `cfg.eeg_online_reference_channel` is not yet present,
it is added as a flat reference channel, increasing the number of channels in `inst` by one.

Note:
- The average reference is added as a projection and not yet applied.
- If you want to drop the online reference channel after re-referencing,
first apply the projection and then drop the channel.
"""

assert_msg = "An average reference projection has already been applied to the data. You cannot add the online reference as a flat channel anymore. Given this function is rather used internally, you might want to raise an issue on GitHub."
assert not mne._fiff.proj._has_eeg_average_ref_proj(inst.info), assert_msg

if cfg.add_online_reference_channel:
assert cfg.eeg_online_reference_channel is not None, "To add the online reference channel as flat channel before re-referencing, `eeg_online_reference_channel` must be provided."

if cfg.eeg_online_reference_channel in inst.ch_names:
msg = f"Specified online reference channel {cfg.eeg_online_reference_channel} exists already. Double-check if this is really the reference channel name, if it is, consider setting `add_online_reference_channel` to `False` in case it is indeed a flat channel."
logger.warning(**gen_log_kwargs(message=msg))
else:
msg = f"Online reference channel {cfg.eeg_online_reference_channel} will be added as flat channel."
logger.info(**gen_log_kwargs(message=msg))

mne.add_reference_channels(inst, ref_channels=[cfg.eeg_online_reference_channel], copy=False)

else:
msg = "Re-referencing to average reference assuming the online reference channel was already added as a flat channel."
logger.info(**gen_log_kwargs(message=msg))

# We use this instead of projection=False to later being able to check if the average projection was already applied
inst.set_eeg_reference("average", projection=True)#.apply_proj()

return inst
19 changes: 15 additions & 4 deletions mne_bids_pipeline/steps/preprocessing/_06a1_fit_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from mne_bids_pipeline._import_data import annotations_to_events, make_epochs
from mne_bids_pipeline._logging import gen_log_kwargs, logger
from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func
from mne_bids_pipeline._reference import set_initial_average_reference
from mne_bids_pipeline._reject import _get_reject
from mne_bids_pipeline._report import _open_report
from mne_bids_pipeline._run import (
Expand Down Expand Up @@ -207,10 +208,17 @@ def run_ica(

# Set an EEG reference
if "eeg" in cfg.ch_types:
projection = True if cfg.eeg_reference == "average" else False
epochs.set_eeg_reference(cfg.eeg_reference, projection=projection)
if cfg.ica_use_icalabel:
epochs.apply_proj() # Apply the reference projection
if cfg.eeg_reference == "average":
set_initial_average_reference(epochs, cfg)
if cfg.ica_use_icalabel:
epochs.apply_proj() # Apply the reference projection
if cfg.drop_channel_after_rereference:
msg = f"Online reference channel {cfg.eeg_online_reference_channel} will be dropped again."
logger.info(**gen_log_kwargs(message=msg))
epochs.drop_channels(cfg.eeg_online_reference_channel)

else:
epochs.set_eeg_reference(cfg.eeg_reference, projection=False)

ar_reject_log = ar_n_interpolate_ = None
if cfg.ica_reject == "autoreject_local":
Expand Down Expand Up @@ -405,6 +413,9 @@ def get_config(
epochs_metadata_keep_last=config.epochs_metadata_keep_last,
epochs_metadata_query=config.epochs_metadata_query,
eeg_reference=get_eeg_reference(config),
eeg_online_reference_channel=config.eeg_online_reference_channel,
add_online_reference_channel=config.add_online_reference_channel,
drop_channel_after_rereference=config.drop_channel_after_rereference,
rest_epochs_duration=config.rest_epochs_duration,
rest_epochs_overlap=config.rest_epochs_overlap,
processing="filt" if config.regress_artifact is None else "regress",
Expand Down
17 changes: 15 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_06a2_find_ica_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)
from mne_bids_pipeline._logging import gen_log_kwargs, logger
from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func
from mne_bids_pipeline._reference import set_initial_average_reference
from mne_bids_pipeline._report import _open_report
from mne_bids_pipeline._run import (
_prep_out_files,
Expand Down Expand Up @@ -170,7 +171,11 @@ def find_ica_artifacts(
# Have the channels needed to make ECG epochs
raw = mne.io.read_raw(raw_fname, preload=False)
if cfg.ica_use_icalabel:
raw.set_eeg_reference("average", projection=True).apply_proj()
set_initial_average_reference(raw, cfg).apply_proj()
if cfg.drop_channel_after_rereference:
msg = f"Online reference channel {cfg.eeg_online_reference_channel} will be dropped again."
logger.info(**gen_log_kwargs(message=msg))
raw.drop_channels(cfg.eeg_online_reference_channel)
# ECG epochs
if not (
"ecg" in raw.get_channel_types()
Expand Down Expand Up @@ -242,7 +247,12 @@ def find_ica_artifacts(
for ri, raw_fname in enumerate(raw_fnames):
raw = mne.io.read_raw_fif(raw_fname, preload=True)
if cfg.ica_use_icalabel:
raw.set_eeg_reference("average", projection=True).apply_proj()
set_initial_average_reference(raw, cfg).apply_proj()
if cfg.drop_channel_after_rereference:
msg = f"Online reference channel {cfg.eeg_online_reference_channel} will be dropped again."
logger.info(**gen_log_kwargs(message=msg))
raw.drop_channels(cfg.eeg_online_reference_channel)

if eog_chs_subj_sess is not None: # explicit None-check to allow []
ch_names = eog_chs_subj_sess
assert all([ch_name in raw.ch_names for ch_name in ch_names])
Expand Down Expand Up @@ -592,6 +602,9 @@ def get_config(
ica_class_thresholds=config.ica_class_thresholds,
ch_types=config.ch_types,
eeg_reference=get_eeg_reference(config),
eeg_online_reference_channel=config.eeg_online_reference_channel,
add_online_reference_channel=config.add_online_reference_channel,
drop_channel_after_rereference=config.drop_channel_after_rereference,
eog_channels=config.eog_channels,
processing="filt" if config.regress_artifact is None else "regress",
**_bids_kwargs(config=config),
Expand Down
10 changes: 8 additions & 2 deletions mne_bids_pipeline/steps/preprocessing/_07_make_epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from mne_bids_pipeline._logging import gen_log_kwargs, logger
from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func
from mne_bids_pipeline._report import _get_prefix_tags, _open_report
from mne_bids_pipeline._reference import set_initial_average_reference
from mne_bids_pipeline._run import (
_prep_out_files,
_sanitize_callable,
Expand Down Expand Up @@ -206,8 +207,10 @@ def run_epochs(

# Set an EEG reference
if "eeg" in cfg.ch_types:
projection = True if cfg.eeg_reference == "average" else False
epochs.set_eeg_reference(cfg.eeg_reference, projection=projection)
if cfg.eeg_reference == "average":
set_initial_average_reference(epochs, cfg)
else:
epochs.set_eeg_reference(cfg.eeg_reference, projection=False)

assert isinstance(epochs.drop_log, tuple)
n_epochs_before_metadata_query = len(epochs.drop_log)
Expand Down Expand Up @@ -347,6 +350,9 @@ def get_config(
ch_types=config.ch_types,
noise_cov=_sanitize_callable(config.noise_cov),
eeg_reference=get_eeg_reference(config),
eeg_online_reference_channel=config.eeg_online_reference_channel,
add_online_reference_channel=config.add_online_reference_channel,
drop_channel_after_rereference=config.drop_channel_after_rereference,
rest_epochs_duration=config.rest_epochs_duration,
rest_epochs_overlap=config.rest_epochs_overlap,
_epochs_split_size=config._epochs_split_size,
Expand Down
48 changes: 43 additions & 5 deletions mne_bids_pipeline/steps/preprocessing/_08a_apply_ica.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
from mne.preprocessing import read_ica
from mne_bids import BIDSPath

from mne_bids_pipeline._config_utils import _get_ssrt, _get_sst, _limit_which_clean
from mne_bids_pipeline._config_utils import _get_ss, _get_ssrt, _limit_which_clean, get_eeg_reference
from mne_bids_pipeline._import_data import _get_run_rest_noise_path, _import_data_kwargs
from mne_bids_pipeline._logging import gen_log_kwargs, logger
from mne_bids_pipeline._parallel import get_parallel_backend, parallel_func
from mne_bids_pipeline._reference import set_initial_average_reference
from mne_bids_pipeline._report import _add_raw, _open_report
from mne_bids_pipeline._run import (
_prep_out_files,
Expand Down Expand Up @@ -142,8 +143,14 @@ def apply_ica_epochs(
logger.info(**gen_log_kwargs(message=msg))

epochs = mne.read_epochs(in_files.pop("epochs"), preload=True)

# Average reference should have already been set in _07_make_epochs.py, therefore we only apply it here
if cfg.ica_use_icalabel:
epochs.set_eeg_reference("average", projection=True).apply_proj()
epochs.apply_proj()
if cfg.drop_channel_after_rereference:
msg = f"Online reference channel {cfg.eeg_online_reference_channel} will be dropped again."
logger.info(**gen_log_kwargs(message=msg))
epochs.drop_channels(cfg.eeg_online_reference_channel)

# Now actually reject the components.
msg = (
Expand All @@ -153,6 +160,13 @@ def apply_ica_epochs(
logger.info(**gen_log_kwargs(message=msg))
epochs_cleaned = ica.apply(epochs.copy()) # Copy b/c works in-place!

# Re-reference data (again) after applying ICA
# TODO: We decided against applying the reference again after the ICA because it's not yet clear whether that's necessary or desired
# if cfg.eeg_reference == "average":
# epochs_cleaned.set_eeg_reference("average", projection=True).apply_proj()
# else:
# epochs_cleaned.set_eeg_reference(cfg.eeg_reference)

msg = f"Saving {len(epochs)} reconstructed epochs after ICA."
logger.info(**gen_log_kwargs(message=msg))
epochs_cleaned.save(
Expand Down Expand Up @@ -218,11 +232,31 @@ def apply_ica_raw(
msg = f"Writing {out_files[in_key].basename} …"
logger.info(**gen_log_kwargs(message=msg))
raw = mne.io.read_raw_fif(raw_fname, preload=True)
if cfg.ica_use_icalabel:
raw.set_eeg_reference("average", projection=True).apply_proj()

if "eeg" in cfg.ch_types:
if cfg.eeg_reference == "average":
set_initial_average_reference(raw, cfg)
if cfg.ica_use_icalabel:
raw.apply_proj()
if cfg.drop_channel_after_rereference:
msg = f"Online reference channel {cfg.eeg_online_reference_channel} will be dropped again."
logger.info(**gen_log_kwargs(message=msg))
raw.drop_channels(cfg.eeg_online_reference_channel)
else:
raw.set_eeg_reference(cfg.eeg_reference, projection=False)

ica.apply(raw)

# Re-reference data (again) after applying ICA
# TODO: We decided against applying the reference again after the ICA because it's not yet clear to us whether that's necessary or desired
# if cfg.eeg_reference == "average":
# raw.set_eeg_reference("average", projection=True).apply_proj()
# else:
# raw.set_eeg_reference(cfg.eeg_reference)

raw.save(out_files[in_key], overwrite=True, split_size=cfg._raw_split_size)
_update_for_splits(out_files, in_key)

# Report
with _open_report(
cfg=cfg,
Expand Down Expand Up @@ -253,7 +287,11 @@ def get_config(
) -> SimpleNamespace:
cfg = SimpleNamespace(
ica_use_icalabel=config.ica_use_icalabel,
processing="filt" if config.regress_artifact is None else "regress",
eeg_reference=get_eeg_reference(config),
eeg_online_reference_channel=config.eeg_online_reference_channel,
add_online_reference_channel=config.add_online_reference_channel,
drop_channel_after_rereference=config.drop_channel_after_rereference,
processing="eyelink" if config.sync_eyelink else "filt" if config.regress_artifact is None else "regress",
_epochs_split_size=config._epochs_split_size,
**_import_data_kwargs(config=config, subject=subject, session=session),
)
Expand Down
2 changes: 2 additions & 0 deletions mne_bids_pipeline/tests/configs/config_ds001971.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
runs = ["01"]
epochs_decim = 5 # to 100 Hz

eeg_online_reference_channel = "dummy_ref" # Should be changed to real online reference

# This is mostly for testing purposes!
decode = True
decoding_time_generalization = True
Expand Down
Loading