-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
25 changed files
with
571 additions
and
89 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Input arguments for `oneshot`, `train`, `eval` entrypoints | ||
|
||
Parsers in `llm-compressor` define the input arguments required for various entry points, including `oneshot`, `train`, and `eval`. | ||
|
||
Each entry point (e.g., oneshot) carries out its logic based on the provided input arguments, `model`, `recipe`, and `dataset`. | ||
|
||
```python | ||
from llmcompressor.transformers import oneshot | ||
|
||
model = ... | ||
recipe = ... | ||
dataset = ... | ||
oneshot(model=model, recipe=recipe, dataset=dataset) | ||
``` | ||
|
||
In addition, users can futher control execution by providing additional arguments. For example, to save the optimized model after completion, the `output_dir` parameter can be specified: | ||
|
||
```python | ||
oneshot( | ||
..., | ||
output_dir=..., | ||
) | ||
``` | ||
|
||
These input arguments can be overloaded into the function signature and will be parsed using Hugging Face's [argument parser](https://github.com/huggingface/transformers/blob/main/src/transformers/hf_argparser.py). The parsers define the acceptable inputs; therefore any arguments to be passed in must be defined. | ||
|
||
`llm-compressor` uses four parsers, located in `llm_compressor/arg_parser`: | ||
* ModelArguments | ||
* DatasetArguments | ||
* RecipeArguments | ||
* TrainingArguments | ||
|
||
|
||
## ModelArguments | ||
Handles model loading and saving. For example, `ModelArguments.model` can be a Hugging Face model identifier or an instance of `AutoModelForCausalLM`. The `save_compressed` flag is a boolean that determines whether the model is saved in compressed safetensors format to minimize disk usage. | ||
|
||
## DataArguments | ||
Manages data loading and preprocessing. The dataset argument can specify a Hugging Face dataset stub or a local dataset compatible with [`load_dataset`](https://github.com/huggingface/datasets/blob/3a4e74a9ace62ecd5c9cde7dcb6bcabd65cc7857/src/datasets/load.py#L1905). The preprocessing_func is a callable function that applies custom logic, such as formatting the data using a chat template. | ||
|
||
## RecipeArguments | ||
Defines the model recipe. A `recipe` consists of user-defined instructions for optimizing the model. Examples of recipes can be found in the `/examples` directory. | ||
|
||
## TrainingArguments | ||
Specifies training parameters based on Hugging Face's [TrainingArguments class](https://github.com/huggingface/transformers/blob/main/src/transformers/training_args.py). These parameters include settings like learning rate (`learning_rate`), and the optimizer to use (`optim`). | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# flake8: noqa | ||
|
||
from .dataset_arguments import DatasetArguments | ||
from .model_arguments import ModelArguments | ||
from .recipe_arguments import RecipeArguments | ||
from .training_arguments import TrainingArguments |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Any, Callable, Dict, List, Optional, Union | ||
|
||
from transformers import DefaultDataCollator | ||
|
||
|
||
@dataclass | ||
class DVCDatasetArguments: | ||
""" | ||
Arguments for training using DVC | ||
""" | ||
|
||
dvc_data_repository: Optional[str] = field( | ||
default=None, | ||
metadata={"help": "Path to repository used for dvc_dataset_path"}, | ||
) | ||
|
||
|
||
@dataclass | ||
class CustomDatasetArguments(DVCDatasetArguments): | ||
""" | ||
Arguments for training using custom datasets | ||
""" | ||
|
||
dataset_path: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"Path to the custom dataset. Supports json, csv, dvc. " | ||
"For DVC, the to dvc dataset to load, of format dvc://path. " | ||
"For csv or json, the path containing the dataset. " | ||
), | ||
}, | ||
) | ||
|
||
text_column: str = field( | ||
default="text", | ||
metadata={ | ||
"help": ( | ||
"Optional key to be used as the `text` input to tokenizer/processor " | ||
"after dataset preprocesssing" | ||
) | ||
}, | ||
) | ||
|
||
remove_columns: Union[None, str, List] = field( | ||
default=None, | ||
metadata={"help": "Column names to remove after preprocessing (deprecated)"}, | ||
) | ||
|
||
preprocessing_func: Union[None, str, Callable] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"Typically a function which applies a chat template. Can take the form " | ||
"of either a function to apply to the dataset, a name defined in " | ||
"src/llmcompressor/transformers/utils/preprocessing_functions.py, or " | ||
"a path to a function definition of the form /path/to/file.py:func" | ||
) | ||
}, | ||
) | ||
|
||
data_collator: Callable[[Any], Any] = field( | ||
default_factory=lambda: DefaultDataCollator(), | ||
metadata={"help": "The function to used to form a batch from the dataset"}, | ||
) | ||
|
||
|
||
@dataclass | ||
class DatasetArguments(CustomDatasetArguments): | ||
""" | ||
Arguments pertaining to what data we are going to input our model for | ||
calibration, training or eval | ||
Using `HfArgumentParser` we can turn this class into argparse | ||
arguments to be able to specify them on the command line | ||
""" | ||
|
||
dataset: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"The name of the dataset to use (via the datasets library). " | ||
"Supports input as a string or DatasetDict from HF" | ||
) | ||
}, | ||
) | ||
dataset_config_name: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": ("The configuration name of the dataset to use"), | ||
}, | ||
) | ||
max_seq_length: int = field( | ||
default=384, | ||
metadata={ | ||
"help": "The maximum total input sequence length after tokenization. " | ||
"Sequences longer than this will be truncated, sequences shorter will " | ||
"be padded." | ||
}, | ||
) | ||
concatenate_data: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether or not to concatenate datapoints to fill max_seq_length" | ||
}, | ||
) | ||
raw_kwargs: Dict = field( | ||
default_factory=dict, | ||
metadata={"help": "Additional keyboard args to pass to datasets load_data"}, | ||
) | ||
splits: Union[None, str, List, Dict] = field( | ||
default=None, | ||
metadata={"help": "Optional percentages of each split to download"}, | ||
) | ||
num_calibration_samples: Optional[int] = field( | ||
default=512, | ||
metadata={"help": "Number of samples to use for one-shot calibration"}, | ||
) | ||
shuffle_calibration_samples: Optional[bool] = field( | ||
default=True, | ||
metadata={ | ||
"help": "whether to shuffle the dataset before selecting calibration data" | ||
}, | ||
) | ||
streaming: Optional[bool] = field( | ||
default=False, | ||
metadata={"help": "True to stream data from a cloud dataset"}, | ||
) | ||
overwrite_cache: bool = field( | ||
default=False, | ||
metadata={"help": "Overwrite the cached preprocessed datasets or not."}, | ||
) | ||
preprocessing_num_workers: Optional[int] = field( | ||
default=None, | ||
metadata={"help": "The number of processes to use for the preprocessing."}, | ||
) | ||
pad_to_max_length: bool = field( | ||
default=True, | ||
metadata={ | ||
"help": "Whether to pad all samples to `max_seq_length`. If False, " | ||
"will pad the samples dynamically when batching to the maximum length " | ||
"in the batch (which can be faster on GPU but will be slower on TPU)." | ||
}, | ||
) | ||
max_train_samples: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": "For debugging purposes or quicker training, truncate the number " | ||
"of training examples to this value if set." | ||
}, | ||
) | ||
max_eval_samples: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": "For debugging purposes or quicker training, truncate the number " | ||
"of evaluation examples to this value if set." | ||
}, | ||
) | ||
max_predict_samples: Optional[int] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"For debugging purposes or quicker training, truncate the number of " | ||
"prediction examples to this value if set." | ||
), | ||
}, | ||
) | ||
min_tokens_per_module: Optional[float] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"The minimum percentage of tokens (out of the total number) " | ||
"that the module should 'receive' throughout the forward " | ||
"pass of the calibration. If a module receives fewer tokens, " | ||
"a warning will be logged. Defaults to 1/num_of_experts." | ||
"note: this argument is only relevant for MoE models" | ||
), | ||
}, | ||
) | ||
trust_remote_code_data: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether or not to allow for datasets defined on the Hub using " | ||
"a dataset script. This option should only be set to True for " | ||
"repositories you trust and in which you have read the code, as it " | ||
"will execute code present on the Hub on your local machine." | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
|
||
@dataclass | ||
class ModelArguments: | ||
""" | ||
Model variables used for oneshot calibration, finetuning and | ||
stage runners (sequential run of oneshot and finetune). | ||
""" | ||
|
||
model: str = field( | ||
metadata={ | ||
"help": ( | ||
"A pretrained model or a string as a path to pretrained model, " | ||
"HF stub, or model identifier from huggingface.co/models." | ||
) | ||
}, | ||
) | ||
distill_teacher: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Teacher model (a trained text generation model)", | ||
}, | ||
) | ||
config_name: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Pretrained config name or path if not the same as model_name" | ||
}, | ||
) | ||
tokenizer: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Pretrained tokenizer name or path if not the same as model_name" | ||
}, | ||
) | ||
processor: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Pretrained processor name or path if not the same as model_name" | ||
}, | ||
) | ||
cache_dir: Optional[str] = field( | ||
default=None, | ||
metadata={"help": "Where to store the pretrained data from huggingface.co"}, | ||
) | ||
|
||
use_auth_token: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Will use token generated when running `transformers-cli login` " | ||
"(necessary to use this script with private models)" | ||
}, | ||
) | ||
precision: str = field( | ||
default="auto", | ||
metadata={"help": "Precision to cast model weights to, default to auto"}, | ||
) | ||
|
||
tie_word_embeddings: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether the model's input and output word embeddings " | ||
"should be tied. Note that this is only relevant if the " | ||
"model has a output word embedding layer." | ||
}, | ||
) | ||
trust_remote_code_model: bool = field( | ||
default=False, | ||
metadata={ | ||
"help": "Whether or not to allow for custom models to execute their " | ||
"own modeling files. This option should only be set to True for " | ||
"repositories you trust and in which you have read the code" | ||
}, | ||
) | ||
save_compressed: Optional[bool] = field( | ||
default=True, | ||
metadata={"help": "Whether to compress sparse models during save"}, | ||
) | ||
oneshot_device: Optional[str] = field( | ||
default="cuda:0", | ||
metadata={"help": "Device to run oneshot calibration on"}, | ||
) | ||
model_revision: str = field( | ||
default="main", | ||
metadata={ | ||
"help": "The specific model version to use " | ||
"(can be a branch name, tag name or commit id)" | ||
}, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from dataclasses import dataclass, field | ||
from typing import List, Optional | ||
|
||
|
||
@dataclass | ||
class RecipeArguments: | ||
"""Recipe and session variables""" | ||
|
||
recipe: Optional[str] = field( | ||
default=None, | ||
metadata={ | ||
"help": "Path to a LLM Compressor sparsification recipe", | ||
}, | ||
) | ||
recipe_args: Optional[List[str]] = field( | ||
default=None, | ||
metadata={ | ||
"help": ( | ||
"List of recipe arguments to evaluate, of the format key1=value1 " | ||
"key2=value2" | ||
) | ||
}, | ||
) | ||
clear_sparse_session: Optional[bool] = field( | ||
default=False, | ||
metadata={ | ||
"help": ( | ||
"Whether to clear CompressionSession/CompressionLifecycle ", | ||
"data between runs.", | ||
) | ||
}, | ||
) |
Oops, something went wrong.