diff --git a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py index 76a090d6e..28936f116 100644 --- a/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py +++ b/model_compression_toolkit/core/common/mixed_precision/solution_refinement_procedure.py @@ -139,8 +139,9 @@ def _get_valid_candidates_indices(node_candidates: List[CandidateNodeQuantizatio activation_num_bits = current_candidate.activation_quantization_cfg.activation_n_bits # Filter candidates that have higher bit-width for both weights and activations (except for the current index). + # TODO: activation bits comparison: should be >= if ACTIVATION or TOTAL ru is used. else should be ==. return [i for i, c in enumerate(node_candidates) if - c.activation_quantization_cfg.activation_n_bits >= activation_num_bits + c.activation_quantization_cfg.activation_n_bits == activation_num_bits and c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits >= weights_num_bits and not (c.activation_quantization_cfg.activation_n_bits == activation_num_bits and c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_num_bits)]