1414# limitations under the License.
1515
1616
17+ import contextlib
18+ import io
1719from functools import partial
1820
21+ import pytest
1922import torch
2023from _test_utils .import_helper import skip_if_no_megatron
2124
2932)
3033from _test_utils .torch .misc import compare_outputs , set_seed
3134from _test_utils .torch .nas_prune .minitron_common import prune_minitron
35+ from megatron .core .ssm .mamba_hybrid_layer_allocation import Symbols
3236from megatron .core .ssm .mamba_layer import MambaLayer
3337from megatron .core .transformer .identity_op import IdentityOp
3438
3741 ImportanceEstimatorRegistry ,
3842 _convert_model_to_dynamic_space ,
3943 get_mcore_minitron_config ,
44+ get_mcore_param_count ,
4045)
4146
4247SEED = 1234
@@ -167,7 +172,7 @@ def _get_model(initialize_megatron=True):
167172 mamba_num_heads = mamba_layer .mixer .nheads
168173
169174 def forward_loop (m ):
170- for _ in range (5 ):
175+ for _ in range (2 ):
171176 run_mcore_inference_with_dummy_input (m , batch_size , hidden_size )
172177
173178 # Traditional GPT pruning parameters
@@ -191,9 +196,10 @@ def forward_loop(m):
191196 "moe_shared_expert_intermediate_size" : pruned_ffn_hidden_size ,
192197 "num_moe_experts" : pruned_num_moe_experts ,
193198 }
199+ constraints = {"export_config" : export_config }
194200 prune_minitron (
195201 model ,
196- export_config ,
202+ constraints ,
197203 {"forward_loop" : forward_loop , "scores_path" : ckpt_path },
198204 channel_divisor ,
199205 )
@@ -225,7 +231,7 @@ def forward_loop(m):
225231
226232 # Assert re-pruning from scores_path works without running the forward loop again
227233 model = _get_model (initialize_megatron = False )
228- prune_minitron (model , export_config , {"scores_path" : ckpt_path }, channel_divisor )
234+ prune_minitron (model , constraints , {"scores_path" : ckpt_path }, channel_divisor )
229235
230236
231237def test_mcore_mamba_hybrid_pruning (tmp_path ):
@@ -234,3 +240,158 @@ def test_mcore_mamba_hybrid_pruning(tmp_path):
234240 job = partial (_test_mcore_mamba_hybrid_pruning , tmp_path / "modelopt_minitron_scores.pth" ),
235241 backend = "nccl" ,
236242 )
243+
244+
245+ def _test_mcore_mamba_hybrid_pruning_nas (ckpt_path , rank , size ):
246+ channel_divisor = 4
247+
248+ # TODO: MoE in MambaModel requires Mcore 0.16+
249+ num_layers = 4 # Atleast one of "M, *, -, E" blocks
250+ hybrid_pattern = "M*-M" # "ME*-"
251+ hidden_size = 16
252+ ffn_hidden_size = 32
253+ num_attention_heads = 16
254+ num_query_groups = 4
255+ mamba_state_dim = 4
256+ mamba_num_heads = 16
257+ mamba_head_dim = 16
258+ mamba_num_groups = 2
259+ num_moe_experts = None
260+ moe_ffn_hidden_size = None
261+ moe_shared_expert_intermediate_size = None
262+ # num_moe_experts = 8
263+ # moe_ffn_hidden_size = 16
264+ # moe_shared_expert_intermediate_size = 16
265+ vocab_size = 32
266+ batch_size = 2
267+
268+ model = get_mcore_mamba_hybrid_model (
269+ tensor_model_parallel_size = 1 ,
270+ pipeline_model_parallel_size = size ,
271+ initialize_megatron = True ,
272+ num_layers = num_layers ,
273+ hybrid_override_pattern = hybrid_pattern ,
274+ hidden_size = hidden_size ,
275+ num_attention_heads = num_attention_heads ,
276+ num_query_groups = num_query_groups ,
277+ ffn_hidden_size = ffn_hidden_size ,
278+ mamba_state_dim = mamba_state_dim ,
279+ mamba_num_heads = mamba_num_heads ,
280+ mamba_head_dim = mamba_head_dim ,
281+ mamba_num_groups = mamba_num_groups ,
282+ moe_ffn_hidden_size = moe_ffn_hidden_size ,
283+ moe_shared_expert_intermediate_size = moe_shared_expert_intermediate_size ,
284+ num_moe_experts = num_moe_experts ,
285+ vocab_size = vocab_size ,
286+ ).cuda ()
287+
288+ param_count = get_mcore_param_count (model )
289+ assert param_count == 31776.0 , param_count
290+
291+ def forward_loop (m ):
292+ for _ in range (2 ):
293+ run_mcore_inference_with_dummy_input (m , batch_size , hidden_size )
294+
295+ def score_func (m ):
296+ c = m .config
297+ return (
298+ c .num_layers
299+ + c .hidden_size
300+ + c .ffn_hidden_size
301+ + c .mamba_num_heads
302+ + c .mamba_head_dim
303+ + c .num_attention_heads
304+ # + c.num_moe_experts
305+ # + c.moe_ffn_hidden_size
306+ # + c.moe_shared_expert_intermediate_size
307+ )
308+
309+ constraints = {"params" : int (param_count * 0.7 )}
310+ config = {
311+ "forward_loop" : forward_loop ,
312+ "scores_path" : ckpt_path ,
313+ "score_func" : score_func ,
314+ "max_width_pruning" : 0.5 ,
315+ "max_depth_pruning" : 0.5 ,
316+ "hparams_to_skip" : ["num_attention_heads" ],
317+ "top_k" : 10 ,
318+ }
319+
320+ # Capture stdout to assert search space output
321+ stdout_capture = io .StringIO ()
322+ with contextlib .redirect_stdout (stdout_capture ):
323+ model , searcher_state = prune_minitron (model , constraints , config , channel_divisor )
324+
325+ # Assert expected search space output is present
326+ captured_output = stdout_capture .getvalue ()
327+ print (captured_output )
328+ if rank == 0 :
329+ assert "Search space for num_layers: [3, 4]" in captured_output
330+ assert "Search space for hidden_size: [12, 16]" in captured_output
331+ assert "Search space for mamba_num_heads: [10, 12, 14, 16]" in captured_output
332+ assert "Search space for mamba_head_dim: [12, 16]" in captured_output
333+ assert "Search space for ffn_hidden_size: [20, 24, 28, 32]" in captured_output
334+ assert "Total search space in consideration: 128" in captured_output
335+
336+ # NOTE: Slight variation in layer ordering for Attention and MLP depending on PP configuration
337+ # This affects param counts when num_layers is pruned
338+ sorted_layers = [
339+ layer
340+ for layer , _ in sorted (
341+ searcher_state ["layer_scores" ].items (), key = lambda x : x [1 ], reverse = True
342+ )
343+ ]
344+ # fmt: off
345+ if sorted_layers == [1 , 4 , 2 , 3 ]:
346+ expected_top_k = [
347+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 32 }, 22196.0 , 94.0 ], # noqa: E501
348+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 28 }, 22068.0 , 90.0 ], # noqa: E501
349+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 24 }, 21940.0 , 86.0 ], # noqa: E501
350+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 32 }, 21916.0 , 94.0 ], # noqa: E501
351+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 28 }, 21820.0 , 90.0 ], # noqa: E501
352+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 20 }, 21812.0 , 82.0 ], # noqa: E501
353+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 24 }, 21724.0 , 86.0 ], # noqa: E501
354+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 20 }, 21628.0 , 82.0 ], # noqa: E501
355+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 10 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 32 }, 21180.0 , 94.0 ], # noqa: E501
356+ [{"num_layers" : 3 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 20 }, 21140.0 , 81.0 ], # noqa: E501
357+ ]
358+ elif sorted_layers == [1 , 4 , 3 , 2 ]:
359+ expected_top_k = [
360+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 32 }, 22196.0 , 94.0 ], # noqa: E501
361+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 28 }, 22068.0 , 90.0 ], # noqa: E501
362+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 24 }, 21940.0 , 86.0 ], # noqa: E501
363+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 32 }, 21916.0 , 94.0 ], # noqa: E501
364+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 28 }, 21820.0 , 90.0 ], # noqa: E501
365+ [{"num_layers" : 4 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 20 }, 21812.0 , 82.0 ], # noqa: E501
366+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 24 }, 21724.0 , 86.0 ], # noqa: E501
367+ [{"num_layers" : 4 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 20 }, 21628.0 , 82.0 ], # noqa: E501
368+ [{"num_layers" : 3 , "hidden_size" : 16 , "mamba_num_heads" : 14 , "mamba_head_dim" : 12 , "ffn_hidden_size" : 32 }, 21524.0 , 93.0 ], # noqa: E501
369+ [{"num_layers" : 3 , "hidden_size" : 12 , "mamba_num_heads" : 14 , "mamba_head_dim" : 16 , "ffn_hidden_size" : 32 }, 21412.0 , 93.0 ], # noqa: E501
370+ ]
371+ else :
372+ raise RuntimeError (f"FIXME: Non deterministic test, assertions may fail: { sorted_layers } " )
373+ # fmt: on
374+
375+ assert get_mcore_param_count (model ) == 22196.0
376+
377+ top_k = searcher_state ["top_k_candidates_per_constraint" ][constraints ["params" ]]
378+ assert len (top_k ) == 10
379+ for actual , (ss_config , params , score ) in zip (top_k , expected_top_k ):
380+ assert actual .ss_config == ss_config , (actual .ss_config , ss_config )
381+ assert actual .params == params , (actual .params , params )
382+ assert actual .score == score , (actual .score , score )
383+
384+
385+ def test_mcore_mamba_hybrid_pruning_nas (tmp_path ):
386+ set_seed (SEED )
387+ if torch .cuda .device_count () > 4 :
388+ pytest .skip ("Skipping test for more than 4 GPUs" )
389+ if "E" in Symbols .VALID :
390+ pytest .skip ("TODO: Update test for MoE in Mamba (Mcore 0.16+)" )
391+ spawn_multiprocess_job (
392+ size = torch .cuda .device_count (),
393+ job = partial (
394+ _test_mcore_mamba_hybrid_pruning_nas , tmp_path / "modelopt_minitron_scores.pth"
395+ ),
396+ backend = "nccl" ,
397+ )
0 commit comments