-
Notifications
You must be signed in to change notification settings - Fork 45
Expand file tree
/
Copy pathexample_conv.py
More file actions
29 lines (24 loc) · 743 Bytes
/
example_conv.py
File metadata and controls
29 lines (24 loc) · 743 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
from network import Network
from conv_layer import ConvLayer
from activation_layer import ActivationLayer
from activations import tanh, tanh_prime
from losses import mse, mse_prime
# training data
x_train = [np.random.rand(10,10,1)]
y_train = [np.random.rand(4,4,2)]
# network
net = Network()
net.add(ConvLayer((10,10,1), (3,3), 1))
net.add(ActivationLayer(tanh, tanh_prime))
net.add(ConvLayer((8,8,1), (3,3), 1))
net.add(ActivationLayer(tanh, tanh_prime))
net.add(ConvLayer((6,6,1), (3,3), 2))
net.add(ActivationLayer(tanh, tanh_prime))
# train
net.use(mse, mse_prime)
net.fit(x_train, y_train, epochs=1000, learning_rate=0.3)
# test
out = net.predict(x_train)
print("predicted = ", out)
print("expected = ", y_train)