|
60 | 60 | import mindspore as ms |
61 | 61 | from mindspore import Parameter, Tensor, mint, nn, ops |
62 | 62 | from mindspore.nn import CrossEntropyLoss, Identity |
| 63 | +from mindspore.nn.utils import no_init_parameters |
63 | 64 |
|
64 | 65 | from .activations import get_activation |
65 | 66 | from .generation.utils import GenerationMixin |
@@ -349,7 +350,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar |
349 | 350 | local_state = {v.name: v for k, v in model_to_load.parameters_and_names()} |
350 | 351 | for k, v in state_dict.items(): |
351 | 352 | if k in local_state: |
352 | | - v.set_dtype(local_state[k].dtype) |
| 353 | + state_dict[k] = ms.Parameter(v.to(local_state[k].dtype), name=k) |
353 | 354 | else: |
354 | 355 | pass # unexpect key keeps origin dtype |
355 | 356 | cm = silence_mindspore_logger() if is_sharded else nullcontext() |
@@ -977,8 +978,8 @@ def _from_config(cls, config, **kwargs): |
977 | 978 | use_flash_attention_2=use_flash_attention_2, |
978 | 979 | mindspore_dtype=mindspore_dtype, |
979 | 980 | ) |
980 | | - |
981 | | - model = cls(config, **kwargs) |
| 981 | + with no_init_parameters(): |
| 982 | + model = cls(config, **kwargs) |
982 | 983 |
|
983 | 984 | # We cannot set default mindspore dtype. So we need to cast model weights after creating. |
984 | 985 | if mindspore_dtype is not None: |
@@ -2348,7 +2349,8 @@ def from_pretrained( |
2348 | 2349 | config, use_flash_attention_2=use_flash_attention_2, mindspore_dtype=mindspore_dtype |
2349 | 2350 | ) |
2350 | 2351 |
|
2351 | | - model = cls(config, *model_args, **model_kwargs) |
| 2352 | + with no_init_parameters(): |
| 2353 | + model = cls(config, *model_args, **model_kwargs) |
2352 | 2354 |
|
2353 | 2355 | # Make sure to tie the weights correctly |
2354 | 2356 | model.tie_weights() |
|
0 commit comments