Skip to content

Commit a726eee

Browse files
asraacopybara-github
authored andcommitted
test: add LeNet MNIST CNN model e2e test (small)
PiperOrigin-RevId: 833907437
1 parent d3f9d76 commit a726eee

File tree

19 files changed

+118558
-27
lines changed

19 files changed

+118558
-27
lines changed

lib/Dialect/TensorExt/Transforms/ImplementRotateAndReduce.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op) {
5353
reduceOp = op.getReduceOp()->getValue().str();
5454
}
5555
implementedKernel = implementRotateAndReduce(vectorLeaf, plaintextsLeaf,
56-
period, steps, reduceOp);
56+
period, steps, {}, reduceOp);
5757
IRRewriter rewriter(op.getContext());
5858
rewriter.setInsertionPointAfter(op);
5959
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

lib/Kernel/KernelImplementation.h

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <cmath>
66
#include <cstddef>
77
#include <cstdint>
8+
#include <iostream>
89
#include <memory>
910
#include <optional>
1011
#include <string>
@@ -117,6 +118,7 @@ std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
117118
implementBabyStepGiantStep(
118119
const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
119120
int64_t steps, DagExtractor<T> extractFunc,
121+
std::map<int, bool> nonZeroDiagonals = {},
120122
const DerivedRotationIndexFn& derivedRotationIndexFn =
121123
defaultDerivedRotationIndexFn) {
122124
using NodeTy = ArithmeticDagNode<T>;
@@ -147,15 +149,29 @@ implementBabyStepGiantStep(
147149
int64_t innerRotAmount =
148150
derivedRotationIndexFn(giantStepSize, j, i, period);
149151
size_t extractionIndex = i + j * giantStepSize;
152+
153+
// If the extractIndex is not in the nonZeroDiagonals, then the value is
154+
// zero and we can skip the multiplication.
155+
if (!nonZeroDiagonals.empty() &&
156+
!nonZeroDiagonals.contains(extractionIndex)) {
157+
continue;
158+
}
159+
150160
auto plaintext = extractFunc(babySteppedDag, extractionIndex);
151161
auto rotatedPlaintext = NodeTy::leftRotate(plaintext, innerRotAmount);
152162
auto multiplied = NodeTy::mul(rotatedPlaintext, babyStepVals[i]);
153163
innerSum =
154164
innerSum == nullptr ? multiplied : NodeTy::add(innerSum, multiplied);
155165
}
156166

157-
auto rotatedSum = NodeTy::leftRotate(innerSum, period * j * giantStepSize);
158-
result = result == nullptr ? rotatedSum : NodeTy::add(result, rotatedSum);
167+
auto rotatedSum =
168+
innerSum == nullptr
169+
? nullptr
170+
: NodeTy::leftRotate(innerSum, period * j * giantStepSize);
171+
result = result == nullptr
172+
? rotatedSum
173+
: (rotatedSum == nullptr ? result
174+
: NodeTy::add(result, rotatedSum));
159175
}
160176

161177
return result;
@@ -185,6 +201,7 @@ std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
185201
std::shared_ptr<ArithmeticDagNode<T>>>
186202
implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
187203
int64_t period, int64_t steps,
204+
std::map<int, bool> nonZeroDiagonals = {},
188205
const std::string& reduceOp = "arith.addi") {
189206
using NodeTy = ArithmeticDagNode<T>;
190207
auto performReduction = [&](std::shared_ptr<NodeTy> left,
@@ -217,7 +234,7 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
217234
};
218235

219236
return implementBabyStepGiantStep<T>(vector, plaintexts.value(), period,
220-
steps, extractFunc);
237+
steps, extractFunc, nonZeroDiagonals);
221238
}
222239

223240
// Returns an arithmetic DAG that implements a baby-step-giant-step between
@@ -248,7 +265,7 @@ implementCiphertextCiphertextBabyStepGiantStep(
248265
int64_t extractionIndex) { return babySteppedDag; };
249266

250267
return implementBabyStepGiantStep<T>(giantSteppedOperand, babySteppedOperand,
251-
period, steps, extractFunc,
268+
period, steps, extractFunc, {},
252269
derivedRotationIndexFn);
253270
}
254271

@@ -260,13 +277,14 @@ template <typename T>
260277
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
261278
std::shared_ptr<ArithmeticDagNode<T>>>
262279
implementHaleviShoup(const T& vector, const T& matrix,
263-
std::vector<int64_t> originalMatrixShape) {
280+
std::vector<int64_t> originalMatrixShape,
281+
std::map<int, bool> nonZeroDiagonals = {}) {
264282
using NodeTy = ArithmeticDagNode<T>;
265283
int64_t numRotations = matrix.getShape()[0];
266284

267285
auto rotateAndReduceResult = implementRotateAndReduce<T>(
268286
vector, std::optional<T>(matrix), /*period=*/1,
269-
/*steps=*/numRotations);
287+
/*steps=*/numRotations, nonZeroDiagonals);
270288

271289
auto summedShifts = rotateAndReduceResult;
272290

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ void mlirToSecretArithmeticPipelineBuilder(
154154
pm.addPass(createActivationCanonicalizations());
155155
pm.addPass(createSelectRewrite());
156156
pm.addPass(createCompareToSignRewrite());
157+
pm.addPass(createCanonicalizerPass());
158+
pm.addPass(createCSEPass());
157159

158160
// Vectorize and optimize rotations
159161
// TODO(#2320): figure out where this fits in the new pipeline
@@ -493,8 +495,8 @@ BackendPipelineBuilder toLattigoPipelineBuilder() {
493495

494496
void linalgPreprocessingBuilder(OpPassManager& manager) {
495497
manager.addPass(createInlineActivations());
496-
manager.addPass(createDropUnitDims());
497498
manager.addPass(createLinalgCanonicalizations());
499+
manager.addPass(createDropUnitDims());
498500
manager.addPass(createFoldConstantTensors());
499501
manager.addPass(createCanonicalizerPass());
500502
manager.addPass(createSymbolDCEPass());

lib/Target/OpenFhePke/Interpreter.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ std::vector<TypedCppValue> Interpreter::interpret(
176176
}
177177

178178
void Interpreter::visit(Operation* op) {
179+
std::cout << "Visiting operation: " << op->getName().getStringRef().str()
180+
<< "\n";
179181
llvm::TypeSwitch<Operation*>(op)
180182
.Case<arith::ConstantOp, arith::AddIOp, arith::AddFOp, arith::SubIOp,
181183
arith::MulIOp, arith::MulFOp, arith::DivSIOp, arith::RemSIOp,
@@ -199,6 +201,16 @@ void Interpreter::visit(Operation* op) {
199201
op->emitError() << "Unsupported operation " << opName.getStringRef()
200202
<< " in interpreter";
201203
});
204+
// If any of the operations op operands have no more uses, then remove
205+
// them from the end.
206+
if (!op->getParentOfType<affine::AffineForOp>() &&
207+
!op->getParentOfType<scf::ForOp>()) {
208+
for (auto operand : op->getOperands()) {
209+
if (liveness.isDeadAfter(operand, op)) {
210+
env.erase(operand);
211+
}
212+
}
213+
}
202214
}
203215

204216
void Interpreter::visit(arith::ConstantOp op) {

lib/Transforms/ConvertToCiphertextSemantics/ConvertToCiphertextSemantics.cpp

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,13 +674,30 @@ struct ConvertLinalgConv2D
674674
cast<TypedValue<RankedTensorType>>(adaptor.getInputs()[1]);
675675
SSAValue matrixLeaf(matrix);
676676

677-
// The original matrix shape is the shape of the expanded filter.
677+
// The original matrix shape is the shape of the expanded filter before
678+
// diagonalization. This is 28x28 for LeNet
678679
RankedTensorType expandedMatrixType = get2dConvFilterExpandedType(
679680
cast<RankedTensorType>(op.getInputs()[1].getType()),
680681
cast<RankedTensorType>(op.getInputs()[0].getType()), /*padding=*/0);
682+
683+
// Get non-zero diagonals of the diagonalized expanded filter matrix.
684+
LayoutAttr filterLayout = getLayoutAttr(adaptor.getInputs()[1]);
685+
auto filterRelation = filterLayout.getIntegerRelation();
686+
PointCollector collector;
687+
getRangePoints(filterRelation, collector);
688+
std::map<int, bool> nonZeroDiagonals;
689+
for (auto point : collector.points) {
690+
nonZeroDiagonals[point[0]] = true;
691+
}
692+
for (auto [ct, val] : nonZeroDiagonals) {
693+
std::cout << ct << ", ";
694+
}
695+
std::cout << "\n";
696+
std::cout << "nonZero diagonal size: " << nonZeroDiagonals.size() << "\n";
697+
681698
std::shared_ptr<ArithmeticDagNode<SSAValue>> implementedKernel =
682699
implementHaleviShoup(vectorLeaf, matrixLeaf,
683-
expandedMatrixType.getShape());
700+
expandedMatrixType.getShape(), nonZeroDiagonals);
684701

685702
rewriter.setInsertionPointAfter(op);
686703
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,9 @@ struct RewriteAvgPoolAsConv2D
517517
RankedTensorType twoDInputType =
518518
RankedTensorType::get({inputTy.getDimSize(2), inputTy.getDimSize(3)},
519519
inputTy.getElementType());
520-
Value convOutput = rewriter.create<tensor::EmptyOp>(
521-
poolOp.getLoc(), twoDOutputType.getShape(),
522-
twoDOutputType.getElementType());
520+
Value convOutput = tensor::EmptyOp::create(rewriter, poolOp.getLoc(),
521+
twoDOutputType.getShape(),
522+
twoDOutputType.getElementType());
523523
for (int n = 0; n < inputTy.getDimSize(0); ++n) {
524524
for (int c = 0; c < inputTy.getDimSize(1); ++c) {
525525
// Compute the 2-D constant convolution.
@@ -531,9 +531,9 @@ struct RewriteAvgPoolAsConv2D
531531
rewriter.getIndexAttr(inputTy.getDimSize(2)),
532532
rewriter.getIndexAttr(inputTy.getDimSize(3))};
533533
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
534-
auto extractInputOp = rewriter.create<tensor::ExtractSliceOp>(
535-
poolOp.getLoc(), twoDInputType, poolOp.getInputs()[0], offsets,
536-
inputSizes, strides);
534+
auto extractInputOp = tensor::ExtractSliceOp::create(
535+
rewriter, poolOp.getLoc(), twoDInputType, poolOp.getInputs()[0],
536+
offsets, inputSizes, strides);
537537

538538
auto convOp = linalg::Conv2DOp::create(
539539
rewriter, poolOp.getLoc(), twoDOutputType,
@@ -543,8 +543,8 @@ struct RewriteAvgPoolAsConv2D
543543
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
544544
rewriter.getIndexAttr(outputTy.getDimSize(2)),
545545
rewriter.getIndexAttr(outputTy.getDimSize(3))};
546-
outputVal = rewriter.create<tensor::InsertSliceOp>(
547-
poolOp.getLoc(), convOp.getResult(0), outputVal, offsets,
546+
outputVal = tensor::InsertSliceOp::create(
547+
rewriter, poolOp.getLoc(), convOp.getResult(0), outputVal, offsets,
548548
outputSizes, strides);
549549
}
550550
}

lib/Utils/Layout/Codegen.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ Value buildIslExpr(isl_ast_expr* expr, std::map<std::string, Value> ivToValue,
212212
SmallVector<Value> args = getArgs(expr);
213213
auto op =
214214
arith::CmpIOp::create(b, islCmpToMlirAttr[type], args[0], args[1]);
215-
return op->getResult(0);
215+
return arith::ExtSIOp::create(b, b.getI32Type(), op->getResult(0));
216216
}
217217

218218
if (type == isl_ast_op_select) {
@@ -229,7 +229,7 @@ Value buildIslExpr(isl_ast_expr* expr, std::map<std::string, Value> ivToValue,
229229
auto eqOp =
230230
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, op,
231231
arith::ConstantIntOp::create(b, 0, 32));
232-
return eqOp->getResult(0);
232+
return arith::ExtSIOp::create(b, b.getI32Type(), eqOp->getResult(0));
233233
}
234234
isl_ast_expr_op_type opType = isl_ast_expr_get_op_type(expr);
235235
char* cStr = isl_ast_expr_to_C_str(expr);
@@ -318,9 +318,13 @@ FailureOr<scf::ValueVector> MLIRLoopNestGenerator::visitAstNodeIf(
318318
isl_ast_expr_free(cond);
319319

320320
// Build scf if operation with the result types of the iter args
321-
auto ifOp =
322-
scf::IfOp::create(builder_, currentLoc_, TypeRange(currentIterArgs_),
323-
condVal, /*addThenBlock=*/true, /*addElseBlock=*/true);
321+
// Convert condVal to an i1
322+
auto condValI1 =
323+
arith::CmpIOp::create(builder_, arith::CmpIPredicate::eq, condVal,
324+
arith::ConstantIntOp::create(builder_, 1, 32));
325+
auto ifOp = scf::IfOp::create(builder_, currentLoc_,
326+
TypeRange(currentIterArgs_), condValI1,
327+
/*addThenBlock=*/true, /*addElseBlock=*/true);
324328

325329
// TODO:(#2120): Handle ISL else conditions.
326330
isl_ast_node* elseNode = isl_ast_node_if_get_else_node(node);

0 commit comments

Comments
 (0)