Skip to content

Commit cfe8683

Browse files
JaxChen29meta-codesync[bot]
authored andcommitted
embedding forward optimization for MI350 (#5064)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/2095 optimization on embedding forward for MI350: 1. apply vec4 on embedding vbe forward kernel instead of vec2 2. As there are 64 threads in rocm, optimize subwarp in embedding forward v2 kernel when embedding dim is from 32 to 64. Pull Request resolved: #5064 Reviewed By: q10 Differential Revision: D85701691 Pulled By: spcyppt fbshipit-source-id: 72f491414f50e53038a4b02f3d555967d34740a7
1 parent c5be0ac commit cfe8683

File tree

3 files changed

+16
-26
lines changed

3 files changed

+16
-26
lines changed

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_template.cu

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,7 @@ using namespace fbgemm_gpu;
8484

8585

8686
{#-/* Set the weights row accessor */#}
87-
{%- if is_rocm %}
88-
const auto weights_row = rocm::WeightRowAccessorVec2
89-
{%- else %}
9087
const auto weights_row = WeightRowAccessor
91-
{%- endif %}
9288
<
9389
{{ 'cache_t' if from_cache else 'emb_t' }},
9490
cache_t
@@ -182,11 +178,7 @@ using namespace fbgemm_gpu;
182178
{%- endif %}
183179

184180
{#-/* Set the weights row accessor */#}
185-
{%- if is_rocm %}
186-
const auto weights_row = rocm::WeightRowAccessorVec2
187-
{%- else %}
188181
const auto weights_row = WeightRowAccessor
189-
{%- endif %}
190182
<
191183
{{ 'cache_t' if from_cache else 'emb_t' }},
192184
cache_t
@@ -319,7 +311,7 @@ using namespace fbgemm_gpu;
319311

320312
{%- if is_rocm %}
321313
{%- if not nobag %}
322-
rocm::Vec2T<cache_t> vals[kManualUnrollLength * kMaxVecsPerThread];
314+
Vec4T<cache_t> vals[kManualUnrollLength * kMaxVecsPerThread];
323315
{%- endif %}
324316
// Iterate over kThreadGroupSize indices
325317
for (auto outer_j = 0; outer_j < kThreadGroupSize && l_start + outer_j < L - L % kManualUnrollLength; outer_j += kManualUnrollLength)
@@ -633,12 +625,7 @@ batch_index_select_dim0_codegen_forward_kernel(
633625
#endif
634626

635627
// Elements are processed 4 at a time through fbgemm_gpu::Vec4 (CUDA float4, 16 bytes)
636-
// for CUDA devices and 2 at a time for ROCm
637-
{%- if is_rocm %}
638-
constexpr int VEC_WIDTH = 2;
639-
{%- else %}
640628
constexpr int VEC_WIDTH = 4;
641-
{%- endif %}
642629
{%- if is_rocm %}
643630
// Unroll factor for ROCm devices
644631
constexpr int kManualUnrollLength = 4;
@@ -743,12 +730,8 @@ batch_index_select_dim0_codegen_forward_kernel(
743730
const float inv_L = (mean_pooling && L != 0) ? static_cast<float>(1.0) / L: static_cast<float>(1.0);
744731

745732
// Set up the accumulator buffer
746-
{%- if is_rocm %}
747-
rocm::Vec2T<cache_t> accumulators[kMaxVecsPerThread];
748-
{%- else %}
749733
Vec4T<cache_t> accumulators[kMaxVecsPerThread];
750734
{%- endif %}
751-
{%- endif %}
752735

753736
{%- if dense %}
754737
{{ embedding_pool_or_store("NULL") }}
@@ -930,7 +913,7 @@ batch_index_select_dim0_codegen_forward_kernel
930913
{%- endmacro %}
931914

932915
{%- macro bulk_template_instantiations(use_cache, kMaxVecsPerThread, kThreadGroupSize) %}
933-
{%- set max_vecs_per_thread = 2 * kMaxVecsPerThread if is_rocm else kMaxVecsPerThread %}
916+
{%- set max_vecs_per_thread = kMaxVecsPerThread %}
934917
{%- for emb_type in (['float', 'at::Half'] + (['at::Float8_e4m3fnuz'] if is_rocm else ['at::Float8_e4m3fn'])) %}
935918
{%- for cache_type in ['float', 'at::Half'] %}
936919
{%- for output_type in ['float', 'at::Half', 'at::BFloat16'] %}

fbgemm_gpu/codegen/training/forward/embedding_forward_split_kernel_v2_template.cu

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -975,6 +975,13 @@ __global__ void split_embedding_codegen_forward_{{ wdesc }}_v2_kernel(
975975
else if (tail_warp_size <= 16) {
976976
INVOKE_PROCESS_ALL_INDICES(large_Ls, 16, 0x55)
977977
}
978+
#if defined(USE_ROCM)
979+
// not sure step mask value to set when group size is 32
980+
// while use_lxu_cache is false step mask makes no sense
981+
else if (tail_warp_size <= 32 && !use_lxu_cache) {
982+
INVOKE_PROCESS_ALL_INDICES(large_Ls, 32, 0xf)
983+
}
984+
#endif
978985
else {
979986
INVOKE_PROCESS_ALL_INDICES(large_Ls, kWarpSize, 0xf)
980987
}

fbgemm_gpu/codegen/training/forward/embedding_forward_split_template.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -720,12 +720,7 @@ batch_index_select_dim0_codegen_forward_cuda(
720720
// kFixedMaxVecsPerThread instead of kMaxVecsPerThread. But
721721
// kMaxVecsPerThread and kFixedMaxVecsPerThread are the same
722722
// forward
723-
{%- if is_rocm %}
724-
// Account for Vec2 load for ROCm
725-
constexpr auto kMaxVecsPerThread = 2 * kFixedMaxVecsPerThread;
726-
{%- else %}
727723
constexpr auto kMaxVecsPerThread = kFixedMaxVecsPerThread;
728-
{%- endif %}
729724

730725
const auto grid = min(
731726
div_round_up(total_B, kForwardMaxThreads / kThreadGroupSize),
@@ -799,9 +794,14 @@ batch_index_select_dim0_codegen_forward_cuda(
799794
// if (!is_experimental)
800795
} else {
801796
// Allocate num warps per table based on max_D
797+
802798
const int num_warps_per_table = B * div_round_up(max_D, kWarpSize * 4);
803-
const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize;
804-
799+
#ifdef USE_ROCM
800+
const uint32_t num_warps_per_threadblock = kForwardMaxThreads / (kWarpSize * 2);
801+
#else
802+
const uint32_t num_warps_per_threadblock = kForwardMaxThreads / kWarpSize;
803+
#endif
804+
805805
const auto kernel_func =
806806
(use_lxu_cache ? split_embedding_codegen_forward_{{ wdesc }}_v2_kernel<
807807
emb_t, cache_t, output_t, index_t, true>

0 commit comments

Comments
 (0)