Skip to content

Commit

Permalink
revert signature, remove training_args, remove unnec class
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Feb 17, 2025
1 parent c2db397 commit 0f1eb55
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 23 deletions.
12 changes: 7 additions & 5 deletions examples/trl_mixin/ex_trl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import DataCollatorForCompletionOnlyLM

from llmcompressor.args import TrainingArguments
from llmcompressor.args import ModelArguments

model_path = "neuralmagic/Llama-2-7b-pruned50-retrained"
output_dir = "./output_trl_sft_test_7b_gsm8k_sft_data"
Expand Down Expand Up @@ -39,21 +39,23 @@ def formatting_prompts_func(example):
response_template = "Answer:"
collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer)

training_args = TrainingArguments(
training_args = dict(
output_dir=output_dir,
num_train_epochs=0.6,
num_train_epochs=0.1,
logging_steps=50,
gradient_checkpointing=True,
max_seq_length=512,
)
model_args = ModelArguments(model=model)

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
processing_class=tokenizer,
recipe=recipe,
train_dataset=dataset,
formatting_func=formatting_prompts_func,
data_collator=collator,
args=training_args,
max_seq_length=512,
model_args=model_args,
)
trainer.train()
20 changes: 3 additions & 17 deletions examples/trl_mixin/sft_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from trl import SFTConfig as TRLSFTConfig
from trl import SFTTrainer as TRLSFTTrainer

from llmcompressor.args import TrainingArguments
from llmcompressor.transformers.finetune.session_mixin import SessionManagerMixIn

__all__ = ["SFTTrainer"]
Expand All @@ -10,11 +9,9 @@
class SFTTrainer(SessionManagerMixIn, TRLSFTTrainer):
def __init__(self, *args, **kwargs):
sft_config_args = kwargs.get("args")
if (
sft_config_args is not None
and sft_config_args.__class__.__name__ == "TrainingArguments"
):
kwargs["args"] = SFTConfig(**sft_config_args.to_dict())
if sft_config_args is not None:
kwargs["args"] = TRLSFTConfig(**sft_config_args)

super().__init__(*args, **kwargs)

def _prepare_dataset(self, dataset, *args, **kwargs):
Expand All @@ -23,14 +20,3 @@ def _prepare_dataset(self, dataset, *args, **kwargs):
return dataset

return super()._prepare_dataset(dataset, *args, **kwargs)


class SFTConfig(TrainingArguments, TRLSFTConfig):
"""
This class is needed to wrap the llmcompressor.transformers.TrainingArguments
and TRLSFTConfig classes. This allows for the use of arguments and
configurations from both classes when training a model.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ class SessionManagerMixIn:
def __init__(
self,
recipe: str,
data_args: "DatasetArguments",
model_args: "ModelArguments",
data_args: Optional["DatasetArguments"] = None,
teacher: Optional[Union[Module, str]] = None,
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
**kwargs,
Expand Down

0 comments on commit 0f1eb55

Please sign in to comment.