Skip to content

Commit 7d9d7af

Browse files
authored
Added workspace check and reflected this in test (#1991)
<!-- .github/pull_request_template.md --> ## 📌 Description This PR attempts to fix #1986 (to be confirmed by requester) The issue is that num_tokens was larger than MAX_TOKEN_NUM, which results in an IMA, or even in a hang. To address this, I added a validation check. This required a non-breaking API change: * create_ipc_workspace_for_all_reduce_fusion now has an optional "create_metadata" bool, which results in an additional return value * it is made optional as additional return value could break the API * trtllm_allreduce_fusion now takes an optional metadata dictionary * When provided, this will run the validation check * again, this is also optional, to avoid breaking the api In addition this PR deprecates the older AllReduce functionality so it can be removed in a major version bump. ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **API Changes** * Workspace creation can optionally return metadata describing the workspace configuration (create_metadata flag). * Allreduce fusion operations accept optional metadata to validate runtime parameters against the workspace and raise clear errors on mismatch. * A workspace destruction endpoint was renamed for naming consistency. * Legacy wrappers were marked deprecated and now point users toward the newer fusion variants. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent c857f09 commit 7d9d7af

File tree

2 files changed

+98
-7
lines changed

2 files changed

+98
-7
lines changed

flashinfer/comm/trtllm_ar.py

Lines changed: 90 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from ctypes import c_void_p, cast
2020
from types import SimpleNamespace
2121
from typing import List, Optional, Tuple, Union
22+
from typing_extensions import deprecated
2223

2324
import torch
2425
import torch.distributed as dist
@@ -121,6 +122,9 @@ def trtllm_lamport_initialize_all(
121122
buffer_0_ptr, buffer_1_ptr, buffer_2_ptr, size, dtype
122123
)
123124

125+
@deprecated(
126+
"trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated, use trtllm_create_ipc_workspace_for_all_reduce_fusion and trtllm_allreduce_fusion instead"
127+
)
124128
@register_custom_op(
125129
"flashinfer::trtllm_custom_all_reduce",
126130
mutates_args=[
@@ -393,6 +397,9 @@ def trtllm_moe_finalize_allreduce_fusion(
393397
LamportTokenNumThreshold = 16
394398

395399

400+
@deprecated(
401+
"trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated, use trtllm_create_ipc_workspace_for_all_reduce_fusion and trtllm_allreduce_fusion instead"
402+
)
396403
def trtllm_create_ipc_workspace_for_all_reduce(
397404
rank: int,
398405
tp_size: int,
@@ -493,14 +500,18 @@ def trtllm_destroy_ipc_workspace_for_all_reduce(
493500
MAX_COMM_SIZE = 2147483647 & ~((1 << 21) - 1) # MAX_INT32 rounded down to 2MB
494501

495502

503+
# @TODO(nvmbreughe): on a next major bump, remove create_metadata and make create_metadata=True the default behavior
496504
def trtllm_create_ipc_workspace_for_all_reduce_fusion(
497505
tp_rank: int,
498506
tp_size: int,
499507
max_token_num: int,
500508
hidden_dim,
501509
use_fp32_lamport: bool = False,
502510
group: Optional[ProcessGroup] = None,
503-
) -> Tuple[List[List[int]], torch.Tensor]:
511+
create_metadata: bool = False,
512+
) -> Union[
513+
Tuple[List[List[int]], torch.Tensor], Tuple[List[List[int]], torch.Tensor, dict]
514+
]:
504515
"""
505516
Parameters:
506517
- tp_rank: the rank of the current process.
@@ -509,6 +520,13 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
509520
- hidden_dim: the dimension of the hidden states.
510521
- use_fp32_lamport: if True, we will use fp32 datatype in allreduce fusion.
511522
- group: the process group to use.
523+
- create_metadata: if True, return metadata dict as third element (default: False).
524+
525+
Returns:
526+
- If create_metadata=False: (ipc_handles, workspace_tensor)
527+
- If create_metadata=True: (ipc_handles, workspace_tensor, metadata)
528+
where metadata contains: tp_rank, tp_size, max_token_num, hidden_dim,
529+
use_fp32_lamport, buffer_size, flag_size, lamport_comm_size, lamport_buffer_size
512530
513531
Note:
514532
We would init 3 IPC buffers for trtllm_custom_all_reduce_fusion.
@@ -517,8 +535,8 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
517535
where:
518536
- buffer_size: tp_size * max_token_num * hidden_dim * sizeof(half)
519537
- flag_size: tp_size * BarrierFlagCount * sizeof(int)
520-
- lamport_buffer_size: tp_size * max(max_token_num, OneShotMaxToken) * tp_size * hidden_dim * sizeof(half)
521-
538+
- lamport_buffer_size: tp_size * max_token_num * tp_size * hidden_dim * sizeof(half)
539+
where sizeof(elem) = 2 (fp16/bf16) or 4 (fp32 when use_fp32_lamport=True)
522540
The workspace is passed as workspace field in AllReduceFusionParams.
523541
524542
We use tp_size and world_size here interchangeably (allReduceFusion).
@@ -608,7 +626,21 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
608626

609627
dist.barrier(group=group) # must sync after create_workspace
610628

611-
return ipc_handles, workspace_tensor
629+
if create_metadata:
630+
metadata = {
631+
"tp_rank": tp_rank,
632+
"tp_size": tp_size,
633+
"max_token_num": max_token_num,
634+
"hidden_dim": hidden_dim,
635+
"use_fp32_lamport": use_fp32_lamport,
636+
"buffer_size": buffer_size,
637+
"flag_size": flag_size,
638+
"lamport_comm_size": lamport_comm_size,
639+
"lamport_buffer_size": lamport_buffer_size,
640+
}
641+
return ipc_handles, workspace_tensor, metadata
642+
else:
643+
return ipc_handles, workspace_tensor
612644

613645

614646
def trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
@@ -675,6 +707,9 @@ def trtllm_lamport_initialize_all(
675707
)
676708

677709

710+
@deprecated(
711+
"trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated, use trtllm_create_ipc_workspace_for_all_reduce_fusion and trtllm_allreduce_fusion instead"
712+
)
678713
def trtllm_custom_all_reduce(
679714
inp: torch.Tensor,
680715
out: torch.Tensor,
@@ -791,6 +826,7 @@ def trtllm_allreduce_fusion(
791826
rms_eps: Optional[float],
792827
scale_factor: Optional[Union[torch.Tensor, float]],
793828
layout_code: Optional[QuantizationSFLayout],
829+
metadata: Optional[dict] = None,
794830
) -> None:
795831
"""
796832
Parameters:
@@ -815,8 +851,58 @@ def trtllm_allreduce_fusion(
815851
- rms_eps: the rms epsilon value.
816852
- scale_factor: the scale factor. For cudaGraphs safety, it should be a tensor.
817853
- layout_code: the layout code.
854+
- metadata: optional workspace metadata dict from create_ipc_workspace_for_all_reduce_fusion.
855+
If provided, validates that token_num <= max_token_num, world_size == tp_size,
856+
and hidden_dim == workspace hidden_dim. Raises ValueError if validation fails.
818857
"""
819858

859+
# Validate against workspace metadata if provided
860+
if metadata is not None:
861+
errors = []
862+
required_keys = ["max_token_num", "tp_size", "hidden_dim", "use_fp32_lamport"]
863+
for key in required_keys:
864+
if key not in metadata:
865+
errors.append(f"Workspace metadata is missing required key: {key}")
866+
if errors:
867+
error_msg = "Workspace metadata validation failed:\n" + "\n".join(
868+
f" - {e}" for e in errors
869+
)
870+
raise ValueError(error_msg)
871+
872+
# Check 1: token_num must not exceed max_token_num
873+
if token_num > metadata["max_token_num"]:
874+
errors.append(
875+
f"token_num ({token_num}) exceeds workspace max_token_num ({metadata['max_token_num']}). "
876+
f"This may cause Illegal Memory Access."
877+
)
878+
879+
# Check 2: world_size must match tp_size
880+
if world_size != metadata["tp_size"]:
881+
errors.append(
882+
f"world_size ({world_size}) does not match workspace tp_size ({metadata['tp_size']}). "
883+
f"Workspace was created for tp_size={metadata['tp_size']}."
884+
)
885+
886+
# Check 3: hidden_dim must match
887+
if hidden_dim != metadata["hidden_dim"]:
888+
errors.append(
889+
f"hidden_dim ({hidden_dim}) does not match workspace hidden_dim ({metadata['hidden_dim']}). "
890+
f"Workspace was created for hidden_dim={metadata['hidden_dim']}."
891+
)
892+
893+
# Check 4: use_fp32_lamport must match
894+
if metadata["use_fp32_lamport"] != (allreduce_in.dtype == torch.float32):
895+
errors.append(
896+
f"use_fp32_lamport ({metadata['use_fp32_lamport']}) does not match allreduce_in.dtype ({allreduce_in.dtype}). "
897+
f"Workspace was created for use_fp32_lamport={metadata['use_fp32_lamport']}."
898+
)
899+
900+
if errors:
901+
error_msg = "Workspace validation failed:\n" + "\n".join(
902+
f" - {e}" for e in errors
903+
)
904+
raise ValueError(error_msg)
905+
820906
if use_oneshot is None:
821907
use_oneshot = _should_use_oneshot(
822908
token_num, hidden_dim, allreduce_in.dtype, world_size

tests/comm/test_trtllm_allreduce_fusion.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,16 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini
5757

5858
lamport_use_fp32 = dtype == torch.float32
5959

60-
# create workspace for allreduce fusion
61-
ipc_handles, workspace_tensor = (
60+
# create workspace for allreduce fusion with metadata
61+
ipc_handles, workspace_tensor, workspace_metadata = (
6262
comm.trtllm_create_ipc_workspace_for_all_reduce_fusion(
6363
rank,
6464
world_size,
6565
MAX_TOKEN_NUM,
6666
hidden_dim,
6767
group=group,
6868
use_fp32_lamport=lamport_use_fp32,
69+
create_metadata=True, # Get metadata for validation
6970
)
7071
)
7172

@@ -184,6 +185,7 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini
184185
rms_eps=rms_eps,
185186
scale_factor=scale_factor,
186187
layout_code=swizzled_layout_code,
188+
metadata=workspace_metadata,
187189
)
188190

189191
# NOTE: in real case, you dont have to set all optional params. You could set those required by fusion pattern.
@@ -213,6 +215,7 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini
213215
rms_eps=rms_eps,
214216
scale_factor=scale_factor,
215217
layout_code=swizzled_layout_code,
218+
metadata=workspace_metadata,
216219
)
217220
# replay
218221
g.replay()
@@ -304,7 +307,9 @@ def _run_correctness_worker(world_size, rank, dtype, hidden_dim, distributed_ini
304307
finally:
305308
dist.barrier(group=group)
306309

307-
comm.trtllm_destroy_ipc_workspace_for_all_reduce(ipc_handles, group=group)
310+
comm.trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
311+
ipc_handles, group=group
312+
)
308313

309314
dist.destroy_process_group(group=group)
310315

0 commit comments

Comments
 (0)