Skip to content

Commit 8e5d307

Browse files
committed
pre-commits
1 parent 17a264b commit 8e5d307

File tree

4 files changed

+125
-120
lines changed

4 files changed

+125
-120
lines changed

lib/Analysis/SchemeInfoAnalysis/SchemeInfoAnalysis.cpp

Lines changed: 71 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
1212
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
1313
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
14-
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
1514
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
1615
#include "mlir/include/mlir/Dialect/Math/IR/Math.h" // from @llvm-project
1716
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
@@ -31,8 +30,7 @@ namespace heir {
3130
LogicalResult SchemeInfoAnalysis::visitOperation(
3231
Operation *op, ArrayRef<const SchemeInfoLattice *> operands,
3332
ArrayRef<SchemeInfoLattice *> results) {
34-
LLVM_DEBUG(llvm::dbgs()
35-
<< "Visiting: " << op->getName() << ". ");
33+
LLVM_DEBUG(llvm::dbgs() << "Visiting: " << op->getName() << ". ");
3634

3735
auto propagate = [&](Value value, const NatureOfComputation &counter) {
3836
auto *oldNoc = getLatticeElement(value);
@@ -44,57 +42,65 @@ LogicalResult SchemeInfoAnalysis::visitOperation(
4442
// count integer arithmetic ops
4543
.Case<arith::AddIOp, arith::SubIOp, arith::MulIOp>([&](auto intOp) {
4644
auto newNoc = NatureOfComputation(0, 0, 1, 0, 0, 0);
47-
intOp->setAttr(numIntArithOpsAttrName, IntegerAttr::get(IntegerType::get(op->getContext(), 64), newNoc.getIntArithOpsCount()));
48-
LLVM_DEBUG(llvm::dbgs()
49-
<< "Counting: " << newNoc << "\n");
45+
intOp->setAttr(numIntArithOpsAttrName,
46+
IntegerAttr::get(IntegerType::get(op->getContext(), 64),
47+
newNoc.getIntArithOpsCount()));
48+
LLVM_DEBUG(llvm::dbgs() << "Counting: " << newNoc << "\n");
5049
propagate(intOp.getResult(), newNoc);
5150
})
5251
// count real arithmetic ops
5352
.Case<arith::AddFOp, arith::SubFOp, arith::MulFOp>([&](auto realOp) {
54-
auto newNoc = NatureOfComputation(0, 0, 0, 1, 0, 0);
55-
realOp->setAttr(numRealArithOpsAttrName, IntegerAttr::get(IntegerType::get(op->getContext(), 64), newNoc.getRealArithOpsCount()));
56-
LLVM_DEBUG(llvm::dbgs()
57-
<< "Counting: " << newNoc << "\n");
53+
auto newNoc = NatureOfComputation(0, 0, 0, 1, 0, 0);
54+
realOp->setAttr(numRealArithOpsAttrName,
55+
IntegerAttr::get(IntegerType::get(op->getContext(), 64),
56+
newNoc.getRealArithOpsCount()));
57+
LLVM_DEBUG(llvm::dbgs() << "Counting: " << newNoc << "\n");
5858
propagate(realOp->getResult(0), newNoc);
5959
})
6060
// count non linear ops
6161
.Case<math::AbsFOp, math::AbsIOp>([&](auto nonLinOp) {
6262
auto newNoc = NatureOfComputation(0, 0, 0, 0, 0, 1);
63-
nonLinOp->setAttr(numNonLinOpsAttrName, IntegerAttr::get(IntegerType::get(op->getContext(), 64), newNoc.getNonLinOpsCount()));
64-
LLVM_DEBUG(llvm::dbgs()
65-
<< "Counting: " << newNoc << "\n");
63+
nonLinOp->setAttr(
64+
numNonLinOpsAttrName,
65+
IntegerAttr::get(IntegerType::get(op->getContext(), 64),
66+
newNoc.getNonLinOpsCount()));
67+
LLVM_DEBUG(llvm::dbgs() << "Counting: " << newNoc << "\n");
6668
propagate(nonLinOp->getResult(0), newNoc);
6769
})
6870
// count bool ops
6971
.Case<arith::AndIOp, arith::OrIOp, arith::XOrIOp>([&](auto boolOp) {
7072
auto newNoc = NatureOfComputation(1, 0, 0, 0, 0, 0);
71-
boolOp->setAttr(numBoolOpsAttrName, IntegerAttr::get(IntegerType::get(op->getContext(), 64), newNoc.getBoolOpsCount()));
72-
LLVM_DEBUG(llvm::dbgs()
73-
<< "Counting: " << newNoc << "\n");
73+
boolOp->setAttr(numBoolOpsAttrName,
74+
IntegerAttr::get(IntegerType::get(op->getContext(), 64),
75+
newNoc.getBoolOpsCount()));
76+
LLVM_DEBUG(llvm::dbgs() << "Counting: " << newNoc << "\n");
7477
propagate(boolOp->getResult(0), newNoc);
7578
})
7679
// count bit ops
7780
.Case<arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp>([&](auto bitOp) {
7881
auto newNoc = NatureOfComputation(0, 1, 0, 0, 0, 0);
79-
bitOp->setAttr(numBitOpsAttrName, IntegerAttr::get(IntegerType::get(op->getContext(), 64), newNoc.getBitOpsCount()));
80-
LLVM_DEBUG(llvm::dbgs()
81-
<< "Counting: " << newNoc << "\n");
82+
bitOp->setAttr(numBitOpsAttrName,
83+
IntegerAttr::get(IntegerType::get(op->getContext(), 64),
84+
newNoc.getBitOpsCount()));
85+
LLVM_DEBUG(llvm::dbgs() << "Counting: " << newNoc << "\n");
8286
propagate(bitOp->getResult(0), newNoc);
8387
})
8488
// count real comparisons
8589
.Case<arith::CmpFOp>([&](auto cmpOps) {
8690
auto newNoc = NatureOfComputation(0, 0, 0, 1, 1, 0);
87-
cmpOps->setAttr(numCmpOpsAttrName, IntegerAttr::get(IntegerType::get(op->getContext(), 64), newNoc.getCmpOpsCount()));
88-
LLVM_DEBUG(llvm::dbgs()
89-
<< "Counting: " << newNoc << "\n");
91+
cmpOps->setAttr(numCmpOpsAttrName,
92+
IntegerAttr::get(IntegerType::get(op->getContext(), 64),
93+
newNoc.getCmpOpsCount()));
94+
LLVM_DEBUG(llvm::dbgs() << "Counting: " << newNoc << "\n");
9095
propagate(cmpOps->getResult(0), newNoc);
9196
})
92-
// count int comparisons
97+
// count int comparisons
9398
.Case<arith::CmpIOp>([&](auto cmpOps) {
9499
auto newNoc = NatureOfComputation(0, 0, 1, 0, 1, 0);
95-
cmpOps->setAttr(numCmpOpsAttrName, IntegerAttr::get(IntegerType::get(op->getContext(), 64), newNoc.getCmpOpsCount()));
96-
LLVM_DEBUG(llvm::dbgs()
97-
<< "Counting: " << newNoc << "\n");
100+
cmpOps->setAttr(numCmpOpsAttrName,
101+
IntegerAttr::get(IntegerType::get(op->getContext(), 64),
102+
newNoc.getCmpOpsCount()));
103+
LLVM_DEBUG(llvm::dbgs() << "Counting: " << newNoc << "\n");
98104
propagate(cmpOps->getResult(0), newNoc);
99105
});
100106
return success();
@@ -128,59 +134,65 @@ bool hasAtLeastOneRealOperand(Operation *op) {
128134
return false;
129135
}
130136

131-
static NatureOfComputation countNatComp(Operation *top, DataFlowSolver *solver) {
132-
auto counter = NatureOfComputation(0,0,0,0,0,0);
133-
top->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
134-
funcOp.getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
135-
LLVM_DEBUG(llvm::dbgs()
136-
<< "Counting here: " << op->getName() << "\n");
137-
if (op->getNumResults() == 0) {
138-
return;
139-
}
140-
auto natcomp = solver->lookupState<SchemeInfoLattice>(op->getResult(0))->getValue();
141-
if (natcomp.isInitialized()) {
142-
counter = counter + natcomp;
143-
}
144-
});
137+
static NatureOfComputation countNatComp(Operation *top,
138+
DataFlowSolver *solver) {
139+
auto counter = NatureOfComputation(0, 0, 0, 0, 0, 0);
140+
top->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
141+
funcOp.getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
142+
LLVM_DEBUG(llvm::dbgs() << "Counting here: " << op->getName() << "\n");
143+
if (op->getNumResults() == 0) {
144+
return;
145+
}
146+
auto natcomp =
147+
solver->lookupState<SchemeInfoLattice>(op->getResult(0))->getValue();
148+
if (natcomp.isInitialized()) {
149+
counter = counter + natcomp;
150+
}
151+
});
145152
});
146153
return counter;
147154
}
148155

149-
static NatureOfComputation getMaxNatComp(Operation *top, DataFlowSolver *solver) {
150-
auto maxNatComp = NatureOfComputation(0,0,0,0,0,0);
151-
top->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
152-
funcOp.getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
153-
if (op->getNumResults() == 0) {
154-
return;
155-
}
156-
auto natcomp = solver->lookupState<SchemeInfoLattice>(op->getResult(0))->getValue();
157-
if (natcomp.isInitialized()) {
158-
maxNatComp = NatureOfComputation::max(maxNatComp, natcomp);
159-
}
160-
});
156+
static NatureOfComputation getMaxNatComp(Operation *top,
157+
DataFlowSolver *solver) {
158+
auto maxNatComp = NatureOfComputation(0, 0, 0, 0, 0, 0);
159+
top->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
160+
funcOp.getBody().walk<WalkOrder::PreOrder>([&](Operation *op) {
161+
if (op->getNumResults() == 0) {
162+
return;
163+
}
164+
auto natcomp =
165+
solver->lookupState<SchemeInfoLattice>(op->getResult(0))->getValue();
166+
if (natcomp.isInitialized()) {
167+
maxNatComp = NatureOfComputation::max(maxNatComp, natcomp);
168+
}
169+
});
161170
});
162171
return maxNatComp;
163172
}
164173

165174
void annotateNatureOfComputation(Operation *top, DataFlowSolver *solver) {
166-
167175
auto getIntegerAttr = [&](int level) {
168176
return IntegerAttr::get(IntegerType::get(top->getContext(), 64), level);
169177
};
170178

171179
auto maxNatComp = getMaxNatComp(top, solver);
172180
auto count = countNatComp(top, solver);
173181
top->walk<WalkOrder::PreOrder>([&](func::FuncOp funcOp) {
174-
funcOp->setAttr(numBoolOpsAttrName, getIntegerAttr(count.getBoolOpsCount()));
182+
funcOp->setAttr(numBoolOpsAttrName,
183+
getIntegerAttr(count.getBoolOpsCount()));
175184
funcOp->setAttr(numBitOpsAttrName, getIntegerAttr(count.getBitOpsCount()));
176-
funcOp->setAttr(numIntArithOpsAttrName, getIntegerAttr(count.getIntArithOpsCount()));
177-
funcOp->setAttr(numRealArithOpsAttrName, getIntegerAttr(count.getRealArithOpsCount()));
185+
funcOp->setAttr(numIntArithOpsAttrName,
186+
getIntegerAttr(count.getIntArithOpsCount()));
187+
funcOp->setAttr(numRealArithOpsAttrName,
188+
getIntegerAttr(count.getRealArithOpsCount()));
178189
funcOp->setAttr(numCmpOpsAttrName, getIntegerAttr(count.getCmpOpsCount()));
179-
funcOp->setAttr(numNonLinOpsAttrName, getIntegerAttr(count.getNonLinOpsCount()));
190+
funcOp->setAttr(numNonLinOpsAttrName,
191+
getIntegerAttr(count.getNonLinOpsCount()));
180192
LLVM_DEBUG(llvm::dbgs()
181-
<< "Writing annotations here: " << funcOp->getName() << "\n");
193+
<< "Writing annotations here: " << funcOp->getName() << "\n");
182194
});
183195
}
184196

185197
} // namespace heir
186-
} // namespace mlir
198+
} // namespace mlir

lib/Analysis/SchemeInfoAnalysis/SchemeInfoAnalysis.h

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
#include <cassert>
66
#include <optional>
77

8-
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
8+
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
99
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
10-
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
11-
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
12-
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
13-
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
14-
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
15-
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
10+
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
11+
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
12+
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
13+
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
14+
#include "mlir/include/mlir/Interfaces/CallInterfaces.h" // from @llvm-project
15+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
1616

1717
namespace mlir {
1818
namespace heir {
@@ -93,82 +93,80 @@ class NatureOfComputation {
9393
numCmpOps == rhs.numCmpOps && numNonLinOps == rhs.numNonLinOps;
9494
}
9595

96-
NatureOfComputation operator+(const NatureOfComputation &rhs) const {
97-
if (!isInitialized() && !rhs.isInitialized()) {
98-
return *this; // return the current object
96+
NatureOfComputation operator+(const NatureOfComputation &rhs) const {
97+
if (!isInitialized() && !rhs.isInitialized()) {
98+
return *this; // return the current object
9999
}
100100

101101
if (isInitialized() && !rhs.isInitialized()) {
102-
return *this; // return the current object
102+
return *this; // return the current object
103103
}
104104

105105
if (!isInitialized() && rhs.isInitialized()) {
106-
return rhs; // return the rhs object
106+
return rhs; // return the rhs object
107107
}
108108

109109
// Both are initialized
110110
return NatureOfComputation(
111-
numBoolOps + rhs.numBoolOps,
112-
numBitOps + rhs.numBitOps,
111+
numBoolOps + rhs.numBoolOps, numBitOps + rhs.numBitOps,
113112
numIntArithOps + rhs.numIntArithOps,
114-
numRealArithOps + rhs.numRealArithOps,
115-
numCmpOps + rhs.numCmpOps,
113+
numRealArithOps + rhs.numRealArithOps, numCmpOps + rhs.numCmpOps,
116114
numNonLinOps + rhs.numNonLinOps);
117115
}
118116

119117
StringRef getDominantAttributeName() const {
120118
assert(isInitialized() && "NatureOfComputation not initialized");
121-
//TODO: define what happens when all are equal
119+
// TODO: define what happens when all are equal
122120
int maxCount = numBoolOps;
123121
StringRef attributeName = numBoolOpsAttrName;
124122

125123
if (numBitOps > maxCount) {
126-
maxCount = numBitOps;
127-
attributeName = numBitOpsAttrName;
124+
maxCount = numBitOps;
125+
attributeName = numBitOpsAttrName;
128126
}
129127
if (numIntArithOps > maxCount) {
130-
maxCount = numIntArithOps;
131-
attributeName = numIntArithOpsAttrName;
128+
maxCount = numIntArithOps;
129+
attributeName = numIntArithOpsAttrName;
132130
}
133131
if (numRealArithOps > maxCount) {
134-
maxCount = numRealArithOps;
135-
attributeName = numRealArithOpsAttrName;
132+
maxCount = numRealArithOps;
133+
attributeName = numRealArithOpsAttrName;
136134
}
137135
if (numCmpOps > maxCount) {
138-
maxCount = numCmpOps;
139-
attributeName = numCmpOpsAttrName;
136+
maxCount = numCmpOps;
137+
attributeName = numCmpOpsAttrName;
140138
}
141139
if (numNonLinOps > maxCount) {
142-
maxCount = numNonLinOps;
143-
attributeName = numNonLinOpsAttrName;
140+
maxCount = numNonLinOps;
141+
attributeName = numNonLinOpsAttrName;
144142
}
145143

146144
return attributeName;
147-
}
145+
}
148146

149-
int getDominantComputationCount() const {
147+
int getDominantComputationCount() const {
150148
assert(isInitialized() && "NatureOfComputation not initialized");
151-
149+
152150
int maxCount = numBoolOps;
153151

154152
if (numBitOps > maxCount) {
155-
maxCount = numBitOps;
153+
maxCount = numBitOps;
156154
}
157155
if (numIntArithOps > maxCount) {
158-
maxCount = numIntArithOps;
156+
maxCount = numIntArithOps;
159157
}
160158
if (numRealArithOps > maxCount) {
161-
maxCount = numRealArithOps;
159+
maxCount = numRealArithOps;
162160
}
163161
if (numCmpOps > maxCount) {
164-
maxCount = numCmpOps;
162+
maxCount = numCmpOps;
165163
}
166164
if (numNonLinOps > maxCount) {
167-
maxCount = numNonLinOps;
165+
maxCount = numNonLinOps;
168166
}
169167

170-
return maxCount; // Returns the highest count after all comparisons
171-
}
168+
return maxCount; // Returns the highest count after all comparisons
169+
}
172170

173171
static NatureOfComputation max(const NatureOfComputation &lhs,
174172
const NatureOfComputation &rhs) {
@@ -198,8 +196,9 @@ int getDominantComputationCount() const {
198196
if (isInitialized()) {
199197
os << "NatureOfComputation(numBoolOps=" << numBoolOps
200198
<< "; numBitOps=" << numBitOps << "; numIntArithOps=" << numIntArithOps
201-
<< "; numRealArithOps=" << numRealArithOps << "; numCmpOps=" << numCmpOps
202-
<< "; numNonLinOps=" << numNonLinOps << ")";
199+
<< "; numRealArithOps=" << numRealArithOps
200+
<< "; numCmpOps=" << numCmpOps << "; numNonLinOps=" << numNonLinOps
201+
<< ")";
203202
} else {
204203
os << "NatureOfComputation(uninitialized)";
205204
}
@@ -227,10 +226,10 @@ class SchemeInfoLattice : public dataflow::Lattice<NatureOfComputation> {
227226
};
228227

229228
class SchemeInfoAnalysis
230-
: public dataflow::SparseForwardDataFlowAnalysis<SchemeInfoLattice>{
229+
: public dataflow::SparseForwardDataFlowAnalysis<SchemeInfoLattice> {
231230
public:
232231
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;
233-
232+
234233
LogicalResult visitOperation(Operation *op,
235234
ArrayRef<const SchemeInfoLattice *> operands,
236235
ArrayRef<SchemeInfoLattice *> results) override;

0 commit comments

Comments
 (0)