-
Notifications
You must be signed in to change notification settings - Fork 14.6k
[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering #151175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering #151175
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Yang Bai (yangtetris) ChangesThis patch extends the Implementation Details:
Example: // Before: Failed for rank > 1
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>
// After: Converts to nested aggregate
%poison = llvm.mlir.poison : !llvm.array<2 x vector<2xf32>>
%inner0 = llvm.insertelement %e0, %poison_1d[%c0] : vector<2xf32>
%inner0 = llvm.insertelement %e1, %inner0[%c1] : vector<2xf32>
%inner1 = llvm.insertelement %e2, %poison_1d[%c0] : vector<2xf32>
%inner1 = llvm.insertelement %e3, %inner1[%c1] : vector<2xf32>
%result = llvm.insertvalue %inner0, %poison[0] : !llvm.array<2 x vector<2xf32>>
%result = llvm.insertvalue %inner1, %result[1] : !llvm.array<2 x vector<2xf32>> Full diff: https://github.com/llvm/llvm-project/pull/151175.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3815b97..26d056cadb19c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering
ConversionPatternRewriter &rewriter) const override {
Location loc = fromElementsOp.getLoc();
VectorType vectorType = fromElementsOp.getType();
- // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
- // Such ops should be handled in the same way as vector.insert.
- if (vectorType.getRank() > 1)
- return rewriter.notifyMatchFailure(fromElementsOp,
- "rank > 1 vectors are not supported");
Type llvmType = typeConverter->convertType(vectorType);
- Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
- for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
- result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+ Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
+
+ Value result;
+ // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+ if (vectorType.getRank() == 0) {
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+ auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
+ result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+ rewriter.replaceOp(fromElementsOp, result);
+ return success();
+ }
+
+ // Build 1D vectors for the innermost dimension
+ int64_t innerDimSize = vectorType.getShape().back();
+ int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
+
+ SmallVector<Value> innerVectors;
+ innerVectors.reserve(numInnerVectors);
+
+ auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+ Type llvmInnerType = typeConverter->convertType(innerVectorType);
+
+ int64_t elementInVectorIdx = 0;
+ Value innerVector;
+ for (auto val : adaptor.getElements()) {
+ if (elementInVectorIdx == 0)
+ innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
+ auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
+ innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+ if (++elementInVectorIdx == innerDimSize) {
+ innerVectors.push_back(innerVector);
+ elementInVectorIdx = 0;
+ }
+ }
+
+ // For 1D vectors, we can just return the first innermost vector.
+ if (vectorType.getRank() == 1) {
+ rewriter.replaceOp(fromElementsOp, innerVectors.front());
+ return success();
+ }
+
+ // Now build the nested aggregate structure from these 1D vectors.
+ result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+
+ // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
+ // insert the 1D vectors into the aggregate.
+ auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+ if (!vectorTypeInfo.llvmNDVectorTy)
+ return failure();
+ int64_t vectorIdx = 0;
+ nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
+ result = LLVM::InsertValueOp::create(rewriter, loc, result,
+ innerVectors[vectorIdx++], position);
+ });
+
rewriter.replaceOp(fromElementsOp, result);
return success();
}
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb3e3cc6..834858c0b7c8f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// -----
+// CHECK-LABEL: func.func @from_elements_3d(
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+// CHECK: %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+// CHECK: %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+// CHECK: %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+// CHECK: %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+// CHECK: %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+// CHECK: %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+// CHECK: %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+// CHECK: return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+ return %0 : vector<2x1x2xf32>
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.to_elements
//===----------------------------------------------------------------------===//
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just minor comments. Feel free to address them before landing. Thanks!
// Use the same iteration approach as VectorBroadcastScalarToNdLowering to | ||
// insert the 1D vectors into the aggregate. | ||
auto vectorTypeInfo = | ||
LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter()); | ||
if (!vectorTypeInfo.llvmNDVectorTy) | ||
return failure(); | ||
int64_t vectorIdx = 0; | ||
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) { | ||
result = LLVM::InsertValueOp::create(rewriter, loc, result, | ||
innerVectors[vectorIdx++], position); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a change to refactor this code for both cases? This sounds like a common pattern that other ops might need as well...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah. Other vector ops might also use this pattern. I just added a new overload to nDVectorIterate
which accepts a VectorType
and internally calls extractNDVectorTypeInfo
. But I didn't change the usage in VectorBroadcastScalarToNdLowering
, because it needs to do some things first that depend on extractNDVectorTypeInfo before it can execute nDVectorIterate.
Co-authored-by: Nicolas Vasilache <[email protected]>
This patch extends the
VectorFromElementsLowering
conversion pattern to supportvectors of any rank, removing the previous restriction to 0D/1D vectors only.
Implementation Details:
length-1 1D vectors
llvm.insertelement
llvm.insertvalue
and
nDVectorIterate
Example: