Skip to content

Commit

Permalink
Merge branch 'debug' into cuda-bin-switch-and-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
TimDettmers committed Aug 4, 2022
2 parents 96bc209 + ab72a12 commit 758c717
Show file tree
Hide file tree
Showing 5 changed files with 180 additions and 199 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ CC_cublasLt111 += -gencode arch=compute_86,code=sm_86


all: $(ROOT_DIR)/dependencies/cub $(BUILD_DIR) env
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' --use_fast_math -Xptxas=-v -dc $(FILES_CUDA) $(INCLUDE) $(LIB) --output-directory $(BUILD_DIR)
$(NVCC) $(COMPUTE_CAPABILITY) -Xcompiler '-fPIC' -dlink $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o -o $(BUILD_DIR)/link.o
$(GPP) -std=c++14 -DBUILD_CUDA -shared -fPIC $(INCLUDE) $(BUILD_DIR)/ops.o $(BUILD_DIR)/kernels.o $(BUILD_DIR)/link.o $(FILES_CPP) -o ./bitsandbytes/libbitsandbytes_cuda$(CUDA_VERSION).so $(LIB)

Expand Down
17 changes: 15 additions & 2 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass

import torch

import math
import bitsandbytes as bnb
import bitsandbytes.functional as F

Expand Down Expand Up @@ -199,6 +199,17 @@ def reset_grads(self):
class MatMul8bitLt(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# default to pytorch behavior if inputs are empty
ctx.is_empty = False
if math.prod(A.shape) == 0:
ctx.is_empty = True
ctx.A = A
ctx.B = B
if A.shape[-1] == B.shape[0]:
return torch.empty(A.shape[:-1]+B.shape[1:], dtype=torch.float16, device=A.device)
else:
return torch.empty(A.shape[:-1]+B.shape[:1], dtype=torch.float16, device=A.device)

# 1. Quantize A
# 2. Quantize B
# 3. Matmul
Expand Down Expand Up @@ -339,6 +350,8 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):

@staticmethod
def backward(ctx, grad_output):
if ctx.is_empty:
return torch.zeros_like(ctx.A), torch.zeros_like(ctx.B), None, None
req_gradA, req_gradB = ctx.req_grads
CAt, subA = ctx.tensors
SCAt, idx = ctx.tensor_states
Expand Down Expand Up @@ -375,7 +388,7 @@ def backward(ctx, grad_output):
ctx.grad_shape
)

return grad_A, grad_B, None, None, None, None, None
return grad_A, grad_B, None, None


matmul = MatMul8bitLt.apply
Expand Down
Loading

0 comments on commit 758c717

Please sign in to comment.