1313# limitations under the License.
1414
1515import json
16- import logging
17- import operator
1816import os
17+ < << << << HEAD
1918import re
19+ == == == =
20+ >> >> >> > 24 f6104 (remove state dict compress and disk decompress )
2021from copy import deepcopy
21- from typing import TYPE_CHECKING , Any , Dict , List , Optional , Set , TypeVar , Union
22+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , TypeVar , Union
2223
2324import compressed_tensors
2425import torch
25- import transformers
2626from compressed_tensors .base import (
2727 COMPRESSION_VERSION_NAME ,
2828 QUANTIZATION_CONFIG_NAME ,
4141 QuantizationConfig ,
4242 QuantizationScheme ,
4343 QuantizationStatus ,
44- apply_quantization_config ,
45- load_pretrained_quantization_parameters ,
4644)
4745from compressed_tensors .transform import TransformConfig
4846from compressed_tensors .utils import (
4947 align_module_device ,
5048 delete_offload_parameter ,
5149 get_execution_device ,
5250 get_offloaded_device ,
53- get_safetensors_folder ,
54- has_offloaded_params ,
55- merge_names ,
56- patch_attr ,
5751 register_offload_parameter ,
58- update_parameter_data ,
5952)
6053from compressed_tensors .utils .helpers import (
6154 fix_fsdp_module_name ,
6255 is_compressed_tensors_config ,
6356)
6457from compressed_tensors .utils .match import match_named_modules
6558from loguru import logger
66- from torch import Tensor
6759from torch .nn import Module
6860from tqdm import tqdm
6961from transformers import AutoConfig
7668
7769__all__ = ["ModelCompressor" , "map_module_to_scheme" ]
7870
79- _LOGGER : logging .Logger = logging .getLogger (__name__ )
80-
8171
8272if TYPE_CHECKING :
8373 # dummy type if not available from transformers
@@ -616,156 +606,6 @@ def decompress_model(self, model: Module):
616606
617607 module .quantization_status = QuantizationStatus .FROZEN
618608
619- # ----- state dict compression pathways ----- #
620-
621- def compress (
622- self ,
623- model : Module ,
624- state_dict : Optional [Dict [str , Tensor ]] = None ,
625- show_progress : bool = False ,
626- ) -> Dict [str , Tensor ]:
627- """
628- Compresses a dense state dict or model with sparsity and/or quantization
629-
630- :param model: uncompressed model to compress
631- :param state_dict: optional uncompressed state_dict to insert into model
632- :return: compressed state dict
633- """
634-
635- if state_dict is None :
636- state_dict = model .state_dict ()
637-
638- if self .quantization_compressor is not None :
639- module_to_scheme = map_module_to_scheme (model )
640- # Note - compress only supports one compression format atm
641- quant_compressor = next (iter (self .quantization_compressor .values ()))
642- state_dict = quant_compressor .compress (
643- state_dict ,
644- names_to_scheme = module_to_scheme ,
645- show_progress = show_progress ,
646- )
647-
648- # TODO: consider sparse compression to also be compression
649- if self .quantization_config .format != CompressionFormat .dense .value :
650- self .quantization_config .quantization_status = (
651- QuantizationStatus .COMPRESSED
652- )
653-
654- if self .sparsity_compressor is not None :
655- sparse_compression_targets : Set [str ] = {
656- module_name
657- for module_name , _module in match_named_modules (
658- model = model ,
659- targets = self .sparsity_config .targets ,
660- ignore = self .sparsity_config .ignore ,
661- )
662- }
663- state_dict = self .sparsity_compressor .compress (
664- state_dict ,
665- compression_targets = sparse_compression_targets ,
666- show_progress = show_progress ,
667- )
668-
669- # HACK: Override the dtype_byte_size function in transformers to
670- # support float8 types. Fix is posted upstream
671- # https://github.com/huggingface/transformers/pull/30488
672- transformers .modeling_utils .dtype_byte_size = new_dtype_byte_size
673-
674- return state_dict
675-
676- # ----- disk decompression pathways ----- #
677-
678- def decompress (self , model_path : str , model : Module ):
679- """
680- Overwrites the weights in model with weights decompressed from model_path
681-
682- :param model_path: path to compressed weights
683- :param model: pytorch model to load decompressed weights into
684-
685- Note: decompress makes use of both _replace_sparsity_weights and
686- _replace_weights. The variations in these methods are a result of the subtle
687- variations between the sparsity and quantization compressors. Specifically,
688- quantization compressors return not just the decompressed weight, but the
689- quantization parameters (e.g scales, zero_point) whereas sparsity compressors
690- only return the decompressed weight.
691-
692- """
693- model_path = get_safetensors_folder (model_path )
694- sparse_decompressed = False
695- quant_compressor = (
696- next (iter (self .quantization_compressor .values ()))
697- if self .quantization_compressor is not None
698- else None
699- )
700-
701- if (
702- self .sparsity_compressor is not None
703- and self .sparsity_config .format != CompressionFormat .dense .value
704- ):
705- # note - decompress only supports one compressor atm
706- params_to_ignore = None
707- if quant_compressor is not None :
708- params_to_ignore = quant_compressor .compression_param_names
709- # Sparse decompression is applied on the model_path
710- # The compressor will try and load any quantization parameters as well
711- # params_to_skip_load will skip over quantization params from being loaded
712- dense_gen = self .sparsity_compressor .decompress (
713- model_path , params_to_skip_load = params_to_ignore
714- )
715- self ._replace_sparsity_weights (dense_gen , model )
716- setattr (model , SPARSITY_CONFIG_NAME , self .sparsity_compressor .config )
717- sparse_decompressed = True
718-
719- if quant_compressor is not None :
720- # Temporarily set quantization status to FROZEN to prevent
721- # quantization during apply_quantization_config. This ensures
722- # that the dtypes of the weights are not unintentionally updated.
723- # The status is restored after quantization params are loaded.
724-
725- with patch_attr (
726- self .quantization_config ,
727- "quantization_status" ,
728- QuantizationStatus .FROZEN ,
729- ):
730- apply_quantization_config (model , self .quantization_config )
731- names_to_scheme : Set [QuantizationScheme ] = {
732- name : getattr (module , "quantization_scheme" )
733- for name , module in model .named_modules ()
734- if getattr (module , "quantization_scheme" , None ) is not None
735- }
736- # Load activation scales/zp or any other quantization parameters
737- # Conditionally load the weight quantization parameters if we have a
738- # dense compressor or if a sparsity compressor has already been applied
739- load_weight_qparams = sparse_decompressed or isinstance (
740- quant_compressor , DenseCompressor
741- )
742- load_pretrained_quantization_parameters (
743- model ,
744- model_path ,
745- # TODO: all weight quantization params will be moved to the
746- # compressor in a follow-up including initialization
747- load_weight_qparams = load_weight_qparams ,
748- )
749-
750- model_path_or_state_dict = (
751- model .state_dict () if sparse_decompressed else model_path
752- )
753-
754- dense_gen = quant_compressor .decompress (
755- model_path_or_state_dict , names_to_scheme = names_to_scheme
756- )
757- # TODO: all weight quantization params will be moved to the compressor
758- # to prevent duplicate parameter updates in update_parameter_data
759- self ._replace_weights (
760- dense_gen , model , load_weight_qparams = not load_weight_qparams
761- )
762-
763- def freeze_quantization_status (module ):
764- module .quantization_status = QuantizationStatus .FROZEN
765-
766- model .apply (freeze_quantization_status )
767- setattr (model , QUANTIZATION_CONFIG_NAME , self .quantization_config )
768-
769609 def update_config (self , save_directory : str ):
770610 """
771611 Update the model config located at save_directory with compression configs
@@ -819,79 +659,6 @@ def update_config(self, save_directory: str):
819659 with open (config_file_path , "w" ) as config_file :
820660 json .dump (config_data , config_file , indent = 2 , sort_keys = True )
821661
822- def _replace_sparsity_weights (self , dense_weight_generator , model : Module ):
823- """
824- Replace the weights of the model with the
825- provided dense weights.
826-
827- This method iterates over the dense_weight_generator and
828- updates the corresponding weights in the model. If a parameter
829- name does not exist in the model, it will be skipped.
830-
831- :param dense_weight_generator (generator): A generator that yields
832- tuples of (name, data), where 'name' is the parameter name and
833- 'data' is the updated param data
834- :param model: The model whose weights are to be updated.
835- """
836- for name , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
837- split_name = name .split ("." )
838- prefix , param_name = "." .join (split_name [:- 1 ]), split_name [- 1 ]
839- module = operator .attrgetter (prefix )(model )
840-
841- params_device = next (module .parameters ()).device
842- device = "cpu" if has_offloaded_params (module ) else params_device
843- delattr (module , param_name )
844- requires_grad = data .dtype in (torch .float16 , torch .float32 , torch .bfloat16 )
845- param = torch .nn .Parameter (data .to (device ), requires_grad = requires_grad )
846- register_offload_parameter (module , param_name , param )
847-
848- def _replace_weights (
849- self , dense_weight_generator , model : Module , load_weight_qparams : bool = True
850- ):
851- """
852- Replace the weights of the model with the
853- provided dense weights.
854-
855- This method iterates over the dense_weight_generator and
856- updates the corresponding weights in the model. If a parameter
857- name does not exist in the model, it will be skipped.
858-
859- :param dense_weight_generator (generator): A generator that yields
860- tuples of (name, data), where 'name' is the parameter name and
861- 'data' is the updated param data
862- :param model: The model whose weights are to be updated.
863- """
864-
865- for mod_path , data in tqdm (dense_weight_generator , desc = "Decompressing model" ):
866- module = operator .attrgetter (mod_path )(model )
867-
868- params_device = next (module .parameters ()).device
869- device = "cpu" if has_offloaded_params (module ) else params_device
870-
871- for param_name , param_data in data .items ():
872- if hasattr (module , param_name ):
873- # If compressed, will have an incorrect dtype for transformers >4.49
874- # TODO: we can also just skip initialization of scales/zp if in
875- # decompression in init to be consistent with loading which happens
876- # later as well however, update_data does a good shape check -
877- # should be moved to the compressor
878-
879- if param_name == "weight" :
880- delattr (module , param_name )
881- requires_grad = param_data .dtype in (
882- torch .float16 ,
883- torch .float32 ,
884- torch .bfloat16 ,
885- )
886- param = torch .nn .Parameter (
887- param_data .to (device ), requires_grad = requires_grad
888- )
889- register_offload_parameter (module , param_name , param )
890- elif load_weight_qparams :
891- # Should already be registered to the correct device for
892- # for scales/zero-points
893- update_parameter_data (module , param_data , param_name )
894-
895662
896663def map_module_to_scheme (model : Module ) -> Dict [str , QuantizationScheme ]:
897664 """
@@ -906,6 +673,7 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
906673 and module .quantization_scheme .weights is not None
907674 )
908675 }
676+ < << << << HEAD
909677
910678
911679# HACK: Override the dtype_byte_size function in transformers to support float8 types
@@ -918,3 +686,5 @@ def new_dtype_byte_size(dtype):
918686 raise ValueError (f"`dtype` is not a valid dtype: { dtype } ." )
919687 bit_size = int (bit_search .groups ()[0 ])
920688 return bit_size // 8
689+ == == == =
690+ >> >> >> > 24 f6104 (remove state dict compress and disk decompress )
0 commit comments