diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index ff3ddf000..d1e392e75 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -1,9 +1,9 @@ from sft_trainer import SFTTrainer from transformers import AutoModelForCausalLM, AutoTokenizer, DefaultDataCollator -from llmcompressor.transformers import ( - DataTrainingArguments, - TextGenerationDataset, +from llmcompressor.transformers import TextGenerationDataset +from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, TrainingArguments, ) @@ -21,7 +21,7 @@ tokenizer = AutoTokenizer.from_pretrained(model_path) # Load gsm8k using SparseML dataset tools -data_args = DataTrainingArguments( +data_args = DatasetArguments( dataset="gsm8k", dataset_config_name="main", max_seq_length=512 ) dataset_manager = TextGenerationDataset.load_from_registry( diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index fa8e434d4..30c97df7a 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -8,12 +8,12 @@ from datasets.formatting.formatting import LazyRow from loguru import logger -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( LABELS_MASK_VALUE, get_custom_datasets_from_path, get_raw_dataset, ) +from llmcompressor.transformers.utils.arg_parser import DatasetArguments from llmcompressor.transformers.utils.preprocessing_functions import ( PreprocessingFunctionRegistry, ) @@ -41,7 +41,7 @@ class TextGenerationDataset(RegistryMixin): def __init__( self, - data_args: DataTrainingArguments, + data_args: DatasetArguments, split: str, processor: Processor, ): diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index e50d4d0c6..bf3feeee7 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="c4") @@ -18,7 +18,7 @@ class C4Dataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "allenai/c4" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index 06ad3ecfa..506f760d0 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="cnn_dailymail") @@ -20,7 +20,7 @@ class CNNDailyMailDataset(TextGenerationDataset): SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "cnn_dailymail" data_args.dataset_config_name = "3.0.0" diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 932bfa54c..ca3caec03 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="evolcodealpaca") @@ -25,7 +25,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "theblackcat102/evol-codealpaca-v1" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index f19b053e1..4528c5340 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -7,7 +7,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="flickr", alias="flickr30k") @@ -31,7 +31,7 @@ class Flickr30K(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "lmms-lab/flickr30k" diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index beae5dfec..8ee26145d 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="gsm8k") @@ -20,7 +20,7 @@ class GSM8KDataset(TextGenerationDataset): GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "gsm8k" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 3b25986ca..0dbf064e5 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="open_platypus") @@ -28,7 +28,7 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "garage-bAInd/Open-Platypus" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index c7f0bbac1..db0be0599 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="ptb") @@ -18,7 +18,7 @@ class PtbDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "ptb_text_only" data_args.text_column = "sentence" diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index 62c012e83..f914ae5d4 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -7,7 +7,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="ultrachat_200k") @@ -33,7 +33,7 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "HuggingFaceH4/ultrachat_200k" data_args.text_column = "messages" diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index a559399d8..5e58c3c94 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -5,7 +5,7 @@ from llmcompressor.typing import Processor if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments as DataArgs + from llmcompressor.transformers.utils.arg_parser import DatasetArguments @TextGenerationDataset.register(name="wikitext") @@ -18,7 +18,7 @@ class WikiTextDataset(TextGenerationDataset): :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DataArgs", split: str, processor: Processor): + def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): data_args = deepcopy(data_args) data_args.dataset = "Salesforce/wikitext" data_args.text_column = "text" diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 0a07c45eb..c1aec5164 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -16,13 +16,20 @@ from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.recipe import Recipe, StageRunType from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( format_calibration_data, make_dataset_splits, ) -from llmcompressor.transformers.finetune.model_args import ModelArguments -from llmcompressor.transformers.finetune.training_args import TrainingArguments +from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, +) +from llmcompressor.transformers.utils.arg_parser.training_arguments import ( + DEFAULT_OUTPUT_DIR, +) +from llmcompressor.transformers.utils.arg_parser.utils import get_dataclass_as_dict from llmcompressor.typing import Processor from llmcompressor.utils.fsdp.helpers import is_fsdp_model, save_model_and_recipe @@ -46,13 +53,15 @@ class StageRunner: def __init__( self, - data_args: "DataTrainingArguments", + data_args: "DatasetArguments", model_args: "ModelArguments", training_args: "TrainingArguments", + recipe_args: "RecipeArguments", ): self._data_args = data_args self._model_args = model_args self._training_args = training_args + self._recipe_args = recipe_args self.datasets = {} self.trainer = None @@ -214,7 +223,7 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): :param checkpoint: optional checkpoint to pick up a stage from """ - recipe_obj = Recipe.create_instance(self._training_args.recipe) + recipe_obj = Recipe.create_instance(self._recipe_args.recipe) with self.trainer.accelerator.main_process_first(): checkpoint_dir = self._model_args.model completed_stages = get_completed_stages(checkpoint_dir) @@ -251,21 +260,30 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None): # run stage if run_type is StageRunType.ONESHOT: - self.one_shot(stage=stage_name) + from llmcompressor.transformers.calibration import Oneshot + + model = get_session_model() + self._model_args.model = model + + oneshot = Oneshot( + output_dir=self._training_args.output_dir, + **get_dataclass_as_dict(self._model_args, ModelArguments), + **get_dataclass_as_dict(self._data_args, DatasetArguments), + **get_dataclass_as_dict(self._recipe_args, RecipeArguments), + ) + + oneshot.run(stage_name=stage_name) elif run_type is StageRunType.TRAIN: self.train(checkpoint=checkpoint, stage=stage_name) checkpoint = None - if ( - self._training_args.output_dir - != TrainingArguments.__dataclass_fields__["output_dir"].default - ): + if self._training_args.output_dir != DEFAULT_OUTPUT_DIR: save_model_and_recipe( model=self.trainer.model, save_path=self._output_dir, processor=self.processor, save_safetensors=self._training_args.save_safetensors, - save_compressed=self._training_args.save_compressed, + save_compressed=self._model_args.save_compressed, ) # save stage to checkpoint dir diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 27860aeb4..07b9ba1ef 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -7,13 +7,12 @@ import torch from loguru import logger from torch.nn import Module -from torch.utils.data import DataLoader, IterableDataset +from torch.utils.data import IterableDataset from transformers.trainer_callback import TrainerState from transformers.trainer_utils import get_last_checkpoint from llmcompressor.core import ( active_session, - apply, callbacks, create_session, finalize, @@ -36,8 +35,10 @@ from llmcompressor.utils.pytorch import qat_active if TYPE_CHECKING: - from llmcompressor.transformers import DataTrainingArguments - + from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, + ModelArguments, + ) __all__ = [ "SessionManagerMixIn", @@ -68,12 +69,14 @@ def __init__( self, recipe: Optional[str] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, - data_args: Optional["DataTrainingArguments"] = None, + data_args: Optional["DatasetArguments"] = None, + model_args: Optional["ModelArguments"] = None, teacher: Optional[Union[Module, str]] = None, **kwargs, ): self.recipe = recipe self.recipe_args = recipe_args + self.model_args = model_args self.teacher = teacher # parse training and metadata args @@ -374,8 +377,8 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): self.initialize_session(epoch=epoch, checkpoint=checkpoint, stage=stage) # do not save checkpoints as compressed - original_save_compressed = self.args.save_compressed - self.args.save_compressed = False + original_save_compressed = self.model_args.save_compressed + self.model_args.save_compressed = False # train with accelerator self.accelerator.wait_for_everyone() @@ -383,7 +386,7 @@ def train(self, *args, stage: Optional[str] = None, **kwargs): self.accelerator.wait_for_everyone() # restore original setting for saving final model - self.args.save_compressed = original_save_compressed + self.model_args.save_compressed = original_save_compressed # lifecycle self.finalize_session() @@ -428,31 +431,6 @@ def predict(self, *args, **kwargs): return output - def one_shot( - self, calibration_data: Optional[DataLoader] = None, stage: Optional[str] = None - ): - """ - Run oneshot calibration on the active model - - :param stage: which stage of the recipe to run, or None to run whole recipe - :param calib_data: dataloader of calibration data - """ - apply( - recipe=self.recipe, - recipe_stage=stage, - recipe_args=self.recipe_args, - model=self.model, - calib_data=calibration_data, - start=-1, - copy_data=False, - accelerator=self.accelerator, - min_tokens_per_module=self.min_tokens_per_module, - ) - - # log model sparsity - # self.maybe_log_model_sparsification() - self.accelerator.wait_for_everyone() - def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): """ Override of the save_model function and expects it to exist in the parent. @@ -474,7 +452,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): if not is_fsdp_model(self.model): self.model.save_pretrained( output_dir, - save_compressed=self.args.save_compressed, + save_compressed=self.model_args.save_compressed, safe_serialization=self.args.save_safetensors, ) else: # FSDP model @@ -482,7 +460,7 @@ def save_model(self, output_dir: str, _internal_call=False, _is_oneshot=False): model=self.model, accelerator=self.accelerator, output_dir=output_dir, - save_compressed=self.args.save_compressed, + save_compressed=self.model_args.save_compressed, save_safetensors=self.metadata.get("save_safetensors", False), ) diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 61e6441bb..6c71610a9 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -20,6 +20,7 @@ import os import warnings from pathlib import PosixPath +from typing import Optional from loguru import logger from transformers import ( @@ -40,18 +41,22 @@ parse_dtype, ) from llmcompressor.recipe import Recipe, StageRunType -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments -from llmcompressor.transformers.finetune.model_args import ModelArguments from llmcompressor.transformers.finetune.runner import StageRunner from llmcompressor.transformers.finetune.trainer import Trainer -from llmcompressor.transformers.finetune.training_args import TrainingArguments from llmcompressor.transformers.sparsification.compressed_tensors_utils import ( modify_fsdp_model_save_pretrained, modify_save_pretrained, patch_tied_tensors_bug, ) from llmcompressor.transformers.sparsification.sparse_model import ( - get_shared_processor_src, + get_processor_from_model, +) +from llmcompressor.transformers.utils.arg_parser import ( + DEFAULT_OUTPUT_DIR, + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, ) from llmcompressor.transformers.utils.helpers import ( detect_last_checkpoint, @@ -65,27 +70,33 @@ def train(**kwargs): """ CLI entrypoint for running training """ - model_args, data_args, training_args = parse_args(**kwargs) + model_args, data_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) training_args.do_train = True - main(model_args, data_args, training_args) + main(model_args, data_args, recipe_args, training_args) def eval(**kwargs): """ CLI entrypoint for running evaluation """ - model_args, data_args, training_args = parse_args(**kwargs) + model_args, data_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) training_args.do_eval = True - main(model_args, data_args, training_args) + main(model_args, data_args, recipe_args, training_args) def oneshot(**kwargs): + from llmcompressor.transformers.calibration.oneshot import Oneshot + """ CLI entrypoint for running oneshot calibration """ - model_args, data_args, training_args = parse_args(**kwargs) - training_args.do_oneshot = True - main(model_args, data_args, training_args) + oneshot = Oneshot(**kwargs) + oneshot.run() + return oneshot # alias @@ -97,12 +108,15 @@ def apply(**kwargs): CLI entrypoint for any of training, eval, predict or oneshot """ report_to = kwargs.get("report_to", None) - model_args, data_args, training_args = parse_args(**kwargs) + model_args, data_args, recipe_args, training_args, _ = parse_args( + include_training_args=True, **kwargs + ) + training_args.run_stages = True if report_to is None: # user didn't specify any reporters # get rid of the reporters inferred from hugging face training_args.report_to = [] - main(model_args, data_args, training_args) + main(model_args, data_args, recipe_args, training_args) def compress(**kwargs): @@ -111,60 +125,100 @@ def compress(**kwargs): def load_dataset(dataset_name: str, **kwargs): parser = HfArgumentParser( - (ModelArguments, DataTrainingArguments, TrainingArguments) + (ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments) ) - model_args, data_args, training_args = parser.parse_dict(kwargs) + _, data_args, _, _ = parser.parse_dict(kwargs) data_args["dataset_name"] = dataset_name -def parse_args(**kwargs): +def parse_args(include_training_args: bool = False, **kwargs): """ Parses kwargs by grouping into model, data or training arg groups: - * model_args in src/llmcompressor/transformers/finetune/model_args.py - * data_args in src/llmcompressor/transformers/finetune/data/data_args.py - * training_args in src/llmcompressor/transformers/finetune/training_args.py + * model_args in + src/llmcompressor/transformers/utils/arg_parser/model_args.py + * data_args in + src/llmcompressor/transformers/utils/arg_parser/data_args.py + * recipe_args in + src/llmcompressor/transformers/utils/arg_parser/recipe_args.py + * training_args in + src/llmcompressor/transformers/utils/arg_parser/training_args.py + + Throws deprecation warnings + + :param include_training_args: Add training_args in the output if set to True. + Note that instantiatng trainng_args will reset HF accelerator and change its + internal state. This dataclass should be instantiated only once to avoid + conflict with Accelerate library's accelerator. - Throws depreciation warnings """ - parser = HfArgumentParser( - (ModelArguments, DataTrainingArguments, TrainingArguments) - ) - if not kwargs: - model_args, data_args, training_args = parser.parse_args_into_dataclasses() + output_dir = kwargs.pop("output_dir", DEFAULT_OUTPUT_DIR) + + if include_training_args: + parser = HfArgumentParser( + (ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments) + ) else: - model_args, data_args, training_args = parser.parse_dict(kwargs) + parser = HfArgumentParser((ModelArguments, DatasetArguments, RecipeArguments)) + + if not kwargs: + # if output_dir passed from cli, pop to avoid using training_args + def _get_output_dir_from_argv() -> Optional[str]: + import sys + + output_dir = None + if "--output_dir" in sys.argv: + index = sys.argv.index("--output_dir") + sys.argv.pop(index) + if index < len(sys.argv): # Check if value exists afer the flag + output_dir = sys.argv.pop(index) + + return output_dir - if training_args.recipe_args is not None: - if not isinstance(training_args.recipe_args, dict): - arg_dict = {} - for recipe_arg in training_args.recipe_args: - key, value = recipe_arg.split("=") - arg_dict[key] = value - training_args.recipe_args = arg_dict + output_dir = _get_output_dir_from_argv() or output_dir - # raise depreciation warnings + parsed_args = parser.parse_args_into_dataclasses() + else: + parsed_args = parser.parse_dict(kwargs) + + # Unpack parsed arguments based on the presence of training arguments + if include_training_args: + model_args, data_args, recipe_args, training_args = parsed_args + if output_dir is not None: + training_args.output_dir = output_dir + else: + model_args, data_args, recipe_args = parsed_args + training_args = None + + if recipe_args.recipe_args is not None: + if not isinstance(recipe_args.recipe_args, dict): + recipe_args.recipe_args = { + key: value + for arg in recipe_args.recipe_args + for key, value in [arg.split("=")] + } + + # Raise deprecation warnings if data_args.remove_columns is not None: warnings.warn( - "`remove_columns` argument is depreciated. When tokenizing datasets, all " - "columns which are invalid inputs the tokenizer will be removed", + "`remove_columns` argument is deprecated. When tokenizing datasets, all " + "columns which are invalid inputs to the tokenizer will be removed.", DeprecationWarning, ) - # silently assign tokenizer to processor + # Silently assign tokenizer to processor if model_args.tokenizer: if model_args.processor: - raise ValueError("Cannot use both a tokenizer and processor") + raise ValueError("Cannot use both a tokenizer and processor.") model_args.processor = model_args.tokenizer - model_args.tokenizer = None + model_args.tokenizer = None - return model_args, data_args, training_args + return model_args, data_args, recipe_args, training_args, output_dir def initialize_model_from_path( model_args: ModelArguments, - training_args: TrainingArguments, + training_args: Optional[TrainingArguments] = None, ): - last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) # Load pretrained model # The .from_pretrained methods guarantee that only one local process can # concurrently download model & vocab. @@ -177,16 +231,23 @@ def initialize_model_from_path( tie_word_embeddings=model_args.tie_word_embeddings, trust_remote_code=model_args.trust_remote_code_model, ) - teacher_config = ( - AutoConfig.from_pretrained( - model_args.distill_teacher, - use_auth_token=True if model_args.use_auth_token else None, - tie_word_embeddings=model_args.tie_word_embeddings, - trust_remote_code=model_args.trust_remote_code_model, + + last_checkpoint = None + + if training_args is not None: + teacher_config = ( + AutoConfig.from_pretrained( + model_args.distill_teacher, + use_auth_token=True if model_args.use_auth_token else None, + tie_word_embeddings=model_args.tie_word_embeddings, + trust_remote_code=model_args.trust_remote_code_model, + ) + if model_args.distill_teacher + else None ) - if model_args.distill_teacher - else None - ) + last_checkpoint = detect_last_checkpoint(training_args, model_args=model_args) + # Set seed before initializing model. + set_seed(training_args.seed) model_path = ( last_checkpoint or model_args.model @@ -194,21 +255,18 @@ def initialize_model_from_path( else model_args.model_name_or_path ) - # Set seed before initializing model. - set_seed(training_args.seed) - # Fallback to CPU if GPU requested and not available - training_args.oneshot_device = fallback_to_cpu(training_args.oneshot_device) + model_args.oneshot_device = fallback_to_cpu(model_args.oneshot_device) # Trainer handles device assignment for FSDP and training, don't do mapping here # if running oneshot outside of FSDP, apply user device settings - device_map = None + fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" - if not fsdp_enabled and training_args.do_oneshot: - device_map = training_args.oneshot_device - logger.warning(f"Moving {model_path} to device {device_map} for One-Shot") - elif not fsdp_enabled: + + device_map = model_args.oneshot_device + if not fsdp_enabled and training_args is not None and training_args.do_train: device_map = "auto" + model_kwargs = { "config": config, "cache_dir": model_args.cache_dir, @@ -218,15 +276,7 @@ def initialize_model_from_path( "device_map": device_map, "trust_remote_code": model_args.trust_remote_code_model, } - teacher_device_map = None if fsdp_enabled else "auto" - teacher_kwargs = { - "config": teacher_config, - "cache_dir": model_args.cache_dir, - "use_auth_token": True if model_args.use_auth_token else None, - "torch_dtype": parse_dtype(model_args.precision), - "device_map": teacher_device_map, - "trust_remote_code": model_args.trust_remote_code_model, - } + # this calls from_pretrained under the hood so should be FSDP safe # optimized models must be decompressed to carry out oneshot/train/etc @@ -242,25 +292,38 @@ def initialize_model_from_path( if "sequence_length" in model_kwargs: model.seqlen = model_kwargs["sequence_length"] - teacher = ( - AutoModelForCausalLM.from_pretrained( - model_args.distill_teacher, - **teacher_kwargs, + teacher = None + if training_args is not None: + teacher_device_map = None if fsdp_enabled else "auto" + teacher_kwargs = { + "config": teacher_config, + "cache_dir": model_args.cache_dir, + "use_auth_token": True if model_args.use_auth_token else None, + "torch_dtype": parse_dtype(model_args.precision), + "device_map": teacher_device_map, + "trust_remote_code": model_args.trust_remote_code_model, + } + + teacher = ( + AutoModelForCausalLM.from_pretrained( + model_args.distill_teacher, + **teacher_kwargs, + ) + if model_args.distill_teacher is not None + else None ) - if model_args.distill_teacher is not None - else None - ) - if teacher is not None and "sequence_length" in teacher_kwargs: - teacher.seqlen = teacher_kwargs["sequence_length"] + if teacher is not None and "sequence_length" in teacher_kwargs: + teacher.seqlen = teacher_kwargs["sequence_length"] - return teacher, model_path, model + return model, teacher def initialize_processor_from_path( - model_args: ModelArguments, model: PreTrainedModel, teacher: PreTrainedModel + model_args: ModelArguments, + model: PreTrainedModel, + teacher: Optional[PreTrainedModel] = None, ) -> Processor: - processor_src = model_args.processor - processor_src = processor_src or get_shared_processor_src(model, teacher) + processor_src = model_args.processor or get_processor_from_model(model, teacher) # The use_fast=True option is not currently supported safely in Transformers # See: https://github.com/huggingface/transformers/pull/34836#issuecomment-2491809727 # noqa: E501 try: @@ -288,7 +351,8 @@ def initialize_processor_from_path( def main( model_args: ModelArguments, - data_args: DataTrainingArguments, + data_args: DatasetArguments, + recipe_args: RecipeArguments, training_args: TrainingArguments, ): """ @@ -323,8 +387,8 @@ def main( ) # Setup based on stage types if running stage mode - if training_args.run_stages and training_args.recipe is not None: - recipe_obj = Recipe.create_instance(training_args.recipe) + if training_args.run_stages and recipe_args.recipe is not None: + recipe_obj = Recipe.create_instance(recipe_args.recipe) for stage in recipe_obj.stages: run_type = stage.infer_run_type() if run_type is StageRunType.ONESHOT: @@ -348,7 +412,7 @@ def main( model = model_args.model if isinstance(model, str) or isinstance(model, PosixPath): - (teacher, _model_path, model) = initialize_model_from_path( + (model, teacher) = initialize_model_from_path( model_args, training_args, ) @@ -371,7 +435,10 @@ def main( # Load datasets stage_runner = StageRunner( - model_args=model_args, data_args=data_args, training_args=training_args + model_args=model_args, + data_args=data_args, + training_args=training_args, + recipe_args=recipe_args, ) add_labels = training_args.do_train or training_args.run_stages stage_runner.populate_datasets(processor=processor, add_labels=add_labels) @@ -379,13 +446,13 @@ def main( eval_dataset = stage_runner.get_dataset_split("validation") calib_dataset = stage_runner.get_dataset_split("calibration") - # Initialize our Trainer trainer = Trainer( model_init=get_session_model, teacher=teacher, - recipe=training_args.recipe, - recipe_args=training_args.recipe_args, + recipe=recipe_args.recipe, + recipe_args=recipe_args.recipe_args, args=training_args, + model_args=model_args, data_args=data_args, train_dataset=train_dataset or calib_dataset, eval_dataset=eval_dataset, @@ -437,13 +504,13 @@ def main( != TrainingArguments.__dataclass_fields__["output_dir"].default ): model.save_pretrained( - training_args.output_dir, save_compressed=training_args.save_compressed + training_args.output_dir, save_compressed=model_args.save_compressed ) if processor is not None: processor.save_pretrained(training_args.output_dir) # Clean up the CompressionSession before exit if requested - if training_args.clear_sparse_session: + if recipe_args.clear_sparse_session: reset_session() diff --git a/src/llmcompressor/transformers/utils/arg_parser/__init__.py b/src/llmcompressor/transformers/utils/arg_parser/__init__.py new file mode 100644 index 000000000..cbb9224af --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa + +from .data_arguments import DatasetArguments +from .model_arguments import ModelArguments +from .recipe_arguments import RecipeArguments +from .training_arguments import DEFAULT_OUTPUT_DIR, TrainingArguments diff --git a/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py new file mode 100644 index 000000000..50d3277f4 --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/data_arguments.py @@ -0,0 +1,189 @@ +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Union + +from transformers import DefaultDataCollator + + +@dataclass +class DVCDatasetArguments: + """ + Arguments for training using DVC + """ + + dvc_data_repository: Optional[str] = field( + default=None, + metadata={"help": "Path to repository used for dvc_dataset_path"}, + ) + + +@dataclass +class CustomDatasetArguments(DVCDatasetArguments): + """ + Arguments for training using custom datasets + """ + + dataset_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to the custom dataset. Supports json, csv, dvc. " + "For DVC, the to dvc dataset to load, of format dvc://path. " + "For csv or json, the path containing the dataset. " + ), + }, + ) + + text_column: str = field( + default="text", + metadata={ + "help": ( + "Optional key to be used as the `text` input to tokenizer/processor " + "after dataset preprocesssing" + ) + }, + ) + + remove_columns: Union[None, str, List] = field( + default=None, + metadata={"help": "Column names to remove after preprocessing (deprecated)"}, + ) + + preprocessing_func: Union[None, str, Callable] = field( + default=None, + metadata={ + "help": ( + "Typically a function which applies a chat template. Can take the form " + "of either a function to apply to the dataset, a name defined in " + "src/llmcompressor/transformers/utils/preprocessing_functions.py, or " + "a path to a function definition of the form /path/to/file.py:func" + ) + }, + ) + + data_collator: Callable[[Any], Any] = field( + default_factory=lambda: DefaultDataCollator(), + metadata={"help": "The function to used to form a batch from the dataset"}, + ) + + +@dataclass +class DatasetArguments(CustomDatasetArguments): + """ + Arguments pertaining to what data we are going to input our model for + calibration, training or eval + + Using `HfArgumentParser` we can turn this class into argparse + arguments to be able to specify them on the command line + """ + + dataset: Optional[str] = field( + default=None, + metadata={ + "help": ( + "The name of the dataset to use (via the datasets library). " + "Supports input as a string or DatasetDict from HF" + ) + }, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": ("The configuration name of the dataset to use"), + }, + ) + max_seq_length: int = field( + default=384, + metadata={ + "help": "The maximum total input sequence length after tokenization. " + "Sequences longer than this will be truncated, sequences shorter will " + "be padded." + }, + ) + concatenate_data: bool = field( + default=False, + metadata={ + "help": "Whether or not to concatenate datapoints to fill max_seq_length" + }, + ) + raw_kwargs: Dict = field( + default_factory=dict, + metadata={"help": "Additional keyboard args to pass to datasets load_data"}, + ) + splits: Union[None, str, List, Dict] = field( + default=None, + metadata={"help": "Optional percentages of each split to download"}, + ) + num_calibration_samples: Optional[int] = field( + default=512, + metadata={"help": "Number of samples to use for one-shot calibration"}, + ) + shuffle_calibration_samples: Optional[bool] = field( + default=True, + metadata={ + "help": "whether to shuffle the dataset before selecting calibration data" + }, + ) + streaming: Optional[bool] = field( + default=False, + metadata={"help": "True to stream data from a cloud dataset"}, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached preprocessed datasets or not."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. If False, " + "will pad the samples dynamically when batching to the maximum length " + "in the batch (which can be faster on GPU but will be slower on TPU)." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number " + "of training examples to this value if set." + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number " + "of evaluation examples to this value if set." + }, + ) + max_predict_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of " + "prediction examples to this value if set." + ), + }, + ) + min_tokens_per_module: Optional[float] = field( + default=None, + metadata={ + "help": ( + "The minimum percentage of tokens (out of the total number) " + "that the module should 'receive' throughout the forward " + "pass of the calibration. If a module receives fewer tokens, " + "a warning will be logged. Defaults to 1/num_of_experts." + "note: this argument is only relevant for MoE models" + ), + }, + ) + trust_remote_code_data: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for datasets defined on the Hub using " + "a dataset script. This option should only be set to True for " + "repositories you trust and in which you have read the code, as it " + "will execute code present on the Hub on your local machine." + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py new file mode 100644 index 000000000..ce424812a --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/model_arguments.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class ModelArguments: + """ + Model variables used for oneshot calibration, training or finetuning and + stage runners (combination of oneshot and finetune going back and forth) + + """ + + model: str = field( + metadata={ + "help": ( + "A pretrained model or a string as a path to pretrained model, " + "HF stub, or model identifier from huggingface.co/models." + ) + }, + ) + distill_teacher: Optional[str] = field( + default=None, + metadata={ + "help": "Teacher model (a trained text generation model)", + }, + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + processor: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained processor name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pretrained data from huggingface.co"}, + ) + + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use token generated when running `transformers-cli login` " + "(necessary to use this script with private models)" + }, + ) + precision: str = field( + default="auto", + metadata={"help": "Precision to cast model weights to, default to auto"}, + ) + + tie_word_embeddings: bool = field( + default=False, + metadata={ + "help": "Whether the model's input and output word embeddings " + "should be tied. Note that this is only relevant if the " + "model has a output word embedding layer." + }, + ) + trust_remote_code_model: bool = field( + default=False, + metadata={ + "help": "Whether or not to allow for custom models to execute their " + "own modeling files. This option should only be set to True for " + "repositories you trust and in which you have read the code" + }, + ) + save_compressed: Optional[bool] = field( + default=True, + metadata={"help": "Whether to compress sparse models during save"}, + ) + oneshot_device: Optional[str] = field( + default="cuda:0", + metadata={"help": "Device to run oneshot calibration on"}, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use " + "(can be a branch name, tag name or commit id)" + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py new file mode 100644 index 000000000..fbe535d7e --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/recipe_arguments.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass, field +from typing import List, Optional + + +@dataclass +class RecipeArguments: + """Recipe and session variables""" + + recipe: Optional[str] = field( + default=None, + metadata={ + "help": "Path to a LLM Compressor sparsification recipe", + }, + ) + recipe_args: Optional[List[str]] = field( + default=None, + metadata={ + "help": ( + "List of recipe arguments to evaluate, of the format key1=value1 " + "key2=value2" + ) + }, + ) + clear_sparse_session: Optional[bool] = field( + default=False, + metadata={ + "help": ( + "Whether to clear CompressionSession/CompressionLifecycle ", + "data between runs.", + ) + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py new file mode 100644 index 000000000..7b61193b0 --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/training_arguments.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments as HFTrainingArgs + +__all__ = ["TrainingArguments", "DEFAULT_OUTPUT_DIR"] + +DEFAULT_OUTPUT_DIR = "./output" + + +@dataclass +class TrainingArguments(HFTrainingArgs): + """ + Training arguments specific to LLM Compressor Transformers workflow using + HFTrainingArgs as base class + + """ + + do_oneshot: Optional[bool] = field( + default=False, + metadata={"help": "Whether to run one-shot calibration in stages"}, + ) + run_stages: Optional[bool] = field( + default=False, metadata={"help": "Whether to trigger recipe stage by stage"} + ) + output_dir: str = field( + default=DEFAULT_OUTPUT_DIR, + metadata={ + "help": "The output directory where the model predictions and " + "checkpoints will be written." + }, + ) diff --git a/src/llmcompressor/transformers/utils/arg_parser/utils.py b/src/llmcompressor/transformers/utils/arg_parser/utils.py new file mode 100644 index 000000000..48455fa15 --- /dev/null +++ b/src/llmcompressor/transformers/utils/arg_parser/utils.py @@ -0,0 +1,30 @@ +from dataclasses import fields +from typing import Any, Dict, Union + +from .data_arguments import DatasetArguments +from .model_arguments import ModelArguments +from .recipe_arguments import RecipeArguments +from .training_arguments import TrainingArguments + +__all__ = [ + "get_dataclass_as_dict", +] + + +def get_dataclass_as_dict( + dataclass_instance: Union[ + "ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments" + ], + dataclass_class: Union[ + "ModelArguments", "RecipeArguments", "DatasetArguments", "TrainingArguments" + ], +) -> Dict[str, Any]: + """ + Get the dataclass instance attributes as a dict, neglicting the inherited class. + Ex. dataclass_class=TrainingArguments will ignore HFTrainignArguments + + """ + return { + field.name: getattr(dataclass_instance, field.name) + for field in fields(dataclass_class) + } diff --git a/src/llmcompressor/transformers/utils/helpers.py b/src/llmcompressor/transformers/utils/helpers.py index c1dcef119..80c4b446e 100644 --- a/src/llmcompressor/transformers/utils/helpers.py +++ b/src/llmcompressor/transformers/utils/helpers.py @@ -14,7 +14,10 @@ from transformers.trainer_utils import get_last_checkpoint if TYPE_CHECKING: - from llmcompressor.transformers import ModelArguments, TrainingArguments + from llmcompressor.transformers.utils.arg_parser import ( + ModelArguments, + TrainingArguments, + ) __all__ = [ "RECIPE_FILE_NAME", diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index 13eab66c9..cefcdaa54 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -13,7 +13,7 @@ from llmcompressor.pytorch.utils import tensors_to_device from llmcompressor.transformers import oneshot from llmcompressor.transformers.finetune.data import TextGenerationDataset -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.transformers.utils.arg_parser import DatasetArguments from tests.testing_utils import parse_params, requires_gpu CONFIGS_DIRECTORY = "tests/llmcompressor/transformers/compression/configs" @@ -59,10 +59,9 @@ def _run_oneshot(model, recipe, dataset, output_dir): max_seq_length = 512 pad_to_max_length = False - oneshot( + oneshot_run = oneshot( model=model, dataset=dataset, - overwrite_output_dir=True, output_dir=output_dir, max_seq_length=max_seq_length, num_calibration_samples=num_calibration_samples, @@ -72,10 +71,8 @@ def _run_oneshot(model, recipe, dataset, output_dir): splits={"calibration": "train_gen[:5%]"}, save_compressed=False, ) - from llmcompressor.pytorch.model_load.helpers import get_session_model - # note: get_session_model() is None outside of function scope - return get_session_model() + return oneshot_run.model def _get_quant_info(self, model): quant_info_weights = {} @@ -147,7 +144,7 @@ def _get_dataloader(self, data_args, tokenizer): @torch.no_grad() def test_perplexity(self): tokenizer = AutoTokenizer.from_pretrained(self.model_stub) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="ultrachat-200k", max_seq_length=self.max_seq_length, ) diff --git a/tests/llmcompressor/transformers/finetune/data/conftest.py b/tests/llmcompressor/transformers/finetune/data/conftest.py index a7a347d99..a4182721d 100644 --- a/tests/llmcompressor/transformers/finetune/data/conftest.py +++ b/tests/llmcompressor/transformers/finetune/data/conftest.py @@ -1,7 +1,7 @@ import pytest from transformers import AutoTokenizer -from llmcompressor.transformers.finetune.model_args import ModelArguments +from llmcompressor.transformers.utils.arg_parser import ModelArguments @pytest.fixture diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index 812b26a56..4b907b6a0 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -1,15 +1,15 @@ import pytest -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments from llmcompressor.transformers.finetune.data.data_helpers import ( get_raw_dataset, make_dataset_splits, ) +from llmcompressor.transformers.utils.arg_parser import DatasetArguments @pytest.mark.unit def test_combined_datasets(): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) raw_wikitext2 = get_raw_dataset(data_args) @@ -33,7 +33,7 @@ def test_combined_datasets(): @pytest.mark.unit def test_separate_datasets(): splits = {"train": "train[:10%]", "validation": "train[10%:20%]"} - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index 64514b252..75be8102c 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -5,22 +5,23 @@ from datasets import IterableDataset, load_dataset from parameterized import parameterized -from llmcompressor.transformers import ( - DataTrainingArguments, - ModelArguments, - TextGenerationDataset, - TrainingArguments, -) +from llmcompressor.transformers import TextGenerationDataset from llmcompressor.transformers.finetune.data.data_helpers import ( format_calibration_data, ) from llmcompressor.transformers.finetune.runner import StageRunner +from llmcompressor.transformers.utils.arg_parser import ( + DatasetArguments, + ModelArguments, + RecipeArguments, + TrainingArguments, +) @pytest.mark.unit class TestConcentrationTokenization(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -53,7 +54,7 @@ def test_concatenation_tokenization(self): @pytest.mark.unit class TestNoPaddingTokenization(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="open_platypus", pad_to_max_length=False ) @@ -96,9 +97,7 @@ def test_no_padding_tokenization(self): @pytest.mark.unit class TestMaxSeqLenClipped(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( - dataset="open_platypus", max_seq_length=4096 - ) + self.data_args = DatasetArguments(dataset="open_platypus", max_seq_length=4096) @pytest.fixture(autouse=True) def prepare_fixture(self, tiny_llama_tokenizer): @@ -120,7 +119,7 @@ def test_max_seq_len_clipped(self): @pytest.mark.unit class TestDatasetKwargsAndPercent(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", raw_kwargs={ "data_files": { @@ -167,7 +166,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): ] ) def test_datasets(self, dataset_key, dataset_config, split, do_concat): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset=dataset_key, dataset_config_name=dataset_config, concatenate_data=do_concat, @@ -206,7 +205,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="evolcodealpaca", dataset_config_name=None, concatenate_data=False, @@ -235,7 +234,7 @@ def test_evol(self): @pytest.mark.unit class TestStreamLoading(unittest.TestCase): def setUp(self): - self.data_args = DataTrainingArguments( + self.data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -276,15 +275,19 @@ def prepare_fixture(self, tiny_llama_tokenizer): [["train"], ["train[60%:]"], [{"train": "train[:20%]"}], [None]] ) def test_split_loading(self, split_def): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="open_platypus", splits=split_def, trust_remote_code_data=True, ) training_args = TrainingArguments(do_train=True, output_dir="dummy") model_args = ModelArguments(model=None) + recipe_args = RecipeArguments() stage_runner = StageRunner( - model_args=model_args, data_args=data_args, training_args=training_args + model_args=model_args, + data_args=data_args, + training_args=training_args, + recipe_args=recipe_args, ) stage_runner.populate_datasets(processor=self.tiny_llama_tokenizer) @@ -318,10 +321,11 @@ def preprocess(sample): ) stage_runner = StageRunner( model_args=None, - data_args=DataTrainingArguments( + data_args=DatasetArguments( dataset=tokenized_dataset, shuffle_calibration_samples=False ), training_args=TrainingArguments(do_oneshot=True), + recipe_args=RecipeArguments(), ) stage_runner.populate_datasets(processor=None) calib_dataset = stage_runner.get_dataset_split("calibration") diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 9aee4c20f..11dc9034f 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -6,12 +6,12 @@ TextGenerationDataset, WikiTextDataset, ) -from llmcompressor.transformers.finetune.data.data_args import DataTrainingArguments +from llmcompressor.transformers.utils.arg_parser import DatasetArguments @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_c4_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="c4", concatenate_data=True) + data_args = DatasetArguments(dataset="c4", concatenate_data=True) c4_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, @@ -27,7 +27,7 @@ def test_c4_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_wikitext_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) wiki_manager = TextGenerationDataset.load_from_registry( @@ -45,7 +45,7 @@ def test_wikitext_initializes(tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_open_platypus_initializes(tiny_llama_tokenizer): - data_args = DataTrainingArguments(dataset="open_platypus", pad_to_max_length=False) + data_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) op_manager = TextGenerationDataset.load_from_registry( data_args.dataset, data_args=data_args, diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index fe699570a..6a6fc9bf3 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -23,12 +23,10 @@ def labeled_dataloader(self, dataset_name, model_name): from transformers import AutoTokenizer, DefaultDataCollator from llmcompressor.transformers.finetune.data import TextGenerationDataset - from llmcompressor.transformers.finetune.data.data_args import ( - DataTrainingArguments, - ) + from llmcompressor.transformers.utils.arg_parser import DatasetArguments tokenizer = AutoTokenizer.from_pretrained(model_name) - data_args = DataTrainingArguments( + data_args = DatasetArguments( dataset=dataset_name, max_seq_length=512, pad_to_max_length=False,