Skip to content

Commit 5f959df

Browse files
842974287facebook-github-bot
authored andcommitted
fix fbgemm build issues after upgrading CK (#4517)
Summary: Pull Request resolved: #4517 X-link: facebookresearch/FBGEMM#1565 Some files are copied over from CK so need to update them. Differential Revision: D78455742
1 parent 71a0e90 commit 5f959df

File tree

5 files changed

+194
-87
lines changed

5 files changed

+194
-87
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ struct fused_moe_traits {
5858
bool local_expert_masking; // if mask experts as local expert
5959
};
6060

61+
// if return zero, no ws needed
62+
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
6163
float fused_moe(
6264
fused_moe_traits,
6365
fused_moe_args,

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moe_kernel.hip

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,14 @@ at::Tensor fused_moe_impl(
7777
auto prec_o = get_prec_str(output);
7878
auto prec_tkw = get_prec_str(topk_weights);
7979

80-
int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts);
80+
int workspace_size = ck_tile::moe_sorting_get_workspace_size(tokens, experts, topk);
8181
void *ws_ptr = nullptr;
8282
if (workspace_size > 0)
8383
{
8484
auto ws = at::zeros({workspace_size}, at::TensorOptions().dtype(topk_ids.dtype()).device(device_of(topk_ids)));
8585
ws_ptr = ws.data_ptr();
8686
}
87-
87+
8888

8989
// Set up traits structure
9090
fused_moe_traits traits{

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/fused_moesorting.hpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,8 @@ struct fused_moesorting_trait {
1515

1616
struct fused_moesorting_args : public ck_tile::MoeSortingHostArgs {};
1717

18+
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk);
1819
float fused_moesorting(
1920
fused_moesorting_trait t,
2021
fused_moesorting_args a,
2122
ck_tile::stream_config s);
22-
23-
int moe_sorting_get_workspace_size(int tokens, int num_experts);
24-
float moe_sorting_mp(
25-
fused_moesorting_trait t,
26-
fused_moesorting_args a,
27-
ck_tile::stream_config s);

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fused_moe/instances/fused_moe_api.hip

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@
33

44
#include "fused_moe.hpp"
55

6+
int fused_moe_get_workspace_size(int tokens, int num_experts, int topk)
7+
{
8+
return ck_tile::moe_sorting_get_workspace_size(tokens, num_experts, topk);
9+
}
10+
611
float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_config& s)
712
{
813
auto s_sub = ck_tile::stream_config{s.stream_id_, false, s.log_level_, 0, 1};
@@ -19,20 +24,21 @@ float fused_moe(fused_moe_traits t, fused_moe_args a, const ck_tile::stream_conf
1924

2025
auto t0 = fused_moesorting_trait{"int32", "fp32", t.local_expert_masking};
2126
auto a0 = fused_moesorting_args{
22-
a.topk_ids_ptr, // const void* p_topk_ids;
23-
a.topk_weight_ptr, // const void* p_weights;
24-
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
25-
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
26-
a.sorted_weight_ptr, // void* p_sorted_weights;
27-
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
28-
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
29-
a.o_ptr, // void* p_moe_buf;
30-
a.ws_ptr, // moe_sorting_ws
31-
a.num_tokens, // index_t tokens;
32-
a.block_m, // index_t unit_size;
33-
a.num_experts, // index_t num_experts;
34-
a.topk, // index_t topk;
35-
a.num_tokens * a.stride_token * o_data_bytes // index_t moe_buf_bytes;
27+
a.topk_ids_ptr, // const void* p_topk_ids;
28+
a.topk_weight_ptr, // const void* p_weights;
29+
a.local_expert_mask_ptr, // const void* p_local_expert_mask;
30+
a.sorted_token_ids_ptr, // void* p_sorted_token_ids;
31+
a.sorted_weight_ptr, // void* p_sorted_weights;
32+
a.sorted_expert_ids_ptr, // void* p_sorted_expert_ids;
33+
a.num_sorted_tiles_ptr, // void* p_total_tokens_post_pad;
34+
a.o_ptr, // void* p_moe_buf;
35+
a.ws_ptr, // void* p_ws;
36+
a.num_tokens, // index_t tokens;
37+
a.block_m, // index_t unit_size;
38+
a.num_experts, // index_t num_experts;
39+
a.topk, // index_t topk;
40+
static_cast<ck_tile::long_index_t>(a.num_tokens) * a.stride_token *
41+
o_data_bytes // index_t moe_buf_bytes;
3642
};
3743

3844
auto t1 = fused_moegemm_traits{t.prec_i,

0 commit comments

Comments
 (0)