Skip to content

Add unified encoder pytorch implementation #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 128 commits into
base: main
Choose a base branch
from

Conversation

CeliaBenquet
Copy link
Member

@CeliaBenquet CeliaBenquet commented May 1, 2025

This PR adds a PyTorch implementation of a unified CEBRA encoder, which is composed of:

  • A new sampling scheme that samples across all sessions so that they can be aligned on the neuron axis to train a single encoder.
  • A unified Dataset and Loader, adapted to the new sampling scheme.
  • A unified Solver that considers multiple sessions to be aligned at inference.
  • A new masked modeling training option, with different types of masking.

🚧 A preprint is pending "Unified CEBRA Encoders for Integrating Neural Recordings via Behavioral Alignment" by Célia Benquet, Hossein Mirzaei, Steffen Schneider, Mackenzie W. Mathis.

@cla-bot cla-bot bot added the CLA signed label May 1, 2025
@CeliaBenquet CeliaBenquet requested review from stes and MMathisLab May 20, 2025 10:51
@MMathisLab MMathisLab changed the base branch from batched-inference-and-padding to main May 23, 2025 13:39
positive=self[index.positive],
negative=self[index.negative],
reference=self[index.reference],
positive=self.apply_mask(self[index.positive]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

quick sanity check; this is backwards compatable? @CeliaBenquet

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

backward compatible yes, I added a check on if the function doesn't exist for people who might want to use the adapt functionality on an older model, good catch.

@@ -97,6 +97,8 @@ def get_datapath(path: str = None) -> str:
from cebra.datasets.hippocampus import *
from cebra.datasets.monkey_reaching import *
from cebra.datasets.synthetic_data import *
from cebra.datasets.perich import *
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to package this, or is this downloaded from DANDI?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a leftover line. Now I don't add the files for datasets for perich (and NLB) because it requires adapting the packages installed when installing CEBRA.

For NLB, it requires having the data downloaded already + the nlb_tools package installed.
For Perich, it also requires to have the data downloaded already but the issue is that it requires POYO code that they removed now, so we would need to go to a previous commit etc.

Let me know if that's necessary to have. And in that case I suppose we also need the S1 M1 dataset class then if you need one.

Copy link
Member

@MMathisLab MMathisLab left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @CeliaBenquet ! I went through and left comments for disucssion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we put this into integrations vs. models? Models to me is encoders only cc @stes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We currently have some decoders here, although these are sklearn specific.

I think this module here is fine, at least right now I dont see a better place in the codebase to put them in. An argument to leave them here would be that they are an "extension" of the encoders we train, plus they are "raw" torch objects, which we currently all collected in cebra.models.

I dont have a strong opinion, just don't see where they would fit better... In integrations, we currently have only "standalone" helper functions, which these aren't.

@CeliaBenquet where are these decoders used around the codebase? and how are they trained?

@@ -0,0 +1,38 @@
import torch.nn as nn
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not have decoders somewhere like integrations? to me model is the encoders only cc @stes

Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall; left some comments!

  • Implementation of the Mixin class for the masking: If I understood correctly, the only change is that this apply_mask function is applied after loading a batch. This seems to be a change that could be minimally invasively applied not in the dataset, but actually in the data loader. Is there a good case why the datasets themselves need to be modified?
  • Discussion on where to place the decoders: currently in cebra.models.decoders; are the decoders useful as "standalone" models? where are they currently used? based on that we could determine if we move them e.g. as standalone to integrations
  • see other comments; mostly on class design, removing duplicated code, etc.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't do two new modules. using cebra.data.mask or cebra.data.masking, and keep the code together I'd say.

Comment on lines +100 to +123
if hasattr(self, "apply_mask"):
batch = [
cebra_data.Batch(
reference=self.apply_mask(
session[index.reference[session_id]]),
positive=self.apply_mask(
session[index.positive[session_id]]),
negative=self.apply_mask(
session[index.negative[session_id]]),
index=index.index,
index_reversed=index.index_reversed,
) for session_id, session in enumerate(self.iter_sessions())
]
else:
batch = [
cebra_data.Batch(
reference=session[index.reference[session_id]],
positive=session[index.positive[session_id]],
negative=session[index.negative[session_id]],
index=index.index,
index_reversed=index.index_reversed,
) for session_id, session in enumerate(self.iter_sessions())
]
return batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we convert this if/else statement into a subclass

)
session.configure_for(model[i])
else:
session.configure_for(model)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would restructure. Make one common baseclass, then the "old" multisession class is one subclass of it, and the unified one is another subclass. removes all if statements and cleanly separates the logic.

Comment on lines +67 to +80
if hasattr(self, "apply_mask"):
# If the dataset has a mask, apply it to the data.
batch = Batch(
positive=self.apply_mask(self[index.positive]),
negative=self.apply_mask(self[index.negative]),
reference=self.apply_mask(self[index.reference]),
)
else:
batch = Batch(
positive=self[index.positive],
negative=self[index.negative],
reference=self[index.reference],
)
return batch
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above; a better way to implement this is by having the masking simply override the load_batch function, vs. introducing this if/else logic.

@@ -33,6 +33,8 @@
from cebra.datasets import register

_DEFAULT_NUM_TIMEPOINTS = 1_000
NUMS_NEURAL = [3, 4, 5]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
NUMS_NEURAL = [3, 4, 5]
_NUMS_NEURAL = [3, 4, 5]

not public (and adapt below)

Comment on lines +40 to +45
Masking helpers
----------------

.. automodule:: cebra.data.masking
:members:
:show-inheritance:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as written above, i would only add a single module dedicated to masking, vs. splitting this up further

Comment on lines +24 to +25
#### Tests for Mask class ####

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#### Tests for Mask class ####

assert emb.shape == (loader.dataset.num_timepoints, 3)

emb = solver.transform(data, labels, session_id=i, batch_size=300)
assert emb.shape == (loader.dataset.num_timepoints, 3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, looks like pre-commit was not run on this PR.

The `set_masks` method should be called to set the masking types
and their corresponding probabilities.
"""
masks = [] # a list of Mask instances
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to be public?

@@ -36,7 +37,7 @@
__all__ = ["Dataset", "Loader"]


class Dataset(abc.ABC, cebra.io.HasDevice):
class Dataset(abc.ABC, cebra.io.HasDevice, cebra_data_masking.MaskedMixin):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this Mixin used anywhere else in the codebase?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants