diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 8868e942ce3d..bda80e2fe2b7 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -801,12 +801,6 @@ def load_sub_model( # add kwargs to loading method diffusers_module = importlib.import_module(__name__.split(".")[0]) loading_kwargs = {} - if issubclass(class_obj, torch.nn.Module): - loading_kwargs["torch_dtype"] = torch_dtype - if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): - loading_kwargs["provider"] = provider - loading_kwargs["sess_options"] = sess_options - loading_kwargs["provider_options"] = provider_options is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) @@ -821,6 +815,17 @@ def load_sub_model( and transformers_version >= version.parse("4.20.0") ) + # For transformers models >= 4.56.0, use 'dtype' instead of 'torch_dtype' to avoid deprecation warnings + if issubclass(class_obj, torch.nn.Module): + if is_transformers_model and transformers_version >= version.parse("4.56.0"): + loading_kwargs["dtype"] = torch_dtype + else: + loading_kwargs["torch_dtype"] = torch_dtype + if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): + loading_kwargs["provider"] = provider + loading_kwargs["sess_options"] = sess_options + loading_kwargs["provider_options"] = provider_options + # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. # This makes sure that the weights won't be initialized which significantly speeds up loading.