Skip to content

fix: skip DDP and optimizer for prm_teacher (inference-only mode)#115

Open
HeJW0303 wants to merge 1 commit into
Gen-Verse:mainfrom
HeJW0303:fix/prm-teacher-inference-only
Open

fix: skip DDP and optimizer for prm_teacher (inference-only mode)#115
HeJW0303 wants to merge 1 commit into
Gen-Verse:mainfrom
HeJW0303:fix/prm-teacher-inference-only

Conversation

@HeJW0303

Copy link
Copy Markdown

Summary

PRM Teacher only performs forward passes to compute teacher log-probabilities for OPD loss. It never does backward passes or gradient updates. However, the current code loads it with full DDP
wrapping + Adam optimizer, wasting ~160GB of memory and causing OOM on single-GPU deployments.

Changes

File: slime/slime/backends/megatron_utils/model.py

  1. New function _setup_inference_only_model(args, role): Creates model with wrap_with_ddp=False, avoiding DDP gradient buffer allocation.
  2. New branch in initialize_model_and_optimizer(): When role == "prm_teacher", calls the inference-only setup, loads checkpoint without optimizer/scheduler, and returns early.

Memory Impact (27B BF16 model, single GPU)

Component Before (training mode) After (inference-only)
Model params (BF16) 54 GB 54 GB
DDP gradient buffers 54 GB 0
Adam m + v (FP32) 108 GB 0
Total ~216 GB OOM ~54 GB fits

Safety Verification

  • Checkpoint key matching: load_checkpoint -> unwrap_model() handles DDP/non-DDP correctly via sharded_state_dict path
  • eval + no_grad: forward_only() has @torch.no_grad() decorator; model_module.eval() is set
  • Gradient isolation: compute_prm_teacher_log_probs output is .cpu() under @torch.no_grad(), never participates in training loss.backward()

Testing

  • Model: Qwen3.5-27B (BF16)
  • Config: PRM Teacher on 1x GPU (TP=1)
  • Result: Model loads successfully, log-prob computation works correctly, no OOM

  PRM Teacher only performs forward passes to compute log-probabilities
  and never requires gradient updates. Loading with DDP + optimizer
  wastes ~160GB memory (gradient buffers + Adam states), making large
  models impossible to run on a single GPU.

  Changes:
  - Add _setup_inference_only_model() with wrap_with_ddp=False
  - Add early return path for role=='prm_teacher' in
    initialize_model_and_optimizer() that skips DDP/optimizer

  Tested with Qwen3.5-27B.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant