@@ -435,15 +435,32 @@ def get_full_pattern(self, dt):
435435 """
436436
437437 full_pattern = []
438- for _ , pattern in self .iter (0 , self .duration , dt ):
438+ for _ , pattern in self .iter (0 , self .duration - 0.5 * dt , dt ):
439439 full_pattern .append (pattern )
440440
441441 return np .array (full_pattern )
442442
443443
444444class ShuffleDataloader (BaseDataloader ):
445445 """
446- TODO.
446+ A dataloader that shuffles patterns and applies transformations.
447+
448+ This class extends `BaseDataloader` and is designed to load data patterns,
449+ shuffle them, and apply transformations both preemptively and online.
450+
451+ :param pattern: List of patterns to shuffle and iterate over.
452+ :type pattern: List[Pattern]
453+ :param t_max: Maximum time for the sequence.
454+ :type t_max: float
455+ :param t_start: Starting time for the sequence, defaults to 0.0.
456+ :type t_start: float, optional
457+ :param seed: Seed for random number generation, defaults to None.
458+ :type seed: Optional[int], optional
459+ :param pre_transforms: List of preprocessing transformations to apply once.
460+ :type pre_transforms: List[Callable]
461+ :param online_transforms: List of online transformations to apply during iteration,
462+ defaults to an empty list.
463+ :type online_transforms: List[Callable]
447464 """
448465
449466 def __init__ (
@@ -456,7 +473,23 @@ def __init__(
456473 online_transforms : List [Callable ] = [],
457474 ) -> None :
458475 """
459- TODO.
476+ Initialize the ShuffleDataloader instance.
477+
478+ Creates a shuffled sequence of patterns and applies pre-transforms.
479+
480+ :param pattern: List of patterns to shuffle and iterate over.
481+ :type pattern: List[Pattern]
482+ :param t_max: Maximum time for the sequence.
483+ :type t_max: float
484+ :param t_start: Starting time for the sequence, defaults to 0.0.
485+ :type t_start: float, optional
486+ :param seed: Seed for random number generation, defaults to None.
487+ :type seed: Optional[int], optional
488+ :param pre_transforms: List of preprocessing transformations to apply once.
489+ :type pre_transforms: List[Callable]
490+ :param online_transforms: List of online transformations to apply during
491+ iteration, defaults to an empty list.
492+ :type online_transforms: List[Callable]
460493 """
461494 self .pattern = copy .deepcopy (pattern )
462495 self .num_pattern = len (pattern )
@@ -490,15 +523,40 @@ def __init__(
490523 )
491524
492525 def _apply_online_transforms (self , pattern_1d ):
526+ """
527+ Apply online transformations to a pattern.
528+
529+ This method applies all specified online transformations in sequence
530+ to the given 1-dimensional pattern.
531+
532+ :param pattern_1d: The input pattern to transform.
533+ :type pattern_1d: Any
534+ :return: Transformed pattern.
535+ :rtype: Any
536+ """
493537 for transform in self .online_transforms :
494538 pattern_1d = transform (pattern_1d )
495539 return pattern_1d
496540
497541 def __call__ (self , t : float , offset : float = 1e-6 ):
498542 """
499- TODO.
500- """
543+ Retrieve the pattern at a specific time `t`.
544+
545+ This method finds the appropriate pattern based on the given time `t`,
546+ applies online transformations, and returns the transformed result.
547+
548+ :param t: Time at which to retrieve the pattern.
549+ :type t: float
550+ :param offset: Small offset added to `t` for numerical stability,
551+ defaults to 1e-6.
552+ :type offset: float, optional
553+ :return: Transformed pattern at time `t`.
554+ :rtype: Any
501555
556+ :raises ValueError: If `t` exceeds `t_max`.
557+ """
558+ if t > self .t_max :
559+ raise ValueError ("t must be small smaller than t_max." )
502560 # find index of the pattern that is at t:
503561 t += offset
504562 idx_in_sequence = np .searchsorted (self ._starting_times , t ) - 1
@@ -514,10 +572,58 @@ def __call__(self, t: float, offset: float = 1e-6):
514572 return pattern_t
515573
516574 def iter (self , t_start , t_stop , dt ):
517- raise NotImplementedError ()
575+ """
576+ Use dataloader as an iterator/iterable.
518577
519- def get_full_pattern (self , dt ):
520- raise NotImplementedError ()
578+ Generates tuples of time and corresponding patterns between `t_start`
579+ and `t_stop` with a step size of `dt`.
580+
581+ :param t_start: Start time of iteration.
582+ :type t_start: float
583+ :param t_stop: Stop time of iteration.
584+ :type t_stop: float
585+ :param dt: Time step between iterations.
586+ :type dt: float
587+ :yield: Tuple containing time and corresponding pattern.
588+ :rtype: Iterator[Tuple[float, np.ndarray]]
589+
590+ Example:
591+ >>> for t, pattern in dataloader.iter(t_start=0.0, t_stop=10.0, dt=0.1):
592+ ... print(t, pattern)
593+ ...
594+ 0.0 [pattern data]
595+ 0.1 [pattern data]
596+ ...
597+ 9.9 [pattern data]
598+ """
599+ t = t_start
600+ while t < t_stop :
601+ yield t , self .__call__ (t , offset = dt * 0.01 )
602+ t += dt
603+
604+ def get_full_pattern (self , dt , idx : Optional [List [int ]] = None ):
605+ """
606+ Retrieve the full concatenated pattern sequence.
607+
608+ Combines all selected patterns into a single array after applying online
609+ transformations.
610+
611+ :param dt: Time step used for sampling patterns.
612+ :type dt: float
613+ :param idx: Indices of patterns to include in the concatenation. Defaults to
614+ all patterns if None.
615+ :type idx: Optional[List[int]]
616+ :return: The concatenated array of transformed patterns.
617+ :rtype: np.ndarray
618+ """
619+ full_pattern = []
620+ if idx is None :
621+ idx = list (range (self .num_pattern ))
622+ for i in idx :
623+ dl = Dataloader (self .pattern [i ], online_transforms = self .online_transforms )
624+ full_pattern .append (dl .get_full_pattern (dt ))
625+
626+ return np .concatenate (full_pattern , axis = 0 )
521627
522628
523629class ContinuousDataloader (BaseDataloader ):
0 commit comments