Skip to content

Commit 7a05e46

Browse files
committed
improve: stop using deprecated functionality
1 parent b6b326a commit 7a05e46

10 files changed

+432
-394
lines changed

datastream/dataset.py

+22-19
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,28 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
2+
3+
import inspect
4+
import random
5+
import string
6+
import textwrap
7+
from functools import lru_cache
8+
from pathlib import Path
39
from typing import (
4-
Tuple,
510
Callable,
6-
Union,
7-
List,
8-
TypeVar,
9-
Generic,
1011
Dict,
11-
Optional,
12+
Generic,
1213
Iterable,
14+
List,
15+
Optional,
16+
Tuple,
17+
TypeVar,
18+
Union,
1319
)
14-
from pathlib import Path
15-
from functools import lru_cache
16-
import string
17-
import random
18-
import textwrap
19-
import inspect
20+
2021
import numpy as np
2122
import pandas as pd
22-
from datastream import tools
23+
from pydantic import BaseModel, ConfigDict
2324

25+
from datastream import tools
2426

2527
T = TypeVar("T")
2628
R = TypeVar("R")
@@ -53,9 +55,10 @@ class Dataset(BaseModel, Generic[T]):
5355
length: int
5456
get_item: Callable[[pd.DataFrame, int], T]
5557

56-
class Config:
57-
arbitrary_types_allowed = True
58-
allow_mutation = False
58+
model_config = ConfigDict(
59+
arbitrary_types_allowed=True,
60+
frozen=True,
61+
)
5962

6063
@staticmethod
6164
def from_subscriptable(subscriptable) -> Dataset:
@@ -96,7 +99,7 @@ def from_dataframe(dataframe: pd.DataFrame) -> Dataset[pd.Series]:
9699

97100
@staticmethod
98101
def from_paths(paths: Iterable[str, Path], pattern: str) -> Dataset[pd.Series]:
99-
"""
102+
r"""
100103
Create ``Dataset`` from paths using regex pattern that extracts information
101104
from the path itself.
102105
:func:`Dataset.__getitem__` will return a row from the dataframe and
@@ -154,7 +157,7 @@ def __eq__(self: Dataset[T], other: Dataset[R]) -> bool:
154157
return True
155158

156159
def replace(self, **kwargs):
157-
new_dict = self.dict()
160+
new_dict = self.model_dump()
158161
new_dict.update(**kwargs)
159162
return type(self)(**new_dict)
160163

datastream/datastream.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66
import torch
7-
from pydantic import BaseModel, PositiveInt
7+
from pydantic import BaseModel, ConfigDict, PositiveInt
88

99
from datastream import Dataset
1010
from datastream.samplers import (
@@ -40,9 +40,10 @@ class Datastream(BaseModel, Generic[T]):
4040
dataset: Dataset
4141
sampler: Optional[torch.utils.data.Sampler]
4242

43-
class Config:
44-
arbitrary_types_allowed = True
45-
allow_mutation = False
43+
model_config = ConfigDict(
44+
arbitrary_types_allowed=True,
45+
frozen=True,
46+
)
4647

4748
def __init__(self, dataset: Dataset[T], sampler: torch.utils.data.Sampler = None):
4849
if len(dataset) == 0:

datastream/samplers/merge_sampler.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from typing import Callable, Iterable, Tuple
66

77
import torch
8-
from pydantic import BaseModel
8+
import torch.utils.data
9+
from pydantic import BaseModel, ConfigDict
910

1011
from datastream import Dataset
1112
from datastream.tools import repeat_map_chain
@@ -19,9 +20,10 @@ class MergeSampler(BaseModel, torch.utils.data.Sampler):
1920
from_mapping: Callable[[int], Tuple[int, int]]
2021
merged_samplers: Iterable
2122

22-
class Config:
23-
arbitrary_types_allowed = True
24-
allow_mutation = False
23+
model_config = ConfigDict(
24+
arbitrary_types_allowed=True,
25+
frozen=True,
26+
)
2527

2628
def __init__(self, samplers, datasets, ns):
2729
BaseModel.__init__(

datastream/samplers/multi_sampler.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
3-
from typing import Tuple, Iterable
2+
43
from itertools import chain, islice
4+
from typing import Iterable, Tuple
5+
56
import torch
6-
from datastream.tools import repeat_map_chain
7-
from datastream.samplers import StandardSampler
7+
from pydantic import BaseModel, ConfigDict
8+
89
from datastream import Dataset
10+
from datastream.samplers import StandardSampler
11+
from datastream.tools import repeat_map_chain
912

1013

1114
# TODO: write custom sampler that avoid replacement between samplers
@@ -15,9 +18,10 @@ class MultiSampler(BaseModel, torch.utils.data.Sampler):
1518
length: int
1619
merged_samplers: Iterable
1720

18-
class Config:
19-
arbitrary_types_allowed = True
20-
allow_mutation = False
21+
model_config = ConfigDict(
22+
arbitrary_types_allowed=True,
23+
frozen=True,
24+
)
2125

2226
def __init__(self, samplers, dataset):
2327
BaseModel.__init__(

datastream/samplers/repeat_sampler.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
2+
33
from typing import Iterable
4+
45
import torch
6+
import torch.utils.data
7+
from pydantic import BaseModel, ConfigDict
58

69

710
class RepeatSampler(BaseModel, torch.utils.data.Sampler):
@@ -10,8 +13,7 @@ class RepeatSampler(BaseModel, torch.utils.data.Sampler):
1013
epoch_bound: bool = False
1114
queue: Iterable
1215

13-
class Config:
14-
arbitrary_types_allowed = True
16+
model_config = ConfigDict(arbitrary_types_allowed=True)
1517

1618
def __init__(self, sampler, length, epoch_bound=False):
1719
"""

datastream/samplers/sequential_sampler.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
2+
33
import torch
4+
import torch.utils.data
5+
from pydantic import BaseModel, ConfigDict
46

57

68
class SequentialSampler(BaseModel, torch.utils.data.Sampler):
79
sampler: torch.utils.data.SequentialSampler
810

9-
class Config:
10-
arbitrary_types_allowed = True
11-
allow_mutation = False
11+
model_config = ConfigDict(
12+
arbitrary_types_allowed=True,
13+
frozen=True,
14+
)
1215

1316
def __init__(self, length):
1417
BaseModel.__init__(

datastream/samplers/standard_sampler.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
2+
33
import torch
4+
import torch.utils.data
5+
from pydantic import BaseModel, ConfigDict
46

57

68
class StandardSampler(BaseModel, torch.utils.data.Sampler):
79
proportion: float
810
replacement: bool
911
sampler: torch.utils.data.WeightedRandomSampler
1012

11-
class Config:
12-
arbitrary_types_allowed = True
13-
allow_mutation = False
13+
model_config = ConfigDict(
14+
arbitrary_types_allowed=True,
15+
frozen=True,
16+
)
1417

1518
def __init__(self, length, proportion=1.0, replacement=False):
1619
BaseModel.__init__(

datastream/samplers/zip_sampler.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
from __future__ import annotations
2-
from pydantic import BaseModel
3-
from typing import Tuple, Callable, Iterable
2+
43
from functools import partial
54
from itertools import islice
5+
from typing import Callable, Iterable, Tuple
6+
67
import torch
7-
from datastream.tools import starcompose, repeat_map_chain
8+
import torch.utils.data
9+
from pydantic import BaseModel, ConfigDict
10+
811
from datastream import Dataset
12+
from datastream.tools import repeat_map_chain, starcompose
913

1014

1115
class ZipSampler(BaseModel, torch.utils.data.Sampler):
@@ -15,9 +19,10 @@ class ZipSampler(BaseModel, torch.utils.data.Sampler):
1519
from_mapping: Callable[[int], Tuple[int, ...]]
1620
zipped_samplers: Iterable
1721

18-
class Config:
19-
arbitrary_types_allowed = True
20-
allow_mutation = False
22+
model_config = ConfigDict(
23+
arbitrary_types_allowed=True,
24+
frozen=True,
25+
)
2126

2227
def __init__(self, samplers, datasets):
2328
BaseModel.__init__(

datastream/tools/verify_split.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import json
22
from pathlib import Path
3-
from pydantic import validate_arguments
43

4+
from pydantic import validate_call
55

6-
@validate_arguments
6+
7+
@validate_call
78
def verify_split(old_path: Path, new_path: Path):
89
"""
910
Verify that no keys from an old split are present in a different new split.

0 commit comments

Comments
 (0)