Skip to content

Fix LoRA model training#1385

Open
lifeiteng wants to merge 4 commits into
areal-project:mainfrom
lifeiteng:fixlora
Open

Fix LoRA model training#1385
lifeiteng wants to merge 4 commits into
areal-project:mainfrom
lifeiteng:fixlora

Conversation

@lifeiteng
Copy link
Copy Markdown

Description

Two independent but related bug fixes that unblock LoRA RL training on the
SGLang backend. Both surface on a single 24GB GPU with enable_offload=false
(actor + sglang co-resident), where the adapter load/request lifecycle is most
stressed.

1. Unload stale LoRA adapters on disk weight update (61d1884a)
Every train step loaded a new versioned adapter lora-<name>-v{N} via
/load_lora_adapter but never unloaded older versions, so sglang accumulated
one adapter per step. VRAM crept up until sglang hung inside
update_weights_from_disk — observed stalling at step 39 (step 38's update was
10.3s; step 39 hit the 600s read timeout), well below the memory ceiling.
Fix: emit a best-effort /unload_lora_adapter for the version that has
fallen outside the retention window (max_head_offpolicyness + 2, enough to
cover in-flight off-policy rollouts). The unload is logged-and-ignored on
failure so it can never break a weight update.

2. Thread lora_name through to generation requests (67f7a4a0)
ArealOpenAI rebuilt GenerationHyperparameters without lora_name, so
/generate always used the dataclass default default_lora while the trainer
loaded the configured adapter. SGLang then rejected every rollout with
LoRA adapter that has never been loaded: default_lora-v0.
Fix: thread the configured adapter name through the request side, with
PPOConfig.__post_init__ syncing it from gconfig.lora_name as the single
source of truth.

Net effect: the load side and the request side now agree on <lora_name>-vN,
and stale adapters are reclaimed so long runs no longer hang.

Related Issue

N/A — discovered while running LoRA GRPO on a single-GPU colocated setup.
(Replace with Fixes #<id> if a tracking issue exists.)

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (N/A — internal bug fix, no user-facing docs)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable): N/A — all changes are additive,
backward-compatible.

Additional Context

Scope: 2 commits, 8 files, +133/−1, including a new 65-line unit test.

New teststests/test_sglang_lora_unload.py, CPU-only (assert on the
request builder; no GPU / sglang server required):

$ pytest tests/test_sglang_lora_unload.py -q
5 passed in 4.45s

Covers: stale-version unload beyond the retention window; no unload within the
window; lora_keep_versions=0 preserving the old load-only behaviour; the
unload being best-effort; and the non-LoRA disk path staying untouched.

Backward compatibility (non-breaking):

  • WeightUpdateMeta.lora_keep_versions, HttpRequest.best_effort, and
    InferenceEngineConfig.lora_name are additive fields with defaults; existing
    configs and the full-model weight-update path are unchanged.
  • lora_keep_versions=0 reproduces the exact prior behaviour (load only).

Note: commit 1 restores logic from standalone commit 386328b9 that was
lost in the controller-v2 refactor.

Key files:

  • areal/api/io_struct.pyWeightUpdateMeta.lora_keep_versions, HttpRequest.best_effort
  • areal/engine/sglang_remote.py — best-effort /unload_lora_adapter in the disk-update request builder
  • areal/infra/remote_inf_engine.py — disk-update executor ignores best-effort failures
  • areal/trainer/rl_trainer.py — derives lora_keep_versions from rollout.max_head_offpolicyness
  • areal/api/cli_args.pyInferenceEngineConfig.lora_name + PPOConfig.__post_init__ sync
  • areal/experimental/openai/{client.py, proxy/proxy_rollout_server.py} — pass lora_name on the request side

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a mechanism to unload stale LoRA adapter versions from the inference server, preventing VRAM accumulation and potential hangs during training. It adds a lora_keep_versions parameter to track the retention window, supports best-effort HTTP requests for cleanup, and propagates the LoRA adapter name through the OpenAI client and rollout configuration. Feedback suggests automatically syncing rollout.use_lora from actor.use_lora to avoid configuration mismatches, and catching Exception instead of BaseException when ignoring best-effort request failures to prevent catching system-exiting exceptions.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment thread areal/api/cli_args.py
Comment on lines +2946 to +2947
if self.rollout.use_lora and not self.rollout.lora_name:
self.rollout.lora_name = self.gconfig.lora_name
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

To prevent configuration mismatch bugs where actor.use_lora is enabled but rollout.use_lora is left disabled (which would cause rollout to run without LoRA), we should automatically sync rollout.use_lora from actor.use_lora during initialization.

Suggested change
if self.rollout.use_lora and not self.rollout.lora_name:
self.rollout.lora_name = self.gconfig.lora_name
if self.actor.use_lora:
self.rollout.use_lora = True
if self.rollout.use_lora and not self.rollout.lora_name:
self.rollout.lora_name = self.gconfig.lora_name

Comment thread areal/infra/remote_inf_engine.py Outdated
Comment thread areal/api/cli_args.py
default=False,
metadata={"help": "Whether to use LoRA. Should be same as actors LORA option."},
)
lora_name: str = field(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run docs/generate_cli_docs.py to update both en/zh versions of cli_reference.md at the same time and add them to this PR.。

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

lifeiteng and others added 4 commits June 5, 2026 10:27
ArealOpenAI rebuilt the generation GenerationHyperparameters without lora_name,
so /generate requests always used the dataclass default "default_lora" while
the trainer loaded the configured adapter (e.g. "lora-gsm8k-v0"). SGLang then
rejected every rollout with "LoRA adapter that has never been loaded:
default_lora-v0".

Thread the configured adapter name through the request side:
- Add InferenceEngineConfig.lora_name; PPOConfig.__post_init__ syncs it from
  gconfig.lora_name when rollout.use_lora (single source of truth).
- ArealOpenAI and its AsyncCompletionsWithReward / AsyncResponsesWithReward
  resources accept lora_name and set it on the rebuilt gconfig.
- proxy_rollout_server._setup_openai_client passes config.lora_name.

Load and request sides now agree on '<lora_name>-vN'.
Each train step loaded a new versioned adapter (lora-<name>-v{N}) via
/load_lora_adapter but never unloaded older versions, so sglang accumulated
one adapter per step. VRAM crept up until sglang hung inside
update_weights_from_disk -- on the single 24GB GPU with enable_offload=false
(actor + sglang co-resident) it stalled at step 39 (step 38 update_weights was
10.3s, step 39 hit the 600s read timeout), well below the memory ceiling.

Unload the version that has fallen outside the retention window
(max_head_offpolicyness + 2, enough to cover off-policy rollouts) as a
best-effort request, logged-and-ignored on failure so it never breaks the
weight update.

- Add WeightUpdateMeta.lora_keep_versions and HttpRequest.best_effort
- SGLangBackend appends a best-effort /unload_lora_adapter for v{N-keep}
- disk-update executor ignores best-effort request failures
- rl_trainer derives lora_keep_versions from rollout.max_head_offpolicyness

Restored from standalone commit 386328b9 (lost in the controller v2 refactor).
Add the lora_name InferenceEngineConfig entry to the EN/ZH CLI
reference tables, generated from the new config field. The adapter
name lets generation requests select the served LoRA adapter and is
auto-filled from gconfig.lora_name by PPOConfig.__post_init__ so the
load and request sides stay in sync.
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants