Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 28 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 12 additions & 0 deletions .idea/tiny-grpo.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Minimal GRPO implementation

Goal: Working toy implementation of llama-3.2-3b locally RL training with GRPO. Understanding the algorithm & hyper parameters. Just running everything locally on a single node.
Since I had a smaller 12 GB GPU, I tested this with smaller number of samples and an even smaller model of LLM instruct than originally proposed.
Goal: Working toy implementation of HuggingFaceTB/SmolLM-135M-Instruct locally RL training with GRPO. Understanding the algorithm & hyper parameters. Just running everything locally on a single node.

### Setup

Expand All @@ -16,16 +16,25 @@ conda activate grpo
```
pip install -r requirements.txt
pip install flash-attn --no-build-isolation

#May need to upgrade nvcc--version to higher for flash-attn to work
```

3. Play with the source in `train.py`
3. Play with the source in `train_ds2.py`
Since I had only one 12 GB 3060 GPU, I modified the code to run on single GPU instead of distributed
```
python train_ds2.py

```
python train.py

with multiple gpu

```
torchrun --nproc_per_node=8 train.py
```

### Inspiration

https://github.com/open-thought/tiny-grpo
- [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF)
- [Spinning Up in Deep RL](https://spinningup.openai.com/en/latest/)

Expand Down
Binary file added __pycache__/ckpt_utils.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/loss.cpython-312.pyc
Binary file not shown.
Binary file added __pycache__/replay_buffer.cpython-312.pyc
Binary file not shown.
34 changes: 34 additions & 0 deletions ckpt_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from torch.distributed.checkpoint.state_dict import (
set_optimizer_state_dict,
set_model_state_dict,
get_model_state_dict,
get_optimizer_state_dict,
)
import torch.distributed.checkpoint as dcp
import torch.distributed as dist


def save_checkpoint(model, optimizer, path):
"""Save model and optimizer state using distributed checkpoint"""
model_state = get_model_state_dict(model=model)
optimizer_state = get_optimizer_state_dict(model=model, optimizers=optimizer)

state_dict = {"model": model_state, "optimizer": optimizer_state}

dcp.save(state_dict=state_dict, storage_writer=dcp.FileSystemWriter(path))


def load_checkpoint(model, optimizer, path):
"""Load model and optimizer state using distributed checkpoint"""

dcp_state_dict = {
"model": get_model_state_dict(model=model),
"optimizer": get_optimizer_state_dict(model=model, optimizers=optimizer),
}

dcp.load(dcp_state_dict, storage_reader=dcp.FileSystemReader(path))

set_model_state_dict(model=model, model_state_dict=dcp_state_dict["model"])
set_optimizer_state_dict(
model=model, optimizers=optimizer, optim_state_dict=dcp_state_dict["optimizer"]
)
Loading