Skip to content

[mlir][vector] Support multi-dimensional vectors in VectorFromElementsLowering #151175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

yangtetris
Copy link
Contributor

@yangtetris yangtetris commented Jul 29, 2025

This patch extends the VectorFromElementsLowering conversion pattern to support
vectors of any rank, removing the previous restriction to 0D/1D vectors only.

Implementation Details:

  1. 0D vectors: Handled explicitly since LLVMTypeConverter converts them to
    length-1 1D vectors
  2. 1D vectors: Direct construction using llvm.insertelement operations
  3. N-D vectors: Two-phase construction:
    • Build 1D vectors for the innermost dimension using llvm.insertelement
    • Assemble them into the nested aggregate structure using llvm.insertvalue
      and nDVectorIterate
  4. Use direct LLVM dialect operations instead of intermediate vector.insert operations for efficiency

Example:

// Before: Failed for rank > 1
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>

// After: Converts to nested aggregate
%poison = llvm.mlir.poison : !llvm.array<2 x vector<2xf32>>
%inner0 = llvm.insertelement %e0, %poison_1d[%c0] : vector<2xf32>
%inner0 = llvm.insertelement %e1, %inner0[%c1] : vector<2xf32>
%inner1 = llvm.insertelement %e2, %poison_1d[%c0] : vector<2xf32>
%inner1 = llvm.insertelement %e3, %inner1[%c1] : vector<2xf32>
%result = llvm.insertvalue %inner0, %poison[0] : !llvm.array<2 x vector<2xf32>>
%result = llvm.insertvalue %inner1, %result[1] : !llvm.array<2 x vector<2xf32>>

@llvmbot
Copy link
Member

llvmbot commented Jul 29, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Yang Bai (yangtetris)

Changes

This patch extends the VectorFromElementsLowering conversion pattern to support
vectors of any rank, removing the previous restriction to 1D vectors only.

Implementation Details:

  1. 0D vectors: Handled explicitly since LLVMTypeConverter converts them to
    length-1 1D vectors
  2. 1D vectors: Direct construction using llvm.insertelement operations
  3. N-D vectors: Two-phase construction:
    • Build 1D vectors for the innermost dimension using llvm.insertelement
    • Assemble them into the nested aggregate structure using llvm.insertvalue
      and nDVectorIterate
  4. Use direct LLVM dialect operations instead of intermediate vector.insert operations for efficiency

Example:

// Before: Failed for rank &gt; 1
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector&lt;2x2xf32&gt;

// After: Converts to nested aggregate
%poison = llvm.mlir.poison : !llvm.array&lt;2 x vector&lt;2xf32&gt;&gt;
%inner0 = llvm.insertelement %e0, %poison_1d[%c0] : vector&lt;2xf32&gt;
%inner0 = llvm.insertelement %e1, %inner0[%c1] : vector&lt;2xf32&gt;
%inner1 = llvm.insertelement %e2, %poison_1d[%c0] : vector&lt;2xf32&gt;
%inner1 = llvm.insertelement %e3, %inner1[%c1] : vector&lt;2xf32&gt;
%result = llvm.insertvalue %inner0, %poison[0] : !llvm.array&lt;2 x vector&lt;2xf32&gt;&gt;
%result = llvm.insertvalue %inner1, %result[1] : !llvm.array&lt;2 x vector&lt;2xf32&gt;&gt;

Full diff: https://github.com/llvm/llvm-project/pull/151175.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp (+55-8)
  • (modified) mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir (+24)
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index 17a79e3815b97..26d056cadb19c 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1890,15 +1890,62 @@ struct VectorFromElementsLowering
                   ConversionPatternRewriter &rewriter) const override {
     Location loc = fromElementsOp.getLoc();
     VectorType vectorType = fromElementsOp.getType();
-    // TODO: Multi-dimensional vectors lower to !llvm.array<... x vector<>>.
-    // Such ops should be handled in the same way as vector.insert.
-    if (vectorType.getRank() > 1)
-      return rewriter.notifyMatchFailure(fromElementsOp,
-                                         "rank > 1 vectors are not supported");
     Type llvmType = typeConverter->convertType(vectorType);
-    Value result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
-    for (auto [idx, val] : llvm::enumerate(adaptor.getElements()))
-      result = vector::InsertOp::create(rewriter, loc, val, result, idx);
+    Type llvmIndexType = typeConverter->convertType(rewriter.getIndexType());
+
+    Value result;
+    // 0D vectors are converted to legnth-1 1D vectors by LLVMTypeConverter.
+    if (vectorType.getRank() == 0) {
+      result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+      auto index0 = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, 0);
+      result = LLVM::InsertElementOp::create(rewriter, loc, result, adaptor.getElements().front(), index0);
+      rewriter.replaceOp(fromElementsOp, result);
+      return success();
+    }
+    
+    // Build 1D vectors for the innermost dimension
+    int64_t innerDimSize = vectorType.getShape().back();
+    int64_t numInnerVectors = vectorType.getNumElements() / innerDimSize;
+
+    SmallVector<Value> innerVectors;
+    innerVectors.reserve(numInnerVectors);
+
+    auto innerVectorType = VectorType::get(innerDimSize, vectorType.getElementType());
+    Type llvmInnerType = typeConverter->convertType(innerVectorType);
+
+    int64_t elementInVectorIdx = 0;
+    Value innerVector;
+    for (auto val : adaptor.getElements()) {
+      if (elementInVectorIdx == 0)
+        innerVector = LLVM::PoisonOp::create(rewriter, loc, llvmInnerType);
+      auto position = LLVM::ConstantOp::create(rewriter, loc, llvmIndexType, elementInVectorIdx);
+      innerVector = LLVM::InsertElementOp::create(rewriter, loc, llvmInnerType, innerVector, val, position);
+      if (++elementInVectorIdx == innerDimSize) {
+        innerVectors.push_back(innerVector);
+        elementInVectorIdx = 0;
+      }
+    }
+
+    // For 1D vectors, we can just return the first innermost vector.
+    if (vectorType.getRank() == 1) {
+      rewriter.replaceOp(fromElementsOp, innerVectors.front());
+      return success();
+    }
+
+    // Now build the nested aggregate structure from these 1D vectors.
+    result = LLVM::PoisonOp::create(rewriter, loc, llvmType);
+    
+    // Use the same iteration approach as VectorBroadcastScalarToNdLowering to
+    // insert the 1D vectors into the aggregate.
+    auto vectorTypeInfo = LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
+    if (!vectorTypeInfo.llvmNDVectorTy)
+      return failure();
+    int64_t vectorIdx = 0;
+    nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
+      result = LLVM::InsertValueOp::create(rewriter, loc, result, 
+                                           innerVectors[vectorIdx++], position);
+    });
+    
     rewriter.replaceOp(fromElementsOp, result);
     return success();
   }
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 31e17fb3e3cc6..834858c0b7c8f 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -2286,6 +2286,30 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
 
 // -----
 
+// CHECK-LABEL: func.func @from_elements_3d(
+//  CHECK-SAME:     %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32)
+//       CHECK:   %[[UNDEF_VEC0:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:   %[[C0_0:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[VEC0_0:.*]] = llvm.insertelement %[[ARG_0]], %[[UNDEF_VEC0]][%[[C0_0]] : i64] : vector<2xf32>
+//       CHECK:   %[[C1_0:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[VEC0_1:.*]] = llvm.insertelement %[[ARG_1]], %[[VEC0_0]][%[[C1_0]] : i64] : vector<2xf32>
+//       CHECK:   %[[UNDEF_VEC1:.*]] = llvm.mlir.poison : vector<2xf32>
+//       CHECK:   %[[C0_1:.*]] = llvm.mlir.constant(0 : i64) : i64
+//       CHECK:   %[[VEC1_0:.*]] = llvm.insertelement %[[ARG_2]], %[[UNDEF_VEC1]][%[[C0_1]] : i64] : vector<2xf32>
+//       CHECK:   %[[C1_1:.*]] = llvm.mlir.constant(1 : i64) : i64
+//       CHECK:   %[[VEC1_1:.*]] = llvm.insertelement %[[ARG_3]], %[[VEC1_0]][%[[C1_1]] : i64] : vector<2xf32>
+//       CHECK:   %[[UNDEF_RES:.*]] = llvm.mlir.poison : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[RES_0:.*]] = llvm.insertvalue %[[VEC0_1]], %[[UNDEF_RES]][0, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[RES_1:.*]] = llvm.insertvalue %[[VEC1_1]], %[[RES_0]][1, 0] : !llvm.array<2 x array<1 x vector<2xf32>>>
+//       CHECK:   %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[RES_1]] : !llvm.array<2 x array<1 x vector<2xf32>>> to vector<2x1x2xf32>
+//       CHECK:   return %[[CAST]]
+func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32>
+  return %0 : vector<2x1x2xf32>
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // vector.to_elements
 //===----------------------------------------------------------------------===//

Copy link

github-actions bot commented Jul 29, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just minor comments. Feel free to address them before landing. Thanks!

Comment on lines 1942 to 1951
// Use the same iteration approach as VectorBroadcastScalarToNdLowering to
// insert the 1D vectors into the aggregate.
auto vectorTypeInfo =
LLVM::detail::extractNDVectorTypeInfo(vectorType, *getTypeConverter());
if (!vectorTypeInfo.llvmNDVectorTy)
return failure();
int64_t vectorIdx = 0;
nDVectorIterate(vectorTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
result = LLVM::InsertValueOp::create(rewriter, loc, result,
innerVectors[vectorIdx++], position);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a change to refactor this code for both cases? This sounds like a common pattern that other ops might need as well...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. Other vector ops might also use this pattern. I just added a new overload to nDVectorIterate which accepts a VectorType and internally calls extractNDVectorTypeInfo. But I didn't change the usage in VectorBroadcastScalarToNdLowering, because it needs to do some things first that depend on extractNDVectorTypeInfo before it can execute nDVectorIterate.

yangtetris and others added 2 commits July 31, 2025 10:23
Co-authored-by: Nicolas Vasilache <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants