Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
2631a83
added _cohy to spectral.time
seqasim Jun 28, 2024
388ab8c
list cohy method in spectral.time
seqasim Jun 28, 2024
c38374a
Mimic conditional statements for coh for cohy in _pairwise_con and _…
seqasim Jun 28, 2024
b44da50
initial add of phase_slope_index_time to effective.py
seqasim Jun 28, 2024
75a72b5
remove unused code from phase_slope_index_time
seqasim Jun 28, 2024
6cef3cb
fixing documentation for phase_slope_index_time
seqasim Jun 28, 2024
dffda01
added test_psi_time to test_effective
seqasim Jun 28, 2024
ef0a484
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 28, 2024
be0c3cd
address pull request comments
seqasim Oct 31, 2024
ce68935
address pull review comments
seqasim Oct 31, 2024
401a6a8
fixed typo in effective.py
seqasim Oct 31, 2024
573211d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 31, 2024
a96db00
removing 'names' from phase_slope_index_time args
seqasim Nov 12, 2024
ac393fe
explicitly provide freqs to test_psi_time
seqasim Nov 12, 2024
b5ab63b
fix freq dim in phase_slope_index_time
seqasim Nov 12, 2024
3c93427
fix shape req in test_psi_time
seqasim Nov 12, 2024
5dad254
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 12, 2024
8820b6f
Merge branch 'main' into main
seqasim Nov 12, 2024
fd5738e
Update for review
tsbinns Nov 16, 2024
48e81a8
Fix too long line
tsbinns Nov 16, 2024
fece4ae
Address reviewer feedback: add average parameter and smoothing option…
seqasim Nov 15, 2025
0412757
Merge remote-tracking branch 'upstream/main'
seqasim Nov 15, 2025
3b4fd07
Merge remote-tracking branch 'origin/main'
seqasim Nov 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ on numpy array inputs.

envelope_correlation
phase_slope_index
phase_slope_index_time
vector_auto_regression
spectral_connectivity_epochs
spectral_connectivity_time
Expand Down
Empty file added file_paths.txt
Empty file.
2 changes: 1 addition & 1 deletion mne_connectivity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from .datasets import make_signals_in_freq_bands, make_surrogate_data
from .decoding import CoherencyDecomposition
from .effective import phase_slope_index
from .effective import phase_slope_index, phase_slope_index_time
from .envelope import envelope_correlation, symmetric_orth
from .io import read_connectivity
from .spectral import spectral_connectivity_epochs, spectral_connectivity_time
Expand Down
247 changes: 245 additions & 2 deletions mne_connectivity/effective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
import numpy as np
from mne.utils import logger, verbose, warn

from .base import SpectralConnectivity, SpectroTemporalConnectivity
from .spectral import spectral_connectivity_epochs
from .base import (
EpochSpectralConnectivity,
SpectralConnectivity,
SpectroTemporalConnectivity,
)
from .spectral import spectral_connectivity_epochs, spectral_connectivity_time
from .utils import fill_doc


Expand Down Expand Up @@ -252,3 +256,242 @@ def phase_slope_index(
)

return conn


@verbose
def phase_slope_index_time(data,
indices=None,
sfreq=2 * np.pi,
mode="multitaper",
fmin=None,
fmax=np.inf,
mt_bandwidth=None,
freqs=None,
n_cycles=7,
padding=0,
average=False,
sm_times=0,
sm_freqs=1,
sm_kernel="hanning",
n_jobs=1,
verbose=None,
):
"""Compute the Phase Slope Index (PSI) connectivity measure across time.

This function computes PSI over time from epoched data. The data may consist of a
single epoch.

The PSI is an effective connectivity measure, i.e., a measure which can give an
indication of the direction of the information flow (causality). For two time
series, one computes the PSI between the first and the second time series as
follows: ::

indices = (np.array([0]), np.array([1]))
psi = phase_slope_index(data, indices=indices, ...)

A positive value means that time series 0 is ahead of time series 1 and a negative
value means the opposite.

The PSI is computed from the coherency (see :func:`spectral_connectivity_time`),
details can be found in :footcite:`NolteEtAl2008`.

Parameters
----------
data : array-like, shape (n_epochs, n_signals, n_times) | Epochs
The data from which to compute connectivity.
freqs : array-like
Array of frequencies of interest for time-frequency decomposition. Only the
frequencies within the range specified by ``fmin`` and ``fmax`` are used.
indices : tuple of array | None
Two arrays with indices of connections for which to compute connectivity. If
`None`, all connections are computed.
sfreq : float
The sampling frequency. Required if data is not :class:`~mne.Epochs`.
mode : str
Time-frequency decomposition method. Can be either: 'multitaper' or
'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and
:func:`mne.time_frequency.tfr_array_morlet` for reference.
fmin : float | tuple of float | None
The lower frequency of interest. Multiple bands are defined using a tuple, e.g.,
``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower bounds. If `None`, the
lowest frequency in ``freqs`` is used.
fmax : float | tuple of float | None
The upper frequency of interest. Multiple bands are defined using a tuple, e.g.
``(13., 30.)`` for two band with 13 Hz and 30 Hz upper bounds. If `None`, the
highest frequency in ``freqs`` is used.
padding : float
Amount of time to consider as padding at the beginning and end of each epoch in
seconds. See Notes of :func:`spectral_connectivity_time` for more information.
mt_bandwidth : float | None
The bandwidth of the multitaper windowing function in Hz.
Only used in 'multitaper' mode.
freqs : array
Array of frequencies of interest. Only used in 'cwt_morlet' mode.
n_cycles : float | array of float
Number of cycles. Fixed number or one per frequency. Only used in
'cwt_morlet' mode.
padding : float
Amount of time to consider as padding at the beginning and end of each
epoch in seconds. See Notes for more information.
average : bool
Average connectivity scores over epochs. If ``True``, output will be
an instance of :class:`SpectralConnectivity`, otherwise
:class:`EpochSpectralConnectivity`.
sm_times : float
Amount of time to consider for the temporal smoothing in seconds.
If zero, no temporal smoothing is applied.
sm_freqs : int
Number of points for frequency smoothing. By default, 1 is used which
is equivalent to no smoothing.
sm_kernel : {'square', 'hanning'}
Smoothing kernel type. Choose either 'square' or 'hanning'.
n_jobs : int
Number of connections to compute in parallel. Memory mapping must be activated.
Please see the Notes section of :func:`spectral_connectivity_time` for details.
%(verbose)s

Returns
-------
conn : instance of Connectivity
Computed connectivity measure(s). Either a
:class:`SpectralConnectivity` or :class:`EpochSpectralConnectivity`
container depending on the ``average`` parameter.
The shape of each array is
(n_signals ** 2, n_bands, n_epochs) or (n_signals ** 2, n_bands)
when "indices" is None, or
(n_con, n_bands, n_epochs) or (n_con, n_bands)
when "indices" is specified and "n_con = len(indices[0])".
The epoch dimension is present when ``average=False`` and absent when
``average=True``.

See Also
--------
mne_connectivity.SpectralConnectivity
mne_connectivity.EpochSpectralConnectivity
mne_connectivity.spectral_connectivity_time

References
----------
.. footbibliography::
"""
logger.info("Estimating phase slope index (PSI) across time")

# estimate the coherency

# Always compute coherency without averaging first, so we can compute PSI
# for each epoch, then average PSI if requested (consistent with spec_conn_time)
cohy = spectral_connectivity_time(
data,
freqs=freqs,
method="cohy",
average=False,
indices=indices,
sfreq=sfreq,
fmin=fmin,
fmax=fmax,
fskip=0,
faverage=False,
sm_times=sm_times,
sm_freqs=sm_freqs,
sm_kernel=sm_kernel,
padding=padding,
mode=mode,
mt_bandwidth=mt_bandwidth,
n_cycles=n_cycles,
decim=1,
n_jobs=n_jobs,
verbose=verbose,
)

freqs_ = np.array(cohy.freqs)
names = cohy.names
n_tapers = cohy.attrs.get("n_tapers")
n_nodes = cohy.n_nodes
n_epochs_used = cohy.n_epochs
metadata = cohy.metadata
events = cohy.events
event_id = cohy.event_id

logger.info(f"Computing PSI from estimated Coherency: {cohy}")
# compute PSI in the requested bands
if fmin is None:
fmin = -np.inf
if fmax is None:
fmax = np.inf

bands = list(zip(np.asarray((fmin,)).ravel(), np.asarray((fmax,)).ravel()))
n_bands = len(bands)

freq_dim = -1

# allocate space for output
out_shape = list(cohy.shape)
out_shape[freq_dim] = n_bands
psi = np.zeros(out_shape, dtype=np.float64)

# allocate accumulator
acc_shape = copy.copy(out_shape)
acc_shape.pop(freq_dim)
acc = np.empty(acc_shape, dtype=np.complex128)

# create list for frequencies used and frequency bands
# of resulting connectivity data
freqs = list()
freq_bands = list()
idx_fi = [slice(None)] * len(out_shape)
idx_fj = [slice(None)] * len(out_shape)
for band_idx, band in enumerate(bands):
freq_idx = np.where((freqs_ > band[0]) & (freqs_ < band[1]))[0]
freqs.append(freqs_[freq_idx])
freq_bands.append(np.mean(freqs_[freq_idx]))

acc.fill(0.0)
for fi, fj in zip(freq_idx, freq_idx[1:]):
idx_fi[freq_dim] = fi
idx_fj[freq_dim] = fj
acc += (
np.conj(cohy.get_data()[tuple(idx_fi)]) * cohy.get_data()[tuple(idx_fj)]
)

idx_fi[freq_dim] = band_idx
psi[tuple(idx_fi)] = np.imag(acc)
logger.info("[PSI Estimation Done]")

# create a connectivity container
# When average=True, average PSI over epochs (consistent with spec_conn_time behavior)
# When average=False, keep epoch dimension
if average:
# Average over epochs
psi = np.mean(psi, axis=0)
conn = SpectralConnectivity(
data=psi,
names=names,
freqs=freq_bands,
n_nodes=n_nodes,
method="phase-slope-index",
spec_method=mode,
indices=indices,
freqs_computed=freqs,
n_epochs_used=n_epochs_used,
n_tapers=n_tapers,
metadata=metadata,
events=events,
event_id=event_id,
)
else:
conn = EpochSpectralConnectivity(
data=psi,
names=names,
freqs=freq_bands,
n_nodes=n_nodes,
method="phase-slope-index",
spec_method=mode,
indices=indices,
freqs_computed=freqs,
n_tapers=n_tapers,
metadata=metadata,
events=events,
event_id=event_id,
)

return conn
43 changes: 38 additions & 5 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,8 @@ def spectral_connectivity_time(
conn = dict()
conn_patterns = dict()
for m in method:
# CaCoh complex-valued, all other methods real-valued
if m == "cacoh":
# Cohy and CaCoh complex-valued, all other methods real-valued
if m in ["cacoh", "cohy"]:
con_scores_dtype = np.complex128
else:
con_scores_dtype = np.float64
Expand Down Expand Up @@ -1035,7 +1035,7 @@ def _parallel_con(
output is a tuple of lists containing arrays for the connectivity scores and
patterns, respectively.
"""
if "coh" in method:
if ("coh" in method) or ("cohy" in method):
# psd
if weights is not None:
psd = weights * w
Expand Down Expand Up @@ -1127,9 +1127,16 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, faverage, weights):
s_xy = np.squeeze(s_xy, axis=0)
s_xy = _smooth_spectra(s_xy, kernel)
out = []
conn_func = {"plv": _plv, "ciplv": _ciplv, "pli": _pli, "wpli": _wpli, "coh": _coh}
conn_func = {
"plv": _plv,
"ciplv": _ciplv,
"pli": _pli,
"wpli": _wpli,
"coh": _coh,
"cohy": _cohy,
}
for m in method:
if m == "coh":
if m in ["coh", "cohy"]:
s_xx = psd[x]
s_yy = psd[y]
out.append(conn_func[m](s_xx, s_yy, s_xy))
Expand Down Expand Up @@ -1372,6 +1379,32 @@ def _coh(s_xx, s_yy, s_xy):
return coh


def _cohy(s_xx, s_yy, s_xy):
"""Compute coherencey given the cross spectral density and PSD.

Parameters
----------
s_xx : array-like, shape (n_freqs, n_times)
The PSD of channel 'x'.
s_yy : array-like, shape (n_freqs, n_times)
The PSD of channel 'y'.
s_xy : array-like, shape (n_freqs, n_times)
The cross PSD between channel 'x' and channel 'y' across
frequency and time points.

Returns
-------
cohy : array-like, shape (n_freqs, n_times)
The estimated COHY.
"""
con_num = s_xy.mean(axis=-1, keepdims=True)
con_den = np.sqrt(
s_xx.mean(axis=-1, keepdims=True) * s_yy.mean(axis=-1, keepdims=True)
)
cohy = con_num / con_den
return cohy


def _compute_csd(x, y, weights):
"""Compute cross spectral density between signals x and y."""
if weights is not None:
Expand Down
Loading