Skip to content

Commit a299b79

Browse files
Enhance Performance by Adding index_t Template Parameter (#1610)
Select `int64_t` or `int` data types at the start of GPU kernel computation, using templates to pass the chosen type. This optimization improves performance for smaller shapes. #### Performance Improvement Reasons: 1. **Better Hardware Support**: GPUs handle 32-bit operations more efficiently than 64-bit. 2. **Reduced Memory Usage**: 32-bit integers use less memory, enhancing bandwidth and cache efficiency. 3. **Faster Instructions**: 32-bit operations require fewer instructions, speeding up computation. --------- Co-authored-by: Yutao Xu <[email protected]>
1 parent 6c833e1 commit a299b79

File tree

1 file changed

+91
-82
lines changed

1 file changed

+91
-82
lines changed

src/ATen/native/xpu/sycl/AveragePool2dKernels.cpp

Lines changed: 91 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ inline int max(int a, int b) {
2222
return a >= b ? a : b;
2323
}
2424

25-
template <typename scalar_t, typename accscalar_t>
25+
template <typename scalar_t, typename accscalar_t, typename index_t>
2626
struct AvgPool2dKernelFunctor {
2727
void operator()(sycl::nd_item<1> item) const {
28-
auto index = item.get_global_linear_id();
28+
index_t index = item.get_global_linear_id();
2929

3030
if (index < total_elements_) {
3131
const int pw = index % pooled_width_;
@@ -73,10 +73,10 @@ struct AvgPool2dKernelFunctor {
7373
AvgPool2dKernelFunctor(
7474
scalar_t* top_data,
7575
const scalar_t* bottom_data,
76-
int64_t total_elements,
77-
int64_t channels,
78-
int64_t height,
79-
int64_t width,
76+
index_t total_elements,
77+
index_t channels,
78+
index_t height,
79+
index_t width,
8080
int pooled_height,
8181
int pooled_width,
8282
int kernel_h,
@@ -109,10 +109,10 @@ struct AvgPool2dKernelFunctor {
109109
private:
110110
scalar_t* top_data_;
111111
const scalar_t* bottom_data_;
112-
int64_t total_elements_;
113-
int64_t channels_;
114-
int64_t height_;
115-
int64_t width_;
112+
index_t total_elements_;
113+
index_t channels_;
114+
index_t height_;
115+
index_t width_;
116116
int pooled_height_;
117117
int pooled_width_;
118118
int kernel_h_;
@@ -126,10 +126,10 @@ struct AvgPool2dKernelFunctor {
126126
bool use_divisor_;
127127
};
128128

129-
template <typename scalar_t, typename accscalar_t>
129+
template <typename scalar_t, typename accscalar_t, typename index_t>
130130
struct AvgPool2dChannelsLastKernelFunctor {
131131
void operator()(sycl::nd_item<1> item) const {
132-
auto index = item.get_global_linear_id();
132+
index_t index = item.get_global_linear_id();
133133

134134
if (index < total_elements_) {
135135
const int c = index % channels_;
@@ -175,10 +175,10 @@ struct AvgPool2dChannelsLastKernelFunctor {
175175
AvgPool2dChannelsLastKernelFunctor(
176176
scalar_t* top_data,
177177
const scalar_t* bottom_data,
178-
int64_t total_elements,
179-
int64_t channels,
180-
int64_t height,
181-
int64_t width,
178+
index_t total_elements,
179+
index_t channels,
180+
index_t height,
181+
index_t width,
182182
int pooled_height,
183183
int pooled_width,
184184
int kernel_h,
@@ -211,10 +211,10 @@ struct AvgPool2dChannelsLastKernelFunctor {
211211
private:
212212
scalar_t* top_data_;
213213
const scalar_t* bottom_data_;
214-
int64_t total_elements_;
215-
int64_t channels_;
216-
int64_t height_;
217-
int64_t width_;
214+
index_t total_elements_;
215+
index_t channels_;
216+
index_t height_;
217+
index_t width_;
218218
int pooled_height_;
219219
int pooled_width_;
220220
int kernel_h_;
@@ -228,13 +228,13 @@ struct AvgPool2dChannelsLastKernelFunctor {
228228
bool use_divisor_;
229229
};
230230

231-
template <typename scalar_t, typename accscalar_t>
231+
template <typename scalar_t, typename accscalar_t, typename index_t>
232232
void launch_avg_pool2d_channels_last_kernel(
233233
const int total_elements,
234234
const Tensor& input,
235-
const int64_t channels,
236-
const int64_t height,
237-
const int64_t width,
235+
const index_t channels,
236+
const index_t height,
237+
const index_t width,
238238
const int pooled_height,
239239
const int pooled_width,
240240
const int kernel_h,
@@ -255,7 +255,7 @@ void launch_avg_pool2d_channels_last_kernel(
255255
const uint32_t global_range =
256256
ceil_div<uint32_t>(total_elements, group_size) * group_size;
257257

258-
auto kfn = AvgPool2dChannelsLastKernelFunctor<scalar_t, accscalar_t>(
258+
auto kfn = AvgPool2dChannelsLastKernelFunctor<scalar_t, accscalar_t, index_t>(
259259
top_data,
260260
bottom_data,
261261
total_elements,
@@ -276,13 +276,13 @@ void launch_avg_pool2d_channels_last_kernel(
276276
sycl_kernel_submit(global_range, group_size, queue, kfn);
277277
}
278278

279-
template <typename scalar_t, typename accscalar_t>
279+
template <typename scalar_t, typename accscalar_t, typename index_t>
280280
void launch_avg_pool2d_kernel(
281281
const int total_elements,
282282
const Tensor& input,
283-
const int64_t channels,
284-
const int64_t height,
285-
const int64_t width,
283+
const index_t channels,
284+
const index_t height,
285+
const index_t width,
286286
const int pooled_height,
287287
const int pooled_width,
288288
const int kernel_h,
@@ -303,7 +303,7 @@ void launch_avg_pool2d_kernel(
303303
const uint32_t global_range =
304304
ceil_div<uint32_t>(total_elements, group_size) * group_size;
305305

306-
auto kfn = AvgPool2dKernelFunctor<scalar_t, accscalar_t>(
306+
auto kfn = AvgPool2dKernelFunctor<scalar_t, accscalar_t, index_t>(
307307
top_data,
308308
bottom_data,
309309
total_elements,
@@ -664,58 +664,67 @@ void avg_pool2d_kernel(
664664
AT_DISPATCH_FLOATING_TYPES_AND2(
665665
kHalf, kBFloat16, input.scalar_type(), "avg_pool2d_xpu", [&] {
666666
using accscalar_t = acc_type_device<scalar_t, kXPU>;
667-
668-
switch (memory_format) {
669-
case MemoryFormat::ChannelsLast: {
670-
output.unsafeGetTensorImpl()->empty_tensor_restride(
671-
MemoryFormat::ChannelsLast);
672-
launch_avg_pool2d_channels_last_kernel<scalar_t, accscalar_t>(
673-
count,
674-
input,
675-
nInputPlane,
676-
inputHeight,
677-
inputWidth,
678-
outputHeight,
679-
outputWidth,
680-
kH_,
681-
kW_,
682-
dH_,
683-
dW_,
684-
padH_,
685-
padW_,
686-
output,
687-
divisor_override_value,
688-
count_include_pad,
689-
use_divisor);
690-
break;
691-
}
692-
case MemoryFormat::Contiguous: {
693-
launch_avg_pool2d_kernel<scalar_t, accscalar_t>(
694-
count,
695-
input,
696-
nInputPlane,
697-
inputHeight,
698-
inputWidth,
699-
outputHeight,
700-
outputWidth,
701-
kH_,
702-
kW_,
703-
dH_,
704-
dW_,
705-
padH_,
706-
padW_,
707-
output,
708-
divisor_override_value,
709-
count_include_pad,
710-
use_divisor);
711-
break;
712-
}
713-
default:
714-
TORCH_CHECK(
715-
false,
716-
"Unsupported memory format. Supports only "
717-
"ChannelsLast, Contiguous");
718-
}
667+
AT_DISPATCH_INDEX_TYPES(
668+
at::native::canUse32BitIndexMath(output, INT_MAX)
669+
? ScalarType::Int
670+
: ScalarType::Long,
671+
"avg_pool2d_xpu",
672+
[&] {
673+
switch (memory_format) {
674+
case MemoryFormat::ChannelsLast: {
675+
output.unsafeGetTensorImpl()->empty_tensor_restride(
676+
MemoryFormat::ChannelsLast);
677+
launch_avg_pool2d_channels_last_kernel<
678+
scalar_t,
679+
accscalar_t,
680+
index_t>(
681+
count,
682+
input,
683+
nInputPlane,
684+
inputHeight,
685+
inputWidth,
686+
outputHeight,
687+
outputWidth,
688+
kH_,
689+
kW_,
690+
dH_,
691+
dW_,
692+
padH_,
693+
padW_,
694+
output,
695+
divisor_override_value,
696+
count_include_pad,
697+
use_divisor);
698+
break;
699+
}
700+
case MemoryFormat::Contiguous: {
701+
launch_avg_pool2d_kernel<scalar_t, accscalar_t, index_t>(
702+
count,
703+
input,
704+
nInputPlane,
705+
inputHeight,
706+
inputWidth,
707+
outputHeight,
708+
outputWidth,
709+
kH_,
710+
kW_,
711+
dH_,
712+
dW_,
713+
padH_,
714+
padW_,
715+
output,
716+
divisor_override_value,
717+
count_include_pad,
718+
use_divisor);
719+
break;
720+
}
721+
default:
722+
TORCH_CHECK(
723+
false,
724+
"Unsupported memory format. Supports only "
725+
"ChannelsLast, Contiguous");
726+
}
727+
});
719728
});
720729
}
721730
}

0 commit comments

Comments
 (0)