2121#include " llvm/ADT/StringMap.h"
2222#include " llvm/ADT/StringRef.h"
2323#include " llvm/ADT/Twine.h"
24+ #include " llvm/ADT/TypeSwitch.h"
2425#include " llvm/Support/ErrorHandling.h"
2526#include " llvm/Support/Path.h"
2627
@@ -42,24 +43,66 @@ template <typename TargetOp> class StdRecognizer {
4243 template <size_t ... Indices>
4344 static TargetOp buildCall (CIRBaseBuilderTy &builder, CallOp call,
4445 std::index_sequence<Indices...>) {
45- return builder.create <TargetOp>(call.getLoc (), call.getResult ().getType (),
46- call.getCalleeAttr (),
47- call.getOperand (Indices)...);
46+ return builder.create <TargetOp>(
47+ call.getLoc (),
48+ (call.getResult () ? call.getResult ().getType () : mlir::TypeRange{}),
49+ call.getCalleeAttr (), call.getOperand (Indices)...);
4850 }
4951
5052public:
51- static bool raise (CallOp call, mlir::MLIRContext &context, bool remark) {
53+ static FuncOp getCalleeFromSymbol (mlir::ModuleOp theModule,
54+ llvm::StringRef name) {
55+ auto global = mlir::SymbolTable::lookupSymbolIn (theModule, name);
56+ assert (global && " expected to find symbol for function" );
57+ return dyn_cast<FuncOp>(global);
58+ }
59+
60+ static std::optional<StringRef>
61+ getRecordName (const clang::CXXRecordDecl *rd) {
62+ if (!rd || !rd->getDeclContext ()->isStdNamespace ())
63+ return std::nullopt ;
64+
65+ if (rd->getDeclName ().isIdentifier ())
66+ return rd->getName ();
67+
68+ return std::nullopt ;
69+ }
70+
71+ static std::optional<std::string>
72+ resolveSpecialMember (mlir::Attribute specialMember) {
73+ return TypeSwitch<Attribute, std::optional<std::string>>(specialMember)
74+ .Case <CXXCtorAttr, CXXDtorAttr>(
75+ [](auto attr) -> std::optional<std::string> {
76+ if (!attr.getRecordDecl ())
77+ return std::nullopt ;
78+ if (auto recordName = getRecordName (*attr.getRecordDecl ()))
79+ return recordName->str () + " _" + attr.getMnemonic ().str ();
80+ return std::nullopt ;
81+ })
82+ .Default ([](Attribute) { return std::nullopt ; });
83+ }
84+
85+ static bool raise (mlir::ModuleOp theModule, CallOp call,
86+ mlir::MLIRContext &context, bool remark) {
5287 constexpr int numArgs = TargetOp::getNumArgs ();
5388 if (call.getNumOperands () != numArgs)
5489 return false ;
5590
56- auto callExprAttr = call.getAstAttr ();
5791 llvm::StringRef stdFuncName = TargetOp::getFunctionName ();
58- if (!callExprAttr || !callExprAttr.isStdFunctionCall (stdFuncName))
59- return false ;
60-
61- if (!checkArguments (call.getArgOperands ()))
62- return false ;
92+ auto calleeFunc = getCalleeFromSymbol (theModule, *call.getCallee ());
93+
94+ if (auto specialMember = calleeFunc.getCxxSpecialMemberAttr ()) {
95+ auto resolved = resolveSpecialMember (specialMember);
96+ if (!resolved || *resolved != stdFuncName.str ())
97+ return false ;
98+ } else {
99+ auto callExprAttr = call.getAstAttr ();
100+ if (!callExprAttr || !callExprAttr.isStdFunctionCall (stdFuncName))
101+ return false ;
102+
103+ if (!checkArguments (call.getArgOperands ()))
104+ return false ;
105+ }
63106
64107 if (remark)
65108 mlir::emitRemark (call.getLoc ())
@@ -194,12 +237,16 @@ void IdiomRecognizerPass::recognizeCall(CallOp call) {
194237
195238 bool remark = opts.emitRemarkFoundCalls ();
196239
197- using StdFunctionsRecognizer = std::tuple<StdRecognizer<StdFindOp>>;
240+ using StdFunctionsRecognizer =
241+ std::tuple<StdRecognizer<StdFindOp>, StdRecognizer<StdVectorCtorOp>,
242+ StdRecognizer<StdVectorDtorOp>>;
198243
199244 // MSVC requires explicitly capturing these variables.
200245 std::apply (
201246 [&, call, remark, this ](auto ... recognizers) {
202- (decltype (recognizers)::raise (call, this ->getContext (), remark) || ...);
247+ (decltype (recognizers)::raise (theModule, call, this ->getContext (),
248+ remark) ||
249+ ...);
203250 },
204251 StdFunctionsRecognizer ());
205252}
0 commit comments