feat(awex): FSDP colocate weight update via CUDA IPC#1361
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements FSDP colocate weight transfer in the AwexFSDPAdapter, enabling FSDP-trained models to update SGLang inference weights via CUDA IPC on shared GPUs. The changes include adding methods for initializing and executing colocate updates, as well as memory management functions (release_memory and resume_memory) to offload weights to the CPU. Comprehensive unit and integration tests have also been added. Feedback was provided regarding the robustness of the _execute_colocate_weight_update_locked method, specifically recommending the use of a try...finally block to ensure weights are re-offloaded even if an exception occurs, preventing potential GPU OOM issues.
f81950b to
ee70752
Compare
Implement the full colocate weight-update lifecycle for AwexFSDPAdapter:
- Add colocate state fields and init_colocate_weight_update to publish
per-rank local shards via CUDA IPC handles
- Implement execute_colocate_weight_update with all-gathered metadata
alignment and try/finally release semantics
- Implement release_memory/resume_memory('weights') for memory management
during colocated inference
- Wrap save_parameters with colocate resume/release for checkpoint safety
- Lazy-import Megatron in adapter factory for FSDP-only users
- Add DP e2e test and unit tests for the colocate path
ee70752 to
d4b9d34
Compare
59ecf48 to
d4b9d34
Compare
|
Hi @guozhihao-224, thanks for the PR! This is a useful first step for colocated FSDP weight update. From my side, I think we can merge this first and leave the remaining items as follow-ups. A few possible follow-ups:
These should not block this PR, but the invariant tests would make the current path safer. |
Description
Adds FSDP colocate weight transfer in
AwexFSDPAdapterso FSDP-trained models can update SGLang inference weights via CUDA IPC on shared GPUs, mirroring the existing Megatron colocate path.What changed:
4 colocate methods on
AwexFSDPAdapter:init_colocate_weight_update,execute_colocate_weight_update,release_memory(["weights"]),resume_memory(["weights"])._iter_hf_params_localhelper: yields each train rank's DTensor_local_tensor(itsShard(0)chunk), or the plain tensor if the param isn't a DTensor; reloads CPU-offloaded tensors to GPU. Skipslm_head.weightwhentie_word_embeddings=Trueso the train-side key set matches inference (SGLang/vLLM collapse the tied head intomodel.embed_tokens.weight).get_weight_metadatareports each train rank's truthfulShard(0)metadata:shape = local shape,global_offset = where this chunk starts in the global tensor. The colocate IPC payload from_iter_hf_params_localmatches that contract exactly, so awex's standardslice_tensor(shard-relativetrain_slices) indexes correctly into each rank's payload, and cross-engine P2P slices that reassemble the full tensor on the infer side are computed against truthful per-rank ownership.save_parameterswrapped with resume/release so the gateway debug/awex/debug/get_parameterspath works after colocate offloads training weights._create_training_adapterlazy-importsMegatronEngine. The eager import previously transitively pulled inmegatron.bridge→transformer_engine, whichpyproject.tomldeliberately marks never-install; FSDP-only deployments couldn't start the awex worker before this fix.13 mocked unit tests in
test_fsdp_colocate_unit.pycovering protocol-level correctness without GPU.New multi-GPU e2e test
test_awex_fsdp_colocate_dp_e2e_weight_update(gated bymulti_gpu and sglang and slow).Related Issue
Follow-up to #1310 (colocated CUDA IPC weight transfer for Megatron).
Type of Change
Checklist
pre-commit run --all-files)main/review-prcommand/create-prAdditional Context
GPU verification (manual, not in CI)
Verified manually on a 2-GPU host with
flash-attn 2.8.3matchingtorch 2.9.1+cu129/ py3.12.DP=2 splits
q_proj.weight[2048, 1024]into[1024, 1024]chunks across train ranks; both halves match bit-exactly (rtol=0, atol=0), confirming train rank 1's local-shard IPC payload reaches the right region of infer rank 0's full tensor via awex's standard transfer plan.[4gpu]and[8gpu]deselected due to GPU count.Open verification items (need 4+ GPU host)
test_awex_fsdp_colocate_dp_e2e_weight_update[4gpu]/[8gpu]test_awex_fsdp_e2e_weight_update[4gpu]/[8gpu](separated NCCL P2P path; metadata-data contract is now self-consistent for this path too, but not GPU-verified yet)Out of scope (file separately)
pyproject.toml'sflash-attn-4-only install is incomplete for transformers's flash-attn-2 detection; needs an additional pre-builtflash-attn==2.8.3wheel (already done byDockerfile).uv sync --extra cudaalone leavestransformersthinking fa2 is available but the import fails.nightly.ymlworkflow is a placeholder; theslow + multi_gputests, including this PR's e2e, never get auto-run.