File tree 2 files changed +12
-3
lines changed
2 files changed +12
-3
lines changed Original file line number Diff line number Diff line change @@ -2703,7 +2703,13 @@ def check_shardings_are_auto(shardings_flat):
2703
2703
raise ValueError (
2704
2704
'The spec of NamedSharding passed to with_sharding_constraint can'
2705
2705
f' only refer to Auto axes of the mesh. Got spec={ s .spec } and'
2706
- f' mesh={ mesh } ' )
2706
+ f' mesh={ mesh } . You probably meant to use `reshard` API?' )
2707
+
2708
+ cur_mesh = mesh_lib .get_abstract_mesh ()
2709
+ if cur_mesh ._are_all_axes_explicit :
2710
+ raise ValueError (
2711
+ 'with_sharding_constraint cannot be used when all axes of the mesh are'
2712
+ ' of type `Explicit`. Please use the `reshard` API.' )
2707
2713
2708
2714
2709
2715
def with_sharding_constraint (x , shardings ):
Original file line number Diff line number Diff line change @@ -7069,8 +7069,11 @@ def test_wsc_error(self, mesh):
7069
7069
"The spec of NamedSharding passed to with_sharding_constraint" ):
7070
7070
jax .lax .with_sharding_constraint (np .arange (8 ).reshape (4 , 2 ), s )
7071
7071
7072
- s = NamedSharding (mesh , P ())
7073
- jax .lax .with_sharding_constraint (np .arange (8 ), s )
7072
+ with self .assertRaisesRegex (
7073
+ ValueError ,
7074
+ 'with_sharding_constraint cannot be used when all axes of the mesh are'
7075
+ ' of type `Explicit`' ):
7076
+ jax .lax .with_sharding_constraint (np .arange (8 ), NamedSharding (mesh , P ()))
7074
7077
7075
7078
s = NamedSharding (Mesh (mesh .devices , mesh .axis_names ,
7076
7079
axis_types = (AxisType .Explicit , AxisType .Auto )),
You can’t perform that action at this time.
0 commit comments