From 8cb9716cd7580bc7794dbfabedd0f091257543bb Mon Sep 17 00:00:00 2001 From: YeonwooSung Date: Sun, 16 Feb 2025 18:04:16 +0900 Subject: [PATCH] feat: Add script for GRPO training --- LLMs/training/train_grpo.py | 166 ++++++++++++++++++++++++++++++++++++ 1 file changed, 166 insertions(+) create mode 100644 LLMs/training/train_grpo.py diff --git a/LLMs/training/train_grpo.py b/LLMs/training/train_grpo.py new file mode 100644 index 0000000..0d844e0 --- /dev/null +++ b/LLMs/training/train_grpo.py @@ -0,0 +1,166 @@ +import re +import torch +from datasets import load_dataset, Dataset +from transformers import AutoTokenizer, AutoModelForCausalLM +from peft import LoraConfig +from trl import GRPOConfig, GRPOTrainer + +# Load and prep dataset + +SYSTEM_PROMPT = """ +Respond in the following format: + +... + + +... + +""" + +XML_COT_FORMAT = """\ + +{reasoning} + + +{answer} + +""" + +def extract_xml_answer(text: str) -> str: + answer = text.split("")[-1] + answer = answer.split("")[0] + return answer.strip() + +def extract_hash_answer(text: str) -> str | None: + if "####" not in text: + return None + return text.split("####")[1].strip().replace(",", "").replace("$", "") + +# uncomment middle messages for 1-shot prompting +def get_gsm8k_questions(split = "train") -> Dataset: + data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore + data = data.map(lambda x: { # type: ignore + 'prompt': [ + {'role': 'system', 'content': SYSTEM_PROMPT}, + #{'role': 'user', 'content': 'What is the largest single-digit prime number?'}, + #{'role': 'assistant', 'content': XML_COT_FORMAT.format( + # reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.", + # answer="7" + #)}, + {'role': 'user', 'content': x['question']} + ], + 'answer': extract_hash_answer(x['answer']) + }) # type: ignore + return data # type: ignore + +dataset = get_gsm8k_questions() + +# Reward functions +def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]: + responses = [completion[0]['content'] for completion in completions] + q = prompts[0][-1]['content'] + extracted_responses = [extract_xml_answer(r) for r in responses] + print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}") + return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] + +def int_reward_func(completions, **kwargs) -> list[float]: + responses = [completion[0]['content'] for completion in completions] + extracted_responses = [extract_xml_answer(r) for r in responses] + return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] + +def strict_format_reward_func(completions, **kwargs) -> list[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r"^\n.*?\n\n\n.*?\n\n$" + responses = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] + return [0.5 if match else 0.0 for match in matches] + +def soft_format_reward_func(completions, **kwargs) -> list[float]: + """Reward function that checks if the completion has a specific format.""" + pattern = r".*?\s*.*?" + responses = [completion[0]["content"] for completion in completions] + matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses] + return [0.5 if match else 0.0 for match in matches] + +def count_xml(text) -> float: + count = 0.0 + if text.count("\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + if text.count("\n\n") == 1: + count += 0.125 + count -= len(text.split("\n\n")[-1])*0.001 + if text.count("\n") == 1: + count += 0.125 + count -= (len(text.split("\n")[-1]) - 1)*0.001 + return count + +def xmlcount_reward_func(completions, **kwargs) -> list[float]: + contents = [completion[0]["content"] for completion in completions] + return [count_xml(c) for c in contents] + +#model_name = "meta-llama/Llama-3.2-1B-Instruct" +model_name = "Qwen/Qwen2.5-1.5B-Instruct" + +if "Llama" in model_name: + output_dir = "outputs/Llama-1B-GRPO" + run_name = "Llama-1B-GRPO-gsm8k" +else: + output_dir="outputs/Qwen-1.5B-GRPO" + run_name="Qwen-1.5B-GRPO-gsm8k" + +training_args = GRPOConfig( + output_dir=output_dir, + run_name=run_name, + learning_rate=5e-6, + adam_beta1 = 0.9, + adam_beta2 = 0.99, + weight_decay = 0.1, + warmup_ratio = 0.1, + lr_scheduler_type='cosine', + logging_steps=1, + bf16=True, + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + num_generations=16, + max_prompt_length=256, + max_completion_length=786, + num_train_epochs=1, + save_steps=100, + max_grad_norm=0.1, + report_to="wandb", + log_on_each_node=False, +) +peft_config = LoraConfig( + r=16, + lora_alpha=64, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"], + task_type="CAUSAL_LM", + lora_dropout=0.05, +) +model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + device_map=None +).to("cuda") + +tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer.pad_token = tokenizer.eos_token + +# use peft at your own risk; not working for me with multi-GPU training +trainer = GRPOTrainer( + model=model, + processing_class=tokenizer, + reward_funcs=[ + xmlcount_reward_func, + soft_format_reward_func, + strict_format_reward_func, + int_reward_func, + correctness_reward_func], + args=training_args, + train_dataset=dataset, + #peft_config=peft_config +) +trainer.train()