Skip to content

Commit

Permalink
cost_model: implement [compute_min_k]
Browse files Browse the repository at this point in the history
  • Loading branch information
miguel-ambrona committed Dec 18, 2024
1 parent 00ef2ce commit c3661b8
Showing 1 changed file with 10 additions and 14 deletions.
24 changes: 10 additions & 14 deletions halo2_frontend/src/dev/cost_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::collections::HashSet;
use std::panic::AssertUnwindSafe;
use std::{iter, num::ParseIntError, panic, str::FromStr};

use crate::circuit::compile_circuit;
use crate::plonk::Circuit;
use halo2_middleware::ff::{Field, FromUniformBytes};
use serde::Deserialize;
Expand Down Expand Up @@ -280,18 +281,16 @@ pub fn from_circuit_to_model_circuit<
options.into_model_circuit::<COMM, SCALAR>(comm_scheme)
}

fn run_mock_prover_with_fallback<F: Ord + Field + FromUniformBytes<64>, C: Circuit<F>>(
circuit: &C,
instances: Vec<Vec<F>>,
) -> MockProver<F> {
(5..25)
pub fn compute_min_k<F: Ord + Field + FromUniformBytes<64>, C: Circuit<F>>(circuit: &C) -> u32 {
// TODO: We could optimize the order here.
let (_, _, cs) = (5..25)
.find_map(|k| {
panic::catch_unwind(AssertUnwindSafe(|| {
MockProver::run(k, circuit, instances.clone()).unwrap()
}))
.ok()
panic::catch_unwind(AssertUnwindSafe(|| compile_circuit(k, circuit, false))).ok()
})
.expect("A circuit which can be implemented with at most 2^24 rows.")
.ok()
.unwrap();
cs.degree() as u32
}

/// Given a circuit, this function returns [CostOptions]. If no upper bound for `k` is
Expand All @@ -303,11 +302,8 @@ pub fn from_circuit_to_cost_model_options<F: Ord + Field + FromUniformBytes<64>,
) -> CostOptions {
let instance_len = instances.iter().map(Vec::len).max().unwrap_or(0);

let prover = if let Some(k) = k_upper_bound {
MockProver::run(k, circuit, instances).unwrap()
} else {
run_mock_prover_with_fallback(circuit, instances.clone())
};
let k = k_upper_bound.unwrap_or_else(|| compute_min_k(circuit));
let prover = MockProver::run(k, circuit, instances).unwrap();

let cs = prover.cs;

Expand Down

0 comments on commit c3661b8

Please sign in to comment.