Skip to content

Commit c4eda1a

Browse files
authored
[BACKPORT]: Add missing overloads for thrust::pow (#1223)
* Add missing overloads for thrust::pow Also add proper type checks for all of those overloads so that we can ensure that we are * Properly constraint the pow overloads and stop pulling in `cuda::std::pow`
1 parent b28d445 commit c4eda1a

File tree

2 files changed

+48
-7
lines changed

2 files changed

+48
-7
lines changed

thrust/testing/complex.cu

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,17 +449,18 @@ struct TestComplexBasicArithmetic
449449
// Test the basic arithmetic functions against std
450450

451451
ASSERT_ALMOST_EQUAL(thrust::abs(a), std::abs(b));
452-
453452
ASSERT_ALMOST_EQUAL(thrust::arg(a), std::arg(b));
454-
455453
ASSERT_ALMOST_EQUAL(thrust::norm(a), std::norm(b));
456454

457455
ASSERT_EQUAL(thrust::conj(a), std::conj(b));
456+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::conj(a))>::value, "");
458457

459458
ASSERT_ALMOST_EQUAL(thrust::polar(data[0], data[1]), std::polar(data[0], data[1]));
459+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::polar(data[0], data[1]))>::value, "");
460460

461461
// random_samples does not seem to produce infinities so proj(z) == z
462462
ASSERT_EQUAL(thrust::proj(a), a);
463+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::proj(a))>::value, "");
463464
}
464465
};
465466
SimpleUnitTest<TestComplexBasicArithmetic, FloatingPointTypes> TestComplexBasicArithmeticInstance;
@@ -556,6 +557,9 @@ struct TestComplexExponentialFunctions
556557
ASSERT_ALMOST_EQUAL(thrust::exp(a), std::exp(b));
557558
ASSERT_ALMOST_EQUAL(thrust::log(a), std::log(b));
558559
ASSERT_ALMOST_EQUAL(thrust::log10(a), std::log10(b));
560+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::exp(a))>::value, "");
561+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log(a))>::value, "");
562+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::log10(a))>::value, "");
559563
}
560564
};
561565
SimpleUnitTest<TestComplexExponentialFunctions, FloatingPointTypes>
@@ -575,16 +579,24 @@ struct TestComplexPowerFunctions
575579
const std::complex<T> b_std(b_thrust);
576580

577581
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
582+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
578583
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
584+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
579585
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
586+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");
587+
588+
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, 4), std::pow(a_std, 4));
589+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::pow(a_thrust, 4))>::value, "");
580590

581591
ASSERT_ALMOST_EQUAL(thrust::sqrt(a_thrust), std::sqrt(a_std));
592+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sqrt(a_thrust))>::value, "");
582593
}
583594

584595
// Test power functions with promoted types.
585596
{
586597
using T0 = T;
587598
using T1 = other_floating_point_type_t<T0>;
599+
using promoted = typename thrust::detail::promoted_numerical_type<T0, T1>::type;
588600

589601
thrust::host_vector<T0> data = unittest::random_samples<T0>(4);
590602

@@ -594,11 +606,17 @@ struct TestComplexPowerFunctions
594606
const std::complex<T0> b_std(data[2], data[3]);
595607

596608
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust), std::pow(a_std, b_std));
609+
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust))>::value, "");
597610
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust), std::pow(b_std, a_std));
611+
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust))>::value, "");
598612
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust, b_thrust.real()), std::pow(a_std, b_std.real()));
613+
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust, b_thrust.real()))>::value, "");
599614
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust, a_thrust.real()), std::pow(b_std, a_std.real()));
615+
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust, a_thrust.real()))>::value, "");
600616
ASSERT_ALMOST_EQUAL(thrust::pow(a_thrust.real(), b_thrust), std::pow(a_std.real(), b_std));
617+
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(a_thrust.real(), b_thrust))>::value, "");
601618
ASSERT_ALMOST_EQUAL(thrust::pow(b_thrust.real(), a_thrust), std::pow(b_std.real(), a_std));
619+
static_assert(cuda::std::is_same<thrust::complex<promoted>, decltype(thrust::pow(b_thrust.real(), a_thrust))>::value, "");
602620
}
603621
}
604622
};
@@ -617,20 +635,32 @@ struct TestComplexTrigonometricFunctions
617635
ASSERT_ALMOST_EQUAL(thrust::cos(a), std::cos(c));
618636
ASSERT_ALMOST_EQUAL(thrust::sin(a), std::sin(c));
619637
ASSERT_ALMOST_EQUAL(thrust::tan(a), std::tan(c));
638+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cos(a))>::value, "");
639+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sin(a))>::value, "");
640+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tan(a))>::value, "");
620641

621642
ASSERT_ALMOST_EQUAL(thrust::cosh(a), std::cosh(c));
622643
ASSERT_ALMOST_EQUAL(thrust::sinh(a), std::sinh(c));
623644
ASSERT_ALMOST_EQUAL(thrust::tanh(a), std::tanh(c));
645+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::cosh(a))>::value, "");
646+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::sinh(a))>::value, "");
647+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::tanh(a))>::value, "");
624648

625649
#if THRUST_CPP_DIALECT >= 2011
626650

627651
ASSERT_ALMOST_EQUAL(thrust::acos(a), std::acos(c));
628652
ASSERT_ALMOST_EQUAL(thrust::asin(a), std::asin(c));
629653
ASSERT_ALMOST_EQUAL(thrust::atan(a), std::atan(c));
654+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acos(a))>::value, "");
655+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asin(a))>::value, "");
656+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atan(a))>::value, "");
630657

631658
ASSERT_ALMOST_EQUAL(thrust::acosh(a), std::acosh(c));
632659
ASSERT_ALMOST_EQUAL(thrust::asinh(a), std::asinh(c));
633660
ASSERT_ALMOST_EQUAL(thrust::atanh(a), std::atanh(c));
661+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::acosh(a))>::value, "");
662+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::asinh(a))>::value, "");
663+
static_assert(cuda::std::is_same<thrust::complex<T>, decltype(thrust::atanh(a))>::value, "");
634664

635665
#endif
636666
}

thrust/thrust/complex.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ using ::cuda::std::proj;
474474
using ::cuda::std::exp;
475475
using ::cuda::std::log;
476476
using ::cuda::std::log10;
477-
using ::cuda::std::pow;
477+
// pow always returns a complex.
478+
// using ::cuda::std::pow;
478479
using ::cuda::std::sqrt;
479480

480481
using ::cuda::std::acos;
@@ -516,15 +517,25 @@ template<class T>
516517
__host__ __device__ complex<T> log10(const complex<T>& c) {
517518
return static_cast<complex<T>>(::cuda::std::log10(c));
518519
}
519-
template<class T>
520-
__host__ __device__ complex<T> pow(const complex<T>& c) {
521-
return static_cast<complex<T>>(::cuda::std::pow(c));
520+
template<class T0, class T1>
521+
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
522+
pow(const complex<T0>& x, const complex<T1>& y) {
523+
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
524+
}
525+
template<class T0, class T1, ::cuda::std::__enable_if_t<::cuda::std::is_arithmetic<T1>::value, int> = 0>
526+
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
527+
pow(const complex<T0>& x, const T1& y) {
528+
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
529+
}
530+
template<class T0, class T1, ::cuda::std::__enable_if_t<::cuda::std::is_arithmetic<T0>::value, int> = 0>
531+
__host__ __device__ complex<typename detail::promoted_numerical_type<T0, T1>::type>
532+
pow(const T0& x, const complex<T1>& y) {
533+
return static_cast<complex<typename detail::promoted_numerical_type<T0, T1>::type>>(::cuda::std::pow(x, y));
522534
}
523535
template<class T>
524536
__host__ __device__ complex<T> sqrt(const complex<T>& c) {
525537
return static_cast<complex<T>>(::cuda::std::sqrt(c));
526538
}
527-
528539
template<class T>
529540
__host__ __device__ complex<T> acos(const complex<T>& c) {
530541
return static_cast<complex<T>>(::cuda::std::acos(c));

0 commit comments

Comments
 (0)