diff --git a/mindone/diffusers/models/model_loading_utils.py b/mindone/diffusers/models/model_loading_utils.py index b480225d93..19677ce212 100644 --- a/mindone/diffusers/models/model_loading_utils.py +++ b/mindone/diffusers/models/model_loading_utils.py @@ -35,9 +35,9 @@ import mindspore as ms from mindspore import nn, ops +from mindspore.ops import Cast from ...safetensors.mindspore import load as safe_load -from ...safetensors.mindspore import load_file as safe_load_file from ..utils import ( CKPT_FILE_EXTENSION, DEFAULT_HF_PARALLEL_LOADING_WORKERS, @@ -51,6 +51,7 @@ ) logger = logging.get_logger(__name__) +cpu_cast = Cast().set_device("CPU") _CLASS_REMAPPING_DICT = { "Transformer2DModel": { @@ -101,7 +102,7 @@ def load_state_dict( if disable_mmap: return safe_load(open(checkpoint_file, "rb").read()) else: - return safe_load_file(checkpoint_file) + return ms.load_checkpoint(checkpoint_file, format="safetensors") # support loading checkpoint file in mindspore format elif file_extension == CKPT_FILE_EXTENSION: return ms.load_checkpoint(checkpoint_file) @@ -145,11 +146,11 @@ def _load_state_dict_into_model( if keep_in_fp32_modules is not None and any( module_to_keep_in_fp32 in k.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules ): - v.set_dtype(ms.float32) + state_dict[k] = ms.Parameter(cpu_cast(v.data, ms.float32), name=k) else: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: pass # unexpect key keeps origin dtype cm = silence_mindspore_logger() if is_sharded else nullcontext() diff --git a/mindone/diffusers/models/modeling_patch.py b/mindone/diffusers/models/modeling_patch.py new file mode 100644 index 0000000000..cd9b97c2d4 --- /dev/null +++ b/mindone/diffusers/models/modeling_patch.py @@ -0,0 +1,49 @@ +import inspect +from functools import wraps + +import mindspore as ms +from mindspore import mint, nn + +SKIP_CLASSES = {nn.Dropout} +# Store original __init__ for manual restore +_ORIG_INITS = {} + + +def patch_nn_default_dtype(dtype=ms.float32, force=False): + """ + Iterate over all Cells under nn and mint.nn, + automatically set or force the default dtype in __init__ if supported. + + Args: + dtype (mindspore.dtype): target dtype to enforce + force (bool): if True, even when user passes dtype explicitly, override it + """ + for module in [ms.nn, mint.nn]: + for name in dir(module): + attr = getattr(module, name) + if inspect.isclass(attr) and issubclass(attr, nn.Cell): + if attr in SKIP_CLASSES: + continue # skip specified classes + sig = inspect.signature(attr.__init__) + if "dtype" in sig.parameters: + if attr not in _ORIG_INITS: + _ORIG_INITS[attr] = attr.__init__ + + _orig_init = attr.__init__ + + @wraps(_orig_init) + def _new_init(self, *args, _orig_init=_orig_init, **kwargs): + if force or "dtype" not in kwargs: + kwargs["dtype"] = dtype + return _orig_init(self, *args, **kwargs) + + setattr(attr, "__init__", _new_init) + + +def restore_nn_default_dtype(): + """ + Manually restore the original __init__ of all patched nn / mint.nn Cells. + """ + for cls, orig_init in _ORIG_INITS.items(): + cls.__init__ = orig_init + _ORIG_INITS.clear() diff --git a/mindone/diffusers/models/modeling_utils.py b/mindone/diffusers/models/modeling_utils.py index fabe98c449..37ad458c45 100644 --- a/mindone/diffusers/models/modeling_utils.py +++ b/mindone/diffusers/models/modeling_utils.py @@ -61,6 +61,7 @@ load_state_dict, split_torch_state_dict_into_shards, ) +from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype class ContextManagers: @@ -853,7 +854,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P ) with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) model = cls.from_config(config, **unused_kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() state_dict = None if not is_sharded: @@ -909,7 +914,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P def to(self, dtype: Optional[ms.Type] = None): for p in self.get_parameters(): - p.set_dtype(dtype) + if p.dtype != dtype: + p.set_dtype(dtype) return self def half(self): diff --git a/mindone/transformers/modeling_patch.py b/mindone/transformers/modeling_patch.py new file mode 100644 index 0000000000..cd9b97c2d4 --- /dev/null +++ b/mindone/transformers/modeling_patch.py @@ -0,0 +1,49 @@ +import inspect +from functools import wraps + +import mindspore as ms +from mindspore import mint, nn + +SKIP_CLASSES = {nn.Dropout} +# Store original __init__ for manual restore +_ORIG_INITS = {} + + +def patch_nn_default_dtype(dtype=ms.float32, force=False): + """ + Iterate over all Cells under nn and mint.nn, + automatically set or force the default dtype in __init__ if supported. + + Args: + dtype (mindspore.dtype): target dtype to enforce + force (bool): if True, even when user passes dtype explicitly, override it + """ + for module in [ms.nn, mint.nn]: + for name in dir(module): + attr = getattr(module, name) + if inspect.isclass(attr) and issubclass(attr, nn.Cell): + if attr in SKIP_CLASSES: + continue # skip specified classes + sig = inspect.signature(attr.__init__) + if "dtype" in sig.parameters: + if attr not in _ORIG_INITS: + _ORIG_INITS[attr] = attr.__init__ + + _orig_init = attr.__init__ + + @wraps(_orig_init) + def _new_init(self, *args, _orig_init=_orig_init, **kwargs): + if force or "dtype" not in kwargs: + kwargs["dtype"] = dtype + return _orig_init(self, *args, **kwargs) + + setattr(attr, "__init__", _new_init) + + +def restore_nn_default_dtype(): + """ + Manually restore the original __init__ of all patched nn / mint.nn Cells. + """ + for cls, orig_init in _ORIG_INITS.items(): + cls.__init__ = orig_init + _ORIG_INITS.clear() diff --git a/mindone/transformers/modeling_utils.py b/mindone/transformers/modeling_utils.py index 7fc96f444f..3b39436a32 100644 --- a/mindone/transformers/modeling_utils.py +++ b/mindone/transformers/modeling_utils.py @@ -61,6 +61,8 @@ import mindspore as ms from mindspore import Parameter, Tensor, mint, nn, ops from mindspore.nn import CrossEntropyLoss, Identity +from mindspore.nn.utils import no_init_parameters +from mindspore.ops import Cast from .activations import get_activation from .generation.utils import GenerationMixin @@ -79,6 +81,7 @@ prune_linear_layer, ) from .modeling_attn_mask_utils import dtype_to_min +from .modeling_patch import patch_nn_default_dtype, restore_nn_default_dtype from .utils.generic import _CAN_RECORD_REGISTRY, OutputRecorder from .utils.import_utils import is_sdpa_available @@ -110,6 +113,7 @@ ] logger = logging.get_logger(__name__) +cpu_cast = Cast().set_device("CPU") _init_weights = True @@ -373,7 +377,7 @@ def _load_state_dict_into_model(model_to_load, state_dict, start_prefix, is_shar local_state = {v.name: v for k, v in model_to_load.parameters_and_names()} for k, v in state_dict.items(): if k in local_state: - v.set_dtype(local_state[k].dtype) + state_dict[k] = ms.Parameter(cpu_cast(v.data, local_state[k].dtype), name=k) else: pass # unexpect key keeps origin dtype cm = silence_mindspore_logger() if is_sharded else nullcontext() @@ -509,7 +513,8 @@ def _get_name(self): def to(self, dtype: Optional[ms.Type] = None): for p in self.get_parameters(): - p.set_dtype(dtype) + if p.dtype != dtype: + p.set_dtype(dtype) return self def float(self): @@ -1157,7 +1162,12 @@ def _from_config(cls, config, **kwargs): if "attn_implementation" in kwargs: config._attn_implementation = kwargs.pop("attn_implementation") - model = cls(config, **kwargs) + with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) + model = cls(config, **kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() # We cannot set default mindspore dtype. So we need to cast model weights after creating. if mindspore_dtype is not None: @@ -2753,7 +2763,12 @@ def from_pretrained( config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained. - model = cls(config, *model_args, **model_kwargs) + with no_init_parameters(): + if mindspore_dtype is not None: + patch_nn_default_dtype(dtype=mindspore_dtype, force=True) + model = cls(config, *model_args, **model_kwargs) + if mindspore_dtype is not None: + restore_nn_default_dtype() # Make sure to tie the weights correctly model.tie_weights()