@@ -22,10 +22,10 @@ inline int max(int a, int b) {
22
22
return a >= b ? a : b;
23
23
}
24
24
25
- template <typename scalar_t , typename accscalar_t >
25
+ template <typename scalar_t , typename accscalar_t , typename index_t >
26
26
struct AvgPool2dKernelFunctor {
27
27
void operator ()(sycl::nd_item<1 > item) const {
28
- auto index = item.get_global_linear_id ();
28
+ index_t index = item.get_global_linear_id ();
29
29
30
30
if (index < total_elements_) {
31
31
const int pw = index % pooled_width_;
@@ -73,10 +73,10 @@ struct AvgPool2dKernelFunctor {
73
73
AvgPool2dKernelFunctor (
74
74
scalar_t * top_data,
75
75
const scalar_t * bottom_data,
76
- int64_t total_elements,
77
- int64_t channels,
78
- int64_t height,
79
- int64_t width,
76
+ index_t total_elements,
77
+ index_t channels,
78
+ index_t height,
79
+ index_t width,
80
80
int pooled_height,
81
81
int pooled_width,
82
82
int kernel_h,
@@ -109,10 +109,10 @@ struct AvgPool2dKernelFunctor {
109
109
private:
110
110
scalar_t * top_data_;
111
111
const scalar_t * bottom_data_;
112
- int64_t total_elements_;
113
- int64_t channels_;
114
- int64_t height_;
115
- int64_t width_;
112
+ index_t total_elements_;
113
+ index_t channels_;
114
+ index_t height_;
115
+ index_t width_;
116
116
int pooled_height_;
117
117
int pooled_width_;
118
118
int kernel_h_;
@@ -126,10 +126,10 @@ struct AvgPool2dKernelFunctor {
126
126
bool use_divisor_;
127
127
};
128
128
129
- template <typename scalar_t , typename accscalar_t >
129
+ template <typename scalar_t , typename accscalar_t , typename index_t >
130
130
struct AvgPool2dChannelsLastKernelFunctor {
131
131
void operator ()(sycl::nd_item<1 > item) const {
132
- auto index = item.get_global_linear_id ();
132
+ index_t index = item.get_global_linear_id ();
133
133
134
134
if (index < total_elements_) {
135
135
const int c = index % channels_;
@@ -175,10 +175,10 @@ struct AvgPool2dChannelsLastKernelFunctor {
175
175
AvgPool2dChannelsLastKernelFunctor (
176
176
scalar_t * top_data,
177
177
const scalar_t * bottom_data,
178
- int64_t total_elements,
179
- int64_t channels,
180
- int64_t height,
181
- int64_t width,
178
+ index_t total_elements,
179
+ index_t channels,
180
+ index_t height,
181
+ index_t width,
182
182
int pooled_height,
183
183
int pooled_width,
184
184
int kernel_h,
@@ -211,10 +211,10 @@ struct AvgPool2dChannelsLastKernelFunctor {
211
211
private:
212
212
scalar_t * top_data_;
213
213
const scalar_t * bottom_data_;
214
- int64_t total_elements_;
215
- int64_t channels_;
216
- int64_t height_;
217
- int64_t width_;
214
+ index_t total_elements_;
215
+ index_t channels_;
216
+ index_t height_;
217
+ index_t width_;
218
218
int pooled_height_;
219
219
int pooled_width_;
220
220
int kernel_h_;
@@ -228,13 +228,13 @@ struct AvgPool2dChannelsLastKernelFunctor {
228
228
bool use_divisor_;
229
229
};
230
230
231
- template <typename scalar_t , typename accscalar_t >
231
+ template <typename scalar_t , typename accscalar_t , typename index_t >
232
232
void launch_avg_pool2d_channels_last_kernel (
233
233
const int total_elements,
234
234
const Tensor& input,
235
- const int64_t channels,
236
- const int64_t height,
237
- const int64_t width,
235
+ const index_t channels,
236
+ const index_t height,
237
+ const index_t width,
238
238
const int pooled_height,
239
239
const int pooled_width,
240
240
const int kernel_h,
@@ -255,7 +255,7 @@ void launch_avg_pool2d_channels_last_kernel(
255
255
const uint32_t global_range =
256
256
ceil_div<uint32_t >(total_elements, group_size) * group_size;
257
257
258
- auto kfn = AvgPool2dChannelsLastKernelFunctor<scalar_t , accscalar_t >(
258
+ auto kfn = AvgPool2dChannelsLastKernelFunctor<scalar_t , accscalar_t , index_t >(
259
259
top_data,
260
260
bottom_data,
261
261
total_elements,
@@ -276,13 +276,13 @@ void launch_avg_pool2d_channels_last_kernel(
276
276
sycl_kernel_submit (global_range, group_size, queue, kfn);
277
277
}
278
278
279
- template <typename scalar_t , typename accscalar_t >
279
+ template <typename scalar_t , typename accscalar_t , typename index_t >
280
280
void launch_avg_pool2d_kernel (
281
281
const int total_elements,
282
282
const Tensor& input,
283
- const int64_t channels,
284
- const int64_t height,
285
- const int64_t width,
283
+ const index_t channels,
284
+ const index_t height,
285
+ const index_t width,
286
286
const int pooled_height,
287
287
const int pooled_width,
288
288
const int kernel_h,
@@ -303,7 +303,7 @@ void launch_avg_pool2d_kernel(
303
303
const uint32_t global_range =
304
304
ceil_div<uint32_t >(total_elements, group_size) * group_size;
305
305
306
- auto kfn = AvgPool2dKernelFunctor<scalar_t , accscalar_t >(
306
+ auto kfn = AvgPool2dKernelFunctor<scalar_t , accscalar_t , index_t >(
307
307
top_data,
308
308
bottom_data,
309
309
total_elements,
@@ -664,58 +664,67 @@ void avg_pool2d_kernel(
664
664
AT_DISPATCH_FLOATING_TYPES_AND2 (
665
665
kHalf , kBFloat16 , input.scalar_type (), " avg_pool2d_xpu" , [&] {
666
666
using accscalar_t = acc_type_device<scalar_t , kXPU >;
667
-
668
- switch (memory_format) {
669
- case MemoryFormat::ChannelsLast: {
670
- output.unsafeGetTensorImpl ()->empty_tensor_restride (
671
- MemoryFormat::ChannelsLast);
672
- launch_avg_pool2d_channels_last_kernel<scalar_t , accscalar_t >(
673
- count,
674
- input,
675
- nInputPlane,
676
- inputHeight,
677
- inputWidth,
678
- outputHeight,
679
- outputWidth,
680
- kH_ ,
681
- kW_ ,
682
- dH_,
683
- dW_,
684
- padH_,
685
- padW_,
686
- output,
687
- divisor_override_value,
688
- count_include_pad,
689
- use_divisor);
690
- break ;
691
- }
692
- case MemoryFormat::Contiguous: {
693
- launch_avg_pool2d_kernel<scalar_t , accscalar_t >(
694
- count,
695
- input,
696
- nInputPlane,
697
- inputHeight,
698
- inputWidth,
699
- outputHeight,
700
- outputWidth,
701
- kH_ ,
702
- kW_ ,
703
- dH_,
704
- dW_,
705
- padH_,
706
- padW_,
707
- output,
708
- divisor_override_value,
709
- count_include_pad,
710
- use_divisor);
711
- break ;
712
- }
713
- default :
714
- TORCH_CHECK (
715
- false ,
716
- " Unsupported memory format. Supports only "
717
- " ChannelsLast, Contiguous" );
718
- }
667
+ AT_DISPATCH_INDEX_TYPES (
668
+ at::native::canUse32BitIndexMath (output, INT_MAX)
669
+ ? ScalarType::Int
670
+ : ScalarType::Long,
671
+ " avg_pool2d_xpu" ,
672
+ [&] {
673
+ switch (memory_format) {
674
+ case MemoryFormat::ChannelsLast: {
675
+ output.unsafeGetTensorImpl ()->empty_tensor_restride (
676
+ MemoryFormat::ChannelsLast);
677
+ launch_avg_pool2d_channels_last_kernel<
678
+ scalar_t ,
679
+ accscalar_t ,
680
+ index_t >(
681
+ count,
682
+ input,
683
+ nInputPlane,
684
+ inputHeight,
685
+ inputWidth,
686
+ outputHeight,
687
+ outputWidth,
688
+ kH_ ,
689
+ kW_ ,
690
+ dH_,
691
+ dW_,
692
+ padH_,
693
+ padW_,
694
+ output,
695
+ divisor_override_value,
696
+ count_include_pad,
697
+ use_divisor);
698
+ break ;
699
+ }
700
+ case MemoryFormat::Contiguous: {
701
+ launch_avg_pool2d_kernel<scalar_t , accscalar_t , index_t >(
702
+ count,
703
+ input,
704
+ nInputPlane,
705
+ inputHeight,
706
+ inputWidth,
707
+ outputHeight,
708
+ outputWidth,
709
+ kH_ ,
710
+ kW_ ,
711
+ dH_,
712
+ dW_,
713
+ padH_,
714
+ padW_,
715
+ output,
716
+ divisor_override_value,
717
+ count_include_pad,
718
+ use_divisor);
719
+ break ;
720
+ }
721
+ default :
722
+ TORCH_CHECK (
723
+ false ,
724
+ " Unsupported memory format. Supports only "
725
+ " ChannelsLast, Contiguous" );
726
+ }
727
+ });
719
728
});
720
729
}
721
730
}
0 commit comments