-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathclassification_main.py
60 lines (49 loc) · 1.92 KB
/
classification_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset.classificatoin_dataset import ClassificationDataset
from models.shufflenetv2 import ShuffleNetV2
from models.mobilenetv3 import MobileNetV3
from train.classification_train import train_step
def get_args_parser():
parser = argparse.ArgumentParser(description='Set Pix2Pix training', add_help=False)
parser.add_argument('--path', defaults='/dataset/', type=str,
help='Path of data')
parser.add_argument('--img_size', default='256', type=int,
help='Input size of Pix2Pix model')
parser.add_argument('--device', default='cuda' if torch.cuda.is_availabel() else 'cpu', type=str,
help='Set device')
parser.add_argument('--epoch', default=100, type=int)
parser.add_argument('--batch_size', default=16, type=int)
return parser
def main(args):
device = torch.device(args.device)
path = args.path
train_loader = DataLoader(
ClassificationDataset(path=args.path, subset='train', img_size=args.img_size, transforms_=True),
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
valid_loader = DataLoader(
ClassificationDataset(path=args.path, subset='valid', img_size=args.img_size, transforms_=True),
batch_size=args.batch_size,
shuffle=True,
drop_last=True,
)
model = ShuffleNetV2().to(device)
history = train_step(
model,
train_data=train_loader,
validation_data=valid_loader,
n_epochs=args.epoch,
learning_rate_scheduler=args.lr_scheduler,
check_point=args.check_point,
early_stop=args.early_stop,
)
if __name__ == '__main__':
parser = argparse.ArgumentParser('Classification Model Training', parents=[get_args_parser()])
args = parser.parse_args()
main(args)