Skip to content

Commit dc0cdf7

Browse files
[Mosaic GPU] Properly handle single-element outputs in inline assembly.
This issue was discovered when enabling the ragged dot example kernel to run using warpgroup semantics. The new test requires `-UNDEBUG`. PiperOrigin-RevId: 762365352
1 parent 199d9f7 commit dc0cdf7

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2591,17 +2591,28 @@ def _repack(regs_it, reg_ty):
25912591
all_reg_constraints = ",".join(
25922592
[*("=" + c for c in reg_constraints), *reg_constraints]
25932593
)
2594-
struct_ty = ir.Type.parse(
2595-
f"!llvm.struct<({','.join(map(str, reg_dtypes))})>"
2596-
)
2597-
result_struct = llvm.inline_asm(
2598-
struct_ty, regs, ptx, all_reg_constraints,
2599-
asm_dialect=0, has_side_effects=True,
2600-
)
2601-
regs = [
2602-
llvm.extractvalue(dtype, result_struct, [i])
2603-
for i, dtype in enumerate(reg_dtypes)
2604-
]
2594+
2595+
if len(reg_dtypes) == 1:
2596+
# The InlineAsm::verify() function doesn't allow a struct output when there
2597+
# is only one element (even though that seems to work for the case below).
2598+
result_elem = llvm.inline_asm(
2599+
reg_dtypes[0], regs, ptx, all_reg_constraints,
2600+
asm_dialect=0, has_side_effects=True,
2601+
)
2602+
regs = [result_elem]
2603+
else:
2604+
struct_ty = ir.Type.parse(
2605+
f"!llvm.struct<({','.join(map(str, reg_dtypes))})>"
2606+
)
2607+
result_struct = llvm.inline_asm(
2608+
struct_ty, regs, ptx, all_reg_constraints,
2609+
asm_dialect=0, has_side_effects=True,
2610+
)
2611+
regs = [
2612+
llvm.extractvalue(dtype, result_struct, [i])
2613+
for i, dtype in enumerate(reg_dtypes)
2614+
]
2615+
26052616
i32 = ir.IntegerType.get_signless(32)
26062617
results = []
26072618
regs_it = iter(regs)

tests/mosaic/gpu_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2451,6 +2451,21 @@ def kernel(ctx, inp, out, smem):
24512451
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, None)
24522452
np.testing.assert_array_equal(f(x), x * 3)
24532453

2454+
def test_optimization_barrier_with_single_value(self):
2455+
shape = (64, 64)
2456+
value = 5.0
2457+
dtype = jnp.float32
2458+
def kernel(ctx, out, smem):
2459+
del ctx, smem
2460+
mlir_type = utils.dtype_to_ir_type(dtype)
2461+
arr = mgpu.FragmentedArray.splat(c(value, mlir_type), shape)
2462+
arr = mgpu.optimization_barrier(arr)
2463+
arr.store_untiled(out)
2464+
2465+
out_shape = jax.ShapeDtypeStruct(shape, dtype)
2466+
f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), (), out_shape, ())
2467+
np.testing.assert_array_equal(f(), jnp.full(shape, value, dtype=dtype))
2468+
24542469
def test_convert_bool_to_u8(self):
24552470
m, n = 128, 128
24562471
def kernel(ctx, dst, _):

0 commit comments

Comments
 (0)