Skip to content

Commit 700dc82

Browse files
committed
change the way of load checkpoint
1 parent 28f3b18 commit 700dc82

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

mindone/diffusers/models/model_loading_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def load_state_dict(
9797
if disable_mmap:
9898
return safe_load(open(checkpoint_file, "rb").read())
9999
else:
100-
return safe_load_file(checkpoint_file)
100+
return ms.load_checkpoint(checkpoint_file, format="safetensors")
101101
else:
102102
raise NotImplementedError(
103103
f"Only supports deserialization of weights file in safetensors format, but got {checkpoint_file}"
@@ -140,11 +140,11 @@ def _load_state_dict_into_model(
140140
and any(module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules)
141141
and dtype == ms.float16
142142
):
143-
v.set_dtype(ms.float32)
143+
state_dict[k] = ms.Parameter(v.to(ms.float32), name=k)
144144
else:
145-
v.set_dtype(local_state[k].dtype)
145+
state_dict[k] = ms.Parameter(v.to(local_state[k].dtype), name=k)
146146
else:
147-
v.set_dtype(local_state[k].dtype)
147+
state_dict[k] = ms.Parameter(v.to(local_state[k].dtype), name=k)
148148
else:
149149
pass # unexpect key keeps origin dtype
150150
cm = silence_mindspore_logger() if is_sharded else nullcontext()

mindone/transformers/modeling_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
import mindspore as ms
6161
from mindspore import Parameter, Tensor, mint, nn, ops
6262
from mindspore.nn import CrossEntropyLoss, Identity
63+
from mindspore.nn.utils import no_init_parameters
6364

6465
from .activations import get_activation
6566
from .generation.utils import GenerationMixin
@@ -349,7 +350,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar
349350
local_state = {v.name: v for k, v in model_to_load.parameters_and_names()}
350351
for k, v in state_dict.items():
351352
if k in local_state:
352-
v.set_dtype(local_state[k].dtype)
353+
state_dict[k] = ms.Parameter(v.to(local_state[k].dtype), name=k)
353354
else:
354355
pass # unexpect key keeps origin dtype
355356
cm = silence_mindspore_logger() if is_sharded else nullcontext()
@@ -977,8 +978,8 @@ def _from_config(cls, config, **kwargs):
977978
use_flash_attention_2=use_flash_attention_2,
978979
mindspore_dtype=mindspore_dtype,
979980
)
980-
981-
model = cls(config, **kwargs)
981+
with no_init_parameters():
982+
model = cls(config, **kwargs)
982983

983984
# We cannot set default mindspore dtype. So we need to cast model weights after creating.
984985
if mindspore_dtype is not None:
@@ -2348,7 +2349,8 @@ def from_pretrained(
23482349
config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype
23492350
)
23502351

2351-
model = cls(config, *model_args, **model_kwargs)
2352+
with no_init_parameters():
2353+
model = cls(config, *model_args, **model_kwargs)
23522354

23532355
# Make sure to tie the weights correctly
23542356
model.tie_weights()

0 commit comments

Comments
 (0)