File tree 2 files changed +16
-2
lines changed
jax/experimental/mosaic/gpu
2 files changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -677,14 +677,28 @@ def as_gpu_kernel(
677
677
if launch_ctx .is_device_collective and not supports_cross_device_collectives ():
678
678
raise RuntimeError ("Kernel is a cross-device collective but no support is available." )
679
679
680
- expected_arg_treedef = jax .tree .structure (in_shape )
680
+ expected_arg_tys , expected_arg_treedef = jax .tree .flatten (in_shape )
681
681
def _check_args (* args ):
682
682
arg_treedef = jax .tree .structure (args )
683
683
if arg_treedef != expected_arg_treedef :
684
684
raise ValueError (
685
685
f"Invalid argument structure: expected { expected_arg_treedef } , got"
686
686
f" { arg_treedef } , ({ args = } )"
687
687
)
688
+ for arg , expected_ty in zip (args , expected_arg_tys ):
689
+ if arg .shape != expected_ty .shape :
690
+ raise ValueError (
691
+ f"Argument shape mismatch: expected { expected_ty .shape } , got"
692
+ f" { arg .shape } "
693
+ )
694
+ if arg .dtype != expected_ty .dtype :
695
+ hint = ""
696
+ if not arg .shape :
697
+ hint = f". Hint: cast the scalar to { expected_ty .dtype } explicitly."
698
+ raise ValueError (
699
+ f"Argument dtype mismatch: expected { expected_ty .dtype } , got"
700
+ f" { arg .dtype } { hint } "
701
+ )
688
702
689
703
def bind (* args ) -> Any :
690
704
return mosaic_gpu_p .bind (* args , module = module , out_types = out_shape )
Original file line number Diff line number Diff line change @@ -460,7 +460,7 @@ def test_scalar_argument(self, dtype):
460
460
" values read from the 32-bit input buffer to sometimes"
461
461
" (nondeterministically) contain garbage." )
462
462
463
- scalar = 42
463
+ scalar = dtype ( 42 )
464
464
expected = np .full ((128 , 128 ), scalar , dtype = dtype )
465
465
466
466
def kernel (ctx , inp , out , _ ):
You can’t perform that action at this time.
0 commit comments