Skip to content

Commit 0f32809

Browse files
authored
Reland [mlir][x86vector] AVX Convert/Broadcast BF16 to F32 instructions (#136830)
Quick fix for the PR: #135143 which failed building on `amd` and `arm` bots build. See the logs in the above PR for the errors.
1 parent 013aab4 commit 0f32809

File tree

9 files changed

+341
-12
lines changed

9 files changed

+341
-12
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
8383
}
8484
}];
8585
let extraClassDeclaration = [{
86-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
86+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
8787
}];
8888
}
8989

@@ -404,8 +404,127 @@ def DotOp : AVX_LowOp<"dot", [Pure,
404404
}
405405
}];
406406
let extraClassDeclaration = [{
407-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&);
407+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
408408
}];
409409
}
410410

411+
412+
//----------------------------------------------------------------------------//
413+
// AVX: Convert packed BF16 even-indexed/odd-indexed elements into packed F32
414+
//----------------------------------------------------------------------------//
415+
416+
def CvtPackedEvenIndexedBF16ToF32Op : AVX_Op<"cvt.packed.even.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
417+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
418+
let summary = "AVX: Convert packed BF16 even-indexed elements into packed F32 Data.";
419+
let description = [{
420+
#### From the Intel Intrinsics Guide:
421+
422+
Convert packed BF16 (16-bit) floating-point even-indexed elements stored at
423+
memory locations starting at location `__A` to packed single-precision
424+
(32-bit) floating-point elements, and store the results in `dst`.
425+
426+
Example:
427+
```mlir
428+
%dst = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
429+
```
430+
}];
431+
let arguments = (ins AnyMemRef:$a);
432+
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
433+
let assemblyFormat =
434+
"$a attr-dict`:` type($a)`->` type($dst)";
435+
436+
let extraClassDefinition = [{
437+
std::string $cppClass::getIntrinsicName() {
438+
std::string intr = "llvm.x86.vcvtneebf162ps";
439+
VectorType vecType = getDst().getType();
440+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
441+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
442+
intr += std::to_string(opBitWidth);
443+
return intr;
444+
}
445+
}];
446+
447+
let extraClassDeclaration = [{
448+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
449+
}];
450+
}
451+
452+
def CvtPackedOddIndexedBF16ToF32Op : AVX_Op<"cvt.packed.odd.indexed.bf16_to_f32", [MemoryEffects<[MemRead]>,
453+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
454+
let summary = "AVX: Convert packed BF16 odd-indexed elements into packed F32 Data.";
455+
let description = [{
456+
#### From the Intel Intrinsics Guide:
457+
458+
Convert packed BF16 (16-bit) floating-point odd-indexed elements stored at
459+
memory locations starting at location `__A` to packed single-precision
460+
(32-bit) floating-point elements, and store the results in `dst`.
461+
462+
Example:
463+
```mlir
464+
%dst = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
465+
```
466+
}];
467+
let arguments = (ins AnyMemRef:$a);
468+
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
469+
let assemblyFormat =
470+
"$a attr-dict`:` type($a)`->` type($dst)";
471+
472+
let extraClassDefinition = [{
473+
std::string $cppClass::getIntrinsicName() {
474+
std::string intr = "llvm.x86.vcvtneobf162ps";
475+
VectorType vecType = getDst().getType();
476+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
477+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
478+
intr += std::to_string(opBitWidth);
479+
return intr;
480+
}
481+
}];
482+
483+
let extraClassDeclaration = [{
484+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
485+
}];
486+
}
487+
488+
//----------------------------------------------------------------------------//
489+
// AVX: Convert BF16 to F32 and broadcast into packed F32
490+
//----------------------------------------------------------------------------//
491+
492+
def BcstBF16ToPackedF32Op : AVX_Op<"bcst.bf16_to_f32.packed", [MemoryEffects<[MemRead]>,
493+
DeclareOpInterfaceMethods<OneToOneIntrinsicOpInterface>]> {
494+
let summary = "AVX: Broadcasts BF16 into packed F32 Data.";
495+
let description = [{
496+
#### From the Intel Intrinsics Guide:
497+
498+
Convert scalar BF16 (16-bit) floating-point element stored at memory locations
499+
starting at location `__A` to a single-precision (32-bit) floating-point,
500+
broadcast it to packed single-precision (32-bit) floating-point elements,
501+
and store the results in `dst`.
502+
503+
Example:
504+
```mlir
505+
%dst = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
506+
```
507+
}];
508+
let arguments = (ins AnyMemRef:$a);
509+
let results = (outs VectorOfLengthAndType<[4, 8], [F32]>:$dst);
510+
let assemblyFormat =
511+
"$a attr-dict`:` type($a)`->` type($dst)";
512+
513+
let extraClassDefinition = [{
514+
std::string $cppClass::getIntrinsicName() {
515+
std::string intr = "llvm.x86.vbcstnebf162ps";
516+
VectorType vecType = getDst().getType();
517+
unsigned elemBitWidth = vecType.getElementTypeBitWidth();
518+
unsigned opBitWidth = vecType.getShape()[0] * elemBitWidth;
519+
intr += std::to_string(opBitWidth);
520+
return intr;
521+
}
522+
}];
523+
524+
let extraClassDeclaration = [{
525+
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
526+
}];
527+
528+
}
529+
411530
#endif // X86VECTOR_OPS

mlir/include/mlir/Dialect/X86Vector/X86VectorDialect.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#define MLIR_DIALECT_X86VECTOR_X86VECTORDIALECT_H_
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1719
#include "mlir/IR/BuiltinTypes.h"
1820
#include "mlir/IR/Dialect.h"
1921
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
5858
}],
5959
/*retType=*/"SmallVector<Value>",
6060
/*methodName=*/"getIntrinsicOperands",
61-
/*args=*/(ins "::mlir::RewriterBase &":$rewriter),
61+
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
6262
/*methodBody=*/"",
6363
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
6464
>,

mlir/lib/Dialect/X86Vector/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ add_mlir_dialect_library(MLIRX86VectorDialect
99

1010
LINK_LIBS PUBLIC
1111
MLIRIR
12+
MLIRLLVMCommonConversion
1213
MLIRLLVMDialect
1314
MLIRSideEffectInterfaces
1415
)

mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,26 @@ void x86vector::X86VectorDialect::initialize() {
3131
>();
3232
}
3333

34+
static SmallVector<Value>
35+
getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
36+
RewriterBase &rewriter,
37+
const LLVMTypeConverter &typeConverter) {
38+
SmallVector<Value> operands;
39+
auto opType = memrefVal.getType();
40+
41+
Type llvmStructType = typeConverter.convertType(opType);
42+
Value llvmStruct =
43+
rewriter
44+
.create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
45+
.getResult(0);
46+
MemRefDescriptor memRefDescriptor(llvmStruct);
47+
48+
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
49+
operands.push_back(ptr);
50+
51+
return operands;
52+
}
53+
3454
LogicalResult x86vector::MaskCompressOp::verify() {
3555
if (getSrc() && getConstantSrc())
3656
return emitError("cannot use both src and constant_src");
@@ -45,8 +65,8 @@ LogicalResult x86vector::MaskCompressOp::verify() {
4565
return success();
4666
}
4767

48-
SmallVector<Value>
49-
x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
68+
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
69+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
5070
auto loc = getLoc();
5171

5272
auto opType = getA().getType();
@@ -64,7 +84,8 @@ x86vector::MaskCompressOp::getIntrinsicOperands(RewriterBase &rewriter) {
6484
}
6585

6686
SmallVector<Value>
67-
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
87+
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
88+
const LLVMTypeConverter &typeConverter) {
6889
SmallVector<Value> operands(getOperands());
6990
// Dot product of all elements, broadcasted to all elements.
7091
Value scale =
@@ -74,5 +95,22 @@ x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter) {
7495
return operands;
7596
}
7697

98+
SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
99+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100+
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
101+
}
102+
103+
SmallVector<Value>
104+
x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
105+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
106+
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
107+
}
108+
109+
SmallVector<Value>
110+
x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
111+
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
112+
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
113+
}
114+
77115
#define GET_OP_CLASSES
78116
#include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ struct OneToOneIntrinsicOpConversion
9696
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
9797
PatternRewriter &rewriter) const override {
9898
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
99-
op.getIntrinsicOperands(rewriter), typeConverter,
100-
rewriter);
99+
op.getIntrinsicOperands(rewriter, typeConverter),
100+
typeConverter, rewriter);
101101
}
102102

103103
private:
@@ -114,7 +114,8 @@ void mlir::populateX86VectorLegalizeForLLVMExportPatterns(
114114

115115
void mlir::configureX86VectorLegalizeForExportTarget(
116116
LLVMConversionTarget &target) {
117-
target.addIllegalOp<MaskCompressOp, MaskRndScaleOp, MaskScaleFOp,
118-
Vp2IntersectOp, DotBF16Op, CvtPackedF32ToBF16Op, RsqrtOp,
119-
DotOp>();
117+
target.addIllegalOp<
118+
MaskCompressOp, MaskRndScaleOp, MaskScaleFOp, Vp2IntersectOp, DotBF16Op,
119+
CvtPackedF32ToBF16Op, CvtPackedEvenIndexedBF16ToF32Op,
120+
CvtPackedOddIndexedBF16ToF32Op, BcstBF16ToPackedF32Op, RsqrtOp, DotOp>();
120121
}

mlir/test/Dialect/X86Vector/legalize-for-llvm.mlir

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,60 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
9595
return %0 : vector<16xbf16>
9696
}
9797

98+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
99+
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
100+
%a: memref<8xbf16>) -> vector<4xf32>
101+
{
102+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps128"
103+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
104+
return %0 : vector<4xf32>
105+
}
106+
107+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
108+
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
109+
%a: memref<16xbf16>) -> vector<8xf32>
110+
{
111+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneebf162ps256"
112+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
113+
return %0 : vector<8xf32>
114+
}
115+
116+
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
117+
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
118+
%a: memref<8xbf16>) -> vector<4xf32>
119+
{
120+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps128"
121+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
122+
return %0 : vector<4xf32>
123+
}
124+
125+
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
126+
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
127+
%a: memref<16xbf16>) -> vector<8xf32>
128+
{
129+
// CHECK: llvm.call_intrinsic "llvm.x86.vcvtneobf162ps256"
130+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
131+
return %0 : vector<8xf32>
132+
}
133+
134+
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_128
135+
func.func @avxbf16_bsct_bf16_to_f32_packed_128(
136+
%a: memref<1xbf16>) -> vector<4xf32>
137+
{
138+
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps128"
139+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
140+
return %0 : vector<4xf32>
141+
}
142+
143+
// CHECK-LABEL: func @avxbf16_bsct_bf16_to_f32_packed_256
144+
func.func @avxbf16_bsct_bf16_to_f32_packed_256(
145+
%a: memref<1xbf16>) -> vector<8xf32>
146+
{
147+
// CHECK: llvm.call_intrinsic "llvm.x86.vbcstnebf162ps256"
148+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
149+
return %0 : vector<8xf32>
150+
}
151+
98152
// CHECK-LABEL: func @avx_rsqrt
99153
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
100154
{

mlir/test/Dialect/X86Vector/roundtrip.mlir

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,66 @@ func.func @avx512bf16_cvt_packed_f32_to_bf16_512(
9494
return %0 : vector<16xbf16>
9595
}
9696

97+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128
98+
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_128(
99+
%a: memref<8xbf16>) -> vector<4xf32>
100+
{
101+
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
102+
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
103+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
104+
return %0 : vector<4xf32>
105+
}
106+
107+
// CHECK-LABEL: func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256
108+
func.func @avxbf16_cvt_packed_even_indexed_bf16_to_f32_256(
109+
%a: memref<16xbf16>) -> vector<8xf32>
110+
{
111+
// CHECK: x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 {{.*}} :
112+
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
113+
%0 = x86vector.avx.cvt.packed.even.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
114+
return %0 : vector<8xf32>
115+
}
116+
117+
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128
118+
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_128(
119+
%a: memref<8xbf16>) -> vector<4xf32>
120+
{
121+
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
122+
// CHECK-SAME: memref<8xbf16> -> vector<4xf32>
123+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<8xbf16> -> vector<4xf32>
124+
return %0 : vector<4xf32>
125+
}
126+
127+
// CHECK-LABEL: func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256
128+
func.func @avxbf16_cvt_packed_odd_indexed_bf16_to_f32_256(
129+
%a: memref<16xbf16>) -> vector<8xf32>
130+
{
131+
// CHECK: x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 {{.*}} :
132+
// CHECK-SAME: memref<16xbf16> -> vector<8xf32>
133+
%0 = x86vector.avx.cvt.packed.odd.indexed.bf16_to_f32 %a : memref<16xbf16> -> vector<8xf32>
134+
return %0 : vector<8xf32>
135+
}
136+
137+
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_128
138+
func.func @avxbf16_bcst_bf16_to_f32_128(
139+
%a: memref<1xbf16>) -> vector<4xf32>
140+
{
141+
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
142+
// CHECK-SAME: memref<1xbf16> -> vector<4xf32>
143+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<4xf32>
144+
return %0 : vector<4xf32>
145+
}
146+
147+
// CHECK-LABEL: func @avxbf16_bcst_bf16_to_f32_256
148+
func.func @avxbf16_bcst_bf16_to_f32_256(
149+
%a: memref<1xbf16>) -> vector<8xf32>
150+
{
151+
// CHECK: x86vector.avx.bcst.bf16_to_f32.packed {{.*}} :
152+
// CHECK-SAME: memref<1xbf16> -> vector<8xf32>
153+
%0 = x86vector.avx.bcst.bf16_to_f32.packed %a : memref<1xbf16> -> vector<8xf32>
154+
return %0 : vector<8xf32>
155+
}
156+
97157
// CHECK-LABEL: func @avx_rsqrt
98158
func.func @avx_rsqrt(%a: vector<8xf32>) -> (vector<8xf32>)
99159
{

0 commit comments

Comments
 (0)