fix: skip DDP and optimizer for prm_teacher (inference-only mode)#115
Open
HeJW0303 wants to merge 1 commit into
Open
fix: skip DDP and optimizer for prm_teacher (inference-only mode)#115HeJW0303 wants to merge 1 commit into
HeJW0303 wants to merge 1 commit into
Conversation
PRM Teacher only performs forward passes to compute log-probabilities
and never requires gradient updates. Loading with DDP + optimizer
wastes ~160GB memory (gradient buffers + Adam states), making large
models impossible to run on a single GPU.
Changes:
- Add _setup_inference_only_model() with wrap_with_ddp=False
- Add early return path for role=='prm_teacher' in
initialize_model_and_optimizer() that skips DDP/optimizer
Tested with Qwen3.5-27B.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
PRM Teacher only performs forward passes to compute teacher log-probabilities for OPD loss. It never does backward passes or gradient updates. However, the current code loads it with full DDP
wrapping + Adam optimizer, wasting ~160GB of memory and causing OOM on single-GPU deployments.
Changes
File:
slime/slime/backends/megatron_utils/model.py_setup_inference_only_model(args, role): Creates model withwrap_with_ddp=False, avoiding DDP gradient buffer allocation.initialize_model_and_optimizer(): Whenrole == "prm_teacher", calls the inference-only setup, loads checkpoint without optimizer/scheduler, and returns early.Memory Impact (27B BF16 model, single GPU)
Safety Verification
load_checkpoint->unwrap_model()handles DDP/non-DDP correctly viasharded_state_dictpathforward_only()has@torch.no_grad()decorator;model_module.eval()is setcompute_prm_teacher_log_probsoutput is.cpu()under@torch.no_grad(), never participates in trainingloss.backward()Testing