Skip to content

Commit d30e6f2

Browse files
authored
feat(deepseek_rope): add deepseek_scaling_rope (#34)
Signed-off-by: Double Young <[email protected]>
1 parent 2e9d9c3 commit d30e6f2

File tree

6 files changed

+401
-1
lines changed

6 files changed

+401
-1
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ if(ONEDNN_FOUND)
198198
file(GLOB _ONEDNN_SRC csrc/xpu/onednn/*.cpp)
199199
list(APPEND VLLM_EXT_XPU_SRC
200200
${_ONEDNN_SRC}
201+
"csrc/xpu/sycl/deepseek_scaling_rope.cpp"
201202
)
202203
include_directories(${ONEDNN_INCLUDE_DIR})
203204
link_libraries(${ONEDNN_LIBRARY})

csrc/xpu/ops.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,9 @@
55
torch::Tensor fp8_gemm_w8a16(const torch::Tensor& A, const torch::Tensor& B,
66
bool trans_B,
77
const std::optional<torch::Tensor>& B_scale_,
8-
const std::optional<torch::Tensor>& bias_);
8+
const std::optional<torch::Tensor>& bias_);
9+
10+
std::tuple<at::Tensor, at::Tensor> deepseek_scaling_rope(
11+
const at::Tensor& positions, const at::Tensor& query, const at::Tensor& key,
12+
const c10::optional<at::Tensor>& offsets_opt,
13+
const at::Tensor& cos_sin_cache, int64_t rotary_dim, bool is_neox);
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
#include <sycl/sycl.hpp>
2+
#include "utils.h"
3+
#include "dispatch_utils.h"
4+
#include <cmath>
5+
#include <c10/macros/Macros.h>
6+
7+
namespace vllm {
8+
9+
template <typename T, int64_t rotary_dim, bool is_neox>
10+
class deepseek_scaling_rope_kernel {
11+
public:
12+
static constexpr int sg_size = 16;
13+
deepseek_scaling_rope_kernel(
14+
const int64_t* positions, const T* query, const T* key,
15+
const int64_t* offsets, const T* cos_sin_cache, T* query_out, T* key_out,
16+
const int64_t batch, const int64_t q_num_head, const int64_t k_num_head,
17+
const int64_t head_size, const int64_t q_num_head_d,
18+
const int64_t q_batch_d, const int64_t k_num_head_d,
19+
const int64_t k_batch_d)
20+
: positions(positions),
21+
query(query),
22+
key(key),
23+
offsets(offsets),
24+
cos_sin_cache(cos_sin_cache),
25+
query_out(query_out),
26+
key_out(key_out),
27+
batch(batch),
28+
q_num_head(q_num_head),
29+
k_num_head(k_num_head),
30+
head_size(head_size),
31+
q_num_head_d(q_num_head_d),
32+
q_batch_d(q_batch_d),
33+
k_num_head_d(k_num_head_d),
34+
k_batch_d(k_batch_d) {}
35+
36+
void rotary_embedding_kernel(const int64_t position, const T* pe,
37+
const T* cos_sin_cache, T* res) const {
38+
constexpr int64_t half_rotary_dim = rotary_dim / 2;
39+
constexpr int64_t vec_2_len = 2;
40+
using v2_type = sycl::vec<T, vec_2_len>;
41+
const int64_t cache_idx = position * rotary_dim;
42+
const T* cos_cache_offset = &cos_sin_cache[cache_idx];
43+
const T* sin_cache_offset = cos_cache_offset + half_rotary_dim;
44+
if constexpr (is_neox) {
45+
// repeat & rotate mul add
46+
for (int64_t i = 0; i < half_rotary_dim; ++i) {
47+
int64_t j = i + half_rotary_dim;
48+
T cv = cos_cache_offset[i];
49+
T sv = sin_cache_offset[i];
50+
res[i] = pe[i] * cv - pe[j] * sv;
51+
res[j] = pe[j] * cv + pe[i] * sv;
52+
}
53+
} else {
54+
// interleave & rotate mul add, unfortunately no prefetch in sycl
55+
const v2_type* pe_2 = reinterpret_cast<const v2_type*>(pe);
56+
v2_type* res_2 = reinterpret_cast<v2_type*>(res);
57+
for (int64_t h = 0; h < half_rotary_dim; ++h) {
58+
T c = cos_cache_offset[h];
59+
T s = sin_cache_offset[h];
60+
v2_type c2 = {c, c};
61+
v2_type s2 = {s, s};
62+
v2_type t = pe_2[h];
63+
v2_type* dst = &res_2[h];
64+
v2_type tr = {-t[1], t[0]};
65+
*dst = t * c2 + tr * s2;
66+
}
67+
}
68+
}
69+
70+
[[sycl::reqd_sub_group_size(sg_size)]] void operator()(
71+
sycl::nd_item<3> idx) const {
72+
int64_t batch_idx = idx.get_global_id(0);
73+
int64_t sg_idx = idx.get_local_id(1);
74+
int64_t local_id = idx.get_global_id(2);
75+
int64_t head_idx = sg_idx * sg_size + local_id;
76+
int64_t qo_idx = batch_idx * q_num_head * head_size + head_idx * head_size;
77+
int64_t ko_idx = batch_idx * k_num_head * head_size +
78+
(head_idx - q_num_head) * head_size;
79+
int64_t qi_idx = batch_idx * q_batch_d + head_idx * q_num_head_d;
80+
int64_t ki_idx =
81+
batch_idx * k_batch_d + (head_idx - q_num_head) * k_num_head_d;
82+
if (head_idx < q_num_head) {
83+
rotary_embedding_kernel(positions[batch_idx], &query[qi_idx],
84+
cos_sin_cache, &query_out[qo_idx]);
85+
} else if (head_idx < q_num_head + k_num_head) {
86+
rotary_embedding_kernel(positions[batch_idx], &key[ki_idx], cos_sin_cache,
87+
&key_out[ko_idx]);
88+
}
89+
}
90+
91+
private:
92+
const int64_t* positions;
93+
const T* query;
94+
const T* key;
95+
const int64_t* offsets;
96+
const T* cos_sin_cache;
97+
T* query_out;
98+
T* key_out;
99+
const int64_t batch;
100+
const int64_t q_num_head;
101+
const int64_t k_num_head;
102+
const int64_t head_size;
103+
const int64_t q_num_head_d;
104+
const int64_t q_batch_d;
105+
const int64_t k_num_head_d;
106+
const int64_t k_batch_d;
107+
};
108+
109+
} // namespace vllm
110+
111+
template <typename T>
112+
void call_deepseek_scaling_rope(const int64_t* positions, const T* query,
113+
const T* key, const int64_t* offsets,
114+
const T* cos_sin_cache, T* query_out,
115+
T* key_out, int64_t batch, int64_t q_num_head,
116+
int64_t k_num_head, int64_t head_size,
117+
int64_t rotary_dim, bool is_neox,
118+
int64_t q_num_head_d, int64_t q_batch_d,
119+
int64_t k_num_head_d, int64_t k_batch_d) {
120+
static constexpr std::array<int, 5> allowed_dims = {32, 64, 96, 128, 256};
121+
auto it = std::find(allowed_dims.begin(), allowed_dims.end(), rotary_dim);
122+
123+
TORCH_CHECK(it != allowed_dims.end(), "Invalid rotary_dim (", rotary_dim,
124+
"). Supported: 32,64,96,128,256");
125+
TORCH_CHECK(rotary_dim == head_size, "rotary_dim (", rotary_dim,
126+
") must equal head_size (", head_size, ")");
127+
128+
const int rot_idx = std::distance(allowed_dims.begin(), it);
129+
const int neox_idx = is_neox ? 1 : 0;
130+
const int func_idx = neox_idx * allowed_dims.size() + rot_idx;
131+
132+
using LaunchFn =
133+
void (*)(sycl::queue&, const int64_t*, const T*, const T*, const int64_t*,
134+
const T*, T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t,
135+
int64_t, int64_t, int64_t);
136+
137+
// Table builder macro
138+
#define REGISTER_CASE(dim, neox) \
139+
[](sycl::queue& q, const int64_t* pos, const T* q_in, const T* k_in, \
140+
const int64_t* off, const T* cache, T* q_out, T* k_out, int64_t b, \
141+
int64_t qh, int64_t kh, int64_t hs, int64_t qhd, int64_t qbd, \
142+
int64_t khd, int64_t kbd) { \
143+
constexpr int64_t sg_size = 16; \
144+
int64_t sg_per_heads = (qh + kh + sg_size - 1) / sg_size; \
145+
sycl::range<3> local(1, sg_per_heads, sg_size); \
146+
sycl::range<3> global(b, sg_per_heads, sg_size); \
147+
at::DeviceGuard dg(at::Device(at::kXPU, at::xpu::current_device())); \
148+
q.submit([&](sycl::handler& cgh) { \
149+
cgh.parallel_for(sycl::nd_range<3>(global, local), \
150+
vllm::deepseek_scaling_rope_kernel<T, dim, neox>{ \
151+
pos, q_in, k_in, off, cache, q_out, k_out, b, qh, \
152+
kh, hs, qhd, qbd, khd, kbd}); \
153+
}); \
154+
}
155+
156+
static constexpr std::array<LaunchFn, allowed_dims.size() * 2> table = {
157+
REGISTER_CASE(32, false), REGISTER_CASE(64, false),
158+
REGISTER_CASE(96, false), REGISTER_CASE(128, false),
159+
REGISTER_CASE(256, false), REGISTER_CASE(32, true),
160+
REGISTER_CASE(64, true), REGISTER_CASE(96, true),
161+
REGISTER_CASE(128, true), REGISTER_CASE(256, true),
162+
};
163+
164+
auto& queue = vllm::xpu::vllmGetQueue();
165+
table[func_idx](queue, positions, query, key, offsets, cos_sin_cache,
166+
query_out, key_out, batch, q_num_head, k_num_head, head_size,
167+
q_num_head_d, q_batch_d, k_num_head_d, k_batch_d);
168+
169+
#undef REGISTER_CASE
170+
}
171+
172+
/**
173+
* @brief Perform deepseek rotary embedding with q&k.
174+
* @param positions index of embedding [batch]
175+
* @param query query to be processed [batch, num_head, head_dim]
176+
* @param key key to be processed [batch, num_head, head_dim]
177+
* @param offsets optional tensor for offset with position
178+
* @param cos_sin_cache shared cache with cos/sin
179+
* @param is_neox choose interleave or half.
180+
* @return A tuple of tensors (query_out, key_out).
181+
*/
182+
std::tuple<torch::Tensor, torch::Tensor> deepseek_scaling_rope(
183+
const torch::Tensor& positions, const torch::Tensor& query,
184+
const torch::Tensor& key, const c10::optional<torch::Tensor>& offsets_opt,
185+
const torch::Tensor& cos_sin_cache, int64_t rotary_dim, bool is_neox) {
186+
auto query_out = at::empty_like(query);
187+
auto key_out = at::empty_like(key);
188+
189+
auto q_shape = query.sizes();
190+
auto q_stride = query.strides();
191+
int64_t head_size = q_shape[2];
192+
int64_t q_num_head = q_shape[1];
193+
int64_t batch = q_shape[0];
194+
int64_t q_num_head_d = q_stride[1];
195+
int64_t q_batch_d = q_stride[0];
196+
auto k_shape = key.sizes();
197+
auto k_stride = key.strides();
198+
int64_t k_num_head = k_shape[1];
199+
int64_t k_num_head_d = k_stride[1];
200+
int64_t k_batch_d = k_stride[0];
201+
if (is_neox) {
202+
query_out = query_out.reshape({1, batch, q_num_head, head_size});
203+
key_out = key_out.reshape({1, batch, k_num_head, head_size});
204+
}
205+
TORCH_CHECK(cos_sin_cache.sizes()[1] == head_size,
206+
"Rotary dim doesn't match query head_size");
207+
TORCH_CHECK(cos_sin_cache.sizes()[1] == k_shape[2],
208+
"Rotary dim doesn't match key head_size");
209+
const c10::MaybeOwned<torch::Tensor> offsets_maybe_owned =
210+
at::borrow_from_optional_tensor(offsets_opt);
211+
const torch::Tensor& offsets = *offsets_maybe_owned;
212+
auto offsets_ptr = offsets.defined() ? offsets.data_ptr() : nullptr;
213+
switch (query.scalar_type()) {
214+
case torch::kFloat:
215+
call_deepseek_scaling_rope<float>(
216+
reinterpret_cast<int64_t*>(positions.data_ptr()),
217+
reinterpret_cast<float*>(query.data_ptr()),
218+
reinterpret_cast<float*>(key.data_ptr()),
219+
reinterpret_cast<int64_t*>(offsets_ptr),
220+
reinterpret_cast<float*>(cos_sin_cache.data_ptr()),
221+
reinterpret_cast<float*>(query_out.data_ptr()),
222+
reinterpret_cast<float*>(key_out.data_ptr()), batch, q_num_head,
223+
k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d,
224+
k_num_head_d, k_batch_d);
225+
break;
226+
case torch::kFloat16:
227+
call_deepseek_scaling_rope<sycl::half>(
228+
reinterpret_cast<int64_t*>(positions.data_ptr()),
229+
reinterpret_cast<sycl::half*>(query.data_ptr()),
230+
reinterpret_cast<sycl::half*>(key.data_ptr()),
231+
reinterpret_cast<int64_t*>(offsets_ptr),
232+
reinterpret_cast<sycl::half*>(cos_sin_cache.data_ptr()),
233+
reinterpret_cast<sycl::half*>(query_out.data_ptr()),
234+
reinterpret_cast<sycl::half*>(key_out.data_ptr()), batch, q_num_head,
235+
k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d,
236+
k_num_head_d, k_batch_d);
237+
break;
238+
case torch::kBFloat16:
239+
call_deepseek_scaling_rope<sycl::ext::oneapi::bfloat16>(
240+
reinterpret_cast<int64_t*>(positions.data_ptr()),
241+
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(query.data_ptr()),
242+
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(key.data_ptr()),
243+
reinterpret_cast<int64_t*>(offsets_ptr),
244+
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(
245+
cos_sin_cache.data_ptr()),
246+
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(query_out.data_ptr()),
247+
reinterpret_cast<sycl::ext::oneapi::bfloat16*>(key_out.data_ptr()),
248+
batch, q_num_head, k_num_head, head_size, rotary_dim, is_neox,
249+
q_num_head_d, q_batch_d, k_num_head_d, k_batch_d);
250+
break;
251+
default:
252+
throw std::invalid_argument(
253+
"Invalid dtype, only supports float32, float16, and bfloat16");
254+
break;
255+
}
256+
return {query_out, key_out};
257+
}

csrc/xpu/torch_bindings.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, xpu_ops) {
1111
"fp8_gemm_w8a16(Tensor! A, Tensor! B, bool trans_B, Tensor? B_scale_, "
1212
"Tensor? bias_) -> Tensor");
1313
xpu_ops.impl("fp8_gemm_w8a16", torch::kXPU, &fp8_gemm_w8a16);
14+
15+
xpu_ops.def(
16+
"deepseek_scaling_rope(Tensor! positions, Tensor! query, Tensor! key, "
17+
"Tensor? offsets_opt, Tensor! cos_sin_cache, int rotary_dim, bool "
18+
"is_neox_style) "
19+
"-> (Tensor, Tensor)");
20+
xpu_ops.impl("deepseek_scaling_rope", torch::kXPU, &deepseek_scaling_rope);
1421
}
1522

1623
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)

tests/register_ops.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Optional
66
import vllm_xpu_kernels._C # noqa: F401
77
import vllm_xpu_kernels._moe_C # noqa: F401
8+
import vllm_xpu_kernels._xpu_C # noqa: F401
89

910

1011
# layer norm ops
@@ -61,6 +62,20 @@ def rotary_embedding(
6162
cos_sin_cache, is_neox)
6263

6364

65+
def deepseek_scaling_rope(
66+
positions: torch.Tensor,
67+
query: torch.Tensor,
68+
key: torch.Tensor,
69+
offsets_opt: Optional[torch.Tensor],
70+
cos_sin_cache: Optional[torch.Tensor],
71+
rotary_dim: int,
72+
is_neox_style: bool,
73+
) -> tuple[torch.Tensor, torch.Tensor]:
74+
return torch.ops._xpu_C.deepseek_scaling_rope(positions, query, key,
75+
offsets_opt, cos_sin_cache,
76+
rotary_dim, is_neox_style)
77+
78+
6479
def reshape_and_cache(
6580
key: torch.Tensor,
6681
value: torch.Tensor,

0 commit comments

Comments
 (0)