Skip to content

Commit 0fe389a

Browse files
Merge pull request #1906 from ZenithalHourlyRate:mgmt-init-multi-use
PiperOrigin-RevId: 771058735
2 parents c80991f + e4190c5 commit 0fe389a

File tree

9 files changed

+106
-16
lines changed

9 files changed

+106
-16
lines changed

lib/Dialect/Mgmt/IR/MgmtOps.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ namespace mlir {
77
namespace heir {
88
namespace mgmt {
99

10+
//===----------------------------------------------------------------------===//
11+
// Canonicalization Patterns
12+
//===----------------------------------------------------------------------===//
13+
1014
// Kept inside a namespace because it generates a function called
1115
// populateWithGenerated, which can conflict with other generated patterns.
1216
#include "lib/Dialect/Mgmt/IR/MgmtCanonicalization.cpp.inc"
@@ -28,6 +32,16 @@ void AdjustScaleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2832
results.add<ModReduceAfterAdjustScale>(context);
2933
}
3034

35+
//===----------------------------------------------------------------------===//
36+
// Utils
37+
//===----------------------------------------------------------------------===//
38+
39+
void cleanupInitOp(Operation *top) {
40+
top->walk([&](mgmt::InitOp initOp) {
41+
if (initOp->use_empty()) initOp.erase();
42+
});
43+
}
44+
3145
} // namespace mgmt
3246
} // namespace heir
3347
} // namespace mlir

lib/Dialect/Mgmt/IR/MgmtOps.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,4 +12,15 @@
1212
#define GET_OP_CLASSES
1313
#include "lib/Dialect/Mgmt/IR/MgmtOps.h.inc"
1414

15+
namespace mlir {
16+
namespace heir {
17+
namespace mgmt {
18+
19+
/// Remove all unused mgmt.init ops from the top operation.
20+
void cleanupInitOp(Operation *top);
21+
22+
} // namespace mgmt
23+
} // namespace heir
24+
} // namespace mlir
25+
1526
#endif // LIB_DIALECT_MGMT_IR_MGMTOPS_H_

lib/Dialect/Mgmt/IR/MgmtOps.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def Mgmt_AdjustScaleOp : Mgmt_Op<"adjust_scale"> {
143143
let hasCanonicalizer = 1;
144144
}
145145

146-
def Mgmt_InitOp : Mgmt_Op<"init"> {
146+
def Mgmt_InitOp : Mgmt_Op<"init",
147+
[MemoryEffects<[MemWrite]>, ElementwiseMappable, SameOperandsAndResultType, ConditionallySpeculatable]> {
148+
147149
let summary = "Init the plaintext with mgmt attributes";
148150

149151
let description = [{
@@ -161,13 +163,30 @@ def Mgmt_InitOp : Mgmt_Op<"init"> {
161163

162164
To address the problem, for each _use_ of the plaintext, we insert an `mgmt.init`
163165
operation to initialize the plaintext with `mgmt` attributes.
166+
167+
Technical reasons for registering memory effects:
168+
169+
Register a (bogus) memory effect to prevent CSE from merging this op.
170+
Two mgmt.init ops could be seen as equivalent only if they have the same
171+
MgmtAttr with *level/dimension/scale* annotated, otherwise we could not
172+
judge whether they are equivalent or not. In practice, we create the op first
173+
and only in later analyses we know whether they are equivalent or not.
174+
175+
ConditionallySpeculatable is for isSpeculatable check in hoisting canonicalization.
164176
}];
165177

166178
let arguments = (ins
167179
AnyType:$input
168180
);
169181
let results = (outs AnyType:$output);
170182
let assemblyFormat = "operands attr-dict `:` type($output)";
183+
184+
let extraClassDeclaration = [{
185+
/// Interface method for ConditionallySpeculatable.
186+
Speculation::Speculatability getSpeculatability() {
187+
return Speculation::Speculatable;
188+
}
189+
}];
171190
}
172191

173192
#endif // LIB_DIALECT_MGMT_IR_MGMTOPS_TD_

lib/Dialect/Secret/Conversions/SecretToBGV/SecretToBGV.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ struct SecretToBGV : public impl::SecretToBGVBase<SecretToBGV> {
274274
}
275275

276276
clearAttrs(getOperation(), mgmt::MgmtDialect::kArgMgmtAttrName);
277+
mgmt::cleanupInitOp(getOperation());
277278
}
278279
};
279280

lib/Dialect/Secret/Conversions/SecretToCKKS/SecretToCKKS.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ struct SecretToCKKS : public impl::SecretToCKKSBase<SecretToCKKS> {
407407
}
408408

409409
clearAttrs(getOperation(), mgmt::MgmtDialect::kArgMgmtAttrName);
410+
mgmt::cleanupInitOp(getOperation());
410411
}
411412
};
412413

lib/Transforms/SecretInsertMgmt/SecretInsertMgmtBFV.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ struct SecretInsertMgmtBFV
9393
(void)walkAndApplyPatterns(getOperation(), std::move(patternsRelinearize));
9494

9595
auto level = maxMulDepth;
96-
// call Canonicalizer here because mgmt.init ops need to be moved out of the
97-
// secret.generic.
98-
// annotate mgmt attribute with all levels set to mulDepth
96+
// 1. Canonicalizer moves mgmt::InitOp out of secret.generic.
97+
// 2. AnnotateMgmt will merge level and dimension into MgmtAttr, for further
98+
// lowering. For B/FV, all levels should be set to mulDepth.
9999
OpPassManager pipeline("builtin.module");
100100
pipeline.addPass(createCanonicalizerPass());
101101
mgmt::AnnotateMgmtOptions annotateMgmtOptions;

lib/Transforms/SecretInsertMgmt/SecretInsertMgmtBGV.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,12 @@ struct SecretInsertMgmtBGV
126126
(void)walkAndApplyPatterns(getOperation(), std::move(patternsMulDepth));
127127
}
128128

129-
// call Canonicalizer here because mgmt ops need to be ordered
130-
// call CSE here because there may be redundant mod reduce
131-
// one Value may get mod reduced multiple times in
132-
// multiple Uses
133-
//
134-
// also run annotate-mgmt for lowering
129+
// 1. Canonicalizer reorders mgmt ops like Rescale/LevelReduce/AdjustScale.
130+
// This is important for AnnotateMgmt.
131+
// Canonicalizer also moves mgmt::InitOp out of secret.generic.
132+
// 2. CSE removes redundant mgmt::ModReduceOp.
133+
// 3. AnnotateMgmt will merge level and dimension into MgmtAttr, for further
134+
// lowering.
135135
OpPassManager pipeline("builtin.module");
136136
pipeline.addPass(createCanonicalizerPass());
137137
pipeline.addPass(createCSEPass());

lib/Transforms/SecretInsertMgmt/SecretInsertMgmtCKKS.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,12 @@ struct SecretInsertMgmtCKKS
226226
(void)walkAndApplyPatterns(getOperation(), std::move(patternsMulDepth));
227227
}
228228

229-
// call Canonicalizer here because mgmt ops need to be ordered
230-
// call CSE here because there may be redundant mod reduce
231-
// one Value may get mod reduced multiple times in
232-
// multiple Uses
233-
//
234-
// also run annotate-mgmt for lowering
229+
// 1. Canonicalizer reorders mgmt ops like Rescale/LevelReduce/AdjustScale.
230+
// This is important for AnnotateMgmt.
231+
// Canonicalizer also moves mgmt::InitOp out of secret.generic.
232+
// 2. CSE removes redundant mgmt::ModReduceOp.
233+
// 3. AnnotateMgmt will merge level and dimension into MgmtAttr, for further
234+
// lowering.
235235
OpPassManager pipeline("builtin.module");
236236
pipeline.addPass(createCanonicalizerPass());
237237
pipeline.addPass(createCSEPass());

tests/Transforms/secret_insert_mgmt/bgv/init.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,48 @@ module {
5252
} -> !secret.secret<i32>
5353
return %0 : !secret.secret<i32>
5454
}
55+
56+
// CHECK: @pt_multiple_uses
57+
// CHECK-SAME: (%[[arg0:.*]]: !secret.secret<i16> {mgmt.mgmt = #mgmt.mgmt<level = 2>},
58+
// CHECK-SAME: %[[arg1:.*]]: i16)
59+
func.func @pt_multiple_uses(%arg0: !secret.secret<i16>, %arg1: i16) -> (!secret.secret<i16>) {
60+
// CHECK: %[[v0:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : i16
61+
// CHECK: %[[v1:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt<level = 1>} : i16
62+
%0 = secret.generic(%arg0: !secret.secret<i16>) {
63+
^body(%input0: i16):
64+
// CHECK: arith.addi
65+
// CHECK-SAME: %[[v0]]
66+
%1 = arith.addi %input0, %arg1 : i16
67+
// CHECK: arith.muli
68+
%2 = arith.muli %1, %1 : i16
69+
// CHECK: arith.muli
70+
// CHECK-SAME: %[[v1]]
71+
%3 = arith.muli %2, %arg1 : i16
72+
secret.yield %3 : i16
73+
} -> !secret.secret<i16>
74+
return %0 : !secret.secret<i16>
75+
}
76+
77+
// CHECK: @pt_multiple_uses_2
78+
// CHECK-SAME: (%[[arg0:.*]]: !secret.secret<i16> {mgmt.mgmt = #mgmt.mgmt<level = 2>},
79+
// CHECK-SAME: %[[arg1:.*]]: i16)
80+
func.func @pt_multiple_uses_2(%arg0: !secret.secret<i16>, %arg1: i16) -> (!secret.secret<i16>) {
81+
// CHECK: %[[v0:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : i16
82+
// Note: these two mgmt.init should not merge, as later optimization like lazy relinearization
83+
// or populate-scale will make them different.
84+
// CHECK: %[[v1:.*]] = mgmt.init %[[arg1]] {mgmt.mgmt = #mgmt.mgmt<level = 2>} : i16
85+
%0 = secret.generic(%arg0: !secret.secret<i16>) {
86+
^body(%input0: i16):
87+
// CHECK: arith.addi
88+
// CHECK-SAME: %[[v0]]
89+
%1 = arith.addi %input0, %arg1 : i16
90+
// CHECK: arith.muli
91+
%2 = arith.muli %1, %1 : i16
92+
// CHECK: arith.addi
93+
// CHECK-SAME: %[[v1]]
94+
%3 = arith.addi %2, %arg1 : i16
95+
secret.yield %3 : i16
96+
} -> !secret.secret<i16>
97+
return %0 : !secret.secret<i16>
98+
}
5599
}

0 commit comments

Comments
 (0)