Skip to content

Commit 2ba78b8

Browse files
committed
fix type inference on mul visitor
1 parent 4ecb2f2 commit 2ba78b8

File tree

2 files changed

+45
-11
lines changed

2 files changed

+45
-11
lines changed

lib/Dialect/Orion/Conversions/OrionToCKKS/IRMaterializingVisitor.cpp

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -182,12 +182,35 @@ Value IRMaterializingVisitor::operator()(const MultiplyNode<SSAValue>& node) {
182182
dyn_cast<lwe::LWEPlaintextType>(lhs.getType());
183183
lwe::LWEPlaintextType rhsPlaintextType =
184184
dyn_cast<lwe::LWEPlaintextType>(rhs.getType());
185+
RankedTensorType lhsTensorType = dyn_cast<RankedTensorType>(lhs.getType());
186+
RankedTensorType rhsTensorType = dyn_cast<RankedTensorType>(rhs.getType());
187+
188+
// Plaintext-Cleartext case
189+
if (lhsPlaintextType && rhsTensorType) {
190+
auto encodedLhs = cast<lwe::RLWEEncodeOp>(lhs.getDefiningOp());
191+
return arith::MulFOp::create(builder, encodedLhs.getInput(), rhs)
192+
.getResult();
193+
}
194+
if (lhsTensorType && rhsPlaintextType) {
195+
auto encodedRhs = cast<lwe::RLWEEncodeOp>(rhs.getDefiningOp());
196+
return arith::MulFOp::create(builder, lhs, encodedRhs.getInput())
197+
.getResult();
198+
}
199+
200+
// Plaintext-plaintext case
201+
if (lhsPlaintextType && rhsPlaintextType) {
202+
auto encodedLhs = cast<lwe::RLWEEncodeOp>(lhs.getDefiningOp());
203+
auto encodedRhs = cast<lwe::RLWEEncodeOp>(rhs.getDefiningOp());
204+
auto cleartextOp = arith::MulFOp::create(builder, encodedLhs.getInput(),
205+
encodedRhs.getInput());
206+
return lwe::RLWEEncodeOp::create(
207+
builder, lhsPlaintextType, cleartextOp.getResult(),
208+
lhsPlaintextType.getPlaintextSpace().getEncoding(),
209+
lhsPlaintextType.getPlaintextSpace().getRing());
210+
}
185211

186212
// Ciphertext-Plaintext case
187213
if (lhsCiphertextType && rhsPlaintextType) {
188-
LLVM_DEBUG(llvm::dbgs()
189-
<< "Handling Ct-Pt mul:\n\n lhs type = " << lhsCiphertextType
190-
<< ",\n\n rhs type = " << rhsPlaintextType << "\n\n");
191214
auto newRhs = encodeCleartextOperand(lhsCiphertextType, rhs,
192215
/*useDefaultScale=*/true);
193216
auto ctPtOp = ckks::MulPlainOp::create(builder, lhs, newRhs);
@@ -196,9 +219,24 @@ Value IRMaterializingVisitor::operator()(const MultiplyNode<SSAValue>& node) {
196219
/*rescale=*/rescaleAfterCtPtMul);
197220
}
198221
if (lhsPlaintextType && rhsCiphertextType) {
199-
LLVM_DEBUG(llvm::dbgs()
200-
<< "Handling Ct-Pt mul(1):\n\n lhs type = " << lhsPlaintextType
201-
<< ",\n\n rhs type = " << rhsCiphertextType << "\n\n");
222+
auto newLhs = encodeCleartextOperand(rhsCiphertextType, lhs,
223+
/*useDefaultScale=*/true);
224+
auto ctPtOp = ckks::MulPlainOp::create(builder, newLhs, rhs);
225+
return relinAndRescale(ctPtOp.getResult(),
226+
/*relinearize=*/false,
227+
/*rescale=*/rescaleAfterCtPtMul);
228+
}
229+
230+
// Ciphertext-Cleartext case
231+
if (lhsCiphertextType && rhsTensorType) {
232+
auto newRhs = encodeCleartextOperand(lhsCiphertextType, rhs,
233+
/*useDefaultScale=*/true);
234+
auto ctPtOp = ckks::MulPlainOp::create(builder, lhs, newRhs);
235+
return relinAndRescale(ctPtOp.getResult(),
236+
/*relinearize=*/false,
237+
/*rescale=*/rescaleAfterCtPtMul);
238+
}
239+
if (lhsTensorType && rhsCiphertextType) {
202240
auto newLhs = encodeCleartextOperand(rhsCiphertextType, lhs,
203241
/*useDefaultScale=*/true);
204242
auto ctPtOp = ckks::MulPlainOp::create(builder, newLhs, rhs);
@@ -209,9 +247,6 @@ Value IRMaterializingVisitor::operator()(const MultiplyNode<SSAValue>& node) {
209247

210248
// Ciphertext-ciphertext case
211249
assert(lhsCiphertextType && rhsCiphertextType);
212-
LLVM_DEBUG(llvm::dbgs() << "Handling Ct-Ct mul(2):\n\n lhs type = "
213-
<< lhsCiphertextType << ",\n\n rhs type = "
214-
<< rhsCiphertextType << "\n\n");
215250
auto ctCtOp = ckks::MulOp::create(builder, lhs, rhs).getResult();
216251
return relinAndRescale(ctCtOp, /*relinearize=*/true,
217252
/*rescale=*/true);

lib/Dialect/Orion/Conversions/OrionToCKKS/OrionToCKKS.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,9 @@ WalkResult handleMul(ckks::MulOp op) {
487487
int64_t operandIndex = (lhsLevel < rhsLevel) ? 1 : 0;
488488
debugLevelAndScale(levelReduceOp.getResult().getType(), "operand");
489489
op->setOperand(operandIndex, levelReduceOp.getResult());
490-
return handleInferTypeOpInterface(op);
491490
}
492491

493-
return WalkResult::advance();
492+
return handleInferTypeOpInterface(op);
494493
}
495494

496495
struct OrionToCKKS : public impl::OrionToCKKSBase<OrionToCKKS> {

0 commit comments

Comments
 (0)