diff --git a/examples/trl_mixin/ex_trl_distillation.py b/examples/trl_mixin/ex_trl_distillation.py index ebd14c5d2..4ebb53276 100644 --- a/examples/trl_mixin/ex_trl_distillation.py +++ b/examples/trl_mixin/ex_trl_distillation.py @@ -19,12 +19,12 @@ max_seq_length = 512 # Load gsm8k using SparseML dataset tools -data_args = DatasetArguments( +dataset_args = DatasetArguments( dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length ) dataset_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split="train", processor=tokenizer, ) @@ -69,7 +69,7 @@ train_dataset=train_dataset, data_collator=data_collator, trl_sft_config_args=trl_sft_config_args, - data_args=data_args, + dataset_args=dataset_args, model_args=model_args, ) trainer.train() diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index ecdebf46b..ea6481043 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -35,7 +35,7 @@ class Oneshot: `kwargs` are parsed into: - `model_args`: Arguments for loading and configuring a pretrained model (e.g., `AutoModelForCausalLM`). - - `data_args`: Arguments for dataset-related configurations, such as + - `dataset_args`: Arguments for dataset-related configurations, such as calibration dataloaders. - `recipe_args`: Arguments for defining and configuring recipes that specify optimization actions. @@ -108,24 +108,23 @@ def __init__( """ Initializes the `Oneshot` class with provided arguments. - Parses the input keyword arguments into `model_args`, `data_args`, and + Parses the input keyword arguments into `model_args`, `dataset_args`, and `recipe_args`. Performs preprocessing to initialize the model and tokenizer/processor. :param model_args: ModelArguments parameters, responsible for controlling model loading and saving logic - :param data_args: DatasetArguments parameters, responsible for controlling + :param dataset_args: DatasetArguments parameters, responsible for controlling dataset loading, preprocessing and dataloader loading :param recipe_args: RecipeArguments parameters, responsible for containing recipe-related parameters :param output_dir: Path to save the output model after carrying out oneshot """ - model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs) self.model_args = model_args - self.data_args = dataset_args + self.dataset_args = dataset_args self.recipe_args = recipe_args self.output_dir = output_dir @@ -136,14 +135,19 @@ def __init__( @classmethod def from_args( - cls, model_args, data_args, recipe_args, output_dir, do_preprocess: bool = True + cls, + model_args, + dataset_args, + recipe_args, + output_dir, + do_preprocess: bool = True, ): """ Used only for the stage runner to populate the args. """ instance = super().__new__(cls) instance.model_args = model_args - instance.data_args = data_args + instance.dataset_args = dataset_args instance.recipe_args = recipe_args instance.output_dir = output_dir @@ -176,7 +180,7 @@ def __call__(self): self.processor = self.model_args.processor calibration_dataloader = get_calibration_dataloader( - self.data_args, self.processor + self.dataset_args, self.processor ) self.apply_recipe_modifiers( calibration_dataloader=calibration_dataloader, @@ -242,7 +246,7 @@ def _pre_process(self): - Applies patches to fix tied tensor issues and modifies `save_pretrained` behavior. - Initializes the processor if specified as a path or `None`. - - Sets the minimum tokens per module if `data_args` are provided. + - Sets the minimum tokens per module if `dataset_args` are provided. Raises: FileNotFoundError: If the model or processor path is invalid. @@ -265,8 +269,8 @@ def _pre_process(self): self.processor = self.model_args.processor # Set minimum tokens per module if data arguments are provided - if self.data_args: - self.min_tokens_per_module = self.data_args.min_tokens_per_module + if self.dataset_args: + self.min_tokens_per_module = self.dataset_args.min_tokens_per_module def check_tied_embeddings(self): """ diff --git a/src/llmcompressor/transformers/finetune/data/base.py b/src/llmcompressor/transformers/finetune/data/base.py index bde819e43..593682cbd 100644 --- a/src/llmcompressor/transformers/finetune/data/base.py +++ b/src/llmcompressor/transformers/finetune/data/base.py @@ -31,7 +31,7 @@ class TextGenerationDataset(RegistryMixin): 3. Tokenize dataset using model tokenizer/processor 4. Apply post processing such as grouping text and/or adding labels for finetuning - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -41,11 +41,11 @@ class TextGenerationDataset(RegistryMixin): def __init__( self, - data_args: DatasetArguments, + dataset_args: DatasetArguments, split: str, processor: Processor, ): - self.data_args = data_args + self.dataset_args = dataset_args self.split = split self.processor = processor @@ -58,23 +58,23 @@ def __init__( self.tokenizer.pad_token = self.tokenizer.eos_token # configure sequence length - max_seq_length = data_args.max_seq_length - if data_args.max_seq_length > self.tokenizer.model_max_length: + max_seq_length = dataset_args.max_seq_length + if dataset_args.max_seq_length > self.tokenizer.model_max_length: logger.warning( f"The max_seq_length passed ({max_seq_length}) is larger than " f"maximum length for model ({self.tokenizer.model_max_length}). " f"Using max_seq_length={self.tokenizer.model_max_length}." ) self.max_seq_length = min( - data_args.max_seq_length, self.tokenizer.model_max_length + dataset_args.max_seq_length, self.tokenizer.model_max_length ) # configure padding self.padding = ( False - if self.data_args.concatenate_data + if self.dataset_args.concatenate_data else "max_length" - if self.data_args.pad_to_max_length + if self.dataset_args.pad_to_max_length else False ) @@ -83,7 +83,7 @@ def __init__( self.padding = False def __call__(self, add_labels: bool = True) -> DatasetType: - dataset = self.data_args.dataset + dataset = self.dataset_args.dataset if isinstance(dataset, str): # load dataset: load from huggingface or disk @@ -96,8 +96,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.preprocess, batched=False, - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Preprocessing", ) logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}") @@ -121,20 +121,20 @@ def __call__(self, add_labels: bool = True) -> DatasetType: # regardless of `batched` argument remove_columns=get_columns(dataset), # assumes that input names # and output names are disjoint - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Tokenizing", ) logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}") - if self.data_args.concatenate_data: + if self.dataset_args.concatenate_data: # postprocess: group text dataset = self.map( dataset, self.group_text, batched=True, - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Concatenating data", ) logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}") @@ -145,8 +145,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType: dataset, self.add_labels, batched=False, # not compatible with batching, need row lengths - num_proc=self.data_args.preprocessing_num_workers, - load_from_cache_file=not self.data_args.overwrite_cache, + num_proc=self.dataset_args.preprocessing_num_workers, + load_from_cache_file=not self.dataset_args.overwrite_cache, desc="Adding labels", ) logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}") @@ -165,27 +165,31 @@ def load_dataset(self): :param cache_dir: disk location to search for cached dataset :return: the requested dataset """ - if self.data_args.dataset_path is not None: - if self.data_args.dvc_data_repository is not None: - self.data_args.raw_kwargs["storage_options"] = { - "url": self.data_args.dvc_data_repository + if self.dataset_args.dataset_path is not None: + if self.dataset_args.dvc_data_repository is not None: + self.dataset_args.raw_kwargs["storage_options"] = { + "url": self.dataset_args.dvc_data_repository } - self.data_args.raw_kwargs["data_files"] = self.data_args.dataset_path + self.dataset_args.raw_kwargs["data_files"] = ( + self.dataset_args.dataset_path + ) else: - self.data_args.raw_kwargs["data_files"] = get_custom_datasets_from_path( - self.data_args.dataset_path, - self.data_args.dataset - if hasattr(self.data_args, "dataset") - else self.data_args.dataset_name, + self.dataset_args.raw_kwargs["data_files"] = ( + get_custom_datasets_from_path( + self.dataset_args.dataset_path, + self.dataset_args.dataset + if hasattr(self.dataset_args, "dataset") + else self.dataset_args.dataset_name, + ) ) - logger.debug(f"Loading dataset {self.data_args.dataset}") + logger.debug(f"Loading dataset {self.dataset_args.dataset}") return get_raw_dataset( - self.data_args, + self.dataset_args, None, split=self.split, - streaming=self.data_args.streaming, - **self.data_args.raw_kwargs, + streaming=self.dataset_args.streaming, + **self.dataset_args.raw_kwargs, ) @cached_property @@ -194,7 +198,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]: The function must return keys which correspond to processor/tokenizer kwargs, optionally including PROMPT_KEY """ - preprocessing_func = self.data_args.preprocessing_func + preprocessing_func = self.dataset_args.preprocessing_func if callable(preprocessing_func): return preprocessing_func @@ -218,9 +222,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]: def rename_columns(self, dataset: DatasetType) -> DatasetType: # rename columns to match processor/tokenizer kwargs column_names = get_columns(dataset) - if self.data_args.text_column in column_names and "text" not in column_names: - logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`") - dataset = dataset.rename_column(self.data_args.text_column, "text") + if self.dataset_args.text_column in column_names and "text" not in column_names: + logger.debug(f"Renaming column `{self.dataset_args.text_column}` to `text`") + dataset = dataset.rename_column(self.dataset_args.text_column, "text") return dataset diff --git a/src/llmcompressor/transformers/finetune/data/c4.py b/src/llmcompressor/transformers/finetune/data/c4.py index 988a4adc3..e4fe6431c 100644 --- a/src/llmcompressor/transformers/finetune/data/c4.py +++ b/src/llmcompressor/transformers/finetune/data/c4.py @@ -13,14 +13,16 @@ class C4Dataset(TextGenerationDataset): """ Child text generation class for the C4 dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "allenai/c4" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "allenai/c4" + dataset_args.text_column = "text" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) diff --git a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py index b005724aa..fcc67482f 100644 --- a/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py +++ b/src/llmcompressor/transformers/finetune/data/cnn_dailymail.py @@ -13,19 +13,21 @@ class CNNDailyMailDataset(TextGenerationDataset): """ Text generation class for the CNN/DailyMail dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n" - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "cnn_dailymail" - data_args.dataset_config_name = "3.0.0" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "cnn_dailymail" + dataset_args.dataset_config_name = "3.0.0" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) def dataset_template(self, sample): return { diff --git a/src/llmcompressor/transformers/finetune/data/custom.py b/src/llmcompressor/transformers/finetune/data/custom.py index 7cff3c1d9..1239e08be 100644 --- a/src/llmcompressor/transformers/finetune/data/custom.py +++ b/src/llmcompressor/transformers/finetune/data/custom.py @@ -7,7 +7,7 @@ class CustomDataset(TextGenerationDataset): Child text generation class for custom local dataset supporting load for csv and json - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` Can also be set to None to load all the splits :param processor: processor or tokenizer to use on dataset diff --git a/src/llmcompressor/transformers/finetune/data/data_helpers.py b/src/llmcompressor/transformers/finetune/data/data_helpers.py index cf9b81f69..bd28de314 100644 --- a/src/llmcompressor/transformers/finetune/data/data_helpers.py +++ b/src/llmcompressor/transformers/finetune/data/data_helpers.py @@ -70,7 +70,7 @@ def format_calibration_data( def get_raw_dataset( - data_args, + dataset_args, cache_dir: Optional[str] = None, streaming: Optional[bool] = False, **kwargs, @@ -84,11 +84,11 @@ def get_raw_dataset( """ raw_datasets = load_dataset( - data_args.dataset, - data_args.dataset_config_name, + dataset_args.dataset, + dataset_args.dataset_config_name, cache_dir=cache_dir, streaming=streaming, - trust_remote_code=data_args.trust_remote_code_data, + trust_remote_code=dataset_args.trust_remote_code_data, **kwargs, ) return raw_datasets @@ -235,26 +235,26 @@ def do_transform(candidate: str) -> bool: def get_calibration_dataloader( - data_args, + dataset_args, processor, add_labels: bool = False, # for oneshot do_oneshot=True, ) -> torch.utils.data.DataLoader: """ - Loads datasets for each flow based on data_args, stores a Dataset for each + Loads datasets for each flow based on dataset_args, stores a Dataset for each enabled flow in self.datasets :param processor: processor or tokenizer to use for dataset tokenization :param add_labels: if True, add labels column to dataset splits """ - if data_args.dataset is None: + if dataset_args.dataset is None: logger.info( "Running oneshot without calibration data. This is expected for " "weight-only and dynamic quantization" ) return - splits = data_args.splits + splits = dataset_args.splits tokenized_datasets = {} def _get_split_name(inp_str): @@ -272,9 +272,11 @@ def _get_split_name(inp_str): splits = {_get_split_name(s): s for s in splits} # default to custom dataset if dataset provided isn't a string - registry_id = data_args.dataset if isinstance(data_args.dataset, str) else "custom" + registry_id = ( + dataset_args.dataset if isinstance(dataset_args.dataset, str) else "custom" + ) for split_name, split_str in splits.items(): - dataset = data_args.dataset + dataset = dataset_args.dataset if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: # dataset is already tokenized tokenized_datasets[split_name] = dataset @@ -286,7 +288,7 @@ def _get_split_name(inp_str): dataset_manager = TextGenerationDataset.load_from_registry( registry_id, - data_args=data_args, + dataset_args=dataset_args, split=split_str, processor=processor, ) @@ -301,7 +303,7 @@ def _get_split_name(inp_str): return format_calibration_data( tokenized_dataset=calibration_dataset, - num_calibration_samples=data_args.num_calibration_samples, - do_shuffle=data_args.shuffle_calibration_samples, - collate_fn=data_args.data_collator, + num_calibration_samples=dataset_args.num_calibration_samples, + do_shuffle=dataset_args.shuffle_calibration_samples, + collate_fn=dataset_args.data_collator, ) diff --git a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py index 3de833738..8a7892c13 100644 --- a/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py +++ b/src/llmcompressor/transformers/finetune/data/evolcodealpaca.py @@ -13,7 +13,7 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): """ Child text generation class for the Evol Code Alpaca dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -25,12 +25,14 @@ class EvolCodeAlpacaDataset(TextGenerationDataset): "\n\n### Response:\n" ) - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "theblackcat102/evol-codealpaca-v1" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "theblackcat102/evol-codealpaca-v1" + dataset_args.text_column = "text" - super().__init__(data_args, split=split, processor=processor) + super().__init__(dataset_args, split=split, processor=processor) def dataset_template(self, sample): prompt = self.EVOL_ALPACA_TEMPLATE.format(instruction=sample["instruction"]) diff --git a/src/llmcompressor/transformers/finetune/data/flickr_30k.py b/src/llmcompressor/transformers/finetune/data/flickr_30k.py index 6e11c3aaf..8ada07a0e 100644 --- a/src/llmcompressor/transformers/finetune/data/flickr_30k.py +++ b/src/llmcompressor/transformers/finetune/data/flickr_30k.py @@ -13,7 +13,7 @@ @TextGenerationDataset.register(name="flickr", alias="flickr30k") class Flickr30K(TextGenerationDataset): """ - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -31,11 +31,13 @@ class Flickr30K(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "lmms-lab/flickr30k" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "lmms-lab/flickr30k" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) if ( self.tokenizer is not None diff --git a/src/llmcompressor/transformers/finetune/data/gsm8k.py b/src/llmcompressor/transformers/finetune/data/gsm8k.py index 4f61d1726..ae1318571 100644 --- a/src/llmcompressor/transformers/finetune/data/gsm8k.py +++ b/src/llmcompressor/transformers/finetune/data/gsm8k.py @@ -13,19 +13,21 @@ class GSM8KDataset(TextGenerationDataset): """ Child text generation class for the Grade School Math 8k dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ GSM_TEMPLATE = "Question: {question}\nAnswer:" - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "gsm8k" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "gsm8k" + dataset_args.text_column = "text" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) def dataset_template(self, sample): prompt = self.GSM_TEMPLATE.format(question=sample["question"]) diff --git a/src/llmcompressor/transformers/finetune/data/open_platypus.py b/src/llmcompressor/transformers/finetune/data/open_platypus.py index 33c9ddc86..81413e785 100644 --- a/src/llmcompressor/transformers/finetune/data/open_platypus.py +++ b/src/llmcompressor/transformers/finetune/data/open_platypus.py @@ -13,7 +13,7 @@ class OpenPlatypusDataset(TextGenerationDataset): """ Child text generation class for the Open Platypus dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -28,11 +28,13 @@ class OpenPlatypusDataset(TextGenerationDataset): "instruction}\n\n### Response:\n", } - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "garage-bAInd/Open-Platypus" - data_args.text_column = "text" - super().__init__(data_args=data_args, split=split, processor=processor) + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "garage-bAInd/Open-Platypus" + dataset_args.text_column = "text" + super().__init__(dataset_args=dataset_args, split=split, processor=processor) def dataset_template(self, sample): if "input" in sample and sample["input"] != "": diff --git a/src/llmcompressor/transformers/finetune/data/ptb.py b/src/llmcompressor/transformers/finetune/data/ptb.py index 7966fe4d0..8f03ad509 100644 --- a/src/llmcompressor/transformers/finetune/data/ptb.py +++ b/src/llmcompressor/transformers/finetune/data/ptb.py @@ -13,18 +13,20 @@ class PtbDataset(TextGenerationDataset): """ Child text generation class for the PTB dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "ptb_text_only" - data_args.text_column = "sentence" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "ptb_text_only" + dataset_args.text_column = "sentence" super().__init__( - data_args=data_args, + dataset_args=dataset_args, split=split, processor=processor, ) diff --git a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py index fad57c076..296eb3db5 100644 --- a/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py +++ b/src/llmcompressor/transformers/finetune/data/ultrachat_200k.py @@ -15,7 +15,7 @@ class UltraChatDataset(TextGenerationDataset): """ Child text generation class for the Ultra Chat 200k dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ @@ -33,15 +33,17 @@ class UltraChatDataset(TextGenerationDataset): "{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" ) - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "HuggingFaceH4/ultrachat_200k" - data_args.text_column = "messages" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "HuggingFaceH4/ultrachat_200k" + dataset_args.text_column = "messages" if split in ["train", "test"]: split += "_sft" - super().__init__(data_args=data_args, split=split, processor=processor) + super().__init__(dataset_args=dataset_args, split=split, processor=processor) if ( self.tokenizer is not None diff --git a/src/llmcompressor/transformers/finetune/data/wikitext.py b/src/llmcompressor/transformers/finetune/data/wikitext.py index 868c9d951..73142d671 100644 --- a/src/llmcompressor/transformers/finetune/data/wikitext.py +++ b/src/llmcompressor/transformers/finetune/data/wikitext.py @@ -13,18 +13,20 @@ class WikiTextDataset(TextGenerationDataset): """ Child text generation class for the Open Platypus dataset - :param data_args: configuration settings for dataset loading + :param dataset_args: configuration settings for dataset loading :param split: split from dataset to load, for instance `test` or `train[:5%]` :param processor: processor or tokenizer to use on dataset """ - def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor): - data_args = deepcopy(data_args) - data_args.dataset = "Salesforce/wikitext" - data_args.text_column = "text" + def __init__( + self, dataset_args: "DatasetArguments", split: str, processor: Processor + ): + dataset_args = deepcopy(dataset_args) + dataset_args.dataset = "Salesforce/wikitext" + dataset_args.text_column = "text" super().__init__( - data_args=data_args, + dataset_args=dataset_args, split=split, processor=processor, ) diff --git a/src/llmcompressor/transformers/finetune/runner.py b/src/llmcompressor/transformers/finetune/runner.py index 1735a99b8..75d963aa5 100644 --- a/src/llmcompressor/transformers/finetune/runner.py +++ b/src/llmcompressor/transformers/finetune/runner.py @@ -42,19 +42,19 @@ class StageRunner: - train() :param model_args: Arguments pertaining to model/config/processor - :param data_args: Arguments pertaining to what data to use for different flows + :param dataset_args: Arguments pertaining to what data to use for different flows :param training_args: Arguments pertaining to training loop configuration :model: unwrapped model to run flows on """ def __init__( self, - data_args: "DatasetArguments", + dataset_args: "DatasetArguments", model_args: "ModelArguments", training_args: "TrainingArguments", recipe_args: "RecipeArguments", ): - self._data_args = data_args + self._dataset_args = dataset_args self._model_args = model_args self._training_args = training_args self._recipe_args = recipe_args @@ -67,13 +67,13 @@ def __init__( def populate_datasets(self, processor: Processor, add_labels: bool = True): """ - Loads datasets for each flow based on data_args, stores a Dataset for each + Loads datasets for each flow based on dataset_args, stores a Dataset for each enabled flow in self.datasets :param processor: processor or tokenizer to use for dataset tokenization :param add_labels: if True, add labels column to dataset splits """ - if self._data_args.dataset is None: + if self._dataset_args.dataset is None: self.processor = self._model_args.processor logger.info( "Running oneshot without calibration data. This is expected for " @@ -81,7 +81,7 @@ def populate_datasets(self, processor: Processor, add_labels: bool = True): ) return - splits = self._data_args.splits + splits = self._dataset_args.splits tokenized_datasets = {} def _get_split_name(inp_str): @@ -100,12 +100,12 @@ def _get_split_name(inp_str): # default to custom dataset if dataset provided isn't a string registry_id = ( - self._data_args.dataset - if isinstance(self._data_args.dataset, str) + self._dataset_args.dataset + if isinstance(self._dataset_args.dataset, str) else "custom" ) for split_name, split_str in splits.items(): - dataset = self._data_args.dataset + dataset = self._dataset_args.dataset if hasattr(dataset, "column_names") and "input_ids" in dataset.column_names: # dataset is already tokenized tokenized_datasets[split_name] = dataset @@ -113,7 +113,7 @@ def _get_split_name(inp_str): # dataset needs to be tokenized dataset_manager = TextGenerationDataset.load_from_registry( registry_id, - data_args=self._data_args, + dataset_args=self._dataset_args, split=split_str, processor=processor, ) @@ -202,7 +202,7 @@ def run_sequential_stages( oneshot = Oneshot.from_args( model_args=self._model_args, - data_args=self._data_args, + dataset_args=self._dataset_args, recipe_args=self._recipe_args, output_dir=self._training_args.output_dir, do_preprocess=do_preprocess, @@ -210,9 +210,9 @@ def run_sequential_stages( calib_data = format_calibration_data( tokenized_dataset=self.get_dataset_split("calibration"), - num_calibration_samples=self._data_args.num_calibration_samples, - do_shuffle=self._data_args.shuffle_calibration_samples, - collate_fn=self._data_args.data_collator, + num_calibration_samples=self._dataset_args.num_calibration_samples, + do_shuffle=self._dataset_args.shuffle_calibration_samples, + collate_fn=self._dataset_args.data_collator, accelerator=self.trainer.accelerator, ) diff --git a/src/llmcompressor/transformers/finetune/session_mixin.py b/src/llmcompressor/transformers/finetune/session_mixin.py index 27882d7d6..f64916e69 100644 --- a/src/llmcompressor/transformers/finetune/session_mixin.py +++ b/src/llmcompressor/transformers/finetune/session_mixin.py @@ -54,14 +54,14 @@ class SessionManagerMixIn: :param recipe: path to recipe file to apply during training :param recipe_args: additional kwargs to use for evaluating recipe - :param data_args: kwargs for configuring dataset loading + :param dataset_args: kwargs for configuring dataset loading :param teacher: optional teacher model to use for distillation """ def __init__( self, recipe: str, - data_args: "DatasetArguments", + dataset_args: "DatasetArguments", model_args: "ModelArguments", teacher: Optional[Union[Module, str]] = None, recipe_args: Optional[Union[Dict[str, Any], str]] = None, @@ -77,7 +77,7 @@ def __init__( self.metadata = None if training_args is not None: - # trl_sft_trainer pathway. Both training_args and data_args + # trl_sft_trainer pathway. Both training_args and dataset_args # have `max_seq_length` which causes collision error. This is the # only shared parameter, where training arg is `TRLSFTConfig` that # inherits HuggingFace's `TrainingArguments` @@ -87,7 +87,7 @@ def __init__( training_args_dict.pop("max_seq_length") ) logger.warning( - "Detected `max_seq_length` in both data_args ", + "Detected `max_seq_length` in both dataset_args ", "and training_args. This is expected for TRL in distillation. ", "Updating metadata to `training_args_max_seq_length`", ) @@ -95,7 +95,7 @@ def __init__( self.metadata = self._extract_metadata( metadata_args=METADATA_ARGS, training_args_dict=training_args_dict, - data_args_dict=asdict(data_args) if data_args else {}, + dataset_args_dict=asdict(dataset_args) if dataset_args else {}, ) # setup metrics and session @@ -125,8 +125,8 @@ def __init__( if self.is_fsdp_enabled: self._prepare_model_for_fsdp() - if data_args is not None: - self.min_tokens_per_module = data_args.min_tokens_per_module + if dataset_args is not None: + self.min_tokens_per_module = dataset_args.min_tokens_per_module def initialize_session( self, @@ -459,16 +459,16 @@ def _extract_metadata( self, metadata_args: List[str], training_args_dict: Dict[str, Any], - data_args_dict: Dict[str, Any], + dataset_args_dict: Dict[str, Any], ) -> Dict[str, Any]: metadata = {} - if not training_args_dict.keys().isdisjoint(data_args_dict.keys()): + if not training_args_dict.keys().isdisjoint(dataset_args_dict.keys()): raise ValueError( "Found common keys in `training_args` and `data args`. " "This is prohibitive and may lead to undesired behavior." ) - args_dict = {**training_args_dict, **data_args_dict} + args_dict = {**training_args_dict, **dataset_args_dict} for arg in metadata_args: if arg not in args_dict.keys(): diff --git a/src/llmcompressor/transformers/finetune/text_generation.py b/src/llmcompressor/transformers/finetune/text_generation.py index 9a3623f60..d03867b85 100644 --- a/src/llmcompressor/transformers/finetune/text_generation.py +++ b/src/llmcompressor/transformers/finetune/text_generation.py @@ -68,18 +68,18 @@ def train(**kwargs): """ CLI entrypoint for running training """ - model_args, data_args, recipe_args, training_args = parse_args(**kwargs) + model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs) training_args.do_train = True - main(model_args, data_args, recipe_args, training_args) + main(model_args, dataset_args, recipe_args, training_args) def eval(**kwargs): """ CLI entrypoint for running evaluation """ - model_args, data_args, recipe_args, training_args = parse_args(**kwargs) + model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs) training_args.do_eval = True - main(model_args, data_args, recipe_args, training_args) + main(model_args, dataset_args, recipe_args, training_args) @deprecated( @@ -99,13 +99,13 @@ def apply(**kwargs): CLI entrypoint for any of training, oneshot """ report_to = kwargs.get("report_to", None) - model_args, data_args, recipe_args, training_args = parse_args(**kwargs) + model_args, dataset_args, recipe_args, training_args = parse_args(**kwargs) training_args.run_stages = True if report_to is None: # user didn't specify any reporters # get rid of the reporters inferred from hugging face training_args.report_to = [] - main(model_args, data_args, recipe_args, training_args) + main(model_args, dataset_args, recipe_args, training_args) def compress(**kwargs): @@ -117,8 +117,8 @@ def parse_args(**kwargs): Parses kwargs by grouping into model, data or training arg groups: * model_args in src/llmcompressor/transformers/utils/arg_parser/model_args.py - * data_args in - src/llmcompressor/transformers/utils/arg_parser/data_args.py + * dataset_args in + src/llmcompressor/transformers/utils/arg_parser/dataset_args.py * recipe_args in src/llmcompressor/transformers/utils/arg_parser/recipe_args.py * training_args in @@ -134,7 +134,7 @@ def parse_args(**kwargs): else: parsed_args = parser.parse_dict(kwargs) - model_args, data_args, recipe_args, training_args = parsed_args + model_args, dataset_args, recipe_args, training_args = parsed_args if recipe_args.recipe_args is not None: if not isinstance(recipe_args.recipe_args, dict): arg_dict = {} @@ -144,7 +144,7 @@ def parse_args(**kwargs): recipe_args.recipe_args = arg_dict # raise depreciation warnings - if data_args.remove_columns is not None: + if dataset_args.remove_columns is not None: warnings.warn( "`remove_columns` argument is depreciated. When tokenizing datasets, all " "columns which are invalid inputs the tokenizer will be removed", @@ -158,7 +158,7 @@ def parse_args(**kwargs): model_args.processor = model_args.tokenizer model_args.tokenizer = None - return model_args, data_args, recipe_args, training_args + return model_args, dataset_args, recipe_args, training_args def initialize_model_from_path( @@ -304,7 +304,7 @@ def initialize_processor_from_path( def main( model_args: ModelArguments, - data_args: DatasetArguments, + dataset_args: DatasetArguments, recipe_args: RecipeArguments, training_args: TrainingArguments, ): @@ -326,8 +326,8 @@ def main( :param model_args: Arguments pertaining to which model/config/tokenizer we are going to fine-tune from - :param data_args: Arguments pertaining to what data we are going to input our model - for training + :param dataset_args: Arguments pertaining to what data we are going to input + our model for training :param training_args: Arguments pertaining to training loop configuration """ @@ -384,7 +384,7 @@ def main( # Load datasets stage_runner = StageRunner( model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, training_args=training_args, recipe_args=recipe_args, ) @@ -400,10 +400,10 @@ def main( recipe_args=recipe_args.recipe_args, args=training_args, model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, train_dataset=train_dataset or calib_dataset, processing_class=processor, - data_collator=data_args.data_collator, + data_collator=dataset_args.data_collator, ) # wrap model.save_pretrained diff --git a/src/llmcompressor/transformers/tracing/debug.py b/src/llmcompressor/transformers/tracing/debug.py index ccce917a7..2bb399b3c 100644 --- a/src/llmcompressor/transformers/tracing/debug.py +++ b/src/llmcompressor/transformers/tracing/debug.py @@ -63,11 +63,11 @@ def trace( print("Loaded model") # Prepare sample data - data_args = DatasetArguments(**get_dataset_kwargs(modality)) + dataset_args = DatasetArguments(**get_dataset_kwargs(modality)) dataset = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, - split=data_args.splits["calibration"], + dataset_args.dataset, + dataset_args=dataset_args, + split=dataset_args.splits["calibration"], processor=processor, )(add_labels=False) sample_input = next(iter(dataset)) @@ -89,7 +89,7 @@ def trace( "\nAttempting trace\n" f" model_id={model_id}\n" f" model_class={model_class.__name__}\n" - f" dataset={data_args.dataset}\n" + f" dataset={dataset_args.dataset}\n" f" split={dataset.split}\n" f" inputs={sample_input.keys()}\n" f" sequential_targets={sequential_targets}\n" diff --git a/tests/llmcompressor/entrypoints/test_oneshot.py b/tests/llmcompressor/entrypoints/test_oneshot.py index 1d00c828f..ba0cb3a3a 100644 --- a/tests/llmcompressor/entrypoints/test_oneshot.py +++ b/tests/llmcompressor/entrypoints/test_oneshot.py @@ -17,7 +17,7 @@ def test_oneshot_from_args(): output_dir = "bar_output_dir" - model_args, data_args, recipe_args, _, output_dir = parse_args( + model_args, dataset_args, recipe_args, _, output_dir = parse_args( model=model, dataset=dataset, recipe=recipe, @@ -26,10 +26,10 @@ def test_oneshot_from_args(): output_dir=output_dir, ) - oneshot = Oneshot.from_args(model_args, data_args, recipe_args, output_dir) + oneshot = Oneshot.from_args(model_args, dataset_args, recipe_args, output_dir) assert oneshot.model == model assert oneshot.model_args is model_args - assert oneshot.data_args is data_args + assert oneshot.dataset_args is dataset_args assert oneshot.recipe_args is recipe_args assert oneshot.model_args is model_args assert oneshot.output_dir is output_dir diff --git a/tests/llmcompressor/transformers/compression/test_quantization.py b/tests/llmcompressor/transformers/compression/test_quantization.py index d7a4bdba7..8a4f46fb5 100644 --- a/tests/llmcompressor/transformers/compression/test_quantization.py +++ b/tests/llmcompressor/transformers/compression/test_quantization.py @@ -123,10 +123,10 @@ def test_quantization_reload(self): assert o_zp.dtype == n_zp.dtype assert torch.equal(o_zp, n_zp) - def _get_dataloader(self, data_args, tokenizer): + def _get_dataloader(self, dataset_args, tokenizer): dataset_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split="train_gen[:5%]", processor=tokenizer, ) @@ -145,11 +145,11 @@ def test_perplexity(self): if self.ppl_threshold is None: pytest.skip("Skipping perplexity calculation.") tokenizer = AutoTokenizer.from_pretrained(self.model_stub) - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="ultrachat-200k", max_seq_length=self.max_seq_length, ) - dataloader = self._get_dataloader(data_args, tokenizer) + dataloader = self._get_dataloader(dataset_args, tokenizer) total_ppl = 0.0 total_non_nan = 0 diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py index aa55a752c..39165ffe6 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_helpers.py @@ -9,10 +9,10 @@ @pytest.mark.unit def test_combined_datasets(): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) - raw_wikitext2 = get_raw_dataset(data_args) + raw_wikitext2 = get_raw_dataset(dataset_args) datasets = {"all": raw_wikitext2} split_datasets = make_dataset_splits(datasets, do_train=True) assert split_datasets.get("train") is not None @@ -23,13 +23,13 @@ def test_combined_datasets(): @pytest.mark.unit def test_separate_datasets(): - splits = {"train": "train[:5%]", "validation": "train[5%:7%]"} - data_args = DatasetArguments( + splits = {"train": "train[:5%]", "validation": "train[10%:20%]"} + dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) datasets = {} for split_name, split_str in splits.items(): - raw_wikitext2 = get_raw_dataset(data_args, split=split_str) + raw_wikitext2 = get_raw_dataset(dataset_args, split=split_str) datasets[split_name] = raw_wikitext2 split_datasets = make_dataset_splits(datasets, do_train=True) diff --git a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py index db539a74c..7198e0da3 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py +++ b/tests/llmcompressor/transformers/finetune/data/test_dataset_loading.py @@ -21,7 +21,7 @@ @pytest.mark.unit class TestConcentrationTokenization(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -33,8 +33,8 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_concatenation_tokenization(self): wiki_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[:5%]", processor=self.tiny_llama_tokenizer, ) @@ -54,7 +54,7 @@ def test_concatenation_tokenization(self): @pytest.mark.unit class TestNoPaddingTokenization(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="open_platypus", pad_to_max_length=False ) @@ -65,8 +65,8 @@ def prepare_fixture(self, tiny_llama_tokenizer): @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_no_padding_tokenization(self): op_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[5%:7%]", processor=self.tiny_llama_tokenizer, ) @@ -75,7 +75,7 @@ def test_no_padding_tokenization(self): dataset, op_manager.preprocess, batched=False, - num_proc=op_manager.data_args.preprocessing_num_workers, + num_proc=op_manager.dataset_args.preprocessing_num_workers, ) dataset = op_manager.rename_columns(dataset) # rename self.assertGreater(len(dataset), 0) @@ -97,7 +97,9 @@ def test_no_padding_tokenization(self): @pytest.mark.unit class TestMaxSeqLenClipped(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments(dataset="open_platypus", max_seq_length=4096) + self.dataset_args = DatasetArguments( + dataset="open_platypus", max_seq_length=4096 + ) @pytest.fixture(autouse=True) def prepare_fixture(self, tiny_llama_tokenizer): @@ -105,8 +107,8 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_max_seq_len_clipped(self): op_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[95%:]", processor=self.tiny_llama_tokenizer, ) @@ -119,7 +121,7 @@ def test_max_seq_len_clipped(self): @pytest.mark.unit class TestDatasetKwargsAndPercent(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="wikitext", raw_kwargs={ "data_files": { @@ -134,16 +136,16 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_dataset_kwargs_and_percentages(self): c4_manager_a = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[5%:6%]", processor=self.tiny_llama_tokenizer, ) raw_dataset_a = c4_manager_a.load_dataset() c4_manager_b = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[6%:8%]", processor=self.tiny_llama_tokenizer, ) @@ -166,15 +168,15 @@ def prepare_fixture(self, tiny_llama_tokenizer): ] ) def test_datasets(self, dataset_key, dataset_config, split, do_concat): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset=dataset_key, dataset_config_name=dataset_config, concatenate_data=do_concat, trust_remote_code_data=True, ) manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=split, processor=self.tiny_llama_tokenizer, ) @@ -205,7 +207,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): self.tiny_llama_tokenizer = tiny_llama_tokenizer def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="evolcodealpaca", dataset_config_name=None, concatenate_data=False, @@ -213,8 +215,8 @@ def setUp(self): def test_evol(self): evol_manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train[:2%]", processor=self.tiny_llama_tokenizer, ) @@ -234,7 +236,7 @@ def test_evol(self): @pytest.mark.unit class TestStreamLoading(unittest.TestCase): def setUp(self): - self.data_args = DatasetArguments( + self.dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1", concatenate_data=True, @@ -247,8 +249,8 @@ def prepare_fixture(self, tiny_llama_tokenizer): def test_stream_loading(self): manager = TextGenerationDataset.load_from_registry( - self.data_args.dataset, - data_args=self.data_args, + self.dataset_args.dataset, + dataset_args=self.dataset_args, split="train", processor=self.tiny_llama_tokenizer, ) @@ -273,7 +275,7 @@ def prepare_fixture(self, tiny_llama_tokenizer): @parameterized.expand([["train[95%:]"], [{"train": "train[:5%]"}]]) def test_split_loading(self, split_def): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="open_platypus", splits=split_def, trust_remote_code_data=True, @@ -283,7 +285,7 @@ def test_split_loading(self, split_def): recipe_args = RecipeArguments() stage_runner = StageRunner( model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, training_args=training_args, recipe_args=recipe_args, ) @@ -319,7 +321,7 @@ def preprocess(sample): ) stage_runner = StageRunner( model_args=None, - data_args=DatasetArguments( + dataset_args=DatasetArguments( dataset=tokenized_dataset, shuffle_calibration_samples=False ), training_args=TrainingArguments(do_oneshot=True), @@ -337,7 +339,7 @@ def preprocess(sample): calib_dataloader = format_calibration_data( tokenized_dataset=calib_dataset, num_calibration_samples=self.num_calib_samples, - do_shuffle=stage_runner._data_args.shuffle_calibration_samples, + do_shuffle=stage_runner._dataset_args.shuffle_calibration_samples, ) self.assertEqual(len(calib_dataloader), self.num_calib_samples) dataloader_sample = next(iter(calib_dataloader))["input_ids"] diff --git a/tests/llmcompressor/transformers/finetune/data/test_registry.py b/tests/llmcompressor/transformers/finetune/data/test_registry.py index 694a9b6d3..29895b4a4 100644 --- a/tests/llmcompressor/transformers/finetune/data/test_registry.py +++ b/tests/llmcompressor/transformers/finetune/data/test_registry.py @@ -11,49 +11,49 @@ @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_c4_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments(dataset="c4", concatenate_data=True) + dataset_args = DatasetArguments(dataset="c4", concatenate_data=True) c4_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=None, processor=tiny_llama_tokenizer, ) assert isinstance(c4_manager, TextGenerationDataset) assert isinstance(c4_manager, C4Dataset) - assert c4_manager.data_args.text_column == "text" + assert c4_manager.dataset_args.text_column == "text" assert not c4_manager.padding - assert c4_manager.max_seq_length == data_args.max_seq_length + assert c4_manager.max_seq_length == dataset_args.max_seq_length @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_wikitext_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset="wikitext", dataset_config_name="wikitext-2-raw-v1" ) wiki_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=None, processor=tiny_llama_tokenizer, ) assert isinstance(wiki_manager, TextGenerationDataset) assert isinstance(wiki_manager, WikiTextDataset) - assert wiki_manager.data_args.text_column == "text" + assert wiki_manager.dataset_args.text_column == "text" assert wiki_manager.padding == "max_length" - assert wiki_manager.max_seq_length == data_args.max_seq_length + assert wiki_manager.max_seq_length == dataset_args.max_seq_length @pytest.mark.usefixtures("tiny_llama_tokenizer") def test_open_platypus_initializes(tiny_llama_tokenizer): - data_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) + dataset_args = DatasetArguments(dataset="open_platypus", pad_to_max_length=False) op_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split=None, processor=tiny_llama_tokenizer, ) assert isinstance(op_manager, TextGenerationDataset) assert isinstance(op_manager, OpenPlatypusDataset) - assert op_manager.data_args.text_column == "text" + assert op_manager.dataset_args.text_column == "text" assert not op_manager.padding - assert op_manager.max_seq_length == data_args.max_seq_length + assert op_manager.max_seq_length == dataset_args.max_seq_length diff --git a/tests/llmcompressor/transformers/finetune/test_session_mixin.py b/tests/llmcompressor/transformers/finetune/test_session_mixin.py index 4fa981de9..65e5140bf 100644 --- a/tests/llmcompressor/transformers/finetune/test_session_mixin.py +++ b/tests/llmcompressor/transformers/finetune/test_session_mixin.py @@ -15,7 +15,7 @@ def __init__( recipe: Optional[str], recipe_args: Optional[Union[Dict[str, Any], str]] = None, model_args: Optional[Union[Dict[str, Any], str]] = None, - data_args: Optional[Union[Dict[str, Any], str]] = None, + dataset_args: Optional[Union[Dict[str, Any], str]] = None, teacher: Optional[Union[Module, str]] = None, **kwargs, ): @@ -24,7 +24,7 @@ def __init__( recipe=recipe, recipe_args=recipe_args, model_args=model_args, - data_args=data_args, + dataset_args=dataset_args, teacher=teacher, **kwargs, ) diff --git a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py index 1016cf422..5528a443e 100644 --- a/tests/llmcompressor/transformers/obcq/test_obcq_completion.py +++ b/tests/llmcompressor/transformers/obcq/test_obcq_completion.py @@ -26,14 +26,14 @@ def labeled_dataloader(self, dataset_name, model_name): from llmcompressor.transformers.finetune.data import TextGenerationDataset tokenizer = AutoTokenizer.from_pretrained(model_name) - data_args = DatasetArguments( + dataset_args = DatasetArguments( dataset=dataset_name, max_seq_length=512, pad_to_max_length=False, ) dataset_manager = TextGenerationDataset.load_from_registry( - data_args.dataset, - data_args=data_args, + dataset_args.dataset, + dataset_args=dataset_args, split="train", processor=tokenizer, )