diff --git a/.github/workflows/push-image-olmo.yml b/.github/workflows/push-image-olmo.yml new file mode 100644 index 000000000..28a8c3467 --- /dev/null +++ b/.github/workflows/push-image-olmo.yml @@ -0,0 +1,81 @@ +# This is an example workflow file. +# +# When you add a new image, copy this file and then change all mentions of "hello-world" with +# the name of your new image. +# +# Read through the rest of the comments in this file to figure out how it works, and what else +# you need to change. +name: build_open_instruct_olmo + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + push: + # Run this workflow anytime a push updates one of the files in the image's directory + # (other than the README), and anytime there's a new release tag for this image. + paths: + - 'open_instruct/**' + - '!open_instruct/README.md' + - 'requirements-olmo.txt' + - 'Dockerfile.olmo' + - '.github/workflows/push-image-olmo.yml' + # Note, add .olmo dockerfile + requirements if adding auto build to those + branches: [main] + # pull_request: # note, comment this out for running on every push + # # Also run on PRs that update the files in the image's directory (other than README). + # branches: [main] + # paths: + # - 'open_instruct/**' + # - '!open_instruct/README.md' + # - 'requirements-olmo.txt' + # - 'Dockerfile.olmo' + workflow_dispatch: # This allows us to manually trigger a build through the GitHub UI. + +env: + DOCKER_BUILDKIT: "1" + +jobs: + build: + name: open_instruct + runs-on: ubuntu-latest + timeout-minutes: 60 + if: (github.event_name != 'workflow_run') || (github.event.workflow_run.conclusion == 'success') + steps: + - uses: actions/checkout@v3 + with: + repository: allenai/oe-eval-internal + path: './oe-eval-internal' + ssh-key: ${{ secrets.OE_EVAL_GIT_CLONE_ACCESS_PRIVATE_SSH_DEPLOY_KEY }} + + - name: Setup environment + uses: ./.github/actions/setup + with: + beaker_token: ${{ secrets.BEAKER_TOKEN }} + # ghcr_token: ${{ secrets.GHCR_TOKEN }} + # ghcr_user: ${{ secrets.GHCR_USER }} + + # big images fail, trying this + - name: Delete huge unnecessary tools folder + run: rm -rf /opt/hostedtoolcache /usr/share/dotnet "$AGENT_TOOLSDIRECTORY" + + - name: Build image + run: | + docker build \ + --build-arg BUILDKIT_INLINE_CACHE=1 \ + --build-arg CUDA=12.1.0 --build-arg \ + TARGET=cudnn8-devel --build-arg DIST=ubuntu20.04 \ + -f Dockerfile.olmo . \ + -t open_instruct_olmo + + - name: Check image + run: | + docker run --rm open_instruct_olmo + - name: Push image + # if: github.event_name != 'pull_request' + uses: ./.github/actions/push + with: + image: open_instruct_olmo # this is the tag of the image we just built in the previous step + beaker: open_instruct_olmo_auto # this is the name of the image on Beaker + latest: true # this flag says we should also push this as the 'latest' version to GHCR diff --git a/.github/workflows/push-image.yml b/.github/workflows/push-image.yml index 40e205fb3..f5a35fdc0 100644 --- a/.github/workflows/push-image.yml +++ b/.github/workflows/push-image.yml @@ -44,8 +44,6 @@ jobs: timeout-minutes: 60 if: (github.event_name != 'workflow_run') || (github.event.workflow_run.conclusion == 'success') steps: - - uses: actions/checkout@v3 - - uses: actions/checkout@v3 with: repository: allenai/oe-eval-internal @@ -69,7 +67,6 @@ jobs: --build-arg BUILDKIT_INLINE_CACHE=1 \ --build-arg CUDA=12.1.0 --build-arg \ TARGET=cudnn8-devel --build-arg DIST=ubuntu20.04 \ - --build-arg REQUIRE=requirements.txt . \ -t open_instruct diff --git a/Dockerfile b/Dockerfile index dd6b95a97..626922ab7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,84 +1,4 @@ -ARG CUDA -ARG DIST -ARG TARGET -FROM --platform=linux/amd64 nvidia/cuda:${CUDA}-${TARGET}-${DIST} - -ARG DEBIAN_FRONTEND="noninteractive" -ENV TZ="America/Los_Angeles" - -# Install base tools. -RUN apt-get update && apt-get install -y \ - build-essential \ - curl \ - git \ - jq \ - language-pack-en \ - make \ - sudo \ - unzip \ - vim \ - wget \ - parallel \ - iputils-ping \ - tmux - -ARG BEAKER_VERSION -RUN curl --silent \ - --connect-timeout 5 \ - --max-time 10 \ - --retry 5 \ - --retry-delay 0 \ - --retry-max-time 40 \ - --output beaker.tar.gz \ - "https://beaker.org/api/v3/release/cli?os=linux&arch=amd64&version=${BEAKER_VERSION}" \ - && tar -zxf beaker.tar.gz -C /usr/local/bin/ ./beaker \ - && rm beaker.tar.gz - -# This ensures the dynamic linker (or NVIDIA's container runtime, I'm not sure) -# puts the right NVIDIA things in the right place (that THOR requires). -ENV NVIDIA_DRIVER_CAPABILITIES=graphics,utility,compute - -# Install conda. We give anyone in the users group the ability to run -# conda commands and install packages in the base (default) environment. -# Things installed into the default environment won't persist, but we prefer -# convenience in this case and try to make sure the user is aware of this -# with a message that's printed when the session starts. -RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh \ - && echo "32d73e1bc33fda089d7cd9ef4c1be542616bd8e437d1f77afeeaf7afdb019787 Miniconda3-py310_23.1.0-1-Linux-x86_64.sh" \ - | sha256sum --check \ - && bash Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -b -p /opt/miniconda3 \ - && rm Miniconda3-py310_23.1.0-1-Linux-x86_64.sh - -ENV PATH=/opt/miniconda3/bin:/opt/miniconda3/condabin:$PATH -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH - -# Install a few additional utilities via pip -RUN /opt/miniconda3/bin/pip install --no-cache-dir \ - gpustat \ - jupyter \ - beaker-gantry \ - oocmap - -# Ensure users can modify their container environment. -RUN echo '%users ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers - -# Make the base image friendlier for interactive workloads. This makes things like the man command -# work. -RUN yes | unminimize - -# Install MLNX OFED user-space drivers -# See https://docs.nvidia.com/networking/pages/releaseview.action?pageId=15049785#Howto:DeployRDMAacceleratedDockercontaineroverInfiniBandfabric.-Dockerfile -ENV MOFED_VER 5.8-1.1.2.1 -ENV OS_VER ubuntu20.04 -ENV PLATFORM x86_64 -RUN wget --quiet https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VER}/MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ - tar -xvf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ - MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}/mlnxofedinstall --basic --user-space-only --without-fw-update -q && \ - rm -rf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM} && \ - rm MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz - -# The -l flag makes bash act as a login shell and load /etc/profile, etc. -ENTRYPOINT ["bash", "-l"] +FROM ghcr.io/allenai/cuda:12.1-cudnn8-dev-ubuntu20.04-v1.2.116 WORKDIR /stage/ @@ -106,6 +26,7 @@ COPY configs configs COPY scripts scripts COPY mason.py mason.py RUN chmod +x scripts/* +RUN pip cache purge # for interactive session RUN chmod -R 777 /stage/ diff --git a/Dockerfile.olmo b/Dockerfile.olmo new file mode 100644 index 000000000..40ef4377a --- /dev/null +++ b/Dockerfile.olmo @@ -0,0 +1,121 @@ +ARG CUDA +ARG DIST +ARG TARGET +FROM --platform=linux/amd64 nvidia/cuda:${CUDA}-${TARGET}-${DIST} + +ARG DEBIAN_FRONTEND="noninteractive" +ENV TZ="America/Los_Angeles" + +# Install base tools. +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + git \ + jq \ + language-pack-en \ + make \ + sudo \ + unzip \ + vim \ + wget \ + parallel \ + iputils-ping \ + tmux + +ARG BEAKER_VERSION +RUN curl --silent \ + --connect-timeout 5 \ + --max-time 10 \ + --retry 5 \ + --retry-delay 0 \ + --retry-max-time 40 \ + --output beaker.tar.gz \ + "https://beaker.org/api/v3/release/cli?os=linux&arch=amd64&version=${BEAKER_VERSION}" \ + && tar -zxf beaker.tar.gz -C /usr/local/bin/ ./beaker \ + && rm beaker.tar.gz + +# This ensures the dynamic linker (or NVIDIA's container runtime, I'm not sure) +# puts the right NVIDIA things in the right place (that THOR requires). +ENV NVIDIA_DRIVER_CAPABILITIES=graphics,utility,compute + +# Install conda. We give anyone in the users group the ability to run +# conda commands and install packages in the base (default) environment. +# Things installed into the default environment won't persist, but we prefer +# convenience in this case and try to make sure the user is aware of this +# with a message that's printed when the session starts. +RUN wget https://repo.anaconda.com/miniconda/Miniconda3-py310_23.1.0-1-Linux-x86_64.sh \ + && echo "32d73e1bc33fda089d7cd9ef4c1be542616bd8e437d1f77afeeaf7afdb019787 Miniconda3-py310_23.1.0-1-Linux-x86_64.sh" \ + | sha256sum --check \ + && bash Miniconda3-py310_23.1.0-1-Linux-x86_64.sh -b -p /opt/miniconda3 \ + && rm Miniconda3-py310_23.1.0-1-Linux-x86_64.sh + +ENV PATH=/opt/miniconda3/bin:/opt/miniconda3/condabin:$PATH +ENV LD_LIBRARY_PATH=/usr/local/cuda/lib:/usr/local/cuda/lib64:$LD_LIBRARY_PATH + +# Install a few additional utilities via pip +RUN /opt/miniconda3/bin/pip install --no-cache-dir \ + gpustat \ + jupyter \ + beaker-gantry \ + oocmap + +# Ensure users can modify their container environment. +RUN echo '%users ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers + +# Make the base image friendlier for interactive workloads. This makes things like the man command +# work. +RUN yes | unminimize + +# Install MLNX OFED user-space drivers +# See https://docs.nvidia.com/networking/pages/releaseview.action?pageId=15049785#Howto:DeployRDMAacceleratedDockercontaineroverInfiniBandfabric.-Dockerfile +ENV MOFED_VER 5.8-1.1.2.1 +ENV OS_VER ubuntu20.04 +ENV PLATFORM x86_64 +RUN wget --quiet https://content.mellanox.com/ofed/MLNX_OFED-${MOFED_VER}/MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ + tar -xvf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz && \ + MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}/mlnxofedinstall --basic --user-space-only --without-fw-update -q && \ + rm -rf MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM} && \ + rm MLNX_OFED_LINUX-${MOFED_VER}-${OS_VER}-${PLATFORM}.tgz + +# The -l flag makes bash act as a login shell and load /etc/profile, etc. +ENTRYPOINT ["bash", "-l"] + +WORKDIR /stage/ + +# TODO When updating flash-attn or torch in the future, make sure to update the version in the requirements.txt file. +ENV HF_HUB_ENABLE_HF_TRANSFER=1 +COPY requirements-olmo.txt . +RUN pip install --upgrade pip "setuptools<70.0.0" wheel +# TODO, unpin setuptools when this issue in flash attention is resolved +RUN pip install torch==2.4.0 torchvision==0.19.0 torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cu121 +RUN pip install packaging +RUN pip install flash-attn==2.5.9.post1 --no-build-isolation +# for newest olmo's, move to requirements when ai2-olmo supports torch 2.4 +# core is a dependency of ai2-olmo +RUN pip install ai2-olmo-core==0.1.0 omegaconf +# RUN pip install ai2-olmo>=0.5.0 --no-deps +# TODO Update Once this is merged https://github.com/allenai/OLMo/pull/719, then next release +RUN pip install git+https://github.com/allenai/OLMo.git@47f8f5abb40eb100c6623be12e1648c841b2ab99 --no-deps +RUN pip install -r requirements-olmo.txt + +RUN pip install git+https://github.com/AkshitaB/vllm.git +RUN pip install vllm-flash-attn + + +# NLTK download +RUN python -m nltk.downloader punkt +COPY open_instruct open_instruct +COPY oe-eval-internal oe-eval-internal + +# install the package in editable mode +COPY pyproject.toml . +RUN pip install -e . +COPY .git/ ./.git/ +COPY eval eval +COPY configs configs +COPY scripts scripts +COPY mason.py mason.py +RUN chmod +x scripts/* + +# for interactive session +RUN chmod -R 777 /stage/ diff --git a/configs/beaker_configs/beaker_configs/ray_node_setup.sh b/configs/beaker_configs/beaker_configs/ray_node_setup.sh new file mode 100644 index 000000000..e746cff1c --- /dev/null +++ b/configs/beaker_configs/beaker_configs/ray_node_setup.sh @@ -0,0 +1,21 @@ +export CURRENT_DATETIME=$(python -c "import datetime; import pytz; print(datetime.datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%m%d%y_%H%M%S'))") +export PYTHONPATH=$REPO_PATH +export PATH="/root/.local/bin:$PATH" + + +echo CURRENT_DATETIME=$CURRENT_DATETIME +echo PYTHONPATH=$PYTHONPATH +echo PATH=$PATH + +# python3 -c "import os, ray; print(os.path.dirname(ray.__file__))" + +RAY_NODE_PORT=8888 +ray stop --force + +if [ "$BEAKER_REPLICA_RANK" == "0" ]; then + echo "Starting Ray head node" + ray start --head --port=$RAY_NODE_PORT +else + echo "Starting Ray worker node $BEAKER_REPLICA_RANK" + ray start --address="${BEAKER_LEADER_REPLICA_HOSTNAME}:${RAY_NODE_PORT}" --block +fi \ No newline at end of file diff --git a/configs/beaker_configs/default_finetune_offloading.yaml b/configs/beaker_configs/default_finetune_offloading.yaml new file mode 100644 index 000000000..4722b4e13 --- /dev/null +++ b/configs/beaker_configs/default_finetune_offloading.yaml @@ -0,0 +1,69 @@ +version: v2 +description: open-instruct-finetune +budget: ai2/oe-adapt +tasks: + - name: open-instruct-finetune + image: + beaker: nathanl/open_instruct_auto + command: [ + '/bin/sh', '-c' + ] + arguments: ['PYTHONPATH="/stage:$PYTHONPATH" accelerate launch + --mixed_precision bf16 + --num_machines 1 + --num_processes 4 + --use_deepspeed + --deepspeed_config_file configs/ds_configs/stage3_offloading_accelerate.conf + open_instruct/finetune.py + --model_name_or_path /hf_llama_models + --use_flash_attn + --tokenizer_name /hf_llama_models + --max_seq_length 2048 + --preprocessing_num_workers 16 + --per_device_train_batch_size 2 + --gradient_accumulation_steps 16 + --learning_rate 2e-5 + --lr_scheduler_type linear + --warmup_ratio 0.03 + --weight_decay 0. + --num_train_epochs 2 + --output_dir /output/ + --with_tracking + --report_to tensorboard + --logging_steps 1 + '] + envVars: + - name: CUDA_DEVICE_ORDER + value: PCI_BUS_ID + - name: TRANSFORMERS_CACHE + value: ./cache/ + - name: WANDB_API_KEY + secret: WANDB_API_KEY + - name: WANDB_PROJECT + value: open-instruct + - name: WANDB_WATCH + value: false + - name: WANDB_LOG_MODEL + value: false + - name: WANDB_DISABLED + value: true + - name: HF_TOKEN + secret: HF_TOKEN + # datasets: # example for how to include datasets in mounting + # - mountPath: /data + # source: + # beaker: Yizhongw03/processed_open_instruct_data + # - mountPath: /mmlu + # source: + # beaker: Yizhongw03/mmlu + # - mountPath: /hf_llama_models + # source: + # beaker: Yizhongw03/hf_llama_model_7B + result: + path: /output + resources: + gpuCount: 4 + context: + cluster: ai2/allennlp-cirrascale + priority: high + preemptible: false \ No newline at end of file diff --git a/configs/train_configs/dpo/olmo_7b_0924.yaml b/configs/train_configs/dpo/olmo_7b_0924.yaml new file mode 100644 index 000000000..3028bfca8 --- /dev/null +++ b/configs/train_configs/dpo/olmo_7b_0924.yaml @@ -0,0 +1,29 @@ +model_name_or_path: /model +model_revision: main +use_flash_attn: true +gradient_checkpointing: true +dataset_mixer: + allenai/ultrafeedback_binarized_cleaned_train: 1.0 + ai2-adapt-dev/DaringAnteater-prefs-RM-filter: 1.0 + ai2-adapt-dev/WildChat-prefs-280824: 1.0 + allenai/tulu-3-hardcoded-preferences: 1.0 +tokenizer_name: /model +use_slow_tokenizer: true +max_seq_length: 2048 +preprocessing_num_workers: 16 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 16 # designed for 8 GPUs, so batch size 128 +learning_rate: 5.0e-7 +lr_scheduler_type: linear +warmup_ratio: 0.1 +weight_decay: 0.0 +num_train_epochs: 1 +output_dir: /output +with_tracking: true +report_to: + - wandb +logging_steps: 1 +use_lora: false +dpo_loss_type: dpo_norm +dpo_beta: 5 +checkpointing_steps: 1000 \ No newline at end of file diff --git a/configs/train_configs/sft/olmo_7b_0924.yaml b/configs/train_configs/sft/olmo_7b_0924.yaml new file mode 100644 index 000000000..e8264bd0a --- /dev/null +++ b/configs/train_configs/sft/olmo_7b_0924.yaml @@ -0,0 +1,22 @@ +model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_revision: main +use_flash_attn: true +tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_name: allenai/llama-3-tulu-v3.3-mix-preview +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 8 # should run with this set to 16 for 1 node only +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: /output/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml new file mode 100644 index 000000000..4539f713a --- /dev/null +++ b/configs/train_configs/sft/olmo_7b_0924_fw2_permissive.yaml @@ -0,0 +1,36 @@ +# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +model_revision: main +use_flash_attn: true +# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_mixer: + ai2-adapt-dev/metamath-qa-reformat: 1.0 # MIT License + natolambert/tulu-v2-sft-mixture-flan: 1.0 # FLAN Apache 2.0 + natolambert/tulu-v2-sft-mixture-cot: 1.0 # FLAN Apache 2.0 + allenai/openassistant-guanaco-reformatted: 1.0 # Apache 2.0 + ai2-adapt-dev/codefeedback-single-turn-reformat-magicoder: 1.0 # MIT MagiCoder section of CodeFeedback + ai2-adapt-dev/aya_dataset-reformat: 1.0 # Apache 2.0 + ai2-adapt-dev/SlimOrca-reformat: 0.25 # MIT License + ai2-adapt-dev/Daring-Anteater-reformat: 1.0 # CC BY 4.0 + ai2-adapt-dev/WebInstructSub-reformat-apache: 0.1 # Apache 2.0 + ai2-adapt-dev/Table-GPT-All-train: 0.5 # MIT +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 +gradient_accumulation_steps: 4 # designed for 4 nodes +# gradient_accumulation_steps: 16 # designed for 1 nodes +gradient_checkpointing: true +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: /output/ +with_tracking: true +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file diff --git a/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml b/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml new file mode 100644 index 000000000..491fc4502 --- /dev/null +++ b/configs/train_configs/sft/olmo_7b_0924_fw2_tulu_v3.4.yaml @@ -0,0 +1,28 @@ +# model_name_or_path: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +model_name_or_path: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +model_revision: main +use_flash_attn: true +# tokenizer_name: ai2-adapt-dev/OLMo-medium-peteish7-anneal-from-928646-50B-nowup-dclm07-flan +tokenizer_name: /adapt-data/ai2-llm/checkpoints/OLMo-medium/peteish7-anneal-from-928646-50B-nowup-dclm07-fw2/step11931-hf +use_slow_tokenizer: false # olmo models only use fast tokenizers +dataset_name: allenai/tulu-v3.4-mix-preview +max_seq_length: 4096 +preprocessing_num_workers: 128 +per_device_train_batch_size: 1 +# gradient_accumulation_steps: 4 # designed for 4 nodes +gradient_accumulation_steps: 8 # designed for 2 nodes +# gradient_accumulation_steps: 16 # designed for 1 nodes +gradient_checkpointing: true +learning_rate: 2.0e-06 +lr_scheduler_type: linear +warmup_ratio: 0.03 +weight_decay: 0.0 +num_train_epochs: 3 +output_dir: /output/ +with_tracking: true +reduce_loss: mean +report_to: + - wandb +logging_steps: 1 +checkpointing_steps: epoch +add_bos: true \ No newline at end of file diff --git a/open_instruct/dpo_tune.py b/open_instruct/dpo_tune.py index 92a783cfb..d9908fb93 100644 --- a/open_instruct/dpo_tune.py +++ b/open_instruct/dpo_tune.py @@ -67,6 +67,7 @@ from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, + check_hf_olmo_availability, clean_last_n_checkpoints, get_datasets, get_last_checkpoint_path, @@ -469,6 +470,12 @@ def prepare_deepspeed(accelerator, model): def main(args: FlatArguments): + # try to import OLMo for automodel + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment @@ -673,7 +680,9 @@ def load_model(): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or ( + check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast) + ): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token diff --git a/open_instruct/finetune.py b/open_instruct/finetune.py index 355d103b5..ea9ced3c0 100644 --- a/open_instruct/finetune.py +++ b/open_instruct/finetune.py @@ -56,6 +56,7 @@ from open_instruct.model_utils import push_folder_to_hub, save_with_accelerate from open_instruct.utils import ( ArgumentParserPlus, + check_hf_olmo_availability, clean_last_n_checkpoints, get_datasets, get_last_checkpoint_path, @@ -453,6 +454,12 @@ def encode_sft_example(example, tokenizer, max_seq_length): def main(args: FlatArguments): + # try to import OLMo for automodel + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + # Initialize the accelerator. We will let the accelerator handle device placement for us in this example. # If we're using tracking, we also need to initialize it here and it will by default pick up all supported trackers # in the environment @@ -646,7 +653,9 @@ def main(args: FlatArguments): 0, 1, ], "LlamaTokenizer should only add one special token - the pad_token, or no tokens if pad token present." - elif isinstance(tokenizer, GPTNeoXTokenizerFast): + elif isinstance(tokenizer, GPTNeoXTokenizerFast) or ( + check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast) + ): # OLMo newer models use this tokenizer if tokenizer.bos_token is None: tokenizer.bos_token = tokenizer.eos_token diff --git a/open_instruct/ground_truth_utils.py b/open_instruct/ground_truth_utils.py index 5c1705123..1f4a21d68 100644 --- a/open_instruct/ground_truth_utils.py +++ b/open_instruct/ground_truth_utils.py @@ -149,4 +149,4 @@ def verify_flan_sample(model_output, ground_truth_answer): test_model_output = "<|assistant|>\nThe answer is $\\boxed{3.14}$" for sample in ds['train']: print(sample) - verify_ifeval_sample(test_model_output, sample['ground_truth']) + verify_ifeval_sample(test_model_output, sample['ground_truth']) \ No newline at end of file diff --git a/open_instruct/mix_data.py b/open_instruct/mix_data.py index 85fa815b1..6c488849d 100644 --- a/open_instruct/mix_data.py +++ b/open_instruct/mix_data.py @@ -14,9 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from open_instruct.finetune import FlatArguments # script for mixing and saving data +from open_instruct.finetune import FlatArguments from open_instruct.utils import ArgumentParserPlus, get_datasets # Run as module for local imports, e.g.: diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index f0a2d48bf..dd5dbeb1f 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -177,8 +177,8 @@ def get_reward( output = lm_backbone( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, - return_dict=True, + # position_ids=position_ids, + # return_dict=True, output_hidden_states=True, use_cache=False, # otherwise mistral-based RM would error out ) @@ -266,12 +266,12 @@ def forward( The output of the model, including hidden states. """ attention_mask = query_responses != pad_token_id - position_ids = attention_mask.cumsum(1) - attention_mask.long() + # position_ids = attention_mask.cumsum(1) - attention_mask.long() input_ids = torch.masked_fill(query_responses, ~attention_mask, 0) return model( input_ids=input_ids, attention_mask=attention_mask, - position_ids=position_ids, + # # position_ids=position_ids, return_dict=True, output_hidden_states=True, ) diff --git a/open_instruct/olmo_adapter/modeling_olmo2.py b/open_instruct/olmo_adapter/modeling_olmo2.py new file mode 100644 index 000000000..19c428edb --- /dev/null +++ b/open_instruct/olmo_adapter/modeling_olmo2.py @@ -0,0 +1,221 @@ +from typing import Callable, Optional, Union, List, Tuple +import torch +from torch import nn +from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss +from transformers import PreTrainedModel +from transformers.modeling_outputs import SequenceClassifierOutputWithPast +from hf_olmo import OLMoTokenizerFast, OLMoConfig, OLMoForCausalLM +from hf_olmo.modeling_olmo import OLMo, create_model_config_from_pretrained_config, ActivationCheckpointingStrategy + +class OLMoForSequenceClassification(PreTrainedModel): + + config_class = OLMoConfig + base_model_prefix = "model" + _no_split_modules = ["OLMoBlock"] + _supports_flash_attn_2 = True + _supports_sdpa = True + supports_gradient_checkpointing = True + + def __init__(self, config: OLMoConfig, model: Optional[OLMo] = None, init_params: bool = False): + super().__init__(config) + self._gradient_checkpointing_func: Optional[Callable] = None + self._gradient_checkpointing = False + + self.num_labels = config.num_labels + if not model: + model_config = create_model_config_from_pretrained_config(config) + # Initialize model (always on CPU to start with so we don't run out of GPU memory). + model_config.init_device = "cpu" + self.model = OLMo(model_config, init_params=init_params) + else: + self.model = model + + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + + @property + def gradient_checkpointing(self) -> bool: + return self._gradient_checkpointing + + @gradient_checkpointing.setter + def gradient_checkpointing(self, enabled: bool): + if self._gradient_checkpointing == enabled: + return + + # HF does not specify a way to pass checkpointing strategies, so we pick + # whole layer as our strategy. We can make this configurable later if needed. + checkpointing_strategy = ActivationCheckpointingStrategy.whole_layer if enabled else None + self.model.set_activation_checkpointing( + checkpointing_strategy, checkpoint_func=self._gradient_checkpointing_func + ) + self._gradient_checkpointing = enabled + + def get_input_embeddings(self) -> torch.nn.Module: + return self.model.transformer.wte + + def set_input_embeddings(self, value: torch.nn.Module): + self.model.transformer.wte = value + + def get_output_embeddings(self): + if self.config.weight_tying: + return self.model.transformer.wte + else: + return self.model.transformer.ff_out + + def set_output_embeddings(self, value: torch.nn.Module): + if self.config.weight_tying: + self.model.transformer.wte = value + else: + self.model.transformer.ff_out = value + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> torch.nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.embedding_size`. + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + Arguments: + new_num_tokens (`int`, *optional*): + The new number of tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + pad_to_multiple_of (`int`, *optional*): + If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to + `None` will just pad the embedding to a multiple of `pad_to_multiple_of`. + This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability + `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more + details about this, or help on choosing the correct value for resizing, refer to this guide: + https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + Note: + This method differs from the base class implementation by resizing the `embedding_size` attribute of the + model configuration instead of the `vocab_size`. It also includes a warning if the resized `embedding_size` + is less than the `vocab_size`. In OLMo, `embedding_size` refers to the dimensionality of the model's token + embeddings, while `vocab_size` refers to the number of unique tokens in the vocabulary. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + + # Update base model and current model config + self.config.embedding_size = model_embeds.weight.shape[0] + self.model.config.embedding_size = model_embeds.weight.shape[0] + + # Check if the embedding size is less than the vocab size + if self.config.embedding_size < self.config.vocab_size: + warning_message = ( + f"Resizing token embeddings to size {self.config.embedding_size}, which is less than the vocab size " + f"{self.config.vocab_size} defined in the model configuration. Make sure your tokenizer's vocabulary " + "size is less than or equal to the new token embedding size." + ) + # log.warning(warning_message) + print(warning_message) + + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + """ + Forward pass for sequence classification with OLMo. + + Args: + input_ids: Input token IDs + attention_mask: Attention mask + position_ids: Position IDs for positional encoding + past_key_values: Past key values for incremental decoding + inputs_embeds: Pre-computed input embeddings + labels: Labels for computing loss + use_cache: Whether to use cached key/values + output_attentions: Whether to output attention weights + output_hidden_states: Whether to output hidden states + return_dict: Whether to return a ModelOutput object + + Returns: + SequenceClassifierOutputWithPast or tuple: Classification outputs + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/open_instruct/olmo_adapter/olmo_1124_vllm.py b/open_instruct/olmo_adapter/olmo_1124_vllm.py new file mode 100644 index 000000000..a5a6061b0 --- /dev/null +++ b/open_instruct/olmo_adapter/olmo_1124_vllm.py @@ -0,0 +1,539 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + + +class FlippedSiluAndMul(SiluAndMul): + """OLMo is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + + attention_layer_norm = True + if attention_layer_norm: + # TODO: finish adding qk norm and norm_after + self.k_norm = RMSNorm( + (config.hidden_size // config.num_attention_heads) * config.num_key_value_heads, + eps=config.rms_norm_eps, + #elementwise_affine=config.attention_layer_norm_with_affine, + #bias=False, + ) + self.q_norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + #q = self.q_norm(q) + #k = self.k_norm(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + try: + self.intermediate_size = config.intermediate_size + except AttributeError: + if config.mlp_hidden_size is not None: + self.intermediate_size = config.mlp_hidden_size // 2 + else: + self.intermediate_size = (config.hidden_size * config.mlp_ratio) // 2 + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + + self.norm_after = True + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + """ + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + """ + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + if self.norm_after: + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + if self.norm_after: + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class OlmoModel(nn.Module): + + def __init__(self, + config: Union[OlmoConfig], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + OlmoDecoderLayer(config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm( + config.hidden_size, + eps=config.rms_norm_eps, + #elementwise_affine=config.layer_norm_with_affine, + #bias=config.bias_for_layer_norm + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + # Get embeddings of input. + # shape: (batch_size, seq_len, hidden_size) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + + # Apply blocks one-by-one. + for layer_idx, decoder_layer in enumerate(self.layers): + # shape: (batch_size, seq_len, hidden_size) + hidden_states = decoder_layer( + positions, + hidden_states, + kv_caches[layer_idx], + attn_metadata, + ) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, hidden_size) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoNewForCausalLM(nn.Module): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.model = OlmoModel(config, cache_config, quant_config) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + #self.unpadded_vocab_size, + config.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + #org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + mapper = {} + "loaded weights -> uninitialized model weights" + for layer_i in range(self.config.num_hidden_layers): + mapper[f"model.layers.{layer_i}.post_attention_layernorm.weight"] = f"model.layers.{layer_i}.input_layernorm.weight" + mapper[f"model.layers.{layer_i}.post_feedforward_layernorm.weight"] = f"model.layers.{layer_i}.post_attention_layernorm.weight" + # from rich.pretty import pprint + # pprint(mapper) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + # print("loaded", name, param) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[mapper.get(name, name)] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # print("loaded", name, param) + + +# loaded model.embed_tokens.weight +# loaded model.layers.0.self_attn.o_proj.weight +# loaded model.layers.0.self_attn.q_norm.weight +# loaded model.layers.0.self_attn.k_norm.weight +# loaded model.layers.0.mlp.down_proj.weight +# loaded model.layers.0.post_attention_layernorm.weight +# loaded model.layers.0.post_feedforward_layernorm.weight +# loaded model.layers.1.self_attn.o_proj.weight +# loaded model.layers.1.self_attn.q_norm.weight +# loaded model.layers.1.self_attn.k_norm.weight +# loaded model.layers.1.mlp.down_proj.weight +# loaded model.layers.1.post_attention_layernorm.weight +# loaded model.layers.1.post_feedforward_layernorm.weight +# loaded model.layers.2.self_attn.o_proj.weight +# loaded model.layers.2.self_attn.q_norm.weight +# loaded model.layers.2.self_attn.k_norm.weight +# loaded model.layers.2.mlp.down_proj.weight +# loaded model.layers.2.post_attention_layernorm.weight +# loaded model.layers.2.post_feedforward_layernorm.weight +# loaded model.norm.weight +# loaded lm_head.weight + +# OlmoNewForCausalLM( +# (model): OlmoModel( +# (embed_tokens): VocabParallelEmbedding(num_embeddings=100352, embedding_dim=4096, org_vocab_size=100352, num_embeddings_padded=100352, tp_size=1) +# (layers): ModuleList( +# (0-31): 32 x OlmoDecoderLayer( +# (self_attn): OlmoAttention( +# (qkv_proj): QKVParallelLinear(in_features=4096, output_features=12288, bias=False, tp_size=1, gather_output=False) +# (k_norm): RMSNorm(hidden_size=4096, eps=1e-06) +# (q_norm): RMSNorm(hidden_size=4096, eps=1e-06) +# (rotary_emb): RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=4096, base=500000, is_neox_style=True) +# (attn): Attention(head_size=128, num_heads=32, num_kv_heads=32, scale=0.08838834764831845, backend=FlashAttentionImpl) +# (o_proj): RowParallelLinear(input_features=4096, output_features=4096, bias=False, tp_size=1, reduce_results=True) +# ) +# (mlp): OlmoMLP( +# (gate_up_proj): MergedColumnParallelLinear(in_features=4096, output_features=22016, bias=False, tp_size=1, gather_output=False) +# (act_fn): FlippedSiluAndMul() +# (down_proj): RowParallelLinear(input_features=11008, output_features=4096, bias=False, tp_size=1, reduce_results=True) +# ) +# (input_layernorm): RMSNorm(hidden_size=4096, eps=1e-06) +# (post_attention_layernorm): RMSNorm(hidden_size=4096, eps=1e-06) +# ) +# ) +# (norm): RMSNorm(hidden_size=4096, eps=1e-06) +# ) +# (lm_head): ParallelLMHead(num_embeddings=100352, embedding_dim=4096, org_vocab_size=100352, num_embeddings_padded=100352, tp_size=1) +# (logits_processor): LogitsProcessor(vocab_size=100352, forg_vocab_size=100352, scale=1.0, logits_as_input=False) +# (sampler): Sampler() +# ) +# Olmo1124ForCausalLM( +# (model): Olmo1124Model( +# (embed_tokens): Embedding(100352, 4096, padding_idx=100277) +# (layers): ModuleList( +# (0-31): 32 x Olmo1124DecoderLayer( +# (self_attn): Olmo1124SdpaAttention( +# (q_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (k_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (v_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (o_proj): Linear(in_features=4096, out_features=4096, bias=False) +# (rotary_emb): Olmo1124RotaryEmbedding() +# (q_norm): Olmo1124RMSNorm((4096,), eps=1e-06) +# (k_norm): Olmo1124RMSNorm((4096,), eps=1e-06) +# ) +# (mlp): Olmo1124MLP( +# (gate_proj): Linear(in_features=4096, out_features=11008, bias=False) +# (up_proj): Linear(in_features=4096, out_features=11008, bias=False) +# (down_proj): Linear(in_features=11008, out_features=4096, bias=False) +# (act_fn): SiLU() +# ) +# (post_attention_layernorm): Olmo1124RMSNorm((4096,), eps=1e-06) +# (post_feedforward_layernorm): Olmo1124RMSNorm((4096,), eps=1e-06) +# ) +# ) +# (norm): Olmo1124RMSNorm((4096,), eps=1e-06) +# ) +# (lm_head): Linear(in_features=4096, out_features=100352, bias=False) +# ) + +# OLMoForCausalLM( +# (model): OLMo( +# (transformer): ModuleDict( +# (wte): Embedding(100352, 4096) +# (emb_drop): Dropout(p=0.0, inplace=False) +# (ln_f): RMSLayerNorm() +# (blocks): ModuleList( +# (0-31): 32 x OLMoSequentialBlock( +# (dropout): Dropout(p=0.0, inplace=False) +# (k_norm): RMSLayerNorm() +# (q_norm): RMSLayerNorm() +# (act): SwiGLU() +# (attn_out): Linear(in_features=4096, out_features=4096, bias=False) +# (ff_out): Linear(in_features=11008, out_features=4096, bias=False) +# (rotary_emb): RotaryEmbedding() +# (att_proj): Linear(in_features=4096, out_features=12288, bias=False) +# (ff_proj): Linear(in_features=4096, out_features=22016, bias=False) +# (attn_norm): RMSLayerNorm() +# (ff_norm): RMSLayerNorm() +# ) +# ) +# (ff_out): Linear(in_features=4096, out_features=100352, bias=False) +# ) +# ) +# ) \ No newline at end of file diff --git a/open_instruct/olmo_adapter/olmo_1124_vllm2.py b/open_instruct/olmo_adapter/olmo_1124_vllm2.py new file mode 100644 index 000000000..d7bcf8972 --- /dev/null +++ b/open_instruct/olmo_adapter/olmo_1124_vllm2.py @@ -0,0 +1,376 @@ +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, + maybe_prefix) + + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.attention_bias, + quant_config=quant_config, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = SiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +@support_torch_compile +class OlmoModel(nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") + self.norm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + if get_pp_group().is_first_rank: + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + # Apply blocks one-by-one. + for i in range(self.start_layer, self.end_layer): + # shape: (batch_size, seq_len, d_model) + hidden_states = self.layers[i]( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoForCausalLM(nn.Module, SupportsPP): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.model = OlmoModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) \ No newline at end of file diff --git a/open_instruct/olmo_adapter/olmo_new.py b/open_instruct/olmo_adapter/olmo_new.py new file mode 100644 index 000000000..0eb0e3db4 --- /dev/null +++ b/open_instruct/olmo_adapter/olmo_new.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.40.1/src/transformers/models/olmo/modeling_olmo.py +# Copyright 2024 The vLLM team. +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import OlmoConfig +from hf_olmo.configuration_olmo import OLMoConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + + +class FlippedSiluAndMul(SiluAndMul): + """OLMo is trained with SwiGLU with flipped halves.""" + + def forward(self, x: torch.Tensor): + a, b = x.chunk(2, dim=-1) + flipped = torch.cat((b, a), dim=-1) + return super().forward(flipped) + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as + ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + self.total_num_heads = config.num_attention_heads + + assert self.hidden_size % self.total_num_heads == 0 + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_sequence_length + self.rope_theta = config.rope_theta + self.clip_qkv = config.clip_qkv + + # Attention input projection. Projects x -> (q, k, v) + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=config.include_bias, + quant_config=quant_config, + ) + + if config.attention_layer_norm: + # TODO: finish adding qk norm and norm_after + self.k_norm = RMSNorm( + (config.d_model // config.n_heads) * config.effective_n_kv_heads, + eps=config.layer_norm_eps, + #elementwise_affine=config.attention_layer_norm_with_affine, + #bias=False, + ) + self.q_norm = RMSNorm( + config.hidden_size, + eps=config.layer_norm_eps, + ) + + # Rotary embeddings. + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = Attention(self.num_heads, + self.head_dim, + scale=self.scaling, + cache_config=cache_config, + quant_config=quant_config) + + # Attention output projection. + self.o_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.include_bias, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + #q = self.q_norm(q) + #k = self.k_norm(k) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as + ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OlmoConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + try: + self.intermediate_size = config.intermediate_size + except AttributeError: + if config.mlp_hidden_size is not None: + self.intermediate_size = config.mlp_hidden_size // 2 + else: + self.intermediate_size = (config.d_model * config.mlp_ratio) // 2 + + # Feed-forward input projection. + self.gate_up_proj = MergedColumnParallelLinear( + self.hidden_size, + [self.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + ) + + # Activation function. + self.act_fn = FlippedSiluAndMul() + + # Feed-forward output projection. + self.down_proj = RowParallelLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + quant_config=quant_config, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class OlmoDecoderLayer(nn.Module): + """ + This is a typical transformer block where the output is + computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + # Attention block. + self.self_attn = OlmoAttention(config, cache_config, quant_config) + + # MLP block. + self.mlp = OlmoMLP(config, quant_config) + + # LayerNorm + + self.norm_after = config.norm_after + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.layer_norm_eps) + + """ + self.input_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + elementwise_affine=False, + bias=False) + """ + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + residual = hidden_states + if self.norm_after: + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, kv_cache, + attn_metadata) + hidden_states = hidden_states + residual + + # MLP block. + residual = hidden_states + if self.norm_after: + hidden_states = self.mlp(hidden_states) + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class OlmoModel(nn.Module): + + def __init__(self, + config: Union[OlmoConfig, OLMoConfig], + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding(config.embedding_size, + config.hidden_size) + self.layers = nn.ModuleList([ + OlmoDecoderLayer(config, cache_config, quant_config) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm( + config.hidden_size, + eps=config.layer_norm_eps, + #elementwise_affine=config.layer_norm_with_affine, + #bias=config.bias_for_layer_norm + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) + + # embed positions + hidden_states = inputs_embeds + + # Apply blocks one-by-one. + for layer_idx, decoder_layer in enumerate(self.layers): + # shape: (batch_size, seq_len, d_model) + hidden_states = decoder_layer( + positions, + hidden_states, + kv_caches[layer_idx], + attn_metadata, + ) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class OlmoNewForCausalLM(nn.Module): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, + config: OlmoConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.model = OlmoModel(config, cache_config, quant_config) + if config.weight_tying: + self.lm_head = self.model.embed_tokens + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + #self.unpadded_vocab_size, + config.embedding_size, + config.hidden_size, + org_num_embeddings=config.embedding_size, + #org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + self.logits_processor = LogitsProcessor(config.embedding_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + attn_metadata=attn_metadata, + ) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def _create_map(self): + mapper = {} + for layer_i in range(self.config.n_layers): + mapper[f"model.transformer.blocks.{layer_i}.att_proj.weight"] = f"model.layers.{layer_i}.self_attn.qkv_proj.weight" + mapper[f"model.transformer.blocks.{layer_i}.attn_out.weight"] = f"model.layers.{layer_i}.self_attn.o_proj.weight" + mapper[f"model.transformer.blocks.{layer_i}.ff_proj.weight"] = f"model.layers.{layer_i}.mlp.gate_up_proj.weight" + mapper[f"model.transformer.blocks.{layer_i}.ff_out.weight"] = f"model.layers.{layer_i}.mlp.down_proj.weight" + + mapper[f"model.transformer.blocks.{layer_i}.attn_norm.weight"] = f"model.layers.{layer_i}.input_layernorm.weight" + mapper[f"model.transformer.blocks.{layer_i}.ff_norm.weight"] = f"model.layers.{layer_i}.post_attention_layernorm.weight" + mapper[f"model.transformer.blocks.{layer_i}.k_norm.weight"] = f"model.layers.{layer_i}.self_attn.k_norm.weight" + mapper[f"model.transformer.blocks.{layer_i}.q_norm.weight"] = f"model.layers.{layer_i}.self_attn.q_norm.weight" + + mapper["model.transformer.ln_f.weight"] = "model.norm.weight" + mapper["model.transformer.wte.weight"] = "model.embed_tokens.weight" + mapper["model.transformer.ff_out.weight"] = "lm_head.weight" + return mapper + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + mapper = self._create_map() + # print("mapper", mapper) + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.weight_tying and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[mapper.get(name, name)] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl.py b/open_instruct/ppo_vllm_thread_ray_gtrl.py index 084d91943..de36c7489 100644 --- a/open_instruct/ppo_vllm_thread_ray_gtrl.py +++ b/open_instruct/ppo_vllm_thread_ray_gtrl.py @@ -1370,7 +1370,7 @@ def vllm_generate( # Ai2 logic: we use /output to store the artifacts of the job, so we # make a copy of the model to `/output` in the end. - if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0: + if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 and args.output_dir != "/output": shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) print("finished training") @@ -1472,7 +1472,7 @@ def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = --preemptible \ --use_hf_tokenizer_template \ --beaker_image "nathanl/open_instruct_auto" \ - --upload_to_hf allenai/tulu-3-evals \ + --upload_to_hf {args.hf_metadata_dataset} \ --run_oe_eval_experiments \ --evaluate_on_weka \ --run_safety_evaluations \ diff --git a/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py new file mode 100644 index 000000000..d39f14fed --- /dev/null +++ b/open_instruct/ppo_vllm_thread_ray_gtrl_olmo.py @@ -0,0 +1,1833 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# --------------------------------------------------------------------- +# Part of the code is adapted from https://github.com/OpenRLHF/OpenRLHF +# which has the following license: +# Copyright [yyyy] [name of copyright owner] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import json +import logging +import os +import random +import shutil +import socket +import subprocess +import threading +import time +from argparse import Namespace +from dataclasses import asdict, dataclass, field +from queue import Empty, Queue +from typing import Any, Callable, Iterator, List, Literal, Optional, Tuple + +import deepspeed +import numpy as np +import pandas as pd +import ray +import torch +import torch.distributed as dist +import torch.nn.functional as F +import torch.utils +import torch.utils.data +import vllm +from datasets import Dataset, DatasetDict +from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus +from huggingface_hub import HfApi +from ray.util.placement_group import PlacementGroup, placement_group +from ray.util.queue import Queue as RayQueue +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from rich.pretty import pprint +from torch.utils.tensorboard import SummaryWriter +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForSequenceClassification, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + get_scheduler, +) +from transformers.integrations import HfDeepSpeedConfig +from vllm import SamplingParams + +from open_instruct.dataset_processor import ( + CHAT_TEMPLATES, + DATASET_SOURCE_KEY, + GROUND_TRUTHS_KEY, + INPUT_IDS_PROMPT_KEY, + DatasetConfig, + SFTGroundTruthDatasetProcessor, + SimpleGenerateCollatorWithGroundTruth, + visualize_token, +) +from open_instruct.model_utils import ( + ModelConfig, + apply_verifiable_reward, + disable_dropout_in_model, + exact_div, + first_true_indices, + forward, + get_reward, + print_rich_single_line_metrics, + print_rich_table, + push_folder_to_hub, + truncate_response, +) +from open_instruct.utils import ( + ArgumentParserPlus, + BeakerRuntimeConfig, + combine_dataset, + get_wandb_tags, + is_beaker_job, + maybe_get_beaker_config, + maybe_use_ai2_hf_entity, + maybe_use_ai2_wandb_entity, + upload_metadata_to_hf, +) +from open_instruct.vllm_utils2 import create_vllm_engines, init_process_group + +api = HfApi() +INVALID_LOGPROB = 1.0 + + +@dataclass +class Args: + # required dataset args + dataset_mixer: str = None + """A dictionary of datasets (local or HF) to sample from.""" + dataset_train_splits: List[str] = None + """The dataset splits to use for training""" + dataset_eval_mixer: Optional[str] = None + """A dictionary of datasets (local or HF) to sample from for evaluation""" + dataset_eval_splits: Optional[List[str]] = None + """The dataset splits to use for evaluation""" + dataset_mixer_dict: Optional[dict] = None + """The dataset mixer as a dictionary""" + dataset_eval_mixer_dict: Optional[dict] = None + """The dataset eval mixer as a dictionary""" + + # common args + exp_name: str = os.path.basename(__file__)[: -len(".py")] + """The name of this experiment""" + seed: int = 1 + """Seed of the experiment""" + run_name: Optional[str] = None + """A unique name of this run""" + + # optimizer args + eps: float = 1e-5 + """The epsilon value for the optimizer""" + learning_rate: float = 2e-5 + """The initial learning rate for AdamW optimizer.""" + lr_scheduler_type: Literal[ + "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup" + ] = "linear" + """Which scheduler to use""" + warm_up_steps: int = 0 + """Number of warm up steps for the scheduler""" + warmup_ratio: float = 0.0 + """Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)""" + + # various batch sizes + num_train_epochs: int = 1 + """Number of epochs to train""" + gradient_accumulation_steps: Optional[int] = None + """The number of gradient accumulation steps""" + per_device_train_batch_size: Optional[int] = 1 + """The forward batch size per device (local_micro_batch_size)""" + per_device_eval_batch_size: Optional[int] = 1 + """The forward batch size per device for evaluation (local_micro_batch_size)""" + total_episodes: Optional[int] = 100000 + """The total number of episodes in the dataset""" + world_size: Optional[int] = None + """The number of processes (GPUs) to use""" + micro_batch_size: Optional[int] = None + """The micro batch size across devices (HF's `per_device_train_batch_size` * `world_size`)""" + local_rollout_batch_size: int = 64 + """The number of rollout episodes per iteration per device""" + rollout_batch_size: Optional[int] = None + """The number of rollout episodes per iteration""" + num_training_steps: Optional[int] = None + """The number of training_steps to train""" + num_evals: int = 4 + """The number of evaluations to run throughout training""" + eval_freq: Optional[int] = None + """The frequency of evaluation steps""" + local_dataloader_batch_size: Optional[int] = None + """The batch size per GPU for the dataloader""" + save_freq: int = -1 + """How many train steps to save the model""" + + # online settings + num_epochs: int = 4 + """the number of epochs to train""" + num_mini_batches: int = 1 + """Number of minibatches to split a batch into""" + local_mini_batch_size: int = 64 + """the mini batch size per GPU""" + mini_batch_size: Optional[int] = None + """the mini batch size across GPUs""" + local_rollout_forward_batch_size: int = 64 + """per rank no grad forward pass in the rollout phase""" + reward_model_path: str = "EleutherAI/pythia-160m" + """the path to the reward model""" + reward_model_revision: Optional[str] = None + """the revision of the reward model""" + init_value_from_scratch: bool = False + """whether to initialize the value model from scratch""" + + # generation config + response_length: int = 53 + """the length of the response""" + stop_token: Optional[Literal["eos", "period"]] = None + """the stop token""" + stop_token_id: Optional[int] = None + """the truncation token id""" + min_response_length: int = 0 + """stop only after this many tokens""" + temperature: float = 0.7 + """the sampling temperature""" + penalty_reward_value: float = -1.0 + """the reward value for responses that do not contain `stop_token_id`""" + non_stop_penalty: bool = False + """whether to penalize responses that do not contain `stop_token_id`""" + number_samples_per_prompt: int = 1 + """the number of samples to generate per prompt, useful for easy-star""" + + # online PPO specific args + beta: float = 0.05 + """the beta value of the RLHF objective (KL coefficient)""" + whiten_rewards: bool = False + """whether to whiten the rewards""" + cliprange: float = 0.2 + """the clip range""" + vf_coef: float = 0.1 + """the value function coefficient""" + cliprange_value: float = 0.2 + """the clip range for the value function""" + gamma: float = 1 + """the discount factor""" + lam: float = 0.95 + """the lambda value for GAE""" + kl_estimator: Literal["kl1", "kl2", "kl3"] = "kl1" + """the KL estimator to use""" + apply_verifiable_reward: bool = False + """whether to apply verifiable reward""" + reward_model_multiplier: float = 1.0 + """the reward model multiplier, for down/upscaling the reward model output""" + answer_extraction_model: str = None + + # async setting + async_mode: bool = True + """Whether to run the generation in async mode which learns from the second latest policy like Cleanba (https://arxiv.org/abs/2310.00036)""" + + # ray + actor_num_gpus_per_node: List[int] = field(default_factory=lambda: [1]) + """number of gpus per node for actor""" + vllm_num_engines: int = 1 + """number of vLLM Engines, set to 0 to disable vLLM""" + vllm_tensor_parallel_size: int = 1 + """tensor parallel size of vLLM Engine for multi-GPU inference""" + vllm_sync_backend: str = "nccl" + """DeepSpeed -> vLLM weight sync backend""" + enable_prefix_caching: bool = False + """whether to enable prefix caching""" + deepspeed_stage: int = 0 + """the deepspeed stage""" + gather_whole_model: bool = True + """whether to gather the whole model to boardcast (not doable for 70B but can be faster for 8B)""" + + # wandb and HF tracking configs + with_tracking: bool = False + """If toggled, this experiment will be tracked with Weights and Biases""" + wandb_project_name: str = "open_instruct_internal" + """The wandb's project name""" + wandb_entity: Optional[str] = None + """The entity (team) of wandb's project""" + push_to_hub: bool = True + """Whether to upload the saved model to huggingface""" + hf_entity: Optional[str] = None + """The user or org name of the model repository from the Hugging Face Hub""" + hf_repo_id: Optional[str] = None + """The id of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_revision: Optional[str] = None + """The revision of the saved model in the Hugging Face Hub (can be autoset if not given)""" + hf_repo_url: Optional[str] = None + """The url of the saved model in the Hugging Face Hub (will be autoset)""" + output_dir: Optional[str] = None + """Where to save the model""" + checkpoint_output_dir: Optional[str] = None + """Where to save the model checkpoints in case of preemption""" + + # Ai2 specific settings + try_launch_beaker_eval_jobs: bool = True + """Whether to launch beaker evaluation jobs after training""" + try_launch_beaker_eval_jobs_on_weka: bool = False + """Whether to launch beaker evaluation jobs after training on weka""" + oe_eval_tasks: Optional[List[str]] = None + """The beaker evaluation tasks to launch""" + hf_metadata_dataset: Optional[str] = "allenai/tulu-3-evals" + """What dataset to upload the metadata to. If unset, don't upload metadata""" + + def __post_init__(self): + self.dataset_mixer_dict, self.dataset_mixer = process_dataset_mixer(self.dataset_mixer) + if self.dataset_eval_mixer is not None: + self.dataset_eval_mixer_dict, self.dataset_eval_mixer = process_dataset_mixer(self.dataset_eval_mixer) + + +def process_dataset_mixer(value) -> Tuple[Optional[dict], Optional[str]]: + # if passed through cli: convert the dataset mixers to dictionaries + if isinstance(value, str): + return json.loads(value), value + # if passed through yaml: convert the dataset mixers to strings + elif isinstance(value, dict): + return value, json.dumps(value) + else: + raise ValueError("Input must be either a string or a dictionary") + + +def calculate_runtime_args(args: Args, model_config: ModelConfig): + """calculate (in-place) runtime args such as the effective batch size, word size, etc.""" + # accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps) + # args.world_size = accelerator.num_processes + args.run_name = f"{args.exp_name}__{args.seed}__{int(time.time())}" + args.gradient_accumulation_steps = exact_div( + args.local_mini_batch_size, + args.per_device_train_batch_size, + "`local_mini_batch_size` must be a multiple of `per_device_train_batch_size`", + ) + args.world_size = sum(args.actor_num_gpus_per_node) + args.micro_batch_size = int(args.per_device_train_batch_size * args.world_size) + args.rollout_batch_size = int(args.local_rollout_batch_size * args.world_size) + args.mini_batch_size = int(args.local_mini_batch_size * args.world_size) + args.num_training_steps = args.total_episodes // (args.rollout_batch_size * args.number_samples_per_prompt) + args.eval_freq = max(1, args.num_training_steps // args.num_evals) + # PPO logic: do checks and set up dataloader batch size + if args.whiten_rewards: + assert ( + args.local_mini_batch_size >= 8 + ), f"Per-rank minibatch size {args.local_mini_batch_size} is insufficient for whitening" + args.local_dataloader_batch_size = args.rollout_batch_size + if args.push_to_hub: + if args.hf_repo_id is None: # auto-generate one + args.hf_repo_id = "open_instruct_dev" + if args.hf_entity is None: # first try to use AI2 entity + args.hf_entity = maybe_use_ai2_hf_entity() + if args.hf_entity is None: # then try to use the user's entity + args.hf_entity = HfApi().whoami()["name"] + args.hf_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" + if args.hf_repo_revision is None: # auto-generate one + args.hf_repo_revision = args.run_name + args.hf_repo_url = f"https://huggingface.co/{args.hf_repo_id}/tree/{args.hf_repo_revision}" + + if args.with_tracking: + if args.wandb_entity is None: + args.wandb_entity = maybe_use_ai2_wandb_entity() + + +def get_train_ds_config( + offload, + adam_offload=False, + stage=0, + bf16=True, + max_norm=1.0, + zpg=8, + grad_accum_dtype=None, + disable_trace_cache=True, +): + device = "cpu" if offload else "none" + zero_opt_dict = { + "stage": stage, + "offload_param": {"device": device}, + "offload_optimizer": { + "device": "cpu" if adam_offload else "none", + "pin_memory": True, + }, + "sub_group_size": "auto", + "stage3_max_live_parameters": "auto", + "stage3_max_reuse_distance": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_prefetch_bucket_size": "auto", + "reduce_bucket_size": "auto", + # # ZeRO++ + # "zero_hpz_partition_size": zpg, + # "zero_quantized_weights": False, + # "zero_quantized_gradients": False, + } + if disable_trace_cache: + zero_opt_dict["stage3_prefetch_bucket_size"] = 0 + zero_opt_dict["stage3_max_live_parameters"] = 0 + zero_opt_dict["stage3_max_reuse_distance"] = 0 + + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "gradient_clipping": max_norm, + "prescale_gradients": False, + "wall_clock_breakdown": False, + "data_types": {"grad_accum_dtype": grad_accum_dtype if grad_accum_dtype else "fp32"}, + } + + +def get_eval_ds_config( + offload, + stage=0, + bf16=True, +): + zero_opt_dict = { + "stage": stage, + "stage3_param_persistence_threshold": "auto", + "offload_param": { + "device": "cpu" if offload else "none", + "pin_memory": True, + }, + } + return { + "steps_per_print": 100, + "zero_optimization": zero_opt_dict, + "bf16": { + "enabled": bf16, + }, + "prescale_gradients": False, + "wall_clock_breakdown": False, + } + + +def get_optimizer_grouped_parameters( + model, + weight_decay, + no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], +): + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in model.named_parameters() + if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": weight_decay, + }, + { + "params": [ + p + for n, p in model.named_parameters() + if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + return optimizer_grouped_parameters + + +def _z3_params_to_fetch(param_list): + return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] + + +def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + if axis is not None: + return (values * mask).sum(axis=axis) / mask.sum(axis=axis) + else: + return (values * mask).sum() / mask.sum() + + +def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: + """Compute variance of tensor with masked values.""" + mean = masked_mean(values, mask) + centered_values = values - mean + variance = masked_mean(centered_values**2, mask) + if unbiased: + mask_sum = mask.sum() + if mask_sum == 0: + raise ValueError( + "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" + "try increase the `mini_batch_size` or `gradient_accumulation_steps`" + ) + # note that if mask_sum == 1, then there is a division by zero issue + # to avoid it you just need to use a larger minibatch_size + bessel_correction = mask_sum / (mask_sum - 1) + variance = variance * bessel_correction + return variance + + +def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: + """Whiten values with masked values.""" + mean, var = masked_mean(values, mask), masked_var(values, mask) + whitened = (values - mean) * torch.rsqrt(var + 1e-8) + if not shift_mean: + whitened += mean + return whitened + + +def remove_padding(sequences, pad_token_id): + return [[inneritem for inneritem in item if inneritem != pad_token_id] for item in sequences] + + +class ShufflingIterator: + def __init__(self, data: np.ndarray, batch_size: int, seed: Optional[int] = None): + self.data = data.copy() + self.batch_size = batch_size + self.index = 0 + self.rng = np.random.default_rng(seed) + self.rng.shuffle(self.data) + + # Ensure the effective dataset size is divisible by batch_size + self.effective_size = len(self.data) - (len(self.data) % batch_size) + + def __iter__(self) -> Iterator[List[int]]: + return self + + def __next__(self) -> List[int]: + if self.index >= self.effective_size: + self.index = 0 + self.rng.shuffle(self.data) + + end_index = self.index + self.batch_size + batch = self.data[self.index : end_index].tolist() + self.index = end_index + + return batch + + +class RayProcess: + def __init__(self, world_size, rank, local_rank, master_addr, master_port): + logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", + ) + + # olmo 1124; pip install git+https://github.com/vwxyzjn/transformers.git@olmo1124_classification + from transformers.models.olmo2.modeling_olmo2 import Olmo2ForSequenceClassification, Olmo2Config + AutoModelForSequenceClassification.register(Olmo2Config, Olmo2ForSequenceClassification) + self.world_size = world_size + self.rank = rank + self.local_rank = local_rank + self.master_addr = master_addr if master_addr else self.get_current_node_ip() + self.master_port = master_port if master_port else self.get_free_port() + os.environ["MASTER_ADDR"] = self.master_addr + os.environ["MASTER_PORT"] = str(self.master_port) + os.environ["WORLD_SIZE"] = str(self.world_size) + os.environ["RANK"] = str(self.rank) + # NOTE: Ray will automatically set the CUDA_VISIBLE_DEVICES + # environment variable for each actor, so always set device to 0 + # os.environ["LOCAL_RANK"] = str(self._local_rank) + os.environ["LOCAL_RANK"] = "0" + random.seed(self.rank) + np.random.seed(self.rank) + torch.manual_seed(self.rank) + + @staticmethod + def get_current_node_ip(): + address = ray._private.services.get_node_ip_address() + # strip ipv6 address + return address.strip("[]") + + @staticmethod + def get_free_port(): + with socket.socket() as sock: + sock.bind(("", 0)) + return sock.getsockname()[1] + + def get_master_addr_port(self): + return self.master_addr, self.master_port + + def empty_cache(self) -> None: + torch.cuda.empty_cache() + + +@ray.remote(num_gpus=1) +class PolicyTrainerRayProcess(RayProcess): + def from_pretrained( + self, args: Args, model_config: ModelConfig, beaker_config: BeakerRuntimeConfig, wandb_url: str + ): + self.args = args + self.model_config = model_config + self.beaker_config = beaker_config + self.wandb_url = wandb_url + torch.cuda.set_device(self.local_rank) + deepspeed.init_distributed() + + ds_config = get_train_ds_config( + offload=False, + adam_offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.original_tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision + ) + self.policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.policy) + self.policy.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.policy, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.policy.parameters(), lr=args.learning_rate) + num_training_steps = args.num_training_steps * args.num_train_epochs * args.num_epochs + warm_up_steps = args.warm_up_steps + if args.warmup_ratio >= 0.0: + warm_up_steps = int(num_training_steps * args.warmup_ratio) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=warm_up_steps, + num_training_steps=num_training_steps, + ) + print(ds_config) + self.model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.policy, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.model.train() + + # value model + self.value_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + if args.init_value_from_scratch: + self.value_model.init_weights() # re-initialize the value model from scratch + disable_dropout_in_model(self.value_model) + self.value_model.gradient_checkpointing_enable() + # AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam + # AdamOptimizer = FusedAdam + # weight_decay = 0.0 + # optim_params = get_optimizer_grouped_parameters(self.value_model, weight_decay) + # self.optimizer = AdamOptimizer(optim_params, lr=args.learning_rate) + self.optimizer = torch.optim.AdamW(self.value_model.parameters(), lr=args.learning_rate) + scheduler = get_scheduler( + args.lr_scheduler_type, + optimizer=self.optimizer, + num_warmup_steps=warm_up_steps, + num_training_steps=num_training_steps, + ) + self.value_model, self.optimizer, _, self.scheduler = deepspeed.initialize( + model=self.value_model, + optimizer=self.optimizer, + config=ds_config, + lr_scheduler=scheduler, + dist_init_required=True, + ) + self.value_model.train() + + # reference model + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + # Costa: MAGIC: it's actually needed to initialize this `dschf`, so + # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration + # next line instructs transformers to partition the model directly over multiple gpus using + # deepspeed.zero.Init when model's `from_pretrained` method is called. + if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: + dschf = HfDeepSpeedConfig(ds_config) + else: + dschf = None + print(f"{dschf=}") + + self.ref_policy: PreTrainedModel = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, + revision=model_config.model_revision, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.ref_policy) + self.ref_policy, *_ = deepspeed.initialize(model=self.ref_policy, config=ds_config) + self.ref_policy.eval() + + # reward model + self.reward_model: PreTrainedModel = AutoModelForSequenceClassification.from_pretrained( + args.reward_model_path, + revision=args.reward_model_revision, + num_labels=1, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_2", + use_cache=False, + ) + disable_dropout_in_model(self.reward_model) + ds_config = get_eval_ds_config( + offload=False, + stage=args.deepspeed_stage, + bf16=True, + ) + ds_config["train_micro_batch_size_per_gpu"] = args.per_device_train_batch_size + ds_config["train_batch_size"] = args.mini_batch_size + self.reward_model, *_ = deepspeed.initialize(model=self.reward_model, config=ds_config) + self.reward_model.eval() + + def get_vocab_size(self): + return self.policy.config.vocab_size + + def forward( + self, + query_response: torch.LongTensor, + response: torch.LongTensor, + pad_token_id: int, + context_length: int, + temperature: float, + ) -> torch.Tensor: + output = forward(self.model, query_response, pad_token_id) + logits = output.logits[:, context_length - 1 : -1] + logits /= temperature + 1e-7 + all_logprob = F.log_softmax(logits, dim=-1) + logprob = torch.gather(all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + return logprob + + def train( + self, + train_dataset: Dataset, + eval_dataset: Dataset, + tokenizer: PreTrainedTokenizer, + vllm_engines: List[ray.actor.ActorHandle], + metrics_queue: RayQueue, + data_collator: Callable, + ): + torch.set_printoptions(precision=4, sci_mode=False) + + args = self.args + + accelerator = Namespace() + accelerator.process_index = self.rank + accelerator.num_processes = self.world_size + accelerator.is_main_process = self.rank == 0 + torch.distributed.barrier() + if self.rank == 0: + master_address = ray._private.services.get_node_ip_address() + with socket.socket() as sock: + sock.bind(("", 0)) + master_port = sock.getsockname()[1] + vllm_num_engines, vllm_tensor_parallel_size = ( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + ) + world_size = vllm_num_engines * vllm_tensor_parallel_size + 1 + backend = args.vllm_sync_backend + # https://github.com/OpenRLHF/OpenRLHF/issues/313 + if vllm.__version__ > "0.4.2" and os.getenv("NCCL_P2P_DISABLE", "0") == "0": + backend = "gloo" + print( + "Warning: using --vllm_sync_backend=gloo for vLLM version > 0.4.2 (or export NCCL_P2P_DISABLE=1)" + ) + refs = [ + engine.init_process_group.remote( + master_address, + master_port, + i * vllm_tensor_parallel_size + 1, + world_size, + "openrlhf", + backend=backend, + ) + for i, engine in enumerate(vllm_engines) + ] + self.model_update_group = init_process_group( + backend=backend, + init_method=f"tcp://{master_address}:{master_port}", + world_size=world_size, + rank=0, + group_name="openrlhf", + ) + ray.get(refs) + torch.distributed.barrier() + + def broadcast_to_vllm(): + # avoid OOM + torch.cuda.empty_cache() + model = self.model.module + count, num_params = 0, len(list(model.named_parameters())) + refss = [] + if args.gather_whole_model: + with deepspeed.zero.GatheredParameters(model.parameters(), enabled=args.deepspeed_stage == 3): + for name, param in model.named_parameters(): + count += 1 # empty_cache at last param + # Fire all vllm engines for broadcast + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + else: # broadcast each parameter independently + for name, param in model.named_parameters(): + count += 1 + if torch.distributed.get_rank() == 0: + shape = param.shape if args.deepspeed_stage != 3 else param.ds_shape + refs = [ + engine.update_weight.remote( + name, dtype=param.dtype, shape=shape, empty_cache=count == num_params + ) + for engine in vllm_engines + ] + refss.extend(refs) + with deepspeed.zero.GatheredParameters([param], enabled=args.deepspeed_stage == 3): + if torch.distributed.get_rank() == 0: + torch.distributed.broadcast(param.data, 0, group=self.model_update_group) + if torch.distributed.get_rank() == 0: + ray.get(refss) + + # broadcast_to_vllm() + if args.stop_token: + if args.stop_token == "eos": + args.stop_token_id = tokenizer.eos_token_id + if args.stop_token == "period": + args.stop_token_id = tokenizer.encode(".")[0] + # data_collator = SimpleGenerateCollator(pad_token_id=tokenizer.pad_token_id) + train_dataset_idxs = np.arange(len(train_dataset)) + shuffling_iter = ShufflingIterator(train_dataset_idxs, args.rollout_batch_size, seed=args.seed) + + # hack to left pad + def repeat_generator(): + while True: + batch_idxs = next(shuffling_iter) + yield [train_dataset[i] for i in batch_idxs] + + iter_dataloader = iter(repeat_generator()) + generation_config = SamplingParams( + temperature=args.temperature, + top_p=1.0, + max_tokens=args.response_length, + include_stop_str_in_output=True, + n=args.number_samples_per_prompt, + ) + # print("setup async queues") + param_prompt_Q = None + response_ids_Q = None + evaluation_Q = None + response_ids_Q = Queue(maxsize=1) + param_prompt_Q = Queue(maxsize=1) + evaluation_Q = Queue(maxsize=1) + num_eval_samples = 32 + sample_evaluation_prompt_token_ids = None + if eval_dataset is not None: + sample_evaluation_prompt_token_ids = eval_dataset[:num_eval_samples][INPUT_IDS_PROMPT_KEY] + + def vllm_generate( + generation_config: SamplingParams, + response_ids_Q: Queue, + param_prompt_Q: Queue, + num_training_steps: int, + sample_evaluation_prompt_token_ids: Optional[List[int]], + evaluation_Q: Queue, + eval_freq: int, + resume_training_step: int, + ): + llm = vllm_engines[0] + for training_step in range(resume_training_step, num_training_steps + 1): + items = param_prompt_Q.get() + if items is None: + break + unwrapped_model, g_queries_list = items + # if unwrapped_model is not None: + generation_start_time = time.time() + + outputs = ray.get( + llm.generate.remote( + sampling_params=generation_config, prompt_token_ids=g_queries_list, use_tqdm=False + ) + ) + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + print(f"🔥🔥🔥 Generation time: {time.time() - generation_start_time:.2f} seconds") + response_ids_Q.put(response_ids) + + if sample_evaluation_prompt_token_ids is not None and (training_step - 1) % eval_freq == 0: + outputs = ray.get( + llm.generate.remote( + prompt_token_ids=sample_evaluation_prompt_token_ids, + sampling_params=generation_config, + use_tqdm=False, + ) + ) + # for evaluation, even if we have multiple outputs, we only look at one of them for simplicity + response_ids = [list(output.outputs[0].token_ids) for output in outputs] + evaluation_Q.put(response_ids) + + resume_training_step = 1 + if accelerator.is_main_process: + thread = threading.Thread( + target=vllm_generate, + args=( + generation_config, + response_ids_Q, + param_prompt_Q, + args.num_training_steps, + sample_evaluation_prompt_token_ids, + evaluation_Q, + args.eval_freq, + resume_training_step, + ), + ) + thread.start() + print("vllm generate thread starts") + + # set up the metrics and initial states + device = torch.device(self.local_rank) + g_vllm_responses = torch.zeros( + (args.rollout_batch_size * args.number_samples_per_prompt, args.response_length), + device=device, + dtype=torch.long, + ) + stats_shape = ( + args.num_epochs, + args.num_mini_batches * args.number_samples_per_prompt, + args.gradient_accumulation_steps, + ) + approxkl_stats = torch.zeros(stats_shape, device=device) + pg_clipfrac_stats = torch.zeros(stats_shape, device=device) + pg_loss_stats = torch.zeros(stats_shape, device=device) + vf_loss_stats = torch.zeros(stats_shape, device=device) + vf_clipfrac_stats = torch.zeros(stats_shape, device=device) + entropy_stats = torch.zeros(stats_shape, device=device) + ratio_stats = torch.zeros(stats_shape, device=device) + local_metrics = torch.zeros((20,), device=device) + episode = args.rollout_batch_size * (resume_training_step - 1) + + # training loop + start_time = time.time() + global_data = next(iter_dataloader) + data = data_collator( + global_data[self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size] + ) + global_queries = data_collator(global_data)[ + INPUT_IDS_PROMPT_KEY + ].tolist() # can be simplified since we `remove_padding` later anyway + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + + answer_extraction_model = None + answer_extraction_tokenizer = None + # for _ in range(1, resume_training_step): # we didn't store scheduler state + # scheduler.step() + + for training_step in range(resume_training_step, args.num_training_steps + 1): + episode += args.rollout_batch_size * args.number_samples_per_prompt # each sample is an episode + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + if accelerator.is_main_process: + df = None + try: + evaluation_responses = evaluation_Q.get(timeout=0.01) + print("🔥🔥🔥 Evaluation responses received") + table = {} + table["prompt"] = tokenizer.batch_decode(sample_evaluation_prompt_token_ids) + table["response"] = tokenizer.batch_decode(evaluation_responses) + table["response"] = [item.replace(tokenizer.pad_token, "") for item in table["response"]] + df = pd.DataFrame(table) + del table + except Empty: + print("🙈 Evaluation responses not received") + + # (optionally) evaluate the model + if args.async_mode: + if training_step != 1: + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + else: + if training_step != 1: + # NOTE: important: the indent here is different for sync mode + # we also set to use `queries = queries_next` immediately + global_data = next(iter_dataloader) + data = data_collator( + global_data[ + self.rank * args.local_rollout_batch_size : (self.rank + 1) * args.local_rollout_batch_size + ] + ) + global_queries = data_collator(global_data)[INPUT_IDS_PROMPT_KEY].tolist() + queries_next = data[INPUT_IDS_PROMPT_KEY].to(device) + ground_truths_next = data[GROUND_TRUTHS_KEY] + datasets_next = data[DATASET_SOURCE_KEY] + start_time = time.time() + broadcast_to_vllm() + print( + f"🔥🔥🔥 Loading weights using shared memory; Time to load weights: {time.time() - start_time:.2f} seconds" + ) + if accelerator.is_main_process: + param_prompt_Q.put((None, remove_padding(global_queries, tokenizer.pad_token_id))) + queries = queries_next + ground_truths = ground_truths_next + datasets = datasets_next + + torch.cuda.empty_cache() + # print('get reward stuff starts') + # if we generate multiple samples per prompt, we need to repeat the queries and ground truths + # to match the vllm outputs. + if args.number_samples_per_prompt > 1: + queries = queries.repeat_interleave(args.number_samples_per_prompt, dim=0) + ground_truths = [gt for gt in ground_truths for _ in range(args.number_samples_per_prompt)] + datasets = [ds for ds in datasets for _ in range(args.number_samples_per_prompt)] + + training_time_start = time.time() + with torch.no_grad(): + context_length = queries.shape[1] + responses = [] + postprocessed_responses = [] + logprobs = [] + ref_logprobs = [] + scores = [] + verifiable_counts = [] + sequence_lengths = [] + values = [] + if accelerator.is_main_process: + g_response_token_ids = response_ids_Q.get() + DUMMY_PAD_TOKEN = 0 # we can't use tokenizer.pad_token_id because it's outside vocab and `torch.gather(all_logprob, 2, response.unsqueeze(-1))` will error out + g_padded_response_ids = [ + response + [DUMMY_PAD_TOKEN] * (args.response_length - len(response)) + for response in g_response_token_ids + ] + g_padded_response_ids = torch.tensor(g_padded_response_ids, device=device) + g_vllm_responses[:] = g_padded_response_ids + dist.broadcast(g_vllm_responses, src=0) + local_vllm_responses = g_vllm_responses[ + accelerator.process_index * queries.shape[0] : (accelerator.process_index + 1) * queries.shape[0] + ] + # print(f"{local_vllm_responses.shape=}, {local_vllm_responses=}") + query_responses = torch.cat((queries, local_vllm_responses), 1) + for i in range(0, queries.shape[0], args.local_rollout_forward_batch_size): + # print(f"get reward stuff starts {i=}") + query = queries[i : i + args.local_rollout_forward_batch_size] + query_response = query_responses[i : i + args.local_rollout_forward_batch_size] + response = query_response[:, context_length:] + + logprob = self.forward( + query_response, response, tokenizer.pad_token_id, context_length, args.temperature + ) + torch.cuda.empty_cache() + + ref_output = forward(self.ref_policy, query_response, tokenizer.pad_token_id) + ref_logits = ref_output.logits[:, context_length - 1 : -1] + ref_logits /= args.temperature + 1e-7 + ref_all_logprob = F.log_softmax(ref_logits, dim=-1) + ref_logprob = torch.gather(ref_all_logprob, 2, response.unsqueeze(-1)).squeeze(-1) + del ref_output, ref_logits, ref_all_logprob + torch.cuda.empty_cache() + + # Response Processing 1. truncate response after the first occurrence of `stop_token_id` + postprocessed_response = response + if args.stop_token_id is not None: # handle the edge case when stop_token_id exists but is 0 + postprocessed_response = truncate_response( + args.stop_token_id, tokenizer.pad_token_id, response + ) + # print("get reward stuff starts 2") + # Response Processing 2. run reward model on the truncated responses + postprocessed_query_response = torch.cat((query, postprocessed_response), 1) + sequence_length = first_true_indices(postprocessed_response == tokenizer.pad_token_id) - 1 + _, score, _ = get_reward( + self.reward_model, postprocessed_query_response, tokenizer.pad_token_id, context_length + ) + # if self.rank == 0 and i == 0: + # print(postprocessed_query_response[0].tolist(), tokenizer.decode(postprocessed_query_response[0])) + if args.reward_model_multiplier != 1.0: + score *= args.reward_model_multiplier + # also apply verifiable reward + if args.apply_verifiable_reward: + # we need to batch the gt to match query. + ground_truth = ground_truths[i : i + args.local_rollout_forward_batch_size] + dataset = datasets[i : i + args.local_rollout_forward_batch_size] + verifiable_reward, verifiable_count = apply_verifiable_reward( + postprocessed_query_response, + tokenizer, + ground_truth, + dataset, + verify_reward=10, + answer_extraction_model=answer_extraction_model, + answer_extraction_tokenizer=answer_extraction_tokenizer, + ) + score += verifiable_reward + else: + verifiable_count = torch.tensor([0.0], device=device).float() + full_value, _, _ = get_reward( + self.value_model, query_response, tokenizer.pad_token_id, context_length + ) + value = full_value[:, context_length - 1 : -1].squeeze(-1) + + responses.append(response) + postprocessed_responses.append(postprocessed_response) + logprobs.append(logprob) + ref_logprobs.append(ref_logprob) + sequence_lengths.append(sequence_length) + scores.append(score) + values.append(value) + verifiable_counts.append(verifiable_count) + # print(f"get reward stuff starts 5") + + responses = torch.cat(responses, 0) + postprocessed_responses = torch.cat(postprocessed_responses, 0) + logprobs = torch.cat(logprobs, 0) + ref_logprobs = torch.cat(ref_logprobs, 0) + sequence_lengths = torch.cat(sequence_lengths, 0) + scores = torch.cat(scores, 0) + verifiable_counts = torch.cat(verifiable_counts, 0) + verifiable_correct_rate = verifiable_counts.sum() / queries.shape[0] + values = torch.cat(values, 0) + # print(f"get reward stuff finished") + del (logprob, ref_logprob, full_value, value, score) + gc.collect() + torch.cuda.empty_cache() + + # Response Processing 3. filter response. Ensure that the sample contains stop_token_id + # responses not passing that filter will receive a low (fixed) score + # only query humans on responses that pass that filter + contain_stop_token = torch.any(postprocessed_responses == args.stop_token_id, dim=-1) + # NOTE: only apply the stop token filter if the response is long enough + # otherwise the model could learn to generate the first token as the stop token + contain_stop_token = contain_stop_token & (sequence_lengths >= args.min_response_length) + if args.non_stop_penalty: + scores = torch.where( + contain_stop_token, scores, torch.full_like(scores, args.penalty_reward_value) + ) + + # be very careful with `padding_mask_p1`; see https://excalidraw.com/#json=LWnzG4w2k5DjF_EOL_xPt,e2w3a-hFJ_gX5vOfeyXGTw + response_idxs = torch.arange(responses.shape[1], device=responses.device).repeat(responses.shape[0], 1) + padding_mask = response_idxs > sequence_lengths.unsqueeze(1) + logprobs = torch.masked_fill(logprobs, padding_mask, INVALID_LOGPROB) + ref_logprobs = torch.masked_fill(ref_logprobs, padding_mask, INVALID_LOGPROB) + sequence_lengths_p1 = sequence_lengths + 1 + padding_mask_p1 = response_idxs > (sequence_lengths_p1.unsqueeze(1)) + values = torch.masked_fill(values, padding_mask_p1, 0) + # print(f"get reward stuff finished 2") + + # 4. compute rewards + kl1 = logprobs - ref_logprobs + kl2 = (kl1) ** 2 / 2 + kl3 = (-kl1).exp() - 1 + kl1 + if args.kl_estimator == "kl1": + kl = kl1 + elif args.kl_estimator == "kl2": + kl = kl2 + elif args.kl_estimator == "kl3": + kl = kl3 + # if self.rank==0: + # print(f"{logprobs[0][:40]=}, {ref_logprobs[0][:40]=}, {kl.sum(1)=}") + non_score_reward = -args.beta * kl + non_score_reward_sum = non_score_reward.sum(1) + rlhf_reward = scores + non_score_reward_sum + rewards = non_score_reward.clone() + actual_start = torch.arange(rewards.size(0), device=rewards.device) + actual_end = torch.where(sequence_lengths_p1 < rewards.size(1), sequence_lengths_p1, sequence_lengths) + rewards[[actual_start, actual_end]] += scores + # print(f"get reward stuff finished 3") + + # 5. whiten rewards + if args.whiten_rewards: + rewards = masked_whiten(rewards, mask=~padding_mask_p1, shift_mean=False) + rewards = torch.masked_fill(rewards, padding_mask_p1, 0) + + # print('gae') + # 6. compute advantages and returns + lastgaelam = 0 + advantages_reversed = [] + gen_length = responses.shape[1] + for t in reversed(range(gen_length)): + nextvalues = values[:, t + 1] if t < gen_length - 1 else 0.0 + delta = rewards[:, t] + args.gamma * nextvalues - values[:, t] + lastgaelam = delta + args.gamma * args.lam * lastgaelam + advantages_reversed.append(lastgaelam) + advantages = torch.stack(advantages_reversed[::-1], axis=1) + returns = advantages + values + advantages = masked_whiten(advantages, ~padding_mask) + advantages = torch.masked_fill(advantages, padding_mask, 0) + torch.cuda.empty_cache() + + # print('training starts') + # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch + for epoch_idx in range(args.num_epochs): + b_inds = np.random.permutation(args.local_rollout_batch_size * args.number_samples_per_prompt) + minibatch_idx = 0 + for mini_batch_start in range( + 0, args.local_rollout_batch_size * args.number_samples_per_prompt, args.local_mini_batch_size + ): + mini_batch_end = mini_batch_start + args.local_mini_batch_size + mini_batch_inds = b_inds[mini_batch_start:mini_batch_end] + gradient_accumulation_idx = 0 + for micro_batch_start in range(0, args.local_mini_batch_size, args.per_device_train_batch_size): + # print("micro batch start", micro_batch_start, self.rank) + micro_batch_end = micro_batch_start + args.per_device_train_batch_size + micro_batch_inds = mini_batch_inds[micro_batch_start:micro_batch_end] + mb_advantage = advantages[micro_batch_inds] + mb_responses = responses[micro_batch_inds] + mb_query_responses = query_responses[micro_batch_inds] + mb_logprobs = logprobs[micro_batch_inds] + mb_return = returns[micro_batch_inds] + mb_values = values[micro_batch_inds] + mb_padding_mask_p1 = padding_mask_p1[micro_batch_inds] + + vpred_temp = get_reward( + self.value_model, mb_query_responses, tokenizer.pad_token_id, context_length + ) + vpred_temp = vpred_temp[0] + vpred = vpred_temp[:, context_length - 1 : -1].squeeze(-1) + vpred = torch.masked_fill(vpred, mb_padding_mask_p1, 0) + vpredclipped = torch.clamp( + vpred, + mb_values - args.cliprange_value, + mb_values + args.cliprange_value, + ) + vf_losses1 = torch.square(vpred - mb_return) + vf_losses2 = torch.square(vpredclipped - mb_return) + vf_loss_max = torch.max(vf_losses1, vf_losses2) + vf_loss = 0.5 * masked_mean(vf_loss_max, ~mb_padding_mask_p1) + self.value_model.backward(vf_loss * args.vf_coef) + self.value_model.step() + + new_logprobs = self.forward( + mb_query_responses, mb_responses, tokenizer.pad_token_id, context_length, args.temperature + ) + # if self.rank==0: + # print(f"{new_logprobs[0][:40]=}, {mb_logprobs[0][:40]=}") + new_logprobs = torch.masked_fill(new_logprobs, padding_mask[micro_batch_inds], INVALID_LOGPROB) + logprobs_diff = new_logprobs - mb_logprobs + ratio = torch.exp(logprobs_diff) + pg_losses = -mb_advantage * ratio + pg_losses2 = -mb_advantage * torch.clamp(ratio, 1.0 - args.cliprange, 1.0 + args.cliprange) + pg_loss_max = torch.max(pg_losses, pg_losses2) + pg_loss = masked_mean(pg_loss_max, ~padding_mask[micro_batch_inds]) + loss = pg_loss + self.model.backward(loss) + # print("backward loss", self.rank, "micro batch start", micro_batch_start) + # print("trying to step", self.rank, "micro batch start", micro_batch_start) + self.model.step() + # print("step", self.rank, "micro batch start", micro_batch_start) + with torch.no_grad(): + # print("waiting for value model step", self.rank, "micro batch start", micro_batch_start) + # vf_loss, vf_clipfrac = ray.get(value_model_step_future) + vf_clipfrac = masked_mean((vf_losses2 > vf_losses1).float(), ~mb_padding_mask_p1) + pg_clipfrac = masked_mean( + (pg_losses2 > pg_losses).float(), ~padding_mask[micro_batch_inds] + ) + # print("value model stepped", self.rank, "micro batch start", micro_batch_start) + # prob_dist = torch.nn.functional.softmax(logits, dim=-1) + # entropy = torch.logsumexp(logits, dim=-1) - torch.sum(prob_dist * logits, dim=-1) + approxkl = 0.5 * (logprobs_diff**2).mean() + approxkl_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = approxkl + pg_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_clipfrac + pg_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = pg_loss + vf_loss_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_loss + vf_clipfrac_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = vf_clipfrac + # entropy_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = entropy.mean() + ratio_stats[epoch_idx, minibatch_idx, gradient_accumulation_idx] = ratio.mean() + gradient_accumulation_idx += 1 + minibatch_idx += 1 + # fmt: off + del mb_advantage, mb_responses, mb_query_responses, mb_logprobs, mb_return, mb_values, mb_padding_mask_p1 + del new_logprobs, logprobs_diff, ratio, pg_losses, pg_losses2, pg_loss_max, pg_loss, loss + # del vpred_temp, vpred, vpredclipped, vf_losses1, vf_losses2, vf_loss_max + # del vf_loss, vf_clipfrac, pg_clipfrac, approxkl + # fmt: on + # del everything and empty cache + torch.cuda.empty_cache() + del b_inds, mini_batch_inds + # print("start metrics") + with torch.no_grad(): + local_metrics[0] = sequence_lengths.float().mean() + local_metrics[1] = (responses == args.stop_token_id).sum().float().mean() + local_metrics[2] = kl.sum(1).mean() + local_metrics[3] = (-logprobs).sum(1).mean() + local_metrics[4] = non_score_reward_sum.mean() + local_metrics[5] = rlhf_reward.mean() + local_metrics[6] = scores.mean() + local_metrics[7] = approxkl_stats.mean() + local_metrics[8] = pg_clipfrac_stats.mean() + local_metrics[9] = pg_loss_stats.mean() + local_metrics[10] = vf_loss_stats.mean() + local_metrics[11] = vf_clipfrac_stats.mean() + local_metrics[12] = entropy_stats.mean() + local_metrics[13] = ratio_stats.mean() + local_metrics[14] = ratio_stats.var() + local_metrics[15] = ((kl) ** 2 / 2).sum(1).mean() + local_metrics[16] = ((-kl).exp() - 1 + kl).sum(1).mean() + local_metrics[17] = verifiable_correct_rate + local_metrics[18] = contain_stop_token.float().mean() + # global_metrics = accelerator.reduce(local_metrics, reduction="mean").tolist() + local_metrics /= dist.get_world_size() + dist.all_reduce(local_metrics, op=dist.ReduceOp.SUM) + global_metrics = local_metrics.tolist() + metrics = { + "episode": episode, + "training_step": training_step, + "lr": self.scheduler.get_last_lr()[0], + "epoch": episode / len(train_dataset), + "time/from_scratch": time.time() - start_time, + "time/training": time.time() - training_time_start, + "val/sequence_lengths": global_metrics[0], + "val/num_stop_token_ids": global_metrics[1], + "objective/kl": global_metrics[2], + "objective/kl2": global_metrics[15], + "objective/kl3": global_metrics[16], + "objective/entropy": global_metrics[3], + "objective/non_score_reward": global_metrics[4], + "objective/rlhf_reward": global_metrics[5], + "objective/scores": global_metrics[6], + "policy/approxkl_avg": global_metrics[7], + "policy/clipfrac_avg": global_metrics[8], + "loss/policy_avg": global_metrics[9], + "loss/value_avg": global_metrics[10], + "val/clipfrac_avg": global_metrics[11], + "policy/entropy_avg": global_metrics[12], + "val/ratio": global_metrics[13], + "val/ratio_var": global_metrics[14], + "objective/verifiable_correct_rate": global_metrics[17], + "val/stop_token_rate": global_metrics[18], + } + if accelerator.is_main_process: + print_rich_single_line_metrics(metrics) + metrics_queue.put((metrics, episode, df)) + del (queries, responses, postprocessed_responses, logprobs, ref_logprobs, sequence_lengths, scores, values) + del (global_metrics, metrics, kl, non_score_reward, non_score_reward_sum, rlhf_reward) + gc.collect() + torch.cuda.empty_cache() + # print(f"finished training {training_step}") + + # save steps + if args.save_freq > 0 and training_step % args.save_freq == 0: + checkpoint_dir = f"{args.output_dir}_checkpoints" + os.makedirs(checkpoint_dir, exist_ok=True) + step_dir = os.path.join(checkpoint_dir, f"step_{training_step}") + os.makedirs(step_dir, exist_ok=True) + print(f"Saving model at step {training_step} to {step_dir}") + self.save_model(step_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(step_dir, training_step) + print(f"Saving final model at step {training_step} to {args.output_dir}") + self.save_model(args.output_dir) + if args.try_launch_beaker_eval_jobs_on_weka: + self.launch_ai2_evals_on_weka(args.output_dir) + + # Ai2 logic: we use /output to store the artifacts of the job, so we + # make a copy of the model to `/output` in the end. + if self.rank == 0 and len(self.beaker_config.beaker_dataset_id_urls) > 0 and args.output_dir != "/output": + shutil.copytree(args.output_dir, "/output", dirs_exist_ok=True) + print("finished training") + + def save_model(self, output_dir: str) -> None: + if self.rank == 0: + os.makedirs(output_dir, exist_ok=True) + + # save model weights for ZeRO2/3 + model_to_save = self.model + if hasattr(model_to_save, "module"): + model_to_save = model_to_save.module + + # gather parameters + output_state_dict = {} + for k, v in model_to_save.named_parameters(): + # only gather z3 params + params_to_fetch = _z3_params_to_fetch([v]) + with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): + vv = v.data.cpu() + if self.rank == 0: + output_state_dict[k] = vv + + if self.rank == 0: + state_dict = model_to_save.state_dict() + + # copy named_buffers with `persistent=True` + for k, v in model_to_save.named_buffers(): + if k not in state_dict: + continue + vv = v.data.cpu() + output_state_dict[k] = vv + + state_dict_keys = set(state_dict.keys()) + output_state_dict_keys = set(output_state_dict.keys()) + + # corner case for tie_word_embeddings, such as Qwen2-0.5B + if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: + state_dict_keys.remove("lm_head.weight") + + assert state_dict_keys.issubset( + output_state_dict_keys + ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" + + # # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 + # if isinstance(model_to_save, PeftModel): + # model_to_save.save_pretrained(output_dir, **kwargs) + # if self.stage == 3: + # torch.save( + # get_peft_model_state_dict(model_to_save, output_state_dict), + # os.path.join(output_dir, "adapter_model.bin"), + # ) + # else: + # save model + model_to_save.save_pretrained(output_dir, state_dict=output_state_dict) + + # save tokenizer + self.original_tokenizer.save_pretrained(output_dir) + + def launch_ai2_evals_on_weka(self, step_dir: str, training_step: Optional[int] = None) -> None: + """auto eval the metrics as `f"{args.exp_name}_step_{training_step}"` in our leaderboard""" + args = self.args + beaker_config = self.beaker_config + model_config = self.model_config + wandb_url = self.wandb_url + # Ai2 specific logic + if is_beaker_job() and self.rank == 0: + if training_step is not None: + leaderboard_name = f"{args.hf_repo_revision}_step_{training_step}" + else: + leaderboard_name = args.hf_repo_revision + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "ppo", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb_url, + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + leaderboard_name, # to match what the auto-evals name as. + ) + + command = f"""\ +python scripts/submit_eval_jobs.py \ + --model_name {leaderboard_name} \ + --location {step_dir} \ + --cluster ai2/saturn-cirrascale ai2/neptune-cirrascale \ + --is_tuned \ + --workspace "tulu-3-results" \ + --priority high \ + --preemptible \ + --use_hf_tokenizer_template \ + --beaker_image "nathanl/open_instruct_auto" \ + --upload_to_hf {args.hf_metadata_dataset} \ + --run_oe_eval_experiments \ + --evaluate_on_weka \ + --run_safety_evaluations \ + --skip_oi_evals""" + if args.oe_eval_tasks is not None: + command += f" --oe_eval_tasks {','.join(args.oe_eval_tasks)}" + print(f"Launching eval jobs with command: {command}") + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + +def kill_ray_cluster_if_a_worker_dies(object_refs: List[Any], stop_event: threading.Event): + while True: + if stop_event.is_set(): + break + for ref in object_refs: + try: + ray.get(ref, timeout=0.01) + except ray.exceptions.GetTimeoutError: + pass + except Exception as e: + print(e) + print(f"Actor {ref} died") + time.sleep(120) + ray.shutdown() + os._exit(1) # Force shutdown the process + + time.sleep(30) + + +class ModelGroup: + def __init__( + self, + pg: PlacementGroup, + ray_process_cls: RayProcess, + num_gpus_per_node: List[int], + ): + self.pg = pg + self.ray_process_cls = ray_process_cls + self.num_gpus_per_node = num_gpus_per_node + self.num_gpus_per_actor = 1 + self.num_cpus_per_actor = 4 + self.models = [] + world_size = sum(self.num_gpus_per_node) + master_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=PlacementGroupSchedulingStrategy( + placement_group=self.pg, placement_group_bundle_index=0 + ), + ).remote(world_size, 0, 0, None, None) + + self.models.append(master_policy) + master_addr, master_port = ray.get(master_policy.get_master_addr_port.remote()) + + def get_bundle_index(rank, num_gpus_per_node): + """given a rank and a list of num_gpus_per_node, return the index of the bundle that the rank belongs to""" + bundle_idx = 0 + while rank >= num_gpus_per_node[bundle_idx]: + rank -= num_gpus_per_node[bundle_idx] + bundle_idx += 1 + return bundle_idx + + assert get_bundle_index(0, [7, 8, 4]) == 0 + assert get_bundle_index(1, [7, 8, 4]) == 0 + assert get_bundle_index(7, [7, 8, 4]) == 1 + assert get_bundle_index(8, [7, 8, 4]) == 1 + assert get_bundle_index(9, [7, 8, 4]) == 1 + assert get_bundle_index(16, [7, 8, 4]) == 2 + + # Setup worker models + for rank in range(1, world_size): + print(f"{rank=}, {world_size=}, {rank=}, {master_addr=}, {master_port=}") + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=self.pg, + placement_group_bundle_index=get_bundle_index(rank, self.num_gpus_per_node), + ) + worker_policy = ray_process_cls.options( + num_cpus=self.num_cpus_per_actor, + num_gpus=self.num_gpus_per_actor, + scheduling_strategy=scheduling_strategy, + ).remote(world_size, rank, 0, master_addr, master_port) + self.models.append(worker_policy) + + +def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + calculate_runtime_args(args, model_config) + + # set up experiment tracking and seeds + all_configs = {} + beaker_config = None + if is_beaker_job(): + args.checkpoint_output_dir = os.environ.get("CHECKPOINT_OUTPUT_DIR", None) + beaker_config = maybe_get_beaker_config() + all_configs.update(vars(beaker_config)) + all_configs.update(**asdict(args), **asdict(dataset_config), **asdict(model_config)) + if args.with_tracking: + import wandb + + wandb.init( + project=args.wandb_project_name, + entity=args.wandb_entity, + sync_tensorboard=True, + config=all_configs, + name=args.run_name, + save_code=True, + tags=[args.exp_name] + get_wandb_tags(), + ) + writer = SummaryWriter(f"runs/{args.run_name}") + writer.add_text( + "hyperparameters", + "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), + ) + + # create a tokenizer (pad from right) + # if check_hf_olmo_availability(): + # # allows AutoModel... to work with not in transformers olmo models + # import hf_olmo # noqa + # from hf_olmo import OLMoTokenizerFast + config = AutoConfig.from_pretrained(model_config.model_name_or_path, revision=model_config.model_revision) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, revision=model_config.model_revision, padding_side="right" + ) + # if check_hf_olmo_availability(): + # print("Using exsiting tokenier chat template...") + # pass + # else: + # tokenizer.chat_template = CHAT_TEMPLATES[dataset_config.chat_template] + # if config.architectures == "LlamaForCausalLM" and config.bos_token_id == 128000: + # tokenizer.pad_token_id = 128002 # <|reserved_special_token_0|> + # elif check_hf_olmo_availability() and isinstance(tokenizer, OLMoTokenizerFast): + # # OLMo newer models use this tokenizer + # breakpoint() + # if tokenizer.bos_token is None: + # tokenizer.bos_token = tokenizer.eos_token + # assert ( + # args.add_bos + # ), "For OLMo with GPTNeoX, you must add bos token to the beginning of the input sequence." + # # else, pythia / other models + # else: + # num_added_tokens = tokenizer.add_special_tokens( + # { + # "pad_token": "", + # } + # ) + # assert num_added_tokens == 1, "GPTNeoXTokenizer should only add one special token - the pad_token." + # else: + # tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # NOTE: we do not resize the embedding + + # create the dataset + dataset_dict = DatasetDict() + dataset_processor = SFTGroundTruthDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + if len(args.dataset_train_splits) != len(args.dataset_mixer_dict) and len(args.dataset_train_splits) == 1: + args.dataset_train_splits = [args.dataset_train_splits[0]] * len(args.dataset_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_train_splits[0]} split for all datasets." + ) + if len(args.dataset_eval_splits) != len(args.dataset_eval_mixer_dict) and len(args.dataset_eval_splits) == 1: + args.dataset_eval_splits = [args.dataset_eval_splits[0]] * len(args.dataset_eval_mixer_dict) + print( + f"Dataset splits not provided for all datasets. Using the same {args.dataset_eval_splits[0]} split for all datasets." + ) + train_dataset = combine_dataset( + args.dataset_mixer_dict, + splits=args.dataset_train_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + if dataset_config.sanity_check: + train_dataset = train_dataset.select( + range(0, min(len(train_dataset), dataset_config.sanity_check_max_samples)) + ) + train_dataset = dataset_processor.tokenize(train_dataset) + train_dataset = dataset_processor.filter(train_dataset, need_contain_labels=False) + dataset_dict["train"] = train_dataset + eval_dataset = None + if args.dataset_eval_mixer is not None: + eval_dataset = combine_dataset( + args.dataset_eval_mixer_dict, + splits=args.dataset_eval_splits, + columns_to_keep=[ + dataset_config.sft_messages_key, + dataset_config.ground_truths_key, + dataset_config.dataset_source_key, + ], + ) + if dataset_config.sanity_check: + eval_dataset = eval_dataset.select( + range(0, min(len(eval_dataset), dataset_config.sanity_check_max_samples)) + ) + eval_dataset = dataset_processor.tokenize(eval_dataset) + eval_dataset = dataset_processor.filter(eval_dataset, need_contain_labels=False) + dataset_dict["eval"] = eval_dataset + data_collator = SimpleGenerateCollatorWithGroundTruth(pad_token_id=tokenizer.pad_token_id) + + # some more runtime logging + pprint([args, dataset_config, model_config]) + visualize_token(train_dataset[0][INPUT_IDS_PROMPT_KEY], tokenizer) + if args.with_tracking: + # upload the visualized token length + dataset_processor.get_token_length_visualization( + dataset_dict, save_path=f"runs/{args.run_name}/token_length.png" + ) + wandb.log({"token_length": wandb.Image(f"runs/{args.run_name}/token_length.png")}) + + # create the model and optimizer + pg = None + bundles = [{"GPU": actor_num_gpus, "CPU": actor_num_gpus * 10} for actor_num_gpus in args.actor_num_gpus_per_node] + pg = placement_group(bundles, strategy="STRICT_SPREAD") + ray.get(pg.ready()) + + inits = [] + policy_group = ModelGroup( + pg, + PolicyTrainerRayProcess, + args.actor_num_gpus_per_node, + ) + wandb_url = wandb.run.get_url() if args.with_tracking else None + inits.extend( + model.from_pretrained.remote(args, model_config, beaker_config, wandb_url) for model in policy_group.models + ) + max_len = dataset_config.max_prompt_token_length + args.response_length + vllm_engines = create_vllm_engines( + args.vllm_num_engines, + args.vllm_tensor_parallel_size, + model_config.model_name_or_path, + model_config.model_revision, + args.seed, + args.enable_prefix_caching, + max_len, + ) + + metrics_queue = RayQueue() + ray.get(inits) + print("======== all models initialized =========") + ray.get(policy_group.models[0].get_vocab_size.remote()) + # print(f"{policy_vocab_size=}, {reward_vocab_size=}") + # if policy_vocab_size != reward_vocab_size: + # ray.shutdown() # shutdown here so this error message is not buried in the logs + # raise ValueError( + # "Policy and reward model must have the same vocab size. " + # f"Policy: {policy_vocab_size}, Reward: {reward_vocab_size}. " + # "If they don't have the same vocab size, the policy could generate tokens which " + # "is going to cause index out of bound error in the reward model." + # ) + + refs = [] + for i, policy_model in enumerate(policy_group.models): + refs.append( + policy_model.train.remote( + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + vllm_engines=vllm_engines, + metrics_queue=metrics_queue, + data_collator=data_collator, + ) + ) + + # somtimes a worker dies due to CUDA issues, but the rest of the cluster would just hang + # so we need kill the ray cluster when this happens. + stop_event = threading.Event() + threading.Thread(target=kill_ray_cluster_if_a_worker_dies, args=(refs, stop_event)).start() + + # train and gather metrics + resume_training_step = 1 + for training_step in range(resume_training_step, args.num_training_steps + 1): + result = metrics_queue.get() + metrics, episode, df = result + for key, value in metrics.items(): + writer.add_scalar(key, value, episode) + + if df is not None: + if args.with_tracking: + wandb.log({"sample_completions": wandb.Table(dataframe=df)}) + else: + print_rich_table(df.iloc[:1]) + ray.get(refs) + + # save model + ray.shutdown() + stop_event.set() + + # Ai2 specific logic + if is_beaker_job(): + if args.hf_metadata_dataset: + dataset_list = list(args.dataset_mixer_dict.keys()) + # mainly just focussing here on what would be useful for the leaderboard. + # wandb will have even more useful information. + metadata_blob = { + "model_name": args.exp_name, + "model_type": "sft", + "datasets": dataset_list, + "base_model": model_config.model_name_or_path, + "wandb_path": wandb.run.get_url(), + "beaker_experiment": beaker_config.beaker_experiment_url, + "beaker_datasets": beaker_config.beaker_dataset_id_urls, + } + upload_metadata_to_hf( + metadata_blob, + "metadata.json", + args.hf_metadata_dataset, + "results/" + args.hf_repo_revision, # to match what the auto-evals name as. + ) + + if args.try_launch_beaker_eval_jobs and len(beaker_config.beaker_dataset_id_urls) > 0: + command = f"""\ + python mason.py \ + --cluster ai2/allennlp-cirrascale ai2/general-cirrascale-a5000 ai2/general-cirrascale-a5000 ai2/s2-cirrascale ai2/general-cirrascale \ + --priority low \ + --preemptible \ + --budget ai2/allennlp \ + --workspace ai2/tulu-2-improvements \ + --image nathanl/open_instruct_auto \ + --pure_docker_mode \ + --gpus 0 -- python scripts/wait_beaker_dataset_model_upload_then_evaluate_model.py \ + --beaker_workload_id {beaker_config.beaker_workload_id} \ + --upload_to_hf {args.hf_metadata_dataset} \ + --model_name {args.hf_repo_revision} + """ + process = subprocess.Popen(["bash", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + print(f"Submit jobs after model training is finished - Stdout:\n{stdout.decode()}") + print(f"Submit jobs after model training is finished - Stderr:\n{stderr.decode()}") + print(f"Submit jobs after model training is finished - process return code: {process.returncode}") + + accelerator = Namespace() + accelerator.is_main_process = True # hack + if args.push_to_hub: + print("Pushing model to hub") + push_folder_to_hub( + accelerator, + args.output_dir, + args.hf_repo_id, + args.hf_repo_revision, + ) + + # The `checkpoint_output_dir` is only used in case of preemption and should be deleted if the run was successful. + # We use `--save_freq` to save intermediate checkpoints in the output folder instead. + if args.checkpoint_output_dir is not None and os.path.exists(args.checkpoint_output_dir): + shutil.rmtree(args.checkpoint_output_dir, ignore_errors=True) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, ModelConfig)) + main(*parser.parse()) diff --git a/open_instruct/rejection_sampling/generation_check_ground_truth.py b/open_instruct/rejection_sampling/generation_check_ground_truth.py new file mode 100644 index 000000000..1e3dd4240 --- /dev/null +++ b/open_instruct/rejection_sampling/generation_check_ground_truth.py @@ -0,0 +1,261 @@ +# Copyright 2024 AllenAI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +python open_instruct/rejection_sampling/generation_check_ground_truth.py \ + --model_name_or_path allenai/open_instruct_dev \ + --revision L3.1-8B-v3.9-nc-fixed-2__meta-llama_Llama-3.1-8B__123__1730531285 \ + --num_completions 3 \ + --dataset_mixer_list ai2-adapt-dev/math_ground_truth 1.0 \ + --dataset_splits train \ + --dataset_end_idx 10 + + +python open_instruct/rejection_sampling/generation_check_ground_truth.py \ + --model_name_or_path allenai/open_instruct_dev \ + --revision olmo_7b_soup_anneal_v3.9_4_DPO___model__42__1730863426 \ + --num_completions 5 \ + --dataset_mixer_list ai2-adapt-dev/gsm8k_ground_truth 1.0 \ + --dataset_splits train \ + --dataset_end_idx 20 +""" +import asyncio +import copy +import json +import os +import sys +import time +from collections import defaultdict +from dataclasses import asdict, dataclass +from pprint import pformat +from typing import Dict, List, Optional + +from huggingface_hub import HfApi +from huggingface_hub.repocard import RepoCard +from rich.pretty import pprint +import torch +from transformers import AutoTokenizer +from vllm import LLM, SamplingParams + +from open_instruct.dataset_processor import ( + INPUT_IDS_PROMPT_KEY, + DatasetConfig, + SFTDatasetProcessor, +) +from open_instruct.model_utils import apply_verifiable_reward +from open_instruct.utils import ArgumentParserPlus, check_hf_olmo_availability, combine_dataset + +api = HfApi() +# we don't use `multiprocessing.cpu_count()` because typically we only have 12 CPUs +# and that the shards might be small +NUM_CPUS_FOR_DATASET_MAP = 4 +if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + from vllm.model_executor.models import ModelRegistry + ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + +@dataclass +class Args: + dataset_mixer_list: List[str] + dataset_splits: List[str] = None + dataset_start_idx: int = 0 + dataset_end_idx: Optional[int] = None + + model_name_or_path: str = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr" + revision: str = "main" + save_filename: str = "completions.jsonl" + + # upload config + hf_repo_id: str = os.path.basename(__file__)[: -len(".py")] + push_to_hub: bool = False + hf_entity: Optional[str] = None + add_timestamp: bool = True + + +@dataclass +class GenerationArgs: + num_completions: int = 3 + temperature: float = 0.8 + response_length: int = 2048 + top_p: float = 0.9 + tensor_parallel_size: int = 1 + + +def save_jsonl(save_filename: str, table: Dict[str, List]): + first_key = list(table.keys())[0] + os.makedirs(os.path.dirname(save_filename), exist_ok=True) + with open(save_filename, "w") as outfile: + for i in range(len(table[first_key])): + json.dump({key: table[key][i] for key in table}, outfile) + outfile.write("\n") + + +def generate_with_vllm(model_name_or_path: str, revision: str, prompt_token_ids: List[int], gen_args: GenerationArgs): + llm = LLM( + model=model_name_or_path, + revision=revision, + tokenizer_revision=revision, + tensor_parallel_size=gen_args.tensor_parallel_size, + max_model_len=gen_args.response_length, + ) + + # filter out prompts which are beyond the model's max token length + max_model_len = llm.llm_engine.scheduler_config.max_model_len + prompt_token_ids_len = len(prompt_token_ids) + prompt_token_ids = [item for item in prompt_token_ids if len(item) < max_model_len] + if len(prompt_token_ids) != prompt_token_ids_len: + print(f"Filtered out {prompt_token_ids_len - len(prompt_token_ids)} prompts which exceeds max token length") + + outputs = llm.generate( + prompt_token_ids=prompt_token_ids, + sampling_params=SamplingParams( + n=gen_args.num_completions, + temperature=gen_args.temperature, + top_p=1.0, + max_tokens=gen_args.response_length, + include_stop_str_in_output=True, + ), + ) + + response_ids = [list(out.token_ids) for output in outputs for out in output.outputs] + return response_ids + + +def format_conversation(messages: list) -> str: + formatted_conversation = [] + + # Iterate through the messages + for message in messages: # Exclude the last assistant message + role = "User A" if message["role"] == "user" else "User B" + content = message["content"].strip() + formatted_conversation.append(f"{role}: {content}") + + # Join the conversation with a single newline + return "\n".join(formatted_conversation) + + +def main(args: Args, dataset_config: DatasetConfig, gen_args: GenerationArgs): + if len(args.dataset_splits) != len(args.dataset_mixer_list) // 2: + args.dataset_splits = ["train"] * (len(args.dataset_mixer_list) // 2) + print(f"Using default dataset_splits: {args.dataset_splits} for {(len(args.dataset_mixer_list) // 2)} datasets") + + dataset = combine_dataset( + args.dataset_mixer_list, + splits=args.dataset_splits, + columns_to_keep=[dataset_config.sft_messages_key, dataset_config.ground_truths_key, dataset_config.dataset_source_key], + ) + if args.dataset_end_idx is None: + args.dataset_end_idx = len(dataset) + dataset = dataset.select(range(args.dataset_start_idx, args.dataset_end_idx)) + pprint([dataset_config, args, gen_args]) + + + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, revision=args.revision) + dataset_processor = SFTDatasetProcessor(tokenizer=tokenizer, config=dataset_config) + dataset = dataset_processor.tokenize(dataset) + dataset = dataset_processor.filter(dataset) + prompt_token_ids = dataset[INPUT_IDS_PROMPT_KEY] + ground_truth = dataset[dataset_config.ground_truths_key] + dataset_source = dataset[dataset_config.dataset_source_key] + response_ids = generate_with_vllm(args.model_name_or_path, args.revision, prompt_token_ids, gen_args) + + + # repeat prompt_token_ids, ground truth, dataset source args.num_completions times + prompt_token_ids = [prompt_token_ids[i] for i in range(len(prompt_token_ids)) for _ in range(gen_args.num_completions)] + ground_truth = [ground_truth[i] for i in range(len(ground_truth)) for _ in range(gen_args.num_completions)] + dataset_source = [dataset_source[i] for i in range(len(dataset_source)) for _ in range(gen_args.num_completions)] + + # left pad prompt token ids with 0 + max_seq_len = max(len(item) for item in prompt_token_ids) + padded_prompt_token_ids = [[0] * (max_seq_len - len(item)) + item for item in prompt_token_ids] + # right pad response token ids with 0 + max_seq_len = max(len(item) for item in response_ids) + padded_response_ids = [item + [0] * (max_seq_len - len(item)) for item in response_ids] + padded_prompt_token_ids = torch.tensor(padded_prompt_token_ids) + padded_response_ids = torch.tensor(padded_response_ids) + query_response = torch.concat([padded_prompt_token_ids, padded_response_ids], dim=1) + verifiable_reward, _ = apply_verifiable_reward( + query_response, + tokenizer, + ground_truth, + dataset_source, + verify_reward=10, + ) + import math + verifiable_reward = verifiable_reward.reshape(-1, gen_args.num_completions) + pass_at_k = (verifiable_reward.sum(dim=1) > 1).float().mean() + maj_at_k = (verifiable_reward.sum(dim=1) > math.ceil(gen_args.num_completions / 2)).float().mean() + printa = lambda i: print(tokenizer.decode(query_response[i]), ground_truth[i]) + print(f"{verifiable_reward=}") + print(f"{pass_at_k=}") + print(f"{maj_at_k=}") + breakpoint() + + # save_jsonl(args.save_filename, table) + +# if args.push_to_hub: +# if args.hf_entity is None: +# args.hf_entity = api.whoami()["name"] +# full_repo_id = f"{args.hf_entity}/{args.hf_repo_id}" +# timestamp = f"_{int(time.time())}" +# if args.add_timestamp: +# full_repo_id += timestamp +# api.create_repo(full_repo_id, repo_type="dataset", exist_ok=True) +# for f in [__file__, args.save_filename]: +# api.upload_file( +# path_or_fileobj=f, +# path_in_repo=f.split("/")[-1], +# repo_id=full_repo_id, +# repo_type="dataset", +# ) +# repo_full_url = f"https://huggingface.co/datasets/{full_repo_id}" +# print(f"Pushed to {repo_full_url}") +# run_command = " ".join(["python"] + sys.argv) +# sft_card = RepoCard( +# content=f"""\ +# # allenai/open_instruct: Generation Dataset + +# See https://github.com/allenai/open-instruct/blob/main/docs/algorithms/rejection_sampling.md for more detail + +# ## Configs + +# ``` +# args: +# {pformat(vars(args))} + +# dataset_config: +# {pformat(vars(dataset_config))} + +# gen_args: +# {pformat(vars(gen_args))} +# ``` + +# ## Reproduce this dataset + +# 1. Download the `{[f.split("/")[-1] for f in [__file__, args.save_filename]]}` from the {repo_full_url}. +# 2. Run `{run_command}` +# """ +# ) +# sft_card.push_to_hub( +# full_repo_id, +# repo_type="dataset", +# ) + + +if __name__ == "__main__": + parser = ArgumentParserPlus((Args, DatasetConfig, GenerationArgs)) + main(*parser.parse()) diff --git a/open_instruct/reward_modeling.py b/open_instruct/reward_modeling.py index c59140840..a2a922d46 100644 --- a/open_instruct/reward_modeling.py +++ b/open_instruct/reward_modeling.py @@ -47,6 +47,7 @@ from open_instruct.reward_modeling_eval import evaluate from open_instruct.utils import ( ArgumentParserPlus, + check_hf_olmo_availability, combine_dataset, get_wandb_tags, is_beaker_job, @@ -55,6 +56,8 @@ maybe_use_ai2_wandb_entity, ) +from transformers.models.olmo_1124.modeling_olmo_1124 import Olmo1124ForSequenceClassification, Olmo1124Config +AutoModelForSequenceClassification.register(Olmo1124Config, Olmo1124ForSequenceClassification) api = HfApi() @@ -195,6 +198,12 @@ def layer_init(layer: nn.Module, std: float): def main(args: Args, dataset_config: DatasetConfig, model_config: ModelConfig): + if check_hf_olmo_availability(): + # allows AutoModel... to work with not in transformers olmo models + import hf_olmo # noqa + from hf_olmo import OLMoTokenizerFast + from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification + AutoModelForSequenceClassification.register(hf_olmo.OLMoConfig, OLMoForSequenceClassification) accelerator = calculate_runtime_args_and_accelerator(args, model_config) local_seed = args.seed + accelerator.process_index diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 95a17410e..550ecb43e 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -14,6 +14,7 @@ import dataclasses import functools +import importlib import json import logging import os @@ -53,6 +54,40 @@ """ +# ---------------------------------------------------------------------------- +# Import utilities +def check_hf_olmo_availability(return_version: bool = False) -> Union[dict, bool]: + pkg_name = "hf_olmo" + + # Check if the package spec exists + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + + if package_exists: + try: + # Primary method to get the package version + package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: + # Fallback method + try: + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") + except ImportError: + package_exists = False + package_version = "N/A" + + if return_version: + return { + "available": package_exists, + "version": package_version, + "python_version": sys.version, + "os": os.name, + "platform": sys.platform, + } + else: + return package_exists + + # ---------------------------------------------------------------------------- # Dataset utilities def is_openai_format(messages: Any) -> bool: diff --git a/open_instruct/vllm_utils2.py b/open_instruct/vllm_utils2.py index 7dbdf6a54..09aa73fff 100644 --- a/open_instruct/vllm_utils2.py +++ b/open_instruct/vllm_utils2.py @@ -35,7 +35,6 @@ ) from vllm.worker.worker import Worker - # Copy from pytorch to allow creating multiple main groups. # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py def init_process_group( @@ -131,7 +130,12 @@ def update_weight(self, name, dtype, shape, empty_cache=False): class LLMRayActor: def __init__(self, *args, **kwargs): import vllm - + + from vllm.model_executor.models import ModelRegistry + from open_instruct.olmo_adapter.olmo_1124_vllm import OlmoNewForCausalLM + ModelRegistry.register_model("Olmo1124ForCausalLM", OlmoNewForCausalLM) + ModelRegistry.register_model("Olmo2ForCausalLM", OlmoNewForCausalLM) + self.__version__ = vllm.__version__ assert self.__version__ >= "0.4.1", "OpenRLHF only supports vLLM >= 0.4.1" diff --git a/open_instruct/x.py b/open_instruct/x.py new file mode 100644 index 000000000..f0e4bb1e7 --- /dev/null +++ b/open_instruct/x.py @@ -0,0 +1,52 @@ +from typing import Optional + +from hf_olmo import * +from transformers import ( + AutoModelForCausalLM, + AutoModelForSequenceClassification, +) + +from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + + +class Args: + model_name_or_path: str = "/net/nfs.cirrascale/allennlp/akshitab/model-checkpoints/peteish7/step11931-unsharded-hf" + trust_remote_code: bool = True + revision: Optional[str] = None + + +# def main(args: Args): +# model = AutoModelForCausalLM.from_pretrained( +# args, +# trust_remote_code=True, +# ) + + +if __name__ == "__main__": + # instead of installing from source, https://github.com/AkshitaB/vllm/blob/c96643ec56da3ab8cefba03cadf7731788e756b5/vllm/model_executor/models/__init__.py#L49 + # here we just register the new model class + from vllm.model_executor.models import ModelRegistry + + ModelRegistry.register_model("OLMoForCausalLM", OlmoNewForCausalLM) + from vllm import LLM, SamplingParams + + model = AutoModelForCausalLM.from_pretrained( + "/net/nfs.cirrascale/allennlp/akshitab/model-checkpoints/peteish7/step11931-unsharded-hf", + trust_remote_code=True, + ) + from vllm.model_executor.models import ModelRegistry + + from open_instruct.olmo_adapter.modeling_olmo2 import OLMoForSequenceClassification + from open_instruct.olmo_adapter.olmo_new import OlmoNewForCausalLM + + AutoModelForSequenceClassification.register(OLMoConfig, OLMoForSequenceClassification) + + s = SamplingParams(temperature=0.0) + llm = LLM( + model="/net/nfs.cirrascale/allennlp/akshitab/model-checkpoints/peteish7/step11931-unsharded-hf", + trust_remote_code=True, + gpu_memory_utilization=0.90, + ) + + vllm_out = llm.generate(["How is the weather today"], sampling_params=s) + print(vllm_out[0].outputs[0].text) diff --git a/requirements-olmo.txt b/requirements-olmo.txt new file mode 100644 index 000000000..1ec51fb2f --- /dev/null +++ b/requirements-olmo.txt @@ -0,0 +1,46 @@ +# TODO When updating flash-attn or torch in the future, make sure to update the version in the Dockerfile +torch==2.4.0 +scipy +packaging +sentencepiece +datasets +deepspeed==0.14.4 +accelerate==0.31.0 +peft>=0.11.1 +bitsandbytes>=0.41.1 +evaluate>=0.4.0 +tokenizers==0.19.1 +protobuf +transformers==4.43.4 +openai>=1.0.0 +tiktoken +rouge_score +tensorboard +wandb +gradio>=3.50.2 +termcolor +jsonlines +unidic-lite +einops +flash-attn==2.5.9.post1 # should really only be in dockerfile. Local env often doesn't have GPUs +fire +alpaca-eval==0.6.2 +# for human eval web app +flask +openpyxl +# for ifeval +nltk==3.8.1 +langdetect +immutabledict +# for math evaluations +antlr4-python3-runtime==4.9.2 +mpmath==1.3.0 +sympy==1.12.0 +# for linting +black +flake8 +isort +autoflake +pytest +hf_transfer +beaker-py \ No newline at end of file diff --git a/scripts/eval/oe-eval.sh b/scripts/eval/oe-eval.sh index 6ee361267..5b162c326 100755 --- a/scripts/eval/oe-eval.sh +++ b/scripts/eval/oe-eval.sh @@ -192,4 +192,4 @@ for TASK in "${TASKS[@]}"; do --beaker-retries 2 \ --beaker-priority "$PRIORITY" fi -done +done \ No newline at end of file diff --git a/scripts/eval_constraints/if_functions.py b/scripts/eval_constraints/if_functions.py deleted file mode 100644 index ada85cf77..000000000 --- a/scripts/eval_constraints/if_functions.py +++ /dev/null @@ -1,487 +0,0 @@ -import re -import json -import langdetect -from typing import List - -""" -This module contains functions to verify constraints in the responses generated by the model. -It covers all 25 constraints from the IFEval taxonomy. To be used either for eval or for ground truth rewards. -""" - - - -# include keywords: Include keywords {keyword1}, {keyword2} in your response - -def verify_keywords(text, keyword_list): - """ - Verify if the response contains all the specified keywords. - - Args: - response (str): The response text to check - keyword_list (list): A list of keywords to check for - - Returns: - bool: True if all keywords are present in the response, False otherwise - """ - # Convert response to lowercase for case-insensitive matching - response_lower = text.lower() - - # Check if all keywords are present in the response - return all(keyword.lower() in response_lower for keyword in keyword_list) - - -# Keyword Frequency: In your response, the word {word} should appear {N} times. -def verify_keyword_frequency(text, word, N): - """ - Verifies if a keyword appears exactly N times in the given text. - - Args: - text (str): The text to analyze - keyword (str): The keyword to count - expected_count (int): The expected number of occurrences - - Returns: - tuple: (bool, int) - (Whether constraint is met, actual count found) - """ - # Convert text to lowercase to make the search case-insensitive - text = text.lower() - keyword = word.lower() - - # Split text into words and remove punctuation - import re - words = re.findall(r'\b\w+\b', text) - - # Count actual occurrences - actual_count = sum(1 for word in words if word == keyword) - - # Check if constraint is met - constraint_met = actual_count == N - - return constraint_met - - -# Forbidden Words: Do not include keywords {forbidden words} in the response. -def validate_forbidden_words(text, forbidden_words): - """ - Validates that the text does not contain any of the specified forbidden words. - - Args: - text (str): The text to check for forbidden words - forbidden_words (list[str]): A list of forbidden words - - Returns: - tuple[bool, list[str]]: A tuple containing: - - Boolean indicating if any forbidden words are present - - List of forbidden words found in the text - - Example: - text = "This is a message that should not contain any bad words" - forbidden_words = ["bad", "evil", "harmful"] - result = validate_forbidden_words(text, forbidden_words) - """ - # Convert text to lowercase for case-insensitive matching - text_lower = text.lower() - - # Check each forbidden word - found_words = [word for word in forbidden_words if word.lower() in text_lower] - - # Return results - return len(found_words) == 0 - - -# Letter Frequency : In your response, the letter {letter} should appear {N} times. - -def verify_letter_frequency(text: str, letter: str, N: int) -> bool: - """ - Verifies if a given letter appears exactly the specified number of times in the text. - - Args: - text (str): The text to check - letter (str): The letter to count (case-sensitive) - target_count (int): The expected number of occurrences - - Returns: - bool: True if the constraint is met, False otherwise - - Example: - >>> verify_letter_frequency("hello world", "l", 3) - True - >>> verify_letter_frequency("hello world", "o", 2) - True - >>> verify_letter_frequency("hello world", "x", 0) - True - """ - if len(letter) != 1: - raise ValueError("Letter parameter must be a single character") - - actual_count = text.count(letter) - return actual_count == N - - -# Response Language: Your ENTIRE response should be in {language}, no other language is allowed. - -def validate_response_language(text, language): - """ - Validates that the entire response is in the specified language. - - Args: - text (str): The text to check - language (str): The language code (e.g., 'en' for English) - - Returns: - bool: True if the response is entirely in the specified language, False otherwise - - Example: - text = "This is an English sentence" - language = "en" - result = validate_response_language(text, language) - """ - from langdetect import detect - - # Detect the language of the text - detected_language = detect(text) - # Check if the detected language matches the expected language - return detected_language == language - - -# Number Paragraphs: Your response should contain {N} paragraphs. You separate paragraphs using the markdown divider: -# * * * -def verify_paragraph_count(text: str, N: int) -> bool: - """ - Verifies that a text contains the expected number of paragraphs, - where paragraphs are separated by markdown dividers '* * *' - - Args: - text (str): The text to analyze - expected_count (int): Expected number of paragraphs - - Returns: - bool: True if the text contains exactly the expected number of paragraphs, - False otherwise - - Example: - text = "First paragraph\n* * *\nSecond paragraph" - verify_paragraph_count(text, 2) - True - """ - def clean_text(text: str) -> str: - """Remove extra whitespace and normalize line endings""" - return '\n'.join(line.strip() for line in text.splitlines()).strip() - - # Clean the input text - text = clean_text(text) - - # Split text by markdown divider - # Add 1 to count since n dividers create n+1 paragraphs - paragraphs = text.split('* * *') - actual_count = len(paragraphs) - - # Verify each split resulted in non-empty content - valid_paragraphs = [p.strip() for p in paragraphs if p.strip()] - if len(valid_paragraphs) != actual_count: - return False - - return actual_count == N - - -# Number Words: Answer with at least / around / at most {N} words - -def validate_word_constraint(text: str, N: int, quantifier: str) -> bool: - """ - Validates if a text meets specified word count constraints. - - Args: - text (str): The text to check - count (int): The target word count - qualifier (str): The type of constraint ('at least', 'around', 'at most') - - Returns: - bool: True if the constraint is met, False otherwise - - Raises: - ValueError: If an invalid qualifier is provided - """ - # Remove extra whitespace and split into words - words = text.strip().split() - actual_count = len(words) - - # Define tolerance for "around" qualifier (±10% of target count) - tolerance = max(round(N * 0.1), 1) - - if quantifier == "at least": - return actual_count >= N - elif quantifier == "at most": - return actual_count <= N - elif quantifier == "around": - return abs(actual_count - N) <= tolerance - else: - return False - - -# Number Sentences: Answer with at least / around / at most {N} sentences. -def verify_sentence_constraint(text: str, N: int, quantifier: str) -> bool: - """ - Verifies if a text contains the expected number of sentences. - - Args: - text (str): The text to analyze - N (int): The expected number of sentences - quantifier (str): The quantifier ('at least', 'around', 'at most') - - Returns: - bool: True if the text contains the expected number of sentences, False otherwise - """ - # Split the text into sentences - sentences = re.split(r'(?= N - elif quantifier == 'around': - return abs(actual_count - N) <= 1 - elif quantifier == 'at most': - return actual_count <= N - else: - return False - - -# Number Paragraphs + First Word in i-th Paragraph: There should be {N} paragraphs. Paragraphs and only paragraphs -# are separated with each other by two line breaks. The {i}-th paragraph must start with word {first word}. -def validate_paragraphs(text, N, first_word, i): - """ - Validates that a text contains the expected number of paragraphs and that the i-th paragraph starts with a specific - word. - - Args: - text (str): The text to analyze - N (int): The expected number of paragraphs - first_word (str): The expected first word of the i-th paragraph - i (int): The index of the paragraph to check (1-indexed) - - Returns: - bool: True if the text meets the paragraph and first word requirements, False otherwise - """ - # Split the text into paragraphs - paragraphs = text.split('\n\n') - - # Check if the number of paragraphs is as expected - if len(paragraphs) != N: - return False - - # Check if the i-th paragraph starts with the specified first word - if paragraphs[i - 1].strip().startswith(first_word): - return True - return False - - -# Postscript: At the end of your response, please explicitly add a postscript starting with {postscript marker} - -def verify_postscript(text, postscript_marker): - """ - Verifies if a text contains a postscript starting with '{postscript marker}' - - Args: - text (str): The text to verify - - Returns: - bool: True if the text contains a valid postscript, False otherwise - """ - # Check if the text contains the postscript marker - if postscript_marker in text: - # Get the index of the marker - marker_index = text.find(postscript_marker) - # Check if the marker appears near the end - remaining_text = text[marker_index:].strip() - # Verify it's not just the marker alone - return len(remaining_text) > len(postscript_marker) - return False - - -# Number Placeholder: The response must contain at least {N} placeholders represented by square brackets, -# such as [address]. -def validate_placeholders(text: str, N: int) -> tuple[bool, List[str]]: - """ - Validates if a text contains at least the specified number of placeholders in square brackets. - - Args: - text (str): The text to check for placeholders - min_placeholders (int): Minimum number of placeholders required - - Returns: - tuple[bool, List[str]]: A tuple containing: - - Boolean indicating if the text meets the placeholder requirement - - List of found placeholders - - Example: - >>> text = "Hello [name], your [item] will be delivered to [address]" - >>> validate_placeholders(text, 2) - (True, ['name', 'item', 'address']) - """ - # Find all placeholders using regex - pattern = r'\[(.*?)\]' - placeholders = re.findall(pattern, text) - - # Check if the number of placeholders meets the requirement - has_enough = len(placeholders) >= N - - return has_enough, placeholders - - -# Number Bullets: Your answer must contain exactly {N} bullet points. Use the markdown bullet points such as: * This -# is a point. -def verify_bullet_points(text: str, N: int) -> tuple[bool, str]: - """ - Verifies if a text contains exactly N bullet points in markdown format. - Returns a tuple of (is_valid, message). - - Args: - text (str): The text to check - expected_count (int): The expected number of bullet points - - Returns: - tuple[bool, str]: (True if constraint is met, explanation message) - """ - # Split text into lines and count lines starting with * or - - lines = text.split('\n') - bullet_points = [line.strip() for line in lines if line.strip().startswith(('*', '-'))] - actual_count = len(bullet_points) - - if actual_count == N: - return True - else: - return False - - -# Title: Your answer must contain a title, wrapped in double angular brackets, such as <>. -def validate_title(text: str) -> bool: - pattern = r"<<(.*?)>>" - matches = re.findall(pattern, text) - - if len(matches) > 0: - return True - else: - return False - - -# Choose: From Answer with one of the following options: {options} -def validate_choice(text: str, options: list) -> bool: - for option in options: - if text in option: - return True - return False - - -# Minimum Number Highlighted Section: Highlight at least {N} sections in your answer with markdown, i.e. *highlighted -# section* -def validate_highlighted_sections(text: str, N: int) -> bool: - pattern = r"\*(.*?)\*" - matches = re.findall(pattern, text) - - if len(matches) >= N: - return True - else: - return False - - -# Multiple Sections: Your response must have {N} sections. Mark the beginning of each section with {section splitter} X. - -def validate_sections(text: str, N: int, section_splitter: str) -> bool: - sections = text.split(section_splitter) - # The first section might not start with the splitter, so we adjust for this - if sections[0] == '': - sections.pop(0) - if len(sections) == N: - return True - else: - return False - - -# JSON Format : Entire output should be wrapped in JSON format. -def validate_json_format(text: str) -> bool: - try: - json_object = json.loads(text) - except ValueError as e: - return False - return True - - -# Repeat Prompt: First, repeat the request without change, then give your answer (do not say anything before -# repeating the request; the request you need to repeat does not include this sentence) -def validate_repeat_prompt(text: str, original_prompt: str) -> bool: - if text.startswith(original_prompt): - return True - else: - return False - - -# Two Responses: Give two different responses. Responses and only responses should be separated by 6 asterisk -# symbols: ******. -def validate_two_responses(text: str) -> bool: - if text.count('******') == 1: - response_list = text.split('******') - first_response = response_list[0].strip() - second_response = response_list[1].strip() - if first_response != second_response: - return True - return False - - -# All Uppercase: Your entire response should be in English, capital letters only. -def validate_uppercase(text: str) -> bool: - # Check if the response is the same as the uppercase version of the response - if text == text.upper(): - return True - else: - return False - - -# All Lowercase: Your entire response should be in English, and in all lowercase letters. No capital letters are -# allowed. -def validate_lowercase(text: str) -> bool: - # Check if the response is the same as the lowercase version of the response - if text == text.lower(): - return True - else: - return False - - -# Frequency of All-capital Words: In your response, words with all capital letters should appear at least / around / -# at most {N} times. -def validate_frequency_capital_words(text: str, N: int, quantifier: str) -> bool: - words = re.findall(r'\b[A-Z]+\b', text) - if quantifier == 'at least': - return len(words) >= N - elif quantifier == 'around': - return len(words) == N - elif quantifier == 'at most': - return len(words) <= N - else: - return False - - -# End Checker: Finish your response with this exact phrase {end phrase}. No other words should follow this phrase. -def validate_end(text: str, end_phrase: str) -> bool: - # Check if the response ends with the end phrase - if text.endswith(end_phrase): - return True - else: - return False - - -# Quotation: Wrap your entire response with double quotation marks. -def validate_quotation(text: str) -> bool: - if text.startswith('"') and text.endswith('"'): - return True - else: - return False - - -# No Commas: In your entire response, refrain from the use of any commas. -def validate_no_commas(text: str) -> bool: - if ',' not in text: - return True - else: - return False diff --git a/scripts/submit_eval_jobs.py b/scripts/submit_eval_jobs.py index 0c6f7cd18..48e9d061b 100755 --- a/scripts/submit_eval_jobs.py +++ b/scripts/submit_eval_jobs.py @@ -660,7 +660,7 @@ def adjust_gpus(task_spec, experiment_group, model_name, gpu_multiplier): experiment_name = experiment_name.replace('β', '').replace(r"{", "").replace(r"}", "") # hack: remove characters beaker doesn't like d["description"] = experiment_name # specific image for safety eval - d["tasks"][0]["image"]["beaker"] = "hamishivi/open-safety" + d["tasks"][0]["image"]["beaker"] = "hamishivi/open_safety_1124" if args.use_alternate_safety_image: d["tasks"][0]["image"]["beaker"] = args.use_alternate_safety_image d["tasks"] = [d["tasks"][0]] diff --git a/scripts/submit_finetune_job.py b/scripts/submit_finetune_job.py index 1f0a6d5b4..02fae860b 100644 --- a/scripts/submit_finetune_job.py +++ b/scripts/submit_finetune_job.py @@ -167,7 +167,7 @@ def parse_args(args): d['tasks'][0]['arguments'][0] = new_arguments # name and description - exp_name = f"open_instruct_finetune_{model_name}_{now}" + exp_name = f"open_instruct_finetune_{model_name}_{now}"[:128] d['description'] = exp_name d['tasks'][0]['name'] = exp_name @@ -333,6 +333,14 @@ def parse_args(args): d['tasks'][0]['envVars'].append({ 'name': 'WANDB_API_KEY', 'secret': f"{beaker_whoami}_WANDB_API_KEY" }) + + # Weka setting + if args.mount_on_weka: + if d['tasks'][0].get('datasets') is None: + d['tasks'][0]['datasets'] = [] + d['tasks'][0]['datasets'].append({ + 'mountPath': f"{args.weka_mount_path}", 'source': {'weka': f"{args.mount_on_weka}"} + }) # mount datasets if args.datasets: