File tree 1 file changed +6
-5
lines changed
1 file changed +6
-5
lines changed Original file line number Diff line number Diff line change 24
24
if is_transformers_available ():
25
25
import transformers
26
26
27
+ if transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
28
+ import deepspeed
29
+
27
30
if is_peft_available ():
28
31
from peft import set_peft_model_state_dict
29
32
@@ -442,15 +445,13 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
442
445
self .cur_decay_value = decay
443
446
one_minus_decay = 1 - decay
444
447
445
- context_manager = contextlib .nullcontext
446
- if is_transformers_available () and transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
447
- import deepspeed
448
+ context_manager = contextlib .nullcontext ()
448
449
449
450
if self .foreach :
450
451
if is_transformers_available () and transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
451
452
context_manager = deepspeed .zero .GatheredParameters (parameters , modifier_rank = None )
452
453
453
- with context_manager () :
454
+ with context_manager :
454
455
params_grad = [param for param in parameters if param .requires_grad ]
455
456
s_params_grad = [
456
457
s_param for s_param , param in zip (self .shadow_params , parameters ) if param .requires_grad
@@ -472,7 +473,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
472
473
if is_transformers_available () and transformers .integrations .deepspeed .is_deepspeed_zero3_enabled ():
473
474
context_manager = deepspeed .zero .GatheredParameters (param , modifier_rank = None )
474
475
475
- with context_manager () :
476
+ with context_manager :
476
477
if param .requires_grad :
477
478
s_param .sub_ (one_minus_decay * (s_param - param ))
478
479
else :
You can’t perform that action at this time.
0 commit comments