File tree 1 file changed +12
-1
lines changed
jax/experimental/mosaic/gpu
1 file changed +12
-1
lines changed Original file line number Diff line number Diff line change @@ -677,14 +677,25 @@ def as_gpu_kernel(
677
677
if launch_ctx .is_device_collective and not supports_cross_device_collectives ():
678
678
raise RuntimeError ("Kernel is a cross-device collective but no support is available." )
679
679
680
- expected_arg_treedef = jax .tree .structure (in_shape )
680
+ expected_arg_tys , expected_arg_treedef = jax .tree .flatten (in_shape )
681
681
def _check_args (* args ):
682
682
arg_treedef = jax .tree .structure (args )
683
683
if arg_treedef != expected_arg_treedef :
684
684
raise ValueError (
685
685
f"Invalid argument structure: expected { expected_arg_treedef } , got"
686
686
f" { arg_treedef } , ({ args = } )"
687
687
)
688
+ for arg , expected_ty in zip (args , expected_arg_tys ):
689
+ if arg .shape != expected_ty .shape :
690
+ raise ValueError (
691
+ f"Argument shape mismatch: expected { expected_ty .shape } , got"
692
+ f" { arg .shape } "
693
+ )
694
+ if arg .dtype != expected_ty .dtype :
695
+ raise ValueError (
696
+ f"Argument dtype mismatch: expected { expected_ty .dtype } , got"
697
+ f" { arg .dtype } "
698
+ )
688
699
689
700
def bind (* args ) -> Any :
690
701
return mosaic_gpu_p .bind (* args , module = module , out_types = out_shape )
You can’t perform that action at this time.
0 commit comments