-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
40 lines (28 loc) · 1.15 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import sys
import pytorch_lightning as pl
from glom import GlomReconstruction, GlomClassification
from datasets import get_dataloaders
import utils
def main(args):
# Dataset and dataloader
dl_dict, ds_info = get_dataloaders(args)
# Model and trainer
if args.task == 'classification':
model = GlomClassification(args, img_size=ds_info['img_size'], patch_size=args.patch_size,
num_classes=ds_info['num_classes'], in_chans=ds_info['in_chans'])
elif args.task == 'reconstruction':
model = GlomReconstruction(args, img_size=ds_info['img_size'], patch_size=args.patch_size,
in_chans=ds_info['in_chans'])
else:
raise Exception(f'Error. Task "{args.task}" is not supported.')
# Logger
logger = utils.get_logger(args)
# Create trainer
trainer = pl.Trainer(max_epochs=args.epochs, gpus=args.gpus, logger=logger)
# Fit
trainer.fit(model, train_dataloaders=dl_dict['train'], val_dataloaders=dl_dict['val'])
if __name__ == '__main__':
# Retrieve input args
args = utils.parse_args(sys.argv[1:])
# Run main
main(args)