Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align const quant with converter #1099

Merged
merged 12 commits into from
Jun 9, 2024
5 changes: 1 addition & 4 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,10 +240,7 @@ def insert_positional_weights_to_input_list(self, input_tensors: List) -> List:
if isinstance(pos, int)):
if pos > len(input_tensors):
Logger.critical("The positional weight index cannot exceed the number of input tensors to the node.") # pragma: no cover
# Insert only positional weights that are not subject to quantization. If the positional weight is
# subject to quantization, the quantization wrapper inserts the positional weight into the node.
if not self.is_weights_quantization_enabled(pos):
input_tensors.insert(pos, weight)
input_tensors.insert(pos, weight)

return input_tensors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,17 @@ def calculate_and_set_weights_params(self, tensor_data: np.ndarray, min_threshol

"""
assert self.enable_weights_quantization
assert not (self.weights_per_channel_threshold and self.weights_channels_axis is None), \
"Trying to calculate threshold per channel, channel axis in None."
if self.weights_quantization_params_fn is not None:
self.set_weights_quantization_param(self.weights_quantization_params_fn(tensor_data,
p=self.l_p_value,
n_bits=self.weights_n_bits,
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
channel_axis=self.weights_channels_axis[0], # output channel axis
min_threshold=min_threshold))
self.set_weights_quantization_param(
self.weights_quantization_params_fn(tensor_data,
p=self.l_p_value,
n_bits=self.weights_n_bits,
per_channel=self.weights_per_channel_threshold and self.weights_channels_axis is not None,
channel_axis=self.weights_channels_axis[0], # output channel axis
min_threshold=min_threshold)[0] # Take only first output, the q-params, as axis is already chosen.
)
else:
self.set_weights_quantization_param({})

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
# ==============================================================================

from typing import Dict
from typing import Dict, Tuple
import numpy as np
from sklearn.cluster import KMeans

Expand Down Expand Up @@ -42,7 +42,8 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
is_symmetric: bool = False,
node=None,
hessian_info_service: HessianInfoService = None,
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES) -> Dict:
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
) -> Tuple[Dict[str, np.ndarray], int]:
"""
The quantizer first finds the closest max value per channel of tensor_data.
Now, we divide tensor_data with the threshold vector per channel. In addition, we scale the result to the range
Expand Down Expand Up @@ -70,27 +71,34 @@ def lut_kmeans_tensor(tensor_data: np.ndarray,
if n_bits >= LUT_VALUES_BITWIDTH:
Logger.critical(f'Look-Up-Table (LUT) bit configuration exceeds maximum: {n_bits} bits provided, must be less than {LUT_VALUES_BITWIDTH} bits.') # pragma: no cover
# TODO: need to set this externally
n_data_points = len(np.unique(tensor_data.flatten()))
if len(np.unique(tensor_data.flatten())) < 2 ** n_bits:
n_clusters = len(np.unique(tensor_data.flatten()))
n_clusters = n_data_points
else:
n_clusters = 2 ** n_bits
kmeans = KMeans(n_clusters=n_clusters, n_init=10)

threshold_selection_tensor = symmetric_selection_tensor if is_symmetric else power_of_two_selection_tensor
thresholds_per_channel = threshold_selection_tensor(tensor_data, p, n_bits, per_channel,
channel_axis, n_iter, min_threshold,
qc.QuantizationErrorMethod.NOCLIPPING)[THRESHOLD]

_params, channel_axis = threshold_selection_tensor(tensor_data, p, n_bits, per_channel,
channel_axis, n_iter, min_threshold,
qc.QuantizationErrorMethod.NOCLIPPING)
thresholds_per_channel = _params[THRESHOLD]

tensor_for_kmeans = int_quantization_with_threshold(tensor_data, thresholds_per_channel, LUT_VALUES_BITWIDTH)
kmeans.fit(tensor_for_kmeans.reshape(-1, 1))

# Add 0 to the LUT
cc = np.round(kmeans.cluster_centers_)
if n_data_points < 2 ** n_bits and np.all(cc != 0):
# In case there are fewer data points than potential clusters, we can add the cluster 0.0
# to the original clusters array to improve quantization (i.e. no need to zero one of the clusters).
cc = np.concatenate([np.zeros([1, 1], dtype=cc.dtype), cc])
closest2zero_idx = (np.abs(cc - 0)).argmin()
cc[closest2zero_idx] = 0.0

return {LUT_VALUES: cc,
SCALE_PER_CHANNEL: thresholds_per_channel}
SCALE_PER_CHANNEL: thresholds_per_channel}, channel_axis


def lut_kmeans_histogram(bins: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
# ==============================================================================
import numpy as np
from typing import Union, Tuple, Dict

import model_compression_toolkit.core.common.quantization.quantization_config as qc
from model_compression_toolkit.constants import MIN_THRESHOLD, THRESHOLD, NUM_QPARAM_HESSIAN_SAMPLES
Expand All @@ -23,20 +24,22 @@
from model_compression_toolkit.core.common.quantization.quantization_params_generation.error_functions import \
get_threshold_selection_tensor_error_function, get_threshold_selection_histogram_error_function
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.core.common.similarity_analyzer import compute_mse
from model_compression_toolkit.core.common.quantization.quantizers.quantizers_helpers import quantize_tensor


def power_of_two_selection_tensor(tensor_data: np.ndarray,
p: int,
n_bits: int,
per_channel: bool = False,
channel_axis: int = 1,
channel_axis: Union[int, None] = 1,
n_iter: int = 10,
min_threshold: float = MIN_THRESHOLD,
quant_error_method: qc.QuantizationErrorMethod = qc.QuantizationErrorMethod.MSE,
node=None,
hessian_info_service: HessianInfoService = None,
num_hessian_samples: int = NUM_QPARAM_HESSIAN_SAMPLES,
) -> dict:
) -> Tuple[Dict[str, np.ndarray], int]:
"""
Compute the power of two threshold based on the provided QuantizationErrorMethod to quantize the tensor.
Different search is applied, depends on the value of the selected QuantizationErrorMethod.
Expand All @@ -46,7 +49,7 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
p: p-norm to use for the Lp-norm distance.
n_bits: Number of bits to quantize the tensor.
per_channel: Whether the quantization should be per-channel or not.
channel_axis: Output channel index.
channel_axis: Output channel index. if None, search for best axis.
n_iter: Number of iterations to search for the optimal threshold (not used for this method).
min_threshold: Minimal threshold to use if threshold is too small (not used for this method).
quant_error_method: an error function to optimize the parameters' selection accordingly.
Expand All @@ -56,11 +59,24 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,

Returns:
Power of two threshold to quantize the tensor in a power of 2 manner.
Selected quantization channel axis.
"""

if quant_error_method == qc.QuantizationErrorMethod.NOCLIPPING:
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
threshold = max_power_of_two(tensor_max, min_threshold)
if channel_axis is None and per_channel:
total_error_list = []
th_list = []
for _axis in range(len(tensor_data.shape)):
tensor_max = get_tensor_max(tensor_data, per_channel, _axis, n_bits)
threshold = max_power_of_two(tensor_max, min_threshold)
q_tensor_data = quantize_tensor(tensor_data, threshold, n_bits, True)
total_error_list.append(compute_mse(tensor_data, q_tensor_data, norm=True))
th_list.append(threshold)
channel_axis = np.argmin(total_error_list)
threshold = th_list[channel_axis]
else:
tensor_max = get_tensor_max(tensor_data, per_channel, channel_axis, n_bits)
threshold = max_power_of_two(tensor_max, min_threshold)
else:
signed = True # weights are always signed
axis = -1 if per_channel else None
Expand All @@ -69,15 +85,15 @@ def power_of_two_selection_tensor(tensor_data: np.ndarray,
n_bits=n_bits, signed=signed, node=node,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
threshold = qparams_selection_tensor_search(error_function,
tensor_data,
n_bits,
per_channel=per_channel,
channel_axis=channel_axis,
n_iter=n_iter,
min_threshold=min_threshold,
signed=signed)
return {THRESHOLD: threshold}
threshold, channel_axis = qparams_selection_tensor_search(error_function,
tensor_data,
n_bits,
per_channel=per_channel,
channel_axis=channel_axis,
n_iter=n_iter,
min_threshold=min_threshold,
signed=signed)
return {THRESHOLD: threshold}, channel_axis


def power_of_two_selection_histogram(bins: np.ndarray,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,14 @@ def calculate_quantization_params(graph: Graph,
mod_attr_cfg = copy.deepcopy(attr_cfg)
mod_attr_cfg.weights_error_method = QuantizationErrorMethod.MSE

weights_params = get_weights_qparams(n.get_weights_by_keys(attr),
candidate_qc.weights_quantization_cfg,
mod_attr_cfg,
output_channels_axis,
node=n,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
weights_params, output_channels_axis = get_weights_qparams(n.get_weights_by_keys(attr),
candidate_qc.weights_quantization_cfg,
mod_attr_cfg,
output_channels_axis,
node=n,
hessian_info_service=hessian_info_service,
num_hessian_samples=num_hessian_samples)
attr_cfg.weights_channels_axis = (output_channels_axis, attr_cfg.weights_channels_axis[1])
attr_cfg.set_weights_quantization_param(weights_params)

if n.is_activation_quantization_enabled():
Expand Down
Loading
Loading