Skip to content

Commit

Permalink
Fix MP refinement procedure (#1338)
Browse files Browse the repository at this point in the history
Fix MP refinement procedure
  • Loading branch information
elad-c authored Jan 21, 2025
1 parent 8f0d5a9 commit bc788eb
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,8 @@ def solve(self, estimate: float, iter_limit: int = 500, time_limit: int = None)
t1 = time()
while expansion_count < iter_limit and len(open_list) > 0:
if time_limit is not None and time() - t1 > time_limit:
raise TimeoutError
# TODO: add test for this.
raise TimeoutError # pragma: no cover
# Choose next node to expand
next_cut = self._get_cut_to_expand(open_list, costs, routes, estimate)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def greedy_solution_refinement_procedure(mp_solution: List[int],
node_candidates = current_node.candidates_quantization_cfg

# only weights kernel attribute is quantized with weights mixed precision
kernel_attr = search_manager.fw_info.get_kernel_op_attributes(current_node)
kernel_attr = search_manager.fw_info.get_kernel_op_attributes(current_node.type)
kernel_attr = None if kernel_attr is None else kernel_attr[0]
valid_candidates = _get_valid_candidates_indices(node_candidates, new_solution[node_idx], kernel_attr)

Expand Down Expand Up @@ -139,8 +139,9 @@ def _get_valid_candidates_indices(node_candidates: List[CandidateNodeQuantizatio
activation_num_bits = current_candidate.activation_quantization_cfg.activation_n_bits

# Filter candidates that have higher bit-width for both weights and activations (except for the current index).
# TODO: activation bits comparison: should be >= if ACTIVATION or TOTAL ru is used. else should be ==.
return [i for i, c in enumerate(node_candidates) if
c.activation_quantization_cfg.activation_n_bits >= activation_num_bits
c.activation_quantization_cfg.activation_n_bits == activation_num_bits
and c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits >= weights_num_bits
and not (c.activation_quantization_cfg.activation_n_bits == activation_num_bits
and c.weights_quantization_cfg.get_attr_config(kernel_attr).weights_n_bits == weights_num_bits)]

0 comments on commit bc788eb

Please sign in to comment.