Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 7 additions & 26 deletions cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -988,33 +988,8 @@ class AllreduceOp
bool ifFallbackToNCCL(size_t seq_len, size_t message_size_bytes, size_t max_workspace_size, bool is_auto)
{
// If messageSize is less than maxWorkspaceSize, use NCCL, regardless of the fusion type.
if (message_size_bytes > max_workspace_size)
if (message_size_bytes > max_workspace_size || !mIsP2PSupported || !mIsNVLINKSupported)
{
if (!is_auto)
{
TLLM_LOG_WARNING(
"Since messageSize is greater than maxWorkspaceSize, fallback to AllReduceStrategy: NCCL");
}
return true;
}

// If Peer to Peer is not supported, fallback to NCCL.
if (!mIsP2PSupported)
{
if (!is_auto)
{
TLLM_LOG_WARNING("Since Peer to Peer not supported, fallback to AllReduceStrategy: NCCL");
}
return true;
}

// If NVLINK is not supported, fallback to NCCL.
if (!mIsNVLINKSupported)
{
if (!is_auto)
{
TLLM_LOG_WARNING("Since NVLINK not supported, fallback to AllReduceStrategy: NCCL");
}
return true;
}
return false;
Expand Down Expand Up @@ -1055,6 +1030,12 @@ class AllreduceOp
// Otherwise, MIN_LATENCY strategy will be directly returned due to more fusions it can support.
// TODO: NCCL AllReduce + subsequent quantization ops (as fallback) can also support the fusion types.
// This should be compared with MIN_LATENCY fused kernels to determine the best strategy.

if (!is_auto)
{
return mStrategy;
}

switch (mOp)
{
case AllReduceFusionOp::NONE:
Expand Down
27 changes: 25 additions & 2 deletions tensorrt_llm/_torch/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch

import tensorrt_llm
from tensorrt_llm._utils import mpi_barrier, mpi_broadcast
from tensorrt_llm.bindings.internal.runtime import delay_kernel
from tensorrt_llm.logger import logger

Expand Down Expand Up @@ -534,8 +535,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
# Add statistics tracking
self.stats = AutoTunerStatistics()

self.profiling_debug = True

@classmethod
def get(cls):
if cls._instance is None:
Expand Down Expand Up @@ -660,6 +659,13 @@ def _profile_runners(
tuning_config: TuningConfig,
**kwargs,
) -> float:
"""Profile runners and select the best tactic.

For multi-rank profiling, only rank 0 performs the actual profiling
to avoid sync issues when different ranks select different tactics.
The results are then broadcasted to all other ranks.
"""

min_time = float('inf')
has_tuning_failure_occured = False
best_runner_id, best_tactic = None, None
Expand Down Expand Up @@ -710,6 +716,13 @@ def _profile_runners(
min_time = time_measured
best_runner_id, best_tactic = runner_id, tac

if self._is_sync_op(runner):
profiling_results = (best_runner_id, best_tactic, min_time,
has_tuning_failure_occured)
# Broadcast profiling results from rank 0 to all other ranks
profiling_results = mpi_broadcast(profiling_results, root=0)
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = profiling_results

return best_runner_id, best_tactic, min_time, has_tuning_failure_occured

def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]:
Expand Down Expand Up @@ -745,6 +758,10 @@ def _profile_single_kernel(
are used to ensure accurate timing.
"""
stream = torch.cuda.current_stream()

if self._is_sync_op(runner):
mpi_barrier()

# warm up, no timing
for _ in range(self.warmup):
runner(inputs, tactic=tactic, **kwargs)
Expand All @@ -757,6 +774,9 @@ def _profile_single_kernel(
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

if self._is_sync_op(runner):
mpi_barrier()

start.record(stream=stream)
for _ in range(self.repeat):
runner(inputs, tactic=tactic, **kwargs)
Expand Down Expand Up @@ -939,6 +959,9 @@ def _prepare_input_tensors(
tensors.append(tensor)
return tensors

def _is_sync_op(self, runner: TunableRunner) -> bool:
return runner.__class__.__name__ in ["AllReduceRunner"]

def clear_cache(self) -> None:
"""Clear the profiling cache."""
self.profiling_cache.clear()
Expand Down
156 changes: 156 additions & 0 deletions tensorrt_llm/_torch/custom_ops/torch_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm import deep_gemm
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy

from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
OptimizationProfile, TunableRunner, TuningConfig)
Expand Down Expand Up @@ -1139,6 +1140,161 @@ def _(
return x.new_empty((b, d), dtype=o_dtype)


class AllReduceRunner(TunableRunner):
all_support_ops = {
AllReduceFusionOp.NONE.value,
AllReduceFusionOp.RESIDUAL_RMS_NORM.value,
}

tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0,
(8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1),
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
)

def __init__(
self,
tp_size: int,
group: List[int],
op: int,
eps: float,
trigger_completion_at_end: bool,
):
self.tp_size = tp_size
self.op = op
self._group = group
self._eps = eps
self._trigger_completion_at_end = trigger_completion_at_end

def __hash__(self):
return hash((self.tp_size, self.op))

Comment on lines +1143 to +1173
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix: Support MPI-disabled PG path in AllReduceRunner.

AllReduceRunner.forward always uses torch.ops.trtllm.allreduce. When MPI is disabled you must call allreduce_pg and pass rank/pg. Suggested patch:

@@
-from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
+from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
+from tensorrt_llm._utils import mpi_disabled
+from typing import Optional, Any, ClassVar
@@
-class AllReduceRunner(TunableRunner):
-    all_support_ops = {
+class AllReduceRunner(TunableRunner):
+    all_support_ops: ClassVar[set[int]] = {
         AllReduceFusionOp.NONE.value,
         AllReduceFusionOp.RESIDUAL_RMS_NORM.value,
     }
@@
-    def __init__(
+    def __init__(
         self,
         tp_size: int,
         group: List[int],
         op: int,
         eps: float,
         trigger_completion_at_end: bool,
+        *,
+        rank: Optional[int] = None,
+        pg: Optional[Any] = None,
     ):
         self.tp_size = tp_size
         self.op = op
         self._group = group
         self._eps = eps
         self._trigger_completion_at_end = trigger_completion_at_end
+        self._rank = rank
+        self._pg = pg
@@
-        return torch.ops.trtllm.allreduce(
-            input,
-            residual,
-            norm_weight,
-            scale,
-            bias,
-            workspace,
-            self._group,
-            tactic,
-            self.op,
-            self._eps,
-            self._trigger_completion_at_end,
-        )
+        if mpi_disabled() and self._pg is not None and self._rank is not None:
+            return torch.ops.trtllm.allreduce_pg(
+                input,
+                residual,
+                norm_weight,
+                scale,
+                bias,
+                workspace,
+                self._group,
+                self._pg,
+                tactic,
+                self.op,
+                self._eps,
+                self._trigger_completion_at_end,
+            )
+        else:
+            return torch.ops.trtllm.allreduce(
+                input,
+                residual,
+                norm_weight,
+                scale,
+                bias,
+                workspace,
+                self._group,
+                tactic,
+                self.op,
+                self._eps,
+                self._trigger_completion_at_end,
+            )

Also applies to: 1188-1210

🧰 Tools
🪛 Ruff (0.14.1)

1144-1147: Mutable class attributes should be annotated with typing.ClassVar

(RUF012)

def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[int]:
valid_tactics = [
AllReduceStrategy.NCCL.value,
AllReduceStrategy.ONESHOT.value,
]
if inputs[0].shape[0] >= self.tp_size:
valid_tactics.append(AllReduceStrategy.TWOSHOT.value)
return valid_tactics

def forward(
self,
inputs: List[torch.Tensor],
tactic: int = -1,
) -> torch.Tensor:
input, residual, norm_weight, scale, bias, workspace = inputs
if tactic == -1:
tactic = AllReduceStrategy.NCCL.value

return torch.ops.trtllm.allreduce(
input,
residual,
norm_weight,
scale,
bias,
workspace,
self._group,
tactic,
self.op,
self._eps,
self._trigger_completion_at_end,
)


@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
def tunable_allreduce(
input: torch.Tensor,
residual: Optional[torch.Tensor],
norm_weight: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
workspace: Optional[torch.Tensor],
group: List[int],
strategy: int,
op: int,
eps: float,
tp_size: int,
trigger_completion_at_end: bool,
) -> List[torch.Tensor]:

tuner = AutoTuner.get()

allreduce_runner = AllReduceRunner(
tp_size,
group,
op,
eps,
trigger_completion_at_end,
)

_, best_tactic = tuner.choose_one(
"trtllm::tunable_allreduce::allreduce",
[allreduce_runner],
AllReduceRunner.tuning_config,
[input, residual, norm_weight, scale, bias, workspace],
)

return allreduce_runner(
[input, residual, norm_weight, scale, bias, workspace],
tactic=best_tactic,
)

Comment on lines +1212 to +1249
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix: Wire PG args through tunable_allreduce and correct fake signature.

  • Missing tp_size in @register_fake signature causes schema mismatch.
  • Need to accept optional rank/pg and pass to runner for MPI-disabled case.

Patch:

@@
-@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
+@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
 def tunable_allreduce(
     input: torch.Tensor,
     residual: Optional[torch.Tensor],
     norm_weight: Optional[torch.Tensor],
     scale: Optional[torch.Tensor],
     bias: Optional[torch.Tensor],
     workspace: Optional[torch.Tensor],
     group: List[int],
     strategy: int,
     op: int,
     eps: float,
     tp_size: int,
     trigger_completion_at_end: bool,
+    rank: Optional[int] = None,
+    pg: Optional[object] = None,
 ) -> List[torch.Tensor]:
@@
-    allreduce_runner = AllReduceRunner(
+    allreduce_runner = AllReduceRunner(
         tp_size,
         group,
         op,
         eps,
         trigger_completion_at_end,
+        rank=rank,
+        pg=pg,
     )
@@
-@tunable_allreduce.register_fake
+@tunable_allreduce.register_fake
 def _(
     input: torch.Tensor,
     residual: Optional[torch.Tensor],
     norm_weight: Optional[torch.Tensor],
     scale: Optional[torch.Tensor],
     bias: Optional[torch.Tensor],
     workspace: Optional[torch.Tensor],
     group: List[int],
     strategy: int,
     op: int,
     eps: float,
+    tp_size: int,
     trigger_completion_at_end: bool,
+    rank: Optional[int] = None,
+    pg: Optional[object] = None,
 ) -> torch.Tensor:

This aligns the fake schema with the real op and unblocks torch.compile/meta.

Also applies to: 1251-1295

🧰 Tools
🪛 Ruff (0.14.1)

1221-1221: Unused function argument: strategy

(ARG001)


@tunable_allreduce.register_fake
def _(
input: torch.Tensor,
residual: Optional[torch.Tensor],
norm_weight: Optional[torch.Tensor],
scale: Optional[torch.Tensor],
bias: Optional[torch.Tensor],
workspace: Optional[torch.Tensor],
group: List[int],
strategy: int,
op: int,
eps: float,
tp_size: int,
trigger_completion_at_end: bool,
) -> torch.Tensor:
if op == int(AllReduceFusionOp.NONE):
return [torch.empty_like(input)]
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
norm_out = torch.empty_like(input)
residual_out = torch.empty_like(input)
return [norm_out, residual_out]
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8):
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
residual_out = torch.empty_like(input)
return [quant_out, residual_out]
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8):
norm_out = torch.empty_like(input)
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
residual_out = torch.empty_like(input)
return [norm_out, quant_out, residual_out]
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4):
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
residual_out = torch.empty_like(input)
return [quant_fp4, scale_fp4, residual_out]
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4):
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
norm_out = torch.empty_like(input)
residual_out = torch.empty_like(input)
return [norm_out, quant_fp4, scale_fp4, residual_out]
else:
return [torch.empty_like(input)]


def get_event(event_idx: int):
from ..utils import get_model_extra_attrs
extra_attrs = get_model_extra_attrs()
Expand Down
49 changes: 33 additions & 16 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,6 @@ def __init__(self,
self._disable_mpi = mpi_disabled()

self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce

if self.mapping.tp_size > 1:
# When Strategy is UB, it is guaranteed that the workspace is not used.
if self.strategy != AllReduceStrategy.UB:
Expand Down Expand Up @@ -574,6 +573,7 @@ def forward(
input = input.contiguous() # Underlying op requires contiguous input

allreduce_strategy = self.strategy

if all_reduce_params is None:
all_reduce_params = AllReduceParams()

Expand All @@ -598,21 +598,38 @@ def forward(
"pg": pg.boxed(),
}

output = self.all_reduce_op(
input=input,
residual=all_reduce_params.residual,
norm_weight=all_reduce_params.norm_weight,
scale=all_reduce_params.scale,
bias=all_reduce_params.bias,
workspace=self.workspace,
group=self.mapping.tp_group,
strategy=allreduce_strategy,
op=all_reduce_params.fusion_op,
eps=all_reduce_params.eps,
trigger_completion_at_end=all_reduce_params.
trigger_completion_at_end,
**additional_args,
)
if self.strategy == AllReduceStrategy.AUTOTUNE:
output = torch.ops.trtllm.tunable_allreduce(
input=input,
residual=all_reduce_params.residual,
norm_weight=all_reduce_params.norm_weight,
scale=all_reduce_params.scale,
bias=all_reduce_params.bias,
workspace=self.workspace,
group=self.mapping.tp_group,
strategy=allreduce_strategy,
op=all_reduce_params.fusion_op,
eps=all_reduce_params.eps,
tp_size=self.mapping.tp_size,
trigger_completion_at_end=all_reduce_params.
trigger_completion_at_end,
)
else:
Comment on lines +601 to +617
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

MPI-disabled bug: AUTOTUNE path ignores ProcessGroup and rank; uses non-PG op.

In the AUTOTUNE branch you call trtllm.tunable_allreduce without passing rank/pg, and AllReduceRunner always dispatches to allreduce (non-PG). When mpi_disabled(), this will fail or hang.

Pass rank and pg through and make the runner use allreduce_pg when MPI is disabled. Suggested patch (ops side):

@@
-        if self.strategy == AllReduceStrategy.AUTOTUNE:
-            output = torch.ops.trtllm.tunable_allreduce(
+        if self.strategy == AllReduceStrategy.AUTOTUNE:
+            output = torch.ops.trtllm.tunable_allreduce(
                 input=input,
                 residual=all_reduce_params.residual,
                 norm_weight=all_reduce_params.norm_weight,
                 scale=all_reduce_params.scale,
                 bias=all_reduce_params.bias,
                 workspace=self.workspace,
                 group=self.mapping.tp_group,
                 strategy=allreduce_strategy,
                 op=all_reduce_params.fusion_op,
                 eps=all_reduce_params.eps,
                 tp_size=self.mapping.tp_size,
-                trigger_completion_at_end=all_reduce_params.
-                trigger_completion_at_end,
+                trigger_completion_at_end=all_reduce_params.trigger_completion_at_end,
+                # Wire PG when MPI is disabled
+                **({"rank": torch.distributed.get_rank(),
+                    "pg": self.mapping.tp_group_pg.boxed()} if self._disable_mpi else {}),
             )

Apply the complementary changes in tensorrt_llm/_torch/custom_ops/torch_custom_ops.py (see my other comment) to accept rank/pg and call allreduce_pg accordingly.

Committable suggestion skipped: line range outside the PR's diff.

output = self.all_reduce_op(
input=input,
residual=all_reduce_params.residual,
norm_weight=all_reduce_params.norm_weight,
scale=all_reduce_params.scale,
bias=all_reduce_params.bias,
workspace=self.workspace,
group=self.mapping.tp_group,
strategy=allreduce_strategy,
op=all_reduce_params.fusion_op,
eps=all_reduce_params.eps,
trigger_completion_at_end=all_reduce_params.
trigger_completion_at_end,
**additional_args,
)

return output if len(output) > 1 else output[0]

Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
"TWOSHOT": AllReduceStrategy.TWOSHOT,
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
"MNNVL": AllReduceStrategy.MNNVL,
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC,
"AUTOTUNE": AllReduceStrategy.AUTOTUNE,
}
key = strategy.upper()
return maps[key] if key in maps else AllReduceStrategy.AUTO
Expand Down
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,8 @@ def __init__(
eps=config.rms_norm_eps,
dtype=config.torch_dtype)

self.all_reduce = AllReduce(mapping=model_config.mapping)
self.all_reduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)

self.next_layer_layernorm: RMSNorm = None
self.next_attn: LlamaAttention = None
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum):
LOWPRECISION = 6
MNNVL = 7
NCCL_SYMMETRIC = 8
AUTOTUNE = 9


class AllReduceFusionOp(IntEnum):
Expand Down
Loading