Skip to content

Commit f2ee949

Browse files
Add unit test for Hybrid NAS-based auto pruning
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent f18f8e2 commit f2ee949

File tree

6 files changed

+179
-15
lines changed

6 files changed

+179
-15
lines changed

modelopt/torch/nas/search_space.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,7 @@ def sort_parameters(self, hps_to_sort: set[str] | None = None, verbose: bool = F
135135
hps_to_sort: A set of hparam names to sort. If not provided or empty, all hparams will be sorted.
136136
verbose: Whether to print the search space and hparam importances.
137137
"""
138-
print_rank_0("Sorting parameters...")
139-
if verbose:
140-
self.print_summary()
138+
print_rank_0("\nSorting parameters...")
141139

142140
# get config and set to max
143141
config = self.config()

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,8 @@ def default_state_dict(self) -> SearchStateDict:
214214
def sanitize_search_config(self, config: SearchConfig | None) -> SearchConfig:
215215
"""Sanitize the search config dict."""
216216
config = super().sanitize_search_config(config)
217-
config["checkpoint"] = config["scores_path"]
217+
if config["scores_path"]:
218+
config["checkpoint"] = config["scores_path"]
218219
config["verbose"] = True # Print for all ranks
219220
return config
220221

@@ -457,7 +458,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
457458
start_layer_number += 1
458459
self.unwrapped_model.decoder.layers = all_layers
459460
print_rank_0(
460-
f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score"
461+
f"\t{candidate.ss_config} -> {num2hrb(candidate.params)} params, {candidate.score:.4f} score\n"
461462
)
462463

463464
dist.barrier()

tests/_test_utils/torch/megatron/models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ def get_mcore_mamba_hybrid_model(
314314
sequence_parallel: bool = False,
315315
# Mamba-specific parameters
316316
mamba_state_dim: int = 32,
317+
mamba_num_heads: int | None = None,
317318
mamba_head_dim: int = 16,
318319
mamba_num_groups: int = 2,
319320
# MoE-specific parameters
@@ -347,6 +348,7 @@ def get_mcore_mamba_hybrid_model(
347348
num_query_groups=num_query_groups,
348349
ffn_hidden_size=ffn_hidden_size,
349350
mamba_state_dim=mamba_state_dim,
351+
mamba_num_heads=mamba_num_heads,
350352
mamba_head_dim=mamba_head_dim,
351353
mamba_num_groups=mamba_num_groups,
352354
num_moe_experts=num_moe_experts,
@@ -358,7 +360,7 @@ def get_mcore_mamba_hybrid_model(
358360
**config_kwargs,
359361
)
360362

361-
if not (skip_moe or "E" in Symbols.VALID):
363+
if not (skip_moe or "E" in Symbols.VALID): # Mcore 0.16+ has MoE support
362364
warn("MoE blocks are not supported in current MambaModel. Skipping MoE blocks.")
363365
skip_moe = True
364366

tests/_test_utils/torch/nas_prune/minitron_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import modelopt.torch.prune as mtp
1717

1818

19-
def prune_minitron(model, export_config, config, channel_divisor=64):
19+
def prune_minitron(model, constraints, config, channel_divisor=64):
2020
return mtp.prune(
2121
model,
2222
mode=[
@@ -31,7 +31,7 @@ def prune_minitron(model, export_config, config, channel_divisor=64):
3131
),
3232
)
3333
],
34-
constraints={"export_config": export_config},
34+
constraints=constraints,
3535
dummy_input=None, # Not used
3636
config=config,
3737
)

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def forward_loop(m):
202202
export_config["hidden_size"] = pruned_hidden_size
203203
if pruned_num_layers_div != 1:
204204
export_config["num_layers"] = pruned_num_layers
205+
constraints = {"export_config": export_config}
205206

206207
config = {
207208
"scores_path": ckpt_path,
@@ -211,7 +212,7 @@ def forward_loop(m):
211212
assert ckpt_path is None
212213
else:
213214
config["forward_loop"] = forward_loop
214-
model, pruning_scores = prune_minitron(model, export_config, config, channel_divisor)
215+
model, pruning_scores = prune_minitron(model, constraints, config, channel_divisor)
215216
if not skip_sorting:
216217
assert pruning_scores["layer_scores"]
217218
assert pruning_scores["activations_per_rank"]
@@ -248,7 +249,7 @@ def forward_loop(m):
248249
model_rerun = _get_model(initialize_megatron=False)
249250
model_rerun.load_state_dict(sd)
250251
model_rerun, pruning_scores = prune_minitron(
251-
model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor
252+
model_rerun, constraints, {"scores_path": ckpt_path}, channel_divisor
252253
)
253254

254255
output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size)
@@ -450,10 +451,11 @@ def forward_loop(m):
450451
"moe_shared_expert_intermediate_size": pruned_moe_shared_ffn,
451452
"num_moe_experts": pruned_num_moe_experts,
452453
}
454+
constraints = {"export_config": export_config}
453455

454456
prune_minitron(
455457
model,
456-
export_config,
458+
constraints,
457459
{"scores_path": ckpt_path, "forward_loop": forward_loop},
458460
channel_divisor,
459461
)
@@ -491,7 +493,7 @@ def forward_loop(m):
491493
# Assert re-pruning from scores_path works without running the forward loop again
492494
model_rerun = _get_model(initialize_megatron=False)
493495
model_rerun.load_state_dict(sd)
494-
prune_minitron(model_rerun, export_config, {"scores_path": ckpt_path}, channel_divisor)
496+
prune_minitron(model_rerun, constraints, {"scores_path": ckpt_path}, channel_divisor)
495497

496498
output_rerun = run_mcore_inference(model_rerun, prompt_tokens, pruned_hidden_size)
497499
assert torch.allclose(output, output_rerun, atol=1e-5)

tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py

Lines changed: 164 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
# limitations under the License.
1515

1616

17+
import contextlib
18+
import io
1719
from functools import partial
1820

21+
import pytest
1922
import torch
2023
from _test_utils.import_helper import skip_if_no_megatron
2124

@@ -29,6 +32,7 @@
2932
)
3033
from _test_utils.torch.misc import compare_outputs, set_seed
3134
from _test_utils.torch.nas_prune.minitron_common import prune_minitron
35+
from megatron.core.ssm.mamba_hybrid_layer_allocation import Symbols
3236
from megatron.core.ssm.mamba_layer import MambaLayer
3337
from megatron.core.transformer.identity_op import IdentityOp
3438

@@ -37,6 +41,7 @@
3741
ImportanceEstimatorRegistry,
3842
_convert_model_to_dynamic_space,
3943
get_mcore_minitron_config,
44+
get_mcore_param_count,
4045
)
4146

4247
SEED = 1234
@@ -167,7 +172,7 @@ def _get_model(initialize_megatron=True):
167172
mamba_num_heads = mamba_layer.mixer.nheads
168173

169174
def forward_loop(m):
170-
for _ in range(5):
175+
for _ in range(2):
171176
run_mcore_inference_with_dummy_input(m, batch_size, hidden_size)
172177

173178
# Traditional GPT pruning parameters
@@ -191,9 +196,10 @@ def forward_loop(m):
191196
"moe_shared_expert_intermediate_size": pruned_ffn_hidden_size,
192197
"num_moe_experts": pruned_num_moe_experts,
193198
}
199+
constraints = {"export_config": export_config}
194200
prune_minitron(
195201
model,
196-
export_config,
202+
constraints,
197203
{"forward_loop": forward_loop, "scores_path": ckpt_path},
198204
channel_divisor,
199205
)
@@ -225,7 +231,7 @@ def forward_loop(m):
225231

226232
# Assert re-pruning from scores_path works without running the forward loop again
227233
model = _get_model(initialize_megatron=False)
228-
prune_minitron(model, export_config, {"scores_path": ckpt_path}, channel_divisor)
234+
prune_minitron(model, constraints, {"scores_path": ckpt_path}, channel_divisor)
229235

230236

231237
def test_mcore_mamba_hybrid_pruning(tmp_path):
@@ -234,3 +240,158 @@ def test_mcore_mamba_hybrid_pruning(tmp_path):
234240
job=partial(_test_mcore_mamba_hybrid_pruning, tmp_path / "modelopt_minitron_scores.pth"),
235241
backend="nccl",
236242
)
243+
244+
245+
def _test_mcore_mamba_hybrid_pruning_nas(ckpt_path, rank, size):
246+
channel_divisor = 4
247+
248+
# TODO: MoE in MambaModel requires Mcore 0.16+
249+
num_layers = 4 # Atleast one of "M, *, -, E" blocks
250+
hybrid_pattern = "M*-M" # "ME*-"
251+
hidden_size = 16
252+
ffn_hidden_size = 32
253+
num_attention_heads = 16
254+
num_query_groups = 4
255+
mamba_state_dim = 4
256+
mamba_num_heads = 16
257+
mamba_head_dim = 16
258+
mamba_num_groups = 2
259+
num_moe_experts = None
260+
moe_ffn_hidden_size = None
261+
moe_shared_expert_intermediate_size = None
262+
# num_moe_experts = 8
263+
# moe_ffn_hidden_size = 16
264+
# moe_shared_expert_intermediate_size = 16
265+
vocab_size = 32
266+
batch_size = 2
267+
268+
model = get_mcore_mamba_hybrid_model(
269+
tensor_model_parallel_size=1,
270+
pipeline_model_parallel_size=size,
271+
initialize_megatron=True,
272+
num_layers=num_layers,
273+
hybrid_override_pattern=hybrid_pattern,
274+
hidden_size=hidden_size,
275+
num_attention_heads=num_attention_heads,
276+
num_query_groups=num_query_groups,
277+
ffn_hidden_size=ffn_hidden_size,
278+
mamba_state_dim=mamba_state_dim,
279+
mamba_num_heads=mamba_num_heads,
280+
mamba_head_dim=mamba_head_dim,
281+
mamba_num_groups=mamba_num_groups,
282+
moe_ffn_hidden_size=moe_ffn_hidden_size,
283+
moe_shared_expert_intermediate_size=moe_shared_expert_intermediate_size,
284+
num_moe_experts=num_moe_experts,
285+
vocab_size=vocab_size,
286+
).cuda()
287+
288+
param_count = get_mcore_param_count(model)
289+
assert param_count == 31776.0, param_count
290+
291+
def forward_loop(m):
292+
for _ in range(2):
293+
run_mcore_inference_with_dummy_input(m, batch_size, hidden_size)
294+
295+
def score_func(m):
296+
c = m.config
297+
return (
298+
c.num_layers
299+
+ c.hidden_size
300+
+ c.ffn_hidden_size
301+
+ c.mamba_num_heads
302+
+ c.mamba_head_dim
303+
+ c.num_attention_heads
304+
# + c.num_moe_experts
305+
# + c.moe_ffn_hidden_size
306+
# + c.moe_shared_expert_intermediate_size
307+
)
308+
309+
constraints = {"params": int(param_count * 0.7)}
310+
config = {
311+
"forward_loop": forward_loop,
312+
"scores_path": ckpt_path,
313+
"score_func": score_func,
314+
"max_width_pruning": 0.5,
315+
"max_depth_pruning": 0.5,
316+
"hparams_to_skip": ["num_attention_heads"],
317+
"top_k": 10,
318+
}
319+
320+
# Capture stdout to assert search space output
321+
stdout_capture = io.StringIO()
322+
with contextlib.redirect_stdout(stdout_capture):
323+
model, searcher_state = prune_minitron(model, constraints, config, channel_divisor)
324+
325+
# Assert expected search space output is present
326+
captured_output = stdout_capture.getvalue()
327+
print(captured_output)
328+
if rank == 0:
329+
assert "Search space for num_layers: [3, 4]" in captured_output
330+
assert "Search space for hidden_size: [12, 16]" in captured_output
331+
assert "Search space for mamba_num_heads: [10, 12, 14, 16]" in captured_output
332+
assert "Search space for mamba_head_dim: [12, 16]" in captured_output
333+
assert "Search space for ffn_hidden_size: [20, 24, 28, 32]" in captured_output
334+
assert "Total search space in consideration: 128" in captured_output
335+
336+
# NOTE: Slight variation in layer ordering for Attention and MLP depending on PP configuration
337+
# This affects param counts when num_layers is pruned
338+
sorted_layers = [
339+
layer
340+
for layer, _ in sorted(
341+
searcher_state["layer_scores"].items(), key=lambda x: x[1], reverse=True
342+
)
343+
]
344+
# fmt: off
345+
if sorted_layers == [1, 4, 2, 3]:
346+
expected_top_k = [
347+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 22196.0, 94.0], # noqa: E501
348+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 28}, 22068.0, 90.0], # noqa: E501
349+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 24}, 21940.0, 86.0], # noqa: E501
350+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21916.0, 94.0], # noqa: E501
351+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 28}, 21820.0, 90.0], # noqa: E501
352+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21812.0, 82.0], # noqa: E501
353+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 24}, 21724.0, 86.0], # noqa: E501
354+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 20}, 21628.0, 82.0], # noqa: E501
355+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 10, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21180.0, 94.0], # noqa: E501
356+
[{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21140.0, 81.0], # noqa: E501
357+
]
358+
elif sorted_layers == [1, 4, 3, 2]:
359+
expected_top_k = [
360+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 22196.0, 94.0], # noqa: E501
361+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 28}, 22068.0, 90.0], # noqa: E501
362+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 24}, 21940.0, 86.0], # noqa: E501
363+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21916.0, 94.0], # noqa: E501
364+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 28}, 21820.0, 90.0], # noqa: E501
365+
[{"num_layers": 4, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 20}, 21812.0, 82.0], # noqa: E501
366+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 24}, 21724.0, 86.0], # noqa: E501
367+
[{"num_layers": 4, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 20}, 21628.0, 82.0], # noqa: E501
368+
[{"num_layers": 3, "hidden_size": 16, "mamba_num_heads": 14, "mamba_head_dim": 12, "ffn_hidden_size": 32}, 21524.0, 93.0], # noqa: E501
369+
[{"num_layers": 3, "hidden_size": 12, "mamba_num_heads": 14, "mamba_head_dim": 16, "ffn_hidden_size": 32}, 21412.0, 93.0], # noqa: E501
370+
]
371+
else:
372+
raise RuntimeError(f"FIXME: Non deterministic test, assertions may fail: {sorted_layers}")
373+
# fmt: on
374+
375+
assert get_mcore_param_count(model) == 22196.0
376+
377+
top_k = searcher_state["top_k_candidates_per_constraint"][constraints["params"]]
378+
assert len(top_k) == 10
379+
for actual, (ss_config, params, score) in zip(top_k, expected_top_k):
380+
assert actual.ss_config == ss_config, (actual.ss_config, ss_config)
381+
assert actual.params == params, (actual.params, params)
382+
assert actual.score == score, (actual.score, score)
383+
384+
385+
def test_mcore_mamba_hybrid_pruning_nas(tmp_path):
386+
set_seed(SEED)
387+
if torch.cuda.device_count() > 4:
388+
pytest.skip("Skipping test for more than 4 GPUs")
389+
if "E" in Symbols.VALID:
390+
pytest.skip("TODO: Update test for MoE in Mamba (Mcore 0.16+)")
391+
spawn_multiprocess_job(
392+
size=torch.cuda.device_count(),
393+
job=partial(
394+
_test_mcore_mamba_hybrid_pruning_nas, tmp_path / "modelopt_minitron_scores.pth"
395+
),
396+
backend="nccl",
397+
)

0 commit comments

Comments
 (0)