2525"""
2626
2727from collections .abc import Callable
28+ from dataclasses import dataclass
2829from functools import partial
2930from itertools import product
3031from typing import Any
@@ -169,6 +170,13 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i
169170 model .config .num_layers = new_num_layers
170171
171172
173+ @dataclass
174+ class CandidateSubnet :
175+ ss_config : dict
176+ params : float
177+ score : float | None
178+
179+
172180class MCoreMinitronSearcher (BaseSearcher ):
173181 """Searcher for Minitron pruning algorithm.
174182
@@ -182,7 +190,8 @@ class MCoreMinitronSearcher(BaseSearcher):
182190
183191 activations_per_rank : list [dict [str , torch .Tensor ]]
184192 layer_scores : dict [int , torch .Tensor ]
185- top_k_candidates_per_constraint : dict [float , list [tuple [dict , float ]]]
193+ # Dict from params constraint to list of tuples (ss_config, params, score)
194+ top_k_candidates_per_constraint : dict [float , list [CandidateSubnet ]]
186195
187196 @property
188197 def default_search_config (self ) -> SearchConfig :
@@ -400,7 +409,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
400409 max_depth_pruning ,
401410 )
402411 sample (self .model , sample_func = max ) # reset to max subnet (for sanity)
403- selected : list [ tuple [ dict , float ]] = []
412+ selected = []
404413 for ss_config in tqdm (
405414 search_space_configs ,
406415 desc = f"Finding top { top_k } candidates fitting the constraints..." ,
@@ -415,23 +424,40 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
415424 layer_ids = sorted_layers [: ss_config ["num_layers" ]]
416425 candidate_params = _param_num_dynamic (self .model , layer_numbers_to_count = layer_ids )
417426 if candidate_params <= max_params :
418- selected .append ((ss_config , candidate_params ))
427+ selected .append (CandidateSubnet (ss_config , candidate_params , None ))
419428 sample (self .model , sample_func = max ) # reset to max subnet
420429 assert len (selected ) > 0 , "No subnets found fitting the constraints!"
421430 print_rank_0 (f"Found { len (selected )} candidates fitting the constraints!" )
422431 self .top_k_candidates_per_constraint [max_params ] = sorted (
423- selected , key = lambda x : x [ 1 ] , reverse = True
432+ selected , key = lambda x : x . params , reverse = True
424433 )[:top_k ]
425434 self .save_search_checkpoint (verbose = True )
426435 else :
427436 print_rank_0 (f"Using top { top_k } candidates from checkpoint" )
428437 top_k_candidates = self .top_k_candidates_per_constraint [max_params ]
429438
439+ print_rank_0 (f"\n ====================\n Top { top_k } candidates:" )
440+ for candidate in top_k_candidates :
441+ print_rank_0 (f"\t { num2hrb (candidate .params )} params: { candidate .ss_config } " )
442+ print_rank_0 ("====================\n " )
443+
430444 # 3. Validate top-k candidates using the score_func and return the best subnet
431- # TODO: update this
432- best = top_k_candidates [0 ][0 ]
445+ for candidate in tqdm (
446+ top_k_candidates ,
447+ desc = f"Validating top { top_k } candidates on given score_func..." ,
448+ disable = not dist .is_master (),
449+ ):
450+ if candidate .score is None : # not restored from checkpoint
451+ self ._prune (candidate .ss_config , prune_depth = False , update_config = False )
452+ candidate .score = self .eval_score (silent = True )
453+ sample (self .model , sample_func = max ) # reset to max subnet
454+ self .save_search_checkpoint (verbose = False )
455+ print_rank_0 (
456+ f"\t { num2hrb (candidate .params )} params: { candidate .ss_config } -> { candidate .score :.4f} score"
457+ )
433458
434- return best
459+ best = max (top_k_candidates , key = lambda x : x .score ) # type: ignore[arg-type, return-value]
460+ return best .ss_config
435461
436462 @staticmethod
437463 def _generate_search_space_combos (
0 commit comments