Skip to content

Commit

Permalink
feat: Add SFT training code for mistral-7b with garole
Browse files Browse the repository at this point in the history
  • Loading branch information
YeonwooSung committed Mar 21, 2024
1 parent 131c708 commit 3768c3e
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions LLMs/training/sft_mistral_7b_garole/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch
import datasets
from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM
import trl

train_dataset = datasets.load_dataset('imdb', split='train')

args = TrainingArguments(
output_dir="./test-galore",
max_steps=100,
per_device_train_batch_size=2,
optim="galore_adamw",
optim_target_modules=["attn", "mlp"]
)

model_id = "mistralai/Mistral-7B-v0.1"

config = AutoConfig.from_pretrained(model_id)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_config(config).to(0)

trainer = trl.SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
dataset_text_field='text',
max_seq_length=512,
)

trainer.train()

0 comments on commit 3768c3e

Please sign in to comment.