-
Notifications
You must be signed in to change notification settings - Fork 45
SFT for D2L + Pre-Training (rename of the previous SFT) #102
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
eb4878f
5c9cb03
2a31400
ec34371
f0f7be0
0329dfb
62184d9
c7f8a1c
87e5ed3
60bcb60
fe1cc98
40aa741
dbed085
1bcc4bb
1b21184
8db6f09
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: what is this .ipynb file for?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is used to generate synthetic immigration data by rephrasing. |
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| """Demo for the supervised fine tuning. | ||
|
|
||
| python -m example.rlhf.supervised_finetuning_demo_d2l | ||
| """ | ||
|
|
||
| from peft import LoraConfig | ||
| from pykoi.chat import QuestionAnswerDatabase | ||
| from pykoi.chat.db.constants import (QA_CSV_HEADER_ANSWER, QA_CSV_HEADER_ID, | ||
| QA_CSV_HEADER_QUESTION, | ||
| QA_CSV_HEADER_VOTE_STATUS) | ||
| from pykoi.rlhf import RLHFConfig, SupervisedFinetuning | ||
|
|
||
| # get data from local database | ||
| qa_database = QuestionAnswerDatabase() | ||
| my_data_pd = qa_database.retrieve_all_question_answers_as_pandas() | ||
| my_data_pd = my_data_pd[ | ||
| [ | ||
| QA_CSV_HEADER_ID, | ||
| QA_CSV_HEADER_QUESTION, | ||
| QA_CSV_HEADER_ANSWER, | ||
| QA_CSV_HEADER_VOTE_STATUS, | ||
| ] | ||
| ] | ||
|
|
||
| # analyze the data | ||
| print(my_data_pd) | ||
| print("My local database has {} samples in total".format(my_data_pd.shape[0])) | ||
|
|
||
| # run supervised finetuning | ||
| config = RLHFConfig(base_model_path="mistralai/Mistral-7B-Instruct-v0.1", | ||
| dataset_type="local_csv", dataset_name="data/chapter22_trnvalfromseed_data_processed.csv", | ||
| train_test_split_ratio=0, # ratio for test set DH:TODO: COBINE TRAIN AND EVAL | ||
| max_seq_length=896, | ||
| per_device_eval_batch_size=1, | ||
| log_freq=20, | ||
| # dh: NOTE: 1 EPOCH iterates the dataset once. So log freq 20 means iterating 20 entries when training batch size = 1. | ||
| # (i.e., log_freq = 0.12 epoch when the dataset has 166 entires). | ||
| save_freq=40000, | ||
| num_train_epochs=20, | ||
| max_steps=-1, # if a positive number is given, it will override num_train_epochs | ||
| device_map="auto", | ||
| lora_config_rl=LoraConfig( | ||
| r=512, | ||
| lora_alpha=1024, | ||
| lora_dropout=0.05, | ||
| target_modules=["q_proj", "k_proj", "v_proj", "o_proj", ], # "gate_proj","up_proj","down_proj",], #"lm_head",], | ||
| bias="none", | ||
| task_type="CAUSAL_LM" | ||
| ), | ||
| data_collator="DataCollatorForCompletionOnlyLM", | ||
| no_evaluation=True, | ||
| prepare_text="d2l", | ||
| split = "train[:10%]" | ||
| ) | ||
| rlhf_step1_sft = SupervisedFinetuning(config) | ||
| rlhf_step1_sft.train_and_save("./models/rlhf_step1_sft") |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
|
|
||
| from accelerate import Accelerator | ||
| from peft import LoraConfig, TaskType | ||
| # TODO: DH: num_train_epochs=20, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: what is this comment code for? |
||
|
|
||
|
|
||
| @dataclass | ||
|
|
@@ -119,6 +120,7 @@ class RLHFConfig: | |
| default="./rlhf_checkpoints", | ||
| metadata={"help": "Output directory for all model weights."}, | ||
| ) | ||
| num_train_epochs: Optional[int] = field(default=5, metadata={"help": "supervised fine tuning training epochs"}) | ||
| log_freq: Optional[int] = field(default=1, metadata={"help": "Logging frequency."}) | ||
| eval_freq: Optional[int] = field( | ||
| default=1000, metadata={"help": "Evaluation frequency."} | ||
|
|
@@ -182,6 +184,18 @@ class RLHFConfig: | |
| ), | ||
| metadata={"help": "LoRA configuration."}, | ||
| ) | ||
| data_collator: Optional[str] = field( | ||
| default=None, | ||
| metadata={"help": "The name of data collator to use for training."}, | ||
| ) | ||
| no_evaluation: Optional[bool] = field( | ||
| default=False, | ||
| metadata={"help": "Whether to disable evaluations during training."}, | ||
| ) | ||
| prepare_text: Optional[str] = field( | ||
| default="sample", | ||
| metadata={"help": "How to prepare the text for the model."}, | ||
| ) | ||
|
|
||
| # Step 2 reward modeling parameters | ||
| reward_model_path: Optional[str] = field( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,40 @@ | ||
| from typing import Any, Dict, List, Union | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. qq: do we still need this customized collator per our discussion, |
||
| from transformers import DataCollatorForLanguageModeling | ||
| import numpy as np | ||
|
|
||
|
|
||
| class DataCollatorForCompletionOnlyLM(DataCollatorForLanguageModeling): | ||
| def torch_call( | ||
| self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]: | ||
| batch = super().torch_call(examples) | ||
|
|
||
| # The prompt ends with the response key plus a newline. We encode this and then try to find it in the | ||
| # sequence of tokens. This should just be a single token. | ||
| RESPONSE_KEY = "### Response:" | ||
| RESPONSE_KEY_NL = f"{RESPONSE_KEY}\n" | ||
| response_token_ids = self.tokenizer.encode(RESPONSE_KEY_NL) | ||
|
|
||
| labels = batch["labels"].clone() | ||
|
|
||
| for i in range(len(examples)): | ||
|
|
||
| response_token_ids_start_idx = None | ||
| for idx in np.where( | ||
| batch["labels"][i] == response_token_ids[0])[0]: | ||
| response_token_ids_start_idx = idx | ||
| break | ||
|
|
||
| if response_token_ids_start_idx is None: | ||
| raise RuntimeError( | ||
| f'Could not find response key {response_token_ids} in token IDs {batch["labels"][i]}' | ||
| ) | ||
|
|
||
| response_token_ids_end_idx = response_token_ids_start_idx + 1 | ||
|
|
||
| # Make pytorch loss function ignore all tokens up through the end | ||
| # of the response key | ||
| labels[i, :response_token_ids_end_idx] = -100 | ||
|
|
||
| batch["labels"] = labels | ||
|
|
||
| return batch | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: what is this .ipynb file for?