Skip to content

Commit

Permalink
Fix issue #841: fix MCT so quantized models work with torch.save
Browse files Browse the repository at this point in the history
  • Loading branch information
elad cohen committed Nov 20, 2023
1 parent b2944fe commit 6a9e3bb
Show file tree
Hide file tree
Showing 10 changed files with 22 additions and 23 deletions.
19 changes: 9 additions & 10 deletions model_compression_toolkit/core/common/defaultdict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,28 @@
# ==============================================================================


from typing import Callable, Dict, Any
from typing import Dict, Any
from copy import deepcopy


class DefaultDict(object):
class DefaultDict:
"""
Default dictionary. It wraps a dictionary given at initialization and return its
values when requested. If the requested key is not presented at initial dictionary,
it returns the returned value a default factory (that is passed at initialization) generates.
it returns the returned value a default value (that is passed at initialization) generates.
"""

def __init__(self,
known_dict: Dict[Any, Any],
default_factory: Callable = None):
default_value: Any = None):
"""
Args:
known_dict: Dictionary to wrap.
default_factory: Callable to get default values when requested key is not in known_dict.
default_value: default value when requested key is not in known_dict.
"""

self.default_factory = default_factory
self.default_value = default_value
self.known_dict = known_dict

def get(self, key: Any) -> Any:
Expand All @@ -51,11 +52,9 @@ def get(self, key: Any) -> Any:
"""

if key in self.known_dict:
return self.known_dict.get(key)
return self.known_dict[key]
else:
if self.default_factory is not None:
return self.default_factory()
return None
return deepcopy(self.default_value)

def keys(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

# If the quantization config does not contain kernel channel mapping or the weights
# quantization is not per-channel, we use a dummy channel mapping.
dummy_channel_mapping = DefaultDict({}, lambda: (None, None))
dummy_channel_mapping = DefaultDict({}, (None, None))


def get_weights_qparams(kernel: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
KERNEL_ATTRIBUTES = DefaultDict({Conv2D: [KERNEL],
DepthwiseConv2D: [DEPTHWISE_KERNEL],
Dense: [KERNEL],
Conv2DTranspose: [KERNEL]}, lambda: [None])
Conv2DTranspose: [KERNEL]}, [None])


"""
Expand All @@ -50,7 +50,7 @@
DEFAULT_CHANNEL_AXIS_DICT = DefaultDict({Conv2D: (3, 2),
DepthwiseConv2D: (2, 2),
Dense: (1, 0),
Conv2DTranspose: (2, 3)}, lambda: (None, None))
Conv2DTranspose: (2, 3)}, (None, None))


"""
Expand All @@ -61,7 +61,7 @@
DepthwiseConv2D: -1,
Dense: -1,
Conv2DTranspose: -1},
lambda: -1)
-1)


"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
KERNEL_ATTRIBUTES = DefaultDict({Conv2d: [KERNEL],
ConvTranspose2d: [KERNEL],
Linear: [KERNEL]},
lambda: [None])
[None])

"""
Map a layer to its kernel's output and input channels indices.
Expand All @@ -43,7 +43,7 @@
DEFAULT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: (0, 1),
Linear: (0, 1),
ConvTranspose2d: (1, 0)},
lambda: (None, None))
(None, None))

"""
Map a layer to its output channel axis.
Expand All @@ -52,7 +52,7 @@
DEFAULT_OUT_CHANNEL_AXIS_DICT = DefaultDict({Conv2d: 1,
Linear: -1,
ConvTranspose2d: 1},
lambda: 1)
1)


"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class STEWeightGPTQQuantizer(BaseKerasGPTQTrainableQuantizer):

def __init__(self,
quantization_config: TrainableQuantizerWeightsConfig,
max_lsbs_change_map: dict = DefaultDict({}, lambda: 1)):
max_lsbs_change_map: dict = DefaultDict({}, 1)):
"""
Initialize a STEWeightGPTQQuantizer object with parameters to use for the quantization.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class STEWeightGPTQQuantizer(BasePytorchGPTQTrainableQuantizer):

def __init__(self,
quantization_config: TrainableQuantizerWeightsConfig,
max_lsbs_change_map: dict = DefaultDict({}, lambda: 1)):
max_lsbs_change_map: dict = DefaultDict({}, 1)):
"""
Construct a Pytorch model that utilize a fake weight quantizer of STE (Straight Through Estimator) for symmetric quantizer.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(self, unit_test, quant_method=QuantizationMethod.SYMMETRIC, roundin
if rounding_type == RoundingType.SoftQuantizer:
self.override_params = {QUANT_PARAM_LEARNING_STR: quantization_parameter_learning}
elif rounding_type == RoundingType.STE:
self.override_params = {MAX_LSB_STR: DefaultDict({}, lambda: 1)}
self.override_params = {MAX_LSB_STR: DefaultDict({}, 1)}
else:
self.override_params = None

Expand Down
2 changes: 1 addition & 1 deletion tests/keras_tests/function_tests/test_get_gptq_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def setUp(self):
train_bias=True,
loss=multiple_tensors_mse_loss,
rounding_type=RoundingType.STE,
gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict({}, lambda: 1)}),
gptq_quantizer_params_override={MAX_LSB_STR: DefaultDict({}, 1)}),
get_keras_gptq_config(n_epochs=1,
optimizer=tf.keras.optimizers.Adam())]

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch_tests/function_tests/get_gptq_config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def run_test(self):
{QUANT_PARAM_LEARNING_STR: self.quantization_parameters_learning}
elif self.rounding_type == RoundingType.STE:
gptqv2_config.gptq_quantizer_params_override = \
{MAX_LSB_STR: DefaultDict({}, lambda: 1)}
{MAX_LSB_STR: DefaultDict({}, 1)}
else:
gptqv2_config.gptq_quantizer_params_override = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, unit_test, experimental_exporter=True, weights_bits=8, weight
self.log_norm_weights = log_norm_weights
self.scaled_log_norm = scaled_log_norm
self.override_params = {QUANT_PARAM_LEARNING_STR: params_learning} if \
rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict({}, lambda: 1)} \
rounding_type == RoundingType.SoftQuantizer else {MAX_LSB_STR: DefaultDict({}, 1)} \
if rounding_type == RoundingType.STE else None

def get_quantization_config(self):
Expand Down

0 comments on commit 6a9e3bb

Please sign in to comment.