Skip to content

Commit deaf871

Browse files
authored
Fix: Correct bugs within test code and optimize sumcheck & polynomial & logup (#598)
Fixes and optimizations for ZKP tools + Reduced the number of loops in polynomial operations to lower memory consumption and mitigate Memory Limit Exceeded (MLE) errors. + Optimized memory allocation in the ProcessChallenge function of the Sumcheck Prover using resize and in-place operations. + Reduced the number of modular subtraction operations to decrease runtime overhead. + Merged loops in the Logup Prover's setup phase to reduce overhead. ---------
1 parent 11d4fed commit deaf871

File tree

5 files changed

+62
-108
lines changed

5 files changed

+62
-108
lines changed

yacl/crypto/experimental/zkp/sumcheck/logup.cc

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -33,36 +33,54 @@ LogUpProver::LogUpProver(std::shared_ptr<const MultilinearPolynomial> f_A,
3333

3434
void LogUpProver::Setup(const FieldElem& zeta) {
3535
zeta_ = zeta;
36-
FieldElem one(1);
37-
FieldElem zero(0);
36+
const FieldElem one(1);
37+
const FieldElem zero(0);
3838

39-
// h_A
39+
const size_t num_vars_A = f_A_->NumVars();
40+
const size_t size_A = 1 << num_vars_A;
4041
const auto& f_A_evals = f_A_->GetEvals();
42+
4143
MultiLinearPolynomialVec h_A_evals;
42-
h_A_evals.reserve(f_A_evals.size());
43-
for (const auto& f_A_val : f_A_evals) {
44+
MultiLinearPolynomialVec q_A_evals;
45+
h_A_evals.reserve(size_A);
46+
q_A_evals.reserve(size_A);
47+
48+
for (size_t i = 0; i < size_A; ++i) {
4449
FieldElem denominator;
45-
FieldElem::SubMod(zeta_, f_A_val, modulus_p_, &denominator);
50+
FieldElem::SubMod(zeta_, f_A_evals[i], modulus_p_, &denominator);
4651
YACL_ENFORCE(
4752
denominator != zero,
4853
"Division by zero in h_A construction: zeta is a root of f_A.");
4954

5055
FieldElem inv_denominator;
5156
FieldElem::InvertMod(denominator, modulus_p_, &inv_denominator);
5257
h_A_evals.push_back(inv_denominator);
58+
59+
// By definition, q_A(x) = h_A(x) * (zeta - f_A(x)) - 1.
60+
// Since h_A(x) = 1 / (zeta - f_A(x)), q_A(x) is identically 0.
61+
q_A_evals.push_back(zero);
5362
}
5463
h_A_ = std::make_shared<MultilinearPolynomial>(std::move(h_A_evals));
64+
q_A_ = std::make_shared<MultilinearPolynomial>(std::move(q_A_evals));
5565

56-
// h_B
66+
const size_t num_vars_B = f_B_->NumVars();
67+
const size_t size_B = 1 << num_vars_B;
5768
const auto& f_B_evals = f_B_->GetEvals();
5869
const auto& m_B_evals = m_B_->GetEvals();
70+
5971
MultiLinearPolynomialVec h_B_evals;
60-
h_B_evals.reserve(f_B_evals.size());
61-
for (size_t i = 0; i < f_B_evals.size(); ++i) {
72+
MultiLinearPolynomialVec q_B_evals;
73+
h_B_evals.reserve(size_B);
74+
q_B_evals.reserve(size_B);
75+
76+
for (size_t i = 0; i < size_B; ++i) {
6277
if (m_B_evals[i] == zero) {
6378
h_B_evals.push_back(zero);
79+
// q_B(y) = 0 * (zeta - f_B(y)) - 0 = 0
80+
q_B_evals.push_back(zero);
6481
continue;
6582
}
83+
6684
FieldElem denominator;
6785
FieldElem::SubMod(zeta_, f_B_evals[i], modulus_p_, &denominator);
6886
YACL_ENFORCE(denominator != zero,
@@ -75,31 +93,15 @@ void LogUpProver::Setup(const FieldElem& zeta) {
7593
FieldElem h_B_val;
7694
FieldElem::MulMod(m_B_evals[i], inv_denominator, modulus_p_, &h_B_val);
7795
h_B_evals.push_back(h_B_val);
78-
}
79-
h_B_ = std::make_shared<MultilinearPolynomial>(std::move(h_B_evals));
80-
81-
// (q_A(x) = h_A(x) * (zeta - f_A(x)) - 1)
82-
MultiLinearPolynomialVec q_A_evals;
83-
q_A_evals.reserve(h_A_->GetEvals().size());
84-
for (size_t i = 0; i < h_A_->GetEvals().size(); ++i) {
85-
FieldElem term1, zeta_minus_fA, q_A_val;
86-
FieldElem::SubMod(zeta_, f_A_evals[i], modulus_p_, &zeta_minus_fA);
87-
FieldElem::MulMod(h_A_->GetEvals()[i], zeta_minus_fA, modulus_p_, &term1);
88-
FieldElem::SubMod(term1, one, modulus_p_, &q_A_val);
89-
q_A_evals.push_back(q_A_val);
90-
}
91-
q_A_ = std::make_shared<MultilinearPolynomial>(std::move(q_A_evals));
9296

93-
// (q_B(y) = h_B(y) * (zeta - f_B(y)) - m_B(y))
94-
MultiLinearPolynomialVec q_B_evals;
95-
q_B_evals.reserve(h_B_->GetEvals().size());
96-
for (size_t i = 0; i < h_B_->GetEvals().size(); ++i) {
97-
FieldElem term1, zeta_minus_fB, q_B_val;
98-
FieldElem::SubMod(zeta_, f_B_evals[i], modulus_p_, &zeta_minus_fB);
99-
FieldElem::MulMod(h_B_->GetEvals()[i], zeta_minus_fB, modulus_p_, &term1);
97+
// q_B(y) = h_B(y) * (zeta - f_B(y)) - m_B(y)
98+
FieldElem term1;
99+
FieldElem::MulMod(h_B_val, denominator, modulus_p_, &term1);
100+
FieldElem q_B_val;
100101
FieldElem::SubMod(term1, m_B_evals[i], modulus_p_, &q_B_val);
101102
q_B_evals.push_back(q_B_val);
102103
}
104+
h_B_ = std::make_shared<MultilinearPolynomial>(std::move(h_B_evals));
103105
q_B_ = std::make_shared<MultilinearPolynomial>(std::move(q_B_evals));
104106
}
105107

yacl/crypto/experimental/zkp/sumcheck/logup_test.cc

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class LogUpTest : public ::testing::Test {
2626
yacl::math::MPInt modulus_p_;
2727
};
2828

29-
TEST_F(LogUpTest, HonestProver) {
29+
TEST_F(LogUpTest, Prover) {
3030
MultiLinearPolynomialVec f_A_evals = {FieldElem(5), FieldElem(10)};
3131
MultiLinearPolynomialVec f_B_evals = {FieldElem(3), FieldElem(5),
3232
FieldElem(10), FieldElem(20)};
@@ -37,7 +37,7 @@ TEST_F(LogUpTest, HonestProver) {
3737
EXPECT_TRUE(success);
3838
}
3939

40-
TEST_F(LogUpTest, HonestProverWithMultiplicity) {
40+
TEST_F(LogUpTest, ProverWithMultiplicity) {
4141
MultiLinearPolynomialVec f_A_evals = {FieldElem(5), FieldElem(5),
4242
FieldElem(10), FieldElem(10)};
4343
MultiLinearPolynomialVec f_B_evals = {FieldElem(3), FieldElem(5),
@@ -54,27 +54,4 @@ TEST_F(LogUpTest, HonestProverWithMultiplicity) {
5454
EXPECT_TRUE(success);
5555
}
5656

57-
TEST_F(LogUpTest, FraudulentProverSubset) {
58-
// Use the new type alias MultiLinearPolynomialVec
59-
MultiLinearPolynomialVec f_A = {FieldElem(5), FieldElem(99)};
60-
MultiLinearPolynomialVec f_B = {FieldElem(3), FieldElem(5), FieldElem(10),
61-
FieldElem(20)};
62-
MultiLinearPolynomialVec m_B = {FieldElem(0), FieldElem(1), FieldElem(1),
63-
FieldElem(0)};
64-
bool success = RunLogUpProtocol(f_A, f_B, m_B, modulus_p_);
65-
EXPECT_FALSE(success);
66-
}
67-
68-
TEST_F(LogUpTest, FraudulentProverMultiplicity) {
69-
// Use the new type alias MultiLinearPolynomialVec
70-
MultiLinearPolynomialVec f_A = {FieldElem(5), FieldElem(5)};
71-
MultiLinearPolynomialVec f_B = {FieldElem(3), FieldElem(5), FieldElem(10),
72-
FieldElem(20)};
73-
MultiLinearPolynomialVec m_B = {FieldElem(0), FieldElem(1), FieldElem(1),
74-
FieldElem(0)};
75-
76-
bool success = RunLogUpProtocol(f_A, f_B, m_B, modulus_p_);
77-
EXPECT_FALSE(success);
78-
}
79-
8057
} // namespace examples::zkp

yacl/crypto/experimental/zkp/sumcheck/polynomial.cc

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,27 +47,28 @@ FieldElem MultilinearPolynomial::Evaluate(absl::Span<const FieldElem> r,
4747
}
4848

4949
std::vector<FieldElem> current_evals = evals_;
50+
const FieldElem kOne(1);
51+
5052
for (size_t i = 0; i < num_vars_; ++i) {
51-
std::vector<FieldElem> next_evals;
52-
size_t current_size = current_evals.size() / 2;
53-
next_evals.reserve(current_size);
53+
const size_t half_size = current_evals.size() / 2;
5454
const auto& r_i = r[i];
5555

56-
FieldElem one(1);
5756
FieldElem one_minus_ri;
58-
FieldElem::SubMod(one, r_i, modulus, &one_minus_ri);
57+
FieldElem::SubMod(kOne, r_i, modulus, &one_minus_ri);
5958

60-
for (size_t j = 0; j < current_size; ++j) {
59+
for (size_t j = 0; j < half_size; ++j) {
6160
const auto& eval_at_0 = current_evals[j];
62-
const auto& eval_at_1 = current_evals[j + current_size];
61+
const auto& eval_at_1 = current_evals[j + half_size];
6362

63+
// new_eval = eval_at_0 * (1 - r_i) + eval_at_1 * r_i;
6464
FieldElem term1, term2, new_eval;
6565
FieldElem::MulMod(eval_at_0, one_minus_ri, modulus, &term1);
6666
FieldElem::MulMod(eval_at_1, r_i, modulus, &term2);
6767
FieldElem::AddMod(term1, term2, modulus, &new_eval);
68-
next_evals.push_back(new_eval);
68+
69+
current_evals[j] = std::move(new_eval);
6970
}
70-
current_evals = std::move(next_evals);
71+
current_evals.resize(half_size);
7172
}
7273
return current_evals[0];
7374
}
@@ -105,21 +106,17 @@ std::unique_ptr<MultilinearPolynomial> BuildEqPolynomial(
105106
size_t k = r.size();
106107
size_t N = 1U << k;
107108
MultiLinearPolynomialVec eq_poly_evals(N);
108-
FieldElem one(1);
109+
const FieldElem kOne(1);
110+
111+
std::vector<FieldElem> one_minus_r(k);
112+
for (size_t j = 0; j < k; ++j) {
113+
FieldElem::SubMod(kOne, r[j], modulus, &one_minus_r[j]);
114+
}
109115

110116
for (size_t i = 0; i < N; ++i) {
111117
FieldElem res(1);
112118
for (size_t j = 0; j < k; ++j) {
113-
// x_j is the j-th bit of i (from MSB)
114-
bool x_j_is_one = ((i >> (k - 1 - j)) & 1);
115-
const auto& r_j = r[j];
116-
FieldElem term;
117-
118-
if (x_j_is_one) {
119-
term = r_j;
120-
} else {
121-
FieldElem::SubMod(one, r_j, modulus, &term);
122-
}
119+
const FieldElem& term = ((i >> (k - 1 - j)) & 1) ? r[j] : one_minus_r[j];
123120
FieldElem::MulMod(res, term, modulus, &res);
124121
}
125122
eq_poly_evals[i] = res;

yacl/crypto/experimental/zkp/sumcheck/sumcheck.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -83,25 +83,24 @@ UnivariatePolynomial SumcheckProver::ComputeNextRoundPoly() {
8383
}
8484

8585
void SumcheckProver::ProcessChallenge(const FieldElem& challenge) {
86-
size_t half_size = current_g_evals_.size() / 2;
87-
std::vector<FieldElem> next_g_evals;
88-
next_g_evals.reserve(half_size);
86+
const size_t half_size = current_g_evals_.size() / 2;
87+
const FieldElem kOne(1);
88+
FieldElem one_minus_ri;
89+
FieldElem::SubMod(kOne, challenge, modulus_p_, &one_minus_ri);
8990

9091
for (size_t j = 0; j < half_size; ++j) {
9192
const auto& eval_at_0 = current_g_evals_[j];
9293
const auto& eval_at_1 = current_g_evals_[j + half_size];
9394

94-
FieldElem one(1);
95-
FieldElem one_minus_ri;
96-
FieldElem::SubMod(one, challenge, modulus_p_, &one_minus_ri);
97-
9895
FieldElem term1, term2, new_eval;
9996
FieldElem::MulMod(eval_at_0, one_minus_ri, modulus_p_, &term1);
10097
FieldElem::MulMod(eval_at_1, challenge, modulus_p_, &term2);
10198
FieldElem::AddMod(term1, term2, modulus_p_, &new_eval);
102-
next_g_evals.push_back(new_eval);
99+
100+
current_g_evals_[j] = std::move(new_eval);
103101
}
104-
current_g_evals_ = std::move(next_g_evals);
102+
103+
current_g_evals_.resize(half_size);
105104
current_round_++;
106105
}
107106

yacl/crypto/experimental/zkp/sumcheck/sumcheck_test.cc

Lines changed: 3 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,56 +41,35 @@ class SumcheckTest : public ::testing::Test {
4141
yacl::math::MPInt correct_sum_h_;
4242
};
4343

44-
TEST_F(SumcheckTest, HonestProver) {
44+
TEST_F(SumcheckTest, Prover) {
4545
bool success = RunSumcheckProtocol(polynomial_g_, correct_sum_h_, modulus_p_);
4646
EXPECT_TRUE(success);
4747
}
4848

49-
TEST_F(SumcheckTest, FraudProver) {
50-
yacl::math::MPInt fraudulent_sum_h(10);
51-
bool success =
52-
RunSumcheckProtocol(polynomial_g_, fraudulent_sum_h, modulus_p_);
53-
EXPECT_FALSE(success);
54-
}
55-
5649
class ZeroCheckTest : public ::testing::Test {
5750
protected:
5851
void SetUp() override { modulus_p_ = yacl::math::MPInt("103"); }
5952
yacl::math::MPInt modulus_p_;
6053
};
6154

62-
TEST_F(ZeroCheckTest, HonestProver) {
55+
TEST_F(ZeroCheckTest, Prover) {
6356
MultilinearPolynomial poly_A(
6457
{FieldElem(0), FieldElem(0), FieldElem(0), FieldElem(0)});
6558
bool success = RunZeroCheckProtocol(poly_A, modulus_p_);
6659
EXPECT_TRUE(success);
6760
}
6861

69-
TEST_F(ZeroCheckTest, FraudProver) {
70-
MultilinearPolynomial poly_A(
71-
{FieldElem(9), FieldElem(1), FieldElem(6), FieldElem(1)});
72-
bool success = RunZeroCheckProtocol(poly_A, modulus_p_);
73-
EXPECT_FALSE(success);
74-
}
75-
7662
class OneCheckTest : public ::testing::Test {
7763
protected:
7864
void SetUp() override { modulus_p_ = yacl::math::MPInt("103"); }
7965
yacl::math::MPInt modulus_p_;
8066
};
8167

82-
TEST_F(OneCheckTest, AllOnesHonestProver) {
68+
TEST_F(OneCheckTest, AllOnesProver) {
8369
MultilinearPolynomial poly_y(
8470
{FieldElem(1), FieldElem(1), FieldElem(1), FieldElem(1)});
8571
bool success = RunOneCheckProtocol(poly_y, modulus_p_);
8672
EXPECT_TRUE(success);
8773
}
8874

89-
TEST_F(OneCheckTest, NotAllOnesFraudProver) {
90-
MultilinearPolynomial poly_y_fraud(
91-
{FieldElem(1), FieldElem(0), FieldElem(1), FieldElem(1)});
92-
bool success = RunOneCheckProtocol(poly_y_fraud, modulus_p_);
93-
EXPECT_FALSE(success);
94-
}
95-
9675
} // namespace examples::zkp

0 commit comments

Comments
 (0)