Skip to content

Commit b458d18

Browse files
committed
fix: updated the code to reflect PR update
Signed-off-by: Omobayode Fagbohungbe <[email protected]>
1 parent 0b5d68a commit b458d18

File tree

12 files changed

+175
-111
lines changed

12 files changed

+175
-111
lines changed

.pylintrc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ ignored-modules=gptqmodel,
6969
llmcompressor,
7070
cutlass_mm,
7171
pygraphviz,
72-
matplotlib
72+
matplotlib,
73+
compressed_tensors
7374

7475
# Python code to execute, usually for sys.path manipulation such as
7576
# pygtk.require().

fms_mo/dq.py

Lines changed: 36 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# Copyright The FMS Model Optimizer Authors
2-
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
6-
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
8+
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -21,6 +21,8 @@
2121
# Standard
2222
from pathlib import Path
2323
import logging
24+
import os
25+
import sys
2426

2527
# Third Party
2628
from datasets import load_from_disk
@@ -34,7 +36,6 @@
3436
)
3537
import torch
3638

37-
import os
3839
# Local
3940
from fms_mo import qconfig_init, qmodel_prep
4041
from fms_mo.custom_ext_kernels.utils import (
@@ -48,14 +49,14 @@
4849
get_act_scales_1gpu,
4950
)
5051
from fms_mo.utils.aiu_utils import save_for_aiu
51-
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
52-
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
53-
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
5452
from fms_mo.utils.dq_inf import (
55-
save_vllm_fp8,
56-
convert_fp8_vllm_to_fms_mo,
5753
check_quantization_setting,
54+
convert_fp8_vllm_to_fms_mo,
55+
save_vllm_fp8,
5856
)
57+
from fms_mo.utils.dq_utils import config_quantize_smooth_layers
58+
from fms_mo.utils.eval_utils import Evaluator, eval_llm_1GPU
59+
from fms_mo.utils.utils import patch_torch_bmm, prepare_input
5960

6061
logger = logging.getLogger(__name__)
6162

@@ -133,16 +134,17 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
133134
low_cpu_mem_usage=bool(model_args.device_map),
134135
)
135136

136-
inference= model.config.to_dict().get("quantization_config",None)
137+
inference_qconfig = None
138+
if hasattr(model, "config"):
139+
inference_qconfig = model.config.to_dict().get("quantization_config", None)
137140

138-
if inference:
139-
quant_setting = check_quantization_setting(inference)
141+
if inference_qconfig:
142+
quant_setting = check_quantization_setting(inference_qconfig)
140143
if quant_setting:
141144
logger.info("Quantization config settings validated ")
142-
model = convert_fp8_vllm_to_fms_mo(model = model)
145+
model = convert_fp8_vllm_to_fms_mo(model=model)
143146
else:
144-
exit("__This quantization config is wrong/not supported__")
145-
147+
sys.exit("Error: This quantization config is wrong/not supported")
146148

147149
embedding_size = model.get_input_embeddings().weight.shape[0]
148150
if len(tokenizer) > embedding_size:
@@ -152,23 +154,29 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
152154
logger.info(f"Model is at {model.device} after intialization")
153155
logger.info(f"Tokenizer is {tokenizer}, block size is {block_size}")
154156

155-
if not inference:
157+
if not inference_qconfig:
156158
logger.info("quantization mode activated, initalizing the qcfg file ")
157159
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
158160
else:
159161
logger.info("inference mode activated")
160-
if os.path.isfile(model_args.model_name_or_path+"/qcfg.json"):
162+
if os.path.isfile(model_args.model_name_or_path + "/qcfg.json"):
161163
if fms_mo_args.override_fms_args:
162-
logger.info("qcfg file found and some parameters are being over-written ")
163-
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg", args=fms_mo_args)
164+
logger.info(
165+
"qcfg file found and some parameters are being over-written "
166+
)
167+
qcfg = qconfig_init(
168+
recipe=model_args.model_name_or_path + "/qcfg", args=fms_mo_args
169+
)
164170
else:
165171
logger.info("qcfg file found, loading the qcfg file ")
166-
qcfg = qconfig_init(recipe=model_args.model_name_or_path+"/qcfg")
172+
qcfg = qconfig_init(recipe=model_args.model_name_or_path + "/qcfg")
167173
else:
168-
logger.info("qcfg file not found in {model_args.model_name_or_path},\
174+
logger.info(
175+
"qcfg file not found in {model_args.model_name_or_path},\
169176
loading fms_mo_args and recipe"
170-
)
177+
)
171178
qcfg = qconfig_init(recipe="dq", args=fms_mo_args)
179+
qcfg["inference"] = True
172180

173181
model_size = model_size_Wb(model, unit="GB")
174182
gpu_mem_util_per = model_size / total_gpu_memory
@@ -193,7 +201,8 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
193201

194202
qcfg["model"] = model_args.model_name_or_path
195203
# config layers to skip, smooth scale
196-
config_quantize_smooth_layers(qcfg)
204+
if not inference_qconfig:
205+
config_quantize_smooth_layers(qcfg)
197206

198207
use_dynamo = True
199208
# use dynamo as default unless really needed, False -> fallback to TorchScript tracing
@@ -225,7 +234,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
225234
)
226235

227236
# For loading or creating smoothquant scale. Sometimes we may include scales in ckpt as well.
228-
if not inference and qcfg["smoothq"] :
237+
if not inference_qconfig and qcfg["smoothq"]:
229238
scale_file = Path(f"./act_scales/{qcfg['model'].replace('/', '-')}.pt")
230239
if qcfg.get("act_scale_path", None):
231240
# user provided a scale file (or a dir)
@@ -259,12 +268,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
259268
use_layer_name_pattern_matching=use_layer_name_pattern_matching,
260269
use_dynamo=use_dynamo,
261270
dev=dev,
262-
mode=inference,
263271
save_fname="dq",
264272
)
265273
logger.info(f"Quantized model {model}")
266274
logger.info("==" * 20)
267-
if not inference:
275+
if not inference_qconfig:
268276
if qcfg["smoothq"]:
269277
logger.info("Starting to apply smooth scale")
270278
dq_llm(model, act_scales, qcfg)
@@ -295,11 +303,11 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
295303
logger.info(
296304
f"Saving model processed for vLLM and tokenizer to {opt_args.output_dir}"
297305
)
298-
save_vllm_fp8(model,qcfg,tokenizer,opt_args.output_dir)
306+
save_vllm_fp8(model, qcfg, tokenizer, opt_args.output_dir)
299307
elif opt_args.save_ckpt:
300308
logger.info(
301309
f"Saving quantized model and tokenizer to {opt_args.output_dir}"
302-
)
310+
)
303311
model.save_pretrained(opt_args.output_dir, use_safetensors=True)
304312
tokenizer.save_pretrained(opt_args.output_dir)
305313

fms_mo/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,6 @@ def forward(self, x):
281281
)
282282

283283
# pylint: disable=not-callable
284-
285284
return F.linear(x, self.W_fp, self.bias)
286285
else:
287286
qinput = self.quantize_feature(x / scale).to(x.dtype)
@@ -297,6 +296,7 @@ def forward(self, x):
297296
)
298297

299298
qbias = self.bias
299+
300300
# pylint: disable=not-callable
301301
output = F.linear(qinput, qweight, qbias)
302302

fms_mo/prep.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
# Third Party
2424
from torch import nn
2525
import torch
26-
import compressed_tensors
26+
2727
# Local
2828
from fms_mo.calib import qmodel_calib
2929
from fms_mo.modules import QBmm_modules, QConv2d_modules, QLinear_modules, QLSTM_modules
@@ -391,13 +391,19 @@ def make_quant_module(module, curr_full_name, qcfg, verbose=False):
391391
# For nn.Linear
392392
elif isinstance(module, nn.Linear):
393393
if module.__class__ != nn.Linear:
394-
if isinstance(module, compressed_tensors.linear.compressed_linear.CompressedLinear):
395-
pass
396-
else:
397-
logger.warning(
398-
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
399-
"Please make sure it doesn't wrap BN and activ func."
400-
"Otherwise please create an equivalen Linear wrapper and change qcfg['mapping']."
394+
if available_packages["compressed_tensors"]:
395+
# Third Party
396+
import compressed_tensors
397+
398+
if isinstance(
399+
module, compressed_tensors.linear.compressed_linear.CompressedLinear
400+
):
401+
pass
402+
else:
403+
logger.warning(
404+
f"{curr_full_name} {type(module)} seems to be a wrapper of Linear."
405+
"Please make sure it doesn't wrap BN and activ func. Otherwise"
406+
"please create an equivalen Linear wrapper and change qcfg['mapping']."
401407
)
402408
QLin = mapping.get(nn.Linear, None)
403409
if QLin is None:
@@ -572,6 +578,7 @@ def has_quantized_module(model):
572578
"""Check if model is already quantized - do not want to quantize twice if so"""
573579
return any(isinstance(m, quantized_modules) for m in model.modules())
574580

581+
575582
def swap_qbmm(model: nn.Module, qcfg: dict):
576583
"""Go through all model.named_modules(), try to create an equivalent
577584
Qbmm layer to replace each of the existing linear Bmm layers.
@@ -581,14 +588,13 @@ def swap_qbmm(model: nn.Module, qcfg: dict):
581588
qcfg (dict): quant config
582589
583590
Returns: updated model is returned with the Qbmm added
584-
591+
585592
"""
586593

594+
# Local
587595
from fms_mo.modules import QBmm
588596

589-
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"][
590-
"which2patch_contextmanager"
591-
]
597+
qcfg["which2patch_contextmanager"] = qcfg["bmm_prep"]["which2patch_contextmanager"]
592598
isbmm = qcfg["which2patch_contextmanager"] == "torch.bmm"
593599
for mod_name, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
594600
mod_bmm_happened = model.get_submodule(mod_name)
@@ -608,6 +614,7 @@ def swap_qbmm(model: nn.Module, qcfg: dict):
608614
)
609615
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
610616

617+
611618
def qmodel_prep(
612619
model,
613620
dloader,
@@ -619,7 +626,6 @@ def qmodel_prep(
619626
Qcali=False,
620627
dev=None,
621628
use_dynamo=False,
622-
mode=False,
623629
verbose=False,
624630
**kwargs,
625631
):
@@ -695,14 +701,13 @@ def qmodel_prep(
695701
Returns:
696702
nn.Module: quantized model ready for further PTQ/QAT
697703
"""
698-
if mode:
699-
700-
if qcfg.get("QBmm"):
701-
swap_qbmm(model,qcfg)
704+
if qcfg["inference"]:
705+
if qcfg.get("QBmm"):
706+
swap_qbmm(model, qcfg)
702707

703-
model = q_any_net_5(model, qcfg, verbose = False)
708+
model = q_any_net_5(model, qcfg, verbose=False)
704709
return model
705-
710+
706711
sys.setrecursionlimit(4000)
707712

708713
currDev = next(model.parameters()).device if dev is None else dev
@@ -951,8 +956,10 @@ def qmodel_prep(
951956
model = torch.nn.parallel.DistributedDataParallel(
952957
model, device_ids=DPorDDPdevices
953958
)
954-
955-
qconfig_save(qcfg, fname=qcfg["output_folder"]+"/qcfg.json")
959+
if qcfg["output_folder"] is None:
960+
qconfig_save(qcfg, fname="qcfg.json")
961+
else:
962+
qconfig_save(qcfg, fname=qcfg["output_folder"] + "/qcfg.json")
956963
qcfg["tb_writer"] = tb_writer
957964

958965
logger.info(f"--- Quantized model --- \n{model}\n")

fms_mo/quant/quantizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def get_weight_quantizer(
237237
recompute=False,
238238
perGp=None,
239239
use_subnormal=False,
240-
emulate = True,
240+
emulate=True,
241241
):
242242
"""Return a quantizer for weight quantization
243243
Regular quantizers:

fms_mo/recipes/dq.json

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,7 @@
1010
"eval_ckpt": true,
1111
"nbits_bmm1" : 32,
1212
"nbits_bmm2" : 32,
13-
"nbits_kvcache" : 32
13+
"nbits_kvcache" : 32,
14+
"inference": false,
15+
"output_folder": null
1416
}
File renamed without changes.

0 commit comments

Comments
 (0)