Skip to content

Commit

Permalink
[Op] Move SharedAttention to parrot.op (#3)
Browse files Browse the repository at this point in the history
This PR move our shared kernel from 3rdparty `vLLM` to the namespace `parrot.op`. And it also adds some tests and benchmark on the shared kernel.

---------

Co-authored-by: Chengruidong Zhang <[email protected]>
  • Loading branch information
SiriusNEO and Starmys authored May 29, 2024
1 parent aafe7ee commit eefd1c9
Show file tree
Hide file tree
Showing 20 changed files with 3,328 additions and 24 deletions.
3 changes: 0 additions & 3 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,4 @@ export LD_LIBRARY_PATH=/usr/local/cuda-12.1/lib64:$LD_LIBRARY_PATH
export SIMULATE_NETWORK_LATENCY_PRT=1 # 0 off, 1 on
export SIMULATE_NETWORK_LATENCY_FS=1 # 0 off, 1 on

# export FS_MAX_GEN_LENGTH=20
# export FS_MAX_GEN_LENGTH=50

# CUDA_LAUNCH_BLOCKING=1
2 changes: 1 addition & 1 deletion 3rdparty/vllm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [
"ninja",
"packaging",
"setuptools",
"torch >= 2.0.0",
"torch == 2.1.0",
"wheel",
]
build-backend = "setuptools.build_meta"
8 changes: 7 additions & 1 deletion INSTALL.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@ pip install torch==2.1.0 --upgrade --index-url https://download.pytorch.org/whl/
### Clone the Project

```bash
git clone --recursive https://github.com/SiriusNEO/LLMOS-Parrot.git
git clone --recursive https://github.com/microsoft/ParrotServe.git
```

### Configure the Environment

```bash
source .env
```

### Install dependencies
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Parrot: Efficient Serving of LLM-based Application with Semantic Variables
# Parrot: Efficient Serving of LLM-based Application with Semantic Variable

This project is a research prototype for now. Being eargerly iterated.

Expand Down
82 changes: 82 additions & 0 deletions benchmark/bench_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from transformers import AutoTokenizer
import torch
import json

from parrot.engine.builtin.builtin_runner import BuiltinRunner
from parrot.engine.config import BuiltinConfig
from parrot.engine.primitive_job import Fill, Generate
from parrot.sampling_config import SamplingConfig


def bench_decode(
attn_func: str, batch_size: int, shared_len: int, diverged_len: int, output_len: int
):
config = BuiltinConfig(
num_kv_cache_blocks=2000,
attn_func=attn_func,
block_size=16,
max_seq_len=65536,
)
sampling_config = SamplingConfig(
max_gen_length=output_len,
ignore_tokenizer_eos=True,
)

runner = BuiltinRunner("lmsys/vicuna-13b-v1.3", config=config)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")

context_len = shared_len + diverged_len

prompt_token_ids = [[100] * context_len for _ in range(batch_size)]

shared_fill = Fill(
session_id=0,
task_id=0,
context_id=0,
parent_context_id=-1,
token_ids=prompt_token_ids[0][:shared_len],
)
diverged_fills = [
Fill(
session_id=0,
task_id=0,
context_id=i + 1,
parent_context_id=0,
token_ids=prompt[shared_len:],
)
for i, prompt in enumerate(prompt_token_ids)
]
gens = [
Generate(
session_id=0,
task_id=0,
context_id=i + 1,
parent_context_id=0,
sampling_config=sampling_config,
)
for i, prompt in enumerate(prompt_token_ids)
]

runner.run_iter([shared_fill])
runner.run_iter(diverged_fills)
for _ in range(output_len):
runner.run_iter(gens)

del runner


if __name__ == "__main__":
# bench_decode(
# attn_func="xformers_fill_vllm_paged_attention_generate",
# batch_size=64,
# shared_len=8192,
# diverged_len=10,
# output_len=10,
# )
bench_decode(
attn_func="xformers_fill_shared_prompts_generate",
batch_size=64,
shared_len=8192,
diverged_len=10,
output_len=10,
)
62 changes: 62 additions & 0 deletions csrc/attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#include <torch/extension.h>
#include <c10/util/Optional.h>

void single_query_cached_kv_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& block_lens, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& block_nums, // [num_seqs]
torch::Tensor& context_lens, // [num_seqs]
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

void single_query_cached_kv_prev_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& qk_maxs, // [num_seqs]
torch::Tensor& exp_sums, // [num_seqs]
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

void single_query_cached_kv_post_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& head_mapping, // [num_heads]
float scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
torch::Tensor& prev_qk_maxs, // [num_seqs]
torch::Tensor& prev_exp_sums, // [num_seqs]
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"single_query_cached_kv_attention",
&single_query_cached_kv_attention,
"Compute the attention between an input query and the cached key/value tensors");
m.def(
"single_query_cached_kv_prev_attention",
&single_query_cached_kv_prev_attention,
"Compute the attention between an input query and the cached key/value tensors and log middle results");
m.def(
"single_query_cached_kv_post_attention",
&single_query_cached_kv_post_attention,
"Compute the attention between an input query and the cached key/value tensors based on previous results");
}
6 changes: 6 additions & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

#include "attention_generic.cuh"
#include "dtype_float16.cuh"
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
64 changes: 64 additions & 0 deletions csrc/attention/attention_generic.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <stdint.h>

namespace vllm {

// A vector type to store Q, K, V elements.
template<typename T, int VEC_SIZE>
struct Vec {};

// A vector type to store FP32 accumulators.
template<typename T>
struct FloatVec {};

// Template vector operations.
template<typename Acc, typename A, typename B>
inline __device__ Acc mul(A a, B b);

template<typename T>
inline __device__ float sum(T v);

template<typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<T, T, T>(a, b));
}

template<typename A, typename T>
inline __device__ float dot(T a, T b) {
return sum(mul<A, T, T>(a, b));
}

template<typename T>
inline __device__ void zero(T& dst) {
constexpr int WORDS = sizeof(T) / 4;
union {
T raw;
uint32_t words[WORDS];
} tmp;

#pragma unroll
for (int ii = 0; ii < WORDS; ++ii) {
tmp.words[ii] = 0u;
}
dst = tmp.raw;
}

} // namespace vllm
Loading

0 comments on commit eefd1c9

Please sign in to comment.