@@ -536,6 +536,7 @@ def is_nvfp4_output_kernel_available(
536536@dataclass (kw_only = True )
537537class TrtllmAttentionMetadata (AttentionMetadata ):
538538 workspace : Optional [torch .Tensor ] = None
539+ cuda_graph_workspace : Optional [torch .Tensor ] = None
539540
540541 # TrtllmAttention needs to know the beam width to access to the cache indirection buffer,
541542 # when beam search is enabled.
@@ -693,6 +694,14 @@ def get_empty_like(like_tensor: torch.Tensor,
693694 device = 'cuda' ,
694695 dtype = torch .int8 ,
695696 )
697+
698+ if self .cuda_graph_workspace is None :
699+ self .cuda_graph_workspace = torch .empty (
700+ (0 , ),
701+ device = 'cuda' ,
702+ dtype = torch .int8 ,
703+ )
704+
696705 if self .kv_cache_manager is not None :
697706 self .kv_cache_block_offsets = get_empty (
698707 [
@@ -1276,8 +1285,9 @@ def forward(
12761285 host_kv_cache_pool_pointers = metadata .host_kv_cache_pool_pointers ,
12771286 host_kv_cache_pool_mapping = metadata .host_kv_cache_pool_mapping ,
12781287 block_ids_per_seq = metadata .block_ids_per_seq ,
1279- workspace = metadata .
1280- workspace , # re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1288+ # re-enable it, if pass None to it, fp8 mla will encounter invalid cuda free issue.
1289+ workspace = metadata .workspace
1290+ if not metadata .is_cuda_graph else metadata .cuda_graph_workspace ,
12811291 cache_indirection = metadata .cache_indirection ,
12821292 kv_scale_orig_quant = self .kv_scale_orig_quant ,
12831293 kv_scale_quant_orig = self .kv_scale_quant_orig ,
0 commit comments