-
Notifications
You must be signed in to change notification settings - Fork 95
/
Copy pathtext_generation.py
466 lines (397 loc) · 15.8 KB
/
text_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
#!/usr/bin/env python
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Adapted from https://github.com/huggingface/transformers
# vllm-project: no copyright
import os
import warnings
from pathlib import PosixPath
from typing import Optional
from compressed_tensors.utils.helpers import deprecated
from loguru import logger
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
HfArgumentParser,
PreTrainedModel,
set_seed,
)
from transformers.utils.quantization_config import CompressedTensorsConfig
from llmcompressor.args import (
DatasetArguments,
ModelArguments,
RecipeArguments,
TrainingArguments,
)
from llmcompressor.core import pre_initialize_structure, reset_session
from llmcompressor.pytorch.model_load.helpers import (
fallback_to_cpu,
get_session_model,
initialize_recipe,
parse_dtype,
save_checkpoint,
)
from llmcompressor.recipe import Recipe, StageRunType
from llmcompressor.transformers.finetune.runner import StageRunner
from llmcompressor.transformers.finetune.trainer import Trainer
from llmcompressor.transformers.sparsification.compressed_tensors_utils import (
modify_save_pretrained,
patch_tied_tensors_bug,
)
from llmcompressor.transformers.sparsification.sparse_model import (
get_processor_name_from_model,
)
from llmcompressor.transformers.utils.helpers import (
detect_last_checkpoint,
is_model_ct_quantized_from_path,
)
from llmcompressor.typing import Processor
from llmcompressor.utils.fsdp.helpers import is_fsdp_model
def train(**kwargs):
"""
CLI entrypoint for running training
"""
model_args, data_args, recipe_args, training_args = parse_args(**kwargs)
training_args.do_train = True
main(model_args, data_args, recipe_args, training_args)
def eval(**kwargs):
"""
CLI entrypoint for running evaluation
"""
model_args, data_args, recipe_args, training_args = parse_args(**kwargs)
training_args.do_eval = True
main(model_args, data_args, recipe_args, training_args)
@deprecated(
message=(
"`from llmcompressor.transformers import oneshot` is deprecated, "
"please use `from llmcompressor import oneshot`."
)
)
def oneshot(**kwargs) -> None:
from llmcompressor import oneshot
oneshot(**kwargs)
def apply(**kwargs):
"""
CLI entrypoint for any of training, oneshot
"""
report_to = kwargs.get("report_to", None)
model_args, data_args, recipe_args, training_args = parse_args(**kwargs)
training_args.run_stages = True
if report_to is None: # user didn't specify any reporters
# get rid of the reporters inferred from hugging face
training_args.report_to = []
main(model_args, data_args, recipe_args, training_args)
def compress(**kwargs):
apply(**kwargs)
def parse_args(**kwargs):
"""
Parses kwargs by grouping into model, data or training arg groups:
* model_args in
src/llmcompressor/transformers/utils/arg_parser/model_args.py
* data_args in
src/llmcompressor/transformers/utils/arg_parser/data_args.py
* recipe_args in
src/llmcompressor/transformers/utils/arg_parser/recipe_args.py
* training_args in
src/llmcompressor/transformers/utils/arg_parser/training_args.py
"""
parser = HfArgumentParser(
(ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments)
)
if not kwargs:
parsed_args = parser.parse_args_into_dataclasses()
else:
parsed_args = parser.parse_dict(kwargs)
model_args, data_args, recipe_args, training_args = parsed_args
if recipe_args.recipe_args is not None:
if not isinstance(recipe_args.recipe_args, dict):
arg_dict = {}
for recipe_arg in recipe_args.recipe_args:
key, value = recipe_arg.split("=")
arg_dict[key] = value
recipe_args.recipe_args = arg_dict
# raise depreciation warnings
if data_args.remove_columns is not None:
warnings.warn(
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
"columns which are invalid inputs the tokenizer will be removed",
DeprecationWarning,
)
# silently assign tokenizer to processor
if model_args.tokenizer:
if model_args.processor:
raise ValueError("Cannot use both a tokenizer and processor")
model_args.processor = model_args.tokenizer
model_args.tokenizer = None
return model_args, data_args, recipe_args, training_args
def initialize_model_from_path(
model_args: ModelArguments,
training_args: Optional[TrainingArguments] = None,
):
# Load pretrained model
# The .from_pretrained methods guarantee that only one local process can
# concurrently download model & vocab.
model_path = model_args.model
config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_path,
cache_dir=model_args.cache_dir,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
last_checkpoint = None
teacher = None
if training_args is not None:
# Load teacher configuration if applicable
teacher_config = (
AutoConfig.from_pretrained(
model_args.distill_teacher,
use_auth_token=True if model_args.use_auth_token else None,
tie_word_embeddings=model_args.tie_word_embeddings,
trust_remote_code=model_args.trust_remote_code_model,
)
if model_args.distill_teacher
else None
)
# Detect last checkpoint
last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args)
# Set seed before initializing model
set_seed(training_args.seed)
# Initialize teacher model if teacher path is provided
if model_args.distill_teacher is not None:
teacher_device_map = (
None
if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
else "auto"
)
teacher_kwargs = {
"config": teacher_config,
"cache_dir": model_args.cache_dir,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": teacher_device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}
teacher = AutoModelForCausalLM.from_pretrained(
model_args.distill_teacher,
**teacher_kwargs,
)
if "sequence_length" in teacher_kwargs:
teacher.seqlen = teacher_kwargs["sequence_length"]
model_path = (
last_checkpoint or model_args.model
if hasattr(model_args, "model")
else model_args.model_name_or_path
)
# Fallback to CPU if GPU requested and not available
model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device)
# Trainer handles device assignment for FSDP and training, don't do mapping here
# if running oneshot outside of FSDP, apply user device settings
fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
device_map = model_args.oneshot_device
if not fsdp_enabled and training_args is not None and training_args.do_train:
device_map = "auto"
model_kwargs = {
"config": config,
"cache_dir": model_args.cache_dir,
"revision": model_args.model_revision,
"use_auth_token": True if model_args.use_auth_token else None,
"torch_dtype": parse_dtype(model_args.precision),
"device_map": device_map,
"trust_remote_code": model_args.trust_remote_code_model,
}
# this calls from_pretrained under the hood so should be FSDP safe
# optimized models must be decompressed to carry out oneshot/train/etc
if is_model_ct_quantized_from_path(model_path):
model_kwargs["quantization_config"] = CompressedTensorsConfig(
run_compressed=False
)
model = AutoModelForCausalLM.from_pretrained(
model_path,
**model_kwargs,
)
if "sequence_length" in model_kwargs:
model.seqlen = model_kwargs["sequence_length"]
return model, teacher
def initialize_processor_from_path(
model_args: ModelArguments,
model: PreTrainedModel,
teacher: Optional[PreTrainedModel] = None,
) -> Processor:
processor_src = model_args.processor or get_processor_name_from_model(
model, teacher
)
# The use_fast=True option is not currently supported safely in Transformers
# See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501
try:
processor = AutoProcessor.from_pretrained(
processor_src,
cache_dir=model_args.cache_dir,
use_fast=True,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
trust_remote_code=model_args.trust_remote_code_model,
)
except Exception:
logger.debug("Could not load fast processor, loading slow processor instead")
processor = AutoProcessor.from_pretrained(
processor_src,
cache_dir=model_args.cache_dir,
use_fast=False,
revision=model_args.model_revision,
use_auth_token=True if model_args.use_auth_token else None,
trust_remote_code=model_args.trust_remote_code_model,
)
return processor
def main(
model_args: ModelArguments,
data_args: DatasetArguments,
recipe_args: RecipeArguments,
training_args: TrainingArguments,
):
"""
Main entrypoint for finetuning text generation models. A model can be loaded from
Hugging Face or disk, and resuming training from a checkpoint is supported.
Lifecycle:
- SparseAutoModel.text_generation_from_pretrained if model provided as
string for model and teacher
- AutoTokenizer.from_pretrained() if tokenizer provided as
string for tokenizer
- StageRunner.populate_datasets()
- Trainer()
- SessionMixIn()
- HFTransformersTrainer()
- StageRunner.train() and/or oneshot()
:param model_args: Arguments pertaining to which model/config/tokenizer we are
going to fine-tune from
:param data_args: Arguments pertaining to what data we are going to input our model
for training
:param training_args: Arguments pertaining to training loop configuration
"""
# Temporary warning, to be removed
if model_args.tie_word_embeddings is True:
logger.warning(
"The tie_word_embeddings flag is by default set to False. "
"This guarantees that the one-shot algorithm saves the final "
"weights without errors. Detected tie_word_embeddings=True. "
"This may cause issues with the one-shot algorithm on save. "
)
# Setup based on stage types if running stage mode
if training_args.run_stages and recipe_args.recipe is not None:
recipe_obj = Recipe.create_instance(recipe_args.recipe)
for stage in recipe_obj.stages:
run_type = stage.infer_run_type()
if run_type is StageRunType.ONESHOT:
training_args.do_oneshot = True
elif run_type is StageRunType.TRAIN:
training_args.do_train = True
# Summary on each process
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, "
f"n_gpu: {training_args.n_gpu}, "
f"distributed training: {bool(training_args.local_rank != -1)}, "
f"16-bits training: {training_args.fp16}"
)
logger.info(f"Training parameters {training_args}")
# Detecting last checkpoint.
last_checkpoint = None
teacher = model_args.distill_teacher
# distill TODO: support for different processor for teacher?
model = model_args.model
if isinstance(model, str) or isinstance(model, PosixPath):
model, teacher = initialize_model_from_path(
model_args,
training_args,
)
# patch a shared tensor bug in HF transformers
# https://github.com/huggingface/transformers/issues/33689
patch_tied_tensors_bug(model)
if teacher is not None:
teacher.eval()
processor = model_args.processor
if isinstance(processor, str) or processor is None:
processor = initialize_processor_from_path(model_args, model, teacher)
pre_initialize_structure(model=model)
# initialize session manager
initialize_recipe(model, None)
# Load datasets
stage_runner = StageRunner(
model_args=model_args,
data_args=data_args,
training_args=training_args,
recipe_args=recipe_args,
)
add_labels = training_args.do_train or training_args.run_stages
stage_runner.populate_datasets(processor=processor, add_labels=add_labels)
train_dataset = stage_runner.get_dataset_split("train")
calib_dataset = stage_runner.get_dataset_split("calibration")
trainer = Trainer(
model_init=get_session_model,
teacher=teacher,
recipe=recipe_args.recipe,
recipe_args=recipe_args.recipe_args,
args=training_args,
model_args=model_args,
data_args=data_args,
train_dataset=train_dataset or calib_dataset,
processing_class=processor,
data_collator=data_args.data_collator,
)
# wrap model.save_pretrained
if is_fsdp_model(model):
raise NotImplementedError(
"FSDP models are not supported in the current release but will be "
"suported in future releases of LLM Compressor"
)
else:
modify_save_pretrained(model)
stage_runner.trainer = trainer
# alternating Training/One-shot
if training_args.run_stages:
checkpoint = None
if last_checkpoint is not None:
checkpoint = last_checkpoint
stage_runner.run_sequential_stages(checkpoint)
# exit immediately
return
# Training
if training_args.do_train:
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
stage_runner.train(checkpoint)
# save if model was provided as a string or custom output_dir was set
if isinstance(model_args.model, str) or (
training_args.output_dir
!= TrainingArguments.__dataclass_fields__["output_dir"].default
and trainer.accelerator.is_main_process
):
save_checkpoint(
save_path=training_args.output_dir,
model=model,
processor=processor,
save_safetensors=True,
save_compressed=model_args.save_compressed,
)
trainer.accelerator.wait_for_everyone()
# Clean up the CompressionSession before exit if requested
if recipe_args.clear_sparse_session:
reset_session()
if __name__ == "__main__":
apply()