Skip to content

Commit 2ae719c

Browse files
committed
feat: added kernel builder for attn
1 parent 83a4415 commit 2ae719c

File tree

6 files changed

+133
-39
lines changed

6 files changed

+133
-39
lines changed

src/kernels/attention/common/fmha_block.h

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,13 @@ using namespace cute;
1414
// AttentionTile specialization for AttentionParams
1515
template <typename TileShape, // (BLK_M, BLK_N, BLK_K)
1616
typename Element, // Element type
17+
typename StrideQ, // (B, Q, H, D)
18+
typename StrideK, // (B, Q, H, D)
19+
typename StrideV, // (B, Q, KH, D)
20+
typename StrideO, // (B, Q, KH, D)
1721
bool kLocal>
1822
struct FmhaBlock {
19-
// (B, Q, H, D)
20-
using StrideQ = Stride<int64_t, int64_t, int64_t, _1>;
21-
using StrideO = StrideQ;
22-
// (B, K, KH, D)
23-
using StrideK = Stride<int64_t, int64_t, int64_t, _1>;
24-
using StrideV = StrideK;
25-
2623
// Host side parameters
27-
2824
struct Arguments {
2925
const void* __restrict__ q_ptr;
3026
const void* __restrict__ k_ptr;

src/kernels/attention/fmha_runner.h

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,14 @@
55
#include <cute/layout.hpp>
66
#include <cute/tensor.hpp>
77

8-
#include "collective/sm120_collective_epilogue.cuh"
9-
#include "collective/sm120_collective_fmha_mainloop_ws.cuh"
10-
#include "common/fmha_block.h"
11-
#include "common/tile_scheduler.cuh"
128
#include "device/fmha.cuh"
139
#include "fmha_params.h"
14-
#include "kernel/sm120_kernel_fmha_ws.cuh"
10+
#include "kernel/kernel_builder.h" // IWYU pragma: keep
1511

1612
namespace llm {
17-
// ? Should include ArchTag?
18-
// * select right kernel based on ArchTag?
19-
// ? how to support fast compliling?
13+
// TODO: support fast compliling
2014
// * only compile the kernel for the target compute capability
21-
template <typename Element, int kHeadDim>
15+
template <class ArchTag, typename Element, int kHeadDim>
2216
class FmhaRunner {
2317
public:
2418
static bool run(const FmhaParams& params, cudaStream_t stream = nullptr) {
@@ -64,26 +58,25 @@ class FmhaRunner {
6458

6559
using TileShape = Shape<Int<BLK_M>, Int<BLK_N>, Int<kHeadDim>>;
6660

67-
using Block = FmhaBlock<TileShape, Element, LOCAL>;
68-
69-
using CollectiveMainloop = Sm120CollectiveFMhaWs<TileShape,
70-
Element,
71-
EVEN_K,
72-
ALIBI,
73-
SOFT_CAP,
74-
LOCAL,
75-
KV_USE_TMA>;
76-
using CollectiveEpilogue =
77-
Sm120CollectiveEpilogue<TileShape, Element, EVEN_K>;
78-
79-
// TODO: support persistent kernels
80-
using TileScheduler = SingleTileScheduler;
81-
82-
using AttnKernel = Sm120KernelFmhaWs<ProblemShape,
83-
Block,
84-
CollectiveMainloop,
85-
CollectiveEpilogue,
86-
TileScheduler>;
61+
// (B, Q, H, D)
62+
using StrideQ = Stride<int64_t, int64_t, int64_t, _1>;
63+
using StrideK = Stride<int64_t, int64_t, int64_t, _1>;
64+
using StrideV = StrideK;
65+
using StrideO = StrideQ;
66+
67+
using AttnKernel = typename KernelBuilder<ArchTag,
68+
ProblemShape,
69+
TileShape,
70+
Element,
71+
StrideQ,
72+
StrideK,
73+
StrideV,
74+
StrideO,
75+
EVEN_K,
76+
ALIBI,
77+
SOFT_CAP,
78+
LOCAL,
79+
KV_USE_TMA>::Kernel;
8780

8881
assert(params.n_heads % params.n_kv_heads == 0 &&
8982
"n_heads must be divisible by n_kv_heads");
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#pragma once
2+
3+
namespace llm {
4+
5+
template <class ArchTag,
6+
class ProblemShape,
7+
class TileShape,
8+
class Element,
9+
class StrideQ,
10+
class StrideK,
11+
class StrideV,
12+
class StrideO,
13+
bool EVEN_K,
14+
bool ALIBI,
15+
bool SOFT_CAP,
16+
bool LOCAL,
17+
bool KV_USE_TMA,
18+
class Enable = void>
19+
struct KernelBuilder {
20+
static_assert(sizeof(Element) == 0,
21+
"Could not build a kernel for given parameters.");
22+
};
23+
24+
} // namespace llm
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
#pragma once
2+
3+
// #include <cuda.h>
4+
// #include <cuda_runtime.h>
5+
// #include <cutlass/arch/arch.h>
6+
// #include <cutlass/arch/reg_reconfig.h>
7+
8+
// #include <cute/layout.hpp>
9+
// #include <cute/tensor.hpp>
10+
// #include <cutlass/pipeline/pipeline.hpp>
11+
12+
#include <cutlass/arch/arch.h>
13+
14+
#include <cute/tensor.hpp>
15+
16+
#include "collective/sm120_collective_epilogue.cuh"
17+
#include "collective/sm120_collective_fmha_mainloop_ws.cuh"
18+
#include "common/fmha_block.h"
19+
#include "common/tile_scheduler.cuh"
20+
#include "kernel/sm120_kernel_fmha_ws.cuh"
21+
#include "kernel_builder_decl.h"
22+
23+
namespace llm {
24+
25+
template <class ProblemShape,
26+
class TileShape,
27+
class Element,
28+
class StrideQ,
29+
class StrideK,
30+
class StrideV,
31+
class StrideO,
32+
bool EVEN_K,
33+
bool ALIBI,
34+
bool SOFT_CAP,
35+
bool LOCAL,
36+
bool KV_USE_TMA>
37+
struct KernelBuilder<cutlass::arch::Sm120,
38+
ProblemShape,
39+
TileShape,
40+
Element,
41+
StrideQ,
42+
StrideK,
43+
StrideV,
44+
StrideO,
45+
EVEN_K,
46+
ALIBI,
47+
SOFT_CAP,
48+
LOCAL,
49+
KV_USE_TMA,
50+
cute::enable_if_t<not cute::is_tuple_v<Element>>> {
51+
using Block =
52+
FmhaBlock<TileShape, Element, StrideQ, StrideK, StrideV, StrideO, LOCAL>;
53+
54+
using CollectiveMainloop = Sm120CollectiveFMhaWs<TileShape,
55+
Element,
56+
EVEN_K,
57+
ALIBI,
58+
SOFT_CAP,
59+
LOCAL,
60+
KV_USE_TMA>;
61+
using CollectiveEpilogue =
62+
Sm120CollectiveEpilogue<TileShape, Element, EVEN_K>;
63+
64+
// TODO: support persistent kernels
65+
using TileScheduler = SingleTileScheduler;
66+
67+
using Kernel = Sm120KernelFmhaWs<ProblemShape,
68+
Block,
69+
CollectiveMainloop,
70+
CollectiveEpilogue,
71+
TileScheduler>;
72+
};
73+
74+
} // namespace llm
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#pragma once
2+
3+
// kernel builder declarations
4+
#include "builders/kernel_builder_decl.h" // IWYU pragma: keep
5+
6+
// kernel build implementations
7+
#include "builders/sm120_kernel_builder.inl" // IWYU pragma: keep

src/kernels/attention/tests/sm120_fmha_test.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ torch::Tensor sm120_fmha(
8989
: nullptr;
9090

9191
// params.max_q_len = max_q_len;
92-
92+
using ArchTag = cutlass::arch::Sm120;
9393
DISPATCH_TORCH_DTYPE_(query.dtype(), Dtype, [&] {
9494
DISPATCH_HEAD_DIM_(head_dim, HEAD_DIM, [&] {
95-
FmhaRunner<Dtype, HEAD_DIM>::run(params, /*stream=*/nullptr);
95+
FmhaRunner<ArchTag, Dtype, HEAD_DIM>::run(params, /*stream=*/nullptr);
9696
});
9797
});
9898
return out;

0 commit comments

Comments
 (0)