-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbnn.py
73 lines (62 loc) · 2.35 KB
/
bnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import tensorflow as tf
class BNN:
def __init__(self, layers, activation=tf.tanh):
self.L = len(layers) - 1
self.variables = self.init_network(layers)
self.bnn_fn = self.build_bnn()
self.bnn_infer_fn = self.build_infer()
self.activation = activation
def init_network(self, layers):
W, b = [], []
init = tf.zeros
# init = tf.keras.initializers.glorot_normal()
for i in range(self.L):
W += [init(shape=[layers[i], layers[i + 1]], dtype=tf.float32)]
b += [tf.zeros(shape=[1, layers[i + 1]], dtype=tf.float32)]
return W + b
def build_bnn(self):
def _fn(x, variables):
"""
BNN function, for one realization of the neural network, used for MCMC
Args:
-----
x: input,
tensor, with shape [None, input_dim]
variables: weights and bias,
list of tensors, each one of which has dimension [:, :]
Returns:
--------
y: output,
tensor, with shape [None, output_dim]
"""
W = variables[: len(variables) // 2]
b = variables[len(variables) // 2 :]
y = x
for i in range(self.L - 1):
y = self.activation(tf.matmul(y, W[i]) + b[i])
return tf.matmul(y, W[-1]) + b[-1]
return _fn
def build_infer(self):
def _fn(x, variables):
"""
BNN function, for batch of realizations of the neural network, used for inference
Args:
-----
x: input,
tensor, with shape [None, input_dim]
variables: weights and bias,
list of tensors, each one of which has dimension [batch_size, :, :]
Returns:
--------
y: output,
tensor, with shape [batch_size, None, output_dim]
"""
W = variables[: len(variables) // 2]
b = variables[len(variables) // 2 :]
batch_size = W[0].shape[0]
y = tf.tile(x[None, :, :], [batch_size, 1, 1])
for i in range(self.L - 1):
y = self.activation(tf.einsum("Nij,Njk->Nik", y, W[i]) + b[i])
return tf.einsum("Nij,Njk->Nik", y, W[-1]) + b[-1]
return _fn