-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain_adv.py
119 lines (93 loc) · 3.81 KB
/
main_adv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
The main file, which exposes the robustness command-line tool, detailed in
:doc:`this walkthrough <../example_usage/cli_usage>`.
"""
from argparse import ArgumentParser
import os
import git
import torch as ch
import cox
import cox.utils
import cox.store
try:
from robustness.model_utils import make_and_restore_model
from robustness.datasets import DATASETS
from robustness.train import train_model, eval_model
from robustness.tools import constants, helpers
from robustness import defaults, __version__
from robustness.defaults import check_and_fill_args
except:
raise ValueError("Make sure to run with python -m (see README.md)")
parser = ArgumentParser()
parser = defaults.add_args_to_parser(defaults.CONFIG_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.MODEL_LOADER_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.TRAINING_ARGS, parser)
parser = defaults.add_args_to_parser(defaults.PGD_ARGS, parser)
def main(args, store=None):
'''Given arguments from `setup_args` and a store from `setup_store`,
trains as a model. Check out the argparse object in this file for
argument options.
'''
# MAKE DATASET AND LOADERS
data_path = os.path.expandvars(args.data)
dataset = DATASETS[args.dataset](data_path)
train_loader, val_loader = dataset.make_loaders(args.workers,
args.batch_size, data_aug=bool(args.data_aug))
train_loader = helpers.DataPrefetcher(train_loader)
val_loader = helpers.DataPrefetcher(val_loader)
loaders = (train_loader, val_loader)
# MAKE MODEL
model, checkpoint = make_and_restore_model(arch=args.arch,
dataset=dataset, resume_path=args.resume)
if 'module' in dir(model): model = model.module
print(args)
if args.eval_only:
return eval_model(args, model, val_loader, store=store)
if not args.resume_optimizer: checkpoint = None
model = train_model(args, model, loaders, store=store,
checkpoint=checkpoint)
return model
def setup_args(args):
'''
Fill the args object with reasonable defaults from
:mod:`robustness.defaults`, and also perform a sanity check to make sure no
args are missing.
'''
# override non-None values with optional config_path
if args.config_path:
args = cox.utils.override_json(args, args.config_path)
ds_class = DATASETS[args.dataset]
args = check_and_fill_args(args, defaults.CONFIG_ARGS, ds_class)
if not args.eval_only:
args = check_and_fill_args(args, defaults.TRAINING_ARGS, ds_class)
if args.adv_train or args.adv_eval:
args = check_and_fill_args(args, defaults.PGD_ARGS, ds_class)
args = check_and_fill_args(args, defaults.MODEL_LOADER_ARGS, ds_class)
if args.eval_only: assert args.resume is not None, \
"Must provide a resume path if only evaluating"
return args
def setup_store_with_metadata(args):
'''
Sets up a store for training according to the arguments object. See the
argparse object above for options.
'''
# Add git commit to args
try:
repo = git.Repo(path=os.path.dirname(os.path.realpath(__file__)),
search_parent_directories=True)
version = repo.head.object.hexsha
except git.exc.InvalidGitRepositoryError:
version = __version__
args.version = version
# Create the store
store = cox.store.Store(args.out_dir, args.exp_name)
args_dict = args.__dict__
schema = cox.store.schema_from_dict(args_dict)
store.add_table('metadata', schema)
store['metadata'].append_row(args_dict)
return store
args = parser.parse_args()
args = cox.utils.Parameters(args.__dict__)
args = setup_args(args)
store = setup_store_with_metadata(args)
final_model = main(args, store=store)