Skip to content

Commit 4085f86

Browse files
authored
[CHUNK_PREFILL] fix tiling shape (#51)
* fix tiling shape Signed-off-by: Yizhou Wang <[email protected]> * use empty instead of zero Signed-off-by: Yizhou Wang <[email protected]> --------- Signed-off-by: Yizhou Wang <[email protected]>
1 parent 133e493 commit 4085f86

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

csrc/flash_attn/flash_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ std::vector<at::Tensor> mha_varlen_fwd(
3636
if (out_.has_value()) {
3737
out = *out_;
3838
} else {
39-
out = torch::zeros_like(q);
39+
out = torch::empty_like(q);
4040
}
4141

4242
cutlass_chunk_prefill_impl(queue, q, k, v, out, block_table_, cu_seqlens_q,

csrc/xpu/cutlass_kernels/chunk_prefill_kernel.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,10 @@ class FMHAPrefillChunk {
292292

293293
Tensor mQ_mkl = cute::get_xe_tensor(
294294
make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l)
295-
Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape(
296-
seq_len_kv_cache, head_size_qk * num_heads_kv, 1)); // (n_cache,k,l)
297-
Tensor mV_cache_nkl = cute::get_xe_tensor(make_shape(
298-
head_size_vo * num_heads_kv, seq_len_kv_cache, 1)); // (n_cache,k,l)
295+
Tensor mK_cache_nkl = cute::get_xe_tensor(
296+
make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l)
297+
Tensor mV_cache_nkl = cute::get_xe_tensor(
298+
make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l)
299299

300300
Tensor mQ_mk = mQ_mkl(_, _, 0);
301301
Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k)

0 commit comments

Comments
 (0)