-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtrain.py
More file actions
executable file
·282 lines (220 loc) · 11.2 KB
/
train.py
File metadata and controls
executable file
·282 lines (220 loc) · 11.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
"""
This script trains a convolutional neural network (CNN) to distinguish between images of corgi butts and loaf of bread. ]
It uses transfer learning with the ResNet-152 model pre-trained on the ImageNet dataset.
The script loads the dataset, preprocesses the images, initializes the model, trains the model, and saves the trained weights to a specified file path.
Usage:
python train_model.py --dataset_path [path to dataset] --model_path [path to save model] --epochs [number of epochs to train for]
Args:
- dataset_path (str): The path to the directory containing the dataset. The dataset should be organized into three subdirectories: 'train', 'valid', and 'test', each containing subdirectories for the two classes ('butt' and 'bread').
- model_path (str): The path to save the trained model's weights.
- epochs (int): The number of epochs to train the model for.
Returns:
The trained CNN model saved to the specified file path.
Example usage:
python train_model.py --dataset_path ./data --model_path ./models/butt_bread_model.pt --epochs 10
"""
import argparse
import os
import time
import torch
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
class ButtBreadModel:
"""
A PyTorch model that predicts whether an image contains a corgi's butt or a loaf of bread.
Attributes:
model (torch.nn.Module): The PyTorch model.
device (torch.device): The device (CPU or GPU) on which to run the model.
criterion (torch.nn.Module): The loss function used to train the model.
optimizer (torch.optim.Optimizer): The optimizer used to update the model's parameters.
Methods:
initialize(): Initializes the model's architecture by loading a pre-trained ResNet-152 model and replacing the
fully connected layer with a new one that outputs two classes (corgi butt or loaf of bread).
train(image_dataloaders, image_datasets, epochs=1): Trains the model on the given image datasets for the given
number of epochs. Returns the trained model.
test(image_dataloaders): Evaluate the model on the test set and return the accuracy.
save(model_path): Save the model weight to a file.
load(model_path): Load the model weight to a file
"""
def __init__(self, device):
"""Initializes the ButtBreadModel with the given device (CPU or GPU)."""
self.model = None
self.device = device
self.criterion = None
self.optimizer = None
def initialize(self):
"""
Initializes the model's architecture by loading a pre-trained ResNet-152 model
and replacing the fully connected layer with a new one
that outputs two classes (corgi butt or loaf of bread).
"""
self.model = models.resnet152(weights="IMAGENET1K_V1").to(self.device)
for parameter in self.model.parameters():
parameter.requires_grad = False
self.model.fc = torch.nn.Sequential(
torch.nn.Linear(2048, 128),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(128, 2),
).to(self.device)
self.criterion = torch.nn.CrossEntropyLoss()
self.optimizer = torch.optim.Adam(self.model.fc.parameters())
def train(self, image_dataloaders, image_datasets, epochs=1):
"""
Trains the model on the given image datasets for the given number of epochs.
Args:
image_dataloaders (dict): A dictionary containing PyTorch DataLoader objects for the training and validation
datasets.
image_datasets (dict): A dictionary containing PyTorch Dataset objects for the training and validation
datasets.
epochs (int, optional): The number of epochs to train the model. Defaults to 1.
Returns:
The trained PyTorch model.
"""
for epoch in range(epochs):
time_start = time.monotonic()
print(f"Epoch {epoch + 1}/{epochs}")
for phase in ["train", "valid"]:
if phase == "train":
self.model.train()
else:
self.model.eval()
running_loss = 0.0
running_corrects = 0
# Iterate and try to predict input and check with output -> generate loss and correct label
for inputs, labels in tqdm(image_dataloaders[phase]):
inputs = inputs.to(self.device)
labels = labels.to(self.device)
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
if phase == "train":
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
_, preds = torch.max(outputs, 1)
running_loss += loss.detach() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(image_datasets[phase])
epoch_accuracy = running_corrects.float() / len(image_datasets[phase])
print(f"{phase} loss: {epoch_loss.item():.4f}, acc: {epoch_accuracy.item():.4f}")
print("Runtime: (", "{0:.2f}".format(time.monotonic() - time_start), " seconds)", sep="")
return self.model
def test(self, image_dataloaders):
"""
Evaluate the model on the test set and return the accuracy.
Args:
image_dataloaders (dict): A dictionary containing PyTorch DataLoader objects for the train, validation, and test sets.
Returns:
float: The accuracy of the model on the test set.
"""
test_accuracy_count = 0
for k, (test_images, test_labels) in tqdm(enumerate(image_dataloaders["test"])):
test_outputs = self.model(test_images.to(self.device))
_, prediction = torch.max(test_outputs.data, 1)
test_accuracy_count += torch.sum(prediction == test_labels.to(self.device).data).item()
test_accuracy = test_accuracy_count / len(image_dataloaders["test"])
return test_accuracy
def save(self, model_path):
"""
Save the model weights to a file.
Args:
model_path (str): The path to the file where the model weights should be saved.
"""
return torch.save(self.model.state_dict(), model_path)
def load(self, model_path):
"""
Load the model weights from a file.
Args:
model_path (str): The path to the file where the model weights are stored.
Returns:
The loaded model with the saved weights.
"""
return self.model.load_state_dict(torch.load(model_path, map_location=self.device)).eval()
def get_dataset(dataset_path: str):
"""
This function takes in a dataset path and returns two dictionaries
containing the image datasets and dataloaders for training, validation, and testing.
The function applies different data transformations to each dataset depending on
whether it's the train, validation, or test dataset.
The train dataset is transformed with resize, random affine, random horizontal flip, to tensor, and normalization.
The validation and test datasets are transformed with resize, to tensor, and normalization.
Args:
dataset_path (str): The path to the dataset directory.
Returns:
image_datasets (dict): A dictionary containing three image datasets: "train", "valid", and "test". Each dataset is an instance of ImageFolder class from torchvision.datasets, and is associated with its own set of data transformations defined by data_transformers dictionary.
image_dataloaders (dict): A dictionary containing three dataloaders: "train", "valid", and "test". Each dataloader is associated with its own dataset in image_datasets and is responsible for loading the dataset with a given batch size and shuffling the data randomly for the train set. The test dataloader has a batch size of 1 since it is only used for evaluating the model. The num_workers parameter specifies how many subprocesses to use for data loading.
"""
data_transformers = {
"train": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.RandomAffine(0, shear=10, scale=(0.8, 1.2)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
),
"valid": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
),
"test": transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
),
]
),
}
image_datasets = {
"train": datasets.ImageFolder(os.path.join(dataset_path, "train"), data_transformers["train"]),
"valid": datasets.ImageFolder(os.path.join(dataset_path, "valid"), data_transformers["valid"]),
"test": datasets.ImageFolder(os.path.join(dataset_path, "test"), data_transformers["test"]),
}
image_dataloaders = {
"train": DataLoader(image_datasets["train"], batch_size=32, shuffle=True, num_workers=2),
"valid": DataLoader(image_datasets["valid"], batch_size=32, shuffle=False, num_workers=2),
"test": DataLoader(image_datasets["test"], batch_size=1, shuffle=False, num_workers=2),
}
return image_datasets, image_dataloaders
def main(opt):
"""
Train and test the ButtBreadModel on the specified dataset, and save the trained model.
Args:
opt (argparse.Namespace): The command-line arguments.
Returns:
None
"""
dataset_path, model_path, epochs = opt.dataset_path, opt.model_path, opt.epochs
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
image_datasets, image_dataloaders = get_dataset(dataset_path)
butt_bread_obj = ButtBreadModel(device=device)
butt_bread_obj.initialize()
butt_bread_obj.train(
image_dataloaders=image_dataloaders,
image_datasets=image_datasets,
epochs=epochs,
)
test_accuracy = butt_bread_obj.test(image_dataloaders=image_dataloaders)
print(f"Test accuracy: {test_accuracy}")
butt_bread_obj.save(model_path=model_path)
print(f"Saved model at {model_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str, default="datasets/", help="Dataset path")
parser.add_argument("--model-path", type=str, default="buttbread_resnet152_1.h5", help="Output model name")
parser.add_argument("--epochs", type=int, default=1, help="Number of epochs")
args = parser.parse_args()
main(args)