diff --git a/Cargo.lock b/Cargo.lock index e650c9b..a62a669 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -516,9 +516,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.27" +version = "1.2.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d487aa071b5f64da6f19a3e848e3578944b726ee5a4854b82172f02aa876bfdc" +checksum = "deec109607ca693028562ed836a5f1c4b8bd77755c4e132fc5ce11b0b6211ae7" dependencies = [ "shlex", ] @@ -526,7 +526,7 @@ dependencies = [ [[package]] name = "ceno-examples" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "glob", ] @@ -559,6 +559,7 @@ dependencies = [ "p3-challenger", "p3-commit", "p3-field", + "p3-fri", "p3-goldilocks", "p3-matrix", "p3-monty-31", @@ -570,12 +571,13 @@ dependencies = [ "sumcheck", "tracing", "transcript", + "witness", ] [[package]] name = "ceno_emul" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "anyhow", "ceno_rt", @@ -598,7 +600,7 @@ dependencies = [ [[package]] name = "ceno_host" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "anyhow", "ceno_emul", @@ -611,7 +613,7 @@ dependencies = [ [[package]] name = "ceno_rt" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "getrandom 0.2.16", "rkyv", @@ -620,7 +622,7 @@ dependencies = [ [[package]] name = "ceno_zkvm" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "base64", "bincode", @@ -634,8 +636,10 @@ dependencies = [ "gkr_iop", "glob", "itertools 0.13.0", + "keccakf", "mpcs", "multilinear_extensions", + "ndarray", "num-traits", "p3", "parse-size", @@ -706,9 +710,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40b6887a1d8685cebccf115538db5c0efe625ccac9696ad45c409d96566e910f" +checksum = "be92d32e80243a54711e5d7ce823c35c41c9d929dc4ab58e1276f625841aadf9" dependencies = [ "clap_builder", "clap_derive", @@ -716,9 +720,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0c66c08ce9f0c698cbce5c0279d0bb6ac936d8674174fe48f736533b964f59e" +checksum = "707eab41e9622f9139419d573eca0900137718000c517d47da73045f54331c3d" dependencies = [ "anstream", "anstyle", @@ -728,9 +732,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.40" +version = "4.5.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c7947ae4cc3d851207c1adb5b5e260ff0cca11446b1d6d1423788e442257ce" +checksum = "ef4f52386a59ca4c860f7393bcf8abd8dfd91ecccc0f774635ff68e92eeef491" dependencies = [ "heck", "proc-macro2", @@ -1137,7 +1141,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "once_cell", "p3", @@ -1227,23 +1231,29 @@ checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f" [[package]] name = "gkr_iop" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "ark-std 0.5.0", "bincode", + "ceno_emul", "clap", "either", "ff_ext", "itertools 0.13.0", + "mpcs", "multilinear_extensions", - "ndarray", + "p3", "p3-field", "p3-goldilocks", + "p3-util", "rand", "rayon", "serde", + "strum", + "strum_macros", "sumcheck", "thiserror", + "thread_local", "tiny-keccak", "tracing", "tracing-forest", @@ -1561,6 +1571,15 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "keccakf" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d4ade81a4c9327bf19dcd0bd45784b99f86243edca6be0de19fc2d3aa8a4de2" +dependencies = [ + "crunchy", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -1584,9 +1603,9 @@ checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libredox" -version = "0.1.4" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1580801010e535496706ba011c15f8532df6b42297d2e471fec38ceadd8c0638" +checksum = "4488594b9328dee448adb906d8b126d9b7deb7cf5c22161ee591610bb1be83c0" dependencies = [ "bitflags", "libc", @@ -1699,7 +1718,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "aes", "bincode", @@ -1730,7 +1749,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "either", "ff_ext", @@ -1928,8 +1947,8 @@ checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" [[package]] name = "openvm" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "bytemuck", "num-bigint 0.4.6", @@ -1941,8 +1960,8 @@ dependencies = [ [[package]] name = "openvm-circuit" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "backtrace", "cfg-if", @@ -1972,8 +1991,8 @@ dependencies = [ [[package]] name = "openvm-circuit-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "itertools 0.14.0", "quote", @@ -1982,8 +2001,8 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derive-new 0.6.0", "itertools 0.14.0", @@ -1997,8 +2016,8 @@ dependencies = [ [[package]] name = "openvm-circuit-primitives-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "itertools 0.14.0", "quote", @@ -2008,7 +2027,7 @@ dependencies = [ [[package]] name = "openvm-custom-insn" version = "0.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "proc-macro2", "quote", @@ -2017,8 +2036,8 @@ dependencies = [ [[package]] name = "openvm-instructions" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "backtrace", "derive-new 0.6.0", @@ -2034,8 +2053,8 @@ dependencies = [ [[package]] name = "openvm-instructions-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "quote", "syn 2.0.104", @@ -2043,8 +2062,8 @@ dependencies = [ [[package]] name = "openvm-native-circuit" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -2070,8 +2089,8 @@ dependencies = [ [[package]] name = "openvm-native-compiler" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "backtrace", "itertools 0.14.0", @@ -2092,8 +2111,8 @@ dependencies = [ [[package]] name = "openvm-native-compiler-derive" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "quote", "syn 2.0.104", @@ -2101,8 +2120,8 @@ dependencies = [ [[package]] name = "openvm-native-recursion" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "cfg-if", "itertools 0.14.0", @@ -2125,10 +2144,9 @@ dependencies = [ [[package]] name = "openvm-platform" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ - "getrandom 0.2.16", "libm", "openvm-custom-insn", "openvm-rv32im-guest", @@ -2136,8 +2154,8 @@ dependencies = [ [[package]] name = "openvm-poseidon2-air" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derivative", "lazy_static", @@ -2153,8 +2171,8 @@ dependencies = [ [[package]] name = "openvm-rv32im-circuit" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "derive-new 0.6.0", "derive_more 1.0.0", @@ -2176,17 +2194,18 @@ dependencies = [ [[package]] name = "openvm-rv32im-guest" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "openvm-custom-insn", + "p3-field", "strum_macros", ] [[package]] name = "openvm-rv32im-transpiler" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "openvm-instructions", "openvm-instructions-derive", @@ -2201,8 +2220,8 @@ dependencies = [ [[package]] name = "openvm-stark-backend" -version = "1.0.0" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.0.0#884f8e6aabf72bde00dc51f1f1121277bff73b1e" +version = "1.1.1" +source = "git+https://github.com/openvm-org/stark-backend.git?rev=f48090c9febd021f8ee0349bc929a775fb1fa3ad#f48090c9febd021f8ee0349bc929a775fb1fa3ad" dependencies = [ "bitcode", "cfg-if", @@ -2226,8 +2245,8 @@ dependencies = [ [[package]] name = "openvm-stark-sdk" -version = "1.0.0" -source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.0.0#884f8e6aabf72bde00dc51f1f1121277bff73b1e" +version = "1.1.1" +source = "git+https://github.com/openvm-org/stark-backend.git?rev=f48090c9febd021f8ee0349bc929a775fb1fa3ad#f48090c9febd021f8ee0349bc929a775fb1fa3ad" dependencies = [ "derivative", "derive_more 0.99.20", @@ -2244,6 +2263,7 @@ dependencies = [ "p3-fri", "p3-goldilocks", "p3-keccak", + "p3-koala-bear", "p3-merkle-tree", "p3-poseidon", "p3-poseidon2", @@ -2261,8 +2281,8 @@ dependencies = [ [[package]] name = "openvm-transpiler" -version = "1.1.0" -source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#0a8b3571c5e123ba47c224afc02df08afc79784a" +version = "1.2.1-rc.0" +source = "git+https://github.com/scroll-tech/openvm.git?branch=feat%2Fnative_multi_observe#831470c9d5fbc4cd15c60dc87b2f7b75b2c28a2e" dependencies = [ "elf", "eyre", @@ -2291,7 +2311,7 @@ checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "p3-baby-bear", "p3-challenger", @@ -2313,7 +2333,7 @@ dependencies = [ [[package]] name = "p3-air" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-matrix", @@ -2322,7 +2342,7 @@ dependencies = [ [[package]] name = "p3-baby-bear" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-mds", @@ -2336,7 +2356,7 @@ dependencies = [ [[package]] name = "p3-blake3" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "blake3", "p3-symmetric", @@ -2346,7 +2366,7 @@ dependencies = [ [[package]] name = "p3-bn254-fr" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "ff 0.13.1", "halo2curves 0.8.0", @@ -2361,7 +2381,7 @@ dependencies = [ [[package]] name = "p3-challenger" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-maybe-rayon", @@ -2373,7 +2393,7 @@ dependencies = [ [[package]] name = "p3-commit" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-challenger", @@ -2387,7 +2407,7 @@ dependencies = [ [[package]] name = "p3-dft" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2400,7 +2420,7 @@ dependencies = [ [[package]] name = "p3-field" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "num-bigint 0.4.6", @@ -2417,7 +2437,7 @@ dependencies = [ [[package]] name = "p3-fri" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-challenger", @@ -2436,7 +2456,7 @@ dependencies = [ [[package]] name = "p3-goldilocks" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "num-bigint 0.4.6", "p3-dft", @@ -2453,7 +2473,7 @@ dependencies = [ [[package]] name = "p3-interpolation" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-matrix", @@ -2464,7 +2484,7 @@ dependencies = [ [[package]] name = "p3-keccak" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2473,10 +2493,24 @@ dependencies = [ "tiny-keccak", ] +[[package]] +name = "p3-koala-bear" +version = "0.1.0" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" +dependencies = [ + "p3-field", + "p3-mds", + "p3-monty-31", + "p3-poseidon2", + "p3-symmetric", + "rand", + "serde", +] + [[package]] name = "p3-matrix" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2491,7 +2525,7 @@ dependencies = [ [[package]] name = "p3-maybe-rayon" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "rayon", ] @@ -2499,7 +2533,7 @@ dependencies = [ [[package]] name = "p3-mds" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-dft", @@ -2513,7 +2547,7 @@ dependencies = [ [[package]] name = "p3-merkle-tree" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-commit", @@ -2530,7 +2564,7 @@ dependencies = [ [[package]] name = "p3-monty-31" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "num-bigint 0.4.6", @@ -2551,7 +2585,7 @@ dependencies = [ [[package]] name = "p3-poseidon" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-field", "p3-mds", @@ -2562,7 +2596,7 @@ dependencies = [ [[package]] name = "p3-poseidon2" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "gcd", "p3-field", @@ -2574,7 +2608,7 @@ dependencies = [ [[package]] name = "p3-poseidon2-air" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "p3-air", "p3-field", @@ -2590,7 +2624,7 @@ dependencies = [ [[package]] name = "p3-symmetric" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-field", @@ -2600,7 +2634,7 @@ dependencies = [ [[package]] name = "p3-uni-stark" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "itertools 0.14.0", "p3-air", @@ -2618,7 +2652,7 @@ dependencies = [ [[package]] name = "p3-util" version = "0.1.0" -source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" +source = "git+https://github.com/Plonky3/Plonky3.git?rev=539bbc84085efb609f4f62cb03cf49588388abdb#539bbc84085efb609f4f62cb03cf49588388abdb" dependencies = [ "serde", ] @@ -2735,7 +2769,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "criterion", "ff_ext", @@ -3093,15 +3127,15 @@ dependencies = [ [[package]] name = "rustix" -version = "1.0.7" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +checksum = "11181fbabf243db407ef8df94a6ce0b2f9a733bd8be4ad02b4eda9602296cac8" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.59.0", + "windows-sys 0.60.2", ] [[package]] @@ -3203,9 +3237,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.141" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" dependencies = [ "itoa", "memchr", @@ -3344,7 +3378,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "crossbeam-channel", "either", @@ -3363,7 +3397,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "itertools 0.13.0", "p3", @@ -3551,7 +3585,7 @@ dependencies = [ "serde_spanned", "toml_datetime", "toml_write", - "winnow 0.7.11", + "winnow 0.7.12", ] [[package]] @@ -3637,7 +3671,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "crossbeam-channel", "ff_ext", @@ -3815,7 +3849,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "bincode", "blake2", @@ -4032,9 +4066,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.11" +version = "0.7.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74c7b26e3480b707944fc872477815d29a8e429d2f93a1ce000f5fa84a15cbcd" +checksum = "f3edebf492c8125044983378ecb5766203ad3b4c2f7a922bd7dd207f6d443e95" dependencies = [ "memchr", ] @@ -4051,7 +4085,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/ceno.git?branch=feat%2Fsmaller_field_support#70a4f50b3dd26919b464d3350d1af34e847e418e" +source = "git+https://github.com/scroll-tech/ceno.git?branch=build%2Fsmaller_field_support_plonky3_539bbc#1acd5edfd68727e431aad0daead3a1ef80c918f2" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index 5b70bcb..41488ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,8 @@ openvm-native-circuit = { git = "https://github.com/scroll-tech/openvm.git", bra openvm-native-compiler = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } openvm-native-compiler-derive = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } openvm-native-recursion = { git = "https://github.com/scroll-tech/openvm.git", branch = "feat/native_multi_observe", default-features = false } -openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } -openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", rev = "f48090c9febd021f8ee0349bc929a775fb1fa3ad", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", rev = "f48090c9febd021f8ee0349bc929a775fb1fa3ad", default-features = false } rand = { version = "0.8.5", default-features = false } itertools = { version = "0.13.0", default-features = false } @@ -19,16 +19,17 @@ bincode = "1.3.3" tracing = "0.1.40" # Plonky3 -p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "1ba4e5c" } -p3-goldilocks = { git = "https://github.com/Plonky3/plonky3", rev = "1ba4e5c" } +p3-air = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-field = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-commit = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-matrix = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-util = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-challenger = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-monty-31 = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-symmetric = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-fri = { git = "https://github.com/Plonky3/Plonky3.git", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } +p3-goldilocks = { git = "https://github.com/Plonky3/plonky3", rev = "539bbc84085efb609f4f62cb03cf49588388abdb" } # WHIR ark-std = { version = "0.5", features = ["std"] } @@ -37,15 +38,26 @@ ark-poly = "0.5" ark-serialize = "0.5" # Ceno -ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "multilinear_extensions" } -ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "sumcheck" } -ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support", package = "transcript" } -ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } -ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } -mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } -ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "feat/smaller_field_support" } +ceno_mle = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "multilinear_extensions" } +ceno_sumcheck = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "sumcheck" } +ceno_transcript = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "transcript" } +ceno_witness = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc", package = "witness" } +ceno_zkvm = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } +ceno_emul = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } +mpcs = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } +ff_ext = { git = "https://github.com/scroll-tech/ceno.git", branch = "build/smaller_field_support_plonky3_539bbc" } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" [features] bench-metrics = ["openvm-circuit/bench-metrics"] + +# [patch."https://github.com/scroll-tech/ceno.git"] +# ceno_mle = { path = "../ceno/multilinear_extensions", package = "multilinear_extensions" } +# ceno_sumcheck = { path = "../ceno/sumcheck", package = "sumcheck" } +# ceno_transcript = { path = "../ceno/transcript", package = "transcript" } +# ceno_witness = { path = "../ceno/witness", package = "witness" } +# ceno_zkvm = { path = "../ceno/ceno_zkvm" } +# ceno_emul = { path = "../ceno/ceno_emul" } +# mpcs = { path = "../ceno/mpcs" } +# ff_ext = { path = "../ceno/ff_ext" } diff --git a/src/arithmetics/mod.rs b/src/arithmetics/mod.rs index 68cecae..c79494d 100644 --- a/src/arithmetics/mod.rs +++ b/src/arithmetics/mod.rs @@ -1,6 +1,6 @@ use crate::tower_verifier::binding::PointAndEvalVariable; -use crate::zkvm_verifier::binding::ZKVMOpcodeProofInputVariable; -use ceno_mle::expression::{Expression, Fixed, Instance}; +use crate::zkvm_verifier::binding::ZKVMChipProofInputVariable; +use ceno_mle::{Expression, Fixed, Instance}; use ceno_zkvm::structs::{ChallengeId, WitnessId}; use ff_ext::ExtensionField; use ff_ext::{BabyBearExt4, SmallField}; @@ -41,10 +41,6 @@ pub unsafe fn exts_to_felts( builder: &mut Builder, exts: &Array>, ) -> Array> { - assert!( - matches!(exts, Array::Dyn(_, _)), - "Expected dynamic array of Exts" - ); let f_len: Usize = builder.eval(exts.len() * Usize::from(C::EF::D)); let f_arr: Array> = Array::Dyn(exts.ptr(), f_len); f_arr @@ -101,6 +97,23 @@ pub fn evaluate_at_point_degree_1( builder.eval(r * (right - left) + left) } +pub fn fixed_dot_product( + builder: &mut Builder, + a: &[Ext], + b: &Array>, + zero: Ext, +) -> Ext<::F, ::EF> { + // simple trick to prefer AddE(1 cycle) than AddEI(4 cycles) + let acc: Ext = builder.eval(zero + zero); + + for (i, va) in a.iter().enumerate() { + let vb = builder.get(b, i); + builder.assign(&acc, acc + *va * vb); + } + + acc +} + pub struct PolyEvaluator { powers_of_2: Array>, } @@ -191,19 +204,20 @@ pub fn dot_product( acc } -pub fn fixed_dot_product( +pub fn dot_product_pt_n_eval( builder: &mut Builder, - a: &[Ext], + pt_and_eval: &Array>, b: &Array>, - zero: Ext, ) -> Ext<::F, ::EF> { - // simple trick to prefer AddE(1 cycle) than AddEI(4 cycles) - let acc: Ext = builder.eval(zero + zero); - - for (i, va) in a.iter().enumerate() { - let vb = builder.get(b, i); - builder.assign(&acc, acc + *va * vb); - } + let acc: Ext = builder.eval(C::F::ZERO); + + iter_zip!(builder, pt_and_eval, b).for_each(|idx_vec, builder| { + let ptr_a = idx_vec[0]; + let ptr_b = idx_vec[1]; + let v_a = builder.iter_ptr_get(&pt_and_eval, ptr_a); + let v_b = builder.iter_ptr_get(&b, ptr_b); + builder.assign(&acc, acc + v_a.eval * v_b); + }); acc } @@ -281,6 +295,32 @@ pub fn eq_eval( acc } +// Evaluate eq polynomial. +pub fn eq_eval_with_index( + builder: &mut Builder, + x: &Array>, + y: &Array>, + xlo: Usize, + ylo: Usize, + len: Usize, +) -> Ext { + let acc: Ext = builder.constant(C::EF::ONE); + + builder.range(0, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let ptr_x: Var = builder.eval(xlo.clone() + i); + let ptr_y: Var = builder.eval(ylo.clone() + i); + let v_x = builder.get(&x, ptr_x); + let v_y = builder.get(&y, ptr_y); + let xi_yi: Ext = builder.eval(v_x * v_y); + let one: Ext = builder.constant(C::EF::ONE); + let new_acc: Ext = builder.eval(acc * (xi_yi + xi_yi - v_x - v_y + one)); + builder.assign(&acc, new_acc); + }); + + acc +} + // Multiply all elements in the Array pub fn product( builder: &mut Builder, @@ -327,21 +367,29 @@ pub fn sum( acc } -// Extend an array by one element -pub fn extend( +// Join two arrays +pub fn join( builder: &mut Builder, - arr: &Array>, - elem: &Ext, + a: &Array>, + b: &Array>, ) -> Array> { - let new_len: Var = builder.eval(arr.len() + C::N::ONE); - let out = builder.dyn_array(new_len); + let a_len = a.len(); + let b_len = b.len(); + let out_len = builder.eval_expr(a_len.clone() + b_len.clone()); + let out = builder.dyn_array(out_len); - builder.range(0, arr.len()).for_each(|i_vec, builder| { + builder.range(0, a_len.clone()).for_each(|i_vec, builder| { let i = i_vec[0]; - let val = builder.get(arr, i); - builder.set_value(&out, i, val); + let a_val = builder.get(a, i); + builder.set(&out, i, a_val); + }); + + builder.range(0, b_len).for_each(|i_vec, builder| { + let b_i = i_vec[0]; + let i = builder.eval_expr(b_i + a_len.clone()); + let b_val = builder.get(b, b_i); + builder.set(&out, i, b_val); }); - builder.set_value(&out, arr.len(), elem.clone()); out } @@ -374,7 +422,7 @@ pub fn gen_alpha_pows( pub fn eq_eval_less_or_equal_than( builder: &mut Builder, _challenger: &mut DuplexChallengerVariable, - opcode_proof: &ZKVMOpcodeProofInputVariable, + opcode_proof: &ZKVMChipProofInputVariable, a: &Array>, b: &Array>, ) -> Ext { @@ -471,6 +519,35 @@ pub fn build_eq_x_r_vec_sequential( evals } +pub fn build_eq_x_r_vec_sequential_with_offset( + builder: &mut Builder, + r: &Array>, + offset: Usize, +) -> Array> { + // we build eq(x,r) from its evaluations + // we want to evaluate eq(x,r) over x \in {0, 1}^num_vars + // for example, with num_vars = 4, x is a binary vector of 4, then + // 0 0 0 0 -> (1-r0) * (1-r1) * (1-r2) * (1-r3) + // 1 0 0 0 -> r0 * (1-r1) * (1-r2) * (1-r3) + // 0 1 0 0 -> (1-r0) * r1 * (1-r2) * (1-r3) + // 1 1 0 0 -> r0 * r1 * (1-r2) * (1-r3) + // .... + // 1 1 1 1 -> r0 * r1 * r2 * r3 + // we will need 2^num_var evaluations + + let r_len: Var = builder.eval(r.len() - offset); + let evals_len: Felt = builder.constant(C::F::ONE); + let evals_len = builder.exp_power_of_2_v::>(evals_len, r_len); + let evals_len = builder.cast_felt_to_var(evals_len); + + let evals: Array> = builder.dyn_array(evals_len); + // _debug + // build_eq_x_r_helper_sequential_offset(r, &mut evals, E::ONE); + // unsafe { std::mem::transmute(evals) } + // FIXME: this function is not implemented yet + evals +} + pub fn ceil_log2(x: usize) -> usize { assert!(x > 0, "ceil_log2: x must be positive"); // Calculate the number of bits in usize @@ -829,7 +906,7 @@ impl UniPolyExtrapolator { p_i: &Array>, eval_at: Ext, ) -> Ext { - let res: Ext = builder.eval(self.constants[0] + self.constants[0]); + let res: Ext = builder.constant(C::EF::ZERO); builder.if_eq(p_i.len(), Usize::from(4)).then_or_else( |builder| { @@ -884,8 +961,8 @@ impl UniPolyExtrapolator { let p_i_0 = builder.get(p_i, 0); let p_i_1 = builder.get(p_i, 1); - let t0: Ext = builder.eval(self.constants[5] * p_i_0 / d0); - let t1: Ext = builder.eval(self.constants[1] * p_i_1 / d1); + let t0: Ext = builder.eval(self.constants[5] * p_i_0 * d0.inverse()); + let t1: Ext = builder.eval(self.constants[1] * p_i_1 * d1.inverse()); builder.eval(l * (t0 + t1)) } @@ -909,9 +986,9 @@ impl UniPolyExtrapolator { let p_i_1: Ext = builder.get(p_i, 1); let p_i_2: Ext = builder.get(p_i, 2); - let t0: Ext = builder.eval(self.constants[6] * p_i_0 / d0); - let t1: Ext = builder.eval(self.constants[5] * p_i_1 / d1); - let t2: Ext = builder.eval(self.constants[6] * p_i_2 / d2); + let t0: Ext = builder.eval(self.constants[6] * p_i_0 * d0.inverse()); + let t1: Ext = builder.eval(self.constants[5] * p_i_1 * d1.inverse()); + let t2: Ext = builder.eval(self.constants[6] * p_i_2 * d2.inverse()); builder.eval(l * (t0 + t1 + t2)) } @@ -938,10 +1015,10 @@ impl UniPolyExtrapolator { let p_i_2: Ext = builder.get(p_i, 2); let p_i_3: Ext = builder.get(p_i, 3); - let t0: Ext = builder.eval(self.constants[9] * p_i_0 / d0); - let t1: Ext = builder.eval(self.constants[6] * p_i_1 / d1); - let t2: Ext = builder.eval(self.constants[7] * p_i_2 / d2); - let t3: Ext = builder.eval(self.constants[8] * p_i_3 / d3); + let t0: Ext = builder.eval(self.constants[9] * p_i_0 * d0.inverse()); + let t1: Ext = builder.eval(self.constants[6] * p_i_1 * d1.inverse()); + let t2: Ext = builder.eval(self.constants[7] * p_i_2 * d2.inverse()); + let t3: Ext = builder.eval(self.constants[8] * p_i_3 * d3.inverse()); builder.eval(l * (t0 + t1 + t2 + t3)) } @@ -971,12 +1048,30 @@ impl UniPolyExtrapolator { let p_i_3: Ext = builder.get(p_i, 3); let p_i_4: Ext = builder.get(p_i, 4); - let t0: Ext = builder.eval(self.constants[11] * p_i_0 / d0); - let t1: Ext = builder.eval(self.constants[9] * p_i_1 / d1); - let t2: Ext = builder.eval(self.constants[10] * p_i_2 / d2); - let t3: Ext = builder.eval(self.constants[9] * p_i_3 / d3); - let t4: Ext = builder.eval(self.constants[11] * p_i_4 / d4); + let t0: Ext = builder.eval(self.constants[11] * p_i_0 * d0.inverse()); + let t1: Ext = builder.eval(self.constants[9] * p_i_1 * d1.inverse()); + let t2: Ext = builder.eval(self.constants[10] * p_i_2 * d2.inverse()); + let t3: Ext = builder.eval(self.constants[9] * p_i_3 * d3.inverse()); + let t4: Ext = builder.eval(self.constants[11] * p_i_4 * d4.inverse()); builder.eval(l * (t0 + t1 + t2 + t3 + t4)) } } + +pub fn extend( + builder: &mut Builder, + arr: &Array>, + elem: &Ext, +) -> Array> { + let new_len: Var = builder.eval(arr.len() + C::N::ONE); + let out = builder.dyn_array(new_len); + + builder.range(0, arr.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let val = builder.get(arr, i); + builder.set_value(&out, i, val); + }); + builder.set_value(&out, arr.len(), elem.clone()); + + out +} diff --git a/src/basefold_verifier/basefold.rs b/src/basefold_verifier/basefold.rs new file mode 100644 index 0000000..8340e54 --- /dev/null +++ b/src/basefold_verifier/basefold.rs @@ -0,0 +1,224 @@ +use std::collections::BTreeMap; + +use itertools::Itertools; +use mpcs::basefold::BasefoldProof as InnerBasefoldProof; +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_compiler_derive::iter_zip; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; +use serde::Deserialize; + +use crate::{ + basefold_verifier::{ + hash::{Hash, HashVariable}, + query_phase::{ + PointAndEvals, PointAndEvalsVariable, QueryOpeningProofs, QueryOpeningProofsVariable, + }, + }, + tower_verifier::binding::{IOPProverMessage, IOPProverMessageVariable}, +}; + +use super::{mmcs::*, structs::DEGREE}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub type HashDigest = MmcsCommitment; +#[derive(Deserialize, Debug)] +pub struct BasefoldCommitment { + pub commit: HashDigest, + pub log2_max_codeword_size: usize, +} + +use mpcs::BasefoldCommitment as InnerBasefoldCommitment; + +impl From> for BasefoldCommitment { + fn from(value: InnerBasefoldCommitment) -> Self { + Self { + commit: Hash { + value: value.commit().into(), + }, + log2_max_codeword_size: value.log2_max_codeword_size, + } + } +} + +impl Hintable for BasefoldCommitment { + type HintVariable = BasefoldCommitmentVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = HashDigest::read(builder); + let log2_max_codeword_size = Usize::Var(usize::read(builder)); + + BasefoldCommitmentVariable { + commit, + log2_max_codeword_size, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.commit.write()); + stream.extend(>::write( + &self.log2_max_codeword_size, + )); + stream + } +} + +pub type HashDigestVariable = MmcsCommitmentVariable; +#[derive(DslVariable, Clone)] +pub struct BasefoldCommitmentVariable { + pub commit: HashDigestVariable, + pub log2_max_codeword_size: Usize, +} + +#[derive(Deserialize)] +pub struct BasefoldProof { + pub commits: Vec, + pub final_message: Vec>, + pub query_opening_proof: QueryOpeningProofs, + pub sumcheck_proof: Vec, + pub pow_witness: F, +} + +#[derive(DslVariable, Clone)] +pub struct BasefoldProofVariable { + pub commits: Array>, + pub final_message: Array>>, + pub query_opening_proof: QueryOpeningProofsVariable, + pub sumcheck_proof: Array>, + pub pow_witness: Felt, +} + +impl Hintable for BasefoldProof { + type HintVariable = BasefoldProofVariable; + fn read(builder: &mut Builder) -> Self::HintVariable { + let commits = Vec::::read(builder); + let final_message = Vec::>::read(builder); + let query_opening_proof = QueryOpeningProofs::read(builder); + let sumcheck_proof = Vec::::read(builder); + let pow_witness = F::read(builder); + BasefoldProofVariable { + commits, + final_message, + query_opening_proof, + sumcheck_proof, + pow_witness, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.commits.write()); + stream.extend(self.final_message.write()); + stream.extend(self.query_opening_proof.write()); + stream.extend(self.sumcheck_proof.write()); + stream.extend(self.pow_witness.write()); + stream + } +} + +impl From> for BasefoldProof { + fn from(proof: InnerBasefoldProof) -> Self { + BasefoldProof { + commits: proof.commits.iter().map(|c| c.clone().into()).collect(), + final_message: proof.final_message.into(), + query_opening_proof: proof + .query_opening_proof + .iter() + .map(|proof| proof.clone().into()) + .collect(), + sumcheck_proof: proof.sumcheck_proof.map_or(vec![], |proof| { + proof.into_iter().map(|proof| proof.into()).collect() + }), + pow_witness: proof.pow_witness, + } + } +} + +#[derive(Deserialize)] +pub struct RoundOpening { + pub num_var: usize, + pub point_and_evals: PointAndEvals, +} + +#[derive(DslVariable, Clone)] +pub struct RoundOpeningVariable { + pub num_var: Var, + pub point_and_evals: PointAndEvalsVariable, +} + +impl Hintable for RoundOpening { + type HintVariable = RoundOpeningVariable; + fn read(builder: &mut Builder) -> Self::HintVariable { + let num_var = usize::read(builder); + let point_and_evals = PointAndEvals::read(builder); + RoundOpeningVariable { + num_var, + point_and_evals, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = vec![]; + stream.extend(>::write(&self.num_var)); + stream.extend(self.point_and_evals.write()); + stream + } +} + +impl VecAutoHintable for RoundOpening {} + +#[derive(Deserialize)] +pub struct Round { + pub commit: BasefoldCommitment, + pub openings: Vec, +} + +#[derive(DslVariable, Clone)] +pub struct RoundVariable { + pub commit: BasefoldCommitmentVariable, + pub openings: Array>, + pub perm: Array>, +} + +impl Hintable for Round { + type HintVariable = RoundVariable; + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = BasefoldCommitment::read(builder); + let openings = Vec::::read(builder); + let perm = Vec::::read(builder); + RoundVariable { + commit, + openings, + perm, + } + } + + fn write(&self) -> Vec::N>> { + let mut perm = vec![0; self.openings.len()]; + self.openings + .iter() + .enumerate() + // the original order + .map(|(i, opening)| (i, opening.num_var)) + .sorted_by(|(_, nv_a), (_, nv_b)| Ord::cmp(nv_b, nv_a)) + .enumerate() + // j is the new index where i is the original index + .map(|(j, (i, _))| (i, j)) + .for_each(|(i, j)| { + perm[i] = j; + }); + let mut stream = vec![]; + stream.extend(self.commit.write()); + stream.extend(self.openings.write()); + stream.extend(perm.write()); + + stream + } +} + +impl VecAutoHintable for Round {} diff --git a/src/basefold_verifier/extension_mmcs.rs b/src/basefold_verifier/extension_mmcs.rs new file mode 100644 index 0000000..6e44863 --- /dev/null +++ b/src/basefold_verifier/extension_mmcs.rs @@ -0,0 +1,89 @@ +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::{hints::Hintable, vars::HintSlice}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +use super::{mmcs::*, structs::*}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub struct ExtMmcsVerifierInput { + pub commit: MmcsCommitment, + pub dimensions: Vec, + pub index: usize, + pub opened_values: Vec>, + pub proof: MmcsProof, +} + +#[derive(DslVariable, Clone)] +pub struct ExtMmcsVerifierInputVariable { + pub commit: MmcsCommitmentVariable, + pub dimensions: Array>, + pub index_bits: Array>, + pub opened_values: Array>>, + pub proof: HintSlice, +} + +impl Hintable for ExtMmcsVerifierInput { + type HintVariable = ExtMmcsVerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = MmcsCommitment::read(builder); + let dimensions = Vec::::read(builder); + let index_bits = Vec::::read(builder); + let opened_values = Vec::>::read(builder); + let length = Usize::from(builder.hint_var()); + let id = Usize::from(builder.hint_load()); + let proof = HintSlice { length, id }; + + ExtMmcsVerifierInputVariable { + commit, + dimensions, + index_bits, + opened_values, + proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.commit.write()); + stream.extend(self.dimensions.write()); + let mut index_bits = Vec::new(); + let mut index = self.index; + while index > 0 { + index_bits.push(index % 2); + index /= 2; + } + stream.extend( as Hintable>::write(&index_bits)); + stream.extend(self.opened_values.write()); + stream.extend( + self.proof + .iter() + .map(|p| p.to_vec()) + .collect::>() + .write(), + ); + stream + } +} + +pub(crate) fn ext_mmcs_verify_batch( + builder: &mut Builder, + input: ExtMmcsVerifierInputVariable, +) { + let dimensions = match input.dimensions { + Array::Dyn(ptr, len) => Array::Dyn(ptr, len.clone()), + _ => panic!("Expected a dynamic array of felts"), + }; + + builder.verify_batch_ext( + &dimensions, + &input.opened_values, + input.proof.id.get_var(), + &input.index_bits, + &input.commit.value, + ); +} diff --git a/src/basefold_verifier/field.rs b/src/basefold_verifier/field.rs new file mode 100644 index 0000000..64eea03 --- /dev/null +++ b/src/basefold_verifier/field.rs @@ -0,0 +1,54 @@ +const TWO_ADICITY: usize = 32; +const TWO_ADIC_GENERATORS: [usize; 33] = [ + 0x0000000000000001, + 0xffffffff00000000, + 0x0001000000000000, + 0xfffffffeff000001, + 0xefffffff00000001, + 0x00003fffffffc000, + 0x0000008000000000, + 0xf80007ff08000001, + 0xbf79143ce60ca966, + 0x1905d02a5c411f4e, + 0x9d8f2ad78bfed972, + 0x0653b4801da1c8cf, + 0xf2c35199959dfcb6, + 0x1544ef2335d17997, + 0xe0ee099310bba1e2, + 0xf6b2cffe2306baac, + 0x54df9630bf79450e, + 0xabd0a6e8aa3d8a0e, + 0x81281a7b05f9beac, + 0xfbd41c6b8caa3302, + 0x30ba2ecd5e93e76d, + 0xf502aef532322654, + 0x4b2a18ade67246b5, + 0xea9d5a1336fbc98b, + 0x86cdcc31c307e171, + 0x4bbaf5976ecfefd8, + 0xed41d05b78d6e286, + 0x10d78dd8915a171d, + 0x59049500004a4485, + 0xdfa8c93ba46d2666, + 0x7e9bd009b86a0845, + 0x400a7f755588e659, + 0x185629dcda58878c, +]; + +use openvm_native_compiler::prelude::*; +use p3_field::FieldAlgebra; + +fn two_adic_generator( + builder: &mut Builder, + bits: Var, +) -> Var { + let bits_limit = builder.eval(Usize::from(TWO_ADICITY) + Usize::from(1)); + builder.assert_less_than_slow_small_rhs(bits, bits_limit); + + let two_adic_generator: Array::F>> = builder.dyn_array(TWO_ADICITY + 1); + builder.range(0, TWO_ADICITY + 1).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set_value(&two_adic_generator, i, C::F::from_canonical_usize(TWO_ADIC_GENERATORS[i.value()])); + }); + builder.get(&two_adic_generator, bits) +} \ No newline at end of file diff --git a/src/basefold_verifier/hash.rs b/src/basefold_verifier/hash.rs new file mode 100644 index 0000000..8cede68 --- /dev/null +++ b/src/basefold_verifier/hash.rs @@ -0,0 +1,76 @@ +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; +use serde::Deserialize; + +use super::structs::DEGREE; + +pub const DIGEST_ELEMS: usize = 8; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +#[derive(Deserialize, Default, Debug)] +pub struct Hash { + pub value: [F; DIGEST_ELEMS], +} + +impl From> for Hash { + fn from(hash: p3_symmetric::Hash) -> Self { + Hash { value: hash.into() } + } +} + +#[derive(DslVariable, Clone)] +pub struct HashVariable { + pub value: Array>, +} + +impl VecAutoHintable for Hash {} + +impl Hintable for Hash { + type HintVariable = HashVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let value = builder.hint_felts_fixed(DIGEST_ELEMS); + + HashVariable { value } + } + + fn write(&self) -> Vec::N>> { + self.value.map(|felt| vec![felt]).to_vec() + } +} + +#[cfg(test)] +mod tests { + use openvm_circuit::arch::{SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + + use crate::basefold_verifier::basefold::HashDigest; + + use super::*; + #[test] + fn test_read_to_hash_variable() { + // simple test program + let mut builder = AsmBuilder::::default(); + let _digest = HashDigest::read(&mut builder); + builder.halt(); + + // configure the VM executor + let system_config = SystemConfig::default().with_max_segment_len(1 << 20); + let config = NativeConfig::new(system_config, Native); + let executor = VmExecutor::new(config); + + // prepare input + let mut input = Vec::new(); + input.extend(Hash::default().write()); + + // execute the program + let program = builder.compile_isa(); + executor.execute(program, input).unwrap(); + } +} diff --git a/src/basefold_verifier/mmcs.rs b/src/basefold_verifier/mmcs.rs new file mode 100644 index 0000000..1f030dc --- /dev/null +++ b/src/basefold_verifier/mmcs.rs @@ -0,0 +1,270 @@ +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::{hints::Hintable, vars::HintSlice}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +use crate::basefold_verifier::utils::{read_hint_slice, write_mmcs_proof}; + +use super::{hash::*, structs::*}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub type MmcsCommitment = Hash; +pub type MmcsProof = Vec<[F; DIGEST_ELEMS]>; + +pub struct MmcsVerifierInput { + pub commit: MmcsCommitment, + pub dimensions: Vec, + pub index: usize, + pub opened_values: Vec>, + pub proof: MmcsProof, +} + +pub type MmcsCommitmentVariable = HashVariable; + +#[derive(DslVariable, Clone)] +pub struct MmcsVerifierInputVariable { + pub commit: MmcsCommitmentVariable, + pub dimensions: Array>, + pub index_bits: Array>, + pub opened_values: Array>>, + pub proof: HintSlice, +} + +impl Hintable for MmcsVerifierInput { + type HintVariable = MmcsVerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let commit = MmcsCommitment::read(builder); + let dimensions = Vec::::read(builder); + let index_bits = Vec::::read(builder); + let opened_values = Vec::>::read(builder); + let proof = read_hint_slice(builder); + + MmcsVerifierInputVariable { + commit, + dimensions, + index_bits, + opened_values, + proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + + let idx_bits = (0..self.proof.len()) + .scan(self.index, |acc, _| { + let bit = *acc & 0x01; + *acc >>= 1; + + Some(bit) + }) + .collect::>(); + + stream.extend(self.commit.write()); + stream.extend(self.dimensions.write()); + stream.extend(idx_bits.write()); + stream.extend(self.opened_values.write()); + stream.extend(write_mmcs_proof(&self.proof)); + + stream + } +} + +pub fn mmcs_verify_batch(builder: &mut Builder, input: MmcsVerifierInputVariable) { + let dimensions = match input.dimensions { + Array::Dyn(ptr, len) => Array::Dyn(ptr, len.clone()), + _ => panic!("Expected a dynamic array of felts"), + }; + builder.verify_batch_felt( + &dimensions, + &input.opened_values, + input.proof.id.get_var(), + &input.index_bits, + &input.commit.value, + ); +} + +#[cfg(test)] +pub mod tests { + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_recursion::hints::Hintable; + use p3_field::FieldAlgebra; + + use super::{mmcs_verify_batch, MmcsCommitment, MmcsVerifierInput, E, F}; + + /// The witness in this test is produced by: + /// https://github.com/Jiangkm3/Plonky3 branch cyte/mmcs-poseidon2-constants + /// cargo test --package p3-merkle-tree --lib -- mmcs::tests::size_gaps --exact --show-output + #[allow(dead_code)] + pub fn build_mmcs_verify_batch() -> (Program, Vec>) { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + // Witness inputs + let mmcs_input = MmcsVerifierInput::read(&mut builder); + mmcs_verify_batch(&mut builder, mmcs_input); + builder.halt(); + + // Pass in witness stream + let f = |n: usize| F::from_canonical_usize(n); + let mut witness_stream: Vec< + Vec>, + > = Vec::new(); + let commit = MmcsCommitment { + value: [ + f(414821839), + f(366064801), + f(76927727), + f(1054874897), + f(522043147), + f(638338172), + f(1583746438), + f(941156703), + ], + }; + let dimensions = vec![7, 0, 0]; + let index = 6; + let opened_values = vec![ + vec![ + f(783379538), + f(1083745632), + f(1297755122), + f(739705382), + f(1249630435), + f(1794480926), + f(706129135), + f(51286871), + ], + vec![ + f(1782820525), + f(487690259), + f(1939320991), + f(1236615939), + f(1149125220), + f(1681169264), + f(418636771), + f(1198975790), + ], + vec![ + f(1782820525), + f(487690259), + f(1939320991), + f(1236615939), + f(1149125220), + f(1681169264), + f(418636771), + f(1198975790), + ], + ]; + let proof = vec![ + [ + f(709175359), + f(862600965), + f(21724453), + f(1644204827), + f(1122851899), + f(902491334), + f(187250228), + f(766400688), + ], + [ + f(1500388444), + f(788589576), + f(699109303), + f(1804289606), + f(295155621), + f(328080503), + f(198482491), + f(1942550078), + ], + [ + f(132120813), + f(362247724), + f(635527855), + f(709381234), + f(1331884835), + f(1016275827), + f(962247980), + f(1772849136), + ], + [ + f(1707124288), + f(1917010688), + f(261076785), + f(346295418), + f(1637246858), + f(1607442625), + f(777235843), + f(194556598), + ], + [ + f(1410853257), + f(1598063795), + f(1111574219), + f(1465562989), + f(1102456901), + f(1433687377), + f(1376477958), + f(1087266135), + ], + [ + f(278709284), + f(1823086849), + f(1969802325), + f(633552560), + f(1780238760), + f(297873878), + f(421105965), + f(1357131680), + ], + [ + f(883611536), + f(685305811), + f(56966874), + f(170904280), + f(1353579462), + f(1357636937), + f(1565241058), + f(209109553), + ], + ]; + let mmcs_input = MmcsVerifierInput { + commit, + dimensions, + index, + opened_values, + proof, + }; + witness_stream.extend(mmcs_input.write()); + + // PROGRAM + let program: Program = builder.compile_isa(); + + (program, witness_stream) + } + + #[test] + fn test_mmcs_verify_batch() { + let (program, witness) = build_mmcs_verify_batch(); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program, witness).unwrap(); + + // _debug + // let results = executor.execute_segments(program, witness).unwrap(); + // for seg in results { + // println!("=> cycle count: {:?}", seg.metrics.cycle_count); + // } + } +} diff --git a/src/basefold_verifier/mod.rs b/src/basefold_verifier/mod.rs new file mode 100644 index 0000000..fac53c8 --- /dev/null +++ b/src/basefold_verifier/mod.rs @@ -0,0 +1,10 @@ +pub(crate) mod basefold; +pub(crate) mod extension_mmcs; +pub(crate) mod hash; +pub(crate) mod mmcs; +pub(crate) mod query_phase; +pub(crate) mod rs; +pub(crate) mod structs; +// pub(crate) mod field; +pub(crate) mod utils; +pub(crate) mod verifier; diff --git a/src/basefold_verifier/query_phase.rs b/src/basefold_verifier/query_phase.rs new file mode 100644 index 0000000..2497666 --- /dev/null +++ b/src/basefold_verifier/query_phase.rs @@ -0,0 +1,862 @@ +// Note: check all XXX comments! + +use ff_ext::{BabyBearExt4, ExtensionField, PoseidonField}; +use mpcs::basefold::QueryOpeningProof as InnerQueryOpeningProof; +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_compiler_derive::iter_zip; +use openvm_native_recursion::{ + hints::{Hintable, VecAutoHintable}, + vars::HintSlice, +}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_commit::ExtensionMmcs; +use p3_field::{Field, FieldAlgebra}; +use serde::Deserialize; + +use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, utils::*}; +use crate::{ + arithmetics::eq_eval_with_index, + tower_verifier::{binding::*, program::interpolate_uni_poly}, +}; + +pub type F = BabyBear; +pub type E = BabyBearExt4; +pub type InnerConfig = AsmConfig; + +use p3_fri::BatchOpening as InnerBatchOpening; +use p3_fri::CommitPhaseProofStep as InnerCommitPhaseProofStep; + +/// We have to define a struct similar to p3_fri::BatchOpening as +/// the trait `Hintable` is defined in another crate inside OpenVM +#[derive(Deserialize)] +pub struct BatchOpening { + pub opened_values: Vec>, + pub opening_proof: MmcsProof, +} + +impl + From< + InnerBatchOpening< + ::BaseField, + <::BaseField as PoseidonField>::MMCS, + >, + > for BatchOpening +{ + fn from( + inner: InnerBatchOpening< + ::BaseField, + <::BaseField as PoseidonField>::MMCS, + >, + ) -> Self { + Self { + opened_values: inner.opened_values, + opening_proof: inner.opening_proof.into(), + } + } +} + +#[derive(DslVariable, Clone)] +pub struct BatchOpeningVariable { + pub opened_values: Array>>, + pub opening_proof: HintSlice, +} + +impl Hintable for BatchOpening { + type HintVariable = BatchOpeningVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let opened_values = Vec::>::read(builder); + let opening_proof = read_hint_slice(builder); + + BatchOpeningVariable { + opened_values, + opening_proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.opened_values.write()); + stream.extend(write_mmcs_proof(&self.opening_proof)); + + stream + } +} + +impl VecAutoHintable for BatchOpening {} + +/// TODO: use `openvm_native_recursion::fri::types::FriCommitPhaseProofStepVariable` instead +#[derive(Deserialize)] +pub struct CommitPhaseProofStep { + pub sibling_value: E, + pub opening_proof: MmcsProof, +} + +pub type ExtMmcs = ExtensionMmcs< + ::BaseField, + E, + <::BaseField as PoseidonField>::MMCS, +>; +impl From>> for CommitPhaseProofStep { + fn from(inner: InnerCommitPhaseProofStep>) -> Self { + Self { + sibling_value: inner.sibling_value, + opening_proof: inner.opening_proof.into(), + } + } +} + +#[derive(DslVariable, Clone)] +pub struct CommitPhaseProofStepVariable { + pub sibling_value: Ext, + pub opening_proof: HintSlice, +} + +impl VecAutoHintable for CommitPhaseProofStep {} + +impl Hintable for CommitPhaseProofStep { + type HintVariable = CommitPhaseProofStepVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let sibling_value = E::read(builder); + let opening_proof = read_hint_slice(builder); + + CommitPhaseProofStepVariable { + sibling_value, + opening_proof, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.sibling_value.write()); + stream.extend(vec![ + vec![F::from_canonical_usize(self.opening_proof.len())], + self.opening_proof + .iter() + .flatten() + .copied() + .collect::>(), + ]); + stream + } +} + +#[derive(Deserialize)] +pub struct QueryOpeningProof { + pub input_proofs: Vec, + pub commit_phase_openings: Vec, +} + +impl From> for QueryOpeningProof { + fn from(proof: InnerQueryOpeningProof) -> Self { + Self { + input_proofs: proof + .input_proofs + .into_iter() + .map(|proof| proof.into()) + .collect(), + commit_phase_openings: proof + .commit_phase_openings + .into_iter() + .map(|proof| proof.into()) + .collect(), + } + } +} + +#[derive(DslVariable, Clone)] +pub struct QueryOpeningProofVariable { + pub input_proofs: Array>, + pub commit_phase_openings: Array>, +} + +pub(crate) type QueryOpeningProofs = Vec; +pub(crate) type QueryOpeningProofsVariable = Array>; + +impl VecAutoHintable for QueryOpeningProof {} + +impl Hintable for QueryOpeningProof { + type HintVariable = QueryOpeningProofVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let input_proofs = Vec::::read(builder); + let commit_phase_openings = Vec::::read(builder); + QueryOpeningProofVariable { + input_proofs, + commit_phase_openings, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.input_proofs.write()); + stream.extend(self.commit_phase_openings.write()); + stream + } +} + +#[derive(Deserialize)] +// NOTE: Different from PointAndEval in tower_verifier! +pub struct PointAndEvals { + pub point: Point, + pub evals: Vec, +} +impl Hintable for PointAndEvals { + type HintVariable = PointAndEvalsVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let point = Point::read(builder); + let evals = Vec::::read(builder); + PointAndEvalsVariable { point, evals } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.point.write()); + stream.extend(self.evals.write()); + stream + } +} +impl VecAutoHintable for PointAndEvals {} + +#[derive(DslVariable, Clone)] +pub struct PointAndEvalsVariable { + pub point: PointVariable, + pub evals: Array>, +} + +#[derive(Deserialize)] +pub struct QueryPhaseVerifierInput { + // pub t_inv_halves: Vec::BaseField>>, + pub max_num_var: usize, + pub batch_coeffs: Vec, + pub fold_challenges: Vec, + pub indices: Vec, + pub proof: BasefoldProof, + pub rounds: Vec, +} + +impl Hintable for QueryPhaseVerifierInput { + type HintVariable = QueryPhaseVerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + // let t_inv_halves = Vec::>::read(builder); + let max_num_var = Usize::Var(usize::read(builder)); + let batch_coeffs = Vec::::read(builder); + let fold_challenges = Vec::::read(builder); + let indices = Vec::::read(builder); + let proof = BasefoldProof::read(builder); + let rounds = Vec::::read(builder); + + QueryPhaseVerifierInputVariable { + // t_inv_halves, + max_num_var, + batch_coeffs, + fold_challenges, + indices, + proof, + rounds, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + // stream.extend(self.t_inv_halves.write()); + stream.extend(>::write(&self.max_num_var)); + stream.extend(self.batch_coeffs.write()); + stream.extend(self.fold_challenges.write()); + stream.extend(self.indices.write()); + stream.extend(self.proof.write()); + stream.extend(self.rounds.write()); + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct QueryPhaseVerifierInputVariable { + // pub t_inv_halves: Array>>, + pub max_num_var: Usize, + pub batch_coeffs: Array>, + pub fold_challenges: Array>, + pub indices: Array>, + pub proof: BasefoldProofVariable, + pub rounds: Array>, +} + +#[derive(DslVariable, Clone)] +pub struct PackedCodeword { + pub low: Ext, + pub high: Ext, +} + +pub(crate) fn batch_verifier_query_phase( + builder: &mut Builder, + input: QueryPhaseVerifierInputVariable, +) { + let inv_2 = builder.constant(C::F::from_canonical_u32(0x3c000001)); + let two_adic_generators_inverses: Array> = builder.dyn_array(28); + for (index, val) in [ + 0x1usize, 0x78000000, 0x67055c21, 0x5ee99486, 0xbb4c4e4, 0x2d4cc4da, 0x669d6090, + 0x17b56c64, 0x67456167, 0x688442f9, 0x145e952d, 0x4fe61226, 0x4c734715, 0x11c33e2a, + 0x62c3d2b1, 0x77cad399, 0x54c131f4, 0x4cabd6a6, 0x5cf5713f, 0x3e9430e8, 0xba067a3, + 0x18adc27d, 0x21fd55bc, 0x4b859b3d, 0x3bd57996, 0x4483d85a, 0x3a26eef8, 0x1a427a41, + ] + .iter() + .enumerate() + { + let generator = builder.constant(C::F::from_canonical_usize(*val).inverse()); + builder.set_value(&two_adic_generators_inverses, index, generator); + } + + // encode_small + let final_message = &input.proof.final_message; + let final_rmm_values_len = builder.get(final_message, 0).len(); + let final_rmm_values = builder.dyn_array(final_rmm_values_len.clone()); + + builder + .range(0, final_rmm_values_len.clone()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let row_len = final_message.len(); + let sum = builder.constant(C::EF::ZERO); + builder.range(0, row_len).for_each(|j_vec, builder| { + let j = j_vec[0]; + let row = builder.get(final_message, j); + let row_j = builder.get(&row, i); + builder.assign(&sum, sum + row_j); + }); + builder.set_value(&final_rmm_values, i, sum); + }); + + let final_rmm = RowMajorMatrixVariable { + values: final_rmm_values, + width: builder.eval(Usize::from(1)), + }; + let final_codeword = encode_small(builder, final_rmm); + + let log2_max_codeword_size: Var = + builder.eval(input.max_num_var.clone() + Usize::from(get_rate_log())); + + let zero: Ext = builder.constant(C::EF::ZERO); + + iter_zip!(builder, input.indices, input.proof.query_opening_proof).for_each( + |ptr_vec, builder| { + // TODO: change type of input.indices to be `Array>>` + let idx = builder.iter_ptr_get(&input.indices, ptr_vec[0]); + let idx = builder.unsafe_cast_var_to_felt(idx); + let idx_bits = builder.num2bits_f(idx, C::N::bits() as u32); + // assert idx_bits[log2_max_codeword_size..] == 0 + builder + .range(log2_max_codeword_size, idx_bits.len()) + .for_each(|i_vec, builder| { + let bit = builder.get(&idx_bits, i_vec[0]); + builder.assert_eq::>(bit, Usize::from(0)); + }); + let idx_bits = idx_bits.slice(builder, 1, log2_max_codeword_size.clone()); + + let reduced_codeword_by_height: Array> = + builder.dyn_array(log2_max_codeword_size.clone()); + // initialize reduced_codeword_by_height with zeroes + iter_zip!(builder, reduced_codeword_by_height).for_each(|ptr_vec, builder| { + let zero_codeword = PackedCodeword { + low: zero.clone(), + high: zero.clone(), + }; + builder.set_value(&reduced_codeword_by_height, ptr_vec[0], zero_codeword); + }); + let query = builder.iter_ptr_get(&input.proof.query_opening_proof, ptr_vec[1]); + let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); + + builder.assert_usize_eq(query.input_proofs.len(), input.rounds.len()); + iter_zip!(builder, query.input_proofs, input.rounds).for_each(|ptr_vec, builder| { + let batch_opening = builder.iter_ptr_get(&query.input_proofs, ptr_vec[0]); + let round = builder.iter_ptr_get(&input.rounds, ptr_vec[1]); + let opened_values = batch_opening.opened_values; + let perm_opened_values = builder.dyn_array(opened_values.len()); + let dimensions = builder.dyn_array(opened_values.len()); + let opening_proof = batch_opening.opening_proof; + + // reorder (opened values, dimension) according to the permutation + builder + .range(0, opened_values.len()) + .for_each(|j_vec, builder| { + let j = j_vec[0]; + let mat_j = builder.get(&opened_values, j); + let num_var_j = builder.get(&round.openings, j).num_var; + let height_j = builder.eval(num_var_j + Usize::from(get_rate_log() - 1)); + + let permuted_j = builder.get(&round.perm, j); + + builder.set_value(&perm_opened_values, permuted_j, mat_j); + builder.set_value(&dimensions, permuted_j, height_j); + }); + + // i >>= (log2_max_codeword_size - commit.log2_max_codeword_size); + let bits_shift: Var = builder + .eval(log2_max_codeword_size.clone() - round.commit.log2_max_codeword_size); + let reduced_idx_bits = idx_bits.slice(builder, bits_shift, idx_bits.len()); + + // verify input mmcs + let mmcs_verifier_input = MmcsVerifierInputVariable { + commit: round.commit.commit.clone(), + dimensions: dimensions, + index_bits: reduced_idx_bits, + opened_values: perm_opened_values, + proof: opening_proof, + }; + + mmcs_verify_batch(builder, mmcs_verifier_input); + + // TODO: optimize this procedure + iter_zip!(builder, opened_values, round.openings).for_each(|ptr_vec, builder| { + let opened_value = builder.iter_ptr_get(&opened_values, ptr_vec[0]); + let opening = builder.iter_ptr_get(&round.openings, ptr_vec[1]); + let log2_height: Var = + builder.eval(opening.num_var + Usize::from(get_rate_log() - 1)); + let width = opening.point_and_evals.evals.len(); + + let batch_coeffs_next_offset: Var = + builder.eval(batch_coeffs_offset + width.clone()); + let coeffs = input.batch_coeffs.slice( + builder, + batch_coeffs_offset.clone(), + batch_coeffs_next_offset.clone(), + ); + let low_values = opened_value.slice(builder, 0, width.clone()); + let high_values = + opened_value.slice(builder, width.clone(), opened_value.len()); + let low: Ext = builder.constant(C::EF::ZERO); + let high: Ext = builder.constant(C::EF::ZERO); + + iter_zip!(builder, coeffs, low_values, high_values).for_each( + |ptr_vec, builder| { + let coeff = builder.iter_ptr_get(&coeffs, ptr_vec[0]); + let low_value = builder.iter_ptr_get(&low_values, ptr_vec[1]); + let high_value = builder.iter_ptr_get(&high_values, ptr_vec[2]); + + builder.assign(&low, low + coeff * low_value); + builder.assign(&high, high + coeff * high_value); + }, + ); + let codeword: PackedCodeword = PackedCodeword { low, high }; + let codeword_acc = builder.get(&reduced_codeword_by_height, log2_height); + + // reduced_openings[log2_height] += codeword + builder.assign(&codeword_acc.low, codeword_acc.low + codeword.low); + builder.assign(&codeword_acc.high, codeword_acc.high + codeword.high); + + builder.set_value(&reduced_codeword_by_height, log2_height, codeword_acc); + builder.assign(&batch_coeffs_offset, batch_coeffs_next_offset); + }); + }); + + let opening_ext = query.commit_phase_openings; + + // fold 1st codeword + let cur_num_var: Var = builder.eval(input.max_num_var.clone()); + let log2_height: Var = + builder.eval(cur_num_var + Usize::from(get_rate_log() - 1)); + + let r = builder.get(&input.fold_challenges, 0); + let codeword = builder.get(&reduced_codeword_by_height, log2_height); + let coeff = verifier_folding_coeffs_level( + builder, + &two_adic_generators_inverses, + log2_height, + &idx_bits, + inv_2, + ); + let folded = codeword_fold_with_challenge::( + builder, + codeword.low, + codeword.high, + r, + coeff, + inv_2, + ); + + // check commit phases + let commits = &input.proof.commits; + builder.assert_eq::>( + commits.len() + Usize::from(1), + input.fold_challenges.len(), + ); + builder.assert_eq::>(commits.len(), opening_ext.len()); + builder.range(0, commits.len()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let commit = builder.get(&commits, i); + let commit_phase_step = builder.get(&opening_ext, i); + let i_plus_one = builder.eval_expr(i + Usize::from(1)); + + let sibling_value = commit_phase_step.sibling_value; + let proof = commit_phase_step.opening_proof; + + builder.assign(&cur_num_var, cur_num_var - Usize::from(1)); + builder.assign(&log2_height, log2_height - Usize::from(1)); + + let folded_idx = builder.get(&idx_bits, i); + let new_involved_packed_codeword = + builder.get(&reduced_codeword_by_height, log2_height.clone()); + + builder.if_eq(folded_idx, Usize::from(0)).then_or_else( + |builder| { + builder.assign(&folded, folded + new_involved_packed_codeword.low); + }, + |builder| { + builder.assign(&folded, folded + new_involved_packed_codeword.high); + }, + ); + + // leafs + let leafs = builder.dyn_array(2); + let sibling_idx = builder.eval_expr(RVar::from(1) - folded_idx); + builder.set_value(&leafs, folded_idx, folded); + builder.set_value(&leafs, sibling_idx, sibling_value); + + // idx >>= 1 + let idx_pair = idx_bits.slice(builder, i_plus_one, idx_bits.len()); + + // mmcs_ext.verify_batch + let dimensions = builder.dyn_array(1); + let opened_values = builder.dyn_array(1); + builder.set_value(&opened_values, 0, leafs.clone()); + builder.set_value(&dimensions, 0, log2_height.clone()); + let ext_mmcs_verifier_input = ExtMmcsVerifierInputVariable { + commit: commit.clone(), + dimensions, + index_bits: idx_pair.clone(), + opened_values, + proof, + }; + ext_mmcs_verify_batch::(builder, ext_mmcs_verifier_input); + + let r = builder.get(&input.fold_challenges, i_plus_one); + let coeff = verifier_folding_coeffs_level( + builder, + &two_adic_generators_inverses, + log2_height, + &idx_pair, + inv_2, + ); + let left = builder.get(&leafs, 0); + let right = builder.get(&leafs, 1); + let new_folded = + codeword_fold_with_challenge(builder, left, right, r, coeff, inv_2); + builder.assign(&folded, new_folded); + }); + + // assert that final_value[i] = folded + let final_idx: Var = builder.constant(C::N::ZERO); + builder + .range(commits.len(), idx_bits.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let bit = builder.get(&idx_bits, i); + builder.assign( + &final_idx, + final_idx * SymbolicVar::from(C::N::from_canonical_u16(2)) + bit, + ); + }); + let final_value = builder.get(&final_codeword.values, final_idx); + builder.assert_eq::>(final_value, folded); + }, + ); + + // 1. check initial claim match with first round sumcheck value + let batch_coeffs_offset: Var = builder.constant(C::N::ZERO); + let expected_sum: Ext = builder.constant(C::EF::ZERO); + iter_zip!(builder, input.rounds).for_each(|ptr_vec, builder| { + let round = builder.iter_ptr_get(&input.rounds, ptr_vec[0]); + iter_zip!(builder, round.openings).for_each(|ptr_vec, builder| { + let opening = builder.iter_ptr_get(&round.openings, ptr_vec[0]); + // TODO: filter out openings with num_var >= get_basecode_msg_size_log::() + let var_diff: Var = builder.eval(input.max_num_var.get_var() - opening.num_var); + let scalar_var = pow_2(builder, var_diff); + let scalar = builder.unsafe_cast_var_to_felt(scalar_var); + iter_zip!(builder, opening.point_and_evals.evals).for_each(|ptr_vec, builder| { + let eval = builder.iter_ptr_get(&opening.point_and_evals.evals, ptr_vec[0]); + let coeff = builder.get(&input.batch_coeffs, batch_coeffs_offset); + let val: Ext = builder.eval(eval * coeff * scalar); + builder.assign(&expected_sum, expected_sum + val); + builder.assign(&batch_coeffs_offset, batch_coeffs_offset + Usize::from(1)); + }); + }); + }); + let sum: Ext = { + let sumcheck_evals = builder.get(&input.proof.sumcheck_proof, 0).evaluations; + let eval0 = builder.get(&sumcheck_evals, 0); + let eval1 = builder.get(&sumcheck_evals, 1); + builder.eval(eval0 + eval1) + }; + builder.assert_eq::>(expected_sum, sum); + + // 2. check every round of sumcheck match with prev claims + let fold_len_minus_one: Var = builder.eval(input.fold_challenges.len() - Usize::from(1)); + builder + .range(0, fold_len_minus_one) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let evals = builder.get(&input.proof.sumcheck_proof, i).evaluations; + let challenge = builder.get(&input.fold_challenges, i); + let left = interpolate_uni_poly(builder, &evals, challenge); + let i_plus_one = builder.eval_expr(i + Usize::from(1)); + let next_evals = builder + .get(&input.proof.sumcheck_proof, i_plus_one) + .evaluations; + let eval0 = builder.get(&next_evals, 0); + let eval1 = builder.get(&next_evals, 1); + let right: Ext = builder.eval(eval0 + eval1); + builder.assert_eq::>(left, right); + }); + + // 3. check final evaluation are correct + let final_evals = builder + .get(&input.proof.sumcheck_proof, fold_len_minus_one.clone()) + .evaluations; + let final_challenge = builder.get(&input.fold_challenges, fold_len_minus_one.clone()); + let left = interpolate_uni_poly(builder, &final_evals, final_challenge); + let right: Ext = builder.constant(C::EF::ZERO); + let one: Var = builder.constant(C::N::ONE); + let j: Var = builder.constant(C::N::ZERO); + // \sum_i eq(p, [r,i]) * f(r,i) + iter_zip!(builder, input.rounds,).for_each(|ptr_vec, builder| { + let round = builder.iter_ptr_get(&input.rounds, ptr_vec[0]); + // TODO: filter out openings with num_var >= get_basecode_msg_size_log::() + iter_zip!(builder, round.openings).for_each(|ptr_vec, builder| { + let opening = builder.iter_ptr_get(&round.openings, ptr_vec[0]); + let point_and_evals = &opening.point_and_evals; + let point = &point_and_evals.point; + + let num_vars_evaluated: Var = + builder.eval(point.fs.len() - Usize::from(get_basecode_msg_size_log())); + let final_message = builder.get(&input.proof.final_message, j); + + // coeff is the eq polynomial evaluated at the first challenge.len() variables + let ylo = builder.eval(input.fold_challenges.len() - num_vars_evaluated); + let coeff = eq_eval_with_index( + builder, + &point.fs, + &input.fold_challenges, + Usize::from(0), + Usize::Var(ylo), + Usize::Var(num_vars_evaluated), + ); + + // compute \sum_i eq(p[..num_vars_evaluated], r) * eq(p[num_vars_evaluated..], i) * f(r,i) + // + // We always assume that num_vars_evaluated is equal to p.len() + // so that the above sum only has one item and the final evaluation vector has only one element. + builder.assert_eq::>(final_message.len(), one); + let final_message = builder.get(&final_message, 0); + let dot_prod: Ext = builder.eval(final_message * coeff); + builder.assign(&right, right + dot_prod); + + builder.assign(&j, j + Usize::from(1)); + }); + }); + builder.assert_eq::>(j, input.proof.final_message.len()); + builder.assert_eq::>(left, right); +} + +#[cfg(test)] +pub mod tests { + use ceno_transcript::{BasicTranscript, Transcript}; + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use itertools::Itertools; + use mpcs::{ + pcs_batch_commit, pcs_trim, util::hash::write_digest_to_transcript, BasefoldDefault, + PolynomialCommitmentScheme, + }; + use mpcs::{BasefoldRSParams, BasefoldSpec, PCSFriParam}; + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::p3_challenger::GrindingChallenger; + use openvm_stark_sdk::p3_baby_bear::BabyBear; + use rand::thread_rng; + + type F = BabyBear; + type E = BabyBearExt4; + type PCS = BasefoldDefault; + + use crate::basefold_verifier::basefold::{Round, RoundOpening}; + use crate::basefold_verifier::query_phase::PointAndEvals; + use crate::tower_verifier::binding::Point; + + use super::{batch_verifier_query_phase, QueryPhaseVerifierInput}; + + pub fn build_batch_verifier_query_phase( + input: QueryPhaseVerifierInput, + ) -> (Program, Vec>) { + // build test program + let mut builder = AsmBuilder::::default(); + let query_phase_input = QueryPhaseVerifierInput::read(&mut builder); + batch_verifier_query_phase(&mut builder, query_phase_input); + builder.halt(); + let program = builder.compile_isa(); + + // prepare input + let mut witness_stream: Vec> = Vec::new(); + witness_stream.extend(input.write()); + + (program, witness_stream) + } + + fn construct_test(dimensions: Vec<(usize, usize)>) { + let mut rng = thread_rng(); + + // setup PCS + let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); + + let mut num_total_polys = 0; + let (matrices, mles): (Vec<_>, Vec<_>) = dimensions + .into_iter() + .map(|(num_vars, width)| { + let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); + let mles = m.to_mles(); + num_total_polys += width; + + (m, mles) + }) + .unzip(); + + // commit to matrices + let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); + let comm = PCS::get_pure_commitment(&pcs_data); + + let point_and_evals = mles + .iter() + .map(|mles| { + let point = E::random_vec(mles[0].num_vars(), &mut rng); + let evals = mles.iter().map(|mle| mle.evaluate(&point)).collect_vec(); + + (point, evals) + }) + .collect_vec(); + + // batch open + let mut transcript = BasicTranscript::::new(&[]); + let rounds = vec![(&pcs_data, point_and_evals.clone())]; + let opening_proof = PCS::batch_open(&pp, rounds, &mut transcript).unwrap(); + + // batch verify + let mut transcript = BasicTranscript::::new(&[]); + let rounds = vec![( + comm, + point_and_evals + .iter() + .map(|(point, evals)| (point.len(), (point.clone(), evals.clone()))) + .collect_vec(), + )]; + PCS::batch_verify(&vp, rounds.clone(), &opening_proof, &mut transcript) + .expect("Native verification failed"); + + let mut transcript = BasicTranscript::::new(&[]); + let batch_coeffs = + transcript.sample_and_append_challenge_pows(num_total_polys, b"batch coeffs"); + + let max_num_var = point_and_evals + .iter() + .map(|(point, _)| point.len()) + .max() + .unwrap(); + let num_rounds = max_num_var; // The final message is of length 1 + + // prepare folding challenges via sumcheck round msg + FRI commitment + let mut fold_challenges: Vec = Vec::with_capacity(num_rounds); + let commits = &opening_proof.commits; + + let sumcheck_messages = opening_proof.sumcheck_proof.as_ref().unwrap(); + for i in 0..num_rounds { + transcript.append_field_element_exts(sumcheck_messages[i].evaluations.as_slice()); + fold_challenges.push( + transcript + .sample_and_append_challenge(b"commit round") + .elements, + ); + if i < num_rounds - 1 { + write_digest_to_transcript(&commits[i], &mut transcript); + } + } + transcript.append_field_element_exts_iter(opening_proof.final_message.iter().flatten()); + + // check pow + let pow_bits = vp.get_pow_bits_by_level(mpcs::PowStrategy::FriPow); + if pow_bits > 0 { + assert!(transcript.check_witness(pow_bits, opening_proof.pow_witness)); + } + + let queries: Vec<_> = transcript.sample_bits_and_append_vec( + b"query indices", + >::get_number_queries(), + max_num_var + >::get_rate_log(), + ); + + let query_input = QueryPhaseVerifierInput { + max_num_var, + fold_challenges, + batch_coeffs, + indices: queries, + proof: opening_proof.into(), + rounds: rounds + .into_iter() + .map(|round| Round { + commit: round.0.into(), + openings: round + .1 + .into_iter() + .map(|(num_var, (point, evals))| RoundOpening { + num_var, + point_and_evals: PointAndEvals { + point: Point { fs: point }, + evals, + }, + }) + .collect(), + }) + .collect(), + }; + let (program, witness) = build_batch_verifier_query_phase(query_input); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program.clone(), witness.clone()).unwrap(); + + // _debug + let results = executor.execute_segments(program, witness).unwrap(); + for seg in results { + println!("=> cycle count: {:?}", seg.metrics.cycle_count); + } + } + + #[test] + fn test_simple_batch() { + for num_var in 5..20 { + construct_test(vec![(num_var, 20)]); + } + } + + #[test] + fn test_decreasing_batch() { + construct_test(vec![ + (14, 20), + (14, 40), + (13, 30), + (12, 30), + (11, 10), + (10, 15), + ]); + } + + #[test] + fn test_random_batch() { + construct_test(vec![(10, 20), (12, 30), (11, 10), (12, 15)]); + } +} diff --git a/src/basefold_verifier/rs.rs b/src/basefold_verifier/rs.rs new file mode 100644 index 0000000..b2730c7 --- /dev/null +++ b/src/basefold_verifier/rs.rs @@ -0,0 +1,303 @@ +// Note: check all XXX comments! + +use std::{cell::RefCell, collections::BTreeMap}; + +use openvm_native_compiler::{asm::AsmConfig, prelude::*}; +use openvm_native_recursion::hints::Hintable; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; +use p3_field::FieldAlgebra; +use serde::Deserialize; + +use super::structs::*; +use super::utils::{pow_felt, pow_felt_bits}; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +pub struct DenseMatrix { + pub values: Vec, + pub width: usize, +} + +impl Hintable for DenseMatrix { + type HintVariable = DenseMatrixVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let values = Vec::::read(builder); + let width = usize::read(builder); + + DenseMatrixVariable { values, width } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.values.write()); + stream.extend(>::write(&self.width)); + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct DenseMatrixVariable { + pub values: Array>, + pub width: Var, +} +pub type RowMajorMatrixVariable = DenseMatrixVariable; + +impl DenseMatrixVariable { + pub fn height(&self, builder: &mut Builder) -> Var { + // Supply height as hint + let height = builder.hint_var(); + builder + .if_eq(self.width.clone(), Usize::from(0)) + .then(|builder| { + builder.assert_usize_eq(height, Usize::from(0)); + }); + builder + .if_ne(self.width.clone(), Usize::from(0)) + .then(|builder| { + // XXX: check that width * height is not a field multiplication + builder.assert_usize_eq(self.width.clone() * height, self.values.len()); + }); + height + } + + // XXX: Find better ways to handle this without cloning + pub fn pad_to_height( + &self, + builder: &mut Builder, + new_height: Var, + fill: Ext, + ) { + // XXX: Not necessary, only for testing purpose + let old_height = self.height(builder); + builder.assert_less_than_slow_small_rhs(old_height, new_height + RVar::from(1)); + let new_size = builder.eval_expr(self.width.clone() * new_height.clone()); + let evals: Array> = builder.dyn_array(new_size); + builder + .range(0, self.values.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + let tmp: Ext = builder.get(&self.values, i); + builder.set(&evals, i, tmp); + }); + builder + .range(self.values.len(), evals.len()) + .for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set(&evals, i, fill); + }); + builder.assign(&self.values, evals); + } +} + +pub fn get_rate_log() -> usize { + 1 +} + +pub fn get_basecode_msg_size_log() -> usize { + 0 +} + +pub fn get_num_queries() -> usize { + 100 +} + +pub fn verifier_folding_coeffs_level( + builder: &mut Builder, + two_adic_generators_inverses: &Array>, + level: Var, + index_bits: &Array>, + two_inv: Felt, +) -> Felt { + let level_plus_one = builder.eval::, _>(level + C::N::ONE); + let g_inv = builder.get(two_adic_generators_inverses, level_plus_one); + + let g_inv_index = pow_felt_bits(builder, g_inv, index_bits, level.into()); + + builder.eval(g_inv_index * two_inv) +} + +/// The DIT FFT algorithm. +#[derive(Deserialize)] +pub struct Radix2Dit { + pub twiddles: RefCell>>, +} + +impl Hintable for Radix2Dit { + type HintVariable = Radix2DitVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let twiddles = Vec::::read(builder); + + Radix2DitVariable { twiddles } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + // XXX: process BTreeMap + let twiddles_vec: Vec = Vec::new(); + stream.extend(twiddles_vec.write()); + stream + } +} + +#[derive(DslVariable, Clone)] +pub struct Radix2DitVariable { + /// Memoized twiddle factors for each length log_n. + /// Precise definition is a map from usize to E + pub twiddles: Array>, +} + +/* +impl Radix2DitVariable { + fn dft_batch( + &self, + builder: &mut Builder, + mat: RowMajorMatrixVariable + ) -> RowMajorMatrixVariable { + let h = mat.height(builder); + let log_h = builder.hint_var(); + let log_h_minus_1: Var = builder.eval(log_h - Usize::from(1)); + let purported_h_lower_bound = pow_2(builder, log_h_minus_1); + let purported_h_upper_bound = pow_2(builder, log_h); + builder.assert_less_than_slow_small_rhs(purported_h_lower_bound, h); + builder.assert_less_than_slow_small_rhs(h, purported_h_upper_bound); + + // TODO: support memoization + // Compute twiddle factors, or take memoized ones if already available. + let twiddles = { + let root = F::two_adic_generator(log_h); + root.powers().take(1 << log_h).collect() + }; + + // DIT butterfly + reverse_matrix_index_bits(&mut mat); + for layer in 0..log_h { + dit_layer(&mut mat.as_view_mut(), layer, twiddles); + } + mat + } +} +*/ + +#[derive(Deserialize)] +pub struct RSCodeVerifierParameters { + pub full_message_size_log: usize, +} + +#[derive(DslVariable, Clone)] +pub struct RSCodeVerifierParametersVariable { + pub full_message_size_log: Usize, +} + +/* +pub(crate) fn encode_small( + builder: &mut Builder, + vp: RSCodeVerifierParametersVariable, + rmm: RowMajorMatrixVariable, +) -> RowMajorMatrixVariable { + let m = rmm; + // Add current setup this is unnecessary + let old_height = m.height(builder); + let new_height = builder.eval_expr( + old_height * Usize::from(1 << get_rate_log()) + ); + m.pad_to_height(builder, new_height, Ext::new(0)); + m +} +*/ + +/// Encode the last message sent from the prover to the verifier +/// in the commit phase. Currently, for simplicity, we drop the +/// early stopping strategy so the last message has just one +/// element, and the encoding is simply repeating this element +/// by the expansion rate. +pub(crate) fn encode_small( + builder: &mut Builder, + rmm: RowMajorMatrixVariable, // Assumed to have only one row and one column +) -> RowMajorMatrixVariable { + // XXX: nondeterministically supply the results for now + let result = builder.array(2); // Assume the expansion rate is fixed to 2 by now + let value = builder.get(&rmm.values, 0); + builder.range(0, 2).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set_value(&result, i, value); + }); + DenseMatrixVariable { + values: result, + width: builder.eval(Usize::from(1)), + } +} + +pub mod tests { + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_compiler::prelude::*; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, + }; + use p3_field::{extension::BinomialExtensionField, FieldAlgebra}; + type SC = BabyBearPoseidon2Config; + + type F = BabyBear; + type E = BinomialExtensionField; + type EF = ::Challenge; + use super::{DenseMatrix, InnerConfig}; + + #[allow(dead_code)] + pub fn build_test_dense_matrix_pad() -> (Program, Vec>) { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + // Witness inputs + let dense_matrix_variable = DenseMatrix::read(&mut builder); + let new_height = builder.eval(Usize::from(8)); + let fill = Ext::new(0); + dense_matrix_variable.pad_to_height(&mut builder, new_height, fill); + builder.halt(); + + // Pass in witness stream + let mut witness_stream: Vec< + Vec>, + > = Vec::new(); + + let verifier_input = DenseMatrix { + values: vec![E::ONE; 25], + width: 5, + }; + witness_stream.extend(verifier_input.write()); + // Hint for height + witness_stream.extend(>::write(&5)); + + let program: Program< + p3_monty_31::MontyField31, + > = builder.compile_isa(); + + (program, witness_stream) + } + + #[test] + fn test_dense_matrix_pad() { + let (program, witness) = build_test_dense_matrix_pad(); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program, witness).unwrap(); + + // _debug + // let results = executor.execute_segments(program, witness).unwrap(); + // for seg in results { + // println!("=> cycle count: {:?}", seg.metrics.cycle_count); + // } + } +} diff --git a/src/basefold_verifier/structs.rs b/src/basefold_verifier/structs.rs new file mode 100644 index 0000000..1f4ffa8 --- /dev/null +++ b/src/basefold_verifier/structs.rs @@ -0,0 +1,42 @@ +use openvm_native_compiler::{asm::AsmConfig, ir::*}; +use openvm_native_compiler_derive::DslVariable; +use openvm_native_recursion::hints::{Hintable, VecAutoHintable}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::extension::BinomialExtensionField; + +pub const DEGREE: usize = 4; + +pub type F = BabyBear; +pub type E = BinomialExtensionField; +pub type InnerConfig = AsmConfig; + +#[derive(DslVariable, Clone)] +pub struct DimensionsVariable { + pub width: Var, + pub height: Var, +} + +pub struct Dimensions { + pub width: usize, + pub height: usize, +} + +impl VecAutoHintable for Dimensions {} + +impl Hintable for Dimensions { + type HintVariable = DimensionsVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let width = usize::read(builder); + let height = usize::read(builder); + + DimensionsVariable { width, height } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write(&self.width)); + stream.extend(>::write(&self.height)); + stream + } +} diff --git a/src/basefold_verifier/utils.rs b/src/basefold_verifier/utils.rs new file mode 100644 index 0000000..86f877a --- /dev/null +++ b/src/basefold_verifier/utils.rs @@ -0,0 +1,315 @@ +use openvm_native_compiler::ir::*; +use openvm_native_recursion::vars::HintSlice; +use p3_baby_bear::BabyBear; +use p3_field::FieldAlgebra; + +use crate::basefold_verifier::mmcs::MmcsProof; + +// XXX: more efficient pow implementation +pub fn pow(builder: &mut Builder, base: Var, exponent: Var) -> Var { + let value: Var = builder.constant(C::N::ONE); + builder.range(0, exponent).for_each(|_, builder| { + builder.assign(&value, value * base); + }); + value +} + +// XXX: more efficient pow implementation +pub fn pow_felt( + builder: &mut Builder, + base: Felt, + exponent: Var, +) -> Felt { + let value: Felt = builder.constant(C::F::ONE); + builder.range(0, exponent).for_each(|_, builder| { + builder.assign(&value, value * base); + }); + value +} + +// XXX: more efficient pow implementation +pub fn pow_felt_bits( + builder: &mut Builder, + base: Felt, + exponent_bits: &Array>, // FIXME: Should be big endian? There is a bit_reverse_rows() in Ceno native code + exponent_len: Usize, +) -> Felt { + let value: Felt = builder.constant(C::F::ONE); + + // Little endian + // let repeated_squared: Felt = base; + // builder.range(0, exponent_len).for_each(|ptr, builder| { + // let ptr = ptr[0]; + // let bit = builder.get(exponent_bits, ptr); + // builder.if_eq(bit, C::N::ONE).then(|builder| { + // builder.assign(&value, value * repeated_squared); + // }); + // builder.assign(&repeated_squared, repeated_squared * repeated_squared); + // }); + + // Big endian + builder.range(0, exponent_len).for_each(|ptr, builder| { + let ptr = ptr[0]; + builder.assign(&value, value * value); + let bit = builder.get(exponent_bits, ptr); + builder.if_eq(bit, C::N::ONE).then(|builder| { + builder.assign(&value, value * base); + }); + }); + value +} + +pub fn pow_2(builder: &mut Builder, exponent: Var) -> Var { + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + pow(builder, two, exponent) +} + +// XXX: Equally outrageously inefficient +pub fn next_power_of_two(builder: &mut Builder, value: Var) -> Var { + // Non-deterministically supply the exponent n such that + // 2^n < v <= 2^{n+1} + // Ignore if v == 1 + let n: Var = builder.hint_var(); + let ret = pow_2(builder, n); + builder.if_eq(value, Usize::from(1)).then(|builder| { + builder.assign(&ret, Usize::from(1)); + }); + builder.if_ne(value, Usize::from(1)).then(|builder| { + builder.assert_less_than_slow_bit_decomp(ret, value); + let two: Var = builder.constant(C::N::from_canonical_usize(2)); + builder.assign(&ret, ret * two); + let ret_plus_one = builder.eval(ret.clone() + Usize::from(1)); + builder.assert_less_than_slow_bit_decomp(value, ret_plus_one); + }); + ret +} + +// Dot product: li * ri +pub fn dot_product( + builder: &mut Builder, + li: &Array>, + ri: &Array, +) -> Ext +where + F: openvm_native_compiler::ir::MemVariable + 'static, +{ + dot_product_with_index::(builder, li, ri, Usize::from(0), Usize::from(0), li.len()) +} + +// Generic dot product of li[llo..llo+len] * ri[rlo..rlo+len] +pub fn dot_product_with_index( + builder: &mut Builder, + li: &Array>, + ri: &Array, + llo: Usize, + rlo: Usize, + len: Usize, +) -> Ext +where + F: openvm_native_compiler::ir::MemVariable + 'static, +{ + let ret: Ext = builder.constant(C::EF::ZERO); + + builder.range(0, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let lidx: Var = builder.eval(llo.clone() + i); + let ridx: Var = builder.eval(rlo.clone() + i); + let l = builder.get(li, lidx); + let r = builder.get(ri, ridx); + builder.assign(&ret, ret + l * r); + }); + ret +} + +// Convert the first len entries of binary to decimal +// BIN is in big endian +pub fn bin_to_dec( + builder: &mut Builder, + bin: &Array>, + len: Var, +) -> Var { + let value: Var = builder.constant(C::N::ZERO); + let two: Var = builder.constant(C::N::TWO); + builder.range(0, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.assign(&value, value * two); + let next_bit = builder.get(bin, i); + builder.assign(&value, value + next_bit); + }); + value +} + +// Convert start to end entries of binary to decimal in little endian +pub fn bin_to_dec_le( + builder: &mut Builder, + bin: &Array>, + start: Var, + end: Var, +) -> Var { + let value: Var = builder.constant(C::N::ZERO); + let two: Var = builder.constant(C::N::TWO); + let power_of_two: Var = builder.constant(C::N::ONE); + builder.range(start, end).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_bit = builder.get(bin, i); + builder.assign(&value, value + power_of_two * next_bit); + builder.assign(&power_of_two, power_of_two * two); + }); + value +} + +// Sort a list in decreasing order, returns: +// 1. The original index of each sorted entry +// 2. Number of unique entries +// 3. Number of counts of each unique entry +pub fn sort_with_count( + builder: &mut Builder, + list: &Array, + ind: Ind, // Convert loaded out entries into comparable ones +) -> ( + Array>, + Var, + Array>, + Array>, +) +where + E: openvm_native_compiler::ir::MemVariable, + N: Into::N>> + + openvm_native_compiler::ir::Variable, + Ind: Fn(E) -> N, +{ + let len = list.len(); + // Nondeterministically supplies: + // 1. num_unique_entries: number of different entries + // 2. entry_order: after sorting by decreasing order, the original index of each entry + // To ensure that entry_order represents sorted index, assert that + // 1. It has the same length as list (checked by requesting list.len() hints) + // 2. It does not contain the same index twice (checked via a correspondence array) + // 3. Sorted entries are in decreasing order + // While checking, record: + // 1. count_per_unique_entry: for each unique entry value, count of entries of that value + let num_unique_entries = builder.hint_var(); + let count_per_unique_entry = builder.dyn_array(num_unique_entries); + let sorted_unique_num_vars = builder.dyn_array(num_unique_entries); + let zero: Ext = builder.constant(C::EF::ZERO); + let one: Ext = builder.constant(C::EF::ONE); + let entries_sort_surjective: Array> = builder.dyn_array(len.clone()); + builder.range(0, len.clone()).for_each(|i_vec, builder| { + let i = i_vec[0]; + builder.set(&entries_sort_surjective, i, zero.clone()); + }); + + let entries_order = builder.dyn_array(len.clone()); + let next_order = builder.hint_var(); + // Check surjection + let surjective = builder.get(&entries_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&entries_sort_surjective, next_order, one.clone()); + builder.set_value(&entries_order, 0, next_order); + let last_entry = ind(builder.get(&list, next_order)); + + let last_unique_entry_index: Var = builder.eval(Usize::from(0)); + let last_count_per_unique_entry: Var = builder.eval(Usize::from(1)); + builder.range(1, len).for_each(|i_vec, builder| { + let i = i_vec[0]; + let next_order = builder.hint_var(); + // Check surjection + let surjective = builder.get(&entries_sort_surjective, next_order); + builder.assert_ext_eq(surjective, zero.clone()); + builder.set(&entries_sort_surjective, next_order, one.clone()); + // Check entries + let next_entry = ind(builder.get(&list, next_order)); + builder + .if_eq(last_entry.clone(), next_entry.clone()) + .then(|builder| { + // next_entry == last_entry + builder.assign( + &last_count_per_unique_entry, + last_count_per_unique_entry + Usize::from(1), + ); + }); + builder + .if_ne(last_entry.clone(), next_entry.clone()) + .then(|builder| { + // next_entry < last_entry + builder.assert_less_than_slow_small_rhs(next_entry.clone(), last_entry.clone()); + + // Update count_per_unique_entry + builder.set( + &count_per_unique_entry, + last_unique_entry_index, + last_count_per_unique_entry, + ); + builder.set( + &sorted_unique_num_vars, + last_unique_entry_index, + last_entry.clone(), + ); + builder.assign(&last_entry, next_entry.clone()); + builder.assign( + &last_unique_entry_index, + last_unique_entry_index + Usize::from(1), + ); + builder.assign(&last_count_per_unique_entry, Usize::from(1)); + }); + + builder.set_value(&entries_order, i, next_order); + }); + + // Final check on num_unique_entries and count_per_unique_entry + builder.set( + &count_per_unique_entry, + last_unique_entry_index, + last_count_per_unique_entry, + ); + builder.set( + &sorted_unique_num_vars, + last_unique_entry_index, + last_entry.clone(), + ); + builder.assign( + &last_unique_entry_index, + last_unique_entry_index + Usize::from(1), + ); + builder.assert_var_eq(last_unique_entry_index, num_unique_entries); + + ( + entries_order, + num_unique_entries, + count_per_unique_entry, + sorted_unique_num_vars, + ) +} + +pub fn codeword_fold_with_challenge( + builder: &mut Builder, + left: Ext, + right: Ext, + challenge: Ext, + coeff: Felt, + inv_2: Felt, +) -> Ext { + // original (left, right) = (lo + hi*x, lo - hi*x), lo, hi are codeword, but after times x it's not codeword + // recover left & right codeword via (lo, hi) = ((left + right) / 2, (left - right) / 2x) + let lo: Ext = builder.eval((left + right) * inv_2); + let hi: Ext = builder.eval((left - right) * coeff); + // e.g. coeff = (2 * dit_butterfly)^(-1) in rs code + // we do fold on (lo, hi) to get folded = (1-r) * lo + r * hi + // (with lo, hi are two codewords), as it match perfectly with raw message in + // lagrange domain fixed variable + let ret: Ext = builder.eval(lo + challenge * (hi - lo)); + ret +} + +pub(crate) fn read_hint_slice(builder: &mut Builder) -> HintSlice { + let length = Usize::from(builder.hint_var()); + let id = Usize::from(builder.hint_load()); + HintSlice { length, id } +} + +pub(crate) fn write_mmcs_proof(proof: &MmcsProof) -> Vec> { + vec![ + vec![BabyBear::from_canonical_usize(proof.len())], + proof.iter().flatten().copied().collect::>(), + ] +} diff --git a/src/basefold_verifier/verifier.rs b/src/basefold_verifier/verifier.rs new file mode 100644 index 0000000..1ea6c12 --- /dev/null +++ b/src/basefold_verifier/verifier.rs @@ -0,0 +1,373 @@ +use crate::{ + basefold_verifier::query_phase::{batch_verifier_query_phase, QueryPhaseVerifierInputVariable}, + transcript::{transcript_check_pow_witness, transcript_observe_label}, +}; + +use super::{basefold::*, extension_mmcs::*, mmcs::*, rs::*, structs::*, utils::*}; +use ff_ext::{BabyBearExt4, ExtensionField, PoseidonField}; +use openvm_native_compiler::{asm::AsmConfig, ir::FromConstant, prelude::*}; +use openvm_native_compiler_derive::iter_zip; +use openvm_native_recursion::{ + challenger::{ + duplex::DuplexChallengerVariable, CanObserveDigest, CanObserveVariable, + CanSampleBitsVariable, CanSampleVariable, FeltChallenger, + }, + hints::{Hintable, VecAutoHintable}, + vars::HintSlice, +}; +use openvm_stark_sdk::p3_baby_bear::BabyBear; +use p3_field::FieldAlgebra; + +pub type F = BabyBear; +pub type E = BabyBearExt4; +pub type InnerConfig = AsmConfig; + +pub fn batch_verify( + builder: &mut Builder, + max_num_var: Var, + rounds: Array>, + proof: BasefoldProofVariable, + challenger: &mut DuplexChallengerVariable, +) { + builder.assert_nonzero(&proof.final_message.len()); + builder.assert_nonzero(&proof.sumcheck_proof.len()); + + // we don't support early stopping for now + iter_zip!(builder, proof.final_message).for_each(|ptr_vec, builder| { + let final_message_len = builder.iter_ptr_get(&proof.final_message, ptr_vec[0]).len(); + builder.assert_eq::>( + final_message_len, + Usize::from(1 << get_basecode_msg_size_log()), + ); + }); + + builder.assert_eq::>( + proof.query_opening_proof.len(), + Usize::from(get_num_queries()), + ); + + // Compute the total number of polynomials across all rounds + let total_num_polys: Var = builder.constant(C::N::ZERO); + iter_zip!(builder, rounds).for_each(|ptr_vec, builder| { + let openings = builder.iter_ptr_get(&rounds, ptr_vec[0]).openings; + iter_zip!(builder, openings).for_each(|ptr_vec_openings, builder| { + let evals_num = builder + .iter_ptr_get(&openings, ptr_vec_openings[0]) + .point_and_evals + .evals + .len(); + builder.assign(&total_num_polys, total_num_polys + evals_num); + }); + }); + + // get batch coeffs + transcript_observe_label(builder, challenger, b"batch coeffs"); + let batch_coeff = challenger.sample_ext(builder); + let running_coeff = + as FromConstant>::constant(C::EF::from_canonical_usize(1), builder); + let batch_coeffs: Array> = builder.dyn_array(total_num_polys); + + iter_zip!(builder, batch_coeffs).for_each(|ptr_vec_batch_coeffs, builder| { + builder.iter_ptr_set(&batch_coeffs, ptr_vec_batch_coeffs[0], running_coeff); + builder.assign(&running_coeff, running_coeff * batch_coeff); + }); + + // The max num var is provided by the prover and not guaranteed to be correct. + // Check that + // 1. it is greater than or equal to every num var; + // 2. it is equal to at least one of the num vars by multiplying all the differences + // together and assert the product is zero. + let diff_product: Var = builder.eval(Usize::from(1)); + iter_zip!(builder, rounds).for_each(|ptr_vec, builder| { + let round = builder.iter_ptr_get(&rounds, ptr_vec[0]); + + iter_zip!(builder, round.openings).for_each(|ptr_vec_opening, builder| { + let opening = builder.iter_ptr_get(&round.openings, ptr_vec_opening[0]); + let diff: Var = builder.eval(max_num_var.clone() - opening.num_var); + // num_var is always smaller than 32. + builder.range_check_var(diff, 5); + builder.assign(&diff_product, diff_product * diff); + }); + }); + // Check that at least one num_var is equal to max_num_var + let zero: Var = builder.eval(C::N::ZERO); + builder.assert_eq::>(diff_product, zero); + + let num_rounds: Var = + builder.eval(max_num_var - Usize::from(get_basecode_msg_size_log())); + + let fold_challenges: Array> = builder.dyn_array(max_num_var); + builder.range(0, num_rounds).for_each(|index_vec, builder| { + let sumcheck_message = builder.get(&proof.sumcheck_proof, index_vec[0]).evaluations; + iter_zip!(builder, sumcheck_message).for_each(|ptr_vec_sumcheck_message, builder| { + let elem = builder.iter_ptr_get(&sumcheck_message, ptr_vec_sumcheck_message[0]); + let elem_felts = builder.ext2felt(elem); + challenger.observe_slice(builder, elem_felts); + }); + + transcript_observe_label(builder, challenger, b"commit round"); + let challenge = challenger.sample_ext(builder); + builder.set(&fold_challenges, index_vec[0], challenge); + builder + .if_ne(index_vec[0], num_rounds - Usize::from(1)) + .then(|builder| { + let commit = builder.get(&proof.commits, index_vec[0]); + challenger.observe_digest(builder, commit.value.into()); + }); + }); + + iter_zip!(builder, proof.final_message).for_each(|ptr_vec_sumcheck_message, builder| { + // Each final message should contain a single element, since the final + // message size log is assumed to be zero + let elems = builder.iter_ptr_get(&proof.final_message, ptr_vec_sumcheck_message[0]); + let elem = builder.get(&elems, 0); + let elem_felts = builder.ext2felt(elem); + challenger.observe_slice(builder, elem_felts); + }); + + transcript_check_pow_witness(builder, challenger, 16, proof.pow_witness); // TODO: avoid hardcoding pow bits + transcript_observe_label(builder, challenger, b"query indices"); + let queries: Array> = builder.dyn_array(get_num_queries()); + builder + .range(0, get_num_queries()) + .for_each(|index_vec, builder| { + let number_of_bits = builder.eval_expr(max_num_var + Usize::from(get_rate_log())); + let query = challenger.sample_bits(builder, number_of_bits); + // TODO: the index will need to be split back to bits in query phase, so it's + // probably better to avoid converting bits to integer altogether + let number_of_bits = builder.eval(max_num_var + Usize::from(get_rate_log())); + let query = bin_to_dec_le(builder, &query, zero, number_of_bits); + builder.set(&queries, index_vec[0], query); + }); + + let input = QueryPhaseVerifierInputVariable { + max_num_var: builder.eval(max_num_var), + batch_coeffs, + fold_challenges, + indices: queries, + proof, + rounds, + }; + batch_verifier_query_phase(builder, input); +} + +#[cfg(test)] +pub mod tests { + use std::{cmp::Reverse, collections::BTreeMap, iter::once}; + + use ceno_mle::mle::MultilinearExtension; + use ceno_transcript::{BasicTranscript, Transcript}; + use ff_ext::{BabyBearExt4, FromUniformBytes}; + use itertools::Itertools; + use mpcs::{ + pcs_batch_commit, pcs_setup, pcs_trim, util::hash::write_digest_to_transcript, + BasefoldDefault, PolynomialCommitmentScheme, + }; + use mpcs::{BasefoldRSParams, BasefoldSpec, PCSFriParam}; + use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; + use openvm_native_circuit::{Native, NativeConfig}; + use openvm_native_compiler::asm::AsmBuilder; + use openvm_native_recursion::challenger::duplex::DuplexChallengerVariable; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::p3_challenger::GrindingChallenger; + use openvm_stark_sdk::config::baby_bear_poseidon2::Challenger; + use openvm_stark_sdk::p3_baby_bear::BabyBear; + use p3_field::Field; + use p3_field::FieldAlgebra; + use rand::thread_rng; + use serde::Deserialize; + + type F = BabyBear; + type E = BabyBearExt4; + type PCS = BasefoldDefault; + + use super::{batch_verify, BasefoldProof, BasefoldProofVariable, InnerConfig, RoundVariable}; + use crate::basefold_verifier::basefold::{Round, RoundOpening}; + use crate::basefold_verifier::query_phase::PointAndEvals; + use crate::{ + basefold_verifier::{ + basefold::BasefoldCommitment, + query_phase::{BatchOpening, CommitPhaseProofStep, QueryOpeningProof}, + }, + tower_verifier::binding::{Point, PointAndEval}, + }; + use openvm_native_compiler::{asm::AsmConfig, prelude::*}; + + #[derive(Deserialize)] + pub struct VerifierInput { + pub max_num_var: usize, + pub proof: BasefoldProof, + pub rounds: Vec, + } + + impl Hintable for VerifierInput { + type HintVariable = VerifierInputVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let max_num_var = usize::read(builder); + let proof = BasefoldProof::read(builder); + let rounds = Vec::::read(builder); + + VerifierInputVariable { + max_num_var, + proof, + rounds, + } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(>::write(&self.max_num_var)); + stream.extend(self.proof.write()); + stream.extend(self.rounds.write()); + stream + } + } + + #[derive(DslVariable, Clone)] + pub struct VerifierInputVariable { + pub max_num_var: Var, + pub proof: BasefoldProofVariable, + pub rounds: Array>, + } + + #[allow(dead_code)] + pub fn build_batch_verifier(input: VerifierInput) -> (Program, Vec>) { + // build test program + let mut builder = AsmBuilder::::default(); + let mut challenger = DuplexChallengerVariable::new(&mut builder); + let verifier_input = VerifierInput::read(&mut builder); + batch_verify( + &mut builder, + verifier_input.max_num_var, + verifier_input.rounds, + verifier_input.proof, + &mut challenger, + ); + builder.halt(); + let program = builder.compile_isa(); + + let mut witness_stream: Vec> = Vec::new(); + witness_stream.extend(input.write()); + + (program, witness_stream) + } + + fn construct_test(dimensions: Vec<(usize, usize)>) { + let mut rng = thread_rng(); + + // setup PCS + let pp = PCS::setup(1 << 20, mpcs::SecurityLevel::Conjecture100bits).unwrap(); + let (pp, vp) = pcs_trim::(pp, 1 << 20).unwrap(); + + let mut num_total_polys = 0; + let (matrices, mles): (Vec<_>, Vec<_>) = dimensions + .into_iter() + .map(|(num_vars, width)| { + let m = ceno_witness::RowMajorMatrix::::rand(&mut rng, 1 << num_vars, width); + let mles = m.to_mles(); + num_total_polys += width; + + (m, mles) + }) + .unzip(); + + // commit to matrices + let pcs_data = pcs_batch_commit::(&pp, matrices).unwrap(); + let comm = PCS::get_pure_commitment(&pcs_data); + + let point_and_evals = mles + .iter() + .map(|mles| { + let point = E::random_vec(mles[0].num_vars(), &mut rng); + let evals = mles.iter().map(|mle| mle.evaluate(&point)).collect_vec(); + + (point, evals) + }) + .collect_vec(); + + // batch open + let mut transcript = BasicTranscript::::new(&[]); + let rounds = vec![(&pcs_data, point_and_evals.clone())]; + let opening_proof = PCS::batch_open(&pp, rounds, &mut transcript).unwrap(); + + // batch verify + let mut transcript = BasicTranscript::::new(&[]); + let rounds = vec![( + comm, + point_and_evals + .iter() + .map(|(point, evals)| (point.len(), (point.clone(), evals.clone()))) + .collect_vec(), + )]; + PCS::batch_verify(&vp, rounds.clone(), &opening_proof, &mut transcript) + .expect("Native verification failed"); + + let max_num_var = point_and_evals + .iter() + .map(|(point, _)| point.len()) + .max() + .unwrap(); + + let verifier_input = VerifierInput { + max_num_var, + rounds: rounds + .into_iter() + .map(|(commit, openings)| Round { + commit: commit.into(), + openings: openings + .into_iter() + .map(|(num_var, (point, evals))| RoundOpening { + num_var, + point_and_evals: PointAndEvals { + point: Point { fs: point }, + evals, + }, + }) + .collect(), + }) + .collect(), + proof: opening_proof.into(), + }; + + let (program, witness) = build_batch_verifier(verifier_input); + + let system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + executor.execute(program.clone(), witness.clone()).unwrap(); + + // _debug + let results = executor.execute_segments(program, witness).unwrap(); + for seg in results { + println!("=> cycle count: {:?}", seg.metrics.cycle_count); + } + } + + #[test] + fn test_simple_batch() { + for num_var in 5..20 { + construct_test(vec![(num_var, 20)]); + } + } + + #[test] + fn test_decreasing_batch() { + construct_test(vec![ + (14, 20), + (14, 40), + (13, 30), + (12, 30), + (11, 10), + (10, 15), + ]); + } + + #[test] + fn test_random_batch() { + construct_test(vec![(10, 20), (12, 30), (11, 10), (12, 15)]); + } +} diff --git a/src/e2e/mod.rs b/src/e2e/mod.rs index 229c0cd..b707d72 100644 --- a/src/e2e/mod.rs +++ b/src/e2e/mod.rs @@ -1,13 +1,12 @@ +use crate::basefold_verifier::basefold::BasefoldCommitment; +use crate::basefold_verifier::query_phase::QueryPhaseVerifierInput; use crate::tower_verifier::binding::IOPProverMessage; use crate::zkvm_verifier::binding::ZKVMProofInput; -use crate::zkvm_verifier::binding::{ - TowerProofInput, ZKVMOpcodeProofInput, ZKVMTableProofInput, E, F, -}; +use crate::zkvm_verifier::binding::{TowerProofInput, ZKVMChipProofInput, E, F}; use crate::zkvm_verifier::verifier::verify_zkvm_proof; use ceno_mle::util::ceil_log2; use ff_ext::BabyBearExt4; use itertools::Itertools; -use mpcs::BasefoldCommitment; use mpcs::{Basefold, BasefoldRSParams}; use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; use openvm_native_circuit::{Native, NativeConfig}; @@ -22,9 +21,7 @@ use openvm_stark_sdk::config::setup_tracing_with_log_level; use openvm_stark_sdk::{ config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, }; -use std::collections::HashMap; use std::fs::File; -use std::thread; type SC = BabyBearPoseidon2Config; type EF = ::Challenge; @@ -34,67 +31,10 @@ use ceno_zkvm::{ structs::ZKVMVerifyingKey, }; -#[derive(Debug, Clone)] -pub struct SubcircuitParams { - pub id: usize, - pub order_idx: usize, - pub type_order_idx: usize, - pub name: String, - pub num_instances: usize, - pub is_opcode: bool, -} - pub fn parse_zkvm_proof_import( zkvm_proof: ZKVMProof>, verifier: &ZKVMVerifier>, -) -> (ZKVMProofInput, Vec) { - let subcircuit_names = verifier.vk.circuit_vks.keys().collect_vec(); - - let mut opcode_num_instances_lookup: HashMap = HashMap::new(); - let mut table_num_instances_lookup: HashMap = HashMap::new(); - for (index, num_instances) in &zkvm_proof.num_instances { - if let Some(_opcode_proof) = zkvm_proof.opcode_proofs.get(index) { - opcode_num_instances_lookup.insert(index.clone(), num_instances.clone()); - } else if let Some(_table_proof) = zkvm_proof.table_proofs.get(index) { - table_num_instances_lookup.insert(index.clone(), num_instances.clone()); - } else { - unreachable!("respective proof of index {} should exist", index) - } - } - - let mut order_idx: usize = 0; - let mut opcode_order_idx: usize = 0; - let mut table_order_idx: usize = 0; - let mut proving_sequence: Vec = vec![]; - for (index, _) in &zkvm_proof.num_instances { - let name = subcircuit_names[*index].clone(); - if zkvm_proof.opcode_proofs.get(index).is_some() { - proving_sequence.push(SubcircuitParams { - id: *index, - order_idx: order_idx.clone(), - type_order_idx: opcode_order_idx.clone(), - name: name.clone(), - num_instances: opcode_num_instances_lookup.get(index).unwrap().clone(), - is_opcode: true, - }); - opcode_order_idx += 1; - } else if zkvm_proof.table_proofs.get(index).is_some() { - proving_sequence.push(SubcircuitParams { - id: *index, - order_idx: order_idx.clone(), - type_order_idx: table_order_idx.clone(), - name: name.clone(), - num_instances: table_num_instances_lookup.get(index).unwrap().clone(), - is_opcode: false, - }); - table_order_idx += 1; - } else { - unreachable!("respective proof of index {} should exist", index) - } - - order_idx += 1; - } - +) -> ZKVMProofInput { let raw_pi = zkvm_proof .raw_pi .iter() @@ -119,14 +59,14 @@ pub fn parse_zkvm_proof_import( }) .collect::>(); - let mut opcode_proofs_vec: Vec = vec![]; - for (opcode_id, opcode_proof) in &zkvm_proof.opcode_proofs { + let mut chip_proofs: Vec = vec![]; + for (chip_id, chip_proof) in &zkvm_proof.chip_proofs { let mut record_r_out_evals: Vec> = vec![]; let mut record_w_out_evals: Vec> = vec![]; let mut record_lk_out_evals: Vec> = vec![]; - let record_r_out_evals_len: usize = opcode_proof.r_out_evals.len(); - for v_vec in &opcode_proof.r_out_evals { + let record_r_out_evals_len: usize = chip_proof.r_out_evals.len(); + for v_vec in &chip_proof.r_out_evals { let mut arr: Vec = vec![]; for v in v_vec { let v_e: E = @@ -135,8 +75,8 @@ pub fn parse_zkvm_proof_import( } record_r_out_evals.push(arr); } - let record_w_out_evals_len: usize = opcode_proof.w_out_evals.len(); - for v_vec in &opcode_proof.w_out_evals { + let record_w_out_evals_len: usize = chip_proof.w_out_evals.len(); + for v_vec in &chip_proof.w_out_evals { let mut arr: Vec = vec![]; for v in v_vec { let v_e: E = @@ -145,8 +85,8 @@ pub fn parse_zkvm_proof_import( } record_w_out_evals.push(arr); } - let record_lk_out_evals_len: usize = opcode_proof.lk_out_evals.len(); - for v_vec in &opcode_proof.lk_out_evals { + let record_lk_out_evals_len: usize = chip_proof.lk_out_evals.len(); + for v_vec in &chip_proof.lk_out_evals { let mut arr: Vec = vec![]; for v in v_vec { let v_e: E = @@ -160,7 +100,7 @@ pub fn parse_zkvm_proof_import( let mut tower_proof = TowerProofInput::default(); let mut proofs: Vec> = vec![]; - for proof in &opcode_proof.tower_proof.proofs { + for proof in &chip_proof.tower_proof.proofs { let mut proof_messages: Vec = vec![]; for m in proof { let mut evaluations_vec: Vec = vec![]; @@ -180,7 +120,7 @@ pub fn parse_zkvm_proof_import( tower_proof.proofs = proofs; let mut prod_specs_eval: Vec>> = vec![]; - for inner_val in &opcode_proof.tower_proof.prod_specs_eval { + for inner_val in &chip_proof.tower_proof.prod_specs_eval { let mut inner_v: Vec> = vec![]; for inner_evals_val in inner_val { let mut inner_evals_v: Vec = vec![]; @@ -198,7 +138,7 @@ pub fn parse_zkvm_proof_import( tower_proof.prod_specs_eval = prod_specs_eval; let mut logup_specs_eval: Vec>> = vec![]; - for inner_val in &opcode_proof.tower_proof.logup_specs_eval { + for inner_val in &chip_proof.tower_proof.logup_specs_eval { let mut inner_v: Vec> = vec![]; for inner_evals_val in inner_val { let mut inner_evals_v: Vec = vec![]; @@ -217,8 +157,8 @@ pub fn parse_zkvm_proof_import( // main constraint and select sumcheck proof let mut main_sumcheck_proofs: Vec = vec![]; - if opcode_proof.main_sumcheck_proofs.is_some() { - for m in opcode_proof.main_sumcheck_proofs.as_ref().unwrap() { + if chip_proof.main_sumcheck_proofs.is_some() { + for m in chip_proof.main_sumcheck_proofs.as_ref().unwrap() { let mut evaluations_vec: Vec = vec![]; for v in &m.evaluations { let v_e: E = @@ -232,20 +172,20 @@ pub fn parse_zkvm_proof_import( } let mut wits_in_evals: Vec = vec![]; - for v in &opcode_proof.wits_in_evals { + for v in &chip_proof.wits_in_evals { let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); wits_in_evals.push(v_e); } let mut fixed_in_evals: Vec = vec![]; - for v in &opcode_proof.fixed_in_evals { + for v in &chip_proof.fixed_in_evals { let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); fixed_in_evals.push(v_e); } - opcode_proofs_vec.push(ZKVMOpcodeProofInput { - idx: opcode_id.clone(), - num_instances: opcode_num_instances_lookup.get(opcode_id).unwrap().clone(), + chip_proofs.push(ZKVMChipProofInput { + idx: chip_id.clone(), + num_instances: chip_proof.num_instances, record_r_out_evals_len, record_w_out_evals_len, record_lk_out_evals_len, @@ -259,146 +199,19 @@ pub fn parse_zkvm_proof_import( }); } - let mut table_proofs_vec: Vec = vec![]; - for (table_id, table_proof) in &zkvm_proof.table_proofs { - let mut record_r_out_evals: Vec> = vec![]; - let mut record_w_out_evals: Vec> = vec![]; - let mut record_lk_out_evals: Vec> = vec![]; - - let record_r_out_evals_len: usize = table_proof.r_out_evals.len(); - for v_vec in &table_proof.r_out_evals { - let mut arr: Vec = vec![]; - for v in v_vec { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - arr.push(v_e); - } - record_r_out_evals.push(arr); - } - let record_w_out_evals_len: usize = table_proof.w_out_evals.len(); - for v_vec in &table_proof.w_out_evals { - let mut arr: Vec = vec![]; - for v in v_vec { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - arr.push(v_e); - } - record_w_out_evals.push(arr); - } - let record_lk_out_evals_len: usize = table_proof.lk_out_evals.len(); - for v_vec in &table_proof.lk_out_evals { - let mut arr: Vec = vec![]; - for v in v_vec { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - arr.push(v_e); - } - record_lk_out_evals.push(arr); - } - - // Tower proof - let mut tower_proof = TowerProofInput::default(); - let mut proofs: Vec> = vec![]; - - for proof in &table_proof.tower_proof.proofs { - let mut proof_messages: Vec = vec![]; - for m in proof { - let mut evaluations_vec: Vec = vec![]; - - for v in &m.evaluations { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - evaluations_vec.push(v_e); - } - proof_messages.push(IOPProverMessage { - evaluations: evaluations_vec, - }); - } - proofs.push(proof_messages); - } - tower_proof.num_proofs = proofs.len(); - tower_proof.proofs = proofs; - - let mut prod_specs_eval: Vec>> = vec![]; - for inner_val in &table_proof.tower_proof.prod_specs_eval { - let mut inner_v: Vec> = vec![]; - for inner_evals_val in inner_val { - let mut inner_evals_v: Vec = vec![]; - - for v in inner_evals_val { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - inner_evals_v.push(v_e); - } - inner_v.push(inner_evals_v); - } - prod_specs_eval.push(inner_v); - } - tower_proof.num_prod_specs = prod_specs_eval.len(); - tower_proof.prod_specs_eval = prod_specs_eval; - - let mut logup_specs_eval: Vec>> = vec![]; - for inner_val in &table_proof.tower_proof.logup_specs_eval { - let mut inner_v: Vec> = vec![]; - for inner_evals_val in inner_val { - let mut inner_evals_v: Vec = vec![]; - - for v in inner_evals_val { - let v_e: E = - serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - inner_evals_v.push(v_e); - } - inner_v.push(inner_evals_v); - } - logup_specs_eval.push(inner_v); - } - tower_proof.num_logup_specs = logup_specs_eval.len(); - tower_proof.logup_specs_eval = logup_specs_eval; - - let mut fixed_in_evals: Vec = vec![]; - for v in &table_proof.fixed_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - fixed_in_evals.push(v_e); - } - let mut wits_in_evals: Vec = vec![]; - for v in &table_proof.wits_in_evals { - let v_e: E = serde_json::from_value(serde_json::to_value(v.clone()).unwrap()).unwrap(); - wits_in_evals.push(v_e); - } + let witin_commit: mpcs::BasefoldCommitment = + serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); + let witin_commit: BasefoldCommitment = witin_commit.into(); - let num_instances = table_num_instances_lookup.get(table_id).unwrap().clone(); + let pcs_proof = zkvm_proof.opening_proof.into(); - table_proofs_vec.push(ZKVMTableProofInput { - idx: table_id.clone(), - num_instances, - record_r_out_evals_len, - record_w_out_evals_len, - record_lk_out_evals_len, - record_r_out_evals, - record_w_out_evals, - record_lk_out_evals, - tower_proof, - fixed_in_evals, - wits_in_evals, - }); + ZKVMProofInput { + raw_pi, + pi_evals, + chip_proofs, + witin_commit, + pcs_proof, } - - let witin_commit: BasefoldCommitment = - serde_json::from_value(serde_json::to_value(zkvm_proof.witin_commit).unwrap()).unwrap(); - let fixed_commit = verifier.vk.fixed_commit.clone(); - - ( - ZKVMProofInput { - raw_pi, - pi_evals, - opcode_proofs: opcode_proofs_vec, - table_proofs: table_proofs_vec, - witin_commit, - fixed_commit, - num_instances: zkvm_proof.num_instances.clone(), - }, - proving_sequence, - ) } pub fn inner_test_thread() { @@ -416,19 +229,14 @@ pub fn inner_test_thread() { .expect("Failed to deserialize vk file"); let verifier = ZKVMVerifier::new(vk); - let (zkvm_proof_input, proving_sequence) = parse_zkvm_proof_import(zkvm_proof, &verifier); + let zkvm_proof_input = parse_zkvm_proof_import(zkvm_proof, &verifier); // OpenVM DSL let mut builder = AsmBuilder::::default(); // Obtain witness inputs let zkvm_proof_input_variables = ZKVMProofInput::read(&mut builder); - verify_zkvm_proof( - &mut builder, - zkvm_proof_input_variables, - &verifier, - proving_sequence, - ); + verify_zkvm_proof(&mut builder, zkvm_proof_input_variables, &verifier); builder.halt(); // Pass in witness stream @@ -475,7 +283,7 @@ pub fn inner_test_thread() { pub fn test_zkvm_proof_verifier_from_bincode_exports() { let stack_size = 64 * 1024 * 1024; // 64 MB - let handler = thread::Builder::new() + let handler = std::thread::Builder::new() .stack_size(stack_size) .spawn(inner_test_thread) .expect("Failed to spawn thread"); diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 438284d..347f556 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,126 +1,114 @@ -use crate::arithmetics::{challenger_multi_observe, exts_to_felts, print_felt_arr}; -use crate::e2e::SubcircuitParams; -use crate::tower_verifier::binding::IOPProverMessage; -use crate::tower_verifier::program::verify_tower_proof; -use crate::transcript::transcript_observe_label; -use crate::zkvm_verifier::binding::ZKVMProofInput; -use crate::zkvm_verifier::binding::{ - TowerProofInput, ZKVMOpcodeProofInput, ZKVMTableProofInput, E, F, -}; -use crate::zkvm_verifier::verifier::verify_zkvm_proof; -use crate::{ - arithmetics::{ - build_eq_x_r_vec_sequential, ceil_log2, concat, dot_product as ext_dot_product, - eq_eval_less_or_equal_than, eval_wellform_address_vec, gen_alpha_pows, max_usize_arr, - max_usize_vec, next_pow2_instance_padding, product, sum as ext_sum, - }, - tower_verifier::{binding::PointVariable, program::iop_verifier_state_verify}, -}; -use ceno_mle::expression::StructuralWitIn; -use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; -use ff_ext::BabyBearExt4; -use itertools::interleave; -use itertools::max; -use itertools::Itertools; -use mpcs::BasefoldCommitment; -use mpcs::{Basefold, BasefoldRSParams}; -use openvm_circuit::arch::{instructions::program::Program, SystemConfig, VmExecutor}; -use openvm_native_circuit::{Native, NativeConfig}; -use openvm_native_compiler::conversion::convert_program; -use openvm_native_compiler::prelude::*; -use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions}; -use openvm_native_compiler_derive::iter_zip; -use openvm_native_recursion::challenger::{self, CanSampleVariable}; -use openvm_native_recursion::challenger::{ - duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, -}; -use openvm_native_recursion::hints::Hintable; -use openvm_stark_backend::config::StarkGenericConfig; -use openvm_stark_sdk::{ - config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, -}; -use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra}; -use std::collections::HashMap; -use std::fs::File; -use std::marker::PhantomData; - -type Pcs = Basefold; -const NUM_FANIN: usize = 2; -const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup -const SEL_DEGREE: usize = 2; - -type SC = BabyBearPoseidon2Config; -type EF = ::Challenge; - -#[test] -pub fn test_native_multi_observe() { - // OpenVM DSL - let mut builder = AsmBuilder::::default(); - - vm_program(&mut builder); - - builder.halt(); - - // Pass in witness stream - let witness_stream: Vec< - Vec>, - > = Vec::new(); - - // Compile program - let options = CompilerOptions::default().with_cycle_tracker(); - let mut compiler = AsmCompiler::new(options.word_size); - compiler.build(builder.operations); - let asm_code = compiler.code(); - let program = convert_program(asm_code, options); - - let mut system_config = SystemConfig::default() - .with_public_values(4) - .with_max_segment_len((1 << 25) - 100); - system_config.profiling = true; - let config = NativeConfig::new(system_config, Native); - - let executor = VmExecutor::::new(config); - - // Alternative execution - // executor.execute(program, witness_stream).unwrap(); - - let res = executor - .execute_and_then(program, witness_stream, |_, seg| Ok(seg), |err| err) - .unwrap(); - - for (i, seg) in res.iter().enumerate() { - println!("=> segment {:?} metrics: {:?}", i, seg.metrics); +#[cfg(test)] +mod tests { + + use crate::arithmetics::{challenger_multi_observe, exts_to_felts}; + + use crate::zkvm_verifier::binding::{E, F}; + use ceno_mle::expression::StructuralWitIn; + use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; + use ff_ext::BabyBearExt4; + use itertools::interleave; + use itertools::max; + use itertools::Itertools; + use mpcs::BasefoldCommitment; + use mpcs::{Basefold, BasefoldRSParams}; + use openvm_circuit::arch::SystemConfig; + use openvm_circuit::arch::VmExecutor; + use openvm_native_circuit::Native; + use openvm_native_circuit::NativeConfig; + use openvm_native_compiler::conversion::convert_program; + use openvm_native_compiler::prelude::*; + use openvm_native_compiler::{asm::AsmBuilder, conversion::CompilerOptions}; + use openvm_native_compiler_derive::iter_zip; + use openvm_native_recursion::challenger::{self, CanSampleVariable}; + use openvm_native_recursion::challenger::{ + duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, + }; + use openvm_native_recursion::hints::Hintable; + use openvm_stark_backend::config::StarkGenericConfig; + use openvm_stark_sdk::{ + config::baby_bear_poseidon2::BabyBearPoseidon2Config, p3_baby_bear::BabyBear, + }; + use p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra}; + + type Pcs = Basefold; + const NUM_FANIN: usize = 2; + const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup + const SEL_DEGREE: usize = 2; + + type SC = BabyBearPoseidon2Config; + type EF = ::Challenge; + + #[test] + pub fn test_native_multi_observe() { + // OpenVM DSL + let mut builder = AsmBuilder::::default(); + + vm_program(&mut builder); + + builder.halt(); + + // Pass in witness stream + let witness_stream: Vec< + Vec>, + > = Vec::new(); + + // Compile program + let options = CompilerOptions::default().with_cycle_tracker(); + let mut compiler = AsmCompiler::new(options.word_size); + compiler.build(builder.operations); + let asm_code = compiler.code(); + let program = convert_program(asm_code, options); + + let mut system_config = SystemConfig::default() + .with_public_values(4) + .with_max_segment_len((1 << 25) - 100); + system_config.profiling = true; + let config = NativeConfig::new(system_config, Native); + + let executor = VmExecutor::::new(config); + + // Alternative execution + // executor.execute(program, witness_stream).unwrap(); + + let res = executor + .execute_and_then(program, witness_stream, |_, seg| Ok(seg), |err| err) + .unwrap(); + + for (i, seg) in res.iter().enumerate() { + println!("=> segment {:?} metrics: {:?}", i, seg.metrics); + } } -} -fn vm_program(builder: &mut Builder) { - let e1: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(16)); - let e2: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(32)); - let e3: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(64)); - let e4: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(128)); - let e5: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(256)); - let len: usize = 5; + fn vm_program(builder: &mut Builder) { + let e1: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(16)); + let e2: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(32)); + let e3: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(64)); + let e4: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(128)); + let e5: Ext = builder.constant(C::EF::GENERATOR.exp_power_of_2(256)); + let len: usize = 5; - let e_arr: Array> = builder.dyn_array(len); - builder.set(&e_arr, 0, e1); - builder.set(&e_arr, 1, e2); - builder.set(&e_arr, 2, e3); - builder.set(&e_arr, 3, e4); - builder.set(&e_arr, 4, e5); + let e_arr: Array> = builder.dyn_array(len); + builder.set(&e_arr, 0, e1); + builder.set(&e_arr, 1, e2); + builder.set(&e_arr, 2, e3); + builder.set(&e_arr, 3, e4); + builder.set(&e_arr, 4, e5); - unsafe { - let mut c1 = DuplexChallengerVariable::new(builder); - let mut c2 = DuplexChallengerVariable::new(builder); + unsafe { + let mut c1 = DuplexChallengerVariable::new(builder); + let mut c2 = DuplexChallengerVariable::new(builder); - let f_arr1 = exts_to_felts(builder, &e_arr); - let f_arr2 = f_arr1.clone(); + let f_arr1 = exts_to_felts(builder, &e_arr); + let f_arr2 = f_arr1.clone(); - challenger_multi_observe(builder, &mut c1, &f_arr1); - let test_e1 = c1.sample(builder); + challenger_multi_observe(builder, &mut c1, &f_arr1); + let test_e1 = c1.sample(builder); - c2.observe_slice(builder, f_arr2); - let test_e2 = c2.sample(builder); + c2.observe_slice(builder, f_arr2); + let test_e2 = c2.sample(builder); - builder.assert_felt_eq(test_e1, test_e2); + builder.assert_felt_eq(test_e1, test_e2); + } } } diff --git a/src/lib.rs b/src/lib.rs index a48c288..04165c2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod arithmetics; +mod basefold_verifier; pub mod constants; mod tower_verifier; mod transcript; diff --git a/src/tower_verifier/binding.rs b/src/tower_verifier/binding.rs index 112ce1a..26c888a 100644 --- a/src/tower_verifier/binding.rs +++ b/src/tower_verifier/binding.rs @@ -11,6 +11,7 @@ pub type InnerConfig = AsmConfig; use openvm_stark_sdk::p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; +use serde::{Deserialize, Serialize}; #[derive(DslVariable, Clone)] pub struct PointVariable { @@ -28,8 +29,9 @@ pub struct IOPProverMessageVariable { pub evaluations: Array>, } +#[derive(Clone, Deserialize)] pub struct Point { - pub fs: Vec, + pub fs: Vec, } impl Hintable for Point { type HintVariable = PointVariable; @@ -48,10 +50,42 @@ impl Hintable for Point { } impl VecAutoHintable for Point {} -#[derive(Debug)] +pub struct PointAndEval { + pub point: Point, + pub eval: E, +} +impl Hintable for PointAndEval { + type HintVariable = PointAndEvalVariable; + + fn read(builder: &mut Builder) -> Self::HintVariable { + let point = Point::read(builder); + let eval = E::read(builder); + PointAndEvalVariable { point, eval } + } + + fn write(&self) -> Vec::N>> { + let mut stream = Vec::new(); + stream.extend(self.point.write()); + stream.extend(self.eval.write()); + stream + } +} +impl VecAutoHintable for PointAndEval {} + +#[derive(Debug, Deserialize)] pub struct IOPProverMessage { pub evaluations: Vec, } + +use ceno_sumcheck::structs::IOPProverMessage as InnerIOPProverMessage; +impl From> for IOPProverMessage { + fn from(value: InnerIOPProverMessage) -> Self { + IOPProverMessage { + evaluations: value.evaluations, + } + } +} + impl Hintable for IOPProverMessage { type HintVariable = IOPProverMessageVariable; diff --git a/src/tower_verifier/program.rs b/src/tower_verifier/program.rs index 6c90245..85d097b 100644 --- a/src/tower_verifier/program.rs +++ b/src/tower_verifier/program.rs @@ -15,6 +15,77 @@ use openvm_native_recursion::challenger::{ }; use p3_field::FieldAlgebra; +pub(crate) fn interpolate_uni_poly( + builder: &mut Builder, + p_i: &Array>, + eval_at: Ext, +) -> Ext { + let len = p_i.len(); + let evals: Array> = builder.dyn_array(len.clone()); + let prod: Ext = builder.eval(eval_at); + + builder.set(&evals, 0, eval_at); + + // `prod = \prod_{j} (eval_at - j)` + let e: Ext = builder.constant(C::EF::ONE); + let one: Ext = builder.constant(C::EF::ONE); + builder.range(1, len.clone()).for_each(|i_vec, builder| { + let i = i_vec[0]; + let tmp: Ext = builder.constant(C::EF::ONE); + builder.assign(&tmp, eval_at - e); + builder.set(&evals, i, tmp); + builder.assign(&prod, prod * tmp); + builder.assign(&e, e + one); + }); + + let denom_up: Ext = builder.constant(C::EF::ONE); + let i: Ext = builder.constant(C::EF::ONE); + builder.assign(&i, i + one); + builder.range(2, len.clone()).for_each(|_i_vec, builder| { + builder.assign(&denom_up, denom_up * i); + builder.assign(&i, i + one); + }); + let denom_down: Ext = builder.constant(C::EF::ONE); + + let idx_vec_len: RVar = builder.eval_expr(len.clone() - RVar::from(1)); + let idx_vec: Array> = builder.dyn_array(idx_vec_len); + let idx_val: Ext = builder.constant(C::EF::ONE); + builder.range(0, idx_vec.len()).for_each(|i_vec, builder| { + builder.set(&idx_vec, i_vec[0], idx_val); + builder.assign(&idx_val, idx_val + one); + }); + let idx_rev = reverse(builder, &idx_vec); + let res = builder.constant(C::EF::ZERO); + + let len_f = idx_val.clone(); + let neg_one: Ext = builder.constant(C::EF::NEG_ONE); + let evals_rev = reverse(builder, &evals); + let p_i_rev = reverse(builder, &p_i); + + let mut idx_pos: RVar = builder.eval_expr(len.clone() - RVar::from(1)); + iter_zip!(builder, idx_rev, evals_rev, p_i_rev).for_each(|ptr_vec, builder| { + let idx = builder.iter_ptr_get(&idx_rev, ptr_vec[0]); + let eval = builder.iter_ptr_get(&evals_rev, ptr_vec[1]); + let up_eval_inv: Ext = builder.eval(denom_up * eval); + builder.assign(&up_eval_inv, up_eval_inv.inverse()); + let p = builder.iter_ptr_get(&p_i_rev, ptr_vec[2]); + + builder.assign(&res, res + p * prod * denom_down * up_eval_inv); + builder.assign(&denom_up, denom_up * (len_f - idx) * neg_one); + builder.assign(&denom_down, denom_down * idx); + + idx_pos = builder.eval_expr(idx_pos - RVar::from(1)); + }); + + let p_i_0 = builder.get(&p_i, 0); + let eval_0 = builder.get(&evals, 0); + let up_eval_inv: Ext = builder.eval(denom_up * eval_0); + builder.assign(&up_eval_inv, up_eval_inv.inverse()); + builder.assign(&res, res + p_i_0 * prod * denom_down * up_eval_inv); + + res +} + // Interpolate a uni-variate degree-`p_i.len()-1` polynomial and evaluate this // polynomial at `eval_at`: // @@ -480,15 +551,17 @@ pub fn verify_tower_proof( }, // update point and eval only for last layer |builder| { - builder.set( - &prod_spec_point_n_eval, - spec_index, - PointAndEvalVariable { + let point_and_eval: PointAndEvalVariable = + builder.eval(PointAndEvalVariable { point: PointVariable { fs: rt_prime.clone(), }, eval: evals, - }, + }); + builder.set_value( + &prod_spec_point_n_eval, + spec_index, + point_and_eval, ); }, ); @@ -546,26 +619,22 @@ pub fn verify_tower_proof( }, // update point and eval only for last layer |builder| { - builder.set( - &logup_spec_p_point_n_eval, - spec_index, - PointAndEvalVariable { + let p_eval: PointAndEvalVariable = + builder.eval(PointAndEvalVariable { point: PointVariable { fs: rt_prime.clone(), }, eval: p_eval, - }, - ); - builder.set( - &logup_spec_q_point_n_eval, - spec_index, - PointAndEvalVariable { + }); + let q_eval: PointAndEvalVariable = + builder.eval(PointAndEvalVariable { point: PointVariable { fs: rt_prime.clone(), }, eval: q_eval, - }, - ); + }); + builder.set_value(&logup_spec_p_point_n_eval, spec_index, p_eval); + builder.set_value(&logup_spec_q_point_n_eval, spec_index, q_eval); }, ); }); @@ -578,12 +647,15 @@ pub fn verify_tower_proof( builder.cycle_tracker_end("derive next layer's expected sum"); - next_rt = PointAndEvalVariable { - point: PointVariable { - fs: rt_prime.clone(), + builder.assign( + &next_rt, + PointAndEvalVariable { + point: PointVariable { + fs: rt_prime.clone(), + }, + eval: curr_eval.clone(), }, - eval: curr_eval.clone(), - }; + ); }); ( @@ -602,15 +674,14 @@ mod tests { use crate::tower_verifier::binding::TowerVerifierInput; use crate::tower_verifier::program::iop_verifier_state_verify; use crate::tower_verifier::program::verify_tower_proof; - use ceno_mle::mle::{DenseMultilinearExtension, IntoMLE, MultilinearExtension}; - use ceno_mle::virtual_poly::ArcMultilinearExtension; + use ceno_mle::mle::ArcMultilinearExtension; + use ceno_mle::mle::{IntoMLE, MultilinearExtension}; use ceno_mle::virtual_polys::VirtualPolynomials; use ceno_sumcheck::structs::IOPProverState; use ceno_transcript::BasicTranscript; use ceno_zkvm::scheme::constants::NUM_FANIN; - use ceno_zkvm::scheme::utils::infer_tower_logup_witness; - use ceno_zkvm::scheme::utils::infer_tower_product_witness; - use ceno_zkvm::structs::TowerProver; + use ceno_zkvm::scheme::hal::TowerProver; + use ceno_zkvm::scheme::hal::TowerProverSpec; use ff_ext::BabyBearExt4; use ff_ext::FieldFrom; use ff_ext::FromUniformBytes; @@ -684,9 +755,8 @@ mod tests { // run sumcheck prover to get sumcheck proof let mut rng = thread_rng(); - let (mles, expected_sum) = - DenseMultilinearExtension::::random_mle_list(nv, degree, &mut rng); - let mles: Vec> = + let (mles, expected_sum) = MultilinearExtension::::random_mle_list(nv, degree, &mut rng); + let mles: Vec> = mles.into_iter().map(|mle| mle as _).collect_vec(); let mut virtual_poly: VirtualPolynomials<'_, E> = VirtualPolynomials::new(1, nv); virtual_poly.add_mle_list(mles.iter().collect_vec(), E::from_v(1)); @@ -756,9 +826,9 @@ mod tests { setup_tracing_with_log_level(tracing::Level::WARN); - let records: Vec> = (0..num_prod_specs) + let records: Vec> = (0..num_prod_specs) .map(|_| { - DenseMultilinearExtension::from_evaluations_ext_vec( + MultilinearExtension::from_evaluations_ext_vec( nv - 1, E::random_vec(1 << (nv - 1), &mut rng), ) @@ -766,10 +836,7 @@ mod tests { .collect_vec(); let denom_records = (0..num_logup_specs) .map(|_| { - DenseMultilinearExtension::from_evaluations_ext_vec( - nv, - E::random_vec(1 << nv, &mut rng), - ) + MultilinearExtension::from_evaluations_ext_vec(nv, E::random_vec(1 << nv, &mut rng)) }) .collect_vec(); @@ -810,7 +877,7 @@ mod tests { first.to_vec().into_mle().into(), second.to_vec().into_mle().into(), ]; - ceno_zkvm::structs::TowerProverSpec { + TowerProverSpec { witness: infer_tower_logup_witness(None, last_layer), } }) diff --git a/src/transcript/mod.rs b/src/transcript/mod.rs index 9a61b80..45af2e2 100644 --- a/src/transcript/mod.rs +++ b/src/transcript/mod.rs @@ -1,9 +1,9 @@ use ff_ext::{BabyBearExt4, ExtensionField as CenoExtensionField, SmallField}; use openvm_native_compiler::prelude::*; -use openvm_native_recursion::challenger::ChallengerVariable; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use openvm_native_recursion::challenger::{CanSampleBitsVariable, ChallengerVariable}; use p3_field::FieldAlgebra; pub fn transcript_observe_label( @@ -17,3 +17,18 @@ pub fn transcript_observe_label( challenger.observe(builder, f); } } + +pub fn transcript_check_pow_witness( + builder: &mut Builder, + challenger: &mut DuplexChallengerVariable, + nbits: usize, + witness: Felt, +) { + let nbits = builder.eval_expr(Usize::from(nbits)); + challenger.observe(builder, witness); + let bits = challenger.sample_bits(builder, nbits); + builder.range(0, nbits).for_each(|index_vec, builder| { + let bit = builder.get(&bits, index_vec[0]); + builder.assert_eq::>(bit, Usize::from(0)); + }); +} diff --git a/src/zkvm_verifier/binding.rs b/src/zkvm_verifier/binding.rs index 4007238..5e35664 100644 --- a/src/zkvm_verifier/binding.rs +++ b/src/zkvm_verifier/binding.rs @@ -1,11 +1,17 @@ use crate::arithmetics::next_pow2_instance_padding; +use crate::basefold_verifier::basefold::{ + BasefoldCommitment, BasefoldCommitmentVariable, BasefoldProof, BasefoldProofVariable, +}; +use crate::basefold_verifier::query_phase::{ + QueryPhaseVerifierInput, QueryPhaseVerifierInputVariable, +}; use crate::{ arithmetics::ceil_log2, tower_verifier::binding::{IOPProverMessage, IOPProverMessageVariable}, }; use ark_std::iterable::Iterable; use ff_ext::BabyBearExt4; -use mpcs::BasefoldCommitment; +use itertools::Itertools; use openvm_native_compiler::{ asm::AsmConfig, ir::{Array, Builder, Config, Felt}, @@ -25,18 +31,12 @@ pub struct ZKVMProofInputVariable { pub raw_pi: Array>>, pub raw_pi_num_variables: Array>, pub pi_evals: Array>, - pub opcode_proofs: Array>, - pub table_proofs: Array>, - - pub witin_commit: Array>, - pub witin_commit_trivial_commits: Array>>, - pub witin_commit_log2_max_codeword_size: Felt, - - pub has_fixed_commit: Usize, - pub fixed_commit: Array>, - pub fixed_commit_trivial_commits: Array>>, - pub fixed_commit_log2_max_codeword_size: Felt, - pub num_instances: Array>>, + pub chip_proofs: Array>, + pub max_num_var: Var, + pub witin_commit: BasefoldCommitmentVariable, + pub witin_perm: Array>, + pub fixed_perm: Array>, + pub pcs_proof: BasefoldProofVariable, } #[derive(DslVariable, Clone)] @@ -50,7 +50,7 @@ pub struct TowerProofInputVariable { } #[derive(DslVariable, Clone)] -pub struct ZKVMOpcodeProofInputVariable { +pub struct ZKVMChipProofInputVariable { pub idx: Usize, pub idx_felt: Felt, pub num_instances: Usize, @@ -72,36 +72,15 @@ pub struct ZKVMOpcodeProofInputVariable { pub fixed_in_evals: Array>, } -#[derive(DslVariable, Clone)] -pub struct ZKVMTableProofInputVariable { - pub idx: Usize, - pub idx_felt: Felt, - pub num_instances: Usize, - pub log2_num_instances: Usize, - - pub record_r_out_evals_len: Usize, - pub record_w_out_evals_len: Usize, - pub record_lk_out_evals_len: Usize, - - pub record_r_out_evals: Array>>, - pub record_w_out_evals: Array>>, - pub record_lk_out_evals: Array>>, - - pub tower_proof: TowerProofInputVariable, - pub fixed_in_evals: Array>, - pub wits_in_evals: Array>, -} - pub(crate) struct ZKVMProofInput { pub raw_pi: Vec>, // Evaluation of raw_pi. pub pi_evals: Vec, - pub opcode_proofs: Vec, - pub table_proofs: Vec, - pub witin_commit: BasefoldCommitment, - pub fixed_commit: Option>, - pub num_instances: Vec<(usize, usize)>, + pub chip_proofs: Vec, + pub witin_commit: BasefoldCommitment, + pub pcs_proof: BasefoldProof, } + impl Hintable for ZKVMProofInput { type HintVariable = ZKVMProofInputVariable; @@ -109,117 +88,71 @@ impl Hintable for ZKVMProofInput { let raw_pi = Vec::>::read(builder); let raw_pi_num_variables = Vec::::read(builder); let pi_evals = Vec::::read(builder); - let opcode_proofs = Vec::::read(builder); - let table_proofs = Vec::::read(builder); - - let witin_commit = Vec::::read(builder); - let witin_commit_trivial_commits = Vec::>::read(builder); - let witin_commit_log2_max_codeword_size = F::read(builder); - - let has_fixed_commit = Usize::Var(usize::read(builder)); - let fixed_commit = Vec::::read(builder); - let fixed_commit_trivial_commits = Vec::>::read(builder); - let fixed_commit_log2_max_codeword_size = F::read(builder); - - let num_instances = Vec::>::read(builder); + let chip_proofs = Vec::::read(builder); + let max_num_var = usize::read(builder); + let witin_commit = BasefoldCommitment::read(builder); + let witin_perm = Vec::::read(builder); + let fixed_perm = Vec::::read(builder); + let pcs_proof = BasefoldProof::read(builder); ZKVMProofInputVariable { raw_pi, raw_pi_num_variables, pi_evals, - opcode_proofs, - table_proofs, + chip_proofs, + max_num_var, witin_commit, - witin_commit_trivial_commits, - witin_commit_log2_max_codeword_size, - has_fixed_commit, - fixed_commit, - fixed_commit_trivial_commits, - fixed_commit_log2_max_codeword_size, - num_instances, + witin_perm, + fixed_perm, + pcs_proof, } } fn write(&self) -> Vec::N>> { let mut stream = Vec::new(); - stream.extend(self.raw_pi.write()); + let raw_pi_num_variables: Vec = self + .raw_pi + .iter() + .map(|v| ceil_log2(v.len().next_power_of_two())) + .collect(); + let witin_num_vars = self + .chip_proofs + .iter() + .map(|proof| ceil_log2(proof.num_instances).max(1)) + .collect::>(); + let fixed_num_vars = self + .chip_proofs + .iter() + .filter(|proof| proof.fixed_in_evals.len() > 0) + .map(|proof| ceil_log2(proof.num_instances).max(1)) + .collect::>(); + let max_num_var = witin_num_vars.iter().map(|x| *x).max().unwrap_or(0); + let get_perm = |v: Vec| { + let mut perm = vec![0; v.len()]; + v.into_iter() + // the original order + .enumerate() + .sorted_by(|(_, nv_a), (_, nv_b)| Ord::cmp(nv_b, nv_a)) + .enumerate() + // j is the new index where i is the original index + .map(|(j, (i, _))| (i, j)) + .for_each(|(i, j)| { + perm[i] = j; + }); + perm + }; + let witin_perm = get_perm(witin_num_vars); + let fixed_perm = get_perm(fixed_num_vars); - let mut raw_pi_num_variables: Vec = vec![]; - for v in &self.raw_pi { - raw_pi_num_variables.push(ceil_log2(v.len().next_power_of_two())); - } + stream.extend(self.raw_pi.write()); stream.extend(raw_pi_num_variables.write()); - stream.extend(self.pi_evals.write()); - stream.extend(self.opcode_proofs.write()); - stream.extend(self.table_proofs.write()); - - // Write in witin_commit - let mut cmt_vec: Vec = vec![]; - self.witin_commit.commit().iter().for_each(|x| { - let f: F = serde_json::from_value(serde_json::to_value(&x).unwrap()).unwrap(); - cmt_vec.push(f); - }); - let mut witin_commit_trivial_commits: Vec> = vec![]; - for trivial_commit in &self.witin_commit.trivial_commits { - let mut t_cmt_vec: Vec = vec![]; - trivial_commit.iter().for_each(|x| { - let f: F = - serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - t_cmt_vec.push(f); - }); - witin_commit_trivial_commits.push(t_cmt_vec); - } - let witin_commit_log2_max_codeword_size = - F::from_canonical_u32(self.witin_commit.log2_max_codeword_size as u32); - stream.extend(cmt_vec.write()); - stream.extend(witin_commit_trivial_commits.write()); - stream.extend(witin_commit_log2_max_codeword_size.write()); - - // Write in fixed_commit - let has_fixed_commit: usize = if self.fixed_commit.is_some() { 1 } else { 0 }; - let mut fixed_commit_vec: Vec = vec![]; - let mut fixed_commit_trivial_commits: Vec> = vec![]; - let mut fixed_commit_log2_max_codeword_size: F = F::ZERO.clone(); - if has_fixed_commit > 0 { - self.fixed_commit - .as_ref() - .unwrap() - .commit() - .iter() - .for_each(|x| { - let f: F = - serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - fixed_commit_vec.push(f); - }); - - for trivial_commit in &self.fixed_commit.as_ref().unwrap().trivial_commits { - let mut t_cmt_vec: Vec = vec![]; - trivial_commit.iter().for_each(|x| { - let f: F = - serde_json::from_value(serde_json::to_value(x.clone()).unwrap()).unwrap(); - t_cmt_vec.push(f); - }); - fixed_commit_trivial_commits.push(t_cmt_vec); - } - fixed_commit_log2_max_codeword_size = F::from_canonical_u32( - self.fixed_commit.as_ref().unwrap().log2_max_codeword_size as u32, - ); - } - stream.extend(>::write(&has_fixed_commit)); - stream.extend(fixed_commit_vec.write()); - stream.extend(fixed_commit_trivial_commits.write()); - stream.extend(fixed_commit_log2_max_codeword_size.write()); - - // Write num_instances - let mut num_instances_vec: Vec> = vec![]; - for (circuit_size, num_var) in &self.num_instances { - num_instances_vec.push(vec![ - F::from_canonical_usize(*circuit_size), - F::from_canonical_usize(*num_var), - ]); - } - stream.extend(num_instances_vec.write()); + stream.extend(self.chip_proofs.write()); + stream.extend(>::write(&max_num_var)); + stream.extend(self.witin_commit.write()); + stream.extend(witin_perm.write()); + stream.extend(fixed_perm.write()); + stream.extend(self.pcs_proof.write()); stream } @@ -236,6 +169,7 @@ pub struct TowerProofInput { pub num_logup_specs: usize, pub logup_specs_eval: Vec>>, } + impl Hintable for TowerProofInput { type HintVariable = TowerProofInputVariable; @@ -296,7 +230,7 @@ impl Hintable for TowerProofInput { } } -pub struct ZKVMOpcodeProofInput { +pub struct ZKVMChipProofInput { pub idx: usize, pub num_instances: usize, @@ -315,9 +249,11 @@ pub struct ZKVMOpcodeProofInput { pub wits_in_evals: Vec, pub fixed_in_evals: Vec, } -impl VecAutoHintable for ZKVMOpcodeProofInput {} -impl Hintable for ZKVMOpcodeProofInput { - type HintVariable = ZKVMOpcodeProofInputVariable; + +impl VecAutoHintable for ZKVMChipProofInput {} + +impl Hintable for ZKVMChipProofInput { + type HintVariable = ZKVMChipProofInputVariable; fn read(builder: &mut Builder) -> Self::HintVariable { let idx = Usize::Var(usize::read(builder)); @@ -339,7 +275,7 @@ impl Hintable for ZKVMOpcodeProofInput { let wits_in_evals = Vec::::read(builder); let fixed_in_evals = Vec::::read(builder); - ZKVMOpcodeProofInputVariable { + ZKVMChipProofInputVariable { idx, idx_felt, num_instances, @@ -400,92 +336,3 @@ impl Hintable for ZKVMOpcodeProofInput { stream } } - -pub struct ZKVMTableProofInput { - pub idx: usize, - pub num_instances: usize, - - // tower evaluation at layer 1 - pub record_r_out_evals_len: usize, - pub record_w_out_evals_len: usize, - pub record_lk_out_evals_len: usize, - pub record_r_out_evals: Vec>, - pub record_w_out_evals: Vec>, - pub record_lk_out_evals: Vec>, - - pub tower_proof: TowerProofInput, - - pub fixed_in_evals: Vec, - pub wits_in_evals: Vec, -} -impl VecAutoHintable for ZKVMTableProofInput {} -impl Hintable for ZKVMTableProofInput { - type HintVariable = ZKVMTableProofInputVariable; - - fn read(builder: &mut Builder) -> Self::HintVariable { - let idx = Usize::Var(usize::read(builder)); - let idx_felt = F::read(builder); - - let num_instances = Usize::Var(usize::read(builder)); - let log2_num_instances = Usize::Var(usize::read(builder)); - - let record_r_out_evals_len = Usize::Var(usize::read(builder)); - let record_w_out_evals_len = Usize::Var(usize::read(builder)); - let record_lk_out_evals_len = Usize::Var(usize::read(builder)); - - let record_r_out_evals = Vec::>::read(builder); - let record_w_out_evals = Vec::>::read(builder); - let record_lk_out_evals = Vec::>::read(builder); - - let tower_proof = TowerProofInput::read(builder); - let fixed_in_evals = Vec::::read(builder); - let wits_in_evals = Vec::::read(builder); - - ZKVMTableProofInputVariable { - idx, - idx_felt, - num_instances, - log2_num_instances, - record_r_out_evals_len, - record_w_out_evals_len, - record_lk_out_evals_len, - record_r_out_evals, - record_w_out_evals, - record_lk_out_evals, - tower_proof, - fixed_in_evals, - wits_in_evals, - } - } - - fn write(&self) -> Vec::N>> { - let mut stream = Vec::new(); - stream.extend(>::write(&self.idx)); - - let idx_u32: F = F::from_canonical_u32(self.idx as u32); - stream.extend(idx_u32.write()); - - stream.extend(>::write(&self.num_instances)); - let log2_num_instances = ceil_log2(self.num_instances); - stream.extend(>::write(&log2_num_instances)); - - stream.extend(>::write( - &self.record_r_out_evals_len, - )); - stream.extend(>::write( - &self.record_w_out_evals_len, - )); - stream.extend(>::write( - &self.record_lk_out_evals_len, - )); - - stream.extend(self.record_r_out_evals.write()); - stream.extend(self.record_w_out_evals.write()); - stream.extend(self.record_lk_out_evals.write()); - - stream.extend(self.tower_proof.write()); - stream.extend(self.fixed_in_evals.write()); - stream.extend(self.wits_in_evals.write()); - stream - } -} diff --git a/src/zkvm_verifier/verifier.rs b/src/zkvm_verifier/verifier.rs index 699b64a..2fabd37 100644 --- a/src/zkvm_verifier/verifier.rs +++ b/src/zkvm_verifier/verifier.rs @@ -1,11 +1,15 @@ -use super::binding::{ - ZKVMOpcodeProofInputVariable, ZKVMProofInputVariable, ZKVMTableProofInputVariable, -}; +use super::binding::{ZKVMChipProofInputVariable, ZKVMProofInputVariable}; use crate::arithmetics::{ challenger_multi_observe, eval_ceno_expr_with_instance, print_ext_arr, print_felt_arr, PolyEvaluator, UniPolyExtrapolator, }; -use crate::e2e::SubcircuitParams; +use crate::basefold_verifier::basefold::{ + BasefoldCommitmentVariable, RoundOpeningVariable, RoundVariable, +}; +use crate::basefold_verifier::mmcs::MmcsCommitmentVariable; +use crate::basefold_verifier::query_phase::PointAndEvalsVariable; +use crate::basefold_verifier::utils::pow_2; +use crate::basefold_verifier::verifier::batch_verify; use crate::tower_verifier::program::verify_tower_proof; use crate::transcript::transcript_observe_label; use crate::{ @@ -17,20 +21,25 @@ use crate::{ tower_verifier::{binding::PointVariable, program::iop_verifier_state_verify}, }; use ceno_mle::expression::{Instance, StructuralWitIn}; +use ceno_zkvm::e2e::B; +use ceno_zkvm::structs::VerifyingKey; use ceno_zkvm::{circuit_builder::SetTableSpec, scheme::verifier::ZKVMVerifier}; use ff_ext::BabyBearExt4; -use itertools::interleave; use itertools::max; +use itertools::{interleave, Itertools}; use mpcs::{Basefold, BasefoldRSParams}; use openvm_native_compiler::prelude::*; use openvm_native_compiler_derive::iter_zip; use openvm_native_recursion::challenger::{ duplex::DuplexChallengerVariable, CanObserveVariable, FeltChallenger, }; +use p3_baby_bear::BabyBear; use p3_field::{Field, FieldAlgebra}; +type F = BabyBear; type E = BabyBearExt4; type Pcs = Basefold; + const NUM_FANIN: usize = 2; const MAINCONSTRAIN_SUMCHECK_BATCH_SIZE: usize = 3; // read/write/lookup const SEL_DEGREE: usize = 2; @@ -68,11 +77,10 @@ pub fn transcript_group_sample_ext( e } -pub fn verify_zkvm_proof( +pub fn verify_zkvm_proof>( builder: &mut Builder, zkvm_proof_input: ZKVMProofInputVariable, - ceno_constraint_system: &ZKVMVerifier, - proving_sequence: Vec, + vk: &ZKVMVerifier, ) { let mut challenger = DuplexChallengerVariable::new(builder); transcript_observe_label(builder, &mut challenger, b"riscv"); @@ -100,44 +108,63 @@ pub fn verify_zkvm_proof( }, ); - challenger_multi_observe(builder, &mut challenger, &zkvm_proof_input.fixed_commit); - iter_zip!(builder, zkvm_proof_input.fixed_commit_trivial_commits).for_each( - |ptr_vec, builder| { - let trivial_cmt = - builder.iter_ptr_get(&zkvm_proof_input.fixed_commit_trivial_commits, ptr_vec[0]); - challenger_multi_observe(builder, &mut challenger, &trivial_cmt); - }, - ); - challenger.observe( - builder, - zkvm_proof_input.fixed_commit_log2_max_codeword_size, - ); + let fixed_commit = if let Some(fixed_commit) = vk.vk.fixed_commit.as_ref() { + let commit: crate::basefold_verifier::hash::Hash = fixed_commit.commit().into(); + let commit_array: Array> = builder.dyn_array(commit.value.len()); + commit.value.into_iter().enumerate().for_each(|(i, v)| { + let v = builder.constant(v); + // TODO: put fixed commit to public values + // builder.commit_public_value(v); + + builder.set_value(&commit_array, i, v); + }); + challenger_multi_observe(builder, &mut challenger, &commit_array); + + // FIXME: do not hardcode this in the program + let log2_max_codeword_size_felt = builder.constant(C::F::from_canonical_usize( + fixed_commit.log2_max_codeword_size, + )); + let log2_max_codeword_size: Var = builder.constant(C::N::from_canonical_usize( + fixed_commit.log2_max_codeword_size, + )); + + challenger.observe(builder, log2_max_codeword_size_felt); + + Some(BasefoldCommitmentVariable { + commit: MmcsCommitmentVariable { + value: commit_array, + }, + log2_max_codeword_size: log2_max_codeword_size.into(), + }) + } else { + None + }; let zero_f: Felt = builder.constant(C::F::ZERO); - iter_zip!(builder, zkvm_proof_input.num_instances).for_each(|ptr_vec, builder| { - let ns = builder.iter_ptr_get(&zkvm_proof_input.num_instances, ptr_vec[0]); - let circuit_size = builder.get(&ns, 0); - let num_var = builder.get(&ns, 1); + iter_zip!(builder, zkvm_proof_input.chip_proofs).for_each(|ptr_vec, builder| { + let chip_proof = builder.iter_ptr_get(&zkvm_proof_input.chip_proofs, ptr_vec[0]); + let num_instances = builder.unsafe_cast_var_to_felt(chip_proof.num_instances.get_var()); - challenger.observe(builder, circuit_size); + challenger.observe(builder, chip_proof.idx_felt); challenger.observe(builder, zero_f); - challenger.observe(builder, num_var); + challenger.observe(builder, num_instances); challenger.observe(builder, zero_f); }); - challenger_multi_observe(builder, &mut challenger, &zkvm_proof_input.witin_commit); - - iter_zip!(builder, zkvm_proof_input.witin_commit_trivial_commits).for_each( - |ptr_vec, builder| { - let trivial_cmt = - builder.iter_ptr_get(&zkvm_proof_input.witin_commit_trivial_commits, ptr_vec[0]); - challenger_multi_observe(builder, &mut challenger, &trivial_cmt); - }, - ); - challenger.observe( + challenger_multi_observe( builder, - zkvm_proof_input.witin_commit_log2_max_codeword_size, + &mut challenger, + &zkvm_proof_input.witin_commit.commit.value, ); + { + let log2_max_codeword_size = builder.unsafe_cast_var_to_felt( + zkvm_proof_input + .witin_commit + .log2_max_codeword_size + .get_var(), + ); + challenger.observe(builder, log2_max_codeword_size); + } let alpha = challenger.sample_ext(builder); let beta = challenger.sample_ext(builder); @@ -150,135 +177,178 @@ pub fn verify_zkvm_proof( let mut poly_evaluator = PolyEvaluator::new(builder); let dummy_table_item = alpha.clone(); - let dummy_table_item_multiplicity: Ext = builder.constant(C::EF::ZERO); - - let mut rt_points: Vec>> = Vec::with_capacity(proving_sequence.len()); - let mut evaluations: Vec>> = - Vec::with_capacity(2 * proving_sequence.len()); // witin + fixed thus *2 - - for subcircuit_params in proving_sequence { - if subcircuit_params.is_opcode { - let opcode_proof = builder.get( - &zkvm_proof_input.opcode_proofs, - subcircuit_params.type_order_idx, - ); - let id_f: Felt = - builder.constant(C::F::from_canonical_usize(subcircuit_params.id)); - challenger.observe(builder, id_f); - - builder.cycle_tracker_start("Verify opcode proof"); - let input_opening_point = verify_opcode_proof( - builder, - &mut challenger, - &opcode_proof, - &zkvm_proof_input.pi_evals, - &challenges, - &subcircuit_params, - &ceno_constraint_system, - &mut unipoly_extrapolator, - ); - builder.cycle_tracker_end("Verify opcode proof"); + let dummy_table_item_multiplicity: Var = builder.constant(C::N::ZERO); + + let num_fixed_opening = vk + .vk + .circuit_vks + .values() + .filter(|c| c.get_cs().num_fixed() > 0) + .count(); + let witin_openings: Array> = + builder.dyn_array(zkvm_proof_input.chip_proofs.len()); + let fixed_openings: Array> = + builder.dyn_array(Usize::from(num_fixed_opening)); + let num_chips_verified: Usize = builder.eval(C::N::ZERO); + let num_chips_have_fixed: Usize = builder.eval(C::N::ZERO); + + let chip_indices: Array> = builder.dyn_array(zkvm_proof_input.chip_proofs.len()); + builder + .range(0, chip_indices.len()) + .for_each(|idx_vec, builder| { + let i = idx_vec[0]; + let chip_proof = builder.get(&zkvm_proof_input.chip_proofs, i); + builder.set(&chip_indices, i, chip_proof.idx); + }); - rt_points.push(input_opening_point); - evaluations.push(opcode_proof.wits_in_evals); + // iterate over all chips + for (i, chip_vk) in vk.vk.circuit_vks.values().enumerate() { + let chip_id: Var = builder.get(&chip_indices, num_chips_verified.get_var()); + builder.if_eq(chip_id, RVar::from(i)).then(|builder| { + let chip_proof = + builder.get(&zkvm_proof_input.chip_proofs, num_chips_verified.get_var()); + challenger.observe(builder, chip_proof.idx_felt); + + builder.cycle_tracker_start("Verify chip proof"); + let input_opening_point = if chip_vk.get_cs().is_opcode_circuit() { + verify_opcode_proof( + builder, + &mut challenger, + &chip_proof, + &zkvm_proof_input.pi_evals, + &challenges, + &chip_vk, + &mut unipoly_extrapolator, + ) + } else { + verify_table_proof( + builder, + &mut challenger, + &chip_proof, + &zkvm_proof_input.pi_evals, + &challenges, + &chip_vk, + &mut unipoly_extrapolator, + ) + }; + builder.cycle_tracker_end("Verify chip proof"); // getting the number of dummy padding item that we used in this opcode circuit - let cs = ceno_constraint_system.vk.circuit_vks[&subcircuit_params.name].get_cs(); - let num_instances = subcircuit_params.num_instances; - let num_lks = cs.lk_expressions.len(); - let num_padded_instance = next_pow2_instance_padding(num_instances) - num_instances; - - let new_multiplicity: Ext = - builder.constant(C::EF::from_canonical_usize(num_lks * num_padded_instance)); - builder.assign( - &dummy_table_item_multiplicity, - dummy_table_item_multiplicity + new_multiplicity, - ); + if chip_vk.get_cs().is_opcode_circuit() { + let num_lks = chip_vk.get_cs().num_lks(); + // FIXME: use builder to compute this + let num_instances = pow_2(builder, chip_proof.log2_num_instances.get_var()); + let num_padded_instance: Var = + builder.eval(num_instances - chip_proof.num_instances); + + let new_multiplicity: Usize = + builder.eval(Usize::from(num_lks) * Usize::from(num_padded_instance)); + builder.assign( + &dummy_table_item_multiplicity, + dummy_table_item_multiplicity + new_multiplicity, + ); + } - let record_r_out_evals_prod = nested_product(builder, &opcode_proof.record_r_out_evals); + let record_r_out_evals_prod = nested_product(builder, &chip_proof.record_r_out_evals); builder.assign(&prod_r, prod_r * record_r_out_evals_prod); - let record_w_out_evals_prod = nested_product(builder, &opcode_proof.record_w_out_evals); + let record_w_out_evals_prod = nested_product(builder, &chip_proof.record_w_out_evals); builder.assign(&prod_w, prod_w * record_w_out_evals_prod); - iter_zip!(builder, opcode_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { - let evals = builder.iter_ptr_get(&opcode_proof.record_lk_out_evals, ptr_vec[0]); + let sign: Ext = if chip_vk.get_cs().is_opcode_circuit() { + builder.constant(C::EF::ONE) + } else { + builder.constant(-C::EF::ONE) + }; + + iter_zip!(builder, chip_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { + let evals = builder.iter_ptr_get(&chip_proof.record_lk_out_evals, ptr_vec[0]); let p1 = builder.get(&evals, 0); let p2 = builder.get(&evals, 1); let q1 = builder.get(&evals, 2); let q2 = builder.get(&evals, 3); - builder.assign(&logup_sum, logup_sum + p1 * q1.inverse()); - builder.assign(&logup_sum, logup_sum + p2 * q2.inverse()); + builder.assign(&logup_sum, logup_sum + sign * p1 * q1.inverse()); + builder.assign(&logup_sum, logup_sum + sign * p2 * q2.inverse()); }); - } else { - let table_proof = builder.get( - &zkvm_proof_input.table_proofs, - subcircuit_params.type_order_idx, - ); - let id_f: Felt = - builder.constant(C::F::from_canonical_usize(subcircuit_params.id)); - challenger.observe(builder, id_f); - let input_opening_point = verify_table_proof( - builder, - &mut challenger, - &table_proof, - &zkvm_proof_input.raw_pi, - &zkvm_proof_input.raw_pi_num_variables, - &zkvm_proof_input.pi_evals, - &challenges, - &subcircuit_params, - ceno_constraint_system, - &mut unipoly_extrapolator, - &mut poly_evaluator, + builder.assert_usize_eq( + chip_proof.log2_num_instances.clone(), + input_opening_point.len(), ); - rt_points.push(input_opening_point); - evaluations.push(table_proof.wits_in_evals); - let cs = ceno_constraint_system.vk.circuit_vks[&subcircuit_params.name].get_cs(); - if cs.num_fixed > 0 { - evaluations.push(table_proof.fixed_in_evals); - } - - iter_zip!(builder, table_proof.record_lk_out_evals).for_each(|ptr_vec, builder| { - let evals = builder.iter_ptr_get(&table_proof.record_lk_out_evals, ptr_vec[0]); - let p1 = builder.get(&evals, 0); - let p2 = builder.get(&evals, 1); - let q1 = builder.get(&evals, 2); - let q2 = builder.get(&evals, 3); - builder.assign( - &logup_sum, - logup_sum - p1 * q1.inverse() - p2 * q2.inverse(), - ); + let witin_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { + num_var: chip_proof.log2_num_instances.get_var(), + point_and_evals: PointAndEvalsVariable { + point: PointVariable { + fs: input_opening_point.clone(), + }, + evals: chip_proof.wits_in_evals, + }, }); + builder.set_value(&witin_openings, num_chips_verified.get_var(), witin_round); + + if chip_vk.get_cs().num_fixed() > 0 { + let fixed_round: RoundOpeningVariable = builder.eval(RoundOpeningVariable { + num_var: chip_proof.log2_num_instances.get_var(), + point_and_evals: PointAndEvalsVariable { + point: PointVariable { + fs: input_opening_point.clone(), + }, + evals: chip_proof.fixed_in_evals, + }, + }); - let record_w_out_evals_prod = nested_product(builder, &table_proof.record_w_out_evals); - builder.assign(&prod_w, prod_w * record_w_out_evals_prod); - let record_r_out_evals_prod = nested_product(builder, &table_proof.record_r_out_evals); - builder.assign(&prod_r, prod_r * record_r_out_evals_prod); - } + builder.set_value(&fixed_openings, num_chips_have_fixed.get_var(), fixed_round); + + builder.inc(&num_chips_have_fixed); + } + + builder.inc(&num_chips_verified); + }); } + builder.assert_usize_eq(num_chips_have_fixed, Usize::from(num_fixed_opening)); + builder.assert_eq::>(num_chips_verified, chip_indices.len()); + let dummy_table_item_multiplicity = + builder.unsafe_cast_var_to_felt(dummy_table_item_multiplicity); builder.assign( &logup_sum, logup_sum - dummy_table_item_multiplicity * dummy_table_item.inverse(), ); - /* TODO: MPCS - PCS::batch_verify( - &self.vk.vp, - &vm_proof.num_instances, - &rt_points, - self.vk.fixed_commit.as_ref(), - &vm_proof.witin_commit, - &evaluations, - &vm_proof.fixed_witin_opening_proof, - &self.vk.circuit_num_polys, - &mut transcript, - ) - .map_err(ZKVMError::PCSError)?; - */ + let rounds = if num_fixed_opening > 0 { + builder.dyn_array(2) + } else { + builder.dyn_array(1) + }; + builder.set( + &rounds, + 0, + RoundVariable { + commit: zkvm_proof_input.witin_commit, + openings: witin_openings, + perm: zkvm_proof_input.witin_perm.clone(), + }, + ); + if num_fixed_opening > 0 { + builder.set( + &rounds, + 1, + RoundVariable { + commit: fixed_commit.unwrap(), + openings: fixed_openings, + perm: zkvm_proof_input.fixed_perm, + }, + ); + } + batch_verify( + builder, + zkvm_proof_input.max_num_var, + rounds, + zkvm_proof_input.pcs_proof, + &mut challenger, + ); let empty_arr: Array> = builder.dyn_array(0); let initial_global_state = eval_ceno_expr_with_instance( @@ -288,7 +358,7 @@ pub fn verify_zkvm_proof( &empty_arr, &zkvm_proof_input.pi_evals, &challenges, - &ceno_constraint_system.vk.initial_global_state_expr, + &vk.vk.initial_global_state_expr, ); builder.assign(&prod_w, prod_w * initial_global_state); @@ -299,7 +369,7 @@ pub fn verify_zkvm_proof( &empty_arr, &zkvm_proof_input.pi_evals, &challenges, - &ceno_constraint_system.vk.finalize_global_state_expr, + &vk.vk.finalize_global_state_expr, ); builder.assign(&prod_r, prod_r * finalize_global_state); @@ -311,20 +381,19 @@ pub fn verify_zkvm_proof( pub fn verify_opcode_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, - opcode_proof: &ZKVMOpcodeProofInputVariable, + opcode_proof: &ZKVMChipProofInputVariable, pi_evals: &Array>, challenges: &Array>, - subcircuit_params: &SubcircuitParams, - cs: &ZKVMVerifier, + vk: &VerifyingKey, unipoly_extrapolator: &mut UniPolyExtrapolator, ) -> Array> { - let cs = &cs.vk.circuit_vks[&subcircuit_params.name].cs; + let cs = vk.get_cs(); let one: Ext = builder.constant(C::EF::ONE); let zero: Ext = builder.constant(C::EF::ZERO); - let r_len = cs.r_expressions.len(); - let w_len = cs.w_expressions.len(); - let lk_len = cs.lk_expressions.len(); + let r_len = cs.zkvm_v1_css.r_expressions.len(); + let w_len = cs.zkvm_v1_css.w_expressions.len(); + let lk_len = cs.zkvm_v1_css.lk_expressions.len(); let num_batched = r_len + w_len + lk_len; let chip_record_alpha: Ext = builder.get(challenges, 0); @@ -383,7 +452,7 @@ pub fn verify_opcode_proof( let alpha_len = builder.eval( num_rw_records.clone() + lk_counts_per_instance - + Usize::from(cs.assert_zero_sumcheck_expressions.len()), + + Usize::from(cs.zkvm_v1_css.assert_zero_sumcheck_expressions.len()), ); transcript_observe_label(builder, challenger, b"combine subset evals"); let alpha_pow = gen_alpha_pows(builder, challenger, alpha_len); @@ -411,7 +480,7 @@ pub fn verify_opcode_proof( let log2_num_instances_var: Var = RVar::from(log2_num_instances.clone()).variable(); let log2_num_instances_f: Felt = builder.unsafe_cast_var_to_felt(log2_num_instances_var); - let max_non_lc_degree: usize = cs.max_non_lc_degree; + let max_non_lc_degree: usize = cs.zkvm_v1_css.max_non_lc_degree; let main_sel_subclaim_max_degree: Felt = builder.constant(C::F::from_canonical_u32( SEL_DEGREE.max(max_non_lc_degree + 1) as u32, )); @@ -442,9 +511,10 @@ pub fn verify_opcode_proof( let empty_arr: Array> = builder.dyn_array(0); let rw_expressions_sum: Ext = builder.constant(C::EF::ZERO); - cs.r_expressions + cs.zkvm_v1_css + .r_expressions .iter() - .chain(cs.w_expressions.iter()) + .chain(cs.zkvm_v1_css.w_expressions.iter()) .for_each(|expr| { let e = eval_ceno_expr_with_instance( builder, @@ -462,7 +532,7 @@ pub fn verify_opcode_proof( builder.assign(&rw_expressions_sum, rw_expressions_sum * sel); let lk_expressions_sum: Ext = builder.constant(C::EF::ZERO); - cs.lk_expressions.iter().for_each(|expr| { + cs.zkvm_v1_css.lk_expressions.iter().for_each(|expr| { let e = eval_ceno_expr_with_instance( builder, &empty_arr, @@ -482,21 +552,24 @@ pub fn verify_opcode_proof( builder.assign(&lk_expressions_sum, lk_expressions_sum * sel); let zero_expressions_sum: Ext = builder.constant(C::EF::ZERO); - cs.assert_zero_sumcheck_expressions.iter().for_each(|expr| { - // evaluate zero expression by all wits_in_evals because they share the unique input_opening_point opening - let e = eval_ceno_expr_with_instance( - builder, - &empty_arr, - &opcode_proof.wits_in_evals, - &empty_arr, - pi_evals, - challenges, - expr, - ); - let alpha = builder.get(&alpha_pow, alpha_idx); - builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); - builder.assign(&zero_expressions_sum, zero_expressions_sum + alpha * e); - }); + cs.zkvm_v1_css + .assert_zero_sumcheck_expressions + .iter() + .for_each(|expr| { + // evaluate zero expression by all wits_in_evals because they share the unique input_opening_point opening + let e = eval_ceno_expr_with_instance( + builder, + &empty_arr, + &opcode_proof.wits_in_evals, + &empty_arr, + pi_evals, + challenges, + expr, + ); + let alpha = builder.get(&alpha_pow, alpha_idx); + builder.assign(&alpha_idx, alpha_idx + Usize::from(1)); + builder.assign(&zero_expressions_sum, zero_expressions_sum + alpha * e); + }); builder.assign(&zero_expressions_sum, zero_expressions_sum * sel); let computed_eval: Ext = @@ -504,18 +577,21 @@ pub fn verify_opcode_proof( builder.assert_ext_eq(computed_eval, expected_evaluation); // verify zero expression (degree = 1) statement, thus no sumcheck - cs.assert_zero_expressions.iter().for_each(|expr| { - let e = eval_ceno_expr_with_instance( - builder, - &empty_arr, - &opcode_proof.wits_in_evals, - &empty_arr, - pi_evals, - challenges, - expr, - ); - builder.assert_ext_eq(e, zero); - }); + cs.zkvm_v1_css + .assert_zero_expressions + .iter() + .for_each(|expr| { + let e = eval_ceno_expr_with_instance( + builder, + &empty_arr, + &opcode_proof.wits_in_evals, + &empty_arr, + pi_evals, + challenges, + expr, + ); + builder.assert_ext_eq(e, zero); + }); input_opening_point } @@ -523,23 +599,23 @@ pub fn verify_opcode_proof( pub fn verify_table_proof( builder: &mut Builder, challenger: &mut DuplexChallengerVariable, - table_proof: &ZKVMTableProofInputVariable, - raw_pi: &Array>>, - raw_pi_num_variables: &Array>, + table_proof: &ZKVMChipProofInputVariable, + // raw_pi: &Array>>, + // raw_pi_num_variables: &Array>, pi_evals: &Array>, challenges: &Array>, - subcircuit_params: &SubcircuitParams, - cs: &ZKVMVerifier, + vk: &VerifyingKey, unipoly_extrapolator: &mut UniPolyExtrapolator, - poly_evaluator: &mut PolyEvaluator, + // poly_evaluator: &mut PolyEvaluator, ) -> Array> { - let cs = cs.vk.circuit_vks[&subcircuit_params.name].get_cs(); + let cs = vk.get_cs(); let tower_proof: &super::binding::TowerProofInputVariable = &table_proof.tower_proof; let r_expected_rounds: Array> = - builder.dyn_array(cs.r_table_expressions.len() * 2); + builder.dyn_array(cs.zkvm_v1_css.r_table_expressions.len() * 2); cs // only iterate r set, as read/write set round should match + .zkvm_v1_css .r_table_expressions .iter() .enumerate() @@ -563,8 +639,9 @@ pub fn verify_table_proof( }); let lk_expected_rounds: Array> = - builder.dyn_array(cs.lk_table_expressions.len()); - cs.lk_table_expressions + builder.dyn_array(cs.zkvm_v1_css.lk_table_expressions.len()); + cs.zkvm_v1_css + .lk_table_expressions .iter() .enumerate() .for_each(|(idx, expr)| { @@ -611,15 +688,17 @@ pub fn verify_table_proof( builder.assert_usize_eq( logup_q_point_and_eval.len(), - Usize::from(cs.lk_table_expressions.len()), + Usize::from(cs.zkvm_v1_css.lk_table_expressions.len()), ); builder.assert_usize_eq( logup_p_point_and_eval.len(), - Usize::from(cs.lk_table_expressions.len()), + Usize::from(cs.zkvm_v1_css.lk_table_expressions.len()), ); builder.assert_usize_eq( prod_point_and_eval.len(), - Usize::from(cs.r_table_expressions.len() + cs.w_table_expressions.len()), + Usize::from( + cs.zkvm_v1_css.r_table_expressions.len() + cs.zkvm_v1_css.w_table_expressions.len(), + ), ); // in table proof, we always skip same point sumcheck for now @@ -628,10 +707,16 @@ pub fn verify_table_proof( // evaluate structural witness from verifier let set_table_exprs = cs + .zkvm_v1_css .r_table_expressions .iter() .map(|r| &r.table_spec) - .chain(cs.lk_table_expressions.iter().map(|r| &r.table_spec)) + .chain( + cs.zkvm_v1_css + .lk_table_expressions + .iter() + .map(|r| &r.table_spec), + ) .collect::>(); let structural_witnesses_vec: Vec> = set_table_exprs .iter() @@ -709,12 +794,13 @@ pub fn verify_table_proof( // verify records (degree = 1) statement, thus no sumcheck interleave( - &cs.r_table_expressions, // r - &cs.w_table_expressions, // w + &cs.zkvm_v1_css.r_table_expressions, // r + &cs.zkvm_v1_css.w_table_expressions, // w ) .map(|rw| &rw.expr) .chain( - cs.lk_table_expressions + cs.zkvm_v1_css + .lk_table_expressions .iter() .flat_map(|lk| vec![&lk.multiplicity, &lk.values]), // p, q ) @@ -734,8 +820,9 @@ pub fn verify_table_proof( builder.assert_ext_eq(e, expected_evals); }); + /* TODO: enable this // assume public io is tiny vector, so we evaluate it directly without PCS - for &Instance(idx) in cs.instance_name_map.keys() { + for &Instance(idx) in cs.instance_name_map().keys() { let poly = builder.get(raw_pi, idx); let poly_num_vars = builder.get(raw_pi_num_variables, idx); let eval_point = rt_tower.fs.slice(builder, 0, poly_num_vars); @@ -743,6 +830,7 @@ pub fn verify_table_proof( let eval = builder.get(&pi_evals, idx); builder.assert_ext_eq(eval, expected_eval); } + */ rt_tower.fs }