Skip to content

Commit bc02c45

Browse files
authored
[TK] Add test for fused attention and ops needed (#480)
This patch adds a test for f16 attention. The fused_attention kernel needed dtype conversion so that is also added. As a side effect, implicit vector conversion (for unary ops) is now also supported.
1 parent d21bb9a commit bc02c45

File tree

6 files changed

+194
-39
lines changed

6 files changed

+194
-39
lines changed

core/shark_turbine/kernel/_support/tracing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,14 @@ def handle_thread_program_id(self, op, axis: int) -> Index:
203203
)
204204
return proxy
205205

206+
def handle_to_dtype(self, op, val, dtype):
207+
return self.region_graph.create_proxy(
208+
"call_function",
209+
op,
210+
args=(val, dtype),
211+
kwargs={},
212+
)
213+
206214
def handle_kernel_buffer_getitem(self, op, kernel_buffer: KernelBuffer, key):
207215
return self.region_graph.create_proxy(
208216
"call_function",
@@ -349,7 +357,6 @@ def handle_vector_broadcast_in_dim(self, op, vector, shape, broadcast_dimensions
349357
i for i in range(len(shape)) if i not in broadcast_dimensions
350358
)
351359
permutation = permutation + tuple(broadcast_dimensions)
352-
print(permutation)
353360

354361
# Transpose
355362
return self.region_graph.create_proxy(

core/shark_turbine/kernel/compiler/builder.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@
3232
F64Type,
3333
)
3434

35-
# TODO: Have a way upstream to check if a floating point type.
36-
FLOAT_TYPES_ASM = {
37-
"bf16",
38-
"f16",
39-
"f32",
40-
"f64",
35+
# TODO: Use FloatType from upstream when available.
36+
FLOAT_BITWIDTHS = {
37+
"bf16": 16,
38+
"f16": 16,
39+
"f32": 32,
40+
"f64": 64,
4141
# TODO: FP8 types.
4242
}
4343

@@ -87,28 +87,54 @@ def __init__(
8787

8888
class _ScalarBuilder:
8989
def is_floating_point_type(self, t: IrType) -> bool:
90-
return str(t) in FLOAT_TYPES_ASM
90+
# TODO: Use FloatType from upstream when available.
91+
return str(t) in FLOAT_BITWIDTHS
9192

9293
def is_integer_type(self, t: IrType) -> bool:
9394
return IntegerType.isinstance(t)
9495

9596
def is_index_type(self, t: IrType) -> bool:
9697
return IndexType.isinstance(t)
9798

98-
def promote(self, value: Value, to_type: IrType) -> Value:
99-
value_type = value.type
99+
def get_typeclass(self, t: IrType, index_same_as_integer=False) -> str:
100+
# If this is a vector type, get the element type.
101+
if isinstance(t, VectorType):
102+
t = t.element_type
103+
if self.is_floating_point_type(t):
104+
return "float"
105+
if self.is_integer_type(t):
106+
return "integer"
107+
if self.is_index_type(t):
108+
return "integer" if index_same_as_integer else "index"
109+
raise CodegenError(f"Unknown typeclass for type `{t}`")
110+
111+
def get_float_bitwidth(self, t: IrType) -> int:
112+
# If this is a vector type, get the element type.
113+
if isinstance(t, VectorType):
114+
t = t.element_type
115+
return FLOAT_BITWIDTHS[str(t)]
116+
117+
def to_dtype(self, value: IRProxyValue, dtype: IrType) -> IRProxyValue:
118+
value_type = value.ir_value.type
119+
# Create a vector type for dtype if value is a vector.
120+
to_type = dtype
121+
if isinstance(value_type, VectorType):
122+
to_type = VectorType.get(value_type.shape, dtype)
123+
100124
# Short-circuit if already the right type.
101125
if value_type == to_type:
102126
return value
103127

104-
attr_name = f"promote_{value_type}_to_{to_type}"
128+
value_typeclass = self.get_typeclass(value_type)
129+
to_typeclass = self.get_typeclass(dtype)
130+
attr_name = f"to_dtype_{value_typeclass}_to_{to_typeclass}"
105131
try:
106132
handler = getattr(self, attr_name)
107133
except AttributeError:
108134
raise CodegenError(
109135
f"No implemented path to implicitly promote scalar `{value_type}` to `{to_type}` (tried '{attr_name}')"
110136
)
111-
return handler(value, to_type)
137+
return IRProxyValue(handler(value.ir_value, to_type))
112138

113139
def constant_attr(self, val: int | float, element_type: IrType) -> Attribute:
114140
if self.is_integer_type(element_type) or self.is_index_type(element_type):
@@ -153,7 +179,7 @@ def binary_arithmetic(
153179
f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir_type} and {rhs_ir_type} due to element type mismatch"
154180
)
155181

156-
typeclass = "float" if self.is_floating_point_type(lhs_ir_type) else "integer"
182+
typeclass = self.get_typeclass(lhs_ir_type, True)
157183
attr_name = f"binary_{op}_{typeclass}"
158184
try:
159185
handler = getattr(self, attr_name)
@@ -176,9 +202,7 @@ def binary_vector_arithmetic(
176202
f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} due to element type mismatch"
177203
)
178204

179-
typeclass = (
180-
"float" if self.is_floating_point_type(lhs_element_type) else "integer"
181-
)
205+
typeclass = self.get_typeclass(lhs_element_type, True)
182206
attr_name = f"binary_{op}_{typeclass}"
183207
try:
184208
handler = getattr(self, attr_name)
@@ -190,7 +214,7 @@ def binary_vector_arithmetic(
190214

191215
def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
192216
val_ir_type = val.ir_value.type
193-
typeclass = "float" if self.is_floating_point_type(val_ir_type) else "integer"
217+
typeclass = self.get_typeclass(val_ir_type, True)
194218
attr_name = f"unary_{op}_{typeclass}"
195219
try:
196220
handler = getattr(self, attr_name)
@@ -203,9 +227,7 @@ def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
203227
def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
204228
val_ir = val.ir_value
205229
val_element_type = VectorType(val_ir.type).element_type
206-
typeclass = (
207-
"float" if self.is_floating_point_type(val_element_type) else "integer"
208-
)
230+
typeclass = self.get_typeclass(val_element_type, True)
209231
attr_name = f"unary_{op}_{typeclass}"
210232
try:
211233
handler = getattr(self, attr_name)
@@ -217,10 +239,33 @@ def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
217239

218240
### Specializations
219241

220-
def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value:
221-
i32_type = IntegerType.get_signless(32)
222-
i32 = arith_d.index_cast(i32_type, value)
223-
return arith_d.sitofp(to_type, i32)
242+
# Casting
243+
def to_dtype_index_to_integer(self, value: Value, to_type: IrType) -> Value:
244+
return arith_d.index_cast(to_type, value)
245+
246+
def to_dtype_index_to_float(self, value: Value, to_type: IrType) -> Value:
247+
# Cast index to integer, and then ask for a integer to float cast.
248+
# TODO: I don't really know how to query the machine bitwidth here,
249+
# so using 64.
250+
casted_to_int = arith_d.index_cast(IntegerType.get_signless(64), value)
251+
return self.to_dtype(IRProxyValue(casted_to_int), to_type).ir_value
252+
253+
def to_dtype_integer_to_float(self, value: Value, to_type: IrType) -> Value:
254+
# sitofp
255+
casted_to_float = arith_d.sitofp(to_type, value)
256+
return self.to_dtype(IRProxyValue(casted_to_float), to_type).ir_value
257+
258+
def to_dtype_float_to_float(self, value: Value, to_type: IrType) -> Value:
259+
# Check bitwidth to determine if we need to extend or narrow
260+
from_type = value.type
261+
from_bitwidth = self.get_float_bitwidth(from_type)
262+
to_bitwidth = self.get_float_bitwidth(to_type)
263+
if from_bitwidth < to_bitwidth:
264+
return arith_d.extf(to_type, value)
265+
elif from_bitwidth > to_bitwidth:
266+
return arith_d.truncf(to_type, value)
267+
else:
268+
raise CodegenError(f"NYI: Cast from {from_type} to {to_type}")
224269

225270
# Binary integer/integer arithmetic.
226271
def binary_add_integer(self, lhs: IRProxyValue, rhs: IRProxyValue) -> IRProxyValue:

core/shark_turbine/kernel/compiler/vector_codegen.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,18 @@ def _(emitter: ThreadEmitter, node: fx.Node):
358358
)
359359

360360

361+
@handle_op(tkl.to_dtype)
362+
def _(emitter: ThreadEmitter, node: fx.Node):
363+
try:
364+
(val, dtype) = node.args
365+
except ValueError as e:
366+
raise ValidationError("Malformed arguments") from e
367+
368+
ir_type = cast_dtype(emitter, dtype)
369+
casted = cast_vector(emitter, val, element_type=ir_type)
370+
emitter.bind_node_proxy(node, IRProxyValue(casted))
371+
372+
361373
@handle_op(ops.kernel_buffer_getitem)
362374
def _(emitter: ThreadEmitter, node: fx.Node):
363375
try:
@@ -828,24 +840,19 @@ def cast_vector(
828840
emitter: ThreadEmitter, value, *, element_type: Optional[IrType] = None
829841
):
830842
proxy_value = cast_py_value(emitter, value)
831-
value = proxy_value.ir_value
832843

833-
# Promote scalar types correctly first.
834-
if element_type and not ShapedType.isinstance(value.type):
844+
# Cast scalar types correctly first.
845+
if element_type and not ShapedType.isinstance(proxy_value.ir_value.type):
835846
# Implicit scalar type promotion.
836-
value = ScalarBuilder.promote(value, element_type)
847+
proxy_value = ScalarBuilder.to_dtype(proxy_value, element_type)
848+
849+
value = proxy_value.ir_value
837850

838851
# After scalar promotion, promote to vector.
839852
if VectorType.isinstance(value.type):
840853
# Already a vector. Coerce or return.
841854
if element_type is not None:
842-
vector_type = VectorType(value.type)
843-
if vector_type.element_type == element_type:
844-
return value
845-
# TODO: Implement vector element type conversion.
846-
raise CodegenError(
847-
f"Implicit conversion of vector element types not supported (`{vector_type.element_type}` -> `{element_type}`)"
848-
)
855+
value = ScalarBuilder.to_dtype(proxy_value, element_type).ir_value
849856
# No target element_type.
850857
return value
851858
else:

core/shark_turbine/kernel/lang/prims.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
"broadcast",
2222
"broadcast_in_dim",
2323
"transpose",
24+
"to_dtype",
2425
]
2526

2627

@@ -31,6 +32,7 @@ def is_debug() -> bool:
3132

3233
# Core language operations
3334
program_id = ops.thread_program_id
35+
to_dtype = ops.to_dtype
3436

3537
# Math Operations
3638
exp2 = ops.exp2

core/shark_turbine/kernel/ops/core.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from typing import Any
1+
from typing import Any, TypeVar
22
import typing
33

44
if typing.TYPE_CHECKING:
55
from ..lang.types import Index, Vector
66

7-
from .base import (
8-
define_op,
9-
)
7+
from .base import define_op
8+
from .._support.dtype import DataType
109

1110
__all__ = [
1211
"kernel_buffer_getitem",
1312
"kernel_buffer_setitem",
1413
"thread_program_id",
14+
"to_dtype",
1515
]
1616

1717

@@ -28,3 +28,8 @@ def kernel_buffer_setitem(kernel_buffer, key, item) -> None:
2828
@define_op
2929
def thread_program_id(axis: int) -> "Index":
3030
...
31+
32+
33+
@define_op
34+
def to_dtype(val, dtype: DataType):
35+
...
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import logging
2+
import unittest
3+
4+
import torch
5+
import shark_turbine.kernel as tk
6+
import shark_turbine.kernel.lang as tkl
7+
8+
BATCH = tkl.sym.BATCH
9+
N_HEADS = tkl.sym.N_HEADS
10+
N_CTX = tkl.sym.N_CTX
11+
D_HEAD = tkl.sym.D_HEAD
12+
13+
BLOCK_N = tkl.sym.BLOCK_N
14+
BLOCK_M = tkl.sym.BLOCK_M
15+
16+
17+
class Test(unittest.TestCase):
18+
def testFusedAttention(self):
19+
@tk.gen.thread(N_CTX // BLOCK_M, BATCH * N_HEADS)
20+
def fused_attention(
21+
Q: tkl.InputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD].of(tkl.f16),
22+
K: tkl.InputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD].of(tkl.f16),
23+
V: tkl.InputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD].of(tkl.f16),
24+
O: tkl.OutputBuffer[BATCH, N_HEADS, N_CTX, D_HEAD].of(tkl.f16),
25+
):
26+
grid_n = tkl.program_id(0)
27+
grid_m = tkl.program_id(1)
28+
29+
batch = grid_m // N_HEADS
30+
head = grid_m % N_HEADS
31+
32+
q = tkl.load(Q, (batch, head, grid_n * BLOCK_M, 0), (BLOCK_M, D_HEAD))
33+
acc_init = tkl.constant((BLOCK_M, D_HEAD), tkl.f32, 0.0)
34+
max_stat_init = tkl.constant((BLOCK_M,), tkl.f32, -1e9)
35+
sum_stat_init = tkl.constant((BLOCK_M,), tkl.f32, 0.0)
36+
37+
@tkl.for_loop(
38+
0, N_CTX, BLOCK_N, init_args=[max_stat_init, sum_stat_init, acc_init]
39+
)
40+
def body(i, old_max, old_sum, old_acc):
41+
k = tkl.load(K, (batch, head, i, 0), (BLOCK_N, D_HEAD))
42+
kT = tkl.transpose(k, (1, 0))
43+
44+
qkT = tkl.constant((BLOCK_M, BLOCK_N), tkl.f32, 0.0)
45+
qkT = tkl.dot(q, kT, qkT)
46+
47+
new_max = tkl.max(qkT, axis=1, acc=old_max)
48+
broadcasted_max = tkl.broadcast_in_dim(
49+
new_max, (BLOCK_M, BLOCK_N), (0,)
50+
)
51+
partial_softmax = tkl.exp2(qkT - broadcasted_max)
52+
scale_factor = tkl.exp2(old_max - new_max)
53+
scaled_old_sum = scale_factor * old_sum
54+
new_sum = tkl.sum(partial_softmax, axis=1, acc=scaled_old_sum)
55+
broadcasted_scale_factor = tkl.broadcast_in_dim(
56+
scale_factor, (BLOCK_M, D_HEAD), (0,)
57+
)
58+
new_acc = old_acc * broadcasted_scale_factor
59+
60+
v = tkl.load(V, (batch, head, i, 0), (BLOCK_N, D_HEAD))
61+
qkT16 = tkl.to_dtype(qkT, tkl.f16)
62+
new_acc = tkl.dot(qkT16, v, new_acc)
63+
64+
return (new_max, new_sum, new_acc)
65+
66+
sum_stat = body[1]
67+
result = body[2]
68+
one = tkl.constant((BLOCK_M,), tkl.f32, 1.0)
69+
one_by_sum = one / sum_stat
70+
result = tkl.broadcast_in_dim(one_by_sum, (BLOCK_M, D_HEAD), (0,)) * result
71+
tkl.store(O, (batch, head, grid_n * BLOCK_M, 0), result)
72+
73+
Q = torch.randn(4, 48, 1024, 64)
74+
K = torch.randn(4, 48, 1024, 64)
75+
V = torch.randn(4, 48, 1024, 64)
76+
O = torch.randn(4, 48, 1024, 64)
77+
78+
with tk.gen.TestLaunchContext(
79+
{
80+
BLOCK_N: 128,
81+
BLOCK_M: 256,
82+
}
83+
):
84+
fused_attention(Q, K, V, O)
85+
86+
87+
if __name__ == "__main__":
88+
logging.basicConfig(level=logging.DEBUG)
89+
unittest.main()

0 commit comments

Comments
 (0)