Skip to content

Commit

Permalink
Merge branch 'multi-backend-refactor' into device_abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
pnunna93 authored Aug 25, 2024
2 parents 42cc717 + 18668d2 commit b22eb2e
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 12 deletions.
12 changes: 6 additions & 6 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from .cextension import lib
from .nn import modules

# NOTE: this is a temporary flag to allow outside libraries to employ conditional logic while the refactor is still in
# alpha/beta: sth like `if getattr(bitsandbytes, "is_multi_backend_refactor_preview", False): do sth`
# the getattr() call above would default to False and any string evaluates to True. This way we have temporary thing
# that we can remove in Transformers with the next release after the official BNB multi-platform release; then
# eventually making it the new default (e.g. just remove if statement and dedent in Transformers)
is_multi_backend_refactor_preview = "TO BE REMOVED ONCE MERGED TO `main`" # bool evals to True for str
features = {"multi_backend"}
supported_torch_devices = {
"cuda", # includes ROCm
"xpu", # Intel GPU
"cpu",
}

# Always register the CPU backend.
register_backend("cpu", CPUBackend())
Expand Down
18 changes: 14 additions & 4 deletions bitsandbytes/backends/cpu_xpu_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,9 @@ def quantize_4bit_impl(
-1, # act_quant_mode. -1 means don't quant activation
)
state.absmax = torch.Tensor()
return torch.Tensor(), state
return torch.empty([1, 0], dtype=torch.uint8), state

return out, state
return out.unsqueeze(0), state


@_maybe_torch_compile
Expand Down Expand Up @@ -428,6 +428,13 @@ def dequantize_4bit_impl(
Dequantized tensor.
"""

if A.shape[0] == 1:
transpose = False
A = A.squeeze(0)
elif A.shape[1] == 1:
transpose = True
A = A.squeeze(1)

if quant_state is None:
assert absmax is not None and out is not None

Expand Down Expand Up @@ -484,7 +491,10 @@ def dequantize_4bit_impl(
out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1]

# take transpose here because weight is transposed (again) for computation
return out.t()
if transpose:
out = out.t()

return out


# Do not need torch.compile here as we are calling torch/ipex kernel
Expand Down Expand Up @@ -523,7 +533,7 @@ def gemm_4bit_impl(
assert state.op_context is not None
output = torch.ops.torch_ipex.ipex_woq_linear(A, state.op_context.get_data_handle())
else:
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize)
dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize).t()
output = torch.matmul(A, dqB.to(A.dtype))
if out is not None:
out.copy_(output)
Expand Down
1 change: 1 addition & 0 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def get_native_library() -> BNBNativeLibrary:
HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0
BNB_HIP_VERSION_SHORT = ""
BNB_BACKEND = "CUDA"

lib = get_native_library()
except Exception as e:
lib = None
Expand Down
12 changes: 10 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
if getattr(self.weight.quant_state, "op_context", None) is not None:
context = self.weight.quant_state.op_context
destination[prefix + "weight." + "absmax"] = context.get_scales().reshape(-1)
self.weight.data = context.to_public(context.get_weight()).reshape([1, -1])

def forward(self, x: torch.Tensor):
# weights are cast automatically as Int8Params, but the bias has to be cast manually
Expand Down Expand Up @@ -639,8 +643,12 @@ def to(self, *args, **kwargs):

if device.type == "cuda" and self.data.device.type == "cpu":
return self.cuda(device)
elif device.type == "cpu" and self.data.dtype != torch.int8:
return self.cpu()
elif device.type == "cpu":
if self.data.dtype == torch.int8:
self.CB = self.data
return self
else:
return self.cpu()
else:
new_param = Int8Params(
super().to(device=device, dtype=dtype, non_blocking=non_blocking),
Expand Down

0 comments on commit b22eb2e

Please sign in to comment.