|
1 | | -# -*- coding: utf-8 -*- |
2 | | - |
3 | | -""" |
4 | | -Using User-Defined Triton Kernels with ``torch.compile`` |
5 | | -========================================================= |
6 | | -**Author:** `Oguz Ulgen <https://github.com/oulgen>`_ |
7 | | -""" |
8 | | - |
9 | | -###################################################################### |
10 | | -# User-defined Triton kernels can be used to optimize specific parts of your |
11 | | -# model's computation. These kernels are written in Triton's language, which is designed |
12 | | -# to make it easier to achieve peak hardware performance. By using user-defined Triton |
13 | | -# kernels with ``torch.compile``, you can integrate these optimized computations into |
14 | | -# your PyTorch model, potentially achieving significant performance improvements. |
15 | | -# |
16 | | -# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``. |
17 | | -# |
18 | | -# Prerequisites |
19 | | -# ------------------- |
20 | | -# |
21 | | -# Before starting this recipe, make sure that you have the following: |
22 | | -# |
23 | | -# * Basic understanding of ``torch.compile`` and Triton. See: |
24 | | -# |
25 | | -# * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__ |
26 | | -# * `Introduction to torch.compile <https://tutorials.pytorch.kr/intermediate/torch_compile_tutorial.html>`__ |
27 | | -# * `Triton language documentation <https://triton-lang.org/main/index.html>`__ |
28 | | -# |
29 | | -# * PyTorch 2.3 or later |
30 | | -# * A GPU that supports Triton |
31 | | -# |
32 | | - |
33 | | -import torch |
34 | | -from torch.utils._triton import has_triton |
35 | | - |
36 | | -###################################################################### |
37 | | -# Basic Usage |
38 | | -# -------------------- |
39 | | -# |
40 | | -# In this example, we will use a simple vector addition kernel from the Triton documentation |
41 | | -# with ``torch.compile``. |
42 | | -# For reference, see `Triton documentation <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__. |
43 | | -# |
44 | | - |
45 | | -if not has_triton(): |
46 | | - print("Skipping because triton is not supported on this device.") |
47 | | -else: |
48 | | - import triton |
49 | | - from triton import language as tl |
50 | | - |
51 | | - @triton.jit |
52 | | - def add_kernel( |
53 | | - in_ptr0, |
54 | | - in_ptr1, |
55 | | - out_ptr, |
56 | | - n_elements, |
57 | | - BLOCK_SIZE: "tl.constexpr", |
58 | | - ): |
59 | | - pid = tl.program_id(axis=0) |
60 | | - block_start = pid * BLOCK_SIZE |
61 | | - offsets = block_start + tl.arange(0, BLOCK_SIZE) |
62 | | - mask = offsets < n_elements |
63 | | - x = tl.load(in_ptr0 + offsets, mask=mask) |
64 | | - y = tl.load(in_ptr1 + offsets, mask=mask) |
65 | | - output = x + y |
66 | | - tl.store(out_ptr + offsets, output, mask=mask) |
67 | | - |
68 | | - @torch.compile(fullgraph=True) |
69 | | - def add_fn(x, y): |
70 | | - output = torch.zeros_like(x) |
71 | | - n_elements = output.numel() |
72 | | - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
73 | | - add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) |
74 | | - return output |
75 | | - |
76 | | - x = torch.randn(4, device="cuda") |
77 | | - y = torch.randn(4, device="cuda") |
78 | | - out = add_fn(x, y) |
79 | | - print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
80 | | - |
81 | | -###################################################################### |
82 | | -# Advanced Usage |
83 | | -# ------------------------------------------------------------------- |
84 | | -# |
85 | | -# Triton's autotune feature is a powerful tool that automatically optimizes the configuration |
86 | | -# parameters of your Triton kernels. It explores a range of possible configurations and |
87 | | -# selects the one that delivers the best performance for your specific use case. |
88 | | -# |
89 | | -# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch |
90 | | -# model is running as efficiently as possible. Here is an example of using ``torch.compile`` |
91 | | -# and ``triton.autotune``. |
92 | | -# |
93 | | -# .. note:: |
94 | | -# |
95 | | -# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``. |
96 | | - |
97 | | -if not has_triton(): |
98 | | - print("Skipping because triton is not supported on this device.") |
99 | | -else: |
100 | | - import triton |
101 | | - from triton import language as tl |
102 | | - |
103 | | - @triton.autotune( |
104 | | - configs=[ |
105 | | - triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8), |
106 | | - triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4), |
107 | | - triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8), |
108 | | - triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4), |
109 | | - ], |
110 | | - key=[], |
111 | | - ) |
112 | | - @triton.jit |
113 | | - def add_kernel_autotuned( |
114 | | - in_ptr0, |
115 | | - in_ptr1, |
116 | | - out_ptr, |
117 | | - n_elements, |
118 | | - BLOCK_SIZE: "tl.constexpr", |
119 | | - ): |
120 | | - pid = tl.program_id(axis=0) |
121 | | - block_start = pid * BLOCK_SIZE |
122 | | - offsets = block_start + tl.arange(0, BLOCK_SIZE) |
123 | | - mask = offsets < n_elements |
124 | | - x = tl.load(in_ptr0 + offsets, mask=mask) |
125 | | - y = tl.load(in_ptr1 + offsets, mask=mask) |
126 | | - output = x + y |
127 | | - tl.store(out_ptr + offsets, output, mask=mask) |
128 | | - |
129 | | - @torch.compile(fullgraph=True) |
130 | | - def add_fn(x, y): |
131 | | - output = torch.zeros_like(x) |
132 | | - n_elements = output.numel() |
133 | | - grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
134 | | - add_kernel_autotuned[grid](x, y, output, n_elements) |
135 | | - return output |
136 | | - |
137 | | - x = torch.randn(4, device="cuda") |
138 | | - y = torch.randn(4, device="cuda") |
139 | | - out = add_fn(x, y) |
140 | | - print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
141 | | - |
142 | | -###################################################################### |
143 | | -# Composibility and Limitations |
144 | | -# -------------------------------------------------------------------- |
145 | | -# |
146 | | -# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile`` |
147 | | -# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor. |
148 | | -# You can use these features together to build complex, high-performance models. |
149 | | -# |
150 | | -# However, there are certain limitations to be aware of: |
151 | | -# |
152 | | -# * **Tensor Subclasses:** Currently, there is no support for |
153 | | -# tensor subclasses and other advanced features. |
154 | | -# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or |
155 | | -# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This |
156 | | -# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used |
157 | | -# together, ``triton.heuristics`` must be used first. |
158 | | -# |
159 | | -# Conclusion |
160 | | -# ----------- |
161 | | -# In this recipe, we explored how to utilize user-defined Triton kernels |
162 | | -# with ``torch.compile``. We delved into the basic usage of a simple |
163 | | -# vector addition kernel and advanced usage involving Triton's autotune |
164 | | -# feature. We also discussed the composability of user-defined Triton |
165 | | -# kernels with other PyTorch features and highlighted some current limitations. |
166 | | -# |
167 | | -# See Also |
168 | | -# --------- |
169 | | -# |
170 | | -# * `Compiling the Optimizers <https://tutorials.pytorch.kr/recipes/compiling_optimizer.html>`__ |
171 | | -# * `Implementing High-Performance Transformers with Scaled Dot Product Attention <https://tutorials.pytorch.kr/intermediate/scaled_dot_product_attention_tutorial.html>`__ |
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +μ¬μ©μ μ μ Triton 컀λμ ``torch.compile``κ³Ό ν¨κ» μ¬μ©νκΈ° |
| 5 | +========================================================= |
| 6 | +**μ μ:** `Oguz Ulgen <https://github.com/oulgen>`_ |
| 7 | +**λ²μ:** `ꡬ경μ <https://github.com/kookyungseon>`_, `μ΄μ±μ΄ <https://github.com/dlcodns>`_ |
| 8 | +""" |
| 9 | + |
| 10 | +###################################################################### |
| 11 | +# μ¬μ©μ μ μ Triton 컀λμ μ¬μ©νλ©΄ λͺ¨λΈμ νΉμ λΆλΆμ κ³μ°μ μ΅μ νν μ μμ΅λλ€. |
| 12 | +# μ΄ μ»€λλ€μ Tritonμ μΈμ΄λ‘ μμ±λ κ²μΌλ‘ μ€κ³λμμ΅λλ€. |
| 13 | +# μ¬μ©μ μ μ Tritonμ μ¬μ©νμ¬ νλμ¨μ΄ μ±λ₯μ μ΅κ³ λ‘ ν₯μμν΅λλ€. |
| 14 | +# ``torch.compile``λ₯Ό μ¬μ©νλ 컀λμ μ΄λ¬ν μ΅μ νλ κ³μ°μ ν΅ν©ν μ μμ΅λλ€. |
| 15 | +# PyTorch λͺ¨λΈμ ν΅ν΄ μλΉν μ±λ₯ ν₯μμ μ€νν μ μμ΅λλ€. |
| 16 | +# |
| 17 | +# μ΄ λ μνΌλ μ¬μ©μ μ μ Triton 컀λμ ``torch.compile``κ³Ό ν¨κ» μ¬μ©ν μ μλ λ°©λ²μ 보μ¬μ€λλ€. |
| 18 | +# |
| 19 | +# μ μ 쑰건 |
| 20 | +# ------------------- |
| 21 | +# |
| 22 | +# μ΄ λ μνΌλ₯Ό μμνκΈ° μ μ λ€μμ΄ μλμ§ νμΈν©λλ€: |
| 23 | +# * ``torch.compile`` λ° Tritonμ λν κΈ°λ³Έμ μΈ μ΄ν΄. μ°Έμ‘°: |
| 24 | +# |
| 25 | +# * `torch.compiler API μ€λͺ
μ <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__ |
| 26 | +# * `torch.compile μκ° <https://tutorials.pytorch.kr/intermediate/torch_compile_tutorial.html>`__ |
| 27 | +# * `Triton μΈμ΄ λ¬Έμ <https://triton-lang.org/main/index.html>`__ |
| 28 | +# |
| 29 | +# * PyTorch 2.3 μ΄μ |
| 30 | +# * Tritonμ μ§μνλ GPU |
| 31 | +# |
| 32 | + |
| 33 | +import torch |
| 34 | +from torch.utils._triton import has_triton |
| 35 | + |
| 36 | +###################################################################### |
| 37 | +# κΈ°λ³Έ μ¬μ©λ² |
| 38 | +# -------------------- |
| 39 | +# |
| 40 | +# μ΄ μμμλ Triton λ¬Έμμ κ°λ¨ν λ²‘ν° λ§μ
컀λμ μ¬μ©ν©λλ€. |
| 41 | +# ``torch.compile``κ³Ό ν¨κ». |
| 42 | +# μ°Έκ³ , `Triton λ¬Έμλ₯Ό μ°Έκ³ νμΈμ <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__. |
| 43 | +# |
| 44 | + |
| 45 | +if not has_triton(): |
| 46 | + print("Skipping because triton is not supported on this device.") |
| 47 | +else: |
| 48 | + import triton |
| 49 | + from triton import language as tl |
| 50 | + |
| 51 | + @triton.jit |
| 52 | + def add_kernel( |
| 53 | + in_ptr0, |
| 54 | + in_ptr1, |
| 55 | + out_ptr, |
| 56 | + n_elements, |
| 57 | + BLOCK_SIZE: "tl.constexpr", |
| 58 | + ): |
| 59 | + pid = tl.program_id(axis=0) |
| 60 | + block_start = pid * BLOCK_SIZE |
| 61 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 62 | + mask = offsets < n_elements |
| 63 | + x = tl.load(in_ptr0 + offsets, mask=mask) |
| 64 | + y = tl.load(in_ptr1 + offsets, mask=mask) |
| 65 | + output = x + y |
| 66 | + tl.store(out_ptr + offsets, output, mask=mask) |
| 67 | + |
| 68 | + @torch.compile(fullgraph=True) |
| 69 | + def add_fn(x, y): |
| 70 | + output = torch.zeros_like(x) |
| 71 | + n_elements = output.numel() |
| 72 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 73 | + add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4) |
| 74 | + return output |
| 75 | + |
| 76 | + x = torch.randn(4, device="cuda") |
| 77 | + y = torch.randn(4, device="cuda") |
| 78 | + out = add_fn(x, y) |
| 79 | + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
| 80 | + |
| 81 | +###################################################################### |
| 82 | +# κ³ κΈ μ¬μ©λ² |
| 83 | +# ------------------------------------------------------------------- |
| 84 | +# |
| 85 | +# Tritonμ μλ νλ κΈ°λ₯μ Triton 컀λμ κ΅¬μ± λ§€κ°λ³μλ₯Ό μλμΌλ‘ μ΅μ νν΄μ£Όλ κ°λ ₯ν λꡬμ
λλ€. |
| 86 | +# λ€μν μ€μ μ κ²ν νμ¬ νΉμ μ¬μ© μ¬λ‘μ μ΅μ μ μ±λ₯μ μ 곡νλ ꡬμ±μ μ νν©λλ€. |
| 87 | +# |
| 88 | +# ``torch.compile``κ³Ό ν¨κ» μ¬μ©ν κ²½μ° ``triton.autotune``μ μ¬μ©νλ©΄ PyTorch λͺ¨λΈμ μ΅λν ν¨μ¨μ μΌλ‘ |
| 89 | +# μ€νν μ μμ΅λλ€. μλλ ``torch.compile``κ³Ό ``triton.autotune``μ μ¬μ©νλ μμ μ
λλ€. |
| 90 | +# |
| 91 | +# .. note:: |
| 92 | +# ``torch.compile``μ ``triton.autotune``μ λν configsμ key μΈμλ§ μ§μν©λλ€. |
| 93 | +# |
| 94 | + |
| 95 | +if not has_triton(): |
| 96 | + print("Skipping because triton is not supported on this device.") |
| 97 | +else: |
| 98 | + import triton |
| 99 | + from triton import language as tl |
| 100 | + |
| 101 | + @triton.autotune( |
| 102 | + configs=[ |
| 103 | + triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8), |
| 104 | + triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4), |
| 105 | + triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8), |
| 106 | + triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4), |
| 107 | + ], |
| 108 | + key=[], |
| 109 | + ) |
| 110 | + @triton.jit |
| 111 | + def add_kernel_autotuned( |
| 112 | + in_ptr0, |
| 113 | + in_ptr1, |
| 114 | + out_ptr, |
| 115 | + n_elements, |
| 116 | + BLOCK_SIZE: "tl.constexpr", |
| 117 | + ): |
| 118 | + pid = tl.program_id(axis=0) |
| 119 | + block_start = pid * BLOCK_SIZE |
| 120 | + offsets = block_start + tl.arange(0, BLOCK_SIZE) |
| 121 | + mask = offsets < n_elements |
| 122 | + x = tl.load(in_ptr0 + offsets, mask=mask) |
| 123 | + y = tl.load(in_ptr1 + offsets, mask=mask) |
| 124 | + output = x + y |
| 125 | + tl.store(out_ptr + offsets, output, mask=mask) |
| 126 | + |
| 127 | + @torch.compile(fullgraph=True) |
| 128 | + def add_fn(x, y): |
| 129 | + output = torch.zeros_like(x) |
| 130 | + n_elements = output.numel() |
| 131 | + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) |
| 132 | + add_kernel_autotuned[grid](x, y, output, n_elements) |
| 133 | + return output |
| 134 | + |
| 135 | + x = torch.randn(4, device="cuda") |
| 136 | + y = torch.randn(4, device="cuda") |
| 137 | + out = add_fn(x, y) |
| 138 | + print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}") |
| 139 | + |
| 140 | +###################################################################### |
| 141 | +# νΈνμ±κ³Ό μ νμ¬ν |
| 142 | +# -------------------------------------------------------------------- |
| 143 | +# |
| 144 | +# PyTorch 2.3 λ²μ κΈ°μ€μΌλ‘, ``torch.compile``μ μ¬μ©μ μ μ Triton 컀λμλ λμ λͺ¨μ |
| 145 | +# ``torch.autograd.Function``, JIT inductor, AOT inductorκ° μ§μλ©λλ€. μ΄ κΈ°λ₯λ€μ |
| 146 | +# μ‘°ν©νμ¬ λ³΅μ‘νκ³ κ³ μ±λ₯μΈ λͺ¨λΈμ ꡬμΆν μ μμ΅λλ€. |
| 147 | +# |
| 148 | +# κ·Έλ¬λ μμλμ΄μΌ ν λͺ κ°μ§ μ ν μ¬νμ΄ μμ΅λλ€. |
| 149 | +# |
| 150 | +# * **Tensor Subclasses:** νμ¬λ‘μλ Tensor νμ ν΄λμ€ λ° κΈ°ν κ³ κΈ κΈ°λ₯μ μ§μλμ§ μμ΅λλ€. |
| 151 | +# |
| 152 | +# * **Triton Features:** ``triton.heuristics``λ λ¨λ
μΌλ‘ μ¬μ©νκ±°λ ``triton.autotune`` μμμ |
| 153 | +# μ¬μ©ν μ μμ§λ§, ``triton.autotune`` λ€μμλ μ¬μ©ν μ μμ΅λλ€. λ°λΌμ ``triton.heuristics``μ |
| 154 | +# ``triton.autotune``μ ν¨κ» μ¬μ©νλ €λ©΄ ``triton.heuristics``λ₯Ό λ¨Όμ μ¬μ©ν΄μΌ ν©λλ€. |
| 155 | +# |
| 156 | +# κ²°λ‘ |
| 157 | +# ----------- |
| 158 | +# |
| 159 | +# μ΄ λ μνΌμμλ μ¬μ©μ μ μ Triton 컀λμ ``torch.compile``λ‘ νμ©νλ λ°©λ²μ μμ보μμ΅λλ€. κ°λ¨ν |
| 160 | +# λ²‘ν° λ§μ
컀λμ κΈ°λ³Έ μ¬μ©λ²κ³Ό Tritonμ μλ νλ κΈ°λ₯μ ν¬ν¨ν κ³ κΈ μ¬μ©λ²μ λν΄ λ€λ€μ΅λλ€. λν μ¬μ©μ |
| 161 | +# μ μ Triton 컀λκ³Ό λ€λ₯Έ Pytorch κΈ°λ₯μ μ‘°ν© κ°λ₯μ±μ λν΄ λ
Όμνκ³ νμ¬μ λͺ κ°μ§ μ ν μ¬νμ κ°μ‘°νμ΅λλ€. |
| 162 | +# |
| 163 | +# κ΄λ ¨ νλͺ© |
| 164 | +# --------- |
| 165 | +# |
| 166 | +# * `Optimizer μ»΄νμΌνκΈ° <https://tutorials.pytorch.kr/recipes/compiling_optimizer.html>`__ |
| 167 | +# * `Scaled Dot Product Attentionμ νμ©ν κ³ μ±λ₯ Transformer ꡬννκΈ° <https://tutorials.pytorch.kr/intermediate/scaled_dot_product_attention_tutorial.html>`__ |
0 commit comments