Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logpolar optimisation #99

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP
neon60 committed Mar 6, 2025
commit c908694c38a822897ed33dde3d87a3c2fd65c933
79 changes: 40 additions & 39 deletions tomobar/cuda_kernels/fft_us_kernels.cu
Original file line number Diff line number Diff line change
@@ -99,20 +99,10 @@ extern "C" __global__ void gather_kernel_center(float2 *g, float2 *f, float *the

const int center_half_size = center_size/2;

//int tx = blockDim.x * blockIdx.x + threadIdx.x;
//int ty = blockDim.y * blockIdx.y + threadIdx.y;

//int tx = max(0, n + m - center_half_size) + blockDim.x * blockIdx.x + threadIdx.x;
//int ty = max(0, n + m - center_half_size) + blockIdx.y;
//int tz = blockDim.z * blockIdx.z + threadIdx.z;

int tx = max(0, n + m - center_half_size) + blockDim.z * blockIdx.z + threadIdx.z;
int ty = max(0, n + m - center_half_size) + blockDim.y * blockIdx.y + threadIdx.y;
int tz = blockDim.x * blockIdx.x + threadIdx.x;

int proj_count = 1; //blockDim.y;
int proj_offset = 0; //threadIdx.y % proj_count;

if (tx >= 2 * n + 2 * m || ty >= 2 * n + 2 * m || tz >= nz)
return;

@@ -138,7 +128,7 @@ extern "C" __global__ void gather_kernel_center(float2 *g, float2 *f, float *the
// Point coordinates
float2 point = make_float2(float(tx - (n+m)) / float(2 * n), float((n+m) - ty) / float(2 * n));

for( int proj_index = proj_offset; proj_index < nproj; proj_index+=proj_count) {
for (int proj_index = 0; proj_index < nproj; proj_index++) {

float sintheta, costheta;
__sincosf(theta[proj_index], &sintheta, &costheta);
@@ -179,40 +169,51 @@ extern "C" __global__ void gather_kernel_center(float2 *g, float2 *f, float *the
radius_max = radius_max < 0 ? 0 : radius_max;
radius_max = radius_max > (n-1) ? (n-1) : radius_max;

for( int radius_index = radius_min; radius_index < radius_max; radius_index++) {

int g_ind = radius_index + proj_index * n + tz * n * nproj;

float x0 = (radius_index - n / 2) / (float)(n) * costheta;
float y0 = (radius_index - n / 2) / (float)(n) * sintheta;

if (x0 >= 0.5f)
x0 = 0.5f - 1e-5;
if (y0 >= 0.5f)
y0 = 0.5f - 1e-5;

float w0 = point.x - x0;
float w1 = point.y - y0;
float w = coeff0 * __expf(coeff1 * (w0 * w0 + w1 * w1));

float2 g0, g0t;

g0.x = g[g_ind].x;
g0.y = g[g_ind].y;
g0t.x = w*g0.x;
g0t.y = w*g0.y;

f_value.x += g0t.x;
f_value.y += g0t.y;
constexpr int length = 4;
float2 f_values[length];
for (int radius_index = radius_min; radius_index < radius_max; radius_index+=length) {

#pragma unroll
for (int i = 0; i < length; i++) {
int g_ind = radius_index + i + proj_index * n + tz * n * nproj;
if( radius_index + i < radius_max ) {
f_values[i].x = g[g_ind].x;
f_values[i].y = g[g_ind].y;
} else {
f_values[i].x = 0.f;
f_values[i].y = 0.f;
}
}

#pragma unroll
for (int i = 0; i < length; i++) {
float x0 = (radius_index + i - n / 2) / (float)n * costheta;
float y0 = (radius_index + i - n / 2) / (float)n * sintheta;

if (x0 >= 0.5f)
x0 = 0.5f - 1e-5;
if (y0 >= 0.5f)
y0 = 0.5f - 1e-5;

float w0 = point.x - x0;
float w1 = point.y - y0;
float w = coeff0 * __expf(coeff1 * (w0 * w0 + w1 * w1));

f_values[i].x *= w;
f_values[i].y *= w;
}

#pragma unroll
for (int i = 0; i < length; i++) {
f_value.x += f_values[i].x;
f_value.y += f_values[i].y;
}
}
}
}

f[f_ind].x = f_value.x;
f[f_ind].y = f_value.y;

// atomicAdd(&(f[f_ind].x), f_value.x);
// atomicAdd(&(f[f_ind].y), f_value.y);
}

extern "C" __global__ void wrap_kernel(float2 *f,
2 changes: 1 addition & 1 deletion tomobar/methodsDIR_CuPy.py
Original file line number Diff line number Diff line change
@@ -292,7 +292,7 @@ def FOURIER_INV(self, data: xp.ndarray, **kwargs) -> xp.ndarray:
# y1 = -(n - n / 2) / n * np.sin(theta);
# print(x1 - x0)
# print(y1 - y0)
center_size = 256
center_size = 512

# STEP2: interpolation (gathering) in the frequency domain
block_dim_x = 32