- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.8k
[TRTLLM-8821][feat] Apply AutoTuner to AllReduce Op for strategy tuning. #8531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -8,6 +8,7 @@ | |
| import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils | ||
| from tensorrt_llm import deep_gemm | ||
| from tensorrt_llm._utils import get_sm_version | ||
| from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy | ||
|  | ||
| from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, | ||
| OptimizationProfile, TunableRunner, TuningConfig) | ||
|  | @@ -1139,6 +1140,161 @@ def _( | |
| return x.new_empty((b, d), dtype=o_dtype) | ||
|  | ||
|  | ||
| class AllReduceRunner(TunableRunner): | ||
| all_support_ops = { | ||
| AllReduceFusionOp.NONE.value, | ||
| AllReduceFusionOp.RESIDUAL_RMS_NORM.value, | ||
| } | ||
|  | ||
| tuning_config = TuningConfig( | ||
| dynamic_tensor_specs=(DynamicTensorSpec( | ||
| 0, 0, | ||
| (8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1), | ||
| last_positive_power_of_2), ), | ||
| constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ), | ||
| ) | ||
|  | ||
| def __init__( | ||
| self, | ||
| tp_size: int, | ||
| group: List[int], | ||
| op: int, | ||
| eps: float, | ||
| trigger_completion_at_end: bool, | ||
| ): | ||
| self.tp_size = tp_size | ||
| self.op = op | ||
| self._group = group | ||
| self._eps = eps | ||
| self._trigger_completion_at_end = trigger_completion_at_end | ||
|  | ||
| def __hash__(self): | ||
| return hash((self.tp_size, self.op)) | ||
|  | ||
| def get_valid_tactics( | ||
| self, | ||
| inputs: List[torch.Tensor], | ||
| profile: OptimizationProfile, | ||
| **kwargs, | ||
| ) -> List[int]: | ||
| valid_tactics = [ | ||
| AllReduceStrategy.NCCL.value, | ||
| AllReduceStrategy.ONESHOT.value, | ||
| ] | ||
| if inputs[0].shape[0] >= self.tp_size: | ||
| valid_tactics.append(AllReduceStrategy.TWOSHOT.value) | ||
| return valid_tactics | ||
|  | ||
| def forward( | ||
| self, | ||
| inputs: List[torch.Tensor], | ||
| tactic: int = -1, | ||
| ) -> torch.Tensor: | ||
| input, residual, norm_weight, scale, bias, workspace = inputs | ||
| if tactic == -1: | ||
| tactic = AllReduceStrategy.NCCL.value | ||
|  | ||
| return torch.ops.trtllm.allreduce( | ||
| input, | ||
| residual, | ||
| norm_weight, | ||
| scale, | ||
| bias, | ||
| workspace, | ||
| self._group, | ||
| tactic, | ||
| self.op, | ||
| self._eps, | ||
| self._trigger_completion_at_end, | ||
| ) | ||
|  | ||
|  | ||
| @torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=()) | ||
| def tunable_allreduce( | ||
| input: torch.Tensor, | ||
| residual: Optional[torch.Tensor], | ||
| norm_weight: Optional[torch.Tensor], | ||
| scale: Optional[torch.Tensor], | ||
| bias: Optional[torch.Tensor], | ||
| workspace: Optional[torch.Tensor], | ||
| group: List[int], | ||
| strategy: int, | ||
| op: int, | ||
| eps: float, | ||
| tp_size: int, | ||
| trigger_completion_at_end: bool, | ||
| ) -> List[torch.Tensor]: | ||
|  | ||
| tuner = AutoTuner.get() | ||
|  | ||
| allreduce_runner = AllReduceRunner( | ||
| tp_size, | ||
| group, | ||
| op, | ||
| eps, | ||
| trigger_completion_at_end, | ||
| ) | ||
|  | ||
| _, best_tactic = tuner.choose_one( | ||
| "trtllm::tunable_allreduce::allreduce", | ||
| [allreduce_runner], | ||
| AllReduceRunner.tuning_config, | ||
| [input, residual, norm_weight, scale, bias, workspace], | ||
| ) | ||
|  | ||
| return allreduce_runner( | ||
| [input, residual, norm_weight, scale, bias, workspace], | ||
| tactic=best_tactic, | ||
| ) | ||
|  | ||
| 
      Comment on lines
    
      +1212
     to 
      +1249
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fix: Wire PG args through tunable_allreduce and correct fake signature. 
 Patch: @@
-@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
+@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
 def tunable_allreduce(
     input: torch.Tensor,
     residual: Optional[torch.Tensor],
     norm_weight: Optional[torch.Tensor],
     scale: Optional[torch.Tensor],
     bias: Optional[torch.Tensor],
     workspace: Optional[torch.Tensor],
     group: List[int],
     strategy: int,
     op: int,
     eps: float,
     tp_size: int,
     trigger_completion_at_end: bool,
+    rank: Optional[int] = None,
+    pg: Optional[object] = None,
 ) -> List[torch.Tensor]:
@@
-    allreduce_runner = AllReduceRunner(
+    allreduce_runner = AllReduceRunner(
         tp_size,
         group,
         op,
         eps,
         trigger_completion_at_end,
+        rank=rank,
+        pg=pg,
     )
@@
-@tunable_allreduce.register_fake
+@tunable_allreduce.register_fake
 def _(
     input: torch.Tensor,
     residual: Optional[torch.Tensor],
     norm_weight: Optional[torch.Tensor],
     scale: Optional[torch.Tensor],
     bias: Optional[torch.Tensor],
     workspace: Optional[torch.Tensor],
     group: List[int],
     strategy: int,
     op: int,
     eps: float,
+    tp_size: int,
     trigger_completion_at_end: bool,
+    rank: Optional[int] = None,
+    pg: Optional[object] = None,
 ) -> torch.Tensor:This aligns the fake schema with the real op and unblocks torch.compile/meta. Also applies to: 1251-1295 🧰 Tools🪛 Ruff (0.14.1)1221-1221: Unused function argument:  (ARG001) | ||
|  | ||
| @tunable_allreduce.register_fake | ||
| def _( | ||
| input: torch.Tensor, | ||
| residual: Optional[torch.Tensor], | ||
| norm_weight: Optional[torch.Tensor], | ||
| scale: Optional[torch.Tensor], | ||
| bias: Optional[torch.Tensor], | ||
| workspace: Optional[torch.Tensor], | ||
| group: List[int], | ||
| strategy: int, | ||
| op: int, | ||
| eps: float, | ||
| tp_size: int, | ||
| trigger_completion_at_end: bool, | ||
| ) -> torch.Tensor: | ||
| if op == int(AllReduceFusionOp.NONE): | ||
| return [torch.empty_like(input)] | ||
| elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM): | ||
| norm_out = torch.empty_like(input) | ||
| residual_out = torch.empty_like(input) | ||
| return [norm_out, residual_out] | ||
| elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8): | ||
| quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) | ||
| residual_out = torch.empty_like(input) | ||
| return [quant_out, residual_out] | ||
| elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8): | ||
| norm_out = torch.empty_like(input) | ||
| quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) | ||
| residual_out = torch.empty_like(input) | ||
| return [norm_out, quant_out, residual_out] | ||
| elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4): | ||
| fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) | ||
| quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) | ||
| scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) | ||
| residual_out = torch.empty_like(input) | ||
| return [quant_fp4, scale_fp4, residual_out] | ||
| elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4): | ||
| fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) | ||
| quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) | ||
| scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) | ||
| norm_out = torch.empty_like(input) | ||
| residual_out = torch.empty_like(input) | ||
| return [norm_out, quant_fp4, scale_fp4, residual_out] | ||
| else: | ||
| return [torch.empty_like(input)] | ||
|  | ||
|  | ||
| def get_event(event_idx: int): | ||
| from ..utils import get_model_extra_attrs | ||
| extra_attrs = get_model_extra_attrs() | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -505,7 +505,6 @@ def __init__(self, | |
| self._disable_mpi = mpi_disabled() | ||
|  | ||
| self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce | ||
|  | ||
| if self.mapping.tp_size > 1: | ||
| # When Strategy is UB, it is guaranteed that the workspace is not used. | ||
| if self.strategy != AllReduceStrategy.UB: | ||
|  | @@ -574,6 +573,7 @@ def forward( | |
| input = input.contiguous() # Underlying op requires contiguous input | ||
|  | ||
| allreduce_strategy = self.strategy | ||
|  | ||
| if all_reduce_params is None: | ||
| all_reduce_params = AllReduceParams() | ||
|  | ||
|  | @@ -598,21 +598,38 @@ def forward( | |
| "pg": pg.boxed(), | ||
| } | ||
|  | ||
| output = self.all_reduce_op( | ||
| input=input, | ||
| residual=all_reduce_params.residual, | ||
| norm_weight=all_reduce_params.norm_weight, | ||
| scale=all_reduce_params.scale, | ||
| bias=all_reduce_params.bias, | ||
| workspace=self.workspace, | ||
| group=self.mapping.tp_group, | ||
| strategy=allreduce_strategy, | ||
| op=all_reduce_params.fusion_op, | ||
| eps=all_reduce_params.eps, | ||
| trigger_completion_at_end=all_reduce_params. | ||
| trigger_completion_at_end, | ||
| **additional_args, | ||
| ) | ||
| if self.strategy == AllReduceStrategy.AUTOTUNE: | ||
| output = torch.ops.trtllm.tunable_allreduce( | ||
| input=input, | ||
| residual=all_reduce_params.residual, | ||
| norm_weight=all_reduce_params.norm_weight, | ||
| scale=all_reduce_params.scale, | ||
| bias=all_reduce_params.bias, | ||
| workspace=self.workspace, | ||
| group=self.mapping.tp_group, | ||
| strategy=allreduce_strategy, | ||
| op=all_reduce_params.fusion_op, | ||
| eps=all_reduce_params.eps, | ||
| tp_size=self.mapping.tp_size, | ||
| trigger_completion_at_end=all_reduce_params. | ||
| trigger_completion_at_end, | ||
| ) | ||
| else: | ||
| 
      Comment on lines
    
      +601
     to 
      +617
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. MPI-disabled bug: AUTOTUNE path ignores ProcessGroup and rank; uses non-PG op. In the AUTOTUNE branch you call trtllm.tunable_allreduce without passing rank/pg, and AllReduceRunner always dispatches to allreduce (non-PG). When mpi_disabled(), this will fail or hang. Pass rank and pg through and make the runner use allreduce_pg when MPI is disabled. Suggested patch (ops side): @@
-        if self.strategy == AllReduceStrategy.AUTOTUNE:
-            output = torch.ops.trtllm.tunable_allreduce(
+        if self.strategy == AllReduceStrategy.AUTOTUNE:
+            output = torch.ops.trtllm.tunable_allreduce(
                 input=input,
                 residual=all_reduce_params.residual,
                 norm_weight=all_reduce_params.norm_weight,
                 scale=all_reduce_params.scale,
                 bias=all_reduce_params.bias,
                 workspace=self.workspace,
                 group=self.mapping.tp_group,
                 strategy=allreduce_strategy,
                 op=all_reduce_params.fusion_op,
                 eps=all_reduce_params.eps,
                 tp_size=self.mapping.tp_size,
-                trigger_completion_at_end=all_reduce_params.
-                trigger_completion_at_end,
+                trigger_completion_at_end=all_reduce_params.trigger_completion_at_end,
+                # Wire PG when MPI is disabled
+                **({"rank": torch.distributed.get_rank(),
+                    "pg": self.mapping.tp_group_pg.boxed()} if self._disable_mpi else {}),
             )Apply the complementary changes in tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (see my other comment) to accept rank/pg and call allreduce_pg accordingly. 
 | ||
| output = self.all_reduce_op( | ||
| input=input, | ||
| residual=all_reduce_params.residual, | ||
| norm_weight=all_reduce_params.norm_weight, | ||
| scale=all_reduce_params.scale, | ||
| bias=all_reduce_params.bias, | ||
| workspace=self.workspace, | ||
| group=self.mapping.tp_group, | ||
| strategy=allreduce_strategy, | ||
| op=all_reduce_params.fusion_op, | ||
| eps=all_reduce_params.eps, | ||
| trigger_completion_at_end=all_reduce_params. | ||
| trigger_completion_at_end, | ||
| **additional_args, | ||
| ) | ||
|  | ||
| return output if len(output) > 1 else output[0] | ||
|  | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix: Support MPI-disabled PG path in AllReduceRunner.
AllReduceRunner.forward always uses torch.ops.trtllm.allreduce. When MPI is disabled you must call allreduce_pg and pass rank/pg. Suggested patch:
Also applies to: 1188-1210
🧰 Tools
🪛 Ruff (0.14.1)
1144-1147: Mutable class attributes should be annotated with
typing.ClassVar(RUF012)