|
8 | 8 | import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils |
9 | 9 | from tensorrt_llm import deep_gemm |
10 | 10 | from tensorrt_llm._utils import get_sm_version |
| 11 | +from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy |
11 | 12 |
|
12 | 13 | from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, |
13 | 14 | OptimizationProfile, TunableRunner, TuningConfig) |
@@ -1036,6 +1037,172 @@ def _( |
1036 | 1037 | return x.new_empty((b, d), dtype=o_dtype) |
1037 | 1038 |
|
1038 | 1039 |
|
| 1040 | +class AllReduceRunner(TunableRunner): |
| 1041 | + all_support_ops = { |
| 1042 | + AllReduceFusionOp.NONE.value, |
| 1043 | + AllReduceFusionOp.RESIDUAL_RMS_NORM.value, |
| 1044 | + } |
| 1045 | + |
| 1046 | + tuning_config = TuningConfig( |
| 1047 | + dynamic_tensor_specs=(DynamicTensorSpec( |
| 1048 | + 0, 0, |
| 1049 | + (8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1), |
| 1050 | + last_positive_power_of_2), ), |
| 1051 | + constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ), |
| 1052 | + ) |
| 1053 | + |
| 1054 | + def __init__( |
| 1055 | + self, |
| 1056 | + tp_size: int, |
| 1057 | + group: List[int], |
| 1058 | + op: int, |
| 1059 | + eps: float, |
| 1060 | + trigger_completion_at_end: bool, |
| 1061 | + ): |
| 1062 | + self.tp_size = tp_size |
| 1063 | + self.op = op |
| 1064 | + self._group = group |
| 1065 | + self._eps = eps |
| 1066 | + self._trigger_completion_at_end = trigger_completion_at_end |
| 1067 | + |
| 1068 | + def __hash__(self): |
| 1069 | + return hash((self.tp_size, self.op)) |
| 1070 | + |
| 1071 | + def get_valid_tactics( |
| 1072 | + self, |
| 1073 | + inputs: List[torch.Tensor], |
| 1074 | + profile: OptimizationProfile, |
| 1075 | + **kwargs, |
| 1076 | + ) -> List[int]: |
| 1077 | + valid_tactics = [ |
| 1078 | + AllReduceStrategy.NCCL.value, |
| 1079 | + AllReduceStrategy.ONESHOT.value, |
| 1080 | + ] |
| 1081 | + if inputs[0].shape[0] >= self.tp_size: |
| 1082 | + valid_tactics.append(AllReduceStrategy.TWOSHOT.value) |
| 1083 | + return valid_tactics |
| 1084 | + |
| 1085 | + def forward( |
| 1086 | + self, |
| 1087 | + inputs: List[torch.Tensor], |
| 1088 | + tactic: int = -1, |
| 1089 | + ) -> torch.Tensor: |
| 1090 | + input, residual, norm_weight, scale, bias, workspace = inputs |
| 1091 | + if tactic == -1: |
| 1092 | + tactic = AllReduceStrategy.NCCL.value |
| 1093 | + |
| 1094 | + torch.ops.trtllm.allreduce( |
| 1095 | + input, |
| 1096 | + residual, |
| 1097 | + norm_weight, |
| 1098 | + scale, |
| 1099 | + bias, |
| 1100 | + workspace, |
| 1101 | + self._group, |
| 1102 | + tactic, |
| 1103 | + self.op, |
| 1104 | + self._eps, |
| 1105 | + self._trigger_completion_at_end, |
| 1106 | + ) |
| 1107 | + |
| 1108 | + |
| 1109 | +@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=()) |
| 1110 | +def tunable_allreduce( |
| 1111 | + input: torch.Tensor, |
| 1112 | + residual: Optional[torch.Tensor], |
| 1113 | + norm_weight: Optional[torch.Tensor], |
| 1114 | + scale: Optional[torch.Tensor], |
| 1115 | + bias: Optional[torch.Tensor], |
| 1116 | + workspace: Optional[torch.Tensor], |
| 1117 | + group: List[int], |
| 1118 | + strategy: int, |
| 1119 | + op: int, |
| 1120 | + eps: float, |
| 1121 | + tp_size: int, |
| 1122 | + trigger_completion_at_end: bool, |
| 1123 | +) -> List[torch.Tensor]: |
| 1124 | + |
| 1125 | + tuner = AutoTuner.get() |
| 1126 | + |
| 1127 | + allreduce_runner = AllReduceRunner( |
| 1128 | + tp_size, |
| 1129 | + group, |
| 1130 | + op, |
| 1131 | + eps, |
| 1132 | + trigger_completion_at_end, |
| 1133 | + ) |
| 1134 | + |
| 1135 | + _, best_tactic = tuner.choose_one( |
| 1136 | + "trtllm::tunable_allreduce::allreduce", |
| 1137 | + [allreduce_runner], |
| 1138 | + AllReduceRunner.tuning_config, |
| 1139 | + [input, residual, norm_weight, scale, bias, workspace], |
| 1140 | + ) |
| 1141 | + |
| 1142 | + if best_tactic == -1: |
| 1143 | + best_tactic = AllReduceStrategy.NCCL.value |
| 1144 | + |
| 1145 | + return torch.ops.trtllm.allreduce( |
| 1146 | + input, |
| 1147 | + residual, |
| 1148 | + norm_weight, |
| 1149 | + scale, |
| 1150 | + bias, |
| 1151 | + workspace, |
| 1152 | + group, |
| 1153 | + best_tactic, |
| 1154 | + op, |
| 1155 | + eps, |
| 1156 | + trigger_completion_at_end, |
| 1157 | + ) |
| 1158 | + |
| 1159 | + |
| 1160 | +@tunable_allreduce.register_fake |
| 1161 | +def _( |
| 1162 | + input: torch.Tensor, |
| 1163 | + residual: Optional[torch.Tensor], |
| 1164 | + norm_weight: Optional[torch.Tensor], |
| 1165 | + scale: Optional[torch.Tensor], |
| 1166 | + bias: Optional[torch.Tensor], |
| 1167 | + workspace: Optional[torch.Tensor], |
| 1168 | + group: List[int], |
| 1169 | + strategy: int, |
| 1170 | + op: int, |
| 1171 | + eps: float, |
| 1172 | + trigger_completion_at_end: bool, |
| 1173 | +) -> torch.Tensor: |
| 1174 | + if op == int(AllReduceFusionOp.NONE): |
| 1175 | + return [torch.empty_like(input)] |
| 1176 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM): |
| 1177 | + norm_out = torch.empty_like(input) |
| 1178 | + residual_out = torch.empty_like(input) |
| 1179 | + return [norm_out, residual_out] |
| 1180 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8): |
| 1181 | + quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) |
| 1182 | + residual_out = torch.empty_like(input) |
| 1183 | + return [quant_out, residual_out] |
| 1184 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8): |
| 1185 | + norm_out = torch.empty_like(input) |
| 1186 | + quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) |
| 1187 | + residual_out = torch.empty_like(input) |
| 1188 | + return [norm_out, quant_out, residual_out] |
| 1189 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4): |
| 1190 | + fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) |
| 1191 | + quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) |
| 1192 | + scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) |
| 1193 | + residual_out = torch.empty_like(input) |
| 1194 | + return [quant_fp4, scale_fp4, residual_out] |
| 1195 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4): |
| 1196 | + fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) |
| 1197 | + quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) |
| 1198 | + scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) |
| 1199 | + norm_out = torch.empty_like(input) |
| 1200 | + residual_out = torch.empty_like(input) |
| 1201 | + return [norm_out, quant_fp4, scale_fp4, residual_out] |
| 1202 | + else: |
| 1203 | + return [torch.empty_like(input)] |
| 1204 | + |
| 1205 | + |
1039 | 1206 | def get_event(event_idx: int): |
1040 | 1207 | from ..utils import get_model_extra_attrs |
1041 | 1208 | extra_attrs = get_model_extra_attrs() |
|
0 commit comments