Skip to content

Commit ac2e407

Browse files
committed
[mlir][tblgen] add concrete create methods
1 parent 2dfcc43 commit ac2e407

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

mlir/include/mlir/TableGen/Class.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ class MethodParameter {
7171
StringRef getName() const { return name; }
7272
/// Returns true if the parameter has a default value.
7373
bool hasDefaultValue() const { return !defaultValue.empty(); }
74+
StringRef getDefaultValue() const { return defaultValue; }
75+
bool isOptional() const { return optional; }
7476

7577
private:
7678
/// The C++ type.

mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
230230
231231
)";
232232

233+
static const char *const inlineCreateBody = R"(
234+
OperationState __state__(loc, getOperationName());
235+
build(builder, __state__, {0});
236+
auto __res__ = dyn_cast<{1}>(builder.create(__state__));
237+
assert(__res__ && "builder didn't return the right type");
238+
return __res__;
239+
)";
240+
233241
//===----------------------------------------------------------------------===//
234242
// Utility structs and functions
235243
//===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
665673
// Generates the build() method that takes each operand/attribute
666674
// as a stand-alone parameter.
667675
void genSeparateArgParamBuilder();
676+
void genInlineCreateBody(const SmallVector<MethodParameter> &paramList);
668677

669678
// Generates the build() method that takes each operand/attribute as a
670679
// stand-alone parameter. The generated build() method uses first operand's
@@ -2557,6 +2566,36 @@ static bool canInferType(const Operator &op) {
25572566
return op.getTrait("::mlir::InferTypeOpInterface::Trait");
25582567
}
25592568

2569+
void OpEmitter::genInlineCreateBody(
2570+
const SmallVector<MethodParameter> &paramList) {
2571+
SmallVector<MethodParameter> createParamList;
2572+
SmallVector<llvm::StringRef, 4> nonBuilderStateArgsList;
2573+
createParamList.emplace_back("::mlir::OpBuilder &", "builder");
2574+
std::string locParamName = "loc";
2575+
while (llvm::find_if(paramList, [&locParamName](const MethodParameter &p) {
2576+
return p.getName() == locParamName;
2577+
})) {
2578+
locParamName += "_";
2579+
}
2580+
createParamList.emplace_back("::mlir::Location", locParamName);
2581+
2582+
for (auto &param : paramList) {
2583+
if (param.getType() == "::mlir::OpBuilder &" or
2584+
param.getType() == "::mlir::OperationState &")
2585+
continue;
2586+
createParamList.emplace_back(param.getType(), param.getName(),
2587+
param.getDefaultValue(), param.isOptional());
2588+
nonBuilderStateArgsList.push_back(param.getName());
2589+
}
2590+
auto *c = opClass.addStaticMethod(opClass.getClassName(), "create",
2591+
createParamList);
2592+
std::string nonBuilderStateArgs = "";
2593+
llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
2594+
interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
2595+
c->body() << llvm::formatv(inlineCreateBody, nonBuilderStateArgs,
2596+
opClass.getClassName());
2597+
}
2598+
25602599
void OpEmitter::genSeparateArgParamBuilder() {
25612600
SmallVector<AttrParamKind, 2> attrBuilderType;
25622601
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
@@ -2573,10 +2612,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
25732612
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
25742613
attrType);
25752614

2576-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2615+
auto *m = opClass.addStaticMethod("void", "build", paramList);
25772616
// If the builder is redundant, skip generating the method.
25782617
if (!m)
25792618
return;
2619+
genInlineCreateBody(paramList);
2620+
25802621
auto &body = m->body();
25812622
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
25822623
/*isRawValueAttr=*/attrType ==
@@ -2701,10 +2742,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
27012742
if (op.getNumVariadicRegions())
27022743
paramList.emplace_back("unsigned", "numRegions");
27032744

2704-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2745+
auto *m = opClass.addStaticMethod("void", "build", paramList);
27052746
// If the builder is redundant, skip generating the method
27062747
if (!m)
27072748
return;
2749+
genInlineCreateBody(paramList);
27082750
auto &body = m->body();
27092751

27102752
// Operands
@@ -2815,10 +2857,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
28152857
if (op.getNumVariadicRegions())
28162858
paramList.emplace_back("unsigned", "numRegions");
28172859

2818-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2860+
auto *m = opClass.addStaticMethod("void", "build", paramList);
28192861
// If the builder is redundant, skip generating the method
28202862
if (!m)
28212863
return;
2864+
genInlineCreateBody(paramList);
28222865
auto &body = m->body();
28232866

28242867
int numResults = op.getNumResults();
@@ -2895,10 +2938,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
28952938
buildParamList(paramList, inferredAttributes, resultNames,
28962939
TypeParamKind::None, attrType);
28972940

2898-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2941+
auto *m = opClass.addStaticMethod("void", "build", paramList);
28992942
// If the builder is redundant, skip generating the method
29002943
if (!m)
29012944
return;
2945+
genInlineCreateBody(paramList);
29022946
auto &body = m->body();
29032947
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
29042948
/*isRawValueAttr=*/attrType ==
@@ -2937,10 +2981,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
29372981
: "attributes";
29382982
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
29392983
attributesName, "{}");
2940-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
2984+
auto *m = opClass.addStaticMethod("void", "build", paramList);
29412985
// If the builder is redundant, skip generating the method
29422986
if (!m)
29432987
return;
2988+
genInlineCreateBody(paramList);
29442989

29452990
auto &body = m->body();
29462991

@@ -3103,10 +3148,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
31033148
if (op.getNumVariadicRegions())
31043149
paramList.emplace_back("unsigned", "numRegions");
31053150

3106-
auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
3151+
auto *m = opClass.addStaticMethod("void", "build", paramList);
31073152
// If the builder is redundant, skip generating the method
31083153
if (!m)
31093154
return;
3155+
genInlineCreateBody(paramList);
31103156
auto &body = m->body();
31113157

31123158
// Operands

0 commit comments

Comments
 (0)