|
7 | 7 | #include <functional> |
8 | 8 | #include <map> |
9 | 9 | #include <memory> |
| 10 | +#include <ostream> |
10 | 11 | #include <string> |
11 | 12 | #include <utility> |
12 | 13 | #include <variant> |
@@ -176,23 +177,23 @@ std::vector<TypedCppValue> Interpreter::interpret( |
176 | 177 |
|
177 | 178 | void Interpreter::visit(Operation* op) { |
178 | 179 | llvm::TypeSwitch<Operation*>(op) |
179 | | - .Case<arith::ConstantOp, arith::AddIOp, arith::SubIOp, arith::MulIOp, |
180 | | - arith::MulFOp, arith::DivSIOp, arith::RemSIOp, arith::AndIOp, |
181 | | - arith::CmpIOp, arith::SelectOp, arith::ExtUIOp, arith::ExtSIOp, |
182 | | - arith::ExtFOp, arith::IndexCastOp, arith::MinSIOp, arith::MaxSIOp, |
183 | | - tensor::EmptyOp, tensor::ExtractOp, tensor::InsertOp, |
184 | | - tensor::SplatOp, tensor::FromElementsOp, tensor::ConcatOp, |
185 | | - tensor::ExtractSliceOp, tensor::InsertSliceOp, |
186 | | - tensor::CollapseShapeOp, tensor::ExpandShapeOp, linalg::BroadcastOp, |
187 | | - scf::ForOp, scf::IfOp, scf::YieldOp, affine::AffineForOp, |
188 | | - affine::AffineYieldOp, lwe::RLWEDecodeOp, AddOp, AddPlainOp, SubOp, |
189 | | - SubPlainOp, MulOp, MulNoRelinOp, MulPlainOp, MulConstOp, NegateOp, |
190 | | - SquareOp, RelinOp, ModReduceOp, LevelReduceOp, RotOp, AutomorphOp, |
191 | | - KeySwitchOp, BootstrapOp, EncryptOp, DecryptOp, |
192 | | - MakePackedPlaintextOp, MakeCKKSPackedPlaintextOp, GenParamsOp, |
193 | | - GenContextOp, GenRotKeyOp, GenMulKeyOp, GenBootstrapKeyOp, |
194 | | - SetupBootstrapOp, FastRotationOp, FastRotationPrecomputeOp>( |
195 | | - [&](auto op) { visit(op); }) |
| 180 | + .Case<arith::ConstantOp, arith::AddIOp, arith::AddFOp, arith::SubIOp, |
| 181 | + arith::MulIOp, arith::MulFOp, arith::DivSIOp, arith::RemSIOp, |
| 182 | + arith::AndIOp, arith::CmpIOp, arith::SelectOp, arith::ExtUIOp, |
| 183 | + arith::ExtSIOp, arith::ExtFOp, arith::FloorDivSIOp, |
| 184 | + arith::IndexCastOp, arith::MinSIOp, arith::MaxSIOp, tensor::EmptyOp, |
| 185 | + tensor::ExtractOp, tensor::InsertOp, tensor::SplatOp, |
| 186 | + tensor::FromElementsOp, tensor::ConcatOp, tensor::ExtractSliceOp, |
| 187 | + tensor::InsertSliceOp, tensor::CollapseShapeOp, |
| 188 | + tensor::ExpandShapeOp, linalg::BroadcastOp, scf::ForOp, scf::IfOp, |
| 189 | + scf::YieldOp, affine::AffineForOp, affine::AffineYieldOp, |
| 190 | + lwe::RLWEDecodeOp, AddOp, AddPlainOp, SubOp, SubPlainOp, MulOp, |
| 191 | + MulNoRelinOp, MulPlainOp, MulConstOp, NegateOp, SquareOp, RelinOp, |
| 192 | + ModReduceOp, LevelReduceOp, RotOp, AutomorphOp, KeySwitchOp, |
| 193 | + BootstrapOp, EncryptOp, DecryptOp, MakePackedPlaintextOp, |
| 194 | + MakeCKKSPackedPlaintextOp, GenParamsOp, GenContextOp, GenRotKeyOp, |
| 195 | + GenMulKeyOp, GenBootstrapKeyOp, SetupBootstrapOp, FastRotationOp, |
| 196 | + FastRotationPrecomputeOp>([&](auto op) { visit(op); }) |
196 | 197 | .Default([&](Operation* op) { |
197 | 198 | OperationName opName = op->getName(); |
198 | 199 | op->emitError() << "Unsupported operation " << opName.getStringRef() |
@@ -305,6 +306,15 @@ void Interpreter::visit(arith::AddIOp op) { |
305 | 306 | env.insert_or_assign(op.getResult(), std::move(result)); |
306 | 307 | } |
307 | 308 |
|
| 309 | +void Interpreter::visit(arith::AddFOp op) { |
| 310 | + auto lhs = env.at(op.getLhs()); |
| 311 | + auto rhs = env.at(op.getRhs()); |
| 312 | + auto result = applyBinop( |
| 313 | + op, lhs, rhs, [](int a, int b) { return a + b; }, |
| 314 | + [](float a, float b) { return a + b; }); |
| 315 | + env.insert_or_assign(op.getResult(), std::move(result)); |
| 316 | +} |
| 317 | + |
308 | 318 | void Interpreter::visit(arith::SubIOp op) { |
309 | 319 | auto lhs = env.at(op.getLhs()); |
310 | 320 | auto rhs = env.at(op.getRhs()); |
@@ -442,6 +452,16 @@ void Interpreter::visit(arith::ExtUIOp op) { |
442 | 452 | } |
443 | 453 | } |
444 | 454 |
|
| 455 | +void Interpreter::visit(arith::FloorDivSIOp op) { |
| 456 | + auto lhs = env.at(op.getLhs()); |
| 457 | + auto rhs = env.at(op.getRhs()); |
| 458 | + auto result = applyBinop( |
| 459 | + op, lhs, rhs, |
| 460 | + [](int a, int b) { return std::floor(static_cast<float>(a) / b); }, |
| 461 | + [](float a, float b) { return std::floor(a / b); }); |
| 462 | + env.insert_or_assign(op.getResult(), std::move(result)); |
| 463 | +} |
| 464 | + |
445 | 465 | void Interpreter::visit(arith::ExtSIOp op) { |
446 | 466 | auto operand = env.at(op.getIn()); |
447 | 467 | // For signed extension, we just convert the value to a larger type |
@@ -1381,7 +1401,31 @@ void Interpreter::visit(MakePackedPlaintextOp op) { |
1381 | 1401 | } |
1382 | 1402 |
|
1383 | 1403 | void Interpreter::visit(MakeCKKSPackedPlaintextOp op) { |
| 1404 | + auto checkIfZero = [&](TypedCppValue value) -> bool { |
| 1405 | + if (std::holds_alternative<std::vector<float>>(value.value)) { |
| 1406 | + auto vec = std::get<std::vector<float>>(value.value); |
| 1407 | + if (llvm::all_of(vec, [](float v) { return v == 0.0f; })) { |
| 1408 | + return true; |
| 1409 | + } |
| 1410 | + } else if (std::holds_alternative<std::vector<double>>(value.value)) { |
| 1411 | + auto vec = std::get<std::vector<double>>(value.value); |
| 1412 | + if (llvm::all_of(vec, [](double v) { return v == 0.0; })) { |
| 1413 | + return true; |
| 1414 | + } |
| 1415 | + } else if (std::holds_alternative<std::vector<int>>(value.value)) { |
| 1416 | + auto vec = std::get<std::vector<int>>(value.value); |
| 1417 | + if (llvm::all_of(vec, [](int v) { return v == 0; })) { |
| 1418 | + return true; |
| 1419 | + } |
| 1420 | + } |
| 1421 | + return false; |
| 1422 | + }; |
| 1423 | + |
1384 | 1424 | auto value = env.at(op.getValue()); |
| 1425 | + if (checkIfZero(value)) { |
| 1426 | + op.dump(); |
| 1427 | + std::cout << "MakeCKKSPackedPlaintextOp: value is zero" << std::endl; |
| 1428 | + } |
1385 | 1429 | auto cc = std::get<CryptoContextT>(env.at(op.getCryptoContext()).value); |
1386 | 1430 | if (std::holds_alternative<std::vector<float>>(value.value)) { |
1387 | 1431 | auto vec = std::get<std::vector<float>>(value.value); |
|
0 commit comments