-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
79 lines (60 loc) · 1.85 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import argparse
import torch
from data import load_data
from rolling import Rolling
from utils import save_or_append
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("It is using %s device" % device)
path_results = 'results.csv'
params_data_generation = {
'batch_size': 4000,
'shuffle': False,
'num_workers': 0,
}
params_dataset = {
'target' : 'T (degC)',
'device' : device
}
rolling_dict = {
'len_buffer' : 100,
'len_training' : 30000,
'len_test' : 5000,
'len_val' : 5000,
'n_fold' : 5,
}
params_model = {
#"model_name" : 'DA-RNN',
#"n_hidden" : 64,
"num_layers" : 2,
"dropout" : 0.15,
"learning_rate" : 1e-3,
"adam_eps" : 1e-8,
"n_epoch" : 30,
#"name" : df.columns,
"device" : device,
"verbose" : True,
"n_epochs_stop" : 30
}
def main():
ap = argparse.ArgumentParser("Weather Forecast")
ap.add_argument("--model", choices=["DA-RNN", "LSTM", "naive_last_step", "naive_rolling_average", "conv"], help="chose which model to train")
ap.add_argument("--nhidden", type=int, help="dimension of the hidden layer")
ap.add_argument("--npast", type=int, help="number of past step as inputs")
ap.add_argument("--timeshift", type=int, help="prediction horizon")
args = ap.parse_args()
print("Loading Data")
df = load_data()
print("Data Loaded")
params_dataset['n_past'] = args.npast
params_dataset['time_shift'] = args.timeshift
params_dataset['df'] = df
params_model['n_hidden'] = args.nhidden
params_model['model_name'] = args.model
params_model['name'] = df.columns
Roll = Rolling(rolling_dict, params_dataset, params_data_generation, params_model)
Roll.rolling_training()
Roll.compute_loss()
results = Roll.return_df()
save_or_append(results, path_results)
if __name__ == "__main__":
main()