Skip to content

Commit 3f9f6f3

Browse files
electron271sstamenkmatthewdouglas
authored
add support for 64 block size on 32 warp size supported amd gpus (#1748)
* add support for 64 block size on 32 warp size supported amd gpus * uncomment 64 block size support in csrc * only enable 64 block size support on architectures with 32 warp size * use BNB_WARP_SIZE instead of warpSize in ops.hip * Reuse BNB_WARP_SIZE macro * Remove unused WARP_SIZE definitions * remove unused import * Apply suggestion from @matthewdouglas * Apply suggestion from @matthewdouglas * Apply suggestion from @matthewdouglas --------- Co-authored-by: sstamenk <[email protected]> Co-authored-by: Matthew Douglas <[email protected]>
1 parent d1c2b0d commit 3f9f6f3

File tree

11 files changed

+114
-69
lines changed

11 files changed

+114
-69
lines changed

bitsandbytes/backends/cuda/ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from bitsandbytes.functional import CUBLAS_Context, _cuda_device_of, _get_tensor_stream, get_ptr
99

1010
from ..._ops import register_kernel
11-
from ...cextension import HIP_ENVIRONMENT, lib
11+
from ...cextension import ROCM_WARP_SIZE_64, lib
1212

1313

1414
@register_kernel("bitsandbytes::int8_linear_matmul", "cuda")
@@ -211,7 +211,7 @@ def _get_col_absmax(
211211
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
212212
torch._check_is_size(blocksize)
213213

214-
if HIP_ENVIRONMENT:
214+
if ROCM_WARP_SIZE_64:
215215
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
216216
else:
217217
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -269,7 +269,7 @@ def _(
269269
def _dequantize_blockwise_impl(
270270
A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
271271
) -> None:
272-
if HIP_ENVIRONMENT:
272+
if ROCM_WARP_SIZE_64:
273273
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
274274
else:
275275
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -303,7 +303,7 @@ def _dequantize_blockwise_impl(
303303
def _(
304304
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
305305
) -> tuple[torch.Tensor, torch.Tensor]:
306-
if HIP_ENVIRONMENT:
306+
if ROCM_WARP_SIZE_64:
307307
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
308308
else:
309309
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])
@@ -385,7 +385,7 @@ def _dequantize_4bit_impl(
385385
dtype: torch.dtype,
386386
out: torch.Tensor,
387387
) -> None:
388-
if HIP_ENVIRONMENT:
388+
if ROCM_WARP_SIZE_64:
389389
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128])
390390
else:
391391
torch._check(blocksize in [4096, 2048, 1024, 512, 256, 128, 64])

bitsandbytes/cextension.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,13 @@
99
import torch
1010

1111
from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
12-
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_cuda_version_tuple, get_rocm_gpu_arch
12+
from bitsandbytes.cuda_specs import (
13+
CUDASpecs,
14+
get_cuda_specs,
15+
get_cuda_version_tuple,
16+
get_rocm_gpu_arch,
17+
get_rocm_warpsize,
18+
)
1319

1420
logger = logging.getLogger(__name__)
1521

@@ -298,6 +304,7 @@ def get_native_library() -> BNBNativeLibrary:
298304

299305

300306
ROCM_GPU_ARCH = get_rocm_gpu_arch()
307+
ROCM_WARP_SIZE_64 = True if get_rocm_warpsize() == 64 else False
301308

302309
HIP_ENVIRONMENT = False
303310
BNB_BACKEND = "CPU"

bitsandbytes/cuda_specs.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,29 @@ def get_rocm_gpu_arch() -> str:
100100
""",
101101
)
102102
return "unknown"
103+
104+
105+
def get_rocm_warpsize() -> int:
106+
"""Get ROCm warp size."""
107+
logger = logging.getLogger(__name__)
108+
try:
109+
if torch.version.hip:
110+
result = subprocess.run(["rocminfo"], capture_output=True, text=True)
111+
match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
112+
if match:
113+
return int(match.group(1))
114+
else:
115+
# default to 64 to be safe
116+
return 64
117+
else:
118+
# nvidia cards always use 32 warp size
119+
return 32
120+
except Exception as e:
121+
logger.error(f"Could not detect ROCm warp size: {e}. Defaulting to 64. (some 4-bit functions may not work!)")
122+
if torch.cuda.is_available():
123+
logger.warning(
124+
"""
125+
ROCm warp size detection failed despite ROCm being available.
126+
""",
127+
)
128+
return 64

bitsandbytes/functional.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
1717

18-
from .cextension import HIP_ENVIRONMENT, lib
18+
from .cextension import ROCM_WARP_SIZE_64, lib
1919

2020
name2qmap = {}
2121

@@ -806,7 +806,7 @@ def quantize_fp4(
806806
quant_storage=torch.uint8,
807807
):
808808
if blocksize is None:
809-
blocksize = 64 if not HIP_ENVIRONMENT else 128
809+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
810810
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "fp4", quant_storage)
811811

812812

@@ -819,7 +819,7 @@ def quantize_nf4(
819819
quant_storage=torch.uint8,
820820
):
821821
if blocksize is None:
822-
blocksize = 64 if not HIP_ENVIRONMENT else 128
822+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
823823
return quantize_4bit(A, absmax, out, blocksize, compress_statistics, "nf4", quant_storage)
824824

825825

@@ -857,7 +857,7 @@ def quantize_4bit(
857857
"""
858858

859859
if blocksize is None:
860-
blocksize = 64 if not HIP_ENVIRONMENT else 128
860+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
861861

862862
input_shape = A.shape
863863

@@ -912,7 +912,7 @@ def dequantize_fp4(
912912
blocksize: Optional[int] = None,
913913
) -> torch.Tensor:
914914
if blocksize is None:
915-
blocksize = 64 if not HIP_ENVIRONMENT else 128
915+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
916916
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "fp4")
917917

918918

@@ -924,7 +924,7 @@ def dequantize_nf4(
924924
blocksize: Optional[int] = None,
925925
) -> torch.Tensor:
926926
if blocksize is None:
927-
blocksize = 64 if not HIP_ENVIRONMENT else 128
927+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
928928
return dequantize_4bit(A, quant_state, absmax, out, blocksize, "nf4")
929929

930930

@@ -964,7 +964,7 @@ def dequantize_4bit(
964964
"""
965965

966966
if blocksize is None:
967-
blocksize = 64 if not HIP_ENVIRONMENT else 128
967+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
968968

969969
if quant_state is None:
970970
assert absmax is not None and out is not None

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn.functional as F
1212

1313
import bitsandbytes as bnb
14-
from bitsandbytes.cextension import HIP_ENVIRONMENT
14+
from bitsandbytes.cextension import ROCM_WARP_SIZE_64
1515
from bitsandbytes.functional import QuantState
1616
from bitsandbytes.optim import GlobalOptimManager
1717
from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
@@ -221,7 +221,7 @@ def __new__(
221221
data = torch.empty(0)
222222

223223
if blocksize is None:
224-
blocksize = 64 if not HIP_ENVIRONMENT else 128
224+
blocksize = 64 if not ROCM_WARP_SIZE_64 else 128
225225

226226
self = torch.Tensor._make_subclass(cls, data, requires_grad)
227227
self.blocksize = blocksize

csrc/common_hip.cuh

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
#pragma once
22

3-
#define BNB_WARP_SIZE warpSize
3+
#ifdef __GFX9__
4+
#define BNB_WARP_SIZE 64
5+
#else
6+
#define BNB_WARP_SIZE 32
7+
#endif
48

59
// These are set based on current BNB support for CDNA 2 & RDNA 3. Update as needed for future archs
6-
#define BNB_MAX_THREADS_PER_SM 2048
10+
#define BNB_MAX_THREADS_PER_CU 2048
711
#define BNB_BF16_AVAILABLE true

csrc/kernels.hip

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,7 +1885,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
18851885
// rowStats [rows]
18861886
// out [rows, cols]
18871887
template<typename T, int THREADS, int SPARSE_DECOMP>
1888-
__launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024)
1888+
__launch_bounds__(1024, BNB_MAX_THREADS_PER_CU / 1024)
18891889
__global__ void kInt8VectorQuant(T * __restrict__ A, int8_t* out, float* rowStats, float threshold, int rows, int cols) {
18901890

18911891
// For sm50/sm52 and CUDA < 12.2 we need to do the reduction in fp32.
@@ -2018,11 +2018,6 @@ __global__ void kdequant_mm_int32_fp16(
20182018
#define DENORM 1.0f/127.0f
20192019
#define MAX_SPARSE_COUNT 32
20202020
#define SMEM_SIZE 8*256
2021-
#if defined(__GFX9__)
2022-
#define WARP_SIZE 64
2023-
#else
2024-
#define WARP_SIZE 32
2025-
#endif
20262021
template <typename T, int SPMM_ITEMS, int BITS>
20272022
__global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *offset_rowidx, int *rowidx, int *colidx, half *values, T *B, half *out, float * __restrict__ const dequant_stats, int nnz, int rowsA, int rowsB, int colsB)
20282023
{
@@ -2043,9 +2038,9 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
20432038
const int offset = local_max_idx == 0 ? 0 : offset_rowidx[local_max_idx-1];
20442039
const int local_row_idx = rowidx[offset];
20452040

2046-
const int warp_id = threadIdx.x / WARP_SIZE;
2047-
const int warp_idx = threadIdx.x % WARP_SIZE;
2048-
const int warp_offset = (warp_id*WARP_SIZE)*SPMM_ITEMS;
2041+
const int warp_id = threadIdx.x / BNB_WARP_SIZE;
2042+
const int warp_idx = threadIdx.x % BNB_WARP_SIZE;
2043+
const int warp_offset = (warp_id*BNB_WARP_SIZE)*SPMM_ITEMS;
20492044
const int num_items = BITS == 8 ? 8 : 8;
20502045
int idx_col_B = warp_offset;
20512046
int local_idx_col_B_offset = 0;
@@ -2065,7 +2060,7 @@ __global__ void kspmm_coo_very_sparse_naive(int *max_count, int *max_idx, int *o
20652060
}
20662061

20672062
// each thread processes SPMM_ITEMS=32 per iteration. We have 256 threads. 32*256=x192
2068-
// we expect each warp to be SPMM_ITEMS*WARP_SIZE apart
2063+
// we expect each warp to be SPMM_ITEMS*BNB_WARP_SIZE apart
20692064
// we have a total of 128 bytes for the bank with a bank size of 4 bytes
20702065
// added 3 bytes = 6 values between warps should reduce bank conflicts
20712066
__shared__ half smem_dequant_stats[SMEM_SIZE];
@@ -2618,15 +2613,15 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
26182613
{
26192614

26202615
// per threadblock:
2621-
// load step-by-step in chunks of [warp_size,warps]: 1xwarp_size * [warp_size,warps] -> [1,warps]
2616+
// load step-by-step in chunks of [BNB_WARP_SIZE,warps]: 1xBNB_WARP_SIZE * [BNB_WARP_SIZE,warps] -> [1,warps]
26222617
// 4 warps -> 4 loads per iter
2623-
// 1xwarp_size * warp_sizex4 -> 1x4 outputs per thread block
2624-
typedef hipcub::WarpReduce<float, WARP_SIZE> WarpReduce;
2625-
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/WARP_SIZE];
2618+
// 1 x BNB_WARP_SIZE * BNB_WARP_SIZE x 4 -> 1x4 outputs per thread block
2619+
typedef hipcub::WarpReduce<float, BNB_WARP_SIZE> WarpReduce;
2620+
__shared__ typename WarpReduce::TempStorage temp_storage[THREADS/BNB_WARP_SIZE];
26262621

2627-
const int warp_idx = threadIdx.x / WARP_SIZE;
2628-
const int warp_lane = threadIdx.x % WARP_SIZE;
2629-
const int row_B = (THREADS/WARP_SIZE)*blockIdx.x + warp_idx;
2622+
const int warp_idx = threadIdx.x / BNB_WARP_SIZE;
2623+
const int warp_lane = threadIdx.x % BNB_WARP_SIZE;
2624+
const int row_B = (THREADS/BNB_WARP_SIZE)*blockIdx.x + warp_idx;
26302625
const int offset_B = ldb * row_B;
26312626
const int num_values_8bit = num_values_4bit/2;
26322627
float local_C = 0.0f;
@@ -2645,7 +2640,7 @@ template <typename T, int THREADS, int BITS> __global__ void kgemm_4bit_inferenc
26452640

26462641
// A: [1, K]
26472642
// B: [M, K]
2648-
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += WARP_SIZE*num_values_4bit)
2643+
for(int inner_idx = warp_lane*num_values_4bit; inner_idx < K; inner_idx += BNB_WARP_SIZE*num_values_4bit)
26492644
{
26502645
const int inner_idx_halved = inner_idx/2;
26512646

@@ -2957,23 +2952,29 @@ MAKE_kQuantizeBlockwise(half, 1024, 4, 0, General8bit)
29572952
MAKE_kQuantizeBlockwise(half, 512, 2, 0, General8bit)
29582953
MAKE_kQuantizeBlockwise(half, 256, 2, 0, General8bit)
29592954
MAKE_kQuantizeBlockwise(half, 128, 2, 0, General8bit)
2960-
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
2955+
#if BNB_WARP_SIZE == 32
2956+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, General8bit)
2957+
#endif
29612958

29622959
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, FP4)
29632960
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, FP4)
29642961
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, FP4)
29652962
MAKE_kQuantizeBlockwise(half, 512, 2, 0, FP4)
29662963
MAKE_kQuantizeBlockwise(half, 256, 2, 0, FP4)
29672964
MAKE_kQuantizeBlockwise(half, 128, 2, 0, FP4)
2968-
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
2965+
#if BNB_WARP_SIZE == 32
2966+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, FP4)
2967+
#endif
29692968

29702969
MAKE_kQuantizeBlockwise(half, 4096, 4, 0, NF4)
29712970
MAKE_kQuantizeBlockwise(half, 2048, 4, 0, NF4)
29722971
MAKE_kQuantizeBlockwise(half, 1024, 4, 0, NF4)
29732972
MAKE_kQuantizeBlockwise(half, 512, 2, 0, NF4)
29742973
MAKE_kQuantizeBlockwise(half, 256, 2, 0, NF4)
29752974
MAKE_kQuantizeBlockwise(half, 128, 2, 0, NF4)
2976-
//MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
2975+
#if BNB_WARP_SIZE == 32
2976+
MAKE_kQuantizeBlockwise(half, 64, 2, 0, NF4)
2977+
#endif
29772978

29782979
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, General8bit)
29792980
MAKE_kQuantizeBlockwise(float, 4096, 4, 1, General8bit)
@@ -2982,23 +2983,29 @@ MAKE_kQuantizeBlockwise(float, 1024, 4, 0, General8bit)
29822983
MAKE_kQuantizeBlockwise(float, 512, 2, 0, General8bit)
29832984
MAKE_kQuantizeBlockwise(float, 256, 2, 0, General8bit)
29842985
MAKE_kQuantizeBlockwise(float, 128, 2, 0, General8bit)
2985-
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
2986+
#if BNB_WARP_SIZE == 32
2987+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, General8bit)
2988+
#endif
29862989

29872990
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, FP4)
29882991
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, FP4)
29892992
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, FP4)
29902993
MAKE_kQuantizeBlockwise(float, 512, 2, 0, FP4)
29912994
MAKE_kQuantizeBlockwise(float, 256, 2, 0, FP4)
29922995
MAKE_kQuantizeBlockwise(float, 128, 2, 0, FP4)
2993-
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
2996+
#if BNB_WARP_SIZE == 32
2997+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, FP4)
2998+
#endif
29942999

29953000
MAKE_kQuantizeBlockwise(float, 4096, 4, 0, NF4)
29963001
MAKE_kQuantizeBlockwise(float, 2048, 4, 0, NF4)
29973002
MAKE_kQuantizeBlockwise(float, 1024, 4, 0, NF4)
29983003
MAKE_kQuantizeBlockwise(float, 512, 2, 0, NF4)
29993004
MAKE_kQuantizeBlockwise(float, 256, 2, 0, NF4)
30003005
MAKE_kQuantizeBlockwise(float, 128, 2, 0, NF4)
3001-
//MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
3006+
#if BNB_WARP_SIZE == 32
3007+
MAKE_kQuantizeBlockwise(float, 64, 2, 0, NF4)
3008+
#endif
30023009

30033010
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, General8bit)
30043011
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 1, General8bit)
@@ -3007,23 +3014,29 @@ MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, General8bit)
30073014
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, General8bit)
30083015
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, General8bit)
30093016
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, General8bit)
3010-
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
3017+
#if BNB_WARP_SIZE == 32
3018+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, General8bit)
3019+
#endif
30113020

30123021
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, FP4)
30133022
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, FP4)
30143023
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, FP4)
30153024
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, FP4)
30163025
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, FP4)
30173026
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, FP4)
3018-
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
3027+
#if BNB_WARP_SIZE == 32
3028+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, FP4)
3029+
#endif
30193030

30203031
MAKE_kQuantizeBlockwise(hip_bfloat16, 4096, 4, 0, NF4)
30213032
MAKE_kQuantizeBlockwise(hip_bfloat16, 2048, 4, 0, NF4)
30223033
MAKE_kQuantizeBlockwise(hip_bfloat16, 1024, 4, 0, NF4)
30233034
MAKE_kQuantizeBlockwise(hip_bfloat16, 512, 2, 0, NF4)
30243035
MAKE_kQuantizeBlockwise(hip_bfloat16, 256, 2, 0, NF4)
30253036
MAKE_kQuantizeBlockwise(hip_bfloat16, 128, 2, 0, NF4)
3026-
//MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
3037+
#if BNB_WARP_SIZE == 32
3038+
MAKE_kQuantizeBlockwise(hip_bfloat16, 64, 2, 0, NF4)
3039+
#endif
30273040

30283041
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, FP4>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);
30293042
template __global__ void kDequantizeBlockwise<half, 512, 64, 8, General8bit>(float *code, unsigned char * A, float * absmax, half *out, const int blocksize, const int n);

0 commit comments

Comments
 (0)