Skip to content

Commit 6d0a140

Browse files
committed
remove state dict compress and disk decompress
Signed-off-by: Kyle Sayers <[email protected]>
1 parent df6fd15 commit 6d0a140

File tree

4 files changed

+19
-439
lines changed

4 files changed

+19
-439
lines changed

src/compressed_tensors/compressors/model_compressors/model_compressor.py

Lines changed: 7 additions & 237 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
# limitations under the License.
1414

1515
import json
16-
import logging
17-
import operator
1816
import os
17+
<<<<<<< HEAD
1918
import re
19+
=======
20+
>>>>>>> 24f6104 (remove state dict compress and disk decompress)
2021
from copy import deepcopy
21-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, TypeVar, Union
22+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, TypeVar, Union
2223

2324
import compressed_tensors
2425
import torch
25-
import transformers
2626
from compressed_tensors.base import (
2727
COMPRESSION_VERSION_NAME,
2828
QUANTIZATION_CONFIG_NAME,
@@ -41,29 +41,21 @@
4141
QuantizationConfig,
4242
QuantizationScheme,
4343
QuantizationStatus,
44-
apply_quantization_config,
45-
load_pretrained_quantization_parameters,
4644
)
4745
from compressed_tensors.transform import TransformConfig
4846
from compressed_tensors.utils import (
4947
align_module_device,
5048
delete_offload_parameter,
5149
get_execution_device,
5250
get_offloaded_device,
53-
get_safetensors_folder,
54-
has_offloaded_params,
55-
merge_names,
56-
patch_attr,
5751
register_offload_parameter,
58-
update_parameter_data,
5952
)
6053
from compressed_tensors.utils.helpers import (
6154
fix_fsdp_module_name,
6255
is_compressed_tensors_config,
6356
)
6457
from compressed_tensors.utils.match import match_named_modules
6558
from loguru import logger
66-
from torch import Tensor
6759
from torch.nn import Module
6860
from tqdm import tqdm
6961
from transformers import AutoConfig
@@ -76,8 +68,6 @@
7668

7769
__all__ = ["ModelCompressor", "map_module_to_scheme"]
7870

79-
_LOGGER: logging.Logger = logging.getLogger(__name__)
80-
8171

8272
if TYPE_CHECKING:
8373
# dummy type if not available from transformers
@@ -616,156 +606,6 @@ def decompress_model(self, model: Module):
616606

617607
module.quantization_status = QuantizationStatus.FROZEN
618608

619-
# ----- state dict compression pathways ----- #
620-
621-
def compress(
622-
self,
623-
model: Module,
624-
state_dict: Optional[Dict[str, Tensor]] = None,
625-
show_progress: bool = False,
626-
) -> Dict[str, Tensor]:
627-
"""
628-
Compresses a dense state dict or model with sparsity and/or quantization
629-
630-
:param model: uncompressed model to compress
631-
:param state_dict: optional uncompressed state_dict to insert into model
632-
:return: compressed state dict
633-
"""
634-
635-
if state_dict is None:
636-
state_dict = model.state_dict()
637-
638-
if self.quantization_compressor is not None:
639-
module_to_scheme = map_module_to_scheme(model)
640-
# Note - compress only supports one compression format atm
641-
quant_compressor = next(iter(self.quantization_compressor.values()))
642-
state_dict = quant_compressor.compress(
643-
state_dict,
644-
names_to_scheme=module_to_scheme,
645-
show_progress=show_progress,
646-
)
647-
648-
# TODO: consider sparse compression to also be compression
649-
if self.quantization_config.format != CompressionFormat.dense.value:
650-
self.quantization_config.quantization_status = (
651-
QuantizationStatus.COMPRESSED
652-
)
653-
654-
if self.sparsity_compressor is not None:
655-
sparse_compression_targets: Set[str] = {
656-
module_name
657-
for module_name, _module in match_named_modules(
658-
model=model,
659-
targets=self.sparsity_config.targets,
660-
ignore=self.sparsity_config.ignore,
661-
)
662-
}
663-
state_dict = self.sparsity_compressor.compress(
664-
state_dict,
665-
compression_targets=sparse_compression_targets,
666-
show_progress=show_progress,
667-
)
668-
669-
# HACK: Override the dtype_byte_size function in transformers to
670-
# support float8 types. Fix is posted upstream
671-
# https://github.com/huggingface/transformers/pull/30488
672-
transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size
673-
674-
return state_dict
675-
676-
# ----- disk decompression pathways ----- #
677-
678-
def decompress(self, model_path: str, model: Module):
679-
"""
680-
Overwrites the weights in model with weights decompressed from model_path
681-
682-
:param model_path: path to compressed weights
683-
:param model: pytorch model to load decompressed weights into
684-
685-
Note: decompress makes use of both _replace_sparsity_weights and
686-
_replace_weights. The variations in these methods are a result of the subtle
687-
variations between the sparsity and quantization compressors. Specifically,
688-
quantization compressors return not just the decompressed weight, but the
689-
quantization parameters (e.g scales, zero_point) whereas sparsity compressors
690-
only return the decompressed weight.
691-
692-
"""
693-
model_path = get_safetensors_folder(model_path)
694-
sparse_decompressed = False
695-
quant_compressor = (
696-
next(iter(self.quantization_compressor.values()))
697-
if self.quantization_compressor is not None
698-
else None
699-
)
700-
701-
if (
702-
self.sparsity_compressor is not None
703-
and self.sparsity_config.format != CompressionFormat.dense.value
704-
):
705-
# note - decompress only supports one compressor atm
706-
params_to_ignore = None
707-
if quant_compressor is not None:
708-
params_to_ignore = quant_compressor.compression_param_names
709-
# Sparse decompression is applied on the model_path
710-
# The compressor will try and load any quantization parameters as well
711-
# params_to_skip_load will skip over quantization params from being loaded
712-
dense_gen = self.sparsity_compressor.decompress(
713-
model_path, params_to_skip_load=params_to_ignore
714-
)
715-
self._replace_sparsity_weights(dense_gen, model)
716-
setattr(model, SPARSITY_CONFIG_NAME, self.sparsity_compressor.config)
717-
sparse_decompressed = True
718-
719-
if quant_compressor is not None:
720-
# Temporarily set quantization status to FROZEN to prevent
721-
# quantization during apply_quantization_config. This ensures
722-
# that the dtypes of the weights are not unintentionally updated.
723-
# The status is restored after quantization params are loaded.
724-
725-
with patch_attr(
726-
self.quantization_config,
727-
"quantization_status",
728-
QuantizationStatus.FROZEN,
729-
):
730-
apply_quantization_config(model, self.quantization_config)
731-
names_to_scheme: Set[QuantizationScheme] = {
732-
name: getattr(module, "quantization_scheme")
733-
for name, module in model.named_modules()
734-
if getattr(module, "quantization_scheme", None) is not None
735-
}
736-
# Load activation scales/zp or any other quantization parameters
737-
# Conditionally load the weight quantization parameters if we have a
738-
# dense compressor or if a sparsity compressor has already been applied
739-
load_weight_qparams = sparse_decompressed or isinstance(
740-
quant_compressor, DenseCompressor
741-
)
742-
load_pretrained_quantization_parameters(
743-
model,
744-
model_path,
745-
# TODO: all weight quantization params will be moved to the
746-
# compressor in a follow-up including initialization
747-
load_weight_qparams=load_weight_qparams,
748-
)
749-
750-
model_path_or_state_dict = (
751-
model.state_dict() if sparse_decompressed else model_path
752-
)
753-
754-
dense_gen = quant_compressor.decompress(
755-
model_path_or_state_dict, names_to_scheme=names_to_scheme
756-
)
757-
# TODO: all weight quantization params will be moved to the compressor
758-
# to prevent duplicate parameter updates in update_parameter_data
759-
self._replace_weights(
760-
dense_gen, model, load_weight_qparams=not load_weight_qparams
761-
)
762-
763-
def freeze_quantization_status(module):
764-
module.quantization_status = QuantizationStatus.FROZEN
765-
766-
model.apply(freeze_quantization_status)
767-
setattr(model, QUANTIZATION_CONFIG_NAME, self.quantization_config)
768-
769609
def update_config(self, save_directory: str):
770610
"""
771611
Update the model config located at save_directory with compression configs
@@ -819,79 +659,6 @@ def update_config(self, save_directory: str):
819659
with open(config_file_path, "w") as config_file:
820660
json.dump(config_data, config_file, indent=2, sort_keys=True)
821661

822-
def _replace_sparsity_weights(self, dense_weight_generator, model: Module):
823-
"""
824-
Replace the weights of the model with the
825-
provided dense weights.
826-
827-
This method iterates over the dense_weight_generator and
828-
updates the corresponding weights in the model. If a parameter
829-
name does not exist in the model, it will be skipped.
830-
831-
:param dense_weight_generator (generator): A generator that yields
832-
tuples of (name, data), where 'name' is the parameter name and
833-
'data' is the updated param data
834-
:param model: The model whose weights are to be updated.
835-
"""
836-
for name, data in tqdm(dense_weight_generator, desc="Decompressing model"):
837-
split_name = name.split(".")
838-
prefix, param_name = ".".join(split_name[:-1]), split_name[-1]
839-
module = operator.attrgetter(prefix)(model)
840-
841-
params_device = next(module.parameters()).device
842-
device = "cpu" if has_offloaded_params(module) else params_device
843-
delattr(module, param_name)
844-
requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16)
845-
param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad)
846-
register_offload_parameter(module, param_name, param)
847-
848-
def _replace_weights(
849-
self, dense_weight_generator, model: Module, load_weight_qparams: bool = True
850-
):
851-
"""
852-
Replace the weights of the model with the
853-
provided dense weights.
854-
855-
This method iterates over the dense_weight_generator and
856-
updates the corresponding weights in the model. If a parameter
857-
name does not exist in the model, it will be skipped.
858-
859-
:param dense_weight_generator (generator): A generator that yields
860-
tuples of (name, data), where 'name' is the parameter name and
861-
'data' is the updated param data
862-
:param model: The model whose weights are to be updated.
863-
"""
864-
865-
for mod_path, data in tqdm(dense_weight_generator, desc="Decompressing model"):
866-
module = operator.attrgetter(mod_path)(model)
867-
868-
params_device = next(module.parameters()).device
869-
device = "cpu" if has_offloaded_params(module) else params_device
870-
871-
for param_name, param_data in data.items():
872-
if hasattr(module, param_name):
873-
# If compressed, will have an incorrect dtype for transformers >4.49
874-
# TODO: we can also just skip initialization of scales/zp if in
875-
# decompression in init to be consistent with loading which happens
876-
# later as well however, update_data does a good shape check -
877-
# should be moved to the compressor
878-
879-
if param_name == "weight":
880-
delattr(module, param_name)
881-
requires_grad = param_data.dtype in (
882-
torch.float16,
883-
torch.float32,
884-
torch.bfloat16,
885-
)
886-
param = torch.nn.Parameter(
887-
param_data.to(device), requires_grad=requires_grad
888-
)
889-
register_offload_parameter(module, param_name, param)
890-
elif load_weight_qparams:
891-
# Should already be registered to the correct device for
892-
# for scales/zero-points
893-
update_parameter_data(module, param_data, param_name)
894-
895662

896663
def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
897664
"""
@@ -906,6 +673,7 @@ def map_module_to_scheme(model: Module) -> Dict[str, QuantizationScheme]:
906673
and module.quantization_scheme.weights is not None
907674
)
908675
}
676+
<<<<<<< HEAD
909677

910678

911679
# HACK: Override the dtype_byte_size function in transformers to support float8 types
@@ -918,3 +686,5 @@ def new_dtype_byte_size(dtype):
918686
raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
919687
bit_size = int(bit_search.groups()[0])
920688
return bit_size // 8
689+
=======
690+
>>>>>>> 24f6104 (remove state dict compress and disk decompress)

src/compressed_tensors/quantization/lifecycle/apply.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -37,76 +37,18 @@
3737
from compressed_tensors.utils.helpers import replace_module
3838
from compressed_tensors.utils.match import match_named_modules, match_targets
3939
from compressed_tensors.utils.offload import update_parameter_data
40-
from compressed_tensors.utils.safetensors_load import get_safetensors_folder
4140
from safetensors import safe_open
4241
from torch.nn import Module
4342

4443

4544
__all__ = [
46-
"load_pretrained_quantization_parameters",
4745
"apply_quantization_config",
4846
]
4947

50-
from compressed_tensors.quantization.utils.helpers import is_module_quantized
51-
from compressed_tensors.utils.safetensors_load import (
52-
get_quantization_parameter_to_path_mapping,
53-
)
54-
5548

5649
_LOGGER = logging.getLogger(__name__)
5750

5851

59-
def load_pretrained_quantization_parameters(
60-
model: Module,
61-
model_name_or_path: Optional[str] = None,
62-
load_weight_qparams: Optional[bool] = False,
63-
):
64-
"""
65-
Loads the quantization parameters (scale and zero point) from model_name_or_path to
66-
a model that has already been initialized with a quantization config.
67-
68-
NOTE: Will always load inputs/output parameters. Will conditioanlly load weight
69-
parameters, if load_weight_qparams is set to True.
70-
71-
:param model: model to load pretrained quantization parameters to
72-
:param model_name_or_path: Hugging Face stub or local folder containing a quantized
73-
model, which is used to load quantization parameters
74-
:param load_weight_qparams: whether or not the weight quantization parameters
75-
should be loaded
76-
"""
77-
model_path = get_safetensors_folder(model_name_or_path)
78-
mapping = get_quantization_parameter_to_path_mapping(model_path)
79-
80-
for name, submodule in model.named_modules():
81-
if not is_module_quantized(submodule):
82-
continue
83-
if submodule.quantization_scheme.input_activations is not None:
84-
base_name = "input"
85-
_load_quant_args_from_mapping(
86-
base_name=base_name,
87-
module_name=name,
88-
module=submodule,
89-
mapping=mapping,
90-
)
91-
if submodule.quantization_scheme.output_activations is not None:
92-
base_name = "output"
93-
_load_quant_args_from_mapping(
94-
base_name=base_name,
95-
module_name=name,
96-
module=submodule,
97-
mapping=mapping,
98-
)
99-
100-
if load_weight_qparams and submodule.quantization_scheme.weights:
101-
base_name = "weight"
102-
_load_quant_args_from_mapping(
103-
base_name=base_name,
104-
module_name=name,
105-
module=submodule,
106-
mapping=mapping,
107-
)
108-
109-
11052
def apply_quantization_config(
11153
model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False
11254
):

0 commit comments

Comments
 (0)