Skip to content

Commit fc385b6

Browse files
committed
Allow num_inputs to be different per witness_sec
1 parent 2f2ecd7 commit fc385b6

File tree

8 files changed

+329
-454
lines changed

8 files changed

+329
-454
lines changed

spartan_parallel/src/custom_dense_mlpoly.rs

Lines changed: 136 additions & 161 deletions
Large diffs are not rendered by default.

spartan_parallel/src/instance.rs

Lines changed: 53 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ pub struct Instance<S: SpartanExtensionField> {
2626
pub digest: Vec<u8>,
2727
}
2828

29-
impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
29+
impl<S: SpartanExtensionField> Instance<S> {
3030
/// Constructs a new `Instance` and an associated satisfying assignment
3131
pub fn new(
3232
num_instances: usize,
@@ -38,6 +38,8 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
3838
B: &Vec<Vec<(usize, usize, [u8; 32])>>,
3939
C: &Vec<Vec<(usize, usize, [u8; 32])>>,
4040
) -> Result<Instance<S>, R1CSError> {
41+
let ZERO = S::field_zero();
42+
4143
let (max_num_vars_padded, num_vars_padded, max_num_cons_padded, num_cons_padded) = {
4244
let max_num_vars_padded = {
4345
let mut max_num_vars_padded = max_num_vars;
@@ -82,12 +84,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
8284
}
8385
}
8486

85-
(
86-
max_num_vars_padded,
87-
num_vars_padded,
88-
max_num_cons_padded,
89-
num_cons_padded,
90-
)
87+
(max_num_vars_padded, num_vars_padded, max_num_cons_padded, num_cons_padded)
9188
};
9289

9390
let bytes_to_scalar =
@@ -124,7 +121,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
124121
// we do not need to pad otherwise because the dummy constraints are implicit in the sum-check protocol
125122
if num_cons[b] == 0 || num_cons[b] == 1 {
126123
for i in tups.len()..num_cons_padded[b] {
127-
mat.push((i, num_vars[b], S::field_zero()));
124+
mat.push((i, num_vars[b], ZERO));
128125
}
129126
}
130127

@@ -245,10 +242,10 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
245242
/// Verify the correctness of each block execution, as well as extracting all memory operations
246243
///
247244
/// Input composition: (if every segment exists)
248-
/// INPUT + VAR Challenges BLOCK_W2 BLOCK_W3 BLOCK_W3_SHIFTED
249-
/// 0 1 2 IOW +1 +2 +3 +4 +5 | 0 1 2 3 | 0 1 2 3 4 NIU 1 2 3 2NP +1 +2 +3 +4 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7
250-
/// v i0 ... PA0 PD0 ... VA0 VD0 ... | tau r r^2 ... | _ _ ZO r*i1 ... MR MC MR ... MR1 MR2 MR3 MC MR1 ... | v x pi D pi D pi D | v x pi D pi D pi D
251-
/// INPUT PHY VIR INPUT PHY VIR INPUT PHY VIR
245+
/// INPUT + VAR BLOCK_W2 Challenges BLOCK_W3 BLOCK_W3_SHIFTED
246+
/// 0 1 2 IOW +1 +2 +3 +4 +5 | 0 1 2 3 4 NIU 1 2 3 2NP +1 +2 +3 +4 | 0 1 2 3 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7
247+
/// v i0 ... PA0 PD0 ... VA0 VD0 ... | _ _ ZO r*i1 ... MR MC MR ... MR1 MR2 MR3 MC MR1 ... | tau r r^2 ... | v x pi D pi D pi D | v x pi D pi D pi D
248+
/// INPUT PHY VIR INPUT PHY VIR INPUT PHY VIR
252249
///
253250
/// VAR:
254251
/// We assume that the witnesses are of the following format:
@@ -271,7 +268,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
271268
/// - VMR3 = r^3 * VT
272269
/// - VMC = (1 or VMC[i-1]) * (tau - VA - VMR1 - VMR2 - VMR3)
273270
/// The final product is stored in X = MC[NV - 1]
274-
///
271+
///
275272
/// If in COMMIT_MODE, commit instance by num_vars_per_block, rounded to the nearest power of four
276273
pub fn gen_block_inst<const PRINT_SIZE: bool, const COMMIT_MODE: bool>(
277274
num_instances: usize,
@@ -306,20 +303,12 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
306303
max_size_per_group.insert(next_group_size(*num_vars), num_vars.next_power_of_two());
307304
}
308305
}
309-
num_vars_per_block
310-
.iter()
311-
.map(|i| {
312-
max_size_per_group
313-
.get(&next_group_size(*i))
314-
.unwrap()
315-
.clone()
316-
})
317-
.collect()
306+
num_vars_per_block.iter().map(|i| max_size_per_group.get(&next_group_size(*i)).unwrap().clone()).collect()
318307
} else {
319308
vec![num_vars; num_instances]
320309
};
321310

322-
if PRINT_SIZE {
311+
if PRINT_SIZE && !COMMIT_MODE {
323312
println!("\n\n--\nBLOCK INSTS");
324313
println!(
325314
"{:10} {:>4} {:>4} {:>4} {:>4}",
@@ -348,37 +337,30 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
348337
let V_VD = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 1;
349338
let V_VL = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 2;
350339
let V_VT = |b: usize, i: usize| io_width + 2 * num_phy_ops[b] + 4 * i + 3;
351-
// in CHALLENGES, not used if !has_mem_op
352-
let V_tau = |b: usize| num_vars_padded_per_block[b];
353-
let V_r = |b: usize, i: usize| num_vars_padded_per_block[b] + i;
354340
// in BLOCK_W2 / INPUT_W2
355341
let V_input_dot_prod = |b: usize, i: usize| {
356342
if i == 0 {
357343
V_input(0)
358344
} else {
359-
2 * num_vars_padded_per_block[b] + 2 + i
345+
num_vars_padded_per_block[b] + 2 + i
360346
}
361347
};
362-
let V_output_dot_prod =
363-
|b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 + (num_inputs_unpadded - 1) + i;
348+
let V_output_dot_prod = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 + (num_inputs_unpadded - 1) + i;
364349
// in BLOCK_W2 / PHY_W2
365-
let V_PMR =
366-
|b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i;
367-
let V_PMC =
368-
|b: usize, i: usize| 2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i + 1;
350+
let V_PMR = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i;
351+
let V_PMC = |b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * i + 1;
369352
// in BLOCK_W2 / VIR_W2
370-
let V_VMR1 = |b: usize, i: usize| {
371-
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i
372-
};
373-
let V_VMR2 = |b: usize, i: usize| {
374-
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 1
375-
};
376-
let V_VMR3 = |b: usize, i: usize| {
377-
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 2
378-
};
379-
let V_VMC = |b: usize, i: usize| {
380-
2 * num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 3
381-
};
353+
let V_VMR1 =
354+
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i;
355+
let V_VMR2 =
356+
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 1;
357+
let V_VMR3 =
358+
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 2;
359+
let V_VMC =
360+
|b: usize, i: usize| num_vars_padded_per_block[b] + 2 * num_inputs_unpadded + 2 * num_phy_ops[b] + 4 * i + 3;
361+
// in CHALLENGES, not used if !has_mem_op
362+
let V_tau = |b: usize| 2 * num_vars_padded_per_block[b];
363+
let V_r = |b: usize, i: usize| 2 * num_vars_padded_per_block[b] + i;
382364
// in BLOCK_W3
383365
let V_v = |b: usize| 3 * num_vars_padded_per_block[b];
384366
let V_x = |b: usize| 3 * num_vars_padded_per_block[b] + 1;
@@ -703,7 +685,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
703685
B_list.push(B);
704686
C_list.push(C);
705687

706-
if PRINT_SIZE {
688+
if PRINT_SIZE && !COMMIT_MODE {
707689
let max_nnz = max(tmp_nnz_A, max(tmp_nnz_B, tmp_nnz_C));
708690
let total_var = num_vars_per_block[b]
709691
+ 2 * num_inputs_unpadded.next_power_of_two()
@@ -724,7 +706,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
724706
}
725707
}
726708

727-
if PRINT_SIZE {
709+
if PRINT_SIZE && !COMMIT_MODE {
728710
println!("Total Num of Blocks: {}", num_instances);
729711
println!("Total Inst Commit Size: {}", total_inst_commit_size);
730712
println!("Total Var Commit Size: {}", total_var_commit_size);
@@ -744,10 +726,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
744726
max_cons_per_group.insert(num_vars_padded_per_block[i], block_num_cons[i]);
745727
}
746728
}
747-
num_vars_padded_per_block
748-
.iter()
749-
.map(|i| max_cons_per_group.get(i).unwrap().clone())
750-
.collect()
729+
num_vars_padded_per_block.iter().map(|i| max_cons_per_group.get(i).unwrap().clone()).collect()
751730
} else {
752731
block_num_cons
753732
}
@@ -759,10 +738,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
759738
block_max_num_cons,
760739
num_cons_padded_per_block,
761740
block_num_vars,
762-
num_vars_padded_per_block
763-
.into_iter()
764-
.map(|i| 8 * i)
765-
.collect(),
741+
num_vars_padded_per_block.into_iter().map(|i| 8 * i).collect(),
766742
&A_list,
767743
&B_list,
768744
&C_list,
@@ -816,9 +792,14 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
816792
/// D2 = D1 * (ls[i+1] - STORE)
817793
/// Where STORE = 0
818794
/// Input composition:
819-
/// Op[k] Op[k + 1] D2 & bits of ts[k + 1] - ts[k]
820-
/// 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4
821-
/// v D1 a d ls ts _ _ | v D1 a d ls ts _ _ | D2 EQ B0 B1 ...
795+
/// bits of ts[k + 1] - ts[k] Op[k] Op[k + 1]
796+
/// 0 1 2 3 4 | 0 1 2 3 4 5 6 7 | 0 1 2 3 4 5 6 7
797+
/// D2 EQ B0 B1 ... | v D1 a d ls ts _ _ | v D1 a d ls ts _ _
798+
///
799+
/// If ADDR_NONCONSEC, address comparison of VIR uses <= instead of +1, with the following expression
800+
/// ts | addr
801+
/// 0 1 2 3 4 | 0 1 2 3 4 5
802+
/// D2 EQ B0 B1 ... | D4 INV EQ B0 B1 ...
822803
pub fn gen_pairwise_check_inst<const PRINT_SIZE: bool>(
823804
max_ts_width: usize,
824805
mem_addr_ts_bits_size: usize,
@@ -834,14 +815,14 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
834815
"", "con", "var", "nnz", "exec"
835816
);
836817
}
818+
837819
// Variable used by printing
838820
let mut total_inst_commit_size = 0;
839821
let mut total_var_commit_size = 0;
840822
let mut total_cons_exec_size = 0;
841823

842824
let pairwise_check_num_vars = max(8, mem_addr_ts_bits_size);
843825
let pairwise_check_max_num_cons = 8 + max_ts_width;
844-
let pairwise_check_num_cons = vec![2, 4, 8 + max_ts_width];
845826
let pairwise_check_num_non_zero_entries: usize = max(13 + max_ts_width, 5 + 2 * max_ts_width);
846827

847828
let pairwise_check_inst = {
@@ -972,23 +953,24 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
972953
let (A, B, C) = {
973954
let width = pairwise_check_num_vars;
974955

975-
let V_valid = 0;
956+
// TS_BITS
957+
let V_D2 = 0;
958+
let V_EQ = 1;
959+
let V_B = |i| 2 + i;
960+
// OP[K], OP[K + 1]
961+
let V_valid = width;
976962
let V_cnst = V_valid;
977-
let V_D1 = 1;
978-
let V_addr = 2;
979-
let V_data = 3;
980-
let V_ls = 4;
981-
let V_ts = 5;
982-
let V_D2 = 2 * width;
983-
let V_EQ = 2 * width + 1;
984-
let V_B = |i| 2 * width + 2 + i;
963+
let V_D1 = width + 1;
964+
let V_addr = width + 2;
965+
let V_data = width + 3;
966+
let V_ls = width + 4;
967+
let V_ts = width + 5;
985968

986969
let mut A: Vec<(usize, usize, [u8; 32])> = Vec::new();
987970
let mut B: Vec<(usize, usize, [u8; 32])> = Vec::new();
988971
let mut C: Vec<(usize, usize, [u8; 32])> = Vec::new();
989972

990973
let mut num_cons = 0;
991-
// Sortedness
992974
// (v[k] - 1) * v[k + 1] = 0
993975
(A, B, C) = Instance::<S>::gen_constr(
994976
A,
@@ -1000,6 +982,7 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
1000982
vec![],
1001983
);
1002984
num_cons += 1;
985+
// Sortedness
1003986
// D1[k] = v[k + 1] * (1 - addr[k + 1] + addr[k])
1004987
(A, B, C) = Instance::<S>::gen_constr(
1005988
A,
@@ -1403,4 +1386,4 @@ impl<S: SpartanExtensionField + Send + Sync> Instance<S> {
14031386
perm_root_inst,
14041387
)
14051388
}
1406-
}
1389+
}

0 commit comments

Comments
 (0)