|
8 | 8 | import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils |
9 | 9 | from tensorrt_llm import deep_gemm |
10 | 10 | from tensorrt_llm._utils import get_sm_version |
| 11 | +from tensorrt_llm.functional import AllReduceFusionOp, AllReduceStrategy |
11 | 12 |
|
12 | 13 | from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec, |
13 | 14 | OptimizationProfile, TunableRunner, TuningConfig) |
@@ -1122,6 +1123,161 @@ def _( |
1122 | 1123 | return x.new_empty((b, d), dtype=o_dtype) |
1123 | 1124 |
|
1124 | 1125 |
|
| 1126 | +class AllReduceRunner(TunableRunner): |
| 1127 | + all_support_ops = { |
| 1128 | + AllReduceFusionOp.NONE.value, |
| 1129 | + AllReduceFusionOp.RESIDUAL_RMS_NORM.value, |
| 1130 | + } |
| 1131 | + |
| 1132 | + tuning_config = TuningConfig( |
| 1133 | + dynamic_tensor_specs=(DynamicTensorSpec( |
| 1134 | + 0, 0, |
| 1135 | + (8192, 4096, 2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1), |
| 1136 | + last_positive_power_of_2), ), |
| 1137 | + constraint_specs=(ConstraintSpec(1, 0, lambda shapes: shapes[0][0]), ), |
| 1138 | + ) |
| 1139 | + |
| 1140 | + def __init__( |
| 1141 | + self, |
| 1142 | + tp_size: int, |
| 1143 | + group: List[int], |
| 1144 | + op: int, |
| 1145 | + eps: float, |
| 1146 | + trigger_completion_at_end: bool, |
| 1147 | + ): |
| 1148 | + self.tp_size = tp_size |
| 1149 | + self.op = op |
| 1150 | + self._group = group |
| 1151 | + self._eps = eps |
| 1152 | + self._trigger_completion_at_end = trigger_completion_at_end |
| 1153 | + |
| 1154 | + def __hash__(self): |
| 1155 | + return hash((self.tp_size, self.op)) |
| 1156 | + |
| 1157 | + def get_valid_tactics( |
| 1158 | + self, |
| 1159 | + inputs: List[torch.Tensor], |
| 1160 | + profile: OptimizationProfile, |
| 1161 | + **kwargs, |
| 1162 | + ) -> List[int]: |
| 1163 | + valid_tactics = [ |
| 1164 | + AllReduceStrategy.NCCL.value, |
| 1165 | + AllReduceStrategy.ONESHOT.value, |
| 1166 | + ] |
| 1167 | + if inputs[0].shape[0] >= self.tp_size: |
| 1168 | + valid_tactics.append(AllReduceStrategy.TWOSHOT.value) |
| 1169 | + return valid_tactics |
| 1170 | + |
| 1171 | + def forward( |
| 1172 | + self, |
| 1173 | + inputs: List[torch.Tensor], |
| 1174 | + tactic: int = -1, |
| 1175 | + ) -> torch.Tensor: |
| 1176 | + input, residual, norm_weight, scale, bias, workspace = inputs |
| 1177 | + if tactic == -1: |
| 1178 | + tactic = AllReduceStrategy.NCCL.value |
| 1179 | + |
| 1180 | + return torch.ops.trtllm.allreduce( |
| 1181 | + input, |
| 1182 | + residual, |
| 1183 | + norm_weight, |
| 1184 | + scale, |
| 1185 | + bias, |
| 1186 | + workspace, |
| 1187 | + self._group, |
| 1188 | + tactic, |
| 1189 | + self.op, |
| 1190 | + self._eps, |
| 1191 | + self._trigger_completion_at_end, |
| 1192 | + ) |
| 1193 | + |
| 1194 | + |
| 1195 | +@torch.library.custom_op("trtllm::tunable_allreduce", mutates_args=()) |
| 1196 | +def tunable_allreduce( |
| 1197 | + input: torch.Tensor, |
| 1198 | + residual: Optional[torch.Tensor], |
| 1199 | + norm_weight: Optional[torch.Tensor], |
| 1200 | + scale: Optional[torch.Tensor], |
| 1201 | + bias: Optional[torch.Tensor], |
| 1202 | + workspace: Optional[torch.Tensor], |
| 1203 | + group: List[int], |
| 1204 | + strategy: int, |
| 1205 | + op: int, |
| 1206 | + eps: float, |
| 1207 | + tp_size: int, |
| 1208 | + trigger_completion_at_end: bool, |
| 1209 | +) -> List[torch.Tensor]: |
| 1210 | + |
| 1211 | + tuner = AutoTuner.get() |
| 1212 | + |
| 1213 | + allreduce_runner = AllReduceRunner( |
| 1214 | + tp_size, |
| 1215 | + group, |
| 1216 | + op, |
| 1217 | + eps, |
| 1218 | + trigger_completion_at_end, |
| 1219 | + ) |
| 1220 | + |
| 1221 | + _, best_tactic = tuner.choose_one( |
| 1222 | + "trtllm::tunable_allreduce::allreduce", |
| 1223 | + [allreduce_runner], |
| 1224 | + AllReduceRunner.tuning_config, |
| 1225 | + [input, residual, norm_weight, scale, bias, workspace], |
| 1226 | + ) |
| 1227 | + |
| 1228 | + return allreduce_runner( |
| 1229 | + [input, residual, norm_weight, scale, bias, workspace], |
| 1230 | + tactic=best_tactic, |
| 1231 | + ) |
| 1232 | + |
| 1233 | + |
| 1234 | +@tunable_allreduce.register_fake |
| 1235 | +def _( |
| 1236 | + input: torch.Tensor, |
| 1237 | + residual: Optional[torch.Tensor], |
| 1238 | + norm_weight: Optional[torch.Tensor], |
| 1239 | + scale: Optional[torch.Tensor], |
| 1240 | + bias: Optional[torch.Tensor], |
| 1241 | + workspace: Optional[torch.Tensor], |
| 1242 | + group: List[int], |
| 1243 | + strategy: int, |
| 1244 | + op: int, |
| 1245 | + eps: float, |
| 1246 | + tp_size: int, |
| 1247 | + trigger_completion_at_end: bool, |
| 1248 | +) -> torch.Tensor: |
| 1249 | + if op == int(AllReduceFusionOp.NONE): |
| 1250 | + return [torch.empty_like(input)] |
| 1251 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM): |
| 1252 | + norm_out = torch.empty_like(input) |
| 1253 | + residual_out = torch.empty_like(input) |
| 1254 | + return [norm_out, residual_out] |
| 1255 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_FP8): |
| 1256 | + quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) |
| 1257 | + residual_out = torch.empty_like(input) |
| 1258 | + return [quant_out, residual_out] |
| 1259 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_FP8): |
| 1260 | + norm_out = torch.empty_like(input) |
| 1261 | + quant_out = torch.empty_like(input, dtype=torch.float8_e4m3fn) |
| 1262 | + residual_out = torch.empty_like(input) |
| 1263 | + return [norm_out, quant_out, residual_out] |
| 1264 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_QUANT_NVFP4): |
| 1265 | + fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) |
| 1266 | + quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) |
| 1267 | + scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) |
| 1268 | + residual_out = torch.empty_like(input) |
| 1269 | + return [quant_fp4, scale_fp4, residual_out] |
| 1270 | + elif op == int(AllReduceFusionOp.RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4): |
| 1271 | + fp4_shape, scale_shape = fp4_utils.get_fp4_shape(input.shape, 16) |
| 1272 | + quant_fp4 = input.new_empty(fp4_shape, dtype=torch.uint8) |
| 1273 | + scale_fp4 = input.new_empty(scale_shape, dtype=torch.uint8) |
| 1274 | + norm_out = torch.empty_like(input) |
| 1275 | + residual_out = torch.empty_like(input) |
| 1276 | + return [norm_out, quant_fp4, scale_fp4, residual_out] |
| 1277 | + else: |
| 1278 | + return [torch.empty_like(input)] |
| 1279 | + |
| 1280 | + |
1125 | 1281 | def get_event(event_idx: int): |
1126 | 1282 | from ..utils import get_model_extra_attrs |
1127 | 1283 | extra_attrs = get_model_extra_attrs() |
|
0 commit comments