From 18c96598da3c7e447203f3e3af54cf078abd6eb9 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Wed, 28 May 2025 15:37:41 +0530 Subject: [PATCH] # Fix for matmul_4bit out Parameter Issue --- bitsandbytes/autograd/_functions.py | 26 ++++++++++++++++++-------- bitsandbytes/backends/cuda/ops.py | 29 ++++++++++++++++++++++------- 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index c7ad3a82c..fd0aa8923 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -299,9 +299,6 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor class MatMul4Bit(torch.autograd.Function): - # forward is the same, but we added the fallback for pre-turing GPUs - # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None") - @staticmethod def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None): # default of pytorch behavior if inputs are empty @@ -319,7 +316,15 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] # 1. Dequantize # 2. MatmulnN - output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + # Use linear function which correctly handles 1D and 2D inputs + result = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias) + + # If out is provided, resize it if necessary and copy the result + if out is not None: + if out.shape != result.shape: + out.resize_(result.shape) + out.copy_(result) + result = out # 3. Save state ctx.state = quant_state @@ -330,7 +335,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] else: ctx.tensors = (None, None) - return output + return result @staticmethod def backward(ctx, grad_output): @@ -385,9 +390,14 @@ def matmul_4bit( ) return MatMul4Bit.apply(A, B, out, bias, quant_state) else: - out = F.gemv_4bit(A, B.t(), out, state=quant_state) + # For 1D case, we'll use the MatMul4Bit implementation which correctly handles out parameter + if out is not None and A.dim() == 1: + return MatMul4Bit.apply(A, B, out, bias, quant_state) + + # For other cases, use gemv_4bit + result = F.gemv_4bit(A, B.t(), out, state=quant_state) if bias is not None: - out += bias - return out + result += bias + return result else: return MatMul4Bit.apply(A, B, out, bias, quant_state) diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py index efdef2871..a32a005e4 100644 --- a/bitsandbytes/backends/cuda/ops.py +++ b/bitsandbytes/backends/cuda/ops.py @@ -427,11 +427,18 @@ def _( blocksize: int, out: torch.Tensor, ) -> None: + expected_shape = (*A.shape[:-1], shapeB[0]) + + if len(A.shape) == 1 and len(out.shape) == 2 and out.shape[0] == 1: + out = out.view(shapeB[0]) + expected_shape = (shapeB[0],) + torch._check( - out.shape == (*A.shape[:-1], shapeB[0]), - lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}", + out.shape == expected_shape, + lambda: f"Expected out.shape == {expected_shape}, got {out.shape}", ) torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}") + _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out) @@ -445,10 +452,13 @@ def _gemv_4bit_impl( out: torch.Tensor, ) -> None: torch._check_is_size(blocksize) - torch._check( - A.numel() == A.size(-1), - lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}", - ) + + is_1d = A.dim() == 1 + if is_1d: + A_reshaped = A.view(1, -1) + else: + A_reshaped = A + torch._check( A.dtype in [torch.float16, torch.bfloat16, torch.float32], lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}", @@ -465,11 +475,16 @@ def _gemv_4bit_impl( k = ct.c_int32(shapeB[1]) lda = m - ldb = ct.c_int32((A.shape[-1] + 1) // 2) + ldb = ct.c_int32((A_reshaped.shape[-1] + 1) // 2) ldc = m stream = _get_tensor_stream(A) + if is_1d and out.dim() > 1: + out_view = out.view(-1) + else: + out_view = out + with _cuda_device_of(A): if A.dtype == torch.float16: lib.cgemm_4bit_inference_naive_fp16(