diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index 0a9c210..47362f6 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -36,7 +36,7 @@ std::vector mha_varlen_fwd( if (out_.has_value()) { out = *out_; } else { - out = torch::zeros_like(q); + out = torch::empty_like(q); } cutlass_chunk_prefill_impl(queue, q, k, v, out, block_table_, cu_seqlens_q, diff --git a/csrc/xpu/cutlass_kernels/chunk_prefill_kernel.hpp b/csrc/xpu/cutlass_kernels/chunk_prefill_kernel.hpp index 1f26194..de026a9 100644 --- a/csrc/xpu/cutlass_kernels/chunk_prefill_kernel.hpp +++ b/csrc/xpu/cutlass_kernels/chunk_prefill_kernel.hpp @@ -292,10 +292,10 @@ class FMHAPrefillChunk { Tensor mQ_mkl = cute::get_xe_tensor( make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) - Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape( - seq_len_kv_cache, head_size_qk * num_heads_kv, 1)); // (n_cache,k,l) - Tensor mV_cache_nkl = cute::get_xe_tensor(make_shape( - head_size_vo * num_heads_kv, seq_len_kv_cache, 1)); // (n_cache,k,l) + Tensor mK_cache_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) Tensor mQ_mk = mQ_mkl(_, _, 0); Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k)