diff --git a/transformer_engine/plugin/core/backends/flagos/impl/gemm.py b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py index e190af5c5d..01b46952bf 100644 --- a/transformer_engine/plugin/core/backends/flagos/impl/gemm.py +++ b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py @@ -71,7 +71,6 @@ def generic_gemm_fl( assert not gelu and gelu_in is None, "Triton-Based General Gemm do not support gelu now" assert quantizer is None, "Triton-Based General Gemm do not support quantization now" - assert bias is None, "Triton-Based General Gemm do not support bias now" alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate) @@ -95,7 +94,18 @@ def generic_gemm_fl( A_comp = A.T if transA else A B_comp = B.T if transB else B - out1 = flag_gems.mm(B_comp, A_comp) + bias_grad = None + if grad: + out1 = flag_gems.mm(B_comp, A_comp) + if bias is not None: + bias_grad = flag_gems.sum_dim(B, dim=[0]) + else: + # NOTE(wqq) flag_gems.addmm uses beta for bias scaling (Y = alpha * WX + beta * bias), + # unlike the beta here (for scaling D). Always set to 1. + if bias is not None: + out1 = flag_gems.addmm(bias, B_comp, A_comp, beta=1, alpha=alpha) + else: + out1 = flag_gems.mm(B_comp, A_comp) if shape_b_changed: out1 = out1.view(s, b, -1) @@ -104,7 +114,6 @@ def generic_gemm_fl( if torch_out_dtype is not None and out1.dtype != torch_out_dtype: out1 = out1.to(torch_out_dtype) - bias_grad = None gelu_input = None extra_output_ret = None