@@ -261,27 +261,47 @@ class MixedDataLoader(cebra_data.Loader):
261
261
1. Positive pairs always share their discrete variable.
262
262
2. Positive pairs are drawn only based on their conditional,
263
263
not discrete variable.
264
+
265
+ Args:
266
+ conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
267
+ time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
268
+ positive_sampling (str): either "discrete_variable" (default) or "conditional"
269
+ discrete_sampling_prior (str): either "empirical" (default) or "uniform"
264
270
"""
265
271
266
272
conditional : str = dataclasses .field (default = "time_delta" )
267
273
time_offset : int = dataclasses .field (default = 10 )
274
+ positive_sampling : str = dataclasses .field (default = "discrete_variable" )
275
+ discrete_sampling_prior : str = dataclasses .field (default = "uniform" )
268
276
269
277
@property
270
- def dindex (self ):
271
- # TODO(stes) rename to discrete_index
278
+ def discrete_index (self ):
272
279
return self .dataset .discrete_index
273
280
274
281
@property
275
- def cindex (self ):
276
- # TODO(stes) rename to continuous_index
282
+ def continuous_index (self ):
277
283
return self .dataset .continuous_index
278
284
279
285
def __post_init__ (self ):
280
286
super ().__post_init__ ()
281
- self .distribution = cebra .distributions .MixedTimeDeltaDistribution (
282
- discrete = self .dindex ,
283
- continuous = self .cindex ,
284
- time_delta = self .time_offset )
287
+ if self .positive_sampling == "conditional" :
288
+ self .distribution = cebra .distributions .MixedTimeDeltaDistribution (
289
+ discrete = self .discrete_index ,
290
+ continuous = self .continuous_index ,
291
+ time_delta = self .time_offset )
292
+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior == "empirical" :
293
+ self .distribution = cebra .distributions .DiscreteEmpirical (self .discrete_index )
294
+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior == "uniform" :
295
+ self .distribution = cebra .distributions .DiscreteUniform (self .discrete_index )
296
+ elif self .positive_sampling == "discrete_variable" and self .discrete_sampling_prior not in ["empirical" , "uniform" ]:
297
+ raise ValueError (
298
+ f"Invalid choice of prior distribution. Got '{ self .discrete_sampling_prior } ', but "
299
+ f"only accept 'uniform' or 'empirical' as potential values." )
300
+ else :
301
+ raise ValueError (
302
+ f"Invalid positive sampling mode: "
303
+ f"{ self .positive_sampling } valid options are "
304
+ f"'conditional' or 'discrete_variable'." )
285
305
286
306
def get_indices (self , num_samples : int ) -> BatchIndex :
287
307
"""Samples indices for reference, positive and negative examples.
@@ -306,12 +326,23 @@ def get_indices(self, num_samples: int) -> BatchIndex:
306
326
class.
307
327
- Sample the negatives with matching discrete variable
308
328
"""
309
- reference_idx = self .distribution .sample_prior (num_samples )
310
- return BatchIndex (
311
- reference = reference_idx ,
312
- negative = self .distribution .sample_prior (num_samples ),
313
- positive = self .distribution .sample_conditional (reference_idx ),
314
- )
329
+ if self .positive_sampling == "conditional" :
330
+ reference_idx = self .distribution .sample_prior (num_samples )
331
+ return BatchIndex (
332
+ reference = reference_idx ,
333
+ negative = self .distribution .sample_prior (num_samples ),
334
+ positive = self .distribution .sample_conditional (reference_idx ),
335
+ )
336
+ else :
337
+ # taken from the DiscreteDataLoader get_indices function
338
+ reference_idx = self .distribution .sample_prior (num_samples * 2 )
339
+ negative_idx = reference_idx [num_samples :]
340
+ reference_idx = reference_idx [:num_samples ]
341
+ reference = self .discrete_index [reference_idx ]
342
+ positive_idx = self .distribution .sample_conditional (reference )
343
+ return BatchIndex (reference = reference_idx ,
344
+ positive = positive_idx ,
345
+ negative = negative_idx )
315
346
316
347
317
348
@dataclasses .dataclass
0 commit comments