-
Notifications
You must be signed in to change notification settings - Fork 88
Add parameters for MixedDataLoader #101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
248b962
dc1c77c
6763dc1
9835d45
8dee8a0
0326fb9
545c9a0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -265,30 +265,69 @@ class MixedDataLoader(cebra_data.Loader): | |||
|
||||
Sampling can be configured in different modes: | ||||
|
||||
1. Positive pairs always share their discrete variable. | ||||
1. Positive pairs always share their discrete variable (positive_sampling = "discrete_variable"). | ||||
2. Positive pairs are drawn only based on their conditional, | ||||
not discrete variable. | ||||
not discrete variable (positive_sampling = "conditional"). | ||||
|
||||
When using the discrete variable, the prior distribution can either be uniform | ||||
(discrete_sampling_prior = "uniform") or empirical (discrete_sampling_prior = "empirical"). | ||||
|
||||
Based on the selection of those parameters, the :py:class:`cebra.distributions.mixed.MixedTimeDeltaDistribution`, | ||||
:py:class:`cebra.distributions.discrete.DiscreteEmpirical`, or :py:class:`cebra.distributions.discrete.DiscreteUniform` | ||||
distributions are used for sampling. | ||||
|
||||
Args: | ||||
conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional` | ||||
time_offset (int): :py:attr:`cebra.CEBRA.time_offsets` | ||||
positive_sampling (str): either "discrete_variable" (default) or "conditional" | ||||
discrete_sampling_prior (str): either "empirical" (default) or "uniform" | ||||
""" | ||||
|
||||
conditional: str = dataclasses.field(default="time_delta") | ||||
time_offset: int = dataclasses.field(default=10) | ||||
positive_sampling: str = dataclasses.field(default="discrete_variable") | ||||
discrete_sampling_prior: str = dataclasses.field(default="uniform") | ||||
|
||||
@property | ||||
def dindex(self): | ||||
# TODO(stes) rename to discrete_index | ||||
warnings.warn("dindex is deprecated. Use discrete_index instead.", | ||||
DeprecationWarning) | ||||
return self.dataset.discrete_index | ||||
|
||||
@property | ||||
def discrete_index(self): | ||||
return self.dataset.discrete_index | ||||
|
||||
@property | ||||
def cindex(self): | ||||
# TODO(stes) rename to continuous_index | ||||
warnings.warn("cindex is deprecated. Use continuous_index instead.", | ||||
DeprecationWarning) | ||||
return self.dataset.continuous_index | ||||
|
||||
@property | ||||
def continuous_index(self): | ||||
return self.dataset.continuous_index | ||||
|
||||
def __post_init__(self): | ||||
super().__post_init__() | ||||
self.distribution = cebra.distributions.MixedTimeDeltaDistribution( | ||||
discrete=self.dindex, | ||||
continuous=self.cindex, | ||||
time_delta=self.time_offset) | ||||
if self.positive_sampling == "conditional": | ||||
self.distribution = cebra.distributions.MixedTimeDeltaDistribution( | ||||
discrete=self.discrete_index, | ||||
continuous=self.continuous_index, | ||||
time_delta=self.time_offset) | ||||
Comment on lines
+313
to
+317
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be the default behavior, that was how the class used to behave. |
||||
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical": | ||||
self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index) | ||||
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform": | ||||
self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index) | ||||
Comment on lines
+318
to
+321
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are these modes different from going for the empirical discrete / uniform discrete distribution in the first place? I think what we rather want is specify an option to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, but I understood that the current docstring of CEBRA/cebra/data/single_session.py Line 268 in 9898850
Even though I agree that it wouldn't make sense in this case to call the |
||||
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]: | ||||
raise ValueError( | ||||
f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but " | ||||
f"only accept 'uniform' or 'empirical' as potential values.") | ||||
else: | ||||
raise ValueError( | ||||
f"Invalid positive sampling mode: " | ||||
f"{self.positive_sampling} valid options are " | ||||
f"'conditional' or 'discrete_variable'.") | ||||
|
||||
def get_indices(self, num_samples: int) -> BatchIndex: | ||||
"""Samples indices for reference, positive and negative examples. | ||||
|
@@ -313,12 +352,15 @@ def get_indices(self, num_samples: int) -> BatchIndex: | |||
class. | ||||
- Sample the negatives with matching discrete variable | ||||
""" | ||||
reference_idx = self.distribution.sample_prior(num_samples) | ||||
return BatchIndex( | ||||
reference=reference_idx, | ||||
negative=self.distribution.sample_prior(num_samples), | ||||
positive=self.distribution.sample_conditional(reference_idx), | ||||
) | ||||
if self.positive_sampling == "conditional": | ||||
reference_idx = self.distribution.sample_prior(num_samples) | ||||
return BatchIndex( | ||||
reference=reference_idx, | ||||
negative=self.distribution.sample_prior(num_samples), | ||||
positive=self.distribution.sample_conditional(reference_idx), | ||||
) | ||||
else: | ||||
return self.distribution.get_indices(num_samples) | ||||
|
||||
|
||||
@dataclasses.dataclass | ||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,6 +186,37 @@ def test_continuous(conditional, device, benchmark): | |
benchmark(load_speed) | ||
|
||
|
||
@parametrize_device | ||
@pytest.mark.parametrize( | ||
"conditional, positive_sampling, discrete_sampling_prior", | ||
[ | ||
("time", "discrete_variable", "empirical"), | ||
("time", "conditional", "empirical"), | ||
("time", "discrete_variable", "uniform"), | ||
("time", "conditional", "uniform"), | ||
("time_delta", "discrete_variable", "empirical"), | ||
("time_delta", "conditional", "empirical"), | ||
("time_delta", "discrete_variable", "uniform"), | ||
("time_delta", "conditional", "uniform"), | ||
], | ||
) | ||
def test_mixed( | ||
conditional, positive_sampling, discrete_sampling_prior, device, benchmark | ||
): | ||
dataset = RandomDataset(N=100, d=5, device=device) | ||
loader = cebra.data.MixedDataLoader( | ||
dataset=dataset, | ||
num_steps=10, | ||
batch_size=8, | ||
conditional=conditional, | ||
positive_sampling=positive_sampling, | ||
discrete_sampling_prior=discrete_sampling_prior, | ||
) | ||
_assert_dataset_on_correct_device(loader, device) | ||
load_speed = LoadSpeed(loader) | ||
benchmark(load_speed) | ||
Comment on lines
+206
to
+217
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should extend the test to check the properties of the positive and negative samples (e.g., check if the discrete labels match and so forth, as expected for each setting of parameters) |
||
|
||
|
||
def _check_attributes(obj, is_list=False): | ||
if is_list: | ||
for obj_ in obj: | ||
|
Uh oh!
There was an error while loading. Please reload this page.