diff --git a/LLMs/training/sft_mistral_7b_garole/train.py b/LLMs/training/sft_mistral_7b_garole/train.py new file mode 100644 index 0000000..795e41c --- /dev/null +++ b/LLMs/training/sft_mistral_7b_garole/train.py @@ -0,0 +1,31 @@ +import torch +import datasets +from transformers import TrainingArguments, AutoConfig, AutoTokenizer, AutoModelForCausalLM +import trl + +train_dataset = datasets.load_dataset('imdb', split='train') + +args = TrainingArguments( + output_dir="./test-galore", + max_steps=100, + per_device_train_batch_size=2, + optim="galore_adamw", + optim_target_modules=["attn", "mlp"] +) + +model_id = "mistralai/Mistral-7B-v0.1" + +config = AutoConfig.from_pretrained(model_id) + +tokenizer = AutoTokenizer.from_pretrained(model_id) +model = AutoModelForCausalLM.from_config(config).to(0) + +trainer = trl.SFTTrainer( + model=model, + args=args, + train_dataset=train_dataset, + dataset_text_field='text', + max_seq_length=512, +) + +trainer.train()