Skip to content

Add path to unified CEBRA demo icon #252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
120 commits
Select commit Hold shift + click to select a range
283de06
first proposal for batching in tranform method
gonlairo Jun 21, 2023
202e379
first running version of padding with batched inference
gonlairo Jun 22, 2023
1f1989d
start tests
gonlairo Jun 23, 2023
8665660
add pad_before_transform to fit function and add support for convolut…
gonlairo Sep 27, 2023
8d5b114
remove print statements
gonlairo Sep 27, 2023
32c5ecd
first passing test
gonlairo Sep 27, 2023
9928f63
add support for hybrid models
gonlairo Sep 28, 2023
be5630a
rewrite transform in sklearn API
gonlairo Sep 28, 2023
1300b20
baseline version of a torch.Datset
gonlairo Oct 16, 2023
bc6af24
move batching logic outside solver
gonlairo Oct 20, 2023
ec377b9
move functionality to base file in solver and separate in functions
gonlairo Oct 27, 2023
6f9ca98
add test_select_model for single session
gonlairo Oct 27, 2023
fbe7eb4
add checks and test for _process_batch
gonlairo Oct 27, 2023
463b0f8
add test_select_model for multisession
gonlairo Oct 30, 2023
5219171
make self.num_sessions compatible with single session training
gonlairo Oct 31, 2023
f9bd1a6
improve test_batched_transform_singlesession
gonlairo Nov 1, 2023
e23a7ef
make it work with small batches
gonlairo Nov 7, 2023
19c3f87
make test with multisession work
gonlairo Nov 8, 2023
87bebac
change to torch padding
gonlairo Nov 9, 2023
f0303e0
add argument to sklearn api
gonlairo Nov 9, 2023
8c8be85
add torch padding to _transform
gonlairo Nov 9, 2023
59df402
convert to torch if numpy array as inputs
gonlairo Nov 9, 2023
1aadc8b
add distinction between pad with data and pad with zeros and modify t…
gonlairo Nov 15, 2023
bc8ee25
differentiate between data padding and zero padding
gonlairo Nov 17, 2023
5e7a14c
remove float16
gonlairo Nov 24, 2023
928d882
change argument position
gonlairo Nov 27, 2023
07bac1c
clean test
gonlairo Nov 27, 2023
0823b54
clean test
gonlairo Nov 27, 2023
9fe3af3
Fix warning
CeliaBenquet Mar 26, 2024
b417a23
Improve modularity remove duplicate code and todos
CeliaBenquet Aug 21, 2024
83c1669
Add tests to solver
CeliaBenquet Aug 22, 2024
9c46eb9
Remove unused import in solver/utils
CeliaBenquet Aug 22, 2024
c845ec3
Fix test plot
CeliaBenquet Aug 22, 2024
9db3e37
Add some coverage
CeliaBenquet Aug 22, 2024
8e5f933
Fix save/load
CeliaBenquet Aug 22, 2024
d08e400
Remove duplicate configure_for in multi dataset
CeliaBenquet Aug 22, 2024
0c693dd
Make save/load cleaner
CeliaBenquet Aug 22, 2024
ae056b2
Merge branch 'main' into batched-inference-and-padding
CeliaBenquet Sep 18, 2024
794867b
Fix codespell errors
CeliaBenquet Sep 18, 2024
0bb6549
Fix docs compilation errors
CeliaBenquet Sep 18, 2024
04a102f
Fix formatting
CeliaBenquet Sep 18, 2024
7aab282
Fix extra docs errors
CeliaBenquet Sep 18, 2024
ffa66eb
Fix offset in docs
CeliaBenquet Sep 18, 2024
7f58607
Remove attribute ref
CeliaBenquet Sep 18, 2024
c2544c7
Add review updates
CeliaBenquet Sep 19, 2024
ad5da03
Merge branch 'main' into batched-inference-and-padding
stes Oct 20, 2024
f6aa2e6
Merge branch 'main' into batched-inference-and-padding
MMathisLab Oct 20, 2024
e1b7cc7
apply ruff auto-fixes
stes Oct 27, 2024
0eac868
Merge remote-tracking branch 'origin/main' into batched-inference-and…
stes Oct 27, 2024
81b964c
Concatenate last batches for batched inference (#200)
CeliaBenquet Jan 21, 2025
a09d123
Fix linting errors in tests (#188)
stes Oct 27, 2024
521f003
Fix `scikit-learn` reference in conda environment files (#195)
stes Nov 8, 2024
46610e3
Add support for new __sklearn_tags__ (#205)
stes Dec 16, 2024
e8004ba
Update workflows to actions/setup-python@v5, actions/cache@v4 (#212)
stes Jan 21, 2025
ddc00f4
Fix deprecation warning force_all_finite -> ensure_all_finite for skl…
icarosadero Jan 22, 2025
7dc9f81
Add tests to check legacy model loading (#214)
stes Jan 29, 2025
a2a6c44
Add improved goodness of fit implementation (#190)
stes Feb 2, 2025
a3b143f
Support numpy 2, upgrade tests to support torch 2.6 (#221)
stes Feb 2, 2025
0d5d82a
Release 0.5.0rc1 (#189)
stes Feb 2, 2025
92fd9bc
Fix pypi action (#222)
stes Feb 3, 2025
69d91ef
Update base.py (#224)
icarosadero Feb 18, 2025
782b63a
Change max consistency value to 100 instead of 99 (#227)
CeliaBenquet Mar 1, 2025
d72b055
Update assets.py --> force check for parent dir (#230)
MMathisLab Mar 1, 2025
9fd91c3
User docs minor edit (#229)
MMathisLab Mar 1, 2025
8d636e9
General Doc refresher (#232)
MMathisLab Mar 3, 2025
36370be
render plotly in our docs, show code/doc version (#231)
MMathisLab Mar 4, 2025
f7f4d7f
Update layout.html (#233)
MMathisLab Mar 6, 2025
798f7b2
Update conf.py (#234)
MMathisLab Mar 6, 2025
4a2996d
Refactoring setup.cfg (#228)
MMathisLab Mar 15, 2025
7abd1b0
Home page landing update (#235)
MMathisLab Mar 15, 2025
673019a
v0.5.0 (#238)
MMathisLab Apr 17, 2025
9625680
Upgrade docs build (#241)
stes Apr 18, 2025
95e5296
Allow indexing of the cebra docs (#242)
stes Apr 20, 2025
20f5a77
Fix broken docs coverage workflows (#246)
stes Apr 23, 2025
0d85abb
Add xCEBRA implementation (AISTATS 2025) (#225)
gonlairo Apr 23, 2025
b19be59
start tests
gonlairo Jun 23, 2023
e908083
remove print statements
gonlairo Sep 27, 2023
3d2b1e3
first passing test
gonlairo Sep 27, 2023
3ef4bc1
move functionality to base file in solver and separate in functions
gonlairo Oct 27, 2023
ad56472
add test_select_model for multisession
gonlairo Oct 30, 2023
b73c123
remove float16
gonlairo Nov 24, 2023
d71ca8d
Improve modularity remove duplicate code and todos
CeliaBenquet Aug 21, 2024
3e91459
Add tests to solver
CeliaBenquet Aug 22, 2024
c6179ad
Fix save/load
CeliaBenquet Aug 22, 2024
dafabe5
Fix extra docs errors
CeliaBenquet Sep 18, 2024
7b0cc68
Add review updates
CeliaBenquet Sep 19, 2024
7dfd4b9
apply ruff auto-fixes
stes Oct 27, 2024
3acbdf4
fix linting errors
stes Jan 21, 2025
5745449
Run isort, ruff, yapf
CeliaBenquet Apr 23, 2025
fa3cd3e
Merge remote-tracking branch 'upstream/main' into batched-inference-a…
CeliaBenquet Apr 23, 2025
1453885
Merge branch 'main' into batched-inference-and-padding
MMathisLab Apr 23, 2025
acd2111
Fix gaussian mixture dataset import
CeliaBenquet Apr 23, 2025
217a8a7
Fix all tests but xcebra tests
CeliaBenquet Apr 23, 2025
a1218aa
Fix pytorch API usage example
CeliaBenquet Apr 24, 2025
64d1db8
Make xCEBRA compatible with the batched inference & padding in solver
CeliaBenquet Apr 24, 2025
9875a38
Add some tests on transform() with xCEBRA
CeliaBenquet Apr 24, 2025
65fc455
Add some docstrings and typings and clean unnecessary changes
CeliaBenquet Apr 24, 2025
1d0c498
Implement review comments
CeliaBenquet Apr 24, 2025
4a25899
Fix sklearn test
CeliaBenquet Apr 25, 2025
b8945ae
Initial pass at integrating unifiedCEBRA
CeliaBenquet Apr 25, 2025
0d56e44
Add name in NOTE
CeliaBenquet Apr 25, 2025
c5dc011
Implement reviews on tests and typing
CeliaBenquet Apr 25, 2025
c9fa5c8
Fix import errors
CeliaBenquet Apr 28, 2025
9ba22bc
Merge branch 'batched-inference-and-padding' into unified-cebra
CeliaBenquet Apr 28, 2025
4632c04
Add select_model to aux solvers
CeliaBenquet Apr 28, 2025
a52f502
Merge branch 'batched-inference-and-padding' into unified-cebra
CeliaBenquet Apr 28, 2025
c22e40e
Fix tests
CeliaBenquet Apr 28, 2025
e8a1877
Add mask tests
CeliaBenquet Apr 28, 2025
22e3c47
Fix docs error
CeliaBenquet Apr 30, 2025
464f4aa
Merge branch 'batched-inference-and-padding' into unified-cebra
CeliaBenquet May 1, 2025
57c9494
Remove masking init()
CeliaBenquet May 1, 2025
0d953fc
Remove shuffled neurons in unified dataset
CeliaBenquet May 1, 2025
eba09b6
Remove extra datasets
CeliaBenquet May 1, 2025
cc8671c
Add tests on the private functions in base solver
CeliaBenquet May 2, 2025
b83421d
Update tests and duplicate code based on review
CeliaBenquet May 5, 2025
f2d1e3a
Fix quantized_embedding_norm undefined when `normalize=False` (#249)
CeliaBenquet May 5, 2025
619a662
Fix tests
CeliaBenquet Apr 28, 2025
32fae46
Adapt unified code to get_model method
CeliaBenquet May 20, 2025
2016566
Add the icon path
CeliaBenquet May 20, 2025
1053333
Merge branch 'AdaptiveMotorControlLab:main' into unified-cebra-demo
CeliaBenquet May 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cebra/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,4 @@
from cebra.data.multiobjective import *
from cebra.data.datasets import *
from cebra.data.helper import *
from cebra.data.masking import *
5 changes: 4 additions & 1 deletion cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch

import cebra.data.assets as cebra_data_assets
import cebra.data.masking as cebra_data_masking
import cebra.distributions
import cebra.io
from cebra.data.datatypes import Batch
Expand All @@ -36,7 +37,7 @@
__all__ = ["Dataset", "Loader"]


class Dataset(abc.ABC, cebra.io.HasDevice):
class Dataset(abc.ABC, cebra.io.HasDevice, cebra_data_masking.MaskedMixin):
"""Abstract base class for implementing a dataset.

The class attributes provide information about the shape of the data when
Expand Down Expand Up @@ -227,6 +228,8 @@ class Loader(abc.ABC, cebra.io.HasDevice):
doc="""A dataset instance specifying a ``__getitem__`` function.""",
)

time_offset: int = dataclasses.field(default=10)

num_steps: int = dataclasses.field(
default=None,
doc=
Expand Down
71 changes: 70 additions & 1 deletion cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import torch

import cebra.data as cebra_data
import cebra.data.masking as cebra_data_masking
import cebra.helper as cebra_helper
import cebra.io as cebra_io
from cebra.data.datatypes import Batch
Expand Down Expand Up @@ -304,7 +305,7 @@ def _iter_property(self, attr):


# TODO(stes): This should be a single session dataset?
class DatasetxCEBRA(cebra_io.HasDevice):
class DatasetxCEBRA(cebra_io.HasDevice, cebra_data_masking.MaskedMixin):
"""Dataset class for xCEBRA models.

This class handles neural data and associated labels for xCEBRA models, providing
Expand Down Expand Up @@ -435,3 +436,71 @@ def load_batch_contrastive(self, index: BatchIndex) -> Batch:
positive=[self[idx] for idx in index.positive],
negative=self[index.negative],
)


class UnifiedDataset(DatasetCollection):
"""Multi session dataset made up of a list of datasets, considered as a unique session.

Considering the sessions as a unique session, or pseudo-session, is used to later train a single
model for all the sessions, even if they originally contain a variable number of neurons.
To do that, we sample ref/pos/neg for each session and concatenate them along the neurons axis.

For instance, for a batch size ``batch_size``, we sample ``(batch_size, num_neurons(session), offset)`` for
each type of samples (ref/pos/neg) and then concatenate so that the final :py:class:`cebra.data.datatypes.Batch`
is of shape ``(batch_size, total_num_neurons, offset)``, with ``total_num_neurons`` is the sum of all the
``num_neurons(session)``.
"""

def __init__(self, *datasets: cebra_data.SingleSessionDataset):
super().__init__(*datasets)

@property
def input_dimension(self) -> int:
"""Returns the sum of the input dimension for each session."""
return np.sum([
self.get_input_dimension(session_id)
for session_id in range(self.num_sessions)
])

def _get_batches(self, index):
"""Return the data at the specified index location."""
return [
cebra_data.Batch(
reference=self.get_session(session_id)[
index.reference[session_id]],
positive=self.get_session(session_id)[
index.positive[session_id]],
negative=self.get_session(session_id)[
index.negative[session_id]],
) for session_id in range(self.num_sessions)
]

def load_batch(self, index: BatchIndex) -> Batch:
"""Return the data at the specified index location.

Concatenate batches for each sessions on the number of neurons axis.

Args:
batches: List of :py:class:`cebra.data.datatypes.Batch` sampled for each session. An instance
:py:class:`cebra.data.datatypes.Batch` of the list is of shape ``(batch_size, num_neurons(session), offset)``.

Returns:
A :py:class:`cebra.data.datatypes.Batch`, of shape ``(batch_size, total_num_neurons, offset)``, where
``total_num_neurons`` is the sum of all the ``num_neurons(session)``
"""
batches = self._get_batches(index)

return cebra_data.Batch(
reference=self.apply_mask(
torch.cat([batch.reference for batch in batches], dim=1)),
positive=self.apply_mask(
torch.cat([batch.positive for batch in batches], dim=1)),
negative=self.apply_mask(
torch.cat([batch.negative for batch in batches], dim=1)),
)

def __getitem__(self, args) -> List[Batch]:
"""Return a set of samples from all sessions."""

session_id, index = args
return self.get_session(session_id).__getitem__(index)
Loading