diff --git a/datasets/bnci2014_001.py b/datasets/bnci2014_001.py new file mode 100644 index 0000000..be61666 --- /dev/null +++ b/datasets/bnci2014_001.py @@ -0,0 +1,126 @@ +import numpy as np + +from braindecode.preprocessing.windowers import create_windows_from_events + +from braindecode.datasets import MOABBDataset +from braindecode.preprocessing import ( + exponential_moving_standardize, + preprocess, + Preprocessor, +) + +from benchopt import BaseDataset + + +# All datasets must be named `Dataset` and inherit from `BaseDataset` +class Dataset(BaseDataset): + + # Name to select the dataset in the CLI and to display the results. + name = "BNCI2014_001" + + requirements = [ + 'pip::braindecode', 'pip::moabb', + ] + + parameters = { + 'train_ratio': [0.8], + 'debug': [False], + 'seed': [42], + } + + def get_data(self): + # The return arguments of this function are passed as keyword arguments + # to `Objective.set_data`. This defines the benchmark's + # API to pass data. It is customizable for each benchmark. + + subjects = [1, 2] if self.debug else [1, 2, 3, 4, 5, 6, 7, 8, 9] + dataset = MOABBDataset( + dataset_name="BNCI2014_001", subject_ids=subjects, + ) + low_cut_hz = 4.0 # low cut frequency for filtering + high_cut_hz = 40.0 # high cut frequency for filtering + # Parameters for exponential moving standardization + factor_new = 1e-3 + init_block_size = 1000 + # Factor to convert from V to uV + factor = 1e6 + + preprocessors = [ + Preprocessor("pick_types", eeg=True, meg=False, stim=False), + Preprocessor(lambda data: np.multiply(data, factor)), + Preprocessor( + "filter", l_freq=low_cut_hz, h_freq=high_cut_hz + ), + Preprocessor( + exponential_moving_standardize, + factor_new=factor_new, + init_block_size=init_block_size, + ), + ] + + # Transform the data + preprocess(dataset, preprocessors) + + trial_start_offset_seconds = -0.5 + # Extract sampling frequency, check that they are same in all datasets + sfreq = dataset.datasets[0].raw.info["sfreq"] + + assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets]) + + # Extract channels names + ch_names = dataset.datasets[0].raw.ch_names + + # Calculate the trial start offset in samples. + trial_start_offset_samples = int(trial_start_offset_seconds * sfreq) + + window_size_samples = None + window_stride_samples = None + + windows_dataset = create_windows_from_events( + dataset, + trial_start_offset_samples=trial_start_offset_samples, + trial_stop_offset_samples=0, + preload=False, + window_size_samples=window_size_samples, + window_stride_samples=window_stride_samples, + ) + + splitted = windows_dataset.split("subject") + subjects = list(splitted.keys()) + + X_all = [] + y_all = [] + for sub in subjects: + n_runs = len(splitted[sub].datasets) + x = [] + y = [] + for run in range(n_runs): + x += [sample[0].T for sample in splitted[sub].datasets[run]] + y += [sample[1] for sample in splitted[sub].datasets[run]] + X_all.append(np.array(x)) + y_all.append(np.array(y)) + random_state = np.random.RandomState(seed=self.seed) + ids_train = random_state.choice( + len(X_all), size=int(len(X_all) * self.train_ratio), + replace=False + ) + X_train = np.concatenate([X_all[i] for i in ids_train]) + y_train = np.concatenate([y_all[i] for i in ids_train]) + X_test = np.concatenate( + [X_all[i] for i in range(len(X_all)) if i not in ids_train] + ) + y_test = np.concatenate( + [y_all[i] for i in range(len(y_all)) if i not in ids_train] + ) + np.unique(y_train), np.unique(y_test) # sanity check + return dict( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + task="classification", + metrics=["accuracy", "balanced_accuracy", "f1_weighted"], + n_classes=len(np.unique(y_train)), + freq=sfreq, + ch_names=ch_names + ) diff --git a/datasets/hgd.py b/datasets/hgd.py new file mode 100644 index 0000000..30cc497 --- /dev/null +++ b/datasets/hgd.py @@ -0,0 +1,201 @@ +"""High-Gamma Dataset (HGD) — motor imagery EEG classification. + +Wraps the High-Gamma Dataset from Schirrmeister et al. (2017): + "Deep learning with convolutional neural networks for EEG decoding + and visualization", Human Brain Mapping. + https://doi.org/10.1002/hbm.23730 + +128-electrode EEG (44 motor-cortex channels used) recorded from 14 healthy +subjects performing ~1000 four-second trials of executed movements across +13 runs. Data are downloaded automatically via braindecode / MOABB. + +Labels (4 classes): + 0: left_hand 1: right_hand 2: feet 3: rest + +The dataset's own train/test split (runs 1-11 → train, runs 12-13 → test) +is preserved per subject; subjects are then pooled according to +`train_ratio`. + +Data contract output +-------------------- +X_train : np.ndarray (N, T, C) windows (n_times, n_channels=44) +y_train : np.ndarray (N,) int class labels 0-3 +X_test : np.ndarray (M, T, C) +y_test : np.ndarray (M,) int +task : "classification" +metrics : ["accuracy", "balanced_accuracy", "f1_weighted"] +n_classes : 4 +""" + +import numpy as np +from braindecode.datasets import HGD +from braindecode.preprocessing.preprocess import preprocess, Preprocessor +from braindecode.preprocessing.windowers import create_windows_from_events +from sklearn.preprocessing import scale as standard_scale + +from benchopt import BaseDataset + + +# --------------------------------------------------------------------------- +# Label mapping +# --------------------------------------------------------------------------- +LABEL_MAPPING = { + "left_hand": 0, + "right_hand": 1, + "feet": 2, + "rest": 3, +} + + +# --------------------------------------------------------------------------- +# Per-subject loader +# --------------------------------------------------------------------------- + +def _load_subject(sub_id, preprocessors, window_size_samples): + """Wrapper that falls back to a description-based split if metadata + item access is unavailable in the installed braindecode version.""" + dataset = HGD(subject_ids=[sub_id]) + preprocess(dataset, preprocessors) + raw = dataset.datasets[0].raw + sfreq = raw.info["sfreq"] + ch_names = raw.ch_names + + windows_dataset = create_windows_from_events( + dataset, + trial_start_offset_samples=0, + trial_stop_offset_samples=0, + window_size_samples=window_size_samples, + window_stride_samples=window_size_samples, + preload=True, + mapping=LABEL_MAPPING, + ) + preprocess( + windows_dataset, [Preprocessor(standard_scale, channel_wise=True)] + ) + + X_, y_ = [], [] + + # windows_dataset.description contains a 'split' column ('train'/'test') + # Each window is indexed into its base dataset; we replicate that mapping. + descriptions = windows_dataset.datasets # list of WindowsDataset + + for ds in descriptions: + for x, label, _ in ds: + window = x.T # (T, C) + X_.append(window) + y_.append(label) + + return X_, y_, sfreq, ch_names + + +# --------------------------------------------------------------------------- +# Benchopt Dataset +# --------------------------------------------------------------------------- + +class Dataset(BaseDataset): + """High-Gamma Dataset (HGD) — 4-class motor EEG classification. + + Parameters + ---------- + resample_hz : float + Target sampling frequency. The raw data are recorded at 500 Hz; + default resamples to 250 Hz (window_size_samples=1000 → 4 s). + high_cut_hz : float + Cutoff for the low-pass filter applied to raw signals (Hz). + factor : float + Multiplicative scaling applied before filtering (e.g. V → µV). + window_size_samples : int + Length of each trial window in samples (after resampling). + train_ratio : float + Fraction of subjects whose trials go into the training pool. + The internal per-subject train/test split (runs 1-11 vs 12-13) is + always respected first; `train_ratio` then selects which subjects + contribute to the final train vs test arrays. + debug : bool + If True, load only the first 2 subjects for fast iteration. + seed : int + Random seed for the subject-level train/test split. + """ + + name = "HGD" + + requirements = ["pip::moabb", "pip::pandas", "pip::braindecode"] + + parameters = { + "seed": [42], + "train_ratio": [0.8], + "resample_hz": [250], + "high_cut_hz": [40.0], + "factor": [1e6], # V → µV + "window_size_samples": [1000], # 4 s at 250 Hz + "debug": [False], + } + + def get_data(self): + n_jobs = 1 + preprocessors = [ + # Convert V → µV + Preprocessor(lambda data: np.multiply(data, self.factor)), + # Resample to target frequency + Preprocessor("resample", sfreq=self.resample_hz), + # Low-pass filter + Preprocessor( + "filter", l_freq=None, h_freq=self.high_cut_hz, n_jobs=n_jobs + ), + ] + + sub_ids = list(range(1, 15)) # 14 subjects + if self.debug: + sub_ids = sub_ids[:2] + + # Collect per-subject (train, test) pairs + X_all, y_all = [], [] + sfreq_ref, ch_names_ref = None, None + + for sub_id in sub_ids: + X_, y_, sfreq, ch_names = _load_subject( + sub_id, preprocessors, self.window_size_samples + ) + if sfreq_ref is None: + sfreq_ref, ch_names_ref = sfreq, ch_names + else: + assert sfreq == sfreq_ref and ch_names == ch_names_ref, f"Inconsistent meta for sub {sub_id}" + X_all.append(X_) + y_all.append(y_) + + # ------------------------------------------------------------------ + # Subject-level train / test split (same pattern as Sleep dataset) + # ------------------------------------------------------------------ + random_state = np.random.RandomState(seed=self.seed) + ids_train = random_state.choice( + len(X_all), + size=int(len(X_all) * self.train_ratio), + replace=False, + ) + ids_train_set = set(ids_train.tolist()) + + X_train = np.concatenate( + [X_all[i] for i in ids_train_set], axis=0 + ) + y_train = np.concatenate( + [y_all[i] for i in ids_train_set], axis=0 + ) + X_test = np.concatenate( + [X_all[i] for i in range(len(X_all)) + if i not in ids_train_set], axis=0 + ) + y_test = np.concatenate( + [y_all[i] for i in range(len(y_all)) + if i not in ids_train_set], axis=0 + ) + return dict( + X_train=X_train, # (N, window_size_samples, n_channels) + y_train=y_train, # (N,) int in {0, 1, 2, 3} + X_test=X_test, # (M, window_size_samples, n_channels) + y_test=y_test, # (M,) + task="classification", + metrics=["accuracy", "balanced_accuracy", "f1_weighted"], + n_classes=4, + freq=sfreq_ref, + ch_names=ch_names_ref + ) diff --git a/datasets/sleepphysionet.py b/datasets/sleepphysionet.py new file mode 100644 index 0000000..8fe1ed8 --- /dev/null +++ b/datasets/sleepphysionet.py @@ -0,0 +1,176 @@ +"""Sleep classification dataset from Sleep Physionet. + +Wraps the sleep recordings from the Sleep Physionet. +Each recording is split into a training portion (first 10 %) and a test +portion. +Labels are from 0 to 4, corresponding to the sleep stages W, N1, N2, N3, REM. + +Data contract output +-------------------- +X_train : List[np.ndarray (T, C)] one array per training sample +y_train : np.ndarray (N,) int class labels +X_test : List[np.ndarray (T, C)] +y_test : np.ndarray (M,) int +task : "classification" +metrics : ["accuracy", "balanced_accuracy", "f1_weighted"] +n_classes : int +""" + +import numpy as np +from braindecode.datasets import SleepPhysionet +from braindecode.preprocessing.preprocess import preprocess, Preprocessor + +from braindecode.preprocessing.windowers import create_windows_from_events +from sklearn.preprocessing import scale as standard_scale + +from benchopt import BaseDataset + + +def _load_subject( + sub_id, preprocessors, mapping=None, window_size_samples=3000 +): + dataset = SleepPhysionet(subject_ids=[sub_id], crop_wake_mins=30) + + preprocess(dataset, preprocessors) + + # Extract the frequency and channels names + raw = dataset.datasets[0].raw + sfreq = raw.info["sfreq"] + ch_names = raw.ch_names + + windows_dataset = create_windows_from_events( + dataset, + trial_start_offset_samples=0, + trial_stop_offset_samples=0, + window_size_samples=window_size_samples, + window_stride_samples=window_size_samples, + preload=True, + mapping=mapping, + ) + + preprocess( + windows_dataset, [Preprocessor(standard_scale, channel_wise=True)] + ) + all_labels = [] + all_data = [] + for i, x in enumerate(windows_dataset): + label = x[1] + all_labels.append(label) + + data = x[0].T + all_data.append(data) + return all_data, all_labels, sfreq, ch_names + + +class Dataset(BaseDataset): + """Sleep classification dataset (TSB-UAD). + + Parameters + ---------- + window_size_samples : int + Length of the windows to split the recordings into. + sub_ids : List[int] + Subject IDs to include (from 1 to 82, excluding 39, 68, 69, 78, 79). + mapping : dict + Mapping from the original sleep stage labels to integers. + We merge stages 3 and 4 following AASM standards. + n_jobs : int + Number of parallel jobs to use for preprocessing. + debug : bool + If True, keep only the first 5000 samples + of each recording for fast testing. + high_cut_hz : float + If not None, apply a low-pass filter with this cutoff frequency (in Hz) + to the raw signals. + factor : float + Factor to multiply the raw signals by (e.g. to convert from V to uV + train_ratio : float + Fraction of each recording used as the training (normal) portion. + """ + + name = "SleepPhysionet" + + requirements = ["pip::pooch", "pip::pandas", "pip::braindecode"] + + parameters = { + "seed": [42], + "train_ratio": [0.8], + "debug": [False], + } + + def prepare(self): + # Allow reuse of the download helper from benchmark_ad if present, + # otherwise fall back to the data path directly. + sub_ids = range(1, 83) + window_size_samples = 3000 + mapping = { # We merge stages 3 and 4 following AASM standards. + "Sleep stage W": 0, + "Sleep stage 1": 1, + "Sleep stage 2": 2, + "Sleep stage 3": 3, + "Sleep stage 4": 3, + "Sleep stage R": 4, + } + n_jobs = 1 + high_cut_hz = 40.0 + factor = 1e6 # Factor to convert from V to uV + + preprocessors = [ + Preprocessor(lambda data: np.multiply(data, factor)), + Preprocessor( + "filter", l_freq=None, + h_freq=high_cut_hz, n_jobs=n_jobs + ), + ] + + X_all, y_all = [], [] + sub_ids = sub_ids[:2] if self.debug else sub_ids + sfreq_ref, ch_names_ref = None, None + for sub_id in sub_ids: + if sub_id in [39, 68, 69, 78, 79]: + continue + X_, y_, sfreq, ch_names = _load_subject( + sub_id, preprocessors, mapping, window_size_samples + ) + if sfreq_ref is None: + sfreq_ref, ch_names_ref = sfreq, ch_names + else: + assert sfreq == sfreq_ref and ch_names == ch_names_ref, f"Inconsistent meta for sub {sub_id}" + if self.debug: + X_ = X_[:5000] + y_ = y_[:5000] + X_all.append(X_) + y_all.append(y_) + + random_state = np.random.RandomState(seed=self.seed) + ids_train = random_state.choice( + len(X_all), size=int(len(X_all) * self.train_ratio), + replace=False + ) + + X_train = np.concatenate([X_all[i] for i in ids_train]) + y_train = np.concatenate([y_all[i] for i in ids_train]) + X_test = np.concatenate( + [X_all[i] for i in range(len(X_all)) if i not in ids_train] + ) + y_test = np.concatenate( + [y_all[i] for i in range(len(y_all)) if i not in ids_train] + ) + + return X_train, y_train, X_test, y_test, sfreq_ref, ch_names_ref + + def get_data(self): + + X_train, y_train, X_test, y_test, sfreq, ch_names = self.prepare() + + return dict( + X_train=X_train, + y_train=y_train, + X_test=X_test, + y_test=y_test, + task="classification", + metrics=["accuracy", "balanced_accuracy", "f1_weighted"], + n_classes=5, + freq=sfreq, + ch_names=ch_names + ) diff --git a/solvers/cbramod.py b/solvers/cbramod.py new file mode 100644 index 0000000..f098eec --- /dev/null +++ b/solvers/cbramod.py @@ -0,0 +1,159 @@ +"""CBraMod solver — frozen foundation backbone + linear probe. + +CBraMod (Wang et al., ICLR 2025) is pretrained on TUEG with 1-second +patches at 200 Hz. Inputs at any other rate are resampled via a +polyphase (anti-aliased) filter. Channel-agnostic montage thanks to +its Asymmetric Conditional Positional Encoding. + +References: + https://arxiv.org/abs/2412.07236 + https://braindecode.org/stable/generated/braindecode.models.CBraMod.html +""" + +from math import gcd + +import numpy as np +import torch +from benchopt import BaseSolver +from scipy.signal import resample_poly + +from benchmark_utils.adapters.linear_probe import LinearProbeAdapter + +SUPPORTED_TASKS = {"classification"} +CBRAMOD_SFREQ = 200 +CBRAMOD_PATCH_SIZE = 200 # 1-second patches at 200 Hz; T must be a multiple + + +class Solver(BaseSolver): + """CBraMod foundation model + linear probe.""" + + name = "CBraMod" + + requirements = [ + "pip::braindecode", + "pip::torch", + "pip::scipy", + ] + + parameters = { + "checkpoint": ["braindecode/cbramod-pretrained"], + "batch_size": [16], + "n_estimators": [100], + "classifier": ["log_reg"], + } + + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"CBraMod solver does not support task={task!r}" + return False, None + + def set_objective(self, task, X_train, y_train, **meta): + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + freq = meta.get("freq") or CBRAMOD_SFREQ + self._src_sfreq = float(freq) + + should_reload = ( + not hasattr(self, "_network") + or getattr(self, "_loaded_checkpoint", None) != self.checkpoint + ) + if should_reload: + try: + from braindecode.models import CBraMod + + # return_encoder_output=True bypasses the randomly-init + # classification head and exposes encoder features. + network = CBraMod.from_pretrained( + self.checkpoint, + return_encoder_output=True, + ) + self._network = network.to(device).eval() + self._loaded_checkpoint = self.checkpoint + print( + f"✓ CBraMod checkpoint loaded: {self.checkpoint} " + f"on device: {device}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to load CBraMod checkpoint " + f"'{self.checkpoint}': {e}" + ) + + self._device = device + + def run(self, _): + self._adapter = LinearProbeAdapter( + encoder=self, + task=self.task, + classifier=self.classifier, + n_estimators=self.n_estimators, + ) + self._adapter.fit(self.X_train, self.y_train) + + def encode(self, X): + """LinearProbeAdapter entry point — returns (N, embed_dim).""" + return self._extract_embeddings(X) + + def _extract_embeddings(self, X): + batch_size = self.batch_size + n_samples = len(X) + all_embeddings = [] + + for batch_idx in range(0, n_samples, batch_size): + batch_end = min(batch_idx + batch_size, n_samples) + X_batch = np.asarray(X[batch_idx:batch_end], dtype=np.float32) + X_batch_processed = self._prepare_inputs(X_batch) + + x_t = torch.tensor( + X_batch_processed, dtype=torch.float32, device=self._device + ) + + with torch.no_grad(): + emb = self._network(x_t) + + # Encoder returns (B, n_chans, n_patches, D) — flatten to + # one vector per sample. + if emb.ndim > 2: + emb = emb.flatten(start_dim=1) + + all_embeddings.append(emb.cpu().numpy()) + + return np.vstack(all_embeddings) + + def _prepare_inputs(self, X_batch): + """Reshape (N, T, C) → (N, C, T), resample to 200 Hz, truncate. + + Polyphase resampling is anti-aliased so it works for both up- + and downsampling. CBraMod patches at 200 samples per token, so + ``T`` must be a multiple of ``CBRAMOD_PATCH_SIZE`` — we truncate + any trailing partial patch rather than zero-pad (cleaner than + injecting fake samples). + """ + X_in = X_batch.transpose(0, 2, 1) + + if int(self._src_sfreq) != CBRAMOD_SFREQ: + src = int(self._src_sfreq) + g = gcd(src, CBRAMOD_SFREQ) + X_in = resample_poly( + X_in, up=CBRAMOD_SFREQ // g, down=src // g, axis=-1 + ) + + T = X_in.shape[-1] + remainder = T % CBRAMOD_PATCH_SIZE + if remainder: + X_in = X_in[..., : T - remainder] + + return np.ascontiguousarray(X_in, dtype=np.float32) + + def get_result(self): + return {"model": self._adapter} diff --git a/solvers/eegnet.py b/solvers/eegnet.py new file mode 100644 index 0000000..c68b929 --- /dev/null +++ b/solvers/eegnet.py @@ -0,0 +1,161 @@ +"""EEGNet solver for time series classification. + +Uses the ``braindecode`` implementation of EEGNetv4 as a classifier on +multivariate time series of shape (N, T, C). + +References: + https://braindecode.org/ + https://arxiv.org/abs/1611.08024 +""" + +import numpy as np +import torch +from benchopt import BaseSolver + +SUPPORTED_TASKS = {"classification"} + + +class Solver(BaseSolver): + """EEGNet time series classification solver. + + The model is built once in ``set_objective`` (not timed). During + ``run`` the network is trained on the training set. + """ + + name = "EEGNet" + + requirements = [ + "pip::braindecode", + "pip::torch", + ] + + parameters = { + "n_epochs": [50], + "batch_size": [32], + "lr": [1e-3], + } + + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"EEGNet solver does not support task={task!r}" + return False, None + + def set_objective(self, task, X_train, y_train, **meta): + """Prepare the solver for a given dataset configuration. + + Model construction is done here (not inside ``run``) so that + the build time is excluded from the benchmark timing. + """ + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + # Infer input dimensions directly from the training data. + # Within a dataset all series share the same (T, C) shape. + X0 = np.asarray(X_train[0]) + n_times = X0.shape[0] + n_channels = X0.shape[1] if X0.ndim == 2 else 1 + n_classes = int(meta.get("n_classes", len(np.unique(y_train)))) + + # Build the network + + try: + from braindecode.models import EEGNet + + network = EEGNet( + n_chans=n_channels, + n_outputs=n_classes, + n_times=n_times, + ) + self._network = network.to(device) + print( + f"✓ EEGNet built: C={n_channels}, T={n_times}, " + f"n_classes={n_classes} on device: {device}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to build EEGNet: {e}. Make sure braindecode " + "is installed." + ) + + self._device = device + self._optimizer = torch.optim.Adam( + self._network.parameters(), lr=self.lr + ) + self._criterion = torch.nn.CrossEntropyLoss() + + def run(self, _): + """Fit the model on the training data.""" + X = self._prepare_inputs(np.asarray(self.X_train, dtype=np.float32)) + y = np.asarray(self.y_train, dtype=np.int64) + + X_t = torch.tensor(X, dtype=torch.float32, device=self._device) + y_t = torch.tensor(y, dtype=torch.long, device=self._device) + + dataset = torch.utils.data.TensorDataset(X_t, y_t) + loader = torch.utils.data.DataLoader( + dataset, batch_size=self.batch_size, shuffle=True + ) + + self._network.train() + for _ in range(self.n_epochs): + for xb, yb in loader: + self._optimizer.zero_grad() + logits = self._network(xb) + loss = self._criterion(logits, yb) + loss.backward() + self._optimizer.step() + + def _prepare_inputs(self, X_batch): + """Reshape inputs to EEGNet's expected layout. + + EEGNet expects arrays of shape (N, C, T). Inputs arrive as + (N, T, C); within a dataset all series share the same length, + so no interpolate needed. + """ + return X_batch.transpose(0, 2, 1) + + def get_result(self): + """Return a predictor exposing ``predict(X_test) -> (N,)``.""" + return { + "model": EEGNetAdapter( + self._network, + self._device, + self._prepare_inputs, + self.batch_size, + ) + } + + +class EEGNetAdapter: + """Wraps a trained EEGNet to expose ``predict(X) -> (N,)`` class indices. + + ``X`` arrives as a list / array of (T, C) series, matching the + classification data contract in ``objective.py``. + """ + + def __init__(self, network, device, prepare_inputs, batch_size): + self._network = network + self._device = device + self._prepare_inputs = prepare_inputs + self._batch_size = batch_size + + def predict(self, X): + X = self._prepare_inputs(np.asarray(X, dtype=np.float32)) + X_t = torch.tensor(X, dtype=torch.float32, device=self._device) + + self._network.eval() + preds = [] + with torch.no_grad(): + for i in range(0, len(X_t), self._batch_size): + logits = self._network(X_t[i:i + self._batch_size]) + preds.append(logits.argmax(dim=1).cpu().numpy()) + return np.concatenate(preds) diff --git a/solvers/reve.py b/solvers/reve.py new file mode 100644 index 0000000..9e50673 --- /dev/null +++ b/solvers/reve.py @@ -0,0 +1,187 @@ +"""REVE solver — frozen foundation backbone + linear probe. + +REVE expects EEG sampled at 200 Hz with monopolar 10-20 channel +names. Inputs at other rates are resampled via polyphase filtering; +datasets with bipolar or prefixed channel names are skipped (the +position bank cannot resolve them to single 3D coordinates). + +References: + https://huggingface.co/brain-bzh/reve-large + https://huggingface.co/brain-bzh/reve-positions +""" + +from math import gcd + +import numpy as np +import torch +from benchopt import BaseSolver +from scipy.signal import resample_poly + +from benchmark_utils.adapters.linear_probe import LinearProbeAdapter + +SUPPORTED_TASKS = {"classification"} +REVE_SFREQ = 200 + + +class Solver(BaseSolver): + """REVE foundation model + linear probe.""" + + name = "REVE" + + requirements = [ + "pip::braindecode", + "pip::transformers", + "pip::torch", + "pip::scipy", + ] + + parameters = { + "checkpoint": ["brain-bzh/reve-large"], + "pos_bank_checkpoint": ["brain-bzh/reve-positions"], + "batch_size": [16], + "n_estimators": [100], + "classifier": ["log_reg"], + } + + def skip(self, task, **kwargs): + if task not in SUPPORTED_TASKS: + return True, f"REVE solver does not support task={task!r}" + + # REVE was trained on monopolar 10-20 montages — bipolar + # derivations ("EEG Fpz-Cz") have no single 3D position. + ch_names = kwargs.get("ch_names") + if ch_names is None: + return True, "REVE requires `ch_names` in dataset meta." + bad = [ + n for n in ch_names + if "-" in n or any( + n.startswith(p) for p in ("EEG ", "EOG ", "EMG ") + ) + ] + if bad: + return True, ( + f"REVE expects monopolar 10-20 names; got {bad[0]!r}." + ) + + # Gated repo: skip cleanly when access/auth is missing. + try: + from huggingface_hub import HfApi + HfApi().model_info(self.checkpoint) + except Exception as e: + return True, ( + f"REVE checkpoint '{self.checkpoint}' gated or unreachable. " + f"Request access and run `huggingface-cli login`. " + f"({type(e).__name__})" + ) + return False, None + + def set_objective(self, task, X_train, y_train, **meta): + self.task = task + self.X_train = X_train + self.y_train = y_train + self.meta = meta + + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" + + freq = meta.get("freq") or REVE_SFREQ + self._src_sfreq = float(freq) + self._ch_names = meta["ch_names"] + + should_reload = ( + not hasattr(self, "_network") + or getattr(self, "_loaded_checkpoint", None) != self.checkpoint + ) + if should_reload: + try: + from transformers import AutoModel + + # trust_remote_code: REVE ships custom modeling code on + # the Hub. Mandatory to avoid an interactive prompt that + # would break batch/CI runs. + network = AutoModel.from_pretrained( + self.checkpoint, trust_remote_code=True + ) + pos_bank = AutoModel.from_pretrained( + self.pos_bank_checkpoint, trust_remote_code=True + ) + + self._network = network.to(device).eval() + self._pos_bank = pos_bank.to(device).eval() + self._loaded_checkpoint = self.checkpoint + print( + f"✓ REVE checkpoint loaded: {self.checkpoint} " + f"on device: {device}" + ) + except Exception as e: + raise RuntimeError( + f"Failed to load REVE checkpoint '{self.checkpoint}': {e}" + ) + + # Positions (C, 3) recomputed per dataset since ch_names varies. + with torch.no_grad(): + self._positions = self._pos_bank(self._ch_names) + + self._device = device + + def run(self, _): + self._adapter = LinearProbeAdapter( + encoder=self, + task=self.task, + classifier=self.classifier, + n_estimators=self.n_estimators, + ) + self._adapter.fit(self.X_train, self.y_train) + + def encode(self, X): + """LinearProbeAdapter entry point — returns (N, embed_dim).""" + return self._extract_embeddings(X) + + def _extract_embeddings(self, X): + batch_size = self.batch_size + n_samples = len(X) + all_embeddings = [] + + for batch_idx in range(0, n_samples, batch_size): + batch_end = min(batch_idx + batch_size, n_samples) + X_batch = np.asarray(X[batch_idx:batch_end], dtype=np.float32) + X_batch_processed = self._prepare_inputs(X_batch) + + x_t = torch.tensor( + X_batch_processed, dtype=torch.float32, device=self._device + ) + pos = self._positions.unsqueeze(0).expand(x_t.size(0), -1, -1) + + with torch.no_grad(): + emb = self._network(x_t, pos) + + if emb.ndim > 2: + emb = emb.flatten(start_dim=1) + + all_embeddings.append(emb.cpu().numpy()) + + return np.vstack(all_embeddings) + + def _prepare_inputs(self, X_batch): + """Reshape (N, T, C) → (N, C, T) and resample to 200 Hz. + + Polyphase (anti-aliased) — REVE is frequency-aware so naive + ``F.interpolate`` would corrupt the bands it was trained on. + """ + X_in = X_batch.transpose(0, 2, 1) + + if int(self._src_sfreq) != REVE_SFREQ: + src = int(self._src_sfreq) + g = gcd(src, REVE_SFREQ) + X_in = resample_poly( + X_in, up=REVE_SFREQ // g, down=src // g, axis=-1 + ) + + return np.ascontiguousarray(X_in, dtype=np.float32) + + def get_result(self): + return {"model": self._adapter}