3434from .modules import WeightOnlyLinear
3535
3636DEBUG = False
37+ accelerator = auto_detect_accelerator ()
3738
3839
3940# ================ device related ===================
@@ -542,8 +543,10 @@ def forward(layer, *args, **kwargs):
542543 if self .run_fn :
543544 if self .run_args :
544545 self .run_fn (self .model , * self .run_args )
546+ accelerator .mark_step ()
545547 else :
546548 self .run_fn (self .model )
549+ accelerator .mark_step ()
547550 else :
548551 for batch in tqdm (self .dataloader ):
549552 if not self .use_layer_wise :
@@ -663,6 +666,7 @@ def tmp(_, inp, out):
663666 for j in range (batch_num ):
664667 cache_keyword_batch = self .gather_single_batch_from_dict (self .cache_key_arguments , j )
665668 cache_positional_batch = self .gather_single_batch_from_list (self .cache_positional_arguments , j )
669+ accelerator .mark_step ()
666670 out = transformer_block (* cache_positional_batch , ** cache_keyword_batch )
667671 out = self .track_hidden_states (out )
668672 self .cache_key_arguments ["batch_num" ] = batch_num
@@ -682,6 +686,9 @@ def tmp(_, inp, out):
682686 W = load_value (self .model , full_layer_name + ".weight" , model_path )
683687 else :
684688 W = sub_layers [layer_name ].weight .data .clone ()
689+ accelerator .mark_step ()
690+ if "hpu" in self .device :
691+ W = W .to ("cpu" )
685692 scale , zp , Q = gptq_for_this_block [layer_name ].fasterquant (
686693 W ,
687694 blocksize = weight_config_this_layer ["block_size" ],
@@ -854,6 +861,8 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
854861 self .quantizer .find_params (W , weight = True )
855862
856863 H = self .H
864+ if "hpu" in self .device :
865+ H = H .to ("cpu" )
857866 del self .H
858867 dead = torch .diag (H ) == 0
859868 H [dead , dead ] = 1
@@ -958,6 +967,10 @@ def fasterquant(self, W, blocksize=128, percdamp=0.01, groupsize=-1, act_order=F
958967 zero .append (self .quantizer .zero )
959968 scale = torch .cat (scale , dim = 1 )
960969 zero = torch .cat (zero , dim = 1 )
970+ if "hpu" in self .device :
971+ scale = scale .to (self .device )
972+ zero = zero .to (self .device )
973+ Q = Q .to (self .device )
961974 return scale , zero , Q
962975
963976 def free (self ):
@@ -973,25 +986,25 @@ def free(self):
973986class Quantizer (nn .Module ):
974987 def __init__ (self , shape = 1 ):
975988 super (Quantizer , self ).__init__ ()
976- self .register_buffer ( " maxq" , torch . tensor ( 0 ))
989+ self .maxq = 0
977990 self .register_buffer ("scale" , torch .zeros (shape ))
978991 self .register_buffer ("zero" , torch .zeros (shape ))
979992
980993 def configure (self , weight_config_this_layer , norm = 2.4 , grid = 100 , maxshrink = 0.8 , trits = False ):
981994 for k , v in weight_config_this_layer .items ():
982995 setattr (self , k , v )
983- self .maxq = torch .tensor (2 ** self .bits - 1 )
996+ # self.maxq = torch.tensor(2**self.bits - 1)
997+ self .maxq = 2 ** self .bits - 1
984998 self .scheme = "sym" if self .sym else "asym"
985999 self .double_quant_scheme = "sym" if self .double_quant_sym else "asym"
9861000 self .norm = norm
9871001 self .grid = grid
9881002 self .maxshrink = maxshrink
9891003 if trits :
990- self .maxq = torch . tensor ( - 1 )
1004+ self .maxq = - 1
9911005
9921006 def find_params (self , x , weight = False ):
9931007 dev = x .device
994- self .maxq = self .maxq .to (dev )
9951008 # NF4 FP4
9961009 if self .dtype != "int" :
9971010 from .utility import quant_tensor
0 commit comments