@@ -182,12 +182,35 @@ Value IRMaterializingVisitor::operator()(const MultiplyNode<SSAValue>& node) {
182182 dyn_cast<lwe::LWEPlaintextType>(lhs.getType ());
183183 lwe::LWEPlaintextType rhsPlaintextType =
184184 dyn_cast<lwe::LWEPlaintextType>(rhs.getType ());
185+ RankedTensorType lhsTensorType = dyn_cast<RankedTensorType>(lhs.getType ());
186+ RankedTensorType rhsTensorType = dyn_cast<RankedTensorType>(rhs.getType ());
187+
188+ // Plaintext-Cleartext case
189+ if (lhsPlaintextType && rhsTensorType) {
190+ auto encodedLhs = cast<lwe::RLWEEncodeOp>(lhs.getDefiningOp ());
191+ return arith::MulFOp::create (builder, encodedLhs.getInput (), rhs)
192+ .getResult ();
193+ }
194+ if (lhsTensorType && rhsPlaintextType) {
195+ auto encodedRhs = cast<lwe::RLWEEncodeOp>(rhs.getDefiningOp ());
196+ return arith::MulFOp::create (builder, lhs, encodedRhs.getInput ())
197+ .getResult ();
198+ }
199+
200+ // Plaintext-plaintext case
201+ if (lhsPlaintextType && rhsPlaintextType) {
202+ auto encodedLhs = cast<lwe::RLWEEncodeOp>(lhs.getDefiningOp ());
203+ auto encodedRhs = cast<lwe::RLWEEncodeOp>(rhs.getDefiningOp ());
204+ auto cleartextOp = arith::MulFOp::create (builder, encodedLhs.getInput (),
205+ encodedRhs.getInput ());
206+ return lwe::RLWEEncodeOp::create (
207+ builder, lhsPlaintextType, cleartextOp.getResult (),
208+ lhsPlaintextType.getPlaintextSpace ().getEncoding (),
209+ lhsPlaintextType.getPlaintextSpace ().getRing ());
210+ }
185211
186212 // Ciphertext-Plaintext case
187213 if (lhsCiphertextType && rhsPlaintextType) {
188- LLVM_DEBUG (llvm::dbgs ()
189- << " Handling Ct-Pt mul:\n\n lhs type = " << lhsCiphertextType
190- << " ,\n\n rhs type = " << rhsPlaintextType << " \n\n " );
191214 auto newRhs = encodeCleartextOperand (lhsCiphertextType, rhs,
192215 /* useDefaultScale=*/ true );
193216 auto ctPtOp = ckks::MulPlainOp::create (builder, lhs, newRhs);
@@ -196,9 +219,24 @@ Value IRMaterializingVisitor::operator()(const MultiplyNode<SSAValue>& node) {
196219 /* rescale=*/ rescaleAfterCtPtMul);
197220 }
198221 if (lhsPlaintextType && rhsCiphertextType) {
199- LLVM_DEBUG (llvm::dbgs ()
200- << " Handling Ct-Pt mul(1):\n\n lhs type = " << lhsPlaintextType
201- << " ,\n\n rhs type = " << rhsCiphertextType << " \n\n " );
222+ auto newLhs = encodeCleartextOperand (rhsCiphertextType, lhs,
223+ /* useDefaultScale=*/ true );
224+ auto ctPtOp = ckks::MulPlainOp::create (builder, newLhs, rhs);
225+ return relinAndRescale (ctPtOp.getResult (),
226+ /* relinearize=*/ false ,
227+ /* rescale=*/ rescaleAfterCtPtMul);
228+ }
229+
230+ // Ciphertext-Cleartext case
231+ if (lhsCiphertextType && rhsTensorType) {
232+ auto newRhs = encodeCleartextOperand (lhsCiphertextType, rhs,
233+ /* useDefaultScale=*/ true );
234+ auto ctPtOp = ckks::MulPlainOp::create (builder, lhs, newRhs);
235+ return relinAndRescale (ctPtOp.getResult (),
236+ /* relinearize=*/ false ,
237+ /* rescale=*/ rescaleAfterCtPtMul);
238+ }
239+ if (lhsTensorType && rhsCiphertextType) {
202240 auto newLhs = encodeCleartextOperand (rhsCiphertextType, lhs,
203241 /* useDefaultScale=*/ true );
204242 auto ctPtOp = ckks::MulPlainOp::create (builder, newLhs, rhs);
@@ -209,9 +247,6 @@ Value IRMaterializingVisitor::operator()(const MultiplyNode<SSAValue>& node) {
209247
210248 // Ciphertext-ciphertext case
211249 assert (lhsCiphertextType && rhsCiphertextType);
212- LLVM_DEBUG (llvm::dbgs () << " Handling Ct-Ct mul(2):\n\n lhs type = "
213- << lhsCiphertextType << " ,\n\n rhs type = "
214- << rhsCiphertextType << " \n\n " );
215250 auto ctCtOp = ckks::MulOp::create (builder, lhs, rhs).getResult ();
216251 return relinAndRescale (ctCtOp, /* relinearize=*/ true ,
217252 /* rescale=*/ true );
0 commit comments