|
15 | 15 | import torch |
16 | 16 |
|
17 | 17 | import tensorrt_llm |
18 | | -from tensorrt_llm._utils import mpi_barrier |
| 18 | +from tensorrt_llm._utils import mpi_barrier, mpi_broadcast |
19 | 19 | from tensorrt_llm.bindings.internal.runtime import delay_kernel |
20 | 20 | from tensorrt_llm.logger import logger |
21 | 21 |
|
@@ -659,6 +659,13 @@ def _profile_runners( |
659 | 659 | tuning_config: TuningConfig, |
660 | 660 | **kwargs, |
661 | 661 | ) -> float: |
| 662 | + """Profile runners and select the best tactic. |
| 663 | +
|
| 664 | + For multi-rank profiling, only rank 0 performs the actual profiling |
| 665 | + to avoid sync issues when different ranks select different tactics. |
| 666 | + The results are then broadcasted to all other ranks. |
| 667 | + """ |
| 668 | + |
662 | 669 | min_time = float('inf') |
663 | 670 | has_tuning_failure_occured = False |
664 | 671 | best_runner_id, best_tactic = None, None |
@@ -709,6 +716,13 @@ def _profile_runners( |
709 | 716 | min_time = time_measured |
710 | 717 | best_runner_id, best_tactic = runner_id, tac |
711 | 718 |
|
| 719 | + if self._is_sync_op(runner): |
| 720 | + profiling_results = (best_runner_id, best_tactic, min_time, |
| 721 | + has_tuning_failure_occured) |
| 722 | + # Broadcast profiling results from rank 0 to all other ranks |
| 723 | + profiling_results = mpi_broadcast(profiling_results, root=0) |
| 724 | + best_runner_id, best_tactic, min_time, has_tuning_failure_occured = profiling_results |
| 725 | + |
712 | 726 | return best_runner_id, best_tactic, min_time, has_tuning_failure_occured |
713 | 727 |
|
714 | 728 | def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]: |
|
0 commit comments