Skip to content

Commit 11bc408

Browse files
Add score calculation logic
Signed-off-by: Keval Morabia <28916987+kevalmorabia97@users.noreply.github.com>
1 parent b858631 commit 11bc408

File tree

1 file changed

+33
-7
lines changed

1 file changed

+33
-7
lines changed

modelopt/torch/prune/plugins/mcore_minitron.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"""
2626

2727
from collections.abc import Callable
28+
from dataclasses import dataclass
2829
from functools import partial
2930
from itertools import product
3031
from typing import Any
@@ -169,6 +170,13 @@ def drop_mcore_language_model_layers(model: nn.Module, *, layers_to_drop: list[i
169170
model.config.num_layers = new_num_layers
170171

171172

173+
@dataclass
174+
class CandidateSubnet:
175+
ss_config: dict
176+
params: float
177+
score: float | None
178+
179+
172180
class MCoreMinitronSearcher(BaseSearcher):
173181
"""Searcher for Minitron pruning algorithm.
174182
@@ -182,7 +190,8 @@ class MCoreMinitronSearcher(BaseSearcher):
182190

183191
activations_per_rank: list[dict[str, torch.Tensor]]
184192
layer_scores: dict[int, torch.Tensor]
185-
top_k_candidates_per_constraint: dict[float, list[tuple[dict, float]]]
193+
# Dict from params constraint to list of tuples (ss_config, params, score)
194+
top_k_candidates_per_constraint: dict[float, list[CandidateSubnet]]
186195

187196
@property
188197
def default_search_config(self) -> SearchConfig:
@@ -400,7 +409,7 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
400409
max_depth_pruning,
401410
)
402411
sample(self.model, sample_func=max) # reset to max subnet (for sanity)
403-
selected: list[tuple[dict, float]] = []
412+
selected = []
404413
for ss_config in tqdm(
405414
search_space_configs,
406415
desc=f"Finding top {top_k} candidates fitting the constraints...",
@@ -415,23 +424,40 @@ def search_best_arch_by_params(self, sorted_layers: list[int]) -> dict:
415424
layer_ids = sorted_layers[: ss_config["num_layers"]]
416425
candidate_params = _param_num_dynamic(self.model, layer_numbers_to_count=layer_ids)
417426
if candidate_params <= max_params:
418-
selected.append((ss_config, candidate_params))
427+
selected.append(CandidateSubnet(ss_config, candidate_params, None))
419428
sample(self.model, sample_func=max) # reset to max subnet
420429
assert len(selected) > 0, "No subnets found fitting the constraints!"
421430
print_rank_0(f"Found {len(selected)} candidates fitting the constraints!")
422431
self.top_k_candidates_per_constraint[max_params] = sorted(
423-
selected, key=lambda x: x[1], reverse=True
432+
selected, key=lambda x: x.params, reverse=True
424433
)[:top_k]
425434
self.save_search_checkpoint(verbose=True)
426435
else:
427436
print_rank_0(f"Using top {top_k} candidates from checkpoint")
428437
top_k_candidates = self.top_k_candidates_per_constraint[max_params]
429438

439+
print_rank_0(f"\n====================\nTop {top_k} candidates:")
440+
for candidate in top_k_candidates:
441+
print_rank_0(f"\t{num2hrb(candidate.params)} params: {candidate.ss_config}")
442+
print_rank_0("====================\n")
443+
430444
# 3. Validate top-k candidates using the score_func and return the best subnet
431-
# TODO: update this
432-
best = top_k_candidates[0][0]
445+
for candidate in tqdm(
446+
top_k_candidates,
447+
desc=f"Validating top {top_k} candidates on given score_func...",
448+
disable=not dist.is_master(),
449+
):
450+
if candidate.score is None: # not restored from checkpoint
451+
self._prune(candidate.ss_config, prune_depth=False, update_config=False)
452+
candidate.score = self.eval_score(silent=True)
453+
sample(self.model, sample_func=max) # reset to max subnet
454+
self.save_search_checkpoint(verbose=False)
455+
print_rank_0(
456+
f"\t{num2hrb(candidate.params)} params: {candidate.ss_config} -> {candidate.score:.4f} score"
457+
)
433458

434-
return best
459+
best = max(top_k_candidates, key=lambda x: x.score) # type: ignore[arg-type, return-value]
460+
return best.ss_config
435461

436462
@staticmethod
437463
def _generate_search_space_combos(

0 commit comments

Comments
 (0)