Skip to content

Commit 673746b

Browse files
authored
recipes_source /torch_compile_user_defined_triton_kernel_tutorial.py λ²ˆμ—­ (#918)
1 parent 10f1fab commit 673746b

File tree

1 file changed

+167
-171
lines changed

1 file changed

+167
-171
lines changed
Lines changed: 167 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -1,171 +1,167 @@
1-
# -*- coding: utf-8 -*-
2-
3-
"""
4-
Using User-Defined Triton Kernels with ``torch.compile``
5-
=========================================================
6-
**Author:** `Oguz Ulgen <https://github.com/oulgen>`_
7-
"""
8-
9-
######################################################################
10-
# User-defined Triton kernels can be used to optimize specific parts of your
11-
# model's computation. These kernels are written in Triton's language, which is designed
12-
# to make it easier to achieve peak hardware performance. By using user-defined Triton
13-
# kernels with ``torch.compile``, you can integrate these optimized computations into
14-
# your PyTorch model, potentially achieving significant performance improvements.
15-
#
16-
# This recipes demonstrates how you can use user-defined Triton kernels with ``torch.compile``.
17-
#
18-
# Prerequisites
19-
# -------------------
20-
#
21-
# Before starting this recipe, make sure that you have the following:
22-
#
23-
# * Basic understanding of ``torch.compile`` and Triton. See:
24-
#
25-
# * `torch.compiler API documentation <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__
26-
# * `Introduction to torch.compile <https://tutorials.pytorch.kr/intermediate/torch_compile_tutorial.html>`__
27-
# * `Triton language documentation <https://triton-lang.org/main/index.html>`__
28-
#
29-
# * PyTorch 2.3 or later
30-
# * A GPU that supports Triton
31-
#
32-
33-
import torch
34-
from torch.utils._triton import has_triton
35-
36-
######################################################################
37-
# Basic Usage
38-
# --------------------
39-
#
40-
# In this example, we will use a simple vector addition kernel from the Triton documentation
41-
# with ``torch.compile``.
42-
# For reference, see `Triton documentation <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__.
43-
#
44-
45-
if not has_triton():
46-
print("Skipping because triton is not supported on this device.")
47-
else:
48-
import triton
49-
from triton import language as tl
50-
51-
@triton.jit
52-
def add_kernel(
53-
in_ptr0,
54-
in_ptr1,
55-
out_ptr,
56-
n_elements,
57-
BLOCK_SIZE: "tl.constexpr",
58-
):
59-
pid = tl.program_id(axis=0)
60-
block_start = pid * BLOCK_SIZE
61-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
62-
mask = offsets < n_elements
63-
x = tl.load(in_ptr0 + offsets, mask=mask)
64-
y = tl.load(in_ptr1 + offsets, mask=mask)
65-
output = x + y
66-
tl.store(out_ptr + offsets, output, mask=mask)
67-
68-
@torch.compile(fullgraph=True)
69-
def add_fn(x, y):
70-
output = torch.zeros_like(x)
71-
n_elements = output.numel()
72-
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73-
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
74-
return output
75-
76-
x = torch.randn(4, device="cuda")
77-
y = torch.randn(4, device="cuda")
78-
out = add_fn(x, y)
79-
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
80-
81-
######################################################################
82-
# Advanced Usage
83-
# -------------------------------------------------------------------
84-
#
85-
# Triton's autotune feature is a powerful tool that automatically optimizes the configuration
86-
# parameters of your Triton kernels. It explores a range of possible configurations and
87-
# selects the one that delivers the best performance for your specific use case.
88-
#
89-
# When used with ``torch.compile``, ``triton.autotune`` can help ensure that your PyTorch
90-
# model is running as efficiently as possible. Here is an example of using ``torch.compile``
91-
# and ``triton.autotune``.
92-
#
93-
# .. note::
94-
#
95-
# ``torch.compile`` only supports configs and key arguments to ``triton.autotune``.
96-
97-
if not has_triton():
98-
print("Skipping because triton is not supported on this device.")
99-
else:
100-
import triton
101-
from triton import language as tl
102-
103-
@triton.autotune(
104-
configs=[
105-
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
106-
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
107-
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
108-
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
109-
],
110-
key=[],
111-
)
112-
@triton.jit
113-
def add_kernel_autotuned(
114-
in_ptr0,
115-
in_ptr1,
116-
out_ptr,
117-
n_elements,
118-
BLOCK_SIZE: "tl.constexpr",
119-
):
120-
pid = tl.program_id(axis=0)
121-
block_start = pid * BLOCK_SIZE
122-
offsets = block_start + tl.arange(0, BLOCK_SIZE)
123-
mask = offsets < n_elements
124-
x = tl.load(in_ptr0 + offsets, mask=mask)
125-
y = tl.load(in_ptr1 + offsets, mask=mask)
126-
output = x + y
127-
tl.store(out_ptr + offsets, output, mask=mask)
128-
129-
@torch.compile(fullgraph=True)
130-
def add_fn(x, y):
131-
output = torch.zeros_like(x)
132-
n_elements = output.numel()
133-
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
134-
add_kernel_autotuned[grid](x, y, output, n_elements)
135-
return output
136-
137-
x = torch.randn(4, device="cuda")
138-
y = torch.randn(4, device="cuda")
139-
out = add_fn(x, y)
140-
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
141-
142-
######################################################################
143-
# Composibility and Limitations
144-
# --------------------------------------------------------------------
145-
#
146-
# As of PyTorch 2.3, the support for user-defined Triton kernels in ``torch.compile``
147-
# includes dynamic shapes, ``torch.autograd.Function``, JIT inductor, and AOT inductor.
148-
# You can use these features together to build complex, high-performance models.
149-
#
150-
# However, there are certain limitations to be aware of:
151-
#
152-
# * **Tensor Subclasses:** Currently, there is no support for
153-
# tensor subclasses and other advanced features.
154-
# * **Triton Features:** While ``triton.heuristics`` can be used either standalone or
155-
# before ``triton.autotune``, it cannot be used after ```triton.autotune``. This
156-
# implies that if ``triton.heuristics`` and ``triton.autotune`` are to be used
157-
# together, ``triton.heuristics`` must be used first.
158-
#
159-
# Conclusion
160-
# -----------
161-
# In this recipe, we explored how to utilize user-defined Triton kernels
162-
# with ``torch.compile``. We delved into the basic usage of a simple
163-
# vector addition kernel and advanced usage involving Triton's autotune
164-
# feature. We also discussed the composability of user-defined Triton
165-
# kernels with other PyTorch features and highlighted some current limitations.
166-
#
167-
# See Also
168-
# ---------
169-
#
170-
# * `Compiling the Optimizers <https://tutorials.pytorch.kr/recipes/compiling_optimizer.html>`__
171-
# * `Implementing High-Performance Transformers with Scaled Dot Product Attention <https://tutorials.pytorch.kr/intermediate/scaled_dot_product_attention_tutorial.html>`__
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
μ‚¬μš©μž μ •μ˜ Triton 컀널을 ``torch.compile``κ³Ό ν•¨κ»˜ μ‚¬μš©ν•˜κΈ°
5+
=========================================================
6+
**μ €μž:** `Oguz Ulgen <https://github.com/oulgen>`_
7+
**λ²ˆμ—­:** `ꡬ경선 <https://github.com/kookyungseon>`_, `μ΄μ±„μš΄ <https://github.com/dlcodns>`_
8+
"""
9+
10+
######################################################################
11+
# μ‚¬μš©μž μ •μ˜ Triton 컀널을 μ‚¬μš©ν•˜λ©΄ λͺ¨λΈμ˜ νŠΉμ • λΆ€λΆ„μ˜ 계산을 μ΅œμ ν™”ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
12+
# 이 컀널듀은 Triton의 μ–Έμ–΄λ‘œ μž‘μ„±λœ κ²ƒμœΌλ‘œ μ„€κ³„λ˜μ—ˆμŠ΅λ‹ˆλ‹€.
13+
# μ‚¬μš©μž μ •μ˜ Triton을 μ‚¬μš©ν•˜μ—¬ ν•˜λ“œμ›¨μ–΄ μ„±λŠ₯을 졜고둜 ν–₯μƒμ‹œν‚΅λ‹ˆλ‹€.
14+
# ``torch.compile``λ₯Ό μ‚¬μš©ν•˜λŠ” 컀널은 μ΄λŸ¬ν•œ μ΅œμ ν™”λœ 계산을 톡합할 수 μžˆμŠ΅λ‹ˆλ‹€.
15+
# PyTorch λͺ¨λΈμ„ 톡해 μƒλ‹Ήν•œ μ„±λŠ₯ ν–₯상을 μ‹€ν˜„ν•  수 μžˆμŠ΅λ‹ˆλ‹€.
16+
#
17+
# 이 λ ˆμ‹œν”ΌλŠ” μ‚¬μš©μž μ •μ˜ Triton 컀널을 ``torch.compile``κ³Ό ν•¨κ»˜ μ‚¬μš©ν•  수 μžˆλŠ” 방법을 λ³΄μ—¬μ€λ‹ˆλ‹€.
18+
#
19+
# μ „μ œμ‘°κ±΄
20+
# -------------------
21+
#
22+
# 이 λ ˆμ‹œν”Όλ₯Ό μ‹œμž‘ν•˜κΈ° 전에 λ‹€μŒμ΄ μžˆλŠ”μ§€ ν™•μΈν•©λ‹ˆλ‹€:
23+
# * ``torch.compile`` 및 Triton에 λŒ€ν•œ 기본적인 이해. μ°Έμ‘°:
24+
#
25+
# * `torch.compiler API μ„€λͺ…μ„œ <https://pytorch.org/docs/stable/torch.compiler.html#torch-compiler>`__
26+
# * `torch.compile μ†Œκ°œ <https://tutorials.pytorch.kr/intermediate/torch_compile_tutorial.html>`__
27+
# * `Triton μ–Έμ–΄ λ¬Έμ„œ <https://triton-lang.org/main/index.html>`__
28+
#
29+
# * PyTorch 2.3 이상
30+
# * Triton을 μ§€μ›ν•˜λŠ” GPU
31+
#
32+
33+
import torch
34+
from torch.utils._triton import has_triton
35+
36+
######################################################################
37+
# κΈ°λ³Έ μ‚¬μš©λ²•
38+
# --------------------
39+
#
40+
# 이 μ˜ˆμ—μ„œλŠ” Triton λ¬Έμ„œμ˜ κ°„λ‹¨ν•œ 벑터 λ§μ…ˆ 컀널을 μ‚¬μš©ν•©λ‹ˆλ‹€.
41+
# ``torch.compile``κ³Ό ν•¨κ»˜.
42+
# μ°Έκ³ , `Triton λ¬Έμ„œλ₯Ό μ°Έκ³ ν•˜μ„Έμš” <https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html>`__.
43+
#
44+
45+
if not has_triton():
46+
print("Skipping because triton is not supported on this device.")
47+
else:
48+
import triton
49+
from triton import language as tl
50+
51+
@triton.jit
52+
def add_kernel(
53+
in_ptr0,
54+
in_ptr1,
55+
out_ptr,
56+
n_elements,
57+
BLOCK_SIZE: "tl.constexpr",
58+
):
59+
pid = tl.program_id(axis=0)
60+
block_start = pid * BLOCK_SIZE
61+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
62+
mask = offsets < n_elements
63+
x = tl.load(in_ptr0 + offsets, mask=mask)
64+
y = tl.load(in_ptr1 + offsets, mask=mask)
65+
output = x + y
66+
tl.store(out_ptr + offsets, output, mask=mask)
67+
68+
@torch.compile(fullgraph=True)
69+
def add_fn(x, y):
70+
output = torch.zeros_like(x)
71+
n_elements = output.numel()
72+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
73+
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=4)
74+
return output
75+
76+
x = torch.randn(4, device="cuda")
77+
y = torch.randn(4, device="cuda")
78+
out = add_fn(x, y)
79+
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
80+
81+
######################################################################
82+
# κ³ κΈ‰ μ‚¬μš©λ²•
83+
# -------------------------------------------------------------------
84+
#
85+
# Triton의 μžλ™ νŠœλ‹ κΈ°λŠ₯은 Triton μ»€λ„μ˜ ꡬ성 λ§€κ°œλ³€μˆ˜λ₯Ό μžλ™μœΌλ‘œ μ΅œμ ν™”ν•΄μ£ΌλŠ” κ°•λ ₯ν•œ λ„κ΅¬μž…λ‹ˆλ‹€.
86+
# λ‹€μ–‘ν•œ 섀정을 κ²€ν† ν•˜μ—¬ νŠΉμ • μ‚¬μš© 사둀에 졜적의 μ„±λŠ₯을 μ œκ³΅ν•˜λŠ” ꡬ성을 μ„ νƒν•©λ‹ˆλ‹€.
87+
#
88+
# ``torch.compile``κ³Ό ν•¨κ»˜ μ‚¬μš©ν•  경우 ``triton.autotune``을 μ‚¬μš©ν•˜λ©΄ PyTorch λͺ¨λΈμ„ μ΅œλŒ€ν•œ 효율적으둜
89+
# μ‹€ν–‰ν•  수 μžˆμŠ΅λ‹ˆλ‹€. μ•„λž˜λŠ” ``torch.compile``κ³Ό ``triton.autotune``을 μ‚¬μš©ν•˜λŠ” μ˜ˆμ œμž…λ‹ˆλ‹€.
90+
#
91+
# .. note::
92+
# ``torch.compile``은 ``triton.autotune``에 λŒ€ν•œ configs와 key 인수만 μ§€μ›ν•©λ‹ˆλ‹€.
93+
#
94+
95+
if not has_triton():
96+
print("Skipping because triton is not supported on this device.")
97+
else:
98+
import triton
99+
from triton import language as tl
100+
101+
@triton.autotune(
102+
configs=[
103+
triton.Config({"BLOCK_SIZE": 4}, num_stages=3, num_warps=8),
104+
triton.Config({"BLOCK_SIZE": 4}, num_stages=4, num_warps=4),
105+
triton.Config({"BLOCK_SIZE": 2}, num_stages=3, num_warps=8),
106+
triton.Config({"BLOCK_SIZE": 2}, num_stages=4, num_warps=4),
107+
],
108+
key=[],
109+
)
110+
@triton.jit
111+
def add_kernel_autotuned(
112+
in_ptr0,
113+
in_ptr1,
114+
out_ptr,
115+
n_elements,
116+
BLOCK_SIZE: "tl.constexpr",
117+
):
118+
pid = tl.program_id(axis=0)
119+
block_start = pid * BLOCK_SIZE
120+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
121+
mask = offsets < n_elements
122+
x = tl.load(in_ptr0 + offsets, mask=mask)
123+
y = tl.load(in_ptr1 + offsets, mask=mask)
124+
output = x + y
125+
tl.store(out_ptr + offsets, output, mask=mask)
126+
127+
@torch.compile(fullgraph=True)
128+
def add_fn(x, y):
129+
output = torch.zeros_like(x)
130+
n_elements = output.numel()
131+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
132+
add_kernel_autotuned[grid](x, y, output, n_elements)
133+
return output
134+
135+
x = torch.randn(4, device="cuda")
136+
y = torch.randn(4, device="cuda")
137+
out = add_fn(x, y)
138+
print(f"Vector addition of\nX:\t{x}\nY:\t{y}\nis equal to\n{out}")
139+
140+
######################################################################
141+
# ν˜Έν™˜μ„±κ³Ό μ œν•œμ‚¬ν•­
142+
# --------------------------------------------------------------------
143+
#
144+
# PyTorch 2.3 버전 κΈ°μ€€μœΌλ‘œ, ``torch.compile``의 μ‚¬μš©μž μ •μ˜ Triton μ»€λ„μ—λŠ” 동적 λͺ¨μ–‘
145+
# ``torch.autograd.Function``, JIT inductor, AOT inductorκ°€ μ§€μ›λ©λ‹ˆλ‹€. 이 κΈ°λŠ₯듀을
146+
# μ‘°ν•©ν•˜μ—¬ λ³΅μž‘ν•˜κ³  κ³ μ„±λŠ₯인 λͺ¨λΈμ„ ꡬ좕할 수 μžˆμŠ΅λ‹ˆλ‹€.
147+
#
148+
# κ·ΈλŸ¬λ‚˜ μ•Œμ•„λ‘μ–΄μ•Ό ν•  λͺ‡ κ°€μ§€ μ œν•œ 사항이 μžˆμŠ΅λ‹ˆλ‹€.
149+
#
150+
# * **Tensor Subclasses:** ν˜„μž¬λ‘œμ„œλŠ” Tensor ν•˜μœ„ 클래슀 및 기타 κ³ κΈ‰ κΈ°λŠ₯은 μ§€μ›λ˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
151+
#
152+
# * **Triton Features:** ``triton.heuristics``λŠ” λ‹¨λ…μœΌλ‘œ μ‚¬μš©ν•˜κ±°λ‚˜ ``triton.autotune`` μ•žμ—μ„œ
153+
# μ‚¬μš©ν•  수 μžˆμ§€λ§Œ, ``triton.autotune`` λ’€μ—μ„œλŠ” μ‚¬μš©ν•  수 μ—†μŠ΅λ‹ˆλ‹€. λ”°λΌμ„œ ``triton.heuristics``와
154+
# ``triton.autotune``을 ν•¨κ»˜ μ‚¬μš©ν•˜λ €λ©΄ ``triton.heuristics``λ₯Ό λ¨Όμ € μ‚¬μš©ν•΄μ•Ό ν•©λ‹ˆλ‹€.
155+
#
156+
# κ²°λ‘ 
157+
# -----------
158+
#
159+
# 이 λ ˆμ‹œν”Όμ—μ„œλŠ” μ‚¬μš©μž μ •μ˜ Triton 컀널을 ``torch.compile``둜 ν™œμš©ν•˜λŠ” 방법을 μ•Œμ•„λ³΄μ•˜μŠ΅λ‹ˆλ‹€. κ°„λ‹¨ν•œ
160+
# 벑터 λ§μ…ˆ μ»€λ„μ˜ κΈ°λ³Έ μ‚¬μš©λ²•κ³Ό Triton의 μžλ™ νŠœλ‹ κΈ°λŠ₯을 ν¬ν•¨ν•œ κ³ κΈ‰ μ‚¬μš©λ²•μ— λŒ€ν•΄ λ‹€λ€˜μŠ΅λ‹ˆλ‹€. λ˜ν•œ μ‚¬μš©μž
161+
# μ •μ˜ Triton 컀널과 λ‹€λ₯Έ Pytorch κΈ°λŠ₯의 μ‘°ν•© κ°€λŠ₯성에 λŒ€ν•΄ λ…Όμ˜ν•˜κ³  ν˜„μž¬μ˜ λͺ‡ κ°€μ§€ μ œν•œ 사항을 κ°•μ‘°ν–ˆμŠ΅λ‹ˆλ‹€.
162+
#
163+
# κ΄€λ ¨ ν•­λͺ©
164+
# ---------
165+
#
166+
# * `Optimizer μ»΄νŒŒμΌν•˜κΈ° <https://tutorials.pytorch.kr/recipes/compiling_optimizer.html>`__
167+
# * `Scaled Dot Product Attention을 ν™œμš©ν•œ κ³ μ„±λŠ₯ Transformer κ΅¬ν˜„ν•˜κΈ° <https://tutorials.pytorch.kr/intermediate/scaled_dot_product_attention_tutorial.html>`__

0 commit comments

Comments
Β (0)