Skip to content

Commit af7c4e2

Browse files
Merge OpenAI Triton commit 9451f8f (#4013)
This PR change the Triton base from 981e987 to 9451f8f (Apr 23). Pass rate: 88.73%
2 parents 3ddb3e3 + 98a8bfe commit af7c4e2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+1528
-1153
lines changed

.github/workflows/integration-tests.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ jobs:
114114
if: env.enable_integration == 'true'
115115
run: |
116116
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
117-
echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"], ["gb200-runner-set"]]'
117+
echo '::set-output name=matrix-CUDA::[["nvidia-a100"], ["nvidia-h100"], ["nvidia-gb200"]]'
118118
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["amd-gfx942"]]'
119119
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
120120
else
@@ -232,7 +232,7 @@ jobs:
232232
env:
233233
CUDA_HOME: "/usr/local/cuda"
234234
run: |
235-
if [ "${{ matrix.runner[0] }}" == "gb200-runner-set" ]; then
235+
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
236236
source /venv/bin/activate
237237
fi
238238
echo "PATH is '$PATH'"
@@ -244,23 +244,23 @@ jobs:
244244
run: make test-lit
245245
- name: Run python tests on CUDA
246246
run: |
247-
if [ "${{ matrix.runner[0] }}" == "gb200-runner-set" ]; then
247+
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
248248
source /venv/bin/activate
249249
fi
250250
make test-unit
251251
- name: Run interpreter tests
252-
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
252+
if: ${{ matrix.runner[0] == 'nvidia-h100' }}
253253
run: make test-interpret
254254
- name: Run regression tests
255255
run: |
256-
if [ "${{ matrix.runner[0] }}" == "gb200-runner-set" ]; then
256+
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
257257
source /venv/bin/activate
258258
fi
259259
make test-regression
260260
- name: Run C++ unittests
261261
run: make test-cpp
262262
- name: Run Proton tests
263-
if: ${{ matrix.runner[0] != 'gb200-runner-set' }}
263+
if: ${{ matrix.runner[0] != 'nvidia-gb200' }}
264264
run: make test-proton
265265
- name: Inspect cache directories
266266
run: |
@@ -409,7 +409,7 @@ jobs:
409409
cd python/test/regression
410410
python3 -m pytest -s -n 8 ./test_cast_matmul.py
411411
- name: Run Proton tests
412-
if: ${{ matrix.runner[0] != 'gb200-runner-set' }}
412+
if: ${{ matrix.runner[0] != 'nvidia-gb200' }}
413413
run: make test-proton
414414
- name: Run C++ unittests
415415
run: make test-cpp

.github/workflows/integration-tests.yml.in

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ jobs:
123123
if: env.enable_integration == 'true'
124124
run: |
125125
if [ x"${{ github.repository }}" == x"triton-lang/triton" ]; then
126-
echo '::set-output name=matrix-CUDA::[["a100-runner-set"], ["h100-runner-set"], ["gb200-runner-set"]]'
126+
echo '::set-output name=matrix-CUDA::[["nvidia-a100"], ["nvidia-h100"], ["nvidia-gb200"]]'
127127
echo '::set-output name=matrix-HIP::[["self-hosted", "gfx90a"], ["amd-gfx942"]]'
128128
echo '::set-output name=matrix-MACOS::[["macos-latest"]]'
129129
else
@@ -264,7 +264,7 @@ jobs:
264264
env:
265265
CUDA_HOME: "/usr/local/cuda"
266266
run: |
267-
if [ "${{ matrix.runner[0] }}" == "gb200-runner-set" ]; then
267+
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
268268
source /venv/bin/activate
269269
fi
270270
echo "PATH is '$PATH'"
@@ -281,18 +281,18 @@ jobs:
281281

282282
- name: Run python tests on CUDA
283283
run: |
284-
if [ "${{ matrix.runner[0] }}" == "gb200-runner-set" ]; then
284+
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
285285
source /venv/bin/activate
286286
fi
287287
make test-unit
288288

289289
- name: Run interpreter tests
290-
if: ${{ matrix.runner[0] == 'h100-runner-set' }}
290+
if: ${{ matrix.runner[0] == 'nvidia-h100' }}
291291
run: make test-interpret
292292

293293
- name: Run regression tests
294294
run: |
295-
if [ "${{ matrix.runner[0] }}" == "gb200-runner-set" ]; then
295+
if [ "${{ matrix.runner[0] }}" == "nvidia-gb200" ]; then
296296
source /venv/bin/activate
297297
fi
298298
make test-regression
@@ -303,7 +303,7 @@ jobs:
303303

304304
- &run-proton-tests-step
305305
name: Run Proton tests
306-
if: ${{ matrix.runner[0] != 'gb200-runner-set' }}
306+
if: ${{ matrix.runner[0] != 'nvidia-gb200' }}
307307
run: make test-proton
308308

309309
- *inspect-cache-directories-step

bench/bench/bench_mlp.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import triton.profiler as proton
44
import torch
55
import triton_bench.swiglu
6-
from triton_bench.mxfp import downcast_to_mxfp
6+
from triton_bench.numerics_details.mxfp import downcast_to_mxfp
77
from triton_bench.matmul_ogs import MicroscalingCtx, matmul_ogs, PrecisionConfig, FlexCtx
88
from triton_bench.numerics import InFlexData
99
from triton_bench.routing import routing
10-
from triton_bench.meta import cuda_capability_geq, is_hip, get_cdna_version
10+
from triton_bench.target_info import is_hip, get_cdna_version
1111

1212
if torch.cuda.is_available() and not is_hip():
1313
from triton._C.libtriton import nvidia
@@ -152,5 +152,5 @@ def bench_mlp(batch, dim1, dim2, n_expts_tot, n_expts_act, x_dtype, w_dtype,
152152
qxdtype = "fp8" if has_native_mx4 else "bf16"
153153
print(bench_mlp(8192, 8192, 8192, 1, 1, "fp8", "fp8", TP=1, EP=1, name="dense"))
154154
print(bench_mlp(8192, 8192, 8192, 1, 1, qxdtype, "mx4", TP=1, EP=1, name="dense"))
155-
print(bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=2, name="llama4"))
156-
print(bench_mlp(2048, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=2, name="llama4"))
155+
print(bench_mlp(2048, 5120, 8192, 128, 4, "fp8", "fp8", TP=4, EP=1, name="llama4"))
156+
print(bench_mlp(2048, 5120, 8192, 128, 4, qxdtype, "mx4", TP=4, EP=1, name="llama4"))

bench/tests/test_compact.py renamed to bench/tests/test_compaction.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22
import torch
3-
from triton_bench.compact import masked_compact, masked_compact_torch
3+
from triton_bench.compaction import compaction, compaction_torch
44

55

66
@pytest.mark.parametrize("n_tokens, n_cols, k, p", [
@@ -9,7 +9,7 @@
99
(131, 128, 16, 0.6),
1010
(496, 128, 16, 0.),
1111
])
12-
def test_masked_compact(n_tokens, n_cols, k, p):
12+
def test_compaction(n_tokens, n_cols, k, p):
1313
device = "cuda"
1414
yi = torch.rand((n_tokens, n_cols), device=device).argsort(dim=-1)
1515
yi = yi[:, :k].to(torch.int32)
@@ -23,7 +23,7 @@ def test_masked_compact(n_tokens, n_cols, k, p):
2323
chunks = mask.view(*mask.shape[:-1], -1, 32)
2424
weights = (1 << torch.arange(32, dtype=torch.int32, device=device))
2525
bitmask = (chunks.int() * weights).sum(dim=-1)
26-
yv_ref, yi_ref = masked_compact_torch(yv, yi, bitmask)
27-
yv_tri, yi_tri = masked_compact(yv, yi, bitmask)
26+
yv_ref, yi_ref = compaction_torch(yv, yi, bitmask)
27+
yv_tri, yi_tri = compaction(yv, yi, bitmask)
2828
assert torch.all(yi_ref == yi_tri)
2929
assert torch.all(yv_ref == yv_tri)

bench/tests/test_matmul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from triton_bench.matmul_ogs import matmul_ogs, matmul_ogs_torch
1212
# numerics utilities
1313
from triton_bench.numerics import InFlexData, OutFlexData
14-
from triton_bench.mxfp import downcast_to_mxfp, upcast_from_mxfp
14+
from triton_bench.numerics_details.mxfp import downcast_to_mxfp, upcast_from_mxfp
1515
# testing utilities
1616
from triton_bench.testing import assert_close, compute_actual_scale
1717
# target-specific utilities
18-
from triton_bench.meta import is_hip
18+
from triton_bench.target_info import is_hip
1919

2020
# ---------------
2121
# initialize data

bench/tests/test_swiglu.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
from triton_bench.routing import routing_torch
12
from triton_bench.swiglu import swiglu, swiglu_torch, PrecisionConfig
23
from triton_bench.testing import assert_close
34
import torch
45
import pytest
56

7+
from .test_routing import init_data as init_routing_data
8+
from .test_routing import ref_expt_data
9+
610
# ---------------
711
# initialize data
812
# ---------------
@@ -15,10 +19,6 @@ def alloc_rand(shape, device, dtype, requires_grad=True):
1519
return torch.randn(shape, device=device, dtype=dtype, requires_grad=requires_grad)
1620

1721

18-
def alloc_rand_like(x):
19-
return alloc_rand(x.shape, x.device, x.dtype, x.requires_grad)
20-
21-
2222
# ---------------
2323
# unit tests
2424
# ---------------
@@ -30,9 +30,17 @@ def test_op(M, N, limit, alpha=0.5):
3030
torch.manual_seed(2)
3131
dev = "cuda"
3232
dtype = torch.bfloat16
33+
# initialize expert data
34+
n_expts_tot = 6
35+
n_expts_act = 2
36+
logits = init_routing_data(M, n_expts_tot).detach()
37+
routing_data, _, _ = routing_torch(logits, n_expts_act)
38+
expt_data = ref_expt_data(routing_data, M * n_expts_act, block_m=128)
39+
n_tokens = expt_data[2 * n_expts_tot].sum()
40+
3341
# initialize data
34-
x = alloc_rand([M, N], device=dev, dtype=torch.bfloat16)
42+
x = alloc_rand([n_tokens, N], device=dev, dtype=dtype)
3543
precision_config = PrecisionConfig(limit=limit)
36-
tri_y = swiglu(x, alpha, precision_config)
44+
tri_y = swiglu(x, alpha, precision_config, expt_data, n_expts_tot)
3745
ref_y = swiglu_torch(x, alpha, precision_config)
3846
assert_close(tri_y, ref_y)

bench/triton_bench/compact.py renamed to bench/triton_bench/compaction.py

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,9 @@
11
import torch
2-
import triton
3-
import triton.language as tl
2+
from .compaction_details._masked_compaction import _masked_compaction
43
from triton_bench import Bitmatrix
54

65

7-
@triton.jit
8-
def _masked_compact(Yv, Yi, BitMask, stride_bm, RetYv, RetYi, sentinel, K: tl.constexpr):
9-
pid_m = tl.program_id(0)
10-
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
11-
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
12-
div = yi // 32
13-
rem = yi % 32
14-
active_bits = (tl.load(BitMask + pid_m * stride_bm + div) >> rem) & 1
15-
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
16-
rev_arange = tl.where(active_bits, 0, K - 1 - tl.arange(0, K))
17-
write_indx = exc_cumsum + rev_arange
18-
yv = tl.where(active_bits, yv, sentinel)
19-
yi = tl.where(active_bits, yi, sentinel)
20-
tl.store(RetYv + pid_m * K + write_indx, yv)
21-
tl.store(RetYi + pid_m * K + write_indx, yi)
22-
23-
24-
def masked_compact(yv, yi, bitmask, sentinel=-1):
6+
def compaction(yv, yi, bitmask, sentinel=-1):
257
"""
268
Return compacted copies of *yv* and *yi* based on a per-row bitmask.
279
@@ -53,7 +35,7 @@ def masked_compact(yv, yi, bitmask, sentinel=-1):
5335
if isinstance(bitmask, Bitmatrix):
5436
bitmask = bitmask.data
5537

56-
_masked_compact[(n_rows, )](
38+
_masked_compaction[(n_rows, )](
5739
yv, yi, bitmask, bitmask.stride(0), # inputs
5840
ret_yv, ret_yi, # outputs
5941
sentinel, # sentinel
@@ -62,7 +44,7 @@ def masked_compact(yv, yi, bitmask, sentinel=-1):
6244
return ret_yv, ret_yi
6345

6446

65-
def masked_compact_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1):
47+
def compaction_torch(yv: torch.Tensor, yi: torch.Tensor, bitmask: torch.Tensor, sentinel=-1):
6648
"""
6749
reference implementation of `masked_compact`
6850
"""
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import triton
2+
import triton.language as tl
3+
4+
5+
@triton.jit
6+
def _masked_compaction(Yv, Yi, BitMask, stride_bm, RetYv, RetYi, sentinel, K: tl.constexpr):
7+
pid_m = tl.program_id(0)
8+
yv = tl.load(Yv + pid_m * K + tl.arange(0, K))
9+
yi = tl.load(Yi + pid_m * K + tl.arange(0, K))
10+
div = yi // 32
11+
rem = yi % 32
12+
active_bits = (tl.load(BitMask + pid_m * stride_bm + div) >> rem) & 1
13+
exc_cumsum = tl.cumsum(active_bits, 0) - active_bits
14+
rev_arange = tl.where(active_bits, 0, K - 1 - tl.arange(0, K))
15+
write_indx = exc_cumsum + rev_arange
16+
yv = tl.where(active_bits, yv, sentinel)
17+
yi = tl.where(active_bits, yi, sentinel)
18+
tl.store(RetYv + pid_m * K + write_indx, yv)
19+
tl.store(RetYi + pid_m * K + write_indx, yi)

0 commit comments

Comments
 (0)