Skip to content

Commit

Permalink
* Implement PartialEq, Eq, Hash for Cell and AssignedCell
Browse files Browse the repository at this point in the history
* Add table, compressed and normal rows count.

* Add rows and table rows to cost model.

* Ignore unassigned cells if they are multiplied by zero

* Some format values are written as

"Scalar(0x..)"

The hotfix was to change the stripping rules, but this is probably
an incorrect implementation of certain traits for one of the curves.
  • Loading branch information
iquerejeta committed Dec 12, 2024
1 parent 0092a15 commit dd00d8e
Show file tree
Hide file tree
Showing 9 changed files with 152 additions and 61 deletions.
19 changes: 17 additions & 2 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ pub trait Chip<F: Field>: Sized {
}

/// Index of a region in a layouter
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct RegionIndex(usize);

impl From<usize> for RegionIndex {
Expand Down Expand Up @@ -86,7 +86,7 @@ impl std::ops::Deref for RegionStart {
}

/// A pointer to a cell within a circuit.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct Cell {
/// Identifies the region in which this cell resides.
pub region_index: RegionIndex,
Expand All @@ -104,6 +104,21 @@ pub struct AssignedCell<V, F: Field> {
_marker: PhantomData<F>,
}

impl<V, F: Field> PartialEq for AssignedCell<V, F> {
fn eq(&self, other: &Self) -> bool {
self.cell == other.cell
}
}

impl<V, F: Field> Eq for AssignedCell<V, F> {}

use std::hash::{Hash, Hasher};
impl<V, F: Field> Hash for AssignedCell<V, F> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.cell.hash(state)
}
}

impl<V, F: Field> AssignedCell<V, F> {
/// Returns the value of the [`AssignedCell`].
pub fn value(&self) -> Value<&V> {
Expand Down
14 changes: 0 additions & 14 deletions src/circuit/floor_planner/v1/strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -213,22 +213,8 @@ pub fn slot_in_biggest_advice_first(
advice_cols * shape.row_count()
};

// This used to incorrectly use `sort_unstable_by_key` with non-unique keys, which gave
// output that differed between 32-bit and 64-bit platforms, and potentially between Rust
// versions.
// We now use `sort_by_cached_key` with non-unique keys, and rely on `region_shapes`
// being sorted by region index (which we also rely on below to return `RegionStart`s
// in the correct order).
#[cfg(not(feature = "floor-planner-v1-legacy-pdqsort"))]
sorted_regions.sort_by_cached_key(sort_key);

// To preserve compatibility, when the "floor-planner-v1-legacy-pdqsort" feature is enabled,
// we use a copy of the pdqsort implementation from the Rust 1.56.1 standard library, fixed
// to its behaviour on 64-bit platforms.
// https://github.com/rust-lang/rust/blob/1.56.1/library/core/src/slice/mod.rs#L2365-L2402
#[cfg(feature = "floor-planner-v1-legacy-pdqsort")]
halo2_legacy_pdqsort::sort::quicksort(&mut sorted_regions, |a, b| sort_key(a).lt(&sort_key(b)));

sorted_regions.reverse();

// Lay out the sorted regions.
Expand Down
76 changes: 75 additions & 1 deletion src/dev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,14 @@ pub use tfp::TracingFloorPlanner;
#[cfg(feature = "dev-graph")]
mod graph;

use crate::plonk::VirtualCell;
use crate::rational::Rational;
#[cfg(feature = "dev-graph")]
#[cfg_attr(docsrs, doc(cfg(feature = "dev-graph")))]
pub use graph::{circuit_dot_graph, layout::CircuitLayout};

use crate::poly::Rotation;

#[derive(Debug)]
struct Region {
/// The name of the region. Not required to be unique.
Expand Down Expand Up @@ -820,7 +823,15 @@ impl<F: FromUniformBytes<64> + Ord> MockProver<F> {
}
_ => {
// Check that it was assigned!
if r.cells.contains_key(&(cell.column, cell_row)) {
if r.cells.contains_key(&(cell.column, cell_row))
|| gate.polynomials().par_iter().all(|expr| {
self.cell_is_irrelevant(
cell,
expr,
gate_row as usize,
)
})
{
None
} else {
Some(VerifyFailure::CellNotAssigned {
Expand Down Expand Up @@ -1124,6 +1135,69 @@ impl<F: FromUniformBytes<64> + Ord> MockProver<F> {
}
}

// Checks if the given expression is guaranteed to be constantly zero at the given offset.
fn expr_is_constantly_zero(&self, expr: &Expression<F>, offset: usize) -> bool {
match expr {
Expression::Constant(constant) => constant.is_zero().into(),
Expression::Selector(selector) => !self.selectors[selector.0][offset],
Expression::Fixed(query) => match self.fixed[query.column_index][offset] {
CellValue::Assigned(value) => value.is_zero().into(),
_ => false,
},
Expression::Scaled(e, factor) => {
factor.is_zero().into() || self.expr_is_constantly_zero(e, offset)
}
Expression::Sum(e1, e2) => {
self.expr_is_constantly_zero(e1, offset) && self.expr_is_constantly_zero(e2, offset)
}
Expression::Product(e1, e2) => {
self.expr_is_constantly_zero(e1, offset) || self.expr_is_constantly_zero(e2, offset)
}
_ => false,
}
}

// Verify that the value of the given cell within the given expression is
// irrelevant to the evaluation of the expression. This may be because
// the cell is always multiplied by an expression that evaluates to 0, or
// because the cell is not being queried in the expression at all.
fn cell_is_irrelevant(&self, cell: &VirtualCell, expr: &Expression<F>, offset: usize) -> bool {
// Check if a given query (defined by its columnd and rotation, since we
// want this function to support different query types) is equal to `cell`.
let eq_query = |query_column: usize, query_rotation: Rotation, col_type: Any| {
cell.column.index() == query_column
&& cell.column.column_type() == &col_type
&& query_rotation == cell.rotation
};
match expr {
Expression::Constant(_) | Expression::Selector(_) => true,
Expression::Fixed(query) => !eq_query(query.column_index, query.rotation(), Any::Fixed),
Expression::Advice(query) => !eq_query(
query.column_index,
query.rotation(),
Any::Advice(Advice::new(query.phase)),
),
Expression::Instance(query) => {
!eq_query(query.column_index, query.rotation(), Any::Instance)
}
Expression::Challenge(_) => true,
Expression::Negated(e) => self.cell_is_irrelevant(cell, e, offset),
Expression::Sum(e1, e2) => {
self.cell_is_irrelevant(cell, e1, offset)
&& self.cell_is_irrelevant(cell, e2, offset)
}
Expression::Product(e1, e2) => {
(self.expr_is_constantly_zero(e1, offset)
|| self.expr_is_constantly_zero(e2, offset))
|| (self.cell_is_irrelevant(cell, e1, offset)
&& self.cell_is_irrelevant(cell, e2, offset))
}
Expression::Scaled(e, factor) => {
factor.is_zero().into() || self.cell_is_irrelevant(cell, e, offset)
}
}
}

/// Panics if the circuit being checked by this `MockProver` is not satisfied.
///
/// Any verification failures will be pretty-printed to stderr before the function
Expand Down
33 changes: 21 additions & 12 deletions src/dev/cost_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ use std::collections::HashSet;
use std::panic::AssertUnwindSafe;
use std::{iter, num::ParseIntError, panic, str::FromStr};

use crate::plonk::Any::Fixed;
use crate::plonk::Circuit;
use ff::{Field, FromUniformBytes};
use serde::Deserialize;
use serde_derive::Serialize;
use crate::plonk::Any::Fixed;

use super::MockProver;

Expand Down Expand Up @@ -88,7 +88,8 @@ impl FromStr for Poly {
pub struct Lookup;

impl Lookup {
fn queries(&self) -> impl Iterator<Item = Poly> {
/// Returns the queries of the lookup argument
pub fn queries(&self) -> impl Iterator<Item = Poly> {
// - product commitments at x and \omega x
// - input commitments at x and x_inv
// - table commitments at x
Expand All @@ -110,7 +111,8 @@ pub struct Permutation {
}

impl Permutation {
fn queries(&self) -> impl Iterator<Item = Poly> {
/// Returns the queries of the Permutation argument
pub fn queries(&self) -> impl Iterator<Item = Poly> {
// - product commitments at x and x_inv
// - polynomial commitments at x
let product = "0,-1".parse().unwrap();
Expand All @@ -120,13 +122,22 @@ impl Permutation {
.chain(Some(product))
.chain(iter::repeat(poly).take(self.columns))
}

/// Returns the number of columns of the Permutation argument
pub fn nr_columns(&self) -> usize {
self.columns
}
}

/// High-level specifications of an abstract circuit.
#[derive(Debug, Deserialize, Serialize)]
pub struct ModelCircuit {
/// Power-of-2 bound on the number of rows in the circuit.
pub k: usize,
/// Number of rows in the circuit (not including table rows).
pub rows: usize,
/// Number of table rows in the circuit.
pub table_rows: usize,
/// Maximum degree of the circuit.
pub max_deg: usize,
/// Number of advice columns.
Expand Down Expand Up @@ -224,6 +235,8 @@ impl CostOptions {

ModelCircuit {
k: self.min_k,
rows: self.rows_count,
table_rows: self.table_rows_count,
max_deg: self.max_degree,
advice_columns: self.advice.len(),
lookups: self.lookup.len(),
Expand Down Expand Up @@ -260,7 +273,7 @@ fn run_mock_prover_with_fallback<F: Ord + Field + FromUniformBytes<64>, C: Circu
panic::catch_unwind(AssertUnwindSafe(|| {
MockProver::run(k, circuit, instances.clone()).unwrap()
}))
.ok()
.ok()
})
.expect("A circuit which can be implemented with at most 2^24 rows.")
}
Expand Down Expand Up @@ -338,11 +351,7 @@ pub fn from_circuit_to_cost_model_options<F: Ord + Field + FromUniformBytes<64>,
// columns (see that [`plonk::circuit::TableColumn` is a wrapper
// around `Column<Fixed>`]). All of a table region's rows are
// counted towards `table_rows_count.`
if region
.columns
.iter()
.all(|c| *c.column_type() == Fixed)
{
if region.columns.iter().all(|c| *c.column_type() == Fixed) {
table_rows_count += (end + 1) - start;
} else {
rows_count += (end + 1) - start;
Expand All @@ -358,9 +367,9 @@ pub fn from_circuit_to_cost_model_options<F: Ord + Field + FromUniformBytes<64>,
table_rows_count + cs.blinding_factors(),
instance_len,
]
.into_iter()
.max()
.unwrap();
.into_iter()
.max()
.unwrap();
if min_k == instance_len {
println!("WARNING: The dominant factor in your circuit's size is the number of public inputs, which causes the verifier to perform linear work.");
}
Expand Down
2 changes: 1 addition & 1 deletion src/dev/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub(super) fn format_value<F: Field>(v: F) -> String {
// Format value as hex.
let s = format!("{v:?}");
// Remove leading zeroes.
let s = s.strip_prefix("0x").unwrap();
let s = s.split_once("0x").unwrap().1.split(')').next().unwrap();
let s = s.trim_start_matches('0');
format!("0x{s}")
}
Expand Down
12 changes: 6 additions & 6 deletions src/helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ pub trait SerdeCurveAffine: PrimeCurveAffine + SerdeObject + Default {
/// Reads an element from the buffer and parses it according to the `format`:
/// - `Processed`: Reads a compressed curve element and decompress it
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
/// does not perform any checks
/// does not perform any checks
fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
match format {
SerdeFormat::Processed => <Self as CurveRead>::read(reader),
Expand Down Expand Up @@ -83,9 +83,9 @@ impl<C: PrimeCurveAffine + SerdeObject + Default> SerdeCurveAffine for C {}
pub trait SerdePrimeField: PrimeField + SerdeObject {
/// Reads a field element as bytes from the buffer according to the `format`:
/// - `Processed`: Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// - `RawBytes`: Reads a field element from raw bytes in its internal Montgomery representations,
/// and checks that the element is less than the modulus.
/// and checks that the element is less than the modulus.
/// - `RawBytesUnchecked`: Reads a field element in Montgomery form and performs no checks.
fn read<R: io::Read>(reader: &mut R, format: SerdeFormat) -> io::Result<Self> {
match format {
Expand All @@ -103,9 +103,9 @@ pub trait SerdePrimeField: PrimeField + SerdeObject {

/// Writes a field element as bytes to the buffer according to the `format`:
/// - `Processed`: Writes a field element in standard form, with endianness specified by the
/// `PrimeField` implementation.
/// `PrimeField` implementation.
/// - Otherwise: Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
/// WITHOUT performing the expensive Montgomery reduction.
fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
match format {
SerdeFormat::Processed => writer.write_all(self.to_repr().as_ref()),
Expand Down
24 changes: 12 additions & 12 deletions src/plonk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ where
///
/// Reads a curve element from the buffer and parses it according to the `format`:
/// - `Processed`: Reads a compressed curve element and decompresses it.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
/// does not perform any checks
/// does not perform any checks
pub fn read<R: io::Read, ConcreteCircuit: Circuit<F>>(
reader: &mut R,
format: SerdeFormat,
Expand Down Expand Up @@ -334,12 +334,12 @@ where
///
/// Writes a curve element according to `format`:
/// - `Processed`: Writes a compressed curve element with coordinates in standard form.
/// Writes a field element in standard form, with endianness specified by the
/// Writes a field element in standard form, with endianness specified by the
/// `PrimeField` implementation.
/// - Otherwise: Writes an uncompressed curve element with coordinates in Montgomery form
/// Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
/// Does so by first writing the verifying key and then serializing the rest of the data (in the form of field polynomials)
/// Writes a field element into raw bytes in its internal Montgomery representation,
/// WITHOUT performing the expensive Montgomery reduction.
/// Does so by first writing the verifying key and then serializing the rest of the data (in the form of field polynomials)
pub fn write<W: io::Write>(&self, writer: &mut W, format: SerdeFormat) -> io::Result<()> {
self.vk.write(writer, format)?;
self.l0.write(writer, format)?;
Expand All @@ -357,12 +357,12 @@ where
///
/// Reads a curve element from the buffer and parses it according to the `format`:
/// - `Processed`: Reads a compressed curve element and decompresses it.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// Reads a field element in standard form, with endianness specified by the
/// `PrimeField` implementation, and checks that the element is less than the modulus.
/// - `RawBytes`: Reads an uncompressed curve element with coordinates in Montgomery form.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// Checks that field elements are less than modulus, and then checks that the point is on the curve.
/// - `RawBytesUnchecked`: Reads an uncompressed curve element with coordinates in Montgomery form;
/// does not perform any checks
/// does not perform any checks
pub fn read<R: io::Read, ConcreteCircuit: Circuit<F>>(
reader: &mut R,
format: SerdeFormat,
Expand Down
5 changes: 3 additions & 2 deletions src/plonk/keygen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::ops::Range;

use ff::{Field, FromUniformBytes, WithSmallOrderMulGroup};
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};

use super::{
circuit::{
Expand Down Expand Up @@ -312,12 +313,12 @@ where
);

let fixed_polys: Vec<_> = fixed
.iter()
.par_iter()
.map(|poly| vk.domain.lagrange_to_coeff(poly.clone()))
.collect();

let fixed_cosets = fixed_polys
.iter()
.par_iter()
.map(|poly| vk.domain.coeff_to_extended(poly.clone()))
.collect();

Expand Down
Loading

0 comments on commit dd00d8e

Please sign in to comment.