From 2ab41b8ad8c6fa57075652cce410e7d16e5a1598 Mon Sep 17 00:00:00 2001 From: wenqingqian Date: Wed, 27 May 2026 15:03:24 +0800 Subject: [PATCH 1/2] bias linear --- .../plugin/core/backends/flagos/impl/gemm.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/transformer_engine/plugin/core/backends/flagos/impl/gemm.py b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py index e190af5c5d..77576d13f5 100644 --- a/transformer_engine/plugin/core/backends/flagos/impl/gemm.py +++ b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py @@ -71,7 +71,7 @@ def generic_gemm_fl( assert not gelu and gelu_in is None, "Triton-Based General Gemm do not support gelu now" assert quantizer is None, "Triton-Based General Gemm do not support quantization now" - assert bias is None, "Triton-Based General Gemm do not support bias now" + alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate) @@ -95,7 +95,18 @@ def generic_gemm_fl( A_comp = A.T if transA else A B_comp = B.T if transB else B - out1 = flag_gems.mm(B_comp, A_comp) + bias_grad = None + if grad: + out1 = flag_gems.mm(B_comp, A_comp) + if bias is not None: + bias_grad = flag_gems.sum_dim(B, dim=[0]) + else: + # NOTE(wqq) flag_gems.addmm uses beta for bias scaling (Y = alpha * WX + beta * bias), + # unlike the beta here (for scaling D). Always set to 1. + if bias is not None: + out1 = flag_gems.addmm(bias, B_comp, A_comp, beta=1, alpha=alpha) + else: + out1 = flag_gems.mm(B_comp, A_comp) if shape_b_changed: out1 = out1.view(s, b, -1) @@ -104,7 +115,6 @@ def generic_gemm_fl( if torch_out_dtype is not None and out1.dtype != torch_out_dtype: out1 = out1.to(torch_out_dtype) - bias_grad = None gelu_input = None extra_output_ret = None From 4677c76f3a7a39493379778c4018e55dc867adab Mon Sep 17 00:00:00 2001 From: wenqingqian Date: Wed, 27 May 2026 15:22:08 +0800 Subject: [PATCH 2/2] minor --- transformer_engine/plugin/core/backends/flagos/impl/gemm.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/plugin/core/backends/flagos/impl/gemm.py b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py index 77576d13f5..01b46952bf 100644 --- a/transformer_engine/plugin/core/backends/flagos/impl/gemm.py +++ b/transformer_engine/plugin/core/backends/flagos/impl/gemm.py @@ -72,7 +72,6 @@ def generic_gemm_fl( assert not gelu and gelu_in is None, "Triton-Based General Gemm do not support gelu now" assert quantizer is None, "Triton-Based General Gemm do not support quantization now" - alpha = validate_gemm_scale(alpha, True) beta = validate_gemm_scale(beta, accumulate)