From eefd1c96611ff3b2fc8a7169b517cd7bbdfa14bf Mon Sep 17 00:00:00 2001 From: Chaofan Lin Date: Wed, 29 May 2024 18:13:01 +0800 Subject: [PATCH] [Op] Move SharedAttention to parrot.op (#3) This PR move our shared kernel from 3rdparty `vLLM` to the namespace `parrot.op`. And it also adds some tests and benchmark on the shared kernel. --------- Co-authored-by: Chengruidong Zhang --- .env | 3 - 3rdparty/vllm/pyproject.toml | 2 +- INSTALL.md | 8 +- README.md | 2 +- benchmark/bench_kernel.py | 82 +++ csrc/attention.cpp | 62 ++ csrc/attention/attention_dtypes.h | 6 + csrc/attention/attention_generic.cuh | 64 ++ csrc/attention/attention_kernels.cu | 574 ++++++++++++++++++ csrc/attention/attention_post_kernels.cu | 531 ++++++++++++++++ csrc/attention/attention_prev_kernels.cu | 528 ++++++++++++++++ csrc/attention/attention_utils.cuh | 55 ++ csrc/attention/dtype_bfloat16.cuh | 451 ++++++++++++++ csrc/attention/dtype_float16.cuh | 444 ++++++++++++++ csrc/attention/dtype_float32.cuh | 268 ++++++++ .../builtin/kernels/shared_flash_decoding.py | 4 +- parrot/engine/config.py | 5 + requirements.txt | 6 +- setup.py | 125 +++- tests/kernel/test_shared_kernel.py | 132 ++++ 20 files changed, 3328 insertions(+), 24 deletions(-) create mode 100644 benchmark/bench_kernel.py create mode 100644 csrc/attention.cpp create mode 100644 csrc/attention/attention_dtypes.h create mode 100644 csrc/attention/attention_generic.cuh create mode 100644 csrc/attention/attention_kernels.cu create mode 100644 csrc/attention/attention_post_kernels.cu create mode 100644 csrc/attention/attention_prev_kernels.cu create mode 100644 csrc/attention/attention_utils.cuh create mode 100644 csrc/attention/dtype_bfloat16.cuh create mode 100644 csrc/attention/dtype_float16.cuh create mode 100644 csrc/attention/dtype_float32.cuh create mode 100644 tests/kernel/test_shared_kernel.py diff --git a/.env b/.env index db9aef5..445c285 100644 --- a/.env +++ b/.env @@ -9,7 +9,4 @@ export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH export SIMULATE_NETWORK_LATENCY_PRT=1 # 0 off, 1 on export SIMULATE_NETWORK_LATENCY_FS=1 # 0 off, 1 on -# export FS_MAX_GEN_LENGTH=20 -# export FS_MAX_GEN_LENGTH=50 - # CUDA_LAUNCH_BLOCKING=1 \ No newline at end of file diff --git a/3rdparty/vllm/pyproject.toml b/3rdparty/vllm/pyproject.toml index 2645664..fcaeb09 100644 --- a/3rdparty/vllm/pyproject.toml +++ b/3rdparty/vllm/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "ninja", "packaging", "setuptools", - "torch >= 2.0.0", + "torch == 2.1.0", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/INSTALL.md b/INSTALL.md index 9e69dfa..df6913c 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -15,7 +15,13 @@ pip install torch==2.1.0 --upgrade --index-url https://download.pytorch.org/whl/ ### Clone the Project ```bash -git clone --recursive https://github.com/SiriusNEO/LLMOS-Parrot.git +git clone --recursive https://github.com/microsoft/ParrotServe.git +``` + +### Configure the Environment + +```bash +source .env ``` ### Install dependencies diff --git a/README.md b/README.md index cb28359..01d7095 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Parrot: Efficient Serving of LLM-based Application with Semantic Variables +# Parrot: Efficient Serving of LLM-based Application with Semantic Variable This project is a research prototype for now. Being eargerly iterated. diff --git a/benchmark/bench_kernel.py b/benchmark/bench_kernel.py new file mode 100644 index 0000000..8961f82 --- /dev/null +++ b/benchmark/bench_kernel.py @@ -0,0 +1,82 @@ +from transformers import AutoTokenizer +import torch +import json + +from parrot.engine.builtin.builtin_runner import BuiltinRunner +from parrot.engine.config import BuiltinConfig +from parrot.engine.primitive_job import Fill, Generate +from parrot.sampling_config import SamplingConfig + + +def bench_decode( + attn_func: str, batch_size: int, shared_len: int, diverged_len: int, output_len: int +): + config = BuiltinConfig( + num_kv_cache_blocks=2000, + attn_func=attn_func, + block_size=16, + max_seq_len=65536, + ) + sampling_config = SamplingConfig( + max_gen_length=output_len, + ignore_tokenizer_eos=True, + ) + + runner = BuiltinRunner("lmsys/vicuna-13b-v1.3", config=config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + + context_len = shared_len + diverged_len + + prompt_token_ids = [[100] * context_len for _ in range(batch_size)] + + shared_fill = Fill( + session_id=0, + task_id=0, + context_id=0, + parent_context_id=-1, + token_ids=prompt_token_ids[0][:shared_len], + ) + diverged_fills = [ + Fill( + session_id=0, + task_id=0, + context_id=i + 1, + parent_context_id=0, + token_ids=prompt[shared_len:], + ) + for i, prompt in enumerate(prompt_token_ids) + ] + gens = [ + Generate( + session_id=0, + task_id=0, + context_id=i + 1, + parent_context_id=0, + sampling_config=sampling_config, + ) + for i, prompt in enumerate(prompt_token_ids) + ] + + runner.run_iter([shared_fill]) + runner.run_iter(diverged_fills) + for _ in range(output_len): + runner.run_iter(gens) + + del runner + + +if __name__ == "__main__": + # bench_decode( + # attn_func="xformers_fill_vllm_paged_attention_generate", + # batch_size=64, + # shared_len=8192, + # diverged_len=10, + # output_len=10, + # ) + bench_decode( + attn_func="xformers_fill_shared_prompts_generate", + batch_size=64, + shared_len=8192, + diverged_len=10, + output_len=10, + ) diff --git a/csrc/attention.cpp b/csrc/attention.cpp new file mode 100644 index 0000000..06cc78a --- /dev/null +++ b/csrc/attention.cpp @@ -0,0 +1,62 @@ +#include +#include + +void single_query_cached_kv_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& block_lens, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& block_nums, // [num_seqs] + torch::Tensor& context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + +void single_query_cached_kv_prev_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& qk_maxs, // [num_seqs] + torch::Tensor& exp_sums, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + +void single_query_cached_kv_post_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& prev_qk_maxs, // [num_seqs] + torch::Tensor& prev_exp_sums, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional& alibi_slopes); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def( + "single_query_cached_kv_attention", + &single_query_cached_kv_attention, + "Compute the attention between an input query and the cached key/value tensors"); + m.def( + "single_query_cached_kv_prev_attention", + &single_query_cached_kv_prev_attention, + "Compute the attention between an input query and the cached key/value tensors and log middle results"); + m.def( + "single_query_cached_kv_post_attention", + &single_query_cached_kv_post_attention, + "Compute the attention between an input query and the cached key/value tensors based on previous results"); +} diff --git a/csrc/attention/attention_dtypes.h b/csrc/attention/attention_dtypes.h new file mode 100644 index 0000000..88b4edd --- /dev/null +++ b/csrc/attention/attention_dtypes.h @@ -0,0 +1,6 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" +#include "dtype_bfloat16.cuh" diff --git a/csrc/attention/attention_generic.cuh b/csrc/attention/attention_generic.cuh new file mode 100644 index 0000000..31fb401 --- /dev/null +++ b/csrc/attention/attention_generic.cuh @@ -0,0 +1,64 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu new file mode 100644 index 0000000..917e0c4 --- /dev/null +++ b/csrc/attention/attention_kernels.cu @@ -0,0 +1,574 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#include + +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm +{ + + // Utility function for attention softmax. + template + inline __device__ float block_sum(float *red_smem, float sum) + { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) + { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) + { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) + { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) + { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); + } + + // Grid: (num_heads, num_seqs). + template < + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> + __global__ void single_query_cached_kv_attention_kernel( + scalar_t *__restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t *__restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t *__restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int *__restrict__ head_mapping, // [num_heads] + const float scale, + const int *__restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ block_lens, // [num_seqs, max_num_blocks_per_seq] + const int *__restrict__ block_nums, // [num_seqs] + const int *__restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float *__restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) + { + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t *q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) + { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float *logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + + const int *block_table = block_tables + seq_idx * max_num_blocks_per_seq; + // const int* block_len = block_lens + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + // std::printf("[%d] %d\n", seq_idx, context_len); + const int num_blocks = block_nums[seq_idx]; + // const int context_len = num_blocks * BLOCK_SIZE; + // int valid_token_num; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) + { + const int physical_block_number = block_table[block_idx]; + // valid_token_num = block_len[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) + { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) + { + const scalar_t *k_ptr = k_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + // TODO: support AliBi: token_idx -> token_cnt + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0; + + if (thread_group_offset == 0) + { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + // const bool mask = physical_block_offset >= valid_token_num; + const bool mask = token_idx >= context_len; + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) + { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) + { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) + { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) + { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) + { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) + { + accs[i] = 0.f; + } + + // scalar_t zero_value; + // zero(zero_value); + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) + { + const int physical_block_number = block_table[block_idx]; + // valid_token_num = block_len[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + // scalar_t *logits_vec_ptr = reinterpret_cast(&logits_vec); + // #pragma unroll + // for (int j = 0; j < V_VEC_SIZE; j++) + // { + // logits_vec_ptr[j] = (physical_block_offset + j < valid_token_num) ? logits_vec_ptr[j] : zero_value; + // } + const scalar_t *v_ptr = v_cache + physical_block_number * kv_block_stride + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) + { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) + { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) + { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) + { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float *out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) + { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) + { + float *dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) + { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) + { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) + { + const float *src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) + { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) + { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) + { + scalar_t *out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) + { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) + { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } + } + +} // namespace vllm + +#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + vllm::single_query_cached_kv_attention_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + block_lens_ptr, \ + block_nums_ptr, \ + context_lens_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + +// TODO(woosuk): Tune NUM_THREADS. +template < + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_attention_launcher( + torch::Tensor &out, + torch::Tensor &query, + torch::Tensor &key_cache, + torch::Tensor &value_cache, + torch::Tensor &head_mapping, + float scale, + torch::Tensor &block_tables, + torch::Tensor &block_lens, + torch::Tensor &block_nums, + torch::Tensor &context_lens, + int max_context_len, + const c10::optional &alibi_slopes) +{ + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float *alibi_slopes_ptr = alibi_slopes ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T *out_ptr = reinterpret_cast(out.data_ptr()); + T *query_ptr = reinterpret_cast(query.data_ptr()); + T *key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T *value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int *head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int *block_tables_ptr = block_tables.data_ptr(); + int *block_lens_ptr = block_lens.data_ptr(); + int *block_nums_ptr = block_nums.data_ptr(); + int *context_lens_ptr = context_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) + { + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; + case 64: + LAUNCH_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 112: + LAUNCH_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + // case 160: + // LAUNCH_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; + case 256: + LAUNCH_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_attention_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + block_lens, \ + block_nums, \ + context_lens, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) \ + { \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ + case 8: \ + CALL_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_KERNEL_LAUNCHER(T, 32); \ + break; \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void single_query_cached_kv_attention( + torch::Tensor &out, // [num_seqs, num_heads, head_size] + torch::Tensor &query, // [num_seqs, num_heads, head_size] + torch::Tensor &key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor &value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor &head_mapping, // [num_heads] + float scale, + torch::Tensor &block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor &block_lens, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor &block_nums, // [num_seqs] + torch::Tensor &context_lens, // [num_seqs] + int block_size, + int max_context_len, + const c10::optional &alibi_slopes) +{ + if (query.dtype() == at::ScalarType::Float) + { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + } + else if (query.dtype() == at::ScalarType::Half) + { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + } + else if (query.dtype() == at::ScalarType::BFloat16) + { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } + else + { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN \ No newline at end of file diff --git a/csrc/attention/attention_post_kernels.cu b/csrc/attention/attention_post_kernels.cu new file mode 100644 index 0000000..563d0b6 --- /dev/null +++ b/csrc/attention/attention_post_kernels.cu @@ -0,0 +1,531 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#include + +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void single_query_cached_kv_post_attention_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const float* __restrict__ prev_qk_maxs, // [num_seqs, num_heads] + const float* __restrict__ prev_exp_sums, // [num_seqs, num_heads] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; + const int num_seqs = gridDim.y; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + float qk_scale = scale * 1.44269504; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = qk_scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = exp2f(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + + const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + float prev_qk_max = prev_qk_maxs[seq_idx * num_heads + head_idx]; + float prev_exp_sum = prev_exp_sums[seq_idx * num_heads + head_idx]; + float comb_qk_max = max(prev_qk_max, qk_max); + float prev_factor = exp2f(prev_qk_max - comb_qk_max); + float factor = exp2f(qk_max - comb_qk_max); + float comb_exp_sum = prev_factor * prev_exp_sum + factor * exp_sum; + prev_factor *= prev_exp_sum / comb_exp_sum; + factor *= exp_sum / comb_exp_sum; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + float prev_acc = to_float(*(out_ptr + row_idx)); + accs[i] = prev_factor * prev_acc + factor * accs[i]; + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +} // namespace vllm + +#define LAUNCH_POST_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + vllm::single_query_cached_kv_post_attention_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + prev_qk_maxs_ptr, \ + prev_exp_sums_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + +// TODO(woosuk): Tune NUM_THREADS. +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_post_attention_launcher( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + torch::Tensor& prev_qk_maxs, + torch::Tensor& prev_exp_sums, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + float* prev_qk_maxs_ptr = prev_qk_maxs.data_ptr(); + float* prev_exp_sums_ptr = prev_exp_sums.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_POST_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; + case 64: + LAUNCH_POST_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_POST_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_POST_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 112: + LAUNCH_POST_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_POST_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + // case 160: + // LAUNCH_POST_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_POST_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; + case 256: + LAUNCH_POST_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_post_attention_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + prev_qk_maxs, \ + prev_exp_sums, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ + case 8: \ + CALL_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_KERNEL_LAUNCHER(T, 32); \ + break; \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void single_query_cached_kv_post_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& prev_qk_maxs, + torch::Tensor& prev_exp_sums, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + if (query.dtype() == at::ScalarType::Float) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN diff --git a/csrc/attention/attention_prev_kernels.cu b/csrc/attention/attention_prev_kernels.cu new file mode 100644 index 0000000..ee94a28 --- /dev/null +++ b/csrc/attention/attention_prev_kernels.cu @@ -0,0 +1,528 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#include + +#define WARP_SIZE 32 +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +// Grid: (num_heads, num_seqs). +template< + typename scalar_t, + int HEAD_SIZE, + int BLOCK_SIZE, + int NUM_THREADS> +__global__ void single_query_cached_kv_prev_attention_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x] + const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size] + const int* __restrict__ head_mapping, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + float* __restrict__ qk_maxs, // [num_seqs, num_heads] + float* __restrict__ exp_sums, // [num_seqs, num_heads] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, + const int kv_block_stride, + const int kv_head_stride) { + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = (BLOCK_SIZE + WARP_SIZE - 1) / WARP_SIZE; + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int kv_head_idx = head_mapping[head_idx]; + const int seq_idx = blockIdx.y; + const int num_seqs = gridDim.y; + const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread group + // fetch or compute 16 bytes at a time. + // For example, if the size of a thread group is 4 and the data type is half, + // then the vector size is 16 / (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the query, and the second thread has 1, 5, 9, ... + // th vectors of the query, and so on. + // NOTE(woosuk): Because q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(scalar_t); + float qk_max = -FLT_MAX; + float qk_scale = scale * 1.44269504; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + const int context_len = context_lens[seq_idx]; + const int num_blocks = (context_len + BLOCK_SIZE - 1) / BLOCK_SIZE; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in the group + // has 0, 4, 8, ... th vectors of the key, and the second thread has 1, 5, 9, ... th + // vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + k_vecs[j] = *reinterpret_cast(k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = qk_scale * Qk_dot::dot(q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= context_len; + logits[token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + float val = exp2f(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < context_len; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = (HEAD_SIZE + NUM_ROWS_PER_ITER - 1) / NUM_ROWS_PER_ITER; + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + for (int block_idx = warp_idx; block_idx < num_blocks; block_idx += NUM_WARPS) { + const int physical_block_number = block_table[block_idx]; + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx)); + + const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec = *reinterpret_cast(v_ptr + offset); + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += __shfl_xor_sync(uint32_t(-1), acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for logits + // is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } + + if (thread_idx == blockDim.x - 1) { + qk_maxs[seq_idx * num_heads + head_idx] = qk_max; + } + if (thread_idx == blockDim.x - 2) { + exp_sums[seq_idx * num_heads + head_idx] = exp_sum; + } +} + +} // namespace vllm + +#define LAUNCH_PREV_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \ + vllm::single_query_cached_kv_prev_attention_kernel \ + <<>>( \ + out_ptr, \ + query_ptr, \ + key_cache_ptr, \ + value_cache_ptr, \ + head_mapping_ptr, \ + scale, \ + block_tables_ptr, \ + context_lens_ptr, \ + qk_maxs_ptr, \ + exp_sums_ptr, \ + max_num_blocks_per_seq, \ + alibi_slopes_ptr, \ + q_stride, \ + kv_block_stride, \ + kv_head_stride); + +// TODO(woosuk): Tune NUM_THREADS. +template< + typename T, + int BLOCK_SIZE, + int NUM_THREADS = 128> +void single_query_cached_kv_prev_attention_launcher( + torch::Tensor& out, + torch::Tensor& query, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& head_mapping, + float scale, + torch::Tensor& block_tables, + torch::Tensor& context_lens, + torch::Tensor& qk_maxs, + torch::Tensor& exp_sums, + int max_context_len, + const c10::optional& alibi_slopes) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = alibi_slopes ? + reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* head_mapping_ptr = reinterpret_cast(head_mapping.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + float* qk_maxs_ptr = qk_maxs.data_ptr(); + float* exp_sums_ptr = exp_sums.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_context_len = ((max_context_len + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_context_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs); + dim3 block(NUM_THREADS); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we omitted head sizes + // 32, 160, 192. + // case 32: + // LAUNCH_PREV_ATTENTION_KERNEL(T, 32, BLOCK_SIZE, NUM_THREADS); + // break; + case 64: + LAUNCH_PREV_ATTENTION_KERNEL(T, 64, BLOCK_SIZE, NUM_THREADS); + break; + case 80: + LAUNCH_PREV_ATTENTION_KERNEL(T, 80, BLOCK_SIZE, NUM_THREADS); + break; + case 96: + LAUNCH_PREV_ATTENTION_KERNEL(T, 96, BLOCK_SIZE, NUM_THREADS); + break; + case 112: + LAUNCH_PREV_ATTENTION_KERNEL(T, 112, BLOCK_SIZE, NUM_THREADS); + break; + case 128: + LAUNCH_PREV_ATTENTION_KERNEL(T, 128, BLOCK_SIZE, NUM_THREADS); + break; + // case 160: + // LAUNCH_PREV_ATTENTION_KERNEL(T, 160, BLOCK_SIZE, NUM_THREADS); + // break; + // case 192: + // LAUNCH_PREV_ATTENTION_KERNEL(T, 192, BLOCK_SIZE, NUM_THREADS); + // break; + case 256: + LAUNCH_PREV_ATTENTION_KERNEL(T, 256, BLOCK_SIZE, NUM_THREADS); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_KERNEL_LAUNCHER(T, BLOCK_SIZE) \ + single_query_cached_kv_prev_attention_launcher( \ + out, \ + query, \ + key_cache, \ + value_cache, \ + head_mapping, \ + scale, \ + block_tables, \ + context_lens, \ + qk_maxs, \ + exp_sums, \ + max_context_len, \ + alibi_slopes); + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_KERNEL_LAUNCHER_BLOCK_SIZE(T) \ + switch (block_size) { \ + /* case 1: */ \ + /* CALL_KERNEL_LAUNCHER(T, 1); */ \ + /* break; */ \ + /* case 2: */ \ + /* CALL_KERNEL_LAUNCHER(T, 2); */ \ + /* break; */ \ + /* case 4: */ \ + /* CALL_KERNEL_LAUNCHER(T, 4); */ \ + /* break; */ \ + case 8: \ + CALL_KERNEL_LAUNCHER(T, 8); \ + break; \ + case 16: \ + CALL_KERNEL_LAUNCHER(T, 16); \ + break; \ + case 32: \ + CALL_KERNEL_LAUNCHER(T, 32); \ + break; \ + /* case 64: */ \ + /* CALL_KERNEL_LAUNCHER(T, 64); */ \ + /* break; */ \ + /* case 128: */ \ + /* CALL_KERNEL_LAUNCHER(T, 128); */ \ + /* break; */ \ + /* case 256: */ \ + /* CALL_KERNEL_LAUNCHER(T, 256); */ \ + /* break; */ \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void single_query_cached_kv_prev_attention( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& head_mapping, // [num_heads] + float scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + torch::Tensor& qk_maxs, + torch::Tensor& exp_sums, + int block_size, + int max_context_len, + const c10::optional& alibi_slopes) { + if (query.dtype() == at::ScalarType::Float) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(float); + } else if (query.dtype() == at::ScalarType::Half) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(uint16_t); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_KERNEL_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN diff --git a/csrc/attention/attention_utils.cuh b/csrc/attention/attention_utils.cuh new file mode 100644 index 0000000..bb7df25 --- /dev/null +++ b/csrc/attention/attention_utils.cuh @@ -0,0 +1,55 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_dtypes.h" + +#include +#include + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm diff --git a/csrc/attention/dtype_bfloat16.cuh b/csrc/attention/dtype_bfloat16.cuh new file mode 100644 index 0000000..dad57d6 --- /dev/null +++ b/csrc/attention/dtype_bfloat16.cuh @@ -0,0 +1,451 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#include +#include +#include + +namespace vllm { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template<> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template<> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template<> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(val); +#endif +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return a + b; +#endif +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hadd2(a, b); +#endif +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul(a, b); +#endif +} + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul2(a, b); +#endif +} + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template<> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template<> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template<> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template<> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template<> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template<> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template<> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template<> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template<> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template<> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template<> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(a, b, c); +#endif +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(bf162bf162(a), b, c); +#endif +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template<> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template<> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template<> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template<> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst = __float22bfloat162_rn(src); +#endif +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#endif +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#endif +} + +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +inline __device__ float2 to_float(__nv_bfloat162 u) { + float2 tmp; + tmp.x = __bfloat162float(u.x); + tmp.y = __bfloat162float(u.y); + return tmp; +} + +inline __device__ Float4_ to_float(bf16_4_t u) { + Float4_ tmp; + tmp.x = to_float(u.x); + tmp.y = to_float(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(bf16_8_t u) { + Float8_ tmp; + tmp.x = to_float(u.x); + tmp.y = to_float(u.y); + tmp.z = to_float(u.z); + tmp.w = to_float(u.w); + return tmp; +} + +} // namespace vllm diff --git a/csrc/attention/dtype_float16.cuh b/csrc/attention/dtype_float16.cuh new file mode 100644 index 0000000..6ffc30c --- /dev/null +++ b/csrc/attention/dtype_float16.cuh @@ -0,0 +1,444 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#include + +namespace vllm { + +// FP16 vector types for Q, K, V. +template<> +struct Vec { + using Type = uint16_t; +}; +template<> +struct Vec { + using Type = uint32_t; +}; +template<> +struct Vec { + using Type = uint2; +}; +template<> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = Float4_; +}; +template<> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +inline __device__ float half_to_float(uint16_t h) { + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template<> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template<> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template<> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template<> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template<> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// Zero-out a vector. +inline __device__ void zero(uint16_t& dst) { + dst = uint16_t(0); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { + return half_to_float(u); +} + +inline __device__ float2 to_float(uint32_t u) { + return half2_to_float2(u); +} + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +} // namespace vllm diff --git a/csrc/attention/dtype_float32.cuh b/csrc/attention/dtype_float32.cuh new file mode 100644 index 0000000..960cf48 --- /dev/null +++ b/csrc/attention/dtype_float32.cuh @@ -0,0 +1,268 @@ +/* + * Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template<> +struct Vec { + using Type = float; +}; +template<> +struct Vec { + using Type = float2; +}; +template<> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template<> +struct FloatVec { + using Type = float; +}; +template<> +struct FloatVec { + using Type = float2; +}; +template<> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { + return a + b; +} + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template<> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template<> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template<> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template<> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template<> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { + return a * b + c; +} + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template<> +inline __device__ float sum(float v) { + return v; +} + +template<> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template<> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template<> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template<> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { + return a * b; +} + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { + dst = src; +} + +inline __device__ void from_float(float2& dst, float2 src) { + dst = src; +} + +inline __device__ void from_float(float4& dst, float4 src) { + dst = src; +} + +// From float to float. +inline __device__ float to_float(float u) { + return u; +} + +inline __device__ float2 to_float(float2 u) { + return u; +} + +inline __device__ float4 to_float(float4 u) { + return u; +} + +inline __device__ Float4_ to_float(Float4_ u) { + return u; +} + +inline __device__ Float8_ to_float(Float8_ u) { + return u; +} + +} // namespace vllm diff --git a/parrot/engine/builtin/kernels/shared_flash_decoding.py b/parrot/engine/builtin/kernels/shared_flash_decoding.py index a172ab0..732957a 100644 --- a/parrot/engine/builtin/kernels/shared_flash_decoding.py +++ b/parrot/engine/builtin/kernels/shared_flash_decoding.py @@ -23,7 +23,9 @@ import triton import triton.language as tl -from vllm import attention_ops, cache_ops +from vllm import cache_ops + +from parrot import attention_ops ### Paged Flash Attention Begin ### diff --git a/parrot/engine/config.py b/parrot/engine/config.py index 01e9cd5..b8e5ca1 100644 --- a/parrot/engine/config.py +++ b/parrot/engine/config.py @@ -45,6 +45,11 @@ def __post_init__(self): self.device = torch.device(self.device) # Replace attn func + if self.attn_func not in ATTN_FUNC_LAYOUT_MAP: + raise ValueError( + f"Unknown attn func name: {self.attn_func}. " + f"Supported attn func names: {list(ATTN_FUNC_LAYOUT_MAP.keys())}" + ) self.mem_layout = ATTN_FUNC_LAYOUT_MAP[self.attn_func] # Set mem layout self.attn_func_name = self.attn_func self.attn_func = self._get_attn_func(self.attn_func) diff --git a/requirements.txt b/requirements.txt index a3bf77e..75f0913 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,9 +13,9 @@ transformers openai # Machine Learning packages -torch >= 2.1.0 -triton >= 2.1.0 -xformers >= 0.0.22.post7 +torch == 2.1.0 +triton == 2.1.0 +xformers == 0.0.22.post7 # Misc regex diff --git a/setup.py b/setup.py index 4687896..e8be984 100644 --- a/setup.py +++ b/setup.py @@ -4,25 +4,122 @@ """Setup scripts.""" -import pathlib -import sys +import os +import subprocess +from typing import Set +from packaging.version import parse, Version from setuptools import find_packages, setup -if len(sys.argv) <= 1: - sys.argv += ["install", "--user"] +import torch +from torch.utils.cpp_extension import BuildExtension, CUDAExtension # , CUDA_HOME -root_path = pathlib.Path(__file__).parent.absolute() +# https://github.com/pytorch/pytorch/issues/22844 +# HACK(chaofan): Sometimes this method fails to detect correct CUDA version. +# We use environment variable CUDA_HOME instead. +CUDA_HOME = os.getenv("CUDA_HOME") -def install(): - setup( - name="parrot", - version="0.1", - author="Chaofan Lin", - package_dir={"": "."}, - packages=find_packages("."), +ROOT_DIR = os.path.dirname(__file__) + +# Compiler flags. +CXX_FLAGS = ["-g", "-O2", "-std=c++17"] +NVCC_FLAGS = ["-O2", "-std=c++17"] + +ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0 +CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] +NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"] + +if CUDA_HOME is None: + raise RuntimeError( + f"Cannot find CUDA_HOME. CUDA must be available to build the package." + ) + + +def get_nvcc_cuda_version(cuda_dir: str) -> Version: + """Get the CUDA version from nvcc. + + Adapted from https://github.com/NVIDIA/apex/blob/8b7a1ff183741dd8f9b87e7bafd04cfde99cea28/setup.py + """ + nvcc_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + +# Collect the compute capabilities of all available GPUs. +device_count = torch.cuda.device_count() +compute_capabilities: Set[int] = set() +for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 7: + raise RuntimeError( + "GPUs with compute capability less than 7.0 are not supported." + ) + compute_capabilities.add(major * 10 + minor) + +# Validate the NVCC CUDA version. +nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) +if nvcc_cuda_version < Version("11.0"): + raise RuntimeError("CUDA 11.0 or higher is required to build the package.") +if 86 in compute_capabilities and nvcc_cuda_version < Version("11.1"): + raise RuntimeError( + "CUDA 11.1 or higher is required for GPUs with compute capability 8.6." + ) +if 89 in compute_capabilities and nvcc_cuda_version < Version("11.8"): + # CUDA 11.8 is required to generate the code targeting compute capability 8.9. + # However, GPUs with compute capability 8.9 can also run the code generated by + # the previous versions of CUDA 11 and targeting compute capability 8.0. + # Therefore, if CUDA 11.8 is not available, we target compute capability 8.0 + # instead of 8.9. + compute_capabilities.remove(89) + compute_capabilities.add(80) +if 90 in compute_capabilities and nvcc_cuda_version < Version("11.8"): + raise RuntimeError( + "CUDA 11.8 or higher is required for GPUs with compute capability 9.0." ) +# If no GPU is available, add all supported compute capabilities. +if not compute_capabilities: + compute_capabilities = {70, 75, 80} + if nvcc_cuda_version >= Version("11.1"): + compute_capabilities.add(86) + if nvcc_cuda_version >= Version("11.8"): + compute_capabilities.add(89) + compute_capabilities.add(90) + +# Add target compute capabilities to NVCC flags. +for capability in compute_capabilities: + NVCC_FLAGS += ["-gencode", f"arch=compute_{capability},code=sm_{capability}"] + +# Use NVCC threads to parallelize the build. +if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] + +ext_modules = [] + +# Attention kernels. +attention_extension = CUDAExtension( + name="parrot.attention_ops", + sources=[ + "csrc/attention.cpp", + "csrc/attention/attention_kernels.cu", + "csrc/attention/attention_prev_kernels.cu", + "csrc/attention/attention_post_kernels.cu", + ], + extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS}, +) +ext_modules.append(attention_extension) -print("Installing Parrot ...") -install() +setup( + name="parrot", + version="0.1", + author="Chaofan Lin", + package_dir={"": "."}, + packages=find_packages(exclude=("csrc")), + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension}, +) diff --git a/tests/kernel/test_shared_kernel.py b/tests/kernel/test_shared_kernel.py new file mode 100644 index 0000000..d0aa03a --- /dev/null +++ b/tests/kernel/test_shared_kernel.py @@ -0,0 +1,132 @@ +from transformers import AutoTokenizer +import torch +import json + +from parrot.engine.builtin.builtin_runner import BuiltinRunner +from parrot.engine.config import BuiltinConfig +from parrot.engine.primitive_job import Fill, Generate +from parrot.sampling_config import SamplingConfig + + +def test_shared_decode(): + config = BuiltinConfig( + num_kv_cache_blocks=2048, + # attn_func="xformers_fill_shared_prompts_generate", + attn_func="xformers_fill_vllm_paged_attention_generate", + block_size=16, + max_seq_len=16384, + ) + sampling_config = SamplingConfig( + max_gen_length=200, + ignore_tokenizer_eos=True, + ) + + runner = BuiltinRunner("lmsys/vicuna-7b-v1.3", config=config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + + # bs=2 + # shared len = 1712 + # diverged len = 3 + prompt_token_ids = [ + [100] * 1712 + [200, 300, 400], + [100] * 1712 + [300, 400, 500], + ] + num_seqs = len(prompt_token_ids) + + shared_ids = 0 + while len(set([prompt[shared_ids] for prompt in prompt_token_ids])) == 1: + shared_ids += 1 + print(shared_ids) + + shared_fill = Fill( + session_id=0, + task_id=0, + context_id=0, + parent_context_id=-1, + token_ids=prompt_token_ids[0][:shared_ids], + ) + diverged_fills = [ + Fill( + session_id=0, + task_id=0, + context_id=i + 1, + parent_context_id=0, + token_ids=prompt[shared_ids:], + ) + for i, prompt in enumerate(prompt_token_ids) + ] + gens = [ + Generate( + session_id=0, + task_id=0, + context_id=i + 1, + parent_context_id=0, + sampling_config=sampling_config, + ) + for i, prompt in enumerate(prompt_token_ids) + ] + + runner.run_iter([shared_fill]) + runner.run_iter(diverged_fills) + for _ in range(10): + runner.run_iter(gens) + + +def test_masked_attention(): + config = BuiltinConfig( + num_kv_cache_blocks=2048, + attn_func="xformers_fill_shared_prompts_generate", + block_size=16, + max_seq_len=16384, + ) + sampling_config = SamplingConfig( + max_gen_length=200, + ignore_tokenizer_eos=True, + ) + + runner = BuiltinRunner("lmsys/vicuna-7b-v1.3", config=config) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer") + prompt = "Hi, my name is John. I'm a research scientist at a AI lab. I'm working on a project to develop a new AI model. I'm looking for a collaborator" + prompt_token_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].tolist() + + shared_ids = 20 # a number that is not divisible by 16 + + shared_fill = Fill( + session_id=0, + task_id=0, + context_id=0, + parent_context_id=-1, + token_ids=prompt_token_ids[0][:shared_ids], + ) + diverged_fills = [ + Fill( + session_id=0, + task_id=0, + context_id=i + 1, + parent_context_id=0, + token_ids=prompt[shared_ids:], + ) + for i, prompt in enumerate(prompt_token_ids) + ] + gens = [ + Generate( + session_id=0, + task_id=0, + context_id=i + 1, + parent_context_id=0, + sampling_config=sampling_config, + ) + for i, prompt in enumerate(prompt_token_ids) + ] + + runner.run_iter([shared_fill]) + runner.run_iter(diverged_fills) + for _ in range(10): + runner.run_iter(gens) + + print(tokenizer.decode(gens[0].context.token_ids)) + + +if __name__ == "__main__": + # test_shared_decode() + test_masked_attention()