@@ -188,7 +188,7 @@ def translation_rule(
188188 ]
189189
190190
191- def points_jvp (type_ , dim , prim , dpoints , source , * points , output_shape , iflag , eps ):
191+ def jvp (type_ , prim , args , tangents , * , output_shape , iflag , eps ):
192192 # Type 1:
193193 # f_k = sum_j c_j * exp(iflag * i * k * x_j)
194194 # df_k/dx_j = iflag * i * k * c_j * exp(iflag * i * k * x_j)
@@ -197,61 +197,63 @@ def points_jvp(type_, dim, prim, dpoints, source, *points, output_shape, iflag,
197197 # c_j = sum_k f_k * exp(iflag * i * k * x_j)
198198 # dc_j/dx_j = sum_k iflag * i * k * f_k * exp(iflag * i * k * x_j)
199199
200- ndim = len (points )
201- n = output_shape [dim ] if type_ == 1 else source .shape [- ndim + dim ]
202-
203- shape = np .ones (ndim , dtype = int )
204- shape [dim ] = - 1
205- k = np .arange (- np .floor (n / 2 ), np .floor ((n - 1 ) / 2 + 1 ))
206- k = k .reshape (shape )
207- factor = 1j * iflag * k
208-
209- if type_ == 1 :
210- return factor * prim .bind (
211- source * dpoints ,
212- * points ,
213- output_shape = output_shape ,
214- iflag = iflag ,
215- eps = eps ,
216- )
217- return dpoints * prim .bind (
218- factor * source ,
219- * points ,
220- output_shape = output_shape ,
221- iflag = iflag ,
222- eps = eps ,
223- )
224-
225-
226- def jvp (type_ , prim , args , tangents , * , output_shape , iflag , eps ):
227- # TODO: We could maybe speed this up by concatenating all the source terms and
228- # then executing a single NUFFT since they all use the same NU points. The
229- # bookkeeping might get a little ugly.
230-
231200 source , * points = args
232201 dsource , * dpoints = tangents
233202 output = prim .bind (source , * points , output_shape = output_shape , iflag = iflag , eps = eps )
234203
204+ # The JVP op can be written as a single transform of the same type with
235205 output_tangents = []
206+ ndim = len (points )
207+ scales = []
208+ arguments = []
236209 if type (dsource ) is not ad .Zero :
237- output_tangents .append (
238- prim .bind (dsource , * points , output_shape = output_shape , iflag = iflag , eps = eps )
239- )
210+ if type_ == 1 :
211+ scales .append (jnp .ones_like (output ))
212+ arguments .append (dsource )
213+ else :
214+ output_tangents .append (
215+ prim .bind (
216+ dsource , * points , output_shape = output_shape , iflag = iflag , eps = eps
217+ )
218+ )
240219
241- output_tangents += [
242- points_jvp (
243- type_ ,
244- dim ,
245- prim ,
246- dx ,
247- source ,
220+ for dim , dx in enumerate (dpoints ):
221+ if type (dx ) is ad .Zero :
222+ continue
223+
224+ n = output_shape [dim ] if type_ == 1 else source .shape [- ndim + dim ]
225+ shape = np .ones (ndim , dtype = int )
226+ shape [dim ] = - 1
227+ k = np .arange (- np .floor (n / 2 ), np .floor ((n - 1 ) / 2 + 1 ))
228+ k = k .reshape (shape )
229+ factor = 1j * iflag * k
230+
231+ if type_ == 1 :
232+ scales .append (factor )
233+ arguments .append (dx * source )
234+ else :
235+ scales .append (dx )
236+ arguments .append (factor * source )
237+
238+ if len (scales ):
239+ axis = - 2 if type_ == 1 else - ndim - 1
240+ output_tangent = prim .bind (
241+ jnp .concatenate (arguments , axis = axis ),
248242 * points ,
249243 output_shape = output_shape ,
250244 iflag = iflag ,
251245 eps = eps ,
252246 )
253- for dim , dx in enumerate (dpoints )
254- ]
247+
248+ axis = - 2 if type_ == 2 else - ndim - 1
249+ output_tangent *= jnp .concatenate (jnp .broadcast_arrays (* scales ), axis = axis )
250+
251+ expand_shape = (
252+ output .shape [: axis + 1 ] + (len (scales ),) + output .shape [axis + 1 :]
253+ )
254+ output_tangents .append (
255+ jnp .sum (jnp .reshape (output_tangent , expand_shape ), axis = axis )
256+ )
255257
256258 return output , reduce (ad .add_tangents , output_tangents , ad .Zero .from_value (output ))
257259
@@ -272,12 +274,15 @@ def transpose(type_, doutput, source, *points, output_shape, eps, iflag):
272274 return (result ,) + tuple (None for _ in range (len (points )))
273275
274276
275- def batch (prim , args , axes ):
276- # We can't batch over the last two dimensions of source
277- mx = args [0 ].ndim - 2
277+ def batch (type_ , prim , args , axes , ** kwargs ):
278+ ndim = len (args ) - 1
279+ if type_ == 1 :
280+ mx = args [0 ].ndim - 2
281+ else :
282+ mx = args [0 ].ndim - ndim - 1
278283 assert all (a < mx for a in axes )
279284 assert all (a == axes [0 ] for a in axes [1 :])
280- return prim .bind (* args ), axes [0 ]
285+ return prim .bind (* args , ** kwargs ), axes [0 ]
281286
282287
283288def pad_shapes (output_dim , source , * points ):
@@ -308,7 +313,7 @@ def pad_shapes(output_dim, source, *points):
308313xla .register_translation (nufft1_p , partial (translation_rule , 1 ), platform = "cpu" )
309314ad .primitive_jvps [nufft1_p ] = partial (jvp , 1 , nufft1_p )
310315ad .primitive_transposes [nufft1_p ] = partial (transpose , 1 )
311- batching .primitive_batchers [nufft1_p ] = partial (batch , nufft1_p )
316+ batching .primitive_batchers [nufft1_p ] = partial (batch , 1 , nufft1_p )
312317
313318
314319nufft2_p = core .Primitive ("nufft2" )
@@ -317,4 +322,4 @@ def pad_shapes(output_dim, source, *points):
317322xla .register_translation (nufft2_p , partial (translation_rule , 2 ), platform = "cpu" )
318323ad .primitive_jvps [nufft2_p ] = partial (jvp , 2 , nufft2_p )
319324ad .primitive_transposes [nufft2_p ] = partial (transpose , 2 )
320- batching .primitive_batchers [nufft2_p ] = partial (batch , nufft2_p )
325+ batching .primitive_batchers [nufft2_p ] = partial (batch , 2 , nufft2_p )
0 commit comments