Skip to content

Commit d3f9d76

Browse files
asraacopybara-github
authored andcommitted
ckks: add optional targetLevel attribute to ckks.bootstrap
See the issue I filed #2436 PiperOrigin-RevId: 838756394
1 parent de41a81 commit d3f9d76

File tree

5 files changed

+66
-1
lines changed

5 files changed

+66
-1
lines changed

lib/Dialect/CKKS/IR/CKKSOps.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
#include "lib/Dialect/CKKS/IR/CKKSOps.h"
22

3+
#include <cstdint>
34
#include <optional>
45

56
#include "lib/Dialect/LWE/IR/LWEOps.h"
67
#include "lib/Dialect/LWE/IR/LWEPatterns.h"
8+
#include "lib/Dialect/LWE/IR/LWETypes.h"
79
#include "mlir/include/mlir/IR/Location.h" // from @llvm-project
810
#include "mlir/include/mlir/IR/MLIRContext.h" // from @llvm-project
911
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
@@ -33,6 +35,21 @@ LogicalResult RescaleOp::verify() {
3335

3436
LogicalResult LevelReduceOp::verify() { return lwe::verifyLevelReduceOp(this); }
3537

38+
LogicalResult BootstrapOp::verify() {
39+
std::optional<int64_t> targetLevel = getTargetLevel();
40+
if (targetLevel.has_value()) {
41+
// If a target level is specified, then the result ciphertext must have that
42+
// many levels.
43+
lwe::LWECiphertextType outputType = lwe::getCtTy(getOutput());
44+
if (outputType.getModulusChain().getCurrent() != targetLevel.value()) {
45+
return emitOpError() << "output ciphertext must have "
46+
<< targetLevel.value() << " levels but has "
47+
<< outputType.getModulusChain().getCurrent();
48+
}
49+
}
50+
return success();
51+
}
52+
3653
//===----------------------------------------------------------------------===//
3754
// Op type inference.
3855
//===----------------------------------------------------------------------===//

lib/Dialect/CKKS/IR/CKKSOps.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,13 +213,16 @@ def CKKS_BootstrapOp : CKKS_Op<"bootstrap", [ElementwiseMappable]> {
213213
}];
214214

215215
let arguments = (ins
216-
LWECiphertextLike:$input
216+
LWECiphertextLike:$input,
217+
OptionalAttr<I64Attr>:$targetLevel
217218
);
218219

219220
let results = (outs
220221
LWECiphertextLike:$output
221222
);
222223

224+
let hasVerifier = 1;
225+
223226
let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` qualified(type($output))" ;
224227
}
225228

lib/Dialect/LWE/Conversions/LWEToOpenfhe/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ cc_library(
1818
"@heir//lib/Dialect/CKKS/IR:Dialect",
1919
"@heir//lib/Dialect/LWE/IR:Dialect",
2020
"@heir//lib/Dialect/Openfhe/IR:Dialect",
21+
"@heir//lib/Parameters/CKKS:Params",
2122
"@heir//lib/Utils",
2223
"@heir//lib/Utils:ConversionUtils",
2324
"@llvm-project//llvm:Support",

lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "lib/Dialect/BGV/IR/BGVDialect.h"
66
#include "lib/Dialect/BGV/IR/BGVOps.h"
7+
#include "lib/Dialect/CKKS/IR/CKKSAttributes.h"
78
#include "lib/Dialect/CKKS/IR/CKKSDialect.h"
89
#include "lib/Dialect/CKKS/IR/CKKSOps.h"
910
#include "lib/Dialect/LWE/Conversions/LWEToOpenfhe/LWEToOpenfhe.h"
@@ -14,6 +15,7 @@
1415
#include "lib/Dialect/Openfhe/IR/OpenfheDialect.h"
1516
#include "lib/Dialect/Openfhe/IR/OpenfheOps.h"
1617
#include "lib/Dialect/Openfhe/IR/OpenfheTypes.h"
18+
#include "lib/Parameters/CKKS/Params.h"
1719
#include "lib/Utils/ConversionUtils.h"
1820
#include "lib/Utils/Utils.h"
1921
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
@@ -22,6 +24,7 @@
2224
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
2325
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
2426
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
27+
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
2528
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
2629
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
2730
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
@@ -275,6 +278,20 @@ struct ConvertBootstrapOp : public OpConversionPattern<ckks::BootstrapOp> {
275278
return rewriter.notifyMatchFailure(op, "No crypto context arg");
276279
}
277280

281+
// TODO(#2436): Support bootstrap target level in OpenFHE
282+
if (op.getTargetLevel().has_value()) {
283+
// Right now we don't support any bootstrap ops with a target level.
284+
// Ideally, we would want to check that the target level is equal to the
285+
// number of Qis available in the scheme parameters (max levels) minus the
286+
// levels consumed by bootstrapping to emit a full bootstrap op. The
287+
// latter info is not persisted in the IR. So we simply rely on higher
288+
// level passes with access to the bootstrap waterline to remove the
289+
// target level attribute.
290+
// TODO(#1207): Persist the number of consumed levels from bootstrapping
291+
return rewriter.notifyMatchFailure(
292+
op, "variadic bootstrapping is not supported in OpenFHE");
293+
}
294+
278295
Value cryptoContext = result.value();
279296
rewriter.replaceOpWithNewOp<openfhe::BootstrapOp>(
280297
op, op.getOutput().getType(), cryptoContext, adaptor.getInput());

tests/Dialect/CKKS/IR/verifier.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,30 @@ module attributes {ckks.schemeParam = #ckks.scheme_param<logN = 13, Q = [3602879
8888
return %ct_2 : !ct_L0_
8989
}
9090
}
91+
92+
// -----
93+
94+
!Z35184372121601_i64_ = !mod_arith.int<35184372121601 : i64>
95+
!Z36028797019389953_i64_ = !mod_arith.int<36028797019389953 : i64>
96+
// note the scaling factor is 45
97+
// after mul it should be 90
98+
#inverse_canonical_encoding = #lwe.inverse_canonical_encoding<scaling_factor = 45>
99+
#key = #lwe.key<>
100+
#modulus_chain_L1_C0_ = #lwe.modulus_chain<elements = <36028797019389953 : i64, 35184372121601 : i64>, current = 0>
101+
#modulus_chain_L1_C1_ = #lwe.modulus_chain<elements = <36028797019389953 : i64, 35184372121601 : i64>, current = 1>
102+
#ring_f64_1_x1024_ = #polynomial.ring<coefficientType = f64, polynomialModulus = <1 + x**1024>>
103+
!rns_L0_ = !rns.rns<!Z36028797019389953_i64_>
104+
!rns_L1_ = !rns.rns<!Z36028797019389953_i64_, !Z35184372121601_i64_>
105+
#ring_rns_L0_1_x1024_ = #polynomial.ring<coefficientType = !rns_L0_, polynomialModulus = <1 + x**1024>>
106+
#ring_rns_L1_1_x1024_ = #polynomial.ring<coefficientType = !rns_L1_, polynomialModulus = <1 + x**1024>>
107+
#ciphertext_space_L0_ = #lwe.ciphertext_space<ring = #ring_rns_L0_1_x1024_, encryption_type = lsb>
108+
#ciphertext_space_L1_ = #lwe.ciphertext_space<ring = #ring_rns_L1_1_x1024_, encryption_type = lsb>
109+
!ct_L0_ = !lwe.lwe_ciphertext<application_data = <message_type = i16>, plaintext_space = <ring = #ring_f64_1_x1024_, encoding = #inverse_canonical_encoding>, ciphertext_space = #ciphertext_space_L0_, key = #key, modulus_chain = #modulus_chain_L1_C0_>
110+
!ct_L1_ = !lwe.lwe_ciphertext<application_data = <message_type = i16>, plaintext_space = <ring = #ring_f64_1_x1024_, encoding = #inverse_canonical_encoding>, ciphertext_space = #ciphertext_space_L1_, key = #key, modulus_chain = #modulus_chain_L1_C1_>
111+
module attributes {ckks.schemeParam = #ckks.scheme_param<logN = 13, Q = [36028797019389953, 35184372121601], P = [36028797019488257], logDefaultScale = 45>, scheme.ckks} {
112+
func.func @bootstrap(%ct: !ct_L0_) -> !ct_L1_ {
113+
// expected-error@+1 {{'ckks.bootstrap' op output ciphertext must have 2 levels but has 1}}
114+
%ct_2 = ckks.bootstrap %ct {targetLevel = 2} : !ct_L0_ -> !ct_L1_
115+
return %ct_2 : !ct_L1_
116+
}
117+
}

0 commit comments

Comments
 (0)