-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathxor.cpp
60 lines (49 loc) · 2.13 KB
/
xor.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include "src/nn.cpp"
int main(void) {
NeuralNetworkConfig config;
config.inputSize = 2;
config.hiddenSize = 3;
config.outputSize = 1;
config.learningRate = 0.1;
NeuralNetwork neuralNetwork(config, SIGMOID);
int modelLoaded = neuralNetwork.loadModel("xor-model.txt");
if (!modelLoaded) {
std::vector<std::pair<std::vector<double>, std::vector<double>>> trainingData = {
{{0.0, 0.0}, {0.0}},
{{0.0, 1.0}, {1.0}},
{{1.0, 0.0}, {1.0}},
{{1.0, 1.0}, {0.0}}
};
std::cout << "Training..." << std::endl;
neuralNetwork.train(trainingData, trainingData, 10000000);
std::cout << "Done!" << std::endl;
neuralNetwork.saveModel("xor-model.txt");
}
std::vector<std::vector<double>> testData = {
{0.0, 0.0},
{0.0, 1.0},
{1.0, 0.0},
{1.0, 1.0}
};
int correctPredictions = 0;
for (auto& input : testData) {
auto output = neuralNetwork.feedforward(input);
bool correct = (output[0] < 0.5 && input[0] == 0.0 && input[1] == 0.0) ||
(output[0] > 0.5 && input[0] == 0.0 && input[1] == 1.0) ||
(output[0] > 0.5 && input[0] == 1.0 && input[1] == 0.0) ||
(output[0] < 0.5 && input[0] == 1.0 && input[1] == 1.0);
if (correct) {
correctPredictions++;
}
}
std::cout << "Accuracy: " << (correctPredictions / static_cast<double>(testData.size())) * 100 << "%" << std::endl;
for (auto& input : testData) {
auto output = neuralNetwork.feedforward(input);
bool correct = (output[0] < 0.5 && input[0] == 0.0 && input[1] == 0.0) ||
(output[0] > 0.5 && input[0] == 0.0 && input[1] == 1.0) ||
(output[0] > 0.5 && input[0] == 1.0 && input[1] == 0.0) ||
(output[0] < 0.5 && input[0] == 1.0 && input[1] == 1.0);
std::cout << "\033[0mInput: \033[1m" << input[0] << " " << input[1] << "\033[0m Output: " << (correct ? "\033[1;32m" : "\033[1;31m") << output[0] << std::endl;
}
return EXIT_SUCCESS;
}