Skip to content

Commit 3bf7545

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Add checks for argument shapes and types
Apparently we never checked it and it's been quite easy to get this wrong. PiperOrigin-RevId: 761977944
1 parent 44d2cc9 commit 3bf7545

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

jax/experimental/mosaic/gpu/core.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,14 +677,25 @@ def as_gpu_kernel(
677677
if launch_ctx.is_device_collective and not supports_cross_device_collectives():
678678
raise RuntimeError("Kernel is a cross-device collective but no support is available.")
679679

680-
expected_arg_treedef = jax.tree.structure(in_shape)
680+
expected_arg_tys, expected_arg_treedef = jax.tree.flatten(in_shape)
681681
def _check_args(*args):
682682
arg_treedef = jax.tree.structure(args)
683683
if arg_treedef != expected_arg_treedef:
684684
raise ValueError(
685685
f"Invalid argument structure: expected {expected_arg_treedef}, got"
686686
f" {arg_treedef}, ({args=})"
687687
)
688+
for arg, expected_ty in zip(args, expected_arg_tys):
689+
if arg.shape != expected_ty.shape:
690+
raise ValueError(
691+
f"Argument shape mismatch: expected {expected_ty.shape}, got"
692+
f" {arg.shape}"
693+
)
694+
if arg.dtype != expected_ty.dtype:
695+
raise ValueError(
696+
f"Argument dtype mismatch: expected {expected_ty.dtype}, got"
697+
f" {arg.dtype}"
698+
)
688699

689700
def bind(*args) -> Any:
690701
return mosaic_gpu_p.bind(*args, module=module, out_types=out_shape)

0 commit comments

Comments
 (0)