Skip to content

Commit fe12ba1

Browse files
committed
Use custom plus operator in gemm and dot product tree reduction kernels
1 parent 0821c73 commit fe12ba1

File tree

2 files changed

+56
-52
lines changed

2 files changed

+56
-52
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -755,9 +755,10 @@ struct DotProductNoAtomicFunctor
755755

756756
auto work_group = it.get_group();
757757

758-
using RedOpT = typename std::conditional<std::is_same_v<outT, bool>,
759-
sycl::logical_or<outT>,
760-
sycl::plus<outT>>::type;
758+
using RedOpT = std::conditional_t<
759+
std::is_same_v<outT, bool>, sycl::logical_or<outT>,
760+
std::conditional_t<tu_ns::is_complex_v<outT>, su_ns::Plus<outT>,
761+
sycl::plus<outT>>>;
761762
outT red_val_over_wg = sycl::reduce_over_group(
762763
work_group, local_red_val, outT(0), RedOpT());
763764

@@ -1009,9 +1010,10 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q,
10091010
// prevents running out of resources on CPU
10101011
std::size_t max_wg = reduction_detail::get_work_group_size(d);
10111012

1012-
using ReductionOpT = typename std::conditional<std::is_same_v<resTy, bool>,
1013-
sycl::logical_or<resTy>,
1014-
sycl::plus<resTy>>::type;
1013+
using ReductionOpT = std::conditional_t<
1014+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
1015+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
1016+
sycl::plus<resTy>>>;
10151017

10161018
std::size_t reductions_per_wi(preferred_reductions_per_wi);
10171019
if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
@@ -1051,7 +1053,7 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q,
10511053
}
10521054
else {
10531055
constexpr resTy identity_val =
1054-
sycl::known_identity<ReductionOpT, resTy>::value;
1056+
su_ns::Identity<ReductionOpT, resTy>::value;
10551057

10561058
// more than one work-groups is needed, requires a temporary
10571059
std::size_t reduction_groups =
@@ -1252,9 +1254,10 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
12521254
// prevents running out of resources on CPU
12531255
std::size_t max_wg = reduction_detail::get_work_group_size(d);
12541256

1255-
using ReductionOpT = typename std::conditional<std::is_same_v<resTy, bool>,
1256-
sycl::logical_or<resTy>,
1257-
sycl::plus<resTy>>::type;
1257+
using ReductionOpT = std::conditional_t<
1258+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
1259+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
1260+
sycl::plus<resTy>>>;
12581261

12591262
std::size_t reductions_per_wi(preferred_reductions_per_wi);
12601263
if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
@@ -1298,7 +1301,7 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
12981301
}
12991302
else {
13001303
constexpr resTy identity_val =
1301-
sycl::known_identity<ReductionOpT, resTy>::value;
1304+
su_ns::Identity<ReductionOpT, resTy>::value;
13021305

13031306
// more than one work-groups is needed, requires a temporary
13041307
std::size_t reduction_groups =

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ namespace kernels
5151
{
5252

5353
using dpctl::tensor::ssize_t;
54+
namespace su_ns = dpctl::tensor::sycl_utils;
5455
namespace tu_ns = dpctl::tensor::type_utils;
5556
namespace exprm_ns = sycl::ext::oneapi::experimental;
5657

@@ -101,7 +102,7 @@ void scale_gemm_nm_parameters(const std::size_t &local_mem_size,
101102
}
102103
} // namespace gemm_detail
103104

104-
using dpctl::tensor::sycl_utils::choose_workgroup_size;
105+
using su_ns::choose_workgroup_size;
105106

106107
template <typename T1, typename T2, typename T3, typename T4, typename T5>
107108
class gemm_seq_reduction_krn;
@@ -2367,12 +2368,12 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q,
23672368
depends);
23682369
}
23692370
else {
2370-
using ReductionOpT =
2371-
typename std::conditional<std::is_same_v<resTy, bool>,
2372-
sycl::logical_or<resTy>,
2373-
sycl::plus<resTy>>::type;
2371+
using ReductionOpT = std::conditional_t<
2372+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
2373+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
2374+
sycl::plus<resTy>>>;
23742375
constexpr resTy identity_val =
2375-
sycl::known_identity<ReductionOpT, resTy>::value;
2376+
su_ns::Identity<ReductionOpT, resTy>::value;
23762377

23772378
std::size_t iter_nelems = batch_nelems * n * m;
23782379
std::size_t reduction_nelems =
@@ -2663,12 +2664,12 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q,
26632664
lhs_indexer, rhs_indexer, res_indexer, depends);
26642665
}
26652666
else {
2666-
using ReductionOpT =
2667-
typename std::conditional<std::is_same_v<resTy, bool>,
2668-
sycl::logical_or<resTy>,
2669-
sycl::plus<resTy>>::type;
2667+
using ReductionOpT = std::conditional_t<
2668+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
2669+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
2670+
sycl::plus<resTy>>>;
26702671
constexpr resTy identity_val =
2671-
sycl::known_identity<ReductionOpT, resTy>::value;
2672+
su_ns::Identity<ReductionOpT, resTy>::value;
26722673
std::size_t iter_nelems = batch_nelems * n * m;
26732674
std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k;
26742675

@@ -3034,12 +3035,12 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q,
30343035
depends);
30353036
}
30363037
else {
3037-
using ReductionOpT =
3038-
typename std::conditional<std::is_same_v<resTy, bool>,
3039-
sycl::logical_or<resTy>,
3040-
sycl::plus<resTy>>::type;
3038+
using ReductionOpT = std::conditional_t<
3039+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3040+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3041+
sycl::plus<resTy>>>;
30413042
constexpr resTy identity_val =
3042-
sycl::known_identity<ReductionOpT, resTy>::value;
3043+
su_ns::Identity<ReductionOpT, resTy>::value;
30433044

30443045
std::size_t iter_nelems = batch_nelems * n * m;
30453046
std::size_t reduction_nelems =
@@ -3222,12 +3223,12 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q,
32223223
lhs_indexer, rhs_indexer, res_indexer, depends);
32233224
}
32243225
else {
3225-
using ReductionOpT =
3226-
typename std::conditional<std::is_same_v<resTy, bool>,
3227-
sycl::logical_or<resTy>,
3228-
sycl::plus<resTy>>::type;
3226+
using ReductionOpT = std::conditional_t<
3227+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3228+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3229+
sycl::plus<resTy>>>;
32293230
constexpr resTy identity_val =
3230-
sycl::known_identity<ReductionOpT, resTy>::value;
3231+
su_ns::Identity<ReductionOpT, resTy>::value;
32313232
std::size_t iter_nelems = batch_nelems * n * m;
32323233
std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k;
32333234

@@ -3591,12 +3592,12 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q,
35913592
res_indexer, depends);
35923593
}
35933594
else {
3594-
using ReductionOpT =
3595-
typename std::conditional<std::is_same_v<resTy, bool>,
3596-
sycl::logical_or<resTy>,
3597-
sycl::plus<resTy>>::type;
3595+
using ReductionOpT = std::conditional_t<
3596+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3597+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3598+
sycl::plus<resTy>>>;
35983599
constexpr resTy identity_val =
3599-
sycl::known_identity<ReductionOpT, resTy>::value;
3600+
su_ns::Identity<ReductionOpT, resTy>::value;
36003601

36013602
std::size_t iter_nelems = n * m;
36023603
std::size_t reduction_nelems =
@@ -3745,12 +3746,12 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q,
37453746
lhs_indexer, rhs_indexer, res_indexer, depends);
37463747
}
37473748
else {
3748-
using ReductionOpT =
3749-
typename std::conditional<std::is_same_v<resTy, bool>,
3750-
sycl::logical_or<resTy>,
3751-
sycl::plus<resTy>>::type;
3749+
using ReductionOpT = std::conditional_t<
3750+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3751+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3752+
sycl::plus<resTy>>>;
37523753
constexpr resTy identity_val =
3753-
sycl::known_identity<ReductionOpT, resTy>::value;
3754+
su_ns::Identity<ReductionOpT, resTy>::value;
37543755

37553756
std::size_t iter_nelems = n * m;
37563757
std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k;
@@ -3979,12 +3980,12 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q,
39793980
res_indexer, depends);
39803981
}
39813982
else {
3982-
using ReductionOpT =
3983-
typename std::conditional<std::is_same_v<resTy, bool>,
3984-
sycl::logical_or<resTy>,
3985-
sycl::plus<resTy>>::type;
3983+
using ReductionOpT = std::conditional_t<
3984+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3985+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3986+
sycl::plus<resTy>>>;
39863987
constexpr resTy identity_val =
3987-
sycl::known_identity<ReductionOpT, resTy>::value;
3988+
su_ns::Identity<ReductionOpT, resTy>::value;
39883989

39893990
std::size_t iter_nelems = n * m;
39903991
std::size_t reduction_nelems =
@@ -4118,12 +4119,12 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q,
41184119
lhs_indexer, rhs_indexer, res_indexer, depends);
41194120
}
41204121
else {
4121-
using ReductionOpT =
4122-
typename std::conditional<std::is_same_v<resTy, bool>,
4123-
sycl::logical_or<resTy>,
4124-
sycl::plus<resTy>>::type;
4122+
using ReductionOpT = std::conditional_t<
4123+
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
4124+
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
4125+
sycl::plus<resTy>>>;
41254126
constexpr resTy identity_val =
4126-
sycl::known_identity<ReductionOpT, resTy>::value;
4127+
su_ns::Identity<ReductionOpT, resTy>::value;
41274128

41284129
std::size_t iter_nelems = n * m;
41294130
std::size_t reduction_nelems = (k + wi_delta_k - 1) / wi_delta_k;

0 commit comments

Comments
 (0)