diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..bd2a5c76e --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..ea69378f4 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=128,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 2, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..b9a717cee --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 2}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json index 5c0dab42b..37ba845fa 100644 --- a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -1 +1 @@ -{"1": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 2}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}} \ No newline at end of file +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..286de4928 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=256,N=7168,expert_num=257,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 2}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..cd4b2b79e --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=4096,N=192,expert_num=128,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 16, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 5}, "8192": {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..fe56e1c44 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=512,N=7168,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..25333e743 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..ed56a6fc7 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=1024,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..5e2f44cb0 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..bc763e8bc --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..457d72dc8 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json index 394ce3193..a4f26860b 100644 --- a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -1 +1 @@ -{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 8, "num_stages": 5}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 8, "num_stages": 4}} \ No newline at end of file +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json new file mode 100644 index 000000000..f1a0658ba --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=7168,N=512,expert_num=257,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=true}_NVIDIA_H200.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 5}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 2, "num_warps": 4, "num_stages": 5}, "512": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 1, "num_warps": 4, "num_stages": 3}, "4096": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}, "8192": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128, "GROUP_SIZE_M": 4, "num_warps": 4, "num_stages": 3}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 000000000..217515264 --- /dev/null +++ b/lightllm/common/all_kernel_configs/grouped_moe_gemm_kernel/{K=96,N=4096,expert_num=128,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1 @@ +{"1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 3}, "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 1, "num_warps": 2, "num_stages": 5}, "256": {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 4, "num_warps": 2, "num_stages": 4}, "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "1024": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "4096": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}, "8192": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "num_warps": 4, "num_stages": 4}} \ No newline at end of file diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py index 131e65f54..5a4e84f82 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py @@ -16,6 +16,7 @@ def __init__( e_score_correction_bias_name: str, weight_prefix: str, n_routed_experts: int, + num_fused_shared_experts: int, split_inter_size: int, data_type: torch.dtype, network_config: Dict[str, Any], @@ -34,7 +35,10 @@ def __init__( self.e_score_correction_bias_name = e_score_correction_bias_name self.weight_prefix = weight_prefix - self.n_routed_experts = n_routed_experts + assert num_fused_shared_experts in [0, 1], "num_fused_shared_experts can only support 0 or 1 now." + self.n_routed_experts = n_routed_experts + num_fused_shared_experts + self.num_fused_shared_experts = num_fused_shared_experts + self.routed_scaling_factor = network_config.get("routed_scaling_factor", 1.0) self.split_inter_size = split_inter_size self.data_type_ = data_type self.tp_rank_ = get_current_rank_in_dp() @@ -63,7 +67,11 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=self.scoring_func, + num_fused_shared_experts=self.num_fused_shared_experts, ) + if self.num_fused_shared_experts > 0: + topk_ids[:, -1] = self.n_routed_experts - 1 + topk_weights[:, -1] = 1.0 / self.routed_scaling_factor w1, w1_scale = self.w1 w2, w2_scale = self.w2 use_fp8_w8a8 = self.quant_method is not None @@ -93,16 +101,18 @@ def _fuse(self): and None not in self.experts_gate_projs and None not in self.w2_list ): - w1_list = [] + gate_out_dim, gate_in_dim = self.experts_gate_projs[0].shape + up_out_dim, up_in_dim = self.experts_up_projs[0].shape + assert gate_in_dim == up_in_dim + dtype = self.experts_gate_projs[0].dtype + total_expert_num = self.n_routed_experts + + w1 = torch.empty((total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu") + for i_experts in range(self.n_routed_experts): - expert_gate_up_proj = torch.cat( - [self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0 - ) - expert_gate_up_proj = expert_gate_up_proj - w1_list.append(expert_gate_up_proj) - - inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1] - w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size) + w1[i_experts, 0:gate_out_dim:, :] = self.experts_gate_projs[i_experts] + w1[i_experts, gate_out_dim:, :] = self.experts_up_projs[i_experts] + inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1] w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size) if not self.quantized_weight and self.quant_method is not None: @@ -123,17 +133,19 @@ def _fuse_weight_scale(self): and None not in self.experts_gate_proj_scales and None not in self.w2_scale_list ): - w1_scale_list = [] - for i_experts in range(self.n_routed_experts): - expert_gate_up_proj_scale = torch.cat( - [self.experts_gate_proj_scales[i_experts], self.experts_up_proj_scales[i_experts]], dim=0 - ) - w1_scale_list.append(expert_gate_up_proj_scale) - - inter_shape, hidden_size = w1_scale_list[0].shape[0], w1_scale_list[0].shape[1] - w1_scale = torch._utils._flatten_dense_tensors(w1_scale_list).view( - len(w1_scale_list), inter_shape, hidden_size + gate_out_dim, gate_in_dim = self.experts_gate_proj_scales[0].shape + up_out_dim, up_in_dim = self.experts_up_proj_scales[0].shape + assert gate_in_dim == up_in_dim + dtype = self.experts_gate_proj_scales[0].dtype + total_expert_num = self.n_routed_experts + + w1_scale = torch.empty( + (total_expert_num, gate_out_dim + up_out_dim, gate_in_dim), dtype=dtype, device="cpu" ) + + for i_experts in range(self.n_routed_experts): + w1_scale[i_experts, 0:gate_out_dim:, :] = self.experts_gate_proj_scales[i_experts] + w1_scale[i_experts, gate_out_dim:, :] = self.experts_up_proj_scales[i_experts] inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1] w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view( len(self.w2_scale_list), inter_shape, hidden_size diff --git a/lightllm/common/fused_moe/grouped_topk.py b/lightllm/common/fused_moe/grouped_topk.py index e8eae1b15..b0e7f51a5 100644 --- a/lightllm/common/fused_moe/grouped_topk.py +++ b/lightllm/common/fused_moe/grouped_topk.py @@ -208,6 +208,7 @@ def triton_grouped_topk( topk_group: int = 0, scoring_func: str = "softmax", group_score_used_topk_num=2, + num_fused_shared_experts: int = 0, ): if correction_bias is not None: @@ -222,8 +223,8 @@ def triton_grouped_topk( dtype = torch.float32 scores_buffer = torch.empty((token_num, total_expert_num), dtype=dtype, device="cuda") - out_topk_weights = torch.empty((token_num, topk), dtype=torch.float32, device="cuda") - out_topk_ids = torch.empty((token_num, topk), dtype=torch.long, device="cuda") + out_topk_weights = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.float32, device="cuda") + out_topk_ids = torch.empty((token_num, topk + num_fused_shared_experts), dtype=torch.long, device="cuda") assert total_expert_num % num_expert_group == 0 diff --git a/lightllm/common/fused_moe/moe_kernel_configs.py b/lightllm/common/fused_moe/moe_kernel_configs.py index 0b107ede3..3b47b14c3 100644 --- a/lightllm/common/fused_moe/moe_kernel_configs.py +++ b/lightllm/common/fused_moe/moe_kernel_configs.py @@ -42,12 +42,12 @@ def try_to_get_best_config( else: if M <= expert_num: config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 32, "num_warps": 4, - "num_stages": 1, + "num_stages": 3, } else: config = { diff --git a/lightllm/common/fused_moe/moe_silu_and_mul.py b/lightllm/common/fused_moe/moe_silu_and_mul.py index 3f6bdb44f..5c62dbc90 100644 --- a/lightllm/common/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/fused_moe/moe_silu_and_mul.py @@ -54,6 +54,54 @@ def _silu_and_mul_kernel( ) +@triton.jit +def _silu_and_mul_kernel_fast( + input_ptr, + output_ptr, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_n, + BLOCK_N: tl.constexpr, + NEED_MASK: tl.constexpr, +): + stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) + stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) + + cur_batch = tl.program_id(0) + pid = tl.program_id(1) + n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) + + up_offsets = cur_batch * stride_input_m + (n_offsets[None, :] + size_n) + gate_offsets = cur_batch * stride_input_m + n_offsets[None, :] + res_offsets = cur_batch * stride_output_m + n_offsets[None, :] + if NEED_MASK: + mask = n_offsets[None, :] < size_n + else: + mask = True + + up = tl.load( + input_ptr + up_offsets, + mask=mask, + other=0.0, + ) + gate = tl.load( + input_ptr + gate_offsets, + mask=mask, + other=0.0, + ).to(tl.float32) + + gate = gate / (1 + tl.exp(-gate)) + gate = gate.to(input_ptr.dtype.element_ty) + + tl.store( + output_ptr + res_offsets, + up * gate, + mask=mask, + ) + + def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config): assert input.is_contiguous() assert output.is_contiguous() @@ -68,6 +116,26 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config): if not run_config: run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype)) + if size_m <= 4096: + BLOCK_N = run_config["BLOCK_N"] + grid = ( + size_m, + triton.cdiv(size_n, BLOCK_N), + ) + NEED_MASK = size_n % BLOCK_N != 0 + _silu_and_mul_kernel_fast[grid]( + input, + output, + stride_input_m, + stride_input_n, + stride_output_m, + stride_output_n, + size_n, + BLOCK_N=BLOCK_N, + NEED_MASK=NEED_MASK, + ) + return + BLOCK_M = run_config["BLOCK_M"] BLOCK_N = run_config["BLOCK_N"] num_warps = run_config["num_warps"] diff --git a/lightllm/common/fused_moe/topk_select.py b/lightllm/common/fused_moe/topk_select.py index 4e50576a9..6ae44e295 100644 --- a/lightllm/common/fused_moe/topk_select.py +++ b/lightllm/common/fused_moe/topk_select.py @@ -175,6 +175,7 @@ def select_experts( num_expert_group: Optional[int] = None, scoring_func: str = "softmax", custom_routing_function: Optional[Callable] = None, + num_fused_shared_experts: int = 0, ): from lightllm.common.fused_moe.topk_select import fused_topk from lightllm.common.fused_moe.grouped_topk import triton_grouped_topk @@ -210,6 +211,7 @@ def select_experts( topk_group=topk_group, scoring_func=scoring_func, group_score_used_topk_num=group_score_topk_num, + num_fused_shared_experts=num_fused_shared_experts, ) elif custom_routing_function is None: diff --git a/lightllm/common/quantization/deepgemm_quant.py b/lightllm/common/quantization/deepgemm_quant.py index 8d14805ad..7b595b6b4 100644 --- a/lightllm/common/quantization/deepgemm_quant.py +++ b/lightllm/common/quantization/deepgemm_quant.py @@ -54,12 +54,12 @@ def apply(self, input_tensor, weights, bias=None, out=None, workspace=None, use_ m, k = input_tensor.shape n = weights[0].shape[1] if input_scale is None: - input_scale = torch.empty((m, k // self.block_size), dtype=torch.float32, device=input_tensor.device) qinput_tensor = self.cache_manager.alloc_tensor( (m, k), qweight.dtype, device=qweight.device, is_graph_out=False ) - per_token_group_quant_fp8(input_tensor, self.block_size, qinput_tensor, input_scale) - input_scale = tma_align_input_scale(input_scale) + _, input_scale = per_token_group_quant_fp8( + input_tensor, self.block_size, qinput_tensor, column_major_scales=True, scale_tma_aligned=True + ) if out is None: if use_custom_tensor_mananger: diff --git a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py index 120880143..f03cb2f4e 100644 --- a/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py +++ b/lightllm/common/quantization/triton_quant/fp8/fp8act_quant_kernel.py @@ -6,7 +6,7 @@ from lightllm.utils.sgl_utils import HAS_SGL_KERNEL, sgl_ops from frozendict import frozendict from functools import lru_cache -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple try: from deep_gemm import ceil_div @@ -109,17 +109,46 @@ def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, x_q: torch.Tensor, - x_s: torch.Tensor, + x_s: torch.Tensor = None, eps: float = 1e-10, dtype: torch.dtype = torch.float8_e4m3fn, + column_major_scales: bool = False, + scale_tma_aligned: bool = False, + alloc_func: Callable = torch.empty, ): + # Adapted from + # https://github.com/sgl-project/sglang/blob/7e257cd666c0d639626487987ea8e590da1e9395/python/sglang/srt/layers/quantization/fp8_kernel.py#L290 if HAS_SGL_KERNEL: finfo = torch.finfo(dtype) fp8_max, fp8_min = finfo.max, finfo.min + if column_major_scales: + if scale_tma_aligned: + # aligned to 4 * sizeof(float) + aligned_size = (x.shape[-2] + 3) // 4 * 4 + x_s = alloc_func( + x.shape[:-2] + (x.shape[-1] // group_size, aligned_size), + device=x.device, + dtype=torch.float32, + ).permute(-1, -2)[: x.shape[-2], :] + else: + x_s = alloc_func( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + if x_s is None: + x_s = alloc_func( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) sgl_ops.sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, 1e-10, fp8_min, fp8_max, False) else: lightllm_per_token_group_quant_fp8(x, group_size, x_q, x_s, eps=1e-10, dtype=torch.float8_e4m3fn) + return x_q, x_s + # copy from # https://github.com/deepseek-ai/DeepGEMM/blob/bd2a77552886b98c205af12f8d7d2d61247c4b27/deep_gemm/jit_kernels/utils.py#L58 @@ -229,8 +258,8 @@ def test_per_token_group_quant_fp8(): x_q = torch.randn((1024, 8192)).cuda().to(torch.float8_e4m3fn) # x_s = torch.randn((1024, 8192 // group_size), dtype=torch.float32).cuda() - x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t() - per_token_group_quant_fp8(x, group_size, x_q, x_s) + # x_s = torch.randn((8192 // group_size, 1024 + 10), dtype=torch.float32).cuda().t() + _, x_s = per_token_group_quant_fp8(x, group_size, x_q, None, column_major_scales=True) x_s = x_s[:1024] th_x_q, th_x_s = torch_quant(x, group_size) print("th_x_s - x_s", torch.abs(th_x_s - x_s.reshape(-1)).max()) @@ -238,4 +267,5 @@ def test_per_token_group_quant_fp8(): if __name__ == "__main__": - test_tma_align() + test_per_token_group_quant_fp8() + # test_tma_align() diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ba752a4e8..5d5cbc55f 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -19,7 +19,6 @@ from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding_fp8 import gqa_token_decode_attention_flash_decoding_fp8 from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd from lightllm.models.deepseek2.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.models.deepseek2.flashinfer_struct import Deepseek2FlashInferStateInfo @@ -666,7 +665,8 @@ def _moe_ffn( hidden_states = input.view(-1, self.embed_dim_) num_tokens, hidden_dim = hidden_states.shape - if self.n_shared_experts is not None: + # if fused_shared_experts is not enabled, compute shared_output + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) router_logits = layer_weight.moe_gate.mm(hidden_states) @@ -682,7 +682,7 @@ def _moe_ffn( hidden_states.mul_(self.routed_scaling_factor) - if self.n_shared_experts is not None: + if self.n_shared_experts is not None and layer_weight.num_fused_shared_experts == 0: hidden_states.add_(shared_output) return hidden_states.view(num_tokens, hidden_dim) diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py index 7a9f3c150..dc2b1e285 100644 --- a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -3,7 +3,7 @@ import math import numpy as np from lightllm.common.basemodel import TransformerLayerWeight -from lightllm.utils.envs_utils import enable_env_vars +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.common.basemodel.layer_weights.meta_weights import ( ROWMMWeight, MultiROWMMWeight, @@ -39,6 +39,9 @@ def _parse_config(self): self.v_head_dim = self.network_config_["v_head_dim"] self.num_attention_heads = self.network_config_["num_attention_heads"] self.kv_lora_rank = self.network_config_["kv_lora_rank"] + self.num_fused_shared_experts = 0 + if get_env_start_args().enable_fused_shared_experts and self.is_moe: + self.num_fused_shared_experts = self.network_config_.get("n_shared_experts", 0) def _init_weight_names(self): if self.q_lora_rank is None: @@ -96,8 +99,25 @@ def _load_vb_scale(self, kv_b_proj_scale_, block_size): )[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1) return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype) + def _rename_shared_experts(self, weights, weight_scale_suffix): + old_prefix = f"model.layers.{self.layer_num_}.mlp.shared_experts" + new_prefix = f"model.layers.{self.layer_num_}.mlp.experts" + proj_names = ["gate_proj", "down_proj", "up_proj"] + for i in range(self.num_fused_shared_experts): + expert_id = self.n_routed_experts + i + for proj in proj_names: + weight_tensor = weights.get(f"{old_prefix}.{proj}.weight") + if weight_tensor is not None: + weights[f"{new_prefix}.{expert_id}.{proj}.weight"] = weight_tensor + if self.quant_cfg.quantized_weight: + scale_tensor = weights.get(f"{old_prefix}.{proj}." + weight_scale_suffix) + if scale_tensor is not None: + weights[f"{new_prefix}.{expert_id}.{proj}." + weight_scale_suffix] = scale_tensor + def load_hf_weights(self, weights): kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj") + if self.quant_cfg.quantized_weight: + weight_scale_suffix = kv_b_quant_method.weight_scale_suffix if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] @@ -105,29 +125,27 @@ def load_hf_weights(self, weights): if self.quant_cfg.quantized_weight: kv_b_proj_ = weight_dequant( kv_b_proj_.cuda(), - weights[ - f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - ].cuda(), + weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix].cuda(), ).cpu() weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_) weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_) if ( self.quant_cfg.quantized_weight - and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - in weights + and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix in weights ): - kv_b_proj_scale_ = weights[ - f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + kv_b_quant_method.weight_scale_suffix - ] + kv_b_proj_scale_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + weight_scale_suffix] block_size = 128 - weights[ - f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + kv_b_quant_method.weight_scale_suffix - ] = self._load_kb_scale(kv_b_proj_scale_, block_size) - weights[ - f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + kv_b_quant_method.weight_scale_suffix - ] = self._load_vb_scale(kv_b_proj_scale_, block_size) + weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj." + weight_scale_suffix] = self._load_kb_scale( + kv_b_proj_scale_, block_size + ) + weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + weight_scale_suffix] = self._load_vb_scale( + kv_b_proj_scale_, block_size + ) + # rename the shared experts weight + if self.num_fused_shared_experts > 0: + self._rename_shared_experts(weights, weight_scale_suffix) return super().load_hf_weights(weights) def _init_qkvo(self): @@ -198,6 +216,8 @@ def _init_qkvo(self): ) def _load_mlp(self, mlp_prefix): + if self.num_fused_shared_experts > 0: + return self.gate_up_proj = MultiROWMMWeight( weight_names=[f"{mlp_prefix}.gate_proj.weight", f"{mlp_prefix}.up_proj.weight"], data_type=self.data_type_, @@ -235,6 +255,7 @@ def _init_moe(self): e_score_correction_bias_name=self.e_score_correction_bias_name, weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts", n_routed_experts=self.n_routed_experts, + num_fused_shared_experts=self.num_fused_shared_experts, split_inter_size=moe_intermediate_size // self.tp_world_size_, data_type=self.data_type_, network_config=self.network_config_, diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py index 93ff323f3..6f2e333db 100644 --- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py +++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py @@ -5,59 +5,52 @@ @triton.jit -def _rotary_kernel( +def _rotary_kernel_q( Q, - K, Cos, Sin, stride_qbs, stride_qh, stride_qd, - stride_kbs, - stride_kh, - stride_kd, stride_cosbs, stride_cosd, stride_sinbs, stride_sind, max_total_len, HEAD_Q, - HEAD_K, # N_CTX 代表要计算的上下文长度 BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): cur_head_index = tl.program_id(0) + if cur_head_index >= HEAD_Q: + return cur_seq_index = tl.program_id(1) - cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 dim_range1 = dim_range0 + 1 off_q0 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range0[None, None, :] * stride_qd + cur_seq_range[:, None, None] * stride_qbs + cur_head_index * stride_qh + dim_range0[None, None, :] * stride_qd ) off_q1 = ( - cur_seq_range[:, None, None] * stride_qbs - + cur_head_range[None, :, None] * stride_qh - + dim_range1[None, None, :] * stride_qd + cur_seq_range[:, None, None] * stride_qbs + cur_head_index * stride_qh + dim_range1[None, None, :] * stride_qd ) + mask = cur_seq_range[:, None, None] < max_total_len cos_range = tl.arange(0, BLOCK_DMODEL // 2) off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd q0 = tl.load( Q + off_q0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + mask=mask, other=0.0, ) q1 = tl.load( Q + off_q1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + mask=mask, other=0.0, ) @@ -67,34 +60,51 @@ def _rotary_kernel( out0 = q0 * cos - q1 * sin out1 = q0 * sin + q1 * cos - tl.store( - Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) - tl.store( - Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) - ) + tl.store(Q + off_q0, out0, mask=mask) + tl.store(Q + off_q1, out1, mask=mask) + return - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) + +@triton.jit +def _rotary_kernel_k( + K, + Cos, + Sin, + stride_kbs, + stride_kh, + stride_kd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + HEAD_K, # HEAD_K is 1. + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_seq_index = tl.program_id(0) + + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + dim_range1 = dim_range0 + 1 + + cos_range = tl.arange(0, BLOCK_DMODEL // 2) + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd + + off_k0 = cur_seq_range[:, None, None] * stride_kbs + dim_range0[None, None, :] * stride_kd + off_k1 = cur_seq_range[:, None, None] * stride_kbs + dim_range1[None, None, :] * stride_kd off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + cos_range[None, None, :] * stride_cosd k0 = tl.load( K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), other=0.0, ) k1 = tl.load( K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), other=0.0, ) @@ -107,12 +117,12 @@ def _rotary_kernel( tl.store( K + off_k0, out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), ) tl.store( K + off_k1, out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + mask=(cur_seq_range[:, None, None] < max_total_len), ) return @@ -126,21 +136,36 @@ def rotary_emb_fwd(q, k, cos, sin): assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" BLOCK_SEQ = 16 - BLOCK_HEAD = 4 + BLOCK_HEAD = 2 if head_dim >= 128: num_warps = 8 else: num_warps = 4 - - grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) - _rotary_kernel[grid]( + grid = (triton.next_power_of_2(head_num_q), triton.cdiv(total_len, BLOCK_SEQ)) + _rotary_kernel_q[grid]( q, - k, cos, sin, q.stride(0), q.stride(1), q.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num_q, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + grid = (triton.cdiv(total_len, BLOCK_SEQ),) + _rotary_kernel_k[grid]( + k, + cos, + sin, k.stride(0), k.stride(1), k.stride(2), @@ -149,9 +174,7 @@ def rotary_emb_fwd(q, k, cos, sin): sin.stride(0), sin.stride(1), total_len, - head_num_q, head_num_k, - BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, num_warps=num_warps, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 4b06a75c3..b00215cff 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -16,7 +16,7 @@ from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd -from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.models.llama.flashattention_infer_struct import FlashAttentionStateInfo diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index 4c8bdd6f2..6a7622d9a 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -111,14 +111,13 @@ def _moe_ffn_edp( ep_output = layer_weight.experts.experts( hidden_states, router_logits=router_logits, - top_k=self.num_experts_per_tok, + top_k=8, renormalize=self.norm_topk_prob, use_grouped_topk=False, topk_group=None, num_expert_group=None, is_prefill=infer_state.is_prefill, ) - ep_output = ep_output.view(token_num, hidden_dim) return ep_output diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 465f9cc92..ab5218d3e 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -442,6 +442,11 @@ def make_argument_parser() -> argparse.ArgumentParser: action="store_true", help="""Whether to update the redundant expert for deepseekv3 model by online expert used counter.""", ) + parser.add_argument( + "--enable_fused_shared_experts", + action="store_true", + help="""Whether to enable fused shared experts for deepseekv3 model.""", + ) parser.add_argument( "--mtp_mode", choices=["deepseekv3", None], diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index b78784d82..b601202bf 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -26,7 +26,9 @@ def get_unique_server_name(): def set_cuda_arch(args): if not torch.cuda.is_available(): return - if args.enable_flashinfer_prefill or args.enable_flashinfer_decode: + from lightllm.utils.sgl_utils import HAS_FLASHINFER + + if HAS_FLASHINFER: capability = torch.cuda.get_device_capability() arch = f"{capability[0]}.{capability[1]}" os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}" diff --git a/lightllm/utils/sgl_utils.py b/lightllm/utils/sgl_utils.py index b48a62506..3a183c47e 100644 --- a/lightllm/utils/sgl_utils.py +++ b/lightllm/utils/sgl_utils.py @@ -30,3 +30,15 @@ "sgl_kernel is not installed, or the installed version did not support fa3. \ Try to upgrade it." ) + +try: + import flashinfer + from flashinfer.norm import fused_add_rmsnorm, rmsnorm + + HAS_FLASHINFER = True +except: + HAS_FLASHINFER = False + logger.warning( + "flashinfer is not installed, you can't use the api of it. \ + You can solve it by running `pip install flashinfer`." + ) diff --git a/test/kernel/fuse_moe_tuning.py b/test/kernel/fuse_moe_tuning.py index 6e971573a..c15129d97 100644 --- a/test/kernel/fuse_moe_tuning.py +++ b/test/kernel/fuse_moe_tuning.py @@ -7,6 +7,7 @@ from typing import List from lightllm.utils.log_utils import init_logger from transformers import AutoConfig +import torch.nn.functional as F logger = init_logger(__name__) @@ -61,6 +62,7 @@ def test_kernel( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, **config, ): set_seed() @@ -68,6 +70,8 @@ def test_kernel( a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1_scale = w2_scale = None + if num_fused_experts > 0: + expert_num += num_fused_experts if use_fp8_w8a8: init_dtype = dtype @@ -91,19 +95,21 @@ def test_kernel( w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() - rnd_logics = torch.randn(m, expert_num, device="cuda") + rnd_logics = torch.randn(m, expert_num - num_fused_experts, device="cuda") topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1) - topk_weights = torch.randn((m, topk), device="cuda", dtype=dtype) / 10 + topk_weights = torch.randn((m, topk + num_fused_experts), device="cuda", dtype=dtype) / 10 + if num_fused_experts > 0: + topk_ids = F.pad(topk_ids, (0, 1), mode="constant", value=expert_num) - expert_to_tokens = torch.empty((expert_num, topk * m), dtype=torch.int32, device="cuda") - expert_to_weights = torch.empty((expert_num, topk * m), dtype=torch.float32, device="cuda") + expert_to_tokens = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.int32, device="cuda") + expert_to_weights = torch.empty((expert_num, (topk + num_fused_experts) * m), dtype=torch.float32, device="cuda") moe_align(topk_ids=topk_ids, out=expert_to_tokens) expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") - moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk) + moe_align1(expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_experts) - out1 = torch.zeros((m * topk, 2 * n), dtype=torch.bfloat16, device="cuda") - down_in = torch.zeros((m * topk, n), dtype=torch.bfloat16, device="cuda") - out2 = torch.zeros((m * topk, k), dtype=torch.bfloat16, device="cuda") + out1 = torch.zeros((m * (topk + 1), 2 * n), dtype=torch.bfloat16, device="cuda") + down_in = torch.zeros((m * (topk + 1), n), dtype=torch.bfloat16, device="cuda") + out2 = torch.zeros((m * (topk + 1), k), dtype=torch.bfloat16, device="cuda") for _ in range(test_count): input_tuples.append( @@ -219,6 +225,7 @@ def worker( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, test_configs, queue, ): @@ -235,6 +242,7 @@ def worker( use_fp8_w8a8=use_fp8_w8a8, is_up=is_up, block_shape=block_shape, + num_fused_experts=num_fused_experts, **test_configs[index], ) queue.put(cost_time) # Put result in queue @@ -302,6 +310,7 @@ def tuning_configs( use_fp8_w8a8: bool, is_up: bool, block_shape, + num_fused_experts: int, ): os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) best_config, best_cost_time = None, 10000000 @@ -325,6 +334,7 @@ def tuning_configs( use_fp8_w8a8, is_up, block_shape, + num_fused_experts, test_configs, queue, ), @@ -359,6 +369,7 @@ def tuning_configs( use_fp8_w8a8, is_up, block_shape, + num_fused_experts, test_configs, queue, ), @@ -393,15 +404,16 @@ def main(args): if config.architectures[0] == "Qwen3MoeForCausalLM": expert_num = config.num_experts topk_num = config.num_experts_per_tok - n = 2 * config.moe_intermediate_size // args.tp + n = config.moe_intermediate_size // args.tp elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: expert_num = config.n_routed_experts topk_num = config.num_experts_per_tok - n = 2 * config.moe_intermediate_size // args.tp + n = config.moe_intermediate_size // args.tp else: pass hidden_dim = getattr(config, "hidden_size", None) or config.text_config.hidden_size + print(n, hidden_dim) use_fp8_w8a8 = args.use_fp8_w8a8 block_shape = None if hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config: @@ -424,6 +436,7 @@ def main(args): "use_fp8_w8a8": use_fp8_w8a8, "is_up": True, "block_shape": block_shape, + "num_fused_experts": args.num_fused_experts, }, ) up_dict[m] = ans @@ -431,7 +444,7 @@ def main(args): N=n * 2, K=hidden_dim, topk_num=topk_num, - expert_num=expert_num, + expert_num=expert_num + 1, mul_routed_weight=False, use_fp8_w8a8=use_fp8_w8a8, out_dtype=str(torch.bfloat16), @@ -453,6 +466,7 @@ def main(args): "use_fp8_w8a8": use_fp8_w8a8, "is_up": False, "block_shape": block_shape, + "num_fused_experts": args.num_fused_experts, }, ) down_dict[m] = ans @@ -461,7 +475,7 @@ def main(args): N=hidden_dim, K=n, topk_num=1, - expert_num=expert_num, + expert_num=expert_num + 1, mul_routed_weight=True, use_fp8_w8a8=use_fp8_w8a8, out_dtype=str(torch.bfloat16), @@ -474,5 +488,6 @@ def main(args): parser.add_argument("--model_dir", type=str, default="deepseek-ai/DeepSeek-R1") parser.add_argument("--tp", type=int, default=8) parser.add_argument("--use_fp8_w8a8", action="store_true") + parser.add_argument("--num_fused_experts", type=int, default=0) args = parser.parse_args() main(args)