Skip to content

Commit 4df3eb8

Browse files
committed
add bf16 test
1 parent aecf9ab commit 4df3eb8

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

test/Transforms/BF16ToGPU/EltwiseAdd.bf16.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,43 @@ module @eltwise_add attributes {gpu.container_module} {
6565
}
6666
func.func private @printMemrefBF16(memref<*xbf16>)
6767
}
68+
69+
70+
module @eltwise_add_usm attributes {gpu.container_module} {
71+
memref.global "private" constant @__constant_10x20xbf16 : memref<10x20xbf16> = dense<5.000000e-01>
72+
func.func @test(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>) -> memref<10x20xbf16> {
73+
%c20 = arith.constant 20 : index
74+
%c10 = arith.constant 10 : index
75+
%c1 = arith.constant 1 : index
76+
%memref_1 = gpu.alloc host_shared () : memref<10x20xbf16>
77+
gpu.launch_func @test_kernel::@test_kernel blocks in (%c10, %c20, %c1) threads in (%c1, %c1, %c1) args(%arg0 : memref<10x20xbf16>, %arg1 : memref<10x20xbf16>, %memref_1 : memref<10x20xbf16>)
78+
%alloc = memref.alloc() : memref<10x20xbf16>
79+
memref.copy %memref_1, %alloc : memref<10x20xbf16> to memref<10x20xbf16>
80+
gpu.dealloc %memref_1 : memref<10x20xbf16>
81+
return %alloc : memref<10x20xbf16>
82+
}
83+
gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Float16Buffer, Int64, Int16, Int8, Kernel, Linkage, Vector16, GenericPointer, Groups, Float16, Float64, AtomicFloat32AddEXT, ExpectAssumeKHR, SubgroupDispatch, VectorComputeINTEL, VectorAnyINTEL, Bfloat16ConversionINTEL], [SPV_EXT_shader_atomic_float_add, SPV_KHR_expect_assume, SPV_INTEL_vector_compute, SPV_INTEL_bfloat16_conversion]>, api=OpenCL, #spirv.resource_limits<>>} {
84+
gpu.func @test_kernel(%arg0: memref<10x20xbf16>, %arg1: memref<10x20xbf16>, %arg2: memref<10x20xbf16>) kernel attributes {VectorComputeFunctionINTEL, gpu.known_block_size = array<i32: 1, 1, 1>, gpu.known_grid_size = array<i32: 10, 20, 1>, spirv.entry_point_abi = #spirv.entry_point_abi<>} {
85+
%block_id_x = gpu.block_id x
86+
%block_id_y = gpu.block_id y
87+
%cst = arith.constant 0.5 : bf16
88+
%0 = memref.load %arg0[%block_id_x, %block_id_y] : memref<10x20xbf16>
89+
%1 = memref.load %arg1[%block_id_x, %block_id_y] : memref<10x20xbf16>
90+
%2 = arith.addf %0, %1 : bf16
91+
%3 = arith.addf %2, %cst : bf16
92+
memref.store %3, %arg2[%block_id_x, %block_id_y] : memref<10x20xbf16>
93+
gpu.return
94+
}
95+
}
96+
func.func @main() {
97+
%0 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
98+
%1 = memref.get_global @__constant_10x20xbf16 : memref<10x20xbf16>
99+
%2 = call @test(%0, %1) : (memref<10x20xbf16>, memref<10x20xbf16>) -> memref<10x20xbf16>
100+
%cast = memref.cast %2 : memref<10x20xbf16> to memref<*xbf16>
101+
// CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
102+
// CHECK-COUNT-200: 1.5
103+
call @printMemrefBF16(%cast) : (memref<*xbf16>) -> ()
104+
return
105+
}
106+
func.func private @printMemrefBF16(memref<*xbf16>) attributes {llvm.emit_c_interface}
107+
}

0 commit comments

Comments
 (0)