From d4b9d340795fcb31e77f82d64d0555d0f9b657b5 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 26 May 2026 19:23:09 +0800 Subject: [PATCH] feat(awex): implement FSDP colocate weight update via CUDA IPC Implement the full colocate weight-update lifecycle for AwexFSDPAdapter: - Add colocate state fields and init_colocate_weight_update to publish per-rank local shards via CUDA IPC handles - Implement execute_colocate_weight_update with all-gathered metadata alignment and try/finally release semantics - Implement release_memory/resume_memory('weights') for memory management during colocated inference - Wrap save_parameters with colocate resume/release for checkpoint safety - Lazy-import Megatron in adapter factory for FSDP-only users - Add DP e2e test and unit tests for the colocate path --- .../training_service/worker/awex.py | 12 +- .../weight_update/awex/fsdp_adapter.py | 302 +++++++++- .../weight_update/test_fsdp_colocate_unit.py | 550 ++++++++++++++++++ .../weight_update/test_nccl_integration.py | 159 +++++ 4 files changed, 1014 insertions(+), 9 deletions(-) create mode 100644 tests/experimental/weight_update/test_fsdp_colocate_unit.py diff --git a/areal/experimental/training_service/worker/awex.py b/areal/experimental/training_service/worker/awex.py index e206810b83..5d6a051076 100644 --- a/areal/experimental/training_service/worker/awex.py +++ b/areal/experimental/training_service/worker/awex.py @@ -194,15 +194,19 @@ def action(): def _create_training_adapter(engine): from areal.engine.fsdp_engine import FSDPEngine - from areal.engine.megatron_engine import MegatronEngine from areal.experimental.weight_update.awex.fsdp_adapter import AwexFSDPAdapter - from areal.experimental.weight_update.awex.megatron_adapter import ( - AwexMegatronAdapter, - ) if isinstance(engine, FSDPEngine): return AwexFSDPAdapter(engine) + # Lazy-import Megatron path: megatron-bridge requires transformer-engine, + # which pyproject.toml deliberately marks as never-install. FSDP-only + # users should not pay the import cost (or hit ImportError). + from areal.engine.megatron_engine import MegatronEngine + from areal.experimental.weight_update.awex.megatron_adapter import ( + AwexMegatronAdapter, + ) + if isinstance(engine, MegatronEngine): return AwexMegatronAdapter(engine) diff --git a/areal/experimental/weight_update/awex/fsdp_adapter.py b/areal/experimental/weight_update/awex/fsdp_adapter.py index 9c061bacb8..316121fdab 100644 --- a/areal/experimental/weight_update/awex/fsdp_adapter.py +++ b/areal/experimental/weight_update/awex/fsdp_adapter.py @@ -2,9 +2,13 @@ from __future__ import annotations # pyright: reportMissingImports=false +import gc import os +import threading +import time from typing import TYPE_CHECKING +import httpx import torch from awex.meta.weight_meta import ( ParameterMeta, @@ -15,6 +19,10 @@ from awex.sharding.rank_info import RankInfo from awex.transfer.nccl_comm import batch_send_recv, nccl_build_send_ops from awex.transfer.transfer_plan import TransferPlan, TransferPlanBuilder +from awex.util.tensor_util import ( + cuda_ipc_serialize, + group_tensors_by_shape_and_dtype, +) from torch.distributed.tensor import DTensor from torch.distributed.tensor.placement_types import Shard @@ -36,7 +44,8 @@ class AwexFSDPAdapter(AwexTrainingAdapter): - """Awex training adapter wrapping FSDPEngine for shard-direct NCCL P2P updates.""" + """Awex training adapter wrapping FSDPEngine for shard-direct NCCL P2P + updates and colocated CUDA-IPC updates (DP-only).""" def __init__(self, engine: FSDPEngine): self._engine = engine @@ -44,6 +53,20 @@ def __init__(self, engine: FSDPEngine): self._weights_update_group = None self._transfer_rank: int | None = None + # ── Colocate state ──────────────────────────────────────── + self._colocate_lock = threading.Lock() + self._colocate_pair_name: str | None = None + self._colocate_kv_store_url: str | None = None + self._colocate_transfer_rank: int | None = None + self._colocate_infer_world_size: int | None = None + self._colocate_admin_api_key: str = "areal-admin-key" + self._colocate_http_client: httpx.Client | None = None + self._colocate_timeout_s: float = 120.0 + + # ── release_memory / resume_memory state ────────────────── + self._released_tags: set[str] = set() + self._offloaded_weights: dict[str, torch.Tensor] = {} + @property def parallelism_strategy(self) -> dict: mesh = self._engine.world_mesh @@ -70,6 +93,9 @@ def get_weight_metadata(self) -> list[ParameterMeta]: for raw_name, param in self._engine.model.named_parameters(): name = self._to_hf_name(raw_name) if self._tie_word_embeddings and name == "lm_head.weight": + # Inference engines (SGLang/vLLM) collapse the tied head into + # `model.embed_tokens.weight`; mirror that to keep param key + # sets aligned across train and infer. continue tensor = param.data if isinstance(tensor, DTensor): @@ -119,9 +145,16 @@ def get_local_shard_parameters( return local_params def save_parameters(self, save_path: str, names: list[str] | None = None) -> None: - params = self.get_local_shard_parameters(names) - cpu_params = {k: v.detach().cpu().clone() for k, v in params.items()} - torch.save(cpu_params, save_path) + weights_offloaded = "weights" in self._released_tags + if weights_offloaded: + self.resume_memory(tags=["weights"]) + try: + params = self.get_local_shard_parameters(names) + cpu_params = {k: v.detach().cpu().clone() for k, v in params.items()} + torch.save(cpu_params, save_path) + finally: + if weights_offloaded: + self.release_memory(tags=["weights"]) def init_weight_update_group( self, @@ -195,6 +228,259 @@ def teardown_weight_update_group(self) -> None: self._weights_update_group = None self._transfer_plan = None self._transfer_rank = None + if self._colocate_http_client is not None: + self._colocate_http_client.close() + self._colocate_http_client = None + + # ── Colocated weight transfer methods ───────────────────────────────── + + def init_colocate_weight_update( + self, + pair_name: str, + kv_store_url: str, + transfer_rank: int, + infer_world_size: int, + train_world_size: int, + num_engines: int, + master_port: int, + admin_api_key: str = "areal-admin-key", + timeout_s: float = 120.0, + ) -> None: + del train_world_size, num_engines, master_port # not needed on training side + self._colocate_pair_name = pair_name + self._colocate_kv_store_url = kv_store_url + self._colocate_transfer_rank = transfer_rank + self._colocate_infer_world_size = infer_world_size + self._colocate_admin_api_key = admin_api_key + self._colocate_timeout_s = timeout_s + if self._colocate_http_client is None: + self._colocate_http_client = httpx.Client() + logger.info( + "Initialized colocate weight update for pair '%s', transfer_rank=%d", + pair_name, + transfer_rank, + ) + + def _iter_hf_params_local(self): + """Yield (hf_name, local_shard_tensor) for every parameter on this rank. + + Each rank yields its own DTensor ``_local_tensor`` (the actual Shard(0) + chunk owned by this rank), or the plain tensor if the parameter is not + a DTensor. This matches the ``Shard(0)`` metadata reported by + ``_extract_dtensor_shard_meta`` so that awex's ``slice_tensor`` + contract holds: ``send_parameters[name].shape == shard_meta.shape``. + + Cross-engine reassembly (i.e. infer rank 0 needs train rank 1's slice + for the second half of every param) is handled by the awex transfer + plan via ``_recv_transfer_plan`` + the colocate transport's P2P phase + — we just publish our local chunk via CUDA IPC and let awex route + slices to whichever infer rank needs them. + + When ``tie_word_embeddings`` is set, ``lm_head.weight`` is skipped so + the train-side key set matches inference engines (e.g. SGLang) which + collapse the tied head into ``model.embed_tokens.weight``. + """ + device = self._engine.device + for raw_name, param in self._engine.model.named_parameters(): + name = self._to_hf_name(raw_name) + if self._tie_word_embeddings and name == "lm_head.weight": + continue + tensor = param.data + if isinstance(tensor, DTensor): + local = tensor._local_tensor + else: + local = tensor + if local.device.type == "cpu": + local = local.to(device, non_blocking=True) + yield name, local.detach() + + def execute_colocate_weight_update(self, version: int) -> None: + with self._colocate_lock: + self._execute_colocate_weight_update_locked(version) + + def _execute_colocate_weight_update_locked(self, version: int) -> None: + kv_store_url = self._colocate_kv_store_url + pair_name = self._colocate_pair_name + transfer_rank = self._colocate_transfer_rank + assert self._colocate_http_client is not None, ( + "init_colocate_weight_update must be called first" + ) + client = self._colocate_http_client + auth_headers = {"Authorization": f"Bearer {self._colocate_admin_api_key}"} + timeout_s = self._colocate_timeout_s + + weights_offloaded = "weights" in self._released_tags + if weights_offloaded: + self.resume_memory(tags=["weights"]) + try: + # Publish each rank's local DTensor shard (no all-gather). The awex + # transfer plan routes the right slice from each rank's IPC payload + # to whichever infer rank needs it. + params: dict[str, torch.Tensor] = {} + for hf_name, tensor in self._iter_hf_params_local(): + params[hf_name] = tensor + tensors = list(params.values()) + names = list(params.keys()) + + group_tensors, metadata = group_tensors_by_shape_and_dtype(tensors) + torch.cuda.synchronize() + + del tensors + + group_shared = [t.share_memory_() for t in group_tensors] + serialized_weights = cuda_ipc_serialize((group_shared, metadata, names)) + torch.cuda.synchronize() + + kv_key = f"colocate_weights_rank{transfer_rank}_{version}" + client.put( + f"{kv_store_url}/weight_meta/{pair_name}/{kv_key}", + json={"value": serialized_weights.hex()}, + headers=auth_headers, + timeout=timeout_s, + ) + + logger.info( + "Serialized %d params (%d groups) for colocate transfer v%d, rank %d", + len(names), + len(group_shared), + version, + transfer_rank, + ) + + done_key = f"colocate_done_rank{transfer_rank}_{version}" + deadline = time.monotonic() + timeout_s + poll_count = 0 + last_status = -1 + while time.monotonic() < deadline: + resp = client.get( + f"{kv_store_url}/weight_meta/{pair_name}/{done_key}", + timeout=5.0, + ) + last_status = resp.status_code + if resp.status_code == 200: + break + poll_count += 1 + time.sleep(0.1) + else: + raise TimeoutError( + f"Inference did not signal completion within {timeout_s}s " + f"(waiting_key={done_key}, put_key={kv_key}, " + f"polls={poll_count}, last_status={last_status})" + ) + + del group_shared, group_tensors, serialized_weights + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + finally: + if weights_offloaded: + self.release_memory(tags=["weights"]) + + # ── Memory release / resume (colocate only) ─────────────────────────── + + _SUPPORTED_RELEASE_TAGS = {"weights"} + + def release_memory(self, tags: list[str] | None = None) -> None: + """Release GPU memory for ``weights`` tag by offloading to CPU. + + v1 supports only ``weights``. ``optimizer`` is intentionally not + supported on FSDP2 because per-parameter optimizer state lives in + sharded DTensor form and replacing it requires coordination with + FSDP2's reshard hooks (see ``release_memory`` design notes). + + Unsupported tags are logged as warnings and ignored. + """ + tags = tags or list(self._SUPPORTED_RELEASE_TAGS) + + unsupported = [t for t in tags if t not in self._SUPPORTED_RELEASE_TAGS] + if unsupported: + logger.warning( + "release_memory: tags %s not supported by FSDP adapter " + "(supported: %s), ignoring", + unsupported, + sorted(self._SUPPORTED_RELEASE_TAGS), + ) + + tags_to_release = [ + t + for t in tags + if t in self._SUPPORTED_RELEASE_TAGS and t not in self._released_tags + ] + if not tags_to_release: + logger.info("release_memory: tags=%s already released, skipping", tags) + return + + if "weights" in tags_to_release: + self._offload_model_weights() + self._released_tags.add("weights") + + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + logger.info("release_memory: done for tags=%s", tags_to_release) + + def _offload_model_weights(self) -> None: + if self._engine.model is None: + return + for name, param in self._engine.model.named_parameters(): + tensor = param.data + if isinstance(tensor, DTensor): + local = tensor._local_tensor + if local.is_cuda: + self._offloaded_weights[name] = local.detach().to( + "cpu", non_blocking=True + ) + tensor._local_tensor = torch.empty( + 0, dtype=local.dtype, device="cpu" + ) + else: + if tensor.is_cuda: + self._offloaded_weights[name] = tensor.detach().to( + "cpu", non_blocking=True + ) + param.data = torch.empty(0, dtype=tensor.dtype, device="cpu") + logger.info( + "Offloaded %d FSDP weight tensors to CPU", + len(self._offloaded_weights), + ) + + def resume_memory(self, tags: list[str] | None = None) -> None: + tags = tags or list(self._SUPPORTED_RELEASE_TAGS) + + tags_to_resume = [ + t + for t in tags + if t in self._SUPPORTED_RELEASE_TAGS and t in self._released_tags + ] + if not tags_to_resume: + logger.info("resume_memory: tags=%s not released, skipping", tags) + return + + if "weights" in tags_to_resume: + self._reload_model_weights() + self._released_tags.discard("weights") + + torch.cuda.synchronize() + logger.info("resume_memory: done for tags=%s", tags_to_resume) + + def _reload_model_weights(self) -> None: + if not self._offloaded_weights: + return + if self._engine.model is None: + return + + device = self._engine.device + for name, param in self._engine.model.named_parameters(): + if name not in self._offloaded_weights: + continue + saved = self._offloaded_weights[name] + tensor = param.data + if isinstance(tensor, DTensor): + tensor._local_tensor = saved.to(device, non_blocking=True) + else: + param.data = saved.to(device, non_blocking=True) + self._offloaded_weights.clear() + logger.info("Reloaded FSDP weights to GPU") def _to_hf_name(self, name: str) -> str: if self._engine.is_vision_model and is_qwen_vl_model( @@ -285,12 +571,18 @@ def _extract_dtensor_shard_meta( dtensor: DTensor, rank_info: RankInfo, ) -> ParameterShardMeta: + # Report this rank's actual local Shard(0) chunk: shape = local shape, + # global_offset = where this chunk starts in the global tensor. The + # colocate IPC payload (`_iter_hf_params_local`) publishes that same + # local chunk, so the awex transfer plan's `train_slices` (which are + # relative to shard.start_offset) correctly index into it, and any + # cross-engine P2P slice that reassembles the full tensor on the infer + # side is computed against truthful per-rank ownership. local_tensor = dtensor._local_tensor sharding_dim, num_shards = self._extract_dtensor_sharding(dtensor) sharding_type = ( ShardingType.TP_SHARDING if num_shards > 1 else ShardingType.NO_SHARDING ) - return ParameterShardMeta( tp_rank=rank_info.tp_rank, attn_tp_rank=rank_info.attn_tp_rank, diff --git a/tests/experimental/weight_update/test_fsdp_colocate_unit.py b/tests/experimental/weight_update/test_fsdp_colocate_unit.py new file mode 100644 index 0000000000..33659a9bf0 --- /dev/null +++ b/tests/experimental/weight_update/test_fsdp_colocate_unit.py @@ -0,0 +1,550 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for AwexFSDPAdapter colocate methods. + +These tests run without GPU/distributed by mocking the engine, awex IPC +helpers, and httpx client. They verify protocol-level correctness only; +real CUDA IPC + DTensor behavior is exercised in the multi-GPU e2e test +in test_nccl_integration.py. +""" + +from __future__ import annotations + +import threading +import time +from unittest.mock import MagicMock + +import httpx +import pytest +import torch + + +@pytest.fixture +def fsdp_adapter(): + """Build an AwexFSDPAdapter with a fully mocked FSDPEngine.""" + from areal.experimental.weight_update.awex.fsdp_adapter import ( + AwexFSDPAdapter, + ) + + mock_engine = MagicMock() + mock_engine.world_size = 2 + mock_engine.rank = 0 + mock_engine.dp_rank = 0 + mock_engine.data_parallel_world_size = 2 + mock_engine.is_vision_model = False + mock_engine.model_config.model_type = "qwen3" + # _iter_hf_params_local moves CPU-offloaded shards onto _engine.device; + # in unit tests we keep tensors on CPU to avoid touching CUDA. + mock_engine.device = torch.device("cpu") + + adapter = AwexFSDPAdapter.__new__(AwexFSDPAdapter) + adapter._engine = mock_engine + adapter._transfer_plan = None + adapter._weights_update_group = None + adapter._transfer_rank = None + return adapter + + +def test_constructor_exposes_colocate_state_fields(fsdp_adapter): + """Every colocate state field used by the four new methods must exist + after construction so that init/execute can store into them safely.""" + # Use the real constructor path (not __new__) for this test. + from areal.experimental.weight_update.awex.fsdp_adapter import ( + AwexFSDPAdapter, + ) + + mock_engine = MagicMock() + adapter = AwexFSDPAdapter(mock_engine) + + assert isinstance(adapter._colocate_lock, type(threading.Lock())) + assert adapter._colocate_http_client is None + assert adapter._colocate_admin_api_key == "areal-admin-key" + assert adapter._colocate_timeout_s == 120.0 + assert adapter._released_tags == set() + assert adapter._offloaded_weights == {} + + +def test_init_colocate_stores_config_and_creates_http_client(fsdp_adapter): + fsdp_adapter._colocate_lock = threading.Lock() + fsdp_adapter._colocate_http_client = None + fsdp_adapter._colocate_admin_api_key = "areal-admin-key" + fsdp_adapter._colocate_timeout_s = 120.0 + + fsdp_adapter.init_colocate_weight_update( + pair_name="pair_a", + kv_store_url="http://gw:8000", + transfer_rank=1, + infer_world_size=2, + train_world_size=2, + num_engines=1, + master_port=29500, + admin_api_key="custom-key", + timeout_s=60.0, + ) + + assert fsdp_adapter._colocate_pair_name == "pair_a" + assert fsdp_adapter._colocate_kv_store_url == "http://gw:8000" + assert fsdp_adapter._colocate_transfer_rank == 1 + assert fsdp_adapter._colocate_infer_world_size == 2 + assert fsdp_adapter._colocate_admin_api_key == "custom-key" + assert fsdp_adapter._colocate_timeout_s == 60.0 + assert isinstance(fsdp_adapter._colocate_http_client, httpx.Client) + + +def test_init_colocate_is_idempotent_for_http_client(fsdp_adapter): + fsdp_adapter._colocate_lock = threading.Lock() + existing = httpx.Client() + fsdp_adapter._colocate_http_client = existing + fsdp_adapter._colocate_admin_api_key = "areal-admin-key" + fsdp_adapter._colocate_timeout_s = 120.0 + + fsdp_adapter.init_colocate_weight_update( + pair_name="pair_b", + kv_store_url="http://gw:8000", + transfer_rank=0, + infer_world_size=1, + train_world_size=1, + num_engines=1, + master_port=29500, + ) + + assert fsdp_adapter._colocate_http_client is existing + existing.close() + + +def test_iter_hf_params_local_yields_local_shards_not_full_tensors(fsdp_adapter): + """Colocate IPC must publish each rank's local DTensor shard, not the + all-gathered full tensor — the awex transfer plan reassembles full + tensors on the inference side via cross-engine P2P slicing.""" + import areal.experimental.weight_update.awex.fsdp_adapter as mod + + local_shard = torch.zeros(4, 4) # half of an [8, 4] global tensor + plain = torch.ones(2, 3) + + mock_dtensor = MagicMock() + mock_dtensor._local_tensor = local_shard + mock_param_dt = MagicMock() + mock_param_dt.data = mock_dtensor + + mock_param_plain = MagicMock() + mock_param_plain.data = plain + + fsdp_adapter._engine.model.named_parameters.return_value = [ + ("model.layers.0.weight", mock_param_dt), + ("model.embed_tokens.weight", mock_param_plain), + ] + fsdp_adapter._engine.device = torch.device("cpu") # for unit test + # Sentinel: must NOT be called by the local-shard path. + fsdp_adapter._engine._get_full_tensor = MagicMock( + side_effect=AssertionError("must not all-gather in colocate publish") + ) + + original_isinstance = isinstance + + def fake_isinstance(obj, cls): + if cls is mod.DTensor: + return obj is mock_dtensor + return original_isinstance(obj, cls) + + import builtins + + builtins.isinstance = fake_isinstance + try: + captured = dict(fsdp_adapter._iter_hf_params_local()) + finally: + builtins.isinstance = original_isinstance + + # The helper detaches its return value to drop autograd; storage is shared + # but the Python object differs, so use data_ptr() to verify identity. + assert captured["model.layers.0.weight"].data_ptr() == local_shard.data_ptr() + assert captured["model.embed_tokens.weight"].data_ptr() == plain.data_ptr() + fsdp_adapter._engine._get_full_tensor.assert_not_called() + + +def test_execute_colocate_puts_weights_then_polls_done(fsdp_adapter, monkeypatch): + """execute_colocate must: PUT weights → poll done key → succeed.""" + fsdp_adapter._colocate_lock = threading.Lock() + fsdp_adapter._colocate_pair_name = "p1" + fsdp_adapter._colocate_kv_store_url = "http://gw:8000" + fsdp_adapter._colocate_transfer_rank = 0 + fsdp_adapter._colocate_admin_api_key = "k" + fsdp_adapter._colocate_timeout_s = 5.0 + fsdp_adapter._released_tags = set() + + plain = torch.ones(2, 3) + mock_param = MagicMock() + mock_param.data = plain + fsdp_adapter._engine.model.named_parameters.return_value = [ + ("model.embed_tokens.weight", mock_param), + ] + + captured_puts: list[tuple] = [] + poll_state = {"calls": 0} + + def mock_put(url, json=None, headers=None, timeout=None): + captured_puts.append((url, json, headers)) + resp = MagicMock() + resp.status_code = 200 + return resp + + def mock_get(url, timeout=None): + poll_state["calls"] += 1 + resp = MagicMock() + # First two polls return 404, third returns 200. + resp.status_code = 200 if poll_state["calls"] >= 3 else 404 + if resp.status_code == 200: + resp.json = lambda: {"value": True} + return resp + + mock_client = MagicMock() + mock_client.put.side_effect = mock_put + mock_client.get.side_effect = mock_get + fsdp_adapter._colocate_http_client = mock_client + + import areal.experimental.weight_update.awex.fsdp_adapter as mod + + monkeypatch.setattr( + mod, + "group_tensors_by_shape_and_dtype", + lambda tensors: (list(tensors), [{"shape": (2, 3)}]), + ) + monkeypatch.setattr( + mod, + "cuda_ipc_serialize", + lambda payload: b"\xde\xad\xbe\xef", + ) + # share_memory_ on a CPU tensor is a no-op for unit testing. + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + monkeypatch.setattr(time, "sleep", lambda s: None) + + fsdp_adapter.execute_colocate_weight_update(version=7) + + assert len(captured_puts) == 1 + url, payload, headers = captured_puts[0] + assert url == "http://gw:8000/weight_meta/p1/colocate_weights_rank0_7" + assert payload == {"value": "deadbeef"} + assert headers == {"Authorization": "Bearer k"} + + assert poll_state["calls"] >= 3 # at least one 404 then a 200 + + +def test_execute_colocate_raises_timeout_when_done_never_appears( + fsdp_adapter, monkeypatch +): + fsdp_adapter._colocate_lock = threading.Lock() + fsdp_adapter._colocate_pair_name = "p" + fsdp_adapter._colocate_kv_store_url = "http://gw:8000" + fsdp_adapter._colocate_transfer_rank = 0 + fsdp_adapter._colocate_admin_api_key = "k" + fsdp_adapter._colocate_timeout_s = 0.05 + fsdp_adapter._released_tags = set() + + fsdp_adapter._engine.model.named_parameters.return_value = [] + + mock_client = MagicMock() + mock_client.put.return_value = MagicMock(status_code=200) + mock_client.get.return_value = MagicMock(status_code=404) + fsdp_adapter._colocate_http_client = mock_client + + import areal.experimental.weight_update.awex.fsdp_adapter as mod + + monkeypatch.setattr( + mod, + "group_tensors_by_shape_and_dtype", + lambda tensors: ([], []), + ) + monkeypatch.setattr(mod, "cuda_ipc_serialize", lambda payload: b"") + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + monkeypatch.setattr(time, "sleep", lambda s: None) + + with pytest.raises(TimeoutError, match="did not signal completion"): + fsdp_adapter.execute_colocate_weight_update(version=1) + + +def test_release_memory_offloads_dtensor_local_to_cpu(fsdp_adapter, monkeypatch): + """release_memory(['weights']) must offload _local_tensor to CPU and + record the original tensor for resume. + + Use MagicMock for local tensors with is_cuda=True — real CPU tensors + cannot have is_cuda reassigned, and production code gates on is_cuda. + """ + cpu_offloaded_dt_local = torch.zeros(2, 4) + mock_dt_local = MagicMock() + mock_dt_local.is_cuda = True + mock_dt_local.dtype = torch.float32 + mock_dt_local.detach.return_value.to.return_value = cpu_offloaded_dt_local + + mock_dtensor = MagicMock() + mock_dtensor._local_tensor = mock_dt_local + mock_param = MagicMock() + mock_param.data = mock_dtensor + + cpu_offloaded_plain = torch.zeros(4, 4) + mock_plain = MagicMock() + mock_plain.is_cuda = True + mock_plain.dtype = torch.float32 + mock_plain.detach.return_value.to.return_value = cpu_offloaded_plain + plain_param = MagicMock() + plain_param.data = mock_plain + + fsdp_adapter._engine.model.named_parameters.return_value = [ + ("model.layers.0.weight", mock_param), + ("model.embed_tokens.weight", plain_param), + ] + + fsdp_adapter._released_tags = set() + fsdp_adapter._offloaded_weights = {} + + import areal.experimental.weight_update.awex.fsdp_adapter as mod + + original_isinstance = isinstance + + def fake_isinstance(obj, cls): + if cls is mod.DTensor: + return obj is mock_dtensor + return original_isinstance(obj, cls) + + import builtins + + builtins.isinstance = fake_isinstance + try: + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + monkeypatch.setattr(torch.cuda, "empty_cache", lambda: None) + fsdp_adapter.release_memory(tags=["weights"]) + finally: + builtins.isinstance = original_isinstance + + assert "weights" in fsdp_adapter._released_tags + assert ( + fsdp_adapter._offloaded_weights["model.layers.0.weight"] + is cpu_offloaded_dt_local + ) + assert ( + fsdp_adapter._offloaded_weights["model.embed_tokens.weight"] + is cpu_offloaded_plain + ) + # DTensor: _local_tensor swapped to empty CPU + assert mock_dtensor._local_tensor is not mock_dt_local + # Plain: param.data swapped to empty CPU + assert plain_param.data is not mock_plain + + +def test_release_memory_is_idempotent(fsdp_adapter): + saved = torch.zeros(1) + fsdp_adapter._released_tags = {"weights"} + fsdp_adapter._offloaded_weights = {"x": saved} + fsdp_adapter._engine.model.named_parameters.return_value = [] + + fsdp_adapter.release_memory(tags=["weights"]) + + # Already released → must not re-iterate or reset offloaded dict. + assert fsdp_adapter._released_tags == {"weights"} + assert fsdp_adapter._offloaded_weights == {"x": saved} + + +def test_release_memory_warns_on_unsupported_tags(fsdp_adapter, caplog): + fsdp_adapter._released_tags = set() + fsdp_adapter._offloaded_weights = {} + fsdp_adapter._engine.model.named_parameters.return_value = [] + + fsdp_adapter.release_memory(tags=["optimizer"]) + + # FSDP adapter v1 only supports "weights"; "optimizer" is logged as warning + # and ignored (not added to _released_tags). + assert "optimizer" not in fsdp_adapter._released_tags + + +def test_resume_memory_restores_offloaded_dtensor_local(fsdp_adapter, monkeypatch): + """After release+resume, _local_tensor must hold the original CUDA + tensor (or a copy on _engine.device for our mock).""" + saved = torch.ones(2, 4) + + mock_dtensor = MagicMock() + mock_dtensor._local_tensor = torch.empty(0, dtype=torch.float32, device="cpu") + mock_param = MagicMock() + mock_param.data = mock_dtensor + + plain_saved = torch.ones(2, 3) + plain_param = MagicMock() + plain_param.data = torch.empty(0, dtype=torch.float32, device="cpu") + + fsdp_adapter._engine.model.named_parameters.return_value = [ + ("model.layers.0.weight", mock_param), + ("model.embed_tokens.weight", plain_param), + ] + fsdp_adapter._engine.device = torch.device("cpu") # for unit test + fsdp_adapter._offloaded_weights = { + "model.layers.0.weight": saved, + "model.embed_tokens.weight": plain_saved, + } + fsdp_adapter._released_tags = {"weights"} + + import areal.experimental.weight_update.awex.fsdp_adapter as mod + + original_isinstance = isinstance + + def fake_isinstance(obj, cls): + if cls is mod.DTensor: + return obj is mock_dtensor + return original_isinstance(obj, cls) + + import builtins + + builtins.isinstance = fake_isinstance + try: + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + fsdp_adapter.resume_memory(tags=["weights"]) + finally: + builtins.isinstance = original_isinstance + + assert "weights" not in fsdp_adapter._released_tags + assert fsdp_adapter._offloaded_weights == {} + assert torch.equal(mock_dtensor._local_tensor, saved) + assert torch.equal(plain_param.data, plain_saved) + + +def test_resume_memory_is_noop_when_not_released(fsdp_adapter): + fsdp_adapter._released_tags = set() + fsdp_adapter._offloaded_weights = {} + fsdp_adapter._engine.model.named_parameters.return_value = [] + + fsdp_adapter.resume_memory(tags=["weights"]) + + assert fsdp_adapter._released_tags == set() + + +def test_execute_colocate_re_releases_on_timeout_when_offloaded( + fsdp_adapter, monkeypatch +): + """If the transfer body raises (TimeoutError, network failure, etc.), + release_memory must still run when weights were originally offloaded — + otherwise _released_tags ends up out of sync with physical state.""" + fsdp_adapter._colocate_lock = threading.Lock() + fsdp_adapter._colocate_pair_name = "p" + fsdp_adapter._colocate_kv_store_url = "http://gw:8000" + fsdp_adapter._colocate_transfer_rank = 0 + fsdp_adapter._colocate_admin_api_key = "k" + fsdp_adapter._colocate_timeout_s = 0.05 + fsdp_adapter._released_tags = {"weights"} # was offloaded before call + fsdp_adapter._offloaded_weights = {} + + fsdp_adapter._engine.model.named_parameters.return_value = [] + + call_order: list[str] = [] + + def track_resume(tags=None): + call_order.append(f"resume:{tags}") + fsdp_adapter._released_tags.discard("weights") + + def track_release(tags=None): + call_order.append(f"release:{tags}") + fsdp_adapter._released_tags.add("weights") + + fsdp_adapter.resume_memory = track_resume + fsdp_adapter.release_memory = track_release + + mock_client = MagicMock() + mock_client.put.return_value = MagicMock(status_code=200) + mock_client.get.return_value = MagicMock(status_code=404) # never 200 → timeout + fsdp_adapter._colocate_http_client = mock_client + + import areal.experimental.weight_update.awex.fsdp_adapter as mod + + monkeypatch.setattr( + mod, + "group_tensors_by_shape_and_dtype", + lambda tensors: ([], []), + ) + monkeypatch.setattr(mod, "cuda_ipc_serialize", lambda payload: b"") + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + monkeypatch.setattr(time, "sleep", lambda s: None) + + with pytest.raises(TimeoutError, match="did not signal completion"): + fsdp_adapter.execute_colocate_weight_update(version=1) + + # Resume ran first, then release ran in finally despite the raise. + # _released_tags ends back at {"weights"} so the offload-state invariant + # is preserved across exceptions. + assert call_order == ["resume:['weights']", "release:['weights']"] + assert "weights" in fsdp_adapter._released_tags + + +def test_execute_colocate_resumes_then_re_releases_when_offloaded( + fsdp_adapter, monkeypatch +): + """If weights were already offloaded, execute_colocate must: + resume → run transfer → release.""" + fsdp_adapter._colocate_lock = threading.Lock() + fsdp_adapter._colocate_pair_name = "p" + fsdp_adapter._colocate_kv_store_url = "http://gw:8000" + fsdp_adapter._colocate_transfer_rank = 0 + fsdp_adapter._colocate_admin_api_key = "k" + fsdp_adapter._colocate_timeout_s = 5.0 + fsdp_adapter._released_tags = {"weights"} # already released + fsdp_adapter._offloaded_weights = {} + + fsdp_adapter._engine.model.named_parameters.return_value = [] + + call_order: list[str] = [] + + def track_resume(tags=None): + call_order.append(f"resume:{tags}") + fsdp_adapter._released_tags.discard("weights") + + def track_release(tags=None): + call_order.append(f"release:{tags}") + fsdp_adapter._released_tags.add("weights") + + fsdp_adapter.resume_memory = track_resume + fsdp_adapter.release_memory = track_release + + mock_client = MagicMock() + mock_client.put.return_value = MagicMock(status_code=200) + done_resp = MagicMock(status_code=200) + done_resp.json = lambda: {"value": True} + mock_client.get.return_value = done_resp + fsdp_adapter._colocate_http_client = mock_client + + import areal.experimental.weight_update.awex.fsdp_adapter as mod + + monkeypatch.setattr( + mod, + "group_tensors_by_shape_and_dtype", + lambda tensors: ([], []), + ) + monkeypatch.setattr(mod, "cuda_ipc_serialize", lambda payload: b"") + monkeypatch.setattr(torch.cuda, "synchronize", lambda: None) + monkeypatch.setattr(time, "sleep", lambda s: None) + + fsdp_adapter.execute_colocate_weight_update(version=3) + + assert call_order == ["resume:['weights']", "release:['weights']"] + + +def test_save_parameters_resumes_offloaded_weights_then_re_releases(fsdp_adapter): + """save_parameters must temporarily resume offloaded weights for readback.""" + shard = torch.ones(2, 3) + fsdp_adapter._released_tags = {"weights"} + fsdp_adapter._offloaded_weights = {} + + call_order: list[str] = [] + + def track_resume(tags=None): + call_order.append(f"resume:{tags}") + fsdp_adapter._released_tags.discard("weights") + + def track_release(tags=None): + call_order.append(f"release:{tags}") + fsdp_adapter._released_tags.add("weights") + + fsdp_adapter.resume_memory = track_resume + fsdp_adapter.release_memory = track_release + fsdp_adapter.get_local_shard_parameters = MagicMock( + return_value={"model.embed_tokens.weight": shard} + ) + + fsdp_adapter.save_parameters("/tmp/params.pt", names=["model.embed_tokens.weight"]) + + assert call_order == ["resume:['weights']", "release:['weights']"] + fsdp_adapter.get_local_shard_parameters.assert_called_once_with( + ["model.embed_tokens.weight"] + ) diff --git a/tests/experimental/weight_update/test_nccl_integration.py b/tests/experimental/weight_update/test_nccl_integration.py index 8b9751b9fb..7fab567fcc 100644 --- a/tests/experimental/weight_update/test_nccl_integration.py +++ b/tests/experimental/weight_update/test_nccl_integration.py @@ -1218,3 +1218,162 @@ def test_awex_megatron_colocate_dp_multi_version_e2e(tmp_path_factory): train_ctrl.destroy() inf_ctrl.destroy() scheduler.delete_workers(None) + + +def _run_fsdp_colocate_e2e( + n_gpus: int, + pair_name: str, + tag: str, + tmp_path_factory, +) -> None: + """E2E for colocated FSDPEngine + SGLang on the same n_gpus. + + Mirrors _run_megatron_colocate_e2e but uses FSDPEngine for training. + DP-only (FSDP TP=PP=EP=1). + """ + from areal.api import FinetuneSpec + from areal.api.cli_args import ( + InferenceEngineConfig, + OptimizerConfig, + SchedulingSpec, + TrainEngineConfig, + ) + from areal.experimental.inference_service.controller.controller import ( + RolloutControllerV2, + ) + from areal.experimental.training_service.controller.controller import ( + GatewayTrainController, + ) + from areal.experimental.weight_update.controller import ( + WeightUpdateController, + WeightUpdateControllerConfig, + ) + + tmp = tmp_path_factory.mktemp(tag) + model_path = _get_test_model_path() + scheduler = _make_local_scheduler(tmp, tag, gpu_devices=list(range(n_gpus))) + + inf_config = InferenceEngineConfig( + tokenizer_path=model_path, + backend=f"sglang:d{n_gpus}", + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.inference_service.guard", + ), + ), + consumer_batch_size=8, + max_head_offpolicyness=1024, + setup_timeout=300.0, + admin_api_key="test-admin", + ) + inf_ctrl = RolloutControllerV2(config=inf_config, scheduler=scheduler) + + train_config = TrainEngineConfig( + backend=f"fsdp:d{n_gpus}", + experiment_name=f"test-awex-{tag}", + trial_name="t0", + path=model_path, + optimizer=OptimizerConfig(), + _version="v2", + setup_timeout=300.0, + scheduling_spec=( + SchedulingSpec( + gpu=1, + cmd="python -m areal.experimental.training_service.guard", + env_vars=dict(NCCL_CUMEM_ENABLE="0", NCCL_NVLS_ENABLE="0"), + ), + ), + ) + train_ctrl = GatewayTrainController( + train_engine="areal.engine.fsdp_engine.FSDPEngine", + config=train_config, + scheduler=scheduler, + ) + + wu_ctrl: WeightUpdateController | None = None + try: + inf_ctrl.initialize( + role="rollout", + server_args={"model_path": model_path, "mem_fraction_static": 0.7}, + wait=True, + ) + inf_worker_urls = list(inf_ctrl._inf_addrs) + + for url in inf_worker_urls: + resp = httpx.post(f"{url}/awex/debug/randomize_parameters", timeout=120.0) + assert resp.status_code == 200, f"randomize_parameters failed: {resp.text}" + + train_ctrl.initialize( + role="actor", + ft_spec=FinetuneSpec( + total_train_epochs=1, dataset_size=100, train_batch_size=2 + ), + wait=True, + ) + train_worker_urls = list(train_ctrl._worker_addrs) + + wu_ctrl = WeightUpdateController( + config=WeightUpdateControllerConfig(host="127.0.0.1", request_timeout=300.0) + ) + wu_ctrl.initialize() + assert wu_ctrl.health_check() + + wu_ctrl.connect( + pair_name=pair_name, + train_worker_urls=train_worker_urls, + inference_worker_urls=inf_worker_urls, + colocate=True, + ) + + result = wu_ctrl.update_weights(version=1) + assert result.status == "ok" + assert result.version == 1 + + wu_ctrl.disconnect() + + gen_resp = httpx.post( + f"{inf_worker_urls[0]}/generate", + json={ + "text": "Hello", + "sampling_params": {"max_new_tokens": 5, "temperature": 0}, + }, + timeout=30.0, + ) + assert gen_resp.status_code == 200, ( + f"Generation failed after weight update: {gen_resp.text}" + ) + + _validate_weight_update_correctness( + train_worker_urls=train_worker_urls, + inf_worker_url=inf_worker_urls[0], + param_dir=tmp, + ) + finally: + if wu_ctrl is not None: + wu_ctrl.destroy() + train_ctrl.destroy() + inf_ctrl.destroy() + scheduler.delete_workers(None) + + +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.sglang +@pytest.mark.parametrize("n_gpus", [2, 4, 8], ids=["2gpu", "4gpu", "8gpu"]) +def test_awex_fsdp_colocate_dp_e2e_weight_update(n_gpus, tmp_path_factory): + """Full round trip: colocated FSDPEngine (pure DP) + SGLang on same GPUs. + + DP-only, mirrors test_awex_megatron_colocate_dp_e2e_weight_update. + Verifies that FSDPEngine._get_full_tensor() materialization + CUDA-IPC + transfer produces parameters on the inference side that match the + training side bit-exactly (via FSDP shard-concat validation). + """ + if current_platform.device_count() < n_gpus: + pytest.skip(f"This test requires {n_gpus} GPUs") + _run_fsdp_colocate_e2e( + n_gpus=n_gpus, + pair_name=f"test_fsdp_colocate_dp{n_gpus}", + tag=f"fsdp_colocate_dp{n_gpus}", + tmp_path_factory=tmp_path_factory, + )