Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f1fdf22
add sleep.py
tgnassou May 28, 2026
1d9974e
WIP: update sleep dataset
tgnassou May 28, 2026
6150cbb
finish sleep
tgnassou May 28, 2026
2ddd308
add bci dataset
tgnassou May 28, 2026
56614cd
Add solver EEGNet
bryan29-ly May 28, 2026
eb15a49
fix the data shape
tgnassou May 28, 2026
5193dec
fix n_classes
tgnassou May 28, 2026
5fa1815
Fix Sleep
bryan29-ly May 28, 2026
7a4177b
add seed
tgnassou May 28, 2026
3553c8e
Add freq and ch_names metadata
bryan29-ly May 29, 2026
b9420fa
add hgd
tgnassou May 29, 2026
28fbabf
Merge branch 'eeg' of https://github.com/tgnassou/benchmark_tsfm into…
tgnassou May 29, 2026
517570a
add pip::
tgnassou May 29, 2026
5ba8df7
Merge branch 'main' into eeg
tgnassou May 29, 2026
f094e5d
Add REVE solver and add metadata to HGD
bryan29-ly May 29, 2026
0a06bac
Merge remote-tracking branch 'refs/remotes/origin/eeg' into eeg
bryan29-ly May 29, 2026
f78a177
Merge remote-tracking branch 'refs/remotes/origin/eeg' into eeg
bryan29-ly May 29, 2026
5f1d18c
Add REVE solver and meta data for HGD dataset
bryan29-ly May 29, 2026
acdd314
debug = True
tgnassou May 29, 2026
67ce8fe
fix conflict
tgnassou May 29, 2026
40b7ace
Add CBraMod and adapt the classification heads
bryan29-ly May 29, 2026
c9b57ce
Merge remote-tracking branch 'refs/remotes/origin/eeg' into eeg
bryan29-ly May 29, 2026
97e5531
Merge branch 'main' into eeg
tgnassou Jun 2, 2026
700b4aa
Merge branch 'eeg' of https://github.com/tgnassou/benchmark_tsfm into…
tgnassou Jun 2, 2026
03102ca
Adapt REVE & CBraMod solvers for linear probing adapter
bryan29-ly Jun 2, 2026
f858634
Fix HGD dependencies & EEGNet model instance creation
bryan29-ly Jun 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions datasets/bnci2014_001.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import numpy as np

from braindecode.preprocessing.windowers import create_windows_from_events

from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import (
exponential_moving_standardize,
preprocess,
Preprocessor,
)

from benchopt import BaseDataset


# All datasets must be named `Dataset` and inherit from `BaseDataset`
class Dataset(BaseDataset):

# Name to select the dataset in the CLI and to display the results.
name = "BNCI2014_001"

requirements = [
'pip::braindecode', 'pip::moabb',
]

parameters = {
'train_ratio': [0.8],
'debug': [False],
'seed': [42],
}

def get_data(self):
# The return arguments of this function are passed as keyword arguments
# to `Objective.set_data`. This defines the benchmark's
# API to pass data. It is customizable for each benchmark.

subjects = [1, 2] if self.debug else [1, 2, 3, 4, 5, 6, 7, 8, 9]
dataset = MOABBDataset(
dataset_name="BNCI2014_001", subject_ids=subjects,
)
low_cut_hz = 4.0 # low cut frequency for filtering
high_cut_hz = 40.0 # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000
# Factor to convert from V to uV
factor = 1e6

preprocessors = [
Preprocessor("pick_types", eeg=True, meg=False, stim=False),
Preprocessor(lambda data: np.multiply(data, factor)),
Preprocessor(
"filter", l_freq=low_cut_hz, h_freq=high_cut_hz
),
Preprocessor(
exponential_moving_standardize,
factor_new=factor_new,
init_block_size=init_block_size,
),
]

# Transform the data
preprocess(dataset, preprocessors)

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]

assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])

# Extract channels names
ch_names = dataset.datasets[0].raw.ch_names

# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

window_size_samples = None
window_stride_samples = None

windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=trial_start_offset_samples,
trial_stop_offset_samples=0,
preload=False,
window_size_samples=window_size_samples,
window_stride_samples=window_stride_samples,
)

splitted = windows_dataset.split("subject")
subjects = list(splitted.keys())

X_all = []
y_all = []
for sub in subjects:
n_runs = len(splitted[sub].datasets)
x = []
y = []
for run in range(n_runs):
x += [sample[0].T for sample in splitted[sub].datasets[run]]
y += [sample[1] for sample in splitted[sub].datasets[run]]
X_all.append(np.array(x))
y_all.append(np.array(y))
random_state = np.random.RandomState(seed=self.seed)
ids_train = random_state.choice(
len(X_all), size=int(len(X_all) * self.train_ratio),
replace=False
)
X_train = np.concatenate([X_all[i] for i in ids_train])
y_train = np.concatenate([y_all[i] for i in ids_train])
X_test = np.concatenate(
[X_all[i] for i in range(len(X_all)) if i not in ids_train]
)
y_test = np.concatenate(
[y_all[i] for i in range(len(y_all)) if i not in ids_train]
)
np.unique(y_train), np.unique(y_test) # sanity check
return dict(
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test,
task="classification",
metrics=["accuracy", "balanced_accuracy", "f1_weighted"],
n_classes=len(np.unique(y_train)),
freq=sfreq,
ch_names=ch_names
)
201 changes: 201 additions & 0 deletions datasets/hgd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""High-Gamma Dataset (HGD) — motor imagery EEG classification.

Wraps the High-Gamma Dataset from Schirrmeister et al. (2017):
"Deep learning with convolutional neural networks for EEG decoding
and visualization", Human Brain Mapping.
https://doi.org/10.1002/hbm.23730

128-electrode EEG (44 motor-cortex channels used) recorded from 14 healthy
subjects performing ~1000 four-second trials of executed movements across
13 runs. Data are downloaded automatically via braindecode / MOABB.

Labels (4 classes):
0: left_hand 1: right_hand 2: feet 3: rest

The dataset's own train/test split (runs 1-11 → train, runs 12-13 → test)
is preserved per subject; subjects are then pooled according to
`train_ratio`.

Data contract output
--------------------
X_train : np.ndarray (N, T, C) windows (n_times, n_channels=44)
y_train : np.ndarray (N,) int class labels 0-3
X_test : np.ndarray (M, T, C)
y_test : np.ndarray (M,) int
task : "classification"
metrics : ["accuracy", "balanced_accuracy", "f1_weighted"]
n_classes : 4
"""

import numpy as np
from braindecode.datasets import HGD
from braindecode.preprocessing.preprocess import preprocess, Preprocessor
from braindecode.preprocessing.windowers import create_windows_from_events
from sklearn.preprocessing import scale as standard_scale

from benchopt import BaseDataset


# ---------------------------------------------------------------------------
# Label mapping
# ---------------------------------------------------------------------------
LABEL_MAPPING = {
"left_hand": 0,
"right_hand": 1,
"feet": 2,
"rest": 3,
}


# ---------------------------------------------------------------------------
# Per-subject loader
# ---------------------------------------------------------------------------

def _load_subject(sub_id, preprocessors, window_size_samples):
"""Wrapper that falls back to a description-based split if metadata
item access is unavailable in the installed braindecode version."""
dataset = HGD(subject_ids=[sub_id])
preprocess(dataset, preprocessors)
raw = dataset.datasets[0].raw
sfreq = raw.info["sfreq"]
ch_names = raw.ch_names

windows_dataset = create_windows_from_events(
dataset,
trial_start_offset_samples=0,
trial_stop_offset_samples=0,
window_size_samples=window_size_samples,
window_stride_samples=window_size_samples,
preload=True,
mapping=LABEL_MAPPING,
)
preprocess(
windows_dataset, [Preprocessor(standard_scale, channel_wise=True)]
)

X_, y_ = [], []

# windows_dataset.description contains a 'split' column ('train'/'test')
# Each window is indexed into its base dataset; we replicate that mapping.
descriptions = windows_dataset.datasets # list of WindowsDataset

for ds in descriptions:
for x, label, _ in ds:
window = x.T # (T, C)
X_.append(window)
y_.append(label)

return X_, y_, sfreq, ch_names


# ---------------------------------------------------------------------------
# Benchopt Dataset
# ---------------------------------------------------------------------------

class Dataset(BaseDataset):
"""High-Gamma Dataset (HGD) — 4-class motor EEG classification.

Parameters
----------
resample_hz : float
Target sampling frequency. The raw data are recorded at 500 Hz;
default resamples to 250 Hz (window_size_samples=1000 → 4 s).
high_cut_hz : float
Cutoff for the low-pass filter applied to raw signals (Hz).
factor : float
Multiplicative scaling applied before filtering (e.g. V → µV).
window_size_samples : int
Length of each trial window in samples (after resampling).
train_ratio : float
Fraction of subjects whose trials go into the training pool.
The internal per-subject train/test split (runs 1-11 vs 12-13) is
always respected first; `train_ratio` then selects which subjects
contribute to the final train vs test arrays.
debug : bool
If True, load only the first 2 subjects for fast iteration.
seed : int
Random seed for the subject-level train/test split.
"""

name = "HGD"

requirements = ["pip::moabb", "pip::pandas", "pip::braindecode"]

parameters = {
"seed": [42],
"train_ratio": [0.8],
"resample_hz": [250],
"high_cut_hz": [40.0],
"factor": [1e6], # V → µV
"window_size_samples": [1000], # 4 s at 250 Hz
"debug": [False],
}

def get_data(self):
n_jobs = 1
preprocessors = [
# Convert V → µV
Preprocessor(lambda data: np.multiply(data, self.factor)),
# Resample to target frequency
Preprocessor("resample", sfreq=self.resample_hz),
# Low-pass filter
Preprocessor(
"filter", l_freq=None, h_freq=self.high_cut_hz, n_jobs=n_jobs
),
]

sub_ids = list(range(1, 15)) # 14 subjects
if self.debug:
sub_ids = sub_ids[:2]

# Collect per-subject (train, test) pairs
X_all, y_all = [], []
sfreq_ref, ch_names_ref = None, None

for sub_id in sub_ids:
X_, y_, sfreq, ch_names = _load_subject(
sub_id, preprocessors, self.window_size_samples
)
if sfreq_ref is None:
sfreq_ref, ch_names_ref = sfreq, ch_names
else:
assert sfreq == sfreq_ref and ch_names == ch_names_ref, f"Inconsistent meta for sub {sub_id}"
X_all.append(X_)
y_all.append(y_)

# ------------------------------------------------------------------
# Subject-level train / test split (same pattern as Sleep dataset)
# ------------------------------------------------------------------
random_state = np.random.RandomState(seed=self.seed)
ids_train = random_state.choice(
len(X_all),
size=int(len(X_all) * self.train_ratio),
replace=False,
)
ids_train_set = set(ids_train.tolist())

X_train = np.concatenate(
[X_all[i] for i in ids_train_set], axis=0
)
y_train = np.concatenate(
[y_all[i] for i in ids_train_set], axis=0
)
X_test = np.concatenate(
[X_all[i] for i in range(len(X_all))
if i not in ids_train_set], axis=0
)
y_test = np.concatenate(
[y_all[i] for i in range(len(y_all))
if i not in ids_train_set], axis=0
)
return dict(
X_train=X_train, # (N, window_size_samples, n_channels)
y_train=y_train, # (N,) int in {0, 1, 2, 3}
X_test=X_test, # (M, window_size_samples, n_channels)
y_test=y_test, # (M,)
task="classification",
metrics=["accuracy", "balanced_accuracy", "f1_weighted"],
n_classes=4,
freq=sfreq_ref,
ch_names=ch_names_ref
)
Loading
Loading