From 815a92e411649cb74023087d38bafffed1841003 Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Thu, 30 Apr 2020 14:49:33 -0700 Subject: [PATCH] Remove assert from ShardedDeviceArray staging. (#2908) This would erroneously fail on Cloud TPU because the TPU client has its own buffer type. --- jax/interpreters/pxla.py | 1 - 1 file changed, 1 deletion(-) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 9c4d1f01fd87..76179c3f02cc 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -459,7 +459,6 @@ def __init__(self, # providing sharding_spec. It assumes that any pre-existing callers are # creating pmap-style ShardedDeviceArrays. if device_buffers is None: - assert isinstance(sharding_spec[0], xb.xla_client._xla.PyLocalBuffer) device_buffers = sharding_spec sharding_spec = _pmap_sharding_spec(aval.shape[0], aval.shape[0], aval.shape[1:])