Skip to content

Commit

Permalink
Mark some functions for deprecation.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 24, 2024
1 parent 01bf54e commit 32979b4
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 18 deletions.
16 changes: 8 additions & 8 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from warnings import warn

import torch
from typing_extensions import deprecated

import bitsandbytes.functional as F

Expand Down Expand Up @@ -97,6 +98,10 @@ def undo_layout(permuted_tensor: torch.Tensor, tile_indices: torch.LongTensor) -
return outputs.reshape(rows, cols).contiguous()


@deprecated(
"MatMul8bit is deprecated and will be removed in a future release. Please use MatMul8bitLt instead.",
category=FutureWarning,
)
class MatMul8bit(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, quant_type="vector", precision=None):
Expand Down Expand Up @@ -208,6 +213,7 @@ def backward(ctx, grad_output):
matmul_cublas = MatMul8bit.apply


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def supports_igemmlt(device: torch.device) -> bool:
"""check if this device supports the optimized int8 kernel"""
if torch.cuda.get_device_capability(device=device) < (7, 5):
Expand All @@ -219,6 +225,7 @@ def supports_igemmlt(device: torch.device) -> bool:
return True


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def _get_tile_size(format):
assert format in (
"col_turing",
Expand All @@ -227,6 +234,7 @@ def _get_tile_size(format):
return (8, 32) if format == "col_turing" else (32, 32)


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def get_tile_inds(format, device):
transform = lambda x: F.transform(x.to(device), from_order="row", to_order=format)[0].to(x.device)
with torch.no_grad():
Expand Down Expand Up @@ -331,14 +339,6 @@ def forward(
# 2. Quantize B
state.CB, state.SCB, _ = F.int8_vectorwise_quant(B.to(torch.float16))

# (
# state.CB,
# state.CBt,
# state.SCB,
# state.SCBt,
# _,
# ) = F.double_quant(B.to(torch.float16))

if state.threshold > 0.0 and coo_tensorA is not None:
state.idx = torch.unique(coo_tensorA._indices()[1]).long()

Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,6 @@ def get_native_library() -> BNBNativeLibrary:
Inspect the output of the command and see if you can locate CUDA libraries. You might need to add them
to your LD_LIBRARY_PATH. If you suspect a bug, please take the information from python -m bitsandbytes
and open an issue at: https://github.com/TimDettmers/bitsandbytes/issues
and open an issue at: https://github.com/bitsandbytes-foundation/bitsandbytes/issues
""",
)
49 changes: 42 additions & 7 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
import torch
from torch import Tensor
from typing_extensions import deprecated

from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict

Expand Down Expand Up @@ -244,10 +245,12 @@ def fill(A, value, device=None, prefetch=True):
elementwise_func("fill", A, None, value)


@deprecated("Function will be removed in a future release.", category=FutureWarning)
def arange(A, device=None):
elementwise_func("arange", A, None, 0)


@deprecated("Function will be removed in a future release.", category=FutureWarning)
def _mul(A, B, device=None):
elementwise_func("_mul", A, B, 0)

Expand Down Expand Up @@ -414,7 +417,7 @@ def create_quantile_map(A, total_bits=8):
return q


# TODO: Deprecate
@deprecated("This function is deprecated and will be removed in a future version.", category=FutureWarning)
def get_special_format_str():
return "row"

Expand Down Expand Up @@ -475,6 +478,10 @@ def post_call(prev_device):
torch.cuda.set_device(prev_device)


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_func(dtype, orderA, orderOut, transpose=False):
name = f'ctransform_{(8 if dtype == torch.int8 else 32)}_{orderA}_to_{orderOut}_{"t" if transpose else "n"}'
if not hasattr(lib, name):
Expand All @@ -486,6 +493,10 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
return getattr(lib, name)


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def get_transform_buffer(shape, dtype, device, to_order, from_order="row", transpose=False):
# init_func = torch.empty
init_func = torch.zeros
Expand Down Expand Up @@ -525,6 +536,10 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans
raise NotImplementedError(f"To_order not supported: {to_order}")


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def nvidia_transform(
A,
to_order,
Expand Down Expand Up @@ -1424,6 +1439,7 @@ def dequantize_4bit(
return out


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def quantize(
A: Tensor,
code: Optional[torch.Tensor] = None,
Expand All @@ -1443,6 +1459,7 @@ def quantize(
return out, (absmax, code)


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequantize(
A: Tensor,
state: Optional[Tuple[Tensor, Tensor]] = None,
Expand All @@ -1463,6 +1480,7 @@ def dequantize(
return out * state[0]


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
"""
Quantizes input tensor to 8-bit.
Expand Down Expand Up @@ -1493,6 +1511,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
return out


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = None) -> Tensor:
"""
Dequantizes the 8-bit tensor to 32-bit.
Expand Down Expand Up @@ -1627,6 +1646,11 @@ def optimizer_update_32bit(
post_call(prev_device)


@deprecated(
"This function is deprecated and will be removed in a future release. "
"Please use optimizer_update_8bit_blockwise instead. ",
category=FutureWarning,
)
def optimizer_update_8bit(
optimizer_name: str,
g: Tensor,
Expand Down Expand Up @@ -1827,6 +1851,7 @@ def optimizer_update_8bit_blockwise(
post_call(prev_device)


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def percentile_clipping(grad: Tensor, gnorm_vec: Tensor, step: int, percentile: int = 5):
"""Applies percentile clipping
Expand Down Expand Up @@ -2516,11 +2541,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
return COOSparseTensor(rows, cols, nnz, rowidx, colidx, values)


def extract_outliers_new(A: torch.Tensor, threshold: float):
outlier_mask = A.abs() >= threshold
return A.masked_fill(outlier_mask == False, 0.0).to_sparse_coo()


def double_quant(A: torch.Tensor, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0):
# TODO: Optimize/write CUDA kernel for this?
# Note: for inference, use the new int8_vectorwise_quant.
Expand Down Expand Up @@ -2576,6 +2596,10 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
return out_row, row_stats, coo_tensor


@deprecated(
"The layout transformation operations will be removed in a future release. Please use row-major layout only.",
category=FutureWarning,
)
def transform(A, to_order, from_order="row", out=None, transpose=False, state=None, ld=None):
prev_device = pre_call(A.device)
if state is None:
Expand Down Expand Up @@ -2772,6 +2796,11 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
C = 127.0


@deprecated(
"This function is deprecated and will be removed in a future release. "
"Consider using `int8_vectorwise_quant` instead.",
category=FutureWarning,
)
def vectorwise_quant(x, dim=1, quant_type="vector"):
if quant_type == "linear":
max1 = torch.abs(x).max().float()
Expand Down Expand Up @@ -2816,6 +2845,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
return None


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def vectorwise_dequant(xq, max1, quant_type="vector"):
if quant_type == "vector":
x = (xq / C * max1).to(torch.float32)
Expand All @@ -2824,6 +2854,10 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
return None


@deprecated(
"This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead.",
category=FutureWarning,
)
def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
if quant_type == "linear":
norm = S1 * S2 / (C * C)
Expand Down Expand Up @@ -2883,6 +2917,7 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
return None


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
offset = B.float().t().sum(0) * (SA[0] + SA[1])
x = xq.float()
Expand All @@ -2898,7 +2933,6 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):


def extract_outliers(A, SA, idx):
# TODO: Implement for row-major
shapeA = SA[0]
formatA = SA[1]
assert formatA in ["col_turing", "col_ampere"]
Expand All @@ -2923,6 +2957,7 @@ def extract_outliers(A, SA, idx):
return out


@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
def pipeline_test(A, batch_size):
out = torch.zeros_like(A)
lib.cpipeline_test(get_ptr(A), get_ptr(out), ct.c_size_t(A.numel()), ct.c_size_t(batch_size))
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def has_ext_modules(self):
description="k-bit optimizers and matrix multiplication routines.",
license="MIT",
keywords="gpu optimizers optimization 8-bit quantization compression",
url="https://github.com/TimDettmers/bitsandbytes",
url="https://github.com/bitsandbytes-foundation/bitsandbytes",
packages=find_packages(),
package_data={"": libs},
install_requires=["torch", "numpy"],
install_requires=["torch", "numpy", "typing_extensions>=4.8.0"],
extras_require={
"benchmark": ["pandas", "matplotlib"],
"test": ["scipy", "lion_pytorch"],
Expand Down

0 comments on commit 32979b4

Please sign in to comment.