diff --git a/datastream/datastream.py b/datastream/datastream.py index c57895c..0729df4 100644 --- a/datastream/datastream.py +++ b/datastream/datastream.py @@ -584,3 +584,29 @@ def test_last_batch(): SequentialSampler(3), ) assert list(map(len, datastream.data_loader(batch_size=2))) == [2, 1] + + +def test_seeded_random_sampler(): + dataset = Dataset.from_subscriptable(np.arange(100)) + datastream = Datastream(dataset, sampler=StandardSampler(len(dataset), seed=1)) + + loader = datastream.data_loader(batch_size=1, collate_fn=tuple) + batches1 = [batch for batch in loader] + batches2 = [batch for batch in loader] + assert all( + batch1[0] == batch2[0] + for batch1, batch2 in zip(batches1, batches2) + ) + + +def test_unseeded_random_sampler(): + dataset = Dataset.from_subscriptable(np.arange(100)) + datastream = Datastream(dataset, sampler=StandardSampler(len(dataset))) + + loader = datastream.data_loader(batch_size=1, collate_fn=tuple) + batches1 = [batch for batch in loader] + batches2 = [batch for batch in loader] + assert any( + batch1[0] != batch2[0] + for batch1, batch2 in zip(batches1, batches2) + ) diff --git a/datastream/samplers/standard_sampler.py b/datastream/samplers/standard_sampler.py index 113366d..302fb70 100644 --- a/datastream/samplers/standard_sampler.py +++ b/datastream/samplers/standard_sampler.py @@ -1,5 +1,6 @@ from __future__ import annotations from pydantic import BaseModel +from typing import Optional import torch @@ -7,12 +8,19 @@ class StandardSampler(BaseModel, torch.utils.data.Sampler): proportion: float replacement: bool sampler: torch.utils.data.WeightedRandomSampler + seed: Optional[int] + generator: Optional[torch.Generator] class Config: arbitrary_types_allowed = True allow_mutation = False - def __init__(self, length, proportion=1.0, replacement=False): + def __init__(self, length, proportion=1.0, replacement=False, seed=None): + if seed is not None: + generator = torch.Generator() + generator.manual_seed(seed) + else: + generator = None BaseModel.__init__( self, proportion=proportion, @@ -21,13 +29,18 @@ def __init__(self, length, proportion=1.0, replacement=False): torch.ones(length).double(), num_samples=int(max(1, min(length, length * proportion))), replacement=replacement, - ) + generator=generator, + ), + seed=seed, + generator=generator, ) def __len__(self): return len(self.sampler) def __iter__(self): + if self.generator is not None: + self.generator.manual_seed(self.seed) return iter(self.sampler) @property @@ -51,6 +64,7 @@ def sample_proportion(self, proportion): len(self), proportion, self.replacement, + self.seed, ) sampler.sampler.weights = self.sampler.weights return sampler