Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cuda source cleanup , refactor and fixes #1328

Merged
Merged
Next Next commit
remove kcompress
abhilash1910 authored Aug 20, 2024
commit a29f44e81e6bfab38180e2f448771b44e2edb8de
80 changes: 0 additions & 80 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
@@ -519,86 +519,6 @@ __global__ void kHistogramScatterAdd2D(float* histogram, int *index1, int *index
}
}

template<typename T, int BLOCK_SIZE, int NUM_MAX>
__global__ void kCompressMax(T * __restrict__ const A, T* out, unsigned char* out_idx, const int n)
{
typedef cub::WarpReduce<T> WarpReduce;
__shared__ typename WarpReduce::TempStorage temp_storage;
typedef cub::BlockLoad<T, BLOCK_SIZE/8 , 8, cub::BLOCK_LOAD_WARP_TRANSPOSE> LoadT;
__shared__ typename LoadT::TempStorage loadt;

const int warp_idx = threadIdx.x/32;
const int valid_items = n - (blockIdx.x*BLOCK_SIZE) > BLOCK_SIZE ? BLOCK_SIZE : n - (blockIdx.x*BLOCK_SIZE);

// BLOCK_SIZE/32 == number of warps
__shared__ int smem_max_indices[8*BLOCK_SIZE/32];
__shared__ float smem_max_values[8*BLOCK_SIZE/32];

T values[8];
T max1 = -64000.0f;
T max2 = -64000.0f;
int max_idx1 = -1;
int max_idx2 = -1;
int sign1 = -1;
int sign2 = -1;

// 1. load 8 values per thread
// 2. compute 2-max in registers (64 max per warp)
// 3. do warp reduction + broadcast back
// 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
// 5. Repeat (3) 8 times for top 8 values in 256
// 6. store with byte index

LoadT(loadt).Load(&(A[(blockIdx.x*BLOCK_SIZE)]), values, valid_items, (T)0.0f);
#pragma unroll 8
for(int i = 0; i < 8; i++)
{
T absval = fabsf(values[i]);
if(absval > max1)
{
max1 = values[i];
sign1 = signbit(values[i]);
max_idx1 = 8*threadIdx.x + i;
}
else if(absval > max2)
{
max2 = values[i];
sign2 = signbit(values[i]);
max_idx2 = 8*threadIdx.x + i;
}
}

float warp_max;
for(int i = 0; i < 8; i++)
{
// 3. do warp reduction + broadcast back
warp_max = WarpReduce(temp_storage).Reduce(max1, cub::Max());
warp_max = cub::ShuffleIndex<32>(warp_max, 0, 0xffffffff);

// 4. Up-shift maxed value, write index into shared memory, replace with 2nd largest
if(warp_max == max1)
{
smem_max_values[warp_idx*8 + i] = sign1 != 0 ? -max1 : max1;
smem_max_indices[warp_idx*8 + i] = max_idx1;

sign1 = sign2;
max1 = max2;
max_idx1 = max_idx2;

max2 = -64000.0f;
}
__syncwarp();
}

if(threadIdx.x % 32 < 8)
{
// offset: 8 values per 256 input values
//
int offset = BLOCK_SIZE*blockIdx.x*BLOCK_SIZE/32*8;
}

}

#define THREADS_ESTIMATE 512
#define NUM_ESTIMATE 8
#define BLOCK_ESTIMATE 4096