From 467b15b3ac41232f7c724ef34268cc299ff4a760 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 12 May 2025 16:00:08 +0800 Subject: [PATCH 01/12] Add WelfordNormPFImpl.h --- src/ATen/native/xpu/sycl/WelfordNorm.h | 327 ---------------- src/ATen/native/xpu/sycl/WelfordNormPFImpl.h | 378 +++++++++++++++++++ 2 files changed, 378 insertions(+), 327 deletions(-) delete mode 100644 src/ATen/native/xpu/sycl/WelfordNorm.h create mode 100644 src/ATen/native/xpu/sycl/WelfordNormPFImpl.h diff --git a/src/ATen/native/xpu/sycl/WelfordNorm.h b/src/ATen/native/xpu/sycl/WelfordNorm.h deleted file mode 100644 index d9d7720f6..000000000 --- a/src/ATen/native/xpu/sycl/WelfordNorm.h +++ /dev/null @@ -1,327 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace at::native::xpu { - -template -inline T divup(T a, T b) { - return (a + b - 1) / b; -} - -std::tuple get_adaptive_config( - const int reduction, - const int n_channels, - const int vec_size, - int max_wg_size, - int loops_per_item = 8) { - loops_per_item /= vec_size; - int group_size_x = std::min(last_pow2(n_channels / vec_size), 32); - int group_size_y = std::min( - last_pow2(divup(reduction, loops_per_item)), max_wg_size / group_size_x); - if (group_size_x * group_size_y != max_wg_size) { - group_size_x = - std::min(last_pow2(n_channels / vec_size), max_wg_size / group_size_y); - } - - int nwg_x = divup(n_channels, group_size_x * vec_size); - int nwg_y = std::min( - divup(reduction, group_size_y * loops_per_item), - int(syclMaxWorkItemsPerTile()) / (nwg_x * group_size_x) / (group_size_y)); - nwg_y = std::max(nwg_y, 1); - - // it's not worth having reduction between work groups if the reduction - // dimension is not big enough - nwg_y = nwg_y < 4 ? 1 : nwg_y; - - return std::make_tuple(group_size_y, group_size_x, nwg_y, nwg_x); -} - -template -inline void welford_merge( - C& count, - T& mean, - T& m2n, - const C& count_new, - const T& mean_new, - const T& m2n_new) { - T factor = T(1.0) / std::max(1, (count + count_new)); - T delta0 = mean - mean_new; - mean = (mean_new * count_new + mean * count) * factor; - m2n += m2n_new + delta0 * delta0 * count_new * count * factor; - count += count_new; -} - -template -inline void welford_vertical_merge( - sycl::nd_item<2>& item, - C& count, - T& mean, - T& m2n, - CACC& shmem_count, - TACC& shmem_mean, - TACC& shmem_m2n) { - // write to shared memory - auto address_base = item.get_local_linear_id(); -#pragma unroll - for (int offset = item.get_local_range(0) / 2; offset > 0; offset >>= 1) { - if (item.get_local_id(0) < offset * 2) { - shmem_mean[address_base] = mean; - shmem_m2n[address_base] = m2n; - shmem_count[address_base] = count; - } - item.barrier(sycl_local_fence); - if (item.get_local_id(0) < offset && - item.get_local_id(0) + offset < item.get_local_range(0)) { - auto address = address_base + offset * item.get_local_range(1); - // read shared memory back to register for reduction - auto count_new = shmem_count[address]; - auto mean_new = shmem_mean[address]; - auto m2n_new = shmem_m2n[address]; -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - welford_merge( - count[v], mean[v], m2n[v], count_new[v], mean_new[v], m2n_new[v]); - } - } - } -} - -template < - typename VarTransform, - typename scalar_t, - typename acc_t, - int VEC_SIZE = 2> -struct WelfordBatchNormStatChannelsLastVecKernelFunctor - : public __SYCL_KER_CONFIG_CONVENTION__ { - using vec_t = memory::aligned_vector; - using acc_vec_t = memory::aligned_vector; - using int_vec_t = memory::aligned_vector; - - void operator()(sycl::nd_item<2> item) const { - // init private counter - acc_vec_t mean; - acc_vec_t m2n; - int_vec_t count; -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - mean[v] = acc_t(0); - m2n[v] = acc_t(0); - count[v] = int(0); - } - - int gy = item.get_group(0); - int gx = item.get_group(1); - int c_vec_offset = item.get_global_id(1) * VEC_SIZE; - int num_cooperative_groups = item.get_group_range(0); - int inner_loop_stride = item.get_local_range(0) * num_cooperative_groups; - - for (int m_offset = item.get_global_id(0); m_offset < reduction_size_; - m_offset += inner_loop_stride) { - if (c_vec_offset < n_channels_) { - int address_vec_base = m_offset * n_channels_ + c_vec_offset; - auto input_vec = *reinterpret_cast( - const_cast(&input_[address_vec_base])); -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - auto x = input_vec[v]; - count[v]++; - acc_t delta0 = x - mean[v]; - mean[v] += delta0 / count[v]; - acc_t delta1 = x - mean[v]; - m2n[v] += delta0 * delta1; - } - } - } - - welford_vertical_merge( - item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); - - // welford vertical merge - if (num_cooperative_groups > 1) { - acc_t* staging_mean = staging_data_; - acc_t* staging_m2n = &staging_data_[n_channels_ * num_cooperative_groups]; - int* staging_count = reinterpret_cast( - &staging_m2n[n_channels_ * num_cooperative_groups]); - int address_vec_base = c_vec_offset + gy * n_channels_; - - // write data to staging_data; - if (item.get_local_id(0) == 0 && c_vec_offset < n_channels_) { - *reinterpret_cast(&staging_mean[address_vec_base]) = mean; - *reinterpret_cast(&staging_m2n[address_vec_base]) = m2n; - *reinterpret_cast(&staging_count[address_vec_base]) = count; - } - item.barrier(sycl_local_fence); - - // mark group done - if (item.get_local_linear_id() == 0) { - sycl_atomic_ref_rlx_dev_global_t atomic_count(semaphores_[gx]); - int old = atomic_count.fetch_add( - 1, sycl_mem_odr_acq_rel - /* , default memory scope is device */); - is_last_group_done_[0] = (old == (num_cooperative_groups - 1)); - } - item.barrier(sycl_local_fence); - - // check that all data is now available in global memory - if (is_last_group_done_[0]) { -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - mean[v] = acc_t(0); - m2n[v] = acc_t(0); - count[v] = int(0); - } - - for (int y = item.get_local_id(0); y < num_cooperative_groups; - y += item.get_local_range(0)) { - if (c_vec_offset < n_channels_) { - address_vec_base = y * n_channels_ + c_vec_offset; - auto mean_new = - *reinterpret_cast(&staging_mean[address_vec_base]); - auto m2n_new = - *reinterpret_cast(&staging_m2n[address_vec_base]); - auto count_new = - *reinterpret_cast(&staging_count[address_vec_base]); -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - welford_merge( - count[v], - mean[v], - m2n[v], - count_new[v], - mean_new[v], - m2n_new[v]); - } - } - } - welford_vertical_merge( - item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); - } - } - - if (item.get_local_id(0) == 0 && - (num_cooperative_groups == 1 || is_last_group_done_[0]) && - c_vec_offset < n_channels_) { - acc_vec_t invstd_vec; -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - invstd_vec[v] = VarTransform{}(m2n[v] / count[v], epsilon_); - } - - *reinterpret_cast(&save_mean_[c_vec_offset]) = mean; - *reinterpret_cast(&save_invstd_[c_vec_offset]) = invstd_vec; - } - } - - void sycl_ker_config_convention(sycl::handler& cgh) { - auto local_size = group_size_x_ * group_size_y_; - shmem_mean_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); - shmem_m2n_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); - shmem_count_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); - is_last_group_done_ = sycl_local_acc_t(sycl::range<1>(1), cgh); - } - - WelfordBatchNormStatChannelsLastVecKernelFunctor( - const scalar_t* input, - acc_t* save_mean, - acc_t* save_invstd, - int reduction_size, - int n_channels, - acc_t* staging_data, - int* semaphores, - double epsilon) - : input_(input), - save_mean_(save_mean), - save_invstd_(save_invstd), - reduction_size_(reduction_size), - n_channels_(n_channels), - staging_data_(staging_data), - semaphores_(semaphores), - epsilon_(epsilon) {} - - void init() { - using KernelT = WelfordBatchNormStatChannelsLastVecKernelFunctor< - VarTransform, - scalar_t, - acc_t, - VEC_SIZE>; - auto max_group_size = syclMaxWorkGroupSize(); - std::tie(group_size_y_, group_size_x_, ngroups_y_, ngroups_x_) = - get_adaptive_config( - reduction_size_, n_channels_, VEC_SIZE, max_group_size); - } - - static bool valid( - int reduction_size, - int n_channels, - const scalar_t* input, - acc_t* save_mean, - acc_t* save_invstd) { - bool valid = sizeof(scalar_t) <= 2; - valid = valid && (n_channels % VEC_SIZE == 0); - valid = valid && - (memory::can_vectorize_up_to((char*)input) >= VEC_SIZE); - valid = valid && - (memory::can_vectorize_up_to((char*)save_mean) >= VEC_SIZE); - valid = valid && - (memory::can_vectorize_up_to((char*)save_invstd) >= VEC_SIZE); - return valid; - } - - sycl::range<2> local_range() const { - return sycl::range<2>(group_size_y_, group_size_x_); - } - - sycl::range<2> global_range() const { - return sycl::range<2>( - group_size_y_ * ngroups_y_, group_size_x_ * ngroups_x_); - } - - int staging_size() const { - return ngroups_y_ * n_channels_ * 4; - } - - int semaphores_size() const { - return ngroups_x_; - } - - bool set_staging_data_check(acc_t* staging_data) { - staging_data_ = staging_data; - return ( - (staging_data == nullptr) || - (memory::can_vectorize_up_to((char*)staging_data) >= VEC_SIZE)); - } - - void set_semaphores(int* semaphores) { - semaphores_ = semaphores; - } - - int num_cooperative_groups() const { - return ngroups_y_; - } - - private: - const scalar_t* input_; - acc_t* save_mean_; - acc_t* save_invstd_; - int reduction_size_; - int n_channels_; - acc_t* staging_data_; - int* semaphores_; - double epsilon_; - - size_t group_size_y_; - size_t group_size_x_; - size_t ngroups_y_; - size_t ngroups_x_; - - sycl_local_acc_t shmem_mean_; - sycl_local_acc_t shmem_m2n_; - sycl_local_acc_t shmem_count_; - sycl_local_acc_t is_last_group_done_; -}; - -} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h new file mode 100644 index 000000000..0221e5859 --- /dev/null +++ b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h @@ -0,0 +1,378 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native::xpu { + +namespace impl { + +// returns floor(log2(n)) +inline int last_pow2(int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +template +inline T divup(T a, T b) { + return (a + b - 1) / b; +} + +std::tuple get_adaptive_config( + const int problem_size, + const int batch_size, + const int vec_size, + const int max_block_size, + int loops_per_thread = 8, + int coop_th = 8) { + loops_per_thread /= + vec_size; // Ensure the number of instructions is normalized + int threads_along_batch = last_pow2(batch_size / vec_size); + int threads_along_problem = last_pow2(divup(problem_size, loops_per_thread)); + + int block_size_x = std::min(threads_along_batch, 32); + int block_size_y = + std::min(threads_along_problem, max_block_size / block_size_x); + if (block_size_x * block_size_y != max_block_size) { + block_size_x = std::min(threads_along_batch, max_block_size / block_size_y); + } + + int max_threads_gpu = syclMaxWorkItemsPerTile(); + int nblock_x = divup(batch_size, block_size_x * vec_size); + int nblock_y = std::min( + divup(problem_size, block_size_y * loops_per_thread), + max_threads_gpu / (nblock_x * block_size_x) / (block_size_y)); + nblock_y = std::max(nblock_y, 1); + + // it's not worth having reduction between blocks if the reduction + // dimension is not big enough + coop_th /= vec_size; + nblock_y = nblock_y < coop_th ? 1 : nblock_y; + + return std::make_tuple(block_size_y, block_size_x, nblock_y, nblock_x); +} + +template +inline void welford_merge( + C& count, + T& mean, + T& m2n, + const C& count_new, + const T& mean_new, + const T& m2n_new) { + T factor = T(1.0) / std::max(1, (count + count_new)); + T delta0 = mean - mean_new; + mean = (mean_new * count_new + mean * count) * factor; + m2n += m2n_new + delta0 * delta0 * count_new * count * factor; + count += count_new; +} + +} // namespace impl + +template +struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { + using vec_t = memory::aligned_vector; + using acc_vec_t = memory::aligned_vector; + using int_vec_t = memory::aligned_vector; + + void operator()(sycl::nd_item<2> item) const { + // init welford counters + acc_vec_t mean; + acc_vec_t m2n; + int_vec_t count; +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + mean[v] = acc_t(0); + m2n[v] = acc_t(0); + count[v] = int(0); + } + + int bx = item.get_group(1); // along batch dim + int by = item.get_group(0); // along problem dim + int batch_vec_offset = item.get_global_id(1) * VEC_SIZE; + int num_cooperative_blocks = item.get_group_range(0); + int inner_loop_stride = item.get_local_range(0) * num_cooperative_blocks; + + if (batch_vec_offset < batch_size_) { + for (int p_offset = item.get_global_id(0); p_offset < problem_size_; + p_offset += inner_loop_stride) { + int address_vec_base = p_offset * batch_size_ + batch_vec_offset; + auto input_vec = *reinterpret_cast( + const_cast(&input_[address_vec_base])); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + auto x = input_vec[v]; + count[v]++; + acc_t delta0 = x - mean[v]; + mean[v] += delta0 / count[v]; + acc_t delta1 = x - mean[v]; + m2n[v] += delta0 * delta1; + } + } + } + + welford_vertical_merge( + item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); + + // welford vertical merge + if (num_cooperative_blocks > 1) { + acc_t* staging_mean = staging_data_; + acc_t* staging_m2n = &staging_data_[batch_size_ * num_cooperative_blocks]; + int* staging_count = reinterpret_cast( + &staging_m2n[batch_size_ * num_cooperative_blocks]); + int address_vec_base = batch_vec_offset + by * batch_size_; + + // write data to staging_data; + if (item.get_local_id(0) == 0 && batch_vec_offset < batch_size_) { + *reinterpret_cast(&staging_mean[address_vec_base]) = mean; + *reinterpret_cast(&staging_m2n[address_vec_base]) = m2n; + *reinterpret_cast(&staging_count[address_vec_base]) = count; + } + item.barrier(sycl_local_fence); + + // mark block done + if (item.get_local_linear_id() == 0) { + sycl_atomic_ref_rlx_dev_global_t atomic_count(semaphores_[bx]); + int old = atomic_count.fetch_add( + 1, sycl_mem_odr_acq_rel + /* , default memory scope is device */); + is_last_block_done_[0] = (old == (num_cooperative_blocks - 1)); + } + item.barrier(sycl_local_fence); + + // check that all data is now available in global memory + if (is_last_block_done_[0]) { +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + mean[v] = acc_t(0); + m2n[v] = acc_t(0); + count[v] = int(0); + } + + for (int y = item.get_local_id(0); y < num_cooperative_blocks; + y += item.get_local_range(0)) { + if (batch_vec_offset < batch_size_) { + address_vec_base = y * batch_size_ + batch_vec_offset; + auto mean_new = + *reinterpret_cast(&staging_mean[address_vec_base]); + auto m2n_new = + *reinterpret_cast(&staging_m2n[address_vec_base]); + auto count_new = + *reinterpret_cast(&staging_count[address_vec_base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + impl::welford_merge( + count[v], + mean[v], + m2n[v], + count_new[v], + mean_new[v], + m2n_new[v]); + } + } + } + welford_vertical_merge( + item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); + } + } + + if (item.get_local_id(0) == 0 && + (num_cooperative_blocks == 1 || is_last_block_done_[0]) && + batch_vec_offset < batch_size_) { + acc_vec_t invstd_vec; +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + invstd_vec[v] = (acc_t)1.0 / std::sqrt(m2n[v] / count[v] + epsilon_); + } + *reinterpret_cast(&save_mean_[batch_vec_offset]) = mean; + *reinterpret_cast(&save_invstd_[batch_vec_offset]) = + invstd_vec; + + if (running_mean_ != nullptr) { + auto running_mean_vec = + *reinterpret_cast(&running_mean_[batch_vec_offset]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + running_mean_vec[v] = + mean[v] * momentum_ + (1 - momentum_) * running_mean_vec[v]; + } + *reinterpret_cast(&running_mean_[batch_vec_offset]) = + running_mean_vec; + } + + if (running_var_ != nullptr) { + auto running_var_vec = + *reinterpret_cast(&running_var_[batch_vec_offset]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + auto unbiased_var = m2n[v] / (count[v] - 1); + running_var_vec[v] = + unbiased_var * momentum_ + (1 - momentum_) * running_var_vec[v]; + } + *reinterpret_cast(&running_var_[batch_vec_offset]) = + running_var_vec; + } + } + } + + template + inline void welford_vertical_merge( + sycl::nd_item<2>& item, + int_vec_t& count, + acc_vec_t& mean, + acc_vec_t& m2n, + CACC& shmem_count, + TACC& shmem_mean, + TACC& shmem_m2n) const { + // write to shared memory + auto address_base = item.get_local_linear_id(); +#pragma unroll + for (int offset = item.get_local_range(0) / 2; offset > 0; offset >>= 1) { + if (item.get_local_id(0) < offset * 2) { + shmem_mean[address_base] = mean; + shmem_m2n[address_base] = m2n; + shmem_count[address_base] = count; + } + item.barrier(sycl_local_fence); + if (item.get_local_id(0) < offset && + item.get_local_id(0) + offset < item.get_local_range(0)) { + auto address = address_base + offset * item.get_local_range(1); + // read shared memory back to register for reduction + auto count_new = shmem_count[address]; + auto mean_new = shmem_mean[address]; + auto m2n_new = shmem_m2n[address]; +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + impl::welford_merge( + count[v], mean[v], m2n[v], count_new[v], mean_new[v], m2n_new[v]); + } + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + auto local_size = block_size_x_ * block_size_y_; + shmem_mean_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); + shmem_m2n_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); + shmem_count_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); + is_last_block_done_ = sycl_local_acc_t(sycl::range<1>(1), cgh); + } + + void init() { + using KernelT = WelfordNormPFKernel; + auto max_group_size = syclMaxWorkGroupSize(); + std::tie(block_size_y_, block_size_x_, nblocks_y_, nblocks_x_) = + impl::get_adaptive_config( + problem_size_, batch_size_, VEC_SIZE, max_group_size); + } + + static bool valid( + int batch_size, + int problem_size, + const scalar_t* input, + acc_t* save_mean, + acc_t* save_invstd) { + if (VEC_SIZE <= 1) + return true; + bool valid = sizeof(scalar_t) <= 4; + valid = valid && (batch_size % VEC_SIZE == 0); + valid = valid && + (memory::can_vectorize_up_to((char*)input) >= VEC_SIZE); + valid = valid && + (memory::can_vectorize_up_to((char*)save_mean) >= VEC_SIZE); + valid = valid && + (memory::can_vectorize_up_to((char*)save_invstd) >= VEC_SIZE); + return valid; + } + + sycl::range<2> local_range() const { + return sycl::range<2>(block_size_y_, block_size_x_); + } + + sycl::range<2> global_range() const { + return sycl::range<2>( + block_size_y_ * nblocks_y_, block_size_x_ * nblocks_x_); + } + + int staging_size() const { + return nblocks_y_ * batch_size_ * 4; + } + + int semaphores_size() const { + return nblocks_x_; + } + + bool set_staging_data_check(acc_t* staging_data) { + staging_data_ = staging_data; + return ( + (staging_data == nullptr) || + (memory::can_vectorize_up_to((char*)staging_data) >= VEC_SIZE)); + } + + void set_semaphores(int* semaphores) { + semaphores_ = semaphores; + } + + void set_running_mean_var( + scalar_t* running_mean, + scalar_t* running_var, + acc_t momentum) { + running_mean_ = running_mean; + running_var_ = running_var; + momentum_ = momentum; + } + + int num_cooperative_blocks() const { + return nblocks_y_; + } + + WelfordNormPFKernel( + const scalar_t* input, + int batch_size, + int problem_size, + acc_t epsilon, + acc_t* save_mean, + acc_t* save_invstd) + : input_(input), + batch_size_(batch_size), + problem_size_(problem_size), + epsilon_(epsilon), + save_mean_(save_mean), + save_invstd_(save_invstd), + staging_data_(nullptr), + semaphores_(nullptr), + running_mean_(nullptr), + running_var_(nullptr) {} + + private: + const scalar_t* input_; + int batch_size_; + int problem_size_; + acc_t epsilon_; + acc_t* save_mean_; + acc_t* save_invstd_; + acc_t* staging_data_; + int* semaphores_; + + scalar_t* running_mean_; + scalar_t* running_var_; + acc_t momentum_; + + size_t block_size_y_; + size_t block_size_x_; + size_t nblocks_y_; + size_t nblocks_x_; + + sycl_local_acc_t shmem_mean_; + sycl_local_acc_t shmem_m2n_; + sycl_local_acc_t shmem_count_; + sycl_local_acc_t is_last_block_done_; +}; + +} // namespace at::native::xpu From b0c4f90a561422534191c78369429df90b5b32eb Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 12 May 2025 16:00:43 +0800 Subject: [PATCH 02/12] Update BatchNormKernels.cpp --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 23 ++++++++----------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index b83b32bb5..3ede46662 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -1072,11 +1072,8 @@ void batch_norm_stats_channels_last_template( at::Tensor staging_data; at::Tensor semaphores; - using VecKernel = WelfordBatchNormStatChannelsLastVecKernelFunctor< - VarTransform, - scalar_t, - accscalar_t, - PREFERRED_VEC_SIZE>; + using VecKernel = + WelfordNormPFKernel; auto input_ptr = input.const_data_ptr(); auto out_mean_ptr = out_mean.mutable_data_ptr(); auto out_invstd_ptr = out_invstd.mutable_data_ptr(); @@ -1086,22 +1083,20 @@ void batch_norm_stats_channels_last_template( reduction_size, stride, input_ptr, out_mean_ptr, out_invstd_ptr)) { auto kfn = VecKernel( input_ptr, - out_mean_ptr, - out_invstd_ptr, - reduction_size, stride, - nullptr, - nullptr, - epsilon); + reduction_size, + epsilon, + out_mean_ptr, + out_invstd_ptr); kfn.init(); staging_data = at::empty({(long)(kfn.staging_size())}, out_mean.options()); semaphores = at::zeros( {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); - accscalar_t* staging_data_ptr = kfn.num_cooperative_groups() > 1 + accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 ? staging_data.mutable_data_ptr() : nullptr; - int* semaphores_ptr = kfn.num_cooperative_groups() > 1 + int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 ? semaphores.mutable_data_ptr() : nullptr; From d36939d565156e412944a86b0e6eb8d19295608e Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 12 May 2025 16:51:30 +0800 Subject: [PATCH 03/12] Add VarTransform --- src/ATen/native/xpu/sycl/WelfordNormPFImpl.h | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h index 0221e5859..d2827a794 100644 --- a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h +++ b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h @@ -75,7 +75,11 @@ inline void welford_merge( } // namespace impl -template +template < + typename VarTransform, + typename scalar_t, + typename acc_t, + int VEC_SIZE> struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { using vec_t = memory::aligned_vector; using acc_vec_t = memory::aligned_vector; @@ -188,7 +192,7 @@ struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { acc_vec_t invstd_vec; #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { - invstd_vec[v] = (acc_t)1.0 / std::sqrt(m2n[v] / count[v] + epsilon_); + invstd_vec[v] = VarTransform{}(m2n[v] / count[v], epsilon_); } *reinterpret_cast(&save_mean_[batch_vec_offset]) = mean; *reinterpret_cast(&save_invstd_[batch_vec_offset]) = @@ -265,7 +269,8 @@ struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { } void init() { - using KernelT = WelfordNormPFKernel; + using KernelT = + WelfordNormPFKernel; auto max_group_size = syclMaxWorkGroupSize(); std::tie(block_size_y_, block_size_x_, nblocks_y_, nblocks_x_) = impl::get_adaptive_config( From 3d160150727cb6ccb496390cbe4d3b99f0e30acc Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 12 May 2025 16:52:16 +0800 Subject: [PATCH 04/12] refine code path --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 94 ++++--------------- 1 file changed, 19 insertions(+), 75 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 3ede46662..f98c77292 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1072,87 +1072,31 @@ void batch_norm_stats_channels_last_template( at::Tensor staging_data; at::Tensor semaphores; - using VecKernel = - WelfordNormPFKernel; auto input_ptr = input.const_data_ptr(); auto out_mean_ptr = out_mean.mutable_data_ptr(); auto out_invstd_ptr = out_invstd.mutable_data_ptr(); - bool use_vec_kernel = false; - - if (VecKernel::valid( - reduction_size, stride, input_ptr, out_mean_ptr, out_invstd_ptr)) { - auto kfn = VecKernel( - input_ptr, - stride, - reduction_size, - epsilon, - out_mean_ptr, - out_invstd_ptr); - kfn.init(); - - staging_data = at::empty({(long)(kfn.staging_size())}, out_mean.options()); - semaphores = at::zeros( - {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); - accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 - ? staging_data.mutable_data_ptr() - : nullptr; - int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 - ? semaphores.mutable_data_ptr() - : nullptr; - - use_vec_kernel = kfn.set_staging_data_check(staging_data_ptr); - - if (use_vec_kernel) { - kfn.set_semaphores(semaphores_ptr); - sycl_kernel_submit( - kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); - return; - } - } - if (!use_vec_kernel) { - using KernelT = BatchNormCollectStatisticsChannelsLastKernelFunctor< - VarTransform, - scalar_t, - accscalar_t, - ELEMENTS_PER_ITER>; + using KernelT = WelfordNormPFKernel; - auto config = get_adaptive_launch_config( - syclMaxWorkGroupSize(), - reduction_size, - stride, - true, - ELEMENTS_PER_WORK_ITEM); - auto global_range = std::get<0>(config); - auto local_range = std::get<1>(config); - - auto wg_size_y = local_range[0]; - auto wg_size_x = local_range[1]; - auto nwg_y = global_range[0] / wg_size_y; - auto nwg_x = global_range[1] / wg_size_x; - if (nwg_y > 1) { - staging_data = - at::empty({(long)(4 * stride * nwg_y)}, out_mean.options()); - semaphores = at::zeros({(long)nwg_x}, input.options().dtype(at::kInt)); - } - accscalar_t* staging_data_ptr = - nwg_y > 1 ? staging_data.mutable_data_ptr() : nullptr; - int* semaphores_ptr = - nwg_y > 1 ? semaphores.mutable_data_ptr() : nullptr; - - auto kfn = KernelT( - input_ptr, - out_mean_ptr, - out_invstd_ptr, - staging_data_ptr, - semaphores_ptr, - reduction_size, - stride, - epsilon, - wg_size_y * wg_size_x); + auto kfn = KernelT( + input_ptr, stride, reduction_size, epsilon, out_mean_ptr, out_invstd_ptr); - sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), kfn); - } + kfn.init(); + staging_data = at::empty({(long)(kfn.staging_size())}, out_mean.options()); + semaphores = at::zeros( + {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); + + accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 + ? staging_data.mutable_data_ptr() + : nullptr; + int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 + ? semaphores.mutable_data_ptr() + : nullptr; + + TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); + kfn.set_semaphores(semaphores_ptr); + sycl_kernel_submit( + kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); } std::tuple batch_norm_stats_kernel( From 58c3763e9fc418c3ea88a0bcedc121a2479bb9a8 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 12 May 2025 17:13:33 +0800 Subject: [PATCH 05/12] add welford_norm_pf_kernel_vec_size --- src/ATen/native/xpu/sycl/WelfordNormPFImpl.h | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h index d2827a794..db8af4cf2 100644 --- a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h +++ b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h @@ -75,6 +75,25 @@ inline void welford_merge( } // namespace impl +template +int welford_norm_pf_kernel_vec_size( + int batch_size, + const scalar_t* input, + acc_t* save_mean, + acc_t* save_invstd, + int max_vec_bytes = 8) { + if (sizeof(scalar_t) >= max_vec_bytes) + return 1; + int vec_size = max_vec_bytes / sizeof(scalar_t); + while ((vec_size >= 1) && (batch_size % vec_size != 0) && + (memory::can_vectorize_up_to((char*)input) >= vec_size) && + (memory::can_vectorize_up_to((char*)save_mean) >= vec_size) && + (memory::can_vectorize_up_to((char*)save_invstd) >= vec_size)) { + vec_size >>= 1; + } + return vec_size; +} + template < typename VarTransform, typename scalar_t, From 5f580b216ccf42b0014fb2c8aacb025ac07f25c4 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 12 May 2025 17:14:38 +0800 Subject: [PATCH 06/12] Using dynamic vec size --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 68 +++++++++++++------ 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index f98c77292..91526e021 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1076,27 +1076,51 @@ void batch_norm_stats_channels_last_template( auto out_mean_ptr = out_mean.mutable_data_ptr(); auto out_invstd_ptr = out_invstd.mutable_data_ptr(); - using KernelT = WelfordNormPFKernel; - - auto kfn = KernelT( - input_ptr, stride, reduction_size, epsilon, out_mean_ptr, out_invstd_ptr); - - kfn.init(); - staging_data = at::empty({(long)(kfn.staging_size())}, out_mean.options()); - semaphores = at::zeros( - {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); - - accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 - ? staging_data.mutable_data_ptr() - : nullptr; - int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 - ? semaphores.mutable_data_ptr() - : nullptr; - - TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); - kfn.set_semaphores(semaphores_ptr); - sycl_kernel_submit( - kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); + int vec_size = welford_norm_pf_kernel_vec_size( + stride, input_ptr, out_mean_ptr, out_invstd_ptr); + +#define DISPATCH_VEC(VEC_SIZE) \ + { \ + using KernelT = \ + WelfordNormPFKernel; \ + auto kfn = KernelT( \ + input_ptr, \ + stride, \ + reduction_size, \ + epsilon, \ + out_mean_ptr, \ + out_invstd_ptr); \ + kfn.init(); \ + staging_data = \ + at::empty({(long)(kfn.staging_size())}, out_mean.options()); \ + semaphores = at::zeros( \ + {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); \ + accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 \ + ? staging_data.mutable_data_ptr() \ + : nullptr; \ + int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 \ + ? semaphores.mutable_data_ptr() \ + : nullptr; \ + TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); \ + kfn.set_semaphores(semaphores_ptr); \ + sycl_kernel_submit( \ + kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); \ + } + + switch (vec_size) { + case 8: + DISPATCH_VEC(8) + break; + case 4: + DISPATCH_VEC(4) + break; + case 2: + DISPATCH_VEC(2) + break; + default: + DISPATCH_VEC(1) + break; + } } std::tuple batch_norm_stats_kernel( @@ -5430,6 +5454,8 @@ std::tuple batch_norm_gather_stats_kernel( counts_); } +#undef DISPATCH_VEC + } // namespace xpu } // namespace native } // namespace at From 6b5d6c6e1c43523fa3340f0a489b0a51c7119f40 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 12 May 2025 17:19:37 +0800 Subject: [PATCH 07/12] Update WelfordNormPFImpl.h --- src/ATen/native/xpu/sycl/WelfordNormPFImpl.h | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h index db8af4cf2..fe4d7c5e6 100644 --- a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h +++ b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h @@ -85,10 +85,11 @@ int welford_norm_pf_kernel_vec_size( if (sizeof(scalar_t) >= max_vec_bytes) return 1; int vec_size = max_vec_bytes / sizeof(scalar_t); - while ((vec_size >= 1) && (batch_size % vec_size != 0) && - (memory::can_vectorize_up_to((char*)input) >= vec_size) && - (memory::can_vectorize_up_to((char*)save_mean) >= vec_size) && - (memory::can_vectorize_up_to((char*)save_invstd) >= vec_size)) { + while ( + !((batch_size % vec_size == 0) && + (memory::can_vectorize_up_to((char*)input) >= vec_size) && + (memory::can_vectorize_up_to((char*)save_mean) >= vec_size) && + (memory::can_vectorize_up_to((char*)save_invstd) >= vec_size))) { vec_size >>= 1; } return vec_size; From 344780d6c4b3cc101edff909407b523b06583864 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 15 May 2025 17:20:18 +0800 Subject: [PATCH 08/12] Update WelfordNormPFImpl.h --- src/ATen/native/xpu/sycl/WelfordNormPFImpl.h | 28 +++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h index fe4d7c5e6..c4d15cc2d 100644 --- a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h +++ b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h @@ -99,11 +99,13 @@ template < typename VarTransform, typename scalar_t, typename acc_t, - int VEC_SIZE> + int VEC_SIZE, + typename running_mean_t = acc_t> struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { using vec_t = memory::aligned_vector; using acc_vec_t = memory::aligned_vector; using int_vec_t = memory::aligned_vector; + using running_mean_vec_t = memory::aligned_vector; void operator()(sycl::nd_item<2> item) const { // init welford counters @@ -219,28 +221,28 @@ struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { invstd_vec; if (running_mean_ != nullptr) { - auto running_mean_vec = - *reinterpret_cast(&running_mean_[batch_vec_offset]); + auto running_mean_vec = *reinterpret_cast( + &running_mean_[batch_vec_offset]); #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { running_mean_vec[v] = mean[v] * momentum_ + (1 - momentum_) * running_mean_vec[v]; } - *reinterpret_cast(&running_mean_[batch_vec_offset]) = - running_mean_vec; + *reinterpret_cast( + &running_mean_[batch_vec_offset]) = running_mean_vec; } if (running_var_ != nullptr) { - auto running_var_vec = - *reinterpret_cast(&running_var_[batch_vec_offset]); + auto running_var_vec = *reinterpret_cast( + &running_var_[batch_vec_offset]); #pragma unroll for (int v = 0; v < VEC_SIZE; ++v) { auto unbiased_var = m2n[v] / (count[v] - 1); running_var_vec[v] = unbiased_var * momentum_ + (1 - momentum_) * running_var_vec[v]; } - *reinterpret_cast(&running_var_[batch_vec_offset]) = - running_var_vec; + *reinterpret_cast( + &running_var_[batch_vec_offset]) = running_var_vec; } } } @@ -345,8 +347,8 @@ struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { } void set_running_mean_var( - scalar_t* running_mean, - scalar_t* running_var, + running_mean_t* running_mean, + running_mean_t* running_var, acc_t momentum) { running_mean_ = running_mean; running_var_ = running_var; @@ -385,8 +387,8 @@ struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { acc_t* staging_data_; int* semaphores_; - scalar_t* running_mean_; - scalar_t* running_var_; + running_mean_t* running_mean_; + running_mean_t* running_var_; acc_t momentum_; size_t block_size_y_; From b6dbf8fa41b88c67c7b6ec1f23a304c75cecb188 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 15 May 2025 17:24:18 +0800 Subject: [PATCH 09/12] Update BatchNormKernels.cpp --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 140 ++++++++++++++++-- 1 file changed, 129 insertions(+), 11 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 91526e021..e17a6bd05 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -1122,6 +1122,7 @@ void batch_norm_stats_channels_last_template( break; } } +#undef DISPATCH_VEC std::tuple batch_norm_stats_kernel( const Tensor& self, @@ -4003,6 +4004,103 @@ void batch_norm_mean_var( } } +void batch_norm_mean_var_fused_cnl( + const Tensor& input, + Tensor& out_mean, + Tensor& out_invstd, + Tensor& running_mean, + Tensor& running_var, + double momentum = 0, + double dummy_epsilon = 1e-5) { + // NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "batch_norm_mean_var_fused_cnl", + [&] { + std::cout << "into batch_norm_mean_var_fused_cnl\n"; + using accscalar_t = acc_type_device; + + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::native::resize_output(out_mean, {stride}); + at::native::resize_output(out_invstd, {stride}); + TORCH_INTERNAL_ASSERT( + out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT( + out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + at::Tensor staging_data; + at::Tensor semaphores; + + auto input_ptr = input.const_data_ptr(); + auto out_mean_ptr = out_mean.mutable_data_ptr(); + auto out_invstd_ptr = out_invstd.mutable_data_ptr(); + + int vec_size = welford_norm_pf_kernel_vec_size( + stride, input_ptr, out_mean_ptr, out_invstd_ptr); // need check running mean & var + std::cout << "vec_size:" << vec_size << "\n"; + + auto running_mean_ptr = running_mean.defined() + ? running_mean.data_ptr() + : nullptr; + auto running_var_ptr = running_var.defined() + ? running_var.data_ptr() + : nullptr; + + std::cout << "got running_var_ptr\n"; + +#define DISPATCH_VEC(VEC_SIZE) \ + { \ + using KernelT = \ + WelfordNormPFKernel; \ + auto kfn = KernelT( \ + input_ptr, \ + stride, \ + reduction_size, \ + (accscalar_t)dummy_epsilon, \ + out_mean_ptr, \ + out_invstd_ptr); \ + kfn.init(); \ + staging_data = \ + at::empty({(long)(kfn.staging_size())}, out_mean.options()); \ + semaphores = at::zeros( \ + {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); \ + accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 \ + ? staging_data.mutable_data_ptr() \ + : nullptr; \ + int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 \ + ? semaphores.mutable_data_ptr() \ + : nullptr; \ + TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); \ + kfn.set_semaphores(semaphores_ptr); \ + kfn.set_running_mean_var( \ + running_mean_ptr, running_var_ptr, (accscalar_t)momentum); \ + sycl_kernel_submit( \ + kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); \ + } + switch (vec_size) { + case 8: + DISPATCH_VEC(8) + break; + case 4: + DISPATCH_VEC(4) + break; + case 2: + DISPATCH_VEC(2) + break; + default: + DISPATCH_VEC(1) + break; + } + }); +} +#undef DISPATCH_VEC + std::tuple batch_norm_update_stats_kernel( const Tensor& self, const std::optional& running_mean_opt, @@ -4266,20 +4364,42 @@ std::tuple batch_norm_kernel( (running_var_opt.has_value() && running_var_opt->defined()); TORCH_CHECK(has_running_mean == has_running_var); + bool can_use_fused_kernel = + batch_norm_choose_impl(self) == Impl::ChannelsLast && + (!save_mean.defined() || save_mean.is_contiguous()) && + (!save_invstd.defined() || save_invstd.is_contiguous()); + if (train) { - batch_norm_mean_var(self, save_mean, save_invstd); - if (has_running_mean) { - const int64_t N = self.numel() / save_mean.numel(); - batch_norm_update_stats_and_invert( + if (can_use_fused_kernel && !has_running_mean) { + Tensor a; + batch_norm_mean_var_fused_cnl( + self, save_mean, save_invstd, a, a, momentum, epsilon); + } else if ( + can_use_fused_kernel && + (*running_mean_opt).dtype() == save_mean.dtype()) { + batch_norm_mean_var_fused_cnl( + self, save_mean, save_invstd, - *running_mean_opt, - *running_var_opt, + const_cast(*running_mean_opt), + const_cast(*running_var_opt), momentum, - epsilon, - N); + epsilon); } else { - batch_norm_calc_invstd(save_invstd, save_invstd, epsilon); + batch_norm_mean_var(self, save_mean, save_invstd); + if (has_running_mean) { + const int64_t N = self.numel() / save_mean.numel(); + batch_norm_update_stats_and_invert( + save_mean, + save_invstd, + *running_mean_opt, + *running_var_opt, + momentum, + epsilon, + N); + } else { + batch_norm_calc_invstd(save_invstd, save_invstd, epsilon); + } } } else { TORCH_CHECK(has_running_mean); @@ -5454,8 +5574,6 @@ std::tuple batch_norm_gather_stats_kernel( counts_); } -#undef DISPATCH_VEC - } // namespace xpu } // namespace native } // namespace at From da0657f487da943ed048a9c926ac99c827efca58 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 15 May 2025 17:36:49 +0800 Subject: [PATCH 10/12] remove cout --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index e17a6bd05..640b3e975 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -4019,7 +4019,6 @@ void batch_norm_mean_var_fused_cnl( input.scalar_type(), "batch_norm_mean_var_fused_cnl", [&] { - std::cout << "into batch_norm_mean_var_fused_cnl\n"; using accscalar_t = acc_type_device; const auto stride = input.sizes()[1]; @@ -4043,7 +4042,6 @@ void batch_norm_mean_var_fused_cnl( int vec_size = welford_norm_pf_kernel_vec_size( stride, input_ptr, out_mean_ptr, out_invstd_ptr); // need check running mean & var - std::cout << "vec_size:" << vec_size << "\n"; auto running_mean_ptr = running_mean.defined() ? running_mean.data_ptr() @@ -4052,8 +4050,6 @@ void batch_norm_mean_var_fused_cnl( ? running_var.data_ptr() : nullptr; - std::cout << "got running_var_ptr\n"; - #define DISPATCH_VEC(VEC_SIZE) \ { \ using KernelT = \ From 2478eb7b0b46f5ec2253f85be3b631fb8c405491 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 15 May 2025 22:58:12 +0800 Subject: [PATCH 11/12] Update BatchNormKernels.cpp --- src/ATen/native/xpu/sycl/BatchNormKernels.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 640b3e975..fe319824f 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -4040,9 +4040,6 @@ void batch_norm_mean_var_fused_cnl( auto out_mean_ptr = out_mean.mutable_data_ptr(); auto out_invstd_ptr = out_invstd.mutable_data_ptr(); - int vec_size = welford_norm_pf_kernel_vec_size( - stride, input_ptr, out_mean_ptr, out_invstd_ptr); // need check running mean & var - auto running_mean_ptr = running_mean.defined() ? running_mean.data_ptr() : nullptr; @@ -4050,6 +4047,14 @@ void batch_norm_mean_var_fused_cnl( ? running_var.data_ptr() : nullptr; + int vec_size = welford_norm_pf_kernel_vec_size( + stride, + input_ptr, + out_mean_ptr, + out_invstd_ptr, + running_mean_ptr, + running_var_ptr); + #define DISPATCH_VEC(VEC_SIZE) \ { \ using KernelT = \ From b6960d725565b70f67d44a69b9ed2e77add03015 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Thu, 15 May 2025 22:59:04 +0800 Subject: [PATCH 12/12] Update WelfordNormPFImpl.h --- src/ATen/native/xpu/sycl/WelfordNormPFImpl.h | 29 +++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h index c4d15cc2d..1c29f4f4f 100644 --- a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h +++ b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h @@ -30,7 +30,7 @@ std::tuple get_adaptive_config( const int vec_size, const int max_block_size, int loops_per_thread = 8, - int coop_th = 8) { + int coop_th = 16) { loops_per_thread /= vec_size; // Ensure the number of instructions is normalized int threads_along_batch = last_pow2(batch_size / vec_size); @@ -43,7 +43,7 @@ std::tuple get_adaptive_config( block_size_x = std::min(threads_along_batch, max_block_size / block_size_y); } - int max_threads_gpu = syclMaxWorkItemsPerTile(); + int max_threads_gpu = syclMaxDSSNum() * syclMaxWorkItemsPerSubSlice(); int nblock_x = divup(batch_size, block_size_x * vec_size); int nblock_y = std::min( divup(problem_size, block_size_y * loops_per_thread), @@ -75,23 +75,38 @@ inline void welford_merge( } // namespace impl -template +template int welford_norm_pf_kernel_vec_size( int batch_size, const scalar_t* input, acc_t* save_mean, acc_t* save_invstd, + running_t* running_mean = nullptr, + running_t* running_var = nullptr, int max_vec_bytes = 8) { if (sizeof(scalar_t) >= max_vec_bytes) return 1; int vec_size = max_vec_bytes / sizeof(scalar_t); + + auto input_vec_size = memory::can_vectorize_up_to((char*)input); + auto save_mean_vec_size = + memory::can_vectorize_up_to((char*)save_mean); + auto save_invstd_vec_size = + memory::can_vectorize_up_to((char*)save_invstd); + while ( - !((batch_size % vec_size == 0) && - (memory::can_vectorize_up_to((char*)input) >= vec_size) && - (memory::can_vectorize_up_to((char*)save_mean) >= vec_size) && - (memory::can_vectorize_up_to((char*)save_invstd) >= vec_size))) { + !(batch_size % vec_size == 0 && input_vec_size >= vec_size && + save_mean_vec_size >= vec_size && save_invstd_vec_size >= vec_size)) { vec_size >>= 1; } + if (running_mean != nullptr) { + vec_size = std::min( + memory::can_vectorize_up_to((char*)running_mean), vec_size); + } + if (running_var != nullptr) { + vec_size = std::min( + memory::can_vectorize_up_to((char*)running_var), vec_size); + } return vec_size; }