diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 3e68a1703..988f25d4c 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -439,6 +439,11 @@ def is_on_gpu(tensors): return on_gpu +def get_tensor_stream(tensor: Tensor) -> torch.cuda.Stream: + stream = torch.cuda.current_stream(tensor.device) + return stream + + def get_ptr(A: Optional[Tensor]) -> Optional[ct.c_void_p]: """ Get the ctypes pointer from a PyTorch Tensor. @@ -973,6 +978,7 @@ def dequantize_blockwise( f"The blockwise of {quant_state.blocksize} is not supported. Supported values: [2048, 4096, 1024, 512, 256, 128, 64]", ) is_on_gpu([A, absmax, out]) + stream = get_tensor_stream(A) if out.dtype == torch.float32: lib.cdequantize_blockwise_fp32( get_ptr(quant_state.code), @@ -981,6 +987,7 @@ def dequantize_blockwise( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), + stream, # Used the _as_parameter_ attribute of torch.cuda.Stream, Similarly for the following ) elif out.dtype == torch.float16: lib.cdequantize_blockwise_fp16( @@ -990,6 +997,7 @@ def dequantize_blockwise( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), + stream, ) elif out.dtype == torch.bfloat16: lib.cdequantize_blockwise_bf16( @@ -999,6 +1007,7 @@ def dequantize_blockwise( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(A.numel()), + stream, ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") @@ -1176,7 +1185,6 @@ def quantize_4bit( prev_device = pre_call(A.device) is_on_gpu([A, out, absmax]) - if A.dtype == torch.float32: if quant_type == "fp4": lib.cquantize_blockwise_fp32_fp4( @@ -1356,6 +1364,7 @@ def dequantize_4bit( device = pre_call(A.device) is_on_gpu([A, absmax, out]) + stream = get_tensor_stream(A) if out.dtype == torch.float32: if quant_state.quant_type == "fp4": lib.cdequantize_blockwise_fp32_fp4( @@ -1365,6 +1374,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: lib.cdequantize_blockwise_fp32_nf4( @@ -1374,6 +1384,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) elif out.dtype == torch.float16: if quant_state.quant_type == "fp4": @@ -1384,6 +1395,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: lib.cdequantize_blockwise_fp16_nf4( @@ -1393,6 +1405,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) elif out.dtype == torch.bfloat16: if quant_state.quant_type == "fp4": @@ -1403,6 +1416,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: lib.cdequantize_blockwise_bf16_nf4( @@ -1412,6 +1426,7 @@ def dequantize_4bit( get_ptr(out), ct.c_int(quant_state.blocksize), ct.c_int(n), + stream, ) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") @@ -1518,7 +1533,8 @@ def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = if out is None: out = torch.zeros_like(A, dtype=torch.float32) is_on_gpu([code, A, out]) - lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel())) + stream = get_tensor_stream(A) + lib.cdequantize(get_ptr(code), get_ptr(A), get_ptr(out), ct.c_int(A.numel()), stream) post_call(prev_device) return out @@ -2013,7 +2029,7 @@ def gemv_4bit( lda = ct.c_int32(lda) ldb = ct.c_int32(ldb) ldc = ct.c_int32(ldc) - + stream = get_tensor_stream(A) if B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32]: if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16( @@ -2029,6 +2045,7 @@ def gemv_4bit( ldb, ldc, ct.c_int32(state.blocksize), + stream, ) elif A.dtype == torch.bfloat16: lib.cgemm_4bit_inference_naive_bf16( @@ -2044,6 +2061,7 @@ def gemv_4bit( ldb, ldc, ct.c_int32(state.blocksize), + stream, ) elif A.dtype == torch.float32: lib.cgemm_4bit_inference_naive_fp32( @@ -2059,6 +2077,7 @@ def gemv_4bit( ldb, ldc, ct.c_int32(state.blocksize), + stream, ) else: raise NotImplementedError(f"Matmul not implemented for data type {A.dtype}") diff --git a/csrc/ops.cu b/csrc/ops.cu index 1566ced5e..fb439be89 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -44,11 +44,11 @@ void quantize(float *code, float *A, unsigned char *out, int n) CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -void dequantize(float *code, unsigned char *A, float *out, int n) +void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream) { int num_blocks = n/1024; num_blocks = n % 1024 == 0 ? num_blocks : num_blocks + 1; - kDequantize<<>>(code, A, out, n); + kDequantize<<>>(code, A, out, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -76,16 +76,16 @@ template void quantizeBlockwise(floa CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n) +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int blocksize, const int n, cudaStream_t stream) { + // printf("stream==%d\n",stream); int num_blocks = n/blocksize; num_blocks = n % blocksize == 0 ? num_blocks : num_blocks + 1; int tile_size = (DATA_TYPE > 0) ? 1024 : 512; - if(DATA_TYPE > 0) - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize/2, n); + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize/2, n); else - kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64>>>(code, A, absmax, out, blocksize, n); + kDequantizeBlockwise<<<(n+tile_size-1)/tile_size, 64, 0, stream>>>(code, A, absmax, out, blocksize, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -724,12 +724,11 @@ template void gemm_4bit_inference(int m, int n, int k, T * A, unsi //kgemm_4bit_inference<<< num_blocks, 32, 0, 0 >>>(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize) +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) { int num_blocks = (m+3)/4; - - kgemm_4bit_inference_naive<<< num_blocks, 128, 0, 0 >>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); + kgemm_4bit_inference_naive<<< num_blocks, 128, 0, stream>>>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } @@ -753,9 +752,9 @@ template void func(float *A, float *B, float value, long n); template void func(float *A, float *B, float value, long n); template void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); +template void gemm_4bit_inference_naive<__nv_bfloat16, 16>(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); +template void gemm_4bit_inference_naive(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); //template void gemm_host(int m, int n, int k, float * A, float* B, float * out, int lda, int ldb, int ldc, int bits); template void gemm_host(int m, int n, int k, half * A, half* B, half * out, int lda, int ldb, int ldc, int bits); @@ -795,15 +794,15 @@ template void quantizeBlockwise<__nv_bfloat16, 0, General8bit>(float * code, __n template void quantizeBlockwise<__nv_bfloat16, 0, FP4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); template void quantizeBlockwise<__nv_bfloat16, 0, NF4>(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); -template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise<__nv_bfloat16, General8bit>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise<__nv_bfloat16, FP4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); +template void dequantizeBlockwise<__nv_bfloat16, NF4>(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream); #define MAKE_optimizer32bit(name, gtype) \ template void optimizer32bit(gtype* g, gtype* p, gtype* return_updates, \ diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 828aa6a49..f54b18d5b 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -7,6 +7,7 @@ #ifndef ops_H #define ops_H +#include #include #include #include @@ -142,9 +143,9 @@ class ContextCusparse template void estimateQuantiles(T *A, float *code, float offset, int n); void quantize(float *code, float *A, unsigned char *out, int n); -void dequantize(float *code, unsigned char *A, float *out, int n); +void dequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream); template void quantizeBlockwise(float * code, T *A, float *absmax, unsigned char *out, float* rand, int rand_offset, int blocksize, const int n); -template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n); +template void dequantizeBlockwise(float *code, unsigned char *A, float *absmax, T *out, int block_size, const int n, cudaStream_t stream); template void optimizer32bit(T* g, T* p, T* return_updates, float* state1, float* state2, float *unorm, float max_unorm, float param_norm, @@ -196,7 +197,7 @@ void matmul4bite(half *A, unsigned char *B, half*out, int lda, int ldb, int rows template void gemm_host(int m, int n, int k, T * A, T* B, T * out, int lda, int ldb, int ldc, int bits); template void gemm_4bit_inference(int m, int n, int k, T * A, unsigned char* B, float *absmax, T * out, int lda, int ldb, int ldc, int blocksize); -template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize); +template void gemm_4bit_inference_naive(int m, int n, int k, T * A, unsigned char* B, float *absmax, float *datatype, T * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream); template void func(T *A, T *B, T value, long n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 6a3030f73..418ca76d3 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -31,14 +31,14 @@ void gemm_host_fp16(int M, int N, int K, half * A, half* B, half * out, int l void gemm_4bit_inference(int m, int n, int k, half * A, unsigned char* B, float *absmax, half * out, int lda, int ldb, int ldc, int blocksize) { gemm_4bit_inference(m, n, k, A, B, absmax, out, lda, ldb, ldc, blocksize); } -void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } -void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) +{ gemm_4bit_inference_naive<__nv_bfloat16, 16>(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } -void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) -{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } +void gemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) +{ gemm_4bit_inference_naive(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #define MAKE_ELEMENTWISE_FUNC(fname, type_name, ctype, FUNC) \ void fname##_##type_name(ctype *A, ctype *B, ctype value, long n){ func(A, B, value, n); } \ @@ -126,17 +126,17 @@ void quantizeBlockwise_fp32(float * code, float *A, float *absmax, unsigned char void quantizeBlockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } void quantizeBlockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise(NULL, A, absmax, out, NULL, 0, blocksize, n); } -void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ -void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } \ +void dequantizeBlockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } \ +void dequantizeBlockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } \ +void dequantizeBlockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } \ -void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(code, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream); } -void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n); } -void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n); } +void dequantizeBlockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, General8bit>(code, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, FP4>(NULL, A, absmax, out, blocksize, n, stream); } +void dequantizeBlockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise<__nv_bfloat16, NF4>(NULL, A, absmax, out, blocksize, n, stream); } #define MAKE_FUNC_TRANSFORM(fbits, fsrc, ftrgt, ftranspose, dtype, src, target, transpose, bits) \ @@ -195,11 +195,11 @@ extern "C" void cestimate_quantiles_fp32(float *A, float *code, float offset, int n){ estimateQuantiles_fp32(A, code, offset, n); } void cestimate_quantiles_fp16(half *A, float *code, float offset, int n){ estimateQuantiles_fp16(A, code, offset, n); } void cquantize(float *code, float *A, unsigned char *out, int n){ quantize(code, A, out, n); } - void cdequantize(float *code, unsigned char *A, float *out, int n){ dequantize(code, A, out, n); } + void cdequantize(float *code, unsigned char *A, float *out, int n, cudaStream_t stream){ dequantize(code, A, out, n, stream); } - void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp16_fp4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp16(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp16_nf4(float *code, unsigned char *A, float *absmax, half *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_fp16(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp16_fp4(float * code, half *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n); } @@ -209,17 +209,17 @@ extern "C" void cquantize_blockwise_fp32_fp4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_fp32_nf4(float * code, float *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp32_fp4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_fp32_nf4(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream); } void cquantize_blockwise_bf16(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_fp4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } void cquantize_blockwise_bf16_nf4(float * code, __nv_bfloat16 *A, float *absmax, unsigned char *out, int blocksize, const int n){ quantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n); } - void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n); } + void cdequantize_blockwise_bf16(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_bf16_fp4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream); } + void cdequantize_blockwise_bf16_nf4(float *code, unsigned char *A, float *absmax, __nv_bfloat16 *out, int blocksize, const int n, cudaStream_t stream){ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream); } #define MAKE_CFUNC32(name, gtype, gbits) \ void c##name##32bit_grad_##gbits(gtype *g, gtype *p, gtype *return_updates, \ @@ -406,14 +406,14 @@ extern "C" CMAKE_ELEMENTWISE_FUNC(arange, fp32, float, ARANGE) CMAKE_ELEMENTWISE_FUNC(_mul, fp32, float, _MUL) - void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp16(int m, int n, int k, half * A, unsigned char* B, float *absmax, float *datatype, half * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + { gemm_4bit_inference_naive_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } - void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_bf16(int m, int n, int k, __nv_bfloat16 * A, unsigned char* B, float *absmax, float *datatype, __nv_bfloat16 * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + { gemm_4bit_inference_naive_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } - void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize) - { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize); } + void cgemm_4bit_inference_naive_fp32(int m, int n, int k, float * A, unsigned char* B, float *absmax, float *datatype, float * out, int lda, int ldb, int ldc, int blocksize, cudaStream_t stream) + { gemm_4bit_inference_naive_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream); } #endif