|
49 | 49 | import jax.numpy as jnp
|
50 | 50 |
|
51 | 51 |
|
| 52 | +WARP_SIZE = 32 |
52 | 53 | WARPGROUP_SIZE = 128
|
53 | 54 |
|
54 | 55 |
|
@@ -464,7 +465,7 @@ def _copy_gmem_to_smem_lowering(
|
464 | 465 | dst_transforms_treedef,
|
465 | 466 | barrier_transforms_treedef,
|
466 | 467 | collective_axes,
|
467 |
| - warpgroup_sync: bool = True, |
| 468 | + for_warpgroup: bool = True, |
468 | 469 | ):
|
469 | 470 | flat_src_transforms, flat_dst_transforms, flat_barrier_transforms = (
|
470 | 471 | util.split_list(
|
@@ -505,15 +506,23 @@ def _copy_gmem_to_smem_lowering(
|
505 | 506 | if ctx.module_ctx.lowering_semantics == mgpu.LoweringSemantics.Lane:
|
506 | 507 | if bytes % WARPGROUP_SIZE:
|
507 | 508 | raise NotImplementedError("Only aligned copies are supported")
|
508 |
| - # We arrive uniformly from each thread in the WG, so we need to divide the |
509 |
| - # number of bytes by the number of threads in the WG. |
510 |
| - # TODO: apaszke - Relax this. We can just select the WG leader and have it |
511 |
| - # arrive with the whole transfer size, while everyone else arrives with 0. |
512 |
| - # But we should continue using this scheme as it's likely to be faster. |
513 |
| - bytes //= WARPGROUP_SIZE |
514 |
| - if warpgroup_sync: |
| 509 | + if for_warpgroup: |
| 510 | + # We arrive uniformly from each thread in the WG, so we need to divide the |
| 511 | + # number of bytes by the number of threads in the WG. |
| 512 | + # TODO: apaszke - Relax this. We can just select the WG leader and have it |
| 513 | + # arrive with the whole transfer size, while everyone else arrives with 0. |
| 514 | + # But we should continue using this scheme as it's likely to be faster. |
| 515 | + bytes //= WARPGROUP_SIZE |
515 | 516 | mgpu.warpgroup_barrier() # Make sure all reads have completed.
|
516 |
| - barrier.arrive_expect_tx(bytes) |
| 517 | + barrier.arrive_expect_tx(bytes) |
| 518 | + else: |
| 519 | + # In Warp-level lowering, we arrive on each CUDA thread in a warp, but |
| 520 | + # the barrier still expects a full 128 arrivals so we arrive 4 times |
| 521 | + # on each CUDA thread instead. |
| 522 | + bytes //= WARP_SIZE |
| 523 | + barrier.arrive(arrival_count=3, can_complete=False) |
| 524 | + barrier.arrive_expect_tx(bytes) |
| 525 | + |
517 | 526 | ctx.launch_ctx.async_copy(
|
518 | 527 | src_ref=src,
|
519 | 528 | dst_ref=dst,
|
@@ -549,7 +558,7 @@ def _copy_gmem_to_smem_lowering(
|
549 | 558 | copy_gmem_to_smem_p,
|
550 | 559 | mgpu.LoweringSemantics.Lane,
|
551 | 560 | primitive_semantics=gpu_core.PrimitiveSemantics.Warp,
|
552 |
| -)(functools.partial(_copy_gmem_to_smem_lowering, warpgroup_sync=False)) |
| 561 | +)(functools.partial(_copy_gmem_to_smem_lowering, for_warpgroup=False)) |
553 | 562 |
|
554 | 563 |
|
555 | 564 | def copy_gmem_to_smem(
|
@@ -713,6 +722,8 @@ def _barrier_wait_pp_eqn(
|
713 | 722 |
|
714 | 723 |
|
715 | 724 | @lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane)
|
| 725 | +@lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Lane, |
| 726 | + gpu_core.PrimitiveSemantics.Warp) |
716 | 727 | @lowering.register_lowering_rule(barrier_wait_p, mgpu.LoweringSemantics.Warpgroup)
|
717 | 728 | def _barrier_wait_lowering(
|
718 | 729 | ctx: lowering.LoweringRuleContext,
|
|
0 commit comments