1
1
from __future__ import annotations
2
2
from pydantic import BaseModel
3
+ from typing import Optional
3
4
import torch
4
5
5
6
6
7
class StandardSampler (BaseModel , torch .utils .data .Sampler ):
7
8
proportion : float
8
9
replacement : bool
9
10
sampler : torch .utils .data .WeightedRandomSampler
11
+ seed : Optional [int ]
12
+ generator : Optional [torch .Generator ]
10
13
11
14
class Config :
12
15
arbitrary_types_allowed = True
13
16
allow_mutation = False
14
17
15
- def __init__ (self , length , proportion = 1.0 , replacement = False ):
18
+ def __init__ (self , length , proportion = 1.0 , replacement = False , seed = None ):
19
+ if seed is not None :
20
+ generator = torch .Generator ()
21
+ generator .manual_seed (seed )
22
+ else :
23
+ generator = None
16
24
BaseModel .__init__ (
17
25
self ,
18
26
proportion = proportion ,
@@ -21,13 +29,18 @@ def __init__(self, length, proportion=1.0, replacement=False):
21
29
torch .ones (length ).double (),
22
30
num_samples = int (max (1 , min (length , length * proportion ))),
23
31
replacement = replacement ,
24
- )
32
+ generator = generator ,
33
+ ),
34
+ seed = seed ,
35
+ generator = generator ,
25
36
)
26
37
27
38
def __len__ (self ):
28
39
return len (self .sampler )
29
40
30
41
def __iter__ (self ):
42
+ if self .generator is not None :
43
+ self .generator .manual_seed (self .seed )
31
44
return iter (self .sampler )
32
45
33
46
@property
@@ -51,6 +64,7 @@ def sample_proportion(self, proportion):
51
64
len (self ),
52
65
proportion ,
53
66
self .replacement ,
67
+ self .seed ,
54
68
)
55
69
sampler .sampler .weights = self .sampler .weights
56
70
return sampler
0 commit comments