@@ -44,6 +44,12 @@ static llvm::cl::opt<bool>
44
44
llvm::cl::desc (" Replace emission of malloc/free by alloca" ),
45
45
llvm::cl::init(false ));
46
46
47
+ static llvm::cl::opt<bool > clUseBarePtrCallConv (
48
+ PASS_NAME " -use-bare-ptr-memref-call-conv" ,
49
+ llvm::cl::desc (" Replace FuncOp's MemRef arguments with "
50
+ " bare pointers to the MemRef element types" ),
51
+ llvm::cl::init(false ));
52
+
47
53
LLVMTypeConverter::LLVMTypeConverter (MLIRContext *ctx)
48
54
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
49
55
assert (llvmDialect && " LLVM IR dialect is not registered" );
@@ -239,6 +245,60 @@ Type LLVMTypeConverter::convertStandardType(Type t) {
239
245
.Default ([](Type) { return Type (); });
240
246
}
241
247
248
+ // Converts function signature following LLVMTypeConverter approach but
249
+ // replacing the type of MemRef arguments with a bare LLVM pointer to
250
+ // the MemRef element type.
251
+ LLVM::LLVMType BarePtrTypeConverter::convertFunctionSignature (
252
+ FunctionType type, bool isVariadic,
253
+ LLVMTypeConverter::SignatureConversion &result) {
254
+ // Convert argument types one by one and check for errors.
255
+ for (auto &en : llvm::enumerate (type.getInputs ())) {
256
+ Type type = en.value ();
257
+ Type converted;
258
+ if (auto memrefTy = type.dyn_cast <MemRefType>())
259
+ converted = convertMemRefTypeToBarePtr (memrefTy)
260
+ .dyn_cast_or_null <LLVM::LLVMType>();
261
+ else
262
+ converted = convertType (type).dyn_cast_or_null <LLVM::LLVMType>();
263
+
264
+ if (!converted)
265
+ return {};
266
+ result.addInputs (en.index (), converted);
267
+ }
268
+
269
+ SmallVector<LLVM::LLVMType, 8 > argTypes;
270
+ argTypes.reserve (llvm::size (result.getConvertedTypes ()));
271
+ for (Type type : result.getConvertedTypes ())
272
+ argTypes.push_back (unwrap (type));
273
+
274
+ // If function does not return anything, create the void result type, if it
275
+ // returns on element, convert it, otherwise pack the result types into a
276
+ // struct.
277
+ LLVM::LLVMType resultType =
278
+ type.getNumResults () == 0
279
+ ? LLVM::LLVMType::getVoidTy (llvmDialect)
280
+ : unwrap (packFunctionResults (type.getResults ()));
281
+ if (!resultType)
282
+ return {};
283
+ return LLVM::LLVMType::getFunctionTy (resultType, argTypes, isVariadic);
284
+ }
285
+
286
+ // Converts MemRefType to a bare LLVM pointer to the MemRef element type.
287
+ Type BarePtrTypeConverter::convertMemRefTypeToBarePtr (MemRefType type) {
288
+ int64_t offset;
289
+ SmallVector<int64_t , 4 > strides;
290
+ bool strideSuccess = succeeded (getStridesAndOffset (type, strides, offset));
291
+ assert (strideSuccess &&
292
+ " Non-strided layout maps must have been normalized away" );
293
+ (void )strideSuccess;
294
+
295
+ LLVM::LLVMType elementType = unwrap (convertType (type.getElementType ()));
296
+ if (!elementType)
297
+ return {};
298
+ auto ptrTy = elementType.getPointerTo (type.getMemorySpace ());
299
+ return ptrTy;
300
+ }
301
+
242
302
LLVMOpLowering::LLVMOpLowering (StringRef rootOpName, MLIRContext *context,
243
303
LLVMTypeConverter &lowering_,
244
304
PatternBenefit benefit)
@@ -548,7 +608,84 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
548
608
for (unsigned idx : promotedArgIndices) {
549
609
BlockArgument arg = firstBlock->getArgument (idx);
550
610
Value loaded = rewriter.create <LLVM::LoadOp>(funcOp.getLoc (), arg);
551
- rewriter.replaceUsesOfBlockArgument (arg, loaded);
611
+ rewriter.replaceUsesOfWith (arg, loaded);
612
+ }
613
+ }
614
+
615
+ rewriter.eraseOp (op);
616
+ return matchSuccess ();
617
+ }
618
+ };
619
+
620
+ // FuncOp conversion that converts MemRef arguments to bare pointers to the type
621
+ // of the MemRef.
622
+ struct BarePtrFuncOpConversion : public LLVMLegalizationPattern <FuncOp> {
623
+ using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
624
+
625
+ PatternMatchResult
626
+ matchAndRewrite (Operation *op, ArrayRef<Value> operands,
627
+ ConversionPatternRewriter &rewriter) const override {
628
+ auto funcOp = cast<FuncOp>(op);
629
+ FunctionType type = funcOp.getType ();
630
+ auto funcLoc = funcOp.getLoc ();
631
+
632
+ // Store the positions of memref-typed arguments so that we can promote them
633
+ // to MemRef descriptor structs at the beginning of the function.
634
+ SmallVector<std::pair<unsigned , Type>, 4 > promotedArgIndices;
635
+ promotedArgIndices.reserve (type.getNumInputs ());
636
+ for (auto en : llvm::enumerate (type.getInputs ())) {
637
+ if (en.value ().isa <MemRefType>())
638
+ promotedArgIndices.push_back ({en.index (), en.value ()});
639
+ }
640
+
641
+ // Convert the original function arguments. MemRef types are lowered to bare
642
+ // pointers to the MemRef element type.
643
+ auto varargsAttr = funcOp.getAttrOfType <BoolAttr>(" std.varargs" );
644
+ TypeConverter::SignatureConversion result (funcOp.getNumArguments ());
645
+ auto llvmType = lowering.convertFunctionSignature (
646
+ funcOp.getType (), varargsAttr && varargsAttr.getValue (), result);
647
+
648
+ // Only retain those attributes that are not constructed by build.
649
+ SmallVector<NamedAttribute, 4 > attributes;
650
+ for (const auto &attr : funcOp.getAttrs ()) {
651
+ if (attr.first .is (SymbolTable::getSymbolAttrName ()) ||
652
+ attr.first .is (impl::getTypeAttrName ()) ||
653
+ attr.first .is (" std.varargs" ))
654
+ continue ;
655
+ attributes.push_back (attr);
656
+ }
657
+
658
+ // Create an LLVM function, use external linkage by default until MLIR
659
+ // functions have linkage.
660
+ auto newFuncOp =
661
+ rewriter.create <LLVM::LLVMFuncOp>(funcLoc, funcOp.getName (), llvmType,
662
+ LLVM::Linkage::External, attributes);
663
+ rewriter.inlineRegionBefore (funcOp.getBody (), newFuncOp.getBody (),
664
+ newFuncOp.end ());
665
+
666
+ // Tell the rewriter to convert the region signature.
667
+ rewriter.applySignatureConversion (&newFuncOp.getBody (), result);
668
+
669
+ // Promote bare pointers from MemRef arguments to a MemRef descriptor struct
670
+ // at the beginning of the function so that all the MemRefs in the function
671
+ // have a uniform representation.
672
+ if (!newFuncOp.getBody ().empty ()) {
673
+ Block *firstBlock = &newFuncOp.getBody ().front ();
674
+ rewriter.setInsertionPoint (firstBlock, firstBlock->begin ());
675
+ for (auto argIdxTypePair : promotedArgIndices) {
676
+ // Replace argument with a placeholder (undef), promote argument to a
677
+ // MemRef descriptor and replace placeholder with the last instruction
678
+ // of the MemRef descriptor. The placeholder is needed to avoid
679
+ // replacing argument uses in the MemRef descriptor instructions.
680
+ BlockArgument arg = firstBlock->getArgument (argIdxTypePair.first );
681
+ Value placeHolder =
682
+ rewriter.create <LLVM::UndefOp>(funcLoc, arg.getType ());
683
+ rewriter.replaceUsesOfWith (arg, placeHolder);
684
+ auto desc = MemRefDescriptor::fromStaticShape (
685
+ rewriter, funcLoc, lowering,
686
+ argIdxTypePair.second .cast <MemRefType>(), arg);
687
+ rewriter.replaceUsesOfWith (placeHolder, desc);
688
+ placeHolder.getDefiningOp ()->erase ();
552
689
}
553
690
}
554
691
@@ -2126,7 +2263,6 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
2126
2263
// clang-format off
2127
2264
patterns.insert <
2128
2265
DimOpLowering,
2129
- FuncOpConversion,
2130
2266
LoadOpLowering,
2131
2267
MemRefCastOpLowering,
2132
2268
StoreOpLowering,
@@ -2139,8 +2275,26 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
2139
2275
// clang-format on
2140
2276
}
2141
2277
2278
+ void mlir::populateStdToLLVMDefaultFuncOpConversionPattern (
2279
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2280
+ patterns.insert <FuncOpConversion>(*converter.getDialect (), converter);
2281
+ }
2282
+
2142
2283
void mlir::populateStdToLLVMConversionPatterns (
2143
2284
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2285
+ populateStdToLLVMDefaultFuncOpConversionPattern (converter, patterns);
2286
+ populateStdToLLVMNonMemoryConversionPatterns (converter, patterns);
2287
+ populateStdToLLVMMemoryConversionPatters (converter, patterns);
2288
+ }
2289
+
2290
+ void mlir::populateStdToLLVMBarePtrFuncOpConversionPattern (
2291
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2292
+ patterns.insert <BarePtrFuncOpConversion>(*converter.getDialect (), converter);
2293
+ }
2294
+
2295
+ void mlir::populateStdToLLVMBarePtrConversionPatterns (
2296
+ LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2297
+ populateStdToLLVMBarePtrFuncOpConversionPattern (converter, patterns);
2144
2298
populateStdToLLVMNonMemoryConversionPatterns (converter, patterns);
2145
2299
populateStdToLLVMMemoryConversionPatters (converter, patterns);
2146
2300
}
@@ -2210,6 +2364,12 @@ makeStandardToLLVMTypeConverter(MLIRContext *context) {
2210
2364
return std::make_unique<LLVMTypeConverter>(context);
2211
2365
}
2212
2366
2367
+ // / Create an instance of BarePtrTypeConverter in the given context.
2368
+ static std::unique_ptr<LLVMTypeConverter>
2369
+ makeStandardToLLVMBarePtrTypeConverter (MLIRContext *context) {
2370
+ return std::make_unique<BarePtrTypeConverter>(context);
2371
+ }
2372
+
2213
2373
namespace {
2214
2374
// / A pass converting MLIR operations into the LLVM IR dialect.
2215
2375
struct LLVMLoweringPass : public ModulePass <LLVMLoweringPass> {
@@ -2274,6 +2434,9 @@ static PassRegistration<LLVMLoweringPass>
2274
2434
" Standard to the LLVM dialect" ,
2275
2435
[] {
2276
2436
return std::make_unique<LLVMLoweringPass>(
2277
- clUseAlloca.getValue (), populateStdToLLVMConversionPatterns,
2278
- makeStandardToLLVMTypeConverter);
2437
+ clUseAlloca.getValue (),
2438
+ clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns
2439
+ : populateStdToLLVMConversionPatterns,
2440
+ clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter
2441
+ : makeStandardToLLVMTypeConverter);
2279
2442
});
0 commit comments