Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Cosmetic] Rename data_args to dataset_args #1206

Merged
merged 10 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/trl_mixin/ex_trl_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
max_seq_length = 512

# Load gsm8k using SparseML dataset tools
data_args = DatasetArguments(
dataset_args = DatasetArguments(
dataset="gsm8k", dataset_config_name="main", max_seq_length=max_seq_length
)
dataset_manager = TextGenerationDataset.load_from_registry(
data_args.dataset,
data_args=data_args,
dataset_args.dataset,
dataset_args=dataset_args,
split="train",
processor=tokenizer,
)
Expand Down Expand Up @@ -69,7 +69,7 @@
train_dataset=train_dataset,
data_collator=data_collator,
trl_sft_config_args=trl_sft_config_args,
data_args=data_args,
dataset_args=dataset_args,
model_args=model_args,
)
trainer.train()
Expand Down
26 changes: 15 additions & 11 deletions src/llmcompressor/entrypoints/oneshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Oneshot:
`kwargs` are parsed into:
- `model_args`: Arguments for loading and configuring a pretrained model
(e.g., `AutoModelForCausalLM`).
- `data_args`: Arguments for dataset-related configurations, such as
- `dataset_args`: Arguments for dataset-related configurations, such as
calibration dataloaders.
- `recipe_args`: Arguments for defining and configuring recipes that specify
optimization actions.
Expand Down Expand Up @@ -108,24 +108,23 @@ def __init__(
"""
Initializes the `Oneshot` class with provided arguments.

Parses the input keyword arguments into `model_args`, `data_args`, and
Parses the input keyword arguments into `model_args`, `dataset_args`, and
`recipe_args`. Performs preprocessing to initialize the model and
tokenizer/processor.

:param model_args: ModelArguments parameters, responsible for controlling
model loading and saving logic
:param data_args: DatasetArguments parameters, responsible for controlling
:param dataset_args: DatasetArguments parameters, responsible for controlling
dataset loading, preprocessing and dataloader loading
:param recipe_args: RecipeArguments parameters, responsible for containing
recipe-related parameters
:param output_dir: Path to save the output model after carrying out oneshot

"""

model_args, dataset_args, recipe_args, _, output_dir = parse_args(**kwargs)

self.model_args = model_args
self.data_args = dataset_args
self.dataset_args = dataset_args
self.recipe_args = recipe_args
self.output_dir = output_dir

Expand All @@ -136,14 +135,19 @@ def __init__(

@classmethod
def from_args(
cls, model_args, data_args, recipe_args, output_dir, do_preprocess: bool = True
cls,
model_args,
dataset_args,
recipe_args,
output_dir,
do_preprocess: bool = True,
):
"""
Used only for the stage runner to populate the args.
"""
instance = super().__new__(cls)
instance.model_args = model_args
instance.data_args = data_args
instance.dataset_args = dataset_args
instance.recipe_args = recipe_args
instance.output_dir = output_dir

Expand Down Expand Up @@ -176,7 +180,7 @@ def __call__(self):
self.processor = self.model_args.processor

calibration_dataloader = get_calibration_dataloader(
self.data_args, self.processor
self.dataset_args, self.processor
)
self.apply_recipe_modifiers(
calibration_dataloader=calibration_dataloader,
Expand Down Expand Up @@ -242,7 +246,7 @@ def _pre_process(self):
- Applies patches to fix tied tensor issues and modifies `save_pretrained`
behavior.
- Initializes the processor if specified as a path or `None`.
- Sets the minimum tokens per module if `data_args` are provided.
- Sets the minimum tokens per module if `dataset_args` are provided.

Raises:
FileNotFoundError: If the model or processor path is invalid.
Expand All @@ -265,8 +269,8 @@ def _pre_process(self):
self.processor = self.model_args.processor

# Set minimum tokens per module if data arguments are provided
if self.data_args:
self.min_tokens_per_module = self.data_args.min_tokens_per_module
if self.dataset_args:
self.min_tokens_per_module = self.dataset_args.min_tokens_per_module

def check_tied_embeddings(self):
"""
Expand Down
76 changes: 40 additions & 36 deletions src/llmcompressor/transformers/finetune/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TextGenerationDataset(RegistryMixin):
3. Tokenize dataset using model tokenizer/processor
4. Apply post processing such as grouping text and/or adding labels for finetuning

:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""
Expand All @@ -41,11 +41,11 @@ class TextGenerationDataset(RegistryMixin):

def __init__(
self,
data_args: DatasetArguments,
dataset_args: DatasetArguments,
split: str,
processor: Processor,
):
self.data_args = data_args
self.dataset_args = dataset_args
self.split = split
self.processor = processor

Expand All @@ -58,23 +58,23 @@ def __init__(
self.tokenizer.pad_token = self.tokenizer.eos_token

# configure sequence length
max_seq_length = data_args.max_seq_length
if data_args.max_seq_length > self.tokenizer.model_max_length:
max_seq_length = dataset_args.max_seq_length
if dataset_args.max_seq_length > self.tokenizer.model_max_length:
logger.warning(
f"The max_seq_length passed ({max_seq_length}) is larger than "
f"maximum length for model ({self.tokenizer.model_max_length}). "
f"Using max_seq_length={self.tokenizer.model_max_length}."
)
self.max_seq_length = min(
data_args.max_seq_length, self.tokenizer.model_max_length
dataset_args.max_seq_length, self.tokenizer.model_max_length
)

# configure padding
self.padding = (
False
if self.data_args.concatenate_data
if self.dataset_args.concatenate_data
else "max_length"
if self.data_args.pad_to_max_length
if self.dataset_args.pad_to_max_length
else False
)

Expand All @@ -83,7 +83,7 @@ def __init__(
self.padding = False

def __call__(self, add_labels: bool = True) -> DatasetType:
dataset = self.data_args.dataset
dataset = self.dataset_args.dataset

if isinstance(dataset, str):
# load dataset: load from huggingface or disk
Expand All @@ -96,8 +96,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
dataset,
self.preprocess,
batched=False,
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Preprocessing",
)
logger.debug(f"Dataset after preprocessing: {get_columns(dataset)}")
Expand All @@ -121,20 +121,20 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
# regardless of `batched` argument
remove_columns=get_columns(dataset), # assumes that input names
# and output names are disjoint
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Tokenizing",
)
logger.debug(f"Model kwargs after tokenizing: {get_columns(dataset)}")

if self.data_args.concatenate_data:
if self.dataset_args.concatenate_data:
# postprocess: group text
dataset = self.map(
dataset,
self.group_text,
batched=True,
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Concatenating data",
)
logger.debug(f"Model kwargs after concatenating: {get_columns(dataset)}")
Expand All @@ -145,8 +145,8 @@ def __call__(self, add_labels: bool = True) -> DatasetType:
dataset,
self.add_labels,
batched=False, # not compatible with batching, need row lengths
num_proc=self.data_args.preprocessing_num_workers,
load_from_cache_file=not self.data_args.overwrite_cache,
num_proc=self.dataset_args.preprocessing_num_workers,
load_from_cache_file=not self.dataset_args.overwrite_cache,
desc="Adding labels",
)
logger.debug(f"Model kwargs after adding labels: {get_columns(dataset)}")
Expand All @@ -165,27 +165,31 @@ def load_dataset(self):
:param cache_dir: disk location to search for cached dataset
:return: the requested dataset
"""
if self.data_args.dataset_path is not None:
if self.data_args.dvc_data_repository is not None:
self.data_args.raw_kwargs["storage_options"] = {
"url": self.data_args.dvc_data_repository
if self.dataset_args.dataset_path is not None:
if self.dataset_args.dvc_data_repository is not None:
self.dataset_args.raw_kwargs["storage_options"] = {
"url": self.dataset_args.dvc_data_repository
}
self.data_args.raw_kwargs["data_files"] = self.data_args.dataset_path
self.dataset_args.raw_kwargs["data_files"] = (
self.dataset_args.dataset_path
)
else:
self.data_args.raw_kwargs["data_files"] = get_custom_datasets_from_path(
self.data_args.dataset_path,
self.data_args.dataset
if hasattr(self.data_args, "dataset")
else self.data_args.dataset_name,
self.dataset_args.raw_kwargs["data_files"] = (
get_custom_datasets_from_path(
self.dataset_args.dataset_path,
self.dataset_args.dataset
if hasattr(self.dataset_args, "dataset")
else self.dataset_args.dataset_name,
)
)

logger.debug(f"Loading dataset {self.data_args.dataset}")
logger.debug(f"Loading dataset {self.dataset_args.dataset}")
return get_raw_dataset(
self.data_args,
self.dataset_args,
None,
split=self.split,
streaming=self.data_args.streaming,
**self.data_args.raw_kwargs,
streaming=self.dataset_args.streaming,
**self.dataset_args.raw_kwargs,
)

@cached_property
Expand All @@ -194,7 +198,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
The function must return keys which correspond to processor/tokenizer kwargs,
optionally including PROMPT_KEY
"""
preprocessing_func = self.data_args.preprocessing_func
preprocessing_func = self.dataset_args.preprocessing_func

if callable(preprocessing_func):
return preprocessing_func
Expand All @@ -218,9 +222,9 @@ def dataset_template(self) -> Union[Callable[[Any], Any], None]:
def rename_columns(self, dataset: DatasetType) -> DatasetType:
# rename columns to match processor/tokenizer kwargs
column_names = get_columns(dataset)
if self.data_args.text_column in column_names and "text" not in column_names:
logger.debug(f"Renaming column `{self.data_args.text_column}` to `text`")
dataset = dataset.rename_column(self.data_args.text_column, "text")
if self.dataset_args.text_column in column_names and "text" not in column_names:
logger.debug(f"Renaming column `{self.dataset_args.text_column}` to `text`")
dataset = dataset.rename_column(self.dataset_args.text_column, "text")

return dataset

Expand Down
14 changes: 8 additions & 6 deletions src/llmcompressor/transformers/finetune/data/c4.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@ class C4Dataset(TextGenerationDataset):
"""
Child text generation class for the C4 dataset

:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""

def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "allenai/c4"
data_args.text_column = "text"
def __init__(
self, dataset_args: "DatasetArguments", split: str, processor: Processor
):
dataset_args = deepcopy(dataset_args)
dataset_args.dataset = "allenai/c4"
dataset_args.text_column = "text"

super().__init__(data_args=data_args, split=split, processor=processor)
super().__init__(dataset_args=dataset_args, split=split, processor=processor)
14 changes: 8 additions & 6 deletions src/llmcompressor/transformers/finetune/data/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,21 @@ class CNNDailyMailDataset(TextGenerationDataset):
"""
Text generation class for the CNN/DailyMail dataset

:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
:param processor: processor or tokenizer to use on dataset
"""

SAMPLE_TEMPLATE = "Article:\n{article}\n\n### Summarization:\n{highlights}\n"

def __init__(self, data_args: "DatasetArguments", split: str, processor: Processor):
data_args = deepcopy(data_args)
data_args.dataset = "cnn_dailymail"
data_args.dataset_config_name = "3.0.0"
def __init__(
self, dataset_args: "DatasetArguments", split: str, processor: Processor
):
dataset_args = deepcopy(dataset_args)
dataset_args.dataset = "cnn_dailymail"
dataset_args.dataset_config_name = "3.0.0"

super().__init__(data_args=data_args, split=split, processor=processor)
super().__init__(dataset_args=dataset_args, split=split, processor=processor)

def dataset_template(self, sample):
return {
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class CustomDataset(TextGenerationDataset):
Child text generation class for custom local dataset supporting load
for csv and json

:param data_args: configuration settings for dataset loading
:param dataset_args: configuration settings for dataset loading
:param split: split from dataset to load, for instance `test` or `train[:5%]`
Can also be set to None to load all the splits
:param processor: processor or tokenizer to use on dataset
Expand Down
Loading