Skip to content

Commit 0356811

Browse files
committed
Merge branch 'main' of https://github.com/intel/intel-xpu-backend-for-triton into amyachev/issue4172
2 parents 61f0107 + f36355f commit 0356811

File tree

16 files changed

+152
-182
lines changed

16 files changed

+152
-182
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ env:
7676
jobs:
7777
build:
7878
name: Build
79+
timeout-minutes: 720
7980
runs-on: ${{ fromJson(inputs.runner_label && format('["linux", "{0}"]', inputs.runner_label) || format('["linux", "{0}", "{1}", "{2}"]', inputs.device, inputs.driver_version, inputs.runner_version)) }}
8081
defaults:
8182
run:

.github/workflows/triton-benchmarks.yml

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ jobs:
140140
python build_report.py $REPORTS/softmax-performance.csv $REPORTS/softmax-xetla-report.csv --benchmark softmax --compiler xetla --param_cols "N" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
141141
142142
- name: Run Triton GEMM kernel benchmark
143+
<<<<<<< HEAD
143144
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py') }}
144145
run: |
145146
cd benchmarks/triton_kernels_benchmark
@@ -154,6 +155,8 @@ jobs:
154155
fi
155156
156157
- name: Run Triton GEMM kernel benchmark - new shapes
158+
=======
159+
>>>>>>> f36355f85676fd6aabbd807e9b4c4859679bb966
157160
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'gemm_benchmark.py_newshapes')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'gemm_benchmark.py_newshapes') }}
158161
run: |
159162
cd benchmarks/triton_kernels_benchmark
@@ -274,8 +277,8 @@ jobs:
274277
python flash_attention_benchmark.py --reports $REPORTS --n_runs $N_RUNS
275278
276279
source ../../scripts/capture-hw-details.sh
277-
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
278-
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
280+
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-triton-report.csv --benchmark flash-attn --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
281+
python build_report.py $REPORTS/attn-performance.csv $REPORTS/attn-xetla-report.csv --benchmark flash-attn --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
279282
280283
- name: Run Triton FA bwd kernel benchmark
281284
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_bwd_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_bwd_benchmark.py') }}
@@ -286,8 +289,8 @@ jobs:
286289
mv $REPORTS/attn-performance.csv $REPORTS/attn-bwd-performance.csv
287290
288291
source ../../scripts/capture-hw-details.sh
289-
python build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-triton-report.csv --benchmark attn-bwd --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
290-
python build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-xetla-report.csv --benchmark attn-bwd --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
292+
python build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-triton-report.csv --benchmark flash-attn-bwd --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
293+
python build_report.py $REPORTS/attn-bwd-performance.csv $REPORTS/attn-bwd-xetla-report.csv --benchmark flash-attn-bwd --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
291294
292295
- name: Run Triton FA fwd kernel benchmark - with tensor descriptors
293296
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flash_attention_tensor_desc_benchmark.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flash_attention_tensor_desc_benchmark.py') }}
@@ -297,22 +300,8 @@ jobs:
297300
mv $REPORTS/attn-performance.csv $REPORTS/attn-tensor-desc-performance.csv
298301
299302
source ../../scripts/capture-hw-details.sh
300-
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
301-
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
302-
303-
- name: Run Prefix Sums kernel benchmark
304-
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'prefix_sums.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefix_sums.py') }}
305-
run: |
306-
cd benchmarks/triton_kernels_benchmark
307-
python prefix_sums.py --reports $REPORTS --n_runs $N_RUNS
308-
source ../../scripts/capture-hw-details.sh
309-
python build_report.py $REPORTS/prefix-sums.csv $REPORTS/prefix_sums-triton-report.csv --benchmark prefix_sums --compiler triton --param_cols "N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
310-
311-
- name: Run micro benchmark
312-
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'micro_benchmarks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'micro_benchmarks') }}
313-
run: |
314-
cd benchmarks/micro_benchmarks
315-
python run_benchmarks.py --reports $REPORTS
303+
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-triton-report.csv --benchmark flash-attn-tensor-desc --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
304+
python build_report.py $REPORTS/attn-tensor-desc-performance.csv $REPORTS/attn-tensor-desc-xetla-report.csv --benchmark flash-attn-tensor-desc --compiler xetla --param_cols "Z,H,N_CTX,D_HEAD,CAUSAL" --tflops_col XeTLA-TFlops --hbm_col "XeTLA-GB/s" --tag $TAG
316305
317306
- name: Run Triton FlexAttention Causal Mask fwd kernel benchmark
318307
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_causal_mask.py') }}
@@ -321,7 +310,7 @@ jobs:
321310
python flex_attention_benchmark_causal_mask.py --reports $REPORTS --n_runs $N_RUNS
322311
323312
source ../../scripts/capture-hw-details.sh
324-
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flexAttnCausal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
313+
python build_report.py $REPORTS/flexAttnCausal-performance.csv $REPORTS/flexAttnCausal-triton-report.csv --benchmark flex-attn-causal --compiler triton --param_cols "Z,H_q,H_kv,N_CTX_q,N_CTX_kv,D_HEAD_qk,D_HEAD_v" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
325314
326315
- name: Run Triton FlexAttention Custom Masks fwd kernel benchmark
327316
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'flex_attention_benchmark_custom_masks.py') }}
@@ -330,9 +319,22 @@ jobs:
330319
python flex_attention_benchmark_custom_masks.py --reports $REPORTS --n_runs $N_RUNS
331320
332321
source ../../scripts/capture-hw-details.sh
333-
python build_report.py $REPORTS/flexAttnMasks-performance.csv $REPORTS/flexAttnMasks-triton-report.csv --benchmark flexAttnMasks --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,MASK" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG --mask
334-
python build_report.py $REPORTS/flexAttnMasks-performance.csv $REPORTS/flexAttnMasks-onednn-report.csv --benchmark flexAttnMasks --compiler onednn --param_cols "Z,H,N_CTX,D_HEAD,MASK" --tflops_col OneDNN-TFlops --hbm_col "OneDNN-GB/s" --tag $TAG --mask
322+
python build_report.py $REPORTS/flexAttnMasks-performance.csv $REPORTS/flexAttnMasks-triton-report.csv --benchmark flex-attn-masks --compiler triton --param_cols "Z,H,N_CTX,D_HEAD,MASK" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG --mask
323+
python build_report.py $REPORTS/flexAttnMasks-performance.csv $REPORTS/flexAttnMasks-onednn-report.csv --benchmark flex-attn-masks --compiler onednn --param_cols "Z,H,N_CTX,D_HEAD,MASK" --tflops_col OneDNN-TFlops --hbm_col "OneDNN-GB/s" --tag $TAG --mask
335324
325+
- name: Run Prefix Sums kernel benchmark
326+
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'prefix_sums.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'prefix_sums.py') }}
327+
run: |
328+
cd benchmarks/triton_kernels_benchmark
329+
python prefix_sums.py --reports $REPORTS --n_runs $N_RUNS
330+
source ../../scripts/capture-hw-details.sh
331+
python build_report.py $REPORTS/prefix-sums.csv $REPORTS/prefix_sums-triton-report.csv --benchmark prefix_sums --compiler triton --param_cols "N" --tflops_col Triton-TFlops --hbm_col "Triton-GB/s" --tag $TAG
332+
333+
- name: Run micro benchmark
334+
if: ${{ steps.install.outcome == 'success' && !cancelled() && (inputs.benchmarks == '' || contains(fromJson(inputs.benchmarks || '[]'), 'micro_benchmarks.py')) && !contains(fromJson(inputs.skip_benchmarks || '[]'), 'micro_benchmarks') }}
335+
run: |
336+
cd benchmarks/micro_benchmarks
337+
python run_benchmarks.py --reports $REPORTS
336338
337339
- name: Upload benchmark reports
338340
if: ${{ steps.install.outcome == 'success' && !cancelled() }}

benchmarks/triton_kernels_benchmark/flash_attention_benchmark.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,12 @@ def _attn_fwd_with_block_pointers(Q, K, V, sm_scale, M, Out, #
160160
configs = [
161161
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN, 'grf_mode': 'large', 'one_matrix_per_load_for_bt': True}, num_stages=s, num_warps=w) \
162162
for BM in [128, 256] \
163-
for BN in [32, 64, 128] \
163+
for BN in [32, 64] \
164164
for s in [2, 3, 4] \
165165
for w in [8, 16, 32] \
166166
]
167167

168-
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL'])
168+
tuner = triton.autotune(configs, key=['N_CTX', 'BLOCK_DMODEL', 'STAGE'])
169169

170170

171171
@triton.jit

benchmarks/triton_kernels_benchmark/gemm_benchmark.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -232,15 +232,13 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
232232
return a_shape, b_shape
233233

234234

235-
NEW_X_VALS = [ #
235+
X_VALS = [ #
236+
[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]
237+
] + [ #
236238
[1, m, n, 4096] for m in [1, 8] for n in [1024, 4096, 6144, 14336, 28672, 128256]
237239
] + [ #
238240
[1, m, 4096, 14336] for m in [1, 8]
239241
] + [ #
240-
[1, 8192, 4096, 4096] #
241-
]
242-
243-
X_VALS = [[1, 1024 * i, 1024 * i, 1024 * i] for i in [1, 2, 4, 8]] + [
244242
[1, 1, 13824, 5120],
245243
[1, 4, 12288, 4096],
246244
[1, 512, 8192, 8192],
@@ -261,6 +259,7 @@ def get_shapes(B, M, N, K, transpose_a, transpose_b):
261259
[32, 4096, 128, 4096],
262260
[4096, 8, 128, 16384],
263261
[4096, 8, 16384, 128],
262+
[1, 8192, 4096, 4096],
264263
]
265264

266265
DEVICE_NAME = torch.xpu.get_device_name()
@@ -281,16 +280,13 @@ def is_enough_memory(x_val):
281280
return enough_memory
282281

283282

284-
if os.getenv('NEW_SHAPES', '1') == '1':
285-
X_VALS += NEW_X_VALS
286283
X_VALS = [x_val for x_val in X_VALS if is_enough_memory(x_val)]
287284

288285

289286
def get_benchmark(
290287
providers_filter: Optional[list[str]] = None,
291288
transpose_a=False,
292289
transpose_b=False,
293-
new_shapes=False,
294290
matmul_kernel=matmul_kernel_with_block_pointers,
295291
matmul_kernel_batched=matmul_kernel_with_block_pointers_batched,
296292
plot_name='matmul-performance',
@@ -303,10 +299,8 @@ def get_benchmark(
303299
'triton': 'Triton',
304300
'onednn': 'OneDNN',
305301
}
306-
# use_xetla and use_cutlass
302+
# use_cutlass
307303
if not (transpose_a or transpose_b):
308-
if not new_shapes:
309-
supported_providers['xetla'] = 'XeTLA'
310304
if '580' not in torch.xpu.get_device_name():
311305
# FIXME: enable cutlass on bmg
312306
supported_providers['cutlass'] = 'CUTLASS'
@@ -459,6 +453,5 @@ def cutlass_invoker():
459453
_benchmark = get_benchmark(
460454
transpose_a=(os.getenv('TRANSPOSE_A', '0') == '1'),
461455
transpose_b=(os.getenv('TRANSPOSE_B', '0') == '1'),
462-
new_shapes=(os.getenv('NEW_SHAPES', '1') == '1'),
463456
)
464457
_benchmark.run(show_plots=False, print_data=True)

benchmarks/triton_kernels_benchmark/gemm_tensor_desc_benchmark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,6 @@ def get_benchmark(
117117
providers_filter: Optional[List[str]] = None,
118118
transpose_a=False,
119119
transpose_b=False,
120-
new_shapes=True,
121120
):
122121
return gemm_benchmark.get_benchmark(
123122
providers_filter=providers_filter,
@@ -126,14 +125,12 @@ def get_benchmark(
126125
plot_name='matmul-tensor-desc-performance',
127126
transpose_a=transpose_a,
128127
transpose_b=transpose_b,
129-
new_shapes=new_shapes,
130128
)
131129

132130

133131
if __name__ == '__main__':
134132
_benchmark = get_benchmark(
135133
transpose_a=(os.getenv('TRANSPOSE_A', '0') == '1'),
136134
transpose_b=(os.getenv('TRANSPOSE_B', '0') == '1'),
137-
new_shapes=(os.getenv('NEW_SHAPES', '1') == '1'),
138135
)
139136
_benchmark.run(show_plots=False, print_data=True)

benchmarks/triton_kernels_benchmark/gemm_tensor_of_ptr_benchmark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ def get_benchmark(
124124
providers_filter: Optional[List[str]] = None,
125125
transpose_a=False,
126126
transpose_b=False,
127-
new_shapes=True,
128127
):
129128
return gemm_benchmark.get_benchmark(
130129
providers_filter=providers_filter,
@@ -133,14 +132,12 @@ def get_benchmark(
133132
plot_name='matmul-tensor-of-ptr-performance',
134133
transpose_a=transpose_a,
135134
transpose_b=transpose_b,
136-
new_shapes=new_shapes,
137135
)
138136

139137

140138
if __name__ == '__main__':
141139
_benchmark = get_benchmark(
142140
transpose_a=(os.getenv('TRANSPOSE_A', '0') == '1'),
143141
transpose_b=(os.getenv('TRANSPOSE_B', '0') == '1'),
144-
new_shapes=(os.getenv('NEW_SHAPES', '1') == '1'),
145142
)
146143
_benchmark.run(show_plots=False, print_data=True)

python/triton/_utils.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,3 @@ def _impl(path: tuple[int, ...], current: Any):
3535
_impl((), iterable)
3636

3737
return list(ret.keys())
38-
39-
40-
class ClassPropertyDescriptor:
41-
42-
def __init__(self, fget, fset=None):
43-
self.fget = fget
44-
self.fset = fset
45-
46-
def __get__(self, obj, cls):
47-
return self.fget(cls)
48-
49-
def __set__(self, obj, value):
50-
if self.fset is None:
51-
raise AttributeError("can't set attribute")
52-
self.fset(obj.__class__, value)
53-
54-
def setter(self, fset):
55-
self.fset = fset
56-
return self
57-
58-
59-
def classproperty(func):
60-
return ClassPropertyDescriptor(func)

python/triton/compiler/compiler.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager
1010
from ..runtime.driver import driver
1111
from ..tools.disasm import get_sass, get_spvdis
12-
from .._utils import classproperty
1312
# TODO: this shouldn't be here
1413
from .code_generator import ast_to_ttir
1514
from pathlib import Path
@@ -432,24 +431,6 @@ def __missing__(self, key):
432431

433432
class CompiledKernel:
434433

435-
# FIXME: remove launch_enter_hook/launch_exit_hook properties
436-
# when pytorch has a compatible layer for the new API.
437-
@classproperty
438-
def launch_enter_hook(cls):
439-
return knobs.runtime.launch_enter_hook
440-
441-
@launch_enter_hook.setter
442-
def launch_enter_hook(cls, value):
443-
knobs.runtime.launch_enter_hook = value
444-
445-
@classproperty
446-
def launch_exit_hook(cls):
447-
return knobs.runtime.launch_exit_hook
448-
449-
@launch_exit_hook.setter
450-
def launch_exit_hook(cls, value):
451-
knobs.runtime.launch_exit_hook = value
452-
453434
def __init__(self, src, metadata_group, hash):
454435
from collections import namedtuple
455436
metadata_path = next((Path(p) for c, p in metadata_group.items() if c.endswith(".json")))

python/triton/runtime/jit.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from types import ModuleType
1515
from .. import knobs
1616
from ..runtime.driver import driver
17-
from .._utils import find_paths_if, get_iterable_path, classproperty
17+
from .._utils import find_paths_if, get_iterable_path
1818

1919
TRITON_MODULE = __name__[:-len(".runtime.jit")]
2020

@@ -494,24 +494,6 @@ class JitFunctionInfo:
494494

495495
class JITFunction(KernelInterface[T]):
496496

497-
# FIXME: remove cache_hook/compiled_hook properties
498-
# when pytorch has a compatible layer for the new API.
499-
@classproperty
500-
def cache_hook(cls):
501-
return knobs.runtime.jit_cache_hook
502-
503-
@cache_hook.setter
504-
def cache_hook(cls, value):
505-
knobs.runtime.jit_cache_hook = value
506-
507-
@classproperty
508-
def compiled_hook(cls):
509-
return knobs.runtime.jit_post_compile_hook
510-
511-
@compiled_hook.setter
512-
def compiled_hook(cls, value):
513-
knobs.runtime.jit_post_compile_hook = value
514-
515497
def _call_hook(
516498
self,
517499
hook,

0 commit comments

Comments
 (0)