10
10
import numpy as np
11
11
import torch
12
12
from torch import Tensor
13
+ from typing_extensions import deprecated
13
14
14
15
from bitsandbytes .utils import pack_dict_to_tensor , unpack_tensor_to_dict
15
16
@@ -244,10 +245,12 @@ def fill(A, value, device=None, prefetch=True):
244
245
elementwise_func ("fill" , A , None , value )
245
246
246
247
248
+ @deprecated ("Function will be removed in a future release." , category = FutureWarning )
247
249
def arange (A , device = None ):
248
250
elementwise_func ("arange" , A , None , 0 )
249
251
250
252
253
+ @deprecated ("Function will be removed in a future release." , category = FutureWarning )
251
254
def _mul (A , B , device = None ):
252
255
elementwise_func ("_mul" , A , B , 0 )
253
256
@@ -414,7 +417,7 @@ def create_quantile_map(A, total_bits=8):
414
417
return q
415
418
416
419
417
- # TODO: Deprecate
420
+ @ deprecated ( "This function is deprecated and will be removed in a future version." , category = FutureWarning )
418
421
def get_special_format_str ():
419
422
return "row"
420
423
@@ -475,6 +478,10 @@ def post_call(prev_device):
475
478
torch .cuda .set_device (prev_device )
476
479
477
480
481
+ @deprecated (
482
+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
483
+ category = FutureWarning ,
484
+ )
478
485
def get_transform_func (dtype , orderA , orderOut , transpose = False ):
479
486
name = f'ctransform_{ (8 if dtype == torch .int8 else 32 )} _{ orderA } _to_{ orderOut } _{ "t" if transpose else "n" } '
480
487
if not hasattr (lib , name ):
@@ -486,6 +493,10 @@ def get_transform_func(dtype, orderA, orderOut, transpose=False):
486
493
return getattr (lib , name )
487
494
488
495
496
+ @deprecated (
497
+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
498
+ category = FutureWarning ,
499
+ )
489
500
def get_transform_buffer (shape , dtype , device , to_order , from_order = "row" , transpose = False ):
490
501
# init_func = torch.empty
491
502
init_func = torch .zeros
@@ -525,6 +536,10 @@ def get_transform_buffer(shape, dtype, device, to_order, from_order="row", trans
525
536
raise NotImplementedError (f"To_order not supported: { to_order } " )
526
537
527
538
539
+ @deprecated (
540
+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
541
+ category = FutureWarning ,
542
+ )
528
543
def nvidia_transform (
529
544
A ,
530
545
to_order ,
@@ -1424,6 +1439,7 @@ def dequantize_4bit(
1424
1439
return out
1425
1440
1426
1441
1442
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1427
1443
def quantize (
1428
1444
A : Tensor ,
1429
1445
code : Optional [torch .Tensor ] = None ,
@@ -1443,6 +1459,7 @@ def quantize(
1443
1459
return out , (absmax , code )
1444
1460
1445
1461
1462
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1446
1463
def dequantize (
1447
1464
A : Tensor ,
1448
1465
state : Optional [Tuple [Tensor , Tensor ]] = None ,
@@ -1463,6 +1480,7 @@ def dequantize(
1463
1480
return out * state [0 ]
1464
1481
1465
1482
1483
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1466
1484
def quantize_no_absmax (A : Tensor , code : Tensor , out : Optional [torch .Tensor ] = None ) -> Tensor :
1467
1485
"""
1468
1486
Quantizes input tensor to 8-bit.
@@ -1493,6 +1511,7 @@ def quantize_no_absmax(A: Tensor, code: Tensor, out: Optional[torch.Tensor] = No
1493
1511
return out
1494
1512
1495
1513
1514
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1496
1515
def dequantize_no_absmax (A : Tensor , code : Tensor , out : Optional [torch .Tensor ] = None ) -> Tensor :
1497
1516
"""
1498
1517
Dequantizes the 8-bit tensor to 32-bit.
@@ -1627,6 +1646,11 @@ def optimizer_update_32bit(
1627
1646
post_call (prev_device )
1628
1647
1629
1648
1649
+ @deprecated (
1650
+ "This function is deprecated and will be removed in a future release. "
1651
+ "Please use optimizer_update_8bit_blockwise instead. " ,
1652
+ category = FutureWarning ,
1653
+ )
1630
1654
def optimizer_update_8bit (
1631
1655
optimizer_name : str ,
1632
1656
g : Tensor ,
@@ -1827,6 +1851,7 @@ def optimizer_update_8bit_blockwise(
1827
1851
post_call (prev_device )
1828
1852
1829
1853
1854
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
1830
1855
def percentile_clipping (grad : Tensor , gnorm_vec : Tensor , step : int , percentile : int = 5 ):
1831
1856
"""Applies percentile clipping
1832
1857
@@ -2516,11 +2541,6 @@ def coo_zeros(rows, cols, nnz, device, dtype=torch.half):
2516
2541
return COOSparseTensor (rows , cols , nnz , rowidx , colidx , values )
2517
2542
2518
2543
2519
- def extract_outliers_new (A : torch .Tensor , threshold : float ):
2520
- outlier_mask = A .abs () >= threshold
2521
- return A .masked_fill (outlier_mask == False , 0.0 ).to_sparse_coo ()
2522
-
2523
-
2524
2544
def double_quant (A : torch .Tensor , col_stats = None , row_stats = None , out_col = None , out_row = None , threshold = 0.0 ):
2525
2545
# TODO: Optimize/write CUDA kernel for this?
2526
2546
# Note: for inference, use the new int8_vectorwise_quant.
@@ -2576,6 +2596,10 @@ def int8_vectorwise_quant(A: torch.Tensor, threshold=0.0):
2576
2596
return out_row , row_stats , coo_tensor
2577
2597
2578
2598
2599
+ @deprecated (
2600
+ "The layout transformation operations will be removed in a future release. Please use row-major layout only." ,
2601
+ category = FutureWarning ,
2602
+ )
2579
2603
def transform (A , to_order , from_order = "row" , out = None , transpose = False , state = None , ld = None ):
2580
2604
prev_device = pre_call (A .device )
2581
2605
if state is None :
@@ -2772,6 +2796,11 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
2772
2796
C = 127.0
2773
2797
2774
2798
2799
+ @deprecated (
2800
+ "This function is deprecated and will be removed in a future release. "
2801
+ "Consider using `int8_vectorwise_quant` instead." ,
2802
+ category = FutureWarning ,
2803
+ )
2775
2804
def vectorwise_quant (x , dim = 1 , quant_type = "vector" ):
2776
2805
if quant_type == "linear" :
2777
2806
max1 = torch .abs (x ).max ().float ()
@@ -2816,6 +2845,7 @@ def vectorwise_quant(x, dim=1, quant_type="vector"):
2816
2845
return None
2817
2846
2818
2847
2848
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
2819
2849
def vectorwise_dequant (xq , max1 , quant_type = "vector" ):
2820
2850
if quant_type == "vector" :
2821
2851
x = (xq / C * max1 ).to (torch .float32 )
@@ -2824,6 +2854,10 @@ def vectorwise_dequant(xq, max1, quant_type="vector"):
2824
2854
return None
2825
2855
2826
2856
2857
+ @deprecated (
2858
+ "This function is deprecated and will be removed in a future release. Consider using `mm_dequant` instead." ,
2859
+ category = FutureWarning ,
2860
+ )
2827
2861
def vectorwise_mm_dequant (xq , S1 , S2 , dtype = torch .half , quant_type = "vector" ):
2828
2862
if quant_type == "linear" :
2829
2863
norm = S1 * S2 / (C * C )
@@ -2883,6 +2917,7 @@ def vectorwise_mm_dequant(xq, S1, S2, dtype=torch.half, quant_type="vector"):
2883
2917
return None
2884
2918
2885
2919
2920
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
2886
2921
def dequant_min_max (xq , A , B , SA , SB , dtype = torch .half ):
2887
2922
offset = B .float ().t ().sum (0 ) * (SA [0 ] + SA [1 ])
2888
2923
x = xq .float ()
@@ -2898,7 +2933,6 @@ def dequant_min_max(xq, A, B, SA, SB, dtype=torch.half):
2898
2933
2899
2934
2900
2935
def extract_outliers (A , SA , idx ):
2901
- # TODO: Implement for row-major
2902
2936
shapeA = SA [0 ]
2903
2937
formatA = SA [1 ]
2904
2938
assert formatA in ["col_turing" , "col_ampere" ]
@@ -2923,6 +2957,7 @@ def extract_outliers(A, SA, idx):
2923
2957
return out
2924
2958
2925
2959
2960
+ @deprecated ("This function is deprecated and will be removed in a future release." , category = FutureWarning )
2926
2961
def pipeline_test (A , batch_size ):
2927
2962
out = torch .zeros_like (A )
2928
2963
lib .cpipeline_test (get_ptr (A ), get_ptr (out ), ct .c_size_t (A .numel ()), ct .c_size_t (batch_size ))
0 commit comments