From f1fdf223b9b879fa6c7790e7198cc69dc2099f3e Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 28 May 2026 14:29:16 +0200 Subject: [PATCH 01/18] add sleep.py --- datasets/sleep.py | 110 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 datasets/sleep.py diff --git a/datasets/sleep.py b/datasets/sleep.py new file mode 100644 index 0000000..d8c1619 --- /dev/null +++ b/datasets/sleep.py @@ -0,0 +1,110 @@ +"""ECG anomaly detection dataset from TSB-UAD. + +Wraps the ECG recordings from the MIT-BIH / TSB-UAD benchmark. +Each recording is split into a training portion (first 10 %) and a test +portion. Labels are point-level binary anomaly indicators. + +Data contract output +-------------------- +X_train : List[np.ndarray (T_i, C)] training portions (C == 1) +y_train : None unsupervised task +X_test : List[np.ndarray (T_j, C)] test portions +y_test : List[np.ndarray (T_j,)] point-level binary labels +task : "anomaly_detection" +metrics : ["auc_roc", "auc_pr", "f1_pa"] +""" + +import numpy as np +import pandas as pd +from pathlib import Path + +from benchopt import BaseDataset +from benchopt.config import get_data_path +from benchmark_utils.download import fetch_tsb_uad + + +def _load_records(db_path, record_ids, number): + db_path = Path(db_path) + if record_ids in (None, "all", ["all"]): + record_ids = [f.stem for f in db_path.glob("*.out") + if f.stem != "MBA_ECG14046_data"] + if number > 0: + record_ids = record_ids[:number] + + X_list, y_list = [], [] + for rid in record_ids: + path = db_path / f"{rid}.out" + if not path.exists(): + continue + data = pd.read_csv(path, header=None).dropna().to_numpy() + if data.shape[1] < 2: + continue + X_list.append(data[:, 0].astype(np.float32)) + y_list.append(data[:, 1].astype(np.int32)) + return X_list, y_list + + +class Dataset(BaseDataset): + """ECG anomaly detection dataset (TSB-UAD). + + Parameters + ---------- + record_ids : list of str or "all" + Which ECG recordings to include. + debug : bool + If True, truncate each recording to 5000 timesteps for fast iteration. + number : int + Maximum number of recordings to load (-1 = all). + train_ratio : float + Fraction of each recording used as the training (normal) portion. + """ + + name = "ECG" + + requirements = ["pip::pooch", "pandas"] + + parameters = { + "record_ids": [ + ["MBA_ECG14046_data_1", "MBA_ECG14046_data_2"], + ], + "debug": [False], + "number": [-1], + "train_ratio": [0.1], + } + + def get_data(self): + + # Allow reuse of the download helper from benchmark_ad if present, + # otherwise fall back to the data path directly. + try: + path = fetch_tsb_uad("ECG") + except ImportError: + path = get_data_path("ECG") + + record_ids = self.record_ids + X_raw, y_raw = _load_records(path, record_ids, self.number) + + if not X_raw: + raise ValueError("No valid ECG records found.") + + X_train, X_test, y_test = [], [], [] + for x, y in zip(X_raw, y_raw): + if self.debug: + x = x[:5000] + y = y[:5000] + + split = max(1, int(len(x) * self.train_ratio)) + + # Reshape to (T, C=1) + X_train.append(x[:split].reshape(-1, 1)) + X_test.append(x[split:].reshape(-1, 1)) + y_test.append(y[split:]) + + return dict( + X_train=X_train, + y_train=None, + X_test=X_test, + y_test=y_test, + task="anomaly_detection", + metrics=["auc_roc", "auc_pr", "f1_pa"], + ) From 1d9974e3818c05df8cd35cb5bd74ca75c98fb4da Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 28 May 2026 15:31:17 +0200 Subject: [PATCH 02/18] WIP: update sleep dataset --- datasets/sleep.py | 187 +++++++++++++++++++++++++++++----------------- 1 file changed, 117 insertions(+), 70 deletions(-) diff --git a/datasets/sleep.py b/datasets/sleep.py index d8c1619..db23f40 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -1,110 +1,157 @@ -"""ECG anomaly detection dataset from TSB-UAD. +"""Sleep classification dataset from Sleep Physionet. -Wraps the ECG recordings from the MIT-BIH / TSB-UAD benchmark. +Wraps the sleep recordings from the Sleep Physionet. Each recording is split into a training portion (first 10 %) and a test -portion. Labels are point-level binary anomaly indicators. +portion. +Labels are from 0 to 4, corresponding to the sleep stages W, N1, N2, N3, REM. Data contract output -------------------- -X_train : List[np.ndarray (T_i, C)] training portions (C == 1) -y_train : None unsupervised task -X_test : List[np.ndarray (T_j, C)] test portions -y_test : List[np.ndarray (T_j,)] point-level binary labels -task : "anomaly_detection" -metrics : ["auc_roc", "auc_pr", "f1_pa"] +X_train : List[np.ndarray (T, C)] one array per training sample +y_train : np.ndarray (N,) int class labels +X_test : List[np.ndarray (T, C)] +y_test : np.ndarray (M,) int +task : "classification" +metrics : ["accuracy", "balanced_accuracy", "f1_weighted"] +n_classes : int """ import numpy as np -import pandas as pd -from pathlib import Path +from braindecode.datasets import SleepPhysionet +from braindecode.preprocessing.preprocess import preprocess, Preprocessor + +from braindecode.preprocessing.windowers import create_windows_from_events +from sklearn.preprocessing import scale as standard_scale from benchopt import BaseDataset -from benchopt.config import get_data_path -from benchmark_utils.download import fetch_tsb_uad - - -def _load_records(db_path, record_ids, number): - db_path = Path(db_path) - if record_ids in (None, "all", ["all"]): - record_ids = [f.stem for f in db_path.glob("*.out") - if f.stem != "MBA_ECG14046_data"] - if number > 0: - record_ids = record_ids[:number] - - X_list, y_list = [], [] - for rid in record_ids: - path = db_path / f"{rid}.out" - if not path.exists(): - continue - data = pd.read_csv(path, header=None).dropna().to_numpy() - if data.shape[1] < 2: - continue - X_list.append(data[:, 0].astype(np.float32)) - y_list.append(data[:, 1].astype(np.int32)) - return X_list, y_list + + +def _load_subject( + sub_id, preprocessors, mapping=None, window_size_samples=3000 +): + + dataset = SleepPhysionet(subject_ids=[sub_id], crop_wake_mins=30) + + preprocess(dataset, preprocessors) + windows_dataset = create_windows_from_events( + dataset, + trial_start_offset_samples=0, + trial_stop_offset_samples=0, + window_size_samples=window_size_samples, + window_stride_samples=window_size_samples, + preload=True, + mapping=mapping, + ) + + preprocess( + windows_dataset, [Preprocessor(standard_scale, channel_wise=True)] + ) + all_labels = [] + all_data = [] + for i, x in enumerate(windows_dataset): + label = x[1] + all_labels.append(label) + + data = x[0] + all_data.append(data) + return np.concatenate(all_data), all_labels class Dataset(BaseDataset): - """ECG anomaly detection dataset (TSB-UAD). + """Sleep classification dataset (TSB-UAD). Parameters ---------- - record_ids : list of str or "all" - Which ECG recordings to include. + window_size_samples : int + Length of the windows to split the recordings into. + sub_ids : List[int] + Subject IDs to include (from 1 to 82, excluding 39, 68, 69, 78, 79). + mapping : dict + Mapping from the original sleep stage labels to integers. + We merge stages 3 and 4 following AASM standards. + n_jobs : int + Number of parallel jobs to use for preprocessing. debug : bool - If True, truncate each recording to 5000 timesteps for fast iteration. - number : int - Maximum number of recordings to load (-1 = all). + If True, keep only the first 5000 samples + of each recording for fast testing. + high_cut_hz : float + If not None, apply a low-pass filter with this cutoff frequency (in Hz) + to the raw signals. + factor : float + Factor to multiply the raw signals by (e.g. to convert from V to uV train_ratio : float Fraction of each recording used as the training (normal) portion. """ - name = "ECG" + name = "Sleep" requirements = ["pip::pooch", "pandas"] parameters = { - "record_ids": [ - ["MBA_ECG14046_data_1", "MBA_ECG14046_data_2"], - ], - "debug": [False], - "number": [-1], - "train_ratio": [0.1], + "window_size_samples": [3000], + "sub_ids": range(1, 83), + "mapping": { # We merge stages 3 and 4 following AASM standards. + "Sleep stage W": 0, + "Sleep stage 1": 1, + "Sleep stage 2": 2, + "Sleep stage 3": 3, + "Sleep stage 4": 3, + "Sleep stage R": 4, + }, + "train_ratio": [0.8], + "n_jobs": [1], + "debug": [True], + "high_cut_hz": [30], + "factor": [1e6], } def get_data(self): # Allow reuse of the download helper from benchmark_ad if present, # otherwise fall back to the data path directly. - try: - path = fetch_tsb_uad("ECG") - except ImportError: - path = get_data_path("ECG") - record_ids = self.record_ids - X_raw, y_raw = _load_records(path, record_ids, self.number) - - if not X_raw: - raise ValueError("No valid ECG records found.") - - X_train, X_test, y_test = [], [], [] - for x, y in zip(X_raw, y_raw): + preprocessors = [ + Preprocessor(lambda data: np.multiply(data, self.factor)), + Preprocessor( + "filter", l_freq=None, + h_freq=self.high_cut_hz, n_jobs=self.n_jobs + ), + ] + + X_all, y_all = [], [] + sub_ids = self.sub_ids[:2] if self.debug else self.sub_ids + for sub_id in sub_ids: + if sub_id in [39, 68, 69, 78, 79]: + continue + X_, y_ = _load_subject( + sub_id, preprocessors, self.mapping, self.window_size_samples + ) if self.debug: - x = x[:5000] - y = y[:5000] - - split = max(1, int(len(x) * self.train_ratio)) + X_ = X_[:5000] + y_ = y_[:5000] + X_all.append(X_) + y_all.append(y_) + + ids_train = np.random.Random(seed=42).choice( + len(self.sub_ids), size=int(len(self.sub_ids) * self.train_ratio), + replace=False + ) - # Reshape to (T, C=1) - X_train.append(x[:split].reshape(-1, 1)) - X_test.append(x[split:].reshape(-1, 1)) - y_test.append(y[split:]) + X_train = np.concatenate([X_all[i] for i in ids_train]) + y_train = np.concatenate([y_all[i] for i in ids_train]) + X_test = np.concatenate( + [X_all[i] for i in range(len(X_all)) if i not in ids_train] + ) + y_test = np.concatenate( + [y_all[i] for i in range(len(y_all)) if i not in ids_train] + ) return dict( X_train=X_train, - y_train=None, + y_train=y_train, X_test=X_test, y_test=y_test, - task="anomaly_detection", - metrics=["auc_roc", "auc_pr", "f1_pa"], + task="classification", + metrics=["accuracy", "balanced_accuracy", "f1_weighted"], + n_classes=5, ) From 6150cbb41749dd5d36604a50b00f05024f172dc9 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 28 May 2026 15:58:51 +0200 Subject: [PATCH 03/18] finish sleep --- datasets/sleep.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/datasets/sleep.py b/datasets/sleep.py index db23f40..1497773 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -29,7 +29,6 @@ def _load_subject( sub_id, preprocessors, mapping=None, window_size_samples=3000 ): - dataset = SleepPhysionet(subject_ids=[sub_id], crop_wake_mins=30) preprocess(dataset, preprocessors) @@ -54,7 +53,7 @@ def _load_subject( data = x[0] all_data.append(data) - return np.concatenate(all_data), all_labels + return all_data, all_labels class Dataset(BaseDataset): @@ -85,7 +84,7 @@ class Dataset(BaseDataset): name = "Sleep" - requirements = ["pip::pooch", "pandas"] + requirements = ["pip::pooch", "pandas", 'braindecode==1.5.1'] parameters = { "window_size_samples": [3000], @@ -119,7 +118,7 @@ def get_data(self): ] X_all, y_all = [], [] - sub_ids = self.sub_ids[:2] if self.debug else self.sub_ids + sub_ids = self.sub_ids[:1] if self.debug else self.sub_ids for sub_id in sub_ids: if sub_id in [39, 68, 69, 78, 79]: continue @@ -132,8 +131,9 @@ def get_data(self): X_all.append(X_) y_all.append(y_) - ids_train = np.random.Random(seed=42).choice( - len(self.sub_ids), size=int(len(self.sub_ids) * self.train_ratio), + random_state = np.random.RandomState(seed=42) + ids_train = random_state.choice( + len(X_all), size=int(len(X_all) * self.train_ratio), replace=False ) @@ -145,7 +145,6 @@ def get_data(self): y_test = np.concatenate( [y_all[i] for i in range(len(y_all)) if i not in ids_train] ) - return dict( X_train=X_train, y_train=y_train, From 2ddd30867d37cf0ba776929ed0148e93b93eb92a Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 28 May 2026 16:16:30 +0200 Subject: [PATCH 04/18] add bci dataset --- datasets/bci.py | 119 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 119 insertions(+) create mode 100644 datasets/bci.py diff --git a/datasets/bci.py b/datasets/bci.py new file mode 100644 index 0000000..8e3fa95 --- /dev/null +++ b/datasets/bci.py @@ -0,0 +1,119 @@ +import numpy as np + +from braindecode.preprocessing.windowers import create_windows_from_events + +from braindecode.datasets import MOABBDataset +from braindecode.preprocessing import ( + exponential_moving_standardize, + preprocess, + Preprocessor, +) + +from benchopt import BaseDataset + + +# All datasets must be named `Dataset` and inherit from `BaseDataset` +class Dataset(BaseDataset): + + # Name to select the dataset in the CLI and to display the results. + name = "bci" + + requirements = [ + 'braindecode==1.5.1', 'moabb==1.5.0', + ] + + parameters = { + 'train_ratio': [0.8], + 'debug': [True], + } + + def get_data(self): + # The return arguments of this function are passed as keyword arguments + # to `Objective.set_data`. This defines the benchmark's + # API to pass data. It is customizable for each benchmark. + + subjects = [1, 2] if self.debug else [1, 2, 3, 4, 5, 6, 7, 8, 9] + dataset = MOABBDataset( + dataset_name="BNCI2014_001", subject_ids=subjects, + ) + low_cut_hz = 4.0 # low cut frequency for filtering + high_cut_hz = 40.0 # high cut frequency for filtering + # Parameters for exponential moving standardization + factor_new = 1e-3 + init_block_size = 1000 + # Factor to convert from V to uV + factor = 1e6 + + preprocessors = [ + Preprocessor("pick_types", eeg=True, meg=False, stim=False), + Preprocessor(lambda data: np.multiply(data, factor)), + Preprocessor( + "filter", l_freq=low_cut_hz, h_freq=high_cut_hz + ), + Preprocessor( + exponential_moving_standardize, + factor_new=factor_new, + init_block_size=init_block_size, + ), + ] + + # Transform the data + preprocess(dataset, preprocessors) + + trial_start_offset_seconds = -0.5 + # Extract sampling frequency, check that they are same in all datasets + sfreq = dataset.datasets[0].raw.info["sfreq"] + + assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets]) + # Calculate the trial start offset in samples. + trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) + + window_size_samples = None + window_stride_samples = None + + windows_dataset = create_windows_from_events( + dataset, + trial_start_offset_samples=trial_start_offset_samples, + trial_stop_offset_samples=0, + preload=False, + window_size_samples=window_size_samples, + window_stride_samples=window_stride_samples, + ) + + splitted = windows_dataset.split("subject") + subjects = list(splitted.keys()) + + X_all = [] + y_all = [] + for sub in subjects: + n_runs = len(splitted[sub].datasets) + x = [] + y = [] + for run in range(n_runs): + x += [sample[0] for sample in splitted[sub].datasets[run]] + y += [sample[1] for sample in splitted[sub].datasets[run]] + X_all.append(np.array(x)) + y_all.append(np.array(y)) + random_state = np.random.RandomState(seed=42) + ids_train = random_state.choice( + len(X_all), size=int(len(X_all) * self.train_ratio), + replace=False + ) + X_train = np.concatenate([X_all[i] for i in ids_train]) + y_train = np.concatenate([y_all[i] for i in ids_train]) + X_test = np.concatenate( + [X_all[i] for i in range(len(X_all)) if i not in ids_train] + ) + y_test = np.concatenate( + [y_all[i] for i in range(len(y_all)) if i not in ids_train] + ) + + return dict( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + task="classification", + metrics=["accuracy", "balanced_accuracy", "f1_weighted"], + n_classes=5, + ) \ No newline at end of file From 56614cd917290cf901f4c59b606b542198b524fc Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Thu, 28 May 2026 16:17:42 +0200 Subject: [PATCH 05/18] Add solver EEGNet --- solvers/eegnet.py | 172 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) create mode 100644 solvers/eegnet.py diff --git a/solvers/eegnet.py b/solvers/eegnet.py new file mode 100644 index 0000000..c2f6865 --- /dev/null +++ b/solvers/eegnet.py @@ -0,0 +1,172 @@ +"""EEGNet solver for time series classification. + +Uses the ``braindecode`` implementation of EEGNetv4 as a classifier on +multivariate time series of shape (N, T, C). + +References: + https://braindecode.org/ + https://arxiv.org/abs/1611.08024 +""" + +import numpy as np +import torch +from benchopt import BaseSolver + +SUPPORTED_TASKS = {"classification"} + + +class Solver(BaseSolver): + """EEGNet time series classification solver. + + The model is built once in ``set_objective`` (not timed). During + ``run`` the network is trained on the training set. + """ + + name = "EEGNet" + + requirements = [ + "pip::braindecode", + "pip::torch", + ] + + parameters = { + "n_epochs": [50], + "batch_size": [32], + "lr": [1e-3], + } + + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"EEGNet solver does not support task={task!r}" + return False, None + + def set_objective(self, task, X_train, y_train, **meta): + """Prepare the solver for a given dataset configuration. + + Model construction is done here (not inside ``run``) so that + the build time is excluded from the benchmark timing. + """ + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # Infer input dimensions directly from the training data. + # Within a dataset all series share the same (T, C) shape. + X0 = np.asarray(X_train[0]) + n_times = X0.shape[0] + n_channels = X0.shape[1] if X0.ndim == 2 else 1 + n_classes = int(meta.get("n_classes", len(np.unique(y_train)))) + + # Build the network once per dataset configuration. + should_reload = ( + not hasattr(self, "_network") + or getattr(self, "_n_channels", None) != n_channels + or getattr(self, "_n_classes", None) != n_classes + or getattr(self, "_n_times", None) != n_times + ) + if should_reload: + try: + from braindecode.models import EEGNet + + network = EEGNet( + n_chans=n_channels, + n_outputs=n_classes, + n_times=n_times, + ) + network = network.to(device) + + self._network = network + self._n_channels = n_channels + self._n_classes = n_classes + self._n_times = n_times + print( + f"✓ EEGNet built: C={n_channels}, T={n_times}, " + f"n_classes={n_classes} on device: {device}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to build EEGNet: {e}. Make sure braindecode " + "is installed." + ) + + self._device = device + self._optimizer = torch.optim.Adam( + self._network.parameters(), lr=self.lr + ) + self._criterion = torch.nn.CrossEntropyLoss() + + def run(self, _): + """Fit the model on the training data.""" + X = self._prepare_inputs(np.asarray(self.X_train, dtype=np.float32)) + y = np.asarray(self.y_train, dtype=np.int64) + + X_t = torch.tensor(X, dtype=torch.float32, device=self._device) + y_t = torch.tensor(y, dtype=torch.long, device=self._device) + + dataset = torch.utils.data.TensorDataset(X_t, y_t) + loader = torch.utils.data.DataLoader( + dataset, batch_size=self.batch_size, shuffle=True + ) + + self._network.train() + for _ in range(self.n_epochs): + for xb, yb in loader: + self._optimizer.zero_grad() + logits = self._network(xb) + loss = self._criterion(logits, yb) + loss.backward() + self._optimizer.step() + + def _prepare_inputs(self, X_batch): + """Reshape inputs to EEGNet's expected layout. + + EEGNet expects arrays of shape (N, C, T). Inputs arrive as + (N, T, C); within a dataset all series share the same length, + so no interpolate needed. + """ + return X_batch.transpose(0, 2, 1) + + def get_result(self): + """Return a predictor exposing ``predict(X_test) -> (N,)``.""" + return { + "model": EEGNetAdapter( + self._network, + self._device, + self._prepare_inputs, + self.batch_size, + ) + } + + +class EEGNetAdapter: + """Wraps a trained EEGNet to expose ``predict(X) -> (N,)`` class indices. + + ``X`` arrives as a list / array of (T, C) series, matching the + classification data contract in ``objective.py``. + """ + + def __init__(self, network, device, prepare_inputs, batch_size): + self._network = network + self._device = device + self._prepare_inputs = prepare_inputs + self._batch_size = batch_size + + def predict(self, X): + X = self._prepare_inputs(np.asarray(X, dtype=np.float32)) + X_t = torch.tensor(X, dtype=torch.float32, device=self._device) + + self._network.eval() + preds = [] + with torch.no_grad(): + for i in range(0, len(X_t), self._batch_size): + logits = self._network(X_t[i:i + self._batch_size]) + preds.append(logits.argmax(dim=1).cpu().numpy()) + return np.concatenate(preds) From eb15a49ca4bfbe9518e4ddf64eb1ade8a03baca3 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 28 May 2026 16:37:08 +0200 Subject: [PATCH 06/18] fix the data shape --- datasets/bci.py | 5 ++--- datasets/sleep.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/datasets/bci.py b/datasets/bci.py index 8e3fa95..dee5fb7 100644 --- a/datasets/bci.py +++ b/datasets/bci.py @@ -90,7 +90,7 @@ def get_data(self): x = [] y = [] for run in range(n_runs): - x += [sample[0] for sample in splitted[sub].datasets[run]] + x += [sample[0].T for sample in splitted[sub].datasets[run]] y += [sample[1] for sample in splitted[sub].datasets[run]] X_all.append(np.array(x)) y_all.append(np.array(y)) @@ -107,7 +107,6 @@ def get_data(self): y_test = np.concatenate( [y_all[i] for i in range(len(y_all)) if i not in ids_train] ) - return dict( X_train=X_train, y_train=y_train, @@ -116,4 +115,4 @@ def get_data(self): task="classification", metrics=["accuracy", "balanced_accuracy", "f1_weighted"], n_classes=5, - ) \ No newline at end of file + ) diff --git a/datasets/sleep.py b/datasets/sleep.py index 1497773..1156d0b 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -51,7 +51,7 @@ def _load_subject( label = x[1] all_labels.append(label) - data = x[0] + data = x[0].T all_data.append(data) return all_data, all_labels @@ -118,7 +118,7 @@ def get_data(self): ] X_all, y_all = [], [] - sub_ids = self.sub_ids[:1] if self.debug else self.sub_ids + sub_ids = self.sub_ids[:2] if self.debug else self.sub_ids for sub_id in sub_ids: if sub_id in [39, 68, 69, 78, 79]: continue From 5193decf4a55f9c71dc3c1c3cdc16f96c9daf5da Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 28 May 2026 16:50:43 +0200 Subject: [PATCH 07/18] fix n_classes --- datasets/bci.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datasets/bci.py b/datasets/bci.py index dee5fb7..1a75f73 100644 --- a/datasets/bci.py +++ b/datasets/bci.py @@ -107,6 +107,7 @@ def get_data(self): y_test = np.concatenate( [y_all[i] for i in range(len(y_all)) if i not in ids_train] ) + np.unique(y_train), np.unique(y_test) # sanity check return dict( X_train=X_train, y_train=y_train, @@ -114,5 +115,5 @@ def get_data(self): y_test=y_test, task="classification", metrics=["accuracy", "balanced_accuracy", "f1_weighted"], - n_classes=5, + n_classes=len(np.unique(y_train)), ) From 5fa181597f4ceeaa0c25346698502682cdbeac83 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Thu, 28 May 2026 16:59:33 +0200 Subject: [PATCH 08/18] Fix Sleep --- datasets/sleep.py | 5 ++--- solvers/mantis.py | 7 ++++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/datasets/sleep.py b/datasets/sleep.py index 1156d0b..def8e88 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -88,7 +88,6 @@ class Dataset(BaseDataset): parameters = { "window_size_samples": [3000], - "sub_ids": range(1, 83), "mapping": { # We merge stages 3 and 4 following AASM standards. "Sleep stage W": 0, "Sleep stage 1": 1, @@ -108,7 +107,7 @@ def get_data(self): # Allow reuse of the download helper from benchmark_ad if present, # otherwise fall back to the data path directly. - + sub_ids = range(1, 83) preprocessors = [ Preprocessor(lambda data: np.multiply(data, self.factor)), Preprocessor( @@ -118,7 +117,7 @@ def get_data(self): ] X_all, y_all = [], [] - sub_ids = self.sub_ids[:2] if self.debug else self.sub_ids + sub_ids = sub_ids[:2] if self.debug else self.sub_ids for sub_id in sub_ids: if sub_id in [39, 68, 69, 78, 79]: continue diff --git a/solvers/mantis.py b/solvers/mantis.py index ecbaed9..cd50c9d 100644 --- a/solvers/mantis.py +++ b/solvers/mantis.py @@ -60,7 +60,12 @@ def set_objective(self, task, X_train, y_train, **meta): self.y_train = y_train self.meta = meta - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" # Load the model only on the first call for this checkpoint. should_reload = ( From 7a4177b9e47cdc242a1642d09f12e000870181d6 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Thu, 28 May 2026 17:05:02 +0200 Subject: [PATCH 09/18] add seed --- datasets/bci.py | 3 ++- datasets/sleep.py | 36 +++++++++++++++++++----------------- 2 files changed, 21 insertions(+), 18 deletions(-) diff --git a/datasets/bci.py b/datasets/bci.py index 1a75f73..af8716f 100644 --- a/datasets/bci.py +++ b/datasets/bci.py @@ -25,6 +25,7 @@ class Dataset(BaseDataset): parameters = { 'train_ratio': [0.8], 'debug': [True], + 'seed': [42], } def get_data(self): @@ -94,7 +95,7 @@ def get_data(self): y += [sample[1] for sample in splitted[sub].datasets[run]] X_all.append(np.array(x)) y_all.append(np.array(y)) - random_state = np.random.RandomState(seed=42) + random_state = np.random.RandomState(seed=self.seed) ids_train = random_state.choice( len(X_all), size=int(len(X_all) * self.train_ratio), replace=False diff --git a/datasets/sleep.py b/datasets/sleep.py index def8e88..4ce3087 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -87,20 +87,9 @@ class Dataset(BaseDataset): requirements = ["pip::pooch", "pandas", 'braindecode==1.5.1'] parameters = { - "window_size_samples": [3000], - "mapping": { # We merge stages 3 and 4 following AASM standards. - "Sleep stage W": 0, - "Sleep stage 1": 1, - "Sleep stage 2": 2, - "Sleep stage 3": 3, - "Sleep stage 4": 3, - "Sleep stage R": 4, - }, + "seed": [42], "train_ratio": [0.8], - "n_jobs": [1], "debug": [True], - "high_cut_hz": [30], - "factor": [1e6], } def get_data(self): @@ -108,21 +97,34 @@ def get_data(self): # Allow reuse of the download helper from benchmark_ad if present, # otherwise fall back to the data path directly. sub_ids = range(1, 83) + window_size_samples = 3000 + mapping = { # We merge stages 3 and 4 following AASM standards. + "Sleep stage W": 0, + "Sleep stage 1": 1, + "Sleep stage 2": 2, + "Sleep stage 3": 3, + "Sleep stage 4": 3, + "Sleep stage R": 4, + } + n_jobs = 1 + high_cut_hz = 40.0 + factor = 1e6 # Factor to convert from V to uV + preprocessors = [ - Preprocessor(lambda data: np.multiply(data, self.factor)), + Preprocessor(lambda data: np.multiply(data, factor)), Preprocessor( "filter", l_freq=None, - h_freq=self.high_cut_hz, n_jobs=self.n_jobs + h_freq=high_cut_hz, n_jobs=n_jobs ), ] X_all, y_all = [], [] - sub_ids = sub_ids[:2] if self.debug else self.sub_ids + sub_ids = sub_ids[:2] if self.debug else sub_ids for sub_id in sub_ids: if sub_id in [39, 68, 69, 78, 79]: continue X_, y_ = _load_subject( - sub_id, preprocessors, self.mapping, self.window_size_samples + sub_id, preprocessors, mapping, window_size_samples ) if self.debug: X_ = X_[:5000] @@ -130,7 +132,7 @@ def get_data(self): X_all.append(X_) y_all.append(y_) - random_state = np.random.RandomState(seed=42) + random_state = np.random.RandomState(seed=self.seed) ids_train = random_state.choice( len(X_all), size=int(len(X_all) * self.train_ratio), replace=False From 3553c8e8e999c14ecce85baa8b162148eb3c4068 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Fri, 29 May 2026 11:29:58 +0200 Subject: [PATCH 10/18] Add freq and ch_names metadata --- datasets/bci.py | 6 ++++++ datasets/sleep.py | 17 +++++++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/datasets/bci.py b/datasets/bci.py index af8716f..9e17259 100644 --- a/datasets/bci.py +++ b/datasets/bci.py @@ -66,6 +66,10 @@ def get_data(self): sfreq = dataset.datasets[0].raw.info["sfreq"] assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets]) + + # Extract channels names + ch_names = dataset.datasets[0].raw.ch_names + # Calculate the trial start offset in samples. trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) @@ -117,4 +121,6 @@ def get_data(self): task="classification", metrics=["accuracy", "balanced_accuracy", "f1_weighted"], n_classes=len(np.unique(y_train)), + freq=sfreq, + ch_names=ch_names ) diff --git a/datasets/sleep.py b/datasets/sleep.py index 4ce3087..99ec5b2 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -32,6 +32,12 @@ def _load_subject( dataset = SleepPhysionet(subject_ids=[sub_id], crop_wake_mins=30) preprocess(dataset, preprocessors) + + # Extract the frequency and channels names + raw = dataset.datasets[0].raw + sfreq = raw.info["sfreq"] + ch_names = raw.ch_names + windows_dataset = create_windows_from_events( dataset, trial_start_offset_samples=0, @@ -53,7 +59,7 @@ def _load_subject( data = x[0].T all_data.append(data) - return all_data, all_labels + return all_data, all_labels, sfreq, ch_names class Dataset(BaseDataset): @@ -120,12 +126,17 @@ def get_data(self): X_all, y_all = [], [] sub_ids = sub_ids[:2] if self.debug else sub_ids + sfreq_ref, ch_names_ref = None, None for sub_id in sub_ids: if sub_id in [39, 68, 69, 78, 79]: continue - X_, y_ = _load_subject( + X_, y_, sfreq, ch_names = _load_subject( sub_id, preprocessors, mapping, window_size_samples ) + if sfreq_ref is None: + sfreq_ref, ch_names_ref = sfreq, ch_names + else: + assert sfreq == sfreq_ref and ch_names == ch_names_ref, f"Inconsistent meta for sub {sub_id}" if self.debug: X_ = X_[:5000] y_ = y_[:5000] @@ -154,4 +165,6 @@ def get_data(self): task="classification", metrics=["accuracy", "balanced_accuracy", "f1_weighted"], n_classes=5, + freq=sfreq, + ch_names=ch_names ) From b9420fa634c1864cd9852f4683fedcf1b0414cc6 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Fri, 29 May 2026 14:13:39 +0200 Subject: [PATCH 11/18] add hgd --- datasets/{bci.py => bnci2014_001.py} | 4 +- datasets/hgd.py | 192 +++++++++++++++++++++++++++ datasets/sleep.py | 12 +- 3 files changed, 203 insertions(+), 5 deletions(-) rename datasets/{bci.py => bnci2014_001.py} (98%) create mode 100644 datasets/hgd.py diff --git a/datasets/bci.py b/datasets/bnci2014_001.py similarity index 98% rename from datasets/bci.py rename to datasets/bnci2014_001.py index af8716f..8db5308 100644 --- a/datasets/bci.py +++ b/datasets/bnci2014_001.py @@ -16,7 +16,7 @@ class Dataset(BaseDataset): # Name to select the dataset in the CLI and to display the results. - name = "bci" + name = "BNCI2014_001" requirements = [ 'braindecode==1.5.1', 'moabb==1.5.0', @@ -24,7 +24,7 @@ class Dataset(BaseDataset): parameters = { 'train_ratio': [0.8], - 'debug': [True], + 'debug': [False], 'seed': [42], } diff --git a/datasets/hgd.py b/datasets/hgd.py new file mode 100644 index 0000000..6d70bbc --- /dev/null +++ b/datasets/hgd.py @@ -0,0 +1,192 @@ +"""High-Gamma Dataset (HGD) — motor imagery EEG classification. + +Wraps the High-Gamma Dataset from Schirrmeister et al. (2017): + "Deep learning with convolutional neural networks for EEG decoding + and visualization", Human Brain Mapping. + https://doi.org/10.1002/hbm.23730 + +128-electrode EEG (44 motor-cortex channels used) recorded from 14 healthy +subjects performing ~1000 four-second trials of executed movements across +13 runs. Data are downloaded automatically via braindecode / MOABB. + +Labels (4 classes): + 0: left_hand 1: right_hand 2: feet 3: rest + +The dataset's own train/test split (runs 1-11 → train, runs 12-13 → test) +is preserved per subject; subjects are then pooled according to +`train_ratio`. + +Data contract output +-------------------- +X_train : np.ndarray (N, T, C) windows (n_times, n_channels=44) +y_train : np.ndarray (N,) int class labels 0-3 +X_test : np.ndarray (M, T, C) +y_test : np.ndarray (M,) int +task : "classification" +metrics : ["accuracy", "balanced_accuracy", "f1_weighted"] +n_classes : 4 +""" + +import numpy as np +from braindecode.datasets import HGD +from braindecode.preprocessing.preprocess import preprocess, Preprocessor +from braindecode.preprocessing.windowers import create_windows_from_events +from sklearn.preprocessing import scale as standard_scale + +from benchopt import BaseDataset + + +# --------------------------------------------------------------------------- +# Label mapping +# --------------------------------------------------------------------------- +LABEL_MAPPING = { + "left_hand": 0, + "right_hand": 1, + "feet": 2, + "rest": 3, +} + + +# --------------------------------------------------------------------------- +# Per-subject loader +# --------------------------------------------------------------------------- + +def _load_subject(sub_id, preprocessors, window_size_samples): + """Wrapper that falls back to a description-based split if metadata + item access is unavailable in the installed braindecode version.""" + dataset = HGD(subject_ids=[sub_id]) + preprocess(dataset, preprocessors) + + windows_dataset = create_windows_from_events( + dataset, + trial_start_offset_samples=0, + trial_stop_offset_samples=0, + window_size_samples=window_size_samples, + window_stride_samples=window_size_samples, + preload=True, + mapping=LABEL_MAPPING, + ) + preprocess( + windows_dataset, [Preprocessor(standard_scale, channel_wise=True)] + ) + + X_, y_ = [], [] + + # windows_dataset.description contains a 'split' column ('train'/'test') + # Each window is indexed into its base dataset; we replicate that mapping. + descriptions = windows_dataset.datasets # list of WindowsDataset + + for ds in descriptions: + for x, label, _ in ds: + window = x.T # (T, C) + X_.append(window) + y_.append(label) + + return X_, y_ + + +# --------------------------------------------------------------------------- +# Benchopt Dataset +# --------------------------------------------------------------------------- + +class Dataset(BaseDataset): + """High-Gamma Dataset (HGD) — 4-class motor EEG classification. + + Parameters + ---------- + resample_hz : float + Target sampling frequency. The raw data are recorded at 500 Hz; + default resamples to 250 Hz (window_size_samples=1000 → 4 s). + high_cut_hz : float + Cutoff for the low-pass filter applied to raw signals (Hz). + factor : float + Multiplicative scaling applied before filtering (e.g. V → µV). + window_size_samples : int + Length of each trial window in samples (after resampling). + train_ratio : float + Fraction of subjects whose trials go into the training pool. + The internal per-subject train/test split (runs 1-11 vs 12-13) is + always respected first; `train_ratio` then selects which subjects + contribute to the final train vs test arrays. + debug : bool + If True, load only the first 2 subjects for fast iteration. + seed : int + Random seed for the subject-level train/test split. + """ + + name = "HGD" + + requirements = ["pip::moabb", "pandas", "braindecode==1.5.1"] + + parameters = { + "seed": [42], + "train_ratio": [0.8], + "resample_hz": [250], + "high_cut_hz": [40.0], + "factor": [1e6], # V → µV + "window_size_samples": [1000], # 4 s at 250 Hz + "debug": [True], + "sub_id": [list(range(1, 15))], # 14 subjects + } + + def get_data(self): + n_jobs = 1 + preprocessors = [ + # Convert V → µV + Preprocessor(lambda data: np.multiply(data, self.factor)), + # Resample to target frequency + Preprocessor("resample", sfreq=self.resample_hz), + # Low-pass filter + Preprocessor( + "filter", l_freq=None, h_freq=self.high_cut_hz, n_jobs=n_jobs + ), + ] + + sub_ids = list(range(1, 15)) # 14 subjects + if self.debug: + sub_ids = sub_ids[:2] + + # Collect per-subject (train, test) pairs + X_all, y_all = [], [] + + for sub_id in sub_ids: + X_, y_ = _load_subject( + sub_id, preprocessors, self.window_size_samples + ) + X_all.append(X_) + y_all.append(y_) + + # ------------------------------------------------------------------ + # Subject-level train / test split (same pattern as Sleep dataset) + # ------------------------------------------------------------------ + random_state = np.random.RandomState(seed=self.seed) + ids_train = random_state.choice( + len(X_all), + size=int(len(X_all) * self.train_ratio), + replace=False, + ) + ids_train_set = set(ids_train.tolist()) + + X_train = np.concatenate( + [X_all[i] for i in ids_train_set], axis=0 + ) + y_train = np.concatenate( + [y_all[i] for i in ids_train_set], axis=0 + ) + X_test = np.concatenate( + [X_all[i] for i in range(len(X_all)) + if i not in ids_train_set], axis=0 + ) + y_test = np.concatenate( + [y_all[i] for i in range(len(y_all)) + if i not in ids_train_set], axis=0 + ) + return dict( + X_train=X_train, # (N, window_size_samples, n_channels) + y_train=y_train, # (N,) int in {0, 1, 2, 3} + X_test=X_test, # (M, window_size_samples, n_channels) + y_test=y_test, # (M,) + task="classification", + metrics=["accuracy", "balanced_accuracy", "f1_weighted"], + n_classes=4, + ) diff --git a/datasets/sleep.py b/datasets/sleep.py index 4ce3087..3a7f02e 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -89,11 +89,10 @@ class Dataset(BaseDataset): parameters = { "seed": [42], "train_ratio": [0.8], - "debug": [True], + "debug": [False], } - def get_data(self): - + def prepare(self): # Allow reuse of the download helper from benchmark_ad if present, # otherwise fall back to the data path directly. sub_ids = range(1, 83) @@ -146,6 +145,13 @@ def get_data(self): y_test = np.concatenate( [y_all[i] for i in range(len(y_all)) if i not in ids_train] ) + + return X_train, y_train, X_test, y_test + + def get_data(self): + + X_train, y_train, X_test, y_test = self.prepare() + return dict( X_train=X_train, y_train=y_train, From 517570a5852390d83828373bed527f682c648a3c Mon Sep 17 00:00:00 2001 From: tgnassou Date: Fri, 29 May 2026 14:41:27 +0200 Subject: [PATCH 12/18] add pip:: --- datasets/bnci2014_001.py | 2 +- datasets/hgd.py | 3 +-- datasets/{sleep.py => sleepphysionet.py} | 4 ++-- 3 files changed, 4 insertions(+), 5 deletions(-) rename datasets/{sleep.py => sleepphysionet.py} (98%) diff --git a/datasets/bnci2014_001.py b/datasets/bnci2014_001.py index 5ec002b..be61666 100644 --- a/datasets/bnci2014_001.py +++ b/datasets/bnci2014_001.py @@ -19,7 +19,7 @@ class Dataset(BaseDataset): name = "BNCI2014_001" requirements = [ - 'braindecode==1.5.1', 'moabb==1.5.0', + 'pip::braindecode', 'pip::moabb', ] parameters = { diff --git a/datasets/hgd.py b/datasets/hgd.py index 6d70bbc..6d249fd 100644 --- a/datasets/hgd.py +++ b/datasets/hgd.py @@ -116,7 +116,7 @@ class Dataset(BaseDataset): name = "HGD" - requirements = ["pip::moabb", "pandas", "braindecode==1.5.1"] + requirements = ["pip::moabb", "pip::pandas", "pip::braindecode"] parameters = { "seed": [42], @@ -126,7 +126,6 @@ class Dataset(BaseDataset): "factor": [1e6], # V → µV "window_size_samples": [1000], # 4 s at 250 Hz "debug": [True], - "sub_id": [list(range(1, 15))], # 14 subjects } def get_data(self): diff --git a/datasets/sleep.py b/datasets/sleepphysionet.py similarity index 98% rename from datasets/sleep.py rename to datasets/sleepphysionet.py index 178f1b6..71f310a 100644 --- a/datasets/sleep.py +++ b/datasets/sleepphysionet.py @@ -88,9 +88,9 @@ class Dataset(BaseDataset): Fraction of each recording used as the training (normal) portion. """ - name = "Sleep" + name = "SleepPhysionet" - requirements = ["pip::pooch", "pandas", 'braindecode==1.5.1'] + requirements = ["pip::pooch", "pip::pandas", "pip::braindecode"] parameters = { "seed": [42], From f094e5d50f7d1bcfa7032164230c072edac2561e Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Fri, 29 May 2026 15:08:31 +0200 Subject: [PATCH 13/18] Add REVE solver and add metadata to HGD --- datasets/hgd.py | 16 +++- datasets/sleep.py | 4 +- solvers/reve.py | 228 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 243 insertions(+), 5 deletions(-) create mode 100644 solvers/reve.py diff --git a/datasets/hgd.py b/datasets/hgd.py index 6d70bbc..fa6a00f 100644 --- a/datasets/hgd.py +++ b/datasets/hgd.py @@ -56,6 +56,9 @@ def _load_subject(sub_id, preprocessors, window_size_samples): item access is unavailable in the installed braindecode version.""" dataset = HGD(subject_ids=[sub_id]) preprocess(dataset, preprocessors) + raw = dataset.datasets[0].raw + sfreq = raw.info["sfreq"] + ch_names = raw.ch_names windows_dataset = create_windows_from_events( dataset, @@ -82,7 +85,7 @@ def _load_subject(sub_id, preprocessors, window_size_samples): X_.append(window) y_.append(label) - return X_, y_ + return X_, y_, sfreq, ch_names # --------------------------------------------------------------------------- @@ -116,7 +119,7 @@ class Dataset(BaseDataset): name = "HGD" - requirements = ["pip::moabb", "pandas", "braindecode==1.5.1"] + requirements = ["pip::moabb", "pandas", "pip::braindecode==1.5.1"] parameters = { "seed": [42], @@ -148,11 +151,16 @@ def get_data(self): # Collect per-subject (train, test) pairs X_all, y_all = [], [] + sfreq_ref, ch_names_ref = None, None for sub_id in sub_ids: - X_, y_ = _load_subject( + X_, y_, sfreq, ch_names = _load_subject( sub_id, preprocessors, self.window_size_samples ) + if sfreq_ref is None: + sfreq_ref, ch_names_ref = sfreq_ref, ch_names + else: + assert sfreq == sfreq_ref and ch_names == ch_names_ref, f"Inconsistent meta for sub {sub_id}" X_all.append(X_) y_all.append(y_) @@ -189,4 +197,6 @@ def get_data(self): task="classification", metrics=["accuracy", "balanced_accuracy", "f1_weighted"], n_classes=4, + freq=sfreq_ref, + ch_names=ch_names_ref ) diff --git a/datasets/sleep.py b/datasets/sleep.py index 178f1b6..284f984 100644 --- a/datasets/sleep.py +++ b/datasets/sleep.py @@ -157,11 +157,11 @@ def prepare(self): [y_all[i] for i in range(len(y_all)) if i not in ids_train] ) - return X_train, y_train, X_test, y_test + return X_train, y_train, X_test, y_test, sfreq_ref, ch_names_ref def get_data(self): - X_train, y_train, X_test, y_test = self.prepare() + X_train, y_train, X_test, y_test, sfreq, ch_names = self.prepare() return dict( X_train=X_train, diff --git a/solvers/reve.py b/solvers/reve.py new file mode 100644 index 0000000..a7d86e5 --- /dev/null +++ b/solvers/reve.py @@ -0,0 +1,228 @@ +"""REVE solver for EEG time series classification. + +Uses the REVE EEG foundation model (HuggingFace +``brain-bzh/reve-*``) to extract embeddings, then trains a Random +Forest classifier on top — no backbone fine-tuning, the foundation +model is kept frozen. + +REVE expects EEG sampled at **200 Hz**; inputs sampled at any other +rate are resampled with a polyphase (anti-aliased) filter before the +forward pass. + +References: + https://huggingface.co/brain-bzh/reve-large + https://huggingface.co/brain-bzh/reve-positions +""" + +import numpy as np +import torch +from benchopt import BaseSolver +from sklearn.pipeline import make_pipeline +from sklearn.ensemble import RandomForestClassifier +from sklearn.preprocessing import FunctionTransformer + +SUPPORTED_TASKS = {"classification"} +REVE_SFREQ = 200 # REVE is trained at 200 Hz + + +class Solver(BaseSolver): + """REVE foundation model + Random Forest classifier.""" + + name = "REVE-RandomForest" + + requirements = [ + "pip::braindecode", + "pip::transformers", + "pip::torch", + "pip::scipy", + ] + + parameters = { + "checkpoint": ["brain-bzh/reve-large"], + "pos_bank_checkpoint": ["brain-bzh/reve-positions"], + "batch_size": [16], + "n_estimators": [100], + } + + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"REVE solver does not support task={task!r}" + + # REVE expects monopolar 10-20 channel names ("Fpz", "C3", …). + # Datasets that ship bipolar derivations ("EEG Fpz-Cz") or + # prefixed names cannot be aligned with REVE's position bank — + # skip them rather than feed bogus electrode positions. + ch_names = kwargs.get("ch_names") + if ch_names is None: + return True, ( + "REVE requires `ch_names` in dataset meta to build " + "channel positions; dataset does not provide it." + ) + bad = [ + n for n in ch_names + if "-" in n or any( + n.startswith(prefix) for prefix in ("EEG ", "EOG ", "EMG ") + ) + ] + if bad: + return True, ( + f"REVE expects monopolar 10-20 channel names; dataset " + f"provides bipolar/prefixed names (e.g. {bad[0]!r}). " + f"Skipping to avoid silently producing biased embeddings." + ) + + # REVE checkpoints are gated on HuggingFace Hub. Skip cleanly + # (rather than crashing mid-benchmark) when the user has no + # access — either because they haven't requested it or because + # their machine isn't authenticated (`huggingface-cli login`). + try: + from huggingface_hub import HfApi + HfApi().model_info(self.checkpoint) + except Exception as e: + return True, ( + f"REVE checkpoint '{self.checkpoint}' is gated or " + f"unreachable. Request access at " + f"https://huggingface.co/{self.checkpoint} and run " + f"`huggingface-cli login`. ({type(e).__name__}: {e})" + ) + return False, None + + def set_objective(self, task, X_train, y_train, **meta): + """Prepare the solver for a given dataset configuration. + + The foundation model and position bank are loaded here (not in + ``run``) so the download/load time is excluded from timing. + """ + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # Source sampling rate and electrode names must come from the + # dataset's meta — REVE cannot infer them. + self._src_sfreq = float(meta.get("freq", REVE_SFREQ)) + self._ch_names = meta.get("ch_names", None) + if self._ch_names is None: + raise ValueError( + "REVE requires `ch_names` in dataset meta to build " + "channel positions." + ) + + should_reload = ( + not hasattr(self, "_network") + or getattr(self, "_loaded_checkpoint", None) != self.checkpoint + ) + if should_reload: + try: + from transformers import AutoModel + + # REVE ships custom modeling code on the Hub + # (`modeling_reve.py`), so `trust_remote_code=True` is + # mandatory to avoid an interactive prompt that would + # break batch / CI runs of the benchmark. + network = AutoModel.from_pretrained( + self.checkpoint, trust_remote_code=True + ) + pos_bank = AutoModel.from_pretrained( + self.pos_bank_checkpoint, trust_remote_code=True + ) + + self._network = network.to(device).eval() + self._pos_bank = pos_bank.to(device).eval() + self._loaded_checkpoint = self.checkpoint + print( + f"✓ REVE checkpoint loaded: {self.checkpoint} " + f"on device: {device}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to load REVE checkpoint '{self.checkpoint}': {e}" + ) + + # Pre-compute channel positions (C, 3) once per dataset. + with torch.no_grad(): + self._positions = self._pos_bank(self._ch_names) + + self._device = device + + self.model = make_pipeline( + FunctionTransformer(self._extract_embeddings), + RandomForestClassifier( + n_estimators=self.n_estimators, + n_jobs=-1, + random_state=42, + ), + ) + + def run(self, _): + """Fit the Random Forest on REVE embeddings.""" + self.model.fit(self.X_train, self.y_train) + + def _extract_embeddings(self, X): + """Forward batches through the frozen REVE backbone. + + Returns + ------- + np.ndarray of shape (N, embedding_dim) + """ + batch_size = self.batch_size + n_samples = len(X) + all_embeddings = [] + + for batch_idx in range(0, n_samples, batch_size): + batch_end = min(batch_idx + batch_size, n_samples) + X_batch = np.asarray(X[batch_idx:batch_end], dtype=np.float32) + X_batch_processed = self._prepare_inputs(X_batch) + + x_t = torch.tensor( + X_batch_processed, dtype=torch.float32, device=self._device + ) + # Broadcast positions (C, 3) → (B, C, 3) for this batch. + pos = self._positions.unsqueeze(0).expand(x_t.size(0), -1, -1) + + # The HF custom `Reve` class loaded via AutoModel is tagged + # "Feature Extraction" on the Hub: its forward already + # returns embeddings directly, no flag needed. + with torch.no_grad(): + emb = self._network(x_t, pos) + # If REVE returns a sequence/spatial map (n_chans × n_patches + # × D), flatten the trailing axes into one vector per sample. + if emb.ndim > 2: + emb = emb.flatten(start_dim=1) + + all_embeddings.append(emb.cpu().numpy()) + + return np.vstack(all_embeddings) + + def _prepare_inputs(self, X_batch): + """Reshape (N, T, C) → (N, C, T) and resample to 200 Hz. + + REVE is frequency-aware. A naive ``F.interpolate`` would alias when + downsampling and add spectral artefacts when upsampling, so we + use a polyphase resampler (``scipy.signal.resample_poly``) that + applies a anti-aliasing filter. + """ + X_in = X_batch.transpose(0, 2, 1) # (N, C, T) + + if int(self._src_sfreq) != REVE_SFREQ: + from math import gcd + from scipy.signal import resample_poly + + src = int(self._src_sfreq) + g = gcd(src, REVE_SFREQ) + up = REVE_SFREQ // g + down = src // g + X_in = resample_poly(X_in, up=up, down=down, axis=-1) + + return np.ascontiguousarray(X_in, dtype=np.float32) + + def get_result(self): + """Return the fitted pipeline.""" + return {"model": self.model} From 5f1d18c6c02c166a7a42b930abd2e3702b031f16 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Fri, 29 May 2026 15:22:47 +0200 Subject: [PATCH 14/18] Add REVE solver and meta data for HGD dataset --- datasets/hgd.py | 2 +- solvers/reve.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/datasets/hgd.py b/datasets/hgd.py index 36443a3..715c7d8 100644 --- a/datasets/hgd.py +++ b/datasets/hgd.py @@ -158,7 +158,7 @@ def get_data(self): sub_id, preprocessors, self.window_size_samples ) if sfreq_ref is None: - sfreq_ref, ch_names_ref = sfreq_ref, ch_names + sfreq_ref, ch_names_ref = sfreq, ch_names else: assert sfreq == sfreq_ref and ch_names == ch_names_ref, f"Inconsistent meta for sub {sub_id}" X_all.append(X_) diff --git a/solvers/reve.py b/solvers/reve.py index a7d86e5..6afddf9 100644 --- a/solvers/reve.py +++ b/solvers/reve.py @@ -106,8 +106,11 @@ def set_objective(self, task, X_train, y_train, **meta): device = "cpu" # Source sampling rate and electrode names must come from the - # dataset's meta — REVE cannot infer them. - self._src_sfreq = float(meta.get("freq", REVE_SFREQ)) + # dataset's meta — REVE cannot infer them. ``dict.get(k, default)`` + # only falls back when the key is absent, so we coalesce ``None`` + # values too (datasets may set ``freq=None`` to signal "unknown"). + freq = meta.get("freq") or REVE_SFREQ + self._src_sfreq = float(freq) self._ch_names = meta.get("ch_names", None) if self._ch_names is None: raise ValueError( @@ -193,7 +196,7 @@ def _extract_embeddings(self, X): with torch.no_grad(): emb = self._network(x_t, pos) # If REVE returns a sequence/spatial map (n_chans × n_patches - # × D), flatten the trailing axes into one vector per sample. + # × D), flatten the trailing axes into one vector per sample if emb.ndim > 2: emb = emb.flatten(start_dim=1) From acdd3140f7f78932cd14da77712fbc7024e63530 Mon Sep 17 00:00:00 2001 From: tgnassou Date: Fri, 29 May 2026 17:02:18 +0200 Subject: [PATCH 15/18] debug = True --- datasets/hgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datasets/hgd.py b/datasets/hgd.py index 715c7d8..b98f053 100644 --- a/datasets/hgd.py +++ b/datasets/hgd.py @@ -129,7 +129,7 @@ class Dataset(BaseDataset): "high_cut_hz": [40.0], "factor": [1e6], # V → µV "window_size_samples": [1000], # 4 s at 250 Hz - "debug": [True], + "debug": [False], } def get_data(self): From 40b7acebb188854507bfd38873f4ed08c7f4f048 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Fri, 29 May 2026 17:56:16 +0200 Subject: [PATCH 16/18] Add CBraMod and adapt the classification heads --- solvers/cbramod.py | 161 +++++++++++++++++++++++++++++++++++++++++++++ solvers/reve.py | 138 ++++++++++++++------------------------ 2 files changed, 209 insertions(+), 90 deletions(-) create mode 100644 solvers/cbramod.py diff --git a/solvers/cbramod.py b/solvers/cbramod.py new file mode 100644 index 0000000..99389c0 --- /dev/null +++ b/solvers/cbramod.py @@ -0,0 +1,161 @@ +"""CBraMod solver — frozen foundation backbone + linear probe. + +CBraMod (Wang et al., ICLR 2025) is pretrained on TUEG with 1-second +patches at 200 Hz. Inputs at any other rate are resampled via a +polyphase (anti-aliased) filter. Channel-agnostic montage thanks to +its Asymmetric Conditional Positional Encoding. + +References: + https://arxiv.org/abs/2412.07236 + https://braindecode.org/stable/generated/braindecode.models.CBraMod.html +""" + +from math import gcd + +import numpy as np +import torch +from benchopt import BaseSolver +from scipy.signal import resample_poly + +from benchmark_utils.adapters.linear_probe import LinearProbeAdapter + +SUPPORTED_TASKS = {"classification"} +CBRAMOD_SFREQ = 200 +CBRAMOD_PATCH_SIZE = 200 # 1-second patches at 200 Hz; T must be a multiple + + +class Solver(BaseSolver): + """CBraMod foundation model + linear probe.""" + + name = "CBraMod" + + requirements = [ + "pip::braindecode", + "pip::torch", + "pip::scipy", + ] + + parameters = { + "checkpoint": ["braindecode/cbramod-pretrained"], + "batch_size": [16], + "n_estimators": [100], + "max_iter": [1000], + "classifier": ["logistic_regression"], + } + + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"CBraMod solver does not support task={task!r}" + return False, None + + def set_objective(self, task, X_train, y_train, **meta): + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + freq = meta.get("freq") or CBRAMOD_SFREQ + self._src_sfreq = float(freq) + + should_reload = ( + not hasattr(self, "_network") + or getattr(self, "_loaded_checkpoint", None) != self.checkpoint + ) + if should_reload: + try: + from braindecode.models import CBraMod + + # return_encoder_output=True bypasses the randomly-init + # classification head and exposes encoder features. + network = CBraMod.from_pretrained( + self.checkpoint, + return_encoder_output=True, + ) + self._network = network.to(device).eval() + self._loaded_checkpoint = self.checkpoint + print( + f"✓ CBraMod checkpoint loaded: {self.checkpoint} " + f"on device: {device}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to load CBraMod checkpoint " + f"'{self.checkpoint}': {e}" + ) + + self._device = device + + def run(self, _): + self._adapter = LinearProbeAdapter( + encoder=self, + task=self.task, + classifier=self.classifier, + max_iter=self.max_iter, + n_estimators=self.n_estimators, + ) + self._adapter.fit(self.X_train, self.y_train) + + def encode(self, X): + """LinearProbeAdapter entry point — returns (N, embed_dim).""" + return self._extract_embeddings(X) + + def _extract_embeddings(self, X): + batch_size = self.batch_size + n_samples = len(X) + all_embeddings = [] + + for batch_idx in range(0, n_samples, batch_size): + batch_end = min(batch_idx + batch_size, n_samples) + X_batch = np.asarray(X[batch_idx:batch_end], dtype=np.float32) + X_batch_processed = self._prepare_inputs(X_batch) + + x_t = torch.tensor( + X_batch_processed, dtype=torch.float32, device=self._device + ) + + with torch.no_grad(): + emb = self._network(x_t) + + # Encoder returns (B, n_chans, n_patches, D) — flatten to + # one vector per sample. + if emb.ndim > 2: + emb = emb.flatten(start_dim=1) + + all_embeddings.append(emb.cpu().numpy()) + + return np.vstack(all_embeddings) + + def _prepare_inputs(self, X_batch): + """Reshape (N, T, C) → (N, C, T), resample to 200 Hz, truncate. + + Polyphase resampling is anti-aliased so it works for both up- + and downsampling. CBraMod patches at 200 samples per token, so + ``T`` must be a multiple of ``CBRAMOD_PATCH_SIZE`` — we truncate + any trailing partial patch rather than zero-pad (cleaner than + injecting fake samples). + """ + X_in = X_batch.transpose(0, 2, 1) + + if int(self._src_sfreq) != CBRAMOD_SFREQ: + src = int(self._src_sfreq) + g = gcd(src, CBRAMOD_SFREQ) + X_in = resample_poly( + X_in, up=CBRAMOD_SFREQ // g, down=src // g, axis=-1 + ) + + T = X_in.shape[-1] + remainder = T % CBRAMOD_PATCH_SIZE + if remainder: + X_in = X_in[..., : T - remainder] + + return np.ascontiguousarray(X_in, dtype=np.float32) + + def get_result(self): + return {"model": self._adapter} diff --git a/solvers/reve.py b/solvers/reve.py index 6afddf9..28b18e5 100644 --- a/solvers/reve.py +++ b/solvers/reve.py @@ -1,34 +1,32 @@ -"""REVE solver for EEG time series classification. +"""REVE solver — frozen foundation backbone + linear probe. -Uses the REVE EEG foundation model (HuggingFace -``brain-bzh/reve-*``) to extract embeddings, then trains a Random -Forest classifier on top — no backbone fine-tuning, the foundation -model is kept frozen. - -REVE expects EEG sampled at **200 Hz**; inputs sampled at any other -rate are resampled with a polyphase (anti-aliased) filter before the -forward pass. +REVE expects EEG sampled at 200 Hz with monopolar 10-20 channel +names. Inputs at other rates are resampled via polyphase filtering; +datasets with bipolar or prefixed channel names are skipped (the +position bank cannot resolve them to single 3D coordinates). References: https://huggingface.co/brain-bzh/reve-large https://huggingface.co/brain-bzh/reve-positions """ +from math import gcd + import numpy as np import torch from benchopt import BaseSolver -from sklearn.pipeline import make_pipeline -from sklearn.ensemble import RandomForestClassifier -from sklearn.preprocessing import FunctionTransformer +from scipy.signal import resample_poly + +from benchmark_utils.adapters.linear_probe import LinearProbeAdapter SUPPORTED_TASKS = {"classification"} -REVE_SFREQ = 200 # REVE is trained at 200 Hz +REVE_SFREQ = 200 class Solver(BaseSolver): - """REVE foundation model + Random Forest classifier.""" + """REVE foundation model + linear probe.""" - name = "REVE-RandomForest" + name = "REVE" requirements = [ "pip::braindecode", @@ -42,57 +40,43 @@ class Solver(BaseSolver): "pos_bank_checkpoint": ["brain-bzh/reve-positions"], "batch_size": [16], "n_estimators": [100], + "max_iter": [1000], + "classifier": ["logistic_regression"], } def skip(self, task, **kwargs): if task not in SUPPORTED_TASKS: return True, f"REVE solver does not support task={task!r}" - # REVE expects monopolar 10-20 channel names ("Fpz", "C3", …). - # Datasets that ship bipolar derivations ("EEG Fpz-Cz") or - # prefixed names cannot be aligned with REVE's position bank — - # skip them rather than feed bogus electrode positions. + # REVE was trained on monopolar 10-20 montages — bipolar + # derivations ("EEG Fpz-Cz") have no single 3D position. ch_names = kwargs.get("ch_names") if ch_names is None: - return True, ( - "REVE requires `ch_names` in dataset meta to build " - "channel positions; dataset does not provide it." - ) + return True, "REVE requires `ch_names` in dataset meta." bad = [ n for n in ch_names if "-" in n or any( - n.startswith(prefix) for prefix in ("EEG ", "EOG ", "EMG ") + n.startswith(p) for p in ("EEG ", "EOG ", "EMG ") ) ] if bad: return True, ( - f"REVE expects monopolar 10-20 channel names; dataset " - f"provides bipolar/prefixed names (e.g. {bad[0]!r}). " - f"Skipping to avoid silently producing biased embeddings." + f"REVE expects monopolar 10-20 names; got {bad[0]!r}." ) - # REVE checkpoints are gated on HuggingFace Hub. Skip cleanly - # (rather than crashing mid-benchmark) when the user has no - # access — either because they haven't requested it or because - # their machine isn't authenticated (`huggingface-cli login`). + # Gated repo: skip cleanly when access/auth is missing. try: from huggingface_hub import HfApi HfApi().model_info(self.checkpoint) except Exception as e: return True, ( - f"REVE checkpoint '{self.checkpoint}' is gated or " - f"unreachable. Request access at " - f"https://huggingface.co/{self.checkpoint} and run " - f"`huggingface-cli login`. ({type(e).__name__}: {e})" + f"REVE checkpoint '{self.checkpoint}' gated or unreachable. " + f"Request access and run `huggingface-cli login`. " + f"({type(e).__name__})" ) return False, None def set_objective(self, task, X_train, y_train, **meta): - """Prepare the solver for a given dataset configuration. - - The foundation model and position bank are loaded here (not in - ``run``) so the download/load time is excluded from timing. - """ self.task = task self.X_train = X_train self.y_train = y_train @@ -105,18 +89,9 @@ def set_objective(self, task, X_train, y_train, **meta): else: device = "cpu" - # Source sampling rate and electrode names must come from the - # dataset's meta — REVE cannot infer them. ``dict.get(k, default)`` - # only falls back when the key is absent, so we coalesce ``None`` - # values too (datasets may set ``freq=None`` to signal "unknown"). freq = meta.get("freq") or REVE_SFREQ self._src_sfreq = float(freq) - self._ch_names = meta.get("ch_names", None) - if self._ch_names is None: - raise ValueError( - "REVE requires `ch_names` in dataset meta to build " - "channel positions." - ) + self._ch_names = meta["ch_names"] should_reload = ( not hasattr(self, "_network") @@ -126,10 +101,9 @@ def set_objective(self, task, X_train, y_train, **meta): try: from transformers import AutoModel - # REVE ships custom modeling code on the Hub - # (`modeling_reve.py`), so `trust_remote_code=True` is - # mandatory to avoid an interactive prompt that would - # break batch / CI runs of the benchmark. + # trust_remote_code: REVE ships custom modeling code on + # the Hub. Mandatory to avoid an interactive prompt that + # would break batch/CI runs. network = AutoModel.from_pretrained( self.checkpoint, trust_remote_code=True ) @@ -149,32 +123,27 @@ def set_objective(self, task, X_train, y_train, **meta): f"Failed to load REVE checkpoint '{self.checkpoint}': {e}" ) - # Pre-compute channel positions (C, 3) once per dataset. + # Positions (C, 3) recomputed per dataset since ch_names varies. with torch.no_grad(): self._positions = self._pos_bank(self._ch_names) self._device = device - self.model = make_pipeline( - FunctionTransformer(self._extract_embeddings), - RandomForestClassifier( - n_estimators=self.n_estimators, - n_jobs=-1, - random_state=42, - ), + def run(self, _): + self._adapter = LinearProbeAdapter( + encoder=self, + task=self.task, + classifier=self.classifier, + max_iter=self.max_iter, + n_estimators=self.n_estimators, ) + self._adapter.fit(self.X_train, self.y_train) - def run(self, _): - """Fit the Random Forest on REVE embeddings.""" - self.model.fit(self.X_train, self.y_train) + def encode(self, X): + """LinearProbeAdapter entry point — returns (N, embed_dim).""" + return self._extract_embeddings(X) def _extract_embeddings(self, X): - """Forward batches through the frozen REVE backbone. - - Returns - ------- - np.ndarray of shape (N, embedding_dim) - """ batch_size = self.batch_size n_samples = len(X) all_embeddings = [] @@ -187,16 +156,11 @@ def _extract_embeddings(self, X): x_t = torch.tensor( X_batch_processed, dtype=torch.float32, device=self._device ) - # Broadcast positions (C, 3) → (B, C, 3) for this batch. pos = self._positions.unsqueeze(0).expand(x_t.size(0), -1, -1) - # The HF custom `Reve` class loaded via AutoModel is tagged - # "Feature Extraction" on the Hub: its forward already - # returns embeddings directly, no flag needed. with torch.no_grad(): emb = self._network(x_t, pos) - # If REVE returns a sequence/spatial map (n_chans × n_patches - # × D), flatten the trailing axes into one vector per sample + if emb.ndim > 2: emb = emb.flatten(start_dim=1) @@ -207,25 +171,19 @@ def _extract_embeddings(self, X): def _prepare_inputs(self, X_batch): """Reshape (N, T, C) → (N, C, T) and resample to 200 Hz. - REVE is frequency-aware. A naive ``F.interpolate`` would alias when - downsampling and add spectral artefacts when upsampling, so we - use a polyphase resampler (``scipy.signal.resample_poly``) that - applies a anti-aliasing filter. + Polyphase (anti-aliased) — REVE is frequency-aware so naive + ``F.interpolate`` would corrupt the bands it was trained on. """ - X_in = X_batch.transpose(0, 2, 1) # (N, C, T) + X_in = X_batch.transpose(0, 2, 1) if int(self._src_sfreq) != REVE_SFREQ: - from math import gcd - from scipy.signal import resample_poly - src = int(self._src_sfreq) g = gcd(src, REVE_SFREQ) - up = REVE_SFREQ // g - down = src // g - X_in = resample_poly(X_in, up=up, down=down, axis=-1) + X_in = resample_poly( + X_in, up=REVE_SFREQ // g, down=src // g, axis=-1 + ) return np.ascontiguousarray(X_in, dtype=np.float32) def get_result(self): - """Return the fitted pipeline.""" - return {"model": self.model} + return {"model": self._adapter} From 03102ca717dcd496941059cf5f01837b7de96632 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Tue, 2 Jun 2026 21:23:40 +0200 Subject: [PATCH 17/18] Adapt REVE & CBraMod solvers for linear probing adapter --- solvers/cbramod.py | 4 +--- solvers/reve.py | 4 +--- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/solvers/cbramod.py b/solvers/cbramod.py index 99389c0..f098eec 100644 --- a/solvers/cbramod.py +++ b/solvers/cbramod.py @@ -39,8 +39,7 @@ class Solver(BaseSolver): "checkpoint": ["braindecode/cbramod-pretrained"], "batch_size": [16], "n_estimators": [100], - "max_iter": [1000], - "classifier": ["logistic_regression"], + "classifier": ["log_reg"], } def skip(self, task, **kwargs): @@ -97,7 +96,6 @@ def run(self, _): encoder=self, task=self.task, classifier=self.classifier, - max_iter=self.max_iter, n_estimators=self.n_estimators, ) self._adapter.fit(self.X_train, self.y_train) diff --git a/solvers/reve.py b/solvers/reve.py index 28b18e5..9e50673 100644 --- a/solvers/reve.py +++ b/solvers/reve.py @@ -40,8 +40,7 @@ class Solver(BaseSolver): "pos_bank_checkpoint": ["brain-bzh/reve-positions"], "batch_size": [16], "n_estimators": [100], - "max_iter": [1000], - "classifier": ["logistic_regression"], + "classifier": ["log_reg"], } def skip(self, task, **kwargs): @@ -134,7 +133,6 @@ def run(self, _): encoder=self, task=self.task, classifier=self.classifier, - max_iter=self.max_iter, n_estimators=self.n_estimators, ) self._adapter.fit(self.X_train, self.y_train) From f8586344befd242eebc6e6641d76eab424549700 Mon Sep 17 00:00:00 2001 From: Bryan Ly Date: Tue, 2 Jun 2026 21:25:25 +0200 Subject: [PATCH 18/18] Fix HGD dependencies & EEGNet model instance creation --- datasets/hgd.py | 1 - solvers/eegnet.py | 51 +++++++++++++++++++---------------------------- 2 files changed, 20 insertions(+), 32 deletions(-) diff --git a/datasets/hgd.py b/datasets/hgd.py index b98f053..30cc497 100644 --- a/datasets/hgd.py +++ b/datasets/hgd.py @@ -119,7 +119,6 @@ class Dataset(BaseDataset): name = "HGD" - requirements = ["pip::moabb", "pandas", "pip::braindecode==1.5.1"] requirements = ["pip::moabb", "pip::pandas", "pip::braindecode"] parameters = { diff --git a/solvers/eegnet.py b/solvers/eegnet.py index c2f6865..c68b929 100644 --- a/solvers/eegnet.py +++ b/solvers/eegnet.py @@ -65,37 +65,26 @@ def set_objective(self, task, X_train, y_train, **meta): n_channels = X0.shape[1] if X0.ndim == 2 else 1 n_classes = int(meta.get("n_classes", len(np.unique(y_train)))) - # Build the network once per dataset configuration. - should_reload = ( - not hasattr(self, "_network") - or getattr(self, "_n_channels", None) != n_channels - or getattr(self, "_n_classes", None) != n_classes - or getattr(self, "_n_times", None) != n_times - ) - if should_reload: - try: - from braindecode.models import EEGNet - - network = EEGNet( - n_chans=n_channels, - n_outputs=n_classes, - n_times=n_times, - ) - network = network.to(device) - - self._network = network - self._n_channels = n_channels - self._n_classes = n_classes - self._n_times = n_times - print( - f"✓ EEGNet built: C={n_channels}, T={n_times}, " - f"n_classes={n_classes} on device: {device}" - ) - except Exception as e: - raise RuntimeError( - f"Failed to build EEGNet: {e}. Make sure braindecode " - "is installed." - ) + # Build the network + + try: + from braindecode.models import EEGNet + + network = EEGNet( + n_chans=n_channels, + n_outputs=n_classes, + n_times=n_times, + ) + self._network = network.to(device) + print( + f"✓ EEGNet built: C={n_channels}, T={n_times}, " + f"n_classes={n_classes} on device: {device}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to build EEGNet: {e}. Make sure braindecode " + "is installed." + ) self._device = device self._optimizer = torch.optim.Adam(