1
1
from __future__ import annotations
2
- from pydantic import BaseModel , PositiveInt
3
- from typing import (
4
- Tuple ,
5
- Dict ,
6
- List ,
7
- Callable ,
8
- Optional ,
9
- TypeVar ,
10
- Generic ,
11
- Union ,
12
- )
2
+
3
+ from typing import Callable , Dict , Generic , List , Optional , Tuple , TypeVar , Union
4
+
13
5
import numpy as np
14
6
import torch
15
- from pathlib import Path
7
+ from pydantic import BaseModel , PositiveInt
16
8
17
9
from datastream import Dataset
18
10
from datastream .samplers import (
19
- StandardSampler ,
20
11
MergeSampler ,
21
- ZipSampler ,
22
12
MultiSampler ,
23
13
RepeatSampler ,
14
+ StandardSampler ,
15
+ ZipSampler ,
24
16
)
25
17
26
-
27
18
T = TypeVar ("T" )
28
19
R = TypeVar ("R" )
29
20
@@ -46,7 +37,7 @@ class Datastream(BaseModel, Generic[T]):
46
37
16
47
38
"""
48
39
49
- dataset : Dataset [ T ]
40
+ dataset : Dataset
50
41
sampler : Optional [torch .utils .data .Sampler ]
51
42
52
43
class Config :
@@ -286,29 +277,25 @@ def cache(
286
277
287
278
288
279
def test_infinite ():
289
-
290
280
datastream = Datastream (Dataset .from_subscriptable (list ("abc" )))
291
281
it = iter (datastream .data_loader (batch_size = 8 , n_batches_per_epoch = 10 ))
292
282
for _ in range (10 ):
293
283
batch = next (it )
294
284
295
285
296
286
def test_iter ():
297
-
298
287
datastream = Datastream (Dataset .from_subscriptable (list ("abc" )))
299
288
assert len (list (datastream )) == 3
300
289
301
290
302
291
def test_empty ():
303
-
304
292
import pytest
305
293
306
294
with pytest .raises (ValueError ):
307
295
Datastream (Dataset .from_subscriptable (list ()))
308
296
309
297
310
298
def test_datastream_merge ():
311
-
312
299
datastream = Datastream .merge (
313
300
[
314
301
Datastream (Dataset .from_subscriptable (list ("abc" ))),
@@ -328,7 +315,6 @@ def test_datastream_merge():
328
315
329
316
330
317
def test_datastream_zip ():
331
-
332
318
datasets = [
333
319
Dataset .from_subscriptable ([1 , 2 ]),
334
320
Dataset .from_subscriptable ([3 , 4 , 5 ]),
@@ -384,7 +370,6 @@ def ZippedMergedDatastream():
384
370
385
371
386
372
def test_datastream_simple_weights ():
387
-
388
373
dataset = Dataset .from_subscriptable ([1 , 2 , 3 , 4 ])
389
374
datastream = (
390
375
Datastream (dataset )
@@ -412,7 +397,6 @@ def test_datastream_simple_weights():
412
397
413
398
414
399
def test_merge_datastream_weights ():
415
-
416
400
datasets = [
417
401
Dataset .from_subscriptable ([1 , 2 ]),
418
402
Dataset .from_subscriptable ([3 , 4 , 5 ]),
@@ -441,7 +425,6 @@ def test_merge_datastream_weights():
441
425
442
426
443
427
def test_multi_sample ():
444
-
445
428
data = [1 , 2 , 4 ]
446
429
n_multi_sample = 2
447
430
@@ -475,7 +458,6 @@ def test_multi_sample():
475
458
476
459
477
460
def test_take ():
478
-
479
461
import pytest
480
462
481
463
datastream = Datastream (Dataset .from_subscriptable (list ("abc" ))).take (2 )
@@ -494,7 +476,6 @@ def test_take():
494
476
495
477
496
478
def test_sequential_sampler ():
497
-
498
479
from datastream .samplers import SequentialSampler
499
480
500
481
dataset = Dataset .from_subscriptable (list ("abc" ))
0 commit comments