From b8051b5ee23fcfd4cb7e7d245a89d76dffdc880b Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Mon, 22 Jul 2024 10:35:39 -0700 Subject: [PATCH] Fix function signature for constraints with return values --- lib/Transform/Arith/MulToAddPdll.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/Transform/Arith/MulToAddPdll.cpp b/lib/Transform/Arith/MulToAddPdll.cpp index 36cc48a..f76c834 100644 --- a/lib/Transform/Arith/MulToAddPdll.cpp +++ b/lib/Transform/Arith/MulToAddPdll.cpp @@ -1,5 +1,6 @@ #include "lib/Transform/Arith/MulToAddPdll.h" #include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/include/mlir/Pass/Pass.h" @@ -10,15 +11,18 @@ namespace tutorial { #define GEN_PASS_DEF_MULTOADDPDLL #include "lib/Transform/Arith/Passes.h.inc" -Attribute halveImpl(PatternRewriter &rewriter, Attribute attr) { - IntegerAttr cAttr = ::llvm::cast<::mlir::IntegerAttr>(attr); + +LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results, + ArrayRef args) { + Attribute attr = args[0].cast(); + IntegerAttr cAttr = cast(attr); int64_t value = cAttr.getValue().getSExtValue(); - return rewriter.getIntegerAttr(cAttr.getType(), value / 2); + results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2)); + return success(); } void registerNativeConstraints(RewritePatternSet &patterns) { - patterns.getPDLPatterns().registerConstraintFunction( - "Halve", halveImpl); + patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl); } struct MulToAddPdll : impl::MulToAddPdllBase {