@@ -1885,7 +1885,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
18851885// rowStats [rows]
18861886// out [rows, cols]
18871887template <typename T, int THREADS, int SPARSE_DECOMP>
1888- __launch_bounds__ (1024 , BNB_MAX_THREADS_PER_SM / 1024 )
1888+ __launch_bounds__ (1024 , BNB_MAX_THREADS_PER_CU / 1024 )
18891889__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t * out, float * rowStats, float threshold, int rows, int cols) {
18901890
18911891 // For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
@@ -2018,11 +2018,6 @@ __global__ void kdequant_mm_int32_fp16(
20182018#define DENORM 1 .0f /127 .0f
20192019#define MAX_SPARSE_COUNT 32
20202020#define SMEM_SIZE 8 *256
2021- #if defined(__GFX9__)
2022- #define WARP_SIZE 64
2023- #else
2024- #define WARP_SIZE 32
2025- #endif
20262021template <typename T, int SPMM_ITEMS, int BITS>
20272022__global__ void kspmm_coo_very_sparse_naive (int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
20282023{
@@ -2043,9 +2038,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
20432038 const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1 ];
20442039 const int local_row_idx = rowidx[offset];
20452040
2046- const int warp_id = threadIdx.x / WARP_SIZE ;
2047- const int warp_idx = threadIdx.x % WARP_SIZE ;
2048- const int warp_offset = (warp_id*WARP_SIZE )*SPMM_ITEMS;
2041+ const int warp_id = threadIdx.x / BNB_WARP_SIZE ;
2042+ const int warp_idx = threadIdx.x % BNB_WARP_SIZE ;
2043+ const int warp_offset = (warp_id*BNB_WARP_SIZE )*SPMM_ITEMS;
20492044 const int num_items = BITS == 8 ? 8 : 8 ;
20502045 int idx_col_B = warp_offset;
20512046 int local_idx_col_B_offset = 0 ;
@@ -2065,7 +2060,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
20652060 }
20662061
20672062 // each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
2068- // we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
2063+ // we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart
20692064 // we have a total of 128 bytes for the bank with a bank size of 4 bytes
20702065 // added 3 bytes = 6 values between warps should reduce bank conflicts
20712066 __shared__ half smem_dequant_stats[SMEM_SIZE];
@@ -2618,15 +2613,15 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
26182613{
26192614
26202615 // per threadblock:
2621- // load step-by-step in chunks of [warp_size ,warps]: 1xwarp_size * [warp_size ,warps] -> [1,warps]
2616+ // load step-by-step in chunks of [BNB_WARP_SIZE ,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE ,warps] -> [1,warps]
26222617 // 4 warps -> 4 loads per iter
2623- // 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
2624- typedef hipcub::WarpReduce<float , WARP_SIZE > WarpReduce;
2625- __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE ];
2618+ // 1 x BNB_WARP_SIZE * BNB_WARP_SIZE x 4 -> 1x4 outputs per thread block
2619+ typedef hipcub::WarpReduce<float , BNB_WARP_SIZE > WarpReduce;
2620+ __shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE ];
26262621
2627- const int warp_idx = threadIdx.x / WARP_SIZE ;
2628- const int warp_lane = threadIdx.x % WARP_SIZE ;
2629- const int row_B = (THREADS/WARP_SIZE )*blockIdx.x + warp_idx;
2622+ const int warp_idx = threadIdx.x / BNB_WARP_SIZE ;
2623+ const int warp_lane = threadIdx.x % BNB_WARP_SIZE ;
2624+ const int row_B = (THREADS/BNB_WARP_SIZE )*blockIdx.x + warp_idx;
26302625 const int offset_B = ldb * row_B;
26312626 const int num_values_8bit = num_values_4bit/2 ;
26322627 float local_C = 0 .0f ;
@@ -2645,7 +2640,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
26452640
26462641 // A: [1, K]
26472642 // B: [M, K]
2648- for (int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE *num_values_4bit)
2643+ for (int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE *num_values_4bit)
26492644 {
26502645 const int inner_idx_halved = inner_idx/2 ;
26512646
@@ -2957,23 +2952,29 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
29572952MAKE_kQuantizeBlockwise(half, 512 , 2 , 0 , General8bit)
29582953MAKE_kQuantizeBlockwise(half, 256 , 2 , 0 , General8bit)
29592954MAKE_kQuantizeBlockwise(half, 128 , 2 , 0 , General8bit)
2960- // MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
2955+ #if BNB_WARP_SIZE == 32
2956+ MAKE_kQuantizeBlockwise (half, 64 , 2 , 0 , General8bit)
2957+ #endif
29612958
29622959MAKE_kQuantizeBlockwise (half, 4096 , 4 , 0 , FP4)
29632960MAKE_kQuantizeBlockwise(half, 2048 , 4 , 0 , FP4)
29642961MAKE_kQuantizeBlockwise(half, 1024 , 4 , 0 , FP4)
29652962MAKE_kQuantizeBlockwise(half, 512 , 2 , 0 , FP4)
29662963MAKE_kQuantizeBlockwise(half, 256 , 2 , 0 , FP4)
29672964MAKE_kQuantizeBlockwise(half, 128 , 2 , 0 , FP4)
2968- // MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
2965+ #if BNB_WARP_SIZE == 32
2966+ MAKE_kQuantizeBlockwise (half, 64 , 2 , 0 , FP4)
2967+ #endif
29692968
29702969MAKE_kQuantizeBlockwise (half, 4096 , 4 , 0 , NF4)
29712970MAKE_kQuantizeBlockwise(half, 2048 , 4 , 0 , NF4)
29722971MAKE_kQuantizeBlockwise(half, 1024 , 4 , 0 , NF4)
29732972MAKE_kQuantizeBlockwise(half, 512 , 2 , 0 , NF4)
29742973MAKE_kQuantizeBlockwise(half, 256 , 2 , 0 , NF4)
29752974MAKE_kQuantizeBlockwise(half, 128 , 2 , 0 , NF4)
2976- // MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
2975+ #if BNB_WARP_SIZE == 32
2976+ MAKE_kQuantizeBlockwise (half, 64 , 2 , 0 , NF4)
2977+ #endif
29772978
29782979MAKE_kQuantizeBlockwise (float , 4096 , 4 , 0 , General8bit)
29792980MAKE_kQuantizeBlockwise(float , 4096 , 4 , 1 , General8bit)
@@ -2982,23 +2983,29 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
29822983MAKE_kQuantizeBlockwise(float , 512 , 2 , 0 , General8bit)
29832984MAKE_kQuantizeBlockwise(float , 256 , 2 , 0 , General8bit)
29842985MAKE_kQuantizeBlockwise(float , 128 , 2 , 0 , General8bit)
2985- // MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
2986+ #if BNB_WARP_SIZE == 32
2987+ MAKE_kQuantizeBlockwise (float , 64 , 2 , 0 , General8bit)
2988+ #endif
29862989
29872990MAKE_kQuantizeBlockwise (float , 4096 , 4 , 0 , FP4)
29882991MAKE_kQuantizeBlockwise(float , 2048 , 4 , 0 , FP4)
29892992MAKE_kQuantizeBlockwise(float , 1024 , 4 , 0 , FP4)
29902993MAKE_kQuantizeBlockwise(float , 512 , 2 , 0 , FP4)
29912994MAKE_kQuantizeBlockwise(float , 256 , 2 , 0 , FP4)
29922995MAKE_kQuantizeBlockwise(float , 128 , 2 , 0 , FP4)
2993- // MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
2996+ #if BNB_WARP_SIZE == 32
2997+ MAKE_kQuantizeBlockwise (float , 64 , 2 , 0 , FP4)
2998+ #endif
29942999
29953000MAKE_kQuantizeBlockwise (float , 4096 , 4 , 0 , NF4)
29963001MAKE_kQuantizeBlockwise(float , 2048 , 4 , 0 , NF4)
29973002MAKE_kQuantizeBlockwise(float , 1024 , 4 , 0 , NF4)
29983003MAKE_kQuantizeBlockwise(float , 512 , 2 , 0 , NF4)
29993004MAKE_kQuantizeBlockwise(float , 256 , 2 , 0 , NF4)
30003005MAKE_kQuantizeBlockwise(float , 128 , 2 , 0 , NF4)
3001- // MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
3006+ #if BNB_WARP_SIZE == 32
3007+ MAKE_kQuantizeBlockwise (float , 64 , 2 , 0 , NF4)
3008+ #endif
30023009
30033010MAKE_kQuantizeBlockwise (hip_bfloat16, 4096 , 4 , 0 , General8bit)
30043011MAKE_kQuantizeBlockwise(hip_bfloat16, 4096 , 4 , 1 , General8bit)
@@ -3007,23 +3014,29 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
30073014MAKE_kQuantizeBlockwise(hip_bfloat16, 512 , 2 , 0 , General8bit)
30083015MAKE_kQuantizeBlockwise(hip_bfloat16, 256 , 2 , 0 , General8bit)
30093016MAKE_kQuantizeBlockwise(hip_bfloat16, 128 , 2 , 0 , General8bit)
3010- // MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
3017+ #if BNB_WARP_SIZE == 32
3018+ MAKE_kQuantizeBlockwise (hip_bfloat16, 64 , 2 , 0 , General8bit)
3019+ #endif
30113020
30123021MAKE_kQuantizeBlockwise (hip_bfloat16, 4096 , 4 , 0 , FP4)
30133022MAKE_kQuantizeBlockwise(hip_bfloat16, 2048 , 4 , 0 , FP4)
30143023MAKE_kQuantizeBlockwise(hip_bfloat16, 1024 , 4 , 0 , FP4)
30153024MAKE_kQuantizeBlockwise(hip_bfloat16, 512 , 2 , 0 , FP4)
30163025MAKE_kQuantizeBlockwise(hip_bfloat16, 256 , 2 , 0 , FP4)
30173026MAKE_kQuantizeBlockwise(hip_bfloat16, 128 , 2 , 0 , FP4)
3018- // MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
3027+ #if BNB_WARP_SIZE == 32
3028+ MAKE_kQuantizeBlockwise (hip_bfloat16, 64 , 2 , 0 , FP4)
3029+ #endif
30193030
30203031MAKE_kQuantizeBlockwise (hip_bfloat16, 4096 , 4 , 0 , NF4)
30213032MAKE_kQuantizeBlockwise(hip_bfloat16, 2048 , 4 , 0 , NF4)
30223033MAKE_kQuantizeBlockwise(hip_bfloat16, 1024 , 4 , 0 , NF4)
30233034MAKE_kQuantizeBlockwise(hip_bfloat16, 512 , 2 , 0 , NF4)
30243035MAKE_kQuantizeBlockwise(hip_bfloat16, 256 , 2 , 0 , NF4)
30253036MAKE_kQuantizeBlockwise(hip_bfloat16, 128 , 2 , 0 , NF4)
3026- // MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
3037+ #if BNB_WARP_SIZE == 32
3038+ MAKE_kQuantizeBlockwise (hip_bfloat16, 64 , 2 , 0 , NF4)
3039+ #endif
30273040
30283041template __global__ void kDequantizeBlockwise <half, 512 , 64 , 8 , FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
30293042template __global__ void kDequantizeBlockwise <half, 512 , 64 , 8 , General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
0 commit comments