diff --git a/include/vast/Dialect/Parser/Ops.td b/include/vast/Dialect/Parser/Ops.td index 30b47cee05..65bb3dfdbb 100644 --- a/include/vast/Dialect/Parser/Ops.td +++ b/include/vast/Dialect/Parser/Ops.td @@ -19,6 +19,8 @@ def Parser_Source { let summary = "Source of parsed data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -31,6 +33,8 @@ def Paser_Sink { let summary = "Sink of parsed data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -43,6 +47,8 @@ def Parser_Parse { let summary = "Parsing operation data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -69,6 +75,8 @@ def Parse_MaybeParse { let summary = "Maybe parsing operation data."; + let hasFolder = 1; + let assemblyFormat = [{ $arguments attr-dict `:` functional-type($arguments, $result) }]; @@ -76,13 +84,15 @@ def Parse_MaybeParse def Parse_Cast : Parser_Op< "cast" > - , Arguments< (ins Parser_AnyDataType:$arguments) > + , Arguments< (ins Parser_AnyDataType:$operand) > , Results< (outs Parser_AnyDataType:$result) > { let summary = "Casting operation."; + let hasFolder = 1; + let assemblyFormat = [{ - $arguments attr-dict `:` functional-type($arguments, $result) + $operand attr-dict `:` functional-type($operand, $result) }]; } diff --git a/lib/vast/Dialect/Parser/Ops.cpp b/lib/vast/Dialect/Parser/Ops.cpp index f9b5ef7639..3d8d312458 100644 --- a/lib/vast/Dialect/Parser/Ops.cpp +++ b/lib/vast/Dialect/Parser/Ops.cpp @@ -15,20 +15,53 @@ using namespace vast::pr; namespace vast::pr { - using fold_result_t = ::llvm::SmallVectorImpl< ::mlir::OpFoldResult >; + using fold_result = ::mlir::OpFoldResult; + using fold_results = ::llvm::SmallVectorImpl< fold_result >; - logical_result NoParse::fold(FoldAdaptor adaptor, fold_result_t &results) { - auto change = mlir::failure(); - auto op = getOperation(); + template< typename op_t > + logical_result forward_same_operation( + op_t op, auto adaptor, fold_results &results + ) { + if (op.getNumOperands() == 1 && op.getNumResults() == 1) { + if (auto operand = op.getOperand(0); mlir::isa< op_t >(operand.getDefiningOp())) { + if (operand.getType() == op->getOpResult(0).getType()) { + results.push_back(operand); + return mlir::success(); + } + } + } + + return mlir::failure(); + } + + logical_result Source::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result Sink::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result Parse::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result NoParse::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } + + logical_result MaybeParse::fold(FoldAdaptor adaptor, fold_results &results) { + return forward_same_operation(*this, adaptor, results); + } - for (auto [idx, operand] : llvm::reverse(llvm::enumerate(getOperands()))) { - if (auto noparse = mlir::dyn_cast< NoParse >(operand.getDefiningOp())) { - op->eraseOperand(idx); - change = mlir::success(); + fold_result Cast::fold(FoldAdaptor adaptor) { + if (auto operand = getOperand(); mlir::isa< Cast >(operand.getDefiningOp())) { + if (operand.getType() == getType()) { + return operand; } } - return change; + return {}; } } // namespace vast::pr