diff --git a/LLMs/training/train_grpo.py b/LLMs/training/train_grpo.py
new file mode 100644
index 0000000..0d844e0
--- /dev/null
+++ b/LLMs/training/train_grpo.py
@@ -0,0 +1,166 @@
+import re
+import torch
+from datasets import load_dataset, Dataset
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from peft import LoraConfig
+from trl import GRPOConfig, GRPOTrainer
+
+# Load and prep dataset
+
+SYSTEM_PROMPT = """
+Respond in the following format:
+
+...
+
+
+...
+
+"""
+
+XML_COT_FORMAT = """\
+
+{reasoning}
+
+
+{answer}
+
+"""
+
+def extract_xml_answer(text: str) -> str:
+ answer = text.split("")[-1]
+ answer = answer.split("")[0]
+ return answer.strip()
+
+def extract_hash_answer(text: str) -> str | None:
+ if "####" not in text:
+ return None
+ return text.split("####")[1].strip().replace(",", "").replace("$", "")
+
+# uncomment middle messages for 1-shot prompting
+def get_gsm8k_questions(split = "train") -> Dataset:
+ data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
+ data = data.map(lambda x: { # type: ignore
+ 'prompt': [
+ {'role': 'system', 'content': SYSTEM_PROMPT},
+ #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
+ #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
+ # reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
+ # answer="7"
+ #)},
+ {'role': 'user', 'content': x['question']}
+ ],
+ 'answer': extract_hash_answer(x['answer'])
+ }) # type: ignore
+ return data # type: ignore
+
+dataset = get_gsm8k_questions()
+
+# Reward functions
+def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
+ responses = [completion[0]['content'] for completion in completions]
+ q = prompts[0][-1]['content']
+ extracted_responses = [extract_xml_answer(r) for r in responses]
+ print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
+ return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
+
+def int_reward_func(completions, **kwargs) -> list[float]:
+ responses = [completion[0]['content'] for completion in completions]
+ extracted_responses = [extract_xml_answer(r) for r in responses]
+ return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
+
+def strict_format_reward_func(completions, **kwargs) -> list[float]:
+ """Reward function that checks if the completion has a specific format."""
+ pattern = r"^\n.*?\n\n\n.*?\n\n$"
+ responses = [completion[0]["content"] for completion in completions]
+ matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
+ return [0.5 if match else 0.0 for match in matches]
+
+def soft_format_reward_func(completions, **kwargs) -> list[float]:
+ """Reward function that checks if the completion has a specific format."""
+ pattern = r".*?\s*.*?"
+ responses = [completion[0]["content"] for completion in completions]
+ matches = [re.match(pattern, r, flags=re.DOTALL) for r in responses]
+ return [0.5 if match else 0.0 for match in matches]
+
+def count_xml(text) -> float:
+ count = 0.0
+ if text.count("\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+ if text.count("\n\n") == 1:
+ count += 0.125
+ count -= len(text.split("\n\n")[-1])*0.001
+ if text.count("\n") == 1:
+ count += 0.125
+ count -= (len(text.split("\n")[-1]) - 1)*0.001
+ return count
+
+def xmlcount_reward_func(completions, **kwargs) -> list[float]:
+ contents = [completion[0]["content"] for completion in completions]
+ return [count_xml(c) for c in contents]
+
+#model_name = "meta-llama/Llama-3.2-1B-Instruct"
+model_name = "Qwen/Qwen2.5-1.5B-Instruct"
+
+if "Llama" in model_name:
+ output_dir = "outputs/Llama-1B-GRPO"
+ run_name = "Llama-1B-GRPO-gsm8k"
+else:
+ output_dir="outputs/Qwen-1.5B-GRPO"
+ run_name="Qwen-1.5B-GRPO-gsm8k"
+
+training_args = GRPOConfig(
+ output_dir=output_dir,
+ run_name=run_name,
+ learning_rate=5e-6,
+ adam_beta1 = 0.9,
+ adam_beta2 = 0.99,
+ weight_decay = 0.1,
+ warmup_ratio = 0.1,
+ lr_scheduler_type='cosine',
+ logging_steps=1,
+ bf16=True,
+ per_device_train_batch_size=1,
+ gradient_accumulation_steps=4,
+ num_generations=16,
+ max_prompt_length=256,
+ max_completion_length=786,
+ num_train_epochs=1,
+ save_steps=100,
+ max_grad_norm=0.1,
+ report_to="wandb",
+ log_on_each_node=False,
+)
+peft_config = LoraConfig(
+ r=16,
+ lora_alpha=64,
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
+ task_type="CAUSAL_LM",
+ lora_dropout=0.05,
+)
+model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_2",
+ device_map=None
+).to("cuda")
+
+tokenizer = AutoTokenizer.from_pretrained(model_name)
+tokenizer.pad_token = tokenizer.eos_token
+
+# use peft at your own risk; not working for me with multi-GPU training
+trainer = GRPOTrainer(
+ model=model,
+ processing_class=tokenizer,
+ reward_funcs=[
+ xmlcount_reward_func,
+ soft_format_reward_func,
+ strict_format_reward_func,
+ int_reward_func,
+ correctness_reward_func],
+ args=training_args,
+ train_dataset=dataset,
+ #peft_config=peft_config
+)
+trainer.train()