Skip to content

Commit 61a9bd2

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Allow eval_shape to propagate shardings if the aval has shardings in full explicit mode
PiperOrigin-RevId: 761708753
1 parent efc70a0 commit 61a9bd2

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

jax/_src/pjit.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,17 @@ def jit_lower(jit_func, *args, **kwargs):
350350
@api_boundary
351351
def jit_eval_shape(jit_func, *args, **kwargs):
352352
p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
353-
out_s = [None if isinstance(s, UnspecifiedValue) else s for s in p.params['out_shardings']]
354-
# TODO(yashkatariya): Add `Layout` to SDS.
355-
out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s,
356-
weak_type=x.weak_type)
357-
for x, s in zip(p.params['jaxpr'].out_avals, out_s)]
353+
out_shardings = [None if isinstance(s, UnspecifiedValue) else s
354+
for s in p.params['out_shardings']]
355+
out = []
356+
for a, out_s in zip(p.params['jaxpr'].out_avals, out_shardings):
357+
if out_s is None:
358+
s = a.sharding if a.sharding.mesh._are_all_axes_explicit else out_s
359+
else:
360+
s = out_s
361+
# TODO(yashkatariya): Add `Layout` to SDS.
362+
out.append(api.ShapeDtypeStruct(a.shape, a.dtype, sharding=s,
363+
weak_type=a.weak_type))
358364
return tree_unflatten(p.out_tree, out)
359365

360366
def jit_evict_fn(self):

tests/pjit_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7857,6 +7857,20 @@ def g(x, y):
78577857
core.ShardingTypeError, "lhs is unreduced while rhs is not"):
78587858
g.trace(x, y)
78597859

7860+
@jtu.with_explicit_mesh((2, 2), ('x', 'y'))
7861+
def test_eval_shape(self, mesh):
7862+
np_inp = np.arange(16).reshape(8, 2)
7863+
arr = jax.device_put(np_inp, P('x', 'y'))
7864+
7865+
@jax.jit
7866+
def f(x):
7867+
return x * 2
7868+
7869+
out = jax.eval_shape(f, arr)
7870+
self.assertIsInstance(out, jax.ShapeDtypeStruct)
7871+
self.assertEqual(out.sharding,
7872+
NamedSharding(mesh.abstract_mesh, P('x', 'y')))
7873+
78607874

78617875
@jtu.pytest_mark_if_available('multiaccelerator')
78627876
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)