1010import numpy as np
1111import torch
1212from torch import Tensor
13+ from typing_extensions import deprecated
1314
1415from bitsandbytes .utils import pack_dict_to_tensor , unpack_tensor_to_dict
1516
@@ -244,10 +245,12 @@ def fill(A, value, device=None, prefetch=True):
244245 elementwise_func ("fill" , A , None , value )
245246
246247
248+ @deprecated ("Function will be removed in a future release." , category = FutureWarning )
247249def arange (A , device = None ):
248250 elementwise_func ("arange" , A , None , 0 )
249251
250252
253+ @deprecated ("Function will be removed in a future release." , category = FutureWarning )
251254def _mul (A , B , device = None ):
252255 elementwise_func ("_mul" , A , B , 0 )
253256
@@ -414,7 +417,7 @@ def create_quantile_map(A, total_bits=8):
414417 return q
415418
416419
417- # TODO: Deprecate
420+ @ deprecated ( "This function is deprecated and will be removed in a future version." , category = FutureWarning )
418421def get_special_format_str ():
419422 return "row"
420423
@@ -475,6 +478,10 @@ def post_call(prev_device):
475478 torch .cuda .set_device (prev_device )
476479
477480
481+ @deprecated (
482+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
483+ category = FutureWarning ,
484+ )
478485def get_transform_func (dtype , orderA , orderOut , transpose = False ):
479486 name = f'ctransform_{ (8 if dtype == torch .int8 else 32 )} _{ orderA } _to_{ orderOut } _{ "t" if transpose else "n" } '
480487 if not hasattr (lib , name ):
@@ -486,6 +493,10 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
486493 return getattr (lib , name )
487494
488495
496+ @deprecated (
497+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
498+ category = FutureWarning ,
499+ )
489500def get_transform_buffer (shape , dtype , device , to_order , from_order = "row" , transpose = False ):
490501 # init_func = torch.empty
491502 init_func = torch .zeros
@@ -525,6 +536,10 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans
525536 raise NotImplementedError (f"To_order not supported: { to_order } " )
526537
527538
539+ @deprecated (
540+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
541+ category = FutureWarning ,
542+ )
528543def nvidia_transform (
529544 A ,
530545 to_order ,
@@ -1424,6 +1439,7 @@ def dequantize_4bit(
14241439 return out
14251440
14261441
1442+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
14271443def quantize (
14281444 A : Tensor ,
14291445 code : Optional [torch .Tensor ] = None ,
@@ -1443,6 +1459,7 @@ def quantize(
14431459 return out , (absmax , code )
14441460
14451461
1462+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
14461463def dequantize (
14471464 A : Tensor ,
14481465 state : Optional [Tuple [Tensor , Tensor ]] = None ,
@@ -1463,6 +1480,7 @@ def dequantize(
14631480 return out * state [0 ]
14641481
14651482
1483+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
14661484def quantize_no_absmax (A : Tensor , code : Tensor , out : Optional [torch .Tensor ] = None ) -> Tensor :
14671485 """
14681486 Quantizes input tensor to 8-bit.
@@ -1493,6 +1511,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
14931511 return out
14941512
14951513
1514+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
14961515def dequantize_no_absmax (A : Tensor , code : Tensor , out : Optional [torch .Tensor ] = None ) -> Tensor :
14971516 """
14981517 Dequantizes the 8-bit tensor to 32-bit.
@@ -1627,6 +1646,11 @@ def optimizer_update_32bit(
16271646 post_call (prev_device )
16281647
16291648
1649+ @deprecated (
1650+ "This function is deprecated and will be removed in a future release. "
1651+ "Please use optimizer_update_8bit_blockwise instead. " ,
1652+ category = FutureWarning ,
1653+ )
16301654def optimizer_update_8bit (
16311655 optimizer_name : str ,
16321656 g : Tensor ,
@@ -1827,6 +1851,7 @@ def optimizer_update_8bit_blockwise(
18271851 post_call (prev_device )
18281852
18291853
1854+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
18301855def percentile_clipping (grad : Tensor , gnorm_vec : Tensor , step : int , percentile : int = 5 ):
18311856 """Applies percentile clipping
18321857
@@ -2516,11 +2541,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
25162541 return COOSparseTensor (rows , cols , nnz , rowidx , colidx , values )
25172542
25182543
2519- def extract_outliers_new (A : torch .Tensor , threshold : float ):
2520- outlier_mask = A .abs () >= threshold
2521- return A .masked_fill (outlier_mask == False , 0.0 ).to_sparse_coo ()
2522-
2523-
25242544def double_quant (A : torch .Tensor , col_stats = None , row_stats = None , out_col = None , out_row = None , threshold = 0.0 ):
25252545 # TODO: Optimize/write CUDA kernel for this?
25262546 # Note: for inference, use the new int8_vectorwise_quant.
@@ -2576,6 +2596,10 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25762596 return out_row , row_stats , coo_tensor
25772597
25782598
2599+ @deprecated (
2600+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
2601+ category = FutureWarning ,
2602+ )
25792603def transform (A , to_order , from_order = "row" , out = None , transpose = False , state = None , ld = None ):
25802604 prev_device = pre_call (A .device )
25812605 if state is None :
@@ -2772,6 +2796,11 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
27722796C = 127.0
27732797
27742798
2799+ @deprecated (
2800+ "This function is deprecated and will be removed in a future release. "
2801+ "Consider using `int8_vectorwise_quant` instead." ,
2802+ category = FutureWarning ,
2803+ )
27752804def vectorwise_quant (x , dim = 1 , quant_type = "vector" ):
27762805 if quant_type == "linear" :
27772806 max1 = torch .abs (x ).max ().float ()
@@ -2816,6 +2845,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
28162845 return None
28172846
28182847
2848+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
28192849def vectorwise_dequant (xq , max1 , quant_type = "vector" ):
28202850 if quant_type == "vector" :
28212851 x = (xq / C * max1 ).to (torch .float32 )
@@ -2824,6 +2854,10 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
28242854 return None
28252855
28262856
2857+ @deprecated (
2858+ "This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead." ,
2859+ category = FutureWarning ,
2860+ )
28272861def vectorwise_mm_dequant (xq , S1 , S2 , dtype = torch .half , quant_type = "vector" ):
28282862 if quant_type == "linear" :
28292863 norm = S1 * S2 / (C * C )
@@ -2883,6 +2917,7 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
28832917 return None
28842918
28852919
2920+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
28862921def dequant_min_max (xq , A , B , SA , SB , dtype = torch .half ):
28872922 offset = B .float ().t ().sum (0 ) * (SA [0 ] + SA [1 ])
28882923 x = xq .float ()
@@ -2898,7 +2933,6 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
28982933
28992934
29002935def extract_outliers (A , SA , idx ):
2901- # TODO: Implement for row-major
29022936 shapeA = SA [0 ]
29032937 formatA = SA [1 ]
29042938 assert formatA in ["col_turing" , "col_ampere" ]
@@ -2923,6 +2957,7 @@ def extract_outliers(A, SA, idx):
29232957 return out
29242958
29252959
2960+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
29262961def pipeline_test (A , batch_size ):
29272962 out = torch .zeros_like (A )
29282963 lib .cpipeline_test (get_ptr (A ), get_ptr (out ), ct .c_size_t (A .numel ()), ct .c_size_t (batch_size ))
0 commit comments