Skip to content

Commit cf7acdd

Browse files
committed
[TRTLLM-8129][feat] Apply AutoTuner to AllReduce Op for strategy tuning.
Signed-off-by: Yukun He <[email protected]>
1 parent c72f6d1 commit cf7acdd

File tree

5 files changed

+285
-35
lines changed

5 files changed

+285
-35
lines changed

tensorrt_llm/_torch/autotuner.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616

1717
import tensorrt_llm
18+
from tensorrt_llm._utils import mpi_barrier
1819
from tensorrt_llm.bindings.internal.runtime import delay_kernel
1920
from tensorrt_llm.logger import logger
2021

@@ -534,8 +535,6 @@ def __init__(self, warmup=3, repeat=10, stream_delay_micro_secs=1000):
534535
# Add statistics tracking
535536
self.stats = AutoTunerStatistics()
536537

537-
self.profiling_debug = True
538-
539538
@classmethod
540539
def get(cls):
541540
if cls._instance is None:
@@ -745,6 +744,10 @@ def _profile_single_kernel(
745744
are used to ensure accurate timing.
746745
"""
747746
stream = torch.cuda.current_stream()
747+
748+
if self._is_sync_op(runner):
749+
mpi_barrier()
750+
748751
# warm up, no timing
749752
for _ in range(self.warmup):
750753
runner(inputs, tactic=tactic, **kwargs)
@@ -757,6 +760,9 @@ def _profile_single_kernel(
757760
start = torch.cuda.Event(enable_timing=True)
758761
end = torch.cuda.Event(enable_timing=True)
759762

763+
if self._is_sync_op(runner):
764+
mpi_barrier()
765+
760766
start.record(stream=stream)
761767
for _ in range(self.repeat):
762768
runner(inputs, tactic=tactic, **kwargs)
@@ -938,6 +944,9 @@ def _prepare_input_tensors(
938944
tensors.append(tensor)
939945
return tensors
940946

947+
def _is_sync_op(self, runner: TunableRunner) -> bool:
948+
return runner.__class__.__name__ in ["AllReduceRunner"]
949+
941950
def clear_cache(self) -> None:
942951
"""Clear the profiling cache."""
943952
self.profiling_cache.clear()

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
99
from tensorrt_llm import deep_gemm
1010
from tensorrt_llm._utils import get_sm_version
11+
from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy
1112

1213
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
1314
OptimizationProfile, TunableRunner, TuningConfig)
@@ -1036,6 +1037,172 @@ def _(
10361037
return x.new_empty((b, d), dtype=o_dtype)
10371038

10381039

1040+
class AllReduceRunner(TunableRunner):
1041+
all_support_ops = {
1042+
AllReduceFusionOp.NONE.value,
1043+
AllReduceFusionOp.RESIDUAL_RMS_NORM.value,
1044+
}
1045+
1046+
tuning_config = TuningConfig(
1047+
dynamic_tensor_specs=(DynamicTensorSpec(
1048+
0, 0,
1049+
(8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1),
1050+
last_positive_power_of_2), ),
1051+
constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ),
1052+
)
1053+
1054+
def __init__(
1055+
self,
1056+
tp_size: int,
1057+
group: List[int],
1058+
op: int,
1059+
eps: float,
1060+
trigger_completion_at_end: bool,
1061+
):
1062+
self.tp_size = tp_size
1063+
self.op = op
1064+
self._group = group
1065+
self._eps = eps
1066+
self._trigger_completion_at_end = trigger_completion_at_end
1067+
1068+
def __hash__(self):
1069+
return hash((self.tp_size, self.op))
1070+
1071+
def get_valid_tactics(
1072+
self,
1073+
inputs: List[torch.Tensor],
1074+
profile: OptimizationProfile,
1075+
**kwargs,
1076+
) -> List[int]:
1077+
valid_tactics = [
1078+
AllReduceStrategy.NCCL.value,
1079+
AllReduceStrategy.ONESHOT.value,
1080+
]
1081+
if inputs[0].shape[0] >= self.tp_size:
1082+
valid_tactics.append(AllReduceStrategy.TWOSHOT.value)
1083+
return valid_tactics
1084+
1085+
def forward(
1086+
self,
1087+
inputs: List[torch.Tensor],
1088+
tactic: int = -1,
1089+
) -> torch.Tensor:
1090+
input, residual, norm_weight, scale, bias, workspace = inputs
1091+
if tactic == -1:
1092+
tactic = AllReduceStrategy.NCCL.value
1093+
1094+
torch.ops.trtllm.allreduce(
1095+
input,
1096+
residual,
1097+
norm_weight,
1098+
scale,
1099+
bias,
1100+
workspace,
1101+
self._group,
1102+
tactic,
1103+
self.op,
1104+
self._eps,
1105+
self._trigger_completion_at_end,
1106+
)
1107+
1108+
1109+
@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=())
1110+
def tunable_allreduce(
1111+
input: torch.Tensor,
1112+
residual: Optional[torch.Tensor],
1113+
norm_weight: Optional[torch.Tensor],
1114+
scale: Optional[torch.Tensor],
1115+
bias: Optional[torch.Tensor],
1116+
workspace: Optional[torch.Tensor],
1117+
group: List[int],
1118+
strategy: int,
1119+
op: int,
1120+
eps: float,
1121+
tp_size: int,
1122+
trigger_completion_at_end: bool,
1123+
) -> List[torch.Tensor]:
1124+
1125+
tuner = AutoTuner.get()
1126+
1127+
allreduce_runner = AllReduceRunner(
1128+
tp_size,
1129+
group,
1130+
op,
1131+
eps,
1132+
trigger_completion_at_end,
1133+
)
1134+
1135+
_, best_tactic = tuner.choose_one(
1136+
"trtllm::tunable_allreduce::allreduce",
1137+
[allreduce_runner],
1138+
AllReduceRunner.tuning_config,
1139+
[input, residual, norm_weight, scale, bias, workspace],
1140+
)
1141+
1142+
if best_tactic == -1:
1143+
best_tactic = AllReduceStrategy.NCCL.value
1144+
1145+
return torch.ops.trtllm.allreduce(
1146+
input,
1147+
residual,
1148+
norm_weight,
1149+
scale,
1150+
bias,
1151+
workspace,
1152+
group,
1153+
best_tactic,
1154+
op,
1155+
eps,
1156+
trigger_completion_at_end,
1157+
)
1158+
1159+
1160+
@tunable_allreduce.register_fake
1161+
def _(
1162+
input: torch.Tensor,
1163+
residual: Optional[torch.Tensor],
1164+
norm_weight: Optional[torch.Tensor],
1165+
scale: Optional[torch.Tensor],
1166+
bias: Optional[torch.Tensor],
1167+
workspace: Optional[torch.Tensor],
1168+
group: List[int],
1169+
strategy: int,
1170+
op: int,
1171+
eps: float,
1172+
trigger_completion_at_end: bool,
1173+
) -> torch.Tensor:
1174+
if op == int(AllReduceFusionOp.NONE):
1175+
return [torch.empty_like(input)]
1176+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM):
1177+
norm_out = torch.empty_like(input)
1178+
residual_out = torch.empty_like(input)
1179+
return [norm_out, residual_out]
1180+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8):
1181+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1182+
residual_out = torch.empty_like(input)
1183+
return [quant_out, residual_out]
1184+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8):
1185+
norm_out = torch.empty_like(input)
1186+
quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn)
1187+
residual_out = torch.empty_like(input)
1188+
return [norm_out, quant_out, residual_out]
1189+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4):
1190+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1191+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1192+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1193+
residual_out = torch.empty_like(input)
1194+
return [quant_fp4, scale_fp4, residual_out]
1195+
elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4):
1196+
fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16)
1197+
quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8)
1198+
scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8)
1199+
norm_out = torch.empty_like(input)
1200+
residual_out = torch.empty_like(input)
1201+
return [norm_out, quant_fp4, scale_fp4, residual_out]
1202+
else:
1203+
return [torch.empty_like(input)]
1204+
1205+
10391206
def get_event(event_idx: int):
10401207
from ..utils import get_model_extra_attrs
10411208
extra_attrs = get_model_extra_attrs()

tensorrt_llm/_torch/distributed/ops.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,6 @@ def __init__(self,
505505
self._disable_mpi = mpi_disabled()
506506

507507
self.all_reduce_op = torch.ops.trtllm.allreduce_pg if self._disable_mpi else torch.ops.trtllm.allreduce
508-
509508
if self.mapping.tp_size > 1:
510509
# When Strategy is UB, it is guaranteed that the workspace is not used.
511510
if self.strategy != AllReduceStrategy.UB:
@@ -574,6 +573,7 @@ def forward(
574573
input = input.contiguous() # Underlying op requires contiguous input
575574

576575
allreduce_strategy = self.strategy
576+
577577
if all_reduce_params is None:
578578
all_reduce_params = AllReduceParams()
579579

@@ -598,21 +598,38 @@ def forward(
598598
"pg": pg.boxed(),
599599
}
600600

601-
output = self.all_reduce_op(
602-
input=input,
603-
residual=all_reduce_params.residual,
604-
norm_weight=all_reduce_params.norm_weight,
605-
scale=all_reduce_params.scale,
606-
bias=all_reduce_params.bias,
607-
workspace=self.workspace,
608-
group=self.mapping.tp_group,
609-
strategy=allreduce_strategy,
610-
op=all_reduce_params.fusion_op,
611-
eps=all_reduce_params.eps,
612-
trigger_completion_at_end=all_reduce_params.
613-
trigger_completion_at_end,
614-
**additional_args,
615-
)
601+
if self.strategy == AllReduceStrategy.AUTOTUNE:
602+
output = torch.ops.trtllm.tunable_allreduce(
603+
input=input,
604+
residual=all_reduce_params.residual,
605+
norm_weight=all_reduce_params.norm_weight,
606+
scale=all_reduce_params.scale,
607+
bias=all_reduce_params.bias,
608+
workspace=self.workspace,
609+
group=self.mapping.tp_group,
610+
strategy=allreduce_strategy,
611+
op=all_reduce_params.fusion_op,
612+
eps=all_reduce_params.eps,
613+
tp_size=self.mapping.tp_size,
614+
trigger_completion_at_end=all_reduce_params.
615+
trigger_completion_at_end,
616+
)
617+
else:
618+
output = self.all_reduce_op(
619+
input=input,
620+
residual=all_reduce_params.residual,
621+
norm_weight=all_reduce_params.norm_weight,
622+
scale=all_reduce_params.scale,
623+
bias=all_reduce_params.bias,
624+
workspace=self.workspace,
625+
group=self.mapping.tp_group,
626+
strategy=allreduce_strategy,
627+
op=all_reduce_params.fusion_op,
628+
eps=all_reduce_params.eps,
629+
trigger_completion_at_end=all_reduce_params.
630+
trigger_completion_at_end,
631+
**additional_args,
632+
)
616633

617634
return output if len(output) > 1 else output[0]
618635

tensorrt_llm/functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3883,6 +3883,7 @@ class AllReduceStrategy(IntEnum):
38833883
LOWPRECISION = 6
38843884
MNNVL = 7
38853885
NCCL_SYMMETRIC = 8
3886+
AUTOTUNE = 9
38863887

38873888

38883889
class AllReduceFusionOp(IntEnum):

0 commit comments

Comments
 (0)