1010import  numpy  as  np 
1111import  torch 
1212from  torch  import  Tensor 
13+ from  typing_extensions  import  deprecated 
1314
1415from  bitsandbytes .utils  import  pack_dict_to_tensor , unpack_tensor_to_dict 
1516
@@ -244,10 +245,12 @@ def fill(A, value, device=None, prefetch=True):
244245    elementwise_func ("fill" , A , None , value )
245246
246247
248+ @deprecated ("Function will be removed in a future release." , category = FutureWarning ) 
247249def  arange (A , device = None ):
248250    elementwise_func ("arange" , A , None , 0 )
249251
250252
253+ @deprecated ("Function will be removed in a future release." , category = FutureWarning ) 
251254def  _mul (A , B , device = None ):
252255    elementwise_func ("_mul" , A , B , 0 )
253256
@@ -414,7 +417,7 @@ def create_quantile_map(A, total_bits=8):
414417    return  q 
415418
416419
417- # TODO: Deprecate 
420+ @ deprecated ( "This function is deprecated and will be removed in a future version." ,  category = FutureWarning ) 
418421def  get_special_format_str ():
419422    return  "row" 
420423
@@ -475,6 +478,10 @@ def post_call(prev_device):
475478    torch .cuda .set_device (prev_device )
476479
477480
481+ @deprecated ( 
482+     "The layout transformation operations will be removed in a future release. Please use row-major layout only." , 
483+     category = FutureWarning , 
484+ ) 
478485def  get_transform_func (dtype , orderA , orderOut , transpose = False ):
479486    name  =  f'ctransform_{ (8  if  dtype  ==  torch .int8  else  32 )} { orderA } { orderOut } { "t"  if  transpose  else  "n" }  
480487    if  not  hasattr (lib , name ):
@@ -486,6 +493,10 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
486493        return  getattr (lib , name )
487494
488495
496+ @deprecated ( 
497+     "The layout transformation operations will be removed in a future release. Please use row-major layout only." , 
498+     category = FutureWarning , 
499+ ) 
489500def  get_transform_buffer (shape , dtype , device , to_order , from_order = "row" , transpose = False ):
490501    # init_func = torch.empty 
491502    init_func  =  torch .zeros 
@@ -525,6 +536,10 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans
525536        raise  NotImplementedError (f"To_order not supported: { to_order }  )
526537
527538
539+ @deprecated ( 
540+     "The layout transformation operations will be removed in a future release. Please use row-major layout only." , 
541+     category = FutureWarning , 
542+ ) 
528543def  nvidia_transform (
529544    A ,
530545    to_order ,
@@ -1424,6 +1439,7 @@ def dequantize_4bit(
14241439        return  out 
14251440
14261441
1442+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
14271443def  quantize (
14281444    A : Tensor ,
14291445    code : Optional [torch .Tensor ] =  None ,
@@ -1443,6 +1459,7 @@ def quantize(
14431459    return  out , (absmax , code )
14441460
14451461
1462+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
14461463def  dequantize (
14471464    A : Tensor ,
14481465    state : Optional [Tuple [Tensor , Tensor ]] =  None ,
@@ -1463,6 +1480,7 @@ def dequantize(
14631480    return  out  *  state [0 ]
14641481
14651482
1483+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
14661484def  quantize_no_absmax (A : Tensor , code : Tensor , out : Optional [torch .Tensor ] =  None ) ->  Tensor :
14671485    """ 
14681486    Quantizes input tensor to 8-bit. 
@@ -1493,6 +1511,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
14931511    return  out 
14941512
14951513
1514+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
14961515def  dequantize_no_absmax (A : Tensor , code : Tensor , out : Optional [torch .Tensor ] =  None ) ->  Tensor :
14971516    """ 
14981517    Dequantizes the 8-bit tensor to 32-bit. 
@@ -1627,6 +1646,11 @@ def optimizer_update_32bit(
16271646    post_call (prev_device )
16281647
16291648
1649+ @deprecated ( 
1650+     "This function is deprecated and will be removed in a future release. "  
1651+     "Please use optimizer_update_8bit_blockwise instead. " , 
1652+     category = FutureWarning , 
1653+ ) 
16301654def  optimizer_update_8bit (
16311655    optimizer_name : str ,
16321656    g : Tensor ,
@@ -1827,6 +1851,7 @@ def optimizer_update_8bit_blockwise(
18271851    post_call (prev_device )
18281852
18291853
1854+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
18301855def  percentile_clipping (grad : Tensor , gnorm_vec : Tensor , step : int , percentile : int  =  5 ):
18311856    """Applies percentile clipping 
18321857
@@ -2516,11 +2541,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
25162541    return  COOSparseTensor (rows , cols , nnz , rowidx , colidx , values )
25172542
25182543
2519- def  extract_outliers_new (A : torch .Tensor , threshold : float ):
2520-     outlier_mask  =  A .abs () >=  threshold 
2521-     return  A .masked_fill (outlier_mask  ==  False , 0.0 ).to_sparse_coo ()
2522- 
2523- 
25242544def  double_quant (A : torch .Tensor , col_stats = None , row_stats = None , out_col = None , out_row = None , threshold = 0.0 ):
25252545    # TODO: Optimize/write CUDA kernel for this? 
25262546    # Note: for inference, use the new int8_vectorwise_quant. 
@@ -2576,6 +2596,10 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
25762596    return  out_row , row_stats , coo_tensor 
25772597
25782598
2599+ @deprecated ( 
2600+     "The layout transformation operations will be removed in a future release. Please use row-major layout only." , 
2601+     category = FutureWarning , 
2602+ ) 
25792603def  transform (A , to_order , from_order = "row" , out = None , transpose = False , state = None , ld = None ):
25802604    prev_device  =  pre_call (A .device )
25812605    if  state  is  None :
@@ -2772,6 +2796,11 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
27722796C  =  127.0 
27732797
27742798
2799+ @deprecated ( 
2800+     "This function is deprecated and will be removed in a future release. "  
2801+     "Consider using `int8_vectorwise_quant` instead." , 
2802+     category = FutureWarning , 
2803+ ) 
27752804def  vectorwise_quant (x , dim = 1 , quant_type = "vector" ):
27762805    if  quant_type  ==  "linear" :
27772806        max1  =  torch .abs (x ).max ().float ()
@@ -2816,6 +2845,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
28162845        return  None 
28172846
28182847
2848+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
28192849def  vectorwise_dequant (xq , max1 , quant_type = "vector" ):
28202850    if  quant_type  ==  "vector" :
28212851        x  =  (xq  /  C  *  max1 ).to (torch .float32 )
@@ -2824,6 +2854,10 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
28242854        return  None 
28252855
28262856
2857+ @deprecated ( 
2858+     "This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead." , 
2859+     category = FutureWarning , 
2860+ ) 
28272861def  vectorwise_mm_dequant (xq , S1 , S2 , dtype = torch .half , quant_type = "vector" ):
28282862    if  quant_type  ==  "linear" :
28292863        norm  =  S1  *  S2  /  (C  *  C )
@@ -2883,6 +2917,7 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
28832917        return  None 
28842918
28852919
2920+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
28862921def  dequant_min_max (xq , A , B , SA , SB , dtype = torch .half ):
28872922    offset  =  B .float ().t ().sum (0 ) *  (SA [0 ] +  SA [1 ])
28882923    x  =  xq .float ()
@@ -2898,7 +2933,6 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
28982933
28992934
29002935def  extract_outliers (A , SA , idx ):
2901-     # TODO: Implement for row-major 
29022936    shapeA  =  SA [0 ]
29032937    formatA  =  SA [1 ]
29042938    assert  formatA  in  ["col_turing" , "col_ampere" ]
@@ -2923,6 +2957,7 @@ def extract_outliers(A, SA, idx):
29232957    return  out 
29242958
29252959
2960+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning ) 
29262961def  pipeline_test (A , batch_size ):
29272962    out  =  torch .zeros_like (A )
29282963    lib .cpipeline_test (get_ptr (A ), get_ptr (out ), ct .c_size_t (A .numel ()), ct .c_size_t (batch_size ))
0 commit comments