diff --git a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp index 7ac78911b..d0a39ab2f 100644 --- a/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp +++ b/src/ATen/native/xpu/sycl/AdaptiveAveragePooling2dKernels.cpp @@ -1,35 +1,40 @@ #include #include +#include #include #include +#include #include #include #include #include +#define START_IND(a, b, c) ((int64_t)((a / b) * c + ((a % b) * c) / b)) +#define END_IND(a, b, c) (1 + ((int64_t)(a + 1) * c - 1) / b) + +#define START_IND_INT(a, b, c) ((a * c) / b) +#define END_IND_INT(a, b, c) (((a + 1) * c + b - 1) / b) + +#define XPU_MAX_THREADS 1024 +#define GROUP_STRIDE 2 // increasing group_stride to lower # of groups launched + namespace at::native::xpu { using namespace at::xpu; -template +template struct AdaptiveAvgPool2dBwdKernelFunctor { void operator()(sycl::nd_item<1> item) const { int64_t gi = item.get_global_linear_id(); for (int64_t i = gi; i < numel_; i += global_range_) { int64_t _iw, _ih, _ic, _ib; - if constexpr (is_channels_last) { - _ic = i % ic_; - _iw = i / ic_ % iw_; - _ih = i / ic_ / iw_ % ih_; - _ib = i / ic_ / iw_ / ih_; - } else { - _iw = i % iw_; - _ih = i / iw_ % ih_; - _ic = i / iw_ / ih_ % ic_; - _ib = i / iw_ / ih_ / ic_; - } + + _iw = i % iw_; + _ih = i / iw_ % ih_; + _ic = i / iw_ / ih_ % ic_; + _ib = i / iw_ / ih_ / ic_; int64_t _oh0 = native::start_index(_ih, ih_, oh_); int64_t _oh1 = native::end_index(_ih, ih_, oh_); @@ -101,7 +106,7 @@ struct AdaptiveAvgPool2dBwdKernelFunctor { PackedTensorAccessor64 gxacc_; }; -template +template struct AdaptiveAvgPool2dBwdSLMKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { void operator()(sycl::nd_item<1> item) const { @@ -133,17 +138,11 @@ struct AdaptiveAvgPool2dBwdSLMKernelFunctor for (int64_t i = gi; i < numel_; i += global_range_) { int64_t _iw, _ih, _ic, _ib; - if constexpr (is_channels_last) { - _ic = i % ic_; - _iw = i / ic_ % iw_; - _ih = i / ic_ / iw_ % ih_; - _ib = i / ic_ / iw_ / ih_; - } else { - _iw = i % iw_; - _ih = i / iw_ % ih_; - _ic = i / iw_ / ih_ % ic_; - _ib = i / iw_ / ih_ / ic_; - } + + _iw = i % iw_; + _ih = i / iw_ % ih_; + _ic = i / iw_ / ih_ % ic_; + _ib = i / iw_ / ih_ / ic_; int64_t _oh0, _oh1, _ow0, _ow1; _oh0 = _oh0_cached_[_ih]; @@ -230,6 +229,175 @@ struct AdaptiveAvgPool2dBwdSLMKernelFunctor sycl_local_acc_t _ikw_cached_; }; +template +struct AdaptiveAvgPool2dBwdSLMChannelsLastKernelFunctor + : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<3> item) const { + scalar_t* out_cached = + (scalar_t*)out_cached_ + .template get_multi_ptr() + .get(); + // flattening cta for pre-computation & smem initialization; + int thread_id = item.get_local_id(2) + + item.get_local_range(2) * + (item.get_local_id(1) + + item.get_local_range(1) * item.get_local_id(0)); + // Precompute output start/end index per input index on width dimension; + // Not doing this for height dimension, as that's our out-most loop. + for (index_t i = thread_id; i < isizeW_; i += group_size_) { + ostartW_cached_[i] = START_IND_INT(i, isizeW_, osizeW_); + oendW_cached_[i] = END_IND_INT(i, isizeW_, osizeW_); + } + + // Precompute pooling height/weight factor for each output element; + // This is used to weight output gradient when accumulate them on input + // gradient. + for (index_t i = thread_id; i < osizeH_; i += group_size_) { + r_kH_cached_[i] = scalar_t(1.0) / + (END_IND_INT(i, osizeH_, isizeH_) - + START_IND_INT(i, osizeH_, isizeH_)); + } + for (index_t i = thread_id; i < osizeW_; i += group_size_) { + r_kW_cached_[i] = scalar_t(1.0) / + (END_IND_INT(i, osizeW_, isizeW_) - + START_IND_INT(i, osizeW_, isizeW_)); + } + + // each cta handles a portion of a single slice on batch dimension; + // we use get_group_range(2) to handle striding on C as well. + int batch_id = item.get_group(2) % sizeB_; + int channel_id = item.get_group(2) / sizeB_; + int channel_offset = + item.get_local_id(2) + channel_id * item.get_local_range(2); + + // use shared memory to store temporary output value. This is simply to + // reduce register usage. + for (index_t i = thread_id; i < kernel_size_C_ * item.get_local_range(2) * + item.get_local_range(1) * item.get_local_range(0); + i += group_size_) { + out_cached[i] = scalar_t(0.0); + } + + item.barrier(sycl_local_fence); + + auto gradInput = gradInput_ + batch_id * isizeH_ * isizeW_ * sizeC_; + auto gradOutput = gradOutput_ + batch_id * ostrideB_; + + // split out_cached and exclusively it assigned to each thread; + out_cached = &out_cached + [(item.get_local_id(0) * item.get_local_range(1) + + item.get_local_id(1)) * + item.get_local_range(2) * kernel_size_C_]; + + // iterate on input H & W. + // each cta handles a consecutive H & W section (TILE); Do NOT stride CTA on + // tile so there's a better chance to hit L1 cache. + index_t iH = + (isizeH_ + item.get_group_range(0) - 1) / item.get_group_range(0); + index_t iW = + (isizeW_ + item.get_group_range(1) - 1) / item.get_group_range(1); + index_t istartH = item.get_local_id(0) + item.get_group(0) * iH; + index_t iendH = std::min(istartH + iH, isizeH_); + index_t istartW = item.get_local_id(1) + item.get_group(1) * iW; + index_t iendW = std::min(istartW + iW, isizeW_); + + // Stride for threads, each subgroup can reuse L1 as they go. So theoretically + // better chance to survive cache eviction. + for (index_t ih = istartH; ih < iendH; ih += item.get_local_range(0)) { + index_t ostartH = START_IND_INT(ih, isizeH_, osizeH_); + index_t oendH = END_IND_INT(ih, isizeH_, osizeH_); + for (index_t iw = istartW; iw < iendW; iw += item.get_local_range(1)) { + // loop on output: hierarchy h->w->c, so we could reuse weight factor f + // because it remains the same for given oh & ow + for (index_t oh = ostartH; oh < oendH; ++oh) { + for (index_t ow = ostartW_cached_[iw]; ow < oendW_cached_[iw]; ++ow) { + scalar_t f = r_kW_cached_[ow] * r_kH_cached_[oh]; + const scalar_t* ptr_gradOutput = + gradOutput + oh * ostrideH_ + ow * ostrideW_; + int cached_index = item.get_local_id(2); + for (index_t c = channel_offset; c < sizeC_; + c += item.get_local_range(2) * kernel_stride_C_) { + out_cached[cached_index] += ptr_gradOutput[c * ostrideC_] * f; + cached_index += item.get_local_range(2); + } + } + } + scalar_t* ptr_gradInput = gradInput + (ih * isizeW_ + iw) * sizeC_; + int cached_index = item.get_local_id(2); + // write accumulated gradInput to global memory; + for (index_t c = channel_offset; c < sizeC_; + c += item.get_local_range(2) * kernel_stride_C_) { + ptr_gradInput[c] = out_cached[cached_index]; + out_cached[cached_index] = scalar_t(0.0); + cached_index += item.get_local_range(2); + } + } + } + } + void sycl_ker_config_convention(sycl::handler& cgh) { + ostartW_cached_ = sycl_local_acc_t(isizeW_, cgh); + oendW_cached_ = sycl_local_acc_t(isizeW_, cgh); + r_kW_cached_ = sycl_local_acc_t(osizeW_, cgh); + r_kH_cached_ = sycl_local_acc_t(osizeH_, cgh); + out_cached_ = sycl_local_acc_t(kernel_size_C_ * group_size_, cgh); + } + + AdaptiveAvgPool2dBwdSLMChannelsLastKernelFunctor( + scalar_t* gradInput, + const scalar_t* gradOutput, + int sizeB, + int sizeC, + int isizeH, + int isizeW, + int osizeH, + int osizeW, + int kernel_stride_C, + int kernel_size_C, + index_t ostrideB, + index_t ostrideC, + index_t ostrideH, + index_t ostrideW, + size_t group_size) + : gradInput_(gradInput), + gradOutput_(gradOutput), + sizeB_(sizeB), + sizeC_(sizeC), + isizeH_(isizeH), + isizeW_(isizeW), + osizeH_(osizeH), + osizeW_(osizeW), + kernel_stride_C_(kernel_stride_C), + kernel_size_C_(kernel_size_C), + ostrideB_(ostrideB), + ostrideC_(ostrideC), + ostrideH_(ostrideH), + ostrideW_(ostrideW), + group_size_(group_size) {} + + private: + scalar_t* gradInput_; + const scalar_t* gradOutput_; + int sizeB_; + int sizeC_; + int isizeH_; + int isizeW_; + int osizeH_; + int osizeW_; + int kernel_stride_C_; + int kernel_size_C_; + index_t ostrideB_; + index_t ostrideC_; + index_t ostrideH_; + index_t ostrideW_; + size_t shmem_size_; + size_t group_size_; + sycl_local_acc_t ostartW_cached_; + sycl_local_acc_t oendW_cached_; + sycl_local_acc_t r_kW_cached_; + sycl_local_acc_t r_kH_cached_; + sycl_local_acc_t out_cached_; +}; + void adaptive_avg_pool2d_backward_kernel( Tensor& grad_input, const Tensor& grad_output_, @@ -246,79 +414,137 @@ void adaptive_avg_pool2d_backward_kernel( grad_input = at::empty_like(input_, smf); } - auto outputHeight = grad_output.size(-2); - auto outputWidth = grad_output.size(-1); + int osizeH = grad_output.size(-2); + int osizeW = grad_output.size(-1); - const auto nInputPlane = input.size(-3); - const auto inputHeight = input.size(-2); - const auto inputWidth = input.size(-1); - - int dH = std::floor((float)2 * inputHeight / outputHeight) - - (inputHeight / outputHeight); - int dW = std::floor((float)2 * inputWidth / outputWidth) - - (inputWidth / outputWidth); - std::vector stride_vec = {dH, dW}; - - int kH = std::ceil((float)2 * inputHeight / outputHeight) - - (inputHeight / outputHeight); - int kW = std::ceil((float)2 * inputWidth / outputWidth) - - (inputWidth / outputWidth); - std::vector kernel_size_vec = {kH, kW}; - - int padH = (dH * (outputHeight - 1) + kH - inputHeight) / 2; - int padW = (dW * (outputWidth - 1) + kW - inputWidth) / 2; - std::vector padding_vec = {padH, padW}; + int sizeC = input.size(-3); + int isizeH = input.size(-2); + int isizeW = input.size(-1); bool is_3d = grad_output.ndimension() == 3; if (is_3d) { - grad_output.resize_({1, nInputPlane, outputHeight, outputWidth}); - grad_input.resize_({1, nInputPlane, inputHeight, inputWidth}); + grad_output.resize_({1, sizeC, osizeH, osizeW}); + grad_input.resize_({1, sizeC, isizeH, isizeW}); } - AT_DISPATCH_FLOATING_TYPES_AND2( - at::ScalarType::BFloat16, - at::ScalarType::Half, - grad_output.scalar_type(), - "adaptive_avg_pool2d_backward_xpu", - [&]() { - using opmath_t = at::opmath_type; - auto gyacc = grad_output.packed_accessor64(); - auto gxacc = grad_input.packed_accessor64(); - - int64_t ohw01_shared_size = - ((inputHeight + inputWidth) * 2) * sizeof(int); - int64_t ikhw_shared_size = - (outputHeight + outputWidth) * sizeof(opmath_t); - bool using_shared = - syclLocalMemSize() >= ohw01_shared_size + ikhw_shared_size; - - auto& q = getCurrentSYCLQueue(); - if (is_smf_channels_last(grad_output)) { - if (using_shared) { - AdaptiveAvgPool2dBwdSLMKernelFunctor kfn( - gyacc, gxacc); - sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); - } else { - AdaptiveAvgPool2dBwdKernelFunctor kfn( - gyacc, gxacc); - sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); - } - } else { + int sizeB = input.size(0); + + int64_t ostrideB = grad_output.stride(0); + int64_t ostrideC = grad_output.stride(1); + int64_t ostrideH = grad_output.stride(2); + int64_t ostrideW = grad_output.stride(3); + + if (is_smf_channels_last(grad_output)) { + // preserve channels_last stride on input tensor; + if (!grad_input.is_contiguous(at::MemoryFormat::ChannelsLast)) { + grad_input.as_strided_( + {sizeB, sizeC, isizeH, isizeW}, + {sizeC * isizeH * isizeW, 1, isizeW * sizeC, sizeC}); + } + + int max_threads = + std::min(syclMaxWorkItemsPerSubSlice(), XPU_MAX_THREADS); + size_t sharedMemPerGroup = syclLocalMemSize(); + + bool done = false; + do { + int group_x = std::max( + std::min(lastPow2(sizeC), syclMaxSubGroupSize()), 1); + int group_y = std::max( + std::min(lastPow2(isizeW), max_threads / group_x), 1); + int group_z = std::max( + std::min(lastPow2(isizeH), max_threads / group_x / group_y), 1); + group_x = std::max( + std::min(lastPow2(sizeC), max_threads / group_y / group_z), 1); + sycl::range<3> local_range{ + (size_t)group_z, (size_t)group_y, (size_t)group_x}; + + int kernel_stride_C = ceil_div(sizeC, group_x * 4); + int kernel_size_C = ceil_div(sizeC, group_x * kernel_stride_C); + + int range_x = sizeB * kernel_stride_C; + + int range_y = ceil_div(isizeW, group_y * GROUP_STRIDE); + int range_z = ceil_div(isizeH, group_z * GROUP_STRIDE); + + sycl::range<3> global_range{ + (size_t)range_z * group_z, + (size_t)range_y * group_y, + (size_t)range_x * group_x}; + + AT_ASSERT(input.numel() < std::numeric_limits::max()); + + AT_DISPATCH_FLOATING_TYPES_AND2( + kHalf, + kBFloat16, + input.scalar_type(), + "adaptive_avg_pool2d_backward_nhwc_xpu", + [&] { + size_t shmem_size = (kernel_size_C * group_x * group_y * group_z + + osizeH + osizeW) * + sizeof(scalar_t) + + 2 * isizeW * sizeof(int32_t); + if (shmem_size <= sharedMemPerGroup) { + AdaptiveAvgPool2dBwdSLMChannelsLastKernelFunctor< + int32_t, + scalar_t> + kfn(grad_input.mutable_data_ptr(), + grad_output.const_data_ptr(), + sizeB, + sizeC, + isizeH, + isizeW, + osizeH, + osizeW, + kernel_stride_C, + kernel_size_C, + ostrideB, + ostrideC, + ostrideH, + ostrideW, + group_x * group_y * group_z); + sycl_kernel_submit( + global_range, local_range, getCurrentSYCLQueue(), kfn); + done = true; + } else { + TORCH_WARN_ONCE( + "Requested shmem_size exceeds sharedMemPerBlock " + "limit! Reducing max_threads..."); + max_threads /= 2; + } + }); + } while (!done && max_threads); + } else { + AT_DISPATCH_FLOATING_TYPES_AND2( + at::ScalarType::BFloat16, + at::ScalarType::Half, + grad_output.scalar_type(), + "adaptive_avg_pool2d_backward_xpu", + [&]() { + using opmath_t = at::opmath_type; + auto gyacc = grad_output.packed_accessor64(); + auto gxacc = grad_input.packed_accessor64(); + + int64_t ohw01_shared_size = ((isizeH + isizeW) * 2) * sizeof(int); + int64_t ikhw_shared_size = (osizeH + osizeW) * sizeof(opmath_t); + bool using_shared = + syclLocalMemSize() >= ohw01_shared_size + ikhw_shared_size; + + auto& q = getCurrentSYCLQueue(); if (using_shared) { - AdaptiveAvgPool2dBwdSLMKernelFunctor kfn( + AdaptiveAvgPool2dBwdSLMKernelFunctor kfn( gyacc, gxacc); sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); } else { - AdaptiveAvgPool2dBwdKernelFunctor kfn( + AdaptiveAvgPool2dBwdKernelFunctor kfn( gyacc, gxacc); sycl_kernel_submit(kfn.glb_range(), kfn.loc_range(), q, kfn); } - } - }); - + }); + } if (is_3d) { - grad_output.resize_({nInputPlane, outputHeight, outputWidth}); - grad_input.resize_({nInputPlane, inputHeight, inputWidth}); + grad_output.resize_({sizeC, osizeH, osizeW}); + grad_input.resize_({sizeC, isizeH, isizeW}); } }