Skip to content

Commit

Permalink
ran black and isort for coherent code formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Titus-von-Koeller committed Aug 1, 2022
1 parent 597a852 commit bfa0e33
Show file tree
Hide file tree
Showing 25 changed files with 3,866 additions and 1,998 deletions.
20 changes: 11 additions & 9 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .nn import modules
from .autograd._functions import mm_cublas, bmm_cublas, matmul_cublas, matmul, MatmulLtState
from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
matmul_cublas, mm_cublas)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules

if COMPILED_WITH_CUDA:
from .optim import adam

__pdoc__ = {'libbitsandbytes': False,
'optim.optimizer.Optimizer8bit': False,
'optim.optimizer.MockArgs': False
}
__pdoc__ = {
"libbitsandbytes": False,
"optim.optimizer.Optimizer8bit": False,
"optim.optimizer.MockArgs": False,
}
134 changes: 88 additions & 46 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
from dataclasses import dataclass

import torch

import bitsandbytes as bnb
import bitsandbytes.functional as F

from dataclasses import dataclass

tensor = torch.Tensor

'''
"""
This class pools outlier dimensions across layers.
This is particularly important for small models where outlier features
are less systematic and occur with low frequency.
'''
"""


class GlobalOutlierPooler(object):
_instance = None

def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")

def initialize(self):
self.outliers = set()
Expand All @@ -29,25 +32,29 @@ def get_instance(cls):
return cls._instance

def add_outliers(self, outlier_idx, feature_dim):
if self.model_dim is None: self.model_dim = feature_dim
if feature_dim != self.model_dim: return # we do not encode outliers for the 2nd FFN layer
if self.model_dim is None:
self.model_dim = feature_dim
if feature_dim != self.model_dim:
return # we do not encode outliers for the 2nd FFN layer

self.outliers.update(outlier_idx.tolist())

def get_current_outlier_idx(self):
return torch.Tensor(list(self.outliers)).to(torch.int64)

class MatMul8bit(torch.autograd.Function):

class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type='vector', precision=[8, 8, 8]):
def forward(ctx, A, B, out=None, quant_type="vector", precision=[8, 8, 8]):

if precision[0] != 8:
with torch.no_grad():
output = torch.matmul(A, B)
else:
if len(B.shape) == 2: dim = 0
else: dim = 1
if len(B.shape) == 2:
dim = 0
else:
dim = 1
qA, SA = F.vectorwise_quant(A, dim=-1, quant_type=quant_type)
qB, SB = F.vectorwise_quant(B, dim=dim, quant_type=quant_type)
iout = F.igemm(qA, qB)
Expand Down Expand Up @@ -84,21 +91,41 @@ def backward(ctx, grad_output):
else:
if len(B.shape) == 2 and len(A.shape) == 3:
grad_output = grad_output.contiguous()
if not grad_output.is_contiguous(): grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(grad_output.view(-1, grad_output.shape[2]), dim=0, quant_type=quant_type)
if not A.is_contiguous(): A = A.contiguous()
qA, S2 = F.vectorwise_quant(A.view(-1, A.shape[2]), dim=0, quant_type=quant_type)
if not grad_output.is_contiguous():
grad_output.contiguous()
qgrad_output, S1 = F.vectorwise_quant(
grad_output.view(-1, grad_output.shape[2]),
dim=0,
quant_type=quant_type,
)
if not A.is_contiguous():
A = A.contiguous()
qA, S2 = F.vectorwise_quant(
A.view(-1, A.shape[2]), dim=0, quant_type=quant_type
)
igrad_B = F.igemm(qA.t(), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.t(), S1, grad_output.dtype, quant_type)
grad_B = F.vectorwise_mm_dequant(
igrad_B, S2.t(), S1, grad_output.dtype, quant_type
)
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(igrad_B, S2.permute(permute_dim), S1, grad_output.dtype, quant_type)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
S2.permute(permute_dim),
S1,
grad_output.dtype,
quant_type,
)

if A.requires_grad:
if len(grad_output.shape) == 3: dims = [2]
else: dims = [1]
if len(grad_output.shape) == 3:
dims = [2]
else:
dims = [1]

if len(B.shape) == 3:
# bio -> boi
Expand All @@ -113,10 +140,14 @@ def backward(ctx, grad_output):
with torch.no_grad():
grad_A = torch.matmul(grad_output, B.permute(permute_dim))
else:
qgrad_output, S1 = F.vectorwise_quant(grad_output, dim=dims, quant_type=quant_type)
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type)
grad_A = F.vectorwise_mm_dequant(
igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
)

return grad_A, grad_B, None, None, None

Expand All @@ -125,6 +156,7 @@ def backward(ctx, grad_output):
bmm_cublas = MatMul8bit.apply
matmul_cublas = MatMul8bit.apply


@dataclass
class MatmulLtState:
CB = None
Expand Down Expand Up @@ -159,7 +191,6 @@ def reset_grads(self):


class MatMul8bitLt(torch.autograd.Function):

@staticmethod
def forward(ctx, A, B, out=None, state=MatmulLtState()):
# 1. Quantize A
Expand All @@ -171,11 +202,15 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
requires_gradB = B.requires_grad
formatB = state.formatB
input_shape = A.shape
if state.outlier_pool is None: state.outlier_pool = GlobalOutlierPooler.get_instance()
assert A.dtype == torch.float16, f'The input data type needs to be fp16 but {A.dtype} was found!'
if state.outlier_pool is None:
state.outlier_pool = GlobalOutlierPooler.get_instance()
assert (
A.dtype == torch.float16
), f"The input data type needs to be fp16 but {A.dtype} was found!"

# 1. Quantize A
if len(A.shape) == 3: A = A.view(-1, A.shape[-1]).contiguous()
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)

if state.threshold > 0.0 and coo_tensorA is not None:
Expand All @@ -191,8 +226,8 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
#state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
#if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
# outlier_idx = torch.unique(coo_tensorA.colidx).long()
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
Expand All @@ -203,24 +238,24 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
# state.idx = outlier_idx
# state.subB = (state.CB[:, state.idx].float().t().contiguous()*(state.SCB/127)).half()

#if state.idx is not None:
# if state.idx is not None:
# # extract outliers
# CA[:, state.idx] = 0
# CAt[:, state.idx] = 0
# subA = A[:, state.idx]
#else:
# else:
# subA = None
else:
if not state.has_fp16_weights and state.CxB is None:
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
subA = None


# 2. Quantize B
if state.has_fp16_weights:
has_grad = (True if (getattr(B, 'grad', None) is not None) else False)
has_grad = True if (getattr(B, "grad", None) is not None) else False
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed: B = B.contiguous()
if is_transposed:
B = B.contiguous()

if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
Expand All @@ -234,14 +269,16 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):

outlier_idx = torch.unique(coo_tensorA.colidx)
state.idx = outlier_idx
#state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
#if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# state.outlier_pool.add_outliers(outlier_idx, A.shape[-1])
# if state.use_pool and state.outlier_pool.model_dim == A.shape[-1]:
# # do not use pool for 2nd FFN layer
# state.idx = state.outlier_pool.get_current_outlier_idx().to(A.device)
#else:
# else:
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (outliers*state.SCB.view(-1, 1)/127.0).t().contiguous().half()
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
subA = A[:, state.idx.long()]
Expand All @@ -254,7 +291,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
output_shape = (input_shape[0], shapeB[0])

# 3. Matmul
C32A, SA = F.transform(CA, 'col32')
C32A, SA = F.transform(CA, "col32")
out32, Sout32 = F.igemmlt(C32A, state.CxB, SA, state.SB)
output = F.mm_dequant(out32, Sout32, SCA, state.SCB)

Expand All @@ -277,7 +314,7 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
ctx.tensor_states = (None, None)
ctx.save_for_backward(None, None)

#clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
# clone_func = torch.clone if len(output_shape) == 3 else lambda x : x
clone_func = torch.clone
return clone_func(output.view(output_shape))

Expand All @@ -288,7 +325,7 @@ def backward(ctx, grad_output):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
assert state.has_fp16_weights, 'Backprop only supported for fp16 weights.'
assert state.has_fp16_weights, "Backprop only supported for fp16 weights."

if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
Expand All @@ -298,28 +335,33 @@ def backward(ctx, grad_output):
Cgrad, Cgradt, SCgrad, SCgradt, coo_tensor = F.double_quant(grad_output)
if req_gradB:
CxAt, SAt = F.transform(CAt, formatB, transpose=True)
C32grad, Sgrad = F.transform(Cgradt, 'col32', transpose=True)
C32grad, Sgrad = F.transform(Cgradt, "col32", transpose=True)
gradB32, SgradB32 = F.igemmlt(C32grad, CxAt, Sgrad, SAt)
grad_B = F.mm_dequant(gradB32, SgradB32, SCgradt, SCAt)
if state.threshold > 0.0 and subA is not None:
grad_B[:, idx] += torch.matmul(grad_output.t(), subA)

if req_gradA:
C32grad, Sgrad = F.transform(Cgrad, 'col32')
C32grad, Sgrad = F.transform(Cgrad, "col32")
if state.CxBt is None:
state.CxBt, state.SBt = F.transform(state.CBt, to_order=formatB, transpose=True)
state.CxBt, state.SBt = F.transform(
state.CBt, to_order=formatB, transpose=True
)
gradA32, SgradA32 = F.igemmlt(C32grad, state.CxBt, Sgrad, state.SBt)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(ctx.grad_shape)
grad_A = F.mm_dequant(gradA32, SgradA32, SCgrad, state.SCBt).view(
ctx.grad_shape
)

return grad_A, grad_B, None, None, None, None, None


matmul = MatMul8bitLt.apply


def matmul(A : tensor, B : tensor, out : tensor=None, state : MatmulLtState = None, threshold=0.0):
def matmul(
A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
):
state = state or MatmulLtState()
if threshold > 0.0:
state.threshold = threshold
return MatMul8bitLt.apply(A, B, out, state)

23 changes: 15 additions & 8 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,29 @@
import ctypes as ct
import os
from warnings import warn

from bitsandbytes.cuda_setup import evaluate_cuda_setup


class CUDALibrary_Singleton(object):
_instance = None

def __init__(self):
raise RuntimeError('Call get_instance() instead')
raise RuntimeError("Call get_instance() instead")

def initialize(self):
self.context = {}
binary_name = evaluate_cuda_setup()
if not os.path.exists(os.path.dirname(__file__) + f'/{binary_name}'):
print(f'TODO: compile library for specific version: {binary_name}')
print('defaulting to libbitsandbytes.so')
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + '/libbitsandbytes.so')
if not os.path.exists(os.path.dirname(__file__) + f"/{binary_name}"):
print(f"TODO: compile library for specific version: {binary_name}")
print("defaulting to libbitsandbytes.so")
self.lib = ct.cdll.LoadLibrary(
os.path.dirname(__file__) + "/libbitsandbytes.so"
)
else:
self.lib = ct.cdll.LoadLibrary(os.path.dirname(__file__) + f'/{binary_name}')
self.lib = ct.cdll.LoadLibrary(
os.path.dirname(__file__) + f"/{binary_name}"
)

@classmethod
def get_instance(cls):
Expand All @@ -35,6 +40,8 @@ def get_instance(cls):
lib.get_cusparse.restype = ct.c_void_p
COMPILED_WITH_CUDA = True
except AttributeError:
warn("The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable.")
warn(
"The installed version of bitsandbytes was compiled without GPU support. "
"8-bit optimizers and GPU quantization are unavailable."
)
COMPILED_WITH_CUDA = False
Loading

0 comments on commit bfa0e33

Please sign in to comment.