diff --git a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp index 735c9615e..e6fc6a96d 100644 --- a/src/ATen/native/xpu/sycl/BatchNormKernels.cpp +++ b/src/ATen/native/xpu/sycl/BatchNormKernels.cpp @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include #include #include @@ -1072,93 +1072,57 @@ void batch_norm_stats_channels_last_template( at::Tensor staging_data; at::Tensor semaphores; - using VecKernel = WelfordBatchNormStatChannelsLastVecKernelFunctor< - VarTransform, - scalar_t, - accscalar_t, - PREFERRED_VEC_SIZE>; auto input_ptr = input.const_data_ptr(); auto out_mean_ptr = out_mean.mutable_data_ptr(); auto out_invstd_ptr = out_invstd.mutable_data_ptr(); - bool use_vec_kernel = false; - - if (VecKernel::valid( - reduction_size, stride, input_ptr, out_mean_ptr, out_invstd_ptr)) { - auto kfn = VecKernel( - input_ptr, - out_mean_ptr, - out_invstd_ptr, - reduction_size, - stride, - nullptr, - nullptr, - epsilon); - kfn.init(); - - staging_data = at::empty({(long)(kfn.staging_size())}, out_mean.options()); - semaphores = at::zeros( - {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); - accscalar_t* staging_data_ptr = kfn.num_cooperative_groups() > 1 - ? staging_data.mutable_data_ptr() - : nullptr; - int* semaphores_ptr = kfn.num_cooperative_groups() > 1 - ? semaphores.mutable_data_ptr() - : nullptr; - - use_vec_kernel = kfn.set_staging_data_check(staging_data_ptr); - - if (use_vec_kernel) { - kfn.set_semaphores(semaphores_ptr); - sycl_kernel_submit( - kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); - return; - } - } - if (!use_vec_kernel) { - using KernelT = BatchNormCollectStatisticsChannelsLastKernelFunctor< - VarTransform, - scalar_t, - accscalar_t, - ELEMENTS_PER_ITER>; - - auto config = get_adaptive_launch_config( - syclMaxWorkGroupSize(), - reduction_size, - stride, - true, - ELEMENTS_PER_WORK_ITEM); - auto global_range = std::get<0>(config); - auto local_range = std::get<1>(config); - - auto wg_size_y = local_range[0]; - auto wg_size_x = local_range[1]; - auto nwg_y = global_range[0] / wg_size_y; - auto nwg_x = global_range[1] / wg_size_x; - if (nwg_y > 1) { - staging_data = - at::empty({(long)(4 * stride * nwg_y)}, out_mean.options()); - semaphores = at::zeros({(long)nwg_x}, input.options().dtype(at::kInt)); - } - accscalar_t* staging_data_ptr = - nwg_y > 1 ? staging_data.mutable_data_ptr() : nullptr; - int* semaphores_ptr = - nwg_y > 1 ? semaphores.mutable_data_ptr() : nullptr; - - auto kfn = KernelT( - input_ptr, - out_mean_ptr, - out_invstd_ptr, - staging_data_ptr, - semaphores_ptr, - reduction_size, - stride, - epsilon, - wg_size_y * wg_size_x); + int vec_size = welford_norm_pf_kernel_vec_size( + stride, input_ptr, out_mean_ptr, out_invstd_ptr); + +#define DISPATCH_VEC(VEC_SIZE) \ + { \ + using KernelT = \ + WelfordNormPFKernel; \ + auto kfn = KernelT( \ + input_ptr, \ + stride, \ + reduction_size, \ + epsilon, \ + out_mean_ptr, \ + out_invstd_ptr); \ + kfn.init(); \ + staging_data = \ + at::empty({(long)(kfn.staging_size())}, out_mean.options()); \ + semaphores = at::zeros( \ + {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); \ + accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 \ + ? staging_data.mutable_data_ptr() \ + : nullptr; \ + int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 \ + ? semaphores.mutable_data_ptr() \ + : nullptr; \ + TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); \ + kfn.set_semaphores(semaphores_ptr); \ + sycl_kernel_submit( \ + kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); \ + } - sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), kfn); + switch (vec_size) { + case 8: + DISPATCH_VEC(8) + break; + case 4: + DISPATCH_VEC(4) + break; + case 2: + DISPATCH_VEC(2) + break; + default: + DISPATCH_VEC(1) + break; } } +#undef DISPATCH_VEC std::tuple batch_norm_stats_kernel( const Tensor& self, @@ -4040,6 +4004,104 @@ void batch_norm_mean_var( } } +void batch_norm_mean_var_fused_cnl( + const Tensor& input, + Tensor& out_mean, + Tensor& out_invstd, + Tensor& running_mean, + Tensor& running_var, + double momentum = 0, + double dummy_epsilon = 1e-5) { + // NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored. + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "batch_norm_mean_var_fused_cnl", + [&] { + using accscalar_t = acc_type_device; + + const auto stride = input.sizes()[1]; + const auto reduction_size = input.numel() / stride; + + at::native::resize_output(out_mean, {stride}); + at::native::resize_output(out_invstd, {stride}); + TORCH_INTERNAL_ASSERT( + out_invstd.dim() == 1 && out_invstd.is_contiguous() && + out_invstd.sizes()[0]); + TORCH_INTERNAL_ASSERT( + out_mean.dim() == 1 && out_mean.is_contiguous() && + out_mean.sizes()[0]); + + at::Tensor staging_data; + at::Tensor semaphores; + + auto input_ptr = input.const_data_ptr(); + auto out_mean_ptr = out_mean.mutable_data_ptr(); + auto out_invstd_ptr = out_invstd.mutable_data_ptr(); + + auto running_mean_ptr = running_mean.defined() + ? running_mean.data_ptr() + : nullptr; + auto running_var_ptr = running_var.defined() + ? running_var.data_ptr() + : nullptr; + + int vec_size = welford_norm_pf_kernel_vec_size( + stride, + input_ptr, + out_mean_ptr, + out_invstd_ptr, + running_mean_ptr, + running_var_ptr); + +#define DISPATCH_VEC(VEC_SIZE) \ + { \ + using KernelT = \ + WelfordNormPFKernel; \ + auto kfn = KernelT( \ + input_ptr, \ + stride, \ + reduction_size, \ + (accscalar_t)dummy_epsilon, \ + out_mean_ptr, \ + out_invstd_ptr); \ + kfn.init(); \ + staging_data = \ + at::empty({(long)(kfn.staging_size())}, out_mean.options()); \ + semaphores = at::zeros( \ + {(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); \ + accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 \ + ? staging_data.mutable_data_ptr() \ + : nullptr; \ + int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 \ + ? semaphores.mutable_data_ptr() \ + : nullptr; \ + TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); \ + kfn.set_semaphores(semaphores_ptr); \ + kfn.set_running_mean_var( \ + running_mean_ptr, running_var_ptr, (accscalar_t)momentum); \ + sycl_kernel_submit( \ + kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); \ + } + switch (vec_size) { + case 8: + DISPATCH_VEC(8) + break; + case 4: + DISPATCH_VEC(4) + break; + case 2: + DISPATCH_VEC(2) + break; + default: + DISPATCH_VEC(1) + break; + } + }); +} +#undef DISPATCH_VEC + std::tuple batch_norm_update_stats_kernel( const Tensor& self, const std::optional& running_mean_opt, @@ -4303,20 +4365,42 @@ std::tuple batch_norm_kernel( (running_var_opt.has_value() && running_var_opt->defined()); TORCH_CHECK(has_running_mean == has_running_var); + bool can_use_fused_kernel = + batch_norm_choose_impl(self) == Impl::ChannelsLast && + (!save_mean.defined() || save_mean.is_contiguous()) && + (!save_invstd.defined() || save_invstd.is_contiguous()); + if (train) { - batch_norm_mean_var(self, save_mean, save_invstd); - if (has_running_mean) { - const int64_t N = self.numel() / save_mean.numel(); - batch_norm_update_stats_and_invert( + if (can_use_fused_kernel && !has_running_mean) { + Tensor a; + batch_norm_mean_var_fused_cnl( + self, save_mean, save_invstd, a, a, momentum, epsilon); + } else if ( + can_use_fused_kernel && + (*running_mean_opt).dtype() == save_mean.dtype()) { + batch_norm_mean_var_fused_cnl( + self, save_mean, save_invstd, - *running_mean_opt, - *running_var_opt, + const_cast(*running_mean_opt), + const_cast(*running_var_opt), momentum, - epsilon, - N); + epsilon); } else { - batch_norm_calc_invstd(save_invstd, save_invstd, epsilon); + batch_norm_mean_var(self, save_mean, save_invstd); + if (has_running_mean) { + const int64_t N = self.numel() / save_mean.numel(); + batch_norm_update_stats_and_invert( + save_mean, + save_invstd, + *running_mean_opt, + *running_var_opt, + momentum, + epsilon, + N); + } else { + batch_norm_calc_invstd(save_invstd, save_invstd, epsilon); + } } } else { TORCH_CHECK(has_running_mean); diff --git a/src/ATen/native/xpu/sycl/WelfordNorm.h b/src/ATen/native/xpu/sycl/WelfordNorm.h deleted file mode 100644 index d9d7720f6..000000000 --- a/src/ATen/native/xpu/sycl/WelfordNorm.h +++ /dev/null @@ -1,327 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -namespace at::native::xpu { - -template -inline T divup(T a, T b) { - return (a + b - 1) / b; -} - -std::tuple get_adaptive_config( - const int reduction, - const int n_channels, - const int vec_size, - int max_wg_size, - int loops_per_item = 8) { - loops_per_item /= vec_size; - int group_size_x = std::min(last_pow2(n_channels / vec_size), 32); - int group_size_y = std::min( - last_pow2(divup(reduction, loops_per_item)), max_wg_size / group_size_x); - if (group_size_x * group_size_y != max_wg_size) { - group_size_x = - std::min(last_pow2(n_channels / vec_size), max_wg_size / group_size_y); - } - - int nwg_x = divup(n_channels, group_size_x * vec_size); - int nwg_y = std::min( - divup(reduction, group_size_y * loops_per_item), - int(syclMaxWorkItemsPerTile()) / (nwg_x * group_size_x) / (group_size_y)); - nwg_y = std::max(nwg_y, 1); - - // it's not worth having reduction between work groups if the reduction - // dimension is not big enough - nwg_y = nwg_y < 4 ? 1 : nwg_y; - - return std::make_tuple(group_size_y, group_size_x, nwg_y, nwg_x); -} - -template -inline void welford_merge( - C& count, - T& mean, - T& m2n, - const C& count_new, - const T& mean_new, - const T& m2n_new) { - T factor = T(1.0) / std::max(1, (count + count_new)); - T delta0 = mean - mean_new; - mean = (mean_new * count_new + mean * count) * factor; - m2n += m2n_new + delta0 * delta0 * count_new * count * factor; - count += count_new; -} - -template -inline void welford_vertical_merge( - sycl::nd_item<2>& item, - C& count, - T& mean, - T& m2n, - CACC& shmem_count, - TACC& shmem_mean, - TACC& shmem_m2n) { - // write to shared memory - auto address_base = item.get_local_linear_id(); -#pragma unroll - for (int offset = item.get_local_range(0) / 2; offset > 0; offset >>= 1) { - if (item.get_local_id(0) < offset * 2) { - shmem_mean[address_base] = mean; - shmem_m2n[address_base] = m2n; - shmem_count[address_base] = count; - } - item.barrier(sycl_local_fence); - if (item.get_local_id(0) < offset && - item.get_local_id(0) + offset < item.get_local_range(0)) { - auto address = address_base + offset * item.get_local_range(1); - // read shared memory back to register for reduction - auto count_new = shmem_count[address]; - auto mean_new = shmem_mean[address]; - auto m2n_new = shmem_m2n[address]; -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - welford_merge( - count[v], mean[v], m2n[v], count_new[v], mean_new[v], m2n_new[v]); - } - } - } -} - -template < - typename VarTransform, - typename scalar_t, - typename acc_t, - int VEC_SIZE = 2> -struct WelfordBatchNormStatChannelsLastVecKernelFunctor - : public __SYCL_KER_CONFIG_CONVENTION__ { - using vec_t = memory::aligned_vector; - using acc_vec_t = memory::aligned_vector; - using int_vec_t = memory::aligned_vector; - - void operator()(sycl::nd_item<2> item) const { - // init private counter - acc_vec_t mean; - acc_vec_t m2n; - int_vec_t count; -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - mean[v] = acc_t(0); - m2n[v] = acc_t(0); - count[v] = int(0); - } - - int gy = item.get_group(0); - int gx = item.get_group(1); - int c_vec_offset = item.get_global_id(1) * VEC_SIZE; - int num_cooperative_groups = item.get_group_range(0); - int inner_loop_stride = item.get_local_range(0) * num_cooperative_groups; - - for (int m_offset = item.get_global_id(0); m_offset < reduction_size_; - m_offset += inner_loop_stride) { - if (c_vec_offset < n_channels_) { - int address_vec_base = m_offset * n_channels_ + c_vec_offset; - auto input_vec = *reinterpret_cast( - const_cast(&input_[address_vec_base])); -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - auto x = input_vec[v]; - count[v]++; - acc_t delta0 = x - mean[v]; - mean[v] += delta0 / count[v]; - acc_t delta1 = x - mean[v]; - m2n[v] += delta0 * delta1; - } - } - } - - welford_vertical_merge( - item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); - - // welford vertical merge - if (num_cooperative_groups > 1) { - acc_t* staging_mean = staging_data_; - acc_t* staging_m2n = &staging_data_[n_channels_ * num_cooperative_groups]; - int* staging_count = reinterpret_cast( - &staging_m2n[n_channels_ * num_cooperative_groups]); - int address_vec_base = c_vec_offset + gy * n_channels_; - - // write data to staging_data; - if (item.get_local_id(0) == 0 && c_vec_offset < n_channels_) { - *reinterpret_cast(&staging_mean[address_vec_base]) = mean; - *reinterpret_cast(&staging_m2n[address_vec_base]) = m2n; - *reinterpret_cast(&staging_count[address_vec_base]) = count; - } - item.barrier(sycl_local_fence); - - // mark group done - if (item.get_local_linear_id() == 0) { - sycl_atomic_ref_rlx_dev_global_t atomic_count(semaphores_[gx]); - int old = atomic_count.fetch_add( - 1, sycl_mem_odr_acq_rel - /* , default memory scope is device */); - is_last_group_done_[0] = (old == (num_cooperative_groups - 1)); - } - item.barrier(sycl_local_fence); - - // check that all data is now available in global memory - if (is_last_group_done_[0]) { -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - mean[v] = acc_t(0); - m2n[v] = acc_t(0); - count[v] = int(0); - } - - for (int y = item.get_local_id(0); y < num_cooperative_groups; - y += item.get_local_range(0)) { - if (c_vec_offset < n_channels_) { - address_vec_base = y * n_channels_ + c_vec_offset; - auto mean_new = - *reinterpret_cast(&staging_mean[address_vec_base]); - auto m2n_new = - *reinterpret_cast(&staging_m2n[address_vec_base]); - auto count_new = - *reinterpret_cast(&staging_count[address_vec_base]); -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - welford_merge( - count[v], - mean[v], - m2n[v], - count_new[v], - mean_new[v], - m2n_new[v]); - } - } - } - welford_vertical_merge( - item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); - } - } - - if (item.get_local_id(0) == 0 && - (num_cooperative_groups == 1 || is_last_group_done_[0]) && - c_vec_offset < n_channels_) { - acc_vec_t invstd_vec; -#pragma unroll - for (int v = 0; v < VEC_SIZE; ++v) { - invstd_vec[v] = VarTransform{}(m2n[v] / count[v], epsilon_); - } - - *reinterpret_cast(&save_mean_[c_vec_offset]) = mean; - *reinterpret_cast(&save_invstd_[c_vec_offset]) = invstd_vec; - } - } - - void sycl_ker_config_convention(sycl::handler& cgh) { - auto local_size = group_size_x_ * group_size_y_; - shmem_mean_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); - shmem_m2n_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); - shmem_count_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); - is_last_group_done_ = sycl_local_acc_t(sycl::range<1>(1), cgh); - } - - WelfordBatchNormStatChannelsLastVecKernelFunctor( - const scalar_t* input, - acc_t* save_mean, - acc_t* save_invstd, - int reduction_size, - int n_channels, - acc_t* staging_data, - int* semaphores, - double epsilon) - : input_(input), - save_mean_(save_mean), - save_invstd_(save_invstd), - reduction_size_(reduction_size), - n_channels_(n_channels), - staging_data_(staging_data), - semaphores_(semaphores), - epsilon_(epsilon) {} - - void init() { - using KernelT = WelfordBatchNormStatChannelsLastVecKernelFunctor< - VarTransform, - scalar_t, - acc_t, - VEC_SIZE>; - auto max_group_size = syclMaxWorkGroupSize(); - std::tie(group_size_y_, group_size_x_, ngroups_y_, ngroups_x_) = - get_adaptive_config( - reduction_size_, n_channels_, VEC_SIZE, max_group_size); - } - - static bool valid( - int reduction_size, - int n_channels, - const scalar_t* input, - acc_t* save_mean, - acc_t* save_invstd) { - bool valid = sizeof(scalar_t) <= 2; - valid = valid && (n_channels % VEC_SIZE == 0); - valid = valid && - (memory::can_vectorize_up_to((char*)input) >= VEC_SIZE); - valid = valid && - (memory::can_vectorize_up_to((char*)save_mean) >= VEC_SIZE); - valid = valid && - (memory::can_vectorize_up_to((char*)save_invstd) >= VEC_SIZE); - return valid; - } - - sycl::range<2> local_range() const { - return sycl::range<2>(group_size_y_, group_size_x_); - } - - sycl::range<2> global_range() const { - return sycl::range<2>( - group_size_y_ * ngroups_y_, group_size_x_ * ngroups_x_); - } - - int staging_size() const { - return ngroups_y_ * n_channels_ * 4; - } - - int semaphores_size() const { - return ngroups_x_; - } - - bool set_staging_data_check(acc_t* staging_data) { - staging_data_ = staging_data; - return ( - (staging_data == nullptr) || - (memory::can_vectorize_up_to((char*)staging_data) >= VEC_SIZE)); - } - - void set_semaphores(int* semaphores) { - semaphores_ = semaphores; - } - - int num_cooperative_groups() const { - return ngroups_y_; - } - - private: - const scalar_t* input_; - acc_t* save_mean_; - acc_t* save_invstd_; - int reduction_size_; - int n_channels_; - acc_t* staging_data_; - int* semaphores_; - double epsilon_; - - size_t group_size_y_; - size_t group_size_x_; - size_t ngroups_y_; - size_t ngroups_x_; - - sycl_local_acc_t shmem_mean_; - sycl_local_acc_t shmem_m2n_; - sycl_local_acc_t shmem_count_; - sycl_local_acc_t is_last_group_done_; -}; - -} // namespace at::native::xpu diff --git a/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h new file mode 100644 index 000000000..1c29f4f4f --- /dev/null +++ b/src/ATen/native/xpu/sycl/WelfordNormPFImpl.h @@ -0,0 +1,420 @@ +#pragma once + +#include +#include +#include +#include + +namespace at::native::xpu { + +namespace impl { + +// returns floor(log2(n)) +inline int last_pow2(int n) { + n |= (n >> 1); + n |= (n >> 2); + n |= (n >> 4); + n |= (n >> 8); + n |= (n >> 16); + return std::max(1, n - (n >> 1)); +} + +template +inline T divup(T a, T b) { + return (a + b - 1) / b; +} + +std::tuple get_adaptive_config( + const int problem_size, + const int batch_size, + const int vec_size, + const int max_block_size, + int loops_per_thread = 8, + int coop_th = 16) { + loops_per_thread /= + vec_size; // Ensure the number of instructions is normalized + int threads_along_batch = last_pow2(batch_size / vec_size); + int threads_along_problem = last_pow2(divup(problem_size, loops_per_thread)); + + int block_size_x = std::min(threads_along_batch, 32); + int block_size_y = + std::min(threads_along_problem, max_block_size / block_size_x); + if (block_size_x * block_size_y != max_block_size) { + block_size_x = std::min(threads_along_batch, max_block_size / block_size_y); + } + + int max_threads_gpu = syclMaxDSSNum() * syclMaxWorkItemsPerSubSlice(); + int nblock_x = divup(batch_size, block_size_x * vec_size); + int nblock_y = std::min( + divup(problem_size, block_size_y * loops_per_thread), + max_threads_gpu / (nblock_x * block_size_x) / (block_size_y)); + nblock_y = std::max(nblock_y, 1); + + // it's not worth having reduction between blocks if the reduction + // dimension is not big enough + coop_th /= vec_size; + nblock_y = nblock_y < coop_th ? 1 : nblock_y; + + return std::make_tuple(block_size_y, block_size_x, nblock_y, nblock_x); +} + +template +inline void welford_merge( + C& count, + T& mean, + T& m2n, + const C& count_new, + const T& mean_new, + const T& m2n_new) { + T factor = T(1.0) / std::max(1, (count + count_new)); + T delta0 = mean - mean_new; + mean = (mean_new * count_new + mean * count) * factor; + m2n += m2n_new + delta0 * delta0 * count_new * count * factor; + count += count_new; +} + +} // namespace impl + +template +int welford_norm_pf_kernel_vec_size( + int batch_size, + const scalar_t* input, + acc_t* save_mean, + acc_t* save_invstd, + running_t* running_mean = nullptr, + running_t* running_var = nullptr, + int max_vec_bytes = 8) { + if (sizeof(scalar_t) >= max_vec_bytes) + return 1; + int vec_size = max_vec_bytes / sizeof(scalar_t); + + auto input_vec_size = memory::can_vectorize_up_to((char*)input); + auto save_mean_vec_size = + memory::can_vectorize_up_to((char*)save_mean); + auto save_invstd_vec_size = + memory::can_vectorize_up_to((char*)save_invstd); + + while ( + !(batch_size % vec_size == 0 && input_vec_size >= vec_size && + save_mean_vec_size >= vec_size && save_invstd_vec_size >= vec_size)) { + vec_size >>= 1; + } + if (running_mean != nullptr) { + vec_size = std::min( + memory::can_vectorize_up_to((char*)running_mean), vec_size); + } + if (running_var != nullptr) { + vec_size = std::min( + memory::can_vectorize_up_to((char*)running_var), vec_size); + } + return vec_size; +} + +template < + typename VarTransform, + typename scalar_t, + typename acc_t, + int VEC_SIZE, + typename running_mean_t = acc_t> +struct WelfordNormPFKernel : public __SYCL_KER_CONFIG_CONVENTION__ { + using vec_t = memory::aligned_vector; + using acc_vec_t = memory::aligned_vector; + using int_vec_t = memory::aligned_vector; + using running_mean_vec_t = memory::aligned_vector; + + void operator()(sycl::nd_item<2> item) const { + // init welford counters + acc_vec_t mean; + acc_vec_t m2n; + int_vec_t count; +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + mean[v] = acc_t(0); + m2n[v] = acc_t(0); + count[v] = int(0); + } + + int bx = item.get_group(1); // along batch dim + int by = item.get_group(0); // along problem dim + int batch_vec_offset = item.get_global_id(1) * VEC_SIZE; + int num_cooperative_blocks = item.get_group_range(0); + int inner_loop_stride = item.get_local_range(0) * num_cooperative_blocks; + + if (batch_vec_offset < batch_size_) { + for (int p_offset = item.get_global_id(0); p_offset < problem_size_; + p_offset += inner_loop_stride) { + int address_vec_base = p_offset * batch_size_ + batch_vec_offset; + auto input_vec = *reinterpret_cast( + const_cast(&input_[address_vec_base])); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + auto x = input_vec[v]; + count[v]++; + acc_t delta0 = x - mean[v]; + mean[v] += delta0 / count[v]; + acc_t delta1 = x - mean[v]; + m2n[v] += delta0 * delta1; + } + } + } + + welford_vertical_merge( + item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); + + // welford vertical merge + if (num_cooperative_blocks > 1) { + acc_t* staging_mean = staging_data_; + acc_t* staging_m2n = &staging_data_[batch_size_ * num_cooperative_blocks]; + int* staging_count = reinterpret_cast( + &staging_m2n[batch_size_ * num_cooperative_blocks]); + int address_vec_base = batch_vec_offset + by * batch_size_; + + // write data to staging_data; + if (item.get_local_id(0) == 0 && batch_vec_offset < batch_size_) { + *reinterpret_cast(&staging_mean[address_vec_base]) = mean; + *reinterpret_cast(&staging_m2n[address_vec_base]) = m2n; + *reinterpret_cast(&staging_count[address_vec_base]) = count; + } + item.barrier(sycl_local_fence); + + // mark block done + if (item.get_local_linear_id() == 0) { + sycl_atomic_ref_rlx_dev_global_t atomic_count(semaphores_[bx]); + int old = atomic_count.fetch_add( + 1, sycl_mem_odr_acq_rel + /* , default memory scope is device */); + is_last_block_done_[0] = (old == (num_cooperative_blocks - 1)); + } + item.barrier(sycl_local_fence); + + // check that all data is now available in global memory + if (is_last_block_done_[0]) { +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + mean[v] = acc_t(0); + m2n[v] = acc_t(0); + count[v] = int(0); + } + + for (int y = item.get_local_id(0); y < num_cooperative_blocks; + y += item.get_local_range(0)) { + if (batch_vec_offset < batch_size_) { + address_vec_base = y * batch_size_ + batch_vec_offset; + auto mean_new = + *reinterpret_cast(&staging_mean[address_vec_base]); + auto m2n_new = + *reinterpret_cast(&staging_m2n[address_vec_base]); + auto count_new = + *reinterpret_cast(&staging_count[address_vec_base]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + impl::welford_merge( + count[v], + mean[v], + m2n[v], + count_new[v], + mean_new[v], + m2n_new[v]); + } + } + } + welford_vertical_merge( + item, count, mean, m2n, shmem_count_, shmem_mean_, shmem_m2n_); + } + } + + if (item.get_local_id(0) == 0 && + (num_cooperative_blocks == 1 || is_last_block_done_[0]) && + batch_vec_offset < batch_size_) { + acc_vec_t invstd_vec; +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + invstd_vec[v] = VarTransform{}(m2n[v] / count[v], epsilon_); + } + *reinterpret_cast(&save_mean_[batch_vec_offset]) = mean; + *reinterpret_cast(&save_invstd_[batch_vec_offset]) = + invstd_vec; + + if (running_mean_ != nullptr) { + auto running_mean_vec = *reinterpret_cast( + &running_mean_[batch_vec_offset]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + running_mean_vec[v] = + mean[v] * momentum_ + (1 - momentum_) * running_mean_vec[v]; + } + *reinterpret_cast( + &running_mean_[batch_vec_offset]) = running_mean_vec; + } + + if (running_var_ != nullptr) { + auto running_var_vec = *reinterpret_cast( + &running_var_[batch_vec_offset]); +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + auto unbiased_var = m2n[v] / (count[v] - 1); + running_var_vec[v] = + unbiased_var * momentum_ + (1 - momentum_) * running_var_vec[v]; + } + *reinterpret_cast( + &running_var_[batch_vec_offset]) = running_var_vec; + } + } + } + + template + inline void welford_vertical_merge( + sycl::nd_item<2>& item, + int_vec_t& count, + acc_vec_t& mean, + acc_vec_t& m2n, + CACC& shmem_count, + TACC& shmem_mean, + TACC& shmem_m2n) const { + // write to shared memory + auto address_base = item.get_local_linear_id(); +#pragma unroll + for (int offset = item.get_local_range(0) / 2; offset > 0; offset >>= 1) { + if (item.get_local_id(0) < offset * 2) { + shmem_mean[address_base] = mean; + shmem_m2n[address_base] = m2n; + shmem_count[address_base] = count; + } + item.barrier(sycl_local_fence); + if (item.get_local_id(0) < offset && + item.get_local_id(0) + offset < item.get_local_range(0)) { + auto address = address_base + offset * item.get_local_range(1); + // read shared memory back to register for reduction + auto count_new = shmem_count[address]; + auto mean_new = shmem_mean[address]; + auto m2n_new = shmem_m2n[address]; +#pragma unroll + for (int v = 0; v < VEC_SIZE; ++v) { + impl::welford_merge( + count[v], mean[v], m2n[v], count_new[v], mean_new[v], m2n_new[v]); + } + } + } + } + + void sycl_ker_config_convention(sycl::handler& cgh) { + auto local_size = block_size_x_ * block_size_y_; + shmem_mean_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); + shmem_m2n_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); + shmem_count_ = sycl_local_acc_t(sycl::range<1>(local_size), cgh); + is_last_block_done_ = sycl_local_acc_t(sycl::range<1>(1), cgh); + } + + void init() { + using KernelT = + WelfordNormPFKernel; + auto max_group_size = syclMaxWorkGroupSize(); + std::tie(block_size_y_, block_size_x_, nblocks_y_, nblocks_x_) = + impl::get_adaptive_config( + problem_size_, batch_size_, VEC_SIZE, max_group_size); + } + + static bool valid( + int batch_size, + int problem_size, + const scalar_t* input, + acc_t* save_mean, + acc_t* save_invstd) { + if (VEC_SIZE <= 1) + return true; + bool valid = sizeof(scalar_t) <= 4; + valid = valid && (batch_size % VEC_SIZE == 0); + valid = valid && + (memory::can_vectorize_up_to((char*)input) >= VEC_SIZE); + valid = valid && + (memory::can_vectorize_up_to((char*)save_mean) >= VEC_SIZE); + valid = valid && + (memory::can_vectorize_up_to((char*)save_invstd) >= VEC_SIZE); + return valid; + } + + sycl::range<2> local_range() const { + return sycl::range<2>(block_size_y_, block_size_x_); + } + + sycl::range<2> global_range() const { + return sycl::range<2>( + block_size_y_ * nblocks_y_, block_size_x_ * nblocks_x_); + } + + int staging_size() const { + return nblocks_y_ * batch_size_ * 4; + } + + int semaphores_size() const { + return nblocks_x_; + } + + bool set_staging_data_check(acc_t* staging_data) { + staging_data_ = staging_data; + return ( + (staging_data == nullptr) || + (memory::can_vectorize_up_to((char*)staging_data) >= VEC_SIZE)); + } + + void set_semaphores(int* semaphores) { + semaphores_ = semaphores; + } + + void set_running_mean_var( + running_mean_t* running_mean, + running_mean_t* running_var, + acc_t momentum) { + running_mean_ = running_mean; + running_var_ = running_var; + momentum_ = momentum; + } + + int num_cooperative_blocks() const { + return nblocks_y_; + } + + WelfordNormPFKernel( + const scalar_t* input, + int batch_size, + int problem_size, + acc_t epsilon, + acc_t* save_mean, + acc_t* save_invstd) + : input_(input), + batch_size_(batch_size), + problem_size_(problem_size), + epsilon_(epsilon), + save_mean_(save_mean), + save_invstd_(save_invstd), + staging_data_(nullptr), + semaphores_(nullptr), + running_mean_(nullptr), + running_var_(nullptr) {} + + private: + const scalar_t* input_; + int batch_size_; + int problem_size_; + acc_t epsilon_; + acc_t* save_mean_; + acc_t* save_invstd_; + acc_t* staging_data_; + int* semaphores_; + + running_mean_t* running_mean_; + running_mean_t* running_var_; + acc_t momentum_; + + size_t block_size_y_; + size_t block_size_x_; + size_t nblocks_y_; + size_t nblocks_x_; + + sycl_local_acc_t shmem_mean_; + sycl_local_acc_t shmem_m2n_; + sycl_local_acc_t shmem_count_; + sycl_local_acc_t is_last_block_done_; +}; + +} // namespace at::native::xpu