Skip to content

Commit 40e8437

Browse files
asraacopybara-github
authored andcommitted
[do not submit] test if plaintexts for a conv are all zeros
This tests to see if any of the plaintext diagonals from a matvec (part of conv) were all zeros and could be eliminated. PiperOrigin-RevId: 836692282
1 parent 72800d5 commit 40e8437

File tree

7 files changed

+312
-17
lines changed

7 files changed

+312
-17
lines changed

lib/Target/OpenFhePke/Interpreter.cpp

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <functional>
88
#include <map>
99
#include <memory>
10+
#include <ostream>
1011
#include <string>
1112
#include <utility>
1213
#include <variant>
@@ -176,23 +177,23 @@ std::vector<TypedCppValue> Interpreter::interpret(
176177

177178
void Interpreter::visit(Operation* op) {
178179
llvm::TypeSwitch<Operation*>(op)
179-
.Case<arith::ConstantOp, arith::AddIOp, arith::SubIOp, arith::MulIOp,
180-
arith::MulFOp, arith::DivSIOp, arith::RemSIOp, arith::AndIOp,
181-
arith::CmpIOp, arith::SelectOp, arith::ExtUIOp, arith::ExtSIOp,
182-
arith::ExtFOp, arith::IndexCastOp, arith::MinSIOp, arith::MaxSIOp,
183-
tensor::EmptyOp, tensor::ExtractOp, tensor::InsertOp,
184-
tensor::SplatOp, tensor::FromElementsOp, tensor::ConcatOp,
185-
tensor::ExtractSliceOp, tensor::InsertSliceOp,
186-
tensor::CollapseShapeOp, tensor::ExpandShapeOp, linalg::BroadcastOp,
187-
scf::ForOp, scf::IfOp, scf::YieldOp, affine::AffineForOp,
188-
affine::AffineYieldOp, lwe::RLWEDecodeOp, AddOp, AddPlainOp, SubOp,
189-
SubPlainOp, MulOp, MulNoRelinOp, MulPlainOp, MulConstOp, NegateOp,
190-
SquareOp, RelinOp, ModReduceOp, LevelReduceOp, RotOp, AutomorphOp,
191-
KeySwitchOp, BootstrapOp, EncryptOp, DecryptOp,
192-
MakePackedPlaintextOp, MakeCKKSPackedPlaintextOp, GenParamsOp,
193-
GenContextOp, GenRotKeyOp, GenMulKeyOp, GenBootstrapKeyOp,
194-
SetupBootstrapOp, FastRotationOp, FastRotationPrecomputeOp>(
195-
[&](auto op) { visit(op); })
180+
.Case<arith::ConstantOp, arith::AddIOp, arith::AddFOp, arith::SubIOp,
181+
arith::MulIOp, arith::MulFOp, arith::DivSIOp, arith::RemSIOp,
182+
arith::AndIOp, arith::CmpIOp, arith::SelectOp, arith::ExtUIOp,
183+
arith::ExtSIOp, arith::ExtFOp, arith::FloorDivSIOp,
184+
arith::IndexCastOp, arith::MinSIOp, arith::MaxSIOp, tensor::EmptyOp,
185+
tensor::ExtractOp, tensor::InsertOp, tensor::SplatOp,
186+
tensor::FromElementsOp, tensor::ConcatOp, tensor::ExtractSliceOp,
187+
tensor::InsertSliceOp, tensor::CollapseShapeOp,
188+
tensor::ExpandShapeOp, linalg::BroadcastOp, scf::ForOp, scf::IfOp,
189+
scf::YieldOp, affine::AffineForOp, affine::AffineYieldOp,
190+
lwe::RLWEDecodeOp, AddOp, AddPlainOp, SubOp, SubPlainOp, MulOp,
191+
MulNoRelinOp, MulPlainOp, MulConstOp, NegateOp, SquareOp, RelinOp,
192+
ModReduceOp, LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp,
193+
BootstrapOp, EncryptOp, DecryptOp, MakePackedPlaintextOp,
194+
MakeCKKSPackedPlaintextOp, GenParamsOp, GenContextOp, GenRotKeyOp,
195+
GenMulKeyOp, GenBootstrapKeyOp, SetupBootstrapOp, FastRotationOp,
196+
FastRotationPrecomputeOp>([&](auto op) { visit(op); })
196197
.Default([&](Operation* op) {
197198
OperationName opName = op->getName();
198199
op->emitError() << "Unsupported operation " << opName.getStringRef()
@@ -305,6 +306,15 @@ void Interpreter::visit(arith::AddIOp op) {
305306
env.insert_or_assign(op.getResult(), std::move(result));
306307
}
307308

309+
void Interpreter::visit(arith::AddFOp op) {
310+
auto lhs = env.at(op.getLhs());
311+
auto rhs = env.at(op.getRhs());
312+
auto result = applyBinop(
313+
op, lhs, rhs, [](int a, int b) { return a + b; },
314+
[](float a, float b) { return a + b; });
315+
env.insert_or_assign(op.getResult(), std::move(result));
316+
}
317+
308318
void Interpreter::visit(arith::SubIOp op) {
309319
auto lhs = env.at(op.getLhs());
310320
auto rhs = env.at(op.getRhs());
@@ -442,6 +452,16 @@ void Interpreter::visit(arith::ExtUIOp op) {
442452
}
443453
}
444454

455+
void Interpreter::visit(arith::FloorDivSIOp op) {
456+
auto lhs = env.at(op.getLhs());
457+
auto rhs = env.at(op.getRhs());
458+
auto result = applyBinop(
459+
op, lhs, rhs,
460+
[](int a, int b) { return std::floor(static_cast<float>(a) / b); },
461+
[](float a, float b) { return std::floor(a / b); });
462+
env.insert_or_assign(op.getResult(), std::move(result));
463+
}
464+
445465
void Interpreter::visit(arith::ExtSIOp op) {
446466
auto operand = env.at(op.getIn());
447467
// For signed extension, we just convert the value to a larger type
@@ -1381,7 +1401,31 @@ void Interpreter::visit(MakePackedPlaintextOp op) {
13811401
}
13821402

13831403
void Interpreter::visit(MakeCKKSPackedPlaintextOp op) {
1404+
auto checkIfZero = [&](TypedCppValue value) -> bool {
1405+
if (std::holds_alternative<std::vector<float>>(value.value)) {
1406+
auto vec = std::get<std::vector<float>>(value.value);
1407+
if (llvm::all_of(vec, [](float v) { return v == 0.0f; })) {
1408+
return true;
1409+
}
1410+
} else if (std::holds_alternative<std::vector<double>>(value.value)) {
1411+
auto vec = std::get<std::vector<double>>(value.value);
1412+
if (llvm::all_of(vec, [](double v) { return v == 0.0; })) {
1413+
return true;
1414+
}
1415+
} else if (std::holds_alternative<std::vector<int>>(value.value)) {
1416+
auto vec = std::get<std::vector<int>>(value.value);
1417+
if (llvm::all_of(vec, [](int v) { return v == 0; })) {
1418+
return true;
1419+
}
1420+
}
1421+
return false;
1422+
};
1423+
13841424
auto value = env.at(op.getValue());
1425+
if (checkIfZero(value)) {
1426+
op.dump();
1427+
std::cout << "MakeCKKSPackedPlaintextOp: value is zero" << std::endl;
1428+
}
13851429
auto cc = std::get<CryptoContextT>(env.at(op.getCryptoContext()).value);
13861430
if (std::holds_alternative<std::vector<float>>(value.value)) {
13871431
auto vec = std::get<std::vector<float>>(value.value);

lib/Target/OpenFhePke/Interpreter.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,15 @@ class Interpreter {
9797

9898
// Upstream ops
9999
void visit(arith::AddIOp op);
100+
void visit(arith::AddFOp op);
100101
void visit(arith::AndIOp op);
101102
void visit(arith::CmpIOp op);
102103
void visit(arith::ConstantOp op);
103104
void visit(arith::DivSIOp op);
104105
void visit(arith::ExtFOp op);
105106
void visit(arith::ExtSIOp op);
106107
void visit(arith::ExtUIOp op);
108+
void visit(arith::FloorDivSIOp op);
107109
void visit(arith::IndexCastOp op);
108110
void visit(arith::MulIOp op);
109111
void visit(arith::MulFOp op);

lib/Target/OpenFhePke/InterpreterTest.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,26 @@ TEST(InterpreterTest, TestAdd) {
6868
EXPECT_EQ(std::get<int>(results[0].value), 7);
6969
}
7070

71+
TEST(InterpreterTest, TestFloorDivSI) {
72+
MLIRContext context;
73+
initContext(context);
74+
auto module = parseTest(&context, R"mlir(
75+
module {
76+
func.func @main(%a: i32, %b: i32) -> i32 {
77+
%c = arith.floordivsi %a, %b : i32
78+
return %c : i32
79+
}
80+
}
81+
)mlir");
82+
Interpreter interpreter(module.get());
83+
std::string entryFunction = "main";
84+
std::vector<TypedCppValue> inputs = {TypedCppValue(-7), TypedCppValue(3)};
85+
std::vector<TypedCppValue> results =
86+
interpreter.interpret(entryFunction, inputs);
87+
EXPECT_EQ(results.size(), 1);
88+
EXPECT_EQ(std::get<int>(results[0].value), -3);
89+
}
90+
7191
TEST(InterpreterTest, TestElementwiseAdd) {
7292
MLIRContext context;
7393
initContext(context);

tests/Examples/common/lenet.mlir

Lines changed: 120 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
load("@heir//tests/Examples/openfhe:test.bzl", "openfhe_interpreter_test")
2+
3+
package(default_applicable_licenses = ["@heir//:license"])
4+
5+
openfhe_interpreter_test(
6+
name = "conv_layer_interpreter_test",
7+
copts = select({
8+
"@heir//:config_openfhe_enable_timing": ["-DOPENFHE_ENABLE_TIMING"],
9+
"//conditions:default": [],
10+
}),
11+
generated_heir_opt_filename = "module.openfhe.mlir",
12+
heir_opt_flags = [
13+
"--annotate-module=backend=openfhe scheme=ckks",
14+
"--torch-linalg-to-ckks=ciphertext-degree=8192",
15+
"--scheme-to-openfhe",
16+
],
17+
mlir_src = "@heir//tests/Examples/openfhe/ckks/conv_diagonal:conv_layer.mlir",
18+
test_src = "conv_layer_interpreter_test.cpp",
19+
deps = [
20+
"@llvm-project//mlir:IR",
21+
"@llvm-project//mlir:Parser",
22+
],
23+
)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
2+
#map1 = affine_map<(d0, d1) -> (d0, d1)>
3+
#map2 = affine_map<(d0, d1) -> (d1)>
4+
module {
5+
func.func @conv(%arg0: tensor<1x1x32x32xf32> {secret.secret}) -> tensor<1x6x28x28xf32> {
6+
%cst = arith.constant dense_resource<torch_tensor_6_torch.float32> : tensor<6xf32>
7+
%cst_3 = arith.constant dense_resource<torch_tensor_6_1_5_5_torch.float32> : tensor<6x1x5x5xf32>
8+
%0 = tensor.empty() : tensor<1x6x28x28xf32>
9+
%broadcasted = linalg.broadcast ins(%cst : tensor<6xf32>) outs(%0 : tensor<1x6x28x28xf32>) dimensions = [0, 2, 3]
10+
%1 = linalg.conv_2d_nchw_fchw {dilations = dense<1> : vector<2xi64>, strides = dense<1> : vector<2xi64>} ins(%arg0, %cst_3 : tensor<1x1x32x32xf32>, tensor<6x1x5x5xf32>) outs(%broadcasted : tensor<1x6x28x28xf32>) -> tensor<1x6x28x28xf32>
11+
return %1 : tensor<1x6x28x28xf32>
12+
}
13+
}
14+
15+
{-#
16+
dialect_resources: {
17+
builtin: {
18+
torch_tensor_6_torch.float32: "0x0400000008B030BEDAE5A43C7D68BF3D7881233ECDE4E13DE065E2BD",
19+
torch_tensor_6_1_5_5_torch.float32: "0x04000000E75A4ABC47F4A6BDF371873C45C5423E7D09EABD4D3E0C3EC74455BD8D00FFBC7A370B3DABA7283E0BB749BE6DD9E7BD20EC49BD12E338BE2DF820BDFD38DABDAD26A3BD2AB6C2BD800DEDBD32E722BEC013FA3DE328323E4261253E002AB4BD80910DBC6DE3EDBD3522023EDADC263E40D5CFBC720F153E93A019BD5DF6CFBD77B8FE3D67A2EE3C5018DBBD32E541BEF3920BBE22612A3E9A8215BC67A286BD004F81BD47EE803DC0276C3D73C0F6BDCADB283E9770DDBDBD173FBE80D1343E335C383E1A4ED93C3339E83C7005433EBA0F74BD9A620ABE07241FBDDA51F6BDA01C293D80105B3D78E6213E9D52EABD35C135BE67AFAE3C577C0CBE0A68D1BDA080EE3DC072F8BD5F3533BEB5AA36BEA74E37BDB57C443EB01193BD33ECD0BB40B4CEBDCB542C3E5A31C9BD0DC05ABD059A0F3EB31F98BD03A0FDBD609019BE0038FEBC42D319BEFD8DEE3DDDFAB43D537C9B3D9AE1B53AAB8C1DBEC066CFBD5F8417BEB34135BE9ACE9ABDF066DBBD4070203E57902DBECD848BBC67C7A2BBE792073E2087E23D6D4035BEE7A331BD9AD19E3C6788F0BD6AAB9E3DEDA1E2BDEAA4063E8A9D00BE1F6025BE735E0B3E4A4D883D9DE8A43D1790E03DA04C71BD5771ABBD0D0E5D3D8DAE953D93570D3D5DBAE7BDA7CA103EF3D290BC138849BE1DD82B3E7013073ECAAD483ECDF9573C425103BE530FAF3D4AF5C7BD8012B3BDA75A703D334B683D7A5BFC3D1DE83FBEE76374BC4AF031BE977688BDEAC638BE4D0BA8BC7A23CBBD73F2D2BD4BD318BE288129BEE3C7473EA2B72EBE1D874BBE139A3D3E42614CBEE70A523D5BC1413E67FCAABD20D02E3D"
20+
}
21+
}
22+
#-}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#include <string>
2+
#include <utility>
3+
#include <vector>
4+
5+
#include "gtest/gtest.h" // from @googletest
6+
#include "lib/Target/OpenFhePke/Interpreter.h"
7+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
8+
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
9+
#include "mlir/include/mlir/IR/OwningOpRef.h" // from @llvm-project
10+
#include "mlir/include/mlir/Parser/Parser.h" // from @llvm-project
11+
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
12+
#include "src/pke/include/openfhe.h" // from @openfhe
13+
14+
namespace mlir {
15+
namespace heir {
16+
namespace openfhe {
17+
18+
using namespace lbcrypto;
19+
using CryptoContextT = CryptoContext<DCRTPoly>;
20+
21+
// Copybara manages this declaration via regex
22+
const std::string WORKSPACE_BASE = "";
23+
24+
TEST(MatmulInterpreterTest, RunTest) {
25+
// Generated by the bazel rule
26+
std::string generatedModuleFile =
27+
WORKSPACE_BASE +
28+
"tests/Examples/openfhe/ckks/conv_diagonal/module.openfhe.mlir";
29+
30+
// Load the MLIR module from a file
31+
MLIRContext context;
32+
initContext(context);
33+
OwningOpRef<ModuleOp> module =
34+
parseSourceFile<ModuleOp>(generatedModuleFile, &context);
35+
Interpreter interpreter(module.get());
36+
37+
TypedCppValue ccInitial =
38+
interpreter.interpret("conv__generate_crypto_context", {})[0];
39+
40+
auto keyPair = std::get<CryptoContextT>(ccInitial.value)->KeyGen();
41+
auto publicKey = keyPair.publicKey;
42+
auto secretKey = keyPair.secretKey;
43+
std::vector<TypedCppValue> args = {ccInitial, TypedCppValue(secretKey)};
44+
TypedCppValue cc = std::move(
45+
interpreter.interpret("conv__configure_crypto_context", args)[0]);
46+
47+
// tensor<1x1x32x32xf32>
48+
std::vector<float> arg0Vals(32 * 32, 0.1f);
49+
50+
TypedCppValue arg0Enc = interpreter.interpret(
51+
"conv__encrypt__arg0",
52+
{cc, TypedCppValue(arg0Vals), TypedCppValue(publicKey)})[0];
53+
54+
TypedCppValue outputEncrypted =
55+
interpreter.interpret("conv", {cc, arg0Enc})[0];
56+
57+
#ifdef OPENFHE_ENABLE_TIMING
58+
interpreter.printTimingResults();
59+
#endif
60+
}
61+
62+
} // namespace openfhe
63+
} // namespace heir
64+
} // namespace mlir

0 commit comments

Comments
 (0)