@@ -173,11 +173,12 @@ class CandidateSubnet:
173173class MCoreMinitronSearcher (BaseSearcher ):
174174 """Searcher for Minitron pruning algorithm.
175175
176- Available additional config options:
177- - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.5 ).
176+ Available additional config options (used when `params` constraint is provided) :
177+ - `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.40 ).
178178 Only top (1 - max_width_pruning) choices will be considered.
179- - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.2 ).
179+ - `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.20 ).
180180 Only top (1 - max_depth_pruning) choices will be considered.
181+ - `hparams_to_skip`: List of hparams to skip during the search (default: None).
181182 - `top_k`: Number of candidates to consider for score_func validation (default: 10).
182183 """
183184
@@ -195,8 +196,9 @@ def default_search_config(self) -> SearchConfig:
195196 "skip_sorting" : False ,
196197 "scores_path" : None ,
197198 # Additional search config for parameter-based pruning
198- "max_width_pruning" : 0.5 ,
199- "max_depth_pruning" : 0.25 ,
199+ "max_width_pruning" : 0.40 ,
200+ "max_depth_pruning" : 0.20 ,
201+ "hparams_to_skip" : None ,
200202 "top_k" : 10 ,
201203 }
202204
@@ -378,6 +380,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
378380 max_params = float (self .constraints ["params" ]) # type: ignore[arg-type]
379381 max_width_pruning = self .config ["max_width_pruning" ]
380382 max_depth_pruning = self .config ["max_depth_pruning" ]
383+ hparams_to_skip = self .config ["hparams_to_skip" ]
381384 top_k = self .config ["top_k" ]
382385 print_rank_0 (
383386 f"\n Searching for the best pruned architecture under { num2hrb (max_params )} params constraints..."
@@ -401,6 +404,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
401404 hp_choices , # type: ignore[arg-type]
402405 max_width_pruning ,
403406 max_depth_pruning ,
407+ hparams_to_skip ,
404408 )
405409 sample (self .model , sample_func = max ) # reset to max subnet (for sanity)
406410 selected = []
@@ -466,18 +470,20 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
466470 @staticmethod
467471 def _generate_search_space_combos (
468472 search_space : dict [str , list ],
469- max_width_pruning : float = 0.5 ,
470- max_depth_pruning : float = 0.2 ,
473+ max_width_pruning : float = 0.40 ,
474+ max_depth_pruning : float = 0.20 ,
475+ hparams_to_skip : list [str ] | None = None ,
471476 ) -> list [dict [str , Any ]]:
472477 """Generate all possible combinations of hyperparameters from the search space.
473478
474479 Args:
475480 search_space: Dictionary mapping hyperparameter names to their possible sorted choices.
476481 Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]}
477- max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.5 ).
482+ max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.40 ).
478483 Only top (1 - max_width_pruning) choices will be considered.
479- max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.2 ).
484+ max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.20 ).
480485 Only top (1 - max_depth_pruning) choices will be considered.
486+ hparams_to_skip: List of hparams to skip during the search (default: None).
481487
482488 Returns:
483489 List of configuration dictionaries, where each dictionary maps hyperparameter
@@ -494,11 +500,22 @@ def _generate_search_space_combos(
494500 f"{ max_depth_pruning * 100 :.0f} % for depth pruning hparams"
495501 )
496502
503+ if hparams_to_skip :
504+ print_rank_0 (f"Skipping { hparams_to_skip = } during search space generation..." )
505+ for hparam in hparams_to_skip :
506+ if hparam in search_space :
507+ search_space .pop (hparam )
508+ else :
509+ warn (f"Hparam { hparam } not found in search space! Skipping..." )
510+
497511 filtered_ss = {
498- k : sorted (v )[int ((1 - max_depth_pruning ) * len (v )) :]
499- if k == "num_layers"
500- else sorted (v )[int ((1 - max_width_pruning ) * len (v )) :]
512+ k : (
513+ sorted (v )[int ((1 - max_depth_pruning ) * len (v )) :]
514+ if k == "num_layers"
515+ else sorted (v )[int ((1 - max_width_pruning ) * len (v )) :]
516+ )
501517 for k , v in search_space .items ()
518+ if len (v ) > 1
502519 }
503520
504521 ss_size = 1
@@ -586,15 +603,15 @@ def get_param_count(mod, name) -> int:
586603 default_rules = {
587604 "megatron.core.models.gpt.GPTModel" : {
588605 "hidden_size_divisor" : 256 ,
589- "ffn_hidden_size_divisor" : 256 ,
606+ "ffn_hidden_size_divisor" : 512 ,
590607 "num_moe_experts_divisor" : 8 ,
591608 "num_layers_divisor" : 2 ,
592609 },
593610 ** (
594611 {
595612 "megatron.core.models.mamba.MambaModel" : {
596613 "hidden_size_divisor" : 256 ,
597- "ffn_hidden_size_divisor" : 256 ,
614+ "ffn_hidden_size_divisor" : 512 ,
598615 "mamba_head_dim_divisor" : 8 ,
599616 "num_moe_experts_divisor" : 8 ,
600617 "num_layers_divisor" : 2 ,
@@ -611,20 +628,23 @@ def get_param_count(mod, name) -> int:
611628
612629def get_mcore_minitron_config (
613630 * ,
614- channel_divisor : int = 256 ,
631+ hidden_size_divisor : int = 256 ,
632+ ffn_hidden_size_divisor : int = 512 ,
615633 mamba_head_dim_divisor : int = 8 ,
616634 num_moe_experts_divisor : int = 8 ,
617635 num_layers_divisor : int = 2 ,
618636) -> ModeloptBaseConfig :
619- """Get a MCoreMinitronConfig with the given channel divisor instead of default."""
637+ """Get a MCoreMinitronConfig with the given divisors instead of default."""
620638 config = MCoreMinitronConfig ()
621639
622640 def _set_divisors (c ):
623641 for k , v in c .items ():
624642 if isinstance (v , dict ):
625643 _set_divisors (v )
626- elif k in ["hidden_size_divisor" , "ffn_hidden_size_divisor" ]:
627- c [k ] = channel_divisor
644+ elif k == "hidden_size_divisor" :
645+ c [k ] = hidden_size_divisor
646+ elif k == "ffn_hidden_size_divisor" :
647+ c [k ] = ffn_hidden_size_divisor
628648 elif k == "mamba_head_dim_divisor" :
629649 c [k ] = mamba_head_dim_divisor
630650 elif k == "num_moe_experts_divisor" :
0 commit comments