This repository has been archived by the owner on Sep 13, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneural_network_functions.py
76 lines (69 loc) · 3.04 KB
/
neural_network_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
_author_ = 'luke'
import time
import matplotlib.pyplot as plt
from matplotlib import font_manager
import matplotlib
from xkcd import xkcdify
def trainNetwork(trainer, runs, verbose):
"""
Trains the network for the given number of runs and returns statistics on the training
:param trainer: the neural network trainer to train the network on
:param runs: the number of times to train the network
:param verbose: boolean value to indicate verbose output
:return totalTime: the total amount of time it took to train the network
:return averageTimePerEpoch: the amount of time it took on average to train the netwrok once
:return trainerErrorValues: the raw error values from the neural network trainer
:return epochTimes: list of amount of time it took for each training run
"""
epochTimes = []
trainerErrorValues = []
globalStart = time.time()
for i in range(1,runs):
if verbose: print (str((i/(runs*1.0)) *100) + '% complete')
localStart = time.time()
trainerErrorValues.append(trainer.train())
localEnd = time.time()
epochTimes.append(localEnd - localStart)
globalEnd = time.time()
totalTime = (globalEnd - globalStart)
averageTimePerEpoch = sum(epochTimes)/len(epochTimes)
return totalTime, averageTimePerEpoch, trainerErrorValues, epochTimes
def graphOutput(xTrain, yTrain, xTest, yTest, futurePredictions, trainingPredictions, ticker):
"""
Graphs the data set and the predictions, styles the graph like xkcd, and saves the graph
to the working directory
:param xTrain: training data set of time values
:param yTrain: training data set of price values
:param xTest: testing data set of time values
:param yTest: testing data set of price values
:param futurePredictions: data set containing the predictions for the testing data
:param trainingPredictions: data set containing the predictions for the training data
:param ticker: the stock that the graphs are referencing
:return: none
"""
plt.figure(1)
prop = font_manager.FontProperties(fname='Humor-Sans-1.0.ttf')
matplotlib.rcParams['font.family'] = prop.get_name()
plt.subplot(2, 1, 1)
plt.tight_layout()
l1, = plt.plot(xTest, yTest, 'w-', label='line1')
l2, = plt.plot(xTest, futurePredictions, 'w--', label='line2')
plt.xlabel('Time (days)')
plt.ylabel('Price (USD)')
ax = plt.gca()
box = ax.get_position()
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
leg = plt.legend([l1, l2], ['Actual Values', 'Predictions'], framealpha=0, loc='center left', bbox_to_anchor=(1, 0.5), borderaxespad=0.)
for text in leg.get_texts():
text.set_color('#91A2C4')
xkcdify(plt)
plt.subplot(2, 1, 2)
plt.tight_layout()
plt.plot(xTrain, yTrain, 'w-')
plt.plot(xTrain, trainingPredictions, 'w--')
plt.xlabel('Time (days)')
plt.ylabel('Price (USD)')
xkcdify(plt)
# plt.show()
plt.savefig(ticker + 'NN.png', transparent=True, bbox_extra_artists=(leg,), bbox_inches='tight', dpi=600)
return