Skip to content

Commit 7e17351

Browse files
committed
adding tests for vmap and batching gradients
1 parent dc0f88f commit 7e17351

File tree

2 files changed

+118
-50
lines changed

2 files changed

+118
-50
lines changed

src/jax_finufft/ops.py

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def translation_rule(
188188
]
189189

190190

191-
def points_jvp(type_, dim, prim, dpoints, source, *points, output_shape, iflag, eps):
191+
def jvp(type_, prim, args, tangents, *, output_shape, iflag, eps):
192192
# Type 1:
193193
# f_k = sum_j c_j * exp(iflag * i * k * x_j)
194194
# df_k/dx_j = iflag * i * k * c_j * exp(iflag * i * k * x_j)
@@ -197,61 +197,63 @@ def points_jvp(type_, dim, prim, dpoints, source, *points, output_shape, iflag,
197197
# c_j = sum_k f_k * exp(iflag * i * k * x_j)
198198
# dc_j/dx_j = sum_k iflag * i * k * f_k * exp(iflag * i * k * x_j)
199199

200-
ndim = len(points)
201-
n = output_shape[dim] if type_ == 1 else source.shape[-ndim + dim]
202-
203-
shape = np.ones(ndim, dtype=int)
204-
shape[dim] = -1
205-
k = np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1))
206-
k = k.reshape(shape)
207-
factor = 1j * iflag * k
208-
209-
if type_ == 1:
210-
return factor * prim.bind(
211-
source * dpoints,
212-
*points,
213-
output_shape=output_shape,
214-
iflag=iflag,
215-
eps=eps,
216-
)
217-
return dpoints * prim.bind(
218-
factor * source,
219-
*points,
220-
output_shape=output_shape,
221-
iflag=iflag,
222-
eps=eps,
223-
)
224-
225-
226-
def jvp(type_, prim, args, tangents, *, output_shape, iflag, eps):
227-
# TODO: We could maybe speed this up by concatenating all the source terms and
228-
# then executing a single NUFFT since they all use the same NU points. The
229-
# bookkeeping might get a little ugly.
230-
231200
source, *points = args
232201
dsource, *dpoints = tangents
233202
output = prim.bind(source, *points, output_shape=output_shape, iflag=iflag, eps=eps)
234203

204+
# The JVP op can be written as a single transform of the same type with
235205
output_tangents = []
206+
ndim = len(points)
207+
scales = []
208+
arguments = []
236209
if type(dsource) is not ad.Zero:
237-
output_tangents.append(
238-
prim.bind(dsource, *points, output_shape=output_shape, iflag=iflag, eps=eps)
239-
)
210+
if type_ == 1:
211+
scales.append(jnp.ones_like(output))
212+
arguments.append(dsource)
213+
else:
214+
output_tangents.append(
215+
prim.bind(
216+
dsource, *points, output_shape=output_shape, iflag=iflag, eps=eps
217+
)
218+
)
240219

241-
output_tangents += [
242-
points_jvp(
243-
type_,
244-
dim,
245-
prim,
246-
dx,
247-
source,
220+
for dim, dx in enumerate(dpoints):
221+
if type(dx) is ad.Zero:
222+
continue
223+
224+
n = output_shape[dim] if type_ == 1 else source.shape[-ndim + dim]
225+
shape = np.ones(ndim, dtype=int)
226+
shape[dim] = -1
227+
k = np.arange(-np.floor(n / 2), np.floor((n - 1) / 2 + 1))
228+
k = k.reshape(shape)
229+
factor = 1j * iflag * k
230+
231+
if type_ == 1:
232+
scales.append(factor)
233+
arguments.append(dx * source)
234+
else:
235+
scales.append(dx)
236+
arguments.append(factor * source)
237+
238+
if len(scales):
239+
axis = -2 if type_ == 1 else -ndim - 1
240+
output_tangent = prim.bind(
241+
jnp.concatenate(arguments, axis=axis),
248242
*points,
249243
output_shape=output_shape,
250244
iflag=iflag,
251245
eps=eps,
252246
)
253-
for dim, dx in enumerate(dpoints)
254-
]
247+
248+
axis = -2 if type_ == 2 else -ndim - 1
249+
output_tangent *= jnp.concatenate(jnp.broadcast_arrays(*scales), axis=axis)
250+
251+
expand_shape = (
252+
output.shape[: axis + 1] + (len(scales),) + output.shape[axis + 1 :]
253+
)
254+
output_tangents.append(
255+
jnp.sum(jnp.reshape(output_tangent, expand_shape), axis=axis)
256+
)
255257

256258
return output, reduce(ad.add_tangents, output_tangents, ad.Zero.from_value(output))
257259

@@ -272,12 +274,15 @@ def transpose(type_, doutput, source, *points, output_shape, eps, iflag):
272274
return (result,) + tuple(None for _ in range(len(points)))
273275

274276

275-
def batch(prim, args, axes):
276-
# We can't batch over the last two dimensions of source
277-
mx = args[0].ndim - 2
277+
def batch(type_, prim, args, axes, **kwargs):
278+
ndim = len(args) - 1
279+
if type_ == 1:
280+
mx = args[0].ndim - 2
281+
else:
282+
mx = args[0].ndim - ndim - 1
278283
assert all(a < mx for a in axes)
279284
assert all(a == axes[0] for a in axes[1:])
280-
return prim.bind(*args), axes[0]
285+
return prim.bind(*args, **kwargs), axes[0]
281286

282287

283288
def pad_shapes(output_dim, source, *points):
@@ -308,7 +313,7 @@ def pad_shapes(output_dim, source, *points):
308313
xla.register_translation(nufft1_p, partial(translation_rule, 1), platform="cpu")
309314
ad.primitive_jvps[nufft1_p] = partial(jvp, 1, nufft1_p)
310315
ad.primitive_transposes[nufft1_p] = partial(transpose, 1)
311-
batching.primitive_batchers[nufft1_p] = partial(batch, nufft1_p)
316+
batching.primitive_batchers[nufft1_p] = partial(batch, 1, nufft1_p)
312317

313318

314319
nufft2_p = core.Primitive("nufft2")
@@ -317,4 +322,4 @@ def pad_shapes(output_dim, source, *points):
317322
xla.register_translation(nufft2_p, partial(translation_rule, 2), platform="cpu")
318323
ad.primitive_jvps[nufft2_p] = partial(jvp, 2, nufft2_p)
319324
ad.primitive_transposes[nufft2_p] = partial(transpose, 2)
320-
batching.primitive_batchers[nufft2_p] = partial(batch, nufft2_p)
325+
batching.primitive_batchers[nufft2_p] = partial(batch, 2, nufft2_p)

tests/ops_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from itertools import product
33

44
import jax
5+
import jax.numpy as jnp
56
import numpy as np
67
import pytest
78
from jax.test_util import check_grads
@@ -132,3 +133,65 @@ def test_nufft2_grad(ndim, num_nonnuniform, num_uniform, iflag):
132133
with jax.experimental.enable_x64():
133134
func = partial(nufft2, eps=eps, iflag=iflag)
134135
check_grads(func, (f, *x), 1, modes=("fwd", "rev"))
136+
137+
138+
@pytest.mark.parametrize(
139+
"ndim, num_nonnuniform, num_uniform, iflag",
140+
product([1, 2, 3], [50], [35], [-1, 1]),
141+
)
142+
def test_nufft1_vmap(ndim, num_nonnuniform, num_uniform, iflag):
143+
random = np.random.default_rng(657)
144+
145+
eps = 1e-10
146+
dtype = np.double
147+
cdtype = np.cdouble
148+
149+
num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))
150+
151+
x = [
152+
random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype)
153+
for _ in range(ndim)
154+
]
155+
c = random.normal(size=num_nonnuniform) + 1j * random.normal(size=num_nonnuniform)
156+
c = c.astype(cdtype)
157+
158+
num = 5
159+
xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x]
160+
cs = jnp.repeat(c[None], num, axis=0)
161+
162+
func = partial(nufft1, num_uniform, eps=eps, iflag=iflag)
163+
calc = jax.vmap(func)(cs, *xs)
164+
expect = func(c, *x)
165+
for n in range(num):
166+
np.testing.assert_allclose(calc[n], expect)
167+
168+
169+
@pytest.mark.parametrize(
170+
"ndim, num_nonnuniform, num_uniform, iflag",
171+
product([1, 2, 3], [50], [35], [-1, 1]),
172+
)
173+
def test_nufft2_vmap(ndim, num_nonnuniform, num_uniform, iflag):
174+
random = np.random.default_rng(657)
175+
176+
eps = 1e-10
177+
dtype = np.double
178+
cdtype = np.cdouble
179+
180+
num_uniform = tuple(num_uniform // ndim + 5 * np.arange(ndim))
181+
182+
x = [
183+
random.uniform(-np.pi, np.pi, size=num_nonnuniform).astype(dtype)
184+
for _ in range(ndim)
185+
]
186+
f = random.normal(size=num_uniform) + 1j * random.normal(size=num_uniform)
187+
f = f.astype(cdtype)
188+
189+
num = 5
190+
xs = [jnp.repeat(x_[None], num, axis=0) for x_ in x]
191+
fs = jnp.repeat(f[None], num, axis=0)
192+
193+
func = partial(nufft2, eps=eps, iflag=iflag)
194+
calc = jax.vmap(func)(fs, *xs)
195+
expect = func(f, *x)
196+
for n in range(num):
197+
np.testing.assert_allclose(calc[n], expect)

0 commit comments

Comments
 (0)