diff --git a/.github/workflows/claude-code-review.yml b/.github/workflows/claude-code-review.yml new file mode 100644 index 0000000000..4f6145beb8 --- /dev/null +++ b/.github/workflows/claude-code-review.yml @@ -0,0 +1,44 @@ +name: Claude Code Review + +on: + pull_request: + types: [opened, synchronize, ready_for_review, reopened] + # Optional: Only run on specific file changes + # paths: + # - "src/**/*.ts" + # - "src/**/*.tsx" + # - "src/**/*.js" + # - "src/**/*.jsx" + +jobs: + claude-review: + # Optional: Filter by PR author + # if: | + # github.event.pull_request.user.login == 'external-contributor' || + # github.event.pull_request.user.login == 'new-developer' || + # github.event.pull_request.author_association == 'FIRST_TIME_CONTRIBUTOR' + + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude Code Review + id: claude-review + uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + plugin_marketplaces: 'https://github.com/anthropics/claude-code.git' + plugins: 'code-review@claude-code-plugins' + prompt: '/code-review:code-review ${{ github.repository }}/pull/${{ github.event.pull_request.number }}' + # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + # or https://code.claude.com/docs/en/cli-reference for available options + diff --git a/.github/workflows/claude.yaml b/.github/workflows/claude.yaml new file mode 100644 index 0000000000..4f06a94cce --- /dev/null +++ b/.github/workflows/claude.yaml @@ -0,0 +1,29 @@ +name: Claude + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened] + pull_request_review: + types: [submitted] + +jobs: + claude: + if: | + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || + (github.event_name == 'issues' && contains(github.event.issue.body, '@claude')) + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write + issues: write + id-token: write + steps: + - uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} \ No newline at end of file diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml new file mode 100644 index 0000000000..4848be3671 --- /dev/null +++ b/.github/workflows/claude.yml @@ -0,0 +1,50 @@ +name: Claude Code + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened, assigned] + pull_request_review: + types: [submitted] + +jobs: + claude: + if: | + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || + (github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude'))) + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + actions: read # Required for Claude to read CI results on PRs + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude Code + id: claude + uses: anthropics/claude-code-action@v1 + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + + # This is an optional setting that allows Claude to read CI results on PRs + additional_permissions: | + actions: read + + # Optional: Give a custom prompt to Claude. If this is not specified, Claude will perform the instructions specified in the comment that tagged it. + # prompt: 'Update the pull request description to include a summary of changes.' + + # Optional: Add claude_args to customize behavior and configuration + # See https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md + # or https://code.claude.com/docs/en/cli-reference for available options + # claude_args: '--allowed-tools Bash(gh pr *)' + diff --git a/.github/workflows/config/.secrets.baseline b/.github/workflows/config/.secrets.baseline index a502b59eb6..555dd017bd 100644 --- a/.github/workflows/config/.secrets.baseline +++ b/.github/workflows/config/.secrets.baseline @@ -131,7 +131,7 @@ "filename": "docs/testing.md", "hashed_secret": "3f3b8ce7c4fec509b2b74ee3e1d98170278ffe4b", "is_verified": false, - "line_number": 116 + "line_number": 113 } ], "tests/unit/test_version_check.py": [ @@ -144,5 +144,5 @@ } ] }, - "generated_at": "2026-02-24T15:55:12Z" + "generated_at": "2026-04-02T18:53:27Z" } diff --git a/3rdparty/Megatron-LM-workspace/setup.py b/3rdparty/Megatron-LM-workspace/setup.py index d6339e726a..cd18f5322d 100644 --- a/3rdparty/Megatron-LM-workspace/setup.py +++ b/3rdparty/Megatron-LM-workspace/setup.py @@ -51,7 +51,7 @@ # TODO(https://github.com/NVIDIA-NeMo/RL/issues/2111): upgrade to core_cu13 when we move to CUDA 13 base container "transformer-engine[pytorch,core_cu12]", # VCS dependency - must match pyproject.toml [tool.uv.sources] - "nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@63154570cea17f8805a7fd15cc3b8cc2919ba575", + "nvidia-resiliency-ext @ git+https://github.com/NVIDIA/nvidia-resiliency-ext.git@15a851565f06e279f18c3ac5e1296b1bcb63be24", "tqdm", "einops~=0.8", "tensorstore~=0.1,!=0.1.46,!=0.1.72", diff --git a/docs/about/algorithms/index.md b/docs/about/algorithms/index.md index 9f4bec628b..fe88e2636b 100644 --- a/docs/about/algorithms/index.md +++ b/docs/about/algorithms/index.md @@ -7,12 +7,12 @@ NeMo RL supports multiple training algorithms for post-training large language m | Algorithms | Single Node | Multi-node | |------------|-------------|------------| | [GRPO](grpo.md) | [GRPO Single Node](grpo.md#grpo-single-node) | [GRPO Multi-node](grpo.md#grpo-multi-node): [GRPO Qwen2.5-32B](grpo.md#grpo-qwen25-32b), [GRPO Multi-Turn](grpo.md#grpo-multi-turn) | -|DAPO (dapo.md)| similar to GRPO example| similar to GRPO example| | [DAPO](dapo.md) | [DAPO Single Node](dapo.md#dapo-single-node) | [DAPO Multi-node](dapo.md#dapo-multi-node) | | [On-policy Distillation](on-policy-distillation.md) | [Distillation Single Node](on-policy-distillation.md#on-policy-distillation-single-node) | [Distillation Multi-node](on-policy-distillation.md#on-policy-distillation-multi-node) | | [Supervised Fine-Tuning (SFT)](sft.md) | [SFT Single Node](sft.md#sft-single-node) | [SFT Multi-node](sft.md#sft-multi-node) | | [DPO](dpo.md) | [DPO Single Node](dpo.md#dpo-single-node) | [DPO Multi-node](dpo.md#dpo-multi-node) | | [RM](rm.md) | [RM Single Node](rm.md#rm-single-node) | [RM Multi-node](rm.md#rm-multi-node) | + On-policy distillation is also supported in the PyTorch DTensor path. ```{toctree} :maxdepth: 2 diff --git a/docs/about/model-support.md b/docs/about/model-support.md index 0d49606031..6454a315db 100644 --- a/docs/about/model-support.md +++ b/docs/about/model-support.md @@ -2,7 +2,7 @@ ## Broad coverage for ๐Ÿค—Hugging Face models via [NeMo AutoModel](https://github.com/NVIDIA-NeMo/Automodel) -NeMo-RL support ๐Ÿค—Hugging Face models from the following classes +NeMo-RL supports ๐Ÿค—Hugging Face models from the following classes - LLMs ([AutoModelForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoModelForCausalLM)) - VLMs ([AutoModelForImageTextToText](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoModelForImageTextToText)) diff --git a/docs/about/performance-summary.md b/docs/about/performance-summary.md index 05f55d8862..61ece32d23 100644 --- a/docs/about/performance-summary.md +++ b/docs/about/performance-summary.md @@ -1,7 +1,7 @@ # Performance -As part of the NVIDIA NeMo Framework, NeMo RL, provides optimal performance for reinforcement learning on generative AI models by incorporating the latest optimizations - such as refit optimizations, mixed-precision training, and off-policy training. +As part of the NVIDIA NeMo Framework, NeMo RL provides optimal performance for reinforcement learning on generative AI models by incorporating the latest optimizations - such as refit optimizations, mixed-precision training, and off-policy training. This page provides performance benchmarks for LLMs and VLMs using NeMo RL across different GPU systems and configurations. The recipes to reproduce these runs, in yaml file form, can be found under [this folder](https://github.com/NVIDIA-NeMo/RL/tree/r0.5.0/examples/configs/recipes/llm/performance). @@ -16,13 +16,13 @@ This page provides performance benchmarks for LLMs and VLMs using NeMo RL across - **EP**: Expert Parallel Size - **T-**: Training related - **G-**: Generation related -- **Training backend**: NeMo RL have two training backends: Megatron and PyTorch DTensor. This performance summary currently only shows number from Megatron backend. +- **Training backend**: NeMo RL has two training backends: Megatron and PyTorch DTensor. This performance summary currently only shows numbers from the Megatron backend. ## Performance Metrics Since reinforcement learning consists of training, generation and transition between the two, performance measurement also reflects this. Specifically, we track the following metrics: - **Step time**: Time for each step, which includes training, generation, policy logprobs, and refit time. -- **Tokens/sec/GPU**: The rate at the tokens are processed by a stage (such as training, generation, or refitting) on a single GPU: +- **Tokens/sec/GPU**: The rate at which the tokens are processed by a stage (such as training, generation, or refitting) on a single GPU: $$ \text{Tokens/sec/GPU} = \frac{\text{Total Tokens Processed}}{\text{Time for Stage} \times \text{Number of GPUs}} @@ -98,4 +98,4 @@ The performance data includes: Note: * All Mixture-of-expert (MoE) model training uses token drop-less. -* The following metrics are extracted from the average of 5 steps: G-Average Seq len, Tokens/sec/gpu, Total Step time(s). Because of the averaging, the numbers in table does not completely match the equation stated in Performance Metrics above but the difference is small. +* The following metrics are extracted from the average of 5 steps: G-Average Seq len, Tokens/sec/gpu, Total Step time(s). Because of the averaging, the numbers in the table do not completely match the equation stated in Performance Metrics above but the difference is small. diff --git a/docs/debugging.md b/docs/debugging.md index cd3b55d354..02b4cfc1be 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -33,7 +33,7 @@ The first node is always the head node, so we need to port forward the dashboard # on the login node is likely taken by someone else. ssh -L $LOCAL_PORT:localhost:$DASHBOARD_PORT -N node-12 -# Example chosing a port other than 8265 for the LOCAL_PORT +# Example choosing a port other than 8265 for the LOCAL_PORT ssh -L 52640:localhost:8265 -N node-12 ``` diff --git a/docs/fp8.md b/docs/fp8.md index 19e7a86d6a..975efc735f 100644 --- a/docs/fp8.md +++ b/docs/fp8.md @@ -93,4 +93,4 @@ The above results are from Llama-3.1-8B-Instruct GRPO experiments. You can run t * For FP8: `examples/configs/grpo_math_8B_megatron_fp8.yaml` In the experiment in this figure, enabling FP8 rollout and training gives 15%-25% decrease in step time, and the validation accuracy curves match up to 1000 steps. -Efforts are ongoing to performs longer runs and further optimize performance. +Efforts are ongoing to perform longer runs and further optimize performance. diff --git a/docs/nsys-profiling.md b/docs/nsys-profiling.md index 7931a9cc95..dfbf085786 100644 --- a/docs/nsys-profiling.md +++ b/docs/nsys-profiling.md @@ -22,7 +22,7 @@ export NRL_NSYS_WORKER_PATTERNS="*policy*,*vllm*" Set the `NRL_NSYS_PROFILE_STEP_RANGE` environment variable to control which training steps the profiler captures. Its format is colon separated integers representing `start:stop`, where `start` is inclusive and `stop` is exclusive -(same as slice syntax `arr[start:stop]`). Note that the `start` is 1-index, so `NRL_NSYS_PROFILE_STEP_RANGE=0:10` would error. +(same as slice syntax `arr[start:stop]`). Note that the `start` is 1-indexed, so `NRL_NSYS_PROFILE_STEP_RANGE=0:10` would error. ```bash export NRL_NSYS_PROFILE_STEP_RANGE=3:5 diff --git a/docs/testing.md b/docs/testing.md index 984054aad4..e1c9d26773 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -55,7 +55,6 @@ Limitations and tips: - The remote-aware selection uses a conservative static import map (no dynamic import resolution). If a test loads code dynamically that isnโ€™t visible via imports, you may need to run it explicitly once to seed the map. - The helper is test-only and does not alter library behavior. It activates automatically when you pass `--testmon`. -Refreshing remote-selection artifacts ### Refreshing Remote-Selection Artifacts If you change test layout or significantly refactor imports, the remote-selection artifacts may become stale. To rebuild them, delete the following files at the repo root and re-run with `--testmon` to seed again: @@ -68,9 +67,7 @@ rm .nrl_remote_map.json .nrl_remote_state.json ### Run Unit Tests in a Hermetic Environment -For environments lacking necessary dependencies (e.g., `gcc`, `nvcc`) -or where environmental configuration may be problematic, tests can be run -in Docker with this script: +For environments lacking necessary dependencies (e.g., `gcc`, `nvcc`) or where environmental configuration may be problematic, tests can be run in Docker with this script: ```sh CONTAINER=... bash tests/run_unit_in_docker.sh @@ -155,7 +152,6 @@ Functional tests are located under `tests/functional/`. uv run bash tests/functional/sft.sh ``` -At the end of each functional test, the metric checks will be printed as well as At the end of each functional test, the metric checks will be printed as well as whether they pass or fail. Here is an example: ```text @@ -169,8 +165,6 @@ At the end of each functional test, the metric checks will be printed as well as ### Run Functional Tests in a Hermetic Environment -For environments lacking necessary dependencies (e.g., `gcc`, `nvcc`) -or where environmental configuration may be problematic, tests can be run For environments lacking necessary dependencies (e.g., `gcc`, `nvcc`) or where environmental configuration may be problematic, tests can be run in Docker with this script: ```sh diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 85953eb0ce..4b3e9be881 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -2515,9 +2515,21 @@ def async_grpo_train( }, } + # Register trajectory collector as a named Ray actor so the rlix pipeline can + # look it up for set_weight_version calls (spec: nemorl-port-plan.md lines 490, 538, 603). + _rlix_pipeline_id = os.environ.get("PIPELINE_ID", "") + _rlix_ray_namespace = os.environ.get("ROLL_RAY_NAMESPACE", "") + _tc_name = ( + f"rlix:trajectory_collector:{_rlix_pipeline_id}" + if _rlix_pipeline_id + else None + ) + _tc_namespace = _rlix_ray_namespace if _rlix_ray_namespace else None + # Initialize trajectory collector with synchronized collection trajectory_collector = AsyncTrajectoryCollector.options( - runtime_env=_tc_runtime_env + runtime_env=_tc_runtime_env, + **({"name": _tc_name, "namespace": _tc_namespace} if _tc_name else {}), ).remote( policy_generation=policy_generation, tokenizer=tokenizer, diff --git a/nemo_rl/models/generation/vllm/vllm_backend.py b/nemo_rl/models/generation/vllm/vllm_backend.py index 9237788be1..c8fd8e3131 100644 --- a/nemo_rl/models/generation/vllm/vllm_backend.py +++ b/nemo_rl/models/generation/vllm/vllm_backend.py @@ -309,6 +309,293 @@ def _load_model_weights(weights, model_runner): return True + # ------------------------------------------------------------------ + # rlix integration: selective sync receiver methods (Feature 4) + # ------------------------------------------------------------------ + + def setup_collective_group( + self, + model_update_name: str, + comm_plan: dict, + mode: str, + timeout_s: float | None = None, + ) -> None: + """Join a dynamic NCCL group for selective model weight broadcast. + + Stores the group in ``self._model_update_groups[group_name]``. + + Args: + model_update_name: Unique sync identifier. + comm_plan: Communication plan with master_addr/port/world_size. + mode: 'receiver' (inference workers are always receivers). + timeout_s: Optional NCCL init timeout in seconds (unused; StatelessProcessGroup uses its own timeout). + """ + from nemo_rl.distributed.stateless_process_group import StatelessProcessGroup + + if not hasattr(self, "_model_update_groups"): + self._model_update_groups: dict = {} # pyrefly: ignore[implicitly-defined-attribute] + + plan_entry = comm_plan[next(iter(comm_plan))] + group_name: str = plan_entry["group_name"] + master_addr: str = plan_entry["master_addr"] + master_port: int = int(plan_entry["master_port"]) + tgt_devices: list = plan_entry.get("tgt_devices", []) + world_size = 1 + len(tgt_devices) + + local_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + rank = 1 + for i, dev in enumerate(tgt_devices): + if int(dev.get("rank", -1)) == local_rank: + rank = i + 1 + break + + pg = StatelessProcessGroup( + master_address=master_addr, + port=master_port, + rank=rank, + world_size=world_size, + ) + pg.init_nccl_communicator(device=self.device) + self._model_update_groups[group_name] = pg + + def update_parameter_in_bucket( + self, + payload: dict, + ipc_local_ranks: list[int], + model_update_transport: str, + is_lora: bool = False, + ) -> None: + """Receive a packed weight bucket and load it into the model. + + Two transport modes (spec: nemorl-port-plan.md lines 316-323, 344-345): + + ``"cpu_serialize"`` โ€” payload contains ``cpu_uint8_bucket`` (CPU uint8 tensor). + DMA copies the buffer to GPU, then unpacks per-param tensors. + Used for cross-GPU or containerized deployments where CUDA IPC is unavailable. + + ``"cuda_ipc"`` โ€” payload contains ``cuda_ipc_handle`` (CUDA IPC handle tuple). + Rebuilds the GPU tensor directly via ``rebuild_cuda_tensor_from_ipc()`` + (zero-copy for same-physical-GPU colocated processes). + Required when sender and receiver share a physical GPU; NCCL cannot + form a group between two ranks on the same GPU (spec line 316). + + Args: + payload: Transport dict. cpu_serialize: keys ``param_names``, ``shapes``, + ``dtypes``, ``offsets``, ``used_bytes``, ``cpu_uint8_bucket``. + cuda_ipc: same keys but ``cuda_ipc_handle`` instead of ``cpu_uint8_bucket``. + ipc_local_ranks: Infer-local ranks that should process this bucket. + model_update_transport: ``"cpu_serialize"`` or ``"cuda_ipc"``. + is_lora: Reserved for LoRA adapter weights (not yet used). + """ + # Use the vLLM worker's own rank, not the distributed process-group rank. + # ipc_local_ranks contains LOCAL ranks within the worker (comm-plan contract, + # spec nemorl-port-plan.md:406-412); torch.distributed.get_rank() would be + # the wrong identity in multi-node setups. (Matches ROLL worker.py:757.) + local_rank = getattr(self, "rank", None) + if local_rank is None: + # Fallback for workers that don't expose .rank โ€” use distributed rank. + local_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + if local_rank not in ipc_local_ranks: + return + + from rlix.pipeline.bucket_cache import BucketRecord, unpack_bucket_record + + # --- Reconstruct named tensors from transport payload --- + if model_update_transport == "cuda_ipc": + # Zero-copy: sender and receiver share the same physical GPU. + # Rebuild GPU tensor from IPC handle โ€” no CPU roundtrip. + # Spec line 316: NCCL cannot form a group on the same GPU; IPC is required. + from torch.multiprocessing.reductions import rebuild_cuda_tensor + device_id = self.device.index if hasattr(self.device, "index") else 0 + ipc_args = list(payload["cuda_ipc_handle"][0]) + ipc_args[6] = device_id # patch device index + buf_gpu = rebuild_cuda_tensor(*ipc_args) + torch.cuda.current_stream().synchronize() + + # Reconstruct named tensors directly from GPU buffer (no CPU copy) + weights = [] + for name, shape, dtype, offset in zip( + payload["param_names"], payload["shapes"], + payload["dtypes"], payload["offsets"] + ): + num_elements = 1 + for s in shape: + num_elements *= s + nbytes = num_elements * torch.empty(0, dtype=dtype).element_size() + flat = buf_gpu[offset : offset + nbytes].view(dtype) + weights.append((name, flat.reshape(shape))) + else: + # cpu_serialize: DMA copy pinned CPU buffer โ†’ GPU, then unpack + buf_gpu = payload["cpu_uint8_bucket"].pin_memory().to(self.device, non_blocking=True) + torch.cuda.current_stream().synchronize() + + record = BucketRecord( + param_names=payload["param_names"], + shapes=payload["shapes"], + dtypes=payload["dtypes"], + offsets=payload["offsets"], + used_bytes=payload["used_bytes"], + cpu_uint8_bucket=buf_gpu.cpu(), + ) + named_tensors = unpack_bucket_record(record) + weights = [(name, t.to(self.device)) for name, t in named_tensors] + + from nemo_rl.models.generation.vllm.quantization import fp8 + + policy_weights, draft_weights = self._split_policy_and_draft_weights(weights) + if fp8.is_fp8_model(self.model_runner.vllm_config): + fp8.load_weights(policy_weights, self.model_runner) + else: + self.model_runner.model.load_weights(weights=policy_weights) + self._load_draft_weights(draft_weights) + torch.cuda.current_stream().synchronize() + del buf_gpu, weights, policy_weights, draft_weights + + def broadcast_parameter( + self, + group_name: str, + names: list[str], + dtypes: list, + shapes: list, + broadcast_local_ranks: list[int], + is_lora: bool = False, + ) -> None: + """Receive a packed NCCL broadcast and load weights into the model. + + Reuses the ``packed_broadcast_consumer`` pattern from + ``update_weights_from_collective`` (vllm_backend.py:294โ€“303). + + Args: + group_name: NCCL group name created by ``setup_collective_group``. + names: HF param names in order (matches sender's bucket). + dtypes: Per-param dtypes. + shapes: Per-param shapes. + broadcast_local_ranks: Infer-local ranks that participate. + Ranks not in this list return immediately (guard). + is_lora: If True, payload contains LoRA adapter weights (reserved, + not yet used โ€” base weights always use False). + """ + if not hasattr(self, "_model_update_groups"): + return + if group_name not in self._model_update_groups: + return + + # Use the vLLM worker's own local rank (same convention as + # update_parameter_in_bucket above). torch.distributed.get_rank() + # returns the global rank in WORLD; broadcast_local_ranks carries + # LOCAL ranks within the worker, so under TP>1 / multi-node the + # global rank never matches and every receiver silently + # early-returns. + local_rank = getattr(self, "rank", None) + if local_rank is None: + local_rank = ( + torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + ) + if local_rank not in broadcast_local_ranks: + return + + group = self._model_update_groups[group_name] + + # Calculate total buffer size (aligned, same as sender packing). + from nemo_rl.models.policy.utils import calculate_aligned_size + + total_bytes = 0 + for name, shape, dtype in zip(names, shapes, dtypes): + num_elements = 1 + for s in shape: + num_elements *= s + nbytes = num_elements * torch.empty(0, dtype=dtype).element_size() + total_bytes = calculate_aligned_size(total_bytes + nbytes) + + recv_buf = torch.zeros(total_bytes, dtype=torch.uint8, device=self.device) + group.broadcast(recv_buf, src=0) + + weights = [] + offset = 0 + for name, shape, dtype in zip(names, shapes, dtypes): + num_elements = 1 + for s in shape: + num_elements *= s + nbytes = num_elements * torch.empty(0, dtype=dtype).element_size() + flat = recv_buf[offset : offset + nbytes].view(dtype) + weights.append((name, flat.reshape(shape))) + offset = calculate_aligned_size(offset + nbytes) + + from nemo_rl.models.generation.vllm.quantization import fp8 + + policy_weights, draft_weights = self._split_policy_and_draft_weights(weights) + if fp8.is_fp8_model(self.model_runner.vllm_config): + fp8.load_weights(policy_weights, self.model_runner) + else: + self.model_runner.model.load_weights(weights=policy_weights) + self._load_draft_weights(draft_weights) + torch.cuda.current_stream().synchronize() + del recv_buf, weights, policy_weights, draft_weights + + def destroy_collective_group(self, group_name: str) -> None: + """Destroy a dynamic NCCL group. + + No-op guard: IPC-only ranks never join the NCCL group, so + ``group_name`` may not be present. + + Args: + group_name: Group name as used in ``setup_collective_group``. + """ + import torch.distributed as dist + + if not hasattr(self, "_model_update_groups"): + return + if group_name not in self._model_update_groups: + return + pg = self._model_update_groups.pop(group_name) + try: + dist.destroy_process_group(pg) + except Exception: + pass + + def verify_model(self, expected_stats: dict) -> None: + """Verify model weights match expected statistics after sync. + + Args: + expected_stats: Dict with keys ``sum``, ``max``, ``min`` computed + by the sender over all weight tensors. + + Raises: + RuntimeError: If any statistic deviates from expected by > 1e-3. + """ + state_dict = self.model_runner.model.state_dict() + vals = [t.float() for t in state_dict.values() if t.numel() > 0] + if not vals: + return + all_flat = torch.cat([v.flatten() for v in vals]) + actual = { + "sum": float(all_flat.sum()), + "max": float(all_flat.max()), + "min": float(all_flat.min()), + } + tol = 1e-3 + for key in ("sum", "max", "min"): + if key not in expected_stats: + continue + if abs(actual[key] - expected_stats[key]) > tol * (abs(expected_stats[key]) + 1.0): + raise RuntimeError( + f"verify_model: {key} mismatch: " + f"expected={expected_stats[key]:.6f} actual={actual[key]:.6f}" + ) + + def finalize_weight_update(self) -> None: + """Run post-loading weight processing (FP8 KV cache, etc.). + + Must be called after all buckets have been loaded via + ``update_parameter_in_bucket`` or ``broadcast_parameter``. + """ + from vllm.model_executor.model_loader.utils import process_weights_after_loading + + process_weights_after_loading( + self.model_runner.model, self.model_config, self.device + ) + self._maybe_process_fp8_kv_cache() + def cleanup(self) -> None: """Shutdown and cleanup resources.""" # Close ZMQ socket and context if they exist diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 0faaad17a1..a75718322f 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -851,6 +851,122 @@ def update_weights_from_collective(self) -> list[ray.ObjectRef]: # this function should co-work with lm_policy, so we should wait for all futures to complete outside return futures + # ------------------------------------------------------------------ + # rlix integration: selective sync receiver pass-throughs (Feature 4) + # ------------------------------------------------------------------ + # + # Every receiver pass-through below MUST dispatch to ALL TP/PP ranks + # (run_rank_0_only_axes=[]). NCCL collectives require every + # participating rank to call init_process_group / join the group; if + # we filter to TP rank 0 only, ranks 1..N-1 never join โ†’ all + # subsequent collectives (broadcast, finalize) deadlock or assert. + + def setup_collective_group( + self, + model_update_name: str, + comm_plan: dict, + mode: str, + timeout_s: float | None = None, + ) -> None: + """Pass-through: join NCCL group on all infer workers. + + Awaits sub-worker futures before returning so the caller's ray.get() + correctly barriers on completion (spec: nemorl-port-plan.md phase barriers). + """ + futures = self.worker_group.run_all_workers_single_data( + "setup_collective_group", + model_update_name=model_update_name, + comm_plan=comm_plan, + mode=mode, + timeout_s=timeout_s, + run_rank_0_only_axes=[], + ) + if futures: + ray.get(futures) + + def update_parameter_in_bucket( + self, + payload: dict, + ipc_local_ranks: list[int], + model_update_transport: str, + is_lora: bool = False, + ) -> None: + """Pass-through: receive a packed weight bucket on IPC-local workers. + + Awaits sub-worker futures so caller ray.get() barriers on weight load completion. + """ + futures = self.worker_group.run_all_workers_single_data( + "update_parameter_in_bucket", + payload=payload, + ipc_local_ranks=ipc_local_ranks, + model_update_transport=model_update_transport, + is_lora=is_lora, + run_rank_0_only_axes=[], + ) + if futures: + ray.get(futures) + + def broadcast_parameter( + self, + group_name: str, + names: list[str], + dtypes: list, + shapes: list, + broadcast_local_ranks: list[int], + is_lora: bool = False, + ) -> None: + """Pass-through: receive NCCL broadcast and load weights. + + Awaits sub-worker futures so caller ray.get() barriers on weight load completion. + """ + futures = self.worker_group.run_all_workers_single_data( + "broadcast_parameter", + group_name=group_name, + names=names, + dtypes=dtypes, + shapes=shapes, + broadcast_local_ranks=broadcast_local_ranks, + is_lora=is_lora, + run_rank_0_only_axes=[], + ) + if futures: + ray.get(futures) + + def destroy_collective_group(self, group_name: str) -> None: + """Pass-through: destroy NCCL group on all infer workers (no-op for non-members). + + Awaits sub-worker futures so caller ray.get() confirms teardown. + """ + futures = self.worker_group.run_all_workers_single_data( + "destroy_collective_group", + group_name=group_name, + run_rank_0_only_axes=[], + ) + if futures: + ray.get(futures) + + def verify_model(self, expected_stats: dict) -> None: + """Pass-through: verify weight stats on infer workers.""" + futures = self.worker_group.run_all_workers_single_data( + "verify_model", + expected_stats=expected_stats, + run_rank_0_only_axes=[], + ) + if futures: + ray.get(futures) + + def finalize_weight_update(self) -> None: + """Pass-through: run post-load weight processing on all infer workers. + + Awaits sub-worker futures so caller ray.get() confirms finalization. + """ + futures = self.worker_group.run_all_workers_single_data( + "finalize_weight_update", + run_rank_0_only_axes=[], + ) + if futures: + ray.get(futures) + def start_gpu_profiling(self) -> None: """Start GPU profiling.""" futures = self.worker_group.run_all_workers_single_data("start_gpu_profiling") diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 8df5e1f15c..bb1ebae172 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -1109,6 +1109,435 @@ def broadcast_weights_for_collective( post_iter_func=lambda x: x[1], ) + # ------------------------------------------------------------------ + # rlix integration: CPU bucket cache support (Feature 4) + # ------------------------------------------------------------------ + # + # Module-level helpers defined immediately below the class are used + # by the rlix integration methods. They are module-level (not class + # methods) so they can be tested without constructing a full worker. + # + # Two-pointer versioning mirrors ROLL megatron_strategy.py:1049โ€“1065: + # build_latest_bucket_cache(v) โ€” called after train_step; all PP ranks + # participate in the collective gather. + # promote_active_checkpoint(v) โ€” called by BucketCacheLifecycle.promote() + # to atomically switch the active pointer. + # + # selective_sync_active_cache() โ€” called by ModelUpdateService on the + # owner rank to transport buckets to infer + # workers (IPC or NCCL per bucket). + # ------------------------------------------------------------------ + + def _rlix_is_cache_owner(self) -> bool: + """Return True only for the single rank that builds/holds the cache.""" + return ( + parallel_state.is_pipeline_first_stage() + and parallel_state.get_tensor_model_parallel_rank() == 0 + and parallel_state.get_data_parallel_rank() == 0 + and parallel_state.get_context_parallel_rank() == 0 + ) + + def _rlix_get_versioned_cache(self): + """Lazy-init and return the per-worker VersionedBucketCache.""" + import threading as _threading + + from rlix.pipeline.bucket_cache import VersionedBucketCache + + if not hasattr(self, "_rlix_versioned_cache"): + self._rlix_versioned_cache = VersionedBucketCache() + self._rlix_model_update_groups: dict = {} + self._rlix_cache_init_lock = _threading.Lock() + return self._rlix_versioned_cache + + @torch.no_grad() + def build_latest_bucket_cache(self, checkpoint_version: int) -> None: + """Gather all HF weights into CPU bucket cache as a new 'latest' version. + + ALL PP/TP/EP ranks must participate simultaneously to keep the Megatron + PP collective alive. Only the cache owner (pp_rank==0, dp_rank==0, + tp_rank==0, cp_rank==0) stores the resulting List[BucketRecord]. + Non-owners drain the generator and return. + + Called by the pipeline after train_step returns. Equivalent to + ROLL worker.py:363 build_latest_bucket_cache. + + Args: + checkpoint_version: Step number (or -1 for base model). + """ + import logging + + from rlix.pipeline.bucket_cache import _bucket_named_tensors + + logger = logging.getLogger(__name__) + checkpoint_version = int(checkpoint_version) + is_owner = self._rlix_is_cache_owner() + + # Ensure refit_conversion_tasks is populated (needed by the iterator). + self.prepare_refit_info() + + # Bucket packing: accumulate named tensors until we reach bucket_size_bytes, + # then flush a BucketRecord. Size is read from the worker config (key + # "rlix_bucket_size_bytes") or the env var RLIX_BUCKET_SIZE_BYTES. + # A hardcoded silent default is intentionally prohibited โ€” callers must set + # the config or env var so the value is always visible in logs. + bucket_size_bytes: int = _rlix_get_bucket_size_bytes(self) + if is_owner and checkpoint_version == -1: + # Init-time VRAM check: verify the bucket fits in available GPU memory. + _rlix_check_vram(bucket_size_bytes, logger) + + buckets = [] + current_batch: list = [] + current_bytes = 0 + + for name, tensor in self._iter_params_with_optional_kv_scales(): + if not is_owner: + # Non-owner: must still exhaust the generator to keep the PP + # collective alive; do NOT store anything. + continue + + cpu_t = tensor.detach().cpu().contiguous() + nbytes = cpu_t.numel() * cpu_t.element_size() + + # Fail fast: a single tensor larger than bucket_size_bytes can never be + # staged within the GPU VRAM budget (spec: nemorl-port-plan.md line 342-343). + # Matches ROLL's send_recv_utils.py assertion pattern. + if nbytes > bucket_size_bytes: + raise RuntimeError( + f"[rlix] Parameter '{name}' ({nbytes >> 20} MB) exceeds " + f"bucket_size_bytes ({bucket_size_bytes >> 20} MB). " + "Increase RLIX_BUCKET_SIZE_BYTES or bucket_size_bytes config." + ) + + if current_batch and current_bytes + nbytes > bucket_size_bytes: + # Flush current batch into a BucketRecord before appending. + buckets.append(_bucket_named_tensors(current_batch)) + current_batch = [] + current_bytes = 0 + + current_batch.append((name, cpu_t)) + current_bytes += nbytes + + if is_owner and current_batch: + buckets.append(_bucket_named_tensors(current_batch)) + + if is_owner: + cache = self._rlix_get_versioned_cache() + cache.build_latest(checkpoint_version, buckets) + total_bytes = sum(b.cpu_uint8_bucket.numel() for b in buckets) + logger.info( + "[rlix] build_latest_bucket_cache version=%d " + "buckets=%d total_bytes=%d", + checkpoint_version, len(buckets), total_bytes, + ) + # Host-RAM fail-fast: two-pointer versioning keeps โ‰ค 2 full model copies. + # Check against actual packed model size, not per-bucket size. + # Spec: nemorl-port-plan.md line 337 โ€” "ไผฐ็ฎ— total_cpu_cache_bytes โ€ฆ fail fast". + if checkpoint_version == -1: + try: + import psutil as _psutil + available_ram = _psutil.virtual_memory().available + ram_budget = int(available_ram * 0.8) + two_copy = 2 * total_bytes + if two_copy > ram_budget: + raise RuntimeError( + f"[rlix] Host RAM budget exceeded: " + f"2 ร— model ({two_copy >> 20} MB) > " + f"80% of available RAM ({ram_budget >> 20} MB). " + "Reduce model size or increase host RAM." + ) + logger.info( + "[rlix] host_ram_check_ok two_copy=%d MB available_ram=%d MB", + two_copy >> 20, available_ram >> 20, + ) + except ImportError: + logger.warning("[rlix] psutil not installed โ€” skipping host-RAM budget check") + + def promote_active_checkpoint(self, version: int) -> None: + """Atomically switch the active cache pointer to *version*. + + Non-owner ranks return immediately (no-op). Only the cache owner + (pp_rank==0, dp_rank==0, tp_rank==0, cp_rank==0) has a live cache. + + Called by ``BucketCacheLifecycle.promote()`` after + ``build_latest_bucket_cache(version)`` has completed on all workers. + Equivalent to ROLL worker.py:387 promote_active_checkpoint. + + Args: + version: Must match a version passed to ``build_latest_bucket_cache``. + """ + import logging + + logger = logging.getLogger(__name__) + version = int(version) + + if not self._rlix_is_cache_owner(): + return + + cache = self._rlix_get_versioned_cache() + cache.promote(version) + logger.info("[rlix] promote_active_checkpoint version=%d", version) + + @torch.no_grad() + def selective_sync_active_cache( + self, + sync_id: str, + comm_plan: Optional[dict], + tgt_dp_ranks: list[int], + tgt_workers: list, + tgt_device_mapping: list[int], + tgt_num_gpus_per_worker: int, + adapters_to_sync: Optional[list[str]] = None, + model_update_transport: str = "cpu_serialize", + ) -> Optional[dict]: + """Transport active cache buckets to inference workers (IPC or NCCL). + + Non-owner ranks return immediately. Owner holds the cache lock for + the entire transport loop to prevent a concurrent promote/build from + racing the sender read. + + Per-bucket staging constraint: CPUโ†’GPU one bucket at a time; delete + immediately after the barrier. Forbidden to load the full model to + GPU at once. + + Args: + sync_id: Unique sync identifier (used for group name lookup). + comm_plan: Communication plan built by ModelUpdateService for the + owner rank. Non-owners receive None. + tgt_dp_ranks: Inference DP ranks to update. + tgt_workers: All inference worker Ray actor handles. + tgt_device_mapping: GPU device indices per infer worker. + tgt_num_gpus_per_worker: Number of GPUs per infer worker. + adapters_to_sync: Unused; reserved for LoRA adapter sync. + + Returns: + ``{"weight_stats": {...}}`` from the owner for post-sync + verification, or ``None`` from non-owners. + """ + import logging + + import torch.distributed as dist + + logger = logging.getLogger(__name__) + + if not self._rlix_is_cache_owner() or comm_plan is None: + return None + + cache = self._rlix_get_versioned_cache() + ipc_targets: list[dict] = comm_plan[next(iter(comm_plan))].get("ipc_targets", []) + broadcast_local_ranks_by_dp_rank: dict[int, list[int]] = ( + comm_plan[next(iter(comm_plan))].get("broadcast_local_ranks_by_dp_rank", {}) + ) + group_name: str = comm_plan[next(iter(comm_plan))]["group_name"] + dp_rank_to_worker = { + int(dp_rank): tgt_workers[dp_rank] + for dp_rank in tgt_dp_ranks + } + + # Hold cache lock for the entire transport to prevent a concurrent + # promote/build from modifying the active pointer during transport. + with cache._cache_lock: + buckets = cache.get_active_buckets() + n_buckets = len(buckets) + + for bucket_idx, bucket in enumerate(buckets): + # Stage single bucket CPUโ†’GPU; release immediately after barrier. + staging_buf: Optional[torch.Tensor] = None + try: + staging_buf = bucket.cpu_uint8_bucket.pin_memory().cuda() + logger.info( + "[ModelUpdateService] bucket_send bucket_idx=%d/%d " + "bytes=%d group_name=%s sync_id=%s", + bucket_idx, n_buckets, bucket.used_bytes, group_name, sync_id, + ) + + recv_refs = [] + + # IPC path: colocated same-GPU workers. + # model_update_transport selects the payload format: + # - "cuda_ipc": CUDA IPC handle (zero-copy, same physical GPU). + # Spec line 316: NCCL CANNOT form a group on the same GPU; IPC is required. + # - "cpu_serialize": CPU uint8 bucket DMA to receiver GPU. + # Used when CUDA IPC is unavailable (e.g. containerized or cross-GPU). + for ipc_target in ipc_targets: + dp_rank = int(ipc_target["dp_rank"]) + local_ranks = ipc_target["local_ranks"] + + if model_update_transport == "cuda_ipc": + # Zero-copy IPC: share the GPU staging buffer with the colocated process. + from nemo_rl.models.policy.utils import get_handle_from_tensor + torch.cuda.current_stream().synchronize() + cuda_ipc_handle = get_handle_from_tensor(staging_buf) + payload = { + "param_names": bucket.param_names, + "shapes": bucket.shapes, + "dtypes": bucket.dtypes, + "offsets": bucket.offsets, + "used_bytes": bucket.used_bytes, + "cuda_ipc_handle": cuda_ipc_handle, + } + else: + # cpu_serialize: send the CPU uint8 bucket (DMA on receiver side). + payload = { + "param_names": bucket.param_names, + "shapes": bucket.shapes, + "dtypes": bucket.dtypes, + "offsets": bucket.offsets, + "used_bytes": bucket.used_bytes, + "cpu_uint8_bucket": bucket.cpu_uint8_bucket, + } + + recv_refs.append( + dp_rank_to_worker[dp_rank].update_parameter_in_bucket.remote( + payload, local_ranks, model_update_transport + ) + ) + + # NCCL broadcast path: cross-GPU workers. + if group_name in self._rlix_model_update_groups: + nccl_group = self._rlix_model_update_groups[group_name] + + # Dispatch receiver .remote() calls FIRST, so the + # actor scheduler queues them BEFORE this worker + # blocks on the collective. dist.broadcast() is a + # synchronous NCCL collective โ€” it pins the Python + # thread until every participating rank arrives. + # If we issued .remote() after, the sender thread + # would already be blocked, .remote() would never + # submit, and receivers would never join the + # collective โ†’ deadlock. + for dp_rank, broadcast_local_ranks in broadcast_local_ranks_by_dp_rank.items(): + recv_refs.append( + dp_rank_to_worker[int(dp_rank)].broadcast_parameter.remote( + group_name, + bucket.param_names, + bucket.dtypes, + bucket.shapes, + broadcast_local_ranks, + ) + ) + + # Now enter the collective. Receivers' broadcast_parameter + # implementations call group.broadcast(recv_buf, src=0) + # and rendezvous with this call. + dist.broadcast(staging_buf, src=0, group=nccl_group) + + import ray as _ray + _ray.get(recv_refs) + + logger.info( + "[ModelUpdateService] bucket_ack bucket_idx=%d/%d sync_id=%s", + bucket_idx, n_buckets, sync_id, + ) + finally: + # Release GPU staging buffer immediately after barrier. + del staging_buf + staging_buf = None + + # Flush GPU streams before teardown: dist.broadcast is async; synchronize + # ensures all NCCL kernels have completed before destroying the communicator. + # _ray.get(recv_refs) above already confirmed receivers finished, so this + # just ensures sender-side CUDA stream is clean. + torch.cuda.synchronize() + # Tear down the NCCL collective group while still holding _cache_lock. + # Spec (nemorl-port-plan.md line 402): lock must span "cache lookup โ†’ + # transport โ†’ NCCL teardown" โ€” releasing before teardown completes + # would allow a concurrent build_latest / promote to race the sender. + self.destroy_collective_group(group_name) + + # Compute weight stats for optional post-sync verification. + weight_stats: dict = {} + try: + sd = {n: t for n, t in self._iter_params_with_optional_kv_scales()} + vals = [t.float() for t in sd.values() if t.numel() > 0] + if vals: + all_flat = torch.cat([v.flatten() for v in vals]) + weight_stats = { + "sum": float(all_flat.sum()), + "max": float(all_flat.max()), + "min": float(all_flat.min()), + } + except Exception: + pass + + return {"weight_stats": weight_stats} + + def setup_collective_group( + self, + model_update_name: str, + comm_plan: dict, + mode: str, + timeout_s: Optional[float] = None, + ) -> None: + """Join a dynamic NCCL group for selective model weight broadcast. + + The sender (mode='sender') joins as rank 0; receivers join at + their assigned rank from the comm_plan. + + Args: + model_update_name: Unique sync identifier (used as group name). + comm_plan: Communication plan dict with master_addr/port and + world size info. + mode: 'sender' (rank 0) or 'receiver'. + timeout_s: Optional NCCL init timeout in seconds. + """ + from nemo_rl.distributed.stateless_process_group import StatelessProcessGroup + + cache = self._rlix_get_versioned_cache() + plan_entry = comm_plan[next(iter(comm_plan))] + group_name: str = plan_entry["group_name"] + master_addr: str = plan_entry["master_addr"] + master_port: int = int(plan_entry["master_port"]) + + if mode == "sender": + tgt_devices = plan_entry.get("tgt_devices", []) + world_size = 1 + len(tgt_devices) + rank = 0 + else: + # Receiver: find our rank from tgt_devices list. + tgt_devices = plan_entry.get("tgt_devices", []) + world_size = 1 + len(tgt_devices) + local_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + rank = 1 # default; real multi-receiver assignment handled by StatelessProcessGroup ordering + for i, dev in enumerate(tgt_devices): + if int(dev.get("rank", -1)) == local_rank: + rank = i + 1 + break + + pg = StatelessProcessGroup( + master_address=master_addr, + port=master_port, + rank=rank, + world_size=world_size, + ) + pg.init_nccl_communicator(device=self.device if hasattr(self, "device") else torch.device("cuda")) + self._rlix_model_update_groups[group_name] = pg + + def destroy_collective_group(self, group_name: str) -> None: + """Destroy a dynamic NCCL group created by setup_collective_group. + + No-op if the group does not exist (IPC-only ranks never join the + NCCL group, so this guard is required). + + Args: + group_name: Group name as used in setup_collective_group. + """ + import logging + + import torch.distributed as dist + + logger = logging.getLogger(__name__) + groups = getattr(self, "_rlix_model_update_groups", {}) + if group_name not in groups: + return + pg = groups.pop(group_name) + try: + dist.destroy_process_group(pg) + except Exception as exc: + logger.warning( + "[rlix] destroy_collective_group failed group_name=%s: %s", + group_name, exc, + ) + def prepare_for_lp_inference(self): self.model = self.move_model(self.model, "cuda", move_grads=False) self.model.eval() @@ -1608,6 +2037,113 @@ def _percentile(values: list[float], p: float) -> float: return final_result +# --------------------------------------------------------------------------- +# rlix module-level helpers for bucket cache (Feature 4) +# --------------------------------------------------------------------------- +# These are module-level so they can be imported and tested without +# constructing a full MegatronPolicyWorkerImpl. +# --------------------------------------------------------------------------- + +_RLIX_BUCKET_SIZE_ENV = "RLIX_BUCKET_SIZE_BYTES" +_RLIX_BUCKET_SIZE_DEFAULT = 256 * 1024 * 1024 # 256 MB documented default +# Transport scratch (NCCL send-side staging overhead estimate). +_RLIX_TRANSPORT_SCRATCH_MB = 64 + + +def _rlix_get_bucket_size_bytes(worker) -> int: + """Return the configured bucket size in bytes for rlix cache building. + + Priority order: + 1. ``worker.cfg["rlix"]["bucket_size_bytes"]`` (explicit config key). + 2. ``RLIX_BUCKET_SIZE_BYTES`` environment variable. + 3. Documented default ``_RLIX_BUCKET_SIZE_DEFAULT`` (256 MB), emitted + as a WARNING so users know the default is active. + + This function is intentionally NOT a silent fallback โ€” every code path + logs the active value so callers are always aware. + + Args: + worker: MegatronPolicyWorkerImpl instance (for cfg access). + + Returns: + Bucket size in bytes (positive int). + + Raises: + ValueError: If the resolved value is <= 0. + """ + import logging + import os + + logger = logging.getLogger(__name__) + + # 1. Worker config + cfg = getattr(worker, "cfg", {}) or {} + rlix_cfg = cfg.get("rlix", {}) or {} + if "bucket_size_bytes" in rlix_cfg: + val = int(rlix_cfg["bucket_size_bytes"]) + logger.info("[rlix] bucket_size_bytes=%d (from worker.cfg['rlix'])", val) + if val <= 0: + raise ValueError(f"[rlix] bucket_size_bytes must be > 0, got {val}") + return val + + # 2. Environment variable + env_val = os.environ.get(_RLIX_BUCKET_SIZE_ENV) + if env_val is not None: + val = int(env_val) + logger.info("[rlix] bucket_size_bytes=%d (from env %s)", val, _RLIX_BUCKET_SIZE_ENV) + if val <= 0: + raise ValueError(f"[rlix] {_RLIX_BUCKET_SIZE_ENV} must be > 0, got {val}") + return val + + # Spec (nemorl-port-plan.md line 343): bucket_size_bytes must be an explicit + # configuration value โ€” no implicit default is allowed. Fail fast so operators + # are forced to make the staging-VRAM budget decision visible in config. + raise RuntimeError( + "[rlix] bucket_size_bytes is not configured. " + f"Set worker.cfg['rlix']['bucket_size_bytes'] or env {_RLIX_BUCKET_SIZE_ENV}. " + "No implicit default is permitted (spec: nemorl-port-plan.md line 343)." + ) + + +def _rlix_check_vram(bucket_size_bytes: int, logger) -> None: + """Fail fast if bucket_size_bytes exceeds available GPU VRAM margin. + + Called once at init time (when ``checkpoint_version == -1``). + Peak staging VRAM estimate: ``bucket_size_bytes + _RLIX_TRANSPORT_SCRATCH_MB * 1024^2``. + + Args: + bucket_size_bytes: Configured bucket size in bytes. + logger: Logger instance (already has worker context). + + Raises: + RuntimeError: If estimated peak staging VRAM exceeds 90% of free GPU memory. + """ + try: + import torch + + free_bytes, total_bytes = torch.cuda.mem_get_info() + scratch_bytes = _RLIX_TRANSPORT_SCRATCH_MB * 1024 * 1024 + peak_bytes = bucket_size_bytes + scratch_bytes + threshold = 0.9 * free_bytes + logger.info( + "[rlix] vram_check free_gb=%.2f peak_staging_gb=%.2f bucket_size_mb=%d", + free_bytes / 1024 ** 3, + peak_bytes / 1024 ** 3, + bucket_size_bytes // (1024 * 1024), + ) + if peak_bytes > threshold: + raise RuntimeError( + f"[rlix] bucket_size_bytes={bucket_size_bytes} exceeds VRAM margin: " + f"peak_staging={peak_bytes / 1024**3:.2f} GB > 90% of free={free_bytes / 1024**3:.2f} GB. " + f"Reduce RLIX_BUCKET_SIZE_BYTES or free GPU memory before training." + ) + except RuntimeError: + raise + except Exception as exc: + # Non-CUDA environments (CPU-only, mock): skip the check. + logger.debug("[rlix] vram_check skipped: %s", exc) + + @ray.remote( runtime_env=get_runtime_env_for_policy_worker("megatron_policy_worker") ) # pragma: no cover diff --git a/tests/unit/models/generation/test_pr3_review_fixes.py b/tests/unit/models/generation/test_pr3_review_fixes.py new file mode 100644 index 0000000000..167037a14b --- /dev/null +++ b/tests/unit/models/generation/test_pr3_review_fixes.py @@ -0,0 +1,276 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Structural unit tests for the rlix-task2 selective-sync review fixes. + +Three independent fixes; each is a CORRECTNESS issue that surfaces only +under TP>1 / multi-receiver topology. Tests here are mock-only โ€” no +Ray, no GPUs โ€” so they run in CI alongside the existing fast suite. + +Behavioral verification of the deadlock path (Bug 1) requires real +Ray-actor scheduling under NCCL and is left to upstream multi-GPU +selective-sync integration tests; the lexical guard at the bottom of +this file catches obvious source-level regressions in the meantime. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------- +# Bug 3 โ€” vllm_generation receiver pass-throughs MUST dispatch to every +# TP/PP rank (run_rank_0_only_axes=[]). Filtering to TP rank 0 only +# leaves ranks 1..N-1 outside the NCCL collective โ†’ deadlock. +# --------------------------------------------------------------------- + +# The 6 receiver pass-through methods that participate in the +# selective-sync NCCL collective. Each MUST forward +# run_rank_0_only_axes=[] to run_all_workers_single_data. +_RECEIVER_PASS_THROUGHS = [ + ("setup_collective_group", { + "model_update_name": "g0", + "comm_plan": {"g0": {"group_name": "g0", "master_addr": "1.2.3.4", + "master_port": 1234, "tgt_devices": []}}, + "mode": "receiver", + "timeout_s": 10.0, + }), + ("update_parameter_in_bucket", { + "payload": {"param_names": [], "shapes": [], "dtypes": [], + "offsets": [], "used_bytes": 0, "cpu_uint8_bucket": b""}, + "ipc_local_ranks": [0], + "model_update_transport": "cpu_serialize", + }), + ("broadcast_parameter", { + "group_name": "g0", + "names": [], + "dtypes": [], + "shapes": [], + "broadcast_local_ranks": [0], + }), + ("destroy_collective_group", {"group_name": "g0"}), + ("verify_model", {"expected_stats": {}}), + ("finalize_weight_update", {}), +] + + +@pytest.mark.parametrize("method_name,kwargs", _RECEIVER_PASS_THROUGHS, + ids=[m for m, _ in _RECEIVER_PASS_THROUGHS]) +def test_receiver_passthroughs_dispatch_to_all_ranks(method_name, kwargs): + """Every receiver pass-through must forward run_rank_0_only_axes=[].""" + from nemo_rl.models.generation.vllm.vllm_generation import VllmGeneration + + fake_worker_group = MagicMock() + fake_worker_group.run_all_workers_single_data.return_value = [] + + instance = MagicMock(spec=VllmGeneration) + instance.worker_group = fake_worker_group + # Bind the real method to the mock instance so we exercise the actual + # implementation rather than the auto-mock. + method = getattr(VllmGeneration, method_name) + method(instance, **kwargs) + + fake_worker_group.run_all_workers_single_data.assert_called_once() + _, call_kwargs = fake_worker_group.run_all_workers_single_data.call_args + assert call_kwargs.get("run_rank_0_only_axes") == [], ( + f"{method_name}: must dispatch to all TP/PP ranks " + f"(run_rank_0_only_axes=[]); got " + f"{call_kwargs.get('run_rank_0_only_axes')!r}" + ) + + +# --------------------------------------------------------------------- +# Bug 2 โ€” vllm_backend.broadcast_parameter must use self.rank (worker- +# local) when comparing against broadcast_local_ranks (also worker-local +# ranks). Falling back to torch.distributed.get_rank() (global rank) +# under TP>1 / multi-node never matches โ†’ silent receiver early-return โ†’ +# the broadcast collective is never entered on the receiver side. +# --------------------------------------------------------------------- + + +def _make_vllm_extension(rank=None): + """Build a minimal stand-in for VllmInternalWorkerExtension that + exposes the attributes broadcast_parameter touches up to (and + including) the post-guard ``torch.zeros(..., device=self.device)`` + call. We assert on whether torch.zeros was invoked rather than on + the deeper ``group.broadcast`` because the function path between the + two requires more state that's expensive to fake. + """ + from nemo_rl.models.generation.vllm.vllm_backend import ( + VllmInternalWorkerExtension, + ) + import torch + + inst = VllmInternalWorkerExtension.__new__(VllmInternalWorkerExtension) + if rank is not None: + inst.rank = rank + # Required by post-guard ``torch.zeros(..., device=self.device)``. + inst.device = torch.device("cpu") + # ``broadcast_parameter`` checks ``_model_update_groups``; populate + # with a stand-in so the rank-comparison branch is reached. + fake_group = MagicMock() + fake_group.broadcast = MagicMock() + inst._model_update_groups = {"g0": fake_group} + return inst + + +def _broadcast_args(): + """Minimal args that make ``broadcast_parameter`` reach the rank + check. ``names``/``dtypes``/``shapes`` empty so the aligned-size + arithmetic is trivially 0 and no real tensor work is required.""" + return dict( + group_name="g0", + names=[], + dtypes=[], + shapes=[], + ) + + +# Sentinel exception raised by ``torch.zeros`` mock โ€” surfaces "the +# function reached the post-guard code path" without requiring us to +# fake the rest of the model state (which sits past torch.zeros and +# would fail with AttributeError otherwise). +class _PastGuardSentinel(RuntimeError): + pass + + +def test_broadcast_parameter_uses_self_rank_when_set(): + """When ``self.rank`` is in ``broadcast_local_ranks``, the receiver + must NOT early-return โ€” proven by observing the sentinel raised + from the first post-guard ``torch.zeros`` call.""" + from nemo_rl.models.generation.vllm import vllm_backend + + inst = _make_vllm_extension(rank=1) + + # Patch torch.distributed.get_rank to a different value so a buggy + # implementation (using global rank) would early-return on + # `42 not in [1]` without raising the sentinel. + with patch.object(vllm_backend, "torch") as mock_torch: + mock_torch.distributed.is_initialized.return_value = True + mock_torch.distributed.get_rank.return_value = 42 + # Allow `torch.empty(0, dtype=...).element_size()` to succeed so + # the aligned-size loop doesn't crash before torch.zeros. + mock_torch.empty.return_value.element_size.return_value = 1 + # Raise sentinel from the first post-guard call. If the function + # early-returned at the rank check, we never get here. + mock_torch.zeros.side_effect = _PastGuardSentinel("past guard") + + with pytest.raises(_PastGuardSentinel): + vllm_backend.VllmInternalWorkerExtension.broadcast_parameter( + inst, + broadcast_local_ranks=[1], + **_broadcast_args(), + ) + + +def test_broadcast_parameter_skips_when_rank_not_in_local_ranks(): + """When ``self.rank`` is NOT in ``broadcast_local_ranks``, the + receiver early-returns (sentinel never raised).""" + from nemo_rl.models.generation.vllm import vllm_backend + + inst = _make_vllm_extension(rank=1) + + with patch.object(vllm_backend, "torch") as mock_torch: + mock_torch.distributed.is_initialized.return_value = True + mock_torch.distributed.get_rank.return_value = 42 + # If reached, would raise โ€” but we expect early return. + mock_torch.zeros.side_effect = _PastGuardSentinel("past guard") + + # Should return None cleanly (no exception) โ€” the early return + # at `if local_rank not in broadcast_local_ranks: return` fires. + result = vllm_backend.VllmInternalWorkerExtension.broadcast_parameter( + inst, + broadcast_local_ranks=[0], # rank 1 not in here + **_broadcast_args(), + ) + assert result is None + + +def test_broadcast_parameter_falls_back_to_global_rank_when_self_rank_absent(): + """Backward-compat: callers that don't set ``self.rank`` fall + through to ``torch.distributed.get_rank()`` (the original behavior).""" + from nemo_rl.models.generation.vllm import vllm_backend + + inst = _make_vllm_extension(rank=None) + # Don't set inst.rank at all โ€” getattr returns None. + + with patch.object(vllm_backend, "torch") as mock_torch: + mock_torch.distributed.is_initialized.return_value = True + # Global rank 0 is in [0] โ†’ not skipped โ†’ sentinel raised. + mock_torch.distributed.get_rank.return_value = 0 + mock_torch.empty.return_value.element_size.return_value = 1 + mock_torch.zeros.side_effect = _PastGuardSentinel("past guard") + + with pytest.raises(_PastGuardSentinel): + vllm_backend.VllmInternalWorkerExtension.broadcast_parameter( + inst, + broadcast_local_ranks=[0], + **_broadcast_args(), + ) + + +# --------------------------------------------------------------------- +# Bug 1 โ€” sender selective_sync_active_cache must dispatch all +# broadcast_parameter receivers BEFORE entering dist.broadcast(). The +# reverse ordering deadlocks: the sender's Python thread is pinned +# inside the collective and never submits the .remote() calls. +# --------------------------------------------------------------------- +# +# This bug only manifests under real Ray-actor scheduling โ€” a unit test +# of the function in isolation can't reproduce the deadlock because +# `.remote()` and `dist.broadcast()` both return synchronously when +# their dependencies are mocked. The lexical guard below catches +# obvious regressions (someone moving `dist.broadcast()` back above the +# `.remote()` loop in a future refactor); behavioral verification under +# real Ray + NCCL is the responsibility of upstream selective-sync +# integration tests on multi-GPU hardware. + + +def test_selective_sync_dispatch_ordering_lexical(): + """Lexical guard: in selective_sync_active_cache, the + broadcast_parameter.remote(...) dispatch loop appears BEFORE the + sender-side dist.broadcast(...) call within the NCCL-broadcast + branch. + + Reads the source file directly via Path (no Python import) so the + test does not pull in megatron / megatron_bridge dependencies that + may not be present on every CI image. + """ + from pathlib import Path + + repo_root = Path(__file__).resolve().parents[4] + src_path = repo_root / "nemo_rl" / "models" / "policy" / "workers" / "megatron_policy_worker.py" + assert src_path.is_file(), f"could not locate {src_path}" + text = src_path.read_text() + + # Find the NCCL-broadcast branch; assert dispatch loop comes before + # the dist.broadcast call. + branch_marker = "if group_name in self._rlix_model_update_groups:" + branch_idx = text.find(branch_marker) + assert branch_idx > 0, "could not locate NCCL-broadcast branch marker" + + # Search within the next ~3000 chars (covers the per-bucket loop body). + region = text[branch_idx : branch_idx + 3000] + + dispatch_idx = region.find(".broadcast_parameter.remote(") + sender_idx = region.find("dist.broadcast(staging_buf") + assert dispatch_idx > 0, "dispatch loop not found in NCCL branch" + assert sender_idx > 0, "sender dist.broadcast not found in NCCL branch" + assert dispatch_idx < sender_idx, ( + "regression: sender dist.broadcast appears BEFORE receiver " + ".remote() dispatch โ€” this reintroduces the deadlock the fix " + "was meant to prevent." + )