-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathrun.py
55 lines (42 loc) · 1.17 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os, sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
from model.model_builder import init_model
from model import *
from init_config import *
from easydict import EasyDict as edict
import sys
from trainer.damnet_trainer import Trainer
import copy
import numpy as np
import random
import argparse
# +
def parse_args():
parser = argparse.ArgumentParser(description='')
parser.add_argument('--config', dest='config_file',
help='configuration filename',
default=None, type=str)
return parser.parse_args()
def main():
seed = 1234
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.enabled = True
cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
args = parse_args()
if args.config_file is None:
raise Exception('no configuration file')
config, writer = init_config(args.config_file, sys.argv)
config.num_classes = 19
model = init_model(config)
trainer = Trainer(model, config, writer)
trainer.train()
# -
if __name__ == "__main__":
main()