Skip to content

Commit 40e9b5f

Browse files
committed
Add the iter and get_full_pattern methods
and teh respective tests for the ShuffleDataloader
1 parent f0a7153 commit 40e9b5f

File tree

2 files changed

+160
-8
lines changed

2 files changed

+160
-8
lines changed

src/elise/data.py

Lines changed: 114 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -435,15 +435,32 @@ def get_full_pattern(self, dt):
435435
"""
436436

437437
full_pattern = []
438-
for _, pattern in self.iter(0, self.duration, dt):
438+
for _, pattern in self.iter(0, self.duration - 0.5 * dt, dt):
439439
full_pattern.append(pattern)
440440

441441
return np.array(full_pattern)
442442

443443

444444
class ShuffleDataloader(BaseDataloader):
445445
"""
446-
TODO.
446+
A dataloader that shuffles patterns and applies transformations.
447+
448+
This class extends `BaseDataloader` and is designed to load data patterns,
449+
shuffle them, and apply transformations both preemptively and online.
450+
451+
:param pattern: List of patterns to shuffle and iterate over.
452+
:type pattern: List[Pattern]
453+
:param t_max: Maximum time for the sequence.
454+
:type t_max: float
455+
:param t_start: Starting time for the sequence, defaults to 0.0.
456+
:type t_start: float, optional
457+
:param seed: Seed for random number generation, defaults to None.
458+
:type seed: Optional[int], optional
459+
:param pre_transforms: List of preprocessing transformations to apply once.
460+
:type pre_transforms: List[Callable]
461+
:param online_transforms: List of online transformations to apply during iteration,
462+
defaults to an empty list.
463+
:type online_transforms: List[Callable]
447464
"""
448465

449466
def __init__(
@@ -456,7 +473,23 @@ def __init__(
456473
online_transforms: List[Callable] = [],
457474
) -> None:
458475
"""
459-
TODO.
476+
Initialize the ShuffleDataloader instance.
477+
478+
Creates a shuffled sequence of patterns and applies pre-transforms.
479+
480+
:param pattern: List of patterns to shuffle and iterate over.
481+
:type pattern: List[Pattern]
482+
:param t_max: Maximum time for the sequence.
483+
:type t_max: float
484+
:param t_start: Starting time for the sequence, defaults to 0.0.
485+
:type t_start: float, optional
486+
:param seed: Seed for random number generation, defaults to None.
487+
:type seed: Optional[int], optional
488+
:param pre_transforms: List of preprocessing transformations to apply once.
489+
:type pre_transforms: List[Callable]
490+
:param online_transforms: List of online transformations to apply during
491+
iteration, defaults to an empty list.
492+
:type online_transforms: List[Callable]
460493
"""
461494
self.pattern = copy.deepcopy(pattern)
462495
self.num_pattern = len(pattern)
@@ -490,15 +523,40 @@ def __init__(
490523
)
491524

492525
def _apply_online_transforms(self, pattern_1d):
526+
"""
527+
Apply online transformations to a pattern.
528+
529+
This method applies all specified online transformations in sequence
530+
to the given 1-dimensional pattern.
531+
532+
:param pattern_1d: The input pattern to transform.
533+
:type pattern_1d: Any
534+
:return: Transformed pattern.
535+
:rtype: Any
536+
"""
493537
for transform in self.online_transforms:
494538
pattern_1d = transform(pattern_1d)
495539
return pattern_1d
496540

497541
def __call__(self, t: float, offset: float = 1e-6):
498542
"""
499-
TODO.
500-
"""
543+
Retrieve the pattern at a specific time `t`.
544+
545+
This method finds the appropriate pattern based on the given time `t`,
546+
applies online transformations, and returns the transformed result.
547+
548+
:param t: Time at which to retrieve the pattern.
549+
:type t: float
550+
:param offset: Small offset added to `t` for numerical stability,
551+
defaults to 1e-6.
552+
:type offset: float, optional
553+
:return: Transformed pattern at time `t`.
554+
:rtype: Any
501555
556+
:raises ValueError: If `t` exceeds `t_max`.
557+
"""
558+
if t > self.t_max:
559+
raise ValueError("t must be small smaller than t_max.")
502560
# find index of the pattern that is at t:
503561
t += offset
504562
idx_in_sequence = np.searchsorted(self._starting_times, t) - 1
@@ -514,10 +572,58 @@ def __call__(self, t: float, offset: float = 1e-6):
514572
return pattern_t
515573

516574
def iter(self, t_start, t_stop, dt):
517-
raise NotImplementedError()
575+
"""
576+
Use dataloader as an iterator/iterable.
518577
519-
def get_full_pattern(self, dt):
520-
raise NotImplementedError()
578+
Generates tuples of time and corresponding patterns between `t_start`
579+
and `t_stop` with a step size of `dt`.
580+
581+
:param t_start: Start time of iteration.
582+
:type t_start: float
583+
:param t_stop: Stop time of iteration.
584+
:type t_stop: float
585+
:param dt: Time step between iterations.
586+
:type dt: float
587+
:yield: Tuple containing time and corresponding pattern.
588+
:rtype: Iterator[Tuple[float, np.ndarray]]
589+
590+
Example:
591+
>>> for t, pattern in dataloader.iter(t_start=0.0, t_stop=10.0, dt=0.1):
592+
... print(t, pattern)
593+
...
594+
0.0 [pattern data]
595+
0.1 [pattern data]
596+
...
597+
9.9 [pattern data]
598+
"""
599+
t = t_start
600+
while t < t_stop:
601+
yield t, self.__call__(t, offset=dt * 0.01)
602+
t += dt
603+
604+
def get_full_pattern(self, dt, idx: Optional[List[int]] = None):
605+
"""
606+
Retrieve the full concatenated pattern sequence.
607+
608+
Combines all selected patterns into a single array after applying online
609+
transformations.
610+
611+
:param dt: Time step used for sampling patterns.
612+
:type dt: float
613+
:param idx: Indices of patterns to include in the concatenation. Defaults to
614+
all patterns if None.
615+
:type idx: Optional[List[int]]
616+
:return: The concatenated array of transformed patterns.
617+
:rtype: np.ndarray
618+
"""
619+
full_pattern = []
620+
if idx is None:
621+
idx = list(range(self.num_pattern))
622+
for i in idx:
623+
dl = Dataloader(self.pattern[i], online_transforms=self.online_transforms)
624+
full_pattern.append(dl.get_full_pattern(dt))
625+
626+
return np.concatenate(full_pattern, axis=0)
521627

522628

523629
class ContinuousDataloader(BaseDataloader):

tests/test_data.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,11 @@ def test_iter(self, dataloader_factory, base_sequence):
328328
t_exp += 0.1
329329
idx += 1
330330

331+
def test_get_full_pattern(self, dataloader_factory, base_sequence):
332+
dataloader = dataloader_factory()
333+
res = dataloader.get_full_pattern(DT)
334+
assert_allclose(res, base_sequence)
335+
331336
def test_Dataloader_factory_error(self, dataloader_factory):
332337
with pytest.raises(TypeError):
333338
dataloader_factory(pattern="Not a Pattern.")
@@ -563,3 +568,44 @@ def test_online_transforms(
563568

564569
expected = transform_mult(transform_plus(expected_pattern[idx]))
565570
assert_allclose(dataloader(t), expected)
571+
572+
def test_iter(
573+
self,
574+
multisequence,
575+
dataloader_factory,
576+
shuffled_idx_factory,
577+
multisequence_target_factory,
578+
shuffledataloader_factory,
579+
):
580+
dataloader = shuffledataloader_factory(
581+
pattern=multisequence, length=SHUFFLE_SIZE, seed=SHUFFLE_SEED
582+
)
583+
expected_random_idx = shuffled_idx_factory(
584+
num_pattern=len(multisequence), seed=SHUFFLE_SEED, length=SHUFFLE_SIZE
585+
)
586+
expected_pattern = multisequence_target_factory(
587+
multisequence, expected_random_idx
588+
)
589+
590+
res_iter = [x[1] for x in dataloader.iter(0.0, 1.0, DT)]
591+
592+
for res, exp in zip(res_iter, expected_pattern):
593+
assert_allclose(res, exp)
594+
595+
def test_get_full_pattern(
596+
self,
597+
multisequence,
598+
shuffledataloader_factory,
599+
base_sequence,
600+
base_sequence2,
601+
base_sequence3,
602+
):
603+
dataloader = shuffledataloader_factory(
604+
pattern=multisequence, length=SHUFFLE_SIZE, seed=SHUFFLE_SEED
605+
)
606+
607+
res = dataloader.get_full_pattern(dt=DT)
608+
expected = np.concatenate(
609+
(base_sequence, base_sequence2, base_sequence3), axis=0
610+
)
611+
assert_allclose(res, expected)

0 commit comments

Comments
 (0)