Skip to content

Commit 139e376

Browse files
committed
nvgpu/nvvm
1 parent ed1f046 commit 139e376

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

mlir_utils/dialects/ext/nvgpu.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from textwrap import dedent
2+
3+
from mlir.dialects._nvgpu_enum_gen import (
4+
TensorMapSwizzleKind,
5+
TensorMapL2PromoKind,
6+
TensorMapOOBKind,
7+
TensorMapInterleaveKind,
8+
)
9+
from mlir.ir import Attribute, Type
10+
11+
12+
def tensormap_descriptor(
13+
tensor,
14+
swizzle: TensorMapSwizzleKind = TensorMapSwizzleKind.SWIZZLE_NONE,
15+
l2promo: TensorMapL2PromoKind = TensorMapL2PromoKind.L2PROMO_NONE,
16+
oob: TensorMapOOBKind = TensorMapOOBKind.OOB_NAN,
17+
interleave: TensorMapInterleaveKind = TensorMapInterleaveKind.INTERLEAVE_NONE,
18+
context=None,
19+
):
20+
return Type.parse(
21+
dedent(
22+
f"""\
23+
!nvgpu.tensormap.descriptor<tensor = {tensor},
24+
swizzle = {swizzle},
25+
l2promo = {l2promo},
26+
oob = {oob},
27+
interleave = {interleave}>
28+
"""
29+
),
30+
context=context,
31+
)

tests/test_nvgpu_nvvm.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
from textwrap import dedent
2+
3+
import pytest
4+
5+
import mlir_utils.types as T
6+
from mlir_utils.dialects.ext.arith import constant
7+
from mlir_utils.dialects.ext.func import func
8+
from mlir_utils.dialects.ext.nvgpu import tensormap_descriptor
9+
from mlir_utils.dialects.memref import cast
10+
from mlir_utils.dialects.nvgpu import tma_create_descriptor
11+
12+
# noinspection PyUnresolvedReferences
13+
from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext
14+
15+
# needed since the fix isn't defined here nor conftest.py
16+
pytest.mark.usefixtures("ctx")
17+
18+
19+
def test_basic(ctx: MLIRContext):
20+
@func
21+
def create_tensor_map(
22+
device_ptr_2d: T.memref(64, 128, element_type=T.f32),
23+
):
24+
crd0 = constant(64, index=True)
25+
crd1 = constant(128, index=True)
26+
device_ptr_2d_unranked = cast(T.memref(element_type=T.f32), device_ptr_2d)
27+
tensor_map_2d = tensormap_descriptor(T.memref(32, 32, T.f32, memory_space=3))
28+
tensor_map_2d = tma_create_descriptor(
29+
tensor_map_2d, device_ptr_2d_unranked, [crd0, crd1]
30+
)
31+
32+
create_tensor_map.emit()
33+
34+
ctx.module.operation.verify()
35+
correct = dedent(
36+
"""\
37+
module {
38+
func.func @create_tensor_map(%arg0: memref<64x128xf32>) {
39+
%c64 = arith.constant 64 : index
40+
%c128 = arith.constant 128 : index
41+
%cast = memref.cast %arg0 : memref<64x128xf32> to memref<*xf32>
42+
%0 = nvgpu.tma.create.descriptor %cast box[%c64, %c128] : memref<*xf32> -> <tensor = memref<32x32xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>
43+
return
44+
}
45+
}
46+
"""
47+
)
48+
filecheck(correct, ctx.module)

0 commit comments

Comments
 (0)