Skip to content

Commit 2302a2e

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix test_make_array_from_single_device_arrays in random_test.py for TPU 7x. This is because create_mesh return a different order of devices while arrays are being created with jax.devices() order.
PiperOrigin-RevId: 761975228
1 parent 5dba0cf commit 2302a2e

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/random_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -991,7 +991,7 @@ def callback(index):
991991
def test_make_array_from_single_device_arrays(self):
992992
devices = jax.devices()
993993
shape = (len(devices),)
994-
mesh = jtu.create_mesh((len(devices),), ('x',))
994+
mesh = jtu.create_mesh((len(devices),), ('x',), iota_order=True)
995995
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x'))
996996
keys = random.split(random.key(0), len(devices))
997997
arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)]

0 commit comments

Comments
 (0)