File tree Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Expand file tree Collapse file tree 1 file changed +9
-2
lines changed Original file line number Diff line number Diff line change @@ -502,10 +502,17 @@ def jax_memory_cleanup(layer):
502502 # For jax, delete all previous allocated memory to avoid temporarily
503503 # duplicating variable allocations. torch and tensorflow have stateful
504504 # variable types and do not need this fix.
505+ # Skip deletion for sharded arrays to avoid breaking references in
506+ # distributed setups.
505507 if keras .config .backend () == "jax" :
506508 for weight in layer .weights :
507- if getattr (weight , "_value" , None ) is not None :
508- weight ._value .delete ()
509+ if weight ._value is not None :
510+ # Do not delete sharded arrays, as they may be referenced in
511+ # JAX's distributed computation graph and deletion can cause
512+ # errors.
513+ sharding = getattr (weight ._value , "sharding" , None )
514+ if sharding is None :
515+ weight ._value .delete ()
509516
510517
511518def set_dtype_in_config (config , dtype = None ):
You can’t perform that action at this time.
0 commit comments