Skip to content

Commit 210b5fc

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Error out if wsc(x, P()) is called in a Explicit mesh context
PiperOrigin-RevId: 762021159
1 parent 3622c92 commit 210b5fc

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

jax/_src/pjit.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2703,7 +2703,13 @@ def check_shardings_are_auto(shardings_flat):
27032703
raise ValueError(
27042704
'The spec of NamedSharding passed to with_sharding_constraint can'
27052705
f' only refer to Auto axes of the mesh. Got spec={s.spec} and'
2706-
f' mesh={mesh}')
2706+
f' mesh={mesh}. You probably meant to use `reshard` API?')
2707+
2708+
cur_mesh = mesh_lib.get_abstract_mesh()
2709+
if cur_mesh._are_all_axes_explicit:
2710+
raise ValueError(
2711+
'with_sharding_constraint cannot be used when all axes of the mesh are'
2712+
' of type `Explicit`. Please use the `reshard` API.')
27072713

27082714

27092715
def with_sharding_constraint(x, shardings):

tests/pjit_test.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7069,8 +7069,11 @@ def test_wsc_error(self, mesh):
70697069
"The spec of NamedSharding passed to with_sharding_constraint"):
70707070
jax.lax.with_sharding_constraint(np.arange(8).reshape(4, 2), s)
70717071

7072-
s = NamedSharding(mesh, P())
7073-
jax.lax.with_sharding_constraint(np.arange(8), s)
7072+
with self.assertRaisesRegex(
7073+
ValueError,
7074+
'with_sharding_constraint cannot be used when all axes of the mesh are'
7075+
' of type `Explicit`'):
7076+
jax.lax.with_sharding_constraint(np.arange(8), NamedSharding(mesh, P()))
70747077

70757078
s = NamedSharding(Mesh(mesh.devices, mesh.axis_names,
70767079
axis_types=(AxisType.Explicit, AxisType.Auto)),

0 commit comments

Comments
 (0)