@@ -65,3 +65,43 @@ module @eltwise_add attributes {gpu.container_module} {
65
65
}
66
66
func.func private @printMemrefBF16 (memref <*xbf16 >)
67
67
}
68
+
69
+
70
+ module @eltwise_add_usm attributes {gpu.container_module } {
71
+ memref.global " private" constant @__constant_10x20xbf16 : memref <10 x20 xbf16 > = dense <5.000000e-01 >
72
+ func.func @test (%arg0: memref <10 x20 xbf16 >, %arg1: memref <10 x20 xbf16 >) -> memref <10 x20 xbf16 > {
73
+ %c20 = arith.constant 20 : index
74
+ %c10 = arith.constant 10 : index
75
+ %c1 = arith.constant 1 : index
76
+ %memref_1 = gpu.alloc host_shared () : memref <10 x20 xbf16 >
77
+ gpu.launch_func @test_kernel ::@test_kernel blocks in (%c10 , %c20 , %c1 ) threads in (%c1 , %c1 , %c1 ) args (%arg0 : memref <10 x20 xbf16 >, %arg1 : memref <10 x20 xbf16 >, %memref_1 : memref <10 x20 xbf16 >)
78
+ %alloc = memref.alloc () : memref <10 x20 xbf16 >
79
+ memref.copy %memref_1 , %alloc : memref <10 x20 xbf16 > to memref <10 x20 xbf16 >
80
+ gpu.dealloc %memref_1 : memref <10 x20 xbf16 >
81
+ return %alloc : memref <10 x20 xbf16 >
82
+ }
83
+ gpu.module @test_kernel attributes {spirv.target_env = #spirv.target_env <#spirv.vce <v1.0 , [Addresses , Float16Buffer , Int64 , Int16 , Int8 , Kernel , Linkage , Vector16 , GenericPointer , Groups , Float16 , Float64 , AtomicFloat32AddEXT , ExpectAssumeKHR , SubgroupDispatch , VectorComputeINTEL , VectorAnyINTEL , Bfloat16ConversionINTEL ], [SPV_EXT_shader_atomic_float_add , SPV_KHR_expect_assume , SPV_INTEL_vector_compute , SPV_INTEL_bfloat16_conversion ]>, api =OpenCL , #spirv.resource_limits <>>} {
84
+ gpu.func @test_kernel (%arg0: memref <10 x20 xbf16 >, %arg1: memref <10 x20 xbf16 >, %arg2: memref <10 x20 xbf16 >) kernel attributes {VectorComputeFunctionINTEL , gpu.known_block_size = array<i32 : 1 , 1 , 1 >, gpu.known_grid_size = array<i32 : 10 , 20 , 1 >, spirv.entry_point_abi = #spirv.entry_point_abi <>} {
85
+ %block_id_x = gpu.block_id x
86
+ %block_id_y = gpu.block_id y
87
+ %cst = arith.constant 0.5 : bf16
88
+ %0 = memref.load %arg0 [%block_id_x , %block_id_y ] : memref <10 x20 xbf16 >
89
+ %1 = memref.load %arg1 [%block_id_x , %block_id_y ] : memref <10 x20 xbf16 >
90
+ %2 = arith.addf %0 , %1 : bf16
91
+ %3 = arith.addf %2 , %cst : bf16
92
+ memref.store %3 , %arg2 [%block_id_x , %block_id_y ] : memref <10 x20 xbf16 >
93
+ gpu.return
94
+ }
95
+ }
96
+ func.func @main () {
97
+ %0 = memref.get_global @__constant_10x20xbf16 : memref <10 x20 xbf16 >
98
+ %1 = memref.get_global @__constant_10x20xbf16 : memref <10 x20 xbf16 >
99
+ %2 = call @test (%0 , %1 ) : (memref <10 x20 xbf16 >, memref <10 x20 xbf16 >) -> memref <10 x20 xbf16 >
100
+ %cast = memref.cast %2 : memref <10 x20 xbf16 > to memref <*xbf16 >
101
+ // CHECK: Unranked Memref base@ = {{(0x)?[-9a-f]*}}
102
+ // CHECK-COUNT-200: 1.5
103
+ call @printMemrefBF16 (%cast ) : (memref <*xbf16 >) -> ()
104
+ return
105
+ }
106
+ func.func private @printMemrefBF16 (memref <*xbf16 >) attributes {llvm.emit_c_interface }
107
+ }
0 commit comments