diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 794c238c0fa7..e586d9e770d4 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -4659,21 +4659,16 @@ def _fix_state_dict_keys_on_save(self, state_dict): @classmethod def _load_pretrained_model( - cls, - model: "PreTrainedModel", - state_dict: Optional[dict], - checkpoint_files: Optional[list[str]], - pretrained_model_name_or_path: Optional[str], - ignore_mismatched_sizes: bool = False, - sharded_metadata: Optional[dict] = None, - device_map: Optional[dict] = None, - disk_offload_folder: Optional[str] = None, - dtype: Optional[torch.dtype] = None, - hf_quantizer: Optional[HfQuantizer] = None, - device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None, - key_mapping: Optional[dict[str, str]] = None, - weights_only: bool = True, + model, + state_dict, + loaded_keys, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + sharded_metadata=None, + _fast_init=True, ): + # TODO: we should only be calling hf_quantizer.skip_placement or something like that is_quantized = hf_quantizer is not None is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in { @@ -5141,14 +5136,17 @@ def set_is_initialized_for_modules(module): if is_deepspeed_zero3_enabled() and not is_quantized: import deepspeed - # keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them - not_initialized_parameters = list( - {v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)} - ) - with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): - self.initialize_weights() - else: - self.initialize_weights() + # keep_vars=True as we need the original tensors, so that the "_is_hf_initialized" is present on them + not_initialized_parameters = list( + {v for v in self.state_dict(keep_vars=True).values() if not getattr(v, "_is_hf_initialized", False)} + ) + with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0): + self.initialize_weights() + else: + # Skip reinitialization for quantized (int8) models + if not is_quantized: + self.initialize_weights() + def _adjust_missing_and_unexpected_keys( self, missing_keys: list[str], unexpected_keys: list[str], loading_task_model_from_base_state_dict: bool diff --git a/src/transformers/utils/constants.py b/src/transformers/utils/constants.py index fefd1b4601da..94e0352df567 100644 --- a/src/transformers/utils/constants.py +++ b/src/transformers/utils/constants.py @@ -4,3 +4,13 @@ IMAGENET_STANDARD_STD = [0.5, 0.5, 0.5] OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711] +__all__ = [ + "IMAGENET_DEFAULT_MEAN", + "IMAGENET_DEFAULT_STD", + "IMAGENET_STANDARD_MEAN", + "IMAGENET_STANDARD_STD", + "OPENAI_CLIP_MEAN", + "OPENAI_CLIP_STD", + "SAFE_WEIGHTS_INDEX_NAME", +] +