Skip to content

Commit c0ca6cb

Browse files
committed
Use specialized functor for multiplying or adding complex inputs
converts to experimental sycl complex values, then performs math operations
1 parent d855c66 commit c0ca6cb

File tree

6 files changed

+127
-58
lines changed

6 files changed

+127
-58
lines changed

dpctl/tensor/libtensor/include/utils/math_utils.hpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,20 @@ template <typename T> T logaddexp(T x, T y)
154154
}
155155
}
156156

157+
template <typename T> T plus_complex(const T &x1, const T &x2)
158+
{
159+
using realT = typename T::value_type;
160+
using sycl_complexT = exprm_ns::complex<realT>;
161+
return T(sycl_complexT(x1) + sycl_complexT(x2));
162+
}
163+
164+
template <typename T> T multiplies_complex(const T &x1, const T &x2)
165+
{
166+
using realT = typename T::value_type;
167+
using sycl_complexT = exprm_ns::complex<realT>;
168+
return T(sycl_complexT(x1) * sycl_complexT(x2));
169+
}
170+
157171
} // namespace math_utils
158172
} // namespace tensor
159173
} // namespace dpctl

dpctl/tensor/libtensor/include/utils/sycl_utils.hpp

Lines changed: 65 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,11 @@ T custom_inclusive_scan_over_group(GroupT &&wg,
298298
return scan_val;
299299
}
300300

301-
// Reduction functors
301+
// Define identities and operator checking structs
302+
303+
template <typename Op, typename T, typename = void> struct GetIdentity
304+
{
305+
};
302306

303307
// Maximum
304308

@@ -324,38 +328,6 @@ template <typename T> struct Maximum
324328
}
325329
};
326330

327-
// Minimum
328-
329-
template <typename T> struct Minimum
330-
{
331-
T operator()(const T &x, const T &y) const
332-
{
333-
if constexpr (detail::IsComplex<T>::value) {
334-
using dpctl::tensor::math_utils::min_complex;
335-
return min_complex<T>(x, y);
336-
}
337-
else if constexpr (std::is_floating_point_v<T> ||
338-
std::is_same_v<T, sycl::half>)
339-
{
340-
return (std::isnan(x) || x < y) ? x : y;
341-
}
342-
else if constexpr (std::is_same_v<T, bool>) {
343-
return x && y;
344-
}
345-
else {
346-
return (x < y) ? x : y;
347-
}
348-
}
349-
};
350-
351-
// Define identities and operator checking structs
352-
353-
template <typename Op, typename T, typename = void> struct GetIdentity
354-
{
355-
};
356-
357-
// Maximum
358-
359331
template <typename T, class Op>
360332
using IsMaximum = std::bool_constant<std::is_same_v<Op, sycl::maximum<T>> ||
361333
std::is_same_v<Op, Maximum<T>>>;
@@ -389,6 +361,28 @@ struct GetIdentity<Op,
389361

390362
// Minimum
391363

364+
template <typename T> struct Minimum
365+
{
366+
T operator()(const T &x, const T &y) const
367+
{
368+
if constexpr (detail::IsComplex<T>::value) {
369+
using dpctl::tensor::math_utils::min_complex;
370+
return min_complex<T>(x, y);
371+
}
372+
else if constexpr (std::is_floating_point_v<T> ||
373+
std::is_same_v<T, sycl::half>)
374+
{
375+
return (std::isnan(x) || x < y) ? x : y;
376+
}
377+
else if constexpr (std::is_same_v<T, bool>) {
378+
return x && y;
379+
}
380+
else {
381+
return (x < y) ? x : y;
382+
}
383+
}
384+
};
385+
392386
template <typename T, class Op>
393387
using IsMinimum = std::bool_constant<std::is_same_v<Op, sycl::minimum<T>> ||
394388
std::is_same_v<Op, Minimum<T>>>;
@@ -422,19 +416,55 @@ struct GetIdentity<Op,
422416

423417
// Plus
424418

419+
template <typename T> struct Plus
420+
{
421+
T operator()(const T &x, const T &y) const
422+
{
423+
if constexpr (detail::IsComplex<T>::value) {
424+
using dpctl::tensor::math_utils::plus_complex;
425+
return plus_complex<T>(x, y);
426+
}
427+
else {
428+
return sycl::plus<T>(x, y);
429+
}
430+
}
431+
};
432+
425433
template <typename T, class Op>
426434
using IsPlus = std::bool_constant<std::is_same_v<Op, sycl::plus<T>> ||
427-
std::is_same_v<Op, std::plus<T>>>;
435+
std::is_same_v<Op, std::plus<T>> ||
436+
std::is_same_v<Op, Plus<T>>>;
428437

429438
template <typename T, class Op>
430439
using IsSyclPlus = std::bool_constant<std::is_same_v<Op, sycl::plus<T>>>;
431440

441+
template <typename Op, typename T>
442+
struct GetIdentity<Op, T, std::enable_if_t<IsPlus<T, Op>::value>>
443+
{
444+
static constexpr T value = static_cast<T>(0);
445+
};
446+
432447
// Multiplies
433448

449+
template <typename T> struct Multiplies
450+
{
451+
T operator()(const T &x, const T &y) const
452+
{
453+
if constexpr (detail::IsComplex<T>::value) {
454+
using dpctl::tensor::math_utils::multiplies_complex;
455+
return multiplies_complex<T>(x, y);
456+
}
457+
else {
458+
return sycl::multiplies<T>(x, y);
459+
}
460+
}
461+
};
462+
434463
template <typename T, class Op>
435464
using IsMultiplies =
436465
std::bool_constant<std::is_same_v<Op, sycl::multiplies<T>> ||
437-
std::is_same_v<Op, std::multiplies<T>>>;
466+
std::is_same_v<Op, std::multiplies<T>> ||
467+
std::is_same_v<Op, Multiplies<T>>>;
438468

439469
template <typename T, class Op>
440470
using IsSyclMultiplies =

dpctl/tensor/libtensor/source/accumulators/cumulative_prod.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ namespace py_internal
4646

4747
namespace su_ns = dpctl::tensor::sycl_utils;
4848
namespace td_ns = dpctl::tensor::type_dispatch;
49+
namespace tu_ns = dpctl::tensor::type_utils;
4950

5051
namespace impl
5152
{
@@ -133,9 +134,12 @@ struct TypePairSupportDataForProdAccumulation
133134
};
134135

135136
template <typename T>
136-
using CumProdScanOpT = std::conditional_t<std::is_same_v<T, bool>,
137-
sycl::logical_and<T>,
138-
sycl::multiplies<T>>;
137+
using CumProdScanOpT =
138+
std::conditional_t<std::is_same_v<T, bool>,
139+
sycl::logical_and<T>,
140+
std::conditional_t<tu_ns::is_complex_v<T>,
141+
su_ns::Multiplies<T>,
142+
sycl::multiplies<T>>>;
139143

140144
template <typename fnT, typename srcTy, typename dstTy>
141145
struct CumProd1DContigFactory

dpctl/tensor/libtensor/source/accumulators/cumulative_sum.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
#include "kernels/accumulators.hpp"
3535
#include "utils/sycl_utils.hpp"
3636
#include "utils/type_dispatch_building.hpp"
37+
#include "utils/type_utils.hpp"
3738

3839
namespace py = pybind11;
3940

@@ -46,6 +47,7 @@ namespace py_internal
4647

4748
namespace su_ns = dpctl::tensor::sycl_utils;
4849
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
4951

5052
namespace impl
5153
{
@@ -133,8 +135,10 @@ struct TypePairSupportDataForSumAccumulation
133135
};
134136

135137
template <typename T>
136-
using CumSumScanOpT = std::
137-
conditional_t<std::is_same_v<T, bool>, sycl::logical_or<T>, sycl::plus<T>>;
138+
using CumSumScanOpT = std::conditional_t<
139+
std::is_same_v<T, bool>,
140+
sycl::logical_or<T>,
141+
std::conditional_t<tu_ns::is_complex_v<T>, su_ns::Plus<T>, sycl::plus<T>>>;
138142

139143
template <typename fnT, typename srcTy, typename dstTy>
140144
struct CumSum1DContigFactory

dpctl/tensor/libtensor/source/reductions/prod.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
#include <vector>
3232

3333
#include "kernels/reductions.hpp"
34+
#include "utils/sycl_utils.hpp"
3435
#include "utils/type_dispatch_building.hpp"
36+
#include "utils/type_utils.hpp"
3537

3638
#include "reduction_atomic_support.hpp"
3739
#include "reduction_over_axis.hpp"
@@ -45,7 +47,9 @@ namespace tensor
4547
namespace py_internal
4648
{
4749

50+
namespace su_ns = dpctl::tensor::sycl_utils;
4851
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace tu_ns = dpctl::tensor::type_utils;
4953

5054
namespace impl
5155
{
@@ -256,9 +260,11 @@ struct ProductOverAxisTempsStridedFactory
256260
if constexpr (TypePairSupportDataForProductReductionTemps<
257261
srcTy, dstTy>::is_defined)
258262
{
259-
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
260-
sycl::logical_and<dstTy>,
261-
sycl::multiplies<dstTy>>;
263+
using ReductionOpT = std::conditional_t<
264+
std::is_same_v<dstTy, bool>, sycl::logical_and<dstTy>,
265+
std::conditional_t<tu_ns::is_complex_v<dstTy>,
266+
su_ns::Multiplies<dstTy>,
267+
sycl::multiplies<dstTy>>>;
262268
return dpctl::tensor::kernels::
263269
reduction_over_group_temps_strided_impl<srcTy, dstTy,
264270
ReductionOpT>;
@@ -315,9 +321,11 @@ struct ProductOverAxis1TempsContigFactory
315321
if constexpr (TypePairSupportDataForProductReductionTemps<
316322
srcTy, dstTy>::is_defined)
317323
{
318-
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
319-
sycl::logical_and<dstTy>,
320-
sycl::multiplies<dstTy>>;
324+
using ReductionOpT = std::conditional_t<
325+
std::is_same_v<dstTy, bool>, sycl::logical_and<dstTy>,
326+
std::conditional_t<tu_ns::is_complex_v<dstTy>,
327+
su_ns::Multiplies<dstTy>,
328+
sycl::multiplies<dstTy>>>;
321329
return dpctl::tensor::kernels::
322330
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
323331
ReductionOpT>;
@@ -336,9 +344,11 @@ struct ProductOverAxis0TempsContigFactory
336344
if constexpr (TypePairSupportDataForProductReductionTemps<
337345
srcTy, dstTy>::is_defined)
338346
{
339-
using ReductionOpT = std::conditional_t<std::is_same_v<dstTy, bool>,
340-
sycl::logical_and<dstTy>,
341-
sycl::multiplies<dstTy>>;
347+
using ReductionOpT = std::conditional_t<
348+
std::is_same_v<dstTy, bool>, sycl::logical_and<dstTy>,
349+
std::conditional_t<tu_ns::is_complex_v<dstTy>,
350+
su_ns::Multiplies<dstTy>,
351+
sycl::multiplies<dstTy>>>;
342352
return dpctl::tensor::kernels::
343353
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
344354
ReductionOpT>;

dpctl/tensor/libtensor/source/reductions/sum.cpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
#include <vector>
3232

3333
#include "kernels/reductions.hpp"
34+
#include "utils/sycl_utils.hpp"
3435
#include "utils/type_dispatch_building.hpp"
36+
#include "utils/type_utils.hpp"
3537

3638
#include "reduction_atomic_support.hpp"
3739
#include "reduction_over_axis.hpp"
@@ -45,7 +47,9 @@ namespace tensor
4547
namespace py_internal
4648
{
4749

50+
namespace su_ns = dpctl::tensor::sycl_utils;
4851
namespace td_ns = dpctl::tensor::type_dispatch;
52+
namespace tu_ns = dpctl::tensor::type_utils;
4953

5054
namespace impl
5155
{
@@ -256,9 +260,10 @@ struct SumOverAxisTempsStridedFactory
256260
if constexpr (TypePairSupportDataForSumReductionTemps<
257261
srcTy, dstTy>::is_defined)
258262
{
259-
using ReductionOpT =
260-
std::conditional_t<std::is_same_v<dstTy, bool>,
261-
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
263+
using ReductionOpT = std::conditional_t<
264+
std::is_same_v<dstTy, bool>, sycl::logical_or<dstTy>,
265+
std::conditional_t<tu_ns::is_complex_v<dstTy>,
266+
su_ns::Plus<dstTy>, sycl::plus<dstTy>>>;
262267
return dpctl::tensor::kernels::
263268
reduction_over_group_temps_strided_impl<srcTy, dstTy,
264269
ReductionOpT>;
@@ -315,9 +320,10 @@ struct SumOverAxis1TempsContigFactory
315320
if constexpr (TypePairSupportDataForSumReductionTemps<
316321
srcTy, dstTy>::is_defined)
317322
{
318-
using ReductionOpT =
319-
std::conditional_t<std::is_same_v<dstTy, bool>,
320-
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
323+
using ReductionOpT = std::conditional_t<
324+
std::is_same_v<dstTy, bool>, sycl::logical_or<dstTy>,
325+
std::conditional_t<tu_ns::is_complex_v<dstTy>,
326+
su_ns::Plus<dstTy>, sycl::plus<dstTy>>>;
321327
return dpctl::tensor::kernels::
322328
reduction_axis1_over_group_temps_contig_impl<srcTy, dstTy,
323329
ReductionOpT>;
@@ -336,9 +342,10 @@ struct SumOverAxis0TempsContigFactory
336342
if constexpr (TypePairSupportDataForSumReductionTemps<
337343
srcTy, dstTy>::is_defined)
338344
{
339-
using ReductionOpT =
340-
std::conditional_t<std::is_same_v<dstTy, bool>,
341-
sycl::logical_or<dstTy>, sycl::plus<dstTy>>;
345+
using ReductionOpT = std::conditional_t<
346+
std::is_same_v<dstTy, bool>, sycl::logical_or<dstTy>,
347+
std::conditional_t<tu_ns::is_complex_v<dstTy>,
348+
su_ns::Plus<dstTy>, sycl::plus<dstTy>>>;
342349
return dpctl::tensor::kernels::
343350
reduction_axis0_over_group_temps_contig_impl<srcTy, dstTy,
344351
ReductionOpT>;

0 commit comments

Comments
 (0)