-
Notifications
You must be signed in to change notification settings - Fork 718
/
Copy pathintel_seal_ext.h
105 lines (89 loc) · 3.81 KB
/
intel_seal_ext.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT license.
#pragma once
#ifdef SEAL_USE_INTEL_HEXL
#include "seal/util/locks.h"
#include <unordered_map>
#include "hexl/hexl.hpp"
namespace intel
{
namespace seal_ext
{
struct HashPair
{
template <class T1, class T2>
size_t operator()(const std::pair<T1, T2> &p) const
{
auto hash1 = std::hash<T1>{}(std::get<0>(p));
auto hash2 = std::hash<T2>{}(std::get<1>(p));
return hash_combine(hash1, hash2);
}
static size_t hash_combine(size_t lhs, size_t rhs)
{
lhs ^= rhs + 0x9e3779b9 + (lhs << 6) + (lhs >> 2);
return lhs;
}
};
static std::unordered_map<std::pair<uint64_t, uint64_t>, intel::hexl::NTT, HashPair> ntt_cache_;
static seal::util::ReaderWriterLocker ntt_cache_locker_;
/**
Returns a HEXL NTT object corresponding to the given parameters.
@param[in] N The polynomial modulus degree
@param[in] modulus The modulus
@param[in] root The root of unity
*/
static intel::hexl::NTT get_ntt(size_t N, uint64_t modulus, uint64_t root)
{
std::pair<uint64_t, uint64_t> key{ N, modulus };
// Enable shared access of NTT already present
{
seal::util::ReaderLock reader_lock(ntt_cache_locker_.acquire_read());
auto ntt_it = ntt_cache_.find(key);
if (ntt_it != ntt_cache_.end())
{
return ntt_it->second;
}
}
// Deal with NTT not yet present
seal::util::WriterLock write_lock(ntt_cache_locker_.acquire_write());
// Check ntt_cache for value (maybe added by another thread)
auto ntt_it = ntt_cache_.find(key);
if (ntt_it == ntt_cache_.end())
{
ntt_it = ntt_cache_.emplace(std::move(key), intel::hexl::NTT(N, modulus, root)).first;
}
return ntt_it->second;
}
/**
Computes for forward negacyclic NTT from the given parameters.
@param[in,out] operand The data on which to compute the NTT.
@param[in] N The polynomial modulus degree
@param[in] modulus The modulus
@param[in] root The root of unity
@param[in] input_mod_factor Bounds the input data to the range [0, input_mod_factor * modulus)
@param[in] output_mod_factor Bounds the output data to the range [0, output_mod_factor * modulus)
*/
static void compute_forward_ntt(
seal::util::CoeffIter operand, size_t N, uint64_t modulus, uint64_t root, uint64_t input_mod_factor,
uint64_t output_mod_factor)
{
get_ntt(N, modulus, root).ComputeForward(operand, operand, input_mod_factor, output_mod_factor);
}
/**
Computes for inverse negacyclic NTT from the given parameters.
@param[in,out] operand The data on which to compute the NTT.
@param[in] N The polynomial modulus degree
@param[in] modulus The modulus
@param[in] root The root of unity
@param[in] input_mod_factor Bounds the input data to the range [0, input_mod_factor * modulus)
@param[in] output_mod_factor Bounds the output data to the range [0, output_mod_factor * modulus)
*/
static void compute_inverse_ntt(
seal::util::CoeffIter operand, size_t N, uint64_t modulus, uint64_t root, uint64_t input_mod_factor,
uint64_t output_mod_factor)
{
get_ntt(N, modulus, root).ComputeInverse(operand, operand, input_mod_factor, output_mod_factor);
}
} // namespace seal_ext
} // namespace intel
#endif