Skip to content

Commit f18f8e2

Browse files
Allow skipping some hparams in NAS and further restrict search space
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent ca95a63 commit f18f8e2

File tree

6 files changed

+69
-26
lines changed

6 files changed

+69
-26
lines changed

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,12 @@ class CandidateSubnet:
173173
class MCoreMinitronSearcher(BaseSearcher):
174174
"""Searcher for Minitron pruning algorithm.
175175
176-
Available additional config options:
177-
- `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.5).
176+
Available additional config options (used when `params` constraint is provided):
177+
- `max_width_pruning`: Maximum fraction per width hyperparameter to prune (default: 0.40).
178178
Only top (1 - max_width_pruning) choices will be considered.
179-
- `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.2).
179+
- `max_depth_pruning`: Maximum fraction per depth hyperparameter to prune (default: 0.20).
180180
Only top (1 - max_depth_pruning) choices will be considered.
181+
- `hparams_to_skip`: List of hparams to skip during the search (default: None).
181182
- `top_k`: Number of candidates to consider for score_func validation (default: 10).
182183
"""
183184

@@ -195,8 +196,9 @@ def default_search_config(self) -> SearchConfig:
195196
"skip_sorting": False,
196197
"scores_path": None,
197198
# Additional search config for parameter-based pruning
198-
"max_width_pruning": 0.5,
199-
"max_depth_pruning": 0.25,
199+
"max_width_pruning": 0.40,
200+
"max_depth_pruning": 0.20,
201+
"hparams_to_skip": None,
200202
"top_k": 10,
201203
}
202204

@@ -378,6 +380,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
378380
max_params = float(self.constraints["params"]) # type: ignore[arg-type]
379381
max_width_pruning = self.config["max_width_pruning"]
380382
max_depth_pruning = self.config["max_depth_pruning"]
383+
hparams_to_skip = self.config["hparams_to_skip"]
381384
top_k = self.config["top_k"]
382385
print_rank_0(
383386
f"\nSearching for the best pruned architecture under {num2hrb(max_params)} params constraints..."
@@ -401,6 +404,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
401404
hp_choices, # type: ignore[arg-type]
402405
max_width_pruning,
403406
max_depth_pruning,
407+
hparams_to_skip,
404408
)
405409
sample(self.model, sample_func=max) # reset to max subnet (for sanity)
406410
selected = []
@@ -466,18 +470,20 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
466470
@staticmethod
467471
def _generate_search_space_combos(
468472
search_space: dict[str, list],
469-
max_width_pruning: float = 0.5,
470-
max_depth_pruning: float = 0.2,
473+
max_width_pruning: float = 0.40,
474+
max_depth_pruning: float = 0.20,
475+
hparams_to_skip: list[str] | None = None,
471476
) -> list[dict[str, Any]]:
472477
"""Generate all possible combinations of hyperparameters from the search space.
473478
474479
Args:
475480
search_space: Dictionary mapping hyperparameter names to their possible sorted choices.
476481
Example: {"hidden_size": [1024, 2048, 3072, 4096], "num_layers": [1, 2, ..., 31, 32]}
477-
max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.5).
482+
max_width_pruning: Maximum fraction of width hyperparameters to prune (default: 0.40).
478483
Only top (1 - max_width_pruning) choices will be considered.
479-
max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.2).
484+
max_depth_pruning: Maximum fraction of depth hyperparameters to prune (default: 0.20).
480485
Only top (1 - max_depth_pruning) choices will be considered.
486+
hparams_to_skip: List of hparams to skip during the search (default: None).
481487
482488
Returns:
483489
List of configuration dictionaries, where each dictionary maps hyperparameter
@@ -494,11 +500,22 @@ def _generate_search_space_combos(
494500
f"{max_depth_pruning * 100:.0f}% for depth pruning hparams"
495501
)
496502

503+
if hparams_to_skip:
504+
print_rank_0(f"Skipping {hparams_to_skip=} during search space generation...")
505+
for hparam in hparams_to_skip:
506+
if hparam in search_space:
507+
search_space.pop(hparam)
508+
else:
509+
warn(f"Hparam {hparam} not found in search space! Skipping...")
510+
497511
filtered_ss = {
498-
k: sorted(v)[int((1 - max_depth_pruning) * len(v)) :]
499-
if k == "num_layers"
500-
else sorted(v)[int((1 - max_width_pruning) * len(v)) :]
512+
k: (
513+
sorted(v)[int((1 - max_depth_pruning) * len(v)) :]
514+
if k == "num_layers"
515+
else sorted(v)[int((1 - max_width_pruning) * len(v)) :]
516+
)
501517
for k, v in search_space.items()
518+
if len(v) > 1
502519
}
503520

504521
ss_size = 1
@@ -586,15 +603,15 @@ def get_param_count(mod, name) -> int:
586603
default_rules={
587604
"megatron.core.models.gpt.GPTModel": {
588605
"hidden_size_divisor": 256,
589-
"ffn_hidden_size_divisor": 256,
606+
"ffn_hidden_size_divisor": 512,
590607
"num_moe_experts_divisor": 8,
591608
"num_layers_divisor": 2,
592609
},
593610
**(
594611
{
595612
"megatron.core.models.mamba.MambaModel": {
596613
"hidden_size_divisor": 256,
597-
"ffn_hidden_size_divisor": 256,
614+
"ffn_hidden_size_divisor": 512,
598615
"mamba_head_dim_divisor": 8,
599616
"num_moe_experts_divisor": 8,
600617
"num_layers_divisor": 2,
@@ -611,20 +628,23 @@ def get_param_count(mod, name) -> int:
611628

612629
def get_mcore_minitron_config(
613630
*,
614-
channel_divisor: int = 256,
631+
hidden_size_divisor: int = 256,
632+
ffn_hidden_size_divisor: int = 512,
615633
mamba_head_dim_divisor: int = 8,
616634
num_moe_experts_divisor: int = 8,
617635
num_layers_divisor: int = 2,
618636
) -> ModeloptBaseConfig:
619-
"""Get a MCoreMinitronConfig with the given channel divisor instead of default."""
637+
"""Get a MCoreMinitronConfig with the given divisors instead of default."""
620638
config = MCoreMinitronConfig()
621639

622640
def _set_divisors(c):
623641
for k, v in c.items():
624642
if isinstance(v, dict):
625643
_set_divisors(v)
626-
elif k in ["hidden_size_divisor", "ffn_hidden_size_divisor"]:
627-
c[k] = channel_divisor
644+
elif k == "hidden_size_divisor":
645+
c[k] = hidden_size_divisor
646+
elif k == "ffn_hidden_size_divisor":
647+
c[k] = ffn_hidden_size_divisor
628648
elif k == "mamba_head_dim_divisor":
629649
c[k] = mamba_head_dim_divisor
630650
elif k == "num_moe_experts_divisor":

tests/_test_utils/torch/nas_prune/minitron_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def prune_minitron(model, export_config, config, channel_divisor=64):
2323
(
2424
"mcore_minitron",
2525
mtp.mcore_minitron.get_mcore_minitron_config(
26-
channel_divisor=channel_divisor,
26+
hidden_size_divisor=channel_divisor,
27+
ffn_hidden_size_divisor=channel_divisor,
2728
mamba_head_dim_divisor=4,
2829
num_moe_experts_divisor=1,
2930
num_layers_divisor=1,

tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,11 @@ def _test_gpt_search_space(
8888
[
8989
(
9090
"mcore_minitron",
91-
get_mcore_minitron_config(channel_divisor=channel_divisor, num_layers_divisor=1),
91+
get_mcore_minitron_config(
92+
hidden_size_divisor=channel_divisor,
93+
ffn_hidden_size_divisor=channel_divisor,
94+
num_layers_divisor=1,
95+
),
9296
)
9397
],
9498
)
@@ -267,7 +271,10 @@ def _test_gpt_moe_search_space(rank, size):
267271
(
268272
"mcore_minitron",
269273
get_mcore_minitron_config(
270-
channel_divisor=channel_divisor, num_moe_experts_divisor=1, num_layers_divisor=1
274+
hidden_size_divisor=channel_divisor,
275+
ffn_hidden_size_divisor=channel_divisor,
276+
num_moe_experts_divisor=1,
277+
num_layers_divisor=1,
271278
),
272279
)
273280
],

tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def _test_mamba_search_space(rank, size):
8181
(
8282
"mcore_minitron",
8383
get_mcore_minitron_config(
84-
channel_divisor=channel_divisor,
84+
hidden_size_divisor=channel_divisor,
85+
ffn_hidden_size_divisor=channel_divisor,
8586
mamba_head_dim_divisor=mamba_head_dim_divisor,
8687
num_layers_divisor=1,
8788
),

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def _test_mcore_gpt_parameter_sorting(activation_func, rank, size):
7878

7979
model.eval()
8080
dynamic_space = _convert_model_to_dynamic_space(
81-
model, get_mcore_minitron_config(channel_divisor=channel_divisor)
81+
model,
82+
get_mcore_minitron_config(
83+
hidden_size_divisor=channel_divisor, ffn_hidden_size_divisor=channel_divisor
84+
),
8285
)
8386
registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks
8487

@@ -355,7 +358,12 @@ def _test_mcore_gpt_moe_parameter_sorting(rank, size):
355358

356359
model.eval()
357360
dynamic_space = _convert_model_to_dynamic_space(
358-
model, get_mcore_minitron_config(channel_divisor=channel_divisor, num_moe_experts_divisor=1)
361+
model,
362+
get_mcore_minitron_config(
363+
hidden_size_divisor=channel_divisor,
364+
ffn_hidden_size_divisor=channel_divisor,
365+
num_moe_experts_divisor=1,
366+
),
359367
)
360368
registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks
361369

@@ -500,11 +508,12 @@ def test_mcore_gpt_pruning_moe(tmp_path):
500508
def test_generate_search_space_combos():
501509
ss = {
502510
"hidden_size": [32, 64, 96, 128, 160],
511+
"ffn_hidden_size": [128, 256, 384, 512, 640],
503512
"num_attention_heads": [8, 16, 24, 32],
504513
"num_layers": [1, 2, 3, 4, 5, 6, 7, 8],
505514
}
506515
ss_combos = MCoreMinitronSearcher._generate_search_space_combos(
507-
ss, max_width_pruning=0.5, max_depth_pruning=0.25
516+
ss, max_width_pruning=0.5, max_depth_pruning=0.25, hparams_to_skip=["ffn_hidden_size"]
508517
)
509518
assert len(ss_combos) == 3 * 2 * 2
510519
assert ss_combos == [

tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,12 @@ def _test_mcore_mamba_parameter_sorting(rank, size):
7878

7979
model.eval()
8080
dynamic_space = _convert_model_to_dynamic_space(
81-
model, get_mcore_minitron_config(channel_divisor=channel_divisor, mamba_head_dim_divisor=4)
81+
model,
82+
get_mcore_minitron_config(
83+
hidden_size_divisor=channel_divisor,
84+
ffn_hidden_size_divisor=channel_divisor,
85+
mamba_head_dim_divisor=4,
86+
),
8287
)
8388
registry = ImportanceEstimatorRegistry(model) # register imp estimators and forward hooks
8489

0 commit comments

Comments
 (0)