-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
112 lines (97 loc) · 3.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gin
import argparse
from src import (
train,
cache,
analyze,
instantiate,
evaluate,
helpers,
mix_with_scaper,
make_scaper_datasets
)
import nussl
import subprocess
from src.helpers import build_logger
from src.debug import DebugDataset
import os
import copy
def edit(experiment_config):
subprocess.run([
f'vim {experiment_config}'
], shell=True)
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('func', type=str)
parser.add_argument('-exp', '--experiment_config', type=str)
parser.add_argument('-dat', '--data_config', default=None, type=str)
parser.add_argument('-env', '--environment_config', default=None, type=str)
parser.add_argument('-out', '--output_folder', type=str)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_arguments()
special_commands = ['all', 'debug']
if args.func not in special_commands:
if args.func not in globals():
raise ValueError(f"No matching function named {args.func}!")
func = globals()[args.func]
_configs = [
args.environment_config,
args.data_config,
args.experiment_config
]
for _config in _configs:
if _config is not None:
gin.parse_config_file(_config)
build_logger()
if args.func == 'debug':
# overfit to a single batch for a given length.
# save the model
# evaluate it on that same sample
# do this via binding parameters to gin config
# then set args.func = 'all'
debug_output_folder = os.path.join(
helpers.output_folder(), 'debug')
gin.bind_parameter(
'output_folder._output_folder',
debug_output_folder
)
with gin.config_scope('train'):
train_dataset = helpers.build_dataset()
test_dataset = copy.deepcopy(train_dataset)
test_dataset.transform = None
test_dataset.cache_populated = False
train_dataset = DebugDataset(train_dataset)
val_dataset = copy.deepcopy(train_dataset)
val_dataset.dataset_length = 1
test_dataset = DebugDataset(test_dataset)
test_dataset.dataset_length = 1
test_dataset.idx = train_dataset.idx
gin.bind_parameter('train/build_dataset.dataset_class', train_dataset)
gin.bind_parameter('val/build_dataset.dataset_class', val_dataset)
gin.bind_parameter('test/build_dataset.dataset_class', test_dataset)
gin.bind_parameter('train.num_epochs', 1)
args.func = 'all'
if args.func == 'all':
train()
evaluate()
analyze()
elif args.func == 'instantiate':
func(args.output_folder)
elif args.func == 'edit':
func(args.experiment_config)
elif args.func == 'cache':
def _setup_for_cache(scope):
with gin.config_scope(scope):
_dataset = helpers.build_dataset()
_dataset.cache_populated = False
gin.bind_parameter(
f'{scope}/build_dataset.dataset_class',
_dataset
)
for scope in ['train', 'val', 'test']:
_setup_for_cache(scope)
cache()
else:
func()