-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquantize.py
320 lines (270 loc) · 10 KB
/
quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import argparse
import os
import logging
import shutil
import subprocess
import sys
import json
import yaml
import faulthandler
import huggingface_hub
import torch
import datasets
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from accelerate import cpu_offload
LOGGER = logging.getLogger(__name__)
def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"-q",
"--quantization",
choices=[
"AWQ-Int4",
"GPTQ-Int8",
"GPTQ-Int4",
"Static-F8",
"Dynamic-F8",
],
required=True,
help="The type of quantization to apply",
)
parser.add_argument(
"-m",
"--model",
default=None,
required=True,
help="The base model. Should be huggingface tag.",
)
parser.add_argument(
"--dataset", default="HuggingFaceH4/ultrachat_200k", help="Calibration data"
)
parser.add_argument(
"--dataset-split", default="train_sft", help="Split for calibration data"
)
parser.add_argument(
"--dataset-name",
default=None,
help="Name for calibration data, passed to datasets.load_dataset.",
)
parser.add_argument(
"--num-samples",
default=512,
type=int,
help="Number of items from dataset to use for calibration",
)
parser.add_argument(
"--seq-length",
default=2048,
type=int,
help="Sequence length for calibration data",
)
parser.add_argument(
"--batch-size",
default=32,
type=int,
help="Number of calibration samples to process at the same time.",
)
parser.add_argument("--update-metadata-only", default=False, action="store_true")
args = parser.parse_args()
model_name = os.path.basename(args.model)
quant_name = f"{model_name}-{args.quantization}"
logging.basicConfig(level=logging.INFO)
LOGGER.info(args)
LOGGER.info(os.environ)
target_device = os.environ.get("CUDA_VISIBLE_DEVICES", "0").split(",")[0]
LOGGER.info(f"Using cuda:{target_device}")
# NOTE: `0` index means the first **visible** cuda device
device = torch.device("cuda:0")
torch.cuda.set_device(device)
if args.update_metadata_only:
os.makedirs(quant_name, exist_ok=True)
# with open(f"{quant_name}/lambda-quant-args.json") as fp:
# args = argparse.Namespace(**json.load(fp))
write_metadata(args, quant_name, device)
return
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
def preprocess(example):
return {
"text": tokenizer.apply_chat_template(example["messages"], tokenize=False)
}
def tokenize(sample):
return tokenizer(
sample["text"],
padding=False,
max_length=args.seq_length,
truncation=True,
add_special_tokens=False,
)
ds = datasets.load_dataset(
args.dataset, args.dataset_name, split=args.dataset_split
)
ds = ds.shuffle(seed=0).select(range(args.num_samples))
if args.quantization == "AWQ-Int4":
from awq import AutoAWQForCausalLM
import transformers
from packaging.version import Version
if Version(transformers.__version__) > Version("4.47.1"):
raise ValueError(
"There's currently a bug in autoawq with newer transformers versions. Please downgrade to 4.47.1"
)
LOGGER.info("Initializing model...")
model = AutoAWQForCausalLM.from_pretrained(
args.model, low_cpu_mem_usage=True, use_cache=False
)
LOGGER.info("Initializing calibration data...")
ds = ds.map(preprocess, remove_columns=ds.column_names)
ds = [q["text"] for q in ds]
LOGGER.info("Quantizing...")
model.quantize(
tokenizer,
calib_data=ds,
max_calib_samples=args.num_samples,
max_calib_seq_len=args.seq_length,
n_parallel_calib_samples=args.batch_size,
)
LOGGER.info("Saving...")
model.save_quantized(quant_name)
tokenizer.save_pretrained(quant_name)
elif args.quantization in ["GPTQ-Int4", "GPTQ-Int8"]:
from gptqmodel import GPTQModel, QuantizeConfig
LOGGER.info("Initializing model...")
model = GPTQModel.load(
args.model,
QuantizeConfig(
bits=4 if "4" in args.quantization else 8,
group_size=128,
device=device,
),
)
LOGGER.info("Initializing calibration data...")
ds = ds.map(preprocess)
ds = ds.map(tokenize, remove_columns=ds.column_names)
LOGGER.info("Quantizing...")
model.quantize(ds.to_list(), batch_size=args.batch_size, tokenizer=tokenizer)
LOGGER.info("Saving...")
model.save_quantized(quant_name)
tokenizer.save_pretrained(quant_name)
elif args.quantization in ["Static-F8", "Dynamic-F8"]:
"""See https://github.com/vllm-project/llm-compressor/tree/main/examples/quantization_w8a8_fp8"""
from llmcompressor.transformers import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
LOGGER.info("Initializing model...")
model = AutoModelForCausalLM.from_pretrained(args.model, use_cache=False)
cpu_offload(model, execution_device=device)
LOGGER.info("Quantizing...")
oneshot(
model=model,
tokenizer=tokenizer,
recipe=QuantizationModifier(
targets="Linear",
scheme={
"Static-F8": "FP8",
"Dynamic-F8": "FP8_DYNAMIC",
}[args.quantization],
ignore=["lm_head"],
),
)
LOGGER.info("Saving...")
model.save_pretrained(quant_name)
tokenizer.save_pretrained(quant_name)
else:
raise NotImplementedError(args.quantization)
write_metadata(args, quant_name, device)
def write_metadata(args, metdata_dir, device: torch.device):
os.makedirs(metdata_dir, exist_ok=True)
LOGGER.info("Downloading base model readmes.")
hf_cache_dir = huggingface_hub.snapshot_download(args.model, allow_patterns="*.md")
for fname in os.listdir(hf_cache_dir):
if fname.endswith("md"):
LOGGER.info(f"Copying {hf_cache_dir}/{fname} into {metdata_dir}")
shutil.copy(
os.path.join(hf_cache_dir, fname),
os.path.join(metdata_dir, fname),
)
if "AWQ" in args.quantization:
import awq
quantization_library = (
f"[autoawq=={awq.__version__}](https://github.com/casper-hansen/AutoAWQ)"
)
elif "GPTQ" in args.quantization:
import gptqmodel
quantization_library = f"[gptqmodel=={gptqmodel.__version__}](https://github.com/ModelCloud/GPTQModel)"
else:
import llmcompressor
quantization_library = f"[llmcompressor=={llmcompressor.__version__}](https://github.com/vllm-project/llm-compressor)"
LOGGER.info(f"Using {quantization_library}")
commit_hash = subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
LOGGER.info(f"lambda-quant commit {commit_hash}")
assert device.type == "cuda"
device_name = torch.cuda.get_device_name(device)
quantization_info = [
f"Quantized using {quantization_library} on an `{device_name}` GPU.",
]
if "AWQ" in args.quantization or "GPTQ" in args.quantization:
quantization_info.extend(
[
"",
f"Calibrated with `{args.num_samples}` samples from `{args.dataset}`, `--batch-size {args.batch_size}`, `--seq-length {args.seq_length}`.",
]
)
new_lines = [
"# Quantization",
f"Created with [lambda-quant](https://github.com/LambdaLabsML/lambda-quant/tree/{commit_hash}) on `Python {sys.version}`",
"",
f"Base Model: [{args.model}](https://huggingface.co/{args.model})",
"",
*quantization_info,
"",
"Steps to create:",
f"1. `git clone https://github.com/LambdaLabsML/lambda-quant`",
f"2. `git checkout {commit_hash}`",
f"3. `python {' '.join(sys.argv)}`",
"",
"## Evaluation",
"TODO",
"",
"## Benchmarks",
"TODO",
"",
"# Base Model README.md",
"",
]
with open(f"{metdata_dir}/README.md") as fp:
readme_content = fp.read()
metadata_start = readme_content.find("---")
if metadata_start >= 0:
metadata_end = readme_content.find("---", metadata_start + len("---")) + len(
"---"
)
metadata = readme_content[metadata_start:metadata_end]
readme_content = readme_content[:metadata_start] + readme_content[metadata_end:]
else:
metadata = "\n".join(
[
"---",
f'base_model: "{args.model}"',
"---",
]
)
metadata = yaml.safe_load(metadata.replace("---", ""))
if "base_model" not in metadata:
metadata["base_model"] = args.model
if "license" not in metadata:
metadata["license"] = "mit"
metadata = "---\n" + yaml.dump(metadata) + "---\n"
new_content = "\n".join(new_lines)
LOGGER.info(f"Writing {new_content} into README.md")
with open(f"{metdata_dir}/README.md", "w") as fp:
fp.write(metadata + "\n" + new_content + "\n" + readme_content)
LOGGER.info(f"Dumping `pip freeze` to {metdata_dir}/requirements-lambda-quant.txt")
freeze = subprocess.check_output(["pip", "freeze"]).decode()
with open(f"{metdata_dir}/requirements-lambda-quant.txt", "w") as fp:
fp.write(freeze)
LOGGER.info(f"Dumping `args` to {metdata_dir}/args-lambda-quant.json")
with open(f"{metdata_dir}/args-lambda-quant.json", "w") as fp:
json.dump(vars(args), fp)
if __name__ == "__main__":
main()