Skip to content

Commit

Permalink
Progress on WHIR integration
Browse files Browse the repository at this point in the history
  • Loading branch information
Kunming Jiang committed Feb 4, 2025
1 parent 2eeab15 commit 30ad655
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 167 deletions.
213 changes: 70 additions & 143 deletions spartan_parallel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -906,28 +906,47 @@ impl<E: ExtensionField + Send + Sync, Pcs: PolynomialCommitmentScheme<E>> SNARK<
let (pp, _error) = Pcs::trim(param, num_vars).unwrap();
// create a multilinear polynomial using the supplied assignment for variables
let poly = DenseMultilinearExtension::from_evaluation_vec_smart(mat_concat_p.len().log_2(), mat_concat_p);
println!("POLY_SIZE: {:?}", poly.num_vars.pow2());
let p_comm = Pcs::commit(&pp, &poly).unwrap();
let v_comm = Pcs::get_pure_commitment(&p_comm);
(poly, p_comm, v_comm)
}

// Convert a list of matrices into mles and commitments
fn mats_to_comms(mats: &Vec<Vec<Vec<E>>>, pp: &ProverParam<E, Pcs>) -> (
Vec<DenseMultilinearExtension<E>>,
Vec<Pcs::CommitmentWithWitness>,
// Convert a matrix into a prover witness sec and a commitment
fn mat_to_prove_wit_sec(mat: Vec<Vec<E>>, pp: &ProverParam<E, Pcs>) -> (
ProverWitnessSecInfo<E, Pcs>,
Pcs::Commitment,
) {
let (poly, p_comm, v_comm) = Self::mat_to_comm(&mat, pp);
let prover_wit_sec = ProverWitnessSecInfo::new(vec![mat], vec![poly], vec![p_comm]);
(prover_wit_sec, v_comm)
}

// Convert a matrix into a prover witness sec without committing
fn mat_to_prover_wit_sec_no_commit(mat: Vec<Vec<E>>) -> ProverWitnessSecInfo<E, Pcs> {
// Flatten the witnesses into a Q_i * X list
let mat_concat_p: Vec<E::BaseField> = mat.clone().into_iter().flatten().map(|e| e.as_bases()[0].clone()).collect();
// create a multilinear polynomial using the supplied assignment for variables
let poly = DenseMultilinearExtension::from_evaluation_vec_smart(mat_concat_p.len().log_2(), mat_concat_p);
let prover_wit_sec = ProverWitnessSecInfo::new(vec![mat], vec![poly], Vec::new()); // No commitment
prover_wit_sec
}

// Convert a list of matrices into prover witness secs and commitments
fn mats_to_prove_wit_sec(mats: Vec<Vec<Vec<E>>>, pp: &ProverParam<E, Pcs>) -> (
ProverWitnessSecInfo<E, Pcs>,
Vec<Pcs::Commitment>,
) {
let mut polys = Vec::new();
let mut p_comms = Vec::new();
let mut v_comms = Vec::new();
for mat in mats {
let (poly, pc, vc) = Self::mat_to_comm(mat, pp);
for mat in &mats {
let (poly, p_comm, v_comm) = Self::mat_to_comm(mat, pp);
polys.push(poly);
p_comms.push(pc);
v_comms.push(vc);
p_comms.push(p_comm);
v_comms.push(v_comm);
}
(polys, p_comms, v_comms)
let prover_wit_sec = ProverWitnessSecInfo::new(mats, polys, p_comms);
(prover_wit_sec, v_comms)
}

/// A method to produce a SNARK proof of the satisfiability of an R1CS instance
Expand Down Expand Up @@ -1329,8 +1348,6 @@ impl<E: ExtensionField + Send + Sync, Pcs: PolynomialCommitmentScheme<E>> SNARK<
perm_w0.extend(vec![E::ZERO; num_ios - 2 * num_inputs_unpadded]);
perm_w0
};
// create a multilinear polynomial using the supplied assignment for variables
let perm_w0_mle = DenseMultilinearExtension::from_evaluation_vec_smart(perm_w0.len().log_2(), perm_w0);

// PERM_EXEC
// w2 is _, _, ZO, r * i1, r^2 * i2, r^3 * i3, ...
Expand Down Expand Up @@ -1553,32 +1570,14 @@ impl<E: ExtensionField + Send + Sync, Pcs: PolynomialCommitmentScheme<E>> SNARK<
};
let block_w3_shifted_mat = block_w3.iter().map(|i| [i[1..].to_vec(), vec![vec![E::ZERO; 8]]].concat()).collect();

let (perm_exec_w2_mle, perm_exec_w2_p_comm, perm_exec_w2_v_comm) = Self::mat_to_comm(&perm_exec_w2, &poly_pp);
let (perm_exec_w3_mle, perm_exec_w3_p_comm, perm_exec_w3_v_comm) = Self::mat_to_comm(&perm_exec_w3, &poly_pp);
let (perm_exec_w3_shifted_mle, perm_exec_w3_shifted_p_comm, perm_exec_w3_shifted_v_comm) = Self::mat_to_comm(&perm_exec_w3_shifted_mat, &poly_pp);

let perm_w0_prover = ProverWitnessSecInfo::new(vec![vec![perm_w0]], vec![perm_w0_mle], _);
let perm_exec_w2_prover =
ProverWitnessSecInfo::new(vec![perm_exec_w2], vec![perm_exec_w2_mle], vec![perm_exec_w2_p_comm]);
let perm_exec_w3_shifted_prover = ProverWitnessSecInfo::new(
vec![perm_exec_w3_shifted_mat],
vec![perm_exec_w3_shifted_mle],
vec![perm_exec_w3_shifted_p_comm],
);
let perm_exec_w3_prover =
ProverWitnessSecInfo::new(vec![perm_exec_w3], vec![perm_exec_w3_mle], vec![perm_exec_w3_p_comm]);

let (block_w2_mle, block_w2_p_comm, block_w2_v_comm) = Self::mats_to_comms(&block_w2, &poly_pp);
let (block_w3_mle, block_w3_p_comm, block_w3_v_comm) = Self::mats_to_comms(&block_w3, &poly_pp);
let (block_w3_shifted_mle, block_w3_shifted_p_comm, block_w3_shifted_v_comm) = Self::mats_to_comms(&block_w3_shifted_mat, &poly_pp);

let block_w2_prover = ProverWitnessSecInfo::new(block_w2, block_w2_mle, block_w2_p_comm);
let block_w3_shifted_prover = ProverWitnessSecInfo::new(
block_w3_shifted_mat,
block_w3_shifted_mle,
block_w3_shifted_p_comm,
);
let block_w3_prover = ProverWitnessSecInfo::new(block_w3, block_w3_mle, block_w3_p_comm);
let perm_w0_prover = Self::mat_to_prover_wit_sec_no_commit(vec![perm_w0]); // Do not commit perm_w0
let (perm_exec_w2_prover, perm_exec_w2_v_comm) = Self::mat_to_prove_wit_sec(perm_exec_w2, &poly_pp);
let (perm_exec_w3_prover, perm_exec_w3_v_comm) = Self::mat_to_prove_wit_sec(perm_exec_w3, &poly_pp);
let (perm_exec_w3_shifted_prover, perm_exec_w3_shifted_v_comm) = Self::mat_to_prove_wit_sec(perm_exec_w3_shifted_mat, &poly_pp);

let (block_w2_prover, block_w2_v_comm) = Self::mats_to_prove_wit_sec(block_w2, &poly_pp);
let (block_w3_prover, block_w3_v_comm) = Self::mats_to_prove_wit_sec(block_w3, &poly_pp);
let (block_w3_shifted_prover, block_w3_shifted_v_comm) = Self::mats_to_prove_wit_sec(block_w3_shifted_mat, &poly_pp);

(
comb_tau,
Expand Down Expand Up @@ -1727,59 +1726,24 @@ impl<E: ExtensionField + Send + Sync, Pcs: PolynomialCommitmentScheme<E>> SNARK<
// WITNESS COMMITMENTS
// --
let timer_commit = Timer::new("input_commit");
let (block_vars_mle_list, block_vars_p_comm_list, block_vars_v_comm_list) = Self::mats_to_comms(&block_vars_mat, &poly_pp);
let (exec_inputs_mle, exec_inputs_p_comm, exec_inputs_v_comm) = Self::mat_to_comm(&exec_inputs_list, &poly_pp);
let (poly_init_phy_mems,) = {
if total_num_init_phy_mem_accesses > 0 {
let poly_init_mems = {
let init_mems = init_phy_mems_list.clone().into_iter().flatten().collect();
// create a multilinear polynomial using the supplied assignment for variables
let poly_init_mems = DensePolynomial::new(init_mems);
poly_init_mems
};
(vec![poly_init_mems],)
} else {
(Vec::new(),)
}
let (block_vars_prover, block_vars_v_comm_list) = Self::mats_to_prove_wit_sec(block_vars_mat, &poly_pp);
let (exec_inputs_prover, exec_inputs_v_comm) = Self::mat_to_prove_wit_sec(exec_inputs_list, &poly_pp);
let init_phy_mems_prover = if total_num_init_phy_mem_accesses > 0 {
Self::mat_to_prover_wit_sec_no_commit(init_phy_mems_list)
} else {
ProverWitnessSecInfo::dummy()
};
let (poly_init_vir_mems,) = {
if total_num_init_vir_mem_accesses > 0 {
let poly_init_mems = {
let init_mems = init_vir_mems_list.clone().into_iter().flatten().collect();
// create a multilinear polynomial using the supplied assignment for variables
let poly_init_mems = DensePolynomial::new(init_mems);
poly_init_mems
};
(vec![poly_init_mems],)
} else {
(Vec::new(),)
}
let init_vir_mems_prover = if total_num_init_vir_mem_accesses > 0 {
Self::mat_to_prover_wit_sec_no_commit(init_vir_mems_list)
} else {
ProverWitnessSecInfo::dummy()
};

let (addr_phy_mems_prover, addr_phy_mems_comm, addr_phy_mems_shifted_prover, addr_phy_mems_shifted_comm) = {
if total_num_phy_mem_accesses > 0 {
// Remove the first entry and shift the remaining entries up by one
let addr_phy_mems_shifted_list = vec![addr_phy_mems_list[1..].to_vec(), vec![vec![E::ZERO; PHY_MEM_WIDTH]]].concat();
let (addr_phy_mems_mle, addr_phy_mems_p_comm, addr_phy_mems_v_comm) = Self::mat_to_comm(&addr_phy_mems_list, &poly_pp);
let (addr_phy_mems_shifted_mle, addr_phy_mems_shifted_p_comm, addr_phy_mems_shifted_v_comm) = Self::mat_to_comm(&addr_phy_mems_shifted_list, &poly_pp);
// Used later by coherence check
let addr_phy_mems_shifted_prover = {
let addr_phy_mems_shifted = addr_phy_mems_shifted_list
.into_iter()
.flatten()
.collect();
let addr_phy_mems_shifted_prover = ProverWitnessSecInfo::new(
vec![addr_phy_mems_shifted_list],
vec![addr_phy_mems_shifted_mle],
vec![addr_phy_mems_shifted_p_comm],
);
addr_phy_mems_shifted_prover
};
let addr_phy_mems_prover = ProverWitnessSecInfo::new(
vec![addr_phy_mems_list],
vec![addr_phy_mems_mle],
vec![addr_phy_mems_p_comm]
);
let (addr_phy_mems_prover, addr_phy_mems_v_comm) = Self::mat_to_prove_wit_sec(addr_phy_mems_list, &poly_pp);
let (addr_phy_mems_shifted_prover, addr_phy_mems_shifted_v_comm) = Self::mat_to_prove_wit_sec(addr_phy_mems_shifted_list, &poly_pp);
(addr_phy_mems_prover, Some(addr_phy_mems_v_comm), addr_phy_mems_shifted_prover, Some(addr_phy_mems_shifted_v_comm))
} else {
(ProverWitnessSecInfo::dummy(), None, ProverWitnessSecInfo::dummy(), None)
Expand All @@ -1789,41 +1753,16 @@ impl<E: ExtensionField + Send + Sync, Pcs: PolynomialCommitmentScheme<E>> SNARK<
if total_num_vir_mem_accesses > 0 {
// Remove the first entry and shift the remaining entries up by one
let addr_vir_mems_shifted_list = vec![addr_vir_mems_list[1..].to_vec(), vec![vec![E::ZERO; VIR_MEM_WIDTH]]].concat();
let (addr_vir_mems_mle, addr_vir_mems_p_comm, addr_vir_mems_v_comm) = Self::mat_to_comm(&addr_vir_mems_list, &poly_pp);
let (addr_vir_mems_shifted_mle, addr_vir_mems_shifted_p_comm, addr_vir_mems_shifted_v_comm) = Self::mat_to_comm(&addr_vir_mems_shifted_list, &poly_pp);
// Used later by coherence check
let addr_vir_mems_shifted_prover = {
let addr_vir_mems_shifted = addr_vir_mems_shifted_list
.into_iter()
.flatten()
.collect();
let addr_vir_mems_shifted_prover = ProverWitnessSecInfo::new(
vec![addr_vir_mems_shifted_list],
vec![addr_vir_mems_shifted_mle],
vec![addr_vir_mems_shifted_p_comm],
);
addr_vir_mems_shifted_prover
};
let addr_vir_mems_prover = ProverWitnessSecInfo::new(
vec![addr_vir_mems_list],
vec![addr_vir_mems_mle],
vec![addr_vir_mems_p_comm]
);
let (addr_ts_bits_prover, addr_ts_bits_comm) = {
let (addr_ts_bits_mle, addr_ts_bits_p_comm, addr_ts_bits_v_comm) = Self::mat_to_comm(&addr_ts_bits_list, &poly_pp);

let addr_ts_bits = addr_ts_bits_list.clone().into_iter().flatten().collect();
let addr_ts_bits_prover =
ProverWitnessSecInfo::new(vec![addr_ts_bits_list], vec![addr_ts_bits_mle], vec![addr_ts_bits_p_comm]);
(addr_ts_bits_prover, addr_ts_bits_v_comm)
};
let (addr_vir_mems_prover, addr_vir_mems_v_comm) = Self::mat_to_prove_wit_sec(addr_vir_mems_list, &poly_pp);
let (addr_vir_mems_shifted_prover, addr_vir_mems_shifted_v_comm) = Self::mat_to_prove_wit_sec(addr_vir_mems_shifted_list, &poly_pp);
let (addr_ts_bits_prover, addr_ts_bits_v_comm) = Self::mat_to_prove_wit_sec(addr_ts_bits_list, &poly_pp);
(
addr_vir_mems_prover,
Some(addr_vir_mems_v_comm),
addr_vir_mems_shifted_prover,
Some(addr_vir_mems_shifted_v_comm),
addr_ts_bits_prover,
Some(addr_ts_bits_comm),
Some(addr_ts_bits_v_comm),
)
} else {
(
Expand All @@ -1836,112 +1775,100 @@ impl<E: ExtensionField + Send + Sync, Pcs: PolynomialCommitmentScheme<E>> SNARK<
)
}
};
let block_vars_prover = ProverWitnessSecInfo::new(block_vars_mat, block_vars_mle_list, block_vars_p_comm_list);
let exec_inputs_prover = ProverWitnessSecInfo::new(vec![exec_inputs_list], vec![exec_inputs_mle], vec![exec_inputs_p_comm]);
let init_phy_mems_prover = if total_num_init_phy_mem_accesses > 0 {
ProverWitnessSecInfo::new(vec![init_phy_mems_list], poly_init_phy_mems)
} else {
ProverWitnessSecInfo::dummy()
};
let init_vir_mems_prover = if total_num_init_vir_mem_accesses > 0 {
ProverWitnessSecInfo::new(vec![init_vir_mems_list], poly_init_vir_mems)
} else {
ProverWitnessSecInfo::dummy()
};
timer_commit.stop();

// Record total size of witnesses:
let block_witness_sizes: Vec<usize> = [
block_vars_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
block_w2_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
block_w3_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
block_w3_shifted_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
]
.concat();
let exec_witness_sizes: Vec<usize> = [
exec_inputs_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
perm_exec_w2_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
perm_exec_w3_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
perm_exec_w3_shifted_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
]
.concat();
let mem_witness_sizes: Vec<usize> = [
addr_phy_mems_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
phy_mem_addr_w2_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
phy_mem_addr_w3_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
phy_mem_addr_w3_shifted_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
addr_vir_mems_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
addr_ts_bits_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
vir_mem_addr_w2_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
vir_mem_addr_w3_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
vir_mem_addr_w3_shifted_prover
.poly_w
.iter()
.map(|i| i.len())
.map(|i| i.evaluations.len())
.collect::<Vec<usize>>(),
]
.concat();
Expand Down
Loading

0 comments on commit 30ad655

Please sign in to comment.