Skip to content

Commit 09494bb

Browse files
author
FelixAbrahamsson
committed
feature: seeded random sampler
1 parent 7f19810 commit 09494bb

File tree

2 files changed

+46
-2
lines changed

2 files changed

+46
-2
lines changed

datastream/datastream.py

+30
Original file line numberDiff line numberDiff line change
@@ -584,3 +584,33 @@ def test_last_batch():
584584
SequentialSampler(3),
585585
)
586586
assert list(map(len, datastream.data_loader(batch_size=2))) == [2, 1]
587+
588+
589+
def test_seeded_random_sampler():
590+
import pandas as pd
591+
592+
dataset = Dataset.from_dataframe(pd.DataFrame(dict(a=np.arange(100))))
593+
datastream = Datastream(dataset, sampler=StandardSampler(len(dataset), seed=1))
594+
595+
loader = datastream.data_loader(batch_size=1, collate_fn=tuple)
596+
batches1 = [batch for batch in loader]
597+
batches2 = [batch for batch in loader]
598+
assert all(
599+
batch1[0]['a'] == batch2[0]['a']
600+
for batch1, batch2 in zip(batches1, batches2)
601+
)
602+
603+
604+
def test_unseeded_random_sampler():
605+
import pandas as pd
606+
607+
dataset = Dataset.from_dataframe(pd.DataFrame(dict(a=np.arange(100))))
608+
datastream = Datastream(dataset, sampler=StandardSampler(len(dataset)))
609+
610+
loader = datastream.data_loader(batch_size=1, collate_fn=tuple)
611+
batches1 = [batch for batch in loader]
612+
batches2 = [batch for batch in loader]
613+
assert any(
614+
batch1[0]['a'] != batch2[0]['a']
615+
for batch1, batch2 in zip(batches1, batches2)
616+
)

datastream/samplers/standard_sampler.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
from __future__ import annotations
22
from pydantic import BaseModel
3+
from typing import Optional
34
import torch
45

56

67
class StandardSampler(BaseModel, torch.utils.data.Sampler):
78
proportion: float
89
replacement: bool
910
sampler: torch.utils.data.WeightedRandomSampler
11+
seed: Optional[int]
12+
generator: Optional[torch.Generator]
1013

1114
class Config:
1215
arbitrary_types_allowed = True
1316
allow_mutation = False
1417

15-
def __init__(self, length, proportion=1.0, replacement=False):
18+
def __init__(self, length, proportion=1.0, replacement=False, seed=None):
19+
if seed is not None:
20+
generator = torch.Generator()
21+
generator.manual_seed(seed)
22+
else:
23+
generator = None
1624
BaseModel.__init__(
1725
self,
1826
proportion=proportion,
@@ -21,13 +29,18 @@ def __init__(self, length, proportion=1.0, replacement=False):
2129
torch.ones(length).double(),
2230
num_samples=int(max(1, min(length, length * proportion))),
2331
replacement=replacement,
24-
)
32+
generator=generator,
33+
),
34+
seed=seed,
35+
generator=generator,
2536
)
2637

2738
def __len__(self):
2839
return len(self.sampler)
2940

3041
def __iter__(self):
42+
if self.generator is not None:
43+
self.generator.manual_seed(self.seed)
3144
return iter(self.sampler)
3245

3346
@property
@@ -51,6 +64,7 @@ def sample_proportion(self, proportion):
5164
len(self),
5265
proportion,
5366
self.replacement,
67+
self.seed,
5468
)
5569
sampler.sampler.weights = self.sampler.weights
5670
return sampler

0 commit comments

Comments
 (0)