Skip to content

Commit cedec84

Browse files
committed
Refactor operators when dispatching to tree reductions
1 parent fe12ba1 commit cedec84

File tree

4 files changed

+53
-68
lines changed

4 files changed

+53
-68
lines changed

dpctl/tensor/libtensor/include/kernels/linalg_functions/dot_product.hpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,17 @@ namespace su_ns = dpctl::tensor::sycl_utils;
5555
namespace tu_ns = dpctl::tensor::type_utils;
5656
namespace exprm_ns = sycl::ext::oneapi::experimental;
5757

58+
namespace detail
59+
{
60+
61+
template <typename T>
62+
using SumTempsOpT = std::conditional_t<
63+
std::is_same_v<T, bool>,
64+
sycl::logical_or<T>,
65+
std::conditional_t<tu_ns::is_complex_v<T>, su_ns::Plus<T>, sycl::plus<T>>>;
66+
67+
} // namespace detail
68+
5869
template <typename lhsT,
5970
typename rhsT,
6071
typename outT,
@@ -758,7 +769,7 @@ struct DotProductNoAtomicFunctor
758769
using RedOpT = std::conditional_t<
759770
std::is_same_v<outT, bool>, sycl::logical_or<outT>,
760771
std::conditional_t<tu_ns::is_complex_v<outT>, su_ns::Plus<outT>,
761-
sycl::plus<outT>>>;
772+
sycl::plus<outT>>>;
762773
outT red_val_over_wg = sycl::reduce_over_group(
763774
work_group, local_red_val, outT(0), RedOpT());
764775

@@ -1010,10 +1021,7 @@ sycl::event dot_product_tree_impl(sycl::queue &exec_q,
10101021
// prevents running out of resources on CPU
10111022
std::size_t max_wg = reduction_detail::get_work_group_size(d);
10121023

1013-
using ReductionOpT = std::conditional_t<
1014-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
1015-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
1016-
sycl::plus<resTy>>>;
1024+
using ReductionOpT = detail::SumTempsOpT<resTy>;
10171025

10181026
std::size_t reductions_per_wi(preferred_reductions_per_wi);
10191027
if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {
@@ -1254,10 +1262,7 @@ dot_product_contig_tree_impl(sycl::queue &exec_q,
12541262
// prevents running out of resources on CPU
12551263
std::size_t max_wg = reduction_detail::get_work_group_size(d);
12561264

1257-
using ReductionOpT = std::conditional_t<
1258-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
1259-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
1260-
sycl::plus<resTy>>>;
1265+
using ReductionOpT = detail::SumTempsOpT<resTy>;
12611266

12621267
std::size_t reductions_per_wi(preferred_reductions_per_wi);
12631268
if (reduction_nelems <= preferred_reductions_per_wi * max_wg) {

dpctl/tensor/libtensor/include/kernels/linalg_functions/gemm.hpp

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1795,6 +1795,17 @@ sycl::event gemm_batch_contig_impl(sycl::queue &exec_q,
17951795

17961796
// ========== Gemm Tree
17971797

1798+
namespace gemm_detail
1799+
{
1800+
1801+
template <typename T>
1802+
using SumTempsOpT = std::conditional_t<
1803+
std::is_same_v<T, bool>,
1804+
sycl::logical_or<T>,
1805+
std::conditional_t<tu_ns::is_complex_v<T>, su_ns::Plus<T>, sycl::plus<T>>>;
1806+
1807+
} // namespace gemm_detail
1808+
17981809
template <typename lhsT,
17991810
typename rhsT,
18001811
typename resT,
@@ -2368,10 +2379,7 @@ gemm_batch_tree_k_impl(sycl::queue &exec_q,
23682379
depends);
23692380
}
23702381
else {
2371-
using ReductionOpT = std::conditional_t<
2372-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
2373-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
2374-
sycl::plus<resTy>>>;
2382+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
23752383
constexpr resTy identity_val =
23762384
su_ns::Identity<ReductionOpT, resTy>::value;
23772385

@@ -2664,10 +2672,7 @@ gemm_batch_tree_nm_impl(sycl::queue &exec_q,
26642672
lhs_indexer, rhs_indexer, res_indexer, depends);
26652673
}
26662674
else {
2667-
using ReductionOpT = std::conditional_t<
2668-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
2669-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
2670-
sycl::plus<resTy>>>;
2675+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
26712676
constexpr resTy identity_val =
26722677
su_ns::Identity<ReductionOpT, resTy>::value;
26732678
std::size_t iter_nelems = batch_nelems * n * m;
@@ -3035,10 +3040,7 @@ gemm_batch_contig_tree_k_impl(sycl::queue &exec_q,
30353040
depends);
30363041
}
30373042
else {
3038-
using ReductionOpT = std::conditional_t<
3039-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3040-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3041-
sycl::plus<resTy>>>;
3043+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
30423044
constexpr resTy identity_val =
30433045
su_ns::Identity<ReductionOpT, resTy>::value;
30443046

@@ -3223,10 +3225,7 @@ gemm_batch_contig_tree_nm_impl(sycl::queue &exec_q,
32233225
lhs_indexer, rhs_indexer, res_indexer, depends);
32243226
}
32253227
else {
3226-
using ReductionOpT = std::conditional_t<
3227-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3228-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3229-
sycl::plus<resTy>>>;
3228+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
32303229
constexpr resTy identity_val =
32313230
su_ns::Identity<ReductionOpT, resTy>::value;
32323231
std::size_t iter_nelems = batch_nelems * n * m;
@@ -3592,10 +3591,7 @@ sycl::event gemm_tree_k_impl(sycl::queue &exec_q,
35923591
res_indexer, depends);
35933592
}
35943593
else {
3595-
using ReductionOpT = std::conditional_t<
3596-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3597-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3598-
sycl::plus<resTy>>>;
3594+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
35993595
constexpr resTy identity_val =
36003596
su_ns::Identity<ReductionOpT, resTy>::value;
36013597

@@ -3746,10 +3742,7 @@ sycl::event gemm_tree_nm_impl(sycl::queue &exec_q,
37463742
lhs_indexer, rhs_indexer, res_indexer, depends);
37473743
}
37483744
else {
3749-
using ReductionOpT = std::conditional_t<
3750-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3751-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3752-
sycl::plus<resTy>>>;
3745+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
37533746
constexpr resTy identity_val =
37543747
su_ns::Identity<ReductionOpT, resTy>::value;
37553748

@@ -3980,10 +3973,7 @@ sycl::event gemm_contig_tree_k_impl(sycl::queue &exec_q,
39803973
res_indexer, depends);
39813974
}
39823975
else {
3983-
using ReductionOpT = std::conditional_t<
3984-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
3985-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
3986-
sycl::plus<resTy>>>;
3976+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
39873977
constexpr resTy identity_val =
39883978
su_ns::Identity<ReductionOpT, resTy>::value;
39893979

@@ -4119,10 +4109,7 @@ sycl::event gemm_contig_tree_nm_impl(sycl::queue &exec_q,
41194109
lhs_indexer, rhs_indexer, res_indexer, depends);
41204110
}
41214111
else {
4122-
using ReductionOpT = std::conditional_t<
4123-
std::is_same_v<resTy, bool>, sycl::logical_or<resTy>,
4124-
std::conditional_t<tu_ns::is_complex_v<resTy>, su_ns::Plus<resTy>,
4125-
sycl::plus<resTy>>>;
4112+
using ReductionOpT = gemm_detail::SumTempsOpT<resTy>;
41264113
constexpr resTy identity_val =
41274114
su_ns::Identity<ReductionOpT, resTy>::value;
41284115

dpctl/tensor/libtensor/source/reductions/prod.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,14 @@ struct TypePairSupportDataForProductReductionTemps
233233
td_ns::NotDefinedEntry>::is_defined;
234234
};
235235

236+
template <typename T>
237+
using ProdTempsOpT =
238+
std::conditional_t<std::is_same_v<T, bool>,
239+
sycl::logical_and<T>,
240+
std::conditional_t<tu_ns::is_complex_v<T>,
241+
su_ns::Multiplies<T>,
242+
sycl::multiplies<T>>>;
243+
236244
template <typename fnT, typename srcTy, typename dstTy>
237245
struct ProductOverAxisAtomicStridedFactory
238246
{
@@ -260,11 +268,7 @@ struct ProductOverAxisTempsStridedFactory
260268
if constexpr (TypePairSupportDataForProductReductionTemps<
261269
srcTy, dstTy>::is_defined)
262270
{
263-
using ReductionOpT = std::conditional_t<
264-
std::is_same_v<dstTy, bool>, sycl::logical_and<dstTy>,
265-
std::conditional_t<tu_ns::is_complex_v<dstTy>,
266-
su_ns::Multiplies<dstTy>,
267-
sycl::multiplies<dstTy>>>;
271+
using ReductionOpT = ProdTempsOpT<dstTy>;
268272
return dpctl::tensor::kernels::
269273
reduction_over_group_temps_strided_impl<srcTy, dstTy,
270274
ReductionOpT>;
@@ -321,11 +325,7 @@ struct ProductOverAxis1TempsContigFactory
321325
if constexpr (TypePairSupportDataForProductReductionTemps<
322326
srcTy, dstTy>::is_defined)
323327
{
324-
using ReductionOpT = std::conditional_t<
325-
std::is_same_v<dstTy, bool>, sycl::logical_and<dstTy>,
326-
std::conditional_t<tu_ns::is_complex_v<dstTy>,
327-
su_ns::Multiplies<dstTy>,
328-
sycl::multiplies<dstTy>>>;
328+
using ReductionOpT = ProdTempsOpT<dstTy>;
329329
return dpctl::tensor::kernels::
330330
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
331331
ReductionOpT>;
@@ -344,11 +344,7 @@ struct ProductOverAxis0TempsContigFactory
344344
if constexpr (TypePairSupportDataForProductReductionTemps<
345345
srcTy, dstTy>::is_defined)
346346
{
347-
using ReductionOpT = std::conditional_t<
348-
std::is_same_v<dstTy, bool>, sycl::logical_and<dstTy>,
349-
std::conditional_t<tu_ns::is_complex_v<dstTy>,
350-
su_ns::Multiplies<dstTy>,
351-
sycl::multiplies<dstTy>>>;
347+
using ReductionOpT = ProdTempsOpT<dstTy>;
352348
return dpctl::tensor::kernels::
353349
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
354350
ReductionOpT>;

dpctl/tensor/libtensor/source/reductions/sum.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,12 @@ struct TypePairSupportDataForSumReductionTemps
233233
td_ns::NotDefinedEntry>::is_defined;
234234
};
235235

236+
template <typename T>
237+
using SumTempsOpT = std::conditional_t<
238+
std::is_same_v<T, bool>,
239+
sycl::logical_or<T>,
240+
std::conditional_t<tu_ns::is_complex_v<T>, su_ns::Plus<T>, sycl::plus<T>>>;
241+
236242
template <typename fnT, typename srcTy, typename dstTy>
237243
struct SumOverAxisAtomicStridedFactory
238244
{
@@ -260,10 +266,7 @@ struct SumOverAxisTempsStridedFactory
260266
if constexpr (TypePairSupportDataForSumReductionTemps<
261267
srcTy, dstTy>::is_defined)
262268
{
263-
using ReductionOpT = std::conditional_t<
264-
std::is_same_v<dstTy, bool>, sycl::logical_or<dstTy>,
265-
std::conditional_t<tu_ns::is_complex_v<dstTy>,
266-
su_ns::Plus<dstTy>, sycl::plus<dstTy>>>;
269+
using ReductionOpT = SumTempsOpT<dstTy>;
267270
return dpctl::tensor::kernels::
268271
reduction_over_group_temps_strided_impl<srcTy, dstTy,
269272
ReductionOpT>;
@@ -320,10 +323,7 @@ struct SumOverAxis1TempsContigFactory
320323
if constexpr (TypePairSupportDataForSumReductionTemps<
321324
srcTy, dstTy>::is_defined)
322325
{
323-
using ReductionOpT = std::conditional_t<
324-
std::is_same_v<dstTy, bool>, sycl::logical_or<dstTy>,
325-
std::conditional_t<tu_ns::is_complex_v<dstTy>,
326-
su_ns::Plus<dstTy>, sycl::plus<dstTy>>>;
326+
using ReductionOpT = SumTempsOpT<dstTy>;
327327
return dpctl::tensor::kernels::
328328
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
329329
ReductionOpT>;
@@ -342,10 +342,7 @@ struct SumOverAxis0TempsContigFactory
342342
if constexpr (TypePairSupportDataForSumReductionTemps<
343343
srcTy, dstTy>::is_defined)
344344
{
345-
using ReductionOpT = std::conditional_t<
346-
std::is_same_v<dstTy, bool>, sycl::logical_or<dstTy>,
347-
std::conditional_t<tu_ns::is_complex_v<dstTy>,
348-
su_ns::Plus<dstTy>, sycl::plus<dstTy>>>;
345+
using ReductionOpT = SumTempsOpT<dstTy>;
349346
return dpctl::tensor::kernels::
350347
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
351348
ReductionOpT>;

0 commit comments

Comments
 (0)