diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 97dcc42312f..bbc896ec681 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -82,6 +82,14 @@ if [[ $commands == *"pytest -v -s basic_correctness/test_basic_correctness.py"* commands=${commands//"pytest -v -s basic_correctness/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s basic_correctness/test_basic_correctness.py"} fi +if [[ $commands == *"pytest -v -s models/test_registry.py"* ]]; then + commands=${commands//"pytest -v -s models/test_registry.py"/"pytest -v -s models/test_registry.py -k 'not BambaForCausalLM and not GritLM and not Mamba2ForCausalLM and not Zamba2ForCausalLM'"} +fi + +if [[ $commands == *"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"* ]]; then + commands=${commands//"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2'"/"VLLM_USE_V1=0 pytest -v -s models/test_initialization.py -k 'not llama4 and not plamo2 and not BambaForCausalLM and not Gemma2ForCausalLM and not Grok1ModelForCausalLM and not Zamba2ForCausalLM and not Gemma2Model and not GritLM'"} +fi + if [[ $commands == *"pytest -v -s compile/test_basic_correctness.py"* ]]; then commands=${commands//"pytest -v -s compile/test_basic_correctness.py"/"VLLM_USE_TRITON_FLASH_ATTN=0 pytest -v -s compile/test_basic_correctness.py"} fi diff --git a/.buildkite/scripts/upload-wheels.sh b/.buildkite/scripts/upload-wheels.sh index 75e3ef26409..037897e53db 100644 --- a/.buildkite/scripts/upload-wheels.sh +++ b/.buildkite/scripts/upload-wheels.sh @@ -75,3 +75,4 @@ else fi aws s3 cp "$wheel" "s3://vllm-wheels/$version/" +aws s3 cp index.html "s3://vllm-wheels/$version/vllm/index.html" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 1040d1e1b80..461fb6d30c4 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -148,6 +148,8 @@ steps: # test with tp=2 and external_dp=2 - VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py - torchrun --nproc-per-node=4 distributed/test_torchrun_example.py + # test with tp=2 and pp=2 + - PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py # test with internal dp - python3 ../examples/offline_inference/data_parallel.py - TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py @@ -216,7 +218,6 @@ steps: - pytest -v -s v1/spec_decode - pytest -v -s v1/kv_connector/unit - pytest -v -s v1/test_serial_utils.py - - pytest -v -s v1/test_stats.py - pytest -v -s v1/test_utils.py - pytest -v -s v1/test_oracle.py # TODO: accuracy does not match, whether setting @@ -380,7 +381,7 @@ steps: - pytest -v -s kernels/mamba - label: Tensorizer Test # 11min - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader @@ -456,7 +457,7 @@ steps: ##### models test ##### - label: Basic Models Test # 24min - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] torch_nightly: true source_file_dependencies: - vllm/ @@ -528,7 +529,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model' - label: Multi-Modal Models Test (Extended) 3 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] optional: true source_file_dependencies: - vllm/ @@ -538,7 +539,7 @@ steps: - pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model' - label: Quantized Models Test - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] source_file_dependencies: - vllm/model_executor/layers/quantization - tests/models/quantization diff --git a/CMakeLists.txt b/CMakeLists.txt index fed6e11e5ef..a6c54be9530 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -301,7 +301,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Only build Marlin kernels if we are building for at least some compatible archs. # Keep building Marlin for 9.0 as there are some group sizes and shapes that # are not supported by Machete yet. - cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + # 9.0 for latest bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_ARCHS) # @@ -445,8 +446,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) # kernels for the remaining archs that are not already built for 3x. + # (Build 8.9 for FP8) cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + "7.5;8.0;8.9+PTX" "${CUDA_ARCHS}") # subtract out the archs that are already built for 3x list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) if (SCALED_MM_2X_ARCHS) @@ -675,7 +677,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") list(APPEND VLLM_MOE_EXT_SRC "${VLLM_MOE_WNA16_SRC}") - cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.7;8.9;9.0;10.0;10.1;12.0" "${CUDA_ARCHS}") + # 9.0 for latest bf16 atomicAdd PTX + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;9.0+PTX" "${CUDA_ARCHS}") if (MARLIN_MOE_ARCHS) # diff --git a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py index 504c5f5812e..08e93837f7d 100644 --- a/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py +++ b/benchmarks/cutlass_benchmarks/w8a8_benchmarks.py @@ -115,8 +115,16 @@ def bench_fp8( a_cont = a.contiguous() scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32) scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32) - block_scale_a = torch.rand((m, k // 128), device="cuda", dtype=torch.float32) - block_scale_b = torch.rand((k // 128, n // 128), device="cuda", dtype=torch.float32) + + def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + block_scale_a = torch.rand( + (m, ceil_div(k, 128)), device="cuda", dtype=torch.float32 + ) + block_scale_b = torch.rand( + ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32 + ) block_scale_a_M_major = block_scale_a.t().contiguous().t() block_scale_b_K_major = block_scale_b.t().contiguous().t() bias = torch.zeros((n,), device="cuda", dtype=torch.bfloat16) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index c9cd099b82a..12e4e39024f 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -228,11 +228,26 @@ macro(set_gencode_flags_for_srcs) "${multiValueArgs}" ${ARGN} ) foreach(_ARCH ${arg_CUDA_ARCHS}) - string(REPLACE "." "" _ARCH "${_ARCH}") - set_gencode_flag_for_srcs( - SRCS ${arg_SRCS} - ARCH "compute_${_ARCH}" - CODE "sm_${_ARCH}") + # handle +PTX suffix: generate both sm and ptx codes if requested + string(FIND "${_ARCH}" "+PTX" _HAS_PTX) + if(NOT _HAS_PTX EQUAL -1) + string(REPLACE "+PTX" "" _BASE_ARCH "${_ARCH}") + string(REPLACE "." "" _STRIPPED_ARCH "${_BASE_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "compute_${_STRIPPED_ARCH}") + else() + string(REPLACE "." "" _STRIPPED_ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_STRIPPED_ARCH}" + CODE "sm_${_STRIPPED_ARCH}") + endif() endforeach() if (${arg_BUILD_PTX_FOR_ARCH}) @@ -251,7 +266,10 @@ endmacro() # # For the given `SRC_CUDA_ARCHS` list of gencode versions in the form # `.[letter]` compute the "loose intersection" with the -# `TGT_CUDA_ARCHS` list of gencodes. +# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in +# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there +# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the +# architecture in `SRC_CUDA_ARCHS`. # The loose intersection is defined as: # { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } # where `<=` is the version comparison operator. @@ -268,44 +286,63 @@ endmacro() # cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) # OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" # +# Example With PTX: +# SRC_CUDA_ARCHS="8.0+PTX" +# TGT_CUDA_ARCHS="9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0+PTX" +# function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) - list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) - set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS}) + set(_SRC_CUDA_ARCHS "${SRC_CUDA_ARCHS}") + set(_TGT_CUDA_ARCHS ${TGT_CUDA_ARCHS}) + + # handle +PTX suffix: separate base arch for matching, record PTX requests + set(_PTX_ARCHS) + foreach(_arch ${_SRC_CUDA_ARCHS}) + if(_arch MATCHES "\\+PTX$") + string(REPLACE "+PTX" "" _base "${_arch}") + list(APPEND _PTX_ARCHS "${_base}") + list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}") + list(APPEND _SRC_CUDA_ARCHS "${_base}") + endif() + endforeach() + list(REMOVE_DUPLICATES _PTX_ARCHS) + list(REMOVE_DUPLICATES _SRC_CUDA_ARCHS) # if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should # remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS set(_CUDA_ARCHS) - if ("9.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") - if ("9.0" IN_LIST TGT_CUDA_ARCHS_) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "9.0") + if ("9.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS) + list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0") set(_CUDA_ARCHS "9.0a") endif() endif() - if ("10.0a" IN_LIST SRC_CUDA_ARCHS) - list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a") + if ("10.0a" IN_LIST _SRC_CUDA_ARCHS) + list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a") if ("10.0" IN_LIST TGT_CUDA_ARCHS) - list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0") + list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0") set(_CUDA_ARCHS "10.0a") endif() endif() - list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) # for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that # is less or equal to ARCH (but has the same major version since SASS binary # compatibility is only forward compatible within the same major version). - foreach(_ARCH ${TGT_CUDA_ARCHS_}) + foreach(_ARCH ${_TGT_CUDA_ARCHS}) set(_TMP_ARCH) # Extract the major version of the target arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" TGT_ARCH_MAJOR "${_ARCH}") - foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + foreach(_SRC_ARCH ${_SRC_CUDA_ARCHS}) # Extract the major version of the source arch string(REGEX REPLACE "^([0-9]+)\\..*$" "\\1" SRC_ARCH_MAJOR "${_SRC_ARCH}") - # Check major-version match AND version-less-or-equal + # Check version-less-or-equal, and allow PTX arches to match across majors if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) - if (SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) + if (_SRC_ARCH IN_LIST _PTX_ARCHS OR SRC_ARCH_MAJOR STREQUAL TGT_ARCH_MAJOR) set(_TMP_ARCH "${_SRC_ARCH}") endif() else() @@ -321,6 +358,18 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR endforeach() list(REMOVE_DUPLICATES _CUDA_ARCHS) + + # reapply +PTX suffix to architectures that requested PTX + set(_FINAL_ARCHS) + foreach(_arch ${_CUDA_ARCHS}) + if(_arch IN_LIST _PTX_ARCHS) + list(APPEND _FINAL_ARCHS "${_arch}+PTX") + else() + list(APPEND _FINAL_ARCHS "${_arch}") + endif() + endforeach() + set(_CUDA_ARCHS ${_FINAL_ARCHS}) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) endfunction() diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 88275dbdd83..55e65967970 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -70,6 +70,9 @@ __device__ __forceinline__ T gelu_tanh_kernel(const T& x) { int64_t num_tokens = input.numel() / input.size(-1); \ dim3 grid(num_tokens); \ dim3 block(std::min(d, 1024)); \ + if (num_tokens == 0) { \ + return; \ + } \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ VLLM_DISPATCH_FLOATING_TYPES( \ diff --git a/csrc/attention/attention_kernels.cuh b/csrc/attention/attention_kernels.cuh index eb216dc8baf..79a546554fa 100644 --- a/csrc/attention/attention_kernels.cuh +++ b/csrc/attention/attention_kernels.cuh @@ -172,7 +172,7 @@ __device__ void paged_attention_kernel( // Load the query to registers. // Each thread in a thread group has a different part of the query. - // For example, if the the thread group size is 4, then the first thread in + // For example, if the thread group size is 4, then the first thread in // the group has 0, 4, 8, ... th vectors of the query, and the second thread // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because // q is split from a qkv tensor, it may not be contiguous. @@ -259,7 +259,7 @@ __device__ void paged_attention_kernel( // Load a key to registers. // Each thread in a thread group has a different part of the key. - // For example, if the the thread group size is 4, then the first thread in + // For example, if the thread group size is 4, then the first thread in // the group has 0, 4, 8, ... th vectors of the key, and the second thread // has 1, 5, 9, ... th vectors of the key, and so on. for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { diff --git a/csrc/dispatch_utils.h b/csrc/dispatch_utils.h index dc6e0769b87..f7b75c48373 100644 --- a/csrc/dispatch_utils.h +++ b/csrc/dispatch_utils.h @@ -65,5 +65,19 @@ AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) +#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__) + #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__)) diff --git a/csrc/moe/moe_align_sum_kernels.cu b/csrc/moe/moe_align_sum_kernels.cu index d7be769458e..6b6a9d04a60 100644 --- a/csrc/moe/moe_align_sum_kernels.cu +++ b/csrc/moe/moe_align_sum_kernels.cu @@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, } if (use_global_memory) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] { // calc needed amount of shared mem for `tokens_cnts` and `cumsum` // tensors @@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, cumsum_buffer.data_ptr()); }); } else if (use_i16) { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { // set dynamic shared mem auto kernel = @@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, topk_ids.numel()); }); } else { - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { auto kernel = vllm::moe::moe_align_block_size_kernel; @@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, TORCH_CHECK(num_experts == 256, "sgl_moe_align_block_size kernel only supports deepseek v3."); - VLLM_DISPATCH_INTEGRAL_TYPES( + VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES( topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] { // calc needed amount of shared mem for `cumsum` tensors auto options_int = diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index de9747b6025..a9379032245 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__ } } -template -__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, - int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert) +template +__launch_bounds__(TPB) __global__ void moeTopK( + const float* inputs_after_softmax, + const bool* finished, + float* output, + IndType* indices, + int* source_rows, + const int num_experts, + const int k, + const int start_expert, + const int end_expert) { using cub_kvp = cub::KeyValuePair; @@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax 2) This implementation assumes k is small, but will work for any k. */ -template +template __launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__ - void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices, + void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices, int* source_rows, const int k, const int start_expert, const int end_expert) { // We begin by enforcing compile time assertions and setting up compile time constants. @@ -397,8 +405,8 @@ struct TopkConstants }; } // namespace detail -template -void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices, +template +void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices, int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream) { static constexpr std::size_t MAX_BYTES_PER_LDG = 16; @@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f token_expert_indices, num_tokens, topk, 0, num_experts, \ stream); +template void topkGatingSoftmaxKernelLauncher( const float* gating_output, float* topk_weights, - int* topk_indicies, + IndType* topk_indicies, int* token_expert_indices, float* softmax_workspace, const int num_tokens, @@ -493,14 +502,32 @@ void topk_softmax( const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); - vllm::moe::topkGatingSoftmaxKernelLauncher( - gating_output.data_ptr(), - topk_weights.data_ptr(), - topk_indices.data_ptr(), - token_expert_indices.data_ptr(), - softmax_workspace.data_ptr(), - num_tokens, - num_experts, - topk, - stream); + + if(topk_indices.scalar_type() == at::ScalarType::Int) + { + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } + else + { + assert(topk_indices.scalar_type() == at::ScalarType::UInt32); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } } diff --git a/csrc/pos_encoding_kernels.cu b/csrc/pos_encoding_kernels.cu index ef6dd1c0978..266f2a0667a 100644 --- a/csrc/pos_encoding_kernels.cu +++ b/csrc/pos_encoding_kernels.cu @@ -44,7 +44,8 @@ inline __device__ void apply_rotary_embedding( // head_size] const scalar_t* cache_ptr, const int head_size, const int num_heads, const int num_kv_heads, const int rot_dim, const int token_idx, - const int64_t query_stride, const int64_t key_stride) { + const int64_t query_stride, const int64_t key_stride, + const int64_t head_stride) { const int embed_dim = rot_dim / 2; const scalar_t* cos_ptr = cache_ptr; const scalar_t* sin_ptr = cache_ptr + embed_dim; @@ -52,7 +53,8 @@ inline __device__ void apply_rotary_embedding( const int nq = num_heads * embed_dim; for (int i = threadIdx.x; i < nq; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int64_t token_head = + token_idx * query_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -62,7 +64,8 @@ inline __device__ void apply_rotary_embedding( const int nk = num_kv_heads * embed_dim; for (int i = threadIdx.x; i < nk; i += blockDim.x) { const int head_idx = i / embed_dim; - const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int64_t token_head = + token_idx * key_stride + head_idx * head_stride; const int rot_offset = i % embed_dim; apply_token_rotary_embedding( key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); @@ -84,7 +87,8 @@ __global__ void rotary_embedding_kernel( const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -92,7 +96,7 @@ __global__ void rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } template @@ -109,9 +113,9 @@ __global__ void batched_rotary_embedding_kernel( const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // // 2] const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] - // or [num_tokens] const int rot_dim, const int64_t query_stride, const int64_t key_stride, - const int num_heads, const int num_kv_heads, const int head_size) { + const int64_t head_stride, const int num_heads, const int num_kv_heads, + const int head_size) { // Each thread block is responsible for one token. const int token_idx = blockIdx.x; int64_t pos = positions[token_idx]; @@ -121,7 +125,7 @@ __global__ void batched_rotary_embedding_kernel( apply_rotary_embedding( query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, - token_idx, query_stride, key_stride); + token_idx, query_stride, key_stride, head_stride); } } // namespace vllm @@ -179,6 +183,12 @@ void rotary_embedding( int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -190,14 +200,14 @@ void rotary_embedding( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), rot_dim, query_stride, key_stride, - num_heads, num_kv_heads, head_size); + head_stride, num_heads, num_kv_heads, head_size); } else { vllm::rotary_embedding_kernel <<>>( positions.data_ptr(), query.data_ptr(), key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } @@ -263,6 +273,12 @@ void batched_rotary_embedding( int seq_dim_idx = positions_ndim - 1; int64_t query_stride = query.stride(seq_dim_idx); int64_t key_stride = key.has_value() ? key->stride(seq_dim_idx) : 0; + // Determine head stride: for [*, heads, head_size] use stride of last dim; + // for flat [*, heads*head_size], heads blocks are contiguous of size + // head_size + int query_ndim = query.dim(); + int64_t head_stride = + (query_ndim == positions_ndim + 2) ? query.stride(-2) : head_size; dim3 grid(num_tokens); dim3 block(std::min(num_heads * rot_dim / 2, 512)); @@ -276,7 +292,7 @@ void batched_rotary_embedding( key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } else { vllm::batched_rotary_embedding_kernel <<>>( @@ -284,7 +300,7 @@ void batched_rotary_embedding( key.has_value() ? key->data_ptr() : nullptr, cos_sin_cache.data_ptr(), cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, - key_stride, num_heads, num_kv_heads, head_size); + key_stride, head_stride, num_heads, num_kv_heads, head_size); } }); } diff --git a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp index b589a479081..2ee6a19407f 100644 --- a/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp +++ b/csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp @@ -1,5 +1,6 @@ #include #include "cuda_utils.h" +#include "cutlass_extensions/common.hpp" template void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, @@ -28,29 +29,46 @@ void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, } } } else { - using GroupShape = std::array; - auto make_group_shape = [](torch::Tensor const& x, - torch::Tensor const& s) -> GroupShape { - TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); - return {cuda_utils::ceil_div(x.size(0), s.size(0)), - cuda_utils::ceil_div(x.size(1), s.size(1))}; - }; + TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); + TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); + int32_t version_num = get_sm_version_num(); + if (version_num >= 100) { + TORCH_CHECK( + a.size(0) == a_scales.size(0) && + cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), + "a_scale_group_shape must be [1, 128]."); + TORCH_CHECK( + cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && + cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), + "b_scale_group_shape must be [128, 128]."); + } else { + // TODO: Remove this after using cutlass sm90 blockwise scaling gemm + // kernel, or introducing ceil_div to the load_init() of mainloop. + using GroupShape = std::array; + auto make_group_shape = [](torch::Tensor const& x, + torch::Tensor const& s) -> GroupShape { + TORCH_CHECK(s.dim() == 2, "cutlass_scaled_mm group scales must be 2D"); + return {cuda_utils::ceil_div(x.size(0), s.size(0)), + cuda_utils::ceil_div(x.size(1), s.size(1))}; + }; + + GroupShape a_scale_group_shape = make_group_shape(a, a_scales); + GroupShape b_scale_group_shape = make_group_shape(b, b_scales); - GroupShape a_scale_group_shape = make_group_shape(a, a_scales); - GroupShape b_scale_group_shape = make_group_shape(b, b_scales); + // 1x128 per-token group scales for activations + // 128x128 blockwise scales for weights + TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && + b_scale_group_shape == GroupShape{128, 128} && + a.dtype() == torch::kFloat8_e4m3fn && + b.dtype() == torch::kFloat8_e4m3fn), + "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" + "a_scale_group_shape must be [1, 128]. Got: [", + a_scale_group_shape[0], ", ", a_scale_group_shape[1], + "]\n" + "b_scale_group_shape must be [128, 128]. Got: [", + b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); + } - // 1x128 per-token group scales for activations - // 128x128 blockwise scales for weights - TORCH_CHECK((a_scale_group_shape == GroupShape{1, 128} && - b_scale_group_shape == GroupShape{128, 128} && - a.dtype() == torch::kFloat8_e4m3fn && - b.dtype() == torch::kFloat8_e4m3fn), - "cutlass_scaled_mm only supports datatype float8_e4m3fn.\n" - "a_scale_group_shape must be [1, 128]. Got: [", - a_scale_group_shape[0], ", ", a_scale_group_shape[1], - "]\n" - "b_scale_group_shape must be [128, 128]. Got: [", - b_scale_group_shape[0], ", ", b_scale_group_shape[1], "]"); TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); blockwise_func(c, a, b, a_scales, b_scales); } diff --git a/docs/source/deployment/frameworks/index.md b/docs/source/deployment/frameworks/index.md index 6708f2c4135..9744f5f4d36 100644 --- a/docs/source/deployment/frameworks/index.md +++ b/docs/source/deployment/frameworks/index.md @@ -10,6 +10,7 @@ chatbox dify dstack helm +lobe-chat lws modal open-webui diff --git a/docs/source/deployment/frameworks/lobe-chat.md b/docs/source/deployment/frameworks/lobe-chat.md new file mode 100644 index 00000000000..6d86b7fa9cc --- /dev/null +++ b/docs/source/deployment/frameworks/lobe-chat.md @@ -0,0 +1,13 @@ +(deployment-lobe-chat)= + +# Lobe Chat + +[Lobe Chat](https://github.com/lobehub/lobe-chat) is an open-source, modern-design ChatGPT/LLMs UI/Framework. + +Supports speech-synthesis, multi-modal, and extensible (function call) plugin system. + +One-click FREE deployment of your private OpenAI ChatGPT/Claude/Gemini/Groq/Ollama chat application. + +It supports vLLM as a AI model provider to efficiently serve large language models. + +For details, see the tutorial [Using vLLM in LobeChat](https://lobehub.com/docs/usage/providers/vllm). diff --git a/docs/source/features/reasoning_outputs.md b/docs/source/features/reasoning_outputs.md index 4759d0c26c3..3c2571298e4 100644 --- a/docs/source/features/reasoning_outputs.md +++ b/docs/source/features/reasoning_outputs.md @@ -141,10 +141,10 @@ Remember to check whether the `reasoning_content` exists in the response before The reasoning content is also available in the structured output. The structured output engine like `xgrammar` will use the reasoning content to generate structured output. It is only supported in v0 engine now. ```bash -VLLM_USE_V1=0 vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 +vllm serve deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B --reasoning-parser deepseek_r1 ``` -Please note that the `VLLM_USE_V1` environment variable must be set to `0` to use the v0 engine. +The following is an example client: ```python from openai import OpenAI diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index 25189b006c2..298ba59f7d8 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -19,8 +19,8 @@ If you are using NVIDIA GPUs, you can install vLLM using [pip](https://pypi.org/ It's recommended to use [uv](https://docs.astral.sh/uv/), a very fast Python environment manager, to create and manage Python environments. Please follow the [documentation](https://docs.astral.sh/uv/#getting-started) to install `uv`. After installing `uv`, you can create a new Python environment and install vLLM using the following commands: ```console -uv venv myenv --python 3.12 --seed -source myenv/bin/activate +uv venv --python 3.12 --seed +source .venv/bin/activate uv pip install vllm ``` diff --git a/examples/offline_inference/chat_with_tools.py b/examples/offline_inference/chat_with_tools.py index 15519bfed9c..b532bf42adf 100644 --- a/examples/offline_inference/chat_with_tools.py +++ b/examples/offline_inference/chat_with_tools.py @@ -68,7 +68,7 @@ def get_current_weather(city: str, state: str, unit: 'str'): "partly cloudly, with highs in the 90's.") -tool_funtions = {"get_current_weather": get_current_weather} +tool_functions = {"get_current_weather": get_current_weather} tools = [{ "type": "function", @@ -122,7 +122,7 @@ def get_current_weather(city: str, state: str, unit: 'str'): # above defined function tool_calls = json.loads(output) tool_answers = [ - tool_funtions[call['name']](**call['arguments']) for call in tool_calls + tool_functions[call['name']](**call['arguments']) for call in tool_calls ] # append the answer as a tool message and let the LLM give you an answer diff --git a/examples/offline_inference/data_parallel.py b/examples/offline_inference/data_parallel.py index 965915beaf5..f636a08c0b0 100644 --- a/examples/offline_inference/data_parallel.py +++ b/examples/offline_inference/data_parallel.py @@ -65,11 +65,17 @@ def parse_args(): type=int, default=0, help="Master node port") + parser.add_argument("--enforce-eager", + action='store_true', + help="Enforce eager mode execution.") + parser.add_argument("--trust-remote-code", + action='store_true', + help="Trust remote code.") return parser.parse_args() def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, - dp_master_port, GPUs_per_dp_rank): + dp_master_port, GPUs_per_dp_rank, enforce_eager, trust_remote_code): os.environ["VLLM_DP_RANK"] = str(global_dp_rank) os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank) os.environ["VLLM_DP_SIZE"] = str(dp_size) @@ -109,10 +115,13 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, max_tokens=[16, 20][global_dp_rank % 2]) # Create an LLM. - llm = LLM(model=model, - tensor_parallel_size=GPUs_per_dp_rank, - enforce_eager=True, - enable_expert_parallel=True) + llm = LLM( + model=model, + tensor_parallel_size=GPUs_per_dp_rank, + enforce_eager=enforce_eager, + enable_expert_parallel=True, + trust_remote_code=trust_remote_code, + ) outputs = llm.generate(prompts, sampling_params) # Print the outputs. for i, output in enumerate(outputs): @@ -155,7 +164,8 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, proc = Process(target=main, args=(args.model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip, dp_master_port, - tp_size)) + tp_size, args.enforce_eager, + args.trust_remote_code)) proc.start() procs.append(proc) exit_code = 0 diff --git a/examples/offline_inference/disaggregated-prefill-v1/README.md b/examples/offline_inference/disaggregated-prefill-v1/README.md new file mode 100644 index 00000000000..f708eb25383 --- /dev/null +++ b/examples/offline_inference/disaggregated-prefill-v1/README.md @@ -0,0 +1,9 @@ +# Disaggregated Prefill V1 + +This example contains scripts that demonstrate disaggregated prefill in the offline setting of vLLM. + +## Files + +- `run.sh` - A helper script that will run `prefill_example.py` and `decode_example.py` sequentially. +- `prefill_example.py` - A script which performs prefill only, saving the KV state to the `local_storage` directory and the prompts to `output.txt`. +- `decode_example.py` - A script which performs decode only, loading the KV state from the `local_storage` directory and the prompts from `output.txt`. diff --git a/examples/offline_inference/eagle.py b/examples/offline_inference/eagle.py index 020521611f3..615f67e9f8d 100644 --- a/examples/offline_inference/eagle.py +++ b/examples/offline_inference/eagle.py @@ -105,6 +105,13 @@ def main(): outputs = llm.generate(prompt_token_ids=prompt_ids, sampling_params=sampling_params) + # print the generated text + for output in outputs: + print("-" * 50) + print(f"prompt: {output.prompt}") + print(f"generated text: {output.outputs[0].text}") + print("-" * 50) + if not hasattr(outputs, "metrics") or outputs.metrics is None: return diff --git a/examples/offline_inference/openai/openai_batch.md b/examples/offline_inference/openai_batch/README.md similarity index 94% rename from examples/offline_inference/openai/openai_batch.md rename to examples/offline_inference/openai_batch/README.md index d271573aa96..42a19f71e9d 100644 --- a/examples/offline_inference/openai/openai_batch.md +++ b/examples/offline_inference/openai_batch/README.md @@ -8,7 +8,7 @@ This is a guide to performing batch inference using the OpenAI batch file format The OpenAI batch file format consists of a series of json objects on new lines. -[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai/openai_example_batch.jsonl) +[See here for an example file.](https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl) Each line represents a separate request. See the [OpenAI package reference](https://platform.openai.com/docs/api-reference/batch/requestInput) for more details. @@ -30,13 +30,13 @@ We currently support `/v1/chat/completions`, `/v1/embeddings`, and `/v1/score` e To follow along with this example, you can download the example batch, or create your own batch file in your working directory. ```console -wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this ```console -$ cat offline_inference/openai/openai_example_batch.jsonl +$ cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` @@ -48,7 +48,7 @@ The batch running tool is designed to be used from the command line. You can run the batch with the following command, which will write its results to a file called `results.jsonl` ```console -python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +python -m vllm.entrypoints.openai.run_batch -i offline_inference/openai_batch/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` ### Step 3: Check your results @@ -65,10 +65,10 @@ $ cat results.jsonl The batch runner supports remote input and output urls that are accessible via http/https. -For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl`, you can run +For example, to run against our example input file located at `https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl`, you can run ```console -python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct +python -m vllm.entrypoints.openai.run_batch -i https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl -o results.jsonl --model meta-llama/Meta-Llama-3-8B-Instruct ``` ## Example 3: Integrating with AWS S3 @@ -89,13 +89,13 @@ To integrate with cloud blob storage, we recommend using presigned urls. To follow along with this example, you can download the example batch, or create your own batch file in your working directory. ```console -wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai/openai_example_batch.jsonl +wget https://raw.githubusercontent.com/vllm-project/vllm/main/examples/offline_inference/openai_batch/openai_example_batch.jsonl ``` Once you've created your batch file it should look like this ```console -$ cat offline_inference/openai/openai_example_batch.jsonl +$ cat offline_inference/openai_batch/openai_example_batch.jsonl {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} {"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "meta-llama/Meta-Llama-3-8B-Instruct", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_completion_tokens": 1000}} ``` @@ -103,7 +103,7 @@ $ cat offline_inference/openai/openai_example_batch.jsonl Now upload your batch file to your S3 bucket. ```console -aws s3 cp offline_inference/openai/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl +aws s3 cp offline_inference/openai_batch/openai_example_batch.jsonl s3://MY_BUCKET/MY_INPUT_FILE.jsonl ``` ### Step 2: Generate your presigned urls diff --git a/examples/offline_inference/openai/openai_example_batch.jsonl b/examples/offline_inference/openai_batch/openai_example_batch.jsonl similarity index 100% rename from examples/offline_inference/openai/openai_example_batch.jsonl rename to examples/offline_inference/openai_batch/openai_example_batch.jsonl diff --git a/examples/offline_inference/torchrun_example.py b/examples/offline_inference/torchrun_example.py index c6d9e6b47e2..bb61a0a29e3 100644 --- a/examples/offline_inference/torchrun_example.py +++ b/examples/offline_inference/torchrun_example.py @@ -8,6 +8,8 @@ see `tests/distributed/test_torchrun_example.py` for the unit test. """ +import torch.distributed as dist + from vllm import LLM, SamplingParams # Create prompts, the same across all ranks @@ -27,23 +29,26 @@ # all ranks have the same random seed, so that sampling can be # deterministic across ranks. llm = LLM( - model="facebook/opt-125m", + model="meta-llama/Llama-3.1-8B", tensor_parallel_size=2, + pipeline_parallel_size=2, distributed_executor_backend="external_launcher", - seed=0, + max_model_len=32768, + seed=1, ) outputs = llm.generate(prompts, sampling_params) # all ranks will have the same outputs -print("-" * 50) -for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}\n" - f"Generated text: {generated_text!r}") +if dist.get_rank() == 0: print("-" * 50) -""" + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}\n" + f"Generated text: {generated_text!r}\n") + print("-" * 50) + """ Further tips: 1. to communicate control messages across all ranks, use the cpu group, diff --git a/examples/online_serving/disaggregated_serving/README.md b/examples/online_serving/disaggregated_serving/README.md new file mode 100644 index 00000000000..090afd7515e --- /dev/null +++ b/examples/online_serving/disaggregated_serving/README.md @@ -0,0 +1,8 @@ +# Disaggregated Serving + +This example contains scripts that demonstrate the disaggregated serving features of vLLM. + +## Files + +- `disagg_proxy_demo.py` - Demonstrates XpYd (X prefill instances, Y decode instances). +- `kv_events.sh` - Demonstrates KV cache event publishing. diff --git a/examples/online_serving/disagg_examples/disagg_proxy_demo.py b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py similarity index 99% rename from examples/online_serving/disagg_examples/disagg_proxy_demo.py rename to examples/online_serving/disaggregated_serving/disagg_proxy_demo.py index a701636f357..1bf4d50e2c9 100644 --- a/examples/online_serving/disagg_examples/disagg_proxy_demo.py +++ b/examples/online_serving/disaggregated_serving/disagg_proxy_demo.py @@ -4,7 +4,7 @@ example usage of XpYd disaggregated prefilling. We can launch multiple vllm instances (2 for prefill and 2 for decode), and launch this proxy demo through: - python3 examples/online_serving/disagg_examples/disagg_proxy_demo.py \ + python3 examples/online_serving/disaggregated_serving/disagg_proxy_demo.py \ --model $model_name \ --prefill localhost:8100 localhost:8101 \ --decode localhost:8200 localhost:8201 \ diff --git a/examples/online_serving/kv_events.sh b/examples/online_serving/disaggregated_serving/kv_events.sh similarity index 100% rename from examples/online_serving/kv_events.sh rename to examples/online_serving/disaggregated_serving/kv_events.sh diff --git a/examples/online_serving/openai_chat_completion_client_for_multimodal.py b/examples/online_serving/openai_chat_completion_client_for_multimodal.py index cffd093c983..2707d46f46e 100644 --- a/examples/online_serving/openai_chat_completion_client_for_multimodal.py +++ b/examples/online_serving/openai_chat_completion_client_for_multimodal.py @@ -1,5 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 -"""An example showing how to use vLLM to serve multimodal models +"""An example showing how to use vLLM to serve multimodal models and run online serving with OpenAI client. Launch the vLLM server with the following command: @@ -12,12 +12,18 @@ --trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}' (audio inference with Ultravox) -vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096 +vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b \ + --max-model-len 4096 --trust-remote-code + +run the script with +python openai_chat_completion_client_for_multimodal.py --chat-type audio """ + import base64 import requests from openai import OpenAI +from utils import get_first_model from vllm.utils import FlexibleArgumentParser @@ -31,9 +37,6 @@ base_url=openai_api_base, ) -models = client.models.list() -model = models.data[0].id - def encode_base64_content_from_url(content_url: str) -> str: """Encode a content retrieved from a remote url to base64 format.""" @@ -46,7 +49,7 @@ def encode_base64_content_from_url(content_url: str) -> str: # Text-only inference -def run_text_only() -> None: +def run_text_only(model: str) -> None: chat_completion = client.chat.completions.create( messages=[{ "role": "user", @@ -61,7 +64,7 @@ def run_text_only() -> None: # Single-image input inference -def run_single_image() -> None: +def run_single_image(model: str) -> None: ## Use image url in the payload image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" @@ -117,7 +120,7 @@ def run_single_image() -> None: # Multi-image input inference -def run_multi_image() -> None: +def run_multi_image(model: str) -> None: image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" chat_completion_from_url = client.chat.completions.create( @@ -152,7 +155,7 @@ def run_multi_image() -> None: # Video input inference -def run_video() -> None: +def run_video(model: str) -> None: video_url = "http://commondatastorage.googleapis.com/gtv-videos-bucket/sample/ForBiggerFun.mp4" video_base64 = encode_base64_content_from_url(video_url) @@ -208,7 +211,7 @@ def run_video() -> None: # Audio input inference -def run_audio() -> None: +def run_audio(model: str) -> None: from vllm.assets.audio import AudioAsset audio_url = AudioAsset("winning_call").url @@ -318,7 +321,8 @@ def parse_args(): def main(args) -> None: chat_type = args.chat_type - example_function_map[chat_type]() + model = get_first_model(client) + example_function_map[chat_type](model) if __name__ == "__main__": diff --git a/examples/online_serving/opentelemetry/Otel.md b/examples/online_serving/opentelemetry/README.md similarity index 100% rename from examples/online_serving/opentelemetry/Otel.md rename to examples/online_serving/opentelemetry/README.md diff --git a/examples/online_serving/utils.py b/examples/online_serving/utils.py new file mode 100644 index 00000000000..4826e8e2052 --- /dev/null +++ b/examples/online_serving/utils.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: Apache-2.0 +from openai import APIConnectionError, OpenAI +from openai.pagination import SyncPage +from openai.types.model import Model + + +def get_first_model(client: OpenAI) -> str: + """ + Get the first model from the vLLM server. + """ + try: + models: SyncPage[Model] = client.models.list() + except APIConnectionError as e: + raise RuntimeError( + "Failed to get the list of models from the vLLM server at " + f"{client.base_url} with API key {client.api_key}. Check\n" + "1. the server is running\n" + "2. the server URL is correct\n" + "3. the API key is correct") from e + + if len(models.data) == 0: + raise RuntimeError( + f"No models found on the vLLM server at {client.base_url}") + + return models.data[0].id diff --git a/pyproject.toml b/pyproject.toml index 46cf7a801fd..0b803a26b65 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,16 +71,15 @@ exclude = [ "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Python 3.8 typing. TODO: Remove these excludes after v1.0.0 +# Python 3.8 typing - skip V0 code "vllm/attention/**/*.py" = ["UP006", "UP035"] "vllm/core/**/*.py" = ["UP006", "UP035"] "vllm/engine/**/*.py" = ["UP006", "UP035"] "vllm/executor/**/*.py" = ["UP006", "UP035"] -"vllm/model_executor/model_loader/**/*.py" = ["UP006", "UP035"] -"vllm/model_executor/models/**/*.py" = ["UP006", "UP035"] "vllm/prompt_adapter/**/*.py" = ["UP006", "UP035"] "vllm/spec_decode/**/*.py" = ["UP006", "UP035"] "vllm/worker/**/*.py" = ["UP006", "UP035"] +# Python 3.8 typing - skip utils for ROCm "vllm/utils.py" = ["UP006", "UP035"] [tool.ruff.lint] @@ -170,3 +169,9 @@ plugins.md013.enabled = false # line-length plugins.md041.enabled = false # first-line-h1 plugins.md033.enabled = false # inline-html plugins.md024.allow_different_nesting = true # no-duplicate-headers + +[tool.ty] +respect-ignore-files = true + +[tool.ty.environment] +python = "./.venv" diff --git a/requirements/rocm-test.txt b/requirements/rocm-test.txt index abd4212c6e3..25f950a99ec 100644 --- a/requirements/rocm-test.txt +++ b/requirements/rocm-test.txt @@ -22,4 +22,10 @@ decord==0.6.0 #sentence-transformers # required by entrypoints/openai/test_score.py sentence-transformers==3.4.1 +# Basic Models Test +matplotlib==3.10.3 + +# Multi-Modal Models Test (Extended) 3 +blobfile==3.0.0 + diff --git a/tests/distributed/test_torchrun_example.py b/tests/distributed/test_torchrun_example.py index 0420a6454d4..bb38e908b73 100644 --- a/tests/distributed/test_torchrun_example.py +++ b/tests/distributed/test_torchrun_example.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # unit test for `examples/offline_inference/torchrun_example.py` - +import os import random import torch.distributed as dist @@ -25,6 +25,7 @@ # to test if all ranks agree on the same kv cache configuration. llm = LLM(model="facebook/opt-125m", tensor_parallel_size=2, + pipeline_parallel_size=int(os.getenv("PP_SIZE", 1)), distributed_executor_backend="external_launcher", gpu_memory_utilization=random.uniform(0.7, 0.9), swap_space=random.randint(1, 4), diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index ce8873d58d4..05d9cfc7ab7 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -181,8 +181,8 @@ def test_get_kwargs(): # literals of literals should have merged choices assert kwargs["literal_literal"]["choices"] == [1, 2] # dict should have json tip in help - json_tip = "\n\nShould be a valid JSON string." - assert kwargs["json_tip"]["help"].endswith(json_tip) + json_tip = "Should either be a valid JSON string or JSON keys" + assert json_tip in kwargs["json_tip"]["help"] # nested config should should construct the nested config assert kwargs["nested_config"]["type"]('{"field": 2}') == NestedConfig(2) # from_cli configs should be constructed with the correct method diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 663b722426c..9773f3e45b9 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -145,6 +145,83 @@ async def test_tokenize_chat( } +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,tokenizer_name", + [(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")], + indirect=["tokenizer_name"], +) +async def test_tokenize_chat_with_tools( + server: RemoteOpenAIServer, + model_name: str, + tokenizer_name: str, +): + tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, + tokenizer_mode="fast") + + for add_generation in [False, True]: + for add_special in [False, True]: + conversation = [{ + "role": + "user", + "content": + "What's the weather like in Paris today?", + }] + + tools = [{ + "type": "function", + "function": { + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string" + } + }, + }, + }, + }] + + for continue_final in [False, True]: + if add_generation and continue_final: + continue + if continue_final: + conversation.append({ + "role": "assistant", + "content": "Sure," + }) + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + continue_final_message=continue_final, + conversation=conversation, + tools=tools, + tokenize=False, + ) + tokens = tokenizer.encode(prompt, + add_special_tokens=add_special) + + response = requests.post( + server.url_for("tokenize"), + json={ + "add_generation_prompt": add_generation, + "continue_final_message": continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name, + "tools": tools, + }, + ) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192, + } + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name,tokenizer_name", diff --git a/tests/kernels/test_triton_unified_attention.py b/tests/kernels/attention/test_triton_unified_attention.py similarity index 98% rename from tests/kernels/test_triton_unified_attention.py rename to tests/kernels/attention/test_triton_unified_attention.py index 50da8e5fd5c..4e15d00255a 100644 --- a/tests/kernels/test_triton_unified_attention.py +++ b/tests/kernels/attention/test_triton_unified_attention.py @@ -99,6 +99,9 @@ def test_triton_unified_attn( ) -> None: torch.set_default_device("cuda") + if q_dtype is not None and q_dtype.itemsize < 2 and block_size < 32: + pytest.skip("block size must be at least 32 for fp8") + current_platform.seed_everything(0) num_seqs = len(seq_lens) query_lens = [x[0] for x in seq_lens] diff --git a/tests/kernels/core/test_pos_encoding.py b/tests/kernels/core/test_pos_encoding.py index d81c7487b88..f327deb0e54 100644 --- a/tests/kernels/core/test_pos_encoding.py +++ b/tests/kernels/core/test_pos_encoding.py @@ -29,12 +29,20 @@ def _get_flat_tensor_shape(batch_size: int, seq_len: int, num_heads: int, return (batch_size, seq_len, num_heads * head_size) +# For testing sliced tensors +def _get_padded_tensor_shape(batch_size: int, seq_len: int, num_heads: int, + head_size: int) -> tuple[int, ...]: + return (batch_size, seq_len, num_heads, head_size + 64) + + def _get_batch_tensor_shape(batch_size: int, seq_len: int, num_heads: int, head_size: int) -> tuple[int, ...]: return (batch_size, seq_len, num_heads, head_size) -TENSORS_SHAPES_FN = [_get_batch_tensor_shape, _get_flat_tensor_shape] +TENSORS_SHAPES_FN = [ + _get_batch_tensor_shape, _get_flat_tensor_shape, _get_padded_tensor_shape +] @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) @@ -79,6 +87,10 @@ def test_rotary_embedding( query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None + # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) @@ -140,6 +152,10 @@ def test_batched_rotary_embedding( query = torch.randn(query_shape, dtype=dtype) key = torch.randn_like(query) if use_key else None + # slice tensor if required, noop otherwise + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None + # NOTE(woosuk): The reference implementation should be executed first # because the custom kernel is in-place. ref_query, ref_key = rope.forward_native(positions, query, key) diff --git a/tests/kernels/core/test_rotary_embedding.py b/tests/kernels/core/test_rotary_embedding.py index 4e54861005f..8383f943b9f 100644 --- a/tests/kernels/core/test_rotary_embedding.py +++ b/tests/kernels/core/test_rotary_embedding.py @@ -38,9 +38,10 @@ def rotary_embedding_opcheck(rot, @pytest.mark.parametrize("head_size", [32, 108]) @pytest.mark.parametrize("seq_len", [11, 1024]) @pytest.mark.parametrize("use_key", [True, False]) +@pytest.mark.parametrize("head_stride_is_contingous", [True, False]) def test_rotary_embedding_opcheck(dist_init, device, max_position, is_neox_style, rotary_dim, head_size, - seq_len, use_key): + seq_len, use_key, head_stride_is_contingous): batch_size = 1 base = 10000 num_heads = 7 @@ -50,15 +51,27 @@ def test_rotary_embedding_opcheck(dist_init, device, max_position, positions = torch.randint(0, max_position, (batch_size, seq_len), device=device) + head_stride = head_size + (64 if head_stride_is_contingous else 0) + query = torch.randn(batch_size, seq_len, - num_heads * head_size, + num_heads, + head_stride, dtype=torch.float32, device=device) key = torch.randn_like(query) if use_key else None + query = query[..., :head_size] + key = key[..., :head_size] if use_key else None rotary_embedding_opcheck(rot, positions, query, key) offsets = torch.zeros(batch_size * seq_len, device=device, dtype=torch.long) rotary_embedding_opcheck(rot, positions, query, key, offsets) + + # if we have a contiguous head stride, test the alternate + # [..., num_heads * head_dim] shape/layout + if head_stride_is_contingous: + rotary_embedding_opcheck( + rot, positions, query.flatten(start_dim=-2), + key.flatten(start_dim=-2) if use_key else None) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py new file mode 100644 index 00000000000..7d369edfc86 --- /dev/null +++ b/tests/kernels/moe/test_batched_moe.py @@ -0,0 +1,114 @@ +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import dataclass + +import pytest +import torch +import triton.language as tl + +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + invoke_moe_batched_triton_kernel) + + +@dataclass +class BatchedMMConfig: + dtype: torch.dtype + num_experts: int + max_tokens_per_expert: int + K: int + N: int + + +@dataclass +class BatchedMMTensors: + A: torch.Tensor # [E, max_tokens, K] + B: torch.Tensor # [E, K, N] - column major + C: torch.Tensor # [E, max_tokens, N] + num_expert_tokens: torch.Tensor # [E] + + @staticmethod + def make_tensors(config: BatchedMMConfig): + A = torch.randn( + (config.num_experts, config.max_tokens_per_expert, config.K), + device="cuda", + dtype=config.dtype) / 10 + B = torch.randn((config.num_experts, config.N, config.K), + device="cuda", + dtype=config.dtype) + C = torch.zeros( + (config.num_experts, config.max_tokens_per_expert, config.N), + device="cuda", + dtype=config.dtype) + num_expert_tokens = torch.randint(low=0, + high=config.max_tokens_per_expert, + size=(config.num_experts, ), + device="cuda", + dtype=torch.int32) + return BatchedMMTensors(A, B, C, num_expert_tokens) + + +def ref_impl(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + num_expert_tokens: torch.Tensor) -> torch.Tensor: + + num_expert_tokens_cpu = num_expert_tokens.clone() + num_expert_tokens_cpu = num_expert_tokens_cpu.to(device="cpu") + num_experts = num_expert_tokens.size(0) + + for e in range(num_experts): + num_tokens = num_expert_tokens_cpu[e] + C[e, :num_tokens, :] = A[e, :num_tokens, :] @ B[e].transpose(0, 1) + + return C + + +@pytest.mark.parametrize("num_experts", [16, 32]) +@pytest.mark.parametrize("max_tokens_per_expert", + [32, 64, 128, 192, 224, 256, 512]) +@pytest.mark.parametrize("K", [128, 256, 1024]) +@pytest.mark.parametrize("N", [128, 256, 512, 1024]) +@pytest.mark.parametrize("dtype", + [torch.float32, torch.float16, torch.bfloat16]) +def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, + N: int, dtype: torch.dtype): + + config = BatchedMMConfig(dtype, num_experts, max_tokens_per_expert, K, N) + tensors = BatchedMMTensors.make_tensors(config) + + test_output = tensors.C + ref_output = test_output.clone() + + compute_tl_dtype = { + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, + torch.float32: tl.float32 + }[test_output.dtype] + invoke_moe_batched_triton_kernel( + tensors.A, + tensors.B, + test_output, + tensors.num_expert_tokens, + compute_tl_dtype, + # Quantization data + None, + None, + None, + # Quantization schemes + False, + False, + False, + config={ + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 16, + "BLOCK_SIZE_K": 16 + }) + + ref_output = ref_impl(tensors.A, tensors.B, ref_output, + tensors.num_expert_tokens) + + rtol, atol = { + torch.float16: (6e-2, 6e-2), + torch.bfloat16: (6e-2, 6e-2), + torch.float32: (1e-2, 1e-2), + }[test_output.dtype] + + torch.testing.assert_close(test_output, ref_output, atol=atol, rtol=rtol) diff --git a/tests/kernels/moe/test_cutlass_moe.py b/tests/kernels/moe/test_cutlass_moe.py index 975cd418a17..7db4fe0f46e 100644 --- a/tests/kernels/moe/test_cutlass_moe.py +++ b/tests/kernels/moe/test_cutlass_moe.py @@ -30,6 +30,11 @@ (224, 3072, 1536), ] +vllm_config = VllmConfig(parallel_config=ParallelConfig( + pipeline_parallel_size=1)) +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @dataclasses.dataclass class MOETensors: @@ -190,7 +195,7 @@ def run_8_bit(moe_tensors: MOETensors8Bit, 'w1_q': moe_tensors.w1_q.transpose(1, 2), # type: ignore[union-attr] 'w2_q': moe_tensors.w2_q.transpose(1, 2), # type: ignore[union-attr] 'topk_weights': topk_weights, - 'topk_ids_': topk_ids, + 'topk_ids': topk_ids, 'ab_strides1': moe_tensors.ab_strides1, 'c_strides1': moe_tensors.c_strides1, 'ab_strides2': moe_tensors.ab_strides2, @@ -231,18 +236,15 @@ def test_cutlass_moe_8_bit_no_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -276,20 +278,17 @@ def test_cutlass_moe_8_bit_cuda_graph( per_out_ch: bool, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): dtype = torch.half mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_ch) score = torch.randn((m, e), device="cuda", dtype=dtype) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. @@ -334,18 +333,15 @@ def test_cutlass_moe_8_bit_EP( ep_size: int, ): current_platform.seed_everything(7) - with set_current_vllm_config( - VllmConfig(parallel_config=ParallelConfig( - pipeline_parallel_size=1))): - + with set_current_vllm_config(vllm_config): mt = MOETensors8Bit.make_moe_tensors_8bit(m, k, n, e, per_act_token, per_out_channel) score = torch.randn((m, e), device="cuda", dtype=torch.half) - topk_weights, topk_ids = fused_topk(mt.a, - score, - topk, - renormalize=False) + topk_weights, topk_ids, _ = fused_topk(mt.a, + score, + topk, + renormalize=False) # Note that we are using the dequantized versions of the tensors. # Using a, w1 and w2 directly results in minor output differences. diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 96b090136e3..43ddc79fcb8 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -12,6 +12,7 @@ import vllm.model_executor.layers.fused_moe # noqa from tests.kernels.utils import opcheck, stack_and_dev, torch_moe +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( @@ -32,6 +33,10 @@ EP_SIZE = [1, 4] TOP_KS = [2, 6] +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("n", [128, 1024, 2048]) @@ -70,31 +75,33 @@ def test_fused_moe( else: e_map = None - torch_output = torch_moe(a, w1, w2, score, topk, e_map) - iterative_output = iterative_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w1, w2, score, topk, e_map) + iterative_output = iterative_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) + + # Pad the weight if moe padding is enabled + if padding: + w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] + torch.cuda.empty_cache() + + triton_output = fused_moe(a, + w1, + w2, + score, + topk, + global_num_experts=e, + expert_map=e_map, + renormalize=False) - # Pad the weight if moe padding is enabled - if padding: - w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128] - torch.cuda.empty_cache() - - triton_output = fused_moe(a, - w1, - w2, - score, - topk, - global_num_experts=e, - expert_map=e_map, - renormalize=False) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(iterative_output, torch_output, @@ -115,7 +122,6 @@ def test_fused_moe( def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, ep_size: int, dtype: torch.dtype, group_size: int, has_zp: bool, weight_bits: int): - print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 @@ -194,22 +200,24 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int, else: e_map = None - triton_output = fused_moe(a, - w1_qweight, - w2_qweight, - score, - topk, - renormalize=False, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=e, - expert_map=e_map, - w1_scale=w1_scales, - w2_scale=w2_scales, - w1_zp=w1_qzeros if has_zp else None, - w2_zp=w2_qzeros if has_zp else None, - block_shape=[0, group_size]) - torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + with set_current_vllm_config(vllm_config): + triton_output = fused_moe(a, + w1_qweight, + w2_qweight, + score, + topk, + renormalize=False, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=e, + expert_map=e_map, + w1_scale=w1_scales, + w2_scale=w2_scales, + w1_zp=w1_qzeros if has_zp else None, + w2_zp=w2_qzeros if has_zp else None, + block_shape=[0, group_size]) + torch_output = torch_moe(a, w1_ref, w2_ref, score, topk, e_map) + torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) @@ -515,7 +523,8 @@ def test_fused_marlin_moe( topk_weights, topk_ids, _ = fused_topk(a, score, topk, False) - torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) + with set_current_vllm_config(vllm_config): + torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map) marlin_output = torch.ops.vllm.fused_marlin_moe( a, diff --git a/tests/kernels/moe/test_pplx_moe.py b/tests/kernels/moe/test_pplx_moe.py new file mode 100644 index 00000000000..8c4a2c3fa44 --- /dev/null +++ b/tests/kernels/moe/test_pplx_moe.py @@ -0,0 +1,691 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Tests for the MOE layers. + +Run `pytest tests/kernels/test_pplx_moe.py`. +""" +import dataclasses +import os +import traceback +from typing import Callable, Optional + +import pytest +import torch + +try: + from pplx_kernels import AllToAll + from pplx_kernels.nvshmem import (nvshmem_alloc_empty_unique_id, + nvshmem_finalize, nvshmem_get_unique_id, + nvshmem_init) + has_pplx = True +except ImportError: + has_pplx = False + +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import override_config +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedExperts, BatchedPrepareAndFinalize, BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.fused_moe import (fused_topk, + get_default_config) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.platforms import current_platform + +PPLX_PREPARE_COMBOS = [(4, 128, 128), (32, 1024, 512), (64, 1024, 512), + (222, 2048, 1024)] + +PPLX_MOE_COMBOS = [ + (1, 128, 128), + (2, 128, 512), + (3, 1024, 2048), + (32, 128, 1024), + (45, 512, 2048), + (64, 1024, 1024), + (222, 1024, 2048), +] + +NUM_EXPERTS = [8, 64] +EP_SIZE = [1, 4] +TOP_KS = [1, 2, 6] + +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + +P = ParamSpec("P") + +requires_pplx = pytest.mark.skipif( + not has_pplx, + reason="Requires PPLX kernels", +) + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +def parallel_launch_from_env( + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + """ + Launches a worker function in parallel across all processes in the current + environment. The environment must have the following variables set: + - WORLD_SIZE: The total number of processes. + - WORLD_LOCAL_SIZE: The number of processes on the current node. + - NODE_RANK: The rank of the current + - MASTER_ADDR: The address of the master process. + - MASTER_PORT: The port of the master process. + """ + assert not kwargs + world_size = int(os.environ["WORLD_SIZE"]) + world_local_size = int(os.environ["WORLD_LOCAL_SIZE"]) + node_rank = int(os.environ["NODE_RANK"]) + assert "MASTER_ADDR" in os.environ + assert "MASTER_PORT" in os.environ + spawn( + _worker_parallel_launch, + args=( + world_size, + world_local_size, + node_rank, + "env://", + worker, + ) + args, + nprocs=world_local_size, + join=True, + ) + + +def torch_prepare( + a: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + max_num_tokens: Optional[int] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert topk_ids.dim() == 2 + assert topk_ids.shape[0] == a.shape[0] + + num_tokens, hidden_dim = a.shape + topk = topk_ids.shape[1] + + tokens_per_expert = torch.bincount(topk_ids.view(-1), + minlength=num_experts) + + assert tokens_per_expert.numel() == num_experts + + if max_num_tokens is None: + max_num_tokens = int(tokens_per_expert.max().item()) + + b_a = torch.zeros((num_experts, max_num_tokens, hidden_dim), + dtype=a.dtype, + device=a.device) + + token_counts = torch.zeros(num_experts, dtype=torch.int, device=a.device) + + for token in range(num_tokens): + for j in range(topk): + expert_id = topk_ids[token, j] + idx = token_counts[expert_id] + b_a[expert_id, idx:idx + 1, :] = a[token, :] + token_counts[expert_id] = token_counts[expert_id] + 1 + + return b_a, tokens_per_expert + + +def torch_finalize(b_out: torch.Tensor, topk_weight: torch.Tensor, + topk_ids: torch.Tensor) -> torch.Tensor: + num_tokens = topk_ids.shape[0] + num_experts = b_out.shape[0] + K = b_out.shape[-1] + out = torch.zeros((num_tokens, K), dtype=b_out.dtype, device=b_out.device) + expert_counts = torch.zeros(num_experts, + dtype=torch.int, + device=b_out.device) + for token in range(num_tokens): + expert_ids = topk_ids[token] + for i in range(expert_ids.numel()): + expert_id = expert_ids[i] + idx = expert_counts[expert_id] + out[token, :] = out[token, :] + b_out[expert_id, idx:idx + + 1, :] * topk_weight[token, i] + expert_counts[expert_id] = expert_counts[expert_id] + 1 + + return out + + +def torch_batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + b_a, tokens_per_expert = torch_prepare(a, topk_ids, num_experts) + assert b_a.dim() == 3 + num_tokens, topk = topk_ids.shape + _, max_num_tokens, K = b_a.shape + assert num_experts == b_a.shape[0] and w2.shape[1] == K + out = torch.zeros((num_experts, max_num_tokens, K), + dtype=b_a.dtype, + device=b_a.device) + tmp = torch.empty((max_num_tokens, w1.shape[1] // 2), + dtype=b_a.dtype, + device=b_a.device) + for expert in range(num_experts): + num = tokens_per_expert[expert] + if num > 0: + torch.ops._C.silu_and_mul( + tmp[:num], b_a[expert, :num, :] @ w1[expert].transpose(0, 1)) + out[expert, :num, :] = tmp[:num] @ w2[expert].transpose(0, 1) + + return torch_finalize(out, topk_weight, topk_ids) + + +def batched_moe( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + num_experts = w1.shape[0] + + fused_experts = FusedMoEModularKernel( + BatchedPrepareAndFinalize(a.shape[0], world_size=1, dp_size=1, rank=0), + BatchedExperts(max_num_tokens=a.shape[0], dp_size=1, world_size=1)) + + return fused_experts(a, w1, w2, topk_weight, topk_ids, num_experts) + + +# Note: same as torch_moe but with fused_topk factored out. +def torch_moe2( + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, +) -> torch.Tensor: + M, K = a.shape + topk = topk_ids.shape[1] + a = a.view(M, -1, K).repeat(1, topk, 1).reshape(-1, K) + out = torch.zeros(M * topk, w2.shape[1], dtype=a.dtype, device=a.device) + num_experts = w1.shape[0] + for i in range(num_experts): + mask = (topk_ids == i).view(-1) + if mask.sum(): + out[mask] = SiluAndMul()( + a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + + return (out.view(M, -1, w2.shape[1]) * + topk_weight.view(M, -1, 1).to(out.dtype)).sum(dim=1) + + +@pytest.mark.parametrize("m", [1, 33, 64, 222]) +@pytest.mark.parametrize("n", [128, 1024, 2048]) +@pytest.mark.parametrize("k", [128, 512, 1024]) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +def test_fused_moe_batched_experts( + m: int, + n: int, + k: int, + e: int, + topk: int, + dtype: torch.dtype, +): + current_platform.seed_everything(7) + + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + with set_current_vllm_config(vllm_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + baseline_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + torch_output = torch_batched_moe(a, w1, w2, topk_weight, topk_ids) + batched_output = batched_moe(a, w1, w2, topk_weight, topk_ids) + + torch.testing.assert_close(baseline_output, + torch_output, + atol=2e-2, + rtol=0) + torch.testing.assert_close(baseline_output, + batched_output, + atol=2e-2, + rtol=0) + + +def rank_chunk(num: int, r: int, w: int) -> int: + rem = num % w + return (num // w) + (1 if r < rem else 0) + + +def chunk_by_rank(t: torch.Tensor, r: int, w: int) -> torch.Tensor: + chunk = rank_chunk(t.shape[0], r, w) + return t[(r * chunk):(r + 1) * chunk] + + +def pplx_prepare_finalize(pgi: ProcessGroupInfo, dp_size: int, a: torch.Tensor, + topk_weight: torch.Tensor, topk_ids: torch.Tensor, + num_experts: int) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + + assert torch.cuda.current_device() == pgi.local_rank + + topk = topk_ids.shape[1] + num_tokens, hidden_dim = a.shape + block_size = 128 + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + max_num_tokens = rank_chunk(num_tokens, 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + topk_ids = topk_ids.to(dtype=torch.uint32) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + a.dtype, + ) + + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + b_a, b_a_scale, expert_num_tokens = prepare_finalize.prepare( + a_chunk, + None, + None, + chunk_topk_weight, + chunk_topk_ids, + num_experts, + None, + False, + ) + + b_a = b_a * 1.5 + + out = torch.full( + (max_num_tokens, hidden_dim), + torch.nan, + dtype=a.dtype, + device=device, + ) + + prepare_finalize.finalize( + out, + b_a, + chunk_topk_weight, + chunk_topk_ids, + False, + ) + + torch.cuda.synchronize() + + ata.destroy() + + num_tokens = a_chunk.shape[0] + + return out[:num_tokens] + + +def _pplx_prepare_finalize( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + score: torch.Tensor, + topk: torch.Tensor, + num_experts: int, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + device = pgi.device + + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + k = a.shape[1] + + a_rep = torch.repeat_interleave(a, topk, dim=0).to(device) + + torch_output = (a_rep.view(-1, topk, k) * 1.5 * + topk_weight.view(-1, topk, 1).to(device)).sum(dim=1).to( + a.dtype) + + pplx_output = pplx_prepare_finalize(pgi, dp_size, a, topk_weight, topk_ids, + num_experts) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +# TODO (bnell): this test point does not work for odd M due to how the test is +# written, not due to limitations of the pplx kernels. The pplx_moe +# test below is able to deal with odd M. +@pytest.mark.parametrize("mnk", PPLX_PREPARE_COMBOS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_prepare_finalize( + mnk: tuple[int, int, int], + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + m, n, k = mnk + world_size, dp_size = world_dp_size + device = "cuda" + a = torch.randn((m, k), device=device, dtype=dtype) / 10 + score = torch.randn((m, e), device=device, dtype=dtype) + + parallel_launch(world_size, _pplx_prepare_finalize, dp_size, a, score, + topk, e) + + +def pplx_moe( + rank: int, + world_size: int, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weight: torch.Tensor, + topk_ids: torch.Tensor, + use_compile: bool = True, + use_cudagraphs: bool = True, +) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import ( + PplxPrepareAndFinalize) + + device = torch.device("cuda", rank) + hidden_dim = a.shape[1] + num_experts = w1.shape[0] + block_size = 128 + topk = topk_ids.shape[1] + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + ata = AllToAll.internode( + max_num_tokens=max_num_tokens, + num_experts=num_experts, + experts_per_token=topk, + rank=rank, + world_size=world_size, + dp_size=dp_size, + hidden_dim=hidden_dim, + hidden_dim_bytes=hidden_dim * a.dtype.itemsize, + hidden_dim_scale_bytes=(0 if a.dtype.itemsize != 1 else + ((hidden_dim + block_size - 1) // block_size * + torch.float32.itemsize)), + ) + + topk_ids = topk_ids.to(dtype=torch.uint32) + + prepare_finalize = PplxPrepareAndFinalize( + ata, + max_num_tokens, + world_size, + rank, + dp_size, + ) + + experts = BatchedTritonExperts(max_num_tokens=a.shape[0], + world_size=world_size, + dp_size=dp_size) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + # Note: workers with the same dp_rank must use the exact same inputs. + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + # Chunking weights like this only works for batched format + w1_chunk = chunk_by_rank(w1, rank, world_size).to(device) + w2_chunk = chunk_by_rank(w2, rank, world_size).to(device) + + if use_compile: + _fused_experts = torch.compile(fused_experts, + backend='inductor', + fullgraph=True) + else: + _fused_experts = fused_experts + + out = _fused_experts(a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + if use_cudagraphs: + out.fill_(0) + stream = torch.cuda.Stream() + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, stream=stream): + out = _fused_experts(a_chunk, + w1_chunk, + w2_chunk, + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + torch.cuda.synchronize() + graph.replay() + + torch.cuda.synchronize() + + ata.destroy() + + return out + + +def _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, topk_ids): + assert torch.cuda.current_device() == pgi.local_rank + + num_experts = w1.shape[0] + device = pgi.device + rank = pgi.rank + world_size = pgi.world_size + max_num_tokens = rank_chunk(a.shape[0], 0, world_size) + + prepare_finalize = BatchedPrepareAndFinalize( + max_num_tokens=max_num_tokens, + world_size=world_size, + dp_size=dp_size, + rank=rank, + ) + + experts = BatchedExperts(max_num_tokens=a.shape[0], + world_size=1, + dp_size=1) + + fused_experts = FusedMoEModularKernel( + prepare_finalize, + experts, + ) + + # Note: workers with the same dp_rank must use the exact same inputs. + a_chunk = chunk_by_rank(a, rank, world_size).to(device) + chunk_topk_weight = chunk_by_rank(topk_weight, rank, world_size).to(device) + chunk_topk_ids = chunk_by_rank(topk_ids, rank, world_size).to(device) + + out = fused_experts( + a_chunk, + # Chunking weights like this only works for batched format + chunk_by_rank(w1, rank, world_size).to(device), + chunk_by_rank(w2, rank, world_size).to(device), + chunk_topk_weight, + chunk_topk_ids, + global_num_experts=num_experts) + + return out + + +def _pplx_moe( + pgi: ProcessGroupInfo, + dp_size: int, + a: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + score: torch.Tensor, + topk: int, +): + uid = nvshmem_get_unique_id( + ) if pgi.rank == 0 else nvshmem_alloc_empty_unique_id() + torch.distributed.broadcast(uid, src=0) + nvshmem_init(uid, pgi.rank, pgi.world_size) + + m, k = a.shape + e, _, n = w2.shape + + moe_config = get_default_config(m, e, n, k, topk, a.dtype, False) + + with set_current_vllm_config(vllm_config), override_config(moe_config): + topk_weight, topk_ids, _ = fused_topk(a, score, topk, False) + torch_output = torch_moe2(a, w1, w2, topk_weight, topk_ids) + pplx_output = pplx_moe(pgi.rank, pgi.world_size, dp_size, a, w1, w2, + topk_weight, topk_ids) + # TODO (bnell): fix + re-enable + #batched_output = _batched_moe(pgi, dp_size, a, w1, w2, topk_weight, + # topk_ids) + + torch_output = chunk_by_rank(torch_output, pgi.rank, + pgi.world_size).to(pplx_output.device) + + torch.testing.assert_close(pplx_output, torch_output, atol=2e-2, rtol=0) + #torch.testing.assert_close(batched_output, torch_output, atol=2e-2, rtol=0) + + nvshmem_finalize() + + +@pytest.mark.parametrize("mnk", PPLX_MOE_COMBOS) +@pytest.mark.parametrize("e", NUM_EXPERTS) +@pytest.mark.parametrize("topk", TOP_KS) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("world_dp_size", [[2, 1]]) +@requires_pplx +def test_pplx_moe( + mnk: tuple[int, int, int], + e: int, + topk: int, + dtype: torch.dtype, + world_dp_size: tuple[int, int], +): + current_platform.seed_everything(7) + m, n, k = mnk + world_size, dp_size = world_dp_size + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + score = torch.randn((m, e), device="cuda", dtype=dtype) + + parallel_launch(world_size, _pplx_moe, dp_size, a, w1, w2, score, topk) diff --git a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py index 44734e9340a..3b5838a99fa 100644 --- a/tests/kernels/moe/test_triton_moe_ptpc_fp8.py +++ b/tests/kernels/moe/test_triton_moe_ptpc_fp8.py @@ -7,6 +7,7 @@ import torch from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.platforms import current_platform @@ -15,6 +16,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16): """Matrix multiplication function that supports per-token input @@ -137,20 +142,21 @@ def test_w8a8_fp8_fused_moe(M, N, K, E, topk, dtype, seed): w2_s = torch.rand(E, K, device=w2_fp32.device) * factor_for_scale score = torch.randn((M, E), dtype=dtype) - ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) - out = fused_moe( - a, - w1, - w2, - score, - topk, - renormalize=False, - use_fp8_w8a8=True, # using fp8 - per_channel_quant=True, - w1_scale=w1_s, - w2_scale=w2_s, - block_shape=None, # Not using block quantization - ) + with set_current_vllm_config(vllm_config): + ref_out = torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk) + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_fp8_w8a8=True, # using fp8 + per_channel_quant=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=None, # Not using block quantization + ) # Check results rel_diff = (torch.mean( diff --git a/tests/kernels/quantization/test_block_fp8.py b/tests/kernels/quantization/test_block_fp8.py index 38c7e461bb9..ef1d7e47ef8 100644 --- a/tests/kernels/quantization/test_block_fp8.py +++ b/tests/kernels/quantization/test_block_fp8.py @@ -11,7 +11,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( - deep_gemm_moe_fp8) + _valid_deep_gemm_shape, deep_gemm_moe_fp8) from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.moe_align_block_size import ( moe_align_block_size) @@ -30,6 +30,10 @@ pytest.skip("FP8 Triton requires CUDA 9.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # Test configurations DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] NUM_TOKENS = [7, 83, 2048] @@ -210,7 +214,6 @@ def test_w8a8_block_fp8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, @@ -258,6 +261,7 @@ def per_block_cast_to_fp8( @pytest.mark.parametrize( "M,N,K,block_size,out_dtype,seed", itertools.product(M, N, K, BLOCK_SIZE, OUT_DTYPES, SEEDS)) +@pytest.mark.skipif(not dg_available, reason="DeepGemm kernels not available.") @torch.inference_mode() def test_w8a8_block_fp8_deep_gemm_matmul(M, N, K, block_size, out_dtype, seed): # only aligned sizes @@ -381,15 +385,11 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed): block_size = [block_m, block_m] dtype = torch.bfloat16 - # only aligned sizes - if (N % block_m != 0 or K % block_m != 0 or topk > E): - pytest.skip( - f"Skipping test; bad size m={M}, n={N}, k={K}, topk={topk}, E={E}") - - if N <= 512: - pytest.skip("Skipping N <= 512 until performance issues solved.") + if topk > E: + pytest.skip(f"Skipping test: topk={topk} > E={E}") - vllm_config = VllmConfig() + if not _valid_deep_gemm_shape(M, N, K): + pytest.skip(f"Skipping test: invalid size m={M}, n={N}, k={K}") torch.manual_seed(seed) fp8_info = torch.finfo(torch.float8_e4m3fn) diff --git a/tests/kernels/quantization/test_block_int8.py b/tests/kernels/quantization/test_block_int8.py index 104f23fd7cd..a4e9f83f0ea 100644 --- a/tests/kernels/quantization/test_block_int8.py +++ b/tests/kernels/quantization/test_block_int8.py @@ -18,6 +18,10 @@ pytest.skip("INT8 Triton requires CUDA 7.0 or higher", allow_module_level=True) +vllm_config = VllmConfig() +vllm_config.scheduler_config.max_num_seqs = 128 +vllm_config.scheduler_config.max_model_len = 8192 + # For test def native_per_token_group_quant_int8(x, @@ -174,7 +178,6 @@ def test_w8a8_block_int8_fused_moe(M, N, K, E, topk, block_size, dtype, seed): score = torch.randn((M, E), dtype=dtype) # Set the context to avoid lots of warning spam. - vllm_config = VllmConfig() with set_current_vllm_config(vllm_config): out = fused_moe( a, diff --git a/tests/lora/test_lora_huggingface.py b/tests/lora/test_lora_huggingface.py index 0875128c4ff..90498c47fb1 100644 --- a/tests/lora/test_lora_huggingface.py +++ b/tests/lora/test_lora_huggingface.py @@ -30,7 +30,7 @@ def test_load_checkpoints_from_huggingface(lora_fixture_name, request): lora_path = get_adapter_absolute_path(lora_name) - # lora loading should work for either absolute path and hugggingface id. + # lora loading should work for either absolute path and huggingface id. peft_helper = PEFTHelper.from_local_dir(lora_path, 4096) lora_model = LoRAModel.from_local_checkpoint( lora_path, diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index 11dfe4d4995..bdaba22c3c7 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -20,11 +20,11 @@ def test_hf_transfer_auto_activation(): try: # enable hf hub transfer if available import hf_transfer # type: ignore # noqa - HF_TRANFER_ACTIVE = True + HF_TRANSFER_ACTIVE = True except ImportError: - HF_TRANFER_ACTIVE = False + HF_TRANSFER_ACTIVE = False assert (huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER == - HF_TRANFER_ACTIVE) + HF_TRANSFER_ACTIVE) def test_download_weights_from_hf(): diff --git a/tests/models/multimodal/generation/test_common.py b/tests/models/multimodal/generation/test_common.py index dead2edc4fa..d51a03dfea7 100644 --- a/tests/models/multimodal/generation/test_common.py +++ b/tests/models/multimodal/generation/test_common.py @@ -8,14 +8,14 @@ from pathlib import PosixPath import pytest -from transformers import (AutoModelForImageTextToText, +from transformers import (AutoModel, AutoModelForImageTextToText, AutoModelForTextToWaveform, AutoModelForVision2Seq) from vllm.platforms import current_platform from vllm.utils import identity -from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets, - VideoTestAssets, VllmRunner) +from ....conftest import (IMAGE_ASSETS, AudioTestAssets, HfRunner, + ImageTestAssets, VideoTestAssets, VllmRunner) from ....utils import (create_new_process_for_each_test, large_gpu_mark, multi_gpu_marks) from ...utils import check_outputs_equal @@ -158,6 +158,17 @@ image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], marks=[pytest.mark.core_model, pytest.mark.cpu_model], ), + "ultravox": VLMTestInfo( + models = ["fixie-ai/ultravox-v0_5-llama-3_2-1b"], + test_type=VLMTestType.AUDIO, + prompt_formatter=lambda audio_prompt: f"<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n{audio_prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", # noqa: E501 + audio_idx_to_prompt=lambda idx: "<|audio|>", + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModel, + hf_output_post_proc=model_utils.ultravox_trunc_hf_output, + marks=[pytest.mark.core_model, pytest.mark.cpu_model], + ), #### Extended model tests "aria": VLMTestInfo( models=["rhymes-ai/Aria"], @@ -393,7 +404,6 @@ formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 ), limit_mm_per_prompt={"video": 4}, - runner_mm_key="videos", )], ), "llava_next_video": VLMTestInfo( @@ -706,6 +716,7 @@ def _mark_splits( # - multi-image # - image embeddings # - video +# - audio # - custom inputs @pytest.mark.parametrize( "model_type,test_case", @@ -803,6 +814,28 @@ def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=False, + )) +def test_audio_models(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( @@ -930,6 +963,29 @@ def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, ) +@pytest.mark.parametrize( + "model_type,test_case", + get_parametrized_options( + VLM_TEST_SETTINGS, + test_type=VLMTestType.AUDIO, + create_new_process_for_each_test=True, + )) +def test_audio_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs, + hf_runner: type[HfRunner], + vllm_runner: type[VllmRunner], + audio_assets: AudioTestAssets, monkeypatch): + if model_type in REQUIRES_V0_MODELS: + monkeypatch.setenv("VLLM_USE_V1", "0") + model_test_info = VLM_TEST_SETTINGS[model_type] + runners.run_audio_test( + model_test_info=model_test_info, + test_case=test_case, + hf_runner=hf_runner, + vllm_runner=vllm_runner, + audio_assets=audio_assets, + ) + + @pytest.mark.parametrize( "model_type,test_case", get_parametrized_options( diff --git a/tests/models/multimodal/generation/test_ultravox.py b/tests/models/multimodal/generation/test_ultravox.py index 322d886a593..2c8a06688ca 100644 --- a/tests/models/multimodal/generation/test_ultravox.py +++ b/tests/models/multimodal/generation/test_ultravox.py @@ -1,20 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 import json -from typing import Any, Optional +from typing import Any import numpy as np import pytest import pytest_asyncio -from transformers import AutoModel, AutoTokenizer +from transformers import AutoTokenizer -from vllm.multimodal.audio import resample_audio_librosa -from vllm.sequence import SampleLogprobs - -from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner +from ....conftest import AUDIO_ASSETS, AudioTestAssets, VllmRunner from ....utils import RemoteOpenAIServer from ...registry import HF_EXAMPLE_MODELS -from ...utils import check_logprobs_close MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b" @@ -88,79 +84,6 @@ def _get_prompt(audio_count, question, placeholder): add_generation_prompt=True) -def vllm_to_hf_output(vllm_output: tuple[list[int], str, - Optional[SampleLogprobs]], - model: str): - """Sanitize vllm output to be comparable with hf output.""" - output_ids, output_str, out_logprobs = vllm_output - - tokenizer = AutoTokenizer.from_pretrained(model) - eos_token_id = tokenizer.eos_token_id - - hf_output_ids = output_ids[:] - hf_output_str = output_str - if hf_output_ids[-1] == eos_token_id: - hf_output_str = hf_output_str + tokenizer.decode(eos_token_id) - - return hf_output_ids, hf_output_str, out_logprobs - - -def run_test( - hf_runner: type[HfRunner], - vllm_runner: type[VllmRunner], - prompts_and_audios: list[tuple[str, str, AudioTuple]], - model: str, - *, - dtype: str, - max_tokens: int, - num_logprobs: int, - **kwargs, -): - """Inference result should be the same between hf and vllm.""" - model_info = HF_EXAMPLE_MODELS.find_hf_info(model) - model_info.check_available_online(on_fail="skip") - model_info.check_transformers_version(on_fail="skip") - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - with vllm_runner(model, dtype=dtype, enforce_eager=True, - **kwargs) as vllm_model: - vllm_outputs_per_audio = [ - vllm_model.generate_greedy_logprobs([vllm_prompt], - max_tokens, - num_logprobs=num_logprobs, - audios=[audio]) - for vllm_prompt, _, audio in prompts_and_audios - ] - - with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: - hf_outputs_per_audio = [ - hf_model.generate_greedy_logprobs_limit( - [hf_prompt], - max_tokens, - num_logprobs=num_logprobs, - audios=[(resample_audio_librosa(audio[0], - orig_sr=audio[1], - target_sr=16000), 16000)]) - for _, hf_prompt, audio in prompts_and_audios - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_audio, - vllm_outputs_per_audio): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - def run_multi_audio_test( vllm_runner: type[VllmRunner], prompts_and_audios: list[tuple[str, list[AudioTuple]]], @@ -194,35 +117,6 @@ def run_multi_audio_test( assert all(tokens for tokens, *_ in vllm_outputs) -@pytest.mark.core_model -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("vllm_kwargs", [ - pytest.param({}, marks=pytest.mark.cpu_model), - pytest.param(CHUNKED_PREFILL_KWARGS), -]) -def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets, - dtype: str, max_tokens: int, num_logprobs: int, - vllm_kwargs: dict) -> None: - audio_inputs = [( - _get_prompt(1, audio, VLLM_PLACEHOLDER), - _get_prompt(1, audio, HF_PLACEHOLDER), - audio.audio_and_sample_rate, - ) for audio in audio_assets] - - run_test( - hf_runner, - vllm_runner, - audio_inputs, - MODEL_NAME, - dtype=dtype, - max_tokens=max_tokens, - num_logprobs=num_logprobs, - **vllm_kwargs, - ) - - @pytest.mark.core_model @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) diff --git a/tests/models/multimodal/generation/vlm_utils/builders.py b/tests/models/multimodal/generation/vlm_utils/builders.py index e3ba955a96a..32117c8d8dc 100644 --- a/tests/models/multimodal/generation/vlm_utils/builders.py +++ b/tests/models/multimodal/generation/vlm_utils/builders.py @@ -7,18 +7,21 @@ import torch +from vllm.multimodal.audio import AudioResampler from vllm.multimodal.image import rescale_image_size from vllm.multimodal.video import (rescale_video_size, resize_video, sample_frames_from_video) -from .....conftest import ImageTestAssets, VideoTestAssets -from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER, +from .....conftest import AudioTestAssets, ImageTestAssets, VideoTestAssets +from .types import (SINGLE_AUDIO_BASE_PROMPT, SINGLE_IMAGE_BASE_PROMPTS, + TEST_AUDIO_PLACEHOLDER, TEST_IMG_PLACEHOLDER, TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT, - ImageSizeWrapper, SizeType, VLMTestInfo) + ImageSizeWrapper, PromptWithMultiModalInput, SizeType, + VLMTestInfo) -def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], - str], +def replace_test_placeholder(prompt: str, mm_idx_to_prompt: Callable[[int], + str], test_placeholder: str) -> str: """Given a prompt, replaces each test placeholder with the model-specific tag. @@ -26,7 +29,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], prompt_segments = prompt.split(test_placeholder) img_prompt = prompt_segments[0] for placeholder_idx, next_seg in enumerate(prompt_segments[1:], start=1): - img_prompt += img_idx_to_prompt(placeholder_idx) + img_prompt += mm_idx_to_prompt(placeholder_idx) img_prompt += next_seg return img_prompt @@ -34,6 +37,7 @@ def replace_test_placeholder(prompt: str, img_idx_to_prompt: Callable[[int], def get_model_prompts(base_prompts: Iterable[str], img_idx_to_prompt: Optional[Callable[[int], str]], video_idx_to_prompt: Optional[Callable[[int], str]], + audio_idx_to_prompt: Optional[Callable[[int], str]], prompt_formatter: Callable[[str], str]) -> list[str]: """Given a model-agnostic base prompt and test configuration for a model(s) to be tested, update the media placeholders and apply the prompt formatting @@ -60,6 +64,11 @@ def get_model_prompts(base_prompts: Iterable[str], video_idx_to_prompt, TEST_VIDEO_PLACEHOLDER) + if audio_idx_to_prompt: + base_prompt = replace_test_placeholder(base_prompt, + audio_idx_to_prompt, + TEST_AUDIO_PLACEHOLDER) + # Apply the prompt formatter to wrap the base prompt with # the correct media placeholders to get the model test prompt model_prompt = prompt_formatter(base_prompt) @@ -68,10 +77,11 @@ def get_model_prompts(base_prompts: Iterable[str], def build_single_image_inputs_from_test_info( - test_info: VLMTestInfo, - image_assets: ImageTestAssets, - size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None): + test_info: VLMTestInfo, + image_assets: ImageTestAssets, + size_wrapper: ImageSizeWrapper, + tmp_path: Optional[PosixPath] = None, +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError( "Prompt formatter must be set to build single image inputs") @@ -79,6 +89,7 @@ def build_single_image_inputs_from_test_info( model_prompts = get_model_prompts(test_info.single_image_prompts, test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter) # For models that require a local path / URL encoded in the image; export @@ -97,28 +108,32 @@ def build_single_image_inputs_from_test_info( return build_single_image_inputs(images, model_prompts, size_wrapper) -def build_single_image_inputs(images, model_prompts, - size_wrapper: ImageSizeWrapper): +def build_single_image_inputs( + images, model_prompts, + size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: # For every image / prompt pair, get a pair containing two lists of # length size_factors, where the first contains duplicates of the model # prompt [str], and the second contains copies of the image after being # scaled by one of the size factors. # # NOTE: rescaling preserves the image aspect ratio. - return [( - [prompt for _ in size_wrapper.data], - [ - apply_image_size_scaling(image, size, size_wrapper.type) - for size in size_wrapper.data - ], - ) for image, prompt in zip(images, model_prompts)] + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + image_data=[ + apply_image_size_scaling(image, size, size_wrapper.type) + for size in size_wrapper.data + ], + ) for image, prompt in zip(images, model_prompts) + ] def build_multi_image_inputs_from_test_info( - test_info: VLMTestInfo, - image_assets: ImageTestAssets, - size_wrapper: ImageSizeWrapper, - tmp_path: Optional[PosixPath] = None): + test_info: VLMTestInfo, + image_assets: ImageTestAssets, + size_wrapper: ImageSizeWrapper, + tmp_path: Optional[PosixPath] = None, +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError( "Prompt formatter must be set to build multi image inputs") @@ -126,6 +141,7 @@ def build_multi_image_inputs_from_test_info( model_prompts = get_model_prompts([test_info.multi_image_prompt], test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter) if test_info.prompt_path_encoder is not None: @@ -146,15 +162,18 @@ def build_multi_image_inputs_from_test_info( ) -def build_multi_image_inputs(image_lists, model_prompts, - size_wrapper: ImageSizeWrapper): - return [( - [prompt for _ in size_wrapper.data], - [[ - apply_image_size_scaling(image, size, size_wrapper.type) - for image in images - ] for size in size_wrapper.data], - ) for images, prompt in zip(image_lists, model_prompts)] +def build_multi_image_inputs( + image_lists, model_prompts, + size_wrapper: ImageSizeWrapper) -> list[PromptWithMultiModalInput]: + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + image_data=[[ + apply_image_size_scaling(image, size, size_wrapper.type) + for image in images + ] for size in size_wrapper.data], + ) for images, prompt in zip(image_lists, model_prompts) + ] def build_embedding_inputs_from_test_info( @@ -177,6 +196,7 @@ def build_embedding_inputs_from_test_info( SINGLE_IMAGE_BASE_PROMPTS, test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter, ) @@ -195,13 +215,14 @@ def build_video_inputs_from_test_info( video_assets: VideoTestAssets, size_wrapper: ImageSizeWrapper, num_frames: int, -): +) -> list[PromptWithMultiModalInput]: if test_info.prompt_formatter is None: raise ValueError("Prompt formatter must be set to build video inputs") model_prompts = get_model_prompts( [VIDEO_BASE_PROMPT], test_info.img_idx_to_prompt, test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, test_info.prompt_formatter, ) @@ -213,10 +234,14 @@ def build_video_inputs_from_test_info( video_scaler = (resize_video if size_wrapper.type == SizeType.FIXED_SIZE else rescale_video_size) - return [( - [prompt for _ in size_wrapper.data], - [video_scaler(video, size) for size in size_wrapper.data], - ) for video, prompt in zip(sampled_vids, model_prompts)] + return [ + PromptWithMultiModalInput( + prompts=[prompt for _ in size_wrapper.data], + video_data=[ + video_scaler(video, size) for size in size_wrapper.data + ], + ) for video, prompt in zip(sampled_vids, model_prompts) + ] def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], @@ -236,3 +261,37 @@ def apply_image_size_scaling(image, size: Union[float, tuple[int, int]], # We have a list of fixed sizes return image.resize(size) raise ValueError("ImageSizeWrapper type must be FIXED_SIZE or SIZE_FACTOR") + + +def build_audio_inputs_from_test_info( + test_info: VLMTestInfo, + audio_assets: AudioTestAssets, +) -> list[PromptWithMultiModalInput]: + if test_info.prompt_formatter is None: + raise ValueError("Prompt formatter must be set to build audio inputs") + model_prompts = get_model_prompts( + SINGLE_AUDIO_BASE_PROMPT, + test_info.img_idx_to_prompt, + test_info.video_idx_to_prompt, + test_info.audio_idx_to_prompt, + test_info.prompt_formatter, + ) + resampler = AudioResampler( + target_sr=16000, + method="librosa", + ) + audios = [asset.audio_and_sample_rate for asset in audio_assets] + resampled_audios = [( + resampler.resample( + audio, + orig_sr=sr, + ), + int(resampler.target_sr), + ) for audio, sr in audios] + + return [ + PromptWithMultiModalInput( + prompts=model_prompts, + audio_data=resampled_audios, + ) + ] diff --git a/tests/models/multimodal/generation/vlm_utils/case_filtering.py b/tests/models/multimodal/generation/vlm_utils/case_filtering.py index 8e825676b8f..a5077a090b5 100644 --- a/tests/models/multimodal/generation/vlm_utils/case_filtering.py +++ b/tests/models/multimodal/generation/vlm_utils/case_filtering.py @@ -83,7 +83,7 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): test_info.num_video_frames) # No sizes passed for custom inputs, since inputs are directly provided - if test_type != VLMTestType.CUSTOM_INPUTS: + if test_type not in (VLMTestType.CUSTOM_INPUTS, VLMTestType.AUDIO): wrapped_sizes = get_wrapped_test_sizes(test_info, test_type) if wrapped_sizes is None: raise ValueError( @@ -91,7 +91,7 @@ def get_model_type_cases(model_type: str, test_info: VLMTestInfo): iter_kwargs["size_wrapper"] = wrapped_sizes #Otherwise expand the custom test options instead - else: + elif test_type == VLMTestType.CUSTOM_INPUTS: if test_info.custom_test_opts is None: raise ValueError("Test has type CUSTOM_INPUTS, but none given") iter_kwargs["custom_test_opts"] = test_info.custom_test_opts @@ -136,8 +136,8 @@ def get_wrapped_test_sizes( ImageSizeWrapper(type=SizeType.SIZE_FACTOR, data=factor) for factor in EMBEDDING_SIZE_FACTORS ]) - # Custom inputs have preprocessed inputs - elif test_type == VLMTestType.CUSTOM_INPUTS: + # Audio and Custom inputs have preprocessed inputs + elif test_type in (VLMTestType.AUDIO, VLMTestType.CUSTOM_INPUTS): return tuple() size_factors = test_info.image_size_factors \ diff --git a/tests/models/multimodal/generation/vlm_utils/core.py b/tests/models/multimodal/generation/vlm_utils/core.py index c3d20f56855..ccd2799abd9 100644 --- a/tests/models/multimodal/generation/vlm_utils/core.py +++ b/tests/models/multimodal/generation/vlm_utils/core.py @@ -1,9 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 """Core test implementation to be shared across modalities.""" -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Optional import torch -from PIL.Image import Image from transformers.models.auto.auto_factory import _BaseAutoModelClass from vllm.config import TaskOption @@ -11,14 +10,14 @@ from .....conftest import HfRunner, VllmRunner from ....registry import HF_EXAMPLE_MODELS -from .types import RunnerOutput +from .types import PromptWithMultiModalInput, RunnerOutput def run_test( *, hf_runner: type[HfRunner], vllm_runner: type[VllmRunner], - inputs: list[tuple[list[str], list[Union[list[Image], Image]]]], + inputs: list[PromptWithMultiModalInput], model: str, dtype: str, max_tokens: int, @@ -38,7 +37,6 @@ def run_test( hf_model_kwargs: Optional[dict[str, Any]], patch_hf_runner: Optional[Callable[[HfRunner], HfRunner]], task: TaskOption = "auto", - runner_mm_key: str = "images", distributed_executor_backend: Optional[str] = None, tensor_parallel_size: int = 1, vllm_embeddings: Optional[torch.Tensor] = None, @@ -94,10 +92,16 @@ def run_test( if stop_str: vllm_kwargs["stop"] = stop_str - for prompts, media in vllm_inputs: - vllm_kwargs[runner_mm_key] = media + for prompts, image_data, video_data, audio_data in vllm_inputs: + mm_data = dict(images=image_data, + videos=video_data, + audios=audio_data) + vllm_kwargs_with_mm_data = vllm_kwargs | mm_data vllm_output = vllm_model.generate_greedy_logprobs( - prompts, max_tokens, num_logprobs=num_logprobs, **vllm_kwargs) + prompts, + max_tokens, + num_logprobs=num_logprobs, + **vllm_kwargs_with_mm_data) vllm_outputs_per_mm.append(vllm_output) hf_model = hf_runner(model, @@ -122,14 +126,17 @@ def run_test( if stop_str: hf_kwargs["stop_strings"] = stop_str - for prompts, media in inputs: - hf_kwargs[runner_mm_key] = media + for prompts, image_data, video_data, audio_data in inputs: + mm_data = dict(images=image_data, + videos=video_data, + audios=audio_data) + hf_kwargs_with_mm_data = hf_kwargs | mm_data hf_output = hf_model.generate_greedy_logprobs_limit( prompts, max_tokens, num_logprobs=num_logprobs, tokenizer=tokenizer, - **hf_kwargs) + **hf_kwargs_with_mm_data) hf_outputs_per_mm.append(hf_output) # Apply output processing / sanitation to the vLLM and HF runner results diff --git a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py index 235618ae547..cc104556113 100644 --- a/tests/models/multimodal/generation/vlm_utils/custom_inputs.py +++ b/tests/models/multimodal/generation/vlm_utils/custom_inputs.py @@ -12,7 +12,7 @@ from .....conftest import IMAGE_ASSETS, VIDEO_ASSETS from .builders import build_multi_image_inputs, build_single_image_inputs -from .types import ImageSizeWrapper, SizeType +from .types import ImageSizeWrapper, PromptWithMultiModalInput, SizeType def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): @@ -32,24 +32,28 @@ def multi_image_multi_aspect_ratio_inputs(formatter: Callable[[str], str]): "\nWhat is the season?", ] formatted_prompts = [formatter(prompt) for prompt in img_prompts] - - return [( - formatted_prompts, + aspect_ratio_images = [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], [ - [stop_sign, cherry_blossom], - # Images with different sizes and aspect-ratios - [ - rescale_image_size(stop_sign, 0.1), - stop_sign, - ], - [ - stop_sign, - rescale_image_size(stop_sign, 0.25), - cherry_blossom.resize((183, 488)), - cherry_blossom.resize((488, 183)) - ], - cherry_blossom, - ])] + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ] + + return [ + PromptWithMultiModalInput( + prompts=formatted_prompts, + image_data=aspect_ratio_images, + ) + ] def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], @@ -68,24 +72,28 @@ def multi_video_multi_aspect_ratio_inputs(formatter: Callable[[str], str], "