Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ LogicalResult convertRotateAndReduceOp(RotateAndReduceOp op) {
reduceOp = op.getReduceOp()->getValue().str();
}
implementedKernel = implementRotateAndReduce(vectorLeaf, plaintextsLeaf,
period, steps, reduceOp);
period, steps, {}, reduceOp);
IRRewriter rewriter(op.getContext());
rewriter.setInsertionPointAfter(op);
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Expand Down
30 changes: 24 additions & 6 deletions lib/Kernel/KernelImplementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <iostream>
#include <memory>
#include <optional>
#include <string>
Expand Down Expand Up @@ -117,6 +118,7 @@ std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
implementBabyStepGiantStep(
const T& giantSteppedOperand, const T& babySteppedOperand, int64_t period,
int64_t steps, DagExtractor<T> extractFunc,
std::map<int, bool> nonZeroDiagonals = {},
const DerivedRotationIndexFn& derivedRotationIndexFn =
defaultDerivedRotationIndexFn) {
using NodeTy = ArithmeticDagNode<T>;
Expand Down Expand Up @@ -147,15 +149,29 @@ implementBabyStepGiantStep(
int64_t innerRotAmount =
derivedRotationIndexFn(giantStepSize, j, i, period);
size_t extractionIndex = i + j * giantStepSize;

// If the extractIndex is not in the nonZeroDiagonals, then the value is
// zero and we can skip the multiplication.
if (!nonZeroDiagonals.empty() &&
!nonZeroDiagonals.contains(extractionIndex)) {
continue;
}

auto plaintext = extractFunc(babySteppedDag, extractionIndex);
auto rotatedPlaintext = NodeTy::leftRotate(plaintext, innerRotAmount);
auto multiplied = NodeTy::mul(rotatedPlaintext, babyStepVals[i]);
innerSum =
innerSum == nullptr ? multiplied : NodeTy::add(innerSum, multiplied);
}

auto rotatedSum = NodeTy::leftRotate(innerSum, period * j * giantStepSize);
result = result == nullptr ? rotatedSum : NodeTy::add(result, rotatedSum);
auto rotatedSum =
innerSum == nullptr
? nullptr
: NodeTy::leftRotate(innerSum, period * j * giantStepSize);
result = result == nullptr
? rotatedSum
: (rotatedSum == nullptr ? result
: NodeTy::add(result, rotatedSum));
}

return result;
Expand Down Expand Up @@ -185,6 +201,7 @@ std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
std::shared_ptr<ArithmeticDagNode<T>>>
implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
int64_t period, int64_t steps,
std::map<int, bool> nonZeroDiagonals = {},
const std::string& reduceOp = "arith.addi") {
using NodeTy = ArithmeticDagNode<T>;
auto performReduction = [&](std::shared_ptr<NodeTy> left,
Expand Down Expand Up @@ -217,7 +234,7 @@ implementRotateAndReduce(const T& vector, std::optional<T> plaintexts,
};

return implementBabyStepGiantStep<T>(vector, plaintexts.value(), period,
steps, extractFunc);
steps, extractFunc, nonZeroDiagonals);
}

// Returns an arithmetic DAG that implements a baby-step-giant-step between
Expand Down Expand Up @@ -248,7 +265,7 @@ implementCiphertextCiphertextBabyStepGiantStep(
int64_t extractionIndex) { return babySteppedDag; };

return implementBabyStepGiantStep<T>(giantSteppedOperand, babySteppedOperand,
period, steps, extractFunc,
period, steps, extractFunc, {},
derivedRotationIndexFn);
}

Expand All @@ -260,13 +277,14 @@ template <typename T>
std::enable_if_t<std::is_base_of<AbstractValue, T>::value,
std::shared_ptr<ArithmeticDagNode<T>>>
implementHaleviShoup(const T& vector, const T& matrix,
std::vector<int64_t> originalMatrixShape) {
std::vector<int64_t> originalMatrixShape,
std::map<int, bool> nonZeroDiagonals = {}) {
using NodeTy = ArithmeticDagNode<T>;
int64_t numRotations = matrix.getShape()[0];

auto rotateAndReduceResult = implementRotateAndReduce<T>(
vector, std::optional<T>(matrix), /*period=*/1,
/*steps=*/numRotations);
/*steps=*/numRotations, nonZeroDiagonals);

auto summedShifts = rotateAndReduceResult;

Expand Down
4 changes: 3 additions & 1 deletion lib/Pipelines/ArithmeticPipelineRegistration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ void mlirToSecretArithmeticPipelineBuilder(
pm.addPass(createActivationCanonicalizations());
pm.addPass(createSelectRewrite());
pm.addPass(createCompareToSignRewrite());
pm.addPass(createCanonicalizerPass());
pm.addPass(createCSEPass());

// Vectorize and optimize rotations
// TODO(#2320): figure out where this fits in the new pipeline
Expand Down Expand Up @@ -493,8 +495,8 @@ BackendPipelineBuilder toLattigoPipelineBuilder() {

void linalgPreprocessingBuilder(OpPassManager& manager) {
manager.addPass(createInlineActivations());
manager.addPass(createDropUnitDims());
manager.addPass(createLinalgCanonicalizations());
manager.addPass(createDropUnitDims());
manager.addPass(createFoldConstantTensors());
manager.addPass(createCanonicalizerPass());
manager.addPass(createSymbolDCEPass());
Expand Down
12 changes: 12 additions & 0 deletions lib/Target/OpenFhePke/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ std::vector<TypedCppValue> Interpreter::interpret(
}

void Interpreter::visit(Operation* op) {
std::cout << "Visiting operation: " << op->getName().getStringRef().str()
<< "\n";
llvm::TypeSwitch<Operation*>(op)
.Case<arith::ConstantOp, arith::AddIOp, arith::AddFOp, arith::SubIOp,
arith::MulIOp, arith::MulFOp, arith::DivSIOp, arith::RemSIOp,
Expand All @@ -199,6 +201,16 @@ void Interpreter::visit(Operation* op) {
op->emitError() << "Unsupported operation " << opName.getStringRef()
<< " in interpreter";
});
// If any of the operations op operands have no more uses, then remove
// them from the end.
if (!op->getParentOfType<affine::AffineForOp>() &&
!op->getParentOfType<scf::ForOp>()) {
for (auto operand : op->getOperands()) {
if (liveness.isDeadAfter(operand, op)) {
env.erase(operand);
}
}
}
}

void Interpreter::visit(arith::ConstantOp op) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -674,13 +674,30 @@ struct ConvertLinalgConv2D
cast<TypedValue<RankedTensorType>>(adaptor.getInputs()[1]);
SSAValue matrixLeaf(matrix);

// The original matrix shape is the shape of the expanded filter.
// The original matrix shape is the shape of the expanded filter before
// diagonalization. This is 28x28 for LeNet
RankedTensorType expandedMatrixType = get2dConvFilterExpandedType(
cast<RankedTensorType>(op.getInputs()[1].getType()),
cast<RankedTensorType>(op.getInputs()[0].getType()), /*padding=*/0);

// Get non-zero diagonals of the diagonalized expanded filter matrix.
LayoutAttr filterLayout = getLayoutAttr(adaptor.getInputs()[1]);
auto filterRelation = filterLayout.getIntegerRelation();
PointCollector collector;
getRangePoints(filterRelation, collector);
std::map<int, bool> nonZeroDiagonals;
for (auto point : collector.points) {
nonZeroDiagonals[point[0]] = true;
}
for (auto [ct, val] : nonZeroDiagonals) {
std::cout << ct << ", ";
}
std::cout << "\n";
std::cout << "nonZero diagonal size: " << nonZeroDiagonals.size() << "\n";

std::shared_ptr<ArithmeticDagNode<SSAValue>> implementedKernel =
implementHaleviShoup(vectorLeaf, matrixLeaf,
expandedMatrixType.getShape());
expandedMatrixType.getShape(), nonZeroDiagonals);

rewriter.setInsertionPointAfter(op);
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,9 +517,9 @@ struct RewriteAvgPoolAsConv2D
RankedTensorType twoDInputType =
RankedTensorType::get({inputTy.getDimSize(2), inputTy.getDimSize(3)},
inputTy.getElementType());
Value convOutput = rewriter.create<tensor::EmptyOp>(
poolOp.getLoc(), twoDOutputType.getShape(),
twoDOutputType.getElementType());
Value convOutput = tensor::EmptyOp::create(rewriter, poolOp.getLoc(),
twoDOutputType.getShape(),
twoDOutputType.getElementType());
for (int n = 0; n < inputTy.getDimSize(0); ++n) {
for (int c = 0; c < inputTy.getDimSize(1); ++c) {
// Compute the 2-D constant convolution.
Expand All @@ -531,9 +531,9 @@ struct RewriteAvgPoolAsConv2D
rewriter.getIndexAttr(inputTy.getDimSize(2)),
rewriter.getIndexAttr(inputTy.getDimSize(3))};
SmallVector<OpFoldResult> strides(4, rewriter.getIndexAttr(1));
auto extractInputOp = rewriter.create<tensor::ExtractSliceOp>(
poolOp.getLoc(), twoDInputType, poolOp.getInputs()[0], offsets,
inputSizes, strides);
auto extractInputOp = tensor::ExtractSliceOp::create(
rewriter, poolOp.getLoc(), twoDInputType, poolOp.getInputs()[0],
offsets, inputSizes, strides);

auto convOp = linalg::Conv2DOp::create(
rewriter, poolOp.getLoc(), twoDOutputType,
Expand All @@ -543,8 +543,8 @@ struct RewriteAvgPoolAsConv2D
rewriter.getIndexAttr(1), rewriter.getIndexAttr(1),
rewriter.getIndexAttr(outputTy.getDimSize(2)),
rewriter.getIndexAttr(outputTy.getDimSize(3))};
outputVal = rewriter.create<tensor::InsertSliceOp>(
poolOp.getLoc(), convOp.getResult(0), outputVal, offsets,
outputVal = tensor::InsertSliceOp::create(
rewriter, poolOp.getLoc(), convOp.getResult(0), outputVal, offsets,
outputSizes, strides);
}
}
Expand Down
14 changes: 9 additions & 5 deletions lib/Utils/Layout/Codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ Value buildIslExpr(isl_ast_expr* expr, std::map<std::string, Value> ivToValue,
SmallVector<Value> args = getArgs(expr);
auto op =
arith::CmpIOp::create(b, islCmpToMlirAttr[type], args[0], args[1]);
return op->getResult(0);
return arith::ExtSIOp::create(b, b.getI32Type(), op->getResult(0));
}

if (type == isl_ast_op_select) {
Expand All @@ -229,7 +229,7 @@ Value buildIslExpr(isl_ast_expr* expr, std::map<std::string, Value> ivToValue,
auto eqOp =
arith::CmpIOp::create(b, arith::CmpIPredicate::eq, op,
arith::ConstantIntOp::create(b, 0, 32));
return eqOp->getResult(0);
return arith::ExtSIOp::create(b, b.getI32Type(), eqOp->getResult(0));
}
isl_ast_expr_op_type opType = isl_ast_expr_get_op_type(expr);
char* cStr = isl_ast_expr_to_C_str(expr);
Expand Down Expand Up @@ -318,9 +318,13 @@ FailureOr<scf::ValueVector> MLIRLoopNestGenerator::visitAstNodeIf(
isl_ast_expr_free(cond);

// Build scf if operation with the result types of the iter args
auto ifOp =
scf::IfOp::create(builder_, currentLoc_, TypeRange(currentIterArgs_),
condVal, /*addThenBlock=*/true, /*addElseBlock=*/true);
// Convert condVal to an i1
auto condValI1 =
arith::CmpIOp::create(builder_, arith::CmpIPredicate::eq, condVal,
arith::ConstantIntOp::create(builder_, 1, 32));
auto ifOp = scf::IfOp::create(builder_, currentLoc_,
TypeRange(currentIterArgs_), condValI1,
/*addThenBlock=*/true, /*addElseBlock=*/true);

// TODO:(#2120): Handle ISL else conditions.
isl_ast_node* elseNode = isl_ast_node_if_get_else_node(node);
Expand Down
Loading
Loading