Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support half precision R2C transforms #796

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions include/matx/transforms/fft/fft_common.h
Original file line number Diff line number Diff line change
@@ -197,16 +197,15 @@ namespace detail {
if constexpr (is_complex_half_v<T2>) {
return FFTType::C2C;
}
else if constexpr (is_half_v<T2>) {
else if constexpr (is_half_v<T2> || is_matx_half_v<T2>) {
return FFTType::R2C;
}
}
else if constexpr (is_half_v<T1> && is_complex_half_v<T2>) {
else if constexpr ((is_half_v<T1> || is_matx_half_v<T1>) && is_complex_half_v<T2>) {
return FFTType::C2R;
}
//else {
return FFTType::C2C;
//}

return FFTType::C2C;
}
}

8 changes: 3 additions & 5 deletions include/matx/transforms/fft/fft_cuda.h
Original file line number Diff line number Diff line change
@@ -169,9 +169,7 @@ template <typename OutTensorType, typename InTensorType> class matxCUDAFFTPlan_t

params.irank = i.Rank();
params.orank = o.Rank();

params.transform_type = DeduceFFTTransformType<OutTensorType, InTensorType>();

params.input_type = matxCUDAFFTPlan_t<OutTensorType, InTensorType>::GetInputType();
params.output_type = matxCUDAFFTPlan_t<OutTensorType, InTensorType>::GetOutputType();
params.exec_type = matxCUDAFFTPlan_t<OutTensorType, InTensorType>::GetExecType();
@@ -406,7 +404,7 @@ matxCUDAFFTPlan1D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream
}
else if (this->params_.transform_type == FFTType::R2C ||
this->params_.transform_type == FFTType::D2Z) {
if (is_cuda_complex_v<T2> || !is_cuda_complex_v<T1>) {
if (is_complex_v<T2> || !is_complex_v<T1>) {
MATX_THROW(matxInvalidType, "FFT types inconsistent with R2C/D2Z transform");
}
if (this->params_.n[0] != i.Size(InTensorType::Rank()-1) ||
@@ -534,14 +532,14 @@ class matxCUDAFFTPlan2D_t : public matxCUDAFFTPlan_t<OutTensorType, InTensorType
this->params_.transform_type == FFTType::Z2D) {
MATX_ASSERT((o.Size(RANK-2) * (o.Size(RANK-1) / 2 + 1)) == i.Size(RANK-1) * i.Size(RANK-2),
matxInvalidSize);
MATX_ASSERT(!is_cuda_complex_v<T1> && is_cuda_complex_v<T2>,
MATX_ASSERT(!is_complex_v<T1> && is_complex_v<T2>,
matxInvalidType);
}
else if (this->params_.transform_type == FFTType::R2C ||
this->params_.transform_type == FFTType::D2Z) {
MATX_ASSERT(o.Size(RANK-1) * o.Size(RANK-2) == (i.Size(RANK-2) * (i.Size(RANK-1) / 2 + 1)),
matxInvalidSize);
MATX_ASSERT(!is_cuda_complex_v<T2> && is_cuda_complex_v<T1>,
MATX_ASSERT(!is_complex_v<T2> && is_complex_v<T1>,
matxInvalidType);
}
else {
136 changes: 82 additions & 54 deletions test/00_transform/FFT.cu
Original file line number Diff line number Diff line change
@@ -604,45 +604,55 @@ TYPED_TEST(FFTTestComplexTypes, IFFT1D1024PadC2C)
MATX_EXIT_HANDLER();
}

TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, FFT1D1024R2C)
TYPED_TEST(FFTTestComplexTypes, FFT1D1024R2C)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

const index_t fft_dim = 1024;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim});
if constexpr (!detail::CheckFFTSupport<ExecType, TestType>()) {
GTEST_SKIP();
} else {
const index_t fft_dim = 1024;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim});

tensor_t<typename TestType::value_type, 1> av{{fft_dim}};
tensor_t<TestType, 1> avo{{fft_dim / 2 + 1}};
this->pb->NumpyToTensorView(av, "a_in");
tensor_t<typename TestType::value_type, 1> av{{fft_dim}};
tensor_t<TestType, 1> avo{{fft_dim / 2 + 1}};
this->pb->NumpyToTensorView(av, "a_in");

(avo = fft(av)).run(this->exec);
this->exec.sync();
(avo = fft(av)).run(this->exec);
this->exec.sync();

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
}
MATX_EXIT_HANDLER();
}

TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, FFT1D1024PadR2C)
TYPED_TEST(FFTTestComplexTypes, FFT1D1024PadR2C)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

const index_t fft_dim = 4;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim*2});
if constexpr (!detail::CheckFFTSupport<ExecType, TestType>()) {
GTEST_SKIP();
} else {
const index_t fft_dim = 4;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim*2});

tensor_t<typename TestType::value_type, 1> av{{fft_dim}};
tensor_t<TestType, 1> avo{{fft_dim + 1}};
this->pb->NumpyToTensorView(av, "a_in");
tensor_t<typename TestType::value_type, 1> av{{fft_dim}};
tensor_t<TestType, 1> avo{{fft_dim + 1}};
this->pb->NumpyToTensorView(av, "a_in");

(avo = fft(av, fft_dim*2)).run(this->exec);
this->exec.sync();
(avo = fft(av, fft_dim*2)).run(this->exec);
this->exec.sync();

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
}
MATX_EXIT_HANDLER();
}

@@ -710,19 +720,25 @@ TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, FFT1D1024PadBatchedR2C)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
const index_t fft_dim = 4;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_1d_batched", {fft_dim, fft_dim, fft_dim*2});
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

tensor_t<typename TestType::value_type, 2> av{{fft_dim, fft_dim}};
tensor_t<TestType, 2> avo{{fft_dim, fft_dim + 1}};
this->pb->NumpyToTensorView(av, "a_in");
if constexpr (!detail::CheckFFTSupport<ExecType, TestType>()) {
GTEST_SKIP();
} else {
const index_t fft_dim = 4;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_1d_batched", {fft_dim, fft_dim, fft_dim*2});

(avo = fft(av, fft_dim*2)).run(this->exec);
this->exec.sync();
tensor_t<typename TestType::value_type, 2> av{{fft_dim, fft_dim}};
tensor_t<TestType, 2> avo{{fft_dim, fft_dim + 1}};
this->pb->NumpyToTensorView(av, "a_in");

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
(avo = fft(av, fft_dim*2)).run(this->exec);
this->exec.sync();

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
}
MATX_EXIT_HANDLER();
}

@@ -875,43 +891,55 @@ TYPED_TEST(FFTTestComplexTypes, IFFT2D16x32C2C)
MATX_EXIT_HANDLER();
}

TYPED_TEST(FFTTestComplexNonHalfTypes, FFT2D16R2C)
TYPED_TEST(FFTTestComplexTypes, FFT2D16R2C)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
const index_t fft_dim = 16;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_2d", {fft_dim, fft_dim});
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

tensor_t<rtype, 2> av{{fft_dim, fft_dim}};
tensor_t<TestType, 2> avo{{fft_dim, fft_dim / 2 + 1}};
this->pb->NumpyToTensorView(av, "a_in");
if constexpr (!detail::CheckFFTSupport<ExecType, TestType>()) {
GTEST_SKIP();
} else {
const index_t fft_dim = 16;
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_2d", {fft_dim, fft_dim});

(avo = fft2(av)).run(this->exec);
this->exec.sync();
tensor_t<rtype, 2> av{{fft_dim, fft_dim}};
tensor_t<TestType, 2> avo{{fft_dim, fft_dim / 2 + 1}};
this->pb->NumpyToTensorView(av, "a_in");

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
(avo = fft2(av)).run(this->exec);
this->exec.sync();

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
}
MATX_EXIT_HANDLER();
}

TYPED_TEST(FFTTestComplexNonHalfTypes, FFT2D16x32R2C)
TYPED_TEST(FFTTestComplexTypes, FFT2D16x32R2C)
{
MATX_ENTER_HANDLER();
using TestType = cuda::std::tuple_element_t<0, TypeParam>;
const index_t fft_dim[] = {16, 32};
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_2d", {fft_dim[0], fft_dim[1]});
using ExecType = cuda::std::tuple_element_t<1, TypeParam>;

tensor_t<rtype, 2> av{{fft_dim[0], fft_dim[1]}};
tensor_t<TestType, 2> avo{{fft_dim[0], fft_dim[1] / 2 + 1}};
this->pb->NumpyToTensorView(av, "a_in");
if constexpr (!detail::CheckFFTSupport<ExecType, TestType>()) {
GTEST_SKIP();
} else {
const index_t fft_dim[] = {16, 32};
using rtype = typename TestType::value_type;
this->pb->template InitAndRunTVGenerator<rtype>(
"00_transforms", "fft_operators", "rfft_2d", {fft_dim[0], fft_dim[1]});

(avo = fft2(av)).run(this->exec);
this->exec.sync();
tensor_t<rtype, 2> av{{fft_dim[0], fft_dim[1]}};
tensor_t<TestType, 2> avo{{fft_dim[0], fft_dim[1] / 2 + 1}};
this->pb->NumpyToTensorView(av, "a_in");

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
(avo = fft2(av)).run(this->exec);
this->exec.sync();

MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh);
}
MATX_EXIT_HANDLER();
}