Skip to content

Commit 6cc627a

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add checks for argument shapes and types
Apparently we never checked it and it's been quite easy to get this wrong. PiperOrigin-RevId: 762394139
1 parent a0af34f commit 6cc627a

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,14 +677,28 @@ def as_gpu_kernel(
677677
if launch_ctx.is_device_collective and not supports_cross_device_collectives():
678678
raise RuntimeError("Kernel is a cross-device collective but no support is available.")
679679

680-
expected_arg_treedef = jax.tree.structure(in_shape)
680+
expected_arg_tys, expected_arg_treedef = jax.tree.flatten(in_shape)
681681
def _check_args(*args):
682682
arg_treedef = jax.tree.structure(args)
683683
if arg_treedef != expected_arg_treedef:
684684
raise ValueError(
685685
f"Invalid argument structure: expected {expected_arg_treedef}, got"
686686
f" {arg_treedef}, ({args=})"
687687
)
688+
for arg, expected_ty in zip(args, expected_arg_tys):
689+
if arg.shape != expected_ty.shape:
690+
raise ValueError(
691+
f"Argument shape mismatch: expected {expected_ty.shape}, got"
692+
f" {arg.shape}"
693+
)
694+
if arg.dtype != expected_ty.dtype:
695+
hint = ""
696+
if not arg.shape:
697+
hint = f". Hint: cast the scalar to {expected_ty.dtype} explicitly."
698+
raise ValueError(
699+
f"Argument dtype mismatch: expected {expected_ty.dtype}, got"
700+
f" {arg.dtype}{hint}"
701+
)
688702

689703
def bind(*args) -> Any:
690704
return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape)

tests/mosaic/gpu_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def test_scalar_argument(self, dtype):
460460
" values read from the 32-bit input buffer to sometimes"
461461
" (nondeterministically) contain garbage.")
462462

463-
scalar = 42
463+
scalar = dtype(42)
464464
expected = np.full((128, 128), scalar, dtype=dtype)
465465

466466
def kernel(ctx, inp, out, _):

0 commit comments

Comments
 (0)