-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathesn_experiment.py
74 lines (59 loc) · 2.38 KB
/
esn_experiment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import argparse
import json
import logging
import numpy as np
import os
from scoop import futures
import esnet
# Initialize logger
logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s %(name)-12s %(levelname)-8s %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
###############################################################################################
# The next part needs to be in the global scope, since all workers
# need access to these variables. I got pickling problems when using
# them as arguments in the evaluation function. I couldn't pickle the
# partial function for some reason, even though it should be supported.
############################################################################
# Parse input arguments
############################################################################
parser = argparse.ArgumentParser()
parser.add_argument("data", help="path to data file", type=str)
parser.add_argument("esnconfig", help="path to ESN config file", type=str)
parser.add_argument("nexp", help="number of runs", type=int)
args = parser.parse_args()
############################################################################
# Read config file
############################################################################
config = json.load(open(args.esnconfig + '.json', 'r'))
############################################################################
# Load data
############################################################################
# If the data is stored in a directory, load the data from there. Otherwise,
# load from the single file and split it.
if os.path.isdir(args.data):
Xtr, Ytr, _, _, Xte, Yte = esnet.load_from_dir(args.data)
else:
X, Y = esnet.load_from_text(args.data)
# Construct training/test sets
Xtr, Ytr, _, _, Xte, Yte = esnet.generate_datasets(X, Y)
def single_run(dummy):
"""
This function will be run by the workers.
"""
_,error = esnet.run_from_config(Xtr, Ytr, Xte, Yte, config)
return error
def main():
# Run in parallel and store result in a numpy array
errors = np.array(list(map(single_run, range(args.nexp))), dtype=float)
print("Errors:")
print(errors)
print("Mean:")
print(np.mean(errors))
print("Std:")
print(np.std(errors))
if __name__ == "__main__":
main()