Skip to content

Commit 8e30117

Browse files
committed
[TRTLLM-8129][feat] Apply AutoTuner to AllReduce Op for strategy tuning.
* Added AUTOTUNE strategy for AllReduce operations to automatically select and apply the best reduction tactic based on tensor characteristics. * For AutoTuner, enhanced distributed synchronization with improved MPI coordination across multiple ranks for reliable profiling and operation execution. This grants the ability to use AutoTuner for distributed operations. Signed-off-by: Yukun He <[email protected]>
1 parent f57dc01 commit 8e30117

File tree

8 files changed

+312
-60
lines changed

8 files changed

+312
-60
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
import tensorrt_llm
18+
from tensorrt_llm._utils import mpi_barrier, mpi_broadcast
1819
from tensorrt_llm.bindings.internal.runtime import delay_kernel
1920
from tensorrt_llm.logger import logger
2021

@@ -534,8 +535,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
534535
# Add statistics tracking
535536
self.stats = AutoTunerStatistics()
536537

537-
self.profiling_debug = True
538-
539538
# Current captured choose_one() contexts
540539
self._active_capture: Optional['AutoTuner.TacticsCapture'] = None
541540
# Last captured choose_one() contexts
@@ -786,6 +785,13 @@ def _profile_runners(
786785
tuning_config: TuningConfig,
787786
**kwargs,
788787
) -> float:
788+
"""Profile runners and select the best tactic.
789+
790+
For multi-rank profiling, only rank 0 performs the actual profiling
791+
to avoid sync issues when different ranks select different tactics.
792+
The results are then broadcasted to all other ranks.
793+
"""
794+
789795
min_time = float('inf')
790796
has_tuning_failure_occured = False
791797
best_runner_id, best_tactic = None, None
@@ -836,6 +842,13 @@ def _profile_runners(
836842
min_time = time_measured
837843
best_runner_id, best_tactic = runner_id, tac
838844

845+
if self._is_sync_op(runner):
846+
profiling_results = (best_runner_id, best_tactic, min_time,
847+
has_tuning_failure_occured)
848+
# Broadcast profiling results from rank 0 to all other ranks
849+
profiling_results = mpi_broadcast(profiling_results, root=0)
850+
best_runner_id, best_tactic, min_time, has_tuning_failure_occured = profiling_results
851+
839852
return best_runner_id, best_tactic, min_time, has_tuning_failure_occured
840853

841854
def _get_input_sizes(self, inputs: List[torch.Tensor]) -> List[torch.Size]:
@@ -871,6 +884,10 @@ def _profile_single_kernel(
871884
are used to ensure accurate timing.
872885
"""
873886
stream = torch.cuda.current_stream()
887+
888+
if self._is_sync_op(runner):
889+
mpi_barrier()
890+
874891
# warm up, no timing
875892
for _ in range(self.warmup):
876893
runner(inputs, tactic=tactic, **kwargs)
@@ -883,6 +900,9 @@ def _profile_single_kernel(
883900
start = torch.cuda.Event(enable_timing=True)
884901
end = torch.cuda.Event(enable_timing=True)
885902

903+
if self._is_sync_op(runner):
904+
mpi_barrier()
905+
886906
start.record(stream=stream)
887907
for _ in range(self.repeat):
888908
runner(inputs, tactic=tactic, **kwargs)
@@ -1065,6 +1085,9 @@ def _prepare_input_tensors(
10651085
tensors.append(tensor)
10661086
return tensors
10671087

1088+
def _is_sync_op(self, runner: TunableRunner) -> bool:
1089+
return runner.__class__.__name__ in ["AllReduceRunner"]
1090+
10681091
def clear_cache(self) -> None:
10691092
"""Clear the profiling cache."""
10701093
self.profiling_cache.clear()

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
99
from tensorrt_llm import deep_gemm
1010
from tensorrt_llm._utils import get_sm_version
11+
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
1112

1213
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
1314
OptimizationProfile, TunableRunner, TuningConfig)
@@ -1122,6 +1123,161 @@ def _(
11221123
return x.new_empty((b, d), dtype=o_dtype)
11231124

11241125

1126+
class AllReduceRunner(TunableRunner):
1127+
all_support_ops = {
1128+
AllReduceFusionOp.NONE.value,
1129+
AllReduceFusionOp.RESIDUAL_RMS_NORM.value,
1130+
}
1131+
1132+
tuning_config = TuningConfig(
1133+
dynamic_tensor_specs=(DynamicTensorSpec(
1134+
0, 0,
1135+
(8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1),
1136+
last_positive_power_of_2), ),
1137+
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
1138+
)
1139+
1140+
def __init__(
1141+
self,
1142+
tp_size: int,
1143+
group: List[int],
1144+
op: int,
1145+
eps: float,
1146+
trigger_completion_at_end: bool,
1147+
):
1148+
self.tp_size = tp_size
1149+
self.op = op
1150+
self._group = group
1151+
self._eps = eps
1152+
self._trigger_completion_at_end = trigger_completion_at_end
1153+
1154+
def __hash__(self):
1155+
return hash((self.tp_size, self.op))
1156+
1157+
def get_valid_tactics(
1158+
self,
1159+
inputs: List[torch.Tensor],
1160+
profile: OptimizationProfile,
1161+
**kwargs,
1162+
) -> List[int]:
1163+
valid_tactics = [
1164+
AllReduceStrategy.NCCL.value,
1165+
AllReduceStrategy.ONESHOT.value,
1166+
]
1167+
if inputs[0].shape[0] >= self.tp_size:
1168+
valid_tactics.append(AllReduceStrategy.TWOSHOT.value)
1169+
return valid_tactics
1170+
1171+
def forward(
1172+
self,
1173+
inputs: List[torch.Tensor],
1174+
tactic: int = -1,
1175+
) -> torch.Tensor:
1176+
input, residual, norm_weight, scale, bias, workspace = inputs
1177+
if tactic == -1:
1178+
tactic = AllReduceStrategy.NCCL.value
1179+
1180+
return torch.ops.trtllm.allreduce(
1181+
input,
1182+
residual,
1183+
norm_weight,
1184+
scale,
1185+
bias,
1186+
workspace,
1187+
self._group,
1188+
tactic,
1189+
self.op,
1190+
self._eps,
1191+
self._trigger_completion_at_end,
1192+
)
1193+
1194+
1195+
@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
1196+
def tunable_allreduce(
1197+
input: torch.Tensor,
1198+
residual: Optional[torch.Tensor],
1199+
norm_weight: Optional[torch.Tensor],
1200+
scale: Optional[torch.Tensor],
1201+
bias: Optional[torch.Tensor],
1202+
workspace: Optional[torch.Tensor],
1203+
group: List[int],
1204+
strategy: int,
1205+
op: int,
1206+
eps: float,
1207+
tp_size: int,
1208+
trigger_completion_at_end: bool,
1209+
) -> List[torch.Tensor]:
1210+
1211+
tuner = AutoTuner.get()
1212+
1213+
allreduce_runner = AllReduceRunner(
1214+
tp_size,
1215+
group,
1216+
op,
1217+
eps,
1218+
trigger_completion_at_end,
1219+
)
1220+
1221+
_, best_tactic = tuner.choose_one(
1222+
"trtllm::tunable_allreduce::allreduce",
1223+
[allreduce_runner],
1224+
AllReduceRunner.tuning_config,
1225+
[input, residual, norm_weight, scale, bias, workspace],
1226+
)
1227+
1228+
return allreduce_runner(
1229+
[input, residual, norm_weight, scale, bias, workspace],
1230+
tactic=best_tactic,
1231+
)
1232+
1233+
1234+
@tunable_allreduce.register_fake
1235+
def _(
1236+
input: torch.Tensor,
1237+
residual: Optional[torch.Tensor],
1238+
norm_weight: Optional[torch.Tensor],
1239+
scale: Optional[torch.Tensor],
1240+
bias: Optional[torch.Tensor],
1241+
workspace: Optional[torch.Tensor],
1242+
group: List[int],
1243+
strategy: int,
1244+
op: int,
1245+
eps: float,
1246+
tp_size: int,
1247+
trigger_completion_at_end: bool,
1248+
) -> torch.Tensor:
1249+
if op == int(AllReduceFusionOp.NONE):
1250+
return [torch.empty_like(input)]
1251+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
1252+
norm_out = torch.empty_like(input)
1253+
residual_out = torch.empty_like(input)
1254+
return [norm_out, residual_out]
1255+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8):
1256+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1257+
residual_out = torch.empty_like(input)
1258+
return [quant_out, residual_out]
1259+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8):
1260+
norm_out = torch.empty_like(input)
1261+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1262+
residual_out = torch.empty_like(input)
1263+
return [norm_out, quant_out, residual_out]
1264+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4):
1265+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1266+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1267+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1268+
residual_out = torch.empty_like(input)
1269+
return [quant_fp4, scale_fp4, residual_out]
1270+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4):
1271+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1272+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1273+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1274+
norm_out = torch.empty_like(input)
1275+
residual_out = torch.empty_like(input)
1276+
return [norm_out, quant_fp4, scale_fp4, residual_out]
1277+
else:
1278+
return [torch.empty_like(input)]
1279+
1280+
11251281
def get_event(event_idx: int):
11261282
from ..utils import get_model_extra_attrs
11271283
extra_attrs = get_model_extra_attrs()

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ def __init__(self,
505505
self._disable_mpi = mpi_disabled()
506506

507507
self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce
508-
509508
if self.mapping.tp_size > 1:
510509
# When Strategy is UB, it is guaranteed that the workspace is not used.
511510
if self.strategy != AllReduceStrategy.UB:
@@ -574,6 +573,7 @@ def forward(
574573
input = input.contiguous() # Underlying op requires contiguous input
575574

576575
allreduce_strategy = self.strategy
576+
577577
if all_reduce_params is None:
578578
all_reduce_params = AllReduceParams()
579579

@@ -598,21 +598,43 @@ def forward(
598598
"pg": pg.boxed(),
599599
}
600600

601-
output = self.all_reduce_op(
602-
input=input,
603-
residual=all_reduce_params.residual,
604-
norm_weight=all_reduce_params.norm_weight,
605-
scale=all_reduce_params.scale,
606-
bias=all_reduce_params.bias,
607-
workspace=self.workspace,
608-
group=self.mapping.tp_group,
609-
strategy=allreduce_strategy,
610-
op=all_reduce_params.fusion_op,
611-
eps=all_reduce_params.eps,
612-
trigger_completion_at_end=all_reduce_params.
613-
trigger_completion_at_end,
614-
**additional_args,
615-
)
601+
# TODO: args for non mpi version are not supported by Python side custom op
602+
# for now, we just fallback to AUTO
603+
if self.strategy == AllReduceStrategy.AUTOTUNE:
604+
self.strategy = AllReduceStrategy.AUTO
605+
606+
if self.strategy == AllReduceStrategy.AUTOTUNE:
607+
output = torch.ops.trtllm.tunable_allreduce(
608+
input=input,
609+
residual=all_reduce_params.residual,
610+
norm_weight=all_reduce_params.norm_weight,
611+
scale=all_reduce_params.scale,
612+
bias=all_reduce_params.bias,
613+
workspace=self.workspace,
614+
group=self.mapping.tp_group,
615+
strategy=allreduce_strategy,
616+
op=all_reduce_params.fusion_op,
617+
eps=all_reduce_params.eps,
618+
tp_size=self.mapping.tp_size,
619+
trigger_completion_at_end=all_reduce_params.
620+
trigger_completion_at_end,
621+
)
622+
else:
623+
output = self.all_reduce_op(
624+
input=input,
625+
residual=all_reduce_params.residual,
626+
norm_weight=all_reduce_params.norm_weight,
627+
scale=all_reduce_params.scale,
628+
bias=all_reduce_params.bias,
629+
workspace=self.workspace,
630+
group=self.mapping.tp_group,
631+
strategy=allreduce_strategy,
632+
op=all_reduce_params.fusion_op,
633+
eps=all_reduce_params.eps,
634+
trigger_completion_at_end=all_reduce_params.
635+
trigger_completion_at_end,
636+
**additional_args,
637+
)
616638

617639
return output if len(output) > 1 else output[0]
618640

tensorrt_llm/_torch/model_config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,8 @@ def get_all_reduce_strategy(strategy: str = "AUTO"):
193193
"TWOSHOT": AllReduceStrategy.TWOSHOT,
194194
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
195195
"MNNVL": AllReduceStrategy.MNNVL,
196-
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC
196+
"NCCL_SYMMETRIC": AllReduceStrategy.NCCL_SYMMETRIC,
197+
"AUTOTUNE": AllReduceStrategy.AUTOTUNE,
197198
}
198199
key = strategy.upper()
199200
return maps[key] if key in maps else AllReduceStrategy.AUTO

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -650,7 +650,8 @@ def __init__(
650650
eps=config.rms_norm_eps,
651651
dtype=config.torch_dtype)
652652

653-
self.all_reduce = AllReduce(mapping=model_config.mapping)
653+
self.all_reduce = AllReduce(mapping=model_config.mapping,
654+
strategy=model_config.allreduce_strategy)
654655

655656
self.next_layer_layernorm: RMSNorm = None
656657
self.next_attn: LlamaAttention = None

tensorrt_llm/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum):
38833883
LOWPRECISION = 6
38843884
MNNVL = 7
38853885
NCCL_SYMMETRIC = 8
3886+
AUTOTUNE = 9
38863887

38873888

38883889
class AllReduceFusionOp(IntEnum):

tensorrt_llm/llmapi/llm_args.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2535,12 +2535,12 @@ class TorchLlmArgs(BaseLlmArgs):
25352535
status="prototype",
25362536
)
25372537

2538-
allreduce_strategy: Optional[Literal[
2539-
'AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT',
2540-
'LOWPRECISION', 'MNNVL',
2541-
'NCCL_SYMMETRIC']] = Field(default='AUTO',
2542-
description="Allreduce strategy to use.",
2543-
status="beta")
2538+
allreduce_strategy: Optional[
2539+
Literal['AUTO', 'NCCL', 'UB', 'MINLATENCY', 'ONESHOT', 'TWOSHOT',
2540+
'LOWPRECISION', 'MNNVL', 'NCCL_SYMMETRIC',
2541+
'AUTOTUNE']] = Field(default='AUTO',
2542+
description="Allreduce strategy to use.",
2543+
status="beta")
25442544
checkpoint_loader: Optional[object] = Field(
25452545
default=None,
25462546
description=

0 commit comments

Comments
 (0)