1111#include " llvm/include/llvm/Support/Debug.h" // from @llvm-project
1212#include " mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
1313#include " mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
14- #include " mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
1514#include " mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
1615#include " mlir/include/mlir/Dialect/Math/IR/Math.h" // from @llvm-project
1716#include " mlir/include/mlir/IR/Attributes.h" // from @llvm-project
@@ -31,8 +30,7 @@ namespace heir {
3130LogicalResult SchemeInfoAnalysis::visitOperation (
3231 Operation *op, ArrayRef<const SchemeInfoLattice *> operands,
3332 ArrayRef<SchemeInfoLattice *> results) {
34- LLVM_DEBUG (llvm::dbgs ()
35- << " Visiting: " << op->getName () << " . " );
33+ LLVM_DEBUG (llvm::dbgs () << " Visiting: " << op->getName () << " . " );
3634
3735 auto propagate = [&](Value value, const NatureOfComputation &counter) {
3836 auto *oldNoc = getLatticeElement (value);
@@ -44,57 +42,65 @@ LogicalResult SchemeInfoAnalysis::visitOperation(
4442 // count integer arithmetic ops
4543 .Case <arith::AddIOp, arith::SubIOp, arith::MulIOp>([&](auto intOp) {
4644 auto newNoc = NatureOfComputation (0 , 0 , 1 , 0 , 0 , 0 );
47- intOp->setAttr (numIntArithOpsAttrName, IntegerAttr::get (IntegerType::get (op->getContext (), 64 ), newNoc.getIntArithOpsCount ()));
48- LLVM_DEBUG (llvm::dbgs ()
49- << " Counting: " << newNoc << " \n " );
45+ intOp->setAttr (numIntArithOpsAttrName,
46+ IntegerAttr::get (IntegerType::get (op->getContext (), 64 ),
47+ newNoc.getIntArithOpsCount ()));
48+ LLVM_DEBUG (llvm::dbgs () << " Counting: " << newNoc << " \n " );
5049 propagate (intOp.getResult (), newNoc);
5150 })
5251 // count real arithmetic ops
5352 .Case <arith::AddFOp, arith::SubFOp, arith::MulFOp>([&](auto realOp) {
54- auto newNoc = NatureOfComputation (0 , 0 , 0 , 1 , 0 , 0 );
55- realOp->setAttr (numRealArithOpsAttrName, IntegerAttr::get (IntegerType::get (op->getContext (), 64 ), newNoc.getRealArithOpsCount ()));
56- LLVM_DEBUG (llvm::dbgs ()
57- << " Counting: " << newNoc << " \n " );
53+ auto newNoc = NatureOfComputation (0 , 0 , 0 , 1 , 0 , 0 );
54+ realOp->setAttr (numRealArithOpsAttrName,
55+ IntegerAttr::get (IntegerType::get (op->getContext (), 64 ),
56+ newNoc.getRealArithOpsCount ()));
57+ LLVM_DEBUG (llvm::dbgs () << " Counting: " << newNoc << " \n " );
5858 propagate (realOp->getResult (0 ), newNoc);
5959 })
6060 // count non linear ops
6161 .Case <math::AbsFOp, math::AbsIOp>([&](auto nonLinOp) {
6262 auto newNoc = NatureOfComputation (0 , 0 , 0 , 0 , 0 , 1 );
63- nonLinOp->setAttr (numNonLinOpsAttrName, IntegerAttr::get (IntegerType::get (op->getContext (), 64 ), newNoc.getNonLinOpsCount ()));
64- LLVM_DEBUG (llvm::dbgs ()
65- << " Counting: " << newNoc << " \n " );
63+ nonLinOp->setAttr (
64+ numNonLinOpsAttrName,
65+ IntegerAttr::get (IntegerType::get (op->getContext (), 64 ),
66+ newNoc.getNonLinOpsCount ()));
67+ LLVM_DEBUG (llvm::dbgs () << " Counting: " << newNoc << " \n " );
6668 propagate (nonLinOp->getResult (0 ), newNoc);
6769 })
6870 // count bool ops
6971 .Case <arith::AndIOp, arith::OrIOp, arith::XOrIOp>([&](auto boolOp) {
7072 auto newNoc = NatureOfComputation (1 , 0 , 0 , 0 , 0 , 0 );
71- boolOp->setAttr (numBoolOpsAttrName, IntegerAttr::get (IntegerType::get (op->getContext (), 64 ), newNoc.getBoolOpsCount ()));
72- LLVM_DEBUG (llvm::dbgs ()
73- << " Counting: " << newNoc << " \n " );
73+ boolOp->setAttr (numBoolOpsAttrName,
74+ IntegerAttr::get (IntegerType::get (op->getContext (), 64 ),
75+ newNoc.getBoolOpsCount ()));
76+ LLVM_DEBUG (llvm::dbgs () << " Counting: " << newNoc << " \n " );
7477 propagate (boolOp->getResult (0 ), newNoc);
7578 })
7679 // count bit ops
7780 .Case <arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp>([&](auto bitOp) {
7881 auto newNoc = NatureOfComputation (0 , 1 , 0 , 0 , 0 , 0 );
79- bitOp->setAttr (numBitOpsAttrName, IntegerAttr::get (IntegerType::get (op->getContext (), 64 ), newNoc.getBitOpsCount ()));
80- LLVM_DEBUG (llvm::dbgs ()
81- << " Counting: " << newNoc << " \n " );
82+ bitOp->setAttr (numBitOpsAttrName,
83+ IntegerAttr::get (IntegerType::get (op->getContext (), 64 ),
84+ newNoc.getBitOpsCount ()));
85+ LLVM_DEBUG (llvm::dbgs () << " Counting: " << newNoc << " \n " );
8286 propagate (bitOp->getResult (0 ), newNoc);
8387 })
8488 // count real comparisons
8589 .Case <arith::CmpFOp>([&](auto cmpOps) {
8690 auto newNoc = NatureOfComputation (0 , 0 , 0 , 1 , 1 , 0 );
87- cmpOps->setAttr (numCmpOpsAttrName, IntegerAttr::get (IntegerType::get (op->getContext (), 64 ), newNoc.getCmpOpsCount ()));
88- LLVM_DEBUG (llvm::dbgs ()
89- << " Counting: " << newNoc << " \n " );
91+ cmpOps->setAttr (numCmpOpsAttrName,
92+ IntegerAttr::get (IntegerType::get (op->getContext (), 64 ),
93+ newNoc.getCmpOpsCount ()));
94+ LLVM_DEBUG (llvm::dbgs () << " Counting: " << newNoc << " \n " );
9095 propagate (cmpOps->getResult (0 ), newNoc);
9196 })
92- // count int comparisons
97+ // count int comparisons
9398 .Case <arith::CmpIOp>([&](auto cmpOps) {
9499 auto newNoc = NatureOfComputation (0 , 0 , 1 , 0 , 1 , 0 );
95- cmpOps->setAttr (numCmpOpsAttrName, IntegerAttr::get (IntegerType::get (op->getContext (), 64 ), newNoc.getCmpOpsCount ()));
96- LLVM_DEBUG (llvm::dbgs ()
97- << " Counting: " << newNoc << " \n " );
100+ cmpOps->setAttr (numCmpOpsAttrName,
101+ IntegerAttr::get (IntegerType::get (op->getContext (), 64 ),
102+ newNoc.getCmpOpsCount ()));
103+ LLVM_DEBUG (llvm::dbgs () << " Counting: " << newNoc << " \n " );
98104 propagate (cmpOps->getResult (0 ), newNoc);
99105 });
100106 return success ();
@@ -128,59 +134,65 @@ bool hasAtLeastOneRealOperand(Operation *op) {
128134 return false ;
129135}
130136
131- static NatureOfComputation countNatComp (Operation *top, DataFlowSolver *solver) {
132- auto counter = NatureOfComputation (0 ,0 ,0 ,0 ,0 ,0 );
133- top->walk <WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
134- funcOp.getBody ().walk <WalkOrder::PreOrder>([&](Operation *op) {
135- LLVM_DEBUG (llvm::dbgs ()
136- << " Counting here: " << op->getName () << " \n " );
137- if (op->getNumResults () == 0 ) {
138- return ;
139- }
140- auto natcomp = solver->lookupState <SchemeInfoLattice>(op->getResult (0 ))->getValue ();
141- if (natcomp.isInitialized ()) {
142- counter = counter + natcomp;
143- }
144- });
137+ static NatureOfComputation countNatComp (Operation *top,
138+ DataFlowSolver *solver) {
139+ auto counter = NatureOfComputation (0 , 0 , 0 , 0 , 0 , 0 );
140+ top->walk <WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
141+ funcOp.getBody ().walk <WalkOrder::PreOrder>([&](Operation *op) {
142+ LLVM_DEBUG (llvm::dbgs () << " Counting here: " << op->getName () << " \n " );
143+ if (op->getNumResults () == 0 ) {
144+ return ;
145+ }
146+ auto natcomp =
147+ solver->lookupState <SchemeInfoLattice>(op->getResult (0 ))->getValue ();
148+ if (natcomp.isInitialized ()) {
149+ counter = counter + natcomp;
150+ }
151+ });
145152 });
146153 return counter;
147154}
148155
149- static NatureOfComputation getMaxNatComp (Operation *top, DataFlowSolver *solver) {
150- auto maxNatComp = NatureOfComputation (0 ,0 ,0 ,0 ,0 ,0 );
151- top->walk <WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
152- funcOp.getBody ().walk <WalkOrder::PreOrder>([&](Operation *op) {
153- if (op->getNumResults () == 0 ) {
154- return ;
155- }
156- auto natcomp = solver->lookupState <SchemeInfoLattice>(op->getResult (0 ))->getValue ();
157- if (natcomp.isInitialized ()) {
158- maxNatComp = NatureOfComputation::max (maxNatComp, natcomp);
159- }
160- });
156+ static NatureOfComputation getMaxNatComp (Operation *top,
157+ DataFlowSolver *solver) {
158+ auto maxNatComp = NatureOfComputation (0 , 0 , 0 , 0 , 0 , 0 );
159+ top->walk <WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
160+ funcOp.getBody ().walk <WalkOrder::PreOrder>([&](Operation *op) {
161+ if (op->getNumResults () == 0 ) {
162+ return ;
163+ }
164+ auto natcomp =
165+ solver->lookupState <SchemeInfoLattice>(op->getResult (0 ))->getValue ();
166+ if (natcomp.isInitialized ()) {
167+ maxNatComp = NatureOfComputation::max (maxNatComp, natcomp);
168+ }
169+ });
161170 });
162171 return maxNatComp;
163172}
164173
165174void annotateNatureOfComputation (Operation *top, DataFlowSolver *solver) {
166-
167175 auto getIntegerAttr = [&](int level) {
168176 return IntegerAttr::get (IntegerType::get (top->getContext (), 64 ), level);
169177 };
170178
171179 auto maxNatComp = getMaxNatComp (top, solver);
172180 auto count = countNatComp (top, solver);
173181 top->walk <WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
174- funcOp->setAttr (numBoolOpsAttrName, getIntegerAttr (count.getBoolOpsCount ()));
182+ funcOp->setAttr (numBoolOpsAttrName,
183+ getIntegerAttr (count.getBoolOpsCount ()));
175184 funcOp->setAttr (numBitOpsAttrName, getIntegerAttr (count.getBitOpsCount ()));
176- funcOp->setAttr (numIntArithOpsAttrName, getIntegerAttr (count.getIntArithOpsCount ()));
177- funcOp->setAttr (numRealArithOpsAttrName, getIntegerAttr (count.getRealArithOpsCount ()));
185+ funcOp->setAttr (numIntArithOpsAttrName,
186+ getIntegerAttr (count.getIntArithOpsCount ()));
187+ funcOp->setAttr (numRealArithOpsAttrName,
188+ getIntegerAttr (count.getRealArithOpsCount ()));
178189 funcOp->setAttr (numCmpOpsAttrName, getIntegerAttr (count.getCmpOpsCount ()));
179- funcOp->setAttr (numNonLinOpsAttrName, getIntegerAttr (count.getNonLinOpsCount ()));
190+ funcOp->setAttr (numNonLinOpsAttrName,
191+ getIntegerAttr (count.getNonLinOpsCount ()));
180192 LLVM_DEBUG (llvm::dbgs ()
181- << " Writing annotations here: " << funcOp->getName () << " \n " );
193+ << " Writing annotations here: " << funcOp->getName () << " \n " );
182194 });
183195}
184196
185197} // namespace heir
186- } // namespace mlir
198+ } // namespace mlir
0 commit comments