diff --git a/cebra/data/__init__.py b/cebra/data/__init__.py index 697801ed..145ff835 100644 --- a/cebra/data/__init__.py +++ b/cebra/data/__init__.py @@ -51,3 +51,4 @@ from cebra.data.multiobjective import * from cebra.data.datasets import * from cebra.data.helper import * +from cebra.data.masking import * diff --git a/cebra/data/base.py b/cebra/data/base.py index 4fa7ba6c..51199cec 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -27,6 +27,7 @@ import torch import cebra.data.assets as cebra_data_assets +import cebra.data.masking as cebra_data_masking import cebra.distributions import cebra.io from cebra.data.datatypes import Batch @@ -36,7 +37,7 @@ __all__ = ["Dataset", "Loader"] -class Dataset(abc.ABC, cebra.io.HasDevice): +class Dataset(abc.ABC, cebra.io.HasDevice, cebra_data_masking.MaskedMixin): """Abstract base class for implementing a dataset. The class attributes provide information about the shape of the data when @@ -227,6 +228,8 @@ class Loader(abc.ABC, cebra.io.HasDevice): doc="""A dataset instance specifying a ``__getitem__`` function.""", ) + time_offset: int = dataclasses.field(default=10) + num_steps: int = dataclasses.field( default=None, doc= diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 24735f47..59af8900 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -29,6 +29,7 @@ import torch import cebra.data as cebra_data +import cebra.data.masking as cebra_data_masking import cebra.helper as cebra_helper import cebra.io as cebra_io from cebra.data.datatypes import Batch @@ -304,7 +305,7 @@ def _iter_property(self, attr): # TODO(stes): This should be a single session dataset? -class DatasetxCEBRA(cebra_io.HasDevice): +class DatasetxCEBRA(cebra_io.HasDevice, cebra_data_masking.MaskedMixin): """Dataset class for xCEBRA models. This class handles neural data and associated labels for xCEBRA models, providing @@ -435,3 +436,71 @@ def load_batch_contrastive(self, index: BatchIndex) -> Batch: positive=[self[idx] for idx in index.positive], negative=self[index.negative], ) + + +class UnifiedDataset(DatasetCollection): + """Multi session dataset made up of a list of datasets, considered as a unique session. + + Considering the sessions as a unique session, or pseudo-session, is used to later train a single + model for all the sessions, even if they originally contain a variable number of neurons. + To do that, we sample ref/pos/neg for each session and concatenate them along the neurons axis. + + For instance, for a batch size ``batch_size``, we sample ``(batch_size, num_neurons(session), offset)`` for + each type of samples (ref/pos/neg) and then concatenate so that the final :py:class:`cebra.data.datatypes.Batch` + is of shape ``(batch_size, total_num_neurons, offset)``, with ``total_num_neurons`` is the sum of all the + ``num_neurons(session)``. + """ + + def __init__(self, *datasets: cebra_data.SingleSessionDataset): + super().__init__(*datasets) + + @property + def input_dimension(self) -> int: + """Returns the sum of the input dimension for each session.""" + return np.sum([ + self.get_input_dimension(session_id) + for session_id in range(self.num_sessions) + ]) + + def _get_batches(self, index): + """Return the data at the specified index location.""" + return [ + cebra_data.Batch( + reference=self.get_session(session_id)[ + index.reference[session_id]], + positive=self.get_session(session_id)[ + index.positive[session_id]], + negative=self.get_session(session_id)[ + index.negative[session_id]], + ) for session_id in range(self.num_sessions) + ] + + def load_batch(self, index: BatchIndex) -> Batch: + """Return the data at the specified index location. + + Concatenate batches for each sessions on the number of neurons axis. + + Args: + batches: List of :py:class:`cebra.data.datatypes.Batch` sampled for each session. An instance + :py:class:`cebra.data.datatypes.Batch` of the list is of shape ``(batch_size, num_neurons(session), offset)``. + + Returns: + A :py:class:`cebra.data.datatypes.Batch`, of shape ``(batch_size, total_num_neurons, offset)``, where + ``total_num_neurons`` is the sum of all the ``num_neurons(session)`` + """ + batches = self._get_batches(index) + + return cebra_data.Batch( + reference=self.apply_mask( + torch.cat([batch.reference for batch in batches], dim=1)), + positive=self.apply_mask( + torch.cat([batch.positive for batch in batches], dim=1)), + negative=self.apply_mask( + torch.cat([batch.negative for batch in batches], dim=1)), + ) + + def __getitem__(self, args) -> List[Batch]: + """Return a set of samples from all sessions.""" + + session_id, index = args + return self.get_session(session_id).__getitem__(index) diff --git a/cebra/data/mask.py b/cebra/data/mask.py new file mode 100644 index 00000000..946d97a4 --- /dev/null +++ b/cebra/data/mask.py @@ -0,0 +1,327 @@ +import abc +import random +from typing import List, Tuple, Union + +import numpy as np +import torch + +__all__ = [ + "Mask", "RandomNeuronMask", "RandomTimestepMask", "NeuronBlockMask", + "TimeBlockMask" +] + + +class Mask: + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + self._check_masking_parameters(masking_value) + + @abc.abstractmethod + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abc.abstractmethod + def _select_masking_params(): + raise NotImplementedError + + def _check_masking_parameters(self, masking_value: Union[float, List[float], + Tuple[float]]): + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(masking_value, float): + assert 0.0 < masking_value < 1.0, ( + f"Masking ratio {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + + elif isinstance(masking_value, list): + assert all(isinstance(ratio, float) for ratio in masking_value), ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert all(0.0 < ratio < 1.0 for ratio in masking_value), ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + + elif isinstance(masking_value, tuple): + assert len(masking_value) == 3, ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be a tuple of (min, max, step).") + assert 0.0 <= masking_value[0] < masking_value[1] <= 1.0, ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert masking_value[2] < masking_value[1] - masking_value[0], ( + f"Masking step {masking_value[2]} for {self.__name__()} " + "should be between smaller than the diff between min " + f"({masking_value[0]}) and max ({masking_value[1]}).") + + else: + raise ValueError( + f"Masking ratio {masking_value} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + +class RandomNeuronMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_ratio = masking_value + + def __name__(self): + return "RandomNeuronMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply random masking on the neuron dimension. + + Args: + data: batch of size (batch_size, n_neurons, offset). + mask_ratio: Proportion of neurons to mask. Default value 0.3 comes + from the MtM paper: https://arxiv.org/pdf/2407.14668v2 + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + """ + batch_size, n_neurons, offset_length = data.shape + mask_ratio = self._select_masking_params() + + # Random mask: shape [batch_size, n_neurons], different per batch and neurons + masked = torch.rand(batch_size, n_neurons, + device=data.device) < mask_ratio + return (~masked).int().unsqueeze(2).expand( + -1, -1, offset_length) # Expand to all timesteps + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_ratio, float): + selected_value = self.mask_ratio + + elif isinstance(self.mask_ratio, list): + selected_value = random.choice(self.mask_ratio) + + elif isinstance(self.mask_ratio, tuple): + min_val, max_val, step_size = self.mask_ratio + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_ratio} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class RandomTimestepMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_ratio = masking_value + + def __name__(self): + return "RandomTimestepMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply random masking on the time dimension. + + Args: + data: batch of size (batch_idx, feature_dim, seq_len). With seq_len + corresponding to the offset. + mask_ratio: Proportion of timesteps masked. Not necessarly consecutive. + Default value 0.3 comes from the MtM paper: https://arxiv.org/pdf/2407.14668v2 + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + + """ + batch_idx, n_neurons, offset_length = data.shape + mask_ratio = self._select_masking_params() + + # Random mask: shape [batbatch_idxch_size, offset_length], different per batch and timestamp + masked = torch.rand(batch_idx, offset_length, + device=data.device) < mask_ratio + return (~masked).int().unsqueeze(1).expand(-1, n_neurons, + -1) # Expand to all neurons + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_ratio, float): + selected_value = self.mask_ratio + + elif isinstance(self.mask_ratio, list): + selected_value = random.choice(self.mask_ratio) + + elif isinstance(self.mask_ratio, tuple): + min_val, max_val, step_size = self.mask_ratio + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_ratio} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class NeuronBlockMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_prop = masking_value + + def __name__(self): + return "NeuronBlockMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply masking to a contiguous block of neurons. + + Args: + data: batch of size (batch_size, n_neurons, offset). + self.mask_prop: Proportion of neurons to mask. The neurons are masked in a + contiguous block. + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + """ + batch_size, n_neurons, offset_length = data.shape + + mask_prop = self._select_masking_params() + num_mask = int(n_neurons * mask_prop) + mask = torch.ones((batch_size, n_neurons), + dtype=torch.int, + device=data.device) + + if num_mask == 0: + return mask.unsqueeze(2) + + for batch_idx in range(batch_size): # Create a mask for each batch + # Select random the start index for the block of neurons to mask + start_idx = torch.randint(0, n_neurons - num_mask + 1, (1,)).item() + end_idx = min(start_idx + num_mask, n_neurons) + mask[batch_idx, start_idx:end_idx] = 0 # set masked neurons to 0 + + return mask.unsqueeze(2).expand( + -1, -1, offset_length) # Expand to all timesteps + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_prop, float): + selected_value = self.mask_prop + + elif isinstance(self.mask_prop, list): + selected_value = random.choice(self.mask_prop) + + elif isinstance(self.mask_prop, tuple): + min_val, max_val, step_size = self.mask_prop + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_prop} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class TimeBlockMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.sampled_rate, self.masked_seq_len = masking_value + + def __name__(self): + return "TimeBlockMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply continguous block masking on the time dimension. + + When choosing which block of timesteps to mask, each timestep is considered + a candidate starting time-step with probability ``self.sampled_rate`` where + ``self.masked_seq_len`` is the length of each masked span starting from the respective + time step. Sampled starting time steps are expanded to length ``self.masked_seq_len`` + and spans can overlap. Inspirede by the wav2vec 2.0 masking strategy. + + Default values from the wav2vec paper: https://arxiv.org/abs/2006.11477. + + Args: + data (torch.Tensor): The input tensor of shape (batch_size, seq_len, feature_dim). + self.sampled_rate (float): The probability of each time-step being a candidate for masking. + self.masked_seq_len (int): The length of each masked span starting from the sampled time-step. + + Returns: + torch.Tensor: A boolean mask of shape (batch_size, seq_len) where True + indicates masked positions. + """ + batch_size, n_neurons, offset_length = data.shape + + sampled_rate, masked_seq_len = self._select_masking_params() + + num_masked_starting_points = int(offset_length * sampled_rate) + mask = torch.ones((batch_size, offset_length), + dtype=int, + device=data.device) + for batch_idx in range(batch_size): + # Sample starting points for masking in the current batch + start_indices = torch.randperm( + offset_length, device=data.device)[:num_masked_starting_points] + + # Apply masking spans + for start in start_indices: + end = min(start + masked_seq_len, offset_length) + mask[batch_idx, start:end] = 0 # set masked timesteps to 0 + + return mask.unsqueeze(1).expand(-1, n_neurons, + -1) # Expand to all neurons + + def _check_masking_parameters(self, masking_value: Union[float, List[float], + Tuple[float]]): + """ + The masking values are the parameters for the timeblock masking. + It needs to be a tuple of (sampled_rate, masked_seq_len) + sampled_rate: The probability of each time-step being a candidate for masking. + masked_seq_len: The length of each masked span starting from the sampled time-step. + """ + assert isinstance(masking_value, tuple) and len(masking_value) == 2, ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be a tuple of (sampled_rate, masked_seq_len).") + assert 0.0 < masking_value[0] < 1.0 and isinstance( + masking_value[0], float), ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert masking_value[1] > 0 and isinstance(masking_value[1], int), ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be an integer greater than 0.") + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + return self.sampled_rate, self.masked_seq_len diff --git a/cebra/data/masking.py b/cebra/data/masking.py new file mode 100644 index 00000000..2a9a5977 --- /dev/null +++ b/cebra/data/masking.py @@ -0,0 +1,86 @@ +import random +from typing import Dict, Optional + +import torch + +import cebra.data.mask as mask + + +class MaskedMixin: + """A mixin class for applying masking to data. + + Note: + This class is designed to be used as a mixin for other classes. + It provides functionality to apply masking to data. + The `set_masks` method should be called to set the masking types + and their corresponding probabilities. + """ + masks = [] # a list of Mask instances + + def set_masks(self, masking: Optional[Dict[str, float]] = None) -> None: + """Set the mask type and probability for the dataset. + + Args: + masking (Dict[str, float]): A dictionary of masking types and their + corresponding required masking values. The keys are the names + of the Mask instances. + + Note: + By default, no masks are applied. + """ + if masking is not None: + for mask_key in masking: + if mask_key in mask.__all__: + cls = getattr(mask, mask_key) + self.masks = [ + m for m in self.masks if not isinstance(m, cls) + ] + self.masks.append(cls(masking[mask_key])) + else: + raise ValueError( + f"Mask type {mask_key} not supported. Supported types are {masking.keys()}" + ) + + def apply_mask(self, + data: torch.Tensor, + chunk_size: int = 1000) -> torch.Tensor: + """Apply masking to the input data. + + Note: + - By default, no masking. Else apply masking on the input data. + - Only one masking type can be applied at a time, but multiple + masking types can be set so that it alternates between them + across iterations. + - Masking is applied to the data in chunks to avoid memory issues. + + Args: + data (torch.Tensor): batch of size (batch_size, num_neurons, offset). + chunk_size (int): Number of rows to process at a time. + + Returns: + torch.Tensor: The masked data. + """ + if data.dim() != 3: + raise ValueError( + f"Data must be a 3D tensor, but got {data.dim()}D tensor.") + if data.dtype != torch.float32: + raise ValueError( + f"Data must be a float32 tensor, but got {data.dtype}.") + + # If masks is empty, return the data as is + if not self.masks: + return data + + sampled_mask = random.choice(self.masks) + mask = sampled_mask.apply_mask(data) + + num_chunks = (data.shape[0] + chunk_size - + 1) // chunk_size # Compute number of chunks + + for i in range(num_chunks): + start, end = i * chunk_size, min((i + 1) * chunk_size, + data.shape[0]) + data[start:end].mul_( + mask[start:end]) # apply mask in-place to save memory + + return data diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index ddcc0fa8..49e9f894 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -26,9 +26,10 @@ import literate_dataclasses as dataclasses import torch +import torch.nn as nn import cebra.data as cebra_data -import cebra.distributions as cebra_distr +import cebra.distributions from cebra.data.datatypes import Batch from cebra.data.datatypes import BatchIndex @@ -38,6 +39,7 @@ "ContinuousMultiSessionDataLoader", "DiscreteMultiSessionDataLoader", "MixedMultiSessionDataLoader", + "UnifiedLoader", ] @@ -104,10 +106,24 @@ def load_batch(self, index: BatchIndex) -> List[Batch]: ) for session_id, session in enumerate(self.iter_sessions()) ] - def configure_for(self, model): - self.offset = model.get_offset() - for session in self.iter_sessions(): - session.configure_for(model) + def configure_for(self, model: "cebra.models.Model"): + """Configure the dataset offset for the provided model. + + Call this function before indexing the dataset. This sets the + :py:attr:`~.Dataset.offset` attribute of the dataset. + + Args: + model: The model to configure the dataset for. + """ + for i, session in enumerate(self.iter_sessions()): + if isinstance(model, nn.ModuleList): + if len(model) != self.num_sessions: + raise ValueError( + f"The model must have {self.num_sessions} sessions, but got {len(model)}." + ) + session.configure_for(model[i]) + else: + session.configure_for(model) @dataclasses.dataclass @@ -119,12 +135,10 @@ class MultiSessionLoader(cebra_data.Loader): dimension, it is better to use a :py:class:`cebra.data.single_session.MixedDataLoader`. """ - time_offset: int = dataclasses.field(default=10) - def __post_init__(self): super().__post_init__() - self.sampler = cebra_distr.MultisessionSampler(self.dataset, - self.time_offset) + self.sampler = cebra.distributions.MultisessionSampler( + self.dataset, self.time_offset) def get_indices(self, num_samples: int) -> List[BatchIndex]: ref_idx = self.sampler.sample_prior(self.batch_size) @@ -149,7 +163,6 @@ class ContinuousMultiSessionDataLoader(MultiSessionLoader): """Contrastive learning conditioned on a continuous behavior variable.""" conditional: str = "time_delta" - time_offset: int = dataclasses.field(default=10) @property def index(self): @@ -163,7 +176,8 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader): # Overwrite sampler with the discrete implementation # Generalize MultisessionSampler to avoid doing this? def __post_init__(self): - self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset) + self.sampler = cebra.distributions.DiscreteMultisessionSampler( + self.dataset) @property def index(self): @@ -173,3 +187,64 @@ def index(self): @dataclasses.dataclass class MixedMultiSessionDataLoader(MultiSessionLoader): pass + + +@dataclasses.dataclass +class UnifiedLoader(ContinuousMultiSessionDataLoader): + """Dataloader for multi-session datasets, considered as a single session. + + This class is used in pair with :py:class:`cebra.data.datasets.UnifiedDataset` + to sample from each session and train a single model on them, even if sessions have a + different number of neurons. + + To sample the reference and negative samples, a target session is randomly selected. Indexes + are unformly sampled in that first session. Then, indexes in the other sessions are samples + conditionally to the first session indexes, so that their corresponding auxiliary variables + are close. For the positive samples, they are sampled conditionally to the reference samples, + in their corresponding session only. + + Then, the ref/pos/neg samples are concatenated respectively, along the neurons axis (takes place + in the :py:class:`cebra.data.datasets.UnifiedDataset`). + + """ + + def __post_init__(self): + super().__post_init__() + self.sampler = cebra.distributions.UnifiedSampler( + self.dataset, self.time_offset) + + def get_indices(self, num_samples: int) -> BatchIndex: + """Sample and return the specified number of indices. + + The elements of the returned ``BatchIndex`` will be used to index the + ``dataset`` of this data loader. + + To sample the reference and negative samples, a target session is + randomly selected. Indexes are unformly sampled in that first + session. Then, indexes in the other sessions are samples conditionally + to the first session indexes, so that their corresponding auxiliary + variables are close. For the positive samples, they are sampled + conditionally to the reference samples, in their corresponding session + only. + + Args: + num_samples: The size of each of the reference, positive and + negative samples to sample. + + Returns: + Batch indices for the reference, positive and negative samples. + """ + ref_idx = self.sampler.sample_prior(self.batch_size) + neg_idx = self.sampler.sample_prior(self.batch_size) + + pos_idx = self.sampler.sample_conditional(ref_idx) + + ref_idx = torch.from_numpy(ref_idx).to(self.device) + neg_idx = torch.from_numpy(neg_idx).to(self.device) + pos_idx = torch.from_numpy(pos_idx).to(self.device) + + return BatchIndex( + reference=ref_idx, + positive=pos_idx, + negative=neg_idx, + ) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index 31d9b9d7..6aaed3d2 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -64,9 +64,9 @@ def __len__(self): def load_batch(self, index: BatchIndex) -> Batch: """Return the data at the specified index location.""" return Batch( - positive=self[index.positive], - negative=self[index.negative], - reference=self[index.reference], + positive=self.apply_mask(self[index.positive]), + negative=self.apply_mask(self[index.negative]), + reference=self.apply_mask(self[index.reference]), ) @@ -189,7 +189,6 @@ class ContinuousDataLoader(cebra_data.Loader): and become equivalent to time contrastive learning. """, ) - time_offset: int = dataclasses.field(default=10) delta: float = dataclasses.field(default=0.1) def __post_init__(self): @@ -278,7 +277,6 @@ class MixedDataLoader(cebra_data.Loader): """ conditional: str = dataclasses.field(default="time_delta") - time_offset: int = dataclasses.field(default=10) @property def dindex(self): diff --git a/cebra/datasets/__init__.py b/cebra/datasets/__init__.py index 5716e399..7a187489 100644 --- a/cebra/datasets/__init__.py +++ b/cebra/datasets/__init__.py @@ -97,6 +97,8 @@ def get_datapath(path: str = None) -> str: from cebra.datasets.hippocampus import * from cebra.datasets.monkey_reaching import * from cebra.datasets.synthetic_data import * + from cebra.datasets.perich import * + from cebra.datasets.nlb import * except ModuleNotFoundError as e: warnings.warn(f"Could not initialize one or more datasets: {e}. " f"For using the datasets, consider installing the " diff --git a/cebra/datasets/demo.py b/cebra/datasets/demo.py index 90ba5367..df1f75fd 100644 --- a/cebra/datasets/demo.py +++ b/cebra/datasets/demo.py @@ -32,7 +32,8 @@ import cebra.io from cebra.datasets import register -_DEFAULT_NUM_TIMEPOINTS = 100000 +_DEFAULT_NUM_TIMEPOINTS = 1_000 +NUMS_NEURAL = [3, 4, 5] class DemoDataset(cebra.data.SingleSessionDataset): @@ -117,7 +118,7 @@ class MultiDiscrete(cebra.data.DatasetCollection): def __init__( self, - nums_neural=[3, 4, 5], + nums_neural=NUMS_NEURAL, num_timepoints=_DEFAULT_NUM_TIMEPOINTS, ): super().__init__(*[ @@ -131,7 +132,7 @@ class MultiContinuous(cebra.data.DatasetCollection): def __init__( self, - nums_neural=[3, 4, 5], + nums_neural=NUMS_NEURAL, num_behavior=5, num_timepoints=_DEFAULT_NUM_TIMEPOINTS, ): @@ -146,8 +147,26 @@ def __init__( # @register("demo-mixed-multisession") class MultiMixed(cebra.data.DatasetCollection): - def __init__(self, nums_neural=[3, 4, 5], num_behavior=5): + def __init__(self, nums_neural=NUMS_NEURAL, num_behavior=5): super().__init__(*[ DemoDatasetMixed(_DEFAULT_NUM_TIMEPOINTS, num_neural, num_behavior) for num_neural in nums_neural ]) + + +@register("demo-continuous-unified") +class DemoDatasetUnified(cebra.data.UnifiedDataset): + + def __init__( + self, + nums_neural=NUMS_NEURAL, + num_behavior=5, + num_timepoints=_DEFAULT_NUM_TIMEPOINTS, + ): + super().__init__(*[ + DemoDatasetContinuous(num_timepoints, num_neural, num_behavior) + for num_neural in nums_neural + ]) + + self.num_timepoints = num_timepoints + self.nums_neural = nums_neural diff --git a/cebra/distributions/multisession.py b/cebra/distributions/multisession.py index 647044f2..1e0c48d4 100644 --- a/cebra/distributions/multisession.py +++ b/cebra/distributions/multisession.py @@ -21,7 +21,11 @@ # """Continuous variable multi-session sampling.""" +import random +from typing import Optional + import numpy as np +import numpy.typing as npt import torch import cebra.distributions as cebra_distr @@ -383,3 +387,202 @@ def __getitem__(self, pos_idx): for i in range(self.num_sessions): pos_samples[i] = self.data[i][pos_idx[i]] return pos_samples + + +class UnifiedSampler(MultisessionSampler): + """Multi-session sampling, considering them as a single session. + + Align embeddings across multiple sessions, using a set of + auxiliary variables, so that the samples in the different sessions + are sampled together based on how the corresponding auxiliary + variables are close from each other. + + Then, the reference, positive and negative can be concatenated on their + neurons axis to train a single model for all sessions. + + Example: + >>> import cebra.distributions.multisession as cebra_distributions_multisession + >>> import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset + >>> import cebra.data + >>> import torch + >>> from torch import nn + >>> # Multisession training: one model per dataset (different input dimensions) + >>> session1 = torch.rand(100, 30) + >>> session2 = torch.rand(100, 50) + >>> index1 = torch.rand(100) + >>> index2 = torch.rand(100) + >>> num_features = 8 + >>> dataset = cebra.data.UnifiedDataset( + ... cebra_sklearn_dataset.SklearnDataset(session1, (index1, )), + ... cebra_sklearn_dataset.SklearnDataset(session2, (index2, ))) + >>> model = cebra.models.init( + ... name="offset1-model", + ... num_neurons=dataset.input_dimension, + ... num_units=32, + ... num_output=num_features, + ... ).to("cpu") + >>> sampler = cebra_distributions_multisession.UnifiedSampler(dataset, time_offset=10) + + >>> # ref and pos samples from all datasets + >>> ref = sampler.sample_prior(100) + >>> pos = sampler.sample_conditional(ref) + >>> ref = torch.LongTensor(ref) + >>> pos = torch.LongTensor(pos) + >>> loss = (ref - pos)**2 + + Note: + This function does currently not support explicitly selected + discrete indices. They should be added as dimensions to the + continuous index. More weight can be added to the discrete + dimensions by using larger values in one-hot coding. + + """ + + def sample_all_uniform_prior(self, + num_samples: int) -> npt.NDArray[np.int64]: + """Returns uniformly sampled index for all sessions of the dataset. + + Args: + num_samples: Number of samples to sample in each session. + + Returns: + ``(N, num_samples)`` with ``N`` the number of sessions. Array of + samples, uniformly picked for each session. + """ + return super().sample_prior(num_samples=num_samples) + + def sample_prior(self, + num_samples: int, + session_id: Optional[int] = None) -> npt.NDArray[np.int64]: + """Return uniformly sampled indices for all sessions. + + First, the reference indexes in a reference session are uniformly sampled. + Then the reference indexes for the other sessions are sampled so that their + corresponding auxiliary variables are close to the reference indexes of the + reference session. + + Args: + num_samples: Number of samples to pick. + session_id: ID of the session to use as the reference session. If ``None``, + the session is randomly selected. + + Returns: + A :py:func:`numpy.array` containing the idx of the reference samples for all + sessions. + """ + + # Randomly pick the reference session + if session_id is None: + session_id = random.choice(list(range(self.num_sessions))) + + # Sample prior for all sessions + idx = self.sample_all_uniform_prior(num_samples=num_samples) + # Keep the idx for the reference session only + idx = torch.from_numpy(idx[session_id]) + + # Sample the references indexes in other sessions, based on their distance to the + # reference idx in the reference session. + return self.sample_all_sessions(idx, session_id).cpu().numpy() + + def _get_query(self, + reference_idx: torch.Tensor, + session_id: int, + aligned: bool = False) -> torch.Tensor: + """ + + Args: + aligned: If True, no time difference is added to the query. + """ + cum_idx = reference_idx + self.lengths[session_id] + if aligned: + query = self.all_data[cum_idx] + else: + diff_idx = torch.randint(len(self.time_difference), + (len(reference_idx),)) + query = self.all_data[cum_idx] + self.time_difference[diff_idx] + return torch.from_numpy(query).to(_device) + + def sample_all_sessions(self, ref_idx: torch.Tensor, + session_id: int) -> torch.Tensor: + """Sample sessions based on a reference session. + + Reference samples for the ``(session_id)``th session were first sampled uniformly, as in + the py:class:`~.MultisessionSampler`. Then, reference samples for the other sessions + are sampled so that they are as close as the corresponding auxiliary variables in + the reference session. + + Note: similar to ``sample_condiditonal`` but at the level of the sessions, sampling ref idx in each + session so that they are close to the ref idx in the reference session (``session_id``th session). + + Args: + ref_idx: Uniformly sampled ``idx`` for the reference session, ``(num_samples, )``, values + can be in ``[0, len(get_session[session_id])]``. + session_id: Session ID of the reference session, whose ``idx`` are present in ``ref_idx``. + + Returns: + The prior for all sessions, creating a "pseudo-animal", where ``idx`` sampled in different + sessions correspond to points in the recordings where the auxiliary variables are similar. + + """ + # Get the continuous data corresponding to the idx + # all_data: (sum(self.session_lengths), ) + # ref_idx: (num_samples, ), values in [O, len(get_session[session_id])] + # self.lengths: (num_sessions, ), cumsum of the length of each session, providing the first + # element of a session in self.all_data. + # cum_ref_idx: (num_samples, ), values of ref_idx, switched to correspond to the indexes in + # of session_id, in the flatten array self.all_data. + all_idx = torch.zeros(self.num_sessions, len(ref_idx), + device=_device).long() + query = self._get_query( + reference_idx=ref_idx, session_id=session_id, + aligned=True) # same query for all + no time diff added + + for i in range(self.num_sessions): + # except for the session_id provided + if i == session_id: + continue + # different query for each. more robust to variance. + #query = self._get_query(reference_idx=ref_idx, + # session_id=session_id, + # aligned=False) + + # get the idx of the datapoint that is the closest to the query + all_idx[i] = self.index[i].search( + query) # search in the whole dataset + + # all_idx[i] = self.index[i].search_or_mask( + # query, threshold=self.distance_threshold[i]) + + all_idx[session_id] = ref_idx + return all_idx + + def sample_conditional( + self, reference_idx: npt.NDArray[np.int64]) -> torch.Tensor: + """Sample from the conditional distribution. + + Contrary to the :py:class:`MultisessionSampler`, conditional distribution + is sampled so that the samples match the reference samples. They are sampled + from the same session as each reference idx only, rather than across all + sessions. + + Args: + reference_idx: Reference indices, with dimension ``(session, batch)``. + + Returns: + Positive indices, which will be grouped by + session and match the reference indices. + Returned shape is ``(session, batch)``. + + """ + + cond_idx = torch.zeros((reference_idx.shape[0], reference_idx.shape[1]), + dtype=torch.int, + device=_device).long() + + for session_id in range(self.num_sessions): + query = self._get_query(reference_idx=reference_idx[session_id], + session_id=session_id) + + cond_idx[session_id] = self.index[session_id].search(query) + + return cond_idx.cpu().numpy() diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 5fb267ac..b44d7601 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -51,6 +51,7 @@ np.dtypes.Float64DType, np.dtypes.Int64DType ] + def check_version(estimator): # NOTE(stes): required as a check for the old way of specifying tags # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 @@ -76,7 +77,6 @@ def _safe_torch_load(filename, weights_only, **kwargs): return checkpoint - def _init_loader( is_cont: bool, is_disc: bool, @@ -129,7 +129,7 @@ def _init_loader( (not is_cont, not is_disc, is_multi), ] if any(all(combination) for combination in incompatible_combinations): - raise ValueError(f"Invalid index combination.\n" + raise ValueError("Invalid index combination.\n" f"Continuous: {is_cont},\n" f"Discrete: {is_disc},\n" f"Hybrid training: {is_hybrid},\n" @@ -293,7 +293,7 @@ def _require_arg(key): "single-session", ) - error_message = (f"Invalid index combination.\n" + error_message = ("Invalid index combination.\n" f"Continuous: {is_cont},\n" f"Discrete: {is_disc},\n" f"Hybrid training: {is_hybrid},\n" @@ -340,7 +340,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": if missing_keys: raise ValueError( f"Missing keys in data dictionary: {', '.join(missing_keys)}. " - f"You can try loading the CEBRA model with the torch backend.") + "You can try loading the CEBRA model with the torch backend.") args, state, state_dict = cebra_info['args'], cebra_info[ 'state'], cebra_info['state_dict'] @@ -497,6 +497,8 @@ class CEBRA(TransformerMixin, BaseEstimator): optimizer documentation in :py:mod:`torch.optim` for further information on how to format the arguments. |Default:| ``(('betas', (0.9, 0.999)), ('eps', 1e-08), ('weight_decay', 0), ('amsgrad', False))`` + masking_kwargs (dict): + TODO(celia) Example: @@ -570,6 +572,8 @@ def __init__( ("weight_decay", 0), ("amsgrad", False), ), + masking_kwargs: Dict[str, Union[float, List[float], Tuple[float, + ...]]] = None, ): self.__dict__.update(locals()) @@ -656,12 +660,12 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]): # TODO(celia): to make it work for multiple set of index. For now, y should be a tuple of one list only if isinstance(y, tuple) and len(y) > 1: raise NotImplementedError( - f"Support for multiple set of index is not implemented in multissesion training, " + "Support for multiple set of index is not implemented in multissesion training, " f"got {len(y)} sets of indexes.") if not _are_sessions_equal(X, y): raise ValueError( - f"Invalid number of sessions: number of sessions in X and y need to match, " + "Invalid number of sessions: number of sessions in X and y need to match, " f"got X:{len(X)} and y:{[len(y_i) for y_i in y]}.") for session in range(len(X)): @@ -685,8 +689,8 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]): else: if not _are_sessions_equal(X, y): raise ValueError( - f"Invalid number of samples or labels sessions: provide one session for single-session training, " - f"and make sure the number of samples in X and y need match, " + "Invalid number of samples or labels sessions: provide one session for single-session training, " + "and make sure the number of samples in X and y match, " f"got {len(X)} and {[len(y_i) for y_i in y]}.") is_multisession = False dataset = _get_dataset(X, y) @@ -813,8 +817,6 @@ def _configure_for_all( "receptive fields/offsets larger than 1 via the sklearn API. " "Please use a different model, or revert to the pytorch " "API for training.") - - d.configure_for(model[n]) else: if not isinstance(model, cebra.models.ConvolutionalModelMixin): if len(model.get_offset()) > 1: @@ -824,37 +826,13 @@ def _configure_for_all( "Please use a different model, or revert to the pytorch " "API for training.") - dataset.configure_for(model) + dataset.configure_for(model) def _select_model(self, X: Union[npt.NDArray, torch.Tensor], session_id: int): - # Choose the model and get its corresponding offset - if self.num_sessions is not None: # multisession implementation - if session_id is None: - raise RuntimeError( - "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." - ) - if session_id >= self.num_sessions or session_id < 0: - raise RuntimeError( - f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." - ) - if self.n_features_[session_id] != X.shape[1]: - raise ValueError( - f"Invalid input shape: model for session {session_id} requires an input of shape" - f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})." - ) - - model = self.model_[session_id] - model.to(self.device_) - else: # single session - if session_id is not None and session_id > 0: - raise RuntimeError( - f"Invalid session_id {session_id}: single session models only takes an optional null session_id." - ) - model = self.model_ - - offset = model.get_offset() - return model, offset + if isinstance(X, np.ndarray): + X = torch.from_numpy(X) + return self.solver_._select_model(X, session_id=session_id) def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): """Check that the input labels are compatible with the labels used to fit the model. @@ -876,7 +854,7 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): # Check that same number of index if len(self.label_types_) != n_idx: raise ValueError( - f"Number of index invalid: labels must have the same number of index as for fitting," + "Number of index invalid: labels must have the same number of index as for fitting," f"expects {len(self.label_types_)}, got {n_idx} idx.") for i in range(len(self.label_types_)): # for each index @@ -889,12 +867,12 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): > 1): # is there more than one feature in the index if label_types_idx[1][1] != y[i].shape[1]: raise ValueError( - f"Labels invalid: must have the same number of features as the ones used for fitting," + "Labels invalid: must have the same number of features as the ones used for fitting," f"expects {label_types_idx[1]}, got {y[i].shape}.") if label_types_idx[0] != y[i].dtype: raise ValueError( - f"Labels invalid: must have the same type of features as the ones used for fitting," + "Labels invalid: must have the same type of features as the ones used for fitting," f"expects {label_types_idx[0]}, got {y[i].dtype}.") def _prepare_fit( @@ -922,6 +900,8 @@ def _prepare_fit( self.offset_ = self._compute_offset() dataset, is_multisession = self._prepare_data(X, y) + dataset.set_masks(self.masking_kwargs) + loader, solver_name = self._prepare_loader( dataset, max_iterations=self.max_iterations, @@ -1081,14 +1061,13 @@ def _partial_fit( # Save variables of interest as semi-private attributes self.model_ = model - self.n_features_ = ([ - loader.dataset.get_input_dimension(session_id) - for session_id in range(loader.dataset.num_sessions) - ] if is_multisession else loader.dataset.input_dimension) + + self.n_features_ = solver.n_features + self.num_sessions_ = solver.num_sessions if hasattr( + solver, "num_sessions") else None self.solver_ = solver self.n_features_in_ = ([model[n].num_input for n in range(len(model))] if is_multisession else model.num_input) - self.num_sessions_ = loader.dataset.num_sessions if is_multisession else None return self @@ -1236,11 +1215,13 @@ def fit( def transform(self, X: Union[npt.NDArray, torch.Tensor], + batch_size: Optional[int] = None, session_id: Optional[int] = None) -> npt.NDArray: """Transform an input sequence and return the embedding. Args: X: A numpy array or torch tensor of size ``time x dimension``. + batch_size: session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for multisession, set to ``None`` for single session. @@ -1255,37 +1236,28 @@ def transform(self, >>> cebra_model = cebra.CEBRA(max_iterations=10) >>> cebra_model.fit(dataset) CEBRA(max_iterations=10) - >>> embedding = cebra_model.transform(dataset) + >>> embedding = cebra_model.transform(dataset, batch_size=200) """ - sklearn_utils_validation.check_is_fitted(self, "n_features_") - model, offset = self._select_model(X, session_id) + self.solver_._check_is_session_id_valid(session_id=session_id) - # Input validation - X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) - input_dtype = X.dtype + if torch.is_tensor(X): + X = X.detach().cpu() - with torch.no_grad(): - model.eval() - - if self.pad_before_transform: - X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), - mode="edge") - X = torch.from_numpy(X).float().to(self.device_) + X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) - if isinstance(model, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - X = X.transpose(1, 0).unsqueeze(0) - output = model(X).cpu().numpy().squeeze(0).transpose(1, 0) - else: - # Standard evaluation, (T, C, dt) - output = model(X).cpu().numpy() + if isinstance(X, np.ndarray): + X = torch.from_numpy(X) - if input_dtype == "float64": - return output.astype(input_dtype) + with torch.no_grad(): + output = self.solver_.transform( + inputs=X, + pad_before_transform=self.pad_before_transform, + session_id=session_id, + batch_size=batch_size) - return output + return output.detach().cpu().numpy() def fit_transform( self, @@ -1501,6 +1473,11 @@ def load(cls, else: cebra_ = _check_type_checkpoint(checkpoint) + n_features = cebra_.n_features_ + cebra_.solver_.n_features = ([ + session_n_features for session_n_features in n_features + ] if isinstance(n_features, list) else n_features) + return cebra_ def to(self, device: Union[str, torch.device]): diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index 0af44ecb..d8fd791d 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -83,7 +83,8 @@ def infonce_loss( f"got {len(y[0])} sessions.") model, _ = cebra_model._select_model( - X, session_id) # check session_id validity and corresponding model + X, session_id=session_id + ) # check session_id validity and corresponding model cebra_model._check_labels_types(y, session_id=session_id) dataset, is_multisession = cebra_model._prepare_data(X, y) # single session diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index d9bb3083..be6f54ce 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -92,7 +92,8 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: X, accept_sparse=False, accept_large_sparse=False, - dtype=("float16", "float32", "float64"), + # NOTE(celia): remove float16 because F.pad does not allow float16. + dtype=("float32", "float64"), order=None, copy=False, ensure_2d=True, diff --git a/cebra/models/decoders.py b/cebra/models/decoders.py new file mode 100644 index 00000000..ec7c3fca --- /dev/null +++ b/cebra/models/decoders.py @@ -0,0 +1,38 @@ +import torch.nn as nn + +from cebra.models import register + + +@register("one-layer-mlp-decoder") +class SingleLayerDecoder(nn.Module): + """Supervised module to predict behaviors. + + Note: + By default, the output dimension is 2, to predict x/y velocity + (Perich et al., 2018). + """ + + def __init__(self, input_dim, output_dim=2): + super(SingleLayerDecoder, self).__init__() + self.fc = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.fc(x) + + +@register("two-layers-mlp-decoder") +class TwoLayersDecoder(nn.Module): + """Supervised module to predict behaviors. + + Note: + By default, the output dimension is 2, to predict x/y velocity + (Perich et al., 2018). + """ + + def __init__(self, input_dim, output_dim=2): + super(TwoLayersDecoder, self).__init__() + self.fc = nn.Sequential(nn.Linear(input_dim, 32), nn.GELU(), + nn.Linear(32, output_dim)) + + def forward(self, x): + return self.fc(x) diff --git a/cebra/models/model.py b/cebra/models/model.py index a74b0229..33cd2782 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -32,6 +32,8 @@ from cebra.models import parametrize from cebra.models import register +DROPOUT = 0.1 + def _check_torch_version(raise_error=False): current_version = tuple( @@ -224,6 +226,12 @@ def __init__(self, # the self.net self.normalize = normalize + def _make_layers(self, num_units, num_layers, kernel_size=3): + return [ + cebra_layers._Skip(nn.Conv1d(num_units, num_units, kernel_size), + nn.GELU()) for _ in range(num_layers) + ] + def forward(self, inp): """Compute the embedding given the input signal. @@ -266,9 +274,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -529,9 +535,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): cebra_layers._MeanAndConv(num_neurons, num_units, 4, stride=2), nn.Conv1d(num_neurons + num_units, num_units, 3, stride=2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -676,22 +680,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -718,24 +707,9 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): ) super().__init__( nn.Conv1d(num_neurons, num_units, 2), - torch.nn.Dropout1d(p=0.1), + torch.nn.Dropout1d(p=DROPOUT), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -769,9 +743,9 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): ) super().__init__( nn.Conv1d(num_neurons, num_units, 2), - torch.nn.Dropout1d(p=0.1), + torch.nn.Dropout1d(p=DROPOUT), nn.GELU(), - *self._make_layers(num_units, 0.1, 16), + *self._make_layers(num_units, p=DROPOUT, n=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -783,6 +757,54 @@ def get_offset(self) -> cebra.data.datatypes.Offset: return cebra.data.Offset(18, 18) +@register("offset40-model") +class Offset40(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 40 samples receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 18), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(20, 20) + + +@register("offset50-model") +class Offset50(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 23), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(25, 25) + + @register("offset15-model") class Offset15Model(_OffsetModel, ConvolutionalModelMixin): """CEBRA model with a 15 sample receptive field.""" @@ -795,12 +817,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=6), nn.Conv1d(num_units, num_output, 2), num_input=num_neurons, num_output=num_output, @@ -824,14 +841,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=8), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -855,9 +865,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=False): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), nn.Tanh(), # Added tanh activation function num_input=num_neurons, diff --git a/cebra/solver/__init__.py b/cebra/solver/__init__.py index 965c16c8..8bc63a42 100644 --- a/cebra/solver/__init__.py +++ b/cebra/solver/__init__.py @@ -42,5 +42,6 @@ from cebra.solver.schedulers import * from cebra.solver.single_session import * from cebra.solver.supervised import * +from cebra.solver.unified_session import * cebra.registry.add_docstring(__name__) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 992f4dae..1fa194b2 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -33,10 +33,13 @@ import abc import os import warnings -from typing import Callable, Dict, List, Literal, Optional +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union import literate_dataclasses as dataclasses import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torch.utils.data import Dataset import cebra import cebra.data @@ -46,6 +49,239 @@ from cebra.solver.util import ProgressBar +def _check_indices(batch_start_idx: int, batch_end_idx: int, + offset: cebra.data.Offset, num_samples: int): + """Check that indices in a batch are in a correct range. + + First and last index must be positive integers, smaller than + the total length of inputs in the dataset, the first index + must be smaller than the last and the batch size cannot be + smaller than the offset of the model. + + Args: + batch_start_idx: Index of the first sample in the batch. + batch_end_idx: Index of the first sample in the batch. + offset: Model offset. + num_samples: Total number of samples in the input. + """ + + if batch_start_idx < 0 or batch_end_idx < 0: + raise ValueError( + f"batch_start_idx ({batch_start_idx}) and batch_end_idx ({batch_end_idx}) must be positive integers." + ) + if batch_start_idx > batch_end_idx: + raise ValueError( + f"batch_start_idx ({batch_start_idx}) cannot be greater than batch_end_idx ({batch_end_idx})." + ) + if batch_end_idx > num_samples: + raise ValueError( + f"batch_end_idx ({batch_end_idx}) cannot exceed the length of inputs ({num_samples})." + ) + + batch_size_length = batch_end_idx - batch_start_idx + if batch_size_length <= len(offset): + raise ValueError( + f"The batch has length {batch_size_length} which " + f"is smaller or equal than the required offset length {len(offset)}." + f"Either choose a model with smaller offset or the batch should contain 3 times more samples." + ) + + +def _add_batched_zero_padding(batched_data: torch.Tensor, + offset: cebra.data.Offset, batch_start_idx: int, + batch_end_idx: int, + num_samples: int) -> torch.Tensor: + """Add zero padding to the input data before inference. + + Args: + batched_data: Data to apply the inference on. + offset: Offset of the model to consider when padding. + batch_start_idx: Index of the first sample in the batch. + batch_end_idx: Index of the first sample in the batch. + num_samples (int): Total number of samples in the data. + + Returns: + The padded batch. + """ + if batch_start_idx > batch_end_idx: + raise ValueError( + f"batch_start_idx ({batch_start_idx}) cannot be greater than batch_end_idx ({batch_end_idx})." + ) + if batch_start_idx < 0 or batch_end_idx < 0: + raise ValueError( + f"batch_start_idx ({batch_start_idx}) and batch_end_idx ({batch_end_idx}) must be positive integers." + ) + + reversed_dims = torch.arange(batched_data.ndim - 1, -1, -1) + + if batch_start_idx == 0: # First batch + batched_data = F.pad(batched_data.permute(*reversed_dims), + (offset.left, 0), + 'replicate').permute(*reversed_dims) + elif batch_end_idx == num_samples: # Last batch + batched_data = F.pad(batched_data.permute(*reversed_dims), + (0, offset.right - 1), + 'replicate').permute(*reversed_dims) + + return batched_data + + +def _get_batch(inputs: torch.Tensor, offset: Optional[cebra.data.Offset], + batch_start_idx: int, batch_end_idx: int, + pad_before_transform: bool) -> torch.Tensor: + """Get a batch of samples between the `batch_start_idx` and `batch_end_idx`. + + Args: + inputs: Input data. + offset: Model offset. + batch_start_idx: Index of the first sample in the batch. + batch_end_idx: Index of the last sample in the batch. + pad_before_transform: If True zero-pad the batched data. + + Returns: + The batch. + """ + if offset is None: + raise ValueError("offset cannot be null.") + + if batch_start_idx == 0: # First batch + indices = batch_start_idx, (batch_end_idx + offset.right - 1) + elif batch_end_idx == len(inputs): # Last batch + indices = (batch_start_idx - offset.left), batch_end_idx + else: + indices = batch_start_idx - offset.left, batch_end_idx + offset.right - 1 + + _check_indices(indices[0], indices[1], offset, len(inputs)) + batched_data = inputs[slice(*indices)] + + if pad_before_transform: + batched_data = _add_batched_zero_padding( + batched_data=batched_data, + offset=offset, + batch_start_idx=batch_start_idx, + batch_end_idx=batch_end_idx, + num_samples=len(inputs)) + + return batched_data + + +def _inference_transform(model: cebra.models.Model, + inputs: torch.Tensor) -> torch.Tensor: + """Compute the embedding on the inputs using the model provided. + + Args: + model: Model to use for inference. + inputs: Data. + + Returns: + The embedding. + """ + inputs = inputs.float().to(next(model.parameters()).device) + + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + inputs = inputs.transpose(1, 0).unsqueeze(0) + output = model(inputs).squeeze(0).transpose(1, 0) + else: + output = model(inputs) + return output + + +def _not_batched_transform( + model: cebra.models.Model, + inputs: torch.Tensor, + pad_before_transform: bool, + offset: cebra.data.datatypes.Offset, +) -> torch.Tensor: + """Compute the embedding. + + Args: + model: The model to use for inference. + inputs: Input data. + pad_before_transform: If True, the input data is zero padded before inference. + offset: Model offset. + + Returns: + torch.Tensor: The (potentially) padded data. + + Raises: + ValueError: If add_padding is True and offset is not provided. + """ + if pad_before_transform: + inputs = F.pad(inputs.T, (offset.left, offset.right - 1), 'replicate').T + output = _inference_transform(model, inputs) + return output + + +def _batched_transform(model: cebra.models.Model, inputs: torch.Tensor, + batch_size: int, pad_before_transform: bool, + offset: cebra.data.datatypes.Offset) -> torch.Tensor: + """Compute the embedding on batched inputs. + + Args: + model: The model to use for inference. + inputs: Input data. + batch_size: Integer corresponding to the batch size. + pad_before_transform: If True, the input data is zero padded before inference. + offset: Model offset. + + Returns: + The embedding. + """ + + class IndexDataset(Dataset): + + def __init__(self, inputs): + self.inputs = inputs + + def __len__(self): + return len(self.inputs) + + def __getitem__(self, idx): + return idx + + index_dataset = IndexDataset(inputs) + index_dataloader = DataLoader(index_dataset, batch_size=batch_size) + + if len(index_dataloader) < 2: + raise ValueError( + f"Number of batches must be greater than 1, you can use transform " + f"without batching instead, got {len(index_dataloader)}.") + + output = [] + for batch_idx, index_batch in enumerate(index_dataloader): + # NOTE(celia): This is to prevent that adding the offset to the + # penultimate batch for larger offset make the batch_end_idx larger + # than the input length, while we also don't want to drop the last + # samples that do not fit in a complete batch. + if batch_idx == (len(index_dataloader) - 2): + # penultimate batch, last complete batch + last_batch = index_batch + continue + if batch_idx == (len(index_dataloader) - 1): + # last batch, incomplete + index_batch = torch.cat((last_batch, index_batch), dim=0) + + if index_batch[-1] + 1 != len(inputs): + raise ValueError( + f"Last batch index {index_batch[-1]} + 1 should be equal to the length of inputs {len(inputs)}." + ) + + # Batch start and end so that `batch_size` size with the last batch including 2 batches + batch_start_idx, batch_end_idx = index_batch[0], index_batch[-1] + 1 + batched_data = _get_batch(inputs=inputs, + offset=offset, + batch_start_idx=batch_start_idx, + batch_end_idx=batch_end_idx, + pad_before_transform=pad_before_transform) + + output_batch = _inference_transform(model, batched_data) + output.append(output_batch) + + output = torch.cat(output, dim=0) + return output + + @dataclasses.dataclass class Solver(abc.ABC, cebra.io.HasDevice): """Solver base class. @@ -92,7 +328,7 @@ def state_dict(self) -> dict: the model was trained with. """ - return { + state_dict = { "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "loss": torch.tensor(self.history), @@ -102,6 +338,13 @@ def state_dict(self) -> dict: "log": self.log, } + if hasattr(self, "n_features"): + state_dict["n_features"] = self.n_features + if hasattr(self, "num_sessions"): + state_dict["num_sessions"] = self.num_sessions + + return state_dict + def load_state_dict(self, state_dict: dict, strict: bool = True): """Update the solver state with the given state_dict. @@ -139,18 +382,55 @@ def _get(key): if _contains("log"): self.log = _get("log") + # Not defined if the model was saved before being fitted. + if "n_features" in state_dict: + self.n_features = _get("n_features") + if "num_sessions" in state_dict: + self.num_sessions = _get("num_sessions") + @property def num_parameters(self) -> int: """Total number of parameters in the encoder and criterion.""" return sum(p.numel() for p in self.parameters()) - def parameters(self): - """Iterate over all parameters.""" - for parameter in self.model.parameters(): - yield parameter + @abc.abstractmethod + def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters of the model. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Yields: + The parameters of the model. + """ + raise NotImplementedError + + def _compute_features( + self, + batch: cebra.data.Batch, + model: Optional[torch.nn.Module] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the features of the reference, positive and negative samples. + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + model: The model to use for inference. + If not provided, the model of the solver is used. + + Returns: + Tuple of reference, positive and negative features. + """ + if model is None: + model = self.model - for parameter in self.criterion.parameters(): - yield parameter + batch.to(self.device) + ref = model(batch.reference) + pos = model(batch.positive) + neg = model(batch.negative) + return ref, pos, neg def _get_loader(self, loader): return ProgressBar( @@ -158,6 +438,16 @@ def _get_loader(self, loader): "tqdm" if self.tqdm_on else "off", ) + @abc.abstractmethod + def _set_fitted_params(self, loader: cebra.data.Loader): + """Set parameters once the solver is fitted. + + Args: + loader: Loader used to fit the solver. + """ + + raise NotImplementedError + def fit( self, loader: cebra.data.Loader, @@ -185,7 +475,7 @@ def fit( TODO: * Refine the API here. Drop the validation entirely, and implement this via a hook? """ - + self._set_fitted_params(loader) self.to(loader.device) iterator = self._get_loader(loader) @@ -250,13 +540,17 @@ def validation(self, Args: loader: Data loader, which is an iterator over `cebra.data.Batch` instances. Each batch contains reference, positive and negative input samples. - session_id: The session ID, an integer between 0 and the number of sessions in the - multisession model, set to None for single session. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. Returns: Loss averaged over iterations on data batch. """ - assert (session_id is None) or (session_id == 0) + if session_id is not None and session_id != 0: + raise ValueError( + f"session_id should be set to None or 0, got {session_id}") + iterator = self._get_loader(loader) total_loss = Meter() self.model.eval() @@ -285,32 +579,165 @@ def decoding(self, train_loader, valid_loader): ) return decode_metric + def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): + """Check that the inputs can be inferred using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + if isinstance(inputs, list): + raise ValueError( + "Inputs to transform() should be the data for a single session, but received a list." + ) + elif not isinstance(inputs, torch.Tensor): + raise ValueError( + f"Inputs should be a torch.Tensor, not {type(inputs)}.") + + @abc.abstractmethod + def _check_is_session_id_valid(self, session_id: Optional[int] = None): + """Check that the session ID provided is valid for the solver instance. + + Args: + session_id: The session ID to check. + """ + raise NotImplementedError + + def _select_model( + self, inputs: Union[torch.Tensor, + List[torch.Tensor]], session_id: Optional[int] + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], + cebra.data.datatypes.Offset]: + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model (first returns) and the offset of the model (second returns). + """ + model = self._get_model(session_id=session_id) + offset = model.get_offset() + + self._check_is_inputs_valid(inputs, session_id=session_id) + return model, offset + + @abc.abstractmethod + def _get_model(self, + session_id: Optional[int] = None) -> cebra.models.Model: + """Get the model to use for inference. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model. + """ + raise NotImplementedError + + def _check_is_fitted(self): + """Check if the model is fitted. + + If the model is fitted, the solver should have a `n_features` attribute. + + Raises: + ValueError: If the model is not fitted. + """ + if not hasattr(self, "n_features"): + raise ValueError( + f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this estimator.") + @torch.no_grad() - def transform(self, inputs: torch.Tensor) -> torch.Tensor: + def transform(self, + inputs: torch.Tensor, + pad_before_transform: Optional[bool] = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = None) -> torch.Tensor: """Compute the embedding. This function by default only applies the ``forward`` function of the given model, after switching it into eval mode. Args: - inputs: The input signal + inputs: The input signal (T, N). + pad_before_transform: If ``False``, no padding is applied to the input + sequence and the output sequence will be smaller than the input + sequence due to the receptive field of the model. If the + input sequence is ``n`` steps long, and a model with receptive + field ``m`` is used, the output sequence would only be + ``n-m+1`` steps long. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + batch_size: If not None, batched inference will not be applied. Returns: The output embedding. - - TODO: - * Remove eval mode """ + self._check_is_fitted() + model, offset = self._select_model(inputs, session_id) - self.model.eval() - return self.model(inputs) + if len(offset) < 2 and pad_before_transform: + pad_before_transform = False + + model.eval() + return self._transform(model=model, + inputs=inputs, + pad_before_transform=pad_before_transform, + offset=offset, + batch_size=batch_size) + + @torch.no_grad() + def _transform(self, model: cebra.models.Model, inputs: torch.Tensor, + pad_before_transform: bool, + offset: cebra.data.datatypes.Offset, + batch_size: Optional[int]) -> torch.Tensor: + """Compute the embedding on the inputs using the model provided. + + Args: + model: Model to use for inference. + inputs: Data. + pad_before_transform: If True zero-pad the batched data. + offset: Offset of the model to consider when padding. + batch_size: If not None, batched inference will not be applied. + + Returns: + The embedding. + """ + if batch_size is not None and inputs.shape[0] > int( + batch_size * 2) and not (isinstance( + self._get_model(0), cebra.models.ResampleModelMixin)): + # NOTE(celia): resampling models are not supported for batched inference. + output = _batched_transform( + model=model, + inputs=inputs, + offset=offset, + batch_size=batch_size, + pad_before_transform=pad_before_transform, + ) + else: + output = _not_batched_transform( + model=model, + inputs=inputs, + offset=offset, + pad_before_transform=pad_before_transform) + return output @abc.abstractmethod def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """Given a batch of input examples, return the model outputs. - TODO: make this a public function? - Args: batch: The input data, not necessarily aligned across the batch dimension. This means that ``batch.index`` specifies the map @@ -323,11 +750,12 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """ raise NotImplementedError - def load(self, logdir, filename="checkpoint.pth"): + def load(self, logdir: str, filename: str = "checkpoint.pth"): """Load the experiment from its checkpoint file. Args: - filename (str): Checkpoint name for loading the experiment. + logdir: Logging directory. + filename: Checkpoint name for loading the experiment. """ savepath = os.path.join(logdir, filename) @@ -337,7 +765,12 @@ def load(self, logdir, filename="checkpoint.pth"): checkpoint = torch.load(savepath, map_location=self.device) self.load_state_dict(checkpoint, strict=True) - def save(self, logdir, filename="checkpoint_last.pth"): + n_features = self.n_features + self.n_features = ([ + session_n_features for session_n_features in n_features + ] if isinstance(n_features, list) else n_features) + + def save(self, logdir: str, filename: str = "checkpoint_last.pth"): """Save the model and optimizer params. Args: @@ -477,3 +910,37 @@ def step(self, batch: cebra.data.Batch) -> dict: time_neg=time_uniform.item(), time_total=time_loss.item(), ) + + +class AuxiliaryVariableSolver(Solver): + + @torch.no_grad() + def transform(self, + inputs: torch.Tensor, + pad_before_transform: bool = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = None, + use_reference_model: bool = False) -> torch.Tensor: + """Compute the embedding. + This function by default use ``model`` that was trained to encode the positive + and negative samples. To use ``reference_model`` instead of ``model`` + ``use_reference_model`` should be equal ``True``. + Args: + inputs: The input signal + use_reference_model: Flag for using ``reference_model`` + Returns: + The output embedding. + """ + self._check_is_fitted() + model, offset = self._select_model( + inputs, session_id, use_reference_model=use_reference_model) + + if len(offset) < 2 and pad_before_transform: + pad_before_transform = False + + model.eval() + return self._transform(model=model, + inputs=inputs, + pad_before_transform=pad_before_transform, + offset=offset, + batch_size=batch_size) diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index eabce729..eed7fa6b 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -21,6 +21,7 @@ # """Solver implementations for multi-session datasetes.""" +import copy from typing import List, Optional import torch @@ -39,6 +40,25 @@ class MultiSessionSolver(abc_.Solver): _variant_name = "multi-session" + def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Yields: + The parameters of the model. + """ + self._check_is_session_id_valid(session_id=session_id) + + for parameter in self.model[session_id].parameters(): + yield parameter + + for parameter in self.criterion.parameters(): + yield parameter + def _mix(self, array: torch.Tensor, idx: torch.Tensor) -> torch.Tensor: shape = array.shape n, m = shape[:2] @@ -60,10 +80,11 @@ def _single_model_inference(self, batch: cebra.data.Batch, across the sample dimensions, the output data should be aligned and ``batch.index`` should be set to ``None``. """ - batch.to(self.device) - ref = torch.stack([model(batch.reference)], dim=0) - pos = torch.stack([model(batch.positive)], dim=0) - neg = torch.stack([model(batch.negative)], dim=0) + ref, pos, neg = self._compute_features(batch, model) + + ref = ref.unsqueeze(0) + pos = pos.unsqueeze(0) + neg = neg.unsqueeze(0) pos = self._mix(pos, batch.index_reversed) @@ -112,6 +133,77 @@ def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: negative=neg.view(-1, num_features), ) + def _set_fitted_params(self, loader: cebra.data.Loader): + """Set parameters once the solver is fitted. + + In multi session solver, the number of session is set to the number of + sessions in the dataset of the loader and the number of + features is set as a list corresponding to the number of neurons in + each dataset. + + Args: + loader: Loader used to fit the solver. + """ + + self.num_sessions = loader.dataset.num_sessions + self.n_features = [ + loader.dataset.get_input_dimension(session_id) + for session_id in range(loader.dataset.num_sessions) + ] + + def _check_is_inputs_valid(self, inputs: torch.Tensor, + session_id: Optional[int]): + """Check that the inputs can be inferred using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + super()._check_is_inputs_valid(inputs, session_id=session_id) + if self.n_features[session_id] != inputs.shape[1]: + raise ValueError( + f"Invalid input shape: model for session {session_id} requires an input of shape" + f"(n_samples, {self.n_features[session_id]}), got (n_samples, {inputs.shape[1]})." + ) + + def _check_is_session_id_valid(self, session_id: Optional[int]): + """Check that the session ID provided is valid for the solver instance. + + The session ID must be non-null and between 0 and the number session + in the dataset. + + Args: + session_id: The session ID to check. + """ + if session_id is None: + raise RuntimeError( + "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." + ) + if session_id >= self.num_sessions or session_id < 0: + raise RuntimeError( + f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." + ) + + def _get_model(self, session_id: Optional[int] = None): + """Get the model for the given session ID. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model for the given session ID. + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + return self.model[session_id] + def validation(self, loader, session_id: Optional[int] = None): """Compute score of the model on data. @@ -143,13 +235,36 @@ def validation(self, loader, session_id: Optional[int] = None): @register("multi-session-aux") -class MultiSessionAuxVariableSolver(abc_.Solver): +class MultiSessionAuxVariableSolver(MultiSessionSolver, + abc_.AuxiliaryVariableSolver): """Multi session training, contrasting neural data against behavior.""" _variant_name = "multi-session-aux" - reference_model: torch.nn.Module + reference_model: torch.nn.Module = None + + def __post_init__(self): + super().__post_init__() + if self.reference_model is None: + # NOTE(stes): This should work, according to this thread + # https://discuss.pytorch.org/t/can-i-deepcopy-a-model/52192/19 + # and create a true copy of the model. + self.reference_model = copy.deepcopy(self.model) + self.reference_model.to(self.device) + + def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: + """Given batches of input examples, computes the feature representations/embeddings. + + Args: + batches: A list of input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. - def _inference(self, batches): + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + + """ refs = [] poss = [] negs = [] @@ -169,3 +284,24 @@ def _inference(self, batches): positive=pos.view(-1, num_features), negative=neg.view(-1, num_features), ) + + def _get_model(self, + session_id: Optional[int] = None, + use_reference_model: bool = False): + """Get the model for the given session ID. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model for the given session ID. + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + if use_reference_model: + model = self.reference_model[session_id] + else: + model = self.model[session_id] + return model diff --git a/cebra/solver/multiobjective.py b/cebra/solver/multiobjective.py index d4aa187d..98587bd7 100644 --- a/cebra/solver/multiobjective.py +++ b/cebra/solver/multiobjective.py @@ -53,6 +53,7 @@ import cebra.data import cebra.io import cebra.models +import cebra.solver.single_session as cebra_solver_single from cebra.solver import register from cebra.solver.base import Solver from cebra.solver.schedulers import Scheduler @@ -154,7 +155,7 @@ def finalize(self): if len(set(self.feature_ranges_tuple)) != len( self.feature_ranges_tuple): raise RuntimeError( - f"Feature ranges are not unique. Please check again and remove the duplicates. " + "Feature ranges are not unique. Please check again and remove the duplicates. " f"Feature ranges: {self.feature_ranges_tuple}") print("Creating MultiCriterion") @@ -187,7 +188,7 @@ def _process_info(self, info): @dataclasses.dataclass -class MultiobjectiveSolverBase(Solver): +class MultiobjectiveSolverBase(cebra_solver_single.SingleSessionSolver): feature_ranges: List[slice] = None renormalize: bool = None @@ -209,6 +210,13 @@ def __post_init__(self): renormalize=self.renormalize, ) + def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters.""" + super().parameters(session_id=session_id) + + for parameter in self.regularizer.parameters(): + yield parameter + def fit(self, loader: cebra.data.Loader, valid_loader: cebra.data.Loader = None, @@ -241,6 +249,7 @@ def _run_validation(): save_hook(solver=self, step=num_steps) return stats_val + self._set_fitted_params(loader) self.to(loader.device) iterator = self._get_loader(loader, @@ -393,11 +402,14 @@ def validation( logger=None, weights_loss: Optional[List[float]] = None, ): + loader.dataset.configure_for(self.model) + iterator = self._get_loader(loader) + self.model.eval() total_loss = Meter() losses_dict = {} - for _, batch in enumerate(loader): + for _, batch in iterator: predictions = self._inference(batch) losses = self.criterion(predictions) @@ -444,37 +456,6 @@ def validation( self.log.setdefault(("sum_loss_val",), []).append(sum_loss_valid) return stats_val - @torch.no_grad() - def transform(self, inputs: torch.Tensor) -> torch.Tensor: - offset = self.model.get_offset() - self.model.eval() - X = inputs.cpu().numpy() - X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), mode="edge") - X = torch.from_numpy(X).float().to(self.device) - - if isinstance(self.model.module, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - X = X.transpose(1, 0).unsqueeze(0) - outputs = self.model(X) - - # switch back from (1, C, T) -> (T, C) - if isinstance(outputs, torch.Tensor): - assert outputs.dim() == 3 and outputs.shape[0] == 1 - outputs = outputs.squeeze(0).transpose(1, 0) - elif isinstance(outputs, tuple): - assert all(tensor.dim() == 3 and tensor.shape[0] == 1 - for tensor in outputs) - outputs = ( - output.squeeze(0).transpose(1, 0) for output in outputs) - outputs = tuple(outputs) - else: - raise ValueError("Invalid condition in solver.transform") - else: - # Standard evaluation, (T, C, dt) - outputs = self.model(X) - - return outputs - @register("supervised-solver-xcebra") @dataclasses.dataclass diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index d172fadc..751e0e61 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -22,6 +22,7 @@ """Single session solvers embed a single pair of time series.""" import copy +from typing import Optional, Tuple import literate_dataclasses as dataclasses import torch @@ -38,11 +39,91 @@ class SingleSessionSolver(abc_.Solver): """Single session training with a symmetric encoder. This solver assumes that reference, positive and negative samples - are processed by the same features encoder. + are processed by the same features encoder and that a single session + is provided to that encoder. """ _variant_name = "single-session" + def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Yields: + The parameters of the model. + """ + # If session_id is invalid, it doesn't matter, since we are + # using a single session solver. + for parameter in self.model.parameters(): + yield parameter + + for parameter in self.criterion.parameters(): + yield parameter + + def _set_fitted_params(self, loader: cebra.data.Loader): + """Set parameters once the solver is fitted. + + In single session solver, the number of session is set to None and the number of + features is set to the number of neurons in the dataset. + + Args: + loader: Loader used to fit the solver. + """ + #self.num_sessions = None + self.n_features = loader.dataset.input_dimension + + def _check_is_inputs_valid(self, inputs: torch.Tensor, session_id: int): + """Check that the inputs can be inferred using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + super()._check_is_inputs_valid(inputs, session_id=session_id) + if self.n_features != inputs.shape[1]: + raise ValueError( + f"Invalid input shape: model for session {session_id} requires an input of shape" + f"(n_samples, {self.n_features}), got (n_samples, {inputs.shape[1]})." + ) + + def _check_is_session_id_valid(self, session_id: Optional[int] = None): + """Check that the session ID provided is valid for the solver instance. + + The session ID must be null or equal to 0. + + Args: + session_id: The session ID to check. + """ + + if session_id is not None and session_id > 0: + raise RuntimeError( + f"Invalid session_id {session_id}: single session models only takes an optional null session_id." + ) + + def _get_model(self, session_id: Optional[int] = None): + """Get the model for the given session ID. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model for the given session ID. + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + return self.model + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """Given a batch of input examples, computes the feature representation/embedding. @@ -56,10 +137,7 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: across the sample dimensions, the output data should be aligned and ``batch.index`` should be set to ``None``. """ - batch.to(self.device) - ref = self.model(batch.reference) - pos = self.model(batch.positive) - neg = self.model(batch.negative) + ref, pos, neg = self._compute_features(batch) return cebra.data.Batch(ref, pos, neg) def get_embedding(self, data: torch.Tensor) -> torch.Tensor: @@ -90,7 +168,8 @@ def get_embedding(self, data: torch.Tensor) -> torch.Tensor: @register("single-session-aux") @dataclasses.dataclass -class SingleSessionAuxVariableSolver(abc_.Solver): +class SingleSessionAuxVariableSolver(SingleSessionSolver, + abc_.AuxiliaryVariableSolver): """Single session training for reference and positive/negative samples. This solver processes reference samples with a model different from @@ -117,7 +196,54 @@ def __post_init__(self): self.reference_model = copy.deepcopy(self.model) self.reference_model.to(self.model.device) - def _inference(self, batch): + def _get_model(self, + session_id: Optional[int] = None, + use_reference_model: bool = False): + """Get the model for the given session ID. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model for the given session ID. + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + if use_reference_model: + model = self.reference_model[session_id] + else: + model = self.model[session_id] + return model + + def _compute_features( + self, + batch: cebra.data.Batch, + model: Optional[torch.nn.Module] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch.to(self.device) + ref = self.reference_model(batch.reference) + pos = self.model(batch.positive) + neg = self.model(batch.negative) + return ref, pos, neg + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given a batch of input examples, computes the feature representation/embedding. + + The reference samples are processed with a different model than the + positive and negative samples. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ batch.to(self.device) ref = self.reference_model(batch.reference) pos = self.model(batch.positive) @@ -127,12 +253,27 @@ def _inference(self, batch): @register("single-session-hybrid") @dataclasses.dataclass -class SingleSessionHybridSolver(abc_.MultiobjectiveSolver): +class SingleSessionHybridSolver(abc_.MultiobjectiveSolver, SingleSessionSolver): """Single session training, contrasting neural data against behavior.""" _variant_name = "single-session-hybrid" def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given a batch of input examples, computes the feature representation/embedding. + + The samples are processed with both a time-contrastive module and a + behavior-contrastive module, that are part of the same model. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ batch.to(self.device) behavior_ref = self.model(batch.reference)[0] behavior_pos = self.model(batch.positive[:int(len(batch.positive) // @@ -145,6 +286,21 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: behavior_neg), cebra.data.Batch( time_ref, time_pos, time_neg) + def _get_model(self, session_id: Optional[int] = None): + """Get the model for the given session ID. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model for the given session ID. + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + return self.model.module + @register("single-session-full") @dataclasses.dataclass @@ -199,6 +355,18 @@ def get_embedding(self, data): return self.model(data[0].T) def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given a batch of input examples, computes the feature representation/embedding. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ outputs = self.get_embedding(self.neural) idc = batch.positive - self.offset.left >= len(outputs) batch.positive[idc] = batch.reference[idc] diff --git a/cebra/solver/unified_session.py b/cebra/solver/unified_session.py new file mode 100644 index 00000000..f4ce138a --- /dev/null +++ b/cebra/solver/unified_session.py @@ -0,0 +1,445 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Solver implementations for unified-session datasets.""" + +from typing import List, Optional, Union + +import literate_dataclasses as dataclasses +import numpy as np +import torch + +import cebra +import cebra.data +import cebra.distributions +import cebra.models +import cebra.solver.base as abc_ +from cebra.solver import register + + +@register("unified-session") +@dataclasses.dataclass +class UnifiedSolver(abc_.Solver): + """Multi session training, considering a single model for all sessions.""" + + _variant_name = "unified-session" + + def parameters(self, session_id: Optional[int] = None): # same as single + """Iterate over all parameters.""" + for parameter in self.model.parameters(): + yield parameter + + for parameter in self.criterion.parameters(): + yield parameter + + def _set_fitted_params(self, loader: cebra.data.Loader): # mix + """Set parameters once the solver is fitted. + + In single session solver, the number of session is set to None and the number of + features is set to the number of neurons in the dataset. + + Args: + loader: Loader used to fit the solver. + """ + self.num_sessions = loader.dataset.num_sessions + self.n_features = loader.dataset.input_dimension + + def _check_is_inputs_valid(self, inputs: Union[torch.Tensor, + List[torch.Tensor]], + session_id: int): + """Check that the inputs can be infered using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + + if isinstance(inputs, list): + inputs_shape = 0 + for i in range(len(inputs)): + inputs_shape += inputs[i].shape[1] + elif isinstance(inputs, + torch.Tensor): #NOTE(celia): flexible input at training + raise NotImplementedError + else: + raise NotImplementedError + + if self.n_features != inputs_shape: + raise ValueError( + f"Invalid input shape: model requires an input of shape" + f"(n_samples, {self.n_features}), got (n_samples, {inputs.shape[1]})." + ) + + def _check_is_session_id_valid( + self, + session_id: Optional[int] = None, + ): # same as multi + """Check that the session ID provided is valid for the solver instance. + + The session ID must be non-null and between 0 and the number session in the dataset. + + Args: + session_id: The session ID to check. + """ + + if session_id is None: + raise RuntimeError( + "No session_id provided: unified model requires a session_id as the target session to use to align the sessions." + ) + if session_id >= self.num_sessions or session_id < 0: + raise RuntimeError( + f"Invalid session_id {session_id}: session_id for the current unified model must be between 0 and {self.num_sessions-1}." + ) + + def _get_model(self, session_id: Optional[int] = None): + """Get the model for the given session ID. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model for the given session ID. + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + return self.model + + def _single_model_inference(self, batch: cebra.data.Batch, + model: torch.nn.Module) -> cebra.data.Batch: + """Given a single batch of input examples, computes the feature representation/embedding. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + model: The model to use for inference. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ + ref, pos, neg = self._compute_features(batch, model) + ref = ref.unsqueeze(0) + pos = pos.unsqueeze(0) + neg = neg.unsqueeze(0) + + num_features = neg.shape[2] + + return cebra.data.Batch( + reference=ref.view(-1, num_features), + positive=pos.view(-1, num_features), + negative=neg.view(-1, num_features), + ) + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given batches of input examples, computes the feature representations/embeddings. + + Args: + batches: A list of input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ + return self._single_model_inference(batch, self.model) + + @torch.no_grad() + def transform(self, + inputs: List[torch.Tensor], + labels: List[torch.Tensor], + pad_before_transform: bool = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = 512) -> torch.Tensor: + """Compute the embedding for the `session_id`th session of the dataset. + + Note: + Compared to the other :py:class:`cebra.solver.base.Solver`, we need all the sessions of + the dataset to transform the data, as the sampling is across all the sessions. + + Args: + inputs: The input signal for all sessions. + labels: The auxiliary variables to use for sampling. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1. + batch_size: If not None, batched inference will be applied. + + Note: + The ``session_id`` is needed in order to sample the corresponding number of samples and + return an embedding of the expected shape. + + Note: + The batched inference will be required in most cases. Default is set to ``100`` for that reason. + + Returns: + The output embedding for the session corresponding to the provided ID `session_id`. The shape + is (num_samples(session_id), output_dimension)``. + + """ + if not isinstance(inputs, list): + raise ValueError( + f"Inputs to transform() should be a list, not {type(inputs)}.") + + self._check_is_fitted() + + if session_id is None: + raise ValueError("Session ID is required for multi-session models.") + + # Sampling according to session_id required + dataset = cebra.data.UnifiedDataset( + cebra.data.TensorDataset( + inputs[i], continuous=labels[i], offset=cebra.data.Offset(0, 1)) + for i in range(len(inputs))).to(self.device) + + # Only used to sample the reference samples + loader = cebra.data.UnifiedLoader(dataset, num_steps=1) + + # Sampling in batch + refs_data_batch_embeddings = [] + batch_range = range(0, len(dataset.get_session(session_id)), batch_size) + if len(batch_range) < 2: + raise ValueError( + "Not enough data to perform the batched transform. Please provide a larger dataset or reduce the batch size." + ) + for batch_start in batch_range: + batch_end = min(batch_start + batch_size, + len(dataset.get_session(session_id))) + + if batch_start == batch_range[-2]: # one before last batch + last_start = batch_start + continue + if batch_start == batch_range[-1]: # last batch, likely uncomplete + batch_start = last_start + batch_end = len(dataset.get_session(session_id)) + + refs_idx_batch = loader.sampler.sample_all_sessions( + ref_idx=torch.arange(batch_start, batch_end), + session_id=session_id).to(self.device) + + refs_data_batch = torch.cat([ + session[refs_idx_batch[session_id]] + for session_id, session in enumerate(dataset.iter_sessions()) + ], + dim=1).squeeze() + # refs_data_batch_embeddings.append(super().transform( + # torch.cat(refs_data_batch, dim=1).squeeze(), + # pad_before_transform=pad_before_transform)) + + if len(self.model.get_offset()) < 2 and pad_before_transform: + pad_before_transform = False + + self.model.eval() + refs_data_batch_embeddings.append( + self._transform(model=self.model, + inputs=refs_data_batch, + pad_before_transform=pad_before_transform, + offset=self.model.get_offset(), + batch_size=batch_size)) + + return torch.cat(refs_data_batch_embeddings, dim=0) + + @torch.no_grad() + def single_session_transform( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + session_id: Optional[int] = None, + pad_before_transform: bool = True, + padding_mode: str = "zero", + batch_size: Optional[int] = 100) -> torch.Tensor: + """Compute the embedding for the `session_id`th session of the dataset without labels alignement. + + By padding the channels that don't correspond to the {session_id}th session, we can + use a single session solver without behavioral alignment. + + Note: The embedding will not benefit from the behavioral alignment, and consequently + from the information contained in the other sessions. We expect single session encoder + behavioral decoding performances. + + Args: + inputs: The input signal for all sessions. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions. + pad_before_transform: If True, pads the input before applying the transform. + padding_mode: The mode to use for padding. Padding is done in the following + ways, either by padding all the other sessions to the length of the + {session_id}th session, or by resampling all sessions in a random way: + - `time`: pads the inputs that are not infered to the maximum length of + the session and then zeros so that the lenght is the same as the + {session_id}th session length. + - `zero`: pads the inputs that are not infered with zeros so that the + lenght is the same as the {session_id}th session length. + - `poisson`: pads the inputs that are not infered with a poisson distribution + so that the lenght is the same as the {session_id}th session length. + - `random`: pads all sessions with random values sampled from a normal + distribution. + - `random_poisson`: pads all sessions with random values sampled from a + poisson distribution. + + batch_size: If not None, batched inference will be applied. + + Returns: + The output embedding for the session corresponding to the provided ID `session_id`. The shape + is (num_samples(session_id), output_dimension)``. + """ + inputs = [session.to(self.device) for session in inputs] + + zero_shape = inputs[session_id].shape[0] + + if padding_mode == "time" or padding_mode == "zero" or padding_mode == "poisson": + for i in range(len(inputs)): + if i != session_id: + if padding_mode == "time": + if inputs[i].shape[0] >= zero_shape: + inputs[i] = inputs[i][:zero_shape] + else: + inputs[i] = torch.cat( + (inputs[i], + torch.zeros( + (zero_shape - inputs[i].shape[0], + inputs[i].shape[1])).to(self.device))) + if padding_mode == "poisson": + inputs[i] = torch.poisson( + torch.ones((zero_shape, inputs[i].shape[1]))) + if padding_mode == "zero": + inputs[i] = torch.zeros( + (zero_shape, inputs[i].shape[1])) + padded_inputs = torch.cat( + [session.to(self.device) for session in inputs], dim=1) + + elif padding_mode == "random_poisson": + padded_inputs = torch.poisson( + torch.ones((zero_shape, self.n_features))) + elif padding_mode == "random": + padded_inputs = torch.normal( + torch.zeros((zero_shape, self.n_features)), + torch.ones((zero_shape, self.n_features))) + + else: + raise ValueError( + f"Invalid padding mode: {padding_mode}. " + "Choose from 'time', 'zero', 'poisson', 'random', or 'random_poisson'." + ) + + # Single session solver transform call + return super().transform(inputs=padded_inputs, + pad_before_transform=pad_before_transform, + batch_size=batch_size) + + @torch.no_grad() + def decoding(self, + train_loader: cebra.data.Loader, + valid_loader: Optional[cebra.data.Loader] = None, + decode: str = "ridge", + max_sessions: int = 5, + max_timesteps: int = 512) -> float: + # Sample a fixed number of sessions to compute the decoding score + # Sample a fixed number of timesteps to compute the decoding score (always the first ones) + if train_loader.dataset.num_sessions > max_sessions: + sessions = np.random.choice(np.arange( + train_loader.dataset.num_sessions), + size=max_sessions, + replace=False) + else: + sessions = np.arange(train_loader.dataset.num_sessions) + + train_scores, valid_scores = [], [] + for i in sessions: + if train_loader.dataset.get_session( + i).neural.shape[0] > max_timesteps: + train_end = max_timesteps + else: + train_end = -1 + train_x = self.transform([ + train_loader.dataset.get_session(j).neural[:train_end] + for j in range(train_loader.dataset.num_sessions) + ], [ + train_loader.dataset.get_session(j).continuous_index[:train_end] + if train_loader.dataset.get_session(j).continuous_index + is not None else + train_loader.dataset.get_session(j).discrete_index[:train_end] + for j in range(train_loader.dataset.num_sessions) + ], + session_id=i, + batch_size=128) + train_y = train_loader.dataset.get_session( + i + ).continuous_index[:train_end] if train_loader.dataset.get_session( + i + ).continuous_index is not None else train_loader.dataset.get_session( + i).discrete_index[:train_end] + + if valid_loader is not None: + if valid_loader.dataset.get_session( + i).neural.shape[0] > max_timesteps: + valid_end = max_timesteps + else: + valid_end = -1 + valid_x = self.transform([ + valid_loader.dataset.get_session(j).neural[:valid_end] + for j in range(valid_loader.dataset.num_sessions) + ], [ + valid_loader.dataset.get_session( + j).continuous_index[:valid_end] + if valid_loader.dataset.get_session(j).continuous_index + is not None else valid_loader.dataset.get_session( + j).discrete_index[:valid_end] + for j in range(valid_loader.dataset.num_sessions) + ], + session_id=i, + batch_size=128) + valid_y = valid_loader.dataset.get_session( + i + ).continuous_index[:valid_end] if valid_loader.dataset.get_session( + i + ).continuous_index is not None else valid_loader.dataset.get_session( + i).discrete_index[:valid_end] + + if decode == "knn": + decoder = cebra.KNNDecoder() + elif decode == "ridge": + decoder = cebra.RidgeRegressor() + else: + raise NotImplementedError(f"Decoder {decode} not implemented.") + + decoder.fit(train_x.cpu().numpy(), train_y.cpu().numpy()) + train_scores.append( + decoder.score(train_x.cpu().numpy(), + train_y.cpu().numpy())) + + if valid_loader is not None: + valid_scores.append( + decoder.score(valid_x.cpu().numpy(), + valid_y.cpu().numpy())) + + if valid_loader is None: + return np.array(train_scores).mean() + else: + return np.array(train_scores).mean(), np.array(valid_scores).mean() diff --git a/docs/source/conf.py b/docs/source/conf.py index 4147e7c9..c760fa1c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -270,7 +270,9 @@ def get_years(start_year=2021): "demo_notebooks/Demo_dandi_NeuroDataReHack_2023": "_static/thumbnails/dandi_demo_monkey.png", "demo_notebooks/Demo_xCEBRA_RatInABox": - "_static/thumbnails/xCEBRA.png" + "_static/thumbnails/xCEBRA.png", + "demo_notebooks/Demo_hippocampus_unified": + "_static/thumbnails/UnifiedCEBRA.png", } rst_prolog = r""" diff --git a/docs/source/usage.rst b/docs/source/usage.rst index aaf09a25..82e45a0b 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1436,17 +1436,14 @@ gets initialized which also allows the `prior` to be directly parametrized. solver.fit(loader=loader) # 7. Transform Embedding - train_batches = np.lib.stride_tricks.sliding_window_view( - neural_data, neural_model.get_offset().__len__(), axis=0 - ) - x_train_emb = solver.transform( - torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device) - ).to(device) + torch.from_numpy(neural_data).type(torch.FloatTensor).to(device), + pad_before_transform=True, + batch_size=512).to(device) # 8. Plot Embedding cebra.plot_embedding( x_train_emb.cpu(), - discrete_label[neural_model.get_offset().__len__() - 1 :, 0], + discrete_label[:,0], markersize=10, ) diff --git a/tests/_util.py b/tests/_util.py index b4a0e07d..42dd54cb 100644 --- a/tests/_util.py +++ b/tests/_util.py @@ -74,3 +74,8 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments): slow_arg, generate_only=True))[0] for slow_arg in slow_arguments ] return parametrize_slow("estimator,check", fast_params, slow_params) + + +def parametrize_device(func): + _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) + return pytest.mark.parametrize("device", _devices)(func) diff --git a/tests/_utils_deprecated.py b/tests/_utils_deprecated.py new file mode 100644 index 00000000..5c533f26 --- /dev/null +++ b/tests/_utils_deprecated.py @@ -0,0 +1,130 @@ +import warnings +from typing import Optional, Union + +import numpy as np +import numpy.typing as npt +import sklearn.utils.validation as sklearn_utils_validation +import torch + +import cebra +import cebra.integrations.sklearn.utils as sklearn_utils +import cebra.models + + +#NOTE: Deprecated: transform is now handled in the solver but the original +# method is kept here for testing. +def cebra_transform_deprecated(cebra_model, + X: Union[npt.NDArray, torch.Tensor], + session_id: Optional[int] = None) -> npt.NDArray: + """Transform an input sequence and return the embedding. + + Args: + cebra_model: The CEBRA model to use for the transform. + X: A numpy array or torch tensor of size ``time x dimension``. + session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for + multisession, set to ``None`` for single session. + + Returns: + A :py:func:`numpy.array` of size ``time x output_dimension``. + + Example: + + >>> import cebra + >>> import numpy as np + >>> dataset = np.random.uniform(0, 1, (1000, 30)) + >>> cebra_model = cebra.CEBRA(max_iterations=10) + >>> cebra_model.fit(dataset) + CEBRA(max_iterations=10) + >>> embedding = cebra_model.transform(dataset) + + """ + warnings.warn( + "The method is deprecated " + "but kept for testing puroposes." + "We recommend using `transform` instead.", + DeprecationWarning, + stacklevel=2) + + sklearn_utils_validation.check_is_fitted(cebra_model, "n_features_") + + if isinstance(X, np.ndarray): + X = torch.from_numpy(X) + + model, offset = cebra_model._select_model(X, session_id) + + # Input validation + X = sklearn_utils.check_input_array(X, min_samples=len(cebra_model.offset_)) + input_dtype = X.dtype + + with torch.no_grad(): + model.eval() + + if cebra_model.pad_before_transform: + X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), + mode="edge") + X = torch.from_numpy(X).float().to(cebra_model.device_) + + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + X = X.transpose(1, 0).unsqueeze(0) + output = model(X).cpu().numpy().squeeze(0).transpose(1, 0) + else: + # Standard evaluation, (T, C, dt) + output = model(X).cpu().numpy() + + if input_dtype == "float64": + return output.astype(input_dtype) + + return output + + +# NOTE: Deprecated: batched transform can now be performed (more memory efficient) +# using the transform method of the model, and handling padding is implemented +# directly in the base Solver. This method is kept for testing purposes. +@torch.no_grad() +def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver", + inputs: torch.Tensor) -> torch.Tensor: + """Transform the input data using the model. + + Args: + solver: The solver containing the model and device. + inputs: The input data to transform. + + Returns: + The transformed data. + """ + + warnings.warn( + "The method is deprecated " + "but kept for testing puroposes." + "We recommend using `transform` instead.", + DeprecationWarning, + stacklevel=2) + + offset = solver.model.get_offset() + solver.model.eval() + X = inputs.cpu().numpy() + X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), mode="edge") + X = torch.from_numpy(X).float().to(solver.device) + + if isinstance(solver.model.module, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + X = X.transpose(1, 0).unsqueeze(0) + outputs = solver.model(X) + + # switch back from (1, C, T) -> (T, C) + if isinstance(outputs, torch.Tensor): + assert outputs.dim() == 3 and outputs.shape[0] == 1 + outputs = outputs.squeeze(0).transpose(1, 0) + elif isinstance(outputs, tuple): + assert all(tensor.dim() == 3 and tensor.shape[0] == 1 + for tensor in outputs) + outputs = (output.squeeze(0).transpose(1, 0) for output in outputs) + outputs = tuple(outputs) + else: + raise ValueError("Invalid condition in solver.transform") + else: + # Standard evaluation, (T, C, dt) + outputs = solver.model(X) + + return outputs diff --git a/tests/test_data_masking.py b/tests/test_data_masking.py new file mode 100644 index 00000000..1b4976af --- /dev/null +++ b/tests/test_data_masking.py @@ -0,0 +1,206 @@ +import copy + +import pytest +import torch + +import cebra.data.mask +from cebra.data.masking import MaskedMixin + +#### Tests for Mask class #### + + +@pytest.mark.parametrize("mask", [ + cebra.data.mask.RandomNeuronMask, + cebra.data.mask.RandomTimestepMask, + cebra.data.mask.NeuronBlockMask, +]) +def test_random_mask(mask: cebra.data.mask.Mask): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + mask = mask(masking_value=0.5) + masked_data = mask.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert (masked_data <= 1).all() and ( + masked_data >= 0).all(), "Masked data should only contain values 0 or 1" + assert torch.sum(masked_data) < torch.sum( + data), "Masked data should have fewer active neurons than original data" + + +def test_timeblock_mask(): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + mask = cebra.data.mask.TimeBlockMask(masking_value=(0.035, 10)) + masked_data = mask.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert (masked_data <= 1).all() and ( + masked_data >= 0).all(), "Masked data should only contain values 0 or 1" + assert torch.sum(masked_data) < torch.sum( + data), "Masked data should have fewer active neurons than original data" + + +#### Tests for MaskedMixin class #### + + +def test_masked_mixin_no_masks(): + mixin = MaskedMixin() + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert torch.equal( + data, + masked_data), "Data should remain unchanged when no masks are applied" + + +@pytest.mark.parametrize( + "mask", ["RandomNeuronMask", "RandomTimestepMask", "NeuronBlockMask"]) +def test_masked_mixin_random_mask(mask): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + + mixin = MaskedMixin() + assert mixin.masks == [], "Masks should be empty initially" + + mixin.set_masks({mask: 0.5}) + assert len(mixin.masks) == 1, "One mask should be set" + assert isinstance(mixin.masks[0], + getattr(cebra.data.mask, + mask)), f"Mask should be of type {mask}" + if isinstance(mixin.masks[0], cebra.data.mask.NeuronBlockMask): + assert mixin.masks[ + 0].mask_prop == 0.5, "Masking value should be set correctly" + else: + assert mixin.masks[ + 0].mask_ratio == 0.5, "Masking value should be set correctly" + + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + mixin.set_masks({mask: [0.5, 0.1]}) + assert len(mixin.masks) == 1, "One mask should be set" + assert isinstance(mixin.masks[0], + getattr(cebra.data.mask, + mask)), f"Mask should be of type {mask}" + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + mixin.set_masks({mask: (0.3, 0.9, 0.05)}) + assert len(mixin.masks) == 1, "One mask should be set" + assert isinstance(mixin.masks[0], + getattr(cebra.data.mask, + mask)), f"Mask should be of type {mask}" + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + +def test_apply_mask_with_time_block_mask(): + mixin = MaskedMixin() + + with pytest.raises(AssertionError, match="sampled_rate.*masked_seq_len"): + mixin.set_masks({"TimeBlockMask": 0.2}) + + with pytest.raises(AssertionError, match="(sampled_rate.*masked_seq_len)"): + mixin.set_masks({"TimeBlockMask": [0.2, 10]}) + + with pytest.raises(AssertionError, match="between.*0.0.*1.0"): + mixin.set_masks({"TimeBlockMask": (-2, 10)}) + + with pytest.raises(AssertionError, match="between.*0.0.*1.0"): + mixin.set_masks({"TimeBlockMask": (2, 10)}) + + with pytest.raises(AssertionError, match="integer.*greater"): + mixin.set_masks({"TimeBlockMask": (0.2, -10)}) + + with pytest.raises(AssertionError, match="integer.*greater"): + mixin.set_masks({"TimeBlockMask": (0.2, 5.5)}) + + mixin.set_masks({"TimeBlockMask": (0.035, 10)}) # Correct usage + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + +def test_multiple_masks_mixin(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5, "RandomTimestepMask": 0.3}) + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data), "Data should be modified when multiple masks are applied" + + masked_data2 = mixin.apply_mask(copy.deepcopy(masked_data)) + assert masked_data2.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data2), "Data should be modified when multiple masks are applied" + assert not torch.equal( + masked_data, masked_data2 + ), "Masked data should be different for different iterations" + + +def test_single_dim_input(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + data = torch.ones((10, 1, 30)) # Single neuron + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified even with a single neuron" + + mixin = MaskedMixin() + mixin.set_masks({"RandomTimestepMask": 0.5}) + data = torch.ones((10, 20, 1)) # Single timestep + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data), "Data should be modified even with a single timestep" + + +def test_apply_mask_with_invalid_input(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + + with pytest.raises(ValueError, match="Data must be a 3D tensor"): + data = torch.ones( + (10, 20)) # Invalid tensor shape (missing offset dimension) + mixin.apply_mask(data) + + with pytest.raises(ValueError, match="Data must be a float32 tensor"): + data = torch.ones((10, 20, 30), dtype=torch.int32) + mixin.apply_mask(data) + + +def test_apply_mask_with_chunk_size(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + data = torch.ones((10000, 20, 30)) # Large tensor to test chunking + masked_data = mixin.apply_mask(copy.deepcopy(data), chunk_size=1000) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 2b704391..656559bb 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -411,3 +411,16 @@ def test_new_delta_normal_with_multidimensional_index(delta, numerical_check): pytest.skip( "multivariate delta distribution can not accurately sample with the " "given parameters. TODO: Add a warning message for these cases.") + + +@pytest.mark.parametrize("time_offset", [1, 5, 10]) +def test_unified_distribution(time_offset): + dataset = cebra_datasets.init("demo-continuous-unified") + sampler = cebra_distr.UnifiedSampler(dataset, time_offset=time_offset) + + num_samples = 5 + sample = sampler.sample_prior(num_samples) + assert sample.shape == (dataset.num_sessions, num_samples) + + positive = sampler.sample_conditional(sample) + assert positive.shape == (dataset.num_sessions, num_samples) diff --git a/tests/test_integration_xcebra.py b/tests/test_integration_xcebra.py index 4e647916..760e26ef 100644 --- a/tests/test_integration_xcebra.py +++ b/tests/test_integration_xcebra.py @@ -1,5 +1,7 @@ import pickle +import _utils_deprecated +import numpy as np import pytest import torch @@ -150,3 +152,39 @@ def test_synthetic_data_training(synthetic_data, device): assert Z2_hat.shape == Z2.shape, f"Incorrect Z2 embedding dimension: {Z2_hat.shape}" assert not torch.isnan(Z1_hat).any(), "NaN values in Z1 embedding" assert not torch.isnan(Z2_hat).any(), "NaN values in Z2 embedding" + + # Test the transform + solver.model.split_outputs = False + transform_embedding = solver.transform(data.neural.to(device)) + assert transform_embedding.shape[ + 1] == n_latents, "Incorrect embedding dimension" + assert not torch.isnan(transform_embedding).any(), "NaN values in embedding" + assert np.allclose(embedding, transform_embedding, rtol=1e-4, atol=1e-4) + + # Test the transform with batching + batched_embedding = solver.transform(data.neural.to(device), batch_size=512) + assert batched_embedding.shape[ + 1] == n_latents, "Incorrect embedding dimension" + assert not torch.isnan(batched_embedding).any(), "NaN values in embedding" + assert np.allclose(embedding, batched_embedding, rtol=1e-4, atol=1e-4) + + assert np.allclose(transform_embedding, + batched_embedding, + rtol=1e-4, + atol=1e-4) + + # Test and compare the previous transform (transform_deprecated) + deprecated_transform_embedding = _utils_deprecated.multiobjective_transform_deprecated( + solver, data.neural.to(device)) + assert np.allclose(embedding, + deprecated_transform_embedding, + rtol=1e-4, + atol=1e-4) + assert np.allclose(transform_embedding, + deprecated_transform_embedding, + rtol=1e-4, + atol=1e-4) + assert np.allclose(batched_embedding, + deprecated_transform_embedding, + rtol=1e-4, + atol=1e-4) diff --git a/tests/test_loader.py b/tests/test_loader.py index 562f64a7..cb6be9a7 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -19,16 +19,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import _util +import numpy as np import pytest import torch import cebra.data import cebra.io - -def parametrize_device(func): - _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) - return pytest.mark.parametrize("device", _devices)(func) +BATCH_SIZE = 32 +NUMS_NEURAL = [3, 4, 5] class LoadSpeed: @@ -107,7 +107,11 @@ def _assert_dataset_on_correct_device(loader, device): assert hasattr(loader, "dataset") assert hasattr(loader, "device") assert isinstance(loader.dataset, cebra.io.HasDevice) - assert loader.dataset.neural.device.type == device + if isinstance(loader, cebra.data.SingleSessionDataset): + assert loader.dataset.neural.device.type == device + elif isinstance(loader, cebra.data.MultiSessionDataset): + for session in loader.dataset.iter_sessions(): + assert session.neural.device.type == device def test_demo_data(): @@ -130,13 +134,15 @@ def _to_str(val): assert _to_str(first) == _to_str(second) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", [ ("demo-discrete", cebra.data.DiscreteDataLoader), ("demo-continuous", cebra.data.ContinuousDataLoader), ("demo-mixed", cebra.data.MixedDataLoader), + ("demo-continuous-multisession", cebra.data.MultiSessionLoader), + ("demo-continuous-unified", cebra.data.UnifiedLoader), ], ) def test_device(data_name, loader_initfunc, device): @@ -147,7 +153,7 @@ def test_device(data_name, loader_initfunc, device): other_device = swap.get(device) dataset = RandomDataset(N=100, device=other_device) - loader = loader_initfunc(dataset, num_steps=10, batch_size=32) + loader = loader_initfunc(dataset, num_steps=10, batch_size=BATCH_SIZE) loader.to(device) assert loader.dataset == dataset _assert_device(loader.device, device) @@ -156,7 +162,7 @@ def test_device(data_name, loader_initfunc, device): _assert_device(loader.get_indices(10).reference.device, device) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize("prior", ("uniform", "empirical")) def test_discrete(prior, device, benchmark): dataset = RandomDataset(N=100, device=device) @@ -171,7 +177,7 @@ def test_discrete(prior, device, benchmark): benchmark(load_speed) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize("conditional", ("time", "time_delta")) def test_continuous(conditional, device, benchmark): dataset = RandomDataset(N=100, d=5, device=device) @@ -199,7 +205,7 @@ def _check_attributes(obj, is_list=False): raise TypeError() -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", [ @@ -211,7 +217,7 @@ def _check_attributes(obj, is_list=False): def test_singlesession_loader(data_name, loader_initfunc, device): data = cebra.datasets.init(data_name) data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=32) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) _assert_dataset_on_correct_device(loader, device) index = loader.get_indices(100) @@ -219,25 +225,33 @@ def test_singlesession_loader(data_name, loader_initfunc, device): for batch in loader: _check_attributes(batch) - assert len(batch.positive) == 32 + assert len(batch.positive) == BATCH_SIZE -def test_multisession_cont_loader(): - data = cebra.datasets.MultiContinuous(nums_neural=[3, 4, 5], - num_behavior=5, - num_timepoints=100) - loader = cebra.data.ContinuousMultiSessionDataLoader( - data, - num_steps=10, - batch_size=32, - ) +@_util.parametrize_device +@pytest.mark.parametrize( + "data_name, loader_initfunc", + [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ("demo-discrete-multisession", + cebra.data.DiscreteMultiSessionDataLoader), + ], +) +def test_multisession_loader(data_name, loader_initfunc, device): + data = cebra.datasets.init(data_name) + data.to(device) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) + + _assert_dataset_on_correct_device(loader, device) # Check the sampler assert hasattr(loader, "sampler") ref_idx = loader.sampler.sample_prior(1000) - assert len(ref_idx) == 3 # num_sessions - for session in range(3): - assert ref_idx[session].max() < 100 + assert len(ref_idx) == len(NUMS_NEURAL) + for session in range(len(NUMS_NEURAL)): + assert ref_idx[session].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) assert pos_idx is not None @@ -245,6 +259,8 @@ def test_multisession_cont_loader(): assert idx_rev is not None batch = next(iter(loader)) + for i, n_neurons in enumerate(NUMS_NEURAL): + assert batch[i].reference.shape == (BATCH_SIZE, n_neurons, 10) def _mix(array, idx): shape = array.shape @@ -259,82 +275,70 @@ def _process(batch, feature_dim=1): [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], dim=0).repeat(1, 1, feature_dim) - assert batch[0].reference.shape == (32, 3, 10) - assert batch[1].reference.shape == (32, 4, 10) - assert batch[2].reference.shape == (32, 5, 10) - dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, 32, 6) + assert dummy_prediction.shape == (3, BATCH_SIZE, 6) _mix(dummy_prediction, batch[0].index) + index = loader.get_indices(100) + #print(index[0]) + #print(type(index)) + _check_attributes(index, is_list=False) -def test_multisession_disc_loader(): - data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5], - num_timepoints=100) - loader = cebra.data.DiscreteMultiSessionDataLoader( - data, - num_steps=10, - batch_size=32, - ) - - # Check the sampler - assert hasattr(loader, "sampler") - ref_idx = loader.sampler.sample_prior(1000) - assert len(ref_idx) == 3 # num_sessions - - # Check sample points are in session length range - for session in range(3): - assert ref_idx[session].max() < loader.sampler.session_lengths[session] - pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) - - assert pos_idx is not None - assert idx is not None - assert idx_rev is not None - - batch = next(iter(loader)) - - def _mix(array, idx): - shape = array.shape - n, m = shape[:2] - mixed = array.reshape(n * m, -1)[idx] - print(mixed.shape, array.shape, idx.shape) - return mixed.reshape(shape) - - def _process(batch, feature_dim=1): - """Given list_i[(N,d_i)] batch, return (#session, N, feature_dim) tensor""" - return torch.stack( - [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], - dim=0).repeat(1, 1, feature_dim) - - assert batch[0].reference.shape == (32, 3, 10) - assert batch[1].reference.shape == (32, 4, 10) - assert batch[2].reference.shape == (32, 5, 10) - - dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, 32, 6) - _mix(dummy_prediction, batch[0].index) + for batch in loader: + _check_attributes(batch, is_list=True) + for session_batch in batch: + assert len(session_batch.positive) == BATCH_SIZE -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", - [('demo-discrete-multisession', cebra.data.DiscreteMultiSessionDataLoader), - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader)], + [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ], ) -def test_multisession_loader(data_name, loader_initfunc, device): - # TODO change number of timepoints across the sessions - +def test_unified_loader(data_name, loader_initfunc, device): data = cebra.datasets.init(data_name) - kwargs = dict(num_steps=10, batch_size=32) - loader = loader_initfunc(data, **kwargs) + data.to(device) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) + + _assert_dataset_on_correct_device(loader, device) + + # Check the sampler + num_samples = 100 + assert hasattr(loader, "sampler") + ref_idx = loader.sampler.sample_all_uniform_prior(num_samples) + assert ref_idx.shape == (len(NUMS_NEURAL), num_samples) + assert isinstance(ref_idx, np.ndarray) + + for session in range(len(NUMS_NEURAL)): + assert ref_idx[session].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS + pos_idx = loader.sampler.sample_conditional(ref_idx) + assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + + for session in range(len(NUMS_NEURAL)): + ref_idx = torch.from_numpy( + loader.sampler.sample_all_uniform_prior( + num_samples=num_samples)[session]) + assert ref_idx.shape == (num_samples,) + all_ref_idx = loader.sampler.sample_all_sessions(ref_idx=ref_idx, + session_id=session) + assert all_ref_idx.shape == (len(NUMS_NEURAL), num_samples) + assert isinstance(all_ref_idx, torch.Tensor) + for i in range(len(all_ref_idx)): + assert all_ref_idx[i].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS + + for i in range(len(all_ref_idx)): + pos_idx = loader.sampler.sample_conditional(all_ref_idx) + assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + + # Check the batch + batch = next(iter(loader)) + assert batch.reference.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.positive.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.negative.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) index = loader.get_indices(100) - print(index[0]) - print(type(index)) _check_attributes(index, is_list=False) - - for batch in loader: - _check_attributes(batch, is_list=True) - for session_batch in batch: - assert len(session_batch.positive) == 32 diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 3b9d309b..c3d2095c 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -24,6 +24,7 @@ import warnings import _util +import _utils_deprecated import numpy as np import pkg_resources import pytest @@ -231,7 +232,7 @@ def iterate_models(): ) in itertools.product( [ "offset10-model", "offset10-model-mse", "offset1-model", - "resample-model" + "offset40-model-4x-subsample" ], _DEVICES, ["euclidean", "cosine"], @@ -319,7 +320,7 @@ def test_sklearn(model_architecture, device): model_architecture=model_architecture, time_offsets=10, learning_rate=3e-4, - max_iterations=5, + max_iterations=2, device=device, output_dimension=output_dimension, batch_size=42, @@ -341,6 +342,20 @@ def test_sklearn(model_architecture, device): assert cebra_model.num_sessions is None embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) + + if model_architecture in [ + "offset36-model-cpu", "offset36-model-dropout-cpu", + "offset36-model-more-dropout-cpu", + "offset40-model-4x-subsample-cpu", + "offset20-model-4x-subsample-cpu", "offset36-model-cuda", + "offset36-model-dropout-cuda", "offset36-model-more-dropout-cuda", + "offset40-model-4x-subsample-cuda", + "offset20-model-4x-subsample-cuda" + ]: + with pytest.raises(ValueError, match="required.*offset.*length"): + embedding = cebra_model.transform(X, batch_size=10) # continuous behavior contrastive cebra_model.fit(X, y_c1, y_c2) @@ -352,9 +367,17 @@ def test_sklearn(model_architecture, device): assert isinstance(embedding, np.ndarray) embedding = cebra_model.transform(X, session_id=0) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, session_id=0, batch_size=50) + assert isinstance(embedding, np.ndarray) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=2) + with pytest.raises(ValueError, match="batch_size"): + embedding = cebra_model.transform(X, batch_size=0) + with pytest.raises(ValueError, match="batch_size"): + embedding = cebra_model.transform(X, batch_size=-10) with pytest.raises(ValueError, match="Invalid.*labels"): cebra_model.fit(X, [y_c1, y_c1_s2]) with pytest.raises(ValueError, match="Invalid.*samples"): @@ -367,11 +390,15 @@ def test_sklearn(model_architecture, device): cebra_model.fit(X, y_d) embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) # mixed cebra_model.fit(X, y_c1, y_c2, y_d) embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) # multi-session discrete behavior contrastive cebra_model.fit([X, X_s2], [y_d, y_d_s2]) @@ -385,12 +412,15 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X_s2, session_id=1) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X_s2.shape[0], output_dimension) + embedding = cebra_model.transform(X_s2, session_id=1, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X_s2.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X, session_id=1) - with pytest.raises(RuntimeError, match="No.*session_id"): + with pytest.raises(RuntimeError, match="session_id.*provided"): embedding = cebra_model.transform(X) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=2) @@ -409,12 +439,15 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X_s2, session_id=1) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X_s2.shape[0], output_dimension) + embedding = cebra_model.transform(X_s2, session_id=1, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X_s2.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X, session_id=1) - with pytest.raises(RuntimeError, match="No.*session_id"): + with pytest.raises(RuntimeError, match="session_id.*provided"): embedding = cebra_model.transform(X) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=2) @@ -440,6 +473,9 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X, session_id=2) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X.shape[0], output_dimension) + embedding = cebra_model.transform(X, session_id=2, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) @@ -447,7 +483,7 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X_s2, session_id=2) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X, session_id=1) - with pytest.raises(RuntimeError, match="No.*session_id"): + with pytest.raises(RuntimeError, match="session_id.*provided"): embedding = cebra_model.transform(X) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=3) @@ -465,6 +501,9 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X, session_id=2) assert isinstance(embedding, np.ndarray) assert embedding.shape == (X.shape[0], output_dimension) + embedding = cebra_model.transform(X, session_id=2, batch_size=50) + assert isinstance(embedding, np.ndarray) + assert embedding.shape == (X.shape[0], output_dimension) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X_s2, session_id=0) @@ -472,7 +511,7 @@ def test_sklearn(model_architecture, device): embedding = cebra_model.transform(X_s2, session_id=2) with pytest.raises(ValueError, match="shape"): embedding = cebra_model.transform(X, session_id=1) - with pytest.raises(RuntimeError, match="No.*session_id"): + with pytest.raises(RuntimeError, match="session_id.*provided"): embedding = cebra_model.transform(X) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=3) @@ -709,6 +748,8 @@ def check_first_layer_dim(model, X): check_first_layer_dim(cebra_model, X_s2) embedding = cebra_model.transform(X_s2) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X_s2, batch_size=50) + assert isinstance(embedding, np.ndarray) cebra_model.fit(X, y_c1, y_c2, adapt=True) check_first_layer_dim(cebra_model, X) @@ -716,6 +757,8 @@ def check_first_layer_dim(model, X): assert isinstance(embedding, np.ndarray) embedding = cebra_model.transform(X, session_id=0) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, session_id=0, batch_size=50) + assert isinstance(embedding, np.ndarray) with pytest.raises(RuntimeError, match="Invalid.*session_id"): embedding = cebra_model.transform(X, session_id=2) @@ -728,11 +771,15 @@ def check_first_layer_dim(model, X): check_first_layer_dim(cebra_model, X_s2) embedding = cebra_model.transform(X_s2) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X_s2, batch_size=50) + assert isinstance(embedding, np.ndarray) cebra_model.fit(X, y_c1, y_c2, y_d, adapt=True) check_first_layer_dim(cebra_model, X) embedding = cebra_model.transform(X) assert isinstance(embedding, np.ndarray) + embedding = cebra_model.transform(X, batch_size=50) + assert isinstance(embedding, np.ndarray) with pytest.raises(NotImplementedError, match=".*multisession.*"): cebra_model.fit([X, X_s2], [y_c1, y_c1_s2], adapt=True) @@ -845,8 +892,8 @@ def test_sklearn_full(model_architecture, device, pad_before_transform): @pytest.mark.parametrize("model_architecture,device", - [("resample-model", "cpu"), - ("resample5-model", "cpu")]) + [("offset40-model-4x-subsample", "cpu"), + ("offset20-model-4x-subsample", "cpu")]) def test_sklearn_resampling_model(model_architecture, device): cebra_model = cebra_sklearn_cebra.CEBRA( model_architecture=model_architecture, @@ -866,10 +913,12 @@ def test_sklearn_resampling_model(model_architecture, device): cebra_model.fit(X, y_c1) output = cebra_model.transform(X) assert output.shape == (250, 4) + output = cebra_model.transform(X, batch_size=100) + assert output.shape == (250, 4) @pytest.mark.parametrize("model_architecture,device", - [("resample1-model", "cpu")]) + [("offset4-model-2x-subsample", "cpu")]) def test_sklearn_resampling_model_not_yet_supported(model_architecture, device): cebra_model = cebra_sklearn_cebra.CEBRA( model_architecture=model_architecture, max_iterations=5) @@ -1291,3 +1340,207 @@ def test_check_device(): torch.backends.mps.is_built = lambda: False with pytest.raises(ValueError): cebra_sklearn_utils.check_device(device) + + +@_util.parametrize_slow( + arg_names="model_architecture,device", + fast_arguments=list( + itertools.islice( + itertools.product( + cebra_sklearn_cebra.CEBRA.supported_model_architectures(), + _DEVICES), + 2, + )), + slow_arguments=list( + itertools.product( + cebra_sklearn_cebra.CEBRA.supported_model_architectures(), + _DEVICES)), +) +def test_new_transform(model_architecture, device): + """ + This is a test that the original sklearn transform returns the same output as + the new sklearn transform that uses the pytorch solver transform. + """ + output_dimension = 4 + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture=model_architecture, + time_offsets=10, + learning_rate=3e-4, + max_iterations=2, + device=device, + output_dimension=output_dimension, + batch_size=42, + verbose=True, + ) + + # example dataset + X = np.random.uniform(0, 1, (1000, 50)) + X_s2 = np.random.uniform(0, 1, (800, 30)) + y_c1 = np.random.uniform(0, 1, (1000, 5)) + y_c1_s2 = np.random.uniform(0, 1, (800, 5)) + y_c2 = np.random.uniform(0, 1, (1000, 2)) + y_d = np.random.randint(0, 10, (1000,)) + y_d_s2 = np.random.randint(0, 10, (800,)) + + # time contrastive + cebra_model.fit(X) + embedding1 = cebra_model.transform(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # continuous behavior contrastive + cebra_model.fit(X, y_c1, y_c2) + assert cebra_model.num_sessions is None + + embedding1 = cebra_model.transform(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X)) + embedding2 = _utils_deprecated.cebra_transform_deprecated( + cebra_model, torch.Tensor(X)) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + torch.Tensor(X), + session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # tensor input + cebra_model.fit(torch.Tensor(X), torch.Tensor(y_c1), torch.Tensor(y_c2)) + + # discrete behavior contrastive + cebra_model.fit(X, y_d) + embedding1 = cebra_model.transform(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # mixed + cebra_model.fit(X, y_c1, y_c2, y_d) + embedding1 = cebra_model.transform(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session discrete behavior contrastive + cebra_model.fit([X, X_s2], [y_d, y_d_s2]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + torch.Tensor(X), + session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X_s2, + session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session continuous behavior contrastive + cebra_model.fit([X, X_s2], [y_c1, y_c1_s2]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + torch.Tensor(X), + session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = cebra_model.transform(X_s2, session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session tensor inputs + cebra_model.fit( + [torch.Tensor(X), torch.Tensor(X_s2)], + [torch.Tensor(y_c1), torch.Tensor(y_c1_s2)], + ) + + # multi-session discrete behavior contrastive, more than two sessions + cebra_model.fit([X, X_s2, X], [y_d, y_d_s2, y_d]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X_s2, + session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X, session_id=2) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=2) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + # multi-session continuous behavior contrastive, more than two sessions + cebra_model.fit([X, X_s2, X], [y_c1, y_c1_s2, y_c1]) + + embedding1 = cebra_model.transform(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X_s2, session_id=1) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X_s2, + session_id=1) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + embedding1 = cebra_model.transform(X, session_id=2) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=2) + assert np.allclose(embedding1, embedding2, rtol=1e-5, + atol=1e-8), "Arrays are not close enough" + + +def test_last_incomplete_batch_smaller_than_offset(): + """ + When offset of the model is larger than the remaining samples in the + last batch, an error could happen. We merge the penultimate + and last batches together to avoid this. + """ + train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100), + continuous=np.random.rand(20111, 2)) + + model = cebra.CEBRA(max_iterations=2, + model_architecture="offset36-model-more-dropout", + device="cpu") + model.fit(train.neural, train.continuous) + + _ = model.transform(train.neural, batch_size=300) diff --git a/tests/test_solver.py b/tests/test_solver.py index 65f49f71..cf2c62ad 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -19,7 +19,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import copy +import tempfile +import numpy as np import pytest import torch from torch import nn @@ -31,41 +34,25 @@ device = "cpu" -single_session_tests = [] -for args in [ - ("demo-discrete", cebra.data.DiscreteDataLoader), - ("demo-continuous", cebra.data.ContinuousDataLoader), - ("demo-mixed", cebra.data.MixedDataLoader), -]: - single_session_tests.append((*args, cebra.solver.SingleSessionSolver)) - -single_session_hybrid_tests = [] -for args in [("demo-continuous", cebra.data.HybridDataLoader)]: - single_session_hybrid_tests.append( - (*args, cebra.solver.SingleSessionHybridSolver)) - -multi_session_tests = [] -for args in [("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader), - ("demo-discrete-multisession", - cebra.data.DiscreteMultiSessionDataLoader)]: - multi_session_tests.append((*args, cebra.solver.MultiSessionSolver)) - # multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) - def _get_loader(data_name, loader_initfunc): data = cebra.datasets.init(data_name) - kwargs = dict(num_steps=10, batch_size=32) + kwargs = dict(num_steps=2, batch_size=32) loader = loader_initfunc(data, **kwargs) - return loader + return loader, data + +OUTPUT_DIMENSION = 3 -def _make_model(dataset): + +def _make_model(dataset, model_architecture="offset10-model"): # TODO flexible input dimension - return nn.Sequential( - nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), - nn.Flatten(start_dim=1, end_dim=-1), - ) + # return nn.Sequential( + # nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), + # nn.Flatten(start_dim=1, end_dim=-1), + # ) + return cebra.models.init(model_architecture, dataset.input_dimension, 32, + OUTPUT_DIMENSION) def _make_behavior_model(dataset): @@ -76,29 +63,144 @@ def _make_behavior_model(dataset): ) -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - single_session_tests) -def test_single_session(data_name, loader_initfunc, solver_initfunc): - loader = _get_loader(data_name, loader_initfunc) - model = _make_model(loader.dataset) +def _assert_same_state_dict(first, second): + assert first.keys() == second.keys() + for key in first: + if isinstance(first[key], torch.Tensor): + assert torch.allclose(first[key], second[key]), key + elif isinstance(first[key], dict): + _assert_same_state_dict(first[key], second[key]), key + else: + assert first[key] == second[key] + + +def check_if_fit(model): + """Check if a model was already fit. + + Args: + model: The model to check. + + Returns: + True if the model was already fit. + """ + return hasattr(model, "n_features_") + + +def _assert_equal(original_solver, loaded_solver): + for k in original_solver.model.state_dict(): + assert original_solver.model.state_dict()[k].all( + ) == loaded_solver.model.state_dict()[k].all() + assert check_if_fit(loaded_solver) == check_if_fit(original_solver) + + if check_if_fit(loaded_solver): + _assert_same_state_dict(original_solver.state_dict_, + loaded_solver.state_dict_) + X = np.random.normal(0, 1, (100, 1)) + + if loaded_solver.num_sessions is not None: + assert np.allclose(loaded_solver.transform(X, session_id=0), + original_solver.transform(X, session_id=0)) + else: + assert np.allclose(loaded_solver.transform(X), + original_solver.transform(X)) + + +def _test_single_session_transform(solver, X, offset): + embedding = solver.transform(X) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(torch.Tensor(X)) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, session_id=0) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, pad_before_transform=False) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == ( + (X.shape[0] - len(offset)) // solver.model.resample_factor + 1, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0] - len(offset) + 1, + OUTPUT_DIMENSION) + + with pytest.raises(ValueError, match="torch.Tensor"): + solver.transform(X.numpy()) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X, session_id=2) + + for param in solver.parameters(): + assert isinstance(param, torch.Tensor) + + fitted_solver = copy.deepcopy(solver) + with tempfile.TemporaryDirectory() as temp_dir: + solver.save(temp_dir) + solver.load(temp_dir) + _assert_equal(fitted_solver, solver) + + +@pytest.mark.parametrize( + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]]) +def test_single_session(data_name, loader_initfunc, model_architecture, + solver_initfunc): + loader, data = _get_loader(data_name, loader_initfunc) + model = _make_model(data, model_architecture) + data.configure_for(model) + offset = model.get_offset() criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, criterion=criterion, - optimizer=optimizer) + optimizer=optimizer, + tqdm_on=False) batch = next(iter(loader)) - assert batch.reference.shape == (32, loader.dataset.input_dimension, 10) + assert batch.reference.shape[:2] == (32, loader.dataset.input_dimension) log = solver.step(batch) assert isinstance(log, dict) + X = loader.dataset.neural + with pytest.raises(ValueError, match="not.*fitted"): + solver.transform(X) + solver.fit(loader) + assert not hasattr(solver, 'num_sessions') + assert solver.n_features == X.shape[1] + + _test_single_session_transform(solver, X, offset) -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - single_session_tests) -def test_single_session_auxvar(data_name, loader_initfunc, solver_initfunc): + +@pytest.mark.parametrize( + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]]) +def test_single_session_auxvar(data_name, loader_initfunc, model_architecture, + solver_initfunc): pytest.skip("Not yet supported") @@ -122,13 +224,21 @@ def test_single_session_auxvar(data_name, loader_initfunc, solver_initfunc): solver.fit(loader) - -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - single_session_hybrid_tests) -def test_single_session_hybrid(data_name, loader_initfunc, solver_initfunc): - loader = _get_loader(data_name, loader_initfunc) - model = cebra.models.init("offset10-model", loader.dataset.input_dimension, - 32, 3) + assert not hasattr(solver, 'num_sessions') + assert solver.n_features == loader.dataset.neural.shape[1] + + +@pytest.mark.parametrize( + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [("demo-continuous", model, cebra.data.HybridDataLoader, + cebra.solver.SingleSessionHybridSolver) + for model in ["offset1-model", "offset10-model"]]) +def test_single_session_hybrid(data_name, loader_initfunc, model_architecture, + solver_initfunc): + loader, data = _get_loader(data_name, loader_initfunc) + model = _make_model(data, model_architecture) + data.configure_for(model) + offset = model.get_offset() criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, @@ -141,16 +251,39 @@ def test_single_session_hybrid(data_name, loader_initfunc, solver_initfunc): log = solver.step(batch) assert isinstance(log, dict) + X = loader.dataset.neural + with pytest.raises(ValueError, match="not.*fitted"): + solver.transform(X) + solver.fit(loader) + assert not hasattr(solver, 'num_sessions') + assert solver.n_features == X.shape[1] + + _test_single_session_transform(solver, X, offset) + + +@pytest.mark.parametrize( + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.MultiSessionSolver) + for dataset, loader in [ + ("demo-discrete-multisession", + cebra.data.DiscreteMultiSessionDataLoader), + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ] + for model in ["offset1-model", "offset10-model"]]) +def test_multi_session(data_name, loader_initfunc, model_architecture, + solver_initfunc): + loader, data = _get_loader(data_name, loader_initfunc) + model = nn.ModuleList([ + _make_model(dataset, model_architecture) + for dataset in data.iter_sessions() + ]) + data.configure_for(model) + offset_length = len(model[0].get_offset()) -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - multi_session_tests) -def test_multi_session(data_name, loader_initfunc, solver_initfunc): - loader = _get_loader(data_name, loader_initfunc) criterion = cebra.models.InfoNCE() - model = nn.ModuleList( - [_make_model(dataset) for dataset in loader.dataset.iter_sessions()]) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, @@ -161,37 +294,110 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): for session_id, dataset in enumerate(loader.dataset.iter_sessions()): assert batch[session_id].reference.shape == (32, dataset.input_dimension, - 10) + offset_length) assert batch[session_id].index is not None log = solver.step(batch) assert isinstance(log, dict) + X = [ + loader.dataset.get_session(i).neural + for i in range(loader.dataset.num_sessions) + ] + with pytest.raises(ValueError, match="not.*fitted"): + solver.transform(X[0], session_id=0) + solver.fit(loader) + assert solver.num_sessions == 3 + assert solver.n_features == [X[i].shape[1] for i in range(len(X))] + + embedding = solver.transform(X[0], session_id=0) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X[0].shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X[1], session_id=1) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X[1].shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X[0], session_id=0, pad_before_transform=False) + assert isinstance(embedding, torch.Tensor) + assert embedding.shape == (X[0].shape[0] - + len(solver.model[0].get_offset()) + 1, + OUTPUT_DIMENSION) + + with pytest.raises(ValueError, match="torch.Tensor"): + embedding = solver.transform(X[0].numpy(), session_id=0) + + with pytest.raises(ValueError, match="shape"): + embedding = solver.transform(X[1], session_id=0) + with pytest.raises(ValueError, match="shape"): + embedding = solver.transform(X[0], session_id=1) + + with pytest.raises(RuntimeError, match="No.*session_id"): + embedding = solver.transform(X[0]) + with pytest.raises(RuntimeError, match="session_id.*provided"): + embedding = solver.transform(X) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X[0], session_id=5) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X[0], session_id=-1) + + for param in solver.parameters(session_id=0): + assert isinstance(param, torch.Tensor) + + fitted_solver = copy.deepcopy(solver) + with tempfile.TemporaryDirectory() as temp_dir: + solver.save(temp_dir) + solver.load(temp_dir) + _assert_equal(fitted_solver, solver) + + +def _make_val_data(dataset): + if isinstance(dataset, cebra.datasets.demo.DemoDataset): + return dataset.neural + elif isinstance(dataset, cebra.datasets.demo.DemoDatasetUnified): + return [session.neural for session in dataset.iter_sessions()], [ + session.continuous_index for session in dataset.iter_sessions() + ] + + +@pytest.mark.parametrize( + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.UnifiedSolver) + for dataset, loader in [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ] + for model in ["offset1-model", "offset10-model"]]) +def test_unified_session(data_name, model_architecture, loader_initfunc, + solver_initfunc): + loader, data = _get_loader(data_name, loader_initfunc) + model = _make_model(data, model_architecture) + data.configure_for(model) + offset = model.get_offset() -@pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", - multi_session_tests) -def test_multi_session_2(data_name, loader_initfunc, solver_initfunc): - loader = _get_loader(data_name, loader_initfunc) criterion = cebra.models.InfoNCE() - model = nn.ModuleList( - [_make_model(dataset) for dataset in loader.dataset.iter_sessions()]) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) solver = solver_initfunc(model=model, criterion=criterion, - optimizer=optimizer, - tqdm_on=True) + optimizer=optimizer) batch = next(iter(loader)) - for session_id, dataset in enumerate(loader.dataset.iter_sessions()): - assert batch[session_id].reference.shape == (32, - dataset.input_dimension, - 10) - assert batch[session_id].index is not None + assert batch.reference.shape == (32, loader.dataset.input_dimension, + len(offset)) log = solver.step(batch) assert isinstance(log, dict) solver.fit(loader) + data, labels = _make_val_data(loader.dataset) + + assert solver.num_sessions == 3 + assert solver.n_features == sum( + [data[i].shape[1] for i in range(len(data))]) + + for i in range(loader.dataset.num_sessions): + emb = solver.transform(data, labels, session_id=i) + assert emb.shape == (loader.dataset.num_timepoints, 3) + + emb = solver.transform(data, labels, session_id=i, batch_size=300) + assert emb.shape == (loader.dataset.num_timepoints, 3) diff --git a/tests/test_solver_batched.py b/tests/test_solver_batched.py new file mode 100644 index 00000000..8d60e77d --- /dev/null +++ b/tests/test_solver_batched.py @@ -0,0 +1,486 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from torch import nn + +import cebra.data +import cebra.datasets +import cebra.models +import cebra.solver + +device = "cpu" + +NUM_STEPS = 2 +BATCHES = [250, 500, 750] +MODELS = ["offset1-model", "offset10-model", "offset40-model-4x-subsample"] + + +@pytest.mark.parametrize( + "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", + [ + # Test case 1: No padding + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( + 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch + (torch.tensor( + [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset( + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch + + # Test case 2: First batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 0, + 2, + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 1), + 0, + 3, + torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + + # Test case 3: Last batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 1, + 3, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [13, 14, 15]]), + True, + cebra.data.Offset(1, 2), + 1, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + ), + + # Test case 4: Middle batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(0, 1), + 1, + 3, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(1, 1), + 1, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [13, 14, 15]]), + True, + cebra.data.Offset(0, 1), + 2, + 4, + torch.tensor([[7, 8, 9], [10, 11, 12]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(0, 1), + 0, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + # Padding without offset (should raise an error) + (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), + # Negative start_batch_idx or end_batch_idx (should raise an error) + (torch.tensor([[1, 2]]), False, cebra.data.Offset( + 0, 1), -1, 2, ValueError), + # out of bound indices because offset is too large + (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( + 5, 5), 1, 2, ValueError), + # Batch length is smaller than offset. + (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset( + 0, 1), 0, 1, ValueError), # first batch + ], +) +def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx, + expected_output): + if expected_output == ValueError: + with pytest.raises(ValueError): + cebra.solver.base._get_batch(inputs, offset, start_batch_idx, + end_batch_idx, add_padding) + else: + result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx, + end_batch_idx, add_padding) + assert torch.equal(result, expected_output) + + +def create_model(model_name, input_dimension): + return cebra.models.init(model_name, + num_neurons=input_dimension, + num_units=128, + num_output=3) + + +@pytest.mark.parametrize( + "data_name, model_name, session_id, loader_initfunc, solver_initfunc", + [(dataset, model, session_id, loader, cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in ["offset1-model", "offset10-model"] + for session_id in [None, 0, 5]] + + [(dataset, model, session_id, loader, + cebra.solver.SingleSessionHybridSolver) + for dataset, loader in [ + ("demo-continuous", cebra.data.HybridDataLoader), + ] + for model in ["offset1-model", "offset10-model"] + for session_id in [None, 0, 5]]) +def test_select_model_single_session(data_name, model_name, session_id, + loader_initfunc, solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = create_model(model_name, dataset.input_dimension) + dataset.configure_for(model) + offset = model.get_offset() + solver = solver_initfunc(model=model, criterion=None, optimizer=None) + + with pytest.raises(ValueError): + solver.n_features = 1000 + solver._select_model(inputs=dataset.neural, session_id=0) + + solver.n_features = dataset.neural.shape[1] + if session_id is not None and session_id > 0: + with pytest.raises(RuntimeError): + solver._select_model(inputs=dataset.neural, session_id=session_id) + else: + model_, offset_ = solver._select_model(inputs=dataset.neural, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ + + +@pytest.mark.parametrize( + "data_name, model_name, session_id, loader_initfunc, solver_initfunc", + [(dataset, model, session_id, loader, cebra.solver.MultiSessionSolver) + for dataset, loader in [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ] + for model in ["offset1-model", "offset10-model"] + for session_id in [None, 0, 1, 5, 2, 6, 4]]) +def test_select_model_multi_session(data_name, model_name, session_id, + loader_initfunc, solver_initfunc): + + dataset = cebra.datasets.init(data_name) + kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = loader_initfunc(dataset, **kwargs) + + model = nn.ModuleList([ + create_model(model_name, dataset.input_dimension) + for dataset in dataset.iter_sessions() + ]) + dataset.configure_for(model) + + offset = model[0].get_offset() + solver = solver_initfunc(model=model, + criterion=cebra.models.InfoNCE(), + optimizer=torch.optim.Adam(model.parameters(), + lr=1e-3)) + + loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = cebra.data.ContinuousMultiSessionDataLoader( + dataset, **loader_kwargs) + solver.fit(loader) + + for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): + inputs = dataset_.neural + + if session_id is None or session_id >= dataset.num_sessions: + with pytest.raises(RuntimeError): + solver._select_model(inputs, session_id=session_id) + elif i != session_id: + with pytest.raises(ValueError): + solver._select_model(inputs, session_id=session_id) + else: + model_, offset_ = solver._select_model(inputs, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ + + +@pytest.mark.parametrize( + "data_name, model_name, padding, batch_size_inference, loader_initfunc, solver_initfunc", + [(dataset, model, padding, batch_size, loader, + cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"] + for padding in [True, False] + for batch_size in BATCHES] + + [(dataset, model, padding, batch_size, loader, + cebra.solver.SingleSessionHybridSolver) + for dataset, loader in [ + ("demo-continuous", cebra.data.HybridDataLoader), + ] + for model in MODELS + for padding in [True, False] + for batch_size in BATCHES]) +def test_batched_transform_single_session( + data_name, + model_name, + padding, + batch_size_inference, + loader_initfunc, + solver_initfunc, +): + dataset = cebra.datasets.init(data_name) + model = create_model(model_name, dataset.input_dimension) + dataset.configure_for(model) + loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = loader_initfunc(dataset, **loader_kwargs) + + criterion = cebra.models.InfoNCE() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer) + solver.fit(loader) + + embedding_batched = solver.transform(inputs=loader.dataset.neural, + batch_size=batch_size_inference, + pad_before_transform=padding) + + embedding = solver.transform(inputs=loader.dataset.neural, + pad_before_transform=padding) + + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize( + "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc", + [(dataset, model, padding, batch_size, loader, + cebra.solver.MultiSessionSolver) + for dataset, loader in [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"] + for padding in [True, False] + for batch_size in BATCHES]) +def test_batched_transform_multi_session(data_name, model_name, padding, + batch_size_inference, loader_initfunc, + solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = nn.ModuleList([ + create_model(model_name, dataset.input_dimension) + for dataset in dataset.iter_sessions() + ]) + dataset.configure_for(model) + + n_samples = dataset._datasets[0].neural.shape[0] + assert all( + d.neural.shape[0] == n_samples for d in dataset._datasets + ), "for this set all of the sessions need to have same number of samples." + + loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = loader_initfunc(dataset, **loader_kwargs) + + criterion = cebra.models.InfoNCE() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer) + solver.fit(loader) + + # Transform each session with the right model, by providing + # the corresponding session ID + for i, inputs in enumerate(dataset.iter_sessions()): + embedding = solver.transform(inputs=inputs.neural, + session_id=i, + pad_before_transform=padding) + embedding_batched = solver.transform(inputs=inputs.neural, + session_id=i, + pad_before_transform=padding, + batch_size=batch_size_inference) + + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize( + "batch_start_idx, batch_end_idx, offset, num_samples, expected_exception", + [ + # Valid indices + (0, 5, cebra.data.Offset(1, 1), 10, None), + (2, 8, cebra.data.Offset(2, 2), 10, None), + # Negative indices + (-1, 5, cebra.data.Offset(1, 1), 10, ValueError), + (0, -5, cebra.data.Offset(1, 1), 10, ValueError), + # Start index greater than end index + (5, 3, cebra.data.Offset(1, 1), 10, ValueError), + # End index out of bounds + (0, 11, cebra.data.Offset(1, 1), 10, ValueError), + # Batch size smaller than offset + (0, 2, cebra.data.Offset(3, 3), 10, ValueError), + ], +) +def test_check_indices(batch_start_idx, batch_end_idx, offset, num_samples, + expected_exception): + if expected_exception: + with pytest.raises(expected_exception): + cebra.solver.base._check_indices(batch_start_idx, batch_end_idx, + offset, num_samples) + else: + cebra.solver.base._check_indices(batch_start_idx, batch_end_idx, offset, + num_samples) + + +@pytest.mark.parametrize( + "batch_start_idx, batch_end_idx, num_samples, expected_exception", + [ + # First batch + (0, 6, 12, 8), + # Last batch + (6, 12, 12, 8), + # Middle batch + (3, 9, 12, 6), + # Invalid start index + (-1, 3, 4, ValueError), + # Invalid end index + (3, -10, 4, ValueError), + # Start index greater than end index + (5, 3, 4, ValueError), + ], +) +def test_add_batched_zero_padding(batch_start_idx, batch_end_idx, num_samples, + expected_exception): + batched_data = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], + [9.0, 10.0], [1.0, 2.0]]) + + model = create_model(model_name="offset5-model", + input_dimension=batched_data.shape[1]) + offset = model.get_offset() + + if expected_exception == ValueError: + with pytest.raises(expected_exception): + result = cebra.solver.base._add_batched_zero_padding( + batched_data, offset, batch_start_idx, batch_end_idx, + num_samples) + else: + result = cebra.solver.base._add_batched_zero_padding( + batched_data, offset, batch_start_idx, batch_end_idx, num_samples) + assert result.shape[0] == expected_exception + + +@pytest.mark.parametrize( + "pad_before_transform, expected_exception", + [ + # Valid batched inputs + (True, None), + # No padding + (False, None), + ], +) +def test_transform(pad_before_transform, expected_exception): + inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], + [9.0, 10.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], + [7.0, 8.0], [9.0, 10.0], [1.0, 2.0], [3.0, 4.0], + [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]) + model = create_model(model_name="offset5-model", + input_dimension=inputs.shape[1]) + offset = model.get_offset() + + result = cebra.solver.base._not_batched_transform( + model=model, + inputs=inputs, + pad_before_transform=pad_before_transform, + offset=offset, + ) + if pad_before_transform: + assert result.shape[0] == inputs.shape[0] + else: + assert result.shape[0] == inputs.shape[0] - len(offset) + 1 + + +@pytest.mark.parametrize( + "batch_size, pad_before_transform, expected_exception", + [ + # Valid batched inputs + (6, True, None), + # Invalid batch size (too large) + (12, True, ValueError), + # Invalid batch size (too small) + (2, True, ValueError), + # Last batch size incomplete + (5, True, None), + # No padding + (6, False, None), + ], +) +def test_batched_transform(batch_size, pad_before_transform, + expected_exception): + inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0], + [9.0, 10.0], [1.0, 2.0], [3.0, 4.0], [5.0, 6.0], + [7.0, 8.0], [9.0, 10.0], [1.0, 2.0], [3.0, 4.0], + [5.0, 6.0], [7.0, 8.0], [9.0, 10.0]]) + model = create_model(model_name="offset5-model", + input_dimension=inputs.shape[1]) + offset = model.get_offset() + + if expected_exception: + with pytest.raises(expected_exception): + cebra.solver.base._batched_transform( + model=model, + inputs=inputs, + batch_size=batch_size, + pad_before_transform=pad_before_transform, + offset=offset, + ) + else: + result = cebra.solver.base._batched_transform( + model=model, + inputs=inputs, + batch_size=batch_size, + pad_before_transform=pad_before_transform, + offset=offset, + ) + if pad_before_transform: + assert result.shape[0] == inputs.shape[0] + else: + assert result.shape[0] == inputs.shape[0] - len(offset) + 1