-
Notifications
You must be signed in to change notification settings - Fork 112
[BUG] Poor pwelch performance #887
Description
Describe the Bug
pwelch reduction stage performs about 10x worse than a similar hand-rolled CUDA implementation.
The relevant parameters in this case are:
- fft size: 65536
- window size: 500
- window overlap: 250
- number of input samples: 164750
- number of FFTs: 658
From looking at the pwelch_impl source, it looks like it is effectively 3 operations:
- FFTs
X_with_overlaps = conj(X_with_overlaps) * X_with_overlapsPxx = sum(mag_sq_X_with_overlaps, {0}) * norm_factor
On my system using matx, step 2 takes about 40us and step 3 takes 380us.
A custom CUDA kernel for steps 2&3 is taking 30us.
To Reproduce
- Allocate
complex<float>input tensor of size 164750 - Create a hamming window of size 500
- Call
matx::pwelch(input, window, 500, 250, 65536)
Expected Behavior
Better performance
Code Snippets
The custom CUDA kernel in this case is:
template<typename ComplexT, typename RealT>
__global__ void ComputePowerSpectrumMultiFFT(const ComplexT* fft_buff, RealT* power_spectrum,
int num_ffts, int num_bins)
{
const int id = blockIdx.x * blockDim.x + threadIdx.x;
const int num_threads = blockDim.x * gridDim.x;
for (int bin = id; bin < num_bins; bin += num_threads)
{
RealT pwr = 0;
for (int fft = 0; fft < num_ffts; fft++)
{
pwr += norm(fft_buff[fft * num_bins + bin]);
}
pwr /= num_ffts;
// our kernel is outputting in dB, but that difference shouldn't matter.
// Although, it would be nice if MatX allowed the `10 * log10` to be fused
// with the rest of pwelch somehow.
if (pwr <= 0)
power_spectrum[bin] = 0;
else
power_spectrum[bin] = 10 * log10(pwr);
}
}
This kernel is not doing the FFT portion of pwelch, but I am not reporting any issue with the FFT portion of pwelch, only the reduction stages.
System Details (please complete the following information):
- OS: Rocky 9
- CUDA version:CUDA 12.3
- g++ version: 11.4.1
Additional Context
magnitude squared calculation info from nsys profile:
Begins: 14.0713s
Ends: 14.0713s (+43.936 μs)
grid: <<<256, 31, 1>>>
block: <<<256, 1, 1>>>
Reduction+normalization (using matx) info from nsys profile:
Begins: 14.0713s
Ends: 14.0717s (+383.104 μs)
grid: <<<65536, 1, 1>>>
block: <<<256, 1, 1>>>
Reduction+normalization (using custom kernel) info from nsys profile:
Begins: 14.0684s
Ends: 14.0684s (+29.344 μs)
grid: <<<64, 1, 1>>>
block: <<<1024, 1, 1>>>