RLHF_PPO_Tuned_GPT2 fine-tunes GPT-2 using Reinforcement Learning from Human Feedback (RLHF) and Proximal Policy Optimization (PPO) to generate responses aligned with human preferences. This project enhances response helpfulness, truthfulness, and harmlessness through reward-based learning, significantly outperforming the vanilla GPT-2 model.
- A reward model was trained using human feedback from the yitingxie/rlhf-reward-datasets to assign positive or negative feedback based on response quality.
- The dataset was filtered, limiting prompts and responses to under 500 tokens for efficiency during training.
- The model was further refined with PPO, ensuring responses are aligned with human feedback by optimizing behavior based on the reward scores generated by the reward model.
- KL divergence and a conservative learning rate were used to maintain stability during training, improving response safety and relevance.
- Base Model: GPT-2 Medium.
- Reward Model: Predicts feedback scores (positive or negative) based on the helpfulness and appropriateness of responses.
- PPO Head: Enhances the GPT-2 model by introducing reinforcement learning to align model outputs with human feedback.
During inference, responses were generated using both the vanilla GPT-2 and the PPO-tuned GPT-2 models. Responses were compared using criteria like:
- Helpfulness: Whether the response answered the question effectively.
- Truthfulness: The factual accuracy of the response.
- Harmlessness: Avoiding misleading or harmful content.
A total of 36 prompts were used for evaluation. The responses were reviewed and judged based on their adherence to the three criteria. It was done using ChatGPT4o (did not use API to keep cost down and as data was small) - Check out this link for detailed analysis: https://chatgpt.com/share/66eb698e-d66c-8004-9318-bf354ce88fa2
- vanilla_responses: 9 better responses
- ppo_responses: 23 better responses
- Neither: 4 responses were equally poor.
The PPO-tuned GPT-2 model significantly outperformed the vanilla model, providing more relevant, less repetitive, and safer answers. PPO fine-tuning helped the model handle sensitive topics more responsibly, demonstrating greater care in avoiding harmful or misleading outputs.
- PPO-tuned responses were generally more coherent and aligned better with human feedback, making them more useful in conversational contexts.
- Vanilla GPT-2 responses often lacked the depth and safety precautions observed in the PPO-tuned model, frequently generating repetitive or less informative answers.
The RLHF PPO-tuned GPT-2 model offers significant improvements over the vanilla GPT-2, particularly in relevance, safety, and overall helpfulness. By combining RLHF and PPO, the model aligns more closely with human preferences, making it more suitable for real-world applications that require trustworthy and non-harmful outputs.
https://www.kaggle.com/code/karthiksundaram123/rlhf-ppo-tuned
1. Clone the Repository
git clone https://github.com/yourusername/RLHF_PPO_Tuned_GPT2.git
cd RLHF_PPO_Tuned_GPT2
2. Create a Virtual Environment (Optional)
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
3. Install Dependencies
pip install -r requirements.txt
4. Ensure PyTorch with CUDA Support
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117
1. Preparing the Dataset
python src/prepare_dataset.py
2. Training the Reward Model
python src/train_reward_model.py
3. Training the PPO Model python src/train_ppo_model.py
4. Generating Responses
python src/generate_responses.py
RLHF_PPO_Tuned_GPT2/
├── README.md
├── requirements.txt
├── .gitignore
├── src/
│ ├── prepare_dataset.py
│ ├── train_reward_model.py
│ ├── train_ppo_model.py
│ ├── generate_responses.py
│ └── utils.py
├── data/ # Initially empty
│ ├── trained_reward_model/ # Created after reward model training
│ ├── ppo_gpt2_model/ # Created after PPO model training
│ └── ppo_gpt2_tokenizer/ # Tokenizer saved here
├── outputs/ # Stores CSV outputs
│ └── query_response_pairs.csv # Created after generating responses
├── logs/ # Stores logs generated during training
└── LICENSE # Optional: Add a license file