Skip to content

Commit cd6e848

Browse files
For sharded weights let's not delete explicitly (#2431)
* For sharded weights let's not delete explicitly * removed some unnecessary conditions * format corrected file
1 parent a309618 commit cd6e848

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

keras_hub/src/utils/preset_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -502,10 +502,17 @@ def jax_memory_cleanup(layer):
502502
# For jax, delete all previous allocated memory to avoid temporarily
503503
# duplicating variable allocations. torch and tensorflow have stateful
504504
# variable types and do not need this fix.
505+
# Skip deletion for sharded arrays to avoid breaking references in
506+
# distributed setups.
505507
if keras.config.backend() == "jax":
506508
for weight in layer.weights:
507-
if getattr(weight, "_value", None) is not None:
508-
weight._value.delete()
509+
if weight._value is not None:
510+
# Do not delete sharded arrays, as they may be referenced in
511+
# JAX's distributed computation graph and deletion can cause
512+
# errors.
513+
sharding = getattr(weight._value, "sharding", None)
514+
if sharding is None:
515+
weight._value.delete()
509516

510517

511518
def set_dtype_in_config(config, dtype=None):

0 commit comments

Comments
 (0)