-
Notifications
You must be signed in to change notification settings - Fork 2
/
mnist_datamodule.py
65 lines (56 loc) · 2.21 KB
/
mnist_datamodule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import lightning as L
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
class MNISTDataModule(L.LightningDataModule):
def __init__(
self,
data_dir: str = "data/",
batch_size: int = 64,
num_workers: int = 0,
pin_memory: bool = False,
):
super().__init__()
self.save_hyperparameters()
self.transforms = transforms.ToTensor()
self.data = {}
def prepare_data(self) -> None:
MNIST(self.hparams.data_dir, train=True, download=True)
MNIST(self.hparams.data_dir, train=False, download=True)
def setup(self, stage: str | None = None) -> None:
if not self.data:
dataset = MNIST(
self.hparams.data_dir, train=True, transform=self.transforms
)
self.data["train"], self.data["val"] = random_split(dataset, [55000, 5000])
self.data["test"] = MNIST(
self.hparams.data_dir, train=False, transform=self.transforms
)
def train_dataloader(self) -> TRAIN_DATALOADERS:
return DataLoader(
dataset=self.data["train"],
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
persistent_workers=self.hparams.num_workers > 0,
shuffle=True,
)
def val_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(
dataset=self.data["val"],
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
persistent_workers=self.hparams.num_workers > 0,
shuffle=False,
)
def test_dataloader(self) -> EVAL_DATALOADERS:
return DataLoader(
dataset=self.data["test"],
batch_size=self.hparams.batch_size,
num_workers=self.hparams.num_workers,
pin_memory=self.hparams.pin_memory,
persistent_workers=self.hparams.num_workers > 0,
shuffle=False,
)