Skip to content

Commit dad9257

Browse files
asraacopybara-github
authored andcommitted
test: add LeNet MNIST CNN model e2e test (small)
PiperOrigin-RevId: 833907437
1 parent 23ad4e5 commit dad9257

File tree

12 files changed

+218185
-22
lines changed

12 files changed

+218185
-22
lines changed

lib/Pipelines/ArithmeticPipelineRegistration.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,8 @@ void mlirToSecretArithmeticPipelineBuilder(
154154
pm.addPass(createActivationCanonicalizations());
155155
pm.addPass(createSelectRewrite());
156156
pm.addPass(createCompareToSignRewrite());
157+
pm.addPass(createCanonicalizerPass());
158+
pm.addPass(createCSEPass());
157159

158160
// Vectorize and optimize rotations
159161
// TODO(#2320): figure out where this fits in the new pipeline
@@ -493,8 +495,8 @@ BackendPipelineBuilder toLattigoPipelineBuilder() {
493495

494496
void linalgPreprocessingBuilder(OpPassManager& manager) {
495497
manager.addPass(createInlineActivations());
496-
manager.addPass(createDropUnitDims());
497498
manager.addPass(createLinalgCanonicalizations());
499+
manager.addPass(createDropUnitDims());
498500
manager.addPass(createFoldConstantTensors());
499501
manager.addPass(createCanonicalizerPass());
500502
manager.addPass(createSymbolDCEPass());

lib/Target/OpenFhePke/Interpreter.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,8 @@ std::vector<TypedCppValue> Interpreter::interpret(
175175
}
176176

177177
void Interpreter::visit(Operation* op) {
178+
std::cout << "Visiting operation: " << op->getName().getStringRef().str()
179+
<< "\n";
178180
llvm::TypeSwitch<Operation*>(op)
179181
.Case<arith::ConstantOp, arith::AddIOp, arith::SubIOp, arith::MulIOp,
180182
arith::MulFOp, arith::DivSIOp, arith::RemSIOp, arith::AndIOp,
@@ -198,6 +200,16 @@ void Interpreter::visit(Operation* op) {
198200
op->emitError() << "Unsupported operation " << opName.getStringRef()
199201
<< " in interpreter";
200202
});
203+
// If any of the operations op operands have no more uses, then remove
204+
// them from the end.
205+
if (!op->getParentOfType<affine::AffineForOp>() &&
206+
!op->getParentOfType<scf::ForOp>()) {
207+
for (auto operand : op->getOperands()) {
208+
if (liveness.isDeadAfter(operand, op)) {
209+
env.erase(operand);
210+
}
211+
}
212+
}
201213
}
202214

203215
void Interpreter::visit(arith::ConstantOp op) {

lib/Transforms/LinalgCanonicalizations/LinalgCanonicalizations.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,9 @@ struct RewriteAvgPoolAsConv2D
517517
RankedTensorType twoDInputType =
518518
RankedTensorType::get({inputTy.getDimSize(2), inputTy.getDimSize(3)},
519519
inputTy.getElementType());
520-
Value convOutput = rewriter.create<tensor::EmptyOp>(
521-
poolOp.getLoc(), twoDOutputType.getShape(),
522-
twoDOutputType.getElementType());
520+
Value convOutput = tensor::EmptyOp::create(rewriter, poolOp.getLoc(),
521+
twoDOutputType.getShape(),
522+
twoDOutputType.getElementType());
523523
for (int n = 0; n < inputTy.getDimSize(0); ++n) {
524524
for (int c = 0; c < inputTy.getDimSize(1); ++c) {
525525
// Compute the 2-D constant convolution.
@@ -531,9 +531,9 @@ struct RewriteAvgPoolAsConv2D
531531
rewriter.getIndexAttr(inputTy.getDimSize(2)),
532532
rewriter.getIndexAttr(inputTy.getDimSize(3))};
533533
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
534-
auto extractInputOp = rewriter.create<tensor::ExtractSliceOp>(
535-
poolOp.getLoc(), twoDInputType, poolOp.getInputs()[0], offsets,
536-
inputSizes, strides);
534+
auto extractInputOp = tensor::ExtractSliceOp::create(
535+
rewriter, poolOp.getLoc(), twoDInputType, poolOp.getInputs()[0],
536+
offsets, inputSizes, strides);
537537

538538
auto convOp = linalg::Conv2DOp::create(
539539
rewriter, poolOp.getLoc(), twoDOutputType,
@@ -543,8 +543,8 @@ struct RewriteAvgPoolAsConv2D
543543
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
544544
rewriter.getIndexAttr(outputTy.getDimSize(2)),
545545
rewriter.getIndexAttr(outputTy.getDimSize(3))};
546-
outputVal = rewriter.create<tensor::InsertSliceOp>(
547-
poolOp.getLoc(), convOp.getResult(0), outputVal, offsets,
546+
outputVal = tensor::InsertSliceOp::create(
547+
rewriter, poolOp.getLoc(), convOp.getResult(0), outputVal, offsets,
548548
outputSizes, strides);
549549
}
550550
}

lib/Utils/Layout/Codegen.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,13 @@ Value buildIslExpr(isl_ast_expr* expr, std::map<std::string, Value> ivToValue,
212212
SmallVector<Value> args = getArgs(expr);
213213
auto op =
214214
arith::CmpIOp::create(b, islCmpToMlirAttr[type], args[0], args[1]);
215-
return op->getResult(0);
216-
}
217-
218-
if (type == isl_ast_op_select) {
215+
auto indexCastOp = arith::IndexCastOp::create(b, b.getIndexType(), op);
216+
return indexCastOp->getResult(0);
217+
} else if (type == isl_ast_op_select) {
219218
// Select op
220219
SmallVector<Value> args = getArgs(expr);
221-
auto op = arith::SelectOp::create(b, args[0], args[1], args[2]);
220+
auto condI1 = arith::IndexCastOp::create(b, b.getI1Type(), args[0]);
221+
auto op = arith::SelectOp::create(b, condI1, args[1], args[2]);
222222
return op->getResult(0);
223223
}
224224

@@ -228,7 +228,9 @@ Value buildIslExpr(isl_ast_expr* expr, std::map<std::string, Value> ivToValue,
228228
auto op = arith::RemSIOp::create(b, args[0], args[1]);
229229
auto eqOp = arith::CmpIOp::create(b, arith::CmpIPredicate::eq, op,
230230
arith::ConstantIndexOp::create(b, 0));
231-
return eqOp->getResult(0);
231+
auto indexCastOp =
232+
arith::IndexCastOp::create(b, b.getIndexType(), eqOp);
233+
return indexCastOp->getResult(0);
232234
}
233235
isl_ast_expr_op_type opType = isl_ast_expr_get_op_type(expr);
234236
char* cStr = isl_ast_expr_to_C_str(expr);
@@ -317,9 +319,12 @@ FailureOr<scf::ValueVector> MLIRLoopNestGenerator::visitAstNodeIf(
317319
isl_ast_expr_free(cond);
318320

319321
// Build scf if operation with the result types of the iter args
320-
auto ifOp =
321-
scf::IfOp::create(builder_, currentLoc_, TypeRange(currentIterArgs_),
322-
condVal, /*addThenBlock=*/true, /*addElseBlock=*/true);
322+
// Convert condVal to an i1
323+
auto condValI1 =
324+
arith::IndexCastOp::create(builder_, builder_.getI1Type(), condVal);
325+
auto ifOp = scf::IfOp::create(builder_, currentLoc_,
326+
TypeRange(currentIterArgs_), condValI1,
327+
/*addThenBlock=*/true, /*addElseBlock=*/true);
323328

324329
// TODO:(#2120): Handle ISL else conditions.
325330
isl_ast_node* elseNode = isl_ast_node_if_get_else_node(node);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_interpreter_test")
2+
3+
# See README.md for setup required to run these tests
4+
load("@rules_cc//cc:cc_binary.bzl", "cc_binary")
5+
6+
package(default_applicable_licenses = ["@heir//:license"])
7+
8+
openfhe_interpreter_test(
9+
name = "lenet_test",
10+
timeout = "eternal",
11+
generated_heir_opt_filename = "lenet_heir_opt.mlir",
12+
heir_opt_flags = [
13+
"--annotate-module=backend=openfhe scheme=ckks",
14+
"--torch-linalg-to-ckks=ciphertext-degree=1024",
15+
"--scheme-to-openfhe",
16+
],
17+
mlir_src = "lenet.mlir",
18+
tags = [
19+
"manual",
20+
"notap",
21+
],
22+
test_src = "lenet_test.cpp",
23+
deps = [
24+
"@llvm-project//mlir:IR",
25+
"@llvm-project//mlir:Parser",
26+
],
27+
)
28+
29+
cc_binary(
30+
name = "lenet_binary",
31+
srcs = ["lenet_main.cpp"],
32+
data = [
33+
"lenet.openfhe.mlir",
34+
],
35+
tags = [
36+
"manual",
37+
"notap",
38+
],
39+
deps = [
40+
"@bazel_tools//tools/cpp/runfiles",
41+
"@heir//lib/Target/OpenFhePke:Interpreter",
42+
"@llvm-project//mlir:IR",
43+
"@llvm-project//mlir:Parser",
44+
"@openfhe//:core",
45+
"@openfhe//:pke",
46+
],
47+
)

tests/Examples/openfhe/ckks/lenet/lenet.mlir

Lines changed: 120 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)