-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcpu_vs_numba_vs_jax_ising.py
205 lines (164 loc) · 6.5 KB
/
cpu_vs_numba_vs_jax_ising.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Compare Gibbs sampling with CPU vs numba vs JAX for an Ising model on a 2D lattice
import numba as nb
import numpy as np
import time
from functools import partial
from jax.scipy.special import logsumexp
from jax import jit, random
from jax.lax import dynamic_slice, dynamic_update_slice, scan
from jax.nn import sigmoid
from jax import numpy as jnp
#####################################
############## Ising model ##########
#####################################
def ising_matrix(grid_side, rho=1):
# Ising model on a 2D periodic lattice
W = np.zeros((grid_side ** 2, grid_side **2))
for i in range(grid_side):
for j in range(grid_side):
idx = i * grid_side + j
idx_up = (i - 1) % grid_side * grid_side + j
idx_down = (i + 1) % grid_side * grid_side + j
idx_left = i * grid_side + (j - 1) % grid_side
idx_right = i * grid_side + (j + 1) % grid_side
for idx_neighbor in [idx_up, idx_down, idx_left, idx_right]:
W[idx, idx_neighbor] = rho
return W
def logZ_estimate(S):
# Ogata Tanamura estimator of the log-partition function
# http://www2.stat.duke.edu/~scs/Courses/Stat376/Papers/NormConstants/PotamianosGoutsiasIEEE1997.pdf
n_samples, d = S.shape
energy = - ((S @ W) * S).sum(1)
logZ = -logsumexp(energy) + d * np.log(2) + np.log(n_samples)
return float(logZ)
#####################################
###### Gibbs sampling on CPU ########
#####################################
def sigmoid_cpu(x):
if x > 0:
x = np.exp(x)
return x / (1 + x)
else:
return 1 / (1 + np.exp(-x))
def gibbs_ising_cpu(W, n_samples, n_steps=1000):
# The result is reflected in S, which is updated in place
d = W.shape[0]
S = 2 * (np.random.rand(n_samples, d) < 0.5).astype(np.float64) - 1
assert W.shape == (d, d)
assert (np.diag(W) == 0).all()
assert (W == W.T).all()
g = S.dot(W.T) # size N_samples x d, g_ij = x^{(i)}^T w_j
for step in range(n_steps):
for j in np.random.permutation(d):
delta = -2 * g[:, j : j + 1] * S[:, j : j + 1]
threshold = np.vectorize(sigmoid_cpu)(delta) # p(switch_j | x_{-j}) = sigmoid(- 2 * x_j * g_j)
flip = (np.random.rand(n_samples, 1) < threshold).astype(np.float64)
# Update S
S[:, j : j + 1] = (1 - 2 * flip) * S[:, j :j + 1]
# Update g
g += flip * 2 * S[:, j : j + 1] * W[j :j + 1]
return S
#####################################
###### Gibbs sampling in numba ######
#####################################
@nb.vectorize()
def nb_sigmoid(x):
if x > 0:
x = np.exp(x)
return x / (1 + x)
else:
return 1 / (1 + np.exp(-x))
@nb.njit(cache=True)
def gibbs_ising_numba(W, n_samples, n_steps=1000):
# The result is reflected in S, which is updated in place
d = W.shape[0]
S = 2 * (np.random.rand(n_samples, d) < 0.5).astype(np.float64) - 1
assert W.shape == (d, d)
assert (np.diag(W) == 0).all()
assert (W == W.T).all()
g = S.dot(W.T) # size N_samples x d, g_ij = x^{(i)}^T w_j
for step in range(n_steps):
for j in np.random.permutation(d):
delta = -2 * g[:, j : j + 1] * S[:, j : j + 1]
threshold = nb_sigmoid(delta) # p(switch_j | x_{-j}) = sigmoid(- 2 * x_j * g_j)
flip = (np.random.rand(n_samples, 1) < threshold).astype(np.float64)
# Update S
S[:, j : j + 1] = (1 - 2 * flip) * S[:, j :j + 1]
# Update g
g += flip * 2 * S[:, j : j + 1] * W[j :j + 1]
return S
#####################################
####### Gibbs sampling in JAX #######
#####################################
@jit
def update_gibbs_j(gSrng, j):
g, S, rng = gSrng
n_samples, d = S.shape
# g = S @ W.T + b.T
S_j = dynamic_slice(S, (0, j), (n_samples, 1))
g_j = dynamic_slice(g, (0, j), (n_samples, 1))
# Update S
delta = - 2 * S_j * g_j
threshold = sigmoid(delta)
rng, rng_input = random.split(rng)
flip = random.bernoulli(rng_input, p=threshold, shape=(n_samples, 1))
S = dynamic_update_slice(S, (1 - 2 * flip) * S_j, (0, j))
# Update g
S_j = dynamic_slice(S, (0, j), (n_samples, 1))
W_j = dynamic_slice(W, (j, 0), (1, d))
delta_g = flip * 2 * S_j * W_j
g = g.at[:].add(delta_g)
return (g, S, rng), None
@jit
def update_gibbs(gSrng, _):
g, S, rng = gSrng
n_samples, d = S.shape
rng, rng_input = random.split(rng)
order = random.permutation(rng_input, d)
# for j in order:
# (g, S, rng), _ = update_gibbs_j((g, S, rng), j)
g, S, rng = scan(update_gibbs_j, (g, S, rng), order)[0]
return (g, S, rng), None
@partial(jit, static_argnums=(1, 2)) # jit with axis being static
def gibbs_ising_jax(W, n_samples, n_steps=1000, rng=random.PRNGKey(42)):
# Vectorization of Gibbs sampling for Ising model
d = W.shape[0]
rng, rng_input = random.split(rng)
S = 2 * random.bernoulli(rng_input, p=0.5, shape=(n_samples, d)).astype(np.float32) - 1
g = S @ W.T
iters = jnp.arange(n_steps)
# for it in iters:
# (g, S, rng), _ = update_gibbs((g, S, rng), it)
g, S, rng = scan(update_gibbs, (g, S, rng), iters)[0]
return S
if __name__=="__main__":
grid_side = 5
n_samples = 1000
n_steps = 1000
# Simulate Ising model
W = ising_matrix(grid_side)
# CPU
start = time.time()
S0 = gibbs_ising_cpu(W, n_samples=n_samples, n_steps=n_steps)
t_cpu = time.time() - start
print(f"Sampling time on CPU: {t_cpu:.3f}s")
# Run numba and JAX twice to use jitted code
# Numba
_ = gibbs_ising_numba(W, n_samples=n_samples, n_steps=n_steps)
start = time.time()
S1 = gibbs_ising_numba(W, n_samples=n_samples, n_steps=n_steps)
t_numba = time.time() - start
print(f"Sampling time with numba: {t_numba:.3f}s")
# JAX
_ = gibbs_ising_jax(W, n_samples=n_samples, n_steps=n_steps)
start = time.time()
S2 = gibbs_ising_jax(W, n_samples=n_samples, n_steps=n_steps)
S2.block_until_ready()
t_jax = time.time() - start
print(f"Sampling time with JAX: {t_jax:.3f}s")
print(f"\nNumba / CPU speed up: {t_cpu / t_numba:.3f}x")
print(f"JAX / numba speed up: {t_numba / t_jax:.3f}x")
# Check that all methods give a similar estimate of the log-partition function
print(f"\nCPU log-partition function estimate: {logZ_estimate(S0):.2f}")
print(f"Numba log-partition function estimate: {logZ_estimate(S1):.2f}")
print(f"JAX log-partition function estimate: {logZ_estimate(S2):.2f}")