Skip to content

Commit 6adccd7

Browse files
authored
[https://nvbugs/5606268][fix] Separate cuda graph workspace to prevent IMA (#8685)
Signed-off-by: Junyi Xu <[email protected]>
1 parent e9aa8b2 commit 6adccd7

File tree

1 file changed

+12
-2
lines changed
  • tensorrt_llm/_torch/attention_backend

1 file changed

+12
-2
lines changed

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def is_nvfp4_output_kernel_available(
536536
@dataclass(kw_only=True)
537537
class TrtllmAttentionMetadata(AttentionMetadata):
538538
workspace: Optional[torch.Tensor] = None
539+
cuda_graph_workspace: Optional[torch.Tensor] = None
539540

540541
# TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
541542
# when beam search is enabled.
@@ -693,6 +694,14 @@ def get_empty_like(like_tensor: torch.Tensor,
693694
device='cuda',
694695
dtype=torch.int8,
695696
)
697+
698+
if self.cuda_graph_workspace is None:
699+
self.cuda_graph_workspace = torch.empty(
700+
(0, ),
701+
device='cuda',
702+
dtype=torch.int8,
703+
)
704+
696705
if self.kv_cache_manager is not None:
697706
self.kv_cache_block_offsets = get_empty(
698707
[
@@ -1276,8 +1285,9 @@ def forward(
12761285
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
12771286
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
12781287
block_ids_per_seq=metadata.block_ids_per_seq,
1279-
workspace=metadata.
1280-
workspace, # re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1288+
# re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1289+
workspace=metadata.workspace
1290+
if not metadata.is_cuda_graph else metadata.cuda_graph_workspace,
12811291
cache_indirection=metadata.cache_indirection,
12821292
kv_scale_orig_quant=self.kv_scale_orig_quant,
12831293
kv_scale_quant_orig=self.kv_scale_quant_orig,

0 commit comments

Comments
 (0)