Skip to content

Commit

Permalink
wip #2 on custom kernel fft
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel-geon-park committed Dec 26, 2017
1 parent f85a91f commit 159ef37
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 61 deletions.
71 changes: 53 additions & 18 deletions src/kernels/fft.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@

#include <cstdio>
#include <algorithm>
#include <mgcpp/kernels/bits/fft.cuh>

//#define BLK 8
#define BLK 64
#define BLK 64Lu
//#define BLK 64Lu
#define PI(T) static_cast<T>(3.141592653589793238462643383279502884197169399375105820974944)

namespace mgcpp
Expand Down Expand Up @@ -42,41 +44,74 @@ namespace mgcpp

template<typename T>
__global__ void
mgblas_rfft_impl(T const *x, cmplx<T> *y, size_t n)
mgblas_rfft_impl(T const *x, cmplx<T> *y, size_t n, size_t m)
{
__shared__ cmplx<T> s[BLK];
int const idx = blockIdx.x * blockDim.x + threadIdx.x;
int const tid = threadIdx.x;
int const idx = blockIdx.x * BLK + tid;

if (idx < n) {
s[tid].real = x[idx];
s[tid].imag = 0;
__syncthreads();
for (int k = 2; k <= m; k <<= 1) {
int const i = tid % k;
if (i < k / 2) {
int const a = tid;
int const b = a + k / 2;
T phi = -2 * PI(T) * i / k;
cmplx<T> z = {cos(phi), sin(phi)};
cmplx<T> u = s[a], v = s[b] * z;
s[a] = u + v;
s[b] = u - v;
}
__syncthreads();
}
y[idx] = s[tid];
}
}

for (int k = 2; k <= n; k <<= 1) {
int const i = idx % k;
if (i < k / 2) {
int const a = idx - blockIdx.x * blockDim.x;
int const b = a + k / 2;
T phi = -2 * PI(T) * i / k;
cmplx<T> z = {cos(phi), sin(phi)}; // z = W_k^(idx%k)
cmplx<T> u = s[a];
s[a] = u + s[b] * z;
s[b] = u - s[b] * z;
}
template<typename T>
__global__ void
mgblas_cfft_impl(cmplx<T> *x, cmplx<T> *y, size_t n, size_t level, size_t m)
{
__shared__ cmplx<T> s[BLK];
int const tid = threadIdx.x;
int const idx = blockIdx.x * BLK + tid;
int const jump = n / level;
int const sidx = idx / jump + (idx % jump) * level;

if (sidx < n) {
s[tid] = x[sidx];
__syncthreads();

for (int k = 2; k <= m; k <<= 1) {
int const i = tid % k;
if (i < k / 2) {
int const a = tid;
int const b = a + k / 2;
T phi = -2 * PI(T) * (sidx % (k * level)) / (k * level);
cmplx<T> z = {cos(phi), sin(phi)}; // z = W_k^(idx%k)
cmplx<T> u = s[a], v = s[b] * z;
s[a] = u + v;
s[b] = u - v;
}
__syncthreads();
}

y[idx] = s[tid];
y[sidx] = s[tid];
}
}

kernel_status_t
mgblas_Srfft(float const *x, float *y, size_t n)
{
cmplx<float> *cy = reinterpret_cast<cmplx<float>*>(y);
int grid_size = static_cast<int>(ceil(static_cast<float>(n)/ BLK));
mgblas_rfft_impl<float><<<grid_size, BLK>>>(x, cy, n);
int grid_size = static_cast<int>(ceil(static_cast<float>(n)/ BLK));
mgblas_rfft_impl<float><<<grid_size, BLK>>>(x, cy, n, std::min(n, BLK));
for (size_t m = n / BLK, level = BLK; m > 1; level *= BLK, m /= BLK) {
mgblas_cfft_impl<float><<<grid_size, BLK>>>(cy, cy, n, level, std::min(m, BLK));
}

return success;
}
Expand Down
99 changes: 56 additions & 43 deletions test/fft_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,43 +3,74 @@

#include <mgcpp/operations/fft.hpp>

#include <random>
#include <complex>
#include <valarray>
using complex = std::complex<double>;
using carray = std::valarray<complex>;
constexpr double PI = 3.1415926535897932384626433832795028;
void fft(carray &a, bool inv)
{
int n = a.size();
for (int i = 1, j = 0; i < n; i++) {
int bit = n >> 1;
for (; j >= bit; bit >>= 1) j -= bit;
j += bit;
if (i < j) std::swap(a[i], a[j]);
}
for (int len = 2; len <= n; len <<= 1) {
complex wlen = std::polar(1., 2 * PI / len * (inv? -1: 1));
for (int i = 0; i < n; i += len) {
complex w(1);
for (int j = 0; j < len / 2; j++){
complex u = a[i + j], v = a[i + j + len / 2] * w;
a[i + j] = u + v;
a[i + j + len / 2] = u - v;
w *= wlen;
}
}
}
if (inv) {
for (int i = 0; i < n; i++) a[i] /= n;
}
}

std::default_random_engine rng;
std::uniform_real_distribution<double> dist(0.0, 1.0);

TEST(fft_operation, float_real_to_complex_fwd_fft)
{
mgcpp::device_vector<float> vec({
1, 2, 1, -1, 1, 1, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3
});
size_t size = 1024;

size_t size = vec.size();
mgcpp::device_vector<float> vec(size);
for (auto i = 0u; i < size; ++i) vec.set_value(i, dist(rng));

carray expected(size);
for (auto i = 0u; i < vec.size(); ++i)
expected[i] = vec.check_value(i);
fft(expected, false);

mgcpp::device_vector<float> result;
EXPECT_NO_THROW({result = mgcpp::strict::rfft(vec);});

float expected[] = {
24.000000, 0.000000,
-2.071930, 5.002081,
4.242640, 1.414214,
2.388955, -0.989537,
0.000000, 0.000000,
-2.388955, -0.989537,
-4.242640, 1.414214,
2.071930, 5.002081,
-8.000000, 0.000000,
};

EXPECT_EQ(result.size(), size / 2 * 2 + 2);
for (auto i = 0u; i < result.size(); ++i) {
EXPECT_NEAR(result.check_value(i), expected[i], 1e-5);
for (auto i = 0u; i < result.size() / 2; ++i) {
EXPECT_NEAR(result.check_value(i * 2), expected[i].real(), 1e-4);
}
}

#include <mgcpp/kernels/bits/fft.cuh>
TEST(fft_operation, float_real_to_complex_fwd_fft_custom_kernel)
{
mgcpp::device_vector<float> vec({
1, 2, 1, -1, 1, 1, 1, 3, 1, 3, 1, 3, 1, 2, 1, 3
});
size_t size = 1024;

size_t size = vec.size();
mgcpp::device_vector<float> vec(size);
for (auto i = 0u; i < size; ++i) vec.set_value(i, dist(rng));

carray expected(size);
for (auto i = 0u; i < vec.size(); ++i)
expected[i] = vec.check_value(i);
fft(expected, false);

mgcpp::device_vector<float> result(size * 2);

Expand All @@ -55,28 +86,10 @@ TEST(fft_operation, float_real_to_complex_fwd_fft_custom_kernel)
}
mgcpp::mgblas_Srfft(vec.data(), result.data_mutable(), size);

float expected[] = {
24.000000, 0.000000,
-2.071930, 5.002081,
4.242640, 1.414214,
2.388955, -0.989537,
0.000000, 0.000000,
-2.388955, -0.989537,
-4.242640, 1.414214,
2.071930, 5.002081,
-8.000000, 0.000000,
2.07193, -5.00208,
-4.24264, -1.41421,
-2.38896, 0.989538,
0., 0.,
2.38896, 0.989538,
4.24264, -1.41421,
-2.07193, -5.00208
};
//EXPECT_EQ(result.size(), size * 2);

EXPECT_EQ(result.size(), size * 2);
for (auto i = 0u; i < result.size(); ++i) {
EXPECT_NEAR(result.check_value(i), expected[i], 1e-5);
for (auto i = 0u; i < result.size() / 2; ++i) {
EXPECT_NEAR(result.check_value(i * 2), expected[i].real(), 1e-4);
}
}

Expand Down

0 comments on commit 159ef37

Please sign in to comment.