@@ -51,6 +51,7 @@ namespace kernels
51
51
{
52
52
53
53
using dpctl::tensor::ssize_t ;
54
+ namespace su_ns = dpctl::tensor::sycl_utils;
54
55
namespace tu_ns = dpctl::tensor::type_utils;
55
56
namespace exprm_ns = sycl::ext::oneapi::experimental;
56
57
@@ -101,7 +102,7 @@ void scale_gemm_nm_parameters(const std::size_t &local_mem_size,
101
102
}
102
103
} // namespace gemm_detail
103
104
104
- using dpctl::tensor::sycl_utils ::choose_workgroup_size;
105
+ using su_ns ::choose_workgroup_size;
105
106
106
107
template <typename T1, typename T2, typename T3, typename T4, typename T5>
107
108
class gemm_seq_reduction_krn ;
@@ -2367,12 +2368,12 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q,
2367
2368
depends);
2368
2369
}
2369
2370
else {
2370
- using ReductionOpT =
2371
- typename std::conditional<std:: is_same_v<resTy, bool >,
2372
- sycl::logical_or <resTy>,
2373
- sycl::plus<resTy>>::type ;
2371
+ using ReductionOpT = std:: conditional_t <
2372
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
2373
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
2374
+ sycl::plus<resTy>>> ;
2374
2375
constexpr resTy identity_val =
2375
- sycl::known_identity <ReductionOpT, resTy>::value;
2376
+ su_ns::Identity <ReductionOpT, resTy>::value;
2376
2377
2377
2378
std::size_t iter_nelems = batch_nelems * n * m;
2378
2379
std::size_t reduction_nelems =
@@ -2663,12 +2664,12 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q,
2663
2664
lhs_indexer, rhs_indexer, res_indexer, depends);
2664
2665
}
2665
2666
else {
2666
- using ReductionOpT =
2667
- typename std::conditional<std:: is_same_v<resTy, bool >,
2668
- sycl::logical_or <resTy>,
2669
- sycl::plus<resTy>>::type ;
2667
+ using ReductionOpT = std:: conditional_t <
2668
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
2669
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
2670
+ sycl::plus<resTy>>> ;
2670
2671
constexpr resTy identity_val =
2671
- sycl::known_identity <ReductionOpT, resTy>::value;
2672
+ su_ns::Identity <ReductionOpT, resTy>::value;
2672
2673
std::size_t iter_nelems = batch_nelems * n * m;
2673
2674
std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
2674
2675
@@ -3034,12 +3035,12 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q,
3034
3035
depends);
3035
3036
}
3036
3037
else {
3037
- using ReductionOpT =
3038
- typename std::conditional<std:: is_same_v<resTy, bool >,
3039
- sycl::logical_or <resTy>,
3040
- sycl::plus<resTy>>::type ;
3038
+ using ReductionOpT = std:: conditional_t <
3039
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3040
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3041
+ sycl::plus<resTy>>> ;
3041
3042
constexpr resTy identity_val =
3042
- sycl::known_identity <ReductionOpT, resTy>::value;
3043
+ su_ns::Identity <ReductionOpT, resTy>::value;
3043
3044
3044
3045
std::size_t iter_nelems = batch_nelems * n * m;
3045
3046
std::size_t reduction_nelems =
@@ -3222,12 +3223,12 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q,
3222
3223
lhs_indexer, rhs_indexer, res_indexer, depends);
3223
3224
}
3224
3225
else {
3225
- using ReductionOpT =
3226
- typename std::conditional<std:: is_same_v<resTy, bool >,
3227
- sycl::logical_or <resTy>,
3228
- sycl::plus<resTy>>::type ;
3226
+ using ReductionOpT = std:: conditional_t <
3227
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3228
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3229
+ sycl::plus<resTy>>> ;
3229
3230
constexpr resTy identity_val =
3230
- sycl::known_identity <ReductionOpT, resTy>::value;
3231
+ su_ns::Identity <ReductionOpT, resTy>::value;
3231
3232
std::size_t iter_nelems = batch_nelems * n * m;
3232
3233
std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
3233
3234
@@ -3591,12 +3592,12 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q,
3591
3592
res_indexer, depends);
3592
3593
}
3593
3594
else {
3594
- using ReductionOpT =
3595
- typename std::conditional<std:: is_same_v<resTy, bool >,
3596
- sycl::logical_or <resTy>,
3597
- sycl::plus<resTy>>::type ;
3595
+ using ReductionOpT = std:: conditional_t <
3596
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3597
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3598
+ sycl::plus<resTy>>> ;
3598
3599
constexpr resTy identity_val =
3599
- sycl::known_identity <ReductionOpT, resTy>::value;
3600
+ su_ns::Identity <ReductionOpT, resTy>::value;
3600
3601
3601
3602
std::size_t iter_nelems = n * m;
3602
3603
std::size_t reduction_nelems =
@@ -3745,12 +3746,12 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q,
3745
3746
lhs_indexer, rhs_indexer, res_indexer, depends);
3746
3747
}
3747
3748
else {
3748
- using ReductionOpT =
3749
- typename std::conditional<std:: is_same_v<resTy, bool >,
3750
- sycl::logical_or <resTy>,
3751
- sycl::plus<resTy>>::type ;
3749
+ using ReductionOpT = std:: conditional_t <
3750
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3751
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3752
+ sycl::plus<resTy>>> ;
3752
3753
constexpr resTy identity_val =
3753
- sycl::known_identity <ReductionOpT, resTy>::value;
3754
+ su_ns::Identity <ReductionOpT, resTy>::value;
3754
3755
3755
3756
std::size_t iter_nelems = n * m;
3756
3757
std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
@@ -3979,12 +3980,12 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q,
3979
3980
res_indexer, depends);
3980
3981
}
3981
3982
else {
3982
- using ReductionOpT =
3983
- typename std::conditional<std:: is_same_v<resTy, bool >,
3984
- sycl::logical_or <resTy>,
3985
- sycl::plus<resTy>>::type ;
3983
+ using ReductionOpT = std:: conditional_t <
3984
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
3985
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
3986
+ sycl::plus<resTy>>> ;
3986
3987
constexpr resTy identity_val =
3987
- sycl::known_identity <ReductionOpT, resTy>::value;
3988
+ su_ns::Identity <ReductionOpT, resTy>::value;
3988
3989
3989
3990
std::size_t iter_nelems = n * m;
3990
3991
std::size_t reduction_nelems =
@@ -4118,12 +4119,12 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q,
4118
4119
lhs_indexer, rhs_indexer, res_indexer, depends);
4119
4120
}
4120
4121
else {
4121
- using ReductionOpT =
4122
- typename std::conditional<std:: is_same_v<resTy, bool >,
4123
- sycl::logical_or <resTy>,
4124
- sycl::plus<resTy>>::type ;
4122
+ using ReductionOpT = std:: conditional_t <
4123
+ std::is_same_v<resTy, bool >, sycl::logical_or<resTy >,
4124
+ std:: conditional_t <tu_ns::is_complex_v<resTy>, su_ns::Plus <resTy>,
4125
+ sycl::plus<resTy>>> ;
4125
4126
constexpr resTy identity_val =
4126
- sycl::known_identity <ReductionOpT, resTy>::value;
4127
+ su_ns::Identity <ReductionOpT, resTy>::value;
4127
4128
4128
4129
std::size_t iter_nelems = n * m;
4129
4130
std::size_t reduction_nelems = (k + wi_delta_k - 1 ) / wi_delta_k;
0 commit comments