Skip to content

Commit

Permalink
reran black with linelength 80 for greater readability
Browse files Browse the repository at this point in the history
  • Loading branch information
Titus-von-Koeller committed Aug 1, 2022
1 parent 3fd06fb commit ea7c14f
Show file tree
Hide file tree
Showing 17 changed files with 665 additions and 203 deletions.
9 changes: 7 additions & 2 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .autograd._functions import (MatmulLtState, bmm_cublas, matmul,
matmul_cublas, mm_cublas)
from .autograd._functions import (
MatmulLtState,
bmm_cublas,
matmul,
matmul_cublas,
mm_cublas,
)
from .cextension import COMPILED_WITH_CUDA
from .nn import modules

Expand Down
45 changes: 36 additions & 9 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def backward(ctx, grad_output):
qgrad_output, S1 = F.vectorwise_quant(
grad_output, dim=dims, quant_type=quant_type
)
qA, S2 = F.vectorwise_quant(A, dim=dims, quant_type=quant_type)
qA, S2 = F.vectorwise_quant(
A, dim=dims, quant_type=quant_type
)
igrad_B = F.igemm(qA.permute(permute_dim), qgrad_output)
grad_B = F.vectorwise_mm_dequant(
igrad_B,
Expand Down Expand Up @@ -146,7 +148,11 @@ def backward(ctx, grad_output):
qB, S3 = F.vectorwise_quant(B, dim=dim_B, quant_type=quant_type)
igrad_A = F.igemm(qgrad_output, qB.permute(permute_dim))
grad_A = F.vectorwise_mm_dequant(
igrad_A, S1, S3.permute(permute_dim), grad_output.dtype, quant_type
igrad_A,
S1,
S3.permute(permute_dim),
grad_output.dtype,
quant_type,
)

return grad_A, grad_B, None, None, None
Expand Down Expand Up @@ -211,7 +217,9 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
# 1. Quantize A
if len(A.shape) == 3:
A = A.view(-1, A.shape[-1]).contiguous()
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(A, threshold=state.threshold)
CA, CAt, SCA, SCAt, coo_tensorA = F.double_quant(
A, threshold=state.threshold
)

if state.threshold > 0.0 and coo_tensorA is not None:
if state.has_fp16_weights:
Expand All @@ -225,7 +233,9 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
if state.CxB is None:
# B in in 8-bit row-major, we can transform it back to 16-bit to extract outlier dimensions
# we also need to convert it to the turing/ampere format
state.CxB, state.SB = F.transform(state.CB, to_order=formatB)
state.CxB, state.SB = F.transform(
state.CB, to_order=formatB
)
# state.B = (state.CB.float()*(state.SCB.view(-1, 1)/127)).half()
# if state.threshold > 0.0 and coo_tensorA is not None and state.idx is None and state.CB is not None:
# # generate outlier index and subB
Expand Down Expand Up @@ -259,7 +269,13 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):

if (state.is_training and not has_grad) or state.CxB is None:
state.reset_grads()
CB, state.CBt, state.SCB, state.SCBt, coo_tensorB = F.double_quant(B)
(
CB,
state.CBt,
state.SCB,
state.SCBt,
coo_tensorB,
) = F.double_quant(B)
state.CxB, state.SB = F.transform(CB, to_order=formatB)
else:
has_grad = False
Expand All @@ -277,7 +293,10 @@ def forward(ctx, A, B, out=None, state=MatmulLtState()):
# state.idx = outlier_idx
outliers = F.extract_outliers(state.CxB, state.SB, state.idx.int())
state.subB = (
(outliers * state.SCB.view(-1, 1) / 127.0).t().contiguous().half()
(outliers * state.SCB.view(-1, 1) / 127.0)
.t()
.contiguous()
.half()
)
CA[:, state.idx.long()] = 0
CAt[:, state.idx.long()] = 0
Expand Down Expand Up @@ -325,10 +344,14 @@ def backward(ctx, grad_output):
SCAt, idx = ctx.tensor_states
formatB = ctx.formatB
state = ctx.state
assert state.has_fp16_weights, "Backprop only supported for fp16 weights."
assert (
state.has_fp16_weights
), "Backprop only supported for fp16 weights."

if len(grad_output.shape) == 3:
grad_output = grad_output.view(-1, grad_output.shape[-1]).contiguous()
grad_output = grad_output.view(
-1, grad_output.shape[-1]
).contiguous()

grad_A = grad_B = None

Expand Down Expand Up @@ -359,7 +382,11 @@ def backward(ctx, grad_output):


def matmul(
A: tensor, B: tensor, out: tensor = None, state: MatmulLtState = None, threshold=0.0
A: tensor,
B: tensor,
out: tensor = None,
state: MatmulLtState = None,
threshold=0.0,
):
state = state or MatmulLtState()
if threshold > 0.0:
Expand Down
45 changes: 38 additions & 7 deletions bitsandbytes/cuda_setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
build is dependent on
- compute capability
- dependent on GPU family
extract factors the build is dependent on:
[X] compute capability
[ ] TODO: Q - What if we have multiple GPUs of different makes?
- CUDA version
- Software:
- CPU-only: only CPU quantization functions (no optimizer, no matrix multipl)
Expand All @@ -19,17 +19,40 @@
"""

import ctypes
import shlex
import subprocess
from os import environ as env
from pathlib import Path
from typing import Set, Union

from .utils import print_err, warn_of_missing_prerequisite


def execute_and_return(command_string: str) -> Tuple[str, str]:
def _decode(subprocess_err_out_tuple):
return tuple(
to_decode.decode("UTF-8").strip()
for to_decode in subprocess_err_out_tuple
)

def execute_and_return_decoded_std_streams(command_string):
return _decode(
subprocess.Popen(
shlex.split(command_string),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
).communicate()
)

std_out, std_err = execute_and_return_decoded_std_streams()
return std_out, std_err


def check_cuda_result(cuda, result_val):
if result_val != 0:
# TODO: undefined name 'error_str'
cuda.cuGetErrorString(result_val, ctypes.byref(error_str))
print(f"Count not initialize CUDA - failure!")
print("Count not initialize CUDA - failure!")
raise Exception("CUDA exception!")
return result_val

Expand All @@ -53,15 +76,19 @@ def get_compute_capability():

result = ctypes.c_int()
device = ctypes.c_int()
# TODO: local variable 'context' is assigned to but never used
context = ctypes.c_void_p()
# TODO: local variable 'error_str' is assigned to but never used
error_str = ctypes.c_char_p()

result = check_cuda_result(cuda, cuda.cuInit(0))

result = check_cuda_result(cuda, cuda.cuDeviceGetCount(ctypes.byref(nGpus)))
ccs = []
for i in range(nGpus.value):
result = check_cuda_result(cuda, cuda.cuDeviceGet(ctypes.byref(device), i))
result = check_cuda_result(
cuda, cuda.cuDeviceGet(ctypes.byref(device), i)
)
result = check_cuda_result(
cuda,
cuda.cuDeviceComputeCapability(
Expand Down Expand Up @@ -114,11 +141,15 @@ def get_cuda_runtime_lib_path(
} - non_existent_directories

if len(cuda_runtime_libs) > 1:
err_msg = f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
err_msg = (
f"Found duplicate {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
)
raise FileNotFoundError(err_msg)

elif len(cuda_runtime_libs) < 1:
err_msg = f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
err_msg = (
f"Did not find {CUDA_RUNTIME_LIB} files: {cuda_runtime_libs}.."
)
raise FileNotFoundError(err_msg)

single_cuda_runtime_lib_dir = next(iter(cuda_runtime_libs))
Expand Down
Loading

0 comments on commit ea7c14f

Please sign in to comment.