From c3748370013925cf5d4a913f9ff81b0a5b153250 Mon Sep 17 00:00:00 2001 From: elad-c Date: Wed, 15 Jan 2025 18:23:06 +0200 Subject: [PATCH] fix PT comments --- .../memory_graph/compute_graph_max_cut.py | 22 +++++++++---------- .../graph/memory_graph/max_cut_astar.py | 7 +++++- requirements.txt | 3 +-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py b/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py index 437919d1e..162c3b890 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/compute_graph_max_cut.py @@ -14,7 +14,6 @@ # ============================================================================== from collections import namedtuple from typing import Tuple, List -import timeout_decorator from model_compression_toolkit.logger import Logger from model_compression_toolkit.constants import OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING @@ -55,17 +54,16 @@ def solver_wrapper(*args, **kwargs): while it < n_iter: estimate = (u_bound + l_bound) / 2 - if it == 0: - schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter) - else: - try: - schedule, max_cut_size, cuts = solver_wrapper(estimate=estimate, iter_limit=astar_n_iter) - except timeout_decorator.TimeoutError: - if last_result[0] is None: - Logger.critical(f"Max-cut solver stopped on timeout in iteration {it} before finding a solution.") # pragma: no cover - else: - Logger.warning(f"Max-cut solver stopped on timeout in iteration {it}.") - return last_result + # Add a timeout of 5 minutes to the solver from the 2nd iteration. + try: + schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter, + time_limit=None if it == 0 else 300) + except TimeoutError: + if last_result[0] is None: + Logger.critical(f"Max-cut solver stopped on timeout in iteration {it} before finding a solution.") # pragma: no cover + else: + Logger.warning(f"Max-cut solver stopped on timeout in iteration {it}.") + return last_result if schedule is None: l_bound = estimate diff --git a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py index 6e651d4db..fc9e05d08 100644 --- a/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py +++ b/model_compression_toolkit/core/common/graph/memory_graph/max_cut_astar.py @@ -14,6 +14,7 @@ # ============================================================================== import copy from typing import List, Tuple, Dict, Set +from time import time from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.constants import DUMMY_TENSOR, DUMMY_NODE @@ -122,7 +123,7 @@ def __init__(self, memory_graph: MemoryGraph): self.target_cut = Cut([], set(), MemoryElements(elements={target_dummy_b, target_dummy_b2}, total_size=0)) - def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode], float, List[Cut]]: + def solve(self, estimate: float, iter_limit: int = 500, time_limit: int = None) -> Tuple[List[BaseNode], float, List[Cut]]: """ The AStar solver function. This method runs an AStar-like search on the memory graph, using the given estimate as a heuristic gap for solutions to consider. @@ -131,6 +132,7 @@ def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode], estimate: Cut size estimation to consider larger size of nodes in each expansion step, in order to fasten the algorithm divergence towards a solution. iter_limit: An upper limit for the number of expansion steps that the algorithm preforms. + time_limit: Optional time limit to the solver. Defaults to None which means no limit. Returns: A solution (if found within the steps limit) which contains: - A schedule for computation of the model (List of nodes). @@ -146,7 +148,10 @@ def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode], expansion_count = 0 + t1 = time() while expansion_count < iter_limit and len(open_list) > 0: + if time_limit is not None and time() - t1 > time_limit: + raise TimeoutError # Choose next node to expand next_cut = self._get_cut_to_expand(open_list, costs, routes, estimate) diff --git a/requirements.txt b/requirements.txt index fabd795df..4c68dd252 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,5 +11,4 @@ matplotlib<3.10.0 scipy protobuf mct-quantizers==1.5.2 -pydantic<2.0 -timeout-decorator \ No newline at end of file +pydantic<2.0 \ No newline at end of file