|
| 1 | +#include <sycl/sycl.hpp> |
| 2 | +#include "utils.h" |
| 3 | +#include "dispatch_utils.h" |
| 4 | +#include <cmath> |
| 5 | +#include <c10/macros/Macros.h> |
| 6 | + |
| 7 | +namespace vllm { |
| 8 | + |
| 9 | +template <typename T, int64_t rotary_dim, bool is_neox> |
| 10 | +class deepseek_scaling_rope_kernel { |
| 11 | + public: |
| 12 | + static constexpr int sg_size = 16; |
| 13 | + deepseek_scaling_rope_kernel( |
| 14 | + const int64_t* positions, const T* query, const T* key, |
| 15 | + const int64_t* offsets, const T* cos_sin_cache, T* query_out, T* key_out, |
| 16 | + const int64_t batch, const int64_t q_num_head, const int64_t k_num_head, |
| 17 | + const int64_t head_size, const int64_t q_num_head_d, |
| 18 | + const int64_t q_batch_d, const int64_t k_num_head_d, |
| 19 | + const int64_t k_batch_d) |
| 20 | + : positions(positions), |
| 21 | + query(query), |
| 22 | + key(key), |
| 23 | + offsets(offsets), |
| 24 | + cos_sin_cache(cos_sin_cache), |
| 25 | + query_out(query_out), |
| 26 | + key_out(key_out), |
| 27 | + batch(batch), |
| 28 | + q_num_head(q_num_head), |
| 29 | + k_num_head(k_num_head), |
| 30 | + head_size(head_size), |
| 31 | + q_num_head_d(q_num_head_d), |
| 32 | + q_batch_d(q_batch_d), |
| 33 | + k_num_head_d(k_num_head_d), |
| 34 | + k_batch_d(k_batch_d) {} |
| 35 | + |
| 36 | + void rotary_embedding_kernel(const int64_t position, const T* pe, |
| 37 | + const T* cos_sin_cache, T* res) const { |
| 38 | + constexpr int64_t half_rotary_dim = rotary_dim / 2; |
| 39 | + constexpr int64_t vec_2_len = 2; |
| 40 | + using v2_type = sycl::vec<T, vec_2_len>; |
| 41 | + const int64_t cache_idx = position * rotary_dim; |
| 42 | + const T* cos_cache_offset = &cos_sin_cache[cache_idx]; |
| 43 | + const T* sin_cache_offset = cos_cache_offset + half_rotary_dim; |
| 44 | + if constexpr (is_neox) { |
| 45 | + // repeat & rotate mul add |
| 46 | + for (int64_t i = 0; i < half_rotary_dim; ++i) { |
| 47 | + int64_t j = i + half_rotary_dim; |
| 48 | + T cv = cos_cache_offset[i]; |
| 49 | + T sv = sin_cache_offset[i]; |
| 50 | + res[i] = pe[i] * cv - pe[j] * sv; |
| 51 | + res[j] = pe[j] * cv + pe[i] * sv; |
| 52 | + } |
| 53 | + } else { |
| 54 | + // interleave & rotate mul add, unfortunately no prefetch in sycl |
| 55 | + const v2_type* pe_2 = reinterpret_cast<const v2_type*>(pe); |
| 56 | + v2_type* res_2 = reinterpret_cast<v2_type*>(res); |
| 57 | + for (int64_t h = 0; h < half_rotary_dim; ++h) { |
| 58 | + T c = cos_cache_offset[h]; |
| 59 | + T s = sin_cache_offset[h]; |
| 60 | + v2_type c2 = {c, c}; |
| 61 | + v2_type s2 = {s, s}; |
| 62 | + v2_type t = pe_2[h]; |
| 63 | + v2_type* dst = &res_2[h]; |
| 64 | + v2_type tr = {-t[1], t[0]}; |
| 65 | + *dst = t * c2 + tr * s2; |
| 66 | + } |
| 67 | + } |
| 68 | + } |
| 69 | + |
| 70 | + [[sycl::reqd_sub_group_size(sg_size)]] void operator()( |
| 71 | + sycl::nd_item<3> idx) const { |
| 72 | + int64_t batch_idx = idx.get_global_id(0); |
| 73 | + int64_t sg_idx = idx.get_local_id(1); |
| 74 | + int64_t local_id = idx.get_global_id(2); |
| 75 | + int64_t head_idx = sg_idx * sg_size + local_id; |
| 76 | + int64_t qo_idx = batch_idx * q_num_head * head_size + head_idx * head_size; |
| 77 | + int64_t ko_idx = batch_idx * k_num_head * head_size + |
| 78 | + (head_idx - q_num_head) * head_size; |
| 79 | + int64_t qi_idx = batch_idx * q_batch_d + head_idx * q_num_head_d; |
| 80 | + int64_t ki_idx = |
| 81 | + batch_idx * k_batch_d + (head_idx - q_num_head) * k_num_head_d; |
| 82 | + if (head_idx < q_num_head) { |
| 83 | + rotary_embedding_kernel(positions[batch_idx], &query[qi_idx], |
| 84 | + cos_sin_cache, &query_out[qo_idx]); |
| 85 | + } else if (head_idx < q_num_head + k_num_head) { |
| 86 | + rotary_embedding_kernel(positions[batch_idx], &key[ki_idx], cos_sin_cache, |
| 87 | + &key_out[ko_idx]); |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + private: |
| 92 | + const int64_t* positions; |
| 93 | + const T* query; |
| 94 | + const T* key; |
| 95 | + const int64_t* offsets; |
| 96 | + const T* cos_sin_cache; |
| 97 | + T* query_out; |
| 98 | + T* key_out; |
| 99 | + const int64_t batch; |
| 100 | + const int64_t q_num_head; |
| 101 | + const int64_t k_num_head; |
| 102 | + const int64_t head_size; |
| 103 | + const int64_t q_num_head_d; |
| 104 | + const int64_t q_batch_d; |
| 105 | + const int64_t k_num_head_d; |
| 106 | + const int64_t k_batch_d; |
| 107 | +}; |
| 108 | + |
| 109 | +} // namespace vllm |
| 110 | + |
| 111 | +template <typename T> |
| 112 | +void call_deepseek_scaling_rope(const int64_t* positions, const T* query, |
| 113 | + const T* key, const int64_t* offsets, |
| 114 | + const T* cos_sin_cache, T* query_out, |
| 115 | + T* key_out, int64_t batch, int64_t q_num_head, |
| 116 | + int64_t k_num_head, int64_t head_size, |
| 117 | + int64_t rotary_dim, bool is_neox, |
| 118 | + int64_t q_num_head_d, int64_t q_batch_d, |
| 119 | + int64_t k_num_head_d, int64_t k_batch_d) { |
| 120 | + static constexpr std::array<int, 5> allowed_dims = {32, 64, 96, 128, 256}; |
| 121 | + auto it = std::find(allowed_dims.begin(), allowed_dims.end(), rotary_dim); |
| 122 | + |
| 123 | + TORCH_CHECK(it != allowed_dims.end(), "Invalid rotary_dim (", rotary_dim, |
| 124 | + "). Supported: 32,64,96,128,256"); |
| 125 | + TORCH_CHECK(rotary_dim == head_size, "rotary_dim (", rotary_dim, |
| 126 | + ") must equal head_size (", head_size, ")"); |
| 127 | + |
| 128 | + const int rot_idx = std::distance(allowed_dims.begin(), it); |
| 129 | + const int neox_idx = is_neox ? 1 : 0; |
| 130 | + const int func_idx = neox_idx * allowed_dims.size() + rot_idx; |
| 131 | + |
| 132 | + using LaunchFn = |
| 133 | + void (*)(sycl::queue&, const int64_t*, const T*, const T*, const int64_t*, |
| 134 | + const T*, T*, T*, int64_t, int64_t, int64_t, int64_t, int64_t, |
| 135 | + int64_t, int64_t, int64_t); |
| 136 | + |
| 137 | +// Table builder macro |
| 138 | +#define REGISTER_CASE(dim, neox) \ |
| 139 | + [](sycl::queue& q, const int64_t* pos, const T* q_in, const T* k_in, \ |
| 140 | + const int64_t* off, const T* cache, T* q_out, T* k_out, int64_t b, \ |
| 141 | + int64_t qh, int64_t kh, int64_t hs, int64_t qhd, int64_t qbd, \ |
| 142 | + int64_t khd, int64_t kbd) { \ |
| 143 | + constexpr int64_t sg_size = 16; \ |
| 144 | + int64_t sg_per_heads = (qh + kh + sg_size - 1) / sg_size; \ |
| 145 | + sycl::range<3> local(1, sg_per_heads, sg_size); \ |
| 146 | + sycl::range<3> global(b, sg_per_heads, sg_size); \ |
| 147 | + at::DeviceGuard dg(at::Device(at::kXPU, at::xpu::current_device())); \ |
| 148 | + q.submit([&](sycl::handler& cgh) { \ |
| 149 | + cgh.parallel_for(sycl::nd_range<3>(global, local), \ |
| 150 | + vllm::deepseek_scaling_rope_kernel<T, dim, neox>{ \ |
| 151 | + pos, q_in, k_in, off, cache, q_out, k_out, b, qh, \ |
| 152 | + kh, hs, qhd, qbd, khd, kbd}); \ |
| 153 | + }); \ |
| 154 | + } |
| 155 | + |
| 156 | + static constexpr std::array<LaunchFn, allowed_dims.size() * 2> table = { |
| 157 | + REGISTER_CASE(32, false), REGISTER_CASE(64, false), |
| 158 | + REGISTER_CASE(96, false), REGISTER_CASE(128, false), |
| 159 | + REGISTER_CASE(256, false), REGISTER_CASE(32, true), |
| 160 | + REGISTER_CASE(64, true), REGISTER_CASE(96, true), |
| 161 | + REGISTER_CASE(128, true), REGISTER_CASE(256, true), |
| 162 | + }; |
| 163 | + |
| 164 | + auto& queue = vllm::xpu::vllmGetQueue(); |
| 165 | + table[func_idx](queue, positions, query, key, offsets, cos_sin_cache, |
| 166 | + query_out, key_out, batch, q_num_head, k_num_head, head_size, |
| 167 | + q_num_head_d, q_batch_d, k_num_head_d, k_batch_d); |
| 168 | + |
| 169 | +#undef REGISTER_CASE |
| 170 | +} |
| 171 | + |
| 172 | +/** |
| 173 | + * @brief Perform deepseek rotary embedding with q&k. |
| 174 | + * @param positions index of embedding [batch] |
| 175 | + * @param query query to be processed [batch, num_head, head_dim] |
| 176 | + * @param key key to be processed [batch, num_head, head_dim] |
| 177 | + * @param offsets optional tensor for offset with position |
| 178 | + * @param cos_sin_cache shared cache with cos/sin |
| 179 | + * @param is_neox choose interleave or half. |
| 180 | + * @return A tuple of tensors (query_out, key_out). |
| 181 | + */ |
| 182 | +std::tuple<torch::Tensor, torch::Tensor> deepseek_scaling_rope( |
| 183 | + const torch::Tensor& positions, const torch::Tensor& query, |
| 184 | + const torch::Tensor& key, const c10::optional<torch::Tensor>& offsets_opt, |
| 185 | + const torch::Tensor& cos_sin_cache, int64_t rotary_dim, bool is_neox) { |
| 186 | + auto query_out = at::empty_like(query); |
| 187 | + auto key_out = at::empty_like(key); |
| 188 | + |
| 189 | + auto q_shape = query.sizes(); |
| 190 | + auto q_stride = query.strides(); |
| 191 | + int64_t head_size = q_shape[2]; |
| 192 | + int64_t q_num_head = q_shape[1]; |
| 193 | + int64_t batch = q_shape[0]; |
| 194 | + int64_t q_num_head_d = q_stride[1]; |
| 195 | + int64_t q_batch_d = q_stride[0]; |
| 196 | + auto k_shape = key.sizes(); |
| 197 | + auto k_stride = key.strides(); |
| 198 | + int64_t k_num_head = k_shape[1]; |
| 199 | + int64_t k_num_head_d = k_stride[1]; |
| 200 | + int64_t k_batch_d = k_stride[0]; |
| 201 | + if (is_neox) { |
| 202 | + query_out = query_out.reshape({1, batch, q_num_head, head_size}); |
| 203 | + key_out = key_out.reshape({1, batch, k_num_head, head_size}); |
| 204 | + } |
| 205 | + TORCH_CHECK(cos_sin_cache.sizes()[1] == head_size, |
| 206 | + "Rotary dim doesn't match query head_size"); |
| 207 | + TORCH_CHECK(cos_sin_cache.sizes()[1] == k_shape[2], |
| 208 | + "Rotary dim doesn't match key head_size"); |
| 209 | + const c10::MaybeOwned<torch::Tensor> offsets_maybe_owned = |
| 210 | + at::borrow_from_optional_tensor(offsets_opt); |
| 211 | + const torch::Tensor& offsets = *offsets_maybe_owned; |
| 212 | + auto offsets_ptr = offsets.defined() ? offsets.data_ptr() : nullptr; |
| 213 | + switch (query.scalar_type()) { |
| 214 | + case torch::kFloat: |
| 215 | + call_deepseek_scaling_rope<float>( |
| 216 | + reinterpret_cast<int64_t*>(positions.data_ptr()), |
| 217 | + reinterpret_cast<float*>(query.data_ptr()), |
| 218 | + reinterpret_cast<float*>(key.data_ptr()), |
| 219 | + reinterpret_cast<int64_t*>(offsets_ptr), |
| 220 | + reinterpret_cast<float*>(cos_sin_cache.data_ptr()), |
| 221 | + reinterpret_cast<float*>(query_out.data_ptr()), |
| 222 | + reinterpret_cast<float*>(key_out.data_ptr()), batch, q_num_head, |
| 223 | + k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d, |
| 224 | + k_num_head_d, k_batch_d); |
| 225 | + break; |
| 226 | + case torch::kFloat16: |
| 227 | + call_deepseek_scaling_rope<sycl::half>( |
| 228 | + reinterpret_cast<int64_t*>(positions.data_ptr()), |
| 229 | + reinterpret_cast<sycl::half*>(query.data_ptr()), |
| 230 | + reinterpret_cast<sycl::half*>(key.data_ptr()), |
| 231 | + reinterpret_cast<int64_t*>(offsets_ptr), |
| 232 | + reinterpret_cast<sycl::half*>(cos_sin_cache.data_ptr()), |
| 233 | + reinterpret_cast<sycl::half*>(query_out.data_ptr()), |
| 234 | + reinterpret_cast<sycl::half*>(key_out.data_ptr()), batch, q_num_head, |
| 235 | + k_num_head, head_size, rotary_dim, is_neox, q_num_head_d, q_batch_d, |
| 236 | + k_num_head_d, k_batch_d); |
| 237 | + break; |
| 238 | + case torch::kBFloat16: |
| 239 | + call_deepseek_scaling_rope<sycl::ext::oneapi::bfloat16>( |
| 240 | + reinterpret_cast<int64_t*>(positions.data_ptr()), |
| 241 | + reinterpret_cast<sycl::ext::oneapi::bfloat16*>(query.data_ptr()), |
| 242 | + reinterpret_cast<sycl::ext::oneapi::bfloat16*>(key.data_ptr()), |
| 243 | + reinterpret_cast<int64_t*>(offsets_ptr), |
| 244 | + reinterpret_cast<sycl::ext::oneapi::bfloat16*>( |
| 245 | + cos_sin_cache.data_ptr()), |
| 246 | + reinterpret_cast<sycl::ext::oneapi::bfloat16*>(query_out.data_ptr()), |
| 247 | + reinterpret_cast<sycl::ext::oneapi::bfloat16*>(key_out.data_ptr()), |
| 248 | + batch, q_num_head, k_num_head, head_size, rotary_dim, is_neox, |
| 249 | + q_num_head_d, q_batch_d, k_num_head_d, k_batch_d); |
| 250 | + break; |
| 251 | + default: |
| 252 | + throw std::invalid_argument( |
| 253 | + "Invalid dtype, only supports float32, float16, and bfloat16"); |
| 254 | + break; |
| 255 | + } |
| 256 | + return {query_out, key_out}; |
| 257 | +} |
0 commit comments