Skip to content

Commit c1acd99

Browse files
author
Prashant Kumar
committed
Add comprehensive tests to test the kernel across available dtypes.
Added softmax and gemm kernel to test across the available float and int dtypes.
1 parent 971231c commit c1acd99

File tree

7 files changed

+208
-0
lines changed

7 files changed

+208
-0
lines changed

core/shark_turbine/kernel/_support/dtype.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,19 @@
1616
_FLOAT_TYPES = ["f16", "f32", "f64"]
1717
_INDEX_TYPES = ["index"]
1818

19+
import torch
20+
21+
_TKL_TO_TORCH_DTYPE = {
22+
"f16": torch.half,
23+
"f32": torch.float,
24+
"f64": torch.double,
25+
"i1": torch.bool,
26+
"i8": torch.int8,
27+
"i16": torch.int16,
28+
"i32": torch.int32,
29+
"i64": torch.int64,
30+
}
31+
1932

2033
# TODO: this should really be a type.
2134
class DataType:
@@ -44,6 +57,14 @@ def is_float_asm(self):
4457
def is_index_asm(self):
4558
return self._name in _INDEX_TYPES
4659

60+
def to_torch_type(self):
61+
try:
62+
return _TKL_TO_TORCH_DTYPE[self._name]
63+
except KeyError:
64+
print(
65+
f"The support for '{self._name}' dtype to torch type isn't implemented."
66+
)
67+
4768

4869
bool = DataType("bool", "i1")
4970
i4 = DataType("i4")

core/shark_turbine/kernel/_support/tracing.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,14 @@ def handle_exp2(self, op, val):
286286
kwargs={},
287287
)
288288

289+
def handle_rsqrt(self, op, val):
290+
return self.region_graph.create_proxy(
291+
"call_function",
292+
target=op,
293+
args=(val,),
294+
kwargs={},
295+
)
296+
289297
def handle_vector_constant(
290298
self, op, shape: Tuple[int, ...], dtype, value: int | float
291299
):

core/shark_turbine/kernel/compiler/builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,5 +306,8 @@ def binary_truediv_float(
306306
def unary_exp2_float(self, val: IRProxyValue) -> IRProxyValue:
307307
return IRProxyValue(math_d.exp2(val.ir_value))
308308

309+
def unary_rsqrt_float(self, val: IRProxyValue) -> IRProxyValue:
310+
return IRProxyValue(math_d.rsqrt(val.ir_value))
311+
309312

310313
ScalarBuilder = _ScalarBuilder()

core/shark_turbine/kernel/compiler/vector_codegen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def _(emitter: ThreadEmitter, node: fx.Node):
238238

239239
UNARY_ARITHMETIC_OPS = [
240240
(tkl.exp2, "exp2"),
241+
(tkl.rsqrt, "rsqrt"),
241242
]
242243

243244

core/shark_turbine/kernel/lang/prims.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"broadcast_in_dim",
2323
"transpose",
2424
"to_dtype",
25+
"rsqrt",
2526
]
2627

2728

@@ -37,6 +38,7 @@ def is_debug() -> bool:
3738
# Math Operations
3839
exp2 = ops.exp2
3940
constant = ops.vector_constant
41+
rsqrt = ops.rsqrt
4042

4143
# Reduction Operations
4244
max = ops.vector_max

core/shark_turbine/kernel/ops/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
__all__ = [
1212
"exp2",
13+
"rsqrt",
1314
"vector_constant",
1415
]
1516

@@ -22,3 +23,8 @@ def exp2(val):
2223
@define_op
2324
def vector_constant(shape: Tuple[int, ...], dtype, value: int | float) -> "Vector":
2425
...
26+
27+
28+
@define_op
29+
def rsqrt(val):
30+
...

core/tests/kernel/coverage_test.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import torch
2+
import shark_turbine.kernel as tk
3+
import shark_turbine.kernel.lang as tkl
4+
import pytest
5+
6+
7+
FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64]
8+
INT_DTYPES = [
9+
tkl.bool,
10+
tkl.i8,
11+
tkl.i16,
12+
tkl.i32,
13+
tkl.i64,
14+
]
15+
16+
17+
def rms_norm_krnl(dtype, input, weight, output):
18+
M = tkl.sym.M
19+
K = tkl.sym.K
20+
21+
@tk.gen.thread(M)
22+
def rms_norm_kernel(
23+
input: tkl.OutputBuffer[M, K, dtype],
24+
weight: tk.lang.InputBuffer[M, K, dtype],
25+
output: tk.lang.OutputBuffer[M, K, dtype],
26+
):
27+
row_index = tk.lang.program_id(0)
28+
eps = tkl.constant((1,), dtype, 0.00001)
29+
zero = tkl.constant((1,), dtype, 0.0)
30+
input_row = input[row_index, :]
31+
sq_inp = input_row * input_row
32+
sq_inp_red = tkl.sum(sq_inp)
33+
# TODO: The input_row * zero is just dummy computation to pass in the right shapes,
34+
# otherwise it leads to 'error: unknown: 'math.exp2' op operand #0 must be floating-point-like, but got 'vector<f16>'
35+
denom = tkl.rsqrt(input_row * zero + sq_inp_red)
36+
denom_eta = denom + eps
37+
output[row_index, :] = denom_eta * input_row * weight[row_index, :]
38+
39+
with tk.gen.TestLaunchContext():
40+
rms_norm_kernel(input, weight, output)
41+
42+
43+
def iota_krnl(dtype, input):
44+
M = tkl.sym.M
45+
46+
@tk.gen.thread(M)
47+
def iota_kernel(out: tkl.OutputBuffer[M, dtype]):
48+
a = (
49+
tkl.constant((17, 37, 19), dtype, 5)
50+
if dtype in INT_DTYPES
51+
else tkl.constant((17, 37, 19), dtype, 5.0)
52+
)
53+
b = (
54+
tkl.constant((17, 37, 19), dtype, 10)
55+
if dtype in INT_DTYPES
56+
else tkl.constant((17, 37, 19), dtype, 10.0)
57+
)
58+
c = (
59+
tkl.constant((17, 37, 19), dtype, 2)
60+
if dtype in INT_DTYPES
61+
else tkl.constant((17, 37, 19), dtype, 2.0)
62+
)
63+
if dtype in INT_DTYPES:
64+
c = (a * b) // c
65+
else:
66+
c = (a * b) / c
67+
c = c + a - b
68+
69+
with tk.gen.TestLaunchContext():
70+
iota_kernel(input)
71+
72+
73+
def softmax_krnl(dtype, input, output):
74+
M = tkl.sym.M
75+
K = tkl.sym.K
76+
77+
@tk.gen.thread(M)
78+
def softmax_kernel(
79+
input: tk.lang.InputBuffer[M, K, dtype],
80+
output: tk.lang.OutputBuffer[M, K, dtype],
81+
):
82+
row_index = tk.lang.program_id(0)
83+
input_row = input[row_index, :]
84+
numerator = tkl.exp2(input_row - tkl.max(input_row))
85+
if dtype in INT_DTYPES:
86+
output_row = numerator // tkl.sum(numerator)
87+
else:
88+
output_row = numerator / tkl.sum(numerator)
89+
output[row_index, :] = output_row
90+
91+
with tk.gen.TestLaunchContext():
92+
softmax_kernel(input, output)
93+
94+
95+
def gemm_fx_kernel(dtype, A, B, output):
96+
N = tkl.sym.N
97+
M = tkl.sym.M
98+
K = tkl.sym.K
99+
BLOCK_SIZE = tkl.sym.BLOCK_SIZE
100+
101+
@tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE)
102+
def gemm_kernel(
103+
A: tkl.InputBuffer[N, K, dtype],
104+
B: tkl.InputBuffer[K, M, dtype],
105+
output: tkl.OutputBuffer[N, M, dtype],
106+
):
107+
grid_n = tkl.program_id(0)
108+
grid_m = tkl.program_id(1)
109+
110+
acc = None
111+
# TODO: Only considering the float and integer cases.
112+
if dtype in INT_DTYPES:
113+
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0)
114+
else:
115+
acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0)
116+
117+
@tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc])
118+
def body(i, c):
119+
a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE))
120+
b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE))
121+
return (tkl.dot(a, b, c),)
122+
123+
tkl.store(output, (grid_n, grid_m), body[0])
124+
125+
with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}):
126+
gemm_kernel(A, B, output)
127+
128+
129+
@pytest.mark.parametrize(
130+
("dtype",),
131+
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
132+
)
133+
def test_iota_krnl(dtype):
134+
input = torch.zeros(17)
135+
iota_krnl(dtype, input)
136+
137+
138+
@pytest.mark.parametrize(
139+
("dtype",),
140+
[(x,) for x in FLOAT_DTYPES],
141+
)
142+
def test_rms_norm_krnl(dtype):
143+
input = torch.randn(128, 64).to(dtype.to_torch_type())
144+
weight = torch.randn(128, 64).to(dtype.to_torch_type())
145+
output = torch.randn(128, 64).to(dtype.to_torch_type())
146+
rms_norm_krnl(dtype, input, weight, output)
147+
148+
149+
@pytest.mark.parametrize(
150+
("dtype",),
151+
[(x,) for x in FLOAT_DTYPES],
152+
)
153+
def test_softmax_krnl(dtype):
154+
input = torch.randn(128, 64).to(dtype.to_torch_type())
155+
output = torch.randn(128, 64).to(dtype.to_torch_type())
156+
softmax_krnl(dtype, input, output)
157+
158+
159+
@pytest.mark.parametrize(
160+
("dtype",),
161+
[(x,) for x in FLOAT_DTYPES + INT_DTYPES],
162+
)
163+
def test_gemm_krnl(dtype):
164+
A = torch.randn(512, 1024).to(dtype.to_torch_type())
165+
B = torch.randn(1024, 2048).to(dtype.to_torch_type())
166+
output = torch.zeros(512, 2048).to(dtype.to_torch_type())
167+
gemm_fx_kernel(dtype, A, B, output)

0 commit comments

Comments
 (0)