Skip to content

# Fix for matmul_4bit out Parameter Issue #1659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,6 @@ def backward(ctx, grad_output):


class MatMul4Bit(torch.autograd.Function):
# forward is the same, but we added the fallback for pre-turing GPUs
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")

@staticmethod
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
# default of pytorch behavior if inputs are empty
Expand All @@ -377,7 +374,15 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]

# 1. Dequantize
# 2. MatmulnN
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
# Use linear function which correctly handles 1D and 2D inputs
result = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)

# If out is provided, resize it if necessary and copy the result
if out is not None:
if out.shape != result.shape:
out.resize_(result.shape)
out.copy_(result)
result = out

# 3. Save state
ctx.state = quant_state
Expand All @@ -388,7 +393,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
else:
ctx.tensors = (None, None)

return output
return result

@staticmethod
def backward(ctx, grad_output):
Expand Down Expand Up @@ -458,9 +463,14 @@ def matmul_4bit(
)
return MatMul4Bit.apply(A, B, out, bias, quant_state)
else:
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
# For 1D case, we'll use the MatMul4Bit implementation which correctly handles out parameter
if out is not None and A.dim() == 1:
return MatMul4Bit.apply(A, B, out, bias, quant_state)

# For other cases, use gemv_4bit
result = F.gemv_4bit(A, B.t(), out, state=quant_state)
if bias is not None:
out += bias
return out
result += bias
return result
else:
return MatMul4Bit.apply(A, B, out, bias, quant_state)
49 changes: 31 additions & 18 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,11 +427,18 @@ def _(
blocksize: int,
out: torch.Tensor,
) -> None:
expected_shape = (*A.shape[:-1], shapeB[0])

if len(A.shape) == 1 and len(out.shape) == 2 and out.shape[0] == 1:
out = out.view(shapeB[0])
expected_shape = (shapeB[0],)

torch._check(
out.shape == (*A.shape[:-1], shapeB[0]),
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
out.shape == expected_shape,
lambda: f"Expected out.shape == {expected_shape}, got {out.shape}",
)
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")

_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)


Expand All @@ -446,32 +453,38 @@ def _gemv_4bit_impl(
) -> None:
torch._check_is_size(blocksize)

# Note: these checks are not strictly necessary, and cost more than they are worth, so they are commented out for now.
# torch._check(
# A.numel() == A.size(-1),
# lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
# )
# torch._check(
# A.dtype in [torch.float16, torch.bfloat16, torch.float32],
# lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
# )
# torch._check(
# B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
# lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
# )
# torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
# torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")
is_1d = A.dim() == 1
if is_1d:
A_reshaped = A.view(1, -1)
else:
A_reshaped = A

torch._check(
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
)
torch._check(
B.dtype in [torch.uint8, torch.bfloat16, torch.float16, torch.float32],
lambda: f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got {B.dtype}",
)
torch._check(absmax.dtype == torch.float32, lambda: f"absmax must be float32, got {absmax.dtype}")
torch._check(code.dtype == torch.float32, lambda: f"code must be float32, got {code.dtype}")

m = ct.c_int32(shapeB[0])
n = ct.c_int32(1)
k = ct.c_int32(shapeB[1])

lda = m
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
ldb = ct.c_int32((A_reshaped.shape[-1] + 1) // 2)
ldc = m

stream = _get_tensor_stream(A)

if is_1d and out.dim() > 1:
out_view = out.view(-1)
else:
out_view = out

with _cuda_device_of(A):
if A.dtype == torch.float16:
lib.cgemm_4bit_inference_naive_fp16(
Expand Down