Skip to content

Commit

Permalink
Experimental improvements on block_witness_gen
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Jan 3, 2025
1 parent ec8458a commit 6aea34c
Showing 1 changed file with 129 additions and 92 deletions.
221 changes: 129 additions & 92 deletions spartan_parallel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ mod unipoly;
use std::{
cmp::{max, Ordering},
fs::File,
io::Write,
io::Write, iter::zip,
};

use dense_mlpoly::{DensePolynomial, PolyEvalProof};
Expand All @@ -50,6 +50,7 @@ use merlin::Transcript;
use r1csinstance::{R1CSCommitment, R1CSDecommitment, R1CSEvalProof, R1CSInstance};
use r1csproof::R1CSProof;
use random::RandomTape;
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
use scalar::SpartanExtensionField;
use serde::{Deserialize, Serialize};
use timer::Timer;
Expand Down Expand Up @@ -1271,37 +1272,60 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
// w2 is _, _, ZO, r * i1, r^2 * i2, r^3 * i3, ...
// where ZO * r^n = r^n * o0 + r^(n + 1) * o1, ...,
// are used by the consistency check
let perm_exec_w2 = {
let mut perm_exec_w2: Vec<Vec<S>> = exec_inputs_list
.iter()
.map(|input| {
[
vec![S::field_zero(); 3],
(1..2 * num_inputs_unpadded - 2)
.map(|j| perm_w0[j] * input[j + 2])
.collect(),
vec![S::field_zero(); num_ios - 2 * num_inputs_unpadded],
]
.concat()
})
.collect();
for q in 0..consis_num_proofs {
perm_exec_w2[q][0] = exec_inputs_list[q][0];
perm_exec_w2[q][1] = exec_inputs_list[q][0];
let (perm_exec_w2, perm_exec_w3) = {
let perm_exec_w2: Vec<Vec<S>>;
let mut perm_exec_w3: Vec<Vec<S>>;
// Entries that do not depend on others can be generated in parallel
(perm_exec_w2, perm_exec_w3) = (0..consis_num_proofs).into_par_iter().map(|q| {
// perm_exec_w2
let mut perm_exec_w2_q = [
vec![S::field_zero(); 3],
(1..2 * num_inputs_unpadded - 2)
.map(|j| perm_w0[j] * exec_inputs_list[q][j + 2])
.collect(),
vec![S::field_zero(); num_ios - 2 * num_inputs_unpadded],
].concat();
perm_exec_w2_q[0] = exec_inputs_list[q][0];
perm_exec_w2_q[1] = exec_inputs_list[q][0];
for i in 0..num_inputs_unpadded - 1 {
let perm = if i == 0 { S::field_one() } else { perm_w0[i] };
perm_exec_w2[q][0] = perm_exec_w2[q][0] + perm * exec_inputs_list[q][2 + i];
perm_exec_w2[q][2] =
perm_exec_w2[q][2] + perm * exec_inputs_list[q][2 + (num_inputs_unpadded - 1) + i];
perm_exec_w2_q[0] = perm_exec_w2_q[0] + perm * exec_inputs_list[q][2 + i];
perm_exec_w2_q[2] =
perm_exec_w2_q[2] + perm * exec_inputs_list[q][2 + (num_inputs_unpadded - 1) + i];
}
perm_exec_w2[q][0] = perm_exec_w2[q][0] * exec_inputs_list[q][0];
let ZO = perm_exec_w2[q][2];
perm_exec_w2[q][1] = perm_exec_w2[q][1] + ZO;
perm_exec_w2[q][1] = perm_exec_w2[q][1] * exec_inputs_list[q][0];
perm_exec_w2_q[0] = perm_exec_w2_q[0] * exec_inputs_list[q][0];
let ZO = perm_exec_w2_q[2];
perm_exec_w2_q[1] = perm_exec_w2_q[1] + ZO;
perm_exec_w2_q[1] = perm_exec_w2_q[1] * exec_inputs_list[q][0];

// perm_exec_w3
let mut perm_exec_w3_q = vec![S::field_zero(); 8];
perm_exec_w3_q[0] = exec_inputs_list[q][0];
perm_exec_w3_q[1] = perm_exec_w3_q[0]
* (comb_tau
- perm_exec_w2_q[3..]
.iter()
.fold(S::field_zero(), |a, b| a + *b)
- exec_inputs_list[q][2]);
perm_exec_w3_q[4] = perm_exec_w2_q[0];
perm_exec_w3_q[5] = perm_exec_w2_q[1];

(perm_exec_w2_q, perm_exec_w3_q)
}).unzip();
// Generate sequential entries separately
for q in (0..consis_num_proofs).rev() {
if q != consis_num_proofs - 1 {
perm_exec_w3[q][3] = perm_exec_w3[q][1]
* (perm_exec_w3[q + 1][2] + S::field_one() - perm_exec_w3[q + 1][0]);
} else {
perm_exec_w3[q][3] = perm_exec_w3[q][1];
}
perm_exec_w3[q][2] = perm_exec_w3[q][0] * perm_exec_w3[q][3];
}
perm_exec_w2
(perm_exec_w2, perm_exec_w3)
};
// w3 is [v, x, pi, D]

/*
let perm_exec_w3 = {
let mut perm_exec_w3: Vec<Vec<S>> = vec![Vec::new(); consis_num_proofs];
for q in (0..consis_num_proofs).rev() {
Expand All @@ -1325,6 +1349,7 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
}
perm_exec_w3
};
*/
// commit the witnesses and inputs separately instance-by-instance
let (perm_exec_poly_w2, perm_exec_poly_w3, perm_exec_poly_w3_shifted) = {
let perm_exec_poly_w2 = {
Expand Down Expand Up @@ -1376,7 +1401,7 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
// w3 is [v, x, pi, D, pi, D, pi, D]
let mut block_w3: Vec<Vec<Vec<S>>> = Vec::new();
let block_w2_prover = {
let mut block_w2 = Vec::new();
let mut block_w2: Vec<Vec<Vec<S>>> = Vec::new();
let block_w2_size_list: Vec<usize> = (0..block_num_instances)
.map(|i| {
(2 * num_inputs_unpadded + 2 * block_num_phy_ops[i] + 4 * block_num_vir_ops[i])
Expand Down Expand Up @@ -1408,113 +1433,125 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
|b: usize, i: usize| 2 * num_inputs_unpadded + 2 * block_num_phy_ops[b] + 4 * i + 3;

for p in 0..block_num_instances {
block_w2.push(vec![Vec::new(); block_num_proofs[p]]);
block_w3.push(vec![Vec::new(); block_num_proofs[p]]);
for q in (0..block_num_proofs[p]).rev() {
let block_w2_p: Vec<Vec<S>>;
let mut block_w3_p: Vec<Vec<S>>;
// Entries that do not depend on others can be generated in parallel
(block_w2_p, block_w3_p) = (0..block_num_proofs[p]).into_par_iter().map(|q| {
let V_CNST = block_vars_mat[p][q][0];
// For INPUT
block_w2[p][q] = vec![S::field_zero(); block_w2_size_list[p]];
let mut q2 = vec![S::field_zero(); block_w2_size_list[p]];

block_w2[p][q][0] = block_vars_mat[p][q][0];
block_w2[p][q][1] = block_vars_mat[p][q][0];
q2[0] = block_vars_mat[p][q][0];
q2[1] = block_vars_mat[p][q][0];
for i in 1..2 * (num_inputs_unpadded - 1) {
block_w2[p][q][2 + i] =
block_w2[p][q][2 + i] + perm_w0[i] * block_vars_mat[p][q][i + 2];
q2[2 + i] = q2[2 + i] + perm_w0[i] * block_vars_mat[p][q][i + 2];
}
for i in 0..num_inputs_unpadded - 1 {
let perm = if i == 0 { S::field_one() } else { perm_w0[i] };
block_w2[p][q][0] = block_w2[p][q][0] + perm * block_vars_mat[p][q][2 + i];
block_w2[p][q][2] =
block_w2[p][q][2] + perm * block_vars_mat[p][q][2 + (num_inputs_unpadded - 1) + i];
}
block_w2[p][q][0] = block_w2[p][q][0] * block_vars_mat[p][q][0];
let ZO = block_w2[p][q][2];
block_w2[p][q][1] = block_w2[p][q][1] + ZO;
block_w2[p][q][1] = block_w2[p][q][1] * block_vars_mat[p][q][0];
block_w3[p][q] = vec![S::field_zero(); 8];
block_w3[p][q][0] = block_vars_mat[p][q][0];
block_w3[p][q][1] = block_w3[p][q][0]
* (comb_tau
- block_w2[p][q][3..]
.iter()
.fold(S::field_zero(), |a, b| a + *b)
- block_vars_mat[p][q][2]);
if q != block_num_proofs[p] - 1 {
block_w3[p][q][3] = block_w3[p][q][1]
* (block_w3[p][q + 1][2] + S::field_one() - block_w3[p][q + 1][0]);
} else {
block_w3[p][q][3] = block_w3[p][q][1];
q2[0] = q2[0] + perm * block_vars_mat[p][q][2 + i];
q2[2] = q2[2] + perm * block_vars_mat[p][q][2 + (num_inputs_unpadded - 1) + i];
}
block_w3[p][q][2] = block_w3[p][q][0] * block_w3[p][q][3];
q2[0] = q2[0] * block_vars_mat[p][q][0];
let ZO = q2[2];
q2[1] = q2[1] + ZO;
q2[1] = q2[1] * block_vars_mat[p][q][0];
let mut q3 = vec![S::field_zero(); 8];
q3[0] = block_vars_mat[p][q][0];

// For PHY
// Compute PMR, PMC
for i in 0..block_num_phy_ops[p] {
// PMR = r * PD
block_w2[p][q][V_PMR(i)] = comb_r * block_vars_mat[p][q][io_width + V_PD(i)];
q2[V_PMR(i)] = comb_r * block_vars_mat[p][q][io_width + V_PD(i)];
// PMC = (1 or PMC[i-1]) * (tau - PA - PMR)
let t = if i == 0 {
V_CNST
} else {
block_w2[p][q][V_PMC(i - 1)]
q2[V_PMC(i - 1)]
};
block_w2[p][q][V_PMC(i)] = t
* (comb_tau - block_vars_mat[p][q][io_width + V_PA(i)] - block_w2[p][q][V_PMR(i)]);
q2[V_PMC(i)] = t
* (comb_tau - block_vars_mat[p][q][io_width + V_PA(i)] - q2[V_PMR(i)]);
}
// Compute x
let px = if block_num_phy_ops[p] == 0 {
V_CNST
} else {
block_w2[p][q][V_PMC(block_num_phy_ops[p] - 1)]
};
// Compute D and pi
if q != block_num_proofs[p] - 1 {
block_w3[p][q][5] =
px * (block_w3[p][q + 1][4] + S::field_one() - block_w3[p][q + 1][0]);
} else {
block_w3[p][q][5] = px;
}
block_w3[p][q][4] = V_CNST * block_w3[p][q][5];

// For VIR
// Compute VMR1, VMR2, VMR3, VMC
for i in 0..block_num_vir_ops[p] {
// VMR1 = r * VD
block_w2[p][q][V_VMR1(p, i)] = comb_r * block_vars_mat[p][q][io_width + V_VD(p, i)];
q2[V_VMR1(p, i)] = comb_r * block_vars_mat[p][q][io_width + V_VD(p, i)];
// VMR2 = r^2 * VL
block_w2[p][q][V_VMR2(p, i)] =
q2[V_VMR2(p, i)] =
comb_r * comb_r * block_vars_mat[p][q][io_width + V_VL(p, i)];
// VMR1 = r^3 * VT
block_w2[p][q][V_VMR3(p, i)] =
q2[V_VMR3(p, i)] =
comb_r * comb_r * comb_r * block_vars_mat[p][q][io_width + V_VT(p, i)];
// VMC = (1 or VMC[i-1]) * (tau - VA - VMR1 - VMR2 - VMR3)
let t = if i == 0 {
V_CNST
} else {
block_w2[p][q][V_VMC(p, i - 1)]
q2[V_VMC(p, i - 1)]
};
block_w2[p][q][V_VMC(p, i)] = t
q2[V_VMC(p, i)] = t
* (comb_tau
- block_vars_mat[p][q][io_width + V_VA(p, i)]
- block_w2[p][q][V_VMR1(p, i)]
- block_w2[p][q][V_VMR2(p, i)]
- block_w2[p][q][V_VMR3(p, i)]);
- q2[V_VMR1(p, i)]
- q2[V_VMR2(p, i)]
- q2[V_VMR3(p, i)]);
}
(q2, q3)
}).unzip();
// Generate sequential entries separately
for q in (0..block_num_proofs[p]).rev() {
let V_CNST = block_vars_mat[p][q][0];
// For INPUT
block_w3_p[q][1] = block_w3_p[q][0]
* (comb_tau
- block_w2_p[q][3..]
.iter()
.fold(S::field_zero(), |a, b| a + *b)
- block_vars_mat[p][q][2]);
if q != block_num_proofs[p] - 1 {
block_w3_p[q][3] = block_w3_p[q][1]
* (block_w3_p[q + 1][2] + S::field_one() - block_w3_p[q + 1][0]);
} else {
block_w3_p[q][3] = block_w3_p[q][1];
}
block_w3_p[q][2] = block_w3_p[q][0] * block_w3_p[q][3];

// For PHY
// Compute x
let px = if block_num_phy_ops[p] == 0 {
V_CNST
} else {
block_w2_p[q][V_PMC(block_num_phy_ops[p] - 1)]
};
// Compute D and pi
if q != block_num_proofs[p] - 1 {
block_w3_p[q][5] =
px * (block_w3_p[q + 1][4] + S::field_one() - block_w3_p[q + 1][0]);
} else {
block_w3_p[q][5] = px;
}
block_w3_p[q][4] = V_CNST * block_w3_p[q][5];

// For VIR
// Compute x
let vx = if block_num_vir_ops[p] == 0 {
V_CNST
} else {
block_w2[p][q][V_VMC(p, block_num_vir_ops[p] - 1)]
block_w2_p[q][V_VMC(p, block_num_vir_ops[p] - 1)]
};
// Compute D and pi
if q != block_num_proofs[p] - 1 {
block_w3[p][q][7] =
vx * (block_w3[p][q + 1][6] + S::field_one() - block_w3[p][q + 1][0]);
block_w3_p[q][7] =
vx * (block_w3_p[q + 1][6] + S::field_one() - block_w3_p[q + 1][0]);
} else {
block_w3[p][q][7] = vx;
block_w3_p[q][7] = vx;
}
block_w3[p][q][6] = V_CNST * block_w3[p][q][7];
block_w3_p[q][6] = V_CNST * block_w3_p[q][7];
}

block_w2.push(block_w2_p);
block_w3.push(block_w3_p);
}

// commit the witnesses and inputs separately instance-by-instance
Expand Down Expand Up @@ -1574,21 +1611,21 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
let perm_w0_prover = ProverWitnessSecInfo::new(vec![vec![perm_w0]], vec![perm_poly_w0]);
let perm_exec_w2_prover =
ProverWitnessSecInfo::new(vec![perm_exec_w2], vec![perm_exec_poly_w2]);
let perm_exec_w3_prover =
ProverWitnessSecInfo::new(vec![perm_exec_w3.clone()], vec![perm_exec_poly_w3]);
let perm_exec_w3_shifted_prover = ProverWitnessSecInfo::new(
vec![[perm_exec_w3[1..].to_vec(), vec![vec![S::field_zero(); 8]]].concat()],
vec![perm_exec_poly_w3_shifted],
);
let perm_exec_w3_prover =
ProverWitnessSecInfo::new(vec![perm_exec_w3], vec![perm_exec_poly_w3]);

let block_w3_prover = ProverWitnessSecInfo::new(block_w3.clone(), block_poly_w3_list);
let block_w3_shifted_prover = ProverWitnessSecInfo::new(
block_w3
.iter()
.map(|i| [i[1..].to_vec(), vec![vec![S::field_zero(); 8]]].concat())
.collect(),
block_poly_w3_list_shifted,
);
let block_w3_prover = ProverWitnessSecInfo::new(block_w3, block_poly_w3_list);

(
comb_tau,
Expand Down

0 comments on commit 6aea34c

Please sign in to comment.