Skip to content
Closed

OSPP #345

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions psi/algorithm/ypir/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2025 The secretflow authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("//bazel:psi.bzl", "psi_cc_library", "psi_cc_test")

package(default_visibility = ["//visibility:public"])

psi_cc_library(
name = "util",
srcs = ["util.cc"],
hdrs = ["util.h"],
deps = [
"//psi/algorithm/spiral:gadget",
"//psi/algorithm/spiral:params",
"//psi/algorithm/spiral:poly_matrix",
"//psi/algorithm/spiral:poly_matrix_utils",
],
)

psi_cc_test(
name = "rlwe_test",
srcs = ["rlwe_test.cc"],
deps = [
":util",
"//psi/algorithm/spiral:gadget",
"//psi/algorithm/spiral:params",
"//psi/algorithm/spiral:util",
"//psi/algorithm/spiral:poly_matrix",
"//psi/algorithm/spiral:poly_matrix_utils",
"//psi/algorithm/spiral:spiral_client",
"@yacl//yacl/crypto/rand",
"@yacl//yacl/crypto/tools:prg",
],
)
46 changes: 46 additions & 0 deletions psi/algorithm/ypir/pir_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright 2025 The secretflow authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "psi/algorithm/ypir/server.h"
#include "yacl/link/test_util.h"

namespace pir::ypir {

inline void GenerateDatabase(std::vector<std::vector<uint64_t>> &database) {
size_t row = static_cast<size_t>(sqrt(kTestSize));
size_t col = static_cast<size_t>(sqrt(kTestSize));
database.resize(row);
for (size_t i = 0; i < row; i++) {
database[i] =
pir::simple::GenerateRandomVector(col, kTestPlainModulus, true);
}
}

TEST(PIRTest, AllWorkflow) {
pir::ypir::YPIRServer server(1 << 6, 1 << 6, 1 << 10, 1 << 11, 65537, 1 << 32, 1 << 56, 12345);
std::vector<std::vector<uint64_t>> database;
GenerateDatabase(database);
// Phase 1: Sets LWE matrix for server and client
server.SetDatabase(database);
server.GenerateLweMatrix();

uint128_t server_seed = server.GetSeed();
auto server_hint_vec = server.GetHint();

client.Setup(server_seed, server_hint_vec);
}

} // namespace pir::ypir
111 changes: 111 additions & 0 deletions psi/algorithm/ypir/rlwe_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#include "gtest/gtest.h"
#include "yacl/crypto/rand/rand.h"
#include "yacl/crypto/tools/prg.h"

#include "../spiral/gadget.h"
#include "../spiral/poly_matrix.h"
#include "../spiral/poly_matrix_utils.h"
#include "../spiral/public_keys.h"
#include "../spiral/util.h"
#include "../spiral/spiral_client.h"
#include "../spiral/arith/number_theory.h"
#include "psi/algorithm/ypir/util.h"

namespace psi::ypir {

// TEST(RLWETest, AllWorkflow) {
// auto params = spiral::util::GetFastExpansionTestingParam();

// size_t t_exp = params.TExpLeft();

// uint64_t p = params.PtModulus();
// uint64_t q = params.Modulus();
// auto sigma_raw = spiral::PolyMatrixRaw::Zero(params.PolyLen(), 1, 1);
// uint64_t scale_k = params.ScaleK();

// auto pt_seed = yacl::crypto::SecureRandU128();
// yacl::crypto::Prg<uint64_t> pt_rng(pt_seed);
// std::vector<uint64_t> m_auto(params.PolyLen());
// for (size_t i = 0; i < params.PolyLen(); ++i) {
// m_auto[i] = pt_rng() % p;
// sigma_raw.Data()[i] = (m_auto[i] * scale_k) % q;
// }
// auto sigma_ntt = spiral::ToNtt(params, sigma_raw);

// class ClientDerive : public spiral::SpiralClient {
// public:
// using spiral::SpiralClient::SpiralClient;
// using spiral::SpiralClient::EncryptMatrixRegev;
// using spiral::SpiralClient::DecryptMatrixRegev;
// };
// ClientDerive client(params);
// auto pks = client.GenPublicKeys();
// ASSERT_FALSE(pks.v_expansion_left_.empty());
// auto pub_param = pks.v_expansion_left_[0];

// size_t t = (params.PolyLen() / 2) + 1;

// auto seed = yacl::crypto::SecureRandU128();
// yacl::crypto::Prg<uint64_t> rng(seed);
// auto pub_seed = yacl::crypto::SecureRandU128();
// yacl::crypto::Prg<uint64_t> rng_pub(pub_seed);
// auto ct = client.EncryptMatrixRegev(sigma_ntt, rng, rng_pub);

// auto out = psi::ypir::HomomorphicAutomorph(params, t, t_exp, ct, pub_param);
// ASSERT_TRUE(out.IsNtt());
// ASSERT_EQ(out.Rows(), static_cast<size_t>(2));
// ASSERT_EQ(out.Cols(), static_cast<size_t>(1));

// auto dec_ntt = client.DecryptMatrixRegev(ct);
// auto dec_raw = spiral::FromNtt(params, dec_ntt);
// auto expect = spiral::Automorphism(params, sigma_raw, t);
// ASSERT_EQ(dec_raw.Data().size(), expect.Data().size());
// for (size_t i = 0; i < expect.Data().size(); ++i) {
// uint64_t v_dec_p = spiral::arith::Rescale(dec_raw.Data()[i], q, p);
// uint64_t v_exp_p = spiral::arith::Rescale(expect.Data()[i], q, p);
// ASSERT_EQ(v_dec_p, v_exp_p) << "mismatch at coeff " << i;
// }
// }

TEST(RLWETest, EncDecCorrect) {
auto params = spiral::util::GetFastExpansionTestingParam();

uint64_t p = params.PtModulus();
uint64_t q = params.Modulus();
uint64_t scale_k = params.ScaleK();
auto sigma_raw = spiral::PolyMatrixRaw::Zero(params.PolyLen(), 1, 1);
auto pt_seed = yacl::crypto::SecureRandU128();
yacl::crypto::Prg<uint64_t> pt_rng(pt_seed);
std::vector<uint64_t> m_enc(params.PolyLen());
for (size_t i = 0; i < params.PolyLen(); ++i) {
m_enc[i] = pt_rng() % p;
sigma_raw.Data()[i] = (m_enc[i] * scale_k) % q;
}
auto sigma_ntt = spiral::ToNtt(params, sigma_raw);

class ClientDerive : public spiral::SpiralClient {
public:
using spiral::SpiralClient::SpiralClient;
using spiral::SpiralClient::EncryptMatrixRegev;
using spiral::SpiralClient::DecryptMatrixRegev;
};
ClientDerive client(params);

auto seed = yacl::crypto::SecureRandU128();
yacl::crypto::Prg<uint64_t> rng(seed);
auto pub_seed = yacl::crypto::SecureRandU128();
yacl::crypto::Prg<uint64_t> rng_pub(pub_seed);

auto ct = client.EncryptMatrixRegev(sigma_ntt, rng, rng_pub);
auto dec_ntt = client.DecryptMatrixRegev(ct);
auto dec_raw = spiral::FromNtt(params, dec_ntt);

ASSERT_EQ(dec_raw.Data().size(), sigma_raw.Data().size());
for (size_t i = 0; i < sigma_raw.Data().size(); ++i) {
uint64_t v_rescaled = spiral::arith::Rescale(dec_raw.Data()[i], params.Modulus(), params.PtModulus());
uint64_t v_exp = m_enc[i];
ASSERT_EQ(v_rescaled, v_exp) << "mismatch at coeff " << i;
}
}

}
48 changes: 48 additions & 0 deletions psi/algorithm/ypir/server.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2025 The secretflow authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "psi/algorithm/ypir/server.h"
#include "experiment/pir/simplepir/util.h"

#include "yacl/crypto/tools/prg.h"

using namespace std;

namespace pir::ypir {
YPIRServer::YPIRServer(size_t row_num, size_t col_num, size_t d1,
size_t d2, uint64_t p, uint64_t q1, uint64_t q2, uint128_t seed) :
row_num_(row_num), col_num_(col_num), d1_(d1),
d2_(d2), p_(p), q1_(q1), q2_(q2), seed_(seed) {
YACL_ENFORCE(row_num > 0, "row_num must be positive");
YACL_ENFORCE(col_num > 0, "col_num must be positive");
YACL_ENFORCE(d1 > 0, "d1 must be positive");
YACL_ENFORCE(d2 > 0, "d2 must be positive");
YACL_ENFORCE(p > 0, "p must be positive");
YACL_ENFORCE(q1 > 0, "q1 must be positive");
YACL_ENFORCE(q2 > 0, "q2 must be positive");
YACL_ENFORCE(row_num % d1 == 0, "row_num must be divisible by d1");
YACL_ENFORCE(col_num % d2 == 0, "col_num must be divisible by d2");

delta1_ = q1_ / p_;
delta2_ = q2_ / p_;
}

void YPIRServer::gen_db() {
database_.resize(row_num_);
for (size_t i = 0; i < row_num_; i++) {
database_[i] = pir::simple::GenerateRandomVector(col_num_, p_, true);
}
}

}
41 changes: 41 additions & 0 deletions psi/algorithm/ypir/server.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright 2025 The secretflow authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
#include <vector>

namespace pir::ypir {
class YPIRServer {
public:
YPIRServer(size_t row_num, size_t col_num, size_t d1, size_t d2, uint64_t p, uint64_t q1, uint64_t q2, uint128_t seed);

void gen_db();

private:
size_t row_num_;
size_t col_num_;
size_t d1_;
size_t d2_;
uint64_t p_;
uint64_t q1_;
uint64_t q2_;
uint128_t seed_;
uint64_t delta1_;
uint64_t delta2_;
std::vector<std::vector<uint64_t>> database_;
};

} // namespace pir::ypir
80 changes: 80 additions & 0 deletions psi/algorithm/ypir/util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "psi/algorithm/ypir/util.h"

#include <cstring>

#include "../spiral/gadget.h"
#include "../spiral/poly_matrix_utils.h"

namespace psi::ypir {

using namespace psi::spiral;

PolyMatrixNtt HomomorphicAutomorph(const Params& params, size_t t,
size_t t_exp, const PolyMatrixNtt& ct,
const PolyMatrixNtt& pub_param) {
YACL_ENFORCE(ct.Rows() == static_cast<size_t>(2));
YACL_ENFORCE(ct.Cols() == static_cast<size_t>(1));

auto ct_raw = PolyMatrixRaw::Zero(params.PolyLen(), 2, 1);
FromNtt(params, ct_raw, ct);
auto ct_auto = Automorphism(params, ct_raw, t);

auto ginv_ct = PolyMatrixRaw::Zero(params.PolyLen(), t_exp, 1);
psi::spiral::util::GadgetInvertRdim(params, ginv_ct, ct_auto, 1);

auto ginv_ct_ntt = PolyMatrixNtt::Zero(params.CrtCount(), params.PolyLen(), t_exp, 1);
ToNttNoReduce(params, ginv_ct_ntt, ginv_ct);

auto w_times_ginv_ct = Multiply(params, pub_param, ginv_ct_ntt);

auto ct_auto_1 = PolyMatrixRaw::Zero(params.PolyLen(), 1, 1);

std::memcpy(ct_auto_1.Data().data(),
ct_auto.Data().data() + ct_auto.PolyStartIndex(1, 0),
sizeof(uint64_t) * ct_auto.NumWords());
auto ct_auto_1_ntt = ToNtt(params, ct_auto_1);

auto res = Add(params, ct_auto_1_ntt.PadTop(1), w_times_ginv_ct);
return res;
}

PolyMatrixNtt RingPackLwesInner(
const Params& params,
size_t ell,
size_t start_idx,
const std::vector<PolyMatrixNtt>& rlwe_cts,
const std::vector<PolyMatrixNtt>& pub_params,
const std::pair<std::vector<PolyMatrixNtt>, std::vector<PolyMatrixNtt>>& y_constants) {
YACL_ENFORCE_EQ(pub_params.size(), params.PolyLenLog2());

if (ell == 0) {
return rlwe_cts[start_idx];
}

size_t step = 1ULL << (params.PolyLenLog2() - ell);
size_t even = start_idx;
size_t odd = start_idx + step;

auto ct_even = RingPackLwesInner(params, ell - 1, even, rlwe_cts, pub_params, y_constants);
auto ct_odd = RingPackLwesInner(params, ell - 1, odd, rlwe_cts, pub_params, y_constants);

const auto& y = y_constants.first[ell - 1];
const auto& neg_y = y_constants.second[ell - 1];

auto y_times_ct_odd = ScalarMultiply(params, y, ct_odd);
auto neg_y_times_ct_odd = ScalarMultiply(params, neg_y, ct_odd);

auto ct_sum_1 = ct_even;
AddInto(params, ct_sum_1, neg_y_times_ct_odd);
AddInto(params, ct_even, y_times_ct_odd);

size_t t = (1ULL << ell) + 1;
const auto& pub_param = pub_params[params.PolyLenLog2() - 1 - (ell - 1)];
auto ct_sum_1_automorphed = HomomorphicAutomorph(params, t, params.TExpLeft(), ct_sum_1, pub_param);

return Add(params, ct_even, ct_sum_1_automorphed);
}

} // namespace psi::ypir


Loading
Loading