Skip to content

Commit 7ca5084

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix an edge-case in reshape sharding rule where the last splitting/merging dim was 1.
PiperOrigin-RevId: 741740811
1 parent ebd90e0 commit 7ca5084

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

jax/_src/lax/lax.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6977,11 +6977,16 @@ def _split_on_one_axis(op_shape, new_sizes, name):
69776977
' the sharding of the output via the `sharding` argument of'
69786978
f' jax.lax.reshape. Got operand.shape={op_shape} and {new_sizes=}')
69796979
temp = [new_sizes[j]]
6980-
while math.prod(temp) != op_shape[i]:
6980+
next_j = j + 1
6981+
while (math.prod(temp) != op_shape[i] or
6982+
(next_j < len(new_sizes) and new_sizes[next_j] == 1)):
69816983
if math.prod(temp) > op_shape[i]:
69826984
return False, []
69836985
j += 1
6986+
if j >= len(new_sizes):
6987+
return False, []
69846988
temp.append(new_sizes[j])
6989+
next_j += 1
69856990
out.append(temp)
69866991
i += 1
69876992
j += 1

tests/pjit_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5508,6 +5508,18 @@ def h2(x, y):
55085508
('4', (1, 4, 1, 6, 1), (1, 4, 6),
55095509
P(None, 'x', None, None, None), P(None, 'x', None), False),
55105510
('5', (4, 6), (4, 6), P(None, 'x'), P(None, 'x'), False),
5511+
('6', (1024, 4096), (1024, 2048, 2, 1, 1, 1, 1),
5512+
P('x', None), P('x', None, None, None, None, None, None), False),
5513+
('7', (1024, 4096, 32), (1024, 2048, 2, 1, 1, 32),
5514+
P('x', None, None), P('x', None, None, None, None, None), False),
5515+
('8', (1024, 4096), (1024, 1, 1, 4096),
5516+
P('x', None), P('x', None, None, None), False),
5517+
('9', (1024, 4096), (1024, 1, 1, 4096),
5518+
P(None, 'x'), P(None, None, None, 'x'), False),
5519+
('10', (1024, 2048, 2, 1, 1, 1), (1024, 4096),
5520+
P('x', None, None, None, None, None), P('x', None), False),
5521+
('11', (1024, 2048, 2, 1, 1, 1), (1024, 4096),
5522+
P(None, 'x', None, None, None, None), P(None, 'x'), False),
55115523
)
55125524
@jtu.with_user_mesh((2,), ('x',))
55135525
def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec,
@@ -5519,6 +5531,8 @@ def test_reshape(self, src_shape, dst_shape, src_spec, dst_spec,
55195531
@partial(jax.jit, static_argnums=1)
55205532
def f(x, new_sharding):
55215533
y = lax.reshape(x, dst_shape, out_sharding=new_sharding)
5534+
self.assertEqual(y.aval.sharding.spec, dst_spec)
5535+
self.assertEqual(y.shape, dst_shape)
55225536
y = y * 2
55235537
self.assertEqual(y.aval.sharding.spec, dst_spec)
55245538
return y

0 commit comments

Comments
 (0)