Skip to content

Commit

Permalink
fix PT comments
Browse files Browse the repository at this point in the history
  • Loading branch information
elad-c committed Jan 15, 2025
1 parent 7a234c4 commit c374837
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# ==============================================================================
from collections import namedtuple
from typing import Tuple, List
import timeout_decorator

from model_compression_toolkit.logger import Logger
from model_compression_toolkit.constants import OPERATORS_SCHEDULING, MAX_CUT, CUTS, FUSED_NODES_MAPPING
Expand Down Expand Up @@ -55,17 +54,16 @@ def solver_wrapper(*args, **kwargs):

while it < n_iter:
estimate = (u_bound + l_bound) / 2
if it == 0:
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter)
else:
try:
schedule, max_cut_size, cuts = solver_wrapper(estimate=estimate, iter_limit=astar_n_iter)
except timeout_decorator.TimeoutError:
if last_result[0] is None:
Logger.critical(f"Max-cut solver stopped on timeout in iteration {it} before finding a solution.") # pragma: no cover
else:
Logger.warning(f"Max-cut solver stopped on timeout in iteration {it}.")
return last_result
# Add a timeout of 5 minutes to the solver from the 2nd iteration.
try:
schedule, max_cut_size, cuts = max_cut_astar.solve(estimate=estimate, iter_limit=astar_n_iter,
time_limit=None if it == 0 else 300)
except TimeoutError:
if last_result[0] is None:
Logger.critical(f"Max-cut solver stopped on timeout in iteration {it} before finding a solution.") # pragma: no cover
else:
Logger.warning(f"Max-cut solver stopped on timeout in iteration {it}.")
return last_result

if schedule is None:
l_bound = estimate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
import copy
from typing import List, Tuple, Dict, Set
from time import time

from model_compression_toolkit.core.common import BaseNode
from model_compression_toolkit.constants import DUMMY_TENSOR, DUMMY_NODE
Expand Down Expand Up @@ -122,7 +123,7 @@ def __init__(self, memory_graph: MemoryGraph):
self.target_cut = Cut([], set(), MemoryElements(elements={target_dummy_b, target_dummy_b2},
total_size=0))

def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode], float, List[Cut]]:
def solve(self, estimate: float, iter_limit: int = 500, time_limit: int = None) -> Tuple[List[BaseNode], float, List[Cut]]:
"""
The AStar solver function. This method runs an AStar-like search on the memory graph,
using the given estimate as a heuristic gap for solutions to consider.
Expand All @@ -131,6 +132,7 @@ def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode],
estimate: Cut size estimation to consider larger size of nodes in each
expansion step, in order to fasten the algorithm divergence towards a solution.
iter_limit: An upper limit for the number of expansion steps that the algorithm preforms.
time_limit: Optional time limit to the solver. Defaults to None which means no limit.
Returns: A solution (if found within the steps limit) which contains:
- A schedule for computation of the model (List of nodes).
Expand All @@ -146,7 +148,10 @@ def solve(self, estimate: float, iter_limit: int = 500) -> Tuple[List[BaseNode],

expansion_count = 0

t1 = time()
while expansion_count < iter_limit and len(open_list) > 0:
if time_limit is not None and time() - t1 > time_limit:
raise TimeoutError
# Choose next node to expand
next_cut = self._get_cut_to_expand(open_list, costs, routes, estimate)

Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ matplotlib<3.10.0
scipy
protobuf
mct-quantizers==1.5.2
pydantic<2.0
timeout-decorator
pydantic<2.0

0 comments on commit c374837

Please sign in to comment.