Skip to content

Rebase BN forward to capture more fusion patterns #1654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
264 changes: 174 additions & 90 deletions src/ATen/native/xpu/sycl/BatchNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <ATen/native/xpu/sycl/Loops.h>
#include <ATen/native/xpu/sycl/Reduce.h>
#include <ATen/native/xpu/sycl/ResizeKernel.h>
#include <ATen/native/xpu/sycl/WelfordNorm.h>
#include <ATen/native/xpu/sycl/WelfordNormPFImpl.h>
#include <ATen/ops/from_blob.h>
#include <ATen/xpu/XPUContext.h>
#include <comm/SYCLContext.h>
Expand Down Expand Up @@ -1072,93 +1072,57 @@ void batch_norm_stats_channels_last_template(
at::Tensor staging_data;
at::Tensor semaphores;

using VecKernel = WelfordBatchNormStatChannelsLastVecKernelFunctor<
VarTransform,
scalar_t,
accscalar_t,
PREFERRED_VEC_SIZE>;
auto input_ptr = input.const_data_ptr<scalar_t>();
auto out_mean_ptr = out_mean.mutable_data_ptr<accscalar_t>();
auto out_invstd_ptr = out_invstd.mutable_data_ptr<accscalar_t>();
bool use_vec_kernel = false;

if (VecKernel::valid(
reduction_size, stride, input_ptr, out_mean_ptr, out_invstd_ptr)) {
auto kfn = VecKernel(
input_ptr,
out_mean_ptr,
out_invstd_ptr,
reduction_size,
stride,
nullptr,
nullptr,
epsilon);
kfn.init();

staging_data = at::empty({(long)(kfn.staging_size())}, out_mean.options());
semaphores = at::zeros(
{(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt));
accscalar_t* staging_data_ptr = kfn.num_cooperative_groups() > 1
? staging_data.mutable_data_ptr<accscalar_t>()
: nullptr;
int* semaphores_ptr = kfn.num_cooperative_groups() > 1
? semaphores.mutable_data_ptr<int>()
: nullptr;

use_vec_kernel = kfn.set_staging_data_check(staging_data_ptr);

if (use_vec_kernel) {
kfn.set_semaphores(semaphores_ptr);
sycl_kernel_submit(
kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn);
return;
}
}

if (!use_vec_kernel) {
using KernelT = BatchNormCollectStatisticsChannelsLastKernelFunctor<
VarTransform,
scalar_t,
accscalar_t,
ELEMENTS_PER_ITER>;

auto config = get_adaptive_launch_config(
syclMaxWorkGroupSize<KernelT>(),
reduction_size,
stride,
true,
ELEMENTS_PER_WORK_ITEM);
auto global_range = std::get<0>(config);
auto local_range = std::get<1>(config);

auto wg_size_y = local_range[0];
auto wg_size_x = local_range[1];
auto nwg_y = global_range[0] / wg_size_y;
auto nwg_x = global_range[1] / wg_size_x;
if (nwg_y > 1) {
staging_data =
at::empty({(long)(4 * stride * nwg_y)}, out_mean.options());
semaphores = at::zeros({(long)nwg_x}, input.options().dtype(at::kInt));
}
accscalar_t* staging_data_ptr =
nwg_y > 1 ? staging_data.mutable_data_ptr<accscalar_t>() : nullptr;
int* semaphores_ptr =
nwg_y > 1 ? semaphores.mutable_data_ptr<int>() : nullptr;

auto kfn = KernelT(
input_ptr,
out_mean_ptr,
out_invstd_ptr,
staging_data_ptr,
semaphores_ptr,
reduction_size,
stride,
epsilon,
wg_size_y * wg_size_x);
int vec_size = welford_norm_pf_kernel_vec_size<scalar_t, accscalar_t>(
stride, input_ptr, out_mean_ptr, out_invstd_ptr);

#define DISPATCH_VEC(VEC_SIZE) \
{ \
using KernelT = \
WelfordNormPFKernel<VarTransform, scalar_t, accscalar_t, VEC_SIZE>; \
auto kfn = KernelT( \
input_ptr, \
stride, \
reduction_size, \
epsilon, \
out_mean_ptr, \
out_invstd_ptr); \
kfn.init(); \
staging_data = \
at::empty({(long)(kfn.staging_size())}, out_mean.options()); \
semaphores = at::zeros( \
{(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); \
accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 \
? staging_data.mutable_data_ptr<accscalar_t>() \
: nullptr; \
int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 \
? semaphores.mutable_data_ptr<int>() \
: nullptr; \
TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); \
kfn.set_semaphores(semaphores_ptr); \
sycl_kernel_submit( \
kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); \
}

sycl_kernel_submit(global_range, local_range, getCurrentSYCLQueue(), kfn);
switch (vec_size) {
case 8:
DISPATCH_VEC(8)
break;
case 4:
DISPATCH_VEC(4)
break;
case 2:
DISPATCH_VEC(2)
break;
default:
DISPATCH_VEC(1)
break;
}
}
#undef DISPATCH_VEC

std::tuple<Tensor, Tensor> batch_norm_stats_kernel(
const Tensor& self,
Expand Down Expand Up @@ -4040,6 +4004,104 @@ void batch_norm_mean_var(
}
}

void batch_norm_mean_var_fused_cnl(
const Tensor& input,
Tensor& out_mean,
Tensor& out_invstd,
Tensor& running_mean,
Tensor& running_var,
double momentum = 0,
double dummy_epsilon = 1e-5) {
// NOTE: Epsilon is only used for InvStd, not Var. The value here is ignored.
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
kBFloat16,
input.scalar_type(),
"batch_norm_mean_var_fused_cnl",
[&] {
using accscalar_t = acc_type_device<scalar_t, kXPU>;

const auto stride = input.sizes()[1];
const auto reduction_size = input.numel() / stride;

at::native::resize_output(out_mean, {stride});
at::native::resize_output(out_invstd, {stride});
TORCH_INTERNAL_ASSERT(
out_invstd.dim() == 1 && out_invstd.is_contiguous() &&
out_invstd.sizes()[0]);
TORCH_INTERNAL_ASSERT(
out_mean.dim() == 1 && out_mean.is_contiguous() &&
out_mean.sizes()[0]);

at::Tensor staging_data;
at::Tensor semaphores;

auto input_ptr = input.const_data_ptr<scalar_t>();
auto out_mean_ptr = out_mean.mutable_data_ptr<accscalar_t>();
auto out_invstd_ptr = out_invstd.mutable_data_ptr<accscalar_t>();

auto running_mean_ptr = running_mean.defined()
? running_mean.data_ptr<accscalar_t>()
: nullptr;
auto running_var_ptr = running_var.defined()
? running_var.data_ptr<accscalar_t>()
: nullptr;

int vec_size = welford_norm_pf_kernel_vec_size<scalar_t, accscalar_t>(
stride,
input_ptr,
out_mean_ptr,
out_invstd_ptr,
running_mean_ptr,
running_var_ptr);

#define DISPATCH_VEC(VEC_SIZE) \
{ \
using KernelT = \
WelfordNormPFKernel<InvStd, scalar_t, accscalar_t, VEC_SIZE>; \
auto kfn = KernelT( \
input_ptr, \
stride, \
reduction_size, \
(accscalar_t)dummy_epsilon, \
out_mean_ptr, \
out_invstd_ptr); \
kfn.init(); \
staging_data = \
at::empty({(long)(kfn.staging_size())}, out_mean.options()); \
semaphores = at::zeros( \
{(long)(kfn.semaphores_size())}, input.options().dtype(at::kInt)); \
accscalar_t* staging_data_ptr = kfn.num_cooperative_blocks() > 1 \
? staging_data.mutable_data_ptr<accscalar_t>() \
: nullptr; \
int* semaphores_ptr = kfn.num_cooperative_blocks() > 1 \
? semaphores.mutable_data_ptr<int>() \
: nullptr; \
TORCH_CHECK(kfn.set_staging_data_check(staging_data_ptr)); \
kfn.set_semaphores(semaphores_ptr); \
kfn.set_running_mean_var( \
running_mean_ptr, running_var_ptr, (accscalar_t)momentum); \
sycl_kernel_submit( \
kfn.global_range(), kfn.local_range(), getCurrentSYCLQueue(), kfn); \
}
switch (vec_size) {
case 8:
DISPATCH_VEC(8)
break;
case 4:
DISPATCH_VEC(4)
break;
case 2:
DISPATCH_VEC(2)
break;
default:
DISPATCH_VEC(1)
break;
}
});
}
#undef DISPATCH_VEC

std::tuple<Tensor, Tensor> batch_norm_update_stats_kernel(
const Tensor& self,
const std::optional<Tensor>& running_mean_opt,
Expand Down Expand Up @@ -4303,20 +4365,42 @@ std::tuple<Tensor&, Tensor&, Tensor&> batch_norm_kernel(
(running_var_opt.has_value() && running_var_opt->defined());
TORCH_CHECK(has_running_mean == has_running_var);

bool can_use_fused_kernel =
batch_norm_choose_impl(self) == Impl::ChannelsLast &&
(!save_mean.defined() || save_mean.is_contiguous()) &&
(!save_invstd.defined() || save_invstd.is_contiguous());

if (train) {
batch_norm_mean_var(self, save_mean, save_invstd);
if (has_running_mean) {
const int64_t N = self.numel() / save_mean.numel();
batch_norm_update_stats_and_invert(
if (can_use_fused_kernel && !has_running_mean) {
Tensor a;
batch_norm_mean_var_fused_cnl(
self, save_mean, save_invstd, a, a, momentum, epsilon);
} else if (
can_use_fused_kernel &&
(*running_mean_opt).dtype() == save_mean.dtype()) {
batch_norm_mean_var_fused_cnl(
self,
save_mean,
save_invstd,
*running_mean_opt,
*running_var_opt,
const_cast<Tensor&>(*running_mean_opt),
const_cast<Tensor&>(*running_var_opt),
momentum,
epsilon,
N);
epsilon);
} else {
batch_norm_calc_invstd(save_invstd, save_invstd, epsilon);
batch_norm_mean_var(self, save_mean, save_invstd);
if (has_running_mean) {
const int64_t N = self.numel() / save_mean.numel();
batch_norm_update_stats_and_invert(
save_mean,
save_invstd,
*running_mean_opt,
*running_var_opt,
momentum,
epsilon,
N);
} else {
batch_norm_calc_invstd(save_invstd, save_invstd, epsilon);
}
}
} else {
TORCH_CHECK(has_running_mean);
Expand Down
Loading
Loading