Skip to content

Commit 12966b5

Browse files
tomnatan30Google-ML-Automation
authored andcommitted
Fix sharding rule in jax test
`ValueError: Sharding rule has 1 operands, but the operation has 2 operands` PiperOrigin-RevId: 762412744
1 parent c1c0c0f commit 12966b5

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/cache_key_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def _cp_add(x, y):
181181
_cp_add.def_partition(
182182
infer_sharding_from_operands=_infer_sharding_from_operands,
183183
partition=_partition,
184-
sharding_rule='i i -> i')
184+
sharding_rule='..., ... -> ...')
185185

186186
devices = np.asarray(jax.devices())
187187
with Mesh(devices, ('x',)) as m:

0 commit comments

Comments
 (0)