Skip to content

Commit 2e9d9c3

Browse files
authored
grouped topk from IPEX (#33)
* add total fused grouped topk Signed-off-by: mayuyuace <[email protected]> * fix warnings Signed-off-by: mayuyuace <[email protected]> * format Signed-off-by: mayuyuace <[email protected]> --------- Signed-off-by: mayuyuace <[email protected]>
1 parent 5d12f87 commit 2e9d9c3

File tree

13 files changed

+552
-29
lines changed

13 files changed

+552
-29
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ define_gpu_extension_target(
222222
set(VLLM_MOE_EXT_SRC
223223
"csrc/moe/torch_bindings.cpp"
224224
"csrc/moe/grouped_topk.cpp"
225+
"csrc/moe/fused_grouped_topk.cpp"
225226
"csrc/moe/moe_align_sum_kernels.cpp")
226227

227228
message(STATUS "Enabling moe extension.")

benchmark/benchmark_grouped_topk.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import torch
99
import triton
1010

11-
from tests.ops.grouped_topk_op import fused_grouped_topk, grouped_topk
11+
from tests.ops.grouped_topk_op import (fused_grouped_topk,
12+
fused_grouped_topk_sycl, grouped_topk)
1213

1314

1415
@torch.compile
@@ -95,8 +96,8 @@ def get_benchmark():
9596
],
9697
x_vals=[tuple(_) for _ in configs],
9798
line_arg="provider",
98-
line_vals=["vllm", "native", "compile"],
99-
line_names=["vllm", "native", "compile"],
99+
line_vals=["vllm", "native", "compile", "sycl"],
100+
line_names=["vllm", "native", "compile", "sycl"],
100101
styles=[("blue", "-"), ("green", "-"), ("orange", "-"),
101102
("red", "-")],
102103
ylabel="us",
@@ -156,7 +157,7 @@ def benchmark(
156157
e_score_correction_bias=e_score_correction_bias),
157158
quantiles=quantiles,
158159
)
159-
else:
160+
elif provider == "compile":
160161
ms, min_ms, max_ms = triton.testing.do_bench(
161162
lambda: grouped_topk_compile(
162163
hidden_states=hidden_states,
@@ -170,6 +171,20 @@ def benchmark(
170171
e_score_correction_bias=e_score_correction_bias),
171172
quantiles=quantiles,
172173
)
174+
elif provider == "sycl":
175+
ms, min_ms, max_ms = triton.testing.do_bench(
176+
lambda: fused_grouped_topk_sycl(
177+
hidden_states=hidden_states,
178+
gating_output=gating_output,
179+
topk=topk,
180+
renormalize=renormalize,
181+
num_expert_group=num_expert_group,
182+
topk_group=topk_group,
183+
scoring_func=scoring_func,
184+
routed_scaling_factor=routed_scaling_factor,
185+
e_score_correction_bias=e_score_correction_bias),
186+
quantiles=quantiles,
187+
)
173188

174189
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
175190

csrc/activation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ class act_kernel {
7373
const int d)
7474
: out_(out), input_(input), d_(d) {}
7575

76-
void operator() [[intel::reqd_sub_group_size(32)]] (
76+
void operator() [[sycl::reqd_sub_group_size(32)]] (
7777
const sycl::nd_item<3>& item_ct1) const {
7878
const int64_t token_idx = item_ct1.get_group(2);
7979
for (int64_t idx = item_ct1.get_local_id(2); idx < d_;
@@ -98,7 +98,7 @@ class act_and_mul_kernel {
9898
const int d)
9999
: out_(out), input_(input), d_(d) {}
100100

101-
void operator() [[intel::reqd_sub_group_size(32)]] (
101+
void operator() [[sycl::reqd_sub_group_size(32)]] (
102102
const sycl::nd_item<3>& item_ct1) const {
103103
const int64_t token_idx = item_ct1.get_group(2);
104104
for (int64_t idx = item_ct1.get_local_id(2); idx < d_;

csrc/layernorm.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ class rms_norm_kernel {
2323
hidden_size(hidden_size_),
2424
s_variance(s_variance_) {}
2525

26-
void operator() [[intel::reqd_sub_group_size(32)]] (
26+
void operator() [[sycl::reqd_sub_group_size(32)]] (
2727
const sycl::nd_item<3>& item_ct1) const {
28-
float* s_variance_ptr = s_variance.get_pointer();
28+
float* s_variance_ptr =
29+
s_variance.template get_multi_ptr<sycl::access::decorated::no>().get();
2930
float variance = 0.0f;
3031

3132
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
@@ -65,7 +66,7 @@ class rms_norm_kernel {
6566
template <typename scalar_t>
6667
void call_rms_norm_kernel(torch::Tensor& out, torch::Tensor& input,
6768
torch::Tensor& weight, float epsilon) {
68-
using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type;
69+
using sycl_t = typename vllm::xpu::SyclTypeTrait<scalar_t>::Type;
6970
int hidden_size = input.size(-1);
7071
int num_tokens = input.numel() / hidden_size;
7172
int64_t input_stride = input.stride(-2);
@@ -104,9 +105,10 @@ class fused_add_rms_norm_kernel {
104105
hidden_size(hidden_size_),
105106
s_variance(s_variance_) {}
106107

107-
void operator() [[intel::reqd_sub_group_size(32)]] (
108+
void operator() [[sycl::reqd_sub_group_size(32)]] (
108109
const sycl::nd_item<3>& item_ct1) const {
109-
float* s_variance_ptr = s_variance.get_pointer();
110+
float* s_variance_ptr =
111+
s_variance.template get_multi_ptr<sycl::access::decorated::no>().get();
110112
float variance = 0.0f;
111113

112114
for (int idx = item_ct1.get_local_id(2); idx < hidden_size;
@@ -150,7 +152,7 @@ template <typename scalar_t>
150152
void call_fused_add_rms_norm_kernel(torch::Tensor& input,
151153
torch::Tensor& residual,
152154
torch::Tensor& weight, float epsilon) {
153-
using sycl_t = vllm::xpu::SyclTypeTrait<scalar_t>::Type;
155+
using sycl_t = typename vllm::xpu::SyclTypeTrait<scalar_t>::Type;
154156
int hidden_size = input.size(-1);
155157
int num_tokens = input.numel() / hidden_size;
156158
auto input_ptr = input.data_ptr<scalar_t>();

0 commit comments

Comments
 (0)