diff --git a/CMakeLists.txt b/CMakeLists.txt index f11d28590b2..afaed7cd182 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -543,8 +543,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUTLASS MoE kernels # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and only works - # on Hopper). get_cutlass_moe_mm_data should only be compiled if it's possible - # to compile MoE kernels that use its output. + # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled + # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu" diff --git a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py index 1be83b84e95..acabe6c1ddb 100644 --- a/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py +++ b/benchmarks/kernels/benchmark_grouped_gemm_cutlass.py @@ -7,8 +7,8 @@ from vllm import _custom_ops as ops from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp8 from vllm.model_executor.layers.fused_moe.fused_moe import ( - cutlass_moe_fp8, fused_experts, fused_topk, ) @@ -70,18 +70,9 @@ def bench_run( w1_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) w2_scale = torch.empty((num_experts, 1, 1), device="cuda", dtype=torch.float32) - ab_strides1 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) - c_strides1 = torch.full((num_experts,), 2 * n, device="cuda", dtype=torch.int64) - ab_strides2 = torch.full((num_experts,), n, device="cuda", dtype=torch.int64) - c_strides2 = torch.full((num_experts,), k, device="cuda", dtype=torch.int64) - for expert in range(num_experts): w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(w1[expert]) w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(w2[expert]) - w1_q_notransp = w1_q.clone() - w2_q_notransp = w2_q.clone() - w1_q = w1_q.transpose(1, 2) - w2_q = w2_q.transpose(1, 2) score = torch.randn((m, num_experts), device="cuda", dtype=dtype) @@ -122,10 +113,6 @@ def run_cutlass_moe( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, num_repeats: int, ): for _ in range(num_repeats): @@ -133,14 +120,10 @@ def run_cutlass_moe( a, w1, w2, - w1_scale, - w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, + w1_scale, + w2_scale, a1_scale=a_scale, ) @@ -153,10 +136,6 @@ def run_cutlass_from_graph( w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, ): with set_current_vllm_config( VllmConfig(parallel_config=ParallelConfig(pipeline_parallel_size=1)) @@ -165,14 +144,10 @@ def run_cutlass_from_graph( a, w1_q, w2_q, - w1_scale, - w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, + w1_scale, + w2_scale, a1_scale=a_scale, ) @@ -218,10 +193,6 @@ def replay_graph(graph, num_repeats): w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, ) torch.cuda.synchronize() @@ -230,8 +201,8 @@ def replay_graph(graph, num_repeats): with torch.cuda.graph(triton_graph, stream=triton_stream): run_triton_from_graph( a, - w1_q_notransp, - w2_q_notransp, + w1_q, + w2_q, topk_weights, topk_ids, w1_scale, @@ -250,18 +221,12 @@ def replay_graph(graph, num_repeats): "w2": w2, "score": score, "topk": topk, - "w1_q_notransp": w1_q_notransp, - "w2_q_notransp": w2_q_notransp, # Cutlass params "a_scale": a_scale, "w1_q": w1_q, "w2_q": w2_q, "w1_scale": w1_scale, "w2_scale": w2_scale, - "ab_strides1": ab_strides1, - "c_strides1": c_strides1, - "ab_strides2": ab_strides2, - "c_strides2": c_strides2, # cuda graph params "cutlass_graph": cutlass_graph, "triton_graph": triton_graph, @@ -279,8 +244,8 @@ def replay_graph(graph, num_repeats): # Warmup run_triton_moe( a, - w1_q_notransp, - w2_q_notransp, + w1_q, + w2_q, topk_weights, topk_ids, w1_scale, @@ -291,7 +256,7 @@ def replay_graph(graph, num_repeats): results.append( benchmark.Timer( - stmt="run_triton_moe(a, w1_q_notransp, w2_q_notransp, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 + stmt="run_triton_moe(a, w1_q, w2_q, topk_weights, topk_ids, w1_scale, w2_scale, a_scale, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, @@ -322,16 +287,12 @@ def replay_graph(graph, num_repeats): w2_scale, topk_weights, topk_ids, - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, num_warmup, ) results.append( benchmark.Timer( - stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, ab_strides1, c_strides1, ab_strides2, c_strides2, num_runs)", # noqa: E501 + stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501 globals=globals, label=label, sub_label=sub_label, diff --git a/csrc/ops.h b/csrc/ops.h index 6905ef6e591..f02f5083ac1 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -236,7 +236,8 @@ void cutlass_moe_mm( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch); void cutlass_fp4_group_mm( torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b, @@ -251,6 +252,14 @@ void get_cutlass_moe_mm_data( const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); +void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, const int64_t n, + const int64_t k); + void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu index 2b8bc3fb0b2..c88e134ae40 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu @@ -84,7 +84,8 @@ void run_cutlass_moe_mm_sm90( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided."); TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided."); TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided."); @@ -113,19 +114,23 @@ void run_cutlass_moe_mm_sm90( if (n >= 8192) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); } else if (k >= 8192) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); } else if (m <= 16) { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); } else { cutlass_group_gemm_caller( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); } } @@ -134,15 +139,18 @@ void dispatch_moe_mm_sm90( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { if (out_tensors.dtype() == torch::kBFloat16) { run_cutlass_moe_mm_sm90( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); } else { run_cutlass_moe_mm_sm90( out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, - problem_sizes, a_strides, b_strides, c_strides); + problem_sizes, a_strides, b_strides, c_strides, per_act_token, + per_out_ch); } } @@ -153,8 +161,9 @@ void cutlass_moe_mm_sm90( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { dispatch_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, - c_strides); + c_strides, per_act_token, per_out_ch); } diff --git a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh index db827b7c5e1..bbd82d72e95 100644 --- a/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh +++ b/csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cuh @@ -76,7 +76,8 @@ void cutlass_group_gemm_caller( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { using ElementAB = typename Gemm::ElementAB; using ElementD = typename Gemm::ElementD; @@ -84,9 +85,6 @@ void cutlass_group_gemm_caller( int k_size = a_tensors.size(1); int n_size = out_tensors.size(1); - bool per_act_token = a_scales.numel() != 1; - bool per_out_ch = b_scales.numel() != num_experts; - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); auto options_int = diff --git a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu index ac414e1bc0c..32254641cc3 100644 --- a/csrc/quantization/cutlass_w8a8/moe/moe_data.cu +++ b/csrc/quantization/cutlass_w8a8/moe/moe_data.cu @@ -7,7 +7,7 @@ constexpr uint64_t THREADS_PER_EXPERT = 512; -__global__ void compute_problem_sizes(const int* __restrict__ topk_ids, +__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids, int32_t* problem_sizes1, int32_t* problem_sizes2, int32_t* atomic_buffer, @@ -62,7 +62,7 @@ __global__ void compute_expert_blockscale_offsets( } } -__global__ void compute_arg_sorts(const int* __restrict__ topk_ids, +__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids, const int32_t* __restrict__ expert_offsets, int32_t* input_permutation, int32_t* output_permutation, @@ -103,7 +103,7 @@ void get_cutlass_moe_mm_data_caller( int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel()); compute_problem_sizes<<>>( - static_cast(topk_ids.data_ptr()), + static_cast(topk_ids.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), n, k); @@ -120,10 +120,44 @@ void get_cutlass_moe_mm_data_caller( static_cast(atomic_buffer.data_ptr()), num_experts); } compute_arg_sorts<<>>( - static_cast(topk_ids.data_ptr()), + static_cast(topk_ids.data_ptr()), static_cast(expert_offsets.data_ptr()), static_cast(input_permutation.data_ptr()), static_cast(output_permutation.data_ptr()), static_cast(atomic_buffer.data_ptr()), topk_ids.numel(), topk_ids.size(1)); } + +__global__ void compute_pplx_data(int32_t* expert_offsets, + int32_t* problem_sizes1, + int32_t* problem_sizes2, + const int32_t* __restrict__ expert_num_tokens, + const int padded_m, const int n, + const int k) { + int expert_idx = threadIdx.x; + + expert_offsets[expert_idx] = expert_idx * padded_m; + problem_sizes1[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes1[expert_idx * 3 + 1] = 2 * n; + problem_sizes1[expert_idx * 3 + 2] = k; + problem_sizes2[expert_idx * 3] = expert_num_tokens[expert_idx]; + problem_sizes2[expert_idx * 3 + 1] = k; + problem_sizes2[expert_idx * 3 + 2] = n; +} + +void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, + const int64_t n, const int64_t k) { + auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); + + compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + static_cast(expert_offsets.data_ptr()), + static_cast(problem_sizes1.data_ptr()), + static_cast(problem_sizes2.data_ptr()), + static_cast(expert_num_tokens.data_ptr()), padded_m, n, + k); +} diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index ee93440b575..34852581081 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -36,7 +36,8 @@ void cutlass_moe_mm_sm90( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides); + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch); #endif @@ -56,6 +57,14 @@ void get_cutlass_moe_mm_data_caller( torch::Tensor& input_permutation, torch::Tensor& output_permutation, const int64_t num_experts, const int64_t n, const int64_t k, const std::optional& blockscale_offsets); + +void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, + const int64_t n, const int64_t k); #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -207,12 +216,13 @@ void cutlass_moe_mm( torch::Tensor const& b_tensors, torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& expert_offsets, torch::Tensor const& problem_sizes, torch::Tensor const& a_strides, - torch::Tensor const& b_strides, torch::Tensor const& c_strides) { + torch::Tensor const& b_strides, torch::Tensor const& c_strides, + bool per_act_token, bool per_out_ch) { int32_t version_num = get_sm_version_num(); #if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, - c_strides); + c_strides, per_act_token, per_out_ch); return; #endif TORCH_CHECK_NOT_IMPLEMENTED( @@ -245,6 +255,29 @@ void get_cutlass_moe_mm_data( version_num, ". Required capability: 90"); } +void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, const int64_t n, + const int64_t k) { + // This function currently gets compiled only if we have a valid cutlass moe + // mm to run it for. + int32_t version_num = get_sm_version_num(); +#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90 + get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, + problem_sizes2, expert_num_tokens, + num_local_experts, padded_m, n, k); + return; +#endif + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " + "for CUDA device capability: ", + version_num, ". Required capability: 90"); +} + void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 93916b7f94b..1a1896b4c1e 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -435,7 +435,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "cutlass_moe_mm(Tensor! out_tensors, Tensor a_tensors, Tensor b_tensors, " " Tensor a_scales, Tensor b_scales, Tensor expert_offsets, " " Tensor problem_sizes, Tensor a_strides, " - " Tensor b_strides, Tensor c_strides) -> ()", + " Tensor b_strides, Tensor c_strides, bool per_act_token, " + " bool per_out_ch) -> ()", {stride_tag}); ops.impl("cutlass_moe_mm", torch::kCUDA, &cutlass_moe_mm); @@ -454,6 +455,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { {stride_tag}); ops.impl("get_cutlass_moe_mm_data", torch::kCUDA, &get_cutlass_moe_mm_data); + // A function that computes data required to run fused MoE with w8a8 grouped + // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs + // as an input, and computes expert_offsets (token start indices of each + // expert). In addition to this, it computes problem sizes for each expert's + // multiplication used by the two mms called from fused MoE operation. + ops.def( + "get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, " + " Tensor! problem_sizes1, " + " Tensor! problem_sizes2, " + " Tensor expert_num_tokens, " + " int num_local_experts, int padded_m, " + " int n, int k) -> ()", + {stride_tag}); + ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA, + &get_cutlass_pplx_moe_mm_data); + // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) ops.def( "cutlass_scaled_mm_supports_block_fp8(int cuda_device_capability) -> " diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 558288ba44d..474745f9481 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -193,14 +193,10 @@ def run_8_bit(moe_tensors: MOETensors8Bit, kwargs = { 'a': moe_tensors.a, - 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] - 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] + 'w1_q': moe_tensors.w1_q, # type: ignore[union-attr] + 'w2_q': moe_tensors.w2_q, # type: ignore[union-attr] 'topk_weights': topk_weights, 'topk_ids': topk_ids, - 'ab_strides1': moe_tensors.ab_strides1, - 'c_strides1': moe_tensors.c_strides1, - 'ab_strides2': moe_tensors.ab_strides2, - 'c_strides2': moe_tensors.c_strides2, 'w1_scale': moe_tensors.w1_scale, 'w2_scale': moe_tensors.w2_scale, 'a1_scale': moe_tensors.a_scale diff --git a/tests/kernels/moe/test_pplx_cutlass_moe.py b/tests/kernels/moe/test_pplx_cutlass_moe.py new file mode 100644 index 00000000000..ef3e6adcfa3 --- /dev/null +++ b/tests/kernels/moe/test_pplx_cutlass_moe.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch + +from tests.pplx_utils import ProcessGroupInfo, parallel_launch +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8 +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.platforms import current_platform + +try: + from pplx_kernels import AllToAll + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) + has_pplx = True +except ImportError: + has_pplx = False + +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + +NUM_EXPERTS = [40, 64] +TOP_KS = [6, 8] + + +def rank_chunk(num, r, w): + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +def chunk_by_rank(t, r, w): + num = t.shape[0] + chunk = rank_chunk(num, r, w) + rem = num % w + if rem == 0 or r < rem: + return t[(r * chunk):(r + 1) * chunk].contiguous() + else: + long_chunks = (num // w + 1) * rem + short_chunks = (r - rem) * chunk + start = long_chunks + short_chunks + return t[start:start + chunk].contiguous() + + +def pplx_cutlass_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + a1_scale: torch.Tensor, + out_dtype, + per_act_token: bool, + per_out_ch: bool, +): + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + assert torch.cuda.current_device() == pgi.local_rank + + num_tokens, hidden_dim = a.shape + num_experts = w1.shape[0] + block_size = hidden_dim # TODO support more cases + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + rank_num_tokens = rank_chunk(num_tokens, rank, world_size) + max_num_tokens = rank_chunk(num_tokens, 0, world_size) + topk = topk_ids.shape[1] + + if block_size == hidden_dim: + scale_elems = 4 # hack to circumvent pplx data format requirements + else: + scale_elems = (hidden_dim + block_size - 1) // block_size + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=pgi.world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim, # because a.dtype.itemsize == 1 + hidden_dim_scale_bytes=scale_elems * torch.float32.itemsize, + ) + + w1 = w1.to(device) + w2 = w2.to(device) + w1_scale = w1_scale.to(device) + w2_scale = w2_scale.to(device) + a1_scale = a1_scale.to(device) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens, + pgi.world_size, + rank, + dp_size, + quant_dtype=torch.float8_e4m3fn, + per_act_token=per_act_token, + ) + + experts = CutlassExpertsFp8((num_experts + world_size - 1) // world_size, + out_dtype, per_act_token, per_out_ch) + + fused_cutlass_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weights, rank, + world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, + world_size).to(torch.uint32).to(device) + + out = fused_cutlass_experts( + a_chunk, + chunk_by_rank(w1, rank, world_size), + chunk_by_rank(w2, rank, world_size), + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts, + expert_map=None, #TODO + w1_scale=chunk_by_rank(w1_scale, rank, world_size), + w2_scale=chunk_by_rank(w2_scale, rank, world_size), + a1_scale=chunk_by_rank(a1_scale, rank, world_size) + if per_act_token else a1_scale[rank]) + + torch.cuda.synchronize() + + ata.destroy() + + return out[:rank_num_tokens] + + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + + +def torch_moe2(a, w1, w2, topk_weight, topk_ids): + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + a1_scale: torch.Tensor, + out_dtype, + a_full: torch.Tensor, + w1_full: torch.Tensor, + w2_full: torch.Tensor, + per_act_token: bool, + per_out_ch: bool, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + with set_current_vllm_config(vllm_config): + torch_output = torch_moe2(a_full, w1_full, w2_full, topk_weights, + topk_ids) + pplx_output = pplx_cutlass_moe(pgi, dp_size, a, w1, w2, w1_scale, + w2_scale, topk_weights, topk_ids, + a1_scale, out_dtype, per_act_token, + per_out_ch) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + # Uncomment if more debugging is needed + # print("PPLX OUT:", pplx_output) + # print("TORCH OUT:", torch_output) + + torch.testing.assert_close(pplx_output, torch_output, atol=0.05, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("m", [2, 224]) +@pytest.mark.parametrize("n", [3072]) +@pytest.mark.parametrize("k", [1536]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("per_act_token", [True, False]) +@pytest.mark.parametrize("per_out_ch", [True, False]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) #, [4, 2]]) +@pytest.mark.skipif( + (lambda x: x is None or not ops.cutlass_group_gemm_supported(x.to_int()))( + current_platform.get_device_capability()), + reason="Grouped gemm is not supported on this GPU type.") +@requires_pplx +def test_cutlass_moe_pplx( + m: int, + n: int, + k: int, + e: int, + topk: int, + per_act_token: bool, + per_out_ch: bool, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + + with set_current_vllm_config(vllm_config): + + dtype = torch.half + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10.0 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10.0 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10.0 + + n_b_scales = 2 * n if per_out_ch else 1 + k_b_scales = k if per_out_ch else 1 + + w1_q = torch.empty((e, 2 * n, k), + device="cuda", + dtype=torch.float8_e4m3fn) + w2_q = torch.empty((e, k, n), device="cuda", dtype=torch.float8_e4m3fn) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=per_out_ch) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=per_out_ch) + + w1_d = torch.empty_like(w1) + w2_d = torch.empty_like(w2) + for expert in range(e): + w1_d[expert] = (w1_q[expert].float() * w1_scale[expert]).half() + w2_d[expert] = (w2_q[expert].float() * w2_scale[expert]).half() + + score = torch.randn((m, e), device="cuda", dtype=dtype) + topk_weights, topk_ids, _ = fused_topk(a, + score, + topk, + renormalize=False) + + world_size, dp_size = world_dp_size + a_scale1 = torch.randn( + (m if per_act_token else 1, 1), device="cuda", + dtype=torch.float32) / 10.0 + if not per_act_token: + a_scale1 = a_scale1.repeat(world_size, 1) + + parallel_launch(world_size, _pplx_moe, dp_size, a, w1_q, w2_q, + w1_scale, w2_scale, topk_weights, topk_ids, a_scale1, + dtype, a, w1_d, w2_d, per_act_token, per_out_ch) diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py index 95c10037b23..bbfe31d0e65 100644 --- a/tests/kernels/moe/test_pplx_moe.py +++ b/tests/kernels/moe/test_pplx_moe.py @@ -4,10 +4,7 @@ Run `pytest tests/kernels/test_pplx_moe.py`. """ -import dataclasses -import os -import traceback -from typing import Callable, Optional +from typing import Optional import pytest import torch @@ -21,10 +18,7 @@ except ImportError: has_pplx = False -from torch.multiprocessing import ( - spawn) # pyright: ignore[reportPrivateImportUsage] -from typing_extensions import Concatenate, ParamSpec - +from tests.pplx_utils import ProcessGroupInfo, parallel_launch from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import override_config @@ -36,6 +30,11 @@ FusedMoEModularKernel) from vllm.platforms import current_platform +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), (222, 2048, 1024)] @@ -57,122 +56,6 @@ vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_model_len = 8192 -P = ParamSpec("P") - -requires_pplx = pytest.mark.skipif( - not has_pplx, - reason="Requires PPLX kernels", -) - - -@dataclasses.dataclass -class ProcessGroupInfo: - world_size: int - world_local_size: int - rank: int - node_rank: int - local_rank: int - device: torch.device - - -def _worker_parallel_launch( - local_rank: int, - world_size: int, - world_local_size: int, - node_rank: int, - init_method: str, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - rank = node_rank * world_local_size + local_rank - torch.cuda.set_device(local_rank) - device = torch.device("cuda", local_rank) - torch.distributed.init_process_group( - backend="cpu:gloo,cuda:nccl", - init_method=init_method, - rank=rank, - world_size=world_size, - device_id=device, - ) - barrier = torch.tensor([rank], device=device) - torch.distributed.all_reduce(barrier) - - try: - worker( - ProcessGroupInfo( - world_size=world_size, - world_local_size=world_local_size, - rank=rank, - node_rank=node_rank, - local_rank=local_rank, - device=device, - ), - *args, - **kwargs, - ) - except Exception as ex: - print(ex) - traceback.print_exc() - raise - finally: - torch.distributed.destroy_process_group() - - -def parallel_launch( - world_size: int, - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - assert not kwargs - spawn( - _worker_parallel_launch, - args=( - world_size, - world_size, - 0, - "tcp://localhost:29500", - worker, - ) + args, - nprocs=world_size, - join=True, - ) - - -def parallel_launch_from_env( - worker: Callable[Concatenate[ProcessGroupInfo, P], None], - *args: P.args, - **kwargs: P.kwargs, -) -> None: - """ - Launches a worker function in parallel across all processes in the current - environment. The environment must have the following variables set: - - WORLD_SIZE: The total number of processes. - - WORLD_LOCAL_SIZE: The number of processes on the current node. - - NODE_RANK: The rank of the current - - MASTER_ADDR: The address of the master process. - - MASTER_PORT: The port of the master process. - """ - assert not kwargs - world_size = int(os.environ["WORLD_SIZE"]) - world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) - node_rank = int(os.environ["NODE_RANK"]) - assert "MASTER_ADDR" in os.environ - assert "MASTER_PORT" in os.environ - spawn( - _worker_parallel_launch, - args=( - world_size, - world_local_size, - node_rank, - "env://", - worker, - ) + args, - nprocs=world_local_size, - join=True, - ) - def torch_prepare( a: torch.Tensor, diff --git a/tests/kernels/quantization/test_cutlass_scaled_mm.py b/tests/kernels/quantization/test_cutlass_scaled_mm.py index 51bb29df054..c4d349f1a5a 100644 --- a/tests/kernels/quantization/test_cutlass_scaled_mm.py +++ b/tests/kernels/quantization/test_cutlass_scaled_mm.py @@ -632,7 +632,8 @@ def test_cutlass_fp8_group_gemm(num_experts: int, per_act_token: bool, ops.cutlass_moe_mm(out_tensors_stacked, a_tensors_stacked, b_tensors_stacked, a_scales_tensors_stacked, b_scales_tensors_stacked, expert_offsets[:-1], - problem_sizes, ab_strides, ab_strides, c_strides) + problem_sizes, ab_strides, ab_strides, c_strides, + per_act_token, per_out_ch) # Validate each group's result against the baseline for g in range(num_experts): diff --git a/tests/pplx_utils.py b/tests/pplx_utils.py new file mode 100644 index 00000000000..2d5d5be80c3 --- /dev/null +++ b/tests/pplx_utils.py @@ -0,0 +1,123 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses +import os +import traceback +from typing import Callable + +import torch +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +P = ParamSpec("P") + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + args, + nprocs=world_local_size, + join=True, + ) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 14404cd735b..92de1f5efa8 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -899,11 +899,36 @@ def shuffle_rows(input_tensor: torch.Tensor, dst2src_map: torch.Tensor): return output_tensor +def get_cutlass_pplx_moe_mm_data(expert_offsets: torch.Tensor, + problem_sizes1: torch.Tensor, + problem_sizes2: torch.Tensor, + expert_num_tokens: torch.Tensor, + num_local_experts: int, padded_m: int, n: int, + k: int): + """ + Prepare data necessary to perform CUTLASS grouped matrix multiplications + used in CUTLASS-based fused MoE. + + The function takes in expert_num_tokens (token count per expert) and + non_zero_expert_idxs (consecutive indices of experts with non-zero token + counts) and uses them to compute: + - expert_offsets: Indices that mark at which token index each expert begins + its computation. + - problem_sizes1, problem_sizes2: MxNxK sizes of each expert's + multiplication in two grouped MMs used in + the fused MoE operation. + """ + return torch.ops._C.get_cutlass_pplx_moe_mm_data( + expert_offsets, problem_sizes1, problem_sizes2, expert_num_tokens, + num_local_experts, padded_m, n, k) + + def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, b_tensors: torch.Tensor, a_scales: torch.Tensor, b_scales: torch.Tensor, expert_offsets: torch.Tensor, problem_sizes: torch.Tensor, a_strides: torch.Tensor, - b_strides: torch.Tensor, c_strides: torch.Tensor): + b_strides: torch.Tensor, c_strides: torch.Tensor, + per_act_token: bool, per_out_ch: bool): """ A single grouped matrix multiplication used in CUTLASS-based fused MoE. The function executes fp8-quantized OUT = AB matrix multiplication. @@ -918,7 +943,7 @@ def cutlass_moe_mm(out_tensors: torch.Tensor, a_tensors: torch.Tensor, return torch.ops._C.cutlass_moe_mm(out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets, problem_sizes, a_strides, b_strides, - c_strides) + c_strides, per_act_token, per_out_ch) def cutlass_fp4_moe_mm(a_tensors: torch.Tensor, b_tensors: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a541d46b14a..76d71ca0885 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -39,6 +39,7 @@ def __init__(self, max_num_tokens: int, world_size: int, dp_size: int, def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 4db6b84e9d5..d62d519af8d 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -67,6 +67,7 @@ def __init__(self, def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, @@ -78,11 +79,11 @@ def workspace_shapes( # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and self.batched_deep_gemm_experts is not None: return self.batched_deep_gemm_experts.workspace_shapes( - a, M, N, K, topk, num_experts) + a, aq, M, N, K, topk, num_experts) else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, M, N, K, topk, num_experts) + a, aq, M, N, K, topk, num_experts) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index e9446bc5fd2..6e7b1a4f2b6 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ CUTLASS based Fused MoE kernels.""" -from typing import Optional +from typing import Callable, Optional import torch @@ -13,110 +13,109 @@ from vllm.scalar_type import scalar_types -class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): - - def __init__( - self, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, - out_dtype: torch.dtype, - ): - super().__init__() - self.ab_strides1 = ab_strides1 - self.c_strides1 = c_strides1 - self.ab_strides2 = ab_strides2 - self.c_strides2 = c_strides2 - self.out_dtype = out_dtype - - def workspace_shapes( - self, - a: torch.Tensor, - M: int, - N: int, - K: int, - topk: int, - num_experts: int, - ) -> tuple[int, int, torch.dtype]: - # Note that K, N are transposed - N, K = K, N - workspace1 = M * topk * max(2 * N, K) - workspace2 = M * topk * N - return (workspace1, workspace2, self.out_dtype) - - def apply( - self, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: Optional[torch.Tensor], - w1_scale: Optional[torch.Tensor], - w2_scale: Optional[torch.Tensor], - w1_zp: Optional[torch.Tensor], - w2_zp: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_num_tokens: Optional[torch.Tensor], - ) -> torch.Tensor: - a1q = hidden_states - - assert w1_scale is not None - assert w2_scale is not None - assert w1.dtype == torch.float8_e4m3fn - assert w2.dtype == torch.float8_e4m3fn - assert a1q.shape[1] == w1.shape[1], "Hidden size mismatch w1" - assert w1.shape[2] == w2.shape[1] * 2, "Hidden size mismatch w2" - assert w1.shape[0] == w2.shape[0], "Expert number mismatch" - assert a1q_scale is None or a1q_scale.dim( - ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ - 0], "Input scale shape mismatch" - assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ - 1] == w1.shape[2], "W1 scale shape mismatch" - assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ - 1] == w2.shape[2], "W2 scale shape mismatch" - assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" - assert w1.shape[0] == w1_scale.shape[ - 0], "w1 scales expert number mismatch" - assert w1.shape[0] == w2_scale.shape[ - 0], "w2 scales expert number mismatch" - assert a2_scale is None or a1q_scale is None or a2_scale.shape == a1q_scale.shape, "Intermediate scale shape mismatch" # noqa: E501 - assert self.ab_strides1.shape[0] == w1.shape[ - 0], "AB Strides 1 expert number mismatch" - assert self.c_strides1.shape[0] == w1.shape[ - 0], "C Strides 1 expert number mismatch" - assert self.ab_strides2.shape[0] == w2.shape[ - 0], "AB Strides 2 expert number mismatch" - assert self.c_strides2.shape[0] == w2.shape[ - 0], "C Strides 2 expert number mismatch" - assert self.out_dtype in [torch.half, - torch.bfloat16], "Invalid output dtype" - - M = a1q.shape[0] - _, N, K = w2.shape # because w1 + w2 are transposed - device = a1q.device - - assert w1.shape[1] == K - assert global_num_experts != -1 - assert a1q_scale is not None - - if expert_map is not None: - "Translate info from expert_map to topk_ids" - local_topk_ids = torch.where(expert_map[topk_ids] != -1, - expert_map[topk_ids], -1) - else: - local_topk_ids = topk_ids +def run_cutlass_moe_fp8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation_callable: Callable, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + out_dtype: torch.dtype, + per_act_token: bool, + per_out_ch: bool, +) -> torch.Tensor: + a1q = hidden_states + + assert w1_scale is not None + assert w2_scale is not None + assert w1.dtype == torch.float8_e4m3fn + assert w2.dtype == torch.float8_e4m3fn + if expert_num_tokens is None: + assert a1q.shape[1] == w1.shape[2], "Hidden size mismatch w1" + else: + assert a1q.shape[2] == w1.shape[2], "Hidden size mismatch w1" + assert w1.shape[1] == w2.shape[2] * 2, "Hidden size mismatch w2" + assert w1_scale.dim() == 1 or w1_scale.shape[1] == 1 or w1_scale.shape[ + 1] == w1.shape[1], "W1 scale shape mismatch" + assert w2_scale.dim() == 1 or w2_scale.shape[1] == 1 or w2_scale.shape[ + 1] == w2.shape[1], "W2 scale shape mismatch" + assert w1.shape[0] == w2.shape[0], "Expert number mismatch" + assert a1q_scale is None or a1q_scale.dim( + ) == 0 or a1q_scale.shape[0] == 1 or a1q_scale.shape[0] == a1q.shape[ + 0], "Input scale shape mismatch" + assert w1.shape[0] == w2.shape[0], "Weights expert number mismatch" + assert w1.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch" + assert w1.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch" + assert a2_scale is None or a2_scale.dim( + ) == 0 or a2_scale.shape[0] == 1 or a2_scale.shape[0] == a1q.shape[ + 0], "Intermediate scale shape mismatch" + assert out_dtype in [torch.half, torch.bfloat16], "Invalid output dtype" + if expert_map is not None: + assert expert_num_tokens is None + + # We have two modes: PPLX and non-PPLX. We differentiate them by checking + # if expert_num_tokens is None (expert_num_tokens is a tensor which PPLX + # uses to track the number of tokens per expert). + # In the non-PPLX mode, the input tokens are not padded: thus, the shape + # of the input is [total_num_tokens, hidden_size]. The input and output + # require shuffling by a_map and c_map such that the tokens assigned to + # each expert are contiguous. + # In the PPLX mode, the input tokens are padded per expert to ensure that + # the PPLX dispatch and combine functions work correctly: thus, the shape + # of the input is [num_experts, max_num_tokens_per_expert, hidden_size]. + # The PPLX input and output require no shuffling by a_map and c_map since + # their tokens are already contiguous for each expert as a result of + # the dispatch function. + is_pplx = expert_num_tokens is not None + + M = a1q.shape[0] # no pplx + padded_M = a1q.shape[1] # pplx + _, K, N = w2.shape + device = a1q.device + + assert w1.shape[2] == K + assert global_num_experts != -1 + assert a1q_scale is not None + + if expert_map is not None: + "Translate info from expert_map to topk_ids" + local_topk_ids = torch.where(expert_map[topk_ids] != -1, + expert_map[topk_ids], -1) + else: + local_topk_ids = topk_ids + + topk = local_topk_ids.shape[1] + local_E = w1.shape[0] + + if is_pplx: + expert_offsets = torch.empty((local_E), + dtype=torch.int32, + device=device) + problem_sizes1 = torch.empty((local_E, 3), + dtype=torch.int32, + device=device) + problem_sizes2 = torch.empty((local_E, 3), + dtype=torch.int32, + device=device) - topk = local_topk_ids.shape[1] + ops.get_cutlass_pplx_moe_mm_data(expert_offsets, problem_sizes1, + problem_sizes2, expert_num_tokens, + local_E, padded_M, N, K) - per_act_token = a1q_scale.numel() != 1 if a1q_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + w1_scale = w1_scale.reshape(w1_scale.shape[0], -1) + w2_scale = w2_scale.reshape(w2_scale.shape[0], -1) + a1q = a1q.reshape(-1, a1q.shape[2]) + a1q_scale = a1q_scale.reshape(-1, a1q_scale.shape[2]).contiguous() + else: expert_offsets = torch.empty((global_num_experts + 1), dtype=torch.int32, device=device) @@ -149,50 +148,130 @@ def apply( a1q = _fp8_perm(a1q, a_map) a1q_scale = a1q_scale[a_map] if per_act_token else a1q_scale - + expert_offsets = expert_offsets[:-1] + + ab_strides1 = torch.full((w1.shape[0], ), + K, + device=device, + dtype=torch.int64) + c_strides1 = torch.full((w1.shape[0], ), + 2 * N, + device=device, + dtype=torch.int64) + ab_strides2 = torch.full((w1.shape[0], ), + N, + device=device, + dtype=torch.int64) + c_strides2 = torch.full((w1.shape[0], ), + K, + device=device, + dtype=torch.int64) + + if is_pplx: + c1 = _resize_cache(workspace13, (local_E * padded_M, N * 2)) + c2 = _resize_cache(workspace2, (local_E * padded_M, N)) + c3 = _resize_cache(workspace13, (local_E * padded_M, K)) + else: c1 = _resize_cache(workspace13, (M * topk, N * 2)) c2 = _resize_cache(workspace2, (M * topk, N)) c3 = _resize_cache(workspace13, (M * topk, K)) - ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, - expert_offsets[:-1], problem_sizes1, - self.ab_strides1, self.ab_strides1, self.c_strides1) + ops.cutlass_moe_mm(c1, a1q, w1, a1q_scale, w1_scale, expert_offsets, + problem_sizes1, ab_strides1, ab_strides1, c_strides1, + per_act_token, per_out_ch) - self.activation(activation, c2, c1) + activation_callable(c2, c1) - a2q, a2q_scale = ops.scaled_fp8_quant( - c2, a2_scale, use_per_token_if_dynamic=per_act_token) + a2q, a2q_scale = ops.scaled_fp8_quant( + c2, a2_scale, use_per_token_if_dynamic=per_act_token) - if expert_map is not None: - c3.fill_(0) + if expert_map is not None: + c3.fill_(0) + + ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, expert_offsets, + problem_sizes2, ab_strides2, ab_strides2, c_strides2, + per_act_token, per_out_ch) - ops.cutlass_moe_mm(c3, a2q, w2, a2q_scale, w2_scale, - expert_offsets[:-1], problem_sizes2, - self.ab_strides2, self.ab_strides2, self.c_strides2) + if is_pplx: + return c3.reshape(local_E, padded_M, K) + else: + return c3[c_map].view(M, topk, K) - c3 = c3[c_map] - return c3 +class CutlassExpertsFp8(mk.FusedMoEPermuteExpertsUnpermute): + + def __init__( + self, + max_experts_per_worker: int, + out_dtype: torch.dtype, + per_act_token: bool, + per_out_ch: bool, + ): + super().__init__() + self.max_experts_per_worker = max_experts_per_worker + self.out_dtype = out_dtype + self.per_act_token = per_act_token + self.per_out_ch = per_out_ch + + def workspace_shapes( + self, + a: torch.Tensor, + aq: torch.Tensor, + M: int, + N: int, + K: int, + topk: int, + num_experts: int, + ) -> tuple[int, int, torch.dtype]: + padded_M = aq.shape[1] + workspace1 = self.max_experts_per_worker * padded_M * max(N, K) + workspace2 = self.max_experts_per_worker * padded_M * (N // 2) + return (workspace1, workspace2, self.out_dtype) + + def apply( + self, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + workspace13: torch.Tensor, + workspace2: torch.Tensor, + expert_num_tokens: Optional[torch.Tensor], + ) -> torch.Tensor: + assert w1_zp is None, "w1_zp is not supported in CUTLASS MoE" + assert w2_zp is None, "w2_zp is not supported in CUTLASS MoE" + activation_callable = lambda i, o: self.activation(activation, i, o) + return run_cutlass_moe_fp8(hidden_states, w1, w2, topk_ids, + activation_callable, global_num_experts, + expert_map, w1_scale, w2_scale, a1q_scale, + a2_scale, workspace13, workspace2, + expert_num_tokens, self.out_dtype, + self.per_act_token, self.per_out_ch) -#TODO make the grouped gemm kernel consistent with scaled gemm kernel def cutlass_moe_fp8( a: torch.Tensor, w1_q: torch.Tensor, w2_q: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, - ab_strides1: torch.Tensor, - c_strides1: torch.Tensor, - ab_strides2: torch.Tensor, - c_strides2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + activation: str = "silu", a1_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None, - out_dtype: torch.dtype = torch.half, expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, + global_num_experts: int = -1, ) -> torch.Tensor: """ This function computes a a8w8-quantized Mixture of Experts (MoE) layer @@ -207,25 +286,17 @@ def cutlass_moe_fp8( Shape: [num_experts, K, 2N] (the weights are passed transposed) - w2_q (torch.Tensor): The second set of fp8-quantized expert weights. Shape: [num_experts, N, K] (the weights are passed transposed) + - topk_weights (torch.Tensor): The weights of each token->expert mapping. + - topk_ids (torch.Tensor): The token->expert mappings. - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q. Shape: [num_experts] or [num_experts, 2N] - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q. Shape: [num_experts] or [num_experts, K] - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk_weights (torch.Tensor): The weights of each token->expert mapping. - - ab_strides1 (torch.Tensor): The input and weights strides of the first - grouped gemm. - - c_strides1 (torch.Tensor): The output strides of the first grouped gemm. - - ab_strides2 (torch.Tensor): The input and weights strides of the second - grouped gemm. - - c_strides2 (torch.Tensor): The output strides of the second grouped gemm. - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a. Shape: scalar or [M] - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize the intermediate result between the gemms. Shape: scalar or [M] - - out_dtype (torch.dtype): The output tensor type. - expert_map (Optional[torch.Tensor]): In the case of Expert parallel, every Rank is responsible for a subset of experts. expert_map is a mapping from global expert-id to local expert-id. When expert_map[i] @@ -233,24 +304,27 @@ def cutlass_moe_fp8( expert-id i. - apply_router_weight_on_input (bool): When true, the topk weights are applied directly on the inputs. This is only applicable when topk is 1. + - global_num_experts (int): The total number of experts. Returns: - torch.Tensor: The fp16 output tensor after applying the MoE layer. """ per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( a2_scale.numel() != 1 if a2_scale is not None else False) + per_out_ch = w1_scale.numel() != w1_q.shape[0] + + out_dtype = a.dtype fn = mk.FusedMoEModularKernel( MoEPrepareAndFinalizeNoEP( - per_channel_quant=per_act_token, quant_dtype=torch.float8_e4m3fn, + per_channel_quant=per_act_token, ), CutlassExpertsFp8( - ab_strides1, - c_strides1, - ab_strides2, - c_strides2, - out_dtype, + max_experts_per_worker=global_num_experts, + out_dtype=out_dtype, + per_act_token=per_act_token, + per_out_ch=per_out_ch, ), ) @@ -260,9 +334,12 @@ def cutlass_moe_fp8( w2_q, topk_weights, topk_ids, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, + False, + activation, + global_num_experts if global_num_experts != -1 else w1_q.size(0), + expert_map, + w1_scale, + w2_scale, a1_scale=a1_scale, a2_scale=a2_scale, apply_router_weight_on_input=apply_router_weight_on_input, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 97b4a49c064..c00e849b4eb 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -73,6 +73,7 @@ def __init__(self): def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 7490a192df9..68a3485ff1f 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -521,6 +521,7 @@ def __init__( def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, @@ -632,6 +633,7 @@ def __init__( def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index de7a9a8d0b3..ba1498e6531 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1545,6 +1545,7 @@ def __init__( def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1812f3b6759..cf8e4ee6509 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -9,6 +9,9 @@ import torch import torch.nn.functional as F +from compressed_tensors.quantization import (QuantizationArgs, + QuantizationStrategy, + QuantizationType) from torch.nn.parameter import UninitializedParameter import vllm.envs as envs @@ -210,6 +213,7 @@ class MoEConfig: moe_parallel_config: FusedMoEParallelConfig in_dtype: torch.dtype # The activation type. + quant_dtype: torch.dtype = None # TODO: add more quantization params, blocked, per-token, etc. block_size: int = 128 @@ -264,8 +268,22 @@ class FusedMoeWeightScaleSupported(Enum): BLOCK = "block" +def get_quant_config_input_activations( + quant_config: Optional[QuantizationConfig] +) -> Optional[QuantizationArgs]: + if (quant_config is not None and hasattr(quant_config, 'target_scheme_map') + and "Linear" in quant_config.target_scheme_map and + "input_activations" in quant_config.target_scheme_map["Linear"]): + return quant_config.target_scheme_map["Linear"].get( + "input_activations") + else: + return None + + class FusedMoEMethodBase(QuantizeMethodBase): + moe: MoEConfig + @abstractmethod def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, @@ -277,6 +295,7 @@ def init_prepare_finalize(self, moe: MoEConfig, all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None + self.moe = moe quant_dtype = None act_quant_block_size = None from vllm.model_executor.layers.quantization.fp8 import Fp8Config @@ -297,13 +316,14 @@ def init_prepare_finalize(self, moe: MoEConfig, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, hidden_dim=moe.hidden_dim, - hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize, + hidden_dim_bytes=moe.hidden_dim * moe.quant_dtype.itemsize, # For blocked per token: set to # ceil_div(hidden_dim, block_size) * sizeof(float32) # For per-token: set to sizeof(float32) - hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else ( - (moe.hidden_dim + moe.block_size - 1) // moe.block_size * - torch.float32.itemsize)), + hidden_dim_scale_bytes=( + 0 if moe.quant_dtype.itemsize != 1 else + ((moe.hidden_dim + moe.block_size - 1) // moe.block_size * + torch.float32.itemsize)), ) # Intranode pplx a2a takes a group name while internode does not. @@ -313,6 +333,9 @@ def init_prepare_finalize(self, moe: MoEConfig, handle = all2all_manager.get_handle(all_to_all_args) + input_activations = get_quant_config_input_activations( + quant_config) + prepare_finalize = PplxPrepareAndFinalize( handle, max_num_tokens=moe.max_num_tokens, @@ -320,7 +343,10 @@ def init_prepare_finalize(self, moe: MoEConfig, rank=all2all_manager.rank, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, - quant_dtype=moe.in_dtype, + quant_dtype=moe.quant_dtype, + per_act_token=(input_activations.strategy + == QuantizationStrategy.TOKEN + if input_activations is not None else False), ) elif moe.use_deepep_ht_kernels: assert moe.dp_size == all2all_manager.dp_world_size @@ -365,15 +391,15 @@ def init_prepare_finalize(self, moe: MoEConfig, self.topk_indices_dtype = None if prepare_finalize is not None: self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() - experts = self.select_gemm_impl(prepare_finalize) + experts = self.select_gemm_impl(prepare_finalize, moe) self.fused_experts = FusedMoEModularKernel( prepare_finalize, experts, ) def select_gemm_impl( - self, prepare_finalize: FusedMoEPrepareAndFinalize - ) -> FusedMoEPermuteExpertsUnpermute: + self, prepare_finalize: FusedMoEPrepareAndFinalize, + moe: Optional[MoEConfig]) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation raise NotImplementedError( @@ -419,7 +445,8 @@ def __init__(self, moe: MoEConfig): else: self.rocm_aiter_fused_experts = None # type: ignore - def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize): + def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize, + moe: Optional[MoEConfig]): assert self.fused_experts == fused_experts @@ -809,7 +836,6 @@ def __init__( activation: str = "silu", ): super().__init__() - if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype @@ -869,14 +895,24 @@ def __init__( from vllm_hpu_extension.ops import DynamicFusedMOE self.hpu_fused_moe = DynamicFusedMOE(self.global_num_experts) + # Only support float8 for now. + quant_dtype = params_dtype + if quant_config is not None: + input_activations = get_quant_config_input_activations( + quant_config) + if (input_activations is not None + and input_activations.num_bits == 8 + and input_activations.type == QuantizationType.FLOAT): + quant_dtype = torch.float8_e4m3fn + moe = MoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, num_local_experts=self.local_num_experts, moe_parallel_config=self.moe_parallel_config, - # TODO (bnell): this needs to be fixed for quantized types. in_dtype=params_dtype, + quant_dtype=quant_dtype, max_num_tokens=MOE_DP_CHUNK_SIZE, ) self.moe_config = moe diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 2c27d31eb6e..e7aaf62fb34 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -175,6 +175,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, @@ -309,7 +310,7 @@ def _do_fused_experts( # Use a1 here to decipher the correct workspace datatype workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(a1, M, N, K, top_k, + self.fused_experts.workspace_shapes(a1, a1q, M, N, K, top_k, global_num_experts)) # We can reuse the memory between cache1 and cache3 because by the time diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 1170a16f3de..5bc01dbf202 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -21,7 +21,8 @@ def __init__(self, rank: int, dp_size: int, quant_dtype: Optional[torch.dtype] = None, - block_shape: Optional[list[int]] = None): + block_shape: Optional[list[int]] = None, + per_act_token: bool = False): super().__init__() assert max_num_tokens > 0 self.a2a = a2a @@ -31,6 +32,7 @@ def __init__(self, self.rank = rank self.dp_size = dp_size self.quant_dtype = quant_dtype + self.per_act_token = per_act_token def max_num_tokens_per_rank(self) -> Optional[int]: return self.max_num_tokens @@ -66,13 +68,14 @@ def prepare( "apply_router_weight_on_input is only implemented for topk=1") a1 = a1 * rank_topk_weights.to(a1.dtype) - per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( - a2_scale.numel() != 1 if a2_scale is not None else False) + repeat_cols = 4 + repeat_rows = 1 if self.per_act_token else a1.shape[0] + a1q, a1q_scale = moe_kernel_quantize_input( + a1, (None if self.per_act_token else a1_scale), self.quant_dtype, + self.per_act_token, self.block_shape) - a1q, a1q_scale = moe_kernel_quantize_input(a1, a1_scale, - self.quant_dtype, - per_act_token, - self.block_shape) + if a1q_scale is not None: + a1q_scale = a1q_scale.repeat(repeat_rows, repeat_cols) # rem_experts need to be 0 for pplx to work properly. rem_experts = num_experts % self.world_size @@ -100,7 +103,7 @@ def prepare( else 1) * float32_size expert_x_scale = torch.empty( ( - num_experts, + num_local_experts, expert_x.size(1), (expert_x.size(2) + block_size - 1) // block_size, ), @@ -121,6 +124,8 @@ def prepare( indices=rank_topk_ids, bound_m=bound_m, ) + if expert_x_scale is not None: + expert_x_scale = expert_x_scale[:, :, 0:1] return expert_x, expert_x_scale, expert_num_tokens, None, None diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 920931a93d3..87de29444c0 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -37,6 +37,7 @@ def __init__(self, def workspace_shapes( self, a: torch.Tensor, + aq: torch.Tensor, M: int, N: int, K: int, @@ -49,9 +50,9 @@ def workspace_shapes( if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, M, N, K, topk, num_experts) + a, aq, M, N, K, topk, num_experts) else: - return self.triton_expert.workspace_shapes(a, M, N, K, topk, + return self.triton_expert.workspace_shapes(a, aq, M, N, K, topk, num_experts) def apply( diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index ebb029572a1..bc9d399cf13 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import enum +import importlib from enum import Enum from typing import Callable, Optional @@ -11,7 +12,6 @@ QuantizationStrategy) import vllm.envs as envs -import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, @@ -30,6 +30,15 @@ from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +has_pplx = importlib.util.find_spec("pplx_kernels") is not None + +if current_platform.is_cuda_alike(): + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedPrepareAndFinalize) + if has_pplx: + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + logger = init_logger(__name__) @@ -77,8 +86,7 @@ def get_moe_method( else: logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod") return CompressedTensorsWNA16MarlinMoEMethod(quant_config) - elif (quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant) - and layer.activation == "silu"): + elif quant_config._is_fp8_w8a8_sm90(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config) elif quant_config._is_fp8_w8a8(weight_quant, input_quant): return CompressedTensorsW8A8Fp8MoEMethod(quant_config) @@ -421,6 +429,11 @@ def __init__( "For FP8 Fused MoE layer, we require either per tensor or " "channelwise, dynamic per token quantization.") + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + cutlass_moe_fp8) + self.fused_experts = cutlass_moe_fp8 # type: ignore + self.disable_expert_map = False + def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs): @@ -499,25 +512,6 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.w13_input_scale = None layer.w2_input_scale = None - device = w13_weight.device - # TODO strides can be shared across multiple layers - self.ab_strides1 = torch.full((num_experts, ), - hidden_size, - device=device, - dtype=torch.int64) - self.c_strides1 = torch.full((num_experts, ), - 2 * intermediate_size_per_partition, - device=device, - dtype=torch.int64) - self.ab_strides2 = torch.full((num_experts, ), - intermediate_size_per_partition, - device=device, - dtype=torch.int64) - self.c_strides2 = torch.full((num_experts, ), - hidden_size, - device=device, - dtype=torch.int64) - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Fp8 moe kernels require a single activation scale. # We take the max of all the scales in case they differ. @@ -558,6 +552,27 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) + def select_gemm_impl(self, prepare_finalize, moe): + from vllm.model_executor.layers.fused_moe.cutlass_moe import ( + CutlassExpertsFp8) + + assert moe is not None + + max_experts_per_worker = ( + (moe.num_experts + prepare_finalize.world_size - 1) // + prepare_finalize.world_size) + experts = CutlassExpertsFp8( + max_experts_per_worker, moe.in_dtype, + self.input_quant.strategy == QuantizationStrategy.TOKEN, + self.weight_quant.strategy == QuantizationStrategy.CHANNEL) + + if has_pplx and isinstance( + prepare_finalize, + (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + # no expert_map support in this case + self.disable_expert_map = True + return experts + def apply( self, layer: torch.nn.Module, @@ -577,9 +592,6 @@ def apply( activation: str = "silu", ) -> torch.Tensor: - assert activation == "silu", ( - f"{activation} not supported for Cutlass MoE.") - topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -590,27 +602,22 @@ def apply( num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias) - - from vllm.model_executor.layers.fused_moe import cutlass_moe_fp8 + e_score_correction_bias=e_score_correction_bias, + indices_type=torch.uint32) - return cutlass_moe_fp8( + return self.fused_experts( x, - layer.w13_weight.transpose(1, 2), - layer.w2_weight.transpose(1, 2), - layer.w13_weight_scale, - layer.w2_weight_scale, + layer.w13_weight, + layer.w2_weight, topk_weights, topk_ids, - self.ab_strides1, - self.c_strides1, - self.ab_strides2, - self.c_strides2, + activation=activation, + global_num_experts=global_num_experts, + expert_map=None if self.disable_expert_map else expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, - out_dtype=x.dtype, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, ) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 5ac22b6a0ae..c785e0d1674 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -769,7 +769,7 @@ def process_weights_after_loading(self, layer: Module) -> None: del layer.w13_input_scale del layer.w2_input_scale - def select_gemm_impl(self, prepare_finalize): + def select_gemm_impl(self, prepare_finalize, moe): from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 BatchedTritonOrDeepGemmExperts)