You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We tackled not only invariance in the same framework, but across two different frameworks. This was a challenging task as it required effectively auditing every single invocation of every kernel. We heavily leveraged the forward pass kernels from vLLM’s [recent batch invariance](https://docs.vllm.ai/en/latest/features/batch_invariance/) work and wrote simple backward passes for these.
52
50
53
-
Then, we wrote a generic reinforcement learning script using GSM8K and a correctness reward. We run everything synchronously, alternating between trainer and generator on a single host. This is demonstrative of exactly on-policy execution, but is not very common in large scale runs.
51
+
We tackled not only invariance in the same framework, but across two different frameworks. This was a challenging task as it required effectively auditing every single invocation of every kernel. We heavily leveraged the forward pass kernels from vLLM’s [recent batch invariance](https://docs.vllm.ai/en/latest/features/batch_invariance/) work and wrote [simple backward passes](https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py) for these.
52
+
53
+
```
54
+
from vllm.model_executor.layers.activation import SiluAndMul
55
+
from vllm.model_executor.layers.batch_invariant import rms_norm
56
+
```
57
+
58
+
vLLM has many heavily optimized fused operations, such as the SiLU MLPs and RMSNorms (with added residuals). To maintain bitwise equivalence, we imported the exact operations for the forward passes. These operations needed custom backward passes registered, but this could be done in the same vanilla PyTorch TorchTitan is written in!
59
+
60
+
```
61
+
from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant
62
+
```
63
+
64
+
While writing the code, we ensured that the original non-invariant TorchTitan could be used. To make this optional and to reduce configuration parameters we leveraged vLLM’s exposed `vllm_is_batch_invariant` function.
65
+
66
+
Then, we wrote a generic reinforcement learning script using GSM8K and a correctness reward. We used TorchTitan’s utilities for a trainer and wrote a custom generator. Our generator, `VLLMRolloutEngine`, wraps simple functionality like calling generate and updating weights. We run everything synchronously, alternating between trainer and generator on a single host. This is demonstrative of exactly on-policy execution, but is not very common in large scale runs.
67
+
68
+
Note that we did not use `torch.compile` for the TorchTitan model, and thus enforced eager mode for vLLM. It is straightforward to remove this constraint, but a `torch.compile` version of the TorchTitan model would need to be built. vLLM heavily leverages `torch.compile` and is able to maintain batch-invariance with it - but to maintain cross-framework compatibility would require a change to the trained version of the model. This will be pursued in followup work!
54
69
55
70
While building this, testing was straightforward as we are able to use exact bitwise checks to ensure the forward logprobs and the perplexity generated by the trainer are identical. We will continue to improve the performance of vLLM and simplify the integration to support all TorchTitan models. To follow this work, please see the linked RFC: [#28326](https://github.com/vllm-project/vllm/issues/28326).
0 commit comments