diff --git a/doc/api.rst b/doc/api.rst index 1a84a4ad..3df2c584 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -45,6 +45,7 @@ on numpy array inputs. envelope_correlation phase_slope_index + phase_slope_index_time vector_auto_regression spectral_connectivity_epochs spectral_connectivity_time diff --git a/file_paths.txt b/file_paths.txt new file mode 100644 index 00000000..e69de29b diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index 80c7fe4e..d55f1aee 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -25,7 +25,7 @@ ) from .datasets import make_signals_in_freq_bands, make_surrogate_data from .decoding import CoherencyDecomposition -from .effective import phase_slope_index +from .effective import phase_slope_index, phase_slope_index_time from .envelope import envelope_correlation, symmetric_orth from .io import read_connectivity from .spectral import spectral_connectivity_epochs, spectral_connectivity_time diff --git a/mne_connectivity/effective.py b/mne_connectivity/effective.py index 21ccf965..7a07eb30 100644 --- a/mne_connectivity/effective.py +++ b/mne_connectivity/effective.py @@ -7,8 +7,12 @@ import numpy as np from mne.utils import logger, verbose, warn -from .base import SpectralConnectivity, SpectroTemporalConnectivity -from .spectral import spectral_connectivity_epochs +from .base import ( + EpochSpectralConnectivity, + SpectralConnectivity, + SpectroTemporalConnectivity, +) +from .spectral import spectral_connectivity_epochs, spectral_connectivity_time from .utils import fill_doc @@ -252,3 +256,242 @@ def phase_slope_index( ) return conn + + +@verbose +def phase_slope_index_time(data, + indices=None, + sfreq=2 * np.pi, + mode="multitaper", + fmin=None, + fmax=np.inf, + mt_bandwidth=None, + freqs=None, + n_cycles=7, + padding=0, + average=False, + sm_times=0, + sm_freqs=1, + sm_kernel="hanning", + n_jobs=1, + verbose=None, + ): + """Compute the Phase Slope Index (PSI) connectivity measure across time. + + This function computes PSI over time from epoched data. The data may consist of a + single epoch. + + The PSI is an effective connectivity measure, i.e., a measure which can give an + indication of the direction of the information flow (causality). For two time + series, one computes the PSI between the first and the second time series as + follows: :: + + indices = (np.array([0]), np.array([1])) + psi = phase_slope_index(data, indices=indices, ...) + + A positive value means that time series 0 is ahead of time series 1 and a negative + value means the opposite. + + The PSI is computed from the coherency (see :func:`spectral_connectivity_time`), + details can be found in :footcite:`NolteEtAl2008`. + + Parameters + ---------- + data : array-like, shape (n_epochs, n_signals, n_times) | Epochs + The data from which to compute connectivity. + freqs : array-like + Array of frequencies of interest for time-frequency decomposition. Only the + frequencies within the range specified by ``fmin`` and ``fmax`` are used. + indices : tuple of array | None + Two arrays with indices of connections for which to compute connectivity. If + `None`, all connections are computed. + sfreq : float + The sampling frequency. Required if data is not :class:`~mne.Epochs`. + mode : str + Time-frequency decomposition method. Can be either: 'multitaper' or + 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for reference. + fmin : float | tuple of float | None + The lower frequency of interest. Multiple bands are defined using a tuple, e.g., + ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower bounds. If `None`, the + lowest frequency in ``freqs`` is used. + fmax : float | tuple of float | None + The upper frequency of interest. Multiple bands are defined using a tuple, e.g. + ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper bounds. If `None`, the + highest frequency in ``freqs`` is used. + padding : float + Amount of time to consider as padding at the beginning and end of each epoch in + seconds. See Notes of :func:`spectral_connectivity_time` for more information. + mt_bandwidth : float | None + The bandwidth of the multitaper windowing function in Hz. + Only used in 'multitaper' mode. + freqs : array + Array of frequencies of interest. Only used in 'cwt_morlet' mode. + n_cycles : float | array of float + Number of cycles. Fixed number or one per frequency. Only used in + 'cwt_morlet' mode. + padding : float + Amount of time to consider as padding at the beginning and end of each + epoch in seconds. See Notes for more information. + average : bool + Average connectivity scores over epochs. If ``True``, output will be + an instance of :class:`SpectralConnectivity`, otherwise + :class:`EpochSpectralConnectivity`. + sm_times : float + Amount of time to consider for the temporal smoothing in seconds. + If zero, no temporal smoothing is applied. + sm_freqs : int + Number of points for frequency smoothing. By default, 1 is used which + is equivalent to no smoothing. + sm_kernel : {'square', 'hanning'} + Smoothing kernel type. Choose either 'square' or 'hanning'. + n_jobs : int + Number of connections to compute in parallel. Memory mapping must be activated. + Please see the Notes section of :func:`spectral_connectivity_time` for details. + %(verbose)s + + Returns + ------- + conn : instance of Connectivity + Computed connectivity measure(s). Either a + :class:`SpectralConnectivity` or :class:`EpochSpectralConnectivity` + container depending on the ``average`` parameter. + The shape of each array is + (n_signals ** 2, n_bands, n_epochs) or (n_signals ** 2, n_bands) + when "indices" is None, or + (n_con, n_bands, n_epochs) or (n_con, n_bands) + when "indices" is specified and "n_con = len(indices[0])". + The epoch dimension is present when ``average=False`` and absent when + ``average=True``. + + See Also + -------- + mne_connectivity.SpectralConnectivity + mne_connectivity.EpochSpectralConnectivity + mne_connectivity.spectral_connectivity_time + + References + ---------- + .. footbibliography:: + """ + logger.info("Estimating phase slope index (PSI) across time") + + # estimate the coherency + + # Always compute coherency without averaging first, so we can compute PSI + # for each epoch, then average PSI if requested (consistent with spec_conn_time) + cohy = spectral_connectivity_time( + data, + freqs=freqs, + method="cohy", + average=False, + indices=indices, + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + fskip=0, + faverage=False, + sm_times=sm_times, + sm_freqs=sm_freqs, + sm_kernel=sm_kernel, + padding=padding, + mode=mode, + mt_bandwidth=mt_bandwidth, + n_cycles=n_cycles, + decim=1, + n_jobs=n_jobs, + verbose=verbose, + ) + + freqs_ = np.array(cohy.freqs) + names = cohy.names + n_tapers = cohy.attrs.get("n_tapers") + n_nodes = cohy.n_nodes + n_epochs_used = cohy.n_epochs + metadata = cohy.metadata + events = cohy.events + event_id = cohy.event_id + + logger.info(f"Computing PSI from estimated Coherency: {cohy}") + # compute PSI in the requested bands + if fmin is None: + fmin = -np.inf + if fmax is None: + fmax = np.inf + + bands = list(zip(np.asarray((fmin,)).ravel(), np.asarray((fmax,)).ravel())) + n_bands = len(bands) + + freq_dim = -1 + + # allocate space for output + out_shape = list(cohy.shape) + out_shape[freq_dim] = n_bands + psi = np.zeros(out_shape, dtype=np.float64) + + # allocate accumulator + acc_shape = copy.copy(out_shape) + acc_shape.pop(freq_dim) + acc = np.empty(acc_shape, dtype=np.complex128) + + # create list for frequencies used and frequency bands + # of resulting connectivity data + freqs = list() + freq_bands = list() + idx_fi = [slice(None)] * len(out_shape) + idx_fj = [slice(None)] * len(out_shape) + for band_idx, band in enumerate(bands): + freq_idx = np.where((freqs_ > band[0]) & (freqs_ < band[1]))[0] + freqs.append(freqs_[freq_idx]) + freq_bands.append(np.mean(freqs_[freq_idx])) + + acc.fill(0.0) + for fi, fj in zip(freq_idx, freq_idx[1:]): + idx_fi[freq_dim] = fi + idx_fj[freq_dim] = fj + acc += ( + np.conj(cohy.get_data()[tuple(idx_fi)]) * cohy.get_data()[tuple(idx_fj)] + ) + + idx_fi[freq_dim] = band_idx + psi[tuple(idx_fi)] = np.imag(acc) + logger.info("[PSI Estimation Done]") + + # create a connectivity container + # When average=True, average PSI over epochs (consistent with spec_conn_time behavior) + # When average=False, keep epoch dimension + if average: + # Average over epochs + psi = np.mean(psi, axis=0) + conn = SpectralConnectivity( + data=psi, + names=names, + freqs=freq_bands, + n_nodes=n_nodes, + method="phase-slope-index", + spec_method=mode, + indices=indices, + freqs_computed=freqs, + n_epochs_used=n_epochs_used, + n_tapers=n_tapers, + metadata=metadata, + events=events, + event_id=event_id, + ) + else: + conn = EpochSpectralConnectivity( + data=psi, + names=names, + freqs=freq_bands, + n_nodes=n_nodes, + method="phase-slope-index", + spec_method=mode, + indices=indices, + freqs_computed=freqs, + n_tapers=n_tapers, + metadata=metadata, + events=events, + event_id=event_id, + ) + + return conn diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index ca2893f5..398badfc 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -638,8 +638,8 @@ def spectral_connectivity_time( conn = dict() conn_patterns = dict() for m in method: - # CaCoh complex-valued, all other methods real-valued - if m == "cacoh": + # Cohy and CaCoh complex-valued, all other methods real-valued + if m in ["cacoh", "cohy"]: con_scores_dtype = np.complex128 else: con_scores_dtype = np.float64 @@ -1035,7 +1035,7 @@ def _parallel_con( output is a tuple of lists containing arrays for the connectivity scores and patterns, respectively. """ - if "coh" in method: + if ("coh" in method) or ("cohy" in method): # psd if weights is not None: psd = weights * w @@ -1127,9 +1127,16 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, faverage, weights): s_xy = np.squeeze(s_xy, axis=0) s_xy = _smooth_spectra(s_xy, kernel) out = [] - conn_func = {"plv": _plv, "ciplv": _ciplv, "pli": _pli, "wpli": _wpli, "coh": _coh} + conn_func = { + "plv": _plv, + "ciplv": _ciplv, + "pli": _pli, + "wpli": _wpli, + "coh": _coh, + "cohy": _cohy, + } for m in method: - if m == "coh": + if m in ["coh", "cohy"]: s_xx = psd[x] s_yy = psd[y] out.append(conn_func[m](s_xx, s_yy, s_xy)) @@ -1372,6 +1379,32 @@ def _coh(s_xx, s_yy, s_xy): return coh +def _cohy(s_xx, s_yy, s_xy): + """Compute coherencey given the cross spectral density and PSD. + + Parameters + ---------- + s_xx : array-like, shape (n_freqs, n_times) + The PSD of channel 'x'. + s_yy : array-like, shape (n_freqs, n_times) + The PSD of channel 'y'. + s_xy : array-like, shape (n_freqs, n_times) + The cross PSD between channel 'x' and channel 'y' across + frequency and time points. + + Returns + ------- + cohy : array-like, shape (n_freqs, n_times) + The estimated COHY. + """ + con_num = s_xy.mean(axis=-1, keepdims=True) + con_den = np.sqrt( + s_xx.mean(axis=-1, keepdims=True) * s_yy.mean(axis=-1, keepdims=True) + ) + cohy = con_num / con_den + return cohy + + def _compute_csd(x, y, weights): """Compute cross spectral density between signals x and y.""" if weights is not None: diff --git a/mne_connectivity/tests/test_effective.py b/mne_connectivity/tests/test_effective.py index 0ff0a8d6..9e65d609 100644 --- a/mne_connectivity/tests/test_effective.py +++ b/mne_connectivity/tests/test_effective.py @@ -1,7 +1,8 @@ import numpy as np from numpy.testing import assert_array_almost_equal -from mne_connectivity.effective import phase_slope_index +from mne_connectivity import EpochSpectralConnectivity, SpectralConnectivity +from mne_connectivity.effective import phase_slope_index, phase_slope_index_time def test_psi(): @@ -39,3 +40,63 @@ def test_psi(): assert np.all(conn_cwt.get_data() > 0) assert conn_cwt.shape[-1] == n_times + + +def test_psi_time(): + """Test Phase Slope Index (PSI) estimation across time.""" + sfreq = 50.0 + n_signals = 3 + n_epochs = 10 + n_times = 500 + rng = np.random.RandomState(42) + data = rng.randn(n_epochs, n_signals, n_times) + + # simulate time shifts + for i in range(n_epochs): + data[i, 1, 10:] = data[i, 0, :-10] # signal 0 is ahead + data[i, 2, :-10] = data[i, 0, 10:] # signal 2 is ahead + + # only compute for a subset of the indices + indices = (np.array([0]), np.array([1])) + + freqs = np.arange(5.0, 20, 0.5) + conn_cwt = phase_slope_index_time( + data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices + ) + + assert np.all(conn_cwt.get_data() > 0) + assert conn_cwt.shape[0] == n_epochs + assert isinstance(conn_cwt, EpochSpectralConnectivity) + + # Test with average=False (explicit) + conn_no_avg = phase_slope_index_time( + data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices, average=False + ) + assert isinstance(conn_no_avg, EpochSpectralConnectivity) + assert conn_no_avg.shape[0] == n_epochs + + # Test with average=True + conn_avg = phase_slope_index_time( + data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices, average=True + ) + assert isinstance(conn_avg, SpectralConnectivity) + # When averaged, epoch dimension should be removed + assert len(conn_avg.shape) == 2 # (n_con, n_bands) + assert conn_avg.shape[0] == len(indices[0]) + # Verify that averaged result matches manual average + assert_array_almost_equal(conn_avg.get_data(), np.mean(conn_no_avg.get_data(), axis=0)) + + # Test with single epoch (no epoch dimension in input) + single_epoch_data = data[0:1] # shape (1, n_signals, n_times) + conn_single = phase_slope_index_time( + single_epoch_data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices + ) + assert isinstance(conn_single, EpochSpectralConnectivity) + assert conn_single.shape[0] == 1 # single epoch + + # Test with single epoch and average=True + conn_single_avg = phase_slope_index_time( + single_epoch_data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices, average=True + ) + assert isinstance(conn_single_avg, SpectralConnectivity) + assert len(conn_single_avg.shape) == 2 # (n_con, n_bands)