|
32 | 32 |
|
33 | 33 | #pragma once |
34 | 34 |
|
| 35 | +#include "matx/kernels/pwelch.cuh" |
| 36 | + |
35 | 37 | namespace matx |
36 | 38 | { |
37 | | - |
38 | | - enum PwelchOutputScaleMode { |
39 | | - PwelchOutputScaleMode_Spectrum, |
40 | | - PwelchOutputScaleMode_Density, |
41 | | - PwelchOutputScaleMode_Spectrum_dB, |
42 | | - PwelchOutputScaleMode_Density_dB |
43 | | - }; |
44 | | - |
45 | | - namespace detail { |
46 | | - template<PwelchOutputScaleMode OUTPUT_SCALE_MODE, typename T_IN, typename T_OUT> |
47 | | - __global__ void pwelch_kernel(const T_IN t_in, T_OUT t_out, typename T_OUT::value_type fs) |
48 | | - { |
49 | | - const index_t tid = blockIdx.x * blockDim.x + threadIdx.x; |
50 | | - const index_t batches = t_in.Shape()[0]; |
51 | | - const index_t nfft = t_in.Shape()[1]; |
52 | | - |
53 | | - if (tid < nfft) |
54 | | - { |
55 | | - typename T_OUT::value_type pxx = 0; |
56 | | - constexpr typename T_OUT::value_type ten = 10; |
57 | | - |
58 | | - for (index_t batch = 0; batch < batches; batch++) |
59 | | - { |
60 | | - pxx += cuda::std::norm(t_in(batch, tid)); |
61 | | - } |
62 | | - |
63 | | - if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum) |
64 | | - { |
65 | | - t_out(tid) = pxx / batches; |
66 | | - } |
67 | | - else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density) |
68 | | - { |
69 | | - t_out(tid) = pxx / (batches * fs); |
70 | | - } |
71 | | - else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Spectrum_dB) |
72 | | - { |
73 | | - pxx /= batches; |
74 | | - if (pxx != 0) |
75 | | - { |
76 | | - t_out(tid) = ten * cuda::std::log10(pxx); |
77 | | - } |
78 | | - else |
79 | | - { |
80 | | - t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest(); |
81 | | - } |
82 | | - } |
83 | | - else if constexpr (OUTPUT_SCALE_MODE == PwelchOutputScaleMode_Density_dB) |
84 | | - { |
85 | | - pxx /= (batches * fs); |
86 | | - if (pxx != 0) |
87 | | - { |
88 | | - t_out(tid) = ten * cuda::std::log10(pxx); |
89 | | - } |
90 | | - else |
91 | | - { |
92 | | - t_out(tid) = cuda::std::numeric_limits<typename T_OUT::value_type>::lowest(); |
93 | | - } |
94 | | - } |
95 | | - } |
96 | | - } |
97 | | - }; |
98 | | - |
99 | | - extern int g_pwelch_alg_mode; |
100 | 39 | template <typename PxxType, typename xType, typename wType> |
101 | 40 | __MATX_INLINE__ void pwelch_impl(PxxType Pxx, const xType& x, const wType& w, index_t nperseg, index_t noverlap, index_t nfft, PwelchOutputScaleMode output_scale_mode, typename PxxType::value_type fs, cudaStream_t stream=0) |
102 | | - { |
| 41 | + { |
| 42 | + #ifndef __CUDACC__ |
| 43 | + MATX_THROW(matxNotSupported, "pwelch not supported on host"); |
| 44 | + #else |
103 | 45 | MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) |
104 | 46 |
|
105 | 47 | MATX_ASSERT_STR(Pxx.Rank() == x.Rank(), matxInvalidDim, "pwelch: Pxx rank must be the same as x rank"); |
@@ -141,6 +83,6 @@ namespace matx |
141 | 83 | { |
142 | 84 | detail::pwelch_kernel<PwelchOutputScaleMode_Density_dB><<<bpk, tpb, 0, stream>>>(X_with_overlaps, Pxx, fs); |
143 | 85 | } |
144 | | - } |
145 | | - |
| 86 | + #endif |
| 87 | + } |
146 | 88 | } // end namespace matx |
0 commit comments