diff --git a/include/matx/transforms/fft/fft_common.h b/include/matx/transforms/fft/fft_common.h index 6ad1a4367..fe5641a52 100644 --- a/include/matx/transforms/fft/fft_common.h +++ b/include/matx/transforms/fft/fft_common.h @@ -197,16 +197,15 @@ namespace detail { if constexpr (is_complex_half_v) { return FFTType::C2C; } - else if constexpr (is_half_v) { + else if constexpr (is_half_v || is_matx_half_v) { return FFTType::R2C; } } - else if constexpr (is_half_v && is_complex_half_v) { + else if constexpr ((is_half_v || is_matx_half_v) && is_complex_half_v) { return FFTType::C2R; } - //else { - return FFTType::C2C; - //} + + return FFTType::C2C; } } diff --git a/include/matx/transforms/fft/fft_cuda.h b/include/matx/transforms/fft/fft_cuda.h index f4793b07e..aa87639b0 100644 --- a/include/matx/transforms/fft/fft_cuda.h +++ b/include/matx/transforms/fft/fft_cuda.h @@ -169,9 +169,7 @@ template class matxCUDAFFTPlan_t params.irank = i.Rank(); params.orank = o.Rank(); - params.transform_type = DeduceFFTTransformType(); - params.input_type = matxCUDAFFTPlan_t::GetInputType(); params.output_type = matxCUDAFFTPlan_t::GetOutputType(); params.exec_type = matxCUDAFFTPlan_t::GetExecType(); @@ -406,7 +404,7 @@ matxCUDAFFTPlan1D_t(OutTensorType &o, const InTensorType &i, cudaStream_t stream } else if (this->params_.transform_type == FFTType::R2C || this->params_.transform_type == FFTType::D2Z) { - if (is_cuda_complex_v || !is_cuda_complex_v) { + if (is_complex_v || !is_complex_v) { MATX_THROW(matxInvalidType, "FFT types inconsistent with R2C/D2Z transform"); } if (this->params_.n[0] != i.Size(InTensorType::Rank()-1) || @@ -534,14 +532,14 @@ class matxCUDAFFTPlan2D_t : public matxCUDAFFTPlan_tparams_.transform_type == FFTType::Z2D) { MATX_ASSERT((o.Size(RANK-2) * (o.Size(RANK-1) / 2 + 1)) == i.Size(RANK-1) * i.Size(RANK-2), matxInvalidSize); - MATX_ASSERT(!is_cuda_complex_v && is_cuda_complex_v, + MATX_ASSERT(!is_complex_v && is_complex_v, matxInvalidType); } else if (this->params_.transform_type == FFTType::R2C || this->params_.transform_type == FFTType::D2Z) { MATX_ASSERT(o.Size(RANK-1) * o.Size(RANK-2) == (i.Size(RANK-2) * (i.Size(RANK-1) / 2 + 1)), matxInvalidSize); - MATX_ASSERT(!is_cuda_complex_v && is_cuda_complex_v, + MATX_ASSERT(!is_complex_v && is_complex_v, matxInvalidType); } else { diff --git a/test/00_transform/FFT.cu b/test/00_transform/FFT.cu index d91370f4a..b4bb32ed2 100644 --- a/test/00_transform/FFT.cu +++ b/test/00_transform/FFT.cu @@ -604,45 +604,55 @@ TYPED_TEST(FFTTestComplexTypes, IFFT1D1024PadC2C) MATX_EXIT_HANDLER(); } -TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, FFT1D1024R2C) +TYPED_TEST(FFTTestComplexTypes, FFT1D1024R2C) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - const index_t fft_dim = 1024; - using rtype = typename TestType::value_type; - this->pb->template InitAndRunTVGenerator( - "00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim}); + if constexpr (!detail::CheckFFTSupport()) { + GTEST_SKIP(); + } else { + const index_t fft_dim = 1024; + using rtype = typename TestType::value_type; + this->pb->template InitAndRunTVGenerator( + "00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim}); - tensor_t av{{fft_dim}}; - tensor_t avo{{fft_dim / 2 + 1}}; - this->pb->NumpyToTensorView(av, "a_in"); + tensor_t av{{fft_dim}}; + tensor_t avo{{fft_dim / 2 + 1}}; + this->pb->NumpyToTensorView(av, "a_in"); - (avo = fft(av)).run(this->exec); - this->exec.sync(); + (avo = fft(av)).run(this->exec); + this->exec.sync(); - MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + } MATX_EXIT_HANDLER(); } -TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, FFT1D1024PadR2C) +TYPED_TEST(FFTTestComplexTypes, FFT1D1024PadR2C) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - const index_t fft_dim = 4; - using rtype = typename TestType::value_type; - this->pb->template InitAndRunTVGenerator( - "00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim*2}); + if constexpr (!detail::CheckFFTSupport()) { + GTEST_SKIP(); + } else { + const index_t fft_dim = 4; + using rtype = typename TestType::value_type; + this->pb->template InitAndRunTVGenerator( + "00_transforms", "fft_operators", "rfft_1d", {fft_dim, fft_dim*2}); - tensor_t av{{fft_dim}}; - tensor_t avo{{fft_dim + 1}}; - this->pb->NumpyToTensorView(av, "a_in"); + tensor_t av{{fft_dim}}; + tensor_t avo{{fft_dim + 1}}; + this->pb->NumpyToTensorView(av, "a_in"); - (avo = fft(av, fft_dim*2)).run(this->exec); - this->exec.sync(); + (avo = fft(av, fft_dim*2)).run(this->exec); + this->exec.sync(); - MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + } MATX_EXIT_HANDLER(); } @@ -710,19 +720,25 @@ TYPED_TEST(FFTTestComplexNonHalfTypesAllExecs, FFT1D1024PadBatchedR2C) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; - const index_t fft_dim = 4; - using rtype = typename TestType::value_type; - this->pb->template InitAndRunTVGenerator( - "00_transforms", "fft_operators", "rfft_1d_batched", {fft_dim, fft_dim, fft_dim*2}); + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - tensor_t av{{fft_dim, fft_dim}}; - tensor_t avo{{fft_dim, fft_dim + 1}}; - this->pb->NumpyToTensorView(av, "a_in"); + if constexpr (!detail::CheckFFTSupport()) { + GTEST_SKIP(); + } else { + const index_t fft_dim = 4; + using rtype = typename TestType::value_type; + this->pb->template InitAndRunTVGenerator( + "00_transforms", "fft_operators", "rfft_1d_batched", {fft_dim, fft_dim, fft_dim*2}); - (avo = fft(av, fft_dim*2)).run(this->exec); - this->exec.sync(); + tensor_t av{{fft_dim, fft_dim}}; + tensor_t avo{{fft_dim, fft_dim + 1}}; + this->pb->NumpyToTensorView(av, "a_in"); - MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + (avo = fft(av, fft_dim*2)).run(this->exec); + this->exec.sync(); + + MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + } MATX_EXIT_HANDLER(); } @@ -875,43 +891,55 @@ TYPED_TEST(FFTTestComplexTypes, IFFT2D16x32C2C) MATX_EXIT_HANDLER(); } -TYPED_TEST(FFTTestComplexNonHalfTypes, FFT2D16R2C) +TYPED_TEST(FFTTestComplexTypes, FFT2D16R2C) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; - const index_t fft_dim = 16; - using rtype = typename TestType::value_type; - this->pb->template InitAndRunTVGenerator( - "00_transforms", "fft_operators", "rfft_2d", {fft_dim, fft_dim}); + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - tensor_t av{{fft_dim, fft_dim}}; - tensor_t avo{{fft_dim, fft_dim / 2 + 1}}; - this->pb->NumpyToTensorView(av, "a_in"); + if constexpr (!detail::CheckFFTSupport()) { + GTEST_SKIP(); + } else { + const index_t fft_dim = 16; + using rtype = typename TestType::value_type; + this->pb->template InitAndRunTVGenerator( + "00_transforms", "fft_operators", "rfft_2d", {fft_dim, fft_dim}); - (avo = fft2(av)).run(this->exec); - this->exec.sync(); + tensor_t av{{fft_dim, fft_dim}}; + tensor_t avo{{fft_dim, fft_dim / 2 + 1}}; + this->pb->NumpyToTensorView(av, "a_in"); - MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + (avo = fft2(av)).run(this->exec); + this->exec.sync(); + + MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + } MATX_EXIT_HANDLER(); } -TYPED_TEST(FFTTestComplexNonHalfTypes, FFT2D16x32R2C) +TYPED_TEST(FFTTestComplexTypes, FFT2D16x32R2C) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; - const index_t fft_dim[] = {16, 32}; - using rtype = typename TestType::value_type; - this->pb->template InitAndRunTVGenerator( - "00_transforms", "fft_operators", "rfft_2d", {fft_dim[0], fft_dim[1]}); + using ExecType = cuda::std::tuple_element_t<1, TypeParam>; - tensor_t av{{fft_dim[0], fft_dim[1]}}; - tensor_t avo{{fft_dim[0], fft_dim[1] / 2 + 1}}; - this->pb->NumpyToTensorView(av, "a_in"); + if constexpr (!detail::CheckFFTSupport()) { + GTEST_SKIP(); + } else { + const index_t fft_dim[] = {16, 32}; + using rtype = typename TestType::value_type; + this->pb->template InitAndRunTVGenerator( + "00_transforms", "fft_operators", "rfft_2d", {fft_dim[0], fft_dim[1]}); - (avo = fft2(av)).run(this->exec); - this->exec.sync(); + tensor_t av{{fft_dim[0], fft_dim[1]}}; + tensor_t avo{{fft_dim[0], fft_dim[1] / 2 + 1}}; + this->pb->NumpyToTensorView(av, "a_in"); - MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + (avo = fft2(av)).run(this->exec); + this->exec.sync(); + + MATX_TEST_ASSERT_COMPARE(this->pb, avo, "a_out", this->thresh); + } MATX_EXIT_HANDLER(); }