Skip to content

Commit f1894a1

Browse files
committed
add positive sampling options for MixedDataLoader
1 parent 0378db0 commit f1894a1

File tree

1 file changed

+45
-14
lines changed

1 file changed

+45
-14
lines changed

cebra/data/single_session.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -261,27 +261,47 @@ class MixedDataLoader(cebra_data.Loader):
261261
1. Positive pairs always share their discrete variable.
262262
2. Positive pairs are drawn only based on their conditional,
263263
not discrete variable.
264+
265+
Args:
266+
conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
267+
time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
268+
positive_sampling (str): either "discrete_variable" (default) or "conditional"
269+
discrete_sampling_prior (str): either "empirical" (default) or "uniform"
264270
"""
265271

266272
conditional: str = dataclasses.field(default="time_delta")
267273
time_offset: int = dataclasses.field(default=10)
274+
positive_sampling: str = dataclasses.field(default="discrete_variable")
275+
discrete_sampling_prior: str = dataclasses.field(default="uniform")
268276

269277
@property
270-
def dindex(self):
271-
# TODO(stes) rename to discrete_index
278+
def discrete_index(self):
272279
return self.dataset.discrete_index
273280

274281
@property
275-
def cindex(self):
276-
# TODO(stes) rename to continuous_index
282+
def continuous_index(self):
277283
return self.dataset.continuous_index
278284

279285
def __post_init__(self):
280286
super().__post_init__()
281-
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
282-
discrete=self.dindex,
283-
continuous=self.cindex,
284-
time_delta=self.time_offset)
287+
if self.positive_sampling == "conditional":
288+
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
289+
discrete=self.discrete_index,
290+
continuous=self.continuous_index,
291+
time_delta=self.time_offset)
292+
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical":
293+
self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index)
294+
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform":
295+
self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index)
296+
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior not in ["empirical", "uniform"]:
297+
raise ValueError(
298+
f"Invalid choice of prior distribution. Got '{self.discrete_sampling_prior}', but "
299+
f"only accept 'uniform' or 'empirical' as potential values.")
300+
else:
301+
raise ValueError(
302+
f"Invalid positive sampling mode: "
303+
f"{self.positive_sampling} valid options are "
304+
f"'conditional' or 'discrete_variable'.")
285305

286306
def get_indices(self, num_samples: int) -> BatchIndex:
287307
"""Samples indices for reference, positive and negative examples.
@@ -306,12 +326,23 @@ def get_indices(self, num_samples: int) -> BatchIndex:
306326
class.
307327
- Sample the negatives with matching discrete variable
308328
"""
309-
reference_idx = self.distribution.sample_prior(num_samples)
310-
return BatchIndex(
311-
reference=reference_idx,
312-
negative=self.distribution.sample_prior(num_samples),
313-
positive=self.distribution.sample_conditional(reference_idx),
314-
)
329+
if self.positive_sampling == "conditional":
330+
reference_idx = self.distribution.sample_prior(num_samples)
331+
return BatchIndex(
332+
reference=reference_idx,
333+
negative=self.distribution.sample_prior(num_samples),
334+
positive=self.distribution.sample_conditional(reference_idx),
335+
)
336+
else:
337+
# taken from the DiscreteDataLoader get_indices function
338+
reference_idx = self.distribution.sample_prior(num_samples * 2)
339+
negative_idx = reference_idx[num_samples:]
340+
reference_idx = reference_idx[:num_samples]
341+
reference = self.discrete_index[reference_idx]
342+
positive_idx = self.distribution.sample_conditional(reference)
343+
return BatchIndex(reference=reference_idx,
344+
positive=positive_idx,
345+
negative=negative_idx)
315346

316347

317348
@dataclasses.dataclass

0 commit comments

Comments
 (0)