Skip to content

Commit fc68336

Browse files
jkr26Google-ML-Automation
authored andcommitted
Raises an explicit error in reshard.
Hit this while working with sharding in types -- passing a sharding that had an empty mesh. (I think this was in a test). This failed trying to acces with `with_spec` attribute on None -- so just catching this case early. PiperOrigin-RevId: 762135310
1 parent e48080f commit fc68336

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

jax/_src/pjit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2987,6 +2987,11 @@ def reshard(xs, out_shardings):
29872987
out_flat = []
29882988
for x, x_aval, s in safe_zip(x_flat, x_avals_flat, shardings_flat):
29892989
ds = canonicalize_sharding(s, 'reshard')
2990+
if ds is None:
2991+
raise ValueError(
2992+
'Reshard should only be used with out_shardings which are non-None '
2993+
'and have a nonempty mesh. Got sharding {s}.'
2994+
)
29902995
ds = ds.with_spec(ds.spec._normalized_spec_for_aval(x_aval.ndim)) # pytype: disable=attribute-error
29912996
out_flat.append(reshard_p.bind(x, dst_sharding=ds))
29922997
return tree_unflatten(treedef, out_flat)

tests/pjit_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8781,6 +8781,16 @@ def f(x):
87818781
"to Shardy"):
87828782
jax.jit(f)(x)
87838783

8784+
def test_reshard_empty_mesh_error(self):
8785+
arr = jax.device_put(np.arange(8), jax.devices()[0])
8786+
with self.assertRaisesRegex(ValueError, "nonempty mesh"):
8787+
reshard(arr, NamedSharding(mesh_lib.empty_abstract_mesh, P(None)))
8788+
8789+
def test_reshard_none_sharding_error(self):
8790+
arr = jax.device_put(np.arange(8), jax.devices()[0])
8791+
with self.assertRaisesRegex(ValueError, "non-None"):
8792+
reshard(arr, None)
8793+
87848794

87858795
if __name__ == '__main__':
87868796
absltest.main(testLoader=jtu.JaxTestLoader())

0 commit comments

Comments
 (0)