From 65aaaec2499e36c04602ca74a1d7d14ba371a699 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 5 Sep 2024 19:41:20 +0000 Subject: [PATCH 01/53] init --- .gitignore | 2 ++ Dockerfile | 4 ++++ open_instruct/dpo_tune.py | 8 +++++++- open_instruct/finetune.py | 8 +++++++- open_instruct/utils.py | 38 ++++++++++++++++++++++++++++++++++++++ requirements.txt | 2 +- 6 files changed, 59 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 55d2caedc..5952c206c 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,8 @@ rejection_sampling/shards1 token_length.png *.tfevents.* +oe-eval-internal/ + results models wandb diff --git a/Dockerfile b/Dockerfile index dd6b95a97..7579a02fc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -90,6 +90,10 @@ RUN pip install --upgrade pip "setuptools<70.0.0" wheel RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 RUN pip install packaging RUN pip install flash-attn==2.6.3 --no-build-isolation +# for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 +# core is a dependency of ai2-olmo +RUN pip install ai2-olmo-core==0.1.0 +RUN pip install ai2-olmo>=0.5.0 --no-deps RUN pip install -r requirements.txt # NLTK download diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index 9c5cf1714..74bb522fd 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -17,13 +17,13 @@ DPO tuning script. Adapted from our finetuning script. """ +import json import logging import math import os import random import subprocess import time -import json from copy import deepcopy from dataclasses import dataclass, field from datetime import timedelta @@ -65,6 +65,7 @@ from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, + check_hf_olmo_availability, clean_last_n_checkpoints, get_datasets, get_last_checkpoint_path, @@ -499,6 +500,11 @@ def prepare_deepspeed(accelerator, model): def main(args: FlatArguments): + # try to import OLMo for automodel + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index e8d69c08e..48ba726e1 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -14,13 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging import math import os import random import subprocess import time -import json from dataclasses import dataclass, field from datetime import timedelta from functools import partial @@ -55,6 +55,7 @@ from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, + check_hf_olmo_availability, clean_last_n_checkpoints, get_datasets, get_last_checkpoint_path, @@ -448,6 +449,11 @@ def _concat_messages(messages): def main(args: FlatArguments): + # try to import OLMo for automodel + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 34fc96dfe..0a95b3f28 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -14,6 +14,7 @@ import dataclasses import functools +import importlib import json import logging import os @@ -51,6 +52,43 @@ """ +# ---------------------------------------------------------------------------- +# Import utilities +def check_hf_olmo_availability(return_version: bool = True) -> Union[dict, bool]: + pkg_name = "hf_olmo" + + # Check if the package spec exists + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method + try: + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") + if package_version == "N/A": + package_exists = False + except ImportError: + package_exists = False + + logger.debug(f"Detected {pkg_name} version: {package_version}") + + if return_version: + return { + "available": package_exists, + "version": package_version, + "python_version": sys.version, + "os": os.name, + "platform": sys.platform, + } + else: + return package_exists + + # ---------------------------------------------------------------------------- # Dataset utilities def is_openai_format(messages: Any) -> bool: diff --git a/requirements.txt b/requirements.txt index 39a3073d6..9b9ff095b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -44,4 +44,4 @@ isort autoflake pytest hf_transfer -beaker-py +beaker-py \ No newline at end of file From 7da49a9b5a8f1491267603f2371700d4aca5c2f7 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Thu, 5 Sep 2024 20:35:18 +0000 Subject: [PATCH 02/53] up --- Dockerfile | 2 +- configs/train_configs/sft/olmo_7b_0924.yaml | 22 +++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 configs/train_configs/sft/olmo_7b_0924.yaml diff --git a/Dockerfile b/Dockerfile index 7579a02fc..87181fb2f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -92,7 +92,7 @@ RUN pip install packaging RUN pip install flash-attn==2.6.3 --no-build-isolation # for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 # core is a dependency of ai2-olmo -RUN pip install ai2-olmo-core==0.1.0 +RUN pip install ai2-olmo-core==0.1.0 omegaconf RUN pip install ai2-olmo>=0.5.0 --no-deps RUN pip install -r requirements.txt diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml new file mode 100644 index 000000000..d76b2e560 --- /dev/null +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -0,0 +1,22 @@ +model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_revision: main +use_flash_attn: true +tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_name: allenai/tulu-v2-sft-mixture-olmo-2048 +max_seq_length: 2048 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 # note, this is set up for 8 GPUs +gradient_accumulation_steps: 16 +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: output/olmo_instruct/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file From d929a895aa80a04ad4254989e1051a5b36c76e44 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 9 Sep 2024 17:23:18 +0000 Subject: [PATCH 03/53] branch dockerfile --- .github/workflows/push-image-olmo.yml | 82 ++++++++++++++++++ .github/workflows/push-image.yml | 2 - Dockerfile | 4 +- Dockerfile.olmo | 115 ++++++++++++++++++++++++++ requirements-olmo.txt | 48 +++++++++++ 5 files changed, 247 insertions(+), 4 deletions(-) create mode 100644 .github/workflows/push-image-olmo.yml create mode 100644 Dockerfile.olmo create mode 100644 requirements-olmo.txt diff --git a/.github/workflows/push-image-olmo.yml b/.github/workflows/push-image-olmo.yml new file mode 100644 index 000000000..8fdafc49f --- /dev/null +++ b/.github/workflows/push-image-olmo.yml @@ -0,0 +1,82 @@ +# This is an example workflow file. +# +# When you add a new image, copy this file and then change all mentions of "hello-world" with +# the name of your new image. +# +# Read through the rest of the comments in this file to figure out how it works, and what else +# you need to change. +name: build_open_instruct_olmo + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + push: + # Run this workflow anytime a push updates one of the files in the image's directory + # (other than the README), and anytime there's a new release tag for this image. + paths: + - 'open_instruct/**' + - '!open_instruct/README.md' + - 'requirements-olmo.txt' + - 'Dockerfile.olmo' + - '.github/workflows/push-image-olmo.yml' + # Note, add .olmo dockerfile + requirements if adding auto build to those + branches: [main] + # pull_request: # note, comment this out for running on every push + # # Also run on PRs that update the files in the image's directory (other than README). + # branches: [main] + # paths: + # - 'open_instruct/**' + # - '!open_instruct/README.md' + # - 'requirements-olmo.txt' + # - 'Dockerfile.olmo' + workflow_dispatch: # This allows us to manually trigger a build through the GitHub UI. + +env: + DOCKER_BUILDKIT: "1" + +jobs: + build: + name: open_instruct + runs-on: ubuntu-latest + timeout-minutes: 60 + if: (github.event_name != 'workflow_run') || (github.event.workflow_run.conclusion == 'success') + steps: + - uses: actions/checkout@v3 + with: + repository: allenai/oe-eval-internal + path: './oe-eval-internal' + ssh-key: ${{ secrets.OE_EVAL_GIT_CLONE_ACCESS_PRIVATE_SSH_DEPLOY_KEY }} + + - name: Setup environment + uses: ./.github/actions/setup + with: + beaker_token: ${{ secrets.BEAKER_TOKEN }} + # ghcr_token: ${{ secrets.GHCR_TOKEN }} + # ghcr_user: ${{ secrets.GHCR_USER }} + + # big images fail, trying this + - name: Delete huge unnecessary tools folder + run: rm -rf /opt/hostedtoolcache /usr/share/dotnet "$AGENT_TOOLSDIRECTORY" + + - name: Build image + run: | + docker build \ + --build-arg BUILDKIT_INLINE_CACHE=1 \ + --build-arg CUDA=12.1.0 --build-arg \ + TARGET=cudnn8-devel --build-arg DIST=ubuntu20.04 \ + --build-arg REQUIRE=requirements.txt \ + -f Dockerfile.olmo . \ + -t open_instruct_olmo + + - name: Check image + run: | + docker run --rm open_instruct_olmo + - name: Push image + # if: github.event_name != 'pull_request' + uses: ./.github/actions/push + with: + image: open_instruct_olmo # this is the tag of the image we just built in the previous step + beaker: open_instruct_olmo_auto # this is the name of the image on Beaker + latest: true # this flag says we should also push this as the 'latest' version to GHCR diff --git a/.github/workflows/push-image.yml b/.github/workflows/push-image.yml index 40e205fb3..4e29f4d28 100644 --- a/.github/workflows/push-image.yml +++ b/.github/workflows/push-image.yml @@ -44,8 +44,6 @@ jobs: timeout-minutes: 60 if: (github.event_name != 'workflow_run') || (github.event.workflow_run.conclusion == 'success') steps: - - uses: actions/checkout@v3 - - uses: actions/checkout@v3 with: repository: allenai/oe-eval-internal diff --git a/Dockerfile b/Dockerfile index 87181fb2f..9f4e1fb00 100644 --- a/Dockerfile +++ b/Dockerfile @@ -92,8 +92,8 @@ RUN pip install packaging RUN pip install flash-attn==2.6.3 --no-build-isolation # for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 # core is a dependency of ai2-olmo -RUN pip install ai2-olmo-core==0.1.0 omegaconf -RUN pip install ai2-olmo>=0.5.0 --no-deps +# RUN pip install ai2-olmo-core==0.1.0 omegaconf +# RUN pip install ai2-olmo>=0.5.0 --no-deps RUN pip install -r requirements.txt # NLTK download diff --git a/Dockerfile.olmo b/Dockerfile.olmo new file mode 100644 index 000000000..38f8a1db4 --- /dev/null +++ b/Dockerfile.olmo @@ -0,0 +1,115 @@ +ARG CUDA +ARG DIST +ARG TARGET +FROM --platform=linux/amd64 nvidia/cuda:${CUDA}-${TARGET}-${DIST} + +ARG DEBIAN_FRONTEND="noninteractive" +ENV TZ="America/Los_Angeles" + +# Install base tools. +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + jq \ + language-pack-en \ + make \ + sudo \ + unzip \ + vim \ + wget \ + parallel \ + iputils-ping \ + tmux + +ARG BEAKER_VERSION +RUN curl --silent \ + --connect-timeout 5 \ + --max-time 10 \ + --retry 5 \ + --retry-delay 0 \ + --retry-max-time 40 \ + --output beaker.tar.gz \ + "https://beaker.org/api/v3/release/cli?os=linux&arch=amd64&version=${BEAKER_VERSION}" \ + && tar -zxf beaker.tar.gz -C /usr/local/bin/ ./beaker \ + && rm beaker.tar.gz + +# This ensures the dynamic linker (or NVIDIA's container runtime, I'm not sure) +# puts the right NVIDIA things in the right place (that THOR requires). +ENV NVIDIA_DRIVER_CAPABILITIES=graphics,utility,compute + +# Install conda. We give anyone in the users group the ability to run +# conda commands and install packages in the base (default) environment. +# Things installed into the default environment won't persist, but we prefer +# convenience in this case and try to make sure the user is aware of this +# with a message that's printed when the session starts. +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh \ + && echo "32d73e1bc33fda089d7cd9ef4c1be542616bd8e437d1f77afeeaf7afdb019787 Miniconda3-py310_23.1.0-1-Linux-x86_64.sh" \ + | sha256sum --check \ + && bash Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -b -p /opt/miniconda3 \ + && rm Miniconda3-py310_23.1.0-1-Linux-x86_64.sh + +ENV PATH=/opt/miniconda3/bin:/opt/miniconda3/condabin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + +# Install a few additional utilities via pip +RUN /opt/miniconda3/bin/pip install --no-cache-dir \ + gpustat \ + jupyter \ + beaker-gantry \ + oocmap + +# Ensure users can modify their container environment. +RUN echo '%users ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers + +# Make the base image friendlier for interactive workloads. This makes things like the man command +# work. +RUN yes | unminimize + +# Install MLNX OFED user-space drivers +# See https://docs.nvidia.com/networking/pages/releaseview.action?pageId=15049785#Howto:DeployRDMAacceleratedDockercontaineroverInfiniBandfabric.-Dockerfile +ENV MOFED_VER 5.8-1.1.2.1 +ENV OS_VER ubuntu20.04 +ENV PLATFORM x86_64 +RUN wget --quiet https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VER}/MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ + tar -xvf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ + MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}/mlnxofedinstall --basic --user-space-only --without-fw-update -q && \ + rm -rf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM} && \ + rm MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz + +# The -l flag makes bash act as a login shell and load /etc/profile, etc. +ENTRYPOINT ["bash", "-l"] + +WORKDIR /stage/ + +# TODO When updating flash-attn or torch in the future, make sure to update the version in the requirements.txt file. +ENV HF_HUB_ENABLE_HF_TRANSFER=1 +COPY requirements.txt . +RUN pip install --upgrade pip "setuptools<70.0.0" wheel +# TODO, unpin setuptools when this issue in flash attention is resolved +RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 +RUN pip install packaging +RUN pip install flash-attn==2.5.9 --no-build-isolation +# for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 +# core is a dependency of ai2-olmo +# RUN pip install ai2-olmo-core==0.1.0 omegaconf +# RUN pip install ai2-olmo>=0.5.0 --no-deps +RUN pip install -r requirements-olmo.txt + +# NLTK download +RUN python -m nltk.downloader punkt +COPY open_instruct open_instruct +COPY oe-eval-internal oe-eval-internal + +# install the package in editable mode +COPY pyproject.toml . +RUN pip install -e . +COPY .git/ ./.git/ +COPY eval eval +COPY configs configs +COPY scripts scripts +COPY mason.py mason.py +RUN chmod +x scripts/* + +# for interactive session +RUN chmod -R 777 /stage/ diff --git a/requirements-olmo.txt b/requirements-olmo.txt new file mode 100644 index 000000000..611ebda53 --- /dev/null +++ b/requirements-olmo.txt @@ -0,0 +1,48 @@ +# TODO When updating flash-attn or torch in the future, make sure to update the version in the Dockerfile +torch==2.4.0 +ai2-olmo-core==0.1.0 +ai2-olmo>=0.5.0 +scipy +packaging +sentencepiece +datasets +deepspeed==0.14.4 +accelerate==0.31.0 +peft>=0.11.1 +bitsandbytes>=0.41.1 +evaluate>=0.4.0 +tokenizers==0.19.1 +protobuf +transformers==4.43.4 +openai>=1.0.0 +tiktoken +rouge_score +tensorboard +wandb +gradio>=3.50.2 +termcolor +jsonlines +unidic-lite +einops +flash-attn==2.5.8 # should really only be in dockerfile. Local env often doesn't have GPUs +fire +alpaca-eval==0.6.2 +# for human eval web app +flask +openpyxl +# for ifeval +nltk==3.8.1 +langdetect +immutabledict +# for math evaluations +antlr4-python3-runtime==4.11.0 +mpmath==1.3.0 +sympy==1.12.0 +# for linting +black +flake8 +isort +autoflake +pytest +hf_transfer +beaker-py \ No newline at end of file From 6ef706c10b9d6e9656db34d544cd57b5a40cb274 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 9 Sep 2024 17:57:40 +0000 Subject: [PATCH 04/53] update --- Dockerfile.olmo | 8 ++++---- requirements-olmo.txt | 4 +--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/Dockerfile.olmo b/Dockerfile.olmo index 38f8a1db4..6259756ae 100644 --- a/Dockerfile.olmo +++ b/Dockerfile.olmo @@ -84,16 +84,16 @@ WORKDIR /stage/ # TODO When updating flash-attn or torch in the future, make sure to update the version in the requirements.txt file. ENV HF_HUB_ENABLE_HF_TRANSFER=1 -COPY requirements.txt . +COPY requirements-olmo.txt . RUN pip install --upgrade pip "setuptools<70.0.0" wheel # TODO, unpin setuptools when this issue in flash attention is resolved RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 RUN pip install packaging -RUN pip install flash-attn==2.5.9 --no-build-isolation +RUN pip install flash-attn==2.5.9.post1 --no-build-isolation # for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 # core is a dependency of ai2-olmo -# RUN pip install ai2-olmo-core==0.1.0 omegaconf -# RUN pip install ai2-olmo>=0.5.0 --no-deps +RUN pip install ai2-olmo-core==0.1.0 omegaconf +RUN pip install ai2-olmo>=0.5.0 --no-deps RUN pip install -r requirements-olmo.txt # NLTK download diff --git a/requirements-olmo.txt b/requirements-olmo.txt index 611ebda53..93367a9e0 100644 --- a/requirements-olmo.txt +++ b/requirements-olmo.txt @@ -1,7 +1,5 @@ # TODO When updating flash-attn or torch in the future, make sure to update the version in the Dockerfile torch==2.4.0 -ai2-olmo-core==0.1.0 -ai2-olmo>=0.5.0 scipy packaging sentencepiece @@ -24,7 +22,7 @@ termcolor jsonlines unidic-lite einops -flash-attn==2.5.8 # should really only be in dockerfile. Local env often doesn't have GPUs +flash-attn==2.5.9.post1 # should really only be in dockerfile. Local env often doesn't have GPUs fire alpaca-eval==0.6.2 # for human eval web app From a4797fb6eda9a1ad6fe8f9d643995e7cdfbfe8b5 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 9 Sep 2024 20:55:35 +0000 Subject: [PATCH 05/53] debugging and minor fixes --- .github/workflows/push-image-olmo.yml | 1 - .github/workflows/push-image.yml | 1 - README.md | 2 +- configs/train_configs/sft/olmo_7b_0924.yaml | 2 +- open_instruct/dpo_tune.py | 3 ++- open_instruct/finetune.py | 3 ++- open_instruct/utils.py | 2 -- requirements-olmo.txt | 2 +- 8 files changed, 7 insertions(+), 9 deletions(-) diff --git a/.github/workflows/push-image-olmo.yml b/.github/workflows/push-image-olmo.yml index 8fdafc49f..28a8c3467 100644 --- a/.github/workflows/push-image-olmo.yml +++ b/.github/workflows/push-image-olmo.yml @@ -66,7 +66,6 @@ jobs: --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg CUDA=12.1.0 --build-arg \ TARGET=cudnn8-devel --build-arg DIST=ubuntu20.04 \ - --build-arg REQUIRE=requirements.txt \ -f Dockerfile.olmo . \ -t open_instruct_olmo diff --git a/.github/workflows/push-image.yml b/.github/workflows/push-image.yml index 4e29f4d28..f5a35fdc0 100644 --- a/.github/workflows/push-image.yml +++ b/.github/workflows/push-image.yml @@ -67,7 +67,6 @@ jobs: --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg CUDA=12.1.0 --build-arg \ TARGET=cudnn8-devel --build-arg DIST=ubuntu20.04 \ - --build-arg REQUIRE=requirements.txt . \ -t open_instruct diff --git a/README.md b/README.md index e0ba44914..701358cd0 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ pip install -r weight-diff-requirements.txt For a second installation strategy, if you'd like to *run experiments within a Docker environment*, you can create one using: ```bash -docker build --build-arg CUDA=12.1.0 --build-arg TARGET=cudnn8-devel --build-arg DIST=ubuntu20.04 --build-arg REQUIRE=requirements.txt . -t open_instruct +docker build --build-arg CUDA=12.1.0 --build-arg TARGET=cudnn8-devel --build-arg DIST=ubuntu20.04 . -t open_instruct # if you are interally at AI2, you can create an image like this: beaker image create open_instruct -n open_instruct -w ai2/$(whoami) diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml index d76b2e560..317e106d9 100644 --- a/configs/train_configs/sft/olmo_7b_0924.yaml +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -1,6 +1,6 @@ model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan model_revision: main -use_flash_attn: true +use_flash_attn: false tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan use_slow_tokenizer: false # olmo models only use fast tokenizers dataset_name: allenai/tulu-v2-sft-mixture-olmo-2048 diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index 74bb522fd..13b7c2f44 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -504,6 +504,7 @@ def main(args: FlatArguments): if check_hf_olmo_availability(): # allows AutoModel... to work with not in transformers olmo models import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -678,7 +679,7 @@ def load_model(): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, OLMoTokenizerFast): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 48ba726e1..211c43887 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -453,6 +453,7 @@ def main(args: FlatArguments): if check_hf_olmo_availability(): # allows AutoModel... to work with not in transformers olmo models import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -649,7 +650,7 @@ def main(args: FlatArguments): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, OLMoTokenizerFast): # noqa # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 0a95b3f28..08046e17e 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -75,8 +75,6 @@ def check_hf_olmo_availability(return_version: bool = True) -> Union[dict, bool] except ImportError: package_exists = False - logger.debug(f"Detected {pkg_name} version: {package_version}") - if return_version: return { "available": package_exists, diff --git a/requirements-olmo.txt b/requirements-olmo.txt index 93367a9e0..1ec51fb2f 100644 --- a/requirements-olmo.txt +++ b/requirements-olmo.txt @@ -33,7 +33,7 @@ nltk==3.8.1 langdetect immutabledict # for math evaluations -antlr4-python3-runtime==4.11.0 +antlr4-python3-runtime==4.9.2 mpmath==1.3.0 sympy==1.12.0 # for linting From 56d6d8644c8e66e3586350ba963394b6e89a6def Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 9 Sep 2024 21:32:33 +0000 Subject: [PATCH 06/53] nit and style --- open_instruct/dpo_tune.py | 4 +++- open_instruct/finetune.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index 13b7c2f44..cd0676648 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -679,7 +679,9 @@ def load_model(): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, OLMoTokenizerFast): + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or ( + check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast) + ): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 211c43887..4e917ff77 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -650,7 +650,7 @@ def main(args: FlatArguments): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, OLMoTokenizerFast): # noqa + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, OLMoTokenizerFast): # noqa # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token From 50500ea5f6f995b18e791242649cb6cb14044319 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 9 Sep 2024 22:13:51 +0000 Subject: [PATCH 07/53] fixes --- open_instruct/finetune.py | 4 ++-- open_instruct/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 4e917ff77..da6398f94 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -650,7 +650,7 @@ def main(args: FlatArguments): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast) or isinstance(tokenizer, OLMoTokenizerFast): # noqa + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or (check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast)): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token @@ -1015,7 +1015,7 @@ def main(args: FlatArguments): if is_beaker_job() and accelerator.is_main_process: # dpo script only supports these two options right now for datasets if args.dataset_mixer: - dataset_list = args.dataset_mixer.keys() + dataset_list = list(args.dataset_mixer.keys()) elif args.dataset_mixer_list: dataset_list = args.dataset_mixer_list[::2] # even indices elif args.dataset_name: diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 08046e17e..c45e5873d 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -54,7 +54,7 @@ # ---------------------------------------------------------------------------- # Import utilities -def check_hf_olmo_availability(return_version: bool = True) -> Union[dict, bool]: +def check_hf_olmo_availability(return_version: bool = False) -> Union[dict, bool]: pkg_name = "hf_olmo" # Check if the package spec exists From 5eb61ccb38da34d43f40d02c97dcc6741e047c06 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 9 Sep 2024 23:16:00 +0000 Subject: [PATCH 08/53] add weka mounting --- configs/train_configs/sft/olmo_7b_0924.yaml | 10 ++++++---- scripts/submit_finetune_job.py | 12 +++++++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml index 317e106d9..74bd932c2 100644 --- a/configs/train_configs/sft/olmo_7b_0924.yaml +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -1,10 +1,12 @@ -model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_name_or_path: /oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf model_revision: main use_flash_attn: false -tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +tokenizer_name: /oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf use_slow_tokenizer: false # olmo models only use fast tokenizers -dataset_name: allenai/tulu-v2-sft-mixture-olmo-2048 -max_seq_length: 2048 +dataset_name: allenai/llama-3-tulu-v3.3-mix-preview +max_seq_length: 4096 preprocessing_num_workers: 128 per_device_train_batch_size: 1 # note, this is set up for 8 GPUs gradient_accumulation_steps: 16 diff --git a/scripts/submit_finetune_job.py b/scripts/submit_finetune_job.py index ef4e87138..37df9509a 100644 --- a/scripts/submit_finetune_job.py +++ b/scripts/submit_finetune_job.py @@ -25,6 +25,8 @@ def main(): parser.add_argument("--num_nodes", type=int, default=1, help="Number of nodes to use") parser.add_argument("--image", type=str, default="nathanl/open_instruct_auto", help="Beaker image to use.") parser.add_argument("--workspace", type=str, default="ai2/tulu-2-improvements", help="Beaker workspace to use.") + parser.add_argument("--mount_on_weka", type=str, default=None, help="Mount a Weka directory to the job") + parser.add_argument("--weka_mount_path", type=str, default="/models", help="Path to mount the Weka directory") # allow unknown args from CLI, use this to modify loaded config in bash scripts for sweeping # Note, can only override args in --config passed (not default FlatArguments class in open_instruct/utils.py) @@ -166,7 +168,7 @@ def parse_args(args): d['tasks'][0]['arguments'][0] = new_arguments # name and description - exp_name = f"open_instruct_finetune_{model_name}_{now}" + exp_name = f"open_instruct_finetune_{model_name}_{now}"[:128] d['description'] = exp_name d['tasks'][0]['name'] = exp_name @@ -220,6 +222,14 @@ def parse_args(args): d['tasks'][0]['envVars'].append({ 'name': 'WANDB_API_KEY', 'secret': f"{beaker_whoami}_WANDB_API_KEY" }) + + # Weka setting + if args.mount_on_weka: + if d['tasks'][0].get('datasets') is None: + d['tasks'][0]['datasets'] = [] + d['tasks'][0]['datasets'].append({ + 'mountPath': f"{args.weka_mount_path}", 'source': {'weka': f"{args.mount_on_weka}"} + }) # optionally, print to debug config print(d) From b1344102cd0738e0579d6623c32793cb3dd02f4f Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 10 Sep 2024 00:46:57 +0000 Subject: [PATCH 09/53] up --- configs/train_configs/sft/olmo_7b_0924.yaml | 8 +++---- .../train_configs/sft/olmo_7b_0924_fw2.yaml | 24 +++++++++++++++++++ scripts/submit_finetune_job.py | 4 ++-- 3 files changed, 30 insertions(+), 6 deletions(-) create mode 100644 configs/train_configs/sft/olmo_7b_0924_fw2.yaml diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml index 74bd932c2..91d72a1ec 100644 --- a/configs/train_configs/sft/olmo_7b_0924.yaml +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -1,21 +1,21 @@ # model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -model_name_or_path: /oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf +model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf model_revision: main use_flash_attn: false # tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -tokenizer_name: /oe-training-default/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf +tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf use_slow_tokenizer: false # olmo models only use fast tokenizers dataset_name: allenai/llama-3-tulu-v3.3-mix-preview max_seq_length: 4096 preprocessing_num_workers: 128 per_device_train_batch_size: 1 # note, this is set up for 8 GPUs -gradient_accumulation_steps: 16 +gradient_accumulation_steps: 8 learning_rate: 2.0e-06 lr_scheduler_type: linear warmup_ratio: 0.03 weight_decay: 0.0 num_train_epochs: 3 -output_dir: output/olmo_instruct/ +output_dir: /output/olmo_instruct/ with_tracking: true report_to: - wandb diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2.yaml new file mode 100644 index 000000000..ff376aebf --- /dev/null +++ b/configs/train_configs/sft/olmo_7b_0924_fw2.yaml @@ -0,0 +1,24 @@ +# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +model_revision: main +use_flash_attn: false +# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_name: allenai/llama-3-tulu-v3.3-mix-preview +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 # note, this is set up for 8 GPUs +gradient_accumulation_steps: 4 # designed for 4 nodes +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: /output/olmo_instruct/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file diff --git a/scripts/submit_finetune_job.py b/scripts/submit_finetune_job.py index 37df9509a..9991003f8 100644 --- a/scripts/submit_finetune_job.py +++ b/scripts/submit_finetune_job.py @@ -26,7 +26,7 @@ def main(): parser.add_argument("--image", type=str, default="nathanl/open_instruct_auto", help="Beaker image to use.") parser.add_argument("--workspace", type=str, default="ai2/tulu-2-improvements", help="Beaker workspace to use.") parser.add_argument("--mount_on_weka", type=str, default=None, help="Mount a Weka directory to the job") - parser.add_argument("--weka_mount_path", type=str, default="/models", help="Path to mount the Weka directory") + parser.add_argument("--weka_mount_path", type=str, default="/adapt-data", help="Path to mount the Weka directory") # allow unknown args from CLI, use this to modify loaded config in bash scripts for sweeping # Note, can only override args in --config passed (not default FlatArguments class in open_instruct/utils.py) @@ -173,7 +173,7 @@ def parse_args(args): d['tasks'][0]['name'] = exp_name # add cluster-specific env vars - if args.cluster == "ai2/jupiter-cirrascale-2": + if args.cluster == "ai2/jupiter-cirrascale-2" and args.num_nodes > 1: d['tasks'][0]['envVars'] += [ { "name": "NCCL_SOCKET_IFNAME", From d007da01bd5d75489fdf92d8d73c9c792eea3dc0 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 10 Sep 2024 16:07:33 +0000 Subject: [PATCH 10/53] add hardcode flash_attn --- open_instruct/finetune.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index da6398f94..f0aa3cab2 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -616,6 +616,7 @@ def main(args: FlatArguments): trust_remote_code=args.trust_remote_code, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", + flash_attention=True if args.use_flash_attn else False, # TODO remove with ai2-olmo > 0.5.0 revision=args.model_revision, token=os.getenv("HF_TOKEN", None), ) @@ -628,6 +629,7 @@ def main(args: FlatArguments): low_cpu_mem_usage=args.low_cpu_mem_usage, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", + flash_attention=True if args.use_flash_attn else False, # TODO remove with ai2-olmo > 0.5.0 revision=args.model_revision, token=os.getenv("HF_TOKEN", None), ) From 5d82ea2bd5293ab19c7c8346aba20031c01f5b6e Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 10 Sep 2024 16:28:16 +0000 Subject: [PATCH 11/53] tweaks --- open_instruct/finetune.py | 40 +++++++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index f0aa3cab2..b2e9ed283 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -453,7 +453,7 @@ def main(args: FlatArguments): if check_hf_olmo_availability(): # allows AutoModel... to work with not in transformers olmo models import hf_olmo # noqa - from hf_olmo import OLMoTokenizerFast + from hf_olmo import OLMoTokenizerFast, OLMoConfig # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -616,23 +616,35 @@ def main(args: FlatArguments): trust_remote_code=args.trust_remote_code, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", - flash_attention=True if args.use_flash_attn else False, # TODO remove with ai2-olmo > 0.5.0 revision=args.model_revision, token=os.getenv("HF_TOKEN", None), ) else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - low_cpu_mem_usage=args.low_cpu_mem_usage, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", - flash_attention=True if args.use_flash_attn else False, # TODO remove with ai2-olmo > 0.5.0 - revision=args.model_revision, - token=os.getenv("HF_TOKEN", None), - ) + if (check_hf_olmo_availability() and isinstance(config, OLMoConfig)): + # handles flash_attn in config. TODO remove on ai2-olmo > 0.5.0 + config.flash_attention = args.use_flash_attn + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + low_cpu_mem_usage=args.low_cpu_mem_usage, + torch_dtype=torch.bfloat16, + revision=args.model_revision, + token=os.getenv("HF_TOKEN", None), + ) + else: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + low_cpu_mem_usage=args.low_cpu_mem_usage, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", + revision=args.model_revision, + token=os.getenv("HF_TOKEN", None), + ) else: logger.info("Training new model from scratch") model = AutoModelForCausalLM.from_config(config) From 11397bf76e67d718da72ecd77856722315db338e Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 10 Sep 2024 17:07:05 +0000 Subject: [PATCH 12/53] making it work nicely --- configs/train_configs/sft/olmo_7b_0924.yaml | 2 +- .../train_configs/sft/olmo_7b_0924_fw2.yaml | 2 +- .../sft/olmo_7b_0924_fw2_permissive.yaml | 31 +++++++++++++++++++ open_instruct/finetune.py | 9 ++++-- open_instruct/utils.py | 3 +- 5 files changed, 40 insertions(+), 7 deletions(-) create mode 100644 configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml index 91d72a1ec..b08a6ac2a 100644 --- a/configs/train_configs/sft/olmo_7b_0924.yaml +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -1,7 +1,7 @@ # model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf model_revision: main -use_flash_attn: false +use_flash_attn: true # tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf use_slow_tokenizer: false # olmo models only use fast tokenizers diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2.yaml index ff376aebf..3a40cd915 100644 --- a/configs/train_configs/sft/olmo_7b_0924_fw2.yaml +++ b/configs/train_configs/sft/olmo_7b_0924_fw2.yaml @@ -1,7 +1,7 @@ # model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf model_revision: main -use_flash_attn: false +use_flash_attn: true # tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf use_slow_tokenizer: false # olmo models only use fast tokenizers diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml new file mode 100644 index 000000000..72d3637b6 --- /dev/null +++ b/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml @@ -0,0 +1,31 @@ +# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +model_revision: main +use_flash_attn: true +# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_mixer: + ai2-adapt-dev/metamath-qa-reformat: 1.0 # MIT License + nvidia/Daring-Anteater: 1.0 # CC BY 4.0, + natolambert/tulu-v2-sft-mixture-flan: 1.0 # FLAN Apache 2.0 + natolambert/tulu-v2-sft-mixture-cot: 1.0 # FLAN Apache 2.0 + Open-Orca/OpenOrca: .02 # MIT + allenai/openassistant-guanaco-reformatted: 1.0 # Apache 2.0 + ai2-adapt-dev/codefeedback-single-turn-reformat-magicoder: 1.0 # MIT MagiCoder section of CodeFeedback +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 # note, this is set up for 8 GPUs +gradient_accumulation_steps: 4 # designed for 4 nodes +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: /output/olmo_instruct/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index b2e9ed283..5842e6811 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -453,7 +453,7 @@ def main(args: FlatArguments): if check_hf_olmo_availability(): # allows AutoModel... to work with not in transformers olmo models import hf_olmo # noqa - from hf_olmo import OLMoTokenizerFast, OLMoConfig + from hf_olmo import OLMoConfig, OLMoTokenizerFast # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -620,7 +620,8 @@ def main(args: FlatArguments): token=os.getenv("HF_TOKEN", None), ) else: - if (check_hf_olmo_availability() and isinstance(config, OLMoConfig)): + if check_hf_olmo_availability() and isinstance(config, OLMoConfig): + logger.info("Temporary loading for recent OLMo Models") # handles flash_attn in config. TODO remove on ai2-olmo > 0.5.0 config.flash_attention = args.use_flash_attn model = AutoModelForCausalLM.from_pretrained( @@ -664,7 +665,9 @@ def main(args: FlatArguments): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast) or (check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast)): + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or ( + check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast) + ): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token diff --git a/open_instruct/utils.py b/open_instruct/utils.py index c45e5873d..b7fd1ae1a 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -70,10 +70,9 @@ def check_hf_olmo_availability(return_version: bool = False) -> Union[dict, bool try: package = importlib.import_module(pkg_name) package_version = getattr(package, "__version__", "N/A") - if package_version == "N/A": - package_exists = False except ImportError: package_exists = False + package_version = "N/A" if return_version: return { From ee5c7d22cce9aebfc46f5ab35d7eaa3f47e5d944 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 11 Sep 2024 22:13:59 +0000 Subject: [PATCH 13/53] clean --- Dockerfile.olmo | 3 +- .../default_finetune_offloading.yaml | 69 +++++++++++++++++++ configs/train_configs/sft/olmo_7b_0924.yaml | 10 ++- .../train_configs/sft/olmo_7b_0924_fw2.yaml | 24 ------- .../sft/olmo_7b_0924_fw2_permissive.yaml | 13 ++-- open_instruct/finetune.py | 52 +++++++------- 6 files changed, 110 insertions(+), 61 deletions(-) create mode 100644 configs/beaker_configs/default_finetune_offloading.yaml delete mode 100644 configs/train_configs/sft/olmo_7b_0924_fw2.yaml diff --git a/Dockerfile.olmo b/Dockerfile.olmo index 6259756ae..1f8da727d 100644 --- a/Dockerfile.olmo +++ b/Dockerfile.olmo @@ -93,7 +93,8 @@ RUN pip install flash-attn==2.5.9.post1 --no-build-isolation # for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 # core is a dependency of ai2-olmo RUN pip install ai2-olmo-core==0.1.0 omegaconf -RUN pip install ai2-olmo>=0.5.0 --no-deps +# RUN pip install ai2-olmo>=0.5.0 --no-deps +RUN pip install git+https://github.com/allenai/OLMo.git@shanea/hf-olmo-gradient-checkpointing --no-deps RUN pip install -r requirements-olmo.txt # NLTK download diff --git a/configs/beaker_configs/default_finetune_offloading.yaml b/configs/beaker_configs/default_finetune_offloading.yaml new file mode 100644 index 000000000..4722b4e13 --- /dev/null +++ b/configs/beaker_configs/default_finetune_offloading.yaml @@ -0,0 +1,69 @@ +version: v2 +description: open-instruct-finetune +budget: ai2/oe-adapt +tasks: + - name: open-instruct-finetune + image: + beaker: nathanl/open_instruct_auto + command: [ + '/bin/sh', '-c' + ] + arguments: ['PYTHONPATH="/stage:$PYTHONPATH" accelerate launch + --mixed_precision bf16 + --num_machines 1 + --num_processes 4 + --use_deepspeed + --deepspeed_config_file configs/ds_configs/stage3_offloading_accelerate.conf + open_instruct/finetune.py + --model_name_or_path /hf_llama_models + --use_flash_attn + --tokenizer_name /hf_llama_models + --max_seq_length 2048 + --preprocessing_num_workers 16 + --per_device_train_batch_size 2 + --gradient_accumulation_steps 16 + --learning_rate 2e-5 + --lr_scheduler_type linear + --warmup_ratio 0.03 + --weight_decay 0. + --num_train_epochs 2 + --output_dir /output/ + --with_tracking + --report_to tensorboard + --logging_steps 1 + '] + envVars: + - name: CUDA_DEVICE_ORDER + value: PCI_BUS_ID + - name: TRANSFORMERS_CACHE + value: ./cache/ + - name: WANDB_API_KEY + secret: WANDB_API_KEY + - name: WANDB_PROJECT + value: open-instruct + - name: WANDB_WATCH + value: false + - name: WANDB_LOG_MODEL + value: false + - name: WANDB_DISABLED + value: true + - name: HF_TOKEN + secret: HF_TOKEN + # datasets: # example for how to include datasets in mounting + # - mountPath: /data + # source: + # beaker: Yizhongw03/processed_open_instruct_data + # - mountPath: /mmlu + # source: + # beaker: Yizhongw03/mmlu + # - mountPath: /hf_llama_models + # source: + # beaker: Yizhongw03/hf_llama_model_7B + result: + path: /output + resources: + gpuCount: 4 + context: + cluster: ai2/allennlp-cirrascale + priority: high + preemptible: false \ No newline at end of file diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml index b08a6ac2a..7f62e1dca 100644 --- a/configs/train_configs/sft/olmo_7b_0924.yaml +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -1,15 +1,13 @@ -# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf +model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan model_revision: main use_flash_attn: true -# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw3/step11931-hf +tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan use_slow_tokenizer: false # olmo models only use fast tokenizers dataset_name: allenai/llama-3-tulu-v3.3-mix-preview max_seq_length: 4096 preprocessing_num_workers: 128 -per_device_train_batch_size: 1 # note, this is set up for 8 GPUs -gradient_accumulation_steps: 8 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 # should run with this set to 16 for 1 node only learning_rate: 2.0e-06 lr_scheduler_type: linear warmup_ratio: 0.03 diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2.yaml deleted file mode 100644 index 3a40cd915..000000000 --- a/configs/train_configs/sft/olmo_7b_0924_fw2.yaml +++ /dev/null @@ -1,24 +0,0 @@ -# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf -model_revision: main -use_flash_attn: true -# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf -use_slow_tokenizer: false # olmo models only use fast tokenizers -dataset_name: allenai/llama-3-tulu-v3.3-mix-preview -max_seq_length: 4096 -preprocessing_num_workers: 128 -per_device_train_batch_size: 1 # note, this is set up for 8 GPUs -gradient_accumulation_steps: 4 # designed for 4 nodes -learning_rate: 2.0e-06 -lr_scheduler_type: linear -warmup_ratio: 0.03 -weight_decay: 0.0 -num_train_epochs: 3 -output_dir: /output/olmo_instruct/ -with_tracking: true -report_to: - - wandb -logging_steps: 1 -checkpointing_steps: epoch -add_bos: true \ No newline at end of file diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml index 72d3637b6..4539f713a 100644 --- a/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml +++ b/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml @@ -7,22 +7,27 @@ tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from use_slow_tokenizer: false # olmo models only use fast tokenizers dataset_mixer: ai2-adapt-dev/metamath-qa-reformat: 1.0 # MIT License - nvidia/Daring-Anteater: 1.0 # CC BY 4.0, natolambert/tulu-v2-sft-mixture-flan: 1.0 # FLAN Apache 2.0 natolambert/tulu-v2-sft-mixture-cot: 1.0 # FLAN Apache 2.0 - Open-Orca/OpenOrca: .02 # MIT allenai/openassistant-guanaco-reformatted: 1.0 # Apache 2.0 ai2-adapt-dev/codefeedback-single-turn-reformat-magicoder: 1.0 # MIT MagiCoder section of CodeFeedback + ai2-adapt-dev/aya_dataset-reformat: 1.0 # Apache 2.0 + ai2-adapt-dev/SlimOrca-reformat: 0.25 # MIT License + ai2-adapt-dev/Daring-Anteater-reformat: 1.0 # CC BY 4.0 + ai2-adapt-dev/WebInstructSub-reformat-apache: 0.1 # Apache 2.0 + ai2-adapt-dev/Table-GPT-All-train: 0.5 # MIT max_seq_length: 4096 preprocessing_num_workers: 128 -per_device_train_batch_size: 1 # note, this is set up for 8 GPUs +per_device_train_batch_size: 1 gradient_accumulation_steps: 4 # designed for 4 nodes +# gradient_accumulation_steps: 16 # designed for 1 nodes +gradient_checkpointing: true learning_rate: 2.0e-06 lr_scheduler_type: linear warmup_ratio: 0.03 weight_decay: 0.0 num_train_epochs: 3 -output_dir: /output/olmo_instruct/ +output_dir: /output/ with_tracking: true report_to: - wandb diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 5842e6811..f771790f8 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -620,32 +620,32 @@ def main(args: FlatArguments): token=os.getenv("HF_TOKEN", None), ) else: - if check_hf_olmo_availability() and isinstance(config, OLMoConfig): - logger.info("Temporary loading for recent OLMo Models") - # handles flash_attn in config. TODO remove on ai2-olmo > 0.5.0 - config.flash_attention = args.use_flash_attn - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - low_cpu_mem_usage=args.low_cpu_mem_usage, - torch_dtype=torch.bfloat16, - revision=args.model_revision, - token=os.getenv("HF_TOKEN", None), - ) - else: - model = AutoModelForCausalLM.from_pretrained( - args.model_name_or_path, - from_tf=bool(".ckpt" in args.model_name_or_path), - config=config, - trust_remote_code=args.trust_remote_code, - low_cpu_mem_usage=args.low_cpu_mem_usage, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", - revision=args.model_revision, - token=os.getenv("HF_TOKEN", None), - ) + # if check_hf_olmo_availability() and isinstance(config, OLMoConfig): + # logger.info("Temporary loading for recent OLMo Models") + # # handles flash_attn in config. TODO remove on ai2-olmo > 0.5.0 + # config.flash_attention = args.use_flash_attn + # model = AutoModelForCausalLM.from_pretrained( + # args.model_name_or_path, + # from_tf=bool(".ckpt" in args.model_name_or_path), + # config=config, + # trust_remote_code=args.trust_remote_code, + # low_cpu_mem_usage=args.low_cpu_mem_usage, + # torch_dtype=torch.bfloat16, + # revision=args.model_revision, + # token=os.getenv("HF_TOKEN", None), + # ) + # else: + model = AutoModelForCausalLM.from_pretrained( + args.model_name_or_path, + from_tf=bool(".ckpt" in args.model_name_or_path), + config=config, + trust_remote_code=args.trust_remote_code, + low_cpu_mem_usage=args.low_cpu_mem_usage, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2" if args.use_flash_attn else "eager", + revision=args.model_revision, + token=os.getenv("HF_TOKEN", None), + ) else: logger.info("Training new model from scratch") model = AutoModelForCausalLM.from_config(config) From fd16aeef6e97de44b938b424db5ccab460b53233 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 11 Sep 2024 22:14:44 +0000 Subject: [PATCH 14/53] clean --- open_instruct/finetune.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index f771790f8..488d4e551 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -453,7 +453,7 @@ def main(args: FlatArguments): if check_hf_olmo_availability(): # allows AutoModel... to work with not in transformers olmo models import hf_olmo # noqa - from hf_olmo import OLMoConfig, OLMoTokenizerFast + from hf_olmo import OLMoTokenizerFast # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers @@ -620,21 +620,6 @@ def main(args: FlatArguments): token=os.getenv("HF_TOKEN", None), ) else: - # if check_hf_olmo_availability() and isinstance(config, OLMoConfig): - # logger.info("Temporary loading for recent OLMo Models") - # # handles flash_attn in config. TODO remove on ai2-olmo > 0.5.0 - # config.flash_attention = args.use_flash_attn - # model = AutoModelForCausalLM.from_pretrained( - # args.model_name_or_path, - # from_tf=bool(".ckpt" in args.model_name_or_path), - # config=config, - # trust_remote_code=args.trust_remote_code, - # low_cpu_mem_usage=args.low_cpu_mem_usage, - # torch_dtype=torch.bfloat16, - # revision=args.model_revision, - # token=os.getenv("HF_TOKEN", None), - # ) - # else: model = AutoModelForCausalLM.from_pretrained( args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), From fa447f8e8c3e6ef0e6cfe6958398e6187054d203 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 11 Sep 2024 22:16:51 +0000 Subject: [PATCH 15/53] clean --- Dockerfile | 4 ---- 1 file changed, 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index 9f4e1fb00..dd6b95a97 100644 --- a/Dockerfile +++ b/Dockerfile @@ -90,10 +90,6 @@ RUN pip install --upgrade pip "setuptools<70.0.0" wheel RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 RUN pip install packaging RUN pip install flash-attn==2.6.3 --no-build-isolation -# for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 -# core is a dependency of ai2-olmo -# RUN pip install ai2-olmo-core==0.1.0 omegaconf -# RUN pip install ai2-olmo>=0.5.0 --no-deps RUN pip install -r requirements.txt # NLTK download From e02e98547eb0f929837002c7191e9ee29f581117 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 11 Sep 2024 22:26:17 +0000 Subject: [PATCH 16/53] up --- Dockerfile.olmo | 1 + 1 file changed, 1 insertion(+) diff --git a/Dockerfile.olmo b/Dockerfile.olmo index 1f8da727d..040abad5c 100644 --- a/Dockerfile.olmo +++ b/Dockerfile.olmo @@ -94,6 +94,7 @@ RUN pip install flash-attn==2.5.9.post1 --no-build-isolation # core is a dependency of ai2-olmo RUN pip install ai2-olmo-core==0.1.0 omegaconf # RUN pip install ai2-olmo>=0.5.0 --no-deps +# TODO Update Once this is merged https://github.com/allenai/OLMo/pull/719, then next release RUN pip install git+https://github.com/allenai/OLMo.git@shanea/hf-olmo-gradient-checkpointing --no-deps RUN pip install -r requirements-olmo.txt From a0a32bf0b4f73d3756b767d38c947dbf9a422cdb Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Wed, 11 Sep 2024 23:15:32 +0000 Subject: [PATCH 17/53] no longer install from branch --- Dockerfile.olmo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Dockerfile.olmo b/Dockerfile.olmo index 040abad5c..10e2d9f12 100644 --- a/Dockerfile.olmo +++ b/Dockerfile.olmo @@ -95,7 +95,7 @@ RUN pip install flash-attn==2.5.9.post1 --no-build-isolation RUN pip install ai2-olmo-core==0.1.0 omegaconf # RUN pip install ai2-olmo>=0.5.0 --no-deps # TODO Update Once this is merged https://github.com/allenai/OLMo/pull/719, then next release -RUN pip install git+https://github.com/allenai/OLMo.git@shanea/hf-olmo-gradient-checkpointing --no-deps +RUN pip install git+https://github.com/allenai/OLMo.git@47f8f5abb40eb100c6623be12e1648c841b2ab99 --no-deps RUN pip install -r requirements-olmo.txt # NLTK download From 503e61e7d4d3b0eabd29c2015ac1fbe491281333 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 16 Sep 2024 19:49:46 +0000 Subject: [PATCH 18/53] fixes --- configs/train_configs/sft/olmo_7b_0924.yaml | 2 +- .../sft/olmo_7b_0924_fw2_tulu_v3.4.yaml | 26 +++++++++++++++++++ open_instruct/mix_data.py | 6 +++-- 3 files changed, 31 insertions(+), 3 deletions(-) create mode 100644 configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml index 7f62e1dca..e8264bd0a 100644 --- a/configs/train_configs/sft/olmo_7b_0924.yaml +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -13,7 +13,7 @@ lr_scheduler_type: linear warmup_ratio: 0.03 weight_decay: 0.0 num_train_epochs: 3 -output_dir: /output/olmo_instruct/ +output_dir: /output/ with_tracking: true report_to: - wandb diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml new file mode 100644 index 000000000..3efd900c6 --- /dev/null +++ b/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml @@ -0,0 +1,26 @@ +# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +model_revision: main +use_flash_attn: true +# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_name: allenai/tulu-v3.4-mix-preview +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 4 # designed for 4 nodes +# gradient_accumulation_steps: 16 # designed for 1 nodes +gradient_checkpointing: true +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: /output/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file diff --git a/open_instruct/mix_data.py b/open_instruct/mix_data.py index 6e7260c8f..c34bf0274 100644 --- a/open_instruct/mix_data.py +++ b/open_instruct/mix_data.py @@ -15,11 +15,13 @@ # limitations under the License. # script for mixing and saving data -from .utils import ArgumentParserPlus, FlatArguments, get_datasets +from open_instruct.utils import ArgumentParserPlus, get_datasets +from open_instruct.finetune import FlatArguments # Run as module for local imports, e.g.: -# python -m open_instruct.mix_data configs/train_configs/sft/default.yaml --dataset_mix_dir=output/tmp/ +# python open_instruct/mix_data.py configs/train_configs/sft/tulu3_8b_preview_mix_v3.4.yaml --dataset_mix_dir=output/tmp/ # can pass --save_to_hub=allenai/tulu-v3.1-mix-preview-4096-OLMoE +# note that = is needed with our argparser def main(): From 982726435268af12ade94d8713a39e747e0a413b Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 24 Sep 2024 00:52:04 +0000 Subject: [PATCH 19/53] dpo config --- configs/train_configs/dpo/olmo_7b_0924.yaml | 29 +++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 configs/train_configs/dpo/olmo_7b_0924.yaml diff --git a/configs/train_configs/dpo/olmo_7b_0924.yaml b/configs/train_configs/dpo/olmo_7b_0924.yaml new file mode 100644 index 000000000..3028bfca8 --- /dev/null +++ b/configs/train_configs/dpo/olmo_7b_0924.yaml @@ -0,0 +1,29 @@ +model_name_or_path: /model +model_revision: main +use_flash_attn: true +gradient_checkpointing: true +dataset_mixer: + allenai/ultrafeedback_binarized_cleaned_train: 1.0 + ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1.0 + ai2-adapt-dev/WildChat-prefs-280824: 1.0 + allenai/tulu-3-hardcoded-preferences: 1.0 +tokenizer_name: /model +use_slow_tokenizer: true +max_seq_length: 2048 +preprocessing_num_workers: 16 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 16 # designed for 8 GPUs, so batch size 128 +learning_rate: 5.0e-7 +lr_scheduler_type: linear +warmup_ratio: 0.1 +weight_decay: 0.0 +num_train_epochs: 1 +output_dir: /output +with_tracking: true +report_to: + - wandb +logging_steps: 1 +use_lora: false +dpo_loss_type: dpo_norm +dpo_beta: 5 +checkpointing_steps: 1000 \ No newline at end of file From b175717e14ea64506485f7261cc1d02d921d5c89 Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Tue, 24 Sep 2024 16:30:10 +0000 Subject: [PATCH 20/53] temp olmo changes --- Dockerfile.olmo | 4 ++++ configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml | 4 +++- scripts/eval/oe-eval.sh | 5 ++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/Dockerfile.olmo b/Dockerfile.olmo index 10e2d9f12..40ef4377a 100644 --- a/Dockerfile.olmo +++ b/Dockerfile.olmo @@ -98,6 +98,10 @@ RUN pip install ai2-olmo-core==0.1.0 omegaconf RUN pip install git+https://github.com/allenai/OLMo.git@47f8f5abb40eb100c6623be12e1648c841b2ab99 --no-deps RUN pip install -r requirements-olmo.txt +RUN pip install git+https://github.com/AkshitaB/vllm.git +RUN pip install vllm-flash-attn + + # NLTK download RUN python -m nltk.downloader punkt COPY open_instruct open_instruct diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml index 3efd900c6..491fc4502 100644 --- a/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml +++ b/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml @@ -9,7 +9,8 @@ dataset_name: allenai/tulu-v3.4-mix-preview max_seq_length: 4096 preprocessing_num_workers: 128 per_device_train_batch_size: 1 -gradient_accumulation_steps: 4 # designed for 4 nodes +# gradient_accumulation_steps: 4 # designed for 4 nodes +gradient_accumulation_steps: 8 # designed for 2 nodes # gradient_accumulation_steps: 16 # designed for 1 nodes gradient_checkpointing: true learning_rate: 2.0e-06 @@ -19,6 +20,7 @@ weight_decay: 0.0 num_train_epochs: 3 output_dir: /output/ with_tracking: true +reduce_loss: mean report_to: - wandb logging_steps: 1 diff --git a/scripts/eval/oe-eval.sh b/scripts/eval/oe-eval.sh index e250faaef..3cb3a68e1 100755 --- a/scripts/eval/oe-eval.sh +++ b/scripts/eval/oe-eval.sh @@ -92,7 +92,10 @@ TASKS=( "alpaca_eval_v2::tulu" "truthfulqa::tulu" ) -MODEL_TYPE="--model-type vllm" +# For models without VLLM (experimental architectures) +# comment out the VLLM arg and set GPU_COUNT_OTHER to 1 +# also consider lowering the batch size (VLLM arg), maybe to 5, VLLM handles it differently +# MODEL_TYPE="--model-type vllm" BATCH_SIZE_VLLM=10000 BATCH_SIZE_OTHER=1 GPU_COUNT=1 From b444e800e0fe9eb09c7cbeb415d7f769dcf9a215 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Tue, 1 Oct 2024 09:31:38 -0700 Subject: [PATCH 21/53] first pass --- open_instruct/dataset_processor.py | 5 +++++ open_instruct/model_utils.py | 27 +++++++++++++++++++++++++++ open_instruct/ppo_vllm_thread.py | 12 ++++++++++++ 3 files changed, 44 insertions(+) diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index 9df94a2cf..018b2735c 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -51,6 +51,7 @@ ATTENTION_MASK_REJECTED_KEY = "attention_mask_rejected" INPUT_IDS_PROMPT_KEY = "input_ids_prompt" ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt" +GROUND_TRUTHS_KEY = "ground_truth" # NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only # also we don't really need `ATTENTION_MASK_CHOSEN_KEY` and `ATTENTION_MASK_REJECTED_KEY` @@ -162,6 +163,9 @@ class DatasetConfig: # columns names for SFT dataset sft_messages_key: str = SFT_MESSAGE_KEY + # columns name for the ground truth + ground_truths_key: str = GROUND_TRUTHS_KEY + # columns names for binary dataset binary_messages_key: str = SFT_MESSAGE_KEY label: str = BINARY_LABEL_KEY @@ -372,6 +376,7 @@ def tokenize_fn(row): if self.config.train_only_on_prompt: labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY]) row[LABELS_KEY] = labels + row[GROUND_TRUTHS_KEY] = row[self.config.ground_truths_key] return row return dataset.map( diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index c858c3e53..6c465a044 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -205,6 +205,33 @@ def get_reward( sequence_lengths, ) +import re + +def apply_verifiable_reward( + query_responses: torch.Tensor, tokenizer, ground_truth: str +): + # decode the responses + decoded_responses = tokenizer.batch_decode(query_responses, skip_special_tokens=True) + # compare with ground truth. + # use same logic as in gsm8k evaluation + predictions = [] + rewards = [] + for response in decoded_responses: + # replace numbers like `x,xxx` with `xxxx` + response = re.sub(r"(\d),(\d)", r"\1\2", response) + numbers = re.findall(r"[-+]?\d*\.\d+|\d+", response) + if numbers: + predictions.append(numbers[-1]) + else: + predictions.append(response) + for prediction, gt in zip(predictions, ground_truth): + if prediction == gt: + print("Ground truth matched with prediction, giving a bonus reward 🤗") + rewards.append(10) + else: + rewards.append(0) + return torch.tensor(rewards, device=query_responses.device) + def forward( model: torch.nn.Module, diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 75e06d134..b9229a1e8 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -38,6 +38,7 @@ from open_instruct.dataset_processor import ( CHAT_TEMPLATES, INPUT_IDS_PROMPT_KEY, + GROUND_TRUTHS_KEY, DatasetConfig, SFTDatasetProcessor, SimpleGenerateCollator, @@ -50,6 +51,7 @@ first_true_indices, forward, get_reward, + get_verifiable_reward, prepare_deepspeed, print_rich_single_line_metrics, print_rich_table, @@ -680,6 +682,7 @@ def repeat_generator(): start_time = time.time() data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY].to(device) send_queries(accelerator, None, tokenizer, param_prompt_Q, queries_next) for _ in range(1, resume_training_step): # we didn't store scheduler state @@ -689,6 +692,7 @@ def repeat_generator(): episode += args.batch_size scheduler.step() queries = queries_next + ground_truths = ground_truths_next if ph.preemptied: break @@ -717,6 +721,7 @@ def repeat_generator(): if training_step != 1: data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY].to(device) send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) else: if training_step != 1: @@ -724,8 +729,10 @@ def repeat_generator(): # we also set to use `queries = queries_next` immediately data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY].to(device) send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) queries = queries_next + ground_truths_next = ground_truths_next training_time_start = time.time() with torch.no_grad(): @@ -791,6 +798,11 @@ def repeat_generator(): _, score, _ = get_reward( reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length ) + # also apply verifiable reward + verifiable_reward = get_verifiable_reward( + postprocessed_query_response, tokenizer, ground_truths + ) + score += verifiable_reward unwrapped_value_model = accelerator.unwrap_model(model).value_model full_value, _, _ = get_reward( unwrapped_value_model, query_response, tokenizer.pad_token_id, context_length From f0569a35ceae3e4dded4ed5794aa8bdabd603173 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 2 Oct 2024 18:23:59 +0000 Subject: [PATCH 22/53] fix spelling, ground truth stuff --- docs/algorithms/online_dpo.md | 16 +++---- docs/algorithms/ppo.md | 14 +++--- docs/algorithms/reward_modeling.md | 14 +++--- open_instruct/dataset_processor.py | 57 +++++++++++++++++++++---- open_instruct/online_dpo_vllm_thread.py | 10 ++--- open_instruct/ppo_vllm_thread.py | 48 ++++++++++++--------- test.sh | 30 ++++++------- 7 files changed, 120 insertions(+), 69 deletions(-) diff --git a/docs/algorithms/online_dpo.md b/docs/algorithms/online_dpo.md index 354fe9fe9..3f7ad1b90 100644 --- a/docs/algorithms/online_dpo.md +++ b/docs/algorithms/online_dpo.md @@ -31,7 +31,7 @@ python open_instruct/online_dpo_vllm_thread.py \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --non_stop_penalty \ @@ -43,7 +43,7 @@ python open_instruct/online_dpo_vllm_thread.py \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 64 \ --max_token_length 2048 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --num_train_epochs 1 \ --beta 0.1 \ --output_dir models/rm/rm_sentiment_1b \ @@ -64,7 +64,7 @@ python open_instruct/online_dpo_vllm_thread.py \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --non_stop_penalty \ @@ -76,7 +76,7 @@ python open_instruct/online_dpo_vllm_thread.py \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 64 \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --num_train_epochs 1 \ --beta 0.1 \ --output_dir models/rm/rm_sentiment_1b \ @@ -112,7 +112,7 @@ python mason.py \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo_vllm_thread_tldr \ --per_device_train_batch_size 16 \ @@ -158,7 +158,7 @@ python mason.py \ --dataset_eval_mixer '{"HuggingFaceH4/no_robots": 1.0}' \ --dataset_eval_splits test \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 8e-7 \ --output_dir /output/ \ --chat_template tulu \ @@ -211,7 +211,7 @@ python mason.py \ --dataset_eval_mixer '{"allenai/ultrafeedback_binarized_cleaned": 1.0}' \ --dataset_eval_splits test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 8e-7 \ --output_dir /output/ \ --chat_template tulu \ @@ -265,7 +265,7 @@ python mason.py \ --dataset_eval_mixer '{"allenai/ultrafeedback_binarized_cleaned": 1.0}' \ --dataset_eval_splits test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 8e-7 \ --output_dir /output/ \ --chat_template tulu \ diff --git a/docs/algorithms/ppo.md b/docs/algorithms/ppo.md index 69305fb1d..1b4985e00 100644 --- a/docs/algorithms/ppo.md +++ b/docs/algorithms/ppo.md @@ -31,7 +31,7 @@ python open_instruct/ppo_vllm_thread.py \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --non_stop_penalty \ @@ -43,7 +43,7 @@ python open_instruct/ppo_vllm_thread.py \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 64 \ --max_token_length 2048 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --num_train_epochs 1 \ --beta 0.1 \ --output_dir models/rm/rm_sentiment_1b \ @@ -64,7 +64,7 @@ python open_instruct/ppo_vllm_thread.py \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --non_stop_penalty \ @@ -76,7 +76,7 @@ python open_instruct/ppo_vllm_thread.py \ --per_device_eval_batch_size 2 \ --gradient_accumulation_steps 64 \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --num_train_epochs 1 \ --beta 0.1 \ --output_dir models/rm/rm_sentiment_1b \ @@ -112,7 +112,7 @@ python mason.py \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo_vllm_thread_tldr \ --per_device_train_batch_size 16 \ @@ -158,7 +158,7 @@ python mason.py \ --dataset_eval_mixer '{"HuggingFaceH4/no_robots": 1.0}' \ --dataset_eval_splits test \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 8e-7 \ --output_dir /output/ \ --chat_template tulu \ @@ -211,7 +211,7 @@ python mason.py \ --dataset_eval_mixer '{"allenai/ultrafeedback_binarized_cleaned": 1.0}' \ --dataset_eval_splits test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 8e-7 \ --output_dir /output/ \ --chat_template tulu \ diff --git a/docs/algorithms/reward_modeling.md b/docs/algorithms/reward_modeling.md index 53e1c6d3c..e675e4b05 100644 --- a/docs/algorithms/reward_modeling.md +++ b/docs/algorithms/reward_modeling.md @@ -40,7 +40,7 @@ python -i open_instruct/reward_modeling.py \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 32 \ --max_token_length 1024 \ - --max_prompt_token_lenth 1024 \ + --max_prompt_token_length 1024 \ --num_train_epochs 1 \ --output_dir models/rm/rm \ --sanity_check \ @@ -71,7 +71,7 @@ python mason.py \ --per_device_eval_batch_size 16 \ --gradient_accumulation_steps 4 \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --num_train_epochs 1 \ --output_dir models/rm/rm_sentiment_1b \ --with_tracking \ @@ -103,7 +103,7 @@ python mason.py \ --per_device_eval_batch_size 8 \ --gradient_accumulation_steps 8 \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --num_train_epochs 1 \ --output_dir models/rm/rm_sentiment_1b \ --with_tracking \ @@ -134,7 +134,7 @@ python mason.py \ --per_device_eval_batch_size 8 \ --gradient_accumulation_steps 4 \ --max_token_length 2048 \ - --max_prompt_token_lenth 1024 \ + --max_prompt_token_length 1024 \ --num_train_epochs 1 \ --output_dir models/rm/rm_hh_1b \ --with_tracking \ @@ -165,7 +165,7 @@ python mason.py \ --per_device_eval_batch_size 8 \ --gradient_accumulation_steps 4 \ --max_token_length 2048 \ - --max_prompt_token_lenth 1024 \ + --max_prompt_token_length 1024 \ --num_train_epochs 1 \ --output_dir models/rm/rm_hh_1b \ --with_tracking \ @@ -198,7 +198,7 @@ python mason.py \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 32 \ --max_token_length 1024 \ - --max_prompt_token_lenth 1024 \ + --max_prompt_token_length 1024 \ --num_train_epochs 1 \ --output_dir models/rm/rm_tulu_8b \ --gradient_checkpointing \ @@ -391,7 +391,7 @@ dataset_config = DatasetConfig( dataset_name="trl-internal-testing/sentiment-trl-style", chat_template="simple_chat", max_token_length=1024, - max_prompt_token_lenth=1024, + max_prompt_token_length=1024, ) tokenizer = AutoTokenizer.from_pretrained("gpt2") tokenizer.chat_template = CHAT_TEMPLATES["simple_chat"] diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index 018b2735c..f18b612fa 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -174,7 +174,7 @@ class DatasetConfig: # filter config max_token_length: Optional[int] = None - max_prompt_token_lenth: Optional[int] = None + max_prompt_token_length: Optional[int] = None # dataset.map config sanity_check: bool = False @@ -314,8 +314,8 @@ def tokenize_fn(row): def filter(self, dataset: Union[Dataset, DatasetDict]): def filter_fn(row): return ( - len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_lenth - if self.config.max_prompt_token_lenth is not None + len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_length + if self.config.max_prompt_token_length is not None else ( True and len(row[INPUT_IDS_CHOSEN_KEY]) <= self.config.max_token_length if self.config.max_token_length is not None @@ -366,8 +366,12 @@ def get_token_length_visualization(self, dataset: DatasetDict, save_path: str = class SFTDatasetProcessor(DatasetProcessor): def tokenize(self, dataset: Dataset): def tokenize_fn(row): + if len(row[self.config.sft_messages_key]) == 1: + prompt = row[self.config.sft_messages_key] + else: + prompt = row[self.config.sft_messages_key][:-1] row[INPUT_IDS_PROMPT_KEY] = self.tokenizer.apply_chat_template( - row[self.config.sft_messages_key][:-1], + prompt, add_generation_prompt=True, ) row[INPUT_IDS_KEY] = self.tokenizer.apply_chat_template(row[self.config.sft_messages_key]) @@ -386,18 +390,18 @@ def tokenize_fn(row): desc="Tokenizing and reformatting SFT data", ) - def filter(self, dataset: Dataset): + def filter(self, dataset: Dataset, need_contain_labels: bool = True): def filter_fn(row): max_prompt_token_length_ok = True - if self.config.max_prompt_token_lenth is not None: - max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_lenth + if self.config.max_prompt_token_length is not None: + max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_length max_token_length_ok = True if self.config.max_token_length is not None: max_token_length_ok = len(row[INPUT_IDS_KEY]) <= self.config.max_token_length contain_some_labels = any(x != -100 for x in row[LABELS_KEY]) - return max_prompt_token_length_ok and max_token_length_ok and contain_some_labels + return max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) return dataset.filter( filter_fn, @@ -515,6 +519,43 @@ def __call__(self, batch: list[dict]): INPUT_IDS_PROMPT_KEY: padded_sequences, } +class SimpleGenerateCollatorWithGroundTruth: + """Simple collator for generation task (always pad from the LEFT)""" + + def __init__(self, pad_token_id: int): + self.pad_token_id = pad_token_id + + def __call__(self, batch: list[dict]): + """the input will have input_ids_prompt""" + # Find max length in the batch + max_length = -1 + for i in range(len(batch)): + max_length = max(max_length, len(batch[i][INPUT_IDS_PROMPT_KEY])) + assert max_length > 0, "the dataset is empty" + + # Initialize lists to store padded sequences and attention masks + padded_sequences = [] + + for i in range(len(batch)): + # Calculate padding length + pad_length = max_length - len(batch[i][INPUT_IDS_PROMPT_KEY]) + + # Pad from the left + padding = [self.pad_token_id] * pad_length + padded_sequence = padding + batch[i][INPUT_IDS_PROMPT_KEY] + padded_sequences.append(padded_sequence) + + # Convert to tensors + padded_sequences = torch.tensor(padded_sequences) + + # ground truths + ground_truths = [x[GROUND_TRUTHS_KEY] for x in batch] + + return { + INPUT_IDS_PROMPT_KEY: padded_sequences, + GROUND_TRUTHS_KEY: ground_truths, + } + if __name__ == "__main__": # too little data; it should just use 1 CPU despite the number of available CPUs diff --git a/open_instruct/online_dpo_vllm_thread.py b/open_instruct/online_dpo_vllm_thread.py index 761eee46e..c93713149 100644 --- a/open_instruct/online_dpo_vllm_thread.py +++ b/open_instruct/online_dpo_vllm_thread.py @@ -509,7 +509,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # handle preemption class PreemptionHandler: - preemptied = False + preempted = False def __init__(self): signal.signal(signal.SIGTERM, self.exit_gracefully) @@ -528,7 +528,7 @@ def exit_gracefully(self, signum, frame): print("vllm thread terminated") except Exception as e: print(e) - self.preemptied = True + self.preempted = True ph = PreemptionHandler() @@ -571,7 +571,7 @@ def repeat_generator(): args=( model_config.model_name_or_path, model_config.model_revision, - dataset_config.max_prompt_token_lenth + args.response_length, + dataset_config.max_prompt_token_length + args.response_length, args.vllm_device, args.vllm_gpu_memory_utilization, generation_config, @@ -614,7 +614,7 @@ def repeat_generator(): episode += args.batch_size scheduler.step() queries = queries_next - if ph.preemptied: + if ph.preempted: break if accelerator.is_main_process: @@ -932,7 +932,7 @@ def repeat_generator(): gc.collect() torch.cuda.empty_cache() - if not ph.preemptied: + if not ph.preempted: # save model os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) original_tokenizer = AutoTokenizer.from_pretrained( diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index b9229a1e8..9776bfcbe 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -41,7 +41,7 @@ GROUND_TRUTHS_KEY, DatasetConfig, SFTDatasetProcessor, - SimpleGenerateCollator, + SimpleGenerateCollatorWithGroundTruth, visualize_token, ) from open_instruct.model_utils import ( @@ -51,7 +51,7 @@ first_true_indices, forward, get_reward, - get_verifiable_reward, + apply_verifiable_reward, prepare_deepspeed, print_rich_single_line_metrics, print_rich_table, @@ -171,6 +171,8 @@ class Args: """the reward value for responses that do not contain `stop_token_id`""" non_stop_penalty: bool = False """whether to penalize responses that do not contain `stop_token_id`""" + number_samples_per_prompt: int = 1 + """the number of samples to generate per prompt, useful for easy-star""" # online PPO specific args beta: float = 0.05 @@ -329,7 +331,7 @@ def vllm_generate( ) generation_start_time = time.time() outputs = llm.generate(prompt_token_ids=g_queries_list, sampling_params=generation_config) - response_ids = [list(output.outputs[0].token_ids) for output in outputs] + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") response_ids_Q.put(response_ids) @@ -337,6 +339,7 @@ def vllm_generate( outputs = llm.generate( prompt_token_ids=sample_evaluation_prompt_token_ids, sampling_params=generation_config ) + # for evaluation, even if we have multiple outputs, we only look at one of them for simplicity response_ids = [list(output.outputs[0].token_ids) for output in outputs] evaluation_Q.put(response_ids) @@ -462,7 +465,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): train_dataset = combine_dataset( args.dataset_mixer_dict, splits=args.dataset_train_splits, - columns_to_keep=[dataset_config.sft_messages_key], + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key], ) if dataset_config.sanity_check: train_dataset = train_dataset.select( @@ -470,19 +473,19 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): ) with accelerator.main_process_first(): train_dataset = dataset_processor.tokenize(train_dataset) - train_dataset = dataset_processor.filter(train_dataset) + train_dataset = dataset_processor.filter(train_dataset, need_contain_labels=False) dataset_dict["train"] = train_dataset eval_dataset = None if args.dataset_eval_mixer is not None: eval_dataset = combine_dataset( args.dataset_eval_mixer_dict, splits=args.dataset_eval_splits, - columns_to_keep=[dataset_config.sft_messages_key], + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key], ) eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) with accelerator.main_process_first(): eval_dataset = dataset_processor.tokenize(eval_dataset) - eval_dataset = dataset_processor.filter(eval_dataset) + eval_dataset = dataset_processor.filter(eval_dataset, need_contain_labels=False) dataset_dict["eval"] = eval_dataset # some more runtime logging @@ -551,7 +554,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): num_warmup_steps=args.warm_up_steps, num_training_steps=args.num_training_steps * args.num_train_epochs, ) - data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + data_collator = SimpleGenerateCollatorWithGroundTruth(pad_token_id=tokenizer.pad_token_id) dataloader = DataLoader( train_dataset, batch_size=args.local_dataloader_batch_size, @@ -585,7 +588,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # handle preemption class PreemptionHandler: - preemptied = False + preempted = False def __init__(self): signal.signal(signal.SIGTERM, self.exit_gracefully) @@ -604,7 +607,7 @@ def exit_gracefully(self, signum, frame): print("vllm thread terminated") except Exception as e: print(e) - self.preemptied = True + self.preempted = True ph = PreemptionHandler() @@ -629,6 +632,7 @@ def repeat_generator(): top_p=1.0, max_tokens=args.response_length, include_stop_str_in_output=True, + n=args.number_samples_per_prompt, ) param_prompt_Q = None response_ids_Q = None @@ -647,7 +651,7 @@ def repeat_generator(): args=( model_config.model_name_or_path, model_config.model_revision, - dataset_config.max_prompt_token_lenth + args.response_length, + dataset_config.max_prompt_token_length + args.response_length, args.vllm_device, args.vllm_gpu_memory_utilization, generation_config, @@ -663,7 +667,7 @@ def repeat_generator(): thread.start() torch.cuda.set_device(device) - g_vllm_responses = torch.zeros((args.batch_size, args.response_length), device=device, dtype=torch.long) + g_vllm_responses = torch.zeros((args.batch_size * args.number_samples_per_prompt, args.response_length), device=device, dtype=torch.long) # set up the metrics and initial states stats_shape = (args.num_epochs, args.num_mini_batches, args.gradient_accumulation_steps) @@ -682,18 +686,18 @@ def repeat_generator(): start_time = time.time() data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - ground_truths_next = data[GROUND_TRUTHS_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] send_queries(accelerator, None, tokenizer, param_prompt_Q, queries_next) for _ in range(1, resume_training_step): # we didn't store scheduler state scheduler.step() for training_step in range(resume_training_step, args.num_training_steps + 1): - episode += args.batch_size + episode += args.batch_size * args.number_samples_per_prompt # each sample is an episode scheduler.step() queries = queries_next ground_truths = ground_truths_next - if ph.preemptied: + if ph.preempted: break if accelerator.is_main_process: @@ -721,7 +725,7 @@ def repeat_generator(): if training_step != 1: data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - ground_truths_next = data[GROUND_TRUTHS_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) else: if training_step != 1: @@ -729,11 +733,17 @@ def repeat_generator(): # we also set to use `queries = queries_next` immediately data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - ground_truths_next = data[GROUND_TRUTHS_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) queries = queries_next ground_truths_next = ground_truths_next + # if we generate multiple samples per prompt, we need to repeat the queries and ground truths + # to match the vllm outputs. + if args.number_samples_per_prompt > 1: + queries = queries.repeat_interleave(args.number_samples_per_prompt, dim=0) + ground_truths = [gt for gt in ground_truths for _ in range(args.number_samples_per_prompt)] + training_time_start = time.time() with torch.no_grad(): context_length = queries.shape[1] @@ -799,7 +809,7 @@ def repeat_generator(): reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length ) # also apply verifiable reward - verifiable_reward = get_verifiable_reward( + verifiable_reward = apply_verifiable_reward( postprocessed_query_response, tokenizer, ground_truths ) score += verifiable_reward @@ -1019,7 +1029,7 @@ def repeat_generator(): gc.collect() torch.cuda.empty_cache() - if not ph.preemptied: + if not ph.preempted: # save model os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) original_tokenizer = AutoTokenizer.from_pretrained( diff --git a/test.sh b/test.sh index f38fb8fb8..df2136d03 100644 --- a/test.sh +++ b/test.sh @@ -59,7 +59,7 @@ accelerate launch --num_processes 2 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo_tulu3 \ @@ -89,7 +89,7 @@ accelerate launch --num_processes 1 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo_tulu3 \ @@ -119,7 +119,7 @@ accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo_tulu3 \ @@ -310,7 +310,7 @@ accelerate launch --num_processes 2 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 5e-7 \ --output_dir models/minimal/online_dpo_tulu2_llama333 \ @@ -342,7 +342,7 @@ accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 5e-7 \ --output_dir models/minimal/online_dpo_tulu2_llama333 \ @@ -373,7 +373,7 @@ accelerate launch --num_processes 2 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 5e-7 \ --output_dir models/minimal/online_dpo_tulu2_llama333 \ @@ -403,7 +403,7 @@ accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train \ --dataset_eval_split validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 5e-7 \ --output_dir models/minimal/online_dpo_tulu2_llama333 \ --chat_template simple_concat_with_space \ @@ -432,7 +432,7 @@ accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train \ --dataset_eval_split validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 5e-7 \ --output_dir models/minimal/online_dpo_tulu2_llama333 \ --chat_template simple_concat_with_space \ @@ -461,7 +461,7 @@ accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 5e-7 \ --output_dir /output/ \ @@ -498,7 +498,7 @@ accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 5e-7 \ --output_dir /output/ \ @@ -530,7 +530,7 @@ accelerate launch --num_processes 8 --config_file configs/ds_configs/deepspeed_z --dataset_train_split train_prefs \ --dataset_eval_split test_prefs \ --max_token_length 512 \ - --max_prompt_token_lenth 256 \ + --max_prompt_token_length 256 \ --sft_messages_key chosen \ --learning_rate 5e-7 \ --output_dir /output/ \ @@ -563,7 +563,7 @@ accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_z --dataset_eval_mixer '{"allenai/ultrafeedback_binarized_cleaned": 1.0}' \ --dataset_eval_splits test_prefs \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --sft_messages_key chosen \ --learning_rate 5e-7 \ --output_dir /output/ \ @@ -594,7 +594,7 @@ python open_instruct/online_dpo_vllm_thread.py \ --dataset_mixer '{"HuggingFaceH4/no_robots": 1.0}' \ --dataset_train_splits train \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 5e-7 \ --output_dir /output/ \ --chat_template tulu \ @@ -633,7 +633,7 @@ python mason.py \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo_vllm_thread_tldr \ --per_device_train_batch_size 2 \ @@ -660,7 +660,7 @@ accelerate launch --num_processes 3 --config_file configs/ds_configs/deepspeed_z --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_lenth 512 \ + --max_prompt_token_length 512 \ --learning_rate 3e-6 \ --output_dir models/minimal/online_dpo_vllm_thread_tldr \ --per_device_train_batch_size 2 \ From 8e0f517040158d7e4e75c1a62832eb4a6a8f990b Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 2 Oct 2024 18:42:44 +0000 Subject: [PATCH 23/53] fix misspelling --- open_instruct/online_dpo_vllm_thread.py | 2 +- open_instruct/ppo_vllm_thread.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/open_instruct/online_dpo_vllm_thread.py b/open_instruct/online_dpo_vllm_thread.py index c93713149..d3015d9c0 100644 --- a/open_instruct/online_dpo_vllm_thread.py +++ b/open_instruct/online_dpo_vllm_thread.py @@ -909,7 +909,7 @@ def repeat_generator(): "val/num_stop_token_ids": global_metrics[1], "objective/kl": global_metrics[2], "objective/kl2": global_metrics[15], - "ojbective/kl3": global_metrics[16], + "objective/kl3": global_metrics[16], "objective/entropy": global_metrics[3], "objective/non_score_reward": global_metrics[4], "objective/rlhf_reward": global_metrics[5], diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 9776bfcbe..cf2d1aa2a 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -1006,7 +1006,7 @@ def repeat_generator(): "val/num_stop_token_ids": global_metrics[1], "objective/kl": global_metrics[2], "objective/kl2": global_metrics[15], - "ojbective/kl3": global_metrics[16], + "objective/kl3": global_metrics[16], "objective/entropy": global_metrics[3], "objective/non_score_reward": global_metrics[4], "objective/rlhf_reward": global_metrics[5], From 6eebf7fd9d3dba0446a918903880d3d5bc7f77ce Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Fri, 4 Oct 2024 20:13:24 +0000 Subject: [PATCH 24/53] count verifieds and intermediate saving --- open_instruct/model_utils.py | 4 +++- open_instruct/ppo_vllm_thread.py | 40 ++++++++++++++++++++++++++++---- 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 6c465a044..1da83b7ae 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -230,7 +230,9 @@ def apply_verifiable_reward( rewards.append(10) else: rewards.append(0) - return torch.tensor(rewards, device=query_responses.device) + rewards_tensors = torch.tensor(rewards, device=query_responses.device) + # return rewards and count of times we applied reward + return rewards_tensors, (rewards_tensors > 0).sum().float().view(-1) def forward( diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index cf2d1aa2a..4c64e669c 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -139,6 +139,8 @@ class Args: """The frequency of evaluation steps""" local_dataloader_batch_size: Optional[int] = None """The batch size per GPU for the dataloader""" + save_freq: int = -1 + """How many train steps to save the model""" # online settings num_epochs: int = 4 @@ -190,6 +192,11 @@ class Args: lam: float = 0.95 """the lambda value for GAE""" kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" + """the KL estimator to use""" + apply_verifiable_reward: bool = False + """whether to apply verifiable reward""" + reward_model_multiplier: float = 1.0 + """the reward model multiplier, for down/upscaling the reward model output""" # vLLM settings. NOTE: currently we need to place the vLLM model on a separate GPU # for generation to work properly because vLLM would pre-alocate the memory. @@ -752,6 +759,7 @@ def repeat_generator(): logprobs = [] ref_logprobs = [] scores = [] + verifiable_counts = [] sequence_lengths = [] values = [] if accelerator.is_main_process: @@ -808,11 +816,16 @@ def repeat_generator(): _, score, _ = get_reward( reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length ) + if args.reward_model_multiplier != 1.0: + score *= args.reward_model_multiplier # also apply verifiable reward - verifiable_reward = apply_verifiable_reward( - postprocessed_query_response, tokenizer, ground_truths - ) - score += verifiable_reward + if args.apply_verifiable_reward: + verifiable_reward, verifiable_count = apply_verifiable_reward( + postprocessed_query_response, tokenizer, ground_truths + ) + score += verifiable_reward + else: + verifiable_count = torch.tensor([0.0], device=device).float() unwrapped_value_model = accelerator.unwrap_model(model).value_model full_value, _, _ = get_reward( unwrapped_value_model, query_response, tokenizer.pad_token_id, context_length @@ -826,6 +839,7 @@ def repeat_generator(): sequence_lengths.append(sequence_length) scores.append(score) values.append(value) + verifiable_counts.append(verifiable_count) responses = torch.cat(responses, 0) postprocessed_responses = torch.cat(postprocessed_responses, 0) logprobs = torch.cat(logprobs, 0) @@ -833,6 +847,7 @@ def repeat_generator(): sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) global_scores = accelerator.gather(scores) + verifiable_counts = torch.cat(verifiable_counts, 0) accelerator.print(f"global_scores: {global_scores}, {global_scores.mean()}") values = torch.cat(values, 0) del (logprob, ref_logprob, full_value, value, score) @@ -994,6 +1009,7 @@ def repeat_generator(): local_metrics[14] = ratio_stats.var() local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() + local_metrics[17] = verifiable_counts.sum() global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() metrics = { "episode": episode, @@ -1019,6 +1035,7 @@ def repeat_generator(): "policy/entropy_avg": global_metrics[12], "val/ratio": global_metrics[13], "val/ratio_var": global_metrics[14], + "objective/verifiable_counts": global_metrics[17] } if accelerator.is_main_process: print_rich_single_line_metrics(metrics) @@ -1029,6 +1046,21 @@ def repeat_generator(): gc.collect() torch.cuda.empty_cache() + # save steps + if arg.save_freq > 0 and training_step % args.save_freq == 0: + os.makedirs(os.path.join(os.path.dirname(args.output_dir), f"step_{training_step}"), exist_ok=True) + original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + save_with_accelerate( + accelerator, + model, + original_tokenizer, + args.output_dir, + model_attribute_to_save="policy", + ) + del original_tokenizer + if not ph.preempted: # save model os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) From f9a0b3c739afac87306ed6498343f6112096c5b6 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 7 Oct 2024 18:26:04 +0000 Subject: [PATCH 25/53] save intermediate steps --- open_instruct/ppo_vllm_thread.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 4c64e669c..703f201d0 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -1047,8 +1047,9 @@ def repeat_generator(): torch.cuda.empty_cache() # save steps - if arg.save_freq > 0 and training_step % args.save_freq == 0: - os.makedirs(os.path.join(os.path.dirname(args.output_dir), f"step_{training_step}"), exist_ok=True) + if args.save_freq > 0 and training_step % args.save_freq == 0: + step_dir = os.path.join(args.output_dir, f"step_{training_step}") + os.makedirs(step_dir, exist_ok=True) original_tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, revision=model_config.model_revision ) @@ -1056,7 +1057,7 @@ def repeat_generator(): accelerator, model, original_tokenizer, - args.output_dir, + step_dir, model_attribute_to_save="policy", ) del original_tokenizer From bad193330f557b32f95ac08521cf0148db0cc076 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Tue, 8 Oct 2024 20:23:53 +0000 Subject: [PATCH 26/53] small fix to logging --- open_instruct/ppo_vllm_thread.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 703f201d0..42a7e6f02 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -1009,7 +1009,7 @@ def repeat_generator(): local_metrics[14] = ratio_stats.var() local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() - local_metrics[17] = verifiable_counts.sum() + local_metrics[17] = verifiable_counts.mean() # verifiable count = % of time we trigger the verifiable reward global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() metrics = { "episode": episode, From 028315d118a3368a514750977bdcdd8eb9b426a0 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 9 Oct 2024 19:06:57 +0000 Subject: [PATCH 27/53] fix bug for forward rollout batching --- open_instruct/ppo_vllm_thread.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 42a7e6f02..a56e5b2ff 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -743,7 +743,7 @@ def repeat_generator(): ground_truths_next = data[GROUND_TRUTHS_KEY] send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) queries = queries_next - ground_truths_next = ground_truths_next + ground_truths = ground_truths_next # if we generate multiple samples per prompt, we need to repeat the queries and ground truths # to match the vllm outputs. @@ -820,8 +820,10 @@ def repeat_generator(): score *= args.reward_model_multiplier # also apply verifiable reward if args.apply_verifiable_reward: + # we need to batch the gt to match query. + ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] verifiable_reward, verifiable_count = apply_verifiable_reward( - postprocessed_query_response, tokenizer, ground_truths + postprocessed_query_response, tokenizer, ground_truth ) score += verifiable_reward else: From faa7dc022fca0f05b030bf9ade85282627b95258 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sat, 12 Oct 2024 00:12:29 +0000 Subject: [PATCH 28/53] support gsm8k and math, more flexibility in future --- open_instruct/dataset_processor.py | 9 + open_instruct/ground_truth_utils.py | 71 +++++++ open_instruct/math_utils.py | 283 ++++++++++++++++++++++++++++ open_instruct/model_utils.py | 27 ++- open_instruct/ppo_vllm_thread.py | 14 +- scripts/create_ground_truth_data.py | 154 +++++++++++++++ 6 files changed, 540 insertions(+), 18 deletions(-) create mode 100644 open_instruct/ground_truth_utils.py create mode 100644 open_instruct/math_utils.py create mode 100644 scripts/create_ground_truth_data.py diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index f18b612fa..9dde6f63e 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -52,6 +52,7 @@ INPUT_IDS_PROMPT_KEY = "input_ids_prompt" ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt" GROUND_TRUTHS_KEY = "ground_truth" +DATASET_SOURCE_KEY = "dataset" # NOTE (Costa): the `INPUT_IDS_PROMPT_KEY` is just for visualization purposes only # also we don't really need `ATTENTION_MASK_CHOSEN_KEY` and `ATTENTION_MASK_REJECTED_KEY` @@ -166,6 +167,9 @@ class DatasetConfig: # columns name for the ground truth ground_truths_key: str = GROUND_TRUTHS_KEY + # columns name for dataset source + dataset_source_key: str = DATASET_SOURCE_KEY + # columns names for binary dataset binary_messages_key: str = SFT_MESSAGE_KEY label: str = BINARY_LABEL_KEY @@ -381,6 +385,7 @@ def tokenize_fn(row): labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY]) row[LABELS_KEY] = labels row[GROUND_TRUTHS_KEY] = row[self.config.ground_truths_key] + row[DATASET_SOURCE_KEY] = row[self.config.dataset_source_key] return row return dataset.map( @@ -551,9 +556,13 @@ def __call__(self, batch: list[dict]): # ground truths ground_truths = [x[GROUND_TRUTHS_KEY] for x in batch] + # datasets + datasets = [x[DATASET_SOURCE_KEY] for x in batch] + return { INPUT_IDS_PROMPT_KEY: padded_sequences, GROUND_TRUTHS_KEY: ground_truths, + DATASET_SOURCE_KEY: datasets, } diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py new file mode 100644 index 000000000..8373305dd --- /dev/null +++ b/open_instruct/ground_truth_utils.py @@ -0,0 +1,71 @@ +''' +Collection of 'ground truth rewards' for different datasets/tasks. +Used to give feedback to the model based on the ground truth answer. +''' +import re +from open_instruct.math_utils import last_boxed_only_string, remove_boxed, get_unnormalized_answer, normalize_final_answer, is_equiv, hendrycks_is_equiv + + +def verify_gsm8k_sample(model_output, ground_truth_answer): + # gsm is easy: extract numbers, and then just compare last number with answer. + # matches how we do eval. + predictions = None + # replace numbers like `x,xxx` with `xxxx` + response = re.sub(r"(\d),(\d)", r"\1\2", model_output) + numbers = re.findall(r"[-+]?\d*\.\d+|\d+", response) + if numbers: + predictions = numbers[-1] + else: + predictions = response + return str(predictions).lower() == str(ground_truth_answer).lower() + + +def verify_math_sample(model_output, ground_truth_answer): + raw_answer = model_output + # for math, more complex. We will try a few different ways to extract the answer. + # this roughly follows 'flex em' in oe-eval-internal + all_answers = [] + # First, try find answer in \boxed{}. + boxed_answer = last_boxed_only_string(raw_answer) + if boxed_answer is not None: + try: + boxed_answer = remove_boxed(boxed_answer) + except AssertionError: + boxed_answer = None + if boxed_answer is not None: + all_answers.append(boxed_answer) + # Second, try to extract via minerva format. + minerva_answer = normalize_final_answer(get_unnormalized_answer(raw_answer)) + if minerva_answer is not None and minerva_answer != "[invalidanswer]": + all_answers.append(minerva_answer) + # If nothing still, try to find the last latex-formatted answer + if len(all_answers) == 0: + dollars = [m.start() for m in re.finditer("\\$", raw_answer)] + if len(dollars) > 1: + # Add the answer between the second to last and last dollar sign + answer = normalize_final_answer(raw_answer[dollars[-2] + 1 : dollars[-1]]) + all_answers.append(answer) + # otherwise, just take the full output. Probably wont work, bit of a yolo. + if len(all_answers) == 0: + all_answers.append(normalize_final_answer(model_output)) + # now, compare all answers to ground truth. + matched = False + for answer in all_answers: + if is_equiv(answer, ground_truth_answer): + matched = True + break + elif hendrycks_is_equiv(answer, ground_truth_answer): + matched = True + break + # if we got any match, we are good. + return matched + + +def verify_ifeval_sample(model_output, constraint_list): + # TODO: IFeval. probably have some constraint list we check against. + pass + + +def verify_flan_sample(model_output, ground_truth_answer): + # TODO: flan. we could do BLEU/ROUGE.... or maybe something like BertScore? + pass \ No newline at end of file diff --git a/open_instruct/math_utils.py b/open_instruct/math_utils.py new file mode 100644 index 000000000..01a2fbaaf --- /dev/null +++ b/open_instruct/math_utils.py @@ -0,0 +1,283 @@ +import re +import sympy +import logging +from typing import Optional + +eval_logger = logging.getLogger("math_utils") + +# from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py#L187 +def last_boxed_only_string(string: str) -> Optional[str]: + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + + return retval + + +def remove_boxed(s: str) -> str: + if "\\boxed " in s: + left = "\\boxed " + assert s[: len(left)] == left + return s[len(left) :] + + left = "\\boxed{" + + assert s[: len(left)] == left + assert s[-1] == "}" + + return s[len(left) : -1] + + +def get_unnormalized_answer(text: str) -> str: + INVALID_ANSWER = "[invalidanswer]" + end_seq = "I hope it is correct." + text += end_seq + match = re.search( + r"Final Answer: The final answer is(.*?). I hope it is correct.", + text, + ) + if match: + return match.group(1).strip() + else: + return INVALID_ANSWER + + +SUBSTITUTIONS = [ + ("an ", ""), + ("a ", ""), + (".$", "$"), + ("\\$", ""), + (r"\ ", ""), + (" ", ""), + ("mbox", "text"), + (",\\text{and}", ","), + ("\\text{and}", ","), + ("\\text{m}", "\\text{}"), +] +REMOVED_EXPRESSIONS = [ + "square", + "ways", + "integers", + "dollars", + "mph", + "inches", + "ft", + "hours", + "km", + "units", + "\\ldots", + "sue", + "points", + "feet", + "minutes", + "digits", + "cents", + "degrees", + "cm", + "gm", + "pounds", + "meters", + "meals", + "edges", + "students", + "childrentickets", + "multiples", + "\\text{s}", + "\\text{.}", + "\\text{\ns}", + "\\text{}^2", + "\\text{}^3", + "\\text{\n}", + "\\text{}", + r"\mathrm{th}", + r"^\circ", + r"^{\circ}", + r"\;", + r",\!", + "{,}", + '"', + "\\dots", +] + +def normalize_final_answer(final_answer: str) -> str: + """ + Normalize a final answer to a quantitative reasoning question. + + Copied character for character from appendix D of Lewkowycz et al. (2022) + """ + final_answer = final_answer.split("=")[-1] + + for before, after in SUBSTITUTIONS: + final_answer = final_answer.replace(before, after) + for expr in REMOVED_EXPRESSIONS: + final_answer = final_answer.replace(expr, "") + + # Extract answer that is in LaTeX math, is bold, + # is surrounded by a box, etc. + final_answer = re.sub(r"(.*?)(\$)(.*?)(\$)(.*)", "$\\3$", final_answer) + final_answer = re.sub(r"(\\text\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\textbf\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\overline\{)(.*?)(\})", "\\2", final_answer) + final_answer = re.sub(r"(\\boxed\{)(.*)(\})", "\\2", final_answer) + + # Normalize shorthand TeX: + # \fracab -> \frac{a}{b} + # \frac{abc}{bef} -> \frac{abc}{bef} + # \fracabc -> \frac{a}{b}c + # \sqrta -> \sqrt{a} + # \sqrtab -> sqrt{a}b + final_answer = re.sub(r"(frac)([^{])(.)", "frac{\\2}{\\3}", final_answer) + final_answer = re.sub(r"(sqrt)([^{])", "sqrt{\\2}", final_answer) + final_answer = final_answer.replace("$", "") + + # Normalize 100,000 -> 100000 + if final_answer.replace(",", "").isdigit(): + final_answer = final_answer.replace(",", "") + + return final_answer + + +def is_equiv(x1: str, x2: str) -> bool: + """ + x1 and x2 are normalized latex string + """ + try: + with timeout(seconds=5): + try: + parsed_x1 = parse_latex(x1) + parsed_x2 = parse_latex(x2) + except ( + sympy.parsing.latex.errors.LaTeXParsingError, + sympy.SympifyError, + TypeError, + ): + eval_logger.debug(f"couldn't parse one of {x1} or {x2}") + return False + + try: + diff = parsed_x1 - parsed_x2 + except TypeError: + eval_logger.debug(f"couldn't subtract {x1} and {x2}") + return False + + try: + if sympy.simplify(diff) == 0: + return True + else: + return False + except ValueError: + eval_logger.debug( + f"Had some trouble simplifying when comparing {x1} and {x2}" + ) + except TimeoutError: + eval_logger.debug(f"Timed out comparing {x1} and {x2}") + return False + except ImportError as e: + eval_logger.error(e) + raise + except Exception as e: + eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}") + return False + +def strip_string(string): + # linebreaks + string = string.replace("\n", "") + + # remove inverse spaces + string = string.replace("\\!", "") + + # replace \\ with \ + string = string.replace("\\\\", "\\") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + + # remove units (on the right) + string = remove_right_units(string) + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") # noqa: W605 + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + # fix sqrt3 --> sqrt{3} + string = fix_sqrt(string) + + # remove spaces + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = fix_fracs(string) + + # manually change 0.5 --> \frac{1}{2} + if string == "0.5": + string = "\\frac{1}{2}" + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = fix_a_slash_b(string) + + return string + + +def hendrycks_is_equiv(str1, str2, verbose=False): + if str1 is None and str2 is None: + print("WARNING: Both None") + return True + if str1 is None or str2 is None: + return False + + try: + ss1 = strip_string(str1) + ss2 = strip_string(str2) + if verbose: + print(ss1, ss2) + return ss1 == ss2 + except Exception: + return str1 == str2 diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 1da83b7ae..222c35e52 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -39,6 +39,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer from open_instruct.utils import retry_on_exception +from open_instruct.ground_truth_utils import verify_gsm8k_sample, verify_math_sample @dataclass @@ -205,29 +206,25 @@ def get_reward( sequence_lengths, ) -import re def apply_verifiable_reward( - query_responses: torch.Tensor, tokenizer, ground_truth: str + query_responses: torch.Tensor, tokenizer, ground_truths: List[str], datasets: List[str], verify_reward : int = 10 ): # decode the responses decoded_responses = tokenizer.batch_decode(query_responses, skip_special_tokens=True) # compare with ground truth. # use same logic as in gsm8k evaluation - predictions = [] rewards = [] - for response in decoded_responses: - # replace numbers like `x,xxx` with `xxxx` - response = re.sub(r"(\d),(\d)", r"\1\2", response) - numbers = re.findall(r"[-+]?\d*\.\d+|\d+", response) - if numbers: - predictions.append(numbers[-1]) - else: - predictions.append(response) - for prediction, gt in zip(predictions, ground_truth): - if prediction == gt: - print("Ground truth matched with prediction, giving a bonus reward 🤗") - rewards.append(10) + for prediction, ground_truth, dataset in zip(decoded_responses, ground_truths, datasets): + verified = False + if dataset.lower() == 'gsm8k': + verified = verify_gsm8k_sample(prediction, ground_truth) + elif dataset.lower() == 'math': + verified = verify_math_sample(prediction, ground_truth) + # if verified, give reward + if verified: + print("Applying ground truth reward 🤗") + rewards.append(verify_reward) else: rewards.append(0) rewards_tensors = torch.tensor(rewards, device=query_responses.device) diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index a56e5b2ff..8b73bc053 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -39,6 +39,7 @@ CHAT_TEMPLATES, INPUT_IDS_PROMPT_KEY, GROUND_TRUTHS_KEY, + DATASET_SOURCE_KEY, DatasetConfig, SFTDatasetProcessor, SimpleGenerateCollatorWithGroundTruth, @@ -472,7 +473,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): train_dataset = combine_dataset( args.dataset_mixer_dict, splits=args.dataset_train_splits, - columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key], + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key, dataset_config.dataset_source_key], ) if dataset_config.sanity_check: train_dataset = train_dataset.select( @@ -487,7 +488,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): eval_dataset = combine_dataset( args.dataset_eval_mixer_dict, splits=args.dataset_eval_splits, - columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key], + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key, dataset_config.dataset_source_key], ) eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) with accelerator.main_process_first(): @@ -694,6 +695,7 @@ def repeat_generator(): data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) ground_truths_next = data[GROUND_TRUTHS_KEY] + datsets_next = data[DATASET_SOURCE_KEY] send_queries(accelerator, None, tokenizer, param_prompt_Q, queries_next) for _ in range(1, resume_training_step): # we didn't store scheduler state @@ -704,6 +706,7 @@ def repeat_generator(): scheduler.step() queries = queries_next ground_truths = ground_truths_next + datasets = datsets_next if ph.preempted: break @@ -733,6 +736,7 @@ def repeat_generator(): data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) else: if training_step != 1: @@ -741,15 +745,18 @@ def repeat_generator(): data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] send_queries(accelerator, generation_model, tokenizer, param_prompt_Q, queries_next) queries = queries_next ground_truths = ground_truths_next + datasets = datasets_next # if we generate multiple samples per prompt, we need to repeat the queries and ground truths # to match the vllm outputs. if args.number_samples_per_prompt > 1: queries = queries.repeat_interleave(args.number_samples_per_prompt, dim=0) ground_truths = [gt for gt in ground_truths for _ in range(args.number_samples_per_prompt)] + datasets = [ds for ds in datasets for _ in range(args.number_samples_per_prompt)] training_time_start = time.time() with torch.no_grad(): @@ -822,8 +829,9 @@ def repeat_generator(): if args.apply_verifiable_reward: # we need to batch the gt to match query. ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] + dataset = datasets[i : i + args.local_rollout_forward_batch_size] verifiable_reward, verifiable_count = apply_verifiable_reward( - postprocessed_query_response, tokenizer, ground_truth + postprocessed_query_response, tokenizer, ground_truth, dataset ) score += verifiable_reward else: diff --git a/scripts/create_ground_truth_data.py b/scripts/create_ground_truth_data.py new file mode 100644 index 000000000..42c8ab206 --- /dev/null +++ b/scripts/create_ground_truth_data.py @@ -0,0 +1,154 @@ +''' +My dumb script to create ground truth data for GTRL training. +''' +import os +import re +import random + +from datasets import load_dataset, Dataset, DatasetDict + +from open_instruct.math_utils import remove_boxed, last_boxed_only_string + +# exemplars we will use to prompt the model +GSM8K_EXEMPLARS = [ + { + "question": "There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?", + "cot_answer": "There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. So the answer is 6.", + "short_answer": "6" + }, + { + "question": "If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?", + "cot_answer": "There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. So the answer is 5.", + "short_answer": "5" + }, + { + "question": "Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?", + "cot_answer": "Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. So the answer is 39.", + "short_answer": "39" + }, + { + "question": "Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?", + "cot_answer": "Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. So the answer is 8.", + "short_answer": "8" + }, + { + "question": "Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?", + "cot_answer": "Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. So the answer is 9.", + "short_answer": "9" + }, + { + "question": "There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?", + "cot_answer": "There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. So the answer is 29.", + "short_answer": "29" + }, + { + "question": "Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?", + "cot_answer": "Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. So the answer is 33.", + "short_answer": "33" + }, + { + "question": "Olivia has $23. She bought five bagels for $3 each. How much money does she have left?", + "cot_answer": "Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. So the answer is 8.", + "short_answer": "8" + } +] + + +MATH_EXAMPLARS = [ + { + "question": "Find the domain of the expression $\\frac{\\sqrt{x-2}}{\\sqrt{5-x}}$.}", + "cot_answer": "The expressions inside each square root must be non-negative.\nTherefore, $x-2 \\ge 0$, so $x\\ge2$, and $5 - x \\ge 0$, so $x \\le 5$.\nAlso, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$.\nTherefore, the domain of the expression is $\\boxed{[2,5)}$.", + "short_answer": "[2,5)" + }, + { + "question": "If $\\det \\mathbf{A} = 2$ and $\\det \\mathbf{B} = 12,$ then find $\\det (\\mathbf{A} \\mathbf{B}).$", + "cot_answer": "We have that $\\det (\\mathbf{A} \\mathbf{B}) = (\\det \\mathbf{A})(\\det \\mathbf{B}) = (2)(12) = \\boxed{24}.$", + "short_answer": "24" + }, + { + "question": "Terrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?", + "cot_answer": "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\\cdot 12\\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\\cdot15\\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{align*}\n30n&=480\\\\\n\\Rightarrow\\qquad n&=480/30=\\boxed{16}\n\\end{align*}", + "short_answer": "16" + }, + { + "question": "If the system of equations\n\n\\begin{align*}\n6x-4y&=a,\\\\\n6y-9x &=b.\n\\end{align*}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{a}{b},$ assuming $b$ is nonzero.", + "cot_answer": "If we multiply the first equation by $-\\frac{3}{2}$, we obtain\n\n$$6y-9x=-\\frac{3}{2}a.$$Since we also know that $6y-9x=b$, we have\n\n$$-\\frac{3}{2}a=b\\Rightarrow\\frac{a}{b}=\\boxed{-\\frac{2}{3}}.$$", + "short_answer": "-\\frac{2}{3}" + } +] + +# now, we construct gsm8k data +gsm8k_prompt = "" +for sample in GSM8K_EXEMPLARS: + gsm8k_prompt += f"Question: {sample['question'].strip()}\nAnswer:{sample['cot_answer'].strip()}\n\n" + +gsm8k_dataset = load_dataset("gsm8k", "main", split="train") +new_data = [] +for sample in gsm8k_dataset: + answer = sample["answer"].split("####")[-1].strip() + new_data.append({ + "messages": [{"role": "user", "content": gsm8k_prompt + f"Question: {sample['question'].strip()}"}], + "ground_truth": answer, + "dataset": "gsm8k" + }) + +# also make a test split for eval +gsm8k_dataset = load_dataset("gsm8k", "main", split="test") +test_data = [] +for sample in gsm8k_dataset: + answer = sample["answer"].split("####")[-1].strip() + test_data.append({ + "messages": [{"role": "user", "content": gsm8k_prompt + f"Question: {sample['question'].strip()}"}], + "ground_truth": answer, + "dataset": "gsm8k" + }) + +# now, we construct math data +math_prompt = "" +for sample in MATH_EXAMPLARS: + math_prompt += f"Question: {sample['question'].strip()}\nAnswer:{sample['cot_answer'].strip()}\n\n" +math_dataset = load_dataset("lighteval/MATH", "all", split="train") +for sample in math_dataset: + # same code used to extract answer for eval + answer = remove_boxed(last_boxed_only_string(sample["solution"])) + if answer is None: + print("skipping") + continue + new_data.append({ + "messages": [{"role": "user", "content": math_prompt + f"Question: {sample['problem'].strip()}"}], + "ground_truth": answer, + "dataset": "MATH" + }) + +# combine into one dataset and push +random.shuffle(new_data) +train_dataset = Dataset.from_list(new_data) +test_dataset = Dataset.from_list(test_data) +dataset = DatasetDict({"train": train_dataset, "test": test_dataset}) +dataset.push_to_hub("ai2-adapt-dev/gsm8k_math_ground_truth") + +# alternate dataset: metamathqa! +metamathqa_dataset = load_dataset("meta-math/MetaMathQA", "main", split="train") +# let's re-use the MATH prompt. +new_data = [] +def extract_answer(text): + # Regular expression to match content after "The answer is:" including numbers, LaTeX fractions, or other expressions + pattern = r'The answer is:\s*([^\s.]+)' + matches = re.findall(pattern, text) + return matches[-1] if matches else None +for sample in metamathqa_dataset: + # same code used to extract answer for eval + answer = extract_answer(sample["response"]) + if answer is None: + print("skipping") + continue + new_data.append({ + "messages": [{"role": "user", "content": math_prompt + f"Question: {sample['query'].strip()}"}], + "ground_truth": answer, + "dataset": "MATH" # lets use the math eval setup + }) + +# combine into one dataset and push +random.shuffle(new_data) +dataset = Dataset.from_list(new_data) +dataset.push_to_hub("ai2-adapt-dev/metamathqa_ground_truth") \ No newline at end of file From 59702432ebecdf4da9c9753457f585d17ad47080 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 16 Oct 2024 19:54:55 +0000 Subject: [PATCH 29/53] add costas plo thing --- open_instruct/ppo_vllm_thread.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 8b73bc053..1159aa898 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -198,6 +198,8 @@ class Args: """whether to apply verifiable reward""" reward_model_multiplier: float = 1.0 """the reward model multiplier, for down/upscaling the reward model output""" + use_plo: bool = False + """Toggle PLO: https://arxiv.org/pdf/2010.03956 sec 3.2. Disable grads on 0 reward samples""" # vLLM settings. NOTE: currently we need to place the vLLM model on a separate GPU # for generation to work properly because vLLM would pre-alocate the memory. @@ -924,6 +926,12 @@ def repeat_generator(): advantages = torch.masked_fill(advantages, padding_mask, 0) torch.cuda.empty_cache() + # if we have plo, apply it here! basically will be augmenting the padding mask + if args.use_plo: + zero_scores = (scores == 0)[:, None] + padding_mask = padding_mask | zero_scores + padding_mask_p1 = padding_mask_p1 | zero_scores + # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch for epoch_idx in range(args.num_epochs): b_inds = np.random.permutation(args.local_batch_size) @@ -967,7 +975,10 @@ def repeat_generator(): pg_loss_max = torch.max(pg_losses, pg_losses2) pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) loss = pg_loss + args.vf_coef * vf_loss - + # can happen when no rewards in padding mask, so its 'all padding + if torch.isnan(loss).any(): + # just continue to the next loop, skip this. + continue accelerator.backward(loss) optimizer.step() optimizer.zero_grad() From f8fb8ebdc34269c18db1194ca88126d98ee2251b Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 21 Oct 2024 23:54:32 +0000 Subject: [PATCH 30/53] add numina math --- scripts/create_ground_truth_data.py | 81 ++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/scripts/create_ground_truth_data.py b/scripts/create_ground_truth_data.py index 42c8ab206..74a9b7de1 100644 --- a/scripts/create_ground_truth_data.py +++ b/scripts/create_ground_truth_data.py @@ -4,6 +4,7 @@ import os import re import random +from tqdm import tqdm from datasets import load_dataset, Dataset, DatasetDict @@ -76,6 +77,11 @@ "short_answer": "-\\frac{2}{3}" } ] +math_messages = [ + [{'role': 'user', 'content': sample['question']}, {'role': 'assistant', 'content': sample['cot_answer']}] for sample in MATH_EXAMPLARS +] +# flatten +math_messages = [item for sublist in math_messages for item in sublist] # now, we construct gsm8k data gsm8k_prompt = "" @@ -151,4 +157,77 @@ def extract_answer(text): # combine into one dataset and push random.shuffle(new_data) dataset = Dataset.from_list(new_data) -dataset.push_to_hub("ai2-adapt-dev/metamathqa_ground_truth") \ No newline at end of file +dataset.push_to_hub("ai2-adapt-dev/metamathqa_ground_truth") + +# alternate dataset: numina-tir +metamathqa_dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train") +# let's re-use the MATH prompt. +new_data = [] +def find_last_outermost_boxed(string): + idx = string.rfind("\\boxed") + if "\\boxed " in string: + return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0] + if idx < 0: + idx = string.rfind("\\fbox") + if idx < 0: + return None + + i = idx + right_brace_idx = None + num_left_braces_open = 0 + while i < len(string): + if string[i] == "{": + num_left_braces_open += 1 + if string[i] == "}": + num_left_braces_open -= 1 + if num_left_braces_open == 0: + right_brace_idx = i + break + i += 1 + + if right_brace_idx is None: + retval = None + else: + retval = string[idx : right_brace_idx + 1] + if retval is not None: + retval = retval[7:-1] # remove \boxed{} + return retval +for sample in tqdm(metamathqa_dataset): + # same code used to extract answer for eval + answer = find_last_outermost_boxed(sample["solution"]) + if answer is None: + print("skipping") + continue + # lets use multi-turn cot prompt instead + new_data.append({ + "messages": math_messages + [{"role": "user", "content": f"{sample['problem'].strip()}"}], + "ground_truth": answer, + "dataset": "MATH" # lets use the math eval setup + }) + +# combine into one dataset and push +random.shuffle(new_data) +dataset = Dataset.from_list(new_data) +dataset.push_to_hub("ai2-adapt-dev/numinamath_tir_ground_truth") + +# alternate dataset: numina-cot (much, much larger) +metamathqa_dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train") +# let's re-use the MATH prompt. +new_data = [] +for sample in tqdm(metamathqa_dataset): + # same code used to extract answer for eval + answer = find_last_outermost_boxed(sample["solution"]) + if answer is None: + print("skipping") + continue + # lets use multi-turn cot prompt instead + new_data.append({ + "messages": math_messages + [{"role": "user", "content": f"{sample['problem'].strip()}"}], + "ground_truth": answer, + "dataset": "MATH" # lets use the math eval setup + }) + +# combine into one dataset and push +random.shuffle(new_data) +dataset = Dataset.from_list(new_data) +dataset.push_to_hub("ai2-adapt-dev/numinamath_cot_ground_truth") \ No newline at end of file From eda48496b210aebf0332f13d6d85c663d7d16677 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Tue, 22 Oct 2024 20:57:18 -0700 Subject: [PATCH 31/53] remove plo, add value model rand init, first stab at rephrase model loading --- open_instruct/model_utils.py | 14 +++++- open_instruct/ppo_vllm_thread.py | 35 ++++++++------- scripts/create_ground_truth_data.py | 70 ++++++++++++++--------------- 3 files changed, 65 insertions(+), 54 deletions(-) diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 222c35e52..979dd27cb 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -39,7 +39,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer from open_instruct.utils import retry_on_exception -from open_instruct.ground_truth_utils import verify_gsm8k_sample, verify_math_sample +from open_instruct.ground_truth_utils import verify_gsm8k_sample, verify_math_sample, verify_strict_math_sample @dataclass @@ -208,10 +208,20 @@ def get_reward( def apply_verifiable_reward( - query_responses: torch.Tensor, tokenizer, ground_truths: List[str], datasets: List[str], verify_reward : int = 10 + query_responses: torch.Tensor, tokenizer, ground_truths: List[str], datasets: List[str], verify_reward : int = 10, answer_extraction_model: Optional[torch.nn.Module] = None, answer_extraction_tokenizer: Optional[PreTrainedTokenizer] = None ): # decode the responses decoded_responses = tokenizer.batch_decode(query_responses, skip_special_tokens=True) + # if we have an answer extraction model, use it to extract the answer from the response + if answer_extraction_model is not None: + prompt = "Thus, the final answer is:" + # add the prompt to the responses + decoded_responses = [f"{response} {prompt}" for response in decoded_responses] + # extract the answer + answer_extraction_inputs = answer_extraction_tokenizer(decoded_responses, return_tensors="pt", padding=True, truncation=True) + answer_extraction_outputs = answer_extraction_model(**answer_extraction_inputs) + # get the predicted answer + decoded_responses = answer_extraction_tokenizer.batch_decode(answer_extraction_outputs.logits.argmax(-1), skip_special_tokens=True) # compare with ground truth. # use same logic as in gsm8k evaluation rewards = [] diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 1159aa898..f60fdf6aa 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -158,6 +158,8 @@ class Args: """the path to the reward model""" reward_model_revision: Optional[str] = None """the revision of the reward model""" + init_value_from_scratch: bool = False + """whether to initialize the value model from scratch""" # generation config response_length: int = 53 @@ -198,8 +200,7 @@ class Args: """whether to apply verifiable reward""" reward_model_multiplier: float = 1.0 """the reward model multiplier, for down/upscaling the reward model output""" - use_plo: bool = False - """Toggle PLO: https://arxiv.org/pdf/2010.03956 sec 3.2. Disable grads on 0 reward samples""" + answer_extraction_model: str = None # vLLM settings. NOTE: currently we need to place the vLLM model on a separate GPU # for generation to work properly because vLLM would pre-alocate the memory. @@ -275,7 +276,7 @@ def calculate_runtime_args_and_accelerator(args: Args, model_config: ModelConfig args.local_mini_batch_size = exact_div( args.local_batch_size, args.num_mini_batches, "`local_batch_size` must be a multiple of `num_mini_batches`" ) - args.num_training_steps = args.total_episodes // args.batch_size + args.num_training_steps = args.total_episodes // (args.batch_size * args.number_samples_per_prompt) args.eval_freq = max(1, args.num_training_steps // args.num_evals) # PPO logic: do checks and set up dataloader batch size if args.whiten_rewards: @@ -532,6 +533,8 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): attn_implementation="flash_attention_2", use_cache=False, ) + if args.init_value_from_scratch: + value_model.init_weights() # re-initialize the value model from scratch reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( args.reward_model_path, revision=args.reward_model_revision, @@ -680,7 +683,7 @@ def repeat_generator(): g_vllm_responses = torch.zeros((args.batch_size * args.number_samples_per_prompt, args.response_length), device=device, dtype=torch.long) # set up the metrics and initial states - stats_shape = (args.num_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + stats_shape = (args.num_epochs, args.num_mini_batches * args.number_samples_per_prompt, args.gradient_accumulation_steps) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) pg_loss_stats = torch.zeros(stats_shape, device=device) @@ -692,6 +695,14 @@ def repeat_generator(): episode = args.batch_size * (resume_training_step - 1) model.train() + # setup extraction model. For now keep on CPU? + if args.answer_extraction_model: + answer_extraction_model = AutoModelForCausalLM.from_pretrained(args.answer_extraction_model) + answer_extraction_tokenizer = AutoTokenizer.from_pretrained(aargs.answer_extraction_model) + else: + answer_extraction_model = None + answer_extraction_tokenizer = None + # training loop start_time = time.time() data = next(iter_dataloader) @@ -833,7 +844,7 @@ def repeat_generator(): ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] dataset = datasets[i : i + args.local_rollout_forward_batch_size] verifiable_reward, verifiable_count = apply_verifiable_reward( - postprocessed_query_response, tokenizer, ground_truth, dataset + postprocessed_query_response, tokenizer, ground_truth, dataset, verify_reward=10, answer_extraction_model=answer_extraction_model, answer_extraction_tokenizer=answer_extraction_tokenizer ) score += verifiable_reward else: @@ -926,17 +937,11 @@ def repeat_generator(): advantages = torch.masked_fill(advantages, padding_mask, 0) torch.cuda.empty_cache() - # if we have plo, apply it here! basically will be augmenting the padding mask - if args.use_plo: - zero_scores = (scores == 0)[:, None] - padding_mask = padding_mask | zero_scores - padding_mask_p1 = padding_mask_p1 | zero_scores - # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch for epoch_idx in range(args.num_epochs): - b_inds = np.random.permutation(args.local_batch_size) + b_inds = np.random.permutation(args.local_batch_size * args.number_samples_per_prompt) minibatch_idx = 0 - for mini_batch_start in range(0, args.local_batch_size, args.local_mini_batch_size): + for mini_batch_start in range(0, args.local_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size): mini_batch_end = mini_batch_start + args.local_mini_batch_size mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 @@ -975,10 +980,6 @@ def repeat_generator(): pg_loss_max = torch.max(pg_losses, pg_losses2) pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) loss = pg_loss + args.vf_coef * vf_loss - # can happen when no rewards in padding mask, so its 'all padding - if torch.isnan(loss).any(): - # just continue to the next loop, skip this. - continue accelerator.backward(loss) optimizer.step() optimizer.zero_grad() diff --git a/scripts/create_ground_truth_data.py b/scripts/create_ground_truth_data.py index 74a9b7de1..a4bd9ced8 100644 --- a/scripts/create_ground_truth_data.py +++ b/scripts/create_ground_truth_data.py @@ -127,37 +127,37 @@ }) # combine into one dataset and push -random.shuffle(new_data) -train_dataset = Dataset.from_list(new_data) -test_dataset = Dataset.from_list(test_data) -dataset = DatasetDict({"train": train_dataset, "test": test_dataset}) -dataset.push_to_hub("ai2-adapt-dev/gsm8k_math_ground_truth") - -# alternate dataset: metamathqa! -metamathqa_dataset = load_dataset("meta-math/MetaMathQA", "main", split="train") -# let's re-use the MATH prompt. -new_data = [] -def extract_answer(text): - # Regular expression to match content after "The answer is:" including numbers, LaTeX fractions, or other expressions - pattern = r'The answer is:\s*([^\s.]+)' - matches = re.findall(pattern, text) - return matches[-1] if matches else None -for sample in metamathqa_dataset: - # same code used to extract answer for eval - answer = extract_answer(sample["response"]) - if answer is None: - print("skipping") - continue - new_data.append({ - "messages": [{"role": "user", "content": math_prompt + f"Question: {sample['query'].strip()}"}], - "ground_truth": answer, - "dataset": "MATH" # lets use the math eval setup - }) - -# combine into one dataset and push -random.shuffle(new_data) -dataset = Dataset.from_list(new_data) -dataset.push_to_hub("ai2-adapt-dev/metamathqa_ground_truth") +# random.shuffle(new_data) +# train_dataset = Dataset.from_list(new_data) +# test_dataset = Dataset.from_list(test_data) +# dataset = DatasetDict({"train": train_dataset, "test": test_dataset}) +# dataset.push_to_hub("ai2-adapt-dev/gsm8k_math_ground_truth") + +# # alternate dataset: metamathqa! +# metamathqa_dataset = load_dataset("meta-math/MetaMathQA", "main", split="train") +# # let's re-use the MATH prompt. +# new_data = [] +# def extract_answer(text): +# # Regular expression to match content after "The answer is:" including numbers, LaTeX fractions, or other expressions +# pattern = r'The answer is:\s*([^\s.]+)' +# matches = re.findall(pattern, text) +# return matches[-1] if matches else None +# for sample in metamathqa_dataset: +# # same code used to extract answer for eval +# answer = extract_answer(sample["response"]) +# if answer is None: +# print("skipping") +# continue +# new_data.append({ +# "messages": [{"role": "user", "content": math_prompt + f"Question: {sample['query'].strip()}"}], +# "ground_truth": answer, +# "dataset": "MATH" # lets use the math eval setup +# }) + +# # combine into one dataset and push +# random.shuffle(new_data) +# dataset = Dataset.from_list(new_data) +# dataset.push_to_hub("ai2-adapt-dev/metamathqa_ground_truth") # alternate dataset: numina-tir metamathqa_dataset = load_dataset("AI-MO/NuminaMath-TIR", split="train") @@ -200,7 +200,7 @@ def find_last_outermost_boxed(string): continue # lets use multi-turn cot prompt instead new_data.append({ - "messages": math_messages + [{"role": "user", "content": f"{sample['problem'].strip()}"}], + "messages": [{"role": "user", "content": math_prompt + f"Question: {sample['problem'].strip()}"}], "ground_truth": answer, "dataset": "MATH" # lets use the math eval setup }) @@ -208,7 +208,7 @@ def find_last_outermost_boxed(string): # combine into one dataset and push random.shuffle(new_data) dataset = Dataset.from_list(new_data) -dataset.push_to_hub("ai2-adapt-dev/numinamath_tir_ground_truth") +dataset.push_to_hub("ai2-adapt-dev/numinamath_tir_ground_truth_one_turn") # alternate dataset: numina-cot (much, much larger) metamathqa_dataset = load_dataset("AI-MO/NuminaMath-CoT", split="train") @@ -222,7 +222,7 @@ def find_last_outermost_boxed(string): continue # lets use multi-turn cot prompt instead new_data.append({ - "messages": math_messages + [{"role": "user", "content": f"{sample['problem'].strip()}"}], + "messages": [{"role": "user", "content": math_prompt + f"Question: {sample['problem'].strip()}"}], "ground_truth": answer, "dataset": "MATH" # lets use the math eval setup }) @@ -230,4 +230,4 @@ def find_last_outermost_boxed(string): # combine into one dataset and push random.shuffle(new_data) dataset = Dataset.from_list(new_data) -dataset.push_to_hub("ai2-adapt-dev/numinamath_cot_ground_truth") \ No newline at end of file +dataset.push_to_hub("ai2-adapt-dev/numinamath_cot_ground_truth_one_turn") \ No newline at end of file From b709f37b8bcfc50a14c1d513c4e16b1145012c12 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 23 Oct 2024 22:17:09 +0000 Subject: [PATCH 32/53] math strict verify --- open_instruct/ground_truth_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 8373305dd..bbaae94b1 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -61,6 +61,30 @@ def verify_math_sample(model_output, ground_truth_answer): return matched +def verify_strict_math_sample(model_output, ground_truth_answer): + raw_answer = model_output + # just trying minerva format. + all_answers = [] + # Second, try to extract via minerva format. + minerva_answer = normalize_final_answer(get_unnormalized_answer(raw_answer)) + if minerva_answer is not None and minerva_answer != "[invalidanswer]": + all_answers.append(minerva_answer) + # otherwise, just take the full output. Probably wont work, bit of a yolo. + if len(all_answers) == 0: + all_answers.append(normalize_final_answer(model_output)) + # now, compare all answers to ground truth. + matched = False + for answer in all_answers: + if is_equiv(answer, ground_truth_answer): + matched = True + break + elif hendrycks_is_equiv(answer, ground_truth_answer): + matched = True + break + # if we got any match, we are good. + return matched + + def verify_ifeval_sample(model_output, constraint_list): # TODO: IFeval. probably have some constraint list we check against. pass From 51f0b2aeb7b0a97b08963dd54828ce1e1de24ebf Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Thu, 24 Oct 2024 16:50:33 +0000 Subject: [PATCH 33/53] ifeval code --- open_instruct/ground_truth_utils.py | 39 ++- open_instruct/if_functions.py | 517 ++++++++++++++++++++++++++++ open_instruct/model_utils.py | 4 +- 3 files changed, 554 insertions(+), 6 deletions(-) create mode 100644 open_instruct/if_functions.py diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index bbaae94b1..7697bde5d 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -3,8 +3,9 @@ Used to give feedback to the model based on the ground truth answer. ''' import re +import json from open_instruct.math_utils import last_boxed_only_string, remove_boxed, get_unnormalized_answer, normalize_final_answer, is_equiv, hendrycks_is_equiv - +from open_instruct.if_functions import IF_FUNCTIONS_MAP def verify_gsm8k_sample(model_output, ground_truth_answer): # gsm is easy: extract numbers, and then just compare last number with answer. @@ -85,11 +86,39 @@ def verify_strict_math_sample(model_output, ground_truth_answer): return matched -def verify_ifeval_sample(model_output, constraint_list): - # TODO: IFeval. probably have some constraint list we check against. - pass +def verify_ifeval_sample(model_output, constraint): + # TODO: just pass in final answer. this should be fine for other evals too. + answer = model_output.split("<|assistant|>\n")[-1].strip() + if "func_name" not in constraint: + print("WARNING: constraint missing func_name") + print(constraint) + return False + # first, parse out the constraint string. + func_name = constraint.pop("func_name") + # get the function + func = IF_FUNCTIONS_MAP[func_name] + # for now, ignore this due to data issues. TODO: fix. + if func_name == "validate_repeat_prompt": + return False # missing 'original_prompt' + if func_name == "validate_sections": + return False # missing 'section_splitter' + # now, run the function + # pop out any none args + non_none_args = {k:v for k,v in constraint.items() if v is not None} + # sometimes we have extra args, sometimes not. + if len(constraint) == 0: + return func(model_output) + return func(answer, **non_none_args) def verify_flan_sample(model_output, ground_truth_answer): # TODO: flan. we could do BLEU/ROUGE.... or maybe something like BertScore? - pass \ No newline at end of file + pass + +# debug code +if __name__ == "__main__": + from datasets import load_dataset + ds = load_dataset("ai2-adapt-dev/prompts_with_constraints_for_ground_truth") + test_model_output = "<|assistant|>\nThe answer is $\\boxed{3.14}$" + for sample in ds['train']: + verify_ifeval_sample(test_model_output, sample['ground_truth']) \ No newline at end of file diff --git a/open_instruct/if_functions.py b/open_instruct/if_functions.py new file mode 100644 index 000000000..bfb1602d8 --- /dev/null +++ b/open_instruct/if_functions.py @@ -0,0 +1,517 @@ +import re +import json +import langdetect +from typing import List + + +""" +This module contains functions to verify constraints in the responses generated by the model. +It covers all 25 constraints from the IFEval taxonomy. To be used either for eval or for ground truth rewards. +""" + + +# include keywords: Include keywords {keyword1}, {keyword2} in your response + +def verify_keywords(response, keyword_list): + """ + Verify if the response contains all the specified keywords. + + Args: + response (str): The response text to check + keyword_list (list): A list of keywords to check for + + Returns: + bool: True if all keywords are present in the response, False otherwise + """ + # Convert response to lowercase for case-insensitive matching + response_lower = response.lower() + + # Check if all keywords are present in the response + return all(keyword.lower() in response_lower for keyword in keyword_list) + + +# Keyword Frequency: In your response, the word {word} should appear {N} times. +# may take multiple keywords. +def verify_keyword_frequency(text, word, N): + """ + Verifies if a keyword appears exactly N times in the given text. + + Args: + text (str): The text to analyze + keyword_list (List[str]): The keywords to count + expected_count (int): The expected number of occurrences for each keyword + + Returns: + tuple: (bool, int) - (Whether constraint is met, actual count found) + """ + # Convert text to lowercase to make the search case-insensitive + text = text.lower() + keyword = word.lower() + + # Split text into words and remove punctuation + import re + words = re.findall(r'\b\w+\b', text) + + # Count actual occurrences + actual_count = sum(1 for word in words if word == keyword) + + # Check if constraint is met + constraint_met = actual_count == N + + return constraint_met + + +# Forbidden Words: Do not include keywords {forbidden words} in the response. +def validate_forbidden_words(text, forbidden_words): + """ + Validates that the text does not contain any of the specified forbidden words. + + Args: + text (str): The text to check for forbidden words + forbidden_words (list[str]): A list of forbidden words + + Returns: + tuple[bool, list[str]]: A tuple containing: + - Boolean indicating if any forbidden words are present + - List of forbidden words found in the text + + Example: + text = "This is a message that should not contain any bad words" + forbidden_words = ["bad", "evil", "harmful"] + result = validate_forbidden_words(text, forbidden_words) + """ + # Convert text to lowercase for case-insensitive matching + text_lower = text.lower() + + # Check each forbidden word + found_words = [word for word in forbidden_words if word.lower() in text_lower] + + # Return results + return len(found_words) == 0 + + +# Letter Frequency : In your response, the letter {letter} should appear {N} times. + +def verify_letter_frequency(text: str, letter: str, N: int) -> bool: + """ + Verifies if a given letter appears exactly the specified number of times in the text. + + Args: + text (str): The text to check + letter (str): The letter to count (case-sensitive) + target_count (int): The expected number of occurrences + + Returns: + bool: True if the constraint is met, False otherwise + + Example: + >>> verify_letter_frequency("hello world", "l", 3) + True + >>> verify_letter_frequency("hello world", "o", 2) + True + >>> verify_letter_frequency("hello world", "x", 0) + True + """ + if len(letter) != 1: + raise ValueError("Letter parameter must be a single character") + + actual_count = text.count(letter) + return actual_count == N + + +# Response Language: Your ENTIRE response should be in {language}, no other language is allowed. + +def validate_response_language(text, language): + """ + Validates that the entire response is in the specified language. + + Args: + text (str): The text to check + language (str): The language code (e.g., 'en' for English) + + Returns: + bool: True if the response is entirely in the specified language, False otherwise + + Example: + text = "This is an English sentence" + language = "en" + result = validate_response_language(text, language) + """ + from langdetect import detect + + # Detect the language of the text + detected_language = detect(text) + # Check if the detected language matches the expected language + return detected_language == language + + +# Number Paragraphs: Your response should contain {N} paragraphs. You separate paragraphs using the markdown divider: +# * * * +def verify_paragraph_count(text: str, N: int) -> bool: + """ + Verifies that a text contains the expected number of paragraphs, + where paragraphs are separated by markdown dividers '* * *' + + Args: + text (str): The text to analyze + expected_count (int): Expected number of paragraphs + + Returns: + bool: True if the text contains exactly the expected number of paragraphs, + False otherwise + + Example: + text = "First paragraph\n* * *\nSecond paragraph" + verify_paragraph_count(text, 2) + True + """ + def clean_text(text: str) -> str: + """Remove extra whitespace and normalize line endings""" + return '\n'.join(line.strip() for line in text.splitlines()).strip() + + # Clean the input text + text = clean_text(text) + + # Split text by markdown divider + # Add 1 to count since n dividers create n+1 paragraphs + paragraphs = text.split('* * *') + actual_count = len(paragraphs) + + # Verify each split resulted in non-empty content + valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] + if len(valid_paragraphs) != actual_count: + return False + + return actual_count == N + + +# Number Words: Answer with at least / around / at most {N} words + +def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: + """ + Validates if a text meets specified word count constraints. + + Args: + text (str): The text to check + count (int): The target word count + qualifier (str): The type of constraint ('at least', 'around', 'at most') + + Returns: + bool: True if the constraint is met, False otherwise + + Raises: + ValueError: If an invalid qualifier is provided + """ + # Remove extra whitespace and split into words + words = text.strip().split() + actual_count = len(words) + + # Define tolerance for "around" qualifier (±10% of target count) + tolerance = max(round(N * 0.1), 1) + + if quantifier == "at least": + return actual_count >= N + elif quantifier == "at most": + return actual_count <= N + elif quantifier == "around": + return abs(actual_count - N) <= tolerance + else: + return False + + +# Number Sentences: Answer with at least / around / at most {N} sentences. +def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: + """ + Verifies if a text contains the expected number of sentences. + + Args: + text (str): The text to analyze + N (int): The expected number of sentences + quantifier (str): The quantifier ('at least', 'around', 'at most') + + Returns: + bool: True if the text contains the expected number of sentences, False otherwise + """ + # Split the text into sentences + sentences = re.split(r'(?= N + elif quantifier == 'around': + return abs(actual_count - N) <= 1 + elif quantifier == 'at most': + return actual_count <= N + else: + return False + + +# Number Paragraphs + First Word in i-th Paragraph: There should be {N} paragraphs. Paragraphs and only paragraphs +# are separated with each other by two line breaks. The {i}-th paragraph must start with word {first word}. +def validate_paragraphs(text, N, first_word, i): + """ + Validates that a text contains the expected number of paragraphs and that the i-th paragraph starts with a specific + word. + + Args: + text (str): The text to analyze + N (int): The expected number of paragraphs + first_word (str): The expected first word of the i-th paragraph + i (int): The index of the paragraph to check (1-indexed) + + Returns: + bool: True if the text meets the paragraph and first word requirements, False otherwise + """ + # Split the text into paragraphs + paragraphs = text.split('\n\n') + + # Check if the number of paragraphs is as expected + if len(paragraphs) != N: + return False + + # Check if the i-th paragraph starts with the specified first word + if paragraphs[i - 1].strip().startswith(first_word): + return True + return False + + +# Postscript: At the end of your response, please explicitly add a postscript starting with {postscript marker} + +def verify_postscript(text, postscript_marker): + """ + Verifies if a text contains a postscript starting with '{postscript marker}' + + Args: + text (str): The text to verify + + Returns: + bool: True if the text contains a valid postscript, False otherwise + """ + # Check if the text contains the postscript marker + if postscript_marker in text: + # Get the index of the marker + marker_index = text.find(postscript_marker) + # Check if the marker appears near the end + remaining_text = text[marker_index:].strip() + # Verify it's not just the marker alone + return len(remaining_text) > len(postscript_marker) + return False + + +# Number Placeholder: The response must contain at least {N} placeholders represented by square brackets, +# such as [address]. +def validate_placeholders(text: str, N: int) -> tuple[bool, List[str]]: + """ + Validates if a text contains at least the specified number of placeholders in square brackets. + + Args: + text (str): The text to check for placeholders + min_placeholders (int): Minimum number of placeholders required + + Returns: + tuple[bool, List[str]]: A tuple containing: + - Boolean indicating if the text meets the placeholder requirement + - List of found placeholders + + Example: + >>> text = "Hello [name], your [item] will be delivered to [address]" + >>> validate_placeholders(text, 2) + (True, ['name', 'item', 'address']) + """ + # Find all placeholders using regex + pattern = r'\[(.*?)\]' + placeholders = re.findall(pattern, text) + + # Check if the number of placeholders meets the requirement + has_enough = len(placeholders) >= N + + return has_enough, placeholders + + +# Number Bullets: Your answer must contain exactly {N} bullet points. Use the markdown bullet points such as: * This +# is a point. +def verify_bullet_points(text: str, N: int) -> tuple[bool, str]: + """ + Verifies if a text contains exactly N bullet points in markdown format. + Returns a tuple of (is_valid, message). + + Args: + text (str): The text to check + expected_count (int): The expected number of bullet points + + Returns: + tuple[bool, str]: (True if constraint is met, explanation message) + """ + # Split text into lines and count lines starting with * or - + lines = text.split('\n') + bullet_points = [line.strip() for line in lines if line.strip().startswith(('*', '-'))] + actual_count = len(bullet_points) + + if actual_count == N: + return True + else: + return False + + +# Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. +def validate_title(answer: str) -> bool: + pattern = r"<<(.*?)>>" + matches = re.findall(pattern, answer) + + if len(matches) > 0: + return True + else: + return False + + +# Choose: From Answer with one of the following options: {options} +def validate_choice(answer: str, options: list) -> bool: + for option in options: + if answer in option: + return True + return False + + +# Minimum Number Highlighted Section: Highlight at least {N} sections in your answer with markdown, i.e. *highlighted +# section* +def validate_highlighted_sections(answer: str, N: int) -> bool: + pattern = r"\*(.*?)\*" + matches = re.findall(pattern, answer) + + if len(matches) >= N: + return True + else: + return False + + +# Multiple Sections: Your response must have {N} sections. Mark the beginning of each section with {section splitter} X. + +def validate_sections(answer: str, N: int, section_splitter: str) -> bool: + sections = answer.split(section_splitter) + # The first section might not start with the splitter, so we adjust for this + if sections[0] == '': + sections.pop(0) + if len(sections) == N: + return True + else: + return False + + +# JSON Format : Entire output should be wrapped in JSON format. +def validate_json_format(data: str) -> bool: + try: + json_object = json.loads(data) + except ValueError as e: + return False + return True + + +# Repeat Prompt: First, repeat the request without change, then give your answer (do not say anything before +# repeating the request; the request you need to repeat does not include this sentence) +def validate_repeat_prompt(full_response: str, original_prompt: str) -> bool: + response_split = full_response.split('. ', 1) + if response_split[0] == original_prompt: + return True + else: + return False + + +# Two Responses: Give two different responses. Responses and only responses should be separated by 6 asterisk +# symbols: ******. +def validate_two_responses(responses: str) -> bool: + if responses.count('******') == 1: + response_list = responses.split('******') + first_response = response_list[0].strip() + second_response = response_list[1].strip() + if first_response != second_response: + return True + return False + + +# All Uppercase: Your entire response should be in English, capital letters only. +def validate_uppercase(response: str) -> bool: + # Check if the response is the same as the uppercase version of the response + if response == response.upper(): + return True + else: + return False + + +# All Lowercase: Your entire response should be in English, and in all lowercase letters. No capital letters are +# allowed. +def validate_lowercase(response: str) -> bool: + # Check if the response is the same as the lowercase version of the response + if response == response.lower(): + return True + else: + return False + + +# Frequency of All-capital Words: In your response, words with all capital letters should appear at least / around / +# at most {N} times. +def validate_frequency_capital_words(response: str, N: int, quantifier: str) -> bool: + words = re.findall(r'\b[A-Z]+\b', response) + if quantifier == 'at least': + return len(words) >= N + elif quantifier == 'around': + return len(words) == N + elif quantifier == 'at most': + return len(words) <= N + else: + return False + + +# End Checker: Finish your response with this exact phrase {end phrase}. No other words should follow this phrase. +def validate_end(response: str, end_phrase: str) -> bool: + # Check if the response ends with the end phrase + if response.endswith(end_phrase): + return True + else: + return False + + +# Quotation: Wrap your entire response with double quotation marks. +def validate_quotation(response: str) -> bool: + if response.startswith('"') and response.endswith('"'): + return True + else: + return False + + +# No Commas: In your entire response, refrain from the use of any commas. +def validate_no_commas(response: str) -> bool: + if ',' not in response: + return True + else: + return False + +IF_FUNCTIONS_MAP = { + 'verify_keywords': verify_keywords, + 'verify_keyword_frequency': verify_keyword_frequency, + 'validate_forbidden_words': validate_forbidden_words, + 'verify_letter_frequency': verify_letter_frequency, + 'validate_response_language': validate_response_language, + 'verify_paragraph_count': verify_paragraph_count, + 'validate_word_constraint': validate_word_constraint, + 'verify_sentence_constraint': verify_sentence_constraint, + 'validate_paragraphs': validate_paragraphs, + 'verify_postscript': verify_postscript, + 'validate_placeholders': validate_placeholders, + 'verify_bullet_points': verify_bullet_points, + 'validate_title': validate_title, + 'validate_choice': validate_choice, + 'validate_highlighted_sections': validate_highlighted_sections, + 'validate_sections': validate_sections, + 'validate_json_format': validate_json_format, + 'validate_repeat_prompt': validate_repeat_prompt, + 'validate_two_responses': validate_two_responses, + 'validate_uppercase': validate_uppercase, + 'validate_lowercase': validate_lowercase, + 'validate_frequency_capital_words': validate_frequency_capital_words, + 'validate_end': validate_end, + 'validate_quotation': validate_quotation, + 'validate_no_commas': validate_no_commas +} diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 979dd27cb..37003deb7 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -39,7 +39,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer from open_instruct.utils import retry_on_exception -from open_instruct.ground_truth_utils import verify_gsm8k_sample, verify_math_sample, verify_strict_math_sample +from open_instruct.ground_truth_utils import verify_gsm8k_sample, verify_math_sample, verify_strict_math_sample, verify_ifeval_sample @dataclass @@ -231,6 +231,8 @@ def apply_verifiable_reward( verified = verify_gsm8k_sample(prediction, ground_truth) elif dataset.lower() == 'math': verified = verify_math_sample(prediction, ground_truth) + elif dataset.lower() == 'ifeval': + verified = verify_ifeval_sample(prediction, ground_truth) # if verified, give reward if verified: print("Applying ground truth reward 🤗") From 79ec9605381a0573efc693a464e82c8931c3a3c8 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Thu, 24 Oct 2024 16:50:59 +0000 Subject: [PATCH 34/53] ifeval debug --- open_instruct/ground_truth_utils.py | 1 + scripts/eval_constraints/if_functions.py | 488 ----------------------- 2 files changed, 1 insertion(+), 488 deletions(-) delete mode 100644 scripts/eval_constraints/if_functions.py diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 7697bde5d..b6f3134e6 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -121,4 +121,5 @@ def verify_flan_sample(model_output, ground_truth_answer): ds = load_dataset("ai2-adapt-dev/prompts_with_constraints_for_ground_truth") test_model_output = "<|assistant|>\nThe answer is $\\boxed{3.14}$" for sample in ds['train']: + print(sample) verify_ifeval_sample(test_model_output, sample['ground_truth']) \ No newline at end of file diff --git a/scripts/eval_constraints/if_functions.py b/scripts/eval_constraints/if_functions.py deleted file mode 100644 index 88ef69ec5..000000000 --- a/scripts/eval_constraints/if_functions.py +++ /dev/null @@ -1,488 +0,0 @@ -import re -import json -import langdetect -from typing import List - - -""" -This module contains functions to verify constraints in the responses generated by the model. -It covers all 25 constraints from the IFEval taxonomy. To be used either for eval or for ground truth rewards. -""" - - -# include keywords: Include keywords {keyword1}, {keyword2} in your response - -def verify_keywords(response, keyword_list): - """ - Verify if the response contains all the specified keywords. - - Args: - response (str): The response text to check - keyword_list (list): A list of keywords to check for - - Returns: - bool: True if all keywords are present in the response, False otherwise - """ - # Convert response to lowercase for case-insensitive matching - response_lower = response.lower() - - # Check if all keywords are present in the response - return all(keyword.lower() in response_lower for keyword in keyword_list) - - -# Keyword Frequency: In your response, the word {word} should appear {N} times. -def verify_keyword_frequency(text, word, N): - """ - Verifies if a keyword appears exactly N times in the given text. - - Args: - text (str): The text to analyze - keyword (str): The keyword to count - expected_count (int): The expected number of occurrences - - Returns: - tuple: (bool, int) - (Whether constraint is met, actual count found) - """ - # Convert text to lowercase to make the search case-insensitive - text = text.lower() - keyword = word.lower() - - # Split text into words and remove punctuation - import re - words = re.findall(r'\b\w+\b', text) - - # Count actual occurrences - actual_count = sum(1 for word in words if word == keyword) - - # Check if constraint is met - constraint_met = actual_count == N - - return constraint_met - - -# Forbidden Words: Do not include keywords {forbidden words} in the response. -def validate_forbidden_words(text, forbidden_words): - """ - Validates that the text does not contain any of the specified forbidden words. - - Args: - text (str): The text to check for forbidden words - forbidden_words (list[str]): A list of forbidden words - - Returns: - tuple[bool, list[str]]: A tuple containing: - - Boolean indicating if any forbidden words are present - - List of forbidden words found in the text - - Example: - text = "This is a message that should not contain any bad words" - forbidden_words = ["bad", "evil", "harmful"] - result = validate_forbidden_words(text, forbidden_words) - """ - # Convert text to lowercase for case-insensitive matching - text_lower = text.lower() - - # Check each forbidden word - found_words = [word for word in forbidden_words if word.lower() in text_lower] - - # Return results - return len(found_words) == 0 - - -# Letter Frequency : In your response, the letter {letter} should appear {N} times. - -def verify_letter_frequency(text: str, letter: str, N: int) -> bool: - """ - Verifies if a given letter appears exactly the specified number of times in the text. - - Args: - text (str): The text to check - letter (str): The letter to count (case-sensitive) - target_count (int): The expected number of occurrences - - Returns: - bool: True if the constraint is met, False otherwise - - Example: - >>> verify_letter_frequency("hello world", "l", 3) - True - >>> verify_letter_frequency("hello world", "o", 2) - True - >>> verify_letter_frequency("hello world", "x", 0) - True - """ - if len(letter) != 1: - raise ValueError("Letter parameter must be a single character") - - actual_count = text.count(letter) - return actual_count == N - - -# Response Language: Your ENTIRE response should be in {language}, no other language is allowed. - -def validate_response_language(text, language): - """ - Validates that the entire response is in the specified language. - - Args: - text (str): The text to check - language (str): The language code (e.g., 'en' for English) - - Returns: - bool: True if the response is entirely in the specified language, False otherwise - - Example: - text = "This is an English sentence" - language = "en" - result = validate_response_language(text, language) - """ - from langdetect import detect - - # Detect the language of the text - detected_language = detect(text) - # Check if the detected language matches the expected language - return detected_language == language - - -# Number Paragraphs: Your response should contain {N} paragraphs. You separate paragraphs using the markdown divider: -# * * * -def verify_paragraph_count(text: str, N: int) -> bool: - """ - Verifies that a text contains the expected number of paragraphs, - where paragraphs are separated by markdown dividers '* * *' - - Args: - text (str): The text to analyze - expected_count (int): Expected number of paragraphs - - Returns: - bool: True if the text contains exactly the expected number of paragraphs, - False otherwise - - Example: - text = "First paragraph\n* * *\nSecond paragraph" - verify_paragraph_count(text, 2) - True - """ - def clean_text(text: str) -> str: - """Remove extra whitespace and normalize line endings""" - return '\n'.join(line.strip() for line in text.splitlines()).strip() - - # Clean the input text - text = clean_text(text) - - # Split text by markdown divider - # Add 1 to count since n dividers create n+1 paragraphs - paragraphs = text.split('* * *') - actual_count = len(paragraphs) - - # Verify each split resulted in non-empty content - valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] - if len(valid_paragraphs) != actual_count: - return False - - return actual_count == N - - -# Number Words: Answer with at least / around / at most {N} words - -def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: - """ - Validates if a text meets specified word count constraints. - - Args: - text (str): The text to check - count (int): The target word count - qualifier (str): The type of constraint ('at least', 'around', 'at most') - - Returns: - bool: True if the constraint is met, False otherwise - - Raises: - ValueError: If an invalid qualifier is provided - """ - # Remove extra whitespace and split into words - words = text.strip().split() - actual_count = len(words) - - # Define tolerance for "around" qualifier (±10% of target count) - tolerance = max(round(N * 0.1), 1) - - if quantifier == "at least": - return actual_count >= N - elif quantifier == "at most": - return actual_count <= N - elif quantifier == "around": - return abs(actual_count - N) <= tolerance - else: - return False - - -# Number Sentences: Answer with at least / around / at most {N} sentences. -def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: - """ - Verifies if a text contains the expected number of sentences. - - Args: - text (str): The text to analyze - N (int): The expected number of sentences - quantifier (str): The quantifier ('at least', 'around', 'at most') - - Returns: - bool: True if the text contains the expected number of sentences, False otherwise - """ - # Split the text into sentences - sentences = re.split(r'(?= N - elif quantifier == 'around': - return abs(actual_count - N) <= 1 - elif quantifier == 'at most': - return actual_count <= N - else: - return False - - -# Number Paragraphs + First Word in i-th Paragraph: There should be {N} paragraphs. Paragraphs and only paragraphs -# are separated with each other by two line breaks. The {i}-th paragraph must start with word {first word}. -def validate_paragraphs(text, N, first_word, i): - """ - Validates that a text contains the expected number of paragraphs and that the i-th paragraph starts with a specific - word. - - Args: - text (str): The text to analyze - N (int): The expected number of paragraphs - first_word (str): The expected first word of the i-th paragraph - i (int): The index of the paragraph to check (1-indexed) - - Returns: - bool: True if the text meets the paragraph and first word requirements, False otherwise - """ - # Split the text into paragraphs - paragraphs = text.split('\n\n') - - # Check if the number of paragraphs is as expected - if len(paragraphs) != N: - return False - - # Check if the i-th paragraph starts with the specified first word - if paragraphs[i - 1].strip().startswith(first_word): - return True - return False - - -# Postscript: At the end of your response, please explicitly add a postscript starting with {postscript marker} - -def verify_postscript(text, postscript_marker): - """ - Verifies if a text contains a postscript starting with '{postscript marker}' - - Args: - text (str): The text to verify - - Returns: - bool: True if the text contains a valid postscript, False otherwise - """ - # Check if the text contains the postscript marker - if postscript_marker in text: - # Get the index of the marker - marker_index = text.find(postscript_marker) - # Check if the marker appears near the end - remaining_text = text[marker_index:].strip() - # Verify it's not just the marker alone - return len(remaining_text) > len(postscript_marker) - return False - - -# Number Placeholder: The response must contain at least {N} placeholders represented by square brackets, -# such as [address]. -def validate_placeholders(text: str, N: int) -> tuple[bool, List[str]]: - """ - Validates if a text contains at least the specified number of placeholders in square brackets. - - Args: - text (str): The text to check for placeholders - min_placeholders (int): Minimum number of placeholders required - - Returns: - tuple[bool, List[str]]: A tuple containing: - - Boolean indicating if the text meets the placeholder requirement - - List of found placeholders - - Example: - >>> text = "Hello [name], your [item] will be delivered to [address]" - >>> validate_placeholders(text, 2) - (True, ['name', 'item', 'address']) - """ - # Find all placeholders using regex - pattern = r'\[(.*?)\]' - placeholders = re.findall(pattern, text) - - # Check if the number of placeholders meets the requirement - has_enough = len(placeholders) >= N - - return has_enough, placeholders - - -# Number Bullets: Your answer must contain exactly {N} bullet points. Use the markdown bullet points such as: * This -# is a point. -def verify_bullet_points(text: str, N: int) -> tuple[bool, str]: - """ - Verifies if a text contains exactly N bullet points in markdown format. - Returns a tuple of (is_valid, message). - - Args: - text (str): The text to check - expected_count (int): The expected number of bullet points - - Returns: - tuple[bool, str]: (True if constraint is met, explanation message) - """ - # Split text into lines and count lines starting with * or - - lines = text.split('\n') - bullet_points = [line.strip() for line in lines if line.strip().startswith(('*', '-'))] - actual_count = len(bullet_points) - - if actual_count == N: - return True - else: - return False - - -# Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. -def validate_title(answer: str) -> bool: - pattern = r"<<(.*?)>>" - matches = re.findall(pattern, answer) - - if len(matches) > 0: - return True - else: - return False - - -# Choose: From Answer with one of the following options: {options} -def validate_choice(answer: str, options: list) -> bool: - for option in options: - if answer in option: - return True - return False - - -# Minimum Number Highlighted Section: Highlight at least {N} sections in your answer with markdown, i.e. *highlighted -# section* -def validate_highlighted_sections(answer: str, N: int) -> bool: - pattern = r"\*(.*?)\*" - matches = re.findall(pattern, answer) - - if len(matches) >= N: - return True - else: - return False - - -# Multiple Sections: Your response must have {N} sections. Mark the beginning of each section with {section splitter} X. - -def validate_sections(answer: str, N: int, section_splitter: str) -> bool: - sections = answer.split(section_splitter) - # The first section might not start with the splitter, so we adjust for this - if sections[0] == '': - sections.pop(0) - if len(sections) == N: - return True - else: - return False - - -# JSON Format : Entire output should be wrapped in JSON format. -def validate_json_format(data: str) -> bool: - try: - json_object = json.loads(data) - except ValueError as e: - return False - return True - - -# Repeat Prompt: First, repeat the request without change, then give your answer (do not say anything before -# repeating the request; the request you need to repeat does not include this sentence) -def validate_repeat_prompt(full_response: str, original_prompt: str) -> bool: - response_split = full_response.split('. ', 1) - if response_split[0] == original_prompt: - return True - else: - return False - - -# Two Responses: Give two different responses. Responses and only responses should be separated by 6 asterisk -# symbols: ******. -def validate_two_responses(responses: str) -> bool: - if responses.count('******') == 1: - response_list = responses.split('******') - first_response = response_list[0].strip() - second_response = response_list[1].strip() - if first_response != second_response: - return True - return False - - -# All Uppercase: Your entire response should be in English, capital letters only. -def validate_uppercase(response: str) -> bool: - # Check if the response is the same as the uppercase version of the response - if response == response.upper(): - return True - else: - return False - - -# All Lowercase: Your entire response should be in English, and in all lowercase letters. No capital letters are -# allowed. -def validate_lowercase(response: str) -> bool: - # Check if the response is the same as the lowercase version of the response - if response == response.lower(): - return True - else: - return False - - -# Frequency of All-capital Words: In your response, words with all capital letters should appear at least / around / -# at most {N} times. -def validate_frequency_capital_words(response: str, N: int, quantifier: str) -> bool: - words = re.findall(r'\b[A-Z]+\b', response) - if quantifier == 'at least': - return len(words) >= N - elif quantifier == 'around': - return len(words) == N - elif quantifier == 'at most': - return len(words) <= N - else: - return False - - -# End Checker: Finish your response with this exact phrase {end phrase}. No other words should follow this phrase. -def validate_end(response: str, end_phrase: str) -> bool: - # Check if the response ends with the end phrase - if response.endswith(end_phrase): - return True - else: - return False - - -# Quotation: Wrap your entire response with double quotation marks. -def validate_quotation(response: str) -> bool: - if response.startswith('"') and response.endswith('"'): - return True - else: - return False - - -# No Commas: In your entire response, refrain from the use of any commas. -def validate_no_commas(response: str) -> bool: - if ',' not in response: - return True - else: - return False From b1b47bf4874b3c971bf7b6cf1200cf5e36a3b267 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Thu, 24 Oct 2024 17:46:27 +0000 Subject: [PATCH 35/53] incorporate val fixes --- open_instruct/if_functions.py | 65 +++++++++++++++++------------------ 1 file changed, 31 insertions(+), 34 deletions(-) diff --git a/open_instruct/if_functions.py b/open_instruct/if_functions.py index bfb1602d8..78b0a959e 100644 --- a/open_instruct/if_functions.py +++ b/open_instruct/if_functions.py @@ -3,7 +3,6 @@ import langdetect from typing import List - """ This module contains functions to verify constraints in the responses generated by the model. It covers all 25 constraints from the IFEval taxonomy. To be used either for eval or for ground truth rewards. @@ -12,7 +11,7 @@ # include keywords: Include keywords {keyword1}, {keyword2} in your response -def verify_keywords(response, keyword_list): +def verify_keywords(text, keyword_list): """ Verify if the response contains all the specified keywords. @@ -24,22 +23,21 @@ def verify_keywords(response, keyword_list): bool: True if all keywords are present in the response, False otherwise """ # Convert response to lowercase for case-insensitive matching - response_lower = response.lower() + response_lower = text.lower() # Check if all keywords are present in the response return all(keyword.lower() in response_lower for keyword in keyword_list) # Keyword Frequency: In your response, the word {word} should appear {N} times. -# may take multiple keywords. def verify_keyword_frequency(text, word, N): """ Verifies if a keyword appears exactly N times in the given text. Args: text (str): The text to analyze - keyword_list (List[str]): The keywords to count - expected_count (int): The expected number of occurrences for each keyword + keyword (str): The keyword to count + expected_count (int): The expected number of occurrences Returns: tuple: (bool, int) - (Whether constraint is met, actual count found) @@ -357,9 +355,9 @@ def verify_bullet_points(text: str, N: int) -> tuple[bool, str]: # Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. -def validate_title(answer: str) -> bool: +def validate_title(text: str) -> bool: pattern = r"<<(.*?)>>" - matches = re.findall(pattern, answer) + matches = re.findall(pattern, text) if len(matches) > 0: return True @@ -368,18 +366,18 @@ def validate_title(answer: str) -> bool: # Choose: From Answer with one of the following options: {options} -def validate_choice(answer: str, options: list) -> bool: +def validate_choice(text: str, options: list) -> bool: for option in options: - if answer in option: + if text in option: return True return False # Minimum Number Highlighted Section: Highlight at least {N} sections in your answer with markdown, i.e. *highlighted # section* -def validate_highlighted_sections(answer: str, N: int) -> bool: +def validate_highlighted_sections(text: str, N: int) -> bool: pattern = r"\*(.*?)\*" - matches = re.findall(pattern, answer) + matches = re.findall(pattern, text) if len(matches) >= N: return True @@ -389,8 +387,8 @@ def validate_highlighted_sections(answer: str, N: int) -> bool: # Multiple Sections: Your response must have {N} sections. Mark the beginning of each section with {section splitter} X. -def validate_sections(answer: str, N: int, section_splitter: str) -> bool: - sections = answer.split(section_splitter) +def validate_sections(text: str, N: int, section_splitter: str) -> bool: + sections = text.split(section_splitter) # The first section might not start with the splitter, so we adjust for this if sections[0] == '': sections.pop(0) @@ -401,9 +399,9 @@ def validate_sections(answer: str, N: int, section_splitter: str) -> bool: # JSON Format : Entire output should be wrapped in JSON format. -def validate_json_format(data: str) -> bool: +def validate_json_format(text: str) -> bool: try: - json_object = json.loads(data) + json_object = json.loads(text) except ValueError as e: return False return True @@ -411,9 +409,8 @@ def validate_json_format(data: str) -> bool: # Repeat Prompt: First, repeat the request without change, then give your answer (do not say anything before # repeating the request; the request you need to repeat does not include this sentence) -def validate_repeat_prompt(full_response: str, original_prompt: str) -> bool: - response_split = full_response.split('. ', 1) - if response_split[0] == original_prompt: +def validate_repeat_prompt(text: str, original_prompt: str) -> bool: + if text.startswith(original_prompt): return True else: return False @@ -421,9 +418,9 @@ def validate_repeat_prompt(full_response: str, original_prompt: str) -> bool: # Two Responses: Give two different responses. Responses and only responses should be separated by 6 asterisk # symbols: ******. -def validate_two_responses(responses: str) -> bool: - if responses.count('******') == 1: - response_list = responses.split('******') +def validate_two_responses(text: str) -> bool: + if text.count('******') == 1: + response_list = text.split('******') first_response = response_list[0].strip() second_response = response_list[1].strip() if first_response != second_response: @@ -432,9 +429,9 @@ def validate_two_responses(responses: str) -> bool: # All Uppercase: Your entire response should be in English, capital letters only. -def validate_uppercase(response: str) -> bool: +def validate_uppercase(text: str) -> bool: # Check if the response is the same as the uppercase version of the response - if response == response.upper(): + if text == text.upper(): return True else: return False @@ -442,9 +439,9 @@ def validate_uppercase(response: str) -> bool: # All Lowercase: Your entire response should be in English, and in all lowercase letters. No capital letters are # allowed. -def validate_lowercase(response: str) -> bool: +def validate_lowercase(text: str) -> bool: # Check if the response is the same as the lowercase version of the response - if response == response.lower(): + if text == text.lower(): return True else: return False @@ -452,8 +449,8 @@ def validate_lowercase(response: str) -> bool: # Frequency of All-capital Words: In your response, words with all capital letters should appear at least / around / # at most {N} times. -def validate_frequency_capital_words(response: str, N: int, quantifier: str) -> bool: - words = re.findall(r'\b[A-Z]+\b', response) +def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool: + words = re.findall(r'\b[A-Z]+\b', text) if quantifier == 'at least': return len(words) >= N elif quantifier == 'around': @@ -465,25 +462,25 @@ def validate_frequency_capital_words(response: str, N: int, quantifier: str) -> # End Checker: Finish your response with this exact phrase {end phrase}. No other words should follow this phrase. -def validate_end(response: str, end_phrase: str) -> bool: +def validate_end(text: str, end_phrase: str) -> bool: # Check if the response ends with the end phrase - if response.endswith(end_phrase): + if text.endswith(end_phrase): return True else: return False # Quotation: Wrap your entire response with double quotation marks. -def validate_quotation(response: str) -> bool: - if response.startswith('"') and response.endswith('"'): +def validate_quotation(text: str) -> bool: + if text.startswith('"') and text.endswith('"'): return True else: return False # No Commas: In your entire response, refrain from the use of any commas. -def validate_no_commas(response: str) -> bool: - if ',' not in response: +def validate_no_commas(text: str) -> bool: + if ',' not in text: return True else: return False From d61038e894975a19d79f42536c30a1203c9066ce Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Thu, 24 Oct 2024 19:40:34 +0000 Subject: [PATCH 36/53] data fixed, remove skips --- open_instruct/ground_truth_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index b6f3134e6..3d1cff51f 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -97,11 +97,6 @@ def verify_ifeval_sample(model_output, constraint): func_name = constraint.pop("func_name") # get the function func = IF_FUNCTIONS_MAP[func_name] - # for now, ignore this due to data issues. TODO: fix. - if func_name == "validate_repeat_prompt": - return False # missing 'original_prompt' - if func_name == "validate_sections": - return False # missing 'section_splitter' # now, run the function # pop out any none args non_none_args = {k:v for k,v in constraint.items() if v is not None} From 77f619d2da2557a0741201e17b6e197ae05db76a Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 29 Oct 2024 15:57:56 -0400 Subject: [PATCH 37/53] Prototype ppo + ray (#390) * Prototype ppo + ray * reduce gradient * push * push changes * quick push * cache changes; this actually works with 6 nodes * push changes * psuh changes * push the latest change * push changes * Fix uploading * Make style * style and quality * update docs * update mason.py * log wandb tables * update docs * make style quality * make sure to save the right thing * push changes * push * push * push changes * push * push * fix * remove preemption code * fix * push changes * push * quick fix * quick push --- docs/algorithms/ppo.md | 243 ++- mason.py | 13 +- open_instruct/dataset_processor.py | 65 +- open_instruct/ppo_vllm_thread.py | 225 ++- open_instruct/ppo_vllm_thread_ray_gtrl.py | 1684 +++++++++++++++++++ open_instruct/ppo_vllm_thread_ray_old.py | 1799 +++++++++++++++++++++ open_instruct/vllm_utils2.py | 240 +++ 7 files changed, 4136 insertions(+), 133 deletions(-) create mode 100644 open_instruct/ppo_vllm_thread_ray_gtrl.py create mode 100644 open_instruct/ppo_vllm_thread_ray_old.py create mode 100644 open_instruct/vllm_utils2.py diff --git a/docs/algorithms/ppo.md b/docs/algorithms/ppo.md index 1b4985e00..3de7f1aa4 100644 --- a/docs/algorithms/ppo.md +++ b/docs/algorithms/ppo.md @@ -88,6 +88,44 @@ python open_instruct/ppo_vllm_thread.py \ --gradient_checkpointing \ --with_tracking \ --push_to_hub + +# LEVEL 0.2: 3 GPUs; 1 GPU for actor and reference, 1 GPU for critic and reward model, and +# 1 GPU for vLLM generation. +python open_instruct/ppo_vllm_thread_ray_old.py \ + --dataset_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ + --dataset_train_splits train \ + --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ + --dataset_eval_splits validation \ + --max_token_length 1024 \ + --max_prompt_token_length 512 \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo_ray \ + --per_device_train_batch_size 8 \ + --local_rollout_forward_batch_size 8 \ + --local_mini_batch_size 128 \ + --local_rollout_batch_size 128 \ + --actor_num_gpus_per_node 1 \ + --ref_num_gpus_per_node 1 \ + --colocate_actor_ref \ + --critic_num_gpus_per_node 1 \ + --reward_num_gpus_per_node 1 \ + --colocate_critic_reward \ + --colocate_actor_ref \ + --vllm_tensor_parallel_size 1 \ + --deepspeed_stage 3 \ + --num_epochs 1 \ + --num_mini_batches 1 \ + --total_episodes 10000 \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --non_stop_penalty \ + --stop_token eos \ + --beta 0.05 \ + --response_length 53 \ + --hf_metadata_dataset "" \ + --no_try_launch_beaker_eval_jobs \ + --with_tracking \ + --push_to_hub ``` @@ -100,43 +138,127 @@ Here we are using --vllm_device cuda:7 to say we want to launch the vllm generat # for running TL;DR you can likely use GPUs with less memory python mason.py \ --image nathanl/open_instruct_auto --pure_docker_mode \ - --cluster ai2/pluto-cirrascale ai2/prior-cirrascale ai2/s2-cirrascale ai2/general-cirrascale \ - --priority normal \ - --resumable \ + --cluster ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale \ + --priority high \ + --workspace ai2/tulu-3-dev \ --preemptible \ --budget ai2/allennlp \ --gpus 8 -- accelerate launch --num_processes 7 --config_file configs/ds_configs/deepspeed_zero3.yaml \ open_instruct/ppo_vllm_thread.py \ + --exp_name "ppo_vllm_thread_ds3" \ --dataset_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_train_splits train \ --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ --dataset_eval_splits validation \ --max_token_length 1024 \ - --max_prompt_token_length 512 \ + --max_prompt_token_lenth 512 \ --learning_rate 3e-6 \ --output_dir models/minimal/ppo_vllm_thread_tldr \ --per_device_train_batch_size 16 \ --local_rollout_forward_batch_size 32 \ --gradient_accumulation_steps 4 \ - --num_epochs 1 \ + --num_epochs 4 \ --num_mini_batches 1 \ --total_episodes 1000000 \ --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ --non_stop_penalty \ --stop_token eos \ - --beta 0.1 \ + --beta 0.05 \ --response_length 53 \ --with_tracking \ --push_to_hub \ --hf_metadata_dataset '""' \ --no_try_launch_beaker_eval_jobs \ --vllm_device cuda:7 + +# or with ray +python mason.py \ + --cluster ai2/allennlp-elara-cirrascale ai2/jupiter-cirrascale-2 ai2/saturn-cirrascale --image costah/open_instruct_ppo_ray --pure_docker_mode \ + --priority high \ + --workspace ai2/tulu-3-dev \ + --preemptible \ + --budget ai2/allennlp \ + --gpus 8 -- python open_instruct/ppo_vllm_thread_ray1.py \ + --exp_name ppo_vllm_thread_ray_not_ds_adamw \ + --dataset_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ + --dataset_train_splits train \ + --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ + --dataset_eval_splits validation \ + --max_token_length 1024 \ + --max_prompt_token_length 512 \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo_ray \ + --per_device_train_batch_size 16 \ + --local_rollout_forward_batch_size 32 \ + --local_mini_batch_size 256 \ + --local_rollout_batch_size 256 \ + --actor_num_gpus_per_node 4 \ + --ref_num_gpus_per_node 4 \ + --colocate_actor_ref \ + --critic_num_gpus_per_node 2 \ + --reward_num_gpus_per_node 1 \ + --vllm_tensor_parallel_size 1 \ + --deepspeed_stage 3 \ + --num_epochs 4 \ + --num_mini_batches 1 \ + --total_episodes 1000000 \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --non_stop_penalty \ + --stop_token eos \ + --beta 0.05 \ + --response_length 53 \ + --hf_metadata_dataset '""' \ + --no_try_launch_beaker_eval_jobs \ + --with_tracking \ + --push_to_hub \ + + + +python open_instruct/ppo_vllm_thread_ray2.py \ + --exp_name ppo_vllm_thread_ray_not_ds_adamw \ + --dataset_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ + --dataset_train_splits train \ + --dataset_eval_mixer '{"trl-internal-testing/tldr-preference-sft-trl-style": 1.0}' \ + --dataset_eval_splits validation \ + --max_token_length 1024 \ + --max_prompt_token_length 512 \ + --learning_rate 3e-6 \ + --output_dir models/minimal/ppo_ray \ + --per_device_train_batch_size 16 \ + --local_rollout_forward_batch_size 32 \ + --local_mini_batch_size 256 \ + --local_rollout_batch_size 256 \ + --colocate_everything \ + --actor_num_gpus_per_node 7 \ + --ref_num_gpus_per_node 7 \ + --critic_num_gpus_per_node 7 \ + --reward_num_gpus_per_node 7 \ + --vllm_tensor_parallel_size 1 \ + --deepspeed_stage 3 \ + --num_epochs 4 \ + --num_mini_batches 1 \ + --total_episodes 1000000 \ + --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \ + --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \ + --non_stop_penalty \ + --stop_token eos \ + --beta 0.05 \ + --response_length 53 \ + --hf_metadata_dataset '""' \ + --no_try_launch_beaker_eval_jobs ``` * Tracked experiment: https://wandb.ai/ai2-llm/open_instruct_internal/runs/by8j2ejp * Trained model: https://huggingface.co/vwxyzjn/ppo_vllm_thread__cleanrl_EleutherAI_pythia-1b-deduped__sft__tldr/tree/ppo_vllm_thread__1__1726110645 +Ray's exps: + +* Tracked experiment: https://wandb.ai/ai2-llm/open_instruct_internal/runs/dixgu9lk +* Trained model: https://huggingface.co/allenai/open_instruct_dev/tree/ppo_vllm_thread_ray__1__1729740264 + +https://wandb.ai/ai2-llm/open_instruct_internal/runs/kzs3h76g?nw=nwusercostah ### LEVEL 2: 8 GPU; Huggingface no robot @@ -184,11 +306,62 @@ python mason.py \ --gradient_checkpointing \ --with_tracking \ --push_to_hub + + # --cluster ai2/pluto-cirrascale \ +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 --image nathanl/open_instruct_auto --pure_docker_mode \ + --workspace ai2/tulu-3-dev \ + --priority high \ + --preemptible \ + --budget ai2/allennlp \ + --gpus 8 -- python open_instruct/ppo_vllm_thread_ray.py \ + --dataset_mixer '{"HuggingFaceH4/no_robots": 1.0}' \ + --dataset_train_splits train \ + --dataset_eval_mixer '{"HuggingFaceH4/no_robots": 1.0}' \ + --dataset_eval_splits test \ + --max_token_length 1024 \ + --max_prompt_token_length 512 \ + --learning_rate 8e-7 \ + --output_dir models/minimal/ppo_ray \ + --per_device_train_batch_size 1 \ + --local_rollout_forward_batch_size 1 \ + --local_mini_batch_size 128 \ + --local_rollout_batch_size 128 \ + --actor_num_gpus_per_node 4 \ + --ref_num_gpus_per_node 4 \ + --colocate_actor_ref \ + --critic_num_gpus_per_node 2 \ + --reward_num_gpus_per_node 1 \ + --vllm_tensor_parallel_size 1 \ + --deepspeed_stage 3 \ + --num_epochs 1 \ + --num_mini_batches 1 \ + --total_episodes 100000 \ + --model_name_or_path allenai/open_instruct_dev \ + --model_revision finetune__meta-llama_Meta-Llama-3.1-8B__42__1726352218 \ + --reward_model_path allenai/open_instruct_dev \ + --reward_model_revision reward_modeling__2__1727077902 \ + --non_stop_penalty \ + --stop_token eos \ + --penalty_reward_value -10.0 \ + --beta 0.03 \ + --num_evals 3 \ + --seed 3 \ + --response_length 1024 \ + --gradient_checkpointing \ + --with_tracking \ + --push_to_hub ``` * Tracked experiment: https://wandb.ai/ai2-llm/open_instruct_internal/runs/jvjegpcq * Trained model: https://huggingface.co/vwxyzjn/ppo_vllm_thread_beta_0.03__allenai_open_instruct_dev/tree/ppo_vllm_thread_beta_0.03__3__1726244716 +Ray's exps: + +* Tracked experiment: https://wandb.ai/ai2-llm/open_instruct_internal/runs/kzs3h76g +* Trained model: https://huggingface.co/allenai/open_instruct_dev/tree/ppo_vllm_thread_ray__3__1729717405 + + ### LEVEL 3: 8 GPU; Training on ultrafeedback RM @@ -238,8 +411,62 @@ python mason.py \ --push_to_hub ``` -* Tracked experiment: https://wandb.ai/ai2-llm/open_instruct_internal/runs/z9035fv5/overview -* Trained model: https://huggingface.co/vwxyzjn/ppo_vllm_thread_beta_0.03__allenai_open_instruct_dev/tree/ppo_vllm_thread_beta_0.03__1__1726282755 +### LEVEL 4: 48 GPUs; 70B training + +```bash +TULU3_RM_REPO=allenai/open_instruct_dev +TULU3_RM1_REV=reward_modeling__2__1726890037 +python mason.py \ + --cluster ai2/jupiter-cirrascale-2 --image nathanl/open_instruct_auto --pure_docker_mode \ + --workspace ai2/tulu-3-dev \ + --priority high \ + --preemptible \ + --num_nodes 6 \ + --image costah/open_instruct_ppo_ray \ + --budget ai2/allennlp \ + --gpus 8 -- source configs/beaker_configs/beaker_configs/ray_node_setup.sh \&\& python open_instruct/ppo_vllm_thread_ray1.py \ + --exp_name ppo_vllm_thread_ray_70B \ + --dataset_mixer "{\"HuggingFaceH4/no_robots\": 9500}" \ + --dataset_train_splits train \ + --dataset_eval_mixer "{\"HuggingFaceH4/no_robots\": 1.0}" \ + --dataset_eval_splits test \ + --max_token_length 1024 \ + --max_prompt_token_length 1024 \ + --learning_rate 4e-7 \ + --output_dir models/minimal/ppo_ray_tldr \ + --per_device_train_batch_size 1 \ + --local_rollout_forward_batch_size 1 \ + --local_mini_batch_size 16 \ + --local_rollout_batch_size 16 \ + --actor_num_gpus_per_node 8 \ + --actor_num_nodes 4 \ + --ref_num_gpus_per_node 8 \ + --critic_num_gpus_per_node 4 \ + --reward_num_gpus_per_node 1 \ + --vllm_num_engines 1 \ + --vllm_tensor_parallel_size 2 \ + --deepspeed_stage 3 \ + --num_epochs 1 \ + --num_mini_batches 1 \ + --total_episodes 100000 \ + --model_name_or_path /weka/oe-adapt-default/jacobm/models/L3.1-70B-v3.7-nc \ + --reward_model_path $TULU3_RM_REPO \ + --reward_model_revision $TULU3_RM1_REV \ + --non_stop_penalty \ + --stop_token eos \ + --penalty_reward_value -10.0 \ + --beta 0.04 \ + --num_evals 3 \ + --seed 3 \ + --response_length 512 \ + --gradient_checkpointing \ + --with_tracking \ + --push_to_hub +``` + +* Beaker link: https://beaker.org/ex/01JAZENQBHGRQZ8MVAQGBX3GW3/tasks/01JAZENQBRYJN59GX6NDD8BK5J/job/01JAZENQGGYM59P5YK6X4XVSBQ +* Tracked experiment: https://wandb.ai/ai2-llm/open_instruct_internal/runs/w944sr4x +* Trained model: https://huggingface.co/allenai/open_instruct_dev/tree/ppo_vllm_thread_ray_70B__3__1729779947 ### Quality of life tools diff --git a/mason.py b/mason.py index 6f95d6452..20048db9e 100644 --- a/mason.py +++ b/mason.py @@ -17,6 +17,7 @@ def parse_beaker_dataset(dataset_str): "ai2/jupiter-cirrascale-2", "ai2/saturn-cirrascale", "ai2/neptune-cirrascale", + "ai2/allennlp-elara-cirrascale", ] @@ -193,17 +194,21 @@ def get_env_vars(pure_docker_mode: bool, cluster: List[str], beaker_secrets: Lis # if all cluster is in weka, we mount the weka elif all(c in WEKA_CLUSTERS for c in cluster): env_vars.extend([ + beaker.EnvVar( + name="HF_HOME", + value="/weka/oe-adapt-default/allennlp/.cache/huggingface", + ), beaker.EnvVar( name="HF_DATASETS_CACHE", - value="/weka/allennlp/.cache/huggingface", + value="/weka/oe-adapt-default/allennlp/.cache/huggingface", ), beaker.EnvVar( name="HF_HUB_CACHE", - value="/weka/allennlp/.cache/hub", + value="/weka/oe-adapt-default/allennlp/.cache/hub", ), beaker.EnvVar( name="CHECKPOINT_OUTPUT_DIR", - value=f"/weka/allennlp/deletable_checkpoint_states/{global_wandb_id}", + value=f"/weka/oe-adapt-default/allennlp/deletable_checkpoint_states/{global_wandb_id}", ), ]) if num_nodes > 1: @@ -256,7 +261,7 @@ def get_datasets(beaker_datasets, cluster: List[str]): res = [ beaker.DataMount( source=beaker.DataSource(weka="oe-adapt-default"), - mount_path="/weka", + mount_path="/weka/oe-adapt-default", ), ] for beaker_dataset in beaker_datasets: diff --git a/open_instruct/dataset_processor.py b/open_instruct/dataset_processor.py index 18f401b05..70247855e 100644 --- a/open_instruct/dataset_processor.py +++ b/open_instruct/dataset_processor.py @@ -372,6 +372,66 @@ def get_token_length_visualization(self, dataset: DatasetDict, save_path: str = class SFTDatasetProcessor(DatasetProcessor): + def tokenize(self, dataset: Dataset): + def tokenize_fn(row): + if len(row[self.config.sft_messages_key]) == 1: + prompt = row[self.config.sft_messages_key] + else: + prompt = row[self.config.sft_messages_key][:-1] + row[INPUT_IDS_PROMPT_KEY] = self.tokenizer.apply_chat_template( + prompt, + add_generation_prompt=True, + ) + row[INPUT_IDS_KEY] = self.tokenizer.apply_chat_template(row[self.config.sft_messages_key]) + row[ATTENTION_MASK_KEY] = [1] * len(row[INPUT_IDS_KEY]) + labels = copy.deepcopy(row[INPUT_IDS_KEY]) + if self.config.train_only_on_prompt: + labels[: len(row[INPUT_IDS_PROMPT_KEY])] = [-100] * len(row[INPUT_IDS_PROMPT_KEY]) + row[LABELS_KEY] = labels + return row + + return dataset.map( + tokenize_fn, + num_proc=get_num_proc(len(dataset), self.config.num_proc, APPLY_CHAT_TEMPLATE_EXAMPLE_PER_SECOND_PER_CPU), + load_from_cache_file=self.config.load_from_cache_file, + desc="Tokenizing and reformatting SFT data", + ) + + def filter(self, dataset: Dataset, need_contain_labels: bool = True): + def filter_fn(row): + max_prompt_token_length_ok = True + if self.config.max_prompt_token_length is not None: + max_prompt_token_length_ok = len(row[INPUT_IDS_PROMPT_KEY]) <= self.config.max_prompt_token_length + + max_token_length_ok = True + if self.config.max_token_length is not None: + max_token_length_ok = len(row[INPUT_IDS_KEY]) <= self.config.max_token_length + + contain_some_labels = any(x != -100 for x in row[LABELS_KEY]) + return ( + max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) + ) + + return dataset.filter( + filter_fn, + num_proc=get_num_proc(len(dataset), self.config.num_proc, FILTER_EXAMPLE_PER_SECOND_PER_CPU), + load_from_cache_file=self.config.load_from_cache_file, + desc="Filtering SFT data", + ) + + def get_token_length_stats(self, dataset: Union[Dataset, DatasetDict]): + return super().get_token_length_stats(features=[INPUT_IDS_PROMPT_KEY, INPUT_IDS_KEY], dataset=dataset) + + def get_token_length_visualization(self, dataset: DatasetDict, save_path: str = "tmp.png", bins: int = 30): + return super().get_token_length_visualization( + features=[INPUT_IDS_PROMPT_KEY, INPUT_IDS_KEY], + dataset=dataset, + save_path=save_path, + bins=bins, + ) + + +class SFTGroundTruthDatasetProcessor(DatasetProcessor): def tokenize(self, dataset: Dataset): def tokenize_fn(row): if len(row[self.config.sft_messages_key]) == 1: @@ -410,7 +470,9 @@ def filter_fn(row): max_token_length_ok = len(row[INPUT_IDS_KEY]) <= self.config.max_token_length contain_some_labels = any(x != -100 for x in row[LABELS_KEY]) - return max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) + return ( + max_prompt_token_length_ok and max_token_length_ok and (contain_some_labels or not need_contain_labels) + ) return dataset.filter( filter_fn, @@ -528,6 +590,7 @@ def __call__(self, batch: list[dict]): INPUT_IDS_PROMPT_KEY: padded_sequences, } + class SimpleGenerateCollatorWithGroundTruth: """Simple collator for generation task (always pad from the LEFT)""" diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 56474fef2..9a3ee1dd1 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -3,7 +3,6 @@ import os import random import shutil -import signal import subprocess import threading import time @@ -37,22 +36,22 @@ from open_instruct.dataset_processor import ( CHAT_TEMPLATES, - INPUT_IDS_PROMPT_KEY, - GROUND_TRUTHS_KEY, DATASET_SOURCE_KEY, + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, DatasetConfig, - SFTDatasetProcessor, + SFTGroundTruthDatasetProcessor, SimpleGenerateCollatorWithGroundTruth, visualize_token, ) from open_instruct.model_utils import ( ModelConfig, + apply_verifiable_reward, disable_dropout_in_model, exact_div, first_true_indices, forward, get_reward, - apply_verifiable_reward, prepare_deepspeed, print_rich_single_line_metrics, print_rich_table, @@ -472,11 +471,15 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # create the dataset dataset_dict = DatasetDict() - dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) train_dataset = combine_dataset( args.dataset_mixer_dict, splits=args.dataset_train_splits, - columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key, dataset_config.dataset_source_key], + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], ) if dataset_config.sanity_check: train_dataset = train_dataset.select( @@ -491,7 +494,11 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): eval_dataset = combine_dataset( args.dataset_eval_mixer_dict, splits=args.dataset_eval_splits, - columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key, dataset_config.dataset_source_key], + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], ) eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) with accelerator.main_process_first(): @@ -599,31 +606,6 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): break resume_training_step > 1 - # handle preemption - class PreemptionHandler: - preempted = False - - def __init__(self): - signal.signal(signal.SIGTERM, self.exit_gracefully) - - def exit_gracefully(self, signum, frame): - output_dir = os.path.join(args.checkpoint_output_dir, f"step_{training_step - 1}") - print(f"SIGTERM received, saving to {output_dir} from {accelerator.local_process_index}") - accelerator.save_state(output_dir) - if accelerator.is_main_process and args.with_tracking: - wandb.log({"preempted": True}, commit=True) - wandb.mark_preempting() - if accelerator.is_main_process: - try: - param_prompt_Q.put(None, timeout=20) - response_ids_Q.get(timeout=20) - print("vllm thread terminated") - except Exception as e: - print(e) - self.preempted = True - - ph = PreemptionHandler() - # deepspeed setup is_deepspeed_enabled = getattr(accelerator.state, "deepspeed_plugin", None) is not None mixed_precision = accelerator.state.mixed_precision @@ -680,10 +662,16 @@ def repeat_generator(): thread.start() torch.cuda.set_device(device) - g_vllm_responses = torch.zeros((args.batch_size * args.number_samples_per_prompt, args.response_length), device=device, dtype=torch.long) + g_vllm_responses = torch.zeros( + (args.batch_size * args.number_samples_per_prompt, args.response_length), device=device, dtype=torch.long + ) # set up the metrics and initial states - stats_shape = (args.num_epochs, args.num_mini_batches * args.number_samples_per_prompt, args.gradient_accumulation_steps) + stats_shape = ( + args.num_epochs, + args.num_mini_batches * args.number_samples_per_prompt, + args.gradient_accumulation_steps, + ) approxkl_stats = torch.zeros(stats_shape, device=device) pg_clipfrac_stats = torch.zeros(stats_shape, device=device) pg_loss_stats = torch.zeros(stats_shape, device=device) @@ -698,7 +686,7 @@ def repeat_generator(): # setup extraction model. For now keep on CPU? if args.answer_extraction_model: answer_extraction_model = AutoModelForCausalLM.from_pretrained(args.answer_extraction_model) - answer_extraction_tokenizer = AutoTokenizer.from_pretrained(aargs.answer_extraction_model) + answer_extraction_tokenizer = AutoTokenizer.from_pretrained(args.answer_extraction_model) else: answer_extraction_model = None answer_extraction_tokenizer = None @@ -708,7 +696,7 @@ def repeat_generator(): data = next(iter_dataloader) queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) ground_truths_next = data[GROUND_TRUTHS_KEY] - datsets_next = data[DATASET_SOURCE_KEY] + datasets_next = data[DATASET_SOURCE_KEY] send_queries(accelerator, None, tokenizer, param_prompt_Q, queries_next) for _ in range(1, resume_training_step): # we didn't store scheduler state @@ -719,9 +707,7 @@ def repeat_generator(): scheduler.step() queries = queries_next ground_truths = ground_truths_next - datasets = datsets_next - if ph.preempted: - break + datasets = datasets_next if accelerator.is_main_process: try: @@ -789,18 +775,11 @@ def repeat_generator(): response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) for response in g_response_token_ids ] - for item in g_padded_response_ids: - assert len(item) == args.response_length - for inner_item in item: - if not inner_item < config.vocab_size: - assert inner_item < config.vocab_size, f"{inner_item=}, {tokenizer.vocab_size=}" g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) g_vllm_responses[:] = g_padded_response_ids broadcast(g_vllm_responses, 0) local_vllm_responses = g_vllm_responses[ - accelerator.process_index - * queries.shape[0] : (accelerator.process_index + 1) - * queries.shape[0] + accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] ] query_responses = torch.cat((queries, local_vllm_responses), 1) for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): @@ -838,13 +817,19 @@ def repeat_generator(): ) if args.reward_model_multiplier != 1.0: score *= args.reward_model_multiplier - # also apply verifiable reward + # also apply verifiable reward if args.apply_verifiable_reward: # we need to batch the gt to match query. ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] dataset = datasets[i : i + args.local_rollout_forward_batch_size] verifiable_reward, verifiable_count = apply_verifiable_reward( - postprocessed_query_response, tokenizer, ground_truth, dataset, verify_reward=10, answer_extraction_model=answer_extraction_model, answer_extraction_tokenizer=answer_extraction_tokenizer + postprocessed_query_response, + tokenizer, + ground_truth, + dataset, + verify_reward=10, + answer_extraction_model=answer_extraction_model, + answer_extraction_tokenizer=answer_extraction_tokenizer, ) score += verifiable_reward else: @@ -869,9 +854,8 @@ def repeat_generator(): ref_logprobs = torch.cat(ref_logprobs, 0) sequence_lengths = torch.cat(sequence_lengths, 0) scores = torch.cat(scores, 0) - global_scores = accelerator.gather(scores) verifiable_counts = torch.cat(verifiable_counts, 0) - accelerator.print(f"global_scores: {global_scores}, {global_scores.mean()}") + verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0] values = torch.cat(values, 0) del (logprob, ref_logprob, full_value, value, score) gc.collect() @@ -908,7 +892,7 @@ def repeat_generator(): kl = kl2 elif args.kl_estimator == "kl3": kl = kl3 - print(f"{accelerator.local_process_index=}, {kl.sum(1)=}") + # print(f"{accelerator.local_process_index=}, {kl.sum(1)=}") non_score_reward = -args.beta * kl non_score_reward_sum = non_score_reward.sum(1) rlhf_reward = scores + non_score_reward_sum @@ -941,7 +925,9 @@ def repeat_generator(): for epoch_idx in range(args.num_epochs): b_inds = np.random.permutation(args.local_batch_size * args.number_samples_per_prompt) minibatch_idx = 0 - for mini_batch_start in range(0, args.local_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size): + for mini_batch_start in range( + 0, args.local_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size + ): mini_batch_end = mini_batch_start + args.local_mini_batch_size mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 @@ -1031,7 +1017,7 @@ def repeat_generator(): local_metrics[14] = ratio_stats.var() local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() - local_metrics[17] = verifiable_counts.mean() # verifiable count = % of time we trigger the verifiable reward + local_metrics[17] = verifiable_correct_rate global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() metrics = { "episode": episode, @@ -1057,7 +1043,7 @@ def repeat_generator(): "policy/entropy_avg": global_metrics[12], "val/ratio": global_metrics[13], "val/ratio_var": global_metrics[14], - "objective/verifiable_counts": global_metrics[17] + "objective/verifiable_correct_rate": global_metrics[17], } if accelerator.is_main_process: print_rich_single_line_metrics(metrics) @@ -1084,74 +1070,73 @@ def repeat_generator(): ) del original_tokenizer - if not ph.preempted: - # save model - os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) - original_tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, revision=model_config.model_revision - ) - save_with_accelerate( + # save model + os.makedirs(os.path.dirname(args.output_dir), exist_ok=True) + original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + save_with_accelerate( + accelerator, + model, + original_tokenizer, + args.output_dir, + model_attribute_to_save="policy", + ) + + # Ai2 specific logic + if is_beaker_job() and accelerator.is_main_process: + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "sft", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: + command = f"""\ + python mason.py \ + --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --model_name {args.hf_repo_revision} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + if args.push_to_hub: + push_folder_to_hub( accelerator, - model, - original_tokenizer, args.output_dir, - model_attribute_to_save="policy", + args.hf_repo_id, + args.hf_repo_revision, ) - # Ai2 specific logic - if is_beaker_job() and accelerator.is_main_process: - if args.hf_metadata_dataset: - dataset_list = list(args.dataset_mixer_dict.keys()) - # mainly just focussing here on what would be useful for the leaderboard. - # wandb will have even more useful information. - metadata_blob = { - "model_name": args.exp_name, - "model_type": "sft", - "datasets": dataset_list, - "base_model": model_config.model_name_or_path, - "wandb_path": wandb.run.get_url(), - "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls, - } - upload_metadata_to_hf( - metadata_blob, - "metadata.json", - args.hf_metadata_dataset, - "results/" + args.hf_repo_revision, # to match what the auto-evals name as. - ) - - if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: - command = f"""\ - python mason.py \ - --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ - --priority low \ - --preemptible \ - --budget ai2/allennlp \ - --workspace ai2/tulu-2-improvements \ - --image nathanl/open_instruct_auto \ - --pure_docker_mode \ - --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ - --beaker_workload_id {beaker_config.beaker_workload_id} \ - --model_name {args.hf_repo_revision} - """ - process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") - print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") - print(f"Submit jobs after model training is finished - process return code: {process.returncode}") - - if args.push_to_hub: - push_folder_to_hub( - accelerator, - args.output_dir, - args.hf_repo_id, - args.hf_repo_revision, - ) - - if accelerator.is_main_process: - # remove args.checkpoint_output_dir - if os.path.exists(args.checkpoint_output_dir): - shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) + if accelerator.is_main_process: + # remove args.checkpoint_output_dir + if os.path.exists(args.checkpoint_output_dir): + shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) if __name__ == "__main__": diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py new file mode 100644 index 000000000..d4a3d63c9 --- /dev/null +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -0,0 +1,1684 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF +# which has the following license: +# Copyright [yyyy] [name of copyright owner] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import logging +import os +import random +import shutil +import socket +import subprocess +import threading +import time +from argparse import Namespace +from dataclasses import asdict, dataclass, field +from queue import Empty, Queue +from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple + +import deepspeed +import numpy as np +import pandas as pd +import ray +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils +import torch.utils.data +import vllm +from datasets import Dataset, DatasetDict +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from huggingface_hub import HfApi +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.queue import Queue as RayQueue +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rich.pretty import pprint +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + get_scheduler, +) +from transformers.deepspeed import HfDeepSpeedConfig +from vllm import SamplingParams + +from open_instruct.dataset_processor import ( + CHAT_TEMPLATES, + DATASET_SOURCE_KEY, + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + DatasetConfig, + SFTGroundTruthDatasetProcessor, + SimpleGenerateCollatorWithGroundTruth, + visualize_token, +) +from open_instruct.model_utils import ( + ModelConfig, + apply_verifiable_reward, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + print_rich_single_line_metrics, + print_rich_table, + push_folder_to_hub, + truncate_response, +) +from open_instruct.utils import ( + ArgumentParserPlus, + combine_dataset, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_hf_entity, + maybe_use_ai2_wandb_entity, + upload_metadata_to_hf, +) +from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group + +api = HfApi() +INVALID_LOGPROB = 1.0 + + +@dataclass +class Args: + # required dataset args + dataset_mixer: str = None + """A dictionary of datasets (local or HF) to sample from.""" + dataset_train_splits: List[str] = None + """The dataset splits to use for training""" + dataset_eval_mixer: Optional[str] = None + """A dictionary of datasets (local or HF) to sample from for evaluation""" + dataset_eval_splits: Optional[List[str]] = None + """The dataset splits to use for evaluation""" + dataset_mixer_dict: Optional[dict] = None + """The dataset mixer as a dictionary""" + dataset_eval_mixer_dict: Optional[dict] = None + """The dataset eval mixer as a dictionary""" + + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + seed: int = 1 + """Seed of the experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + + # optimizer args + eps: float = 1e-5 + """The epsilon value for the optimizer""" + learning_rate: float = 2e-5 + """The initial learning rate for AdamW optimizer.""" + lr_scheduler_type: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" + ] = "linear" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + # various batch sizes + num_train_epochs: int = 1 + """Number of epochs to train""" + gradient_accumulation_steps: Optional[int] = None + """The number of gradient accumulation steps""" + per_device_train_batch_size: Optional[int] = 1 + """The forward batch size per device (local_micro_batch_size)""" + per_device_eval_batch_size: Optional[int] = 1 + """The forward batch size per device for evaluation (local_micro_batch_size)""" + total_episodes: Optional[int] = 100000 + """The total number of episodes in the dataset""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_rollout_batch_size: int = 64 + """The number of rollout episodes per iteration per device""" + rollout_batch_size: Optional[int] = None + """The number of rollout episodes per iteration""" + num_training_steps: Optional[int] = None + """The number of training_steps to train""" + num_evals: int = 4 + """The number of evaluations to run throughout training""" + eval_freq: Optional[int] = None + """The frequency of evaluation steps""" + local_dataloader_batch_size: Optional[int] = None + """The batch size per GPU for the dataloader""" + save_freq: int = -1 + """How many train steps to save the model""" + + # online settings + num_epochs: int = 4 + """the number of epochs to train""" + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + local_mini_batch_size: int = 64 + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + reward_model_path: str = "EleutherAI/pythia-160m" + """the path to the reward model""" + reward_model_revision: Optional[str] = None + """the revision of the reward model""" + init_value_from_scratch: bool = False + """whether to initialize the value model from scratch""" + + # generation config + response_length: int = 53 + """the length of the response""" + stop_token: Optional[Literal["eos", "period"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + min_response_length: int = 0 + """stop only after this many tokens""" + temperature: float = 0.7 + """the sampling temperature""" + penalty_reward_value: float = -1.0 + """the reward value for responses that do not contain `stop_token_id`""" + non_stop_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + number_samples_per_prompt: int = 1 + """the number of samples to generate per prompt, useful for easy-star""" + + # online PPO specific args + beta: float = 0.05 + """the beta value of the RLHF objective (KL coefficient)""" + whiten_rewards: bool = False + """whether to whiten the rewards""" + cliprange: float = 0.2 + """the clip range""" + vf_coef: float = 0.1 + """the value function coefficient""" + cliprange_value: float = 0.2 + """the clip range for the value function""" + gamma: float = 1 + """the discount factor""" + lam: float = 0.95 + """the lambda value for GAE""" + kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" + """the KL estimator to use""" + apply_verifiable_reward: bool = False + """whether to apply verifiable reward""" + reward_model_multiplier: float = 1.0 + """the reward model multiplier, for down/upscaling the reward model output""" + answer_extraction_model: str = None + + # async setting + async_mode: bool = True + """Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" + + # ray + actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1]) + """number of gpus per node for actor""" + vllm_num_engines: int = 1 + """number of vLLM Engines, set to 0 to disable vLLM""" + vllm_tensor_parallel_size: int = 1 + """tensor parallel size of vLLM Engine for multi-GPU inference""" + vllm_sync_backend: str = "nccl" + """DeepSpeed -> vLLM weight sync backend""" + enable_prefix_caching: bool = False + """whether to enable prefix caching""" + deepspeed_stage: int = 0 + + # wandb and HF tracking configs + with_tracking: bool = False + """If toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "open_instruct_internal" + """The wandb's project name""" + wandb_entity: Optional[str] = None + """The entity (team) of wandb's project""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + output_dir: Optional[str] = None + """Where to save the model""" + checkpoint_output_dir: Optional[str] = None + """Where to save the model checkpoints in case of preemption""" + + # Ai2 specific settings + try_launch_beaker_eval_jobs: bool = True + """Whether to launch beaker evaluation jobs after training""" + hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" + """What dataset to upload the metadata to. If unset, don't upload metadata""" + + def __post_init__(self): + self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) + if self.dataset_eval_mixer is not None: + self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) + + +def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: + # if passed through cli: convert the dataset mixers to dictionaries + if isinstance(value, str): + return json.loads(value), value + # if passed through yaml: convert the dataset mixers to strings + elif isinstance(value, dict): + return value, json.dumps(value) + else: + raise ValueError("Input must be either a string or a dictionary") + + +def calculate_runtime_args(args: Args, model_config: ModelConfig): + """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" + # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + # args.world_size = accelerator.num_processes + args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + args.gradient_accumulation_steps = exact_div( + args.local_mini_batch_size, + args.per_device_train_batch_size, + "`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`", + ) + args.world_size = sum(args.actor_num_gpus_per_node) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.mini_batch_size = int(args.local_mini_batch_size * args.world_size) + args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt) + args.eval_freq = max(1, args.num_training_steps // args.num_evals) + # PPO logic: do checks and set up dataloader batch size + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + args.local_dataloader_batch_size = args.rollout_batch_size + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: # first try to use AI2 entity + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: # then try to use the user's entity + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: # auto-generate one + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if args.with_tracking: + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + + +def get_train_ds_config( + offload, + adam_offload=False, + stage=0, + bf16=True, + max_norm=1.0, + zpg=8, + grad_accum_dtype=None, + disable_trace_cache=True, +): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": { + "device": "cpu" if adam_offload else "none", + "pin_memory": True, + }, + "sub_group_size": "auto", + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "reduce_bucket_size": "auto", + # # ZeRO++ + # "zero_hpz_partition_size": zpg, + # "zero_quantized_weights": False, + # "zero_quantized_gradients": False, + } + if disable_trace_cache: + zero_opt_dict["stage3_prefetch_bucket_size"] = 0 + zero_opt_dict["stage3_max_live_parameters"] = 0 + zero_opt_dict["stage3_max_reuse_distance"] = 0 + + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": max_norm, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, + } + + +def get_eval_ds_config( + offload, + stage=0, + bf16=True, +): + zero_opt_dict = { + "stage": stage, + "stage3_param_persistence_threshold": "auto", + "offload_param": { + "device": "cpu" if offload else "none", + "pin_memory": True, + }, + } + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def remove_padding(sequences, pad_token_id): + return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences] + + +class ShufflingIterator: + def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): + self.data = data.copy() + self.batch_size = batch_size + self.index = 0 + self.rng = np.random.default_rng(seed) + self.rng.shuffle(self.data) + + # Ensure the effective dataset size is divisible by batch_size + self.effective_size = len(self.data) - (len(self.data) % batch_size) + + def __iter__(self) -> Iterator[List[int]]: + return self + + def __next__(self) -> List[int]: + if self.index >= self.effective_size: + self.index = 0 + self.rng.shuffle(self.data) + + end_index = self.index + self.batch_size + batch = self.data[self.index : end_index].tolist() + self.index = end_index + + return batch + + +class RayProcess: + def __init__(self, world_size, rank, local_rank, master_addr, master_port): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + self.world_size = world_size + self.rank = rank + self.local_rank = local_rank + self.master_addr = master_addr if master_addr else self.get_current_node_ip() + self.master_port = master_port if master_port else self.get_free_port() + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["RANK"] = str(self.rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + # os.environ["LOCAL_RANK"] = str(self._local_rank) + os.environ["LOCAL_RANK"] = "0" + random.seed(self.rank) + np.random.seed(self.rank) + torch.manual_seed(self.rank) + + @staticmethod + def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + @staticmethod + def get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self.master_addr, self.master_port + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + +@ray.remote(num_gpus=1) +class PolicyTrainerRayProcess(RayProcess): + def from_pretrained(self, args: Args, model_config: ModelConfig): + self.args = args + torch.cuda.set_device(self.local_rank) + deepspeed.init_distributed() + + ds_config = get_train_ds_config( + offload=False, + adam_offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.policy) + self.policy.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + print(ds_config) + self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.policy, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.model.train() + + # value model + self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + if args.init_value_from_scratch: + self.value_model.init_weights() # re-initialize the value model from scratch + disable_dropout_in_model(self.value_model) + self.value_model.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.value_model, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.value_model.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + self.value_model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.value_model, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.value_model.train() + + # reference model + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.ref_policy) + self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) + self.ref_policy.eval() + + # reward model + self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.reward_model) + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config) + self.reward_model.eval() + + def get_vocab_size(self): + return self.policy.config.vocab_size + + def forward( + self, + query_response: torch.LongTensor, + response: torch.LongTensor, + pad_token_id: int, + context_length: int, + temperature: float, + ) -> torch.Tensor: + output = forward(self.model, query_response, pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + return logprob + + def train( + self, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizer, + vllm_engines: List[ray.actor.ActorHandle], + metrics_queue: RayQueue, + data_collator: Callable, + ): + torch.set_printoptions(precision=4, sci_mode=False) + + args = self.args + + accelerator = Namespace() + accelerator.process_index = self.rank + accelerator.num_processes = self.world_size + accelerator.is_main_process = self.rank == 0 + torch.distributed.barrier() + if self.rank == 0: + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + vllm_num_engines, vllm_tensor_parallel_size = ( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + ) + world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 + backend = args.vllm_sync_backend + # https://github.com/OpenRLHF/OpenRLHF/issues/313 + if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0": + backend = "gloo" + print( + "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)" + ) + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * vllm_tensor_parallel_size + 1, + world_size, + "openrlhf", + backend=backend, + ) + for i, engine in enumerate(vllm_engines) + ] + self.model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name="openrlhf", + ) + ray.get(refs) + torch.distributed.barrier() + + def broadcast_to_vllm(): + # avoid OOM + torch.cuda.empty_cache() + model = self.model.module + count, num_params = 0, len(list(model.named_parameters())) + refss = [] + with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3): + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + # print(f"broadcasting {name=} {shape=}") + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 + # with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + # ray.get(refs) + # print(f"broadcasting {name=} {shape=} success") + if torch.distributed.get_rank() == 0: + ray.get(refss) + + # broadcast_to_vllm() + if args.stop_token: + if args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id + if args.stop_token == "period": + args.stop_token_id = tokenizer.encode(".")[0] + # data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + train_dataset_idxs = np.arange(len(train_dataset)) + shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed) + + # hack to left pad + def repeat_generator(): + while True: + batch_idxs = next(shuffling_iter) + yield [train_dataset[i] for i in batch_idxs] + + iter_dataloader = iter(repeat_generator()) + generation_config = SamplingParams( + temperature=args.temperature, + top_p=1.0, + max_tokens=args.response_length, + include_stop_str_in_output=True, + n=args.number_samples_per_prompt, + ) + print("setup async queues") + param_prompt_Q = None + response_ids_Q = None + evaluation_Q = None + response_ids_Q = Queue(maxsize=1) + param_prompt_Q = Queue(maxsize=1) + evaluation_Q = Queue(maxsize=1) + num_eval_samples = 32 + sample_evaluation_prompt_token_ids = None + if eval_dataset is not None: + sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] + + def vllm_generate( + generation_config: SamplingParams, + response_ids_Q: Queue, + param_prompt_Q: Queue, + num_training_steps: int, + sample_evaluation_prompt_token_ids: Optional[List[int]], + evaluation_Q: Queue, + eval_freq: int, + resume_training_step: int, + ): + llm = vllm_engines[0] + for training_step in range(resume_training_step, num_training_steps + 1): + items = param_prompt_Q.get() + if items is None: + break + unwrapped_model, g_queries_list = items + # if unwrapped_model is not None: + generation_start_time = time.time() + + outputs = ray.get( + llm.generate.remote(sampling_params=generation_config, prompt_token_ids=g_queries_list) + ) + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") + response_ids_Q.put(response_ids) + + if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: + outputs = ray.get( + llm.generate.remote( + prompt_token_ids=sample_evaluation_prompt_token_ids, sampling_params=generation_config + ) + ) + # for evaluation, even if we have multiple outputs, we only look at one of them for simplicity + response_ids = [list(output.outputs[0].token_ids) for output in outputs] + evaluation_Q.put(response_ids) + + resume_training_step = 1 + if accelerator.is_main_process: + thread = threading.Thread( + target=vllm_generate, + args=( + generation_config, + response_ids_Q, + param_prompt_Q, + args.num_training_steps, + sample_evaluation_prompt_token_ids, + evaluation_Q, + args.eval_freq, + resume_training_step, + ), + ) + thread.start() + print("vllm generate thread starts") + + # set up the metrics and initial states + device = torch.device(self.local_rank) + g_vllm_responses = torch.zeros( + (args.rollout_batch_size * args.number_samples_per_prompt, args.response_length), + device=device, + dtype=torch.long, + ) + stats_shape = ( + args.num_epochs, + args.num_mini_batches * args.number_samples_per_prompt, + args.gradient_accumulation_steps, + ) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + local_metrics = torch.zeros((20,), device=device) + episode = args.rollout_batch_size * (resume_training_step - 1) + + # training loop + start_time = time.time() + global_data = next(iter_dataloader) + data = data_collator( + global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size] + ) + global_queries = data_collator(global_data)[ + INPUT_IDS_PROMPT_KEY + ].tolist() # can be simplified since we `remove_padding` later anyway + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + + answer_extraction_model = None + answer_extraction_tokenizer = None + # for _ in range(1, resume_training_step): # we didn't store scheduler state + # scheduler.step() + + for training_step in range(resume_training_step, args.num_training_steps + 1): + episode += args.rollout_batch_size * args.number_samples_per_prompt # each sample is an episode + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + if accelerator.is_main_process: + df = None + try: + evaluation_responses = evaluation_Q.get(timeout=0.01) + print("🔥🔥🔥 Evaluation responses received") + table = {} + table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids) + table["response"] = tokenizer.batch_decode(evaluation_responses) + table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] + df = pd.DataFrame(table) + del table + except Empty: + print("🙈 Evaluation responses not received") + + # (optionally) evaluate the model + if args.async_mode: + if training_step != 1: + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + else: + if training_step != 1: + # NOTE: important: the indent here is different for sync mode + # we also set to use `queries = queries_next` immediately + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + torch.cuda.empty_cache() + # print('get reward stuff starts') + # if we generate multiple samples per prompt, we need to repeat the queries and ground truths + # to match the vllm outputs. + if args.number_samples_per_prompt > 1: + queries = queries.repeat_interleave(args.number_samples_per_prompt, dim=0) + ground_truths = [gt for gt in ground_truths for _ in range(args.number_samples_per_prompt)] + datasets = [ds for ds in datasets for _ in range(args.number_samples_per_prompt)] + + training_time_start = time.time() + with torch.no_grad(): + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + verifiable_counts = [] + sequence_lengths = [] + values = [] + if accelerator.is_main_process: + g_response_token_ids = response_ids_Q.get() + DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out + g_padded_response_ids = [ + response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) + for response in g_response_token_ids + ] + g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) + g_vllm_responses[:] = g_padded_response_ids + dist.broadcast(g_vllm_responses, src=0) + local_vllm_responses = g_vllm_responses[ + accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] + ] + # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") + query_responses = torch.cat((queries, local_vllm_responses), 1) + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + print(f"get reward stuff starts {i=}") + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + + logprob = self.forward( + query_response, response, tokenizer.pad_token_id, context_length, args.temperature + ) + torch.cuda.empty_cache() + + ref_output = forward(self.ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + # print("get reward stuff starts 2") + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + if args.reward_model_multiplier != 1.0: + score *= args.reward_model_multiplier + # also apply verifiable reward + if args.apply_verifiable_reward: + # we need to batch the gt to match query. + ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] + dataset = datasets[i : i + args.local_rollout_forward_batch_size] + verifiable_reward, verifiable_count = apply_verifiable_reward( + postprocessed_query_response, + tokenizer, + ground_truth, + dataset, + verify_reward=10, + answer_extraction_model=answer_extraction_model, + answer_extraction_tokenizer=answer_extraction_tokenizer, + ) + score += verifiable_reward + else: + verifiable_count = torch.tensor([0.0], device=device).float() + full_value, _, _ = get_reward( + self.value_model, query_response, tokenizer.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + verifiable_counts.append(verifiable_count) + # print(f"get reward stuff starts 5") + + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + verifiable_counts = torch.cat(verifiable_counts, 0) + verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0] + values = torch.cat(values, 0) + # print(f"get reward stuff finished") + del (logprob, ref_logprob, full_value, value, score) + gc.collect() + torch.cuda.empty_cache() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_stop_token = torch.any(postprocessed_responses == args.stop_token_id, dim=-1) + # NOTE: only apply the stop token filter if the response is long enough + # otherwise the model could learn to generate the first token as the stop token + contain_stop_token = contain_stop_token & (sequence_lengths >= args.min_response_length) + if args.non_stop_penalty: + scores = torch.where( + contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value) + ) + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + # print(f"get reward stuff finished 2") + + # 4. compute rewards + kl1 = logprobs - ref_logprobs + kl2 = (kl1) ** 2 / 2 + kl3 = (-kl1).exp() - 1 + kl1 + if args.kl_estimator == "kl1": + kl = kl1 + elif args.kl_estimator == "kl2": + kl = kl2 + elif args.kl_estimator == "kl3": + kl = kl3 + # if self.rank==0: + # print(f"{logprobs[0][:40]=}, {ref_logprobs[0][:40]=}, {kl.sum(1)=}") + non_score_reward = -args.beta * kl + non_score_reward_sum = non_score_reward.sum(1) + rlhf_reward = scores + non_score_reward_sum + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + # print(f"get reward stuff finished 3") + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # print('gae') + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # print('training starts') + # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch + for epoch_idx in range(args.num_epochs): + b_inds = np.random.permutation(args.local_rollout_batch_size * args.number_samples_per_prompt) + minibatch_idx = 0 + for mini_batch_start in range( + 0, args.local_rollout_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size + ): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + print("micro batch start", micro_batch_start, self.rank) + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds] + + vpred_temp = get_reward( + self.value_model, mb_query_responses, tokenizer.pad_token_id, context_length + ) + vpred_temp = vpred_temp[0] + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1) + self.value_model.backward(vf_loss * args.vf_coef) + self.value_model.step() + + new_logprobs = self.forward( + mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature + ) + # if self.rank==0: + # print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}") + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + self.model.backward(loss) + # print("backward loss", self.rank, "micro batch start", micro_batch_start) + # print("trying to step", self.rank, "micro batch start", micro_batch_start) + self.model.step() + # print("step", self.rank, "micro batch start", micro_batch_start) + with torch.no_grad(): + # print("waiting for value model step", self.rank, "micro batch start", micro_batch_start) + # vf_loss, vf_clipfrac = ray.get(value_model_step_future) + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1) + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + # print("value model stepped", self.rank, "micro batch start", micro_batch_start) + # prob_dist = torch.nn.functional.softmax(logits, dim=-1) + # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + # entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # fmt: off + del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1 + del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss + # del vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max + # del vf_loss, vf_clipfrac, pg_clipfrac, approxkl + # fmt: on + # del everything and empty cache + torch.cuda.empty_cache() + del b_inds, mini_batch_inds + # print("start metrics") + with torch.no_grad(): + local_metrics[0] = sequence_lengths.float().mean() + local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() + local_metrics[2] = kl.sum(1).mean() + local_metrics[3] = (-logprobs).sum(1).mean() + local_metrics[4] = non_score_reward_sum.mean() + local_metrics[5] = rlhf_reward.mean() + local_metrics[6] = scores.mean() + local_metrics[7] = approxkl_stats.mean() + local_metrics[8] = pg_clipfrac_stats.mean() + local_metrics[9] = pg_loss_stats.mean() + local_metrics[10] = vf_loss_stats.mean() + local_metrics[11] = vf_clipfrac_stats.mean() + local_metrics[12] = entropy_stats.mean() + local_metrics[13] = ratio_stats.mean() + local_metrics[14] = ratio_stats.var() + local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() + local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() + local_metrics[17] = verifiable_correct_rate + # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() + local_metrics /= dist.get_world_size() + dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) + global_metrics = local_metrics.tolist() + metrics = { + "episode": episode, + "training_step": training_step, + "lr": self.scheduler.get_last_lr()[0], + "epoch": episode / len(train_dataset), + "time/from_scratch": time.time() - start_time, + "time/training": time.time() - training_time_start, + "val/sequence_lengths": global_metrics[0], + "val/num_stop_token_ids": global_metrics[1], + "objective/kl": global_metrics[2], + "objective/kl2": global_metrics[15], + "objective/kl3": global_metrics[16], + "objective/entropy": global_metrics[3], + "objective/non_score_reward": global_metrics[4], + "objective/rlhf_reward": global_metrics[5], + "objective/scores": global_metrics[6], + "policy/approxkl_avg": global_metrics[7], + "policy/clipfrac_avg": global_metrics[8], + "loss/policy_avg": global_metrics[9], + "loss/value_avg": global_metrics[10], + "val/clipfrac_avg": global_metrics[11], + "policy/entropy_avg": global_metrics[12], + "val/ratio": global_metrics[13], + "val/ratio_var": global_metrics[14], + "objective/verifiable_correct_rate": global_metrics[17], + } + if accelerator.is_main_process: + print_rich_single_line_metrics(metrics) + metrics_queue.put((metrics, episode, df)) + del (queries, responses, postprocessed_responses, logprobs, ref_logprobs, sequence_lengths, scores, values) + del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) + gc.collect() + torch.cuda.empty_cache() + print(f"finished training {training_step}") + + # save steps + if args.save_freq > 0 and training_step % args.save_freq == 0: + step_dir = os.path.join(args.output_dir, f"step_{training_step}") + os.makedirs(step_dir, exist_ok=True) + self.save_model(step_dir) + print("finished training") + + def save_model(self, output_dir: str) -> None: + if self.rank == 0: + os.makedirs(output_dir, exist_ok=True) + + # save model weights for ZeRO2/3 + model_to_save = self.model + if hasattr(model_to_save, "module"): + model_to_save = model_to_save.module + + # gather parameters + output_state_dict = {} + for k, v in model_to_save.named_parameters(): + # only gather z3 params + params_to_fetch = _z3_params_to_fetch([v]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + vv = v.data.cpu() + if self.rank == 0: + output_state_dict[k] = vv + + if self.rank == 0: + state_dict = model_to_save.state_dict() + + # copy named_buffers with `persistent=True` + for k, v in model_to_save.named_buffers(): + if k not in state_dict: + continue + vv = v.data.cpu() + output_state_dict[k] = vv + + state_dict_keys = set(state_dict.keys()) + output_state_dict_keys = set(output_state_dict.keys()) + + # corner case for tie_word_embeddings, such as Qwen2-0.5B + if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: + state_dict_keys.remove("lm_head.weight") + + assert state_dict_keys.issubset( + output_state_dict_keys + ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" + + # # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 + # if isinstance(model_to_save, PeftModel): + # model_to_save.save_pretrained(output_dir, **kwargs) + # if self.stage == 3: + # torch.save( + # get_peft_model_state_dict(model_to_save, output_state_dict), + # os.path.join(output_dir, "adapter_model.bin"), + # ) + # else: + # save model + model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) + + # save tokenizer + self.original_tokenizer.save_pretrained(output_dir) + + +def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): + while True: + if stop_event.is_set(): + break + for ref in object_refs: + try: + ray.get(ref, timeout=0.01) + except ray.exceptions.GetTimeoutError: + pass + except Exception as e: + print(e) + print(f"Actor {ref} died") + time.sleep(120) + ray.shutdown() + os._exit(1) # Force shutdown the process + + time.sleep(30) + + +class ModelGroup: + def __init__( + self, + pg: PlacementGroup, + ray_process_cls: RayProcess, + num_gpus_per_node: List[int], + ): + self.pg = pg + self.ray_process_cls = ray_process_cls + self.num_gpus_per_node = num_gpus_per_node + self.num_gpus_per_actor = 1 + self.num_cpus_per_actor = 4 + self.models = [] + world_size = sum(self.num_gpus_per_node) + master_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=self.pg, placement_group_bundle_index=0 + ), + ).remote(world_size, 0, 0, None, None) + + self.models.append(master_policy) + master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) + + def get_bundle_index(rank, num_gpus_per_node): + """given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" + bundle_idx = 0 + while rank >= num_gpus_per_node[bundle_idx]: + rank -= num_gpus_per_node[bundle_idx] + bundle_idx += 1 + return bundle_idx + + assert get_bundle_index(0, [7, 8, 4]) == 0 + assert get_bundle_index(1, [7, 8, 4]) == 0 + assert get_bundle_index(7, [7, 8, 4]) == 1 + assert get_bundle_index(8, [7, 8, 4]) == 1 + assert get_bundle_index(9, [7, 8, 4]) == 1 + assert get_bundle_index(16, [7, 8, 4]) == 2 + + # Setup worker models + for rank in range(1, world_size): + print(f"{rank=}, {world_size=}, {rank=}, {master_addr=}, {master_port=}") + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.pg, + placement_group_bundle_index=get_bundle_index(rank, self.num_gpus_per_node), + ) + worker_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=scheduling_strategy, + ).remote(world_size, rank, 0, master_addr, master_port) + self.models.append(worker_policy) + + +def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + calculate_runtime_args(args, model_config) + + # set up experiment tracking and seeds + all_configs = {} + if is_beaker_job(): + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", args.output_dir) + beaker_config = maybe_get_beaker_config() + # try saving to the beaker `/output`, which will be uploaded to the beaker dataset + if len(beaker_config.beaker_dataset_id_urls) > 0: + args.output_dir = "/output" + all_configs.update(vars(beaker_config)) + all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) + if args.with_tracking: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=all_configs, + name=args.run_name, + save_code=True, + tags=[args.exp_name] + get_wandb_tags(), + ) + writer = SummaryWriter(f"runs/{args.run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # create a tokenizer (pad from right) + config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + ) + if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + + # create the dataset + dataset_dict = DatasetDict() + dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + train_dataset = combine_dataset( + args.dataset_mixer_dict, + splits=args.dataset_train_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + if dataset_config.sanity_check: + train_dataset = train_dataset.select( + range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) + ) + train_dataset = dataset_processor.tokenize(train_dataset) + train_dataset = dataset_processor.filter(train_dataset, need_contain_labels=False) + dataset_dict["train"] = train_dataset + eval_dataset = None + if args.dataset_eval_mixer is not None: + eval_dataset = combine_dataset( + args.dataset_eval_mixer_dict, + splits=args.dataset_eval_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) + eval_dataset = dataset_processor.tokenize(eval_dataset) + eval_dataset = dataset_processor.filter(eval_dataset, need_contain_labels=False) + dataset_dict["eval"] = eval_dataset + data_collator = SimpleGenerateCollatorWithGroundTruth(pad_token_id=tokenizer.pad_token_id) + + # some more runtime logging + pprint([args, dataset_config, model_config]) + visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) + if args.with_tracking: + # upload the visualized token length + dataset_processor.get_token_length_visualization( + dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" + ) + wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) + + # create the model and optimizer + pg = None + bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.actor_num_gpus_per_node] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + inits = [] + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + args.actor_num_gpus_per_node, + ) + inits.extend(model.from_pretrained.remote(args, model_config) for model in policy_group.models) + max_len = dataset_config.max_prompt_token_length + args.response_length + vllm_engines = create_vllm_engines( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + model_config.model_name_or_path, + model_config.model_revision, + args.seed, + args.enable_prefix_caching, + max_len, + ) + + metrics_queue = RayQueue() + ray.get(inits) + print("======== all models initialized =========") + ray.get(policy_group.models[0].get_vocab_size.remote()) + # print(f"{policy_vocab_size=}, {reward_vocab_size=}") + # if policy_vocab_size != reward_vocab_size: + # ray.shutdown() # shutdown here so this error message is not buried in the logs + # raise ValueError( + # "Policy and reward model must have the same vocab size. " + # f"Policy: {policy_vocab_size}, Reward: {reward_vocab_size}. " + # "If they don't have the same vocab size, the policy could generate tokens which " + # "is going to cause index out of bound error in the reward model." + # ) + + refs = [] + for i, policy_model in enumerate(policy_group.models): + refs.append( + policy_model.train.remote( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + vllm_engines=vllm_engines, + metrics_queue=metrics_queue, + data_collator=data_collator, + ) + ) + + # somtimes a worker dies due to CUDA issues, but the rest of the cluster would just hang + # so we need kill the ray cluster when this happens. + stop_event = threading.Event() + threading.Thread(target=kill_ray_cluster_if_a_worker_dies, args=(refs, stop_event)).start() + + # train and gather metrics + resume_training_step = 1 + for training_step in range(resume_training_step, args.num_training_steps + 1): + result = metrics_queue.get() + metrics, episode, df = result + for key, value in metrics.items(): + writer.add_scalar(key, value, episode) + + if df is not None: + if args.with_tracking: + wandb.log({"sample_completions": wandb.Table(dataframe=df)}) + else: + print_rich_table(df) + ray.get(refs) + + # save model + ray.get([policy_model.save_model.remote(args.output_dir) for policy_model in policy_group.models]) + ray.shutdown() + stop_event.set() + + # hack + accelerator = Namespace() + accelerator.is_main_process = True + + # Ai2 specific logic + if is_beaker_job() and accelerator.is_main_process: + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "sft", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: + command = f"""\ + python mason.py \ + --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --model_name {args.hf_repo_revision} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + if args.push_to_hub: + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + + if accelerator.is_main_process: + # remove args.checkpoint_output_dir + if os.path.exists(args.checkpoint_output_dir): + shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) + main(*parser.parse()) diff --git a/open_instruct/ppo_vllm_thread_ray_old.py b/open_instruct/ppo_vllm_thread_ray_old.py new file mode 100644 index 000000000..822c8a18d --- /dev/null +++ b/open_instruct/ppo_vllm_thread_ray_old.py @@ -0,0 +1,1799 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF +# which has the following license: +# Copyright [yyyy] [name of copyright owner] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import logging +import os +import random +import shutil +import socket +import subprocess +import threading +import time +from argparse import Namespace +from dataclasses import asdict, dataclass +from queue import Empty, Queue +from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple + +import deepspeed +import numpy as np +import pandas as pd +import ray +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils +import torch.utils.data +import vllm +from datasets import Dataset, DatasetDict +from deepspeed.ops.adam import FusedAdam +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from huggingface_hub import HfApi +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.queue import Queue as RayQueue +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rich.pretty import pprint +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + get_scheduler, +) +from transformers.deepspeed import HfDeepSpeedConfig +from vllm import SamplingParams + +from open_instruct.dataset_processor import ( + CHAT_TEMPLATES, + INPUT_IDS_PROMPT_KEY, + DatasetConfig, + SFTDatasetProcessor, + SimpleGenerateCollator, + visualize_token, +) +from open_instruct.model_utils import ( + ModelConfig, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + print_rich_single_line_metrics, + push_folder_to_hub, + truncate_response, +) +from open_instruct.utils import ( + ArgumentParserPlus, + combine_dataset, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_hf_entity, + maybe_use_ai2_wandb_entity, + upload_metadata_to_hf, +) +from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group + +api = HfApi() +INVALID_LOGPROB = 1.0 + + +@dataclass +class Args: + # required dataset args + dataset_mixer: str = None + """A dictionary of datasets (local or HF) to sample from.""" + dataset_train_splits: List[str] = None + """The dataset splits to use for training""" + dataset_eval_mixer: Optional[str] = None + """A dictionary of datasets (local or HF) to sample from for evaluation""" + dataset_eval_splits: Optional[List[str]] = None + """The dataset splits to use for evaluation""" + dataset_mixer_dict: Optional[dict] = None + """The dataset mixer as a dictionary""" + dataset_eval_mixer_dict: Optional[dict] = None + """The dataset eval mixer as a dictionary""" + + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + seed: int = 1 + """Seed of the experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + + # optimizer args + eps: float = 1e-5 + """The epsilon value for the optimizer""" + learning_rate: float = 2e-5 + """The initial learning rate for AdamW optimizer.""" + lr_scheduler_type: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" + ] = "linear" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + # various batch sizes + num_train_epochs: int = 1 + """Number of epochs to train""" + gradient_accumulation_steps: Optional[int] = None + """The number of gradient accumulation steps""" + per_device_train_batch_size: Optional[int] = 1 + """The forward batch size per device (local_micro_batch_size)""" + per_device_eval_batch_size: Optional[int] = 1 + """The forward batch size per device for evaluation (local_micro_batch_size)""" + total_episodes: Optional[int] = 100000 + """The total number of episodes in the dataset""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_rollout_batch_size: int = 64 + """The number of rollout episodes per iteration per device""" + rollout_batch_size: Optional[int] = None + """The number of rollout episodes per iteration""" + num_training_steps: Optional[int] = None + """The number of training_steps to train""" + num_evals: int = 4 + """The number of evaluations to run throughout training""" + eval_freq: Optional[int] = None + """The frequency of evaluation steps""" + local_dataloader_batch_size: Optional[int] = None + """The batch size per GPU for the dataloader""" + + # online settings + num_epochs: int = 4 + """the number of epochs to train""" + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + local_mini_batch_size: int = 64 + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + reward_model_path: str = "EleutherAI/pythia-160m" + """the path to the reward model""" + reward_model_revision: Optional[str] = None + """the revision of the reward model""" + + # generation config + response_length: int = 53 + """the length of the response""" + stop_token: Optional[Literal["eos", "period"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + min_response_length: int = 0 + """stop only after this many tokens""" + temperature: float = 0.7 + """the sampling temperature""" + penalty_reward_value: float = -1.0 + """the reward value for responses that do not contain `stop_token_id`""" + non_stop_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + + # online PPO specific args + beta: float = 0.05 + """the beta value of the RLHF objective (KL coefficient)""" + whiten_rewards: bool = False + """whether to whiten the rewards""" + cliprange: float = 0.2 + """the clip range""" + vf_coef: float = 0.1 + """the value function coefficient""" + cliprange_value: float = 0.2 + """the clip range for the value function""" + gamma: float = 1 + """the discount factor""" + lam: float = 0.95 + """the lambda value for GAE""" + kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" + + # async setting + async_mode: bool = True + """Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" + + # ray + actor_num_nodes: int = 1 + """number of nodes for actor""" + actor_num_gpus_per_node: int = 8 + """number of gpus per node for actor""" + ref_num_nodes: int = 1 + """number of nodes for reference""" + ref_num_gpus_per_node: int = 8 + """number of gpus per node for reference""" + colocate_actor_ref: bool = False + """whether to colocate reference and actor model, if true, they will share same gpus.""" + reward_num_nodes: int = 1 + """number of nodes for reward model""" + reward_num_gpus_per_node: int = 8 + """number of gpus per node for reward model""" + critic_num_nodes: int = 1 + """number of nodes for critic""" + critic_num_gpus_per_node: int = 8 + """number of gpus per node for critic""" + colocate_critic_reward: bool = False + """whether to colocate critic and reward model, if true, they will share same gpus.""" + vllm_num_engines: int = 1 + """number of vLLM Engines, set to 0 to disable vLLM""" + vllm_tensor_parallel_size: int = 1 + """tensor parallel size of vLLM Engine for multi-GPU inference""" + vllm_sync_backend: str = "nccl" + """DeepSpeed -> vLLM weight sync backend""" + enable_prefix_caching: bool = False + """whether to enable prefix caching""" + deepspeed_stage: int = 0 + + # wandb and HF tracking configs + with_tracking: bool = False + """If toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "open_instruct_internal" + """The wandb's project name""" + wandb_entity: Optional[str] = None + """The entity (team) of wandb's project""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + output_dir: Optional[str] = None + """Where to save the model""" + checkpoint_output_dir: Optional[str] = None + """Where to save the model checkpoints in case of preemption""" + + # Ai2 specific settings + try_launch_beaker_eval_jobs: bool = True + """Whether to launch beaker evaluation jobs after training""" + hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" + """What dataset to upload the metadata to. If unset, don't upload metadata""" + + def __post_init__(self): + self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) + if self.dataset_eval_mixer is not None: + self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) + + +def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: + # if passed through cli: convert the dataset mixers to dictionaries + if isinstance(value, str): + return json.loads(value), value + # if passed through yaml: convert the dataset mixers to strings + elif isinstance(value, dict): + return value, json.dumps(value) + else: + raise ValueError("Input must be either a string or a dictionary") + + +def calculate_runtime_args(args: Args, model_config: ModelConfig): + """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" + # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + # args.world_size = accelerator.num_processes + args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + args.gradient_accumulation_steps = exact_div( + args.local_mini_batch_size, + args.per_device_train_batch_size, + "`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`", + ) + args.world_size = args.actor_num_gpus_per_node * args.actor_num_nodes + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.mini_batch_size = int(args.local_mini_batch_size * args.world_size) + args.num_training_steps = args.total_episodes // args.rollout_batch_size + args.eval_freq = max(1, args.num_training_steps // args.num_evals) + # PPO logic: do checks and set up dataloader batch size + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + args.local_dataloader_batch_size = args.rollout_batch_size + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: # first try to use AI2 entity + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: # then try to use the user's entity + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: # auto-generate one + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if args.with_tracking: + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + + +def get_train_ds_config( + offload, + adam_offload=False, + stage=0, + bf16=True, + max_norm=1.0, + zpg=8, + grad_accum_dtype=None, + disable_trace_cache=True, +): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": { + "device": "cpu" if adam_offload else "none", + "pin_memory": True, + }, + "sub_group_size": "auto", + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "reduce_bucket_size": "auto", + # # ZeRO++ + # "zero_hpz_partition_size": zpg, + # "zero_quantized_weights": False, + # "zero_quantized_gradients": False, + } + if disable_trace_cache: + zero_opt_dict["stage3_prefetch_bucket_size"] = 0 + zero_opt_dict["stage3_max_live_parameters"] = 0 + zero_opt_dict["stage3_max_reuse_distance"] = 0 + + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": max_norm, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, + } + + +def get_eval_ds_config( + offload, + stage=0, + bf16=True, +): + zero_opt_dict = { + "stage": stage, + "stage3_param_persistence_threshold": "auto", + "offload_param": { + "device": "cpu" if offload else "none", + "pin_memory": True, + }, + } + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def remove_padding(sequences, pad_token_id): + return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences] + + +class ShufflingIterator: + def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): + self.data = data.copy() + self.batch_size = batch_size + self.index = 0 + self.rng = np.random.default_rng(seed) + self.rng.shuffle(self.data) + + # Ensure the effective dataset size is divisible by batch_size + self.effective_size = len(self.data) - (len(self.data) % batch_size) + + def __iter__(self) -> Iterator[List[int]]: + return self + + def __next__(self) -> List[int]: + if self.index >= self.effective_size: + self.index = 0 + self.rng.shuffle(self.data) + + end_index = self.index + self.batch_size + batch = self.data[self.index : end_index].tolist() + self.index = end_index + + return batch + + +class RayProcess: + def __init__(self, world_size, rank, local_rank, master_addr, master_port): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + self.world_size = world_size + self.rank = rank + self.local_rank = local_rank + self.master_addr = master_addr if master_addr else self.get_current_node_ip() + self.master_port = master_port if master_port else self.get_free_port() + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["RANK"] = str(self.rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + # os.environ["LOCAL_RANK"] = str(self._local_rank) + os.environ["LOCAL_RANK"] = "0" + random.seed(self.rank) + np.random.seed(self.rank) + torch.manual_seed(self.rank) + + @staticmethod + def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + @staticmethod + def get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self.master_addr, self.master_port + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + +@ray.remote(num_gpus=1) +class PolicyTrainerRayProcess(RayProcess): + def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): + self.args = args + self.num_gpu_per_node = num_gpus_per_node + self.num_nodes = num_nodes + torch.cuda.set_device(self.local_rank) + deepspeed.init_distributed() + + ds_config = get_train_ds_config( + offload=False, + adam_offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.policy) + self.policy.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + AdamOptimizer = FusedAdam + weight_decay = 0.0 + optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) + self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + # self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + print(ds_config) + self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.policy, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.model.train() + + def get_vocab_size(self): + return self.policy.config.vocab_size + + def forward( + self, + query_response: torch.LongTensor, + response: torch.LongTensor, + pad_token_id: int, + context_length: int, + temperature: float, + ) -> torch.Tensor: + output = forward(self.model, query_response, pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + return logprob + + def train( + self, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizer, + value_model: ray.actor.ActorHandle, + ref_model: ray.actor.ActorHandle, + reward_model: ray.actor.ActorHandle, + vllm_engines: List[ray.actor.ActorHandle], + metrics_queue: RayQueue, + data_collator: Callable, + ): + torch.set_printoptions(precision=4, sci_mode=False) + + args = self.args + + accelerator = Namespace() + accelerator.process_index = self.rank + accelerator.num_processes = self.world_size + accelerator.is_main_process = self.rank == 0 + torch.distributed.barrier() + if self.rank == 0: + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + vllm_num_engines, vllm_tensor_parallel_size = ( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + ) + world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 + backend = args.vllm_sync_backend + # https://github.com/OpenRLHF/OpenRLHF/issues/313 + if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0": + backend = "gloo" + print( + "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)" + ) + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * vllm_tensor_parallel_size + 1, + world_size, + "openrlhf", + backend=backend, + ) + for i, engine in enumerate(vllm_engines) + ] + self.model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name="openrlhf", + ) + ray.get(refs) + torch.distributed.barrier() + + def broadcast_to_vllm(): + # avoid OOM + torch.cuda.empty_cache() + model = self.model.module + count, num_params = 0, len(list(model.named_parameters())) + refss = [] + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + # print(f"broadcasting {name=} {shape=}") + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 + with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + # ray.get(refs) + # print(f"broadcasting {name=} {shape=} success") + if torch.distributed.get_rank() == 0: + ray.get(refss) + + # broadcast_to_vllm() + print(f"broadcasted to vllm finished {self.rank=} {self.local_rank=}, {self.world_size=}") + if args.stop_token: + if args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id + if args.stop_token == "period": + args.stop_token_id = tokenizer.encode(".")[0] + # data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + train_dataset_idxs = np.arange(len(train_dataset)) + shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed) + print(f"2broadcasted to vllm finished {self.rank=} {self.local_rank=}, {self.world_size=}") + + # hack to left pad + def repeat_generator(): + while True: + batch_idxs = next(shuffling_iter) + yield [train_dataset[i] for i in batch_idxs] + + iter_dataloader = iter(repeat_generator()) + generation_config = SamplingParams( + temperature=args.temperature, + top_p=1.0, + max_tokens=args.response_length, + include_stop_str_in_output=True, + ) + print("setup async queues") + param_prompt_Q = None + response_ids_Q = None + evaluation_Q = None + response_ids_Q = Queue(maxsize=1) + param_prompt_Q = Queue(maxsize=1) + evaluation_Q = Queue(maxsize=1) + num_eval_samples = 32 + sample_evaluation_prompt_token_ids = None + if eval_dataset is not None: + sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] + + def vllm_generate( + generation_config: SamplingParams, + response_ids_Q: Queue, + param_prompt_Q: Queue, + num_training_steps: int, + sample_evaluation_prompt_token_ids: Optional[List[int]], + evaluation_Q: Queue, + eval_freq: int, + resume_training_step: int, + ): + llm = vllm_engines[0] + for training_step in range(resume_training_step, num_training_steps + 1): + items = param_prompt_Q.get() + if items is None: + break + unwrapped_model, g_queries_list = items + # if unwrapped_model is not None: + generation_start_time = time.time() + + outputs = ray.get( + llm.generate.remote(sampling_params=generation_config, prompt_token_ids=g_queries_list) + ) + response_ids = [list(output.outputs[0].token_ids) for output in outputs] + print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") + response_ids_Q.put(response_ids) + + if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: + outputs = ray.get( + llm.generate.remote( + prompt_token_ids=sample_evaluation_prompt_token_ids, sampling_params=generation_config + ) + ) + response_ids = [list(output.outputs[0].token_ids) for output in outputs] + evaluation_Q.put(response_ids) + + resume_training_step = 1 + if accelerator.is_main_process: + thread = threading.Thread( + target=vllm_generate, + args=( + generation_config, + response_ids_Q, + param_prompt_Q, + args.num_training_steps, + sample_evaluation_prompt_token_ids, + evaluation_Q, + args.eval_freq, + resume_training_step, + ), + ) + thread.start() + print("vllm generate thread starts") + + # set up the metrics and initial states + device = torch.device(self.local_rank) + g_vllm_responses = torch.zeros( + (args.rollout_batch_size, args.response_length), device=device, dtype=torch.long + ) + stats_shape = (args.num_epochs, args.num_mini_batches, args.gradient_accumulation_steps) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + local_metrics = torch.zeros((20,), device=device) + episode = args.rollout_batch_size * (resume_training_step - 1) + + # training loop + start_time = time.time() + global_data = next(iter_dataloader) + data = data_collator( + global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size] + ) + global_queries = data_collator(global_data)[ + INPUT_IDS_PROMPT_KEY + ].tolist() # can be simplified since we `remove_padding` later anyway + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + + # for _ in range(1, resume_training_step): # we didn't store scheduler state + # scheduler.step() + + for training_step in range(resume_training_step, args.num_training_steps + 1): + episode += args.rollout_batch_size + queries = queries_next + + if accelerator.is_main_process: + df = None + try: + evaluation_responses = evaluation_Q.get(timeout=0.01) + print("🔥🔥🔥 Evaluation responses received") + table = {} + table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids) + table["response"] = tokenizer.batch_decode(evaluation_responses) + table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] + df = pd.DataFrame(table) + del table + except Empty: + print("🙈 Evaluation responses not received") + + # (optionally) evaluate the model + if args.async_mode: + if training_step != 1: + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + else: + if training_step != 1: + # NOTE: important: the indent here is different for sync mode + # we also set to use `queries = queries_next` immediately + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + queries = queries_next + + # print('get reward stuff starts') + training_time_start = time.time() + with torch.no_grad(): + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + sequence_lengths = [] + values = [] + if accelerator.is_main_process: + g_response_token_ids = response_ids_Q.get() + DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out + g_padded_response_ids = [ + response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) + for response in g_response_token_ids + ] + g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) + print(f"{g_padded_response_ids.shape=}") + print(f"{g_vllm_responses.shape=}") + g_vllm_responses[:] = g_padded_response_ids + dist.broadcast(g_vllm_responses, src=0) + local_vllm_responses = g_vllm_responses[ + accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] + ] + # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") + query_responses = torch.cat((queries, local_vllm_responses), 1) + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + print(f"get reward stuff starts {i=}") + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + + # 1. launch ref model future + ref_logprob_future = ref_model.forward.remote( + query_response, response, tokenizer.pad_token_id, context_length, args.temperature + ) + + # 2. launch reward model future + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + # print("get reward stuff starts 2") + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + reward_future = reward_model.forward.remote( + postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + + # print("get reward stuff starts 3") + # 3. launch value model future + value_future = value_model.forward.remote(query_response, tokenizer.pad_token_id, context_length) + + # 4. do local forward pass + logprob = self.forward( + query_response, response, tokenizer.pad_token_id, context_length, args.temperature + ) + torch.cuda.empty_cache() + + # print("get reward stuff starts 4") + # 5. get results from futures + _, score, _ = ray.get(reward_future) + # print(f"{score.shape=}") + full_value, _, _ = ray.get(value_future) + # print(f"{full_value.shape=}") + ref_logprob = ray.get(ref_logprob_future) + # print(f"{ref_logprob.shape=}") + if args.colocate_critic_reward: + ray.get([value_model.empty_cache.remote()]) + ray.get([reward_model.empty_cache.remote()]) + if args.colocate_actor_ref: + ray.get([ref_model.empty_cache.remote()]) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + # print(f"get reward stuff starts 5") + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + values = torch.cat(values, 0) + # print(f"get reward stuff finished") + del (logprob, ref_logprob, full_value, value, score) + gc.collect() + torch.cuda.empty_cache() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_stop_token = torch.any(postprocessed_responses == args.stop_token_id, dim=-1) + # NOTE: only apply the stop token filter if the response is long enough + # otherwise the model could learn to generate the first token as the stop token + contain_stop_token = contain_stop_token & (sequence_lengths >= args.min_response_length) + if args.non_stop_penalty: + scores = torch.where( + contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value) + ) + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + # print(f"get reward stuff finished 2") + + # 4. compute rewards + kl1 = logprobs - ref_logprobs + kl2 = (kl1) ** 2 / 2 + kl3 = (-kl1).exp() - 1 + kl1 + if args.kl_estimator == "kl1": + kl = kl1 + elif args.kl_estimator == "kl2": + kl = kl2 + elif args.kl_estimator == "kl3": + kl = kl3 + # if self.rank==0: + # print(f"{logprobs[0][:40]=}, {ref_logprobs[0][:40]=}, {kl.sum(1)=}") + non_score_reward = -args.beta * kl + non_score_reward_sum = non_score_reward.sum(1) + rlhf_reward = scores + non_score_reward_sum + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + # print(f"get reward stuff finished 3") + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # print('gae') + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # print('training starts') + # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch + for epoch_idx in range(args.num_epochs): + b_inds = np.random.permutation(args.local_rollout_batch_size) + minibatch_idx = 0 + for mini_batch_start in range(0, args.local_rollout_batch_size, args.local_mini_batch_size): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + print("micro batch start", micro_batch_start, self.rank) + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds] + + value_model_step_future = value_model.step.remote( + mb_query_responses, + tokenizer.pad_token_id, + context_length, + mb_padding_mask_p1, + mb_return, + mb_values, + args.cliprange_value, + args.vf_coef, + ) + new_logprobs = self.forward( + mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature + ) + # if self.rank==0: + # print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}") + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + self.model.backward(loss) + # print("backward loss", self.rank, "micro batch start", micro_batch_start) + # print("trying to step", self.rank, "micro batch start", micro_batch_start) + self.model.step() + # print("step", self.rank, "micro batch start", micro_batch_start) + with torch.no_grad(): + # print("waiting for value model step", self.rank, "micro batch start", micro_batch_start) + vf_loss, vf_clipfrac = ray.get(value_model_step_future) + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + # print("value model stepped", self.rank, "micro batch start", micro_batch_start) + # prob_dist = torch.nn.functional.softmax(logits, dim=-1) + # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + # entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # fmt: off + del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1 + del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss + # del vf_loss, vf_clipfrac, pg_clipfrac, approxkl + # fmt: on + # del everything and empty cache + torch.cuda.empty_cache() + del b_inds, mini_batch_inds + # print("start metrics") + with torch.no_grad(): + local_metrics[0] = sequence_lengths.float().mean() + local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() + local_metrics[2] = kl.sum(1).mean() + local_metrics[3] = (-logprobs).sum(1).mean() + local_metrics[4] = non_score_reward_sum.mean() + local_metrics[5] = rlhf_reward.mean() + local_metrics[6] = scores.mean() + local_metrics[7] = approxkl_stats.mean() + local_metrics[8] = pg_clipfrac_stats.mean() + local_metrics[9] = pg_loss_stats.mean() + local_metrics[10] = vf_loss_stats.mean() + local_metrics[11] = vf_clipfrac_stats.mean() + local_metrics[12] = entropy_stats.mean() + local_metrics[13] = ratio_stats.mean() + local_metrics[14] = ratio_stats.var() + local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() + local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() + # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() + local_metrics /= dist.get_world_size() + dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) + global_metrics = local_metrics.tolist() + metrics = { + "episode": episode, + "training_step": training_step, + "lr": self.scheduler.get_last_lr()[0], + "epoch": episode / len(train_dataset), + "time/from_scratch": time.time() - start_time, + "time/training": time.time() - training_time_start, + "val/sequence_lengths": global_metrics[0], + "val/num_stop_token_ids": global_metrics[1], + "objective/kl": global_metrics[2], + "objective/kl2": global_metrics[15], + "ojbective/kl3": global_metrics[16], + "objective/entropy": global_metrics[3], + "objective/non_score_reward": global_metrics[4], + "objective/rlhf_reward": global_metrics[5], + "objective/scores": global_metrics[6], + "policy/approxkl_avg": global_metrics[7], + "policy/clipfrac_avg": global_metrics[8], + "loss/policy_avg": global_metrics[9], + "loss/value_avg": global_metrics[10], + "val/clipfrac_avg": global_metrics[11], + "policy/entropy_avg": global_metrics[12], + "val/ratio": global_metrics[13], + "val/ratio_var": global_metrics[14], + } + if accelerator.is_main_process: + print_rich_single_line_metrics(metrics) + metrics_queue.put((metrics, episode, df)) + del (queries, responses, postprocessed_responses, logprobs, ref_logprobs, sequence_lengths, scores, values) + del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) + gc.collect() + torch.cuda.empty_cache() + print(f"finished training {training_step}") + print("finished training") + + def save_model(self, tokenizer: PreTrainedTokenizer, output_dir: str) -> None: + if self.rank == 0: + os.makedirs(output_dir, exist_ok=True) + + # save model weights for ZeRO2/3 + model_to_save = self.model + if hasattr(model_to_save, "module"): + model_to_save = model_to_save.module + + # gather parameters + output_state_dict = {} + for k, v in model_to_save.named_parameters(): + # only gather z3 params + params_to_fetch = _z3_params_to_fetch([v]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + vv = v.data.cpu() + if self.rank == 0: + output_state_dict[k] = vv + + if self.rank == 0: + state_dict = model_to_save.state_dict() + + # copy named_buffers with `persistent=True` + for k, v in model_to_save.named_buffers(): + if k not in state_dict: + continue + vv = v.data.cpu() + output_state_dict[k] = vv + + state_dict_keys = set(state_dict.keys()) + output_state_dict_keys = set(output_state_dict.keys()) + + # corner case for tie_word_embeddings, such as Qwen2-0.5B + if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: + state_dict_keys.remove("lm_head.weight") + + assert state_dict_keys.issubset( + output_state_dict_keys + ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" + + # # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 + # if isinstance(model_to_save, PeftModel): + # model_to_save.save_pretrained(output_dir, **kwargs) + # if self.stage == 3: + # torch.save( + # get_peft_model_state_dict(model_to_save, output_state_dict), + # os.path.join(output_dir, "adapter_model.bin"), + # ) + # else: + # save model + model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) + + # save tokenizer + tokenizer.save_pretrained(output_dir) + + +@ray.remote(num_gpus=1) +class ReferenceModelRayProcess(RayProcess): + def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): + self.args = args + self.num_gpu_per_node = num_gpus_per_node + self.num_nodes = num_nodes + torch.cuda.set_device(self.local_rank) + deepspeed.init_distributed() + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.policy) + self.model, *_ = deepspeed.initialize(model=self.policy, config=ds_config, dist_init_required=True) + self.model.eval() + + def forward( + self, + query_response: torch.LongTensor, + response: torch.LongTensor, + pad_token_id: int, + context_length: int, + temperature: float, + ) -> torch.Tensor: + output = forward(self.model, query_response, pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + return logprob + + +@ray.remote(num_gpus=1) +class ValueTrainerRayProcess(RayProcess): + def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): + self.args = args + self.num_gpu_per_node = num_gpus_per_node + self.num_nodes = num_nodes + torch.cuda.set_device(self.local_rank) + deepspeed.init_distributed() + + ds_config = get_train_ds_config( + offload=False, + adam_offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.value_model) + self.value_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + AdamOptimizer = FusedAdam + weight_decay = 0.0 + optim_params = get_optimizer_grouped_parameters(self.value_model, weight_decay) + self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + # self.optimizer = torch.optim.AdamW(self.value_model.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + # print(ds_config) + self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.value_model, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.model.train() + + def forward( + self, query_responses: torch.Tensor, pad_token_id: int, context_length: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return get_reward(self.value_model, query_responses, pad_token_id, context_length) + + def step( + self, + query_responses: torch.Tensor, + pad_token_id: int, + context_length: int, + mb_padding_mask_p1: torch.Tensor, + mb_return: torch.Tensor, + mb_values: torch.Tensor, + cliprange_value: float, + vf_coef: float, + ) -> None: + torch.cuda.empty_cache() + vpred_temp = self.forward(query_responses, pad_token_id, context_length) + vpred_temp = vpred_temp[0] + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0) + vpredclipped = torch.clamp( + vpred, + mb_values - cliprange_value, + mb_values + cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1) + self.model.backward(vf_loss * vf_coef) + self.model.step() + with torch.no_grad(): + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1) + del (vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max) + return vf_loss, vf_clipfrac + + +@ray.remote(num_gpus=1) +class RewardModelRayProcess(RayProcess): + def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): + deepspeed.init_distributed() + self.num_gpu_per_node = num_gpus_per_node + self.num_nodes = num_nodes + self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.reward_model) + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + self.model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config, dist_init_required=True) + self.model.eval() + + def forward( + self, query_responses: torch.Tensor, pad_token_id: int, context_length: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return get_reward(self.reward_model, query_responses, pad_token_id, context_length) + + def get_vocab_size(self): + return self.reward_model.config.vocab_size + + +def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): + while True: + if stop_event.is_set(): + break + for ref in object_refs: + try: + ray.get(ref, timeout=0.01) + except ray.exceptions.GetTimeoutError: + pass + except ray.exceptions.ActorDiedError as e: + ray.shutdown() + print(f"Actor {ref} died") + print(e) + os._exit(1) # Force shutdown the process + + time.sleep(30) + + +class ModelGroup: + def __init__( + self, + pg: PlacementGroup, + ray_process_cls: RayProcess, + num_gpus_per_actor: int, + num_gpus_per_node: int, + num_nodes: int, + ): + self.pg = pg + self.ray_process_cls = ray_process_cls + self.num_gpus_per_actor = num_gpus_per_actor + self.num_gpus_per_node = num_gpus_per_node + self.num_nodes = num_nodes + self.models = [] + + world_size = num_gpus_per_node * num_nodes + if self.num_gpus_per_actor > 1 and self.pg is None: + bundles = [{"GPU": self.num_gpus_per_actor, "CPU": self.num_gpus_per_actor} for _ in range(self.num_nodes)] + + self.pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(self.pg.ready()) + if self.pg: + master_policy = ray_process_cls.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=self.pg, placement_group_bundle_index=0 + ), + ).remote(world_size, 0, 0, None, None) + else: + master_policy = ray_process_cls.options( + num_cpus=num_gpus_per_actor, + num_gpus=num_gpus_per_actor, + ).remote(world_size, 0, 0, None, None) + + self.models.append(master_policy) + master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) + + # Setup worker models + for rank in range(1, world_size): + print(f"{rank=}, {world_size, rank, 0, master_addr, master_port=}") + scheduling_strategy = None + if pg is not None: + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.pg, + placement_group_bundle_index=rank // self.num_gpus_per_node, + ) + worker_policy = ray_process_cls.options( + num_cpus=self.num_gpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=scheduling_strategy, + ).remote(world_size, rank, 0, master_addr, master_port) + self.models.append(worker_policy) + + +def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + calculate_runtime_args(args, model_config) + + # set up experiment tracking and seeds + all_configs = {} + if is_beaker_job(): + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", args.output_dir) + beaker_config = maybe_get_beaker_config() + # try saving to the beaker `/output`, which will be uploaded to the beaker dataset + if len(beaker_config.beaker_dataset_id_urls) > 0: + args.output_dir = "/output" + all_configs.update(vars(beaker_config)) + all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) + if args.with_tracking: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=all_configs, + name=args.run_name, + save_code=True, + tags=[args.exp_name] + get_wandb_tags(), + ) + writer = SummaryWriter(f"runs/{args.run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # create a tokenizer (pad from right) + config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + ) + if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + + # create the dataset + dataset_dict = DatasetDict() + dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + train_dataset = combine_dataset( + args.dataset_mixer_dict, + splits=args.dataset_train_splits, + columns_to_keep=[dataset_config.sft_messages_key], + ) + if dataset_config.sanity_check: + train_dataset = train_dataset.select( + range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) + ) + train_dataset = dataset_processor.tokenize(train_dataset) + train_dataset = dataset_processor.filter(train_dataset) + dataset_dict["train"] = train_dataset + eval_dataset = None + if args.dataset_eval_mixer is not None: + eval_dataset = combine_dataset( + args.dataset_eval_mixer_dict, + splits=args.dataset_eval_splits, + columns_to_keep=[dataset_config.sft_messages_key], + ) + eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) + eval_dataset = dataset_processor.tokenize(eval_dataset) + eval_dataset = dataset_processor.filter(eval_dataset) + dataset_dict["eval"] = eval_dataset + data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + + # some more runtime logging + pprint([args, dataset_config, model_config]) + visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) + if args.with_tracking: + # upload the visualized token length + dataset_processor.get_token_length_visualization( + dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" + ) + wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) + + # create the model and optimizer + pg = None + if args.colocate_actor_ref: + assert ( + args.actor_num_nodes == args.ref_num_nodes and args.actor_num_gpus_per_node == args.ref_num_gpus_per_node + ), "num_nodes and num_gpus_per_node must be the same when colocate actor and ref model." + + bundles = [ + {"GPU": args.actor_num_gpus_per_node, "CPU": args.actor_num_gpus_per_node} + for _ in range(args.actor_num_nodes) + ] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + inits = [] + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + 0.75 if args.colocate_actor_ref else 1, + args.actor_num_gpus_per_node, + args.actor_num_nodes, + ) + inits.extend( + model.from_pretrained.remote(args, model_config, args.actor_num_gpus_per_node, args.actor_num_nodes) + for model in policy_group.models + ) + ref_model_group = ModelGroup( + pg, + ReferenceModelRayProcess, + 0.25 if args.colocate_actor_ref else 1, + args.ref_num_gpus_per_node, + args.ref_num_nodes, + ) + inits.extend( + model.from_pretrained.remote(args, model_config, args.ref_num_gpus_per_node, args.ref_num_nodes) + for model in ref_model_group.models + ) + + # if colocated, create placement group for critic and reward model explicitly. + pg = None + if args.colocate_critic_reward: + assert ( + args.critic_num_nodes == args.reward_num_nodes + and args.critic_num_gpus_per_node == args.reward_num_gpus_per_node + ), "num_nodes and num_gpus_per_node must be the same when colocate critic and reward model." + + bundles = [ + {"GPU": args.critic_num_gpus_per_node, "CPU": args.critic_num_gpus_per_node} + for _ in range(args.critic_num_nodes) + ] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + value_model_group = ModelGroup( + pg, + ValueTrainerRayProcess, + 0.75 if args.colocate_critic_reward else 1, + args.critic_num_gpus_per_node, + args.critic_num_nodes, + ) + inits.extend( + model.from_pretrained.remote(args, model_config, args.critic_num_gpus_per_node, args.critic_num_nodes) + for model in value_model_group.models + ) + reward_model_group = ModelGroup( + pg, + RewardModelRayProcess, + 0.25 if args.colocate_critic_reward else 1, + args.reward_num_gpus_per_node, + args.reward_num_nodes, + ) + inits.extend( + model.from_pretrained.remote(args, model_config, args.reward_num_gpus_per_node, args.reward_num_nodes) + for model in reward_model_group.models + ) + + max_len = dataset_config.max_prompt_token_length + args.response_length + vllm_engines = create_vllm_engines( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + model_config.model_name_or_path, + model_config.model_revision, + args.seed, + args.enable_prefix_caching, + max_len, + ) + + metrics_queue = RayQueue() + ray.get(inits) + print("======== all models initialized =========") + policy_vocab_size = ray.get(policy_group.models[0].get_vocab_size.remote()) + reward_vocab_size = ray.get(reward_model_group.models[0].get_vocab_size.remote()) + print(f"{policy_vocab_size=}, {reward_vocab_size=}") + if policy_vocab_size != reward_vocab_size: + ray.shutdown() # shutdown here so this error message is not buried in the logs + raise ValueError( + "Policy and reward model must have the same vocab size. " + f"Policy: {policy_vocab_size}, Reward: {reward_vocab_size}. " + "If they don't have the same vocab size, the policy could generate tokens which " + "is going to cause index out of bound error in the reward model." + ) + + refs = [] + for i, policy_model in enumerate(policy_group.models): + value_model = value_model_group.models[i % len(value_model_group.models)] + ref_model = ref_model_group.models[i % len(ref_model_group.models)] + reward_model = reward_model_group.models[i % len(reward_model_group.models)] + print(f"{value_model=}, {i=}, {len(value_model_group.models)=}") + refs.append( + policy_model.train.remote( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + value_model=value_model, + ref_model=ref_model, + reward_model=reward_model, + vllm_engines=vllm_engines, + metrics_queue=metrics_queue, + data_collator=data_collator, + ) + ) + + # somtimes a worker dies due to CUDA issues, but the rest of the cluster would just hang + # so we need kill the ray cluster when this happens. + stop_event = threading.Event() + threading.Thread(target=kill_ray_cluster_if_a_worker_dies, args=(refs, stop_event)).start() + + # train and gather metrics + resume_training_step = 1 + for training_step in range(resume_training_step, args.num_training_steps + 1): + result = metrics_queue.get() + metrics, episode, df = result + for key, value in metrics.items(): + writer.add_scalar(key, value, episode) + + if df is not None: + if args.with_tracking: + wandb.log({"sample_completions": wandb.Table(dataframe=df)}) + # else: + # print_rich_table(df) + ray.get(refs) + + # save model + original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + ray.get( + [policy_model.save_model.remote(original_tokenizer, args.output_dir) for policy_model in policy_group.models] + ) + ray.shutdown() + stop_event.set() + + # Ai2 specific logic + if is_beaker_job(): + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "sft", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: + command = f"""\ + python mason.py \ + --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --model_name {args.hf_repo_revision} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + accelerator = Namespace() + accelerator.is_main_process = True # hack + if args.push_to_hub: + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + + if accelerator.is_main_process: + # remove args.checkpoint_output_dir + if os.path.exists(args.checkpoint_output_dir): + shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) + main(*parser.parse()) diff --git a/open_instruct/vllm_utils2.py b/open_instruct/vllm_utils2.py new file mode 100644 index 000000000..7dbdf6a54 --- /dev/null +++ b/open_instruct/vllm_utils2.py @@ -0,0 +1,240 @@ +# Taken and modified from https://github.com/huggingface/trl +# Copyright 2024 The AllenAI Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This file is copied from https://github.com/OpenRLHF/OpenRLHF""" + + +from datetime import timedelta +from typing import Any, Optional, Union + +import ray +import torch +import torch.distributed +from ray.util.placement_group import placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from torch.distributed.distributed_c10d import ( + Backend, + PrefixStore, + Store, + _new_process_group_helper, + _world, + default_pg_timeout, + rendezvous, +) +from vllm.worker.worker import Worker + + +# Copy from pytorch to allow creating multiple main groups. +# https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py +def init_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = None, + pg_options: Optional[Any] = None, +): + assert (store is None) or (init_method is None), "Cannot specify both init_method and store." + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + if backend: + backend = Backend(backend) + else: + backend = Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + # backward compatible API + if store is None: + rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + + # Use a PrefixStore to avoid accidental overrides of keys used by + # different systems (e.g. RPC) in case the store is multi-tenant. + store = PrefixStore(group_name, store) + + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + pg_options=pg_options, + timeout=timeout, + ) + + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + + return pg + + +class WorkerWrap(Worker): + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl"): + """Init torch process group for model weights update""" + assert torch.distributed.is_initialized(), "default torch process group must be initialized" + assert group_name != "", "group name must not be empty" + + rank = torch.distributed.get_rank() + rank_offset + self._model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=rank, + group_name=group_name, + ) + print( + f"init_process_group: master_address={master_address}, master_port={master_port}, ", + f"rank={rank}, world_size={world_size}, group_name={group_name}", + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + """Broadcast weight to all vllm workers from source rank 0 (actor model)""" + # print(f"update_weight: {name}, dtype: {dtype}, shape: {shape}, rank: {torch.distributed.get_rank()}, world_size: {torch.distributed.get_world_size()}") + # if torch.distributed.get_rank() == 0: + # print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") + + assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" + weight = torch.empty(shape, dtype=dtype, device="cuda") + torch.distributed.broadcast(weight, 0, group=self._model_update_group) + + self.model_runner.model.load_weights(weights=[(name, weight)]) + + del weight + # TODO: should we empty cache if all weights have updated? + # if empty_cache: + # torch.cuda.empty_cache() + + +@ray.remote +class LLMRayActor: + def __init__(self, *args, **kwargs): + import vllm + + self.__version__ = vllm.__version__ + assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1" + + self.use_gpu_executor = kwargs["tensor_parallel_size"] == 1 + + # See https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py + if self.use_gpu_executor: + + vllm.worker.worker.Worker = WorkerWrap + else: + # RayGPUExecutor + # See the patch https://github.com/vllm-project/vllm/commit/479d69fad0538f04cb22bf13e76ff91cfeb8a4e5 + kwargs["worker_use_ray"] = True + + if vllm.__version__ > "0.4.1": + RayWorkerWrapperPath = vllm.executor.ray_utils + else: + RayWorkerWrapperPath = vllm.engine.ray_utils + + class RayWorkerWrapper(RayWorkerWrapperPath.RayWorkerWrapper): + def __init__(self, *args, **kwargs) -> None: + kwargs["worker_module_name"] = "open_instruct.vllm_utils2" + kwargs["worker_class_name"] = "WorkerWrap" + super().__init__(*args, **kwargs) + + RayWorkerWrapperPath.RayWorkerWrapper = RayWorkerWrapper + + self.llm = vllm.LLM(*args, **kwargs) + + def generate(self, *args, **kwargs): + return self.llm.generate(*args, **kwargs) + + def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend): + if self.use_gpu_executor: + return self.llm.llm_engine.model_executor.driver_worker.init_process_group( + master_address, master_port, rank_offset, world_size, group_name, backend + ) + else: + return self.llm.llm_engine.model_executor._run_workers( + "init_process_group", master_address, master_port, rank_offset, world_size, group_name, backend + ) + + def update_weight(self, name, dtype, shape, empty_cache=False): + self.stop_remote_worker_execution_loop() + + if self.use_gpu_executor: + return self.llm.llm_engine.model_executor.driver_worker.update_weight(name, dtype, shape, empty_cache) + else: + return self.llm.llm_engine.model_executor._run_workers("update_weight", name, dtype, shape, empty_cache) + + def stop_remote_worker_execution_loop(self): + # Fix error for using 2 communication group + # https://github.com/vllm-project/vllm/commit/eb6d3c264d0cd8e44dec16bca7947fbe96415ce9#diff-e1ad69e38e033accddfa5480ec808c4740eb39244d1ef51cc3407e20dde8cfd4 + if self.__version__ > "0.4.2": + self.llm.llm_engine.model_executor.stop_remote_worker_execution_loop() + + +def create_vllm_engines( + num_engines: int, + tensor_parallel_size: int, + pretrain: str, + revision: str, + seed: int, + enable_prefix_caching: bool, + max_model_len: int, +): + vllm_engines = [] + for i in range(num_engines): + # When tensor_parallel_size=1, vLLM init model in LLMEngine directly, assign 1 GPU for it. + num_gpus = int(tensor_parallel_size == 1) + scheduling_strategy = None + + if tensor_parallel_size > 1: + bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size + pg = placement_group(bundles) + ray.get(pg.ready()) + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0 + ) + print(f"vllm: {num_gpus=}, {num_engines=}") + vllm_engines.append( + LLMRayActor.options( + num_cpus=1, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + ).remote( + pretrain, + revision=revision, + tokenizer_revision=revision, + trust_remote_code=True, + tensor_parallel_size=tensor_parallel_size, + dtype="bfloat16", + seed=seed + i, + enable_prefix_caching=enable_prefix_caching, + max_model_len=max_model_len, + ) + ) + + return vllm_engines + + +if __name__ == "__main__": + llm = LLMRayActor.remote("meta-llama/Llama-3.1-8B-Instruct", tensor_parallel_size=2) + output = ray.get(llm.generate.remote("San Franciso is a")) + print(f"output: {output}") From f6a2b75177fe2af4f9bafacdb4f32dca3a8c5001 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 30 Oct 2024 15:52:08 -0700 Subject: [PATCH 38/53] add weka save override --- open_instruct/ppo_vllm_thread_ray_gtrl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index d4a3d63c9..38a6ab3d8 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -275,6 +275,8 @@ class Args: """Where to save the model""" checkpoint_output_dir: Optional[str] = None """Where to save the model checkpoints in case of preemption""" + overwrite_beaker_output_dir: Optional[str] = None + """Where to save in a beaker job, if not just /output. Useful with weka.""" # Ai2 specific settings try_launch_beaker_eval_jobs: bool = True @@ -1466,6 +1468,9 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # try saving to the beaker `/output`, which will be uploaded to the beaker dataset if len(beaker_config.beaker_dataset_id_urls) > 0: args.output_dir = "/output" + # if the user has asked to save to a specific directory, use that instead + if args.overwrite_beaker_output_dir is not None: + args.output_dir = args.overwrite_beaker_output_dir all_configs.update(vars(beaker_config)) all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) if args.with_tracking: From b59659a79e1c178c8d43d801cec351fa5e32df2f Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 30 Oct 2024 16:35:53 -0700 Subject: [PATCH 39/53] add multinode ray file --- configs/beaker_configs/ray_node_setup.sh | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100755 configs/beaker_configs/ray_node_setup.sh diff --git a/configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/ray_node_setup.sh new file mode 100755 index 000000000..fe075749c --- /dev/null +++ b/configs/beaker_configs/ray_node_setup.sh @@ -0,0 +1,21 @@ +export CURRENT_DATETIME=$(python -c "import datetime; import pytz; print(datetime.datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%m%d%y_%H%M%S'))") +export PYTHONPATH=$REPO_PATH +export PATH="/root/.local/bin:$PATH" + + +echo CURRENT_DATETIME=$CURRENT_DATETIME +echo PYTHONPATH=$PYTHONPATH +echo PATH=$PATH + +# python3 -c "import os, ray; print(os.path.dirname(ray.__file__))" + +RAY_NODE_PORT=8888 +ray stop --force + +if [ "$BEAKER_REPLICA_RANK" == "0" ]; then + echo "Starting Ray head node" + ray start --head --port=$RAY_NODE_PORT +else + echo "Starting Ray worker node $BEAKER_REPLICA_RANK" + ray start --address="${BEAKER_LEADER_REPLICA_HOSTNAME}:${RAY_NODE_PORT}" --block +fi From 36a2ed474ff17fe8d72474a3a6be7b4e4228b616 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 30 Oct 2024 17:28:04 -0700 Subject: [PATCH 40/53] lint and fix --- open_instruct/ground_truth_utils.py | 7 +- open_instruct/if_functions.py | 682 ++++++++++++++-------------- open_instruct/math_utils.py | 97 +++- open_instruct/model_utils.py | 2 +- 4 files changed, 441 insertions(+), 347 deletions(-) diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 3d1cff51f..2340900c0 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -3,10 +3,10 @@ Used to give feedback to the model based on the ground truth answer. ''' import re -import json from open_instruct.math_utils import last_boxed_only_string, remove_boxed, get_unnormalized_answer, normalize_final_answer, is_equiv, hendrycks_is_equiv from open_instruct.if_functions import IF_FUNCTIONS_MAP + def verify_gsm8k_sample(model_output, ground_truth_answer): # gsm is easy: extract numbers, and then just compare last number with answer. # matches how we do eval. @@ -99,7 +99,7 @@ def verify_ifeval_sample(model_output, constraint): func = IF_FUNCTIONS_MAP[func_name] # now, run the function # pop out any none args - non_none_args = {k:v for k,v in constraint.items() if v is not None} + non_none_args = {k: v for k, v in constraint.items() if v is not None} # sometimes we have extra args, sometimes not. if len(constraint) == 0: return func(model_output) @@ -110,6 +110,7 @@ def verify_flan_sample(model_output, ground_truth_answer): # TODO: flan. we could do BLEU/ROUGE.... or maybe something like BertScore? pass + # debug code if __name__ == "__main__": from datasets import load_dataset @@ -117,4 +118,4 @@ def verify_flan_sample(model_output, ground_truth_answer): test_model_output = "<|assistant|>\nThe answer is $\\boxed{3.14}$" for sample in ds['train']: print(sample) - verify_ifeval_sample(test_model_output, sample['ground_truth']) \ No newline at end of file + verify_ifeval_sample(test_model_output, sample['ground_truth']) diff --git a/open_instruct/if_functions.py b/open_instruct/if_functions.py index fa21c55ce..367bdfda9 100644 --- a/open_instruct/if_functions.py +++ b/open_instruct/if_functions.py @@ -1,6 +1,5 @@ import re import json -import langdetect from typing import List """ @@ -9,507 +8,506 @@ """ - # include keywords: Include keywords {keyword1}, {keyword2} in your response - def verify_keywords(text, keyword_list): - """ - Verify if the response contains all the specified keywords. + """ + Verify if the response contains all the specified keywords. - Args: - response (str): The response text to check - keyword_list (list): A list of keywords to check for + Args: + response (str): The response text to check + keyword_list (list): A list of keywords to check for - Returns: - bool: True if all keywords are present in the response, False otherwise - """ - # Convert response to lowercase for case-insensitive matching - response_lower = text.lower() + Returns: + bool: True if all keywords are present in the response, False otherwise + """ + # Convert response to lowercase for case-insensitive matching + response_lower = text.lower() - # Check if all keywords are present in the response - return all(keyword.lower() in response_lower for keyword in keyword_list) + # Check if all keywords are present in the response + return all(keyword.lower() in response_lower for keyword in keyword_list) # Keyword Frequency: In your response, the word {word} should appear {N} times. def verify_keyword_frequency(text, word, N): - """ - Verifies if a keyword appears exactly N times in the given text. + """ + Verifies if a keyword appears exactly N times in the given text. - Args: - text (str): The text to analyze - keyword (str): The keyword to count - expected_count (int): The expected number of occurrences + Args: + text (str): The text to analyze + keyword (str): The keyword to count + expected_count (int): The expected number of occurrences - Returns: - tuple: (bool, int) - (Whether constraint is met, actual count found) - """ - # Convert text to lowercase to make the search case-insensitive - text = text.lower() - keyword = word.lower() + Returns: + tuple: (bool, int) - (Whether constraint is met, actual count found) + """ + # Convert text to lowercase to make the search case-insensitive + text = text.lower() + keyword = word.lower() - # Split text into words and remove punctuation - import re - words = re.findall(r'\b\w+\b', text) + # Split text into words and remove punctuation + import re + words = re.findall(r'\b\w+\b', text) - # Count actual occurrences - actual_count = sum(1 for word in words if word == keyword) + # Count actual occurrences + actual_count = sum(1 for word in words if word == keyword) - # Check if constraint is met - constraint_met = actual_count == N + # Check if constraint is met + constraint_met = actual_count == N - return constraint_met + return constraint_met # Forbidden Words: Do not include keywords {forbidden words} in the response. def validate_forbidden_words(text, forbidden_words): - """ - Validates that the text does not contain any of the specified forbidden words. + """ + Validates that the text does not contain any of the specified forbidden words. - Args: - text (str): The text to check for forbidden words - forbidden_words (list[str]): A list of forbidden words + Args: + text (str): The text to check for forbidden words + forbidden_words (list[str]): A list of forbidden words - Returns: - tuple[bool, list[str]]: A tuple containing: - - Boolean indicating if any forbidden words are present - - List of forbidden words found in the text + Returns: + tuple[bool, list[str]]: A tuple containing: + - Boolean indicating if any forbidden words are present + - List of forbidden words found in the text - Example: - text = "This is a message that should not contain any bad words" - forbidden_words = ["bad", "evil", "harmful"] - result = validate_forbidden_words(text, forbidden_words) - """ - # Convert text to lowercase for case-insensitive matching - text_lower = text.lower() + Example: + text = "This is a message that should not contain any bad words" + forbidden_words = ["bad", "evil", "harmful"] + result = validate_forbidden_words(text, forbidden_words) + """ + # Convert text to lowercase for case-insensitive matching + text_lower = text.lower() - # Check each forbidden word - found_words = [word for word in forbidden_words if word.lower() in text_lower] + # Check each forbidden word + found_words = [word for word in forbidden_words if word.lower() in text_lower] - # Return results - return len(found_words) == 0 + # Return results + return len(found_words) == 0 # Letter Frequency : In your response, the letter {letter} should appear {N} times. def verify_letter_frequency(text: str, letter: str, N: int) -> bool: - """ - Verifies if a given letter appears exactly the specified number of times in the text. + """ + Verifies if a given letter appears exactly the specified number of times in the text. - Args: - text (str): The text to check - letter (str): The letter to count (case-sensitive) - target_count (int): The expected number of occurrences + Args: + text (str): The text to check + letter (str): The letter to count (case-sensitive) + target_count (int): The expected number of occurrences - Returns: - bool: True if the constraint is met, False otherwise + Returns: + bool: True if the constraint is met, False otherwise - Example: - >>> verify_letter_frequency("hello world", "l", 3) - True - >>> verify_letter_frequency("hello world", "o", 2) - True - >>> verify_letter_frequency("hello world", "x", 0) - True - """ - if len(letter) != 1: - raise ValueError("Letter parameter must be a single character") + Example: + >>> verify_letter_frequency("hello world", "l", 3) + True + >>> verify_letter_frequency("hello world", "o", 2) + True + >>> verify_letter_frequency("hello world", "x", 0) + True + """ + if len(letter) != 1: + raise ValueError("Letter parameter must be a single character") - actual_count = text.count(letter) - return actual_count == N + actual_count = text.count(letter) + return actual_count == N # Response Language: Your ENTIRE response should be in {language}, no other language is allowed. def validate_response_language(text, language): - """ - Validates that the entire response is in the specified language. + """ + Validates that the entire response is in the specified language. - Args: - text (str): The text to check - language (str): The language code (e.g., 'en' for English) + Args: + text (str): The text to check + language (str): The language code (e.g., 'en' for English) - Returns: - bool: True if the response is entirely in the specified language, False otherwise + Returns: + bool: True if the response is entirely in the specified language, False otherwise - Example: - text = "This is an English sentence" - language = "en" - result = validate_response_language(text, language) - """ - from langdetect import detect + Example: + text = "This is an English sentence" + language = "en" + result = validate_response_language(text, language) + """ + from langdetect import detect - # Detect the language of the text - detected_language = detect(text) - # Check if the detected language matches the expected language - return detected_language == language + # Detect the language of the text + detected_language = detect(text) + # Check if the detected language matches the expected language + return detected_language == language # Number Paragraphs: Your response should contain {N} paragraphs. You separate paragraphs using the markdown divider: # * * * def verify_paragraph_count(text: str, N: int) -> bool: - """ - Verifies that a text contains the expected number of paragraphs, - where paragraphs are separated by markdown dividers '* * *' + """ + Verifies that a text contains the expected number of paragraphs, + where paragraphs are separated by markdown dividers '* * *' - Args: - text (str): The text to analyze - expected_count (int): Expected number of paragraphs + Args: + text (str): The text to analyze + expected_count (int): Expected number of paragraphs - Returns: - bool: True if the text contains exactly the expected number of paragraphs, - False otherwise + Returns: + bool: True if the text contains exactly the expected number of paragraphs, + False otherwise - Example: - text = "First paragraph\n* * *\nSecond paragraph" - verify_paragraph_count(text, 2) - True - """ - def clean_text(text: str) -> str: - """Remove extra whitespace and normalize line endings""" - return '\n'.join(line.strip() for line in text.splitlines()).strip() + Example: + text = "First paragraph\n* * *\nSecond paragraph" + verify_paragraph_count(text, 2) + True + """ + def clean_text(text: str) -> str: + """Remove extra whitespace and normalize line endings""" + return '\n'.join(line.strip() for line in text.splitlines()).strip() - # Clean the input text - text = clean_text(text) + # Clean the input text + text = clean_text(text) - # Split text by markdown divider - # Add 1 to count since n dividers create n+1 paragraphs - paragraphs = text.split('* * *') - actual_count = len(paragraphs) + # Split text by markdown divider + # Add 1 to count since n dividers create n+1 paragraphs + paragraphs = text.split('* * *') + actual_count = len(paragraphs) - # Verify each split resulted in non-empty content - valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] - if len(valid_paragraphs) != actual_count: - return False + # Verify each split resulted in non-empty content + valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] + if len(valid_paragraphs) != actual_count: + return False - return actual_count == N + return actual_count == N # Number Words: Answer with at least / around / at most {N} words def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: - """ - Validates if a text meets specified word count constraints. + """ + Validates if a text meets specified word count constraints. - Args: - text (str): The text to check - count (int): The target word count - qualifier (str): The type of constraint ('at least', 'around', 'at most') + Args: + text (str): The text to check + count (int): The target word count + qualifier (str): The type of constraint ('at least', 'around', 'at most') - Returns: - bool: True if the constraint is met, False otherwise + Returns: + bool: True if the constraint is met, False otherwise - Raises: - ValueError: If an invalid qualifier is provided - """ - # Remove extra whitespace and split into words - words = text.strip().split() - actual_count = len(words) + Raises: + ValueError: If an invalid qualifier is provided + """ + # Remove extra whitespace and split into words + words = text.strip().split() + actual_count = len(words) - # Define tolerance for "around" qualifier (±10% of target count) - tolerance = max(round(N * 0.1), 1) + # Define tolerance for "around" qualifier (±10% of target count) + tolerance = max(round(N * 0.1), 1) - if quantifier == "at least": - return actual_count >= N - elif quantifier == "at most": - return actual_count <= N - elif quantifier == "around": - return abs(actual_count - N) <= tolerance - else: - return False + if quantifier == "at least": + return actual_count >= N + elif quantifier == "at most": + return actual_count <= N + elif quantifier == "around": + return abs(actual_count - N) <= tolerance + else: + return False # Number Sentences: Answer with at least / around / at most {N} sentences. def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: - """ - Verifies if a text contains the expected number of sentences. - - Args: - text (str): The text to analyze - N (int): The expected number of sentences - quantifier (str): The quantifier ('at least', 'around', 'at most') - - Returns: - bool: True if the text contains the expected number of sentences, False otherwise - """ - # Split the text into sentences - sentences = re.split(r'(?= N - elif quantifier == 'around': - return abs(actual_count - N) <= 1 - elif quantifier == 'at most': - return actual_count <= N - else: - return False + """ + Verifies if a text contains the expected number of sentences. + + Args: + text (str): The text to analyze + N (int): The expected number of sentences + quantifier (str): The quantifier ('at least', 'around', 'at most') + + Returns: + bool: True if the text contains the expected number of sentences, False otherwise + """ + # Split the text into sentences + sentences = re.split(r'(?= N + elif quantifier == 'around': + return abs(actual_count - N) <= 1 + elif quantifier == 'at most': + return actual_count <= N + else: + return False # Number Paragraphs + First Word in i-th Paragraph: There should be {N} paragraphs. Paragraphs and only paragraphs # are separated with each other by two line breaks. The {i}-th paragraph must start with word {first word}. def validate_paragraphs(text, N, first_word, i): - """ - Validates that a text contains the expected number of paragraphs and that the i-th paragraph starts with a specific - word. + """ + Validates that a text contains the expected number of paragraphs and that the i-th paragraph starts with a specific + word. - Args: - text (str): The text to analyze - N (int): The expected number of paragraphs - first_word (str): The expected first word of the i-th paragraph - i (int): The index of the paragraph to check (1-indexed) + Args: + text (str): The text to analyze + N (int): The expected number of paragraphs + first_word (str): The expected first word of the i-th paragraph + i (int): The index of the paragraph to check (1-indexed) - Returns: - bool: True if the text meets the paragraph and first word requirements, False otherwise - """ - # Split the text into paragraphs - paragraphs = text.split('\n\n') + Returns: + bool: True if the text meets the paragraph and first word requirements, False otherwise + """ + # Split the text into paragraphs + paragraphs = text.split('\n\n') - # Check if the number of paragraphs is as expected - if len(paragraphs) != N: - return False + # Check if the number of paragraphs is as expected + if len(paragraphs) != N: + return False - # Check if the i-th paragraph starts with the specified first word - if paragraphs[i - 1].strip().startswith(first_word): - return True - return False + # Check if the i-th paragraph starts with the specified first word + if paragraphs[i - 1].strip().startswith(first_word): + return True + return False # Postscript: At the end of your response, please explicitly add a postscript starting with {postscript marker} def verify_postscript(text, postscript_marker): - """ - Verifies if a text contains a postscript starting with '{postscript marker}' - - Args: - text (str): The text to verify - - Returns: - bool: True if the text contains a valid postscript, False otherwise - """ - # Check if the text contains the postscript marker - if postscript_marker in text: - # Get the index of the marker - marker_index = text.find(postscript_marker) - # Check if the marker appears near the end - remaining_text = text[marker_index:].strip() - # Verify it's not just the marker alone - return len(remaining_text) > len(postscript_marker) - return False + """ + Verifies if a text contains a postscript starting with '{postscript marker}' + + Args: + text (str): The text to verify + + Returns: + bool: True if the text contains a valid postscript, False otherwise + """ + # Check if the text contains the postscript marker + if postscript_marker in text: + # Get the index of the marker + marker_index = text.find(postscript_marker) + # Check if the marker appears near the end + remaining_text = text[marker_index:].strip() + # Verify it's not just the marker alone + return len(remaining_text) > len(postscript_marker) + return False # Number Placeholder: The response must contain at least {N} placeholders represented by square brackets, # such as [address]. def validate_placeholders(text: str, N: int) -> tuple[bool, List[str]]: - """ - Validates if a text contains at least the specified number of placeholders in square brackets. + """ + Validates if a text contains at least the specified number of placeholders in square brackets. - Args: - text (str): The text to check for placeholders - min_placeholders (int): Minimum number of placeholders required + Args: + text (str): The text to check for placeholders + min_placeholders (int): Minimum number of placeholders required - Returns: - tuple[bool, List[str]]: A tuple containing: - - Boolean indicating if the text meets the placeholder requirement - - List of found placeholders + Returns: + tuple[bool, List[str]]: A tuple containing: + - Boolean indicating if the text meets the placeholder requirement + - List of found placeholders - Example: - >>> text = "Hello [name], your [item] will be delivered to [address]" - >>> validate_placeholders(text, 2) - (True, ['name', 'item', 'address']) - """ - # Find all placeholders using regex - pattern = r'\[(.*?)\]' - placeholders = re.findall(pattern, text) + Example: + >>> text = "Hello [name], your [item] will be delivered to [address]" + >>> validate_placeholders(text, 2) + (True, ['name', 'item', 'address']) + """ + # Find all placeholders using regex + pattern = r'\[(.*?)\]' + placeholders = re.findall(pattern, text) - # Check if the number of placeholders meets the requirement - has_enough = len(placeholders) >= N + # Check if the number of placeholders meets the requirement + has_enough = len(placeholders) >= N - return has_enough, placeholders + return has_enough, placeholders # Number Bullets: Your answer must contain exactly {N} bullet points. Use the markdown bullet points such as: * This # is a point. def verify_bullet_points(text: str, N: int) -> tuple[bool, str]: - """ - Verifies if a text contains exactly N bullet points in markdown format. - Returns a tuple of (is_valid, message). + """ + Verifies if a text contains exactly N bullet points in markdown format. + Returns a tuple of (is_valid, message). - Args: - text (str): The text to check - expected_count (int): The expected number of bullet points + Args: + text (str): The text to check + expected_count (int): The expected number of bullet points - Returns: - tuple[bool, str]: (True if constraint is met, explanation message) - """ - # Split text into lines and count lines starting with * or - - lines = text.split('\n') - bullet_points = [line.strip() for line in lines if line.strip().startswith(('*', '-'))] - actual_count = len(bullet_points) + Returns: + tuple[bool, str]: (True if constraint is met, explanation message) + """ + # Split text into lines and count lines starting with * or - + lines = text.split('\n') + bullet_points = [line.strip() for line in lines if line.strip().startswith(('*', '-'))] + actual_count = len(bullet_points) - if actual_count == N: - return True - else: - return False + if actual_count == N: + return True + else: + return False # Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. def validate_title(text: str) -> bool: - pattern = r"<<(.*?)>>" - matches = re.findall(pattern, text) + pattern = r"<<(.*?)>>" + matches = re.findall(pattern, text) - if len(matches) > 0: - return True - else: - return False + if len(matches) > 0: + return True + else: + return False # Choose: From Answer with one of the following options: {options} def validate_choice(text: str, options: list) -> bool: - for option in options: - if text in option: - return True - return False + for option in options: + if text in option: + return True + return False # Minimum Number Highlighted Section: Highlight at least {N} sections in your answer with markdown, i.e. *highlighted # section* def validate_highlighted_sections(text: str, N: int) -> bool: - pattern = r"\*(.*?)\*" - matches = re.findall(pattern, text) + pattern = r"\*(.*?)\*" + matches = re.findall(pattern, text) - if len(matches) >= N: - return True - else: - return False + if len(matches) >= N: + return True + else: + return False # Multiple Sections: Your response must have {N} sections. Mark the beginning of each section with {section splitter} X. def validate_sections(text: str, N: int, section_splitter: str) -> bool: - sections = text.split(section_splitter) - # The first section might not start with the splitter, so we adjust for this - if sections[0] == '': - sections.pop(0) - if len(sections) == N: - return True - else: - return False + sections = text.split(section_splitter) + # The first section might not start with the splitter, so we adjust for this + if sections[0] == '': + sections.pop(0) + if len(sections) == N: + return True + else: + return False # JSON Format : Entire output should be wrapped in JSON format. def validate_json_format(text: str) -> bool: - try: - json_object = json.loads(text) - except ValueError as e: - return False - return True + try: + json.loads(text) + except ValueError: + return False + return True # Repeat Prompt: First, repeat the request without change, then give your answer (do not say anything before # repeating the request; the request you need to repeat does not include this sentence) def validate_repeat_prompt(text: str, original_prompt: str) -> bool: - if text.startswith(original_prompt): - return True - else: - return False + if text.startswith(original_prompt): + return True + else: + return False # Two Responses: Give two different responses. Responses and only responses should be separated by 6 asterisk # symbols: ******. def validate_two_responses(text: str) -> bool: - if text.count('******') == 1: - response_list = text.split('******') - first_response = response_list[0].strip() - second_response = response_list[1].strip() - if first_response != second_response: - return True - return False + if text.count('******') == 1: + response_list = text.split('******') + first_response = response_list[0].strip() + second_response = response_list[1].strip() + if first_response != second_response: + return True + return False # All Uppercase: Your entire response should be in English, capital letters only. def validate_uppercase(text: str) -> bool: - # Check if the response is the same as the uppercase version of the response - if text == text.upper(): - return True - else: - return False + # Check if the response is the same as the uppercase version of the response + if text == text.upper(): + return True + else: + return False # All Lowercase: Your entire response should be in English, and in all lowercase letters. No capital letters are # allowed. def validate_lowercase(text: str) -> bool: - # Check if the response is the same as the lowercase version of the response - if text == text.lower(): - return True - else: - return False + # Check if the response is the same as the lowercase version of the response + if text == text.lower(): + return True + else: + return False # Frequency of All-capital Words: In your response, words with all capital letters should appear at least / around / # at most {N} times. def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool: - words = re.findall(r'\b[A-Z]+\b', text) - if quantifier == 'at least': - return len(words) >= N - elif quantifier == 'around': - return len(words) == N - elif quantifier == 'at most': - return len(words) <= N - else: - return False + words = re.findall(r'\b[A-Z]+\b', text) + if quantifier == 'at least': + return len(words) >= N + elif quantifier == 'around': + return len(words) == N + elif quantifier == 'at most': + return len(words) <= N + else: + return False # End Checker: Finish your response with this exact phrase {end phrase}. No other words should follow this phrase. def validate_end(text: str, end_phrase: str) -> bool: - # Check if the response ends with the end phrase - if text.endswith(end_phrase): - return True - else: - return False + # Check if the response ends with the end phrase + if text.endswith(end_phrase): + return True + else: + return False # Quotation: Wrap your entire response with double quotation marks. def validate_quotation(text: str) -> bool: - if text.startswith('"') and text.endswith('"'): - return True - else: - return False + if text.startswith('"') and text.endswith('"'): + return True + else: + return False # No Commas: In your entire response, refrain from the use of any commas. def validate_no_commas(text: str) -> bool: - if ',' not in text: - return True - else: - return False + if ',' not in text: + return True + else: + return False + IF_FUNCTIONS_MAP = { - 'verify_keywords': verify_keywords, - 'verify_keyword_frequency': verify_keyword_frequency, - 'validate_forbidden_words': validate_forbidden_words, - 'verify_letter_frequency': verify_letter_frequency, - 'validate_response_language': validate_response_language, - 'verify_paragraph_count': verify_paragraph_count, - 'validate_word_constraint': validate_word_constraint, - 'verify_sentence_constraint': verify_sentence_constraint, - 'validate_paragraphs': validate_paragraphs, - 'verify_postscript': verify_postscript, - 'validate_placeholders': validate_placeholders, - 'verify_bullet_points': verify_bullet_points, - 'validate_title': validate_title, - 'validate_choice': validate_choice, - 'validate_highlighted_sections': validate_highlighted_sections, - 'validate_sections': validate_sections, - 'validate_json_format': validate_json_format, - 'validate_repeat_prompt': validate_repeat_prompt, - 'validate_two_responses': validate_two_responses, - 'validate_uppercase': validate_uppercase, - 'validate_lowercase': validate_lowercase, - 'validate_frequency_capital_words': validate_frequency_capital_words, - 'validate_end': validate_end, - 'validate_quotation': validate_quotation, - 'validate_no_commas': validate_no_commas + 'verify_keywords': verify_keywords, + 'verify_keyword_frequency': verify_keyword_frequency, + 'validate_forbidden_words': validate_forbidden_words, + 'verify_letter_frequency': verify_letter_frequency, + 'validate_response_language': validate_response_language, + 'verify_paragraph_count': verify_paragraph_count, + 'validate_word_constraint': validate_word_constraint, + 'verify_sentence_constraint': verify_sentence_constraint, + 'validate_paragraphs': validate_paragraphs, + 'verify_postscript': verify_postscript, + 'validate_placeholders': validate_placeholders, + 'verify_bullet_points': verify_bullet_points, + 'validate_title': validate_title, + 'validate_choice': validate_choice, + 'validate_highlighted_sections': validate_highlighted_sections, + 'validate_sections': validate_sections, + 'validate_json_format': validate_json_format, + 'validate_repeat_prompt': validate_repeat_prompt, + 'validate_two_responses': validate_two_responses, + 'validate_uppercase': validate_uppercase, + 'validate_lowercase': validate_lowercase, + 'validate_frequency_capital_words': validate_frequency_capital_words, + 'validate_end': validate_end, + 'validate_quotation': validate_quotation, + 'validate_no_commas': validate_no_commas } diff --git a/open_instruct/math_utils.py b/open_instruct/math_utils.py index 01a2fbaaf..56a0a8cc4 100644 --- a/open_instruct/math_utils.py +++ b/open_instruct/math_utils.py @@ -1,10 +1,15 @@ import re -import sympy +import signal import logging from typing import Optional +import sympy +from sympy.parsing.latex import parse_latex + + eval_logger = logging.getLogger("math_utils") + # from https://github.com/EleutherAI/lm-evaluation-harness/blob/main/lm_eval/tasks/minerva_math/utils.py#L187 def last_boxed_only_string(string: str) -> Optional[str]: idx = string.rfind("\\boxed") @@ -121,6 +126,7 @@ def get_unnormalized_answer(text: str) -> str: "\\dots", ] + def normalize_final_answer(final_answer: str) -> str: """ Normalize a final answer to a quantitative reasoning question. @@ -159,6 +165,22 @@ def normalize_final_answer(final_answer: str) -> str: return final_answer +class timeout: + def __init__(self, seconds=1, error_message="Timeout"): + self.seconds = seconds + self.error_message = error_message + + def handle_timeout(self, signum, frame): + raise TimeoutError(self.error_message) + + def __enter__(self): + signal.signal(signal.SIGALRM, self.handle_timeout) + signal.alarm(self.seconds) + + def __exit__(self, type, value, traceback): + signal.alarm(0) + + def is_equiv(x1: str, x2: str) -> bool: """ x1 and x2 are normalized latex string @@ -201,6 +223,79 @@ def is_equiv(x1: str, x2: str) -> bool: eval_logger.debug(f"Failed comparing {x1} and {x2} with {e}") return False + +def fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except AssertionError: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + a = int(a) + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except AssertionError: + return string + + +def remove_right_units(string): + # "\\text{ " only ever occurs (at least in the val set) when describing units + if "\\text{ " in string: + splits = string.split("\\text{ ") + assert len(splits) == 2 + return splits[0] + else: + return string + + +def fix_sqrt(string): + if "\\sqrt" not in string: + return string + splits = string.split("\\sqrt") + new_string = splits[0] + for split in splits[1:]: + if split[0] != "{": + a = split[0] + new_substr = "\\sqrt{" + a + "}" + split[1:] + else: + new_substr = "\\sqrt" + split + new_string += new_substr + return new_string + + def strip_string(string): # linebreaks string = string.replace("\n", "") diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 37003deb7..e3669afe4 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -39,7 +39,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizer from open_instruct.utils import retry_on_exception -from open_instruct.ground_truth_utils import verify_gsm8k_sample, verify_math_sample, verify_strict_math_sample, verify_ifeval_sample +from open_instruct.ground_truth_utils import verify_gsm8k_sample, verify_math_sample, verify_ifeval_sample @dataclass From 527c51fefa72ea7abcebb532165fc61883c1acfa Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Thu, 31 Oct 2024 19:13:09 -0700 Subject: [PATCH 41/53] first stab at flan --- open_instruct/ground_truth_utils.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 2340900c0..08a6f0ad5 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -3,6 +3,7 @@ Used to give feedback to the model based on the ground truth answer. ''' import re +import string from open_instruct.math_utils import last_boxed_only_string, remove_boxed, get_unnormalized_answer, normalize_final_answer, is_equiv, hendrycks_is_equiv from open_instruct.if_functions import IF_FUNCTIONS_MAP @@ -106,9 +107,32 @@ def verify_ifeval_sample(model_output, constraint): return func(answer, **non_none_args) +def normalize_answer(s): + """ + Lower text and remove punctuation, articles and extra whitespace. + From https://github.com/huggingface/evaluate/blob/main/metrics/squad/compute_score.py + """ + + def remove_articles(text): + return re.sub(r"\b(a|an|the)\b", " ", text) + + def white_space_fix(text): + return " ".join(text.split()) + + def remove_punc(text): + exclude = set(string.punctuation) + return "".join(ch for ch in text if ch not in exclude) + + def lower(text): + return text.lower() + + return white_space_fix(remove_articles(remove_punc(lower(s)))) + + def verify_flan_sample(model_output, ground_truth_answer): - # TODO: flan. we could do BLEU/ROUGE.... or maybe something like BertScore? - pass + # Flan! we will just use... exact match with some basic cleaning, after extracting the answer. + answer_string = model_output.split("The answer is: ")[-1].strip() + return normalize_answer(answer_string) == normalize_answer(ground_truth_answer) # debug code From 281e28eb245648dc9f3dc4b38c664d1f0f06f2f7 Mon Sep 17 00:00:00 2001 From: nouhadziri Date: Mon, 4 Nov 2024 15:05:34 -0500 Subject: [PATCH 42/53] add olmo training --- .../sft/{ => olmo}/olmo_7b_0924.yaml | 0 .../sft/olmo/olmo_7b_0924_v3.9_safety.yaml | 24 +++++++++++++++++++ open_instruct/olmo/scripts/sft/olmo_test.sh | 7 ++++++ 3 files changed, 31 insertions(+) rename configs/train_configs/sft/{ => olmo}/olmo_7b_0924.yaml (100%) create mode 100644 configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml create mode 100644 open_instruct/olmo/scripts/sft/olmo_test.sh diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo/olmo_7b_0924.yaml similarity index 100% rename from configs/train_configs/sft/olmo_7b_0924.yaml rename to configs/train_configs/sft/olmo/olmo_7b_0924.yaml diff --git a/configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml b/configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml new file mode 100644 index 000000000..49f926e9f --- /dev/null +++ b/configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml @@ -0,0 +1,24 @@ +model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_revision: main +use_flash_attn: true +tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_mixer: + ai2-adapt-dev/synthetic-cot-wildguarmixtrain: 86759 + ai2-adapt-dev/tulu_v3.9_wildjailbreak_decontaminated_unused: 209574 # all +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 # should run with this set to 16 for 1 node only +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: /output/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file diff --git a/open_instruct/olmo/scripts/sft/olmo_test.sh b/open_instruct/olmo/scripts/sft/olmo_test.sh new file mode 100644 index 000000000..f8ff7b63e --- /dev/null +++ b/open_instruct/olmo/scripts/sft/olmo_test.sh @@ -0,0 +1,7 @@ +python scripts/submit_finetune_job.py \ + --default_beaker_config configs/beaker_configs/default_finetune_offloading.yaml \ + --config configs/train_configs/sft/olmo_7b_0924_v3.9_safety.yaml \ + --cluster ai2/jupiter-cirrascale-2 \ + --priority high \ + --exp_name nd-SFT-olmo_7b_0924_v3.9_safety \ + --num_gpus 8 \ No newline at end of file From e8ddd567fb4f16d69341499fe81328a5d1dc6df8 Mon Sep 17 00:00:00 2001 From: nouhadziri Date: Mon, 4 Nov 2024 15:15:03 -0500 Subject: [PATCH 43/53] fix dir in config --- open_instruct/olmo/scripts/sft/olmo_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/olmo/scripts/sft/olmo_test.sh b/open_instruct/olmo/scripts/sft/olmo_test.sh index f8ff7b63e..73c4b8153 100644 --- a/open_instruct/olmo/scripts/sft/olmo_test.sh +++ b/open_instruct/olmo/scripts/sft/olmo_test.sh @@ -1,6 +1,6 @@ python scripts/submit_finetune_job.py \ --default_beaker_config configs/beaker_configs/default_finetune_offloading.yaml \ - --config configs/train_configs/sft/olmo_7b_0924_v3.9_safety.yaml \ + --config configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml \ --cluster ai2/jupiter-cirrascale-2 \ --priority high \ --exp_name nd-SFT-olmo_7b_0924_v3.9_safety \ From abd25bf81a5194b28cbd56ee3af8a7ffe7a5d46b Mon Sep 17 00:00:00 2001 From: nouhadziri Date: Mon, 4 Nov 2024 16:04:59 -0500 Subject: [PATCH 44/53] rollback my changes --- .../sft/olmo/olmo_7b_0924_v3.9_safety.yaml | 24 ------------------- .../sft/{olmo => }/olmo_7b_0924.yaml | 0 open_instruct/olmo/scripts/sft/olmo_test.sh | 7 ------ 3 files changed, 31 deletions(-) delete mode 100644 configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml rename configs/train_configs/sft/{olmo => }/olmo_7b_0924.yaml (100%) delete mode 100644 open_instruct/olmo/scripts/sft/olmo_test.sh diff --git a/configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml b/configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml deleted file mode 100644 index 49f926e9f..000000000 --- a/configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml +++ /dev/null @@ -1,24 +0,0 @@ -model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -model_revision: main -use_flash_attn: true -tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan -use_slow_tokenizer: false # olmo models only use fast tokenizers -dataset_mixer: - ai2-adapt-dev/synthetic-cot-wildguarmixtrain: 86759 - ai2-adapt-dev/tulu_v3.9_wildjailbreak_decontaminated_unused: 209574 # all -max_seq_length: 4096 -preprocessing_num_workers: 128 -per_device_train_batch_size: 1 -gradient_accumulation_steps: 8 # should run with this set to 16 for 1 node only -learning_rate: 2.0e-06 -lr_scheduler_type: linear -warmup_ratio: 0.03 -weight_decay: 0.0 -num_train_epochs: 3 -output_dir: /output/ -with_tracking: true -report_to: - - wandb -logging_steps: 1 -checkpointing_steps: epoch -add_bos: true \ No newline at end of file diff --git a/configs/train_configs/sft/olmo/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml similarity index 100% rename from configs/train_configs/sft/olmo/olmo_7b_0924.yaml rename to configs/train_configs/sft/olmo_7b_0924.yaml diff --git a/open_instruct/olmo/scripts/sft/olmo_test.sh b/open_instruct/olmo/scripts/sft/olmo_test.sh deleted file mode 100644 index 73c4b8153..000000000 --- a/open_instruct/olmo/scripts/sft/olmo_test.sh +++ /dev/null @@ -1,7 +0,0 @@ -python scripts/submit_finetune_job.py \ - --default_beaker_config configs/beaker_configs/default_finetune_offloading.yaml \ - --config configs/train_configs/sft/olmo/olmo_7b_0924_v3.9_safety.yaml \ - --cluster ai2/jupiter-cirrascale-2 \ - --priority high \ - --exp_name nd-SFT-olmo_7b_0924_v3.9_safety \ - --num_gpus 8 \ No newline at end of file From f037460cca72ef0d2a23b36878e210404c606ce5 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 5 Nov 2024 13:24:11 -0500 Subject: [PATCH 45/53] eval on intermediate checkpoints (#414) --- open_instruct/model_utils.py | 3 + .../ppo_vllm_thread_ray_gtrl_weka.py | 1780 +++++++++++++++++ scripts/eval/oe-eval.sh | 57 +- scripts/submit_eval_jobs.py | 79 +- 4 files changed, 1903 insertions(+), 16 deletions(-) create mode 100644 open_instruct/ppo_vllm_thread_ray_gtrl_weka.py diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index e3669afe4..f0a2d48bf 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -227,6 +227,9 @@ def apply_verifiable_reward( rewards = [] for prediction, ground_truth, dataset in zip(decoded_responses, ground_truths, datasets): verified = False + if ground_truth is None: + rewards.append(0) + continue if dataset.lower() == 'gsm8k': verified = verify_gsm8k_sample(prediction, ground_truth) elif dataset.lower() == 'math': diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_weka.py b/open_instruct/ppo_vllm_thread_ray_gtrl_weka.py new file mode 100644 index 000000000..306554c2b --- /dev/null +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_weka.py @@ -0,0 +1,1780 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF +# which has the following license: +# Copyright [yyyy] [name of copyright owner] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import logging +import os +import random +import shutil +import socket +import subprocess +import threading +import time +from argparse import Namespace +from dataclasses import asdict, dataclass, field +from queue import Empty, Queue +from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple + +import deepspeed +import numpy as np +import pandas as pd +import ray +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils +import torch.utils.data +import vllm +from datasets import Dataset, DatasetDict +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from huggingface_hub import HfApi +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.queue import Queue as RayQueue +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rich.pretty import pprint +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + get_scheduler, +) +from transformers.integrations import HfDeepSpeedConfig +from vllm import SamplingParams + +from open_instruct.dataset_processor import ( + CHAT_TEMPLATES, + DATASET_SOURCE_KEY, + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + DatasetConfig, + SFTGroundTruthDatasetProcessor, + SimpleGenerateCollatorWithGroundTruth, + visualize_token, +) +from open_instruct.model_utils import ( + ModelConfig, + apply_verifiable_reward, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + print_rich_single_line_metrics, + print_rich_table, + push_folder_to_hub, + truncate_response, +) +from open_instruct.utils import ( + ArgumentParserPlus, + BeakerRuntimeConfig, + combine_dataset, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_hf_entity, + maybe_use_ai2_wandb_entity, + upload_metadata_to_hf, +) +from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group + +api = HfApi() +INVALID_LOGPROB = 1.0 + + +@dataclass +class Args: + # required dataset args + dataset_mixer: str = None + """A dictionary of datasets (local or HF) to sample from.""" + dataset_train_splits: List[str] = None + """The dataset splits to use for training""" + dataset_eval_mixer: Optional[str] = None + """A dictionary of datasets (local or HF) to sample from for evaluation""" + dataset_eval_splits: Optional[List[str]] = None + """The dataset splits to use for evaluation""" + dataset_mixer_dict: Optional[dict] = None + """The dataset mixer as a dictionary""" + dataset_eval_mixer_dict: Optional[dict] = None + """The dataset eval mixer as a dictionary""" + + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + seed: int = 1 + """Seed of the experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + + # optimizer args + eps: float = 1e-5 + """The epsilon value for the optimizer""" + learning_rate: float = 2e-5 + """The initial learning rate for AdamW optimizer.""" + lr_scheduler_type: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" + ] = "linear" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + # various batch sizes + num_train_epochs: int = 1 + """Number of epochs to train""" + gradient_accumulation_steps: Optional[int] = None + """The number of gradient accumulation steps""" + per_device_train_batch_size: Optional[int] = 1 + """The forward batch size per device (local_micro_batch_size)""" + per_device_eval_batch_size: Optional[int] = 1 + """The forward batch size per device for evaluation (local_micro_batch_size)""" + total_episodes: Optional[int] = 100000 + """The total number of episodes in the dataset""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_rollout_batch_size: int = 64 + """The number of rollout episodes per iteration per device""" + rollout_batch_size: Optional[int] = None + """The number of rollout episodes per iteration""" + num_training_steps: Optional[int] = None + """The number of training_steps to train""" + num_evals: int = 4 + """The number of evaluations to run throughout training""" + eval_freq: Optional[int] = None + """The frequency of evaluation steps""" + local_dataloader_batch_size: Optional[int] = None + """The batch size per GPU for the dataloader""" + save_freq: int = -1 + """How many train steps to save the model""" + + # online settings + num_epochs: int = 4 + """the number of epochs to train""" + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + local_mini_batch_size: int = 64 + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + reward_model_path: str = "EleutherAI/pythia-160m" + """the path to the reward model""" + reward_model_revision: Optional[str] = None + """the revision of the reward model""" + init_value_from_scratch: bool = False + """whether to initialize the value model from scratch""" + + # generation config + response_length: int = 53 + """the length of the response""" + stop_token: Optional[Literal["eos", "period"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + min_response_length: int = 0 + """stop only after this many tokens""" + temperature: float = 0.7 + """the sampling temperature""" + penalty_reward_value: float = -1.0 + """the reward value for responses that do not contain `stop_token_id`""" + non_stop_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + number_samples_per_prompt: int = 1 + """the number of samples to generate per prompt, useful for easy-star""" + + # online PPO specific args + beta: float = 0.05 + """the beta value of the RLHF objective (KL coefficient)""" + whiten_rewards: bool = False + """whether to whiten the rewards""" + cliprange: float = 0.2 + """the clip range""" + vf_coef: float = 0.1 + """the value function coefficient""" + cliprange_value: float = 0.2 + """the clip range for the value function""" + gamma: float = 1 + """the discount factor""" + lam: float = 0.95 + """the lambda value for GAE""" + kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" + """the KL estimator to use""" + apply_verifiable_reward: bool = False + """whether to apply verifiable reward""" + reward_model_multiplier: float = 1.0 + """the reward model multiplier, for down/upscaling the reward model output""" + answer_extraction_model: str = None + + # async setting + async_mode: bool = True + """Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" + + # ray + actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1]) + """number of gpus per node for actor""" + vllm_num_engines: int = 1 + """number of vLLM Engines, set to 0 to disable vLLM""" + vllm_tensor_parallel_size: int = 1 + """tensor parallel size of vLLM Engine for multi-GPU inference""" + vllm_sync_backend: str = "nccl" + """DeepSpeed -> vLLM weight sync backend""" + enable_prefix_caching: bool = False + """whether to enable prefix caching""" + deepspeed_stage: int = 0 + """the deepspeed stage""" + gather_whole_model: bool = False + """whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)""" + + # wandb and HF tracking configs + with_tracking: bool = False + """If toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "open_instruct_internal" + """The wandb's project name""" + wandb_entity: Optional[str] = None + """The entity (team) of wandb's project""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + output_dir: Optional[str] = None + """Where to save the model""" + checkpoint_output_dir: Optional[str] = None + """Where to save the model checkpoints in case of preemption""" + + # Ai2 specific settings + try_launch_beaker_eval_jobs: bool = True + """Whether to launch beaker evaluation jobs after training""" + try_launch_beaker_eval_jobs_on_weka: bool = False + """Whether to launch beaker evaluation jobs after training on weka""" + oe_eval_tasks: Optional[List[str]] = None + """The beaker evaluation tasks to launch""" + hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" + """What dataset to upload the metadata to. If unset, don't upload metadata""" + + def __post_init__(self): + self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) + if self.dataset_eval_mixer is not None: + self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) + + +def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: + # if passed through cli: convert the dataset mixers to dictionaries + if isinstance(value, str): + return json.loads(value), value + # if passed through yaml: convert the dataset mixers to strings + elif isinstance(value, dict): + return value, json.dumps(value) + else: + raise ValueError("Input must be either a string or a dictionary") + + +def calculate_runtime_args(args: Args, model_config: ModelConfig): + """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" + # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + # args.world_size = accelerator.num_processes + args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + args.gradient_accumulation_steps = exact_div( + args.local_mini_batch_size, + args.per_device_train_batch_size, + "`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`", + ) + args.world_size = sum(args.actor_num_gpus_per_node) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.mini_batch_size = int(args.local_mini_batch_size * args.world_size) + args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt) + args.eval_freq = max(1, args.num_training_steps // args.num_evals) + # PPO logic: do checks and set up dataloader batch size + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + args.local_dataloader_batch_size = args.rollout_batch_size + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: # first try to use AI2 entity + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: # then try to use the user's entity + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: # auto-generate one + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if args.with_tracking: + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + + +def get_train_ds_config( + offload, + adam_offload=False, + stage=0, + bf16=True, + max_norm=1.0, + zpg=8, + grad_accum_dtype=None, + disable_trace_cache=True, +): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": { + "device": "cpu" if adam_offload else "none", + "pin_memory": True, + }, + "sub_group_size": "auto", + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "reduce_bucket_size": "auto", + # # ZeRO++ + # "zero_hpz_partition_size": zpg, + # "zero_quantized_weights": False, + # "zero_quantized_gradients": False, + } + if disable_trace_cache: + zero_opt_dict["stage3_prefetch_bucket_size"] = 0 + zero_opt_dict["stage3_max_live_parameters"] = 0 + zero_opt_dict["stage3_max_reuse_distance"] = 0 + + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": max_norm, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, + } + + +def get_eval_ds_config( + offload, + stage=0, + bf16=True, +): + zero_opt_dict = { + "stage": stage, + "stage3_param_persistence_threshold": "auto", + "offload_param": { + "device": "cpu" if offload else "none", + "pin_memory": True, + }, + } + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def remove_padding(sequences, pad_token_id): + return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences] + + +class ShufflingIterator: + def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): + self.data = data.copy() + self.batch_size = batch_size + self.index = 0 + self.rng = np.random.default_rng(seed) + self.rng.shuffle(self.data) + + # Ensure the effective dataset size is divisible by batch_size + self.effective_size = len(self.data) - (len(self.data) % batch_size) + + def __iter__(self) -> Iterator[List[int]]: + return self + + def __next__(self) -> List[int]: + if self.index >= self.effective_size: + self.index = 0 + self.rng.shuffle(self.data) + + end_index = self.index + self.batch_size + batch = self.data[self.index : end_index].tolist() + self.index = end_index + + return batch + + +class RayProcess: + def __init__(self, world_size, rank, local_rank, master_addr, master_port): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + self.world_size = world_size + self.rank = rank + self.local_rank = local_rank + self.master_addr = master_addr if master_addr else self.get_current_node_ip() + self.master_port = master_port if master_port else self.get_free_port() + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["RANK"] = str(self.rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + # os.environ["LOCAL_RANK"] = str(self._local_rank) + os.environ["LOCAL_RANK"] = "0" + random.seed(self.rank) + np.random.seed(self.rank) + torch.manual_seed(self.rank) + + @staticmethod + def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + @staticmethod + def get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self.master_addr, self.master_port + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + +@ray.remote(num_gpus=1) +class PolicyTrainerRayProcess(RayProcess): + def from_pretrained(self, args: Args, model_config: ModelConfig, beaker_config: BeakerRuntimeConfig, wandb_url: str): + self.args = args + self.model_config = model_config + self.beaker_config = beaker_config + self.wandb_url = wandb_url + torch.cuda.set_device(self.local_rank) + deepspeed.init_distributed() + + ds_config = get_train_ds_config( + offload=False, + adam_offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.policy) + self.policy.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + print(ds_config) + self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.policy, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.model.train() + + # value model + self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + if args.init_value_from_scratch: + self.value_model.init_weights() # re-initialize the value model from scratch + disable_dropout_in_model(self.value_model) + self.value_model.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.value_model, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.value_model.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + self.value_model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.value_model, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.value_model.train() + + # reference model + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.ref_policy) + self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) + self.ref_policy.eval() + + # reward model + self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.reward_model) + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config) + self.reward_model.eval() + + def get_vocab_size(self): + return self.policy.config.vocab_size + + def forward( + self, + query_response: torch.LongTensor, + response: torch.LongTensor, + pad_token_id: int, + context_length: int, + temperature: float, + ) -> torch.Tensor: + output = forward(self.model, query_response, pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + return logprob + + def train( + self, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizer, + vllm_engines: List[ray.actor.ActorHandle], + metrics_queue: RayQueue, + data_collator: Callable, + ): + torch.set_printoptions(precision=4, sci_mode=False) + + args = self.args + + accelerator = Namespace() + accelerator.process_index = self.rank + accelerator.num_processes = self.world_size + accelerator.is_main_process = self.rank == 0 + torch.distributed.barrier() + if self.rank == 0: + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + vllm_num_engines, vllm_tensor_parallel_size = ( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + ) + world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 + backend = args.vllm_sync_backend + # https://github.com/OpenRLHF/OpenRLHF/issues/313 + if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0": + backend = "gloo" + print( + "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)" + ) + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * vllm_tensor_parallel_size + 1, + world_size, + "openrlhf", + backend=backend, + ) + for i, engine in enumerate(vllm_engines) + ] + self.model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name="openrlhf", + ) + ray.get(refs) + torch.distributed.barrier() + + def broadcast_to_vllm(): + # avoid OOM + torch.cuda.empty_cache() + model = self.model.module + count, num_params = 0, len(list(model.named_parameters())) + refss = [] + if args.gather_whole_model: + with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3): + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + else: # broadcast each parameter independently + for name, param in model.named_parameters(): + count += 1 + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + if torch.distributed.get_rank() == 0: + ray.get(refss) + + # broadcast_to_vllm() + if args.stop_token: + if args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id + if args.stop_token == "period": + args.stop_token_id = tokenizer.encode(".")[0] + # data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + train_dataset_idxs = np.arange(len(train_dataset)) + shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed) + + # hack to left pad + def repeat_generator(): + while True: + batch_idxs = next(shuffling_iter) + yield [train_dataset[i] for i in batch_idxs] + + iter_dataloader = iter(repeat_generator()) + generation_config = SamplingParams( + temperature=args.temperature, + top_p=1.0, + max_tokens=args.response_length, + include_stop_str_in_output=True, + n=args.number_samples_per_prompt, + ) + # print("setup async queues") + param_prompt_Q = None + response_ids_Q = None + evaluation_Q = None + response_ids_Q = Queue(maxsize=1) + param_prompt_Q = Queue(maxsize=1) + evaluation_Q = Queue(maxsize=1) + num_eval_samples = 32 + sample_evaluation_prompt_token_ids = None + if eval_dataset is not None: + sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] + + def vllm_generate( + generation_config: SamplingParams, + response_ids_Q: Queue, + param_prompt_Q: Queue, + num_training_steps: int, + sample_evaluation_prompt_token_ids: Optional[List[int]], + evaluation_Q: Queue, + eval_freq: int, + resume_training_step: int, + ): + llm = vllm_engines[0] + for training_step in range(resume_training_step, num_training_steps + 1): + items = param_prompt_Q.get() + if items is None: + break + unwrapped_model, g_queries_list = items + # if unwrapped_model is not None: + generation_start_time = time.time() + + outputs = ray.get( + llm.generate.remote(sampling_params=generation_config, prompt_token_ids=g_queries_list, use_tqdm=False) + ) + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") + response_ids_Q.put(response_ids) + + if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: + outputs = ray.get( + llm.generate.remote( + prompt_token_ids=sample_evaluation_prompt_token_ids, sampling_params=generation_config, use_tqdm=False + ) + ) + # for evaluation, even if we have multiple outputs, we only look at one of them for simplicity + response_ids = [list(output.outputs[0].token_ids) for output in outputs] + evaluation_Q.put(response_ids) + + resume_training_step = 1 + if accelerator.is_main_process: + thread = threading.Thread( + target=vllm_generate, + args=( + generation_config, + response_ids_Q, + param_prompt_Q, + args.num_training_steps, + sample_evaluation_prompt_token_ids, + evaluation_Q, + args.eval_freq, + resume_training_step, + ), + ) + thread.start() + print("vllm generate thread starts") + + # set up the metrics and initial states + device = torch.device(self.local_rank) + g_vllm_responses = torch.zeros( + (args.rollout_batch_size * args.number_samples_per_prompt, args.response_length), + device=device, + dtype=torch.long, + ) + stats_shape = ( + args.num_epochs, + args.num_mini_batches * args.number_samples_per_prompt, + args.gradient_accumulation_steps, + ) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + local_metrics = torch.zeros((20,), device=device) + episode = args.rollout_batch_size * (resume_training_step - 1) + + # training loop + start_time = time.time() + global_data = next(iter_dataloader) + data = data_collator( + global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size] + ) + global_queries = data_collator(global_data)[ + INPUT_IDS_PROMPT_KEY + ].tolist() # can be simplified since we `remove_padding` later anyway + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + + answer_extraction_model = None + answer_extraction_tokenizer = None + # for _ in range(1, resume_training_step): # we didn't store scheduler state + # scheduler.step() + + for training_step in range(resume_training_step, args.num_training_steps + 1): + episode += args.rollout_batch_size * args.number_samples_per_prompt # each sample is an episode + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + if accelerator.is_main_process: + df = None + try: + evaluation_responses = evaluation_Q.get(timeout=0.01) + print("🔥🔥🔥 Evaluation responses received") + table = {} + table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids) + table["response"] = tokenizer.batch_decode(evaluation_responses) + table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] + df = pd.DataFrame(table) + del table + except Empty: + print("🙈 Evaluation responses not received") + + # (optionally) evaluate the model + if args.async_mode: + if training_step != 1: + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + else: + if training_step != 1: + # NOTE: important: the indent here is different for sync mode + # we also set to use `queries = queries_next` immediately + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + torch.cuda.empty_cache() + # print('get reward stuff starts') + # if we generate multiple samples per prompt, we need to repeat the queries and ground truths + # to match the vllm outputs. + if args.number_samples_per_prompt > 1: + queries = queries.repeat_interleave(args.number_samples_per_prompt, dim=0) + ground_truths = [gt for gt in ground_truths for _ in range(args.number_samples_per_prompt)] + datasets = [ds for ds in datasets for _ in range(args.number_samples_per_prompt)] + + training_time_start = time.time() + with torch.no_grad(): + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + verifiable_counts = [] + sequence_lengths = [] + values = [] + if accelerator.is_main_process: + g_response_token_ids = response_ids_Q.get() + DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out + g_padded_response_ids = [ + response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) + for response in g_response_token_ids + ] + g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) + g_vllm_responses[:] = g_padded_response_ids + dist.broadcast(g_vllm_responses, src=0) + local_vllm_responses = g_vllm_responses[ + accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] + ] + # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") + query_responses = torch.cat((queries, local_vllm_responses), 1) + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + # print(f"get reward stuff starts {i=}") + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + + logprob = self.forward( + query_response, response, tokenizer.pad_token_id, context_length, args.temperature + ) + torch.cuda.empty_cache() + + ref_output = forward(self.ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + # print("get reward stuff starts 2") + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + if args.reward_model_multiplier != 1.0: + score *= args.reward_model_multiplier + # also apply verifiable reward + if args.apply_verifiable_reward: + # we need to batch the gt to match query. + ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] + dataset = datasets[i : i + args.local_rollout_forward_batch_size] + verifiable_reward, verifiable_count = apply_verifiable_reward( + postprocessed_query_response, + tokenizer, + ground_truth, + dataset, + verify_reward=10, + answer_extraction_model=answer_extraction_model, + answer_extraction_tokenizer=answer_extraction_tokenizer, + ) + score += verifiable_reward + else: + verifiable_count = torch.tensor([0.0], device=device).float() + full_value, _, _ = get_reward( + self.value_model, query_response, tokenizer.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + verifiable_counts.append(verifiable_count) + # print(f"get reward stuff starts 5") + + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + verifiable_counts = torch.cat(verifiable_counts, 0) + verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0] + values = torch.cat(values, 0) + # print(f"get reward stuff finished") + del (logprob, ref_logprob, full_value, value, score) + gc.collect() + torch.cuda.empty_cache() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_stop_token = torch.any(postprocessed_responses == args.stop_token_id, dim=-1) + # NOTE: only apply the stop token filter if the response is long enough + # otherwise the model could learn to generate the first token as the stop token + contain_stop_token = contain_stop_token & (sequence_lengths >= args.min_response_length) + if args.non_stop_penalty: + scores = torch.where( + contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value) + ) + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + # print(f"get reward stuff finished 2") + + # 4. compute rewards + kl1 = logprobs - ref_logprobs + kl2 = (kl1) ** 2 / 2 + kl3 = (-kl1).exp() - 1 + kl1 + if args.kl_estimator == "kl1": + kl = kl1 + elif args.kl_estimator == "kl2": + kl = kl2 + elif args.kl_estimator == "kl3": + kl = kl3 + # if self.rank==0: + # print(f"{logprobs[0][:40]=}, {ref_logprobs[0][:40]=}, {kl.sum(1)=}") + non_score_reward = -args.beta * kl + non_score_reward_sum = non_score_reward.sum(1) + rlhf_reward = scores + non_score_reward_sum + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + # print(f"get reward stuff finished 3") + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # print('gae') + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # print('training starts') + # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch + for epoch_idx in range(args.num_epochs): + b_inds = np.random.permutation(args.local_rollout_batch_size * args.number_samples_per_prompt) + minibatch_idx = 0 + for mini_batch_start in range( + 0, args.local_rollout_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size + ): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + # print("micro batch start", micro_batch_start, self.rank) + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds] + + vpred_temp = get_reward( + self.value_model, mb_query_responses, tokenizer.pad_token_id, context_length + ) + vpred_temp = vpred_temp[0] + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1) + self.value_model.backward(vf_loss * args.vf_coef) + self.value_model.step() + + new_logprobs = self.forward( + mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature + ) + # if self.rank==0: + # print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}") + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + self.model.backward(loss) + # print("backward loss", self.rank, "micro batch start", micro_batch_start) + # print("trying to step", self.rank, "micro batch start", micro_batch_start) + self.model.step() + # print("step", self.rank, "micro batch start", micro_batch_start) + with torch.no_grad(): + # print("waiting for value model step", self.rank, "micro batch start", micro_batch_start) + # vf_loss, vf_clipfrac = ray.get(value_model_step_future) + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1) + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + # print("value model stepped", self.rank, "micro batch start", micro_batch_start) + # prob_dist = torch.nn.functional.softmax(logits, dim=-1) + # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + # entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # fmt: off + del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1 + del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss + # del vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max + # del vf_loss, vf_clipfrac, pg_clipfrac, approxkl + # fmt: on + # del everything and empty cache + torch.cuda.empty_cache() + del b_inds, mini_batch_inds + # print("start metrics") + with torch.no_grad(): + local_metrics[0] = sequence_lengths.float().mean() + local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() + local_metrics[2] = kl.sum(1).mean() + local_metrics[3] = (-logprobs).sum(1).mean() + local_metrics[4] = non_score_reward_sum.mean() + local_metrics[5] = rlhf_reward.mean() + local_metrics[6] = scores.mean() + local_metrics[7] = approxkl_stats.mean() + local_metrics[8] = pg_clipfrac_stats.mean() + local_metrics[9] = pg_loss_stats.mean() + local_metrics[10] = vf_loss_stats.mean() + local_metrics[11] = vf_clipfrac_stats.mean() + local_metrics[12] = entropy_stats.mean() + local_metrics[13] = ratio_stats.mean() + local_metrics[14] = ratio_stats.var() + print(ratio_stats) + local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() + local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() + local_metrics[17] = verifiable_correct_rate + # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() + local_metrics /= dist.get_world_size() + dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) + global_metrics = local_metrics.tolist() + metrics = { + "episode": episode, + "training_step": training_step, + "lr": self.scheduler.get_last_lr()[0], + "epoch": episode / len(train_dataset), + "time/from_scratch": time.time() - start_time, + "time/training": time.time() - training_time_start, + "val/sequence_lengths": global_metrics[0], + "val/num_stop_token_ids": global_metrics[1], + "objective/kl": global_metrics[2], + "objective/kl2": global_metrics[15], + "objective/kl3": global_metrics[16], + "objective/entropy": global_metrics[3], + "objective/non_score_reward": global_metrics[4], + "objective/rlhf_reward": global_metrics[5], + "objective/scores": global_metrics[6], + "policy/approxkl_avg": global_metrics[7], + "policy/clipfrac_avg": global_metrics[8], + "loss/policy_avg": global_metrics[9], + "loss/value_avg": global_metrics[10], + "val/clipfrac_avg": global_metrics[11], + "policy/entropy_avg": global_metrics[12], + "val/ratio": global_metrics[13], + "val/ratio_var": global_metrics[14], + "objective/verifiable_correct_rate": global_metrics[17], + } + if accelerator.is_main_process: + print_rich_single_line_metrics(metrics) + metrics_queue.put((metrics, episode, df)) + del (queries, responses, postprocessed_responses, logprobs, ref_logprobs, sequence_lengths, scores, values) + del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) + gc.collect() + torch.cuda.empty_cache() + # print(f"finished training {training_step}") + + # save steps + if args.save_freq > 0 and training_step % args.save_freq == 0: + checkpoint_dir = f"{args.output_dir}_checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") + os.makedirs(step_dir, exist_ok=True) + print(f"Saving model at step {training_step} to {step_dir}") + self.save_model(step_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(step_dir, training_step) + print(f"Saving final model at step {training_step} to {args.output_dir}") + self.save_model(args.output_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(args.output_dir) + + # Ai2 logic: we use /output to store the artifacts of the job, so we + # make a copy of the model to `/output` in the end. + if len(self.beaker_config.beaker_dataset_id_urls) > 0: + shutil.copytree(args.output_dir, "/output") + print("finished training") + + + def save_model(self, output_dir: str) -> None: + if self.rank == 0: + os.makedirs(output_dir, exist_ok=True) + + # save model weights for ZeRO2/3 + model_to_save = self.model + if hasattr(model_to_save, "module"): + model_to_save = model_to_save.module + + # gather parameters + output_state_dict = {} + for k, v in model_to_save.named_parameters(): + # only gather z3 params + params_to_fetch = _z3_params_to_fetch([v]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + vv = v.data.cpu() + if self.rank == 0: + output_state_dict[k] = vv + + if self.rank == 0: + state_dict = model_to_save.state_dict() + + # copy named_buffers with `persistent=True` + for k, v in model_to_save.named_buffers(): + if k not in state_dict: + continue + vv = v.data.cpu() + output_state_dict[k] = vv + + state_dict_keys = set(state_dict.keys()) + output_state_dict_keys = set(output_state_dict.keys()) + + # corner case for tie_word_embeddings, such as Qwen2-0.5B + if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: + state_dict_keys.remove("lm_head.weight") + + assert state_dict_keys.issubset( + output_state_dict_keys + ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" + + # # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 + # if isinstance(model_to_save, PeftModel): + # model_to_save.save_pretrained(output_dir, **kwargs) + # if self.stage == 3: + # torch.save( + # get_peft_model_state_dict(model_to_save, output_state_dict), + # os.path.join(output_dir, "adapter_model.bin"), + # ) + # else: + # save model + model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) + + # save tokenizer + self.original_tokenizer.save_pretrained(output_dir) + + def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = None) -> None: + """auto eval the metrics as `f"{args.exp_name}_step_{training_step}"` in our leaderboard""" + args = self.args + beaker_config = self.beaker_config + model_config = self.model_config + wandb_url = self.wandb_url + # Ai2 specific logic + if is_beaker_job() and self.rank == 0: + if training_step is not None: + leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" + else: + leaderboard_name = args.hf_repo_revision + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "ppo", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb_url, + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + leaderboard_name, # to match what the auto-evals name as. + ) + + command = f"""\ +python scripts/submit_eval_jobs.py \ + --model_name {leaderboard_name} \ + --location {step_dir} \ + --cluster ai2/saturn-cirrascale \ + --is_tuned \ + --workspace "tulu-3-results" \ + --preemptible \ + --use_hf_tokenizer_template \ + --beaker_image "nathanl/open_instruct_auto" \ + --upload_to_hf allenai/tulu-3-evals \ + --run_oe_eval_experiments \ + --evaluate_on_weka \ + --run_safety_evaluations \ + --skip_oi_evals""" + if args.oe_eval_tasks is not None: + command += f" --oe_eval_tasks {','.join(args.oe_eval_tasks)}" + print(f"Launching eval jobs with command: {command}") + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + +def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): + while True: + if stop_event.is_set(): + break + for ref in object_refs: + try: + ray.get(ref, timeout=0.01) + except ray.exceptions.GetTimeoutError: + pass + except Exception as e: + print(e) + print(f"Actor {ref} died") + time.sleep(120) + ray.shutdown() + os._exit(1) # Force shutdown the process + + time.sleep(30) + + +class ModelGroup: + def __init__( + self, + pg: PlacementGroup, + ray_process_cls: RayProcess, + num_gpus_per_node: List[int], + ): + self.pg = pg + self.ray_process_cls = ray_process_cls + self.num_gpus_per_node = num_gpus_per_node + self.num_gpus_per_actor = 1 + self.num_cpus_per_actor = 4 + self.models = [] + world_size = sum(self.num_gpus_per_node) + master_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=self.pg, placement_group_bundle_index=0 + ), + ).remote(world_size, 0, 0, None, None) + + self.models.append(master_policy) + master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) + + def get_bundle_index(rank, num_gpus_per_node): + """given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" + bundle_idx = 0 + while rank >= num_gpus_per_node[bundle_idx]: + rank -= num_gpus_per_node[bundle_idx] + bundle_idx += 1 + return bundle_idx + + assert get_bundle_index(0, [7, 8, 4]) == 0 + assert get_bundle_index(1, [7, 8, 4]) == 0 + assert get_bundle_index(7, [7, 8, 4]) == 1 + assert get_bundle_index(8, [7, 8, 4]) == 1 + assert get_bundle_index(9, [7, 8, 4]) == 1 + assert get_bundle_index(16, [7, 8, 4]) == 2 + + # Setup worker models + for rank in range(1, world_size): + print(f"{rank=}, {world_size=}, {rank=}, {master_addr=}, {master_port=}") + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.pg, + placement_group_bundle_index=get_bundle_index(rank, self.num_gpus_per_node), + ) + worker_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=scheduling_strategy, + ).remote(world_size, rank, 0, master_addr, master_port) + self.models.append(worker_policy) + + +def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + calculate_runtime_args(args, model_config) + + # set up experiment tracking and seeds + all_configs = {} + beaker_config = None + if is_beaker_job(): + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", None) + beaker_config = maybe_get_beaker_config() + all_configs.update(vars(beaker_config)) + all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) + if args.with_tracking: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=all_configs, + name=args.run_name, + save_code=True, + tags=[args.exp_name] + get_wandb_tags(), + ) + writer = SummaryWriter(f"runs/{args.run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # create a tokenizer (pad from right) + config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + ) + if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + + # create the dataset + dataset_dict = DatasetDict() + dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1: + args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict) + print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets.") + if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1: + args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict) + print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets.") + train_dataset = combine_dataset( + args.dataset_mixer_dict, + splits=args.dataset_train_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + if dataset_config.sanity_check: + train_dataset = train_dataset.select( + range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) + ) + train_dataset = dataset_processor.tokenize(train_dataset) + train_dataset = dataset_processor.filter(train_dataset, need_contain_labels=False) + dataset_dict["train"] = train_dataset + eval_dataset = None + if args.dataset_eval_mixer is not None: + eval_dataset = combine_dataset( + args.dataset_eval_mixer_dict, + splits=args.dataset_eval_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + if dataset_config.sanity_check: + eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) + eval_dataset = dataset_processor.tokenize(eval_dataset) + eval_dataset = dataset_processor.filter(eval_dataset, need_contain_labels=False) + dataset_dict["eval"] = eval_dataset + data_collator = SimpleGenerateCollatorWithGroundTruth(pad_token_id=tokenizer.pad_token_id) + + # some more runtime logging + pprint([args, dataset_config, model_config]) + visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) + if args.with_tracking: + # upload the visualized token length + dataset_processor.get_token_length_visualization( + dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" + ) + wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) + + # create the model and optimizer + pg = None + bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.actor_num_gpus_per_node] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + inits = [] + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + args.actor_num_gpus_per_node, + ) + wandb_url = wandb.run.get_url() if args.with_tracking else None + inits.extend(model.from_pretrained.remote(args, model_config, beaker_config, wandb_url) for model in policy_group.models) + max_len = dataset_config.max_prompt_token_length + args.response_length + vllm_engines = create_vllm_engines( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + model_config.model_name_or_path, + model_config.model_revision, + args.seed, + args.enable_prefix_caching, + max_len, + ) + + metrics_queue = RayQueue() + ray.get(inits) + print("======== all models initialized =========") + ray.get(policy_group.models[0].get_vocab_size.remote()) + # print(f"{policy_vocab_size=}, {reward_vocab_size=}") + # if policy_vocab_size != reward_vocab_size: + # ray.shutdown() # shutdown here so this error message is not buried in the logs + # raise ValueError( + # "Policy and reward model must have the same vocab size. " + # f"Policy: {policy_vocab_size}, Reward: {reward_vocab_size}. " + # "If they don't have the same vocab size, the policy could generate tokens which " + # "is going to cause index out of bound error in the reward model." + # ) + + refs = [] + for i, policy_model in enumerate(policy_group.models): + refs.append( + policy_model.train.remote( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + vllm_engines=vllm_engines, + metrics_queue=metrics_queue, + data_collator=data_collator, + ) + ) + + # somtimes a worker dies due to CUDA issues, but the rest of the cluster would just hang + # so we need kill the ray cluster when this happens. + stop_event = threading.Event() + threading.Thread(target=kill_ray_cluster_if_a_worker_dies, args=(refs, stop_event)).start() + + # train and gather metrics + resume_training_step = 1 + for training_step in range(resume_training_step, args.num_training_steps + 1): + result = metrics_queue.get() + metrics, episode, df = result + for key, value in metrics.items(): + writer.add_scalar(key, value, episode) + + if df is not None: + if args.with_tracking: + wandb.log({"sample_completions": wandb.Table(dataframe=df)}) + # else: + # print_rich_table(df) + ray.get(refs) + + # save model + ray.shutdown() + stop_event.set() + + # Ai2 specific logic + if is_beaker_job(): + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "sft", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: + command = f"""\ + python mason.py \ + --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --model_name {args.hf_repo_revision} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + accelerator = Namespace() + accelerator.is_main_process = True # hack + if args.push_to_hub: + print("Pushing model to hub") + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + + # The `checkpoint_output_dir` is only used in case of preemption and should be deleted if the run was successful. + # We use `--save_freq` to save intermediate checkpoints in the output folder instead. + if args.checkpoint_output_dir is not None and os.path.exists(args.checkpoint_output_dir): + shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) + main(*parser.parse()) diff --git a/scripts/eval/oe-eval.sh b/scripts/eval/oe-eval.sh index ed605b685..320c08912 100755 --- a/scripts/eval/oe-eval.sh +++ b/scripts/eval/oe-eval.sh @@ -1,6 +1,6 @@ #!/bin/bash -set -ex +# set -ex # A script for using oe-eval for our development! # to use, clone oe-eval (https://github.com/allenai/oe-eval-internal) into the top level dir of this repo. @@ -36,7 +36,8 @@ set -ex # Function to print usage usage() { - echo "Usage: $0 --model-name MODEL_NAME --model-location MODEL_LOCATION [--num_gpus GPUS] [--hf-upload] [--revision REVISION] [--max-length ]" + echo "Usage: $0 --model-name MODEL_NAME --model-location MODEL_LOCATION [--num_gpus GPUS] [--hf-upload] [--revision REVISION] [--max-length ] [--tasks TASKS] [--evaluate_on_weka]" + echo "TASKS should be a comma-separated list of task specifications (e.g., 'gsm8k::tulu,bbh:cot::tulu')" exit 1 } @@ -49,14 +50,13 @@ while [[ "$#" -gt 0 ]]; do --hf-upload) HF_UPLOAD="true" ;; --revision) REVISION="$2"; shift ;; --max-length) MAX_LENGTH="$2"; shift ;; + --tasks) CUSTOM_TASKS="$2"; shift ;; + --evaluate_on_weka) EVALUATE_ON_WEKA="true" ;; *) echo "Unknown parameter passed: $1"; usage ;; esac shift done -# Optional: Default number of GPUs if not specified -NUM_GPUS="${NUM_GPUS:-1}" - # Check required arguments if [[ -z "$MODEL_NAME" || -z "$MODEL_LOCATION" ]]; then echo "Error: --model-name and --model-location are required." @@ -69,6 +69,7 @@ MODEL_NAME_SAFE=${MODEL_NAME//\//_} # Set defaults for optional arguments HF_UPLOAD="${HF_UPLOAD:-false}" MAX_LENGTH="${MAX_LENGTH:-4096}" +EVALUATE_ON_WEKA="${EVALUATE_ON_WEKA:-false}" # Set HF_UPLOAD_ARG if HF_UPLOAD is true if [ "$HF_UPLOAD" == "true" ]; then @@ -84,7 +85,8 @@ else REVISION_ARG="" fi -TASKS=( +# Define default tasks if no custom tasks provided +DEFAULT_TASKS=( "gsm8k::tulu" "bbh:cot::tulu" "drop::llama3" @@ -97,6 +99,17 @@ TASKS=( "alpaca_eval_v2::tulu" "truthfulqa::tulu" ) + +# If custom tasks provided, convert comma-separated string to array +if [[ -n "$CUSTOM_TASKS" ]]; then + IFS=',' read -ra TASKS <<< "$CUSTOM_TASKS" +else + TASKS=("${DEFAULT_TASKS[@]}") +fi + +# Optional: Default number of GPUs if not specified +NUM_GPUS="${NUM_GPUS:-1}" + MODEL_TYPE="--model-type vllm" BATCH_SIZE_VLLM=10000 BATCH_SIZE_OTHER=1 @@ -106,6 +119,7 @@ GPU_COUNT_OTHER=$((NUM_GPUS * 2)) MODEL_TYPE_OTHER="" for TASK in "${TASKS[@]}"; do + echo $TASK # mmlu and truthfulqa need different batch sizes and gpu counts if [[ "$TASK" == "mmlu:mc::tulu" || "$TASK" == "truthfulqa::tulu" ]]; then BATCH_SIZE=$BATCH_SIZE_OTHER @@ -117,5 +131,34 @@ for TASK in "${TASKS[@]}"; do GPU_COUNT=$GPU_COUNT fi - python oe-eval-internal/oe_eval/launch.py --model "$MODEL_NAME" --beaker-workspace "ai2/tulu-3-results" --beaker-budget ai2/oe-adapt --task "$TASK" $MODEL_TYPE --batch-size "$BATCH_SIZE" --model-args "{\"model_path\":\"${MODEL_LOCATION}\", \"max_length\": ${MAX_LENGTH}}" ${HF_UPLOAD_ARG} --gpus "$GPU_COUNT" --gantry-args '{"env-secret": "OPENAI_API_KEY=openai_api_key"}' ${REVISION_ARG} --beaker-retries 2 + if [ "$EVALUATE_ON_WEKA" == "true" ]; then + python oe-eval-internal/oe_eval/launch.py \ + --model "$MODEL_NAME" \ + --beaker-workspace "ai2/tulu-3-results" \ + --beaker-budget ai2/oe-adapt \ + --task "$TASK" \ + $MODEL_TYPE \ + --batch-size "$BATCH_SIZE" \ + --model-args "{\"model_path\":\"${MODEL_LOCATION}\", \"max_length\": ${MAX_LENGTH}}" \ + ${HF_UPLOAD_ARG} \ + --gpus "$GPU_COUNT" \ + --gantry-args '{"env-secret": "OPENAI_API_KEY=openai_api_key", "weka": "oe-adapt-default:/weka/oe-adapt-default"}' \ + ${REVISION_ARG} \ + --beaker-retries 2 \ + --cluster ai2/saturn-cirrascale --beaker-priority "high" + else + python oe-eval-internal/oe_eval/launch.py \ + --model "$MODEL_NAME" \ + --beaker-workspace "ai2/tulu-3-results" \ + --beaker-budget ai2/oe-adapt \ + --task "$TASK" \ + $MODEL_TYPE \ + --batch-size "$BATCH_SIZE" \ + --model-args "{\"model_path\":\"${MODEL_LOCATION}\", \"max_length\": ${MAX_LENGTH}}" \ + ${HF_UPLOAD_ARG} \ + --gpus "$GPU_COUNT" \ + --gantry-args '{"env-secret": "OPENAI_API_KEY=openai_api_key"}' \ + ${REVISION_ARG} \ + --beaker-retries 2 + fi done diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index 52aacea3c..e11d7aaa6 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -66,6 +66,31 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): # Launcher +NFS_CLUSTERS = [ + "ai2/allennlp-cirrascale", + "ai2/aristo-cirrascale", + "ai2/climate-cirrascale", + "ai2/general-cirrascale", + "ai2/general-cirrascale-a5000", + "ai2/mosaic-cirrascale", + "ai2/mosaic-cirrascale-a100", + "ai2/pluto-cirrascale", + "ai2/prior-cirrascale", + "ai2/s2-cirrascale", + "ai2/s2-cirrascale-l40", +] + +WEKA_CLUSTERS = [ + "ai2/jupiter-cirrascale-2", + "ai2/saturn-cirrascale", + "ai2/neptune-cirrascale", + "ai2/allennlp-elara-cirrascale", +] +GCP_CLUSTERS = [ + "ai2/augusta-google-1" +] + + today = date.today().strftime("%m%d%Y") parser = argparse.ArgumentParser() @@ -102,6 +127,8 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): parser.add_argument("--run_safety_evaluations", action="store_true", help="Run the OE safety evaluations too.") parser.add_argument("--skip_oi_evals", action="store_true", help="Don't run open instruct evals.") parser.add_argument("--oe_eval_max_length", type=int, default=4096, help="Max length for OE eval.") +parser.add_argument("--evaluate_on_weka", action="store_true", help="Evaluate OE eval on Beaker.") +parser.add_argument("--oe_eval_tasks", type=str, default=None, help="Evaluate OE eval on Beaker.") args = parser.parse_args() @@ -121,11 +148,25 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): d1['tasks'][0]['resources']['gpuCount'] = 1 # remove nfs if asked or jupiter in cluster list. -nfs_available = True -if args.no_nfs or any(["jupiter" in c for c in cluster]): - # remove the NFS dataset - last element in the list. - d1['tasks'][0]['datasets'] = d1['tasks'][0]['datasets'][:-1] - nfs_available = False +nfs_available = False +weka_available = False +if all(c in NFS_CLUSTERS for c in cluster): + d1['tasks'][0]['datasets'].append({ + 'mountPath': "/net/nfs.cirrascale", + "source": { + "hostPath": "/net/nfs.cirrascale" + } + }) + nfs_available = True +elif all(c in WEKA_CLUSTERS for c in cluster): + d1['tasks'][0]['datasets'].append({ + 'mountPath': "/weka/oe-adapt-default", + "source": { + "weka": "oe-adapt-default" + } + }) + weka_available = True + # Use a different image if requested. if args.beaker_image is not None: @@ -462,7 +503,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): task_spec['arguments'] = [task_spec['arguments'][0].replace("--model_name_or_path /model", f"--model_name_or_path {model_info[1]} --hf_revision {args.hf_revision}")] task_spec['arguments'] = [task_spec['arguments'][0].replace("--tokenizer_name_or_path /model", f"--tokenizer_name_or_path {model_info[1]}")] elif model_info[1].startswith("/"): # if it's a local model, load it from the local directory - assert nfs_available, "NFS is required for path-based models." # to be safe. + assert nfs_available or weka_available, "NFS / Weka is required for path-based models." # to be safe. task_spec['arguments'] = [task_spec['arguments'][0].replace("--model_name_or_path /model", f"--model_name_or_path {model_info[1]}")] task_spec['arguments'] = [task_spec['arguments'][0].replace("--tokenizer_name_or_path /model", f"--tokenizer_name_or_path {model_info[1]}")] else: # if it's a beaker model, mount the beaker dataset to `/model` @@ -582,10 +623,16 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): ## model location munging: if beaker, use beaker://. If hf, just name if model_info[0].startswith("hf-"): oe_eval_cmd += f" --model-location {model_info[1]}" + elif model_info[1].startswith("/"): + oe_eval_cmd += f" --model-location {model_info[1]}" else: oe_eval_cmd += f" --model-location beaker://{model_info[1]}" if args.hf_revision: oe_eval_cmd += f" --revision {args.hf_revision}" + if args.evaluate_on_weka: + oe_eval_cmd += " --evaluate_on_weka" + if args.oe_eval_tasks: + oe_eval_cmd += f" --tasks {args.oe_eval_tasks}" # add string with number of gpus num_gpus = task_spec['resources']['gpuCount'] # if num_gpus > 1, double it again for oe-eval configs @@ -603,7 +650,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): # create an experiment that runs the safety eval tasks if args.run_safety_evaluations: # just take the original spec we had, modify it for safety eval. - experiment_name = f"oi_safety_{model_name}" + experiment_name = f"oi_safety_{model_name.replace('β', '')}" d["description"] = experiment_name # specific image for safety eval d["tasks"][0]["image"]["beaker"] = "hamishivi/open-safety" @@ -611,7 +658,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): task_spec = d["tasks"][0] task_spec["name"] = experiment_name task_spec["arguments"][0] = f''' -PYTHONPATH=. python evaluation/run_all_generation_benchmarks.py \ +VLLM_WORKER_MULTIPROC_METHOD=spawn PYTHONPATH=. python evaluation/run_all_generation_benchmarks.py \ --model_name_or_path /model \ --model_input_template_path_or_name hf \ --report_output_path /output/metrics.json \ @@ -622,12 +669,26 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): task_spec['arguments'] = [task_spec['arguments'][0].replace("--model_name_or_path /model", f"--model_name_or_path {model_info[1]} --hf_revision {args.hf_revision}")] task_spec['arguments'] = [task_spec['arguments'][0].replace("--tokenizer_name_or_path /model", f"--tokenizer_name_or_path {model_info[1]}")] elif model_info[1].startswith("/"): # if it's a local model, load it from the local directory - assert nfs_available, "NFS is required for path-based models." # to be safe. + assert nfs_available or weka_available, "NFS / Weka is required for path-based models." # to be safe. task_spec['arguments'] = [task_spec['arguments'][0].replace("--model_name_or_path /model", "--model_name_or_path "+model_info[1])] task_spec['arguments'] = [task_spec['arguments'][0].replace("--tokenizer_name_or_path /model", "--tokenizer_name_or_path "+model_info[1])] else: # if it's a beaker model, mount the beaker dataset to `/model` task_spec['datasets'][1]['source']['beaker'] = model_info[1] + task_spec = adjust_gpus( + task_spec=task_spec, + experiment_group="safety_eval", + model_name=model_info[0], + gpu_multiplier=args.gpu_multiplier, + ) + + # add gpu information. + # we just assume you want to use all the gpus for one task at a time + if "70B" in model_info[0]: + task_spec['resources']['gpuCount'] = 8 + num_gpus = task_spec['resources']['gpuCount'] + task_spec["arguments"][0]+= f" --min_gpus_per_task {num_gpus}" + if args.upload_to_hf: hf_dataset = args.upload_to_hf # to match the way oe-eval script works. From 63a44494f39f675b196775e18a25b2393aa09e40 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 5 Nov 2024 21:27:44 +0000 Subject: [PATCH 46/53] quick change --- open_instruct/ppo_vllm_thread.py | 17 +- open_instruct/ppo_vllm_thread_ray_gtrl.py | 188 +- .../ppo_vllm_thread_ray_gtrl_weka.py | 1780 ---------------- open_instruct/ppo_vllm_thread_ray_old.py | 1799 ----------------- 4 files changed, 160 insertions(+), 3624 deletions(-) delete mode 100644 open_instruct/ppo_vllm_thread_ray_gtrl_weka.py delete mode 100644 open_instruct/ppo_vllm_thread_ray_old.py diff --git a/open_instruct/ppo_vllm_thread.py b/open_instruct/ppo_vllm_thread.py index 9a3ee1dd1..6ee84dcc1 100644 --- a/open_instruct/ppo_vllm_thread.py +++ b/open_instruct/ppo_vllm_thread.py @@ -427,7 +427,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # set up experiment tracking and seeds all_configs = {} if is_beaker_job(): - args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", args.output_dir) + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", None) beaker_config = maybe_get_beaker_config() # try saving to the beaker `/output`, which will be uploaded to the beaker dataset if len(beaker_config.beaker_dataset_id_urls) > 0: @@ -472,6 +472,16 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # create the dataset dataset_dict = DatasetDict() dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1: + args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets." + ) + if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1: + args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets." + ) train_dataset = combine_dataset( args.dataset_mixer_dict, splits=args.dataset_train_splits, @@ -1134,8 +1144,9 @@ def repeat_generator(): ) if accelerator.is_main_process: - # remove args.checkpoint_output_dir - if os.path.exists(args.checkpoint_output_dir): + # The `checkpoint_output_dir` is only used in case of preemption and should be deleted if the run was successful. + # We use `--save_freq` to save intermediate checkpoints in the output folder instead. + if args.checkpoint_output_dir is not None and os.path.exists(args.checkpoint_output_dir): shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 38a6ab3d8..4c9377bb2 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -70,7 +70,7 @@ PreTrainedTokenizer, get_scheduler, ) -from transformers.deepspeed import HfDeepSpeedConfig +from transformers.integrations import HfDeepSpeedConfig from vllm import SamplingParams from open_instruct.dataset_processor import ( @@ -98,6 +98,7 @@ ) from open_instruct.utils import ( ArgumentParserPlus, + BeakerRuntimeConfig, combine_dataset, get_wandb_tags, is_beaker_job, @@ -253,6 +254,9 @@ class Args: enable_prefix_caching: bool = False """whether to enable prefix caching""" deepspeed_stage: int = 0 + """the deepspeed stage""" + gather_whole_model: bool = True + """whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)""" # wandb and HF tracking configs with_tracking: bool = False @@ -275,12 +279,14 @@ class Args: """Where to save the model""" checkpoint_output_dir: Optional[str] = None """Where to save the model checkpoints in case of preemption""" - overwrite_beaker_output_dir: Optional[str] = None - """Where to save in a beaker job, if not just /output. Useful with weka.""" # Ai2 specific settings try_launch_beaker_eval_jobs: bool = True """Whether to launch beaker evaluation jobs after training""" + try_launch_beaker_eval_jobs_on_weka: bool = False + """Whether to launch beaker evaluation jobs after training on weka""" + oe_eval_tasks: Optional[List[str]] = None + """The beaker evaluation tasks to launch""" hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" """What dataset to upload the metadata to. If unset, don't upload metadata""" @@ -552,8 +558,13 @@ def empty_cache(self) -> None: @ray.remote(num_gpus=1) class PolicyTrainerRayProcess(RayProcess): - def from_pretrained(self, args: Args, model_config: ModelConfig): + def from_pretrained( + self, args: Args, model_config: ModelConfig, beaker_config: BeakerRuntimeConfig, wandb_url: str + ): self.args = args + self.model_config = model_config + self.beaker_config = beaker_config + self.wandb_url = wandb_url torch.cuda.set_device(self.local_rank) deepspeed.init_distributed() @@ -772,14 +783,27 @@ def broadcast_to_vllm(): model = self.model.module count, num_params = 0, len(list(model.named_parameters())) refss = [] - with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3): + if args.gather_whole_model: + with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3): + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + else: # broadcast each parameter independently for name, param in model.named_parameters(): - count += 1 # empty_cache at last param - - # Fire all vllm engines for broadcast + count += 1 if torch.distributed.get_rank() == 0: shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape - # print(f"broadcasting {name=} {shape=}") refs = [ engine.update_weight.remote( name, dtype=param.dtype, shape=shape, empty_cache=count == num_params @@ -787,12 +811,9 @@ def broadcast_to_vllm(): for engine in vllm_engines ] refss.extend(refs) - # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 - # with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): - if torch.distributed.get_rank() == 0: - torch.distributed.broadcast(param.data, 0, group=self.model_update_group) - # ray.get(refs) - # print(f"broadcasting {name=} {shape=} success") + with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) if torch.distributed.get_rank() == 0: ray.get(refss) @@ -820,7 +841,7 @@ def repeat_generator(): include_stop_str_in_output=True, n=args.number_samples_per_prompt, ) - print("setup async queues") + # print("setup async queues") param_prompt_Q = None response_ids_Q = None evaluation_Q = None @@ -852,7 +873,9 @@ def vllm_generate( generation_start_time = time.time() outputs = ray.get( - llm.generate.remote(sampling_params=generation_config, prompt_token_ids=g_queries_list) + llm.generate.remote( + sampling_params=generation_config, prompt_token_ids=g_queries_list, use_tqdm=False + ) ) response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") @@ -861,7 +884,9 @@ def vllm_generate( if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: outputs = ray.get( llm.generate.remote( - prompt_token_ids=sample_evaluation_prompt_token_ids, sampling_params=generation_config + prompt_token_ids=sample_evaluation_prompt_token_ids, + sampling_params=generation_config, + use_tqdm=False, ) ) # for evaluation, even if we have multiple outputs, we only look at one of them for simplicity @@ -1030,7 +1055,7 @@ def vllm_generate( # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") query_responses = torch.cat((queries, local_vllm_responses), 1) for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - print(f"get reward stuff starts {i=}") + # print(f"get reward stuff starts {i=}") query = queries[i : i + args.local_rollout_forward_batch_size] query_response = query_responses[i : i + args.local_rollout_forward_batch_size] response = query_response[:, context_length:] @@ -1185,7 +1210,7 @@ def vllm_generate( mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] gradient_accumulation_idx = 0 for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - print("micro batch start", micro_batch_start, self.rank) + # print("micro batch start", micro_batch_start, self.rank) micro_batch_end = micro_batch_start + args.per_device_train_batch_size micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] mb_advantage = advantages[micro_batch_inds] @@ -1318,13 +1343,27 @@ def vllm_generate( del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) gc.collect() torch.cuda.empty_cache() - print(f"finished training {training_step}") + # print(f"finished training {training_step}") # save steps if args.save_freq > 0 and training_step % args.save_freq == 0: - step_dir = os.path.join(args.output_dir, f"step_{training_step}") + checkpoint_dir = f"{args.output_dir}_checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") os.makedirs(step_dir, exist_ok=True) + print(f"Saving model at step {training_step} to {step_dir}") self.save_model(step_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(step_dir, training_step) + print(f"Saving final model at step {training_step} to {args.output_dir}") + self.save_model(args.output_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(args.output_dir) + + # Ai2 logic: we use /output to store the artifacts of the job, so we + # make a copy of the model to `/output` in the end. + if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0: + shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training") def save_model(self, output_dir: str) -> None: @@ -1382,6 +1421,62 @@ def save_model(self, output_dir: str) -> None: # save tokenizer self.original_tokenizer.save_pretrained(output_dir) + def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = None) -> None: + """auto eval the metrics as `f"{args.exp_name}_step_{training_step}"` in our leaderboard""" + args = self.args + beaker_config = self.beaker_config + model_config = self.model_config + wandb_url = self.wandb_url + # Ai2 specific logic + if is_beaker_job() and self.rank == 0: + if training_step is not None: + leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" + else: + leaderboard_name = args.hf_repo_revision + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "ppo", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb_url, + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + leaderboard_name, # to match what the auto-evals name as. + ) + + command = f"""\ +python scripts/submit_eval_jobs.py \ + --model_name {leaderboard_name} \ + --location {step_dir} \ + --cluster ai2/saturn-cirrascale \ + --is_tuned \ + --workspace "tulu-3-results" \ + --preemptible \ + --use_hf_tokenizer_template \ + --beaker_image "nathanl/open_instruct_auto" \ + --upload_to_hf allenai/tulu-3-evals \ + --run_oe_eval_experiments \ + --evaluate_on_weka \ + --run_safety_evaluations \ + --skip_oi_evals""" + if args.oe_eval_tasks is not None: + command += f" --oe_eval_tasks {','.join(args.oe_eval_tasks)}" + print(f"Launching eval jobs with command: {command}") + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): while True: @@ -1462,15 +1557,10 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # set up experiment tracking and seeds all_configs = {} + beaker_config = None if is_beaker_job(): - args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", args.output_dir) + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", None) beaker_config = maybe_get_beaker_config() - # try saving to the beaker `/output`, which will be uploaded to the beaker dataset - if len(beaker_config.beaker_dataset_id_urls) > 0: - args.output_dir = "/output" - # if the user has asked to save to a specific directory, use that instead - if args.overwrite_beaker_output_dir is not None: - args.output_dir = args.overwrite_beaker_output_dir all_configs.update(vars(beaker_config)) all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) if args.with_tracking: @@ -1505,6 +1595,16 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): # create the dataset dataset_dict = DatasetDict() dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1: + args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets." + ) + if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1: + args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets." + ) train_dataset = combine_dataset( args.dataset_mixer_dict, splits=args.dataset_train_splits, @@ -1532,7 +1632,10 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): dataset_config.dataset_source_key, ], ) - eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) + if dataset_config.sanity_check: + eval_dataset = eval_dataset.select( + range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples)) + ) eval_dataset = dataset_processor.tokenize(eval_dataset) eval_dataset = dataset_processor.filter(eval_dataset, need_contain_labels=False) dataset_dict["eval"] = eval_dataset @@ -1560,7 +1663,10 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): PolicyTrainerRayProcess, args.actor_num_gpus_per_node, ) - inits.extend(model.from_pretrained.remote(args, model_config) for model in policy_group.models) + wandb_url = wandb.run.get_url() if args.with_tracking else None + inits.extend( + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url) for model in policy_group.models + ) max_len = dataset_config.max_prompt_token_length + args.response_length vllm_engines = create_vllm_engines( args.vllm_num_engines, @@ -1616,20 +1722,15 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): if args.with_tracking: wandb.log({"sample_completions": wandb.Table(dataframe=df)}) else: - print_rich_table(df) + print_rich_table(df.iloc[:1]) ray.get(refs) # save model - ray.get([policy_model.save_model.remote(args.output_dir) for policy_model in policy_group.models]) ray.shutdown() stop_event.set() - # hack - accelerator = Namespace() - accelerator.is_main_process = True - # Ai2 specific logic - if is_beaker_job() and accelerator.is_main_process: + if is_beaker_job(): if args.hf_metadata_dataset: dataset_list = list(args.dataset_mixer_dict.keys()) # mainly just focussing here on what would be useful for the leaderboard. @@ -1670,7 +1771,10 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + accelerator = Namespace() + accelerator.is_main_process = True # hack if args.push_to_hub: + print("Pushing model to hub") push_folder_to_hub( accelerator, args.output_dir, @@ -1678,10 +1782,10 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): args.hf_repo_revision, ) - if accelerator.is_main_process: - # remove args.checkpoint_output_dir - if os.path.exists(args.checkpoint_output_dir): - shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) + # The `checkpoint_output_dir` is only used in case of preemption and should be deleted if the run was successful. + # We use `--save_freq` to save intermediate checkpoints in the output folder instead. + if args.checkpoint_output_dir is not None and os.path.exists(args.checkpoint_output_dir): + shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) if __name__ == "__main__": diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_weka.py b/open_instruct/ppo_vllm_thread_ray_gtrl_weka.py deleted file mode 100644 index 306554c2b..000000000 --- a/open_instruct/ppo_vllm_thread_ray_gtrl_weka.py +++ /dev/null @@ -1,1780 +0,0 @@ -# Copyright 2024 AllenAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# --------------------------------------------------------------------- -# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF -# which has the following license: -# Copyright [yyyy] [name of copyright owner] -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import json -import logging -import os -import random -import shutil -import socket -import subprocess -import threading -import time -from argparse import Namespace -from dataclasses import asdict, dataclass, field -from queue import Empty, Queue -from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple - -import deepspeed -import numpy as np -import pandas as pd -import ray -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.utils -import torch.utils.data -import vllm -from datasets import Dataset, DatasetDict -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from huggingface_hub import HfApi -from ray.util.placement_group import PlacementGroup, placement_group -from ray.util.queue import Queue as RayQueue -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from rich.pretty import pprint -from torch.utils.tensorboard import SummaryWriter -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizer, - get_scheduler, -) -from transformers.integrations import HfDeepSpeedConfig -from vllm import SamplingParams - -from open_instruct.dataset_processor import ( - CHAT_TEMPLATES, - DATASET_SOURCE_KEY, - GROUND_TRUTHS_KEY, - INPUT_IDS_PROMPT_KEY, - DatasetConfig, - SFTGroundTruthDatasetProcessor, - SimpleGenerateCollatorWithGroundTruth, - visualize_token, -) -from open_instruct.model_utils import ( - ModelConfig, - apply_verifiable_reward, - disable_dropout_in_model, - exact_div, - first_true_indices, - forward, - get_reward, - print_rich_single_line_metrics, - print_rich_table, - push_folder_to_hub, - truncate_response, -) -from open_instruct.utils import ( - ArgumentParserPlus, - BeakerRuntimeConfig, - combine_dataset, - get_wandb_tags, - is_beaker_job, - maybe_get_beaker_config, - maybe_use_ai2_hf_entity, - maybe_use_ai2_wandb_entity, - upload_metadata_to_hf, -) -from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group - -api = HfApi() -INVALID_LOGPROB = 1.0 - - -@dataclass -class Args: - # required dataset args - dataset_mixer: str = None - """A dictionary of datasets (local or HF) to sample from.""" - dataset_train_splits: List[str] = None - """The dataset splits to use for training""" - dataset_eval_mixer: Optional[str] = None - """A dictionary of datasets (local or HF) to sample from for evaluation""" - dataset_eval_splits: Optional[List[str]] = None - """The dataset splits to use for evaluation""" - dataset_mixer_dict: Optional[dict] = None - """The dataset mixer as a dictionary""" - dataset_eval_mixer_dict: Optional[dict] = None - """The dataset eval mixer as a dictionary""" - - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """The name of this experiment""" - seed: int = 1 - """Seed of the experiment""" - run_name: Optional[str] = None - """A unique name of this run""" - - # optimizer args - eps: float = 1e-5 - """The epsilon value for the optimizer""" - learning_rate: float = 2e-5 - """The initial learning rate for AdamW optimizer.""" - lr_scheduler_type: Literal[ - "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" - ] = "linear" - """Which scheduler to use""" - warm_up_steps: int = 0 - """Number of warm up steps for the scheduler""" - - # various batch sizes - num_train_epochs: int = 1 - """Number of epochs to train""" - gradient_accumulation_steps: Optional[int] = None - """The number of gradient accumulation steps""" - per_device_train_batch_size: Optional[int] = 1 - """The forward batch size per device (local_micro_batch_size)""" - per_device_eval_batch_size: Optional[int] = 1 - """The forward batch size per device for evaluation (local_micro_batch_size)""" - total_episodes: Optional[int] = 100000 - """The total number of episodes in the dataset""" - world_size: Optional[int] = None - """The number of processes (GPUs) to use""" - micro_batch_size: Optional[int] = None - """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" - local_rollout_batch_size: int = 64 - """The number of rollout episodes per iteration per device""" - rollout_batch_size: Optional[int] = None - """The number of rollout episodes per iteration""" - num_training_steps: Optional[int] = None - """The number of training_steps to train""" - num_evals: int = 4 - """The number of evaluations to run throughout training""" - eval_freq: Optional[int] = None - """The frequency of evaluation steps""" - local_dataloader_batch_size: Optional[int] = None - """The batch size per GPU for the dataloader""" - save_freq: int = -1 - """How many train steps to save the model""" - - # online settings - num_epochs: int = 4 - """the number of epochs to train""" - num_mini_batches: int = 1 - """Number of minibatches to split a batch into""" - local_mini_batch_size: int = 64 - """the mini batch size per GPU""" - mini_batch_size: Optional[int] = None - """the mini batch size across GPUs""" - local_rollout_forward_batch_size: int = 64 - """per rank no grad forward pass in the rollout phase""" - reward_model_path: str = "EleutherAI/pythia-160m" - """the path to the reward model""" - reward_model_revision: Optional[str] = None - """the revision of the reward model""" - init_value_from_scratch: bool = False - """whether to initialize the value model from scratch""" - - # generation config - response_length: int = 53 - """the length of the response""" - stop_token: Optional[Literal["eos", "period"]] = None - """the stop token""" - stop_token_id: Optional[int] = None - """the truncation token id""" - min_response_length: int = 0 - """stop only after this many tokens""" - temperature: float = 0.7 - """the sampling temperature""" - penalty_reward_value: float = -1.0 - """the reward value for responses that do not contain `stop_token_id`""" - non_stop_penalty: bool = False - """whether to penalize responses that do not contain `stop_token_id`""" - number_samples_per_prompt: int = 1 - """the number of samples to generate per prompt, useful for easy-star""" - - # online PPO specific args - beta: float = 0.05 - """the beta value of the RLHF objective (KL coefficient)""" - whiten_rewards: bool = False - """whether to whiten the rewards""" - cliprange: float = 0.2 - """the clip range""" - vf_coef: float = 0.1 - """the value function coefficient""" - cliprange_value: float = 0.2 - """the clip range for the value function""" - gamma: float = 1 - """the discount factor""" - lam: float = 0.95 - """the lambda value for GAE""" - kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" - """the KL estimator to use""" - apply_verifiable_reward: bool = False - """whether to apply verifiable reward""" - reward_model_multiplier: float = 1.0 - """the reward model multiplier, for down/upscaling the reward model output""" - answer_extraction_model: str = None - - # async setting - async_mode: bool = True - """Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" - - # ray - actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1]) - """number of gpus per node for actor""" - vllm_num_engines: int = 1 - """number of vLLM Engines, set to 0 to disable vLLM""" - vllm_tensor_parallel_size: int = 1 - """tensor parallel size of vLLM Engine for multi-GPU inference""" - vllm_sync_backend: str = "nccl" - """DeepSpeed -> vLLM weight sync backend""" - enable_prefix_caching: bool = False - """whether to enable prefix caching""" - deepspeed_stage: int = 0 - """the deepspeed stage""" - gather_whole_model: bool = False - """whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)""" - - # wandb and HF tracking configs - with_tracking: bool = False - """If toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "open_instruct_internal" - """The wandb's project name""" - wandb_entity: Optional[str] = None - """The entity (team) of wandb's project""" - push_to_hub: bool = True - """Whether to upload the saved model to huggingface""" - hf_entity: Optional[str] = None - """The user or org name of the model repository from the Hugging Face Hub""" - hf_repo_id: Optional[str] = None - """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_revision: Optional[str] = None - """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_url: Optional[str] = None - """The url of the saved model in the Hugging Face Hub (will be autoset)""" - output_dir: Optional[str] = None - """Where to save the model""" - checkpoint_output_dir: Optional[str] = None - """Where to save the model checkpoints in case of preemption""" - - # Ai2 specific settings - try_launch_beaker_eval_jobs: bool = True - """Whether to launch beaker evaluation jobs after training""" - try_launch_beaker_eval_jobs_on_weka: bool = False - """Whether to launch beaker evaluation jobs after training on weka""" - oe_eval_tasks: Optional[List[str]] = None - """The beaker evaluation tasks to launch""" - hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" - """What dataset to upload the metadata to. If unset, don't upload metadata""" - - def __post_init__(self): - self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) - if self.dataset_eval_mixer is not None: - self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) - - -def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: - # if passed through cli: convert the dataset mixers to dictionaries - if isinstance(value, str): - return json.loads(value), value - # if passed through yaml: convert the dataset mixers to strings - elif isinstance(value, dict): - return value, json.dumps(value) - else: - raise ValueError("Input must be either a string or a dictionary") - - -def calculate_runtime_args(args: Args, model_config: ModelConfig): - """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" - # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - # args.world_size = accelerator.num_processes - args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - args.gradient_accumulation_steps = exact_div( - args.local_mini_batch_size, - args.per_device_train_batch_size, - "`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`", - ) - args.world_size = sum(args.actor_num_gpus_per_node) - args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) - args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) - args.mini_batch_size = int(args.local_mini_batch_size * args.world_size) - args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt) - args.eval_freq = max(1, args.num_training_steps // args.num_evals) - # PPO logic: do checks and set up dataloader batch size - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - args.local_dataloader_batch_size = args.rollout_batch_size - if args.push_to_hub: - if args.hf_repo_id is None: # auto-generate one - args.hf_repo_id = "open_instruct_dev" - if args.hf_entity is None: # first try to use AI2 entity - args.hf_entity = maybe_use_ai2_hf_entity() - if args.hf_entity is None: # then try to use the user's entity - args.hf_entity = HfApi().whoami()["name"] - args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" - if args.hf_repo_revision is None: # auto-generate one - args.hf_repo_revision = args.run_name - args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" - - if args.with_tracking: - if args.wandb_entity is None: - args.wandb_entity = maybe_use_ai2_wandb_entity() - - -def get_train_ds_config( - offload, - adam_offload=False, - stage=0, - bf16=True, - max_norm=1.0, - zpg=8, - grad_accum_dtype=None, - disable_trace_cache=True, -): - device = "cpu" if offload else "none" - zero_opt_dict = { - "stage": stage, - "offload_param": {"device": device}, - "offload_optimizer": { - "device": "cpu" if adam_offload else "none", - "pin_memory": True, - }, - "sub_group_size": "auto", - "stage3_max_live_parameters": "auto", - "stage3_max_reuse_distance": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_prefetch_bucket_size": "auto", - "reduce_bucket_size": "auto", - # # ZeRO++ - # "zero_hpz_partition_size": zpg, - # "zero_quantized_weights": False, - # "zero_quantized_gradients": False, - } - if disable_trace_cache: - zero_opt_dict["stage3_prefetch_bucket_size"] = 0 - zero_opt_dict["stage3_max_live_parameters"] = 0 - zero_opt_dict["stage3_max_reuse_distance"] = 0 - - return { - "steps_per_print": 100, - "zero_optimization": zero_opt_dict, - "bf16": { - "enabled": bf16, - }, - "gradient_clipping": max_norm, - "prescale_gradients": False, - "wall_clock_breakdown": False, - "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, - } - - -def get_eval_ds_config( - offload, - stage=0, - bf16=True, -): - zero_opt_dict = { - "stage": stage, - "stage3_param_persistence_threshold": "auto", - "offload_param": { - "device": "cpu" if offload else "none", - "pin_memory": True, - }, - } - return { - "steps_per_print": 100, - "zero_optimization": zero_opt_dict, - "bf16": { - "enabled": bf16, - }, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - - -def get_optimizer_grouped_parameters( - model, - weight_decay, - no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], -): - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in model.named_parameters() - if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) - ], - "weight_decay": weight_decay, - }, - { - "params": [ - p - for n, p in model.named_parameters() - if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) - ], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - - -def _z3_params_to_fetch(param_list): - return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] - - -def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: - """Compute mean of tensor with a masked values.""" - if axis is not None: - return (values * mask).sum(axis=axis) / mask.sum(axis=axis) - else: - return (values * mask).sum() / mask.sum() - - -def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: - """Compute variance of tensor with masked values.""" - mean = masked_mean(values, mask) - centered_values = values - mean - variance = masked_mean(centered_values**2, mask) - if unbiased: - mask_sum = mask.sum() - if mask_sum == 0: - raise ValueError( - "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" - "try increase the `mini_batch_size` or `gradient_accumulation_steps`" - ) - # note that if mask_sum == 1, then there is a division by zero issue - # to avoid it you just need to use a larger minibatch_size - bessel_correction = mask_sum / (mask_sum - 1) - variance = variance * bessel_correction - return variance - - -def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: - """Whiten values with masked values.""" - mean, var = masked_mean(values, mask), masked_var(values, mask) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -def remove_padding(sequences, pad_token_id): - return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences] - - -class ShufflingIterator: - def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): - self.data = data.copy() - self.batch_size = batch_size - self.index = 0 - self.rng = np.random.default_rng(seed) - self.rng.shuffle(self.data) - - # Ensure the effective dataset size is divisible by batch_size - self.effective_size = len(self.data) - (len(self.data) % batch_size) - - def __iter__(self) -> Iterator[List[int]]: - return self - - def __next__(self) -> List[int]: - if self.index >= self.effective_size: - self.index = 0 - self.rng.shuffle(self.data) - - end_index = self.index + self.batch_size - batch = self.data[self.index : end_index].tolist() - self.index = end_index - - return batch - - -class RayProcess: - def __init__(self, world_size, rank, local_rank, master_addr, master_port): - logging.basicConfig( - format="%(asctime)s %(levelname)-8s %(message)s", - level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S", - ) - self.world_size = world_size - self.rank = rank - self.local_rank = local_rank - self.master_addr = master_addr if master_addr else self.get_current_node_ip() - self.master_port = master_port if master_port else self.get_free_port() - os.environ["MASTER_ADDR"] = self.master_addr - os.environ["MASTER_PORT"] = str(self.master_port) - os.environ["WORLD_SIZE"] = str(self.world_size) - os.environ["RANK"] = str(self.rank) - # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES - # environment variable for each actor, so always set device to 0 - # os.environ["LOCAL_RANK"] = str(self._local_rank) - os.environ["LOCAL_RANK"] = "0" - random.seed(self.rank) - np.random.seed(self.rank) - torch.manual_seed(self.rank) - - @staticmethod - def get_current_node_ip(): - address = ray._private.services.get_node_ip_address() - # strip ipv6 address - return address.strip("[]") - - @staticmethod - def get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - def get_master_addr_port(self): - return self.master_addr, self.master_port - - def empty_cache(self) -> None: - torch.cuda.empty_cache() - - -@ray.remote(num_gpus=1) -class PolicyTrainerRayProcess(RayProcess): - def from_pretrained(self, args: Args, model_config: ModelConfig, beaker_config: BeakerRuntimeConfig, wandb_url: str): - self.args = args - self.model_config = model_config - self.beaker_config = beaker_config - self.wandb_url = wandb_url - torch.cuda.set_device(self.local_rank) - deepspeed.init_distributed() - - ds_config = get_train_ds_config( - offload=False, - adam_offload=False, - stage=args.deepspeed_stage, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["train_batch_size"] = args.mini_batch_size - # Costa: MAGIC: it's actually needed to initialize this `dschf`, so - # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration - # next line instructs transformers to partition the model directly over multiple gpus using - # deepspeed.zero.Init when model's `from_pretrained` method is called. - if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: - dschf = HfDeepSpeedConfig(ds_config) - else: - dschf = None - print(f"{dschf=}") - - self.original_tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, revision=model_config.model_revision - ) - self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, - revision=model_config.model_revision, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - disable_dropout_in_model(self.policy) - self.policy.gradient_checkpointing_enable() - # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam - # AdamOptimizer = FusedAdam - # weight_decay = 0.0 - # optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) - # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) - self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) - scheduler = get_scheduler( - args.lr_scheduler_type, - optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, - ) - print(ds_config) - self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( - model=self.policy, - optimizer=self.optimizer, - config=ds_config, - lr_scheduler=scheduler, - dist_init_required=True, - ) - self.model.train() - - # value model - self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( - args.reward_model_path, - revision=args.reward_model_revision, - num_labels=1, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - if args.init_value_from_scratch: - self.value_model.init_weights() # re-initialize the value model from scratch - disable_dropout_in_model(self.value_model) - self.value_model.gradient_checkpointing_enable() - # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam - # AdamOptimizer = FusedAdam - # weight_decay = 0.0 - # optim_params = get_optimizer_grouped_parameters(self.value_model, weight_decay) - # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) - self.optimizer = torch.optim.AdamW(self.value_model.parameters(), lr=args.learning_rate) - scheduler = get_scheduler( - args.lr_scheduler_type, - optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, - ) - self.value_model, self.optimizer, _, self.scheduler = deepspeed.initialize( - model=self.value_model, - optimizer=self.optimizer, - config=ds_config, - lr_scheduler=scheduler, - dist_init_required=True, - ) - self.value_model.train() - - # reference model - ds_config = get_eval_ds_config( - offload=False, - stage=args.deepspeed_stage, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["train_batch_size"] = args.mini_batch_size - # Costa: MAGIC: it's actually needed to initialize this `dschf`, so - # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration - # next line instructs transformers to partition the model directly over multiple gpus using - # deepspeed.zero.Init when model's `from_pretrained` method is called. - if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: - dschf = HfDeepSpeedConfig(ds_config) - else: - dschf = None - print(f"{dschf=}") - - self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, - revision=model_config.model_revision, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - disable_dropout_in_model(self.ref_policy) - self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) - self.ref_policy.eval() - - # reward model - self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( - args.reward_model_path, - revision=args.reward_model_revision, - num_labels=1, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - disable_dropout_in_model(self.reward_model) - ds_config = get_eval_ds_config( - offload=False, - stage=args.deepspeed_stage, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["train_batch_size"] = args.mini_batch_size - self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config) - self.reward_model.eval() - - def get_vocab_size(self): - return self.policy.config.vocab_size - - def forward( - self, - query_response: torch.LongTensor, - response: torch.LongTensor, - pad_token_id: int, - context_length: int, - temperature: float, - ) -> torch.Tensor: - output = forward(self.model, query_response, pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= temperature + 1e-7 - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - return logprob - - def train( - self, - train_dataset: Dataset, - eval_dataset: Dataset, - tokenizer: PreTrainedTokenizer, - vllm_engines: List[ray.actor.ActorHandle], - metrics_queue: RayQueue, - data_collator: Callable, - ): - torch.set_printoptions(precision=4, sci_mode=False) - - args = self.args - - accelerator = Namespace() - accelerator.process_index = self.rank - accelerator.num_processes = self.world_size - accelerator.is_main_process = self.rank == 0 - torch.distributed.barrier() - if self.rank == 0: - master_address = ray._private.services.get_node_ip_address() - with socket.socket() as sock: - sock.bind(("", 0)) - master_port = sock.getsockname()[1] - vllm_num_engines, vllm_tensor_parallel_size = ( - args.vllm_num_engines, - args.vllm_tensor_parallel_size, - ) - world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 - backend = args.vllm_sync_backend - # https://github.com/OpenRLHF/OpenRLHF/issues/313 - if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0": - backend = "gloo" - print( - "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)" - ) - refs = [ - engine.init_process_group.remote( - master_address, - master_port, - i * vllm_tensor_parallel_size + 1, - world_size, - "openrlhf", - backend=backend, - ) - for i, engine in enumerate(vllm_engines) - ] - self.model_update_group = init_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name="openrlhf", - ) - ray.get(refs) - torch.distributed.barrier() - - def broadcast_to_vllm(): - # avoid OOM - torch.cuda.empty_cache() - model = self.model.module - count, num_params = 0, len(list(model.named_parameters())) - refss = [] - if args.gather_whole_model: - with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3): - for name, param in model.named_parameters(): - count += 1 # empty_cache at last param - # Fire all vllm engines for broadcast - if torch.distributed.get_rank() == 0: - shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape - refs = [ - engine.update_weight.remote( - name, dtype=param.dtype, shape=shape, empty_cache=count == num_params - ) - for engine in vllm_engines - ] - refss.extend(refs) - if torch.distributed.get_rank() == 0: - torch.distributed.broadcast(param.data, 0, group=self.model_update_group) - else: # broadcast each parameter independently - for name, param in model.named_parameters(): - count += 1 - if torch.distributed.get_rank() == 0: - shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape - refs = [ - engine.update_weight.remote( - name, dtype=param.dtype, shape=shape, empty_cache=count == num_params - ) - for engine in vllm_engines - ] - refss.extend(refs) - with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): - if torch.distributed.get_rank() == 0: - torch.distributed.broadcast(param.data, 0, group=self.model_update_group) - if torch.distributed.get_rank() == 0: - ray.get(refss) - - # broadcast_to_vllm() - if args.stop_token: - if args.stop_token == "eos": - args.stop_token_id = tokenizer.eos_token_id - if args.stop_token == "period": - args.stop_token_id = tokenizer.encode(".")[0] - # data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) - train_dataset_idxs = np.arange(len(train_dataset)) - shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed) - - # hack to left pad - def repeat_generator(): - while True: - batch_idxs = next(shuffling_iter) - yield [train_dataset[i] for i in batch_idxs] - - iter_dataloader = iter(repeat_generator()) - generation_config = SamplingParams( - temperature=args.temperature, - top_p=1.0, - max_tokens=args.response_length, - include_stop_str_in_output=True, - n=args.number_samples_per_prompt, - ) - # print("setup async queues") - param_prompt_Q = None - response_ids_Q = None - evaluation_Q = None - response_ids_Q = Queue(maxsize=1) - param_prompt_Q = Queue(maxsize=1) - evaluation_Q = Queue(maxsize=1) - num_eval_samples = 32 - sample_evaluation_prompt_token_ids = None - if eval_dataset is not None: - sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] - - def vllm_generate( - generation_config: SamplingParams, - response_ids_Q: Queue, - param_prompt_Q: Queue, - num_training_steps: int, - sample_evaluation_prompt_token_ids: Optional[List[int]], - evaluation_Q: Queue, - eval_freq: int, - resume_training_step: int, - ): - llm = vllm_engines[0] - for training_step in range(resume_training_step, num_training_steps + 1): - items = param_prompt_Q.get() - if items is None: - break - unwrapped_model, g_queries_list = items - # if unwrapped_model is not None: - generation_start_time = time.time() - - outputs = ray.get( - llm.generate.remote(sampling_params=generation_config, prompt_token_ids=g_queries_list, use_tqdm=False) - ) - response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] - print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") - response_ids_Q.put(response_ids) - - if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: - outputs = ray.get( - llm.generate.remote( - prompt_token_ids=sample_evaluation_prompt_token_ids, sampling_params=generation_config, use_tqdm=False - ) - ) - # for evaluation, even if we have multiple outputs, we only look at one of them for simplicity - response_ids = [list(output.outputs[0].token_ids) for output in outputs] - evaluation_Q.put(response_ids) - - resume_training_step = 1 - if accelerator.is_main_process: - thread = threading.Thread( - target=vllm_generate, - args=( - generation_config, - response_ids_Q, - param_prompt_Q, - args.num_training_steps, - sample_evaluation_prompt_token_ids, - evaluation_Q, - args.eval_freq, - resume_training_step, - ), - ) - thread.start() - print("vllm generate thread starts") - - # set up the metrics and initial states - device = torch.device(self.local_rank) - g_vllm_responses = torch.zeros( - (args.rollout_batch_size * args.number_samples_per_prompt, args.response_length), - device=device, - dtype=torch.long, - ) - stats_shape = ( - args.num_epochs, - args.num_mini_batches * args.number_samples_per_prompt, - args.gradient_accumulation_steps, - ) - approxkl_stats = torch.zeros(stats_shape, device=device) - pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - pg_loss_stats = torch.zeros(stats_shape, device=device) - vf_loss_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropy_stats = torch.zeros(stats_shape, device=device) - ratio_stats = torch.zeros(stats_shape, device=device) - local_metrics = torch.zeros((20,), device=device) - episode = args.rollout_batch_size * (resume_training_step - 1) - - # training loop - start_time = time.time() - global_data = next(iter_dataloader) - data = data_collator( - global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size] - ) - global_queries = data_collator(global_data)[ - INPUT_IDS_PROMPT_KEY - ].tolist() # can be simplified since we `remove_padding` later anyway - queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - ground_truths_next = data[GROUND_TRUTHS_KEY] - datasets_next = data[DATASET_SOURCE_KEY] - if accelerator.is_main_process: - param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) - - answer_extraction_model = None - answer_extraction_tokenizer = None - # for _ in range(1, resume_training_step): # we didn't store scheduler state - # scheduler.step() - - for training_step in range(resume_training_step, args.num_training_steps + 1): - episode += args.rollout_batch_size * args.number_samples_per_prompt # each sample is an episode - queries = queries_next - ground_truths = ground_truths_next - datasets = datasets_next - - if accelerator.is_main_process: - df = None - try: - evaluation_responses = evaluation_Q.get(timeout=0.01) - print("🔥🔥🔥 Evaluation responses received") - table = {} - table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids) - table["response"] = tokenizer.batch_decode(evaluation_responses) - table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] - df = pd.DataFrame(table) - del table - except Empty: - print("🙈 Evaluation responses not received") - - # (optionally) evaluate the model - if args.async_mode: - if training_step != 1: - global_data = next(iter_dataloader) - data = data_collator( - global_data[ - self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size - ] - ) - global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() - queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - ground_truths_next = data[GROUND_TRUTHS_KEY] - datasets_next = data[DATASET_SOURCE_KEY] - - start_time = time.time() - broadcast_to_vllm() - print( - f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" - ) - if accelerator.is_main_process: - param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) - else: - if training_step != 1: - # NOTE: important: the indent here is different for sync mode - # we also set to use `queries = queries_next` immediately - global_data = next(iter_dataloader) - data = data_collator( - global_data[ - self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size - ] - ) - global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() - queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - ground_truths_next = data[GROUND_TRUTHS_KEY] - datasets_next = data[DATASET_SOURCE_KEY] - start_time = time.time() - broadcast_to_vllm() - print( - f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" - ) - if accelerator.is_main_process: - param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) - queries = queries_next - ground_truths = ground_truths_next - datasets = datasets_next - - torch.cuda.empty_cache() - # print('get reward stuff starts') - # if we generate multiple samples per prompt, we need to repeat the queries and ground truths - # to match the vllm outputs. - if args.number_samples_per_prompt > 1: - queries = queries.repeat_interleave(args.number_samples_per_prompt, dim=0) - ground_truths = [gt for gt in ground_truths for _ in range(args.number_samples_per_prompt)] - datasets = [ds for ds in datasets for _ in range(args.number_samples_per_prompt)] - - training_time_start = time.time() - with torch.no_grad(): - context_length = queries.shape[1] - responses = [] - postprocessed_responses = [] - logprobs = [] - ref_logprobs = [] - scores = [] - verifiable_counts = [] - sequence_lengths = [] - values = [] - if accelerator.is_main_process: - g_response_token_ids = response_ids_Q.get() - DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out - g_padded_response_ids = [ - response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) - for response in g_response_token_ids - ] - g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) - g_vllm_responses[:] = g_padded_response_ids - dist.broadcast(g_vllm_responses, src=0) - local_vllm_responses = g_vllm_responses[ - accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] - ] - # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") - query_responses = torch.cat((queries, local_vllm_responses), 1) - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - # print(f"get reward stuff starts {i=}") - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response = query_responses[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] - - logprob = self.forward( - query_response, response, tokenizer.pad_token_id, context_length, args.temperature - ) - torch.cuda.empty_cache() - - ref_output = forward(self.ref_policy, query_response, tokenizer.pad_token_id) - ref_logits = ref_output.logits[:, context_length - 1 : -1] - ref_logits /= args.temperature + 1e-7 - ref_all_logprob = F.log_softmax(ref_logits, dim=-1) - ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - del ref_output, ref_logits, ref_all_logprob - torch.cuda.empty_cache() - - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) - # print("get reward stuff starts 2") - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - _, score, _ = get_reward( - self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - if args.reward_model_multiplier != 1.0: - score *= args.reward_model_multiplier - # also apply verifiable reward - if args.apply_verifiable_reward: - # we need to batch the gt to match query. - ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] - dataset = datasets[i : i + args.local_rollout_forward_batch_size] - verifiable_reward, verifiable_count = apply_verifiable_reward( - postprocessed_query_response, - tokenizer, - ground_truth, - dataset, - verify_reward=10, - answer_extraction_model=answer_extraction_model, - answer_extraction_tokenizer=answer_extraction_tokenizer, - ) - score += verifiable_reward - else: - verifiable_count = torch.tensor([0.0], device=device).float() - full_value, _, _ = get_reward( - self.value_model, query_response, tokenizer.pad_token_id, context_length - ) - value = full_value[:, context_length - 1 : -1].squeeze(-1) - - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - values.append(value) - verifiable_counts.append(verifiable_count) - # print(f"get reward stuff starts 5") - - responses = torch.cat(responses, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) - ref_logprobs = torch.cat(ref_logprobs, 0) - sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) - verifiable_counts = torch.cat(verifiable_counts, 0) - verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0] - values = torch.cat(values, 0) - # print(f"get reward stuff finished") - del (logprob, ref_logprob, full_value, value, score) - gc.collect() - torch.cuda.empty_cache() - - # Response Processing 3. filter response. Ensure that the sample contains stop_token_id - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - contain_stop_token = torch.any(postprocessed_responses == args.stop_token_id, dim=-1) - # NOTE: only apply the stop token filter if the response is long enough - # otherwise the model could learn to generate the first token as the stop token - contain_stop_token = contain_stop_token & (sequence_lengths >= args.min_response_length) - if args.non_stop_penalty: - scores = torch.where( - contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value) - ) - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - sequence_lengths_p1 = sequence_lengths + 1 - padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) - values = torch.masked_fill(values, padding_mask_p1, 0) - # print(f"get reward stuff finished 2") - - # 4. compute rewards - kl1 = logprobs - ref_logprobs - kl2 = (kl1) ** 2 / 2 - kl3 = (-kl1).exp() - 1 + kl1 - if args.kl_estimator == "kl1": - kl = kl1 - elif args.kl_estimator == "kl2": - kl = kl2 - elif args.kl_estimator == "kl3": - kl = kl3 - # if self.rank==0: - # print(f"{logprobs[0][:40]=}, {ref_logprobs[0][:40]=}, {kl.sum(1)=}") - non_score_reward = -args.beta * kl - non_score_reward_sum = non_score_reward.sum(1) - rlhf_reward = scores + non_score_reward_sum - rewards = non_score_reward.clone() - actual_start = torch.arange(rewards.size(0), device=rewards.device) - actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) - rewards[[actual_start, actual_end]] += scores - # print(f"get reward stuff finished 3") - - # 5. whiten rewards - if args.whiten_rewards: - rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) - rewards = torch.masked_fill(rewards, padding_mask_p1, 0) - - # print('gae') - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = responses.shape[1] - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.gamma * args.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = masked_whiten(advantages, ~padding_mask) - advantages = torch.masked_fill(advantages, padding_mask, 0) - torch.cuda.empty_cache() - - # print('training starts') - # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch - for epoch_idx in range(args.num_epochs): - b_inds = np.random.permutation(args.local_rollout_batch_size * args.number_samples_per_prompt) - minibatch_idx = 0 - for mini_batch_start in range( - 0, args.local_rollout_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size - ): - mini_batch_end = mini_batch_start + args.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - # print("micro batch start", micro_batch_start, self.rank) - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - mb_return = returns[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds] - - vpred_temp = get_reward( - self.value_model, mb_query_responses, tokenizer.pad_token_id, context_length - ) - vpred_temp = vpred_temp[0] - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0) - vpredclipped = torch.clamp( - vpred, - mb_values - args.cliprange_value, - mb_values + args.cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss_max = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1) - self.value_model.backward(vf_loss * args.vf_coef) - self.value_model.step() - - new_logprobs = self.forward( - mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature - ) - # if self.rank==0: - # print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}") - new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) - loss = pg_loss - self.model.backward(loss) - # print("backward loss", self.rank, "micro batch start", micro_batch_start) - # print("trying to step", self.rank, "micro batch start", micro_batch_start) - self.model.step() - # print("step", self.rank, "micro batch start", micro_batch_start) - with torch.no_grad(): - # print("waiting for value model step", self.rank, "micro batch start", micro_batch_start) - # vf_loss, vf_clipfrac = ray.get(value_model_step_future) - vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1) - pg_clipfrac = masked_mean( - (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] - ) - # print("value model stepped", self.rank, "micro batch start", micro_batch_start) - # prob_dist = torch.nn.functional.softmax(logits, dim=-1) - # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - # entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - # fmt: off - del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1 - del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss - # del vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max - # del vf_loss, vf_clipfrac, pg_clipfrac, approxkl - # fmt: on - # del everything and empty cache - torch.cuda.empty_cache() - del b_inds, mini_batch_inds - # print("start metrics") - with torch.no_grad(): - local_metrics[0] = sequence_lengths.float().mean() - local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() - local_metrics[2] = kl.sum(1).mean() - local_metrics[3] = (-logprobs).sum(1).mean() - local_metrics[4] = non_score_reward_sum.mean() - local_metrics[5] = rlhf_reward.mean() - local_metrics[6] = scores.mean() - local_metrics[7] = approxkl_stats.mean() - local_metrics[8] = pg_clipfrac_stats.mean() - local_metrics[9] = pg_loss_stats.mean() - local_metrics[10] = vf_loss_stats.mean() - local_metrics[11] = vf_clipfrac_stats.mean() - local_metrics[12] = entropy_stats.mean() - local_metrics[13] = ratio_stats.mean() - local_metrics[14] = ratio_stats.var() - print(ratio_stats) - local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() - local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() - local_metrics[17] = verifiable_correct_rate - # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() - local_metrics /= dist.get_world_size() - dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) - global_metrics = local_metrics.tolist() - metrics = { - "episode": episode, - "training_step": training_step, - "lr": self.scheduler.get_last_lr()[0], - "epoch": episode / len(train_dataset), - "time/from_scratch": time.time() - start_time, - "time/training": time.time() - training_time_start, - "val/sequence_lengths": global_metrics[0], - "val/num_stop_token_ids": global_metrics[1], - "objective/kl": global_metrics[2], - "objective/kl2": global_metrics[15], - "objective/kl3": global_metrics[16], - "objective/entropy": global_metrics[3], - "objective/non_score_reward": global_metrics[4], - "objective/rlhf_reward": global_metrics[5], - "objective/scores": global_metrics[6], - "policy/approxkl_avg": global_metrics[7], - "policy/clipfrac_avg": global_metrics[8], - "loss/policy_avg": global_metrics[9], - "loss/value_avg": global_metrics[10], - "val/clipfrac_avg": global_metrics[11], - "policy/entropy_avg": global_metrics[12], - "val/ratio": global_metrics[13], - "val/ratio_var": global_metrics[14], - "objective/verifiable_correct_rate": global_metrics[17], - } - if accelerator.is_main_process: - print_rich_single_line_metrics(metrics) - metrics_queue.put((metrics, episode, df)) - del (queries, responses, postprocessed_responses, logprobs, ref_logprobs, sequence_lengths, scores, values) - del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) - gc.collect() - torch.cuda.empty_cache() - # print(f"finished training {training_step}") - - # save steps - if args.save_freq > 0 and training_step % args.save_freq == 0: - checkpoint_dir = f"{args.output_dir}_checkpoints" - os.makedirs(checkpoint_dir, exist_ok=True) - step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") - os.makedirs(step_dir, exist_ok=True) - print(f"Saving model at step {training_step} to {step_dir}") - self.save_model(step_dir) - if args.try_launch_beaker_eval_jobs_on_weka: - self.launch_ai2_evals_on_weka(step_dir, training_step) - print(f"Saving final model at step {training_step} to {args.output_dir}") - self.save_model(args.output_dir) - if args.try_launch_beaker_eval_jobs_on_weka: - self.launch_ai2_evals_on_weka(args.output_dir) - - # Ai2 logic: we use /output to store the artifacts of the job, so we - # make a copy of the model to `/output` in the end. - if len(self.beaker_config.beaker_dataset_id_urls) > 0: - shutil.copytree(args.output_dir, "/output") - print("finished training") - - - def save_model(self, output_dir: str) -> None: - if self.rank == 0: - os.makedirs(output_dir, exist_ok=True) - - # save model weights for ZeRO2/3 - model_to_save = self.model - if hasattr(model_to_save, "module"): - model_to_save = model_to_save.module - - # gather parameters - output_state_dict = {} - for k, v in model_to_save.named_parameters(): - # only gather z3 params - params_to_fetch = _z3_params_to_fetch([v]) - with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): - vv = v.data.cpu() - if self.rank == 0: - output_state_dict[k] = vv - - if self.rank == 0: - state_dict = model_to_save.state_dict() - - # copy named_buffers with `persistent=True` - for k, v in model_to_save.named_buffers(): - if k not in state_dict: - continue - vv = v.data.cpu() - output_state_dict[k] = vv - - state_dict_keys = set(state_dict.keys()) - output_state_dict_keys = set(output_state_dict.keys()) - - # corner case for tie_word_embeddings, such as Qwen2-0.5B - if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: - state_dict_keys.remove("lm_head.weight") - - assert state_dict_keys.issubset( - output_state_dict_keys - ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" - - # # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 - # if isinstance(model_to_save, PeftModel): - # model_to_save.save_pretrained(output_dir, **kwargs) - # if self.stage == 3: - # torch.save( - # get_peft_model_state_dict(model_to_save, output_state_dict), - # os.path.join(output_dir, "adapter_model.bin"), - # ) - # else: - # save model - model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) - - # save tokenizer - self.original_tokenizer.save_pretrained(output_dir) - - def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = None) -> None: - """auto eval the metrics as `f"{args.exp_name}_step_{training_step}"` in our leaderboard""" - args = self.args - beaker_config = self.beaker_config - model_config = self.model_config - wandb_url = self.wandb_url - # Ai2 specific logic - if is_beaker_job() and self.rank == 0: - if training_step is not None: - leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" - else: - leaderboard_name = args.hf_repo_revision - if args.hf_metadata_dataset: - dataset_list = list(args.dataset_mixer_dict.keys()) - # mainly just focussing here on what would be useful for the leaderboard. - # wandb will have even more useful information. - metadata_blob = { - "model_name": args.exp_name, - "model_type": "ppo", - "datasets": dataset_list, - "base_model": model_config.model_name_or_path, - "wandb_path": wandb_url, - "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls, - } - upload_metadata_to_hf( - metadata_blob, - "metadata.json", - args.hf_metadata_dataset, - "results/" + leaderboard_name, # to match what the auto-evals name as. - ) - - command = f"""\ -python scripts/submit_eval_jobs.py \ - --model_name {leaderboard_name} \ - --location {step_dir} \ - --cluster ai2/saturn-cirrascale \ - --is_tuned \ - --workspace "tulu-3-results" \ - --preemptible \ - --use_hf_tokenizer_template \ - --beaker_image "nathanl/open_instruct_auto" \ - --upload_to_hf allenai/tulu-3-evals \ - --run_oe_eval_experiments \ - --evaluate_on_weka \ - --run_safety_evaluations \ - --skip_oi_evals""" - if args.oe_eval_tasks is not None: - command += f" --oe_eval_tasks {','.join(args.oe_eval_tasks)}" - print(f"Launching eval jobs with command: {command}") - process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") - print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") - print(f"Submit jobs after model training is finished - process return code: {process.returncode}") - -def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): - while True: - if stop_event.is_set(): - break - for ref in object_refs: - try: - ray.get(ref, timeout=0.01) - except ray.exceptions.GetTimeoutError: - pass - except Exception as e: - print(e) - print(f"Actor {ref} died") - time.sleep(120) - ray.shutdown() - os._exit(1) # Force shutdown the process - - time.sleep(30) - - -class ModelGroup: - def __init__( - self, - pg: PlacementGroup, - ray_process_cls: RayProcess, - num_gpus_per_node: List[int], - ): - self.pg = pg - self.ray_process_cls = ray_process_cls - self.num_gpus_per_node = num_gpus_per_node - self.num_gpus_per_actor = 1 - self.num_cpus_per_actor = 4 - self.models = [] - world_size = sum(self.num_gpus_per_node) - master_policy = ray_process_cls.options( - num_cpus=self.num_cpus_per_actor, - num_gpus=self.num_gpus_per_actor, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=self.pg, placement_group_bundle_index=0 - ), - ).remote(world_size, 0, 0, None, None) - - self.models.append(master_policy) - master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) - - def get_bundle_index(rank, num_gpus_per_node): - """given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" - bundle_idx = 0 - while rank >= num_gpus_per_node[bundle_idx]: - rank -= num_gpus_per_node[bundle_idx] - bundle_idx += 1 - return bundle_idx - - assert get_bundle_index(0, [7, 8, 4]) == 0 - assert get_bundle_index(1, [7, 8, 4]) == 0 - assert get_bundle_index(7, [7, 8, 4]) == 1 - assert get_bundle_index(8, [7, 8, 4]) == 1 - assert get_bundle_index(9, [7, 8, 4]) == 1 - assert get_bundle_index(16, [7, 8, 4]) == 2 - - # Setup worker models - for rank in range(1, world_size): - print(f"{rank=}, {world_size=}, {rank=}, {master_addr=}, {master_port=}") - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=self.pg, - placement_group_bundle_index=get_bundle_index(rank, self.num_gpus_per_node), - ) - worker_policy = ray_process_cls.options( - num_cpus=self.num_cpus_per_actor, - num_gpus=self.num_gpus_per_actor, - scheduling_strategy=scheduling_strategy, - ).remote(world_size, rank, 0, master_addr, master_port) - self.models.append(worker_policy) - - -def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): - calculate_runtime_args(args, model_config) - - # set up experiment tracking and seeds - all_configs = {} - beaker_config = None - if is_beaker_job(): - args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", None) - beaker_config = maybe_get_beaker_config() - all_configs.update(vars(beaker_config)) - all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) - if args.with_tracking: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=all_configs, - name=args.run_name, - save_code=True, - tags=[args.exp_name] + get_wandb_tags(), - ) - writer = SummaryWriter(f"runs/{args.run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - - # create a tokenizer (pad from right) - config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) - tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" - ) - if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: - tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding - tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] - - # create the dataset - dataset_dict = DatasetDict() - dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) - if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1: - args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict) - print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets.") - if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1: - args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict) - print(f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets.") - train_dataset = combine_dataset( - args.dataset_mixer_dict, - splits=args.dataset_train_splits, - columns_to_keep=[ - dataset_config.sft_messages_key, - dataset_config.ground_truths_key, - dataset_config.dataset_source_key, - ], - ) - if dataset_config.sanity_check: - train_dataset = train_dataset.select( - range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) - ) - train_dataset = dataset_processor.tokenize(train_dataset) - train_dataset = dataset_processor.filter(train_dataset, need_contain_labels=False) - dataset_dict["train"] = train_dataset - eval_dataset = None - if args.dataset_eval_mixer is not None: - eval_dataset = combine_dataset( - args.dataset_eval_mixer_dict, - splits=args.dataset_eval_splits, - columns_to_keep=[ - dataset_config.sft_messages_key, - dataset_config.ground_truths_key, - dataset_config.dataset_source_key, - ], - ) - if dataset_config.sanity_check: - eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) - eval_dataset = dataset_processor.tokenize(eval_dataset) - eval_dataset = dataset_processor.filter(eval_dataset, need_contain_labels=False) - dataset_dict["eval"] = eval_dataset - data_collator = SimpleGenerateCollatorWithGroundTruth(pad_token_id=tokenizer.pad_token_id) - - # some more runtime logging - pprint([args, dataset_config, model_config]) - visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) - if args.with_tracking: - # upload the visualized token length - dataset_processor.get_token_length_visualization( - dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" - ) - wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) - - # create the model and optimizer - pg = None - bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.actor_num_gpus_per_node] - pg = placement_group(bundles, strategy="STRICT_SPREAD") - ray.get(pg.ready()) - - inits = [] - policy_group = ModelGroup( - pg, - PolicyTrainerRayProcess, - args.actor_num_gpus_per_node, - ) - wandb_url = wandb.run.get_url() if args.with_tracking else None - inits.extend(model.from_pretrained.remote(args, model_config, beaker_config, wandb_url) for model in policy_group.models) - max_len = dataset_config.max_prompt_token_length + args.response_length - vllm_engines = create_vllm_engines( - args.vllm_num_engines, - args.vllm_tensor_parallel_size, - model_config.model_name_or_path, - model_config.model_revision, - args.seed, - args.enable_prefix_caching, - max_len, - ) - - metrics_queue = RayQueue() - ray.get(inits) - print("======== all models initialized =========") - ray.get(policy_group.models[0].get_vocab_size.remote()) - # print(f"{policy_vocab_size=}, {reward_vocab_size=}") - # if policy_vocab_size != reward_vocab_size: - # ray.shutdown() # shutdown here so this error message is not buried in the logs - # raise ValueError( - # "Policy and reward model must have the same vocab size. " - # f"Policy: {policy_vocab_size}, Reward: {reward_vocab_size}. " - # "If they don't have the same vocab size, the policy could generate tokens which " - # "is going to cause index out of bound error in the reward model." - # ) - - refs = [] - for i, policy_model in enumerate(policy_group.models): - refs.append( - policy_model.train.remote( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - vllm_engines=vllm_engines, - metrics_queue=metrics_queue, - data_collator=data_collator, - ) - ) - - # somtimes a worker dies due to CUDA issues, but the rest of the cluster would just hang - # so we need kill the ray cluster when this happens. - stop_event = threading.Event() - threading.Thread(target=kill_ray_cluster_if_a_worker_dies, args=(refs, stop_event)).start() - - # train and gather metrics - resume_training_step = 1 - for training_step in range(resume_training_step, args.num_training_steps + 1): - result = metrics_queue.get() - metrics, episode, df = result - for key, value in metrics.items(): - writer.add_scalar(key, value, episode) - - if df is not None: - if args.with_tracking: - wandb.log({"sample_completions": wandb.Table(dataframe=df)}) - # else: - # print_rich_table(df) - ray.get(refs) - - # save model - ray.shutdown() - stop_event.set() - - # Ai2 specific logic - if is_beaker_job(): - if args.hf_metadata_dataset: - dataset_list = list(args.dataset_mixer_dict.keys()) - # mainly just focussing here on what would be useful for the leaderboard. - # wandb will have even more useful information. - metadata_blob = { - "model_name": args.exp_name, - "model_type": "sft", - "datasets": dataset_list, - "base_model": model_config.model_name_or_path, - "wandb_path": wandb.run.get_url(), - "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls, - } - upload_metadata_to_hf( - metadata_blob, - "metadata.json", - args.hf_metadata_dataset, - "results/" + args.hf_repo_revision, # to match what the auto-evals name as. - ) - - if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: - command = f"""\ - python mason.py \ - --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ - --priority low \ - --preemptible \ - --budget ai2/allennlp \ - --workspace ai2/tulu-2-improvements \ - --image nathanl/open_instruct_auto \ - --pure_docker_mode \ - --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ - --beaker_workload_id {beaker_config.beaker_workload_id} \ - --model_name {args.hf_repo_revision} - """ - process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") - print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") - print(f"Submit jobs after model training is finished - process return code: {process.returncode}") - - accelerator = Namespace() - accelerator.is_main_process = True # hack - if args.push_to_hub: - print("Pushing model to hub") - push_folder_to_hub( - accelerator, - args.output_dir, - args.hf_repo_id, - args.hf_repo_revision, - ) - - # The `checkpoint_output_dir` is only used in case of preemption and should be deleted if the run was successful. - # We use `--save_freq` to save intermediate checkpoints in the output folder instead. - if args.checkpoint_output_dir is not None and os.path.exists(args.checkpoint_output_dir): - shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) - - -if __name__ == "__main__": - parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) - main(*parser.parse()) diff --git a/open_instruct/ppo_vllm_thread_ray_old.py b/open_instruct/ppo_vllm_thread_ray_old.py deleted file mode 100644 index 822c8a18d..000000000 --- a/open_instruct/ppo_vllm_thread_ray_old.py +++ /dev/null @@ -1,1799 +0,0 @@ -# Copyright 2024 AllenAI. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# --------------------------------------------------------------------- -# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF -# which has the following license: -# Copyright [yyyy] [name of copyright owner] -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import gc -import json -import logging -import os -import random -import shutil -import socket -import subprocess -import threading -import time -from argparse import Namespace -from dataclasses import asdict, dataclass -from queue import Empty, Queue -from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple - -import deepspeed -import numpy as np -import pandas as pd -import ray -import torch -import torch.distributed as dist -import torch.nn.functional as F -import torch.utils -import torch.utils.data -import vllm -from datasets import Dataset, DatasetDict -from deepspeed.ops.adam import FusedAdam -from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus -from huggingface_hub import HfApi -from ray.util.placement_group import PlacementGroup, placement_group -from ray.util.queue import Queue as RayQueue -from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy -from rich.pretty import pprint -from torch.utils.tensorboard import SummaryWriter -from transformers import ( - AutoConfig, - AutoModelForCausalLM, - AutoModelForSequenceClassification, - AutoTokenizer, - PreTrainedModel, - PreTrainedTokenizer, - get_scheduler, -) -from transformers.deepspeed import HfDeepSpeedConfig -from vllm import SamplingParams - -from open_instruct.dataset_processor import ( - CHAT_TEMPLATES, - INPUT_IDS_PROMPT_KEY, - DatasetConfig, - SFTDatasetProcessor, - SimpleGenerateCollator, - visualize_token, -) -from open_instruct.model_utils import ( - ModelConfig, - disable_dropout_in_model, - exact_div, - first_true_indices, - forward, - get_reward, - print_rich_single_line_metrics, - push_folder_to_hub, - truncate_response, -) -from open_instruct.utils import ( - ArgumentParserPlus, - combine_dataset, - get_wandb_tags, - is_beaker_job, - maybe_get_beaker_config, - maybe_use_ai2_hf_entity, - maybe_use_ai2_wandb_entity, - upload_metadata_to_hf, -) -from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group - -api = HfApi() -INVALID_LOGPROB = 1.0 - - -@dataclass -class Args: - # required dataset args - dataset_mixer: str = None - """A dictionary of datasets (local or HF) to sample from.""" - dataset_train_splits: List[str] = None - """The dataset splits to use for training""" - dataset_eval_mixer: Optional[str] = None - """A dictionary of datasets (local or HF) to sample from for evaluation""" - dataset_eval_splits: Optional[List[str]] = None - """The dataset splits to use for evaluation""" - dataset_mixer_dict: Optional[dict] = None - """The dataset mixer as a dictionary""" - dataset_eval_mixer_dict: Optional[dict] = None - """The dataset eval mixer as a dictionary""" - - # common args - exp_name: str = os.path.basename(__file__)[: -len(".py")] - """The name of this experiment""" - seed: int = 1 - """Seed of the experiment""" - run_name: Optional[str] = None - """A unique name of this run""" - - # optimizer args - eps: float = 1e-5 - """The epsilon value for the optimizer""" - learning_rate: float = 2e-5 - """The initial learning rate for AdamW optimizer.""" - lr_scheduler_type: Literal[ - "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" - ] = "linear" - """Which scheduler to use""" - warm_up_steps: int = 0 - """Number of warm up steps for the scheduler""" - - # various batch sizes - num_train_epochs: int = 1 - """Number of epochs to train""" - gradient_accumulation_steps: Optional[int] = None - """The number of gradient accumulation steps""" - per_device_train_batch_size: Optional[int] = 1 - """The forward batch size per device (local_micro_batch_size)""" - per_device_eval_batch_size: Optional[int] = 1 - """The forward batch size per device for evaluation (local_micro_batch_size)""" - total_episodes: Optional[int] = 100000 - """The total number of episodes in the dataset""" - world_size: Optional[int] = None - """The number of processes (GPUs) to use""" - micro_batch_size: Optional[int] = None - """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" - local_rollout_batch_size: int = 64 - """The number of rollout episodes per iteration per device""" - rollout_batch_size: Optional[int] = None - """The number of rollout episodes per iteration""" - num_training_steps: Optional[int] = None - """The number of training_steps to train""" - num_evals: int = 4 - """The number of evaluations to run throughout training""" - eval_freq: Optional[int] = None - """The frequency of evaluation steps""" - local_dataloader_batch_size: Optional[int] = None - """The batch size per GPU for the dataloader""" - - # online settings - num_epochs: int = 4 - """the number of epochs to train""" - num_mini_batches: int = 1 - """Number of minibatches to split a batch into""" - local_mini_batch_size: int = 64 - """the mini batch size per GPU""" - mini_batch_size: Optional[int] = None - """the mini batch size across GPUs""" - local_rollout_forward_batch_size: int = 64 - """per rank no grad forward pass in the rollout phase""" - reward_model_path: str = "EleutherAI/pythia-160m" - """the path to the reward model""" - reward_model_revision: Optional[str] = None - """the revision of the reward model""" - - # generation config - response_length: int = 53 - """the length of the response""" - stop_token: Optional[Literal["eos", "period"]] = None - """the stop token""" - stop_token_id: Optional[int] = None - """the truncation token id""" - min_response_length: int = 0 - """stop only after this many tokens""" - temperature: float = 0.7 - """the sampling temperature""" - penalty_reward_value: float = -1.0 - """the reward value for responses that do not contain `stop_token_id`""" - non_stop_penalty: bool = False - """whether to penalize responses that do not contain `stop_token_id`""" - - # online PPO specific args - beta: float = 0.05 - """the beta value of the RLHF objective (KL coefficient)""" - whiten_rewards: bool = False - """whether to whiten the rewards""" - cliprange: float = 0.2 - """the clip range""" - vf_coef: float = 0.1 - """the value function coefficient""" - cliprange_value: float = 0.2 - """the clip range for the value function""" - gamma: float = 1 - """the discount factor""" - lam: float = 0.95 - """the lambda value for GAE""" - kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" - - # async setting - async_mode: bool = True - """Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" - - # ray - actor_num_nodes: int = 1 - """number of nodes for actor""" - actor_num_gpus_per_node: int = 8 - """number of gpus per node for actor""" - ref_num_nodes: int = 1 - """number of nodes for reference""" - ref_num_gpus_per_node: int = 8 - """number of gpus per node for reference""" - colocate_actor_ref: bool = False - """whether to colocate reference and actor model, if true, they will share same gpus.""" - reward_num_nodes: int = 1 - """number of nodes for reward model""" - reward_num_gpus_per_node: int = 8 - """number of gpus per node for reward model""" - critic_num_nodes: int = 1 - """number of nodes for critic""" - critic_num_gpus_per_node: int = 8 - """number of gpus per node for critic""" - colocate_critic_reward: bool = False - """whether to colocate critic and reward model, if true, they will share same gpus.""" - vllm_num_engines: int = 1 - """number of vLLM Engines, set to 0 to disable vLLM""" - vllm_tensor_parallel_size: int = 1 - """tensor parallel size of vLLM Engine for multi-GPU inference""" - vllm_sync_backend: str = "nccl" - """DeepSpeed -> vLLM weight sync backend""" - enable_prefix_caching: bool = False - """whether to enable prefix caching""" - deepspeed_stage: int = 0 - - # wandb and HF tracking configs - with_tracking: bool = False - """If toggled, this experiment will be tracked with Weights and Biases""" - wandb_project_name: str = "open_instruct_internal" - """The wandb's project name""" - wandb_entity: Optional[str] = None - """The entity (team) of wandb's project""" - push_to_hub: bool = True - """Whether to upload the saved model to huggingface""" - hf_entity: Optional[str] = None - """The user or org name of the model repository from the Hugging Face Hub""" - hf_repo_id: Optional[str] = None - """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_revision: Optional[str] = None - """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" - hf_repo_url: Optional[str] = None - """The url of the saved model in the Hugging Face Hub (will be autoset)""" - output_dir: Optional[str] = None - """Where to save the model""" - checkpoint_output_dir: Optional[str] = None - """Where to save the model checkpoints in case of preemption""" - - # Ai2 specific settings - try_launch_beaker_eval_jobs: bool = True - """Whether to launch beaker evaluation jobs after training""" - hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" - """What dataset to upload the metadata to. If unset, don't upload metadata""" - - def __post_init__(self): - self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) - if self.dataset_eval_mixer is not None: - self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) - - -def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: - # if passed through cli: convert the dataset mixers to dictionaries - if isinstance(value, str): - return json.loads(value), value - # if passed through yaml: convert the dataset mixers to strings - elif isinstance(value, dict): - return value, json.dumps(value) - else: - raise ValueError("Input must be either a string or a dictionary") - - -def calculate_runtime_args(args: Args, model_config: ModelConfig): - """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" - # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) - # args.world_size = accelerator.num_processes - args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" - args.gradient_accumulation_steps = exact_div( - args.local_mini_batch_size, - args.per_device_train_batch_size, - "`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`", - ) - args.world_size = args.actor_num_gpus_per_node * args.actor_num_nodes - args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) - args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) - args.mini_batch_size = int(args.local_mini_batch_size * args.world_size) - args.num_training_steps = args.total_episodes // args.rollout_batch_size - args.eval_freq = max(1, args.num_training_steps // args.num_evals) - # PPO logic: do checks and set up dataloader batch size - if args.whiten_rewards: - assert ( - args.local_mini_batch_size >= 8 - ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" - args.local_dataloader_batch_size = args.rollout_batch_size - if args.push_to_hub: - if args.hf_repo_id is None: # auto-generate one - args.hf_repo_id = "open_instruct_dev" - if args.hf_entity is None: # first try to use AI2 entity - args.hf_entity = maybe_use_ai2_hf_entity() - if args.hf_entity is None: # then try to use the user's entity - args.hf_entity = HfApi().whoami()["name"] - args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" - if args.hf_repo_revision is None: # auto-generate one - args.hf_repo_revision = args.run_name - args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" - - if args.with_tracking: - if args.wandb_entity is None: - args.wandb_entity = maybe_use_ai2_wandb_entity() - - -def get_train_ds_config( - offload, - adam_offload=False, - stage=0, - bf16=True, - max_norm=1.0, - zpg=8, - grad_accum_dtype=None, - disable_trace_cache=True, -): - device = "cpu" if offload else "none" - zero_opt_dict = { - "stage": stage, - "offload_param": {"device": device}, - "offload_optimizer": { - "device": "cpu" if adam_offload else "none", - "pin_memory": True, - }, - "sub_group_size": "auto", - "stage3_max_live_parameters": "auto", - "stage3_max_reuse_distance": "auto", - "stage3_param_persistence_threshold": "auto", - "stage3_prefetch_bucket_size": "auto", - "reduce_bucket_size": "auto", - # # ZeRO++ - # "zero_hpz_partition_size": zpg, - # "zero_quantized_weights": False, - # "zero_quantized_gradients": False, - } - if disable_trace_cache: - zero_opt_dict["stage3_prefetch_bucket_size"] = 0 - zero_opt_dict["stage3_max_live_parameters"] = 0 - zero_opt_dict["stage3_max_reuse_distance"] = 0 - - return { - "steps_per_print": 100, - "zero_optimization": zero_opt_dict, - "bf16": { - "enabled": bf16, - }, - "gradient_clipping": max_norm, - "prescale_gradients": False, - "wall_clock_breakdown": False, - "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, - } - - -def get_eval_ds_config( - offload, - stage=0, - bf16=True, -): - zero_opt_dict = { - "stage": stage, - "stage3_param_persistence_threshold": "auto", - "offload_param": { - "device": "cpu" if offload else "none", - "pin_memory": True, - }, - } - return { - "steps_per_print": 100, - "zero_optimization": zero_opt_dict, - "bf16": { - "enabled": bf16, - }, - "prescale_gradients": False, - "wall_clock_breakdown": False, - } - - -def get_optimizer_grouped_parameters( - model, - weight_decay, - no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], -): - optimizer_grouped_parameters = [ - { - "params": [ - p - for n, p in model.named_parameters() - if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) - ], - "weight_decay": weight_decay, - }, - { - "params": [ - p - for n, p in model.named_parameters() - if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) - ], - "weight_decay": 0.0, - }, - ] - return optimizer_grouped_parameters - - -def _z3_params_to_fetch(param_list): - return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] - - -def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: - """Compute mean of tensor with a masked values.""" - if axis is not None: - return (values * mask).sum(axis=axis) / mask.sum(axis=axis) - else: - return (values * mask).sum() / mask.sum() - - -def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: - """Compute variance of tensor with masked values.""" - mean = masked_mean(values, mask) - centered_values = values - mean - variance = masked_mean(centered_values**2, mask) - if unbiased: - mask_sum = mask.sum() - if mask_sum == 0: - raise ValueError( - "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" - "try increase the `mini_batch_size` or `gradient_accumulation_steps`" - ) - # note that if mask_sum == 1, then there is a division by zero issue - # to avoid it you just need to use a larger minibatch_size - bessel_correction = mask_sum / (mask_sum - 1) - variance = variance * bessel_correction - return variance - - -def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: - """Whiten values with masked values.""" - mean, var = masked_mean(values, mask), masked_var(values, mask) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) - if not shift_mean: - whitened += mean - return whitened - - -def remove_padding(sequences, pad_token_id): - return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences] - - -class ShufflingIterator: - def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): - self.data = data.copy() - self.batch_size = batch_size - self.index = 0 - self.rng = np.random.default_rng(seed) - self.rng.shuffle(self.data) - - # Ensure the effective dataset size is divisible by batch_size - self.effective_size = len(self.data) - (len(self.data) % batch_size) - - def __iter__(self) -> Iterator[List[int]]: - return self - - def __next__(self) -> List[int]: - if self.index >= self.effective_size: - self.index = 0 - self.rng.shuffle(self.data) - - end_index = self.index + self.batch_size - batch = self.data[self.index : end_index].tolist() - self.index = end_index - - return batch - - -class RayProcess: - def __init__(self, world_size, rank, local_rank, master_addr, master_port): - logging.basicConfig( - format="%(asctime)s %(levelname)-8s %(message)s", - level=logging.INFO, - datefmt="%Y-%m-%d %H:%M:%S", - ) - self.world_size = world_size - self.rank = rank - self.local_rank = local_rank - self.master_addr = master_addr if master_addr else self.get_current_node_ip() - self.master_port = master_port if master_port else self.get_free_port() - os.environ["MASTER_ADDR"] = self.master_addr - os.environ["MASTER_PORT"] = str(self.master_port) - os.environ["WORLD_SIZE"] = str(self.world_size) - os.environ["RANK"] = str(self.rank) - # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES - # environment variable for each actor, so always set device to 0 - # os.environ["LOCAL_RANK"] = str(self._local_rank) - os.environ["LOCAL_RANK"] = "0" - random.seed(self.rank) - np.random.seed(self.rank) - torch.manual_seed(self.rank) - - @staticmethod - def get_current_node_ip(): - address = ray._private.services.get_node_ip_address() - # strip ipv6 address - return address.strip("[]") - - @staticmethod - def get_free_port(): - with socket.socket() as sock: - sock.bind(("", 0)) - return sock.getsockname()[1] - - def get_master_addr_port(self): - return self.master_addr, self.master_port - - def empty_cache(self) -> None: - torch.cuda.empty_cache() - - -@ray.remote(num_gpus=1) -class PolicyTrainerRayProcess(RayProcess): - def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): - self.args = args - self.num_gpu_per_node = num_gpus_per_node - self.num_nodes = num_nodes - torch.cuda.set_device(self.local_rank) - deepspeed.init_distributed() - - ds_config = get_train_ds_config( - offload=False, - adam_offload=False, - stage=args.deepspeed_stage, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["train_batch_size"] = args.mini_batch_size - # Costa: MAGIC: it's actually needed to initialize this `dschf`, so - # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration - # next line instructs transformers to partition the model directly over multiple gpus using - # deepspeed.zero.Init when model's `from_pretrained` method is called. - if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: - dschf = HfDeepSpeedConfig(ds_config) - else: - dschf = None - print(f"{dschf=}") - - self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, - revision=model_config.model_revision, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - disable_dropout_in_model(self.policy) - self.policy.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam - AdamOptimizer = FusedAdam - weight_decay = 0.0 - optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) - self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) - # self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) - scheduler = get_scheduler( - args.lr_scheduler_type, - optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, - ) - print(ds_config) - self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( - model=self.policy, - optimizer=self.optimizer, - config=ds_config, - lr_scheduler=scheduler, - dist_init_required=True, - ) - self.model.train() - - def get_vocab_size(self): - return self.policy.config.vocab_size - - def forward( - self, - query_response: torch.LongTensor, - response: torch.LongTensor, - pad_token_id: int, - context_length: int, - temperature: float, - ) -> torch.Tensor: - output = forward(self.model, query_response, pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= temperature + 1e-7 - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - return logprob - - def train( - self, - train_dataset: Dataset, - eval_dataset: Dataset, - tokenizer: PreTrainedTokenizer, - value_model: ray.actor.ActorHandle, - ref_model: ray.actor.ActorHandle, - reward_model: ray.actor.ActorHandle, - vllm_engines: List[ray.actor.ActorHandle], - metrics_queue: RayQueue, - data_collator: Callable, - ): - torch.set_printoptions(precision=4, sci_mode=False) - - args = self.args - - accelerator = Namespace() - accelerator.process_index = self.rank - accelerator.num_processes = self.world_size - accelerator.is_main_process = self.rank == 0 - torch.distributed.barrier() - if self.rank == 0: - master_address = ray._private.services.get_node_ip_address() - with socket.socket() as sock: - sock.bind(("", 0)) - master_port = sock.getsockname()[1] - vllm_num_engines, vllm_tensor_parallel_size = ( - args.vllm_num_engines, - args.vllm_tensor_parallel_size, - ) - world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 - backend = args.vllm_sync_backend - # https://github.com/OpenRLHF/OpenRLHF/issues/313 - if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0": - backend = "gloo" - print( - "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)" - ) - refs = [ - engine.init_process_group.remote( - master_address, - master_port, - i * vllm_tensor_parallel_size + 1, - world_size, - "openrlhf", - backend=backend, - ) - for i, engine in enumerate(vllm_engines) - ] - self.model_update_group = init_process_group( - backend=backend, - init_method=f"tcp://{master_address}:{master_port}", - world_size=world_size, - rank=0, - group_name="openrlhf", - ) - ray.get(refs) - torch.distributed.barrier() - - def broadcast_to_vllm(): - # avoid OOM - torch.cuda.empty_cache() - model = self.model.module - count, num_params = 0, len(list(model.named_parameters())) - refss = [] - for name, param in model.named_parameters(): - count += 1 # empty_cache at last param - - # Fire all vllm engines for broadcast - if torch.distributed.get_rank() == 0: - shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape - # print(f"broadcasting {name=} {shape=}") - refs = [ - engine.update_weight.remote( - name, dtype=param.dtype, shape=shape, empty_cache=count == num_params - ) - for engine in vllm_engines - ] - refss.extend(refs) - # For ZeRO-3, allgather sharded parameter and broadcast to all vllm engines by rank 0 - with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): - if torch.distributed.get_rank() == 0: - torch.distributed.broadcast(param.data, 0, group=self.model_update_group) - # ray.get(refs) - # print(f"broadcasting {name=} {shape=} success") - if torch.distributed.get_rank() == 0: - ray.get(refss) - - # broadcast_to_vllm() - print(f"broadcasted to vllm finished {self.rank=} {self.local_rank=}, {self.world_size=}") - if args.stop_token: - if args.stop_token == "eos": - args.stop_token_id = tokenizer.eos_token_id - if args.stop_token == "period": - args.stop_token_id = tokenizer.encode(".")[0] - # data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) - train_dataset_idxs = np.arange(len(train_dataset)) - shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed) - print(f"2broadcasted to vllm finished {self.rank=} {self.local_rank=}, {self.world_size=}") - - # hack to left pad - def repeat_generator(): - while True: - batch_idxs = next(shuffling_iter) - yield [train_dataset[i] for i in batch_idxs] - - iter_dataloader = iter(repeat_generator()) - generation_config = SamplingParams( - temperature=args.temperature, - top_p=1.0, - max_tokens=args.response_length, - include_stop_str_in_output=True, - ) - print("setup async queues") - param_prompt_Q = None - response_ids_Q = None - evaluation_Q = None - response_ids_Q = Queue(maxsize=1) - param_prompt_Q = Queue(maxsize=1) - evaluation_Q = Queue(maxsize=1) - num_eval_samples = 32 - sample_evaluation_prompt_token_ids = None - if eval_dataset is not None: - sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] - - def vllm_generate( - generation_config: SamplingParams, - response_ids_Q: Queue, - param_prompt_Q: Queue, - num_training_steps: int, - sample_evaluation_prompt_token_ids: Optional[List[int]], - evaluation_Q: Queue, - eval_freq: int, - resume_training_step: int, - ): - llm = vllm_engines[0] - for training_step in range(resume_training_step, num_training_steps + 1): - items = param_prompt_Q.get() - if items is None: - break - unwrapped_model, g_queries_list = items - # if unwrapped_model is not None: - generation_start_time = time.time() - - outputs = ray.get( - llm.generate.remote(sampling_params=generation_config, prompt_token_ids=g_queries_list) - ) - response_ids = [list(output.outputs[0].token_ids) for output in outputs] - print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") - response_ids_Q.put(response_ids) - - if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: - outputs = ray.get( - llm.generate.remote( - prompt_token_ids=sample_evaluation_prompt_token_ids, sampling_params=generation_config - ) - ) - response_ids = [list(output.outputs[0].token_ids) for output in outputs] - evaluation_Q.put(response_ids) - - resume_training_step = 1 - if accelerator.is_main_process: - thread = threading.Thread( - target=vllm_generate, - args=( - generation_config, - response_ids_Q, - param_prompt_Q, - args.num_training_steps, - sample_evaluation_prompt_token_ids, - evaluation_Q, - args.eval_freq, - resume_training_step, - ), - ) - thread.start() - print("vllm generate thread starts") - - # set up the metrics and initial states - device = torch.device(self.local_rank) - g_vllm_responses = torch.zeros( - (args.rollout_batch_size, args.response_length), device=device, dtype=torch.long - ) - stats_shape = (args.num_epochs, args.num_mini_batches, args.gradient_accumulation_steps) - approxkl_stats = torch.zeros(stats_shape, device=device) - pg_clipfrac_stats = torch.zeros(stats_shape, device=device) - pg_loss_stats = torch.zeros(stats_shape, device=device) - vf_loss_stats = torch.zeros(stats_shape, device=device) - vf_clipfrac_stats = torch.zeros(stats_shape, device=device) - entropy_stats = torch.zeros(stats_shape, device=device) - ratio_stats = torch.zeros(stats_shape, device=device) - local_metrics = torch.zeros((20,), device=device) - episode = args.rollout_batch_size * (resume_training_step - 1) - - # training loop - start_time = time.time() - global_data = next(iter_dataloader) - data = data_collator( - global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size] - ) - global_queries = data_collator(global_data)[ - INPUT_IDS_PROMPT_KEY - ].tolist() # can be simplified since we `remove_padding` later anyway - queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - if accelerator.is_main_process: - param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) - - # for _ in range(1, resume_training_step): # we didn't store scheduler state - # scheduler.step() - - for training_step in range(resume_training_step, args.num_training_steps + 1): - episode += args.rollout_batch_size - queries = queries_next - - if accelerator.is_main_process: - df = None - try: - evaluation_responses = evaluation_Q.get(timeout=0.01) - print("🔥🔥🔥 Evaluation responses received") - table = {} - table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids) - table["response"] = tokenizer.batch_decode(evaluation_responses) - table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] - df = pd.DataFrame(table) - del table - except Empty: - print("🙈 Evaluation responses not received") - - # (optionally) evaluate the model - if args.async_mode: - if training_step != 1: - global_data = next(iter_dataloader) - data = data_collator( - global_data[ - self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size - ] - ) - global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() - queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - - start_time = time.time() - broadcast_to_vllm() - print( - f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" - ) - if accelerator.is_main_process: - param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) - else: - if training_step != 1: - # NOTE: important: the indent here is different for sync mode - # we also set to use `queries = queries_next` immediately - global_data = next(iter_dataloader) - data = data_collator( - global_data[ - self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size - ] - ) - global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() - queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) - start_time = time.time() - broadcast_to_vllm() - print( - f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" - ) - if accelerator.is_main_process: - param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) - queries = queries_next - - # print('get reward stuff starts') - training_time_start = time.time() - with torch.no_grad(): - context_length = queries.shape[1] - responses = [] - postprocessed_responses = [] - logprobs = [] - ref_logprobs = [] - scores = [] - sequence_lengths = [] - values = [] - if accelerator.is_main_process: - g_response_token_ids = response_ids_Q.get() - DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out - g_padded_response_ids = [ - response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) - for response in g_response_token_ids - ] - g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) - print(f"{g_padded_response_ids.shape=}") - print(f"{g_vllm_responses.shape=}") - g_vllm_responses[:] = g_padded_response_ids - dist.broadcast(g_vllm_responses, src=0) - local_vllm_responses = g_vllm_responses[ - accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] - ] - # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") - query_responses = torch.cat((queries, local_vllm_responses), 1) - for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): - print(f"get reward stuff starts {i=}") - query = queries[i : i + args.local_rollout_forward_batch_size] - query_response = query_responses[i : i + args.local_rollout_forward_batch_size] - response = query_response[:, context_length:] - - # 1. launch ref model future - ref_logprob_future = ref_model.forward.remote( - query_response, response, tokenizer.pad_token_id, context_length, args.temperature - ) - - # 2. launch reward model future - # Response Processing 1. truncate response after the first occurrence of `stop_token_id` - postprocessed_response = response - if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 - postprocessed_response = truncate_response( - args.stop_token_id, tokenizer.pad_token_id, response - ) - # print("get reward stuff starts 2") - # Response Processing 2. run reward model on the truncated responses - postprocessed_query_response = torch.cat((query, postprocessed_response), 1) - sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 - reward_future = reward_model.forward.remote( - postprocessed_query_response, tokenizer.pad_token_id, context_length - ) - - # print("get reward stuff starts 3") - # 3. launch value model future - value_future = value_model.forward.remote(query_response, tokenizer.pad_token_id, context_length) - - # 4. do local forward pass - logprob = self.forward( - query_response, response, tokenizer.pad_token_id, context_length, args.temperature - ) - torch.cuda.empty_cache() - - # print("get reward stuff starts 4") - # 5. get results from futures - _, score, _ = ray.get(reward_future) - # print(f"{score.shape=}") - full_value, _, _ = ray.get(value_future) - # print(f"{full_value.shape=}") - ref_logprob = ray.get(ref_logprob_future) - # print(f"{ref_logprob.shape=}") - if args.colocate_critic_reward: - ray.get([value_model.empty_cache.remote()]) - ray.get([reward_model.empty_cache.remote()]) - if args.colocate_actor_ref: - ray.get([ref_model.empty_cache.remote()]) - value = full_value[:, context_length - 1 : -1].squeeze(-1) - responses.append(response) - postprocessed_responses.append(postprocessed_response) - logprobs.append(logprob) - ref_logprobs.append(ref_logprob) - sequence_lengths.append(sequence_length) - scores.append(score) - values.append(value) - # print(f"get reward stuff starts 5") - responses = torch.cat(responses, 0) - postprocessed_responses = torch.cat(postprocessed_responses, 0) - logprobs = torch.cat(logprobs, 0) - ref_logprobs = torch.cat(ref_logprobs, 0) - sequence_lengths = torch.cat(sequence_lengths, 0) - scores = torch.cat(scores, 0) - values = torch.cat(values, 0) - # print(f"get reward stuff finished") - del (logprob, ref_logprob, full_value, value, score) - gc.collect() - torch.cuda.empty_cache() - - # Response Processing 3. filter response. Ensure that the sample contains stop_token_id - # responses not passing that filter will receive a low (fixed) score - # only query humans on responses that pass that filter - contain_stop_token = torch.any(postprocessed_responses == args.stop_token_id, dim=-1) - # NOTE: only apply the stop token filter if the response is long enough - # otherwise the model could learn to generate the first token as the stop token - contain_stop_token = contain_stop_token & (sequence_lengths >= args.min_response_length) - if args.non_stop_penalty: - scores = torch.where( - contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value) - ) - - # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw - response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) - padding_mask = response_idxs > sequence_lengths.unsqueeze(1) - logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) - ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) - sequence_lengths_p1 = sequence_lengths + 1 - padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) - values = torch.masked_fill(values, padding_mask_p1, 0) - # print(f"get reward stuff finished 2") - - # 4. compute rewards - kl1 = logprobs - ref_logprobs - kl2 = (kl1) ** 2 / 2 - kl3 = (-kl1).exp() - 1 + kl1 - if args.kl_estimator == "kl1": - kl = kl1 - elif args.kl_estimator == "kl2": - kl = kl2 - elif args.kl_estimator == "kl3": - kl = kl3 - # if self.rank==0: - # print(f"{logprobs[0][:40]=}, {ref_logprobs[0][:40]=}, {kl.sum(1)=}") - non_score_reward = -args.beta * kl - non_score_reward_sum = non_score_reward.sum(1) - rlhf_reward = scores + non_score_reward_sum - rewards = non_score_reward.clone() - actual_start = torch.arange(rewards.size(0), device=rewards.device) - actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) - rewards[[actual_start, actual_end]] += scores - # print(f"get reward stuff finished 3") - - # 5. whiten rewards - if args.whiten_rewards: - rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) - rewards = torch.masked_fill(rewards, padding_mask_p1, 0) - - # print('gae') - # 6. compute advantages and returns - lastgaelam = 0 - advantages_reversed = [] - gen_length = responses.shape[1] - for t in reversed(range(gen_length)): - nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 - delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] - lastgaelam = delta + args.gamma * args.lam * lastgaelam - advantages_reversed.append(lastgaelam) - advantages = torch.stack(advantages_reversed[::-1], axis=1) - returns = advantages + values - advantages = masked_whiten(advantages, ~padding_mask) - advantages = torch.masked_fill(advantages, padding_mask, 0) - torch.cuda.empty_cache() - - # print('training starts') - # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch - for epoch_idx in range(args.num_epochs): - b_inds = np.random.permutation(args.local_rollout_batch_size) - minibatch_idx = 0 - for mini_batch_start in range(0, args.local_rollout_batch_size, args.local_mini_batch_size): - mini_batch_end = mini_batch_start + args.local_mini_batch_size - mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] - gradient_accumulation_idx = 0 - for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): - print("micro batch start", micro_batch_start, self.rank) - micro_batch_end = micro_batch_start + args.per_device_train_batch_size - micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] - mb_advantage = advantages[micro_batch_inds] - mb_responses = responses[micro_batch_inds] - mb_query_responses = query_responses[micro_batch_inds] - mb_logprobs = logprobs[micro_batch_inds] - mb_return = returns[micro_batch_inds] - mb_values = values[micro_batch_inds] - mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds] - - value_model_step_future = value_model.step.remote( - mb_query_responses, - tokenizer.pad_token_id, - context_length, - mb_padding_mask_p1, - mb_return, - mb_values, - args.cliprange_value, - args.vf_coef, - ) - new_logprobs = self.forward( - mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature - ) - # if self.rank==0: - # print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}") - new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) - logprobs_diff = new_logprobs - mb_logprobs - ratio = torch.exp(logprobs_diff) - pg_losses = -mb_advantage * ratio - pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) - pg_loss_max = torch.max(pg_losses, pg_losses2) - pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) - loss = pg_loss - self.model.backward(loss) - # print("backward loss", self.rank, "micro batch start", micro_batch_start) - # print("trying to step", self.rank, "micro batch start", micro_batch_start) - self.model.step() - # print("step", self.rank, "micro batch start", micro_batch_start) - with torch.no_grad(): - # print("waiting for value model step", self.rank, "micro batch start", micro_batch_start) - vf_loss, vf_clipfrac = ray.get(value_model_step_future) - pg_clipfrac = masked_mean( - (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] - ) - # print("value model stepped", self.rank, "micro batch start", micro_batch_start) - # prob_dist = torch.nn.functional.softmax(logits, dim=-1) - # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) - approxkl = 0.5 * (logprobs_diff**2).mean() - approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl - pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac - pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss - vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss - vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac - # entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() - ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() - gradient_accumulation_idx += 1 - minibatch_idx += 1 - # fmt: off - del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1 - del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss - # del vf_loss, vf_clipfrac, pg_clipfrac, approxkl - # fmt: on - # del everything and empty cache - torch.cuda.empty_cache() - del b_inds, mini_batch_inds - # print("start metrics") - with torch.no_grad(): - local_metrics[0] = sequence_lengths.float().mean() - local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() - local_metrics[2] = kl.sum(1).mean() - local_metrics[3] = (-logprobs).sum(1).mean() - local_metrics[4] = non_score_reward_sum.mean() - local_metrics[5] = rlhf_reward.mean() - local_metrics[6] = scores.mean() - local_metrics[7] = approxkl_stats.mean() - local_metrics[8] = pg_clipfrac_stats.mean() - local_metrics[9] = pg_loss_stats.mean() - local_metrics[10] = vf_loss_stats.mean() - local_metrics[11] = vf_clipfrac_stats.mean() - local_metrics[12] = entropy_stats.mean() - local_metrics[13] = ratio_stats.mean() - local_metrics[14] = ratio_stats.var() - local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() - local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() - # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() - local_metrics /= dist.get_world_size() - dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) - global_metrics = local_metrics.tolist() - metrics = { - "episode": episode, - "training_step": training_step, - "lr": self.scheduler.get_last_lr()[0], - "epoch": episode / len(train_dataset), - "time/from_scratch": time.time() - start_time, - "time/training": time.time() - training_time_start, - "val/sequence_lengths": global_metrics[0], - "val/num_stop_token_ids": global_metrics[1], - "objective/kl": global_metrics[2], - "objective/kl2": global_metrics[15], - "ojbective/kl3": global_metrics[16], - "objective/entropy": global_metrics[3], - "objective/non_score_reward": global_metrics[4], - "objective/rlhf_reward": global_metrics[5], - "objective/scores": global_metrics[6], - "policy/approxkl_avg": global_metrics[7], - "policy/clipfrac_avg": global_metrics[8], - "loss/policy_avg": global_metrics[9], - "loss/value_avg": global_metrics[10], - "val/clipfrac_avg": global_metrics[11], - "policy/entropy_avg": global_metrics[12], - "val/ratio": global_metrics[13], - "val/ratio_var": global_metrics[14], - } - if accelerator.is_main_process: - print_rich_single_line_metrics(metrics) - metrics_queue.put((metrics, episode, df)) - del (queries, responses, postprocessed_responses, logprobs, ref_logprobs, sequence_lengths, scores, values) - del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) - gc.collect() - torch.cuda.empty_cache() - print(f"finished training {training_step}") - print("finished training") - - def save_model(self, tokenizer: PreTrainedTokenizer, output_dir: str) -> None: - if self.rank == 0: - os.makedirs(output_dir, exist_ok=True) - - # save model weights for ZeRO2/3 - model_to_save = self.model - if hasattr(model_to_save, "module"): - model_to_save = model_to_save.module - - # gather parameters - output_state_dict = {} - for k, v in model_to_save.named_parameters(): - # only gather z3 params - params_to_fetch = _z3_params_to_fetch([v]) - with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): - vv = v.data.cpu() - if self.rank == 0: - output_state_dict[k] = vv - - if self.rank == 0: - state_dict = model_to_save.state_dict() - - # copy named_buffers with `persistent=True` - for k, v in model_to_save.named_buffers(): - if k not in state_dict: - continue - vv = v.data.cpu() - output_state_dict[k] = vv - - state_dict_keys = set(state_dict.keys()) - output_state_dict_keys = set(output_state_dict.keys()) - - # corner case for tie_word_embeddings, such as Qwen2-0.5B - if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: - state_dict_keys.remove("lm_head.weight") - - assert state_dict_keys.issubset( - output_state_dict_keys - ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" - - # # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 - # if isinstance(model_to_save, PeftModel): - # model_to_save.save_pretrained(output_dir, **kwargs) - # if self.stage == 3: - # torch.save( - # get_peft_model_state_dict(model_to_save, output_state_dict), - # os.path.join(output_dir, "adapter_model.bin"), - # ) - # else: - # save model - model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) - - # save tokenizer - tokenizer.save_pretrained(output_dir) - - -@ray.remote(num_gpus=1) -class ReferenceModelRayProcess(RayProcess): - def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): - self.args = args - self.num_gpu_per_node = num_gpus_per_node - self.num_nodes = num_nodes - torch.cuda.set_device(self.local_rank) - deepspeed.init_distributed() - ds_config = get_eval_ds_config( - offload=False, - stage=args.deepspeed_stage, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["train_batch_size"] = args.mini_batch_size - # Costa: MAGIC: it's actually needed to initialize this `dschf`, so - # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration - # next line instructs transformers to partition the model directly over multiple gpus using - # deepspeed.zero.Init when model's `from_pretrained` method is called. - if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: - dschf = HfDeepSpeedConfig(ds_config) - else: - dschf = None - print(f"{dschf=}") - - self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( - model_config.model_name_or_path, - revision=model_config.model_revision, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - disable_dropout_in_model(self.policy) - self.model, *_ = deepspeed.initialize(model=self.policy, config=ds_config, dist_init_required=True) - self.model.eval() - - def forward( - self, - query_response: torch.LongTensor, - response: torch.LongTensor, - pad_token_id: int, - context_length: int, - temperature: float, - ) -> torch.Tensor: - output = forward(self.model, query_response, pad_token_id) - logits = output.logits[:, context_length - 1 : -1] - logits /= temperature + 1e-7 - all_logprob = F.log_softmax(logits, dim=-1) - logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) - return logprob - - -@ray.remote(num_gpus=1) -class ValueTrainerRayProcess(RayProcess): - def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): - self.args = args - self.num_gpu_per_node = num_gpus_per_node - self.num_nodes = num_nodes - torch.cuda.set_device(self.local_rank) - deepspeed.init_distributed() - - ds_config = get_train_ds_config( - offload=False, - adam_offload=False, - stage=args.deepspeed_stage, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["train_batch_size"] = args.mini_batch_size - # Costa: MAGIC: it's actually needed to initialize this `dschf`, so - # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration - # next line instructs transformers to partition the model directly over multiple gpus using - # deepspeed.zero.Init when model's `from_pretrained` method is called. - if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: - dschf = HfDeepSpeedConfig(ds_config) - else: - dschf = None - print(f"{dschf=}") - - self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( - args.reward_model_path, - revision=args.reward_model_revision, - num_labels=1, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - disable_dropout_in_model(self.value_model) - self.value_model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) - # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam - AdamOptimizer = FusedAdam - weight_decay = 0.0 - optim_params = get_optimizer_grouped_parameters(self.value_model, weight_decay) - self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) - # self.optimizer = torch.optim.AdamW(self.value_model.parameters(), lr=args.learning_rate) - scheduler = get_scheduler( - args.lr_scheduler_type, - optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, - ) - # print(ds_config) - self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( - model=self.value_model, - optimizer=self.optimizer, - config=ds_config, - lr_scheduler=scheduler, - dist_init_required=True, - ) - self.model.train() - - def forward( - self, query_responses: torch.Tensor, pad_token_id: int, context_length: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return get_reward(self.value_model, query_responses, pad_token_id, context_length) - - def step( - self, - query_responses: torch.Tensor, - pad_token_id: int, - context_length: int, - mb_padding_mask_p1: torch.Tensor, - mb_return: torch.Tensor, - mb_values: torch.Tensor, - cliprange_value: float, - vf_coef: float, - ) -> None: - torch.cuda.empty_cache() - vpred_temp = self.forward(query_responses, pad_token_id, context_length) - vpred_temp = vpred_temp[0] - vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) - vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0) - vpredclipped = torch.clamp( - vpred, - mb_values - cliprange_value, - mb_values + cliprange_value, - ) - vf_losses1 = torch.square(vpred - mb_return) - vf_losses2 = torch.square(vpredclipped - mb_return) - vf_loss_max = torch.max(vf_losses1, vf_losses2) - vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1) - self.model.backward(vf_loss * vf_coef) - self.model.step() - with torch.no_grad(): - vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1) - del (vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max) - return vf_loss, vf_clipfrac - - -@ray.remote(num_gpus=1) -class RewardModelRayProcess(RayProcess): - def from_pretrained(self, args: Args, model_config: ModelConfig, num_gpus_per_node: int, num_nodes: int): - deepspeed.init_distributed() - self.num_gpu_per_node = num_gpus_per_node - self.num_nodes = num_nodes - self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( - args.reward_model_path, - revision=args.reward_model_revision, - num_labels=1, - torch_dtype=torch.bfloat16, - attn_implementation="flash_attention_2", - use_cache=False, - ) - disable_dropout_in_model(self.reward_model) - ds_config = get_eval_ds_config( - offload=False, - stage=args.deepspeed_stage, - bf16=True, - ) - ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size - ds_config["train_batch_size"] = args.mini_batch_size - self.model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config, dist_init_required=True) - self.model.eval() - - def forward( - self, query_responses: torch.Tensor, pad_token_id: int, context_length: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return get_reward(self.reward_model, query_responses, pad_token_id, context_length) - - def get_vocab_size(self): - return self.reward_model.config.vocab_size - - -def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): - while True: - if stop_event.is_set(): - break - for ref in object_refs: - try: - ray.get(ref, timeout=0.01) - except ray.exceptions.GetTimeoutError: - pass - except ray.exceptions.ActorDiedError as e: - ray.shutdown() - print(f"Actor {ref} died") - print(e) - os._exit(1) # Force shutdown the process - - time.sleep(30) - - -class ModelGroup: - def __init__( - self, - pg: PlacementGroup, - ray_process_cls: RayProcess, - num_gpus_per_actor: int, - num_gpus_per_node: int, - num_nodes: int, - ): - self.pg = pg - self.ray_process_cls = ray_process_cls - self.num_gpus_per_actor = num_gpus_per_actor - self.num_gpus_per_node = num_gpus_per_node - self.num_nodes = num_nodes - self.models = [] - - world_size = num_gpus_per_node * num_nodes - if self.num_gpus_per_actor > 1 and self.pg is None: - bundles = [{"GPU": self.num_gpus_per_actor, "CPU": self.num_gpus_per_actor} for _ in range(self.num_nodes)] - - self.pg = placement_group(bundles, strategy="STRICT_SPREAD") - ray.get(self.pg.ready()) - if self.pg: - master_policy = ray_process_cls.options( - num_cpus=num_gpus_per_actor, - num_gpus=num_gpus_per_actor, - scheduling_strategy=PlacementGroupSchedulingStrategy( - placement_group=self.pg, placement_group_bundle_index=0 - ), - ).remote(world_size, 0, 0, None, None) - else: - master_policy = ray_process_cls.options( - num_cpus=num_gpus_per_actor, - num_gpus=num_gpus_per_actor, - ).remote(world_size, 0, 0, None, None) - - self.models.append(master_policy) - master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) - - # Setup worker models - for rank in range(1, world_size): - print(f"{rank=}, {world_size, rank, 0, master_addr, master_port=}") - scheduling_strategy = None - if pg is not None: - scheduling_strategy = PlacementGroupSchedulingStrategy( - placement_group=self.pg, - placement_group_bundle_index=rank // self.num_gpus_per_node, - ) - worker_policy = ray_process_cls.options( - num_cpus=self.num_gpus_per_actor, - num_gpus=self.num_gpus_per_actor, - scheduling_strategy=scheduling_strategy, - ).remote(world_size, rank, 0, master_addr, master_port) - self.models.append(worker_policy) - - -def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): - calculate_runtime_args(args, model_config) - - # set up experiment tracking and seeds - all_configs = {} - if is_beaker_job(): - args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", args.output_dir) - beaker_config = maybe_get_beaker_config() - # try saving to the beaker `/output`, which will be uploaded to the beaker dataset - if len(beaker_config.beaker_dataset_id_urls) > 0: - args.output_dir = "/output" - all_configs.update(vars(beaker_config)) - all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) - if args.with_tracking: - import wandb - - wandb.init( - project=args.wandb_project_name, - entity=args.wandb_entity, - sync_tensorboard=True, - config=all_configs, - name=args.run_name, - save_code=True, - tags=[args.exp_name] + get_wandb_tags(), - ) - writer = SummaryWriter(f"runs/{args.run_name}") - writer.add_text( - "hyperparameters", - "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), - ) - - # create a tokenizer (pad from right) - config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) - tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" - ) - if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: - tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> - else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding - tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] - - # create the dataset - dataset_dict = DatasetDict() - dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config) - train_dataset = combine_dataset( - args.dataset_mixer_dict, - splits=args.dataset_train_splits, - columns_to_keep=[dataset_config.sft_messages_key], - ) - if dataset_config.sanity_check: - train_dataset = train_dataset.select( - range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) - ) - train_dataset = dataset_processor.tokenize(train_dataset) - train_dataset = dataset_processor.filter(train_dataset) - dataset_dict["train"] = train_dataset - eval_dataset = None - if args.dataset_eval_mixer is not None: - eval_dataset = combine_dataset( - args.dataset_eval_mixer_dict, - splits=args.dataset_eval_splits, - columns_to_keep=[dataset_config.sft_messages_key], - ) - eval_dataset = eval_dataset.select(range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples))) - eval_dataset = dataset_processor.tokenize(eval_dataset) - eval_dataset = dataset_processor.filter(eval_dataset) - dataset_dict["eval"] = eval_dataset - data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) - - # some more runtime logging - pprint([args, dataset_config, model_config]) - visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) - if args.with_tracking: - # upload the visualized token length - dataset_processor.get_token_length_visualization( - dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" - ) - wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) - - # create the model and optimizer - pg = None - if args.colocate_actor_ref: - assert ( - args.actor_num_nodes == args.ref_num_nodes and args.actor_num_gpus_per_node == args.ref_num_gpus_per_node - ), "num_nodes and num_gpus_per_node must be the same when colocate actor and ref model." - - bundles = [ - {"GPU": args.actor_num_gpus_per_node, "CPU": args.actor_num_gpus_per_node} - for _ in range(args.actor_num_nodes) - ] - pg = placement_group(bundles, strategy="STRICT_SPREAD") - ray.get(pg.ready()) - - inits = [] - policy_group = ModelGroup( - pg, - PolicyTrainerRayProcess, - 0.75 if args.colocate_actor_ref else 1, - args.actor_num_gpus_per_node, - args.actor_num_nodes, - ) - inits.extend( - model.from_pretrained.remote(args, model_config, args.actor_num_gpus_per_node, args.actor_num_nodes) - for model in policy_group.models - ) - ref_model_group = ModelGroup( - pg, - ReferenceModelRayProcess, - 0.25 if args.colocate_actor_ref else 1, - args.ref_num_gpus_per_node, - args.ref_num_nodes, - ) - inits.extend( - model.from_pretrained.remote(args, model_config, args.ref_num_gpus_per_node, args.ref_num_nodes) - for model in ref_model_group.models - ) - - # if colocated, create placement group for critic and reward model explicitly. - pg = None - if args.colocate_critic_reward: - assert ( - args.critic_num_nodes == args.reward_num_nodes - and args.critic_num_gpus_per_node == args.reward_num_gpus_per_node - ), "num_nodes and num_gpus_per_node must be the same when colocate critic and reward model." - - bundles = [ - {"GPU": args.critic_num_gpus_per_node, "CPU": args.critic_num_gpus_per_node} - for _ in range(args.critic_num_nodes) - ] - pg = placement_group(bundles, strategy="STRICT_SPREAD") - ray.get(pg.ready()) - - value_model_group = ModelGroup( - pg, - ValueTrainerRayProcess, - 0.75 if args.colocate_critic_reward else 1, - args.critic_num_gpus_per_node, - args.critic_num_nodes, - ) - inits.extend( - model.from_pretrained.remote(args, model_config, args.critic_num_gpus_per_node, args.critic_num_nodes) - for model in value_model_group.models - ) - reward_model_group = ModelGroup( - pg, - RewardModelRayProcess, - 0.25 if args.colocate_critic_reward else 1, - args.reward_num_gpus_per_node, - args.reward_num_nodes, - ) - inits.extend( - model.from_pretrained.remote(args, model_config, args.reward_num_gpus_per_node, args.reward_num_nodes) - for model in reward_model_group.models - ) - - max_len = dataset_config.max_prompt_token_length + args.response_length - vllm_engines = create_vllm_engines( - args.vllm_num_engines, - args.vllm_tensor_parallel_size, - model_config.model_name_or_path, - model_config.model_revision, - args.seed, - args.enable_prefix_caching, - max_len, - ) - - metrics_queue = RayQueue() - ray.get(inits) - print("======== all models initialized =========") - policy_vocab_size = ray.get(policy_group.models[0].get_vocab_size.remote()) - reward_vocab_size = ray.get(reward_model_group.models[0].get_vocab_size.remote()) - print(f"{policy_vocab_size=}, {reward_vocab_size=}") - if policy_vocab_size != reward_vocab_size: - ray.shutdown() # shutdown here so this error message is not buried in the logs - raise ValueError( - "Policy and reward model must have the same vocab size. " - f"Policy: {policy_vocab_size}, Reward: {reward_vocab_size}. " - "If they don't have the same vocab size, the policy could generate tokens which " - "is going to cause index out of bound error in the reward model." - ) - - refs = [] - for i, policy_model in enumerate(policy_group.models): - value_model = value_model_group.models[i % len(value_model_group.models)] - ref_model = ref_model_group.models[i % len(ref_model_group.models)] - reward_model = reward_model_group.models[i % len(reward_model_group.models)] - print(f"{value_model=}, {i=}, {len(value_model_group.models)=}") - refs.append( - policy_model.train.remote( - train_dataset=train_dataset, - eval_dataset=eval_dataset, - tokenizer=tokenizer, - value_model=value_model, - ref_model=ref_model, - reward_model=reward_model, - vllm_engines=vllm_engines, - metrics_queue=metrics_queue, - data_collator=data_collator, - ) - ) - - # somtimes a worker dies due to CUDA issues, but the rest of the cluster would just hang - # so we need kill the ray cluster when this happens. - stop_event = threading.Event() - threading.Thread(target=kill_ray_cluster_if_a_worker_dies, args=(refs, stop_event)).start() - - # train and gather metrics - resume_training_step = 1 - for training_step in range(resume_training_step, args.num_training_steps + 1): - result = metrics_queue.get() - metrics, episode, df = result - for key, value in metrics.items(): - writer.add_scalar(key, value, episode) - - if df is not None: - if args.with_tracking: - wandb.log({"sample_completions": wandb.Table(dataframe=df)}) - # else: - # print_rich_table(df) - ray.get(refs) - - # save model - original_tokenizer = AutoTokenizer.from_pretrained( - model_config.model_name_or_path, revision=model_config.model_revision - ) - ray.get( - [policy_model.save_model.remote(original_tokenizer, args.output_dir) for policy_model in policy_group.models] - ) - ray.shutdown() - stop_event.set() - - # Ai2 specific logic - if is_beaker_job(): - if args.hf_metadata_dataset: - dataset_list = list(args.dataset_mixer_dict.keys()) - # mainly just focussing here on what would be useful for the leaderboard. - # wandb will have even more useful information. - metadata_blob = { - "model_name": args.exp_name, - "model_type": "sft", - "datasets": dataset_list, - "base_model": model_config.model_name_or_path, - "wandb_path": wandb.run.get_url(), - "beaker_experiment": beaker_config.beaker_experiment_url, - "beaker_datasets": beaker_config.beaker_dataset_id_urls, - } - upload_metadata_to_hf( - metadata_blob, - "metadata.json", - args.hf_metadata_dataset, - "results/" + args.hf_repo_revision, # to match what the auto-evals name as. - ) - - if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: - command = f"""\ - python mason.py \ - --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ - --priority low \ - --preemptible \ - --budget ai2/allennlp \ - --workspace ai2/tulu-2-improvements \ - --image nathanl/open_instruct_auto \ - --pure_docker_mode \ - --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ - --beaker_workload_id {beaker_config.beaker_workload_id} \ - --model_name {args.hf_repo_revision} - """ - process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) - stdout, stderr = process.communicate() - print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") - print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") - print(f"Submit jobs after model training is finished - process return code: {process.returncode}") - - accelerator = Namespace() - accelerator.is_main_process = True # hack - if args.push_to_hub: - push_folder_to_hub( - accelerator, - args.output_dir, - args.hf_repo_id, - args.hf_repo_revision, - ) - - if accelerator.is_main_process: - # remove args.checkpoint_output_dir - if os.path.exists(args.checkpoint_output_dir): - shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) - - -if __name__ == "__main__": - parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) - main(*parser.parse()) From ebdf456bbc7e3c4fbb94916a1f1022f1f2d80d53 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 8 Nov 2024 17:34:45 +0000 Subject: [PATCH 47/53] update OLMo code --- .../beaker_configs/ray_node_setup.sh | 21 + open_instruct/model_utils.py | 8 +- open_instruct/olmo_adapter/modeling_olmo2.py | 221 ++ open_instruct/olmo_adapter/olmo_new.py | 448 ++++ .../ppo_vllm_thread_ray_gtrl_olmo.py | 1808 +++++++++++++++++ open_instruct/reward_modeling.py | 7 + open_instruct/vllm_utils2.py | 8 + open_instruct/x.py | 52 + 8 files changed, 2569 insertions(+), 4 deletions(-) create mode 100644 configs/beaker_configs/beaker_configs/ray_node_setup.sh create mode 100644 open_instruct/olmo_adapter/modeling_olmo2.py create mode 100644 open_instruct/olmo_adapter/olmo_new.py create mode 100644 open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py create mode 100644 open_instruct/x.py diff --git a/configs/beaker_configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/beaker_configs/ray_node_setup.sh new file mode 100644 index 000000000..e746cff1c --- /dev/null +++ b/configs/beaker_configs/beaker_configs/ray_node_setup.sh @@ -0,0 +1,21 @@ +export CURRENT_DATETIME=$(python -c "import datetime; import pytz; print(datetime.datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%m%d%y_%H%M%S'))") +export PYTHONPATH=$REPO_PATH +export PATH="/root/.local/bin:$PATH" + + +echo CURRENT_DATETIME=$CURRENT_DATETIME +echo PYTHONPATH=$PYTHONPATH +echo PATH=$PATH + +# python3 -c "import os, ray; print(os.path.dirname(ray.__file__))" + +RAY_NODE_PORT=8888 +ray stop --force + +if [ "$BEAKER_REPLICA_RANK" == "0" ]; then + echo "Starting Ray head node" + ray start --head --port=$RAY_NODE_PORT +else + echo "Starting Ray worker node $BEAKER_REPLICA_RANK" + ray start --address="${BEAKER_LEADER_REPLICA_HOSTNAME}:${RAY_NODE_PORT}" --block +fi \ No newline at end of file diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index f0a2d48bf..dd5dbeb1f 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -177,8 +177,8 @@ def get_reward( output = lm_backbone( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, + # position_ids=position_ids, + # return_dict=True, output_hidden_states=True, use_cache=False, # otherwise mistral-based RM would error out ) @@ -266,12 +266,12 @@ def forward( The output of the model, including hidden states. """ attention_mask = query_responses != pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() + # position_ids = attention_mask.cumsum(1) - attention_mask.long() input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return model( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, + # # position_ids=position_ids, return_dict=True, output_hidden_states=True, ) diff --git a/open_instruct/olmo_adapter/modeling_olmo2.py b/open_instruct/olmo_adapter/modeling_olmo2.py new file mode 100644 index 000000000..19c428edb --- /dev/null +++ b/open_instruct/olmo_adapter/modeling_olmo2.py @@ -0,0 +1,221 @@ +from typing import Callable, Optional, Union, List, Tuple +import torch +from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss +from transformers import PreTrainedModel +from transformers.modeling_outputs import SequenceClassifierOutputWithPast +from hf_olmo import OLMoTokenizerFast, OLMoConfig, OLMoForCausalLM +from hf_olmo.modeling_olmo import OLMo, create_model_config_from_pretrained_config, ActivationCheckpointingStrategy + +class OLMoForSequenceClassification(PreTrainedModel): + + config_class = OLMoConfig + base_model_prefix = "model" + _no_split_modules = ["OLMoBlock"] + _supports_flash_attn_2 = True + _supports_sdpa = True + supports_gradient_checkpointing = True + + def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False): + super().__init__(config) + self._gradient_checkpointing_func: Optional[Callable] = None + self._gradient_checkpointing = False + + self.num_labels = config.num_labels + if not model: + model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cpu" + self.model = OLMo(model_config, init_params=init_params) + else: + self.model = model + + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + + @property + def gradient_checkpointing(self) -> bool: + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, enabled: bool): + if self._gradient_checkpointing == enabled: + return + + # HF does not specify a way to pass checkpointing strategies, so we pick + # whole layer as our strategy. We can make this configurable later if needed. + checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None + self.model.set_activation_checkpointing( + checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func + ) + self._gradient_checkpointing = enabled + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.transformer.wte + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.transformer.wte = value + + def get_output_embeddings(self): + if self.config.weight_tying: + return self.model.transformer.wte + else: + return self.model.transformer.ff_out + + def set_output_embeddings(self, value: torch.nn.Module): + if self.config.weight_tying: + self.model.transformer.wte = value + else: + self.model.transformer.ff_out = value + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> torch.nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`. + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + Arguments: + new_num_tokens (`int`, *optional*): + The new number of tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + Note: + This method differs from the base class implementation by resizing the `embedding_size` attribute of the + model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size` + is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token + embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Update base model and current model config + self.config.embedding_size = model_embeds.weight.shape[0] + self.model.config.embedding_size = model_embeds.weight.shape[0] + + # Check if the embedding size is less than the vocab size + if self.config.embedding_size < self.config.vocab_size: + warning_message = ( + f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size " + f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary " + "size is less than or equal to the new token embedding size." + ) + # log.warning(warning_message) + print(warning_message) + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + """ + Forward pass for sequence classification with OLMo. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + position_ids: Position IDs for positional encoding + past_key_values: Past key values for incremental decoding + inputs_embeds: Pre-computed input embeddings + labels: Labels for computing loss + use_cache: Whether to use cached key/values + output_attentions: Whether to output attention weights + output_hidden_states: Whether to output hidden states + return_dict: Whether to return a ModelOutput object + + Returns: + SequenceClassifierOutputWithPast or tuple: Classification outputs + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/open_instruct/olmo_adapter/olmo_new.py b/open_instruct/olmo_adapter/olmo_new.py new file mode 100644 index 000000000..0eb0e3db4 --- /dev/null +++ b/open_instruct/olmo_adapter/olmo_new.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig +from hf_olmo.configuration_olmo import OLMoConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + + +class FlippedSiluAndMul(SiluAndMul): + """OLMo is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_sequence_length + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.include_bias, + quant_config=quant_config, + ) + + if config.attention_layer_norm: + # TODO: finish adding qk norm and norm_after + self.k_norm = RMSNorm( + (config.d_model // config.n_heads) * config.effective_n_kv_heads, + eps=config.layer_norm_eps, + #elementwise_affine=config.attention_layer_norm_with_affine, + #bias=False, + ) + self.q_norm = RMSNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.include_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + #q = self.q_norm(q) + #k = self.k_norm(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + try: + self.intermediate_size = config.intermediate_size + except AttributeError: + if config.mlp_hidden_size is not None: + self.intermediate_size = config.mlp_hidden_size // 2 + else: + self.intermediate_size = (config.d_model * config.mlp_ratio) // 2 + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = FlippedSiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + + self.norm_after = config.norm_after + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + + """ + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + """ + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + if self.norm_after: + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + if self.norm_after: + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class OlmoModel(nn.Module): + + def __init__(self, + config: Union[OlmoConfig, OLMoConfig], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.embedding_size, + config.hidden_size) + self.layers = nn.ModuleList([ + OlmoDecoderLayer(config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm( + config.hidden_size, + eps=config.layer_norm_eps, + #elementwise_affine=config.layer_norm_with_affine, + #bias=config.bias_for_layer_norm + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + + # Apply blocks one-by-one. + for layer_idx, decoder_layer in enumerate(self.layers): + # shape: (batch_size, seq_len, d_model) + hidden_states = decoder_layer( + positions, + hidden_states, + kv_caches[layer_idx], + attn_metadata, + ) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoNewForCausalLM(nn.Module): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.model = OlmoModel(config, cache_config, quant_config) + if config.weight_tying: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + #self.unpadded_vocab_size, + config.embedding_size, + config.hidden_size, + org_num_embeddings=config.embedding_size, + #org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.embedding_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def _create_map(self): + mapper = {} + for layer_i in range(self.config.n_layers): + mapper[f"model.transformer.blocks.{layer_i}.att_proj.weight"] = f"model.layers.{layer_i}.self_attn.qkv_proj.weight" + mapper[f"model.transformer.blocks.{layer_i}.attn_out.weight"] = f"model.layers.{layer_i}.self_attn.o_proj.weight" + mapper[f"model.transformer.blocks.{layer_i}.ff_proj.weight"] = f"model.layers.{layer_i}.mlp.gate_up_proj.weight" + mapper[f"model.transformer.blocks.{layer_i}.ff_out.weight"] = f"model.layers.{layer_i}.mlp.down_proj.weight" + + mapper[f"model.transformer.blocks.{layer_i}.attn_norm.weight"] = f"model.layers.{layer_i}.input_layernorm.weight" + mapper[f"model.transformer.blocks.{layer_i}.ff_norm.weight"] = f"model.layers.{layer_i}.post_attention_layernorm.weight" + mapper[f"model.transformer.blocks.{layer_i}.k_norm.weight"] = f"model.layers.{layer_i}.self_attn.k_norm.weight" + mapper[f"model.transformer.blocks.{layer_i}.q_norm.weight"] = f"model.layers.{layer_i}.self_attn.q_norm.weight" + + mapper["model.transformer.ln_f.weight"] = "model.norm.weight" + mapper["model.transformer.wte.weight"] = "model.embed_tokens.weight" + mapper["model.transformer.ff_out.weight"] = "lm_head.weight" + return mapper + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + mapper = self._create_map() + # print("mapper", mapper) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.weight_tying and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[mapper.get(name, name)] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py new file mode 100644 index 000000000..81ab3eeb2 --- /dev/null +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py @@ -0,0 +1,1808 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF +# which has the following license: +# Copyright [yyyy] [name of copyright owner] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import logging +import os +import random +import shutil +import socket +import subprocess +import threading +import time +from argparse import Namespace +from dataclasses import asdict, dataclass, field +from queue import Empty, Queue +from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple + +import deepspeed +import numpy as np +import pandas as pd +import ray +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils +import torch.utils.data +import vllm +from datasets import Dataset, DatasetDict +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from huggingface_hub import HfApi +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.queue import Queue as RayQueue +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rich.pretty import pprint +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + get_scheduler, +) +from transformers.integrations import HfDeepSpeedConfig +from vllm import SamplingParams + +from open_instruct.dataset_processor import ( + CHAT_TEMPLATES, + DATASET_SOURCE_KEY, + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + DatasetConfig, + SFTGroundTruthDatasetProcessor, + SimpleGenerateCollatorWithGroundTruth, + visualize_token, +) +from open_instruct.model_utils import ( + ModelConfig, + apply_verifiable_reward, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + print_rich_single_line_metrics, + print_rich_table, + push_folder_to_hub, + truncate_response, +) +from open_instruct.utils import ( + ArgumentParserPlus, + BeakerRuntimeConfig, + check_hf_olmo_availability, + combine_dataset, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_hf_entity, + maybe_use_ai2_wandb_entity, + upload_metadata_to_hf, +) +from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group + +api = HfApi() +INVALID_LOGPROB = 1.0 + + +@dataclass +class Args: + # required dataset args + dataset_mixer: str = None + """A dictionary of datasets (local or HF) to sample from.""" + dataset_train_splits: List[str] = None + """The dataset splits to use for training""" + dataset_eval_mixer: Optional[str] = None + """A dictionary of datasets (local or HF) to sample from for evaluation""" + dataset_eval_splits: Optional[List[str]] = None + """The dataset splits to use for evaluation""" + dataset_mixer_dict: Optional[dict] = None + """The dataset mixer as a dictionary""" + dataset_eval_mixer_dict: Optional[dict] = None + """The dataset eval mixer as a dictionary""" + + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + seed: int = 1 + """Seed of the experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + + # optimizer args + eps: float = 1e-5 + """The epsilon value for the optimizer""" + learning_rate: float = 2e-5 + """The initial learning rate for AdamW optimizer.""" + lr_scheduler_type: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" + ] = "linear" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + + # various batch sizes + num_train_epochs: int = 1 + """Number of epochs to train""" + gradient_accumulation_steps: Optional[int] = None + """The number of gradient accumulation steps""" + per_device_train_batch_size: Optional[int] = 1 + """The forward batch size per device (local_micro_batch_size)""" + per_device_eval_batch_size: Optional[int] = 1 + """The forward batch size per device for evaluation (local_micro_batch_size)""" + total_episodes: Optional[int] = 100000 + """The total number of episodes in the dataset""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_rollout_batch_size: int = 64 + """The number of rollout episodes per iteration per device""" + rollout_batch_size: Optional[int] = None + """The number of rollout episodes per iteration""" + num_training_steps: Optional[int] = None + """The number of training_steps to train""" + num_evals: int = 4 + """The number of evaluations to run throughout training""" + eval_freq: Optional[int] = None + """The frequency of evaluation steps""" + local_dataloader_batch_size: Optional[int] = None + """The batch size per GPU for the dataloader""" + save_freq: int = -1 + """How many train steps to save the model""" + + # online settings + num_epochs: int = 4 + """the number of epochs to train""" + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + local_mini_batch_size: int = 64 + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + reward_model_path: str = "EleutherAI/pythia-160m" + """the path to the reward model""" + reward_model_revision: Optional[str] = None + """the revision of the reward model""" + init_value_from_scratch: bool = False + """whether to initialize the value model from scratch""" + + # generation config + response_length: int = 53 + """the length of the response""" + stop_token: Optional[Literal["eos", "period"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + min_response_length: int = 0 + """stop only after this many tokens""" + temperature: float = 0.7 + """the sampling temperature""" + penalty_reward_value: float = -1.0 + """the reward value for responses that do not contain `stop_token_id`""" + non_stop_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + number_samples_per_prompt: int = 1 + """the number of samples to generate per prompt, useful for easy-star""" + + # online PPO specific args + beta: float = 0.05 + """the beta value of the RLHF objective (KL coefficient)""" + whiten_rewards: bool = False + """whether to whiten the rewards""" + cliprange: float = 0.2 + """the clip range""" + vf_coef: float = 0.1 + """the value function coefficient""" + cliprange_value: float = 0.2 + """the clip range for the value function""" + gamma: float = 1 + """the discount factor""" + lam: float = 0.95 + """the lambda value for GAE""" + kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" + """the KL estimator to use""" + apply_verifiable_reward: bool = False + """whether to apply verifiable reward""" + reward_model_multiplier: float = 1.0 + """the reward model multiplier, for down/upscaling the reward model output""" + answer_extraction_model: str = None + + # async setting + async_mode: bool = True + """Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" + + # ray + actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1]) + """number of gpus per node for actor""" + vllm_num_engines: int = 1 + """number of vLLM Engines, set to 0 to disable vLLM""" + vllm_tensor_parallel_size: int = 1 + """tensor parallel size of vLLM Engine for multi-GPU inference""" + vllm_sync_backend: str = "nccl" + """DeepSpeed -> vLLM weight sync backend""" + enable_prefix_caching: bool = False + """whether to enable prefix caching""" + deepspeed_stage: int = 0 + """the deepspeed stage""" + gather_whole_model: bool = True + """whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)""" + + # wandb and HF tracking configs + with_tracking: bool = False + """If toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "open_instruct_internal" + """The wandb's project name""" + wandb_entity: Optional[str] = None + """The entity (team) of wandb's project""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + output_dir: Optional[str] = None + """Where to save the model""" + checkpoint_output_dir: Optional[str] = None + """Where to save the model checkpoints in case of preemption""" + + # Ai2 specific settings + try_launch_beaker_eval_jobs: bool = True + """Whether to launch beaker evaluation jobs after training""" + try_launch_beaker_eval_jobs_on_weka: bool = False + """Whether to launch beaker evaluation jobs after training on weka""" + oe_eval_tasks: Optional[List[str]] = None + """The beaker evaluation tasks to launch""" + hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" + """What dataset to upload the metadata to. If unset, don't upload metadata""" + + def __post_init__(self): + self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) + if self.dataset_eval_mixer is not None: + self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) + + +def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: + # if passed through cli: convert the dataset mixers to dictionaries + if isinstance(value, str): + return json.loads(value), value + # if passed through yaml: convert the dataset mixers to strings + elif isinstance(value, dict): + return value, json.dumps(value) + else: + raise ValueError("Input must be either a string or a dictionary") + + +def calculate_runtime_args(args: Args, model_config: ModelConfig): + """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" + # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + # args.world_size = accelerator.num_processes + args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + args.gradient_accumulation_steps = exact_div( + args.local_mini_batch_size, + args.per_device_train_batch_size, + "`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`", + ) + args.world_size = sum(args.actor_num_gpus_per_node) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.mini_batch_size = int(args.local_mini_batch_size * args.world_size) + args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt) + args.eval_freq = max(1, args.num_training_steps // args.num_evals) + # PPO logic: do checks and set up dataloader batch size + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + args.local_dataloader_batch_size = args.rollout_batch_size + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: # first try to use AI2 entity + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: # then try to use the user's entity + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: # auto-generate one + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if args.with_tracking: + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + + +def get_train_ds_config( + offload, + adam_offload=False, + stage=0, + bf16=True, + max_norm=1.0, + zpg=8, + grad_accum_dtype=None, + disable_trace_cache=True, +): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": { + "device": "cpu" if adam_offload else "none", + "pin_memory": True, + }, + "sub_group_size": "auto", + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "reduce_bucket_size": "auto", + # # ZeRO++ + # "zero_hpz_partition_size": zpg, + # "zero_quantized_weights": False, + # "zero_quantized_gradients": False, + } + if disable_trace_cache: + zero_opt_dict["stage3_prefetch_bucket_size"] = 0 + zero_opt_dict["stage3_max_live_parameters"] = 0 + zero_opt_dict["stage3_max_reuse_distance"] = 0 + + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": max_norm, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, + } + + +def get_eval_ds_config( + offload, + stage=0, + bf16=True, +): + zero_opt_dict = { + "stage": stage, + "stage3_param_persistence_threshold": "auto", + "offload_param": { + "device": "cpu" if offload else "none", + "pin_memory": True, + }, + } + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def remove_padding(sequences, pad_token_id): + return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences] + + +class ShufflingIterator: + def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): + self.data = data.copy() + self.batch_size = batch_size + self.index = 0 + self.rng = np.random.default_rng(seed) + self.rng.shuffle(self.data) + + # Ensure the effective dataset size is divisible by batch_size + self.effective_size = len(self.data) - (len(self.data) % batch_size) + + def __iter__(self) -> Iterator[List[int]]: + return self + + def __next__(self) -> List[int]: + if self.index >= self.effective_size: + self.index = 0 + self.rng.shuffle(self.data) + + end_index = self.index + self.batch_size + batch = self.data[self.index : end_index].tolist() + self.index = end_index + + return batch + + +class RayProcess: + def __init__(self, world_size, rank, local_rank, master_addr, master_port): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification + from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + from vllm.model_executor.models import ModelRegistry + AutoModelForSequenceClassification.register(hf_olmo.OLMoConfig, OLMoForSequenceClassification) + ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + self.world_size = world_size + self.rank = rank + self.local_rank = local_rank + self.master_addr = master_addr if master_addr else self.get_current_node_ip() + self.master_port = master_port if master_port else self.get_free_port() + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["RANK"] = str(self.rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + # os.environ["LOCAL_RANK"] = str(self._local_rank) + os.environ["LOCAL_RANK"] = "0" + random.seed(self.rank) + np.random.seed(self.rank) + torch.manual_seed(self.rank) + + @staticmethod + def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + @staticmethod + def get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self.master_addr, self.master_port + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + +@ray.remote(num_gpus=1) +class PolicyTrainerRayProcess(RayProcess): + def from_pretrained( + self, args: Args, model_config: ModelConfig, beaker_config: BeakerRuntimeConfig, wandb_url: str + ): + self.args = args + self.model_config = model_config + self.beaker_config = beaker_config + self.wandb_url = wandb_url + torch.cuda.set_device(self.local_rank) + deepspeed.init_distributed() + + ds_config = get_train_ds_config( + offload=False, + adam_offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.policy) + self.policy.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + print(ds_config) + self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.policy, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.model.train() + + # value model + self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + if args.init_value_from_scratch: + self.value_model.init_weights() # re-initialize the value model from scratch + disable_dropout_in_model(self.value_model) + self.value_model.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.value_model, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.value_model.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=args.warm_up_steps, + num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + ) + self.value_model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.value_model, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.value_model.train() + + # reference model + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.ref_policy) + self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) + self.ref_policy.eval() + + # reward model + self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.reward_model) + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config) + self.reward_model.eval() + + def get_vocab_size(self): + return self.policy.config.vocab_size + + def forward( + self, + query_response: torch.LongTensor, + response: torch.LongTensor, + pad_token_id: int, + context_length: int, + temperature: float, + ) -> torch.Tensor: + output = forward(self.model, query_response, pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + return logprob + + def train( + self, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizer, + vllm_engines: List[ray.actor.ActorHandle], + metrics_queue: RayQueue, + data_collator: Callable, + ): + torch.set_printoptions(precision=4, sci_mode=False) + + args = self.args + + accelerator = Namespace() + accelerator.process_index = self.rank + accelerator.num_processes = self.world_size + accelerator.is_main_process = self.rank == 0 + torch.distributed.barrier() + if self.rank == 0: + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + vllm_num_engines, vllm_tensor_parallel_size = ( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + ) + world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 + backend = args.vllm_sync_backend + # https://github.com/OpenRLHF/OpenRLHF/issues/313 + if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0": + backend = "gloo" + print( + "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)" + ) + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * vllm_tensor_parallel_size + 1, + world_size, + "openrlhf", + backend=backend, + ) + for i, engine in enumerate(vllm_engines) + ] + self.model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name="openrlhf", + ) + ray.get(refs) + torch.distributed.barrier() + + def broadcast_to_vllm(): + # avoid OOM + torch.cuda.empty_cache() + model = self.model.module + count, num_params = 0, len(list(model.named_parameters())) + refss = [] + if args.gather_whole_model: + with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3): + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + else: # broadcast each parameter independently + for name, param in model.named_parameters(): + count += 1 + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + if torch.distributed.get_rank() == 0: + ray.get(refss) + + # broadcast_to_vllm() + if args.stop_token: + if args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id + if args.stop_token == "period": + args.stop_token_id = tokenizer.encode(".")[0] + # data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + train_dataset_idxs = np.arange(len(train_dataset)) + shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed) + + # hack to left pad + def repeat_generator(): + while True: + batch_idxs = next(shuffling_iter) + yield [train_dataset[i] for i in batch_idxs] + + iter_dataloader = iter(repeat_generator()) + generation_config = SamplingParams( + temperature=args.temperature, + top_p=1.0, + max_tokens=args.response_length, + include_stop_str_in_output=True, + n=args.number_samples_per_prompt, + ) + # print("setup async queues") + param_prompt_Q = None + response_ids_Q = None + evaluation_Q = None + response_ids_Q = Queue(maxsize=1) + param_prompt_Q = Queue(maxsize=1) + evaluation_Q = Queue(maxsize=1) + num_eval_samples = 32 + sample_evaluation_prompt_token_ids = None + if eval_dataset is not None: + sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] + + def vllm_generate( + generation_config: SamplingParams, + response_ids_Q: Queue, + param_prompt_Q: Queue, + num_training_steps: int, + sample_evaluation_prompt_token_ids: Optional[List[int]], + evaluation_Q: Queue, + eval_freq: int, + resume_training_step: int, + ): + llm = vllm_engines[0] + for training_step in range(resume_training_step, num_training_steps + 1): + items = param_prompt_Q.get() + if items is None: + break + unwrapped_model, g_queries_list = items + # if unwrapped_model is not None: + generation_start_time = time.time() + + outputs = ray.get( + llm.generate.remote( + sampling_params=generation_config, prompt_token_ids=g_queries_list, use_tqdm=False + ) + ) + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") + response_ids_Q.put(response_ids) + + if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: + outputs = ray.get( + llm.generate.remote( + prompt_token_ids=sample_evaluation_prompt_token_ids, + sampling_params=generation_config, + use_tqdm=False, + ) + ) + # for evaluation, even if we have multiple outputs, we only look at one of them for simplicity + response_ids = [list(output.outputs[0].token_ids) for output in outputs] + evaluation_Q.put(response_ids) + + resume_training_step = 1 + if accelerator.is_main_process: + thread = threading.Thread( + target=vllm_generate, + args=( + generation_config, + response_ids_Q, + param_prompt_Q, + args.num_training_steps, + sample_evaluation_prompt_token_ids, + evaluation_Q, + args.eval_freq, + resume_training_step, + ), + ) + thread.start() + print("vllm generate thread starts") + + # set up the metrics and initial states + device = torch.device(self.local_rank) + g_vllm_responses = torch.zeros( + (args.rollout_batch_size * args.number_samples_per_prompt, args.response_length), + device=device, + dtype=torch.long, + ) + stats_shape = ( + args.num_epochs, + args.num_mini_batches * args.number_samples_per_prompt, + args.gradient_accumulation_steps, + ) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + local_metrics = torch.zeros((20,), device=device) + episode = args.rollout_batch_size * (resume_training_step - 1) + + # training loop + start_time = time.time() + global_data = next(iter_dataloader) + data = data_collator( + global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size] + ) + global_queries = data_collator(global_data)[ + INPUT_IDS_PROMPT_KEY + ].tolist() # can be simplified since we `remove_padding` later anyway + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + + answer_extraction_model = None + answer_extraction_tokenizer = None + # for _ in range(1, resume_training_step): # we didn't store scheduler state + # scheduler.step() + + for training_step in range(resume_training_step, args.num_training_steps + 1): + episode += args.rollout_batch_size * args.number_samples_per_prompt # each sample is an episode + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + if accelerator.is_main_process: + df = None + try: + evaluation_responses = evaluation_Q.get(timeout=0.01) + print("🔥🔥🔥 Evaluation responses received") + table = {} + table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids) + table["response"] = tokenizer.batch_decode(evaluation_responses) + table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] + df = pd.DataFrame(table) + del table + except Empty: + print("🙈 Evaluation responses not received") + + # (optionally) evaluate the model + if args.async_mode: + if training_step != 1: + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + else: + if training_step != 1: + # NOTE: important: the indent here is different for sync mode + # we also set to use `queries = queries_next` immediately + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + torch.cuda.empty_cache() + # print('get reward stuff starts') + # if we generate multiple samples per prompt, we need to repeat the queries and ground truths + # to match the vllm outputs. + if args.number_samples_per_prompt > 1: + queries = queries.repeat_interleave(args.number_samples_per_prompt, dim=0) + ground_truths = [gt for gt in ground_truths for _ in range(args.number_samples_per_prompt)] + datasets = [ds for ds in datasets for _ in range(args.number_samples_per_prompt)] + + training_time_start = time.time() + with torch.no_grad(): + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + verifiable_counts = [] + sequence_lengths = [] + values = [] + if accelerator.is_main_process: + g_response_token_ids = response_ids_Q.get() + DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out + g_padded_response_ids = [ + response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) + for response in g_response_token_ids + ] + g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) + g_vllm_responses[:] = g_padded_response_ids + dist.broadcast(g_vllm_responses, src=0) + local_vllm_responses = g_vllm_responses[ + accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] + ] + # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") + query_responses = torch.cat((queries, local_vllm_responses), 1) + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + # print(f"get reward stuff starts {i=}") + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + + logprob = self.forward( + query_response, response, tokenizer.pad_token_id, context_length, args.temperature + ) + torch.cuda.empty_cache() + + ref_output = forward(self.ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + # print("get reward stuff starts 2") + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + if args.reward_model_multiplier != 1.0: + score *= args.reward_model_multiplier + # also apply verifiable reward + if args.apply_verifiable_reward: + # we need to batch the gt to match query. + ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] + dataset = datasets[i : i + args.local_rollout_forward_batch_size] + verifiable_reward, verifiable_count = apply_verifiable_reward( + postprocessed_query_response, + tokenizer, + ground_truth, + dataset, + verify_reward=10, + answer_extraction_model=answer_extraction_model, + answer_extraction_tokenizer=answer_extraction_tokenizer, + ) + score += verifiable_reward + else: + verifiable_count = torch.tensor([0.0], device=device).float() + full_value, _, _ = get_reward( + self.value_model, query_response, tokenizer.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + verifiable_counts.append(verifiable_count) + # print(f"get reward stuff starts 5") + + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + verifiable_counts = torch.cat(verifiable_counts, 0) + verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0] + values = torch.cat(values, 0) + # print(f"get reward stuff finished") + del (logprob, ref_logprob, full_value, value, score) + gc.collect() + torch.cuda.empty_cache() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_stop_token = torch.any(postprocessed_responses == args.stop_token_id, dim=-1) + # NOTE: only apply the stop token filter if the response is long enough + # otherwise the model could learn to generate the first token as the stop token + contain_stop_token = contain_stop_token & (sequence_lengths >= args.min_response_length) + if args.non_stop_penalty: + scores = torch.where( + contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value) + ) + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + # print(f"get reward stuff finished 2") + + # 4. compute rewards + kl1 = logprobs - ref_logprobs + kl2 = (kl1) ** 2 / 2 + kl3 = (-kl1).exp() - 1 + kl1 + if args.kl_estimator == "kl1": + kl = kl1 + elif args.kl_estimator == "kl2": + kl = kl2 + elif args.kl_estimator == "kl3": + kl = kl3 + # if self.rank==0: + # print(f"{logprobs[0][:40]=}, {ref_logprobs[0][:40]=}, {kl.sum(1)=}") + non_score_reward = -args.beta * kl + non_score_reward_sum = non_score_reward.sum(1) + rlhf_reward = scores + non_score_reward_sum + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + # print(f"get reward stuff finished 3") + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # print('gae') + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # print('training starts') + # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch + for epoch_idx in range(args.num_epochs): + b_inds = np.random.permutation(args.local_rollout_batch_size * args.number_samples_per_prompt) + minibatch_idx = 0 + for mini_batch_start in range( + 0, args.local_rollout_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size + ): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + # print("micro batch start", micro_batch_start, self.rank) + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds] + + vpred_temp = get_reward( + self.value_model, mb_query_responses, tokenizer.pad_token_id, context_length + ) + vpred_temp = vpred_temp[0] + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1) + self.value_model.backward(vf_loss * args.vf_coef) + self.value_model.step() + + new_logprobs = self.forward( + mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature + ) + # if self.rank==0: + # print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}") + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + self.model.backward(loss) + # print("backward loss", self.rank, "micro batch start", micro_batch_start) + # print("trying to step", self.rank, "micro batch start", micro_batch_start) + self.model.step() + # print("step", self.rank, "micro batch start", micro_batch_start) + with torch.no_grad(): + # print("waiting for value model step", self.rank, "micro batch start", micro_batch_start) + # vf_loss, vf_clipfrac = ray.get(value_model_step_future) + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1) + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + # print("value model stepped", self.rank, "micro batch start", micro_batch_start) + # prob_dist = torch.nn.functional.softmax(logits, dim=-1) + # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + # entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # fmt: off + del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1 + del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss + # del vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max + # del vf_loss, vf_clipfrac, pg_clipfrac, approxkl + # fmt: on + # del everything and empty cache + torch.cuda.empty_cache() + del b_inds, mini_batch_inds + # print("start metrics") + with torch.no_grad(): + local_metrics[0] = sequence_lengths.float().mean() + local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() + local_metrics[2] = kl.sum(1).mean() + local_metrics[3] = (-logprobs).sum(1).mean() + local_metrics[4] = non_score_reward_sum.mean() + local_metrics[5] = rlhf_reward.mean() + local_metrics[6] = scores.mean() + local_metrics[7] = approxkl_stats.mean() + local_metrics[8] = pg_clipfrac_stats.mean() + local_metrics[9] = pg_loss_stats.mean() + local_metrics[10] = vf_loss_stats.mean() + local_metrics[11] = vf_clipfrac_stats.mean() + local_metrics[12] = entropy_stats.mean() + local_metrics[13] = ratio_stats.mean() + local_metrics[14] = ratio_stats.var() + local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() + local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() + local_metrics[17] = verifiable_correct_rate + # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() + local_metrics /= dist.get_world_size() + dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) + global_metrics = local_metrics.tolist() + metrics = { + "episode": episode, + "training_step": training_step, + "lr": self.scheduler.get_last_lr()[0], + "epoch": episode / len(train_dataset), + "time/from_scratch": time.time() - start_time, + "time/training": time.time() - training_time_start, + "val/sequence_lengths": global_metrics[0], + "val/num_stop_token_ids": global_metrics[1], + "objective/kl": global_metrics[2], + "objective/kl2": global_metrics[15], + "objective/kl3": global_metrics[16], + "objective/entropy": global_metrics[3], + "objective/non_score_reward": global_metrics[4], + "objective/rlhf_reward": global_metrics[5], + "objective/scores": global_metrics[6], + "policy/approxkl_avg": global_metrics[7], + "policy/clipfrac_avg": global_metrics[8], + "loss/policy_avg": global_metrics[9], + "loss/value_avg": global_metrics[10], + "val/clipfrac_avg": global_metrics[11], + "policy/entropy_avg": global_metrics[12], + "val/ratio": global_metrics[13], + "val/ratio_var": global_metrics[14], + "objective/verifiable_correct_rate": global_metrics[17], + } + if accelerator.is_main_process: + print_rich_single_line_metrics(metrics) + metrics_queue.put((metrics, episode, df)) + del (queries, responses, postprocessed_responses, logprobs, ref_logprobs, sequence_lengths, scores, values) + del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) + gc.collect() + torch.cuda.empty_cache() + # print(f"finished training {training_step}") + + # save steps + if args.save_freq > 0 and training_step % args.save_freq == 0: + checkpoint_dir = f"{args.output_dir}_checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") + os.makedirs(step_dir, exist_ok=True) + print(f"Saving model at step {training_step} to {step_dir}") + self.save_model(step_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(step_dir, training_step) + print(f"Saving final model at step {training_step} to {args.output_dir}") + self.save_model(args.output_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(args.output_dir) + + # Ai2 logic: we use /output to store the artifacts of the job, so we + # make a copy of the model to `/output` in the end. + if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0: + shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) + print("finished training") + + def save_model(self, output_dir: str) -> None: + if self.rank == 0: + os.makedirs(output_dir, exist_ok=True) + + # save model weights for ZeRO2/3 + model_to_save = self.model + if hasattr(model_to_save, "module"): + model_to_save = model_to_save.module + + # gather parameters + output_state_dict = {} + for k, v in model_to_save.named_parameters(): + # only gather z3 params + params_to_fetch = _z3_params_to_fetch([v]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + vv = v.data.cpu() + if self.rank == 0: + output_state_dict[k] = vv + + if self.rank == 0: + state_dict = model_to_save.state_dict() + + # copy named_buffers with `persistent=True` + for k, v in model_to_save.named_buffers(): + if k not in state_dict: + continue + vv = v.data.cpu() + output_state_dict[k] = vv + + state_dict_keys = set(state_dict.keys()) + output_state_dict_keys = set(output_state_dict.keys()) + + # corner case for tie_word_embeddings, such as Qwen2-0.5B + if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: + state_dict_keys.remove("lm_head.weight") + + assert state_dict_keys.issubset( + output_state_dict_keys + ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" + + # # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 + # if isinstance(model_to_save, PeftModel): + # model_to_save.save_pretrained(output_dir, **kwargs) + # if self.stage == 3: + # torch.save( + # get_peft_model_state_dict(model_to_save, output_state_dict), + # os.path.join(output_dir, "adapter_model.bin"), + # ) + # else: + # save model + model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) + + # save tokenizer + self.original_tokenizer.save_pretrained(output_dir) + + def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = None) -> None: + """auto eval the metrics as `f"{args.exp_name}_step_{training_step}"` in our leaderboard""" + args = self.args + beaker_config = self.beaker_config + model_config = self.model_config + wandb_url = self.wandb_url + # Ai2 specific logic + if is_beaker_job() and self.rank == 0: + if training_step is not None: + leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" + else: + leaderboard_name = args.hf_repo_revision + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "ppo", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb_url, + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + leaderboard_name, # to match what the auto-evals name as. + ) + + command = f"""\ +python scripts/submit_eval_jobs.py \ + --model_name {leaderboard_name} \ + --location {step_dir} \ + --cluster ai2/saturn-cirrascale \ + --is_tuned \ + --workspace "tulu-3-results" \ + --preemptible \ + --use_hf_tokenizer_template \ + --beaker_image "nathanl/open_instruct_auto" \ + --upload_to_hf allenai/tulu-3-evals \ + --run_oe_eval_experiments \ + --evaluate_on_weka \ + --run_safety_evaluations \ + --skip_oi_evals""" + if args.oe_eval_tasks is not None: + command += f" --oe_eval_tasks {','.join(args.oe_eval_tasks)}" + print(f"Launching eval jobs with command: {command}") + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + +def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): + while True: + if stop_event.is_set(): + break + for ref in object_refs: + try: + ray.get(ref, timeout=0.01) + except ray.exceptions.GetTimeoutError: + pass + except Exception as e: + print(e) + print(f"Actor {ref} died") + time.sleep(120) + ray.shutdown() + os._exit(1) # Force shutdown the process + + time.sleep(30) + + +class ModelGroup: + def __init__( + self, + pg: PlacementGroup, + ray_process_cls: RayProcess, + num_gpus_per_node: List[int], + ): + self.pg = pg + self.ray_process_cls = ray_process_cls + self.num_gpus_per_node = num_gpus_per_node + self.num_gpus_per_actor = 1 + self.num_cpus_per_actor = 4 + self.models = [] + world_size = sum(self.num_gpus_per_node) + master_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=self.pg, placement_group_bundle_index=0 + ), + ).remote(world_size, 0, 0, None, None) + + self.models.append(master_policy) + master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) + + def get_bundle_index(rank, num_gpus_per_node): + """given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" + bundle_idx = 0 + while rank >= num_gpus_per_node[bundle_idx]: + rank -= num_gpus_per_node[bundle_idx] + bundle_idx += 1 + return bundle_idx + + assert get_bundle_index(0, [7, 8, 4]) == 0 + assert get_bundle_index(1, [7, 8, 4]) == 0 + assert get_bundle_index(7, [7, 8, 4]) == 1 + assert get_bundle_index(8, [7, 8, 4]) == 1 + assert get_bundle_index(9, [7, 8, 4]) == 1 + assert get_bundle_index(16, [7, 8, 4]) == 2 + + # Setup worker models + for rank in range(1, world_size): + print(f"{rank=}, {world_size=}, {rank=}, {master_addr=}, {master_port=}") + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.pg, + placement_group_bundle_index=get_bundle_index(rank, self.num_gpus_per_node), + ) + worker_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=scheduling_strategy, + ).remote(world_size, rank, 0, master_addr, master_port) + self.models.append(worker_policy) + + +def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + calculate_runtime_args(args, model_config) + + # set up experiment tracking and seeds + all_configs = {} + beaker_config = None + if is_beaker_job(): + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", None) + beaker_config = maybe_get_beaker_config() + all_configs.update(vars(beaker_config)) + all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) + if args.with_tracking: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=all_configs, + name=args.run_name, + save_code=True, + tags=[args.exp_name] + get_wandb_tags(), + ) + writer = SummaryWriter(f"runs/{args.run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # create a tokenizer (pad from right) + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + ) + if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + else: + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + + # create the dataset + dataset_dict = DatasetDict() + dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1: + args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets." + ) + if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1: + args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets." + ) + train_dataset = combine_dataset( + args.dataset_mixer_dict, + splits=args.dataset_train_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + if dataset_config.sanity_check: + train_dataset = train_dataset.select( + range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) + ) + train_dataset = dataset_processor.tokenize(train_dataset) + train_dataset = dataset_processor.filter(train_dataset, need_contain_labels=False) + dataset_dict["train"] = train_dataset + eval_dataset = None + if args.dataset_eval_mixer is not None: + eval_dataset = combine_dataset( + args.dataset_eval_mixer_dict, + splits=args.dataset_eval_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + if dataset_config.sanity_check: + eval_dataset = eval_dataset.select( + range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples)) + ) + eval_dataset = dataset_processor.tokenize(eval_dataset) + eval_dataset = dataset_processor.filter(eval_dataset, need_contain_labels=False) + dataset_dict["eval"] = eval_dataset + data_collator = SimpleGenerateCollatorWithGroundTruth(pad_token_id=tokenizer.pad_token_id) + + # some more runtime logging + pprint([args, dataset_config, model_config]) + visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) + if args.with_tracking: + # upload the visualized token length + dataset_processor.get_token_length_visualization( + dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" + ) + wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) + + # create the model and optimizer + pg = None + bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.actor_num_gpus_per_node] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + inits = [] + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + args.actor_num_gpus_per_node, + ) + wandb_url = wandb.run.get_url() if args.with_tracking else None + inits.extend( + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url) for model in policy_group.models + ) + max_len = dataset_config.max_prompt_token_length + args.response_length + vllm_engines = create_vllm_engines( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + model_config.model_name_or_path, + model_config.model_revision, + args.seed, + args.enable_prefix_caching, + max_len, + ) + + metrics_queue = RayQueue() + ray.get(inits) + print("======== all models initialized =========") + ray.get(policy_group.models[0].get_vocab_size.remote()) + # print(f"{policy_vocab_size=}, {reward_vocab_size=}") + # if policy_vocab_size != reward_vocab_size: + # ray.shutdown() # shutdown here so this error message is not buried in the logs + # raise ValueError( + # "Policy and reward model must have the same vocab size. " + # f"Policy: {policy_vocab_size}, Reward: {reward_vocab_size}. " + # "If they don't have the same vocab size, the policy could generate tokens which " + # "is going to cause index out of bound error in the reward model." + # ) + + refs = [] + for i, policy_model in enumerate(policy_group.models): + refs.append( + policy_model.train.remote( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + vllm_engines=vllm_engines, + metrics_queue=metrics_queue, + data_collator=data_collator, + ) + ) + + # somtimes a worker dies due to CUDA issues, but the rest of the cluster would just hang + # so we need kill the ray cluster when this happens. + stop_event = threading.Event() + threading.Thread(target=kill_ray_cluster_if_a_worker_dies, args=(refs, stop_event)).start() + + # train and gather metrics + resume_training_step = 1 + for training_step in range(resume_training_step, args.num_training_steps + 1): + result = metrics_queue.get() + metrics, episode, df = result + for key, value in metrics.items(): + writer.add_scalar(key, value, episode) + + if df is not None: + if args.with_tracking: + wandb.log({"sample_completions": wandb.Table(dataframe=df)}) + else: + print_rich_table(df.iloc[:1]) + ray.get(refs) + + # save model + ray.shutdown() + stop_event.set() + + # Ai2 specific logic + if is_beaker_job(): + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "sft", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: + command = f"""\ + python mason.py \ + --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --model_name {args.hf_repo_revision} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + accelerator = Namespace() + accelerator.is_main_process = True # hack + if args.push_to_hub: + print("Pushing model to hub") + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + + # The `checkpoint_output_dir` is only used in case of preemption and should be deleted if the run was successful. + # We use `--save_freq` to save intermediate checkpoints in the output folder instead. + if args.checkpoint_output_dir is not None and os.path.exists(args.checkpoint_output_dir): + shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) + main(*parser.parse()) diff --git a/open_instruct/reward_modeling.py b/open_instruct/reward_modeling.py index 75618d9bc..e3e927d20 100644 --- a/open_instruct/reward_modeling.py +++ b/open_instruct/reward_modeling.py @@ -47,6 +47,7 @@ from open_instruct.reward_modeling_eval import evaluate from open_instruct.utils import ( ArgumentParserPlus, + check_hf_olmo_availability, combine_dataset, get_wandb_tags, is_beaker_job, @@ -195,6 +196,12 @@ def layer_init(layer: nn.Module, std: float): def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification + AutoModelForSequenceClassification.register(hf_olmo.OLMoConfig, OLMoForSequenceClassification) accelerator = calculate_runtime_args_and_accelerator(args, model_config) local_seed = args.seed + accelerator.process_index diff --git a/open_instruct/vllm_utils2.py b/open_instruct/vllm_utils2.py index 7dbdf6a54..b50f76fa0 100644 --- a/open_instruct/vllm_utils2.py +++ b/open_instruct/vllm_utils2.py @@ -35,6 +35,8 @@ ) from vllm.worker.worker import Worker +from open_instruct.utils import check_hf_olmo_availability + # Copy from pytorch to allow creating multiple main groups. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py @@ -131,6 +133,12 @@ def update_weight(self, name, dtype, shape, empty_cache=False): class LLMRayActor: def __init__(self, *args, **kwargs): import vllm + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + from vllm.model_executor.models import ModelRegistry + ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) self.__version__ = vllm.__version__ assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1" diff --git a/open_instruct/x.py b/open_instruct/x.py new file mode 100644 index 000000000..f0e4bb1e7 --- /dev/null +++ b/open_instruct/x.py @@ -0,0 +1,52 @@ +from typing import Optional + +from hf_olmo import * +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, +) + +from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + + +class Args: + model_name_or_path: str = "/net/nfs.cirrascale/allennlp/akshitab/model-checkpoints/peteish7/step11931-unsharded-hf" + trust_remote_code: bool = True + revision: Optional[str] = None + + +# def main(args: Args): +# model = AutoModelForCausalLM.from_pretrained( +# args, +# trust_remote_code=True, +# ) + + +if __name__ == "__main__": + # instead of installing from source, https://github.com/AkshitaB/vllm/blob/c96643ec56da3ab8cefba03cadf7731788e756b5/vllm/model_executor/models/__init__.py#L49 + # here we just register the new model class + from vllm.model_executor.models import ModelRegistry + + ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + from vllm import LLM, SamplingParams + + model = AutoModelForCausalLM.from_pretrained( + "/net/nfs.cirrascale/allennlp/akshitab/model-checkpoints/peteish7/step11931-unsharded-hf", + trust_remote_code=True, + ) + from vllm.model_executor.models import ModelRegistry + + from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification + from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + + AutoModelForSequenceClassification.register(OLMoConfig, OLMoForSequenceClassification) + + s = SamplingParams(temperature=0.0) + llm = LLM( + model="/net/nfs.cirrascale/allennlp/akshitab/model-checkpoints/peteish7/step11931-unsharded-hf", + trust_remote_code=True, + gpu_memory_utilization=0.90, + ) + + vllm_out = llm.generate(["How is the weather today"], sampling_params=s) + print(vllm_out[0].outputs[0].text) From 3422229c58e42bb6b2627899d969444b595ce44e Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Mon, 11 Nov 2024 18:27:02 +0000 Subject: [PATCH 48/53] push changes --- Dockerfile | 1 + .../ppo_vllm_thread_ray_gtrl_olmo.py | 28 +- .../generation_check_ground_truth.py | 261 ++++++++++++++++++ 3 files changed, 286 insertions(+), 4 deletions(-) create mode 100644 open_instruct/rejection_sampling/generation_check_ground_truth.py diff --git a/Dockerfile b/Dockerfile index dd6b95a97..f89725c3d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -91,6 +91,7 @@ RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url h RUN pip install packaging RUN pip install flash-attn==2.6.3 --no-build-isolation RUN pip install -r requirements.txt +RUN pip install ai2_olmo # NLTK download RUN python -m nltk.downloader punkt diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py index 81ab3eeb2..5410b450e 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py @@ -1601,11 +1601,31 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" ) - if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: - tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + if check_hf_olmo_availability(): + print("Using exsiting tokenier chat template...") + pass else: - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding - tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + # if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + # tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + # elif check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast): + # # OLMo newer models use this tokenizer + # breakpoint() + # if tokenizer.bos_token is None: + # tokenizer.bos_token = tokenizer.eos_token + # assert ( + # args.add_bos + # ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + # # else, pythia / other models + # else: + # num_added_tokens = tokenizer.add_special_tokens( + # { + # "pad_token": "", + # } + # ) + # assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token." + # else: + # tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding # create the dataset dataset_dict = DatasetDict() diff --git a/open_instruct/rejection_sampling/generation_check_ground_truth.py b/open_instruct/rejection_sampling/generation_check_ground_truth.py new file mode 100644 index 000000000..1e3dd4240 --- /dev/null +++ b/open_instruct/rejection_sampling/generation_check_ground_truth.py @@ -0,0 +1,261 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +python open_instruct/rejection_sampling/generation_check_ground_truth.py \ + --model_name_or_path allenai/open_instruct_dev \ + --revision L3.1-8B-v3.9-nc-fixed-2__meta-llama_Llama-3.1-8B__123__1730531285 \ + --num_completions 3 \ + --dataset_mixer_list ai2-adapt-dev/math_ground_truth 1.0 \ + --dataset_splits train \ + --dataset_end_idx 10 + + +python open_instruct/rejection_sampling/generation_check_ground_truth.py \ + --model_name_or_path allenai/open_instruct_dev \ + --revision olmo_7b_soup_anneal_v3.9_4_DPO___model__42__1730863426 \ + --num_completions 5 \ + --dataset_mixer_list ai2-adapt-dev/gsm8k_ground_truth 1.0 \ + --dataset_splits train \ + --dataset_end_idx 20 +""" +import asyncio +import copy +import json +import os +import sys +import time +from collections import defaultdict +from dataclasses import asdict, dataclass +from pprint import pformat +from typing import Dict, List, Optional + +from huggingface_hub import HfApi +from huggingface_hub.repocard import RepoCard +from rich.pretty import pprint +import torch +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + +from open_instruct.dataset_processor import ( + INPUT_IDS_PROMPT_KEY, + DatasetConfig, + SFTDatasetProcessor, +) +from open_instruct.model_utils import apply_verifiable_reward +from open_instruct.utils import ArgumentParserPlus, check_hf_olmo_availability, combine_dataset + +api = HfApi() +# we don't use `multiprocessing.cpu_count()` because typically we only have 12 CPUs +# and that the shards might be small +NUM_CPUS_FOR_DATASET_MAP = 4 +if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + from vllm.model_executor.models import ModelRegistry + ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + +@dataclass +class Args: + dataset_mixer_list: List[str] + dataset_splits: List[str] = None + dataset_start_idx: int = 0 + dataset_end_idx: Optional[int] = None + + model_name_or_path: str = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr" + revision: str = "main" + save_filename: str = "completions.jsonl" + + # upload config + hf_repo_id: str = os.path.basename(__file__)[: -len(".py")] + push_to_hub: bool = False + hf_entity: Optional[str] = None + add_timestamp: bool = True + + +@dataclass +class GenerationArgs: + num_completions: int = 3 + temperature: float = 0.8 + response_length: int = 2048 + top_p: float = 0.9 + tensor_parallel_size: int = 1 + + +def save_jsonl(save_filename: str, table: Dict[str, List]): + first_key = list(table.keys())[0] + os.makedirs(os.path.dirname(save_filename), exist_ok=True) + with open(save_filename, "w") as outfile: + for i in range(len(table[first_key])): + json.dump({key: table[key][i] for key in table}, outfile) + outfile.write("\n") + + +def generate_with_vllm(model_name_or_path: str, revision: str, prompt_token_ids: List[int], gen_args: GenerationArgs): + llm = LLM( + model=model_name_or_path, + revision=revision, + tokenizer_revision=revision, + tensor_parallel_size=gen_args.tensor_parallel_size, + max_model_len=gen_args.response_length, + ) + + # filter out prompts which are beyond the model's max token length + max_model_len = llm.llm_engine.scheduler_config.max_model_len + prompt_token_ids_len = len(prompt_token_ids) + prompt_token_ids = [item for item in prompt_token_ids if len(item) < max_model_len] + if len(prompt_token_ids) != prompt_token_ids_len: + print(f"Filtered out {prompt_token_ids_len - len(prompt_token_ids)} prompts which exceeds max token length") + + outputs = llm.generate( + prompt_token_ids=prompt_token_ids, + sampling_params=SamplingParams( + n=gen_args.num_completions, + temperature=gen_args.temperature, + top_p=1.0, + max_tokens=gen_args.response_length, + include_stop_str_in_output=True, + ), + ) + + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + return response_ids + + +def format_conversation(messages: list) -> str: + formatted_conversation = [] + + # Iterate through the messages + for message in messages: # Exclude the last assistant message + role = "User A" if message["role"] == "user" else "User B" + content = message["content"].strip() + formatted_conversation.append(f"{role}: {content}") + + # Join the conversation with a single newline + return "\n".join(formatted_conversation) + + +def main(args: Args, dataset_config: DatasetConfig, gen_args: GenerationArgs): + if len(args.dataset_splits) != len(args.dataset_mixer_list) // 2: + args.dataset_splits = ["train"] * (len(args.dataset_mixer_list) // 2) + print(f"Using default dataset_splits: {args.dataset_splits} for {(len(args.dataset_mixer_list) // 2)} datasets") + + dataset = combine_dataset( + args.dataset_mixer_list, + splits=args.dataset_splits, + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key, dataset_config.dataset_source_key], + ) + if args.dataset_end_idx is None: + args.dataset_end_idx = len(dataset) + dataset = dataset.select(range(args.dataset_start_idx, args.dataset_end_idx)) + pprint([dataset_config, args, gen_args]) + + + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, revision=args.revision) + dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + dataset = dataset_processor.tokenize(dataset) + dataset = dataset_processor.filter(dataset) + prompt_token_ids = dataset[INPUT_IDS_PROMPT_KEY] + ground_truth = dataset[dataset_config.ground_truths_key] + dataset_source = dataset[dataset_config.dataset_source_key] + response_ids = generate_with_vllm(args.model_name_or_path, args.revision, prompt_token_ids, gen_args) + + + # repeat prompt_token_ids, ground truth, dataset source args.num_completions times + prompt_token_ids = [prompt_token_ids[i] for i in range(len(prompt_token_ids)) for _ in range(gen_args.num_completions)] + ground_truth = [ground_truth[i] for i in range(len(ground_truth)) for _ in range(gen_args.num_completions)] + dataset_source = [dataset_source[i] for i in range(len(dataset_source)) for _ in range(gen_args.num_completions)] + + # left pad prompt token ids with 0 + max_seq_len = max(len(item) for item in prompt_token_ids) + padded_prompt_token_ids = [[0] * (max_seq_len - len(item)) + item for item in prompt_token_ids] + # right pad response token ids with 0 + max_seq_len = max(len(item) for item in response_ids) + padded_response_ids = [item + [0] * (max_seq_len - len(item)) for item in response_ids] + padded_prompt_token_ids = torch.tensor(padded_prompt_token_ids) + padded_response_ids = torch.tensor(padded_response_ids) + query_response = torch.concat([padded_prompt_token_ids, padded_response_ids], dim=1) + verifiable_reward, _ = apply_verifiable_reward( + query_response, + tokenizer, + ground_truth, + dataset_source, + verify_reward=10, + ) + import math + verifiable_reward = verifiable_reward.reshape(-1, gen_args.num_completions) + pass_at_k = (verifiable_reward.sum(dim=1) > 1).float().mean() + maj_at_k = (verifiable_reward.sum(dim=1) > math.ceil(gen_args.num_completions / 2)).float().mean() + printa = lambda i: print(tokenizer.decode(query_response[i]), ground_truth[i]) + print(f"{verifiable_reward=}") + print(f"{pass_at_k=}") + print(f"{maj_at_k=}") + breakpoint() + + # save_jsonl(args.save_filename, table) + +# if args.push_to_hub: +# if args.hf_entity is None: +# args.hf_entity = api.whoami()["name"] +# full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" +# timestamp = f"_{int(time.time())}" +# if args.add_timestamp: +# full_repo_id += timestamp +# api.create_repo(full_repo_id, repo_type="dataset", exist_ok=True) +# for f in [__file__, args.save_filename]: +# api.upload_file( +# path_or_fileobj=f, +# path_in_repo=f.split("/")[-1], +# repo_id=full_repo_id, +# repo_type="dataset", +# ) +# repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}" +# print(f"Pushed to {repo_full_url}") +# run_command = " ".join(["python"] + sys.argv) +# sft_card = RepoCard( +# content=f"""\ +# # allenai/open_instruct: Generation Dataset + +# See https://github.com/allenai/open-instruct/blob/main/docs/algorithms/rejection_sampling.md for more detail + +# ## Configs + +# ``` +# args: +# {pformat(vars(args))} + +# dataset_config: +# {pformat(vars(dataset_config))} + +# gen_args: +# {pformat(vars(gen_args))} +# ``` + +# ## Reproduce this dataset + +# 1. Download the `{[f.split("/")[-1] for f in [__file__, args.save_filename]]}` from the {repo_full_url}. +# 2. Run `{run_command}` +# """ +# ) +# sft_card.push_to_hub( +# full_repo_id, +# repo_type="dataset", +# ) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, GenerationArgs)) + main(*parser.parse()) From 7905e63b1e0539ac31ab0e01a2a119ff3e2424a3 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Fri, 15 Nov 2024 22:44:13 +0000 Subject: [PATCH 49/53] push changes --- .../ppo_vllm_thread_ray_gtrl_olmo.py | 33 +++++++++++++------ open_instruct/vllm_utils2.py | 13 +++++--- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py index 5410b450e..f7b5ddc2c 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py @@ -522,15 +522,26 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): datefmt="%Y-%m-%d %H:%M:%S", ) - if check_hf_olmo_availability(): - # allows AutoModel... to work with not in transformers olmo models - import hf_olmo # noqa - from hf_olmo import OLMoTokenizerFast - from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification - from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM - from vllm.model_executor.models import ModelRegistry - AutoModelForSequenceClassification.register(hf_olmo.OLMoConfig, OLMoForSequenceClassification) - ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + # olmo 1124; pip install git+https://github.com/vwxyzjn/transformers.git@olmo1124_classification + from vllm.model_executor.models import ModelRegistry + from transformers.models.olmo_1124.modeling_olmo_1124 import Olmo1124ForSequenceClassification, Olmo1124Config + AutoModelForSequenceClassification.register(Olmo1124Config, Olmo1124ForSequenceClassification) + from open_instruct.olmo_adapter.olmo_1124_vllm import OlmoNewForCausalLM + ModelRegistry.register_model("Olmo1124ForCausalLM", OlmoNewForCausalLM) + + # # hf_olmo + # import hf_olmo # noqa + # from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification + # AutoModelForSequenceClassification.register(hf_olmo.OLMoConfig, OLMoForSequenceClassification) + # from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + # ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + + # other hf olmo + from open_instruct.olmo_adapter.modeling_olmo3 import OlmoForSequenceClassification + from open_instruct.olmo_adapter.modeling_olmoe3 import OlmoeForSequenceClassification + from transformers import OlmoConfig, OlmoeConfig + AutoModelForSequenceClassification.register(OlmoConfig, OlmoForSequenceClassification) + AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification) self.world_size = world_size self.rank = rank self.local_rank = local_rank @@ -1097,6 +1108,8 @@ def vllm_generate( _, score, _ = get_reward( self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length ) + if self.rank == 0 and i == 0: + print(postprocessed_query_response[0].tolist(), tokenizer.decode(postprocessed_query_response[0])) if args.reward_model_multiplier != 1.0: score *= args.reward_model_multiplier # also apply verifiable reward @@ -1373,7 +1386,7 @@ def vllm_generate( # Ai2 logic: we use /output to store the artifacts of the job, so we # make a copy of the model to `/output` in the end. - if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0: + if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 and args.output_dir != "/output": shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training") diff --git a/open_instruct/vllm_utils2.py b/open_instruct/vllm_utils2.py index b50f76fa0..3f97ffb55 100644 --- a/open_instruct/vllm_utils2.py +++ b/open_instruct/vllm_utils2.py @@ -133,12 +133,15 @@ def update_weight(self, name, dtype, shape, empty_cache=False): class LLMRayActor: def __init__(self, *args, **kwargs): import vllm - if check_hf_olmo_availability(): + from vllm.model_executor.models import ModelRegistry + # if check_hf_olmo_availability(): # allows AutoModel... to work with not in transformers olmo models - import hf_olmo # noqa - from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM - from vllm.model_executor.models import ModelRegistry - ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + # import hf_olmo # noqa + # from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + # ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + from open_instruct.olmo_adapter.olmo_1124_vllm import OlmoNewForCausalLM + ModelRegistry.register_model("Olmo1124ForCausalLM", OlmoNewForCausalLM) + self.__version__ = vllm.__version__ assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1" From e857e36819e8b4b66958f37d06854a77885fce3f Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Sat, 23 Nov 2024 20:29:52 +0000 Subject: [PATCH 50/53] quick change --- Dockerfile | 84 +-- open_instruct/olmo_adapter/olmo_1124_vllm.py | 539 ++++++++++++++++++ open_instruct/olmo_adapter/olmo_1124_vllm2.py | 376 ++++++++++++ .../ppo_vllm_thread_ray_gtrl_olmo.py | 60 +- 4 files changed, 943 insertions(+), 116 deletions(-) create mode 100644 open_instruct/olmo_adapter/olmo_1124_vllm.py create mode 100644 open_instruct/olmo_adapter/olmo_1124_vllm2.py diff --git a/Dockerfile b/Dockerfile index f89725c3d..626922ab7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,84 +1,4 @@ -ARG CUDA -ARG DIST -ARG TARGET -FROM --platform=linux/amd64 nvidia/cuda:${CUDA}-${TARGET}-${DIST} - -ARG DEBIAN_FRONTEND="noninteractive" -ENV TZ="America/Los_Angeles" - -# Install base tools. -RUN apt-get update && apt-get install -y \ - build-essential \ - curl \ - git \ - jq \ - language-pack-en \ - make \ - sudo \ - unzip \ - vim \ - wget \ - parallel \ - iputils-ping \ - tmux - -ARG BEAKER_VERSION -RUN curl --silent \ - --connect-timeout 5 \ - --max-time 10 \ - --retry 5 \ - --retry-delay 0 \ - --retry-max-time 40 \ - --output beaker.tar.gz \ - "https://beaker.org/api/v3/release/cli?os=linux&arch=amd64&version=${BEAKER_VERSION}" \ - && tar -zxf beaker.tar.gz -C /usr/local/bin/ ./beaker \ - && rm beaker.tar.gz - -# This ensures the dynamic linker (or NVIDIA's container runtime, I'm not sure) -# puts the right NVIDIA things in the right place (that THOR requires). -ENV NVIDIA_DRIVER_CAPABILITIES=graphics,utility,compute - -# Install conda. We give anyone in the users group the ability to run -# conda commands and install packages in the base (default) environment. -# Things installed into the default environment won't persist, but we prefer -# convenience in this case and try to make sure the user is aware of this -# with a message that's printed when the session starts. -RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh \ - && echo "32d73e1bc33fda089d7cd9ef4c1be542616bd8e437d1f77afeeaf7afdb019787 Miniconda3-py310_23.1.0-1-Linux-x86_64.sh" \ - | sha256sum --check \ - && bash Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -b -p /opt/miniconda3 \ - && rm Miniconda3-py310_23.1.0-1-Linux-x86_64.sh - -ENV PATH=/opt/miniconda3/bin:/opt/miniconda3/condabin:$PATH -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - -# Install a few additional utilities via pip -RUN /opt/miniconda3/bin/pip install --no-cache-dir \ - gpustat \ - jupyter \ - beaker-gantry \ - oocmap - -# Ensure users can modify their container environment. -RUN echo '%users ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers - -# Make the base image friendlier for interactive workloads. This makes things like the man command -# work. -RUN yes | unminimize - -# Install MLNX OFED user-space drivers -# See https://docs.nvidia.com/networking/pages/releaseview.action?pageId=15049785#Howto:DeployRDMAacceleratedDockercontaineroverInfiniBandfabric.-Dockerfile -ENV MOFED_VER 5.8-1.1.2.1 -ENV OS_VER ubuntu20.04 -ENV PLATFORM x86_64 -RUN wget --quiet https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VER}/MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ - tar -xvf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ - MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}/mlnxofedinstall --basic --user-space-only --without-fw-update -q && \ - rm -rf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM} && \ - rm MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz - -# The -l flag makes bash act as a login shell and load /etc/profile, etc. -ENTRYPOINT ["bash", "-l"] +FROM ghcr.io/allenai/cuda:12.1-cudnn8-dev-ubuntu20.04-v1.2.116 WORKDIR /stage/ @@ -91,7 +11,6 @@ RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url h RUN pip install packaging RUN pip install flash-attn==2.6.3 --no-build-isolation RUN pip install -r requirements.txt -RUN pip install ai2_olmo # NLTK download RUN python -m nltk.downloader punkt @@ -107,6 +26,7 @@ COPY configs configs COPY scripts scripts COPY mason.py mason.py RUN chmod +x scripts/* +RUN pip cache purge # for interactive session RUN chmod -R 777 /stage/ diff --git a/open_instruct/olmo_adapter/olmo_1124_vllm.py b/open_instruct/olmo_adapter/olmo_1124_vllm.py new file mode 100644 index 000000000..a5a6061b0 --- /dev/null +++ b/open_instruct/olmo_adapter/olmo_1124_vllm.py @@ -0,0 +1,539 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + + +class FlippedSiluAndMul(SiluAndMul): + """OLMo is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + + attention_layer_norm = True + if attention_layer_norm: + # TODO: finish adding qk norm and norm_after + self.k_norm = RMSNorm( + (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, + eps=config.rms_norm_eps, + #elementwise_affine=config.attention_layer_norm_with_affine, + #bias=False, + ) + self.q_norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + #q = self.q_norm(q) + #k = self.k_norm(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + try: + self.intermediate_size = config.intermediate_size + except AttributeError: + if config.mlp_hidden_size is not None: + self.intermediate_size = config.mlp_hidden_size // 2 + else: + self.intermediate_size = (config.hidden_size * config.mlp_ratio) // 2 + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + + self.norm_after = True + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + """ + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + """ + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + if self.norm_after: + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + if self.norm_after: + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class OlmoModel(nn.Module): + + def __init__(self, + config: Union[OlmoConfig], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + OlmoDecoderLayer(config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + #elementwise_affine=config.layer_norm_with_affine, + #bias=config.bias_for_layer_norm + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + # Get embeddings of input. + # shape: (batch_size, seq_len, hidden_size) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + + # Apply blocks one-by-one. + for layer_idx, decoder_layer in enumerate(self.layers): + # shape: (batch_size, seq_len, hidden_size) + hidden_states = decoder_layer( + positions, + hidden_states, + kv_caches[layer_idx], + attn_metadata, + ) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, hidden_size) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoNewForCausalLM(nn.Module): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.model = OlmoModel(config, cache_config, quant_config) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + #self.unpadded_vocab_size, + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + #org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + mapper = {} + "loaded weights -> uninitialized model weights" + for layer_i in range(self.config.num_hidden_layers): + mapper[f"model.layers.{layer_i}.post_attention_layernorm.weight"] = f"model.layers.{layer_i}.input_layernorm.weight" + mapper[f"model.layers.{layer_i}.post_feedforward_layernorm.weight"] = f"model.layers.{layer_i}.post_attention_layernorm.weight" + # from rich.pretty import pprint + # pprint(mapper) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + # print("loaded", name, param) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[mapper.get(name, name)] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # print("loaded", name, param) + + +# loaded model.embed_tokens.weight +# loaded model.layers.0.self_attn.o_proj.weight +# loaded model.layers.0.self_attn.q_norm.weight +# loaded model.layers.0.self_attn.k_norm.weight +# loaded model.layers.0.mlp.down_proj.weight +# loaded model.layers.0.post_attention_layernorm.weight +# loaded model.layers.0.post_feedforward_layernorm.weight +# loaded model.layers.1.self_attn.o_proj.weight +# loaded model.layers.1.self_attn.q_norm.weight +# loaded model.layers.1.self_attn.k_norm.weight +# loaded model.layers.1.mlp.down_proj.weight +# loaded model.layers.1.post_attention_layernorm.weight +# loaded model.layers.1.post_feedforward_layernorm.weight +# loaded model.layers.2.self_attn.o_proj.weight +# loaded model.layers.2.self_attn.q_norm.weight +# loaded model.layers.2.self_attn.k_norm.weight +# loaded model.layers.2.mlp.down_proj.weight +# loaded model.layers.2.post_attention_layernorm.weight +# loaded model.layers.2.post_feedforward_layernorm.weight +# loaded model.norm.weight +# loaded lm_head.weight + +# OlmoNewForCausalLM( +# (model): OlmoModel( +# (embed_tokens): VocabParallelEmbedding(num_embeddings=100352, embedding_dim=4096, org_vocab_size=100352, num_embeddings_padded=100352, tp_size=1) +# (layers): ModuleList( +# (0-31): 32 x OlmoDecoderLayer( +# (self_attn): OlmoAttention( +# (qkv_proj): QKVParallelLinear(in_features=4096, output_features=12288, bias=False, tp_size=1, gather_output=False) +# (k_norm): RMSNorm(hidden_size=4096, eps=1e-06) +# (q_norm): RMSNorm(hidden_size=4096, eps=1e-06) +# (rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=4096, base=500000, is_neox_style=True) +# (attn): Attention(head_size=128, num_heads=32, num_kv_heads=32, scale=0.08838834764831845, backend=FlashAttentionImpl) +# (o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True) +# ) +# (mlp): OlmoMLP( +# (gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=22016, bias=False, tp_size=1, gather_output=False) +# (act_fn): FlippedSiluAndMul() +# (down_proj): RowParallelLinear(input_features=11008, output_features=4096, bias=False, tp_size=1, reduce_results=True) +# ) +# (input_layernorm): RMSNorm(hidden_size=4096, eps=1e-06) +# (post_attention_layernorm): RMSNorm(hidden_size=4096, eps=1e-06) +# ) +# ) +# (norm): RMSNorm(hidden_size=4096, eps=1e-06) +# ) +# (lm_head): ParallelLMHead(num_embeddings=100352, embedding_dim=4096, org_vocab_size=100352, num_embeddings_padded=100352, tp_size=1) +# (logits_processor): LogitsProcessor(vocab_size=100352, forg_vocab_size=100352, scale=1.0, logits_as_input=False) +# (sampler): Sampler() +# ) +# Olmo1124ForCausalLM( +# (model): Olmo1124Model( +# (embed_tokens): Embedding(100352, 4096, padding_idx=100277) +# (layers): ModuleList( +# (0-31): 32 x Olmo1124DecoderLayer( +# (self_attn): Olmo1124SdpaAttention( +# (q_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (k_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (v_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (o_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (rotary_emb): Olmo1124RotaryEmbedding() +# (q_norm): Olmo1124RMSNorm((4096,), eps=1e-06) +# (k_norm): Olmo1124RMSNorm((4096,), eps=1e-06) +# ) +# (mlp): Olmo1124MLP( +# (gate_proj): Linear(in_features=4096, out_features=11008, bias=False) +# (up_proj): Linear(in_features=4096, out_features=11008, bias=False) +# (down_proj): Linear(in_features=11008, out_features=4096, bias=False) +# (act_fn): SiLU() +# ) +# (post_attention_layernorm): Olmo1124RMSNorm((4096,), eps=1e-06) +# (post_feedforward_layernorm): Olmo1124RMSNorm((4096,), eps=1e-06) +# ) +# ) +# (norm): Olmo1124RMSNorm((4096,), eps=1e-06) +# ) +# (lm_head): Linear(in_features=4096, out_features=100352, bias=False) +# ) + +# OLMoForCausalLM( +# (model): OLMo( +# (transformer): ModuleDict( +# (wte): Embedding(100352, 4096) +# (emb_drop): Dropout(p=0.0, inplace=False) +# (ln_f): RMSLayerNorm() +# (blocks): ModuleList( +# (0-31): 32 x OLMoSequentialBlock( +# (dropout): Dropout(p=0.0, inplace=False) +# (k_norm): RMSLayerNorm() +# (q_norm): RMSLayerNorm() +# (act): SwiGLU() +# (attn_out): Linear(in_features=4096, out_features=4096, bias=False) +# (ff_out): Linear(in_features=11008, out_features=4096, bias=False) +# (rotary_emb): RotaryEmbedding() +# (att_proj): Linear(in_features=4096, out_features=12288, bias=False) +# (ff_proj): Linear(in_features=4096, out_features=22016, bias=False) +# (attn_norm): RMSLayerNorm() +# (ff_norm): RMSLayerNorm() +# ) +# ) +# (ff_out): Linear(in_features=4096, out_features=100352, bias=False) +# ) +# ) +# ) \ No newline at end of file diff --git a/open_instruct/olmo_adapter/olmo_1124_vllm2.py b/open_instruct/olmo_adapter/olmo_1124_vllm2.py new file mode 100644 index 000000000..d7bcf8972 --- /dev/null +++ b/open_instruct/olmo_adapter/olmo_1124_vllm2.py @@ -0,0 +1,376 @@ +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@support_torch_compile +class OlmoModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + if get_pp_group().is_first_rank: + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + # Apply blocks one-by-one. + for i in range(self.start_layer, self.end_layer): + # shape: (batch_size, seq_len, d_model) + hidden_states = self.layers[i]( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoForCausalLM(nn.Module, SupportsPP): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.model = OlmoModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py index f7b5ddc2c..f979b3a42 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py @@ -99,7 +99,6 @@ from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, - check_hf_olmo_availability, combine_dataset, get_wandb_tags, is_beaker_job, @@ -149,6 +148,8 @@ class Args: """Which scheduler to use""" warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" + warmup_ratio: float = 0.0 + """Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)""" # various batch sizes num_train_epochs: int = 1 @@ -523,25 +524,8 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): ) # olmo 1124; pip install git+https://github.com/vwxyzjn/transformers.git@olmo1124_classification - from vllm.model_executor.models import ModelRegistry from transformers.models.olmo_1124.modeling_olmo_1124 import Olmo1124ForSequenceClassification, Olmo1124Config AutoModelForSequenceClassification.register(Olmo1124Config, Olmo1124ForSequenceClassification) - from open_instruct.olmo_adapter.olmo_1124_vllm import OlmoNewForCausalLM - ModelRegistry.register_model("Olmo1124ForCausalLM", OlmoNewForCausalLM) - - # # hf_olmo - # import hf_olmo # noqa - # from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification - # AutoModelForSequenceClassification.register(hf_olmo.OLMoConfig, OLMoForSequenceClassification) - # from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM - # ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) - - # other hf olmo - from open_instruct.olmo_adapter.modeling_olmo3 import OlmoForSequenceClassification - from open_instruct.olmo_adapter.modeling_olmoe3 import OlmoeForSequenceClassification - from transformers import OlmoConfig, OlmoeConfig - AutoModelForSequenceClassification.register(OlmoConfig, OlmoForSequenceClassification) - AutoModelForSequenceClassification.register(OlmoeConfig, OlmoeForSequenceClassification) self.world_size = world_size self.rank = rank self.local_rank = local_rank @@ -626,11 +610,15 @@ def from_pretrained( # optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) + num_training_steps = args.num_training_steps * args.num_train_epochs * args.num_epochs + warm_up_steps = args.warm_up_steps + if args.warmup_ratio >= 0.0: + warm_up_steps = int(num_training_steps * args.warmup_ratio) scheduler = get_scheduler( args.lr_scheduler_type, optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + num_warmup_steps=warm_up_steps, + num_training_steps=num_training_steps, ) print(ds_config) self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( @@ -664,8 +652,8 @@ def from_pretrained( scheduler = get_scheduler( args.lr_scheduler_type, optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + num_warmup_steps=warm_up_steps, + num_training_steps=num_training_steps, ) self.value_model, self.optimizer, _, self.scheduler = deepspeed.initialize( model=self.value_model, @@ -1108,8 +1096,8 @@ def vllm_generate( _, score, _ = get_reward( self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length ) - if self.rank == 0 and i == 0: - print(postprocessed_query_response[0].tolist(), tokenizer.decode(postprocessed_query_response[0])) + # if self.rank == 0 and i == 0: + # print(postprocessed_query_response[0].tolist(), tokenizer.decode(postprocessed_query_response[0])) if args.reward_model_multiplier != 1.0: score *= args.reward_model_multiplier # also apply verifiable reward @@ -1330,6 +1318,7 @@ def vllm_generate( local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() local_metrics[17] = verifiable_correct_rate + local_metrics[18] = contain_stop_token.float().mean() # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() local_metrics /= dist.get_world_size() dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) @@ -1359,6 +1348,7 @@ def vllm_generate( "val/ratio": global_metrics[13], "val/ratio_var": global_metrics[14], "objective/verifiable_correct_rate": global_metrics[17], + "val/stop_token_rate": global_metrics[18], } if accelerator.is_main_process: print_rich_single_line_metrics(metrics) @@ -1481,9 +1471,10 @@ def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = python scripts/submit_eval_jobs.py \ --model_name {leaderboard_name} \ --location {step_dir} \ - --cluster ai2/saturn-cirrascale \ + --cluster ai2/saturn-cirrascale ai2/neptune-cirrascale \ --is_tuned \ --workspace "tulu-3-results" \ + --priority high \ --preemptible \ --use_hf_tokenizer_template \ --beaker_image "nathanl/open_instruct_auto" \ @@ -1606,19 +1597,19 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): ) # create a tokenizer (pad from right) - if check_hf_olmo_availability(): - # allows AutoModel... to work with not in transformers olmo models - import hf_olmo # noqa - from hf_olmo import OLMoTokenizerFast + # if check_hf_olmo_availability(): + # # allows AutoModel... to work with not in transformers olmo models + # import hf_olmo # noqa + # from hf_olmo import OLMoTokenizerFast config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) tokenizer = AutoTokenizer.from_pretrained( model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" ) - if check_hf_olmo_availability(): - print("Using exsiting tokenier chat template...") - pass - else: - tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + # if check_hf_olmo_availability(): + # print("Using exsiting tokenier chat template...") + # pass + # else: + # tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] # if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: # tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> # elif check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast): @@ -1811,6 +1802,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): --pure_docker_mode \ --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ --beaker_workload_id {beaker_config.beaker_workload_id} \ + --upload_to_hf {args.hf_metadata_dataset} \ --model_name {args.hf_repo_revision} """ process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) From 0bebd50f90d694c939470107b7ea0b538344ecab Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Tue, 26 Nov 2024 19:51:04 +0000 Subject: [PATCH 51/53] push --- open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py | 2 +- open_instruct/reward_modeling.py | 2 ++ scripts/submit_eval_jobs.py | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py index f979b3a42..6c3c893cc 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py @@ -1478,7 +1478,7 @@ def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = --preemptible \ --use_hf_tokenizer_template \ --beaker_image "nathanl/open_instruct_auto" \ - --upload_to_hf allenai/tulu-3-evals \ + --upload_to_hf {args.hf_metadata_dataset} \ --run_oe_eval_experiments \ --evaluate_on_weka \ --run_safety_evaluations \ diff --git a/open_instruct/reward_modeling.py b/open_instruct/reward_modeling.py index ffc4c50cd..a2a922d46 100644 --- a/open_instruct/reward_modeling.py +++ b/open_instruct/reward_modeling.py @@ -56,6 +56,8 @@ maybe_use_ai2_wandb_entity, ) +from transformers.models.olmo_1124.modeling_olmo_1124 import Olmo1124ForSequenceClassification, Olmo1124Config +AutoModelForSequenceClassification.register(Olmo1124Config, Olmo1124ForSequenceClassification) api = HfApi() diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index 0c6f7cd18..48e9d061b 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -660,7 +660,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): experiment_name = experiment_name.replace('β', '').replace(r"{", "").replace(r"}", "") # hack: remove characters beaker doesn't like d["description"] = experiment_name # specific image for safety eval - d["tasks"][0]["image"]["beaker"] = "hamishivi/open-safety" + d["tasks"][0]["image"]["beaker"] = "hamishivi/open_safety_1124" if args.use_alternate_safety_image: d["tasks"][0]["image"]["beaker"] = args.use_alternate_safety_image d["tasks"] = [d["tasks"][0]] From b452f92266ec0f097ce8fd3a6cd1b554160eabb7 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 8 Jan 2025 06:28:54 -0800 Subject: [PATCH 52/53] push changes --- open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py | 4 ++-- open_instruct/vllm_utils2.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py index 6c3c893cc..d39f14fed 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py @@ -524,8 +524,8 @@ def __init__(self, world_size, rank, local_rank, master_addr, master_port): ) # olmo 1124; pip install git+https://github.com/vwxyzjn/transformers.git@olmo1124_classification - from transformers.models.olmo_1124.modeling_olmo_1124 import Olmo1124ForSequenceClassification, Olmo1124Config - AutoModelForSequenceClassification.register(Olmo1124Config, Olmo1124ForSequenceClassification) + from transformers.models.olmo2.modeling_olmo2 import Olmo2ForSequenceClassification, Olmo2Config + AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification) self.world_size = world_size self.rank = rank self.local_rank = local_rank diff --git a/open_instruct/vllm_utils2.py b/open_instruct/vllm_utils2.py index 13fefeabb..09aa73fff 100644 --- a/open_instruct/vllm_utils2.py +++ b/open_instruct/vllm_utils2.py @@ -134,6 +134,7 @@ def __init__(self, *args, **kwargs): from vllm.model_executor.models import ModelRegistry from open_instruct.olmo_adapter.olmo_1124_vllm import OlmoNewForCausalLM ModelRegistry.register_model("Olmo1124ForCausalLM", OlmoNewForCausalLM) + ModelRegistry.register_model("Olmo2ForCausalLM", OlmoNewForCausalLM) self.__version__ = vllm.__version__ assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1" From f732a739c7e9371007993c352e01515b431f7c45 Mon Sep 17 00:00:00 2001 From: Costa Huang Date: Wed, 8 Jan 2025 06:30:51 -0800 Subject: [PATCH 53/53] push changes --- open_instruct/ppo_vllm_thread_ray_gtrl.py | 24 ++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 4c9377bb2..de36c7489 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -148,6 +148,8 @@ class Args: """Which scheduler to use""" warm_up_steps: int = 0 """Number of warm up steps for the scheduler""" + warmup_ratio: float = 0.0 + """Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)""" # various batch sizes num_train_epochs: int = 1 @@ -604,11 +606,15 @@ def from_pretrained( # optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) + num_training_steps = args.num_training_steps * args.num_train_epochs * args.num_epochs + warm_up_steps = args.warm_up_steps + if args.warmup_ratio >= 0.0: + warm_up_steps = int(num_training_steps * args.warmup_ratio) scheduler = get_scheduler( args.lr_scheduler_type, optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + num_warmup_steps=warm_up_steps, + num_training_steps=num_training_steps, ) print(ds_config) self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( @@ -642,8 +648,8 @@ def from_pretrained( scheduler = get_scheduler( args.lr_scheduler_type, optimizer=self.optimizer, - num_warmup_steps=args.warm_up_steps, - num_training_steps=args.num_training_steps * args.num_train_epochs * args.num_epochs, + num_warmup_steps=warm_up_steps, + num_training_steps=num_training_steps, ) self.value_model, self.optimizer, _, self.scheduler = deepspeed.initialize( model=self.value_model, @@ -1306,6 +1312,7 @@ def vllm_generate( local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() local_metrics[17] = verifiable_correct_rate + local_metrics[18] = contain_stop_token.float().mean() # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() local_metrics /= dist.get_world_size() dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) @@ -1335,6 +1342,7 @@ def vllm_generate( "val/ratio": global_metrics[13], "val/ratio_var": global_metrics[14], "objective/verifiable_correct_rate": global_metrics[17], + "val/stop_token_rate": global_metrics[18], } if accelerator.is_main_process: print_rich_single_line_metrics(metrics) @@ -1362,7 +1370,7 @@ def vllm_generate( # Ai2 logic: we use /output to store the artifacts of the job, so we # make a copy of the model to `/output` in the end. - if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0: + if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 and args.output_dir != "/output": shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training") @@ -1457,13 +1465,14 @@ def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = python scripts/submit_eval_jobs.py \ --model_name {leaderboard_name} \ --location {step_dir} \ - --cluster ai2/saturn-cirrascale \ + --cluster ai2/saturn-cirrascale ai2/neptune-cirrascale \ --is_tuned \ --workspace "tulu-3-results" \ + --priority high \ --preemptible \ --use_hf_tokenizer_template \ --beaker_image "nathanl/open_instruct_auto" \ - --upload_to_hf allenai/tulu-3-evals \ + --upload_to_hf {args.hf_metadata_dataset} \ --run_oe_eval_experiments \ --evaluate_on_weka \ --run_safety_evaluations \ @@ -1763,6 +1772,7 @@ def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): --pure_docker_mode \ --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ --beaker_workload_id {beaker_config.beaker_workload_id} \ + --upload_to_hf {args.hf_metadata_dataset} \ --model_name {args.hf_repo_revision} """ process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE)