diff --git a/spartan_parallel/src/r1csproof.rs b/spartan_parallel/src/r1csproof.rs index 05a52d94..a55a47dd 100644 --- a/spartan_parallel/src/r1csproof.rs +++ b/spartan_parallel/src/r1csproof.rs @@ -13,12 +13,17 @@ use crate::{ProverWitnessSecInfo, VerifierWitnessSecInfo}; use merlin::Transcript; use serde::{Deserialize, Serialize}; use std::cmp::min; +use std::iter::zip; #[derive(Serialize, Deserialize, Debug)] pub struct R1CSProof { sc_proof_phase1: SumcheckInstanceProof, sc_proof_phase2: SumcheckInstanceProof, claims_phase2: (S, S, S), + // Need to commit vars for short and long witnesses separately + // The long version must exist, the short version might not + eval_vars_at_ry_list: Vec>, + eval_vars_at_ry: S, // proof_eval_vars_at_ry_list: Vec>, } @@ -255,8 +260,9 @@ impl R1CSProof { &poly_Cz.index(0, 0, 0, 0), ); - // prove the final step of sum-check #1 - let _taus_bound_rx = tau_claim; + S::append_field_to_transcript(b"Az_claim", transcript, *Az_claim); + S::append_field_to_transcript(b"Bz_claim", transcript, *Bz_claim); + S::append_field_to_transcript(b"Cz_claim", transcript, *Cz_claim); // Separate the result rx into rp, rq, and rx let (rx_rev, rq_rev) = rx.split_at(num_rounds_x); @@ -376,7 +382,7 @@ impl R1CSProof { let mut Zr_list = Vec::new(); // List of evaluations separated by witness_secs let mut eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs]; - + let mut raw_eval_vars_at_ry_list = vec![Vec::new(); num_witness_secs]; // Does not multiply ry_factor for i in 0..num_witness_secs { let w = witness_secs[i]; let wit_sec_num_instance = w.w_mat.len(); @@ -409,9 +415,8 @@ impl R1CSProof { } else { eval_vars_at_ry_list[i] .push(eval_vars_at_ry * ry_factors[num_rounds_y - w.num_inputs[p].log_2()]); - eval_vars_at_ry_list[i] - .push(eval_vars_at_ry * ry_factors[num_rounds_y - w.num_inputs[p].log_2()]); } + raw_eval_vars_at_ry_list[i].push(eval_vars_at_ry); } } @@ -434,13 +439,6 @@ impl R1CSProof { // So we need to multiply each entry by (1 - rq0)(1 - rq1) let mut eval_vars_comb_list = Vec::new(); for p in 0..num_instances { - let _wit_sec_p = |i: usize| { - if witness_secs[i].w_mat.len() == 1 { - 0 - } else { - p - } - }; let wit_sec_p = |i: usize| { if witness_secs[i].w_mat.len() == 1 { 0 @@ -490,7 +488,7 @@ impl R1CSProof { timer_polyeval.stop(); let poly_vars = DensePolynomial::new(eval_vars_comb_list); - let _eval_vars_at_ry = poly_vars.evaluate(&rp); + let eval_vars_at_ry = poly_vars.evaluate(&rp); timer_prove.stop(); @@ -499,6 +497,8 @@ impl R1CSProof { sc_proof_phase1, sc_proof_phase2, claims_phase2: (*Az_claim, *Bz_claim, *Cz_claim), + eval_vars_at_ry_list: raw_eval_vars_at_ry_list, + eval_vars_at_ry, // proof_eval_vars_at_ry_list, }, [rp, rq_rev, rx, [rw, ry].concat()], @@ -509,7 +509,7 @@ impl R1CSProof { &self, num_instances: usize, max_num_proofs: usize, - _num_proofs: &Vec, + num_proofs: &Vec, max_num_inputs: usize, // NUM_WITNESS_SECS @@ -525,7 +525,7 @@ impl R1CSProof { witness_secs: Vec<&VerifierWitnessSecInfo>, num_cons: usize, - _evals: &[S; 3], + evals: &[S; 3], transcript: &mut Transcript, ) -> Result<[Vec; 4], ProofVerifyError> { >::append_protocol_name( @@ -551,7 +551,7 @@ impl R1CSProof { let tau_q = transcript.challenge_vector(b"challenge_tau_q", num_rounds_q); let tau_x = transcript.challenge_vector(b"challenge_tau_x", num_rounds_x); - let (_, rx) = self.sc_proof_phase1.verify( + let (claim_post_phase_1, rx) = self.sc_proof_phase1.verify( S::field_zero(), num_rounds_x + num_rounds_q + num_rounds_p, 3, @@ -563,7 +563,7 @@ impl R1CSProof { let (rq_rev, rp_round1) = rq_rev.split_at(num_rounds_q); let rx: Vec = rx_rev.iter().copied().rev().collect(); let rq_rev = rq_rev.to_vec(); - let _rq: Vec = rq_rev.iter().copied().rev().collect(); + let rq: Vec = rq_rev.iter().copied().rev().collect(); let rp_round1 = rp_round1.to_vec(); // taus_bound_rx is really taus_bound_rx_rq_rp @@ -578,7 +578,16 @@ impl R1CSProof { let taus_bound_rx: S = (0..rx_rev.len()) .map(|i| rx_rev[i] * tau_x[i] + (S::field_one() - rx_rev[i]) * (S::field_one() - tau_x[i])) .product(); - let _taus_bound_rx = taus_bound_rp * taus_bound_rq * taus_bound_rx; + let taus_bound_rx = taus_bound_rp * taus_bound_rq * taus_bound_rx; + + // perform the intermediate sum-check test with claimed Az, Bz, and Cz + let (Az_claim, Bz_claim, Cz_claim) = self.claims_phase2; + S::append_field_to_transcript(b"Az_claim", transcript, Az_claim); + S::append_field_to_transcript(b"Bz_claim", transcript, Bz_claim); + S::append_field_to_transcript(b"Cz_claim", transcript, Cz_claim); + + // debug_zk + // assert_eq!(taus_bound_rx * (Az_claim * Bz_claim - Cz_claim), claim_post_phase_1); // derive three public challenges and then derive a joint claim let r_A: S = transcript.challenge_scalar(b"challenge_Az"); @@ -589,7 +598,7 @@ impl R1CSProof { let claim_phase2 = r_A * Az_claim + r_B * Bz_claim + r_C * Cz_claim; // verify the joint claim with a sum-check protocol - let (_, ry) = self.sc_proof_phase2.verify( + let (claim_post_phase_2, ry) = self.sc_proof_phase2.verify( claim_phase2, num_rounds_y + num_rounds_w + num_rounds_p, 3, @@ -604,7 +613,7 @@ impl R1CSProof { let ry: Vec = ry_rev.iter().copied().rev().collect(); // An Eq function to match p with rp - let _p_rp_poly_bound_ry: S = (0..rp.len()) + let p_rp_poly_bound_ry: S = (0..rp.len()) .map(|i| rp[i] * rp_round1[i] + (S::field_one() - rp[i]) * (S::field_one() - rp_round1[i])) .product(); @@ -621,13 +630,14 @@ impl R1CSProof { let timer_commit_opening = Timer::new("verify_sc_commitment_opening"); let mut num_proofs_list = Vec::new(); let mut num_inputs_list = Vec::new(); - + let mut eval_Zr_list = Vec::new(); for i in 0..num_witness_secs { let w = witness_secs[i]; let wit_sec_num_instance = w.num_proofs.len(); for p in 0..wit_sec_num_instance { num_proofs_list.push(w.num_proofs[p]); num_inputs_list.push(w.num_inputs[p]); + eval_Zr_list.push(self.eval_vars_at_ry_list[i][p]); } } @@ -643,15 +653,24 @@ impl R1CSProof { */ // Then on rp + let mut expected_eval_vars_list = Vec::new(); for p in 0..num_instances { - let _wit_sec_p = |i: usize| { + let wit_sec_p = |i: usize| { if witness_secs[i].num_proofs.len() == 1 { 0 } else { p } }; - let _prefix_list = match num_witness_secs.next_power_of_two() { + let c = |i: usize| { + if witness_secs[i].num_inputs[wit_sec_p(i)] >= max_num_inputs { + self.eval_vars_at_ry_list[i][wit_sec_p(i)] + } else { + self.eval_vars_at_ry_list[i][wit_sec_p(i)] + * ry_factors[num_rounds_y - witness_secs[i].num_inputs[wit_sec_p(i)].log_2()] + } + }; + let prefix_list = match num_witness_secs.next_power_of_two() { 1 => { vec![S::field_one()] } @@ -682,10 +701,33 @@ impl R1CSProof { panic!("Unsupported num_witness_secs: {}", num_witness_secs); } }; + let mut eval_vars_comb = + (1..num_witness_secs).fold(prefix_list[0] * c(0), |s, i| s + prefix_list[i] * c(i)); + for q in 0..(num_rounds_q - num_proofs[p].log_2()) { + eval_vars_comb *= S::field_one() - rq[q]; + } + expected_eval_vars_list.push(eval_vars_comb); } + let EQ_p = &EqPolynomial::new(rp.clone()).evals()[..num_instances]; + let expected_eval_vars_at_ry = + zip(EQ_p, expected_eval_vars_list).fold(S::field_zero(), |s, (a, b)| s + *a * b); + + assert_eq!(expected_eval_vars_at_ry, self.eval_vars_at_ry); + timer_commit_opening.stop(); + // compute commitment to eval_Z_at_ry = (ONE - ry[0]) * self.eval_vars_at_ry + ry[0] * poly_input_eval + let eval_Z_at_ry = &self.eval_vars_at_ry; + + // perform the final check in the second sum-check protocol + let [eval_A_r, eval_B_r, eval_C_r] = evals; + let expected_claim_post_phase2 = + (r_A * *eval_A_r + r_B * *eval_B_r + r_C * *eval_C_r) * *eval_Z_at_ry * p_rp_poly_bound_ry; + + // verify proof that expected_claim_post_phase2 == claim_post_phase2 + assert_eq!(claim_post_phase_2, expected_claim_post_phase2); + Ok([rp, rq_rev, rx, [rw, ry].concat()]) } }