diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 9c4d1f01fd87..76179c3f02cc 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -459,7 +459,6 @@ def __init__(self, # providing sharding_spec. It assumes that any pre-existing callers are # creating pmap-style ShardedDeviceArrays. if device_buffers is None: - assert isinstance(sharding_spec[0], xb.xla_client._xla.PyLocalBuffer) device_buffers = sharding_spec sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], aval.shape[1:])