Skip to content

Commit fba39e3

Browse files
committed
Minor fix on profiling cache sync between different gpus.
Signed-off-by: Yukun He <[email protected]>
1 parent 872bb3c commit fba39e3

File tree

6 files changed

+29
-48
lines changed

6 files changed

+29
-48
lines changed

cpp/tensorrt_llm/thop/allreduceOp.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,29 +1048,6 @@ class AllreduceOp
10481048
return AllReduceStrategyType::NCCL;
10491049
}
10501050

1051-
// This rule based heuristic only chooses between NCCL and MIN_LATENCY strategies.
1052-
1053-
// Heurisitic will only be applied on NONE and RESIDUAL_RMS_NORM fusion types.
1054-
// Because NCCL might be faster on some large messageSize cases.
1055-
// Otherwise, MIN_LATENCY strategy will be directly returned due to more fusions it can support.
1056-
// TODO: NCCL AllReduce + subsequent quantization ops (as fallback) can also support the fusion types.
1057-
// This should be compared with MIN_LATENCY fused kernels to determine the best strategy.
1058-
switch (mOp)
1059-
{
1060-
case AllReduceFusionOp::NONE:
1061-
case AllReduceFusionOp::RESIDUAL_RMS_NORM: break;
1062-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_FP8:
1063-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_FP8:
1064-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_QUANT_NVFP4:
1065-
case AllReduceFusionOp::RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: return AllReduceStrategyType::MIN_LATENCY;
1066-
// Suppose NCCL has fallback implementations for all fusion types.
1067-
default: return AllReduceStrategyType::NCCL;
1068-
}
1069-
1070-
// Check mOp to be supported by the heuristic.
1071-
TORCH_CHECK(mOp == AllReduceFusionOp::NONE || mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM,
1072-
"Only NONE and RESIDUAL_RMS_NORM are supported for NCCL/MIN_LATENCY heuristic.");
1073-
10741051
// Default to NCCL.
10751052
AllReduceStrategyType strategy = AllReduceStrategyType::NCCL;
10761053

tensorrt_llm/_torch/autotuner.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616

1717
import tensorrt_llm
18-
from tensorrt_llm._utils import mpi_barrier
18+
from tensorrt_llm._utils import mpi_barrier, mpi_broadcast
1919
from tensorrt_llm.bindings.internal.runtime import delay_kernel
2020
from tensorrt_llm.logger import logger
2121

@@ -659,6 +659,13 @@ def _profile_runners(
659659
tuning_config: TuningConfig,
660660
**kwargs,
661661
) -> float:
662+
"""Profile runners and select the best tactic.
663+
664+
For multi-rank profiling, only rank 0 performs the actual profiling
665+
to avoid sync issues when different ranks select different tactics.
666+
The results are then broadcasted to all other ranks.
667+
"""
668+
662669
min_time = float('inf')
663670
has_tuning_failure_occured = False
664671
best_runner_id, best_tactic = None, None
@@ -709,6 +716,13 @@ def _profile_runners(
709716
min_time = time_measured
710717
best_runner_id, best_tactic = runner_id, tac
711718

719+
if self._is_sync_op(runner):
720+
profiling_results = (best_runner_id, best_tactic, min_time,
721+
has_tuning_failure_occured)
722+
# Broadcast profiling results from rank 0 to all other ranks
723+
profiling_results = mpi_broadcast(profiling_results, root=0)
724+
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = profiling_results
725+
712726
return best_runner_id, best_tactic, min_time, has_tuning_failure_occured
713727

714728
def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]:

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,7 +1194,7 @@ def forward(
11941194
if tactic == -1:
11951195
tactic = AllReduceStrategy.NCCL.value
11961196

1197-
torch.ops.trtllm.allreduce(
1197+
return torch.ops.trtllm.allreduce(
11981198
input,
11991199
residual,
12001200
norm_weight,
@@ -1242,21 +1242,9 @@ def tunable_allreduce(
12421242
[input, residual, norm_weight, scale, bias, workspace],
12431243
)
12441244

1245-
if best_tactic == -1:
1246-
best_tactic = AllReduceStrategy.NCCL.value
1247-
1248-
return torch.ops.trtllm.allreduce(
1249-
input,
1250-
residual,
1251-
norm_weight,
1252-
scale,
1253-
bias,
1254-
workspace,
1255-
group,
1256-
best_tactic,
1257-
op,
1258-
eps,
1259-
trigger_completion_at_end,
1245+
return allreduce_runner(
1246+
[input, residual, norm_weight, scale, bias, workspace],
1247+
tactic=best_tactic,
12601248
)
12611249

12621250

tensorrt_llm/_torch/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
191191
"TWOSHOT": AllReduceStrategy.TWOSHOT,
192192
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
193193
"MNNVL": AllReduceStrategy.MNNVL,
194-
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC
194+
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC,
195+
"AUTOTUNE": AllReduceStrategy.AUTOTUNE,
195196
}
196197
key = strategy.upper()
197198
return maps[key] if key in maps else AllReduceStrategy.AUTO

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,8 @@ def __init__(
650650
eps=config.rms_norm_eps,
651651
dtype=config.torch_dtype)
652652

653-
self.all_reduce = AllReduce(mapping=model_config.mapping)
653+
self.all_reduce = AllReduce(mapping=model_config.mapping,
654+
strategy=model_config.allreduce_strategy)
654655

655656
self.next_layer_layernorm: RMSNorm = None
656657
self.next_attn: LlamaAttention = None

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2492,12 +2492,12 @@ class TorchLlmArgs(BaseLlmArgs):
24922492
status="prototype",
24932493
)
24942494

2495-
allreduce_strategy: Optional[Literal[
2496-
'AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT',
2497-
'LOWPRECISION', 'MNNVL',
2498-
'NCCL_SYMMETRIC']] = Field(default='AUTO',
2499-
description="Allreduce strategy to use.",
2500-
status="beta")
2495+
allreduce_strategy: Optional[
2496+
Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT',
2497+
'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC',
2498+
'AUTOTUNE']] = Field(default='AUTO',
2499+
description="Allreduce strategy to use.",
2500+
status="beta")
25012501
checkpoint_loader: Optional[object] = Field(
25022502
default=None,
25032503
description=

0 commit comments

Comments
 (0)