Skip to content

Commit b6b326a

Browse files
committed
improve: support pydantic 2
1 parent 5f0f4f0 commit b6b326a

File tree

5 files changed

+520
-562
lines changed

5 files changed

+520
-562
lines changed

datastream/__init__.py

-7
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,2 @@
11
from datastream.dataset import Dataset
22
from datastream.datastream import Datastream
3-
4-
from pkg_resources import get_distribution, DistributionNotFound
5-
6-
try:
7-
__version__ = get_distribution("pytorch-datastream").version
8-
except DistributionNotFound:
9-
pass

datastream/datastream.py

+7-26
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,20 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel, PositiveInt
3-
from typing import (
4-
Tuple,
5-
Dict,
6-
List,
7-
Callable,
8-
Optional,
9-
TypeVar,
10-
Generic,
11-
Union,
12-
)
2+
3+
from typing import Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
4+
135
import numpy as np
146
import torch
15-
from pathlib import Path
7+
from pydantic import BaseModel, PositiveInt
168

179
from datastream import Dataset
1810
from datastream.samplers import (
19-
StandardSampler,
2011
MergeSampler,
21-
ZipSampler,
2212
MultiSampler,
2313
RepeatSampler,
14+
StandardSampler,
15+
ZipSampler,
2416
)
2517

26-
2718
T = TypeVar("T")
2819
R = TypeVar("R")
2920

@@ -46,7 +37,7 @@ class Datastream(BaseModel, Generic[T]):
4637
16
4738
"""
4839

49-
dataset: Dataset[T]
40+
dataset: Dataset
5041
sampler: Optional[torch.utils.data.Sampler]
5142

5243
class Config:
@@ -286,29 +277,25 @@ def cache(
286277

287278

288279
def test_infinite():
289-
290280
datastream = Datastream(Dataset.from_subscriptable(list("abc")))
291281
it = iter(datastream.data_loader(batch_size=8, n_batches_per_epoch=10))
292282
for _ in range(10):
293283
batch = next(it)
294284

295285

296286
def test_iter():
297-
298287
datastream = Datastream(Dataset.from_subscriptable(list("abc")))
299288
assert len(list(datastream)) == 3
300289

301290

302291
def test_empty():
303-
304292
import pytest
305293

306294
with pytest.raises(ValueError):
307295
Datastream(Dataset.from_subscriptable(list()))
308296

309297

310298
def test_datastream_merge():
311-
312299
datastream = Datastream.merge(
313300
[
314301
Datastream(Dataset.from_subscriptable(list("abc"))),
@@ -328,7 +315,6 @@ def test_datastream_merge():
328315

329316

330317
def test_datastream_zip():
331-
332318
datasets = [
333319
Dataset.from_subscriptable([1, 2]),
334320
Dataset.from_subscriptable([3, 4, 5]),
@@ -384,7 +370,6 @@ def ZippedMergedDatastream():
384370

385371

386372
def test_datastream_simple_weights():
387-
388373
dataset = Dataset.from_subscriptable([1, 2, 3, 4])
389374
datastream = (
390375
Datastream(dataset)
@@ -412,7 +397,6 @@ def test_datastream_simple_weights():
412397

413398

414399
def test_merge_datastream_weights():
415-
416400
datasets = [
417401
Dataset.from_subscriptable([1, 2]),
418402
Dataset.from_subscriptable([3, 4, 5]),
@@ -441,7 +425,6 @@ def test_merge_datastream_weights():
441425

442426

443427
def test_multi_sample():
444-
445428
data = [1, 2, 4]
446429
n_multi_sample = 2
447430

@@ -475,7 +458,6 @@ def test_multi_sample():
475458

476459

477460
def test_take():
478-
479461
import pytest
480462

481463
datastream = Datastream(Dataset.from_subscriptable(list("abc"))).take(2)
@@ -494,7 +476,6 @@ def test_take():
494476

495477

496478
def test_sequential_sampler():
497-
498479
from datastream.samplers import SequentialSampler
499480

500481
dataset = Dataset.from_subscriptable(list("abc"))

datastream/samplers/merge_sampler.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
3-
from typing import Tuple, Callable, Iterable
2+
43
from functools import partial
54
from itertools import chain, islice
5+
from typing import Callable, Iterable, Tuple
6+
67
import torch
7-
from datastream.tools import repeat_map_chain
8+
from pydantic import BaseModel
9+
810
from datastream import Dataset
11+
from datastream.tools import repeat_map_chain
912

1013

1114
class MergeSampler(BaseModel, torch.utils.data.Sampler):
@@ -39,7 +42,9 @@ def __iter__(self):
3942

4043
@staticmethod
4144
def merged_samplers_length(samplers, ns):
42-
return min([len(sampler) / n for sampler, n in zip(samplers, ns)]) * sum(ns)
45+
return int(
46+
min([len(sampler) / n for sampler, n in zip(samplers, ns)]) * sum(ns)
47+
)
4348

4449
@staticmethod
4550
def merge_samplers(samplers, datasets, ns):

0 commit comments

Comments
 (0)