Skip to content

Commit e644e26

Browse files
842974287facebook-github-bot
authored andcommitted
fix fbgemm build issues after upgrading CK (#4517)
Summary: Pull Request resolved: #4517 X-link: facebookresearch/FBGEMM#1565 Some files are copied over from CK so need to update them. Reviewed By: q10 Differential Revision: D78455742
1 parent fa50579 commit e644e26

File tree

5 files changed

+285
-113
lines changed

5 files changed

+285
-113
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ struct fused_moe_args {
1616
const void*
1717
y_smooth_scale_ptr; // [e, 1, n], smooth-quant-scale for 2nd gemm input
1818
const void* local_expert_mask_ptr; // [e], local_expert_mask_ptr for EP
19+
const void* local_tokens; // [1] if not nullptr, tokens read from here
1920
void* o_ptr; // [m, k], output token (no need to do zeroing)
2021
void* ws_ptr; // size is moe_sorting_get_workspace_size()
2122
// if return zero, then could be nullptr
@@ -58,6 +59,8 @@ struct fused_moe_traits {
5859
bool local_expert_masking; // if mask experts as local expert
5960
};
6061

62+
// if return zero, no ws needed
63+
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
6164
float fused_moe(
6265
fused_moe_traits,
6366
fused_moe_args,

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ at::Tensor fused_moe_impl(
7777
auto prec_o = get_prec_str(output);
7878
auto prec_tkw = get_prec_str(topk_weights);
7979

80-
int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
80+
int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk, 0 /*dispatch policy*/);
8181
void *ws_ptr = nullptr;
8282
if (workspace_size > 0)
8383
{
8484
auto ws = at::zeros({workspace_size}, at::TensorOptions().dtype(topk_ids.dtype()).device(device_of(topk_ids)));
8585
ws_ptr = ws.data_ptr();
8686
}
87-
87+
8888

8989
// Set up traits structure
9090
fused_moe_traits traits{
@@ -109,7 +109,8 @@ at::Tensor fused_moe_impl(
109109
gate_up_scales.has_value() ? gate_up_scales->data_ptr() : nullptr,
110110
down_scales.has_value() ? down_scales->data_ptr() : nullptr,
111111
smooth_scales.has_value() ? smooth_scales->data_ptr() : nullptr, // expert_mask
112-
nullptr,
112+
nullptr, // local_expert_mask_ptr
113+
nullptr, // local_tokens
113114
output.data_ptr(),
114115
ws_ptr,
115116
topk_ids.data_ptr(),

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moesorting.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,8 @@ struct fused_moesorting_trait {
1515

1616
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs {};
1717

18+
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
1819
float fused_moesorting(
1920
fused_moesorting_trait t,
2021
fused_moesorting_args a,
2122
ck_tile::stream_config s);
22-
23-
int moe_sorting_get_workspace_size(int tokens, int num_experts);
24-
float moe_sorting_mp(
25-
fused_moesorting_trait t,
26-
fused_moesorting_args a,
27-
ck_tile::stream_config s);

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33

44
#include "fused_moe.hpp"
55

6+
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
7+
{
8+
return ck_tile::moe_sorting_get_workspace_size(
9+
tokens, num_experts, topk, 0 /*dispatch policy*/);
10+
}
11+
612
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
713
{
814
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
@@ -18,21 +24,28 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
1824
}();
1925

2026
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
21-
auto a0 = fused_moesorting_args{
22-
a.topk_ids_ptr, // const void* p_topk_ids;
23-
a.topk_weight_ptr, // const void* p_weights;
24-
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
25-
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
26-
a.sorted_weight_ptr, // void* p_sorted_weights;
27-
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
28-
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
29-
a.o_ptr, // void* p_moe_buf;
30-
a.ws_ptr, // moe_sorting_ws
31-
a.num_tokens, // index_t tokens;
32-
a.block_m, // index_t unit_size;
33-
a.num_experts, // index_t num_experts;
34-
a.topk, // index_t topk;
35-
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
27+
auto a0 = fused_moesorting_args
28+
{
29+
a.topk_ids_ptr, // const void* p_topk_ids;
30+
a.topk_weight_ptr, // const void* p_weights;
31+
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
32+
a.local_tokens,
33+
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
34+
a.sorted_weight_ptr, // void* p_sorted_weights;
35+
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
36+
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
37+
a.o_ptr, // void* p_moe_buf;
38+
a.ws_ptr, // void* p_ws;
39+
a.num_tokens, // index_t tokens;
40+
a.block_m, // index_t unit_size;
41+
a.num_experts, // index_t num_experts;
42+
a.topk, // index_t topk;
43+
#if MOE_SORTING_FMOE_2D_BUF
44+
a.stride_token, o_data_bytes,
45+
#else
46+
static_cast<ck_tile::long_index_t>(a.num_tokens) *
47+
a.stride_token* o_data_bytes // index_t moe_buf_bytes;
48+
#endif
3649
};
3750

3851
auto t1 = fused_moegemm_traits{t.prec_i,

0 commit comments

Comments
 (0)