Skip to content

Commit

Permalink
Add CKKS/BFV Negate and Sub Ptxt/Ctxt benchmarks
Browse files Browse the repository at this point in the history
Add benchmarks of high and low level NTT API

Add faster low level NTT benchmark
  • Loading branch information
jlhcrawford authored and Wei Dai committed Sep 15, 2021
1 parent c2f4d37 commit 99a952c
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 0 deletions.
1 change: 1 addition & 0 deletions native/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ if(SEAL_BUILD_BENCH)
PRIVATE
${CMAKE_CURRENT_LIST_DIR}/bench.cpp
${CMAKE_CURRENT_LIST_DIR}/keygen.cpp
${CMAKE_CURRENT_LIST_DIR}/ntt.cpp
${CMAKE_CURRENT_LIST_DIR}/bfv.cpp
${CMAKE_CURRENT_LIST_DIR}/ckks.cpp
)
Expand Down
12 changes: 12 additions & 0 deletions native/bench/bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ namespace sealbench
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, DecodeBatch, bm_bfv_decode_batch, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateAddCt, bm_bfv_add_ct, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateAddPt, bm_bfv_add_pt, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateNegate, bm_bfv_negate, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateSubCt, bm_bfv_sub_ct, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateSubPt, bm_bfv_sub_pt, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateMulCt, bm_bfv_mul_ct, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateMulPt, bm_bfv_mul_pt, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(BFV, n, log_q, EvaluateSquare, bm_bfv_square, bm_env_bfv);
Expand All @@ -84,6 +87,9 @@ namespace sealbench
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, DecodeDouble, bm_ckks_decode_double, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateAddCt, bm_ckks_add_ct, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateAddPt, bm_ckks_add_pt, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateNegate, bm_ckks_negate, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateSubCt, bm_ckks_sub_ct, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateSubPt, bm_ckks_sub_pt, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateMulCt, bm_ckks_mul_ct, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateMulPt, bm_ckks_mul_pt, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateSquare, bm_ckks_square, bm_env_ckks);
Expand All @@ -96,6 +102,12 @@ namespace sealbench
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateRelinInplace, bm_ckks_relin_inplace, bm_env_ckks);
SEAL_BENCHMARK_REGISTER(CKKS, n, log_q, EvaluateRotate, bm_ckks_rotate, bm_env_ckks);
}
SEAL_BENCHMARK_REGISTER(NTT, n, log_q, ForwardNTT, bm_forward_ntt, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(NTT, n, log_q, InverseNTT, bm_inverse_ntt, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(NTT, n, log_q, ForwardNTTLowLevel, bm_forward_ntt_low_level, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(NTT, n, log_q, InverseNTTLowLevel, bm_inverse_ntt_low_level, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(NTT, n, log_q, ForwardNTTLowLevelLazy, bm_forward_ntt_low_level_lazy, bm_env_bfv);
SEAL_BENCHMARK_REGISTER(NTT, n, log_q, InverseNTTLowLevelLazy, bm_inverse_ntt_low_level_lazy, bm_env_bfv);
}

} // namespace sealbench
Expand Down
14 changes: 14 additions & 0 deletions native/bench/bench.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ namespace sealbench
std::vector<seal::Ciphertext> ct_;
}; // namespace BMEnv

// NTT benchmark cases
void bm_forward_ntt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_inverse_ntt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_forward_ntt_low_level(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_inverse_ntt_low_level(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_forward_ntt_low_level_lazy(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_inverse_ntt_low_level_lazy(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);

// KeyGen benchmark cases
void bm_keygen_secret(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_keygen_public(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
Expand All @@ -350,6 +358,9 @@ namespace sealbench
void bm_bfv_decode_batch(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_add_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_add_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_negate(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_sub_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_sub_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_mul_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_mul_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_bfv_square(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
Expand All @@ -366,6 +377,9 @@ namespace sealbench
void bm_ckks_decode_double(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_add_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_add_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_negate(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_sub_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_sub_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_mul_ct(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_mul_pt(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
void bm_ckks_square(benchmark::State &state, std::shared_ptr<BMEnv> bm_env);
Expand Down
42 changes: 42 additions & 0 deletions native/bench/bfv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,48 @@ namespace sealbench
}
}

void bm_bfv_negate(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);

state.ResumeTiming();
bm_env->evaluator()->negate(ct[0], ct[2]);
}
}

void bm_bfv_sub_ct(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);
bm_env->randomize_ct_bfv(ct[1]);

state.ResumeTiming();
bm_env->evaluator()->sub(ct[0], ct[1], ct[2]);
}
}

void bm_bfv_sub_pt(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
Plaintext &pt = bm_env->pt()[0];
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);
bm_env->randomize_pt_bfv(pt);

state.ResumeTiming();
bm_env->evaluator()->sub_plain(ct[0], pt, ct[2]);
}
}

void bm_bfv_mul_ct(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
Expand Down
50 changes: 50 additions & 0 deletions native/bench/ckks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,56 @@ namespace sealbench
}
}

void bm_ckks_negate(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
double scale = bm_env->safe_scale();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_ckks(ct[0]);
ct[0].scale() = scale;

state.ResumeTiming();
bm_env->evaluator()->negate(ct[0], ct[2]);
}
}

void bm_ckks_sub_ct(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
double scale = bm_env->safe_scale();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_ckks(ct[0]);
ct[0].scale() = scale;
bm_env->randomize_ct_ckks(ct[1]);
ct[1].scale() = scale;

state.ResumeTiming();
bm_env->evaluator()->sub(ct[0], ct[1], ct[2]);
}
}

void bm_ckks_sub_pt(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
Plaintext &pt = bm_env->pt()[0];
double scale = bm_env->safe_scale();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_ckks(ct[0]);
ct[0].scale() = scale;
bm_env->randomize_pt_ckks(pt);
pt.scale() = scale;

state.ResumeTiming();
bm_env->evaluator()->sub_plain(ct[0], pt, ct[2]);
}
}

void bm_ckks_mul_ct(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
Expand Down
109 changes: 109 additions & 0 deletions native/bench/ntt.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.

#include "seal/seal.h"
#include "seal/util/rlwe.h"
#include "bench.h"

using namespace benchmark;
using namespace sealbench;
using namespace seal;
using namespace std;

/**
This file defines benchmarks for NTT-related HE primitives.
*/

namespace sealbench
{
void bm_forward_ntt(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);

state.ResumeTiming();
bm_env->evaluator()->transform_to_ntt(ct[0], ct[2]);
}
}

void bm_inverse_ntt(State &state, shared_ptr<BMEnv> bm_env)
{
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);
bm_env->evaluator()->transform_to_ntt_inplace(ct[0]);

state.ResumeTiming();
bm_env->evaluator()->transform_from_ntt(ct[0], ct[2]);
}
}

void bm_forward_ntt_low_level(State &state, shared_ptr<BMEnv> bm_env)
{
parms_id_type parms_id = bm_env->context().first_parms_id();
auto context_data = bm_env->context().get_context_data(parms_id);
const auto &small_ntt_tables = context_data->small_ntt_tables();
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);

state.ResumeTiming();
ntt_negacyclic_harvey(ct[0].data(), small_ntt_tables[0]);
}
}

void bm_inverse_ntt_low_level(State &state, shared_ptr<BMEnv> bm_env)
{
parms_id_type parms_id = bm_env->context().first_parms_id();
auto context_data = bm_env->context().get_context_data(parms_id);
const auto &small_ntt_tables = context_data->small_ntt_tables();
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);

state.ResumeTiming();
inverse_ntt_negacyclic_harvey(ct[0].data(), small_ntt_tables[0]);
}
}

void bm_forward_ntt_low_level_lazy(State &state, shared_ptr<BMEnv> bm_env)
{
parms_id_type parms_id = bm_env->context().first_parms_id();
auto context_data = bm_env->context().get_context_data(parms_id);
const auto &small_ntt_tables = context_data->small_ntt_tables();
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);

state.ResumeTiming();
ntt_negacyclic_harvey_lazy(ct[0].data(), small_ntt_tables[0]);
}
}

void bm_inverse_ntt_low_level_lazy(State &state, shared_ptr<BMEnv> bm_env)
{
parms_id_type parms_id = bm_env->context().first_parms_id();
auto context_data = bm_env->context().get_context_data(parms_id);
const auto &small_ntt_tables = context_data->small_ntt_tables();
vector<Ciphertext> &ct = bm_env->ct();
for (auto _ : state)
{
state.PauseTiming();
bm_env->randomize_ct_bfv(ct[0]);

state.ResumeTiming();
inverse_ntt_negacyclic_harvey_lazy(ct[0].data(), small_ntt_tables[0]);
}
}
} // namespace sealbench

0 comments on commit 99a952c

Please sign in to comment.