Skip to content

Commit dd00d8e

Browse files
committed
* Implement PartialEq, Eq, Hash for Cell and AssignedCell
* Add table, compressed and normal rows count. * Add rows and table rows to cost model. * Ignore unassigned cells if they are multiplied by zero * Some format values are written as "Scalar(0x..)" The hotfix was to change the stripping rules, but this is probably an incorrect implementation of certain traits for one of the curves.
1 parent 0092a15 commit dd00d8e

File tree

9 files changed

+152
-61
lines changed

9 files changed

+152
-61
lines changed

src/circuit.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ pub trait Chip<F: Field>: Sized {
5050
}
5151

5252
/// Index of a region in a layouter
53-
#[derive(Clone, Copy, Debug)]
53+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
5454
pub struct RegionIndex(usize);
5555

5656
impl From<usize> for RegionIndex {
@@ -86,7 +86,7 @@ impl std::ops::Deref for RegionStart {
8686
}
8787

8888
/// A pointer to a cell within a circuit.
89-
#[derive(Clone, Copy, Debug)]
89+
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
9090
pub struct Cell {
9191
/// Identifies the region in which this cell resides.
9292
pub region_index: RegionIndex,
@@ -104,6 +104,21 @@ pub struct AssignedCell<V, F: Field> {
104104
_marker: PhantomData<F>,
105105
}
106106

107+
impl<V, F: Field> PartialEq for AssignedCell<V, F> {
108+
fn eq(&self, other: &Self) -> bool {
109+
self.cell == other.cell
110+
}
111+
}
112+
113+
impl<V, F: Field> Eq for AssignedCell<V, F> {}
114+
115+
use std::hash::{Hash, Hasher};
116+
impl<V, F: Field> Hash for AssignedCell<V, F> {
117+
fn hash<H: Hasher>(&self, state: &mut H) {
118+
self.cell.hash(state)
119+
}
120+
}
121+
107122
impl<V, F: Field> AssignedCell<V, F> {
108123
/// Returns the value of the [`AssignedCell`].
109124
pub fn value(&self) -> Value<&V> {

src/circuit/floor_planner/v1/strategy.rs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -213,22 +213,8 @@ pub fn slot_in_biggest_advice_first(
213213
advice_cols * shape.row_count()
214214
};
215215

216-
// This used to incorrectly use `sort_unstable_by_key` with non-unique keys, which gave
217-
// output that differed between 32-bit and 64-bit platforms, and potentially between Rust
218-
// versions.
219-
// We now use `sort_by_cached_key` with non-unique keys, and rely on `region_shapes`
220-
// being sorted by region index (which we also rely on below to return `RegionStart`s
221-
// in the correct order).
222-
#[cfg(not(feature = "floor-planner-v1-legacy-pdqsort"))]
223216
sorted_regions.sort_by_cached_key(sort_key);
224217

225-
// To preserve compatibility, when the "floor-planner-v1-legacy-pdqsort" feature is enabled,
226-
// we use a copy of the pdqsort implementation from the Rust 1.56.1 standard library, fixed
227-
// to its behaviour on 64-bit platforms.
228-
// https://github.com/rust-lang/rust/blob/1.56.1/library/core/src/slice/mod.rs#L2365-L2402
229-
#[cfg(feature = "floor-planner-v1-legacy-pdqsort")]
230-
halo2_legacy_pdqsort::sort::quicksort(&mut sorted_regions, |a, b| sort_key(a).lt(&sort_key(b)));
231-
232218
sorted_regions.reverse();
233219

234220
// Lay out the sorted regions.

src/dev.rs

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,14 @@ pub use tfp::TracingFloorPlanner;
4747
#[cfg(feature = "dev-graph")]
4848
mod graph;
4949

50+
use crate::plonk::VirtualCell;
5051
use crate::rational::Rational;
5152
#[cfg(feature = "dev-graph")]
5253
#[cfg_attr(docsrs, doc(cfg(feature = "dev-graph")))]
5354
pub use graph::{circuit_dot_graph, layout::CircuitLayout};
5455

56+
use crate::poly::Rotation;
57+
5558
#[derive(Debug)]
5659
struct Region {
5760
/// The name of the region. Not required to be unique.
@@ -820,7 +823,15 @@ impl<F: FromUniformBytes<64> + Ord> MockProver<F> {
820823
}
821824
_ => {
822825
// Check that it was assigned!
823-
if r.cells.contains_key(&(cell.column, cell_row)) {
826+
if r.cells.contains_key(&(cell.column, cell_row))
827+
|| gate.polynomials().par_iter().all(|expr| {
828+
self.cell_is_irrelevant(
829+
cell,
830+
expr,
831+
gate_row as usize,
832+
)
833+
})
834+
{
824835
None
825836
} else {
826837
Some(VerifyFailure::CellNotAssigned {
@@ -1124,6 +1135,69 @@ impl<F: FromUniformBytes<64> + Ord> MockProver<F> {
11241135
}
11251136
}
11261137

1138+
// Checks if the given expression is guaranteed to be constantly zero at the given offset.
1139+
fn expr_is_constantly_zero(&self, expr: &Expression<F>, offset: usize) -> bool {
1140+
match expr {
1141+
Expression::Constant(constant) => constant.is_zero().into(),
1142+
Expression::Selector(selector) => !self.selectors[selector.0][offset],
1143+
Expression::Fixed(query) => match self.fixed[query.column_index][offset] {
1144+
CellValue::Assigned(value) => value.is_zero().into(),
1145+
_ => false,
1146+
},
1147+
Expression::Scaled(e, factor) => {
1148+
factor.is_zero().into() || self.expr_is_constantly_zero(e, offset)
1149+
}
1150+
Expression::Sum(e1, e2) => {
1151+
self.expr_is_constantly_zero(e1, offset) && self.expr_is_constantly_zero(e2, offset)
1152+
}
1153+
Expression::Product(e1, e2) => {
1154+
self.expr_is_constantly_zero(e1, offset) || self.expr_is_constantly_zero(e2, offset)
1155+
}
1156+
_ => false,
1157+
}
1158+
}
1159+
1160+
// Verify that the value of the given cell within the given expression is
1161+
// irrelevant to the evaluation of the expression. This may be because
1162+
// the cell is always multiplied by an expression that evaluates to 0, or
1163+
// because the cell is not being queried in the expression at all.
1164+
fn cell_is_irrelevant(&self, cell: &VirtualCell, expr: &Expression<F>, offset: usize) -> bool {
1165+
// Check if a given query (defined by its columnd and rotation, since we
1166+
// want this function to support different query types) is equal to `cell`.
1167+
let eq_query = |query_column: usize, query_rotation: Rotation, col_type: Any| {
1168+
cell.column.index() == query_column
1169+
&& cell.column.column_type() == &col_type
1170+
&& query_rotation == cell.rotation
1171+
};
1172+
match expr {
1173+
Expression::Constant(_) | Expression::Selector(_) => true,
1174+
Expression::Fixed(query) => !eq_query(query.column_index, query.rotation(), Any::Fixed),
1175+
Expression::Advice(query) => !eq_query(
1176+
query.column_index,
1177+
query.rotation(),
1178+
Any::Advice(Advice::new(query.phase)),
1179+
),
1180+
Expression::Instance(query) => {
1181+
!eq_query(query.column_index, query.rotation(), Any::Instance)
1182+
}
1183+
Expression::Challenge(_) => true,
1184+
Expression::Negated(e) => self.cell_is_irrelevant(cell, e, offset),
1185+
Expression::Sum(e1, e2) => {
1186+
self.cell_is_irrelevant(cell, e1, offset)
1187+
&& self.cell_is_irrelevant(cell, e2, offset)
1188+
}
1189+
Expression::Product(e1, e2) => {
1190+
(self.expr_is_constantly_zero(e1, offset)
1191+
|| self.expr_is_constantly_zero(e2, offset))
1192+
|| (self.cell_is_irrelevant(cell, e1, offset)
1193+
&& self.cell_is_irrelevant(cell, e2, offset))
1194+
}
1195+
Expression::Scaled(e, factor) => {
1196+
factor.is_zero().into() || self.cell_is_irrelevant(cell, e, offset)
1197+
}
1198+
}
1199+
}
1200+
11271201
/// Panics if the circuit being checked by this `MockProver` is not satisfied.
11281202
///
11291203
/// Any verification failures will be pretty-printed to stderr before the function

src/dev/cost_model.rs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ use std::collections::HashSet;
55
use std::panic::AssertUnwindSafe;
66
use std::{iter, num::ParseIntError, panic, str::FromStr};
77

8+
use crate::plonk::Any::Fixed;
89
use crate::plonk::Circuit;
910
use ff::{Field, FromUniformBytes};
1011
use serde::Deserialize;
1112
use serde_derive::Serialize;
12-
use crate::plonk::Any::Fixed;
1313

1414
use super::MockProver;
1515

@@ -88,7 +88,8 @@ impl FromStr for Poly {
8888
pub struct Lookup;
8989

9090
impl Lookup {
91-
fn queries(&self) -> impl Iterator<Item = Poly> {
91+
/// Returns the queries of the lookup argument
92+
pub fn queries(&self) -> impl Iterator<Item = Poly> {
9293
// - product commitments at x and \omega x
9394
// - input commitments at x and x_inv
9495
// - table commitments at x
@@ -110,7 +111,8 @@ pub struct Permutation {
110111
}
111112

112113
impl Permutation {
113-
fn queries(&self) -> impl Iterator<Item = Poly> {
114+
/// Returns the queries of the Permutation argument
115+
pub fn queries(&self) -> impl Iterator<Item = Poly> {
114116
// - product commitments at x and x_inv
115117
// - polynomial commitments at x
116118
let product = "0,-1".parse().unwrap();
@@ -120,13 +122,22 @@ impl Permutation {
120122
.chain(Some(product))
121123
.chain(iter::repeat(poly).take(self.columns))
122124
}
125+
126+
/// Returns the number of columns of the Permutation argument
127+
pub fn nr_columns(&self) -> usize {
128+
self.columns
129+
}
123130
}
124131

125132
/// High-level specifications of an abstract circuit.
126133
#[derive(Debug, Deserialize, Serialize)]
127134
pub struct ModelCircuit {
128135
/// Power-of-2 bound on the number of rows in the circuit.
129136
pub k: usize,
137+
/// Number of rows in the circuit (not including table rows).
138+
pub rows: usize,
139+
/// Number of table rows in the circuit.
140+
pub table_rows: usize,
130141
/// Maximum degree of the circuit.
131142
pub max_deg: usize,
132143
/// Number of advice columns.
@@ -224,6 +235,8 @@ impl CostOptions {
224235

225236
ModelCircuit {
226237
k: self.min_k,
238+
rows: self.rows_count,
239+
table_rows: self.table_rows_count,
227240
max_deg: self.max_degree,
228241
advice_columns: self.advice.len(),
229242
lookups: self.lookup.len(),
@@ -260,7 +273,7 @@ fn run_mock_prover_with_fallback<F: Ord + Field + FromUniformBytes<64>, C: Circu
260273
panic::catch_unwind(AssertUnwindSafe(|| {
261274
MockProver::run(k, circuit, instances.clone()).unwrap()
262275
}))
263-
.ok()
276+
.ok()
264277
})
265278
.expect("A circuit which can be implemented with at most 2^24 rows.")
266279
}
@@ -338,11 +351,7 @@ pub fn from_circuit_to_cost_model_options<F: Ord + Field + FromUniformBytes<64>,
338351
// columns (see that [`plonk::circuit::TableColumn` is a wrapper
339352
// around `Column<Fixed>`]). All of a table region's rows are
340353
// counted towards `table_rows_count.`
341-
if region
342-
.columns
343-
.iter()
344-
.all(|c| *c.column_type() == Fixed)
345-
{
354+
if region.columns.iter().all(|c| *c.column_type() == Fixed) {
346355
table_rows_count += (end + 1) - start;
347356
} else {
348357
rows_count += (end + 1) - start;
@@ -358,9 +367,9 @@ pub fn from_circuit_to_cost_model_options<F: Ord + Field + FromUniformBytes<64>,
358367
table_rows_count + cs.blinding_factors(),
359368
instance_len,
360369
]
361-
.into_iter()
362-
.max()
363-
.unwrap();
370+
.into_iter()
371+
.max()
372+
.unwrap();
364373
if min_k == instance_len {
365374
println!("WARNING: The dominant factor in your circuit's size is the number of public inputs, which causes the verifier to perform linear work.");
366375
}

src/dev/util.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ pub(super) fn format_value<F: Field>(v: F) -> String {
6565
// Format value as hex.
6666
let s = format!("{v:?}");
6767
// Remove leading zeroes.
68-
let s = s.strip_prefix("0x").unwrap();
68+
let s = s.split_once("0x").unwrap().1.split(')').next().unwrap();
6969
let s = s.trim_start_matches('0');
7070
format!("0x{s}")
7171
}

src/helpers.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ pub trait SerdeCurveAffine: PrimeCurveAffine + SerdeObject + Default {
4949
/// Reads an element from the buffer and parses it according to the `format`:
5050
/// - `Processed`: Reads a compressed curve element and decompress it
5151
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
52-
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
52+
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
5353
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
54-
/// does not perform any checks
54+
/// does not perform any checks
5555
fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
5656
match format {
5757
SerdeFormat::Processed => <Self as CurveRead>::read(reader),
@@ -83,9 +83,9 @@ impl<C: PrimeCurveAffine + SerdeObject + Default> SerdeCurveAffine for C {}
8383
pub trait SerdePrimeField: PrimeField + SerdeObject {
8484
/// Reads a field element as bytes from the buffer according to the `format`:
8585
/// - `Processed`: Reads a field element in standard form, with endianness specified by the
86-
/// `PrimeField` implementation, and checks that the element is less than the modulus.
86+
/// `PrimeField` implementation, and checks that the element is less than the modulus.
8787
/// - `RawBytes`: Reads a field element from raw bytes in its internal Montgomery representations,
88-
/// and checks that the element is less than the modulus.
88+
/// and checks that the element is less than the modulus.
8989
/// - `RawBytesUnchecked`: Reads a field element in Montgomery form and performs no checks.
9090
fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
9191
match format {
@@ -103,9 +103,9 @@ pub trait SerdePrimeField: PrimeField + SerdeObject {
103103

104104
/// Writes a field element as bytes to the buffer according to the `format`:
105105
/// - `Processed`: Writes a field element in standard form, with endianness specified by the
106-
/// `PrimeField` implementation.
106+
/// `PrimeField` implementation.
107107
/// - Otherwise: Writes a field element into raw bytes in its internal Montgomery representation,
108-
/// WITHOUT performing the expensive Montgomery reduction.
108+
/// WITHOUT performing the expensive Montgomery reduction.
109109
fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
110110
match format {
111111
SerdeFormat::Processed => writer.write_all(self.to_repr().as_ref()),

src/plonk.rs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,12 @@ where
103103
///
104104
/// Reads a curve element from the buffer and parses it according to the `format`:
105105
/// - `Processed`: Reads a compressed curve element and decompresses it.
106-
/// Reads a field element in standard form, with endianness specified by the
107-
/// `PrimeField` implementation, and checks that the element is less than the modulus.
106+
/// Reads a field element in standard form, with endianness specified by the
107+
/// `PrimeField` implementation, and checks that the element is less than the modulus.
108108
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
109-
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
109+
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
110110
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
111-
/// does not perform any checks
111+
/// does not perform any checks
112112
pub fn read<R: io::Read, ConcreteCircuit: Circuit<F>>(
113113
reader: &mut R,
114114
format: SerdeFormat,
@@ -334,12 +334,12 @@ where
334334
///
335335
/// Writes a curve element according to `format`:
336336
/// - `Processed`: Writes a compressed curve element with coordinates in standard form.
337-
/// Writes a field element in standard form, with endianness specified by the
337+
/// Writes a field element in standard form, with endianness specified by the
338338
/// `PrimeField` implementation.
339339
/// - Otherwise: Writes an uncompressed curve element with coordinates in Montgomery form
340-
/// Writes a field element into raw bytes in its internal Montgomery representation,
341-
/// WITHOUT performing the expensive Montgomery reduction.
342-
/// Does so by first writing the verifying key and then serializing the rest of the data (in the form of field polynomials)
340+
/// Writes a field element into raw bytes in its internal Montgomery representation,
341+
/// WITHOUT performing the expensive Montgomery reduction.
342+
/// Does so by first writing the verifying key and then serializing the rest of the data (in the form of field polynomials)
343343
pub fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
344344
self.vk.write(writer, format)?;
345345
self.l0.write(writer, format)?;
@@ -357,12 +357,12 @@ where
357357
///
358358
/// Reads a curve element from the buffer and parses it according to the `format`:
359359
/// - `Processed`: Reads a compressed curve element and decompresses it.
360-
/// Reads a field element in standard form, with endianness specified by the
361-
/// `PrimeField` implementation, and checks that the element is less than the modulus.
360+
/// Reads a field element in standard form, with endianness specified by the
361+
/// `PrimeField` implementation, and checks that the element is less than the modulus.
362362
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
363-
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
363+
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
364364
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
365-
/// does not perform any checks
365+
/// does not perform any checks
366366
pub fn read<R: io::Read, ConcreteCircuit: Circuit<F>>(
367367
reader: &mut R,
368368
format: SerdeFormat,

src/plonk/keygen.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use std::ops::Range;
44

55
use ff::{Field, FromUniformBytes, WithSmallOrderMulGroup};
6+
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
67

78
use super::{
89
circuit::{
@@ -312,12 +313,12 @@ where
312313
);
313314

314315
let fixed_polys: Vec<_> = fixed
315-
.iter()
316+
.par_iter()
316317
.map(|poly| vk.domain.lagrange_to_coeff(poly.clone()))
317318
.collect();
318319

319320
let fixed_cosets = fixed_polys
320-
.iter()
321+
.par_iter()
321322
.map(|poly| vk.domain.coeff_to_extended(poly.clone()))
322323
.collect();
323324

0 commit comments

Comments
 (0)