@@ -86,28 +86,29 @@ def mha_forward_kernel(
86
86
segment_ids_ref : jax .Array | None , # segment_id arrays
87
87
o_ref : Any , # Output
88
88
* residual_refs : Any , # Residual outputs
89
- num_heads : int ,
90
89
sm_scale : float ,
91
90
causal : bool ,
92
91
block_q : int ,
93
- block_d : int ,
94
92
block_k : int ,
93
+ head_dim : int ,
95
94
):
96
95
seq_len = k_ref .shape [0 ]
97
96
start_q = pl .program_id (0 )
97
+ head_dim_padded = q_ref .shape [- 1 ]
98
98
99
99
# o is the buffer where we accumulate the output on sram.
100
100
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
101
101
m_i = jnp .zeros (block_q , dtype = jnp .float32 ) - float ('inf' )
102
102
l_i = jnp .zeros (block_q , dtype = jnp .float32 )
103
103
# acc is the buffer where we accumulate the output on sram.
104
- o = jnp .zeros ((block_q , block_d ), dtype = jnp .float32 )
104
+ o = jnp .zeros ((block_q , head_dim_padded ), dtype = jnp .float32 )
105
105
106
106
# Load q: it will stay in L1 throughout. Indices form a matrix because we
107
107
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
108
- # q tile has shape [block_q, block_d ], block_d = = head_dim.
108
+ # q tile has shape [block_q, head_dim_padded ], head_dim_padded > = head_dim.
109
109
curr_q_slice = pl .dslice (start_q * block_q , block_q )
110
- q = q_ref [...]
110
+ head_mask = (jnp .arange (head_dim_padded ) < head_dim )[None , :]
111
+ q = pl .load (q_ref , (slice (None ), slice (None )), mask = head_mask , other = 0.0 )
111
112
q_segment_ids = (
112
113
None
113
114
if segment_ids_ref is None
@@ -121,7 +122,7 @@ def body(start_k, carry):
121
122
o_prev , m_prev , l_prev = carry
122
123
curr_k_slice = pl .dslice (start_k * block_k , block_k )
123
124
124
- k = pl .load (k_ref , (curr_k_slice , slice (None )))
125
+ k = pl .load (k_ref , (curr_k_slice , slice (None )), mask = head_mask , other = 0.0 )
125
126
qk = pl .dot (q , k .T ) # [block_q, block_k]
126
127
127
128
# Scale logits to convert from base-2 to the natural log domain.
@@ -161,7 +162,7 @@ def body(start_k, carry):
161
162
l_curr = s_curr .sum (axis = - 1 )
162
163
l_next = l_prev_corr + l_curr
163
164
o_prev_corr = correction [:, None ] * o_prev
164
- v = pl .load (v_ref , (curr_k_slice , pl . dslice ( block_d )) )
165
+ v = pl .load (v_ref , (curr_k_slice , slice ( None )), mask = head_mask )
165
166
o_curr = pl .dot (s_curr .astype (v .dtype ), v )
166
167
167
168
o_next = o_prev_corr + o_curr
@@ -182,7 +183,8 @@ def body(start_k, carry):
182
183
lse_ref = residual_refs [0 ]
183
184
lse_ref [...] = m_i + jnp .log2 (l_i )
184
185
# Write output to dram.
185
- o_ref [...] = o .astype (o_ref .dtype )
186
+ pl .store (o_ref , (slice (None ), slice (o .shape [- 1 ])), o .astype (o_ref .dtype ),
187
+ mask = head_mask )
186
188
187
189
def segment_mask (
188
190
q_segment_ids : jax .Array ,
@@ -235,6 +237,17 @@ def mha(
235
237
kv_seq_len = k .shape [1 ]
236
238
block_q = min (block_sizes .block_q , q_seq_len )
237
239
block_k = min (block_sizes .block_k , kv_seq_len )
240
+ head_dim_padded = pl .next_power_of_2 (head_dim )
241
+ if (q .shape [- 1 ] != k .shape [- 1 ]) or (q .shape [- 1 ] != v .shape [- 1 ]):
242
+ raise ValueError (
243
+ f"This kernel expects q, k, and v to have the same head dimension, but"
244
+ f" found { q .shape = } , { k .shape = } , { v .shape = } ."
245
+ )
246
+ if q_seq_len % block_q != 0 :
247
+ raise ValueError (f"{ q_seq_len = } must be a multiple of { block_q = } " )
248
+ if kv_seq_len % block_k != 0 :
249
+ raise ValueError (f"{ kv_seq_len = } must be a multiple of { block_k = } " )
250
+
238
251
# Heuristics.
239
252
grid_ = grid
240
253
if grid_ is None :
@@ -243,21 +256,17 @@ def mha(
243
256
num_warps_ = num_warps
244
257
if num_warps_ is None :
245
258
num_warps_ = 4 if head_dim <= 64 else 8
246
- kernel = functools .partial (mha_forward_kernel , num_heads = num_heads ,
247
- sm_scale = sm_scale , block_q = block_q ,
248
- block_k = block_k , block_d = head_dim ,
249
- causal = causal )
259
+ kernel = functools .partial (mha_forward_kernel , sm_scale = sm_scale ,
260
+ block_q = block_q , block_k = block_k ,
261
+ head_dim = head_dim , causal = causal )
250
262
251
263
in_specs = [
252
- pl .BlockSpec (
253
- (None , block_q , None , head_dim ), lambda i , j , k : (j , i , k , 0 )
254
- ),
255
- pl .BlockSpec (
256
- (None , kv_seq_len , None , head_dim ), lambda _ , j , k : (j , 0 , k , 0 )
257
- ),
258
- pl .BlockSpec (
259
- (None , kv_seq_len , None , head_dim ), lambda _ , j , k : (j , 0 , k , 0 )
260
- ),
264
+ pl .BlockSpec ((None , block_q , None , head_dim_padded ),
265
+ lambda i , j , k : (j , i , k , 0 )),
266
+ pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
267
+ lambda _ , j , k : (j , 0 , k , 0 )),
268
+ pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
269
+ lambda _ , j , k : (j , 0 , k , 0 )),
261
270
]
262
271
in_specs .append (
263
272
None # type: ignore[arg-type]
@@ -270,7 +279,7 @@ def mha(
270
279
grid = grid_ ,
271
280
in_specs = in_specs ,
272
281
out_specs = pl .BlockSpec (
273
- (None , block_q , None , head_dim ), lambda i , j , k : (j , i , k , 0 )
282
+ (None , block_q , None , head_dim_padded ), lambda i , j , k : (j , i , k , 0 )
274
283
),
275
284
compiler_params = plgpu .TritonCompilerParams (
276
285
num_warps = num_warps_ , num_stages = num_stages ),
@@ -301,6 +310,17 @@ def _mha_forward(
301
310
kv_seq_len = k .shape [1 ]
302
311
block_q = min (block_sizes .block_q , q_seq_len )
303
312
block_k = min (block_sizes .block_k , kv_seq_len )
313
+ if (q .shape [- 1 ] != k .shape [- 1 ]) or (q .shape [- 1 ] != v .shape [- 1 ]):
314
+ raise ValueError (
315
+ f"This kernel expects q, k, and v to have the same head dimension, but"
316
+ f" found { q .shape = } , { k .shape = } , { v .shape = } ."
317
+ )
318
+ if q_seq_len % block_q != 0 :
319
+ raise ValueError (f"{ q_seq_len = } must be a multiple of { block_q = } " )
320
+ if kv_seq_len % block_k != 0 :
321
+ raise ValueError (f"{ kv_seq_len = } must be a multiple of { block_k = } " )
322
+ head_dim_padded = pl .next_power_of_2 (head_dim )
323
+
304
324
# Heuristics.
305
325
grid_ = grid
306
326
if grid_ is None :
@@ -309,25 +329,22 @@ def _mha_forward(
309
329
num_warps_ = num_warps
310
330
if num_warps_ is None :
311
331
num_warps_ = 4 if head_dim <= 64 else 8
312
- kernel = functools .partial (mha_forward_kernel , num_heads = num_heads ,
313
- sm_scale = sm_scale , causal = causal , block_q = block_q ,
314
- block_k = block_k , block_d = head_dim )
332
+ kernel = functools .partial (mha_forward_kernel , sm_scale = sm_scale ,
333
+ causal = causal , block_q = block_q , block_k = block_k ,
334
+ head_dim = head_dim )
315
335
out_shape = [
316
336
jax .ShapeDtypeStruct (shape = q .shape , dtype = q .dtype ), # out
317
337
jax .ShapeDtypeStruct (
318
338
shape = (batch_size , num_heads , q_seq_len ), dtype = jnp .float32 # lse
319
339
),
320
340
]
321
341
in_specs = [
322
- pl .BlockSpec (
323
- (None , block_q , None , head_dim ), lambda i , j , k : (j , i , k , 0 )
324
- ),
325
- pl .BlockSpec (
326
- (None , kv_seq_len , None , head_dim ), lambda _ , j , k : (j , 0 , k , 0 )
327
- ),
328
- pl .BlockSpec (
329
- (None , kv_seq_len , None , head_dim ), lambda _ , j , k : (j , 0 , k , 0 )
330
- ),
342
+ pl .BlockSpec ((None , block_q , None , head_dim_padded ),
343
+ lambda i , j , k : (j , i , k , 0 )),
344
+ pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
345
+ lambda _ , j , k : (j , 0 , k , 0 )),
346
+ pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
347
+ lambda _ , j , k : (j , 0 , k , 0 )),
331
348
]
332
349
in_specs .append (
333
350
None # type: ignore[arg-type]
@@ -339,9 +356,8 @@ def _mha_forward(
339
356
grid = grid_ ,
340
357
in_specs = in_specs ,
341
358
out_specs = [
342
- pl .BlockSpec (
343
- (None , block_q , None , head_dim ), lambda i , j , k : (j , i , k , 0 )
344
- ),
359
+ pl .BlockSpec ((None , block_q , None , head_dim_padded ),
360
+ lambda i , j , k : (j , i , k , 0 )),
345
361
pl .BlockSpec ((None , None , block_q ), lambda i , j , k : (j , k , i )),
346
362
],
347
363
compiler_params = plgpu .TritonCompilerParams (
@@ -355,10 +371,11 @@ def _mha_forward(
355
371
return out , (q , k , v , segment_ids , out , lse )
356
372
357
373
358
- def _preprocess_backward_kernel (out_ref , dout_ref , delta_ref ):
374
+ def _preprocess_backward_kernel (out_ref , dout_ref , delta_ref , head_dim : int ):
359
375
# load
360
- o = out_ref [...].astype (jnp .float32 )
361
- do = dout_ref [...].astype (jnp .float32 )
376
+ head_mask = (jnp .arange (out_ref .shape [- 1 ]) < head_dim )[None , :]
377
+ o = pl .load (out_ref , (slice (None ), slice (None )), mask = head_mask , other = 0.0 )
378
+ do = pl .load (dout_ref , (slice (None ), slice (None )), mask = head_mask , other = 0.0 )
362
379
# compute
363
380
delta = jnp .sum (o * do , axis = 1 )
364
381
# write-back
@@ -368,17 +385,16 @@ def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref):
368
385
def _preprocess_backward (out , do , lse , block_q : int ,
369
386
debug : bool , interpret : bool ):
370
387
batch_size , seq_len , num_heads , head_dim = out .shape
388
+ head_dim_padded = pl .next_power_of_2 (head_dim )
371
389
out_shape = jax .ShapeDtypeStruct (lse .shape , lse .dtype )
372
390
delta = pl .pallas_call (
373
- _preprocess_backward_kernel ,
391
+ functools . partial ( _preprocess_backward_kernel , head_dim = head_dim ) ,
374
392
grid = (pl .cdiv (seq_len , block_q ), batch_size , num_heads ),
375
393
in_specs = [
376
- pl .BlockSpec (
377
- (None , block_q , None , head_dim ), lambda i , j , k : (j , i , k , 0 )
378
- ),
379
- pl .BlockSpec (
380
- (None , block_q , None , head_dim ), lambda i , j , k : (j , i , k , 0 )
381
- ),
394
+ pl .BlockSpec ((None , block_q , None , head_dim_padded ),
395
+ lambda i , j , k : (j , i , k , 0 )),
396
+ pl .BlockSpec ((None , block_q , None , head_dim_padded ),
397
+ lambda i , j , k : (j , i , k , 0 )),
382
398
],
383
399
out_specs = pl .BlockSpec ((None , None , block_q ), lambda i , j , k : (j , k , i )),
384
400
compiler_params = plgpu .TritonCompilerParams (num_warps = 4 , num_stages = 3 ),
@@ -414,7 +430,7 @@ def mha_backward_kernel(
414
430
block_kv_dkv : int ,
415
431
block_q_dq : int ,
416
432
block_kv_dq : int ,
417
- block_d : int ,
433
+ head_dim : int ,
418
434
):
419
435
del out_ref # Not needed
420
436
q_seq_len = q_ref .shape [0 ]
@@ -427,11 +443,13 @@ def mha_backward_kernel(
427
443
start_k = pl .program_id (2 )
428
444
curr_k_slice = pl .dslice (start_k * block_kv_dkv , block_kv_dkv )
429
445
430
- dv = jnp .zeros ([block_kv_dkv , block_d ], dtype = jnp .float32 )
431
- dk = jnp .zeros ([block_kv_dkv , block_d ], dtype = jnp .float32 )
446
+ head_dim_padded = q_ref .shape [- 1 ]
447
+ dv = jnp .zeros ([block_kv_dkv , head_dim_padded ], dtype = jnp .float32 )
448
+ dk = jnp .zeros ([block_kv_dkv , head_dim_padded ], dtype = jnp .float32 )
432
449
433
- v = pl .load (v_ref , (curr_k_slice , slice (None )))
434
- k = pl .load (k_ref , (curr_k_slice , slice (None )))
450
+ head_mask = (jnp .arange (head_dim_padded ) < head_dim )[None , :]
451
+ v = pl .load (v_ref , (curr_k_slice , slice (None )), mask = head_mask , other = 0.0 )
452
+ k = pl .load (k_ref , (curr_k_slice , slice (None )), mask = head_mask , other = 0.0 )
435
453
span_k = start_k * block_kv_dkv + jnp .arange (block_kv_dkv )
436
454
kv_segment_ids = (
437
455
None
@@ -443,7 +461,7 @@ def inner_loop_dkdv(start_q, carry):
443
461
dv , dk = carry
444
462
curr_q_slice = pl .dslice (start_q * block_q_dkv , block_q_dkv )
445
463
446
- q = pl .load (q_ref , (curr_q_slice , slice (None )))
464
+ q = pl .load (q_ref , (curr_q_slice , slice (None )), mask = head_mask , other = 0.0 )
447
465
qk = pl .dot (q , k .T )
448
466
qk_scale = math .log2 (math .e )
449
467
if sm_scale != 1. :
@@ -466,7 +484,8 @@ def inner_loop_dkdv(start_q, carry):
466
484
467
485
lse = pl .load (lse_ref , (curr_q_slice ,))
468
486
di = pl .load (delta_ref , (curr_q_slice ,))
469
- do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )))
487
+ do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )), mask = head_mask ,
488
+ other = 0.0 )
470
489
471
490
p = jnp .exp2 (qk - lse [:, None ])
472
491
dv = dv + pl .dot (p .astype (do .dtype ).T , do )
@@ -483,8 +502,10 @@ def inner_loop_dkdv(start_q, carry):
483
502
dv , dk = lax .fori_loop (
484
503
lower_bound , pl .cdiv (q_seq_len , block_q_dkv ), inner_loop_dkdv , (dv , dk )
485
504
)
486
- dv_ref [...] = dv .astype (dv_ref .dtype )
487
- dk_ref [...] = dk .astype (dk_ref .dtype )
505
+ pl .store (dv_ref , (slice (None ), slice (dv .shape [- 1 ])), dv .astype (dv_ref .dtype ),
506
+ mask = head_mask )
507
+ pl .store (dk_ref , (slice (None ), slice (dk .shape [- 1 ])), dk .astype (dk_ref .dtype ),
508
+ mask = head_mask )
488
509
489
510
# Scan #2: dQ
490
511
# 1. Load a block of Q of size (block_q_dq, head_dim) in SMEM.
@@ -493,22 +514,23 @@ def inner_loop_dkdv(start_q, carry):
493
514
start_q = pl .program_id (2 )
494
515
curr_q_slice = pl .ds (start_q * block_q_dq , block_q_dq )
495
516
span_q = start_q * block_q_dq + jnp .arange (block_q_dq )
496
- dq = jnp .zeros ([block_q_dq , block_d ], dtype = jnp .float32 )
517
+ dq = jnp .zeros ([block_q_dq , head_dim_padded ], dtype = jnp .float32 )
497
518
498
- q = pl .load (q_ref , (curr_q_slice , slice (None )))
519
+ q = pl .load (q_ref , (curr_q_slice , slice (None )), mask = head_mask , other = 0.0 )
499
520
q_segment_ids = (
500
521
None
501
522
if segment_ids_ref is None
502
523
else pl .load (segment_ids_ref , (curr_q_slice ,))
503
524
)
504
525
lse = pl .load (lse_ref , (curr_q_slice ,))
505
- do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )))
526
+ do = pl .load (do_scaled_ref , (curr_q_slice , slice (None )), mask = head_mask ,
527
+ other = 0.0 )
506
528
di = pl .load (delta_ref , (curr_q_slice ,))
507
529
508
530
def inner_loop_dq (start_k , dq ):
509
531
curr_k_slice = pl .dslice (start_k * block_kv_dq , block_kv_dq )
510
- k = pl .load (k_ref , (curr_k_slice , slice (None )))
511
- v = pl .load (v_ref , (curr_k_slice , slice (None )))
532
+ k = pl .load (k_ref , (curr_k_slice , slice (None )), mask = head_mask , other = 0.0 )
533
+ v = pl .load (v_ref , (curr_k_slice , slice (None )), mask = head_mask , other = 0.0 )
512
534
513
535
qk = pl .dot (q , k .T )
514
536
qk_scale = math .log2 (math .e )
@@ -547,7 +569,8 @@ def inner_loop_dq(start_k, dq):
547
569
upper_bound = pl .cdiv (kv_seq_len , block_kv_dq )
548
570
549
571
dq = lax .fori_loop (0 , upper_bound , inner_loop_dq , (dq ))
550
- dq_ref [...] = dq .astype (dq_ref .dtype )
572
+ pl .store (dq_ref , (slice (None ), slice (dq .shape [- 1 ])), dq .astype (dq_ref .dtype ),
573
+ mask = head_mask )
551
574
552
575
553
576
def _mha_backward (sm_scale : float , causal : bool , block_sizes : BlockSizes ,
@@ -576,6 +599,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
576
599
block_kv_dkv = min (block_sizes .block_kv_dkv , kv_seq_len )
577
600
block_q_dq = min (block_sizes .block_q_dq , q_seq_len )
578
601
block_kv_dq = min (block_sizes .block_kv_dq , kv_seq_len )
602
+ head_dim_padded = pl .next_power_of_2 (head_dim )
579
603
580
604
if q_seq_len // block_q_dq != kv_seq_len // block_kv_dkv :
581
605
raise ValueError (
@@ -591,28 +615,24 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
591
615
]
592
616
593
617
in_specs = [
594
- pl .BlockSpec (
595
- (None , q_seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
596
- ),
597
- pl .BlockSpec (
598
- (None , kv_seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
599
- ),
600
- pl .BlockSpec (
601
- (None , kv_seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
602
- ),
603
- pl .BlockSpec (
604
- (None , q_seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
605
- ),
606
- pl .BlockSpec (
607
- (None , q_seq_len , None , head_dim ), lambda i , j , _ : (i , 0 , j , 0 )
608
- ),
618
+ pl .BlockSpec ((None , q_seq_len , None , head_dim_padded ),
619
+ lambda i , j , _ : (i , 0 , j , 0 )),
620
+ pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
621
+ lambda i , j , _ : (i , 0 , j , 0 )),
622
+ pl .BlockSpec ((None , kv_seq_len , None , head_dim_padded ),
623
+ lambda i , j , _ : (i , 0 , j , 0 )),
624
+ pl .BlockSpec ((None , q_seq_len , None , head_dim_padded ),
625
+ lambda i , j , _ : (i , 0 , j , 0 )),
626
+ pl .BlockSpec ((None , q_seq_len , None , head_dim_padded ),
627
+ lambda i , j , _ : (i , 0 , j , 0 )),
609
628
pl .BlockSpec ((None , None , q_seq_len ), lambda i , j , _ : (i , j , 0 )),
610
629
pl .BlockSpec ((None , None , q_seq_len ), lambda i , j , _ : (i , j , 0 )),
611
630
]
612
631
if segment_ids is None :
613
632
in_specs .insert (3 , None ) # type: ignore[arg-type]
614
633
else :
615
- in_specs .insert (3 , pl .BlockSpec ((None , kv_seq_len ), lambda i , j , _ : (i , 0 )))
634
+ in_specs .insert (3 , pl .BlockSpec ((None , kv_seq_len ),
635
+ lambda i , j , _ : (i , 0 )))
616
636
617
637
grid = (batch_size , num_heads , pl .cdiv (kv_seq_len , block_kv_dkv ))
618
638
num_warps_ = num_warps
@@ -635,22 +655,22 @@ def _mha_backward(sm_scale: float, causal: bool, block_sizes: BlockSizes,
635
655
block_kv_dkv = block_kv_dkv ,
636
656
block_q_dq = block_q_dq ,
637
657
block_kv_dq = block_kv_dq ,
638
- block_d = head_dim ,
658
+ head_dim = head_dim ,
639
659
),
640
660
out_shape = out_shapes ,
641
661
in_specs = in_specs ,
642
662
grid = grid ,
643
663
out_specs = [
644
664
pl .BlockSpec (
645
- (None , block_q_dq , None , head_dim ),
665
+ (None , block_q_dq , None , head_dim_padded ),
646
666
lambda i , j , k : (i , k , j , 0 ), # dq
647
667
),
648
668
pl .BlockSpec (
649
- (None , block_kv_dkv , None , head_dim ),
669
+ (None , block_kv_dkv , None , head_dim_padded ),
650
670
lambda i , j , k : (i , k , j , 0 ), # dk
651
671
),
652
672
pl .BlockSpec (
653
- (None , block_kv_dkv , None , head_dim ),
673
+ (None , block_kv_dkv , None , head_dim_padded ),
654
674
lambda i , j , k : (i , k , j , 0 ), # dv
655
675
),
656
676
],
0 commit comments