diff --git a/core/shark_turbine/kernel/_support/dtype.py b/core/shark_turbine/kernel/_support/dtype.py index 2b56012ef..c0b233c8f 100644 --- a/core/shark_turbine/kernel/_support/dtype.py +++ b/core/shark_turbine/kernel/_support/dtype.py @@ -16,6 +16,19 @@ _FLOAT_TYPES = ["f16", "f32", "f64"] _INDEX_TYPES = ["index"] +import torch + +_TKL_TO_TORCH_DTYPE = { + "f16": torch.half, + "f32": torch.float, + "f64": torch.double, + "i1": torch.bool, + "i8": torch.int8, + "i16": torch.int16, + "i32": torch.int32, + "i64": torch.int64, +} + # TODO: this should really be a type. class DataType: @@ -44,6 +57,14 @@ def is_float_asm(self): def is_index_asm(self): return self._name in _INDEX_TYPES + def to_torch_type(self): + try: + return _TKL_TO_TORCH_DTYPE[self._name] + except KeyError: + print( + f"The support for '{self._name}' dtype to torch type isn't implemented." + ) + bool = DataType("bool", "i1") i4 = DataType("i4") diff --git a/core/shark_turbine/kernel/_support/tracing.py b/core/shark_turbine/kernel/_support/tracing.py index 6d550d652..52b2dca84 100644 --- a/core/shark_turbine/kernel/_support/tracing.py +++ b/core/shark_turbine/kernel/_support/tracing.py @@ -286,6 +286,14 @@ def handle_exp2(self, op, val): kwargs={}, ) + def handle_rsqrt(self, op, val): + return self.region_graph.create_proxy( + "call_function", + target=op, + args=(val,), + kwargs={}, + ) + def handle_vector_constant( self, op, shape: Tuple[int, ...], dtype, value: int | float ): diff --git a/core/shark_turbine/kernel/compiler/builder.py b/core/shark_turbine/kernel/compiler/builder.py index c7231250b..ee6030a70 100644 --- a/core/shark_turbine/kernel/compiler/builder.py +++ b/core/shark_turbine/kernel/compiler/builder.py @@ -306,5 +306,8 @@ def binary_truediv_float( def unary_exp2_float(self, val: IRProxyValue) -> IRProxyValue: return IRProxyValue(math_d.exp2(val.ir_value)) + def unary_rsqrt_float(self, val: IRProxyValue) -> IRProxyValue: + return IRProxyValue(math_d.rsqrt(val.ir_value)) + ScalarBuilder = _ScalarBuilder() diff --git a/core/shark_turbine/kernel/compiler/vector_codegen.py b/core/shark_turbine/kernel/compiler/vector_codegen.py index 5ecb8ef93..9bdbc8a4b 100644 --- a/core/shark_turbine/kernel/compiler/vector_codegen.py +++ b/core/shark_turbine/kernel/compiler/vector_codegen.py @@ -238,6 +238,7 @@ def _(emitter: ThreadEmitter, node: fx.Node): UNARY_ARITHMETIC_OPS = [ (tkl.exp2, "exp2"), + (tkl.rsqrt, "rsqrt"), ] diff --git a/core/shark_turbine/kernel/lang/prims.py b/core/shark_turbine/kernel/lang/prims.py index b9f163085..f0a4fa0b9 100644 --- a/core/shark_turbine/kernel/lang/prims.py +++ b/core/shark_turbine/kernel/lang/prims.py @@ -22,6 +22,7 @@ "broadcast_in_dim", "transpose", "to_dtype", + "rsqrt", ] @@ -37,6 +38,7 @@ def is_debug() -> bool: # Math Operations exp2 = ops.exp2 constant = ops.vector_constant +rsqrt = ops.rsqrt # Reduction Operations max = ops.vector_max diff --git a/core/shark_turbine/kernel/ops/math.py b/core/shark_turbine/kernel/ops/math.py index 0b617baa5..57ac1888f 100644 --- a/core/shark_turbine/kernel/ops/math.py +++ b/core/shark_turbine/kernel/ops/math.py @@ -10,6 +10,7 @@ __all__ = [ "exp2", + "rsqrt", "vector_constant", ] @@ -22,3 +23,8 @@ def exp2(val): @define_op def vector_constant(shape: Tuple[int, ...], dtype, value: int | float) -> "Vector": ... + + +@define_op +def rsqrt(val): + ... diff --git a/core/tests/kernel/coverage_test.py b/core/tests/kernel/coverage_test.py new file mode 100644 index 000000000..05a58946b --- /dev/null +++ b/core/tests/kernel/coverage_test.py @@ -0,0 +1,167 @@ +import torch +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl +import pytest + + +FLOAT_DTYPES = [tkl.f16, tkl.f32, tkl.f64] +INT_DTYPES = [ + tkl.bool, + tkl.i8, + tkl.i16, + tkl.i32, + tkl.i64, +] + + +def rms_norm_krnl(dtype, input, weight, output): + M = tkl.sym.M + K = tkl.sym.K + + @tk.gen.thread(M) + def rms_norm_kernel( + input: tkl.OutputBuffer[M, K, dtype], + weight: tk.lang.InputBuffer[M, K, dtype], + output: tk.lang.OutputBuffer[M, K, dtype], + ): + row_index = tk.lang.program_id(0) + eps = tkl.constant((1,), dtype, 0.00001) + zero = tkl.constant((1,), dtype, 0.0) + input_row = input[row_index, :] + sq_inp = input_row * input_row + sq_inp_red = tkl.sum(sq_inp) + # TODO: The input_row * zero is just dummy computation to pass in the right shapes, + # otherwise it leads to 'error: unknown: 'math.exp2' op operand #0 must be floating-point-like, but got 'vector' + denom = tkl.rsqrt(input_row * zero + sq_inp_red) + denom_eta = denom + eps + output[row_index, :] = denom_eta * input_row * weight[row_index, :] + + with tk.gen.TestLaunchContext(): + rms_norm_kernel(input, weight, output) + + +def iota_krnl(dtype, input): + M = tkl.sym.M + + @tk.gen.thread(M) + def iota_kernel(out: tkl.OutputBuffer[M, dtype]): + a = ( + tkl.constant((17, 37, 19), dtype, 5) + if dtype in INT_DTYPES + else tkl.constant((17, 37, 19), dtype, 5.0) + ) + b = ( + tkl.constant((17, 37, 19), dtype, 10) + if dtype in INT_DTYPES + else tkl.constant((17, 37, 19), dtype, 10.0) + ) + c = ( + tkl.constant((17, 37, 19), dtype, 2) + if dtype in INT_DTYPES + else tkl.constant((17, 37, 19), dtype, 2.0) + ) + if dtype in INT_DTYPES: + c = (a * b) // c + else: + c = (a * b) / c + c = c + a - b + + with tk.gen.TestLaunchContext(): + iota_kernel(input) + + +def softmax_krnl(dtype, input, output): + M = tkl.sym.M + K = tkl.sym.K + + @tk.gen.thread(M) + def softmax_kernel( + input: tk.lang.InputBuffer[M, K, dtype], + output: tk.lang.OutputBuffer[M, K, dtype], + ): + row_index = tk.lang.program_id(0) + input_row = input[row_index, :] + numerator = tkl.exp2(input_row - tkl.max(input_row)) + if dtype in INT_DTYPES: + output_row = numerator // tkl.sum(numerator) + else: + output_row = numerator / tkl.sum(numerator) + output[row_index, :] = output_row + + with tk.gen.TestLaunchContext(): + softmax_kernel(input, output) + + +def gemm_fx_kernel(dtype, A, B, output): + N = tkl.sym.N + M = tkl.sym.M + K = tkl.sym.K + BLOCK_SIZE = tkl.sym.BLOCK_SIZE + + @tk.gen.thread(N // BLOCK_SIZE, M // BLOCK_SIZE) + def gemm_kernel( + A: tkl.InputBuffer[N, K, dtype], + B: tkl.InputBuffer[K, M, dtype], + output: tkl.OutputBuffer[N, M, dtype], + ): + grid_n = tkl.program_id(0) + grid_m = tkl.program_id(1) + + acc = None + # TODO: Only considering the float and integer cases. + if dtype in INT_DTYPES: + acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0) + else: + acc = tkl.constant((BLOCK_SIZE, BLOCK_SIZE), dtype, 0.0) + + @tkl.for_loop(0, K // BLOCK_SIZE, init_args=[acc]) + def body(i, c): + a = tkl.load(A, (grid_n, i * BLOCK_SIZE), (BLOCK_SIZE, BLOCK_SIZE)) + b = tkl.load(B, (i * BLOCK_SIZE, grid_m), (BLOCK_SIZE, BLOCK_SIZE)) + return (tkl.dot(a, b, c),) + + tkl.store(output, (grid_n, grid_m), body[0]) + + with tk.gen.TestLaunchContext({BLOCK_SIZE: 32}): + gemm_kernel(A, B, output) + + +@pytest.mark.parametrize( + ("dtype",), + [(x,) for x in FLOAT_DTYPES + INT_DTYPES], +) +def test_iota_krnl(dtype): + input = torch.zeros(17) + iota_krnl(dtype, input) + + +@pytest.mark.parametrize( + ("dtype",), + [(x,) for x in FLOAT_DTYPES], +) +def test_rms_norm_krnl(dtype): + input = torch.randn(128, 64).to(dtype.to_torch_type()) + weight = torch.randn(128, 64).to(dtype.to_torch_type()) + output = torch.randn(128, 64).to(dtype.to_torch_type()) + rms_norm_krnl(dtype, input, weight, output) + + +@pytest.mark.parametrize( + ("dtype",), + [(x,) for x in FLOAT_DTYPES], +) +def test_softmax_krnl(dtype): + input = torch.randn(128, 64).to(dtype.to_torch_type()) + output = torch.randn(128, 64).to(dtype.to_torch_type()) + softmax_krnl(dtype, input, output) + + +@pytest.mark.parametrize( + ("dtype",), + [(x,) for x in FLOAT_DTYPES + INT_DTYPES], +) +def test_gemm_krnl(dtype): + A = torch.randn(512, 1024).to(dtype.to_torch_type()) + B = torch.randn(1024, 2048).to(dtype.to_torch_type()) + output = torch.zeros(512, 2048).to(dtype.to_torch_type()) + gemm_fx_kernel(dtype, A, B, output)