11# Copyright The FMS Model Optimizer Authors 
2- 
2+ # 
33# Licensed under the Apache License, Version 2.0 (the "License"); 
44# you may not use this file except in compliance with the License. 
55# You may obtain a copy of the License at 
6- 
6+ # 
77#     http://www.apache.org/licenses/LICENSE-2.0 
8- 
8+ # 
99# Unless required by applicable law or agreed to in writing, software 
1010# distributed under the License is distributed on an "AS IS" BASIS, 
1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
2121# Standard 
2222from  pathlib  import  Path 
2323import  logging 
24+ import  os 
25+ import  sys 
2426
2527# Third Party 
2628from  datasets  import  load_from_disk 
3436)
3537import  torch 
3638
37- import  os 
3839# Local 
3940from  fms_mo  import  qconfig_init , qmodel_prep 
4041from  fms_mo .custom_ext_kernels .utils  import  (
4849    get_act_scales_1gpu ,
4950)
5051from  fms_mo .utils .aiu_utils  import  save_for_aiu 
51- from  fms_mo .utils .dq_utils  import  config_quantize_smooth_layers 
52- from  fms_mo .utils .eval_utils  import  Evaluator , eval_llm_1GPU 
53- from  fms_mo .utils .utils  import  patch_torch_bmm , prepare_input 
5452from  fms_mo .utils .dq_inf  import  (
55-     save_vllm_fp8 ,
56-     convert_fp8_vllm_to_fms_mo ,
5753    check_quantization_setting ,
54+     convert_fp8_vllm_to_fms_mo ,
55+     save_vllm_fp8 ,
5856)
57+ from  fms_mo .utils .dq_utils  import  config_quantize_smooth_layers 
58+ from  fms_mo .utils .eval_utils  import  Evaluator , eval_llm_1GPU 
59+ from  fms_mo .utils .utils  import  patch_torch_bmm , prepare_input 
5960
6061logger  =  logging .getLogger (__name__ )
6162
@@ -133,16 +134,17 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
133134        low_cpu_mem_usage = bool (model_args .device_map ),
134135    )
135136
136-     inference =  model .config .to_dict ().get ("quantization_config" ,None )
137+     inference_qconfig  =  None 
138+     if  hasattr (model , "config" ):
139+         inference_qconfig  =  model .config .to_dict ().get ("quantization_config" , None )
137140
138-     if  inference :
139-         quant_setting  =  check_quantization_setting (inference )
141+     if  inference_qconfig :
142+         quant_setting  =  check_quantization_setting (inference_qconfig )
140143        if  quant_setting :
141144            logger .info ("Quantization config settings validated " )
142-             model  =  convert_fp8_vllm_to_fms_mo (model   =   model )
145+             model  =  convert_fp8_vllm_to_fms_mo (model = model )
143146        else :
144-             exit ("__This quantization config is wrong/not supported__" )
145- 
147+             sys .exit ("Error: This quantization config is wrong/not supported" )
146148
147149    embedding_size  =  model .get_input_embeddings ().weight .shape [0 ]
148150    if  len (tokenizer ) >  embedding_size :
@@ -152,23 +154,29 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
152154    logger .info (f"Model is at { model .device }  )
153155    logger .info (f"Tokenizer is { tokenizer } { block_size }  )
154156
155-     if  not  inference :
157+     if  not  inference_qconfig :
156158        logger .info ("quantization mode activated, initalizing the qcfg file " )
157159        qcfg  =  qconfig_init (recipe = "dq" , args = fms_mo_args )
158160    else :
159161        logger .info ("inference mode activated" )
160-         if  os .path .isfile (model_args .model_name_or_path + "/qcfg.json" ):
162+         if  os .path .isfile (model_args .model_name_or_path   +   "/qcfg.json" ):
161163            if  fms_mo_args .override_fms_args :
162-                 logger .info ("qcfg file found and some parameters are being over-written " )
163-                 qcfg  =  qconfig_init (recipe = model_args .model_name_or_path + "/qcfg" , args = fms_mo_args )
164+                 logger .info (
165+                     "qcfg file found and some parameters are being over-written " 
166+                 )
167+                 qcfg  =  qconfig_init (
168+                     recipe = model_args .model_name_or_path  +  "/qcfg" , args = fms_mo_args 
169+                 )
164170            else :
165171                logger .info ("qcfg file found, loading the qcfg file " )
166-                 qcfg  =  qconfig_init (recipe = model_args .model_name_or_path + "/qcfg" )
172+                 qcfg  =  qconfig_init (recipe = model_args .model_name_or_path   +   "/qcfg" )
167173        else :
168-             logger .info ("qcfg file not found in {model_args.model_name_or_path},\  
174+             logger .info (
175+                 "qcfg file not found in {model_args.model_name_or_path},\  
169176
170-                          )
177+             )
171178            qcfg  =  qconfig_init (recipe = "dq" , args = fms_mo_args )
179+         qcfg ["inference" ] =  True 
172180
173181    model_size  =  model_size_Wb (model , unit = "GB" )
174182    gpu_mem_util_per  =  model_size  /  total_gpu_memory 
@@ -193,7 +201,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
193201
194202    qcfg ["model" ] =  model_args .model_name_or_path 
195203    # config layers to skip, smooth scale 
196-     config_quantize_smooth_layers (qcfg )
204+     if  not  inference_qconfig :
205+         config_quantize_smooth_layers (qcfg )
197206
198207    use_dynamo  =  True 
199208    # use dynamo as default unless really needed, False -> fallback to TorchScript tracing 
@@ -225,7 +234,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
225234    )
226235
227236    # For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well. 
228-     if  not  inference  and  qcfg ["smoothq" ]  :
237+     if  not  inference_qconfig  and  qcfg ["smoothq" ]:
229238        scale_file  =  Path (f"./act_scales/{ qcfg ['model' ].replace ('/' , '-' )}  )
230239        if  qcfg .get ("act_scale_path" , None ):
231240            # user provided a scale file (or a dir) 
@@ -259,12 +268,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
259268            use_layer_name_pattern_matching = use_layer_name_pattern_matching ,
260269            use_dynamo = use_dynamo ,
261270            dev = dev ,
262-             mode = inference ,
263271            save_fname = "dq" ,
264272        )
265273        logger .info (f"Quantized model { model }  )
266274        logger .info ("=="  *  20 )
267-     if  not  inference :
275+     if  not  inference_qconfig :
268276        if  qcfg ["smoothq" ]:
269277            logger .info ("Starting to apply smooth scale" )
270278            dq_llm (model , act_scales , qcfg )
@@ -295,11 +303,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
295303            logger .info (
296304                f"Saving model processed for vLLM and tokenizer to { opt_args .output_dir }  
297305            )
298-             save_vllm_fp8 (model ,qcfg ,tokenizer ,opt_args .output_dir )
306+             save_vllm_fp8 (model ,  qcfg ,  tokenizer ,  opt_args .output_dir )
299307        elif  opt_args .save_ckpt :
300308            logger .info (
301309                f"Saving quantized model and tokenizer to { opt_args .output_dir }  
302-                  )
310+             )
303311            model .save_pretrained (opt_args .output_dir , use_safetensors = True )
304312            tokenizer .save_pretrained (opt_args .output_dir )
305313
0 commit comments