|
| 1 | +from textwrap import dedent |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +import mlir_utils.types as T |
| 6 | +from mlir_utils.dialects.ext.arith import constant |
| 7 | +from mlir_utils.dialects.ext.func import func |
| 8 | +from mlir_utils.dialects.ext.nvgpu import tensormap_descriptor |
| 9 | +from mlir_utils.dialects.memref import cast |
| 10 | +from mlir_utils.dialects.nvgpu import tma_create_descriptor |
| 11 | + |
| 12 | +# noinspection PyUnresolvedReferences |
| 13 | +from mlir_utils.testing import mlir_ctx as ctx, filecheck, MLIRContext |
| 14 | + |
| 15 | +# needed since the fix isn't defined here nor conftest.py |
| 16 | +pytest.mark.usefixtures("ctx") |
| 17 | + |
| 18 | + |
| 19 | +def test_basic(ctx: MLIRContext): |
| 20 | + @func |
| 21 | + def create_tensor_map( |
| 22 | + device_ptr_2d: T.memref(64, 128, element_type=T.f32), |
| 23 | + ): |
| 24 | + crd0 = constant(64, index=True) |
| 25 | + crd1 = constant(128, index=True) |
| 26 | + device_ptr_2d_unranked = cast(T.memref(element_type=T.f32), device_ptr_2d) |
| 27 | + tensor_map_2d = tensormap_descriptor(T.memref(32, 32, T.f32, memory_space=3)) |
| 28 | + tensor_map_2d = tma_create_descriptor( |
| 29 | + tensor_map_2d, device_ptr_2d_unranked, [crd0, crd1] |
| 30 | + ) |
| 31 | + |
| 32 | + create_tensor_map.emit() |
| 33 | + |
| 34 | + ctx.module.operation.verify() |
| 35 | + correct = dedent( |
| 36 | + """\ |
| 37 | + module { |
| 38 | + func.func @create_tensor_map(%arg0: memref<64x128xf32>) { |
| 39 | + %c64 = arith.constant 64 : index |
| 40 | + %c128 = arith.constant 128 : index |
| 41 | + %cast = memref.cast %arg0 : memref<64x128xf32> to memref<*xf32> |
| 42 | + %0 = nvgpu.tma.create.descriptor %cast box[%c64, %c128] : memref<*xf32> -> <tensor = memref<32x32xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none> |
| 43 | + return |
| 44 | + } |
| 45 | + } |
| 46 | + """ |
| 47 | + ) |
| 48 | + filecheck(correct, ctx.module) |
0 commit comments