Skip to content

Commit

Permalink
int8 more cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Nov 4, 2024
1 parent 875414e commit 0aefeb0
Showing 1 changed file with 5 additions and 14 deletions.
19 changes: 5 additions & 14 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class MatmulLtState:
has_fp16_weights = True
memory_efficient_backward = False
use_pool = False
formatB = "row" # F.get_special_format_str() TODO: Deprecate/remove
formatB = "row" # TODO: Deprecate/remove

def reset_grads(self):
self.CB = None
Expand Down Expand Up @@ -394,9 +394,9 @@ def forward(
output_shape = (*input_shape[:-1], state.CB.shape[0])

if len(input_shape) == 3:
return output.reshape(output_shape) # .clone()
else:
return output
return output.reshape(output_shape)

return output

@staticmethod
def backward(ctx, grad_output):
Expand All @@ -418,11 +418,6 @@ def backward(ctx, grad_output):
if len(grad_output.shape) == 3:
grad_output = grad_output.reshape(-1, grad_output.shape[-1]).contiguous()

# if req_gradB:
# grad_B = torch.matmul(grad_output.t(), A)
# if state.threshold > 0.0 and subA is not None:
# grad_B[:, idx] += torch.matmul(grad_output.t(), subA)
# Cgrad, Cgradt, SCgrad, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))
if req_gradB:
Cgrad, _, _, SCgradt, _ = F.double_quant(grad_output.to(torch.float16))

Expand All @@ -432,15 +427,11 @@ def backward(ctx, grad_output):
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
# grad_output @ B.T
# if state.CBt is not None:
# gradA32, SgradA32 = F.igemmlt(Cgrad, state.CBt.t())
# grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape).to(ctx.dtype_A)
if state.CB is not None:
CB = state.CB.to(ctx.dtype_A, copy=True).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
grad_A = torch.matmul(grad_output, CB).view(ctx.grad_shape).to(ctx.dtype_A)
else:
raise Exception("State must contain either CBt or CB matrix for backward")
raise Exception("State must contain CB matrix for backward")

return grad_A, grad_B, None, grad_bias, None

Expand Down

0 comments on commit 0aefeb0

Please sign in to comment.