@@ -152,7 +152,7 @@ def body(start_k, carry):
152
152
# Apply mask to qk.
153
153
qk = jnp .where (mask , qk , DEFAULT_MASK_VALUE )
154
154
155
- m_curr = qk .max (axis = - 1 )
155
+ m_curr = jnp .max (qk , axis = - 1 )
156
156
m_next = jnp .maximum (m_prev , m_curr )
157
157
correction = jnp .exp2 (m_prev - m_next )
158
158
l_prev_corr = correction * l_prev
@@ -201,7 +201,7 @@ def segment_mask(
201
201
202
202
203
203
@functools .partial (
204
- jax .custom_vjp , nondiff_argnums = [4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 ]
204
+ jax .custom_vjp , nondiff_argnums = [4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 ]
205
205
)
206
206
@functools .partial (
207
207
jax .jit ,
@@ -215,6 +215,7 @@ def segment_mask(
215
215
"grid" ,
216
216
"interpret" ,
217
217
"debug" ,
218
+ "return_residuals" ,
218
219
],
219
220
)
220
221
def mha (
@@ -231,6 +232,7 @@ def mha(
231
232
grid : tuple [int , ...] | None = None ,
232
233
interpret : bool = False ,
233
234
debug : bool = False ,
235
+ return_residuals : bool = False ,
234
236
):
235
237
del backward_pass_impl
236
238
batch_size , q_seq_len , num_heads , head_dim = q .shape
@@ -273,21 +275,27 @@ def mha(
273
275
if segment_ids is None
274
276
else pl .BlockSpec ((None , kv_seq_len ), lambda _ , j , k : (j , 0 ))
275
277
)
276
- out_shape = jax .ShapeDtypeStruct (shape = q .shape , dtype = q .dtype )
277
- return pl .pallas_call (
278
+ out_shape = [q ]
279
+ out_specs = [pl .BlockSpec ((None , block_q , None , head_dim_padded ),
280
+ lambda i , j , k : (j , i , k , 0 ))]
281
+ if return_residuals :
282
+ out_shape .append (jax .ShapeDtypeStruct (
283
+ shape = (batch_size , num_heads , q_seq_len ), dtype = jnp .float32 )) # lse
284
+ out_specs .append (
285
+ pl .BlockSpec ((None , None , block_q ), lambda i , j , k : (j , k , i ))) # lse
286
+ out = pl .pallas_call (
278
287
kernel ,
279
288
grid = grid_ ,
280
289
in_specs = in_specs ,
281
- out_specs = pl .BlockSpec (
282
- (None , block_q , None , head_dim_padded ), lambda i , j , k : (j , i , k , 0 )
283
- ),
290
+ out_specs = out_specs ,
284
291
compiler_params = plgpu .TritonCompilerParams (
285
292
num_warps = num_warps_ , num_stages = num_stages ),
286
293
out_shape = out_shape ,
287
294
debug = debug ,
288
295
interpret = interpret ,
289
296
name = "mha_forward" ,
290
297
)(q , k , v , segment_ids )
298
+ return out if return_residuals else out [0 ]
291
299
292
300
293
301
def _mha_forward (
@@ -304,71 +312,17 @@ def _mha_forward(
304
312
grid : Any ,
305
313
interpret : bool ,
306
314
debug : bool ,
315
+ return_residuals : bool ,
307
316
):
308
- del backward_pass_impl
309
- batch_size , q_seq_len , num_heads , head_dim = q .shape
310
- kv_seq_len = k .shape [1 ]
311
- block_q = min (block_sizes .block_q , q_seq_len )
312
- block_k = min (block_sizes .block_k , kv_seq_len )
313
- if (q .shape [- 1 ] != k .shape [- 1 ]) or (q .shape [- 1 ] != v .shape [- 1 ]):
314
- raise ValueError (
315
- f"This kernel expects q, k, and v to have the same head dimension, but"
316
- f" found { q .shape = } , { k .shape = } , { v .shape = } ."
317
- )
318
- if q_seq_len % block_q != 0 :
319
- raise ValueError (f"{ q_seq_len = } must be a multiple of { block_q = } " )
320
- if kv_seq_len % block_k != 0 :
321
- raise ValueError (f"{ kv_seq_len = } must be a multiple of { block_k = } " )
322
- head_dim_padded = pl .next_power_of_2 (head_dim )
323
-
324
- # Heuristics.
325
- grid_ = grid
326
- if grid_ is None :
327
- grid_ = (pl .cdiv (q_seq_len , block_q ), batch_size , num_heads )
328
-
329
- num_warps_ = num_warps
330
- if num_warps_ is None :
331
- num_warps_ = 4 if head_dim <= 64 else 8
332
- kernel = functools .partial (mha_forward_kernel , sm_scale = sm_scale ,
333
- causal = causal , block_q = block_q , block_k = block_k ,
334
- head_dim = head_dim )
335
- out_shape = [
336
- jax .ShapeDtypeStruct (shape = q .shape , dtype = q .dtype ), # out
337
- jax .ShapeDtypeStruct (
338
- shape = (batch_size , num_heads , q_seq_len ), dtype = jnp .float32 # lse
339
- ),
340
- ]
341
- in_specs = [
342
- pl .BlockSpec ((None , block_q , None , head_dim_padded ),
343
- lambda i , j , k : (j , i , k , 0 )),
344
- pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
345
- lambda _ , j , k : (j , 0 , k , 0 )),
346
- pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
347
- lambda _ , j , k : (j , 0 , k , 0 )),
348
- ]
349
- in_specs .append (
350
- None # type: ignore[arg-type]
351
- if segment_ids is None
352
- else pl .BlockSpec ((None , kv_seq_len ), lambda _ , j , k : (j , 0 ))
353
- )
354
- out , lse = pl .pallas_call (
355
- kernel ,
356
- grid = grid_ ,
357
- in_specs = in_specs ,
358
- out_specs = [
359
- pl .BlockSpec ((None , block_q , None , head_dim_padded ),
360
- lambda i , j , k : (j , i , k , 0 )),
361
- pl .BlockSpec ((None , None , block_q ), lambda i , j , k : (j , k , i )),
362
- ],
363
- compiler_params = plgpu .TritonCompilerParams (
364
- num_warps = num_warps_ , num_stages = num_stages
365
- ),
366
- out_shape = out_shape ,
367
- debug = debug ,
368
- interpret = interpret ,
369
- name = "mha_forward" ,
370
- )(q , k , v , segment_ids )
371
- return out , (q , k , v , segment_ids , out , lse )
317
+ out , lse = mha (q , k , v , segment_ids = segment_ids , sm_scale = sm_scale ,
318
+ causal = causal , block_sizes = block_sizes ,
319
+ backward_pass_impl = backward_pass_impl ,
320
+ num_warps = num_warps , num_stages = num_stages ,
321
+ grid = grid , interpret = interpret , debug = debug ,
322
+ return_residuals = True )
323
+ residuals = (q , k , v , segment_ids , out , lse )
324
+ ret = (out , lse ) if return_residuals else out
325
+ return ret , residuals
372
326
373
327
374
328
def _preprocess_backward_kernel (out_ref , dout_ref , delta_ref , head_dim : int ):
@@ -576,9 +530,12 @@ def inner_loop_dq(start_k, dq):
576
530
def _mha_backward (sm_scale : float , causal : bool , block_sizes : BlockSizes ,
577
531
backward_pass_impl : str , num_warps : int | None ,
578
532
num_stages : int , grid : Any , interpret : bool ,
579
- debug : bool , res , do ):
580
- del num_stages , grid
533
+ debug : bool , return_residuals : bool , res , do ):
534
+ if return_residuals :
535
+ raise ValueError (
536
+ "Kernel differentiation is not supported if return_residuals is True." )
581
537
q , k , v , segment_ids , out , lse = res
538
+ del num_stages , grid , return_residuals
582
539
583
540
if backward_pass_impl == "xla" :
584
541
return jax .vjp (
0 commit comments