Skip to content

Commit 6aea34c

Browse files
committed
Experimental improvements on block_witness_gen
1 parent ec8458a commit 6aea34c

File tree

1 file changed

+129
-92
lines changed

1 file changed

+129
-92
lines changed

spartan_parallel/src/lib.rs

Lines changed: 129 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ mod unipoly;
3838
use std::{
3939
cmp::{max, Ordering},
4040
fs::File,
41-
io::Write,
41+
io::Write, iter::zip,
4242
};
4343

4444
use dense_mlpoly::{DensePolynomial, PolyEvalProof};
@@ -50,6 +50,7 @@ use merlin::Transcript;
5050
use r1csinstance::{R1CSCommitment, R1CSDecommitment, R1CSEvalProof, R1CSInstance};
5151
use r1csproof::R1CSProof;
5252
use random::RandomTape;
53+
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator};
5354
use scalar::SpartanExtensionField;
5455
use serde::{Deserialize, Serialize};
5556
use timer::Timer;
@@ -1271,37 +1272,60 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
12711272
// w2 is _, _, ZO, r * i1, r^2 * i2, r^3 * i3, ...
12721273
// where ZO * r^n = r^n * o0 + r^(n + 1) * o1, ...,
12731274
// are used by the consistency check
1274-
let perm_exec_w2 = {
1275-
let mut perm_exec_w2: Vec<Vec<S>> = exec_inputs_list
1276-
.iter()
1277-
.map(|input| {
1278-
[
1279-
vec![S::field_zero(); 3],
1280-
(1..2 * num_inputs_unpadded - 2)
1281-
.map(|j| perm_w0[j] * input[j + 2])
1282-
.collect(),
1283-
vec![S::field_zero(); num_ios - 2 * num_inputs_unpadded],
1284-
]
1285-
.concat()
1286-
})
1287-
.collect();
1288-
for q in 0..consis_num_proofs {
1289-
perm_exec_w2[q][0] = exec_inputs_list[q][0];
1290-
perm_exec_w2[q][1] = exec_inputs_list[q][0];
1275+
let (perm_exec_w2, perm_exec_w3) = {
1276+
let perm_exec_w2: Vec<Vec<S>>;
1277+
let mut perm_exec_w3: Vec<Vec<S>>;
1278+
// Entries that do not depend on others can be generated in parallel
1279+
(perm_exec_w2, perm_exec_w3) = (0..consis_num_proofs).into_par_iter().map(|q| {
1280+
// perm_exec_w2
1281+
let mut perm_exec_w2_q = [
1282+
vec![S::field_zero(); 3],
1283+
(1..2 * num_inputs_unpadded - 2)
1284+
.map(|j| perm_w0[j] * exec_inputs_list[q][j + 2])
1285+
.collect(),
1286+
vec![S::field_zero(); num_ios - 2 * num_inputs_unpadded],
1287+
].concat();
1288+
perm_exec_w2_q[0] = exec_inputs_list[q][0];
1289+
perm_exec_w2_q[1] = exec_inputs_list[q][0];
12911290
for i in 0..num_inputs_unpadded - 1 {
12921291
let perm = if i == 0 { S::field_one() } else { perm_w0[i] };
1293-
perm_exec_w2[q][0] = perm_exec_w2[q][0] + perm * exec_inputs_list[q][2 + i];
1294-
perm_exec_w2[q][2] =
1295-
perm_exec_w2[q][2] + perm * exec_inputs_list[q][2 + (num_inputs_unpadded - 1) + i];
1292+
perm_exec_w2_q[0] = perm_exec_w2_q[0] + perm * exec_inputs_list[q][2 + i];
1293+
perm_exec_w2_q[2] =
1294+
perm_exec_w2_q[2] + perm * exec_inputs_list[q][2 + (num_inputs_unpadded - 1) + i];
12961295
}
1297-
perm_exec_w2[q][0] = perm_exec_w2[q][0] * exec_inputs_list[q][0];
1298-
let ZO = perm_exec_w2[q][2];
1299-
perm_exec_w2[q][1] = perm_exec_w2[q][1] + ZO;
1300-
perm_exec_w2[q][1] = perm_exec_w2[q][1] * exec_inputs_list[q][0];
1296+
perm_exec_w2_q[0] = perm_exec_w2_q[0] * exec_inputs_list[q][0];
1297+
let ZO = perm_exec_w2_q[2];
1298+
perm_exec_w2_q[1] = perm_exec_w2_q[1] + ZO;
1299+
perm_exec_w2_q[1] = perm_exec_w2_q[1] * exec_inputs_list[q][0];
1300+
1301+
// perm_exec_w3
1302+
let mut perm_exec_w3_q = vec![S::field_zero(); 8];
1303+
perm_exec_w3_q[0] = exec_inputs_list[q][0];
1304+
perm_exec_w3_q[1] = perm_exec_w3_q[0]
1305+
* (comb_tau
1306+
- perm_exec_w2_q[3..]
1307+
.iter()
1308+
.fold(S::field_zero(), |a, b| a + *b)
1309+
- exec_inputs_list[q][2]);
1310+
perm_exec_w3_q[4] = perm_exec_w2_q[0];
1311+
perm_exec_w3_q[5] = perm_exec_w2_q[1];
1312+
1313+
(perm_exec_w2_q, perm_exec_w3_q)
1314+
}).unzip();
1315+
// Generate sequential entries separately
1316+
for q in (0..consis_num_proofs).rev() {
1317+
if q != consis_num_proofs - 1 {
1318+
perm_exec_w3[q][3] = perm_exec_w3[q][1]
1319+
* (perm_exec_w3[q + 1][2] + S::field_one() - perm_exec_w3[q + 1][0]);
1320+
} else {
1321+
perm_exec_w3[q][3] = perm_exec_w3[q][1];
1322+
}
1323+
perm_exec_w3[q][2] = perm_exec_w3[q][0] * perm_exec_w3[q][3];
13011324
}
1302-
perm_exec_w2
1325+
(perm_exec_w2, perm_exec_w3)
13031326
};
1304-
// w3 is [v, x, pi, D]
1327+
1328+
/*
13051329
let perm_exec_w3 = {
13061330
let mut perm_exec_w3: Vec<Vec<S>> = vec![Vec::new(); consis_num_proofs];
13071331
for q in (0..consis_num_proofs).rev() {
@@ -1325,6 +1349,7 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
13251349
}
13261350
perm_exec_w3
13271351
};
1352+
*/
13281353
// commit the witnesses and inputs separately instance-by-instance
13291354
let (perm_exec_poly_w2, perm_exec_poly_w3, perm_exec_poly_w3_shifted) = {
13301355
let perm_exec_poly_w2 = {
@@ -1376,7 +1401,7 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
13761401
// w3 is [v, x, pi, D, pi, D, pi, D]
13771402
let mut block_w3: Vec<Vec<Vec<S>>> = Vec::new();
13781403
let block_w2_prover = {
1379-
let mut block_w2 = Vec::new();
1404+
let mut block_w2: Vec<Vec<Vec<S>>> = Vec::new();
13801405
let block_w2_size_list: Vec<usize> = (0..block_num_instances)
13811406
.map(|i| {
13821407
(2 * num_inputs_unpadded + 2 * block_num_phy_ops[i] + 4 * block_num_vir_ops[i])
@@ -1408,113 +1433,125 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
14081433
|b: usize, i: usize| 2 * num_inputs_unpadded + 2 * block_num_phy_ops[b] + 4 * i + 3;
14091434

14101435
for p in 0..block_num_instances {
1411-
block_w2.push(vec![Vec::new(); block_num_proofs[p]]);
1412-
block_w3.push(vec![Vec::new(); block_num_proofs[p]]);
1413-
for q in (0..block_num_proofs[p]).rev() {
1436+
let block_w2_p: Vec<Vec<S>>;
1437+
let mut block_w3_p: Vec<Vec<S>>;
1438+
// Entries that do not depend on others can be generated in parallel
1439+
(block_w2_p, block_w3_p) = (0..block_num_proofs[p]).into_par_iter().map(|q| {
14141440
let V_CNST = block_vars_mat[p][q][0];
14151441
// For INPUT
1416-
block_w2[p][q] = vec![S::field_zero(); block_w2_size_list[p]];
1442+
let mut q2 = vec![S::field_zero(); block_w2_size_list[p]];
14171443

1418-
block_w2[p][q][0] = block_vars_mat[p][q][0];
1419-
block_w2[p][q][1] = block_vars_mat[p][q][0];
1444+
q2[0] = block_vars_mat[p][q][0];
1445+
q2[1] = block_vars_mat[p][q][0];
14201446
for i in 1..2 * (num_inputs_unpadded - 1) {
1421-
block_w2[p][q][2 + i] =
1422-
block_w2[p][q][2 + i] + perm_w0[i] * block_vars_mat[p][q][i + 2];
1447+
q2[2 + i] = q2[2 + i] + perm_w0[i] * block_vars_mat[p][q][i + 2];
14231448
}
14241449
for i in 0..num_inputs_unpadded - 1 {
14251450
let perm = if i == 0 { S::field_one() } else { perm_w0[i] };
1426-
block_w2[p][q][0] = block_w2[p][q][0] + perm * block_vars_mat[p][q][2 + i];
1427-
block_w2[p][q][2] =
1428-
block_w2[p][q][2] + perm * block_vars_mat[p][q][2 + (num_inputs_unpadded - 1) + i];
1429-
}
1430-
block_w2[p][q][0] = block_w2[p][q][0] * block_vars_mat[p][q][0];
1431-
let ZO = block_w2[p][q][2];
1432-
block_w2[p][q][1] = block_w2[p][q][1] + ZO;
1433-
block_w2[p][q][1] = block_w2[p][q][1] * block_vars_mat[p][q][0];
1434-
block_w3[p][q] = vec![S::field_zero(); 8];
1435-
block_w3[p][q][0] = block_vars_mat[p][q][0];
1436-
block_w3[p][q][1] = block_w3[p][q][0]
1437-
* (comb_tau
1438-
- block_w2[p][q][3..]
1439-
.iter()
1440-
.fold(S::field_zero(), |a, b| a + *b)
1441-
- block_vars_mat[p][q][2]);
1442-
if q != block_num_proofs[p] - 1 {
1443-
block_w3[p][q][3] = block_w3[p][q][1]
1444-
* (block_w3[p][q + 1][2] + S::field_one() - block_w3[p][q + 1][0]);
1445-
} else {
1446-
block_w3[p][q][3] = block_w3[p][q][1];
1451+
q2[0] = q2[0] + perm * block_vars_mat[p][q][2 + i];
1452+
q2[2] = q2[2] + perm * block_vars_mat[p][q][2 + (num_inputs_unpadded - 1) + i];
14471453
}
1448-
block_w3[p][q][2] = block_w3[p][q][0] * block_w3[p][q][3];
1454+
q2[0] = q2[0] * block_vars_mat[p][q][0];
1455+
let ZO = q2[2];
1456+
q2[1] = q2[1] + ZO;
1457+
q2[1] = q2[1] * block_vars_mat[p][q][0];
1458+
let mut q3 = vec![S::field_zero(); 8];
1459+
q3[0] = block_vars_mat[p][q][0];
14491460

14501461
// For PHY
14511462
// Compute PMR, PMC
14521463
for i in 0..block_num_phy_ops[p] {
14531464
// PMR = r * PD
1454-
block_w2[p][q][V_PMR(i)] = comb_r * block_vars_mat[p][q][io_width + V_PD(i)];
1465+
q2[V_PMR(i)] = comb_r * block_vars_mat[p][q][io_width + V_PD(i)];
14551466
// PMC = (1 or PMC[i-1]) * (tau - PA - PMR)
14561467
let t = if i == 0 {
14571468
V_CNST
14581469
} else {
1459-
block_w2[p][q][V_PMC(i - 1)]
1470+
q2[V_PMC(i - 1)]
14601471
};
1461-
block_w2[p][q][V_PMC(i)] = t
1462-
* (comb_tau - block_vars_mat[p][q][io_width + V_PA(i)] - block_w2[p][q][V_PMR(i)]);
1472+
q2[V_PMC(i)] = t
1473+
* (comb_tau - block_vars_mat[p][q][io_width + V_PA(i)] - q2[V_PMR(i)]);
14631474
}
1464-
// Compute x
1465-
let px = if block_num_phy_ops[p] == 0 {
1466-
V_CNST
1467-
} else {
1468-
block_w2[p][q][V_PMC(block_num_phy_ops[p] - 1)]
1469-
};
1470-
// Compute D and pi
1471-
if q != block_num_proofs[p] - 1 {
1472-
block_w3[p][q][5] =
1473-
px * (block_w3[p][q + 1][4] + S::field_one() - block_w3[p][q + 1][0]);
1474-
} else {
1475-
block_w3[p][q][5] = px;
1476-
}
1477-
block_w3[p][q][4] = V_CNST * block_w3[p][q][5];
14781475

14791476
// For VIR
14801477
// Compute VMR1, VMR2, VMR3, VMC
14811478
for i in 0..block_num_vir_ops[p] {
14821479
// VMR1 = r * VD
1483-
block_w2[p][q][V_VMR1(p, i)] = comb_r * block_vars_mat[p][q][io_width + V_VD(p, i)];
1480+
q2[V_VMR1(p, i)] = comb_r * block_vars_mat[p][q][io_width + V_VD(p, i)];
14841481
// VMR2 = r^2 * VL
1485-
block_w2[p][q][V_VMR2(p, i)] =
1482+
q2[V_VMR2(p, i)] =
14861483
comb_r * comb_r * block_vars_mat[p][q][io_width + V_VL(p, i)];
14871484
// VMR1 = r^3 * VT
1488-
block_w2[p][q][V_VMR3(p, i)] =
1485+
q2[V_VMR3(p, i)] =
14891486
comb_r * comb_r * comb_r * block_vars_mat[p][q][io_width + V_VT(p, i)];
14901487
// VMC = (1 or VMC[i-1]) * (tau - VA - VMR1 - VMR2 - VMR3)
14911488
let t = if i == 0 {
14921489
V_CNST
14931490
} else {
1494-
block_w2[p][q][V_VMC(p, i - 1)]
1491+
q2[V_VMC(p, i - 1)]
14951492
};
1496-
block_w2[p][q][V_VMC(p, i)] = t
1493+
q2[V_VMC(p, i)] = t
14971494
* (comb_tau
14981495
- block_vars_mat[p][q][io_width + V_VA(p, i)]
1499-
- block_w2[p][q][V_VMR1(p, i)]
1500-
- block_w2[p][q][V_VMR2(p, i)]
1501-
- block_w2[p][q][V_VMR3(p, i)]);
1496+
- q2[V_VMR1(p, i)]
1497+
- q2[V_VMR2(p, i)]
1498+
- q2[V_VMR3(p, i)]);
1499+
}
1500+
(q2, q3)
1501+
}).unzip();
1502+
// Generate sequential entries separately
1503+
for q in (0..block_num_proofs[p]).rev() {
1504+
let V_CNST = block_vars_mat[p][q][0];
1505+
// For INPUT
1506+
block_w3_p[q][1] = block_w3_p[q][0]
1507+
* (comb_tau
1508+
- block_w2_p[q][3..]
1509+
.iter()
1510+
.fold(S::field_zero(), |a, b| a + *b)
1511+
- block_vars_mat[p][q][2]);
1512+
if q != block_num_proofs[p] - 1 {
1513+
block_w3_p[q][3] = block_w3_p[q][1]
1514+
* (block_w3_p[q + 1][2] + S::field_one() - block_w3_p[q + 1][0]);
1515+
} else {
1516+
block_w3_p[q][3] = block_w3_p[q][1];
15021517
}
1518+
block_w3_p[q][2] = block_w3_p[q][0] * block_w3_p[q][3];
1519+
1520+
// For PHY
1521+
// Compute x
1522+
let px = if block_num_phy_ops[p] == 0 {
1523+
V_CNST
1524+
} else {
1525+
block_w2_p[q][V_PMC(block_num_phy_ops[p] - 1)]
1526+
};
1527+
// Compute D and pi
1528+
if q != block_num_proofs[p] - 1 {
1529+
block_w3_p[q][5] =
1530+
px * (block_w3_p[q + 1][4] + S::field_one() - block_w3_p[q + 1][0]);
1531+
} else {
1532+
block_w3_p[q][5] = px;
1533+
}
1534+
block_w3_p[q][4] = V_CNST * block_w3_p[q][5];
1535+
1536+
// For VIR
15031537
// Compute x
15041538
let vx = if block_num_vir_ops[p] == 0 {
15051539
V_CNST
15061540
} else {
1507-
block_w2[p][q][V_VMC(p, block_num_vir_ops[p] - 1)]
1541+
block_w2_p[q][V_VMC(p, block_num_vir_ops[p] - 1)]
15081542
};
15091543
// Compute D and pi
15101544
if q != block_num_proofs[p] - 1 {
1511-
block_w3[p][q][7] =
1512-
vx * (block_w3[p][q + 1][6] + S::field_one() - block_w3[p][q + 1][0]);
1545+
block_w3_p[q][7] =
1546+
vx * (block_w3_p[q + 1][6] + S::field_one() - block_w3_p[q + 1][0]);
15131547
} else {
1514-
block_w3[p][q][7] = vx;
1548+
block_w3_p[q][7] = vx;
15151549
}
1516-
block_w3[p][q][6] = V_CNST * block_w3[p][q][7];
1550+
block_w3_p[q][6] = V_CNST * block_w3_p[q][7];
15171551
}
1552+
1553+
block_w2.push(block_w2_p);
1554+
block_w3.push(block_w3_p);
15181555
}
15191556

15201557
// commit the witnesses and inputs separately instance-by-instance
@@ -1574,21 +1611,21 @@ impl<S: SpartanExtensionField + Send + Sync> SNARK<S> {
15741611
let perm_w0_prover = ProverWitnessSecInfo::new(vec![vec![perm_w0]], vec![perm_poly_w0]);
15751612
let perm_exec_w2_prover =
15761613
ProverWitnessSecInfo::new(vec![perm_exec_w2], vec![perm_exec_poly_w2]);
1577-
let perm_exec_w3_prover =
1578-
ProverWitnessSecInfo::new(vec![perm_exec_w3.clone()], vec![perm_exec_poly_w3]);
15791614
let perm_exec_w3_shifted_prover = ProverWitnessSecInfo::new(
15801615
vec![[perm_exec_w3[1..].to_vec(), vec![vec![S::field_zero(); 8]]].concat()],
15811616
vec![perm_exec_poly_w3_shifted],
15821617
);
1618+
let perm_exec_w3_prover =
1619+
ProverWitnessSecInfo::new(vec![perm_exec_w3], vec![perm_exec_poly_w3]);
15831620

1584-
let block_w3_prover = ProverWitnessSecInfo::new(block_w3.clone(), block_poly_w3_list);
15851621
let block_w3_shifted_prover = ProverWitnessSecInfo::new(
15861622
block_w3
15871623
.iter()
15881624
.map(|i| [i[1..].to_vec(), vec![vec![S::field_zero(); 8]]].concat())
15891625
.collect(),
15901626
block_poly_w3_list_shifted,
15911627
);
1628+
let block_w3_prover = ProverWitnessSecInfo::new(block_w3, block_poly_w3_list);
15921629

15931630
(
15941631
comb_tau,

0 commit comments

Comments
 (0)