Skip to content

Commit 19bad95

Browse files
committed
[mlir] Introduce bare ptr calling convention for MemRefs in LLVM dialect
Summary: This patch introduces an alternative calling convention for MemRef function arguments in LLVM dialect. It converts MemRef function arguments to LLVM bare pointers to the MemRef element type instead of creating a MemRef descriptor. Bare pointers are then promoted to a MemRef descriptors at the beginning of the function. This calling convention is only enabled with a flag. This is a stepping stone towards having an alternative and simpler lowering for MemRefs when dynamic shapes are not needed. It can also be used to temporarily overcome the issue with passing 'noalias' attribute for MemRef arguments, discussed in [1, 2], since we can now convert: func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) { return } into: llvm.func @check_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) { %0 = llvm.mlir.undef ... %1 = llvm.insertvalue %arg0, %0[0] ... %2 = llvm.insertvalue %arg0, %1[1] ... ... llvm.return } Related discussion: [1] tensorflow/mlir#309 [2] tensorflow/mlir#337 WIP: I plan to move all the tests with only static shapes from convert-memref-ops.mlir to an independent file so that we can also have coverage for those tests with this alternative calling convention. Reviewers: ftynse, bondhugula, nicolasvasilache Subscribers: jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, csigg, arpith-jacob, mgester, lucyrfox, herhut, aartbik, liufengdb, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D72802
1 parent d629525 commit 19bad95

File tree

7 files changed

+221
-15
lines changed

7 files changed

+221
-15
lines changed

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h

+23-4
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,9 @@ class LLVMTypeConverter : public TypeConverter {
4747
/// Convert a function type. The arguments and results are converted one by
4848
/// one and results are packed into a wrapped LLVM IR structure type. `result`
4949
/// is populated with argument mapping.
50-
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
51-
SignatureConversion &result);
50+
virtual LLVM::LLVMType convertFunctionSignature(FunctionType type,
51+
bool isVariadic,
52+
SignatureConversion &result);
5253

5354
/// Convert a non-empty list of types to be returned from a function into a
5455
/// supported LLVM IR type. In particular, if more than one values is
@@ -81,6 +82,9 @@ class LLVMTypeConverter : public TypeConverter {
8182
llvm::Module *module;
8283
LLVM::LLVMDialect *llvmDialect;
8384

85+
// Extract an LLVM IR dialect type.
86+
LLVM::LLVMType unwrap(Type type);
87+
8488
private:
8589
Type convertStandardType(Type type);
8690

@@ -120,9 +124,24 @@ class LLVMTypeConverter : public TypeConverter {
120124
// Get the LLVM representation of the index type based on the bitwidth of the
121125
// pointer as defined by the data layout of the module.
122126
LLVM::LLVMType getIndexType();
127+
};
123128

124-
// Extract an LLVM IR dialect type.
125-
LLVM::LLVMType unwrap(Type type);
129+
/// Custom LLVMTypeConverter that overrides `convertFunctionSignature` to
130+
/// replace the type of MemRef function arguments with bare pointer to the
131+
/// MemRef element type.
132+
class BarePtrTypeConverter : public mlir::LLVMTypeConverter {
133+
public:
134+
using LLVMTypeConverter::LLVMTypeConverter;
135+
136+
/// Converts function signature following LLVMTypeConverter approach but
137+
/// replacing the type of MemRef arguments with a bare LLVM pointer to
138+
/// the MemRef element type.
139+
mlir::LLVM::LLVMType convertFunctionSignature(
140+
mlir::FunctionType type, bool isVariadic,
141+
mlir::LLVMTypeConverter::SignatureConversion &result) override;
142+
143+
private:
144+
mlir::Type convertMemRefTypeToBarePtr(mlir::MemRefType type);
126145
};
127146

128147
/// Helper class to produce LLVM dialect operations extracting or inserting

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h

+19-3
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ using LLVMTypeConverterMaker =
4444
std::function<std::unique_ptr<LLVMTypeConverter>(MLIRContext *)>;
4545

4646
/// Collect a set of patterns to convert memory-related operations from the
47-
/// Standard dialect to the LLVM dialect, excluding the memory-related
48-
/// operations.
47+
/// Standard dialect to the LLVM dialect, excluding non-memory-related
48+
/// operations and FuncOp.
4949
void populateStdToLLVMMemoryConversionPatters(
5050
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
5151

@@ -54,10 +54,26 @@ void populateStdToLLVMMemoryConversionPatters(
5454
void populateStdToLLVMNonMemoryConversionPatterns(
5555
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
5656

57-
/// Collect a set of patterns to convert from the Standard dialect to LLVM.
57+
/// Collect the default pattern to convert a FuncOp to the LLVM dialect.
58+
void populateStdToLLVMDefaultFuncOpConversionPattern(
59+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
60+
61+
/// Collect a set of default patterns to convert from the Standard dialect to
62+
/// LLVM.
5863
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
5964
OwningRewritePatternList &patterns);
6065

66+
/// Collect the pattern to convert a FuncOp to the LLVM dialect using the bare
67+
/// pointer calling convertion for MemRef function arguments.
68+
void populateStdToLLVMBarePtrFuncOpConversionPattern(
69+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
70+
71+
/// Collect a set of patterns to convert from the Standard dialect to
72+
/// LLVM using the bare pointer calling convention for MemRef function
73+
/// arguments.
74+
void populateStdToLLVMBarePtrConversionPatterns(
75+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
76+
6177
/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
6278
/// By default stdlib malloc/free are used for allocating MemRef payloads.
6379
/// Specifying `useAlloca-true` emits stack allocations instead. In the future

mlir/include/mlir/Transforms/DialectConversion.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
321321
TypeConverter::SignatureConversion &conversion);
322322

323323
/// Replace all the uses of the block argument `from` with value `to`.
324-
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
324+
void replaceUsesOfWith(Value from, Value to);
325325

326326
/// Return the converted value that replaces 'key'. Return 'key' if there is
327327
/// no such a converted value.

mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -667,7 +667,7 @@ struct GPUFuncOpLowering : LLVMOpLowering {
667667

668668
BlockArgument arg = block.getArgument(en.index());
669669
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
670-
rewriter.replaceUsesOfBlockArgument(arg, loaded);
670+
rewriter.replaceUsesOfWith(arg, loaded);
671671
}
672672
}
673673

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

+167-4
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ static llvm::cl::opt<bool>
4444
llvm::cl::desc("Replace emission of malloc/free by alloca"),
4545
llvm::cl::init(false));
4646

47+
static llvm::cl::opt<bool> clUseBarePtrCallConv(
48+
PASS_NAME "-use-bare-ptr-memref-call-conv",
49+
llvm::cl::desc("Replace FuncOp's MemRef arguments with "
50+
"bare pointers to the MemRef element types"),
51+
llvm::cl::init(false));
52+
4753
LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
4854
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
4955
assert(llvmDialect && "LLVM IR dialect is not registered");
@@ -239,6 +245,60 @@ Type LLVMTypeConverter::convertStandardType(Type t) {
239245
.Default([](Type) { return Type(); });
240246
}
241247

248+
// Converts function signature following LLVMTypeConverter approach but
249+
// replacing the type of MemRef arguments with a bare LLVM pointer to
250+
// the MemRef element type.
251+
LLVM::LLVMType BarePtrTypeConverter::convertFunctionSignature(
252+
FunctionType type, bool isVariadic,
253+
LLVMTypeConverter::SignatureConversion &result) {
254+
// Convert argument types one by one and check for errors.
255+
for (auto &en : llvm::enumerate(type.getInputs())) {
256+
Type type = en.value();
257+
Type converted;
258+
if (auto memrefTy = type.dyn_cast<MemRefType>())
259+
converted = convertMemRefTypeToBarePtr(memrefTy)
260+
.dyn_cast_or_null<LLVM::LLVMType>();
261+
else
262+
converted = convertType(type).dyn_cast_or_null<LLVM::LLVMType>();
263+
264+
if (!converted)
265+
return {};
266+
result.addInputs(en.index(), converted);
267+
}
268+
269+
SmallVector<LLVM::LLVMType, 8> argTypes;
270+
argTypes.reserve(llvm::size(result.getConvertedTypes()));
271+
for (Type type : result.getConvertedTypes())
272+
argTypes.push_back(unwrap(type));
273+
274+
// If function does not return anything, create the void result type, if it
275+
// returns on element, convert it, otherwise pack the result types into a
276+
// struct.
277+
LLVM::LLVMType resultType =
278+
type.getNumResults() == 0
279+
? LLVM::LLVMType::getVoidTy(llvmDialect)
280+
: unwrap(packFunctionResults(type.getResults()));
281+
if (!resultType)
282+
return {};
283+
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
284+
}
285+
286+
// Converts MemRefType to a bare LLVM pointer to the MemRef element type.
287+
Type BarePtrTypeConverter::convertMemRefTypeToBarePtr(MemRefType type) {
288+
int64_t offset;
289+
SmallVector<int64_t, 4> strides;
290+
bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
291+
assert(strideSuccess &&
292+
"Non-strided layout maps must have been normalized away");
293+
(void)strideSuccess;
294+
295+
LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
296+
if (!elementType)
297+
return {};
298+
auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
299+
return ptrTy;
300+
}
301+
242302
LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
243303
LLVMTypeConverter &lowering_,
244304
PatternBenefit benefit)
@@ -548,7 +608,84 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
548608
for (unsigned idx : promotedArgIndices) {
549609
BlockArgument arg = firstBlock->getArgument(idx);
550610
Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
551-
rewriter.replaceUsesOfBlockArgument(arg, loaded);
611+
rewriter.replaceUsesOfWith(arg, loaded);
612+
}
613+
}
614+
615+
rewriter.eraseOp(op);
616+
return matchSuccess();
617+
}
618+
};
619+
620+
// FuncOp conversion that converts MemRef arguments to bare pointers to the type
621+
// of the MemRef.
622+
struct BarePtrFuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
623+
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;
624+
625+
PatternMatchResult
626+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
627+
ConversionPatternRewriter &rewriter) const override {
628+
auto funcOp = cast<FuncOp>(op);
629+
FunctionType type = funcOp.getType();
630+
auto funcLoc = funcOp.getLoc();
631+
632+
// Store the positions of memref-typed arguments so that we can promote them
633+
// to MemRef descriptor structs at the beginning of the function.
634+
SmallVector<std::pair<unsigned, Type>, 4> promotedArgIndices;
635+
promotedArgIndices.reserve(type.getNumInputs());
636+
for (auto en : llvm::enumerate(type.getInputs())) {
637+
if (en.value().isa<MemRefType>())
638+
promotedArgIndices.push_back({en.index(), en.value()});
639+
}
640+
641+
// Convert the original function arguments. MemRef types are lowered to bare
642+
// pointers to the MemRef element type.
643+
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
644+
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
645+
auto llvmType = lowering.convertFunctionSignature(
646+
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);
647+
648+
// Only retain those attributes that are not constructed by build.
649+
SmallVector<NamedAttribute, 4> attributes;
650+
for (const auto &attr : funcOp.getAttrs()) {
651+
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
652+
attr.first.is(impl::getTypeAttrName()) ||
653+
attr.first.is("std.varargs"))
654+
continue;
655+
attributes.push_back(attr);
656+
}
657+
658+
// Create an LLVM function, use external linkage by default until MLIR
659+
// functions have linkage.
660+
auto newFuncOp =
661+
rewriter.create<LLVM::LLVMFuncOp>(funcLoc, funcOp.getName(), llvmType,
662+
LLVM::Linkage::External, attributes);
663+
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
664+
newFuncOp.end());
665+
666+
// Tell the rewriter to convert the region signature.
667+
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);
668+
669+
// Promote bare pointers from MemRef arguments to a MemRef descriptor struct
670+
// at the beginning of the function so that all the MemRefs in the function
671+
// have a uniform representation.
672+
if (!newFuncOp.getBody().empty()) {
673+
Block *firstBlock = &newFuncOp.getBody().front();
674+
rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
675+
for (auto argIdxTypePair : promotedArgIndices) {
676+
// Replace argument with a placeholder (undef), promote argument to a
677+
// MemRef descriptor and replace placeholder with the last instruction
678+
// of the MemRef descriptor. The placeholder is needed to avoid
679+
// replacing argument uses in the MemRef descriptor instructions.
680+
BlockArgument arg = firstBlock->getArgument(argIdxTypePair.first);
681+
Value placeHolder =
682+
rewriter.create<LLVM::UndefOp>(funcLoc, arg.getType());
683+
rewriter.replaceUsesOfWith(arg, placeHolder);
684+
auto desc = MemRefDescriptor::fromStaticShape(
685+
rewriter, funcLoc, lowering,
686+
argIdxTypePair.second.cast<MemRefType>(), arg);
687+
rewriter.replaceUsesOfWith(placeHolder, desc);
688+
placeHolder.getDefiningOp()->erase();
552689
}
553690
}
554691

@@ -2126,7 +2263,6 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
21262263
// clang-format off
21272264
patterns.insert<
21282265
DimOpLowering,
2129-
FuncOpConversion,
21302266
LoadOpLowering,
21312267
MemRefCastOpLowering,
21322268
StoreOpLowering,
@@ -2139,8 +2275,26 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
21392275
// clang-format on
21402276
}
21412277

2278+
void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
2279+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2280+
patterns.insert<FuncOpConversion>(*converter.getDialect(), converter);
2281+
}
2282+
21422283
void mlir::populateStdToLLVMConversionPatterns(
21432284
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2285+
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns);
2286+
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
2287+
populateStdToLLVMMemoryConversionPatters(converter, patterns);
2288+
}
2289+
2290+
void mlir::populateStdToLLVMBarePtrFuncOpConversionPattern(
2291+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2292+
patterns.insert<BarePtrFuncOpConversion>(*converter.getDialect(), converter);
2293+
}
2294+
2295+
void mlir::populateStdToLLVMBarePtrConversionPatterns(
2296+
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
2297+
populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns);
21442298
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
21452299
populateStdToLLVMMemoryConversionPatters(converter, patterns);
21462300
}
@@ -2210,6 +2364,12 @@ makeStandardToLLVMTypeConverter(MLIRContext *context) {
22102364
return std::make_unique<LLVMTypeConverter>(context);
22112365
}
22122366

2367+
/// Create an instance of BarePtrTypeConverter in the given context.
2368+
static std::unique_ptr<LLVMTypeConverter>
2369+
makeStandardToLLVMBarePtrTypeConverter(MLIRContext *context) {
2370+
return std::make_unique<BarePtrTypeConverter>(context);
2371+
}
2372+
22132373
namespace {
22142374
/// A pass converting MLIR operations into the LLVM IR dialect.
22152375
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
@@ -2274,6 +2434,9 @@ static PassRegistration<LLVMLoweringPass>
22742434
"Standard to the LLVM dialect",
22752435
[] {
22762436
return std::make_unique<LLVMLoweringPass>(
2277-
clUseAlloca.getValue(), populateStdToLLVMConversionPatterns,
2278-
makeStandardToLLVMTypeConverter);
2437+
clUseAlloca.getValue(),
2438+
clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns
2439+
: populateStdToLLVMConversionPatterns,
2440+
clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter
2441+
: makeStandardToLLVMTypeConverter);
22792442
});

mlir/lib/Transforms/DialectConversion.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -861,8 +861,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
861861
return impl->applySignatureConversion(region, conversion);
862862
}
863863

864-
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
865-
Value to) {
864+
void ConversionPatternRewriter::replaceUsesOfWith(Value from, Value to) {
866865
for (auto &u : from.getUses()) {
867866
if (u.getOwner() == to.getDefiningOp())
868867
continue;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
// RUN: mlir-opt -convert-std-to-llvm -split-input-file -convert-std-to-llvm-use-bare-ptr-memref-call-conv=1 %s | FileCheck %s --check-prefix=BAREPTR
2+
3+
// BAREPTR-LABEL: func @check_noalias
4+
// BAREPTR-SAME: [[ARG:%.*]]: !llvm<"float*"> {llvm.noalias = true}
5+
func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) {
6+
return
7+
}
8+
9+
// WIP: Move tests with static shapes from convert-memref-ops.mlir here.

0 commit comments

Comments
 (0)