-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][tblgen] add concrete create methods #147168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[mlir][tblgen] add concrete create methods #147168
Conversation
ac2e407
to
5393a8c
Compare
5393a8c
to
2902d5a
Compare
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-core Author: Maksim Levental (makslevental) ChangesCurrently To improve QoL, this PR adds static create methods to the ops themselves like static arith::ConstantIntOp create(OpBuilder& builder, Location location, int64_t value, unsigned width); Now if one types See https://discourse.llvm.org/t/rfc-building-mlir-operation-observed-caveats-and-proposed-solution/87204/13 for more info. Full diff: https://github.com/llvm/llvm-project/pull/147168.diff 2 Files Affected:
diff --git a/mlir/include/mlir/TableGen/Class.h b/mlir/include/mlir/TableGen/Class.h
index f750a34a3b2ba..69cefbbc43e0a 100644
--- a/mlir/include/mlir/TableGen/Class.h
+++ b/mlir/include/mlir/TableGen/Class.h
@@ -71,6 +71,8 @@ class MethodParameter {
StringRef getName() const { return name; }
/// Returns true if the parameter has a default value.
bool hasDefaultValue() const { return !defaultValue.empty(); }
+ StringRef getDefaultValue() const { return defaultValue; }
+ bool isOptional() const { return optional; }
private:
/// The C++ type.
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 6008ed4673d1b..65094dcaeb6d8 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -230,6 +230,14 @@ static const char *const opCommentHeader = R"(
)";
+static const char *const inlineCreateBody = R"(
+ OperationState __state__({0}, getOperationName());
+ build(builder, __state__, {1});
+ auto __res__ = dyn_cast<{2}>(builder.create(__state__));
+ assert(__res__ && "builder didn't return the right type");
+ return __res__;
+)";
+
//===----------------------------------------------------------------------===//
// Utility structs and functions
//===----------------------------------------------------------------------===//
@@ -665,6 +673,7 @@ class OpEmitter {
// Generates the build() method that takes each operand/attribute
// as a stand-alone parameter.
void genSeparateArgParamBuilder();
+ void genInlineCreateBody(const SmallVector<MethodParameter> ¶mList);
// Generates the build() method that takes each operand/attribute as a
// stand-alone parameter. The generated build() method uses first operand's
@@ -2557,6 +2566,36 @@ static bool canInferType(const Operator &op) {
return op.getTrait("::mlir::InferTypeOpInterface::Trait");
}
+void OpEmitter::genInlineCreateBody(
+ const SmallVector<MethodParameter> ¶mList) {
+ SmallVector<MethodParameter> createParamList;
+ SmallVector<llvm::StringRef, 4> nonBuilderStateArgsList;
+ createParamList.emplace_back("::mlir::OpBuilder &", "builder");
+ std::string locParamName = "loc";
+ while (llvm::find_if(paramList, [&locParamName](const MethodParameter &p) {
+ return p.getName().str() == locParamName;
+ })) {
+ locParamName += "_";
+ }
+ createParamList.emplace_back("::mlir::Location", locParamName);
+
+ for (auto ¶m : paramList) {
+ if (param.getType() == "::mlir::OpBuilder &" ||
+ param.getType() == "::mlir::OperationState &")
+ continue;
+ createParamList.emplace_back(param.getType(), param.getName(),
+ param.getDefaultValue(), param.isOptional());
+ nonBuilderStateArgsList.push_back(param.getName());
+ }
+ auto *c = opClass.addStaticMethod(opClass.getClassName(), "create",
+ createParamList);
+ std::string nonBuilderStateArgs = "";
+ llvm::raw_string_ostream nonBuilderStateArgsOS(nonBuilderStateArgs);
+ interleaveComma(nonBuilderStateArgsList, nonBuilderStateArgsOS);
+ c->body() << llvm::formatv(inlineCreateBody, locParamName,
+ nonBuilderStateArgs, opClass.getClassName());
+}
+
void OpEmitter::genSeparateArgParamBuilder() {
SmallVector<AttrParamKind, 2> attrBuilderType;
attrBuilderType.push_back(AttrParamKind::WrappedAttr);
@@ -2573,10 +2612,12 @@ void OpEmitter::genSeparateArgParamBuilder() {
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
attrType);
- auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", paramList);
// If the builder is redundant, skip generating the method.
if (!m)
return;
+ genInlineCreateBody(paramList);
+
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
/*isRawValueAttr=*/attrType ==
@@ -2701,10 +2742,11 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder(
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
- auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", paramList);
// If the builder is redundant, skip generating the method
if (!m)
return;
+ genInlineCreateBody(paramList);
auto &body = m->body();
// Operands
@@ -2815,10 +2857,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder(
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
- auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", paramList);
// If the builder is redundant, skip generating the method
if (!m)
return;
+ genInlineCreateBody(paramList);
auto &body = m->body();
int numResults = op.getNumResults();
@@ -2895,10 +2938,11 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
buildParamList(paramList, inferredAttributes, resultNames,
TypeParamKind::None, attrType);
- auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", paramList);
// If the builder is redundant, skip generating the method
if (!m)
return;
+ genInlineCreateBody(paramList);
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
/*isRawValueAttr=*/attrType ==
@@ -2937,10 +2981,11 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(
: "attributes";
paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>",
attributesName, "{}");
- auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", paramList);
// If the builder is redundant, skip generating the method
if (!m)
return;
+ genInlineCreateBody(paramList);
auto &body = m->body();
@@ -3103,10 +3148,11 @@ void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) {
if (op.getNumVariadicRegions())
paramList.emplace_back("unsigned", "numRegions");
- auto *m = opClass.addStaticMethod("void", "build", std::move(paramList));
+ auto *m = opClass.addStaticMethod("void", "build", paramList);
// If the builder is redundant, skip generating the method
if (!m)
return;
+ genInlineCreateBody(paramList);
auto &body = m->body();
// Operands
|
ae157aa
to
c99b3c0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like a reasonable incremental change on the current state (not too far from the current syntax and not more verbose, almost same number of characters), preserving compatibility with existing builders (important since they can be handwritten today).
It does not close the door to moving to another pattern (like the fluent API that is explored, or something else).
It's a bit annoying that people would wonder why they would see in the codebase either builder.create<Op>(...)
or Op::create(builder, ...)
. We're losing consistency and I would rather not. I'm not sure if we should have a coding guideline: plan to migrate to this new pattern and ultimately deprecating builder.create<Op<(...)
(at least upstream)?
Let me ask again the others we previously talked to about adding support SemaCodeComplete side. The cost of adding support may be better trade off than adding additional and migration (I mean the feature clangd side is more useful in general, I have no idea about the cost ... but will ask). |
I can send an NFC PR after this one to update all the upstream uses (a regex will be sufficient).
|
We would have to deprecate the other form completely everywhere, that's a cost to all users. It can be invoked from template classes/patterns, so a plain regex wouldn't suffice. So prudent to consider alternatives and their timelines. I'm not proposing to solve this for all possible ways to forward argument templated methods, just those here and we have a fairly standard forwarding. clangd would be main one as its part of this larger project indeed - impossible to cater to all (else we'd probably be using C++11 here still). |
Is deprecating the current usage necessary? |
Upstream, I would say yes: consistency is important in the codebase IMO: having to completely equivalent way of writing it is hurting readability and accessibility to the codebase. |
c99b3c0
to
0317629
Compare
Here is what the update would look like #147311. It's a lot of files so it would probably be better to chop by affected dir. Note currently it's missing instances like
which are "custom" but these are easily fixed since if they're actually being used now the corresponding |
Oh so I was wrong when I wrote before that:
I forgot the extra builder leading argument here. How do you write a manual builder now? Ideally we wouldn't add complexity in the body of the builder for users. |
If you look at before and after
it's actually the exact same number of chars 😄 (it actually looks like a savings but it's not - it's just that
I don't think there's really any issue here - we won't have to change any existing impls. Handling these "custom" builders just means adding // Arith.h
static void build(OpBuilder &builder, OperationState &result, int64_t value);
+ static arith::ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value);
// ArithOps.cpp
void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, int64_t value) {
...
}
+ arith::ConstantIndexOp::create(OpBuilder &builder, Location location, int64_t value) {
+ mlir::OperationState state(location, getOperationName());
+ build(builder, state, value);
+ auto result = llvm::dyn_cast<ConstantIndexOp>(builder.create(state));
+ assert(result && "builder didn't return the right type");
+ return result;
+ } A little tedious but I don't think there are that many (and they're easily found like I pointed out). |
I didn't mean the number of char: I meant about the compatibility aspects. I initially thought your scheme would be...
Well it's not backward compatible, whether it's an issue or not is TDB :) But thinking a bit more: can't you automatically generate from TablGen the diff you just showed from the custom builder declaration in ODS ( |
Currently
builder.create<...>
does not in any meaningful way hint/show the various builders an op supports (arg names/types) becausecreate
forwards the args tobuild
.To improve QoL, this PR adds static create methods to the ops themselves like
Now if one types
arith::ConstantIntO::create(builder,...
instead ofbuilder.create<arith::ConstantIntO>(...
auto-complete/hints will pop up.See https://discourse.llvm.org/t/rfc-building-mlir-operation-observed-caveats-and-proposed-solution/87204/13 for more info.
TODO: this still needs tests.