Skip to content

Commit 22d082f

Browse files
[recipe] feat: add open math reasoning (#3767)
### What does this PR do? - Add open math reasoning recipe using sft trainer with model engine - Support setting none to val dataset in sft trainer - Fix main_eval - Using aiohttp for main_generation_server to avoid hang in AsyncOpenAI ### Checklist Before Starting - [ ] Search for similar PRs. Paste at least one query link here: ... - [ ] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [ ] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [ ] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent 8ec9bf6 commit 22d082f

File tree

12 files changed

+459
-56
lines changed

12 files changed

+459
-56
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Open math reasoning
2+
## Introduction
3+
In this recipe, we perform SFT on the [open math reasoning](https://huggingface.co/datasets/nvidia/OpenMathReasoning) dataset using the new SFT trainer with backend agostic model engine. Note that our goal is not to replicate the [AIMO-2 Winning Solution](https://arxiv.org/abs/2504.16891) work, but to demonstrate a SFT demo from end to end.
4+
5+
Note that you may need to modify the path as needed in the following scripts.
6+
## Dataset Preprocessing
7+
### Download Dataset
8+
```bash
9+
hf download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* --local-dir /path/to/dataset/nvidia/OpenMathReasoning
10+
hf download math-ai/aime24 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime24
11+
hf download math-ai/aime25 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime25
12+
```
13+
14+
### Preprocess the dataset
15+
```bash
16+
python3 recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py --local_dataset_path /path/to/nvidia/OpenMathReasoning --local_save_dir /path/to/open_math_reasoning
17+
```
18+
19+
### Prepare the eval dataset
20+
```bash
21+
python3 recipe/open_math_reasoning/prepare_eval_dataset.py --local_dataset_path /path/to/dataset --local_save_dir /path/to/eval_dataset
22+
```
23+
24+
## Train the model using SFT
25+
### FSDP backend
26+
export CKPT_HOME=/path/to/ckpt
27+
export BACKEND=fsdp2
28+
export MODEL_ID=Qwen/Qwen3-8B-Base
29+
export TRAIN_FILES=/path/to/open_math_reasoning/cot_dataset.parquet
30+
bash recipe/open_math_reasoning/run_sft_qwen3_8b.sh
31+
32+
### Megatron backend
33+
TODO
34+
35+
## Eval the model
36+
### Merge checkpoint into huggingface format
37+
```bash
38+
python -m verl.model_merger merge --backend fsdp --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface
39+
```
40+
41+
### Generate the responses
42+
```bash
43+
export MODEL_PATH=/path/to/ckpt/global_step_19751/huggingface
44+
bash recipe/open_math_reasoning/run_generation.sh
45+
```
46+
47+
### Evaluate the responses
48+
```bash
49+
bash recipe/open_math_reasoning/run_eval.sh
50+
```
51+
52+
You should see the results like:
53+
```python
54+
{'test_score/aime24': 0.584375, 'test_score/aime25': 0.43333333333333335}
55+
```
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
def compute_score_data_source(data_source, response, ground_truth):
17+
from verl.utils.reward_score.math_reward import compute_score
18+
19+
if data_source in ["aime24", "aime25"]:
20+
return compute_score(response, ground_truth)
21+
else:
22+
raise ValueError(f"Unknown data source: {data_source}")
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# prepare eval dataset including AIME'24, AIME'25
16+
17+
# hf download math-ai/aime24 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime24
18+
# hf download math-ai/aime25 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime25
19+
20+
import os
21+
22+
import datasets
23+
24+
from verl.utils.reward_score.math_reward import remove_boxed
25+
26+
instruction_following = "Please reason step by step, and put your final answer within \\boxed{}."
27+
28+
29+
def make_map_fn(data_source):
30+
def process_fn(example, idx):
31+
question_raw = example.pop("problem")
32+
33+
question = question_raw + " " + instruction_following
34+
35+
if "solution" not in example:
36+
example["solution"] = example["answer"]
37+
38+
answer_raw = example.pop("solution")
39+
40+
example.clear()
41+
42+
try:
43+
solution = remove_boxed(answer_raw)
44+
except Exception:
45+
solution = answer_raw
46+
47+
data = {
48+
"data_source": data_source,
49+
"prompt": [
50+
{
51+
"role": "user",
52+
"content": question,
53+
}
54+
],
55+
"ability": "math",
56+
"reward_model": {"style": "rule", "ground_truth": solution},
57+
"extra_info": {
58+
"index": idx,
59+
"answer": answer_raw,
60+
"question": question_raw,
61+
},
62+
}
63+
return data
64+
65+
return process_fn
66+
67+
68+
if __name__ == "__main__":
69+
import argparse
70+
71+
parser = argparse.ArgumentParser()
72+
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
73+
parser.add_argument(
74+
"--local_save_dir", default="~/data/math-ai", help="The save directory for the preprocessed dataset."
75+
)
76+
77+
args = parser.parse_args()
78+
79+
if args.local_dataset_path is not None:
80+
aime24_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime24")
81+
aime25_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime25")
82+
else:
83+
aime24_dataset_path = "math-ai/aime24"
84+
aime25_dataset_path = "math-ai/aime25"
85+
86+
aime24_dataset = datasets.load_dataset(aime24_dataset_path, split="test")
87+
aime25_dataset = datasets.load_dataset(aime25_dataset_path, split="test")
88+
89+
aime24_dataset = aime24_dataset.map(function=make_map_fn("aime24"), with_indices=True)
90+
aime25_dataset = aime25_dataset.map(function=make_map_fn("aime25"), with_indices=True)
91+
92+
local_save_dir = os.path.expanduser(args.local_save_dir)
93+
os.makedirs(local_save_dir, exist_ok=True)
94+
95+
aime24_dataset.to_parquet(os.path.join(local_save_dir, "aime24_test.parquet"))
96+
aime25_dataset.to_parquet(os.path.join(local_save_dir, "aime25_test.parquet"))
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
17+
--local-dir /path/to/nvidia/OpenMathReasoning
18+
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
19+
--local-dir /opt/tiger/nvidia/OpenMathReasoning
20+
"""
21+
22+
import argparse
23+
import os
24+
25+
import datasets
26+
27+
if __name__ == "__main__":
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
30+
parser.add_argument(
31+
"--local_save_dir",
32+
default="~/data/open_math_reasoning",
33+
help="The save directory for the preprocessed dataset.",
34+
)
35+
36+
args = parser.parse_args()
37+
local_dataset_path = args.local_dataset_path
38+
39+
data_source = "nvidia/OpenMathReasoning"
40+
41+
if local_dataset_path is not None:
42+
dataset = datasets.load_dataset(local_dataset_path, split="cot")
43+
else:
44+
dataset = datasets.load_dataset(data_source, split="cot")
45+
46+
def make_map_fn(split):
47+
def process_fn(example, idx):
48+
question = example.pop("problem")
49+
solution = example.pop("generated_solution")
50+
51+
extra_info = {}
52+
for key, value in example.items():
53+
extra_info[key] = value
54+
example.clear()
55+
56+
data = {
57+
"messages": [
58+
{"role": "user", "content": question, "loss_mask": 0},
59+
{"role": "assistant", "content": solution, "loss_mask": 1},
60+
],
61+
"extra_info": extra_info,
62+
}
63+
return data
64+
65+
return process_fn
66+
67+
# filter out data where the problem_type is not has_answer_extracted
68+
dataset = dataset.filter(lambda example: example["problem_type"] == "has_answer_extracted")
69+
dataset = dataset.map(function=make_map_fn("cot"), with_indices=True)
70+
local_save_dir = os.path.expanduser(args.local_save_dir)
71+
os.makedirs(local_save_dir, exist_ok=True)
72+
dataset.to_parquet(os.path.join(local_save_dir, "cot_dataset.parquet"))
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env bash
2+
3+
# Evaluation
4+
python3 -m verl.trainer.main_eval \
5+
data.path=$HOME/data/gen/qwen_8b_gen_test.parquet \
6+
custom_reward_function.path=recipe/open_math_reasoning/compute_score.py \
7+
custom_reward_function.name=compute_score_data_source
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env bash
2+
3+
MODEL_PATH=${MODEL_PATH:-/path/to/ckpt/global_step_19751/huggingface}
4+
5+
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
6+
NNODES=${NNODES:-1}
7+
OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_8b_gen_test.parquet}
8+
GEN_TP=${GEN_TP:-1} # Default tensor parallel size to 2
9+
10+
aime24_test_path=${HOME}/data/math-ai/aime24_test.parquet
11+
aime25_test_path=${HOME}/data/math-ai/aime25_test.parquet
12+
train_files="['$aime24_test_path', '$aime25_test_path']"
13+
14+
python3 -m verl.trainer.main_generation_server \
15+
trainer.nnodes="${NNODES}" \
16+
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
17+
actor_rollout_ref.model.path="${MODEL_PATH}" \
18+
actor_rollout_ref.model.trust_remote_code=True \
19+
actor_rollout_ref.rollout.temperature=1.0 \
20+
actor_rollout_ref.rollout.top_p=0.7 \
21+
actor_rollout_ref.rollout.prompt_length=2048 \
22+
actor_rollout_ref.rollout.response_length=20480 \
23+
actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \
24+
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
25+
actor_rollout_ref.rollout.name=vllm \
26+
actor_rollout_ref.rollout.n=32 \
27+
data.train_files="$train_files" \
28+
data.prompt_key=prompt \
29+
+data.output_path="${OUTPUT_PATH}" \
30+
31+
32+
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env bash
2+
set -xeuo pipefail
3+
4+
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
5+
6+
TRAIN_FILES=${TRAIN_FILES:-/path/to/cot_dataset.parquet}
7+
8+
backend=${BACKEND:-fsdp}
9+
10+
project_name=verl_sft_test
11+
12+
RESUME_MODE=auto
13+
MODEL_ID=${MODEL_ID:-Qwen/Qwen3-8B-Base}
14+
15+
SP_SIZE=${SP_SIZE:-8}
16+
FSDP_SIZE=${FSDP_SIZE:-16}
17+
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"}
18+
19+
TP_SIZE=${TP_SIZE:-1}
20+
PP_SIZE=${PP_SIZE:-1}
21+
VPP_SIZE=${VPP_SIZE:-null}
22+
CP_SIZE=${CP_SIZE:-1}
23+
24+
PAD_MODE=${PAD_MODE:-no_padding}
25+
26+
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
27+
28+
FSDP_ENGINE_CONFIG="\
29+
engine=${backend} \
30+
optim=${backend} \
31+
optim.lr=2e-5 \
32+
optim.lr_warmup_steps_ratio=0.01 \
33+
optim.weight_decay=0.1 \
34+
optim.betas="[0.9,0.95]" \
35+
optim.clip_grad=1.0 \
36+
optim.min_lr_ratio=0.1 \
37+
optim.warmup_style=cosine \
38+
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
39+
engine.strategy=${FSDP_STRATEGY} \
40+
engine.fsdp_size=${FSDP_SIZE}"
41+
42+
43+
MEGATRON_ENGINE_CONFIG="\
44+
engine=${backend} \
45+
optim=${backend} \
46+
optim.lr=1e-5 \
47+
optim.lr_warmup_steps_ratio=0.2 \
48+
optim.weight_decay=0.1 \
49+
optim.betas="[0.9,0.95]" \
50+
optim.clip_grad=1.0 \
51+
optim.lr_warmup_init=0 \
52+
optim.lr_decay_style=cosine \
53+
optim.min_lr=1e-6 \
54+
engine.tensor_model_parallel_size=${TP_SIZE} \
55+
engine.pipeline_model_parallel_size=${PP_SIZE} \
56+
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
57+
engine.context_parallel_size=${CP_SIZE}"
58+
59+
if [ "$backend" = "fsdp" ]; then
60+
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
61+
echo "Using fsdp engine"
62+
exp_name=nvidia-openmathreasoning-qwen3-8b-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp-1008a1
63+
else
64+
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
65+
echo "Using megatron engine"
66+
exp_name=nvidia-openmathreasoning-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
67+
fi
68+
69+
CKPT_HOME=${CKPT_HOME:-$HOME/open_verl/sft/${project_name}/${exp_name}}
70+
mkdir -p "${CKPT_HOME}"
71+
72+
torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-8} \
73+
${ENTRYPOINT} \
74+
data.train_files="${TRAIN_FILES}" \
75+
data.train_batch_size=96 \
76+
data.max_length=32768 \
77+
data.pad_mode=${PAD_MODE} \
78+
data.truncation=error \
79+
data.use_dynamic_bsz=True \
80+
data.max_token_len_per_gpu=65536 \
81+
data.messages_key=messages \
82+
model.path=$MODEL_ID \
83+
model.use_remove_padding=${USE_REMOVE_PADDING} \
84+
${ENGINE_CONFIG} \
85+
trainer.test_freq=-1 \
86+
trainer.save_freq=4000 \
87+
trainer.logger=['console','wandb'] \
88+
trainer.project_name="${project_name}" \
89+
trainer.experiment_name="${exp_name}" \
90+
trainer.total_epochs=1 \
91+
trainer.default_local_dir="${CKPT_HOME}" \
92+
trainer.resume_mode=${RESUME_MODE} \
93+
trainer.max_ckpt_to_keep=5 \
94+
checkpoint.save_contents=[model,optimizer,extra]

verl/trainer/config/sft_trainer_engine.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ data:
1818
max_token_len_per_gpu: 8192
1919
use_dynamic_bsz: True
2020
train_files: ~/data/gsm8k/train.parquet
21-
val_files: ~/data/gsm8k/test.parquet
21+
val_files: null
2222
# Multi-turn settings
2323
messages_key: messages # Key for messages list in multi-turn mode
2424
tools_key: tools # Key for tools list in multi-turn mode

0 commit comments

Comments
 (0)