Skip to content

Commit 29a2c5d

Browse files
charchit7a-r-r-o-w
andauthored
Resolves [BUG] 'GatheredParameters' object is not callable (#9614)
* gatherparams bug * calling context lib object * fix --------- Co-authored-by: Aryan <[email protected]>
1 parent 0d935df commit 29a2c5d

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/diffusers/training_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@
2424
if is_transformers_available():
2525
import transformers
2626

27+
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
28+
import deepspeed
29+
2730
if is_peft_available():
2831
from peft import set_peft_model_state_dict
2932

@@ -442,15 +445,13 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
442445
self.cur_decay_value = decay
443446
one_minus_decay = 1 - decay
444447

445-
context_manager = contextlib.nullcontext
446-
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
447-
import deepspeed
448+
context_manager = contextlib.nullcontext()
448449

449450
if self.foreach:
450451
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
451452
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
452453

453-
with context_manager():
454+
with context_manager:
454455
params_grad = [param for param in parameters if param.requires_grad]
455456
s_params_grad = [
456457
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
@@ -472,7 +473,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
472473
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
473474
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
474475

475-
with context_manager():
476+
with context_manager:
476477
if param.requires_grad:
477478
s_param.sub_(one_minus_decay * (s_param - param))
478479
else:

0 commit comments

Comments
 (0)