From 6aea34c557e8f2e715f20b165a2a58990877a08a Mon Sep 17 00:00:00 2001 From: Kunming Jiang Date: Fri, 3 Jan 2025 12:05:37 -0500 Subject: [PATCH] Experimental improvements on block_witness_gen --- spartan_parallel/src/lib.rs | 221 +++++++++++++++++++++--------------- 1 file changed, 129 insertions(+), 92 deletions(-) diff --git a/spartan_parallel/src/lib.rs b/spartan_parallel/src/lib.rs index b23320c9..e1e54e7c 100644 --- a/spartan_parallel/src/lib.rs +++ b/spartan_parallel/src/lib.rs @@ -38,7 +38,7 @@ mod unipoly; use std::{ cmp::{max, Ordering}, fs::File, - io::Write, + io::Write, iter::zip, }; use dense_mlpoly::{DensePolynomial, PolyEvalProof}; @@ -50,6 +50,7 @@ use merlin::Transcript; use r1csinstance::{R1CSCommitment, R1CSDecommitment, R1CSEvalProof, R1CSInstance}; use r1csproof::R1CSProof; use random::RandomTape; +use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator}; use scalar::SpartanExtensionField; use serde::{Deserialize, Serialize}; use timer::Timer; @@ -1271,37 +1272,60 @@ impl SNARK { // w2 is _, _, ZO, r * i1, r^2 * i2, r^3 * i3, ... // where ZO * r^n = r^n * o0 + r^(n + 1) * o1, ..., // are used by the consistency check - let perm_exec_w2 = { - let mut perm_exec_w2: Vec> = exec_inputs_list - .iter() - .map(|input| { - [ - vec![S::field_zero(); 3], - (1..2 * num_inputs_unpadded - 2) - .map(|j| perm_w0[j] * input[j + 2]) - .collect(), - vec![S::field_zero(); num_ios - 2 * num_inputs_unpadded], - ] - .concat() - }) - .collect(); - for q in 0..consis_num_proofs { - perm_exec_w2[q][0] = exec_inputs_list[q][0]; - perm_exec_w2[q][1] = exec_inputs_list[q][0]; + let (perm_exec_w2, perm_exec_w3) = { + let perm_exec_w2: Vec>; + let mut perm_exec_w3: Vec>; + // Entries that do not depend on others can be generated in parallel + (perm_exec_w2, perm_exec_w3) = (0..consis_num_proofs).into_par_iter().map(|q| { + // perm_exec_w2 + let mut perm_exec_w2_q = [ + vec![S::field_zero(); 3], + (1..2 * num_inputs_unpadded - 2) + .map(|j| perm_w0[j] * exec_inputs_list[q][j + 2]) + .collect(), + vec![S::field_zero(); num_ios - 2 * num_inputs_unpadded], + ].concat(); + perm_exec_w2_q[0] = exec_inputs_list[q][0]; + perm_exec_w2_q[1] = exec_inputs_list[q][0]; for i in 0..num_inputs_unpadded - 1 { let perm = if i == 0 { S::field_one() } else { perm_w0[i] }; - perm_exec_w2[q][0] = perm_exec_w2[q][0] + perm * exec_inputs_list[q][2 + i]; - perm_exec_w2[q][2] = - perm_exec_w2[q][2] + perm * exec_inputs_list[q][2 + (num_inputs_unpadded - 1) + i]; + perm_exec_w2_q[0] = perm_exec_w2_q[0] + perm * exec_inputs_list[q][2 + i]; + perm_exec_w2_q[2] = + perm_exec_w2_q[2] + perm * exec_inputs_list[q][2 + (num_inputs_unpadded - 1) + i]; } - perm_exec_w2[q][0] = perm_exec_w2[q][0] * exec_inputs_list[q][0]; - let ZO = perm_exec_w2[q][2]; - perm_exec_w2[q][1] = perm_exec_w2[q][1] + ZO; - perm_exec_w2[q][1] = perm_exec_w2[q][1] * exec_inputs_list[q][0]; + perm_exec_w2_q[0] = perm_exec_w2_q[0] * exec_inputs_list[q][0]; + let ZO = perm_exec_w2_q[2]; + perm_exec_w2_q[1] = perm_exec_w2_q[1] + ZO; + perm_exec_w2_q[1] = perm_exec_w2_q[1] * exec_inputs_list[q][0]; + + // perm_exec_w3 + let mut perm_exec_w3_q = vec![S::field_zero(); 8]; + perm_exec_w3_q[0] = exec_inputs_list[q][0]; + perm_exec_w3_q[1] = perm_exec_w3_q[0] + * (comb_tau + - perm_exec_w2_q[3..] + .iter() + .fold(S::field_zero(), |a, b| a + *b) + - exec_inputs_list[q][2]); + perm_exec_w3_q[4] = perm_exec_w2_q[0]; + perm_exec_w3_q[5] = perm_exec_w2_q[1]; + + (perm_exec_w2_q, perm_exec_w3_q) + }).unzip(); + // Generate sequential entries separately + for q in (0..consis_num_proofs).rev() { + if q != consis_num_proofs - 1 { + perm_exec_w3[q][3] = perm_exec_w3[q][1] + * (perm_exec_w3[q + 1][2] + S::field_one() - perm_exec_w3[q + 1][0]); + } else { + perm_exec_w3[q][3] = perm_exec_w3[q][1]; + } + perm_exec_w3[q][2] = perm_exec_w3[q][0] * perm_exec_w3[q][3]; } - perm_exec_w2 + (perm_exec_w2, perm_exec_w3) }; - // w3 is [v, x, pi, D] + + /* let perm_exec_w3 = { let mut perm_exec_w3: Vec> = vec![Vec::new(); consis_num_proofs]; for q in (0..consis_num_proofs).rev() { @@ -1325,6 +1349,7 @@ impl SNARK { } perm_exec_w3 }; + */ // commit the witnesses and inputs separately instance-by-instance let (perm_exec_poly_w2, perm_exec_poly_w3, perm_exec_poly_w3_shifted) = { let perm_exec_poly_w2 = { @@ -1376,7 +1401,7 @@ impl SNARK { // w3 is [v, x, pi, D, pi, D, pi, D] let mut block_w3: Vec>> = Vec::new(); let block_w2_prover = { - let mut block_w2 = Vec::new(); + let mut block_w2: Vec>> = Vec::new(); let block_w2_size_list: Vec = (0..block_num_instances) .map(|i| { (2 * num_inputs_unpadded + 2 * block_num_phy_ops[i] + 4 * block_num_vir_ops[i]) @@ -1408,113 +1433,125 @@ impl SNARK { |b: usize, i: usize| 2 * num_inputs_unpadded + 2 * block_num_phy_ops[b] + 4 * i + 3; for p in 0..block_num_instances { - block_w2.push(vec![Vec::new(); block_num_proofs[p]]); - block_w3.push(vec![Vec::new(); block_num_proofs[p]]); - for q in (0..block_num_proofs[p]).rev() { + let block_w2_p: Vec>; + let mut block_w3_p: Vec>; + // Entries that do not depend on others can be generated in parallel + (block_w2_p, block_w3_p) = (0..block_num_proofs[p]).into_par_iter().map(|q| { let V_CNST = block_vars_mat[p][q][0]; // For INPUT - block_w2[p][q] = vec![S::field_zero(); block_w2_size_list[p]]; + let mut q2 = vec![S::field_zero(); block_w2_size_list[p]]; - block_w2[p][q][0] = block_vars_mat[p][q][0]; - block_w2[p][q][1] = block_vars_mat[p][q][0]; + q2[0] = block_vars_mat[p][q][0]; + q2[1] = block_vars_mat[p][q][0]; for i in 1..2 * (num_inputs_unpadded - 1) { - block_w2[p][q][2 + i] = - block_w2[p][q][2 + i] + perm_w0[i] * block_vars_mat[p][q][i + 2]; + q2[2 + i] = q2[2 + i] + perm_w0[i] * block_vars_mat[p][q][i + 2]; } for i in 0..num_inputs_unpadded - 1 { let perm = if i == 0 { S::field_one() } else { perm_w0[i] }; - block_w2[p][q][0] = block_w2[p][q][0] + perm * block_vars_mat[p][q][2 + i]; - block_w2[p][q][2] = - block_w2[p][q][2] + perm * block_vars_mat[p][q][2 + (num_inputs_unpadded - 1) + i]; - } - block_w2[p][q][0] = block_w2[p][q][0] * block_vars_mat[p][q][0]; - let ZO = block_w2[p][q][2]; - block_w2[p][q][1] = block_w2[p][q][1] + ZO; - block_w2[p][q][1] = block_w2[p][q][1] * block_vars_mat[p][q][0]; - block_w3[p][q] = vec![S::field_zero(); 8]; - block_w3[p][q][0] = block_vars_mat[p][q][0]; - block_w3[p][q][1] = block_w3[p][q][0] - * (comb_tau - - block_w2[p][q][3..] - .iter() - .fold(S::field_zero(), |a, b| a + *b) - - block_vars_mat[p][q][2]); - if q != block_num_proofs[p] - 1 { - block_w3[p][q][3] = block_w3[p][q][1] - * (block_w3[p][q + 1][2] + S::field_one() - block_w3[p][q + 1][0]); - } else { - block_w3[p][q][3] = block_w3[p][q][1]; + q2[0] = q2[0] + perm * block_vars_mat[p][q][2 + i]; + q2[2] = q2[2] + perm * block_vars_mat[p][q][2 + (num_inputs_unpadded - 1) + i]; } - block_w3[p][q][2] = block_w3[p][q][0] * block_w3[p][q][3]; + q2[0] = q2[0] * block_vars_mat[p][q][0]; + let ZO = q2[2]; + q2[1] = q2[1] + ZO; + q2[1] = q2[1] * block_vars_mat[p][q][0]; + let mut q3 = vec![S::field_zero(); 8]; + q3[0] = block_vars_mat[p][q][0]; // For PHY // Compute PMR, PMC for i in 0..block_num_phy_ops[p] { // PMR = r * PD - block_w2[p][q][V_PMR(i)] = comb_r * block_vars_mat[p][q][io_width + V_PD(i)]; + q2[V_PMR(i)] = comb_r * block_vars_mat[p][q][io_width + V_PD(i)]; // PMC = (1 or PMC[i-1]) * (tau - PA - PMR) let t = if i == 0 { V_CNST } else { - block_w2[p][q][V_PMC(i - 1)] + q2[V_PMC(i - 1)] }; - block_w2[p][q][V_PMC(i)] = t - * (comb_tau - block_vars_mat[p][q][io_width + V_PA(i)] - block_w2[p][q][V_PMR(i)]); + q2[V_PMC(i)] = t + * (comb_tau - block_vars_mat[p][q][io_width + V_PA(i)] - q2[V_PMR(i)]); } - // Compute x - let px = if block_num_phy_ops[p] == 0 { - V_CNST - } else { - block_w2[p][q][V_PMC(block_num_phy_ops[p] - 1)] - }; - // Compute D and pi - if q != block_num_proofs[p] - 1 { - block_w3[p][q][5] = - px * (block_w3[p][q + 1][4] + S::field_one() - block_w3[p][q + 1][0]); - } else { - block_w3[p][q][5] = px; - } - block_w3[p][q][4] = V_CNST * block_w3[p][q][5]; // For VIR // Compute VMR1, VMR2, VMR3, VMC for i in 0..block_num_vir_ops[p] { // VMR1 = r * VD - block_w2[p][q][V_VMR1(p, i)] = comb_r * block_vars_mat[p][q][io_width + V_VD(p, i)]; + q2[V_VMR1(p, i)] = comb_r * block_vars_mat[p][q][io_width + V_VD(p, i)]; // VMR2 = r^2 * VL - block_w2[p][q][V_VMR2(p, i)] = + q2[V_VMR2(p, i)] = comb_r * comb_r * block_vars_mat[p][q][io_width + V_VL(p, i)]; // VMR1 = r^3 * VT - block_w2[p][q][V_VMR3(p, i)] = + q2[V_VMR3(p, i)] = comb_r * comb_r * comb_r * block_vars_mat[p][q][io_width + V_VT(p, i)]; // VMC = (1 or VMC[i-1]) * (tau - VA - VMR1 - VMR2 - VMR3) let t = if i == 0 { V_CNST } else { - block_w2[p][q][V_VMC(p, i - 1)] + q2[V_VMC(p, i - 1)] }; - block_w2[p][q][V_VMC(p, i)] = t + q2[V_VMC(p, i)] = t * (comb_tau - block_vars_mat[p][q][io_width + V_VA(p, i)] - - block_w2[p][q][V_VMR1(p, i)] - - block_w2[p][q][V_VMR2(p, i)] - - block_w2[p][q][V_VMR3(p, i)]); + - q2[V_VMR1(p, i)] + - q2[V_VMR2(p, i)] + - q2[V_VMR3(p, i)]); + } + (q2, q3) + }).unzip(); + // Generate sequential entries separately + for q in (0..block_num_proofs[p]).rev() { + let V_CNST = block_vars_mat[p][q][0]; + // For INPUT + block_w3_p[q][1] = block_w3_p[q][0] + * (comb_tau + - block_w2_p[q][3..] + .iter() + .fold(S::field_zero(), |a, b| a + *b) + - block_vars_mat[p][q][2]); + if q != block_num_proofs[p] - 1 { + block_w3_p[q][3] = block_w3_p[q][1] + * (block_w3_p[q + 1][2] + S::field_one() - block_w3_p[q + 1][0]); + } else { + block_w3_p[q][3] = block_w3_p[q][1]; } + block_w3_p[q][2] = block_w3_p[q][0] * block_w3_p[q][3]; + + // For PHY + // Compute x + let px = if block_num_phy_ops[p] == 0 { + V_CNST + } else { + block_w2_p[q][V_PMC(block_num_phy_ops[p] - 1)] + }; + // Compute D and pi + if q != block_num_proofs[p] - 1 { + block_w3_p[q][5] = + px * (block_w3_p[q + 1][4] + S::field_one() - block_w3_p[q + 1][0]); + } else { + block_w3_p[q][5] = px; + } + block_w3_p[q][4] = V_CNST * block_w3_p[q][5]; + + // For VIR // Compute x let vx = if block_num_vir_ops[p] == 0 { V_CNST } else { - block_w2[p][q][V_VMC(p, block_num_vir_ops[p] - 1)] + block_w2_p[q][V_VMC(p, block_num_vir_ops[p] - 1)] }; // Compute D and pi if q != block_num_proofs[p] - 1 { - block_w3[p][q][7] = - vx * (block_w3[p][q + 1][6] + S::field_one() - block_w3[p][q + 1][0]); + block_w3_p[q][7] = + vx * (block_w3_p[q + 1][6] + S::field_one() - block_w3_p[q + 1][0]); } else { - block_w3[p][q][7] = vx; + block_w3_p[q][7] = vx; } - block_w3[p][q][6] = V_CNST * block_w3[p][q][7]; + block_w3_p[q][6] = V_CNST * block_w3_p[q][7]; } + + block_w2.push(block_w2_p); + block_w3.push(block_w3_p); } // commit the witnesses and inputs separately instance-by-instance @@ -1574,14 +1611,13 @@ impl SNARK { let perm_w0_prover = ProverWitnessSecInfo::new(vec![vec![perm_w0]], vec![perm_poly_w0]); let perm_exec_w2_prover = ProverWitnessSecInfo::new(vec![perm_exec_w2], vec![perm_exec_poly_w2]); - let perm_exec_w3_prover = - ProverWitnessSecInfo::new(vec![perm_exec_w3.clone()], vec![perm_exec_poly_w3]); let perm_exec_w3_shifted_prover = ProverWitnessSecInfo::new( vec![[perm_exec_w3[1..].to_vec(), vec![vec![S::field_zero(); 8]]].concat()], vec![perm_exec_poly_w3_shifted], ); + let perm_exec_w3_prover = + ProverWitnessSecInfo::new(vec![perm_exec_w3], vec![perm_exec_poly_w3]); - let block_w3_prover = ProverWitnessSecInfo::new(block_w3.clone(), block_poly_w3_list); let block_w3_shifted_prover = ProverWitnessSecInfo::new( block_w3 .iter() @@ -1589,6 +1625,7 @@ impl SNARK { .collect(), block_poly_w3_list_shifted, ); + let block_w3_prover = ProverWitnessSecInfo::new(block_w3, block_poly_w3_list); ( comb_tau,