Skip to content

Commit 08c3cf6

Browse files
authored
Merge pull request #1 from vllm-project/zhuohan/bitwise-rewrote
2 parents 9a4cf9a + 54251d3 commit 08c3cf6

File tree

1 file changed

+54
-42
lines changed

1 file changed

+54
-42
lines changed
Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,86 @@
1-
---
1+
---
22
layout: post
33
title: "No More Train-Inference Mismatch: Bitwise Consistent On-Policy Reinforcement Learning with vLLM and TorchTitan"
4-
author: "vLLM and TorchTitan Teams"
4+
author: "vLLM and TorchTitan Teams"
55
---
66

7-
Across the septillions of FLOPs used in pre-training, numerical mismatches have had effectively imperceptible impact. Pre-training typically runs at a fixed batch size which induces the same reduction kernels to be run - often side-stepping the issue entirely.
7+
We demonstrate an open-source bitwise consistent on-policy RL run with [TorchTitan](https://github.com/pytorch/torchtitan) as the training engine and [vLLM](https://github.com/vllm-project/vllm) as the inference engine. Built on top of [vLLM's recent work on batch-invariant inference](https://docs.vllm.ai/en/latest/features/batch_invariance/), we show how to run an RL fine-tune of Qwen3 1.7B with bitwise matching training and inference numerics in [our open-sourced instructions](https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/deterministic_vllm_rl):
88

9-
Reinforcement learning, on the other hand, seems to almost exclusively run different reduction algorithms due to its inference-heavy (and thus largely latency and memory-bound) nature. Kernels optimized for low-batch size inference typically run reductions without tiling, whereas kernels for training models parallelize heavily to reuse data and amp up compute utilization. That means the generators and the trainers are typically running completely different kernels!
9+
![](/assets/figures/2025-11-10-bitwise-exact-rl/rl-script-demo.png)
1010

11-
So intuitively, why might this be an issue? A rudimentary explanation is that the training becomes implicitly “off-policy” because the outputs from the generator do not match the outputs a trainer might produce given the same inputs.
11+
Reinforcement learning has been shown to amplify tiny numerical mismatches between trainer and sampler, leading to non-deterministic and unstable training behavior ([He et al.](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/) & [Yao, Liu et al.](https://fengyao.notion.site/off-policy-rl)). We verified the impact of numerics on RL results with our results: Running the sampler with different kernels than the trainer (`batch_inv_OFF`) shows a reduced reward over 100 steps. Enabling bitwise exact training (`batch_inv_ON`, where `kl_div` always equals to 0.0), we see the model not only train in fewer steps, but reach a higher total reward.
1212

13-
Discussion on this can be found on ThinkingMachine’s post Defeating Nondeterminism in LLM Inference ([He et al.](https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/)) and the post Your Efficient RL Framework Secretly Brings You Off-Policy RL Training ([Yao, Liu et al.](https://fengyao.notion.site/off-policy-rl)).
13+
![](/assets/figures/2025-11-10-bitwise-exact-rl/reward-comparison.png)
1414

15-
## Background
1615

17-
Floating point numbers are effectively a binary scientific notation. They utilize three components: a sign bit (s), a mantissa (M) and an exponent (e).
18-
<p align="center">
19-
<img width="340" height="130" src="/assets/figures/2025-11-10-bitwise-exact-rl/floating-point-representation.png" />
20-
</p>
16+
## Approach
2117

22-
Each of these components are represented as integers and suffer from the exact same rounding errors you might expect. In bf16, the most commonly used representation for machine learning, 7 bits are dedicated to the mantissa. This is not very many bits! The value 3.0 can be represented exactly, but a value like 3.6 cannot…
18+
Training and inference frameworks often use vastly different kernels because of the different workload properties. Even within an inference framework, different kernels can be chosen for different scenarios: Kernels for high batch sizes parallelize heavily on the batch dimension, while kernels for low batch sizes parallelize more within a single instance to have better utilization on parallel cores on GPUs. All these differences cause numerical differences between training and inference frameworks and lead to worse RL results.
2319

24-
<p align="center">
25-
<img width="480" height="355" src="/assets/figures/2025-11-10-bitwise-exact-rl/bf16-rounding-example.png" />
26-
</p>
2720

28-
When you want a new value in bf16 you end up rounding it to the nearest available value. What’s of particular interest today is the implication of this rounding process happening at different points in a sequence of additions.
21+
In this work, we tackled the invariance across two different frameworks: TorchTitan as the training framework and vLLM as the inference framework. We audited every single invocation of every kernel during the forward pass to make sure they are bitwise equivalent across the frameworks. We leveraged the forward pass kernels from vLLM’s [recent batch invariance](https://docs.vllm.ai/en/latest/features/batch_invariance/) work and wrote [simple backward passes](https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py) for these ops.
2922

30-
![](/assets/figures/2025-11-10-bitwise-exact-rl/rounding-sequence.png)
3123

32-
These rounding steps can cause two of the exact same inputs to generate *different* outputs! That means the same framework on the same hardware with the same inputs and the same weights can produce distinct outputs if *any* of the logic *anywhere* in the execution dispatches a different (but still correct) kernel.
24+
vLLM has many heavily optimized fused operations, such as the SiLU MLPs and RMSNorms (with added residuals). To maintain bitwise equivalence, we imported the exact operations for the forward passes. These operations needed custom backward passes registered, and this could be done in the same vanilla PyTorch TorchTitan is written in.
3325

34-
## Demonstration
26+
For the RL demo, we wrote a generic reinforcement learning script using GSM8K and a correctness reward. We used TorchTitan’s utilities for a trainer and wrote a custom generator. Our generator, `VLLMRolloutEngine`, wraps simple functionality like calling generate and updating weights. We run everything synchronously, alternating between trainer and generator on a single host. This is demonstrative of exactly on-policy execution, but is not very common in large scale runs.
3527

36-
Reinforcement learning has been shown to amplify tiny numerical perturbations, leading to non-deterministic and unstable training behavior. By combining the [recent work](https://github.com/pytorch/torchtitan/tree/main/torchtitan/experiments/deterministic_vllm_rl) of vLLM with TorchTitan we were able to demonstrate the stabilized training dynamics of reinforcement learning with exact bitwise parity between generator and trainer. This has been landed as a script in TorchTitan [here](https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py).
28+
## What’s Next
3729

38-
![](/assets/figures/2025-11-10-bitwise-exact-rl/rl-script-demo.png)
30+
We will continue to push forward on bitwise consistent training and inference. To follow this work, please see the linked RFC: [#28326](https://github.com/vllm-project/vllm/issues/28326). More specifically, we will focus on the following directions:
3931

40-
The script will download and run an RL fine-tune of Qwen3 1.7B locally and plot the reward and entropy in tensorboard.
32+
**Unified model definition.** Although we have demonstrated the bitwise equivalent training and inference results, there are still two copies of the model code, one for training and one for inference. This is easy for our first integration but fragile for long-term maintenance: any slight change to each of the model code will break the equivalence between training and inference and lead to numerical mismatches. Having a shared model code for both training and inference frameworks will eliminate the possibility of introducing accidental human errors and make the bitwise matching property easier to maintain.
4133

42-
![](/assets/figures/2025-11-10-bitwise-exact-rl/tensorboard-plot.png)
34+
**Compilation Support.** For now, we do not use `torch.compile` for the TorchTitan model, and thus enforce eager mode for vLLM. It is straightforward to remove this constraint, but a `torch.compile` version of the TorchTitan model would need to be built. vLLM heavily leverages `torch.compile` and is able to maintain batch-invariance with it - but to maintain cross-framework compatibility would require a change to the trained version of the model. This will be pursued in followup work!
4335

44-
Running the demonstration associated with this blog post we see exactly the issue described below. Running the generator with different kernels than the trainer (batch_inv_OFF) shows a reduced reward over 100 steps. Enabling bitwise exact training, we see the model not only train in fewer steps, but reach a higher total reward!
36+
**RL Performance** Our current results show that the bitwise RL run is 2.4x slower than the non-bitwise case. We will continue to improve the performance of vLLM with better tuning of batch-invariant kernels, as well as levereaging technologies including compilation.
4537

46-
![](/assets/figures/2025-11-10-bitwise-exact-rl/reward-comparison.png)
38+
**Wider Model Support** We plan to extend this bitwise-consistent RL framework beyond Qwen3 1.7B to support other open models. We will also generalize the auditing tools and backward implementations to cover a broader range of operator types, making bitwise training-inference consistency a scalable and reusable feature.
4739

48-
## How It’s Done & What’s Next
4940

41+
---
42+
*Authors:
43+
Bram Wasti, Wentao Ye, Teja Rao, Michael Goin, Paul Zhang, Tianyu Liu, Natalia Gimelshein, Woosuk Kwon, Kaichao You, Zhuohan Li*
5044

51-
We tackled not only invariance in the same framework, but across two different frameworks. This was a challenging task as it required effectively auditing every single invocation of every kernel. We heavily leveraged the forward pass kernels from vLLM’s [recent batch invariance](https://docs.vllm.ai/en/latest/features/batch_invariance/) work and wrote [simple backward passes](https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py) for these.
45+
<!--
5246
53-
```
54-
from vllm.model_executor.layers.activation import SiluAndMul
55-
from vllm.model_executor.layers.batch_invariant import rms_norm
56-
```
47+
DEPRECATED
5748
58-
vLLM has many heavily optimized fused operations, such as the SiLU MLPs and RMSNorms (with added residuals). To maintain bitwise equivalence, we imported the exact operations for the forward passes. These operations needed custom backward passes registered, but this could be done in the same vanilla PyTorch TorchTitan is written in!
5949
60-
```
61-
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
62-
```
6350
64-
While writing the code, we ensured that the original non-invariant TorchTitan could be used. To make this optional and to reduce configuration parameters we leveraged vLLM’s exposed `vllm_is_batch_invariant` function.
51+
## Background
52+
53+
Across the septillions of FLOPs used in pre-training, numerical mismatches have had effectively imperceptible impact. Pre-training typically runs at a fixed batch size which induces the same reduction kernels to be run - often side-stepping the issue entirely.
54+
55+
Reinforcement learning, on the other hand, seems to almost exclusively run different reduction algorithms due to its inference-heavy (and thus largely latency and memory-bound) nature. Kernels optimized for low-batch size inference typically run reductions without tiling, whereas kernels for training models parallelize heavily to reuse data and amp up compute utilization. That means the generators and the trainers are typically running completely different kernels!
6556
66-
Then, we wrote a generic reinforcement learning script using GSM8K and a correctness reward. We used TorchTitan’s utilities for a trainer and wrote a custom generator. Our generator, `VLLMRolloutEngine`, wraps simple functionality like calling generate and updating weights. We run everything synchronously, alternating between trainer and generator on a single host. This is demonstrative of exactly on-policy execution, but is not very common in large scale runs.
57+
So intuitively, why might this be an issue? A rudimentary explanation is that the training becomes implicitly “off-policy” because the outputs from the generator do not match the outputs a trainer might produce given the same inputs.
6758
68-
Note that we did not use `torch.compile` for the TorchTitan model, and thus enforced eager mode for vLLM. It is straightforward to remove this constraint, but a `torch.compile` version of the TorchTitan model would need to be built. vLLM heavily leverages `torch.compile` and is able to maintain batch-invariance with it - but to maintain cross-framework compatibility would require a change to the trained version of the model. This will be pursued in followup work!
59+
Floating point numbers are effectively a binary scientific notation. They utilize three components: a sign bit (s), a mantissa (M) and an exponent (e).
60+
<p align="center">
61+
<img width="340" height="130" src="/assets/figures/2025-11-10-bitwise-exact-rl/floating-point-representation.png" />
62+
</p>
6963
70-
While building this, testing was straightforward as we are able to use exact bitwise checks to ensure the forward logprobs and the perplexity generated by the trainer are identical. We will continue to improve the performance of vLLM and simplify the integration to support all TorchTitan models. To follow this work, please see the linked RFC: [#28326](https://github.com/vllm-project/vllm/issues/28326).
64+
Each of these components are represented as integers and suffer from the exact same rounding errors you might expect. In bf16, the most commonly used representation for machine learning, 7 bits are dedicated to the mantissa. This is not very many bits! The value 3.0 can be represented exactly, but a value like 3.6 cannot…
7165
72-
---
73-
*Authors:
74-
Bram Wasti, Wentao Ye, Teja Rao, Michael Goin, Paul Zhang, Tianyu Liu, Natalia Gimelshein, Woosuk Kwon, Kaichao You, Zhuohan Li*
66+
<p align="center">
67+
<img width="480" height="355" src="/assets/figures/2025-11-10-bitwise-exact-rl/bf16-rounding-example.png" />
68+
</p>
69+
70+
When you want a new value in bf16 you end up rounding it to the nearest available value. What’s of particular interest today is the implication of this rounding process happening at different points in a sequence of additions.
71+
72+
![](/assets/figures/2025-11-10-bitwise-exact-rl/rounding-sequence.png)
73+
74+
These rounding steps can cause two of the exact same inputs to generate *different* outputs! That means the same framework on the same hardware with the same inputs and the same weights can produce distinct outputs if *any* of the logic *anywhere* in the execution dispatches a different (but still correct) kernel. -->
75+
76+
<!-- ```
77+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
78+
``` -->
79+
80+
<!-- While writing the code, we ensured that the original non-invariant TorchTitan could be used. To make this optional and to reduce configuration parameters we leveraged vLLM’s exposed `vllm_is_batch_invariant` function. -->
81+
82+
83+
<!-- ```
84+
from vllm.model_executor.layers.activation import SiluAndMul
85+
from vllm.model_executor.layers.batch_invariant import rms_norm
86+
``` -->

0 commit comments

Comments
 (0)