diff --git a/.github/workflows/benchmarks-execute.yml b/.github/workflows/benchmarks-execute.yml index 741ccdb0f1..fdfdeb9047 100644 --- a/.github/workflows/benchmarks-execute.yml +++ b/.github/workflows/benchmarks-execute.yml @@ -1,8 +1,9 @@ -name: "benchmarks-execute" +name: "Execution benchmarks" on: push: - branches: ["main"] + # TODO(ayush): remove after feat/new-execution is merged + branches: ["main", "feat/new-execution"] pull_request: types: [opened, synchronize, reopened, labeled] branches: ["**"] @@ -18,6 +19,10 @@ on: - ".github/workflows/benchmarks-execute.yml" workflow_dispatch: +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} + cancel-in-progress: true + env: CARGO_TERM_COLOR: always @@ -26,15 +31,14 @@ jobs: runs-on: - runs-on=${{ github.run_id }} - runner=8cpu-linux-x64 + - extras=s3-cache steps: + - uses: runs-on/action@v1 - uses: actions/checkout@v4 - - - name: Set up Rust - uses: actions-rs/toolchain@v1 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 with: - profile: minimal - toolchain: stable - override: true + cache-on-failure: true - name: Run execution benchmarks working-directory: benchmarks/execute @@ -110,3 +114,41 @@ jobs: echo -e "\nBenchmark Summary:" cat "$SUMMARY_FILE" fi + + codspeed-benchmarks: + name: Run codspeed ${{ matrix.mode }} benchmarks + runs-on: + - runs-on=${{ github.run_id }} + - runner=8cpu-linux-x64 + - extras=s3-cache + + strategy: + matrix: + mode: [instrumentation, walltime] + env: + CODSPEED_RUNNER_MODE: ${{ matrix.mode }} + + steps: + - uses: runs-on/action@v1 + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + with: + cache-on-failure: true + + - name: Install cargo-binstall + uses: cargo-bins/cargo-binstall@main + - name: Install codspeed + run: cargo binstall --no-confirm --force cargo-codspeed + + - name: Build benchmarks + working-directory: benchmarks/execute + run: cargo codspeed build + - name: Run benchmarks + uses: CodSpeedHQ/action@v3 + with: + working-directory: benchmarks/execute + run: cargo codspeed run + token: ${{ secrets.CODSPEED_TOKEN }} + env: + CODSPEED_RUNNER_MODE: ${{ matrix.mode }} diff --git a/.gitignore b/.gitignore index d794a5dc57..15fa79c61f 100644 --- a/.gitignore +++ b/.gitignore @@ -37,3 +37,6 @@ guest.syms # openvm generated files crates/cli/openvm/ + +# samply profile +profile.json.gz diff --git a/Cargo.lock b/Cargo.lock index 0faf467b81..ccce87fa32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -51,9 +51,9 @@ dependencies = [ [[package]] name = "ahash" -version = "0.8.11" +version = "0.8.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" dependencies = [ "cfg-if", "once_cell", @@ -76,28 +76,61 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "alloy-eip2124" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "741bdd7499908b3aa0b159bba11e71c8cddd009a2c2eb7a06e825f1ec87900a5" +dependencies = [ + "alloy-primitives 1.1.0", + "alloy-rlp", + "crc", + "serde", + "thiserror 2.0.12", +] + [[package]] name = "alloy-eip2930" -version = "0.1.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0069cf0642457f87a01a014f6dc29d5d893cd4fd8fddf0c3cdfad1bb3ebafc41" +checksum = "7b82752a889170df67bbb36d42ca63c531eb16274f0d7299ae2a680facba17bd" dependencies = [ - "alloy-primitives 0.8.25", + "alloy-primitives 1.1.0", "alloy-rlp", "serde", ] [[package]] name = "alloy-eip7702" -version = "0.4.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c986539255fb839d1533c128e190e557e52ff652c9ef62939e233a81dd93f7e" +checksum = "804cefe429015b4244966c006d25bda5545fa9db5990e9c9079faf255052f50a" dependencies = [ - "alloy-primitives 0.8.25", + "alloy-primitives 1.1.0", "alloy-rlp", - "derive_more 1.0.0", "k256", "serde", + "thiserror 2.0.12", +] + +[[package]] +name = "alloy-eips" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "609515c1955b33af3d78d26357540f68c5551a90ef58fd53def04f2aa074ec43" +dependencies = [ + "alloy-eip2124", + "alloy-eip2930", + "alloy-eip7702", + "alloy-primitives 1.1.0", + "alloy-rlp", + "alloy-serde", + "auto_impl", + "c-kzg", + "derive_more 2.0.1", + "either", + "serde", + "sha2 0.10.9", ] [[package]] @@ -121,7 +154,7 @@ dependencies = [ "bytes", "cfg-if", "const-hex", - "derive_more 0.99.19", + "derive_more 0.99.20", "hex-literal", "itoa", "ruint", @@ -138,7 +171,7 @@ dependencies = [ "bytes", "cfg-if", "const-hex", - "derive_more 0.99.19", + "derive_more 0.99.20", "hex-literal", "itoa", "ruint", @@ -157,14 +190,41 @@ dependencies = [ "const-hex", "derive_more 2.0.1", "foldhash", - "hashbrown 0.15.2", - "indexmap 2.7.1", + "hashbrown 0.15.3", + "indexmap 2.9.0", "itoa", "k256", "keccak-asm", "paste", "proptest", - "rand", + "rand 0.8.5", + "ruint", + "rustc-hash 2.1.1", + "serde", + "sha3", + "tiny-keccak", +] + +[[package]] +name = "alloy-primitives" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a12fe11d0b8118e551c29e1a67ccb6d01cc07ef08086df30f07487146de6fa1" +dependencies = [ + "alloy-rlp", + "bytes", + "cfg-if", + "const-hex", + "derive_more 2.0.1", + "foldhash", + "hashbrown 0.15.3", + "indexmap 2.9.0", + "itoa", + "k256", + "keccak-asm", + "paste", + "proptest", + "rand 0.9.1", "ruint", "rustc-hash 2.1.1", "serde", @@ -191,7 +251,18 @@ checksum = "a40e1ef334153322fd878d07e86af7a529bcb86b2439525920a88eba87bcf943" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", +] + +[[package]] +name = "alloy-serde" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4dba6ff08916bc0a9cbba121ce21f67c0b554c39cf174bc7b9df6c651bd3c3b" +dependencies = [ + "alloy-primitives 1.1.0", + "serde", + "serde_json", ] [[package]] @@ -205,7 +276,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -218,11 +289,11 @@ dependencies = [ "alloy-sol-macro-input", "const-hex", "heck", - "indexmap 2.7.1", + "indexmap 2.9.0", "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "syn-solidity", "tiny-keccak", ] @@ -241,7 +312,7 @@ dependencies = [ "proc-macro2", "quote", "serde_json", - "syn 2.0.98", + "syn 2.0.101", "syn-solidity", ] @@ -252,7 +323,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d162f8524adfdfb0e4bd0505c734c985f3e2474eb022af32eef0d52a4f3935c" dependencies = [ "serde", - "winnow 0.7.3", + "winnow 0.7.10", ] [[package]] @@ -350,9 +421,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.96" +version = "1.0.98" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b964d184e89d9b6b67dd2715bc8e74cf3107fb2b529990c90cf517326150bf4" +checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487" [[package]] name = "ariadne" @@ -364,6 +435,51 @@ dependencies = [ "yansi 0.5.1", ] +[[package]] +name = "ark-bls12-381" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3df4dcc01ff89867cd86b0da835f23c3f02738353aaee7dde7495af71363b8d5" +dependencies = [ + "ark-ec", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", +] + +[[package]] +name = "ark-bn254" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d69eab57e8d2663efa5c63135b2af4f396d66424f88954c21104125ab6b3e6bc" +dependencies = [ + "ark-ec", + "ark-ff 0.5.0", + "ark-r1cs-std", + "ark-std 0.5.0", +] + +[[package]] +name = "ark-ec" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d68f2d516162846c1238e755a7c4d131b892b70cc70c471a8e3ca3ed818fce" +dependencies = [ + "ahash", + "ark-ff 0.5.0", + "ark-poly", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.3", + "itertools 0.13.0", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "zeroize", +] + [[package]] name = "ark-ff" version = "0.3.0" @@ -402,6 +518,26 @@ dependencies = [ "zeroize", ] +[[package]] +name = "ark-ff" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a177aba0ed1e0fbb62aa9f6d0502e9b46dad8c2eab04c14258a1212d2557ea70" +dependencies = [ + "ark-ff-asm 0.5.0", + "ark-ff-macros 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "educe", + "itertools 0.13.0", + "num-bigint 0.4.6", + "num-traits", + "paste", + "zeroize", +] + [[package]] name = "ark-ff-asm" version = "0.3.0" @@ -422,6 +558,16 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-ff-asm" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" +dependencies = [ + "quote", + "syn 2.0.101", +] + [[package]] name = "ark-ff-macros" version = "0.3.0" @@ -447,6 +593,63 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "ark-ff-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09be120733ee33f7693ceaa202ca41accd5653b779563608f1234f78ae07c4b3" +dependencies = [ + "num-bigint 0.4.6", + "num-traits", + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "ark-poly" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "579305839da207f02b89cd1679e50e67b4331e2f9294a57693e5051b7703fe27" +dependencies = [ + "ahash", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", + "ark-std 0.5.0", + "educe", + "fnv", + "hashbrown 0.15.3", +] + +[[package]] +name = "ark-r1cs-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "941551ef1df4c7a401de7068758db6503598e6f01850bdb2cfdb614a1f9dbea1" +dependencies = [ + "ark-ec", + "ark-ff 0.5.0", + "ark-relations", + "ark-std 0.5.0", + "educe", + "num-bigint 0.4.6", + "num-integer", + "num-traits", + "tracing", +] + +[[package]] +name = "ark-relations" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec46ddc93e7af44bcab5230937635b06fb5744464dd6a7e7b083e80ebd274384" +dependencies = [ + "ark-ff 0.5.0", + "ark-std 0.5.0", + "tracing", + "tracing-subscriber 0.2.25", +] + [[package]] name = "ark-serialize" version = "0.3.0" @@ -468,6 +671,30 @@ dependencies = [ "num-bigint 0.4.6", ] +[[package]] +name = "ark-serialize" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f4d068aaf107ebcd7dfb52bc748f8030e0fc930ac8e360146ca54c1203088f7" +dependencies = [ + "ark-serialize-derive", + "ark-std 0.5.0", + "arrayvec", + "digest 0.10.7", + "num-bigint 0.4.6", +] + +[[package]] +name = "ark-serialize-derive" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "ark-std" version = "0.3.0" @@ -475,7 +702,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1df2c09229cbc5a028b1d70e00fdb2acee28b1055dfb5ca73eea49c5a25c4e7c" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", ] [[package]] @@ -485,7 +712,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94893f1e0c6eeab764ade8dc4c0db24caf4fe7cbbaafc0eba0a9030f447b5185" dependencies = [ "num-traits", - "rand", + "rand 0.8.5", +] + +[[package]] +name = "ark-std" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "246a225cc6131e9ee4f24619af0f19d67761fff15d7ccc22e42b80846e69449a" +dependencies = [ + "num-traits", + "rand 0.8.5", ] [[package]] @@ -511,13 +748,13 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.86" +version = "0.1.88" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d" +checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -529,6 +766,12 @@ dependencies = [ "bytemuck", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "aurora-engine-modexp" version = "1.2.0" @@ -541,13 +784,13 @@ dependencies = [ [[package]] name = "auto_impl" -version = "1.2.1" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e12882f59de5360c748c4cbf569a042d5fb0eb515f7bea9c1f470b47f6ffbd73" +checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -558,9 +801,9 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "aws-config" -version = "1.5.18" +version = "1.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90aff65e86db5fe300752551c1b015ef72b708ac54bded8ef43d0d53cb7cb0b1" +checksum = "b6fcc63c9860579e4cb396239570e979376e70aab79e496621748a09913f8b36" dependencies = [ "aws-credential-types", "aws-runtime", @@ -568,7 +811,7 @@ dependencies = [ "aws-sdk-ssooidc", "aws-sdk-sts", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -577,7 +820,7 @@ dependencies = [ "bytes", "fastrand", "hex", - "http 0.2.12", + "http 1.3.1", "ring", "time", "tokio", @@ -588,9 +831,9 @@ dependencies = [ [[package]] name = "aws-credential-types" -version = "1.2.1" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60e8f6b615cb5fc60a98132268508ad104310f0cfb25a1c22eee76efdf9154da" +checksum = "687bc16bc431a8533fe0097c7f0182874767f920989d7260950172ae8e3c4465" dependencies = [ "aws-smithy-async", "aws-smithy-runtime-api", @@ -598,17 +841,40 @@ dependencies = [ "zeroize", ] +[[package]] +name = "aws-lc-rs" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fcc8f365936c834db5514fc45aee5b1202d677e6b40e48468aaaa8183ca8c7" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.29.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61b1d86e7705efe1be1b569bab41d4fa1e14e220b60a160f78de2db687add079" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "aws-runtime" -version = "1.5.5" +version = "1.5.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76dd04d39cc12844c0994f2c9c5a6f5184c22e9188ec1ff723de41910a21dcad" +checksum = "6c4063282c69991e57faab9e5cb21ae557e59f5b0fb285c196335243df8dc25c" dependencies = [ "aws-credential-types", "aws-sigv4", "aws-smithy-async", "aws-smithy-eventstream", - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", @@ -617,7 +883,6 @@ dependencies = [ "fastrand", "http 0.2.12", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "tracing", @@ -626,9 +891,9 @@ dependencies = [ [[package]] name = "aws-sdk-s3" -version = "1.78.0" +version = "1.85.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3038614b6cf7dd68d9a7b5b39563d04337eb3678d1d4173e356e927b0356158a" +checksum = "d5c82dae9304e7ced2ff6cca43dceb2d6de534c95a506ff0f168a7463c9a885d" dependencies = [ "aws-credential-types", "aws-runtime", @@ -636,7 +901,7 @@ dependencies = [ "aws-smithy-async", "aws-smithy-checksums", "aws-smithy-eventstream", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", @@ -648,32 +913,34 @@ dependencies = [ "hex", "hmac", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "lru", "once_cell", "percent-encoding", "regex-lite", - "sha2", + "sha2 0.10.9", "tracing", "url", ] [[package]] name = "aws-sdk-sso" -version = "1.61.0" +version = "1.67.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e65ff295979977039a25f5a0bf067a64bc5e6aa38f3cef4037cf42516265553c" +checksum = "0d4863da26489d1e6da91d7e12b10c17e86c14f94c53f416bd10e0a9c34057ba" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -682,20 +949,21 @@ dependencies = [ [[package]] name = "aws-sdk-ssooidc" -version = "1.62.0" +version = "1.68.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91430a60f754f235688387b75ee798ef00cfd09709a582be2b7525ebb5306d4f" +checksum = "95caa3998d7237789b57b95a8e031f60537adab21fa84c91e35bef9455c652e4" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-runtime", "aws-smithy-runtime-api", "aws-smithy-types", "aws-types", "bytes", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -704,14 +972,14 @@ dependencies = [ [[package]] name = "aws-sdk-sts" -version = "1.62.0" +version = "1.68.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9276e139d39fff5a0b0c984fc2d30f970f9a202da67234f948fda02e5bea1dbe" +checksum = "4939f6f449a37308a78c5a910fd91265479bd2bb11d186f0b8fc114d89ec828d" dependencies = [ "aws-credential-types", "aws-runtime", "aws-smithy-async", - "aws-smithy-http 0.61.1", + "aws-smithy-http", "aws-smithy-json", "aws-smithy-query", "aws-smithy-runtime", @@ -719,6 +987,7 @@ dependencies = [ "aws-smithy-types", "aws-smithy-xml", "aws-types", + "fastrand", "http 0.2.12", "once_cell", "regex-lite", @@ -727,13 +996,13 @@ dependencies = [ [[package]] name = "aws-sigv4" -version = "1.2.9" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9bfe75fad52793ce6dec0dc3d4b1f388f038b5eb866c8d4d7f3a8e21b5ea5051" +checksum = "3503af839bd8751d0bdc5a46b9cac93a003a353e635b0c12cf2376b5b53e41ea" dependencies = [ "aws-credential-types", "aws-smithy-eventstream", - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", @@ -742,12 +1011,11 @@ dependencies = [ "hex", "hmac", "http 0.2.12", - "http 1.2.0", - "once_cell", + "http 1.3.1", "p256 0.11.1", "percent-encoding", "ring", - "sha2", + "sha2 0.10.9", "subtle", "time", "tracing", @@ -756,9 +1024,9 @@ dependencies = [ [[package]] name = "aws-smithy-async" -version = "1.2.4" +version = "1.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa59d1327d8b5053c54bf2eaae63bf629ba9e904434d0835a28ed3c0ed0a614e" +checksum = "1e190749ea56f8c42bf15dd76c65e14f8f765233e6df9b0506d9d934ebef867c" dependencies = [ "futures-util", "pin-project-lite", @@ -767,11 +1035,11 @@ dependencies = [ [[package]] name = "aws-smithy-checksums" -version = "0.63.0" +version = "0.63.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2dc8d842d872529355c72632de49ef8c5a2949a4472f10e802f28cf925770c" +checksum = "b65d21e1ba6f2cdec92044f904356a19f5ad86961acf015741106cdfafd747c0" dependencies = [ - "aws-smithy-http 0.60.12", + "aws-smithy-http", "aws-smithy-types", "bytes", "crc32c", @@ -783,15 +1051,15 @@ dependencies = [ "md-5", "pin-project-lite", "sha1", - "sha2", + "sha2 0.10.9", "tracing", ] [[package]] name = "aws-smithy-eventstream" -version = "0.60.7" +version = "0.60.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "461e5e02f9864cba17cff30f007c2e37ade94d01e87cdb5204e44a84e6d38c17" +checksum = "7c45d3dddac16c5c59d553ece225a88870cf81b7b813c9cc17b78cf4685eac7a" dependencies = [ "aws-smithy-types", "bytes", @@ -800,18 +1068,19 @@ dependencies = [ [[package]] name = "aws-smithy-http" -version = "0.60.12" +version = "0.62.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7809c27ad8da6a6a68c454e651d4962479e81472aa19ae99e59f9aba1f9713cc" +checksum = "99335bec6cdc50a346fda1437f9fefe33abf8c99060739a546a16457f2862ca9" dependencies = [ + "aws-smithy-eventstream", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "bytes-utils", "futures-core", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", - "once_cell", "percent-encoding", "pin-project-lite", "pin-utils", @@ -819,35 +1088,51 @@ dependencies = [ ] [[package]] -name = "aws-smithy-http" -version = "0.61.1" +name = "aws-smithy-http-client" +version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6f276f21c7921fe902826618d1423ae5bf74cf8c1b8472aee8434f3dfd31824" +checksum = "7e44697a9bded898dcd0b1cb997430d949b87f4f8940d91023ae9062bf218250" dependencies = [ - "aws-smithy-eventstream", + "aws-smithy-async", "aws-smithy-runtime-api", "aws-smithy-types", - "bytes", - "bytes-utils", - "futures-core", + "h2 0.4.10", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", - "once_cell", - "percent-encoding", + "hyper 0.14.32", + "hyper 1.6.0", + "hyper-rustls 0.24.2", + "hyper-rustls 0.27.5", + "hyper-util", "pin-project-lite", - "pin-utils", + "rustls 0.21.12", + "rustls 0.23.27", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tower", "tracing", ] [[package]] name = "aws-smithy-json" -version = "0.61.2" +version = "0.61.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "623a51127f24c30776c8b374295f2df78d92517386f77ba30773f15a30ce1422" +checksum = "92144e45819cae7dc62af23eac5a038a58aa544432d2102609654376a900bd07" dependencies = [ "aws-smithy-types", ] +[[package]] +name = "aws-smithy-observability" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9364d5989ac4dd918e5cc4c4bdcc61c9be17dcd2586ea7f69e348fc7c6cab393" +dependencies = [ + "aws-smithy-runtime-api", +] + [[package]] name = "aws-smithy-query" version = "0.60.7" @@ -860,42 +1145,39 @@ dependencies = [ [[package]] name = "aws-smithy-runtime" -version = "1.7.8" +version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d526a12d9ed61fadefda24abe2e682892ba288c2018bcb38b1b4c111d13f6d92" +checksum = "14302f06d1d5b7d333fd819943075b13d27c7700b414f574c3c35859bfb55d5e" dependencies = [ "aws-smithy-async", - "aws-smithy-http 0.60.12", + "aws-smithy-http", + "aws-smithy-http-client", + "aws-smithy-observability", "aws-smithy-runtime-api", "aws-smithy-types", "bytes", "fastrand", - "h2", "http 0.2.12", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", - "httparse", - "hyper", - "hyper-rustls", - "once_cell", "pin-project-lite", "pin-utils", - "rustls", "tokio", "tracing", ] [[package]] name = "aws-smithy-runtime-api" -version = "1.7.3" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92165296a47a812b267b4f41032ff8069ab7ff783696d217f0994a0d7ab585cd" +checksum = "a1e5d9e3a80a18afa109391fb5ad09c3daf887b516c6fd805a157c6ea7994a57" dependencies = [ "aws-smithy-async", "aws-smithy-types", "bytes", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "pin-project-lite", "tokio", "tracing", @@ -904,16 +1186,16 @@ dependencies = [ [[package]] name = "aws-smithy-types" -version = "1.2.13" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7b8a53819e42f10d0821f56da995e1470b199686a1809168db6ca485665f042" +checksum = "40076bd09fadbc12d5e026ae080d0930defa606856186e31d83ccc6a255eeaf3" dependencies = [ "base64-simd", "bytes", "bytes-utils", "futures-core", "http 0.2.12", - "http 1.2.0", + "http 1.3.1", "http-body 0.4.6", "http-body 1.0.1", "http-body-util", @@ -939,9 +1221,9 @@ dependencies = [ [[package]] name = "aws-types" -version = "1.3.5" +version = "1.3.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dfbd0a668309ec1f66c0f6bda4840dd6d4796ae26d699ebc266d7cc95c6d040f" +checksum = "8a322fec39e4df22777ed3ad8ea868ac2f94cd15e1a55f6ee8d8d6305057689a" dependencies = [ "aws-credential-types", "aws-smithy-async", @@ -953,9 +1235,9 @@ dependencies = [ [[package]] name = "backtrace" -version = "0.3.74" +version = "0.3.75" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d82cb332cdfaed17ae235a638438ac4d4839913cc2af585c3c6746e8f8bee1a" +checksum = "6806a6321ec58106fea15becdad98371e28d92ccbc7c8f1b3b6dd724fe8f1002" dependencies = [ "addr2line", "cfg-if", @@ -1009,9 +1291,9 @@ dependencies = [ [[package]] name = "base64ct" -version = "1.6.0" +version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" [[package]] name = "bincode" @@ -1022,13 +1304,45 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.101", + "which", +] + [[package]] name = "bit-set" version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" dependencies = [ - "bit-vec", + "bit-vec 0.6.3", +] + +[[package]] +name = "bit-set" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08807e080ed7f9d5433fa9b275196cfc35414f66a0c79d864dc51a0d825231a3" +dependencies = [ + "bit-vec 0.8.0", ] [[package]] @@ -1037,11 +1351,17 @@ version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" +[[package]] +name = "bit-vec" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e764a1d40d510daf35e07be9eb06e75770908c27d411ee6c92109c9840eaaf7" + [[package]] name = "bitcode" -version = "0.6.5" +version = "0.6.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18c1406a27371b2f76232a2259df6ab607b91b5a0a7476a7729ff590df5a969a" +checksum = "cf300f4aa6e66f3bdff11f1236a88c622fe47ea814524792240b4d554d9858ee" dependencies = [ "arrayvec", "bitcode_derive", @@ -1058,7 +1378,23 @@ checksum = "42b6b4cb608b8282dc3b53d0f4c9ab404655d562674c682db7e6c0458cc83c23" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", +] + +[[package]] +name = "bitcoin-io" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b47c4ab7a93edb0c7198c5535ed9b52b63095f4e9b45279c6736cec4b856baf" + +[[package]] +name = "bitcoin_hashes" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb18c03d0db0247e147a21a6faafd5a7eb851c743db062de72018b6b7e8e4d16" +dependencies = [ + "bitcoin-io", + "hex-conservative", ] [[package]] @@ -1069,9 +1405,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.8.0" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" [[package]] name = "bitvec" @@ -1107,16 +1443,24 @@ dependencies = [ [[package]] name = "blake3" -version = "1.6.0" +version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1230237285e3e10cde447185e8975408ae24deaa67205ce684805c25bc0c7937" +checksum = "3888aaa89e4b2a40fca9848e400f6a658a5a3978de7be858e209cafa8be9a4a0" dependencies = [ "arrayref", "arrayvec", "cc", "cfg-if", "constant_time_eq 0.3.1", - "memmap2", +] + +[[package]] +name = "block-buffer" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4" +dependencies = [ + "generic-array", ] [[package]] @@ -1137,7 +1481,7 @@ dependencies = [ "ff 0.12.1", "group 0.12.1", "pairing 0.22.0", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -1155,9 +1499,9 @@ dependencies = [ [[package]] name = "bon" -version = "3.3.2" +version = "3.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7acc34ff59877422326db7d6f2d845a582b16396b6b08194942bf34c6528ab" +checksum = "ced38439e7a86a4761f7f7d5ded5ff009135939ecb464a24452eaa4c1696af7d" dependencies = [ "bon-macros", "rustversion", @@ -1165,9 +1509,9 @@ dependencies = [ [[package]] name = "bon-macros" -version = "3.3.2" +version = "3.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4159dd617a7fbc9be6a692fe69dc2954f8e6bb6bb5e4d7578467441390d77fd0" +checksum = "0ce61d2d3844c6b8d31b2353d9f66cf5e632b3e9549583fe3cac2f4f6136725e" dependencies = [ "darling", "ident_case", @@ -1175,7 +1519,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -1202,15 +1546,15 @@ checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf" [[package]] name = "byte-slice-cast" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" +checksum = "7575182f7272186991736b70173b0ea045398f984bf5ebbb3804736ce1330c9d" [[package]] name = "bytemuck" -version = "1.21.0" +version = "1.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef657dfab802224e671f5818e9a4935f9b1957ed18e58292690cc39e7a4092a3" +checksum = "9134a6ef01ce4b366b50689c94f82c14bc72bc5d0386829828a2e2752ef7958c" [[package]] name = "byteorder" @@ -1220,9 +1564,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9" +checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" dependencies = [ "serde", ] @@ -1259,9 +1603,9 @@ dependencies = [ [[package]] name = "c-kzg" -version = "1.0.3" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0307f72feab3300336fb803a57134159f6e20139af1357f36c54cb90d8e8928" +checksum = "7318cfa722931cb5fe0838b98d3ce5621e75f6a6408abc21721d80de9223f2e4" dependencies = [ "blst", "cc", @@ -1302,7 +1646,7 @@ dependencies = [ "target-lexicon", "tempfile", "tokio", - "toml 0.8.20", + "toml 0.8.22", "tracing", "vergen", ] @@ -1324,7 +1668,7 @@ checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037" dependencies = [ "camino", "cargo-platform", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "thiserror 1.0.69", @@ -1338,15 +1682,24 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.14" +version = "1.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3d1b2e905a3a7b00a6141adb0e4c0bb941d11caf55349d863942a1cc44e3c9" +checksum = "32db95edf998450acc7881c932f94cd9b05c87b4b2599e8bab064753da4acfd1" dependencies = [ "jobserver", "libc", "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "1.0.0" @@ -1355,15 +1708,15 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "chrono" -version = "0.4.39" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825" +checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" dependencies = [ "android-tzdata", "iana-time-zone", "num-traits", "serde", - "windows-targets 0.52.6", + "windows-link", ] [[package]] @@ -1403,11 +1756,22 @@ dependencies = [ "inout", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" -version = "4.5.30" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" +checksum = "ed93b9805f8ba930df42c2590f05453d5ec36cbb85d018868a5b24d31f6ac000" dependencies = [ "clap_builder", "clap_derive", @@ -1415,26 +1779,27 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.30" +version = "4.5.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" +checksum = "379026ff283facf611b0ea629334361c4211d1b12ee01024eec1591133b04120" dependencies = [ "anstream", "anstyle", "clap_lex", "strsim", + "terminal_size", ] [[package]] name = "clap_derive" -version = "4.5.28" +version = "4.5.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ced95c6f4a675af3da73304b9ac4ed991640c36374e4b46795c49e17cf1ed" +checksum = "09176aae279615badda0765c0c0b3f6ed53f4709118af73cf4655d85d1530cd7" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -1443,12 +1808,90 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + +[[package]] +name = "codspeed" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93f4cce9c27c49c4f101fffeebb1826f41a9df2e7498b7cd4d95c0658b796c6c" +dependencies = [ + "colored", + "libc", + "serde", + "serde_json", + "uuid", +] + +[[package]] +name = "codspeed-divan-compat" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8620a09dfaf37b3c45f982c4b65bd8f9b0203944da3ffa705c0fcae6b84655ff" +dependencies = [ + "codspeed", + "codspeed-divan-compat-macros", + "codspeed-divan-compat-walltime", +] + +[[package]] +name = "codspeed-divan-compat-macros" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30fe872bc4214626b35d3a1706a905d0243503bb6ba3bb7be2fc59083d5d680c" +dependencies = [ + "divan-macros", + "itertools 0.14.0", + "proc-macro-crate", + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "codspeed-divan-compat-walltime" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "104caa97b36d4092d89e24e4b103b40ede1edab03c0372d19e14a33f9393132b" +dependencies = [ + "cfg-if", + "clap", + "codspeed", + "condtype", + "divan-macros", + "libc", + "regex-lite", +] + [[package]] name = "colorchoice" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" +[[package]] +name = "colored" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c" +dependencies = [ + "lazy_static", + "windows-sys 0.59.0", +] + +[[package]] +name = "condtype" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf0a07a401f374238ab8e2f11a104d2851bf9ce711ec69804834de8af45c7af" + [[package]] name = "const-default" version = "1.0.0" @@ -1522,6 +1965,16 @@ dependencies = [ "libc", ] +[[package]] +name = "core-foundation" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b55271e5c8c478ad3f38ad24ef34923091e0548492a266d19b3c0b4d82574c63" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -1539,9 +1992,9 @@ dependencies = [ [[package]] name = "crc" -version = "3.2.1" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69e6e4d7b33a94f0991c26729976b10ebde1d34c3ee82408fb536164fa10d636" +checksum = "9710d3b3739c2e349eb44fe848ad0b7c8cb1e42bd87ee49371df2f7acaf3e675" dependencies = [ "crc-catalog", ] @@ -1636,9 +2089,9 @@ dependencies = [ [[package]] name = "crossbeam-channel" -version = "0.5.14" +version = "0.5.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ba6d68e24814cb8de6bb986db8222d3a027d15872cabc0d18817bc3c0e4471" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" dependencies = [ "crossbeam-utils", ] @@ -1690,7 +2143,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef2b4b23cddf68b89b8f8069890e8c270d54e2d5fe1b143820234805e4cb17ef" dependencies = [ "generic-array", - "rand_core", + "rand_core 0.6.4", "subtle", "zeroize", ] @@ -1702,7 +2155,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" dependencies = [ "generic-array", - "rand_core", + "rand_core 0.6.4", "subtle", "zeroize", ] @@ -1719,9 +2172,9 @@ dependencies = [ [[package]] name = "darling" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f63b86c8a8826a49b8c21f08a2d07338eec8d900540f8630dc76284be802989" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ "darling_core", "darling_macro", @@ -1729,27 +2182,27 @@ dependencies = [ [[package]] name = "darling_core" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95133861a8032aaea082871032f5815eb9e98cef03fa916ab4500513994df9e5" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" dependencies = [ "fnv", "ident_case", "proc-macro2", "quote", "strsim", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] name = "darling_macro" -version = "0.20.10" +version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -1764,9 +2217,9 @@ dependencies = [ [[package]] name = "der" -version = "0.7.9" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f55bf8e7b65898637379c1b74eb1551107c8294ed26d855ceb9fd1a09cfc9bc0" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" dependencies = [ "const-oid", "zeroize", @@ -1774,9 +2227,9 @@ dependencies = [ [[package]] name = "deranged" -version = "0.3.11" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +checksum = "9c9e6a11ca8224451684bc0d7d5a7adbf8f2fd6887261a1cfc3c0432f9d4068e" dependencies = [ "powerfmt", "serde", @@ -1801,7 +2254,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -1812,20 +2265,31 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", +] + +[[package]] +name = "derive-where" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e73f2692d4bd3cac41dca28934a39894200c9fabf49586d77d0e5954af1d7902" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", ] [[package]] name = "derive_more" -version = "0.99.19" +version = "0.99.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3da29a38df43d6f156149c9b43ded5e018ddff2a855cf2cfd62e8cd7d079c69f" +checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" dependencies = [ "convert_case", "proc-macro2", "quote", "rustc_version 0.4.1", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -1854,7 +2318,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "unicode-xid", ] @@ -1866,7 +2330,7 @@ checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "unicode-xid", ] @@ -1885,7 +2349,7 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", + "block-buffer 0.10.4", "const-oid", "crypto-common", "subtle", @@ -1941,7 +2405,18 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", +] + +[[package]] +name = "divan-macros" +version = "0.1.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8dc51d98e636f5e3b0759a39257458b22619cac7e96d932da6eeb052891bb67c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", ] [[package]] @@ -1956,12 +2431,6 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" -[[package]] -name = "dyn-clone" -version = "1.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35" - [[package]] name = "ecdsa" version = "0.14.8" @@ -1980,7 +2449,7 @@ version = "0.16.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ - "der 0.7.9", + "der 0.7.10", "digest 0.10.7", "elliptic-curve 0.13.8", "rfc6979 0.4.0", @@ -1988,11 +2457,23 @@ dependencies = [ "spki 0.7.3", ] +[[package]] +name = "educe" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d7bc049e1bd8cdeb31b68bbd586a9464ecf9f3944af3958a7a9d0f8b9799417" +dependencies = [ + "enum-ordinalize", + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "either" -version = "1.13.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" [[package]] name = "elf" @@ -2014,7 +2495,7 @@ dependencies = [ "generic-array", "group 0.12.1", "pkcs8 0.9.0", - "rand_core", + "rand_core 0.6.4", "sec1 0.3.0", "subtle", "zeroize", @@ -2029,11 +2510,11 @@ dependencies = [ "base16ct 0.2.0", "crypto-bigint 0.5.5", "digest 0.10.7", - "ff 0.13.0", + "ff 0.13.1", "generic-array", "group 0.13.0", "pkcs8 0.10.2", - "rand_core", + "rand_core 0.6.4", "sec1 0.7.3", "subtle", "zeroize", @@ -2075,6 +2556,26 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" +[[package]] +name = "enum-ordinalize" +version = "4.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea0dcfa4e54eeb516fe454635a95753ddd39acda650ce703031c6973e315dd5" +dependencies = [ + "enum-ordinalize-derive", +] + +[[package]] +name = "enum-ordinalize-derive" +version = "4.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "enum_dispatch" version = "0.3.13" @@ -2084,7 +2585,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -2095,7 +2596,7 @@ checksum = "2f9ed6b3789237c8a0c1c505af1c7eb2c560df6186f01b098c3a1064ea532f38" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -2109,9 +2610,9 @@ dependencies = [ [[package]] name = "env_logger" -version = "0.11.6" +version = "0.11.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dcaee3d8e3cfc3fd92428d477bc97fc29ec8716d180c0d74c643bb26166660e0" +checksum = "13c863f0904021b108aa8b2f55046443e6b1ebde8fd4a15c399893aae4fa069f" dependencies = [ "anstream", "anstyle", @@ -2127,9 +2628,9 @@ checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" [[package]] name = "errno" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33d852cb9b869c2a9b3df2f71a3074817f01e1844f839a144f5fcef059a4eb5d" +checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e" dependencies = [ "libc", "windows-sys 0.59.0", @@ -2199,7 +2700,7 @@ dependencies = [ "k256", "num_enum", "open-fastrlp", - "rand", + "rand 0.8.5", "rlp", "serde", "serde_json", @@ -2219,7 +2720,7 @@ dependencies = [ "chrono", "ethers-core", "reqwest", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "thiserror 1.0.69", @@ -2246,10 +2747,10 @@ dependencies = [ "path-slash", "rayon", "regex", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", - "sha2", + "sha2 0.10.9", "solang-parser", "svm-rs", "svm-rs-builds", @@ -2306,31 +2807,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d013fc25338cc558c5c2cfbad646908fb23591e2404481826742b651c9af7160" dependencies = [ "bitvec", - "rand_core", + "rand_core 0.6.4", "subtle", ] [[package]] name = "ff" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" +checksum = "c0b50bfb653653f9ca9095b427bed08ab8d75a137839d9ad64eb11810d5b6393" dependencies = [ "bitvec", "byteorder", "ff_derive", - "rand_core", + "rand_core 0.6.4", "subtle", ] [[package]] name = "ff_derive" -version = "0.13.0" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9f54704be45ed286151c5e11531316eaef5b8f5af7d597b806fdb8af108d84a" +checksum = "f10d12652036b0e99197587c6ba87a8fc3031986499973c030d8b44fcc151b60" dependencies = [ "addchain", - "cfg-if", "num-bigint 0.3.3", "num-integer", "num-traits", @@ -2348,7 +2848,7 @@ dependencies = [ "atomic", "pear", "serde", - "toml 0.8.20", + "toml 0.8.22", "uncased", "version_check", ] @@ -2360,7 +2860,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "835c052cb0c08c1acf6ffd71c022172e18723949c8282f2b9f27efbc51e64534" dependencies = [ "byteorder", - "rand", + "rand 0.8.5", "rustc-hex", "static_assertions", ] @@ -2373,9 +2873,9 @@ checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" [[package]] name = "flate2" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11faaf5a5236997af9848be0bef4db95824b1d534ebc64d0f0c6cf3e67bd38dc" +checksum = "7ced92e76e966ca2fd84c8f7aa01a4aea65b0eb6648d72f7c8f3e2764a67fece" dependencies = [ "crc32fast", "miniz_oxide", @@ -2389,9 +2889,9 @@ checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" [[package]] name = "foldhash" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" [[package]] name = "forge-fmt" @@ -2438,7 +2938,7 @@ dependencies = [ "regex", "reqwest", "revm-primitives 1.3.0", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", "serde_regex", @@ -2459,6 +2959,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "funty" version = "2.0.0" @@ -2494,7 +3000,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -2544,9 +3050,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", @@ -2555,26 +3061,26 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.3.1" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" dependencies = [ "cfg-if", "libc", - "wasi 0.13.3+wasi-0.2.2", - "windows-targets 0.52.6", + "r-efi", + "wasi 0.14.2+wasi-0.2.4", ] [[package]] name = "getset" -version = "0.1.4" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded738faa0e88d3abc9d1a13cb11adc2073c400969eeb8793cf7132589959fc" +checksum = "f3586f256131df87204eb733da72e3d3eb4f343c639f4b7be279ac7c48baeafe" dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -2589,7 +3095,7 @@ version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b903b73e45dc0c6c596f2d37eccece7c1c8bb6e4407b001096387c63d0d93724" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.0", "libc", "libgit2-sys", "log", @@ -2598,9 +3104,9 @@ dependencies = [ [[package]] name = "glam" -version = "0.30.0" +version = "0.30.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17fcdf9683c406c2fc4d124afd29c0d595e22210d633cbdb8695ba9935ab1dc6" +checksum = "6b46b9ca4690308844c644e7c634d68792467260e051c8543e0c7871662b3ba7" [[package]] name = "glob" @@ -2629,7 +3135,7 @@ checksum = "5dfbfb3a6cfbd390d5c9564ab283a0349b9b9fcd46a706c1eb10e0db70bfbac7" dependencies = [ "ff 0.12.1", "memuse", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -2639,8 +3145,8 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" dependencies = [ - "ff 0.13.0", - "rand_core", + "ff 0.13.1", + "rand_core 0.6.4", "subtle", ] @@ -2656,7 +3162,26 @@ dependencies = [ "futures-sink", "futures-util", "http 0.2.12", - "indexmap 2.7.1", + "indexmap 2.9.0", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9421a676d1b147b16b82c9225157dc629087ef8ec4d5e2960f9437a90dac0a5" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.3.1", + "indexmap 2.9.0", "slab", "tokio", "tokio-util", @@ -2665,9 +3190,9 @@ dependencies = [ [[package]] name = "half" -version = "2.4.1" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +checksum = "459196ed295495a68f7d7fe1d84f6c4b7ff0e21fe3017b2f283c6fac3ad803c9" dependencies = [ "cfg-if", "crunchy", @@ -2690,14 +3215,14 @@ checksum = "62f0ca78d12ac5c893f286d7cdfe3869290305ab8cac376e2592cdc8396da102" dependencies = [ "blake2b_simd", "crossbeam", - "ff 0.13.0", + "ff 0.13.1", "group 0.13.0", "halo2curves-axiom", "itertools 0.11.0", "maybe-rayon", "pairing 0.23.0", - "rand", - "rand_core", + "rand 0.8.5", + "rand_core 0.6.4", "rayon", "rustc-hash 1.1.0", "sha3", @@ -2718,7 +3243,7 @@ dependencies = [ "num-integer", "num-traits", "poseidon-primitives", - "rand_chacha", + "rand_chacha 0.3.1", "rayon", "rustc-hash 1.1.0", "serde", @@ -2736,9 +3261,9 @@ dependencies = [ "num-bigint 0.4.6", "num-integer", "num-traits", - "rand", - "rand_chacha", - "rand_core", + "rand 0.8.5", + "rand_chacha 0.3.1", + "rand_core 0.6.4", "rayon", "serde", "serde_json", @@ -2755,7 +3280,7 @@ dependencies = [ "ff 0.12.1", "group 0.12.1", "pasta_curves 0.4.1", - "rand_core", + "rand_core 0.6.4", "rayon", ] @@ -2767,7 +3292,7 @@ checksum = "b756596082144af6e57105a20403b7b80fe9dccd085700b74fae3af523b74dba" dependencies = [ "blake2", "digest 0.10.7", - "ff 0.13.0", + "ff 0.13.1", "group 0.13.0", "halo2derive", "hex", @@ -2777,12 +3302,12 @@ dependencies = [ "num-traits", "pairing 0.23.0", "paste", - "rand", - "rand_core", + "rand 0.8.5", + "rand_core 0.6.4", "rayon", "serde", "serde_arrays", - "sha2", + "sha2 0.10.9", "static_assertions", "subtle", "unroll", @@ -2796,7 +3321,7 @@ checksum = "dd8309e4638b4f1bcf6613d72265a84074d26034c35edc5d605b5688e580b8b8" dependencies = [ "blake2b_simd", "digest 0.10.7", - "ff 0.13.0", + "ff 0.13.1", "group 0.13.0", "hex", "lazy_static", @@ -2805,12 +3330,12 @@ dependencies = [ "pairing 0.23.0", "pasta_curves 0.5.1", "paste", - "rand", - "rand_core", + "rand 0.8.5", + "rand_core 0.6.4", "rayon", "serde", "serde_arrays", - "sha2", + "sha2 0.10.9", "static_assertions", "subtle", "unroll", @@ -2848,9 +3373,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.15.2" +version = "0.15.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" +checksum = "84b26c544d002229e640969970a2e74021aadf6e2f96372b9c58eff97de08eb3" dependencies = [ "allocator-api2", "equivalent", @@ -2872,9 +3397,9 @@ checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "hermit-abi" -version = "0.4.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" +checksum = "f154ce46856750ed433c8649605bf7ed2de3bc35fd9d2a9f30cddd873c80cb08" [[package]] name = "hex" @@ -2885,6 +3410,15 @@ dependencies = [ "serde", ] +[[package]] +name = "hex-conservative" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5313b072ce3c597065a808dbf612c4c8e8590bdbf8b579508bf7a762c5eae6cd" +dependencies = [ + "arrayvec", +] + [[package]] name = "hex-literal" version = "0.4.1" @@ -2922,9 +3456,9 @@ dependencies = [ [[package]] name = "http" -version = "1.2.0" +version = "1.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea" +checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565" dependencies = [ "bytes", "fnv", @@ -2949,27 +3483,27 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1efedce1fb8e6913f23e0c92de8e62cd5b772a67e7b3946df930a62566c93184" dependencies = [ "bytes", - "http 1.2.0", + "http 1.3.1", ] [[package]] name = "http-body-util" -version = "0.1.2" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f" +checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a" dependencies = [ "bytes", - "futures-util", - "http 1.2.0", + "futures-core", + "http 1.3.1", "http-body 1.0.1", "pin-project-lite", ] [[package]] name = "httparse" -version = "1.10.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a" +checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "httpdate" @@ -2987,7 +3521,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -3002,31 +3536,90 @@ dependencies = [ ] [[package]] -name = "hyper-rustls" -version = "0.24.2" +name = "hyper" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.10", + "http 1.3.1", + "http-body 1.0.1", + "httparse", + "itoa", + "pin-project-lite", + "smallvec", + "tokio", + "want", +] + +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper 0.14.32", + "log", + "rustls 0.21.12", + "rustls-native-certs 0.6.3", + "tokio", + "tokio-rustls 0.24.1", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d191583f3da1305256f22463b9bb0471acad48a4e534a5218b9963e9c1f59b2" +dependencies = [ + "futures-util", + "http 1.3.1", + "hyper 1.6.0", + "hyper-util", + "rustls 0.23.27", + "rustls-native-certs 0.8.1", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.2", + "tower-service", +] + +[[package]] +name = "hyper-util" +version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" dependencies = [ + "bytes", + "futures-channel", "futures-util", - "http 0.2.12", - "hyper", - "log", - "rustls", - "rustls-native-certs", + "http 1.3.1", + "http-body 1.0.1", + "hyper 1.6.0", + "libc", + "pin-project-lite", + "socket2", "tokio", - "tokio-rustls", + "tower-service", + "tracing", ] [[package]] name = "iana-time-zone" -version = "0.1.61" +version = "0.1.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "235e081f3925a06703c2d0117ea8b91f042756fd6e7a6e5d901e8ca1a996b220" +checksum = "b0c919e5debc312ad217002b8048a17b7d83f80703865bbfcfebb0458b0b27d8" dependencies = [ "android_system_properties", "core-foundation-sys", "iana-time-zone-haiku", "js-sys", + "log", "wasm-bindgen", "windows-core", ] @@ -3042,21 +3635,22 @@ dependencies = [ [[package]] name = "icu_collections" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" +checksum = "200072f5d0e3614556f94a9930d5dc3e0662a652823904c3a75dc3b0af7fee47" dependencies = [ "displaydoc", + "potential_utf", "yoke", "zerofrom", "zerovec", ] [[package]] -name = "icu_locid" -version = "1.5.0" +name = "icu_locale_core" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" +checksum = "0cde2700ccaed3872079a65fb1a78f6c0a36c91570f28755dda67bc8f7d9f00a" dependencies = [ "displaydoc", "litemap", @@ -3065,31 +3659,11 @@ dependencies = [ "zerovec", ] -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - [[package]] name = "icu_normalizer" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" +checksum = "436880e8e18df4d7bbc06d58432329d6458cc84531f7ac5f024e93deadb37979" dependencies = [ "displaydoc", "icu_collections", @@ -3097,67 +3671,54 @@ dependencies = [ "icu_properties", "icu_provider", "smallvec", - "utf16_iter", - "utf8_iter", - "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" +checksum = "00210d6893afc98edb752b664b8890f0ef174c8adbb8d0be9710fa66fbbf72d3" [[package]] name = "icu_properties" -version = "1.5.1" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" +checksum = "2549ca8c7241c82f59c80ba2a6f415d931c5b58d24fb8412caa1a1f02c49139a" dependencies = [ "displaydoc", "icu_collections", - "icu_locid_transform", + "icu_locale_core", "icu_properties_data", "icu_provider", - "tinystr", + "potential_utf", + "zerotrie", "zerovec", ] [[package]] name = "icu_properties_data" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" +checksum = "8197e866e47b68f8f7d95249e172903bec06004b18b2937f1095d40a0c57de04" [[package]] name = "icu_provider" -version = "1.5.0" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" +checksum = "03c80da27b5f4187909049ee2d72f276f0d9f99a42c306bd0131ecfe04d8e5af" dependencies = [ "displaydoc", - "icu_locid", - "icu_provider_macros", + "icu_locale_core", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", + "zerotrie", "zerovec", ] -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.98", -] - [[package]] name = "ident_case" version = "1.0.1" @@ -3177,9 +3738,9 @@ dependencies = [ [[package]] name = "idna_adapter" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" +checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344" dependencies = [ "icu_normalizer", "icu_properties", @@ -3220,7 +3781,7 @@ checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -3242,12 +3803,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "2.7.1" +version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652" +checksum = "cea70ddb795996207ad57735b50c5982d8844f38ba9ee5f1aedcfb708a2aa11e" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown 0.15.3", "serde", ] @@ -3274,11 +3835,11 @@ checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130" [[package]] name = "is-terminal" -version = "0.4.15" +version = "0.4.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e19b23d53f35ce9f56aebc7d1bb4e6ac1e9c0db7ac85c8d1760c04379edced37" +checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9" dependencies = [ - "hermit-abi 0.4.0", + "hermit-abi 0.5.1", "libc", "windows-sys 0.59.0", ] @@ -3307,6 +3868,24 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -3318,16 +3897,17 @@ dependencies = [ [[package]] name = "itoa" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "jobserver" -version = "0.1.32" +version = "0.1.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +checksum = "38f262f097c174adebe41eb73d66ae9c06b2844fb0da69969647bbddd9b0538a" dependencies = [ + "getrandom 0.3.3", "libc", ] @@ -3351,7 +3931,7 @@ dependencies = [ "bls12_381", "ff 0.12.1", "group 0.12.1", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -3365,7 +3945,7 @@ dependencies = [ "ecdsa 0.16.9", "elliptic-curve 0.13.8", "once_cell", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -3394,7 +3974,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55cb077ad656299f160924eb2912aa147d7339ea7d69e1b5517326fdcec3c1ca" dependencies = [ "ascii-canvas", - "bit-set", + "bit-set 0.5.3", "ena", "itertools 0.11.0", "lalrpop-util", @@ -3426,11 +4006,17 @@ dependencies = [ "spin", ] +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" -version = "0.2.169" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "libgit2-sys" @@ -3444,17 +4030,27 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "libloading" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a793df0d7afeac54f95b471d3af7f0d4fb975699f972341a4b76988d49cdf0c" +dependencies = [ + "cfg-if", + "windows-targets 0.53.0", +] + [[package]] name = "libm" -version = "0.2.11" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8355be11b20d696c8f18f6cc018c4e372165b1fa8126cef092399c9951984ffa" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" [[package]] name = "libmimalloc-sys" -version = "0.1.39" +version = "0.1.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23aa6811d3bd4deb8a84dde645f943476d13b248d818edcf8ce0b2f37f036b44" +checksum = "ec9d6fac27761dabcd4ee73571cdb06b7022dc99089acbe5435691edffaac0f4" dependencies = [ "cc", "libc", @@ -3466,15 +4062,61 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.0", "libc", ] +[[package]] +name = "libsecp256k1" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e79019718125edc905a079a70cfa5f3820bc76139fc91d6f9abc27ea2a887139" +dependencies = [ + "arrayref", + "base64 0.22.1", + "digest 0.9.0", + "libsecp256k1-core", + "libsecp256k1-gen-ecmult", + "libsecp256k1-gen-genmult", + "rand 0.8.5", + "serde", + "sha2 0.9.9", +] + +[[package]] +name = "libsecp256k1-core" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5be9b9bb642d8522a44d533eab56c16c738301965504753b03ad1de3425d5451" +dependencies = [ + "crunchy", + "digest 0.9.0", + "subtle", +] + +[[package]] +name = "libsecp256k1-gen-ecmult" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3038c808c55c87e8a172643a7d87187fc6c4174468159cb3090659d55bcb4809" +dependencies = [ + "libsecp256k1-core", +] + +[[package]] +name = "libsecp256k1-gen-genmult" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3db8d6ba2cec9eacc40e6e8ccc98931840301f1006e95647ceb2dd5c3aa06f7c" +dependencies = [ + "libsecp256k1-core", +] + [[package]] name = "libz-sys" -version = "1.1.21" +version = "1.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df9b68e50e6e0b26f672573834882eb57759f6db9b3be2ea3c35c91188bb4eaa" +checksum = "8b70e7a7df205e92a1a4cd9aaae7898dac0aa555503cc0a649494d0d60e7651d" dependencies = [ "cc", "libc", @@ -3494,11 +4136,17 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" +[[package]] +name = "linux-raw-sys" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" + [[package]] name = "litemap" -version = "0.7.4" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" +checksum = "241eaef5fd12c88705a01fc1066c48c4b36e0dd4377dcdc7ec3942cea7a69956" [[package]] name = "lock_api" @@ -3518,9 +4166,9 @@ checksum = "9374ef4228402d4b7e403e5838cb880d9ee663314b0a900d5a6aabf0c213552e" [[package]] name = "log" -version = "0.4.25" +version = "0.4.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "lru" @@ -3528,7 +4176,7 @@ version = "0.12.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" dependencies = [ - "hashbrown 0.15.2", + "hashbrown 0.15.3", ] [[package]] @@ -3539,7 +4187,7 @@ checksum = "1b27834086c65ec3f9387b096d66e99f221cf081c2b738042aa252bcd41204e3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -3594,9 +4242,9 @@ checksum = "3d97bbf43eb4f088f8ca469930cde17fa036207c9a5e02ccc5107c4e8b17c964" [[package]] name = "metrics" -version = "0.23.0" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "884adb57038347dfbaf2d5065887b6cf4312330dc8e94bc30a1a839bd79d3261" +checksum = "3045b4193fbdc5b5681f32f11070da9be3609f189a79f3390706d42587f46bb5" dependencies = [ "ahash", "portable-atomic", @@ -3608,7 +4256,7 @@ version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62a6a1f7141f1d9bc7a886b87536bbfc97752e08b369e1e0453a9acfab5f5da4" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.9.0", "itoa", "lockfree-object-pool", "metrics", @@ -3616,7 +4264,7 @@ dependencies = [ "once_cell", "tracing", "tracing-core", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -3629,7 +4277,7 @@ dependencies = [ "crossbeam-epoch", "crossbeam-utils", "hashbrown 0.14.5", - "indexmap 2.7.1", + "indexmap 2.9.0", "metrics", "num_cpus", "ordered-float", @@ -3640,9 +4288,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.43" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68914350ae34959d83f732418d51e2427a794055d0b9529f48259ac07af65633" +checksum = "995942f432bbb4822a7e9c3faa87a695185b0d09273ba85f097b54f4e458f2af" dependencies = [ "libmimalloc-sys", ] @@ -3653,11 +4301,17 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" -version = "0.8.4" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3b1c9bd4fe1f0f8b387f6eb9eb3b4a1aa26185e5750efb9140301703f62cd1b" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] @@ -3688,6 +4342,16 @@ dependencies = [ "smallvec", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -3731,7 +4395,7 @@ checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ "num-integer", "num-traits", - "rand", + "rand 0.8.5", "serde", ] @@ -3829,7 +4493,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -3856,7 +4520,7 @@ dependencies = [ "num-bigint 0.4.6", "num-integer", "num-traits", - "rand", + "rand 0.8.5", ] [[package]] @@ -3870,15 +4534,21 @@ dependencies = [ [[package]] name = "once_cell" -version = "1.20.3" +version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" + +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" [[package]] name = "open-fastrlp" @@ -3946,11 +4616,12 @@ dependencies = [ "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", "serde-big-array", "serde_with", "strum", + "test-case", ] [[package]] @@ -3959,7 +4630,7 @@ version = "0.1.0" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -3980,7 +4651,7 @@ version = "1.1.1-rc.0" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4017,22 +4688,26 @@ dependencies = [ name = "openvm-benchmarks-execute" version = "1.1.1-rc.0" dependencies = [ - "cargo-openvm", "clap", + "codspeed-divan-compat", "criterion", "derive_more 1.0.0", "eyre", "openvm-benchmarks-utils", + "openvm-bigint-circuit", + "openvm-bigint-transpiler", "openvm-circuit", "openvm-keccak256-circuit", "openvm-keccak256-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sdk", + "openvm-sha256-circuit", + "openvm-sha256-transpiler", "openvm-stark-sdk", "openvm-transpiler", + "serde", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -4064,7 +4739,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", - "rand_chacha", + "rand_chacha 0.3.1", "serde", "tiny-keccak", "tokio", @@ -4082,7 +4757,7 @@ dependencies = [ "openvm-transpiler", "tempfile", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -4102,8 +4777,9 @@ dependencies = [ "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", + "test-case", ] [[package]] @@ -4185,7 +4861,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "p3-baby-bear", - "rand", + "rand 0.8.5", "rustc-hash 2.1.1", "serde", "serde-big-array", @@ -4201,7 +4877,7 @@ version = "1.1.1-rc.0" dependencies = [ "itertools 0.14.0", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4215,7 +4891,7 @@ dependencies = [ "openvm-circuit-primitives-derive", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "test-case", "tracing", ] @@ -4226,7 +4902,7 @@ version = "1.1.1-rc.0" dependencies = [ "itertools 0.14.0", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4249,7 +4925,7 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4278,7 +4954,7 @@ dependencies = [ "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", "serde_with", "strum", @@ -4335,7 +5011,7 @@ version = "1.1.1-rc.0" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4375,7 +5051,7 @@ dependencies = [ "quote", "strum", "strum_macros", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4396,7 +5072,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "p3-keccak-air", - "rand", + "rand 0.8.5", "serde", "serde-big-array", "strum", @@ -4444,7 +5120,7 @@ dependencies = [ name = "openvm-macros-common" version = "1.1.1-rc.0" dependencies = [ - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4461,7 +5137,7 @@ dependencies = [ "openvm-pairing-guest", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", "serde_with", "tracing", @@ -4483,9 +5159,10 @@ dependencies = [ "openvm-native-compiler", "openvm-poseidon2-air", "openvm-rv32im-circuit", + "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", "serde-big-array", "static_assertions", @@ -4511,7 +5188,7 @@ dependencies = [ "openvm-stark-backend", "openvm-stark-sdk", "p3-symmetric", - "rand", + "rand 0.8.5", "serde", "snark-verifier-sdk", "strum", @@ -4524,7 +5201,7 @@ name = "openvm-native-compiler-derive" version = "1.1.1-rc.0" dependencies = [ "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -4548,7 +5225,7 @@ dependencies = [ "p3-fri", "p3-merkle-tree", "p3-symmetric", - "rand", + "rand 0.8.5", "serde", "serde_json", "serde_with", @@ -4592,7 +5269,7 @@ dependencies = [ "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", "strum", ] @@ -4617,7 +5294,7 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-platform", "openvm-rv32im-guest", - "rand", + "rand 0.8.5", "serde", "strum_macros", "subtle", @@ -4645,7 +5322,7 @@ dependencies = [ "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", - "rand", + "rand 0.8.5", ] [[package]] @@ -4667,7 +5344,7 @@ version = "1.1.1-rc.0" dependencies = [ "critical-section", "embedded-alloc", - "getrandom 0.2.15", + "getrandom 0.2.16", "libm", "openvm-custom-insn", "openvm-rv32im-guest", @@ -4685,7 +5362,7 @@ dependencies = [ "p3-poseidon2", "p3-poseidon2-air", "p3-symmetric", - "rand", + "rand 0.8.5", "zkhash", ] @@ -4715,7 +5392,7 @@ dependencies = [ "openvm-rv32im-circuit", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", "serde-big-array", "serde_with", @@ -4738,10 +5415,11 @@ dependencies = [ "openvm-rv32im-transpiler", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", "serde-big-array", "strum", + "test-case", ] [[package]] @@ -4844,8 +5522,8 @@ dependencies = [ "openvm-circuit-primitives", "openvm-stark-backend", "openvm-stark-sdk", - "rand", - "sha2", + "rand 0.8.5", + "sha2 0.10.9", ] [[package]] @@ -4864,9 +5542,9 @@ dependencies = [ "openvm-sha256-transpiler", "openvm-stark-backend", "openvm-stark-sdk", - "rand", + "rand 0.8.5", "serde", - "sha2", + "sha2 0.10.9", "strum", ] @@ -4875,7 +5553,7 @@ name = "openvm-sha256-guest" version = "1.1.1-rc.0" dependencies = [ "openvm-platform", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -4909,6 +5587,7 @@ dependencies = [ [[package]] name = "openvm-stark-backend" version = "1.0.0" +source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.0.1#e540dbdd09ef20db7207ad7f2674bece75a2b803" dependencies = [ "bitcode", "cfg-if", @@ -4936,10 +5615,11 @@ dependencies = [ [[package]] name = "openvm-stark-sdk" version = "1.0.0" +source = "git+https://github.com/openvm-org/stark-backend.git?tag=v1.0.1#e540dbdd09ef20db7207ad7f2674bece75a2b803" dependencies = [ "derivative", - "derive_more 0.99.19", - "ff 0.13.0", + "derive_more 0.99.20", + "ff 0.13.1", "itertools 0.14.0", "metrics", "metrics-tracing-context", @@ -4957,14 +5637,14 @@ dependencies = [ "p3-poseidon", "p3-poseidon2", "p3-symmetric", - "rand", + "rand 0.8.5", "serde", "serde_json", "static_assertions", - "toml 0.8.20", + "toml 0.8.22", "tracing", "tracing-forest", - "tracing-subscriber", + "tracing-subscriber 0.3.19", "zkhash", ] @@ -5042,7 +5722,7 @@ checksum = "51f44edd08f51e2ade572f141051021c5af22677e42b7dd28a88155151c33594" dependencies = [ "ecdsa 0.14.8", "elliptic-curve 0.12.3", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -5051,7 +5731,10 @@ version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9863ad85fa8f4460f9c48cb909d38a0d689dba1f6f6988a5e3e0d31071bcd4b" dependencies = [ + "ecdsa 0.16.9", "elliptic-curve 0.13.8", + "primeorder", + "sha2 0.10.9", ] [[package]] @@ -5073,7 +5756,7 @@ dependencies = [ "p3-monty-31", "p3-poseidon2", "p3-symmetric", - "rand", + "rand 0.8.5", "serde", ] @@ -5092,13 +5775,13 @@ name = "p3-bn254-fr" version = "0.1.0" source = "git+https://github.com/Plonky3/Plonky3.git?rev=1ba4e5c#1ba4e5c40417f4f7aae86bcca56b6484b4b2490b" dependencies = [ - "ff 0.13.0", + "ff 0.13.1", "halo2curves", "num-bigint 0.4.6", "p3-field", "p3-poseidon2", "p3-symmetric", - "rand", + "rand 0.8.5", "serde", ] @@ -5153,7 +5836,7 @@ dependencies = [ "nums", "p3-maybe-rayon", "p3-util", - "rand", + "rand 0.8.5", "serde", "tracing", ] @@ -5172,7 +5855,7 @@ dependencies = [ "p3-matrix", "p3-maybe-rayon", "p3-util", - "rand", + "rand 0.8.5", "serde", "tracing", ] @@ -5190,7 +5873,7 @@ dependencies = [ "p3-poseidon2", "p3-symmetric", "p3-util", - "rand", + "rand 0.8.5", "serde", ] @@ -5227,7 +5910,7 @@ dependencies = [ "p3-matrix", "p3-maybe-rayon", "p3-util", - "rand", + "rand 0.8.5", "tracing", ] @@ -5241,7 +5924,7 @@ dependencies = [ "p3-monty-31", "p3-poseidon2", "p3-symmetric", - "rand", + "rand 0.8.5", "serde", ] @@ -5254,7 +5937,7 @@ dependencies = [ "p3-field", "p3-maybe-rayon", "p3-util", - "rand", + "rand 0.8.5", "serde", "tracing", "transpose", @@ -5279,7 +5962,7 @@ dependencies = [ "p3-matrix", "p3-symmetric", "p3-util", - "rand", + "rand 0.8.5", ] [[package]] @@ -5294,7 +5977,7 @@ dependencies = [ "p3-maybe-rayon", "p3-symmetric", "p3-util", - "rand", + "rand 0.8.5", "serde", "tracing", ] @@ -5314,7 +5997,7 @@ dependencies = [ "p3-poseidon2", "p3-symmetric", "p3-util", - "rand", + "rand 0.8.5", "serde", "tracing", "transpose", @@ -5328,7 +6011,7 @@ dependencies = [ "p3-field", "p3-mds", "p3-symmetric", - "rand", + "rand 0.8.5", ] [[package]] @@ -5340,7 +6023,7 @@ dependencies = [ "p3-field", "p3-mds", "p3-symmetric", - "rand", + "rand 0.8.5", ] [[package]] @@ -5354,7 +6037,7 @@ dependencies = [ "p3-maybe-rayon", "p3-poseidon2", "p3-util", - "rand", + "rand 0.8.5", "tikv-jemallocator", "tracing", ] @@ -5438,7 +6121,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -5471,7 +6154,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" dependencies = [ "base64ct", - "rand_core", + "rand_core 0.6.4", "subtle", ] @@ -5485,7 +6168,7 @@ dependencies = [ "ff 0.12.1", "group 0.12.1", "lazy_static", - "rand", + "rand 0.8.5", "static_assertions", "subtle", ] @@ -5497,10 +6180,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3e57598f73cc7e1b2ac63c79c517b31a0877cd7c402cdcaa311b5208de7a095" dependencies = [ "blake2b_simd", - "ff 0.13.0", + "ff 0.13.1", "group 0.13.0", "lazy_static", - "rand", + "rand 0.8.5", "static_assertions", "subtle", ] @@ -5526,7 +6209,7 @@ dependencies = [ "digest 0.10.7", "hmac", "password-hash", - "sha2", + "sha2 0.10.9", ] [[package]] @@ -5549,7 +6232,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -5560,12 +6243,12 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" -version = "2.7.15" +version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc" +checksum = "198db74531d58c70a361c42201efde7e2591e976d518caf7662a47dc5720e7b6" dependencies = [ "memchr", - "thiserror 2.0.11", + "thiserror 2.0.12", "ucd-trie", ] @@ -5576,7 +6259,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4c5cc86750666a3ed20bdaf5ca2a0344f9c67674cae0515bec2da16fbaa47db" dependencies = [ "fixedbitset", - "indexmap 2.7.1", + "indexmap 2.9.0", ] [[package]] @@ -5596,7 +6279,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand", + "rand 0.8.5", ] [[package]] @@ -5609,7 +6292,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -5649,15 +6332,15 @@ version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" dependencies = [ - "der 0.7.9", + "der 0.7.10", "spki 0.7.3", ] [[package]] name = "pkg-config" -version = "0.3.31" +version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" [[package]] name = "plotters" @@ -5689,9 +6372,9 @@ dependencies = [ [[package]] name = "portable-atomic" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "280dc24453071f1b63954171985a0b0d30058d287960968b9b2aca264c8d4ee6" +checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" [[package]] name = "poseidon-primitives" @@ -5700,14 +6383,23 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e4aaeda7a092e21165cc5f0cbc738e72a46f31c03c3cbd87b71ceae9d2d93bc" dependencies = [ "bitvec", - "ff 0.13.0", + "ff 0.13.1", "lazy_static", "log", - "rand", + "rand 0.8.5", "rand_xorshift", "thiserror 1.0.69", ] +[[package]] +name = "potential_utf" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7c30837279ca13e7c867e9e40053bc68740f988cb07f7ca6df43cc734b585" +dependencies = [ + "zerovec", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -5716,9 +6408,9 @@ checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" -version = "0.2.20" +version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ "zerocopy", ] @@ -5731,12 +6423,21 @@ checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" [[package]] name = "prettyplease" -version = "0.2.29" +version = "0.2.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6924ced06e1f7dfe3fa48d57b9f74f55d8915f5036121bef647ef4b204895fac" +checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" dependencies = [ "proc-macro2", - "syn 2.0.98", + "syn 2.0.101", +] + +[[package]] +name = "primeorder" +version = "0.13.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "353e1ca18966c16d9deb1c69278edbc5f194139612772bd9537af60ac231e1e6" +dependencies = [ + "elliptic-curve 0.13.8", ] [[package]] @@ -5755,11 +6456,11 @@ dependencies = [ [[package]] name = "proc-macro-crate" -version = "3.2.0" +version = "3.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8ecf48c7ca261d60b74ab1a7b20da18bede46776b2e55535cb958eb595c5fa7b" +checksum = "edce586971a4dfaa28950c6f18ed55e0406c1ab88bbce2c6f6293a7aaba73d35" dependencies = [ - "toml_edit 0.22.24", + "toml_edit 0.22.26", ] [[package]] @@ -5781,14 +6482,14 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] name = "proc-macro2" -version = "1.0.93" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60946a68e5f9d28b0dc1c21bb8a97ee7d018a8b322fa57838ba31cc878e22d99" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -5801,24 +6502,24 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "version_check", "yansi 1.0.1", ] [[package]] name = "proptest" -version = "1.5.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4c2511913b88df1637da85cc8d96ec8e43a3f8bb8ccb71ee1ac240d6f3df58d" +checksum = "14cae93065090804185d3b75f0bf93b8eeda30c7a9b4a33d3bdb3988d6229e50" dependencies = [ - "bit-set", - "bit-vec", - "bitflags 2.8.0", + "bit-set 0.8.0", + "bit-vec 0.8.0", + "bitflags 2.9.0", "lazy_static", "num-traits", - "rand", - "rand_chacha", + "rand 0.8.5", + "rand_chacha 0.3.1", "rand_xorshift", "regex-syntax 0.8.5", "rusty-fork", @@ -5849,13 +6550,19 @@ checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0" [[package]] name = "quote" -version = "1.0.38" +version = "1.0.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e4dccaaaf89514f546c693ddc140f729f958c247918a13380cccc6078391acc" +checksum = "1885c039570dc00dcb4ff087a89e185fd56bae234ddc7f056a945bf36467248d" dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74765f6d916ee2faa39bc8e68e4f3ed8949b48cccdac59983d287a7cb71ce9c5" + [[package]] name = "radium" version = "0.7.0" @@ -5879,8 +6586,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha", - "rand_core", + "rand_chacha 0.3.1", + "rand_core 0.6.4", + "serde", +] + +[[package]] +name = "rand" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", "serde", ] @@ -5891,7 +6609,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.6.4", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", ] [[package]] @@ -5900,7 +6628,17 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.3", + "serde", ] [[package]] @@ -5909,16 +6647,16 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d25bf25ec5ae4a3f1b92f929810509a2f53d7dca2f50b794ff57e3face536c8f" dependencies = [ - "rand_core", + "rand_core 0.6.4", ] [[package]] name = "raw-cpuid" -version = "11.4.0" +version = "11.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "529468c1335c1c03919960dfefdb1b3648858c20d7ec2d0663e728e4a717efbc" +checksum = "c6df7ab838ed27997ba19a4664507e6f82b41fe6e20be42929332156e5e85146" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.0", ] [[package]] @@ -5943,11 +6681,11 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.11" +version = "0.5.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2f103c6d277498fbceb16e84d317e2a400f160f46904d5f5410848c829511a3" +checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.0", ] [[package]] @@ -5956,7 +6694,7 @@ version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", "libredox", "thiserror 1.0.69", ] @@ -6022,11 +6760,11 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", - "hyper", - "hyper-rustls", + "hyper 0.14.32", + "hyper-rustls 0.24.2", "ipnet", "js-sys", "log", @@ -6034,7 +6772,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.12", "rustls-pemfile", "serde", "serde_json", @@ -6042,7 +6780,7 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tower-service", "url", "wasm-bindgen", @@ -6054,46 +6792,162 @@ dependencies = [ [[package]] name = "revm" -version = "18.0.0" +version = "22.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15689a3c6a8d14b647b4666f2e236ef47b5a5133cdfd423f545947986fff7013" +checksum = "f5378e95ffe5c8377002dafeb6f7d370a55517cef7d6d6c16fc552253af3b123" +dependencies = [ + "revm-bytecode", + "revm-context", + "revm-context-interface", + "revm-database", + "revm-database-interface", + "revm-handler", + "revm-inspector", + "revm-interpreter", + "revm-precompile", + "revm-primitives 18.0.0", + "revm-state", +] + +[[package]] +name = "revm-bytecode" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e63e138d520c5c5bc25ecc82506e9e4e6e85a811809fc5251c594378dccabfc6" +dependencies = [ + "bitvec", + "phf", + "revm-primitives 18.0.0", + "serde", +] + +[[package]] +name = "revm-context" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9765628dfea4f3686aa8f2a72471c52801e6b38b601939ac16965f49bac66580" dependencies = [ - "auto_impl", "cfg-if", - "dyn-clone", + "derive-where", + "revm-bytecode", + "revm-context-interface", + "revm-database-interface", + "revm-primitives 18.0.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-context-interface" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82d74335aa1f14222cc4d3be1f62a029cc7dc03819cc8d080ff17b7e1d76375f" +dependencies = [ + "alloy-eip2930", + "alloy-eip7702", + "auto_impl", + "revm-database-interface", + "revm-primitives 18.0.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-database" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e5c80c5a2fd605f2119ee32a63fb3be941fb6a81ced8cdb3397abca28317224" +dependencies = [ + "alloy-eips", + "revm-bytecode", + "revm-database-interface", + "revm-primitives 18.0.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-database-interface" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0e4dfbc734b1ea67b5e8f8b3c7dc4283e2210d978cdaf6c7a45e97be5ea53b3" +dependencies = [ + "auto_impl", + "revm-primitives 18.0.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-handler" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8676379521c7bf179c31b685c5126ce7800eab5844122aef3231b97026d41a10" +dependencies = [ + "auto_impl", + "revm-bytecode", + "revm-context", + "revm-context-interface", + "revm-database-interface", "revm-interpreter", "revm-precompile", + "revm-primitives 18.0.0", + "revm-state", + "serde", +] + +[[package]] +name = "revm-inspector" +version = "3.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfed4ecf999a3f6ae776ae2d160478c5dca986a8c2d02168e04066b1e34c789e" +dependencies = [ + "auto_impl", + "revm-context", + "revm-database-interface", + "revm-handler", + "revm-interpreter", + "revm-primitives 18.0.0", + "revm-state", "serde", "serde_json", ] [[package]] name = "revm-interpreter" -version = "14.0.0" +version = "18.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74e3f11d0fed049a4a10f79820c59113a79b38aed4ebec786a79d5c667bfeb51" +checksum = "feb20260342003cfb791536e678ef5bbea1bfd1f8178b170e8885ff821985473" dependencies = [ - "revm-primitives 14.0.0", + "revm-bytecode", + "revm-context-interface", + "revm-primitives 18.0.0", "serde", ] [[package]] name = "revm-precompile" -version = "15.0.0" +version = "19.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e381060af24b750069a2b2d2c54bba273d84e8f5f9e8026fc9262298e26cc336" +checksum = "418e95eba68c9806c74f3e36cd5d2259170b61e90ac608b17ff8c435038ddace" dependencies = [ + "ark-bls12-381", + "ark-bn254", + "ark-ec", + "ark-ff 0.5.0", + "ark-serialize 0.5.0", "aurora-engine-modexp", "blst", "c-kzg", "cfg-if", "k256", + "libsecp256k1", "once_cell", - "revm-primitives 14.0.0", + "p256 0.13.2", + "revm-primitives 18.0.0", "ripemd", "secp256k1", - "sha2", - "substrate-bn", + "sha2 0.10.9", ] [[package]] @@ -6105,7 +6959,7 @@ dependencies = [ "alloy-primitives 0.4.2", "alloy-rlp", "auto_impl", - "bitflags 2.8.0", + "bitflags 2.9.0", "bitvec", "enumn", "hashbrown 0.14.5", @@ -6114,21 +6968,24 @@ dependencies = [ [[package]] name = "revm-primitives" -version = "14.0.0" +version = "18.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3702f132bb484f4f0d0ca4f6fbde3c82cfd745041abbedd6eda67730e1868ef0" +checksum = "1fc2283ff87358ec7501956c5dd8724a6c2be959c619c4861395ae5e0054575f" dependencies = [ - "alloy-eip2930", - "alloy-eip7702", - "alloy-primitives 0.8.25", - "auto_impl", - "bitflags 2.8.0", - "bitvec", - "c-kzg", - "cfg-if", - "dyn-clone", + "alloy-primitives 1.1.0", "enumn", - "hex", + "serde", +] + +[[package]] +name = "revm-state" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09dd121f6e66d75ab111fb51b4712f129511569bc3e41e6067ae760861418bd8" +dependencies = [ + "bitflags 2.9.0", + "revm-bytecode", + "revm-primitives 18.0.0", "serde", ] @@ -6155,13 +7012,13 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.13" +version = "0.17.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" dependencies = [ "cc", "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "libc", "untrusted", "windows-sys 0.52.0", @@ -6222,9 +7079,9 @@ dependencies = [ [[package]] name = "ruint" -version = "1.12.4" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5ef8fb1dd8de3870cb8400d51b4c2023854bbafd5431a3ac7e7317243e22d2f" +checksum = "78a46eb779843b2c4f21fac5773e25d6d5b7c8f0922876c91541790d2ca27eef" dependencies = [ "alloy-rlp", "ark-ff 0.3.0", @@ -6238,7 +7095,8 @@ dependencies = [ "parity-scale-codec", "primitive-types", "proptest", - "rand", + "rand 0.8.5", + "rand 0.9.1", "rlp", "ruint-macro", "serde", @@ -6291,7 +7149,7 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" dependencies = [ - "semver 1.0.25", + "semver 1.0.26", ] [[package]] @@ -6300,10 +7158,23 @@ version = "0.38.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.0", + "errno", + "libc", + "linux-raw-sys 0.4.15", + "windows-sys 0.59.0", +] + +[[package]] +name = "rustix" +version = "1.0.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c71e83d6afe7ff64890ec6b71d6a69bb8a610ab78ce364b3352876bb4c801266" +dependencies = [ + "bitflags 2.9.0", "errno", "libc", - "linux-raw-sys", + "linux-raw-sys 0.9.4", "windows-sys 0.59.0", ] @@ -6315,10 +7186,24 @@ checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.23.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" +dependencies = [ + "aws-lc-rs", + "once_cell", + "rustls-pki-types", + "rustls-webpki 0.103.3", + "subtle", + "zeroize", +] + [[package]] name = "rustls-native-certs" version = "0.6.3" @@ -6328,7 +7213,19 @@ dependencies = [ "openssl-probe", "rustls-pemfile", "schannel", - "security-framework", + "security-framework 2.11.1", +] + +[[package]] +name = "rustls-native-certs" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7fcff2dd52b58a8d98a70243663a0d234c4e2b79235637849d15913394a247d3" +dependencies = [ + "openssl-probe", + "rustls-pki-types", + "schannel", + "security-framework 3.2.0", ] [[package]] @@ -6340,6 +7237,15 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pki-types" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "229a4a4c221013e7e1f1a043678c5cc39fe5171437c88fb47151a21e6f5b5c79" +dependencies = [ + "zeroize", +] + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -6350,11 +7256,23 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.103.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4a72fe2bcf7a6ac6fd7d0b9e5cb68aeb7d4c0a0271730218b3e92d43b4eb435" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4" +checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" [[package]] name = "rusty-fork" @@ -6370,9 +7288,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" [[package]] name = "same-file" @@ -6404,7 +7322,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -6453,7 +7371,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" dependencies = [ "base16ct 0.2.0", - "der 0.7.9", + "der 0.7.10", "generic-array", "pkcs8 0.10.2", "subtle", @@ -6462,11 +7380,12 @@ dependencies = [ [[package]] name = "secp256k1" -version = "0.29.1" +version = "0.30.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9465315bc9d4566e1724f0fffcbcc446268cb522e60f9a27bcded6b19c108113" +checksum = "b50c5943d326858130af85e049f2661ba3c78b26589b8ab98e65e80ae44a1252" dependencies = [ - "rand", + "bitcoin_hashes", + "rand 0.8.5", "secp256k1-sys", ] @@ -6485,8 +7404,21 @@ version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ - "bitflags 2.8.0", - "core-foundation", + "bitflags 2.9.0", + "core-foundation 0.9.4", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271720403f46ca04f7ba6f55d438f8bd878d6b8ca0a1046e8228c4145bcbb316" +dependencies = [ + "bitflags 2.9.0", + "core-foundation 0.10.0", "core-foundation-sys", "libc", "security-framework-sys", @@ -6513,9 +7445,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.25" +version = "1.0.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03" +checksum = "56e6fa9c48d24d85fb3de5ad847117517440f6beceb7798af16b4a87d616b8d0" dependencies = [ "serde", ] @@ -6531,9 +7463,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" dependencies = [ "serde_derive", ] @@ -6558,22 +7490,22 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.218" +version = "1.0.219" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] name = "serde_json" -version = "1.0.139" +version = "1.0.140" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" +checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.9.0", "itoa", "memchr", "ryu", @@ -6621,7 +7553,7 @@ dependencies = [ "chrono", "hex", "indexmap 1.9.3", - "indexmap 2.7.1", + "indexmap 2.9.0", "serde", "serde_derive", "serde_json", @@ -6638,7 +7570,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -6654,9 +7586,22 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.8" +version = "0.9.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800" +dependencies = [ + "block-buffer 0.9.0", + "cfg-if", + "cpufeatures", + "digest 0.9.0", + "opaque-debug", +] + +[[package]] +name = "sha2" +version = "0.10.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" +checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", @@ -6700,9 +7645,9 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook-registry" -version = "1.4.2" +version = "1.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +checksum = "9203b8055f63a2a00e2f593bb0510367fe707d7ff1e5c872de2f537b339e5410" dependencies = [ "libc", ] @@ -6714,7 +7659,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "74233d3b3b2f6d4b006dc19dee745e73e2a6bfb6f93607cd3b02bd5b00797d7c" dependencies = [ "digest 0.10.7", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -6724,7 +7669,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ "digest 0.10.7", - "rand_core", + "rand_core 0.6.4", ] [[package]] @@ -6750,15 +7695,15 @@ dependencies = [ [[package]] name = "smallvec" -version = "1.14.0" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" +checksum = "8917285742e9f3e1683f0a9c4e6b57960b7314d0b08d30d1ecd426713ee2eee9" [[package]] name = "snark-verifier" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28e4c4ed1edca41687fe2d8a09ba30badb0a5cc7fa56dd1159d62aeab7c99ace" +checksum = "4d798d8ce8e29b8820ecc1028ac44cc4fc0f0296728af6fe6a0c4db05782c0a4" dependencies = [ "halo2-base", "halo2-ecc", @@ -6769,7 +7714,7 @@ dependencies = [ "num-integer", "num-traits", "pairing 0.23.0", - "rand", + "rand 0.8.5", "revm", "ruint", "serde", @@ -6778,9 +7723,9 @@ dependencies = [ [[package]] name = "snark-verifier-sdk" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "babff70ce6292fce03f692d68569f76b8f6710dbac7be7fe5f32c915909c9065" +checksum = "a338d065044702bf751e87cf353daac63e2fc4c53a3e323cbcd98c603ee6e66c" dependencies = [ "bincode", "ethereum-types", @@ -6792,8 +7737,8 @@ dependencies = [ "num-bigint 0.4.6", "num-integer", "num-traits", - "rand", - "rand_chacha", + "rand 0.8.5", + "rand_chacha 0.3.1", "serde", "serde_json", "snark-verifier", @@ -6801,9 +7746,9 @@ dependencies = [ [[package]] name = "socket2" -version = "0.5.8" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +checksum = "4f5fd57c80058a56cf5c777ab8a126398ece8e442983605d280a44ce79d0edef" dependencies = [ "libc", "windows-sys 0.52.0", @@ -6846,7 +7791,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" dependencies = [ "base64ct", - "der 0.7.9", + "der 0.7.10", ] [[package]] @@ -6904,20 +7849,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.98", -] - -[[package]] -name = "substrate-bn" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b5bbfa79abbae15dd642ea8176a21a635ff3c00059961d1ea27ad04e5b441c" -dependencies = [ - "byteorder", - "crunchy", - "lazy_static", - "rand", - "rustc-hex", + "syn 2.0.101", ] [[package]] @@ -6950,10 +7882,10 @@ dependencies = [ "hex", "once_cell", "reqwest", - "semver 1.0.25", + "semver 1.0.26", "serde", "serde_json", - "sha2", + "sha2 0.10.9", "thiserror 1.0.69", "url", "zip", @@ -6967,7 +7899,7 @@ checksum = "aa64b5e8eecd3a8af7cfc311e29db31a268a62d5953233d3e8243ec77a71c4e3" dependencies = [ "build_const", "hex", - "semver 1.0.25", + "semver 1.0.26", "serde_json", "svm-rs", ] @@ -6985,9 +7917,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.98" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36147f1a48ae0ec2b5b3bc5b537d267457555a10dc06f3dbc8cb11ba3006d3b1" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -7003,7 +7935,7 @@ dependencies = [ "paste", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -7014,13 +7946,13 @@ checksum = "2047c6ded9c721764247e62cd3b03c09ffc529b2ba5b10ec482ae507a4a70160" [[package]] name = "synstructure" -version = "0.13.1" +version = "0.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -7030,7 +7962,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", - "core-foundation", + "core-foundation 0.9.4", "system-configuration-sys", ] @@ -7058,15 +7990,14 @@ checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" [[package]] name = "tempfile" -version = "3.17.1" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" +checksum = "e8a64e3985349f2441a1a9ef0b853f869006c3855f2cda6862a94d26ebb9d6a1" dependencies = [ - "cfg-if", "fastrand", - "getrandom 0.3.1", + "getrandom 0.3.3", "once_cell", - "rustix", + "rustix 1.0.7", "windows-sys 0.59.0", ] @@ -7081,6 +8012,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "terminal_size" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45c6481c4829e4cc63825e62c49186a34538b7b2750b73b266581ffb612fb5ed" +dependencies = [ + "rustix 1.0.7", + "windows-sys 0.59.0", +] + [[package]] name = "test-case" version = "3.3.1" @@ -7099,7 +8040,7 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -7110,7 +8051,7 @@ checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "test-case-core", ] @@ -7122,7 +8063,7 @@ checksum = "e7f46083d221181166e5b6f6b1e5f1d499f3a76888826e6cb1d057554157cd0f" dependencies = [ "env_logger", "test-log-macros", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -7133,7 +8074,7 @@ checksum = "888d0c3c6db53c0fdab160d2ed5e12ba745383d3e85813f2ea0f2b1475ab553f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -7147,11 +8088,11 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc" +checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" dependencies = [ - "thiserror-impl 2.0.11", + "thiserror-impl 2.0.12", ] [[package]] @@ -7162,18 +8103,18 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] name = "thiserror-impl" -version = "2.0.11" +version = "2.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2" +checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -7217,9 +8158,9 @@ dependencies = [ [[package]] name = "time" -version = "0.3.37" +version = "0.3.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21" +checksum = "8a7619e19bc266e0f9c5e6686659d394bc57973859340060a69221e57dbc0c40" dependencies = [ "deranged", "itoa", @@ -7234,15 +8175,15 @@ dependencies = [ [[package]] name = "time-core" -version = "0.1.2" +version = "0.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" +checksum = "c9e9a38711f559d9e3ce1cdb06dd7c5b8ea546bc90052da6d06bb76da74bb07c" [[package]] name = "time-macros" -version = "0.2.19" +version = "0.2.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de" +checksum = "3526739392ec93fd8b359c8e98514cb3e8e021beb4e5f597b00a0221f8ed8a49" dependencies = [ "num-conv", "time-core", @@ -7259,9 +8200,9 @@ dependencies = [ [[package]] name = "tinystr" -version = "0.7.6" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" +checksum = "5d4f6d1145dcb577acf783d4e601bc1d76a13337bb54e6233add580b07344c8b" dependencies = [ "displaydoc", "zerovec", @@ -7279,9 +8220,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.44.2" +version = "1.45.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6b88822cbe49de4185e3a4cbf8321dd487cf5fe0c5c65695fef6346371e9c48" +checksum = "2513ca694ef9ede0fb23fe71a4ee4107cb102b9dc1930f6d0fd77aae068ae165" dependencies = [ "backtrace", "bytes", @@ -7302,7 +8243,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -7311,15 +8252,25 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.12", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b" +dependencies = [ + "rustls 0.23.27", "tokio", ] [[package]] name = "tokio-util" -version = "0.7.13" +version = "0.7.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078" +checksum = "66a539a9ad6d5d281510d5bd368c973d636c02dbf8a67300bfb6b950696ad7df" dependencies = [ "bytes", "futures-core", @@ -7334,7 +8285,7 @@ version = "0.7.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd79e69d3b627db300ff956027cc6c3798cef26d22526befdfcd12feeb6d2257" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.9.0", "serde", "serde_spanned", "toml_datetime", @@ -7343,21 +8294,21 @@ dependencies = [ [[package]] name = "toml" -version = "0.8.20" +version = "0.8.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd87a5cdd6ffab733b2f74bc4fd7ee5fff6634124999ac278c35fc78c6120148" +checksum = "05ae329d1f08c4d17a59bed7ff5b5a769d062e64a62d34a3261b219e62cd5aae" dependencies = [ "serde", "serde_spanned", "toml_datetime", - "toml_edit 0.22.24", + "toml_edit 0.22.26", ] [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" dependencies = [ "serde", ] @@ -7368,7 +8319,7 @@ version = "0.19.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.9.0", "serde", "serde_spanned", "toml_datetime", @@ -7377,17 +8328,40 @@ dependencies = [ [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "310068873db2c5b3e7659d2cc35d21855dbafa50d1ce336397c666e3cb08137e" dependencies = [ - "indexmap 2.7.1", + "indexmap 2.9.0", "serde", "serde_spanned", "toml_datetime", - "winnow 0.7.3", + "toml_write", + "winnow 0.7.10", +] + +[[package]] +name = "toml_write" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfb942dfe1d8e29a7ee7fcbde5bd2b9a25fb89aa70caea2eba3bee836ff41076" + +[[package]] +name = "tower" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d039ad9159c98b70ecfd540b2573b97f7f52c3e8d9f8ad57a24b916a536975f9" +dependencies = [ + "tower-layer", + "tower-service", ] +[[package]] +name = "tower-layer" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" + [[package]] name = "tower-service" version = "0.3.3" @@ -7413,7 +8387,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -7436,7 +8410,7 @@ dependencies = [ "smallvec", "thiserror 1.0.69", "tracing", - "tracing-subscriber", + "tracing-subscriber 0.3.19", ] [[package]] @@ -7450,6 +8424,15 @@ dependencies = [ "tracing-core", ] +[[package]] +name = "tracing-subscriber" +version = "0.2.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e0d2eaa99c3c2e41547cfa109e910a68ea03823cccad4a0525dcbc9b01e8c71" +dependencies = [ + "tracing-core", +] + [[package]] name = "tracing-subscriber" version = "0.3.19" @@ -7525,9 +8508,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" +checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" [[package]] name = "unicode-width" @@ -7574,12 +8557,6 @@ version = "2.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - [[package]] name = "utf8_iter" version = "1.0.4" @@ -7594,9 +8571,12 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.13.2" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c1f41ffb7cf259f1ecc2876861a17e7142e63ead296f671f81f6ae85903e0d6" +checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +dependencies = [ + "getrandom 0.3.3", +] [[package]] name = "valuable" @@ -7671,9 +8651,9 @@ checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasi" -version = "0.13.3+wasi-0.2.2" +version = "0.14.2+wasi-0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2" +checksum = "9683f9a5a998d873c0d21fcbe3c083009670149a8fab228644b8bd36b2c48cb3" dependencies = [ "wit-bindgen-rt", ] @@ -7700,7 +8680,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "wasm-bindgen-shared", ] @@ -7735,7 +8715,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -7765,6 +8745,18 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "winapi" version = "0.3.9" @@ -7798,11 +8790,61 @@ checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-core" -version = "0.52.0" +version = "0.61.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +checksum = "4763c1de310c86d75a878046489e2e5ba02c649d185f21c67d4cf8a56d098980" dependencies = [ - "windows-targets 0.52.6", + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + +[[package]] +name = "windows-link" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" + +[[package]] +name = "windows-result" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c64fd11a4fd95df68efcfee5f44a294fe71b8bc6a91993e2791938abcc712252" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ba9642430ee452d5a7aa78d72907ebe8cfda358e8cb7918a2050581322f97" +dependencies = [ + "windows-link", ] [[package]] @@ -7856,13 +8898,29 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-targets" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -7875,6 +8933,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -7887,6 +8951,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -7899,12 +8969,24 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -7917,6 +8999,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -7929,6 +9017,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -7941,6 +9035,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -7953,6 +9053,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" version = "0.5.40" @@ -7964,9 +9070,9 @@ dependencies = [ [[package]] name = "winnow" -version = "0.7.3" +version = "0.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e7f4ea97f6f78012141bcdb6a216b2609f0979ada50b20ca5b52dde2eac2bb1" +checksum = "c06928c8748d81b05c9be96aad92e1b6ff01833332f281e8cfca3be4b35fc9ec" dependencies = [ "memchr", ] @@ -7983,24 +9089,18 @@ dependencies = [ [[package]] name = "wit-bindgen-rt" -version = "0.33.0" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c" +checksum = "6f42320e61fe2cfd34354ecb597f86f413484a798ba44a8ca1165c58d42da6c1" dependencies = [ - "bitflags 2.8.0", + "bitflags 2.9.0", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - [[package]] name = "writeable" -version = "0.5.5" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" +checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" [[package]] name = "wyz" @@ -8031,9 +9131,9 @@ checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" [[package]] name = "yoke" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" +checksum = "5f41bb01b8226ef4bfd589436a297c53d118f65921786300e427be8d487695cc" dependencies = [ "serde", "stable_deref_trait", @@ -8043,55 +9143,54 @@ dependencies = [ [[package]] name = "yoke-derive" -version = "0.7.5" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" +checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "synstructure", ] [[package]] name = "zerocopy" -version = "0.7.35" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "byteorder", "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.35" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] name = "zerofrom" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" +checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" +checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", "synstructure", ] @@ -8112,14 +9211,25 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", +] + +[[package]] +name = "zerotrie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "36f0bbd478583f79edad978b407914f61b2972f5af6fa089686016be8f9af595" +dependencies = [ + "displaydoc", + "yoke", + "zerofrom", ] [[package]] name = "zerovec" -version = "0.10.4" +version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" +checksum = "4a05eb080e015ba39cc9e23bbe5e7fb04d5fb040350f99f34e338d5fdd294428" dependencies = [ "yoke", "zerofrom", @@ -8128,13 +9238,13 @@ dependencies = [ [[package]] name = "zerovec-derive" -version = "0.10.3" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" +checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.98", + "syn 2.0.101", ] [[package]] @@ -8176,9 +9286,9 @@ dependencies = [ "jubjub", "lazy_static", "pasta_curves 0.5.1", - "rand", + "rand 0.8.5", "serde", - "sha2", + "sha2 0.10.9", "sha3", "subtle", ] diff --git a/Cargo.toml b/Cargo.toml index 3f79027b71..87eafe6f98 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -107,8 +107,8 @@ lto = "thin" [workspace.dependencies] # Stark Backend -openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } -openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.0", default-features = false } +openvm-stark-backend = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.1", default-features = false } +openvm-stark-sdk = { git = "https://github.com/openvm-org/stark-backend.git", tag = "v1.0.1", default-features = false } # OpenVM openvm-sdk = { path = "crates/sdk", default-features = false } diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index 319490220a..d1037ea75b 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -9,40 +9,51 @@ license.workspace = true [dependencies] openvm-benchmarks-utils.workspace = true -cargo-openvm.workspace = true +# cargo-openvm.workspace = true openvm-circuit.workspace = true -openvm-sdk.workspace = true +# openvm-sdk.workspace = true openvm-stark-sdk.workspace = true openvm-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true +openvm-bigint-circuit.workspace = true +openvm-bigint-transpiler.workspace = true openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true +openvm-sha256-circuit.workspace = true +openvm-sha256-transpiler.workspace = true +# bitcode.workspace = true clap = { version = "4.5.9", features = ["derive", "env"] } eyre.workspace = true tracing.workspace = true derive_more = { workspace = true, features = ["from"] } +serde = { workspace = true, features = ["derive"] } tracing-subscriber = { version = "0.3.17", features = ["std", "env-filter"] } [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } +divan = { package = "codspeed-divan-compat", version = "*" } [features] default = ["jemalloc"] -profiling = ["openvm-sdk/profiling"] mimalloc = ["openvm-circuit/mimalloc"] jemalloc = ["openvm-circuit/jemalloc"] jemalloc-prof = ["openvm-circuit/jemalloc-prof"] nightly-features = ["openvm-circuit/nightly-features"] +profiling = ["openvm-circuit/function-span", "openvm-transpiler/function-span"] -[[bench]] -name = "fibonacci_execute" -harness = false +# [[bench]] +# name = "fibonacci_execute" +# harness = false + +# [[bench]] +# name = "regex_execute" +# harness = false [[bench]] -name = "regex_execute" +name = "execute" harness = false [package.metadata.cargo-shear] diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs new file mode 100644 index 0000000000..995c2bbbd8 --- /dev/null +++ b/benchmarks/execute/benches/execute.rs @@ -0,0 +1,108 @@ +use eyre::Result; +use openvm_benchmarks_utils::{get_elf_path, get_programs_dir, read_elf_file}; +use openvm_bigint_circuit::{Int256, Int256Executor, Int256Periphery}; +use openvm_bigint_transpiler::Int256TranspilerExtension; +use openvm_circuit::{ + arch::{instructions::exe::VmExe, SystemConfig, VmExecutor}, + derive::VmConfig, +}; +use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; +use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, + Rv32MExecutor, Rv32MPeriphery, +}; +use openvm_rv32im_transpiler::{ + Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +}; +use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; +use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_stark_sdk::{ + openvm_stark_backend::{self, p3_field::PrimeField32}, + p3_baby_bear::BabyBear, +}; +use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use serde::{Deserialize, Serialize}; + +static AVAILABLE_PROGRAMS: &[&str] = &[ + "fibonacci_recursive", + "fibonacci_iterative", + "quicksort", + "bubblesort", + "factorial_iterative_u256", + "revm_snailtracer", + "keccak256", + "keccak256_iter", + "sha256", + "sha256_iter", + // "revm_transfer", + // "pairing", +]; + +// TODO(ayush): remove from here +#[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] +pub struct ExecuteConfig { + #[system] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub bigint: Int256, + #[extension] + pub keccak: Keccak256, + #[extension] + pub sha256: Sha256, +} + +impl Default for ExecuteConfig { + fn default() -> Self { + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I::default(), + rv32m: Rv32M::default(), + io: Rv32Io::default(), + bigint: Int256::default(), + keccak: Keccak256::default(), + sha256: Sha256::default(), + } + } +} + +fn main() { + divan::main(); +} + +/// Run a specific OpenVM program +fn run_program(program: &str) -> Result<()> { + let program_dir = get_programs_dir().join(program); + let elf_path = get_elf_path(&program_dir); + let elf = read_elf_file(&elf_path)?; + + let vm_config = ExecuteConfig::default(); + + let transpiler = Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Int256TranspilerExtension) + .with_extension(Keccak256TranspilerExtension) + .with_extension(Sha256TranspilerExtension); + + let exe = VmExe::from_elf(elf, transpiler)?; + + let executor = VmExecutor::new(vm_config); + executor + .execute_e1(exe, vec![], None) + .expect("Failed to execute program"); + + Ok(()) +} + +#[divan::bench(args = AVAILABLE_PROGRAMS, sample_count=10)] +fn benchmark_execute(program: &str) { + run_program(program).unwrap(); +} diff --git a/benchmarks/execute/benches/fibonacci_execute.rs b/benchmarks/execute/benches/fibonacci_execute.rs index 70952b53c9..d7eb47f04f 100644 --- a/benchmarks/execute/benches/fibonacci_execute.rs +++ b/benchmarks/execute/benches/fibonacci_execute.rs @@ -5,7 +5,8 @@ use openvm_rv32im_circuit::Rv32ImConfig; use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sdk::StdIn; +// TODO(ayush): add this back +// use openvm_sdk::StdIn; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -28,10 +29,11 @@ fn benchmark_function(c: &mut Criterion) { group.bench_function("execute", |b| { b.iter(|| { - let n = 100_000u64; - let mut stdin = StdIn::default(); - stdin.write(&n); - executor.execute(exe.clone(), stdin).unwrap(); + // TODO(ayush): add this back + // let n = 100_000u64; + // let mut stdin = StdIn::default(); + // stdin.write(&n); + executor.execute(exe.clone(), vec![]).unwrap(); }) }); diff --git a/benchmarks/execute/benches/regex_execute.rs b/benchmarks/execute/benches/regex_execute.rs index a3a110e344..d4116b5aab 100644 --- a/benchmarks/execute/benches/regex_execute.rs +++ b/benchmarks/execute/benches/regex_execute.rs @@ -1,47 +1,47 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use openvm_benchmarks_utils::{build_elf, get_programs_dir}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{transpiler::Transpiler, FromElf}; +// TODO(ayush): add this back +// use criterion::{black_box, criterion_group, criterion_main, Criterion}; +// use openvm_benchmarks_utils::{build_elf, get_programs_dir}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_keccak256_circuit::Keccak256Rv32Config; +// use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{transpiler::Transpiler, FromElf}; -fn benchmark_function(c: &mut Criterion) { - let program_dir = get_programs_dir().join("regex"); - let elf = build_elf(&program_dir, "release").unwrap(); +// fn benchmark_function(c: &mut Criterion) { +// let program_dir = get_programs_dir().join("regex"); +// let elf = build_elf(&program_dir, "release").unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - ) - .unwrap(); +// let exe = VmExe::from_elf( +// elf, +// Transpiler::::default() +// .with_extension(Rv32ITranspilerExtension) +// .with_extension(Rv32MTranspilerExtension) +// .with_extension(Rv32IoTranspilerExtension) +// .with_extension(Keccak256TranspilerExtension), +// ) +// .unwrap(); - let mut group = c.benchmark_group("regex"); - group.sample_size(10); - let config = Keccak256Rv32Config::default(); - let executor = VmExecutor::::new(config); +// let mut group = c.benchmark_group("regex"); +// group.sample_size(10); +// let config = Keccak256Rv32Config::default(); +// let executor = VmExecutor::::new(config); - let data = include_str!("../../guest/regex/regex_email.txt"); +// let data = include_str!("../../guest/regex/regex_email.txt"); - let fe_bytes = data.to_owned().into_bytes(); - group.bench_function("execute", |b| { - b.iter(|| { - executor - .execute(exe.clone(), black_box(StdIn::from_bytes(&fe_bytes))) - .unwrap(); - }) - }); +// let fe_bytes = data.to_owned().into_bytes(); +// group.bench_function("execute", |b| { +// b.iter(|| { +// let input = black_box(Stdin::from_bytes(&fe_bytes)); +// executor.execute(exe.clone(), input).unwrap(); +// }) +// }); - group.finish(); -} +// group.finish(); +// } -criterion_group!(benches, benchmark_function); -criterion_main!(benches); +// criterion_group!(benches, benchmark_function); +// criterion_main!(benches); diff --git a/benchmarks/execute/examples/regex_execute.rs b/benchmarks/execute/examples/regex_execute.rs index 59705a19fd..3a6fd4162f 100644 --- a/benchmarks/execute/examples/regex_execute.rs +++ b/benchmarks/execute/examples/regex_execute.rs @@ -1,35 +1,35 @@ -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_keccak256_circuit::Keccak256Rv32Config; -use openvm_keccak256_transpiler::Keccak256TranspilerExtension; -use openvm_rv32im_transpiler::{ - Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, -}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use openvm_transpiler::{ - elf::Elf, openvm_platform::memory::MEM_SIZE, transpiler::Transpiler, FromElf, -}; +// use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; +// use openvm_keccak256_circuit::Keccak256Rv32Config; +// use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +// use openvm_rv32im_transpiler::{ +// Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +// }; +// use openvm_sdk::StdIn; +// use openvm_stark_sdk::p3_baby_bear::BabyBear; +// use openvm_transpiler::{ +// elf::Elf, openvm_platform::memory::MEM_SIZE, transpiler::Transpiler, FromElf, +// }; fn main() { - let elf = Elf::decode(include_bytes!("regex-elf"), MEM_SIZE as u32).unwrap(); - let exe = VmExe::from_elf( - elf, - Transpiler::::default() - .with_extension(Rv32ITranspilerExtension) - .with_extension(Rv32MTranspilerExtension) - .with_extension(Rv32IoTranspilerExtension) - .with_extension(Keccak256TranspilerExtension), - ) - .unwrap(); + // let elf = Elf::decode(include_bytes!("regex-elf"), MEM_SIZE as u32).unwrap(); + // let exe = VmExe::from_elf( + // elf, + // Transpiler::::default() + // .with_extension(Rv32ITranspilerExtension) + // .with_extension(Rv32MTranspilerExtension) + // .with_extension(Rv32IoTranspilerExtension) + // .with_extension(Keccak256TranspilerExtension), + // ) + // .unwrap(); - let config = Keccak256Rv32Config::default(); - let executor = VmExecutor::::new(config); + // let config = Keccak256Rv32Config::default(); + // let executor = VmExecutor::::new(config); - let data = include_str!("../../guest/regex/regex_email.txt"); + // let data = include_str!("../../guest/regex/regex_email.txt"); - let timer = std::time::Instant::now(); - executor - .execute(exe.clone(), StdIn::from_bytes(data.as_bytes())) - .unwrap(); - println!("execute_time: {:?}", timer.elapsed()); + // let timer = std::time::Instant::now(); + // executor + // .execute(exe.clone(), StdIn::from_bytes(data.as_bytes())) + // .unwrap(); + // println!("execute_time: {:?}", timer.elapsed()); } diff --git a/benchmarks/execute/src/main.rs b/benchmarks/execute/src/main.rs index 80db3ec5a4..bbeb4ddba5 100644 --- a/benchmarks/execute/src/main.rs +++ b/benchmarks/execute/src/main.rs @@ -1,30 +1,46 @@ -use cargo_openvm::{default::DEFAULT_APP_CONFIG_PATH, util::read_config_toml_or_default}; -use clap::{Parser, ValueEnum}; +use clap::Parser; use eyre::Result; use openvm_benchmarks_utils::{get_elf_path, get_programs_dir, read_elf_file}; -use openvm_circuit::arch::{instructions::exe::VmExe, VmExecutor}; -use openvm_sdk::StdIn; -use openvm_stark_sdk::bench::run_with_metric_collection; -use openvm_transpiler::FromElf; - -#[derive(Debug, Clone, ValueEnum)] -enum BuildProfile { - Debug, - Release, -} +use openvm_bigint_circuit::{Int256, Int256Executor, Int256Periphery}; +use openvm_bigint_transpiler::Int256TranspilerExtension; +use openvm_circuit::{ + arch::{instructions::exe::VmExe, SystemConfig, VmExecutor}, + derive::VmConfig, +}; +use openvm_keccak256_circuit::{Keccak256, Keccak256Executor, Keccak256Periphery}; +use openvm_keccak256_transpiler::Keccak256TranspilerExtension; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, + Rv32MExecutor, Rv32MPeriphery, +}; +use openvm_rv32im_transpiler::{ + Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, +}; +use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256Periphery}; +use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_stark_sdk::{ + bench::run_with_metric_collection, + openvm_stark_backend::{self, p3_field::PrimeField32}, + p3_baby_bear::BabyBear, +}; +use openvm_transpiler::{transpiler::Transpiler, FromElf}; +use serde::{Deserialize, Serialize}; + +// const DEFAULT_APP_CONFIG_PATH: &str = "./openvm.toml"; static AVAILABLE_PROGRAMS: &[&str] = &[ "fibonacci_recursive", "fibonacci_iterative", "quicksort", "bubblesort", - "pairing", + "factorial_iterative_u256", + "revm_snailtracer", "keccak256", "keccak256_iter", "sha256", "sha256_iter", "revm_transfer", - "revm_snailtracer", + // "pairing", ]; #[derive(Parser)] @@ -51,6 +67,39 @@ struct Cli { verbose: bool, } +#[derive(Clone, Debug, VmConfig, Serialize, Deserialize)] +pub struct ExecuteConfig { + #[system] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub bigint: Int256, + #[extension] + pub keccak: Keccak256, + #[extension] + pub sha256: Sha256, +} + +impl Default for ExecuteConfig { + // TODO(ayush): this should be auto-derived as vmconfig should have a with_continuations method + fn default() -> Self { + Self { + system: SystemConfig::default().with_continuations(), + rv32i: Rv32I::default(), + rv32m: Rv32M::default(), + io: Rv32Io::default(), + bigint: Int256::default(), + keccak: Keccak256::default(), + sha256: Sha256::default(), + } + } +} + fn main() -> Result<()> { let cli = Cli::parse(); @@ -106,13 +155,72 @@ fn main() -> Result<()> { let elf_path = get_elf_path(&program_dir); let elf = read_elf_file(&elf_path)?; - let config_path = program_dir.join(DEFAULT_APP_CONFIG_PATH); - let vm_config = read_config_toml_or_default(&config_path)?.app_vm_config; + // let config_path = program_dir.join(DEFAULT_APP_CONFIG_PATH); + // let vm_config = read_config_toml_or_default(&config_path)?.app_vm_config; + // let transpiler = vm_config.transpiler; + let vm_config = ExecuteConfig::default(); - let exe = VmExe::from_elf(elf, vm_config.transpiler())?; + let transpiler = Transpiler::::default() + .with_extension(Rv32ITranspilerExtension) + .with_extension(Rv32MTranspilerExtension) + .with_extension(Rv32IoTranspilerExtension) + .with_extension(Int256TranspilerExtension) + .with_extension(Keccak256TranspilerExtension) + .with_extension(Sha256TranspilerExtension); + + let exe = VmExe::from_elf(elf, transpiler)?; let executor = VmExecutor::new(vm_config); - executor.execute(exe, StdIn::default())?; + executor + .execute_e1(exe.clone(), vec![], None) + // .execute(exe.clone(), vec![]) + // .execute_metered(exe.clone(), vec![], widths, interactions) + .expect("Failed to execute program"); + + // let vm = VirtualMachine::new(default_engine(), vm_config.clone()); + // let pk = vm.keygen(); + // let (widths, interactions): (Vec, Vec) = { + // let vk = pk.get_vk(); + // vk.inner + // .per_air + // .iter() + // .map(|vk| { + // let total_width = vk.params.width.preprocessed.unwrap_or(0) + // + vk.params.width.cached_mains.iter().sum::() + // + vk.params.width.common_main + // // TODO(ayush): no magic value 4. should come from stark config + // + vk.params.width.after_challenge.iter().sum::() * 4; + // (total_width, vk.symbolic_constraints.interactions.len()) + // }) + // .unzip() + // }; + + // // E2 to find segment points + // let segments = executor.execute_metered(exe.clone(), vec![], widths, interactions)?; + // for Segment { + // clk_start, + // num_cycles, + // .. + // } in segments + // { + // // E1 till clk_start + // let state = executor.execute_e1(exe.clone(), vec![], Some(clk_start))?; + // assert!(state.clk == clk_start); + // // E3/tracegen from clk_start for num_cycles beginning with state + // let mut result = + // executor.execute_and_generate_segment::( + // exe.clone(), + // state, + // num_cycles, + // )?; + // // let proof_input = result.per_segment.pop().unwrap(); + // // let proof = tracing::info_span!("prove_single") + // // .in_scope(|| vm.prove_single(&pk, proof_input)); + + // // let proof_bytes = bitcode::serialize(&proof)?; + // // tracing::info!("Proof size: {} bytes", proof_bytes.len()); + // } + tracing::info!("Completed program: {}", program); } tracing::info!("All programs executed successfully"); diff --git a/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf b/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf index 0f81a3926f..1608170268 100755 Binary files a/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf and b/benchmarks/guest/bubblesort/elf/openvm-bubblesort-program.elf differ diff --git a/benchmarks/guest/bubblesort/src/main.rs b/benchmarks/guest/bubblesort/src/main.rs index 0dd7e51146..d859641504 100644 --- a/benchmarks/guest/bubblesort/src/main.rs +++ b/benchmarks/guest/bubblesort/src/main.rs @@ -1,7 +1,7 @@ use core::hint::black_box; use openvm as _; -const ARRAY_SIZE: usize = 100; +const ARRAY_SIZE: usize = 1_000; fn bubblesort(arr: &mut [T]) { let len = arr.len(); diff --git a/benchmarks/guest/factorial_iterative_u256/Cargo.toml b/benchmarks/guest/factorial_iterative_u256/Cargo.toml new file mode 100644 index 0000000000..a0abd084d8 --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/Cargo.toml @@ -0,0 +1,17 @@ +[workspace] +[package] +name = "openvm-factorial-iterative-u256-program" +version = "0.0.0" +edition = "2021" + +[dependencies] +openvm = { path = "../../../crates/toolchain/openvm", features = ["std"] } +openvm-bigint-guest = { path = "../../../extensions/bigint/guest" } + +[features] +default = [] + +[profile.profiling] +inherits = "release" +debug = 2 +strip = false diff --git a/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf b/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf new file mode 100755 index 0000000000..222122b27b Binary files /dev/null and b/benchmarks/guest/factorial_iterative_u256/elf/openvm-factorial-iterative-u256-program.elf differ diff --git a/benchmarks/guest/factorial_iterative_u256/openvm.toml b/benchmarks/guest/factorial_iterative_u256/openvm.toml new file mode 100644 index 0000000000..b226887890 --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/openvm.toml @@ -0,0 +1,4 @@ +[app_vm_config.rv32i] +[app_vm_config.rv32m] +[app_vm_config.io] +[app_vm_config.bigint] diff --git a/benchmarks/guest/factorial_iterative_u256/src/main.rs b/benchmarks/guest/factorial_iterative_u256/src/main.rs new file mode 100644 index 0000000000..359de19ae7 --- /dev/null +++ b/benchmarks/guest/factorial_iterative_u256/src/main.rs @@ -0,0 +1,16 @@ +use core::hint::black_box; +use openvm as _; +use openvm_bigint_guest::U256; + +// This will overflow but that is fine +const N: u32 = 65_000; + +pub fn main() { + let mut acc = U256::from_u32(1); + let mut i = U256::from_u32(N); + while i > black_box(U256::ZERO) { + acc *= i.clone(); + i -= U256::from_u32(1); + } + black_box(acc); +} diff --git a/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf b/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf index ac9fbf3e89..a7b1753491 100755 Binary files a/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf and b/benchmarks/guest/fibonacci_iterative/elf/openvm-fibonacci-iterative-program.elf differ diff --git a/benchmarks/guest/fibonacci_iterative/src/main.rs b/benchmarks/guest/fibonacci_iterative/src/main.rs index 09ceb5df41..f7ab8ec0f6 100644 --- a/benchmarks/guest/fibonacci_iterative/src/main.rs +++ b/benchmarks/guest/fibonacci_iterative/src/main.rs @@ -1,15 +1,15 @@ use core::hint::black_box; -use openvm as _; +use openvm::io::reveal_u32; -const N: u64 = 100_000; +const N: u32 = 900_000; pub fn main() { - let mut a: u64 = 0; - let mut b: u64 = 1; + let mut a: u32 = 0; + let mut b: u32 = 1; for _ in 0..black_box(N) { - let c: u64 = a.wrapping_add(b); + let c: u32 = a.wrapping_add(b); a = b; b = c; } - black_box(a); + reveal_u32(a, 0); } diff --git a/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf b/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf index 7dee9d4286..8696e249f0 100755 Binary files a/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf and b/benchmarks/guest/fibonacci_recursive/elf/openvm-fibonacci-recursive-program.elf differ diff --git a/benchmarks/guest/fibonacci_recursive/src/main.rs b/benchmarks/guest/fibonacci_recursive/src/main.rs index fae64a1b0f..9020bc91ef 100644 --- a/benchmarks/guest/fibonacci_recursive/src/main.rs +++ b/benchmarks/guest/fibonacci_recursive/src/main.rs @@ -1,14 +1,15 @@ use core::hint::black_box; -use openvm as _; +use openvm::io::reveal_u32; -const N: u64 = 25; +const N: u32 = 27; pub fn main() { let n = black_box(N); - black_box(fibonacci(n)); + let result = fibonacci(n); + reveal_u32(result, 0); } -fn fibonacci(n: u64) -> u64 { +fn fibonacci(n: u32) -> u32 { if n == 0 { 0 } else if n == 1 { diff --git a/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf b/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf index 7425897f99..3686d70424 100755 Binary files a/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf and b/benchmarks/guest/keccak256/elf/openvm-keccak256-program.elf differ diff --git a/benchmarks/guest/keccak256/src/main.rs b/benchmarks/guest/keccak256/src/main.rs index ee7ec8b09a..1d88ab2432 100644 --- a/benchmarks/guest/keccak256/src/main.rs +++ b/benchmarks/guest/keccak256/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_keccak256_guest::keccak256; -const INPUT_LENGTH_BYTES: usize = 100 * 1024; // 100 KB +const INPUT_LENGTH_BYTES: usize = 384 * 1024; pub fn main() { let mut input = Vec::with_capacity(INPUT_LENGTH_BYTES); diff --git a/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf b/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf index 0cf372eec3..390c945907 100755 Binary files a/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf and b/benchmarks/guest/keccak256_iter/elf/openvm-keccak256-iter-program.elf differ diff --git a/benchmarks/guest/keccak256_iter/src/main.rs b/benchmarks/guest/keccak256_iter/src/main.rs index 7ef36a5fa7..3c740c092c 100644 --- a/benchmarks/guest/keccak256_iter/src/main.rs +++ b/benchmarks/guest/keccak256_iter/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_keccak256_guest::keccak256; -const ITERATIONS: usize = 10_000; +const ITERATIONS: usize = 65_000; pub fn main() { // Initialize with hash of an empty vector diff --git a/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf b/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf index 54af6272d6..be7ca2922a 100755 Binary files a/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf and b/benchmarks/guest/quicksort/elf/openvm-quicksort-program.elf differ diff --git a/benchmarks/guest/quicksort/src/main.rs b/benchmarks/guest/quicksort/src/main.rs index 30218cf40e..a6579306c7 100644 --- a/benchmarks/guest/quicksort/src/main.rs +++ b/benchmarks/guest/quicksort/src/main.rs @@ -1,7 +1,7 @@ use core::hint::black_box; use openvm as _; -const ARRAY_SIZE: usize = 1_000; +const ARRAY_SIZE: usize = 3_500; fn quicksort(arr: &mut [T]) { if arr.len() <= 1 { diff --git a/benchmarks/guest/sha256/elf/openvm-sha256-program.elf b/benchmarks/guest/sha256/elf/openvm-sha256-program.elf index 9524e8f552..1f548633e2 100755 Binary files a/benchmarks/guest/sha256/elf/openvm-sha256-program.elf and b/benchmarks/guest/sha256/elf/openvm-sha256-program.elf differ diff --git a/benchmarks/guest/sha256/src/main.rs b/benchmarks/guest/sha256/src/main.rs index 2c22b2e369..a13bfbc8d5 100644 --- a/benchmarks/guest/sha256/src/main.rs +++ b/benchmarks/guest/sha256/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_sha256_guest::sha256; -const INPUT_LENGTH_BYTES: usize = 100 * 1024; // 100 KB +const INPUT_LENGTH_BYTES: usize = 384 * 1024; pub fn main() { let mut input = Vec::with_capacity(INPUT_LENGTH_BYTES); diff --git a/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf b/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf index 95b469ece5..65ebf86403 100755 Binary files a/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf and b/benchmarks/guest/sha256_iter/elf/openvm-sha256-iter-program.elf differ diff --git a/benchmarks/guest/sha256_iter/src/main.rs b/benchmarks/guest/sha256_iter/src/main.rs index 7d0d23dd7f..fe2acc265b 100644 --- a/benchmarks/guest/sha256_iter/src/main.rs +++ b/benchmarks/guest/sha256_iter/src/main.rs @@ -3,7 +3,7 @@ use openvm as _; use openvm_sha256_guest::sha256; -const ITERATIONS: usize = 20_000; +const ITERATIONS: usize = 150_000; pub fn main() { // Initialize with hash of an empty vector diff --git a/benchmarks/utils/src/build-elfs.rs b/benchmarks/utils/src/build-elfs.rs index 3bed7cf6fd..3ce24c7c5c 100644 --- a/benchmarks/utils/src/build-elfs.rs +++ b/benchmarks/utils/src/build-elfs.rs @@ -63,6 +63,12 @@ fn main() -> Result<()> { let programs_to_build = if cli.programs.is_empty() { available_programs } else { + for prog in &cli.programs { + if !available_programs.iter().any(|(name, _)| name == prog) { + tracing::warn!("Program '{}' not found in available programs", prog); + } + } + available_programs .into_iter() .filter(|(name, _)| cli.programs.contains(name)) @@ -70,6 +76,12 @@ fn main() -> Result<()> { }; // Filter out skipped programs + for prog in &cli.skip { + if !programs_to_build.iter().any(|(name, _)| name == prog) { + tracing::warn!("Program '{}' not found in programs to skip", prog); + } + } + let programs_to_build = programs_to_build .into_iter() .filter(|(name, _)| !cli.skip.contains(name)) diff --git a/crates/circuits/mod-builder/src/core_chip.rs b/crates/circuits/mod-builder/src/core_chip.rs index 30e9c65dbb..1f219c62f7 100644 --- a/crates/circuits/mod-builder/src/core_chip.rs +++ b/crates/circuits/mod-builder/src/core_chip.rs @@ -1,24 +1,29 @@ use itertools::Itertools; use num_bigint::BigUint; use num_traits::Zero; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, MinimalInstruction, - Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, DynAdapterInterface, DynArray, + MinimalInstruction, Result, StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, + VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ var_range::SharedVariableRangeCheckerChip, SubAir, TraceSubRowGenerator, }; -use openvm_instructions::instruction::Instruction; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, rap::BaseAirWithPublicValues, }; use openvm_stark_sdk::p3_baby_bear::BabyBear; -use serde::{Deserialize, Serialize}; -use serde_with::{serde_as, DisplayFromStr}; use crate::{ utils::{biguint_to_limbs_vec, limbs_to_biguint}, @@ -165,27 +170,21 @@ where } } -#[serde_as] -#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] -pub struct FieldExpressionRecord { - #[serde_as(as = "Vec")] - pub inputs: Vec, - pub flags: Vec, -} - -pub struct FieldExpressionCoreChip { - pub air: FieldExpressionCoreAir, +// TODO(arayi): use lifetimes and references for fields +pub struct FieldExpressionStep { + adapter: A, + pub expr: FieldExpr, + pub offset: usize, + pub local_opcode_idx: Vec, + pub opcode_flag_idx: Vec, pub range_checker: SharedVariableRangeCheckerChip, - pub name: String, - - /// Whether to finalize the trace. True if all-zero rows don't satisfy the constraints (e.g. - /// there is int_add) pub should_finalize: bool, } -impl FieldExpressionCoreChip { +impl FieldExpressionStep { pub fn new( + adapter: A, expr: FieldExpr, offset: usize, local_opcode_idx: Vec, @@ -194,145 +193,206 @@ impl FieldExpressionCoreChip { name: &str, should_finalize: bool, ) -> Self { - let air = FieldExpressionCoreAir::new(expr, offset, local_opcode_idx, opcode_flag_idx); + let opcode_flag_idx = if opcode_flag_idx.is_empty() && expr.needs_setup() { + // single op chip that needs setup, so there is only one default flag, must be 0. + vec![0] + } else { + // multi ops chip or no-setup chip, use as is. + opcode_flag_idx + }; + assert_eq!(opcode_flag_idx.len(), local_opcode_idx.len() - 1); tracing::info!( - "FieldExpressionCoreChip: opcode={name}, main_width={}", - BaseAir::::width(&air) + "FieldExpressionCoreStep: opcode={name}, main_width={}", + BaseAir::::width(&expr) ); Self { - air, + adapter, + expr, + offset, + local_opcode_idx, + opcode_flag_idx, range_checker, name: name.to_string(), should_finalize, } } + pub fn num_inputs(&self) -> usize { + self.expr.builder.num_input + } + + pub fn num_vars(&self) -> usize { + self.expr.builder.num_variables + } + + pub fn num_flags(&self) -> usize { + self.expr.builder.num_flags + } - pub fn expr(&self) -> &FieldExpr { - &self.air.expr + pub fn output_indices(&self) -> &[usize] { + &self.expr.builder.output_indices } } -impl VmCoreChip for FieldExpressionCoreChip +impl TraceStep for FieldExpressionStep where - I: VmAdapterInterface, - I::Reads: Into>, - AdapterRuntimeContext: From>>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into>, + WriteData: From>, + TraceContext<'a> = (), + >, { - type Record = FieldExpressionRecord; - type Air = FieldExpressionCoreAir; + fn get_opcode_name(&self, _opcode: usize) -> String { + self.name.clone() + } - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let field_element_limbs = self.air.expr.canonical_num_limbs(); - let limb_bits = self.air.expr.canonical_limb_bits(); - let data: DynArray<_> = reads.into(); - let data = data.0; - assert_eq!(data.len(), self.air.num_inputs() * field_element_limbs); - let data_u32: Vec = data.iter().map(|x| x.as_canonical_u32()).collect(); - - let mut inputs = vec![]; - for i in 0..self.air.num_inputs() { - let start = i * field_element_limbs; - let end = start + field_element_limbs; - let limb_slice = &data_u32[start..end]; - let input = limbs_to_biguint(limb_slice, limb_bits); - inputs.push(input); - } + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = row_slice.split_at_mut(A::WIDTH); - let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(self.air.offset); - let mut flags = vec![]; - - // If the chip doesn't need setup, (right now) it must be single op chip and thus no flag is - // needed. Otherwise, there is a flag for each opcode and will be derived by - // is_valid - sum(flags). - if self.expr().needs_setup() { - flags = vec![false; self.air.num_flags()]; - self.air - .opcode_flag_idx - .iter() - .enumerate() - .for_each(|(i, &flag_idx)| { - flags[flag_idx] = local_opcode_idx == self.air.local_opcode_idx[i] - }); - } + A::start(*state.pc, state.memory, adapter_row); - let vars = self.air.expr.execute(inputs.clone(), flags.clone()); - assert_eq!(vars.len(), self.air.num_vars()); + let data: DynArray<_> = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); - let outputs: Vec = self - .air - .output_indices() - .iter() - .map(|&i| vars[i].clone()) - .collect(); - let writes: Vec = outputs - .iter() - .map(|x| biguint_to_limbs_vec(x.clone(), limb_bits, field_element_limbs)) - .concat() - .into_iter() - .map(|x| F::from_canonical_u32(x)) - .collect(); + let (writes, inputs, flags) = run_field_expression(self, &data, instruction); - let ctx = AdapterRuntimeContext::<_, DynAdapterInterface<_>>::without_pc(writes); - Ok((ctx.into(), FieldExpressionRecord { inputs, flags })) - } - - fn get_opcode_name(&self, _opcode: usize) -> String { - self.name.clone() - } + // TODO(arayi): Should move this to fill_trace_row + self.expr + .generate_subrow((self.range_checker.as_ref(), inputs, flags), core_row); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.air.expr.generate_subrow( - (self.range_checker.as_ref(), record.inputs, record.flags), - row_slice, - ); - } + self.adapter + .write(state.memory, instruction, adapter_row, &writes.into()); - fn air(&self) -> &Self::Air { - &self.air + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + *trace_offset += width; + Ok(()) } - fn finalize(&self, trace: &mut RowMajorMatrix, num_records: usize) { - if !self.should_finalize || num_records == 0 { - return; - } - - let core_width = >::width(&self.air); - let adapter_width = trace.width() - core_width; - let dummy_row = self.generate_dummy_trace_row(adapter_width, core_width); - for row in trace.rows_mut().skip(num_records) { - row.copy_from_slice(&dummy_row); - } + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row: &mut [F]) { + let (adapter_row, _) = row.split_at_mut(A::WIDTH); + self.adapter.fill_trace_row(mem_helper, (), adapter_row); } -} -impl FieldExpressionCoreChip { // We will be setting is_valid = 0. That forces all flags be 0 (otherwise setup will be -1). // We generate a dummy row with all flags set to 0, then we set is_valid = 0. - fn generate_dummy_trace_row( - &self, - adapter_width: usize, - core_width: usize, - ) -> Vec { - let record = FieldExpressionRecord { - inputs: vec![BigUint::zero(); self.air.num_inputs()], - flags: vec![false; self.air.num_flags()], - }; - let mut row = vec![F::ZERO; adapter_width + core_width]; - let core_row = &mut row[adapter_width..]; + fn fill_dummy_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, row: &mut [F]) { + if !self.should_finalize { + return; + } + let inputs: Vec = vec![BigUint::zero(); self.num_inputs()]; + let flags: Vec = vec![false; self.num_flags()]; + let core_row = &mut row[A::WIDTH..]; // We **do not** want this trace row to update the range checker // so we must create a temporary range checker let tmp_range_checker = SharedVariableRangeCheckerChip::new(self.range_checker.bus()); - self.air.expr.generate_subrow( - (tmp_range_checker.as_ref(), record.inputs, record.flags), - core_row, - ); + self.expr + .generate_subrow((tmp_range_checker.as_ref(), inputs, flags), core_row); core_row[0] = F::ZERO; // is_valid = 0 - row } } + +impl StepExecutorE1 for FieldExpressionStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1>, WriteData: From>>, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let data: DynArray<_> = self.adapter.read(state, instruction).into(); + + let writes = run_field_expression(self, &data, instruction).0; + self.adapter.write(state, instruction, &writes.into()); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) + } +} + +fn run_field_expression( + step: &FieldExpressionStep, + data: &DynArray, + instruction: &Instruction, +) -> (DynArray, Vec, Vec) { + let field_element_limbs = step.expr.canonical_num_limbs(); + let limb_bits = step.expr.canonical_limb_bits(); + + let data = data.0.iter().map(|&x| x as u32).collect_vec(); + + assert_eq!(data.len(), step.num_inputs() * field_element_limbs); + + let mut inputs = Vec::with_capacity(step.num_inputs()); + for i in 0..step.num_inputs() { + let start = i * field_element_limbs; + let end = start + field_element_limbs; + let limb_slice = &data[start..end]; + let input = limbs_to_biguint(limb_slice, limb_bits); + inputs.push(input); + } + + let Instruction { opcode, .. } = instruction; + let local_opcode_idx = opcode.local_opcode_idx(step.offset); + let mut flags = vec![]; + + // If the chip doesn't need setup, (right now) it must be single op chip and thus no flag is + // needed. Otherwise, there is a flag for each opcode and will be derived by + // is_valid - sum(flags). + if step.expr.needs_setup() { + flags = vec![false; step.num_flags()]; + step.opcode_flag_idx + .iter() + .enumerate() + .for_each(|(i, &flag_idx)| { + flags[flag_idx] = local_opcode_idx == step.local_opcode_idx[i] + }); + } + + let vars = step.expr.execute(inputs.clone(), flags.clone()); + assert_eq!(vars.len(), step.num_vars()); + + let outputs: Vec = step + .output_indices() + .iter() + .map(|&i| vars[i].clone()) + .collect(); + let writes: DynArray<_> = outputs + .iter() + .map(|x| biguint_to_limbs_vec(x.clone(), limb_bits, field_element_limbs)) + .concat() + .into_iter() + .map(|x| x as u8) + .collect::>() + .into(); + + (writes, inputs, flags) +} diff --git a/crates/circuits/sha256-air/src/air.rs b/crates/circuits/sha256-air/src/air.rs index 96578984d0..b27af6ffa9 100644 --- a/crates/circuits/sha256-air/src/air.rs +++ b/crates/circuits/sha256-air/src/air.rs @@ -15,11 +15,11 @@ use openvm_stark_backend::{ use super::{ big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, - small_sig1_field, u32_into_limbs, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, - SHA256_H, SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, - SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S, + small_sig1_field, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, SHA256_H, + SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_WORD_BITS, + SHA256_WORD_U16S, SHA256_WORD_U8S, }; -use crate::constraint_word_addition; +use crate::{constraint_word_addition, u32_into_u16s}; /// Expects the message to be padded to a multiple of 512 bits #[derive(Clone, Debug)] @@ -154,7 +154,7 @@ impl Sha256Air { .assert_eq( a_limb, AB::Expr::from_canonical_u32( - u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], + u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], ), ); @@ -166,7 +166,7 @@ impl Sha256Air { .assert_eq( e_limb, AB::Expr::from_canonical_u32( - u32_into_limbs::<2>(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], + u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], ), ); } @@ -561,9 +561,8 @@ impl Sha256Air { .map(|rw_idx| { ( rw_idx, - u32_into_limbs::( - SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i], - )[j] as usize, + u32_into_u16s(SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i])[j] + as usize, ) }) .collect::>(), diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs index 903b7b0695..5822bfe235 100644 --- a/crates/circuits/sha256-air/src/tests.rs +++ b/crates/circuits/sha256-air/src/tests.rs @@ -13,18 +13,19 @@ use openvm_stark_backend::{ interaction::{BusIndex, InteractionBuilder}, p3_air::{Air, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, prover::types::AirProofInput, rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + utils::disable_debug_builder, + verifier::VerificationError, AirRef, Chip, ChipUsageGetter, }; -use openvm_stark_sdk::utils::create_seeded_rng; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::Rng; use crate::{ - compose, small_sig0_field, Sha256Air, Sha256RoundCols, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_U16S, SHA256_WORD_U8S, + Sha256Air, Sha256DigestCols, Sha256StepHelper, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, + SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, }; // A wrapper AIR purely for testing purposes @@ -50,6 +51,7 @@ impl Air for Sha256TestAir { // A wrapper Chip purely for testing purposes pub struct Sha256TestChip { pub air: Sha256TestAir, + pub step: Sha256StepHelper, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, } @@ -64,8 +66,9 @@ where fn generate_air_proof_input(self) -> AirProofInput { let trace = crate::generate_trace::>( - &self.air.sub_air, - self.bitwise_lookup_chip.clone(), + &self.step, + self.bitwise_lookup_chip.as_ref(), + >>::width(&self.air.sub_air), self.records, ); AirProofInput::simple_no_pis(trace) @@ -86,10 +89,10 @@ impl ChipUsageGetter for Sha256TestChip { } const SELF_BUS_IDX: BusIndex = 28; -#[test] -fn rand_sha256_test() { +type F = BabyBear; + +fn create_chip_with_rand_records() -> (Sha256TestChip, SharedBitwiseOperationLookupChip<8>) { let mut rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let len = rng.gen_range(1..100); @@ -105,129 +108,47 @@ fn rand_sha256_test() { air: Sha256TestAir { sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), }, + step: Sha256StepHelper::new(), bitwise_lookup_chip: bitwise_chip.clone(), records: random_records, }; + (chip, bitwise_chip) +} +#[test] +fn rand_sha256_test() { + let tester = VmChipTestBuilder::default(); + let (chip, bitwise_chip) = create_chip_with_rand_records(); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -// A wrapper Chip to test that the final_hash is properly constrained. -// This chip implements a malicious trace gen that violates the final_hash constraints. -pub struct Sha256TestBadFinalHashChip { - pub air: Sha256TestAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - pub records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, -} - -impl Chip for Sha256TestBadFinalHashChip -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let mut trace = crate::generate_trace::>( - &self.air.sub_air, - self.bitwise_lookup_chip.clone(), - self.records.clone(), - ); - - // Set the final_hash in the digest row of the last block of each hash to zero. - // That is, every hash that this chip does will result in a final_hash of zero. - for (i, row) in self.records.iter().enumerate() { - if row.1 { - let last_digest_row_idx = (i + 1) * SHA256_ROWS_PER_BLOCK - 1; - let last_digest_row: &mut crate::Sha256DigestCols> = - trace.row_mut(last_digest_row_idx)[..SHA256_DIGEST_WIDTH].borrow_mut(); - // Set the final_hash to all zeros +#[test] +fn negative_sha256_test_bad_final_hash() { + let tester = VmChipTestBuilder::default(); + let (chip, bitwise_chip) = create_chip_with_rand_records(); + + // Set the final_hash to all zeros + let modify_trace = |trace: &mut RowMajorMatrix| { + trace.row_chunks_exact_mut(1).for_each(|row| { + let mut row_slice = row.row_slice(0).to_vec(); + let cols: &mut Sha256DigestCols = row_slice[..SHA256_DIGEST_WIDTH].borrow_mut(); + if cols.flags.is_last_block.is_one() && cols.flags.is_digest_row.is_one() { for i in 0..SHA256_HASH_WORDS { for j in 0..SHA256_WORD_U8S { - last_digest_row.final_hash[i][j] = Val::::ZERO; + cols.final_hash[i][j] = F::ZERO; } } - - let (last_round_row, last_digest_row) = - trace.row_pair_mut(last_digest_row_idx - 1, last_digest_row_idx); - let last_round_row: &mut crate::Sha256RoundCols> = - last_round_row.borrow_mut(); - let last_digest_row: &mut crate::Sha256RoundCols> = - last_digest_row.borrow_mut(); - // fix the intermed_4 for the digest row - generate_intermed_4(last_round_row, last_digest_row); + row.values.copy_from_slice(&row_slice); } - } - - let non_padded_height = self.records.len() * SHA256_ROWS_PER_BLOCK; - let width = >>::width(&self.air.sub_air); - // recalculate the missing cells (second pass of generate_trace) - trace.values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.air.sub_air.generate_missing_cells(chunk, width, 0); - }); - - AirProofInput::simple_no_pis(trace) - } -} - -// Copy of private method in Sha256Air used for testing -/// Puts the correct intermed_4 in the `next_row` -fn generate_intermed_4( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, -) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let sig_w = small_sig0_field::(&w[i + 1]); - let sig_w_limbs: [F; SHA256_WORD_U16S] = - array::from_fn(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)); - for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { - next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb; - } - } -} - -impl ChipUsageGetter for Sha256TestBadFinalHashChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.len() * SHA256_ROWS_PER_BLOCK - } - - fn trace_width(&self) -> usize { - max(SHA256_ROUND_WIDTH, SHA256_DIGEST_WIDTH) - } -} - -#[test] -#[should_panic] -fn test_sha256_final_hash_constraints() { - let mut rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let len = rng.gen_range(1..100); - let random_records: Vec<_> = (0..len) - .map(|_| (array::from_fn(|_| rng.gen::()), true)) - .collect(); - let chip = Sha256TestBadFinalHashChip { - air: Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), - }, - bitwise_lookup_chip: bitwise_chip.clone(), - records: random_records, + }); }; - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .load(bitwise_chip) + .finalize(); + tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); } diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs index eaf9174f50..21483de642 100644 --- a/crates/circuits/sha256-air/src/trace.rs +++ b/crates/circuits/sha256-air/src/trace.rs @@ -1,41 +1,47 @@ use std::{array, borrow::BorrowMut, ops::Range}; use openvm_circuit_primitives::{ - bitwise_op_lookup::SharedBitwiseOperationLookupChip, utils::next_power_of_two_or_zero, + bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder, + utils::next_power_of_two_or_zero, }; use openvm_stark_backend::{ - p3_air::BaseAir, p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, + p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, }; use sha2::{compress256, digest::generic_array::GenericArray}; use super::{ - air::Sha256Air, big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, - get_flag_pt_array, maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, - SHA256_DIGEST_WIDTH, SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, + big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, get_flag_pt_array, + maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, SHA256_DIGEST_WIDTH, + SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, }; use crate::{ big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1, - u32_into_limbs, SHA256_BLOCK_U8S, SHA256_BUFFER_SIZE, SHA256_H, SHA256_INVALID_CARRY_A, + u32_into_bits_field, u32_into_u16s, SHA256_BLOCK_U8S, SHA256_H, SHA256_INVALID_CARRY_A, SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK, SHA256_WORD_BITS, SHA256_WORD_U16S, SHA256_WORD_U8S, }; +/// A helper struct for the SHA256 trace generation. +/// Also, separates the inner AIR from the trace generation. +pub struct Sha256StepHelper { + pub row_idx_encoder: Encoder, +} + /// The trace generation of SHA256 should be done in two passes. /// The first pass should do `get_block_trace` for every block and generate the invalid rows through /// `get_default_row` The second pass should go through all the blocks and call /// `generate_missing_cells` -impl Sha256Air { +impl Sha256StepHelper { + pub fn new() -> Self { + Self { + row_idx_encoder: Encoder::new(18, 2, false), + } + } /// This function takes the input_message (padding not handled), the previous hash, - /// and returns the new hash after processing the block input - pub fn get_block_hash( - prev_hash: &[u32; SHA256_HASH_WORDS], - input: [u8; SHA256_BLOCK_U8S], - ) -> [u32; SHA256_HASH_WORDS] { - let mut new_hash = *prev_hash; + /// and updates the prev_hash after processing the block input + pub fn get_block_hash(prev_hash: &mut [u32; SHA256_HASH_WORDS], input: [u8; SHA256_BLOCK_U8S]) { let input_array = [GenericArray::from(input)]; - compress256(&mut new_hash, &input_array); - new_hash + compress256(prev_hash, &input_array); } /// This function takes a 512-bit chunk of the input message (padding not handled), the previous @@ -52,18 +58,16 @@ impl Sha256Air { trace_width: usize, trace_start_col: usize, input: &[u32; SHA256_BLOCK_WORDS], - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, prev_hash: &[u32; SHA256_HASH_WORDS], is_last_block: bool, global_block_idx: u32, local_block_idx: u32, - buffer_vals: &[[F; SHA256_BUFFER_SIZE]; 4], ) { #[cfg(debug_assertions)] { assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK); assert!(trace_start_col + super::SHA256_WIDTH <= trace_width); - assert!(self.bitwise_lookup_bus == bitwise_lookup_chip.bus()); if local_block_idx == 0 { assert!(*prev_hash == SHA256_H); } @@ -87,14 +91,10 @@ impl Sha256Air { cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); // W_idx = M_idx - if i < SHA256_ROWS_PER_BLOCK / SHA256_ROUNDS_PER_ROW { + if i < 4 { for j in 0..SHA256_ROUNDS_PER_ROW { - cols.message_schedule.w[j] = u32_into_limbs::( - input[i * SHA256_ROUNDS_PER_ROW + j], - ) - .map(F::from_canonical_u32); - cols.message_schedule.carry_or_buffer[j] = - array::from_fn(|k| buffer_vals[i][j * SHA256_WORD_U16S * 2 + k]); + cols.message_schedule.w[j] = + u32_into_bits_field::(input[i * SHA256_ROUNDS_PER_ROW + j]); } } // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} @@ -108,14 +108,10 @@ impl Sha256Air { message_schedule[idx - 16], ]; let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - cols.message_schedule.w[j] = - u32_into_limbs::(w).map(F::from_canonical_u32); + cols.message_schedule.w[j] = u32_into_bits_field::(w); - let nums_limbs = nums - .iter() - .map(|x| u32_into_limbs::(*x)) - .collect::>(); - let w_limbs = u32_into_limbs::(w); + let nums_limbs = nums.map(|x| u32_into_u16s(x)); + let w_limbs = u32_into_u16s(w); // fill in the carrys for k in 0..SHA256_WORD_U16S { @@ -157,25 +153,18 @@ impl Sha256Air { // e = d + t1 let e = work_vars[3].wrapping_add(t1_sum); - cols.work_vars.e[j] = - u32_into_limbs::(e).map(F::from_canonical_u32); - let e_limbs = u32_into_limbs::(e); + cols.work_vars.e[j] = u32_into_bits_field::(e); + let e_limbs = u32_into_u16s(e); // a = t1 + t2 let a = t1_sum.wrapping_add(t2_sum); - cols.work_vars.a[j] = - u32_into_limbs::(a).map(F::from_canonical_u32); - let a_limbs = u32_into_limbs::(a); + cols.work_vars.a[j] = u32_into_bits_field::(a); + let a_limbs = u32_into_u16s(a); // fill in the carrys for k in 0..SHA256_WORD_U16S { - let t1_limb = t1.iter().fold(0, |acc, &num| { - acc + u32_into_limbs::(num)[k] - }); - let t2_limb = t2.iter().fold(0, |acc, &num| { - acc + u32_into_limbs::(num)[k] - }); + let t1_limb = t1.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); + let t2_limb = t2.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - let mut e_limb = - t1_limb + u32_into_limbs::(work_vars[3])[k]; + let mut e_limb = t1_limb + u32_into_u16s(work_vars[3])[k]; let mut a_limb = t1_limb + t2_limb; if k > 0 { a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32(); @@ -203,16 +192,14 @@ impl Sha256Air { if i > 0 { for j in 0..SHA256_ROUNDS_PER_ROW { let idx = i * SHA256_ROUNDS_PER_ROW + j; - let w_4 = u32_into_limbs::(message_schedule[idx - 4]); - let sig_0_w_3 = u32_into_limbs::(small_sig0( - message_schedule[idx - 3], - )); + let w_4 = u32_into_u16s(message_schedule[idx - 4]); + let sig_0_w_3 = u32_into_u16s(small_sig0(message_schedule[idx - 3])); cols.schedule_helper.intermed_4[j] = array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])); if j < SHA256_ROUNDS_PER_ROW - 1 { let w_3 = message_schedule[idx - 3]; cols.schedule_helper.w_3[j] = - u32_into_limbs::(w_3).map(F::from_canonical_u32); + u32_into_u16s(w_3).map(F::from_canonical_u32); } } } @@ -223,8 +210,7 @@ impl Sha256Air { row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut(); for j in 0..SHA256_ROUNDS_PER_ROW - 1 { let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3]; - cols.schedule_helper.w_3[j] = - u32_into_limbs::(w_3).map(F::from_canonical_u32); + cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_canonical_u32); } cols.flags.is_round_row = F::ZERO; cols.flags.is_first_4_rows = F::ZERO; @@ -237,29 +223,27 @@ impl Sha256Air { cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); let final_hash: [u32; SHA256_HASH_WORDS] = array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i])); - let final_hash_limbs: [[u32; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = - array::from_fn(|i| u32_into_limbs::(final_hash[i])); + let final_hash_limbs: [[u8; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = + array::from_fn(|i| final_hash[i].to_le_bytes()); // need to ensure final hash limbs are bytes, in order for // prev_hash[i] + work_vars[i] == final_hash[i] // to be constrained correctly for word in final_hash_limbs.iter() { for chunk in word.chunks(2) { - bitwise_lookup_chip.request_range(chunk[0], chunk[1]); + bitwise_lookup_chip.request_range(chunk[0] as u32, chunk[1] as u32); } } cols.final_hash = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(final_hash_limbs[i][j])) + array::from_fn(|j| F::from_canonical_u8(final_hash_limbs[i][j])) }); - cols.prev_hash = prev_hash - .map(|f| u32_into_limbs::(f).map(F::from_canonical_u32)); + cols.prev_hash = prev_hash.map(|f| u32_into_u16s(f).map(F::from_canonical_u32)); let hash = if is_last_block { - SHA256_H.map(u32_into_limbs::) + SHA256_H.map(u32_into_bits_field::) } else { cols.final_hash - .map(|f| limbs_into_u32(f.map(|x| x.as_canonical_u32()))) - .map(u32_into_limbs::) - } - .map(|x| x.map(F::from_canonical_u32)); + .map(|f| u32::from_le_bytes(f.map(|x| x.as_canonical_u32() as u8))) + .map(u32_into_bits_field::) + }; for i in 0..SHA256_ROUNDS_PER_ROW { cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; @@ -338,7 +322,10 @@ impl Sha256Air { /// Fills the `cols` as a padding row /// Note: we still need to correctly fill in the hash values, carries and intermeds - pub fn generate_default_row(self: &Sha256Air, cols: &mut Sha256RoundCols) { + pub fn generate_default_row( + self: &Sha256StepHelper, + cols: &mut Sha256RoundCols, + ) { cols.flags.is_round_row = F::ZERO; cols.flags.is_first_4_rows = F::ZERO; cols.flags.is_digest_row = F::ZERO; @@ -353,9 +340,7 @@ impl Sha256Air { cols.message_schedule.carry_or_buffer = [[F::ZERO; SHA256_WORD_U16S * 2]; SHA256_ROUNDS_PER_ROW]; - let hash = SHA256_H - .map(u32_into_limbs::) - .map(|x| x.map(F::from_canonical_u32)); + let hash = SHA256_H.map(u32_into_bits_field::); for i in 0..SHA256_ROUNDS_PER_ROW { cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; @@ -486,15 +471,16 @@ impl Sha256Air { } } +/// Generates a trace for a standalone SHA256 computation (currently only used for testing) /// `records` consists of pairs of `(input_block, is_last_block)`. pub fn generate_trace( - sub_air: &Sha256Air, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + step: &Sha256StepHelper, + bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, + width: usize, records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, ) -> RowMajorMatrix { let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; let height = next_power_of_two_or_zero(non_padded_height); - let width = >::width(sub_air); let mut values = F::zero_vec(height * width); struct BlockContext { @@ -522,7 +508,7 @@ pub fn generate_trace( prev_hash = SHA256_H; } else { local_block_idx += 1; - prev_hash = Sha256Air::get_block_hash(&prev_hash, input); + Sha256StepHelper::get_block_hash(&mut prev_hash, input); } } // first pass @@ -542,17 +528,16 @@ pub fn generate_trace( input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32 })) }); - sub_air.generate_block_trace( + step.generate_block_trace( block, width, 0, &input_words, - bitwise_lookup_chip.clone(), + bitwise_lookup_chip, &prev_hash, is_last_block, global_block_idx, local_block_idx, - &[[F::ZERO; 16]; 4], ); }); // second pass: padding rows @@ -560,14 +545,14 @@ pub fn generate_trace( .par_chunks_mut(width) .for_each(|row| { let cols: &mut Sha256RoundCols = row.borrow_mut(); - sub_air.generate_default_row(cols); + step.generate_default_row(cols); }); // second pass: non-padding rows values[width..] .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) .take(non_padded_height / SHA256_ROWS_PER_BLOCK) .for_each(|chunk| { - sub_air.generate_missing_cells(chunk, width, 0); + step.generate_missing_cells(chunk, width, 0); }); RowMajorMatrix::new(values, width) } diff --git a/crates/circuits/sha256-air/src/utils.rs b/crates/circuits/sha256-air/src/utils.rs index abf8b6e7f2..1d56314b13 100644 --- a/crates/circuits/sha256-air/src/utils.rs +++ b/crates/circuits/sha256-air/src/utils.rs @@ -74,10 +74,14 @@ pub const SHA256_H: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, ]; -/// Convert a u32 into a list of limbs in little endian -pub fn u32_into_limbs(num: u32) -> [u32; NUM_LIMBS] { - let limb_bits = 32 / NUM_LIMBS; - array::from_fn(|i| (num >> (limb_bits * i)) & ((1 << limb_bits) - 1)) +/// Convert a u32 into a list of bits in little endian then convert each bit into a field element +pub fn u32_into_bits_field(num: u32) -> [F; SHA256_WORD_BITS] { + array::from_fn(|i| F::from_bool((num >> i) & 1 == 1)) +} + +/// Convert a u32 into a an array of 2 16-bit limbs in little endian +pub fn u32_into_u16s(num: u32) -> [u32; 2] { + [num & 0xffff, num >> 16] } /// Convert a list of limbs in little endian into a u32 diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index 532c9b8d1c..02b2fbb8b4 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -12,7 +12,7 @@ use openvm_circuit::{ SystemConfig, SystemExecutor, SystemPeriphery, VmChipComplex, VmConfig, VmInventoryError, }, circuit_derive::{Chip, ChipUsageGetter}, - derive::{AnyEnum, InstructionExecutor}, + derive::{AnyEnum, InsExecutorE1, InstructionExecutor}, }; use openvm_ecc_circuit::{ WeierstrassExtension, WeierstrassExtensionExecutor, WeierstrassExtensionPeriphery, @@ -63,7 +63,7 @@ pub struct SdkVmConfig { pub ecc: Option, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, InsExecutorE1)] pub enum SdkVmConfigExecutor { #[any_enum] System(SystemExecutor), diff --git a/crates/sdk/src/prover/vm/local.rs b/crates/sdk/src/prover/vm/local.rs index b56c6a1ad3..85c604bb98 100644 --- a/crates/sdk/src/prover/vm/local.rs +++ b/crates/sdk/src/prover/vm/local.rs @@ -99,7 +99,7 @@ where exe.clone(), input.clone(), |seg_idx, mut seg| { - final_memory = mem::take(&mut seg.final_memory); + final_memory = mem::take(&mut seg.ctrl.final_memory); let proof_input = info_span!("trace_gen", segment = seg_idx) .in_scope(|| seg.generate_proof_input(Some(committed_program.clone())))?; info_span!("prove_segment", segment = seg_idx) diff --git a/crates/toolchain/instructions/src/exe.rs b/crates/toolchain/instructions/src/exe.rs index fb84ec7da5..9db5f242ac 100644 --- a/crates/toolchain/instructions/src/exe.rs +++ b/crates/toolchain/instructions/src/exe.rs @@ -5,8 +5,9 @@ use serde::{Deserialize, Serialize}; use crate::program::Program; -/// Memory image is a map from (address space, address) to word. -pub type MemoryImage = BTreeMap<(u32, u32), F>; +// TODO[jpw]: delete this +/// Memory image is a map from (address space, address * size_of) to u8. +pub type SparseMemoryImage = BTreeMap<(u32, u32), u8>; /// Stores the starting address, end address, and name of a set of function. pub type FnBounds = BTreeMap; @@ -22,7 +23,7 @@ pub struct VmExe { /// Start address of pc. pub pc_start: u32, /// Initial memory image. - pub init_memory: MemoryImage, + pub init_memory: SparseMemoryImage, /// Starting + ending bounds for each function. pub fn_bounds: FnBounds, } @@ -40,7 +41,7 @@ impl VmExe { self.pc_start = pc_start; self } - pub fn with_init_memory(mut self, init_memory: MemoryImage) -> Self { + pub fn with_init_memory(mut self, init_memory: SparseMemoryImage) -> Self { self.init_memory = init_memory; self } diff --git a/crates/toolchain/platform/src/alloc.rs b/crates/toolchain/platform/src/alloc.rs new file mode 100644 index 0000000000..0af25a3671 --- /dev/null +++ b/crates/toolchain/platform/src/alloc.rs @@ -0,0 +1,62 @@ +extern crate alloc; + +use alloc::alloc::{alloc, dealloc, handle_alloc_error, Layout}; +use core::ptr::NonNull; + +/// Bytes allocated according to the given Layout +pub struct AlignedBuf { + pub ptr: *mut u8, + pub layout: Layout, +} + +impl AlignedBuf { + /// Allocate a new buffer whose start address is aligned to `align` bytes. + /// *NOTE* if `len` is zero then a creates new `NonNull` that is dangling and 16-byte aligned. + pub fn uninit(len: usize, align: usize) -> Self { + let layout = Layout::from_size_align(len, align).unwrap(); + if layout.size() == 0 { + return Self { + ptr: NonNull::::dangling().as_ptr() as *mut u8, + layout, + }; + } + // SAFETY: `len` is nonzero + let ptr = unsafe { alloc(layout) }; + if ptr.is_null() { + handle_alloc_error(layout); + } + AlignedBuf { ptr, layout } + } + + /// Allocate a new buffer whose start address is aligned to `align` bytes + /// and copy the given data into it. + /// + /// # Safety + /// - `bytes` must not be null + /// - `len` should not be zero + /// + /// See [alloc]. In particular `data` should not be empty. + pub unsafe fn new(bytes: *const u8, len: usize, align: usize) -> Self { + let buf = Self::uninit(len, align); + // SAFETY: + // - src and dst are not null + // - src and dst are allocated for size + // - no alignment requirements on u8 + // - non-overlapping since ptr is newly allocated + unsafe { + core::ptr::copy_nonoverlapping(bytes, buf.ptr, len); + } + + buf + } +} + +impl Drop for AlignedBuf { + fn drop(&mut self) { + if self.layout.size() != 0 { + unsafe { + dealloc(self.ptr, self.layout); + } + } + } +} diff --git a/crates/toolchain/platform/src/lib.rs b/crates/toolchain/platform/src/lib.rs index 901666d530..4d02cc3138 100644 --- a/crates/toolchain/platform/src/lib.rs +++ b/crates/toolchain/platform/src/lib.rs @@ -4,14 +4,17 @@ #![deny(rustdoc::broken_intra_doc_links)] #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] -#[cfg(all(feature = "rust-runtime", target_os = "zkvm"))] +#[cfg(target_os = "zkvm")] pub use openvm_custom_insn::{custom_insn_i, custom_insn_r}; +#[cfg(target_os = "zkvm")] +pub mod alloc; #[cfg(all(feature = "export-getrandom", target_os = "zkvm"))] mod getrandom; #[cfg(all(feature = "rust-runtime", target_os = "zkvm"))] pub mod heap; #[cfg(all(feature = "export-libm", target_os = "zkvm"))] mod libm_extern; + pub mod memory; pub mod print; #[cfg(feature = "rust-runtime")] diff --git a/crates/toolchain/tests/Cargo.toml b/crates/toolchain/tests/Cargo.toml index d31d388c32..ce882932a7 100644 --- a/crates/toolchain/tests/Cargo.toml +++ b/crates/toolchain/tests/Cargo.toml @@ -8,11 +8,15 @@ homepage.workspace = true repository.workspace = true [dependencies] +openvm-build.workspace = true +openvm-transpiler.workspace = true +eyre.workspace = true +tempfile.workspace = true + +[dev-dependencies] openvm-stark-backend.workspace = true openvm-stark-sdk.workspace = true openvm-circuit = { workspace = true, features = ["test-utils"] } -openvm-transpiler.workspace = true -openvm-build.workspace = true openvm-algebra-transpiler.workspace = true openvm-bigint-circuit.workspace = true openvm-rv32im-circuit.workspace = true @@ -21,10 +25,7 @@ openvm-algebra-circuit.workspace = true openvm-ecc-guest = { workspace = true, features = ["halo2curves", "k256"] } openvm-instructions = { workspace = true } openvm-platform = { workspace = true } - -eyre.workspace = true test-case.workspace = true -tempfile.workspace = true serde = { workspace = true, features = ["alloc"] } derive_more = { workspace = true, features = ["from"] } diff --git a/crates/toolchain/tests/src/utils.rs b/crates/toolchain/tests/src/utils.rs deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/crates/toolchain/transpiler/src/util.rs b/crates/toolchain/transpiler/src/util.rs index d9135de153..c5711653ff 100644 --- a/crates/toolchain/transpiler/src/util.rs +++ b/crates/toolchain/transpiler/src/util.rs @@ -1,7 +1,7 @@ use std::collections::BTreeMap; use openvm_instructions::{ - exe::MemoryImage, + exe::SparseMemoryImage, instruction::Instruction, riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS}, utils::isize_to_field, @@ -165,17 +165,14 @@ pub fn nop() -> Instruction { } } -/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as, address) -> word) -pub fn elf_memory_image_to_openvm_memory_image( +/// Converts our memory image (u32 -> [u8; 4]) into Vm memory image ((as=2, address) -> byte) +pub fn elf_memory_image_to_openvm_memory_image( memory_image: BTreeMap, -) -> MemoryImage { - let mut result = MemoryImage::new(); +) -> SparseMemoryImage { + let mut result = SparseMemoryImage::new(); for (addr, word) in memory_image { for (i, byte) in word.to_le_bytes().into_iter().enumerate() { - result.insert( - (RV32_MEMORY_AS, addr + i as u32), - F::from_canonical_u8(byte), - ); + result.insert((RV32_MEMORY_AS, addr + i as u32), byte); } } result diff --git a/crates/vm/derive/src/lib.rs b/crates/vm/derive/src/lib.rs index 37dca6e4ed..a7913893df 100644 --- a/crates/vm/derive/src/lib.rs +++ b/crates/vm/derive/src/lib.rs @@ -113,6 +113,129 @@ pub fn instruction_executor_derive(input: TokenStream) -> TokenStream { } } +#[proc_macro_derive(InsExecutorE1)] +pub fn ins_executor_e1_executor_derive(input: TokenStream) -> TokenStream { + let ast: syn::DeriveInput = syn::parse(input).unwrap(); + + let name = &ast.ident; + let generics = &ast.generics; + let (impl_generics, ty_generics, _) = generics.split_for_impl(); + + match &ast.data { + Data::Struct(inner) => { + // Check if the struct has only one unnamed field + let inner_ty = match &inner.fields { + Fields::Unnamed(fields) => { + if fields.unnamed.len() != 1 { + panic!("Only one unnamed field is supported"); + } + fields.unnamed.first().unwrap().ty.clone() + } + _ => panic!("Only unnamed fields are supported"), + }; + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let mut new_generics = generics.clone(); + let where_clause = new_generics.make_where_clause(); + where_clause + .predicates + .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InsExecutorE1 }); + quote! { + impl #impl_generics ::openvm_circuit::arch::InsExecutorE1 for #name #ty_generics #where_clause { + fn execute_e1( + &mut self, + state: &mut ::openvm_circuit::arch::VmStateMut<::openvm_circuit::system::memory::online::GuestMemory, Ctx>, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + ) -> ::openvm_circuit::arch::Result<()> + where + Ctx: ::openvm_circuit::arch::execution_mode::E1E2ExecutionCtx, + { + self.0.execute_e1(state, instruction) + } + + fn execute_metered( + &mut self, + state: &mut ::openvm_circuit::arch::VmStateMut<::openvm_circuit::system::memory::online::GuestMemory, ::openvm_circuit::arch::execution_mode::metered::MeteredCtx>, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction, + chip_index: usize, + ) -> ::openvm_circuit::arch::Result<()> + where { + self.0.execute_metered(state, instruction, chip_index) + } + } + } + .into() + } + Data::Enum(e) => { + let variants = e + .variants + .iter() + .map(|variant| { + let variant_name = &variant.ident; + + let mut fields = variant.fields.iter(); + let field = fields.next().unwrap(); + assert!(fields.next().is_none(), "Only one field is supported"); + (variant_name, field) + }) + .collect::>(); + let first_ty_generic = ast + .generics + .params + .first() + .and_then(|param| match param { + GenericParam::Type(type_param) => Some(&type_param.ident), + _ => None, + }) + .expect("First generic must be type for Field"); + // Use full path ::openvm_circuit... so it can be used either within or outside the vm + // crate. Assume F is already generic of the field. + let execute_e1_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InsExecutorE1<#first_ty_generic>>::execute_e1(x, state, instruction) + } + }).collect::>(); + let execute_metered_arms = variants.iter().map(|(variant_name, field)| { + let field_ty = &field.ty; + quote! { + #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InsExecutorE1<#first_ty_generic>>::execute_metered(x, state, instruction, chip_index) + } + }).collect::>(); + + quote! { + impl #impl_generics ::openvm_circuit::arch::InsExecutorE1<#first_ty_generic> for #name #ty_generics { + fn execute_e1( + &mut self, + state: &mut ::openvm_circuit::arch::VmStateMut<::openvm_circuit::system::memory::online::GuestMemory, Ctx>, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>, + ) -> ::openvm_circuit::arch::Result<()> + where + Ctx: ::openvm_circuit::arch::execution_mode::E1E2ExecutionCtx, + { + match self { + #(#execute_e1_arms,)* + } + } + + fn execute_metered( + &mut self, + state: &mut ::openvm_circuit::arch::VmStateMut<::openvm_circuit::system::memory::online::GuestMemory, ::openvm_circuit::arch::execution_mode::metered::MeteredCtx>, + instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>, + chip_index: usize, + ) -> ::openvm_circuit::arch::Result<()> { + match self { + #(#execute_metered_arms,)* + } + } + } + } + .into() + } + Data::Union(_) => unimplemented!("Unions are not supported"), + } +} + /// Derives `AnyEnum` trait on an enum type. /// By default an enum arm will just return `self` as `&dyn Any`. /// @@ -347,7 +470,7 @@ pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::T let periphery_type = Ident::new(&format!("{}Periphery", name), name.span()); TokenStream::from(quote! { - #[derive(::openvm_circuit::circuit_derive::ChipUsageGetter, ::openvm_circuit::circuit_derive::Chip, ::openvm_circuit::derive::InstructionExecutor, ::derive_more::derive::From, ::openvm_circuit::derive::AnyEnum)] + #[derive(::openvm_circuit::circuit_derive::ChipUsageGetter, ::openvm_circuit::circuit_derive::Chip, ::openvm_circuit::derive::InstructionExecutor, ::openvm_circuit::derive::InsExecutorE1, ::derive_more::derive::From, ::openvm_circuit::derive::AnyEnum)] pub enum #executor_type { #[any_enum] #source_name_upper(#source_executor_type), diff --git a/crates/vm/src/arch/config.rs b/crates/vm/src/arch/config.rs index 30d92130a1..239f9eec7c 100644 --- a/crates/vm/src/arch/config.rs +++ b/crates/vm/src/arch/config.rs @@ -8,9 +8,9 @@ use openvm_stark_backend::{p3_field::PrimeField32, ChipUsageGetter}; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use super::{ - segment::DefaultSegmentationStrategy, AnyEnum, InstructionExecutor, SegmentationStrategy, - SystemComplex, SystemExecutor, SystemPeriphery, VmChipComplex, VmInventoryError, - PUBLIC_VALUES_AIR_ID, + segmentation_strategy::{DefaultSegmentationStrategy, SegmentationStrategy}, + AnyEnum, InstructionExecutor, SystemComplex, SystemExecutor, SystemPeriphery, VmChipComplex, + VmInventoryError, PUBLIC_VALUES_AIR_ID, }; use crate::system::memory::BOUNDARY_AIR_OFFSET; @@ -45,6 +45,7 @@ pub struct MemoryConfig { /// space `0` in memory. pub as_height: usize, /// The offset of the address space. Should be fixed to equal `1`. + // TODO[jpw]: remove this and make constant pub as_offset: u32, pub pointer_max_bits: usize, /// All timestamps must be in the range `[0, 2^clk_max_bits)`. Maximum allowed: 29. diff --git a/crates/vm/src/arch/execution.rs b/crates/vm/src/arch/execution.rs index 4edc88d355..6ac124aa42 100644 --- a/crates/vm/src/arch/execution.rs +++ b/crates/vm/src/arch/execution.rs @@ -1,4 +1,4 @@ -use std::{cell::RefCell, rc::Rc}; +use std::cell::RefCell; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -6,13 +6,22 @@ use openvm_instructions::{ }; use openvm_stark_backend::{ interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, - p3_field::FieldAlgebra, + p3_field::{FieldAlgebra, PrimeField32}, }; use serde::{Deserialize, Serialize}; use thiserror::Error; -use super::Streams; -use crate::system::{memory::MemoryController, program::ProgramBus}; +use super::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + Streams, +}; +use crate::system::{ + memory::{ + online::{GuestMemory, TracingMemory}, + MemoryController, + }, + program::ProgramBus, +}; pub type Result = std::result::Result; @@ -66,8 +75,30 @@ pub enum ExecutionError { DidNotTerminate, #[error("program exit code {0}")] FailedWithExitCode(u32), + #[error("trace buffer out of bounds: requested {requested} but capacity is {capacity}")] + TraceBufferOutOfBounds { requested: usize, capacity: usize }, +} + +/// Global VM state accessible during instruction execution. +/// The state is generic in guest memory `MEM` and additional host state `CTX`. +/// The host state is execution context specific. +#[derive(derive_new::new)] +pub struct VmStateMut<'a, MEM, CTX> { + pub pc: &'a mut u32, + pub memory: &'a mut MEM, + pub ctx: &'a mut CTX, +} + +impl VmStateMut<'_, TracingMemory, CTX> { + // TODO: store as u32 directly + #[inline(always)] + pub fn ins_start(&self, from_state: &mut ExecutionState) { + from_state.pc = F::from_canonical_u32(*self.pc); + from_state.timestamp = F::from_canonical_u32(self.memory.timestamp); + } } +// TODO: old pub trait InstructionExecutor { /// Runtime execution of the instruction, if the instruction is owned by the /// current instance. May internally store records of this call for later trace generation. @@ -83,22 +114,58 @@ pub trait InstructionExecutor { fn get_opcode_name(&self, opcode: usize) -> String; } -impl> InstructionExecutor for RefCell { - fn execute( +/// New trait for instruction execution +pub trait InsExecutorE1 { + fn execute_e1( &mut self, - memory: &mut MemoryController, + state: &mut VmStateMut, instruction: &Instruction, - prev_state: ExecutionState, - ) -> Result> { - self.borrow_mut().execute(memory, instruction, prev_state) + ) -> Result<()> + where + F: PrimeField32, + Ctx: E1E2ExecutionCtx; + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> + where + F: PrimeField32; +} + +impl InsExecutorE1 for RefCell +where + C: InsExecutorE1, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + F: PrimeField32, + Ctx: E1E2ExecutionCtx, + { + self.borrow_mut().execute_e1(state, instruction) } - fn get_opcode_name(&self, opcode: usize) -> String { - self.borrow().get_opcode_name(opcode) + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> + where + F: PrimeField32, + { + self.borrow_mut() + .execute_metered(state, instruction, chip_index) } } -impl> InstructionExecutor for Rc> { +impl> InstructionExecutor for RefCell { fn execute( &mut self, memory: &mut MemoryController, @@ -325,11 +392,11 @@ impl From<(u32, Option)> for PcIncOrSet { pub trait PhantomSubExecutor: Send { fn phantom_execute( &mut self, - memory: &MemoryController, + memory: &GuestMemory, streams: &mut Streams, discriminant: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, c_upper: u16, ) -> eyre::Result<()>; } diff --git a/crates/vm/src/arch/execution_control.rs b/crates/vm/src/arch/execution_control.rs new file mode 100644 index 0000000000..1be6a166fe --- /dev/null +++ b/crates/vm/src/arch/execution_control.rs @@ -0,0 +1,64 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{ExecutionError, VmChipComplex, VmConfig, VmSegmentState}; + +/// Trait for execution control, determining segmentation and stopping conditions +pub trait ExecutionControl +where + F: PrimeField32, + VC: VmConfig, +{ + /// Host context + type Ctx; + + /// Determines if execution should suspend + fn should_suspend( + &mut self, + state: &mut VmSegmentState, + chip_complex: &VmChipComplex, + ) -> bool; + + /// Called before execution begins + fn on_start( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ); + + /// Called after suspend or terminate + fn on_suspend_or_terminate( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: Option, + ); + + fn on_suspend( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + self.on_suspend_or_terminate(state, chip_complex, None); + } + + fn on_terminate( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: u32, + ) { + self.on_suspend_or_terminate(state, chip_complex, Some(exit_code)); + } + + /// Execute a single instruction + // TODO(ayush): change instruction to Instruction / PInstruction + fn execute_instruction( + &mut self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32; +} diff --git a/crates/vm/src/arch/execution_mode/e1.rs b/crates/vm/src/arch/execution_mode/e1.rs new file mode 100644 index 0000000000..2bb5a85a88 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/e1.rs @@ -0,0 +1,85 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::arch::{ + execution_control::ExecutionControl, execution_mode::E1E2ExecutionCtx, ExecutionError, + InsExecutorE1, VmChipComplex, VmConfig, VmSegmentState, VmStateMut, +}; + +pub type E1Ctx = (); + +impl E1E2ExecutionCtx for E1Ctx { + fn on_memory_operation(&mut self, _address_space: u32, _ptr: u32, _size: u32) {} +} + +/// Implementation of the ExecutionControl trait using the old segmentation strategy +#[derive(Default, derive_new::new)] +pub struct E1ExecutionControl { + pub clk_end: Option, +} + +impl ExecutionControl for E1ExecutionControl +where + F: PrimeField32, + VC: VmConfig, + VC::Executor: InsExecutorE1, +{ + type Ctx = E1Ctx; + + fn should_suspend( + &mut self, + state: &mut VmSegmentState, + _chip_complex: &VmChipComplex, + ) -> bool { + if let Some(clk_end) = self.clk_end { + state.clk >= clk_end + } else { + false + } + } + + fn on_start( + &mut self, + _state: &mut VmSegmentState, + _chip_complex: &mut VmChipComplex, + ) { + } + + fn on_suspend_or_terminate( + &mut self, + _state: &mut VmSegmentState, + _chip_complex: &mut VmChipComplex, + _exit_code: Option, + ) { + } + + /// Execute a single instruction + fn execute_instruction( + &mut self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32, + { + let &Instruction { opcode, .. } = instruction; + + if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { + let mut vm_state = VmStateMut { + pc: &mut state.pc, + memory: state.memory.as_mut().unwrap(), + ctx: &mut state.ctx, + }; + executor.execute_e1(&mut vm_state, instruction)?; + } else { + return Err(ExecutionError::DisabledOperation { + pc: state.pc, + opcode, + }); + }; + state.clk += 1; + + Ok(()) + } +} diff --git a/crates/vm/src/arch/execution_mode/metered/bounded.rs b/crates/vm/src/arch/execution_mode/metered/bounded.rs new file mode 100644 index 0000000000..0a933a5493 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/bounded.rs @@ -0,0 +1,183 @@ +use openvm_instructions::riscv::RV32_IMM_AS; + +use crate::{ + arch::{execution_mode::E1E2ExecutionCtx, PUBLIC_VALUES_AIR_ID}, + system::memory::{dimensions::MemoryDimensions, CHUNK, CHUNK_BITS}, +}; + +// TODO(ayush): can segmentation also be triggered by timestamp overflow? should that be tracked? +#[derive(Debug)] +pub struct MeteredCtxBounded { + pub trace_heights: Vec, + + continuations_enabled: bool, + num_access_adapters: u8, + // TODO(ayush): take alignment into account for access adapters + #[allow(dead_code)] + as_byte_alignment_bits: Vec, + pub memory_dimensions: MemoryDimensions, + + // Indices of leaf nodes in the memory merkle tree + pub leaf_indices: Vec, +} + +impl MeteredCtxBounded { + pub fn new( + num_traces: usize, + continuations_enabled: bool, + num_access_adapters: u8, + as_byte_alignment_bits: Vec, + memory_dimensions: MemoryDimensions, + ) -> Self { + Self { + trace_heights: vec![0; num_traces], + continuations_enabled, + num_access_adapters, + as_byte_alignment_bits, + memory_dimensions, + leaf_indices: Vec::new(), + } + } +} + +impl MeteredCtxBounded { + fn update_boundary_merkle_heights(&mut self, address_space: u32, ptr: u32, size: u32) { + let boundary_idx = if self.continuations_enabled { + PUBLIC_VALUES_AIR_ID + } else { + PUBLIC_VALUES_AIR_ID + 1 + }; + let poseidon2_idx = self.trace_heights.len() - 2; + + let num_blocks = (size + CHUNK as u32 - 1) >> CHUNK_BITS; + for i in 0..num_blocks { + let addr = ptr.wrapping_add(i * CHUNK as u32); + let block_id = addr >> CHUNK_BITS; + let leaf_id = self + .memory_dimensions + .label_to_index((address_space, block_id)); + + if let Err(insert_idx) = self.leaf_indices.binary_search(&leaf_id) { + self.leaf_indices.insert(insert_idx, leaf_id); + + self.trace_heights[boundary_idx] += 1; + self.trace_heights[poseidon2_idx] += 2; + + if self.continuations_enabled { + let pred_id = insert_idx.checked_sub(1).map(|idx| self.leaf_indices[idx]); + let succ_id = (insert_idx < self.leaf_indices.len() - 1) + .then(|| self.leaf_indices[insert_idx + 1]); + let height_change = calculate_merkle_node_updates( + leaf_id, + pred_id, + succ_id, + self.memory_dimensions.overall_height() as u32, + ); + self.trace_heights[boundary_idx + 1] += height_change * 2; + self.trace_heights[poseidon2_idx] += height_change * 2; + } + } + } + } + + fn update_adapter_heights_batch(&mut self, size: u32, num: u32) { + let adapter_offset = if self.continuations_enabled { + PUBLIC_VALUES_AIR_ID + 2 + } else { + PUBLIC_VALUES_AIR_ID + 1 + }; + + apply_adapter_updates_batch(size, num, &mut self.trace_heights[adapter_offset..]); + } + + fn update_adapter_heights(&mut self, size: u32) { + self.update_adapter_heights_batch(size, 1); + } + + pub fn finalize_access_adapter_heights(&mut self) { + self.update_adapter_heights_batch(CHUNK as u32, self.leaf_indices.len() as u32); + } + + pub fn trace_heights_if_finalized(&mut self) -> Vec { + let num_leaves = self.leaf_indices.len() as u32; + let mut access_adapter_updates = vec![0; self.num_access_adapters as usize]; + apply_adapter_updates_batch(CHUNK as u32, num_leaves, &mut access_adapter_updates); + + let adapter_offset = if self.continuations_enabled { + PUBLIC_VALUES_AIR_ID + 2 + } else { + PUBLIC_VALUES_AIR_ID + 1 + }; + self.trace_heights + .iter() + .enumerate() + .map(|(i, &height)| { + if i >= adapter_offset && i < adapter_offset + access_adapter_updates.len() { + height + access_adapter_updates[i - adapter_offset] + } else { + height + } + }) + .collect() + } +} + +impl E1E2ExecutionCtx for MeteredCtxBounded { + fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) { + debug_assert!( + address_space != RV32_IMM_AS, + "address space must not be immediate" + ); + debug_assert!(size.is_power_of_two(), "size must be a power of 2"); + + // Handle access adapter updates + self.update_adapter_heights(size); + + // Handle merkle tree updates + // TODO(ayush): use a looser upper bound + // see if this can be approximated by total number of reads/writes for AS != register + self.update_boundary_merkle_heights(address_space, ptr, size); + } +} + +fn apply_adapter_updates_batch(size: u32, num: u32, trace_heights: &mut [u32]) { + let size_bits = size.ilog2(); + for adapter_bits in (3..=size_bits).rev() { + trace_heights[adapter_bits as usize - 1] += num << (size_bits - adapter_bits + 1); + } +} + +fn calculate_merkle_node_updates( + leaf_id: u64, + pred_id: Option, + succ_id: Option, + height: u32, +) -> u32 { + // First node requires height many updates + if pred_id.is_none() && succ_id.is_none() { + return height; + } + + // Calculate the difference in divergence + let mut diff = 0; + + // Add new divergences between pred and leaf_index + if let Some(p) = pred_id { + let new_divergence = (p ^ leaf_id).ilog2(); + diff += new_divergence; + } + + // Add new divergences between leaf_index and succ + if let Some(s) = succ_id { + let new_divergence = (leaf_id ^ s).ilog2(); + diff += new_divergence; + } + + // Remove old divergence between pred and succ if both existed + if let (Some(p), Some(s)) = (pred_id, succ_id) { + let old_divergence = (p ^ s).ilog2(); + diff -= old_divergence; + } + + diff +} diff --git a/crates/vm/src/arch/execution_mode/metered/exact.rs b/crates/vm/src/arch/execution_mode/metered/exact.rs new file mode 100644 index 0000000000..7640960813 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/exact.rs @@ -0,0 +1,393 @@ +use std::collections::BTreeMap; + +use openvm_instructions::riscv::RV32_IMM_AS; + +use crate::{ + arch::{execution_mode::E1E2ExecutionCtx, PUBLIC_VALUES_AIR_ID}, + system::memory::{dimensions::MemoryDimensions, CHUNK, CHUNK_BITS}, +}; + +// TODO(ayush): can segmentation also be triggered by timestamp overflow? should that be tracked? +#[derive(Debug)] +pub struct MeteredCtxExact { + pub trace_heights: Vec, + + continuations_enabled: bool, + num_access_adapters: u8, + as_byte_alignment_bits: Vec, + pub memory_dimensions: MemoryDimensions, + + // Map from (addr_space, addr) -> (size, offset) + pub last_memory_access: BTreeMap<(u8, u32), (u8, u8)>, + // Indices of leaf nodes in the memory merkle tree + pub leaf_indices: Vec, +} + +impl MeteredCtxExact { + pub fn new( + num_traces: usize, + continuations_enabled: bool, + num_access_adapters: u8, + as_byte_alignment_bits: Vec, + memory_dimensions: MemoryDimensions, + ) -> Self { + Self { + trace_heights: vec![0; num_traces], + continuations_enabled, + num_access_adapters, + as_byte_alignment_bits, + memory_dimensions, + last_memory_access: BTreeMap::new(), + leaf_indices: Vec::new(), + } + } +} + +impl MeteredCtxExact { + fn update_boundary_merkle_heights(&mut self, address_space: u32, ptr: u32, size: u32) { + let boundary_idx = if self.continuations_enabled { + PUBLIC_VALUES_AIR_ID + } else { + PUBLIC_VALUES_AIR_ID + 1 + }; + let poseidon2_idx = self.trace_heights.len() - 2; + + let num_blocks = (size + CHUNK as u32 - 1) >> CHUNK_BITS; + for i in 0..num_blocks { + let addr = ptr.wrapping_add(i * CHUNK as u32); + let block_id = addr >> CHUNK_BITS; + let leaf_id = self + .memory_dimensions + .label_to_index((address_space, block_id)); + + if let Err(insert_idx) = self.leaf_indices.binary_search(&leaf_id) { + self.leaf_indices.insert(insert_idx, leaf_id); + + self.trace_heights[boundary_idx] += 1; + // NOTE: this is an upper bound since poseidon chip removes duplicates + self.trace_heights[poseidon2_idx] += 2; + + if self.continuations_enabled { + let pred_id = insert_idx.checked_sub(1).map(|idx| self.leaf_indices[idx]); + let succ_id = (insert_idx < self.leaf_indices.len() - 1) + .then(|| self.leaf_indices[insert_idx + 1]); + let height_change = calculate_merkle_node_updates( + leaf_id, + pred_id, + succ_id, + self.memory_dimensions.overall_height() as u32, + ); + self.trace_heights[boundary_idx + 1] += height_change * 2; + self.trace_heights[poseidon2_idx] += height_change * 2; + } + } + } + } + + #[allow(clippy::type_complexity)] + fn calculate_splits_and_merges( + &self, + address_space: u32, + ptr: u32, + size: u32, + ) -> (Vec<(u32, u32)>, Vec<(u32, u32)>) { + // Skip adapters if this is a repeated access to the same location with same size + let last_access = self.last_memory_access.get(&(address_space as u8, ptr)); + if matches!(last_access, Some(&(last_access_size, 0)) if size == last_access_size as u32) { + return (vec![], vec![]); + } + + // Go to the start of block + let mut ptr_start = ptr; + if let Some(&(_, last_access_offset)) = last_access { + ptr_start = ptr.wrapping_sub(last_access_offset as u32); + } + + let align_bits = self.as_byte_alignment_bits[address_space as usize] as usize; + let align = 1 << align_bits; + + // Split intersecting blocks to align bytes + let mut curr_block = ptr_start >> align_bits; + let end_block = curr_block + (size >> align_bits); + let mut splits = vec![]; + while curr_block < end_block { + let curr_block_size = if let Some(&(last_access_size, _)) = self + .last_memory_access + .get(&(address_space as u8, curr_block.wrapping_mul(align as u32))) + { + last_access_size as u32 + } else { + // Initial memory access only happens at CHUNK boundary + let chunk_ratio = 1 << (CHUNK_BITS - align_bits); + let chunk_offset = curr_block & (chunk_ratio - 1); + curr_block -= chunk_offset; + CHUNK as u32 + }; + + if curr_block_size > align as u32 { + let curr_ptr = curr_block.wrapping_mul(align as u32); + splits.push((curr_ptr, curr_block_size)); + } + + curr_block += curr_block_size >> align_bits; + } + // Merge added blocks from align to size bytes + let merges = vec![(ptr, size)]; + + (splits, merges) + } + + #[allow(clippy::type_complexity)] + fn apply_adapter_updates( + &mut self, + addr_space: u32, + ptr: u32, + size: u32, + trace_heights: &mut Option<&mut [u32]>, + memory_updates: &mut Option)>>, + ) { + let adapter_offset = if self.continuations_enabled { + PUBLIC_VALUES_AIR_ID + 2 + } else { + PUBLIC_VALUES_AIR_ID + 1 + }; + + let (splits, merges) = self.calculate_splits_and_merges(addr_space, ptr, size); + for (curr_ptr, curr_size) in splits { + if let Some(trace_heights) = trace_heights { + apply_single_adapter_heights_update(trace_heights, curr_size); + } else { + apply_single_adapter_heights_update( + &mut self.trace_heights[adapter_offset..], + curr_size, + ); + } + let updates = add_memory_access_split_with_return( + &mut self.last_memory_access, + (addr_space, curr_ptr), + curr_size, + self.as_byte_alignment_bits[addr_space as usize], + ); + if let Some(memory_updates) = memory_updates { + memory_updates.extend(&updates); + } + } + for (curr_ptr, curr_size) in merges { + if let Some(trace_heights) = trace_heights { + apply_single_adapter_heights_update(trace_heights, curr_size); + } else { + apply_single_adapter_heights_update( + &mut self.trace_heights[adapter_offset..], + curr_size, + ); + } + let updates = add_memory_access_merge_with_return( + &mut self.last_memory_access, + (addr_space, curr_ptr), + curr_size, + self.as_byte_alignment_bits[addr_space as usize], + ); + if let Some(memory_updates) = memory_updates { + memory_updates.extend(updates); + } + } + } + + fn update_adapter_heights(&mut self, addr_space: u32, ptr: u32, size: u32) { + self.apply_adapter_updates(addr_space, ptr, size, &mut None, &mut None); + } + + pub fn finalize_access_adapter_heights(&mut self) { + let indices_to_process: Vec<_> = self + .leaf_indices + .iter() + .map(|&idx| { + let (addr_space, block_id) = self.memory_dimensions.index_to_label(idx); + (addr_space, block_id) + }) + .collect(); + for (addr_space, block_id) in indices_to_process { + self.update_adapter_heights(addr_space, block_id * CHUNK as u32, CHUNK as u32); + } + } + + pub fn trace_heights_if_finalized(&mut self) -> Vec { + let indices_to_process: Vec<_> = self + .leaf_indices + .iter() + .map(|&idx| { + let (addr_space, block_id) = self.memory_dimensions.index_to_label(idx); + (addr_space, block_id) + }) + .collect(); + + let mut access_adapter_updates = vec![0; self.num_access_adapters as usize]; + let mut memory_updates = Some(vec![]); + for (addr_space, block_id) in indices_to_process { + let ptr = block_id * CHUNK as u32; + self.apply_adapter_updates( + addr_space, + ptr, + CHUNK as u32, + &mut Some(&mut access_adapter_updates), + &mut memory_updates, + ); + } + + // Restore original memory state + for (key, old_value) in memory_updates.unwrap().into_iter().rev() { + match old_value { + Some(value) => { + self.last_memory_access.insert(key, value); + } + None => { + self.last_memory_access.remove(&key); + } + } + } + + let adapter_offset = if self.continuations_enabled { + PUBLIC_VALUES_AIR_ID + 2 + } else { + PUBLIC_VALUES_AIR_ID + 1 + }; + self.trace_heights + .iter() + .enumerate() + .map(|(i, &height)| { + if i >= adapter_offset && i < adapter_offset + access_adapter_updates.len() { + height + access_adapter_updates[i - adapter_offset] + } else { + height + } + }) + .collect() + } +} + +impl E1E2ExecutionCtx for MeteredCtxExact { + fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32) { + debug_assert!( + address_space != RV32_IMM_AS, + "address space must not be immediate" + ); + debug_assert!(size.is_power_of_two(), "size must be a power of 2"); + + // Handle access adapter updates + self.update_adapter_heights(address_space, ptr, size); + + // Handle merkle tree updates + // TODO(ayush): see if this can be approximated by total number of reads/writes for AS != + // register + self.update_boundary_merkle_heights(address_space, ptr, size); + } +} + +fn apply_single_adapter_heights_update(trace_heights: &mut [u32], size: u32) { + let size_bits = size.ilog2(); + for adapter_bits in (3..=size_bits).rev() { + trace_heights[adapter_bits as usize - 1] += 1 << (size_bits - adapter_bits); + } +} + +#[allow(clippy::type_complexity)] +fn add_memory_access( + memory_access_map: &mut BTreeMap<(u8, u32), (u8, u8)>, + (address_space, ptr): (u32, u32), + size: u32, + align_bits: u8, + is_split: bool, +) -> Vec<((u8, u32), Option<(u8, u8)>)> { + let align = 1 << align_bits; + debug_assert_eq!( + size & (align as u32 - 1), + 0, + "Size must be a multiple of alignment" + ); + + let num_chunks = size >> align_bits; + let mut old_values = Vec::with_capacity(num_chunks as usize); + + for i in 0..num_chunks { + let curr_ptr = ptr.wrapping_add(i * align as u32); + let key = (address_space as u8, curr_ptr); + + let value = if is_split { + (align as u8, 0) + } else { + (size as u8, (i * align as u32) as u8) + }; + + let old_value = memory_access_map.insert(key, value); + old_values.push((key, old_value)); + } + + old_values +} + +#[allow(clippy::type_complexity)] +fn add_memory_access_split_with_return( + memory_access_map: &mut BTreeMap<(u8, u32), (u8, u8)>, + (address_space, ptr): (u32, u32), + size: u32, + align_bits: u8, +) -> Vec<((u8, u32), Option<(u8, u8)>)> { + add_memory_access( + memory_access_map, + (address_space, ptr), + size, + align_bits, + true, + ) +} + +#[allow(clippy::type_complexity)] +fn add_memory_access_merge_with_return( + memory_access_map: &mut BTreeMap<(u8, u32), (u8, u8)>, + (address_space, ptr): (u32, u32), + size: u32, + align_bits: u8, +) -> Vec<((u8, u32), Option<(u8, u8)>)> { + add_memory_access( + memory_access_map, + (address_space, ptr), + size, + align_bits, + false, + ) +} + +fn calculate_merkle_node_updates( + leaf_id: u64, + pred_id: Option, + succ_id: Option, + height: u32, +) -> u32 { + // First node requires height many updates + if pred_id.is_none() && succ_id.is_none() { + return height; + } + + // Calculate the difference in divergence + let mut diff = 0; + + // Add new divergences between pred and leaf_index + if let Some(p) = pred_id { + let new_divergence = (p ^ leaf_id).ilog2(); + diff += new_divergence; + } + + // Add new divergences between leaf_index and succ + if let Some(s) = succ_id { + let new_divergence = (leaf_id ^ s).ilog2(); + diff += new_divergence; + } + + // Remove old divergence between pred and succ if both existed + if let (Some(p), Some(s)) = (pred_id, succ_id) { + let old_divergence = (p ^ s).ilog2(); + diff -= old_divergence; + } + + diff +} diff --git a/crates/vm/src/arch/execution_mode/metered/mod.rs b/crates/vm/src/arch/execution_mode/metered/mod.rs new file mode 100644 index 0000000000..389fb2c1ed --- /dev/null +++ b/crates/vm/src/arch/execution_mode/metered/mod.rs @@ -0,0 +1,293 @@ +pub mod bounded; +pub mod exact; + +// pub use exact::MeteredCtxExact as MeteredCtx; +pub use bounded::MeteredCtxBounded as MeteredCtx; +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::{p3_field::PrimeField32, ChipUsageGetter}; +use p3_baby_bear::BabyBear; + +use crate::arch::{ + execution_control::ExecutionControl, ChipId, ExecutionError, InsExecutorE1, VmChipComplex, + VmConfig, VmSegmentState, VmStateMut, CONNECTOR_AIR_ID, DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT, + DEFAULT_MAX_SEGMENT_LEN, PROGRAM_AIR_ID, PUBLIC_VALUES_AIR_ID, +}; + +/// Check segment every 100 instructions. +const SEGMENT_CHECK_INTERVAL: u64 = 100; + +// TODO(ayush): fix these values +const MAX_TRACE_HEIGHT: u32 = DEFAULT_MAX_SEGMENT_LEN as u32; +const MAX_TRACE_CELLS_PER_CHIP: usize = DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT; +const MAX_INTERACTIONS: usize = BabyBear::ORDER_U32 as usize; + +#[derive(derive_new::new, Debug)] +pub struct Segment { + pub clk_start: u64, + pub num_cycles: u64, + pub trace_heights: Vec, +} + +pub struct MeteredExecutionControl<'a> { + // Constants + air_names: &'a [String], + pub widths: &'a [usize], + pub interactions: &'a [usize], + // State + // TODO(ayush): should probably be in metered ctx + pub clk_last_segment_check: u64, + pub segments: Vec, +} + +impl<'a> MeteredExecutionControl<'a> { + pub fn new(air_names: &'a [String], widths: &'a [usize], interactions: &'a [usize]) -> Self { + Self { + air_names, + widths, + interactions, + clk_last_segment_check: 0, + segments: vec![], + } + } + + /// Calculate the total cells used based on trace heights and widths + fn calculate_total_cells(&self, trace_heights: &[u32]) -> usize { + trace_heights + .iter() + .zip(self.widths) + .map(|(&height, &width)| height.next_power_of_two() as usize * width) + .sum() + } + + /// Calculate the total interactions based on trace heights and interaction counts + fn calculate_total_interactions(&self, trace_heights: &[u32]) -> usize { + trace_heights + .iter() + .zip(self.interactions) + // We add 1 for the zero messages from the padding rows + .map(|(&height, &interactions)| (height + 1) as usize * interactions) + .sum() + } + + fn should_segment(&mut self, state: &mut VmSegmentState) -> bool { + let trace_heights = state.ctx.trace_heights_if_finalized(); + let max_trace_cells = MAX_TRACE_CELLS_PER_CHIP * trace_heights.len(); + for (i, &height) in trace_heights.iter().enumerate() { + let padded_height = height.next_power_of_two(); + if padded_height > MAX_TRACE_HEIGHT { + tracing::info!( + "Segment {:2} | clk {:9} | chip {} ({}) height ({:8}) > max ({:8})", + self.segments.len(), + self.clk_last_segment_check, + i, + self.air_names[i], + padded_height, + MAX_TRACE_HEIGHT + ); + return true; + } + } + + let total_cells = self.calculate_total_cells(&trace_heights); + if total_cells > max_trace_cells { + tracing::info!( + "Segment {:2} | clk {:9} | total cells ({:10}) > max ({:10})", + self.segments.len(), + self.clk_last_segment_check, + total_cells, + max_trace_cells + ); + return true; + } + + let total_interactions = self.calculate_total_interactions(&trace_heights); + if total_interactions > MAX_INTERACTIONS { + tracing::info!( + "Segment {:2} | clk {:9} | total interactions ({:11}) > max ({:11})", + self.segments.len(), + self.clk_last_segment_check, + total_interactions, + MAX_INTERACTIONS + ); + return true; + } + + false + } + + fn reset_segment( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) where + F: PrimeField32, + VC: VmConfig, + { + state.ctx.leaf_indices.clear(); + + // TODO(ayush): only reset trace heights for chips that are not constant height instead + // of refilling again + state.ctx.trace_heights.fill(0); + + // Program | Connector | Public Values | Memory ... | Executors (except Public Values) | + // Range Checker + state.ctx.trace_heights[PROGRAM_AIR_ID] = + chip_complex.program_chip().true_program_length as u32; + state.ctx.trace_heights[CONNECTOR_AIR_ID] = 2; + + let mut offset = if chip_complex.config().has_public_values_chip() { + PUBLIC_VALUES_AIR_ID + 1 + } else { + PUBLIC_VALUES_AIR_ID + }; + offset += chip_complex.memory_controller().num_airs(); + + // Periphery chips with constant heights + for (i, chip_id) in chip_complex + .inventory + .insertion_order + .iter() + .rev() + .enumerate() + { + if let &ChipId::Periphery(id) = chip_id { + if let Some(constant_height) = + chip_complex.inventory.periphery[id].constant_trace_height() + { + state.ctx.trace_heights[offset + i] = constant_height as u32; + } + } + } + + // Range checker chip + if let (Some(range_checker_height), Some(last_height)) = ( + chip_complex.range_checker_chip().constant_trace_height(), + state.ctx.trace_heights.last_mut(), + ) { + *last_height = range_checker_height as u32; + } + } + + fn check_segment_limits( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) where + F: PrimeField32, + VC: VmConfig, + { + // Avoid checking segment too often. + if state.clk < self.clk_last_segment_check + SEGMENT_CHECK_INTERVAL { + return; + } + + if self.should_segment(state) { + let clk_start = self + .segments + .last() + .map_or(0, |s| s.clk_start + s.num_cycles); + let segment = Segment { + clk_start, + num_cycles: self.clk_last_segment_check - clk_start, + // TODO(ayush): this is trace heights after overflow so an overestimate + trace_heights: state.ctx.trace_heights.clone(), + }; + self.segments.push(segment); + + self.reset_segment::(state, chip_complex); + } + + self.clk_last_segment_check = state.clk; + } +} + +impl ExecutionControl for MeteredExecutionControl<'_> +where + F: PrimeField32, + VC: VmConfig, + VC::Executor: InsExecutorE1, +{ + type Ctx = MeteredCtx; + + fn should_suspend( + &mut self, + _state: &mut VmSegmentState, + _chip_complex: &VmChipComplex, + ) -> bool { + false + } + + fn on_start( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + self.reset_segment::(state, chip_complex); + } + + fn on_suspend_or_terminate( + &mut self, + state: &mut VmSegmentState, + _chip_complex: &mut VmChipComplex, + _exit_code: Option, + ) { + state.ctx.finalize_access_adapter_heights(); + + tracing::info!( + "Segment {:2} | clk {:9} | terminated", + self.segments.len(), + state.clk, + ); + // Add the last segment + let clk_start = self + .segments + .last() + .map_or(0, |s| s.clk_start + s.num_cycles); + let segment = Segment { + clk_start, + num_cycles: state.clk - clk_start, + // TODO(ayush): this is trace heights after overflow so an overestimate + trace_heights: state.ctx.trace_heights.clone(), + }; + self.segments.push(segment); + } + + /// Execute a single instruction + fn execute_instruction( + &mut self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32, + { + // Check if segmentation needs to happen + self.check_segment_limits::(state, chip_complex); + + let mut offset = if chip_complex.config().has_public_values_chip() { + PUBLIC_VALUES_AIR_ID + 1 + } else { + PUBLIC_VALUES_AIR_ID + }; + offset += chip_complex.memory_controller().num_airs(); + + let &Instruction { opcode, .. } = instruction; + if let Some((executor, i)) = chip_complex.inventory.get_mut_executor_with_index(&opcode) { + let mut vm_state = VmStateMut { + pc: &mut state.pc, + memory: state.memory.as_mut().unwrap(), + ctx: &mut state.ctx, + }; + executor.execute_metered(&mut vm_state, instruction, offset + i)?; + } else { + return Err(ExecutionError::DisabledOperation { + pc: state.pc, + opcode, + }); + }; + state.clk += 1; + + Ok(()) + } +} diff --git a/crates/vm/src/arch/execution_mode/mod.rs b/crates/vm/src/arch/execution_mode/mod.rs new file mode 100644 index 0000000000..0f37eb3638 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/mod.rs @@ -0,0 +1,8 @@ +pub mod e1; +pub mod metered; +pub mod tracegen; + +// TODO(ayush): better name +pub trait E1E2ExecutionCtx { + fn on_memory_operation(&mut self, address_space: u32, ptr: u32, size: u32); +} diff --git a/crates/vm/src/arch/execution_mode/tracegen/mod.rs b/crates/vm/src/arch/execution_mode/tracegen/mod.rs new file mode 100644 index 0000000000..ea03d7966d --- /dev/null +++ b/crates/vm/src/arch/execution_mode/tracegen/mod.rs @@ -0,0 +1,7 @@ +mod normal; +mod segmentation; + +pub use normal::TracegenExecutionControl; +pub use segmentation::TracegenExecutionControlWithSegmentation; + +pub type TracegenCtx = (); diff --git a/crates/vm/src/arch/execution_mode/tracegen/normal.rs b/crates/vm/src/arch/execution_mode/tracegen/normal.rs new file mode 100644 index 0000000000..c2331caa85 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/tracegen/normal.rs @@ -0,0 +1,103 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + arch::{ + execution_control::ExecutionControl, ExecutionError, ExecutionState, InstructionExecutor, + VmChipComplex, VmConfig, VmSegmentState, + }, + system::memory::{MemoryImage, INITIAL_TIMESTAMP}, +}; + +pub type TracegenCtx = (); + +/// Implementation of the ExecutionControl trait using the old segmentation strategy +pub struct TracegenExecutionControl { + // State + pub clk_end: u64, + // TODO(ayush): do we need this if only executing one segment? + pub final_memory: Option, +} + +impl TracegenExecutionControl { + pub fn new(clk_end: u64) -> Self { + Self { + clk_end, + final_memory: None, + } + } +} + +impl ExecutionControl for TracegenExecutionControl +where + F: PrimeField32, + VC: VmConfig, +{ + type Ctx = TracegenCtx; + + fn should_suspend( + &mut self, + state: &mut VmSegmentState, + _chip_complex: &VmChipComplex, + ) -> bool { + state.clk >= self.clk_end + } + + fn on_start( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + chip_complex + .connector_chip_mut() + .begin(ExecutionState::new(state.pc, INITIAL_TIMESTAMP + 1)); + } + + fn on_suspend_or_terminate( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: Option, + ) { + // TODO(ayush): this should ideally not be here + self.final_memory = Some(chip_complex.base.memory_controller.memory_image().clone()); + + let timestamp = chip_complex.memory_controller().timestamp(); + chip_complex + .connector_chip_mut() + .end(ExecutionState::new(state.pc, timestamp), exit_code); + } + + /// Execute a single instruction + fn execute_instruction( + &mut self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32, + { + let timestamp = chip_complex.memory_controller().timestamp(); + + let &Instruction { opcode, .. } = instruction; + + if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { + let memory_controller = &mut chip_complex.base.memory_controller; + let new_state = executor.execute( + memory_controller, + instruction, + ExecutionState::new(state.pc, timestamp), + )?; + state.pc = new_state.pc; + } else { + return Err(ExecutionError::DisabledOperation { + pc: state.pc, + opcode, + }); + }; + state.clk += 1; + + Ok(()) + } +} diff --git a/crates/vm/src/arch/execution_mode/tracegen/segmentation.rs b/crates/vm/src/arch/execution_mode/tracegen/segmentation.rs new file mode 100644 index 0000000000..f193dc2059 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/tracegen/segmentation.rs @@ -0,0 +1,118 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::TracegenCtx; +use crate::{ + arch::{ + execution_control::ExecutionControl, ExecutionError, ExecutionState, InstructionExecutor, + VmChipComplex, VmConfig, VmSegmentState, + }, + system::memory::{MemoryImage, INITIAL_TIMESTAMP}, +}; + +/// Check segment every 100 instructions. +const SEGMENT_CHECK_INTERVAL: usize = 100; + +// TODO(ayush): fix this name since it's a mouthful +/// Implementation of the ExecutionControl trait using the old segmentation strategy +pub struct TracegenExecutionControlWithSegmentation { + // Constant + air_names: Vec, + // State + pub since_last_segment_check: usize, + pub final_memory: Option, +} + +impl TracegenExecutionControlWithSegmentation { + pub fn new(air_names: Vec) -> Self { + Self { + since_last_segment_check: 0, + air_names, + final_memory: None, + } + } +} + +impl ExecutionControl for TracegenExecutionControlWithSegmentation +where + F: PrimeField32, + VC: VmConfig, +{ + type Ctx = TracegenCtx; + + fn should_suspend( + &mut self, + _state: &mut VmSegmentState, + chip_complex: &VmChipComplex, + ) -> bool { + // Avoid checking segment too often. + if self.since_last_segment_check != SEGMENT_CHECK_INTERVAL { + self.since_last_segment_check += 1; + return false; + } + self.since_last_segment_check = 0; + chip_complex.config().segmentation_strategy.should_segment( + &self.air_names, + &chip_complex.dynamic_trace_heights().collect::>(), + &chip_complex.current_trace_cells(), + ) + } + + fn on_start( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + chip_complex + .connector_chip_mut() + .begin(ExecutionState::new(state.pc, INITIAL_TIMESTAMP + 1)); + } + + fn on_suspend_or_terminate( + &mut self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: Option, + ) { + // TODO(ayush): this should ideally not be here + self.final_memory = Some(chip_complex.base.memory_controller.memory_image().clone()); + + let timestamp = chip_complex.memory_controller().timestamp(); + chip_complex + .connector_chip_mut() + .end(ExecutionState::new(state.pc, timestamp), exit_code); + } + + /// Execute a single instruction + fn execute_instruction( + &mut self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32, + { + let timestamp = chip_complex.memory_controller().timestamp(); + + let &Instruction { opcode, .. } = instruction; + + if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { + let memory_controller = &mut chip_complex.base.memory_controller; + let new_state = executor.execute( + memory_controller, + instruction, + ExecutionState::new(state.pc, timestamp), + )?; + state.pc = new_state.pc; + } else { + return Err(ExecutionError::DisabledOperation { + pc: state.pc, + opcode, + }); + }; + state.clk += 1; + + Ok(()) + } +} diff --git a/crates/vm/src/arch/extensions.rs b/crates/vm/src/arch/extensions.rs index adda318f6a..fd34a1815a 100644 --- a/crates/vm/src/arch/extensions.rs +++ b/crates/vm/src/arch/extensions.rs @@ -10,7 +10,7 @@ use getset::Getters; use itertools::{zip_eq, Itertools}; #[cfg(feature = "bench-metrics")] use metrics::counter; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor}; use openvm_circuit_primitives::{ utils::next_power_of_two_or_zero, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, @@ -40,17 +40,23 @@ use super::{ }; #[cfg(feature = "bench-metrics")] use crate::metrics::VmMetrics; -use crate::system::{ - connector::VmConnectorChip, - memory::{ - offline_checker::{MemoryBridge, MemoryBus}, - MemoryController, MemoryImage, OfflineMemory, BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET, +use crate::{ + arch::{ExecutionBridge, VmAirWrapper}, + system::{ + connector::VmConnectorChip, + memory::{ + offline_checker::{MemoryBridge, MemoryBus}, + MemoryController, MemoryImage, BOUNDARY_AIR_OFFSET, MERKLE_AIR_OFFSET, + }, + native_adapter::{NativeAdapterAir, NativeAdapterStep}, + phantom::PhantomChip, + poseidon2::Poseidon2PeripheryChip, + program::{ProgramBus, ProgramChip}, + public_values::{ + core::{PublicValuesCoreAir, PublicValuesCoreStep}, + PublicValuesChip, + }, }, - native_adapter::NativeAdapterChip, - phantom::PhantomChip, - poseidon2::Poseidon2PeripheryChip, - program::{ProgramBus, ProgramChip}, - public_values::{core::PublicValuesCoreChip, PublicValuesChip}, }; /// Global AIR ID in the VM circuit verifying key. @@ -205,7 +211,7 @@ pub struct VmInventory { pub periphery: Vec

, /// Order of insertion. The reverse of this will be the order the chips are destroyed /// to generate trace. - insertion_order: Vec, + pub insertion_order: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -331,6 +337,25 @@ impl VmInventory { self.executors.get_mut(*id) } + pub fn get_mut_executor_with_index(&mut self, opcode: &VmOpcode) -> Option<(&mut E, usize)> { + let id = *self.instruction_lookup.get(opcode)?; + + self.executors.get_mut(id).map(|executor| { + // TODO(ayush): cache this somewhere + let insertion_id = self + .insertion_order + .iter() + .rev() + .position(|chip_id| match chip_id { + ChipId::Executor(exec_id) => *exec_id == id, + _ => false, + }) + .unwrap(); + + (executor, insertion_id) + }) + } + pub fn executors(&self) -> &[E] { &self.executors } @@ -494,10 +519,6 @@ impl SystemBase { self.memory_controller.memory_bridge() } - pub fn offline_memory(&self) -> Arc>> { - self.memory_controller.offline_memory().clone() - } - pub fn execution_bus(&self) -> ExecutionBus { self.connector_chip.air.execution_bus } @@ -519,7 +540,7 @@ impl SystemBase { } } -#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor)] +#[derive(ChipUsageGetter, Chip, AnyEnum, From, InstructionExecutor, InsExecutorE1)] pub enum SystemExecutor { PublicValues(PublicValuesChip), Phantom(RefCell>), @@ -557,7 +578,6 @@ impl SystemComplex { ) }; let memory_bridge = memory_controller.memory_bridge(); - let offline_memory = memory_controller.offline_memory(); let program_chip = ProgramChip::new(program_bus); let connector_chip = VmConnectorChip::new( execution_bus, @@ -570,14 +590,29 @@ impl SystemComplex { // PublicValuesChip is required when num_public_values > 0 in single segment mode. if config.has_public_values_chip() { assert_eq!(inventory.executors().len(), Self::PV_EXECUTOR_IDX); + + // TODO(ayush): this should be decided after e2 execution + const MAX_INS_CAPACITY: usize = 1 << 22; let chip = PublicValuesChip::new( - NativeAdapterChip::new(execution_bus, program_bus, memory_bridge), - PublicValuesCoreChip::new( + VmAirWrapper::new( + NativeAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + PublicValuesCoreAir::new( + config.num_public_values, + config.max_constraint_degree as u32 - 1, + ), + ), + PublicValuesCoreStep::new( + NativeAdapterStep::new(), config.num_public_values, config.max_constraint_degree as u32 - 1, ), - offline_memory, + MAX_INS_CAPACITY, + memory_controller.helper(), ); + inventory .add_executor(chip, [PublishOpcode::PUBLISH.global_opcode()]) .unwrap(); @@ -776,11 +811,11 @@ impl VmChipComplex { .as_any_kind_mut() .downcast_mut() .expect("Poseidon2 chip required for persistent memory"); - self.base.memory_controller.finalize(Some(hasher)) + self.base.memory_controller.finalize(Some(hasher)); } else { self.base .memory_controller - .finalize(None::<&mut Poseidon2PeripheryChip>) + .finalize(None::<&mut Poseidon2PeripheryChip>); }; } @@ -788,7 +823,7 @@ impl VmChipComplex { self.base.program_chip.set_program(program); } - pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage) { + pub(crate) fn set_initial_memory(&mut self, memory: MemoryImage) { self.base.memory_controller.set_initial_memory(memory); } @@ -809,7 +844,7 @@ impl VmChipComplex { } // we always need to special case it because we need to fix the air id. - fn public_values_chip_idx(&self) -> Option { + pub(crate) fn public_values_chip_idx(&self) -> Option { self.config .has_public_values_chip() .then_some(Self::PV_EXECUTOR_IDX) @@ -838,7 +873,7 @@ impl VmChipComplex { } /// Return air names of all chips in order. - pub(crate) fn air_names(&self) -> Vec + pub fn air_names(&self) -> Vec where E: ChipUsageGetter, P: ChipUsageGetter, @@ -851,6 +886,7 @@ impl VmChipComplex { .chain([self.range_checker_chip().air_name()]) .collect() } + /// Return trace heights of all chips in order corresponding to `air_names`. pub(crate) fn current_trace_heights(&self) -> Vec where diff --git a/crates/vm/src/arch/integration_api.rs b/crates/vm/src/arch/integration_api.rs index b1116d8c48..4294458b22 100644 --- a/crates/vm/src/arch/integration_api.rs +++ b/crates/vm/src/arch/integration_api.rs @@ -1,28 +1,28 @@ -use std::{ - array::from_fn, - borrow::Borrow, - marker::PhantomData, - sync::{Arc, Mutex}, -}; +use std::{array::from_fn, borrow::Borrow, marker::PhantomData, sync::Arc}; use openvm_circuit_primitives::utils::next_power_of_two_or_zero; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_stark_backend::{ - air_builders::{debug::DebugConstraintBuilder, symbolic::SymbolicRapBuilder}, config::{StarkGenericConfig, Val}, p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{FieldAlgebra, PrimeField32}, + p3_field::{Field, FieldAlgebra, PrimeField32}, p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_maybe_rayon::prelude::*, prover::types::AirProofInput, - rap::{get_air_name, BaseAirWithPublicValues, PartitionedBaseAir}, + rap::{get_air_name, AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, AirRef, Chip, ChipUsageGetter, }; use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use super::{ExecutionState, InstructionExecutor, Result}; -use crate::system::memory::{MemoryController, OfflineMemory}; +use super::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + ExecutionState, InsExecutorE1, InstructionExecutor, Result, VmStateMut, +}; +use crate::system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, MemoryController, SharedMemoryHelper, +}; /// The interface between primitive AIR and machine adapter AIR. pub trait VmAdapterInterface { @@ -37,60 +37,6 @@ pub trait VmAdapterInterface { type ProcessedInstruction; } -/// The adapter owns all memory accesses and timestamp changes. -/// The adapter AIR should also own `ExecutionBridge` and `MemoryBridge`. -pub trait VmAdapterChip { - /// Records generated by adapter before main instruction execution - type ReadRecord: Send + Serialize + DeserializeOwned; - /// Records generated by adapter after main instruction execution - type WriteRecord: Send + Serialize + DeserializeOwned; - /// AdapterAir should not have public values - type Air: BaseAir + Clone; - - type Interface: VmAdapterInterface; - - /// Given instruction, perform memory reads and return only the read data that the integrator - /// needs to use. This is called at the start of instruction execution. - /// - /// The implementer may choose to store data in the `Self::ReadRecord` struct, for example in - /// an [Option], which will later be sent to the `postprocess` method. - #[allow(clippy::type_complexity)] - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )>; - - /// Given instruction and the data to write, perform memory writes and return the `(record, - /// next_timestamp)` of the full adapter record for this instruction. This is guaranteed to - /// be called after `preprocess`. - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)>; - - /// Populates `row_slice` with values corresponding to `record`. - /// The provided `row_slice` will have length equal to `self.air().width()`. - /// This function will be called for each row in the trace which is being used, and all other - /// rows in the trace will be filled with zeroes. - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ); - - fn air(&self) -> &Self::Air; -} - pub trait VmAdapterAir: BaseAir { type Interface: VmAdapterInterface; @@ -111,6 +57,7 @@ pub trait VmAdapterAir: BaseAir { fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var; } +// TODO: delete /// Trait to be implemented on primitive chip to integrate with the machine. pub trait VmCoreChip> { /// Minimum data that must be recorded to be able to generate trace for one row of @@ -183,6 +130,7 @@ where } } +// TODO: delete pub struct AdapterRuntimeContext> { /// Leave as `None` to allow the adapter to decide the `to_pc` automatically. pub to_pc: Option, @@ -207,35 +155,110 @@ pub struct AdapterAirContext> { pub instruction: I::ProcessedInstruction, } -pub struct VmChipWrapper, C: VmCoreChip> { - pub adapter: A, - pub core: C, - pub records: Vec<(A::ReadRecord, A::WriteRecord, C::Record)>, - offline_memory: Arc>>, +/// Interface for trace generation of a single instruction.The trace is provided as a mutable +/// buffer during both instruction execution and trace generation. +/// It is expected that no additional memory allocation is necessary and the trace buffer +/// is sufficient, with possible overwriting. +pub trait TraceStep { + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + // TODO(ayush): combine to a single struct + trace: &mut [F], + trace_offset: &mut usize, + // TODO(ayush): move air inside step and remove width + width: usize, + ) -> Result<()>; + + /// Populates `trace`. This function will always be called after + /// [`TraceStep::execute`], so the `trace` should already contain context necessary to + /// fill in the rest of it. + // TODO(ayush): come up with a better abstraction for chips that fill a dynamic number of rows + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut [F], + width: usize, + rows_used: usize, + ) where + Self: Send + Sync, + F: Send + Sync, + { + trace[..rows_used * width] + .par_chunks_exact_mut(width) + .for_each(|row_slice| { + self.fill_trace_row(mem_helper, row_slice); + }); + trace[rows_used * width..] + .par_chunks_exact_mut(width) + .for_each(|row_slice| { + self.fill_dummy_trace_row(mem_helper, row_slice); + }); + } + + /// Populates `row_slice`. This function will always be called after + /// [`TraceStep::execute`], so the `row_slice` should already contain context necessary to + /// fill in the rest of the row. This function will be called for each row in the trace which is + /// being used, and all other rows in the trace will be filled with zeroes. + /// + /// The provided `row_slice` will have length equal to the width of the AIR. + fn fill_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, _row_slice: &mut [F]) { + unreachable!("fill_trace_row is not implemented") + } + + /// Populates `row_slice`. This function will be called on dummy rows. + /// By default the trace is padded with empty (all 0) rows to make the height a power of 2. + /// + /// The provided `row_slice` will have length equal to the width of the AIR. + fn fill_dummy_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, _row_slice: &mut [F]) { + // By default, the row is filled with zeroes + } + /// Returns a list of public values to publish. + fn generate_public_values(&self) -> Vec { + vec![] + } + + /// Displayable opcode name for logging and debugging purposes. + fn get_opcode_name(&self, opcode: usize) -> String; } -const DEFAULT_RECORDS_CAPACITY: usize = 1 << 20; +// TODO(ayush): rename to ChipWithExecutionContext or something +pub struct NewVmChipWrapper { + pub air: AIR, + pub step: STEP, + pub trace_buffer: Vec, + // TODO(ayush): width should be a constant? + width: usize, + buffer_idx: usize, + mem_helper: SharedMemoryHelper, +} -impl VmChipWrapper +impl NewVmChipWrapper where - A: VmAdapterChip, - C: VmCoreChip, + F: Field, + AIR: BaseAir, { - pub fn new(adapter: A, core: C, offline_memory: Arc>>) -> Self { + pub fn new(air: AIR, step: STEP, height: usize, mem_helper: SharedMemoryHelper) -> Self { + assert!(height == 0 || height.is_power_of_two()); + let width = air.width(); + let trace_buffer = F::zero_vec(height * width); Self { - adapter, - core, - records: Vec::with_capacity(DEFAULT_RECORDS_CAPACITY), - offline_memory, + air, + step, + trace_buffer, + width, + buffer_idx: 0, + mem_helper, } } } -impl InstructionExecutor for VmChipWrapper +impl InstructionExecutor for NewVmChipWrapper where F: PrimeField32, - A: VmAdapterChip + Send + Sync, - M: VmCoreChip + Send + Sync, + STEP: TraceStep // TODO: CTX? + + StepExecutorE1, { fn execute( &mut self, @@ -243,104 +266,195 @@ where instruction: &Instruction, from_state: ExecutionState, ) -> Result> { - let (reads, read_record) = self.adapter.preprocess(memory, instruction)?; - let (output, core_record) = - self.core - .execute_instruction(instruction, from_state.pc, reads)?; - let (to_state, write_record) = - self.adapter - .postprocess(memory, instruction, from_state, output, &read_record)?; - self.records.push((read_record, write_record, core_record)); - Ok(to_state) + let mut pc = from_state.pc; + let state = VmStateMut { + pc: &mut pc, + memory: &mut memory.memory, + ctx: &mut (), + }; + self.step.execute( + state, + instruction, + &mut self.trace_buffer, + &mut self.buffer_idx, + self.width, + )?; + + Ok(ExecutionState { + pc, + timestamp: memory.memory.timestamp, + }) } fn get_opcode_name(&self, opcode: usize) -> String { - self.core.get_opcode_name(opcode) + self.step.get_opcode_name(opcode) } } // Note[jpw]: the statement we want is: -// - when A::Air is an AdapterAir for all AirBuilders needed by stark-backend -// - and when M::Air is an CoreAir for all AirBuilders needed by stark-backend, -// then VmAirWrapper is an Air for all AirBuilders needed -// by stark-backend, which is equivalent to saying it implements AirRef +// - `Air` is an `Air` for all `AB: AirBuilder`s needed by stark-backend +// which is equivalent to saying it implements AirRef // The where clauses to achieve this statement is unfortunately really verbose. -impl Chip for VmChipWrapper, A, C> +impl Chip for NewVmChipWrapper, AIR, STEP> where SC: StarkGenericConfig, Val: PrimeField32, - A: VmAdapterChip> + Send + Sync, - C: VmCoreChip, A::Interface> + Send + Sync, - A::Air: Send + Sync + 'static, - A::Air: VmAdapterAir>>, - A::Air: for<'a> VmAdapterAir>, - C::Air: Send + Sync + 'static, - C::Air: VmCoreAir< - SymbolicRapBuilder>, - >>>::Interface, - >, - C::Air: for<'a> VmCoreAir< - DebugConstraintBuilder<'a, SC>, - >>::Interface, - >, + STEP: TraceStep, ()> + Send + Sync, + AIR: Clone + AnyRap + 'static, { fn air(&self) -> AirRef { - let air: VmAirWrapper = VmAirWrapper { - adapter: self.adapter.air().clone(), - core: self.core.air().clone(), - }; - Arc::new(air) + Arc::new(self.air.clone()) } - fn generate_air_proof_input(self) -> AirProofInput { - let num_records = self.records.len(); - let height = next_power_of_two_or_zero(num_records); - let core_width = self.core.air().width(); - let adapter_width = self.adapter.air().width(); - let width = core_width + adapter_width; - let mut values = Val::::zero_vec(height * width); - - let memory = self.offline_memory.lock().unwrap(); - - // This zip only goes through records. - // The padding rows between records.len()..height are filled with zeros. - values - .par_chunks_mut(width) - .zip(self.records.into_par_iter()) - .for_each(|(row_slice, record)| { - let (adapter_row, core_row) = row_slice.split_at_mut(adapter_width); - self.adapter - .generate_trace_row(adapter_row, record.0, record.1, &memory); - self.core.generate_trace_row(core_row, record.2); - }); - - let mut trace = RowMajorMatrix::new(values, width); - self.core.finalize(&mut trace, num_records); - - AirProofInput::simple(trace, self.core.generate_public_values()) + fn generate_air_proof_input(mut self) -> AirProofInput { + assert_eq!(self.buffer_idx % self.width, 0); + let rows_used = self.current_trace_height(); + let height = next_power_of_two_or_zero(rows_used); + // This should be automatic since trace_buffer's height is a power of two: + assert!(height.checked_mul(self.width).unwrap() <= self.trace_buffer.len()); + self.trace_buffer.truncate(height * self.width); + let mem_helper = self.mem_helper.as_borrowed(); + self.step + .fill_trace(&mem_helper, &mut self.trace_buffer, self.width, rows_used); + drop(self.mem_helper); + let trace = RowMajorMatrix::new(self.trace_buffer, self.width); + // self.inner.finalize(&mut trace, num_records); + + AirProofInput::simple(trace, self.step.generate_public_values()) } } -impl ChipUsageGetter for VmChipWrapper +impl ChipUsageGetter for NewVmChipWrapper where - A: VmAdapterChip + Sync, - M: VmCoreChip + Sync, + C: Sync, { fn air_name(&self) -> String { - format!( - "<{},{}>", - get_air_name(self.adapter.air()), - get_air_name(self.core.air()) - ) + get_air_name(&self.air) } fn current_trace_height(&self) -> usize { - self.records.len() + self.buffer_idx / self.width } fn trace_width(&self) -> usize { - self.adapter.air().width() + self.core.air().width() + self.width + } +} + +// TODO[jpw]: switch read,write to store into abstract buffer, then fill_trace_row using buffer +/// A helper trait for expressing generic state accesses within the implementation of +/// [TraceStep]. Note that this is only a helper trait when the same interface of state access +/// is reused or shared by multiple implementations. It is not required to implement this trait if +/// it is easier to implement the [TraceStep] trait directly without this trait. +pub trait AdapterTraceStep { + /// Adapter row width + const WIDTH: usize; + type ReadData; + type WriteData; + /// The minimal amount of information needed to generate the sub-row of the trace matrix. + /// This type has a lifetime so other context, such as references to other chips, can be + /// provided. + type TraceContext<'a> + where + Self: 'a; + + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]); + + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + adapter_row: &mut [F], + ) -> Self::ReadData; + + fn write( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + adapter_row: &mut [F], + data: &Self::WriteData, + ); + + // Note[jpw]: should we reuse TraceSubRowGenerator trait instead? + /// Post-execution filling of rest of adapter row. + fn fill_trace_row( + &self, + mem_helper: &MemoryAuxColsFactory, + ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], + ); +} + +pub trait AdapterExecutorE1 +where + F: PrimeField32, +{ + type ReadData; + type WriteData; + + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx; + + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx; +} + +// TODO: Rename core/step to operator +pub trait StepExecutorE1 { + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx; + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()>; +} + +impl InsExecutorE1 for NewVmChipWrapper +where + F: PrimeField32, + S: StepExecutorE1, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + self.step.execute_e1(state, instruction) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> + where + F: PrimeField32, + { + self.step.execute_metered(state, instruction, chip_index) } } +#[derive(Clone, Copy, derive_new::new)] pub struct VmAirWrapper { pub adapter: A, pub core: C, diff --git a/crates/vm/src/arch/mod.rs b/crates/vm/src/arch/mod.rs index 63ee5e6f8b..fdb2b7e49a 100644 --- a/crates/vm/src/arch/mod.rs +++ b/crates/vm/src/arch/mod.rs @@ -2,12 +2,17 @@ mod config; /// Instruction execution traits and types. /// Execution bus and interface. mod execution; +/// Module for controlling VM execution flow, including segmentation and instruction execution +pub mod execution_control; +pub mod execution_mode; /// Traits and builders to compose collections of chips into a virtual machine. mod extensions; /// Traits and wrappers to facilitate VM chip integration mod integration_api; /// Runtime execution and segmentation pub mod segment; +/// Strategy for determining when to segment VM execution +pub mod segmentation_strategy; /// Top level [VirtualMachine] constructor and API. pub mod vm; @@ -23,4 +28,5 @@ pub use execution::*; pub use extensions::*; pub use integration_api::*; pub use segment::*; +pub use segmentation_strategy::*; pub use vm::*; diff --git a/crates/vm/src/arch/segment.rs b/crates/vm/src/arch/segment.rs index 634632ce2b..af57ad7fc9 100644 --- a/crates/vm/src/arch/segment.rs +++ b/crates/vm/src/arch/segment.rs @@ -1,10 +1,7 @@ -use std::sync::Arc; - use backtrace::Backtrace; use openvm_instructions::{ exe::FnBounds, instruction::{DebugInfo, Instruction}, - program::Program, }; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig}, @@ -17,137 +14,46 @@ use openvm_stark_backend::{ }; use super::{ - ExecutionError, GenerationError, Streams, SystemBase, SystemConfig, VmChipComplex, - VmComplexTraceHeights, VmConfig, + execution_control::ExecutionControl, ExecutionError, GenerationError, SystemConfig, + VmChipComplex, VmComplexTraceHeights, VmConfig, }; #[cfg(feature = "bench-metrics")] use crate::metrics::VmMetrics; use crate::{ - arch::{instructions::*, ExecutionState, InstructionExecutor}, - system::memory::MemoryImage, + arch::{instructions::*, InstructionExecutor}, + system::memory::online::GuestMemory, }; -/// Check segment every 100 instructions. -const SEGMENT_CHECK_INTERVAL: usize = 100; - -const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; -// a heuristic number for the maximum number of cells per chip in a segment -// a few reasons for this number: -// 1. `VmAirWrapper` is -// the chip with the most cells in a segment from the reth-benchmark. -// 2. `VmAirWrapper`: -// its trace width is 36 and its after challenge trace width is 80. -const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120; - -pub trait SegmentationStrategy: - std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe -{ - /// Whether the execution should segment based on the trace heights and cells. - /// - /// Air names are provided for debugging purposes. - fn should_segment( - &self, - air_names: &[String], - trace_heights: &[usize], - trace_cells: &[usize], - ) -> bool; - - /// A strategy that segments more aggressively than the current one. - /// - /// Called when `should_segment` results in a segment that is infeasible. Execution will be - /// re-run with the stricter segmentation strategy. - fn stricter_strategy(&self) -> Arc; -} - -/// Default segmentation strategy: segment if any chip's height or cells exceed the limits. -#[derive(Debug, Clone)] -pub struct DefaultSegmentationStrategy { - max_segment_len: usize, - max_cells_per_chip_in_segment: usize, -} - -impl Default for DefaultSegmentationStrategy { - fn default() -> Self { - Self { - max_segment_len: DEFAULT_MAX_SEGMENT_LEN, - max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT, - } - } +pub struct VmSegmentState { + pub clk: u64, + pub pc: u32, + pub memory: Option, + pub exit_code: Option, + pub ctx: Ctx, } -impl DefaultSegmentationStrategy { - pub fn new_with_max_segment_len(max_segment_len: usize) -> Self { - Self { - max_segment_len, - max_cells_per_chip_in_segment: max_segment_len * 120, - } - } - - pub fn new(max_segment_len: usize, max_cells_per_chip_in_segment: usize) -> Self { +impl VmSegmentState { + pub fn new(clk: u64, pc: u32, memory: Option, ctx: Ctx) -> Self { Self { - max_segment_len, - max_cells_per_chip_in_segment, - } - } - - pub fn max_segment_len(&self) -> usize { - self.max_segment_len - } -} - -const SEGMENTATION_BACKOFF_FACTOR: usize = 4; - -impl SegmentationStrategy for DefaultSegmentationStrategy { - fn should_segment( - &self, - air_names: &[String], - trace_heights: &[usize], - trace_cells: &[usize], - ) -> bool { - for (i, &height) in trace_heights.iter().enumerate() { - if height > self.max_segment_len { - tracing::info!( - "Should segment because chip {} (name: {}) has height {}", - i, - air_names[i], - height - ); - return true; - } - } - for (i, &num_cells) in trace_cells.iter().enumerate() { - if num_cells > self.max_cells_per_chip_in_segment { - tracing::info!( - "Should segment because chip {} (name: {}) has {} cells", - i, - air_names[i], - num_cells - ); - return true; - } + clk, + pc, + memory, + ctx, + exit_code: None, } - false - } - - fn stricter_strategy(&self) -> Arc { - Arc::new(Self { - max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR, - max_cells_per_chip_in_segment: self.max_cells_per_chip_in_segment - / SEGMENTATION_BACKOFF_FACTOR, - }) } } -pub struct ExecutionSegment +pub struct VmSegmentExecutor where F: PrimeField32, VC: VmConfig, + Ctrl: ExecutionControl, { pub chip_complex: VmChipComplex, - /// Memory image after segment was executed. Not used in trace generation. - pub final_memory: Option>, + /// Execution control for determining segmentation and stopping conditions + pub ctrl: Ctrl, - pub since_last_segment_check: usize, pub trace_height_constraints: Vec, /// Air names for debug purposes only. @@ -157,38 +63,24 @@ where pub metrics: VmMetrics, } -pub struct ExecutionSegmentState { - pub pc: u32, - pub is_terminated: bool, -} - -impl> ExecutionSegment { +impl VmSegmentExecutor +where + F: PrimeField32, + VC: VmConfig, + Ctrl: ExecutionControl, +{ /// Creates a new execution segment from a program and initial state, using parent VM config pub fn new( - config: &VC, - program: Program, - init_streams: Streams, - initial_memory: Option>, + chip_complex: VmChipComplex, trace_height_constraints: Vec, #[allow(unused_variables)] fn_bounds: FnBounds, + ctrl: Ctrl, ) -> Self { - let mut chip_complex = config.create_chip_complex().unwrap(); - chip_complex.set_streams(init_streams); - let program = if !config.system().profiling { - program.strip_debug_infos() - } else { - program - }; - chip_complex.set_program(program); - - if let Some(initial_memory) = initial_memory { - chip_complex.set_initial_memory(initial_memory); - } let air_names = chip_complex.air_names(); Self { chip_complex, - final_memory: None, + ctrl, air_names, trace_height_constraints, #[cfg(feature = "bench-metrics")] @@ -196,7 +88,6 @@ impl> ExecutionSegment { fn_bounds, ..Default::default() }, - since_last_segment_check: 0, } } @@ -211,133 +102,122 @@ impl> ExecutionSegment { .set_override_inventory_trace_heights(overridden_heights.inventory); } - /// Stopping is triggered by should_segment() - pub fn execute_from_pc( + /// Stopping is triggered by should_stop() or if VM is terminated + pub fn execute_from_state( &mut self, - mut pc: u32, - ) -> Result { - let mut timestamp = self.chip_complex.memory_controller().timestamp(); + state: &mut VmSegmentState, + ) -> Result<(), ExecutionError> { let mut prev_backtrace: Option = None; - self.chip_complex - .connector_chip_mut() - .begin(ExecutionState::new(pc, timestamp)); - - let mut did_terminate = false; + // Call the pre-execution hook + self.ctrl.on_start(state, &mut self.chip_complex); loop { - #[allow(unused_variables)] - let (opcode, dsl_instr) = { - let Self { - chip_complex, - #[cfg(feature = "bench-metrics")] - metrics, - .. - } = self; - let SystemBase { - program_chip, - memory_controller, - .. - } = &mut chip_complex.base; - - let (instruction, debug_info) = program_chip.get_instruction(pc)?; - tracing::trace!("pc: {pc:#x} | time: {timestamp} | {:?}", instruction); - - #[allow(unused_variables)] - let (dsl_instr, trace) = debug_info.as_ref().map_or( - (None, None), - |DebugInfo { - dsl_instruction, - trace, - }| (Some(dsl_instruction), trace.as_ref()), - ); - - let &Instruction { opcode, c, .. } = instruction; - if opcode == SystemOpcode::TERMINATE.global_opcode() { - did_terminate = true; - self.chip_complex.connector_chip_mut().end( - ExecutionState::new(pc, timestamp), - Some(c.as_canonical_u32()), - ); - break; + if let Some(exit_code) = state.exit_code { + self.ctrl + .on_terminate(state, &mut self.chip_complex, exit_code); + break; + } + if self.should_suspend(state) { + self.ctrl.on_suspend(state, &mut self.chip_complex); + break; + } + + // Fetch, decode and execute single instruction + self.execute_instruction(state, &mut prev_backtrace)?; + } + Ok(()) + } + + /// Executes a single instruction and updates VM state + // TODO(ayush): clean this up, separate to smaller functions + fn execute_instruction( + &mut self, + state: &mut VmSegmentState, + prev_backtrace: &mut Option, + ) -> Result<(), ExecutionError> { + let pc = state.pc; + let timestamp = self.chip_complex.memory_controller().timestamp(); + + // Process an instruction and update VM state + let (instruction, debug_info) = self.chip_complex.base.program_chip.get_instruction(pc)?; + + tracing::trace!("pc: {pc:#x} | time: {timestamp} | {:?}", instruction); + + let &Instruction { opcode, c, .. } = instruction; + + // Handle termination instruction + if opcode == SystemOpcode::TERMINATE.global_opcode() { + state.exit_code = Some(c.as_canonical_u32()); + return Ok(()); + } + + // Extract debug info components + #[allow(unused_variables)] + let (dsl_instr, trace) = debug_info.as_ref().map_or( + (None, None), + |DebugInfo { + dsl_instruction, + trace, + }| (Some(dsl_instruction.clone()), trace.as_ref()), + ); + + // Handle phantom instructions + if opcode == SystemOpcode::PHANTOM.global_opcode() { + let discriminant = c.as_canonical_u32() as u16; + if let Some(phantom) = SysPhantom::from_repr(discriminant) { + tracing::trace!("pc: {pc:#x} | system phantom: {phantom:?}"); + + if phantom == SysPhantom::DebugPanic { + if let Some(mut backtrace) = prev_backtrace.take() { + backtrace.resolve(); + eprintln!("openvm program failure; backtrace:\n{:?}", backtrace); + } else { + eprintln!("openvm program failure; no backtrace"); + } + return Err(ExecutionError::Fail { pc }); } - // Some phantom instruction handling is more convenient to do here than in - // PhantomChip. - if opcode == SystemOpcode::PHANTOM.global_opcode() { - // Note: the discriminant is the lower 16 bits of the c operand. - let discriminant = c.as_canonical_u32() as u16; - let phantom = SysPhantom::from_repr(discriminant); - tracing::trace!("pc: {pc:#x} | system phantom: {phantom:?}"); + #[cfg(feature = "bench-metrics")] + { + let dsl_str = dsl_instr.clone().unwrap_or_else(|| "Default".to_string()); match phantom { - Some(SysPhantom::DebugPanic) => { - if let Some(mut backtrace) = prev_backtrace { - backtrace.resolve(); - eprintln!("openvm program failure; backtrace:\n{:?}", backtrace); - } else { - eprintln!("openvm program failure; no backtrace"); - } - return Err(ExecutionError::Fail { pc }); - } - Some(SysPhantom::CtStart) => - { - #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .start(dsl_instr.cloned().unwrap_or("Default".to_string())) - } - Some(SysPhantom::CtEnd) => - { - #[cfg(feature = "bench-metrics")] - metrics - .cycle_tracker - .end(dsl_instr.cloned().unwrap_or("Default".to_string())) - } + SysPhantom::CtStart => self.metrics.cycle_tracker.start(dsl_str), + SysPhantom::CtEnd => self.metrics.cycle_tracker.end(dsl_str), _ => {} } } - prev_backtrace = trace.cloned(); - - if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { - let next_state = InstructionExecutor::execute( - executor, - memory_controller, - instruction, - ExecutionState::new(pc, timestamp), - )?; - assert!(next_state.timestamp > timestamp); - pc = next_state.pc; - timestamp = next_state.timestamp; - } else { - return Err(ExecutionError::DisabledOperation { pc, opcode }); - }; - (opcode, dsl_instr.cloned()) - }; + } + } - #[cfg(feature = "bench-metrics")] + // TODO(ayush): move to vm state? + *prev_backtrace = trace.cloned(); + + // Execute the instruction using the control implementation + // TODO(AG): maybe avoid cloning the instruction? + self.ctrl + .execute_instruction(state, &instruction.clone(), &mut self.chip_complex)?; + + // Update metrics if enabled + #[cfg(feature = "bench-metrics")] + { self.update_instruction_metrics(pc, opcode, dsl_instr); + } - if self.should_segment() { - self.chip_complex - .connector_chip_mut() - .end(ExecutionState::new(pc, timestamp), None); - break; - } + Ok(()) + } + + /// Returns bool of whether to switch to next segment or not. + fn should_suspend(&mut self, state: &mut VmSegmentState) -> bool { + if !self.system_config().continuation_enabled { + return false; } - self.final_memory = Some( - self.chip_complex - .base - .memory_controller - .memory_image() - .clone(), - ); - Ok(ExecutionSegmentState { - pc, - is_terminated: did_terminate, - }) + // Check with the execution control policy + self.ctrl.should_suspend(state, &self.chip_complex) } + // TODO(ayush): this is not relevant for e1/e2 execution /// Generate ProofInput to prove the segment. Should be called after ::execute pub fn generate_proof_input( #[allow(unused_mut)] mut self, @@ -358,30 +238,28 @@ impl> ExecutionSegment { }) } - /// Returns bool of whether to switch to next segment or not. This is called every clock cycle - /// inside of Core trace generation. - fn should_segment(&mut self) -> bool { - if !self.system_config().continuation_enabled { - return false; - } - // Avoid checking segment too often. - if self.since_last_segment_check != SEGMENT_CHECK_INTERVAL { - self.since_last_segment_check += 1; - return false; + #[cfg(feature = "bench-metrics")] + #[allow(unused_variables)] + pub fn update_instruction_metrics( + &mut self, + pc: u32, + opcode: VmOpcode, + dsl_instr: Option, + ) { + self.metrics.cycle_count += 1; + + if self.system_config().profiling { + let executor = self.chip_complex.inventory.get_executor(opcode).unwrap(); + let opcode_name = executor.get_opcode_name(opcode.as_usize()); + self.metrics.update_trace_cells( + &self.air_names, + self.chip_complex.current_trace_cells(), + opcode_name, + dsl_instr, + ); + + #[cfg(feature = "function-span")] + self.metrics.update_current_fn(pc); } - self.since_last_segment_check = 0; - let segmentation_strategy = &self.system_config().segmentation_strategy; - segmentation_strategy.should_segment( - &self.air_names, - &self - .chip_complex - .dynamic_trace_heights() - .collect::>(), - &self.chip_complex.current_trace_cells(), - ) - } - - pub fn current_trace_cells(&self) -> Vec { - self.chip_complex.current_trace_cells() } } diff --git a/crates/vm/src/arch/segmentation_strategy.rs b/crates/vm/src/arch/segmentation_strategy.rs new file mode 100644 index 0000000000..000bc0f8e0 --- /dev/null +++ b/crates/vm/src/arch/segmentation_strategy.rs @@ -0,0 +1,109 @@ +use std::sync::Arc; + +pub const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; +// a heuristic number for the maximum number of cells per chip in a segment +// a few reasons for this number: +// 1. `VmAirWrapper` is +// the chip with the most cells in a segment from the reth-benchmark. +// 2. `VmAirWrapper`: +// its trace width is 36 and its after challenge trace width is 80. +pub const DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT: usize = DEFAULT_MAX_SEGMENT_LEN * 120; + +pub trait SegmentationStrategy: + std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe +{ + /// Whether the execution should segment based on the trace heights and cells. + /// + /// Air names are provided for debugging purposes. + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool; + + /// A strategy that segments more aggressively than the current one. + /// + /// Called when `should_segment` results in a segment that is infeasible. Execution will be + /// re-run with the stricter segmentation strategy. + fn stricter_strategy(&self) -> Arc; +} + +/// Default segmentation strategy: segment if any chip's height or cells exceed the limits. +#[derive(Debug, Clone)] +pub struct DefaultSegmentationStrategy { + max_segment_len: usize, + max_cells_per_chip_in_segment: usize, +} + +impl Default for DefaultSegmentationStrategy { + fn default() -> Self { + Self { + max_segment_len: DEFAULT_MAX_SEGMENT_LEN, + max_cells_per_chip_in_segment: DEFAULT_MAX_CELLS_PER_CHIP_IN_SEGMENT, + } + } +} + +impl DefaultSegmentationStrategy { + pub fn new_with_max_segment_len(max_segment_len: usize) -> Self { + Self { + max_segment_len, + max_cells_per_chip_in_segment: max_segment_len * 120, + } + } + + pub fn new(max_segment_len: usize, max_cells_per_chip_in_segment: usize) -> Self { + Self { + max_segment_len, + max_cells_per_chip_in_segment, + } + } + + pub fn max_segment_len(&self) -> usize { + self.max_segment_len + } +} + +const SEGMENTATION_BACKOFF_FACTOR: usize = 4; + +impl SegmentationStrategy for DefaultSegmentationStrategy { + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool { + for (i, &height) in trace_heights.iter().enumerate() { + if height > self.max_segment_len { + tracing::info!( + "Should segment because chip {} (name: {}) has height {}", + i, + air_names[i], + height + ); + return true; + } + } + for (i, &num_cells) in trace_cells.iter().enumerate() { + if num_cells > self.max_cells_per_chip_in_segment { + tracing::info!( + "Should segment because chip {} (name: {}) has {} cells", + i, + air_names[i], + num_cells + ); + return true; + } + } + false + } + + fn stricter_strategy(&self) -> Arc { + Arc::new(Self { + max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR, + max_cells_per_chip_in_segment: self.max_cells_per_chip_in_segment + / SEGMENTATION_BACKOFF_FACTOR, + }) + } +} diff --git a/crates/vm/src/arch/testing/memory/air.rs b/crates/vm/src/arch/testing/memory/air.rs index 8a394c0cce..90c1b4ce49 100644 --- a/crates/vm/src/arch/testing/memory/air.rs +++ b/crates/vm/src/arch/testing/memory/air.rs @@ -1,46 +1,155 @@ -use std::{borrow::Borrow, mem::size_of}; +use std::{mem::size_of, sync::Arc}; -use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, BaseAir}, - p3_matrix::Matrix, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + prover::types::AirProofInput, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + AirRef, Chip, ChipUsageGetter, }; use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress}; -#[derive(Clone, Copy, Debug, AlignedBorrow, derive_new::new)] #[repr(C)] -pub struct DummyMemoryInteractionCols { - pub address: MemoryAddress, - pub data: [T; BLOCK_SIZE], - pub timestamp: T, +#[derive(Clone, Copy)] +pub struct DummyMemoryInteractionColsRef<'a, T> { + pub address: MemoryAddress<&'a T, &'a T>, + pub data: &'a [T], + pub timestamp: &'a T, /// The send frequency. Send corresponds to write. To read, set to negative. - pub count: T, + pub count: &'a T, +} + +#[repr(C)] +pub struct DummyMemoryInteractionColsMut<'a, T> { + pub address: MemoryAddress<&'a mut T, &'a mut T>, + pub data: &'a mut [T], + pub timestamp: &'a mut T, + /// The send frequency. Send corresponds to write. To read, set to negative. + pub count: &'a mut T, +} + +impl<'a, T> DummyMemoryInteractionColsRef<'a, T> { + pub fn from_slice(slice: &'a [T]) -> Self { + let (address, slice) = slice.split_at(size_of::>()); + let (count, slice) = slice.split_last().unwrap(); + let (timestamp, data) = slice.split_last().unwrap(); + Self { + address: MemoryAddress::new(&address[0], &address[1]), + data, + timestamp, + count, + } + } +} + +impl<'a, T> DummyMemoryInteractionColsMut<'a, T> { + pub fn from_mut_slice(slice: &'a mut [T]) -> Self { + let (addr_space, slice) = slice.split_first_mut().unwrap(); + let (ptr, slice) = slice.split_first_mut().unwrap(); + let (count, slice) = slice.split_last_mut().unwrap(); + let (timestamp, data) = slice.split_last_mut().unwrap(); + Self { + address: MemoryAddress::new(addr_space, ptr), + data, + timestamp, + count, + } + } } #[derive(Clone, Copy, Debug, derive_new::new)] -pub struct MemoryDummyAir { +pub struct MemoryDummyAir { pub bus: MemoryBus, + pub block_size: usize, } -impl BaseAirWithPublicValues for MemoryDummyAir {} -impl PartitionedBaseAir for MemoryDummyAir {} -impl BaseAir for MemoryDummyAir { +impl BaseAirWithPublicValues for MemoryDummyAir {} +impl PartitionedBaseAir for MemoryDummyAir {} +impl BaseAir for MemoryDummyAir { fn width(&self) -> usize { - size_of::>() + self.block_size + 4 } } -impl Air for MemoryDummyAir { +impl Air for MemoryDummyAir { fn eval(&self, builder: &mut AB) { let main = builder.main(); let local = main.row_slice(0); - let local: &DummyMemoryInteractionCols = (*local).borrow(); + let local = DummyMemoryInteractionColsRef::from_slice(&local); self.bus - .send(local.address, local.data.to_vec(), local.timestamp) - .eval(builder, local.count); + .send( + MemoryAddress::new(*local.address.address_space, *local.address.pointer), + local.data.to_vec(), + *local.timestamp, + ) + .eval(builder, *local.count); + } +} + +#[derive(Clone)] +pub struct MemoryDummyChip { + pub air: MemoryDummyAir, + pub trace: Vec, +} + +impl MemoryDummyChip { + pub fn new(air: MemoryDummyAir) -> Self { + Self { + air, + trace: Vec::new(), + } + } +} + +impl MemoryDummyChip { + pub fn send(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) { + self.push(addr_space, ptr, data, timestamp, F::ONE); + } + + pub fn receive(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32) { + self.push(addr_space, ptr, data, timestamp, F::NEG_ONE); + } + + pub fn push(&mut self, addr_space: u32, ptr: u32, data: &[F], timestamp: u32, count: F) { + assert_eq!(data.len(), self.air.block_size); + self.trace.push(F::from_canonical_u32(addr_space)); + self.trace.push(F::from_canonical_u32(ptr)); + self.trace.extend_from_slice(data); + self.trace.push(F::from_canonical_u32(timestamp)); + self.trace.push(count); + } +} + +impl Chip for MemoryDummyChip> +where + Val: PrimeField32, +{ + fn air(&self) -> AirRef { + Arc::new(self.air) + } + + fn generate_air_proof_input(mut self) -> AirProofInput { + let height = self.current_trace_height().next_power_of_two(); + let width = self.trace_width(); + self.trace.resize(height * width, Val::::ZERO); + + AirProofInput::simple_no_pis(RowMajorMatrix::new(self.trace, width)) + } +} + +impl ChipUsageGetter for MemoryDummyChip { + fn air_name(&self) -> String { + format!("MemoryDummyAir<{}>", self.air.block_size) + } + fn current_trace_height(&self) -> usize { + self.trace.len() / self.trace_width() + } + fn trace_width(&self) -> usize { + BaseAir::::width(&self.air) } } diff --git a/crates/vm/src/arch/testing/memory/mod.rs b/crates/vm/src/arch/testing/memory/mod.rs index ae1136bc7f..247ca3970d 100644 --- a/crates/vm/src/arch/testing/memory/mod.rs +++ b/crates/vm/src/arch/testing/memory/mod.rs @@ -1,140 +1,105 @@ -use std::{array::from_fn, borrow::BorrowMut as _, cell::RefCell, mem::size_of, rc::Rc, sync::Arc}; +use std::collections::HashMap; -use air::{DummyMemoryInteractionCols, MemoryDummyAir}; +use air::{MemoryDummyAir, MemoryDummyChip}; use openvm_circuit::system::memory::MemoryController; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - prover::types::AirProofInput, - AirRef, Chip, ChipUsageGetter, -}; -use rand::{seq::SliceRandom, Rng}; - -use crate::system::memory::{offline_checker::MemoryBus, MemoryAddress, RecordId}; +use openvm_stark_backend::p3_field::PrimeField32; +use rand::Rng; pub mod air; -const WORD_SIZE: usize = 1; - /// A dummy testing chip that will add unconstrained messages into the [MemoryBus]. /// Stores a log of raw messages to send/receive to the [MemoryBus]. /// /// It will create a [air::MemoryDummyAir] to add messages to MemoryBus. pub struct MemoryTester { - pub bus: MemoryBus, - pub controller: Rc>>, - /// Log of record ids - pub records: Vec, + /// Map from `block_size` to [MemoryDummyChip] of that block size + pub chip_for_block: HashMap>, + // TODO: make this just TracedMemory? + pub controller: MemoryController, } impl MemoryTester { - pub fn new(controller: Rc>>) -> Self { - let bus = controller.borrow().memory_bus; + pub fn new(controller: MemoryController) -> Self { + let bus = controller.memory_bus; + let mut chip_for_block = HashMap::new(); + for log_block_size in 0..6 { + let block_size = 1 << log_block_size; + let chip = MemoryDummyChip::new(MemoryDummyAir::new(bus, block_size)); + chip_for_block.insert(block_size, chip); + } Self { - bus, + chip_for_block, controller, - records: Vec::new(), - } - } - - /// Returns the cell value at the current timestamp according to `MemoryController`. - pub fn read_cell(&mut self, address_space: usize, pointer: usize) -> F { - let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize); - // core::BorrowMut confuses compiler - let (record_id, value) = - RefCell::borrow_mut(&self.controller).read_cell(addr_space, pointer); - self.records.push(record_id); - value - } - - pub fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) { - let [addr_space, pointer] = [address_space, pointer].map(F::from_canonical_usize); - let (record_id, _) = - RefCell::borrow_mut(&self.controller).write_cell(addr_space, pointer, value); - self.records.push(record_id); - } - - pub fn read(&mut self, address_space: usize, pointer: usize) -> [F; N] { - from_fn(|i| self.read_cell(address_space, pointer + i)) - } - - pub fn write( - &mut self, - address_space: usize, - mut pointer: usize, - cells: [F; N], - ) { - for cell in cells { - self.write_cell(address_space, pointer, cell); - pointer += 1; } } -} -impl Chip for MemoryTester> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(MemoryDummyAir::::new(self.bus)) - } - - fn generate_air_proof_input(self) -> AirProofInput { - let offline_memory = self.controller.borrow().offline_memory(); - let offline_memory = offline_memory.lock().unwrap(); - - let height = self.records.len().next_power_of_two(); - let width = self.trace_width(); - let mut values = Val::::zero_vec(2 * height * width); - // This zip only goes through records. The padding rows between records.len()..height - // are filled with zeros - in particular count = 0 so nothing is added to bus. - for (row, id) in values.chunks_mut(2 * width).zip(self.records) { - let (first, second) = row.split_at_mut(width); - let row: &mut DummyMemoryInteractionCols, WORD_SIZE> = first.borrow_mut(); - let record = offline_memory.record_by_id(id); - row.address = MemoryAddress { - address_space: record.address_space, - pointer: record.pointer, - }; - row.data - .copy_from_slice(record.prev_data_slice().unwrap_or(record.data_slice())); - row.timestamp = Val::::from_canonical_u32(record.prev_timestamp); - row.count = -Val::::ONE; - - let row: &mut DummyMemoryInteractionCols, WORD_SIZE> = second.borrow_mut(); - row.address = MemoryAddress { - address_space: record.address_space, - pointer: record.pointer, + // TODO: change interface by implementing GuestMemory trait after everything works + pub fn read(&mut self, addr_space: usize, ptr: usize) -> [F; N] { + let controller = &mut self.controller; + let t = controller.memory.timestamp(); + // TODO: hack + let (t_prev, data) = if addr_space <= 3 { + let (t_prev, data) = unsafe { + controller + .memory + .read::(addr_space as u32, ptr as u32) }; - row.data.copy_from_slice(record.data_slice()); - row.timestamp = Val::::from_canonical_u32(record.timestamp); - row.count = Val::::ONE; - } - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) + (t_prev, data.map(F::from_canonical_u8)) + } else { + unsafe { + controller + .memory + .read::(addr_space as u32, ptr as u32) + } + }; + self.chip_for_block.get_mut(&N).unwrap().receive( + addr_space as u32, + ptr as u32, + &data, + t_prev, + ); + self.chip_for_block + .get_mut(&N) + .unwrap() + .send(addr_space as u32, ptr as u32, &data, t); + + data } -} -impl ChipUsageGetter for MemoryTester { - fn air_name(&self) -> String { - "MemoryDummyAir".to_string() - } - fn current_trace_height(&self) -> usize { - self.records.len() - } - - fn trace_width(&self) -> usize { - size_of::>() + // TODO: see read + pub fn write(&mut self, addr_space: usize, ptr: usize, data: [F; N]) { + let controller = &mut self.controller; + let t = controller.memory.timestamp(); + // TODO: hack + let (t_prev, data_prev) = if addr_space <= 3 { + let (t_prev, data_prev) = unsafe { + controller.memory.write::( + addr_space as u32, + ptr as u32, + &data.map(|x| x.as_canonical_u32() as u8), + ) + }; + (t_prev, data_prev.map(F::from_canonical_u8)) + } else { + unsafe { + controller + .memory + .write::(addr_space as u32, ptr as u32, &data) + } + }; + self.chip_for_block.get_mut(&N).unwrap().receive( + addr_space as u32, + ptr as u32, + &data_prev, + t_prev, + ); + self.chip_for_block + .get_mut(&N) + .unwrap() + .send(addr_space as u32, ptr as u32, &data, t); } } -pub fn gen_address_space(rng: &mut R) -> usize -where - R: Rng + ?Sized, -{ - *[1, 2].choose(rng).unwrap() -} - pub fn gen_pointer(rng: &mut R, len: usize) -> usize where R: Rng + ?Sized, diff --git a/crates/vm/src/arch/testing/mod.rs b/crates/vm/src/arch/testing/mod.rs index 44b19177be..fd0c73dc1a 100644 --- a/crates/vm/src/arch/testing/mod.rs +++ b/crates/vm/src/arch/testing/mod.rs @@ -1,18 +1,14 @@ -use std::{ - cell::RefCell, - iter::zip, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{borrow::Borrow, iter::zip}; use openvm_circuit_primitives::var_range::{ SharedVariableRangeCheckerChip, VariableRangeCheckerBus, }; use openvm_instructions::instruction::Instruction; +use openvm_poseidon2_air::Poseidon2Config; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::VerificationData, - interaction::BusIndex, + interaction::{BusIndex, PermutationCheckBus}, p3_field::PrimeField32, p3_matrix::dense::{DenseMatrix, RowMajorMatrix}, prover::types::AirProofInput, @@ -32,13 +28,14 @@ use program::ProgramTester; use rand::{rngs::StdRng, RngCore, SeedableRng}; use tracing::Level; -use super::{ExecutionBus, InstructionExecutor, SystemPort}; +use super::{ExecutionBridge, ExecutionBus, InstructionExecutor, SystemPort}; use crate::{ arch::{ExecutionState, MemoryConfig}, system::{ memory::{ + interface::MemoryInterface, offline_checker::{MemoryBridge, MemoryBus}, - MemoryController, OfflineMemory, + MemoryController, SharedMemoryHelper, }, poseidon2::Poseidon2PeripheryChip, program::ProgramBus, @@ -48,11 +45,9 @@ use crate::{ pub mod execution; pub mod memory; pub mod program; -pub mod test_adapter; pub use execution::ExecutionTester; pub use memory::MemoryTester; -pub use test_adapter::TestAdapterChip; pub const EXECUTION_BUS: BusIndex = 0; pub const MEMORY_BUS: BusIndex = 1; @@ -76,7 +71,7 @@ pub struct VmChipTestBuilder { impl VmChipTestBuilder { pub fn new( - memory_controller: Rc>>, + memory_controller: MemoryController, execution_bus: ExecutionBus, program_bus: ProgramBus, rng: StdRng, @@ -110,16 +105,12 @@ impl VmChipTestBuilder { ) { let initial_state = ExecutionState { pc: initial_pc, - timestamp: self.memory.controller.borrow().timestamp(), + timestamp: self.memory.controller.timestamp(), }; tracing::debug!(?initial_state.timestamp); let final_state = executor - .execute( - &mut *self.memory.controller.borrow_mut(), - instruction, - initial_state, - ) + .execute(&mut self.memory.controller, instruction, initial_state) .expect("Expected the execution not to fail"); self.program.execute(instruction, &initial_state); @@ -130,14 +121,6 @@ impl VmChipTestBuilder { self.rng.next_u32() % (1 << (F::bits() - 2)) } - pub fn read_cell(&mut self, address_space: usize, pointer: usize) -> F { - self.memory.read_cell(address_space, pointer) - } - - pub fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) { - self.memory.write_cell(address_space, pointer, value); - } - pub fn read(&mut self, address_space: usize, pointer: usize) -> [F; N] { self.memory.read(address_space, pointer) } @@ -162,9 +145,22 @@ impl VmChipTestBuilder { pointer: usize, writes: Vec<[F; NUM_LIMBS]>, ) { - self.write(1usize, register, [F::from_canonical_usize(pointer)]); - for (i, &write) in writes.iter().enumerate() { - self.write(2usize, pointer + i * NUM_LIMBS, write); + self.write( + 1usize, + register, + pointer.to_le_bytes().map(F::from_canonical_u8), + ); + if NUM_LIMBS.is_power_of_two() { + for (i, &write) in writes.iter().enumerate() { + self.write(2usize, pointer + i * NUM_LIMBS, write); + } + } else { + for (i, &write) in writes.iter().enumerate() { + let ptr = pointer + i * NUM_LIMBS; + for j in (0..NUM_LIMBS).step_by(4) { + self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap()); + } + } } } @@ -176,6 +172,10 @@ impl VmChipTestBuilder { } } + pub fn execution_bridge(&self) -> ExecutionBridge { + ExecutionBridge::new(self.execution.bus, self.program.bus) + } + pub fn execution_bus(&self) -> ExecutionBus { self.execution.bus } @@ -185,27 +185,27 @@ impl VmChipTestBuilder { } pub fn memory_bus(&self) -> MemoryBus { - self.memory.bus + self.memory.controller.memory_bus } - pub fn memory_controller(&self) -> Rc>> { - self.memory.controller.clone() + pub fn memory_controller(&self) -> &MemoryController { + &self.memory.controller } pub fn range_checker(&self) -> SharedVariableRangeCheckerChip { - self.memory.controller.borrow().range_checker.clone() + self.memory.controller.range_checker.clone() } pub fn memory_bridge(&self) -> MemoryBridge { - self.memory.controller.borrow().memory_bridge() + self.memory.controller.memory_bridge() } - pub fn address_bits(&self) -> usize { - self.memory.controller.borrow().mem_config.pointer_max_bits + pub fn memory_helper(&self) -> SharedMemoryHelper { + self.memory.controller.helper() } - pub fn offline_memory_mutex_arc(&self) -> Arc>> { - self.memory_controller().borrow().offline_memory().clone() + pub fn address_bits(&self) -> usize { + self.memory.controller.mem_config.pointer_max_bits } pub fn get_default_register(&mut self, increment: usize) -> usize { @@ -247,10 +247,6 @@ type TestSC = BabyBearBlake3Config; impl VmChipTestBuilder { pub fn build(self) -> VmChipTester { - self.memory - .controller - .borrow_mut() - .finalize(None::<&mut Poseidon2PeripheryChip>); let tester = VmChipTester { memory: Some(self.memory), ..Default::default() @@ -259,10 +255,6 @@ impl VmChipTestBuilder { tester.load(self.program) } pub fn build_babybear_poseidon2(self) -> VmChipTester { - self.memory - .controller - .borrow_mut() - .finalize(None::<&mut Poseidon2PeripheryChip>); let tester = VmChipTester { memory: Some(self.memory), ..Default::default() @@ -272,8 +264,34 @@ impl VmChipTestBuilder { } } +impl VmChipTestBuilder { + pub fn default_persistent() -> Self { + let mem_config = MemoryConfig::default(); + let range_checker = SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new( + RANGE_CHECKER_BUS, + mem_config.decomp, + )); + let memory_controller = MemoryController::with_persistent_memory( + MemoryBus::new(MEMORY_BUS), + mem_config, + range_checker, + PermutationCheckBus::new(MEMORY_MERKLE_BUS), + PermutationCheckBus::new(POSEIDON2_DIRECT_BUS), + ); + Self { + memory: MemoryTester::new(memory_controller), + execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)), + program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)), + rng: StdRng::seed_from_u64(0), + default_register: 0, + default_pointer: 0, + } + } +} + impl Default for VmChipTestBuilder { fn default() -> Self { + setup_tracing_with_log_level(Level::INFO); let mem_config = MemoryConfig::default(); let range_checker = SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new( RANGE_CHECKER_BUS, @@ -285,7 +303,7 @@ impl Default for VmChipTestBuilder { range_checker, ); Self { - memory: MemoryTester::new(Rc::new(RefCell::new(memory_controller))), + memory: MemoryTester::new(memory_controller), execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)), program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)), rng: StdRng::seed_from_u64(0), @@ -326,19 +344,47 @@ where pub fn finalize(mut self) -> Self { if let Some(memory_tester) = self.memory.take() { - let memory_controller = memory_tester.controller.clone(); - let range_checker = memory_controller.borrow().range_checker.clone(); - self = self.load(memory_tester); // dummy memory interactions - { - let airs = memory_controller.borrow().airs(); - let air_proof_inputs = Rc::try_unwrap(memory_controller) - .unwrap_or_else(|_| panic!("Memory controller was not dropped")) - .into_inner() - .generate_air_proof_inputs(); - self.air_proof_inputs.extend( - zip(airs, air_proof_inputs).filter(|(_, input)| input.main_trace_height() > 0), - ); - } + // Balance memory boundaries + let mut memory_controller = memory_tester.controller; + let range_checker = memory_controller.range_checker.clone(); + match &memory_controller.interface_chip { + MemoryInterface::Volatile { .. } => { + memory_controller.finalize(None::<&mut Poseidon2PeripheryChip>>); + // dummy memory interactions: + for mem_chip in memory_tester.chip_for_block.into_values() { + self = self.load(mem_chip); + } + { + let airs = memory_controller.borrow().airs(); + let air_proof_inputs = memory_controller.generate_air_proof_inputs(); + self.air_proof_inputs.extend( + zip(airs, air_proof_inputs) + .filter(|(_, input)| input.main_trace_height() > 0), + ); + } + } + MemoryInterface::Persistent { .. } => { + let mut poseidon_chip = Poseidon2PeripheryChip::new( + Poseidon2Config::default(), + POSEIDON2_DIRECT_BUS, + 3, + ); + memory_controller.finalize(Some(&mut poseidon_chip)); + // dummy memory interactions: + for mem_chip in memory_tester.chip_for_block.into_values() { + self = self.load(mem_chip); + } + { + let airs = memory_controller.borrow().airs(); + let air_proof_inputs = memory_controller.generate_air_proof_inputs(); + self.air_proof_inputs.extend( + zip(airs, air_proof_inputs) + .filter(|(_, input)| input.main_trace_height() > 0), + ); + } + self = self.load(poseidon_chip); + } + }; self = self.load(range_checker); // this must be last because other trace generation // mutates its state } diff --git a/crates/vm/src/arch/testing/test_adapter.rs b/crates/vm/src/arch/testing/test_adapter.rs deleted file mode 100644 index bca9eed724..0000000000 --- a/crates/vm/src/arch/testing/test_adapter.rs +++ /dev/null @@ -1,175 +0,0 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - collections::VecDeque, - fmt::Debug, -}; - -use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::instruction::Instruction; -use openvm_stark_backend::{ - interaction::InteractionBuilder, - p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, -}; -use serde::{Deserialize, Serialize}; - -use crate::{ - arch::{ - AdapterAirContext, AdapterRuntimeContext, DynAdapterInterface, DynArray, ExecutionBridge, - ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - }, - system::memory::{MemoryController, OfflineMemory}, -}; - -// Replaces A: VmAdapterChip while testing VmCoreChip functionality, as it has no -// constraints and thus cannot cause a failure. -pub struct TestAdapterChip { - /// List of the return values of `preprocess` this chip should provide on each sequential call. - pub prank_reads: VecDeque>, - /// List of `pc_inc` to use in `postprocess` on each sequential call. - /// Defaults to `4` if not provided. - pub prank_pc_inc: VecDeque>, - - pub air: TestAdapterAir, -} - -impl TestAdapterChip { - pub fn new( - prank_reads: Vec>, - prank_pc_inc: Vec>, - execution_bridge: ExecutionBridge, - ) -> Self { - Self { - prank_reads: prank_reads.into(), - prank_pc_inc: prank_pc_inc.into(), - air: TestAdapterAir { execution_bridge }, - } - } -} - -#[derive(Clone, Serialize, Deserialize)] -pub struct TestAdapterRecord { - pub from_pc: u32, - pub operands: [T; 7], -} - -impl VmAdapterChip for TestAdapterChip { - type ReadRecord = (); - type WriteRecord = TestAdapterRecord; - type Air = TestAdapterAir; - type Interface = DynAdapterInterface; - - fn preprocess( - &mut self, - _memory: &mut MemoryController, - _instruction: &Instruction, - ) -> Result<(DynArray, Self::ReadRecord)> { - Ok(( - self.prank_reads - .pop_front() - .expect("Not enough prank reads provided") - .into(), - (), - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - _output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let pc_inc = self - .prank_pc_inc - .pop_front() - .map(|x| x.unwrap_or(4)) - .unwrap_or(4); - Ok(( - ExecutionState { - pc: from_state.pc + pc_inc, - timestamp: memory.timestamp(), - }, - TestAdapterRecord { - operands: [ - instruction.a, - instruction.b, - instruction.c, - instruction.d, - instruction.e, - instruction.f, - instruction.g, - ], - from_pc: from_state.pc, - }, - )) - } - - fn generate_trace_row( - &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - _memory: &OfflineMemory, - ) { - let cols: &mut TestAdapterCols = row_slice.borrow_mut(); - cols.from_pc = F::from_canonical_u32(write_record.from_pc); - cols.operands = write_record.operands; - // row_slice[0] = F::from_canonical_u32(write_record.from_pc); - // row_slice[1..].copy_from_slice(&write_record.operands); - } - - fn air(&self) -> &Self::Air { - &self.air - } -} - -#[derive(Clone, Copy, Debug)] -pub struct TestAdapterAir { - pub execution_bridge: ExecutionBridge, -} - -#[repr(C)] -#[derive(AlignedBorrow)] -pub struct TestAdapterCols { - pub from_pc: T, - pub operands: [T; 7], -} - -impl BaseAir for TestAdapterAir { - fn width(&self) -> usize { - TestAdapterCols::::width() - } -} - -impl VmAdapterAir for TestAdapterAir { - type Interface = DynAdapterInterface; - - fn eval( - &self, - builder: &mut AB, - local: &[AB::Var], - ctx: AdapterAirContext, - ) { - let processed_instruction: MinimalInstruction = ctx.instruction.into(); - let cols: &TestAdapterCols = local.borrow(); - self.execution_bridge - .execute_and_increment_or_set_pc( - processed_instruction.opcode, - cols.operands.to_vec(), - ExecutionState { - pc: cols.from_pc.into(), - timestamp: AB::Expr::ONE, - }, - AB::Expr::ZERO, - (4, ctx.to_pc), - ) - .eval(builder, processed_instruction.is_valid); - } - - fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var { - let cols: &TestAdapterCols = local.borrow(); - cols.from_pc - } -} diff --git a/crates/vm/src/arch/vm.rs b/crates/vm/src/arch/vm.rs index a826fb4137..bcdf433acd 100644 --- a/crates/vm/src/arch/vm.rs +++ b/crates/vm/src/arch/vm.rs @@ -1,13 +1,14 @@ use std::{borrow::Borrow, collections::VecDeque, marker::PhantomData, mem, sync::Arc}; use openvm_circuit::system::program::trace::compute_exe_commit; -use openvm_instructions::exe::VmExe; +use openvm_instructions::{exe::VmExe, program::Program}; use openvm_stark_backend::{ config::{Com, Domain, StarkGenericConfig, Val}, engine::StarkEngine, keygen::types::{LinearConstraint, MultiStarkProvingKey, MultiStarkVerifyingKey}, p3_commit::PolynomialSpace, p3_field::{FieldAlgebra, PrimeField32}, + p3_util::log2_strict_usize, proof::Proof, prover::types::{CommittedTraceData, ProofInput}, utils::metrics_span, @@ -19,17 +20,27 @@ use thiserror::Error; use tracing::info_span; use super::{ - ExecutionError, VmComplexTraceHeights, VmConfig, CONNECTOR_AIR_ID, MERKLE_AIR_ID, - PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, + execution_mode::{metered::Segment, tracegen::TracegenExecutionControlWithSegmentation}, + ExecutionError, InsExecutorE1, VmChipComplex, VmComplexTraceHeights, VmConfig, + VmInventoryError, CONNECTOR_AIR_ID, MERKLE_AIR_ID, PROGRAM_AIR_ID, PROGRAM_CACHED_TRACE_INDEX, }; #[cfg(feature = "bench-metrics")] use crate::metrics::VmMetrics; use crate::{ - arch::{hasher::poseidon2::vm_poseidon2_hasher, segment::ExecutionSegment}, + arch::{ + execution_mode::{ + e1::E1ExecutionControl, + metered::{MeteredCtx, MeteredExecutionControl}, + tracegen::TracegenExecutionControl, + }, + hasher::poseidon2::vm_poseidon2_hasher, + VmSegmentExecutor, VmSegmentState, + }, system::{ connector::{VmConnectorPvs, DEFAULT_SUSPEND_EXIT_CODE}, memory::{ merkle::MemoryMerklePvs, + online::GuestMemory, paged_vec::AddressMap, tree::public_values::{UserPublicValuesProof, UserPublicValuesProofError}, MemoryImage, CHUNK, @@ -47,7 +58,6 @@ pub enum GenerationError { } /// VM memory state for continuations. -pub type VmMemoryState = MemoryImage; #[derive(Clone, Default, Debug)] pub struct Streams { @@ -95,32 +105,41 @@ pub enum ExitCode { pub struct VmExecutorResult { pub per_segment: Vec>, /// When VM is running on persistent mode, public values are stored in a special memory space. - pub final_memory: Option>>, + pub final_memory: Option, } -pub struct VmExecutorNextSegmentState { - pub memory: MemoryImage, - pub input: Streams, +pub struct VmState +where + F: PrimeField32, +{ + pub clk: u64, pub pc: u32, + pub memory: MemoryImage, + pub input: Streams, #[cfg(feature = "bench-metrics")] pub metrics: VmMetrics, } -impl VmExecutorNextSegmentState { - pub fn new(memory: MemoryImage, input: impl Into>, pc: u32) -> Self { +impl VmState { + pub fn new(clk: u64, pc: u32, memory: MemoryImage, input: impl Into>) -> Self { Self { + clk, + pc, memory, input: input.into(), - pc, #[cfg(feature = "bench-metrics")] metrics: VmMetrics::default(), } } } -pub struct VmExecutorOneSegmentResult> { - pub segment: ExecutionSegment, - pub next_state: Option>, +pub struct VmExecutorOneSegmentResult +where + F: PrimeField32, + VC: VmConfig, +{ + pub segment: VmSegmentExecutor, + pub next_state: Option>, } impl VmExecutor @@ -164,29 +183,32 @@ where &self, exe: impl Into>, input: impl Into>, - mut f: impl FnMut(usize, ExecutionSegment) -> Result, + mut f: impl FnMut( + usize, + VmSegmentExecutor, + ) -> Result, map_err: impl Fn(ExecutionError) -> E, ) -> Result, E> { let mem_config = self.config.system().memory_config; let exe = exe.into(); - let mut segment_results = vec![]; - let memory = AddressMap::from_iter( + let memory = AddressMap::from_sparse( mem_config.as_offset, 1 << mem_config.as_height, 1 << mem_config.pointer_max_bits, exe.init_memory.clone(), ); + let pc = exe.pc_start; - let mut state = VmExecutorNextSegmentState::new(memory, input, pc); + let mut state = VmState::new(0, pc, memory, input); #[cfg(feature = "bench-metrics")] { state.metrics.fn_bounds = exe.fn_bounds.clone(); } - let mut segment_idx = 0; - + let mut segment_results = vec![]; loop { + let segment_idx = segment_results.len(); let _span = info_span!("execute_segment", segment = segment_idx).entered(); let one_segment_result = self .execute_until_segment(exe.clone(), state) @@ -196,7 +218,6 @@ where break; } state = one_segment_result.next_state.unwrap(); - segment_idx += 1; } tracing::debug!("Number of continuation segments: {}", segment_results.len()); #[cfg(feature = "bench-metrics")] @@ -209,7 +230,10 @@ where &self, exe: impl Into>, input: impl Into>, - ) -> Result>, ExecutionError> { + ) -> Result< + Vec>, + ExecutionError, + > { self.execute_and_then(exe, input, |_, seg| Ok(seg), |err| err) } @@ -220,17 +244,25 @@ where pub fn execute_until_segment( &self, exe: impl Into>, - from_state: VmExecutorNextSegmentState, + from_state: VmState, ) -> Result, ExecutionError> { let exe = exe.into(); - let mut segment = ExecutionSegment::new( + + let chip_complex = create_and_initialize_chip_complex( &self.config, exe.program.clone(), from_state.input, Some(from_state.memory), + ) + .unwrap(); + let ctrl = TracegenExecutionControlWithSegmentation::new(chip_complex.air_names()); + let mut segment = VmSegmentExecutor::new( + chip_complex, self.trace_height_constraints.clone(), exe.fn_bounds.clone(), + ctrl, ); + #[cfg(feature = "bench-metrics")] { segment.metrics = from_state.metrics; @@ -238,9 +270,13 @@ where if let Some(overridden_heights) = self.overridden_heights.as_ref() { segment.set_override_trace_heights(overridden_heights.clone()); } - let state = metrics_span("execute_time_ms", || segment.execute_from_pc(from_state.pc))?; - if state.is_terminated { + let mut exec_state = VmSegmentState::new(from_state.clk, from_state.pc, None, ()); + metrics_span("execute_time_ms", || { + segment.execute_from_state(&mut exec_state) + })?; + + if exec_state.exit_code.is_some() { return Ok(VmExecutorOneSegmentResult { segment, next_state: None, @@ -252,22 +288,24 @@ where "multiple segments require to enable continuations" ); assert_eq!( - state.pc, + exec_state.pc, segment.chip_complex.connector_chip().boundary_states[1] .unwrap() .pc ); - let final_memory = mem::take(&mut segment.final_memory) - .expect("final memory should be set in continuations segment"); let streams = segment.chip_complex.take_streams(); #[cfg(feature = "bench-metrics")] let metrics = segment.metrics.partial_take(); + + // TODO(ayush): this can probably be avoided + let memory = segment.ctrl.final_memory.as_ref().unwrap().clone(); Ok(VmExecutorOneSegmentResult { segment, - next_state: Some(VmExecutorNextSegmentState { - memory: final_memory, + next_state: Some(VmState { + clk: exec_state.clk, + pc: exec_state.pc, + memory, input: streams, - pc: state.pc, #[cfg(feature = "bench-metrics")] metrics, }), @@ -278,7 +316,7 @@ where &self, exe: impl Into>, input: impl Into>, - ) -> Result>, ExecutionError> { + ) -> Result, ExecutionError> { let mut last = None; self.execute_and_then( exe, @@ -290,7 +328,7 @@ where |err| err, )?; let last = last.expect("at least one segment must be executed"); - let final_memory = last.final_memory; + let final_memory = last.ctrl.final_memory; let end_state = last.chip_complex.connector_chip().boundary_states[1].expect("end state must be set"); if end_state.is_terminate != 1 { @@ -302,6 +340,176 @@ where Ok(final_memory) } + pub fn execute_e1( + &self, + exe: impl Into>, + input: impl Into>, + num_cycles: Option, + ) -> Result, ExecutionError> + where + VC::Executor: InsExecutorE1, + { + let mem_config = self.config.system().memory_config; + let exe = exe.into(); + let memory = AddressMap::from_sparse( + mem_config.as_offset, + 1 << mem_config.as_height, + 1 << mem_config.pointer_max_bits, + exe.init_memory.clone(), + ); + + let state = VmState::new(0, exe.pc_start, memory, input); + + let _span = info_span!("execute_e1_until_cycle").entered(); + + let chip_complex = create_and_initialize_chip_complex( + &self.config, + exe.program.clone(), + state.input, + None, + ) + .unwrap(); + let mut segment = VmSegmentExecutor::::new( + chip_complex, + self.trace_height_constraints.clone(), + exe.fn_bounds.clone(), + E1ExecutionControl::new(num_cycles), + ); + #[cfg(feature = "bench-metrics")] + { + segment.metrics = state.metrics; + } + + let mut exec_state = VmSegmentState::new( + state.clk, + state.pc, + Some(GuestMemory::new(state.memory)), + (), + ); + metrics_span("execute_time_ms", || { + segment.execute_from_state(&mut exec_state) + })?; + + if let Some(end_cycle) = num_cycles { + assert_eq!(exec_state.clk, end_cycle); + } else { + match exec_state.exit_code { + Some(code) => { + if code != ExitCode::Success as u32 { + return Err(ExecutionError::FailedWithExitCode(code)); + } + } + None => return Err(ExecutionError::DidNotTerminate), + }; + } + + let state = VmState { + clk: exec_state.clk, + pc: exec_state.pc, + memory: exec_state.memory.unwrap().memory, + input: segment.chip_complex.take_streams(), + #[cfg(feature = "bench-metrics")] + metrics: segment.metrics.partial_take(), + }; + + Ok(state) + } + + pub fn execute_metered( + &self, + exe: impl Into>, + input: impl Into>, + widths: Vec, + interactions: Vec, + ) -> Result, ExecutionError> + where + VC::Executor: InsExecutorE1, + { + let mem_config = self.config.system().memory_config; + let exe = exe.into(); + + let memory = AddressMap::from_sparse( + mem_config.as_offset, + 1 << mem_config.as_height, + 1 << mem_config.pointer_max_bits, + exe.init_memory.clone(), + ); + let state = VmState::new(0, exe.pc_start, memory, input); + + let _span = info_span!("execute_metered").entered(); + + let chip_complex = create_and_initialize_chip_complex( + &self.config, + exe.program.clone(), + state.input, + None, + ) + .unwrap(); + let air_names = chip_complex.air_names(); + let ctrl = MeteredExecutionControl::new(&air_names, &widths, &interactions); + let mut executor = VmSegmentExecutor::::new( + chip_complex, + self.trace_height_constraints.clone(), + exe.fn_bounds.clone(), + ctrl, + ); + + #[cfg(feature = "bench-metrics")] + { + executor.metrics = state.metrics; + } + + let continuations_enabled = executor + .chip_complex + .memory_controller() + .continuation_enabled(); + let num_access_adapters = executor + .chip_complex + .memory_controller() + .access_adapters + .num_access_adapters(); + let ctx = MeteredCtx::new( + widths.len(), + continuations_enabled, + num_access_adapters as u8, + executor + .chip_complex + .memory_controller() + .memory + .min_block_size + .iter() + .map(|&x| log2_strict_usize(x as usize) as u8) + .collect(), + executor + .chip_complex + .memory_controller() + .mem_config() + .memory_dimensions(), + ); + + let mut exec_state = VmSegmentState::new( + state.clk, + state.pc, + Some(GuestMemory::new(state.memory)), + ctx, + ); + metrics_span("execute_time_ms", || { + executor.execute_from_state(&mut exec_state) + })?; + + // Check exit code + match exec_state.exit_code { + Some(code) => { + if code != ExitCode::Success as u32 { + return Err(ExecutionError::FailedWithExitCode(code)); + } + } + None => return Err(ExecutionError::DidNotTerminate), + }; + + Ok(executor.ctrl.segments) + } + pub fn execute_and_generate( &self, exe: impl Into>, @@ -315,6 +523,63 @@ where self.execute_and_generate_impl(exe.into(), None, input) } + pub fn execute_and_generate_segment( + &self, + exe: impl Into>, + state: VmState, + num_cycles: u64, + ) -> Result, GenerationError> + where + Domain: PolynomialSpace, + VC::Executor: Chip, + VC::Periphery: Chip, + { + let _span = info_span!("execute_and_generate_segment").entered(); + + let exe = exe.into(); + let chip_complex = create_and_initialize_chip_complex( + &self.config, + exe.program.clone(), + state.input, + Some(state.memory), + ) + .unwrap(); + let ctrl = TracegenExecutionControl::new(state.clk + num_cycles); + let mut segment = VmSegmentExecutor::<_, VC, _>::new( + chip_complex, + self.trace_height_constraints.clone(), + exe.fn_bounds.clone(), + ctrl, + ); + + // TODO(ayush): do i need this? + if let Some(overridden_heights) = self.overridden_heights.as_ref() { + segment.set_override_trace_heights(overridden_heights.clone()); + } + + let mut exec_state = VmSegmentState::new(state.clk, state.pc, None, ()); + metrics_span("execute_from_state", || { + segment.execute_from_state(&mut exec_state) + })?; + + assert_eq!( + exec_state.pc, + segment.chip_complex.connector_chip().boundary_states[1] + .unwrap() + .pc + ); + + // TODO(ayush): avoid cloning + let final_memory = segment.ctrl.final_memory.clone(); + let proof_input = tracing::info_span!("generate_proof_input") + .in_scope(|| segment.generate_proof_input(None))?; + + Ok(VmExecutorResult { + per_segment: vec![proof_input], + final_memory, + }) + } + pub fn execute_and_generate_with_cached_program( &self, committed_exe: Arc>, @@ -350,7 +615,7 @@ where |seg_idx, mut seg| { // Note: this will only be Some on the last segment; otherwise it is // already moved into next segment state - final_memory = mem::take(&mut seg.final_memory); + final_memory = mem::take(&mut seg.ctrl.final_memory); tracing::info_span!("trace_gen", segment = seg_idx) .in_scope(|| seg.generate_proof_input(committed_program.clone())) }, @@ -433,7 +698,7 @@ where let air_heights = segment.chip_complex.current_trace_heights(); let vm_heights = segment.chip_complex.get_internal_trace_heights(); let public_values = if let Some(pv_chip) = segment.chip_complex.public_values_chip() { - pv_chip.core.get_custom_public_values() + pv_chip.step.get_custom_public_values() } else { vec![] }; @@ -466,20 +731,31 @@ where &self, exe: VmExe, input: impl Into>, - ) -> Result, ExecutionError> { - let pc_start = exe.pc_start; - let mut segment = ExecutionSegment::new( + ) -> Result, ExecutionError> + { + let chip_complex = create_and_initialize_chip_complex( &self.config, exe.program.clone(), input.into(), None, + ) + .unwrap(); + let ctrl = TracegenExecutionControlWithSegmentation::new(chip_complex.air_names()); + let mut segment = VmSegmentExecutor::new( + chip_complex, self.trace_height_constraints.clone(), exe.fn_bounds.clone(), + ctrl, ); + if let Some(overridden_heights) = self.overridden_heights.as_ref() { segment.set_override_trace_heights(overridden_heights.clone()); } - metrics_span("execute_time_ms", || segment.execute_from_pc(pc_start))?; + + let mut exec_state = VmSegmentState::new(0, exe.pc_start, None, ()); + metrics_span("execute_time_ms", || { + segment.execute_from_state(&mut exec_state) + })?; Ok(segment) } } @@ -587,7 +863,7 @@ where &self, exe: impl Into>, input: impl Into>, - ) -> Result>, ExecutionError> { + ) -> Result, ExecutionError> { self.executor.execute(exe, input) } @@ -859,3 +1135,33 @@ where } } } + +/// Create and initialize a chip complex with program, streams, and optional memory +pub fn create_and_initialize_chip_complex( + config: &VC, + program: Program, + init_streams: Streams, + initial_memory: Option, +) -> Result, VmInventoryError> +where + F: PrimeField32, + VC: VmConfig, +{ + let mut chip_complex = config.create_chip_complex()?; + chip_complex.set_streams(init_streams); + + // Strip debug info if profiling is disabled + let program = if !config.system().profiling { + program.strip_debug_infos() + } else { + program + }; + + chip_complex.set_program(program); + + if let Some(initial_memory) = initial_memory { + chip_complex.set_initial_memory(initial_memory); + } + + Ok(chip_complex) +} diff --git a/crates/vm/src/metrics/mod.rs b/crates/vm/src/metrics/mod.rs index 916e8251ac..1e24cafc80 100644 --- a/crates/vm/src/metrics/mod.rs +++ b/crates/vm/src/metrics/mod.rs @@ -2,13 +2,7 @@ use std::{collections::BTreeMap, mem}; use cycle_tracker::CycleTracker; use metrics::counter; -use openvm_instructions::{ - exe::{FnBound, FnBounds}, - VmOpcode, -}; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::arch::{ExecutionSegment, InstructionExecutor, VmConfig}; +use openvm_instructions::exe::{FnBound, FnBounds}; pub mod cycle_tracker; @@ -30,39 +24,8 @@ pub struct VmMetrics { pub(crate) current_trace_cells: Vec, } -impl ExecutionSegment -where - F: PrimeField32, - VC: VmConfig, -{ - /// Update metrics that increment per instruction - #[allow(unused_variables)] - pub fn update_instruction_metrics( - &mut self, - pc: u32, - opcode: VmOpcode, - dsl_instr: Option, - ) { - self.metrics.cycle_count += 1; - - if self.system_config().profiling { - let executor = self.chip_complex.inventory.get_executor(opcode).unwrap(); - let opcode_name = executor.get_opcode_name(opcode.as_usize()); - self.metrics.update_trace_cells( - &self.air_names, - self.current_trace_cells(), - opcode_name, - dsl_instr, - ); - - #[cfg(feature = "function-span")] - self.metrics.update_current_fn(pc); - } - } -} - impl VmMetrics { - fn update_trace_cells( + pub fn update_trace_cells( &mut self, air_names: &[String], now_trace_cells: Vec, diff --git a/crates/vm/src/system/connector/mod.rs b/crates/vm/src/system/connector/mod.rs index dc9ff88ea2..fd97d7e55d 100644 --- a/crates/vm/src/system/connector/mod.rs +++ b/crates/vm/src/system/connector/mod.rs @@ -215,6 +215,7 @@ impl VmConnectorChip { pub fn begin(&mut self, state: ExecutionState) { self.boundary_states[0] = Some(ConnectorCols { pc: state.pc, + // TODO(ayush): should this be hardcoded to INITIAL_TIMESTAMP? timestamp: state.timestamp, is_terminate: 0, exit_code: 0, diff --git a/crates/vm/src/system/memory/adapter/mod.rs b/crates/vm/src/system/memory/adapter/mod.rs index 64e79a920b..ea76117066 100644 --- a/crates/vm/src/system/memory/adapter/mod.rs +++ b/crates/vm/src/system/memory/adapter/mod.rs @@ -1,11 +1,10 @@ -use std::{borrow::BorrowMut, cmp::max, sync::Arc}; +use std::{borrow::BorrowMut, io::Cursor, sync::Arc}; pub use air::*; pub use columns::*; use enum_dispatch::enum_dispatch; use openvm_circuit_primitives::{ - is_less_than::IsLtSubAir, utils::next_power_of_two_or_zero, - var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator, + is_less_than::IsLtSubAir, var_range::SharedVariableRangeCheckerChip, TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_stark_backend::{ @@ -13,8 +12,7 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_commit::PolynomialSpace, p3_field::PrimeField32, - p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, + p3_matrix::{dense::RowMajorMatrix, Matrix}, p3_util::log2_strict_usize, prover::types::AirProofInput, AirRef, Chip, ChipUsageGetter, @@ -32,7 +30,7 @@ pub struct AccessAdapterInventory { air_names: Vec, } -impl AccessAdapterInventory { +impl AccessAdapterInventory { pub fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, @@ -80,6 +78,14 @@ impl AccessAdapterInventory { } } + pub fn set_trace(&mut self, index: usize, trace: Vec, width: usize) + where + F: PrimeField32, + { + let trace = RowMajorMatrix::new(trace, width); + self.chips[index].set_trace(trace); + } + #[cfg(test)] pub fn records_for_n(&self, n: usize) -> &[AccessAdapterRecord] { let idx = log2_strict_usize(n) - 1; @@ -134,7 +140,10 @@ impl AccessAdapterInventory { memory_bus: MemoryBus, clk_max_bits: usize, max_access_adapter_n: usize, - ) -> Option> { + ) -> Option> + where + F: Clone + Send + Sync, + { if N <= max_access_adapter_n { Some(GenericAccessAdapterChip::new::( range_checker, @@ -145,6 +154,39 @@ impl AccessAdapterInventory { None } } + + pub(crate) fn execute_split( + &mut self, + address: MemoryAddress, + values: &[F], + timestamp: u32, + row_slice: &mut [F], + ) where + F: PrimeField32, + { + let index = get_chip_index(values.len()); + self.chips[index].execute_split(address, values, timestamp, row_slice); + } + + pub(crate) fn execute_merge( + &mut self, + address: MemoryAddress, + values: &[F], + left_timestamp: u32, + right_timestamp: u32, + row_slice: &mut [F], + ) where + F: PrimeField32, + { + let index = get_chip_index(values.len()); + self.chips[index].execute_merge( + address, + values, + left_timestamp, + right_timestamp, + row_slice, + ); + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -173,6 +215,28 @@ pub trait GenericAccessAdapterChipTrait { fn generate_trace(self) -> RowMajorMatrix where F: PrimeField32; + fn set_trace(&mut self, trace: RowMajorMatrix) + where + F: PrimeField32; + + fn execute_split( + &mut self, + address: MemoryAddress, + values: &[F], + timestamp: u32, + row_slice: &mut [F], + ) where + F: PrimeField32; + + fn execute_merge( + &mut self, + address: MemoryAddress, + values: &[F], + left_timestamp: u32, + right_timestamp: u32, + row_slice: &mut [F], + ) where + F: PrimeField32; } #[derive(Chip, ChipUsageGetter)] @@ -186,7 +250,7 @@ enum GenericAccessAdapterChip { N32(AccessAdapterChip), } -impl GenericAccessAdapterChip { +impl GenericAccessAdapterChip { fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, @@ -216,13 +280,16 @@ impl GenericAccessAdapterChip { } } } + pub struct AccessAdapterChip { air: AccessAdapterAir, range_checker: SharedVariableRangeCheckerChip, pub records: Vec>, + trace: RowMajorMatrix, overridden_height: Option, } -impl AccessAdapterChip { + +impl AccessAdapterChip { pub fn new( range_checker: SharedVariableRangeCheckerChip, memory_bus: MemoryBus, @@ -233,6 +300,7 @@ impl AccessAdapterChip { air: AccessAdapterAir:: { memory_bus, lt_air }, range_checker, records: vec![], + trace: RowMajorMatrix::new(Vec::new(), 0), overridden_height: None, } } @@ -251,48 +319,135 @@ impl GenericAccessAdapterChipTrait for AccessAdapterChip::width(&self.air); - let height = if let Some(oh) = self.overridden_height { - assert!( - oh >= self.records.len(), - "Overridden height is less than the required height" - ); - oh - } else { - self.records.len() - }; - let height = next_power_of_two_or_zero(height); - let mut values = F::zero_vec(height * width); - - values - .par_chunks_mut(width) - .zip(self.records.into_par_iter()) - .for_each(|(row, record)| { - let row: &mut AccessAdapterCols = row.borrow_mut(); - - row.is_valid = F::ONE; - row.values = record.data.try_into().unwrap(); - row.address = MemoryAddress::new(record.address_space, record.start_index); - - let (left_timestamp, right_timestamp) = match record.kind { - AccessAdapterRecordKind::Split => (record.timestamp, record.timestamp), - AccessAdapterRecordKind::Merge { - left_timestamp, - right_timestamp, - } => (left_timestamp, right_timestamp), - }; - debug_assert_eq!(max(left_timestamp, right_timestamp), record.timestamp); - - row.left_timestamp = F::from_canonical_u32(left_timestamp); - row.right_timestamp = F::from_canonical_u32(right_timestamp); - row.is_split = F::from_bool(record.kind == AccessAdapterRecordKind::Split); - - self.air.lt_air.generate_subrow( - (self.range_checker.as_ref(), left_timestamp, right_timestamp), - (&mut row.lt_aux, &mut row.is_right_larger), - ); - }); - RowMajorMatrix::new(values, width) + let mut trace = self.trace; + let height = trace.height(); + trace.pad_to_height(height.next_power_of_two(), F::ZERO); + trace + // TODO(AG): everything related to the calculated trace height + // needs to be in memory controller, who owns these traces. + + // let width = BaseAir::::width(&self.air); + // let height = if let Some(oh) = self.overridden_height { + // assert!( + // oh >= self.records.len(), + // "Overridden height is less than the required height" + // ); + // oh + // } else { + // self.records.len() + // }; + // let height = next_power_of_two_or_zero(height); + // let mut values = F::zero_vec(height * width); + + // values + // .par_chunks_mut(width) + // .zip(self.records.into_par_iter()) + // .for_each(|(row, record)| { + // let row: &mut AccessAdapterCols = row.borrow_mut(); + + // row.is_valid = F::ONE; + // row.values = record.data.try_into().unwrap(); + // row.address = MemoryAddress::new(record.address_space, record.start_index); + + // let (left_timestamp, right_timestamp) = match record.kind { + // AccessAdapterRecordKind::Split => (record.timestamp, record.timestamp), + // AccessAdapterRecordKind::Merge { + // left_timestamp, + // right_timestamp, + // } => (left_timestamp, right_timestamp), + // }; + // debug_assert_eq!(max(left_timestamp, right_timestamp), record.timestamp); + + // row.left_timestamp = F::from_canonical_u32(left_timestamp); + // row.right_timestamp = F::from_canonical_u32(right_timestamp); + // row.is_split = F::from_bool(record.kind == AccessAdapterRecordKind::Split); + + // self.air.lt_air.generate_subrow( + // (self.range_checker.as_ref(), left_timestamp, right_timestamp), + // (&mut row.lt_aux, &mut row.is_right_larger), + // ); + // }); + // RowMajorMatrix::new(values, width) + } + + fn set_trace(&mut self, trace: RowMajorMatrix) { + self.trace = trace; + } + + fn execute_split( + &mut self, + address: MemoryAddress, + values: &[F], + timestamp: u32, + row_slice: &mut [F], + ) where + F: PrimeField32, + { + let row: &mut AccessAdapterCols = row_slice.borrow_mut(); + row.is_valid = F::ONE; + row.is_split = F::ONE; + row.address = MemoryAddress::new( + F::from_canonical_u32(address.address_space), + F::from_canonical_u32(address.pointer), + ); + row.left_timestamp = F::from_canonical_u32(timestamp); + row.right_timestamp = F::from_canonical_u32(timestamp); + row.is_right_larger = F::ZERO; + debug_assert_eq!( + values.len(), + N, + "Input values slice length must match the access adapter type" + ); + // TODO: move this to `fill_trace_row` + self.air.lt_air.generate_subrow( + (self.range_checker.as_ref(), timestamp, timestamp), + (&mut row.lt_aux, &mut row.is_right_larger), + ); + + // SAFETY: `values` slice is asserted to have length N. `row.values` is an array of length + // N. Pointers are valid and regions do not overlap because exactly one of them is a + // part of the trace. + unsafe { + std::ptr::copy_nonoverlapping(values.as_ptr(), row.values.as_mut_ptr(), N); + } + } + + fn execute_merge( + &mut self, + address: MemoryAddress, + values: &[F], + left_timestamp: u32, + right_timestamp: u32, + row_slice: &mut [F], + ) where + F: PrimeField32, + { + let row: &mut AccessAdapterCols = row_slice.borrow_mut(); + row.is_valid = F::ONE; + row.is_split = F::ZERO; + row.address = MemoryAddress::new( + F::from_canonical_u32(address.address_space), + F::from_canonical_u32(address.pointer), + ); + row.left_timestamp = F::from_canonical_u32(left_timestamp); + row.right_timestamp = F::from_canonical_u32(right_timestamp); + debug_assert_eq!( + values.len(), + N, + "Input values slice length must match the access adapter type" + ); + // TODO: move this to `fill_trace_row` + self.air.lt_air.generate_subrow( + (self.range_checker.as_ref(), left_timestamp, right_timestamp), + (&mut row.lt_aux, &mut row.is_right_larger), + ); + + // SAFETY: `values` slice is asserted to have length N. `row.values` is an array of length + // N. Pointers are valid and regions do not overlap because exactly one of them is a + // part of the trace. + unsafe { + std::ptr::copy_nonoverlapping(values.as_ptr(), row.values.as_mut_ptr(), N); + } } } @@ -328,3 +483,51 @@ impl ChipUsageGetter for AccessAdapterChip { fn air_name(n: usize) -> String { format!("AccessAdapter<{}>", n) } + +#[inline(always)] +pub fn get_chip_index(block_size: usize) -> usize { + assert!( + block_size.is_power_of_two() && block_size >= 2, + "Invalid block size {} for split operation", + block_size + ); + let index = block_size.trailing_zeros() - 1; + index as usize +} + +pub struct AdapterInventoryTraceCursor { + // [AG] TODO: replace with a pre-allocated space + cursors: Vec>>, + widths: Vec, +} + +impl AdapterInventoryTraceCursor { + pub fn new(as_cnt: usize) -> Self { + let cursors = vec![Cursor::new(Vec::new()); as_cnt]; + let widths = vec![ + size_of::>(), + size_of::>(), + size_of::>(), + size_of::>(), + size_of::>(), + ]; + Self { cursors, widths } + } + + pub fn get_row_slice(&mut self, block_size: usize) -> &mut [F] { + let index = get_chip_index(block_size); + let begin = self.cursors[index].position() as usize; + let end = begin + self.widths[index]; + self.cursors[index].get_mut().resize(end, F::ZERO); + self.cursors[index].set_position(end as u64); + &mut self.cursors[index].get_mut()[begin..end] + } + + pub fn extract_trace(&mut self, index: usize) -> Vec { + std::mem::replace(&mut self.cursors[index], Cursor::new(Vec::new())).into_inner() + } + + pub fn width(&self, index: usize) -> usize { + self.widths[index] + } +} diff --git a/crates/vm/src/system/memory/controller/dimensions.rs b/crates/vm/src/system/memory/controller/dimensions.rs index 1082d3adf0..87976d2fa0 100644 --- a/crates/vm/src/system/memory/controller/dimensions.rs +++ b/crates/vm/src/system/memory/controller/dimensions.rs @@ -30,6 +30,15 @@ impl MemoryDimensions { debug_assert!(block_id < (1 << self.address_height)); (((addr_space - self.as_offset) as u64) << self.address_height) + block_id as u64 } + + /// Convert an index in the memory merkle tree to an address label (address space, block id). + /// + /// This function performs the inverse operation of `label_to_index`. + pub fn index_to_label(&self, index: u64) -> (u32, u32) { + let block_id = (index & ((1 << self.address_height) - 1)) as u32; + let addr_space = (index >> self.address_height) as u32 + self.as_offset; + (addr_space, block_id) + } } impl MemoryConfig { diff --git a/crates/vm/src/system/memory/controller/interface.rs b/crates/vm/src/system/memory/controller/interface.rs index b51e960a32..b00171a3c2 100644 --- a/crates/vm/src/system/memory/controller/interface.rs +++ b/crates/vm/src/system/memory/controller/interface.rs @@ -13,7 +13,7 @@ pub enum MemoryInterface { Persistent { boundary_chip: PersistentBoundaryChip, merkle_chip: MemoryMerkleChip, - initial_memory: MemoryImage, + initial_memory: MemoryImage, }, } diff --git a/crates/vm/src/system/memory/controller/mod.rs b/crates/vm/src/system/memory/controller/mod.rs index 680a03ab8e..6bc186e0b8 100644 --- a/crates/vm/src/system/memory/controller/mod.rs +++ b/crates/vm/src/system/memory/controller/mod.rs @@ -1,18 +1,12 @@ -use std::{ - array, - collections::BTreeMap, - iter, - marker::PhantomData, - mem, - sync::{Arc, Mutex}, -}; +use std::{collections::BTreeMap, iter, marker::PhantomData}; use getset::{Getters, MutGetters}; use openvm_circuit_primitives::{ assert_less_than::{AssertLtSubAir, LessThanAuxCols}, - is_zero::IsZeroSubAir, utils::next_power_of_two_or_zero, - var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip, + }, TraceSubRowGenerator, }; use openvm_stark_backend::{ @@ -29,8 +23,10 @@ use serde::{Deserialize, Serialize}; use self::interface::MemoryInterface; use super::{ + online::INITIAL_TIMESTAMP, paged_vec::{AddressMap, PAGE_SIZE}, volatile::VolatileBoundaryChip, + MemoryAddress, }; use crate::{ arch::{hasher::HasherChip, MemoryConfig}, @@ -38,14 +34,9 @@ use crate::{ adapter::AccessAdapterInventory, dimensions::MemoryDimensions, merkle::{MemoryMerkleChip, SerialReceiver}, - offline::{MemoryRecord, OfflineMemory, INITIAL_TIMESTAMP}, - offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryBus, MemoryReadAuxCols, - MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols, AUX_LEN, - }, - online::{Memory, MemoryLogEntry}, + offline_checker::{MemoryBaseAuxCols, MemoryBridge, MemoryBus, AUX_LEN}, + online::{AccessMetadata, TracingMemory}, persistent::PersistentBoundaryChip, - tree::MemoryNode, }, }; @@ -53,6 +44,8 @@ pub mod dimensions; pub mod interface; pub const CHUNK: usize = 8; +pub const CHUNK_BITS: usize = CHUNK.ilog2() as usize; + /// The offset of the Merkle AIR in AIRs of MemoryController. pub const MERKLE_AIR_OFFSET: usize = 1; /// The offset of the boundary AIR in AIRs of MemoryController. @@ -62,7 +55,7 @@ pub const BOUNDARY_AIR_OFFSET: usize = 0; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] pub struct RecordId(pub usize); -pub type MemoryImage = AddressMap; +pub type MemoryImage = AddressMap; #[repr(C)] #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -98,29 +91,8 @@ pub struct MemoryController { // Store separately to avoid smart pointer reference each time range_checker_bus: VariableRangeCheckerBus, // addr_space -> Memory data structure - memory: Memory, - /// A reference to the `OfflineMemory`. Will be populated after `finalize()`. - offline_memory: Arc>>, + pub memory: TracingMemory, pub access_adapters: AccessAdapterInventory, - // Filled during finalization. - final_state: Option>, -} - -#[allow(clippy::large_enum_variant)] -#[derive(Debug)] -enum FinalState { - Volatile(VolatileFinalState), - #[allow(dead_code)] - Persistent(PersistentFinalState), -} -#[derive(Debug, Default)] -struct VolatileFinalState { - _marker: PhantomData, -} -#[allow(dead_code)] -#[derive(Debug)] -struct PersistentFinalState { - final_memory: Equipartition, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -226,7 +198,6 @@ impl MemoryController { range_checker: SharedVariableRangeCheckerChip, ) -> Self { let range_checker_bus = range_checker.bus(); - let initial_memory = AddressMap::from_mem_config(&mem_config); assert!(mem_config.pointer_max_bits <= F::bits() - 2); assert!(mem_config.as_height < F::bits() - 2); let addr_space_max_bits = log2_ceil_usize( @@ -243,14 +214,7 @@ impl MemoryController { range_checker.clone(), ), }, - memory: Memory::new(&mem_config), - offline_memory: Arc::new(Mutex::new(OfflineMemory::new( - initial_memory, - 1, - memory_bus, - range_checker.clone(), - mem_config, - ))), + memory: TracingMemory::new(&mem_config, range_checker.clone(), memory_bus, 1), access_adapters: AccessAdapterInventory::new( range_checker.clone(), memory_bus, @@ -259,7 +223,6 @@ impl MemoryController { ), range_checker, range_checker_bus, - final_state: None, } } @@ -294,14 +257,8 @@ impl MemoryController { memory_bus, mem_config, interface_chip, - memory: Memory::new(&mem_config), // it is expected that the memory will be set later - offline_memory: Arc::new(Mutex::new(OfflineMemory::new( - AddressMap::from_mem_config(&mem_config), - CHUNK, - memory_bus, - range_checker.clone(), - mem_config, - ))), + memory: TracingMemory::new(&mem_config, range_checker.clone(), memory_bus, CHUNK), /* it is expected that the memory will be + * set later */ access_adapters: AccessAdapterInventory::new( range_checker.clone(), memory_bus, @@ -310,12 +267,11 @@ impl MemoryController { ), range_checker, range_checker_bus, - final_state: None, } } - pub fn memory_image(&self) -> &MemoryImage { - &self.memory.data + pub fn memory_image(&self) -> &MemoryImage { + &self.memory.data.memory } pub fn set_override_trace_heights(&mut self, overridden_heights: MemoryTraceHeights) { @@ -344,26 +300,30 @@ impl MemoryController { } } - pub fn set_initial_memory(&mut self, memory: MemoryImage) { + pub fn set_initial_memory(&mut self, memory: MemoryImage) { if self.timestamp() > INITIAL_TIMESTAMP + 1 { panic!("Cannot set initial memory after first timestamp"); } - let mut offline_memory = self.offline_memory.lock().unwrap(); - offline_memory.set_initial_memory(memory.clone(), self.mem_config); - - self.memory = Memory::from_image(memory.clone(), self.mem_config.access_capacity); + if memory.is_empty() { + return; + } match &mut self.interface_chip { MemoryInterface::Volatile { .. } => { - assert!( - memory.is_empty(), - "Cannot set initial memory for volatile memory" - ); + panic!("Cannot set initial memory for volatile memory"); } MemoryInterface::Persistent { initial_memory, .. } => { - *initial_memory = memory; + *initial_memory = memory.clone(); } } + + self.memory = TracingMemory::new( + &self.mem_config, + self.range_checker.clone(), + self.memory_bus, + CHUNK, + ) + .with_image(memory, self.mem_config.access_capacity); } pub fn memory_bridge(&self) -> MemoryBridge { @@ -379,50 +339,61 @@ impl MemoryController { (record_id, data) } - pub fn read(&mut self, address_space: F, pointer: F) -> (RecordId, [F; N]) { + // TEMP[jpw]: Function is safe temporarily for refactoring + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, and it must be the + /// exact type used to represent a single memory cell in address space `address_space`. For + /// standard usage, `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + pub fn read( + &mut self, + address_space: F, + pointer: F, + ) -> (RecordId, [T; N]) { let address_space_u32 = address_space.as_canonical_u32(); let ptr_u32 = pointer.as_canonical_u32(); assert!( address_space == F::ZERO || ptr_u32 < (1 << self.mem_config.pointer_max_bits), "memory out of bounds: {ptr_u32:?}", ); + todo!() + // let (record_id, values) = unsafe { self.memory.read::(address_space_u32, ptr_u32) + // }; - let (record_id, values) = self.memory.read::(address_space_u32, ptr_u32); - - (record_id, values) + // (record_id, values) } /// Reads a word directly from memory without updating internal state. /// /// Any value returned is unconstrained. - pub fn unsafe_read_cell(&self, addr_space: F, ptr: F) -> F { - self.unsafe_read::<1>(addr_space, ptr)[0] + pub fn unsafe_read_cell(&self, addr_space: F, ptr: F) -> T { + self.unsafe_read::(addr_space, ptr)[0] } /// Reads a word directly from memory without updating internal state. /// /// Any value returned is unconstrained. - pub fn unsafe_read(&self, addr_space: F, ptr: F) -> [F; N] { + pub fn unsafe_read(&self, addr_space: F, ptr: F) -> [T; N] { let addr_space = addr_space.as_canonical_u32(); let ptr = ptr.as_canonical_u32(); - array::from_fn(|i| self.memory.get(addr_space, ptr + i as u32)) + todo!() + // unsafe { array::from_fn(|i| self.memory.get::(addr_space, ptr + i as u32)) } } /// Writes `data` to the given cell. /// /// Returns the `RecordId` and previous data. - pub fn write_cell(&mut self, address_space: F, pointer: F, data: F) -> (RecordId, F) { - let (record_id, [data]) = self.write(address_space, pointer, [data]); + pub fn write_cell(&mut self, address_space: F, pointer: F, data: T) -> (RecordId, T) { + let (record_id, [data]) = self.write(address_space, pointer, &[data]); (record_id, data) } - pub fn write( + pub fn write( &mut self, address_space: F, pointer: F, - data: [F; N], - ) -> (RecordId, [F; N]) { - assert_ne!(address_space, F::ZERO); + data: &[T; N], + ) -> (RecordId, [T; N]) { + debug_assert_ne!(address_space, F::ZERO); let address_space_u32 = address_space.as_canonical_u32(); let ptr_u32 = pointer.as_canonical_u32(); assert!( @@ -430,13 +401,23 @@ impl MemoryController { "memory out of bounds: {ptr_u32:?}", ); - self.memory.write(address_space_u32, ptr_u32, data) + todo!() + // unsafe { self.memory.write::(address_space_u32, ptr_u32, data) } + } + + pub fn helper(&self) -> SharedMemoryHelper { + let range_bus = self.range_checker.bus(); + SharedMemoryHelper { + range_checker: self.range_checker.clone(), + timestamp_lt_air: AssertLtSubAir::new(range_bus, self.mem_config.clk_max_bits), + _marker: Default::default(), + } } pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory { let range_bus = self.range_checker.bus(); MemoryAuxColsFactory { - range_checker: self.range_checker.clone(), + range_checker: self.range_checker.as_ref(), timestamp_lt_air: AssertLtSubAir::new(range_bus, self.mem_config.clk_max_bits), _marker: Default::default(), } @@ -454,106 +435,195 @@ impl MemoryController { self.memory.timestamp() } - fn replay_access_log(&mut self) { - let log = mem::take(&mut self.memory.log); - if log.is_empty() { - // Online memory logs may be empty, but offline memory may be replayed from external - // sources. In these cases, we skip the calls to replay access logs because - // `set_log_capacity` would panic. - tracing::debug!("skipping replay_access_log"); - return; - } - - let mut offline_memory = self.offline_memory.lock().unwrap(); - offline_memory.set_log_capacity(log.len()); - - for entry in log { - Self::replay_access( - entry, - &mut offline_memory, - &mut self.interface_chip, - &mut self.access_adapters, - ); - } - } - - /// Low-level API to replay a single memory access log entry and populate the [OfflineMemory], - /// [MemoryInterface], and `AccessAdapterInventory`. - pub fn replay_access( - entry: MemoryLogEntry, - offline_memory: &mut OfflineMemory, - interface_chip: &mut MemoryInterface, - adapter_records: &mut AccessAdapterInventory, - ) { - match entry { - MemoryLogEntry::Read { - address_space, - pointer, - len, - } => { - if address_space != 0 { - interface_chip.touch_range(address_space, pointer, len as u32); + /// Returns the equipartition of the touched blocks. + /// Has side effects (namely setting the traces for the access adapters). + fn touched_blocks_to_equipartition( + &mut self, + touched_blocks: Vec<((u32, u32), AccessMetadata)>, + ) -> TimestampedEquipartition { + let mut current_values = [F::ZERO; CHUNK]; + let mut current_cnt = 0; + let mut current_address = MemoryAddress::new(0, 0); + let mut current_timestamps = vec![0; CHUNK]; + let mut final_memory = TimestampedEquipartition::::new(); + for ((addr_space, ptr), metadata) in touched_blocks { + let AccessMetadata { + timestamp, + block_size, + } = metadata; + if current_cnt > 0 + && (current_address.address_space != addr_space + || current_address.pointer + CHUNK as u32 <= ptr) + { + let min_block_size = + self.memory.min_block_size[current_address.address_space as usize] as usize; + current_values[current_cnt..].fill(F::ZERO); + current_timestamps[(current_cnt / min_block_size)..].fill(INITIAL_TIMESTAMP); + self.memory.execute_merges::( + current_address, + min_block_size, + ¤t_values, + ¤t_timestamps, + ); + final_memory.insert( + (current_address.address_space, current_address.pointer), + TimestampedValues { + timestamp: *current_timestamps + .iter() + .take(current_cnt.div_ceil(min_block_size)) + .max() + .unwrap(), + values: current_values, + }, + ); + current_cnt = 0; + } + let min_block_size = self.memory.min_block_size[addr_space as usize] as usize; + if current_cnt == 0 { + let rem = ptr & (CHUNK as u32 - 1); + if rem != 0 { + current_values[..(rem as usize)].fill(F::ZERO); + current_address = MemoryAddress::new(addr_space, ptr - rem); + } else { + current_address = MemoryAddress::new(addr_space, ptr); } - offline_memory.read(address_space, pointer, len, adapter_records); + } else { + let offset = (ptr - current_address.pointer) as usize; + current_values[current_cnt..offset].fill(F::ZERO); + current_timestamps[(current_cnt / min_block_size)..(offset / min_block_size)] + .fill(INITIAL_TIMESTAMP); + current_cnt = offset; } - MemoryLogEntry::Write { - address_space, - pointer, - data, - } => { - if address_space != 0 { - interface_chip.touch_range(address_space, pointer, data.len() as u32); + debug_assert!(block_size >= min_block_size as u32); + debug_assert!(ptr % min_block_size as u32 == 0); + + let values = (0..block_size) + .map(|i| self.memory.data.memory.get_f::(addr_space, ptr + i)) + .collect::>(); + self.memory.execute_splits::( + MemoryAddress::new(addr_space, ptr), + min_block_size.min(CHUNK), + &values, + metadata.timestamp, + ); + if INITIAL_MERGES { + debug_assert_eq!(CHUNK, 1); + let initial_values = vec![F::ZERO; min_block_size]; + let initial_timestamps = vec![INITIAL_TIMESTAMP; min_block_size / CHUNK]; + for i in (0..block_size).step_by(min_block_size) { + self.memory.execute_merges::( + MemoryAddress::new(addr_space, ptr + i), + CHUNK, + &initial_values, + &initial_timestamps, + ); } - offline_memory.write(address_space, pointer, data, adapter_records); } - MemoryLogEntry::IncrementTimestampBy(amount) => { - offline_memory.increment_timestamp_by(amount); + for i in 0..block_size { + current_values[current_cnt] = values[i as usize]; + if current_cnt & (min_block_size - 1) == 0 { + current_timestamps[current_cnt / min_block_size] = timestamp; + } + current_cnt += 1; + if current_cnt == CHUNK { + self.memory.execute_merges::( + current_address, + min_block_size, + ¤t_values, + ¤t_timestamps, + ); + final_memory.insert( + (current_address.address_space, current_address.pointer), + TimestampedValues { + timestamp: *current_timestamps + .iter() + .take(current_cnt.div_ceil(min_block_size)) + .max() + .unwrap(), + values: current_values, + }, + ); + current_address.pointer += current_cnt as u32; + current_cnt = 0; + } } - }; + } + if current_cnt > 0 { + let min_block_size = + self.memory.min_block_size[current_address.address_space as usize] as usize; + current_values[current_cnt..].fill(F::ZERO); + current_timestamps[(current_cnt / min_block_size)..].fill(INITIAL_TIMESTAMP); + self.memory.execute_merges::( + current_address, + min_block_size, + ¤t_values, + ¤t_timestamps, + ); + final_memory.insert( + (current_address.address_space, current_address.pointer), + TimestampedValues { + timestamp: *current_timestamps + .iter() + .take(current_cnt.div_ceil(min_block_size)) + .max() + .unwrap(), + values: current_values, + }, + ); + } + + for i in 0..self.access_adapters.num_access_adapters() { + let width = self.memory.adapter_inventory_trace_cursor.width(i); + let trace = self.memory.adapter_inventory_trace_cursor.extract_trace(i); + self.access_adapters.set_trace(i, trace, width); + } + + final_memory } /// Returns the final memory state if persistent. + #[allow(clippy::assertions_on_constants)] pub fn finalize(&mut self, hasher: Option<&mut H>) where H: HasherChip + Sync + for<'a> SerialReceiver<&'a [F]>, { - if self.final_state.is_some() { - return; - } + let touched_blocks = self.memory.touched_blocks().collect::>(); - self.replay_access_log(); - let mut offline_memory = self.offline_memory.lock().unwrap(); + let mut final_memory_volatile = None; + let mut final_memory_persistent = None; + + match &self.interface_chip { + MemoryInterface::Volatile { .. } => { + final_memory_volatile = + Some(self.touched_blocks_to_equipartition::<1, true>(touched_blocks)); + } + MemoryInterface::Persistent { .. } => { + final_memory_persistent = + Some(self.touched_blocks_to_equipartition::(touched_blocks)); + } + } match &mut self.interface_chip { MemoryInterface::Volatile { boundary_chip } => { - let final_memory = offline_memory.finalize::<1>(&mut self.access_adapters); + let final_memory = final_memory_volatile.unwrap(); boundary_chip.finalize(final_memory); - self.final_state = Some(FinalState::Volatile(VolatileFinalState::default())); } MemoryInterface::Persistent { - merkle_chip, boundary_chip, + merkle_chip, initial_memory, } => { - let hasher = hasher.unwrap(); - let final_partition = offline_memory.finalize::(&mut self.access_adapters); + let final_memory = final_memory_persistent.unwrap(); - boundary_chip.finalize(initial_memory, &final_partition, hasher); - let final_memory_values = final_partition + let hasher = hasher.unwrap(); + boundary_chip.finalize(initial_memory, &final_memory, hasher); + let final_memory_values = final_memory .into_par_iter() .map(|(key, value)| (key, value.values)) .collect(); - let initial_node = MemoryNode::tree_from_memory( - merkle_chip.air.memory_dimensions, - initial_memory, - hasher, - ); - merkle_chip.finalize(&initial_node, &final_memory_values, hasher); - self.final_state = Some(FinalState::Persistent(PersistentFinalState { - final_memory: final_memory_values.clone(), - })); + merkle_chip.finalize(initial_memory.clone(), &final_memory_values, hasher); } - }; + } } pub fn generate_air_proof_inputs(self) -> Vec> @@ -694,74 +764,30 @@ impl MemoryController { ret.extend(self.access_adapters.get_cells()); ret } - - /// Returns a reference to the offline memory. - /// - /// Until `finalize` is called, the `OfflineMemory` does not contain useful state, and should - /// therefore not be used by any chip during execution. However, to obtain a reference to the - /// offline memory that will be useful in trace generation, a chip can call `offline_memory()` - /// and store the returned reference for later use. - pub fn offline_memory(&self) -> Arc>> { - self.offline_memory.clone() - } - pub fn get_memory_logs(&self) -> &Vec> { - &self.memory.log - } - pub fn set_memory_logs(&mut self, logs: Vec>) { - self.memory.log = logs; - } - pub fn take_memory_logs(&mut self) -> Vec> { - std::mem::take(&mut self.memory.log) - } } -pub struct MemoryAuxColsFactory { +/// Owned version of [MemoryAuxColsFactory]. +pub struct SharedMemoryHelper { pub(crate) range_checker: SharedVariableRangeCheckerChip, pub(crate) timestamp_lt_air: AssertLtSubAir, pub(crate) _marker: PhantomData, } +/// A helper for generating trace values in auxiliary memory columns related to the offline memory +/// argument. +pub struct MemoryAuxColsFactory<'a, T> { + pub(crate) range_checker: &'a VariableRangeCheckerChip, + pub(crate) timestamp_lt_air: AssertLtSubAir, + pub(crate) _marker: PhantomData, +} + // NOTE[jpw]: The `make_*_aux_cols` functions should be thread-safe so they can be used in // parallelized trace generation. -impl MemoryAuxColsFactory { - pub fn generate_read_aux(&self, read: &MemoryRecord, buffer: &mut MemoryReadAuxCols) { - assert!( - !read.address_space.is_zero(), - "cannot make `MemoryReadAuxCols` for address space 0" - ); - self.generate_base_aux(read, &mut buffer.base); - } - - pub fn generate_read_or_immediate_aux( - &self, - read: &MemoryRecord, - buffer: &mut MemoryReadOrImmediateAuxCols, - ) { - IsZeroSubAir.generate_subrow( - read.address_space, - (&mut buffer.is_zero_aux, &mut buffer.is_immediate), - ); - self.generate_base_aux(read, &mut buffer.base); - } - - pub fn generate_write_aux( - &self, - write: &MemoryRecord, - buffer: &mut MemoryWriteAuxCols, - ) { - buffer - .prev_data - .copy_from_slice(write.prev_data_slice().unwrap()); - self.generate_base_aux(write, &mut buffer.base); - } - - pub fn generate_base_aux(&self, record: &MemoryRecord, buffer: &mut MemoryBaseAuxCols) { - buffer.prev_timestamp = F::from_canonical_u32(record.prev_timestamp); - self.generate_timestamp_lt( - record.prev_timestamp, - record.timestamp, - &mut buffer.timestamp_lt_aux, - ); +impl MemoryAuxColsFactory<'_, F> { + /// Fill the trace assuming `prev_timestamp` is already provided in `buffer`. + pub fn fill_from_prev(&self, timestamp: u32, buffer: &mut MemoryBaseAuxCols) { + let prev_timestamp = buffer.prev_timestamp.as_canonical_u32(); + self.generate_timestamp_lt(prev_timestamp, timestamp, &mut buffer.timestamp_lt_aux); } fn generate_timestamp_lt( @@ -770,38 +796,16 @@ impl MemoryAuxColsFactory { timestamp: u32, buffer: &mut LessThanAuxCols, ) { - debug_assert!(prev_timestamp < timestamp); + debug_assert!( + prev_timestamp < timestamp, + "prev_timestamp {prev_timestamp} >= timestamp {timestamp}" + ); self.timestamp_lt_air.generate_subrow( - (self.range_checker.as_ref(), prev_timestamp, timestamp), + (self.range_checker, prev_timestamp, timestamp), &mut buffer.lower_decomp, ); } - /// In general, prefer `generate_read_aux` which writes in-place rather than this function. - pub fn make_read_aux_cols(&self, read: &MemoryRecord) -> MemoryReadAuxCols { - assert!( - !read.address_space.is_zero(), - "cannot make `MemoryReadAuxCols` for address space 0" - ); - MemoryReadAuxCols::new( - read.prev_timestamp, - self.generate_timestamp_lt_cols(read.prev_timestamp, read.timestamp), - ) - } - - /// In general, prefer `generate_write_aux` which writes in-place rather than this function. - pub fn make_write_aux_cols( - &self, - write: &MemoryRecord, - ) -> MemoryWriteAuxCols { - let prev_data = write.prev_data_slice().unwrap(); - MemoryWriteAuxCols::new( - prev_data.try_into().unwrap(), - F::from_canonical_u32(write.prev_timestamp), - self.generate_timestamp_lt_cols(write.prev_timestamp, write.timestamp), - ) - } - fn generate_timestamp_lt_cols( &self, prev_timestamp: u32, @@ -809,14 +813,22 @@ impl MemoryAuxColsFactory { ) -> LessThanAuxCols { debug_assert!(prev_timestamp < timestamp); let mut decomp = [F::ZERO; AUX_LEN]; - self.timestamp_lt_air.generate_subrow( - (self.range_checker.as_ref(), prev_timestamp, timestamp), - &mut decomp, - ); + self.timestamp_lt_air + .generate_subrow((self.range_checker, prev_timestamp, timestamp), &mut decomp); LessThanAuxCols::new(decomp) } } +impl SharedMemoryHelper { + pub fn as_borrowed(&self) -> MemoryAuxColsFactory<'_, T> { + MemoryAuxColsFactory { + range_checker: self.range_checker.as_ref(), + timestamp_lt_air: self.timestamp_lt_air, + _marker: PhantomData, + } + } +} + #[cfg(test)] mod tests { use openvm_circuit_primitives::var_range::{ @@ -857,9 +869,9 @@ mod tests { if rng.gen_bool(0.5) { let data = F::from_canonical_u32(rng.gen_range(0..1 << 30)); - memory_controller.write(address_space, pointer, [data]); + memory_controller.write(address_space, pointer, &[data]); } else { - memory_controller.read::<1>(address_space, pointer); + memory_controller.read::(address_space, pointer); } } assert!(memory_controller diff --git a/crates/vm/src/system/memory/merkle/mod.rs b/crates/vm/src/system/memory/merkle/mod.rs index 74f8951bc4..4eac44bf81 100644 --- a/crates/vm/src/system/memory/merkle/mod.rs +++ b/crates/vm/src/system/memory/merkle/mod.rs @@ -1,17 +1,19 @@ use openvm_stark_backend::{interaction::PermutationCheckBus, p3_field::PrimeField32}; use rustc_hash::FxHashSet; -use super::controller::dimensions::MemoryDimensions; +use super::{controller::dimensions::MemoryDimensions, Equipartition, MemoryImage}; mod air; mod columns; mod trace; +mod tree; pub use air::*; pub use columns::*; pub(super) use trace::SerialReceiver; -#[cfg(test)] -mod tests; +// TODO: add back +// #[cfg(test)] +// mod tests; pub struct MemoryMerkleChip { pub air: MemoryMerkleAir, @@ -78,3 +80,17 @@ impl MemoryMerkleChip { } } } + +fn memory_to_partition( + memory: &MemoryImage, +) -> Equipartition { + let mut memory_partition = Equipartition::new(); + for ((address_space, pointer), value) in memory.items() { + let label = (address_space, pointer / N as u32); + let chunk = memory_partition + .entry(label) + .or_insert_with(|| [F::default(); N]); + chunk[(pointer % N as u32) as usize] = value; + } + memory_partition +} diff --git a/crates/vm/src/system/memory/merkle/tests/mod.rs b/crates/vm/src/system/memory/merkle/tests/mod.rs index 05c966dc23..65474093e3 100644 --- a/crates/vm/src/system/memory/merkle/tests/mod.rs +++ b/crates/vm/src/system/memory/merkle/tests/mod.rs @@ -7,7 +7,7 @@ use std::{ use openvm_stark_backend::{ interaction::{PermutationCheckBus, PermutationInteractionType}, - p3_field::FieldAlgebra, + p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::dense::RowMajorMatrix, prover::types::AirProofInput, Chip, ChipUsageGetter, @@ -19,6 +19,7 @@ use openvm_stark_sdk::{ }; use rand::RngCore; +use super::memory_to_partition; use crate::{ arch::testing::{MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS}, system::memory::{ @@ -39,9 +40,9 @@ const COMPRESSION_BUS: PermutationCheckBus = PermutationCheckBus::new(POSEIDON2_ fn test( memory_dimensions: MemoryDimensions, - initial_memory: &MemoryImage, + initial_memory: &MemoryImage, touched_labels: BTreeSet<(u32, u32)>, - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) { let MemoryDimensions { as_height, @@ -51,30 +52,31 @@ fn test( let merkle_bus = PermutationCheckBus::new(MEMORY_MERKLE_BUS); // checking validity of test data - for ((address_space, pointer), value) in final_memory.items() { + for ((address_space, pointer), value) in final_memory.items::() { let label = pointer / CHUNK as u32; assert!(address_space - as_offset < (1 << as_height)); assert!(pointer < ((CHUNK << address_height).div_ceil(PAGE_SIZE) * PAGE_SIZE) as u32); - if initial_memory.get(&(address_space, pointer)) != Some(&value) { + if unsafe { initial_memory.get::((address_space, pointer)) } != value { assert!(touched_labels.contains(&(address_space, label))); } } - for key in initial_memory.items().map(|(key, _)| key) { - assert!(final_memory.get(&key).is_some()); - } - for &(address_space, label) in touched_labels.iter() { - let mut contains_some_key = false; - for i in 0..CHUNK { - if final_memory - .get(&(address_space, label * CHUNK as u32 + i as u32)) - .is_some() - { - contains_some_key = true; - break; - } - } - assert!(contains_some_key); - } + // for key in initial_memory.items().map(|(key, _)| key) { + // assert!(unsafe { final_memory.get(key).is_some() }); + // } + // for &(address_space, label) in touched_labels.iter() { + // let mut contains_some_key = false; + // for i in 0..CHUNK { + // if unsafe { + // final_memory + // .get((address_space, label * CHUNK as u32 + i as u32)) + // .is_some() + // } { + // contains_some_key = true; + // break; + // } + // } + // assert!(contains_some_key); + // } let mut hash_test_chip = HashTestChip::new(); @@ -126,12 +128,11 @@ fn test( }; for (address_space, address_label) in touched_labels { - let initial_values = array::from_fn(|i| { - initial_memory - .get(&(address_space, address_label * CHUNK as u32 + i as u32)) - .copied() - .unwrap_or_default() - }); + let initial_values = unsafe { + array::from_fn(|i| { + initial_memory.get((address_space, address_label * CHUNK as u32 + i as u32)) + }) + }; let as_label = address_space - as_offset; interaction( PermutationInteractionType::Send, @@ -180,20 +181,6 @@ fn test( .expect("Verification failed"); } -fn memory_to_partition( - memory: &MemoryImage, -) -> Equipartition { - let mut memory_partition = Equipartition::new(); - for ((address_space, pointer), value) in memory.items() { - let label = (address_space, pointer / N as u32); - let chunk = memory_partition - .entry(label) - .or_insert_with(|| [F::default(); N]); - chunk[(pointer % N as u32) as usize] = value; - } - memory_partition -} - fn random_test( height: usize, max_value: u32, @@ -203,8 +190,12 @@ fn random_test( let mut rng = create_seeded_rng(); let mut next_u32 = || rng.next_u64() as u32; - let mut initial_memory = AddressMap::new(1, 2, CHUNK << height); - let mut final_memory = AddressMap::new(1, 2, CHUNK << height); + let as_cnt = 2; + let mut initial_memory = AddressMap::new(1, as_cnt, CHUNK << height); + let mut final_memory = AddressMap::new(1, as_cnt, CHUNK << height); + // TEMP[jpw]: override so address space uses field element + initial_memory.cell_size = vec![4; as_cnt]; + final_memory.cell_size = vec![4; as_cnt]; let mut seen = HashSet::new(); let mut touched_labels = BTreeSet::new(); @@ -221,15 +212,19 @@ fn random_test( if is_initial && num_initial_addresses != 0 { num_initial_addresses -= 1; let value = BabyBear::from_canonical_u32(next_u32() % max_value); - initial_memory.insert(&(address_space, pointer), value); - final_memory.insert(&(address_space, pointer), value); + unsafe { + initial_memory.insert((address_space, pointer), value); + final_memory.insert((address_space, pointer), value); + } } if is_touched && num_touched_addresses != 0 { num_touched_addresses -= 1; touched_labels.insert((address_space, label)); if value_changes || !is_initial { let value = BabyBear::from_canonical_u32(next_u32() % max_value); - final_memory.insert(&(address_space, pointer), value); + unsafe { + final_memory.insert((address_space, pointer), value); + } } } } diff --git a/crates/vm/src/system/memory/merkle/trace.rs b/crates/vm/src/system/memory/merkle/trace.rs index 52609f259a..5a5dfc35e0 100644 --- a/crates/vm/src/system/memory/merkle/trace.rs +++ b/crates/vm/src/system/memory/merkle/trace.rs @@ -1,6 +1,5 @@ use std::{ borrow::BorrowMut, - cmp::Reverse, sync::{atomic::AtomicU32, Arc}, }; @@ -11,16 +10,13 @@ use openvm_stark_backend::{ prover::types::AirProofInput, AirRef, Chip, ChipUsageGetter, }; -use rustc_hash::FxHashSet; use crate::{ arch::hasher::HasherChip, system::{ memory::{ - controller::dimensions::MemoryDimensions, - merkle::{FinalState, MemoryMerkleChip, MemoryMerkleCols}, - tree::MemoryNode::{self, NonLeaf}, - Equipartition, + merkle::{tree::MerkleTree, FinalState, MemoryMerkleChip, MemoryMerkleCols}, + Equipartition, MemoryImage, }, poseidon2::{ Poseidon2PeripheryBaseChip, Poseidon2PeripheryChip, PERIPHERY_POSEIDON2_WIDTH, @@ -31,37 +27,13 @@ use crate::{ impl MemoryMerkleChip { pub fn finalize( &mut self, - initial_tree: &MemoryNode, + initial_memory: MemoryImage, final_memory: &Equipartition, hasher: &mut impl HasherChip, ) { assert!(self.final_state.is_none(), "Merkle chip already finalized"); - // there needs to be a touched node with `height_section` = 0 - // shouldn't be a leaf because - // trace generation will expect an interaction from MemoryInterfaceChip in that case - if self.touched_nodes.len() == 1 { - self.touch_node(1, 0, 0); - } - - let mut rows = vec![]; - let mut tree_helper = TreeHelper { - memory_dimensions: self.air.memory_dimensions, - final_memory, - touched_nodes: &self.touched_nodes, - trace_rows: &mut rows, - }; - let final_tree = tree_helper.recur( - self.air.memory_dimensions.overall_height(), - initial_tree, - 0, - 0, - hasher, - ); - self.final_state = Some(FinalState { - rows, - init_root: initial_tree.hash(), - final_root: final_tree.hash(), - }); + let mut tree = MerkleTree::from_memory(initial_memory, &self.air.memory_dimensions, hasher); + self.final_state = Some(tree.finalize(hasher, final_memory, &self.air.memory_dimensions)); } } @@ -85,7 +57,8 @@ where } = self.final_state.unwrap(); // important that this sort be stable, // because we need the initial root to be first and the final root to be second - rows.sort_by_key(|row| Reverse(row.parent_height)); + rows.reverse(); + rows.swap(0, 1); let width = MemoryMerkleCols::, CHUNK>::width(); let mut height = rows.len().next_power_of_two(); @@ -122,136 +95,6 @@ impl ChipUsageGetter for MemoryMerkleChip { - memory_dimensions: MemoryDimensions, - final_memory: &'a Equipartition, - touched_nodes: &'a FxHashSet<(usize, u32, u32)>, - trace_rows: &'a mut Vec>, -} - -impl TreeHelper<'_, CHUNK, F> { - fn recur( - &mut self, - height: usize, - initial_node: &MemoryNode, - as_label: u32, - address_label: u32, - hasher: &mut impl HasherChip, - ) -> MemoryNode { - if height == 0 { - let address_space = as_label + self.memory_dimensions.as_offset; - let leaf_values = *self - .final_memory - .get(&(address_space, address_label)) - .unwrap_or(&[F::ZERO; CHUNK]); - MemoryNode::new_leaf(hasher.hash(&leaf_values)) - } else if let NonLeaf { - left: initial_left_node, - right: initial_right_node, - .. - } = initial_node.clone() - { - // Tell the hasher about this hash. - hasher.compress_and_record(&initial_left_node.hash(), &initial_right_node.hash()); - - let is_as_section = height > self.memory_dimensions.address_height; - - let (left_as_label, right_as_label) = if is_as_section { - (2 * as_label, 2 * as_label + 1) - } else { - (as_label, as_label) - }; - let (left_address_label, right_address_label) = if is_as_section { - (address_label, address_label) - } else { - (2 * address_label, 2 * address_label + 1) - }; - - let left_is_final = - !self - .touched_nodes - .contains(&(height - 1, left_as_label, left_address_label)); - - let final_left_node = if left_is_final { - initial_left_node - } else { - Arc::new(self.recur( - height - 1, - &initial_left_node, - left_as_label, - left_address_label, - hasher, - )) - }; - - let right_is_final = - !self - .touched_nodes - .contains(&(height - 1, right_as_label, right_address_label)); - - let final_right_node = if right_is_final { - initial_right_node - } else { - Arc::new(self.recur( - height - 1, - &initial_right_node, - right_as_label, - right_address_label, - hasher, - )) - }; - - let final_node = MemoryNode::new_nonleaf(final_left_node, final_right_node, hasher); - self.add_trace_row(height, as_label, address_label, initial_node, None); - self.add_trace_row( - height, - as_label, - address_label, - &final_node, - Some([left_is_final, right_is_final]), - ); - final_node - } else { - panic!("Leaf {:?} found at nonzero height {}", initial_node, height); - } - } - - /// Expects `node` to be NonLeaf - fn add_trace_row( - &mut self, - parent_height: usize, - as_label: u32, - address_label: u32, - node: &MemoryNode, - direction_changes: Option<[bool; 2]>, - ) { - let [left_direction_change, right_direction_change] = - direction_changes.unwrap_or([false; 2]); - let cols = if let NonLeaf { hash, left, right } = node { - MemoryMerkleCols { - expand_direction: if direction_changes.is_none() { - F::ONE - } else { - F::NEG_ONE - }, - height_section: F::from_bool(parent_height > self.memory_dimensions.address_height), - parent_height: F::from_canonical_usize(parent_height), - is_root: F::from_bool(parent_height == self.memory_dimensions.overall_height()), - parent_as_label: F::from_canonical_u32(as_label), - parent_address_label: F::from_canonical_u32(address_label), - parent_hash: *hash, - left_child_hash: left.hash(), - right_child_hash: right.hash(), - left_direction_different: F::from_bool(left_direction_change), - right_direction_different: F::from_bool(right_direction_change), - } - } else { - panic!("trace_rows expects node = {:?} to be NonLeaf", node); - }; - self.trace_rows.push(cols); - } -} - pub trait SerialReceiver { fn receive(&mut self, msg: T); } diff --git a/crates/vm/src/system/memory/merkle/tree.rs b/crates/vm/src/system/memory/merkle/tree.rs new file mode 100644 index 0000000000..74f3d54730 --- /dev/null +++ b/crates/vm/src/system/memory/merkle/tree.rs @@ -0,0 +1,228 @@ +use openvm_stark_backend::p3_field::PrimeField32; +use rustc_hash::FxHashMap; + +use super::{memory_to_partition, FinalState, MemoryMerkleCols}; +use crate::{ + arch::hasher::HasherChip, + system::memory::{dimensions::MemoryDimensions, AddressMap, Equipartition, PAGE_SIZE}, +}; + +#[derive(Debug)] +pub struct MerkleTree { + /// Height of the tree -- the root is the only node at height `height`, + /// and the leaves are at height `0`. + height: usize, + /// Nodes corresponding to all zeroes. + zero_nodes: Vec<[F; CHUNK]>, + /// Nodes in the tree that have ever been touched. + nodes: FxHashMap, +} + +impl MerkleTree { + pub fn new(height: usize, hasher: &impl HasherChip) -> Self { + Self { + height, + zero_nodes: (0..height + 1) + .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| { + let result = Some(*acc); + *acc = hasher.compress(acc, acc); + result + }) + .collect(), + nodes: FxHashMap::default(), + } + } + + /// Shared logic for both from_memory and finalize. + fn process_layers( + &mut self, + layer: Vec<(u64, [F; CHUNK])>, + md: &MemoryDimensions, + mut rows: Option<&mut Vec>>, + mut compress: CompressFn, + ) where + CompressFn: FnMut(&[F; CHUNK], &[F; CHUNK]) -> [F; CHUNK], + { + let mut layer = layer + .into_iter() + .map(|(index, values)| (index, values, self.get_node(index))) + .collect::>(); + for height in 1..=self.height { + let mut i = 0; + let mut new_layer = Vec::new(); + while i < layer.len() { + let (index, values, old_values) = layer[i]; + let par_index = index >> 1; + i += 1; + + let par_old_values = self.get_node(par_index); + + // Lowest `label_section_height` bits of `par_index` are the address label, + // The remaining highest are the address space label. + let label_section_height = md.address_height.saturating_sub(height); + let parent_address_label = (par_index & ((1 << label_section_height) - 1)) as u32; + let parent_as_label = + ((par_index & !(1 << (self.height - height))) >> label_section_height) as u32; + + self.nodes.insert(index, values); + + if i < layer.len() && layer[i].0 == index ^ 1 { + // sibling found + let (_, sibling_values, sibling_old_values) = layer[i]; + i += 1; + let combined = compress(&values, &sibling_values); + + // Only record rows if requested + if let Some(rows) = rows.as_deref_mut() { + rows.push(MemoryMerkleCols { + expand_direction: F::ONE, + height_section: F::from_bool(height > md.address_height), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32(parent_address_label), + parent_hash: self.get_node(par_index), + left_child_hash: old_values, + right_child_hash: sibling_old_values, + left_direction_different: F::ZERO, + right_direction_different: F::ZERO, + }); + rows.push(MemoryMerkleCols { + expand_direction: F::NEG_ONE, + height_section: F::from_bool(height > md.address_height), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32(parent_address_label), + parent_hash: combined, + left_child_hash: values, + right_child_hash: sibling_values, + left_direction_different: F::ZERO, + right_direction_different: F::ZERO, + }); + // This is a hacky way to say "and we also want to record the old values" + compress(&old_values, &sibling_old_values); + } + + self.nodes.insert(index ^ 1, sibling_values); + new_layer.push((par_index, combined, par_old_values)); + } else { + // no sibling found + let sibling_values = self.get_node(index ^ 1); + let is_left = index % 2 == 0; + let (left, right) = if is_left { + (values, sibling_values) + } else { + (sibling_values, values) + }; + let combined = compress(&left, &right); + + if let Some(rows) = rows.as_deref_mut() { + rows.push(MemoryMerkleCols { + expand_direction: F::ONE, + height_section: F::from_bool(height > md.address_height), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32(parent_address_label), + parent_hash: self.get_node(par_index), + left_child_hash: if is_left { old_values } else { left }, + right_child_hash: if is_left { right } else { old_values }, + left_direction_different: F::ZERO, + right_direction_different: F::ZERO, + }); + rows.push(MemoryMerkleCols { + expand_direction: F::NEG_ONE, + height_section: F::from_bool(height > md.address_height), + parent_height: F::from_canonical_usize(height), + is_root: F::from_bool(height == md.overall_height()), + parent_as_label: F::from_canonical_u32(parent_as_label), + parent_address_label: F::from_canonical_u32(parent_address_label), + parent_hash: combined, + left_child_hash: left, + right_child_hash: right, + left_direction_different: F::from_bool(!is_left), + right_direction_different: F::from_bool(is_left), + }); + // This is a hacky way to say "and we also want to record the old values" + if is_left { + compress(&old_values, &right); + } else { + compress(&left, &old_values); + } + } + + new_layer.push((par_index, combined, par_old_values)); + } + } + layer = new_layer; + } + if !layer.is_empty() { + assert_eq!(layer.len(), 1); + self.nodes.insert(layer[0].0, layer[0].1); + } + } + + pub fn from_memory( + initial_memory: AddressMap, + md: &MemoryDimensions, + hasher: &impl HasherChip, + ) -> Self { + let mut tree = Self::new(md.overall_height(), hasher); + let layer: Vec<_> = memory_to_partition(&initial_memory) + .iter() + .map(|((addr_sp, ptr), v)| { + ( + (1 << tree.height) + md.label_to_index((*addr_sp, *ptr)), + hasher.hash(v), + ) + }) + .collect(); + tree.process_layers(layer, md, None, |left, right| hasher.compress(left, right)); + tree + } + + pub fn finalize( + &mut self, + hasher: &mut impl HasherChip, + touched: &Equipartition, + md: &MemoryDimensions, + ) -> FinalState { + let init_root = self.get_node(1); + let layer: Vec<_> = touched + .iter() + .map(|((addr_sp, ptr), v)| { + ( + (1 << self.height) + md.label_to_index((*addr_sp, *ptr / CHUNK as u32)), + hasher.hash(v), + ) + }) + .collect(); + let mut rows = Vec::with_capacity(if touched.is_empty() { + 0 + } else { + layer + .iter() + .zip(layer.iter().skip(1)) + .fold(md.overall_height(), |acc, ((lhs, _), (rhs, _))| { + acc + (lhs ^ rhs).ilog2() as usize + }) + }); + self.process_layers(layer, md, Some(&mut rows), |left, right| { + hasher.compress_and_record(left, right) + }); + let final_root = self.get_node(1); + FinalState { + rows, + init_root, + final_root, + } + } + + fn get_node(&self, index: u64) -> [F; CHUNK] { + self.nodes + .get(&index) + .cloned() + .unwrap_or(self.zero_nodes[self.height - index.ilog2() as usize]) + } +} diff --git a/crates/vm/src/system/memory/mod.rs b/crates/vm/src/system/memory/mod.rs index ac6a7d85cf..1001520635 100644 --- a/crates/vm/src/system/memory/mod.rs +++ b/crates/vm/src/system/memory/mod.rs @@ -1,20 +1,20 @@ use openvm_circuit_primitives_derive::AlignedBorrow; -mod adapter; +pub mod adapter; mod controller; pub mod merkle; -mod offline; pub mod offline_checker; pub mod online; pub mod paged_vec; mod persistent; -#[cfg(test)] -mod tests; +// TODO: add back +// #[cfg(test)] +// mod tests; pub mod tree; mod volatile; pub use controller::*; -pub use offline::*; +pub use online::INITIAL_TIMESTAMP; pub use paged_vec::*; #[derive(PartialEq, Copy, Clone, Debug, Eq)] diff --git a/crates/vm/src/system/memory/offline.rs b/crates/vm/src/system/memory/offline.rs deleted file mode 100644 index 74bb238811..0000000000 --- a/crates/vm/src/system/memory/offline.rs +++ /dev/null @@ -1,1070 +0,0 @@ -use std::{array, cmp::max}; - -use openvm_circuit_primitives::{ - assert_less_than::AssertLtSubAir, var_range::SharedVariableRangeCheckerChip, -}; -use openvm_stark_backend::p3_field::PrimeField32; -use rustc_hash::FxHashSet; - -use super::{AddressMap, PagedVec, PAGE_SIZE}; -use crate::{ - arch::MemoryConfig, - system::memory::{ - adapter::{AccessAdapterInventory, AccessAdapterRecord, AccessAdapterRecordKind}, - offline_checker::{MemoryBridge, MemoryBus}, - MemoryAuxColsFactory, MemoryImage, RecordId, TimestampedEquipartition, TimestampedValues, - }, -}; - -pub const INITIAL_TIMESTAMP: u32 = 0; - -#[repr(C)] -#[derive(Clone, Default, PartialEq, Eq, Debug)] -struct BlockData { - pointer: u32, - timestamp: u32, - size: usize, -} - -struct BlockMap { - /// Block ids. 0 is a special value standing for the default block. - id: AddressMap, - /// The place where non-default blocks are stored. - storage: Vec, - initial_block_size: usize, -} - -impl BlockMap { - pub fn from_mem_config(mem_config: &MemoryConfig, initial_block_size: usize) -> Self { - assert!(initial_block_size.is_power_of_two()); - Self { - id: AddressMap::from_mem_config(mem_config), - storage: vec![], - initial_block_size, - } - } - - fn initial_block_data(pointer: u32, initial_block_size: usize) -> BlockData { - let aligned_pointer = (pointer / initial_block_size as u32) * initial_block_size as u32; - BlockData { - pointer: aligned_pointer, - size: initial_block_size, - timestamp: INITIAL_TIMESTAMP, - } - } - - pub fn get_without_adding(&self, address: &(u32, u32)) -> BlockData { - let idx = self.id.get(address).unwrap_or(&0); - if idx == &0 { - Self::initial_block_data(address.1, self.initial_block_size) - } else { - self.storage[idx - 1].clone() - } - } - - pub fn get(&mut self, address: &(u32, u32)) -> &BlockData { - let (address_space, pointer) = *address; - let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0); - if idx == &0 { - // `initial_block_size` is a power of two, as asserted in `from_mem_config`. - let pointer = pointer & !(self.initial_block_size as u32 - 1); - self.set_range( - &(address_space, pointer), - self.initial_block_size, - Self::initial_block_data(pointer, self.initial_block_size), - ); - self.storage.last().unwrap() - } else { - &self.storage[idx - 1] - } - } - - pub fn get_mut(&mut self, address: &(u32, u32)) -> &mut BlockData { - let (address_space, pointer) = *address; - let idx = self.id.get(&(address_space, pointer)).unwrap_or(&0); - if idx == &0 { - let pointer = pointer - pointer % self.initial_block_size as u32; - self.set_range( - &(address_space, pointer), - self.initial_block_size, - Self::initial_block_data(pointer, self.initial_block_size), - ); - self.storage.last_mut().unwrap() - } else { - &mut self.storage[idx - 1] - } - } - - pub fn set_range(&mut self, address: &(u32, u32), len: usize, block: BlockData) { - let (address_space, pointer) = address; - self.storage.push(block); - for i in 0..len { - self.id - .insert(&(*address_space, pointer + i as u32), self.storage.len()); - } - } - - pub fn items(&self) -> impl Iterator + '_ { - self.id - .items() - .filter(|(_, idx)| *idx > 0) - .map(|(address, idx)| (address, &self.storage[idx - 1])) - } -} - -#[derive(Debug, Clone, PartialEq)] -pub struct MemoryRecord { - pub address_space: T, - pub pointer: T, - pub timestamp: u32, - pub prev_timestamp: u32, - data: Vec, - /// None if a read. - prev_data: Option>, -} - -impl MemoryRecord { - pub fn data_slice(&self) -> &[T] { - self.data.as_slice() - } - - pub fn prev_data_slice(&self) -> Option<&[T]> { - self.prev_data.as_deref() - } -} - -impl MemoryRecord { - pub fn data_at(&self, index: usize) -> T { - self.data[index] - } -} - -pub struct OfflineMemory { - block_data: BlockMap, - data: Vec>, - as_offset: u32, - timestamp: u32, - timestamp_max_bits: usize, - - memory_bus: MemoryBus, - range_checker: SharedVariableRangeCheckerChip, - - log: Vec>>, -} - -impl OfflineMemory { - /// Creates a new partition with the given initial block size. - /// - /// Panics if the initial block size is not a power of two. - pub fn new( - initial_memory: MemoryImage, - initial_block_size: usize, - memory_bus: MemoryBus, - range_checker: SharedVariableRangeCheckerChip, - config: MemoryConfig, - ) -> Self { - assert_eq!(initial_memory.as_offset, config.as_offset); - Self { - block_data: BlockMap::from_mem_config(&config, initial_block_size), - data: initial_memory.paged_vecs, - as_offset: config.as_offset, - timestamp: INITIAL_TIMESTAMP + 1, - timestamp_max_bits: config.clk_max_bits, - memory_bus, - range_checker, - log: vec![], - } - } - - pub fn set_initial_memory(&mut self, initial_memory: MemoryImage, config: MemoryConfig) { - assert_eq!(self.timestamp, INITIAL_TIMESTAMP + 1); - assert_eq!(initial_memory.as_offset, config.as_offset); - self.as_offset = config.as_offset; - self.data = initial_memory.paged_vecs; - } - - pub(super) fn set_log_capacity(&mut self, access_capacity: usize) { - assert!(self.log.is_empty()); - self.log = Vec::with_capacity(access_capacity); - } - - pub fn memory_bridge(&self) -> MemoryBridge { - MemoryBridge::new( - self.memory_bus, - self.timestamp_max_bits, - self.range_checker.bus(), - ) - } - - pub fn timestamp(&self) -> u32 { - self.timestamp - } - - /// Increments the current timestamp by one and returns the new value. - pub fn increment_timestamp(&mut self) { - self.increment_timestamp_by(1) - } - - /// Increments the current timestamp by a specified delta and returns the new value. - pub fn increment_timestamp_by(&mut self, delta: u32) { - self.log.push(None); - self.timestamp += delta; - } - - /// Writes an array of values to the memory at the specified address space and start index. - pub fn write( - &mut self, - address_space: u32, - pointer: u32, - values: Vec, - records: &mut AccessAdapterInventory, - ) { - let len = values.len(); - assert!(len.is_power_of_two()); - assert_ne!(address_space, 0); - - let prev_timestamp = self.access_updating_timestamp(address_space, pointer, len, records); - - debug_assert!(prev_timestamp < self.timestamp); - - let pointer = pointer as usize; - let prev_data = self.data[(address_space - self.as_offset) as usize] - .set_range(pointer..pointer + len, &values); - - let record = MemoryRecord { - address_space: F::from_canonical_u32(address_space), - pointer: F::from_canonical_usize(pointer), - timestamp: self.timestamp, - prev_timestamp, - data: values, - prev_data: Some(prev_data), - }; - self.log.push(Some(record)); - self.timestamp += 1; - } - - /// Reads an array of values from the memory at the specified address space and start index. - pub fn read( - &mut self, - address_space: u32, - pointer: u32, - len: usize, - adapter_records: &mut AccessAdapterInventory, - ) { - assert!(len.is_power_of_two()); - if address_space == 0 { - let pointer = F::from_canonical_u32(pointer); - self.log.push(Some(MemoryRecord { - address_space: F::ZERO, - pointer, - timestamp: self.timestamp, - prev_timestamp: 0, - data: vec![pointer], - prev_data: None, - })); - self.timestamp += 1; - return; - } - - let prev_timestamp = - self.access_updating_timestamp(address_space, pointer, len, adapter_records); - - debug_assert!(prev_timestamp < self.timestamp); - - let values = self.range_vec(address_space, pointer, len); - - self.log.push(Some(MemoryRecord { - address_space: F::from_canonical_u32(address_space), - pointer: F::from_canonical_u32(pointer), - timestamp: self.timestamp, - prev_timestamp, - data: values, - prev_data: None, - })); - self.timestamp += 1; - } - - pub fn record_by_id(&self, id: RecordId) -> &MemoryRecord { - self.log[id.0].as_ref().unwrap() - } - - pub fn finalize( - &mut self, - adapter_records: &mut AccessAdapterInventory, - ) -> TimestampedEquipartition { - // First make sure the partition we maintain in self.block_data is an equipartition. - // Grab all aligned pointers that need to be re-accessed. - let to_access: FxHashSet<_> = self - .block_data - .items() - .map(|((address_space, pointer), _)| (address_space, (pointer / N as u32) * N as u32)) - .collect(); - - for &(address_space, pointer) in to_access.iter() { - let block = self.block_data.get(&(address_space, pointer)); - if block.pointer != pointer || block.size != N { - self.access(address_space, pointer, N, adapter_records); - } - } - - let mut equipartition = TimestampedEquipartition::::new(); - for (address_space, pointer) in to_access { - let block = self.block_data.get(&(address_space, pointer)); - - debug_assert_eq!(block.pointer % N as u32, 0); - debug_assert_eq!(block.size, N); - - equipartition.insert( - (address_space, pointer / N as u32), - TimestampedValues { - timestamp: block.timestamp, - values: self.range_array::(address_space, pointer), - }, - ); - } - equipartition - } - - // Modifies the partition to ensure that there is a block starting at (address_space, query). - fn split_to_make_boundary( - &mut self, - address_space: u32, - query: u32, - records: &mut AccessAdapterInventory, - ) { - let lim = (self.data[(address_space - self.as_offset) as usize].memory_size()) as u32; - if query == lim { - return; - } - assert!(query < lim); - let original_block = self.block_containing(address_space, query); - if original_block.pointer == query { - return; - } - - let data = self.range_vec(address_space, original_block.pointer, original_block.size); - - let timestamp = original_block.timestamp; - - let mut cur_ptr = original_block.pointer; - let mut cur_size = original_block.size; - while cur_size > 0 { - // Split. - records.add_record(AccessAdapterRecord { - timestamp, - address_space: F::from_canonical_u32(address_space), - start_index: F::from_canonical_u32(cur_ptr), - data: data[(cur_ptr - original_block.pointer) as usize - ..(cur_ptr - original_block.pointer) as usize + cur_size] - .to_vec(), - kind: AccessAdapterRecordKind::Split, - }); - - let half_size = cur_size / 2; - let half_size_u32 = half_size as u32; - let mid_ptr = cur_ptr + half_size_u32; - - if query <= mid_ptr { - // The right is finalized; add it to the partition. - let block = BlockData { - pointer: mid_ptr, - size: half_size, - timestamp, - }; - self.block_data - .set_range(&(address_space, mid_ptr), half_size, block); - } - if query >= cur_ptr + half_size_u32 { - // The left is finalized; add it to the partition. - let block = BlockData { - pointer: cur_ptr, - size: half_size, - timestamp, - }; - self.block_data - .set_range(&(address_space, cur_ptr), half_size, block); - } - if mid_ptr <= query { - cur_ptr = mid_ptr; - } - if cur_ptr == query { - break; - } - cur_size = half_size; - } - } - - fn access_updating_timestamp( - &mut self, - address_space: u32, - pointer: u32, - size: usize, - records: &mut AccessAdapterInventory, - ) -> u32 { - self.access(address_space, pointer, size, records); - - let mut prev_timestamp = None; - - let mut i = 0; - while i < size as u32 { - let block = self.block_data.get_mut(&(address_space, pointer + i)); - debug_assert!(i == 0 || prev_timestamp == Some(block.timestamp)); - prev_timestamp = Some(block.timestamp); - block.timestamp = self.timestamp; - i = block.pointer + block.size as u32; - } - prev_timestamp.unwrap() - } - - fn access( - &mut self, - address_space: u32, - pointer: u32, - size: usize, - records: &mut AccessAdapterInventory, - ) { - self.split_to_make_boundary(address_space, pointer, records); - self.split_to_make_boundary(address_space, pointer + size as u32, records); - - let block_data = self.block_containing(address_space, pointer); - - if block_data.pointer == pointer && block_data.size == size { - return; - } - assert!(size > 1); - - // Now recursively access left and right blocks to ensure they are in the partition. - let half_size = size / 2; - self.access(address_space, pointer, half_size, records); - self.access( - address_space, - pointer + half_size as u32, - half_size, - records, - ); - - self.merge_block_with_next(address_space, pointer, records); - } - - /// Merges the two adjacent blocks starting at (address_space, pointer). - /// - /// Panics if there is no block starting at (address_space, pointer) or if the two blocks - /// do not have the same size. - fn merge_block_with_next( - &mut self, - address_space: u32, - pointer: u32, - records: &mut AccessAdapterInventory, - ) { - let left_block = self.block_data.get(&(address_space, pointer)); - - let left_timestamp = left_block.timestamp; - let size = left_block.size; - - let right_timestamp = self - .block_data - .get(&(address_space, pointer + size as u32)) - .timestamp; - - let timestamp = max(left_timestamp, right_timestamp); - self.block_data.set_range( - &(address_space, pointer), - 2 * size, - BlockData { - pointer, - size: 2 * size, - timestamp, - }, - ); - records.add_record(AccessAdapterRecord { - timestamp, - address_space: F::from_canonical_u32(address_space), - start_index: F::from_canonical_u32(pointer), - data: self.range_vec(address_space, pointer, 2 * size), - kind: AccessAdapterRecordKind::Merge { - left_timestamp, - right_timestamp, - }, - }); - } - - fn block_containing(&mut self, address_space: u32, pointer: u32) -> BlockData { - self.block_data - .get_without_adding(&(address_space, pointer)) - } - - pub fn get(&self, address_space: u32, pointer: u32) -> F { - self.data[(address_space - self.as_offset) as usize] - .get(pointer as usize) - .cloned() - .unwrap_or_default() - } - - fn range_array(&self, address_space: u32, pointer: u32) -> [F; N] { - array::from_fn(|i| self.get(address_space, pointer + i as u32)) - } - - fn range_vec(&self, address_space: u32, pointer: u32, len: usize) -> Vec { - let pointer = pointer as usize; - self.data[(address_space - self.as_offset) as usize].range_vec(pointer..pointer + len) - } - - pub fn aux_cols_factory(&self) -> MemoryAuxColsFactory { - let range_bus = self.range_checker.bus(); - MemoryAuxColsFactory { - range_checker: self.range_checker.clone(), - timestamp_lt_air: AssertLtSubAir::new(range_bus, self.timestamp_max_bits), - _marker: Default::default(), - } - } - - // just for unit testing - #[cfg(test)] - fn last_record(&self) -> &MemoryRecord { - self.log.last().unwrap().as_ref().unwrap() - } -} - -#[cfg(test)] -mod tests { - use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, - }; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - - use super::{BlockData, MemoryRecord, OfflineMemory}; - use crate::{ - arch::MemoryConfig, - system::memory::{ - adapter::{AccessAdapterInventory, AccessAdapterRecord, AccessAdapterRecordKind}, - offline_checker::MemoryBus, - paged_vec::AddressMap, - MemoryImage, TimestampedValues, - }, - }; - - macro_rules! bb { - ($x:expr) => { - BabyBear::from_canonical_u32($x) - }; - } - - macro_rules! bba { - [$($x:expr),*] => { - [$(BabyBear::from_canonical_u32($x)),*] - } - } - - macro_rules! bbvec { - [$($x:expr),*] => { - vec![$(BabyBear::from_canonical_u32($x)),*] - } - } - - fn setup_test( - initial_memory: MemoryImage, - initial_block_size: usize, - ) -> (OfflineMemory, AccessAdapterInventory) { - let memory_bus = MemoryBus::new(0); - let range_checker = - SharedVariableRangeCheckerChip::new(VariableRangeCheckerBus::new(1, 29)); - let mem_config = MemoryConfig { - as_offset: initial_memory.as_offset, - ..Default::default() - }; - let memory = OfflineMemory::new( - initial_memory, - initial_block_size, - memory_bus, - range_checker.clone(), - mem_config, - ); - let access_adapter_inventory = AccessAdapterInventory::new( - range_checker, - memory_bus, - mem_config.clk_max_bits, - mem_config.max_access_adapter_n, - ); - (memory, access_adapter_inventory) - } - - #[test] - fn test_partition() { - let initial_memory = AddressMap::new(0, 1, 16); - let (mut memory, _) = setup_test(initial_memory, 8); - assert_eq!( - memory.block_containing(1, 13), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 8), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 15), - BlockData { - pointer: 8, - size: 8, - timestamp: 0, - } - ); - - assert_eq!( - memory.block_containing(1, 16), - BlockData { - pointer: 16, - size: 8, - timestamp: 0, - } - ); - } - - #[test] - fn test_write_read_initial_block_len_1() { - let (mut memory, mut access_adapters) = setup_test(MemoryImage::default(), 1); - let address_space = 1; - - memory.write(address_space, 0, bbvec![1, 2, 3, 4], &mut access_adapters); - - memory.read(address_space, 0, 2, &mut access_adapters); - let read_record = memory.last_record(); - assert_eq!(read_record.data, bba![1, 2]); - - memory.write(address_space, 2, bbvec![100], &mut access_adapters); - - memory.read(address_space, 0, 4, &mut access_adapters); - let read_record = memory.last_record(); - assert_eq!(read_record.data, bba![1, 2, 100, 4]); - } - - #[test] - fn test_records_initial_block_len_1() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 1); - - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - - // Above write first causes merge of [0:1] and [1:2] into [0:2]. - assert_eq!( - adapter_records.records_for_n(2)[0], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // then merge [2:3] and [3:4] into [2:4]. - assert_eq!( - adapter_records.records_for_n(2)[1], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(2), - data: bbvec![0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // then merge [0:2] and [2:4] into [0:4]. - assert_eq!( - adapter_records.records_for_n(4)[0], - AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0, 0, 0], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 0, - right_timestamp: 0, - }, - } - ); - // At time 1 we write [0:4]. - let write_record = memory.last_record(); - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 1, - prev_timestamp: 0, - data: bbvec![1, 2, 3, 4], - prev_data: Some(bbvec![0, 0, 0, 0]), - } - ); - assert_eq!(memory.timestamp(), 2); - assert_eq!(adapter_records.total_records(), 3); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - // At time 2 we read [0:4]. - assert_eq!(adapter_records.total_records(), 3); - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 2, - prev_timestamp: 1, - data: bbvec![1, 2, 3, 4], - prev_data: None, - } - ); - assert_eq!(memory.timestamp(), 3); - - memory.write(1, 0, bbvec![10, 11], &mut adapter_records); - let write_record = memory.last_record(); - // write causes split [0:4] into [0:2] and [2:4] (to prepare for write to [0:2]). - assert_eq!(adapter_records.total_records(), 4); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 2, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![1, 2, 3, 4], - kind: AccessAdapterRecordKind::Split, - } - ); - - // At time 3 we write [10, 11] into [0, 2]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 3, - prev_timestamp: 2, - data: bbvec![10, 11], - prev_data: Some(bbvec![1, 2]), - } - ); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - assert_eq!(adapter_records.total_records(), 5); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 3, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![10, 11, 3, 4], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 3, - right_timestamp: 2 - }, - } - ); - // At time 9 we read [0:4]. - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 4, - prev_timestamp: 3, - data: bbvec![10, 11, 3, 4], - prev_data: None, - } - ); - } - - #[test] - fn test_records_initial_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - let write_record = memory.last_record(); - - // Above write first causes split of [0:8] into [0:4] and [4:8]. - assert_eq!(adapter_records.total_records(), 1); - assert_eq!( - adapter_records.records_for_n(8).last().unwrap(), - &AccessAdapterRecord { - timestamp: 0, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![0, 0, 0, 0, 0, 0, 0, 0], - kind: AccessAdapterRecordKind::Split, - } - ); - // At time 1 we write [0:4]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 1, - prev_timestamp: 0, - data: bbvec![1, 2, 3, 4], - prev_data: Some(bbvec![0, 0, 0, 0]), - } - ); - assert_eq!(memory.timestamp(), 2); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - // At time 2 we read [0:4]. - assert_eq!(adapter_records.total_records(), 1); - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 2, - prev_timestamp: 1, - data: bbvec![1, 2, 3, 4], - prev_data: None, - } - ); - assert_eq!(memory.timestamp(), 3); - - memory.write(1, 0, bbvec![10, 11], &mut adapter_records); - let write_record = memory.last_record(); - // write causes split [0:4] into [0:2] and [2:4] (to prepare for write to [0:2]). - assert_eq!(adapter_records.total_records(), 2); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 2, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![1, 2, 3, 4], - kind: AccessAdapterRecordKind::Split, - } - ); - - // At time 3 we write [10, 11] into [0, 2]. - assert_eq!( - write_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 3, - prev_timestamp: 2, - data: bbvec![10, 11], - prev_data: Some(bbvec![1, 2]), - } - ); - - memory.read(1, 0, 4, &mut adapter_records); - let read_record = memory.last_record(); - assert_eq!(adapter_records.total_records(), 3); - assert_eq!( - adapter_records.records_for_n(4).last().unwrap(), - &AccessAdapterRecord { - timestamp: 3, - address_space: bb!(1), - start_index: bb!(0), - data: bbvec![10, 11, 3, 4], - kind: AccessAdapterRecordKind::Merge { - left_timestamp: 3, - right_timestamp: 2 - }, - } - ); - // At time 9 we read [0:4]. - assert_eq!( - read_record, - &MemoryRecord { - address_space: bb!(1), - pointer: bb!(0), - timestamp: 4, - prev_timestamp: 3, - data: bbvec![10, 11, 3, 4], - prev_data: None, - } - ); - } - - #[test] - fn test_get_initial_block_len_1() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 1); - - memory.write(2, 0, bbvec![4, 3, 2, 1], &mut adapter_records); - - assert_eq!(memory.get(2, 0), BabyBear::from_canonical_u32(4)); - assert_eq!(memory.get(2, 1), BabyBear::from_canonical_u32(3)); - assert_eq!(memory.get(2, 2), BabyBear::from_canonical_u32(2)); - assert_eq!(memory.get(2, 3), BabyBear::from_canonical_u32(1)); - assert_eq!(memory.get(2, 5), BabyBear::ZERO); - - assert_eq!(memory.get(1, 0), BabyBear::ZERO); - } - - #[test] - fn test_get_initial_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - - memory.write(2, 0, bbvec![4, 3, 2, 1], &mut adapter_records); - - assert_eq!(memory.get(2, 0), BabyBear::from_canonical_u32(4)); - assert_eq!(memory.get(2, 1), BabyBear::from_canonical_u32(3)); - assert_eq!(memory.get(2, 2), BabyBear::from_canonical_u32(2)); - assert_eq!(memory.get(2, 3), BabyBear::from_canonical_u32(1)); - assert_eq!(memory.get(2, 5), BabyBear::ZERO); - assert_eq!(memory.get(2, 9), BabyBear::ZERO); - assert_eq!(memory.get(1, 0), BabyBear::ZERO); - } - - #[test] - fn test_finalize_empty() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 4); - - let memory = memory.finalize::<4>(&mut adapter_records); - assert_eq!(memory.len(), 0); - assert_eq!(adapter_records.total_records(), 0); - } - - #[test] - fn test_finalize_block_len_8() { - let (mut memory, mut adapter_records) = setup_test(MemoryImage::default(), 8); - // Make block 0:4 in address space 1 active. - memory.write(1, 0, bbvec![1, 2, 3, 4], &mut adapter_records); - - // Make block 16:32 in address space 1 active. - memory.write( - 1, - 16, - bbvec![1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], - &mut adapter_records, - ); - - // Make block 64:72 in address space 2 active. - memory.write(2, 64, bbvec![8, 7, 6, 5, 4, 3, 2, 1], &mut adapter_records); - - let num_records_before_finalize = adapter_records.total_records(); - - // Finalize to a partition of size 8. - let final_memory = memory.finalize::<8>(&mut adapter_records); - assert_eq!(final_memory.len(), 4); - assert_eq!( - final_memory.get(&(1, 0)), - Some(&TimestampedValues { - values: bba![1, 2, 3, 4, 0, 0, 0, 0], - timestamp: 1, - }) - ); - // start_index = 16 corresponds to label = 2 - assert_eq!( - final_memory.get(&(1, 2)), - Some(&TimestampedValues { - values: bba![1, 1, 1, 1, 1, 1, 1, 1], - timestamp: 2, - }) - ); - // start_index = 24 corresponds to label = 3 - assert_eq!( - final_memory.get(&(1, 3)), - Some(&TimestampedValues { - values: bba![1, 1, 1, 1, 1, 1, 1, 1], - timestamp: 2, - }) - ); - // start_index = 64 corresponds to label = 8 - assert_eq!( - final_memory.get(&(2, 8)), - Some(&TimestampedValues { - values: bba![8, 7, 6, 5, 4, 3, 2, 1], - timestamp: 3, - }) - ); - - // We need to do 1 + 1 + 0 = 2 adapters. - assert_eq!( - adapter_records.total_records() - num_records_before_finalize, - 2 - ); - } - - #[test] - fn test_write_read_initial_block_len_8_initial_memory() { - type F = BabyBear; - - // Initialize initial memory with blocks at indices 0 and 2 - let mut initial_memory = MemoryImage::default(); - for i in 0..8 { - initial_memory.insert(&(1, i), F::from_canonical_u32(i + 1)); - initial_memory.insert(&(1, 16 + i), F::from_canonical_u32(i + 1)); - } - - let (mut memory, mut adapter_records) = setup_test(initial_memory, 8); - - // Verify initial state of block 0 (pointers 0–8) - memory.read(1, 0, 8, &mut adapter_records); - let initial_read_record_0 = memory.last_record(); - assert_eq!(initial_read_record_0.data, bbvec![1, 2, 3, 4, 5, 6, 7, 8]); - - // Verify initial state of block 2 (pointers 16–24) - memory.read(1, 16, 8, &mut adapter_records); - let initial_read_record_2 = memory.last_record(); - assert_eq!(initial_read_record_2.data, bbvec![1, 2, 3, 4, 5, 6, 7, 8]); - - // Test: Write a partial block to block 0 (pointer 0) and read back partially and fully - memory.write(1, 0, bbvec![9, 9, 9, 9], &mut adapter_records); - memory.read(1, 0, 2, &mut adapter_records); - let partial_read_record = memory.last_record(); - assert_eq!(partial_read_record.data, bbvec![9, 9]); - - memory.read(1, 0, 8, &mut adapter_records); - let full_read_record_0 = memory.last_record(); - assert_eq!(full_read_record_0.data, bbvec![9, 9, 9, 9, 5, 6, 7, 8]); - - // Test: Write a single element to pointer 2 and verify read in different lengths - memory.write(1, 2, bbvec![100], &mut adapter_records); - memory.read(1, 1, 4, &mut adapter_records); - let read_record_4 = memory.last_record(); - assert_eq!(read_record_4.data, bbvec![9, 100, 9, 5]); - - memory.read(1, 2, 8, &mut adapter_records); - let full_read_record_2 = memory.last_record(); - assert_eq!(full_read_record_2.data, bba![100, 9, 5, 6, 7, 8, 0, 0]); - - // Test: Write and read at the last pointer in block 2 (pointer 23, part of key (1, 2)) - memory.write(1, 23, bbvec![77], &mut adapter_records); - memory.read(1, 23, 2, &mut adapter_records); - let boundary_read_record = memory.last_record(); - assert_eq!(boundary_read_record.data, bba![77, 0]); // Last byte modified, ensuring boundary check - - // Test: Reading from an uninitialized block (should default to 0) - memory.read(1, 10, 4, &mut adapter_records); - let default_read_record = memory.last_record(); - assert_eq!(default_read_record.data, bba![0, 0, 0, 0]); - - memory.read(1, 100, 4, &mut adapter_records); - let default_read_record = memory.last_record(); - assert_eq!(default_read_record.data, bba![0, 0, 0, 0]); - - // Test: Overwrite entire memory pointer 16–24 and verify - memory.write( - 1, - 16, - bbvec![50, 50, 50, 50, 50, 50, 50, 50], - &mut adapter_records, - ); - memory.read(1, 16, 8, &mut adapter_records); - let overwrite_read_record = memory.last_record(); - assert_eq!( - overwrite_read_record.data, - bba![50, 50, 50, 50, 50, 50, 50, 50] - ); // Verify entire block overwrite - } -} diff --git a/crates/vm/src/system/memory/offline_checker/bridge.rs b/crates/vm/src/system/memory/offline_checker/bridge.rs index 2c7e180cfb..3174309454 100644 --- a/crates/vm/src/system/memory/offline_checker/bridge.rs +++ b/crates/vm/src/system/memory/offline_checker/bridge.rs @@ -21,7 +21,7 @@ use crate::system::memory::{ /// be decomposed into) for the `AssertLtSubAir` in the `MemoryOfflineChecker`. /// Warning: This requires that (clk_max_bits + decomp - 1) / decomp = AUX_LEN /// in MemoryOfflineChecker (or whenever AssertLtSubAir is used) -pub(crate) const AUX_LEN: usize = 2; +pub const AUX_LEN: usize = 2; /// The [MemoryBridge] is used within AIR evaluation functions to constrain logical memory /// operations (read/write). It adds all necessary constraints and interactions. diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index 5a27b3e433..be5037d3ec 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -1,6 +1,8 @@ //! Defines auxiliary columns for memory operations: `MemoryReadAuxCols`, //! `MemoryReadWithImmediateAuxCols`, and `MemoryWriteAuxCols`. +use std::ops::DerefMut; + use openvm_circuit_primitives::is_less_than::LessThanAuxCols; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_stark_backend::p3_field::PrimeField32; @@ -16,13 +18,19 @@ pub struct MemoryBaseAuxCols { /// The previous timestamps in which the cells were accessed. pub(in crate::system::memory) prev_timestamp: T, /// The auxiliary columns to perform the less than check. - pub(in crate::system::memory) timestamp_lt_aux: LessThanAuxCols, + pub timestamp_lt_aux: LessThanAuxCols, +} + +impl MemoryBaseAuxCols { + pub fn set_prev(&mut self, prev_timestamp: F) { + self.prev_timestamp = prev_timestamp; + } } #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryWriteAuxCols { - pub(in crate::system::memory) base: MemoryBaseAuxCols, + pub base: MemoryBaseAuxCols, pub(in crate::system::memory) prev_data: [T; N], } @@ -40,9 +48,7 @@ impl MemoryWriteAuxCols { prev_data, } } -} -impl MemoryWriteAuxCols { pub fn from_base(base: MemoryBaseAuxCols, prev_data: [T; N]) -> Self { Self { base, prev_data } } @@ -54,6 +60,12 @@ impl MemoryWriteAuxCols { pub fn prev_data(&self) -> &[T; N] { &self.prev_data } + + /// Sets the previous timestamp and data **without** updating the less than auxiliary columns. + pub fn set_prev(&mut self, timestamp: T, data: [T; N]) { + self.base.prev_timestamp = timestamp; + self.prev_data = data; + } } /// The auxiliary columns for a memory read operation with block size `N`. @@ -67,10 +79,7 @@ pub struct MemoryReadAuxCols { } impl MemoryReadAuxCols { - pub(in crate::system::memory) fn new( - prev_timestamp: u32, - timestamp_lt_aux: LessThanAuxCols, - ) -> Self { + pub fn new(prev_timestamp: u32, timestamp_lt_aux: LessThanAuxCols) -> Self { Self { base: MemoryBaseAuxCols { prev_timestamp: F::from_canonical_u32(prev_timestamp), @@ -82,14 +91,19 @@ impl MemoryReadAuxCols { pub fn get_base(self) -> MemoryBaseAuxCols { self.base } + + /// Sets the previous timestamp **without** updating the less than auxiliary columns. + pub fn set_prev(&mut self, timestamp: F) { + self.base.prev_timestamp = timestamp; + } } #[repr(C)] #[derive(Clone, Debug, AlignedBorrow)] pub struct MemoryReadOrImmediateAuxCols { - pub(crate) base: MemoryBaseAuxCols, - pub(crate) is_immediate: T, - pub(crate) is_zero_aux: T, + pub base: MemoryBaseAuxCols, + pub is_immediate: T, + pub is_zero_aux: T, } impl AsRef> for MemoryWriteAuxCols { @@ -102,3 +116,15 @@ impl AsRef> for MemoryWriteAuxCols unsafe { &*(self as *const MemoryWriteAuxCols as *const MemoryReadAuxCols) } } } + +impl AsMut> for MemoryWriteAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} + +impl AsMut> for MemoryReadAuxCols { + fn as_mut(&mut self) -> &mut MemoryBaseAuxCols { + &mut self.base + } +} diff --git a/crates/vm/src/system/memory/online.rs b/crates/vm/src/system/memory/online.rs index a5bf663e4c..b1e8a3b642 100644 --- a/crates/vm/src/system/memory/online.rs +++ b/crates/vm/src/system/memory/online.rs @@ -1,14 +1,149 @@ use std::fmt::Debug; +use getset::Getters; +use itertools::{izip, zip_eq}; +use openvm_circuit_primitives::var_range::SharedVariableRangeCheckerChip; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; -use super::paged_vec::{AddressMap, PAGE_SIZE}; -use crate::{ - arch::MemoryConfig, - system::memory::{offline::INITIAL_TIMESTAMP, MemoryImage, RecordId}, +use super::{ + adapter::{AccessAdapterInventory, AdapterInventoryTraceCursor}, + offline_checker::MemoryBus, + paged_vec::{AddressMap, PAGE_SIZE}, + Address, MemoryAddress, PagedVec, }; +use crate::{arch::MemoryConfig, system::memory::MemoryImage}; +pub const INITIAL_TIMESTAMP: u32 = 0; + +#[derive(Debug, Clone, derive_new::new)] +pub struct GuestMemory { + pub memory: AddressMap, +} + +impl GuestMemory { + /// Returns `[pointer:BLOCK_SIZE]_{address_space}` + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + pub unsafe fn read( + &self, + addr_space: u32, + ptr: u32, + ) -> [T; BLOCK_SIZE] + where + T: Copy + Debug, + { + debug_assert_eq!( + size_of::(), + self.memory.cell_size[(addr_space - self.memory.as_offset) as usize] + ); + let read = self + .memory + .paged_vecs + .get_unchecked((addr_space - self.memory.as_offset) as usize) + .get((ptr as usize) * size_of::()); + read + } + + /// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}` + /// + /// # Safety + /// See [`GuestMemory::read`]. + pub unsafe fn write( + &mut self, + addr_space: u32, + ptr: u32, + values: &[T; BLOCK_SIZE], + ) where + T: Copy + Debug, + { + debug_assert_eq!( + size_of::(), + self.memory.cell_size[(addr_space - self.memory.as_offset) as usize], + "addr_space={addr_space}" + ); + self.memory + .paged_vecs + .get_unchecked_mut((addr_space - self.memory.as_offset) as usize) + .set((ptr as usize) * size_of::(), values); + } + + /// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}` and returns + /// the previous values. + /// + /// # Safety + /// See [`GuestMemory::read`]. + #[inline(always)] + pub unsafe fn replace( + &mut self, + address_space: u32, + pointer: u32, + values: &[T; BLOCK_SIZE], + ) -> [T; BLOCK_SIZE] + where + T: Copy + Debug, + { + let prev = self.read(address_space, pointer); + self.write(address_space, pointer, values); + prev + } +} + +// /// API for guest memory conforming to OpenVM ISA +// pub trait GuestMemory { +// /// Returns `[pointer:BLOCK_SIZE]_{address_space}` +// /// +// /// # Safety +// /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, +// /// and it must be the exact type used to represent a single memory cell in +// /// address space `address_space`. For standard usage, +// /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. +// unsafe fn read( +// &self, +// address_space: u32, +// pointer: u32, +// ) -> [T; BLOCK_SIZE] +// where +// T: Copy + Debug; + +// /// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}` +// /// +// /// # Safety +// /// See [`GuestMemory::read`]. +// unsafe fn write( +// &mut self, +// address_space: u32, +// pointer: u32, +// values: &[T; BLOCK_SIZE], +// ) where +// T: Copy + Debug; + +// /// Writes `values` to `[pointer:BLOCK_SIZE]_{address_space}` and returns +// /// the previous values. +// /// +// /// # Safety +// /// See [`GuestMemory::read`]. +// #[inline(always)] +// unsafe fn replace( +// &mut self, +// address_space: u32, +// pointer: u32, +// values: &[T; BLOCK_SIZE], +// ) -> [T; BLOCK_SIZE] +// where +// T: Copy + Debug, +// { +// let prev = self.read(address_space, pointer); +// self.write(address_space, pointer, values); +// prev +// } +// } + +// TO BE DELETED #[derive(Debug, Clone, Serialize, Deserialize)] pub enum MemoryLogEntry { Read { @@ -24,128 +159,472 @@ pub enum MemoryLogEntry { IncrementTimestampBy(u32), } -/// A simple data structure to read to/write from memory. -/// -/// Stores a log of memory accesses to reconstruct aspects of memory state for trace generation. -#[derive(Debug)] -pub struct Memory { - pub(super) data: AddressMap, - pub(super) log: Vec>, - timestamp: u32, +// perf[jpw]: since we restrict `timestamp < 2^29`, we could pack `timestamp, log2(block_size)` +// into a single u32 to save half the memory, since `block_size` is a power of 2 and its log2 +// is less than 2^3. +#[repr(C)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, derive_new::new)] +pub struct AccessMetadata { + pub timestamp: u32, + pub block_size: u32, } -impl Memory { - pub fn new(mem_config: &MemoryConfig) -> Self { +impl AccessMetadata { + /// A marker indicating that the element is a part of a larger block which starts earlier. + pub const OCCUPIED: u32 = u32::MAX; +} + +/// Online memory that stores additional information for trace generation purposes. +/// In particular, keeps track of timestamp. +#[derive(Getters)] +pub struct TracingMemory { + pub timestamp: u32, + /// The initial block size -- this depends on the type of boundary chip. + initial_block_size: usize, + /// The underlying data memory, with memory cells typed by address space: see [AddressMap]. + // TODO: make generic in GuestMemory + #[getset(get = "pub")] + pub data: GuestMemory, + /// A map of `addr_space -> (ptr / min_block_size[addr_space] -> (timestamp: u32, block_size: + /// u32))` for the timestamp and block size of the latest access. + pub(super) meta: Vec>, + /// For each `addr_space`, the minimum block size allowed for memory accesses. In other words, + /// all memory accesses in `addr_space` must be aligned to this block size. + pub min_block_size: Vec, + pub access_adapter_inventory: AccessAdapterInventory, + pub adapter_inventory_trace_cursor: AdapterInventoryTraceCursor, +} + +impl TracingMemory { + // TODO: per-address space memory capacity specification + pub fn new( + mem_config: &MemoryConfig, + range_checker: SharedVariableRangeCheckerChip, + memory_bus: MemoryBus, + initial_block_size: usize, + ) -> Self { + assert_eq!(mem_config.as_offset, 1); + let num_cells = 1usize << mem_config.pointer_max_bits; // max cells per address space + let num_addr_sp = 1 + (1 << mem_config.as_height); + let mut min_block_size = vec![1; num_addr_sp]; + // TMP: hardcoding for now + min_block_size[1] = 4; + min_block_size[2] = 4; + min_block_size[3] = 4; + let meta = min_block_size + .iter() + .map(|&min_block_size| { + PagedVec::new( + num_cells + .checked_mul(size_of::()) + .unwrap() + .div_ceil(PAGE_SIZE * min_block_size as usize), + ) + }) + .collect(); Self { - data: AddressMap::from_mem_config(mem_config), + data: GuestMemory::new(AddressMap::from_mem_config(mem_config)), + meta, + min_block_size, timestamp: INITIAL_TIMESTAMP + 1, - log: Vec::with_capacity(mem_config.access_capacity), + initial_block_size, + access_adapter_inventory: AccessAdapterInventory::new( + range_checker, + memory_bus, + mem_config.clk_max_bits, + mem_config.max_access_adapter_n, + ), + adapter_inventory_trace_cursor: AdapterInventoryTraceCursor::new(num_addr_sp), } } /// Instantiates a new `Memory` data structure from an image. - pub fn from_image(image: MemoryImage, access_capacity: usize) -> Self { - Self { - data: image, - timestamp: INITIAL_TIMESTAMP + 1, - log: Vec::with_capacity(access_capacity), + pub fn with_image(mut self, image: MemoryImage, _access_capacity: usize) -> Self { + for (i, (paged_vec, cell_size)) in izip!(&image.paged_vecs, &image.cell_size).enumerate() { + let num_cells = paged_vec.bytes_capacity() / cell_size; + + self.meta[i] = PagedVec::new( + num_cells + .checked_mul(size_of::()) + .unwrap() + .div_ceil(PAGE_SIZE * self.min_block_size[i] as usize), + ); } + self.data = GuestMemory::new(image); + self } - fn last_record_id(&self) -> RecordId { - RecordId(self.log.len() - 1) + #[inline(always)] + fn assert_alignment(&self, block_size: usize, align: usize, addr_space: u32, ptr: u32) { + debug_assert!(block_size.is_power_of_two()); + debug_assert_eq!(block_size % align, 0); + debug_assert_ne!(addr_space, 0); + debug_assert_eq!(align as u32, self.min_block_size[addr_space as usize]); + assert_eq!( + ptr % (align as u32), + 0, + "pointer={ptr} not aligned to {align}" + ); + } + + pub(crate) fn execute_splits( + &mut self, + address: MemoryAddress, + align: usize, + values: &[F], + timestamp: u32, + ) { + if UPDATE_META { + for i in 0..(values.len() / align) { + self.set_meta_block( + address.address_space as usize, + address.pointer as usize + i * align, + align, + align, + timestamp, + ); + } + } + let mut size = align; + let MemoryAddress { + address_space, + pointer, + } = address; + while size < values.len() { + size *= 2; + for i in (0..values.len()).step_by(size) { + self.access_adapter_inventory.execute_split( + MemoryAddress { + address_space, + pointer: pointer + i as u32, + }, + &values[i..i + size], + timestamp, + self.adapter_inventory_trace_cursor.get_row_slice(size), + ); + } + } + } + + pub(crate) fn execute_merges( + &mut self, + address: MemoryAddress, + align: usize, + values: &[F], + timestamps: &[u32], + ) { + if UPDATE_META { + self.set_meta_block( + address.address_space as usize, + address.pointer as usize, + align, + values.len(), + *timestamps.iter().max().unwrap(), + ); + } + let mut size = align; + let MemoryAddress { + address_space, + pointer, + } = address; + while size < values.len() { + size *= 2; + for i in (0..values.len()).step_by(size) { + let left_timestamp = timestamps[(i / align)..((i + size / 2) / align)] + .iter() + .max() + .unwrap(); + let right_timestamp = timestamps[((i + size / 2) / align)..((i + size) / align)] + .iter() + .max() + .unwrap(); + self.access_adapter_inventory.execute_merge( + MemoryAddress { + address_space, + pointer: pointer + i as u32, + }, + &values[i..i + size], + *left_timestamp, + *right_timestamp, + self.adapter_inventory_trace_cursor.get_row_slice(size), + ); + } + } } - /// Writes an array of values to the memory at the specified address space and start index. + /// Updates the metadata with the given block. + fn set_meta_block( + &mut self, + address_space: usize, + pointer: usize, + align: usize, + block_size: usize, + timestamp: u32, + ) { + let ptr = pointer / align; + let meta = unsafe { self.meta.get_unchecked_mut(address_space) }; + meta.set( + ptr * size_of::(), + &AccessMetadata { + timestamp, + block_size: block_size as u32, + }, + ); + for i in 1..(block_size / align) { + meta.set( + (ptr + i) * size_of::(), + &AccessMetadata { + timestamp, + block_size: AccessMetadata::OCCUPIED, + }, + ); + } + } + + /// Returns the timestamp of the previous access to `[pointer:BLOCK_SIZE]_{address_space}`. + /// If we need to split/merge/initialize something for this, we first do all the necessary + /// actions. In the end of this process, we have this segment intact in our `meta`. /// - /// Returns the `RecordId` for the memory record and the previous data. - pub fn write( + /// Caller must ensure alignment (e.g. via `assert_alignment`) prior to calling this function. + fn prev_access_time( + &mut self, + address_space: usize, + pointer: usize, + align: usize, + ) -> u32 { + let num_segs = BLOCK_SIZE / align; + + let begin = pointer / align; + let end = begin + BLOCK_SIZE / align; + + let mut prev_ts = INITIAL_TIMESTAMP; + let mut block_timestamps = vec![INITIAL_TIMESTAMP; num_segs]; + let mut cur_ptr = begin; + let need_to_merge = loop { + if cur_ptr >= end { + break true; + } + let mut current_metadata = self.meta[address_space] + .get::(cur_ptr * size_of::()); + if current_metadata.block_size == BLOCK_SIZE as u32 && cur_ptr + num_segs == end { + // We do not have to do anything + prev_ts = current_metadata.timestamp; + break false; + } else if current_metadata.block_size == 0 { + // Initialize + if self.initial_block_size < align { + // Only happens in volatile, so empty initial memory + self.set_meta_block( + address_space, + cur_ptr * align, + align, + align, + INITIAL_TIMESTAMP, + ); + current_metadata = AccessMetadata::new(INITIAL_TIMESTAMP, align as u32); + } else { + cur_ptr -= cur_ptr % (self.initial_block_size / align); + self.set_meta_block( + address_space, + cur_ptr * align, + align, + self.initial_block_size, + INITIAL_TIMESTAMP, + ); + current_metadata = + AccessMetadata::new(INITIAL_TIMESTAMP, self.initial_block_size as u32); + } + } + prev_ts = prev_ts.max(current_metadata.timestamp); + while current_metadata.block_size == AccessMetadata::OCCUPIED { + cur_ptr -= 1; + current_metadata = self.meta[address_space] + .get::(cur_ptr * size_of::()); + } + block_timestamps[cur_ptr.saturating_sub(begin) + ..((cur_ptr + (current_metadata.block_size as usize) / align).min(end) - begin)] + .fill(current_metadata.timestamp); + if current_metadata.block_size > align as u32 { + // Split + let address = MemoryAddress::new(address_space as u32, (cur_ptr * align) as u32); + let values = (0..current_metadata.block_size as usize) + .map(|i| { + self.data + .memory + .get_f(address.address_space, address.pointer + (i as u32)) + }) + .collect::>(); + self.execute_splits::(address, align, &values, current_metadata.timestamp); + } + cur_ptr += current_metadata.block_size as usize / align; + }; + if need_to_merge { + // Merge + let values = (0..BLOCK_SIZE) + .map(|i| { + self.data + .memory + .get_f(address_space as u32, (pointer + i) as u32) + }) + .collect::>(); + self.execute_merges::( + MemoryAddress::new(address_space as u32, pointer as u32), + align, + &values, + &block_timestamps, + ); + } + prev_ts + } + + /// Atomic read operation which increments the timestamp by 1. + /// Returns `(t_prev, [pointer:BLOCK_SIZE]_{address_space})` where `t_prev` is the + /// timestamp of the last memory access. + /// + /// The previous memory access is treated as atomic even if previous accesses were for + /// a smaller block size. This is made possible by internal memory access adapters + /// that split/merge memory blocks. More specifically, the last memory access corresponding + /// to `t_prev` may refer to an atomic access inserted by the memory access adapters. + /// + /// # Assumptions + /// The `BLOCK_SIZE` is a multiple of `ALIGN`, which must equal the minimum block size + /// of `address_space`. + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + /// + /// In addition: + /// - `address_space` must be valid. + #[inline(always)] + pub unsafe fn read( &mut self, address_space: u32, pointer: u32, - values: [F; N], - ) -> (RecordId, [F; N]) { - assert!(N.is_power_of_two()); + ) -> (u32, [T; BLOCK_SIZE]) + where + T: Copy + Debug, + { + self.assert_alignment(BLOCK_SIZE, ALIGN, address_space, pointer); + let t_prev = + self.prev_access_time::(address_space as usize, pointer as usize, ALIGN); + let t_curr = self.timestamp; + self.timestamp += 1; + let values = self.data.read(address_space, pointer); + self.set_meta_block( + address_space as usize, + pointer as usize, + ALIGN, + BLOCK_SIZE, + t_curr, + ); - let prev_data = self.data.set_range(&(address_space, pointer), &values); + (t_prev, values) + } - self.log.push(MemoryLogEntry::Write { - address_space, - pointer, - data: values.to_vec(), - }); + /// Atomic write operation that writes `values` into `[pointer:BLOCK_SIZE]_{address_space}` and + /// then increments the timestamp by 1. Returns `(t_prev, values_prev)` which equal the + /// timestamp and value `[pointer:BLOCK_SIZE]_{address_space}` of the last memory access. + /// + /// The previous memory access is treated as atomic even if previous accesses were for + /// a smaller block size. This is made possible by internal memory access adapters + /// that split/merge memory blocks. More specifically, the last memory access corresponding + /// to `t_prev` may refer to an atomic access inserted by the memory access adapters. + /// + /// # Assumptions + /// The `BLOCK_SIZE` is a multiple of `ALIGN`, which must equal the minimum block size + /// of `address_space`. + /// + /// # Safety + /// The type `T` must be stack-allocated `repr(C)` or `repr(transparent)`, + /// and it must be the exact type used to represent a single memory cell in + /// address space `address_space`. For standard usage, + /// `T` is either `u8` or `F` where `F` is the base field of the ZK backend. + /// + /// In addition: + /// - `address_space` must be valid. + #[inline(always)] + pub unsafe fn write( + &mut self, + address_space: u32, + pointer: u32, + values: &[T; BLOCK_SIZE], + ) -> (u32, [T; BLOCK_SIZE]) + where + T: Copy + Debug, + { + self.assert_alignment(BLOCK_SIZE, ALIGN, address_space, pointer); + let t_prev = + self.prev_access_time::(address_space as usize, pointer as usize, ALIGN); + let values_prev = self.data.replace(address_space, pointer, values); + let t_curr = self.timestamp; self.timestamp += 1; + self.set_meta_block( + address_space as usize, + pointer as usize, + ALIGN, + BLOCK_SIZE, + t_curr, + ); - (self.last_record_id(), prev_data) + (t_prev, values_prev) } - /// Reads an array of values from the memory at the specified address space and start index. - pub fn read(&mut self, address_space: u32, pointer: u32) -> (RecordId, [F; N]) { - assert!(N.is_power_of_two()); - - self.log.push(MemoryLogEntry::Read { - address_space, - pointer, - len: N, - }); - - let values = if address_space == 0 { - assert_eq!(N, 1, "cannot batch read from address space 0"); - [F::from_canonical_u32(pointer); N] - } else { - self.range_array::(address_space, pointer) - }; + pub fn increment_timestamp(&mut self) { self.timestamp += 1; - (self.last_record_id(), values) } pub fn increment_timestamp_by(&mut self, amount: u32) { self.timestamp += amount; - self.log.push(MemoryLogEntry::IncrementTimestampBy(amount)) } pub fn timestamp(&self) -> u32 { self.timestamp } - #[inline(always)] - pub fn get(&self, address_space: u32, pointer: u32) -> F { - *self.data.get(&(address_space, pointer)).unwrap_or(&F::ZERO) - } - - #[inline(always)] - fn range_array(&self, address_space: u32, pointer: u32) -> [F; N] { - self.data.get_range(&(address_space, pointer)) + /// Returns iterator over `((addr_space, ptr), (timestamp, block_size))` of the address, last + /// accessed timestamp, and block size of all memory blocks that have been accessed since this + /// instance of [TracingMemory] was constructed. This is similar to a soft-dirty mechanism, + /// where the memory data is loaded from an initial image and considered "clean", and then + /// all future accesses are marked as "dirty". + // block_size is initialized to 0, so nonzero block_size happens to also mark "dirty" cells + // **Assuming** for now that only the start of a block has nonzero block_size + pub fn touched_blocks(&self) -> impl Iterator + '_ { + zip_eq(&self.meta, &self.min_block_size) + .enumerate() + .flat_map(move |(addr_space, (page, &align))| { + page.iter::() + .filter_map(move |(idx, metadata)| { + (metadata.block_size != 0 + && metadata.block_size != AccessMetadata::OCCUPIED) + .then_some(((addr_space as u32, idx as u32 * align), metadata)) + }) + }) } } -#[cfg(test)] -mod tests { - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - - use super::Memory; - use crate::arch::MemoryConfig; +// #[cfg(test)] +// mod tests { +// use super::TracingMemory; +// use crate::arch::MemoryConfig; - macro_rules! bba { - [$($x:expr),*] => { - [$(BabyBear::from_canonical_u32($x)),*] - } - } - - #[test] - fn test_write_read() { - let mut memory = Memory::new(&MemoryConfig::default()); - let address_space = 1; +// #[test] +// fn test_write_read() { +// let mut memory = TracingMemory::new(&MemoryConfig::default()); +// let address_space = 1; - memory.write(address_space, 0, bba![1, 2, 3, 4]); +// unsafe { +// memory.write(address_space, 0, &[1u8, 2, 3, 4]); - let (_, data) = memory.read::<2>(address_space, 0); - assert_eq!(data, bba![1, 2]); +// let (_, data) = memory.read::(address_space, 0); +// assert_eq!(data, [1u8, 2]); - memory.write(address_space, 2, bba![100]); +// memory.write(address_space, 2, &[100u8]); - let (_, data) = memory.read::<4>(address_space, 0); - assert_eq!(data, bba![1, 2, 100, 4]); - } -} +// let (_, data) = memory.read::(address_space, 0); +// assert_eq!(data, [1u8, 2, 100, 4]); +// } +// } +// } diff --git a/crates/vm/src/system/memory/paged_vec.rs b/crates/vm/src/system/memory/paged_vec.rs index 8a8b030970..ead78f81ee 100644 --- a/crates/vm/src/system/memory/paged_vec.rs +++ b/crates/vm/src/system/memory/paged_vec.rs @@ -1,27 +1,41 @@ -use std::{mem::MaybeUninit, ops::Range, ptr}; - +use std::{ + alloc::{alloc, Layout}, + fmt::Debug, + marker::PhantomData, + mem::MaybeUninit, + ptr, +}; + +use itertools::{zip_eq, Itertools}; +use openvm_instructions::exe::SparseMemoryImage; +use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use crate::arch::MemoryConfig; /// (address_space, pointer) pub type Address = (u32, u32); +/// 4096 is the default page size on host architectures if huge pages is not enabled pub const PAGE_SIZE: usize = 1 << 12; +// TODO[jpw]: replace this with mmap implementation #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct PagedVec { - pub pages: Vec>>, +pub struct PagedVec { + /// Assume each page in `pages` is either unalloc or PAGE_SIZE bytes long and aligned to + /// PAGE_SIZE + pub pages: Vec>>, } // ------------------------------------------------------------------ // Common Helper Functions // These functions encapsulate the common logic for copying ranges // across pages, both for read-only and read-write (set) cases. -impl PagedVec { +impl PagedVec { // Copies a range of length `len` starting at index `start` // into the memory pointed to by `dst`. If the relevant page is not - // initialized, fills that portion with T::default(). - fn read_range_generic(&self, start: usize, len: usize, dst: *mut T) { + // initialized, fills that portion with `0u8`. + #[inline] + pub fn read_range_generic(&self, start: usize, len: usize, dst: *mut u8) { let start_page = start / PAGE_SIZE; let end_page = (start + len - 1) / PAGE_SIZE; unsafe { @@ -30,22 +44,27 @@ impl PagedVec { if let Some(page) = self.pages[start_page].as_ref() { ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len); } else { - std::slice::from_raw_parts_mut(dst, len).fill(T::default()); + std::slice::from_raw_parts_mut(dst, len).fill(0u8); } } else { + debug_assert_eq!( + start_page + 1, + end_page, + "Range spans more than two pages: {:?}", + (start_page, end_page, start, len) + ); let offset = start % PAGE_SIZE; let first_part = PAGE_SIZE - offset; if let Some(page) = self.pages[start_page].as_ref() { ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part); } else { - std::slice::from_raw_parts_mut(dst, first_part).fill(T::default()); + std::slice::from_raw_parts_mut(dst, first_part).fill(0u8); } let second_part = len - first_part; if let Some(page) = self.pages[end_page].as_ref() { ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part); } else { - std::slice::from_raw_parts_mut(dst.add(first_part), second_part) - .fill(T::default()); + std::slice::from_raw_parts_mut(dst.add(first_part), second_part).fill(0u8); } } } @@ -55,29 +74,28 @@ impl PagedVec { // It copies the current values into the memory pointed to by `dst` // and then writes the new values into the underlying pages, // allocating pages (with defaults) if necessary. - fn set_range_generic(&mut self, start: usize, len: usize, new: *const T, dst: *mut T) { + #[inline] + pub fn set_range_generic(&mut self, start: usize, len: usize, new: *const u8, dst: *mut u8) { let start_page = start / PAGE_SIZE; let end_page = (start + len - 1) / PAGE_SIZE; unsafe { if start_page == end_page { let offset = start % PAGE_SIZE; - let page = - self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); + let page = self.pages[start_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]); ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, len); ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), len); } else { + assert_eq!(start_page + 1, end_page); let offset = start % PAGE_SIZE; let first_part = PAGE_SIZE - offset; { - let page = - self.pages[start_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); + let page = self.pages[start_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]); ptr::copy_nonoverlapping(page.as_ptr().add(offset), dst, first_part); ptr::copy_nonoverlapping(new, page.as_mut_ptr().add(offset), first_part); } let second_part = len - first_part; { - let page = - self.pages[end_page].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); + let page = self.pages[end_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]); ptr::copy_nonoverlapping(page.as_ptr(), dst.add(first_part), second_part); ptr::copy_nonoverlapping(new.add(first_part), page.as_mut_ptr(), second_part); } @@ -88,118 +106,113 @@ impl PagedVec { // ------------------------------------------------------------------ // Implementation for types requiring Default + Clone -impl PagedVec { +impl PagedVec { pub fn new(num_pages: usize) -> Self { Self { pages: vec![None; num_pages], } } - pub fn get(&self, index: usize) -> Option<&T> { - let page_idx = index / PAGE_SIZE; - self.pages[page_idx] - .as_ref() - .map(|page| &page[index % PAGE_SIZE]) - } - - pub fn get_mut(&mut self, index: usize) -> Option<&mut T> { - let page_idx = index / PAGE_SIZE; - self.pages[page_idx] - .as_mut() - .map(|page| &mut page[index % PAGE_SIZE]) - } - - pub fn set(&mut self, index: usize, value: T) -> Option { - let page_idx = index / PAGE_SIZE; - if let Some(page) = self.pages[page_idx].as_mut() { - Some(std::mem::replace(&mut page[index % PAGE_SIZE], value)) - } else { - let page = self.pages[page_idx].get_or_insert_with(|| vec![T::default(); PAGE_SIZE]); - page[index % PAGE_SIZE] = value; - None - } - } - - #[inline(always)] - pub fn range_vec(&self, range: Range) -> Vec { - let len = range.end - range.start; - // Create a vector for uninitialized values. - let mut result: Vec> = Vec::with_capacity(len); - // SAFETY: We set the length and then initialize every element via read_range_generic. - unsafe { - result.set_len(len); - self.read_range_generic(range.start, len, result.as_mut_ptr() as *mut T); - std::mem::transmute::>, Vec>(result) - } - } - - pub fn set_range(&mut self, range: Range, values: &[T]) -> Vec { - let len = range.end - range.start; - assert_eq!(values.len(), len); - let mut result: Vec> = Vec::with_capacity(len); - // SAFETY: We will write to every element in result via set_range_generic. - unsafe { - result.set_len(len); - self.set_range_generic( - range.start, - len, - values.as_ptr(), - result.as_mut_ptr() as *mut T, - ); - std::mem::transmute::>, Vec>(result) - } - } - - pub fn memory_size(&self) -> usize { - self.pages.len() * PAGE_SIZE + /// Total capacity across available pages, in bytes. + pub fn bytes_capacity(&self) -> usize { + self.pages.len().checked_mul(PAGE_SIZE).unwrap() } pub fn is_empty(&self) -> bool { self.pages.iter().all(|page| page.is_none()) } -} -// ------------------------------------------------------------------ -// Implementation for types requiring Default + Copy -impl PagedVec { + /// # Panics + /// If `from..from + size_of()` is out of bounds. + #[inline(always)] + pub fn get(&self, from: usize) -> BLOCK { + // Create an uninitialized array of MaybeUninit + let mut result: MaybeUninit = MaybeUninit::uninit(); + self.read_range_generic(from, size_of::(), result.as_mut_ptr() as *mut u8); + // SAFETY: + // - All elements have been initialized (zero-initialized if page didn't exist). + // - `result` is aligned to `BLOCK` + unsafe { result.assume_init() } + } + + /// # Panics + /// If `start..start + size_of()` is out of bounds. + // @dev: `values` is passed by reference since the data is copied into memory. Even though the + // compiler probably optimizes it, we use reference to avoid any unnecessary copy of `values` + // onto the stack in the function call. #[inline(always)] - pub fn range_array(&self, from: usize) -> [T; N] { - // Create an uninitialized array of MaybeUninit - let mut result: [MaybeUninit; N] = unsafe { - // SAFETY: An uninitialized `[MaybeUninit; N]` is valid. - MaybeUninit::uninit().assume_init() - }; - self.read_range_generic(from, N, result.as_mut_ptr() as *mut T); - // SAFETY: All elements have been initialized. - unsafe { ptr::read(&result as *const _ as *const [T; N]) } + pub fn set(&mut self, start: usize, values: &BLOCK) { + let len = size_of::(); + let start_page = start / PAGE_SIZE; + let end_page = (start + len - 1) / PAGE_SIZE; + let src = values as *const _ as *const u8; + unsafe { + if start_page == end_page { + let offset = start % PAGE_SIZE; + let page = self.pages[start_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]); + ptr::copy_nonoverlapping(src, page.as_mut_ptr().add(offset), len); + } else { + assert_eq!(start_page + 1, end_page); + let offset = start % PAGE_SIZE; + let first_part = PAGE_SIZE - offset; + { + let page = self.pages[start_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]); + ptr::copy_nonoverlapping(src, page.as_mut_ptr().add(offset), first_part); + } + let second_part = len - first_part; + { + let page = self.pages[end_page].get_or_insert_with(|| vec![0u8; PAGE_SIZE]); + ptr::copy_nonoverlapping(src.add(first_part), page.as_mut_ptr(), second_part); + } + } + } } + /// memcpy of new `values` into pages, memcpy of old existing values into new returned value. + /// # Panics + /// If `from..from + size_of()` is out of bounds. #[inline(always)] - pub fn set_range_array(&mut self, from: usize, values: &[T; N]) -> [T; N] { + pub fn replace(&mut self, from: usize, values: &BLOCK) -> BLOCK { // Create an uninitialized array for old values. - let mut result: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; - self.set_range_generic(from, N, values.as_ptr(), result.as_mut_ptr() as *mut T); - unsafe { ptr::read(&result as *const _ as *const [T; N]) } + let mut result: MaybeUninit = MaybeUninit::uninit(); + self.set_range_generic( + from, + size_of::(), + values as *const _ as *const u8, + result.as_mut_ptr() as *mut u8, + ); + // SAFETY: + // - All elements have been initialized (zero-initialized if page didn't exist). + // - `result` is aligned to `BLOCK` + unsafe { result.assume_init() } } } -impl PagedVec { - pub fn iter(&self) -> PagedVecIter<'_, T, PAGE_SIZE> { +impl PagedVec { + /// Iterate over [PagedVec] as iterator of elements of type `T`. + /// Iterator is over `(index, element)` where `index` is the byte index divided by + /// `size_of::()`. + /// + /// `T` must be stack allocated + pub fn iter(&self) -> PagedVecIter<'_, T, PAGE_SIZE> { + assert!(size_of::() <= PAGE_SIZE); PagedVecIter { vec: self, current_page: 0, current_index_in_page: 0, + phantom: PhantomData, } } } pub struct PagedVecIter<'a, T, const PAGE_SIZE: usize> { - vec: &'a PagedVec, + vec: &'a PagedVec, current_page: usize, current_index_in_page: usize, + phantom: PhantomData, } -impl Iterator for PagedVecIter<'_, T, PAGE_SIZE> { +impl Iterator for PagedVecIter<'_, T, PAGE_SIZE> { type Item = (usize, T); fn next(&mut self) -> Option { @@ -210,39 +223,56 @@ impl Iterator for PagedVecIter<'_, T, PAGE_SIZ debug_assert_eq!(self.current_index_in_page, 0); self.current_index_in_page = 0; } - if self.current_page >= self.vec.pages.len() { + let global_index = self.current_page * PAGE_SIZE + self.current_index_in_page; + if global_index + size_of::() > self.vec.bytes_capacity() { return None; } - let global_index = self.current_page * PAGE_SIZE + self.current_index_in_page; - let page = self.vec.pages[self.current_page].as_ref()?; - let value = page[self.current_index_in_page].clone(); + // PERF: this can be optimized + let value = self.vec.get(global_index); - self.current_index_in_page += 1; - if self.current_index_in_page == PAGE_SIZE { + self.current_index_in_page += size_of::(); + if self.current_index_in_page >= PAGE_SIZE { self.current_page += 1; - self.current_index_in_page = 0; + self.current_index_in_page -= PAGE_SIZE; } - Some((global_index, value)) + Some((global_index / size_of::(), value)) } } +/// Map from address space to guest memory. +/// The underlying memory is typeless, stored as raw bytes, but usage +/// implicitly assumes that each address space has memory cells of a fixed type (e.g., `u8, F`). +/// We do not use a typemap for performance reasons, and it is up to the user to enforce types. +/// Needless to say, this is a very `unsafe` API. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct AddressMap { - pub paged_vecs: Vec>, +pub struct AddressMap { + pub paged_vecs: Vec>, + /// byte size of cells per address space + pub cell_size: Vec, pub as_offset: u32, } -impl Default for AddressMap { +impl Default for AddressMap { fn default() -> Self { Self::from_mem_config(&MemoryConfig::default()) } } -impl AddressMap { +impl AddressMap { pub fn new(as_offset: u32, as_cnt: usize, mem_size: usize) -> Self { + // TMP: hardcoding for now + let mut cell_size = vec![1, 1, 1]; + cell_size.resize(as_cnt, 4); + let paged_vecs = cell_size + .iter() + .map(|&cell_size| { + PagedVec::new(mem_size.checked_mul(cell_size).unwrap().div_ceil(PAGE_SIZE)) + }) + .collect(); Self { - paged_vecs: vec![PagedVec::new(mem_size.div_ceil(PAGE_SIZE)); as_cnt], + paged_vecs, + cell_size, as_offset, } } @@ -253,51 +283,147 @@ impl AddressMap { 1 << mem_config.pointer_max_bits, ) } - pub fn items(&self) -> impl Iterator + '_ { - self.paged_vecs - .iter() + pub fn items(&self) -> impl Iterator + '_ { + zip_eq(&self.paged_vecs, &self.cell_size) .enumerate() - .flat_map(move |(as_idx, page)| { - page.iter() - .map(move |(ptr_idx, x)| ((as_idx as u32 + self.as_offset, ptr_idx as u32), x)) + .flat_map(move |(as_idx, (page, &cell_size))| { + // TODO: better way to handle address space conversions to F + if cell_size == 1 { + page.iter::() + .map(move |(ptr_idx, x)| { + ( + (as_idx as u32 + self.as_offset, ptr_idx as u32), + F::from_canonical_u8(x), + ) + }) + .collect_vec() + } else { + // TEMP + assert_eq!(cell_size, 4); + page.iter::() + .map(move |(ptr_idx, x)| { + ((as_idx as u32 + self.as_offset, ptr_idx as u32), x) + }) + .collect_vec() + } }) } - pub fn get(&self, address: &Address) -> Option<&T> { - self.paged_vecs[(address.0 - self.as_offset) as usize].get(address.1 as usize) + + pub fn get_f(&self, addr_space: u32, ptr: u32) -> F { + debug_assert_ne!(addr_space, 0); + // TODO: fix this + unsafe { + if addr_space <= 3 { + F::from_canonical_u8(self.get::((addr_space, ptr))) + } else { + self.get::((addr_space, ptr)) + } + } } - pub fn get_mut(&mut self, address: &Address) -> Option<&mut T> { - self.paged_vecs[(address.0 - self.as_offset) as usize].get_mut(address.1 as usize) + + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn get(&self, (addr_space, ptr): Address) -> T { + debug_assert_eq!( + size_of::(), + self.cell_size[(addr_space - self.as_offset) as usize] + ); + self.paged_vecs + .get_unchecked((addr_space - self.as_offset) as usize) + .get((ptr as usize) * size_of::()) + } + + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub fn read_range_generic( + &self, + (addr_space, ptr): Address, + len: usize, + ) -> Vec { + let mut block: Vec = Vec::with_capacity(len); + unsafe { + self.paged_vecs + .get_unchecked((addr_space - self.as_offset) as usize) + .read_range_generic( + (ptr as usize) * size_of::(), + len * size_of::(), + block.as_mut_ptr() as *mut u8, + ); + block.set_len(len); + } + block } - pub fn insert(&mut self, address: &Address, data: T) -> Option { - self.paged_vecs[(address.0 - self.as_offset) as usize].set(address.1 as usize, data) + + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub unsafe fn insert(&mut self, (addr_space, ptr): Address, data: T) -> T { + debug_assert_eq!( + size_of::(), + self.cell_size[(addr_space - self.as_offset) as usize] + ); + self.paged_vecs + .get_unchecked_mut((addr_space - self.as_offset) as usize) + .replace((ptr as usize) * size_of::(), &data) } pub fn is_empty(&self) -> bool { self.paged_vecs.iter().all(|page| page.is_empty()) } - pub fn from_iter( + // TODO[jpw]: stabilize the boundary memory image format and how to construct + /// # Safety + /// - `T` **must** be the correct type for a single memory cell for `addr_space` + /// - Assumes `addr_space` is within the configured memory and not out of bounds + pub fn from_sparse( as_offset: u32, as_cnt: usize, mem_size: usize, - iter: impl IntoIterator, + sparse_map: SparseMemoryImage, ) -> Self { let mut vec = Self::new(as_offset, as_cnt, mem_size); - for (address, data) in iter { - vec.insert(&address, data); + for ((addr_space, index), data_byte) in sparse_map.into_iter() { + vec.paged_vecs[(addr_space - as_offset) as usize].set(index as usize, &data_byte); } vec } } -impl AddressMap { - pub fn get_range(&self, address: &Address) -> [T; N] { - self.paged_vecs[(address.0 - self.as_offset) as usize].range_array(address.1 as usize) - } - pub fn set_range(&mut self, address: &Address, values: &[T; N]) -> [T; N] { - self.paged_vecs[(address.0 - self.as_offset) as usize] - .set_range_array(address.1 as usize, values) - } -} +// impl GuestMemory for AddressMap { +// unsafe fn read(&self, addr_space: u32, ptr: u32) -> [T; +// BLOCK_SIZE] where +// T: Copy + Debug, +// { +// debug_assert_eq!( +// size_of::(), +// self.cell_size[(addr_space - self.as_offset) as usize] +// ); +// let read = self +// .paged_vecs +// .get_unchecked((addr_space - self.as_offset) as usize) +// .get((ptr as usize) * size_of::()); +// read +// } + +// unsafe fn write( +// &mut self, +// addr_space: u32, +// ptr: u32, +// values: &[T; BLOCK_SIZE], +// ) where +// T: Copy + Debug, +// { +// debug_assert_eq!( +// size_of::(), +// self.cell_size[(addr_space - self.as_offset) as usize], +// "addr_space={addr_space}" +// ); +// self.paged_vecs +// .get_unchecked_mut((addr_space - self.as_offset) as usize) +// .set((ptr as usize) * size_of::(), values); +// } +// } #[cfg(test)] mod tests { @@ -305,143 +431,144 @@ mod tests { #[test] fn test_basic_get_set() { - let mut v = PagedVec::<_, 4>::new(3); - assert_eq!(v.get(0), None); - v.set(0, 42); - assert_eq!(v.get(0), Some(&42)); - } - - #[test] - fn test_cross_page_operations() { - let mut v = PagedVec::<_, 4>::new(3); - v.set(3, 10); // Last element of first page - v.set(4, 20); // First element of second page - assert_eq!(v.get(3), Some(&10)); - assert_eq!(v.get(4), Some(&20)); - } - - #[test] - fn test_page_boundaries() { - let mut v = PagedVec::<_, 4>::new(2); - // Fill first page - v.set(0, 1); - v.set(1, 2); - v.set(2, 3); - v.set(3, 4); - // Fill second page - v.set(4, 5); - v.set(5, 6); - v.set(6, 7); - v.set(7, 8); - - // Verify all values - assert_eq!(v.range_vec(0..8), [1, 2, 3, 4, 5, 6, 7, 8]); - } - - #[test] - fn test_range_cross_page_boundary() { - let mut v = PagedVec::<_, 4>::new(2); - v.set_range(2..8, &[10, 11, 12, 13, 14, 15]); - assert_eq!(v.range_vec(2..8), [10, 11, 12, 13, 14, 15]); - } - - #[test] - fn test_large_indices() { - let mut v = PagedVec::<_, 4>::new(100); - let large_index = 399; - v.set(large_index, 42); - assert_eq!(v.get(large_index), Some(&42)); - } - - #[test] - fn test_range_operations_with_defaults() { - let mut v = PagedVec::<_, 4>::new(3); - v.set(2, 5); - v.set(5, 10); - - // Should include both set values and defaults - assert_eq!(v.range_vec(1..7), [0, 5, 0, 0, 10, 0]); - } - - #[test] - fn test_non_zero_default_type() { - let mut v: PagedVec = PagedVec::new(2); - assert_eq!(v.get(0), None); // bool's default - v.set(0, true); - assert_eq!(v.get(0), Some(&true)); - assert_eq!(v.get(1), Some(&false)); // because we created the page - } - - #[test] - fn test_set_range_overlapping_pages() { - let mut v = PagedVec::<_, 4>::new(3); - let test_data = [1, 2, 3, 4, 5, 6]; - v.set_range(2..8, &test_data); - - // Verify first page - assert_eq!(v.get(2), Some(&1)); - assert_eq!(v.get(3), Some(&2)); - - // Verify second page - assert_eq!(v.get(4), Some(&3)); - assert_eq!(v.get(5), Some(&4)); - assert_eq!(v.get(6), Some(&5)); - assert_eq!(v.get(7), Some(&6)); - } - - #[test] - fn test_overlapping_set_ranges() { - let mut v = PagedVec::<_, 4>::new(3); - - // Initial set_range - v.set_range(0..5, &[1, 2, 3, 4, 5]); - assert_eq!(v.range_vec(0..5), [1, 2, 3, 4, 5]); - - // Overlap from beginning - v.set_range(0..3, &[10, 20, 30]); - assert_eq!(v.range_vec(0..5), [10, 20, 30, 4, 5]); - - // Overlap in middle - v.set_range(2..4, &[42, 43]); - assert_eq!(v.range_vec(0..5), [10, 20, 42, 43, 5]); - - // Overlap at end - v.set_range(4..6, &[91, 92]); - assert_eq!(v.range_vec(0..6), [10, 20, 42, 43, 91, 92]); - } - - #[test] - fn test_overlapping_set_ranges_cross_pages() { - let mut v = PagedVec::<_, 4>::new(3); - - // Fill across first two pages - v.set_range(0..8, &[1, 2, 3, 4, 5, 6, 7, 8]); - - // Overlap end of first page and start of second - v.set_range(2..6, &[21, 22, 23, 24]); - assert_eq!(v.range_vec(0..8), [1, 2, 21, 22, 23, 24, 7, 8]); - - // Overlap multiple pages - v.set_range(1..7, &[31, 32, 33, 34, 35, 36]); - assert_eq!(v.range_vec(0..8), [1, 31, 32, 33, 34, 35, 36, 8]); - } - - #[test] - fn test_iterator() { - let mut v = PagedVec::<_, 4>::new(3); - - v.set_range(4..10, &[1, 2, 3, 4, 5, 6]); - let contents: Vec<_> = v.iter().collect(); - assert_eq!(contents.len(), 8); // two pages - - contents - .iter() - .take(6) - .enumerate() - .for_each(|(i, &(idx, val))| { - assert_eq!((idx, val), (4 + i, 1 + i)); - }); - assert_eq!(contents[6], (10, 0)); - assert_eq!(contents[7], (11, 0)); - } + let mut v = PagedVec::<16>::new(3); + assert_eq!(v.get::(0), 0u32); + v.set(0, &42u32); + assert_eq!(v.get::(0), 42u32); + } + + // TEMP: disable tests (need to update indexing * size_of) + // #[test] + // fn test_cross_page_operations() { + // let mut v = PagedVec::<16>::new(3); + // v.set(3, 10u32); // Last element of first page + // v.set(4, 20u32); // First element of second page + // assert_eq!(v.get(3), 10u32); + // assert_eq!(v.get(4), 20u32); + // } + + // #[test] + // fn test_page_boundaries() { + // let mut v = PagedVec::<16>::new(2); + // // Fill first page + // v.set(0, 1u32); + // v.set(1, 2u32); + // v.set(2, 3u32); + // v.set(3, 4u32); + // // Fill second page + // v.set(4, 5u32); + // v.set(5, 6u32); + // v.set(6, 7u32); + // v.set(7, 8u32); + + // // Verify all values + // assert_eq!(v.get::<[u32; 8]>(0), [1, 2, 3, 4, 5, 6, 7, 8]); + // } + + // #[test] + // fn test_range_cross_page_boundary() { + // let mut v = PagedVec::<16>::new(2); + // v.set::<[u32; 6]>(2, [10, 11, 12, 13, 14, 15]); + // assert_eq!(v.get::<[u32; 6]>(2), [10, 11, 12, 13, 14, 15]); + // } + + // #[test] + // fn test_large_indices() { + // let mut v = PagedVec::<16>::new(100); + // let large_index = 399; + // v.set(large_index, 42u32); + // assert_eq!(v.get(large_index), 42u32); + // } + + // #[test] + // fn test_range_operations_with_defaults() { + // let mut v = PagedVec::<16>::new(3); + // v.set(2, 5u32); + // v.set(5, 10u32); + + // // Should include both set values and defaults + // assert_eq!(v.range_vec(1..7), [0, 5, 0, 0, 10, 0]); + // } + + // #[test] + // fn test_non_zero_default_type() { + // let mut v: PagedVec<4> = PagedVec::new(2); + // assert_eq!(v.get(0), false); // bool's default + // v.set(0, true); + // assert_eq!(v.get(0), true); + // assert_eq!(v.get(1), false); // because we created the page + // } + + // #[test] + // fn test_set_range_overlapping_pages() { + // let mut v = PagedVec::<_, 16>::new(3); + // let test_data = [1u32, 2, 3, 4, 5, 6]; + // v.set(2, test_data); + + // // Verify first page + // assert_eq!(v.get(2), 1u32); + // assert_eq!(v.get(3), 2u32); + + // // Verify second page + // assert_eq!(v.get(4), 3u32); + // assert_eq!(v.get(5), 4u32); + // assert_eq!(v.get(6), 5u32); + // assert_eq!(v.get(7), 6u32); + // } + + // #[test] + // fn test_overlapping_set_ranges() { + // let mut v = PagedVec::<_, 16>::new(3); + + // // Initial set_range + // v.set(0, [1u32, 2, 3, 4, 5]); + // assert_eq!(v.get::<[u32; 5]>(0), [1, 2, 3, 4, 5]); + + // // Overlap from beginning + // v.set(0, [10u32, 20, 30]); + // assert_eq!(v.get::<[u32; 5]>(0), [10, 20, 30, 4, 5]); + + // // Overlap in middle + // v.set(2, [42u32, 43]); + // assert_eq!(v.get::<[u32; 5]>(0), [10, 20, 42, 43, 5]); + + // // Overlap at end + // v.set(4, [91u32, 92]); + // assert_eq!(v.get::<[u32; 6]>(0), [10, 20, 42, 43, 91, 92]); + // } + + // #[test] + // fn test_overlapping_set_ranges_cross_pages() { + // let mut v = PagedVec::<16>::new(3); + + // // Fill across first two pages + // v.set::<[u32; 8]>(0, [1, 2, 3, 4, 5, 6, 7, 8]); + + // // Overlap end of first page and start of second + // v.set::<[u32; 4]>(2, [21, 22, 23, 24]); + // assert_eq!(v.get::<[u32; 8]>(0), [1, 2, 21, 22, 23, 24, 7, 8]); + + // // Overlap multiple pages + // v.set::<[u32; 6]>(1, [31, 32, 33, 34, 35, 36]); + // assert_eq!(v.get::<[u32; 8]>(0), [1, 31, 32, 33, 34, 35, 36, 8]); + // } + + // #[test] + // fn test_iterator() { + // let mut v = PagedVec::<16>::new(3); + + // v.set(4..10, &[1, 2, 3, 4, 5, 6]); + // let contents: Vec<_> = v.iter().collect(); + // assert_eq!(contents.len(), 8); // two pages + + // contents + // .iter() + // .take(6) + // .enumerate() + // .for_each(|(i, &(idx, val))| { + // assert_eq!((idx, val), (4 + i, 1 + i)); + // }); + // assert_eq!(contents[6], (10, 0)); + // assert_eq!(contents[7], (11, 0)); + // } } diff --git a/crates/vm/src/system/memory/persistent.rs b/crates/vm/src/system/memory/persistent.rs index 55a178be4d..bdd1fa4d04 100644 --- a/crates/vm/src/system/memory/persistent.rs +++ b/crates/vm/src/system/memory/persistent.rs @@ -19,12 +19,12 @@ use openvm_stark_backend::{ }; use rustc_hash::FxHashSet; -use super::merkle::SerialReceiver; +use super::{merkle::SerialReceiver, online::INITIAL_TIMESTAMP, TimestampedValues}; use crate::{ arch::hasher::Hasher, system::memory::{ dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage, - TimestampedEquipartition, INITIAL_TIMESTAMP, + TimestampedEquipartition, }, }; @@ -123,18 +123,18 @@ impl Air for PersistentBoundaryA pub struct PersistentBoundaryChip { pub air: PersistentBoundaryAir, - touched_labels: TouchedLabels, + pub touched_labels: TouchedLabels, overridden_height: Option, } #[derive(Debug)] -enum TouchedLabels { +pub enum TouchedLabels { Running(FxHashSet<(u32, u32)>), Final(Vec>), } #[derive(Debug)] -struct FinalTouchedLabel { +pub struct FinalTouchedLabel { address_space: u32, label: u32, init_values: [F; CHUNK], @@ -159,7 +159,8 @@ impl TouchedLabels { _ => panic!("Cannot touch after finalization"), } } - fn len(&self) -> usize { + + pub fn len(&self) -> usize { match self { TouchedLabels::Running(touched_labels) => touched_labels.len(), TouchedLabels::Final(touched_labels) => touched_labels.len(), @@ -200,45 +201,37 @@ impl PersistentBoundaryChip { pub fn finalize( &mut self, - initial_memory: &MemoryImage, + initial_memory: &MemoryImage, + // Only touched stuff final_memory: &TimestampedEquipartition, hasher: &mut H, ) where H: Hasher + Sync + for<'a> SerialReceiver<&'a [F]>, { - match &mut self.touched_labels { - TouchedLabels::Running(touched_labels) => { - let final_touched_labels: Vec<_> = touched_labels - .par_iter() - .map(|&(address_space, label)| { - let pointer = label * CHUNK as u32; - let init_values = array::from_fn(|i| { - *initial_memory - .get(&(address_space, pointer + i as u32)) - .unwrap_or(&F::ZERO) - }); - let initial_hash = hasher.hash(&init_values); - let timestamped_values = final_memory.get(&(address_space, label)).unwrap(); - let final_hash = hasher.hash(×tamped_values.values); - FinalTouchedLabel { - address_space, - label, - init_values, - final_values: timestamped_values.values, - init_hash: initial_hash, - final_hash, - final_timestamp: timestamped_values.timestamp, - } - }) - .collect(); - for l in &final_touched_labels { - hasher.receive(&l.init_values); - hasher.receive(&l.final_values); + let final_touched_labels: Vec<_> = final_memory + .par_iter() + .map(|(&(addr_space, ptr), &ts_values)| { + let TimestampedValues { timestamp, values } = ts_values; + let init_values = + array::from_fn(|i| initial_memory.get_f::(addr_space, ptr + i as u32)); + let initial_hash = hasher.hash(&init_values); + let final_hash = hasher.hash(&values); + FinalTouchedLabel { + address_space: addr_space, + label: ptr / CHUNK as u32, + init_values, + final_values: values, + init_hash: initial_hash, + final_hash, + final_timestamp: timestamp, } - self.touched_labels = TouchedLabels::Final(final_touched_labels); - } - _ => panic!("Cannot finalize after finalization"), + }) + .collect(); + for l in &final_touched_labels { + hasher.receive(&l.init_values); + hasher.receive(&l.final_values); } + self.touched_labels = TouchedLabels::Final(final_touched_labels); } } diff --git a/crates/vm/src/system/memory/tests.rs b/crates/vm/src/system/memory/tests.rs index 9ebb9306aa..b950e7862e 100644 --- a/crates/vm/src/system/memory/tests.rs +++ b/crates/vm/src/system/memory/tests.rs @@ -292,34 +292,35 @@ fn make_random_accesses( ) -> Vec { (0..1024) .map(|_| { - let address_space = F::from_canonical_u32(*[1, 2].choose(&mut rng).unwrap()); + let address_space = F::from_canonical_u32(*[4, 5].choose(&mut rng).unwrap()); match rng.gen_range(0..5) { 0 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 1)); let data = F::from_canonical_u32(rng.gen_range(0..1 << 30)); - let (record_id, _) = memory_controller.write(address_space, pointer, [data]); + let (record_id, _) = memory_controller.write(address_space, pointer, &[data]); record_id } 1 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 1)); - let (record_id, _) = memory_controller.read::<1>(address_space, pointer); + let (record_id, _) = memory_controller.read::(address_space, pointer); record_id } 2 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 4)); - let (record_id, _) = memory_controller.read::<4>(address_space, pointer); + let (record_id, _) = memory_controller.read::(address_space, pointer); record_id } 3 => { let pointer = F::from_canonical_usize(gen_pointer(rng, 4)); let data = array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..1 << 30))); - let (record_id, _) = memory_controller.write::<4>(address_space, pointer, data); + let (record_id, _) = + memory_controller.write::(address_space, pointer, &data); record_id } 4 => { let pointer = F::from_canonical_usize(gen_pointer(rng, MAX)); - let (record_id, _) = memory_controller.read::(address_space, pointer); + let (record_id, _) = memory_controller.read::(address_space, pointer); record_id } _ => unreachable!(), diff --git a/crates/vm/src/system/memory/tree/mod.rs b/crates/vm/src/system/memory/tree/mod.rs index fcdb86d8ee..cb32429d5a 100644 --- a/crates/vm/src/system/memory/tree/mod.rs +++ b/crates/vm/src/system/memory/tree/mod.rs @@ -142,7 +142,7 @@ impl MemoryNode { pub fn tree_from_memory( memory_dimensions: MemoryDimensions, - memory: &MemoryImage, + memory: &MemoryImage, hasher: &(impl Hasher + Sync), ) -> MemoryNode { // Construct a Vec that includes the address space in the label calculation, diff --git a/crates/vm/src/system/memory/tree/public_values.rs b/crates/vm/src/system/memory/tree/public_values.rs index 1c6866b959..6df24ec4e0 100644 --- a/crates/vm/src/system/memory/tree/public_values.rs +++ b/crates/vm/src/system/memory/tree/public_values.rs @@ -11,6 +11,7 @@ use crate::{ }, }; +pub const PUBLIC_VALUES_AS: u32 = 3; pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = 2; /// Merkle proof for user public values in the memory state. @@ -51,7 +52,7 @@ impl UserPublicValuesProof { memory_dimensions: MemoryDimensions, num_public_values: usize, hasher: &(impl Hasher + Sync), - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Self { let proof = compute_merkle_proof_to_user_public_values_root( memory_dimensions, @@ -121,7 +122,7 @@ fn compute_merkle_proof_to_user_public_values_root + Sync), - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Vec<[F; CHUNK]> { assert_eq!( num_public_values % CHUNK, @@ -169,7 +170,7 @@ fn compute_merkle_proof_to_user_public_values_root( memory_dimensions: &MemoryDimensions, num_public_values: usize, - final_memory: &MemoryImage, + final_memory: &MemoryImage, ) -> Vec { // All (addr, value) pairs in the public value address space. let f_as_start = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; @@ -198,6 +199,8 @@ pub fn extract_public_values( public_values } +// TODO: add back +/* #[cfg(test)] mod tests { use openvm_stark_backend::p3_field::FieldAlgebra; @@ -206,7 +209,7 @@ mod tests { use super::{UserPublicValuesProof, PUBLIC_VALUES_ADDRESS_SPACE_OFFSET}; use crate::{ arch::{hasher::poseidon2::vm_poseidon2_hasher, SystemConfig}, - system::memory::{paged_vec::AddressMap, tree::MemoryNode, CHUNK}, + system::memory::{online::GuestMemory, paged_vec::AddressMap, tree::MemoryNode, CHUNK}, }; type F = BabyBear; @@ -218,12 +221,14 @@ mod tests { let memory_dimensions = vm_config.memory_config.memory_dimensions(); let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset; let num_public_values = 16; - let memory = AddressMap::from_iter( + let mut memory = AddressMap::new( memory_dimensions.as_offset, 1 << memory_dimensions.as_height, 1 << memory_dimensions.address_height, - [((pv_as, 15), F::ONE)], ); + unsafe { + memory.write::(pv_as, 15, &[F::ONE]); + } let mut expected_pvs = F::zero_vec(num_public_values); expected_pvs[15] = F::ONE; @@ -241,3 +246,4 @@ mod tests { .unwrap(); } } +*/ diff --git a/crates/vm/src/system/memory/volatile/mod.rs b/crates/vm/src/system/memory/volatile/mod.rs index e01162c789..e345acddda 100644 --- a/crates/vm/src/system/memory/volatile/mod.rs +++ b/crates/vm/src/system/memory/volatile/mod.rs @@ -183,7 +183,7 @@ pub struct VolatileBoundaryChip { pub air: VolatileBoundaryAir, range_checker: SharedVariableRangeCheckerChip, overridden_height: Option, - final_memory: Option>, + pub final_memory: Option>, addr_space_max_bits: usize, pointer_max_bits: usize, } diff --git a/crates/vm/src/system/mod.rs b/crates/vm/src/system/mod.rs index a1038ac86a..7164846d5b 100644 --- a/crates/vm/src/system/mod.rs +++ b/crates/vm/src/system/mod.rs @@ -1,5 +1,6 @@ pub mod connector; pub mod memory; +// Necessary for the PublicValuesChip pub mod native_adapter; /// Chip to handle phantom instructions. /// The Air will always constrain a NOP which advances pc by DEFAULT_PC_STEP. diff --git a/crates/vm/src/system/native_adapter/mod.rs b/crates/vm/src/system/native_adapter/mod.rs index 95c2c7c4a4..25729e9376 100644 --- a/crates/vm/src/system/native_adapter/mod.rs +++ b/crates/vm/src/system/native_adapter/mod.rs @@ -1,3 +1,5 @@ +mod util; + use std::{ borrow::{Borrow, BorrowMut}, marker::PhantomData, @@ -5,16 +7,12 @@ use std::{ use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + AdapterAirContext, BasicAdapterInterface, ExecutionBridge, ExecutionState, + MinimalInstruction, VmAdapterAir, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, + MemoryAddress, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -24,67 +22,13 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; - -use crate::system::memory::{OfflineMemory, RecordId}; - -/// R reads(R<=2), W writes(W<=1). -/// Operands: b for the first read, c for the second read, a for the first write. -/// If an operand is not used, its address space and pointer should be all 0. -#[derive(Debug)] -pub struct NativeAdapterChip { - pub air: NativeAdapterAir, - _phantom: PhantomData, -} - -impl NativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: NativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _phantom: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeReadRecord { - #[serde(with = "BigArray")] - pub reads: [(RecordId, [F; 1]); R], -} - -impl NativeReadRecord { - pub fn b(&self) -> &[F; 1] { - &self.reads[0].1 - } +use util::{tracing_read_or_imm_native, tracing_write_native, AS_NATIVE}; - pub fn c(&self) -> &[F; 1] { - &self.reads[1].1 - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeWriteRecord { - pub from_state: ExecutionState, - #[serde(with = "BigArray")] - pub writes: [(RecordId, [F; 1]); W], -} - -impl NativeWriteRecord { - pub fn a(&self) -> &[F; 1] { - &self.writes[0].1 - } -} +use super::memory::{online::TracingMemory, MemoryAuxColsFactory}; +use crate::{ + arch::{execution_mode::E1E2ExecutionCtx, AdapterExecutorE1, AdapterTraceStep, VmStateMut}, + system::memory::online::GuestMemory, +}; #[repr(C)] #[derive(AlignedBorrow)] @@ -205,101 +149,184 @@ impl VmAdapterAir } } -impl VmAdapterChip - for NativeAdapterChip +/// R reads(R<=2), W writes(W<=1). +/// Operands: b for the first read, c for the second read, a for the first write. +/// If an operand is not used, its address space and pointer should be all 0. +#[derive(Debug, derive_new::new)] +pub struct NativeAdapterStep { + _phantom: PhantomData, +} + +impl AdapterTraceStep for NativeAdapterStep +where + F: PrimeField32, { - type ReadRecord = NativeReadRecord; - type WriteRecord = NativeWriteRecord; - type Air = NativeAdapterAir; - type Interface = BasicAdapterInterface, R, W, 1, 1>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = size_of::>(); + type ReadData = [[F; 1]; R]; + type WriteData = [[F; 1]; W]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut NativeAdapterCols = adapter_row.borrow_mut(); + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + adapter_row: &mut [F], + ) -> Self::ReadData { assert!(R <= 2); - let Instruction { b, c, e, f, .. } = *instruction; - let mut reads = Vec::with_capacity(R); + let &Instruction { b, e, f, c, .. } = instruction; + + let cols: &mut NativeAdapterCols<_, R, W> = adapter_row.borrow_mut(); + + let mut reads = [[F::ZERO; 1]; R]; if R >= 1 { - reads.push(memory.read::<1>(e, b)); + cols.reads_aux[0].address.pointer = b; + reads[0][0] = tracing_read_or_imm_native( + memory, + e.as_canonical_u32(), + b, + &mut cols.reads_aux[0].address.address_space, + &mut cols.reads_aux[0].read_aux, + ); } if R >= 2 { - reads.push(memory.read::<1>(f, c)); + cols.reads_aux[1].address.pointer = c; + reads[1][0] = tracing_read_or_imm_native( + memory, + f.as_canonical_u32(), + c, + &mut cols.reads_aux[1].address.address_space, + &mut cols.reads_aux[1].read_aux, + ); } - let i_reads: [_; R] = std::array::from_fn(|i| reads[i].1); - - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) + reads } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { + adapter_row: &mut [F], + data: &Self::WriteData, + ) { assert!(W <= 1); - let Instruction { a, d, .. } = *instruction; - let mut writes = Vec::with_capacity(W); + + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS_NATIVE); + + let cols: &mut NativeAdapterCols<_, R, W> = adapter_row.borrow_mut(); + if W >= 1 { - let (record_id, _) = memory.write(d, a, output.writes[0]); - writes.push((record_id, output.writes[0])); + cols.writes_aux[0].address.address_space = F::from_canonical_u32(AS_NATIVE); + cols.writes_aux[0].address.pointer = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + &data[0], + &mut cols.writes_aux[0].write_aux, + ); } - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: writes.try_into().unwrap(), - }, - )) } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let row_slice: &mut NativeAdapterCols<_, R, W> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - for (i, read) in read_record.reads.iter().enumerate() { - let (id, _) = read; - let record = memory.record_by_id(*id); - aux_cols_factory - .generate_read_or_immediate_aux(record, &mut row_slice.reads_aux[i].read_aux); - row_slice.reads_aux[i].address = - MemoryAddress::new(record.address_space, record.pointer); + let adapter_row: &mut NativeAdapterCols<_, R, W> = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + for read_aux in &mut adapter_row.reads_aux { + mem_helper.fill_from_prev(timestamp, &mut read_aux.read_aux.base); + timestamp += 1; + + if read_aux.address.address_space.is_zero() { + read_aux.read_aux.is_immediate = F::ONE; + read_aux.read_aux.is_zero_aux = F::ZERO; + } else { + read_aux.read_aux.is_immediate = F::ZERO; + read_aux.read_aux.is_zero_aux = read_aux.address.address_space.inverse(); + } + } + + for write_aux in &mut adapter_row.writes_aux { + mem_helper.fill_from_prev(timestamp, write_aux.write_aux.as_mut()); + timestamp += 1; } + } +} + +impl AdapterExecutorE1 for NativeAdapterStep +where + F: PrimeField32, +{ + type ReadData = [F; R]; + type WriteData = [F; W]; - for (i, write) in write_record.writes.iter().enumerate() { - let (id, _) = write; - let record = memory.record_by_id(*id); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i].write_aux); - row_slice.writes_aux[i].address = - MemoryAddress::new(record.address_space, record.pointer); + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + assert!(R <= 2); + + let &Instruction { b, c, e, f, .. } = instruction; + + let mut reads = [F::ZERO; R]; + if R >= 1 { + let [value] = unsafe { + state + .memory + .read::(e.as_canonical_u32(), b.as_canonical_u32()) + }; + reads[0] = value; } + if R >= 2 { + let [value] = unsafe { + state + .memory + .read::(f.as_canonical_u32(), c.as_canonical_u32()) + }; + reads[1] = value; + } + reads } - fn air(&self) -> &Self::Air { - &self.air + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + assert!(W <= 1); + + let &Instruction { a, d, .. } = instruction; + if W >= 1 { + unsafe { + state + .memory + .write(d.as_canonical_u32(), a.as_canonical_u32(), data) + }; + } } } diff --git a/crates/vm/src/system/native_adapter/util.rs b/crates/vm/src/system/native_adapter/util.rs new file mode 100644 index 0000000000..5db1a5a050 --- /dev/null +++ b/crates/vm/src/system/native_adapter/util.rs @@ -0,0 +1,98 @@ +// TODO(ayush): this whole file is copied from extensions/native and shouldn't be here +use openvm_circuit::system::memory::{ + offline_checker::{MemoryBaseAuxCols, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, + online::TracingMemory, +}; +use openvm_stark_backend::p3_field::PrimeField32; + +// TODO(ayush): should be imported from somewhere +const AS_IMMEDIATE: u32 = 0; +pub(super) const AS_NATIVE: u32 = 4; + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:BLOCK_SIZE]_4)` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +fn timed_read( + memory: &mut TracingMemory, + ptr: u32, +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.read::(AS_NATIVE, ptr) } +} + +#[inline(always)] +fn timed_write( + memory: &mut TracingMemory, + ptr: u32, + vals: &[F; BLOCK_SIZE], +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.write::(AS_NATIVE, ptr, vals) } +} + +/// Reads register value at `ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read_native( + memory: &mut TracingMemory, + ptr: u32, + aux_cols: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] +where + F: PrimeField32, +{ + let (t_prev, data) = timed_read(memory, ptr); + aux_cols.set_prev(F::from_canonical_u32(t_prev)); + data +} + +/// Writes `ptr, vals` into memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_write_native( + memory: &mut TracingMemory, + ptr: u32, + vals: &[F; BLOCK_SIZE], + aux_cols: &mut MemoryWriteAuxCols, +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write(memory, ptr, vals); + aux_cols.set_prev(F::from_canonical_u32(t_prev), data_prev); +} + +/// Reads value at `_ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read_or_imm_native( + memory: &mut TracingMemory, + addr_space: u32, + ptr_or_imm: F, + addr_space_mut: &mut F, + aux_cols: &mut MemoryReadOrImmediateAuxCols, +) -> F +where + F: PrimeField32, +{ + debug_assert!(addr_space == AS_IMMEDIATE || addr_space == AS_NATIVE); + + if addr_space == AS_IMMEDIATE { + *addr_space_mut = F::ZERO; + memory.increment_timestamp(); + ptr_or_imm + } else { + *addr_space_mut = F::from_canonical_u32(AS_NATIVE); + let data: [F; 1] = + tracing_read_native(memory, ptr_or_imm.as_canonical_u32(), &mut aux_cols.base); + data[0] + } +} diff --git a/crates/vm/src/system/phantom/mod.rs b/crates/vm/src/system/phantom/mod.rs index 28977fe2cd..cdf6f0dfa0 100644 --- a/crates/vm/src/system/phantom/mod.rs +++ b/crates/vm/src/system/phantom/mod.rs @@ -23,11 +23,12 @@ use rustc_hash::FxHashMap; use serde::{Deserialize, Serialize}; use serde_big_array::BigArray; -use super::memory::MemoryController; +use super::memory::{online::GuestMemory, MemoryController}; use crate::{ arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, - PcIncOrSet, PhantomSubExecutor, Streams, + execution_mode::{e1::E1Ctx, metered::MeteredCtx, E1E2ExecutionCtx}, + ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InsExecutorE1, + InstructionExecutor, PcIncOrSet, PhantomSubExecutor, Streams, VmStateMut, }, system::program::ProgramBus, }; @@ -124,13 +125,19 @@ impl PhantomChip { } } -impl InstructionExecutor for PhantomChip { - fn execute( +impl InsExecutorE1 for PhantomChip +where + F: PrimeField32, +{ + fn execute_e1( &mut self, - memory: &mut MemoryController, + state: &mut VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + ) -> Result<(), ExecutionError> + where + F: PrimeField32, + Ctx: E1E2ExecutionCtx, + { let &Instruction { opcode, a, b, c, .. } = instruction; @@ -145,38 +152,72 @@ impl InstructionExecutor for PhantomChip { .phantom_executors .get_mut(&discriminant) .ok_or_else(|| ExecutionError::PhantomNotFound { - pc: from_state.pc, + pc: *state.pc, discriminant, })?; let mut streams = self.streams.get().unwrap().lock().unwrap(); + // TODO(ayush): implement phantom subexecutor for new traits sub_executor .as_mut() .phantom_execute( - memory, + state.memory, &mut streams, discriminant, - a, - b, + a.as_canonical_u32(), + b.as_canonical_u32(), (c_u32 >> 16) as u16, ) .map_err(|e| ExecutionError::Phantom { - pc: from_state.pc, + pc: *state.pc, discriminant, inner: e, })?; } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + _chip_index: usize, + ) -> Result<(), ExecutionError> { + self.execute_e1(state, instruction)?; + + Ok(()) + } +} + +impl InstructionExecutor for PhantomChip { + fn execute( + &mut self, + memory: &mut MemoryController, + instruction: &Instruction, + from_state: ExecutionState, + ) -> Result, ExecutionError> { + let mut pc = from_state.pc; self.rows.push(PhantomCols { - pc: F::from_canonical_u32(from_state.pc), - operands: [a, b, c], - timestamp: F::from_canonical_u32(from_state.timestamp), + pc: F::from_canonical_u32(pc), + operands: [instruction.a, instruction.b, instruction.c], + timestamp: F::from_canonical_u32(memory.memory.timestamp), is_valid: F::ONE, }); + + let mut state = VmStateMut { + pc: &mut pc, + memory: &mut memory.memory.data, + ctx: &mut E1Ctx::default(), + }; + self.execute_e1(&mut state, instruction)?; memory.increment_timestamp(); - Ok(ExecutionState::new( - from_state.pc + DEFAULT_PC_STEP, - from_state.timestamp + 1, - )) + + Ok(ExecutionState { + pc, + timestamp: memory.memory.timestamp, + }) } fn get_opcode_name(&self, _: usize) -> String { diff --git a/crates/vm/src/system/poseidon2/trace.rs b/crates/vm/src/system/poseidon2/trace.rs index 2b6f3e6b0b..979585c830 100644 --- a/crates/vm/src/system/poseidon2/trace.rs +++ b/crates/vm/src/system/poseidon2/trace.rs @@ -8,7 +8,6 @@ use openvm_stark_backend::{ p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, prover::types::AirProofInput, - rap::get_air_name, AirRef, Chip, ChipUsageGetter, }; @@ -63,7 +62,7 @@ impl ChipUsageGetter for Poseidon2PeripheryBaseChip { fn air_name(&self) -> String { - get_air_name(&self.air) + format!("Poseidon2PeripheryAir", SBOX_REGISTERS) } fn current_trace_height(&self) -> usize { diff --git a/crates/vm/src/system/program/trace.rs b/crates/vm/src/system/program/trace.rs index d9e2abd956..bf050b74bb 100644 --- a/crates/vm/src/system/program/trace.rs +++ b/crates/vm/src/system/program/trace.rs @@ -82,17 +82,14 @@ where let memory_dimensions = memory_config.memory_dimensions(); let app_program_commit: &[Val; CHUNK] = self.committed_program.commitment.as_ref(); let mem_config = memory_config; - let init_memory_commit = MemoryNode::tree_from_memory( - memory_dimensions, - &AddressMap::from_iter( - mem_config.as_offset, - 1 << mem_config.as_height, - 1 << mem_config.pointer_max_bits, - self.exe.init_memory.clone(), - ), - &hasher, - ) - .hash(); + let memory_image = AddressMap::from_sparse( + mem_config.as_offset, + 1 << mem_config.as_height, + 1 << mem_config.pointer_max_bits, + self.exe.init_memory.clone(), + ); + let init_memory_commit = + MemoryNode::tree_from_memory(memory_dimensions, &memory_image, &hasher).hash(); Com::::from(compute_exe_commit( &hasher, app_program_commit, diff --git a/crates/vm/src/system/public_values/core.rs b/crates/vm/src/system/public_values/core.rs index de189f101b..0e54624c32 100644 --- a/crates/vm/src/system/public_values/core.rs +++ b/crates/vm/src/system/public_values/core.rs @@ -2,7 +2,10 @@ use std::sync::Mutex; use openvm_circuit_primitives::{encoder::Encoder, SubAir}; use openvm_instructions::{ - instruction::Instruction, LocalOpcode, PublishOpcode, PublishOpcode::PUBLISH, + instruction::Instruction, + program::DEFAULT_PC_STEP, + LocalOpcode, + PublishOpcode::{self, PUBLISH}, }; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -14,13 +17,19 @@ use serde::{Deserialize, Serialize}; use crate::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, MinimalInstruction, - Result, VmAdapterInterface, VmCoreAir, VmCoreChip, + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, BasicAdapterInterface, + MinimalInstruction, Result, StepExecutorE1, TraceStep, VmCoreAir, VmStateMut, + }, + system::{ + memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, + public_values::columns::PublicValuesCoreColsView, }, - system::public_values::columns::PublicValuesCoreColsView, }; pub(crate) type AdapterInterface = BasicAdapterInterface, 2, 0, 1, 1>; -pub(crate) type AdapterInterfaceReads = as VmAdapterInterface>::Reads; #[derive(Clone, Debug)] pub struct PublicValuesCoreAir { @@ -107,19 +116,25 @@ pub struct PublicValuesRecord { /// ATTENTION: If a specific public value is not provided, a default 0 will be used when generating /// the proof but in the perspective of constraints, it could be any value. -pub struct PublicValuesCoreChip { - air: PublicValuesCoreAir, +pub struct PublicValuesCoreStep { + adapter: A, + // TODO(ayush): put air here and take from air + encoder: Encoder, // Mutex is to make the struct Sync. But it actually won't be accessed by multiple threads. - custom_pvs: Mutex>>, + pub(crate) custom_pvs: Mutex>>, } -impl PublicValuesCoreChip { +impl PublicValuesCoreStep +where + F: PrimeField32, +{ /// **Note:** `max_degree` is the maximum degree of the constraint polynomials to represent the /// flags. If you want the overall AIR's constraint degree to be `<= max_constraint_degree`, /// then typically you should set `max_degree` to `max_constraint_degree - 1`. - pub fn new(num_custom_pvs: usize, max_degree: u32) -> Self { + pub fn new(adapter: A, num_custom_pvs: usize, max_degree: u32) -> Self { Self { - air: PublicValuesCoreAir::new(num_custom_pvs, max_degree), + adapter, + encoder: Encoder::new(num_custom_pvs, max_degree, true), custom_pvs: Mutex::new(vec![None; num_custom_pvs]), } } @@ -128,18 +143,39 @@ impl PublicValuesCoreChip { } } -impl VmCoreChip> for PublicValuesCoreChip { - type Record = PublicValuesRecord; - type Air = PublicValuesCoreAir; +impl TraceStep for PublicValuesCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = [[F; 1]; 2], + WriteData = [[F; 1]; 0], + TraceContext<'a> = (), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + PublishOpcode::from_usize(opcode - PublishOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - _instruction: &Instruction, - _from_pc: u32, - reads: AdapterInterfaceReads, - ) -> Result<(AdapterRuntimeContext>, Self::Record)> { - let [[value], [index]] = reads; + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [[value], [index]] = self.adapter.read(state.memory, instruction, adapter_row); { let idx: usize = index.as_canonical_u32() as usize; let mut custom_pvs = self.custom_pvs.lock().unwrap(); @@ -152,31 +188,32 @@ impl VmCoreChip> for PublicValuesCoreChi panic!("Custom public value {} already set", idx); } } - let output = AdapterRuntimeContext { - to_pc: None, - writes: [], - }; - let record = Self::Record { value, index }; - Ok((output, record)) - } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - PublishOpcode::from_usize(opcode - PublishOpcode::CLASS_OFFSET) - ) + let cols = PublicValuesCoreColsView::<_, &mut F>::borrow_mut(core_row); + debug_assert_eq!(cols.width(), width - A::WIDTH); + + *cols.value = value; + *cols.index = index; + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + *trace_offset += width; + + Ok(()) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let mut cols = PublicValuesCoreColsView::<_, &mut F>::borrow_mut(row_slice); - debug_assert_eq!(cols.width(), BaseAir::::width(&self.air)); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let cols = PublicValuesCoreColsView::<_, &mut F>::borrow_mut(core_row); + *cols.is_valid = F::ONE; - *cols.value = record.value; - *cols.index = record.index; - let idx: usize = record.index.as_canonical_u32() as usize; - let pt = self.air.encoder.get_flag_pt(idx); - for (i, var) in cols.custom_pv_vars.iter_mut().enumerate() { - **var = F::from_canonical_u32(pt[i]); + + let idx: usize = cols.index.as_canonical_u32() as usize; + let pt = self.encoder.get_flag_pt(idx); + for (i, var) in cols.custom_pv_vars.into_iter().enumerate() { + *var = F::from_canonical_u32(pt[i]); } } @@ -186,8 +223,50 @@ impl VmCoreChip> for PublicValuesCoreChi .map(|x| x.unwrap_or(F::ZERO)) .collect() } +} + +impl StepExecutorE1 for PublicValuesCoreStep +where + F: PrimeField32, + A: 'static + for<'a> AdapterExecutorE1, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let [value, index] = self.adapter.read(state, instruction); + + let idx: usize = index.as_canonical_u32() as usize; + { + let mut custom_pvs = self.custom_pvs.lock().unwrap(); + + if custom_pvs[idx].is_none() { + custom_pvs[idx] = Some(value); + } else { + // Not a hard constraint violation when publishing the same value twice but the + // program should avoid that. + panic!("Custom public value {} already set", idx); + } + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; - fn air(&self) -> &Self::Air { - &self.air + Ok(()) } } diff --git a/crates/vm/src/system/public_values/mod.rs b/crates/vm/src/system/public_values/mod.rs index 918606497b..5b1b28cc01 100644 --- a/crates/vm/src/system/public_values/mod.rs +++ b/crates/vm/src/system/public_values/mod.rs @@ -1,8 +1,10 @@ +use core::PublicValuesCoreStep; + use crate::{ - arch::{VmAirWrapper, VmChipWrapper}, + arch::{NewVmChipWrapper, VmAirWrapper}, system::{ - native_adapter::{NativeAdapterAir, NativeAdapterChip}, - public_values::core::{PublicValuesCoreAir, PublicValuesCoreChip}, + native_adapter::{NativeAdapterAir, NativeAdapterStep}, + public_values::core::PublicValuesCoreAir, }, }; @@ -14,5 +16,5 @@ pub mod core; mod tests; pub type PublicValuesAir = VmAirWrapper, PublicValuesCoreAir>; -pub type PublicValuesChip = - VmChipWrapper, PublicValuesCoreChip>; +pub type PublicValuesStep = PublicValuesCoreStep, F>; +pub type PublicValuesChip = NewVmChipWrapper>; diff --git a/crates/vm/src/utils/stark_utils.rs b/crates/vm/src/utils/stark_utils.rs index d940be5c75..d8aea2123d 100644 --- a/crates/vm/src/utils/stark_utils.rs +++ b/crates/vm/src/utils/stark_utils.rs @@ -16,9 +16,12 @@ use openvm_stark_sdk::{ utils::ProofInputForTest, }; -use crate::arch::{ - vm::{VirtualMachine, VmExecutor}, - Streams, VmConfig, VmMemoryState, +use crate::{ + arch::{ + vm::{VirtualMachine, VmExecutor}, + Streams, VmConfig, + }, + system::memory::MemoryImage, }; pub fn air_test(config: VC, exe: impl Into>) @@ -36,7 +39,7 @@ pub fn air_test_with_min_segments( exe: impl Into>, input: impl Into>, min_segments: usize, -) -> Option> +) -> Option where VC: VmConfig, VC::Executor: Chip, @@ -53,7 +56,7 @@ pub fn air_test_impl( input: impl Into>, min_segments: usize, debug: bool, -) -> Option> +) -> Option where VC: VmConfig, VC::Executor: Chip, diff --git a/crates/vm/tests/integration_test.rs b/crates/vm/tests/integration_test.rs index 168d756111..20d8f3f8e2 100644 --- a/crates/vm/tests/integration_test.rs +++ b/crates/vm/tests/integration_test.rs @@ -1,15 +1,18 @@ use std::{ collections::{BTreeMap, VecDeque}, iter::zip, + mem::transmute, sync::Arc, }; use openvm_circuit::{ arch::{ + create_and_initialize_chip_complex, + execution_mode::tracegen::TracegenExecutionControlWithSegmentation, hasher::{poseidon2::vm_poseidon2_hasher, Hasher}, - ChipId, ExecutionSegment, MemoryConfig, SingleSegmentVmExecutor, SystemConfig, - SystemTraceHeights, VirtualMachine, VmComplexTraceHeights, VmConfig, - VmInventoryTraceHeights, + ChipId, MemoryConfig, SingleSegmentVmExecutor, SystemConfig, SystemTraceHeights, + VirtualMachine, VmComplexTraceHeights, VmConfig, VmInventoryTraceHeights, + VmSegmentExecutor, VmSegmentState, }, system::{ memory::{MemoryTraceHeights, VolatileMemoryTraceHeights, CHUNK}, @@ -316,9 +319,8 @@ fn test_vm_initial_memory() { Instruction::::from_isize(TERMINATE.global_opcode(), 0, 0, 0, 0, 0), ]); - let init_memory: BTreeMap<_, _> = [((4, 7), BabyBear::from_canonical_u32(101))] - .into_iter() - .collect(); + let raw = unsafe { transmute::(BabyBear::from_canonical_u32(101)) }; + let init_memory = BTreeMap::from_iter((0..4).map(|i| ((4u32, 7u32 * 4 + i), raw[i as usize]))); let config = test_native_continuations_config(); let exe = VmExe { @@ -713,15 +715,24 @@ fn test_hint_load_1() { let program = Program::from_instructions(&instructions); - let mut segment = ExecutionSegment::new( + let chip_complex = create_and_initialize_chip_complex( &test_native_config(), program, vec![vec![F::ONE, F::TWO]].into(), None, + ) + .unwrap(); + let ctrl = TracegenExecutionControlWithSegmentation::new(chip_complex.air_names()); + let mut segment = VmSegmentExecutor::::new( + chip_complex, vec![], Default::default(), + ctrl, ); - segment.execute_from_pc(0).unwrap(); + + let mut exec_state = VmSegmentState::new(0, 0, None, ()); + segment.execute_from_state(&mut exec_state).unwrap(); + let streams = segment.chip_complex.take_streams(); assert!(streams.input_stream.is_empty()); assert_eq!(streams.hint_stream, VecDeque::from(vec![F::ZERO])); @@ -750,22 +761,33 @@ fn test_hint_load_2() { let program = Program::from_instructions(&instructions); - let mut segment = ExecutionSegment::new( + let chip_complex = create_and_initialize_chip_complex( &test_native_config(), program, vec![vec![F::ONE, F::TWO], vec![F::TWO, F::ONE]].into(), None, + ) + .unwrap(); + let ctrl = TracegenExecutionControlWithSegmentation::new(chip_complex.air_names()); + let mut segment = VmSegmentExecutor::::new( + chip_complex, vec![], Default::default(), + ctrl, ); - segment.execute_from_pc(0).unwrap(); - assert_eq!( + + let mut exec_state = VmSegmentState::new(0, 0, None, ()); + segment.execute_from_state(&mut exec_state).unwrap(); + + let [read] = unsafe { segment .chip_complex .memory_controller() - .unsafe_read_cell(F::from_canonical_usize(4), F::from_canonical_usize(32)), - F::ZERO - ); + .memory + .data + .read::(4, 32) + }; + assert_eq!(read, F::ZERO); let streams = segment.chip_complex.take_streams(); assert!(streams.input_stream.is_empty()); assert_eq!(streams.hint_stream, VecDeque::from(vec![F::ONE])); diff --git a/extensions/algebra/circuit/Cargo.toml b/extensions/algebra/circuit/Cargo.toml index 7949fb0946..fdaab36ff3 100644 --- a/extensions/algebra/circuit/Cargo.toml +++ b/extensions/algebra/circuit/Cargo.toml @@ -37,3 +37,4 @@ openvm-mod-circuit-builder = { workspace = true, features = ["test-utils"] } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } openvm-pairing-guest = { workspace = true, features = ["halo2curves"] } +test-case = {workspace = true} \ No newline at end of file diff --git a/extensions/algebra/circuit/src/fp2_chip/addsub.rs b/extensions/algebra/circuit/src/fp2_chip/addsub.rs index 4eca1ad102..5165395b28 100644 --- a/extensions/algebra/circuit/src/fp2_chip/addsub.rs +++ b/extensions/algebra/circuit/src/fp2_chip/addsub.rs @@ -1,63 +1,26 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Fp2Opcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{Fp2Air, Fp2Chip, Fp2Step}; use crate::Fp2; -// Input: Fp2 * 2 -// Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp2AddSubChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp2AddSubChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Fp2Opcode::ADD as usize, - Fp2Opcode::SUB as usize, - Fp2Opcode::SETUP_ADDSUB as usize, - ], - vec![is_add_flag, is_sub_flag], - range_checker, - "Fp2AddSub", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - pub fn fp2_addsub_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -85,11 +48,73 @@ pub fn fp2_addsub_expr( ) } +// Input: Fp2 * 2 +// Output: Fp2 +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1)] +pub struct Fp2AddSubChip( + pub Fp2Chip, +); + +impl + Fp2AddSubChip +{ + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + height: usize, + ) -> Self { + let (expr, is_add_flag, is_sub_flag) = fp2_addsub_expr(config, range_checker.bus()); + + let local_opcode_idx = vec![ + Fp2Opcode::ADD as usize, + Fp2Opcode::SUB as usize, + Fp2Opcode::SETUP_ADDSUB as usize, + ]; + let opcode_flag_idx = vec![is_add_flag, is_sub_flag]; + let air = Fp2Air::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = Fp2Step::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + range_checker, + "Fp2AddSub", + false, + ); + Self(Fp2Chip::new(air, step, height, mem_helper)) + } + pub fn expr(&self) -> &FieldExpr { + &self.0.step.expr + } +} + #[cfg(test)] mod tests { use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; + use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -101,52 +126,30 @@ mod tests { ExprBuilderConfig, }; use openvm_pairing_guest::bn254::BN254_MODULUS; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; + use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use super::Fp2AddSubChip; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; + const MAX_INS_CAPACITY: usize = 128; + const OFFSET: usize = Fp2Opcode::CLASS_OFFSET; type F = BabyBear; - #[test] - fn test_fp2_addsub() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let modulus = BN254_MODULUS.clone(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = Fp2AddSubChip::new( - adapter, - config, - Fp2Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(42); + fn set_and_execute_rand( + tester: &mut VmChipTestBuilder, + chip: &mut Fp2AddSubChip, + modulus: &BigUint, + ) { + let mut rng = create_seeded_rng(); let x = Fq2::random(&mut rng); let y = Fq2::random(&mut rng); let inputs = [x.c0, x.c1, y.c0, y.c1].map(bn254_fq_to_biguint); let expected_sum = bn254_fq2_to_biguint_vec(x + y); let r_sum = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![true, false]); assert_eq!(r_sum.len(), 2); @@ -155,8 +158,6 @@ mod tests { let expected_sub = bn254_fq2_to_biguint_vec(x - y); let r_sub = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![false, true]); assert_eq!(r_sub.len(), 2); @@ -177,30 +178,57 @@ mod tests { .map(BabyBear::from_canonical_u32) }) .collect_vec(); - let modulus = - biguint_to_limbs::(modulus, LIMB_BITS).map(BabyBear::from_canonical_u32); + let modulus = biguint_to_limbs::(modulus.clone(), LIMB_BITS) + .map(BabyBear::from_canonical_u32); let zero = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( - &mut tester, + tester, vec![modulus, zero], vec![zero; 2], - chip.0.core.air.offset + Fp2Opcode::SETUP_ADDSUB as usize, + OFFSET + Fp2Opcode::SETUP_ADDSUB as usize, ); let instruction1 = rv32_write_heap_default( - &mut tester, + tester, x_limbs.clone(), y_limbs.clone(), - chip.0.core.air.offset + Fp2Opcode::ADD as usize, + OFFSET + Fp2Opcode::ADD as usize, ); - let instruction2 = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.0.core.air.offset + Fp2Opcode::SUB as usize, + let instruction2 = + rv32_write_heap_default(tester, x_limbs, y_limbs, OFFSET + Fp2Opcode::SUB as usize); + + tester.execute(chip, &setup_instruction); + tester.execute(chip, &instruction1); + tester.execute(chip, &instruction2); + } + + #[test] + fn test_fp2_addsub() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let modulus = BN254_MODULUS.clone(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Fp2AddSubChip::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + OFFSET, + bitwise_chip.clone(), + tester.range_checker(), + MAX_INS_CAPACITY, ); - tester.execute(&mut chip, &setup_instruction); - tester.execute(&mut chip, &instruction1); - tester.execute(&mut chip, &instruction2); + + let num_ops = 10; + for _ in 0..num_ops { + set_and_execute_rand(&mut tester, &mut chip, &modulus); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/extensions/algebra/circuit/src/fp2_chip/mod.rs b/extensions/algebra/circuit/src/fp2_chip/mod.rs index cd316fd70c..75d15cbdbe 100644 --- a/extensions/algebra/circuit/src/fp2_chip/mod.rs +++ b/extensions/algebra/circuit/src/fp2_chip/mod.rs @@ -3,3 +3,17 @@ pub use addsub::*; mod muldiv; pub use muldiv::*; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; +use openvm_mod_circuit_builder::{FieldExpressionCoreAir, FieldExpressionStep}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; + +pub(crate) type Fp2Air = VmAirWrapper< + Rv32VecHeapAdapterAir<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>, + FieldExpressionCoreAir, +>; + +pub(crate) type Fp2Step = + FieldExpressionStep>; + +pub(crate) type Fp2Chip = + NewVmChipWrapper, Fp2Step>; diff --git a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs index 83ef9565f3..ea4889e74e 100644 --- a/extensions/algebra/circuit/src/fp2_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/fp2_chip/muldiv.rs @@ -1,63 +1,26 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Fp2Opcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, SymbolicExpr, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, SymbolicExpr, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{Fp2Air, Fp2Chip, Fp2Step}; use crate::Fp2; -// Input: Fp2 * 2 -// Output: Fp2 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp2MulDivChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp2MulDivChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Fp2Opcode::MUL as usize, - Fp2Opcode::DIV as usize, - Fp2Opcode::SETUP_MULDIV as usize, - ], - vec![is_mul_flag, is_div_flag], - range_checker, - "Fp2MulDiv", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - pub fn fp2_muldiv_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -124,11 +87,74 @@ pub fn fp2_muldiv_expr( ) } +// Input: Fp2 * 2 +// Output: Fp2 +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1)] +pub struct Fp2MulDivChip( + pub Fp2Chip, +); + +impl + Fp2MulDivChip +{ + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + height: usize, + ) -> Self { + let (expr, is_mul_flag, is_div_flag) = fp2_muldiv_expr(config, range_checker.bus()); + + let local_opcode_idx = vec![ + Fp2Opcode::MUL as usize, + Fp2Opcode::DIV as usize, + Fp2Opcode::SETUP_MULDIV as usize, + ]; + let opcode_flag_idx = vec![is_mul_flag, is_div_flag]; + let air = Fp2Air::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = Fp2Step::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + opcode_flag_idx, + range_checker, + "Fp2MulDiv", + false, + ); + Self(Fp2Chip::new(air, step, height, mem_helper)) + } + + pub fn expr(&self) -> &FieldExpr { + &self.0.step.expr + } +} + #[cfg(test)] mod tests { use halo2curves_axiom::{bn256::Fq2, ff::Field}; use itertools::Itertools; + use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -140,57 +166,30 @@ mod tests { ExprBuilderConfig, }; use openvm_pairing_guest::bn254::BN254_MODULUS; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; + use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; + use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; - use super::Fp2MulDivChip; + use crate::fp2_chip::Fp2MulDivChip; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; + const OFFSET: usize = Fp2Opcode::CLASS_OFFSET; + const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; - #[test] - fn test_fp2_muldiv() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let modulus = BN254_MODULUS.clone(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = Fp2MulDivChip::new( - adapter, - config, - Fp2Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!( - chip.0.core.expr().builder.num_variables, - 2, - "Fp2MulDiv should only introduce new z Fp2 variable (2 Fp var)" - ); - - let mut rng = StdRng::seed_from_u64(42); + fn set_and_execute_rand( + tester: &mut VmChipTestBuilder, + chip: &mut Fp2MulDivChip, + modulus: &BigUint, + ) { + let mut rng = create_seeded_rng(); let x = Fq2::random(&mut rng); let y = Fq2::random(&mut rng); let inputs = [x.c0, x.c1, y.c0, y.c1].map(bn254_fq_to_biguint); let expected_mul = bn254_fq2_to_biguint_vec(x * y); let r_mul = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![true, false]); assert_eq!(r_mul.len(), 2); @@ -199,8 +198,6 @@ mod tests { let expected_div = bn254_fq2_to_biguint_vec(x * y.invert().unwrap()); let r_div = chip - .0 - .core .expr() .execute_with_output(inputs.to_vec(), vec![false, true]); assert_eq!(r_div.len(), 2); @@ -221,30 +218,62 @@ mod tests { .map(BabyBear::from_canonical_u32) }) .collect_vec(); - let modulus = - biguint_to_limbs::(modulus, LIMB_BITS).map(BabyBear::from_canonical_u32); + let modulus = biguint_to_limbs::(modulus.clone(), LIMB_BITS) + .map(BabyBear::from_canonical_u32); let zero = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( - &mut tester, + tester, vec![modulus, zero], vec![zero; 2], - chip.0.core.air.offset + Fp2Opcode::SETUP_MULDIV as usize, + OFFSET + Fp2Opcode::SETUP_MULDIV as usize, ); let instruction1 = rv32_write_heap_default( - &mut tester, + tester, x_limbs.clone(), y_limbs.clone(), - chip.0.core.air.offset + Fp2Opcode::MUL as usize, + OFFSET + Fp2Opcode::MUL as usize, ); - let instruction2 = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.0.core.air.offset + Fp2Opcode::DIV as usize, + let instruction2 = + rv32_write_heap_default(tester, x_limbs, y_limbs, OFFSET + Fp2Opcode::DIV as usize); + tester.execute(chip, &setup_instruction); + tester.execute(chip, &instruction1); + tester.execute(chip, &instruction2); + } + + #[test] + fn test_fp2_muldiv() { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let modulus = BN254_MODULUS.clone(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Fp2MulDivChip::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + OFFSET, + bitwise_chip.clone(), + tester.range_checker(), + MAX_INS_CAPACITY, ); - tester.execute(&mut chip, &setup_instruction); - tester.execute(&mut chip, &instruction1); - tester.execute(&mut chip, &instruction2); + assert_eq!( + chip.expr().builder.num_variables, + 2, + "Fp2MulDiv should only introduce new z Fp2 variable (2 Fp var)" + ); + + let num_ops = 10; + for _ in 0..num_ops { + set_and_execute_rand(&mut tester, &mut chip, &modulus); + } + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } diff --git a/extensions/algebra/circuit/src/fp2_extension.rs b/extensions/algebra/circuit/src/fp2_extension.rs index 940ec4c864..6fa87a0943 100644 --- a/extensions/algebra/circuit/src/fp2_extension.rs +++ b/extensions/algebra/circuit/src/fp2_extension.rs @@ -2,17 +2,18 @@ use derive_more::derive::From; use num_bigint::BigUint; use openvm_algebra_transpiler::Fp2Opcode; use openvm_circuit::{ - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, + arch::{ + ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{LocalOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -20,6 +21,9 @@ use strum::EnumCount; use crate::fp2_chip::{Fp2AddSubChip, Fp2MulDivChip}; +// TODO: this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[serde_as] #[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] pub struct Fp2Extension { @@ -27,7 +31,7 @@ pub struct Fp2Extension { pub supported_modulus: Vec, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, AnyEnum, From)] pub enum Fp2ExtensionExecutor { // 32 limbs prime Fp2AddSubRv32_32(Fp2AddSubChip), @@ -58,6 +62,11 @@ impl VmExtension for Fp2Extension { program_bus, memory_bridge, } = builder.system_port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = builder.system_base().range_checker_chip.clone(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -69,9 +78,6 @@ impl VmExtension for Fp2Extension { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; let addsub_opcodes = (Fp2Opcode::ADD as usize)..=(Fp2Opcode::SETUP_ADDSUB as usize); let muldiv_opcodes = (Fp2Opcode::MUL as usize)..=(Fp2Opcode::SETUP_MULDIV as usize); @@ -91,28 +97,18 @@ impl VmExtension for Fp2Extension { num_limbs: 48, limb_bits: 8, }; - let adapter_chip_32 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - let adapter_chip_48 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); if bytes <= 32 { let addsub_chip = Fp2AddSubChip::new( - adapter_chip_32.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( Fp2ExtensionExecutor::Fp2AddSubRv32_32(addsub_chip), @@ -121,11 +117,15 @@ impl VmExtension for Fp2Extension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = Fp2MulDivChip::new( - adapter_chip_32.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( Fp2ExtensionExecutor::Fp2MulDivRv32_32(muldiv_chip), @@ -135,11 +135,15 @@ impl VmExtension for Fp2Extension { )?; } else if bytes <= 48 { let addsub_chip = Fp2AddSubChip::new( - adapter_chip_48.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( Fp2ExtensionExecutor::Fp2AddSubRv32_48(addsub_chip), @@ -148,11 +152,15 @@ impl VmExtension for Fp2Extension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = Fp2MulDivChip::new( - adapter_chip_48.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( Fp2ExtensionExecutor::Fp2MulDivRv32_48(muldiv_chip), diff --git a/extensions/algebra/circuit/src/modular_chip/addsub.rs b/extensions/algebra/circuit/src/modular_chip/addsub.rs index 34bede150f..1c3e497b40 100644 --- a/extensions/algebra/circuit/src/modular_chip/addsub.rs +++ b/extensions/algebra/circuit/src/modular_chip/addsub.rs @@ -1,22 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldVariable, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{ModularAir, ModularChip, ModularStep}; + pub fn addsub_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -43,39 +46,58 @@ pub fn addsub_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1)] pub struct ModularAddSubChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, + pub ModularChip, ); impl ModularAddSubChip { pub fn new( - adapter: Rv32VecHeapAdapterChip, + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, config: ExprBuilderConfig, offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, + height: usize, ) -> Self { let (expr, is_add_flag, is_sub_flag) = addsub_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( + + let local_opcode_idx = vec![ + Rv32ModularArithmeticOpcode::ADD as usize, + Rv32ModularArithmeticOpcode::SUB as usize, + Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, + ]; + let opcode_flag_idx = vec![is_add_flag, is_sub_flag]; + let air = ModularAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = ModularStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), expr, offset, - vec![ - Rv32ModularArithmeticOpcode::ADD as usize, - Rv32ModularArithmeticOpcode::SUB as usize, - Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize, - ], - vec![is_add_flag, is_sub_flag], + local_opcode_idx, + opcode_flag_idx, range_checker, "ModularAddSub", false, ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) + Self(ModularChip::new(air, step, height, mem_helper)) } } diff --git a/extensions/algebra/circuit/src/modular_chip/is_eq.rs b/extensions/algebra/circuit/src/modular_chip/is_eq.rs index fe91585466..b0dad5d7f5 100644 --- a/extensions/algebra/circuit/src/modular_chip/is_eq.rs +++ b/extensions/algebra/circuit/src/modular_chip/is_eq.rs @@ -5,9 +5,16 @@ use std::{ use num_bigint::BigUint; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bigint::utils::big_uint_to_limbs, @@ -16,21 +23,20 @@ use openvm_circuit_primitives::{ SubAir, TraceSubRowGenerator, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; + // Given two numbers b and c, we want to prove that a) b == c or b != c, depending on // result of cmp_result and b) b, c < N for some modulus N that is passed into the AIR // at runtime (i.e. when chip is instantiated). #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct ModularIsEqualCoreCols { pub is_valid: T, pub is_setup: T, @@ -277,156 +283,201 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct ModularIsEqualCoreRecord { - #[serde(with = "BigArray")] - pub b: [T; READ_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; READ_LIMBS], - pub cmp_result: T, - #[serde(with = "BigArray")] - pub eq_marker: [T; READ_LIMBS], - pub b_diff_idx: usize, - pub c_diff_idx: usize, - pub is_setup: bool, -} - -pub struct ModularIsEqualCoreChip< +#[derive(derive_new::new)] +pub struct ModularIsEqualStep< + A, const READ_LIMBS: usize, const WRITE_LIMBS: usize, const LIMB_BITS: usize, > { - pub air: ModularIsEqualCoreAir, + adapter: A, + pub modulus_limbs: [u8; READ_LIMBS], + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl - ModularIsEqualCoreChip -{ - pub fn new( - modulus: BigUint, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - air: ModularIsEqualCoreAir::new(modulus, bitwise_lookup_chip.bus(), offset), - bitwise_lookup_chip, - } - } -} - -impl< - F: PrimeField32, - I: VmAdapterInterface, - const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, - > VmCoreChip for ModularIsEqualCoreChip +impl + TraceStep for ModularIsEqualStep where - I::Reads: Into<[[F; READ_LIMBS]; 2]>, - I::Writes: From<[[F; WRITE_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; READ_LIMBS]; 2]>, + WriteData: From<[u8; WRITE_LIMBS]>, + TraceContext<'a> = (), + >, { - type Record = ModularIsEqualCoreRecord; - type Air = ModularIsEqualCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let data: [[F; READ_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (b_cmp, b_diff_idx) = run_unsigned_less_than::(&b, &self.air.modulus_limbs); - let (c_cmp, c_diff_idx) = run_unsigned_less_than::(&c, &self.air.modulus_limbs); - let is_setup = instruction.opcode.local_opcode_idx(self.air.offset) + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let Instruction { opcode, .. } = instruction; + + let local_opcode = + Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + matches!( + local_opcode, + Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ + ); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = row_slice.split_at_mut(A::WIDTH); + + let cols: &mut ModularIsEqualCoreCols = core_row.borrow_mut(); + + A::start(*state.pc, state.memory, adapter_row); + let [b, c] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + + cols.b = b.map(F::from_canonical_u8); + cols.c = c.map(F::from_canonical_u8); + + let (b_cmp, _) = run_unsigned_less_than::(&b, &self.modulus_limbs); + let (c_cmp, _) = run_unsigned_less_than::(&c, &self.modulus_limbs); + let is_setup = instruction.opcode.local_opcode_idx(self.offset) == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize; + cols.is_setup = F::from_bool(is_setup); + if !is_setup { - assert!(b_cmp, "{:?} >= {:?}", b, self.air.modulus_limbs); - } - assert!(c_cmp, "{:?} >= {:?}", c, self.air.modulus_limbs); - if !is_setup { - self.bitwise_lookup_chip.request_range( - self.air.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1, - self.air.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1, - ); + assert!(b_cmp, "{:?} >= {:?}", b, self.modulus_limbs); } + assert!(c_cmp, "{:?} >= {:?}", c, self.modulus_limbs); - let mut eq_marker = [F::ZERO; READ_LIMBS]; - let mut cmp_result = F::ZERO; - self.air - .subair - .generate_subrow((&data[0], &data[1]), (&mut eq_marker, &mut cmp_result)); - - let mut writes = [F::ZERO; WRITE_LIMBS]; - writes[0] = cmp_result; - - let output = AdapterRuntimeContext::without_pc([writes]); - let record = ModularIsEqualCoreRecord { - is_setup, - b: data[0], - c: data[1], - cmp_result, - eq_marker, - b_diff_idx, - c_diff_idx, - }; + let mut write_data = [0u8; WRITE_LIMBS]; + write_data[0] = (b == c) as u8; + self.adapter + .write(state.memory, instruction, adapter_row, &write_data.into()); - Ok((output, record)) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + *trace_offset += width; - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32ModularArithmeticOpcode::from_usize(opcode - self.air.offset) - ) + Ok(()) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut(); - row_slice.is_valid = F::ONE; - row_slice.is_setup = F::from_bool(record.is_setup); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.cmp_result = record.cmp_result; - - row_slice.eq_marker = record.eq_marker; + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = row_slice.split_at_mut(A::WIDTH); + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + let cols: &mut ModularIsEqualCoreCols = core_row.borrow_mut(); - if !record.is_setup { - row_slice.b_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.b_diff_idx]) - - record.b[record.b_diff_idx]; + cols.is_valid = F::ONE; + let sub_air = IsEqArraySubAir::; + sub_air.generate_subrow( + (&cols.b, &cols.c), + (&mut cols.eq_marker, &mut cols.cmp_result), + ); + let b = cols.b.map(|x| x.as_canonical_u32() as u8); + let c = cols.c.map(|x| x.as_canonical_u32() as u8); + let (_, b_diff_idx) = run_unsigned_less_than::(&b, &self.modulus_limbs); + let (_, c_diff_idx) = run_unsigned_less_than::(&c, &self.modulus_limbs); + + if cols.is_setup != F::ONE { + cols.b_lt_diff = + F::from_canonical_u8(self.modulus_limbs[b_diff_idx]) - cols.b[b_diff_idx]; + self.bitwise_lookup_chip.request_range( + (self.modulus_limbs[b_diff_idx] - b[b_diff_idx] - 1) as u32, + (self.modulus_limbs[c_diff_idx] - c[c_diff_idx] - 1) as u32, + ); } - row_slice.c_lt_diff = F::from_canonical_u32(self.air.modulus_limbs[record.c_diff_idx]) - - record.c[record.c_diff_idx]; - row_slice.c_lt_mark = if record.b_diff_idx == record.c_diff_idx { + cols.c_lt_diff = F::from_canonical_u8(self.modulus_limbs[c_diff_idx]) - cols.c[c_diff_idx]; + cols.c_lt_mark = if b_diff_idx == c_diff_idx { F::ONE } else { F::from_canonical_u8(2) }; - row_slice.lt_marker = from_fn(|i| { - if i == record.b_diff_idx { + cols.lt_marker = from_fn(|i| { + if i == b_diff_idx { F::ONE - } else if i == record.c_diff_idx { - row_slice.c_lt_mark + } else if i == c_diff_idx { + cols.c_lt_mark } else { F::ZERO } }); } - fn air(&self) -> &Self::Air { - &self.air + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32ModularArithmeticOpcode::from_usize(opcode - self.offset) + ) + } +} + +impl + StepExecutorE1 for ModularIsEqualStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData: Into<[[u8; READ_LIMBS]; 2]>, + WriteData: From<[u8; WRITE_LIMBS]>, + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; + + let local_opcode = + Rv32ModularArithmeticOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + matches!( + local_opcode, + Rv32ModularArithmeticOpcode::IS_EQ | Rv32ModularArithmeticOpcode::SETUP_ISEQ + ); + + let [b, c] = self.adapter.read(state, instruction).into(); + let (b_cmp, _) = run_unsigned_less_than::(&b, &self.modulus_limbs); + let (c_cmp, _) = run_unsigned_less_than::(&c, &self.modulus_limbs); + let is_setup = instruction.opcode.local_opcode_idx(self.offset) + == Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize; + + if !is_setup { + assert!(b_cmp, "{:?} >= {:?}", b, self.modulus_limbs); + } + assert!(c_cmp, "{:?} >= {:?}", c, self.modulus_limbs); + + let mut write_data = [0u8; WRITE_LIMBS]; + write_data[0] = (b == c) as u8; + + self.adapter.write(state, instruction, &write_data.into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // Returns (cmp_result, diff_idx) pub(super) fn run_unsigned_less_than( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize) { for i in (0..NUM_LIMBS).rev() { if x[i] != y[i] { diff --git a/extensions/algebra/circuit/src/modular_chip/mod.rs b/extensions/algebra/circuit/src/modular_chip/mod.rs index 2dd9838206..1ba8876f5e 100644 --- a/extensions/algebra/circuit/src/modular_chip/mod.rs +++ b/extensions/algebra/circuit/src/modular_chip/mod.rs @@ -4,21 +4,56 @@ mod is_eq; pub use is_eq::*; mod muldiv; pub use muldiv::*; -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use openvm_rv32_adapters::Rv32IsEqualModAdapterChip; +use openvm_mod_circuit_builder::{FieldExpressionCoreAir, FieldExpressionStep}; +use openvm_rv32_adapters::{ + Rv32IsEqualModAdapterAir, Rv32IsEqualModeAdapterStep, Rv32VecHeapAdapterAir, + Rv32VecHeapAdapterStep, +}; #[cfg(test)] mod tests; +pub(crate) type ModularAir = VmAirWrapper< + Rv32VecHeapAdapterAir<2, BLOCKS, BLOCKS, BLOCK_SIZE, BLOCK_SIZE>, + FieldExpressionCoreAir, +>; + +pub(crate) type ModularStep = + FieldExpressionStep>; + +pub(crate) type ModularChip = + NewVmChipWrapper, ModularStep>; + // Must have TOTAL_LIMBS = NUM_LANES * LANE_SIZE +pub type ModularIsEqualAir< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, +> = VmAirWrapper< + Rv32IsEqualModAdapterAir<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + ModularIsEqualCoreAir, +>; + +pub type VmModularIsEqualStep< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, +> = ModularIsEqualStep< + Rv32IsEqualModeAdapterStep<2, NUM_LANES, LANE_SIZE, TOTAL_LIMBS>, + TOTAL_LIMBS, + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, +>; + pub type ModularIsEqualChip< F, const NUM_LANES: usize, const LANE_SIZE: usize, const TOTAL_LIMBS: usize, -> = VmChipWrapper< +> = NewVmChipWrapper< F, - Rv32IsEqualModAdapterChip, - ModularIsEqualCoreChip, + ModularIsEqualAir, + VmModularIsEqualStep, >; diff --git a/extensions/algebra/circuit/src/modular_chip/muldiv.rs b/extensions/algebra/circuit/src/modular_chip/muldiv.rs index 30f063e2b1..f3defa4507 100644 --- a/extensions/algebra/circuit/src/modular_chip/muldiv.rs +++ b/extensions/algebra/circuit/src/modular_chip/muldiv.rs @@ -1,22 +1,25 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; +use std::{cell::RefCell, rc::Rc}; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, }; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; +use openvm_instructions::riscv::RV32_CELL_BITS; use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, FieldVariable, SymbolicExpr, + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldVariable, SymbolicExpr, }; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; +use super::{ModularAir, ModularChip, ModularStep}; + pub fn muldiv_expr( config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus, @@ -58,39 +61,58 @@ pub fn muldiv_expr( ) } -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1)] pub struct ModularMulDivChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, + pub ModularChip, ); impl ModularMulDivChip { pub fn new( - adapter: Rv32VecHeapAdapterChip, + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, config: ExprBuilderConfig, offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, + height: usize, ) -> Self { let (expr, is_mul_flag, is_div_flag) = muldiv_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( + + let local_opcode_idx = vec![ + Rv32ModularArithmeticOpcode::MUL as usize, + Rv32ModularArithmeticOpcode::DIV as usize, + Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, + ]; + let opcode_flag_idx = vec![is_mul_flag, is_div_flag]; + let air = ModularAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new( + expr.clone(), + offset, + local_opcode_idx.clone(), + opcode_flag_idx.clone(), + ), + ); + + let step = ModularStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), expr, offset, - vec![ - Rv32ModularArithmeticOpcode::MUL as usize, - Rv32ModularArithmeticOpcode::DIV as usize, - Rv32ModularArithmeticOpcode::SETUP_MULDIV as usize, - ], - vec![is_mul_flag, is_div_flag], + local_opcode_idx, + opcode_flag_idx, range_checker, "ModularMulDiv", false, ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) + Self(ModularChip::new(air, step, height, mem_helper)) } } diff --git a/extensions/algebra/circuit/src/modular_chip/tests.rs b/extensions/algebra/circuit/src/modular_chip/tests.rs index 1ad3310f76..9a5de34345 100644 --- a/extensions/algebra/circuit/src/modular_chip/tests.rs +++ b/extensions/algebra/circuit/src/modular_chip/tests.rs @@ -6,7 +6,6 @@ use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; use openvm_circuit::arch::{ instructions::LocalOpcode, testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - AdapterRuntimeContext, Result, VmAdapterInterface, VmChipWrapper, VmCoreChip, }; use openvm_circuit_primitives::{ bigint::utils::{big_uint_to_limbs, secp256k1_coord_prime, secp256k1_scalar_prime}, @@ -18,105 +17,61 @@ use openvm_mod_circuit_builder::{ ExprBuilderConfig, }; use openvm_pairing_guest::bls12_381::BLS12_381_MODULUS; -use openvm_rv32_adapters::{ - rv32_write_heap_default, write_ptr_reg, Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip, -}; +use openvm_rv32_adapters::{rv32_write_heap_default, write_ptr_reg}; use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS; -use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; +use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; use super::{ - ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreAir, ModularIsEqualCoreChip, - ModularIsEqualCoreCols, ModularIsEqualCoreRecord, ModularMulDivChip, + ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreAir, ModularIsEqualCoreCols, + ModularMulDivChip, }; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 32; +const _BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; -const ADD_LOCAL: usize = Rv32ModularArithmeticOpcode::ADD as usize; -const MUL_LOCAL: usize = Rv32ModularArithmeticOpcode::MUL as usize; +#[cfg(test)] +mod addsubtests { + use test_case::test_case; -#[test] -fn test_coord_addsub() { - let opcode_offset = 0; - let modulus = secp256k1_coord_prime(); - test_addsub(opcode_offset, modulus); -} + use super::*; -#[test] -fn test_scalar_addsub() { - let opcode_offset = 4; - let modulus = secp256k1_scalar_prime(); - test_addsub(opcode_offset, modulus); -} + const ADD_LOCAL: usize = Rv32ModularArithmeticOpcode::ADD as usize; + + fn set_and_execute_addsub( + tester: &mut VmChipTestBuilder, + chip: &mut ModularAddSubChip, + modulus: &BigUint, + is_setup: bool, + ) { + let mut rng = create_seeded_rng(); + + let (a, b, op) = if is_setup { + (modulus.clone(), BigUint::zero(), ADD_LOCAL + 2) + } else { + let a_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut a = BigUint::new(a_digits.clone()); + let b_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut b = BigUint::new(b_digits.clone()); + + let op = rng.gen_range(0..2) + ADD_LOCAL; // 0 for add, 1 for sub + a %= modulus; + b %= modulus; + (a, b, op) + }; -fn test_addsub(opcode_offset: usize, modulus: BigUint) { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - // doing 1xNUM_LIMBS reads and writes - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = ModularAddSubChip::new( - adapter, - config, - Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - let mut rng = create_seeded_rng(); - let num_tests = 50; - let mut all_ops = vec![ADD_LOCAL + 2]; // setup - let mut all_a = vec![modulus.clone()]; - let mut all_b = vec![BigUint::zero()]; - - // First loop: generate all random test data. - for _ in 0..num_tests { - let a_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut a = BigUint::new(a_digits.clone()); - let b_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut b = BigUint::new(b_digits.clone()); - - let op = rng.gen_range(0..2) + ADD_LOCAL; // 0 for add, 1 for sub - a %= &modulus; - b %= &modulus; - - all_ops.push(op); - all_a.push(a); - all_b.push(b); - } - // Second loop: actually run the tests. - for i in 0..=num_tests { - let op = all_ops[i]; - let a = all_a[i].clone(); - let b = all_b[i].clone(); - if i > 0 { - // if not setup - assert!(a < modulus); - assert!(b < modulus); - } let expected_answer = match op - ADD_LOCAL { - 0 => (&a + &b) % &modulus, - 1 => (&a + &modulus - &b) % &modulus, - 2 => a.clone() % &modulus, + 0 => (&a + &b) % modulus, + 1 => (&a + modulus - &b) % modulus, + 2 => a.clone() % modulus, _ => panic!(), }; @@ -133,11 +88,11 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { let data_as = 2; let address1 = 0u32; let address2 = 128u32; - let address3 = (1 << 28) + 1234; // a large memory address to test heap adapter + let address3 = (1 << 28) + 1228; // a large memory address to test heap adapter - write_ptr_reg(&mut tester, ptr_as, addr_ptr1, address1); - write_ptr_reg(&mut tester, ptr_as, addr_ptr2, address2); - write_ptr_reg(&mut tester, ptr_as, addr_ptr3, address3); + write_ptr_reg(tester, ptr_as, addr_ptr1, address1); + write_ptr_reg(tester, ptr_as, addr_ptr2, address2); + write_ptr_reg(tester, ptr_as, addr_ptr3, address3); let a_limbs: [BabyBear; NUM_LIMBS] = biguint_to_limbs(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); @@ -147,105 +102,92 @@ fn test_addsub(opcode_offset: usize, modulus: BigUint) { tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.0.core.air.offset + op), + VmOpcode::from_usize(chip.0.step.offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, ptr_as as isize, data_as as isize, ); - tester.execute(&mut chip, &instruction); + tester.execute(chip, &instruction); let expected_limbs = biguint_to_limbs::(expected_answer, LIMB_BITS); - for (i, expected) in expected_limbs.into_iter().enumerate() { - let address = address3 as usize + i; - let read_val = tester.read_cell(data_as, address); - assert_eq!(BabyBear::from_canonical_u32(expected), read_val); - } + let read_vals = tester.read::(data_as, address3 as usize); + assert_eq!(read_vals, expected_limbs.map(F::from_canonical_u32)); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn test_addsub(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); -#[test] -fn test_coord_muldiv() { - let opcode_offset = 0; - let modulus = secp256k1_coord_prime(); - test_muldiv(opcode_offset, modulus); -} + // doing 1xNUM_LIMBS reads and writes + let mut chip = ModularAddSubChip::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.memory_helper(), + tester.address_bits(), + config, + Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, + bitwise_chip.clone(), + tester.range_checker(), + MAX_INS_CAPACITY, + ); -#[test] -fn test_scalar_muldiv() { - let opcode_offset = 4; - let modulus = secp256k1_scalar_prime(); - test_muldiv(opcode_offset, modulus); -} + for i in 0..num_ops { + set_and_execute_addsub(&mut tester, &mut chip, &modulus, i == 0); + } -fn test_muldiv(opcode_offset: usize, modulus: BigUint) { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: modulus.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - // doing 1xNUM_LIMBS reads and writes - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = ModularMulDivChip::new( - adapter, - config, - Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - let mut rng = create_seeded_rng(); - let num_tests = 50; - let mut all_ops = vec![MUL_LOCAL + 2]; - let mut all_a = vec![modulus.clone()]; - let mut all_b = vec![BigUint::zero()]; - - // First loop: generate all random test data. - for _ in 0..num_tests { - let a_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut a = BigUint::new(a_digits.clone()); - let b_digits: Vec<_> = (0..NUM_LIMBS) - .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) - .collect(); - let mut b = BigUint::new(b_digits.clone()); - - // let op = rng.gen_range(2..4); // 2 for mul, 3 for div - let op = MUL_LOCAL; - a %= &modulus; - b %= &modulus; - - all_ops.push(op); - all_a.push(a); - all_b.push(b); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } - // Second loop: actually run the tests. - for i in 0..=num_tests { - let op = all_ops[i]; - let a = all_a[i].clone(); - let b = all_b[i].clone(); - if i > 0 { - // if not setup - assert!(a < modulus); - assert!(b < modulus); - } +} + +#[cfg(test)] +mod muldivtests { + use test_case::test_case; + + use super::*; + + const MUL_LOCAL: usize = Rv32ModularArithmeticOpcode::MUL as usize; + + fn set_and_execute_muldiv( + tester: &mut VmChipTestBuilder, + chip: &mut ModularMulDivChip, + modulus: &BigUint, + is_setup: bool, + ) { + let mut rng = create_seeded_rng(); + + let (a, b, op) = if is_setup { + (modulus.clone(), BigUint::zero(), MUL_LOCAL + 2) + } else { + let a_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut a = BigUint::new(a_digits.clone()); + let b_digits: Vec<_> = (0..NUM_LIMBS) + .map(|_| rng.gen_range(0..(1 << LIMB_BITS))) + .collect(); + let mut b = BigUint::new(b_digits.clone()); + + let op = rng.gen_range(0..2) + MUL_LOCAL; // 0 for add, 1 for sub + a %= modulus; + b %= modulus; + (a, b, op) + }; + let expected_answer = match op - MUL_LOCAL { - 0 => (&a * &b) % &modulus, - 1 => (&a * b.modinv(&modulus).unwrap()) % &modulus, - 2 => a.clone() % &modulus, + 0 => (&a * &b) % modulus, + 1 => (&a * b.modinv(modulus).unwrap()) % modulus, + 2 => a.clone() % modulus, _ => panic!(), }; @@ -264,307 +206,369 @@ fn test_muldiv(opcode_offset: usize, modulus: BigUint) { let address2 = 128; let address3 = 256; - write_ptr_reg(&mut tester, ptr_as, addr_ptr1, address1); - write_ptr_reg(&mut tester, ptr_as, addr_ptr2, address2); - write_ptr_reg(&mut tester, ptr_as, addr_ptr3, address3); + write_ptr_reg(tester, ptr_as, addr_ptr1, address1); + write_ptr_reg(tester, ptr_as, addr_ptr2, address2); + write_ptr_reg(tester, ptr_as, addr_ptr3, address3); - let a_limbs: [BabyBear; NUM_LIMBS] = - biguint_to_limbs(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let a_limbs: [F; NUM_LIMBS] = + biguint_to_limbs(a.clone(), LIMB_BITS).map(F::from_canonical_u32); tester.write(data_as, address1 as usize, a_limbs); - let b_limbs: [BabyBear; NUM_LIMBS] = - biguint_to_limbs(b.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); + let b_limbs: [F; NUM_LIMBS] = + biguint_to_limbs(b.clone(), LIMB_BITS).map(F::from_canonical_u32); tester.write(data_as, address2 as usize, b_limbs); let instruction = Instruction::from_isize( - VmOpcode::from_usize(chip.0.core.air.offset + op), + VmOpcode::from_usize(chip.0.step.offset + op), addr_ptr3 as isize, addr_ptr1 as isize, addr_ptr2 as isize, ptr_as as isize, data_as as isize, ); - tester.execute(&mut chip, &instruction); + tester.execute(chip, &instruction); let expected_limbs = biguint_to_limbs::(expected_answer, LIMB_BITS); - for (i, expected) in expected_limbs.into_iter().enumerate() { - let address = address3 as usize + i; - let read_val = tester.read_cell(data_as, address); - assert_eq!(BabyBear::from_canonical_u32(expected), read_val); - } + let read_vals = tester.read::(data_as, address3 as usize); + assert_eq!(read_vals, expected_limbs.map(F::from_canonical_u32)); } - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - - tester.simple_test().expect("Verification failed"); -} -fn test_is_equal( - opcode_offset: usize, - modulus: BigUint, - num_tests: usize, -) { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = ModularIsEqualChip::::new( - Rv32IsEqualModAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), + #[test_case(0, secp256k1_coord_prime(), 50)] + #[test_case(4, secp256k1_scalar_prime(), 50)] + fn test_muldiv(opcode_offset: usize, modulus: BigUint, num_ops: usize) { + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); + let config = ExprBuilderConfig { + modulus: modulus.clone(), + num_limbs: NUM_LIMBS, + limb_bits: LIMB_BITS, + }; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + // doing 1xNUM_LIMBS reads and writes + let mut chip = ModularMulDivChip::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), + config, + Rv32ModularArithmeticOpcode::CLASS_OFFSET + opcode_offset, bitwise_chip.clone(), - ), - ModularIsEqualCoreChip::new(modulus.clone(), bitwise_chip.clone(), opcode_offset), - tester.offline_memory_mutex_arc(), - ); + tester.range_checker(), + MAX_INS_CAPACITY, + ); - { - let vec = big_uint_to_limbs(&modulus, LIMB_BITS); - let modulus_limbs: [F; TOTAL_LIMBS] = std::array::from_fn(|i| { - if i < vec.len() { - F::from_canonical_usize(vec[i]) - } else { - F::ZERO - } - }); + for i in 0..num_ops { + set_and_execute_muldiv(&mut tester, &mut chip, &modulus, i == 0); + } + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - let setup_instruction = rv32_write_heap_default::( - &mut tester, - vec![modulus_limbs], - vec![[F::ZERO; TOTAL_LIMBS]], - opcode_offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, + tester.simple_test().expect("Verification failed"); + } +} + +#[cfg(test)] +mod is_equal_tests { + use openvm_rv32_adapters::{Rv32IsEqualModAdapterAir, Rv32IsEqualModeAdapterStep}; + use openvm_stark_backend::{ + p3_air::BaseAir, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, + }; + + use super::*; + use crate::modular_chip::{ModularIsEqualAir, ModularIsEqualStep}; + + fn create_test_chips< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, + >( + tester: &mut VmChipTestBuilder, + modulus: &BigUint, + modulus_limbs: [u8; TOTAL_LIMBS], + offset: usize, + ) -> ( + ModularIsEqualChip, + SharedBitwiseOperationLookupChip, + ) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let chip = ModularIsEqualChip::::new( + ModularIsEqualAir::new( + Rv32IsEqualModAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + ModularIsEqualCoreAir::new(modulus.clone(), bitwise_bus, offset), + ), + ModularIsEqualStep::new( + Rv32IsEqualModeAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + modulus_limbs, + offset, + bitwise_chip.clone(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - tester.execute(&mut chip, &setup_instruction); + + (chip, bitwise_chip) } - for _ in 0..num_tests { - let b = generate_field_element::(&modulus, &mut rng); - let c = if rng.gen_bool(0.5) { - b + + fn set_and_execute_is_equal< + const NUM_LANES: usize, + const LANE_SIZE: usize, + const TOTAL_LIMBS: usize, + >( + tester: &mut VmChipTestBuilder, + chip: &mut ModularIsEqualChip, + rng: &mut StdRng, + modulus: &BigUint, + offset: usize, + modulus_limbs: [F; TOTAL_LIMBS], + is_setup: bool, + b: Option<[F; TOTAL_LIMBS]>, + c: Option<[F; TOTAL_LIMBS]>, + ) { + let instruction = if is_setup { + rv32_write_heap_default::( + tester, + vec![modulus_limbs], + vec![[F::ZERO; TOTAL_LIMBS]], + offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, + ) } else { - generate_field_element::(&modulus, &mut rng) + let b = b.unwrap_or( + generate_field_element::(modulus, rng) + .map(F::from_canonical_u32), + ); + let c = c.unwrap_or(if rng.gen_bool(0.5) { + b + } else { + generate_field_element::(modulus, rng) + .map(F::from_canonical_u32) + }); + + rv32_write_heap_default::( + tester, + vec![b], + vec![c], + offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, + ) }; + tester.execute(chip, &instruction); + } - let instruction = rv32_write_heap_default::( - &mut tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - opcode_offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, - ); - tester.execute(&mut chip, &instruction); + ////////////////////////////////////////////////////////////////////////////////////// + // POSITIVE TESTS + // + // Randomly generate computations and execute, ensuring that the generated trace + // passes all constraints. + ////////////////////////////////////////////////////////////////////////////////////// + + #[test] + fn test_modular_is_equal_1x32() { + test_is_equal::<1, 32, 32>(17, secp256k1_coord_prime(), 100); } - // Special case where b == c are close to the prime - let b_vec = big_uint_to_limbs(&modulus, LIMB_BITS); - let mut b = from_fn(|i| if i < b_vec.len() { b_vec[i] as u32 } else { 0 }); - b[0] -= 1; - let instruction = rv32_write_heap_default::( - &mut tester, - vec![b.map(F::from_canonical_u32)], - vec![b.map(F::from_canonical_u32)], - opcode_offset + Rv32ModularArithmeticOpcode::IS_EQ as usize, - ); - tester.execute(&mut chip, &instruction); - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + #[test] + fn test_modular_is_equal_3x16() { + test_is_equal::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 100); + } -#[test] -fn test_modular_is_equal_1x32() { - test_is_equal::<1, 32, 32>(17, secp256k1_coord_prime(), 100); -} + fn test_is_equal( + opcode_offset: usize, + modulus: BigUint, + num_tests: usize, + ) { + let mut rng = create_seeded_rng(); + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); -#[test] -fn test_modular_is_equal_3x16() { - test_is_equal::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 100); -} + let vec = big_uint_to_limbs(&modulus, LIMB_BITS); + let modulus_limbs: [u8; TOTAL_LIMBS] = + from_fn(|i| if i < vec.len() { vec[i] as u8 } else { 0 }); -// Wrapper chip for testing a bad setup row -type BadModularIsEqualChip< - F, - const NUM_LANES: usize, - const LANE_SIZE: usize, - const TOTAL_LIMBS: usize, -> = VmChipWrapper< - F, - Rv32IsEqualModAdapterChip, - BadModularIsEqualCoreChip, ->; - -// Wrapper chip for testing a bad setup row -struct BadModularIsEqualCoreChip< - const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, -> { - chip: ModularIsEqualCoreChip, -} + let (mut chip, bitwise_chip) = create_test_chips::( + &mut tester, + &modulus, + modulus_limbs, + opcode_offset, + ); -impl - BadModularIsEqualCoreChip -{ - pub fn new( - modulus: BigUint, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - offset: usize, - ) -> Self { - Self { - chip: ModularIsEqualCoreChip::new(modulus, bitwise_lookup_chip, offset), + let modulus_limbs = modulus_limbs.map(F::from_canonical_u8); + + for i in 0..num_tests { + set_and_execute_is_equal( + &mut tester, + &mut chip, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + i == 0, // the first test is a setup test + None, + None, + ); } + + // Special case where b == c are close to the prime + let mut b = modulus_limbs; + b[0] -= F::ONE; + set_and_execute_is_equal( + &mut tester, + &mut chip, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + false, + Some(b), + Some(b), + ); + + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } -} -impl< - F: PrimeField32, - I: VmAdapterInterface, + ////////////////////////////////////////////////////////////////////////////////////// + // NEGATIVE TESTS + // + // Given a fake trace of a single operation, setup a chip and run the test. We replace + // part of the trace and check that the chip throws the expected error. + ////////////////////////////////////////////////////////////////////////////////////// + + /// Negative tests test for 3 "type" of errors determined by the value of b[0]: + fn run_negative_is_equal_test< + const NUM_LANES: usize, + const LANE_SIZE: usize, const READ_LIMBS: usize, - const WRITE_LIMBS: usize, - const LIMB_BITS: usize, - > VmCoreChip for BadModularIsEqualCoreChip -where - I::Reads: Into<[[F; READ_LIMBS]; 2]>, - I::Writes: From<[[F; WRITE_LIMBS]; 1]>, -{ - type Record = ModularIsEqualCoreRecord; - type Air = ModularIsEqualCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - // Override the b_diff_idx to be out of bounds. - // This will cause lt_marker to be all zeros except a 2. - // There was a bug in this case which allowed b to be less than N. - self.chip.execute_instruction(instruction, from_pc, reads) - } + >( + modulus: BigUint, + opcode_offset: usize, + test_case: usize, + expected_error: VerificationError, + ) { + let mut rng = create_seeded_rng(); + let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - fn get_opcode_name(&self, opcode: usize) -> String { - as VmCoreChip>::get_opcode_name(&self.chip, opcode) - } + let vec = big_uint_to_limbs(&modulus, LIMB_BITS); + let modulus_limbs: [u8; READ_LIMBS] = + from_fn(|i| if i < vec.len() { vec[i] as u8 } else { 0 }); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - as VmCoreChip>::generate_trace_row(&self.chip, row_slice, record.clone()); - let row_slice: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = row_slice.borrow_mut(); - // decide which bug to test based on b[0] - if record.b[0] == F::ONE { - // test the constraint that c_lt_mark = 2 when is_setup = 1 - row_slice.c_lt_mark = F::ONE; - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::ONE; - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - row_slice.b_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.b[READ_LIMBS - 1]; - } else if record.b[0] == F::from_canonical_u32(2) { - // test the constraint that b[i] = N[i] for all i when prefix_sum is not 1 or - // lt_marker_sum - is_setup - row_slice.c_lt_mark = F::from_canonical_u8(2); - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - } else if record.b[0] == F::from_canonical_u32(3) { - // test the constraint that sum_i lt_marker[i] = 2 when is_setup = 1 - row_slice.c_lt_mark = F::from_canonical_u8(2); - row_slice.lt_marker = [F::ZERO; READ_LIMBS]; - row_slice.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); - row_slice.lt_marker[0] = F::ONE; - row_slice.b_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[0]) - record.b[0]; - row_slice.c_lt_diff = - F::from_canonical_u32(self.chip.air.modulus_limbs[READ_LIMBS - 1]) - - record.c[READ_LIMBS - 1]; - } - } + let (mut chip, bitwise_chip) = create_test_chips::( + &mut tester, + &modulus, + modulus_limbs, + opcode_offset, + ); - fn air(&self) -> &Self::Air { - as VmCoreChip>::air( - &self.chip, - ) - } -} + let modulus_limbs = modulus_limbs.map(F::from_canonical_u8); -// Test that passes the wrong modulus in the setup instruction. -// This proof should fail to verify. -fn test_is_equal_setup_bad< - const NUM_LANES: usize, - const LANE_SIZE: usize, - const TOTAL_LIMBS: usize, ->( - opcode_offset: usize, - modulus: BigUint, - b_val: u32, /* used to select which bug to test. currently only 1, 2, and 3 are supported - * (because there are three bugs to test) */ -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = BadModularIsEqualChip::::new( - Rv32IsEqualModAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ), - BadModularIsEqualCoreChip::new(modulus.clone(), bitwise_chip.clone(), opcode_offset), - tester.offline_memory_mutex_arc(), - ); - - let mut b_limbs = [F::ZERO; TOTAL_LIMBS]; - b_limbs[0] = F::from_canonical_u32(b_val); - let setup_instruction = rv32_write_heap_default::( - &mut tester, - vec![b_limbs], - vec![[F::ZERO; TOTAL_LIMBS]], - opcode_offset + Rv32ModularArithmeticOpcode::SETUP_ISEQ as usize, - ); - tester.execute(&mut chip, &setup_instruction); - - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} + set_and_execute_is_equal( + &mut tester, + &mut chip, + &mut rng, + &modulus, + opcode_offset, + modulus_limbs, + true, + None, + None, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_1_1x32() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 1); -} + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); + let cols: &mut ModularIsEqualCoreCols<_, READ_LIMBS> = + trace_row.split_at_mut(adapter_width).1.borrow_mut(); + if test_case == 1 { + // test the constraint that c_lt_mark = 2 when is_setup = 1 + cols.b[0] = F::from_canonical_u32(1); + cols.c_lt_mark = F::ONE; + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::ONE; + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + cols.b_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.b[READ_LIMBS - 1]; + } else if test_case == 2 { + // test the constraint that b[i] = N[i] for all i when prefix_sum is not 1 or + // lt_marker_sum - is_setup + cols.b[0] = F::from_canonical_u32(2); + cols.c_lt_mark = F::from_canonical_u8(2); + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + } else if test_case == 3 { + // test the constraint that sum_i lt_marker[i] = 2 when is_setup = 1 + cols.b[0] = F::from_canonical_u32(3); + cols.c_lt_mark = F::from_canonical_u8(2); + cols.lt_marker = [F::ZERO; READ_LIMBS]; + cols.lt_marker[READ_LIMBS - 1] = F::from_canonical_u8(2); + cols.lt_marker[0] = F::ONE; + cols.b_lt_diff = modulus_limbs[0] - cols.b[0]; + cols.c_lt_diff = modulus_limbs[READ_LIMBS - 1] - cols.c[READ_LIMBS - 1]; + } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_2_1x32_2() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 2); -} + disable_debug_builder(); + let tester = tester + .build() + .load_and_prank_trace(chip, modify_trace) + .load(bitwise_chip) + .finalize(); + tester.simple_test_with_expected_error(expected_error); + } -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_3_1x32() { - test_is_equal_setup_bad::<1, 32, 32>(17, secp256k1_coord_prime(), 3); -} + #[test] + fn negative_test_modular_is_equal_1x32() { + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 1, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_1_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 1); -} + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 2, + VerificationError::OodEvaluationMismatch, + ); -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_2_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 2); -} + run_negative_is_equal_test::<1, 32, 32>( + secp256k1_coord_prime(), + 17, + 3, + VerificationError::OodEvaluationMismatch, + ); + } -#[should_panic] -#[test] -fn test_modular_is_equal_setup_bad_3_3x16() { - test_is_equal_setup_bad::<3, 16, 48>(17, BLS12_381_MODULUS.clone(), 3); + #[test] + fn negative_test_modular_is_equal_3x16() { + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 1, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 2, + VerificationError::OodEvaluationMismatch, + ); + + run_negative_is_equal_test::<3, 16, 48>( + BLS12_381_MODULUS.clone(), + 17, + 3, + VerificationError::OodEvaluationMismatch, + ); + } } diff --git a/extensions/algebra/circuit/src/modular_extension.rs b/extensions/algebra/circuit/src/modular_extension.rs index 18a19becfa..bfaed25682 100644 --- a/extensions/algebra/circuit/src/modular_extension.rs +++ b/extensions/algebra/circuit/src/modular_extension.rs @@ -1,28 +1,37 @@ +use std::array; + use derive_more::derive::From; use num_bigint::BigUint; use openvm_algebra_transpiler::Rv32ModularArithmeticOpcode; use openvm_circuit::{ self, - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, + arch::{ + ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor}; +use openvm_circuit_primitives::{ + bigint::utils::big_uint_to_limbs, + bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, }; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{LocalOpcode, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::{Rv32IsEqualModAdapterChip, Rv32VecHeapAdapterChip}; +use openvm_rv32_adapters::{Rv32IsEqualModAdapterAir, Rv32IsEqualModeAdapterStep}; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; use strum::EnumCount; use crate::modular_chip::{ - ModularAddSubChip, ModularIsEqualChip, ModularIsEqualCoreChip, ModularMulDivChip, + ModularAddSubChip, ModularIsEqualAir, ModularIsEqualChip, ModularIsEqualCoreAir, + ModularIsEqualStep, ModularMulDivChip, }; +// TODO: this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[serde_as] #[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] pub struct ModularExtension { @@ -30,7 +39,7 @@ pub struct ModularExtension { pub supported_modulus: Vec, } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, AnyEnum, From, InsExecutorE1)] pub enum ModularExtensionExecutor { // 32 limbs prime ModularAddSubRv32_32(ModularAddSubChip), @@ -63,7 +72,11 @@ impl VmExtension for ModularExtension { program_bus, memory_bridge, } = builder.system_port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); let range_checker = builder.system_base().range_checker_chip.clone(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -75,8 +88,6 @@ impl VmExtension for ModularExtension { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; let addsub_opcodes = (Rv32ModularArithmeticOpcode::ADD as usize) ..=(Rv32ModularArithmeticOpcode::SETUP_ADDSUB as usize); @@ -101,28 +112,20 @@ impl VmExtension for ModularExtension { num_limbs: 48, limb_bits: 8, }; - let adapter_chip_32 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); - let adapter_chip_48 = Rv32VecHeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), - ); + + let modulus_limbs = big_uint_to_limbs(&modulus, 8); if bytes <= 32 { let addsub_chip = ModularAddSubChip::new( - adapter_chip_32.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( ModularExtensionExecutor::ModularAddSubRv32_32(addsub_chip), @@ -131,11 +134,15 @@ impl VmExtension for ModularExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = ModularMulDivChip::new( - adapter_chip_32.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( ModularExtensionExecutor::ModularMulDivRv32_32(muldiv_chip), @@ -143,20 +150,36 @@ impl VmExtension for ModularExtension { .clone() .map(|x| VmOpcode::from_usize(x + start_offset)), )?; + + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); let isequal_chip = ModularIsEqualChip::new( - Rv32IsEqualModAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + ModularIsEqualAir::new( + Rv32IsEqualModAdapterAir::new( + execution_bridge.clone(), + memory_bridge.clone(), + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + ModularIsEqualCoreAir::new( + modulus.clone(), + bitwise_lu_chip.bus(), + start_offset, + ), ), - ModularIsEqualCoreChip::new( - modulus.clone(), - bitwise_lu_chip.clone(), + ModularIsEqualStep::new( + Rv32IsEqualModeAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + modulus_limbs, start_offset, + bitwise_lu_chip.clone(), ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( ModularExtensionExecutor::ModularIsEqualRv32_32(isequal_chip), @@ -166,11 +189,15 @@ impl VmExtension for ModularExtension { )?; } else if bytes <= 48 { let addsub_chip = ModularAddSubChip::new( - adapter_chip_48.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( ModularExtensionExecutor::ModularAddSubRv32_48(addsub_chip), @@ -179,11 +206,15 @@ impl VmExtension for ModularExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let muldiv_chip = ModularMulDivChip::new( - adapter_chip_48.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( ModularExtensionExecutor::ModularMulDivRv32_48(muldiv_chip), @@ -191,20 +222,35 @@ impl VmExtension for ModularExtension { .clone() .map(|x| VmOpcode::from_usize(x + start_offset)), )?; + let modulus_limbs = array::from_fn(|i| { + if i < modulus_limbs.len() { + modulus_limbs[i] as u8 + } else { + 0 + } + }); let isequal_chip = ModularIsEqualChip::new( - Rv32IsEqualModAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + ModularIsEqualAir::new( + Rv32IsEqualModAdapterAir::new( + execution_bridge.clone(), + memory_bridge.clone(), + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + ModularIsEqualCoreAir::new( + modulus.clone(), + bitwise_lu_chip.bus(), + start_offset, + ), ), - ModularIsEqualCoreChip::new( - modulus.clone(), - bitwise_lu_chip.clone(), + ModularIsEqualStep::new( + Rv32IsEqualModeAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + modulus_limbs, start_offset, + bitwise_lu_chip.clone(), ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( ModularExtensionExecutor::ModularIsEqualRv32_48(isequal_chip), diff --git a/extensions/algebra/moduli-macros/src/lib.rs b/extensions/algebra/moduli-macros/src/lib.rs index 5d8d921f2f..18931d05fb 100644 --- a/extensions/algebra/moduli-macros/src/lib.rs +++ b/extensions/algebra/moduli-macros/src/lib.rs @@ -734,7 +734,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { let mut externs = Vec::new(); let mut setups = Vec::new(); - let mut openvm_section = Vec::new(); let mut setup_all_moduli = Vec::new(); // List of all modular limbs in one (that is, with a compile-time known size) array. @@ -746,8 +745,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { for (mod_idx, item) in items.into_iter().enumerate() { let modulus = item.value(); - println!("[init] modulus #{} = {}", mod_idx, modulus); - let modulus_bytes = string_to_bytes(&modulus); let mut limbs = modulus_bytes.len(); let mut block_size = 32; @@ -782,28 +779,8 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { .collect::>() .join(""); - let serialized_modulus = - core::iter::once(1) // 1 for "modulus" - .chain(core::iter::once(mod_idx as u8)) // mod_idx is u8 for now (can make it u32), because we don't know the order of - // variables in the elf - .chain((modulus_bytes.len() as u32).to_le_bytes().iter().copied()) - .chain(modulus_bytes.iter().copied()) - .collect::>(); - let serialized_name = syn::Ident::new( - &format!("OPENVM_SERIALIZED_MODULUS_{}", mod_idx), - span.into(), - ); - let serialized_len = serialized_modulus.len(); let setup_function = syn::Ident::new(&format!("setup_{}", mod_idx), span.into()); - openvm_section.push(quote::quote_spanned! { span.into() => - #[cfg(target_os = "zkvm")] - #[link_section = ".openvm"] - #[no_mangle] - #[used] - static #serialized_name: [u8; #serialized_len] = [#(#serialized_modulus),*]; - }); - for op_type in ["add", "sub", "mul", "div"] { let func_name = syn::Ident::new( &format!("{}_extern_func_{}", op_type, modulus_hex), @@ -857,19 +834,12 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { pub fn #setup_function() { #[cfg(target_os = "zkvm")] { - let mut ptr = 0; - assert_eq!(#serialized_name[ptr], 1); - ptr += 1; - assert_eq!(#serialized_name[ptr], #mod_idx as u8); - ptr += 1; - assert_eq!(#serialized_name[ptr..ptr+4].iter().rev().fold(0, |acc, &x| acc * 256 + x as usize), #limbs); - ptr += 4; - let remaining = &#serialized_name[ptr..]; - // To avoid importing #struct_name, we create a placeholder struct with the same size and alignment. #[repr(C, align(#block_size))] struct AlignedPlaceholder([u8; #limbs]); + const MODULUS_BYTES: AlignedPlaceholder = AlignedPlaceholder([#(#modulus_bytes),*]); + // We are going to use the numeric representation of the `rs2` register to distinguish the chip to setup. // The transpiler will transform this instruction, based on whether `rs2` is `x0`, `x1` or `x2`, into a `SETUP_ADDSUB`, `SETUP_MULDIV` or `SETUP_ISEQ` instruction. let mut uninit: core::mem::MaybeUninit = core::mem::MaybeUninit::uninit(); @@ -880,7 +850,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = In uninit.as_mut_ptr(), - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_ADDMOD ); openvm::platform::custom_insn_r!( @@ -890,7 +860,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = In uninit.as_mut_ptr(), - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x1" // will be parsed as 1 and therefore transpiled to SETUP_MULDIV ); unsafe { @@ -903,7 +873,7 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize), rd = InOut tmp, - rs1 = In remaining.as_ptr(), + rs1 = In MODULUS_BYTES.0.as_ptr(), rs2 = Const "x2" // will be parsed as 2 and therefore transpiled to SETUP_ISEQ ); // rd = inout(reg) is necessary because this instruction will write to `rd` register @@ -916,7 +886,6 @@ pub fn moduli_init(input: TokenStream) -> TokenStream { let total_limbs_cnt = two_modular_limbs_flattened_list.len(); let cnt_limbs_list_len = limb_list_borders.len(); TokenStream::from(quote::quote_spanned! { span.into() => - #(#openvm_section)* #[cfg(target_os = "zkvm")] mod openvm_intrinsics_ffi { #(#externs)* diff --git a/extensions/bigint/circuit/Cargo.toml b/extensions/bigint/circuit/Cargo.toml index 09d68a9d1b..7d133ff151 100644 --- a/extensions/bigint/circuit/Cargo.toml +++ b/extensions/bigint/circuit/Cargo.toml @@ -29,6 +29,7 @@ serde.workspace = true openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } openvm-rv32-adapters = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel", "jemalloc"] diff --git a/extensions/bigint/circuit/src/extension.rs b/extensions/bigint/circuit/src/extension.rs index b9eeeafd99..fac415c965 100644 --- a/extensions/bigint/circuit/src/extension.rs +++ b/extensions/bigint/circuit/src/extension.rs @@ -5,11 +5,12 @@ use openvm_bigint_transpiler::{ }; use openvm_circuit::{ arch::{ - SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + ExecutionBridge, SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, + VmInventoryError, }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, @@ -25,6 +26,9 @@ use serde::{Deserialize, Serialize}; use crate::*; +// TODO: this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Int256Rv32Config { #[system] @@ -69,7 +73,7 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { [1 << 8, 32 * (1 << 8)] } -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, From, AnyEnum)] pub enum Int256Executor { BaseAlu256(Rv32BaseAlu256Chip), LessThan256(Rv32LessThan256Chip), @@ -101,6 +105,8 @@ impl VmExtension for Int256 { program_bus, memory_bridge, } = builder.system_port(); + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker_chip = builder.system_base().range_checker_chip.clone(); let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() @@ -113,8 +119,8 @@ impl VmExtension for Int256 { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; + + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; let range_tuple_chip = if let Some(chip) = builder .find_chip::>() @@ -133,66 +139,97 @@ impl VmExtension for Int256 { }; let base_alu_chip = Rv32BaseAlu256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BaseAluCoreAir::new(bitwise_lu_chip.bus(), Rv32BaseAlu256Opcode::CLASS_OFFSET), + ), + Rv32BaseAlu256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + Rv32BaseAlu256Opcode::CLASS_OFFSET, ), - BaseAluCoreChip::new(bitwise_lu_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( base_alu_chip, Rv32BaseAlu256Opcode::iter().map(|x| x.global_opcode()), )?; let less_than_chip = Rv32LessThan256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + LessThanCoreAir::new(bitwise_lu_chip.bus(), Rv32LessThan256Opcode::CLASS_OFFSET), + ), + Rv32LessThan256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + Rv32LessThan256Opcode::CLASS_OFFSET, ), - LessThanCoreChip::new(bitwise_lu_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( less_than_chip, Rv32LessThan256Opcode::iter().map(|x| x.global_opcode()), )?; let branch_equal_chip = Rv32BranchEqual256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BranchEqualCoreAir::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + Rv32BranchEqual256Step::new( + Rv32HeapBranchAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + Rv32BranchEqual256Opcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( branch_equal_chip, Rv32BranchEqual256Opcode::iter().map(|x| x.global_opcode()), )?; let branch_less_than_chip = Rv32BranchLessThan256Chip::new( - Rv32HeapBranchAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + BranchLessThanCoreAir::new( + bitwise_lu_chip.bus(), + Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ), ), - BranchLessThanCoreChip::new( + Rv32BranchLessThan256Step::new( + Rv32HeapBranchAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), Rv32BranchLessThan256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( branch_less_than_chip, @@ -200,36 +237,53 @@ impl VmExtension for Int256 { )?; let multiplication_chip = Rv32Multiplication256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + MultiplicationCoreAir::new(*range_tuple_chip.bus(), Rv32Mul256Opcode::CLASS_OFFSET), ), - MultiplicationCoreChip::new(range_tuple_chip, Rv32Mul256Opcode::CLASS_OFFSET), - offline_memory.clone(), + Rv32Multiplication256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), + range_tuple_chip.clone(), + Rv32Mul256Opcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( multiplication_chip, Rv32Mul256Opcode::iter().map(|x| x.global_opcode()), )?; let shift_chip = Rv32Shift256Chip::new( - Rv32HeapAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - address_bits, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + ), + ShiftCoreAir::new( + bitwise_lu_chip.bus(), + range_checker_chip.bus(), + Rv32Shift256Opcode::CLASS_OFFSET, + ), ), - ShiftCoreChip::new( + Rv32Shift256Step::new( + Rv32HeapAdapterStep::new(pointer_max_bits, bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), - range_checker_chip, + range_checker_chip.clone(), Rv32Shift256Opcode::CLASS_OFFSET, ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( shift_chip, Rv32Shift256Opcode::iter().map(|x| x.global_opcode()), diff --git a/extensions/bigint/circuit/src/lib.rs b/extensions/bigint/circuit/src/lib.rs index 295ef73db2..ba971f27a5 100644 --- a/extensions/bigint/circuit/src/lib.rs +++ b/extensions/bigint/circuit/src/lib.rs @@ -1,9 +1,15 @@ -use openvm_circuit::{self, arch::VmChipWrapper}; -use openvm_rv32_adapters::{Rv32HeapAdapterChip, Rv32HeapBranchAdapterChip}; +use openvm_circuit::{ + self, + arch::{NewVmChipWrapper, VmAirWrapper}, +}; +use openvm_rv32_adapters::{ + Rv32HeapAdapterAir, Rv32HeapAdapterStep, Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep, +}; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV32_CELL_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, + BaseAluCoreAir, BaseAluStep, BranchEqualCoreAir, BranchEqualStep, BranchLessThanCoreAir, + BranchLessThanStep, LessThanCoreAir, LessThanStep, MultiplicationCoreAir, MultiplicationStep, + ShiftCoreAir, ShiftStep, }; mod extension; @@ -12,38 +18,74 @@ pub use extension::*; #[cfg(test)] mod tests; -pub type Rv32BaseAlu256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - BaseAluCoreChip, +/// BaseAlu256 +pub type Rv32BaseAlu256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + BaseAluCoreAir, >; +pub type Rv32BaseAlu256Step = BaseAluStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, +>; +pub type Rv32BaseAlu256Chip = NewVmChipWrapper; -pub type Rv32LessThan256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - LessThanCoreChip, +/// LessThan256 +pub type Rv32LessThan256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + LessThanCoreAir, +>; +pub type Rv32LessThan256Step = LessThanStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, >; +pub type Rv32LessThan256Chip = NewVmChipWrapper; -pub type Rv32Multiplication256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - MultiplicationCoreChip, +/// Multiplication256 +pub type Rv32Multiplication256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + MultiplicationCoreAir, >; +pub type Rv32Multiplication256Step = MultiplicationStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, +>; +pub type Rv32Multiplication256Chip = + NewVmChipWrapper; -pub type Rv32Shift256Chip = VmChipWrapper< - F, - Rv32HeapAdapterChip, - ShiftCoreChip, +/// Shift256 +pub type Rv32Shift256Air = VmAirWrapper< + Rv32HeapAdapterAir<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + ShiftCoreAir, +>; +pub type Rv32Shift256Step = ShiftStep< + Rv32HeapAdapterStep<2, INT256_NUM_LIMBS, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, >; +pub type Rv32Shift256Chip = NewVmChipWrapper; -pub type Rv32BranchEqual256Chip = VmChipWrapper< - F, - Rv32HeapBranchAdapterChip, - BranchEqualCoreChip, +/// BranchEqual256 +pub type Rv32BranchEqual256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchEqualCoreAir, >; +pub type Rv32BranchEqual256Step = + BranchEqualStep, INT256_NUM_LIMBS>; +pub type Rv32BranchEqual256Chip = + NewVmChipWrapper; -pub type Rv32BranchLessThan256Chip = VmChipWrapper< - F, - Rv32HeapBranchAdapterChip, - BranchLessThanCoreChip, +/// BranchLessThan256 +pub type Rv32BranchLessThan256Air = VmAirWrapper< + Rv32HeapBranchAdapterAir<2, INT256_NUM_LIMBS>, + BranchLessThanCoreAir, +>; +pub type Rv32BranchLessThan256Step = BranchLessThanStep< + Rv32HeapBranchAdapterStep<2, INT256_NUM_LIMBS>, + INT256_NUM_LIMBS, + RV32_CELL_BITS, >; +pub type Rv32BranchLessThan256Chip = + NewVmChipWrapper; diff --git a/extensions/bigint/circuit/src/tests.rs b/extensions/bigint/circuit/src/tests.rs index 0e26352410..abaf29345c 100644 --- a/extensions/bigint/circuit/src/tests.rs +++ b/extensions/bigint/circuit/src/tests.rs @@ -5,7 +5,7 @@ use openvm_bigint_transpiler::{ use openvm_circuit::{ arch::{ testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS}, - InstructionExecutor, + InstructionExecutor, VmAirWrapper, }, utils::generate_long_number, }; @@ -13,171 +13,176 @@ use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; -use openvm_instructions::{program::PC_BITS, riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_instructions::{ + program::{DEFAULT_PC_STEP, PC_BITS}, + riscv::RV32_CELL_BITS, + LocalOpcode, +}; use openvm_rv32_adapters::{ - rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterChip, - Rv32HeapBranchAdapterChip, + rv32_heap_branch_default, rv32_write_heap_default, Rv32HeapAdapterAir, Rv32HeapAdapterStep, + Rv32HeapBranchAdapterAir, Rv32HeapBranchAdapterStep, }; use openvm_rv32im_circuit::{ adapters::{INT256_NUM_LIMBS, RV_B_TYPE_IMM_BITS}, - BaseAluCoreChip, BranchEqualCoreChip, BranchLessThanCoreChip, LessThanCoreChip, - MultiplicationCoreChip, ShiftCoreChip, + BaseAluCoreAir, BranchEqualCoreAir, BranchLessThanCoreAir, LessThanCoreAir, + MultiplicationCoreAir, ShiftCoreAir, }; use openvm_rv32im_transpiler::{ - BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, ShiftOpcode, + BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, MulOpcode, ShiftOpcode, }; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::{ Rv32BaseAlu256Chip, Rv32BranchEqual256Chip, Rv32BranchLessThan256Chip, Rv32LessThan256Chip, Rv32Multiplication256Chip, Rv32Shift256Chip, }; +use crate::{ + Rv32BaseAlu256Step, Rv32BranchEqual256Step, Rv32BranchLessThan256Step, Rv32LessThan256Step, + Rv32Multiplication256Step, Rv32Shift256Step, +}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); #[allow(clippy::type_complexity)] -fn run_int_256_rand_execute>( - opcode: usize, - num_ops: usize, - executor: &mut E, +fn set_and_execute_rand>( tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: usize, branch_fn: Option bool>, ) { - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); - - let mut rng = create_seeded_rng(); let branch = branch_fn.is_some(); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - if branch { - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - let instruction = rv32_heap_branch_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - imm as isize, - opcode, - ); - - tester.execute_with_pc( - executor, - &instruction, - rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), - ); - - let cmp_result = branch_fn.unwrap()(opcode, &b, &c); - let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; - let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; - assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); - } else { - let instruction = rv32_write_heap_default( - tester, - vec![b.map(F::from_canonical_u32)], - vec![c.map(F::from_canonical_u32)], - opcode, - ); - tester.execute(executor, &instruction); - } + let b = generate_long_number::(rng); + let c = generate_long_number::(rng); + if branch { + let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); + let instruction = rv32_heap_branch_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + imm as isize, + opcode, + ); + + tester.execute_with_pc( + chip, + &instruction, + rng.gen_range((ABS_MAX_BRANCH as u32)..(1 << (PC_BITS - 1))), + ); + + let cmp_result = branch_fn.unwrap()(opcode, &b, &c); + let from_pc = tester.execution.last_from_pc().as_canonical_u32() as i32; + let to_pc = tester.execution.last_to_pc().as_canonical_u32() as i32; + assert_eq!(to_pc, from_pc + if cmp_result { imm } else { 4 }); + } else { + let instruction = rv32_write_heap_default( + tester, + vec![b.map(F::from_canonical_u32)], + vec![c.map(F::from_canonical_u32)], + opcode, + ); + tester.execute(chip, &instruction); } } +#[test_case(BaseAluOpcode::ADD, 24)] +#[test_case(BaseAluOpcode::SUB, 24)] +#[test_case(BaseAluOpcode::XOR, 24)] +#[test_case(BaseAluOpcode::OR, 24)] +#[test_case(BaseAluOpcode::AND, 24)] fn run_alu_256_rand_test(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BaseAlu256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAlu256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BaseAluCoreAir::new(bitwise_bus, offset), + ), + Rv32BaseAlu256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), + offset, ), - BaseAluCoreChip::new(bitwise_chip.clone(), Rv32BaseAlu256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32BaseAlu256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn alu_256_add_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::ADD, 24); -} - -#[test] -fn alu_256_sub_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::SUB, 24); -} - -#[test] -fn alu_256_xor_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::XOR, 24); -} - -#[test] -fn alu_256_or_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::OR, 24); -} - -#[test] -fn alu_256_and_rand_test() { - run_alu_256_rand_test(BaseAluOpcode::AND, 24); -} - +#[test_case(LessThanOpcode::SLT, 24)] +#[test_case(LessThanOpcode::SLTU, 24)] fn run_lt_256_rand_test(opcode: LessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32LessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32LessThan256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + LessThanCoreAir::new(bitwise_bus, offset), + ), + Rv32LessThan256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), + offset, ), - LessThanCoreChip::new(bitwise_chip.clone(), Rv32LessThan256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32LessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn lt_256_slt_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLT, 24); -} - -#[test] -fn lt_256_sltu_rand_test() { - run_lt_256_rand_test(LessThanOpcode::SLTU, 24); -} +#[test_case(MulOpcode::MUL, 24)] +fn run_mul_256_rand_test(opcode: MulOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Mul256Opcode::CLASS_OFFSET; -fn run_mul_256_rand_test(num_ops: usize) { let range_tuple_bus = RangeTupleCheckerBus::new( RANGE_TUPLE_CHECKER_BUS, [ @@ -185,105 +190,120 @@ fn run_mul_256_rand_test(num_ops: usize) { (INT256_NUM_LIMBS * (1 << RV32_CELL_BITS)) as u32, ], ); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Multiplication256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + MultiplicationCoreAir::new(range_tuple_bus, offset), + ), + Rv32Multiplication256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + range_tuple_chip.clone(), + offset, ), - MultiplicationCoreChip::new(range_tuple_checker.clone(), Rv32Mul256Opcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - Rv32Mul256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } let tester = tester .build() .load(chip) - .load(range_tuple_checker) + .load(range_tuple_chip) .load(bitwise_chip) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn mul_256_rand_test() { - run_mul_256_rand_test(24); -} - +#[test_case(ShiftOpcode::SLL, 24)] +#[test_case(ShiftOpcode::SRL, 24)] +#[test_case(ShiftOpcode::SRA, 24)] fn run_shift_256_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32Shift256Opcode::CLASS_OFFSET; + + let range_checker_chip = tester.range_checker(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32Shift256Chip::::new( - Rv32HeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + ShiftCoreAir::new(bitwise_bus, range_checker_chip.bus(), offset), ), - ShiftCoreChip::new( + Rv32Shift256Step::new( + Rv32HeapAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - tester.memory_controller().borrow().range_checker.clone(), - Rv32Shift256Opcode::CLASS_OFFSET, + range_checker_chip.clone(), + offset, ), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - run_int_256_rand_execute( - opcode.local_usize() + Rv32Shift256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - None, - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + None, + ); + } + + drop(range_checker_chip); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn shift_256_sll_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SLL, 24); -} - -#[test] -fn shift_256_srl_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRL, 24); -} - -#[test] -fn shift_256_sra_rand_test() { - run_shift_256_rand_test(ShiftOpcode::SRA, 24); -} - +#[test_case(BranchEqualOpcode::BEQ, 24)] +#[test_case(BranchEqualOpcode::BNE, 24)] fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchEqual256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut chip = Rv32BranchEqual256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BranchEqualCoreAir::new(offset, DEFAULT_PC_STEP), + ), + Rv32BranchEqual256Step::new( + Rv32HeapBranchAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), + offset, + DEFAULT_PC_STEP, ), - BranchEqualCoreChip::new(Rv32BranchEqual256Opcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { @@ -294,93 +314,79 @@ fn run_beq_256_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { == BranchEqualOpcode::BNE.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET) }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchEqual256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn beq_256_beq_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BEQ, 24); -} - -#[test] -fn beq_256_bne_rand_test() { - run_beq_256_rand_test(BranchEqualOpcode::BNE, 24); -} - +#[test_case(BranchLessThanOpcode::BLT, 24)] +#[test_case(BranchLessThanOpcode::BLTU, 24)] +#[test_case(BranchLessThanOpcode::BGE, 24)] +#[test_case(BranchLessThanOpcode::BGEU, 24)] fn run_blt_256_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let offset = Rv32BranchLessThan256Opcode::CLASS_OFFSET; + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); let mut chip = Rv32BranchLessThan256Chip::::new( - Rv32HeapBranchAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), + VmAirWrapper::new( + Rv32HeapBranchAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + ), + BranchLessThanCoreAir::new(bitwise_bus, offset), ), - BranchLessThanCoreChip::new( + Rv32BranchLessThan256Step::new( + Rv32HeapBranchAdapterStep::new(tester.address_bits(), bitwise_chip.clone()), bitwise_chip.clone(), - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + offset, ), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - - let branch_fn = |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| { - let opcode = - BranchLessThanOpcode::from_usize(opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET); - let (is_ge, is_signed) = match opcode { - BranchLessThanOpcode::BLT => (false, true), - BranchLessThanOpcode::BLTU => (false, false), - BranchLessThanOpcode::BGE => (true, true), - BranchLessThanOpcode::BGEU => (true, false), - }; - let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; - for (x, y) in x.iter().rev().zip(y.iter().rev()) { - if x != y { - return (x < y) ^ x_sign ^ y_sign ^ is_ge; + let branch_fn = + |opcode: usize, x: &[u32; INT256_NUM_LIMBS], y: &[u32; INT256_NUM_LIMBS]| -> bool { + let opcode = BranchLessThanOpcode::from_usize( + opcode - Rv32BranchLessThan256Opcode::CLASS_OFFSET, + ); + let (is_ge, is_signed) = match opcode { + BranchLessThanOpcode::BLT => (false, true), + BranchLessThanOpcode::BLTU => (false, false), + BranchLessThanOpcode::BGE => (true, true), + BranchLessThanOpcode::BGEU => (true, false), + }; + let x_sign = x[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + let y_sign = y[INT256_NUM_LIMBS - 1] >> (RV32_CELL_BITS - 1) != 0 && is_signed; + for (x, y) in x.iter().rev().zip(y.iter().rev()) { + if x != y { + return (x < y) ^ x_sign ^ y_sign ^ is_ge; + } } - } - is_ge - }; + is_ge + }; - run_int_256_rand_execute( - opcode.local_usize() + Rv32BranchLessThan256Opcode::CLASS_OFFSET, - num_ops, - &mut chip, - &mut tester, - Some(branch_fn), - ); + for _ in 0..num_ops { + set_and_execute_rand( + &mut tester, + &mut chip, + &mut rng, + opcode.local_usize() + offset, + Some(branch_fn), + ); + } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } - -#[test] -fn blt_256_blt_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLT, 24); -} - -#[test] -fn blt_256_bltu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BLTU, 24); -} - -#[test] -fn blt_256_bge_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGE, 24); -} - -#[test] -fn blt_256_bgeu_rand_test() { - run_blt_256_rand_test(BranchLessThanOpcode::BGEU, 24); -} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs index 24bcc52ef3..32ece53961 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/add_ne.rs @@ -1,7 +1,24 @@ use std::{cell::RefCell, rc::Rc}; -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, +}; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_mod_circuit_builder::{ + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, +}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{WeierstrassAir, WeierstrassChip, WeierstrassStep}; // Assumes that (x1, y1), (x2, y2) both lie on the curve and are not the identity point. // Further assumes that x1, x2 are not equal in the coordinate field. @@ -26,3 +43,58 @@ pub fn ec_add_ne_expr( let builder = builder.borrow().clone(); FieldExpr::new(builder, range_bus, true) } + +/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. +/// BLOCKS: how many blocks do we need to represent one input or output +/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per +/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. + +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1)] +pub struct EcAddNeChip( + pub WeierstrassChip, +); + +impl + EcAddNeChip +{ + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + height: usize, + ) -> Self { + let expr = ec_add_ne_expr(config, range_checker.bus()); + + let local_opcode_idx = vec![ + Rv32WeierstrassOpcode::EC_ADD_NE as usize, + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, + ]; + + let air = WeierstrassAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ); + + let step = WeierstrassStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + vec![], + range_checker, + "EcAddNe", + false, + ); + Self(WeierstrassChip::new(air, step, height, mem_helper)) + } +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/double.rs b/extensions/ecc/circuit/src/weierstrass_chip/double.rs index 0ae55f2df7..e0478600c0 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/double.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/double.rs @@ -2,8 +2,25 @@ use std::{cell::RefCell, rc::Rc}; use num_bigint::BigUint; use num_traits::One; -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr, FieldVariable}; +use openvm_circuit::{ + arch::ExecutionBridge, + system::memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, +}; +use openvm_circuit_derive::{InsExecutorE1, InstructionExecutor}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, + Chip, ChipUsageGetter, +}; +use openvm_ecc_transpiler::Rv32WeierstrassOpcode; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_mod_circuit_builder::{ + ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreAir, FieldVariable, +}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{WeierstrassAir, WeierstrassChip, WeierstrassStep}; pub fn ec_double_ne_expr( config: ExprBuilderConfig, // The coordinate field. @@ -34,3 +51,58 @@ pub fn ec_double_ne_expr( let builder = builder.borrow().clone(); FieldExpr::new_with_setup_values(builder, range_bus, true, vec![a_biguint]) } + +/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. +/// BLOCKS: how many blocks do we need to represent one input or output +/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per +/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. +#[derive(Chip, ChipUsageGetter, InstructionExecutor, InsExecutorE1)] +pub struct EcDoubleChip( + pub WeierstrassChip, +); + +impl + EcDoubleChip +{ + pub fn new( + execution_bridge: ExecutionBridge, + memory_bridge: MemoryBridge, + mem_helper: SharedMemoryHelper, + pointer_max_bits: usize, + config: ExprBuilderConfig, + offset: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + range_checker: SharedVariableRangeCheckerChip, + a_biguint: BigUint, + height: usize, + ) -> Self { + let expr = ec_double_ne_expr(config, range_checker.bus(), a_biguint); + + let local_opcode_idx = vec![ + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + ]; + + let air = WeierstrassAir::new( + Rv32VecHeapAdapterAir::new( + execution_bridge, + memory_bridge, + bitwise_lookup_chip.bus(), + pointer_max_bits, + ), + FieldExpressionCoreAir::new(expr.clone(), offset, local_opcode_idx.clone(), vec![]), + ); + + let step = WeierstrassStep::new( + Rv32VecHeapAdapterStep::new(pointer_max_bits, bitwise_lookup_chip), + expr, + offset, + local_opcode_idx, + vec![], + range_checker, + "EcDouble", + true, + ); + Self(WeierstrassChip::new(air, step, height, mem_helper)) + } +} diff --git a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs index 0bcee1facf..2a837b6aac 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/mod.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/mod.rs @@ -1,99 +1,38 @@ mod add_ne; mod double; -use std::sync::Arc; - pub use add_ne::*; pub use double::*; #[cfg(test)] mod tests; -use std::sync::Mutex; - -use num_bigint::BigUint; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::SharedVariableRangeCheckerChip; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_ecc_transpiler::Rv32WeierstrassOpcode; -use openvm_mod_circuit_builder::{ExprBuilderConfig, FieldExpressionCoreChip}; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -/// BLOCK_SIZE: how many cells do we read at a time, must be a power of 2. -/// BLOCKS: how many blocks do we need to represent one input or output -/// For example, for bls12_381, BLOCK_SIZE = 16, each element has 3 blocks and with two elements per -/// input AffinePoint, BLOCKS = 6. For secp256k1, BLOCK_SIZE = 32, BLOCKS = 2. -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcAddNeChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - EcAddNeChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = ec_add_ne_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Rv32WeierstrassOpcode::EC_ADD_NE as usize, - Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, - ], - vec![], - range_checker, - "EcAddNe", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcDoubleChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - EcDoubleChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - offset: usize, - a: BigUint, - offline_memory: Arc>>, - ) -> Self { - let expr = ec_double_ne_expr(config, range_checker.bus(), a); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![ - Rv32WeierstrassOpcode::EC_DOUBLE as usize, - Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, - ], - vec![], - range_checker, - "EcDouble", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; +use openvm_mod_circuit_builder::{FieldExpressionCoreAir, FieldExpressionStep}; +use openvm_rv32_adapters::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterStep}; + +pub(crate) type WeierstrassAir< + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = VmAirWrapper< + Rv32VecHeapAdapterAir, + FieldExpressionCoreAir, +>; + +pub(crate) type WeierstrassStep< + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = FieldExpressionStep>; + +pub(crate) type WeierstrassChip< + F, + const NUM_READS: usize, + const BLOCKS: usize, + const BLOCK_SIZE: usize, +> = NewVmChipWrapper< + F, + WeierstrassAir, + WeierstrassStep, +>; diff --git a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs index 213918ec2e..2809dc23c1 100644 --- a/extensions/ecc/circuit/src/weierstrass_chip/tests.rs +++ b/extensions/ecc/circuit/src/weierstrass_chip/tests.rs @@ -10,7 +10,7 @@ use openvm_circuit_primitives::{ use openvm_ecc_transpiler::Rv32WeierstrassOpcode; use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; use openvm_mod_circuit_builder::{test_utils::biguint_to_limbs, ExprBuilderConfig, FieldExpr}; -use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; +use openvm_rv32_adapters::rv32_write_heap_default; use openvm_stark_backend::p3_field::FieldAlgebra; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -19,6 +19,7 @@ use super::{EcAddNeChip, EcDoubleChip}; const NUM_LIMBS: usize = 32; const LIMB_BITS: usize = 8; const BLOCK_SIZE: usize = 32; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; lazy_static::lazy_static! { @@ -87,21 +88,20 @@ fn test_add_ne() { }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + + let mut chip = EcAddNeChip::::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcAddNeChip::new( - adapter, config, Rv32WeierstrassOpcode::CLASS_OFFSET, + bitwise_chip.clone(), tester.range_checker(), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, ); - assert_eq!(chip.0.core.expr().builder.num_variables, 3); // lambda, x3, y3 + + assert_eq!(chip.0.step.expr.builder.num_variables, 3); // lambda, x3, y3 let (p1_x, p1_y) = SampleEcPoints[0].clone(); let (p2_x, p2_y) = SampleEcPoints[1].clone(); @@ -117,21 +117,21 @@ fn test_add_ne() { let r = chip .0 - .core - .expr() + .step + .expr .execute(vec![p1_x, p1_y, p2_x, p2_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 assert_eq!(r[1], SampleEcPoints[2].0); assert_eq!(r[2], SampleEcPoints[2].1); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(chip.0.core.expr()).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.step.expr).try_into().unwrap(); let mut one_limbs = [BabyBear::ONE; NUM_LIMBS]; one_limbs[0] = BabyBear::ONE; let setup_instruction = rv32_write_heap_default( &mut tester, vec![prime_limbs, one_limbs], // inputs[0] = prime, others doesn't matter vec![one_limbs, one_limbs], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, + chip.0.step.offset + Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -139,7 +139,7 @@ fn test_add_ne() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![p2_x_limbs, p2_y_limbs], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_ADD_NE as usize, + chip.0.step.offset + Rv32WeierstrassOpcode::EC_ADD_NE as usize, ); tester.execute(&mut chip, &instruction); @@ -159,12 +159,17 @@ fn test_double() { }; let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + let mut chip = EcDoubleChip::::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), + config, + Rv32WeierstrassOpcode::CLASS_OFFSET, bitwise_chip.clone(), + tester.range_checker(), + BigUint::zero(), + MAX_INS_CAPACITY, ); let (p1_x, p1_y) = SampleEcPoints[1].clone(); @@ -173,29 +178,21 @@ fn test_double() { let p1_y_limbs = biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); - let mut chip = EcDoubleChip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - config, - Rv32WeierstrassOpcode::CLASS_OFFSET, - BigUint::zero(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!(chip.0.core.air.expr.builder.num_variables, 3); // lambda, x3, y3 + assert_eq!(chip.0.step.expr.builder.num_variables, 3); // lambda, x3, y3 - let r = chip.0.core.air.expr.execute(vec![p1_x, p1_y], vec![true]); + let r = chip.0.step.expr.execute(vec![p1_x, p1_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 assert_eq!(r[1], SampleEcPoints[3].0); assert_eq!(r[2], SampleEcPoints[3].1); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.core.air.expr).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.step.expr).try_into().unwrap(); let a_limbs = [BabyBear::ZERO; NUM_LIMBS]; let setup_instruction = rv32_write_heap_default( &mut tester, vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + chip.0.step.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -203,7 +200,7 @@ fn test_double() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + chip.0.step.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, ); tester.execute(&mut chip, &instruction); @@ -227,12 +224,18 @@ fn test_p256_double() { .unwrap(); let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), + + let mut chip = EcDoubleChip::::new( + tester.execution_bridge(), tester.memory_bridge(), + tester.memory_helper(), tester.address_bits(), + config, + Rv32WeierstrassOpcode::CLASS_OFFSET, bitwise_chip.clone(), + tester.range_checker(), + a.clone(), + MAX_INS_CAPACITY, ); // Testing data from: http://point-at-infinity.org/ecc/nisttv @@ -251,17 +254,9 @@ fn test_p256_double() { let p1_y_limbs = biguint_to_limbs::(p1_y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); - let mut chip = EcDoubleChip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - config, - Rv32WeierstrassOpcode::CLASS_OFFSET, - a.clone(), - tester.offline_memory_mutex_arc(), - ); - assert_eq!(chip.0.core.air.expr.builder.num_variables, 3); // lambda, x3, y3 + assert_eq!(chip.0.step.expr.builder.num_variables, 3); // lambda, x3, y3 - let r = chip.0.core.air.expr.execute(vec![p1_x, p1_y], vec![true]); + let r = chip.0.step.expr.execute(vec![p1_x, p1_y], vec![true]); assert_eq!(r.len(), 3); // lambda, x3, y3 let expected_double_x = BigUint::from_str_radix( "7CF27B188D034F7E8A52380304B51AC3C08969E277F21B35A60B48FC47669978", @@ -276,7 +271,7 @@ fn test_p256_double() { assert_eq!(r[1], expected_double_x); assert_eq!(r[2], expected_double_y); - let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.core.air.expr).try_into().unwrap(); + let prime_limbs: [BabyBear; NUM_LIMBS] = prime_limbs(&chip.0.step.expr).try_into().unwrap(); let a_limbs = biguint_to_limbs::(a.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32); let setup_instruction = rv32_write_heap_default( @@ -284,7 +279,7 @@ fn test_p256_double() { vec![prime_limbs, a_limbs], /* inputs[0] = prime, inputs[1] = a coeff of weierstrass * equation */ vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, + chip.0.step.offset + Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize, ); tester.execute(&mut chip, &setup_instruction); @@ -292,9 +287,12 @@ fn test_p256_double() { &mut tester, vec![p1_x_limbs, p1_y_limbs], vec![], - chip.0.core.air.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, + chip.0.step.offset + Rv32WeierstrassOpcode::EC_DOUBLE as usize, ); + tester.execute(&mut chip, &instruction); + // Adding another row to make sure there are dummy rows, and that the dummy row constraints are + // satisfied tester.execute(&mut chip, &instruction); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); diff --git a/extensions/ecc/circuit/src/weierstrass_extension.rs b/extensions/ecc/circuit/src/weierstrass_extension.rs index c5b23ccd0d..c2ce65406c 100644 --- a/extensions/ecc/circuit/src/weierstrass_extension.rs +++ b/extensions/ecc/circuit/src/weierstrass_extension.rs @@ -4,10 +4,12 @@ use num_traits::{FromPrimitive, Zero}; use once_cell::sync::Lazy; use openvm_algebra_guest::IntMod; use openvm_circuit::{ - arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, + arch::{ + ExecutionBridge, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; @@ -19,7 +21,6 @@ use openvm_ecc_guest::{ use openvm_ecc_transpiler::{EccPhantom, Rv32WeierstrassOpcode}; use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode}; use openvm_mod_circuit_builder::ExprBuilderConfig; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use serde_with::{serde_as, DisplayFromStr}; @@ -27,6 +28,9 @@ use strum::EnumCount; use super::{EcAddNeChip, EcDoubleChip}; +// TODO: this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[serde_as] #[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)] pub struct CurveConfig { @@ -63,7 +67,7 @@ pub struct WeierstrassExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, InsExecutorE1)] pub enum WeierstrassExtensionExecutor { // 32 limbs prime EcAddNeRv32_32(EcAddNeChip), @@ -93,6 +97,11 @@ impl VmExtension for WeierstrassExtension { program_bus, memory_bridge, } = builder.system_port(); + + let execution_bridge = ExecutionBridge::new(execution_bus, program_bus); + let range_checker = builder.system_base().range_checker_chip.clone(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -104,9 +113,7 @@ impl VmExtension for WeierstrassExtension { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - let pointer_bits = builder.system_config().memory_config.pointer_max_bits; + let ec_add_ne_opcodes = (Rv32WeierstrassOpcode::EC_ADD_NE as usize) ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize); let ec_double_opcodes = (Rv32WeierstrassOpcode::EC_DOUBLE as usize) @@ -128,18 +135,17 @@ impl VmExtension for WeierstrassExtension { }; if bytes <= 32 { let add_ne_chip = EcAddNeChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); + inventory.add_executor( WeierstrassExtensionExecutor::EcAddNeRv32_32(add_ne_chip), ec_add_ne_opcodes @@ -147,18 +153,16 @@ impl VmExtension for WeierstrassExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let double_chip = EcDoubleChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - range_checker.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config32.clone(), start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), curve.a.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( WeierstrassExtensionExecutor::EcDoubleRv32_32(double_chip), @@ -168,18 +172,17 @@ impl VmExtension for WeierstrassExtension { )?; } else if bytes <= 48 { let add_ne_chip = EcAddNeChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), range_checker.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); + inventory.add_executor( WeierstrassExtensionExecutor::EcAddNeRv32_48(add_ne_chip), ec_add_ne_opcodes @@ -187,18 +190,16 @@ impl VmExtension for WeierstrassExtension { .map(|x| VmOpcode::from_usize(x + start_offset)), )?; let double_chip = EcDoubleChip::new( - Rv32VecHeapAdapterChip::::new( - execution_bus, - program_bus, - memory_bridge, - pointer_bits, - bitwise_lu_chip.clone(), - ), - range_checker.clone(), + execution_bridge.clone(), + memory_bridge.clone(), + builder.system_base().memory_controller.helper(), + pointer_max_bits, config48.clone(), start_offset, + bitwise_lu_chip.clone(), + range_checker.clone(), curve.a.clone(), - offline_memory.clone(), + MAX_INS_CAPACITY, ); inventory.add_executor( WeierstrassExtensionExecutor::EcDoubleRv32_48(double_chip), @@ -236,11 +237,14 @@ pub(crate) mod phantom { use num_traits::{FromPrimitive, One}; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_ecc_guest::weierstrass::DecompressionHint; - use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant}; - use openvm_rv32im_circuit::adapters::unsafe_read_rv32_register; + use openvm_instructions::{ + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + PhantomDiscriminant, + }; + use openvm_rv32im_circuit::adapters::new_read_rv32_register; use openvm_stark_backend::p3_field::PrimeField32; use rand::{rngs::StdRng, SeedableRng}; @@ -260,11 +264,11 @@ pub(crate) mod phantom { impl PhantomSubExecutor for DecompressHintSubEx { fn phantom_execute( &mut self, - memory: &MemoryController, + memory: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, c_upper: u16, ) -> eyre::Result<()> { let c_idx = c_upper as usize; @@ -275,7 +279,7 @@ pub(crate) mod phantom { ); } let curve = &self.supported_curves[c_idx]; - let rs1 = unsafe_read_rv32_register(memory, a); + let rs1 = new_read_rv32_register(memory, RV32_REGISTER_AS, a); let num_limbs: usize = if curve.modulus.bits().div_ceil(8) <= 32 { 32 } else if curve.modulus.bits().div_ceil(8) <= 48 { @@ -283,21 +287,15 @@ pub(crate) mod phantom { } else { bail!("Modulus too large") }; - let mut x_limbs: Vec = Vec::with_capacity(num_limbs); - for i in 0..num_limbs { - let limb = memory.unsafe_read_cell( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1 + i as u32), - ); - x_limbs.push(limb.as_canonical_u32() as u8); - } + let x_limbs: Vec = memory + .memory + .read_range_generic((RV32_MEMORY_AS, rs1), num_limbs); let x = BigUint::from_bytes_le(&x_limbs); - let rs2 = unsafe_read_rv32_register(memory, b); - let rec_id = memory.unsafe_read_cell( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs2), - ); - let hint = self.decompress_point(x, rec_id.as_canonical_u32() & 1 == 1, c_idx); + let rs2 = new_read_rv32_register(memory, RV32_REGISTER_AS, b); + let rec_id = memory + .memory + .read_range_generic::((RV32_MEMORY_AS, rs2), 1)[0]; + let hint = self.decompress_point(x, rec_id & 1 == 1, c_idx); let hint_bytes = once(F::from_bool(hint.possible)) .chain(repeat(F::ZERO)) .take(4) @@ -311,6 +309,7 @@ pub(crate) mod phantom { ) .collect(); streams.hint_stream = hint_bytes; + Ok(()) } } @@ -442,11 +441,11 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NonQrHintSubEx { fn phantom_execute( &mut self, - _: &MemoryController, + _: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { let c_idx = c_upper as usize; diff --git a/extensions/keccak256/circuit/src/extension.rs b/extensions/keccak256/circuit/src/extension.rs index d24681fb55..616de6bfe5 100644 --- a/extensions/keccak256/circuit/src/extension.rs +++ b/extensions/keccak256/circuit/src/extension.rs @@ -1,3 +1,5 @@ +use std::result::Result; + use derive_more::derive::From; use openvm_circuit::{ arch::{ @@ -5,7 +7,7 @@ use openvm_circuit::{ }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::bitwise_op_lookup::BitwiseOperationLookupBus; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::*; @@ -19,6 +21,9 @@ use strum::IntoEnumIterator; use crate::*; +// TODO: this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Keccak256Rv32Config { #[system] @@ -48,7 +53,7 @@ impl Default for Keccak256Rv32Config { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Keccak256; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, InsExecutorE1)] pub enum Keccak256Executor { Keccak256(KeccakVmChip), } @@ -68,11 +73,8 @@ impl VmExtension for Keccak256 { builder: &mut VmInventoryBuilder, ) -> Result, VmInventoryError> { let mut inventory = VmInventory::new(); - let SystemPort { - execution_bus, - program_bus, - memory_bridge, - } = builder.system_port(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -84,17 +86,27 @@ impl VmExtension for Keccak256 { inventory.add_periphery_chip(chip.clone()); chip }; - let offline_memory = builder.system_base().offline_memory(); - let address_bits = builder.system_config().memory_config.pointer_max_bits; - let keccak_chip = KeccakVmChip::new( + let SystemPort { execution_bus, program_bus, memory_bridge, - address_bits, - bitwise_lu_chip, - Rv32KeccakOpcode::CLASS_OFFSET, - offline_memory, + } = builder.system_port(); + let keccak_chip = KeccakVmChip::new( + KeccakVmAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + pointer_max_bits, + Rv32KeccakOpcode::CLASS_OFFSET, + ), + KeccakVmStep::new( + bitwise_lu_chip.clone(), + Rv32KeccakOpcode::CLASS_OFFSET, + pointer_max_bits, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( keccak_chip, diff --git a/extensions/keccak256/circuit/src/lib.rs b/extensions/keccak256/circuit/src/lib.rs index c9fd1c9f5a..2f602d685c 100644 --- a/extensions/keccak256/circuit/src/lib.rs +++ b/extensions/keccak256/circuit/src/lib.rs @@ -1,17 +1,10 @@ //! Stateful keccak256 hasher. Handles full keccak sponge (padding, absorb, keccak-f) on //! variable length inputs read from VM memory. -use std::{ - array::from_fn, - cmp::min, - sync::{Arc, Mutex}, -}; use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; use openvm_stark_backend::p3_field::PrimeField32; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; +use p3_keccak_air::NUM_ROUNDS; use tiny_keccak::{Hasher, Keccak}; -use utils::num_keccak_f; pub mod air; pub mod columns; @@ -26,17 +19,23 @@ mod tests; pub use air::KeccakVmAir; use openvm_circuit::{ - arch::{ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor}, - system::{ - memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory, RecordId}, - program::ProgramBus, + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + ExecutionBridge, NewVmChipWrapper, Result, StepExecutorE1, VmStateMut, }, + system::memory::online::GuestMemory, }; use openvm_instructions::{ - instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, }; use openvm_keccak256_transpiler::Rv32KeccakOpcode; -use openvm_rv32im_circuit::adapters::read_rv32_register; +use openvm_rv32im_circuit::adapters::{ + memory_read_from_state, memory_write_from_state, new_read_rv32_register_from_state, +}; +use utils::num_keccak_f; // ==== Constants for register/memory adapter ==== /// Register reads to get dst, src, len @@ -69,75 +68,38 @@ pub const KECCAK_DIGEST_BYTES: usize = 32; /// Number of 64-bit digest limbs. pub const KECCAK_DIGEST_U64S: usize = KECCAK_DIGEST_BYTES / 8; -pub struct KeccakVmChip { - pub air: KeccakVmAir, - /// IO and memory data necessary for each opcode call - pub records: Vec>, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - - offset: usize, +pub type KeccakVmChip = NewVmChipWrapper; - offline_memory: Arc>>, -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct KeccakRecord { - pub pc: F, - pub dst_read: RecordId, - pub src_read: RecordId, - pub len_read: RecordId, - pub input_blocks: Vec, - pub digest_writes: [RecordId; KECCAK_DIGEST_WRITES], -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct KeccakInputBlock { - /// Memory reads for non-padding bytes in this block. Length is at most [KECCAK_RATE_BYTES / - /// KECCAK_WORD_SIZE]. - pub reads: Vec, - /// Index in `reads` of the memory read for < KECCAK_WORD_SIZE bytes, if any. - pub partial_read_idx: Option, - /// Bytes with padding. Can be derived from `bytes_read` but we store for convenience. - #[serde(with = "BigArray")] - pub padded_bytes: [u8; KECCAK_RATE_BYTES], - pub remaining_len: usize, - pub src: usize, - pub is_new_start: bool, +//#[derive(derive_new::new)] +pub struct KeccakVmStep { + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pub offset: usize, + pub pointer_max_bits: usize, } -impl KeccakVmChip { +impl KeccakVmStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, offset: usize, - offline_memory: Arc>>, + pointer_max_bits: usize, ) -> Self { Self { - air: KeccakVmAir::new( - ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_chip.bus(), - address_bits, - offset, - ), bitwise_lookup_chip, - records: Vec::new(), offset, - offline_memory, + pointer_max_bits, } } } -impl InstructionExecutor for KeccakVmChip { - fn execute( +impl StepExecutorE1 for KeccakVmStep { + fn execute_e1( &mut self, - memory: &mut MemoryController, + state: &mut VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { let &Instruction { opcode, a, @@ -147,140 +109,91 @@ impl InstructionExecutor for KeccakVmChip { e, .. } = instruction; - let local_opcode = Rv32KeccakOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - debug_assert_eq!(local_opcode, Rv32KeccakOpcode::KECCAK256); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); - let mut timestamp_delta = 3; - let (dst_read, dst) = read_rv32_register(memory, d, a); - let (src_read, src) = read_rv32_register(memory, d, b); - let (len_read, len) = read_rv32_register(memory, d, c); - #[cfg(debug_assertions)] - { - assert!(dst < (1 << self.air.ptr_max_bits)); - assert!(src < (1 << self.air.ptr_max_bits)); - assert!(len < (1 << self.air.ptr_max_bits)); - } + debug_assert_eq!(opcode, Rv32KeccakOpcode::KECCAK256.global_opcode()); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + let dst = new_read_rv32_register_from_state(state, d, a.as_canonical_u32()); + let src = new_read_rv32_register_from_state(state, d, b.as_canonical_u32()); + let len = new_read_rv32_register_from_state(state, d, c.as_canonical_u32()); - let mut remaining_len = len as usize; - let num_blocks = num_keccak_f(remaining_len); - let mut input_blocks = Vec::with_capacity(num_blocks); let mut hasher = Keccak::v256(); - let mut src = src as usize; - for block_idx in 0..num_blocks { - if block_idx != 0 { - memory.increment_timestamp_by(KECCAK_REGISTER_READS as u32); - timestamp_delta += KECCAK_REGISTER_READS as u32; - } - let mut reads = Vec::with_capacity(KECCAK_RATE_BYTES); + // TODO(ayush): read in a single call + let mut message = Vec::with_capacity(len as usize); + for offset in (0..len as usize).step_by(KECCAK_WORD_SIZE) { + let read = memory_read_from_state::<_, KECCAK_WORD_SIZE>(state, e, src + offset as u32); + let copy_len = std::cmp::min(KECCAK_WORD_SIZE, (len as usize) - offset); + message.extend_from_slice(&read[..copy_len]); + } + hasher.update(&message); - let mut partial_read_idx = None; - let mut bytes = [0u8; KECCAK_RATE_BYTES]; - for i in (0..KECCAK_RATE_BYTES).step_by(KECCAK_WORD_SIZE) { - if i < remaining_len { - let read = - memory.read::(e, F::from_canonical_usize(src + i)); + let mut output = [0u8; 32]; + hasher.finalize(&mut output); + memory_write_from_state(state, e, dst, &output); - let chunk = read.1.map(|x| { - x.as_canonical_u32() - .try_into() - .expect("Memory cell not a byte") - }); - let copy_len = min(KECCAK_WORD_SIZE, remaining_len - i); - if copy_len != KECCAK_WORD_SIZE { - partial_read_idx = Some(reads.len()); - } - bytes[i..i + copy_len].copy_from_slice(&chunk[..copy_len]); - reads.push(read.0); - } else { - memory.increment_timestamp(); - } - timestamp_delta += 1; - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - let mut block = KeccakInputBlock { - reads, - partial_read_idx, - padded_bytes: bytes, - remaining_len, - src, - is_new_start: block_idx == 0, - }; - if block_idx != num_blocks - 1 { - src += KECCAK_RATE_BYTES; - remaining_len -= KECCAK_RATE_BYTES; - hasher.update(&block.padded_bytes); - } else { - // handle padding here since it is convenient - debug_assert!(remaining_len < KECCAK_RATE_BYTES); - hasher.update(&block.padded_bytes[..remaining_len]); + Ok(()) + } - if remaining_len == KECCAK_RATE_BYTES - 1 { - block.padded_bytes[remaining_len] = 0b1000_0001; - } else { - block.padded_bytes[remaining_len] = 0x01; - block.padded_bytes[KECCAK_RATE_BYTES - 1] = 0x80; - } - } - input_blocks.push(block); - } - let mut output = [0u8; 32]; - hasher.finalize(&mut output); - let dst = dst as usize; - let digest_writes: [_; KECCAK_DIGEST_WRITES] = from_fn(|i| { - timestamp_delta += 1; - memory - .write::( - e, - F::from_canonical_usize(dst + i * KECCAK_WORD_SIZE), - from_fn(|j| F::from_canonical_u8(output[i * KECCAK_WORD_SIZE + j])), - ) - .0 - }); - tracing::trace!("[runtime] keccak256 output: {:?}", output); + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); - let record = KeccakRecord { - pc: F::from_canonical_u32(from_state.pc), - dst_read, - src_read, - len_read, - input_blocks, - digest_writes, - }; + debug_assert_eq!(opcode, Rv32KeccakOpcode::KECCAK256.global_opcode()); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); - // Add the events to chip state for later trace generation usage - self.records.push(record); + let dst = new_read_rv32_register_from_state(state, d, a.as_canonical_u32()); + let src = new_read_rv32_register_from_state(state, d, b.as_canonical_u32()); + let len = new_read_rv32_register_from_state(state, d, c.as_canonical_u32()); - // NOTE: Check this is consistent with KeccakVmAir::timestamp_change (we don't use it to - // avoid unnecessary conversions here) - let total_timestamp_delta = - len + (KECCAK_REGISTER_READS + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32; - memory.increment_timestamp_by(total_timestamp_delta - timestamp_delta); + let num_blocks = num_keccak_f(len as usize); - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: from_state.timestamp + total_timestamp_delta, - }) - } + let mut message = Vec::with_capacity(len as usize); + for offset in (0..len as usize).step_by(KECCAK_WORD_SIZE) { + let read = memory_read_from_state::<_, KECCAK_WORD_SIZE>(state, e, src + offset as u32); + let copy_len = std::cmp::min(KECCAK_WORD_SIZE, (len as usize) - offset); + message.extend_from_slice(&read[..copy_len]); + } - fn get_opcode_name(&self, _: usize) -> String { - "KECCAK256".to_string() - } -} + let mut hasher = Keccak::v256(); + hasher.update(&message); -impl Default for KeccakInputBlock { - fn default() -> Self { - // Padding for empty byte array so padding constraints still hold - let mut padded_bytes = [0u8; KECCAK_RATE_BYTES]; - padded_bytes[0] = 0x01; - *padded_bytes.last_mut().unwrap() = 0x80; - Self { - padded_bytes, - partial_read_idx: None, - remaining_len: 0, - is_new_start: true, - reads: Vec::new(), - src: 0, + let mut output = [0u8; 32]; + hasher.finalize(&mut output); + + for (i, word) in output.chunks_exact(KECCAK_WORD_SIZE).enumerate() { + memory_write_from_state::<_, KECCAK_WORD_SIZE>( + state, + e, + dst + (i * KECCAK_WORD_SIZE) as u32, + word.try_into().unwrap(), + ); } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + state.ctx.trace_heights[chip_index] += (num_blocks * NUM_ROUNDS) as u32; + + Ok(()) } } diff --git a/extensions/keccak256/circuit/src/tests.rs b/extensions/keccak256/circuit/src/tests.rs index 65a34491b8..0732729024 100644 --- a/extensions/keccak256/circuit/src/tests.rs +++ b/extensions/keccak256/circuit/src/tests.rs @@ -1,11 +1,17 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; use hex::FromHex; -use openvm_circuit::arch::testing::{VmChipTestBuilder, VmChipTester, BITWISE_OP_LOOKUP_BUS}; +use openvm_circuit::arch::testing::{ + memory::gen_pointer, VmChipTestBuilder, VmChipTester, BITWISE_OP_LOOKUP_BUS, +}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; use openvm_keccak256_transpiler::Rv32KeccakOpcode; use openvm_stark_backend::{ p3_field::FieldAlgebra, utils::disable_debug_builder, verifier::VerificationError, @@ -19,34 +25,56 @@ use rand::Rng; use tiny_keccak::Hasher; use super::{columns::KeccakVmCols, utils::num_keccak_f, KeccakVmChip}; +use crate::{KeccakVmAir, KeccakVmStep}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 4096; + +fn create_test_chips( + tester: &mut VmChipTestBuilder, +) -> (KeccakVmChip, SharedBitwiseOperationLookupChip<8>) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::<8>::new(bitwise_bus); + let chip = KeccakVmChip::new( + KeccakVmAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + tester.address_bits(), + Rv32KeccakOpcode::CLASS_OFFSET, + ), + KeccakVmStep::new( + bitwise_chip.clone(), + Rv32KeccakOpcode::CLASS_OFFSET, + tester.address_bits(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + (chip, bitwise_chip) +} + // io is vector of (input, expected_output, prank_output) where prank_output is Some if the trace // will be replaced #[allow(clippy::type_complexity)] fn build_keccak256_test( io: Vec<(Vec, Option<[u8; 32]>, Option<[u8; 32]>)>, ) -> VmChipTester { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::<8>::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = KeccakVmChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - Rv32KeccakOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chips(&mut tester); - let mut dst = 0; - let src = 0; + let max_mem_ptr = 1 << (tester.address_bits() - 3); + let mut dst = rng.gen_range(0..max_mem_ptr) << 2; + let src = rng.gen_range(0..max_mem_ptr) << 2; for (input, expected_output, _) in &io { - let [a, b, c] = [0, 4, 8]; // space apart for register limbs - let [d, e] = [1, 2]; + let [a, b, c] = [ + gen_pointer(&mut rng, 4), + gen_pointer(&mut rng, 4), + gen_pointer(&mut rng, 4), + ]; // space apart for register limbs + let [d, e] = [RV32_REGISTER_AS as usize, RV32_MEMORY_AS as usize]; tester.write(d, a, (dst as u32).to_le_bytes().map(F::from_canonical_u8)); tester.write(d, b, (src as u32).to_le_bytes().map(F::from_canonical_u8)); @@ -55,9 +83,15 @@ fn build_keccak256_test( c, (input.len() as u32).to_le_bytes().map(F::from_canonical_u8), ); - for (i, byte) in input.iter().enumerate() { - tester.write_cell(e, src + i, F::from_canonical_u8(*byte)); - } + + input.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); + tester.write( + 2, + src as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), + ); + }); tester.execute( &mut chip, @@ -71,13 +105,15 @@ fn build_keccak256_test( ), ); if let Some(output) = expected_output { - for (i, byte) in output.iter().enumerate() { - assert_eq!(tester.read_cell(e, dst + i), F::from_canonical_u8(*byte)); - } + assert_eq!( + output.map(F::from_canonical_u8), + tester.read::<32>(e, dst as usize) + ); } // shift dst to not deal with timestamps for pranking dst += 32; } + let mut tester = tester.build().load(chip).load(bitwise_chip).finalize(); let keccak_trace = tester.air_proof_inputs[2] @@ -152,7 +188,7 @@ fn test_keccak256_positive_kat_vectors() { let output = Vec::from_hex(output).unwrap(); io.push((input, Some(output.try_into().unwrap()), None)); } - let tester = build_keccak256_test(io); + tester.simple_test().expect("Verification failed"); } diff --git a/extensions/keccak256/circuit/src/trace.rs b/extensions/keccak256/circuit/src/trace.rs index c314c38eac..52ed933375 100644 --- a/extensions/keccak256/circuit/src/trace.rs +++ b/extensions/keccak256/circuit/src/trace.rs @@ -1,150 +1,219 @@ -use std::{array::from_fn, borrow::BorrowMut, sync::Arc}; +use std::{array::from_fn, borrow::BorrowMut, cmp::min}; -use openvm_circuit::system::memory::RecordId; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use openvm_circuit::{ + arch::{Result, TraceStep, VmStateMut}, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, +}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_keccak256_transpiler::Rv32KeccakOpcode; +use openvm_rv32im_circuit::adapters::{tracing_read, tracing_write}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::get_air_name, - AirRef, Chip, ChipUsageGetter, + p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, }; use p3_keccak_air::{ generate_trace_rows, NUM_KECCAK_COLS as NUM_KECCAK_PERM_COLS, NUM_ROUNDS, U64_LIMBS, }; -use tiny_keccak::keccakf; +use tiny_keccak::{keccakf, Hasher, Keccak}; use super::{ - columns::{KeccakInstructionCols, KeccakVmCols}, - KeccakVmChip, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_RATE_U16S, + columns::KeccakVmCols, KECCAK_ABSORB_READS, KECCAK_DIGEST_WRITES, KECCAK_RATE_BYTES, KECCAK_REGISTER_READS, NUM_ABSORB_ROUNDS, }; +use crate::{columns::NUM_KECCAK_VM_COLS, utils::num_keccak_f, KeccakVmStep, KECCAK_WORD_SIZE}; -impl Chip for KeccakVmChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air) - } +impl TraceStep for KeccakVmStep { + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(opcode, Rv32KeccakOpcode::KECCAK256.global_opcode()); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); - fn generate_air_proof_input(self) -> AirProofInput { - let trace_width = self.trace_width(); - let records = self.records; - let total_num_blocks: usize = records.iter().map(|r| r.input_blocks.len()).sum(); - let mut states = Vec::with_capacity(total_num_blocks); - let mut instruction_blocks = Vec::with_capacity(total_num_blocks); - let memory = self.offline_memory.lock().unwrap(); + let trace = &mut trace[*trace_offset..]; + let (dst, mut src, mut remaining_len) = { + let cols: &mut KeccakVmCols = trace[..width].borrow_mut(); + cols.instruction.start_timestamp = F::from_canonical_u32(state.memory.timestamp()); - #[derive(Clone)] - struct StateDiff { - /// hi-byte of pre-state - pre_hi: [u8; KECCAK_RATE_U16S], - /// hi-byte of post-state - post_hi: [u8; KECCAK_RATE_U16S], - /// if first block - register_reads: Option<[RecordId; KECCAK_REGISTER_READS]>, - /// if last block - digest_writes: Option<[RecordId; KECCAK_DIGEST_WRITES]>, - } + let a = a.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + let dst = tracing_read(state.memory, d, a, &mut cols.mem_oc.register_aux[0]); + let src = tracing_read(state.memory, d, b, &mut cols.mem_oc.register_aux[1]); + let len = tracing_read(state.memory, d, c, &mut cols.mem_oc.register_aux[2]); + ( + dst, + u32::from_le_bytes(src), + u32::from_le_bytes(len) as usize, + ) + }; - impl Default for StateDiff { - fn default() -> Self { - Self { - pre_hi: [0; KECCAK_RATE_U16S], - post_hi: [0; KECCAK_RATE_U16S], - register_reads: None, - digest_writes: None, + // Due to the AIR constraints, the final memory timestamp should be the following: + let final_timestamp = state.memory.timestamp() + + (remaining_len + KECCAK_ABSORB_READS + KECCAK_DIGEST_WRITES) as u32; + let num_blocks = num_keccak_f(remaining_len); + let mut hasher = Keccak::v256(); + + trace + .chunks_mut(width * NUM_ROUNDS) + .enumerate() + .take(num_blocks) + .for_each(|(block_idx, chunk)| { + let cols: &mut KeccakVmCols = chunk[..NUM_KECCAK_VM_COLS].borrow_mut(); + if block_idx != 0 { + cols.instruction.start_timestamp = + F::from_canonical_u32(state.memory.timestamp()); + + state + .memory + .increment_timestamp_by(KECCAK_REGISTER_READS as u32); } - } + cols.instruction.dst_ptr = a; + cols.instruction.src_ptr = b; + cols.instruction.len_ptr = c; + cols.instruction.dst = dst.map(F::from_canonical_u8); + cols.instruction + .src_limbs + .copy_from_slice(&src.to_le_bytes().map(F::from_canonical_u8)[1..]); + cols.instruction.len_limbs.copy_from_slice( + &(remaining_len as u32) + .to_le_bytes() + .map(F::from_canonical_u8)[1..], + ); + cols.instruction.src = F::from_canonical_u32(src); + cols.instruction.remaining_len = F::from_canonical_usize(remaining_len); + cols.instruction.pc = F::from_canonical_u32(*state.pc); + cols.sponge.is_new_start = F::from_bool(block_idx == 0); + + for i in (0..KECCAK_RATE_BYTES).step_by(KECCAK_WORD_SIZE) { + if i < remaining_len { + let read = tracing_read::<_, KECCAK_WORD_SIZE>( + state.memory, + e, + src + i as u32, + &mut cols.mem_oc.absorb_reads[i / KECCAK_WORD_SIZE], + ); + let copy_len = min(KECCAK_WORD_SIZE, remaining_len - i); + hasher.update(&read[..copy_len]); + cols.sponge.block_bytes[i..i + copy_len] + .copy_from_slice(&read.map(F::from_canonical_u8)[..copy_len]); + if copy_len != KECCAK_WORD_SIZE { + cols.mem_oc + .partial_block + .copy_from_slice(&read.map(F::from_canonical_u8)[1..]); + } + } else { + state.memory.increment_timestamp(); + } + } + if block_idx == num_blocks - 1 { + if remaining_len == KECCAK_RATE_BYTES - 1 { + cols.sponge.block_bytes[remaining_len] = F::from_canonical_u32(0b1000_0001); + } else { + cols.sponge.block_bytes[remaining_len] = F::from_canonical_u32(0x01); + cols.sponge.block_bytes[KECCAK_RATE_BYTES - 1] = + F::from_canonical_u32(0x80); + } + } else { + src += KECCAK_RATE_BYTES as u32; + remaining_len -= KECCAK_RATE_BYTES; + } + }); + + let last_row_offset = (num_blocks * NUM_ROUNDS - 1) * width; + let last_row: &mut KeccakVmCols = + trace[last_row_offset..last_row_offset + NUM_KECCAK_VM_COLS].borrow_mut(); + let mut digest = [0u8; 32]; + hasher.finalize(&mut digest); + for (i, word) in digest.chunks_exact(KECCAK_WORD_SIZE).enumerate() { + tracing_write::<_, KECCAK_WORD_SIZE>( + state.memory, + e, + u32::from_le_bytes(dst) + (i * KECCAK_WORD_SIZE) as u32, + word.try_into().unwrap(), + &mut last_row.mem_oc.digest_writes[i], + ); } - // prepare the states - let mut state: [u64; 25]; - for record in records { - let dst_read = memory.record_by_id(record.dst_read); - let src_read = memory.record_by_id(record.src_read); - let len_read = memory.record_by_id(record.len_read); + state.memory.timestamp = final_timestamp; + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + *trace_offset += num_blocks * NUM_ROUNDS * width; + Ok(()) + } + + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut [F], + width: usize, + rows_used: usize, + ) where + Self: Send + Sync, + F: Send + Sync, + { + if rows_used == 0 { + return; + } - state = [0u64; 25]; - let src_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = src_read.data_slice() - [1..RV32_REGISTER_NUM_LIMBS] - .try_into() - .unwrap(); - let len_limbs: [_; RV32_REGISTER_NUM_LIMBS - 1] = len_read.data_slice() - [1..RV32_REGISTER_NUM_LIMBS] - .try_into() - .unwrap(); - let mut instruction = KeccakInstructionCols { - pc: record.pc, - is_enabled: Val::::ONE, - is_enabled_first_round: Val::::ZERO, - start_timestamp: Val::::from_canonical_u32(dst_read.timestamp), - dst_ptr: dst_read.pointer, - src_ptr: src_read.pointer, - len_ptr: len_read.pointer, - dst: dst_read.data_slice().try_into().unwrap(), - src_limbs, - src: Val::::from_canonical_usize(record.input_blocks[0].src), - len_limbs, - remaining_len: Val::::from_canonical_usize( - record.input_blocks[0].remaining_len, - ), - }; - let num_blocks = record.input_blocks.len(); - for (idx, block) in record.input_blocks.into_iter().enumerate() { + let num_blocks = rows_used.div_ceil(NUM_ROUNDS); + let mut states = Vec::with_capacity(num_blocks); + let mut state = [0u64; 25]; + trace + .chunks_mut(width * NUM_ROUNDS) + .take(num_blocks) + .for_each(|chunk| { + let cols: &mut KeccakVmCols = chunk[..NUM_KECCAK_VM_COLS].borrow_mut(); + if cols.sponge.is_new_start.is_one() { + // a new instruction is starting + state = [0u64; 25]; + } // absorb - for (bytes, s) in block.padded_bytes.chunks_exact(8).zip(state.iter_mut()) { + for (bytes, s) in cols + .sponge + .block_bytes + .chunks_exact(8) + .zip(state.iter_mut()) + { // u64 <-> bytes conversion is little-endian for (i, &byte) in bytes.iter().enumerate() { + let byte = byte.as_canonical_u32(); let s_byte = (*s >> (i * 8)) as u8; // Update bitwise lookup (i.e. xor) chip state: order matters! - if idx != 0 { - self.bitwise_lookup_chip - .request_xor(byte as u32, s_byte as u32); + if cols.sponge.is_new_start.is_zero() { + self.bitwise_lookup_chip.request_xor(byte, s_byte as u32); } *s ^= (byte as u64) << (i * 8); } } - let pre_hi: [u8; KECCAK_RATE_U16S] = - from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8); states.push(state); keccakf(&mut state); - let post_hi: [u8; KECCAK_RATE_U16S] = - from_fn(|i| (state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8); - // Range check the final state - if idx == num_blocks - 1 { - for s in state.into_iter().take(NUM_ABSORB_ROUNDS) { - for s_byte in s.to_le_bytes() { - self.bitwise_lookup_chip.request_xor(0, s_byte as u32); - } - } - } - let register_reads = - (idx == 0).then_some([record.dst_read, record.src_read, record.len_read]); - let digest_writes = (idx == num_blocks - 1).then_some(record.digest_writes); - let diff = StateDiff { - pre_hi, - post_hi, - register_reads, - digest_writes, - }; - instruction_blocks.push((instruction, diff, block)); - instruction.remaining_len -= Val::::from_canonical_usize(KECCAK_RATE_BYTES); - instruction.src += Val::::from_canonical_usize(KECCAK_RATE_BYTES); - instruction.start_timestamp += - Val::::from_canonical_usize(KECCAK_REGISTER_READS + KECCAK_ABSORB_READS); - } - } + }); // We need to transpose state matrices due to a plonky3 issue: https://github.com/Plonky3/Plonky3/issues/672 // Note: the fix for this issue will be a commit after the major Field crate refactor PR https://github.com/Plonky3/Plonky3/pull/640 // which will require a significant refactor to switch to. let p3_states = states - .into_iter() + .par_iter() .map(|state| { // transpose of 5x5 matrix from_fn(|i| { @@ -154,122 +223,147 @@ where }) }) .collect(); - let p3_keccak_trace: RowMajorMatrix> = generate_trace_rows(p3_states, 0); - let num_rows = p3_keccak_trace.height(); - // Every `NUM_ROUNDS` rows corresponds to one input block - let num_blocks = num_rows.div_ceil(NUM_ROUNDS); - // Resize with dummy `is_enabled = 0` - instruction_blocks.resize(num_blocks, Default::default()); - - let aux_cols_factory = memory.aux_cols_factory(); - - // Use unsafe alignment so we can parallelly write to the matrix - let mut trace = - RowMajorMatrix::new(Val::::zero_vec(num_rows * trace_width), trace_width); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.ptr_max_bits; + let p3_keccak_trace: RowMajorMatrix = generate_trace_rows(p3_states, 0); trace - .values - .par_chunks_mut(trace_width * NUM_ROUNDS) + .par_chunks_mut(width * NUM_ROUNDS) .zip( p3_keccak_trace .values .par_chunks(NUM_KECCAK_PERM_COLS * NUM_ROUNDS), ) - .zip(instruction_blocks.into_par_iter()) - .for_each(|((rows, p3_keccak_mat), (instruction, diff, block))| { - let height = rows.len() / trace_width; - for (row, p3_keccak_row) in rows - .chunks_exact_mut(trace_width) - .zip(p3_keccak_mat.chunks_exact(NUM_KECCAK_PERM_COLS)) + .enumerate() + .for_each(|(block_idx, (block, p3_keccak_block))| { + // let cols: &mut KeccakVmCols = block[..NUM_KECCAK_VM_COLS].borrow_mut(); + if block_idx >= num_blocks { + // fill in a dummy row + block + .par_chunks_mut(width) + .zip(p3_keccak_block.par_chunks_exact(NUM_KECCAK_PERM_COLS)) + .for_each(|(row, p3_keccak_row)| { + row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_keccak_row); + let cols: &mut KeccakVmCols = row.borrow_mut(); + cols.sponge.block_bytes[0] = F::ONE; + cols.sponge.block_bytes[KECCAK_RATE_BYTES - 1] = + F::from_canonical_u32(0x80); + cols.sponge.is_padding_byte[0..KECCAK_RATE_BYTES].fill(F::ONE); + }); + + // The first row of the `dummy` block should have `is_new_start = F::ONE` + let first_dummy_row: &mut KeccakVmCols = block[..width].borrow_mut(); + first_dummy_row.sponge.is_new_start = F::ONE; + return; + } + + // the first row is treated differently + let (first_row, block) = block.split_at_mut(width); + first_row[..NUM_KECCAK_PERM_COLS] + .copy_from_slice(&p3_keccak_block[..NUM_KECCAK_PERM_COLS]); + let first_row: &mut KeccakVmCols = first_row.borrow_mut(); + first_row.instruction.is_enabled = F::ONE; + let remaining_len = first_row.instruction.remaining_len.as_canonical_u32() as usize; + for i in remaining_len..KECCAK_RATE_BYTES { + first_row.sponge.is_padding_byte[i] = F::ONE; + } + + for (row, p3_keccak_row) in block + .chunks_exact_mut(width) + .zip(p3_keccak_block.chunks_exact(NUM_KECCAK_PERM_COLS).skip(1)) { // Safety: `KeccakPermCols` **must** be the first field in `KeccakVmCols` row[..NUM_KECCAK_PERM_COLS].copy_from_slice(p3_keccak_row); - let row_mut: &mut KeccakVmCols> = row.borrow_mut(); - row_mut.instruction = instruction; + let cols: &mut KeccakVmCols = row.borrow_mut(); - row_mut.sponge.block_bytes = - block.padded_bytes.map(Val::::from_canonical_u8); - if let Some(partial_read_idx) = block.partial_read_idx { - let partial_read = memory.record_by_id(block.reads[partial_read_idx]); - row_mut - .mem_oc - .partial_block - .copy_from_slice(&partial_read.data_slice()[1..]); - } - for (i, is_padding) in row_mut.sponge.is_padding_byte.iter_mut().enumerate() { - *is_padding = Val::::from_bool(i >= block.remaining_len); - } + cols.instruction = first_row.instruction; + cols.sponge.block_bytes = first_row.sponge.block_bytes; + cols.sponge.is_padding_byte = first_row.sponge.is_padding_byte; + cols.mem_oc.partial_block = first_row.mem_oc.partial_block; } - let first_row: &mut KeccakVmCols> = rows[..trace_width].borrow_mut(); - first_row.sponge.is_new_start = Val::::from_bool(block.is_new_start); - first_row.sponge.state_hi = diff.pre_hi.map(Val::::from_canonical_u8); + + let (_, last_row) = block.split_at_mut(width * (NUM_ROUNDS - 2)); + let last_row: &mut KeccakVmCols = last_row.borrow_mut(); + first_row.instruction.is_enabled_first_round = first_row.instruction.is_enabled; - // Make memory access aux columns. Any aux column not explicitly defined defaults to - // all 0s - if let Some(register_reads) = diff.register_reads { - let need_range_check = [ - ®ister_reads[0], // dst - ®ister_reads[1], // src - ®ister_reads[2], // len - ®ister_reads[2], - ] - .map(|r| { - memory - .record_by_id(*r) - .data_slice() - .last() - .unwrap() - .as_canonical_u32() - }); - for bytes in need_range_check.chunks(2) { - self.bitwise_lookup_chip.request_range( - bytes[0] << limb_shift_bits, - bytes[1] << limb_shift_bits, - ); - } - for (i, id) in register_reads.into_iter().enumerate() { - aux_cols_factory.generate_read_aux( - memory.record_by_id(id), - &mut first_row.mem_oc.register_aux[i], + first_row.sponge.state_hi = from_fn(|i| { + F::from_canonical_u8( + (states[block_idx][i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8, + ) + }); + + let start_timestamp = first_row.instruction.start_timestamp.as_canonical_u32(); + first_row + .mem_oc + .absorb_reads + .par_iter_mut() + .take(remaining_len.div_ceil(KECCAK_WORD_SIZE)) + .enumerate() + .for_each(|(i, read)| { + mem_helper.fill_from_prev( + start_timestamp + KECCAK_REGISTER_READS as u32 + i as u32, + read.as_mut(), ); - } - } - for (i, id) in block.reads.into_iter().enumerate() { - aux_cols_factory.generate_read_aux( - memory.record_by_id(id), - &mut first_row.mem_oc.absorb_reads[i], + }); + + // Check if the first row is a new start (e.g. register reads happened) + if first_row.sponge.is_new_start.is_one() { + let limb_shift_bits = + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + + self.bitwise_lookup_chip.request_range( + first_row.instruction.dst[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() + << limb_shift_bits, + first_row.instruction.src_limbs[RV32_REGISTER_NUM_LIMBS - 2] + .as_canonical_u32() + << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + first_row.instruction.len_limbs[RV32_REGISTER_NUM_LIMBS - 2] + .as_canonical_u32() + << limb_shift_bits, + first_row.instruction.len_limbs[RV32_REGISTER_NUM_LIMBS - 2] + .as_canonical_u32() + << limb_shift_bits, ); + first_row + .mem_oc + .register_aux + .par_iter_mut() + .enumerate() + .for_each(|(i, aux)| { + mem_helper.fill_from_prev(start_timestamp + i as u32, aux.as_mut()); + }); } - let last_row: &mut KeccakVmCols> = - rows[(height - 1) * trace_width..].borrow_mut(); - last_row.sponge.state_hi = diff.post_hi.map(Val::::from_canonical_u8); - last_row.inner.export = instruction.is_enabled - * Val::::from_bool(block.remaining_len < KECCAK_RATE_BYTES); - if let Some(digest_writes) = diff.digest_writes { - for (i, record_id) in digest_writes.into_iter().enumerate() { - let record = memory.record_by_id(record_id); - aux_cols_factory - .generate_write_aux(record, &mut last_row.mem_oc.digest_writes[i]); + let mut state = states[block_idx]; + keccakf(&mut state); + last_row.sponge.state_hi = from_fn(|i| { + F::from_canonical_u8((state[i / U64_LIMBS] >> ((i % U64_LIMBS) * 16 + 8)) as u8) + }); + last_row.inner.export = last_row.instruction.is_enabled + * F::from_bool(remaining_len < KECCAK_RATE_BYTES); + + // Check if this is the last block (e.g. digest write happened) + if remaining_len < KECCAK_RATE_BYTES { + let write_timestamp = + start_timestamp + KECCAK_REGISTER_READS as u32 + KECCAK_ABSORB_READS as u32; + last_row + .mem_oc + .digest_writes + .par_iter_mut() + .enumerate() + .for_each(|(i, write)| { + mem_helper.fill_from_prev(write_timestamp + i as u32, write.as_mut()); + }); + for s in state.into_iter().take(NUM_ABSORB_ROUNDS) { + for s_byte in s.to_le_bytes() { + self.bitwise_lookup_chip.request_xor(0, s_byte as u32); + } } } }); - - AirProofInput::simple_no_pis(trace) - } -} - -impl ChipUsageGetter for KeccakVmChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - let num_blocks: usize = self.records.iter().map(|r| r.input_blocks.len()).sum(); - num_blocks * NUM_ROUNDS } - fn trace_width(&self) -> usize { - BaseAir::::width(&self.air) + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32KeccakOpcode::KECCAK256) } } diff --git a/extensions/keccak256/guest/src/lib.rs b/extensions/keccak256/guest/src/lib.rs index 459c4c910d..86ba45e577 100644 --- a/extensions/keccak256/guest/src/lib.rs +++ b/extensions/keccak256/guest/src/lib.rs @@ -1,7 +1,9 @@ #![no_std] #[cfg(target_os = "zkvm")] -use core::mem::MaybeUninit; +extern crate alloc; +#[cfg(target_os = "zkvm")] +use {core::mem::MaybeUninit, openvm_platform::alloc::AlignedBuf}; /// This is custom-0 defined in RISC-V spec document pub const OPCODE: u8 = 0x0b; @@ -41,6 +43,43 @@ pub fn keccak256(input: &[u8]) -> [u8; 32] { #[inline(always)] #[no_mangle] extern "C" fn native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + const MIN_ALIGN: usize = 4; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, MIN_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN); + __native_keccak256(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_keccak256(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, MIN_ALIGN); + __native_keccak256(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_keccak256(bytes, len, output); + } + }; + } +} + +/// keccak256 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_keccak256(bytes: *const u8, len: usize, output: *mut u8) { openvm_platform::custom_insn_r!( opcode = OPCODE, funct3 = KECCAK256_FUNCT3, diff --git a/extensions/native/circuit/Cargo.toml b/extensions/native/circuit/Cargo.toml index 5d5913b4be..a81be9b41b 100644 --- a/extensions/native/circuit/Cargo.toml +++ b/extensions/native/circuit/Cargo.toml @@ -17,6 +17,7 @@ openvm-circuit = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-instructions = { workspace = true } openvm-rv32im-circuit = { workspace = true } +openvm-rv32im-transpiler = { workspace = true } openvm-native-compiler = { workspace = true } diff --git a/extensions/native/circuit/src/adapters/alu_native_adapter.rs b/extensions/native/circuit/src/adapters/alu_native_adapter.rs index e85797536f..55232950db 100644 --- a/extensions/native/circuit/src/adapters/alu_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/alu_native_adapter.rs @@ -1,21 +1,18 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, - }, - native_adapter::{NativeReadRecord, NativeWriteRecord}, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -27,27 +24,11 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, }; -#[derive(Debug)] -pub struct AluNativeAdapterChip { - pub air: AluNativeAdapterAir, - _marker: PhantomData, -} - -impl AluNativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: AluNativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} +use super::tracing_write_native; +use crate::adapters::{ + memory_read_or_imm_native_from_state, memory_write_native_from_state, + tracing_read_or_imm_native, +}; #[repr(C)] #[derive(AlignedBorrow)] @@ -144,88 +125,148 @@ impl VmAdapterAir for AluNativeAdapterAir { } } -impl VmAdapterChip for AluNativeAdapterChip { - type ReadRecord = NativeReadRecord; - type WriteRecord = NativeWriteRecord; - type Air = AluNativeAdapterAir; - type Interface = BasicAdapterInterface, 2, 1, 1, 1>; +#[derive(derive_new::new)] +pub struct AluNativeAdapterStep; + +impl AdapterTraceStep for AluNativeAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [F; 2]; + type WriteData = [F; 1]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut AluNativeAdapterCols = adapter_row.borrow_mut(); - fn preprocess( - &mut self, - memory: &mut MemoryController, + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, e, f, .. } = *instruction; - - let reads = vec![memory.read::<1>(e, b), memory.read::<1>(f, c)]; - let i_reads: [_; 2] = std::array::from_fn(|i| reads[i].1); - - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { b, c, e, f, .. } = instruction; + + let adapter_row: &mut AluNativeAdapterCols = adapter_row.borrow_mut(); + + adapter_row.b_pointer = b; + let rs1 = tracing_read_or_imm_native( + memory, + e.as_canonical_u32(), + b, + &mut adapter_row.e_as, + &mut adapter_row.reads_aux[0], + ); + adapter_row.c_pointer = c; + let rs2 = tracing_read_or_imm_native( + memory, + f.as_canonical_u32(), + c, + &mut adapter_row.f_as, + &mut adapter_row.reads_aux[1], + ); + [rs1, rs2] } - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, .. } = *_instruction; - let writes = vec![memory.write( - F::from_canonical_u32(AS::Native as u32), - a, - output.writes[0], - )]; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: writes.try_into().unwrap(), - }, - )) + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { a, .. } = instruction; + + let adapter_row: &mut AluNativeAdapterCols = adapter_row.borrow_mut(); + adapter_row.a_pointer = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data, + &mut adapter_row.write_aux, + ); } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - let row_slice: &mut AluNativeAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); + let adapter_row: &mut AluNativeAdapterCols = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); + mem_helper.fill_from_prev(timestamp, &mut adapter_row.reads_aux[0].base); + timestamp += 1; - row_slice.a_pointer = memory.record_by_id(write_record.writes[0].0).pointer; - row_slice.b_pointer = memory.record_by_id(read_record.reads[0].0).pointer; - row_slice.c_pointer = memory.record_by_id(read_record.reads[1].0).pointer; - row_slice.e_as = memory.record_by_id(read_record.reads[0].0).address_space; - row_slice.f_as = memory.record_by_id(read_record.reads[1].0).address_space; + mem_helper.fill_from_prev(timestamp, &mut adapter_row.reads_aux[1].base); + timestamp += 1; - for (i, x) in read_record.reads.iter().enumerate() { - let read = memory.record_by_id(x.0); - aux_cols_factory.generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i]); + mem_helper.fill_from_prev(timestamp, adapter_row.write_aux.as_mut()); + + if adapter_row.e_as.is_zero() { + adapter_row.reads_aux[0].is_immediate = F::ONE; + adapter_row.reads_aux[0].is_zero_aux = F::ZERO; + } else { + adapter_row.reads_aux[0].is_immediate = F::ZERO; + adapter_row.reads_aux[0].is_zero_aux = adapter_row.e_as.inverse(); } - let write = memory.record_by_id(write_record.writes[0].0); - aux_cols_factory.generate_write_aux(write, &mut row_slice.write_aux); + if adapter_row.f_as.is_zero() { + adapter_row.reads_aux[1].is_immediate = F::ONE; + adapter_row.reads_aux[1].is_zero_aux = F::ZERO; + } else { + adapter_row.reads_aux[1].is_immediate = F::ZERO; + adapter_row.reads_aux[1].is_zero_aux = adapter_row.f_as.inverse(); + } } +} + +impl AdapterExecutorE1 for AluNativeAdapterStep +where + F: PrimeField32, +{ + type ReadData = [F; 2]; + type WriteData = [F; 1]; + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { b, c, e, f, .. } = instruction; + + let rs1 = memory_read_or_imm_native_from_state(state, e.as_canonical_u32(), b); + let rs2 = memory_read_or_imm_native_from_state(state, f.as_canonical_u32(), c); + + [rs1, rs2] + } + + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { a, .. } = instruction; - fn air(&self) -> &Self::Air { - &self.air + memory_write_native_from_state(state, a.as_canonical_u32(), data); } } diff --git a/extensions/native/circuit/src/adapters/branch_native_adapter.rs b/extensions/native/circuit/src/adapters/branch_native_adapter.rs index 7d3e97a6bf..d764f9ad42 100644 --- a/extensions/native/circuit/src/adapters/branch_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/branch_native_adapter.rs @@ -1,21 +1,18 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, - }, - native_adapter::NativeReadRecord, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadOrImmediateAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -27,27 +24,7 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, }; -#[derive(Debug)] -pub struct BranchNativeAdapterChip { - pub air: BranchNativeAdapterAir, - _marker: PhantomData, -} - -impl BranchNativeAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: BranchNativeAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} +use crate::adapters::{memory_read_or_imm_native_from_state, tracing_read_or_imm_native}; #[repr(C)] #[derive(AlignedBorrow)] @@ -145,71 +122,131 @@ impl VmAdapterAir for BranchNativeAdapterAir { } } -impl VmAdapterChip for BranchNativeAdapterChip { - type ReadRecord = NativeReadRecord; - type WriteRecord = ExecutionState; - type Air = BranchNativeAdapterAir; - type Interface = BasicAdapterInterface, 2, 0, 1, 1>; +#[derive(derive_new::new)] +pub struct BranchNativeAdapterStep; + +impl AdapterTraceStep for BranchNativeAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [F; 2]; + type WriteData = (); + type TraceContext<'a> = (); - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut BranchNativeAdapterCols = adapter_row.borrow_mut(); + + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, d, e, .. } = *instruction; - - let reads = vec![memory.read::<1>(d, a), memory.read::<1>(e, b)]; - let i_reads: [_; 2] = std::array::from_fn(|i| reads[i].1); - - Ok(( - i_reads, - Self::ReadRecord { - reads: reads.try_into().unwrap(), - }, - )) + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { a, b, d, e, .. } = instruction; + let adapter_row: &mut BranchNativeAdapterCols = adapter_row.borrow_mut(); + + adapter_row.reads_aux[0].address.pointer = a; + let rs1 = tracing_read_or_imm_native( + memory, + d.as_canonical_u32(), + a, + &mut adapter_row.reads_aux[0].address.address_space, + &mut adapter_row.reads_aux[0].read_aux, + ); + adapter_row.reads_aux[1].address.pointer = b; + let rs2 = tracing_read_or_imm_native( + memory, + e.as_canonical_u32(), + b, + &mut adapter_row.reads_aux[1].address.address_space, + &mut adapter_row.reads_aux[1].read_aux, + ); + [rs1, rs2] } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + _memory: &mut TracingMemory, _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - from_state, - )) + _adapter_row: &mut [F], + _data: &Self::WriteData, + ) { } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _trace_ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let row_slice: &mut BranchNativeAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); + let adapter_row: &mut BranchNativeAdapterCols = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); - row_slice.from_state = write_record.map(F::from_canonical_u32); - for (i, x) in read_record.reads.iter().enumerate() { - let read = memory.record_by_id(x.0); + mem_helper.fill_from_prev(timestamp, &mut adapter_row.reads_aux[0].read_aux.base); + timestamp += 1; - row_slice.reads_aux[i].address = MemoryAddress::new(read.address_space, read.pointer); - aux_cols_factory - .generate_read_or_immediate_aux(read, &mut row_slice.reads_aux[i].read_aux); + mem_helper.fill_from_prev(timestamp, &mut adapter_row.reads_aux[1].read_aux.base); + + let read_aux0 = &mut adapter_row.reads_aux[0]; + if read_aux0.address.address_space.is_zero() { + read_aux0.read_aux.is_immediate = F::ONE; + read_aux0.read_aux.is_zero_aux = F::ZERO; + } else { + read_aux0.read_aux.is_immediate = F::ZERO; + read_aux0.read_aux.is_zero_aux = read_aux0.address.address_space.inverse(); + } + + let read_aux1 = &mut adapter_row.reads_aux[1]; + if read_aux1.address.address_space.is_zero() { + read_aux1.read_aux.is_immediate = F::ONE; + read_aux1.read_aux.is_zero_aux = F::ZERO; + } else { + read_aux1.read_aux.is_immediate = F::ZERO; + read_aux1.read_aux.is_zero_aux = read_aux1.address.address_space.inverse(); } } +} + +impl AdapterExecutorE1 for BranchNativeAdapterStep +where + F: PrimeField32, +{ + type ReadData = [F; 2]; + type WriteData = (); + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { a, b, d, e, .. } = instruction; + + let rs1 = memory_read_or_imm_native_from_state(state, d.as_canonical_u32(), a); + let rs2 = memory_read_or_imm_native_from_state(state, e.as_canonical_u32(), b); + + [rs1, rs2] + } - fn air(&self) -> &Self::Air { - &self.air + #[inline(always)] + fn write( + &self, + _state: &mut VmStateMut, + _instruction: &Instruction, + _data: &Self::WriteData, + ) { } } diff --git a/extensions/native/circuit/src/adapters/convert_adapter.rs b/extensions/native/circuit/src/adapters/convert_adapter.rs index cac6d91bac..fe1ee70778 100644 --- a/extensions/native/circuit/src/adapters/convert_adapter.rs +++ b/extensions/native/circuit/src/adapters/convert_adapter.rs @@ -1,71 +1,33 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{ + instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_MEMORY_AS, +}; use openvm_native_compiler::conversion::AS; +use openvm_rv32im_circuit::adapters::{memory_write_from_state, tracing_write}; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorReadRecord { - #[serde(with = "BigArray")] - pub reads: [RecordId; NUM_READS], -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct VectorWriteRecord { - pub from_state: ExecutionState, - pub writes: [RecordId; 1], -} - -#[allow(dead_code)] -#[derive(Debug)] -pub struct ConvertAdapterChip { - pub air: ConvertAdapterAir, - _marker: PhantomData, -} -impl - ConvertAdapterChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: ConvertAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} +use crate::adapters::{memory_read_native_from_state, tracing_read_native}; #[repr(C)] #[derive(AlignedBorrow)] @@ -155,74 +117,132 @@ impl Vm } } -impl VmAdapterChip - for ConvertAdapterChip -{ - type ReadRecord = VectorReadRecord<1, READ_SIZE>; - type WriteRecord = VectorWriteRecord; - type Air = ConvertAdapterAir; - type Interface = BasicAdapterInterface, 1, 1, READ_SIZE, WRITE_SIZE>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, e, .. } = *instruction; +#[derive(derive_new::new)] +pub struct ConvertAdapterStep; - let y_val = memory.read::(e, b); +impl AdapterTraceStep + for ConvertAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [F; READ_SIZE]; + type WriteData = [u8; WRITE_SIZE]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut ConvertAdapterCols = + adapter_row.borrow_mut(); + + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - Ok(([y_val.1], Self::ReadRecord { reads: [y_val.0] })) + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { b, e, .. } = instruction; + + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + + let adapter_row: &mut ConvertAdapterCols = + adapter_row.borrow_mut(); + + adapter_row.b_pointer = b; + let read = tracing_read_native( + memory, + b.as_canonical_u32(), + adapter_row.reads_aux[0].as_mut(), + ); + read } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (write_id, _) = memory.write::(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - writes: [write_id], - }, - )) + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_MEMORY_AS); + + let adapter_row: &mut ConvertAdapterCols = + adapter_row.borrow_mut(); + + adapter_row.a_pointer = a; + tracing_write( + memory, + RV32_MEMORY_AS, + a.as_canonical_u32(), + data, + &mut adapter_row.writes_aux[0], + ); } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut ConvertAdapterCols<_, READ_SIZE, WRITE_SIZE> = row_slice.borrow_mut(); + let adapter_row: &mut ConvertAdapterCols = + adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[0].as_mut()); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.writes_aux[0].as_mut()); + } +} + +impl AdapterExecutorE1 + for ConvertAdapterStep +where + F: PrimeField32, +{ + type ReadData = [F; READ_SIZE]; + type WriteData = [u8; WRITE_SIZE]; - let read = memory.record_by_id(read_record.reads[0]); - let write = memory.record_by_id(write_record.writes[0]); + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, e, .. } = instruction; - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.a_pointer = write.pointer; - row_slice.b_pointer = read.pointer; + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); - aux_cols_factory.generate_read_aux(read, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_write_aux(write, &mut row_slice.writes_aux[0]); + memory_read_native_from_state(state, b.as_canonical_u32()) } - fn air(&self) -> &Self::Air { - &self.air + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_MEMORY_AS); + + memory_write_from_state(state, RV32_MEMORY_AS, a.as_canonical_u32(), data); } } diff --git a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs index 4bcf96d195..d6eaf0a3bf 100644 --- a/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs +++ b/extensions/native/circuit/src/adapters/loadstore_native_adapter.rs @@ -5,19 +5,17 @@ use std::{ use openvm_circuit::{ arch::{ - instructions::LocalOpcode, AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, - ExecutionBus, ExecutionState, Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface, VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ conversion::AS, NativeLoadStoreOpcode::{self, *}, @@ -27,7 +25,11 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; + +use crate::adapters::{ + memory_read_native, memory_read_native_from_state, memory_write_native_from_state, + tracing_read_native, tracing_write_native, +}; pub struct NativeLoadStoreInstruction { pub is_valid: T, @@ -48,55 +50,6 @@ impl VmAdapterInterface type ProcessedInstruction = NativeLoadStoreInstruction; } -#[derive(Debug)] -pub struct NativeLoadStoreAdapterChip { - pub air: NativeLoadStoreAdapterAir, - offset: usize, - _marker: PhantomData, -} - -impl NativeLoadStoreAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - offset: usize, - ) -> Self { - Self { - air: NativeLoadStoreAdapterAir { - memory_bridge, - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - }, - offset, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeLoadStoreReadRecord { - pub pointer_read: RecordId, - pub data_read: Option, - pub write_as: F, - pub write_ptr: F, - - pub a: F, - pub b: F, - pub c: F, - pub d: F, - pub e: F, -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct NativeLoadStoreWriteRecord { - pub from_state: ExecutionState, - pub write_id: RecordId, -} - #[repr(C)] #[derive(Clone, Debug, AlignedBorrow)] pub struct NativeLoadStoreAdapterCols { @@ -214,23 +167,37 @@ impl VmAdapterAir } } -impl VmAdapterChip - for NativeLoadStoreAdapterChip +#[derive(derive_new::new)] +pub struct NativeLoadStoreAdapterStep { + offset: usize, +} + +impl AdapterTraceStep + for NativeLoadStoreAdapterStep +where + F: PrimeField32, { - type ReadRecord = NativeLoadStoreReadRecord; - type WriteRecord = NativeLoadStoreWriteRecord; - type Air = NativeLoadStoreAdapterAir; - type Interface = NativeLoadStoreAdapterInterface; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = std::mem::size_of::>(); + type ReadData = (F, [F; NUM_CELLS]); + type WriteData = [F; NUM_CELLS]; + type TraceContext<'a> = F; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { opcode, a, b, @@ -238,100 +205,209 @@ impl VmAdapterChip d, e, .. - } = *instruction; + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let read_as = d; - let read_ptr = c; - let read_cell = memory.read_cell(read_as, read_ptr); + let adapter_row: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + adapter_row.a = a; + adapter_row.b = b; + adapter_row.c = c; + + // Read the pointer value from memory + let [read_cell] = tracing_read_native::( + memory, + c.as_canonical_u32(), + adapter_row.pointer_read_aux_cols.as_mut(), + ); - let (data_read_as, data_write_as) = { - match local_opcode { - LOADW => (e, d), - STOREW | HINT_STOREW => (d, e), - } + let (data_read_as, _) = match local_opcode { + LOADW => (e.as_canonical_u32(), d.as_canonical_u32()), + STOREW | HINT_STOREW => (d.as_canonical_u32(), e.as_canonical_u32()), }; - let (data_read_ptr, data_write_ptr) = { - match local_opcode { - LOADW => (read_cell.1 + b, a), - STOREW | HINT_STOREW => (a, read_cell.1 + b), - } + + debug_assert_eq!(data_read_as, AS::Native as u32); + + let (data_read_ptr, _) = match local_opcode { + LOADW => ((read_cell + b).as_canonical_u32(), a.as_canonical_u32()), + STOREW | HINT_STOREW => (a.as_canonical_u32(), (read_cell + b).as_canonical_u32()), }; - let data_read = match local_opcode { - HINT_STOREW => None, - LOADW | STOREW => Some(memory.read::(data_read_as, data_read_ptr)), + // Read data based on opcode + let data_read: [F; NUM_CELLS] = match local_opcode { + HINT_STOREW => [F::ZERO; NUM_CELLS], + LOADW | STOREW => tracing_read_native::( + memory, + data_read_ptr, + adapter_row.data_read_aux_cols.as_mut(), + ), }; - let record = NativeLoadStoreReadRecord { - pointer_read: read_cell.0, - data_read: data_read.map(|x| x.0), - write_as: data_write_as, - write_ptr: data_write_ptr, + + (read_cell, data_read) + } + + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { + opcode, a, b, c, d, e, + .. + } = instruction; + + // TODO(ayush): remove duplication + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let adapter_row: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + + let [read_cell] = memory_read_native::(memory.data(), c.as_canonical_u32()); + + let (_, data_write_as) = match local_opcode { + LOADW => (e.as_canonical_u32(), d.as_canonical_u32()), + STOREW | HINT_STOREW => (d.as_canonical_u32(), e.as_canonical_u32()), }; - Ok(( - (read_cell.1, data_read.map_or([F::ZERO; NUM_CELLS], |x| x.1)), - record, - )) - } + debug_assert_eq!(data_write_as, AS::Native as u32); + + let data_write_ptr = match local_opcode { + LOADW => a.as_canonical_u32(), + STOREW | HINT_STOREW => (read_cell + b).as_canonical_u32(), + }; + + adapter_row.data_write_pointer = F::from_canonical_u32(data_write_ptr); - fn postprocess( - &mut self, - memory: &mut MemoryController, - _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let (write_id, _) = - memory.write::(read_record.write_as, read_record.write_ptr, output.writes); - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state: from_state.map(F::from_canonical_u32), - write_id, - }, - )) + // Write data to memory + tracing_write_native( + memory, + data_write_ptr, + data, + &mut adapter_row.data_write_aux_cols, + ); } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + is_hint_storew: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let cols: &mut NativeLoadStoreAdapterCols<_, NUM_CELLS> = row_slice.borrow_mut(); - cols.from_state = write_record.from_state; - cols.a = read_record.a; - cols.b = read_record.b; - cols.c = read_record.c; - - let data_read = read_record.data_read.map(|read| memory.record_by_id(read)); - if let Some(data_read) = data_read { - aux_cols_factory.generate_read_aux(data_read, &mut cols.data_read_aux_cols); + let adapter_row: &mut NativeLoadStoreAdapterCols = adapter_row.borrow_mut(); + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + // Fill auxiliary columns for memory operations + mem_helper.fill_from_prev(timestamp, adapter_row.pointer_read_aux_cols.as_mut()); + timestamp += 1; + + if is_hint_storew.is_zero() { + mem_helper.fill_from_prev(timestamp, adapter_row.data_read_aux_cols.as_mut()); + timestamp += 1; } - let write = memory.record_by_id(write_record.write_id); - cols.data_write_pointer = write.pointer; + mem_helper.fill_from_prev(timestamp, adapter_row.data_write_aux_cols.as_mut()); + } +} - aux_cols_factory.generate_read_aux( - memory.record_by_id(read_record.pointer_read), - &mut cols.pointer_read_aux_cols, - ); - aux_cols_factory.generate_write_aux(write, &mut cols.data_write_aux_cols); +impl AdapterExecutorE1 for NativeLoadStoreAdapterStep +where + F: PrimeField32, +{ + type ReadData = (F, [F; NUM_CELLS]); + type WriteData = [F; NUM_CELLS]; + + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let [read_cell]: [F; 1] = memory_read_native_from_state(state, c.as_canonical_u32()); + + let data_read_as = match local_opcode { + LOADW => e.as_canonical_u32(), + STOREW | HINT_STOREW => d.as_canonical_u32(), + }; + + debug_assert_eq!(data_read_as, AS::Native as u32); + + let data_read_ptr = match local_opcode { + LOADW => (read_cell + b).as_canonical_u32(), + STOREW | HINT_STOREW => a.as_canonical_u32(), + }; + + let data_read: [F; NUM_CELLS] = match local_opcode { + HINT_STOREW => [F::ZERO; NUM_CELLS], + LOADW | STOREW => memory_read_native_from_state(state, data_read_ptr), + }; + + (read_cell, data_read) } - fn air(&self) -> &Self::Air { - &self.air + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let [read_cell]: [F; 1] = memory_read_native(state.memory, c.as_canonical_u32()); + + let data_write_as = match local_opcode { + LOADW => d.as_canonical_u32(), + STOREW | HINT_STOREW => e.as_canonical_u32(), + }; + + debug_assert_eq!(data_write_as, AS::Native as u32); + + let data_write_ptr = match local_opcode { + LOADW => a.as_canonical_u32(), + STOREW | HINT_STOREW => (read_cell + b).as_canonical_u32(), + }; + + memory_write_native_from_state(state, data_write_ptr, data); } } diff --git a/extensions/native/circuit/src/adapters/mod.rs b/extensions/native/circuit/src/adapters/mod.rs index c5cd3b9422..6041f5861c 100644 --- a/extensions/native/circuit/src/adapters/mod.rs +++ b/extensions/native/circuit/src/adapters/mod.rs @@ -1,3 +1,13 @@ +use openvm_circuit::{ + arch::{execution_mode::E1E2ExecutionCtx, VmStateMut}, + system::memory::{ + offline_checker::{MemoryBaseAuxCols, MemoryReadOrImmediateAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + }, +}; +use openvm_native_compiler::conversion::AS; +use openvm_stark_backend::p3_field::PrimeField32; + pub mod alu_native_adapter; // 2 reads, 0 writes, imm support, jump support pub mod branch_native_adapter; @@ -6,3 +16,176 @@ pub mod convert_adapter; pub mod loadstore_native_adapter; // 2 reads, 1 write, read size = write size = N, no imm support, read/write to address space d pub mod native_vectorized_adapter; + +#[inline(always)] +pub fn memory_read_native(memory: &GuestMemory, ptr: u32) -> [F; N] +where + F: PrimeField32, +{ + // SAFETY: + // - address space `AS::Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.read::(AS::Native as u32, ptr) } +} + +#[inline(always)] +pub fn memory_read_or_imm_native(memory: &GuestMemory, addr_space: u32, ptr_or_imm: F) -> F +where + F: PrimeField32, +{ + debug_assert!(addr_space == AS::Immediate as u32 || addr_space == AS::Native as u32); + + if addr_space == AS::Native as u32 { + let [result]: [F; 1] = memory_read_native(memory, ptr_or_imm.as_canonical_u32()); + result + } else { + ptr_or_imm + } +} +#[inline(always)] +pub fn memory_write_native(memory: &mut GuestMemory, ptr: u32, data: &[F; N]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `AS::Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.write::(AS::Native as u32, ptr, data) } +} + +#[inline(always)] +pub fn memory_read_native_from_state( + state: &mut VmStateMut, + ptr: u32, +) -> [F; N] +where + F: PrimeField32, + Ctx: E1E2ExecutionCtx, +{ + state + .ctx + .on_memory_operation(AS::Native as u32, ptr, N as u32); + + memory_read_native(state.memory, ptr) +} + +#[inline(always)] +pub fn memory_read_or_imm_native_from_state( + state: &mut VmStateMut, + addr_space: u32, + ptr_or_imm: F, +) -> F +where + F: PrimeField32, + Ctx: E1E2ExecutionCtx, +{ + debug_assert!(addr_space == AS::Immediate as u32 || addr_space == AS::Native as u32); + + if addr_space == AS::Native as u32 { + let [result]: [F; 1] = memory_read_native_from_state(state, ptr_or_imm.as_canonical_u32()); + result + } else { + ptr_or_imm + } +} + +#[inline(always)] +pub fn memory_write_native_from_state( + state: &mut VmStateMut, + ptr: u32, + data: &[F; N], +) where + F: PrimeField32, + Ctx: E1E2ExecutionCtx, +{ + state + .ctx + .on_memory_operation(AS::Native as u32, ptr, N as u32); + + memory_write_native(state.memory, ptr, data) +} +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:BLOCK_SIZE]_4)` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +fn timed_read( + memory: &mut TracingMemory, + ptr: u32, +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.read::(AS::Native as u32, ptr) } +} + +#[inline(always)] +fn timed_write( + memory: &mut TracingMemory, + ptr: u32, + vals: &[F; BLOCK_SIZE], +) -> (u32, [F; BLOCK_SIZE]) +where + F: PrimeField32, +{ + // SAFETY: + // - address space `Native` will always have cell type `F` and minimum alignment of `1` + unsafe { memory.write::(AS::Native as u32, ptr, vals) } +} + +/// Reads register value at `ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read_native( + memory: &mut TracingMemory, + ptr: u32, + aux_cols: &mut MemoryBaseAuxCols, +) -> [F; BLOCK_SIZE] +where + F: PrimeField32, +{ + let (t_prev, data) = timed_read(memory, ptr); + aux_cols.set_prev(F::from_canonical_u32(t_prev)); + data +} + +/// Writes `ptr, vals` into memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_write_native( + memory: &mut TracingMemory, + ptr: u32, + vals: &[F; BLOCK_SIZE], + aux_cols: &mut MemoryWriteAuxCols, +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write(memory, ptr, vals); + aux_cols.set_prev(F::from_canonical_u32(t_prev), data_prev); +} + +/// Reads value at `_ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read_or_imm_native( + memory: &mut TracingMemory, + addr_space: u32, + ptr_or_imm: F, + addr_space_mut: &mut F, + aux_cols: &mut MemoryReadOrImmediateAuxCols, +) -> F +where + F: PrimeField32, +{ + debug_assert!(addr_space == AS::Immediate as u32 || addr_space == AS::Native as u32); + + if addr_space == AS::Immediate as u32 { + *addr_space_mut = F::ZERO; + memory.increment_timestamp(); + ptr_or_imm + } else { + *addr_space_mut = F::from_canonical_u32(AS::Native as u32); + let data: [F; 1] = + tracing_read_native(memory, ptr_or_imm.as_canonical_u32(), &mut aux_cols.base); + data[0] + } +} diff --git a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs index c151197297..bf57a3b4e6 100644 --- a/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs +++ b/extensions/native/circuit/src/adapters/native_vectorized_adapter.rs @@ -1,20 +1,18 @@ use std::{ borrow::{Borrow, BorrowMut}, - marker::PhantomData, + mem::size_of, }; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -25,44 +23,11 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -#[allow(dead_code)] -#[derive(Debug)] -pub struct NativeVectorizedAdapterChip { - pub air: NativeVectorizedAdapterAir, - _marker: PhantomData, -} - -impl NativeVectorizedAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: NativeVectorizedAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeVectorizedReadRecord { - pub b: RecordId, - pub c: RecordId, -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct NativeVectorizedWriteRecord { - pub from_state: ExecutionState, - pub a: RecordId, -} +use super::{ + memory_read_native_from_state, memory_write_native_from_state, tracing_read_native, + tracing_write_native, +}; #[repr(C)] #[derive(AlignedBorrow)] @@ -156,80 +121,140 @@ impl VmAdapterAir for NativeVectoriz } } -impl VmAdapterChip for NativeVectorizedAdapterChip { - type ReadRecord = NativeVectorizedReadRecord; - type WriteRecord = NativeVectorizedWriteRecord; - type Air = NativeVectorizedAdapterAir; - type Interface = BasicAdapterInterface, 2, 1, N, N>; +#[derive(derive_new::new)] +pub struct NativeVectorizedAdapterStep; - fn preprocess( - &mut self, - memory: &mut MemoryController, +impl AdapterTraceStep for NativeVectorizedAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[F; N]; 2]; + type WriteData = [F; N]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut NativeVectorizedAdapterCols = adapter_row.borrow_mut(); + + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp()); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; - - let y_val = memory.read::(d, b); - let z_val = memory.read::(e, c); - - Ok(( - [y_val.1, z_val.1], - Self::ReadRecord { - b: y_val.0, - c: z_val.0, - }, - )) + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { b, c, d, e, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + + let adapter_row: &mut NativeVectorizedAdapterCols = adapter_row.borrow_mut(); + + adapter_row.b_pointer = b; + let y_val = tracing_read_native( + memory, + b.as_canonical_u32(), + adapter_row.reads_aux[0].as_mut(), + ); + adapter_row.c_pointer = c; + let z_val = tracing_read_native( + memory, + c.as_canonical_u32(), + adapter_row.reads_aux[1].as_mut(), + ); + + [y_val, z_val] } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (a_val, _) = memory.write(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - a: a_val, - }, - )) + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + + let adapter_row: &mut NativeVectorizedAdapterCols = adapter_row.borrow_mut(); + + adapter_row.a_pointer = a; + tracing_write_native( + memory, + a.as_canonical_u32(), + data, + &mut adapter_row.writes_aux[0], + ); } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut NativeVectorizedAdapterCols<_, N> = row_slice.borrow_mut(); - - let b_record = memory.record_by_id(read_record.b); - let c_record = memory.record_by_id(read_record.c); - let a_record = memory.record_by_id(write_record.a); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.a_pointer = a_record.pointer; - row_slice.b_pointer = b_record.pointer; - row_slice.c_pointer = c_record.pointer; - aux_cols_factory.generate_read_aux(b_record, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(c_record, &mut row_slice.reads_aux[1]); - aux_cols_factory.generate_write_aux(a_record, &mut row_slice.writes_aux[0]); + let adapter_row: &mut NativeVectorizedAdapterCols = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[0].as_mut()); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[1].as_mut()); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.writes_aux[0].as_mut()); } +} + +impl AdapterExecutorE1 for NativeVectorizedAdapterStep +where + F: PrimeField32, +{ + type ReadData = [[F; N]; 2]; + type WriteData = [F; N]; + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, c, d, e, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); + debug_assert_eq!(e.as_canonical_u32(), AS::Native as u32); + + let y_val: [F; N] = memory_read_native_from_state(state, b.as_canonical_u32()); + let z_val: [F; N] = memory_read_native_from_state(state, c.as_canonical_u32()); + + [y_val, z_val] + } + + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), AS::Native as u32); - fn air(&self) -> &Self::Air { - &self.air + memory_write_native_from_state(state, a.as_canonical_u32(), data); } } diff --git a/extensions/native/circuit/src/branch_eq/core.rs b/extensions/native/circuit/src/branch_eq/core.rs new file mode 100644 index 0000000000..d1b9359815 --- /dev/null +++ b/extensions/native/circuit/src/branch_eq/core.rs @@ -0,0 +1,168 @@ +use std::{array, borrow::BorrowMut}; + +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterExecutorE1, AdapterTraceStep, Result, StepExecutorE1, TraceStep, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, +}; +use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_native_compiler::NativeBranchEqualOpcode; +use openvm_rv32im_circuit::BranchEqualCoreCols; +use openvm_rv32im_transpiler::BranchEqualOpcode; +use openvm_stark_backend::p3_field::PrimeField32; + +pub struct NativeBranchEqualStep { + adapter: A, + pub offset: usize, + pub pc_step: u32, +} + +impl NativeBranchEqualStep { + pub fn new(adapter: A, offset: usize, pc_step: u32) -> Self { + Self { + adapter, + offset, + pc_step, + } + } +} + +impl TraceStep for NativeBranchEqualStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[F; 2]>, + WriteData = (), + TraceContext<'a> = (), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + NativeBranchEqualOpcode::from_usize(opcode - self.offset) + ) + } + + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { opcode, c: imm, .. } = instruction; + + let branch_eq_opcode = + NativeBranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + + let (cmp_result, diff_idx, diff_inv_val) = run_eq(branch_eq_opcode, rs1, rs2); + + let core_row: &mut BranchEqualCoreCols<_, 1> = core_row.borrow_mut(); + core_row.a = [rs1]; + core_row.b = [rs2]; + core_row.cmp_result = F::from_bool(cmp_result); + core_row.imm = imm; + core_row.opcode_beq_flag = F::from_bool(branch_eq_opcode.0 == BranchEqualOpcode::BEQ); + core_row.opcode_bne_flag = F::from_bool(branch_eq_opcode.0 == BranchEqualOpcode::BNE); + core_row.diff_inv_marker = + array::from_fn(|i| if i == diff_idx { diff_inv_val } else { F::ZERO }); + + if cmp_result { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); + } + + *trace_offset += width; + + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, _core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + } +} + +impl StepExecutorE1 for NativeBranchEqualStep +where + F: PrimeField32, + A: 'static + for<'a> AdapterExecutorE1, WriteData = ()>, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { opcode, c: imm, .. } = instruction; + + let branch_eq_opcode = + NativeBranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let [rs1, rs2] = self.adapter.read(state, instruction).into(); + + // TODO(ayush): probably don't need the other values + let (cmp_result, _, _) = run_eq::(branch_eq_opcode, rs1, rs2); + + if cmp_result { + // TODO(ayush): verify this is fine + // state.pc = state.pc.wrapping_add(imm.as_canonical_u32()); + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); + } + + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) + } +} + +// Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) +#[inline(always)] +pub(super) fn run_eq(local_opcode: NativeBranchEqualOpcode, x: F, y: F) -> (bool, usize, F) +where + F: PrimeField32, +{ + if x != y { + return ( + local_opcode.0 == BranchEqualOpcode::BNE, + 0, + (x - y).inverse(), + ); + } + (local_opcode.0 == BranchEqualOpcode::BEQ, 0, F::ZERO) +} diff --git a/extensions/native/circuit/src/branch_eq/mod.rs b/extensions/native/circuit/src/branch_eq/mod.rs index e1b566bb7f..272f83289e 100644 --- a/extensions/native/circuit/src/branch_eq/mod.rs +++ b/extensions/native/circuit/src/branch_eq/mod.rs @@ -1,8 +1,12 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; -use openvm_rv32im_circuit::{BranchEqualCoreAir, BranchEqualCoreChip}; +pub mod core; -use super::adapters::branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterChip}; +use core::NativeBranchEqualStep; + +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; +use openvm_rv32im_circuit::BranchEqualCoreAir; + +use crate::adapters::branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterStep}; pub type NativeBranchEqAir = VmAirWrapper>; -pub type NativeBranchEqChip = - VmChipWrapper, BranchEqualCoreChip<1>>; +pub type NativeBranchEqStep = NativeBranchEqualStep; +pub type NativeBranchEqChip = NewVmChipWrapper; diff --git a/extensions/native/circuit/src/castf/core.rs b/extensions/native/circuit/src/castf/core.rs index 664767e35e..0c096b402a 100644 --- a/extensions/native/circuit/src/castf/core.rs +++ b/extensions/native/circuit/src/castf/core.rs @@ -1,14 +1,21 @@ use std::borrow::{Borrow, BorrowMut}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::var_range::{ SharedVariableRangeCheckerChip, VariableRangeCheckerBus, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::CastfOpcode; use openvm_rv32im_circuit::adapters::RV32_REGISTER_NUM_LIMBS; use openvm_stark_backend::{ @@ -17,7 +24,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; // LIMB_BITS is the size of the limbs in bits. pub(crate) const LIMB_BITS: usize = 8; @@ -32,7 +38,7 @@ pub struct CastFCoreCols { pub is_valid: T, } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct CastFCoreAir { pub bus: VariableRangeCheckerBus, /* to communicate with the range checker that checks that * all limbs are < 2^LIMB_BITS */ @@ -104,44 +110,36 @@ where } } -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct CastFRecord { - pub in_val: F, - pub out_val: [u32; RV32_REGISTER_NUM_LIMBS], -} - -pub struct CastFCoreChip { - pub air: CastFCoreAir, +#[derive(derive_new::new)] +pub struct CastFCoreStep { + adapter: A, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl CastFCoreChip { - pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self { - Self { - air: CastFCoreAir { - bus: range_checker_chip.bus(), - }, - range_checker_chip, - } - } -} - -impl> VmCoreChip for CastFCoreChip +impl TraceStep for CastFCoreStep where - I::Reads: Into<[[F; 1]; 1]>, - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = [F; 1], + WriteData = [u8; RV32_REGISTER_NUM_LIMBS], + TraceContext<'a> = (), + >, { - type Record = CastFRecord; - type Air = CastFCoreAir; + fn get_opcode_name(&self, _opcode: usize) -> String { + format!("{:?}", CastfOpcode::CASTF) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let Instruction { opcode, .. } = instruction; assert_eq!( @@ -149,52 +147,102 @@ where CastfOpcode::CASTF as usize ); - let y = reads.into()[0][0]; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [y] = self.adapter.read(state.memory, instruction, adapter_row); + let x = CastF::solve(y.as_canonical_u32()); - let output = AdapterRuntimeContext { - to_pc: None, - writes: [x.map(F::from_canonical_u32)].into(), - }; + let core_row: &mut CastFCoreCols = core_row.borrow_mut(); + core_row.in_val = y; + core_row.out_val = x.map(F::from_canonical_u8); + core_row.is_valid = F::ONE; - let record = CastFRecord { - in_val: y, - out_val: x, - }; + self.adapter + .write(state.memory, instruction, adapter_row, &x); - Ok((output, record)) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn get_opcode_name(&self, _opcode: usize) -> String { - format!("{:?}", CastfOpcode::CASTF) + *trace_offset += width; + + Ok(()) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - for (i, limb) in record.out_val.iter().enumerate() { - if i == 3 { - self.range_checker_chip.add_count(*limb, FINAL_LIMB_BITS); - } else { - self.range_checker_chip.add_count(*limb, LIMB_BITS); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let core_row: &mut CastFCoreCols = core_row.borrow_mut(); + + if core_row.is_valid == F::ONE { + for (i, limb) in core_row.out_val.iter().enumerate() { + if i == 3 { + self.range_checker_chip + .add_count(limb.as_canonical_u32(), FINAL_LIMB_BITS); + } else { + self.range_checker_chip + .add_count(limb.as_canonical_u32(), LIMB_BITS); + } } } + } +} + +impl StepExecutorE1 for CastFCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; + + assert_eq!( + opcode.local_opcode_idx(CastfOpcode::CLASS_OFFSET), + CastfOpcode::CASTF as usize + ); + + let [y] = self.adapter.read(state, instruction); + + let x = CastF::solve(y.as_canonical_u32()); + + self.adapter.write(state, instruction, &x); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - let cols: &mut CastFCoreCols = row_slice.borrow_mut(); - cols.in_val = record.in_val; - cols.out_val = record.out_val.map(F::from_canonical_u32); - cols.is_valid = F::ONE; + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } pub struct CastF; impl CastF { - pub(super) fn solve(y: u32) -> [u32; RV32_REGISTER_NUM_LIMBS] { - let mut x = [0; 4]; + pub(super) fn solve(y: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { + let mut x = [0u8; RV32_REGISTER_NUM_LIMBS]; for (i, limb) in x.iter_mut().enumerate() { - *limb = (y >> (8 * i)) & 0xFF; + *limb = ((y >> (8 * i)) & 0xFF) as u8; } x } diff --git a/extensions/native/circuit/src/castf/mod.rs b/extensions/native/circuit/src/castf/mod.rs index 9fbd77f245..b7ac2fc266 100644 --- a/extensions/native/circuit/src/castf/mod.rs +++ b/extensions/native/circuit/src/castf/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::convert_adapter::{ConvertAdapterAir, ConvertAdapterChip}; +use super::adapters::convert_adapter::{ConvertAdapterAir, ConvertAdapterStep}; #[cfg(test)] mod tests; @@ -9,4 +9,5 @@ mod core; pub use core::*; pub type CastFAir = VmAirWrapper, CastFCoreAir>; -pub type CastFChip = VmChipWrapper, CastFCoreChip>; +pub type CastFStep = CastFCoreStep>; +pub type CastFChip = NewVmChipWrapper; diff --git a/extensions/native/circuit/src/castf/tests.rs b/extensions/native/circuit/src/castf/tests.rs index 9758e6b956..5b816d5945 100644 --- a/extensions/native/circuit/src/castf/tests.rs +++ b/extensions/native/circuit/src/castf/tests.rs @@ -1,6 +1,9 @@ -use std::borrow::BorrowMut; +use std::borrow::{Borrow, BorrowMut}; -use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; +use openvm_circuit::arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder}, + VmAirWrapper, +}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_native_compiler::CastfOpcode; use openvm_stark_backend::{ @@ -13,11 +16,28 @@ use openvm_stark_sdk::{ use rand::{rngs::StdRng, Rng}; use super::{ - super::adapters::convert_adapter::{ConvertAdapterChip, ConvertAdapterCols}, - CastF, CastFChip, CastFCoreChip, CastFCoreCols, FINAL_LIMB_BITS, LIMB_BITS, + super::adapters::convert_adapter::{ConvertAdapterAir, ConvertAdapterCols, ConvertAdapterStep}, + CastF, CastFChip, CastFCoreAir, CastFCoreCols, CastFStep, FINAL_LIMB_BITS, LIMB_BITS, }; + +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +fn create_test_chip(tester: &VmChipTestBuilder) -> CastFChip { + CastFChip::::new( + VmAirWrapper::new( + ConvertAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + CastFCoreAir::new(tester.range_checker().bus()), + ), + CastFStep::new( + ConvertAdapterStep::<1, 4>::new(), + tester.range_checker().clone(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} + fn generate_uint_number(rng: &mut StdRng) -> u32 { rng.gen_range(0..(1 << 30) - 1) } @@ -37,7 +57,7 @@ fn prepare_castf_rand_write_execute( let operand1_f = F::from_canonical_u32(y); - tester.write_cell(as_y, address_y, operand1_f); + tester.memory.write(as_y, address_y, [operand1_f]); let x = CastF::solve(operand1); tester.execute( @@ -48,7 +68,7 @@ fn prepare_castf_rand_write_execute( ), ); assert_eq!( - x.map(F::from_canonical_u32), + x.map(F::from_canonical_u8), tester.read::<4>(as_x, address_x) ); } @@ -57,15 +77,7 @@ fn prepare_castf_rand_write_execute( fn castf_rand_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(tester.range_checker()), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&tester); let num_tests: usize = 3; for _ in 0..num_tests { @@ -81,15 +93,7 @@ fn castf_rand_test() { fn negative_castf_overflow_test() { let mut tester = VmChipTestBuilder::default(); let range_checker_chip = tester.range_checker(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&tester); let mut rng = create_seeded_rng(); let y = generate_uint_number(&mut rng); @@ -125,15 +129,7 @@ fn negative_castf_overflow_test() { fn negative_castf_memread_test() { let mut tester = VmChipTestBuilder::default(); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&tester); let mut rng = create_seeded_rng(); let y = generate_uint_number(&mut rng); @@ -169,15 +165,7 @@ fn negative_castf_memread_test() { fn negative_castf_memwrite_test() { let mut tester = VmChipTestBuilder::default(); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&tester); let mut rng = create_seeded_rng(); let y = generate_uint_number(&mut rng); @@ -213,15 +201,7 @@ fn negative_castf_memwrite_test() { fn negative_castf_as_test() { let mut tester = VmChipTestBuilder::default(); let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = CastFChip::::new( - ConvertAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - CastFCoreChip::new(range_checker_chip.clone()), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&tester); let mut rng = create_seeded_rng(); let y = generate_uint_number(&mut rng); diff --git a/extensions/native/circuit/src/extension.rs b/extensions/native/circuit/src/extension.rs index 8f73423e40..e76fb4fa49 100644 --- a/extensions/native/circuit/src/extension.rs +++ b/extensions/native/circuit/src/extension.rs @@ -1,17 +1,19 @@ -use air::VerifyBatchBus; -use alu_native_adapter::AluNativeAdapterChip; -use branch_native_adapter::BranchNativeAdapterChip; +use alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterStep}; +use branch_native_adapter::{BranchNativeAdapterAir, BranchNativeAdapterStep}; +use convert_adapter::{ConvertAdapterAir, ConvertAdapterStep}; use derive_more::derive::From; -use loadstore_native_adapter::NativeLoadStoreAdapterChip; -use native_vectorized_adapter::NativeVectorizedAdapterChip; +use fri::{FriReducedOpeningAir, FriReducedOpeningChip, FriReducedOpeningStep}; +use jal::{JalRangeCheckAir, JalRangeCheckChip, JalRangeCheckStep}; +use loadstore_native_adapter::{NativeLoadStoreAdapterAir, NativeLoadStoreAdapterStep}; +use native_vectorized_adapter::{NativeVectorizedAdapterAir, NativeVectorizedAdapterStep}; use openvm_circuit::{ arch::{ - ExecutionBridge, MemoryConfig, SystemConfig, SystemPort, VmExtension, VmInventory, - VmInventoryBuilder, VmInventoryError, + ExecutionBridge, MemoryConfig, SystemConfig, SystemPort, VmAirWrapper, VmExtension, + VmInventory, VmInventoryBuilder, VmInventoryError, }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor, VmConfig}; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_instructions::{program::DEFAULT_PC_STEP, LocalOpcode, PhantomDiscriminant}; use openvm_native_compiler::{ @@ -21,7 +23,7 @@ use openvm_native_compiler::{ }; use openvm_poseidon2_air::Poseidon2Config; use openvm_rv32im_circuit::{ - BranchEqualCoreChip, Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, + BranchEqualCoreAir, Rv32I, Rv32IExecutor, Rv32IPeriphery, Rv32Io, Rv32IoExecutor, Rv32IoPeriphery, Rv32M, Rv32MExecutor, Rv32MPeriphery, }; use openvm_stark_backend::p3_field::PrimeField32; @@ -29,12 +31,15 @@ use serde::{Deserialize, Serialize}; use strum::IntoEnumIterator; use crate::{ - adapters::{convert_adapter::ConvertAdapterChip, *}, - chip::NativePoseidon2Chip, + adapters::*, phantom::*, + poseidon2::{air::VerifyBatchBus, new_native_poseidon2_chip, NativePoseidon2Chip}, *, }; +// TODO(ayush): this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[derive(Clone, Debug, Serialize, Deserialize, VmConfig, derive_new::new)] pub struct NativeConfig { #[system] @@ -63,7 +68,7 @@ impl NativeConfig { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Native; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, From, AnyEnum)] pub enum NativeExecutor { LoadStore(NativeLoadStoreChip), BlockLoadStore(NativeLoadStoreChip), @@ -94,58 +99,83 @@ impl VmExtension for Native { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); + + let range_checker = &builder.system_base().range_checker_chip; let mut load_store_chip = NativeLoadStoreChip::::new( - NativeLoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + VmAirWrapper::new( + NativeLoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + NativeLoadStoreCoreAir::new(NativeLoadStoreOpcode::CLASS_OFFSET), + ), + NativeLoadStoreCoreStep::new( + NativeLoadStoreAdapterStep::new(NativeLoadStoreOpcode::CLASS_OFFSET), NativeLoadStoreOpcode::CLASS_OFFSET, ), - NativeLoadStoreCoreChip::new(NativeLoadStoreOpcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); - load_store_chip.core.set_streams(builder.streams().clone()); - + load_store_chip.step.set_streams(builder.streams().clone()); inventory.add_executor( load_store_chip, NativeLoadStoreOpcode::iter().map(|x| x.global_opcode()), )?; let mut block_load_store_chip = NativeLoadStoreChip::::new( - NativeLoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + VmAirWrapper::new( + NativeLoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + NativeLoadStoreCoreAir::new(NativeLoadStore4Opcode::CLASS_OFFSET), + ), + NativeLoadStoreCoreStep::new( + NativeLoadStoreAdapterStep::new(NativeLoadStore4Opcode::CLASS_OFFSET), NativeLoadStore4Opcode::CLASS_OFFSET, ), - NativeLoadStoreCoreChip::new(NativeLoadStore4Opcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); block_load_store_chip - .core + .step .set_streams(builder.streams().clone()); - inventory.add_executor( block_load_store_chip, NativeLoadStore4Opcode::iter().map(|x| x.global_opcode()), )?; let branch_equal_chip = NativeBranchEqChip::new( - BranchNativeAdapterChip::<_>::new(execution_bus, program_bus, memory_bridge), - BranchEqualCoreChip::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + NativeBranchEqAir::new( + BranchNativeAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + BranchEqualCoreAir::new(NativeBranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + NativeBranchEqStep::new( + BranchNativeAdapterStep::new(), + NativeBranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( branch_equal_chip, NativeBranchEqualOpcode::iter().map(|x| x.global_opcode()), )?; - let jal_chip = JalRangeCheckChip::new( - ExecutionBridge::new(execution_bus, program_bus), - offline_memory.clone(), - builder.system_base().range_checker_chip.clone(), + let jal_chip = JalRangeCheckChip::::new( + JalRangeCheckAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + range_checker.bus(), + ), + JalRangeCheckStep::new(range_checker.clone()), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( jal_chip, @@ -155,44 +185,63 @@ impl VmExtension for Native { ], )?; - let field_arithmetic_chip = FieldArithmeticChip::new( - AluNativeAdapterChip::::new(execution_bus, program_bus, memory_bridge), - FieldArithmeticCoreChip::new(), - offline_memory.clone(), + let field_arithmetic_chip = FieldArithmeticChip::::new( + VmAirWrapper::new( + AluNativeAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + FieldArithmeticCoreAir::new(), + ), + FieldArithmeticStep::new(AluNativeAdapterStep::new()), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( field_arithmetic_chip, FieldArithmeticOpcode::iter().map(|x| x.global_opcode()), )?; - let field_extension_chip = FieldExtensionChip::new( - NativeVectorizedAdapterChip::new(execution_bus, program_bus, memory_bridge), - FieldExtensionCoreChip::new(), - offline_memory.clone(), + let field_extension_chip = FieldExtensionChip::::new( + VmAirWrapper::new( + NativeVectorizedAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + FieldExtensionCoreAir::new(), + ), + FieldExtensionStep::new(NativeVectorizedAdapterStep::new()), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( field_extension_chip, FieldExtensionOpcode::iter().map(|x| x.global_opcode()), )?; - let fri_reduced_opening_chip = FriReducedOpeningChip::new( - execution_bus, - program_bus, - memory_bridge, - offline_memory.clone(), - builder.streams().clone(), + let fri_reduced_opening_chip = FriReducedOpeningChip::::new( + FriReducedOpeningAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + FriReducedOpeningStep::new(builder.streams().clone()), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); + inventory.add_executor( fri_reduced_opening_chip, FriOpcode::iter().map(|x| x.global_opcode()), )?; - let poseidon2_chip = NativePoseidon2Chip::new( + let poseidon2_chip = new_native_poseidon2_chip( builder.system_port(), - offline_memory.clone(), Poseidon2Config::default(), VerifyBatchBus::new(builder.new_bus_idx()), builder.streams().clone(), + // TODO: this may use too much memory. + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( poseidon2_chip, @@ -236,7 +285,7 @@ pub(crate) mod phantom { use eyre::bail; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::PhantomDiscriminant; use openvm_stark_backend::p3_field::{Field, PrimeField32}; @@ -250,11 +299,11 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintInputSubEx { fn phantom_execute( &mut self, - _: &MemoryController, + _: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let hint = match streams.input_stream.pop_front() { @@ -275,11 +324,11 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintSliceSubEx { fn phantom_execute( &mut self, - _: &MemoryController, + _: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let hint = match streams.input_stream.pop_front() { @@ -298,15 +347,14 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativePrintSubEx { fn phantom_execute( &mut self, - memory: &MemoryController, + memory: &GuestMemory, _: &mut Streams, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, c_upper: u16, ) -> eyre::Result<()> { - let addr_space = F::from_canonical_u16(c_upper); - let value = memory.unsafe_read_cell(addr_space, a); + let [value] = unsafe { memory.read::(c_upper as u32, a) }; println!("{}", value); Ok(()) } @@ -315,18 +363,16 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintBitsSubEx { fn phantom_execute( &mut self, - memory: &MemoryController, + memory: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + len: u32, c_upper: u16, ) -> eyre::Result<()> { - let addr_space = F::from_canonical_u16(c_upper); - let val = memory.unsafe_read_cell(addr_space, a); + let [val] = unsafe { memory.read::(c_upper as u32, a) }; let mut val = val.as_canonical_u32(); - let len = b.as_canonical_u32(); assert!(streams.hint_stream.is_empty()); for _ in 0..len { streams @@ -341,11 +387,11 @@ pub(crate) mod phantom { impl PhantomSubExecutor for NativeHintLoadSubEx { fn phantom_execute( &mut self, - _: &MemoryController, + _: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let payload = match streams.input_stream.pop_front() { @@ -367,7 +413,7 @@ pub(crate) mod phantom { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct CastFExtension; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, From, AnyEnum)] pub enum CastFExtensionExecutor { CastF(CastFChip), } @@ -391,13 +437,19 @@ impl VmExtension for CastFExtension { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); - let range_checker = builder.system_base().range_checker_chip.clone(); - - let castf_chip = CastFChip::new( - ConvertAdapterChip::new(execution_bus, program_bus, memory_bridge), - CastFCoreChip::new(range_checker.clone()), - offline_memory.clone(), + let range_checker = &builder.system_base().range_checker_chip; + + let castf_chip = CastFChip::::new( + VmAirWrapper::new( + ConvertAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + CastFCoreAir::new(range_checker.bus()), + ), + CastFStep::new(ConvertAdapterStep::<1, 4>::new(), range_checker.clone()), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor(castf_chip, [CastfOpcode::CASTF.global_opcode()])?; diff --git a/extensions/native/circuit/src/field_arithmetic/core.rs b/extensions/native/circuit/src/field_arithmetic/core.rs index c813f6a066..2ae4811450 100644 --- a/extensions/native/circuit/src/field_arithmetic/core.rs +++ b/extensions/native/circuit/src/field_arithmetic/core.rs @@ -1,12 +1,19 @@ use std::borrow::{Borrow, BorrowMut}; use itertools::izip; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::FieldArithmeticOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -31,7 +38,7 @@ pub struct FieldArithmeticCoreCols { pub divisor_inv: T, } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct FieldArithmeticCoreAir {} impl BaseAir for FieldArithmeticCoreAir { @@ -114,91 +121,126 @@ pub struct FieldArithmeticRecord { pub c: F, } -pub struct FieldArithmeticCoreChip { - pub air: FieldArithmeticCoreAir, +#[derive(derive_new::new)] +pub struct FieldArithmeticCoreStep { + adapter: A, } -impl FieldArithmeticCoreChip { - pub fn new() -> Self { - Self { - air: FieldArithmeticCoreAir {}, - } +impl TraceStep for FieldArithmeticCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = [F; 2], + WriteData = [F; 1], + TraceContext<'a> = (), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET) + ) } -} -impl Default for FieldArithmeticCoreChip { - fn default() -> Self { - Self::new() + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { opcode, .. } = instruction; + let local_opcode = FieldArithmeticOpcode::from_usize( + opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET), + ); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [b_val, c_val] = self.adapter.read(state.memory, instruction, adapter_row); + + let a_val = FieldArithmetic::run_field_arithmetic(local_opcode, b_val, c_val).unwrap(); + + let core_row: &mut FieldArithmeticCoreCols<_> = core_row.borrow_mut(); + core_row.a = a_val; + core_row.b = b_val; + core_row.c = c_val; + + core_row.is_add = F::from_bool(local_opcode == FieldArithmeticOpcode::ADD); + core_row.is_sub = F::from_bool(local_opcode == FieldArithmeticOpcode::SUB); + core_row.is_mul = F::from_bool(local_opcode == FieldArithmeticOpcode::MUL); + core_row.is_div = F::from_bool(local_opcode == FieldArithmeticOpcode::DIV); + + self.adapter + .write(state.memory, instruction, adapter_row, &[a_val]); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let core_row: &mut FieldArithmeticCoreCols<_> = core_row.borrow_mut(); + + core_row.divisor_inv = if core_row.is_div.is_zero() { + F::ZERO + } else { + core_row.c.inverse() + }; } } -impl> VmCoreChip for FieldArithmeticCoreChip +impl StepExecutorE1 for FieldArithmeticCoreStep where - I::Reads: Into<[[F; 1]; 2]>, - I::Writes: From<[[F; 1]; 1]>, + F: PrimeField32, + A: 'static + for<'a> AdapterExecutorE1, { - type Record = FieldArithmeticRecord; - type Air = FieldArithmeticCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute_e1( + &mut self, + state: &mut VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { let Instruction { opcode, .. } = instruction; + let local_opcode = FieldArithmeticOpcode::from_usize( opcode.local_opcode_idx(FieldArithmeticOpcode::CLASS_OFFSET), ); - let data: [[F; 1]; 2] = reads.into(); - let b = data[0][0]; - let c = data[1][0]; - let a = FieldArithmetic::run_field_arithmetic(local_opcode, b, c).unwrap(); - - let output: AdapterRuntimeContext = AdapterRuntimeContext { - to_pc: None, - writes: [[a]].into(), - }; + let [b_val, c_val] = self.adapter.read(state, instruction); + let a_val = FieldArithmetic::run_field_arithmetic(local_opcode, b_val, c_val).unwrap(); - let record = Self::Record { - opcode: local_opcode, - a, - b, - c, - }; + self.adapter.write(state, instruction, &[a_val]); - Ok((output, record)) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - FieldArithmeticOpcode::from_usize(opcode - FieldArithmeticOpcode::CLASS_OFFSET) - ) + Ok(()) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let FieldArithmeticRecord { opcode, a, b, c } = record; - let row_slice: &mut FieldArithmeticCoreCols<_> = row_slice.borrow_mut(); - row_slice.a = a; - row_slice.b = b; - row_slice.c = c; - - row_slice.is_add = F::from_bool(opcode == FieldArithmeticOpcode::ADD); - row_slice.is_sub = F::from_bool(opcode == FieldArithmeticOpcode::SUB); - row_slice.is_mul = F::from_bool(opcode == FieldArithmeticOpcode::MUL); - row_slice.is_div = F::from_bool(opcode == FieldArithmeticOpcode::DIV); - row_slice.divisor_inv = if opcode == FieldArithmeticOpcode::DIV { - c.inverse() - } else { - F::ZERO - }; - } + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; - fn air(&self) -> &Self::Air { - &self.air + Ok(()) } } diff --git a/extensions/native/circuit/src/field_arithmetic/mod.rs b/extensions/native/circuit/src/field_arithmetic/mod.rs index 865434cb37..1cf9a27925 100644 --- a/extensions/native/circuit/src/field_arithmetic/mod.rs +++ b/extensions/native/circuit/src/field_arithmetic/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterChip}; +use crate::adapters::alu_native_adapter::{AluNativeAdapterAir, AluNativeAdapterStep}; #[cfg(test)] mod tests; @@ -9,5 +9,5 @@ mod core; pub use core::*; pub type FieldArithmeticAir = VmAirWrapper; -pub type FieldArithmeticChip = - VmChipWrapper, FieldArithmeticCoreChip>; +pub type FieldArithmeticStep = FieldArithmeticCoreStep; +pub type FieldArithmeticChip = NewVmChipWrapper; diff --git a/extensions/native/circuit/src/field_arithmetic/tests.rs b/extensions/native/circuit/src/field_arithmetic/tests.rs index 8e69f8c44b..25f7a62895 100644 --- a/extensions/native/circuit/src/field_arithmetic/tests.rs +++ b/extensions/native/circuit/src/field_arithmetic/tests.rs @@ -1,6 +1,9 @@ use std::borrow::BorrowMut; -use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; +use openvm_circuit::arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder}, + VmAirWrapper, +}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_native_compiler::FieldArithmeticOpcode; use openvm_stark_backend::{ @@ -17,9 +20,27 @@ use rand::Rng; use strum::EnumCount; use super::{ - core::FieldArithmeticCoreChip, FieldArithmetic, FieldArithmeticChip, FieldArithmeticCoreCols, + FieldArithmetic, FieldArithmeticChip, FieldArithmeticCoreAir, FieldArithmeticCoreCols, + FieldArithmeticStep, +}; +use crate::adapters::alu_native_adapter::{ + AluNativeAdapterAir, AluNativeAdapterCols, AluNativeAdapterStep, }; -use crate::adapters::alu_native_adapter::{AluNativeAdapterChip, AluNativeAdapterCols}; + +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; + +fn create_test_chip(tester: &VmChipTestBuilder) -> FieldArithmeticChip { + FieldArithmeticChip::::new( + VmAirWrapper::new( + AluNativeAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + FieldArithmeticCoreAir::new(), + ), + FieldArithmeticStep::new(AluNativeAdapterStep::new()), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} #[test] fn new_field_arithmetic_air_test() { @@ -28,15 +49,7 @@ fn new_field_arithmetic_air_test() { let xy_address_space_range = || 0usize..=1; let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&tester); let mut rng = create_seeded_rng(); @@ -74,10 +87,10 @@ fn new_field_arithmetic_air_test() { ); if as1 != 0 { - tester.write_cell(as1, address1, operand1); + tester.write(as1, address1, [operand1]); } if as2 != 0 { - tester.write_cell(as2, address2, operand2); + tester.write(as2, address2, [operand2]); } tester.execute( &mut chip, @@ -86,7 +99,7 @@ fn new_field_arithmetic_air_test() { [result_address, address1, address2, result_as, as1, as2], ), ); - assert_eq!(result, tester.read_cell(result_as, result_address)); + assert_eq!(result, tester.read::<1>(result_as, result_address)[0]); } let mut tester = tester.build().load(chip).finalize(); @@ -122,17 +135,9 @@ fn new_field_arithmetic_air_test() { #[test] fn new_field_arithmetic_air_zero_div_zero() { let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), - ); - tester.write_cell(4, 6, BabyBear::from_canonical_u32(111)); - tester.write_cell(4, 7, BabyBear::from_canonical_u32(222)); + let mut chip = create_test_chip(&tester); + tester.write(4, 6, [BabyBear::from_canonical_u32(111)]); + tester.write(4, 7, [BabyBear::from_canonical_u32(222)]); tester.execute( &mut chip, @@ -166,16 +171,8 @@ fn new_field_arithmetic_air_zero_div_zero() { #[test] fn new_field_arithmetic_air_test_panic() { let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldArithmeticChip::new( - AluNativeAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldArithmeticCoreChip::new(), - tester.offline_memory_mutex_arc(), - ); - tester.write_cell(4, 0, BabyBear::ZERO); + let mut chip = create_test_chip(&tester); + tester.write(4, 0, [BabyBear::ZERO]); // should panic tester.execute( &mut chip, diff --git a/extensions/native/circuit/src/field_extension/core.rs b/extensions/native/circuit/src/field_extension/core.rs index d8c83fabdd..2b77bdc7cc 100644 --- a/extensions/native/circuit/src/field_extension/core.rs +++ b/extensions/native/circuit/src/field_extension/core.rs @@ -5,12 +5,19 @@ use std::{ }; use itertools::izip; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::FieldExtensionOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -38,7 +45,7 @@ pub struct FieldExtensionCoreCols { pub divisor_inv: [T; EXT_DEG], } -#[derive(Copy, Clone, Debug)] +#[derive(derive_new::new, Copy, Clone, Debug)] pub struct FieldExtensionCoreAir {} impl BaseAir for FieldExtensionCoreAir { @@ -141,90 +148,131 @@ pub struct FieldExtensionRecord { pub z: [F; EXT_DEG], } -pub struct FieldExtensionCoreChip { - pub air: FieldExtensionCoreAir, +#[derive(derive_new::new)] +pub struct FieldExtensionCoreStep { + adapter: A, } -impl FieldExtensionCoreChip { - pub fn new() -> Self { - Self { - air: FieldExtensionCoreAir {}, - } +impl TraceStep for FieldExtensionCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = [[F; EXT_DEG]; 2], + WriteData = [F; EXT_DEG], + TraceContext<'a> = (), + >, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET) + ) } -} -impl Default for FieldExtensionCoreChip { - fn default() -> Self { - Self::new() + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { opcode, .. } = instruction; + + let local_opcode = FieldExtensionOpcode::from_usize( + opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET), + ); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [y, z] = self.adapter.read(state.memory, instruction, adapter_row); + + let x = FieldExtension::solve(local_opcode, y, z).unwrap(); + + let core_row: &mut FieldExtensionCoreCols<_> = core_row.borrow_mut(); + core_row.x = x; + core_row.y = y; + core_row.z = z; + core_row.is_add = F::from_bool(local_opcode == FieldExtensionOpcode::FE4ADD); + core_row.is_sub = F::from_bool(local_opcode == FieldExtensionOpcode::FE4SUB); + core_row.is_mul = F::from_bool(local_opcode == FieldExtensionOpcode::BBE4MUL); + core_row.is_div = F::from_bool(local_opcode == FieldExtensionOpcode::BBE4DIV); + + self.adapter + .write(state.memory, instruction, adapter_row, &x); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let core_row: &mut FieldExtensionCoreCols<_> = core_row.borrow_mut(); + + core_row.divisor_inv = if core_row.is_div.is_one() { + FieldExtension::invert(core_row.z) + } else { + [F::ZERO; EXT_DEG] + }; } } -impl> VmCoreChip for FieldExtensionCoreChip +impl StepExecutorE1 for FieldExtensionCoreStep where - I::Reads: Into<[[F; EXT_DEG]; 2]>, - I::Writes: From<[[F; EXT_DEG]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1, { - type Record = FieldExtensionRecord; - type Air = FieldExtensionCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute_e1( + &mut self, + state: &mut VmStateMut, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { let Instruction { opcode, .. } = instruction; - let local_opcode_idx = opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET); - let data: [[F; EXT_DEG]; 2] = reads.into(); - let y: [F; EXT_DEG] = data[0]; - let z: [F; EXT_DEG] = data[1]; + let local_opcode_idx = opcode.local_opcode_idx(FieldExtensionOpcode::CLASS_OFFSET); - let x = FieldExtension::solve(FieldExtensionOpcode::from_usize(local_opcode_idx), y, z) - .unwrap(); + let [y_val, z_val] = self.adapter.read(state, instruction); - let output = AdapterRuntimeContext { - to_pc: None, - writes: [x].into(), - }; + let x_val = FieldExtension::solve( + FieldExtensionOpcode::from_usize(local_opcode_idx), + y_val, + z_val, + ) + .unwrap(); - let record = Self::Record { - opcode: FieldExtensionOpcode::from_usize(local_opcode_idx), - x, - y, - z, - }; + self.adapter.write(state, instruction, &x_val); - Ok((output, record)) - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - FieldExtensionOpcode::from_usize(opcode - FieldExtensionOpcode::CLASS_OFFSET) - ) + Ok(()) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let FieldExtensionRecord { opcode, x, y, z } = record; - let cols: &mut FieldExtensionCoreCols<_> = row_slice.borrow_mut(); - cols.x = x; - cols.y = y; - cols.z = z; - cols.is_add = F::from_bool(opcode == FieldExtensionOpcode::FE4ADD); - cols.is_sub = F::from_bool(opcode == FieldExtensionOpcode::FE4SUB); - cols.is_mul = F::from_bool(opcode == FieldExtensionOpcode::BBE4MUL); - cols.is_div = F::from_bool(opcode == FieldExtensionOpcode::BBE4DIV); - cols.divisor_inv = if opcode == FieldExtensionOpcode::BBE4DIV { - FieldExtension::invert(z) - } else { - [F::ZERO; EXT_DEG] - }; - } + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; - fn air(&self) -> &Self::Air { - &self.air + Ok(()) } } diff --git a/extensions/native/circuit/src/field_extension/mod.rs b/extensions/native/circuit/src/field_extension/mod.rs index d109deb528..c6bcf39f49 100644 --- a/extensions/native/circuit/src/field_extension/mod.rs +++ b/extensions/native/circuit/src/field_extension/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; use super::adapters::native_vectorized_adapter::{ - NativeVectorizedAdapterAir, NativeVectorizedAdapterChip, + NativeVectorizedAdapterAir, NativeVectorizedAdapterStep, }; #[cfg(test)] @@ -12,5 +12,5 @@ pub use core::*; pub type FieldExtensionAir = VmAirWrapper, FieldExtensionCoreAir>; -pub type FieldExtensionChip = - VmChipWrapper, FieldExtensionCoreChip>; +pub type FieldExtensionStep = FieldExtensionCoreStep>; +pub type FieldExtensionChip = NewVmChipWrapper; diff --git a/extensions/native/circuit/src/field_extension/tests.rs b/extensions/native/circuit/src/field_extension/tests.rs index 66d6c94004..e158957415 100644 --- a/extensions/native/circuit/src/field_extension/tests.rs +++ b/extensions/native/circuit/src/field_extension/tests.rs @@ -3,7 +3,10 @@ use std::{ ops::{Add, Div, Mul, Sub}, }; -use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; +use openvm_circuit::arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder}, + VmAirWrapper, +}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_native_compiler::FieldExtensionOpcode; use openvm_stark_backend::{ @@ -16,25 +19,32 @@ use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::Rng; use strum::EnumCount; -use super::{ - super::adapters::native_vectorized_adapter::NativeVectorizedAdapterChip, FieldExtension, - FieldExtensionChip, FieldExtensionCoreChip, +use super::{FieldExtension, FieldExtensionChip, FieldExtensionCoreAir, FieldExtensionCoreStep}; +use crate::adapters::native_vectorized_adapter::{ + NativeVectorizedAdapterAir, NativeVectorizedAdapterStep, }; +const MAX_INS_CAPACITY: usize = 128; +type F = BabyBear; + +fn create_test_chip(tester: &VmChipTestBuilder) -> FieldExtensionChip { + FieldExtensionChip::::new( + VmAirWrapper::new( + NativeVectorizedAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + FieldExtensionCoreAir::new(), + ), + FieldExtensionCoreStep::new(NativeVectorizedAdapterStep::new()), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} + #[test] fn new_field_extension_air_test() { type F = BabyBear; let mut tester = VmChipTestBuilder::default(); - let mut chip = FieldExtensionChip::new( - NativeVectorizedAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - FieldExtensionCoreChip::new(), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&tester); let trace_width = chip.trace_width(); let mut rng = create_seeded_rng(); diff --git a/extensions/native/circuit/src/fri/mod.rs b/extensions/native/circuit/src/fri/mod.rs index 7dbc3fd851..b970cfa557 100644 --- a/extensions/native/circuit/src/fri/mod.rs +++ b/extensions/native/circuit/src/fri/mod.rs @@ -5,38 +5,36 @@ use std::{ sync::{Arc, Mutex}, }; -use itertools::{zip_eq, Itertools}; +use itertools::zip_eq; use openvm_circuit::{ arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams, + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + ExecutionBridge, ExecutionState, NewVmChipWrapper, Result, StepExecutorE1, Streams, + TraceStep, VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols, AUX_LEN}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; +use openvm_circuit_primitives::is_less_than::LessThanAuxCols; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{conversion::AS, FriOpcode::FRI_REDUCED_OPENING}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, + p3_matrix::Matrix, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, }; -use serde::{Deserialize, Serialize}; use static_assertions::const_assert_eq; use crate::{ + adapters::{ + memory_read_native, memory_write_native, tracing_read_native, tracing_write_native, + }, field_extension::{FieldExtension, EXT_DEG}, utils::const_max, }; @@ -219,8 +217,8 @@ const INSTRUCTION_READS: usize = 5; /// it starts with a Workload row (T1) and ends with either a Disabled or Instruction2 row (T7). /// The other transition constraints then ensure the proper state transitions from Workload to /// Instruction2. -#[derive(Copy, Clone, Debug)] -struct FriReducedOpeningAir { +#[derive(Copy, Clone, Debug, derive_new::new)] +pub struct FriReducedOpeningAir { execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge, } @@ -544,94 +542,83 @@ fn elem_to_ext(elem: F) -> [F; EXT_DEG] { ret } -#[derive(Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct FriReducedOpeningRecord { - pub pc: F, - pub start_timestamp: F, - pub instruction: Instruction, - pub alpha_read: RecordId, - pub length_read: RecordId, - pub a_ptr_read: RecordId, - pub is_init_read: RecordId, - pub b_ptr_read: RecordId, - pub a_rws: Vec, - pub b_reads: Vec, - pub result_write: RecordId, +pub struct FriReducedOpeningStep { + pub height: usize, + streams: Arc>>, } -impl FriReducedOpeningRecord { - pub fn get_height(&self) -> usize { - // 2 for instruction rows - self.a_rws.len() + 2 +impl FriReducedOpeningStep { + pub fn new(streams: Arc>>) -> Self { + Self { height: 0, streams } } } -pub struct FriReducedOpeningChip { - air: FriReducedOpeningAir, - pub records: Vec>, - pub height: usize, - offline_memory: Arc>>, - streams: Arc>>, -} -impl FriReducedOpeningChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - offline_memory: Arc>>, - streams: Arc>>, - ) -> Self { - let air = FriReducedOpeningAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }; - Self { - records: vec![], - air, - height: 0, - offline_memory, - streams, - } +impl TraceStep for FriReducedOpeningStep +where + F: PrimeField32, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + assert_eq!(opcode, FRI_REDUCED_OPENING.global_opcode().as_usize()); + String::from("FRI_REDUCED_OPENING") } -} -impl InstructionExecutor for FriReducedOpeningChip { + fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, CTX>, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + trace: &mut [F], + trace_offset: &mut usize, + _width: usize, + ) -> Result<()> { let &Instruction { - a: a_ptr_ptr, - b: b_ptr_ptr, - c: length_ptr, - d: alpha_ptr, - e: result_ptr, - f: hint_id_ptr, - g: is_init_ptr, + a, + b, + c, + d, + e, + f, + g, .. } = instruction; - let addr_space = F::from_canonical_u32(AS::Native as u32); - let alpha_read = memory.read(addr_space, alpha_ptr); - let length_read = memory.read_cell(addr_space, length_ptr); - let a_ptr_read = memory.read_cell(addr_space, a_ptr_ptr); - let b_ptr_read = memory.read_cell(addr_space, b_ptr_ptr); - let is_init_read = memory.read_cell(addr_space, is_init_ptr); - let is_init = is_init_read.1.as_canonical_u32(); + let a_ptr_ptr = a.as_canonical_u32(); + let b_ptr_ptr = b.as_canonical_u32(); + let length_ptr = c.as_canonical_u32(); + let alpha_ptr = d.as_canonical_u32(); + let result_ptr = e.as_canonical_u32(); + let hint_id_ptr = f.as_canonical_u32(); + let is_init_ptr = g.as_canonical_u32(); - let hint_id_f = memory.unsafe_read_cell(addr_space, hint_id_ptr); + let timestamp_start = state.memory.timestamp(); + + // TODO(ayush): there should be a way to avoid this + let mut alpha_aux = MemoryReadAuxCols::new(0, LessThanAuxCols::new([F::ZERO; AUX_LEN])); + let alpha = tracing_read_native(state.memory, alpha_ptr, alpha_aux.as_mut()); + + let mut length_aux = MemoryReadAuxCols::new(0, LessThanAuxCols::new([F::ZERO; AUX_LEN])); + let [length]: [F; 1] = tracing_read_native(state.memory, length_ptr, length_aux.as_mut()); + + let mut a_ptr_aux = MemoryReadAuxCols::new(0, LessThanAuxCols::new([F::ZERO; AUX_LEN])); + let [a_ptr]: [F; 1] = tracing_read_native(state.memory, a_ptr_ptr, a_ptr_aux.as_mut()); + + let mut b_ptr_aux = MemoryReadAuxCols::new(0, LessThanAuxCols::new([F::ZERO; AUX_LEN])); + let [b_ptr]: [F; 1] = tracing_read_native(state.memory, b_ptr_ptr, b_ptr_aux.as_mut()); + + let mut is_init_aux = MemoryReadAuxCols::new(0, LessThanAuxCols::new([F::ZERO; AUX_LEN])); + let [is_init_read]: [F; 1] = + tracing_read_native(state.memory, is_init_ptr, is_init_aux.as_mut()); + let is_init = is_init_read.as_canonical_u32(); + + let [hint_id_f]: [F; 1] = memory_read_native(state.memory.data(), hint_id_ptr); let hint_id = hint_id_f.as_canonical_u32() as usize; - let alpha = alpha_read.1; - let length = length_read.1.as_canonical_u32() as usize; - let a_ptr = a_ptr_read.1; - let b_ptr = b_ptr_read.1; + let length = length.as_canonical_u32() as usize; - let mut a_rws = Vec::with_capacity(length); - let mut b_reads = Vec::with_capacity(length); - let mut result = [F::ZERO; EXT_DEG]; + let write_a = F::ONE - is_init_read; + + // TODO(ayush): why do we need this?should this be incremented only in tracegen execute? + // 2 for instruction rows + self.height += length + 2; let data = if is_init == 0 { let mut streams = self.streams.lock().unwrap(); @@ -640,122 +627,38 @@ impl InstructionExecutor for FriReducedOpeningChip { } else { vec![] }; + + let mut as_and_bs = Vec::with_capacity(length); #[allow(clippy::needless_range_loop)] for i in 0..length { - let a_rw = if is_init == 0 { - let (record_id, _) = - memory.write_cell(addr_space, a_ptr + F::from_canonical_usize(i), data[i]); - (record_id, data[i]) + // First read goes to last row + let start = *trace_offset + (length - i - 1) * OVERALL_WIDTH; + let cols: &mut WorkloadCols = trace[start..start + WL_WIDTH].borrow_mut(); + + let a_ptr_i = (a_ptr + F::from_canonical_usize(i)).as_canonical_u32(); + let [a]: [F; 1] = if is_init == 0 { + tracing_write_native(state.memory, a_ptr_i, &[data[i]], &mut cols.a_aux); + [data[i]] } else { - memory.read_cell(addr_space, a_ptr + F::from_canonical_usize(i)) + tracing_read_native(state.memory, a_ptr_i, cols.a_aux.as_mut()) }; - let b_read = - memory.read::(addr_space, b_ptr + F::from_canonical_usize(EXT_DEG * i)); - a_rws.push(a_rw); - b_reads.push(b_read); - } + let b_ptr_i = (b_ptr + F::from_canonical_usize(EXT_DEG * i)).as_canonical_u32(); + let b = tracing_read_native::(state.memory, b_ptr_i, cols.b_aux.as_mut()); - for (a_rw, b_read) in a_rws.iter().rev().zip_eq(b_reads.iter().rev()) { - let a = a_rw.1; - let b = b_read.1; - // result = result * alpha + (b - a) - result = FieldExtension::add( - FieldExtension::multiply(result, alpha), - FieldExtension::subtract(b, elem_to_ext(a)), - ); + as_and_bs.push((a, b)); } - let (result_write, _) = memory.write(addr_space, result_ptr, result); - - let record = FriReducedOpeningRecord { - pc: F::from_canonical_u32(from_state.pc), - start_timestamp: F::from_canonical_u32(from_state.timestamp), - instruction: instruction.clone(), - alpha_read: alpha_read.0, - length_read: length_read.0, - a_ptr_read: a_ptr_read.0, - is_init_read: is_init_read.0, - b_ptr_read: b_ptr_read.0, - a_rws: a_rws.into_iter().map(|r| r.0).collect(), - b_reads: b_reads.into_iter().map(|r| r.0).collect(), - result_write, - }; - self.height += record.get_height(); - self.records.push(record); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) - } - - fn get_opcode_name(&self, opcode: usize) -> String { - assert_eq!(opcode, FRI_REDUCED_OPENING.global_opcode().as_usize()); - String::from("FRI_REDUCED_OPENING") - } -} + let mut result = [F::ZERO; EXT_DEG]; + for (i, (a, b)) in as_and_bs.into_iter().rev().enumerate() { + let start = *trace_offset + i * OVERALL_WIDTH; + let cols: &mut WorkloadCols = trace[start..start + WL_WIDTH].borrow_mut(); -fn record_to_rows( - record: FriReducedOpeningRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, -) { - let Instruction { - a: a_ptr_ptr, - b: b_ptr_ptr, - c: length_ptr, - d: alpha_ptr, - e: result_ptr, - f: hint_id_ptr, - g: is_init_ptr, - .. - } = record.instruction; - - let length_read = memory.record_by_id(record.length_read); - let alpha_read = memory.record_by_id(record.alpha_read); - let a_ptr_read = memory.record_by_id(record.a_ptr_read); - let b_ptr_read = memory.record_by_id(record.b_ptr_read); - let is_init_read = memory.record_by_id(record.is_init_read); - let is_init = is_init_read.data_at(0); - let write_a = F::ONE - is_init; - - let length = length_read.data_at(0).as_canonical_u32() as usize; - let alpha: [F; EXT_DEG] = alpha_read.data_slice().try_into().unwrap(); - let a_ptr = a_ptr_read.data_at(0); - let b_ptr = b_ptr_read.data_at(0); - - let mut result = [F::ZERO; EXT_DEG]; - - let alpha_aux = aux_cols_factory.make_read_aux_cols(alpha_read); - let length_aux = aux_cols_factory.make_read_aux_cols(length_read); - let a_ptr_aux = aux_cols_factory.make_read_aux_cols(a_ptr_read); - let b_ptr_aux = aux_cols_factory.make_read_aux_cols(b_ptr_read); - let is_init_aux = aux_cols_factory.make_read_aux_cols(is_init_read); - - let result_aux = aux_cols_factory.make_write_aux_cols(memory.record_by_id(record.result_write)); - - // WorkloadCols - for (i, (&a_record_id, &b_record_id)) in record - .a_rws - .iter() - .rev() - .zip_eq(record.b_reads.iter().rev()) - .enumerate() - { - let a_rw = memory.record_by_id(a_record_id); - let b_read = memory.record_by_id(b_record_id); - let a = a_rw.data_at(0); - let b: [F; EXT_DEG] = b_read.data_slice().try_into().unwrap(); - - let start = i * OVERALL_WIDTH; - let cols: &mut WorkloadCols = slice[start..start + WL_WIDTH].borrow_mut(); - *cols = WorkloadCols { - prefix: PrefixCols { + cols.prefix = PrefixCols { general: GeneralCols { is_workload_row: F::ONE, is_ins_row: F::ZERO, - timestamp: record.start_timestamp + F::from_canonical_usize((length - i) * 2), + timestamp: F::from_canonical_u32(timestamp_start) + + F::from_canonical_usize((length - i) * 2), }, a_or_is_first: a, data: DataCols { @@ -766,133 +669,226 @@ fn record_to_rows( result, alpha, }, - }, - // Generate write aux columns no matter `a` is read or written. When `a` is written, - // `prev_data` is not constrained. - a_aux: if a_rw.prev_data_slice().is_some() { - aux_cols_factory.make_write_aux_cols(a_rw) - } else { - let read_aux = aux_cols_factory.make_read_aux_cols(a_rw); - MemoryWriteAuxCols::from_base(read_aux.get_base(), [F::ZERO]) - }, - b, - b_aux: aux_cols_factory.make_read_aux_cols(b_read), - }; - // result = result * alpha + (b - a) - result = FieldExtension::add( - FieldExtension::multiply(result, alpha), - FieldExtension::subtract(b, elem_to_ext(a)), - ); - } - // Instruction1Cols - { - let start = length * OVERALL_WIDTH; - let cols: &mut Instruction1Cols = slice[start..start + INS_1_WIDTH].borrow_mut(); - *cols = Instruction1Cols { - prefix: PrefixCols { - general: GeneralCols { - is_workload_row: F::ZERO, - is_ins_row: F::ONE, - timestamp: record.start_timestamp, - }, - a_or_is_first: F::ONE, - data: DataCols { - a_ptr, - write_a, - b_ptr, - idx: F::from_canonical_usize(length), - result, - alpha, + }; + cols.b = b; + + // result = result * alpha + (b - a) + result = FieldExtension::add( + FieldExtension::multiply(result, alpha), + FieldExtension::subtract(b, elem_to_ext(a)), + ); + } + + // Instruction1Cols + { + let start = *trace_offset + length * OVERALL_WIDTH; + let cols: &mut Instruction1Cols = trace[start..start + INS_1_WIDTH].borrow_mut(); + *cols = Instruction1Cols { + prefix: PrefixCols { + general: GeneralCols { + is_workload_row: F::ZERO, + is_ins_row: F::ONE, + timestamp: F::from_canonical_u32(timestamp_start), + }, + a_or_is_first: F::ONE, + data: DataCols { + a_ptr, + write_a, + b_ptr, + idx: F::from_canonical_usize(length), + result, + alpha, + }, }, - }, - pc: record.pc, - a_ptr_ptr, - a_ptr_aux, - b_ptr_ptr, - b_ptr_aux, - write_a_x_is_first: write_a, - }; - } - // Instruction2Cols - { - let start = (length + 1) * OVERALL_WIDTH; - let cols: &mut Instruction2Cols = slice[start..start + INS_2_WIDTH].borrow_mut(); - *cols = Instruction2Cols { - general: GeneralCols { + pc: F::from_canonical_u32(*state.pc), + a_ptr_ptr: a, + a_ptr_aux, + b_ptr_ptr: b, + b_ptr_aux, + write_a_x_is_first: write_a, + }; + } + + // Instruction2Cols + { + let start = *trace_offset + (length + 1) * OVERALL_WIDTH; + let cols: &mut Instruction2Cols = trace[start..start + INS_2_WIDTH].borrow_mut(); + cols.general = GeneralCols { is_workload_row: F::ZERO, is_ins_row: F::ONE, - timestamp: record.start_timestamp, - }, - is_first: F::ZERO, - length_ptr, - length_aux, - alpha_ptr, - alpha_aux, - result_ptr, - result_aux, - hint_id_ptr, - is_init_ptr, - is_init_aux, - write_a_x_is_first: F::ZERO, - }; - } -} + timestamp: F::from_canonical_u32(timestamp_start), + }; + cols.is_first = F::ZERO; + cols.length_ptr = c; + cols.length_aux = length_aux; + cols.alpha_ptr = d; + cols.alpha_aux = alpha_aux; + cols.result_ptr = e; + cols.hint_id_ptr = f; + cols.is_init_ptr = g; + cols.is_init_aux = is_init_aux; + cols.write_a_x_is_first = F::ZERO; + + tracing_write_native(state.memory, result_ptr, &result, &mut cols.result_aux); + + // TODO(ayush): this is a bad hack to make length available to fill_trace_row + cols.result_aux.base.timestamp_lt_aux.lower_decomp[0] = + F::from_canonical_u32(length as u32); + } -impl ChipUsageGetter for FriReducedOpeningChip { - fn air_name(&self) -> String { - "FriReducedOpeningAir".to_string() - } + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn current_trace_height(&self) -> usize { - self.height + *trace_offset += (length + 2) * OVERALL_WIDTH; + + Ok(()) } - fn trace_width(&self) -> usize { - OVERALL_WIDTH + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (is_workload_row, is_ins_row) = { + let cols: &GeneralCols = row_slice[..GENERAL_WIDTH].borrow(); + (cols.is_workload_row.is_one(), cols.is_ins_row.is_one()) + }; + + if is_workload_row { + let cols: &mut WorkloadCols = row_slice[..WL_WIDTH].borrow_mut(); + + let timestamp = cols.prefix.general.timestamp.as_canonical_u32(); + mem_helper.fill_from_prev(timestamp + 3, cols.a_aux.as_mut()); + mem_helper.fill_from_prev(timestamp + 4, cols.b_aux.as_mut()); + } + + if is_ins_row { + let is_ins_1_row = row_slice[GENERAL_WIDTH].is_one(); + + if is_ins_1_row { + let cols: &mut Instruction1Cols = row_slice[..INS_1_WIDTH].borrow_mut(); + let timestamp = cols.prefix.general.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp + 2, cols.a_ptr_aux.as_mut()); + mem_helper.fill_from_prev(timestamp + 3, cols.b_ptr_aux.as_mut()); + } else { + let cols: &mut Instruction2Cols = row_slice[..INS_2_WIDTH].borrow_mut(); + let timestamp = cols.general.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, cols.alpha_aux.as_mut()); + mem_helper.fill_from_prev(timestamp + 1, cols.length_aux.as_mut()); + mem_helper.fill_from_prev(timestamp + 4, cols.is_init_aux.as_mut()); + + // TODO(ayush): this is bad + let length = cols.result_aux.get_base().timestamp_lt_aux.lower_decomp[0]; + mem_helper.fill_from_prev( + timestamp + 5 + 2 * length.as_canonical_u32(), + cols.result_aux.as_mut(), + ); + } + } } } -impl Chip for FriReducedOpeningChip> +impl StepExecutorE1 for FriReducedOpeningStep where - Val: PrimeField32, + F: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air) - } - fn generate_air_proof_input(self) -> AirProofInput { - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = Val::::zero_vec(OVERALL_WIDTH * height); - let chunked_trace = { - let sizes: Vec<_> = self - .records - .par_iter() - .map(|record| OVERALL_WIDTH * record.get_height()) - .collect(); - variable_chunks_mut(&mut flat_trace, &sizes) + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { + a, + b, + c, + d, + e, + f, + g, + .. + } = instruction; + + let a_ptr_ptr = a.as_canonical_u32(); + let b_ptr_ptr = b.as_canonical_u32(); + let length_ptr = c.as_canonical_u32(); + let alpha_ptr = d.as_canonical_u32(); + let result_ptr = e.as_canonical_u32(); + let hint_id_ptr = f.as_canonical_u32(); + let is_init_ptr = g.as_canonical_u32(); + + let alpha = memory_read_native(state.memory, alpha_ptr); + let [length]: [F; 1] = memory_read_native(state.memory, length_ptr); + let [a_ptr]: [F; 1] = memory_read_native(state.memory, a_ptr_ptr); + let [b_ptr]: [F; 1] = memory_read_native(state.memory, b_ptr_ptr); + let [is_init_read]: [F; 1] = memory_read_native(state.memory, is_init_ptr); + let is_init = is_init_read.as_canonical_u32(); + + let [hint_id_f]: [F; 1] = memory_read_native(state.memory, hint_id_ptr); + let hint_id = hint_id_f.as_canonical_u32() as usize; + + let length = length.as_canonical_u32() as usize; + + let data = if is_init == 0 { + let mut streams = self.streams.lock().unwrap(); + let hint_steam = &mut streams.hint_space[hint_id]; + hint_steam.drain(0..length).collect() + } else { + vec![] }; - let memory = self.offline_memory.lock().unwrap(); - let aux_cols_factory = memory.aux_cols_factory(); + let mut as_and_bs = Vec::with_capacity(length); + #[allow(clippy::needless_range_loop)] + for i in 0..length { + let a_ptr_i = (a_ptr + F::from_canonical_usize(i)).as_canonical_u32(); + let [a]: [F; 1] = if is_init == 0 { + memory_write_native(state.memory, a_ptr_i, &[data[i]]); + [data[i]] + } else { + memory_read_native(state.memory, a_ptr_i) + }; + let b_ptr_i = (b_ptr + F::from_canonical_usize(EXT_DEG * i)).as_canonical_u32(); + let b = memory_read_native::(state.memory, b_ptr_i); - self.records - .into_par_iter() - .zip_eq(chunked_trace.into_par_iter()) - .for_each(|(record, slice)| { - record_to_rows(record, &aux_cols_factory, slice, &memory); - }); + as_and_bs.push((a, b)); + } - let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH); - AirProofInput::simple_no_pis(matrix) + let mut result = [F::ZERO; EXT_DEG]; + for (a, b) in as_and_bs.into_iter().rev() { + // result = result * alpha + (b - a) + result = FieldExtension::add( + FieldExtension::multiply(result, alpha), + FieldExtension::subtract(b, elem_to_ext(a)), + ); + } + + // TODO(ayush): why do we need this?should this be incremented only in tracegen execute? + // 2 for instruction rows + self.height += length + 2; + + memory_write_native(state.memory, result_ptr, &result); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } -} -fn variable_chunks_mut<'a, T>(mut slice: &'a mut [T], sizes: &[usize]) -> Vec<&'a mut [T]> { - let mut result = Vec::with_capacity(sizes.len()); - for &size in sizes { - // split_at_mut guarantees disjoint slices - let (left, right) = slice.split_at_mut(size); - result.push(left); - slice = right; // move forward for the next chunk + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + let &Instruction { c, .. } = instruction; + + let length_ptr = c.as_canonical_u32(); + let [length]: [F; 1] = memory_read_native(state.memory, length_ptr); + + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += length.as_canonical_u32() + 2; + + Ok(()) } - result } + +pub type FriReducedOpeningChip = + NewVmChipWrapper>; diff --git a/extensions/native/circuit/src/fri/tests.rs b/extensions/native/circuit/src/fri/tests.rs index 97dcdbc532..5779be638c 100644 --- a/extensions/native/circuit/src/fri/tests.rs +++ b/extensions/native/circuit/src/fri/tests.rs @@ -15,8 +15,26 @@ use openvm_stark_backend::{ use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::Rng; -use super::{super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningChip, EXT_DEG}; -use crate::OVERALL_WIDTH; +use super::{ + super::field_extension::FieldExtension, elem_to_ext, FriReducedOpeningAir, + FriReducedOpeningChip, FriReducedOpeningStep, EXT_DEG, +}; +use crate::fri::OVERALL_WIDTH; + +const MAX_INS_CAPACITY: usize = 1024; +type F = BabyBear; + +fn create_test_chip( + tester: &VmChipTestBuilder, + streams: Arc>>, +) -> FriReducedOpeningChip { + FriReducedOpeningChip::::new( + FriReducedOpeningAir::new(tester.execution_bridge(), tester.memory_bridge()), + FriReducedOpeningStep::new(streams), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} fn compute_fri_mat_opening( alpha: [F; EXT_DEG], @@ -44,13 +62,7 @@ fn fri_mat_opening_air_test() { let mut tester = VmChipTestBuilder::default(); let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = FriReducedOpeningChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - streams.clone(), - ); + let mut chip = create_test_chip(&tester, streams.clone()); let mut rng = create_seeded_rng(); @@ -85,39 +97,39 @@ fn fri_mat_opening_air_test() { let address_space = 4usize; - /*tracing::debug!( - "{opcode:?} d = {}, e = {}, f = {}, result_addr = {}, addr1 = {}, addr2 = {}, z = {}, x = {}, y = {}", - result_as, as1, as2, result_pointer, address1, address2, result, operand1, operand2, - );*/ + // tracing::debug!( + // "{opcode:?} d = {}, e = {}, f = {}, result_addr = {}, addr1 = {}, addr2 = {}, z = {}, + // x = {}, y = {}", result_as, as1, as2, result_pointer, address1, address2, + // result, operand1, operand2, ); tester.write(address_space, alpha_pointer, alpha); - tester.write_cell( + tester.write( address_space, length_pointer, - BabyBear::from_canonical_usize(length), + [BabyBear::from_canonical_usize(length)], ); - tester.write_cell( + tester.write( address_space, a_pointer_pointer, - BabyBear::from_canonical_usize(a_pointer), + [BabyBear::from_canonical_usize(a_pointer)], ); - tester.write_cell( + tester.write( address_space, b_pointer_pointer, - BabyBear::from_canonical_usize(b_pointer), + [BabyBear::from_canonical_usize(b_pointer)], ); let is_init = rng.gen_range(0..2); - tester.write_cell( + tester.write( address_space, is_init_ptr, - BabyBear::from_canonical_u32(is_init), + [BabyBear::from_canonical_u32(is_init)], ); if is_init == 0 { streams.lock().unwrap().hint_space[0].extend_from_slice(&a); } else { for (i, ai) in a.iter().enumerate() { - tester.write_cell(address_space, a_pointer + i, *ai); + tester.write(address_space, a_pointer + i, [*ai]); } } for (i, bi) in b.iter().enumerate() { @@ -142,7 +154,7 @@ fn fri_mat_opening_air_test() { assert_eq!(result, tester.read(address_space, result_pointer)); // Check that `a` was populated. for (i, ai) in a.iter().enumerate() { - let found = tester.read_cell(address_space, a_pointer + i); + let [found] = tester.read(address_space, a_pointer + i); assert_eq!(*ai, found); } } diff --git a/extensions/native/circuit/src/jal/mod.rs b/extensions/native/circuit/src/jal/mod.rs index 28322834a2..7fbeace742 100644 --- a/extensions/native/circuit/src/jal/mod.rs +++ b/extensions/native/circuit/src/jal/mod.rs @@ -1,40 +1,38 @@ use std::{ borrow::{Borrow, BorrowMut}, ops::Deref, - sync::{Arc, Mutex}, }; use openvm_circuit::{ - arch::{ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, PcIncOrSet}, + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + ExecutionBridge, ExecutionError, ExecutionState, NewVmChipWrapper, PcIncOrSet, Result, + StepExecutorE1, TraceStep, VmStateMut, + }, system::memory::{ offline_checker::{MemoryBridge, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; -use openvm_circuit_primitives::{ - utils::next_power_of_two_or_zero, - var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip, - }, +use openvm_circuit_primitives::var_range::{ + SharedVariableRangeCheckerChip, VariableRangeCheckerBus, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{conversion::AS, NativeJalOpcode, NativeRangeCheckOpcode}; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, + p3_matrix::Matrix, rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - AirRef, Chip, ChipUsageGetter, }; -use serde::{Deserialize, Serialize}; use static_assertions::const_assert_eq; use AS::Native; +use crate::adapters::{memory_read_native, memory_write_native, tracing_write_native}; + #[cfg(test)] mod tests; @@ -57,7 +55,7 @@ struct JalRangeCheckCols { const OVERALL_WIDTH: usize = JalRangeCheckCols::::width(); const_assert_eq!(OVERALL_WIDTH, 12); -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct JalRangeCheckAir { execution_bridge: ExecutionBridge, memory_bridge: MemoryBridge, @@ -136,207 +134,211 @@ where } } -impl JalRangeCheckAir { - fn new( - execution_bridge: ExecutionBridge, - memory_bridge: MemoryBridge, - range_bus: VariableRangeCheckerBus, - ) -> Self { - Self { - execution_bridge, - memory_bridge, - range_bus, - } - } -} - -#[repr(C)] -#[derive(Serialize, Deserialize)] -pub struct JalRangeCheckRecord { - pub state: ExecutionState, - pub a_rw: RecordId, - pub b: u32, - pub c: u8, - pub is_jal: bool, -} - /// Chip for JAL and RANGE_CHECK. These opcodes are logically irrelevant. Putting these opcodes into /// the same chip is just to save columns. -pub struct JalRangeCheckChip { - air: JalRangeCheckAir, - pub records: Vec, - offline_memory: Arc>>, +pub struct JalRangeCheckStep { range_checker_chip: SharedVariableRangeCheckerChip, /// If true, ignore execution errors. debug: bool, } -impl JalRangeCheckChip { - pub fn new( - execution_bridge: ExecutionBridge, - offline_memory: Arc>>, - range_checker_chip: SharedVariableRangeCheckerChip, - ) -> Self { - let memory_bridge = offline_memory.lock().unwrap().memory_bridge(); - let air = JalRangeCheckAir::new(execution_bridge, memory_bridge, range_checker_chip.bus()); +impl JalRangeCheckStep { + pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self { Self { - air, - records: vec![], - offline_memory, range_checker_chip, debug: false, } } - pub fn with_debug(mut self) -> Self { + pub fn set_debug(&mut self) { self.debug = true; - self } } -impl InstructionExecutor for JalRangeCheckChip { +impl TraceStep for JalRangeCheckStep +where + F: PrimeField32, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize(); + let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK + .global_opcode() + .as_usize(); + if opcode == jal_opcode { + return String::from("JAL"); + } + if opcode == range_check_opcode { + return String::from("RANGE_CHECK"); + } + panic!("Unknown opcode {}", opcode); + } + fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, CTX>, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { - if instruction.opcode == NativeJalOpcode::JAL.global_opcode() { - let (record_id, _) = memory.write( - F::from_canonical_u32(AS::Native as u32), - instruction.a, - [F::from_canonical_u32(from_state.pc + DEFAULT_PC_STEP)], + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { + opcode, a, b, c, .. + } = instruction; + + debug_assert!( + opcode == NativeJalOpcode::JAL.global_opcode() + || opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() + ); + + let row: &mut JalRangeCheckCols = + trace[*trace_offset..*trace_offset + width].borrow_mut(); + + row.state.pc = F::from_canonical_u32(*state.pc); + row.state.timestamp = F::from_canonical_u32(state.memory.timestamp); + + row.a_pointer = a; + row.b = b; + + if opcode == NativeJalOpcode::JAL.global_opcode() { + row.is_jal = F::ONE; + row.c = F::ZERO; + + tracing_write_native( + state.memory, + a.as_canonical_u32(), + &[F::from_canonical_u32( + state.pc.wrapping_add(DEFAULT_PC_STEP), + )], + &mut row.writes_aux, ); - let b = instruction.b.as_canonical_u32(); - self.records.push(JalRangeCheckRecord { - state: from_state, - a_rw: record_id, - b, - c: 0, - is_jal: true, - }); - return Ok(ExecutionState { - pc: (F::from_canonical_u32(from_state.pc) + instruction.b).as_canonical_u32(), - timestamp: memory.timestamp(), - }); - } else if instruction.opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { - let d = F::from_canonical_u32(AS::Native as u32); - // This is a read, but we make the record have prev_data - let a_val = memory.unsafe_read_cell(d, instruction.a); - let (record_id, _) = memory.write(d, instruction.a, [a_val]); + // TODO(ayush): can this addition be done in u32 instead of F + *state.pc = (F::from_canonical_u32(*state.pc) + b).as_canonical_u32(); + } else if opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { + row.is_jal = F::ZERO; + row.c = c; + + let [a_val]: [F; 1] = memory_read_native(state.memory.data(), a.as_canonical_u32()); + tracing_write_native( + state.memory, + a.as_canonical_u32(), + &[a_val], + &mut row.writes_aux, + ); + + // TODO(ayush): should this debug stuff be removed? let a_val = a_val.as_canonical_u32(); - let b = instruction.b.as_canonical_u32(); - let c = instruction.c.as_canonical_u32(); + let b = b.as_canonical_u32(); + let c = c.as_canonical_u32(); + debug_assert!(!self.debug || b <= 16); debug_assert!(!self.debug || c <= 14); + let x = a_val & ((1 << 16) - 1); if !self.debug && x >= 1 << b { - return Err(ExecutionError::Fail { pc: from_state.pc }); + return Err(ExecutionError::Fail { pc: *state.pc }); } let y = a_val >> 16; if !self.debug && y >= 1 << c { - return Err(ExecutionError::Fail { pc: from_state.pc }); + return Err(ExecutionError::Fail { pc: *state.pc }); } - self.records.push(JalRangeCheckRecord { - state: from_state, - a_rw: record_id, - b, - c: c as u8, - is_jal: false, - }); - return Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }); - } - panic!("Unknown opcode {}", instruction.opcode); - } - fn get_opcode_name(&self, opcode: usize) -> String { - let jal_opcode = NativeJalOpcode::JAL.global_opcode().as_usize(); - let range_check_opcode = NativeRangeCheckOpcode::RANGE_CHECK - .global_opcode() - .as_usize(); - if opcode == jal_opcode { - return String::from("JAL"); - } - if opcode == range_check_opcode { - return String::from("RANGE_CHECK"); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); } - panic!("Unknown opcode {}", opcode); - } -} -impl ChipUsageGetter for JalRangeCheckChip { - fn air_name(&self) -> String { - "JalRangeCheck".to_string() - } + *trace_offset += width; - fn current_trace_height(&self) -> usize { - self.records.len() + Ok(()) } - fn trace_width(&self) -> usize { - OVERALL_WIDTH + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let row: &mut JalRangeCheckCols<_> = row_slice.borrow_mut(); + + let timestamp = row.state.timestamp.as_canonical_u32(); + mem_helper.fill_from_prev(timestamp, row.writes_aux.as_mut()); + + row.is_range_check = F::ONE - row.is_jal; + + if row.is_range_check.is_one() { + let a_val = row.writes_aux.prev_data()[0]; + let a_val_u32 = a_val.as_canonical_u32(); + let y = a_val_u32 >> 16; + let x = a_val_u32 & ((1 << 16) - 1); + self.range_checker_chip + .add_count(x, row.b.as_canonical_u32() as usize); + self.range_checker_chip + .add_count(y, row.c.as_canonical_u32() as usize); + row.y = F::from_canonical_u32(y); + } } } -impl Chip for JalRangeCheckChip> +impl StepExecutorE1 for JalRangeCheckStep where - Val: PrimeField32, + F: PrimeField32, { - fn air(&self) -> AirRef { - Arc::new(self.air) - } - fn generate_air_proof_input(self) -> AirProofInput { - let height = next_power_of_two_or_zero(self.records.len()); - let mut flat_trace = Val::::zero_vec(OVERALL_WIDTH * height); - let memory = self.offline_memory.lock().unwrap(); - let aux_cols_factory = memory.aux_cols_factory(); - - self.records - .into_par_iter() - .zip(flat_trace.par_chunks_mut(OVERALL_WIDTH)) - .for_each(|(record, slice)| { - record_to_row( - record, - &aux_cols_factory, - self.range_checker_chip.as_ref(), - slice, - &memory, - ); - }); + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { opcode, a, b, .. } = instruction; + + debug_assert!( + opcode == NativeJalOpcode::JAL.global_opcode() + || opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() + ); + + if opcode == NativeJalOpcode::JAL.global_opcode() { + memory_write_native( + state.memory, + a.as_canonical_u32(), + &[F::from_canonical_u32( + state.pc.wrapping_add(DEFAULT_PC_STEP), + )], + ); + // TODO(ayush): can this addition be done in u32 instead of F + *state.pc = (F::from_canonical_u32(*state.pc) + b).as_canonical_u32(); + } else if opcode == NativeRangeCheckOpcode::RANGE_CHECK.global_opcode() { + // TODO(ayush): should this not call memory callback? + let [a_val]: [F; 1] = memory_read_native(state.memory, a.as_canonical_u32()); - let matrix = RowMajorMatrix::new(flat_trace, OVERALL_WIDTH); - AirProofInput::simple_no_pis(matrix) + memory_write_native(state.memory, a.as_canonical_u32(), &[a_val]); + + let a_val = a_val.as_canonical_u32(); + let b = instruction.b.as_canonical_u32(); + let c = instruction.c.as_canonical_u32(); + + debug_assert!(!self.debug || b <= 16); + debug_assert!(!self.debug || c <= 14); + + let x = a_val & ((1 << 16) - 1); + if !self.debug && x >= 1 << b { + return Err(ExecutionError::Fail { pc: *state.pc }); + } + let y = a_val >> 16; + if !self.debug && y >= 1 << c { + return Err(ExecutionError::Fail { pc: *state.pc }); + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + } + + Ok(()) } -} -fn record_to_row( - record: JalRangeCheckRecord, - aux_cols_factory: &MemoryAuxColsFactory, - range_checker_chip: &VariableRangeCheckerChip, - slice: &mut [F], - memory: &OfflineMemory, -) { - let a_record = memory.record_by_id(record.a_rw); - let col: &mut JalRangeCheckCols<_> = slice.borrow_mut(); - col.is_jal = F::from_bool(record.is_jal); - col.is_range_check = F::from_bool(!record.is_jal); - col.a_pointer = a_record.pointer; - col.state = ExecutionState { - pc: F::from_canonical_u32(record.state.pc), - timestamp: F::from_canonical_u32(record.state.timestamp), - }; - aux_cols_factory.generate_write_aux(a_record, &mut col.writes_aux); - col.b = F::from_canonical_u32(record.b); - if !record.is_jal { - let a_val = a_record.data_at(0); - let a_val_u32 = a_val.as_canonical_u32(); - let y = a_val_u32 >> 16; - let x = a_val_u32 & ((1 << 16) - 1); - range_checker_chip.add_count(x, record.b as usize); - range_checker_chip.add_count(y, record.c as usize); - col.c = F::from_canonical_u32(record.c as u32); - col.y = F::from_canonical_u32(y); + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } + +pub type JalRangeCheckChip = NewVmChipWrapper; diff --git a/extensions/native/circuit/src/jal/tests.rs b/extensions/native/circuit/src/jal/tests.rs index dd56b73c8f..2494264c22 100644 --- a/extensions/native/circuit/src/jal/tests.rs +++ b/extensions/native/circuit/src/jal/tests.rs @@ -1,6 +1,6 @@ use std::borrow::BorrowMut; -use openvm_circuit::arch::{testing::VmChipTestBuilder, ExecutionBridge}; +use openvm_circuit::arch::testing::VmChipTestBuilder; use openvm_instructions::{ instruction::Instruction, program::{DEFAULT_PC_STEP, PC_BITS}, @@ -16,9 +16,25 @@ use openvm_stark_backend::{ use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use crate::{jal::JalRangeCheckCols, JalRangeCheckChip}; +use super::{JalRangeCheckAir, JalRangeCheckStep}; +use crate::jal::{JalRangeCheckChip, JalRangeCheckCols}; + +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +fn create_test_chip(tester: &VmChipTestBuilder) -> JalRangeCheckChip { + JalRangeCheckChip::::new( + JalRangeCheckAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + tester.range_checker().bus(), + ), + JalRangeCheckStep::new(tester.range_checker().clone()), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} + fn set_and_execute( tester: &mut VmChipTestBuilder, chip: &mut JalRangeCheckChip, @@ -61,7 +77,7 @@ fn set_and_execute_range_check( for RangeCheckTestCase { val, x_bit, y_bit } in test_cases { let d = 4usize; - tester.write_cell(d, a, F::from_canonical_u32(val)); + tester.write(d, a, [F::from_canonical_u32(val)]); tester.execute_with_pc( chip, &Instruction::from_usize( @@ -73,19 +89,12 @@ fn set_and_execute_range_check( } } -fn setup() -> (StdRng, VmChipTestBuilder, JalRangeCheckChip) { - let rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - let execution_bridge = ExecutionBridge::new(tester.execution_bus(), tester.program_bus()); - let offline_memory = tester.offline_memory_mutex_arc(); - let range_checker = tester.range_checker(); - let chip = JalRangeCheckChip::::new(execution_bridge, offline_memory, range_checker); - (rng, tester, chip) -} - #[test] fn rand_jal_test() { - let (mut rng, mut tester, mut chip) = setup(); + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&tester); + let num_tests: usize = 100; for _ in 0..num_tests { set_and_execute(&mut tester, &mut chip, &mut rng, None, None); @@ -97,7 +106,10 @@ fn rand_jal_test() { #[test] fn rand_range_check_test() { - let (mut rng, mut tester, mut chip) = setup(); + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&tester); + let f = |x: u32, y: u32| RangeCheckTestCase { val: x + y * (1 << 16), x_bit: 32 - x.leading_zeros(), @@ -129,8 +141,11 @@ fn rand_range_check_test() { #[test] fn negative_range_check_test() { { - let (mut rng, mut tester, chip) = setup(); - let mut chip = chip.with_debug(); + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&tester); + chip.step.set_debug(); + set_and_execute_range_check( &mut tester, &mut chip, @@ -147,8 +162,11 @@ fn negative_range_check_test() { assert!(result.is_err()); } { - let (mut rng, mut tester, chip) = setup(); - let mut chip = chip.with_debug(); + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&tester); + chip.step.set_debug(); + set_and_execute_range_check( &mut tester, &mut chip, @@ -168,7 +186,10 @@ fn negative_range_check_test() { #[test] fn negative_jal_test() { - let (mut rng, mut tester, mut chip) = setup(); + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&tester); + set_and_execute(&mut tester, &mut chip, &mut rng, None, None); let tester = tester.build(); diff --git a/extensions/native/circuit/src/lib.rs b/extensions/native/circuit/src/lib.rs index 46c6bc890f..4e16d00233 100644 --- a/extensions/native/circuit/src/lib.rs +++ b/extensions/native/circuit/src/lib.rs @@ -13,10 +13,10 @@ pub use branch_eq::*; pub use castf::*; pub use field_arithmetic::*; pub use field_extension::*; -pub use fri::*; +// pub use fri::*; pub use jal::*; pub use loadstore::*; -pub use poseidon2::*; +// pub use poseidon2::*; mod extension; pub use extension::*; diff --git a/extensions/native/circuit/src/loadstore/core.rs b/extensions/native/circuit/src/loadstore/core.rs index 60c7bbdbdb..9d61fa95fb 100644 --- a/extensions/native/circuit/src/loadstore/core.rs +++ b/extensions/native/circuit/src/loadstore/core.rs @@ -4,12 +4,20 @@ use std::{ sync::{Arc, Mutex, OnceLock}, }; -use openvm_circuit::arch::{ - instructions::LocalOpcode, AdapterAirContext, AdapterRuntimeContext, ExecutionError, Result, - Streams, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + instructions::LocalOpcode, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, ExecutionError, Result, + StepExecutorE1, Streams, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::instruction::Instruction; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP}; use openvm_native_compiler::NativeLoadStoreOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -44,7 +52,7 @@ pub struct NativeLoadStoreCoreRecord { pub data: [F; NUM_CELLS], } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, derive_new::new)] pub struct NativeLoadStoreCoreAir { pub offset: usize, } @@ -114,15 +122,23 @@ where } #[derive(Debug)] -pub struct NativeLoadStoreCoreChip { - pub air: NativeLoadStoreCoreAir, +pub struct NativeLoadStoreCoreStep +where + F: Field, +{ + adapter: A, + offset: usize, pub streams: OnceLock>>>, } -impl NativeLoadStoreCoreChip { - pub fn new(offset: usize) -> Self { +impl NativeLoadStoreCoreStep +where + F: Field, +{ + pub fn new(adapter: A, offset: usize) -> Self { Self { - air: NativeLoadStoreCoreAir:: { offset }, + adapter, + offset, streams: OnceLock::new(), } } @@ -131,69 +147,129 @@ impl NativeLoadStoreCoreChip { } } -impl Default for NativeLoadStoreCoreChip { - fn default() -> Self { - Self::new(NativeLoadStoreOpcode::CLASS_OFFSET) - } -} - -impl, const NUM_CELLS: usize> VmCoreChip - for NativeLoadStoreCoreChip +impl TraceStep + for NativeLoadStoreCoreStep where - I::Reads: Into<(F, [F; NUM_CELLS])>, - I::Writes: From<[F; NUM_CELLS]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = (F, [F; NUM_CELLS]), + WriteData = [F; NUM_CELLS], + TraceContext<'a> = F, + >, { - type Record = NativeLoadStoreCoreRecord; - type Air = NativeLoadStoreCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + NativeLoadStoreOpcode::from_usize(opcode - self.offset) + ) + } - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, .. } = *instruction; - let local_opcode = - NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let (pointer_read, data_read) = reads.into(); + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { opcode, .. } = instruction; + + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let (pointer_read, data_read) = self.adapter.read(state.memory, instruction, adapter_row); let data = if local_opcode == NativeLoadStoreOpcode::HINT_STOREW { let mut streams = self.streams.get().unwrap().lock().unwrap(); if streams.hint_stream.len() < NUM_CELLS { - return Err(ExecutionError::HintOutOfBounds { pc: from_pc }); + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); } array::from_fn(|_| streams.hint_stream.pop_front().unwrap()) } else { data_read }; - let output = AdapterRuntimeContext::without_pc(data); - let record = NativeLoadStoreCoreRecord { - opcode: NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)), - pointer_read, - data, - }; - Ok((output, record)) + self.adapter + .write(state.memory, instruction, adapter_row, &data); + + let core_row: &mut NativeLoadStoreCoreCols = core_row.borrow_mut(); + + core_row.pointer_read = pointer_read; + core_row.data = data; + core_row.is_loadw = F::from_bool(local_opcode == NativeLoadStoreOpcode::LOADW); + core_row.is_storew = F::from_bool(local_opcode == NativeLoadStoreOpcode::STOREW); + core_row.is_hint_storew = F::from_bool(local_opcode == NativeLoadStoreOpcode::HINT_STOREW); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - NativeLoadStoreOpcode::from_usize(opcode - self.air.offset) - ) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + let core_row: &mut NativeLoadStoreCoreCols = core_row.borrow_mut(); + self.adapter + .fill_trace_row(mem_helper, core_row.is_hint_storew, adapter_row); } +} + +impl StepExecutorE1 for NativeLoadStoreCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let cols: &mut NativeLoadStoreCoreCols<_, NUM_CELLS> = row_slice.borrow_mut(); - cols.is_loadw = F::from_bool(record.opcode == NativeLoadStoreOpcode::LOADW); - cols.is_storew = F::from_bool(record.opcode == NativeLoadStoreOpcode::STOREW); - cols.is_hint_storew = F::from_bool(record.opcode == NativeLoadStoreOpcode::HINT_STOREW); + // Get the local opcode for this instruction + let local_opcode = NativeLoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - cols.pointer_read = record.pointer_read; - cols.data = record.data; + let (_, data_read) = self.adapter.read(state, instruction); + + let data = if local_opcode == NativeLoadStoreOpcode::HINT_STOREW { + let mut streams = self.streams.get().unwrap().lock().unwrap(); + if streams.hint_stream.len() < NUM_CELLS { + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); + } + array::from_fn(|_| streams.hint_stream.pop_front().unwrap()) + } else { + data_read + }; + + self.adapter.write(state, instruction, &data); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } diff --git a/extensions/native/circuit/src/loadstore/mod.rs b/extensions/native/circuit/src/loadstore/mod.rs index 3dd51113a9..e9dadaa038 100644 --- a/extensions/native/circuit/src/loadstore/mod.rs +++ b/extensions/native/circuit/src/loadstore/mod.rs @@ -1,4 +1,4 @@ -use openvm_circuit::arch::{VmAirWrapper, VmChipWrapper}; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; #[cfg(test)] mod tests; @@ -7,13 +7,12 @@ mod core; pub use core::*; use super::adapters::loadstore_native_adapter::{ - NativeLoadStoreAdapterAir, NativeLoadStoreAdapterChip, + NativeLoadStoreAdapterAir, NativeLoadStoreAdapterStep, }; pub type NativeLoadStoreAir = VmAirWrapper, NativeLoadStoreCoreAir>; -pub type NativeLoadStoreChip = VmChipWrapper< - F, - NativeLoadStoreAdapterChip, - NativeLoadStoreCoreChip, ->; +pub type NativeLoadStoreStep = + NativeLoadStoreCoreStep, F, NUM_CELLS>; +pub type NativeLoadStoreChip = + NewVmChipWrapper, NativeLoadStoreStep>; diff --git a/extensions/native/circuit/src/loadstore/tests.rs b/extensions/native/circuit/src/loadstore/tests.rs index cd653c2fc0..af906cdfb9 100644 --- a/extensions/native/circuit/src/loadstore/tests.rs +++ b/extensions/native/circuit/src/loadstore/tests.rs @@ -1,17 +1,18 @@ use std::sync::{Arc, Mutex}; -use openvm_circuit::arch::{testing::VmChipTestBuilder, Streams}; +use openvm_circuit::arch::{testing::VmChipTestBuilder, Streams, VmAirWrapper}; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_native_compiler::NativeLoadStoreOpcode::{self, *}; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{ - super::adapters::loadstore_native_adapter::NativeLoadStoreAdapterChip, NativeLoadStoreChip, - NativeLoadStoreCoreChip, +use super::{NativeLoadStoreChip, NativeLoadStoreCoreAir, NativeLoadStoreCoreStep}; +use crate::adapters::loadstore_native_adapter::{ + NativeLoadStoreAdapterAir, NativeLoadStoreAdapterStep, }; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; #[derive(Debug)] @@ -28,20 +29,22 @@ struct TestData { is_hint: bool, } -fn setup() -> (StdRng, VmChipTestBuilder, NativeLoadStoreChip) { - let rng = create_seeded_rng(); - let tester = VmChipTestBuilder::default(); - - let adapter = NativeLoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - NativeLoadStoreOpcode::CLASS_OFFSET, +fn create_test_chip(tester: &VmChipTestBuilder) -> NativeLoadStoreChip { + let mut chip = NativeLoadStoreChip::::new( + VmAirWrapper::new( + NativeLoadStoreAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + NativeLoadStoreCoreAir::new(NativeLoadStoreOpcode::CLASS_OFFSET), + ), + NativeLoadStoreCoreStep::new( + NativeLoadStoreAdapterStep::new(NativeLoadStoreOpcode::CLASS_OFFSET), + NativeLoadStoreOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - let mut inner = NativeLoadStoreCoreChip::new(NativeLoadStoreOpcode::CLASS_OFFSET); - inner.set_streams(Arc::new(Mutex::new(Streams::default()))); - let chip = NativeLoadStoreChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - (rng, tester, chip) + chip.step + .set_streams(Arc::new(Mutex::new(Streams::default()))); + chip } fn gen_test_data(rng: &mut StdRng, opcode: NativeLoadStoreOpcode) -> TestData { @@ -102,7 +105,7 @@ fn set_values( } if data.is_hint { for _ in 0..data.e.as_canonical_u32() { - chip.core + chip.step .streams .get() .unwrap() @@ -164,7 +167,11 @@ fn set_and_execute( #[test] fn rand_native_loadstore_test() { setup_tracing(); - let (mut rng, mut tester, mut chip) = setup(); + + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&tester); + for _ in 0..20 { set_and_execute(&mut tester, &mut chip, &mut rng, STOREW); set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); diff --git a/extensions/native/circuit/src/poseidon2/air.rs b/extensions/native/circuit/src/poseidon2/air.rs index 5ed28abd60..9d24966fe7 100644 --- a/extensions/native/circuit/src/poseidon2/air.rs +++ b/extensions/native/circuit/src/poseidon2/air.rs @@ -20,15 +20,13 @@ use openvm_stark_backend::{ rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; -use crate::{ +use crate::poseidon2::{ chip::{NUM_INITIAL_READS, NUM_SIMPLE_ACCESSES}, - poseidon2::{ - columns::{ - InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, - TopLevelSpecificCols, - }, - CHUNK, + columns::{ + InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, }, + CHUNK, }; #[derive(Clone, Debug)] diff --git a/extensions/native/circuit/src/poseidon2/chip.rs b/extensions/native/circuit/src/poseidon2/chip.rs index 426b089a9c..627a5b6c67 100644 --- a/extensions/native/circuit/src/poseidon2/chip.rs +++ b/extensions/native/circuit/src/poseidon2/chip.rs @@ -1,10 +1,17 @@ -use std::sync::{Arc, Mutex}; +use std::{ + borrow::{Borrow, BorrowMut}, + sync::{Arc, Mutex}, +}; use openvm_circuit::{ arch::{ - ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, Streams, SystemPort, + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + StepExecutorE1, Streams, TraceStep, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, }, - system::memory::{MemoryController, OfflineMemory, RecordId}, }; use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_native_compiler::{ @@ -12,151 +19,40 @@ use openvm_native_compiler::{ Poseidon2Opcode::{COMP_POS2, PERM_POS2}, VerifyBatchOpcode::VERIFY_BATCH, }; -use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir, Poseidon2SubChip}; +use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubChip}; use openvm_stark_backend::{ + p3_air::BaseAir, p3_field::{Field, PrimeField32}, p3_maybe_rayon::prelude::{ParallelIterator, ParallelSlice}, }; -use serde::{Deserialize, Serialize}; -use crate::poseidon2::{ - air::{NativePoseidon2Air, VerifyBatchBus}, - CHUNK, +use crate::{ + adapters::{ + memory_read_native, memory_write_native, tracing_read_native, tracing_write_native, + }, + poseidon2::{ + columns::{ + InsideRowSpecificCols, NativePoseidon2Cols, SimplePoseidonSpecificCols, + TopLevelSpecificCols, + }, + CHUNK, + }, }; -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct VerifyBatchRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub dim_base_pointer: F, - pub opened_base_pointer: F, - pub opened_length: usize, - pub index_base_pointer: F, - pub commit_pointer: F, - - pub dim_base_pointer_read: RecordId, - pub opened_base_pointer_read: RecordId, - pub opened_length_read: RecordId, - pub index_base_pointer_read: RecordId, - pub commit_pointer_read: RecordId, - - pub commit_read: RecordId, - pub initial_log_height: usize, - pub top_level: Vec>, -} - -impl VerifyBatchRecord { - pub fn opened_element_size_inv(&self) -> F { - self.instruction.g - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct TopLevelRecord { - // must be present in first record - pub incorporate_row: Option>, - // must be present in all bust last record - pub incorporate_sibling: Option>, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct IncorporateSiblingRecord { - pub read_sibling_is_on_right: RecordId, - pub sibling_is_on_right: bool, - pub p2_input: [F; 2 * CHUNK], -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct IncorporateRowRecord { - pub chunks: Vec>, - pub initial_opened_index: usize, - pub final_opened_index: usize, - pub initial_height_read: RecordId, - pub final_height_read: RecordId, - pub p2_input: [F; 2 * CHUNK], -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct InsideRowRecord { - pub cells: Vec, - pub p2_input: [F; 2 * CHUNK], -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct CellRecord { - pub read: RecordId, - pub opened_index: usize, - pub read_row_pointer_and_length: Option, - pub row_pointer: usize, - pub row_end: usize, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct SimplePoseidonRecord { - pub from_state: ExecutionState, - pub instruction: Instruction, - - pub read_input_pointer_1: RecordId, - pub read_input_pointer_2: Option, - pub read_output_pointer: RecordId, - pub read_data_1: RecordId, - pub read_data_2: RecordId, - pub write_data_1: RecordId, - pub write_data_2: Option, - - pub input_pointer_1: F, - pub input_pointer_2: F, - pub output_pointer: F, - pub p2_input: [F; 2 * CHUNK], -} - -#[derive(Debug, Clone, Serialize, Deserialize, Default)] -#[serde(bound = "F: Field")] -pub struct NativePoseidon2RecordSet { - pub verify_batch_records: Vec>, - pub simple_permute_records: Vec>, -} - -pub struct NativePoseidon2Chip { - pub(super) air: NativePoseidon2Air, - pub record_set: NativePoseidon2RecordSet, - pub height: usize, - pub(super) offline_memory: Arc>>, +pub struct NativePoseidon2Step { + // pre-computed Poseidon2 sub cols for dummy rows. + empty_poseidon2_sub_cols: Vec, pub(super) subchip: Poseidon2SubChip, pub(super) streams: Arc>>, } -impl NativePoseidon2Chip { - pub fn new( - port: SystemPort, - offline_memory: Arc>>, - poseidon2_config: Poseidon2Config, - verify_batch_bus: VerifyBatchBus, - streams: Arc>>, - ) -> Self { - let air = NativePoseidon2Air { - execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), - memory_bridge: port.memory_bridge, - internal_bus: verify_batch_bus, - subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), - address_space: F::from_canonical_u32(AS::Native as u32), - }; +impl NativePoseidon2Step { + pub fn new(poseidon2_config: Poseidon2Config, streams: Arc>>) -> Self { + let subchip = Poseidon2SubChip::new(poseidon2_config.constants); + let empty_poseidon2_sub_cols = subchip.generate_trace(vec![[F::ZERO; CHUNK * 2]]).values; Self { - record_set: Default::default(), - air, - height: 0, - offline_memory, - subchip: Poseidon2SubChip::new(poseidon2_config.constants), + empty_poseidon2_sub_cols, + subchip, streams, } } @@ -172,18 +68,26 @@ impl NativePoseidon2Chip InstructionExecutor - for NativePoseidon2Chip +impl TraceStep + for NativePoseidon2Step { fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, CTX>, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> openvm_circuit::arch::Result<()> { + debug_assert_eq!(width, NativePoseidon2Cols::::width()); + let init_timestamp_u32 = state.memory.timestamp; if instruction.opcode == PERM_POS2.global_opcode() || instruction.opcode == COMP_POS2.global_opcode() { + let cols: &mut NativePoseidon2Cols = + trace[*trace_offset..*trace_offset + width].borrow_mut(); + let simple_cols: &mut SimplePoseidonSpecificCols = + cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); let &Instruction { a: output_register, b: input_register_1, @@ -192,22 +96,45 @@ impl InstructionExecutor e: data_address_space, .. } = instruction; + debug_assert_eq!( + register_address_space, + F::from_canonical_u32(AS::Native as u32) + ); + debug_assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); + let [output_pointer]: [F; 1] = tracing_read_native( + state.memory, + output_register.as_canonical_u32(), + simple_cols.read_output_pointer.as_mut(), + ); + let output_pointer_u32 = output_pointer.as_canonical_u32(); + let [input_pointer_1]: [F; 1] = tracing_read_native( + state.memory, + input_register_1.as_canonical_u32(), + simple_cols.read_input_pointer_1.as_mut(), + ); + let input_pointer_1_u32 = input_pointer_1.as_canonical_u32(); + let [input_pointer_2]: [F; 1] = if instruction.opcode == PERM_POS2.global_opcode() { + state.memory.increment_timestamp(); + [input_pointer_1 + F::from_canonical_usize(CHUNK)] + } else { + tracing_read_native( + state.memory, + input_register_2.as_canonical_u32(), + simple_cols.read_input_pointer_2.as_mut(), + ) + }; + let input_pointer_2_u32 = input_pointer_2.as_canonical_u32(); + let data_1: [F; CHUNK] = tracing_read_native( + state.memory, + input_pointer_1_u32, + simple_cols.read_data_1.as_mut(), + ); + let data_2: [F; CHUNK] = tracing_read_native( + state.memory, + input_pointer_2_u32, + simple_cols.read_data_2.as_mut(), + ); - let (read_output_pointer, output_pointer) = - memory.read_cell(register_address_space, output_register); - let (read_input_pointer_1, input_pointer_1) = - memory.read_cell(register_address_space, input_register_1); - let (read_input_pointer_2, input_pointer_2) = - if instruction.opcode == PERM_POS2.global_opcode() { - memory.increment_timestamp(); - (None, input_pointer_1 + F::from_canonical_usize(CHUNK)) - } else { - let (read_input_pointer_2, input_pointer_2) = - memory.read_cell(register_address_space, input_register_2); - (Some(read_input_pointer_2), input_pointer_2) - }; - let (read_data_1, data_1) = memory.read::(data_address_space, input_pointer_1); - let (read_data_2, data_2) = memory.read::(data_address_space, input_pointer_2); let p2_input = std::array::from_fn(|i| { if i < CHUNK { data_1[i] @@ -216,50 +143,54 @@ impl InstructionExecutor } }); let output = self.subchip.permute(p2_input); - let (write_data_1, _) = memory.write::( - data_address_space, - output_pointer, - std::array::from_fn(|i| output[i]), + tracing_write_native( + state.memory, + output_pointer_u32, + &std::array::from_fn(|i| output[i]), + &mut simple_cols.write_data_1, ); - let write_data_2 = if instruction.opcode == PERM_POS2.global_opcode() { - Some( - memory - .write::( - data_address_space, - output_pointer + F::from_canonical_usize(CHUNK), - std::array::from_fn(|i| output[CHUNK + i]), - ) - .0, - ) + if instruction.opcode == PERM_POS2.global_opcode() { + tracing_write_native( + state.memory, + output_pointer_u32 + CHUNK as u32, + &std::array::from_fn(|i| output[i + CHUNK]), + &mut simple_cols.write_data_2, + ); } else { - memory.increment_timestamp(); - None - }; - - assert_eq!( - memory.timestamp(), - from_state.timestamp + NUM_SIMPLE_ACCESSES + state.memory.increment_timestamp(); + } + debug_assert_eq!( + state.memory.timestamp, + init_timestamp_u32 + NUM_SIMPLE_ACCESSES ); - - self.record_set - .simple_permute_records - .push(SimplePoseidonRecord { - from_state, - instruction: instruction.clone(), - read_input_pointer_1, - read_input_pointer_2, - read_output_pointer, - read_data_1, - read_data_2, - write_data_1, - write_data_2, - input_pointer_1, - input_pointer_2, - output_pointer, - p2_input, - }); - self.height += 1; + cols.incorporate_row = F::ZERO; + cols.incorporate_sibling = F::ZERO; + cols.inside_row = F::ZERO; + cols.simple = F::ONE; + cols.end_inside_row = F::ZERO; + cols.end_top_level = F::ZERO; + cols.is_exhausted = [F::ZERO; CHUNK - 1]; + cols.start_timestamp = F::from_canonical_u32(init_timestamp_u32); + + cols.inner.inputs = p2_input; + simple_cols.pc = F::from_canonical_u32(*state.pc); + simple_cols.is_compress = F::from_bool(instruction.opcode == COMP_POS2.global_opcode()); + simple_cols.output_register = output_register; + simple_cols.input_register_1 = input_register_1; + simple_cols.input_register_2 = input_register_2; + simple_cols.output_pointer = output_pointer; + simple_cols.input_pointer_1 = input_pointer_1; + simple_cols.input_pointer_2 = input_pointer_2; + + *trace_offset += width; } else if instruction.opcode == VERIFY_BATCH.global_opcode() { + let init_timestamp = F::from_canonical_u32(init_timestamp_u32); + let mut col_buffer = + vec![F::ZERO; NativePoseidon2Cols::::width()]; + let last_top_level_cols: &mut NativePoseidon2Cols = + col_buffer.as_mut_slice().borrow_mut(); + let ltl_specific_cols: &mut TopLevelSpecificCols = + last_top_level_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); let &Instruction { a: dim_register, b: opened_register, @@ -270,35 +201,116 @@ impl InstructionExecutor g: opened_element_size_inv, .. } = instruction; - let address_space = self.air.address_space; // calc inverse fast assuming opened_element_size in {1, 4} let mut opened_element_size = F::ONE; while opened_element_size * opened_element_size_inv != F::ONE { opened_element_size += F::ONE; } - let proof_id = memory.unsafe_read_cell(address_space, proof_id_ptr); - let (dim_base_pointer_read, dim_base_pointer) = - memory.read_cell(address_space, dim_register); - let (opened_base_pointer_read, opened_base_pointer) = - memory.read_cell(address_space, opened_register); - let (opened_length_read, opened_length) = - memory.read_cell(address_space, opened_length_register); - let (index_base_pointer_read, index_base_pointer) = - memory.read_cell(address_space, index_register); - let (commit_pointer_read, commit_pointer) = - memory.read_cell(address_space, commit_register); - let (commit_read, commit) = memory.read(address_space, commit_pointer); + let [proof_id]: [F; 1] = + memory_read_native(state.memory.data(), proof_id_ptr.as_canonical_u32()); + let [dim_base_pointer]: [F; 1] = tracing_read_native( + state.memory, + dim_register.as_canonical_u32(), + ltl_specific_cols.dim_base_pointer_read.as_mut(), + ); + let dim_base_pointer_u32 = dim_base_pointer.as_canonical_u32(); + let [opened_base_pointer]: [F; 1] = tracing_read_native( + state.memory, + opened_register.as_canonical_u32(), + ltl_specific_cols.opened_base_pointer_read.as_mut(), + ); + let opened_base_pointer_u32 = opened_base_pointer.as_canonical_u32(); + let [opened_length]: [F; 1] = tracing_read_native( + state.memory, + opened_length_register.as_canonical_u32(), + ltl_specific_cols.opened_length_read.as_mut(), + ); + let [index_base_pointer]: [F; 1] = tracing_read_native( + state.memory, + index_register.as_canonical_u32(), + ltl_specific_cols.index_base_pointer_read.as_mut(), + ); + let index_base_pointer_u32 = index_base_pointer.as_canonical_u32(); + let [commit_pointer]: [F; 1] = tracing_read_native( + state.memory, + commit_register.as_canonical_u32(), + ltl_specific_cols.commit_pointer_read.as_mut(), + ); + let commit = tracing_read_native( + state.memory, + commit_pointer.as_canonical_u32(), + ltl_specific_cols.commit_read.as_mut(), + ); let opened_length = opened_length.as_canonical_u32() as usize; + let [initial_log_height]: [F; 1] = + memory_read_native(state.memory.data(), dim_base_pointer_u32); + let initial_log_height_u32 = initial_log_height.as_canonical_u32(); + let mut log_height = initial_log_height_u32 as i32; + + // Number of non-inside rows, this is used to compute the offset of the inside row + // section. + let num_non_inside_rows = { + let mut log_height = initial_log_height_u32 as i32; + let mut opened_index = 0; + let mut num_non_inside_rows = 0; + while log_height >= 0 { + if opened_index < opened_length + && memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] == F::from_canonical_u32(log_height as u32) + { + let mut row_pointer = 0; + let mut row_end = 0; + let mut is_first_in_segment = true; + + loop { + let mut cell_idx = 0; + for _ in 0..CHUNK { + if is_first_in_segment || row_pointer == row_end { + if is_first_in_segment { + is_first_in_segment = false; + } else { + opened_index += 1; + if opened_index == opened_length + || memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] != F::from_canonical_u32(log_height as u32) + { + break; + } + } + let [new_row_pointer, row_len]: [F; 2] = memory_read_native( + state.memory.data(), + opened_base_pointer_u32 + 2 * opened_index as u32, + ); + row_pointer = new_row_pointer.as_canonical_u32() as usize; + row_end = row_pointer + + (opened_element_size * row_len).as_canonical_u32() + as usize; + } + cell_idx += 1; + row_pointer += 1; + } - let initial_log_height = memory - .unsafe_read_cell(address_space, dim_base_pointer) - .as_canonical_u32(); - let mut log_height = initial_log_height as i32; - let mut sibling_index = 0; + if cell_idx < CHUNK { + break; + } + } + num_non_inside_rows += 1; + } + if log_height != 0 { + num_non_inside_rows += 1; + } + log_height -= 1; + } + num_non_inside_rows + }; + let mut proof_index = 0; let mut opened_index = 0; - let mut top_level = vec![]; let mut root = [F::ZERO; CHUNK]; let sibling_proof: Vec<[F; CHUNK]> = { @@ -310,18 +322,21 @@ impl InstructionExecutor .collect() }; + let mut inside_row_offset = *trace_offset + num_non_inside_rows * width; + let mut non_inside_row_offset = *trace_offset; + while log_height >= 0 { - let incorporate_row = if opened_index < opened_length - && memory.unsafe_read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(opened_index), - ) == F::from_canonical_u32(log_height as u32) + if opened_index < opened_length + && memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] == F::from_canonical_u32(log_height as u32) { + state + .memory + .increment_timestamp_by(NUM_INITIAL_READS as u32); + let incorporate_start_timestamp = state.memory.timestamp; let initial_opened_index = opened_index; - for _ in 0..NUM_INITIAL_READS { - memory.increment_timestamp(); - } - let mut chunks = vec![]; let mut row_pointer = 0; let mut row_end = 0; @@ -332,166 +347,365 @@ impl InstructionExecutor let mut is_first_in_segment = true; loop { - let mut cells = vec![]; + let inside_cols: &mut NativePoseidon2Cols = + trace[inside_row_offset..inside_row_offset + width].borrow_mut(); + let inside_specific_cols: &mut InsideRowSpecificCols = inside_cols + .specific[..InsideRowSpecificCols::::width()] + .borrow_mut(); + let start_timestamp_u32 = state.memory.timestamp; + + let mut cells_idx = 0; for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { - let read_row_pointer_and_length = if is_first_in_segment - || row_pointer == row_end - { + let cell_cols = &mut inside_specific_cols.cells[cells_idx]; + if is_first_in_segment || row_pointer == row_end { if is_first_in_segment { is_first_in_segment = false; } else { opened_index += 1; if opened_index == opened_length - || memory.unsafe_read_cell( - address_space, - dim_base_pointer - + F::from_canonical_usize(opened_index), - ) != F::from_canonical_u32(log_height as u32) + || memory_read_native::( + state.memory.data(), + dim_base_pointer_u32 + opened_index as u32, + )[0] != F::from_canonical_u32(log_height as u32) { break; } } - let (result, [new_row_pointer, row_len]) = memory.read( - address_space, - opened_base_pointer + F::from_canonical_usize(2 * opened_index), + let [new_row_pointer, row_len]: [F; 2] = tracing_read_native( + state.memory, + opened_base_pointer_u32 + 2 * opened_index as u32, + cell_cols.read_row_pointer_and_length.as_mut(), ); row_pointer = new_row_pointer.as_canonical_u32() as usize; row_end = row_pointer + (opened_element_size * row_len).as_canonical_u32() as usize; - Some(result) + cell_cols.is_first_in_row = F::ONE; } else { - memory.increment_timestamp(); - None - }; - let (read, value) = memory - .read_cell(address_space, F::from_canonical_usize(row_pointer)); - cells.push(CellRecord { - read, - opened_index, - read_row_pointer_and_length, - row_pointer, - row_end, - }); + state.memory.increment_timestamp(); + } + let [value]: [F; 1] = tracing_read_native( + state.memory, + row_pointer as u32, + cell_cols.read.as_mut(), + ); + + cell_cols.opened_index = F::from_canonical_usize(opened_index); + cell_cols.row_pointer = F::from_canonical_usize(row_pointer); + cell_cols.row_end = F::from_canonical_usize(row_end); + *chunk_elem = value; row_pointer += 1; + cells_idx += 1; } - if cells.is_empty() { + if cells_idx == 0 { break; } - let cells_len = cells.len(); - chunks.push(InsideRowRecord { - cells, - p2_input: rolling_hash, - }); - self.height += 1; + let p2_input = rolling_hash; prev_rolling_hash = Some(rolling_hash); self.subchip.permute_mut(&mut rolling_hash); - if cells_len < CHUNK { - for _ in 0..CHUNK - cells_len { - memory.increment_timestamp(); - memory.increment_timestamp(); + if cells_idx < CHUNK { + state + .memory + .increment_timestamp_by(2 * (CHUNK - cells_idx) as u32); + } + + inside_row_offset += width; + inside_cols.inner.inputs = p2_input; + inside_cols.incorporate_row = F::ZERO; + inside_cols.incorporate_sibling = F::ZERO; + inside_cols.inside_row = F::ONE; + inside_cols.simple = F::ZERO; + // `end_inside_row` of the last row will be set to 1 after this loop. + inside_cols.end_inside_row = F::ZERO; + inside_cols.end_top_level = F::ZERO; + inside_cols.opened_element_size_inv = opened_element_size_inv; + inside_cols.very_first_timestamp = + F::from_canonical_u32(incorporate_start_timestamp); + inside_cols.start_timestamp = F::from_canonical_u32(start_timestamp_u32); + + inside_cols.initial_opened_index = + F::from_canonical_usize(initial_opened_index); + inside_cols.opened_base_pointer = opened_base_pointer; + if cells_idx < CHUNK { + let exhausted_opened_idx = F::from_canonical_usize(opened_index - 1); + for exhausted_idx in cells_idx..CHUNK { + inside_cols.is_exhausted[exhausted_idx - 1] = F::ONE; + inside_specific_cols.cells[exhausted_idx].opened_index = + exhausted_opened_idx; } break; } } + { + let inside_cols: &mut NativePoseidon2Cols = + trace[inside_row_offset - width..inside_row_offset].borrow_mut(); + inside_cols.end_inside_row = F::ONE; + } + + let incorporate_cols: &mut NativePoseidon2Cols = + trace[non_inside_row_offset..non_inside_row_offset + width].borrow_mut(); + let top_level_specific_cols: &mut TopLevelSpecificCols = incorporate_cols + .specific[..TopLevelSpecificCols::::width()] + .borrow_mut(); + let final_opened_index = opened_index - 1; - let (initial_height_read, height_check) = memory.read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(initial_opened_index), + let [height_check]: [F; 1] = tracing_read_native( + state.memory, + dim_base_pointer_u32 + initial_opened_index as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), ); assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); - let (final_height_read, height_check) = memory.read_cell( - address_space, - dim_base_pointer + F::from_canonical_usize(final_opened_index), + let final_height_read_timestamp = state.memory.timestamp; + let [height_check]: [F; 1] = tracing_read_native( + state.memory, + dim_base_pointer_u32 + final_opened_index as u32, + top_level_specific_cols.read_final_height.as_mut(), ); assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); - - let (p2_input, new_root) = if log_height as u32 == initial_log_height { + let (p2_input, new_root) = if log_height as u32 == initial_log_height_u32 { (prev_rolling_hash.unwrap(), hash) } else { self.compress(root, hash) }; root = new_root; - self.height += 1; - Some(IncorporateRowRecord { - chunks, - initial_opened_index, - final_opened_index, - initial_height_read, - final_height_read, - p2_input, - }) - } else { - None - }; - - let incorporate_sibling = if log_height == 0 { - None - } else { - for _ in 0..NUM_INITIAL_READS { - memory.increment_timestamp(); - } + non_inside_row_offset += width; + + incorporate_cols.incorporate_row = F::ONE; + incorporate_cols.incorporate_sibling = F::ZERO; + incorporate_cols.inside_row = F::ZERO; + incorporate_cols.simple = F::ZERO; + incorporate_cols.end_inside_row = F::ZERO; + incorporate_cols.end_top_level = F::ZERO; + incorporate_cols.start_top_level = F::from_bool(proof_index == 0); + incorporate_cols.opened_element_size_inv = opened_element_size_inv; + incorporate_cols.very_first_timestamp = init_timestamp; + incorporate_cols.start_timestamp = F::from_canonical_u32( + incorporate_start_timestamp - NUM_INITIAL_READS as u32, + ); + top_level_specific_cols.end_timestamp = + F::from_canonical_u32(final_height_read_timestamp + 1); + + incorporate_cols.inner.inputs = p2_input; + incorporate_cols.initial_opened_index = + F::from_canonical_usize(initial_opened_index); + top_level_specific_cols.final_opened_index = + F::from_canonical_usize(final_opened_index); + top_level_specific_cols.log_height = F::from_canonical_u32(log_height as u32); + top_level_specific_cols.opened_length = F::from_canonical_usize(opened_length); + top_level_specific_cols.dim_base_pointer = dim_base_pointer; + incorporate_cols.opened_base_pointer = opened_base_pointer; + top_level_specific_cols.index_base_pointer = index_base_pointer; + top_level_specific_cols.proof_index = F::from_canonical_usize(proof_index); + } - let (read_sibling_is_on_right, sibling_is_on_right) = memory.read_cell( - address_space, - index_base_pointer + F::from_canonical_usize(sibling_index), + if log_height != 0 { + let row_start_timestamp = state.memory.timestamp; + state + .memory + .increment_timestamp_by(NUM_INITIAL_READS as u32); + + let sibling_cols: &mut NativePoseidon2Cols = + trace[non_inside_row_offset..non_inside_row_offset + width].borrow_mut(); + let top_level_specific_cols: &mut TopLevelSpecificCols = + sibling_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); + + let read_sibling_is_on_right_timestamp = state.memory.timestamp; + let [sibling_is_on_right]: [F; 1] = tracing_read_native( + state.memory, + index_base_pointer_u32 + proof_index as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), ); - let sibling_is_on_right = sibling_is_on_right == F::ONE; - let sibling = sibling_proof[sibling_index]; - let (p2_input, new_root) = if sibling_is_on_right { + let sibling = sibling_proof[proof_index]; + let (p2_input, new_root) = if sibling_is_on_right == F::ONE { self.compress(sibling, root) } else { self.compress(root, sibling) }; root = new_root; - self.height += 1; - Some(IncorporateSiblingRecord { - read_sibling_is_on_right, - sibling_is_on_right, - p2_input, - }) + non_inside_row_offset += width; + + sibling_cols.inner.inputs = p2_input; + + sibling_cols.incorporate_row = F::ZERO; + sibling_cols.incorporate_sibling = F::ONE; + sibling_cols.inside_row = F::ZERO; + sibling_cols.simple = F::ZERO; + sibling_cols.end_inside_row = F::ZERO; + sibling_cols.end_top_level = F::ZERO; + sibling_cols.start_top_level = F::ZERO; + sibling_cols.opened_element_size_inv = opened_element_size_inv; + sibling_cols.very_first_timestamp = init_timestamp; + sibling_cols.start_timestamp = F::from_canonical_u32(row_start_timestamp); + + top_level_specific_cols.end_timestamp = + F::from_canonical_u32(read_sibling_is_on_right_timestamp + 1); + sibling_cols.initial_opened_index = F::from_canonical_usize(opened_index); + top_level_specific_cols.final_opened_index = + F::from_canonical_usize(opened_index - 1); + top_level_specific_cols.log_height = F::from_canonical_u32(log_height as u32); + top_level_specific_cols.opened_length = F::from_canonical_usize(opened_length); + top_level_specific_cols.dim_base_pointer = dim_base_pointer; + sibling_cols.opened_base_pointer = opened_base_pointer; + top_level_specific_cols.index_base_pointer = index_base_pointer; + + top_level_specific_cols.proof_index = F::from_canonical_usize(proof_index); + top_level_specific_cols.sibling_is_on_right = sibling_is_on_right; }; - top_level.push(TopLevelRecord { - incorporate_row, - incorporate_sibling, - }); - log_height -= 1; - sibling_index += 1; + proof_index += 1; } - + let ltl_trace_cols: &mut NativePoseidon2Cols = + trace[non_inside_row_offset - width..non_inside_row_offset].borrow_mut(); + let ltl_trace_specific_cols: &mut TopLevelSpecificCols = + ltl_trace_cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); + ltl_trace_cols.end_top_level = F::ONE; + ltl_trace_specific_cols.pc = F::from_canonical_u32(*state.pc); + ltl_trace_specific_cols.dim_register = dim_register; + ltl_trace_specific_cols.opened_register = opened_register; + ltl_trace_specific_cols.opened_length_register = opened_length_register; + ltl_trace_specific_cols.proof_id = proof_id_ptr; + ltl_trace_specific_cols.index_register = index_register; + ltl_trace_specific_cols.commit_register = commit_register; + ltl_trace_specific_cols.commit_pointer = commit_pointer; + ltl_trace_specific_cols.dim_base_pointer_read = ltl_specific_cols.dim_base_pointer_read; + ltl_trace_specific_cols.opened_base_pointer_read = + ltl_specific_cols.opened_base_pointer_read; + ltl_trace_specific_cols.opened_length_read = ltl_specific_cols.opened_length_read; + ltl_trace_specific_cols.index_base_pointer_read = + ltl_specific_cols.index_base_pointer_read; + ltl_trace_specific_cols.commit_pointer_read = ltl_specific_cols.commit_pointer_read; + ltl_trace_specific_cols.commit_read = ltl_specific_cols.commit_read; + + *trace_offset = inside_row_offset; assert_eq!(commit, root); - self.record_set - .verify_batch_records - .push(VerifyBatchRecord { - from_state, - instruction: instruction.clone(), - dim_base_pointer, - opened_base_pointer, - opened_length, - index_base_pointer, - commit_pointer, - dim_base_pointer_read, - opened_base_pointer_read, - opened_length_read, - index_base_pointer_read, - commit_pointer_read, - commit_read, - initial_log_height: initial_log_height as usize, - top_level, - }); } else { unreachable!() } - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) + + *state.pc += DEFAULT_PC_STEP; + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let inner_cols = { + let cols: &NativePoseidon2Cols = row_slice.as_ref().borrow(); + &self.subchip.generate_trace(vec![cols.inner.inputs]).values + }; + let inner_width = self.subchip.air.width(); + row_slice[..inner_width].copy_from_slice(&inner_cols); + let cols: &mut NativePoseidon2Cols = row_slice.borrow_mut(); + + // Simple poseidon2 row + if cols.simple.is_one() { + let simple_cols: &mut SimplePoseidonSpecificCols = + cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + mem_helper.fill_from_prev( + start_timestamp_u32, + simple_cols.read_output_pointer.as_mut(), + ); + mem_helper.fill_from_prev( + start_timestamp_u32 + 1, + simple_cols.read_input_pointer_1.as_mut(), + ); + if simple_cols.is_compress.is_one() { + mem_helper.fill_from_prev( + start_timestamp_u32 + 2, + simple_cols.read_input_pointer_2.as_mut(), + ); + } + mem_helper.fill_from_prev(start_timestamp_u32 + 3, simple_cols.read_data_1.as_mut()); + mem_helper.fill_from_prev(start_timestamp_u32 + 4, simple_cols.read_data_2.as_mut()); + mem_helper.fill_from_prev(start_timestamp_u32 + 5, simple_cols.write_data_1.as_mut()); + if simple_cols.is_compress.is_zero() { + mem_helper + .fill_from_prev(start_timestamp_u32 + 6, simple_cols.write_data_2.as_mut()); + } + } else if cols.inside_row.is_one() { + let inside_row_specific_cols: &mut InsideRowSpecificCols = + cols.specific[..InsideRowSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + for (i, cell) in inside_row_specific_cols.cells.iter_mut().enumerate() { + if i > 0 && cols.is_exhausted[i - 1].is_one() { + break; + } + if cell.is_first_in_row.is_one() { + mem_helper.fill_from_prev( + start_timestamp_u32 + 2 * i as u32, + cell.read_row_pointer_and_length.as_mut(), + ); + } + mem_helper + .fill_from_prev(start_timestamp_u32 + 2 * i as u32 + 1, cell.read.as_mut()); + } + } else { + let top_level_specific_cols: &mut TopLevelSpecificCols = + cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); + let start_timestamp_u32 = cols.start_timestamp.as_canonical_u32(); + if cols.end_top_level.is_one() { + let very_start_timestamp_u32 = cols.very_first_timestamp.as_canonical_u32(); + mem_helper.fill_from_prev( + very_start_timestamp_u32, + top_level_specific_cols.dim_base_pointer_read.as_mut(), + ); + mem_helper.fill_from_prev( + very_start_timestamp_u32 + 1, + top_level_specific_cols.opened_base_pointer_read.as_mut(), + ); + mem_helper.fill_from_prev( + very_start_timestamp_u32 + 2, + top_level_specific_cols.opened_length_read.as_mut(), + ); + mem_helper.fill_from_prev( + very_start_timestamp_u32 + 3, + top_level_specific_cols.index_base_pointer_read.as_mut(), + ); + mem_helper.fill_from_prev( + very_start_timestamp_u32 + 4, + top_level_specific_cols.commit_pointer_read.as_mut(), + ); + mem_helper.fill_from_prev( + very_start_timestamp_u32 + 5, + top_level_specific_cols.commit_read.as_mut(), + ); + } + if cols.incorporate_row.is_one() { + let end_timestamp = top_level_specific_cols.end_timestamp.as_canonical_u32(); + mem_helper.fill_from_prev( + end_timestamp - 2, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), + ); + mem_helper.fill_from_prev( + end_timestamp - 1, + top_level_specific_cols.read_final_height.as_mut(), + ); + } else if cols.incorporate_sibling.is_one() { + mem_helper.fill_from_prev( + start_timestamp_u32 + NUM_INITIAL_READS as u32, + top_level_specific_cols + .read_initial_height_or_sibling_is_on_right + .as_mut(), + ); + } else { + unreachable!() + } + } + } + + fn fill_dummy_trace_row(&self, _mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let width = self.subchip.air.width(); + row_slice[..width].copy_from_slice(&self.empty_poseidon2_sub_cols); } fn get_opcode_name(&self, opcode: usize) -> String { @@ -506,3 +720,263 @@ impl InstructionExecutor } } } +impl StepExecutorE1 + for NativePoseidon2Step +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> openvm_circuit::arch::Result<()> + where + Ctx: E1E2ExecutionCtx, + { + self.execute_e1_impl(state, instruction); + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> openvm_circuit::arch::Result<()> { + let height = self.execute_e1_impl(state, instruction); + state.ctx.trace_heights[chip_index] += height as u32; + + Ok(()) + } +} + +impl NativePoseidon2Step { + /// Returns the number of used rows. + fn execute_e1_impl( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> usize + where + Ctx: E1E2ExecutionCtx, + { + let mut height = 0; + if instruction.opcode == PERM_POS2.global_opcode() + || instruction.opcode == COMP_POS2.global_opcode() + { + let &Instruction { + a: output_register, + b: input_register_1, + c: input_register_2, + d: register_address_space, + e: data_address_space, + .. + } = instruction; + debug_assert_eq!( + register_address_space, + F::from_canonical_u32(AS::Native as u32) + ); + debug_assert_eq!(data_address_space, F::from_canonical_u32(AS::Native as u32)); + let [output_pointer]: [F; 1] = + memory_read_native(state.memory, output_register.as_canonical_u32()); + let [input_pointer_1]: [F; 1] = + memory_read_native(state.memory, input_register_1.as_canonical_u32()); + let [input_pointer_2] = if instruction.opcode == PERM_POS2.global_opcode() { + [input_pointer_1 + F::from_canonical_usize(CHUNK)] + } else { + memory_read_native(state.memory, input_register_2.as_canonical_u32()) + }; + let data_1: [F; CHUNK] = + memory_read_native(state.memory, input_pointer_1.as_canonical_u32()); + let data_2: [F; CHUNK] = + memory_read_native(state.memory, input_pointer_2.as_canonical_u32()); + + let p2_input = std::array::from_fn(|i| { + if i < CHUNK { + data_1[i] + } else { + data_2[i - CHUNK] + } + }); + let output = self.subchip.permute(p2_input); + let output_pointer_u32 = output_pointer.as_canonical_u32(); + memory_write_native::( + state.memory, + output_pointer_u32, + &std::array::from_fn(|i| output[i]), + ); + if instruction.opcode == PERM_POS2.global_opcode() { + memory_write_native::( + state.memory, + output_pointer_u32 + CHUNK as u32, + &std::array::from_fn(|i| output[i + CHUNK]), + ); + } + + height += 1; + } else if instruction.opcode == VERIFY_BATCH.global_opcode() { + // TODO: Add a flag `optimistic_execution`. When the flag is true, we trust all inputs + // and skip all input validation computation during E1 execution. + let &Instruction { + a: dim_register, + b: opened_register, + c: opened_length_register, + d: proof_id_ptr, + e: index_register, + f: commit_register, + g: opened_element_size_inv, + .. + } = instruction; + // calc inverse fast assuming opened_element_size in {1, 4} + let mut opened_element_size = F::ONE; + while opened_element_size * opened_element_size_inv != F::ONE { + opened_element_size += F::ONE; + } + + let [proof_id]: [F; 1] = + memory_read_native(state.memory, proof_id_ptr.as_canonical_u32()); + let [dim_base_pointer]: [F; 1] = + memory_read_native(state.memory, dim_register.as_canonical_u32()); + let dim_base_pointer_u32 = dim_base_pointer.as_canonical_u32(); + let [opened_base_pointer]: [F; 1] = + memory_read_native(state.memory, opened_register.as_canonical_u32()); + let opened_base_pointer_u32 = opened_base_pointer.as_canonical_u32(); + let [opened_length]: [F; 1] = + memory_read_native(state.memory, opened_length_register.as_canonical_u32()); + let [index_base_pointer]: [F; 1] = + memory_read_native(state.memory, index_register.as_canonical_u32()); + let index_base_pointer_u32 = index_base_pointer.as_canonical_u32(); + let [commit_pointer]: [F; 1] = + memory_read_native(state.memory, commit_register.as_canonical_u32()); + let commit: [F; CHUNK] = + memory_read_native(state.memory, commit_pointer.as_canonical_u32()); + + let opened_length = opened_length.as_canonical_u32() as usize; + + let initial_log_height = { + let [height]: [F; 1] = memory_read_native(state.memory, dim_base_pointer_u32); + height.as_canonical_u32() + }; + + let mut log_height = initial_log_height as i32; + let mut sibling_index = 0; + let mut opened_index = 0; + + let mut root = [F::ZERO; CHUNK]; + let sibling_proof: Vec<[F; CHUNK]> = { + let streams = self.streams.lock().unwrap(); + let proof_idx = proof_id.as_canonical_u32() as usize; + streams.hint_space[proof_idx] + .par_chunks(CHUNK) + .map(|c| c.try_into().unwrap()) + .collect() + }; + + while log_height >= 0 { + if opened_index < opened_length + && memory_read_native::( + state.memory, + dim_base_pointer_u32 + opened_index as u32, + )[0] == F::from_canonical_u32(log_height as u32) + { + let initial_opened_index = opened_index; + + let mut row_pointer = 0; + let mut row_end = 0; + + let mut rolling_hash = [F::ZERO; 2 * CHUNK]; + + let mut is_first_in_segment = true; + + loop { + let mut cells_len = 0; + for chunk_elem in rolling_hash.iter_mut().take(CHUNK) { + if is_first_in_segment || row_pointer == row_end { + if is_first_in_segment { + is_first_in_segment = false; + } else { + opened_index += 1; + if opened_index == opened_length + || memory_read_native::( + state.memory, + dim_base_pointer_u32 + opened_index as u32, + )[0] != F::from_canonical_u32(log_height as u32) + { + break; + } + } + let [new_row_pointer, row_len]: [F; 2] = memory_read_native( + state.memory, + opened_base_pointer_u32 + 2 * opened_index as u32, + ); + row_pointer = new_row_pointer.as_canonical_u32() as usize; + row_end = row_pointer + + (opened_element_size * row_len).as_canonical_u32() as usize; + } + let [value]: [F; 1] = + memory_read_native(state.memory, row_pointer as u32); + cells_len += 1; + *chunk_elem = value; + row_pointer += 1; + } + if cells_len == 0 { + break; + } + height += 1; + self.subchip.permute_mut(&mut rolling_hash); + if cells_len < CHUNK { + break; + } + } + let final_opened_index = opened_index - 1; + let [height_check]: [F; 1] = memory_read_native( + state.memory, + dim_base_pointer_u32 + initial_opened_index as u32, + ); + assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + let [height_check]: [F; 1] = memory_read_native( + state.memory, + dim_base_pointer_u32 + final_opened_index as u32, + ); + assert_eq!(height_check, F::from_canonical_u32(log_height as u32)); + + let hash: [F; CHUNK] = std::array::from_fn(|i| rolling_hash[i]); + + let new_root = if log_height as u32 == initial_log_height { + hash + } else { + self.compress(root, hash).1 + }; + root = new_root; + + height += 1; + } + + if log_height != 0 { + let [sibling_is_on_right]: [F; 1] = memory_read_native( + state.memory, + index_base_pointer_u32 + sibling_index as u32, + ); + let sibling_is_on_right = sibling_is_on_right == F::ONE; + let sibling = sibling_proof[sibling_index]; + let new_root = if sibling_is_on_right { + self.compress(sibling, root).1 + } else { + self.compress(root, sibling).1 + }; + root = new_root; + + height += 1; + } + + log_height -= 1; + sibling_index += 1; + } + + assert_eq!(commit, root); + } else { + unreachable!() + } + *state.pc += DEFAULT_PC_STEP; + + height + } +} diff --git a/extensions/native/circuit/src/poseidon2/mod.rs b/extensions/native/circuit/src/poseidon2/mod.rs index af503e20f4..7dde941126 100644 --- a/extensions/native/circuit/src/poseidon2/mod.rs +++ b/extensions/native/circuit/src/poseidon2/mod.rs @@ -1,8 +1,49 @@ +use std::sync::{Arc, Mutex}; + +use openvm_circuit::{ + arch::{ExecutionBridge, NewVmChipWrapper, Streams, SystemPort}, + system::memory::SharedMemoryHelper, +}; +use openvm_native_compiler::conversion::AS; +use openvm_poseidon2_air::{Poseidon2Config, Poseidon2SubAir}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::poseidon2::{ + air::{NativePoseidon2Air, VerifyBatchBus}, + chip::NativePoseidon2Step, +}; + pub mod air; pub mod chip; mod columns; #[cfg(test)] mod tests; -mod trace; const CHUNK: usize = 8; +pub type NativePoseidon2Chip = NewVmChipWrapper< + F, + NativePoseidon2Air, + NativePoseidon2Step, +>; + +pub fn new_native_poseidon2_chip( + port: SystemPort, + poseidon2_config: Poseidon2Config, + verify_batch_bus: VerifyBatchBus, + streams: Arc>>, + max_ins_capacity: usize, + mem_helper: SharedMemoryHelper, +) -> NativePoseidon2Chip { + NativePoseidon2Chip::::new( + NativePoseidon2Air { + execution_bridge: ExecutionBridge::new(port.execution_bus, port.program_bus), + memory_bridge: port.memory_bridge, + internal_bus: verify_batch_bus, + subair: Arc::new(Poseidon2SubAir::new(poseidon2_config.constants.into())), + address_space: F::from_canonical_u32(AS::Native as u32), + }, + NativePoseidon2Step::new(poseidon2_config, streams), + max_ins_capacity, + mem_helper, + ) +} diff --git a/extensions/native/circuit/src/poseidon2/tests.rs b/extensions/native/circuit/src/poseidon2/tests.rs index 32a0e483a3..b81af91380 100644 --- a/extensions/native/circuit/src/poseidon2/tests.rs +++ b/extensions/native/circuit/src/poseidon2/tests.rs @@ -34,11 +34,12 @@ use rand::{rngs::StdRng, Rng}; use super::air::VerifyBatchBus; use crate::{ - poseidon2::{chip::NativePoseidon2Chip, CHUNK}, + poseidon2::{new_native_poseidon2_chip, CHUNK}, NativeConfig, }; const VERIFY_BATCH_BUS: VerifyBatchBus = VerifyBatchBus::new(7); +const MAX_INS_CAPACITY: usize = 1 << 15; fn compute_commit( dim: &[usize], @@ -155,12 +156,13 @@ fn test(cases: [Case; N]) { let mut tester = VmChipTestBuilder::default(); let streams = Arc::new(Mutex::new(Streams::default())); - let mut chip = NativePoseidon2Chip::::new( + let mut chip = new_native_poseidon2_chip::( tester.system_port(), - tester.offline_memory_mutex_arc(), Poseidon2Config::default(), VERIFY_BATCH_BUS, streams.clone(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); let mut rng = create_seeded_rng(); @@ -174,7 +176,7 @@ fn test(cases: [Case; N]) { random_instance(&mut rng, row_lengths, opened_element_size, |left, right| { let concatenated = std::array::from_fn(|i| if i < CHUNK { left[i] } else { right[i - CHUNK] }); - let permuted = chip.subchip.permute(concatenated); + let permuted = chip.step.subchip.permute(concatenated); ( std::array::from_fn(|i| permuted[i]), std::array::from_fn(|i| permuted[i + CHUNK]), @@ -218,7 +220,7 @@ fn test(cases: [Case; N]) { [row_pointer, opened_row.len() / opened_element_size], ); for (j, &opened_value) in opened_row.iter().enumerate() { - tester.write_cell(address_space, row_pointer + j, opened_value); + tester.write(address_space, row_pointer + j, [opened_value]); } } streams @@ -226,7 +228,7 @@ fn test(cases: [Case; N]) { .push(proof.iter().flatten().copied().collect()); drop(streams); for (i, &bit) in sibling_is_on_right.iter().enumerate() { - tester.write_cell(address_space, index_base_pointer + i, F::from_bool(bit)); + tester.write(address_space, index_base_pointer + i, [F::from_bool(bit)]); } tester.write(address_space, commit_pointer, commit); @@ -385,12 +387,13 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester::new( + let mut chip = new_native_poseidon2_chip::( tester.system_port(), - tester.offline_memory_mutex_arc(), Poseidon2Config::default(), VERIFY_BATCH_BUS, streams.clone(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); let mut rng = create_seeded_rng(); @@ -417,12 +420,12 @@ fn tester_with_random_poseidon2_ops(num_ops: usize) -> VmChipTester ChipUsageGetter - for NativePoseidon2Chip -{ - fn air_name(&self) -> String { - "VerifyBatchAir".to_string() - } - - fn current_trace_height(&self) -> usize { - self.height - } - - fn trace_width(&self) -> usize { - NativePoseidon2Cols::::width() - } -} - -impl NativePoseidon2Chip { - fn generate_subair_cols(&self, input: [F; 2 * CHUNK], cols: &mut [F]) { - let inner_trace = self.subchip.generate_trace(vec![input]); - let inner_width = self.air.subair.width(); - cols[..inner_width].copy_from_slice(inner_trace.values.as_slice()); - } - #[allow(clippy::too_many_arguments)] - fn incorporate_sibling_record_to_row( - &self, - record: &IncorporateSiblingRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &VerifyBatchRecord, - proof_index: usize, - opened_index: usize, - log_height: usize, - ) { - let &IncorporateSiblingRecord { - read_sibling_is_on_right, - sibling_is_on_right, - p2_input, - } = record; - - let read_sibling_is_on_right = memory.record_by_id(read_sibling_is_on_right); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ONE; - cols.inside_row = F::ZERO; - cols.simple = F::ZERO; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.start_top_level = F::ZERO; - cols.opened_element_size_inv = parent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp); - cols.start_timestamp = - F::from_canonical_u32(read_sibling_is_on_right.timestamp - NUM_INITIAL_READS as u32); - - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.end_timestamp = - F::from_canonical_usize(read_sibling_is_on_right.timestamp as usize + 1); - cols.initial_opened_index = F::from_canonical_usize(opened_index); - specific.final_opened_index = F::from_canonical_usize(opened_index - 1); - specific.log_height = F::from_canonical_usize(log_height); - specific.opened_length = F::from_canonical_usize(parent.opened_length); - specific.dim_base_pointer = parent.dim_base_pointer; - cols.opened_base_pointer = parent.opened_base_pointer; - specific.index_base_pointer = parent.index_base_pointer; - - specific.proof_index = F::from_canonical_usize(proof_index); - aux_cols_factory.generate_read_aux( - read_sibling_is_on_right, - &mut specific.read_initial_height_or_sibling_is_on_right, - ); - specific.sibling_is_on_right = F::from_bool(sibling_is_on_right); - } - fn correct_last_top_level_row( - &self, - record: &VerifyBatchRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) { - let &VerifyBatchRecord { - from_state, - commit_pointer, - dim_base_pointer_read, - opened_base_pointer_read, - opened_length_read, - index_base_pointer_read, - commit_pointer_read, - commit_read, - .. - } = record; - let instruction = &record.instruction; - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.end_top_level = F::ONE; - - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.pc = F::from_canonical_u32(from_state.pc); - specific.dim_register = instruction.a; - specific.opened_register = instruction.b; - specific.opened_length_register = instruction.c; - specific.proof_id = instruction.d; - specific.index_register = instruction.e; - specific.commit_register = instruction.f; - specific.commit_pointer = commit_pointer; - aux_cols_factory.generate_read_aux( - memory.record_by_id(dim_base_pointer_read), - &mut specific.dim_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(opened_base_pointer_read), - &mut specific.opened_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(opened_length_read), - &mut specific.opened_length_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(index_base_pointer_read), - &mut specific.index_base_pointer_read, - ); - aux_cols_factory.generate_read_aux( - memory.record_by_id(commit_pointer_read), - &mut specific.commit_pointer_read, - ); - aux_cols_factory - .generate_read_aux(memory.record_by_id(commit_read), &mut specific.commit_read); - } - #[allow(clippy::too_many_arguments)] - fn incorporate_row_record_to_row( - &self, - record: &IncorporateRowRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &VerifyBatchRecord, - proof_index: usize, - log_height: usize, - ) { - let &IncorporateRowRecord { - initial_opened_index, - final_opened_index, - initial_height_read, - final_height_read, - p2_input, - .. - } = record; - - let initial_height_read = memory.record_by_id(initial_height_read); - let final_height_read = memory.record_by_id(final_height_read); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ONE; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ZERO; - cols.simple = F::ZERO; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.start_top_level = F::from_bool(proof_index == 0); - cols.opened_element_size_inv = parent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32(parent.from_state.timestamp); - cols.start_timestamp = F::from_canonical_u32( - memory - .record_by_id( - record.chunks[0].cells[0] - .read_row_pointer_and_length - .unwrap(), - ) - .timestamp - - NUM_INITIAL_READS as u32, - ); - let specific: &mut TopLevelSpecificCols = - cols.specific[..TopLevelSpecificCols::::width()].borrow_mut(); - - specific.end_timestamp = F::from_canonical_u32(final_height_read.timestamp + 1); - - cols.initial_opened_index = F::from_canonical_usize(initial_opened_index); - specific.final_opened_index = F::from_canonical_usize(final_opened_index); - specific.log_height = F::from_canonical_usize(log_height); - specific.opened_length = F::from_canonical_usize(parent.opened_length); - specific.dim_base_pointer = parent.dim_base_pointer; - cols.opened_base_pointer = parent.opened_base_pointer; - specific.index_base_pointer = parent.index_base_pointer; - - specific.proof_index = F::from_canonical_usize(proof_index); - aux_cols_factory.generate_read_aux( - initial_height_read, - &mut specific.read_initial_height_or_sibling_is_on_right, - ); - aux_cols_factory.generate_read_aux(final_height_read, &mut specific.read_final_height); - } - #[allow(clippy::too_many_arguments)] - fn inside_row_record_to_row( - &self, - record: &InsideRowRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - parent: &IncorporateRowRecord, - grandparent: &VerifyBatchRecord, - is_last: bool, - ) { - let InsideRowRecord { cells, p2_input } = record; - - self.generate_subair_cols(*p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ONE; - cols.simple = F::ZERO; - cols.end_inside_row = F::from_bool(is_last); - cols.end_top_level = F::ZERO; - cols.opened_element_size_inv = grandparent.opened_element_size_inv(); - cols.very_first_timestamp = F::from_canonical_u32( - memory - .record_by_id( - parent.chunks[0].cells[0] - .read_row_pointer_and_length - .unwrap(), - ) - .timestamp, - ); - cols.start_timestamp = - F::from_canonical_u32(memory.record_by_id(cells[0].read).timestamp - 1); - let specific: &mut InsideRowSpecificCols = - cols.specific[..InsideRowSpecificCols::::width()].borrow_mut(); - - for (record, cell) in cells.iter().zip(specific.cells.iter_mut()) { - let &CellRecord { - read, - opened_index, - read_row_pointer_and_length, - row_pointer, - row_end, - } = record; - aux_cols_factory.generate_read_aux(memory.record_by_id(read), &mut cell.read); - cell.opened_index = F::from_canonical_usize(opened_index); - if let Some(read_row_pointer_and_length) = read_row_pointer_and_length { - aux_cols_factory.generate_read_aux( - memory.record_by_id(read_row_pointer_and_length), - &mut cell.read_row_pointer_and_length, - ); - } - cell.row_pointer = F::from_canonical_usize(row_pointer); - cell.row_end = F::from_canonical_usize(row_end); - cell.is_first_in_row = F::from_bool(read_row_pointer_and_length.is_some()); - } - - for cell in specific.cells.iter_mut().skip(cells.len()) { - cell.opened_index = F::from_canonical_usize(parent.final_opened_index); - } - - cols.is_exhausted = std::array::from_fn(|i| F::from_bool(i + 1 >= cells.len())); - - cols.initial_opened_index = F::from_canonical_usize(parent.initial_opened_index); - cols.opened_base_pointer = grandparent.opened_base_pointer; - } - // returns number of used cells - fn verify_batch_record_to_rows( - &self, - record: &VerifyBatchRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) -> usize { - let width = NativePoseidon2Cols::::width(); - let mut used_cells = 0; - - let mut opened_index = 0; - for (proof_index, top_level) in record.top_level.iter().enumerate() { - let log_height = record.initial_log_height - proof_index; - if let Some(incorporate_row) = &top_level.incorporate_row { - self.incorporate_row_record_to_row( - incorporate_row, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - record, - proof_index, - log_height, - ); - opened_index = incorporate_row.final_opened_index + 1; - used_cells += width; - } - if let Some(incorporate_sibling) = &top_level.incorporate_sibling { - self.incorporate_sibling_record_to_row( - incorporate_sibling, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - record, - proof_index, - opened_index, - log_height, - ); - used_cells += width; - } - } - self.correct_last_top_level_row( - record, - aux_cols_factory, - &mut slice[used_cells - width..used_cells], - memory, - ); - - for top_level in record.top_level.iter() { - if let Some(incorporate_row) = &top_level.incorporate_row { - for (i, chunk) in incorporate_row.chunks.iter().enumerate() { - self.inside_row_record_to_row( - chunk, - aux_cols_factory, - &mut slice[used_cells..used_cells + width], - memory, - incorporate_row, - record, - i == incorporate_row.chunks.len() - 1, - ); - used_cells += width; - } - } - } - - used_cells - } - fn simple_record_to_row( - &self, - record: &SimplePoseidonRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - ) { - let &SimplePoseidonRecord { - from_state, - instruction: - Instruction { - opcode, - a: output_register, - b: input_register_1, - c: input_register_2, - .. - }, - read_input_pointer_1, - read_input_pointer_2, - read_output_pointer, - read_data_1, - read_data_2, - write_data_1, - write_data_2, - input_pointer_1, - input_pointer_2, - output_pointer, - p2_input, - } = record; - - let read_input_pointer_1 = memory.record_by_id(read_input_pointer_1); - let read_output_pointer = memory.record_by_id(read_output_pointer); - let read_data_1 = memory.record_by_id(read_data_1); - let read_data_2 = memory.record_by_id(read_data_2); - let write_data_1 = memory.record_by_id(write_data_1); - - self.generate_subair_cols(p2_input, slice); - let cols: &mut NativePoseidon2Cols = slice.borrow_mut(); - cols.incorporate_row = F::ZERO; - cols.incorporate_sibling = F::ZERO; - cols.inside_row = F::ZERO; - cols.simple = F::ONE; - cols.end_inside_row = F::ZERO; - cols.end_top_level = F::ZERO; - cols.is_exhausted = [F::ZERO; CHUNK - 1]; - - cols.start_timestamp = F::from_canonical_u32(from_state.timestamp); - let specific: &mut SimplePoseidonSpecificCols = - cols.specific[..SimplePoseidonSpecificCols::::width()].borrow_mut(); - - specific.pc = F::from_canonical_u32(from_state.pc); - specific.is_compress = F::from_bool(opcode == COMP_POS2.global_opcode()); - specific.output_register = output_register; - specific.input_register_1 = input_register_1; - specific.input_register_2 = input_register_2; - specific.output_pointer = output_pointer; - specific.input_pointer_1 = input_pointer_1; - specific.input_pointer_2 = input_pointer_2; - aux_cols_factory.generate_read_aux(read_output_pointer, &mut specific.read_output_pointer); - aux_cols_factory - .generate_read_aux(read_input_pointer_1, &mut specific.read_input_pointer_1); - aux_cols_factory.generate_read_aux(read_data_1, &mut specific.read_data_1); - aux_cols_factory.generate_read_aux(read_data_2, &mut specific.read_data_2); - aux_cols_factory.generate_write_aux(write_data_1, &mut specific.write_data_1); - - if opcode == COMP_POS2.global_opcode() { - let read_input_pointer_2 = memory.record_by_id(read_input_pointer_2.unwrap()); - aux_cols_factory - .generate_read_aux(read_input_pointer_2, &mut specific.read_input_pointer_2); - } else { - let write_data_2 = memory.record_by_id(write_data_2.unwrap()); - aux_cols_factory.generate_write_aux(write_data_2, &mut specific.write_data_2); - } - } - - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); - - let memory = self.offline_memory.lock().unwrap(); - - let aux_cols_factory = memory.aux_cols_factory(); - - let mut used_cells = 0; - for record in self.record_set.verify_batch_records.iter() { - used_cells += self.verify_batch_record_to_rows( - record, - &aux_cols_factory, - &mut flat_trace[used_cells..], - &memory, - ); - } - for record in self.record_set.simple_permute_records.iter() { - self.simple_record_to_row( - record, - &aux_cols_factory, - &mut flat_trace[used_cells..used_cells + width], - &memory, - ); - used_cells += width; - } - // poseidon2 constraints are always checked - // following can be optimized to only hash [0; _] once - flat_trace[used_cells..] - .par_chunks_mut(width) - .for_each(|row| { - self.generate_subair_cols([F::ZERO; 2 * CHUNK], row); - }); - - RowMajorMatrix::new(flat_trace, width) - } -} - -impl Chip - for NativePoseidon2Chip, SBOX_REGISTERS> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } - fn generate_air_proof_input(self) -> AirProofInput { - AirProofInput::simple_no_pis(self.generate_trace()) - } -} diff --git a/extensions/pairing/circuit/src/fp12_chip/add.rs b/extensions/pairing/circuit/src/fp12_chip/add.rs deleted file mode 100644 index 643c68ef27..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/add.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; - -use crate::Fp12; - -pub fn fp12_add_expr(config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.add(&mut y); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/fp12_chip/mod.rs b/extensions/pairing/circuit/src/fp12_chip/mod.rs deleted file mode 100644 index c6894d0d27..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -mod add; -mod mul; -mod sub; - -pub use add::*; -pub use mul::*; -pub use sub::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/fp12_chip/mul.rs b/extensions/pairing/circuit/src/fp12_chip/mul.rs deleted file mode 100644 index 0736981de7..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/mul.rs +++ /dev/null @@ -1,175 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::Fp12Opcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; -// Input: Fp12 * 2 -// Output: Fp12 -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct Fp12MulChip( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl - Fp12MulChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = fp12_mul_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![Fp12Opcode::MUL as usize], - vec![], - range_checker, - "Fp12Mul", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn fp12_mul_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.mul(&mut y, xi); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use halo2curves_axiom::{bn256::Fq12, ff::Field}; - use itertools::Itertools; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::algebra::field::FieldExtension; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::{ - test_utils::{biguint_to_limbs, bn254_fq12_to_biguint_vec, bn254_fq2_to_biguint_vec}, - ExprBuilderConfig, - }; - use openvm_pairing_guest::bn254::{BN254_MODULUS, BN254_XI_ISIZE}; - use openvm_rv32_adapters::rv32_write_heap_default_with_increment; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - const LIMB_BITS: usize = 8; - type F = BabyBear; - - #[test] - fn test_fp12_mul_bn254() { - const NUM_LIMBS: usize = 32; - const BLOCK_SIZE: usize = 32; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - - let mut chip = Fp12MulChip::new( - adapter, - config, - BN254_XI_ISIZE, - Fp12Opcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(64); - let x = Fq12::random(&mut rng); - let y = Fq12::random(&mut rng); - let inputs = [x.to_coeffs(), y.to_coeffs()] - .concat() - .iter() - .flat_map(|&x| bn254_fq2_to_biguint_vec(x)) - .collect::>(); - - let cmp = bn254_fq12_to_biguint_vec(x * y); - let res = chip - .0 - .core - .expr() - .execute_with_output(inputs.clone(), vec![true]); - assert_eq!(res.len(), cmp.len()); - for i in 0..res.len() { - assert_eq!(res[i], cmp[i]); - } - - let x_limbs = inputs[..12] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect_vec(); - let y_limbs = inputs[12..] - .iter() - .map(|y| { - biguint_to_limbs::(y.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect_vec(); - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - x_limbs, - y_limbs, - 512, - chip.0.core.air.offset + Fp12Opcode::MUL as usize, - ); - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/fp12_chip/sub.rs b/extensions/pairing/circuit/src/fp12_chip/sub.rs deleted file mode 100644 index 470e700910..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/sub.rs +++ /dev/null @@ -1,20 +0,0 @@ -use std::{cell::RefCell, rc::Rc}; - -use openvm_circuit_primitives::var_range::VariableRangeCheckerBus; -use openvm_mod_circuit_builder::{ExprBuilder, ExprBuilderConfig, FieldExpr}; - -use crate::Fp12; - -pub fn fp12_sub_expr(config: ExprBuilderConfig, range_bus: VariableRangeCheckerBus) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x = Fp12::new(builder.clone()); - let mut y = Fp12::new(builder.clone()); - let mut res = x.sub(&mut y); - res.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/fp12_chip/tests.rs b/extensions/pairing/circuit/src/fp12_chip/tests.rs deleted file mode 100644 index a9f6b235d5..0000000000 --- a/extensions/pairing/circuit/src/fp12_chip/tests.rs +++ /dev/null @@ -1,271 +0,0 @@ -use num_bigint::BigUint; -use openvm_circuit::arch::{ - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmChipWrapper, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{ - test_utils::{ - biguint_to_limbs, bls12381_fq12_random, bn254_fq12_random, bn254_fq12_to_biguint_vec, - }, - ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_guest::{ - bls12_381::{ - BLS12_381_BLOCK_SIZE, BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS, - BLS12_381_XI_ISIZE, - }, - bn254::{BN254_BLOCK_SIZE, BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS, BN254_XI_ISIZE}, -}; -use openvm_pairing_transpiler::{Bls12381Fp12Opcode, Bn254Fp12Opcode, Fp12Opcode}; -use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; - -use super::{fp12_add_expr, fp12_mul_expr, fp12_sub_expr}; - -type F = BabyBear; - -#[allow(clippy::too_many_arguments)] -fn test_fp12_fn< - const INPUT_SIZE: usize, - const NUM_LIMBS: usize, - const LIMB_BITS: usize, - const BLOCK_SIZE: usize, ->( - mut tester: VmChipTestBuilder, - expr: FieldExpr, - offset: usize, - local_opcode_idx: usize, - name: &str, - x: Vec, - y: Vec, - var_len: usize, -) { - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![local_opcode_idx], - vec![], - tester.memory_controller().borrow().range_checker.clone(), - name, - false, - ); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let adapter = - Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - - let x_limbs = x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let y_limbs = y - .iter() - .map(|y| { - biguint_to_limbs::(y.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let mut chip = VmChipWrapper::new(adapter, core, tester.offline_memory_mutex_arc()); - - let res = chip.core.air.expr.execute([x, y].concat(), vec![]); - assert_eq!(res.len(), var_len); - - let instruction = rv32_write_heap_default( - &mut tester, - x_limbs, - y_limbs, - chip.core.air.offset + local_opcode_idx, - ); - tester.execute(&mut chip, &instruction); - - let run_tester = tester.build().load(chip).load(bitwise_chip).finalize(); - run_tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_fp12_add_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let expr = fp12_add_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(1)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(2)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::ADD as usize, - "Bn254Fp12Add", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_sub_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let expr = fp12_sub_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(59)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(3)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::SUB as usize, - "Bn254Fp12Sub", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_mul_bn254() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: BN254_NUM_LIMBS, - limb_bits: BN254_LIMB_BITS, - }; - let xi = BN254_XI_ISIZE; - let expr = fp12_mul_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - xi, - ); - - let x = bn254_fq12_to_biguint_vec(bn254_fq12_random(5)); - let y = bn254_fq12_to_biguint_vec(bn254_fq12_random(25)); - - test_fp12_fn::<12, BN254_NUM_LIMBS, BN254_LIMB_BITS, BN254_BLOCK_SIZE>( - tester, - expr, - Bn254Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::MUL as usize, - "Bn254Fp12Mul", - x, - y, - 33, - ); -} - -#[test] -fn test_fp12_add_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let expr = fp12_add_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bls12381_fq12_random(3); - let y = bls12381_fq12_random(99); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::ADD as usize, - "Bls12381Fp12Add", - x, - y, - 12, - ); -} - -#[test] -fn test_fp12_sub_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let expr = fp12_sub_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - ); - - let x = bls12381_fq12_random(8); - let y = bls12381_fq12_random(9); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::SUB as usize, - "Bls12381Fp12Sub", - x, - y, - 12, - ); -} - -// NOTE[yj]: This test requires RUST_MIN_STACK=8388608 to run without overflowing the stack, so it -// is ignored by the test runner for now -#[test] -#[ignore] -fn test_fp12_mul_bls12381() { - let tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }; - let xi = BLS12_381_XI_ISIZE; - let expr = fp12_mul_expr( - config, - tester.memory_controller().borrow().range_checker.bus(), - xi, - ); - - let x = bls12381_fq12_random(5); - let y = bls12381_fq12_random(25); - - test_fp12_fn::<36, BLS12_381_NUM_LIMBS, BLS12_381_LIMB_BITS, BLS12_381_BLOCK_SIZE>( - tester, - expr, - Bls12381Fp12Opcode::CLASS_OFFSET, - Fp12Opcode::MUL as usize, - "Bls12381Fp12Mul", - x, - y, - 46, - ); -} diff --git a/extensions/pairing/circuit/src/lib.rs b/extensions/pairing/circuit/src/lib.rs index b2b962b7f7..f96d126555 100644 --- a/extensions/pairing/circuit/src/lib.rs +++ b/extensions/pairing/circuit/src/lib.rs @@ -1,11 +1,7 @@ mod config; mod fp12; -mod fp12_chip; -mod pairing_chip; mod pairing_extension; pub use config::*; pub use fp12::*; -pub use fp12_chip::*; -pub use pairing_chip::*; pub use pairing_extension::*; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs deleted file mode 100644 index 08857995f3..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod mul_013_by_013; -mod mul_by_01234; - -pub use mul_013_by_013::*; -pub use mul_by_01234::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs deleted file mode 100644 index 36d1012e9b..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_013_by_013.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements -// Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMul013By013Chip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMul013By013Chip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_013_by_013_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_013_BY_013 as usize], - vec![], - range_checker, - "Mul013By013", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_013_by_013_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut b0 = Fp2::new(builder.clone()); - let mut c0 = Fp2::new(builder.clone()); - let mut b1 = Fp2::new(builder.clone()); - let mut c1 = Fp2::new(builder.clone()); - - // where w⁶ = xi - // l0 * l1 = 1 + (b0 + b1)w + (b0b1)w² + (c0 + c1)w³ + (b0c1 + b1c0)w⁴ + (c0c1)w⁶ - // = (1 + c0c1 * xi) + (b0 + b1)w + (b0b1)w² + (c0 + c1)w³ + (b0c1 + b1c0)w⁴ - let l0 = c0.mul(&mut c1).int_mul(xi).int_add([1, 0]); - let l1 = b0.add(&mut b1); - let l2 = b0.mul(&mut b1); - let l3 = c0.add(&mut c1); - let l4 = b0.mul(&mut c1).add(&mut b1.mul(&mut c0)); - - [l0, l1, l2, l3, l4].map(|mut l| l.save_output()); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs deleted file mode 100644 index 996372e994..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/mul_by_01234.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; - -// Input: Fp12 (12 field elements), [Fp2; 5] (5 x 2 field elements) -// Output: Fp12 (12 field elements) -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMulBy01234Chip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMulBy01234Chip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_by_01234_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_BY_01234 as usize], - vec![], - range_checker.clone(), - "MulBy01234", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_by_01234_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut f = Fp12::new(builder.clone()); - let mut x0 = Fp2::new(builder.clone()); - let mut x1 = Fp2::new(builder.clone()); - let mut x2 = Fp2::new(builder.clone()); - let mut x3 = Fp2::new(builder.clone()); - let mut x4 = Fp2::new(builder.clone()); - - let mut r = f.mul_by_01234(&mut x0, &mut x1, &mut x2, &mut x3, &mut x4, xi); - r.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs deleted file mode 100644 index 81da3169fa..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/d_type/tests.rs +++ /dev/null @@ -1,287 +0,0 @@ -use halo2curves_axiom::{ - bn256::{Fq, Fq12, Fq2, G1Affine}, - ff::Field, -}; -use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_ecc_guest::AffinePoint; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{ - test_utils::{ - biguint_to_limbs, bn254_fq12_to_biguint_vec, bn254_fq2_to_biguint_vec, bn254_fq_to_biguint, - }, - ExprBuilderConfig, -}; -use openvm_pairing_guest::{ - bn254::{BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS, BN254_XI_ISIZE}, - halo2curves_shims::bn254::{tangent_line_013, Bn254}, - pairing::{Evaluatable, LineMulDType, UnevaluatedLine}, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::{ - rv32_write_heap_default, rv32_write_heap_default_with_increment, Rv32VecHeapAdapterChip, - Rv32VecHeapTwoReadsAdapterChip, -}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use rand::{rngs::StdRng, SeedableRng}; - -use super::{super::EvaluateLineChip, *}; - -type F = BabyBear; -const NUM_LIMBS: usize = 32; -const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 32; - -#[test] -fn test_mul_013_by_013() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMul013By013Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }, - BN254_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(8); - let mut rng1 = StdRng::seed_from_u64(95); - let rnd_pt_0 = G1Affine::random(&mut rng0); - let rnd_pt_1 = G1Affine::random(&mut rng1); - let ec_pt_0 = AffinePoint:: { - x: rnd_pt_0.x, - y: rnd_pt_0.y, - }; - let ec_pt_1 = AffinePoint:: { - x: rnd_pt_1.x, - y: rnd_pt_1.y, - }; - let line0 = tangent_line_013::(ec_pt_0); - let line1 = tangent_line_013::(ec_pt_1); - let input_line0 = [ - bn254_fq2_to_biguint_vec(line0.b), - bn254_fq2_to_biguint_vec(line0.c), - ] - .concat(); - let input_line1 = [ - bn254_fq2_to_biguint_vec(line1.b), - bn254_fq2_to_biguint_vec(line1.c), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_line0.clone(), input_line1.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 10); - - let r_cmp = Bn254::mul_013_by_013(&line0, &line1); - let r_cmp_bigint = r_cmp - .map(|x| [bn254_fq_to_biguint(x.c0), bn254_fq_to_biguint(x.c1)]) - .concat(); - - for i in 0..10 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_line0_limbs = input_line0 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_line1_limbs = input_line1 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default( - &mut tester, - input_line0_limbs, - input_line1_limbs, - chip.0.core.air.offset + PairingOpcode::MUL_013_BY_013 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_mul_by_01234() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMulBy01234Chip::new( - adapter, - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - num_limbs: NUM_LIMBS, - limb_bits: LIMB_BITS, - }, - BN254_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(8); - let f = Fq12::random(&mut rng); - let x0 = Fq2::random(&mut rng); - let x1 = Fq2::random(&mut rng); - let x2 = Fq2::random(&mut rng); - let x3 = Fq2::random(&mut rng); - let x4 = Fq2::random(&mut rng); - - let input_f = bn254_fq12_to_biguint_vec(f); - let input_x = [ - bn254_fq2_to_biguint_vec(x0), - bn254_fq2_to_biguint_vec(x1), - bn254_fq2_to_biguint_vec(x2), - bn254_fq2_to_biguint_vec(x3), - bn254_fq2_to_biguint_vec(x4), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_f.clone(), input_x.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 12); - - let r_cmp = Bn254::mul_by_01234(&f, &[x0, x1, x2, x3, x4]); - let r_cmp_bigint = bn254_fq12_to_biguint_vec(r_cmp); - - for i in 0..12 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_f_limbs = input_f - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_x_limbs = input_x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_f_limbs, - input_x_limbs, - 512, - chip.0.core.air.offset + PairingOpcode::MUL_BY_01234 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn test_evaluate_line() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: BN254_LIMB_BITS, - num_limbs: BN254_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EvaluateLineChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(42); - let uneval_b = Fq2::random(&mut rng); - let uneval_c = Fq2::random(&mut rng); - let x_over_y = Fq::random(&mut rng); - let y_inv = Fq::random(&mut rng); - let mut inputs = vec![]; - inputs.extend(bn254_fq2_to_biguint_vec(uneval_b)); - inputs.extend(bn254_fq2_to_biguint_vec(uneval_c)); - inputs.push(bn254_fq_to_biguint(x_over_y)); - inputs.push(bn254_fq_to_biguint(y_inv)); - let input_limbs = inputs - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect(); - - let uneval: UnevaluatedLine = UnevaluatedLine { - b: uneval_b, - c: uneval_c, - }; - let evaluated = uneval.evaluate(&(x_over_y, y_inv)); - - let result = chip.0.core.expr().execute_with_output(inputs, vec![]); - assert_eq!(result.len(), 4); - assert_eq!(result[0], bn254_fq_to_biguint(evaluated.b.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(evaluated.b.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(evaluated.c.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(evaluated.c.c1)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs, - vec![], - chip.0.core.air.offset + PairingOpcode::EVALUATE_LINE as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs b/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs deleted file mode 100644 index dc0a8cdfe1..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/evaluate_line.rs +++ /dev/null @@ -1,102 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: UnevaluatedLine, (Fp, Fp) -// Output: EvaluatedLine -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EvaluateLineChip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EvaluateLineChip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = evaluate_line_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::EVALUATE_LINE as usize], - vec![], - range_checker, - "EvaluateLine", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn evaluate_line_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut uneval_b = Fp2::new(builder.clone()); - let mut uneval_c = Fp2::new(builder.clone()); - - let mut x_over_y = ExprBuilder::new_input(builder.clone()); - let mut y_inv = ExprBuilder::new_input(builder.clone()); - - let mut b = uneval_b.scalar_mul(&mut x_over_y); - let mut c = uneval_c.scalar_mul(&mut y_inv); - b.save_output(); - c.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs deleted file mode 100644 index b454d260ce..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod mul_023_by_023; -mod mul_by_02345; - -pub use mul_023_by_023::*; -pub use mul_by_02345::*; - -#[cfg(test)] -mod tests; diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs deleted file mode 100644 index 0d760b886e..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_023_by_023.rs +++ /dev/null @@ -1,101 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: line0.b, line0.c, line1.b, line1.c : 2 x 4 field elements -// Output: 5 Fp2 coefficients -> 10 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMul023By023Chip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMul023By023Chip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_023_by_023_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_023_BY_023 as usize], - vec![], - range_checker, - "Mul023By023", - true, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_023_by_023_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut b0 = Fp2::new(builder.clone()); // x2 - let mut c0 = Fp2::new(builder.clone()); // x3 - let mut b1 = Fp2::new(builder.clone()); // y2 - let mut c1 = Fp2::new(builder.clone()); // y3 - - // where w⁶ = xi - // l0 * l1 = c0c1 + (c0b1 + c1b0)w² + (c0 + c1)w³ + (b0b1)w⁴ + (b0 +b1)w⁵ + w⁶ - // = (c0c1 + xi) + (c0b1 + c1b0)w² + (c0 + c1)w³ + (b0b1)w⁴ + (b0 + b1)w⁵ - let l0 = c0.mul(&mut c1).int_add(xi); - let l2 = c0.mul(&mut b1).add(&mut c1.mul(&mut b0)); - let l3 = c0.add(&mut c1); - let l4 = b0.mul(&mut b1); - let l5 = b0.add(&mut b1); - - [l0, l2, l3, l4, l5].map(|mut l| l.save_output()); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs deleted file mode 100644 index ad0e91e7bd..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/mul_by_02345.rs +++ /dev/null @@ -1,113 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapTwoReadsAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -use crate::Fp12; - -// Input: 2 Fp12: 2 x 12 field elements -// Output: Fp12 -> 12 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct EcLineMulBy02345Chip< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS1: usize, - const INPUT_BLOCKS2: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > EcLineMulBy02345Chip -{ - pub fn new( - adapter: Rv32VecHeapTwoReadsAdapterChip< - F, - INPUT_BLOCKS1, - INPUT_BLOCKS2, - OUTPUT_BLOCKS, - BLOCK_SIZE, - BLOCK_SIZE, - >, - range_checker: SharedVariableRangeCheckerChip, - config: ExprBuilderConfig, - xi: [isize; 2], - offset: usize, - offline_memory: Arc>>, - ) -> Self { - assert!( - xi[0].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); // not a hard rule, but we expect xi to be small - assert!( - xi[1].unsigned_abs() < 1 << config.limb_bits, - "expect xi to be small" - ); - let expr = mul_by_02345_expr(config, range_checker.bus(), xi); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MUL_BY_02345 as usize], - vec![], - range_checker, - "MulBy02345", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -pub fn mul_by_02345_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, - xi: [isize; 2], -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config.clone(), range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut f = Fp12::new(builder.clone()); - let mut x0 = Fp2::new(builder.clone()); - let mut x2 = Fp2::new(builder.clone()); - let mut x3 = Fp2::new(builder.clone()); - let mut x4 = Fp2::new(builder.clone()); - let mut x5 = Fp2::new(builder.clone()); - - let mut r = f.mul_by_02345(&mut x0, &mut x2, &mut x3, &mut x4, &mut x5, xi); - r.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs b/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs deleted file mode 100644 index 4331d2278e..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/m_type/tests.rs +++ /dev/null @@ -1,217 +0,0 @@ -use halo2curves_axiom::{ - bls12_381::{Fq, Fq12, Fq2, G1Affine}, - ff::Field, -}; -use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, -}; -use openvm_ecc_guest::AffinePoint; -use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; -use openvm_mod_circuit_builder::{test_utils::*, ExprBuilderConfig}; -use openvm_pairing_guest::{ - bls12_381::{BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS, BLS12_381_XI_ISIZE}, - halo2curves_shims::bls12_381::{tangent_line_023, Bls12_381}, - pairing::LineMulMType, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::{ - rv32_write_heap_default_with_increment, Rv32VecHeapAdapterChip, Rv32VecHeapTwoReadsAdapterChip, -}; -use openvm_stark_backend::p3_field::FieldAlgebra; -use openvm_stark_sdk::p3_baby_bear::BabyBear; -use rand::{rngs::StdRng, SeedableRng}; - -use super::*; - -type F = BabyBear; -const NUM_LIMBS: usize = 48; -const LIMB_BITS: usize = 8; -const BLOCK_SIZE: usize = 16; - -#[test] -fn test_mul_023_by_023() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMul023By023Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }, - BLS12_381_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(15); - let mut rng1 = StdRng::seed_from_u64(95); - let rnd_pt_0 = G1Affine::random(&mut rng0); - let rnd_pt_1 = G1Affine::random(&mut rng1); - let ec_pt_0 = AffinePoint:: { - x: rnd_pt_0.x, - y: rnd_pt_0.y, - }; - let ec_pt_1 = AffinePoint:: { - x: rnd_pt_1.x, - y: rnd_pt_1.y, - }; - let line0 = tangent_line_023::(ec_pt_0); - let line1 = tangent_line_023::(ec_pt_1); - let input_line0 = [ - bls12381_fq2_to_biguint_vec(line0.b), - bls12381_fq2_to_biguint_vec(line0.c), - ] - .concat(); - let input_line1 = [ - bls12381_fq2_to_biguint_vec(line1.b), - bls12381_fq2_to_biguint_vec(line1.c), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_line0.clone(), input_line1.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 10); - - let r_cmp = Bls12_381::mul_023_by_023(&line0, &line1); - let r_cmp_bigint = r_cmp - .map(|x| [bls12381_fq_to_biguint(x.c0), bls12381_fq_to_biguint(x.c1)]) - .concat(); - - for i in 0..10 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_line0_limbs = input_line0 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_line1_limbs = input_line1 - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_line0_limbs, - input_line1_limbs, - 512, - chip.0.core.air.offset + PairingOpcode::MUL_023_BY_023 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} - -// NOTE[yj]: this test requires `RUST_MIN_STACK=8388608` to run otherwise it will overflow the stack -#[test] -#[ignore] -fn test_mul_by_02345() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapTwoReadsAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = EcLineMulBy02345Chip::new( - adapter, - tester.memory_controller().borrow().range_checker.clone(), - ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - num_limbs: BLS12_381_NUM_LIMBS, - limb_bits: BLS12_381_LIMB_BITS, - }, - BLS12_381_XI_ISIZE, - PairingOpcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); - - let mut rng = StdRng::seed_from_u64(19); - let f = Fq12::random(&mut rng); - let x0 = Fq2::random(&mut rng); - let x2 = Fq2::random(&mut rng); - let x3 = Fq2::random(&mut rng); - let x4 = Fq2::random(&mut rng); - let x5 = Fq2::random(&mut rng); - - let input_f = bls12381_fq12_to_biguint_vec(f); - let input_x = [ - bls12381_fq2_to_biguint_vec(x0), - bls12381_fq2_to_biguint_vec(x2), - bls12381_fq2_to_biguint_vec(x3), - bls12381_fq2_to_biguint_vec(x4), - bls12381_fq2_to_biguint_vec(x5), - ] - .concat(); - - let vars = chip - .0 - .core - .expr() - .execute([input_f.clone(), input_x.clone()].concat(), vec![]); - let output_indices = chip.0.core.expr().builder.output_indices.clone(); - let output = output_indices - .iter() - .map(|i| vars[*i].clone()) - .collect::>(); - assert_eq!(output.len(), 12); - - let r_cmp = Bls12_381::mul_by_02345(&f, &[x0, x2, x3, x4, x5]); - let r_cmp_bigint = bls12381_fq12_to_biguint_vec(r_cmp); - - for i in 0..12 { - assert_eq!(output[i], r_cmp_bigint[i]); - } - - let input_f_limbs = input_f - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - let input_x_limbs = input_x - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS).map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default_with_increment( - &mut tester, - input_f_limbs, - input_x_limbs, - 1024, - chip.0.core.air.offset + PairingOpcode::MUL_BY_02345 as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); -} diff --git a/extensions/pairing/circuit/src/pairing_chip/line/mod.rs b/extensions/pairing/circuit/src/pairing_chip/line/mod.rs deleted file mode 100644 index acf02c72be..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/line/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod d_type; -mod evaluate_line; -mod m_type; - -pub use d_type::*; -pub use evaluate_line::*; -pub use m_type::*; diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs deleted file mode 100644 index 77084428c9..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_and_add_step.rs +++ /dev/null @@ -1,215 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: two AffinePoint: 4 field elements each -// Output: (AffinePoint, UnevaluatedLine, UnevaluatedLine) -> 2*2 + 2*2 + 2*2 = 12 -// field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct MillerDoubleAndAddStepChip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - pub VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > MillerDoubleAndAddStepChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = miller_double_and_add_step_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MILLER_DOUBLE_AND_ADD_STEP as usize], - vec![], - range_checker, - "MillerDoubleAndAddStep", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -// Ref: openvm_pairing_guest::miller_step -pub fn miller_double_and_add_step_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x_s = Fp2::new(builder.clone()); - let mut y_s = Fp2::new(builder.clone()); - let mut x_q = Fp2::new(builder.clone()); - let mut y_q = Fp2::new(builder.clone()); - - // λ1 = (y_s - y_q) / (x_s - x_q) - let mut lambda1 = y_s.sub(&mut y_q).div(&mut x_s.sub(&mut x_q)); - let mut x_sq = lambda1.square().sub(&mut x_s).sub(&mut x_q); - // λ2 = -λ1 - 2y_s / (x_{s+q} - x_s) - let mut lambda2 = lambda1 - .neg() - .sub(&mut y_s.int_mul([2, 0]).div(&mut x_sq.sub(&mut x_s))); - let mut x_sqs = lambda2.square().sub(&mut x_s).sub(&mut x_sq); - let mut y_sqs = lambda2.mul(&mut (x_s.sub(&mut x_sqs))).sub(&mut y_s); - - x_sqs.save_output(); - y_sqs.save_output(); - - let mut b0 = lambda1.neg(); - let mut c0 = lambda1.mul(&mut x_s).sub(&mut y_s); - b0.save_output(); - c0.save_output(); - - let mut b1 = lambda2.neg(); - let mut c1 = lambda2.mul(&mut x_s).sub(&mut y_s); - b1.save_output(); - c1.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use halo2curves_axiom::bn256::G2Affine; - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::AffinePoint; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::test_utils::{biguint_to_limbs, bn254_fq_to_biguint}; - use openvm_pairing_guest::{ - bn254::BN254_MODULUS, halo2curves_shims::bn254::Bn254, pairing::MillerStep, - }; - use openvm_pairing_transpiler::PairingOpcode; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - type F = BabyBear; - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 32; - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_and_add() { - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleAndAddStepChip::new( - adapter, - ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: LIMB_BITS, - num_limbs: NUM_LIMBS, - }, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(2); - let Q = G2Affine::random(&mut rng0); - let Q2 = G2Affine::random(&mut rng0); - let inputs = [ - Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1, Q2.x.c0, Q2.x.c1, Q2.y.c0, Q2.y.c1, - ] - .map(bn254_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let Q_ecpoint2 = AffinePoint { x: Q2.x, y: Q2.y }; - let (Q_daa, l_qa, l_sqs) = Bn254::miller_double_and_add_step(&Q_ecpoint, &Q_ecpoint2); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 12); // AffinePoint and 4 Fp2 coefficients - assert_eq!(result[0], bn254_fq_to_biguint(Q_daa.x.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(Q_daa.x.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(Q_daa.y.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(Q_daa.y.c1)); - assert_eq!(result[4], bn254_fq_to_biguint(l_qa.b.c0)); - assert_eq!(result[5], bn254_fq_to_biguint(l_qa.b.c1)); - assert_eq!(result[6], bn254_fq_to_biguint(l_qa.c.c0)); - assert_eq!(result[7], bn254_fq_to_biguint(l_qa.c.c1)); - assert_eq!(result[8], bn254_fq_to_biguint(l_sqs.b.c0)); - assert_eq!(result[9], bn254_fq_to_biguint(l_sqs.b.c1)); - assert_eq!(result[10], bn254_fq_to_biguint(l_sqs.c.c0)); - assert_eq!(result[11], bn254_fq_to_biguint(l_sqs.c.c1)); - - let input1_limbs = inputs[0..4] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let input2_limbs = inputs[4..8] - .iter() - .map(|x| { - biguint_to_limbs::(x.clone(), LIMB_BITS) - .map(BabyBear::from_canonical_u32) - }) - .collect::>(); - - let instruction = rv32_write_heap_default( - &mut tester, - input1_limbs, - input2_limbs, - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_AND_ADD_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs b/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs deleted file mode 100644 index 519eb473a5..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/miller_double_step.rs +++ /dev/null @@ -1,253 +0,0 @@ -use std::{ - cell::RefCell, - rc::Rc, - sync::{Arc, Mutex}, -}; - -use openvm_algebra_circuit::Fp2; -use openvm_circuit::{arch::VmChipWrapper, system::memory::OfflineMemory}; -use openvm_circuit_derive::InstructionExecutor; -use openvm_circuit_primitives::var_range::{ - SharedVariableRangeCheckerChip, VariableRangeCheckerBus, -}; -use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; -use openvm_mod_circuit_builder::{ - ExprBuilder, ExprBuilderConfig, FieldExpr, FieldExpressionCoreChip, -}; -use openvm_pairing_transpiler::PairingOpcode; -use openvm_rv32_adapters::Rv32VecHeapAdapterChip; -use openvm_stark_backend::p3_field::PrimeField32; - -// Input: AffinePoint: 4 field elements -// Output: (AffinePoint, Fp2, Fp2) -> 8 field elements -#[derive(Chip, ChipUsageGetter, InstructionExecutor)] -pub struct MillerDoubleStepChip< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, ->( - VmChipWrapper< - F, - Rv32VecHeapAdapterChip, - FieldExpressionCoreChip, - >, -); - -impl< - F: PrimeField32, - const INPUT_BLOCKS: usize, - const OUTPUT_BLOCKS: usize, - const BLOCK_SIZE: usize, - > MillerDoubleStepChip -{ - pub fn new( - adapter: Rv32VecHeapAdapterChip, - config: ExprBuilderConfig, - offset: usize, - range_checker: SharedVariableRangeCheckerChip, - offline_memory: Arc>>, - ) -> Self { - let expr = miller_double_step_expr(config, range_checker.bus()); - let core = FieldExpressionCoreChip::new( - expr, - offset, - vec![PairingOpcode::MILLER_DOUBLE_STEP as usize], - vec![], - range_checker, - "MillerDoubleStep", - false, - ); - Self(VmChipWrapper::new(adapter, core, offline_memory)) - } -} - -// Ref: https://github.com/openvm-org/openvm/blob/f7d6fa7b8ef247e579740eb652fcdf5a04259c28/lib/ecc-execution/src/common/miller_step.rs#L7 -pub fn miller_double_step_expr( - config: ExprBuilderConfig, - range_bus: VariableRangeCheckerBus, -) -> FieldExpr { - config.check_valid(); - let builder = ExprBuilder::new(config, range_bus.range_max_bits); - let builder = Rc::new(RefCell::new(builder)); - - let mut x_s = Fp2::new(builder.clone()); - let mut y_s = Fp2::new(builder.clone()); - - let mut three_x_square = x_s.square().int_mul([3, 0]); - let mut lambda = three_x_square.div(&mut y_s.int_mul([2, 0])); - let mut x_2s = lambda.square().sub(&mut x_s.int_mul([2, 0])); - let mut y_2s = lambda.mul(&mut (x_s.sub(&mut x_2s))).sub(&mut y_s); - x_2s.save_output(); - y_2s.save_output(); - - let mut b = lambda.neg(); - let mut c = lambda.mul(&mut x_s).sub(&mut y_s); - b.save_output(); - c.save_output(); - - let builder = builder.borrow().clone(); - FieldExpr::new(builder, range_bus, false) -} - -#[cfg(test)] -mod tests { - use openvm_circuit::arch::testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}; - use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, - }; - use openvm_ecc_guest::AffinePoint; - use openvm_instructions::{riscv::RV32_CELL_BITS, LocalOpcode}; - use openvm_mod_circuit_builder::test_utils::{ - biguint_to_limbs, bls12381_fq_to_biguint, bn254_fq_to_biguint, - }; - use openvm_pairing_guest::{ - bls12_381::{BLS12_381_LIMB_BITS, BLS12_381_MODULUS, BLS12_381_NUM_LIMBS}, - bn254::{BN254_LIMB_BITS, BN254_MODULUS, BN254_NUM_LIMBS}, - halo2curves_shims::{bls12_381::Bls12_381, bn254::Bn254}, - pairing::MillerStep, - }; - use openvm_pairing_transpiler::PairingOpcode; - use openvm_rv32_adapters::{rv32_write_heap_default, Rv32VecHeapAdapterChip}; - use openvm_stark_backend::p3_field::FieldAlgebra; - use openvm_stark_sdk::p3_baby_bear::BabyBear; - use rand::{rngs::StdRng, SeedableRng}; - - use super::*; - - type F = BabyBear; - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_bn254() { - use halo2curves_axiom::bn256::G2Affine; - const NUM_LIMBS: usize = 32; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 32; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BN254_MODULUS.clone(), - limb_bits: BN254_LIMB_BITS, - num_limbs: BN254_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleStepChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(2); - let Q = G2Affine::random(&mut rng0); - let inputs = [Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1].map(bn254_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let (Q_acc_init, l_init) = Bn254::miller_double_step(&Q_ecpoint); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 8); // AffinePoint and two Fp2 coefficients - assert_eq!(result[0], bn254_fq_to_biguint(Q_acc_init.x.c0)); - assert_eq!(result[1], bn254_fq_to_biguint(Q_acc_init.x.c1)); - assert_eq!(result[2], bn254_fq_to_biguint(Q_acc_init.y.c0)); - assert_eq!(result[3], bn254_fq_to_biguint(Q_acc_init.y.c1)); - assert_eq!(result[4], bn254_fq_to_biguint(l_init.b.c0)); - assert_eq!(result[5], bn254_fq_to_biguint(l_init.b.c1)); - assert_eq!(result[6], bn254_fq_to_biguint(l_init.c.c0)); - assert_eq!(result[7], bn254_fq_to_biguint(l_init.c.c1)); - - let input_limbs = inputs - .map(|x| biguint_to_limbs::(x, LIMB_BITS).map(BabyBear::from_canonical_u32)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs.to_vec(), - vec![], - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } - - #[test] - #[allow(non_snake_case)] - fn test_miller_double_bls12_381() { - use halo2curves_axiom::bls12_381::G2Affine; - const NUM_LIMBS: usize = 48; - const LIMB_BITS: usize = 8; - const BLOCK_SIZE: usize = 16; - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let config = ExprBuilderConfig { - modulus: BLS12_381_MODULUS.clone(), - limb_bits: BLS12_381_LIMB_BITS, - num_limbs: BLS12_381_NUM_LIMBS, - }; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let adapter = Rv32VecHeapAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - bitwise_chip.clone(), - ); - let mut chip = MillerDoubleStepChip::new( - adapter, - config, - PairingOpcode::CLASS_OFFSET, - tester.range_checker(), - tester.offline_memory_mutex_arc(), - ); - - let mut rng0 = StdRng::seed_from_u64(12); - let Q = G2Affine::random(&mut rng0); - let inputs = [Q.x.c0, Q.x.c1, Q.y.c0, Q.y.c1].map(bls12381_fq_to_biguint); - - let Q_ecpoint = AffinePoint { x: Q.x, y: Q.y }; - let (Q_acc_init, l_init) = Bls12_381::miller_double_step(&Q_ecpoint); - let result = chip - .0 - .core - .expr() - .execute_with_output(inputs.to_vec(), vec![]); - assert_eq!(result.len(), 8); // AffinePoint and two Fp2 coefficients - assert_eq!(result[0], bls12381_fq_to_biguint(Q_acc_init.x.c0)); - assert_eq!(result[1], bls12381_fq_to_biguint(Q_acc_init.x.c1)); - assert_eq!(result[2], bls12381_fq_to_biguint(Q_acc_init.y.c0)); - assert_eq!(result[3], bls12381_fq_to_biguint(Q_acc_init.y.c1)); - assert_eq!(result[4], bls12381_fq_to_biguint(l_init.b.c0)); - assert_eq!(result[5], bls12381_fq_to_biguint(l_init.b.c1)); - assert_eq!(result[6], bls12381_fq_to_biguint(l_init.c.c0)); - assert_eq!(result[7], bls12381_fq_to_biguint(l_init.c.c1)); - - let input_limbs = inputs - .map(|x| biguint_to_limbs::(x, LIMB_BITS).map(BabyBear::from_canonical_u32)); - - let instruction = rv32_write_heap_default( - &mut tester, - input_limbs.to_vec(), - vec![], - chip.0.core.air.offset + PairingOpcode::MILLER_DOUBLE_STEP as usize, - ); - - tester.execute(&mut chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); - } -} diff --git a/extensions/pairing/circuit/src/pairing_chip/mod.rs b/extensions/pairing/circuit/src/pairing_chip/mod.rs deleted file mode 100644 index df00df16ce..0000000000 --- a/extensions/pairing/circuit/src/pairing_chip/mod.rs +++ /dev/null @@ -1,8 +0,0 @@ -mod line; -mod miller_double_step; - -pub use line::*; -pub use miller_double_step::*; - -mod miller_double_and_add_step; -pub use miller_double_and_add_step::*; diff --git a/extensions/pairing/circuit/src/pairing_extension.rs b/extensions/pairing/circuit/src/pairing_extension.rs index eca4cea8dd..763da36e82 100644 --- a/extensions/pairing/circuit/src/pairing_extension.rs +++ b/extensions/pairing/circuit/src/pairing_extension.rs @@ -5,7 +5,7 @@ use openvm_circuit::{ arch::{VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor}; use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter}; use openvm_ecc_circuit::CurveConfig; @@ -19,8 +19,6 @@ use openvm_stark_backend::p3_field::PrimeField32; use serde::{Deserialize, Serialize}; use strum::FromRepr; -use super::*; - // All the supported pairing curves. #[derive(Clone, Copy, Debug, FromRepr, Serialize, Deserialize)] #[repr(usize)] @@ -60,14 +58,9 @@ pub struct PairingExtension { pub supported_curves: Vec, } -#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)] +#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum, InsExecutorE1)] pub enum PairingExtensionExecutor { - // bn254 (32 limbs) - MillerDoubleAndAddStepRv32_32(MillerDoubleAndAddStepChip), - EvaluateLineRv32_32(EvaluateLineChip), - // bls12-381 (48 limbs) - MillerDoubleAndAddStepRv32_48(MillerDoubleAndAddStepChip), - EvaluateLineRv32_48(EvaluateLineChip), + Phantom(PhantomChip), } #[derive(ChipUsageGetter, Chip, AnyEnum, From)] @@ -101,7 +94,7 @@ pub(crate) mod phantom { use eyre::bail; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_ecc_guest::{algebra::field::FieldExtension, halo2curves::ff, AffinePoint}; use openvm_instructions::{ @@ -113,7 +106,7 @@ pub(crate) mod phantom { bn254::BN254_NUM_LIMBS, pairing::{FinalExp, MultiMillerLoop}, }; - use openvm_rv32im_circuit::adapters::{compose, unsafe_read_rv32_register}; + use openvm_rv32im_circuit::adapters::{memory_read, new_read_rv32_register}; use openvm_stark_backend::p3_field::PrimeField32; use super::PairingCurve; @@ -123,43 +116,40 @@ pub(crate) mod phantom { impl PhantomSubExecutor for PairingHintSubEx { fn phantom_execute( &mut self, - memory: &MemoryController, + memory: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, c_upper: u16, ) -> eyre::Result<()> { - let rs1 = unsafe_read_rv32_register(memory, a); - let rs2 = unsafe_read_rv32_register(memory, b); + let rs1 = new_read_rv32_register(memory, 1, a); + let rs2 = new_read_rv32_register(memory, 1, b); hint_pairing(memory, &mut streams.hint_stream, rs1, rs2, c_upper) } } fn hint_pairing( - memory: &MemoryController, + memory: &GuestMemory, hint_stream: &mut VecDeque, rs1: u32, rs2: u32, c_upper: u16, ) -> eyre::Result<()> { - let p_ptr = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1), - )); + let p_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs1)); // len in bytes - let p_len = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs1 + RV32_REGISTER_NUM_LIMBS as u32), - )); - let q_ptr = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs2), + let p_len = u32::from_le_bytes(memory_read( + memory, + RV32_MEMORY_AS, + rs1 + RV32_REGISTER_NUM_LIMBS as u32, )); + + let q_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs2)); // len in bytes - let q_len = compose(memory.unsafe_read( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(rs2 + RV32_REGISTER_NUM_LIMBS as u32), + let q_len = u32::from_le_bytes(memory_read( + memory, + RV32_MEMORY_AS, + rs2 + RV32_REGISTER_NUM_LIMBS as u32, )); match PairingCurve::from_repr(c_upper as usize) { @@ -255,22 +245,17 @@ pub(crate) mod phantom { } fn read_fp( - memory: &MemoryController, + memory: &GuestMemory, ptr: u32, ) -> eyre::Result where Fp::Repr: From<[u8; N]>, { - let mut repr = [0u8; N]; - for (i, byte) in repr.iter_mut().enumerate() { - *byte = memory - .unsafe_read_cell( - F::from_canonical_u32(RV32_MEMORY_AS), - F::from_canonical_u32(ptr + i as u32), - ) - .as_canonical_u32() - .try_into()?; - } + let repr: [u8; N] = memory + .memory + .read_range_generic((RV32_MEMORY_AS, ptr), N) + .try_into() + .unwrap(); Fp::from_repr(repr.into()) .into_option() .ok_or(eyre::eyre!("bad ff::PrimeField repr")) diff --git a/extensions/pairing/transpiler/src/lib.rs b/extensions/pairing/transpiler/src/lib.rs index 7777c37c91..e80deaf154 100644 --- a/extensions/pairing/transpiler/src/lib.rs +++ b/extensions/pairing/transpiler/src/lib.rs @@ -1,71 +1,11 @@ use openvm_instructions::{ - instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant, + instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, PhantomDiscriminant, }; -use openvm_instructions_derive::LocalOpcode; use openvm_pairing_guest::{PairingBaseFunct7, OPCODE, PAIRING_FUNCT3}; use openvm_stark_backend::p3_field::PrimeField32; use openvm_transpiler::{TranspilerExtension, TranspilerOutput}; use rrs_lib::instruction_formats::RType; -use strum::{EnumCount, EnumIter, FromRepr}; - -// NOTE: the following opcodes are enabled only in testing and not enabled in the VM Extension -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x750] -#[repr(usize)] -#[allow(non_camel_case_types)] -pub enum PairingOpcode { - MILLER_DOUBLE_AND_ADD_STEP, - MILLER_DOUBLE_STEP, - EVALUATE_LINE, - MUL_013_BY_013, - MUL_023_BY_023, - MUL_BY_01234, - MUL_BY_02345, -} - -// NOTE: Fp12 opcodes are only enabled in testing and not enabled in the VM Extension -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x700] -#[repr(usize)] -#[allow(non_camel_case_types)] -pub enum Fp12Opcode { - ADD, - SUB, - MUL, -} -const FP12_OPS: usize = 4; - -pub struct Bn254Fp12Opcode(Fp12Opcode); - -impl LocalOpcode for Bn254Fp12Opcode { - const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET; - - fn from_usize(value: usize) -> Self { - Self(Fp12Opcode::from_usize(value)) - } - - fn local_usize(&self) -> usize { - self.0.local_usize() - } -} - -pub struct Bls12381Fp12Opcode(Fp12Opcode); - -impl LocalOpcode for Bls12381Fp12Opcode { - const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET + FP12_OPS; - - fn from_usize(value: usize) -> Self { - Self(Fp12Opcode::from_usize(value - FP12_OPS)) - } - - fn local_usize(&self) -> usize { - self.0.local_usize() + FP12_OPS - } -} +use strum::FromRepr; #[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)] #[repr(u16)] diff --git a/extensions/rv32-adapters/src/eq_mod.rs b/extensions/rv32-adapters/src/eq_mod.rs index ab80481f19..e8e2c88f47 100644 --- a/extensions/rv32-adapters/src/eq_mod.rs +++ b/extensions/rv32-adapters/src/eq_mod.rs @@ -1,22 +1,19 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -29,16 +26,14 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + memory_read_from_state, memory_write_from_state, new_read_rv32_register_from_state, + tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; -use serde_with::serde_as; /// This adapter reads from NUM_READS <= 2 pointers and writes to a register. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -47,7 +42,7 @@ use serde_with::serde_as; /// starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes are to 32-bit register rd. #[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct Rv32IsEqualModAdapterCols< T, const NUM_READS: usize, @@ -227,209 +222,216 @@ impl< } } -pub struct Rv32IsEqualModAdapterChip< - F: Field, +pub struct Rv32IsEqualModeAdapterStep< const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, > { - pub air: Rv32IsEqualModAdapterAir, + pointer_max_bits: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, } impl< - F: PrimeField32, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, - > Rv32IsEqualModAdapterChip + > Rv32IsEqualModeAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); Self { - air: Rv32IsEqualModAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, } } } -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32IsEqualModReadRecord< - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCK_SIZE: usize, -> { - #[serde(with = "BigArray")] - pub rs: [RecordId; NUM_READS], - #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")] - pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS], -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32IsEqualModWriteRecord { - pub from_state: ExecutionState, - pub rd_id: RecordId, -} - impl< F: PrimeField32, + CTX, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCK_SIZE: usize, const TOTAL_READ_SIZE: usize, - > VmAdapterChip - for Rv32IsEqualModAdapterChip + > AdapterTraceStep + for Rv32IsEqualModeAdapterStep +where + F: PrimeField32, { - type ReadRecord = Rv32IsEqualModReadRecord; - type WriteRecord = Rv32IsEqualModWriteRecord; - type Air = Rv32IsEqualModAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - NUM_READS, - 1, - TOTAL_READ_SIZE, - RV32_REGISTER_NUM_LIMBS, - >; + const WIDTH: usize = + Rv32IsEqualModAdapterCols::::width(); + type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type TraceContext<'a> = (); + + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let cols: &mut Rv32IsEqualModAdapterCols = + adapter_row.borrow_mut(); + cols.from_state.pc = F::from_canonical_u32(pc); + cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - fn preprocess( - &mut self, - memory: &mut MemoryController, + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + adapter_row: &mut [F], + ) -> Self::ReadData { let Instruction { b, c, d, e, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + let e = e.as_canonical_u32(); + let d = d.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + let cols: &mut Rv32IsEqualModAdapterCols = + adapter_row.borrow_mut(); - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { + // Read register values + let rs_vals: [_; NUM_READS] = from_fn(|i| { let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record + cols.rs_ptr[i] = addr; + let rs_val = tracing_read(memory, d, addr.as_canonical_u32(), &mut cols.rs_read_aux[i]); + cols.rs_val[i] = rs_val.map(F::from_canonical_u8); + u32::from_le_bytes(rs_val) }); - let read_records = rs_vals.map(|address| { - debug_assert!(address < (1 << self.air.address_bits)); - from_fn(|i| { - memory - .read::(e, F::from_canonical_u32(address + (i * BLOCK_SIZE) as u32)) + // Read memory values + from_fn(|i| { + assert!(rs_vals[i] as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits)); + from_fn::<_, BLOCKS_PER_READ, _>(|j| { + tracing_read::<_, BLOCK_SIZE>( + memory, + e, + rs_vals[i] + (j * BLOCK_SIZE) as u32, + &mut cols.heap_read_aux[i][j], + ) }) - }); - - let read_data = read_records.map(|r| { - let read = r.map(|x| x.1); - let mut read_it = read.iter().flatten(); - from_fn(|_| *(read_it.next().unwrap())) - }); - let record = Rv32IsEqualModReadRecord { - rs: rs_records, - reads: read_records.map(|r| r.map(|x| x.0)), - }; - - Ok((read_data, record)) + .concat() + .try_into() + .unwrap() + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { + adapter_row: &mut [F], + data: &Self::WriteData, + ) { let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - debug_assert!( - memory.timestamp() - from_state.timestamp - == (NUM_READS * (BLOCKS_PER_READ + 1) + 1) as u32, - "timestamp delta is {}, expected {}", - memory.timestamp() - from_state.timestamp, - NUM_READS * (BLOCKS_PER_READ + 1) + 1 + let cols: &mut Rv32IsEqualModAdapterCols = + adapter_row.borrow_mut(); + cols.rd_ptr = a; + tracing_write( + memory, + d.as_canonical_u32(), + a.as_canonical_u32(), + data, + &mut cols.writes_aux, ); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) } - fn generate_trace_row( + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32IsEqualModAdapterCols = - row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rs = read_record.rs.map(|r| memory.record_by_id(r)); - for (i, r) in rs.iter().enumerate() { - row_slice.rs_ptr[i] = r.pointer; - row_slice.rs_val[i].copy_from_slice(r.data_slice()); - aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); - for (j, x) in read_record.reads[i].iter().enumerate() { - let read = memory.record_by_id(*x); - aux_cols_factory.generate_read_aux(read, &mut row_slice.heap_read_aux[i][j]); - } - } + let cols: &mut Rv32IsEqualModAdapterCols = + adapter_row.borrow_mut(); + let mut timestamp = cols.from_state.timestamp.as_canonical_u32(); + let mut timestamp_pp = || { + timestamp += 1; + timestamp - 1 + }; - let rd = memory.record_by_id(write_record.rd_id); - row_slice.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); + cols.rs_read_aux.iter_mut().for_each(|aux| { + mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut()); + }); - // Range checks - let need_range_check: [u32; 2] = from_fn(|i| { - if i < NUM_READS { - rs[i] - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - } else { - 0 - } + cols.heap_read_aux.iter_mut().for_each(|reads| { + reads + .iter_mut() + .for_each(|aux| mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut())); }); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits; + + mem_helper.fill_from_prev(timestamp_pp(), cols.writes_aux.as_mut()); + + // Range checks: + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; self.bitwise_lookup_chip.request_range( - need_range_check[0] << limb_shift_bits, - need_range_check[1] << limb_shift_bits, + cols.rs_val[0][RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + if NUM_READS > 1 { + cols.rs_val[1][RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits + } else { + 0 + }, ); } +} - fn air(&self) -> &Self::Air { - &self.air +impl< + F: PrimeField32, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCK_SIZE: usize, + const TOTAL_READ_SIZE: usize, + > AdapterExecutorE1 + for Rv32IsEqualModeAdapterStep +{ + type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, c, d, e, .. } = *instruction; + + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + // Read register values + let rs_vals = from_fn(|i| { + let addr = if i == 0 { b } else { c }; + new_read_rv32_register_from_state(state, d, addr.as_canonical_u32()) + }); + + // Read memory values + rs_vals.map(|address| { + assert!(address as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits)); + memory_read_from_state(state, e, address) + }) + } + + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, .. } = *instruction; + memory_write_from_state(state, d.as_canonical_u32(), a.as_canonical_u32(), data); } } diff --git a/extensions/rv32-adapters/src/heap.rs b/extensions/rv32-adapters/src/heap.rs index cd9f93abbc..d596092e39 100644 --- a/extensions/rv32-adapters/src/heap.rs +++ b/extensions/rv32-adapters/src/heap.rs @@ -1,18 +1,14 @@ -use std::{ - array::{self, from_fn}, - borrow::Borrow, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, MinimalInstruction, VmAdapterAir, VmStateMut, }, - system::{ - memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory}, - program::ProgramBus, + system::memory::{ + offline_checker::MemoryBridge, + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -20,20 +16,15 @@ use openvm_circuit_primitives::bitwise_op_lookup::{ }; use openvm_instructions::{ instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, }; -use openvm_rv32im_circuit::adapters::read_rv32_register; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, PrimeField32}, }; -use super::{ - vec_heap_generate_trace_row_impl, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, - Rv32VecHeapReadRecord, Rv32VecHeapWriteRecord, -}; +use crate::{Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols, Rv32VecHeapAdapterStep}; /// This adapter reads from NUM_READS <= 2 pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -101,137 +92,106 @@ impl< } } -pub struct Rv32HeapAdapterChip< - F: Field, +pub struct Rv32HeapAdapterStep< const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, -> { - pub air: Rv32HeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} +>(Rv32VecHeapAdapterStep); -impl - Rv32HeapAdapterChip +impl + Rv32HeapAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); - Self { - air: Rv32HeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + Rv32HeapAdapterStep(Rv32VecHeapAdapterStep::new( + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, - } + )) } } -impl - VmAdapterChip for Rv32HeapAdapterChip +impl< + F: PrimeField32, + CTX, + const NUM_READS: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceStep for Rv32HeapAdapterStep +where + F: PrimeField32, { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord<1, WRITE_SIZE>; - type Air = Rv32HeapAdapterAir; - type Interface = - BasicAdapterInterface, NUM_READS, 1, READ_SIZE, WRITE_SIZE>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = + Rv32VecHeapAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; 1]; + + type TraceContext<'a> = (); + + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_cols: &mut Rv32VecHeapAdapterCols = + adapter_row.borrow_mut(); + adapter_cols.from_state.pc = F::from_canonical_u32(pc); + adapter_cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; - - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); - - let read_records = rs_vals.map(|address| { - debug_assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - [memory.read::(e, F::from_canonical_u32(address))] - }); - let read_data = read_records.map(|r| r[0].1); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| array::from_fn(|i| r[i].0)), - }; - - Ok((read_data, record)) + adapter_row: &mut [F], + ) -> Self::ReadData { + let read_data = AdapterTraceStep::::read(&self.0, memory, instruction, adapter_row); + read_data.map(|r| r[0]) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let writes = [memory.write(e, read_record.rd_val, output.writes[0]).0]; - - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 6, - "timestamp delta is {}, expected 6", - timestamp_delta - ); + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + AdapterTraceStep::::write(&self.0, memory, instruction, adapter_row, data); + } - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, ctx: (), adapter_row: &mut [F]) { + AdapterTraceStep::::fill_trace_row(&self.0, mem_helper, ctx, adapter_row); } +} + +impl + AdapterExecutorE1 for Rv32HeapAdapterStep +{ + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; 1]; - fn generate_trace_row( + fn read( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ); + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let read_data = AdapterExecutorE1::::read(&self.0, state, instruction); + read_data.map(|r| r[0]) } - fn air(&self) -> &Self::Air { - &self.air + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + AdapterExecutorE1::::write(&self.0, state, instruction, data); } } diff --git a/extensions/rv32-adapters/src/heap_branch.rs b/extensions/rv32-adapters/src/heap_branch.rs index 29c9a151c9..8997b78f2f 100644 --- a/extensions/rv32-adapters/src/heap_branch.rs +++ b/extensions/rv32-adapters/src/heap_branch.rs @@ -1,23 +1,19 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, - iter::once, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -30,15 +26,14 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + memory_read_from_state, new_read_rv32_register_from_state, tracing_read, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; /// This adapter reads from NUM_READS <= 2 pointers. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -170,158 +165,169 @@ impl VmA } } -pub struct Rv32HeapBranchAdapterChip { - pub air: Rv32HeapBranchAdapterAir, +pub struct Rv32HeapBranchAdapterStep { + pub pointer_max_bits: usize, + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, } -impl - Rv32HeapBranchAdapterChip +impl + Rv32HeapBranchAdapterStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, + pointer_max_bits: usize, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, ) -> Self { assert!(NUM_READS <= 2); assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" ); Self { - air: Rv32HeapBranchAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, + pointer_max_bits, bitwise_lookup_chip, - _marker: PhantomData, } } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Rv32HeapBranchReadRecord { - #[serde(with = "BigArray")] - pub rs_reads: [RecordId; NUM_READS], - #[serde(with = "BigArray")] - pub heap_reads: [RecordId; NUM_READS], -} - -impl VmAdapterChip - for Rv32HeapBranchAdapterChip +impl AdapterTraceStep + for Rv32HeapBranchAdapterStep +where + F: PrimeField32, { - type ReadRecord = Rv32HeapBranchReadRecord; - type WriteRecord = ExecutionState; - type Air = Rv32HeapBranchAdapterAir; - type Interface = BasicAdapterInterface, NUM_READS, 0, READ_SIZE, 0>; - - fn preprocess( - &mut self, - memory: &mut MemoryController, + const WIDTH: usize = Rv32HeapBranchAdapterCols::::width(); + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = (); + type TraceContext<'a> = (); + + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let cols: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = + adapter_row.borrow_mut(); + cols.from_state.pc = F::from_canonical_u32(pc); + cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + adapter_row: &mut [F], + ) -> Self::ReadData { let Instruction { a, b, d, e, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { - let addr = if i == 0 { a } else { b }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record - }); + let cols: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = + adapter_row.borrow_mut(); - let heap_records = rs_vals.map(|address| { - assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits)); - memory.read::(e, F::from_canonical_u32(address)) + // Read register values + let rs_vals: [_; NUM_READS] = from_fn(|i| { + let addr = if i == 0 { a } else { b }; + cols.rs_ptr[i] = addr; + let rs_val = tracing_read(memory, d, addr.as_canonical_u32(), &mut cols.rs_read_aux[i]); + cols.rs_val[i] = rs_val.map(F::from_canonical_u8); + u32::from_le_bytes(rs_val) }); - let record = Rv32HeapBranchReadRecord { - rs_reads: rs_records, - heap_reads: heap_records.map(|r| r.0), - }; - Ok((heap_records.map(|r| r.1), record)) + // Read memory values + from_fn(|i| { + assert!(rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)); + tracing_read(memory, e, rs_vals[i], &mut cols.heap_read_aux[i]) + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + _memory: &mut TracingMemory, _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 4, - "timestamp delta is {}, expected 4", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - from_state, - )) + _adapter_row: &mut [F], + _data: &Self::WriteData, + ) { + // This function intentionally does nothing } - fn generate_trace_row( + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = - row_slice.borrow_mut(); - row_slice.from_state = write_record.map(F::from_canonical_u32); + let cols: &mut Rv32HeapBranchAdapterCols = + adapter_row.borrow_mut(); - let rs_reads = read_record.rs_reads.map(|r| memory.record_by_id(r)); + let mut timestamp = cols.from_state.timestamp.as_canonical_u32(); + let mut timestamp_pp = || { + timestamp += 1; + timestamp - 1 + }; - for (i, rs_read) in rs_reads.iter().enumerate() { - row_slice.rs_ptr[i] = rs_read.pointer; - row_slice.rs_val[i].copy_from_slice(rs_read.data_slice()); - aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]); - } + cols.rs_read_aux.iter_mut().for_each(|aux| { + mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut()); + }); - for (i, heap_read) in read_record.heap_reads.iter().enumerate() { - let record = memory.record_by_id(*heap_read); - aux_cols_factory.generate_read_aux(record, &mut row_slice.heap_read_aux[i]); - } + cols.heap_read_aux.iter_mut().for_each(|aux| { + mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut()); + }); // Range checks: - let need_range_check: Vec = rs_reads - .iter() - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }) - .chain(once(0)) // in case NUM_READS is odd - .collect(); - debug_assert!(self.air.address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits; - for pair in need_range_check.chunks_exact(2) { - self.bitwise_lookup_chip - .request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); - } + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + self.bitwise_lookup_chip.request_range( + cols.rs_val[0][RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + if NUM_READS > 1 { + cols.rs_val[1][RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits + } else { + 0 + }, + ); } +} + +impl AdapterExecutorE1 + for Rv32HeapBranchAdapterStep +{ + type ReadData = [[u8; READ_SIZE]; NUM_READS]; + type WriteData = (); - fn air(&self) -> &Self::Air { - &self.air + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, b, d, e, .. } = *instruction; + + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + // Read register values + let rs_vals = from_fn(|i| { + let addr = if i == 0 { a } else { b }; + new_read_rv32_register_from_state(state, d, addr.as_canonical_u32()) + }); + + // Read memory values + rs_vals.map(|address| { + assert!(address as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)); + memory_read_from_state(state, e, address) + }) + } + + fn write( + &self, + _state: &mut VmStateMut, + _instruction: &Instruction, + _data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + // This function intentionally does nothing } } diff --git a/extensions/rv32-adapters/src/vec_heap.rs b/extensions/rv32-adapters/src/vec_heap.rs index fab0df3334..924cff2f3c 100644 --- a/extensions/rv32-adapters/src/vec_heap.rs +++ b/extensions/rv32-adapters/src/vec_heap.rs @@ -2,21 +2,18 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, iter::{once, zip}, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VecHeapAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + ExecutionBridge, ExecutionState, VecHeapAdapterInterface, VmAdapterAir, VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -29,15 +26,15 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + abstract_compose, memory_read_from_state, memory_write_from_state, + new_read_rv32_register_from_state, tracing_read, tracing_write, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; /// This adapter reads from R (R <= 2) pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -46,89 +43,8 @@ use serde_with::serde_as; /// starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`). /// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the /// heap, starting from the address in `rd`. -#[derive(Clone)] -pub struct Rv32VecHeapAdapterChip< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub air: - Rv32VecHeapAdapterAir, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - Rv32VecHeapAdapterChip -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - assert!(NUM_READS <= 2); - assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" - ); - Self { - air: Rv32VecHeapAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -#[serde(bound = "F: Field")] -pub struct Rv32VecHeapReadRecord< - F: Field, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const READ_SIZE: usize, -> { - /// Read register value from address space e=1 - #[serde_as(as = "[_; NUM_READS]")] - pub rs: [RecordId; NUM_READS], - /// Read register value from address space d=1 - pub rd: RecordId, - - pub rd_val: F, - - #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")] - pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS], -} - #[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Rv32VecHeapWriteRecord { - pub from_state: ExecutionState, - #[serde_as(as = "[_; BLOCKS_PER_WRITE]")] - pub writes: [RecordId; BLOCKS_PER_WRITE], -} - -#[repr(C)] -#[derive(AlignedBorrow)] +#[derive(AlignedBorrow, Debug)] pub struct Rv32VecHeapAdapterCols< T, const NUM_READS: usize, @@ -346,201 +262,258 @@ impl< } } +#[derive(derive_new::new)] +pub struct Rv32VecHeapAdapterStep< + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pointer_max_bits: usize, + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + impl< F: PrimeField32, + CTX, const NUM_READS: usize, const BLOCKS_PER_READ: usize, const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterChip - for Rv32VecHeapAdapterChip< - F, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > + > AdapterTraceStep + for Rv32VecHeapAdapterStep { - type ReadRecord = Rv32VecHeapReadRecord; - type WriteRecord = Rv32VecHeapWriteRecord; - type Air = - Rv32VecHeapAdapterAir; - type Interface = VecHeapAdapterInterface< + const WIDTH: usize = Rv32VecHeapAdapterCols::< F, NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - >; + >::width(); + type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; + type TraceContext<'a> = (); + + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + adapter_cols.from_state.pc = F::from_canonical_u32(pc); + adapter_cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - fn preprocess( - &mut self, - memory: &mut MemoryController, + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { + adapter_row: &mut [F], + ) -> Self::ReadData { let Instruction { a, b, c, d, e, .. } = *instruction; - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + let e = e.as_canonical_u32(); + let d = d.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); // Read register values - let mut rs_vals = [0; NUM_READS]; - let rs_records: [_; NUM_READS] = from_fn(|i| { + let rs_vals: [_; NUM_READS] = from_fn(|i| { let addr = if i == 0 { b } else { c }; - let (record, val) = read_rv32_register(memory, d, addr); - rs_vals[i] = val; - record + cols.rs_ptr[i] = addr; + let rs_val = tracing_read(memory, d, addr.as_canonical_u32(), &mut cols.rs_read_aux[i]); + cols.rs_val[i] = rs_val.map(F::from_canonical_u8); + u32::from_le_bytes(rs_val) }); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); + + cols.rd_ptr = a; + let rd_val = tracing_read(memory, d, a.as_canonical_u32(), &mut cols.rd_read_aux); + cols.rd_val = rd_val.map(F::from_canonical_u8); // Read memory values - let read_records = rs_vals.map(|address| { + from_fn(|i| { assert!( - address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.air.address_bits) + rs_vals[i] as usize + READ_SIZE * BLOCKS_PER_READ - 1 + < (1 << self.pointer_max_bits) ); - from_fn(|i| { - memory.read::(e, F::from_canonical_u32(address + (i * READ_SIZE) as u32)) + from_fn(|j| { + tracing_read( + memory, + e, + rs_vals[i] + (j * READ_SIZE) as u32, + &mut cols.reads_aux[i][j], + ) }) - }); - let read_data = read_records.map(|r| r.map(|x| x.1)); - assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits)); - - let record = Rv32VecHeapReadRecord { - rs: rs_records, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads: read_records.map(|r| r.map(|x| x.0)), - }; - - Ok((read_data, record)) + }) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let mut i = 0; - let writes = output.writes.map(|write| { - let (record_id, _) = memory.write( + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let e = instruction.e.as_canonical_u32(); + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + + let rd_val = u32::from_le_bytes(cols.rd_val.map(|x| x.as_canonical_u32() as u8)); + assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.pointer_max_bits)); + + for (i, block) in data.iter().enumerate() { + tracing_write( + memory, e, - read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32), - write, + rd_val + (i * WRITE_SIZE) as u32, + block, + &mut cols.writes_aux[i], ); - i += 1; - record_id - }); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + } } - fn generate_trace_row( + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - vec_heap_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ) - } + let cols: &mut Rv32VecHeapAdapterCols< + F, + NUM_READS, + BLOCKS_PER_READ, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); - fn air(&self) -> &Self::Air { - &self.air - } -} + let mut timestamp = cols.from_state.timestamp.as_canonical_u32(); + let mut timestamp_pp = || { + timestamp += 1; + timestamp - 1 + }; -pub(super) fn vec_heap_generate_trace_row_impl< - F: PrimeField32, - const NUM_READS: usize, - const BLOCKS_PER_READ: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->( - row_slice: &mut [F], - read_record: &Rv32VecHeapReadRecord, - write_record: &Rv32VecHeapWriteRecord, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - address_bits: usize, - memory: &OfflineMemory, -) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32VecHeapAdapterCols< - F, - NUM_READS, - BLOCKS_PER_READ, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rd = memory.record_by_id(read_record.rd); - let rs = read_record - .rs - .into_iter() - .map(|r| memory.record_by_id(r)) - .collect::>(); - - row_slice.rd_ptr = rd.pointer; - row_slice.rd_val.copy_from_slice(rd.data_slice()); - - for (i, r) in rs.iter().enumerate() { - row_slice.rs_ptr[i] = r.pointer; - row_slice.rs_val[i].copy_from_slice(r.data_slice()); - aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]); - } + cols.rs_read_aux + .iter_mut() + .for_each(|aux| mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut())); + mem_helper.fill_from_prev(timestamp_pp(), cols.rd_read_aux.as_mut()); - aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); + cols.reads_aux.iter_mut().for_each(|reads| { + reads + .iter_mut() + .for_each(|aux| mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut())); + }); + + cols.writes_aux.iter_mut().for_each(|write| { + mem_helper.fill_from_prev(timestamp_pp(), write.as_mut()); + }); - for (i, reads) in read_record.reads.iter().enumerate() { - for (j, &x) in reads.iter().enumerate() { - let record = memory.record_by_id(x); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads_aux[i][j]); + // Range checks: + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + if NUM_READS > 1 { + self.bitwise_lookup_chip.request_range( + cols.rs_val[0][RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + cols.rs_val[1][RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + cols.rd_val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + cols.rd_val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + ); + } else { + self.bitwise_lookup_chip.request_range( + cols.rs_val[0][RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + cols.rd_val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + ); } } +} - for (i, &w) in write_record.writes.iter().enumerate() { - let record = memory.record_by_id(w); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); - } +impl< + F: PrimeField32, + const NUM_READS: usize, + const BLOCKS_PER_READ: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterExecutorE1 + for Rv32VecHeapAdapterStep +{ + type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS]; + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; - // Range checks: - let need_range_check: Vec = rs - .iter() - .chain(std::iter::repeat_n(&rd, 2)) - .map(|record| { - record - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, c, d, e, .. } = *instruction; + + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + // Read register values + let rs_vals = from_fn(|i| { + let addr = if i == 0 { b } else { c }; + new_read_rv32_register_from_state(state, d, addr.as_canonical_u32()) + }); + + // Read memory values + rs_vals.map(|address| { + assert!( + address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.pointer_max_bits) + ); + from_fn(|i| memory_read_from_state(state, e, address + (i * READ_SIZE) as u32)) }) - .collect(); - debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits; - for pair in need_range_check.chunks_exact(2) { - bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + } + + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, e, .. } = *instruction; + let rd_val = + new_read_rv32_register_from_state(state, d.as_canonical_u32(), a.as_canonical_u32()); + assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.pointer_max_bits)); + + for (i, block) in data.iter().enumerate() { + memory_write_from_state( + state, + e.as_canonical_u32(), + rd_val + (i * WRITE_SIZE) as u32, + block, + ); + } } } diff --git a/extensions/rv32-adapters/src/vec_heap_two_reads.rs b/extensions/rv32-adapters/src/vec_heap_two_reads.rs index f829db8bbc..2929ece00b 100644 --- a/extensions/rv32-adapters/src/vec_heap_two_reads.rs +++ b/extensions/rv32-adapters/src/vec_heap_two_reads.rs @@ -2,21 +2,18 @@ use std::{ array::from_fn, borrow::{Borrow, BorrowMut}, iter::zip, - marker::PhantomData, }; use itertools::izip; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VecHeapTwoReadsAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + ExecutionBridge, ExecutionState, VecHeapTwoReadsAdapterInterface, VmAdapterAir, VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ @@ -29,15 +26,15 @@ use openvm_instructions::{ riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, }; use openvm_rv32im_circuit::adapters::{ - abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + abstract_compose, memory_read_from_state, memory_write_from_state, + new_read_rv32_register_from_state, tracing_read, tracing_write, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, }; use openvm_stark_backend::{ interaction::InteractionBuilder, p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use serde_with::serde_as; /// This adapter reads from 2 pointers and writes to 1 pointer. /// * The data is read from the heap (address space 2), and the pointers are read from registers @@ -47,99 +44,6 @@ use serde_with::serde_as; /// * NOTE that the two reads can read different numbers of blocks. /// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the /// heap, starting from the address in `rd`. -pub struct Rv32VecHeapTwoReadsAdapterChip< - F: Field, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, -> { - pub air: Rv32VecHeapTwoReadsAdapterAir< - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - >, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl< - F: PrimeField32, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, - > - Rv32VecHeapTwoReadsAdapterChip< - F, - BLOCKS_PER_READ1, - BLOCKS_PER_READ2, - BLOCKS_PER_WRITE, - READ_SIZE, - WRITE_SIZE, - > -{ - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - assert!( - RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS, - "address_bits={address_bits} needs to be large enough for high limb range check" - ); - Self { - air: Rv32VecHeapTwoReadsAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bus: bitwise_lookup_chip.bus(), - address_bits, - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32VecHeapTwoReadsReadRecord< - F: Field, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const READ_SIZE: usize, -> { - /// Read register value from address space e=1 - pub rs1: RecordId, - pub rs2: RecordId, - /// Read register value from address space d=1 - pub rd: RecordId, - - pub rd_val: F, - - #[serde_as(as = "[_; BLOCKS_PER_READ1]")] - pub reads1: [RecordId; BLOCKS_PER_READ1], - #[serde_as(as = "[_; BLOCKS_PER_READ2]")] - pub reads2: [RecordId; BLOCKS_PER_READ2], -} - -#[repr(C)] -#[serde_as] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct Rv32VecHeapTwoReadsWriteRecord { - pub from_state: ExecutionState, - #[serde_as(as = "[_; BLOCKS_PER_WRITE]")] - pub writes: [RecordId; BLOCKS_PER_WRITE], -} - #[repr(C)] #[derive(AlignedBorrow)] pub struct Rv32VecHeapTwoReadsAdapterCols< @@ -372,16 +276,25 @@ impl< } } +pub struct Rv32VecHeapTwoReadsAdapterStep< + const BLOCKS_PER_READ1: usize, + const BLOCKS_PER_READ2: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, +> { + pointer_max_bits: usize, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} + impl< - F: PrimeField32, const BLOCKS_PER_READ1: usize, const BLOCKS_PER_READ2: usize, const BLOCKS_PER_WRITE: usize, const READ_SIZE: usize, const WRITE_SIZE: usize, - > VmAdapterChip - for Rv32VecHeapTwoReadsAdapterChip< - F, + > + Rv32VecHeapTwoReadsAdapterStep< BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, @@ -389,189 +302,277 @@ impl< WRITE_SIZE, > { - type ReadRecord = - Rv32VecHeapTwoReadsReadRecord; - type WriteRecord = Rv32VecHeapTwoReadsWriteRecord; - type Air = Rv32VecHeapTwoReadsAdapterAir< + pub fn new( + pointer_max_bits: usize, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + ) -> Self { + assert!( + RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS, + "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check" + ); + Self { + pointer_max_bits, + bitwise_lookup_chip, + } + } +} + +impl< + F: PrimeField32, + CTX, + const BLOCKS_PER_READ1: usize, + const BLOCKS_PER_READ2: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterTraceStep + for Rv32VecHeapTwoReadsAdapterStep< BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - >; - type Interface = VecHeapTwoReadsAdapterInterface< + > +{ + const WIDTH: usize = Rv32VecHeapTwoReadsAdapterCols::< F, BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - >; + >::width(); - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, c, d, e, .. } = *instruction; + type ReadData = ( + [[u8; READ_SIZE]; BLOCKS_PER_READ1], + [[u8; READ_SIZE]; BLOCKS_PER_READ2], + ); + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; + type TraceContext<'a> = (); - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_cols: &mut Rv32VecHeapTwoReadsAdapterCols< + F, + BLOCKS_PER_READ1, + BLOCKS_PER_READ2, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + adapter_cols.from_state.pc = F::from_canonical_u32(pc); + adapter_cols.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - let (rs1_record, rs1_val) = read_rv32_register(memory, d, b); - let (rs2_record, rs2_val) = read_rv32_register(memory, d, c); - let (rd_record, rd_val) = read_rv32_register(memory, d, a); + fn read( + &self, + memory: &mut TracingMemory, + instruction: &Instruction, + adapter_row: &mut [F], + ) -> Self::ReadData { + let Instruction { a, b, c, d, e, .. } = *instruction; - assert!(rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 < (1 << self.air.address_bits)); - let read1_records = from_fn(|i| { - memory.read::(e, F::from_canonical_u32(rs1_val + (i * READ_SIZE) as u32)) - }); - let read1_data = read1_records.map(|r| r.1); - assert!(rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 < (1 << self.air.address_bits)); - let read2_records = from_fn(|i| { - memory.read::(e, F::from_canonical_u32(rs2_val + (i * READ_SIZE) as u32)) - }); - let read2_data = read2_records.map(|r| r.1); - assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits)); - - let record = Rv32VecHeapTwoReadsReadRecord { - rs1: rs1_record, - rs2: rs2_record, - rd: rd_record, - rd_val: F::from_canonical_u32(rd_val), - reads1: read1_records.map(|r| r.0), - reads2: read2_records.map(|r| r.0), - }; + let e = e.as_canonical_u32(); + let d = d.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); - Ok(((read1_data, read2_data), record)) + let cols: &mut Rv32VecHeapTwoReadsAdapterCols< + F, + BLOCKS_PER_READ1, + BLOCKS_PER_READ2, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + + // Read register values + cols.rs1_ptr = b; + let rs1_val = tracing_read(memory, d, b.as_canonical_u32(), &mut cols.rs1_read_aux); + cols.rs1_val = rs1_val.map(F::from_canonical_u8); + let rs1_val = u32::from_le_bytes(rs1_val); + cols.rs2_ptr = c; + let rs2_val = tracing_read(memory, d, c.as_canonical_u32(), &mut cols.rs2_read_aux); + cols.rs2_val = rs2_val.map(F::from_canonical_u8); + let rs2_val = u32::from_le_bytes(rs2_val); + + cols.rd_ptr = a; + let rd_val = tracing_read(memory, d, a.as_canonical_u32(), &mut cols.rd_read_aux); + cols.rd_val = rd_val.map(F::from_canonical_u8); + assert!(rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 < (1 << self.pointer_max_bits)); + assert!(rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 < (1 << self.pointer_max_bits)); + + ( + from_fn(|i| { + tracing_read( + memory, + e, + rs1_val + (i * READ_SIZE) as u32, + &mut cols.reads1_aux[i], + ) + }), + from_fn(|i| { + tracing_read( + memory, + e, + rs2_val + (i * READ_SIZE) as u32, + &mut cols.reads2_aux[i], + ) + }), + ) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let e = instruction.e; - let mut i = 0; - let writes = output.writes.map(|write| { - let (record_id, _) = memory.write( + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let e = instruction.e.as_canonical_u32(); + let cols: &mut Rv32VecHeapTwoReadsAdapterCols< + F, + BLOCKS_PER_READ1, + BLOCKS_PER_READ2, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + + let rd_val = u32::from_le_bytes(cols.rd_val.map(|x| x.as_canonical_u32() as u8)); + assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.pointer_max_bits)); + + for (i, block) in data.iter().enumerate() { + tracing_write( + memory, e, - read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32), - write, + rd_val + (i * WRITE_SIZE) as u32, + block, + &mut cols.writes_aux[i], ); - i += 1; - record_id - }); - - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, writes }, - )) + } } - fn generate_trace_row( + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &openvm_circuit::system::memory::MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - vec_heap_two_reads_generate_trace_row_impl( - row_slice, - &read_record, - &write_record, - self.bitwise_lookup_chip.clone(), - self.air.address_bits, - memory, - ) - } + let cols: &mut Rv32VecHeapTwoReadsAdapterCols< + F, + BLOCKS_PER_READ1, + BLOCKS_PER_READ2, + BLOCKS_PER_WRITE, + READ_SIZE, + WRITE_SIZE, + > = adapter_row.borrow_mut(); + + let mut timestamp = cols.from_state.timestamp.as_canonical_u32(); + let mut timestamp_pp = || { + timestamp += 1; + timestamp - 1 + }; + + mem_helper.fill_from_prev(timestamp_pp(), cols.rs1_read_aux.as_mut()); + mem_helper.fill_from_prev(timestamp_pp(), cols.rs2_read_aux.as_mut()); + mem_helper.fill_from_prev(timestamp_pp(), cols.rd_read_aux.as_mut()); + cols.reads1_aux.iter_mut().for_each(|aux| { + mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut()); + }); + cols.reads2_aux.iter_mut().for_each(|aux| { + mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut()); + }); + cols.writes_aux.iter_mut().for_each(|aux| { + mem_helper.fill_from_prev(timestamp_pp(), aux.as_mut()); + }); + + debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - fn air(&self) -> &Self::Air { - &self.air + let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits; + self.bitwise_lookup_chip.request_range( + cols.rs1_val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + cols.rs2_val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + ); + self.bitwise_lookup_chip.request_range( + cols.rd_val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + cols.rd_val[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() << limb_shift_bits, + ); } } -pub(super) fn vec_heap_two_reads_generate_trace_row_impl< - F: PrimeField32, - const BLOCKS_PER_READ1: usize, - const BLOCKS_PER_READ2: usize, - const BLOCKS_PER_WRITE: usize, - const READ_SIZE: usize, - const WRITE_SIZE: usize, ->( - row_slice: &mut [F], - read_record: &Rv32VecHeapTwoReadsReadRecord, - write_record: &Rv32VecHeapTwoReadsWriteRecord, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - address_bits: usize, - memory: &OfflineMemory, -) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32VecHeapTwoReadsAdapterCols< - F, +impl< + F: PrimeField32, + const BLOCKS_PER_READ1: usize, + const BLOCKS_PER_READ2: usize, + const BLOCKS_PER_WRITE: usize, + const READ_SIZE: usize, + const WRITE_SIZE: usize, + > AdapterExecutorE1 + for Rv32VecHeapTwoReadsAdapterStep< BLOCKS_PER_READ1, BLOCKS_PER_READ2, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE, - > = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - - let rd = memory.record_by_id(read_record.rd); - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - - row_slice.rd_ptr = rd.pointer; - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - - row_slice.rd_val.copy_from_slice(rd.data_slice()); - row_slice.rs1_val.copy_from_slice(rs1.data_slice()); - row_slice.rs2_val.copy_from_slice(rs2.data_slice()); - - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.rs1_read_aux); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.rs2_read_aux); - aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux); - - for (i, r) in read_record.reads1.iter().enumerate() { - let record = memory.record_by_id(*r); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads1_aux[i]); - } + > +{ + type ReadData = ( + [[u8; READ_SIZE]; BLOCKS_PER_READ1], + [[u8; READ_SIZE]; BLOCKS_PER_READ2], + ); + type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE]; - for (i, r) in read_record.reads2.iter().enumerate() { - let record = memory.record_by_id(*r); - aux_cols_factory.generate_read_aux(record, &mut row_slice.reads2_aux[i]); + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, c, d, e, .. } = *instruction; + + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + // Read register values + let rs1_val = new_read_rv32_register_from_state(state, d, b.as_canonical_u32()); + let rs2_val = new_read_rv32_register_from_state(state, d, c.as_canonical_u32()); + + assert!(rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 < (1 << self.pointer_max_bits)); + assert!(rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 < (1 << self.pointer_max_bits)); + // Read memory values + let read_data1 = + from_fn(|i| memory_read_from_state(state, e, rs1_val + (i * READ_SIZE) as u32)); + let read_data2 = + from_fn(|i| memory_read_from_state(state, e, rs2_val + (i * READ_SIZE) as u32)); + + (read_data1, read_data2) } - for (i, w) in write_record.writes.iter().enumerate() { - let record = memory.record_by_id(*w); - aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]); - } - // Range checks: - let need_range_check = [ - &read_record.rs1, - &read_record.rs2, - &read_record.rd, - &read_record.rd, - ] - .map(|record| { - memory - .record_by_id(*record) - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - }); - debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS); - let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits; - for pair in need_range_check.chunks_exact(2) { - bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits); + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, e, .. } = *instruction; + + let rd_val = + new_read_rv32_register_from_state(state, d.as_canonical_u32(), a.as_canonical_u32()); + assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.pointer_max_bits)); + + for (i, block) in data.iter().enumerate() { + memory_write_from_state( + state, + e.as_canonical_u32(), + rd_val + (i * WRITE_SIZE) as u32, + block, + ); + } } } diff --git a/extensions/rv32im/circuit/Cargo.toml b/extensions/rv32im/circuit/Cargo.toml index 8b20385104..0e28dd3093 100644 --- a/extensions/rv32im/circuit/Cargo.toml +++ b/extensions/rv32im/circuit/Cargo.toml @@ -21,6 +21,7 @@ derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true eyre.workspace = true + # for div_rem: num-bigint.workspace = true num-integer.workspace = true @@ -30,6 +31,7 @@ serde-big-array.workspace = true [dev-dependencies] openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils"] } +test-case.workspace = true [features] default = ["parallel", "jemalloc"] diff --git a/extensions/rv32im/circuit/src/adapters/alu.rs b/extensions/rv32im/circuit/src/adapters/alu.rs index b61e2a224a..271405fbab 100644 --- a/extensions/rv32im/circuit/src/adapters/alu.rs +++ b/extensions/rv32im/circuit/src/adapters/alu.rs @@ -1,20 +1,15 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::{ @@ -32,60 +27,11 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -use super::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; - -/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e. -/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c -/// is an immediate). -pub struct Rv32BaseAluAdapterChip { - pub air: Rv32BaseAluAdapterAir, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - _marker: PhantomData, -} - -impl Rv32BaseAluAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - ) -> Self { - Self { - air: Rv32BaseAluAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - }, - bitwise_lookup_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32BaseAluReadRecord { - /// Read register value from address space d=1 - pub rs1: RecordId, - /// Either - /// - read rs2 register value or - /// - if `rs2_is_imm` is true, this is None - pub rs2: Option, - /// immediate value of rs2 or 0 - pub rs2_imm: F, -} -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32BaseAluWriteRecord { - pub from_state: ExecutionState, - /// Write to destination register - pub rd: (RecordId, [F; 4]), -} +use super::{ + tracing_read, tracing_read_imm, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; +use crate::adapters::{imm_to_bytes, memory_read_from_state, memory_write_from_state}; #[repr(C)] #[derive(AlignedBorrow)] @@ -101,7 +47,9 @@ pub struct Rv32BaseAluAdapterCols { pub writes_aux: MemoryWriteAuxCols, } -#[allow(dead_code)] +/// Reads instructions of the form OP a, b, c, d, e where \[a:4\]_d = \[b:4\]_d op \[c:4\]_e. +/// Operand d can only be 1, and e can be either 1 (for register reads) or 0 (when c +/// is an immediate). #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32BaseAluAdapterAir { pub(super) execution_bridge: ExecutionBridge, @@ -213,129 +161,173 @@ impl VmAdapterAir for Rv32BaseAluAdapterAir { } } -impl VmAdapterChip for Rv32BaseAluAdapterChip { - type ReadRecord = Rv32BaseAluReadRecord; - type WriteRecord = Rv32BaseAluWriteRecord; - type Air = Rv32BaseAluAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; +#[derive(derive_new::new)] +pub struct Rv32BaseAluAdapterStep { + // TODO(arayi): use reference to bitwise lookup chip with lifetimes instead + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, +} - fn preprocess( - &mut self, - memory: &mut MemoryController, +impl AdapterTraceStep + for Rv32BaseAluAdapterStep +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { b, c, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert!( - e.as_canonical_u32() == RV32_IMM_AS || e.as_canonical_u32() == RV32_REGISTER_AS + e.as_canonical_u32() == RV32_REGISTER_AS || e.as_canonical_u32() == RV32_IMM_AS ); - let rs1 = memory.read::(d, b); - let (rs2, rs2_data, rs2_imm) = if e.is_zero() { - let c_u32 = c.as_canonical_u32(); - debug_assert_eq!(c_u32 >> 24, 0); - memory.increment_timestamp(); - ( - None, - [ - c_u32 as u8, - (c_u32 >> 8) as u8, - (c_u32 >> 16) as u8, - (c_u32 >> 16) as u8, - ] - .map(F::from_canonical_u8), - c, + let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); + + adapter_row.rs1_ptr = b; + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut adapter_row.reads_aux[0], + ); + + let rs2 = if e.as_canonical_u32() == RV32_REGISTER_AS { + adapter_row.rs2_as = e; + adapter_row.rs2 = c; + + tracing_read( + memory, + RV32_REGISTER_AS, + c.as_canonical_u32(), + &mut adapter_row.reads_aux[1], ) } else { - let rs2_read = memory.read::(e, c); - (Some(rs2_read.0), rs2_read.1, F::ZERO) + adapter_row.rs2_as = e; + + tracing_read_imm(memory, c.as_canonical_u32(), &mut adapter_row.rs2) }; - Ok(( - [rs1.1, rs2_data], - Self::ReadRecord { - rs1: rs1.0, - rs2, - rs2_imm, - }, - )) + [rs1, rs2] } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = instruction; - let rd = memory.write(*d, *a, output.writes[0]); + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { a, d, .. } = instruction; - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 3, - "timestamp delta is {}, expected 3", - timestamp_delta - ); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd }, - )) + adapter_row.rd_ptr = a; + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &data[0], + &mut adapter_row.writes_aux, + ); } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _ctx: (), + adapter_row: &mut [F], ) { - let row_slice: &mut Rv32BaseAluAdapterCols<_> = row_slice.borrow_mut(); - let aux_cols_factory = memory.aux_cols_factory(); - - let rd = memory.record_by_id(write_record.rd.0); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - row_slice.rd_ptr = rd.pointer; - - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = read_record.rs2.map(|rs2| memory.record_by_id(rs2)); - row_slice.rs1_ptr = rs1.pointer; - - if let Some(rs2) = rs2 { - row_slice.rs2 = rs2.pointer; - row_slice.rs2_as = rs2.address_space; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); + let adapter_row: &mut Rv32BaseAluAdapterCols = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[0].as_mut()); + timestamp += 1; + + if !adapter_row.rs2_as.is_zero() { + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[1].as_mut()); } else { - row_slice.rs2 = read_record.rs2_imm; - row_slice.rs2_as = F::ZERO; - let rs2_imm = row_slice.rs2.as_canonical_u32(); + let rs2_imm = adapter_row.rs2.as_canonical_u32(); let mask = (1 << RV32_CELL_BITS) - 1; self.bitwise_lookup_chip .request_range(rs2_imm & mask, (rs2_imm >> 8) & mask); - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - // row_slice.reads_aux[1] is disabled } - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.writes_aux.as_mut()); + } +} + +impl AdapterExecutorE1 for Rv32BaseAluAdapterStep +where + F: PrimeField32, +{ + // TODO(ayush): directly use u32 + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, c, d, e, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert!( + e.as_canonical_u32() == RV32_IMM_AS || e.as_canonical_u32() == RV32_REGISTER_AS + ); + + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, RV32_REGISTER_AS, b.as_canonical_u32()); + + let rs2 = if e.as_canonical_u32() == RV32_REGISTER_AS { + let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, RV32_REGISTER_AS, c.as_canonical_u32()); + rs2 + } else { + imm_to_bytes(c.as_canonical_u32()) + }; + + [rs1, rs2] } - fn air(&self) -> &Self::Air { - &self.air + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + rd: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + memory_write_from_state(state, d.as_canonical_u32(), a.as_canonical_u32(), &rd[0]); } } diff --git a/extensions/rv32im/circuit/src/adapters/branch.rs b/extensions/rv32im/circuit/src/adapters/branch.rs index 3e26f37f4c..7af331d7c5 100644 --- a/extensions/rv32im/circuit/src/adapters/branch.rs +++ b/extensions/rv32im/circuit/src/adapters/branch.rs @@ -1,20 +1,15 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -26,48 +21,9 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c. -/// Operands d and e can only be 1. -#[derive(Debug)] -pub struct Rv32BranchAdapterChip { - pub air: Rv32BranchAdapterAir, - _marker: PhantomData, -} - -impl Rv32BranchAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32BranchAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32BranchReadRecord { - /// Read register value from address space d = 1 - pub rs1: RecordId, - /// Read register value from address space e = 1 - pub rs2: RecordId, -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32BranchWriteRecord { - pub from_state: ExecutionState, -} +use crate::adapters::{memory_read_from_state, tracing_read}; #[repr(C)] #[derive(AlignedBorrow)] @@ -149,80 +105,125 @@ impl VmAdapterAir for Rv32BranchAdapterAir { } } -impl VmAdapterChip for Rv32BranchAdapterChip { - type ReadRecord = Rv32BranchReadRecord; - type WriteRecord = Rv32BranchWriteRecord; - type Air = Rv32BranchAdapterAir; - type Interface = BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>; +/// Reads instructions of the form OP a, b, c, d, e where if(\[a:4\]_d op \[b:4\]_e) pc += c. +/// Operands d and e can only be 1. +#[derive(derive_new::new)] +pub struct Rv32BranchAdapterStep; + +impl AdapterTraceStep for Rv32BranchAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = (); + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32BranchAdapterCols = adapter_row.borrow_mut(); + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { a, b, d, e, .. } = *instruction; + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { a, b, d, e, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_REGISTER_AS); - let rs1 = memory.read::(d, a); - let rs2 = memory.read::(e, b); + let adapter_row: &mut Rv32BranchAdapterCols = adapter_row.borrow_mut(); + + adapter_row.rs1_ptr = a; + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut adapter_row.reads_aux[0], + ); + adapter_row.rs2_ptr = b; + let rs2 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut adapter_row.reads_aux[1], + ); - Ok(( - [rs1.1, rs2.1], - Self::ReadRecord { - rs1: rs1.0, - rs2: rs2.0, - }, - )) + [rs1, rs2] } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + _memory: &mut TracingMemory, _instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 2, - "timestamp delta is {}, expected 2", - timestamp_delta - ); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state }, - )) + _adapter_row: &mut [F], + _data: &Self::WriteData, + ) { } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _trace_ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32BranchAdapterCols<_> = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); + let adapter_row: &mut Rv32BranchAdapterCols = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[0].as_mut()); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[1].as_mut()); + } +} + +impl AdapterExecutorE1 for Rv32BranchAdapterStep +where + F: PrimeField32, +{ + // TODO(ayush): directly use u32 + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = (); + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, b, d, e, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_REGISTER_AS); + + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, RV32_REGISTER_AS, a.as_canonical_u32()); + let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, RV32_REGISTER_AS, b.as_canonical_u32()); + + [rs1, rs2] } - fn air(&self) -> &Self::Air { - &self.air + #[inline(always)] + fn write( + &self, + _state: &mut VmStateMut, + _instruction: &Instruction, + _data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { } } diff --git a/extensions/rv32im/circuit/src/adapters/jalr.rs b/extensions/rv32im/circuit/src/adapters/jalr.rs index f7dbf623b8..29bc4552df 100644 --- a/extensions/rv32im/circuit/src/adapters/jalr.rs +++ b/extensions/rv32im/circuit/src/adapters/jalr.rs @@ -1,20 +1,15 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, Result, SignedImmInstruction, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, SignedImmInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::utils::not; @@ -27,44 +22,11 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd) -#[derive(Debug)] -pub struct Rv32JalrAdapterChip { - pub air: Rv32JalrAdapterAir, - _marker: PhantomData, -} - -impl Rv32JalrAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32JalrAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32JalrReadRecord { - pub rs1: RecordId, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32JalrWriteRecord { - pub from_state: ExecutionState, - pub rd_id: Option, -} +use crate::adapters::{ + memory_read_from_state, memory_write_from_state, tracing_read, tracing_write, +}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -179,84 +141,144 @@ impl VmAdapterAir for Rv32JalrAdapterAir { } } -impl VmAdapterChip for Rv32JalrAdapterChip { - type ReadRecord = Rv32JalrReadRecord; - type WriteRecord = Rv32JalrWriteRecord; - type Air = Rv32JalrAdapterAir; - type Interface = BasicAdapterInterface< - F, - SignedImmInstruction, - 1, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; - fn preprocess( - &mut self, - memory: &mut MemoryController, +// This adapter reads from [b:4]_d (rs1) and writes to [a:4]_d (rd) +#[derive(derive_new::new)] +pub struct Rv32JalrAdapterStep; + +impl AdapterTraceStep for Rv32JalrAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [u8; RV32_REGISTER_NUM_LIMBS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, d, .. } = *instruction; + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { b, d, .. } = instruction; + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - let rs1 = memory.read::(d, b); + let adapter_row: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); - Ok(([rs1.1], Rv32JalrReadRecord { rs1: rs1.0 })) + adapter_row.rs1_ptr = b; + tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut adapter_row.rs1_aux_cols, + ) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { a, d, f: enabled, .. - } = *instruction; - let rd_id = if enabled != F::ZERO { - let (record_id, _) = memory.write(d, a, output.writes[0]); - Some(record_id) + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + if enabled != F::ZERO { + let adapter_row: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); + + adapter_row.needs_write = F::ONE; + + adapter_row.rd_ptr = a; + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + data, + &mut adapter_row.rd_aux_cols, + ); } else { memory.increment_timestamp(); - None - }; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + } } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _trace_ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32JalrAdapterCols<_> = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1); - adapter_cols.rs1_ptr = rs1.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); - if let Some(id) = write_record.rd_id { - let rd = memory.record_by_id(id); - adapter_cols.rd_ptr = rd.pointer; - adapter_cols.needs_write = F::ONE; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); + let adapter_row: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.rs1_aux_cols.as_mut()); + timestamp += 1; + + if adapter_row.needs_write.is_one() { + mem_helper.fill_from_prev(timestamp, adapter_row.rd_aux_cols.as_mut()); } } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterExecutorE1 for Rv32JalrAdapterStep +where + F: PrimeField32, +{ + // TODO(ayush): directly use u32 + type ReadData = [u8; RV32_REGISTER_NUM_LIMBS]; + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, RV32_REGISTER_AS, b.as_canonical_u32()); + + rs1 + } + + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { + a, d, f: enabled, .. + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + if *enabled != F::ZERO { + memory_write_from_state(state, RV32_REGISTER_AS, a.as_canonical_u32(), data); + } } } diff --git a/extensions/rv32im/circuit/src/adapters/loadstore.rs b/extensions/rv32im/circuit/src/adapters/loadstore.rs index b92680a0c7..9ca3135af2 100644 --- a/extensions/rv32im/circuit/src/adapters/loadstore.rs +++ b/extensions/rv32im/circuit/src/adapters/loadstore.rs @@ -6,17 +6,13 @@ use std::{ use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState, - Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface, VmStateMut, }, - system::{ - memory::{ - offline_checker::{ - MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols, - }, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::{ @@ -36,10 +32,12 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; -use super::{compose, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::RV32_CELL_BITS; +use super::RV32_REGISTER_NUM_LIMBS; +use crate::adapters::{ + memory_read, memory_read_from_state, memory_write_from_state, tracing_read, + tracing_write_with_base_aux, RV32_CELL_BITS, +}; /// LoadStore Adapter handles all memory and register operations, so it must be aware /// of the instruction type, specifically whether it is a load or store @@ -64,22 +62,6 @@ pub struct LoadStoreInstruction { pub store_shift_amount: T, } -/// The LoadStoreAdapter separates Runtime and Air AdapterInterfaces. -/// This is necessary because `prev_data` should be owned by the core chip and sent to the adapter, -/// and it must have an AB::Var type in AIR as to satisfy the memory_bridge interface. -/// This is achieved by having different types for reads and writes in Air AdapterInterface. -/// This method ensures that there are no modifications to the global interfaces. -/// -/// Here 2 reads represent read_data and prev_data, -/// The second element of the tuple in Reads is the shift amount needed to be passed to the core -/// chip Getting the intermediate pointer is completely internal to the adapter and shouldn't be a -/// part of the AdapterInterface -pub struct Rv32LoadStoreAdapterRuntimeInterface(PhantomData); -impl VmAdapterInterface for Rv32LoadStoreAdapterRuntimeInterface { - type Reads = ([[T; RV32_REGISTER_NUM_LIMBS]; 2], T); - type Writes = [[T; RV32_REGISTER_NUM_LIMBS]; 1]; - type ProcessedInstruction = (); -} pub struct Rv32LoadStoreAdapterAirInterface(PhantomData); /// Using AB::Var for prev_data and AB::Expr for read_data @@ -92,65 +74,6 @@ impl VmAdapterInterface for Rv32LoadStoreAdapt type ProcessedInstruction = LoadStoreInstruction; } -/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm. -/// In case of Loads, reads from the shifted intermediate pointer and writes to rd. -/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer. -pub struct Rv32LoadStoreAdapterChip { - pub air: Rv32LoadStoreAdapterAir, - pub range_checker_chip: SharedVariableRangeCheckerChip, - _marker: PhantomData, -} - -impl Rv32LoadStoreAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - pointer_max_bits: usize, - range_checker_chip: SharedVariableRangeCheckerChip, - ) -> Self { - assert!(range_checker_chip.range_max_bits() >= 15); - Self { - air: Rv32LoadStoreAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - range_bus: range_checker_chip.bus(), - pointer_max_bits, - }, - range_checker_chip, - _marker: PhantomData, - } - } -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32LoadStoreReadRecord { - pub rs1_record: RecordId, - /// This will be a read from a register in case of Stores and a read from RISC-V memory in case - /// of Loads. - pub read: RecordId, - pub rs1_ptr: F, - pub imm: F, - pub imm_sign: F, - pub mem_as: F, - pub mem_ptr_limbs: [u32; 2], - pub shift_amount: u32, -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32LoadStoreWriteRecord { - /// This will be a write to a register in case of Load and a write to RISC-V memory in case of - /// Stores. For better struct packing, `RecordId(usize::MAX)` is used to indicate that - /// there is no write. - pub write_id: RecordId, - pub from_state: ExecutionState, - pub rd_rs2_ptr: F, -} - #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32LoadStoreAdapterCols { @@ -366,22 +289,46 @@ impl VmAdapterAir for Rv32LoadStoreAdapterAir { } } -impl VmAdapterChip for Rv32LoadStoreAdapterChip { - type ReadRecord = Rv32LoadStoreReadRecord; - type WriteRecord = Rv32LoadStoreWriteRecord; - type Air = Rv32LoadStoreAdapterAir; - type Interface = Rv32LoadStoreAdapterRuntimeInterface; +/// This chip reads rs1 and gets a intermediate memory pointer address with rs1 + imm. +/// In case of Loads, reads from the shifted intermediate pointer and writes to rd. +/// In case of Stores, reads from rs2 and writes to the shifted intermediate pointer. +pub struct Rv32LoadStoreAdapterStep { + pointer_max_bits: usize, +} + +impl Rv32LoadStoreAdapterStep { + pub fn new(pointer_max_bits: usize) -> Self { + Self { pointer_max_bits } + } +} + +impl AdapterTraceStep for Rv32LoadStoreAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = ( + ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]), + u32, + ); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type TraceContext<'a> = &'a SharedVariableRangeCheckerChip; + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - #[allow(clippy::type_complexity)] - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { opcode, a, b, @@ -390,16 +337,26 @@ impl VmAdapterChip for Rv32LoadStoreAdapterChip { e, g, .. - } = *instruction; + } = instruction; + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert!(e.as_canonical_u32() != RV32_IMM_AS); + let adapter_row: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); + let local_opcode = Rv32LoadStoreOpcode::from_usize( opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let rs1_record = memory.read::(d, b); - let rs1_val = compose(rs1_record.1); + adapter_row.rs1_ptr = b; + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut adapter_row.rs1_aux_cols, + ); + + let rs1_val = u32::from_le_bytes(rs1); let imm = c.as_canonical_u32(); let imm_sign = g.as_canonical_u32(); let imm_extended = imm + imm_sign * 0xffff0000; @@ -407,137 +364,329 @@ impl VmAdapterChip for Rv32LoadStoreAdapterChip { let ptr_val = rs1_val.wrapping_add(imm_extended); let shift_amount = ptr_val % 4; assert!( - ptr_val < (1 << self.air.pointer_max_bits), + ptr_val < (1 << self.pointer_max_bits), "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}", - self.air.pointer_max_bits + self.pointer_max_bits ); let mem_ptr_limbs = array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff)); let ptr_val = ptr_val - shift_amount; - let read_record = match local_opcode { - LOADW | LOADB | LOADH | LOADBU | LOADHU => { - memory.read::(e, F::from_canonical_u32(ptr_val)) - } - STOREW | STOREH | STOREB => memory.read::(d, a), + let read_data = match local_opcode { + LOADW | LOADB | LOADH | LOADBU | LOADHU => tracing_read( + memory, + e.as_canonical_u32(), + ptr_val, + &mut adapter_row.read_data_aux, + ), + STOREW | STOREH | STOREB => tracing_read( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &mut adapter_row.read_data_aux, + ), }; // We need to keep values of some cells to keep them unchanged when writing to those cells let prev_data = match local_opcode { - STOREW | STOREH | STOREB => array::from_fn(|i| { - memory.unsafe_read_cell(e, F::from_canonical_usize(ptr_val as usize + i)) - }), + STOREW | STOREH | STOREB => { + if e.as_canonical_u32() == 4 { + unsafe { + memory + .data() + .read::(4, ptr_val) + .map(|x| x.as_canonical_u32() as u8) + } + } else { + memory_read(memory.data(), e.as_canonical_u32(), ptr_val) + } + } LOADW | LOADB | LOADH | LOADBU | LOADHU => { - array::from_fn(|i| memory.unsafe_read_cell(d, a + F::from_canonical_usize(i))) + memory_read(memory.data(), d.as_canonical_u32(), a.as_canonical_u32()) } }; - Ok(( - ( - [prev_data, read_record.1], - F::from_canonical_u32(shift_amount), - ), - Self::ReadRecord { - rs1_record: rs1_record.0, - rs1_ptr: b, - read: read_record.0, - imm: c, - imm_sign: g, - shift_amount, - mem_ptr_limbs, - mem_as: e, - }, - )) + adapter_row + .rs1_data + .copy_from_slice(&rs1.map(F::from_canonical_u8)); + adapter_row.imm = c; + adapter_row.imm_sign = g; + adapter_row.mem_ptr_limbs = mem_ptr_limbs.map(F::from_canonical_u32); + adapter_row.mem_as = e; + + ((prev_data, read_data), shift_amount) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { opcode, a, + c, d, e, f: enabled, + g, .. - } = *instruction; + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert!(e.as_canonical_u32() != RV32_IMM_AS); let local_opcode = Rv32LoadStoreOpcode::from_usize( opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let write_id = if enabled != F::ZERO { - let (record_id, _) = match local_opcode { + let adapter_row: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); + + let rs1 = adapter_row.rs1_data.map(|x| x.as_canonical_u32() as u8); + + let rs1_val = u32::from_le_bytes(rs1); + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + + let ptr_val = rs1_val.wrapping_add(imm_extended); + assert!( + ptr_val < (1 << self.pointer_max_bits), + "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}", + self.pointer_max_bits + ); + + let mem_ptr_limbs: [u32; 2] = + array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff)); + + if enabled != F::ZERO { + adapter_row.needs_write = F::ONE; + + match local_opcode { STOREW | STOREH | STOREB => { - let ptr = read_record.mem_ptr_limbs[0] - + read_record.mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2)); - memory.write(e, F::from_canonical_u32(ptr & 0xfffffffc), output.writes[0]) + let ptr = mem_ptr_limbs[0] + mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2)); + let ptr = ptr & 0xfffffffc; + + // TODO(arayi): This workaround should be temporary + if e.as_canonical_u32() == 4 { + let (t_prev, _) = unsafe { + memory.write::( + e.as_canonical_u32(), + ptr, + &data.map(F::from_canonical_u8), + ) + }; + adapter_row + .write_base_aux + .set_prev(F::from_canonical_u32(t_prev)); + } else { + tracing_write_with_base_aux( + memory, + e.as_canonical_u32(), + ptr, + data, + &mut adapter_row.write_base_aux, + ); + } + } + LOADW | LOADB | LOADH | LOADBU | LOADHU => { + tracing_write_with_base_aux( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + data, + &mut adapter_row.write_base_aux, + ); } - LOADW | LOADB | LOADH | LOADBU | LOADHU => memory.write(d, a, output.writes[0]), }; - record_id + adapter_row.rd_rs2_ptr = a; } else { memory.increment_timestamp(); - // RecordId will never get to usize::MAX, so it can be used as a flag for no write - RecordId(usize::MAX) }; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - write_id, - rd_rs2_ptr: a, - }, - )) } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + range_checker_chip: &SharedVariableRangeCheckerChip, + adapter_row: &mut [F], ) { - self.range_checker_chip.add_count( - (read_record.mem_ptr_limbs[0] - read_record.shift_amount) / 4, + // TODO(ayush): should this be here? + assert!(range_checker_chip.range_max_bits() >= 15); + + let adapter_row: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); + + let rs1 = adapter_row.rs1_data.map(|x| x.as_canonical_u32() as u8); + let rs1_val = u32::from_le_bytes(rs1); + + let imm = adapter_row.imm.as_canonical_u32(); + let imm_sign = adapter_row.imm_sign.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + + let ptr_val = rs1_val.wrapping_add(imm_extended); + let shift_amount = ptr_val % 4; + + range_checker_chip.add_count( + (adapter_row.mem_ptr_limbs[0].as_canonical_u32() - shift_amount) / 4, RV32_CELL_BITS * 2 - 2, ); - self.range_checker_chip.add_count( - read_record.mem_ptr_limbs[1], - self.air.pointer_max_bits - RV32_CELL_BITS * 2, + range_checker_chip.add_count( + adapter_row.mem_ptr_limbs[1].as_canonical_u32(), + self.pointer_max_bits - RV32_CELL_BITS * 2, ); - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32LoadStoreAdapterCols<_> = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rs1 = memory.record_by_id(read_record.rs1_record); - adapter_cols.rs1_data.copy_from_slice(rs1.data_slice()); - aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols); - adapter_cols.rs1_ptr = read_record.rs1_ptr; - adapter_cols.rd_rs2_ptr = write_record.rd_rs2_ptr; - let read = memory.record_by_id(read_record.read); - aux_cols_factory.generate_read_aux(read, &mut adapter_cols.read_data_aux); - adapter_cols.imm = read_record.imm; - adapter_cols.imm_sign = read_record.imm_sign; - adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32); - adapter_cols.mem_as = read_record.mem_as; - if write_record.write_id.0 != usize::MAX { - let write = memory.record_by_id(write_record.write_id); - aux_cols_factory.generate_base_aux(write, &mut adapter_cols.write_base_aux); - adapter_cols.needs_write = F::ONE; + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.rs1_aux_cols.as_mut()); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.read_data_aux.as_mut()); + timestamp += 1; + + if adapter_row.needs_write.is_one() { + mem_helper.fill_from_prev(timestamp, &mut adapter_row.write_base_aux); } } +} + +impl AdapterExecutorE1 for Rv32LoadStoreAdapterStep +where + F: PrimeField32, +{ + // TODO(ayush): directly use u32 + type ReadData = ( + ([u8; RV32_REGISTER_NUM_LIMBS], [u8; RV32_REGISTER_NUM_LIMBS]), + u32, + ); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { + opcode, + a, + b, + c, + d, + e, + g, + .. + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert!(e.as_canonical_u32() != RV32_IMM_AS); + + let local_opcode = Rv32LoadStoreOpcode::from_usize( + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + + let rs1_bytes: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, d.as_canonical_u32(), b.as_canonical_u32()); + let rs1_val = u32::from_le_bytes(rs1_bytes); + + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + + let ptr_val = rs1_val.wrapping_add(imm_extended); + assert!( + ptr_val < (1 << self.pointer_max_bits), + "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}", + self.pointer_max_bits + ); + let shift_amount = ptr_val % 4; + + let ptr_val = ptr_val - shift_amount; // aligned ptr + + let read_data: [u8; RV32_REGISTER_NUM_LIMBS] = match local_opcode { + LOADW | LOADB | LOADH | LOADBU | LOADHU => { + memory_read_from_state(state, e.as_canonical_u32(), ptr_val) + } + STOREW | STOREH | STOREB => { + memory_read_from_state(state, RV32_REGISTER_AS, a.as_canonical_u32()) + } + }; + + // For stores, we need the previous memory content to preserve unchanged bytes + let prev_data: [u8; RV32_REGISTER_NUM_LIMBS] = match local_opcode { + STOREW | STOREH | STOREB => memory_read(state.memory, e.as_canonical_u32(), ptr_val), + LOADW | LOADB | LOADH | LOADBU | LOADHU => { + memory_read(state.memory, RV32_REGISTER_AS, a.as_canonical_u32()) + } + }; + + ((prev_data, read_data), shift_amount) + } + + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + data: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + // TODO(ayush): remove duplication with read + let &Instruction { + opcode, + a, + b, + c, + d, + e, + f: enabled, + g, + .. + } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert!(e.as_canonical_u32() != RV32_IMM_AS); + + let local_opcode = Rv32LoadStoreOpcode::from_usize( + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + + let rs1_bytes: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read(state.memory, RV32_REGISTER_AS, b.as_canonical_u32()); + let rs1_val = u32::from_le_bytes(rs1_bytes); + + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + + let ptr_val = rs1_val.wrapping_add(imm_extended); + assert!( + ptr_val < (1 << self.pointer_max_bits), + "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}", + self.pointer_max_bits + ); + let shift_amount = ptr_val % 4; + + let ptr_val = ptr_val - shift_amount; // aligned ptr + + let mem_ptr_limbs: [u32; 2] = + array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff)); - fn air(&self) -> &Self::Air { - &self.air + if enabled != F::ZERO { + match local_opcode { + STOREW | STOREH | STOREB => { + let ptr = mem_ptr_limbs[0] + mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2)); + memory_write_from_state(state, e.as_canonical_u32(), ptr & 0xfffffffc, data); + } + LOADW | LOADB | LOADH | LOADBU | LOADHU => { + memory_write_from_state(state, RV32_REGISTER_AS, a.as_canonical_u32(), data); + } + } + } } } diff --git a/extensions/rv32im/circuit/src/adapters/mod.rs b/extensions/rv32im/circuit/src/adapters/mod.rs index ab15671b74..388bfb9d32 100644 --- a/extensions/rv32im/circuit/src/adapters/mod.rs +++ b/extensions/rv32im/circuit/src/adapters/mod.rs @@ -1,6 +1,15 @@ use std::ops::Mul; -use openvm_circuit::system::memory::{MemoryController, RecordId}; +use openvm_circuit::{ + arch::{execution_mode::E1E2ExecutionCtx, VmStateMut}, + system::memory::{ + offline_checker::{MemoryBaseAuxCols, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + tree::public_values::PUBLIC_VALUES_AS, + MemoryController, RecordId, + }, +}; +use openvm_instructions::riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}; use openvm_stark_backend::p3_field::{FieldAlgebra, PrimeField32}; mod alu; @@ -46,6 +55,187 @@ pub fn decompose(value: u32) -> [F; RV32_REGISTER_NUM_LIMBS] { }) } +#[inline(always)] +pub fn imm_to_bytes(imm: u32) -> [u8; RV32_REGISTER_NUM_LIMBS] { + debug_assert_eq!(imm >> 24, 0); + let mut imm_le = imm.to_le_bytes(); + imm_le[3] = imm_le[2]; + imm_le +} + +#[inline(always)] +pub fn memory_read(memory: &GuestMemory, address_space: u32, ptr: u32) -> [u8; N] { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS, + ); + + // TODO(ayush): PUBLIC_VALUES_AS safety? + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn memory_write( + memory: &mut GuestMemory, + address_space: u32, + ptr: u32, + data: &[u8; N], +) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // TODO(ayush): PUBLIC_VALUES_AS safety? + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +#[inline(always)] +pub fn memory_read_from_state( + state: &mut VmStateMut, + address_space: u32, + ptr: u32, +) -> [u8; N] +where + Ctx: E1E2ExecutionCtx, +{ + state.ctx.on_memory_operation(address_space, ptr, N as u32); + + memory_read(state.memory, address_space, ptr) +} + +#[inline(always)] +pub fn memory_write_from_state( + state: &mut VmStateMut, + address_space: u32, + ptr: u32, + data: &[u8; N], +) where + Ctx: E1E2ExecutionCtx, +{ + state.ctx.on_memory_operation(address_space, ptr, N as u32); + + memory_write(state.memory, address_space, ptr, data) +} + +/// Atomic read operation which increments the timestamp by 1. +/// Returns `(t_prev, [ptr:4]_{address_space})` where `t_prev` is the timestamp of the last memory +/// access. +#[inline(always)] +pub fn timed_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, +) -> (u32, [u8; N]) { + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.read::(address_space, ptr) } +} + +#[inline(always)] +pub fn timed_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: &[u8; N], +) -> (u32, [u8; N]) { + // TODO(ayush): should this allow public values address space + debug_assert!( + address_space == RV32_REGISTER_AS + || address_space == RV32_MEMORY_AS + || address_space == PUBLIC_VALUES_AS + ); + + // SAFETY: + // - address space `RV32_REGISTER_AS` and `RV32_MEMORY_AS` will always have cell type `u8` and + // minimum alignment of `RV32_REGISTER_NUM_LIMBS` + unsafe { memory.write::(address_space, ptr, data) } +} + +/// Reads register value at `reg_ptr` from memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_read( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + aux_cols: &mut MemoryReadAuxCols, /* TODO[jpw]: switch to raw u8 + * buffer */ +) -> [u8; N] +where + F: PrimeField32, +{ + let (t_prev, data) = timed_read(memory, address_space, ptr); + aux_cols.set_prev(F::from_canonical_u32(t_prev)); + data +} + +#[inline(always)] +pub fn tracing_read_imm( + memory: &mut TracingMemory, + imm: u32, + imm_mut: &mut F, +) -> [u8; RV32_REGISTER_NUM_LIMBS] +where + F: PrimeField32, +{ + *imm_mut = F::from_canonical_u32(imm); + memory.increment_timestamp(); + imm_to_bytes(imm) +} + +/// Writes `reg_ptr, reg_val` into memory and records the memory access in mutable buffer. +/// Trace generation relevant to this memory access can be done fully from the recorded buffer. +#[inline(always)] +pub fn tracing_write( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: &[u8; N], + aux_cols: &mut MemoryWriteAuxCols, /* TODO[jpw]: switch to raw + * u8 + * buffer */ +) where + F: PrimeField32, +{ + let (t_prev, data_prev) = timed_write(memory, address_space, ptr, data); + aux_cols.set_prev( + F::from_canonical_u32(t_prev), + data_prev.map(F::from_canonical_u8), + ); +} + +// TODO(ayush): this is bad but not sure how to avoid +#[inline(always)] +pub fn tracing_write_with_base_aux( + memory: &mut TracingMemory, + address_space: u32, + ptr: u32, + data: &[u8; N], + base_aux_cols: &mut MemoryBaseAuxCols, +) where + F: PrimeField32, +{ + let (t_prev, _) = timed_write(memory, address_space, ptr, data); + base_aux_cols.set_prev(F::from_canonical_u32(t_prev)); +} + +// TODO: delete /// Read register value as [RV32_REGISTER_NUM_LIMBS] limbs from memory. /// Returns the read record and the register value as u32. /// Does not make any range check calls. @@ -55,16 +245,34 @@ pub fn read_rv32_register( pointer: F, ) -> (RecordId, u32) { debug_assert_eq!(address_space, F::ONE); - let record = memory.read::(address_space, pointer); - let val = compose(record.1); + let record = memory.read::(address_space, pointer); + let val = u32::from_le_bytes(record.1); (record.0, val) } +#[inline(always)] +pub fn new_read_rv32_register(memory: &GuestMemory, address_space: u32, ptr: u32) -> u32 { + u32::from_le_bytes(memory_read(memory, address_space, ptr)) +} + +// TODO(AG): if "register", why `address_space` is not hardcoded to be 1? +#[inline(always)] +pub fn new_read_rv32_register_from_state( + state: &mut VmStateMut, + address_space: u32, + ptr: u32, +) -> u32 +where + Ctx: E1E2ExecutionCtx, +{ + u32::from_le_bytes(memory_read_from_state(state, address_space, ptr)) +} + /// Peeks at the value of a register without updating the memory state or incrementing the /// timestamp. pub fn unsafe_read_rv32_register(memory: &MemoryController, pointer: F) -> u32 { - let data = memory.unsafe_read::(F::ONE, pointer); - compose(data) + let data = memory.unsafe_read::(F::ONE, pointer); + u32::from_le_bytes(data) } pub fn abstract_compose>( @@ -76,3 +284,8 @@ pub fn abstract_compose>( acc + limb * T::from_canonical_u32(1 << (i * RV32_CELL_BITS)) }) } + +// TEMP[jpw] +pub fn tmp_convert_to_u8s(data: [F; N]) -> [u8; N] { + data.map(|x| x.as_canonical_u32() as u8) +} diff --git a/extensions/rv32im/circuit/src/adapters/mul.rs b/extensions/rv32im/circuit/src/adapters/mul.rs index a82e83acaa..de5460e402 100644 --- a/extensions/rv32im/circuit/src/adapters/mul.rs +++ b/extensions/rv32im/circuit/src/adapters/mul.rs @@ -1,20 +1,15 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -26,49 +21,9 @@ use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; - -use super::RV32_REGISTER_NUM_LIMBS; - -/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d. -/// Operand d can only be 1, and there is no immediate support. -#[derive(Debug)] -pub struct Rv32MultAdapterChip { - pub air: Rv32MultAdapterAir, - _marker: PhantomData, -} - -impl Rv32MultAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32MultAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32MultReadRecord { - /// Reads from operand registers - pub rs1: RecordId, - pub rs2: RecordId, -} - -#[repr(C)] -#[derive(Debug, Serialize, Deserialize)] -pub struct Rv32MultWriteRecord { - pub from_state: ExecutionState, - /// Write to destination register - pub rd_id: RecordId, -} +use super::{tracing_write, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{memory_read_from_state, memory_write_from_state, tracing_read}; #[repr(C)] #[derive(AlignedBorrow)] @@ -81,6 +36,8 @@ pub struct Rv32MultAdapterCols { pub writes_aux: MemoryWriteAuxCols, } +/// Reads instructions of the form OP a, b, c, d where \[a:4\]_d = \[b:4\]_d op \[c:4\]_d. +/// Operand d can only be 1, and there is no immediate support. #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32MultAdapterAir { pub(super) execution_bridge: ExecutionBridge, @@ -167,92 +124,143 @@ impl VmAdapterAir for Rv32MultAdapterAir { } } -impl VmAdapterChip for Rv32MultAdapterChip { - type ReadRecord = Rv32MultReadRecord; - type WriteRecord = Rv32MultWriteRecord; - type Air = Rv32MultAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; +#[derive(derive_new::new)] +pub struct Rv32MultAdapterStep; + +impl AdapterTraceStep for Rv32MultAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32MultAdapterCols = adapter_row.borrow_mut(); + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, .. } = *instruction; + adapter_row: &mut [F], + ) -> Self::ReadData { + let &Instruction { b, c, d, .. } = instruction; debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - let rs1 = memory.read::(d, b); - let rs2 = memory.read::(d, c); + let adapter_row: &mut Rv32MultAdapterCols = adapter_row.borrow_mut(); - Ok(( - [rs1.1, rs2.1], - Self::ReadRecord { - rs1: rs1.0, - rs2: rs2.0, - }, - )) + adapter_row.rs1_ptr = b; + let rs1 = tracing_read( + memory, + RV32_REGISTER_AS, + b.as_canonical_u32(), + &mut adapter_row.reads_aux[0], + ); + adapter_row.rs2_ptr = c; + let rs2 = tracing_read( + memory, + RV32_REGISTER_AS, + c.as_canonical_u32(), + &mut adapter_row.reads_aux[1], + ); + + [rs1, rs2] } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { a, d, .. } = instruction; - let timestamp_delta = memory.timestamp() - from_state.timestamp; - debug_assert!( - timestamp_delta == 3, - "timestamp delta is {}, expected 3", - timestamp_delta - ); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + let adapter_row: &mut Rv32MultAdapterCols = adapter_row.borrow_mut(); - Ok(( - ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + adapter_row.rd_ptr = a; + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + &data[0], + &mut adapter_row.writes_aux, + ) } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _trace_ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let row_slice: &mut Rv32MultAdapterCols<_> = row_slice.borrow_mut(); - row_slice.from_state = write_record.from_state.map(F::from_canonical_u32); - let rd = memory.record_by_id(write_record.rd_id); - row_slice.rd_ptr = rd.pointer; - let rs1 = memory.record_by_id(read_record.rs1); - let rs2 = memory.record_by_id(read_record.rs2); - row_slice.rs1_ptr = rs1.pointer; - row_slice.rs2_ptr = rs2.pointer; - aux_cols_factory.generate_read_aux(rs1, &mut row_slice.reads_aux[0]); - aux_cols_factory.generate_read_aux(rs2, &mut row_slice.reads_aux[1]); - aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux); + let adapter_row: &mut Rv32MultAdapterCols = adapter_row.borrow_mut(); + + let mut timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[0].as_mut()); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.reads_aux[1].as_mut()); + timestamp += 1; + + mem_helper.fill_from_prev(timestamp, adapter_row.writes_aux.as_mut()); } +} + +impl AdapterExecutorE1 for Rv32MultAdapterStep +where + F: PrimeField32, +{ + // TODO(ayush): directly use u32 + type ReadData = [[u8; RV32_REGISTER_NUM_LIMBS]; 2]; + type WriteData = [[u8; RV32_REGISTER_NUM_LIMBS]; 1]; + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { b, c, d, .. } = instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + let rs1: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, RV32_REGISTER_AS, b.as_canonical_u32()); + let rs2: [u8; RV32_REGISTER_NUM_LIMBS] = + memory_read_from_state(state, RV32_REGISTER_AS, c.as_canonical_u32()); + + [rs1, rs2] + } + + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + rd: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, .. } = *instruction; + + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - fn air(&self) -> &Self::Air { - &self.air + memory_write_from_state(state, RV32_REGISTER_AS, a.as_canonical_u32(), &rd[0]); } } diff --git a/extensions/rv32im/circuit/src/adapters/rdwrite.rs b/extensions/rv32im/circuit/src/adapters/rdwrite.rs index abd4d8eb17..d577d32a0b 100644 --- a/extensions/rv32im/circuit/src/adapters/rdwrite.rs +++ b/extensions/rv32im/circuit/src/adapters/rdwrite.rs @@ -1,20 +1,15 @@ -use std::{ - borrow::{Borrow, BorrowMut}, - marker::PhantomData, -}; +use std::borrow::{Borrow, BorrowMut}; use openvm_circuit::{ arch::{ - AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, - ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip, - VmAdapterInterface, + execution_mode::E1E2ExecutionCtx, AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, + BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir, + VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryWriteAuxCols}, - MemoryAddress, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, }, }; use openvm_circuit_primitives::utils::not; @@ -27,59 +22,9 @@ use openvm_stark_backend::{ p3_air::{AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, }; -use serde::{Deserialize, Serialize}; use super::RV32_REGISTER_NUM_LIMBS; - -/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 -#[derive(Debug)] -pub struct Rv32RdWriteAdapterChip { - pub air: Rv32RdWriteAdapterAir, - _marker: PhantomData, -} - -/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 -#[derive(Debug)] -pub struct Rv32CondRdWriteAdapterChip { - /// Do not use the inner air directly, use `air` instead. - inner: Rv32RdWriteAdapterChip, - pub air: Rv32CondRdWriteAdapterAir, -} - -impl Rv32RdWriteAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - Self { - air: Rv32RdWriteAdapterAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - }, - _marker: PhantomData, - } - } -} - -impl Rv32CondRdWriteAdapterChip { - pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, - memory_bridge: MemoryBridge, - ) -> Self { - let inner = Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge); - let air = Rv32CondRdWriteAdapterAir { inner: inner.air }; - Self { inner, air } - } -} - -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32RdWriteWriteRecord { - pub from_state: ExecutionState, - pub rd_id: Option, -} +use crate::adapters::{memory_write_from_state, tracing_write}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -92,16 +37,18 @@ pub struct Rv32RdWriteAdapterCols { #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32CondRdWriteAdapterCols { - inner: Rv32RdWriteAdapterCols, + pub inner: Rv32RdWriteAdapterCols, pub needs_write: T, } +/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32RdWriteAdapterAir { pub(super) memory_bridge: MemoryBridge, pub(super) execution_bridge: ExecutionBridge, } +/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 #[derive(Clone, Copy, Debug, derive_new::new)] pub struct Rv32CondRdWriteAdapterAir { inner: Rv32RdWriteAdapterAir, @@ -241,131 +188,237 @@ impl VmAdapterAir for Rv32CondRdWriteAdapterAir { } } -impl VmAdapterChip for Rv32RdWriteAdapterChip { - type ReadRecord = (); - type WriteRecord = Rv32RdWriteWriteRecord; - type Air = Rv32RdWriteAdapterAir; - type Interface = BasicAdapterInterface, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>; +/// This adapter doesn't read anything, and writes to \[a:4\]_d, where d == 1 +#[derive(derive_new::new)] +pub struct Rv32RdWriteAdapterStep; + +impl AdapterTraceStep for Rv32RdWriteAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32RdWriteAdapterCols = adapter_row.borrow_mut(); + adapter_row.from_state.pc = F::from_canonical_u32(pc); + adapter_row.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } + + #[inline(always)] + fn read( + &self, + _memory: &mut TracingMemory, + _instruction: &Instruction, + _adapter_row: &mut [F], + ) -> Self::ReadData { + } - fn preprocess( - &mut self, - _memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let d = instruction.d; + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let &Instruction { a, d, .. } = instruction; + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - Ok(([], ())) - } + let adapter_row: &mut Rv32RdWriteAdapterCols = adapter_row.borrow_mut(); - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let (rd_id, _) = memory.write(d, a, output.writes[0]); - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { - from_state, - rd_id: Some(rd_id), - }, - )) + adapter_row.rd_ptr = a; + tracing_write( + memory, + RV32_REGISTER_AS, + a.as_canonical_u32(), + data, + &mut adapter_row.rd_aux_cols, + ); } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + _trace_ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32RdWriteAdapterCols = row_slice.borrow_mut(); - adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32); - let rd = memory.record_by_id(write_record.rd_id.unwrap()); - adapter_cols.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.rd_aux_cols); + let adapter_row: &mut Rv32RdWriteAdapterCols = adapter_row.borrow_mut(); + + let timestamp = adapter_row.from_state.timestamp.as_canonical_u32(); + + mem_helper.fill_from_prev(timestamp, adapter_row.rd_aux_cols.as_mut()); } +} + +impl AdapterExecutorE1 for Rv32RdWriteAdapterStep +where + F: PrimeField32, +{ + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + + #[inline(always)] + fn read( + &self, + _state: &mut VmStateMut, + _instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + } + + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + rd: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { a, d, .. } = instruction; - fn air(&self) -> &Self::Air { - &self.air + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + + memory_write_from_state(state, RV32_REGISTER_AS, a.as_canonical_u32(), rd); } } -impl VmAdapterChip for Rv32CondRdWriteAdapterChip { - type ReadRecord = (); - type WriteRecord = Rv32RdWriteWriteRecord; - type Air = Rv32CondRdWriteAdapterAir; - type Interface = BasicAdapterInterface, 0, 1, 0, RV32_REGISTER_NUM_LIMBS>; +/// This adapter doesn't read anything, and **maybe** writes to \[a:4\]_d, where d == 1 +#[derive(derive_new::new)] +pub struct Rv32CondRdWriteAdapterStep { + inner: Rv32RdWriteAdapterStep, +} + +impl AdapterTraceStep for Rv32CondRdWriteAdapterStep +where + F: PrimeField32, +{ + const WIDTH: usize = size_of::>(); + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + type TraceContext<'a> = (); + + #[inline(always)] + fn start(pc: u32, memory: &TracingMemory, adapter_row: &mut [F]) { + let adapter_row: &mut Rv32CondRdWriteAdapterCols = adapter_row.borrow_mut(); + + adapter_row.inner.from_state.pc = F::from_canonical_u32(pc); + adapter_row.inner.from_state.timestamp = F::from_canonical_u32(memory.timestamp); + } - fn preprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn read( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - self.inner.preprocess(memory, instruction) + adapter_row: &mut [F], + ) -> Self::ReadData { + >::read( + &self.inner, + memory, + instruction, + adapter_row, + ) } - fn postprocess( - &mut self, - memory: &mut MemoryController, + #[inline(always)] + fn write( + &self, + memory: &mut TracingMemory, instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - let Instruction { a, d, .. } = *instruction; - let rd_id = if instruction.f != F::ZERO { - let (rd_id, _) = memory.write(d, a, output.writes[0]); - Some(rd_id) + adapter_row: &mut [F], + data: &Self::WriteData, + ) { + let Instruction { f: enabled, .. } = instruction; + + if *enabled != F::ZERO { + let (inner_row, needs_write) = unsafe { + adapter_row.split_at_mut_unchecked(size_of::>()) + }; + + needs_write[0] = F::ONE; + >::write( + &self.inner, + memory, + instruction, + inner_row, + data, + ); } else { memory.increment_timestamp(); - None - }; - - Ok(( - ExecutionState { - pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP), - timestamp: memory.timestamp(), - }, - Self::WriteRecord { from_state, rd_id }, - )) + } } - fn generate_trace_row( + #[inline(always)] + fn fill_trace_row( &self, - row_slice: &mut [F], - _read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, + mem_helper: &MemoryAuxColsFactory, + trace_ctx: Self::TraceContext<'_>, + adapter_row: &mut [F], ) { - let aux_cols_factory = memory.aux_cols_factory(); - let adapter_cols: &mut Rv32CondRdWriteAdapterCols = row_slice.borrow_mut(); - adapter_cols.inner.from_state = write_record.from_state.map(F::from_canonical_u32); - if let Some(rd_id) = write_record.rd_id { - let rd = memory.record_by_id(rd_id); - adapter_cols.inner.rd_ptr = rd.pointer; - aux_cols_factory.generate_write_aux(rd, &mut adapter_cols.inner.rd_aux_cols); - adapter_cols.needs_write = F::ONE; + let adapter_row_ref: &mut Rv32CondRdWriteAdapterCols = adapter_row.borrow_mut(); + + if adapter_row_ref.needs_write.is_one() { + let (inner_row, _) = unsafe { + adapter_row.split_at_mut_unchecked(size_of::>()) + }; + + >::fill_trace_row( + &self.inner, + mem_helper, + trace_ctx, + inner_row, + ) } } +} - fn air(&self) -> &Self::Air { - &self.air +impl AdapterExecutorE1 for Rv32CondRdWriteAdapterStep +where + F: PrimeField32, +{ + type ReadData = (); + type WriteData = [u8; RV32_REGISTER_NUM_LIMBS]; + + #[inline(always)] + fn read( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Self::ReadData + where + Ctx: E1E2ExecutionCtx, + { + >::read(&self.inner, state, instruction) + } + + #[inline(always)] + fn write( + &self, + state: &mut VmStateMut, + instruction: &Instruction, + rd: &Self::WriteData, + ) where + Ctx: E1E2ExecutionCtx, + { + let Instruction { f: enabled, .. } = instruction; + + if *enabled != F::ZERO { + >::write( + &self.inner, + state, + instruction, + rd, + ) + } } } diff --git a/extensions/rv32im/circuit/src/auipc/core.rs b/extensions/rv32im/circuit/src/auipc/core.rs index 8ec9e274f6..7e037df397 100644 --- a/extensions/rv32im/circuit/src/auipc/core.rs +++ b/extensions/rv32im/circuit/src/auipc/core.rs @@ -1,17 +1,28 @@ use std::{ - array, + array::{self, from_fn}, borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, ImmInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + LocalOpcode, +}; use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -19,11 +30,8 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; - -use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; +use crate::adapters::{Rv32RdWriteAdapterCols, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] @@ -36,7 +44,7 @@ pub struct Rv32AuipcCoreCols { pub rd_data: [T; RV32_REGISTER_NUM_LIMBS], } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, derive_new::new)] pub struct Rv32AuipcCoreAir { pub bus: BitwiseOperationLookupBus, } @@ -185,117 +193,167 @@ where } } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Rv32AuipcCoreRecord { - pub imm_limbs: [F; RV32_REGISTER_NUM_LIMBS - 1], - pub pc_limbs: [F; RV32_REGISTER_NUM_LIMBS - 2], - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS], -} - -pub struct Rv32AuipcCoreChip { - pub air: Rv32AuipcCoreAir, +pub struct Rv32AuipcCoreStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl Rv32AuipcCoreChip { - pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip) -> Self { +impl Rv32AuipcCoreStep { + pub fn new( + adapter: A, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + ) -> Self { Self { - air: Rv32AuipcCoreAir { - bus: bitwise_lookup_chip.bus(), - }, + adapter, bitwise_lookup_chip, } } } -impl> VmCoreChip for Rv32AuipcCoreChip +impl TraceStep for Rv32AuipcCoreStep where - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = (), + WriteData = [u8; RV32_REGISTER_NUM_LIMBS], + TraceContext<'a> = (), + >, { - type Record = Rv32AuipcCoreRecord; - type Air = Rv32AuipcCoreAir; + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", AUIPC) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - _reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = Rv32AuipcOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32AuipcOpcode::CLASS_OFFSET), - ); - let imm = instruction.c.as_canonical_u32(); - let rd_data = run_auipc(local_opcode, from_pc, imm); - let rd_data_field = rd_data.map(F::from_canonical_u32); - - let output = AdapterRuntimeContext::without_pc([rd_data_field]); - - let imm_limbs = array::from_fn(|i| (imm >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX); - let pc_limbs: [u32; RV32_REGISTER_NUM_LIMBS] = - array::from_fn(|i| (from_pc >> (i * RV32_CELL_BITS)) & RV32_LIMB_MAX); + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let Instruction { opcode, c: imm, .. } = instruction; - for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) { - self.bitwise_lookup_chip - .request_range(rd_data[i * 2], rd_data[i * 2 + 1]); - } + let local_opcode = + Rv32AuipcOpcode::from_usize(opcode.local_opcode_idx(Rv32AuipcOpcode::CLASS_OFFSET)); - let mut need_range_check: Vec = Vec::new(); - for limb in imm_limbs { - need_range_check.push(limb); - } + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; - for (i, limb) in pc_limbs.iter().enumerate().skip(1) { - if i == pc_limbs.len() - 1 { - need_range_check.push((*limb) << (pc_limbs.len() * RV32_CELL_BITS - PC_BITS)); - } else { - need_range_check.push(*limb); - } - } + A::start(*state.pc, state.memory, adapter_row); - for pair in need_range_check.chunks(2) { - self.bitwise_lookup_chip.request_range(pair[0], pair[1]); - } + let imm_u32 = imm.as_canonical_u32(); + let rd = run_auipc(local_opcode, *state.pc, imm_u32); + + let core_row: &mut Rv32AuipcCoreCols = core_row.borrow_mut(); + core_row.rd_data = rd.map(F::from_canonical_u8); + + // TODO(ayush): see if there's a better way + // We decompose during fill_trace_row later: + core_row.imm_limbs[0] = *imm; + + self.adapter + .write(state.memory, instruction, adapter_row, &rd); - Ok(( - output, - Self::Record { - imm_limbs: imm_limbs.map(F::from_canonical_u32), - pc_limbs: array::from_fn(|i| F::from_canonical_u32(pc_limbs[i + 1])), - rd_data: rd_data.map(F::from_canonical_u32), - }, - )) + // TODO(ayush): add increment_pc function to vmstate + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32AuipcOpcode::from_usize(opcode - Rv32AuipcOpcode::CLASS_OFFSET) - ) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let core_row: &mut Rv32AuipcCoreCols = core_row.borrow_mut(); + + core_row.is_valid = F::ONE; + + // TODO(ayush): this is bad since we're treating adapters as generic. maybe + // add a .state() function to adapters or get_from_pc like in air + let adapter_row: &mut Rv32RdWriteAdapterCols = adapter_row.borrow_mut(); + let from_pc = adapter_row.from_state.pc.as_canonical_u32(); + + let pc_limbs = from_pc.to_le_bytes(); + let imm = core_row.imm_limbs[0].as_canonical_u32(); + let imm_limbs = imm.to_le_bytes(); + debug_assert_eq!(imm_limbs[3], 0); + core_row.imm_limbs = from_fn(|i| F::from_canonical_u8(imm_limbs[i])); + // only the middle 2 limbs: + core_row.pc_limbs = from_fn(|i| F::from_canonical_u8(pc_limbs[i + 1])); + + // range checks: + let rd_data = core_row.rd_data.map(|x| x.as_canonical_u32()); + for pair in rd_data.chunks_exact(2) { + self.bitwise_lookup_chip.request_range(pair[0], pair[1]); + } + // hardcoding for performance: first 3 limbs of imm_limbs, last 3 limbs of pc_limbs where + // most significant limb of pc_limbs is shifted up + self.bitwise_lookup_chip + .request_range(imm_limbs[0] as u32, imm_limbs[1] as u32); + self.bitwise_lookup_chip + .request_range(imm_limbs[2] as u32, pc_limbs[1] as u32); + let msl_shift = RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - PC_BITS; + self.bitwise_lookup_chip + .request_range(pc_limbs[2] as u32, (pc_limbs[3] as u32) << msl_shift); } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut Rv32AuipcCoreCols = row_slice.borrow_mut(); - core_cols.imm_limbs = record.imm_limbs; - core_cols.pc_limbs = record.pc_limbs; - core_cols.rd_data = record.rd_data; - core_cols.is_valid = F::ONE; +impl StepExecutorE1 for Rv32AuipcCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, c: imm, .. } = instruction; + + let local_opcode = + Rv32AuipcOpcode::from_usize(opcode.local_opcode_idx(Rv32AuipcOpcode::CLASS_OFFSET)); + + let imm = imm.as_canonical_u32(); + let rd = run_auipc(local_opcode, *state.pc, imm); + + self.adapter.write(state, instruction, &rd); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // returns rd_data +// TODO(ayush): remove _opcode +#[inline(always)] pub(super) fn run_auipc( _opcode: Rv32AuipcOpcode, pc: u32, imm: u32, -) -> [u32; RV32_REGISTER_NUM_LIMBS] { +) -> [u8; RV32_REGISTER_NUM_LIMBS] { let rd = pc.wrapping_add(imm << RV32_CELL_BITS); - array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX) + rd.to_le_bytes() } diff --git a/extensions/rv32im/circuit/src/auipc/mod.rs b/extensions/rv32im/circuit/src/auipc/mod.rs index 6e2234bfbd..f2aa252ee8 100644 --- a/extensions/rv32im/circuit/src/auipc/mod.rs +++ b/extensions/rv32im/circuit/src/auipc/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::Rv32RdWriteAdapterChip; +use crate::adapters::{Rv32RdWriteAdapterAir, Rv32RdWriteAdapterStep}; mod core; pub use core::*; @@ -8,4 +8,6 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32AuipcChip = VmChipWrapper, Rv32AuipcCoreChip>; +pub type Rv32AuipcAir = VmAirWrapper; +pub type Rv32AuipcStep = Rv32AuipcCoreStep; +pub type Rv32AuipcChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/auipc/tests.rs b/extensions/rv32im/circuit/src/auipc/tests.rs index 2c8a399198..42f1cd1468 100644 --- a/extensions/rv32im/circuit/src/auipc/tests.rs +++ b/extensions/rv32im/circuit/src/auipc/tests.rs @@ -1,31 +1,60 @@ use std::borrow::BorrowMut; -use openvm_circuit::arch::{testing::VmChipTestBuilder, VmAdapterChip}; +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + VmAirWrapper, +}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::Rv32AuipcOpcode::{self, *}; use openvm_stark_backend::{ - interaction::BusIndex, p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{run_auipc, Rv32AuipcChip, Rv32AuipcCoreChip, Rv32AuipcCoreCols}; -use crate::adapters::{Rv32RdWriteAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::{run_auipc, Rv32AuipcChip, Rv32AuipcCoreAir, Rv32AuipcCoreCols, Rv32AuipcStep}; +use crate::{ + adapters::{ + Rv32RdWriteAdapterAir, Rv32RdWriteAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + test_utils::get_verification_error, +}; const IMM_BITS: usize = 24; -const BITWISE_OP_LOOKUP_BUS: BusIndex = 9; - +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32AuipcChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let chip = Rv32AuipcChip::::new( + VmAirWrapper::new( + Rv32RdWriteAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + Rv32AuipcCoreAir::new(bitwise_bus), + ), + Rv32AuipcStep::new(Rv32RdWriteAdapterStep::new(), bitwise_chip.clone()), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, bitwise_chip) +} + fn set_and_execute( tester: &mut VmChipTestBuilder, chip: &mut Rv32AuipcChip, @@ -43,10 +72,8 @@ fn set_and_execute( initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), ); let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); - let rd_data = run_auipc(opcode, initial_pc, imm as u32); - - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -59,17 +86,8 @@ fn set_and_execute( #[test] fn rand_auipc_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32AuipcCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32AuipcChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); let num_tests: usize = 100; for _ in 0..num_tests { @@ -84,32 +102,26 @@ fn rand_auipc_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct AuipcPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub imm_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, + pub pc_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 2]>, +} + fn run_negative_auipc_test( opcode: Rv32AuipcOpcode, initial_imm: Option, initial_pc: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - imm_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, - pc_limbs: Option<[u32; RV32_REGISTER_NUM_LIMBS - 2]>, - expected_error: VerificationError, + prank_vals: AuipcPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let core = Rv32AuipcCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32AuipcChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); set_and_execute( &mut tester, @@ -120,39 +132,32 @@ fn run_negative_auipc_test( initial_pc, ); - let tester = tester.build(); - - let auipc_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let auipc_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = auipc_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (_, core_row) = trace_row.split_at_mut(adapter_width); - let core_cols: &mut Rv32AuipcCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(data) = imm_limbs { + if let Some(data) = prank_vals.imm_limbs { core_cols.imm_limbs = data.map(F::from_canonical_u32); } - - if let Some(data) = pc_limbs { + if let Some(data) = prank_vals.pc_limbs { core_cols.pc_limbs = data.map(F::from_canonical_u32); } - *auipc_trace = RowMajorMatrix::new(trace_row, auipc_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; + disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) + .build() + .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -161,47 +166,53 @@ fn invalid_limb_negative_tests() { AUIPC, Some(9722891), None, - None, - Some([107, 46, 81]), - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + imm_limbs: Some([107, 46, 81]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, Some(0), Some(2110400), - Some([194, 51, 32, 240]), - None, - Some([51, 32]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([194, 51, 32, 240]), + pc_limbs: Some([51, 32]), + ..Default::default() + }, + true, ); run_negative_auipc_test( AUIPC, None, None, - None, - None, - Some([206, 166]), - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + pc_limbs: Some([206, 166]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, None, None, - Some([30, 92, 82, 132]), - None, - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + rd_data: Some([30, 92, 82, 132]), + ..Default::default() + }, + false, ); - run_negative_auipc_test( AUIPC, None, Some(876487877), - Some([197, 202, 49, 70]), - Some([166, 243, 17]), - Some([36, 62]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([197, 202, 49, 70]), + imm_limbs: Some([166, 243, 17]), + pc_limbs: Some([36, 62]), + }, + true, ); } @@ -211,37 +222,42 @@ fn overflow_negative_tests() { AUIPC, Some(256264), None, - None, - Some([3592, 219, 3]), - None, - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + imm_limbs: Some([3592, 219, 3]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, None, None, - None, - None, - Some([0, 0]), - VerificationError::OodEvaluationMismatch, + AuipcPrankValues { + pc_limbs: Some([0, 0]), + ..Default::default() + }, + false, ); run_negative_auipc_test( AUIPC, Some(255), None, - None, - Some([F::NEG_ONE.as_canonical_u32(), 1, 0]), - None, - VerificationError::ChallengePhaseError, + AuipcPrankValues { + imm_limbs: Some([F::NEG_ONE.as_canonical_u32(), 1, 0]), + ..Default::default() + }, + true, ); run_negative_auipc_test( AUIPC, Some(0), Some(255), - Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), - Some([0, 0, 0]), - Some([1, 0]), - VerificationError::ChallengePhaseError, + AuipcPrankValues { + rd_data: Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), + imm_limbs: Some([0, 0, 0]), + pc_limbs: Some([1, 0]), + }, + true, ); } @@ -251,27 +267,6 @@ fn overflow_negative_tests() { /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32RdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32AuipcCoreChip::new(bitwise_chip); - let mut chip = Rv32AuipcChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, AUIPC, None, None); - } -} - #[test] fn run_auipc_sanity_test() { let opcode = AUIPC; diff --git a/extensions/rv32im/circuit/src/base_alu/core.rs b/extensions/rv32im/circuit/src/base_alu/core.rs index a87418cc91..b63ef95479 100644 --- a/extensions/rv32im/circuit/src/base_alu/core.rs +++ b/extensions/rv32im/circuit/src/base_alu/core.rs @@ -1,18 +1,26 @@ use std::{ array, borrow::{Borrow, BorrowMut}, + iter::zip, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::BaseAluOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,8 +28,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -38,10 +44,10 @@ pub struct BaseAluCoreCols { pub opcode_and_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct BaseAluCoreAir { pub bus: BitwiseOperationLookupBus, - offset: usize, + pub offset: usize, } impl BaseAir @@ -165,175 +171,217 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct BaseAluCoreRecord { - pub opcode: BaseAluOpcode, - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], -} - -pub struct BaseAluCoreChip { - pub air: BaseAluCoreAir, +pub struct BaseAluStep { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl BaseAluCoreChip { +impl BaseAluStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { - air: BaseAluCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, + adapter, + offset, bitwise_lookup_chip, } } } -impl VmCoreChip - for BaseAluCoreChip +impl TraceStep + for BaseAluStep where F: PrimeField32, - I: VmAdapterInterface, - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), + >, { - type Record = BaseAluCoreRecord; - type Air = BaseAluCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", BaseAluOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let Instruction { opcode, .. } = instruction; - let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let a = run_alu::(local_opcode, &b, &c); + let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let output = AdapterRuntimeContext { - to_pc: None, - writes: [a.map(F::from_canonical_u32)].into(), - }; + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + + let rd = run_alu::(local_opcode, &rs1, &rs2); + + let core_row: &mut BaseAluCoreCols = core_row.borrow_mut(); + core_row.a = rd.map(F::from_canonical_u8); + core_row.b = rs1.map(F::from_canonical_u8); + core_row.c = rs2.map(F::from_canonical_u8); + core_row.opcode_add_flag = F::from_bool(local_opcode == BaseAluOpcode::ADD); + core_row.opcode_sub_flag = F::from_bool(local_opcode == BaseAluOpcode::SUB); + core_row.opcode_xor_flag = F::from_bool(local_opcode == BaseAluOpcode::XOR); + core_row.opcode_or_flag = F::from_bool(local_opcode == BaseAluOpcode::OR); + core_row.opcode_and_flag = F::from_bool(local_opcode == BaseAluOpcode::AND); - if local_opcode == BaseAluOpcode::ADD || local_opcode == BaseAluOpcode::SUB { - for a_val in a { + self.adapter + .write(state.memory, instruction, adapter_row, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let core_row: &mut BaseAluCoreCols = core_row.borrow_mut(); + + if core_row.opcode_add_flag == F::ONE || core_row.opcode_sub_flag == F::ONE { + for a_val in core_row.a.map(|x| x.as_canonical_u32()) { self.bitwise_lookup_chip.request_xor(a_val, a_val); } } else { - for (b_val, c_val) in b.iter().zip(c.iter()) { - self.bitwise_lookup_chip.request_xor(*b_val, *c_val); + let b = core_row.b.map(|x| x.as_canonical_u32()); + let c = core_row.c.map(|x| x.as_canonical_u32()); + for (b_val, c_val) in zip(b, c) { + self.bitwise_lookup_chip.request_xor(b_val, c_val); } } + } +} - let record = Self::Record { - opcode: local_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - }; +impl StepExecutorE1 + for BaseAluStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; - Ok((output, record)) - } + let local_opcode = BaseAluOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", BaseAluOpcode::from_usize(opcode - self.air.offset)) - } + let [rs1, rs2] = self.adapter.read(state, instruction).into(); + let rd = run_alu::(local_opcode, &rs1, &rs2); + self.adapter.write(state, instruction, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BaseAluCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.opcode_add_flag = F::from_bool(record.opcode == BaseAluOpcode::ADD); - row_slice.opcode_sub_flag = F::from_bool(record.opcode == BaseAluOpcode::SUB); - row_slice.opcode_xor_flag = F::from_bool(record.opcode == BaseAluOpcode::XOR); - row_slice.opcode_or_flag = F::from_bool(record.opcode == BaseAluOpcode::OR); - row_slice.opcode_and_flag = F::from_bool(record.opcode == BaseAluOpcode::AND); + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } +#[inline(always)] pub(super) fn run_alu( opcode: BaseAluOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + debug_assert!(LIMB_BITS <= 8, "specialize for bytes"); match opcode { BaseAluOpcode::ADD => run_add::(x, y), BaseAluOpcode::SUB => run_subtract::(x, y), - BaseAluOpcode::XOR => run_xor::(x, y), - BaseAluOpcode::OR => run_or::(x, y), - BaseAluOpcode::AND => run_and::(x, y), + BaseAluOpcode::XOR => run_xor::(x, y), + BaseAluOpcode::OR => run_or::(x, y), + BaseAluOpcode::AND => run_and::(x, y), } } +#[inline(always)] fn run_add( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { - let mut z = [0u32; NUM_LIMBS]; - let mut carry = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + let mut z = [0u8; NUM_LIMBS]; + let mut carry = [0u8; NUM_LIMBS]; for i in 0..NUM_LIMBS { - z[i] = x[i] + y[i] + if i > 0 { carry[i - 1] } else { 0 }; - carry[i] = z[i] >> LIMB_BITS; - z[i] &= (1 << LIMB_BITS) - 1; + let mut overflow = + (x[i] as u16) + (y[i] as u16) + if i > 0 { carry[i - 1] as u16 } else { 0 }; + carry[i] = (overflow >> LIMB_BITS) as u8; + overflow &= (1u16 << LIMB_BITS) - 1; + z[i] = overflow as u8; } z } +#[inline(always)] fn run_subtract( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { - let mut z = [0u32; NUM_LIMBS]; - let mut carry = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> [u8; NUM_LIMBS] { + let mut z = [0u8; NUM_LIMBS]; + let mut carry = [0u8; NUM_LIMBS]; for i in 0..NUM_LIMBS { - let rhs = y[i] + if i > 0 { carry[i - 1] } else { 0 }; - if x[i] >= rhs { - z[i] = x[i] - rhs; + let rhs = y[i] as u16 + if i > 0 { carry[i - 1] as u16 } else { 0 }; + if x[i] as u16 >= rhs { + z[i] = x[i] - rhs as u8; carry[i] = 0; } else { - z[i] = x[i] + (1 << LIMB_BITS) - rhs; + z[i] = (x[i] as u16 + (1u16 << LIMB_BITS) - rhs) as u8; carry[i] = 1; } } z } -fn run_xor( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_xor(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] ^ y[i]) } -fn run_or( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_or(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] | y[i]) } -fn run_and( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> [u32; NUM_LIMBS] { +#[inline(always)] +fn run_and(x: &[u8; NUM_LIMBS], y: &[u8; NUM_LIMBS]) -> [u8; NUM_LIMBS] { array::from_fn(|i| x[i] & y[i]) } diff --git a/extensions/rv32im/circuit/src/base_alu/mod.rs b/extensions/rv32im/circuit/src/base_alu/mod.rs index cbda8ce555..266a7ee453 100644 --- a/extensions/rv32im/circuit/src/base_alu/mod.rs +++ b/extensions/rv32im/circuit/src/base_alu/mod.rs @@ -1,7 +1,8 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32BaseAluAdapterChip; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -9,8 +10,8 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32BaseAluChip = VmChipWrapper< - F, - Rv32BaseAluAdapterChip, - BaseAluCoreChip, ->; +pub type Rv32BaseAluAir = + VmAirWrapper>; +pub type Rv32BaseAluStep = + BaseAluStep, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>; +pub type Rv32BaseAluChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/base_alu/tests.rs b/extensions/rv32im/circuit/src/base_alu/tests.rs index 165cd12526..813eb34435 100644 --- a/extensions/rv32im/circuit/src/base_alu/tests.rs +++ b/extensions/rv32im/circuit/src/base_alu/tests.rs @@ -1,45 +1,110 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge, ExecutionState, - MinimalInstruction, Result, VmAdapterChip, VmAdapterInterface, VmChipWrapper, - }, - system::memory::{MemoryController, OfflineMemory}, - utils::generate_long_number, +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::BaseAluOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::BaseAluOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, - p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_field::{FieldAlgebra, PrimeField32}, p3_matrix::{ dense::{DenseMatrix, RowMajorMatrix}, Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_alu, BaseAluCoreChip, Rv32BaseAluChip}; +use super::{core::run_alu, BaseAluCoreAir, Rv32BaseAluChip, Rv32BaseAluStep}; use crate::{ adapters::{ - Rv32BaseAluAdapterAir, Rv32BaseAluAdapterChip, Rv32BaseAluReadRecord, - Rv32BaseAluWriteRecord, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, }, base_alu::BaseAluCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, }; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32BaseAluChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let chip = Rv32BaseAluChip::new( + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + BaseAluCoreAir::new(bitwise_bus, BaseAluOpcode::CLASS_OFFSET), + ), + Rv32BaseAluStep::new( + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), + bitwise_chip.clone(), + BaseAluOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, bitwise_chip) +} + +fn set_and_execute( + tester: &mut VmChipTestBuilder, + chip: &mut Rv32BaseAluChip, + rng: &mut StdRng, + opcode: BaseAluOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) + } else { + generate_rv32_is_type_immediate(rng) + }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(chip, &instruction); + + let a = run_alu::(opcode, &b, &c) + .map(F::from_canonical_u8); + assert_eq!(a, tester.read::(1, rd)) +} + ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS // @@ -47,135 +112,105 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// -fn run_rv32_alu_rand_test(opcode: BaseAluOpcode, num_ops: usize) { +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test(opcode: BaseAluOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) - } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) - }; + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); - let (instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let a = run_alu::(opcode, &b, &c) - .map(F::from_canonical_u32); - assert_eq!(a, tester.read::(1, rd)) + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_alu_add_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::ADD, 100); -} +#[test_case(ADD, 100)] +#[test_case(SUB, 100)] +#[test_case(XOR, 100)] +#[test_case(OR, 100)] +#[test_case(AND, 100)] +fn rand_rv32_alu_test_persistent(opcode: BaseAluOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); -#[test] -fn rv32_alu_sub_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::SUB, 100); -} + let mut tester = VmChipTestBuilder::default_persistent(); + let (mut chip, bitwise_chip) = create_test_chip(&tester); -#[test] -fn rv32_alu_xor_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::XOR, 100); -} + // TODO(AG): make a more meaningful test for memory accesses + tester.write(2, 1024, [F::ONE; 4]); + tester.write(2, 1028, [F::ONE; 4]); + let sm = tester.read(2, 1024); + assert_eq!(sm, [F::ONE; 8]); -#[test] -fn rv32_alu_or_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::OR, 100); -} + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); + } -#[test] -fn rv32_alu_and_rand_test() { - run_rv32_alu_rand_test(BaseAluOpcode::AND, 100); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BaseAluTestChip = - VmChipWrapper, BaseAluCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_alu_negative_test( +fn run_negative_alu_test( opcode: BaseAluOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + prank_opcode_flags: Option<[bool; 5]>, + is_imm: Option, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + &mut rng, + opcode, + Some(b), + is_imm, + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - - if (opcode == BaseAluOpcode::ADD || opcode == BaseAluOpcode::SUB) - && a.iter().all(|&a_val| a_val < (1 << RV32_CELL_BITS)) - { - bitwise_chip.clear(); - for a_val in a { - bitwise_chip.request_xor(a_val, a_val); - } - } - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut BaseAluCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + if let Some(prank_c) = prank_c { + cols.c = prank_c.map(F::from_canonical_u32); + } + if let Some(prank_opcode_flags) = prank_opcode_flags { + cols.opcode_add_flag = F::from_bool(prank_opcode_flags[0]); + cols.opcode_and_flag = F::from_bool(prank_opcode_flags[1]); + cols.opcode_or_flag = F::from_bool(prank_opcode_flags[2]); + cols.opcode_sub_flag = F::from_bool(prank_opcode_flags[3]); + cols.opcode_xor_flag = F::from_bool(prank_opcode_flags[4]); + } + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -184,90 +219,135 @@ fn run_rv32_alu_negative_test( .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_alu_add_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::ADD, + run_negative_alu_test( + ADD, [246, 0, 0, 0], [250, 0, 0, 0], [250, 0, 0, 0], + None, + None, + None, false, ); } #[test] fn rv32_alu_add_out_of_range_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::ADD, + run_negative_alu_test( + ADD, [500, 0, 0, 0], [250, 0, 0, 0], [250, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_sub_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::SUB, + run_negative_alu_test( + SUB, [255, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], + None, + None, + None, false, ); } #[test] fn rv32_alu_sub_out_of_range_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::SUB, + run_negative_alu_test( + SUB, [F::NEG_ONE.as_canonical_u32(), 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_xor_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::XOR, + run_negative_alu_test( + XOR, [255, 255, 255, 255], [0, 0, 1, 0], [255, 255, 255, 255], + None, + None, + None, true, ); } #[test] fn rv32_alu_or_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::OR, + run_negative_alu_test( + OR, [255, 255, 255, 255], [255, 255, 255, 254], [0, 0, 0, 0], + None, + None, + None, true, ); } #[test] fn rv32_alu_and_wrong_negative_test() { - run_rv32_alu_negative_test( - BaseAluOpcode::AND, + run_negative_alu_test( + AND, [255, 255, 255, 255], [0, 0, 1, 0], [0, 0, 0, 0], + None, + None, + None, true, ); } +#[test] +fn rv32_alu_adapter_unconstrained_imm_limb_test() { + run_negative_alu_test( + ADD, + [255, 7, 0, 0], + [0, 0, 0, 0], + [255, 7, 0, 0], + Some([511, 6, 0, 0]), + None, + Some(true), + true, + ); +} + +#[test] +fn rv32_alu_adapter_unconstrained_rs2_read_test() { + run_negative_alu_test( + ADD, + [2, 2, 2, 2], + [1, 1, 1, 1], + [1, 1, 1, 1], + None, + Some([false, false, false, false, false]), + Some(false), + false, + ); +} + /////////////////////////////////////////////////////////////////////////////////////// /// SANITY TESTS /// @@ -276,10 +356,10 @@ fn rv32_alu_and_wrong_negative_test() { #[test] fn run_add_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; - let result = run_alu::(BaseAluOpcode::ADD, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [23, 205, 73, 49]; + let result = run_alu::(ADD, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -287,10 +367,10 @@ fn run_add_sanity_test() { #[test] fn run_sub_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; - let result = run_alu::(BaseAluOpcode::SUB, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [179, 118, 240, 172]; + let result = run_alu::(SUB, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -298,10 +378,10 @@ fn run_sub_sanity_test() { #[test] fn run_xor_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; - let result = run_alu::(BaseAluOpcode::XOR, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [215, 138, 49, 173]; + let result = run_alu::(XOR, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -309,10 +389,10 @@ fn run_xor_sanity_test() { #[test] fn run_or_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; - let result = run_alu::(BaseAluOpcode::OR, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [247, 171, 61, 239]; + let result = run_alu::(OR, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -320,195 +400,11 @@ fn run_or_sanity_test() { #[test] fn run_and_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; - let result = run_alu::(BaseAluOpcode::AND, &x, &y); + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [229, 33, 29, 111]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [50, 171, 44, 194]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [32, 33, 12, 66]; + let result = run_alu::(AND, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } } - -////////////////////////////////////////////////////////////////////////////////////// -// ADAPTER TESTS -// -// Ensure that the adapter is correct. -////////////////////////////////////////////////////////////////////////////////////// - -// A pranking chip where `preprocess` can have `rs2` limbs that overflow. -struct Rv32BaseAluAdapterTestChip(Rv32BaseAluAdapterChip); - -impl VmAdapterChip for Rv32BaseAluAdapterTestChip { - type ReadRecord = Rv32BaseAluReadRecord; - type WriteRecord = Rv32BaseAluWriteRecord; - type Air = Rv32BaseAluAdapterAir; - type Interface = BasicAdapterInterface< - F, - MinimalInstruction, - 2, - 1, - RV32_REGISTER_NUM_LIMBS, - RV32_REGISTER_NUM_LIMBS, - >; - - fn preprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - ) -> Result<( - >::Reads, - Self::ReadRecord, - )> { - let Instruction { b, c, d, e, .. } = *instruction; - - let rs1 = memory.read::(d, b); - let (rs2, rs2_data, rs2_imm) = if e.is_zero() { - let c_u32 = c.as_canonical_u32(); - memory.increment_timestamp(); - let mask1 = (1 << 9) - 1; - let mask2 = (1 << 3) - 2; - ( - None, - [ - (c_u32 & mask1) as u16, - ((c_u32 >> 8) & mask2) as u16, - (c_u32 >> 16) as u16, - (c_u32 >> 16) as u16, - ] - .map(F::from_canonical_u16), - c, - ) - } else { - let rs2_read = memory.read::(e, c); - (Some(rs2_read.0), rs2_read.1, F::ZERO) - }; - - Ok(( - [rs1.1, rs2_data], - Self::ReadRecord { - rs1: rs1.0, - rs2, - rs2_imm, - }, - )) - } - - fn postprocess( - &mut self, - memory: &mut MemoryController, - instruction: &Instruction, - from_state: ExecutionState, - output: AdapterRuntimeContext, - _read_record: &Self::ReadRecord, - ) -> Result<(ExecutionState, Self::WriteRecord)> { - self.0 - .postprocess(memory, instruction, from_state, output, _read_record) - } - - fn generate_trace_row( - &self, - row_slice: &mut [F], - read_record: Self::ReadRecord, - write_record: Self::WriteRecord, - memory: &OfflineMemory, - ) { - self.0 - .generate_trace_row(row_slice, read_record, write_record, memory) - } - - fn air(&self) -> &Self::Air { - self.0.air() - } -} - -#[test] -fn rv32_alu_adapter_unconstrained_imm_limb_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = VmChipWrapper::new( - Rv32BaseAluAdapterTestChip(Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - )), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - let b = [0, 0, 0, 0]; - let (c_imm, c) = { - let imm = (1 << 11) - 1; - let fake_c = [(1 << 9) - 1, (1 << 3) - 2, 0, 0]; - (Some(imm), fake_c) - }; - - let (instruction, _rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - BaseAluOpcode::ADD.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - disable_debug_builder(); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test_with_expected_error(VerificationError::ChallengePhaseError); -} - -#[test] -fn rv32_alu_adapter_unconstrained_rs2_read_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BaseAluChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - BaseAluCoreChip::new(bitwise_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - let b = [1, 1, 1, 1]; - let c = [1, 1, 1, 1]; - let (instruction, _rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - None, - BaseAluOpcode::ADD.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - - let modify_trace = |trace: &mut DenseMatrix| { - let mut values = trace.row_slice(0).to_vec(); - let mut dummy_values = values.clone(); - let cols: &mut BaseAluCoreCols = - dummy_values.split_at_mut(adapter_width).1.borrow_mut(); - cols.opcode_add_flag = F::ZERO; - values.extend(dummy_values); - *trace = RowMajorMatrix::new(values, trace_width); - }; - - disable_debug_builder(); - let tester = tester - .build() - .load_and_prank_trace(chip, modify_trace) - .load(bitwise_chip) - .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); -} diff --git a/extensions/rv32im/circuit/src/branch_eq/core.rs b/extensions/rv32im/circuit/src/branch_eq/core.rs index bb04d86ee5..91547cf3f1 100644 --- a/extensions/rv32im/circuit/src/branch_eq/core.rs +++ b/extensions/rv32im/circuit/src/branch_eq/core.rs @@ -3,9 +3,16 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, ImmInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::utils::not; use openvm_circuit_primitives_derive::AlignedBorrow; @@ -17,8 +24,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -37,7 +42,7 @@ pub struct BranchEqualCoreCols { pub diff_inv_marker: [T; NUM_LIMBS], } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct BranchEqualCoreAir { offset: usize, pc_step: u32, @@ -134,116 +139,152 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BranchEqualCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - pub cmp_result: T, - pub imm: T, - pub diff_inv_val: T, - pub diff_idx: usize, - pub opcode: BranchEqualOpcode, +pub struct BranchEqualStep { + adapter: A, + pub offset: usize, + pub pc_step: u32, } -#[derive(Debug)] -pub struct BranchEqualCoreChip { - pub air: BranchEqualCoreAir, -} - -impl BranchEqualCoreChip { - pub fn new(offset: usize, pc_step: u32) -> Self { +impl BranchEqualStep { + pub fn new(adapter: A, offset: usize, pc_step: u32) -> Self { Self { - air: BranchEqualCoreAir { offset, pc_step }, + adapter, + offset, + pc_step, } } } -impl, const NUM_LIMBS: usize> VmCoreChip - for BranchEqualCoreChip +impl TraceStep for BranchEqualStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: Default, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData = (), + TraceContext<'a> = (), + >, { - type Record = BranchEqualCoreRecord; - type Air = BranchEqualCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", BranchEqualOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, c: imm, .. } = *instruction; - let branch_eq_opcode = - BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let x = data[0].map(|x| x.as_canonical_u32()); - let y = data[1].map(|y| y.as_canonical_u32()); - let (cmp_result, diff_idx, diff_inv_val) = run_eq::(branch_eq_opcode, &x, &y); - - let output = AdapterRuntimeContext { - to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()), - writes: Default::default(), - }; - let record = BranchEqualCoreRecord { - opcode: branch_eq_opcode, - a: data[0], - b: data[1], - cmp_result: F::from_bool(cmp_result), - imm, - diff_idx, - diff_inv_val, - }; - - Ok((output, record)) + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { opcode, c: imm, .. } = instruction; + + let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + + let (cmp_result, diff_idx, diff_inv_val) = run_eq(branch_eq_opcode, &rs1, &rs2); + + let core_row: &mut BranchEqualCoreCols<_, NUM_LIMBS> = core_row.borrow_mut(); + core_row.a = rs1.map(F::from_canonical_u8); + core_row.b = rs2.map(F::from_canonical_u8); + core_row.cmp_result = F::from_bool(cmp_result); + core_row.imm = imm; + core_row.opcode_beq_flag = F::from_bool(branch_eq_opcode == BranchEqualOpcode::BEQ); + core_row.opcode_bne_flag = F::from_bool(branch_eq_opcode == BranchEqualOpcode::BNE); + core_row.diff_inv_marker = + array::from_fn(|i| if i == diff_idx { diff_inv_val } else { F::ZERO }); + + if cmp_result { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); + } + + *trace_offset += width; + + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - BranchEqualOpcode::from_usize(opcode - self.air.offset) - ) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, _core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BranchEqualCoreCols<_, NUM_LIMBS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.cmp_result = record.cmp_result; - row_slice.imm = record.imm; - row_slice.opcode_beq_flag = F::from_bool(record.opcode == BranchEqualOpcode::BEQ); - row_slice.opcode_bne_flag = F::from_bool(record.opcode == BranchEqualOpcode::BNE); - row_slice.diff_inv_marker = array::from_fn(|i| { - if i == record.diff_idx { - record.diff_inv_val - } else { - F::ZERO - } - }); +impl StepExecutorE1 for BranchEqualStep +where + F: PrimeField32, + A: 'static + for<'a> AdapterExecutorE1, WriteData = ()>, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { opcode, c: imm, .. } = instruction; + + let branch_eq_opcode = BranchEqualOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let [rs1, rs2] = self.adapter.read(state, instruction).into(); + + // TODO(ayush): probably don't need the other values + let (cmp_result, _, _) = run_eq::(branch_eq_opcode, &rs1, &rs2); + + if cmp_result { + // TODO(ayush): verify this is fine + // state.pc = state.pc.wrapping_add(imm.as_canonical_u32()); + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(self.pc_step); + } + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // Returns (cmp_result, diff_idx, x[diff_idx] - y[diff_idx]) -pub(super) fn run_eq( +#[inline(always)] +pub(super) fn run_eq( local_opcode: BranchEqualOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> (bool, usize, F) { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> (bool, usize, F) +where + F: PrimeField32, +{ for i in 0..NUM_LIMBS { if x[i] != y[i] { return ( local_opcode == BranchEqualOpcode::BNE, i, - (F::from_canonical_u32(x[i]) - F::from_canonical_u32(y[i])).inverse(), + (F::from_canonical_u8(x[i]) - F::from_canonical_u8(y[i])).inverse(), ); } } diff --git a/extensions/rv32im/circuit/src/branch_eq/mod.rs b/extensions/rv32im/circuit/src/branch_eq/mod.rs index 7d53946a73..9172197e2a 100644 --- a/extensions/rv32im/circuit/src/branch_eq/mod.rs +++ b/extensions/rv32im/circuit/src/branch_eq/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; use super::adapters::RV32_REGISTER_NUM_LIMBS; -use crate::adapters::Rv32BranchAdapterChip; +use crate::adapters::{Rv32BranchAdapterAir, Rv32BranchAdapterStep}; mod core; pub use core::*; @@ -9,5 +9,7 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32BranchEqualChip = - VmChipWrapper, BranchEqualCoreChip>; +pub type Rv32BranchEqualAir = + VmAirWrapper>; +pub type Rv32BranchEqualStep = BranchEqualStep; +pub type Rv32BranchEqualChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/branch_eq/tests.rs b/extensions/rv32im/circuit/src/branch_eq/tests.rs index c16858b071..e5da1f901f 100644 --- a/extensions/rv32im/circuit/src/branch_eq/tests.rs +++ b/extensions/rv32im/circuit/src/branch_eq/tests.rs @@ -1,11 +1,14 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{ - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder}, - BasicAdapterInterface, ExecutionBridge, ImmInstruction, InstructionExecutor, VmAdapterChip, - VmChipWrapper, VmCoreChip, + testing::{memory::gen_pointer, VmChipTestBuilder}, + InstructionExecutor, VmAirWrapper, +}; +use openvm_instructions::{ + instruction::Instruction, + program::{DEFAULT_PC_STEP, PC_BITS}, + LocalOpcode, }; -use openvm_instructions::{instruction::Instruction, program::PC_BITS, LocalOpcode}; use openvm_rv32im_transpiler::BranchEqualOpcode; use openvm_stark_backend::{ p3_air::BaseAir, @@ -15,42 +18,67 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::{ - core::{run_eq, BranchEqualCoreChip}, + core::{run_eq, BranchEqualStep}, BranchEqualCoreCols, Rv32BranchEqualChip, }; -use crate::adapters::{Rv32BranchAdapterChip, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS}; +use crate::{ + adapters::{ + Rv32BranchAdapterAir, Rv32BranchAdapterStep, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS, + }, + test_utils::get_verification_error, + BranchEqualCoreAir, +}; type F = BabyBear; - -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); + +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Rv32BranchEqualChip { + Rv32BranchEqualChip::::new( + VmAirWrapper::new( + Rv32BranchAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchEqualCoreAir::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + BranchEqualStep::new( + Rv32BranchAdapterStep::new(), + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_branch_eq_rand_execute>( +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - opcode: BranchEqualOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - imm: i32, rng: &mut StdRng, + opcode: BranchEqualOpcode, + a: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + imm: Option, ) { + let a = a.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let b = b.unwrap_or(if rng.gen_bool(0.5) { + a + } else { + array::from_fn(|_| rng.gen_range(0..=u8::MAX)) + }); + + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - tester.write::(1, rs1, a.map(F::from_canonical_u32)); - tester.write::(1, rs2, b.map(F::from_canonical_u32)); + tester.write::(1, rs1, a.map(F::from_canonical_u8)); + tester.write::(1, rs2, b.map(F::from_canonical_u8)); + let initial_pc = rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))); tester.execute_with_pc( chip, &Instruction::from_isize( @@ -61,7 +89,7 @@ fn run_rv32_branch_eq_rand_execute>( 1, 1, ), - rng.gen_range(imm.unsigned_abs()..(1 << (PC_BITS - 1))), + initial_pc, ); let (cmp_result, _, _) = run_eq::(opcode, &a, &b); @@ -72,94 +100,71 @@ fn run_rv32_branch_eq_rand_execute>( assert_eq!(to_pc, from_pc + pc_inc); } -fn run_rv32_branch_eq_rand_test(opcode: BranchEqualOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(BranchEqualOpcode::BEQ, 100)] +#[test_case(BranchEqualOpcode::BNE, 100)] +fn rand_rv32_branch_eq_test(opcode: BranchEqualOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchEqualChip::::new( - Rv32BranchAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&mut tester); for _ in 0..num_ops { - let a = array::from_fn(|_| rng.gen_range(0..F::ORDER_U32)); - let b = if rng.gen_bool(0.5) { - a - } else { - array::from_fn(|_| rng.gen_range(0..F::ORDER_U32)) - }; - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - run_rv32_branch_eq_rand_execute(&mut tester, &mut chip, opcode, a, b, imm, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_beq_rand_test() { - run_rv32_branch_eq_rand_test(BranchEqualOpcode::BEQ, 100); -} - -#[test] -fn rv32_bne_rand_test() { - run_rv32_branch_eq_rand_test(BranchEqualOpcode::BNE, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BranchEqualTestChip = - VmChipWrapper, BranchEqualCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_beq_negative_test( +fn run_negative_branch_eq_test( opcode: BranchEqualOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, - diff_inv_marker: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + a: [u8; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: Option, + prank_diff_inv_marker: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + interaction_error: bool, ) { - let imm = 16u32; + let imm = 16i32; + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchEqualTestChip::::new( - TestAdapterChip::new( - vec![[a.map(F::from_canonical_u32), b.map(F::from_canonical_u32)].concat()], - vec![if cmp_result { Some(imm) } else { None }], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, 4), - tester.offline_memory_mutex_arc(), - ); + let mut chip = create_test_chip(&mut tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, imm as usize, 1, 1]), + &mut rng, + opcode, + Some(a), + Some(b), + Some(imm), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut BranchEqualCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.cmp_result = F::from_bool(cmp_result); - if let Some(diff_inv_marker) = diff_inv_marker { + if let Some(cmp_result) = prank_cmp_result { + cols.cmp_result = F::from_bool(cmp_result); + } + if let Some(diff_inv_marker) = prank_diff_inv_marker { cols.diff_inv_marker = diff_inv_marker.map(F::from_canonical_u32); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -167,88 +172,96 @@ fn run_rv32_beq_negative_test( .build() .load_and_prank_trace(chip, modify_trace) .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_beq_wrong_cmp_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 0, 7], - true, + Some(true), None, + false, ); - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 7, 0], - false, + Some(false), None, + false, ); } #[test] fn rv32_beq_zero_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 0, 7], - true, + Some(true), Some([0, 0, 0, 0]), + false, ); } #[test] fn rv32_beq_invalid_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BEQ, [0, 0, 7, 0], [0, 0, 7, 0], - false, + Some(false), Some([0, 0, 1, 0]), + false, ); } #[test] fn rv32_bne_wrong_cmp_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 0, 7], - false, + Some(false), None, + false, ); - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 7, 0], - true, + Some(true), None, + false, ); } #[test] fn rv32_bne_zero_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 0, 7], - false, + Some(false), Some([0, 0, 0, 0]), + false, ); } #[test] fn rv32_bne_invalid_inv_marker_negative_test() { - run_rv32_beq_negative_test( + run_negative_branch_eq_test( BranchEqualOpcode::BNE, [0, 0, 7, 0], [0, 0, 7, 0], - true, + Some(true), Some([0, 0, 1, 0]), + false, ); } @@ -259,38 +272,37 @@ fn rv32_bne_invalid_inv_marker_negative_test() { /////////////////////////////////////////////////////////////////////////////////////// #[test] -fn execute_pc_increment_sanity_test() { - let core = - BranchEqualCoreChip::::new(BranchEqualOpcode::CLASS_OFFSET, 4); - - let mut instruction = Instruction:: { - opcode: BranchEqualOpcode::BEQ.global_opcode(), - c: F::from_canonical_u8(8), - ..Default::default() - }; - let x: [F; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60].map(F::from_canonical_u32); - let y: [F; RV32_REGISTER_NUM_LIMBS] = [19, 32, 1804, 60].map(F::from_canonical_u32); - - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, y]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_none()); - - instruction.opcode = BranchEqualOpcode::BNE.global_opcode(); - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, y]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_some()); - assert_eq!(output.to_pc.unwrap(), 8); +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let mut chip = create_test_chip(&mut tester); + + let x = [19, 4, 179, 60]; + let y = [19, 32, 180, 60]; + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchEqualOpcode::BEQ, + Some(x), + Some(y), + Some(8), + ); + + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchEqualOpcode::BNE, + Some(x), + Some(y), + Some(8), + ); } #[test] fn run_eq_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 4, 17, 60]; let (cmp_result, _, diff_val) = run_eq::(BranchEqualOpcode::BEQ, &x, &x); assert!(cmp_result); @@ -304,13 +316,13 @@ fn run_eq_sanity_test() { #[test] fn run_ne_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 4, 1790, 60]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [19, 32, 1804, 60]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 4, 17, 60]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [19, 32, 18, 60]; let (cmp_result, diff_idx, diff_val) = run_eq::(BranchEqualOpcode::BEQ, &x, &y); assert!(!cmp_result); assert_eq!( - diff_val * (F::from_canonical_u32(x[diff_idx]) - F::from_canonical_u32(y[diff_idx])), + diff_val * (F::from_canonical_u8(x[diff_idx]) - F::from_canonical_u8(y[diff_idx])), F::ONE ); @@ -318,7 +330,7 @@ fn run_ne_sanity_test() { run_eq::(BranchEqualOpcode::BNE, &x, &y); assert!(cmp_result); assert_eq!( - diff_val * (F::from_canonical_u32(x[diff_idx]) - F::from_canonical_u32(y[diff_idx])), + diff_val * (F::from_canonical_u8(x[diff_idx]) - F::from_canonical_u8(y[diff_idx])), F::ONE ); } diff --git a/extensions/rv32im/circuit/src/branch_lt/core.rs b/extensions/rv32im/circuit/src/branch_lt/core.rs index 3eebb02146..9a777b2d1a 100644 --- a/extensions/rv32im/circuit/src/branch_lt/core.rs +++ b/extensions/rv32im/circuit/src/branch_lt/core.rs @@ -3,9 +3,16 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, ImmInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, @@ -20,8 +27,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -53,7 +58,7 @@ pub struct BranchLessThanCoreCols { pub bus: BitwiseOperationLookupBus, offset: usize, @@ -187,67 +192,72 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -pub struct BranchLessThanCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - pub cmp_result: T, - pub cmp_lt: T, - pub imm: T, - pub a_msb_f: T, - pub b_msb_f: T, - pub diff_val: T, - pub diff_idx: usize, - pub opcode: BranchLessThanOpcode, -} - -pub struct BranchLessThanCoreChip { - pub air: BranchLessThanCoreAir, +pub struct BranchLessThanStep { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl BranchLessThanCoreChip { +impl + BranchLessThanStep +{ pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { - air: BranchLessThanCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, + adapter, + offset, bitwise_lookup_chip, } } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for BranchLessThanCoreChip +impl TraceStep + for BranchLessThanStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: Default, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData = (), + TraceContext<'a> = (), + >, { - type Record = BranchLessThanCoreRecord; - type Air = BranchLessThanCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + BranchLessThanOpcode::from_usize(opcode - self.offset) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let Instruction { opcode, c: imm, .. } = *instruction; - let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let a = data[0].map(|x| x.as_canonical_u32()); - let b = data[1].map(|y| y.as_canonical_u32()); + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let &Instruction { opcode, c: imm, .. } = instruction; + + let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + let (cmp_result, diff_idx, a_sign, b_sign) = - run_cmp::(blt_opcode, &a, &b); + run_cmp::(blt_opcode, &rs1, &rs2); let signed = matches!( blt_opcode, @@ -259,6 +269,9 @@ where ); let cmp_lt = cmp_result ^ ge_opcode; + let a = rs1.map(u32::from); + let b = rs2.map(u32::from); + // We range check (a_msb_f + 128) and (b_msb_f + 128) if signed, // a_msb_f and b_msb_f if not let (a_msb_f, a_msb_range) = if a_sign { @@ -283,8 +296,6 @@ where b[NUM_LIMBS - 1] + ((signed as u32) << (LIMB_BITS - 1)), ) }; - self.bitwise_lookup_chip - .request_range(a_msb_range, b_msb_range); let diff_val = if diff_idx == NUM_LIMBS { 0 @@ -301,65 +312,98 @@ where a[diff_idx] - b[diff_idx] }; + let core_row: &mut BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut(); + core_row.a = rs1.map(F::from_canonical_u8); + core_row.b = rs2.map(F::from_canonical_u8); + core_row.cmp_result = F::from_bool(cmp_result); + core_row.cmp_lt = F::from_bool(cmp_lt); + core_row.imm = imm; + core_row.a_msb_f = a_msb_f; + core_row.b_msb_f = b_msb_f; + core_row.diff_marker = array::from_fn(|i| F::from_bool(i == diff_idx)); + core_row.diff_val = F::from_canonical_u32(diff_val); + core_row.opcode_blt_flag = F::from_bool(blt_opcode == BranchLessThanOpcode::BLT); + core_row.opcode_bltu_flag = F::from_bool(blt_opcode == BranchLessThanOpcode::BLTU); + core_row.opcode_bge_flag = F::from_bool(blt_opcode == BranchLessThanOpcode::BGE); + core_row.opcode_bgeu_flag = F::from_bool(blt_opcode == BranchLessThanOpcode::BGEU); + + if cmp_result { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + } + + // TODO(ayush): move to fill_trace_row + self.bitwise_lookup_chip + .request_range(a_msb_range, b_msb_range); + if diff_idx != NUM_LIMBS { self.bitwise_lookup_chip.request_range(diff_val - 1, 0); } - let output = AdapterRuntimeContext { - to_pc: cmp_result.then_some((F::from_canonical_u32(from_pc) + imm).as_canonical_u32()), - writes: Default::default(), - }; - let record = BranchLessThanCoreRecord { - opcode: blt_opcode, - a: data[0], - b: data[1], - cmp_result: F::from_bool(cmp_result), - cmp_lt: F::from_bool(cmp_lt), - imm, - a_msb_f, - b_msb_f, - diff_val: F::from_canonical_u32(diff_val), - diff_idx, - }; + *trace_offset += width; - Ok((output, record)) + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - BranchLessThanOpcode::from_usize(opcode - self.air.offset) - ) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, _core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); } +} + +impl StepExecutorE1 + for BranchLessThanStep +where + F: PrimeField32, + A: 'static + for<'a> AdapterExecutorE1, WriteData = ()>, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { opcode, c: imm, .. } = instruction; + + let blt_opcode = BranchLessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let [rs1, rs2] = self.adapter.read(state, instruction).into(); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut BranchLessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = - row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.cmp_result = record.cmp_result; - row_slice.cmp_lt = record.cmp_lt; - row_slice.imm = record.imm; - row_slice.a_msb_f = record.a_msb_f; - row_slice.b_msb_f = record.b_msb_f; - row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx)); - row_slice.diff_val = record.diff_val; - row_slice.opcode_blt_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLT); - row_slice.opcode_bltu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BLTU); - row_slice.opcode_bge_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGE); - row_slice.opcode_bgeu_flag = F::from_bool(record.opcode == BranchLessThanOpcode::BGEU); + // TODO(ayush): probably don't need the other values + let (cmp_result, _, _, _) = run_cmp::(blt_opcode, &rs1, &rs2); + + if cmp_result { + *state.pc = (F::from_canonical_u32(*state.pc) + imm).as_canonical_u32(); + } else { + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + } + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // Returns (cmp_result, diff_idx, x_sign, y_sign) +#[inline(always)] pub(super) fn run_cmp( local_opcode: BranchLessThanOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize, bool, bool) { let signed = local_opcode == BranchLessThanOpcode::BLT || local_opcode == BranchLessThanOpcode::BGE; diff --git a/extensions/rv32im/circuit/src/branch_lt/mod.rs b/extensions/rv32im/circuit/src/branch_lt/mod.rs index b0bf8fc417..dba3751be2 100644 --- a/extensions/rv32im/circuit/src/branch_lt/mod.rs +++ b/extensions/rv32im/circuit/src/branch_lt/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32BranchAdapterChip; +use crate::adapters::{Rv32BranchAdapterAir, Rv32BranchAdapterStep}; mod core; pub use core::*; @@ -9,8 +9,11 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32BranchLessThanChip = VmChipWrapper< - F, - Rv32BranchAdapterChip, - BranchLessThanCoreChip, +pub type Rv32BranchLessThanAir = VmAirWrapper< + Rv32BranchAdapterAir, + BranchLessThanCoreAir, >; +pub type Rv32BranchLessThanStep = + BranchLessThanStep; +pub type Rv32BranchLessThanChip = + NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/branch_lt/tests.rs b/extensions/rv32im/circuit/src/branch_lt/tests.rs index 8c1d7f697a..2c84f4b1ea 100644 --- a/extensions/rv32im/circuit/src/branch_lt/tests.rs +++ b/extensions/rv32im/circuit/src/branch_lt/tests.rs @@ -1,12 +1,11 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ - testing::{memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - BasicAdapterInterface, ExecutionBridge, ImmInstruction, InstructionExecutor, VmAdapterChip, - VmChipWrapper, VmCoreChip, + testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InstructionExecutor, VmAirWrapper, }, - utils::{generate_long_number, i32_to_f}, + utils::i32_to_f, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -21,46 +20,76 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::{ - core::{run_cmp, BranchLessThanCoreChip}, + core::{run_cmp, BranchLessThanStep}, Rv32BranchLessThanChip, }; use crate::{ adapters::{ - Rv32BranchAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_B_TYPE_IMM_BITS, + Rv32BranchAdapterAir, Rv32BranchAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + RV_B_TYPE_IMM_BITS, }, branch_lt::BranchLessThanCoreCols, + test_utils::get_verification_error, + BranchLessThanCoreAir, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; +const ABS_MAX_IMM: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32BranchLessThanChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let chip = Rv32BranchLessThanChip::::new( + VmAirWrapper::new( + Rv32BranchAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + BranchLessThanCoreAir::new(bitwise_bus, BranchLessThanOpcode::CLASS_OFFSET), + ), + BranchLessThanStep::new( + Rv32BranchAdapterStep::new(), + bitwise_chip.clone(), + BranchLessThanOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, bitwise_chip) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_branch_lt_rand_execute>( +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - opcode: BranchLessThanOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - imm: i32, rng: &mut StdRng, + opcode: BranchLessThanOpcode, + a: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + imm: Option, ) { + let a = a.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let b = b.unwrap_or(if rng.gen_bool(0.5) { + a + } else { + array::from_fn(|_| rng.gen_range(0..=u8::MAX)) + }); + + let imm = imm.unwrap_or(rng.gen_range((-ABS_MAX_IMM)..ABS_MAX_IMM)); let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - tester.write::(1, rs1, a.map(F::from_canonical_u32)); - tester.write::(1, rs2, b.map(F::from_canonical_u32)); + tester.write::(1, rs1, a.map(F::from_canonical_u8)); + tester.write::(1, rs2, b.map(F::from_canonical_u8)); tester.execute_with_pc( chip, @@ -83,93 +112,57 @@ fn run_rv32_branch_lt_rand_execute>( assert_eq!(to_pc, from_pc + pc_inc); } -fn run_rv32_branch_lt_rand_test(opcode: BranchLessThanOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); - const ABS_MAX_BRANCH: i32 = 1 << (RV_B_TYPE_IMM_BITS - 1); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(BranchLessThanOpcode::BLT, 100)] +#[test_case(BranchLessThanOpcode::BLTU, 100)] +#[test_case(BranchLessThanOpcode::BGE, 100)] +#[test_case(BranchLessThanOpcode::BGEU, 100)] +fn rand_branch_lt_test(opcode: BranchLessThanOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32BranchLessThanChip::::new( - Rv32BranchAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - BranchLessThanCoreChip::new(bitwise_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); for _ in 0..num_ops { - let a = generate_long_number::(&mut rng); - let b = if rng.gen_bool(0.5) { - a - } else { - generate_long_number::(&mut rng) - }; - let imm = rng.gen_range((-ABS_MAX_BRANCH)..ABS_MAX_BRANCH); - run_rv32_branch_lt_rand_execute(&mut tester, &mut chip, opcode, a, b, imm, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } // Test special case where b = c - run_rv32_branch_lt_rand_execute( + set_and_execute( &mut tester, &mut chip, - opcode, - [101, 128, 202, 255], - [101, 128, 202, 255], - 24, &mut rng, + opcode, + Some([101, 128, 202, 255]), + Some([101, 128, 202, 255]), + Some(24), ); - run_rv32_branch_lt_rand_execute( + set_and_execute( &mut tester, &mut chip, - opcode, - [36, 0, 0, 0], - [36, 0, 0, 0], - 24, &mut rng, + opcode, + Some([36, 0, 0, 0]), + Some([36, 0, 0, 0]), + Some(24), ); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_blt_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BLT, 10); -} - -#[test] -fn rv32_bltu_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BLTU, 12); -} - -#[test] -fn rv32_bge_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BGE, 12); -} - -#[test] -fn rv32_bgeu_rand_test() { - run_rv32_branch_lt_rand_test(BranchLessThanOpcode::BGEU, 12); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32BranchLessThanTestChip = VmChipWrapper< - F, - TestAdapterChip, - BranchLessThanCoreChip, ->; - #[derive(Clone, Copy, Default, PartialEq)] struct BranchLessThanPrankValues { pub a_msb: Option, @@ -179,66 +172,31 @@ struct BranchLessThanPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_blt_negative_test( +fn run_negative_branch_lt_test( opcode: BranchLessThanOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, + a: [u8; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: bool, prank_vals: BranchLessThanPrankValues, interaction_error: bool, ) { - let imm = 16u32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32BranchLessThanTestChip::::new( - TestAdapterChip::new( - vec![[a.map(F::from_canonical_u32), b.map(F::from_canonical_u32)].concat()], - vec![if cmp_result { Some(imm) } else { None }], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - BranchLessThanCoreChip::new(bitwise_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let imm = 16i32; + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, imm as usize, 1, 1]), + &mut rng, + opcode, + Some(a), + Some(b), + Some(imm), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); + let adapter_width = BaseAir::::width(&chip.air.adapter); let ge_opcode = opcode == BranchLessThanOpcode::BGE || opcode == BranchLessThanOpcode::BGEU; - let (_, _, a_sign, b_sign) = run_cmp::(opcode, &a, &b); - - if prank_vals != BranchLessThanPrankValues::default() { - debug_assert!(prank_vals.diff_val.is_some()); - let a_msb = prank_vals.a_msb.unwrap_or( - a[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if a_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let b_msb = prank_vals.b_msb.unwrap_or( - b[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if b_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let signed_offset = match opcode { - BranchLessThanOpcode::BLT | BranchLessThanOpcode::BGE => 1 << (RV32_CELL_BITS - 1), - _ => 0, - }; - - bitwise_chip.clear(); - bitwise_chip.request_range( - (a_msb + signed_offset) as u8 as u32, - (b_msb + signed_offset) as u8 as u32, - ); - - let diff_val = prank_vals - .diff_val - .unwrap() - .clamp(0, (1 << RV32_CELL_BITS) - 1); - if diff_val > 0 { - bitwise_chip.request_range(diff_val - 1, 0); - } - } let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); @@ -257,10 +215,10 @@ fn run_rv32_blt_negative_test( if let Some(diff_val) = prank_vals.diff_val { cols.diff_val = F::from_canonical_u32(diff_val); } - cols.cmp_result = F::from_bool(cmp_result); - cols.cmp_lt = F::from_bool(ge_opcode ^ cmp_result); + cols.cmp_result = F::from_bool(prank_cmp_result); + cols.cmp_lt = F::from_bool(ge_opcode ^ prank_cmp_result); - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -269,11 +227,7 @@ fn run_rv32_blt_negative_test( .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -281,10 +235,10 @@ fn rv32_blt_wrong_lt_cmp_negative_test() { let a = [145, 34, 25, 205]; let b = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -292,10 +246,10 @@ fn rv32_blt_wrong_ge_cmp_negative_test() { let a = [73, 35, 25, 205]; let b = [145, 34, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -303,10 +257,10 @@ fn rv32_blt_wrong_eq_cmp_negative_test() { let a = [73, 35, 25, 205]; let b = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -317,10 +271,10 @@ fn rv32_blt_fake_diff_val_negative_test() { diff_val: Some(F::NEG_ONE.as_canonical_u32()), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } #[test] @@ -332,10 +286,10 @@ fn rv32_blt_zero_diff_val_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } #[test] @@ -347,10 +301,10 @@ fn rv32_blt_fake_diff_marker_negative_test() { diff_val: Some(72), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -362,10 +316,10 @@ fn rv32_blt_zero_diff_marker_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -378,8 +332,8 @@ fn rv32_blt_signed_wrong_a_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, false); } #[test] @@ -392,8 +346,8 @@ fn rv32_blt_signed_wrong_a_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, true, prank_vals, true); } #[test] @@ -406,8 +360,8 @@ fn rv32_blt_signed_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, false); } #[test] @@ -420,8 +374,8 @@ fn rv32_blt_signed_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLT, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGE, a, b, false, prank_vals, true); } #[test] @@ -434,8 +388,8 @@ fn rv32_blt_unsigned_wrong_a_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, false); } #[test] @@ -448,8 +402,8 @@ fn rv32_blt_unsigned_wrong_a_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, false, prank_vals, true); } #[test] @@ -462,8 +416,8 @@ fn rv32_blt_unsigned_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, false); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, false); } #[test] @@ -476,8 +430,8 @@ fn rv32_blt_unsigned_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_blt_negative_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); - run_rv32_blt_negative_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BLTU, a, b, false, prank_vals, true); + run_negative_branch_lt_test(BranchLessThanOpcode::BGEU, a, b, true, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -487,42 +441,37 @@ fn rv32_blt_unsigned_wrong_b_msb_sign_negative_test() { /////////////////////////////////////////////////////////////////////////////////////// #[test] -fn execute_pc_increment_sanity_test() { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let core = BranchLessThanCoreChip::::new( - bitwise_chip, - BranchLessThanOpcode::CLASS_OFFSET, +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, _) = create_test_chip(&mut tester); + + let x = [145, 34, 25, 205]; + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchLessThanOpcode::BLT, + Some(x), + Some(x), + Some(8), ); - let mut instruction = Instruction:: { - opcode: BranchLessThanOpcode::BLT.global_opcode(), - c: F::from_canonical_u8(8), - ..Default::default() - }; - let x: [F; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205].map(F::from_canonical_u32); - - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, x]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_none()); - - instruction.opcode = BranchLessThanOpcode::BGE.global_opcode(); - let result = as VmCoreChip< - F, - BasicAdapterInterface, 2, 0, RV32_REGISTER_NUM_LIMBS, 0>, - >>::execute_instruction(&core, &instruction, 0, [x, x]); - let (output, _) = result.expect("execute_instruction failed"); - assert!(output.to_pc.is_some()); - assert_eq!(output.to_pc.unwrap(), 8); + set_and_execute( + &mut tester, + &mut chip, + &mut rng, + BranchLessThanOpcode::BGE, + Some(x), + Some(x), + Some(8), + ); } #[test] fn run_cmp_unsigned_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::(BranchLessThanOpcode::BLTU, &x, &y); assert!(cmp_result); @@ -540,8 +489,8 @@ fn run_cmp_unsigned_sanity_test() { #[test] fn run_cmp_same_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::(BranchLessThanOpcode::BLT, &x, &y); assert!(cmp_result); @@ -559,8 +508,8 @@ fn run_cmp_same_sign_sanity_test() { #[test] fn run_cmp_diff_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::(BranchLessThanOpcode::BLT, &x, &y); assert!(!cmp_result); @@ -578,7 +527,7 @@ fn run_cmp_diff_sign_sanity_test() { #[test] fn run_cmp_eq_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; let (cmp_result, diff_idx, x_sign, y_sign) = run_cmp::(BranchLessThanOpcode::BLT, &x, &x); assert!(!cmp_result); diff --git a/extensions/rv32im/circuit/src/divrem/core.rs b/extensions/rv32im/circuit/src/divrem/core.rs index bad043d582..8850681e31 100644 --- a/extensions/rv32im/circuit/src/divrem/core.rs +++ b/extensions/rv32im/circuit/src/divrem/core.rs @@ -5,9 +5,16 @@ use std::{ use num_bigint::BigUint; use num_integer::Integer; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, @@ -15,7 +22,7 @@ use openvm_circuit_primitives::{ utils::{not, select}, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::DivRemOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -23,8 +30,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -67,7 +72,7 @@ pub struct DivRemCoreCols { pub opcode_remu_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct DivRemCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_tuple_bus: RangeTupleCheckerBus<2>, @@ -342,14 +347,24 @@ where } } -pub struct DivRemCoreChip { - pub air: DivRemCoreAir, +#[derive(Debug, Eq, PartialEq)] +#[repr(u8)] +pub(super) enum DivRemCoreSpecialCase { + None, + ZeroDivisor, + SignedOverflow, +} + +pub struct DivRemStep { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl DivRemCoreChip { +impl DivRemStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize, @@ -369,82 +384,62 @@ impl DivRemCoreChip { - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub q: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub r: [T; NUM_LIMBS], - pub zero_divisor: T, - pub r_zero: T, - pub b_sign: T, - pub c_sign: T, - pub q_sign: T, - pub sign_xor: T, - pub c_sum_inv: T, - pub r_sum_inv: T, - #[serde(with = "BigArray")] - pub r_prime: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub r_inv: [T; NUM_LIMBS], - pub lt_diff_val: T, - pub lt_diff_idx: usize, - pub opcode: DivRemOpcode, -} - -#[derive(Debug, Eq, PartialEq)] -#[repr(u8)] -pub(super) enum DivRemCoreSpecialCase { - None, - ZeroDivisor, - SignedOverflow, -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for DivRemCoreChip +impl TraceStep + for DivRemStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), + >, { - type Record = DivRemCoreRecord; - type Air = DivRemCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", DivRemOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let Instruction { opcode, .. } = instruction; - let divrem_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let is_div = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::DIVU; + let divrem_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + let is_signed = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::REM; + let is_div = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::DIVU; + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); + let b = rs1.map(u32::from); + let c = rs2.map(u32::from); let (q, r, b_sign, c_sign, q_sign, case) = run_divrem::(is_signed, &b, &c); + // TODO(ayush): move parts to fill_trace_row let carries = run_mul_carries::(is_signed, &c, &q, &r, q_sign); for i in 0..NUM_LIMBS { self.range_tuple_chip.add_count(&[q[i], carries[i]]); @@ -469,7 +464,7 @@ where ); } - let c_sum_f = data[1].iter().fold(F::ZERO, |acc, c| acc + *c); + let c_sum_f = F::from_canonical_u32(c.iter().sum()); let c_sum_inv_f = c_sum_f.try_inverse().unwrap_or(F::ZERO); let r_sum_f = r @@ -491,67 +486,116 @@ where }; let r_prime_f = r_prime.map(F::from_canonical_u32); - let output = AdapterRuntimeContext::without_pc([ - (if is_div { &q } else { &r }).map(F::from_canonical_u32) - ]); - let record = DivRemCoreRecord { - opcode: divrem_opcode, - b: data[0], - c: data[1], - q: q.map(F::from_canonical_u32), - r: r.map(F::from_canonical_u32), - zero_divisor: F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor), - r_zero: F::from_bool(r_zero), - b_sign: F::from_bool(b_sign), - c_sign: F::from_bool(c_sign), - q_sign: F::from_bool(q_sign), - sign_xor: F::from_bool(sign_xor), - c_sum_inv: c_sum_inv_f, - r_sum_inv: r_sum_inv_f, - r_prime: r_prime_f, - r_inv: r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse()), - lt_diff_val: F::from_canonical_u32(lt_diff_val), - lt_diff_idx, + + let core_row: &mut DivRemCoreCols<_, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut(); + core_row.b = rs1.map(F::from_canonical_u8); + core_row.c = rs2.map(F::from_canonical_u8); + core_row.q = q.map(F::from_canonical_u32); + core_row.r = r.map(F::from_canonical_u32); + core_row.zero_divisor = F::from_bool(case == DivRemCoreSpecialCase::ZeroDivisor); + core_row.r_zero = F::from_bool(r_zero); + core_row.b_sign = F::from_bool(b_sign); + core_row.c_sign = F::from_bool(c_sign); + core_row.q_sign = F::from_bool(q_sign); + core_row.sign_xor = F::from_bool(sign_xor); + core_row.c_sum_inv = c_sum_inv_f; + core_row.r_sum_inv = r_sum_inv_f; + core_row.r_prime = r_prime_f; + core_row.r_inv = r_prime_f.map(|r| (r - F::from_canonical_u32(256)).inverse()); + core_row.lt_marker = array::from_fn(|i| F::from_bool(i == lt_diff_idx)); + core_row.lt_diff = F::from_canonical_u32(lt_diff_val); + core_row.opcode_div_flag = F::from_bool(divrem_opcode == DivRemOpcode::DIV); + core_row.opcode_divu_flag = F::from_bool(divrem_opcode == DivRemOpcode::DIVU); + core_row.opcode_rem_flag = F::from_bool(divrem_opcode == DivRemOpcode::REM); + core_row.opcode_remu_flag = F::from_bool(divrem_opcode == DivRemOpcode::REMU); + + let rd = if is_div { + q.map(|x| x as u8) + } else { + r.map(|x| x as u8) }; - Ok((output, record)) + self.adapter + .write(state.memory, instruction, adapter_row, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", DivRemOpcode::from_usize(opcode - self.air.offset)) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, _core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut DivRemCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.q = record.q; - row_slice.r = record.r; - row_slice.zero_divisor = record.zero_divisor; - row_slice.r_zero = record.r_zero; - row_slice.b_sign = record.b_sign; - row_slice.c_sign = record.c_sign; - row_slice.q_sign = record.q_sign; - row_slice.sign_xor = record.sign_xor; - row_slice.c_sum_inv = record.c_sum_inv; - row_slice.r_sum_inv = record.r_sum_inv; - row_slice.r_prime = record.r_prime; - row_slice.r_inv = record.r_inv; - row_slice.lt_marker = array::from_fn(|i| F::from_bool(i == record.lt_diff_idx)); - row_slice.lt_diff = record.lt_diff_val; - row_slice.opcode_div_flag = F::from_bool(record.opcode == DivRemOpcode::DIV); - row_slice.opcode_divu_flag = F::from_bool(record.opcode == DivRemOpcode::DIVU); - row_slice.opcode_rem_flag = F::from_bool(record.opcode == DivRemOpcode::REM); - row_slice.opcode_remu_flag = F::from_bool(record.opcode == DivRemOpcode::REMU); +impl StepExecutorE1 + for DivRemStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = *instruction; + + // Determine opcode and operation type + let divrem_opcode = DivRemOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let [rs1, rs2] = self.adapter.read(state, instruction).into(); + let rs1 = rs1.map(u32::from); + let rs2 = rs2.map(u32::from); + + let is_div = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::DIVU; + let is_signed = divrem_opcode == DivRemOpcode::DIV || divrem_opcode == DivRemOpcode::REM; + + // Perform division/remainder computation + let (q, r, _, _, _, _) = run_divrem::(is_signed, &rs1, &rs2); + + // Determine result based on operation type (DIV or REM) + let rd = if is_div { + q.map(|x| x as u8) + } else { + r.map(|x| x as u8) + }; + + self.adapter.write(state, instruction, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // Returns (quotient, remainder, x_sign, y_sign, q_sign, case) where case = 0 for normal, 1 // for zero divisor, and 2 for signed overflow +#[inline(always)] pub(super) fn run_divrem( signed: bool, x: &[u32; NUM_LIMBS], @@ -628,6 +672,7 @@ pub(super) fn run_divrem( (q, r, x_sign, y_sign, q_sign, DivRemCoreSpecialCase::None) } +#[inline(always)] pub(super) fn run_sltu_diff_idx( x: &[u32; NUM_LIMBS], y: &[u32; NUM_LIMBS], @@ -644,6 +689,7 @@ pub(super) fn run_sltu_diff_idx( } // returns carries of d * q + r +#[inline(always)] pub(super) fn run_mul_carries( signed: bool, d: &[u32; NUM_LIMBS], @@ -684,6 +730,7 @@ pub(super) fn run_mul_carries( carry } +#[inline(always)] fn limbs_to_biguint( x: &[u32; NUM_LIMBS], ) -> BigUint { @@ -696,6 +743,7 @@ fn limbs_to_biguint( res } +#[inline(always)] fn biguint_to_limbs( x: &BigUint, ) -> [u32; NUM_LIMBS] { @@ -711,6 +759,7 @@ fn biguint_to_limbs( res } +#[inline(always)] fn negate( x: &[u32; NUM_LIMBS], ) -> [u32; NUM_LIMBS] { diff --git a/extensions/rv32im/circuit/src/divrem/mod.rs b/extensions/rv32im/circuit/src/divrem/mod.rs index 979ab38dc3..ab75cebf18 100644 --- a/extensions/rv32im/circuit/src/divrem/mod.rs +++ b/extensions/rv32im/circuit/src/divrem/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep}; mod core; pub use core::*; @@ -8,8 +9,7 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32DivRemChip = VmChipWrapper< - F, - Rv32MultAdapterChip, - DivRemCoreChip, ->; +pub type Rv32DivRemAir = + VmAirWrapper>; +pub type Rv32DivRemStep = DivRemStep; +pub type Rv32DivRemChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/divrem/tests.rs b/extensions/rv32im/circuit/src/divrem/tests.rs index 41d8a9cc46..da775724d2 100644 --- a/extensions/rv32im/circuit/src/divrem/tests.rs +++ b/extensions/rv32im/circuit/src/divrem/tests.rs @@ -3,10 +3,9 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ testing::{ - memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, - RANGE_TUPLE_CHECKER_BUS, + memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS, }, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, + InstructionExecutor, VmAirWrapper, }, utils::generate_long_number, }; @@ -15,7 +14,7 @@ use openvm_circuit_primitives::{ range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::DivRemOpcode; +use openvm_rv32im_transpiler::DivRemOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{Field, FieldAlgebra}, @@ -24,29 +23,26 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; use super::core::run_divrem; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, divrem::{ - run_mul_carries, run_sltu_diff_idx, DivRemCoreChip, DivRemCoreCols, DivRemCoreSpecialCase, + run_mul_carries, run_sltu_diff_idx, DivRemCoreCols, DivRemCoreSpecialCase, DivRemStep, Rv32DivRemChip, }, + test_utils::get_verification_error, + DivRemCoreAir, }; type F = BabyBear; - -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; fn limb_sra( x: [u32; NUM_LIMBS], @@ -57,15 +53,58 @@ fn limb_sra( array::from_fn(|i| if i + shift < NUM_LIMBS { x[i] } else { ext }) } +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32DivRemChip, + SharedBitwiseOperationLookupChip, + SharedRangeTupleCheckerChip<2>, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); + + let chip = Rv32DivRemChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + DivRemCoreAir::new(bitwise_bus, range_tuple_bus, DivRemOpcode::CLASS_OFFSET), + ), + DivRemStep::new( + Rv32MultAdapterStep::new(), + bitwise_chip.clone(), + range_tuple_chip.clone(), + DivRemOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, bitwise_chip, range_tuple_chip) +} + #[allow(clippy::too_many_arguments)] -fn run_rv32_divrem_rand_write_execute>( - opcode: DivRemOpcode, +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], rng: &mut StdRng, + opcode: DivRemOpcode, + b: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, ) { + let b = b.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let c = c.unwrap_or(limb_sra::( + generate_long_number::(rng), + rng.gen_range(0..(RV32_REGISTER_NUM_LIMBS - 1)), + )); + let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); let rd = gen_pointer(rng, 4); @@ -73,8 +112,8 @@ fn run_rv32_divrem_rand_write_execute>( tester.write::(1, rs1, b.map(F::from_canonical_u32)); tester.write::(1, rs2, c.map(F::from_canonical_u32)); - let is_div = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::DIVU; - let is_signed = opcode == DivRemOpcode::DIV || opcode == DivRemOpcode::REM; + let is_div = opcode == DIV || opcode == DIVU; + let is_signed = opcode == DIV || opcode == REM; let (q, r, _, _, _, _) = run_divrem::(is_signed, &b, &c); @@ -89,136 +128,101 @@ fn run_rv32_divrem_rand_write_execute>( ); } -fn run_rv32_divrem_rand_test(opcode: DivRemOpcode, num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; - let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(DIV, 100)] +#[test_case(DIVU, 100)] +#[test_case(REM, 100)] +#[test_case(REMU, 100)] +fn rand_divrem_test(opcode: DivRemOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32DivRemChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - DivRemCoreChip::new( - bitwise_chip.clone(), - range_tuple_checker.clone(), - DivRemOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip, range_tuple_chip) = create_test_chip(&mut tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let leading_zeros = rng.gen_range(0..(RV32_REGISTER_NUM_LIMBS - 1)); - let c = limb_sra::( - generate_long_number::(&mut rng), - leading_zeros, - ); - run_rv32_divrem_rand_write_execute(opcode, &mut tester, &mut chip, b, c, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } // Test special cases in addition to random cases (i.e. zero divisor with b > 0, // zero divisor with b < 0, r = 0 (3 cases), and signed overflow). - run_rv32_divrem_rand_write_execute( - opcode, + set_and_execute( &mut tester, &mut chip, - [98, 188, 163, 127], - [0, 0, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([98, 188, 163, 127]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [98, 188, 163, 229], - [0, 0, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([98, 188, 163, 229]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 128], - [0, 1, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([0, 0, 0, 128]), + Some([0, 1, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 127], - [0, 1, 0, 0], &mut rng, - ); - run_rv32_divrem_rand_write_execute( opcode, + Some([0, 0, 0, 127]), + Some([0, 1, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 0], - [0, 0, 0, 0], &mut rng, + opcode, + Some([0, 0, 0, 0]), + Some([0, 0, 0, 0]), ); - run_rv32_divrem_rand_write_execute( + set_and_execute( + &mut tester, + &mut chip, + &mut rng, opcode, + Some([0, 0, 0, 0]), + Some([0, 0, 0, 0]), + ); + set_and_execute( &mut tester, &mut chip, - [0, 0, 0, 128], - [255, 255, 255, 255], &mut rng, + opcode, + Some([0, 0, 0, 128]), + Some([255, 255, 255, 255]), ); let tester = tester .build() .load(chip) .load(bitwise_chip) - .load(range_tuple_checker) + .load(range_tuple_chip) .finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_div_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::DIV, 100); -} - -#[test] -fn rv32_divu_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::DIVU, 100); -} - -#[test] -fn rv32_rem_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::REM, 100); -} - -#[test] -fn rv32_remu_rand_test() { - run_rv32_divrem_rand_test(DivRemOpcode::REMU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32DivRemTestChip = - VmChipWrapper, DivRemCoreChip>; - #[derive(Default, Clone, Copy)] struct DivRemPrankValues { pub q: Option<[u32; NUM_LIMBS]>, @@ -229,84 +233,20 @@ struct DivRemPrankValues { pub r_zero: Option, } -fn run_rv32_divrem_negative_test( - signed: bool, +fn run_negative_divrem_test( + opcode: DivRemOpcode, b: [u32; RV32_REGISTER_NUM_LIMBS], c: [u32; RV32_REGISTER_NUM_LIMBS], - prank_vals: &DivRemPrankValues, + prank_vals: DivRemPrankValues, interaction_error: bool, ) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32DivRemTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat(); 2], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - DivRemCoreChip::new( - bitwise_chip.clone(), - range_tuple_chip.clone(), - DivRemOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); - - let (div_opcode, rem_opcode) = if signed { - (DivRemOpcode::DIV, DivRemOpcode::REM) - } else { - (DivRemOpcode::DIVU, DivRemOpcode::REMU) - }; - tester.execute( - &mut chip, - &Instruction::from_usize(div_opcode.global_opcode(), [0, 0, 0, 1, 1]), - ); - tester.execute( - &mut chip, - &Instruction::from_usize(rem_opcode.global_opcode(), [0, 0, 0, 1, 1]), - ); - - let (q, r, b_sign, c_sign, q_sign, case) = - run_divrem::(signed, &b, &c); - let q = prank_vals.q.unwrap_or(q); - let r = prank_vals.r.unwrap_or(r); - let carries = - run_mul_carries::(signed, &c, &q, &r, q_sign); - - range_tuple_chip.clear(); - for i in 0..RV32_REGISTER_NUM_LIMBS { - range_tuple_chip.add_count(&[q[i], carries[i]]); - range_tuple_chip.add_count(&[r[i], carries[i + RV32_REGISTER_NUM_LIMBS]]); - } - - if let Some(diff_val) = prank_vals.diff_val { - bitwise_chip.clear(); - if signed { - let b_sign_mask = if b_sign { 1 << (RV32_CELL_BITS - 1) } else { 0 }; - let c_sign_mask = if c_sign { 1 << (RV32_CELL_BITS - 1) } else { 0 }; - bitwise_chip.request_range( - (b[RV32_REGISTER_NUM_LIMBS - 1] - b_sign_mask) << 1, - (c[RV32_REGISTER_NUM_LIMBS - 1] - c_sign_mask) << 1, - ); - } - if case == DivRemCoreSpecialCase::None { - bitwise_chip.request_range(diff_val - 1, 0); - } - } + let (mut chip, bitwise_chip, range_tuple_chip) = create_test_chip(&mut tester); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, Some(b), Some(c)); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut DivRemCoreCols = @@ -338,7 +278,7 @@ fn run_rv32_divrem_negative_test( cols.r_zero = F::from_bool(r_zero); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -348,11 +288,7 @@ fn run_rv32_divrem_negative_test( .load(bitwise_chip) .load(range_tuple_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -363,7 +299,8 @@ fn rv32_divrem_unsigned_wrong_q_negative_test() { q: Some([245, 168, 7, 0]), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -376,7 +313,8 @@ fn rv32_divrem_unsigned_wrong_r_negative_test() { diff_val: Some(31), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -387,7 +325,8 @@ fn rv32_divrem_unsigned_high_mult_negative_test() { q: Some([128, 0, 0, 1]), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -400,7 +339,8 @@ fn rv32_divrem_unsigned_zero_divisor_wrong_r_negative_test() { diff_val: Some(255), ..Default::default() }; - run_rv32_divrem_negative_test(false, b, c, &prank_vals, true); + run_negative_divrem_test(DIVU, b, c, prank_vals, true); + run_negative_divrem_test(REMU, b, c, prank_vals, true); } #[test] @@ -411,7 +351,8 @@ fn rv32_divrem_signed_wrong_q_negative_test() { q: Some([74, 61, 255, 255]), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -424,7 +365,8 @@ fn rv32_divrem_signed_wrong_r_negative_test() { diff_val: Some(20), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -435,7 +377,8 @@ fn rv32_divrem_signed_high_mult_negative_test() { q: Some([1, 0, 0, 1]), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -449,7 +392,8 @@ fn rv32_divrem_signed_r_wrong_sign_negative_test() { diff_val: Some(192), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -463,7 +407,8 @@ fn rv32_divrem_signed_r_wrong_prime_negative_test() { diff_val: Some(36), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -476,7 +421,8 @@ fn rv32_divrem_signed_zero_divisor_wrong_r_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, true); + run_negative_divrem_test(DIV, b, c, prank_vals, true); + run_negative_divrem_test(REM, b, c, prank_vals, true); } #[test] @@ -491,8 +437,10 @@ fn rv32_divrem_false_zero_divisor_flag_negative_test() { zero_divisor: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -507,8 +455,10 @@ fn rv32_divrem_false_r_zero_flag_negative_test() { r_zero: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -519,8 +469,10 @@ fn rv32_divrem_unset_zero_divisor_flag_negative_test() { zero_divisor: Some(false), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -532,8 +484,10 @@ fn rv32_divrem_wrong_r_zero_flag_negative_test() { r_zero: Some(true), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } #[test] @@ -544,8 +498,10 @@ fn rv32_divrem_unset_r_zero_flag_negative_test() { r_zero: Some(false), ..Default::default() }; - run_rv32_divrem_negative_test(true, b, c, &prank_vals, false); - run_rv32_divrem_negative_test(false, b, c, &prank_vals, false); + run_negative_divrem_test(DIVU, b, c, prank_vals, false); + run_negative_divrem_test(REMU, b, c, prank_vals, false); + run_negative_divrem_test(DIV, b, c, prank_vals, false); + run_negative_divrem_test(REM, b, c, prank_vals, false); } /////////////////////////////////////////////////////////////////////////////////////// diff --git a/extensions/rv32im/circuit/src/extension.rs b/extensions/rv32im/circuit/src/extension.rs index f1f67d3994..844840e1b3 100644 --- a/extensions/rv32im/circuit/src/extension.rs +++ b/extensions/rv32im/circuit/src/extension.rs @@ -1,11 +1,12 @@ use derive_more::derive::From; use openvm_circuit::{ arch::{ - SystemConfig, SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError, + ExecutionBridge, SystemConfig, SystemPort, VmAirWrapper, VmExtension, VmInventory, + VmInventoryBuilder, VmInventoryError, }, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, @@ -23,6 +24,9 @@ use strum::IntoEnumIterator; use crate::{adapters::*, *}; +// TODO(ayush): this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + /// Config for a VM with base extension and IO extension #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Rv32IConfig { @@ -127,7 +131,7 @@ fn default_range_tuple_checker_sizes() -> [u32; 2] { // ============ Executor and Periphery Enums for Extension ============ /// RISC-V 32-bit Base (RV32I) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, From, AnyEnum)] pub enum Rv32IExecutor { // Rv32 (for standard 32-bit integers): BaseAlu(Rv32BaseAluChip), @@ -143,7 +147,7 @@ pub enum Rv32IExecutor { } /// RISC-V 32-bit Multiplication Extension (RV32M) Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, From, AnyEnum)] pub enum Rv32MExecutor { Multiplication(Rv32MultiplicationChip), MultiplicationHigh(Rv32MulHChip), @@ -151,7 +155,7 @@ pub enum Rv32MExecutor { } /// RISC-V 32-bit Io Instruction Executors -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, InsExecutorE1, From, AnyEnum)] pub enum Rv32IoExecutor { HintStore(Rv32HintStoreChip), } @@ -197,7 +201,6 @@ impl VmExtension for Rv32I { } = builder.system_port(); let range_checker = builder.system_base().range_checker_chip.clone(); - let offline_memory = builder.system_base().offline_memory(); let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; let bitwise_lu_chip = if let Some(&chip) = builder @@ -213,14 +216,21 @@ impl VmExtension for Rv32I { }; let base_alu_chip = Rv32BaseAluChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + ), + BaseAluCoreAir::new(bitwise_lu_chip.bus(), BaseAluOpcode::CLASS_OFFSET), + ), + Rv32BaseAluStep::new( + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + BaseAluOpcode::CLASS_OFFSET, ), - BaseAluCoreChip::new(bitwise_lu_chip.clone(), BaseAluOpcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( base_alu_chip, @@ -228,43 +238,65 @@ impl VmExtension for Rv32I { )?; let lt_chip = Rv32LessThanChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + ), + LessThanCoreAir::new(bitwise_lu_chip.bus(), LessThanOpcode::CLASS_OFFSET), + ), + LessThanStep::new( + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), + LessThanOpcode::CLASS_OFFSET, ), - LessThanCoreChip::new(bitwise_lu_chip.clone(), LessThanOpcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor(lt_chip, LessThanOpcode::iter().map(|x| x.global_opcode()))?; let shift_chip = Rv32ShiftChip::new( - Rv32BaseAluAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - bitwise_lu_chip.clone(), + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + ), + ShiftCoreAir::new( + bitwise_lu_chip.bus(), + range_checker.bus(), + ShiftOpcode::CLASS_OFFSET, + ), ), - ShiftCoreChip::new( + ShiftStep::new( + Rv32BaseAluAdapterStep::new(bitwise_lu_chip.clone()), bitwise_lu_chip.clone(), range_checker.clone(), ShiftOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor(shift_chip, ShiftOpcode::iter().map(|x| x.global_opcode()))?; let load_store_chip = Rv32LoadStoreChip::new( - Rv32LoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - pointer_max_bits, + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + range_checker.bus(), + pointer_max_bits, + ), + LoadStoreCoreAir::new(Rv32LoadStoreOpcode::CLASS_OFFSET), + ), + LoadStoreStep::new( + Rv32LoadStoreAdapterStep::new(pointer_max_bits), range_checker.clone(), + Rv32LoadStoreOpcode::CLASS_OFFSET, ), - LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( load_store_chip, @@ -274,15 +306,21 @@ impl VmExtension for Rv32I { )?; let load_sign_extend_chip = Rv32LoadSignExtendChip::new( - Rv32LoadStoreAdapterChip::new( - execution_bus, - program_bus, - memory_bridge, - pointer_max_bits, + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + range_checker.bus(), + pointer_max_bits, + ), + LoadSignExtendCoreAir::new(range_checker.bus()), + ), + LoadSignExtendStep::new( + Rv32LoadStoreAdapterStep::new(pointer_max_bits), range_checker.clone(), ), - LoadSignExtendCoreChip::new(range_checker.clone()), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( load_sign_extend_chip, @@ -290,49 +328,99 @@ impl VmExtension for Rv32I { )?; let beq_chip = Rv32BranchEqualChip::new( - Rv32BranchAdapterChip::new(execution_bus, program_bus, memory_bridge), - BranchEqualCoreChip::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), - offline_memory.clone(), + VmAirWrapper::new( + Rv32BranchAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + BranchEqualCoreAir::new(BranchEqualOpcode::CLASS_OFFSET, DEFAULT_PC_STEP), + ), + BranchEqualStep::new( + Rv32BranchAdapterStep::new(), + BranchEqualOpcode::CLASS_OFFSET, + DEFAULT_PC_STEP, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( beq_chip, BranchEqualOpcode::iter().map(|x| x.global_opcode()), )?; - let blt_chip = Rv32BranchLessThanChip::new( - Rv32BranchAdapterChip::new(execution_bus, program_bus, memory_bridge), - BranchLessThanCoreChip::new( + let blt_chip = Rv32BranchLessThanChip::::new( + VmAirWrapper::new( + Rv32BranchAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + BranchLessThanCoreAir::new( + bitwise_lu_chip.bus(), + BranchLessThanOpcode::CLASS_OFFSET, + ), + ), + BranchLessThanStep::new( + Rv32BranchAdapterStep::new(), bitwise_lu_chip.clone(), BranchLessThanOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( blt_chip, BranchLessThanOpcode::iter().map(|x| x.global_opcode()), )?; - let jal_lui_chip = Rv32JalLuiChip::new( - Rv32CondRdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32JalLuiCoreChip::new(bitwise_lu_chip.clone()), - offline_memory.clone(), + let jal_lui_chip = Rv32JalLuiChip::::new( + VmAirWrapper::new( + Rv32CondRdWriteAdapterAir::new(Rv32RdWriteAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + )), + Rv32JalLuiCoreAir::new(bitwise_lu_chip.bus()), + ), + Rv32JalLuiStep::new( + Rv32CondRdWriteAdapterStep::new(Rv32RdWriteAdapterStep::new()), + bitwise_lu_chip.clone(), + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( jal_lui_chip, Rv32JalLuiOpcode::iter().map(|x| x.global_opcode()), )?; - let jalr_chip = Rv32JalrChip::new( - Rv32JalrAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32JalrCoreChip::new(bitwise_lu_chip.clone(), range_checker.clone()), - offline_memory.clone(), + let jalr_chip = Rv32JalrChip::::new( + VmAirWrapper::new( + Rv32JalrAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + Rv32JalrCoreAir::new(bitwise_lu_chip.bus(), range_checker.bus()), + ), + Rv32JalrStep::new( + Rv32JalrAdapterStep::new(), + bitwise_lu_chip.clone(), + range_checker.clone(), + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor(jalr_chip, Rv32JalrOpcode::iter().map(|x| x.global_opcode()))?; - let auipc_chip = Rv32AuipcChip::new( - Rv32RdWriteAdapterChip::new(execution_bus, program_bus, memory_bridge), - Rv32AuipcCoreChip::new(bitwise_lu_chip.clone()), - offline_memory.clone(), + let auipc_chip = Rv32AuipcChip::::new( + VmAirWrapper::new( + Rv32RdWriteAdapterAir::new( + memory_bridge, + ExecutionBridge::new(execution_bus, program_bus), + ), + Rv32AuipcCoreAir::new(bitwise_lu_chip.bus()), + ), + Rv32AuipcStep::new(Rv32RdWriteAdapterStep::new(), bitwise_lu_chip.clone()), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( auipc_chip, @@ -371,7 +459,6 @@ impl VmExtension for Rv32M { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() @@ -401,28 +488,63 @@ impl VmExtension for Rv32M { chip }; - let mul_chip = Rv32MultiplicationChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - MultiplicationCoreChip::new(range_tuple_checker.clone(), MulOpcode::CLASS_OFFSET), - offline_memory.clone(), + let mul_chip = Rv32MultiplicationChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + // TODO(ayush): bus should return value not reference + MultiplicationCoreAir::new(*range_tuple_checker.bus(), MulOpcode::CLASS_OFFSET), + ), + MultiplicationStep::new( + Rv32MultAdapterStep::new(), + range_tuple_checker.clone(), + MulOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor(mul_chip, MulOpcode::iter().map(|x| x.global_opcode()))?; - let mul_h_chip = Rv32MulHChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - MulHCoreChip::new(bitwise_lu_chip.clone(), range_tuple_checker.clone()), - offline_memory.clone(), + let mul_h_chip = Rv32MulHChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + MulHCoreAir::new(bitwise_lu_chip.bus(), *range_tuple_checker.bus()), + ), + MulHStep::new( + Rv32MultAdapterStep::new(), + bitwise_lu_chip.clone(), + range_tuple_checker.clone(), + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor(mul_h_chip, MulHOpcode::iter().map(|x| x.global_opcode()))?; - let div_rem_chip = Rv32DivRemChip::new( - Rv32MultAdapterChip::new(execution_bus, program_bus, memory_bridge), - DivRemCoreChip::new( + let div_rem_chip = Rv32DivRemChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + ), + DivRemCoreAir::new( + bitwise_lu_chip.bus(), + *range_tuple_checker.bus(), + DivRemOpcode::CLASS_OFFSET, + ), + ), + DivRemStep::new( + Rv32MultAdapterStep::new(), bitwise_lu_chip.clone(), range_tuple_checker.clone(), DivRemOpcode::CLASS_OFFSET, ), - offline_memory.clone(), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( div_rem_chip, @@ -447,7 +569,6 @@ impl VmExtension for Rv32Io { program_bus, memory_bridge, } = builder.system_port(); - let offline_memory = builder.system_base().offline_memory(); let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() @@ -461,16 +582,23 @@ impl VmExtension for Rv32Io { chip }; - let mut hintstore_chip = Rv32HintStoreChip::new( - execution_bus, - program_bus, - bitwise_lu_chip.clone(), - memory_bridge, - offline_memory.clone(), - builder.system_config().memory_config.pointer_max_bits, - Rv32HintStoreOpcode::CLASS_OFFSET, + let mut hintstore_chip = Rv32HintStoreChip::::new( + Rv32HintStoreAir::new( + ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lu_chip.bus(), + Rv32HintStoreOpcode::CLASS_OFFSET, + builder.system_config().memory_config.pointer_max_bits, + ), + Rv32HintStoreStep::new( + bitwise_lu_chip, + builder.system_config().memory_config.pointer_max_bits, + Rv32HintStoreOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); - hintstore_chip.set_streams(builder.streams().clone()); + hintstore_chip.step.set_streams(builder.streams().clone()); inventory.add_executor( hintstore_chip, @@ -486,13 +614,13 @@ mod phantom { use eyre::bail; use openvm_circuit::{ arch::{PhantomSubExecutor, Streams}, - system::memory::MemoryController, + system::memory::online::GuestMemory, }; use openvm_instructions::PhantomDiscriminant; use openvm_stark_backend::p3_field::{Field, PrimeField32}; use rand::{rngs::OsRng, Rng}; - use crate::adapters::unsafe_read_rv32_register; + use crate::adapters::{memory_read, new_read_rv32_register}; pub struct Rv32HintInputSubEx; pub struct Rv32HintRandomSubEx { @@ -508,11 +636,11 @@ mod phantom { impl PhantomSubExecutor for Rv32HintInputSubEx { fn phantom_execute( &mut self, - _: &MemoryController, + _: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - _: F, - _: F, + _: u32, + _: u32, _: u16, ) -> eyre::Result<()> { let mut hint = match streams.input_stream.pop_front() { @@ -539,14 +667,14 @@ mod phantom { impl PhantomSubExecutor for Rv32HintRandomSubEx { fn phantom_execute( &mut self, - memory: &MemoryController, + memory: &GuestMemory, streams: &mut Streams, _: PhantomDiscriminant, - a: F, - _: F, + a: u32, + _: u32, _: u16, ) -> eyre::Result<()> { - let len = unsafe_read_rv32_register(memory, a) as usize; + let len = new_read_rv32_register(memory, 1, a) as usize; streams.hint_stream.clear(); streams.hint_stream.extend( std::iter::repeat_with(|| F::from_canonical_u8(self.rng.gen::())).take(len * 4), @@ -558,22 +686,18 @@ mod phantom { impl PhantomSubExecutor for Rv32PrintStrSubEx { fn phantom_execute( &mut self, - memory: &MemoryController, + memory: &GuestMemory, _: &mut Streams, _: PhantomDiscriminant, - a: F, - b: F, + a: u32, + b: u32, _: u16, ) -> eyre::Result<()> { - let rd = unsafe_read_rv32_register(memory, a); - let rs1 = unsafe_read_rv32_register(memory, b); + let rd = new_read_rv32_register(memory, 1, a); + let rs1 = new_read_rv32_register(memory, 1, b); let bytes = (0..rs1) - .map(|i| -> eyre::Result { - let val = memory.unsafe_read_cell(F::TWO, F::from_canonical_u32(rd + i)); - let byte: u8 = val.as_canonical_u32().try_into()?; - Ok(byte) - }) - .collect::>>()?; + .map(|i| memory_read::<1>(memory, 2, rd + i)[0]) + .collect::>(); let peeked_str = String::from_utf8(bytes)?; print!("{peeked_str}"); Ok(()) diff --git a/extensions/rv32im/circuit/src/hintstore/mod.rs b/extensions/rv32im/circuit/src/hintstore/mod.rs index 6f70a584d0..6917862b62 100644 --- a/extensions/rv32im/circuit/src/hintstore/mod.rs +++ b/extensions/rv32im/circuit/src/hintstore/mod.rs @@ -5,19 +5,19 @@ use std::{ use openvm_circuit::{ arch::{ - ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams, + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + ExecutionBridge, ExecutionError, ExecutionState, NewVmChipWrapper, Result, StepExecutorE1, + Streams, TraceStep, VmStateMut, }, - system::{ - memory::{ - offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, - MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId, - }, - program::ProgramBus, + system::memory::{ + offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols}, + online::{GuestMemory, TracingMemory}, + MemoryAddress, MemoryAuxColsFactory, RecordId, }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, - utils::{next_power_of_two_or_zero, not}, + utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; use openvm_instructions::{ @@ -31,18 +31,18 @@ use openvm_rv32im_transpiler::{ Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW}, }; use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, interaction::InteractionBuilder, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::types::AirProofInput, - rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir}, - Chip, ChipUsageGetter, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, }; use serde::{Deserialize, Serialize}; -use crate::adapters::{compose, decompose}; +use crate::adapters::{ + decompose, memory_read, memory_read_from_state, memory_write_from_state, tracing_read, + tracing_write, +}; #[cfg(test)] mod tests; @@ -70,7 +70,7 @@ pub struct Rv32HintStoreCols { pub num_words_aux_cols: MemoryReadAuxCols, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct Rv32HintStoreAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, @@ -277,53 +277,54 @@ pub struct Rv32HintStoreRecord { pub hints: Vec<([F; RV32_REGISTER_NUM_LIMBS], RecordId)>, } -pub struct Rv32HintStoreChip { - air: Rv32HintStoreAir, - pub records: Vec>, - pub height: usize, - offline_memory: Arc>>, +pub struct Rv32HintStoreStep { + pointer_max_bits: usize, + offset: usize, pub streams: OnceLock>>>, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl Rv32HintStoreChip { +impl Rv32HintStoreStep { pub fn new( - execution_bus: ExecutionBus, - program_bus: ProgramBus, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, - memory_bridge: MemoryBridge, - offline_memory: Arc>>, pointer_max_bits: usize, offset: usize, ) -> Self { - let air = Rv32HintStoreAir { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_operation_lookup_bus: bitwise_lookup_chip.bus(), - offset, - pointer_max_bits, - }; Self { - records: vec![], - air, - height: 0, - offline_memory, + pointer_max_bits, + offset, streams: OnceLock::new(), bitwise_lookup_chip, } } + pub fn set_streams(&mut self, streams: Arc>>) { self.streams.set(streams).unwrap(); } } -impl InstructionExecutor for Rv32HintStoreChip { +impl TraceStep for Rv32HintStoreStep +where + F: PrimeField32, +{ + fn get_opcode_name(&self, opcode: usize) -> String { + if opcode == HINT_STOREW.global_opcode().as_usize() { + String::from("HINT_STOREW") + } else if opcode == HINT_BUFFER.global_opcode().as_usize() { + String::from("HINT_BUFFER") + } else { + unreachable!("unsupported opcode: {}", opcode) + } + } + fn execute( &mut self, - memory: &mut MemoryController, + state: VmStateMut, CTX>, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let &Instruction { opcode, a: num_words_ptr, @@ -332,194 +333,218 @@ impl InstructionExecutor for Rv32HintStoreChip { e, .. } = instruction; + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let local_opcode = - Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let (mem_ptr_read, mem_ptr_limbs) = memory.read::(d, mem_ptr_ptr); - let (num_words, num_words_read) = if local_opcode == HINT_STOREW { - memory.increment_timestamp(); - (1, None) + let local_opcode = Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row: &mut Rv32HintStoreCols = + trace[*trace_offset..*trace_offset + width].borrow_mut(); + + row.from_state.pc = F::from_canonical_u32(*state.pc); + row.from_state.timestamp = F::from_canonical_u32(state.memory.timestamp); + + row.mem_ptr_ptr = mem_ptr_ptr; + let mem_ptr_limbs: [u8; RV32_REGISTER_NUM_LIMBS] = tracing_read( + state.memory, + RV32_REGISTER_AS, + mem_ptr_ptr.as_canonical_u32(), + &mut row.mem_ptr_aux_cols, + ); + let mem_ptr = u32::from_le_bytes(mem_ptr_limbs); + debug_assert!(mem_ptr <= (1 << self.pointer_max_bits)); + + row.num_words_ptr = num_words_ptr; + let num_words = if local_opcode == HINT_STOREW { + row.is_single = F::ONE; + state.memory.increment_timestamp(); + 1 } else { - let (num_words_read, num_words_limbs) = - memory.read::(d, num_words_ptr); - (compose(num_words_limbs), Some(num_words_read)) + row.is_buffer_start = F::ONE; + row.is_buffer = F::ONE; + let num_words_limbs: [u8; RV32_REGISTER_NUM_LIMBS] = tracing_read( + state.memory, + RV32_REGISTER_AS, + num_words_ptr.as_canonical_u32(), + &mut row.num_words_aux_cols, + ); + u32::from_le_bytes(num_words_limbs) }; debug_assert_ne!(num_words, 0); - debug_assert!(num_words <= (1 << self.air.pointer_max_bits)); - - let mem_ptr = compose(mem_ptr_limbs); - - debug_assert!(mem_ptr <= (1 << self.air.pointer_max_bits)); + debug_assert!(num_words <= (1 << self.pointer_max_bits)); let mut streams = self.streams.get().unwrap().lock().unwrap(); if streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { - return Err(ExecutionError::HintOutOfBounds { pc: from_state.pc }); + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); } - let mut record = Rv32HintStoreRecord { - from_state, - instruction: instruction.clone(), - mem_ptr_read, - mem_ptr, - num_words, - num_words_read, - hints: vec![], - }; + let mem_ptr_msl = mem_ptr >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); + let num_words_msl = num_words >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); + // TODO(ayush): see if this can be moved to fill_trace_row + self.bitwise_lookup_chip.request_range( + mem_ptr_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits), + num_words_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits), + ); + + for word_index in 0..(num_words as usize) { + let offset = *trace_offset + word_index * width; + let row: &mut Rv32HintStoreCols = trace[offset..offset + width].borrow_mut(); - for word_index in 0..num_words { if word_index != 0 { - memory.increment_timestamp(); - memory.increment_timestamp(); + row.is_buffer = F::ONE; + row.from_state.timestamp = F::from_canonical_u32(state.memory.timestamp); + + state.memory.increment_timestamp(); + state.memory.increment_timestamp(); } - let data: [F; RV32_REGISTER_NUM_LIMBS] = + let data_f: [F; RV32_REGISTER_NUM_LIMBS] = std::array::from_fn(|_| streams.hint_stream.pop_front().unwrap()); - let (write, _) = memory.write( - e, - F::from_canonical_u32(mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index)), - data, + let data: [u8; RV32_REGISTER_NUM_LIMBS] = + data_f.map(|byte| byte.as_canonical_u32() as u8); + + let mem_ptr_word = mem_ptr + (RV32_REGISTER_NUM_LIMBS * word_index) as u32; + + row.data = data_f; + tracing_write( + state.memory, + RV32_MEMORY_AS, + mem_ptr_word, + &data, + &mut row.write_aux, ); - record.hints.push((data, write)); + + row.rem_words_limbs = decompose(num_words - word_index as u32); + row.mem_ptr_limbs = decompose(mem_ptr_word); } - self.height += record.hints.len(); - self.records.push(record); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - let next_state = ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }; - Ok(next_state) + *trace_offset += (num_words as usize) * width; + + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - if opcode == HINT_STOREW.global_opcode().as_usize() { - String::from("HINT_STOREW") - } else if opcode == HINT_BUFFER.global_opcode().as_usize() { - String::from("HINT_BUFFER") - } else { - unreachable!("unsupported opcode: {}", opcode) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let row: &mut Rv32HintStoreCols = row_slice.borrow_mut(); + + let mut timestamp = row.from_state.timestamp.as_canonical_u32(); + + if row.is_single.is_one() || row.is_buffer_start.is_one() { + mem_helper.fill_from_prev(timestamp, row.mem_ptr_aux_cols.as_mut()); } - } -} + timestamp += 1; -impl ChipUsageGetter for Rv32HintStoreChip { - fn air_name(&self) -> String { - "Rv32HintStoreAir".to_string() - } + if row.is_buffer_start.is_one() { + mem_helper.fill_from_prev(timestamp, row.num_words_aux_cols.as_mut()); + } + timestamp += 1; - fn current_trace_height(&self) -> usize { - self.height - } + mem_helper.fill_from_prev(timestamp, row.write_aux.as_mut()); - fn trace_width(&self) -> usize { - Rv32HintStoreCols::::width() + for half in 0..(RV32_REGISTER_NUM_LIMBS / 2) { + self.bitwise_lookup_chip.request_range( + row.data[2 * half].as_canonical_u32(), + row.data[2 * half + 1].as_canonical_u32(), + ); + } } } -impl Rv32HintStoreChip { - // returns number of used u32s - fn record_to_rows( - record: Rv32HintStoreRecord, - aux_cols_factory: &MemoryAuxColsFactory, - slice: &mut [F], - memory: &OfflineMemory, - bitwise_lookup_chip: &SharedBitwiseOperationLookupChip, - pointer_max_bits: usize, - ) -> usize { - let width = Rv32HintStoreCols::::width(); - let cols: &mut Rv32HintStoreCols = slice[..width].borrow_mut(); - - cols.is_single = F::from_bool(record.num_words_read.is_none()); - cols.is_buffer = F::from_bool(record.num_words_read.is_some()); - cols.is_buffer_start = cols.is_buffer; - - cols.from_state = record.from_state.map(F::from_canonical_u32); - cols.mem_ptr_ptr = record.instruction.b; - aux_cols_factory.generate_read_aux( - memory.record_by_id(record.mem_ptr_read), - &mut cols.mem_ptr_aux_cols, - ); +impl StepExecutorE1 for Rv32HintStoreStep +where + F: PrimeField32, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let &Instruction { + opcode, + a: num_words_ptr, + b: mem_ptr_ptr, + d, + e, + .. + } = instruction; - cols.num_words_ptr = record.instruction.a; - if let Some(num_words_read) = record.num_words_read { - aux_cols_factory.generate_read_aux( - memory.record_by_id(num_words_read), - &mut cols.num_words_aux_cols, - ); - } + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - let mut mem_ptr = record.mem_ptr; - let mut rem_words = record.num_words; - let mut used_u32s = 0; + let local_opcode = Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let mem_ptr_msl = mem_ptr >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); - let rem_words_msl = rem_words >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS); - bitwise_lookup_chip.request_range( - mem_ptr_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits), - rem_words_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits), - ); - for (i, &(data, write)) in record.hints.iter().enumerate() { - for half in 0..(RV32_REGISTER_NUM_LIMBS / 2) { - bitwise_lookup_chip.request_range( - data[2 * half].as_canonical_u32(), - data[2 * half + 1].as_canonical_u32(), - ); - } + let mem_ptr_limbs = + memory_read_from_state(state, RV32_REGISTER_AS, mem_ptr_ptr.as_canonical_u32()); + let mem_ptr = u32::from_le_bytes(mem_ptr_limbs); + debug_assert!(mem_ptr <= (1 << self.pointer_max_bits)); - let cols: &mut Rv32HintStoreCols = slice[used_u32s..used_u32s + width].borrow_mut(); - cols.from_state.timestamp = - F::from_canonical_u32(record.from_state.timestamp + (3 * i as u32)); - cols.data = data; - aux_cols_factory.generate_write_aux(memory.record_by_id(write), &mut cols.write_aux); - cols.rem_words_limbs = decompose(rem_words); - cols.mem_ptr_limbs = decompose(mem_ptr); - if i != 0 { - cols.is_buffer = F::ONE; - } - used_u32s += width; - mem_ptr += RV32_REGISTER_NUM_LIMBS as u32; - rem_words -= 1; + let num_words = if local_opcode == HINT_STOREW { + 1 + } else { + let num_words_limbs = + memory_read_from_state(state, RV32_REGISTER_AS, num_words_ptr.as_canonical_u32()); + u32::from_le_bytes(num_words_limbs) + }; + debug_assert_ne!(num_words, 0); + debug_assert!(num_words <= (1 << self.pointer_max_bits)); + + let mut streams = self.streams.get().unwrap().lock().unwrap(); + if streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize { + return Err(ExecutionError::HintOutOfBounds { pc: *state.pc }); } - used_u32s - } + for word_index in 0..num_words { + let data: [u8; RV32_REGISTER_NUM_LIMBS] = std::array::from_fn(|_| { + streams.hint_stream.pop_front().unwrap().as_canonical_u32() as u8 + }); + memory_write_from_state( + state, + RV32_MEMORY_AS, + mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index), + &data, + ); + } - fn generate_trace(self) -> RowMajorMatrix { - let width = self.trace_width(); - let height = next_power_of_two_or_zero(self.height); - let mut flat_trace = F::zero_vec(width * height); + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - let memory = self.offline_memory.lock().unwrap(); + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + let &Instruction { + opcode, + a: num_words_ptr, + .. + } = instruction; - let aux_cols_factory = memory.aux_cols_factory(); + let local_opcode = Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let mut used_u32s = 0; - for record in self.records { - used_u32s += Self::record_to_rows( - record, - &aux_cols_factory, - &mut flat_trace[used_u32s..], - &memory, - &self.bitwise_lookup_chip, - self.air.pointer_max_bits, + let num_words = if local_opcode == HINT_STOREW { + 1 + } else { + let num_words_limbs = memory_read( + state.memory, + RV32_REGISTER_AS, + num_words_ptr.as_canonical_u32(), ); - } - // padding rows can just be all zeros - RowMajorMatrix::new(flat_trace, width) - } -} + u32::from_le_bytes(num_words_limbs) + }; -impl Chip for Rv32HintStoreChip> -where - Val: PrimeField32, -{ - fn air(&self) -> Arc> { - Arc::new(self.air) - } - fn generate_air_proof_input(self) -> AirProofInput { - AirProofInput::simple_no_pis(self.generate_trace()) + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += num_words; + + Ok(()) } } + +pub type Rv32HintStoreChip = NewVmChipWrapper>; diff --git a/extensions/rv32im/circuit/src/hintstore/tests.rs b/extensions/rv32im/circuit/src/hintstore/tests.rs index 204070762c..c56bfe185d 100644 --- a/extensions/rv32im/circuit/src/hintstore/tests.rs +++ b/extensions/rv32im/circuit/src/hintstore/tests.rs @@ -6,7 +6,7 @@ use std::{ use openvm_circuit::arch::{ testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - Streams, + ExecutionBridge, Streams, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -24,15 +24,41 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::{Rv32HintStoreChip, Rv32HintStoreCols}; -use crate::adapters::decompose; +use super::{Rv32HintStoreAir, Rv32HintStoreChip, Rv32HintStoreCols, Rv32HintStoreStep}; +use crate::{adapters::decompose, test_utils::get_verification_error}; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 1024; + +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32HintStoreChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Rv32HintStoreChip::::new( + Rv32HintStoreAir::new( + ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), + tester.memory_bridge(), + bitwise_chip.bus(), + 0, + tester.address_bits(), + ), + Rv32HintStoreStep::new(bitwise_chip.clone(), tester.address_bits(), 0), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + chip.step + .set_streams(Arc::new(Mutex::new(Streams::default()))); + (chip, bitwise_chip) +} fn set_and_execute( tester: &mut VmChipTestBuilder, @@ -40,15 +66,9 @@ fn set_and_execute( rng: &mut StdRng, opcode: Rv32HintStoreOpcode, ) { - let mem_ptr = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - 2)), - ) << 2; + let mem_ptr = rng + .gen_range(0..(1 << (tester.memory_controller().mem_config().pointer_max_bits - 2))) + << 2; let b = gen_pointer(rng, 4); tester.write(1, b, decompose(mem_ptr)); @@ -56,7 +76,8 @@ fn set_and_execute( let read_data: [F; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|_| F::from_canonical_u32(rng.gen_range(0..(1 << RV32_CELL_BITS)))); for data in read_data { - chip.streams + chip.step + .streams .get() .unwrap() .lock() @@ -80,15 +101,9 @@ fn set_and_execute_buffer( rng: &mut StdRng, opcode: Rv32HintStoreOpcode, ) { - let mem_ptr = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - 2)), - ) << 2; + let mem_ptr = rng + .gen_range(0..(1 << (tester.memory_controller().mem_config().pointer_max_bits - 2))) + << 2; let b = gen_pointer(rng, 4); tester.write(1, b, decompose(mem_ptr)); @@ -102,7 +117,8 @@ fn set_and_execute_buffer( .collect(); for i in 0..num_words { for datum in data[i as usize] { - chip.streams + chip.step + .streams .get() .unwrap() .lock() @@ -131,30 +147,15 @@ fn set_and_execute_buffer( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// + #[test] fn rand_hintstore_test() { - setup_tracing(); let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); - - let num_tests: usize = 8; - for _ in 0..num_tests { + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); + let num_ops: usize = 100; + for _ in 0..num_ops { if rng.gen_bool(0.5) { set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); } else { @@ -162,7 +163,6 @@ fn rand_hintstore_test() { } } - drop(range_checker_chip); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } @@ -171,64 +171,44 @@ fn rand_hintstore_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// #[allow(clippy::too_many_arguments)] fn run_negative_hintstore_test( opcode: Rv32HintStoreOpcode, - data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - expected_error: VerificationError, + prank_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); set_and_execute(&mut tester, &mut chip, &mut rng, opcode); let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); let cols: &mut Rv32HintStoreCols = trace_row.as_mut_slice().borrow_mut(); - if let Some(data) = data { + if let Some(data) = prank_data { cols.data = data.map(F::from_canonical_u32); } *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn negative_hintstore_tests() { - run_negative_hintstore_test( - HINT_STOREW, - Some([92, 187, 45, 280]), - VerificationError::ChallengePhaseError, - ); + run_negative_hintstore_test(HINT_STOREW, Some([92, 187, 45, 280]), true); } + /////////////////////////////////////////////////////////////////////////////////////// /// SANITY TESTS /// @@ -239,22 +219,10 @@ fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - - let mut chip = Rv32HintStoreChip::::new( - tester.execution_bus(), - tester.program_bus(), - bitwise_chip.clone(), - tester.memory_bridge(), - tester.offline_memory_mutex_arc(), - tester.address_bits(), - 0, - ); - chip.set_streams(Arc::new(Mutex::new(Streams::default()))); + let (mut chip, _) = create_test_chip(&mut tester); - let num_tests: usize = 100; - for _ in 0..num_tests { + let num_ops: usize = 10; + for _ in 0..num_ops { set_and_execute(&mut tester, &mut chip, &mut rng, HINT_STOREW); } } diff --git a/extensions/rv32im/circuit/src/jal_lui/core.rs b/extensions/rv32im/circuit/src/jal_lui/core.rs index 2ba10e615e..836223316e 100644 --- a/extensions/rv32im/circuit/src/jal_lui/core.rs +++ b/extensions/rv32im/circuit/src/jal_lui/core.rs @@ -1,11 +1,15 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, -}; - -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, ImmInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -23,10 +27,11 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS}; +const ADDITIONAL_BITS: u32 = 0b11000000; + #[repr(C)] #[derive(Debug, Clone, AlignedBorrow)] pub struct Rv32JalLuiCoreCols { @@ -36,7 +41,7 @@ pub struct Rv32JalLuiCoreCols { pub is_lui: T, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy, derive_new::new)] pub struct Rv32JalLuiCoreAir { pub bus: BitwiseOperationLookupBus, } @@ -140,125 +145,174 @@ where } } -#[repr(C)] -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(bound = "F: Field")] -pub struct Rv32JalLuiCoreRecord { - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS], - pub imm: F, - pub is_jal: bool, - pub is_lui: bool, -} - -pub struct Rv32JalLuiCoreChip { - pub air: Rv32JalLuiCoreAir, +pub struct Rv32JalLuiCoreStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl Rv32JalLuiCoreChip { - pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip) -> Self { +impl Rv32JalLuiCoreStep { + pub fn new( + adapter: A, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + ) -> Self { Self { - air: Rv32JalLuiCoreAir { - bus: bitwise_lookup_chip.bus(), - }, + adapter, bitwise_lookup_chip, } } } -impl> VmCoreChip for Rv32JalLuiCoreChip +impl TraceStep for Rv32JalLuiCoreStep where - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = (), + WriteData = [u8; RV32_REGISTER_NUM_LIMBS], + TraceContext<'a> = (), + >, { - type Record = Rv32JalLuiCoreRecord; - type Air = Rv32JalLuiCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - _reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = Rv32JalLuiOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET), - ); - let imm = instruction.c; + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let Instruction { opcode, c: imm, .. } = instruction; + + let local_opcode = + Rv32JalLuiOpcode::from_usize(opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + A::start(*state.pc, state.memory, adapter_row); + + let core_row: &mut Rv32JalLuiCoreCols = core_row.borrow_mut(); + + // `c` can be "negative" as a field element + let imm_f = imm.as_canonical_u32(); let signed_imm = match local_opcode { JAL => { - // Note: signed_imm is a signed integer and imm is a field element - (imm + F::from_canonical_u32(1 << (RV_J_TYPE_IMM_BITS - 1))).as_canonical_u32() - as i32 - - (1 << (RV_J_TYPE_IMM_BITS - 1)) + if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) { + imm_f as i32 + } else { + let neg_imm_f = F::ORDER_U32 - imm_f; + debug_assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1))); + -(neg_imm_f as i32) + } } - LUI => imm.as_canonical_u32() as i32, + LUI => imm_f as i32, }; - let (to_pc, rd_data) = run_jal_lui(local_opcode, from_pc, signed_imm); + let (to_pc, rd_data) = run_jal_lui(local_opcode, *state.pc, signed_imm); - for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) { - self.bitwise_lookup_chip - .request_range(rd_data[i * 2], rd_data[i * 2 + 1]); - } + core_row.rd_data = rd_data.map(F::from_canonical_u8); + core_row.imm = instruction.c; + core_row.is_jal = F::from_bool(local_opcode == JAL); + core_row.is_lui = F::from_bool(local_opcode == LUI); + + self.adapter + .write(state.memory, instruction, adapter_row, &rd_data); + + *state.pc = to_pc; + + *trace_offset += width; + + Ok(()) + } - if local_opcode == JAL { - let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); - let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x)); + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + let core_row: &mut Rv32JalLuiCoreCols = core_row.borrow_mut(); + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let rd_data = core_row.rd_data.map(|x| x.as_canonical_u32()); + for pair in rd_data.chunks_exact(2) { + self.bitwise_lookup_chip.request_range(pair[0], pair[1]); + } + if core_row.is_jal == F::ONE { self.bitwise_lookup_chip - .request_xor(rd_data[3], additional_bits); + .request_xor(rd_data[3], ADDITIONAL_BITS); } + } +} - let rd_data = rd_data.map(F::from_canonical_u32); +impl StepExecutorE1 for Rv32JalLuiCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, c: imm, .. } = instruction; - let output = AdapterRuntimeContext { - to_pc: Some(to_pc), - writes: [rd_data].into(), + let local_opcode = + Rv32JalLuiOpcode::from_usize(opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET)); + + let imm_f = imm.as_canonical_u32(); + let signed_imm = match local_opcode { + JAL => { + if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) { + imm_f as i32 + } else { + let neg_imm_f = F::ORDER_U32 - imm_f; + debug_assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1))); + -(neg_imm_f as i32) + } + } + LUI => imm_f as i32, }; + let (to_pc, rd) = run_jal_lui(local_opcode, *state.pc, signed_imm); - Ok(( - output, - Rv32JalLuiCoreRecord { - rd_data, - imm, - is_jal: local_opcode == JAL, - is_lui: local_opcode == LUI, - }, - )) - } + self.adapter.write(state, instruction, &rd); - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET) - ) - } + *state.pc = to_pc; - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut Rv32JalLuiCoreCols = row_slice.borrow_mut(); - core_cols.rd_data = record.rd_data; - core_cols.imm = record.imm; - core_cols.is_jal = F::from_bool(record.is_jal); - core_cols.is_lui = F::from_bool(record.is_lui); + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // returns (to_pc, rd_data) +#[inline(always)] pub(super) fn run_jal_lui( opcode: Rv32JalLuiOpcode, pc: u32, imm: i32, -) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) { +) -> (u32, [u8; RV32_REGISTER_NUM_LIMBS]) { match opcode { JAL => { - let rd_data = array::from_fn(|i| { - ((pc + DEFAULT_PC_STEP) >> (8 * i)) & ((1 << RV32_CELL_BITS) - 1) - }); + let rd_data = (pc + DEFAULT_PC_STEP).to_le_bytes(); let next_pc = pc as i32 + imm; assert!(next_pc >= 0); (next_pc as u32, rd_data) @@ -266,9 +320,15 @@ pub(super) fn run_jal_lui( LUI => { let imm = imm as u32; let rd = imm << 12; - let rd_data = - array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & ((1 << RV32_CELL_BITS) - 1)); - (pc + DEFAULT_PC_STEP, rd_data) + (pc + DEFAULT_PC_STEP, rd.to_le_bytes()) } } } + +// TODO(ayush): move from here +#[test] +fn test_additional_bits() { + let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1); + let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1u32 << x)); + assert_eq!(additional_bits, ADDITIONAL_BITS); +} diff --git a/extensions/rv32im/circuit/src/jal_lui/mod.rs b/extensions/rv32im/circuit/src/jal_lui/mod.rs index 779b710bea..85b7b3ce35 100644 --- a/extensions/rv32im/circuit/src/jal_lui/mod.rs +++ b/extensions/rv32im/circuit/src/jal_lui/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::Rv32CondRdWriteAdapterChip; +use crate::adapters::{Rv32CondRdWriteAdapterAir, Rv32CondRdWriteAdapterStep}; mod core; pub use core::*; @@ -8,4 +8,6 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32JalLuiChip = VmChipWrapper, Rv32JalLuiCoreChip>; +pub type Rv32JalLuiAir = VmAirWrapper; +pub type Rv32JalLuiStep = Rv32JalLuiCoreStep; +pub type Rv32JalLuiChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/jal_lui/tests.rs b/extensions/rv32im/circuit/src/jal_lui/tests.rs index 35e258cbfb..ec09769612 100644 --- a/extensions/rv32im/circuit/src/jal_lui/tests.rs +++ b/extensions/rv32im/circuit/src/jal_lui/tests.rs @@ -2,7 +2,7 @@ use std::borrow::BorrowMut; use openvm_circuit::arch::{ testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmAdapterChip, + VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -12,27 +12,60 @@ use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{run_jal_lui, Rv32JalLuiChip, Rv32JalLuiCoreChip}; +use super::{run_jal_lui, Rv32JalLuiChip, Rv32JalLuiCoreAir, Rv32JalLuiStep}; use crate::{ adapters::{ - Rv32CondRdWriteAdapterChip, Rv32CondRdWriteAdapterCols, RV32_CELL_BITS, - RV32_REGISTER_NUM_LIMBS, RV_IS_TYPE_IMM_BITS, + Rv32CondRdWriteAdapterAir, Rv32CondRdWriteAdapterCols, Rv32CondRdWriteAdapterStep, + Rv32RdWriteAdapterAir, Rv32RdWriteAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + RV_IS_TYPE_IMM_BITS, }, jal_lui::Rv32JalLuiCoreCols, + test_utils::get_verification_error, }; const IMM_BITS: usize = 20; const LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; +const MAX_INS_CAPACITY: usize = 128; + type F = BabyBear; +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32JalLuiChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let chip = Rv32JalLuiChip::::new( + VmAirWrapper::new( + Rv32CondRdWriteAdapterAir::new(Rv32RdWriteAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + )), + Rv32JalLuiCoreAir::new(bitwise_bus), + ), + Rv32JalLuiStep::new( + Rv32CondRdWriteAdapterStep::new(Rv32RdWriteAdapterStep::new()), + bitwise_chip.clone(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + (chip, bitwise_chip) +} + fn set_and_execute( tester: &mut VmChipTestBuilder, chip: &mut Rv32JalLuiChip, @@ -71,7 +104,7 @@ fn set_and_execute( let rd_data = if needs_write { rd_data } else { [0; 4] }; assert_eq!(next_pc, final_pc); - assert_eq!(rd_data.map(F::from_canonical_u32), tester.read::<4>(1, a)); + assert_eq!(rd_data.map(F::from_canonical_u8), tester.read::<4>(1, a)); } /////////////////////////////////////////////////////////////////////////////////////// @@ -81,25 +114,15 @@ fn set_and_execute( /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_jal_lui_test() { +#[test_case(JAL, 100)] +#[test_case(LUI, 100)] +fn rand_jal_lui_test(opcode: Rv32JalLuiOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32JalLuiCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, JAL, None, None); - set_and_execute(&mut tester, &mut chip, &mut rng, LUI, None, None); + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } let tester = tester.build().load(chip).load(bitwise_chip).finalize(); @@ -109,35 +132,29 @@ fn rand_jal_lui_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct JalLuiPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub imm: Option, + pub is_jal: Option, + pub is_lui: Option, + pub needs_write: Option, +} + #[allow(clippy::too_many_arguments)] fn run_negative_jal_lui_test( opcode: Rv32JalLuiOpcode, initial_imm: Option, initial_pc: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - imm: Option, - is_jal: Option, - is_lui: Option, - needs_write: Option, - expected_error: VerificationError, + prank_vals: JalLuiPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let core = Rv32JalLuiCoreChip::new(bitwise_chip.clone()); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&tester); set_and_execute( &mut tester, @@ -148,51 +165,43 @@ fn run_negative_jal_lui_test( initial_pc, ); - let tester = tester.build(); - - let jal_lui_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jal_lui_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = jal_lui_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (adapter_row, core_row) = trace_row.split_at_mut(adapter_width); - let adapter_cols: &mut Rv32CondRdWriteAdapterCols = adapter_row.borrow_mut(); let core_cols: &mut Rv32JalLuiCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(imm) = imm { + if let Some(imm) = prank_vals.imm { core_cols.imm = if imm < 0 { F::NEG_ONE * F::from_canonical_u32((-imm) as u32) } else { F::from_canonical_u32(imm as u32) }; } - if let Some(is_jal) = is_jal { + if let Some(is_jal) = prank_vals.is_jal { core_cols.is_jal = F::from_bool(is_jal); } - if let Some(is_lui) = is_lui { + if let Some(is_lui) = prank_vals.is_lui { core_cols.is_lui = F::from_bool(is_lui); } - - if let Some(needs_write) = needs_write { + if let Some(needs_write) = prank_vals.needs_write { adapter_cols.needs_write = F::from_bool(needs_write); } - *jal_lui_trace = RowMajorMatrix::new(trace_row, jal_lui_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) + .build() + .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -201,34 +210,35 @@ fn opcode_flag_negative_test() { JAL, None, None, - None, - None, - Some(false), - Some(true), - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + is_jal: Some(false), + is_lui: Some(true), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( JAL, None, None, - None, - None, - Some(false), - Some(false), - Some(false), - VerificationError::ChallengePhaseError, + JalLuiPrankValues { + is_jal: Some(false), + is_lui: Some(false), + needs_write: Some(false), + ..Default::default() + }, + true, ); run_negative_jal_lui_test( LUI, None, None, - None, - None, - Some(true), - Some(false), - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + is_jal: Some(true), + is_lui: Some(false), + ..Default::default() + }, + false, ); } @@ -238,67 +248,61 @@ fn overflow_negative_tests() { JAL, None, None, - Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([LIMB_MAX, LIMB_MAX, LIMB_MAX, LIMB_MAX]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - Some([0, LIMB_MAX, LIMB_MAX, LIMB_MAX + 1]), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + rd_data: Some([0, LIMB_MAX, LIMB_MAX, LIMB_MAX + 1]), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - None, - Some(-1), - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + imm: Some(-1), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( LUI, None, None, - None, - Some(-28), - None, - None, - None, - VerificationError::OodEvaluationMismatch, + JalLuiPrankValues { + imm: Some(-28), + ..Default::default() + }, + false, ); run_negative_jal_lui_test( JAL, None, Some(251), - Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), - None, - None, - None, - None, - VerificationError::ChallengePhaseError, + JalLuiPrankValues { + rd_data: Some([F::NEG_ONE.as_canonical_u32(), 1, 0, 0]), + ..Default::default() + }, + true, ); } @@ -307,25 +311,12 @@ fn overflow_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// + #[test] fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let adapter = Rv32CondRdWriteAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let core = Rv32JalLuiCoreChip::new(bitwise_chip); - let mut chip = Rv32JalLuiChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute(&mut tester, &mut chip, &mut rng, JAL, None, None); - set_and_execute(&mut tester, &mut chip, &mut rng, LUI, None, None); - } + let (mut chip, _) = create_test_chip(&tester); set_and_execute( &mut tester, diff --git a/extensions/rv32im/circuit/src/jalr/core.rs b/extensions/rv32im/circuit/src/jalr/core.rs index fd89c1e317..1c543718ed 100644 --- a/extensions/rv32im/circuit/src/jalr/core.rs +++ b/extensions/rv32im/circuit/src/jalr/core.rs @@ -3,9 +3,16 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, SignedImmInstruction, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, Result, SignedImmInstruction, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, @@ -24,9 +31,8 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use crate::adapters::{compose, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{compose, Rv32JalrAdapterCols, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1; @@ -46,18 +52,7 @@ pub struct Rv32JalrCoreCols { pub imm_sign: T, } -#[repr(C)] -#[derive(Serialize, Deserialize)] -pub struct Rv32JalrCoreRecord { - pub imm: F, - pub rs1_data: [F; RV32_REGISTER_NUM_LIMBS], - pub rd_data: [F; RV32_REGISTER_NUM_LIMBS - 1], - pub to_pc_least_sig_bit: F, - pub to_pc_limbs: [u32; 2], - pub imm_sign: F, -} - -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct Rv32JalrCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_bus: VariableRangeCheckerBus, @@ -181,116 +176,208 @@ where } } -pub struct Rv32JalrCoreChip { - pub air: Rv32JalrCoreAir, +pub struct Rv32JalrCoreStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl Rv32JalrCoreChip { +impl Rv32JalrCoreStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: SharedVariableRangeCheckerChip, ) -> Self { assert!(range_checker_chip.range_max_bits() >= 16); Self { - air: Rv32JalrCoreAir { - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - range_bus: range_checker_chip.bus(), - }, + adapter, bitwise_lookup_chip, range_checker_chip, } } } -impl> VmCoreChip for Rv32JalrCoreChip +impl TraceStep for Rv32JalrCoreStep where - I::Reads: Into<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, - I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = [u8; RV32_REGISTER_NUM_LIMBS], + WriteData = [u8; RV32_REGISTER_NUM_LIMBS], + TraceContext<'a> = (), + >, { - type Record = Rv32JalrCoreRecord; - type Air = Rv32JalrCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let Instruction { opcode, c, g, .. } = *instruction; + let local_opcode = Rv32JalrOpcode::from_usize(opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET)); + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let rs1 = self.adapter.read(state.memory, instruction, adapter_row); + // TODO(ayush): avoid this conversion + let rs1_val = compose(rs1.map(F::from_canonical_u8)); + let imm = c.as_canonical_u32(); let imm_sign = g.as_canonical_u32(); let imm_extended = imm + imm_sign * 0xffff0000; - let rs1 = reads.into()[0]; - let rs1_val = compose(rs1); + // TODO(ayush): this is bad since we're treating adapters as generic. maybe + // add a .state() function to adapters or get_from_pc like in air + let adapter_row_ref: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); + let from_pc = adapter_row_ref.from_state.pc.as_canonical_u32(); let (to_pc, rd_data) = run_jalr(local_opcode, from_pc, imm_extended, rs1_val); - self.bitwise_lookup_chip - .request_range(rd_data[0], rd_data[1]); - self.range_checker_chip - .add_count(rd_data[2], RV32_CELL_BITS); - self.range_checker_chip - .add_count(rd_data[3], PC_BITS - RV32_CELL_BITS * 3); - let mask = (1 << 15) - 1; let to_pc_least_sig_bit = rs1_val.wrapping_add(imm_extended) & 1; let to_pc_limbs = array::from_fn(|i| ((to_pc >> (1 + i * 15)) & mask)); - let rd_data = rd_data.map(F::from_canonical_u32); + let core_row: &mut Rv32JalrCoreCols = core_row.borrow_mut(); + core_row.imm = c; + core_row.rd_data = array::from_fn(|i| F::from_canonical_u32(rd_data[i + 1])); + core_row.rs1_data = rs1.map(F::from_canonical_u8); + core_row.to_pc_least_sig_bit = F::from_canonical_u32(to_pc_least_sig_bit); + core_row.to_pc_limbs = to_pc_limbs.map(F::from_canonical_u32); + core_row.imm_sign = g; + core_row.is_valid = F::ONE; - let output = AdapterRuntimeContext { - to_pc: Some(to_pc), - writes: [rd_data].into(), - }; - - Ok(( - output, - Rv32JalrCoreRecord { - imm: c, - rd_data: array::from_fn(|i| rd_data[i + 1]), - rs1_data: rs1, - to_pc_least_sig_bit: F::from_canonical_u32(to_pc_least_sig_bit), - to_pc_limbs, - imm_sign: g, - }, - )) + self.adapter.write( + state.memory, + instruction, + adapter_row, + &rd_data.map(|x| x as u8), + ); + + *state.pc = to_pc; + + *trace_offset += width; + + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET) - ) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + let core_row: &mut Rv32JalrCoreCols = core_row.borrow_mut(); + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + // TODO(ayush): this shouldn't be here since it is generic on A + let adapter_row: &mut Rv32JalrAdapterCols = adapter_row.borrow_mut(); + + // composed is the composition of 3 most significant limbs of rd + let composed = core_row + .rd_data + .iter() + .enumerate() + .fold(F::ZERO, |acc, (i, &val)| { + acc + val * F::from_canonical_u32(1 << ((i + 1) * RV32_CELL_BITS)) + }); + + let least_sig_limb = + adapter_row.from_state.pc + F::from_canonical_u32(DEFAULT_PC_STEP) - composed; + + let rd_data: [F; RV32_REGISTER_NUM_LIMBS] = array::from_fn(|i| { + if i == 0 { + least_sig_limb + } else { + core_row.rd_data[i - 1] + } + }); + + self.bitwise_lookup_chip + .request_range(rd_data[0].as_canonical_u32(), rd_data[1].as_canonical_u32()); + + self.range_checker_chip + .add_count(rd_data[2].as_canonical_u32(), RV32_CELL_BITS); + self.range_checker_chip + .add_count(rd_data[3].as_canonical_u32(), PC_BITS - RV32_CELL_BITS * 3); + + self.range_checker_chip + .add_count(core_row.to_pc_limbs[0].as_canonical_u32(), 15); + self.range_checker_chip + .add_count(core_row.to_pc_limbs[1].as_canonical_u32(), 14); } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - self.range_checker_chip.add_count(record.to_pc_limbs[0], 15); - self.range_checker_chip.add_count(record.to_pc_limbs[1], 14); - - let core_cols: &mut Rv32JalrCoreCols = row_slice.borrow_mut(); - core_cols.imm = record.imm; - core_cols.rd_data = record.rd_data; - core_cols.rs1_data = record.rs1_data; - core_cols.to_pc_least_sig_bit = record.to_pc_least_sig_bit; - core_cols.to_pc_limbs = record.to_pc_limbs.map(F::from_canonical_u32); - core_cols.imm_sign = record.imm_sign; - core_cols.is_valid = F::ONE; +impl StepExecutorE1 for Rv32JalrCoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData = [u8; RV32_REGISTER_NUM_LIMBS], + WriteData = [u8; RV32_REGISTER_NUM_LIMBS], + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, c, g, .. } = instruction; + + let local_opcode = + Rv32JalrOpcode::from_usize(opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET)); + + let rs1 = self.adapter.read(state, instruction); + let rs1 = u32::from_le_bytes(rs1); + + let imm = c.as_canonical_u32(); + let imm_sign = g.as_canonical_u32(); + let imm_extended = imm + imm_sign * 0xffff0000; + + // TODO(ayush): should this be [u8; 4]? + let (to_pc, rd) = run_jalr(local_opcode, *state.pc, imm_extended, rs1); + let rd = rd.map(|x| x as u8); + + self.adapter.write(state, instruction, &rd); + + *state.pc = to_pc; + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // returns (to_pc, rd_data) +#[inline(always)] pub(super) fn run_jalr( _opcode: Rv32JalrOpcode, pc: u32, diff --git a/extensions/rv32im/circuit/src/jalr/mod.rs b/extensions/rv32im/circuit/src/jalr/mod.rs index 1d85dcbe4a..f3e6f9d9a3 100644 --- a/extensions/rv32im/circuit/src/jalr/mod.rs +++ b/extensions/rv32im/circuit/src/jalr/mod.rs @@ -1,6 +1,6 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use crate::adapters::Rv32JalrAdapterChip; +use crate::adapters::{Rv32JalrAdapterAir, Rv32JalrAdapterStep}; mod core; pub use core::*; @@ -8,4 +8,6 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32JalrChip = VmChipWrapper, Rv32JalrCoreChip>; +pub type Rv32JalrAir = VmAirWrapper; +pub type Rv32JalrStep = Rv32JalrCoreStep; +pub type Rv32JalrChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/jalr/tests.rs b/extensions/rv32im/circuit/src/jalr/tests.rs index e22d97967f..c7d54b525d 100644 --- a/extensions/rv32im/circuit/src/jalr/tests.rs +++ b/extensions/rv32im/circuit/src/jalr/tests.rs @@ -2,7 +2,7 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{ testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - VmAdapterChip, + VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -12,26 +12,60 @@ use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, utils::disable_debug_builder, - verifier::VerificationError, - Chip, ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use super::Rv32JalrCoreAir; use crate::{ - adapters::{compose, Rv32JalrAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - jalr::{run_jalr, Rv32JalrChip, Rv32JalrCoreChip, Rv32JalrCoreCols}, + adapters::{ + compose, Rv32JalrAdapterAir, Rv32JalrAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + jalr::{run_jalr, Rv32JalrChip, Rv32JalrCoreCols, Rv32JalrStep}, + test_utils::get_verification_error, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; + type F = BabyBear; fn into_limbs(num: u32) -> [u32; 4] { array::from_fn(|i| (num >> (8 * i)) & 255) } +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32JalrChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_checker_chip = tester.memory_controller().range_checker.clone(); + + let chip = Rv32JalrChip::::new( + VmAirWrapper::new( + Rv32JalrAdapterAir::new(tester.memory_bridge(), tester.execution_bridge()), + Rv32JalrCoreAir::new(bitwise_bus, range_checker_chip.bus()), + ), + Rv32JalrStep::new( + Rv32JalrAdapterStep::new(), + bitwise_chip.clone(), + range_checker_chip.clone(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, bitwise_chip) +} + #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, @@ -55,6 +89,7 @@ fn set_and_execute( tester.write(1, b, rs1); + let initial_pc = initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))); tester.execute_with_pc( chip, &Instruction::from_usize( @@ -69,9 +104,8 @@ fn set_and_execute( imm_sign as usize, ], ), - initial_pc.unwrap_or(rng.gen_range(0..(1 << PC_BITS))), + initial_pc, ); - let initial_pc = tester.execution.last_from_pc().as_canonical_u32(); let final_pc = tester.execution.last_to_pc().as_canonical_u32(); let rs1 = compose(rs1); @@ -92,21 +126,11 @@ fn set_and_execute( #[test] fn rand_jalr_test() { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32JalrCoreChip::new(bitwise_chip.clone(), range_checker_chip.clone()); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { + let num_ops = 100; + for _ in 0..num_ops { set_and_execute( &mut tester, &mut chip, @@ -119,7 +143,6 @@ fn rand_jalr_test() { ); } - drop(range_checker_chip); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } @@ -128,10 +151,18 @@ fn rand_jalr_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// +#[derive(Clone, Copy, Default, PartialEq)] +struct JalrPrankValues { + pub rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, + pub rs1_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + pub to_pc_least_sig_bit: Option, + pub to_pc_limbs: Option<[u32; 2]>, + pub imm_sign: Option, +} + #[allow(clippy::too_many_arguments)] fn run_negative_jalr_test( opcode: Rv32JalrOpcode, @@ -139,27 +170,13 @@ fn run_negative_jalr_test( initial_rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, initial_imm: Option, initial_imm_sign: Option, - rd_data: Option<[u32; RV32_REGISTER_NUM_LIMBS - 1]>, - rs1_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, - to_pc_least_sig_bit: Option, - to_pc_limbs: Option<[u32; 2]>, - imm_sign: Option, - expected_error: VerificationError, + prank_vals: JalrPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let adapter_width = BaseAir::::width(adapter.air()); - let inner = Rv32JalrCoreChip::new(bitwise_chip.clone(), range_checker_chip.clone()); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); + let (mut chip, bitwise_chip) = create_test_chip(&mut tester); set_and_execute( &mut tester, @@ -172,49 +189,38 @@ fn run_negative_jalr_test( initial_rs1, ); - let tester = tester.build(); - - let jalr_trace_width = chip.trace_width(); - let air = chip.air(); - let mut chip_input = chip.generate_air_proof_input(); - let jalr_trace = chip_input.raw.common_main.as_mut().unwrap(); - { - let mut trace_row = jalr_trace.row_slice(0).to_vec(); - + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { + let mut trace_row = trace.row_slice(0).to_vec(); let (_, core_row) = trace_row.split_at_mut(adapter_width); - let core_cols: &mut Rv32JalrCoreCols = core_row.borrow_mut(); - if let Some(data) = rd_data { + if let Some(data) = prank_vals.rd_data { core_cols.rd_data = data.map(F::from_canonical_u32); } - - if let Some(data) = rs1_data { + if let Some(data) = prank_vals.rs1_data { core_cols.rs1_data = data.map(F::from_canonical_u32); } - - if let Some(data) = to_pc_least_sig_bit { + if let Some(data) = prank_vals.to_pc_least_sig_bit { core_cols.to_pc_least_sig_bit = F::from_canonical_u32(data); } - - if let Some(data) = to_pc_limbs { + if let Some(data) = prank_vals.to_pc_limbs { core_cols.to_pc_limbs = data.map(F::from_canonical_u32); } - - if let Some(data) = imm_sign { + if let Some(data) = prank_vals.imm_sign { core_cols.imm_sign = F::from_canonical_u32(data); } - *jalr_trace = RowMajorMatrix::new(trace_row, jalr_trace_width); - } + *trace = RowMajorMatrix::new(trace_row, trace.width()); + }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester - .load_air_proof_input((air, chip_input)) + .build() + .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -225,12 +231,11 @@ fn invalid_cols_negative_tests() { None, Some(15362), Some(0), - None, - None, - None, - None, - Some(1), - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + imm_sign: Some(1), + ..Default::default() + }, + false, ); run_negative_jalr_test( @@ -239,12 +244,11 @@ fn invalid_cols_negative_tests() { None, Some(15362), Some(1), - None, - None, - None, - None, - Some(0), - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + imm_sign: Some(0), + ..Default::default() + }, + false, ); run_negative_jalr_test( @@ -253,12 +257,11 @@ fn invalid_cols_negative_tests() { Some([23, 154, 67, 28]), Some(42512), Some(1), - None, - None, - Some(0), - None, - None, - VerificationError::OodEvaluationMismatch, + JalrPrankValues { + to_pc_least_sig_bit: Some(0), + ..Default::default() + }, + false, ); } @@ -270,12 +273,11 @@ fn overflow_negative_tests() { None, None, None, - Some([1, 0, 0]), - None, - None, - None, - None, - VerificationError::ChallengePhaseError, + JalrPrankValues { + rd_data: Some([1, 0, 0]), + ..Default::default() + }, + true, ); run_negative_jalr_test( @@ -284,15 +286,14 @@ fn overflow_negative_tests() { Some([0, 0, 0, 0]), Some((1 << 15) - 2), Some(0), - None, - None, - None, - Some([ - (F::NEG_ONE * F::from_canonical_u32((1 << 14) + 1)).as_canonical_u32(), - 1, - ]), - None, - VerificationError::ChallengePhaseError, + JalrPrankValues { + to_pc_limbs: Some([ + (F::NEG_ONE * F::from_canonical_u32((1 << 14) + 1)).as_canonical_u32(), + 1, + ]), + ..Default::default() + }, + true, ); } @@ -301,36 +302,6 @@ fn overflow_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - - let adapter = Rv32JalrAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ); - let inner = Rv32JalrCoreChip::new(bitwise_chip, range_checker_chip); - let mut chip = Rv32JalrChip::::new(adapter, inner, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - JALR, - None, - None, - None, - None, - ); - } -} #[test] fn run_jalr_sanity_test() { diff --git a/extensions/rv32im/circuit/src/less_than/core.rs b/extensions/rv32im/circuit/src/less_than/core.rs index a605dc43de..fddad57900 100644 --- a/extensions/rv32im/circuit/src/less_than/core.rs +++ b/extensions/rv32im/circuit/src/less_than/core.rs @@ -3,16 +3,23 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, utils::not, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::LessThanOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,8 +27,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -45,7 +50,7 @@ pub struct LessThanCoreCols { pub diff_val: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct LessThanCoreAir { pub bus: BitwiseOperationLookupBus, offset: usize, @@ -163,95 +168,132 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct LessThanCoreRecord { - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - pub cmp_result: T, - pub b_msb_f: T, - pub c_msb_f: T, - pub diff_val: T, - pub diff_idx: usize, - pub opcode: LessThanOpcode, -} - -pub struct LessThanCoreChip { - pub air: LessThanCoreAir, +pub struct LessThanStep { + adapter: A, + offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, } -impl LessThanCoreChip { +impl LessThanStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, ) -> Self { Self { - air: LessThanCoreAir { - bus: bitwise_lookup_chip.bus(), - offset, - }, + adapter, + offset, bitwise_lookup_chip, } } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for LessThanCoreChip +impl TraceStep + for LessThanStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), + >, { - type Record = LessThanCoreRecord; - type Air = LessThanCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", LessThanOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + debug_assert!(LIMB_BITS <= 8); + let Instruction { opcode, .. } = instruction; - let less_than_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); + let local_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + + let (cmp_result, _, _, _) = run_less_than::(local_opcode, &rs1, &rs2); + + let core_row: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut(); + core_row.b = rs1.map(F::from_canonical_u8); + core_row.c = rs2.map(F::from_canonical_u8); + core_row.opcode_slt_flag = F::from_bool(local_opcode == LessThanOpcode::SLT); + core_row.opcode_sltu_flag = F::from_bool(local_opcode == LessThanOpcode::SLTU); + + let mut output = [0u8; NUM_LIMBS]; + output[0] = cmp_result as u8; + + self.adapter + .write(state.memory, instruction, adapter_row, &[output].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let core_row: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut(); + + let b = core_row.b.map(|x| x.as_canonical_u32() as u8); + let c = core_row.c.map(|x| x.as_canonical_u32() as u8); + // It's easier (and faster?) to re-execute + let local_opcode = if core_row.opcode_slt_flag.is_one() { + LessThanOpcode::SLT + } else { + LessThanOpcode::SLTU + }; let (cmp_result, diff_idx, b_sign, c_sign) = - run_less_than::(less_than_opcode, &b, &c); + run_less_than::(local_opcode, &b, &c); // We range check (b_msb_f + 128) and (c_msb_f + 128) if signed, // b_msb_f and c_msb_f if not let (b_msb_f, b_msb_range) = if b_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - b[NUM_LIMBS - 1]), - b[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u16((1u16 << LIMB_BITS) - b[NUM_LIMBS - 1] as u16), + b[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(b[NUM_LIMBS - 1]), + F::from_canonical_u8(b[NUM_LIMBS - 1]), b[NUM_LIMBS - 1] - + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)), + + (((local_opcode == LessThanOpcode::SLT) as u8) << (LIMB_BITS - 1)), ) }; let (c_msb_f, c_msb_range) = if c_sign { ( - -F::from_canonical_u32((1 << LIMB_BITS) - c[NUM_LIMBS - 1]), - c[NUM_LIMBS - 1] - (1 << (LIMB_BITS - 1)), + -F::from_canonical_u16((1u16 << LIMB_BITS) - c[NUM_LIMBS - 1] as u16), + c[NUM_LIMBS - 1] - (1u8 << (LIMB_BITS - 1)), ) } else { ( - F::from_canonical_u32(c[NUM_LIMBS - 1]), + F::from_canonical_u8(c[NUM_LIMBS - 1]), c[NUM_LIMBS - 1] - + (((less_than_opcode == LessThanOpcode::SLT) as u32) << (LIMB_BITS - 1)), + + (((local_opcode == LessThanOpcode::SLT) as u8) << (LIMB_BITS - 1)), ) }; - self.bitwise_lookup_chip - .request_range(b_msb_range, c_msb_range); let diff_val = if diff_idx == NUM_LIMBS { 0 @@ -263,60 +305,83 @@ where } .as_canonical_u32() } else if cmp_result { - c[diff_idx] - b[diff_idx] + (c[diff_idx] - b[diff_idx]) as u32 } else { - b[diff_idx] - c[diff_idx] + (b[diff_idx] - c[diff_idx]) as u32 }; + self.bitwise_lookup_chip + .request_range(b_msb_range as u32, c_msb_range as u32); if diff_idx != NUM_LIMBS { self.bitwise_lookup_chip.request_range(diff_val - 1, 0); } - let mut writes = [0u32; NUM_LIMBS]; - writes[0] = cmp_result as u32; - - let output = AdapterRuntimeContext::without_pc([writes.map(F::from_canonical_u32)]); - let record = LessThanCoreRecord { - opcode: less_than_opcode, - b: data[0], - c: data[1], - cmp_result: F::from_bool(cmp_result), - b_msb_f, - c_msb_f, - diff_val: F::from_canonical_u32(diff_val), - diff_idx, - }; - - Ok((output, record)) + core_row.diff_val = F::from_canonical_u32(diff_val); + core_row.cmp_result = F::from_bool(cmp_result); + core_row.b_msb_f = b_msb_f; + core_row.c_msb_f = c_msb_f; + core_row.diff_val = F::from_canonical_u32(diff_val); + core_row.diff_marker = array::from_fn(|i| F::from_bool(i == diff_idx)); } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", LessThanOpcode::from_usize(opcode - self.air.offset)) - } +impl StepExecutorE1 + for LessThanStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; + + let less_than_opcode = LessThanOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut LessThanCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.cmp_result = record.cmp_result; - row_slice.b_msb_f = record.b_msb_f; - row_slice.c_msb_f = record.c_msb_f; - row_slice.diff_val = record.diff_val; - row_slice.opcode_slt_flag = F::from_bool(record.opcode == LessThanOpcode::SLT); - row_slice.opcode_sltu_flag = F::from_bool(record.opcode == LessThanOpcode::SLTU); - row_slice.diff_marker = array::from_fn(|i| F::from_bool(i == record.diff_idx)); + let [rs1, rs2] = self.adapter.read(state, instruction).into(); + + // Run the comparison + let (cmp_result, _, _, _) = + run_less_than::(less_than_opcode, &rs1, &rs2); + let mut rd = [0u8; NUM_LIMBS]; + rd[0] = cmp_result as u8; + + self.adapter.write(state, instruction, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // Returns (cmp_result, diff_idx, x_sign, y_sign) +#[inline(always)] pub(super) fn run_less_than( opcode: LessThanOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], ) -> (bool, usize, bool, bool) { let x_sign = (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT; let y_sign = (y[NUM_LIMBS - 1] >> (LIMB_BITS - 1) == 1) && opcode == LessThanOpcode::SLT; diff --git a/extensions/rv32im/circuit/src/less_than/mod.rs b/extensions/rv32im/circuit/src/less_than/mod.rs index f8247d2d33..7fc8937d3c 100644 --- a/extensions/rv32im/circuit/src/less_than/mod.rs +++ b/extensions/rv32im/circuit/src/less_than/mod.rs @@ -1,6 +1,8 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -8,8 +10,8 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32LessThanChip = VmChipWrapper< - F, - Rv32BaseAluAdapterChip, - LessThanCoreChip, ->; +pub type Rv32LessThanAir = + VmAirWrapper>; +pub type Rv32LessThanStep = + LessThanStep, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>; +pub type Rv32LessThanChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/less_than/tests.rs b/extensions/rv32im/circuit/src/less_than/tests.rs index 18d64bf5f6..373c2920f7 100644 --- a/extensions/rv32im/circuit/src/less_than/tests.rs +++ b/extensions/rv32im/circuit/src/less_than/tests.rs @@ -1,17 +1,17 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InstructionExecutor, VmAirWrapper, }, - utils::{generate_long_number, i32_to_f}, + utils::i32_to_f, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::LessThanOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::LessThanOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::{FieldAlgebra, PrimeField32}, @@ -20,20 +20,95 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_less_than, LessThanCoreChip, Rv32LessThanChip}; +use super::{core::run_less_than, LessThanCoreAir, LessThanStep, Rv32LessThanChip}; use crate::{ - adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, less_than::LessThanCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; + +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32LessThanChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let chip = Rv32LessThanChip::::new( + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + LessThanCoreAir::new(bitwise_bus, LessThanOpcode::CLASS_OFFSET), + ), + LessThanStep::new( + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), + bitwise_chip.clone(), + LessThanOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, bitwise_chip) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: LessThanOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) + } else { + generate_rv32_is_type_immediate(rng) + }; + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), + ) + }; + + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(chip, &instruction); + + let (cmp, _, _, _) = run_less_than::(opcode, &b, &c); + let mut a = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; + a[0] = F::from_bool(cmp); + assert_eq!(a, tester.read::(1, rd)); +} ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS @@ -42,100 +117,51 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// +#[test_case(SLT, 100)] +#[test_case(SLTU, 100)] fn run_rv32_lt_rand_test(opcode: LessThanOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32LessThanChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), - ), - LessThanCoreChip::new(bitwise_chip.clone(), LessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) - } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) - }; - - let (instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let (cmp, _, _, _) = - run_less_than::(opcode, &b, &c); - let mut a = [F::ZERO; RV32_REGISTER_NUM_LIMBS]; - a[0] = F::from_bool(cmp); - assert_eq!(a, tester.read::(1, rd)); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); } // Test special case where b = c let b = [101, 128, 202, 255]; - let (instruction, _) = rv32_rand_write_register_or_imm( + set_and_execute( &mut tester, - b, - b, - None, - opcode.global_opcode().as_usize(), + &mut chip, &mut rng, + opcode, + Some(b), + Some(false), + Some(b), ); - tester.execute(&mut chip, &instruction); let b = [36, 0, 0, 0]; - let (instruction, _) = rv32_rand_write_register_or_imm( + set_and_execute( &mut tester, - b, - b, - Some(36), - opcode.global_opcode().as_usize(), + &mut chip, &mut rng, + opcode, + Some(b), + Some(true), + Some(b), ); - tester.execute(&mut chip, &instruction); let tester = tester.build().load(chip).load(bitwise_chip).finalize(); tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_slt_rand_test() { - run_rv32_lt_rand_test(LessThanOpcode::SLT, 100); -} - -#[test] -fn rv32_sltu_rand_test() { - run_rv32_lt_rand_test(LessThanOpcode::SLTU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32LessThanTestChip = - VmChipWrapper, LessThanCoreChip>; - #[derive(Clone, Copy, Default, PartialEq)] struct LessThanPrankValues { pub b_msb: Option, @@ -145,67 +171,29 @@ struct LessThanPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_lt_negative_test( +fn run_negative_less_than_test( opcode: LessThanOpcode, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - cmp_result: bool, + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_cmp_result: bool, prank_vals: LessThanPrankValues, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let mut chip = Rv32LessThanTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - LessThanCoreChip::new(bitwise_chip.clone(), LessThanOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + &mut rng, + opcode, + Some(b), + Some(false), + Some(c), ); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, _, b_sign, c_sign) = - run_less_than::(opcode, &b, &c); - - if prank_vals != LessThanPrankValues::default() { - debug_assert!(prank_vals.diff_val.is_some()); - let b_msb = prank_vals.b_msb.unwrap_or( - b[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if b_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let c_msb = prank_vals.c_msb.unwrap_or( - c[RV32_REGISTER_NUM_LIMBS - 1] as i32 - if c_sign { 1 << RV32_CELL_BITS } else { 0 }, - ); - let sign_offset = if opcode == LessThanOpcode::SLT { - 1 << (RV32_CELL_BITS - 1) - } else { - 0 - }; - - bitwise_chip.clear(); - bitwise_chip.request_range( - (b_msb + sign_offset) as u8 as u32, - (c_msb + sign_offset) as u8 as u32, - ); - - let diff_val = prank_vals - .diff_val - .unwrap() - .clamp(0, (1 << RV32_CELL_BITS) - 1); - if diff_val > 0 { - bitwise_chip.request_range(diff_val - 1, 0); - } - }; - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut LessThanCoreCols = @@ -223,9 +211,9 @@ fn run_rv32_lt_negative_test( if let Some(diff_val) = prank_vals.diff_val { cols.diff_val = F::from_canonical_u32(diff_val); } - cols.cmp_result = F::from_bool(cmp_result); + cols.cmp_result = F::from_bool(prank_cmp_result); - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -234,11 +222,7 @@ fn run_rv32_lt_negative_test( .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -246,8 +230,8 @@ fn rv32_lt_wrong_false_cmp_negative_test() { let b = [145, 34, 25, 205]; let c = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -255,8 +239,8 @@ fn rv32_lt_wrong_true_cmp_negative_test() { let b = [73, 35, 25, 205]; let c = [145, 34, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -264,8 +248,8 @@ fn rv32_lt_wrong_eq_negative_test() { let b = [73, 35, 25, 205]; let c = [73, 35, 25, 205]; let prank_vals = Default::default(); - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -276,8 +260,8 @@ fn rv32_lt_fake_diff_val_negative_test() { diff_val: Some(F::NEG_ONE.as_canonical_u32()), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } #[test] @@ -289,8 +273,8 @@ fn rv32_lt_zero_diff_val_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } #[test] @@ -302,8 +286,8 @@ fn rv32_lt_fake_diff_marker_negative_test() { diff_val: Some(72), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -315,8 +299,8 @@ fn rv32_lt_zero_diff_marker_negative_test() { diff_val: Some(0), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -329,7 +313,7 @@ fn rv32_slt_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, false); + run_negative_less_than_test(SLT, b, c, false, prank_vals, false); } #[test] @@ -342,7 +326,7 @@ fn rv32_slt_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, false, prank_vals, true); + run_negative_less_than_test(SLT, b, c, false, prank_vals, true); } #[test] @@ -355,7 +339,7 @@ fn rv32_slt_wrong_c_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, false); + run_negative_less_than_test(SLT, b, c, true, prank_vals, false); } #[test] @@ -368,7 +352,7 @@ fn rv32_slt_wrong_c_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLT, b, c, true, prank_vals, true); + run_negative_less_than_test(SLT, b, c, true, prank_vals, true); } #[test] @@ -381,7 +365,7 @@ fn rv32_sltu_wrong_b_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, false); } #[test] @@ -394,7 +378,7 @@ fn rv32_sltu_wrong_b_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, true, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, true, prank_vals, true); } #[test] @@ -407,7 +391,7 @@ fn rv32_sltu_wrong_c_msb_negative_test() { diff_val: Some(1), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, false); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, false); } #[test] @@ -420,7 +404,7 @@ fn rv32_sltu_wrong_c_msb_sign_negative_test() { diff_val: Some(256), ..Default::default() }; - run_rv32_lt_negative_test(LessThanOpcode::SLTU, b, c, false, prank_vals, true); + run_negative_less_than_test(SLTU, b, c, false, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -431,10 +415,10 @@ fn rv32_sltu_wrong_c_msb_sign_negative_test() { #[test] fn run_sltu_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLTU, &x, &y); + run_less_than::(SLTU, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(!x_sign); // unsigned @@ -443,10 +427,10 @@ fn run_sltu_sanity_test() { #[test] fn run_slt_same_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [145, 34, 25, 205]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [73, 35, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &y); + run_less_than::(SLT, &x, &y); assert!(cmp_result); assert_eq!(diff_idx, 1); assert!(x_sign); // negative @@ -455,10 +439,10 @@ fn run_slt_same_sign_sanity_test() { #[test] fn run_slt_diff_sign_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [173, 34, 25, 205]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &y); + run_less_than::(SLT, &x, &y); assert!(!cmp_result); assert_eq!(diff_idx, 3); assert!(!x_sign); // positive @@ -467,9 +451,9 @@ fn run_slt_diff_sign_sanity_test() { #[test] fn run_less_than_equal_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 35, 25, 55]; let (cmp_result, diff_idx, x_sign, y_sign) = - run_less_than::(LessThanOpcode::SLT, &x, &x); + run_less_than::(SLT, &x, &x); assert!(!cmp_result); assert_eq!(diff_idx, RV32_REGISTER_NUM_LIMBS); assert!(!x_sign); // positive diff --git a/extensions/rv32im/circuit/src/load_sign_extend/core.rs b/extensions/rv32im/circuit/src/load_sign_extend/core.rs index 2284d6815c..5142502200 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/core.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/core.rs @@ -3,15 +3,23 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, Result, StepExecutorE1, TraceStep, + VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ utils::select, var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -59,7 +67,7 @@ pub struct LoadSignExtendCoreRecord { pub most_sig_bit: bool, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct LoadSignExtendCoreAir { pub range_bus: VariableRangeCheckerBus, } @@ -178,53 +186,73 @@ where } } -pub struct LoadSignExtendCoreChip { - pub air: LoadSignExtendCoreAir, +pub struct LoadSignExtendStep { + adapter: A, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl LoadSignExtendCoreChip { - pub fn new(range_checker_chip: SharedVariableRangeCheckerChip) -> Self { +impl + LoadSignExtendStep +{ + pub fn new(adapter: A, range_checker_chip: SharedVariableRangeCheckerChip) -> Self { Self { - air: LoadSignExtendCoreAir:: { - range_bus: range_checker_chip.bus(), - }, + adapter, range_checker_chip, } } } -impl, const NUM_CELLS: usize, const LIMB_BITS: usize> - VmCoreChip for LoadSignExtendCoreChip +impl TraceStep + for LoadSignExtendStep where - I::Reads: Into<([[F; NUM_CELLS]; 2], F)>, - I::Writes: From<[[F; NUM_CELLS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = (([u8; NUM_CELLS], [u8; NUM_CELLS]), u32), + WriteData = [u8; NUM_CELLS], + TraceContext<'a> = &'a SharedVariableRangeCheckerChip, + >, { - type Record = LoadSignExtendCoreRecord; - type Air = LoadSignExtendCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let Instruction { opcode, .. } = instruction; + let local_opcode = Rv32LoadStoreOpcode::from_usize( - instruction - .opcode - .local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), ); - let (data, shift_amount) = reads.into(); - let shift_amount = shift_amount.as_canonical_u32(); + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let ((prev_data, read_data), shift_amount) = + self.adapter.read(state.memory, instruction, adapter_row); + let prev_data = prev_data.map(F::from_canonical_u8); + let read_data = read_data.map(F::from_canonical_u8); + + // TODO(ayush): should functions operate on u8 limbs instead of F? let write_data: [F; NUM_CELLS] = run_write_data_sign_extend::<_, NUM_CELLS, LIMB_BITS>( local_opcode, - data[1], - data[0], + read_data, + prev_data, shift_amount, ); - let output = AdapterRuntimeContext::without_pc([write_data]); let most_sig_limb = match local_opcode { LOADB => write_data[0], @@ -234,50 +262,107 @@ where .as_canonical_u32(); let most_sig_bit = most_sig_limb & (1 << (LIMB_BITS - 1)); + + let read_shift = shift_amount & 2; + + let core_row: &mut LoadSignExtendCoreCols = core_row.borrow_mut(); + core_row.opcode_loadb_flag0 = + F::from_bool(local_opcode == LOADB && (shift_amount & 1) == 0); + core_row.opcode_loadb_flag1 = + F::from_bool(local_opcode == LOADB && (shift_amount & 1) == 1); + core_row.opcode_loadh_flag = F::from_bool(local_opcode == LOADH); + core_row.shift_most_sig_bit = F::from_canonical_u32((shift_amount & 2) >> 1); + core_row.data_most_sig_bit = F::from_bool(most_sig_bit != 0); + core_row.prev_data = prev_data; + core_row.shifted_read_data = + array::from_fn(|i| read_data[(i + read_shift as usize) % NUM_CELLS]); + + self.adapter.write( + state.memory, + instruction, + adapter_row, + &write_data.map(|x| x.as_canonical_u32() as u8), + ); + + // TODO(ayush): move to fill_trace_row self.range_checker_chip .add_count(most_sig_limb - most_sig_bit, LIMB_BITS - 1); - let read_shift = shift_amount & 2; + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; - Ok(( - output, - LoadSignExtendCoreRecord { - opcode: local_opcode, - most_sig_bit: most_sig_bit != 0, - prev_data: data[0], - shifted_read_data: array::from_fn(|i| { - data[1][(i + read_shift as usize) % NUM_CELLS] - }), - shift_amount, - }, - )) + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET) - ) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + let _core_row: &mut LoadSignExtendCoreCols = core_row.borrow_mut(); + + self.adapter + .fill_trace_row(mem_helper, &self.range_checker_chip, adapter_row); } +} + +impl StepExecutorE1 + for LoadSignExtendStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData = (([u8; NUM_CELLS], [u8; NUM_CELLS]), u32), + WriteData = [u8; NUM_CELLS], + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; + + let local_opcode = Rv32LoadStoreOpcode::from_usize( + opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET), + ); + + let ((_, read_data), shift_amount) = self.adapter.read(state, instruction); + let read_data = read_data.map(F::from_canonical_u8); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut LoadSignExtendCoreCols = row_slice.borrow_mut(); - let opcode = record.opcode; - let shift = record.shift_amount; - core_cols.opcode_loadb_flag0 = F::from_bool(opcode == LOADB && (shift & 1) == 0); - core_cols.opcode_loadb_flag1 = F::from_bool(opcode == LOADB && (shift & 1) == 1); - core_cols.opcode_loadh_flag = F::from_bool(opcode == LOADH); - core_cols.shift_most_sig_bit = F::from_canonical_u32((shift & 2) >> 1); - core_cols.data_most_sig_bit = F::from_bool(record.most_sig_bit); - core_cols.prev_data = record.prev_data; - core_cols.shifted_read_data = record.shifted_read_data; + // TODO(ayush): clean this up for e1 + let write_data = run_write_data_sign_extend::<_, NUM_CELLS, LIMB_BITS>( + local_opcode, + read_data, + [F::ZERO; NUM_CELLS], + shift_amount, + ); + let write_data = write_data.map(|x| x.as_canonical_u32() as u8); + + self.adapter.write(state, instruction, &write_data); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } +// TODO(ayush): remove _prev_data +#[inline(always)] pub(super) fn run_write_data_sign_extend< F: PrimeField32, const NUM_CELLS: usize, diff --git a/extensions/rv32im/circuit/src/load_sign_extend/mod.rs b/extensions/rv32im/circuit/src/load_sign_extend/mod.rs index 79efbe912e..43c563e432 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/mod.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/mod.rs @@ -1,7 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use crate::adapters::Rv32LoadStoreAdapterChip; +use crate::adapters::{Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterStep}; mod core; pub use core::*; @@ -9,8 +9,11 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32LoadSignExtendChip = VmChipWrapper< - F, - Rv32LoadStoreAdapterChip, - LoadSignExtendCoreChip, +pub type Rv32LoadSignExtendAir = VmAirWrapper< + Rv32LoadStoreAdapterAir, + LoadSignExtendCoreAir, >; +pub type Rv32LoadSignExtendStep = + LoadSignExtendStep; +pub type Rv32LoadSignExtendChip = + NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/load_sign_extend/tests.rs b/extensions/rv32im/circuit/src/load_sign_extend/tests.rs index 0fe6d859d1..b3f4805a00 100644 --- a/extensions/rv32im/circuit/src/load_sign_extend/tests.rs +++ b/extensions/rv32im/circuit/src/load_sign_extend/tests.rs @@ -2,7 +2,7 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::arch::{ testing::{memory::gen_pointer, VmChipTestBuilder}, - VmAdapterChip, + VmAirWrapper, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; @@ -14,19 +14,24 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::run_write_data_sign_extend; +use super::{run_write_data_sign_extend, LoadSignExtendCoreAir}; use crate::{ - adapters::{compose, Rv32LoadStoreAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + compose, Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterStep, RV32_CELL_BITS, + RV32_REGISTER_NUM_LIMBS, + }, load_sign_extend::LoadSignExtendCoreCols, - LoadSignExtendCoreChip, Rv32LoadSignExtendChip, + test_utils::get_verification_error, + LoadSignExtendStep, Rv32LoadSignExtendChip, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; @@ -34,6 +39,27 @@ fn into_limbs(num: u32) -> [u32; array::from_fn(|i| (num >> (LIMB_BITS * i)) & ((1 << LIMB_BITS) - 1)) } +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Rv32LoadSignExtendChip { + let range_checker_chip = tester.memory_controller().range_checker.clone(); + Rv32LoadSignExtendChip::::new( + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + range_checker_chip.bus(), + tester.address_bits(), + ), + LoadSignExtendCoreAir::new(range_checker_chip.bus()), + ), + LoadSignExtendStep::new( + Rv32LoadStoreAdapterStep::new(tester.address_bits()), + range_checker_chip.clone(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} + #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, @@ -55,13 +81,7 @@ fn set_and_execute( _ => unreachable!(), }; let ptr_val = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - alignment)), + 0..(1 << (tester.memory_controller().mem_config().pointer_max_bits - alignment)), ) << alignment; let rs1 = rs1 @@ -123,40 +143,19 @@ fn set_and_execute( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_load_sign_extend_test() { - setup_tracing(); +#[test_case(LOADB, 100)] +#[test_case(LOADH, 100)] +fn rand_load_sign_extend_test(opcode: Rv32LoadStoreOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADB, - None, - None, - None, - None, - ); + let mut chip = create_test_chip(&mut tester); + for _ in 0..num_ops { set_and_execute( &mut tester, &mut chip, &mut rng, - LOADH, + opcode, None, None, None, @@ -172,36 +171,29 @@ fn rand_load_sign_extend_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_negative_loadstore_test( - opcode: Rv32LoadStoreOpcode, - read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, +#[derive(Clone, Copy, Default, PartialEq)] +struct LoadSignExtPrankValues { data_most_sig_bit: Option, shift_most_sig_bit: Option, opcode_flags: Option<[bool; 3]>, +} + +#[allow(clippy::too_many_arguments)] +fn run_negative_load_sign_extend_test( + opcode: Rv32LoadStoreOpcode, + read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, - expected_error: VerificationError, + prank_vals: LoadSignExtPrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip.clone()); - let adapter_width = BaseAir::::width(adapter.air()); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut chip = create_test_chip(&mut tester); set_and_execute( &mut tester, @@ -214,78 +206,78 @@ fn run_negative_loadstore_test( imm_sign, ); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); - let (_, core_row) = trace_row.split_at_mut(adapter_width); let core_cols: &mut LoadSignExtendCoreCols = core_row.borrow_mut(); - if let Some(shifted_read_data) = read_data { core_cols.shifted_read_data = shifted_read_data.map(F::from_canonical_u32); } - - if let Some(data_most_sig_bit) = data_most_sig_bit { + if let Some(data_most_sig_bit) = prank_vals.data_most_sig_bit { core_cols.data_most_sig_bit = F::from_canonical_u32(data_most_sig_bit); } - if let Some(shift_most_sig_bit) = shift_most_sig_bit { + if let Some(shift_most_sig_bit) = prank_vals.shift_most_sig_bit { core_cols.shift_most_sig_bit = F::from_canonical_u32(shift_most_sig_bit); } - - if let Some(opcode_flags) = opcode_flags { + if let Some(opcode_flags) = prank_vals.opcode_flags { core_cols.opcode_loadb_flag0 = F::from_bool(opcode_flags[0]); core_cols.opcode_loadb_flag1 = F::from_bool(opcode_flags[1]); core_cols.opcode_loadh_flag = F::from_bool(opcode_flags[2]); } + *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn loadstore_negative_tests() { - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADB, Some([233, 187, 145, 238]), - Some(0), - None, None, None, None, - None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + data_most_sig_bit: Some(0), + ..Default::default() + }, + true, ); - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADH, None, - None, - Some(0), - None, Some([202, 109, 183, 26]), Some(31212), None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + shift_most_sig_bit: Some(0), + ..Default::default() + }, + true, ); - run_negative_loadstore_test( + run_negative_load_sign_extend_test( LOADB, None, - None, - None, - Some([true, false, false]), Some([250, 132, 77, 5]), Some(47741), None, - VerificationError::ChallengePhaseError, + LoadSignExtPrankValues { + opcode_flags: Some([true, false, false]), + ..Default::default() + }, + true, ); } @@ -294,46 +286,6 @@ fn loadstore_negative_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadSignExtendCoreChip::new(range_checker_chip); - let mut chip = - Rv32LoadSignExtendChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 10; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADH, - None, - None, - None, - None, - ); - } -} #[test] fn solve_loadh_extend_sign_sanity_test() { diff --git a/extensions/rv32im/circuit/src/loadstore/core.rs b/extensions/rv32im/circuit/src/loadstore/core.rs index 36beb10629..9b9eeb59ab 100644 --- a/extensions/rv32im/circuit/src/loadstore/core.rs +++ b/extensions/rv32im/circuit/src/loadstore/core.rs @@ -1,10 +1,19 @@ use std::borrow::{Borrow, BorrowMut}; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, Result, StepExecutorE1, TraceStep, + VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; +use openvm_circuit_primitives::var_range::SharedVariableRangeCheckerChip; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *}; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -70,7 +79,7 @@ pub struct LoadStoreCoreRecord { pub write_data: [F; NUM_CELLS], } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, derive_new::new)] pub struct LoadStoreCoreAir { pub offset: usize, } @@ -246,70 +255,74 @@ where } } -#[derive(Debug)] -pub struct LoadStoreCoreChip { - pub air: LoadStoreCoreAir, +pub struct LoadStoreStep { + adapter: A, + pub range_checker_chip: SharedVariableRangeCheckerChip, + pub offset: usize, } -impl LoadStoreCoreChip { - pub fn new(offset: usize) -> Self { +impl LoadStoreStep { + pub fn new( + adapter: A, + range_checker_chip: SharedVariableRangeCheckerChip, + offset: usize, + ) -> Self { Self { - air: LoadStoreCoreAir { offset }, + adapter, + range_checker_chip, + offset, } } } -impl, const NUM_CELLS: usize> VmCoreChip - for LoadStoreCoreChip +impl TraceStep for LoadStoreStep where - I::Reads: Into<([[F; NUM_CELLS]; 2], F)>, - I::Writes: From<[[F; NUM_CELLS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData = (([u8; NUM_CELLS], [u8; NUM_CELLS]), u32), + WriteData = [u8; NUM_CELLS], + TraceContext<'a> = &'a SharedVariableRangeCheckerChip, + >, { - type Record = LoadStoreCoreRecord; - type Air = LoadStoreCoreAir; - - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, - instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { - let local_opcode = - Rv32LoadStoreOpcode::from_usize(instruction.opcode.local_opcode_idx(self.air.offset)); - - let (reads, shift_amount) = reads.into(); - let shift = shift_amount.as_canonical_u32(); - let prev_data = reads[0]; - let read_data = reads[1]; - let write_data = run_write_data(local_opcode, read_data, prev_data, shift); - let output = AdapterRuntimeContext::without_pc([write_data]); - - Ok(( - output, - LoadStoreCoreRecord { - opcode: local_opcode, - shift, - prev_data, - read_data, - write_data, - }, - )) - } - fn get_opcode_name(&self, opcode: usize) -> String { format!( "{:?}", - Rv32LoadStoreOpcode::from_usize(opcode - self.air.offset) + Rv32LoadStoreOpcode::from_usize(opcode - self.offset) ) } - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let core_cols: &mut LoadStoreCoreCols = row_slice.borrow_mut(); - let opcode = record.opcode; + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let Instruction { opcode, .. } = instruction; + + let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let ((prev_data, read_data), shift) = + self.adapter.read(state.memory, instruction, adapter_row); + let prev_data = prev_data.map(F::from_canonical_u8); + let read_data = read_data.map(F::from_canonical_u8); + + let write_data = run_write_data(local_opcode, read_data, prev_data, shift); + + let core_cols: &mut LoadStoreCoreCols = core_row.borrow_mut(); + let flags = &mut core_cols.flags; *flags = [F::ZERO; 4]; - match (opcode, record.shift) { + match (local_opcode, shift) { (LOADW, 0) => flags[0] = F::TWO, (LOADHU, 0) => flags[1] = F::TWO, (LOADHU, 2) => flags[2] = F::TWO, @@ -328,18 +341,86 @@ where (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE), _ => unreachable!(), }; - core_cols.prev_data = record.prev_data; - core_cols.read_data = record.read_data; + core_cols.prev_data = prev_data; + core_cols.read_data = read_data; core_cols.is_valid = F::ONE; - core_cols.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode)); - core_cols.write_data = record.write_data; + core_cols.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&local_opcode)); + core_cols.write_data = write_data; + + self.adapter.write( + state.memory, + instruction, + adapter_row, + &write_data.map(|x| x.as_canonical_u32() as u8), + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, _core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter + .fill_trace_row(mem_helper, &self.range_checker_chip, adapter_row); + } +} + +impl StepExecutorE1 for LoadStoreStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData = (([u8; NUM_CELLS], [u8; NUM_CELLS]), u32), + WriteData = [u8; NUM_CELLS], + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; + + // Get the local opcode for this instruction + let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let ((prev_data, read_data), shift_amount) = self.adapter.read(state, instruction); + let prev_data = prev_data.map(F::from_canonical_u8); + let read_data = read_data.map(F::from_canonical_u8); + + // Process the data according to the load/store type and alignment + let write_data = run_write_data(local_opcode, read_data, prev_data, shift_amount); + let write_data = write_data.map(|x| x.as_canonical_u32() as u8); + + self.adapter.write(state, instruction, &write_data); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } +#[inline(always)] pub(super) fn run_write_data( opcode: Rv32LoadStoreOpcode, read_data: [F; NUM_CELLS], diff --git a/extensions/rv32im/circuit/src/loadstore/mod.rs b/extensions/rv32im/circuit/src/loadstore/mod.rs index 825f82166c..51735afbd7 100644 --- a/extensions/rv32im/circuit/src/loadstore/mod.rs +++ b/extensions/rv32im/circuit/src/loadstore/mod.rs @@ -2,12 +2,15 @@ mod core; pub use core::*; -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32LoadStoreAdapterChip, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::RV32_REGISTER_NUM_LIMBS; +use crate::adapters::{Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterStep}; #[cfg(test)] mod tests; -pub type Rv32LoadStoreChip = - VmChipWrapper, LoadStoreCoreChip>; +pub type Rv32LoadStoreAir = + VmAirWrapper>; +pub type Rv32LoadStoreStep = LoadStoreStep; +pub type Rv32LoadStoreChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/loadstore/tests.rs b/extensions/rv32im/circuit/src/loadstore/tests.rs index 0fbfa137b9..e090570cb4 100644 --- a/extensions/rv32im/circuit/src/loadstore/tests.rs +++ b/extensions/rv32im/circuit/src/loadstore/tests.rs @@ -3,7 +3,7 @@ use std::{array, borrow::BorrowMut}; use openvm_circuit::{ arch::{ testing::{memory::gen_pointer, VmChipTestBuilder}, - VmAdapterChip, + VmAirWrapper, }, utils::u32_into_limbs, }; @@ -17,21 +17,49 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, }; -use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, seq::SliceRandom, Rng}; +use test_case::test_case; -use super::{run_write_data, LoadStoreCoreChip, Rv32LoadStoreChip}; +use super::{run_write_data, LoadStoreCoreAir, LoadStoreStep, Rv32LoadStoreChip}; use crate::{ - adapters::{compose, Rv32LoadStoreAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + adapters::{ + compose, Rv32LoadStoreAdapterAir, Rv32LoadStoreAdapterCols, Rv32LoadStoreAdapterStep, + RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, loadstore::LoadStoreCoreCols, + test_utils::get_verification_error, }; const IMM_BITS: usize = 16; +const MAX_INS_CAPACITY: usize = 128; type F = BabyBear; +fn create_test_chip(tester: &mut VmChipTestBuilder) -> Rv32LoadStoreChip { + let range_checker_chip = tester.range_checker(); + + Rv32LoadStoreChip::::new( + VmAirWrapper::new( + Rv32LoadStoreAdapterAir::new( + tester.memory_bridge(), + tester.execution_bridge(), + range_checker_chip.bus(), + tester.address_bits(), + ), + LoadStoreCoreAir::new(Rv32LoadStoreOpcode::CLASS_OFFSET), + ), + LoadStoreStep::new( + Rv32LoadStoreAdapterStep::new(tester.address_bits()), + range_checker_chip.clone(), + Rv32LoadStoreOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ) +} + #[allow(clippy::too_many_arguments)] fn set_and_execute( tester: &mut VmChipTestBuilder, @@ -55,13 +83,7 @@ fn set_and_execute( }; let ptr_val = rng.gen_range( - 0..(1 - << (tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits - - alignment)), + 0..(1 << (tester.memory_controller().mem_config().pointer_max_bits - alignment)), ) << alignment; let rs1 = rs1 @@ -143,80 +165,23 @@ fn set_and_execute( /// Randomly generate computations and execute, ensuring that the generated trace /// passes all constraints. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_loadstore_test() { - setup_tracing(); +#[test_case(LOADW, 100)] +#[test_case(LOADBU, 100)] +#[test_case(LOADHU, 100)] +#[test_case(STOREW, 100)] +#[test_case(STOREB, 100)] +#[test_case(STOREH, 100)] +fn rand_loadstore_test(opcode: Rv32LoadStoreOpcode, num_ops: usize) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut chip = create_test_chip(&mut tester); - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADBU, - None, - None, - None, - None, - ); + for _ in 0..num_ops { set_and_execute( &mut tester, &mut chip, &mut rng, - LOADHU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREH, + opcode, None, None, None, @@ -224,7 +189,6 @@ fn rand_loadstore_test() { ); } - drop(range_checker_chip); let tester = tester.build().load(chip).finalize(); tester.simple_test().expect("Verification failed"); } @@ -233,38 +197,31 @@ fn rand_loadstore_test() { // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adaptor is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -#[allow(clippy::too_many_arguments)] -fn run_negative_loadstore_test( - opcode: Rv32LoadStoreOpcode, +#[derive(Clone, Copy, Default, PartialEq)] +struct LoadStorePrankValues { read_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, prev_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, write_data: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, flags: Option<[u32; 4]>, is_load: Option, + mem_as: Option, +} + +#[allow(clippy::too_many_arguments)] +fn run_negative_loadstore_test( + opcode: Rv32LoadStoreOpcode, rs1: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, imm: Option, imm_sign: Option, - mem_as: Option, - expected_error: VerificationError, + prank_vals: LoadStorePrankValues, + interaction_error: bool, ) { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let adapter_width = BaseAir::::width(adapter.air()); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); + let mut chip = create_test_chip(&mut tester); set_and_execute( &mut tester, @@ -274,38 +231,45 @@ fn run_negative_loadstore_test( rs1, imm, imm_sign, - mem_as, + None, ); + let adapter_width = BaseAir::::width(&chip.air.adapter); + let modify_trace = |trace: &mut DenseMatrix| { let mut trace_row = trace.row_slice(0).to_vec(); - let (_, core_row) = trace_row.split_at_mut(adapter_width); + let (adapter_row, core_row) = trace_row.split_at_mut(adapter_width); + let adapter_cols: &mut Rv32LoadStoreAdapterCols = adapter_row.borrow_mut(); let core_cols: &mut LoadStoreCoreCols = core_row.borrow_mut(); - if let Some(read_data) = read_data { + + if let Some(read_data) = prank_vals.read_data { core_cols.read_data = read_data.map(F::from_canonical_u32); } - if let Some(prev_data) = prev_data { + if let Some(prev_data) = prank_vals.prev_data { core_cols.prev_data = prev_data.map(F::from_canonical_u32); } - if let Some(write_data) = write_data { + if let Some(write_data) = prank_vals.write_data { core_cols.write_data = write_data.map(F::from_canonical_u32); } - if let Some(flags) = flags { + if let Some(flags) = prank_vals.flags { core_cols.flags = flags.map(F::from_canonical_u32); } - if let Some(is_load) = is_load { + if let Some(is_load) = prank_vals.is_load { core_cols.is_load = F::from_bool(is_load); } + if let Some(mem_as) = prank_vals.mem_as { + adapter_cols.mem_as = F::from_canonical_u32(mem_as); + } + *trace = RowMajorMatrix::new(trace_row, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .finalize(); - tester.simple_test_with_expected_error(expected_error); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -315,41 +279,36 @@ fn negative_wrong_opcode_tests() { None, None, None, - None, - Some(false), - None, - None, - None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + is_load: Some(false), + ..Default::default() + }, + false, ); run_negative_loadstore_test( LOADBU, - None, - None, - None, - Some([0, 0, 0, 2]), - None, Some([4, 0, 0, 0]), Some(1), None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + flags: Some([0, 0, 0, 2]), + ..Default::default() + }, + false, ); run_negative_loadstore_test( STOREH, - None, - None, - None, - Some([1, 0, 1, 0]), - Some(true), Some([11, 169, 76, 28]), Some(37121), None, - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + flags: Some([1, 0, 1, 0]), + is_load: Some(true), + ..Default::default() + }, + false, ); } @@ -357,30 +316,34 @@ fn negative_wrong_opcode_tests() { fn negative_write_data_tests() { run_negative_loadstore_test( LOADHU, - Some([175, 33, 198, 250]), - Some([90, 121, 64, 205]), - Some([175, 33, 0, 0]), - Some([0, 2, 0, 0]), - Some(true), Some([13, 11, 156, 23]), Some(43641), None, - None, - VerificationError::ChallengePhaseError, + LoadStorePrankValues { + read_data: Some([175, 33, 198, 250]), + prev_data: Some([90, 121, 64, 205]), + write_data: Some([175, 33, 0, 0]), + flags: Some([0, 2, 0, 0]), + is_load: Some(true), + mem_as: None, + }, + true, ); run_negative_loadstore_test( STOREB, - Some([175, 33, 198, 250]), - Some([90, 121, 64, 205]), - Some([175, 121, 64, 205]), - Some([0, 0, 1, 1]), - None, Some([45, 123, 87, 24]), Some(28122), Some(0), - None, - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + read_data: Some([175, 33, 198, 250]), + prev_data: Some([90, 121, 64, 205]), + write_data: Some([175, 121, 64, 205]), + flags: Some([0, 0, 1, 1]), + is_load: None, + mem_as: None, + }, + false, ); } @@ -391,39 +354,35 @@ fn negative_wrong_address_space_tests() { None, None, None, - None, - None, - None, - None, - None, - Some(3), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(3), + ..Default::default() + }, + false, ); + run_negative_loadstore_test( LOADW, None, None, None, - None, - None, - None, - None, - None, - Some(4), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(4), + ..Default::default() + }, + false, ); + run_negative_loadstore_test( STOREW, None, None, None, - None, - None, - None, - None, - None, - Some(1), - VerificationError::OodEvaluationMismatch, + LoadStorePrankValues { + mem_as: Some(1), + ..Default::default() + }, + false, ); } @@ -432,86 +391,6 @@ fn negative_wrong_address_space_tests() { /// /// Ensure that solve functions produce the correct results. /////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let adapter = Rv32LoadStoreAdapterChip::::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - tester.address_bits(), - range_checker_chip.clone(), - ); - let core = LoadStoreCoreChip::new(Rv32LoadStoreOpcode::CLASS_OFFSET); - let mut chip = Rv32LoadStoreChip::::new(adapter, core, tester.offline_memory_mutex_arc()); - - let num_tests: usize = 100; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADBU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - LOADHU, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREW, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREB, - None, - None, - None, - None, - ); - set_and_execute( - &mut tester, - &mut chip, - &mut rng, - STOREH, - None, - None, - None, - None, - ); - } -} - #[test] fn run_loadw_storew_sanity_test() { let read_data = [138, 45, 202, 76].map(F::from_canonical_u32); diff --git a/extensions/rv32im/circuit/src/mul/core.rs b/extensions/rv32im/circuit/src/mul/core.rs index fa65a6cf09..0aa431baa8 100644 --- a/extensions/rv32im/circuit/src/mul/core.rs +++ b/extensions/rv32im/circuit/src/mul/core.rs @@ -3,13 +3,20 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::MulOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -17,8 +24,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; #[repr(C)] #[derive(AlignedBorrow)] @@ -29,7 +34,7 @@ pub struct MultiplicationCoreCols { pub bus: RangeTupleCheckerBus<2>, pub offset: usize, @@ -110,13 +115,20 @@ where } #[derive(Debug)] -pub struct MultiplicationCoreChip { - pub air: MultiplicationCoreAir, +pub struct MultiplicationStep { + adapter: A, + pub offset: usize, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl MultiplicationCoreChip { - pub fn new(range_tuple_chip: SharedRangeTupleCheckerChip<2>, offset: usize) -> Self { +impl + MultiplicationStep +{ + pub fn new( + adapter: A, + range_tuple_chip: SharedRangeTupleCheckerChip<2>, + offset: usize, + ) -> Self { // The RangeTupleChecker is used to range check (a[i], carry[i]) pairs where 0 <= i // < NUM_LIMBS. a[i] must have LIMB_BITS bits and carry[i] is the sum of i + 1 bytes // (with LIMB_BITS bits). @@ -132,102 +144,157 @@ impl MultiplicationCoreChip { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for MultiplicationCoreChip +impl TraceStep + for MultiplicationStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), + >, { - type Record = MultiplicationCoreRecord; - type Air = MultiplicationCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", MulOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let Instruction { opcode, .. } = instruction; + assert_eq!( - MulOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)), + MulOpcode::from_usize(opcode.local_opcode_idx(self.offset)), MulOpcode::MUL ); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, carry) = run_mul::(&b, &c); + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + let (a, carry) = run_mul::(&rs1, &rs2); + + let core_row: &mut MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut(); + core_row.a = a.map(F::from_canonical_u8); + core_row.b = rs1.map(F::from_canonical_u8); + core_row.c = rs2.map(F::from_canonical_u8); + core_row.is_valid = F::ONE; + + // TODO(ayush): move to fill_trace_row for (a, carry) in a.iter().zip(carry.iter()) { - self.range_tuple_chip.add_count(&[*a, *carry]); + self.range_tuple_chip.add_count(&[*a as u32, *carry]); } - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = MultiplicationCoreRecord { - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - }; + // TODO(ayush): avoid this conversion + self.adapter + .write(state.memory, instruction, adapter_row, &[a].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; - Ok((output, record)) + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", MulOpcode::from_usize(opcode - self.air.offset)) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, _core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); } +} + +impl StepExecutorE1 + for MultiplicationStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; + + // Verify the opcode is MUL + // TODO(ayush): debug_assert + assert_eq!( + MulOpcode::from_usize(opcode.local_opcode_idx(self.offset)), + MulOpcode::MUL + ); + + let [rs1, rs2] = self.adapter.read(state, instruction).into(); + + let (rd, _) = run_mul::(&rs1, &rs2); - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut MultiplicationCoreCols<_, NUM_LIMBS, LIMB_BITS> = - row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.is_valid = F::ONE; + self.adapter.write(state, instruction, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // returns mul, carry +#[inline(always)] pub(super) fn run_mul( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], [u32; NUM_LIMBS]) { - let mut result = [0; NUM_LIMBS]; - let mut carry = [0; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], [u32; NUM_LIMBS]) { + let mut result = [0u8; NUM_LIMBS]; + let mut carry = [0u32; NUM_LIMBS]; for i in 0..NUM_LIMBS { + let mut res = 0u32; if i > 0 { - result[i] = carry[i - 1]; + res = carry[i - 1]; } for j in 0..=i { - result[i] += x[j] * y[i - j]; + res += (x[j] as u32) * (y[i - j] as u32); } - carry[i] = result[i] >> LIMB_BITS; - result[i] %= 1 << LIMB_BITS; + carry[i] = res >> LIMB_BITS; + res %= 1u32 << LIMB_BITS; + result[i] = res as u8; } (result, carry) } diff --git a/extensions/rv32im/circuit/src/mul/mod.rs b/extensions/rv32im/circuit/src/mul/mod.rs index 5f28439977..00a3fe77bb 100644 --- a/extensions/rv32im/circuit/src/mul/mod.rs +++ b/extensions/rv32im/circuit/src/mul/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep}; mod core; pub use core::*; @@ -8,8 +9,11 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32MultiplicationChip = VmChipWrapper< - F, - Rv32MultAdapterChip, - MultiplicationCoreChip, +pub type Rv32MultiplicationAir = VmAirWrapper< + Rv32MultAdapterAir, + MultiplicationCoreAir, >; +pub type Rv32MultiplicationStep = + MultiplicationStep; +pub type Rv32MultiplicationChip = + NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/mul/tests.rs b/extensions/rv32im/circuit/src/mul/tests.rs index b942c24cc3..e21195e088 100644 --- a/extensions/rv32im/circuit/src/mul/tests.rs +++ b/extensions/rv32im/circuit/src/mul/tests.rs @@ -1,15 +1,12 @@ -use std::borrow::BorrowMut; +use std::{array, borrow::BorrowMut}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, RANGE_TUPLE_CHECKER_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, - }, - utils::generate_long_number, +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, RANGE_TUPLE_CHECKER_BUS}, + InstructionExecutor, VmAirWrapper, }; use openvm_circuit_primitives::range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::MulOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::MulOpcode::{self, MUL}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -18,20 +15,74 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; use super::core::run_mul; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - mul::{MultiplicationCoreChip, MultiplicationCoreCols, Rv32MultiplicationChip}, - test_utils::rv32_rand_write_register_or_imm, + adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + mul::{MultiplicationCoreCols, MultiplicationStep, Rv32MultiplicationChip}, + test_utils::{get_verification_error, rv32_rand_write_register_or_imm}, + MultiplicationCoreAir, }; +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; type F = BabyBear; +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> (Rv32MultiplicationChip, SharedRangeTupleCheckerChip<2>) { + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); + + let chip = Rv32MultiplicationChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + MultiplicationCoreAir::new(range_tuple_bus, MulOpcode::CLASS_OFFSET), + ), + MultiplicationStep::new( + Rv32MultAdapterStep::new(), + range_tuple_checker.clone(), + MulOpcode::CLASS_OFFSET, + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, range_tuple_checker) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: MulOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let c = c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + + let (mut instruction, rd) = + rv32_rand_write_register_or_imm(tester, b, c, None, opcode.global_opcode().as_usize(), rng); + + instruction.e = F::ZERO; + tester.execute(chip, &instruction); + + let (a, _) = run_mul::(&b, &c); + assert_eq!( + a.map(F::from_canonical_u8), + tester.read::(1, rd) + ) +} + ////////////////////////////////////////////////////////////////////////////////////// // POSITIVE TESTS // @@ -39,48 +90,15 @@ type F = BabyBear; // passes all constraints. ////////////////////////////////////////////////////////////////////////////////////// -fn run_rv32_mul_rand_test(num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; +#[test] +fn run_rv32_mul_rand_test() { let mut rng = create_seeded_rng(); - - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MultiplicationChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - MultiplicationCoreChip::new(range_tuple_checker.clone(), MulOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, range_tuple_checker) = create_test_chip(&mut tester); + let num_ops = 100; for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - - let (mut instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - None, - MulOpcode::MUL.global_opcode().as_usize(), - &mut rng, - ); - instruction.e = F::ZERO; - tester.execute(&mut chip, &instruction); - - let (a, _) = run_mul::(&b, &c); - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) - ) + set_and_execute(&mut tester, &mut chip, &mut rng, MUL, None, None); } let tester = tester @@ -91,74 +109,36 @@ fn run_rv32_mul_rand_test(num_ops: usize) { tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_mul_rand_test() { - run_rv32_mul_rand_test(1); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32MultiplicationTestChip = VmChipWrapper< - F, - TestAdapterChip, - MultiplicationCoreChip, ->; - #[allow(clippy::too_many_arguments)] -fn run_rv32_mul_negative_test( - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], - is_valid: bool, +fn run_negative_mul_test( + opcode: MulOpcode, + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], + prank_is_valid: bool, interaction_error: bool, ) { - const MAX_NUM_LIMBS: u32 = 32; - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MultiplicationTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - MultiplicationCoreChip::new(range_tuple_chip.clone(), MulOpcode::CLASS_OFFSET), - tester.offline_memory_mutex_arc(), - ); - - tester.execute( - &mut chip, - &Instruction::from_usize(MulOpcode::MUL.global_opcode(), [0, 0, 0, 1, 0]), - ); + let (mut chip, range_tuple_chip) = create_test_chip(&mut tester); - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, carry) = run_mul::(&b, &c); - - range_tuple_chip.clear(); - if is_valid { - for (a, carry) in a.iter().zip(carry.iter()) { - range_tuple_chip.add_count(&[*a, *carry]); - } - } + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, Some(b), Some(c)); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut MultiplicationCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - cols.is_valid = F::from_bool(is_valid); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + cols.is_valid = F::from_bool(prank_is_valid); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -167,16 +147,13 @@ fn run_rv32_mul_negative_test( .load_and_prank_trace(chip, modify_trace) .load(range_tuple_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_mul_wrong_negative_test() { - run_rv32_mul_negative_test( + run_negative_mul_test( + MUL, [63, 247, 125, 234], [51, 109, 78, 142], [197, 85, 150, 32], @@ -187,7 +164,8 @@ fn rv32_mul_wrong_negative_test() { #[test] fn rv32_mul_is_valid_false_negative_test() { - run_rv32_mul_negative_test( + run_negative_mul_test( + MUL, [63, 247, 125, 234], [51, 109, 78, 142], [197, 85, 150, 32], @@ -204,9 +182,9 @@ fn rv32_mul_is_valid_false_negative_test() { #[test] fn run_mul_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [197, 85, 150, 32]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [51, 109, 78, 142]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [63, 247, 125, 232]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [197, 85, 150, 32]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [51, 109, 78, 142]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [63, 247, 125, 232]; let c: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (result, carry) = run_mul::(&x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { diff --git a/extensions/rv32im/circuit/src/mulh/core.rs b/extensions/rv32im/circuit/src/mulh/core.rs index 16aa8fd550..354db4ae5a 100644 --- a/extensions/rv32im/circuit/src/mulh/core.rs +++ b/extensions/rv32im/circuit/src/mulh/core.rs @@ -3,16 +3,23 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::MulHOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -20,8 +27,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -40,7 +45,7 @@ pub struct MulHCoreCols { pub opcode_mulhu_flag: T, } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct MulHCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_tuple_bus: RangeTupleCheckerBus<2>, @@ -183,14 +188,15 @@ where } } -pub struct MulHCoreChip { - pub air: MulHCoreAir, +pub struct MulHStep { + adapter: A, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_tuple_chip: SharedRangeTupleCheckerChip<2>, } -impl MulHCoreChip { +impl MulHStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_tuple_chip: SharedRangeTupleCheckerChip<2>, ) -> Self { @@ -209,56 +215,71 @@ impl MulHCoreChip { - pub opcode: MulHOpcode, - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub a_mul: [T; NUM_LIMBS], - pub b_ext: T, - pub c_ext: T, -} - -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for MulHCoreChip +impl TraceStep + for MulHStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), + >, { - type Record = MulHCoreRecord; - type Air = MulHCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!( + "{:?}", + MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET) + ) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let Instruction { opcode, .. } = instruction; + let mulh_opcode = MulHOpcode::from_usize(opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + + let b = rs1.map(u32::from); + let c = rs2.map(u32::from); let (a, a_mul, carry, b_ext, c_ext) = run_mulh::(mulh_opcode, &b, &c); + let core_row: &mut MulHCoreCols<_, NUM_LIMBS, LIMB_BITS> = core_row.borrow_mut(); + core_row.a = a.map(F::from_canonical_u32); + core_row.b = b.map(F::from_canonical_u32); + core_row.c = c.map(F::from_canonical_u32); + core_row.a_mul = a_mul.map(F::from_canonical_u32); + core_row.b_ext = F::from_canonical_u32(b_ext); + core_row.c_ext = F::from_canonical_u32(c_ext); + core_row.opcode_mulh_flag = F::from_bool(mulh_opcode == MulHOpcode::MULH); + core_row.opcode_mulhsu_flag = F::from_bool(mulh_opcode == MulHOpcode::MULHSU); + core_row.opcode_mulhu_flag = F::from_bool(mulh_opcode == MulHOpcode::MULHU); + + // TODO(ayush): move to fill_trace_row for i in 0..NUM_LIMBS { self.range_tuple_chip.add_count(&[a_mul[i], carry[i]]); self.range_tuple_chip @@ -274,46 +295,77 @@ where ); } - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = MulHCoreRecord { - opcode: mulh_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - a_mul: a_mul.map(F::from_canonical_u32), - b_ext: F::from_canonical_u32(b_ext), - c_ext: F::from_canonical_u32(c_ext), - }; + // TODO(ayush): avoid this conversion + let a = a.map(|x| x as u8); + self.adapter + .write(state.memory, instruction, adapter_row, &[a].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - Ok((output, record)) + *trace_offset += width; + + Ok(()) } - fn get_opcode_name(&self, opcode: usize) -> String { - format!( - "{:?}", - MulHOpcode::from_usize(opcode - MulHOpcode::CLASS_OFFSET) - ) + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, _core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); } +} - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - let row_slice: &mut MulHCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.a_mul = record.a_mul; - row_slice.b_ext = record.b_ext; - row_slice.c_ext = record.c_ext; - row_slice.opcode_mulh_flag = F::from_bool(record.opcode == MulHOpcode::MULH); - row_slice.opcode_mulhsu_flag = F::from_bool(record.opcode == MulHOpcode::MULHSU); - row_slice.opcode_mulhu_flag = F::from_bool(record.opcode == MulHOpcode::MULHU); +impl StepExecutorE1 + for MulHStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; + + let mulh_opcode = MulHOpcode::from_usize(opcode.local_opcode_idx(MulHOpcode::CLASS_OFFSET)); + + let [rs1, rs2] = self.adapter.read(state, instruction).into(); + let rs1 = rs1.map(u32::from); + let rs2 = rs2.map(u32::from); + + let (rd, _, _, _, _) = run_mulh::(mulh_opcode, &rs1, &rs2); + let rd = rd.map(|x| x as u8); + + self.adapter.write(state, instruction, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } // returns mulh[[s]u], mul, carry, x_ext, y_ext +#[inline(always)] pub(super) fn run_mulh( opcode: MulHOpcode, x: &[u32; NUM_LIMBS], diff --git a/extensions/rv32im/circuit/src/mulh/mod.rs b/extensions/rv32im/circuit/src/mulh/mod.rs index 284b77191a..39607bdf5c 100644 --- a/extensions/rv32im/circuit/src/mulh/mod.rs +++ b/extensions/rv32im/circuit/src/mulh/mod.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use crate::adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep}; mod core; pub use core::*; @@ -8,5 +9,7 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32MulHChip = - VmChipWrapper, MulHCoreChip>; +pub type Rv32MulHAir = + VmAirWrapper>; +pub type Rv32MulHStep = MulHStep; +pub type Rv32MulHChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/mulh/tests.rs b/extensions/rv32im/circuit/src/mulh/tests.rs index 1c7cf5b5cb..54bff347c1 100644 --- a/extensions/rv32im/circuit/src/mulh/tests.rs +++ b/extensions/rv32im/circuit/src/mulh/tests.rs @@ -3,10 +3,9 @@ use std::borrow::BorrowMut; use openvm_circuit::{ arch::{ testing::{ - memory::gen_pointer, TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, - RANGE_TUPLE_CHECKER_BUS, + memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, RANGE_TUPLE_CHECKER_BUS, }, - ExecutionBridge, InstructionExecutor, VmAdapterChip, VmChipWrapper, + InstructionExecutor, VmAirWrapper, }, utils::generate_long_number, }; @@ -15,7 +14,7 @@ use openvm_circuit_primitives::{ range_tuple::{RangeTupleCheckerBus, SharedRangeTupleCheckerChip}, }; use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::MulHOpcode; +use openvm_rv32im_transpiler::MulHOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -24,36 +23,75 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::rngs::StdRng; +use test_case::test_case; use super::core::run_mulh; use crate::{ - adapters::{Rv32MultAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - mulh::{MulHCoreChip, MulHCoreCols, Rv32MulHChip}, + adapters::{Rv32MultAdapterAir, Rv32MultAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, + mulh::{MulHCoreCols, MulHStep, Rv32MulHChip}, + test_utils::get_verification_error, + MulHCoreAir, }; +const MAX_INS_CAPACITY: usize = 128; +// the max number of limbs we currently support MUL for is 32 (i.e. for U256s) +const MAX_NUM_LIMBS: u32 = 32; type F = BabyBear; -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip( + tester: &mut VmChipTestBuilder, +) -> ( + Rv32MulHChip, + SharedBitwiseOperationLookupChip, + SharedRangeTupleCheckerChip<2>, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let range_tuple_bus = RangeTupleCheckerBus::new( + RANGE_TUPLE_CHECKER_BUS, + [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], + ); + + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); + + let chip = Rv32MulHChip::::new( + VmAirWrapper::new( + Rv32MultAdapterAir::new(tester.execution_bridge(), tester.memory_bridge()), + MulHCoreAir::new(bitwise_bus, range_tuple_bus), + ), + MulHStep::new( + Rv32MultAdapterStep::new(), + bitwise_chip.clone(), + range_tuple_checker.clone(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + + (chip, bitwise_chip, range_tuple_checker) +} #[allow(clippy::too_many_arguments)] -fn run_rv32_mulh_rand_write_execute>( - opcode: MulHOpcode, +fn set_and_execute>( tester: &mut VmChipTestBuilder, chip: &mut E, - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], rng: &mut StdRng, + opcode: MulHOpcode, + b: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, + c: Option<[u32; RV32_REGISTER_NUM_LIMBS]>, ) { + let b = b.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let c = c.unwrap_or(generate_long_number::< + RV32_REGISTER_NUM_LIMBS, + RV32_CELL_BITS, + >(rng)); + let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); let rd = gen_pointer(rng, 4); @@ -61,47 +99,35 @@ fn run_rv32_mulh_rand_write_execute>( tester.write::(1, rs1, b.map(F::from_canonical_u32)); tester.write::(1, rs2, c.map(F::from_canonical_u32)); - let (a, _, _, _, _) = run_mulh::(opcode, &b, &c); tester.execute( chip, &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 0]), ); + let (a, _, _, _, _) = run_mulh::(opcode, &b, &c); assert_eq!( a.map(F::from_canonical_u32), tester.read::(1, rd) ); } +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// + +#[test_case(MULH, 100)] +#[test_case(MULHSU, 100)] +#[test_case(MULHU, 100)] fn run_rv32_mulh_rand_test(opcode: MulHOpcode, num_ops: usize) { - // the max number of limbs we currently support MUL for is 32 (i.e. for U256s) - const MAX_NUM_LIMBS: u32 = 32; let mut rng = create_seeded_rng(); - - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_checker = SharedRangeTupleCheckerChip::new(range_tuple_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MulHChip::::new( - Rv32MultAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - ), - MulHCoreChip::new(bitwise_chip.clone(), range_tuple_checker.clone()), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip, range_tuple_checker) = create_test_chip(&mut tester); for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let c = generate_long_number::(&mut rng); - run_rv32_mulh_rand_write_execute(opcode, &mut tester, &mut chip, b, c, &mut rng); + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None); } let tester = tester @@ -113,88 +139,40 @@ fn run_rv32_mulh_rand_test(opcode: MulHOpcode, num_ops: usize) { tester.simple_test().expect("Verification failed"); } -#[test] -fn rv32_mulh_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULH, 100); -} - -#[test] -fn rv32_mulhsu_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULHSU, 100); -} - -#[test] -fn rv32_mulhu_rand_test() { - run_rv32_mulh_rand_test(MulHOpcode::MULHU, 100); -} - ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32MulHTestChip = - VmChipWrapper, MulHCoreChip>; - #[allow(clippy::too_many_arguments)] -fn run_rv32_mulh_negative_test( +fn run_negative_mulh_test( opcode: MulHOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], b: [u32; RV32_REGISTER_NUM_LIMBS], c: [u32; RV32_REGISTER_NUM_LIMBS], - a_mul: [u32; RV32_REGISTER_NUM_LIMBS], - b_ext: u32, - c_ext: u32, + prank_a_mul: [u32; RV32_REGISTER_NUM_LIMBS], + prank_b_ext: u32, + prank_c_ext: u32, interaction_error: bool, ) { - const MAX_NUM_LIMBS: u32 = 32; - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let range_tuple_bus = RangeTupleCheckerBus::new( - RANGE_TUPLE_CHECKER_BUS, - [1 << RV32_CELL_BITS, MAX_NUM_LIMBS * (1 << RV32_CELL_BITS)], - ); - - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let range_tuple_chip = SharedRangeTupleCheckerChip::new(range_tuple_bus); - + let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32MulHTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - MulHCoreChip::new(bitwise_chip.clone(), range_tuple_chip.clone()), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip, range_tuple_chip) = create_test_chip(&mut tester); - tester.execute( - &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 0]), - ); - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - let (_, _, carry, _, _) = run_mulh::(opcode, &b, &c); - - range_tuple_chip.clear(); - for i in 0..RV32_REGISTER_NUM_LIMBS { - range_tuple_chip.add_count(&[a_mul[i], carry[i]]); - range_tuple_chip.add_count(&[a[i], carry[RV32_REGISTER_NUM_LIMBS + i]]); - } + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, Some(b), Some(c)); + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut MulHCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); - cols.a_mul = a_mul.map(F::from_canonical_u32); - cols.b_ext = F::from_canonical_u32(b_ext); - cols.c_ext = F::from_canonical_u32(c_ext); - *trace = RowMajorMatrix::new(values, trace_width); + cols.a = prank_a.map(F::from_canonical_u32); + cols.a_mul = prank_a_mul.map(F::from_canonical_u32); + cols.b_ext = F::from_canonical_u32(prank_b_ext); + cols.c_ext = F::from_canonical_u32(prank_c_ext); + *trace = RowMajorMatrix::new(values, trace.width()); }; disable_debug_builder(); @@ -204,17 +182,13 @@ fn run_rv32_mulh_negative_test( .load(bitwise_chip) .load(range_tuple_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] fn rv32_mulh_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [130, 9, 135, 241], [197, 85, 150, 32], [51, 109, 78, 142], @@ -227,8 +201,8 @@ fn rv32_mulh_wrong_a_mul_negative_test() { #[test] fn rv32_mulh_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [130, 9, 135, 242], [197, 85, 150, 32], [51, 109, 78, 142], @@ -241,8 +215,8 @@ fn rv32_mulh_wrong_a_negative_test() { #[test] fn rv32_mulh_wrong_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [1, 0, 0, 0], [0, 0, 0, 128], [2, 0, 0, 0], @@ -255,8 +229,8 @@ fn rv32_mulh_wrong_ext_negative_test() { #[test] fn rv32_mulh_invalid_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULH, + run_negative_mulh_test( + MULH, [3, 2, 2, 2], [0, 0, 0, 128], [2, 0, 0, 0], @@ -269,8 +243,8 @@ fn rv32_mulh_invalid_ext_negative_test() { #[test] fn rv32_mulhsu_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [174, 40, 246, 202], [197, 85, 150, 160], [51, 109, 78, 142], @@ -283,8 +257,8 @@ fn rv32_mulhsu_wrong_a_mul_negative_test() { #[test] fn rv32_mulhsu_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [174, 40, 246, 201], [197, 85, 150, 160], [51, 109, 78, 142], @@ -297,8 +271,8 @@ fn rv32_mulhsu_wrong_a_negative_test() { #[test] fn rv32_mulhsu_wrong_b_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [1, 0, 0, 0], [0, 0, 0, 128], [2, 0, 0, 0], @@ -311,8 +285,8 @@ fn rv32_mulhsu_wrong_b_ext_negative_test() { #[test] fn rv32_mulhsu_wrong_c_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHSU, + run_negative_mulh_test( + MULHSU, [0, 0, 0, 64], [0, 0, 0, 128], [0, 0, 0, 128], @@ -325,8 +299,8 @@ fn rv32_mulhsu_wrong_c_ext_negative_test() { #[test] fn rv32_mulhu_wrong_a_mul_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [130, 9, 135, 241], [197, 85, 150, 32], [51, 109, 78, 142], @@ -339,8 +313,8 @@ fn rv32_mulhu_wrong_a_mul_negative_test() { #[test] fn rv32_mulhu_wrong_a_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [130, 9, 135, 240], [197, 85, 150, 32], [51, 109, 78, 142], @@ -353,8 +327,8 @@ fn rv32_mulhu_wrong_a_negative_test() { #[test] fn rv32_mulhu_wrong_ext_negative_test() { - run_rv32_mulh_negative_test( - MulHOpcode::MULHU, + run_negative_mulh_test( + MULHU, [255, 255, 255, 255], [0, 0, 0, 128], [2, 0, 0, 0], @@ -380,7 +354,7 @@ fn run_mulh_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [303, 375, 449, 463]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULH, &x, &y); + run_mulh::(MULH, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -400,7 +374,7 @@ fn run_mulhu_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [107, 93, 18, 0]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHU, &x, &y); + run_mulh::(MULHU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -420,7 +394,7 @@ fn run_mulhsu_pos_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [107, 93, 18, 0]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 205]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHSU, &x, &y); + run_mulh::(MULHSU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); @@ -440,7 +414,7 @@ fn run_mulhsu_neg_sanity_test() { let c: [u32; RV32_REGISTER_NUM_LIMBS] = [212, 292, 326, 379]; let c_mul: [u32; RV32_REGISTER_NUM_LIMBS] = [39, 100, 126, 231]; let (res, res_mul, carry, x_ext, y_ext) = - run_mulh::(MulHOpcode::MULHSU, &x, &y); + run_mulh::(MULHSU, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], res[i]); assert_eq!(z_mul[i], res_mul[i]); diff --git a/extensions/rv32im/circuit/src/shift/core.rs b/extensions/rv32im/circuit/src/shift/core.rs index cada97685e..b53131482c 100644 --- a/extensions/rv32im/circuit/src/shift/core.rs +++ b/extensions/rv32im/circuit/src/shift/core.rs @@ -3,9 +3,16 @@ use std::{ borrow::{Borrow, BorrowMut}, }; -use openvm_circuit::arch::{ - AdapterAirContext, AdapterRuntimeContext, MinimalInstruction, Result, VmAdapterInterface, - VmCoreAir, VmCoreChip, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + AdapterAirContext, AdapterExecutorE1, AdapterTraceStep, MinimalInstruction, Result, + StepExecutorE1, TraceStep, VmAdapterInterface, VmCoreAir, VmStateMut, + }, + system::memory::{ + online::{GuestMemory, TracingMemory}, + MemoryAuxColsFactory, + }, }; use openvm_circuit_primitives::{ bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip}, @@ -13,7 +20,7 @@ use openvm_circuit_primitives::{ var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus}, }; use openvm_circuit_primitives_derive::AlignedBorrow; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; +use openvm_instructions::{instruction::Instruction, program::DEFAULT_PC_STEP, LocalOpcode}; use openvm_rv32im_transpiler::ShiftOpcode; use openvm_stark_backend::{ interaction::InteractionBuilder, @@ -21,8 +28,6 @@ use openvm_stark_backend::{ p3_field::{Field, FieldAlgebra, PrimeField32}, rap::BaseAirWithPublicValues, }; -use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use serde_big_array::BigArray; use strum::IntoEnumIterator; #[repr(C)] @@ -51,7 +56,7 @@ pub struct ShiftCoreCols { pub bit_shift_carry: [T; NUM_LIMBS], } -#[derive(Copy, Clone, Debug)] +#[derive(Copy, Clone, Debug, derive_new::new)] pub struct ShiftCoreAir { pub bitwise_lookup_bus: BitwiseOperationLookupBus, pub range_bus: VariableRangeCheckerBus, @@ -237,155 +242,205 @@ where } } -#[repr(C)] -#[derive(Clone, Debug, Serialize, Deserialize)] -#[serde(bound = "T: Serialize + DeserializeOwned")] -pub struct ShiftCoreRecord { - #[serde(with = "BigArray")] - pub a: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub b: [T; NUM_LIMBS], - #[serde(with = "BigArray")] - pub c: [T; NUM_LIMBS], - pub b_sign: T, - #[serde(with = "BigArray")] - pub bit_shift_carry: [u32; NUM_LIMBS], - pub bit_shift: usize, - pub limb_shift: usize, - pub opcode: ShiftOpcode, -} - -pub struct ShiftCoreChip { - pub air: ShiftCoreAir, +pub struct ShiftStep { + adapter: A, + pub offset: usize, pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, pub range_checker_chip: SharedVariableRangeCheckerChip, } -impl ShiftCoreChip { +impl ShiftStep { pub fn new( + adapter: A, bitwise_lookup_chip: SharedBitwiseOperationLookupChip, range_checker_chip: SharedVariableRangeCheckerChip, offset: usize, ) -> Self { assert_eq!(NUM_LIMBS % 2, 0, "Number of limbs must be divisible by 2"); Self { - air: ShiftCoreAir { - bitwise_lookup_bus: bitwise_lookup_chip.bus(), - range_bus: range_checker_chip.bus(), - offset, - }, + adapter, + offset, bitwise_lookup_chip, range_checker_chip, } } } -impl, const NUM_LIMBS: usize, const LIMB_BITS: usize> - VmCoreChip for ShiftCoreChip +impl TraceStep + for ShiftStep where - I::Reads: Into<[[F; NUM_LIMBS]; 2]>, - I::Writes: From<[[F; NUM_LIMBS]; 1]>, + F: PrimeField32, + A: 'static + + for<'a> AdapterTraceStep< + F, + CTX, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + TraceContext<'a> = (), + >, { - type Record = ShiftCoreRecord; - type Air = ShiftCoreAir; + fn get_opcode_name(&self, opcode: usize) -> String { + format!("{:?}", ShiftOpcode::from_usize(opcode - self.offset)) + } - #[allow(clippy::type_complexity)] - fn execute_instruction( - &self, + fn execute( + &mut self, + state: VmStateMut, CTX>, instruction: &Instruction, - _from_pc: u32, - reads: I::Reads, - ) -> Result<(AdapterRuntimeContext, Self::Record)> { + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { let Instruction { opcode, .. } = instruction; - let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.air.offset)); - let data: [[F; NUM_LIMBS]; 2] = reads.into(); - let b = data[0].map(|x| x.as_canonical_u32()); - let c = data[1].map(|y| y.as_canonical_u32()); - let (a, limb_shift, bit_shift) = run_shift::(shift_opcode, &b, &c); + let local_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); + + let row_slice = &mut trace[*trace_offset..*trace_offset + width]; + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + A::start(*state.pc, state.memory, adapter_row); + + let [rs1, rs2] = self + .adapter + .read(state.memory, instruction, adapter_row) + .into(); + + let (output, limb_shift, bit_shift) = + run_shift::(local_opcode, &rs1, &rs2); + + let core_row: &mut ShiftCoreCols = core_row.borrow_mut(); + core_row.a = output.map(F::from_canonical_u8); + core_row.b = rs1.map(F::from_canonical_u8); + core_row.c = rs2.map(F::from_canonical_u8); + // To be transformed later in fill_trace_row: + core_row.opcode_sll_flag = F::from_canonical_usize(local_opcode as usize); + core_row.bit_shift_marker[0] = F::from_canonical_usize(bit_shift); + core_row.limb_shift_marker[0] = F::from_canonical_usize(limb_shift); - let bit_shift_carry = array::from_fn(|i| match shift_opcode { + self.adapter + .write(state.memory, instruction, adapter_row, &[output].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + *trace_offset += width; + + Ok(()) + } + + fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory, row_slice: &mut [F]) { + let (adapter_row, core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) }; + + self.adapter.fill_trace_row(mem_helper, (), adapter_row); + + let core_row: &mut ShiftCoreCols = core_row.borrow_mut(); + + let local_opcode = + ShiftOpcode::from_usize(core_row.opcode_sll_flag.as_canonical_u32() as usize); + + let bit_shift = core_row.bit_shift_marker[0].as_canonical_u32() as usize; + let limb_shift = core_row.limb_shift_marker[0].as_canonical_u32() as usize; + let b = core_row.b.map(|x| x.as_canonical_u32()); + let c = core_row.c.map(|x| x.as_canonical_u32()); + + let bit_shift_carry = array::from_fn(|i| match local_opcode { ShiftOpcode::SLL => b[i] >> (LIMB_BITS - bit_shift), _ => b[i] % (1 << bit_shift), }); + for carry_val in bit_shift_carry { + self.range_checker_chip.add_count(carry_val, bit_shift); + } + + let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); + self.range_checker_chip.add_count( + (((c[0] as usize) - bit_shift - limb_shift * LIMB_BITS) >> num_bits_log) as u32, + LIMB_BITS - num_bits_log as usize, + ); + let mut b_sign = 0; - if shift_opcode == ShiftOpcode::SRA { + if local_opcode == ShiftOpcode::SRA { b_sign = b[NUM_LIMBS - 1] >> (LIMB_BITS - 1); self.bitwise_lookup_chip .request_xor(b[NUM_LIMBS - 1], 1 << (LIMB_BITS - 1)); } - for i in 0..(NUM_LIMBS / 2) { + for pair in core_row.a.chunks_exact(2) { self.bitwise_lookup_chip - .request_range(a[i * 2], a[i * 2 + 1]); + .request_range(pair[0].as_canonical_u32(), pair[1].as_canonical_u32()); } - let output = AdapterRuntimeContext::without_pc([a.map(F::from_canonical_u32)]); - let record = ShiftCoreRecord { - opcode: shift_opcode, - a: a.map(F::from_canonical_u32), - b: data[0], - c: data[1], - bit_shift_carry, - bit_shift, - limb_shift, - b_sign: F::from_canonical_u32(b_sign), + core_row.bit_multiplier_left = match local_opcode { + ShiftOpcode::SLL => F::from_canonical_usize(1 << bit_shift), + _ => F::ZERO, }; - - Ok((output, record)) + core_row.bit_multiplier_right = match local_opcode { + ShiftOpcode::SLL => F::ZERO, + _ => F::from_canonical_usize(1 << bit_shift), + }; + core_row.b_sign = F::from_canonical_u32(b_sign); + core_row.bit_shift_marker = array::from_fn(|i| F::from_bool(i == bit_shift)); + core_row.limb_shift_marker = array::from_fn(|i| F::from_bool(i == limb_shift)); + core_row.bit_shift_carry = bit_shift_carry.map(F::from_canonical_u32); + core_row.opcode_sll_flag = F::from_bool(local_opcode == ShiftOpcode::SLL); + core_row.opcode_srl_flag = F::from_bool(local_opcode == ShiftOpcode::SRL); + core_row.opcode_sra_flag = F::from_bool(local_opcode == ShiftOpcode::SRA); } +} - fn get_opcode_name(&self, opcode: usize) -> String { - format!("{:?}", ShiftOpcode::from_usize(opcode - self.air.offset)) - } +impl StepExecutorE1 + for ShiftStep +where + F: PrimeField32, + A: 'static + + for<'a> AdapterExecutorE1< + F, + ReadData: Into<[[u8; NUM_LIMBS]; 2]>, + WriteData: From<[[u8; NUM_LIMBS]; 1]>, + >, +{ + fn execute_e1( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { + let Instruction { opcode, .. } = instruction; - fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) { - for carry_val in record.bit_shift_carry { - self.range_checker_chip - .add_count(carry_val, record.bit_shift); - } + let shift_opcode = ShiftOpcode::from_usize(opcode.local_opcode_idx(self.offset)); - let num_bits_log = (NUM_LIMBS * LIMB_BITS).ilog2(); - self.range_checker_chip.add_count( - (((record.c[0].as_canonical_u32() as usize) - - record.bit_shift - - record.limb_shift * LIMB_BITS) - >> num_bits_log) as u32, - LIMB_BITS - num_bits_log as usize, - ); + let [rs1, rs2] = self.adapter.read(state, instruction).into(); - let row_slice: &mut ShiftCoreCols<_, NUM_LIMBS, LIMB_BITS> = row_slice.borrow_mut(); - row_slice.a = record.a; - row_slice.b = record.b; - row_slice.c = record.c; - row_slice.bit_multiplier_left = match record.opcode { - ShiftOpcode::SLL => F::from_canonical_usize(1 << record.bit_shift), - _ => F::ZERO, - }; - row_slice.bit_multiplier_right = match record.opcode { - ShiftOpcode::SLL => F::ZERO, - _ => F::from_canonical_usize(1 << record.bit_shift), - }; - row_slice.b_sign = record.b_sign; - row_slice.bit_shift_marker = array::from_fn(|i| F::from_bool(i == record.bit_shift)); - row_slice.limb_shift_marker = array::from_fn(|i| F::from_bool(i == record.limb_shift)); - row_slice.bit_shift_carry = record.bit_shift_carry.map(F::from_canonical_u32); - row_slice.opcode_sll_flag = F::from_bool(record.opcode == ShiftOpcode::SLL); - row_slice.opcode_srl_flag = F::from_bool(record.opcode == ShiftOpcode::SRL); - row_slice.opcode_sra_flag = F::from_bool(record.opcode == ShiftOpcode::SRA); + let (rd, _, _) = run_shift::(shift_opcode, &rs1, &rs2); + + self.adapter.write(state, instruction, &[rd].into()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) } - fn air(&self) -> &Self::Air { - &self.air + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + self.execute_e1(state, instruction)?; + state.ctx.trace_heights[chip_index] += 1; + + Ok(()) } } +// Returns (result, limb_shift, bit_shift) +#[inline(always)] pub(super) fn run_shift( opcode: ShiftOpcode, - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], usize, usize) { + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], usize, usize) { match opcode { ShiftOpcode::SLL => run_shift_left::(x, y), ShiftOpcode::SRL => run_shift_right::(x, y, true), @@ -393,53 +448,60 @@ pub(super) fn run_shift( } } +#[inline(always)] fn run_shift_left( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], -) -> ([u32; NUM_LIMBS], usize, usize) { - let mut result = [0u32; NUM_LIMBS]; + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], +) -> ([u8; NUM_LIMBS], usize, usize) { + let mut result = [0u8; NUM_LIMBS]; let (limb_shift, bit_shift) = get_shift::(y); for i in limb_shift..NUM_LIMBS { result[i] = if i > limb_shift { - ((x[i - limb_shift] << bit_shift) + (x[i - limb_shift - 1] >> (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) + (((x[i - limb_shift] as u16) << bit_shift) + | ((x[i - limb_shift - 1] as u16) >> (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) } else { - (x[i - limb_shift] << bit_shift) % (1 << LIMB_BITS) - }; + ((x[i - limb_shift] as u16) << bit_shift) % (1u16 << LIMB_BITS) + } as u8; } (result, limb_shift, bit_shift) } +#[inline(always)] fn run_shift_right( - x: &[u32; NUM_LIMBS], - y: &[u32; NUM_LIMBS], + x: &[u8; NUM_LIMBS], + y: &[u8; NUM_LIMBS], logical: bool, -) -> ([u32; NUM_LIMBS], usize, usize) { +) -> ([u8; NUM_LIMBS], usize, usize) { let fill = if logical { 0 } else { - ((1 << LIMB_BITS) - 1) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) + (((1u16 << LIMB_BITS) - 1) as u8) * (x[NUM_LIMBS - 1] >> (LIMB_BITS - 1)) }; let mut result = [fill; NUM_LIMBS]; let (limb_shift, bit_shift) = get_shift::(y); for i in 0..(NUM_LIMBS - limb_shift) { - result[i] = if i + limb_shift + 1 < NUM_LIMBS { - ((x[i + limb_shift] >> bit_shift) + (x[i + limb_shift + 1] << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) + let res = if i + limb_shift + 1 < NUM_LIMBS { + (((x[i + limb_shift] >> bit_shift) as u16) + | ((x[i + limb_shift + 1] as u16) << (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) } else { - ((x[i + limb_shift] >> bit_shift) + (fill << (LIMB_BITS - bit_shift))) - % (1 << LIMB_BITS) - } + (((x[i + limb_shift] >> bit_shift) as u16) | ((fill as u16) << (LIMB_BITS - bit_shift))) + % (1u16 << LIMB_BITS) + }; + result[i] = res as u8; } (result, limb_shift, bit_shift) } -fn get_shift(y: &[u32]) -> (usize, usize) { - // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so so the shift is defined +#[inline(always)] +fn get_shift(y: &[u8]) -> (usize, usize) { + debug_assert!(NUM_LIMBS * LIMB_BITS <= (1 << LIMB_BITS)); + // We assume `NUM_LIMBS * LIMB_BITS <= 2^LIMB_BITS` so the shift is defined // entirely in y[0]. let shift = (y[0] as usize) % (NUM_LIMBS * LIMB_BITS); (shift / LIMB_BITS, shift % LIMB_BITS) diff --git a/extensions/rv32im/circuit/src/shift/mod.rs b/extensions/rv32im/circuit/src/shift/mod.rs index 58d5ad022b..e6fa09d53c 100644 --- a/extensions/rv32im/circuit/src/shift/mod.rs +++ b/extensions/rv32im/circuit/src/shift/mod.rs @@ -1,6 +1,8 @@ -use openvm_circuit::arch::VmChipWrapper; +use openvm_circuit::arch::{NewVmChipWrapper, VmAirWrapper}; -use super::adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; +use super::adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, +}; mod core; pub use core::*; @@ -8,8 +10,8 @@ pub use core::*; #[cfg(test)] mod tests; -pub type Rv32ShiftChip = VmChipWrapper< - F, - Rv32BaseAluAdapterChip, - ShiftCoreChip, ->; +pub type Rv32ShiftAir = + VmAirWrapper>; +pub type Rv32ShiftStep = + ShiftStep, RV32_REGISTER_NUM_LIMBS, RV32_CELL_BITS>; +pub type Rv32ShiftChip = NewVmChipWrapper; diff --git a/extensions/rv32im/circuit/src/shift/tests.rs b/extensions/rv32im/circuit/src/shift/tests.rs index 7a3ef6e72c..fd111e2342 100644 --- a/extensions/rv32im/circuit/src/shift/tests.rs +++ b/extensions/rv32im/circuit/src/shift/tests.rs @@ -1,17 +1,14 @@ use std::{array, borrow::BorrowMut}; -use openvm_circuit::{ - arch::{ - testing::{TestAdapterChip, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - ExecutionBridge, VmAdapterChip, VmChipWrapper, - }, - utils::generate_long_number, +use openvm_circuit::arch::{ + testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + InstructionExecutor, VmAirWrapper, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; -use openvm_instructions::{instruction::Instruction, LocalOpcode}; -use openvm_rv32im_transpiler::ShiftOpcode; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_transpiler::ShiftOpcode::{self, *}; use openvm_stark_backend::{ p3_air::BaseAir, p3_field::FieldAlgebra, @@ -20,108 +17,129 @@ use openvm_stark_backend::{ Matrix, }, utils::disable_debug_builder, - verifier::VerificationError, - ChipUsageGetter, }; use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; +use rand::{rngs::StdRng, Rng}; +use test_case::test_case; -use super::{core::run_shift, Rv32ShiftChip, ShiftCoreChip}; +use super::{core::run_shift, Rv32ShiftChip, ShiftCoreAir, ShiftCoreCols, ShiftStep}; use crate::{ - adapters::{Rv32BaseAluAdapterChip, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}, - shift::ShiftCoreCols, - test_utils::{generate_rv32_is_type_immediate, rv32_rand_write_register_or_imm}, + adapters::{ + Rv32BaseAluAdapterAir, Rv32BaseAluAdapterStep, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, + }, + test_utils::{ + generate_rv32_is_type_immediate, get_verification_error, rv32_rand_write_register_or_imm, + }, }; type F = BabyBear; +const MAX_INS_CAPACITY: usize = 128; -////////////////////////////////////////////////////////////////////////////////////// -// POSITIVE TESTS -// -// Randomly generate computations and execute, ensuring that the generated trace -// passes all constraints. -////////////////////////////////////////////////////////////////////////////////////// - -fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { - let mut rng = create_seeded_rng(); +fn create_test_chip( + tester: &VmChipTestBuilder, +) -> ( + Rv32ShiftChip, + SharedBitwiseOperationLookupChip, +) { let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut tester = VmChipTestBuilder::default(); - let mut chip = Rv32ShiftChip::::new( - Rv32BaseAluAdapterChip::new( - tester.execution_bus(), - tester.program_bus(), - tester.memory_bridge(), - bitwise_chip.clone(), + let chip = Rv32ShiftChip::::new( + VmAirWrapper::new( + Rv32BaseAluAdapterAir::new( + tester.execution_bridge(), + tester.memory_bridge(), + bitwise_bus, + ), + ShiftCoreAir::new( + bitwise_bus, + tester.range_checker().bus(), + ShiftOpcode::CLASS_OFFSET, + ), ), - ShiftCoreChip::new( + ShiftStep::new( + Rv32BaseAluAdapterStep::new(bitwise_chip.clone()), bitwise_chip.clone(), - tester.memory_controller().borrow().range_checker.clone(), + tester.range_checker().clone(), ShiftOpcode::CLASS_OFFSET, ), - tester.offline_memory_mutex_arc(), + MAX_INS_CAPACITY, + tester.memory_helper(), ); - for _ in 0..num_ops { - let b = generate_long_number::(&mut rng); - let (c_imm, c) = if rng.gen_bool(0.5) { - ( - None, - generate_long_number::(&mut rng), - ) + (chip, bitwise_chip) +} + +#[allow(clippy::too_many_arguments)] +fn set_and_execute>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: ShiftOpcode, + b: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, + is_imm: Option, + c: Option<[u8; RV32_REGISTER_NUM_LIMBS]>, +) { + let b = b.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))); + let (c_imm, c) = if is_imm.unwrap_or(rng.gen_bool(0.5)) { + let (imm, c) = if let Some(c) = c { + ((u32::from_le_bytes(c) & 0xFFFFFF) as usize, c) } else { - let (imm, c) = generate_rv32_is_type_immediate(&mut rng); - (Some(imm), c) + generate_rv32_is_type_immediate(rng) }; - - let (instruction, rd) = rv32_rand_write_register_or_imm( - &mut tester, - b, - c, - c_imm, - opcode.global_opcode().as_usize(), - &mut rng, - ); - tester.execute(&mut chip, &instruction); - - let (a, _, _) = run_shift::(opcode, &b, &c); - assert_eq!( - a.map(F::from_canonical_u32), - tester.read::(1, rd) + (Some(imm), c) + } else { + ( + None, + c.unwrap_or(array::from_fn(|_| rng.gen_range(0..=u8::MAX))), ) - } + }; + let (instruction, rd) = rv32_rand_write_register_or_imm( + tester, + b, + c, + c_imm, + opcode.global_opcode().as_usize(), + rng, + ); + tester.execute(chip, &instruction); - let tester = tester.build().load(chip).load(bitwise_chip).finalize(); - tester.simple_test().expect("Verification failed"); + let (a, _, _) = run_shift::(opcode, &b, &c); + assert_eq!( + a.map(F::from_canonical_u8), + tester.read::(1, rd) + ) } -#[test] -fn rv32_shift_sll_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SLL, 100); -} +////////////////////////////////////////////////////////////////////////////////////// +// POSITIVE TESTS +// +// Randomly generate computations and execute, ensuring that the generated trace +// passes all constraints. +////////////////////////////////////////////////////////////////////////////////////// +#[test_case(SLL, 100)] +#[test_case(SRL, 100)] +#[test_case(SRA, 100)] +fn run_rv32_shift_rand_test(opcode: ShiftOpcode, num_ops: usize) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, bitwise_chip) = create_test_chip(&tester); -#[test] -fn rv32_shift_srl_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SRL, 100); -} + for _ in 0..num_ops { + set_and_execute(&mut tester, &mut chip, &mut rng, opcode, None, None, None); + } -#[test] -fn rv32_shift_sra_rand_test() { - run_rv32_shift_rand_test(ShiftOpcode::SRA, 100); + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); } ////////////////////////////////////////////////////////////////////////////////////// // NEGATIVE TESTS // // Given a fake trace of a single operation, setup a chip and run the test. We replace -// the write part of the trace and check that the core chip throws the expected error. -// A dummy adapter is used so memory interactions don't indirectly cause false passes. +// part of the trace and check that the chip throws the expected error. ////////////////////////////////////////////////////////////////////////////////////// -type Rv32ShiftTestChip = - VmChipWrapper, ShiftCoreChip>; - #[derive(Clone, Copy, Default, PartialEq)] struct ShiftPrankValues { pub bit_shift: Option, @@ -134,63 +152,35 @@ struct ShiftPrankValues { } #[allow(clippy::too_many_arguments)] -fn run_rv32_shift_negative_test( +fn run_negative_shift_test( opcode: ShiftOpcode, - a: [u32; RV32_REGISTER_NUM_LIMBS], - b: [u32; RV32_REGISTER_NUM_LIMBS], - c: [u32; RV32_REGISTER_NUM_LIMBS], + prank_a: [u32; RV32_REGISTER_NUM_LIMBS], + b: [u8; RV32_REGISTER_NUM_LIMBS], + c: [u8; RV32_REGISTER_NUM_LIMBS], prank_vals: ShiftPrankValues, interaction_error: bool, ) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let mut rng = create_seeded_rng(); let mut tester: VmChipTestBuilder = VmChipTestBuilder::default(); - let range_checker_chip = tester.memory_controller().borrow().range_checker.clone(); - let mut chip = Rv32ShiftTestChip::::new( - TestAdapterChip::new( - vec![[b.map(F::from_canonical_u32), c.map(F::from_canonical_u32)].concat()], - vec![None], - ExecutionBridge::new(tester.execution_bus(), tester.program_bus()), - ), - ShiftCoreChip::new( - bitwise_chip.clone(), - range_checker_chip.clone(), - ShiftOpcode::CLASS_OFFSET, - ), - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chip(&tester); - tester.execute( + set_and_execute( + &mut tester, &mut chip, - &Instruction::from_usize(opcode.global_opcode(), [0, 0, 0, 1, 1]), + &mut rng, + opcode, + Some(b), + Some(false), + Some(c), ); - let bit_shift = prank_vals - .bit_shift - .unwrap_or(c[0] % (RV32_CELL_BITS as u32)); - let bit_shift_carry = prank_vals - .bit_shift_carry - .unwrap_or(array::from_fn(|i| match opcode { - ShiftOpcode::SLL => b[i] >> ((RV32_CELL_BITS as u32) - bit_shift), - _ => b[i] % (1 << bit_shift), - })); - - range_checker_chip.clear(); - range_checker_chip.add_count(bit_shift, RV32_CELL_BITS.ilog2() as usize); - for (a_val, carry_val) in a.iter().zip(bit_shift_carry.iter()) { - range_checker_chip.add_count(*a_val, RV32_CELL_BITS); - range_checker_chip.add_count(*carry_val, bit_shift as usize); - } - - let trace_width = chip.trace_width(); - let adapter_width = BaseAir::::width(chip.adapter.air()); - + let adapter_width = BaseAir::::width(&chip.air.adapter); let modify_trace = |trace: &mut DenseMatrix| { let mut values = trace.row_slice(0).to_vec(); let cols: &mut ShiftCoreCols = values.split_at_mut(adapter_width).1.borrow_mut(); - cols.a = a.map(F::from_canonical_u32); + cols.a = prank_a.map(F::from_canonical_u32); if let Some(bit_multiplier_left) = prank_vals.bit_multiplier_left { cols.bit_multiplier_left = F::from_canonical_u32(bit_multiplier_left); } @@ -210,21 +200,16 @@ fn run_rv32_shift_negative_test( cols.bit_shift_carry = bit_shift_carry.map(F::from_canonical_u32); } - *trace = RowMajorMatrix::new(values, trace_width); + *trace = RowMajorMatrix::new(values, trace.width()); }; - drop(range_checker_chip); disable_debug_builder(); let tester = tester .build() .load_and_prank_trace(chip, modify_trace) .load(bitwise_chip) .finalize(); - tester.simple_test_with_expected_error(if interaction_error { - VerificationError::ChallengePhaseError - } else { - VerificationError::OodEvaluationMismatch - }); + tester.simple_test_with_expected_error(get_verification_error(interaction_error)); } #[test] @@ -233,9 +218,9 @@ fn rv32_shift_wrong_negative_test() { let b = [1, 0, 0, 0]; let c = [1, 0, 0, 0]; let prank_vals = Default::default(); - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SLL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -249,7 +234,7 @@ fn rv32_sll_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -261,7 +246,7 @@ fn rv32_sll_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 0, 1, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -273,7 +258,7 @@ fn rv32_sll_wrong_bit_carry_negative_test() { bit_shift_carry: Some([0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, true); + run_negative_shift_test(SLL, a, b, c, prank_vals, true); } #[test] @@ -286,7 +271,7 @@ fn rv32_sll_wrong_bit_mult_side_negative_test() { bit_multiplier_right: Some(1), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SLL, a, b, c, prank_vals, false); + run_negative_shift_test(SLL, a, b, c, prank_vals, false); } #[test] @@ -300,7 +285,7 @@ fn rv32_srl_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); } #[test] @@ -312,7 +297,7 @@ fn rv32_srl_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 1, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); } #[test] @@ -325,8 +310,8 @@ fn rv32_srx_wrong_bit_mult_side_negative_test() { bit_multiplier_right: Some(0), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRL, a, b, c, prank_vals, false); - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRL, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -340,7 +325,7 @@ fn rv32_sra_wrong_bit_shift_negative_test() { bit_shift_marker: Some([0, 0, 1, 0, 0, 0, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -352,7 +337,7 @@ fn rv32_sra_wrong_limb_shift_negative_test() { limb_shift_marker: Some([0, 1, 0, 0]), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, false); + run_negative_shift_test(SRA, a, b, c, prank_vals, false); } #[test] @@ -364,7 +349,7 @@ fn rv32_sra_wrong_sign_negative_test() { b_sign: Some(0), ..Default::default() }; - run_rv32_shift_negative_test(ShiftOpcode::SRA, a, b, c, prank_vals, true); + run_negative_shift_test(SRA, a, b, c, prank_vals, true); } /////////////////////////////////////////////////////////////////////////////////////// @@ -375,11 +360,11 @@ fn rv32_sra_wrong_sign_negative_test() { #[test] fn run_sll_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [45, 7, 61, 186]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [91, 0, 100, 0]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [0, 0, 0, 104]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [45, 7, 61, 186]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [91, 0, 100, 0]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [0, 0, 0, 104]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SLL, &x, &y); + run_shift::(SLL, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -390,11 +375,11 @@ fn run_sll_sanity_test() { #[test] fn run_srl_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [49, 190, 190, 190]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [110, 100, 0, 0]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [49, 190, 190, 190]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [110, 100, 0, 0]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SRL, &x, &y); + run_shift::(SRL, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } @@ -405,11 +390,11 @@ fn run_srl_sanity_test() { #[test] fn run_sra_sanity_test() { - let x: [u32; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; - let y: [u32; RV32_REGISTER_NUM_LIMBS] = [113, 20, 50, 80]; - let z: [u32; RV32_REGISTER_NUM_LIMBS] = [110, 228, 255, 255]; + let x: [u8; RV32_REGISTER_NUM_LIMBS] = [31, 190, 221, 200]; + let y: [u8; RV32_REGISTER_NUM_LIMBS] = [113, 20, 50, 80]; + let z: [u8; RV32_REGISTER_NUM_LIMBS] = [110, 228, 255, 255]; let (result, limb_shift, bit_shift) = - run_shift::(ShiftOpcode::SRA, &x, &y); + run_shift::(SRA, &x, &y); for i in 0..RV32_REGISTER_NUM_LIMBS { assert_eq!(z[i], result[i]) } diff --git a/extensions/rv32im/circuit/src/test_utils.rs b/extensions/rv32im/circuit/src/test_utils.rs index 8a105ff990..f018b0d845 100644 --- a/extensions/rv32im/circuit/src/test_utils.rs +++ b/extensions/rv32im/circuit/src/test_utils.rs @@ -1,6 +1,6 @@ use openvm_circuit::arch::testing::{memory::gen_pointer, VmChipTestBuilder}; use openvm_instructions::{instruction::Instruction, VmOpcode}; -use openvm_stark_backend::p3_field::FieldAlgebra; +use openvm_stark_backend::{p3_field::FieldAlgebra, verifier::VerificationError}; use openvm_stark_sdk::p3_baby_bear::BabyBear; use rand::{rngs::StdRng, Rng}; @@ -10,8 +10,8 @@ use super::adapters::{RV32_REGISTER_NUM_LIMBS, RV_IS_TYPE_IMM_BITS}; #[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] pub fn rv32_rand_write_register_or_imm( tester: &mut VmChipTestBuilder, - rs1_writes: [u32; NUM_LIMBS], - rs2_writes: [u32; NUM_LIMBS], + rs1_writes: [u8; NUM_LIMBS], + rs2_writes: [u8; NUM_LIMBS], imm: Option, opcode_with_offset: usize, rng: &mut StdRng, @@ -22,9 +22,9 @@ pub fn rv32_rand_write_register_or_imm( let rs2 = imm.unwrap_or_else(|| gen_pointer(rng, NUM_LIMBS)); let rd = gen_pointer(rng, NUM_LIMBS); - tester.write::(1, rs1, rs1_writes.map(BabyBear::from_canonical_u32)); + tester.write::(1, rs1, rs1_writes.map(BabyBear::from_canonical_u8)); if !rs2_is_imm { - tester.write::(1, rs2, rs2_writes.map(BabyBear::from_canonical_u32)); + tester.write::(1, rs2, rs2_writes.map(BabyBear::from_canonical_u8)); } ( @@ -37,9 +37,7 @@ pub fn rv32_rand_write_register_or_imm( } #[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] -pub fn generate_rv32_is_type_immediate( - rng: &mut StdRng, -) -> (usize, [u32; RV32_REGISTER_NUM_LIMBS]) { +pub fn generate_rv32_is_type_immediate(rng: &mut StdRng) -> (usize, [u8; RV32_REGISTER_NUM_LIMBS]) { let mut imm: u32 = rng.gen_range(0..(1 << RV_IS_TYPE_IMM_BITS)); if (imm & 0x800) != 0 { imm |= !0xFFF @@ -51,7 +49,17 @@ pub fn generate_rv32_is_type_immediate( (imm >> 8) as u8, (imm >> 16) as u8, (imm >> 16) as u8, - ] - .map(|x| x as u32), + ], ) } + +/// Returns the corresponding verification error based on whether +/// an interaction error or a constraint error is expected +#[cfg_attr(all(feature = "test-utils", not(test)), allow(dead_code))] +pub fn get_verification_error(is_interaction_error: bool) -> VerificationError { + if is_interaction_error { + VerificationError::ChallengePhaseError + } else { + VerificationError::OodEvaluationMismatch + } +} diff --git a/extensions/rv32im/tests/src/lib.rs b/extensions/rv32im/tests/src/lib.rs index e59088eac8..eae71da4fd 100644 --- a/extensions/rv32im/tests/src/lib.rs +++ b/extensions/rv32im/tests/src/lib.rs @@ -131,7 +131,7 @@ mod tests { let config = Rv32IConfig::default(); let executor = VmExecutor::::new(config.clone()); let final_memory = executor.execute(exe, vec![])?.unwrap(); - let hasher = vm_poseidon2_hasher(); + let hasher = vm_poseidon2_hasher::(); let pv_proof = UserPublicValuesProof::compute( config.system.memory_config.memory_dimensions(), 64, diff --git a/extensions/sha256/circuit/src/extension.rs b/extensions/sha256/circuit/src/extension.rs index 76a6c1ec0c..deba8b2d10 100644 --- a/extensions/sha256/circuit/src/extension.rs +++ b/extensions/sha256/circuit/src/extension.rs @@ -3,7 +3,7 @@ use openvm_circuit::{ arch::{SystemConfig, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError}, system::phantom::PhantomChip, }; -use openvm_circuit_derive::{AnyEnum, InstructionExecutor, VmConfig}; +use openvm_circuit_derive::{AnyEnum, InsExecutorE1, InstructionExecutor, VmConfig}; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, }; @@ -20,6 +20,9 @@ use strum::IntoEnumIterator; use crate::*; +// TODO: this should be decided after e2 execution +const MAX_INS_CAPACITY: usize = 1 << 22; + #[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] pub struct Sha256Rv32Config { #[system] @@ -49,7 +52,7 @@ impl Default for Sha256Rv32Config { #[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] pub struct Sha256; -#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)] +#[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum, InsExecutorE1)] pub enum Sha256Executor { Sha256(Sha256VmChip), } @@ -69,6 +72,8 @@ impl VmExtension for Sha256 { builder: &mut VmInventoryBuilder, ) -> Result, VmInventoryError> { let mut inventory = VmInventory::new(); + let pointer_max_bits = builder.system_config().memory_config.pointer_max_bits; + let bitwise_lu_chip = if let Some(&chip) = builder .find_chip::>() .first() @@ -82,12 +87,19 @@ impl VmExtension for Sha256 { }; let sha256_chip = Sha256VmChip::new( - builder.system_port(), - builder.system_config().memory_config.pointer_max_bits, - bitwise_lu_chip, - builder.new_bus_idx(), - Rv32Sha256Opcode::CLASS_OFFSET, - builder.system_base().offline_memory(), + Sha256VmAir::new( + builder.system_port(), + bitwise_lu_chip.bus(), + pointer_max_bits, + builder.new_bus_idx(), + ), + Sha256VmStep::new( + bitwise_lu_chip.clone(), + Rv32Sha256Opcode::CLASS_OFFSET, + pointer_max_bits, + ), + MAX_INS_CAPACITY, + builder.system_base().memory_controller.helper(), ); inventory.add_executor( sha256_chip, diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs index 286df07697..d8e4c02348 100644 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ b/extensions/sha256/circuit/src/sha256_chip/air.rs @@ -1,7 +1,7 @@ use std::{array, borrow::Borrow, cmp::min}; use openvm_circuit::{ - arch::ExecutionBridge, + arch::{ExecutionBridge, SystemPort}, system::memory::{offline_checker::MemoryBridge, MemoryAddress}, }; use openvm_circuit_primitives::{ @@ -17,7 +17,7 @@ use openvm_sha256_air::{ }; use openvm_sha256_transpiler::Rv32Sha256Opcode; use openvm_stark_backend::{ - interaction::InteractionBuilder, + interaction::{BusIndex, InteractionBuilder}, p3_air::{Air, AirBuilder, BaseAir}, p3_field::{Field, FieldAlgebra}, p3_matrix::Matrix, @@ -31,7 +31,7 @@ use super::{ /// Sha256VmAir does all constraints related to message padding and /// the Sha256Air subair constrains the actual hash -#[derive(Clone, Debug, derive_new::new)] +#[derive(Clone, Debug)] pub struct Sha256VmAir { pub execution_bridge: ExecutionBridge, pub memory_bridge: MemoryBridge, @@ -44,6 +44,28 @@ pub struct Sha256VmAir { pub(super) padding_encoder: Encoder, } +impl Sha256VmAir { + pub fn new( + SystemPort { + execution_bus, + program_bus, + memory_bridge, + }: SystemPort, + bitwise_lookup_bus: BitwiseOperationLookupBus, + ptr_max_bits: usize, + self_bus_idx: BusIndex, + ) -> Self { + Self { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lookup_bus, + ptr_max_bits, + sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx), + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), + } + } +} + impl BaseAirWithPublicValues for Sha256VmAir {} impl PartitionedBaseAir for Sha256VmAir {} impl BaseAir for Sha256VmAir { diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha256/circuit/src/sha256_chip/mod.rs index 4c40eca5d8..b87069b2e2 100644 --- a/extensions/sha256/circuit/src/sha256_chip/mod.rs +++ b/extensions/sha256/circuit/src/sha256_chip/mod.rs @@ -1,13 +1,12 @@ //! Sha256 hasher. Handles full sha256 hashing with padding. //! variable length inputs read from VM memory. -use std::{ - array, - cmp::{max, min}, - sync::{Arc, Mutex}, -}; -use openvm_circuit::arch::{ - ExecutionBridge, ExecutionError, ExecutionState, InstructionExecutor, SystemPort, +use openvm_circuit::{ + arch::{ + execution_mode::{metered::MeteredCtx, E1E2ExecutionCtx}, + NewVmChipWrapper, Result, StepExecutorE1, VmStateMut, + }, + system::memory::online::GuestMemory, }; use openvm_circuit_primitives::{ bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, @@ -18,11 +17,12 @@ use openvm_instructions::{ riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS}, LocalOpcode, }; -use openvm_rv32im_circuit::adapters::read_rv32_register; -use openvm_sha256_air::{Sha256Air, SHA256_BLOCK_BITS}; +use openvm_rv32im_circuit::adapters::{ + memory_read_from_state, memory_write_from_state, new_read_rv32_register_from_state, +}; +use openvm_sha256_air::{Sha256StepHelper, SHA256_BLOCK_BITS, SHA256_ROWS_PER_BLOCK}; use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{interaction::BusIndex, p3_field::PrimeField32}; -use serde::{Deserialize, Serialize}; +use openvm_stark_backend::p3_field::PrimeField32; use sha2::{Digest, Sha256}; mod air; @@ -31,7 +31,6 @@ mod trace; pub use air::*; pub use columns::*; -use openvm_circuit::system::memory::{MemoryController, OfflineMemory, RecordId}; #[cfg(test)] mod tests; @@ -47,64 +46,42 @@ const SHA256_WRITE_SIZE: usize = 32; pub const SHA256_BLOCK_CELLS: usize = SHA256_BLOCK_BITS / RV32_CELL_BITS; /// Number of rows we will do a read on for each SHA256 block pub const SHA256_NUM_READ_ROWS: usize = SHA256_BLOCK_CELLS / SHA256_READ_SIZE; -pub struct Sha256VmChip { - pub air: Sha256VmAir, - /// IO and memory data necessary for each opcode call - pub records: Vec>, - pub offline_memory: Arc>>, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - - offset: usize, -} -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub struct Sha256Record { - pub from_state: ExecutionState, - pub dst_read: RecordId, - pub src_read: RecordId, - pub len_read: RecordId, - pub input_records: Vec<[RecordId; SHA256_NUM_READ_ROWS]>, - pub input_message: Vec<[[u8; SHA256_READ_SIZE]; SHA256_NUM_READ_ROWS]>, - pub digest_write: RecordId, +pub type Sha256VmChip = NewVmChipWrapper; + +pub struct Sha256VmStep { + pub inner: Sha256StepHelper, + pub padding_encoder: Encoder, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub offset: usize, + pub pointer_max_bits: usize, } -impl Sha256VmChip { +impl Sha256VmStep { pub fn new( - SystemPort { - execution_bus, - program_bus, - memory_bridge, - }: SystemPort, - address_bits: usize, - bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, - self_bus_idx: BusIndex, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, offset: usize, - offline_memory: Arc>>, + pointer_max_bits: usize, ) -> Self { Self { - air: Sha256VmAir::new( - ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_chip.bus(), - address_bits, - Sha256Air::new(bitwise_lookup_chip.bus(), self_bus_idx), - Encoder::new(PaddingFlags::COUNT, 2, false), - ), + inner: Sha256StepHelper::new(), + padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), bitwise_lookup_chip, - records: Vec::new(), offset, - offline_memory, + pointer_max_bits, } } } -impl InstructionExecutor for Sha256VmChip { - fn execute( +impl StepExecutorE1 for Sha256VmStep { + fn execute_e1( &mut self, - memory: &mut MemoryController, + state: &mut VmStateMut, instruction: &Instruction, - from_state: ExecutionState, - ) -> Result, ExecutionError> { + ) -> Result<()> + where + Ctx: E1E2ExecutionCtx, + { let &Instruction { opcode, a, @@ -114,86 +91,86 @@ impl InstructionExecutor for Sha256VmChip { e, .. } = instruction; + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); let local_opcode = opcode.local_opcode_idx(self.offset); debug_assert_eq!(local_opcode, Rv32Sha256Opcode::SHA256.local_usize()); - debug_assert_eq!(d, F::from_canonical_u32(RV32_REGISTER_AS)); - debug_assert_eq!(e, F::from_canonical_u32(RV32_MEMORY_AS)); - - debug_assert_eq!(from_state.timestamp, memory.timestamp()); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + let dst = new_read_rv32_register_from_state(state, d, a.as_canonical_u32()); + let src = new_read_rv32_register_from_state(state, d, b.as_canonical_u32()); + let len = new_read_rv32_register_from_state(state, d, c.as_canonical_u32()); - let (dst_read, dst) = read_rv32_register(memory, d, a); - let (src_read, src) = read_rv32_register(memory, d, b); - let (len_read, len) = read_rv32_register(memory, d, c); + debug_assert!(src + len <= (1 << self.pointer_max_bits)); + let mut hasher = Sha256::new(); - #[cfg(debug_assertions)] - { - assert!(dst < (1 << self.air.ptr_max_bits)); - assert!(src < (1 << self.air.ptr_max_bits)); - assert!(len < (1 << self.air.ptr_max_bits)); + // TODO(ayush): read in a single call + let mut message = Vec::with_capacity(len as usize); + for offset in (0..len as usize).step_by(SHA256_READ_SIZE) { + let read = memory_read_from_state::<_, SHA256_READ_SIZE>(state, e, src + offset as u32); + let copy_len = std::cmp::min(SHA256_READ_SIZE, (len as usize) - offset); + message.extend_from_slice(&read[..copy_len]); } + hasher.update(&message); + + let output = hasher.finalize(); + memory_write_from_state(state, e, dst, output.as_ref()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } + + fn execute_metered( + &mut self, + state: &mut VmStateMut, + instruction: &Instruction, + chip_index: usize, + ) -> Result<()> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + let local_opcode = opcode.local_opcode_idx(self.offset); + + debug_assert_eq!(local_opcode, Rv32Sha256Opcode::SHA256.local_usize()); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); + + let dst = new_read_rv32_register_from_state(state, d, a.as_canonical_u32()); + let src = new_read_rv32_register_from_state(state, d, b.as_canonical_u32()); + let len = new_read_rv32_register_from_state(state, d, c.as_canonical_u32()); + + debug_assert!(src + len <= (1 << self.pointer_max_bits)); - // need to pad with one 1 bit, 64 bits for the message length and then pad until the length - // is divisible by [SHA256_BLOCK_BITS] let num_blocks = ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); - // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used - debug_assert!( - src as usize + num_blocks * SHA256_BLOCK_CELLS <= (1 << self.air.ptr_max_bits) - ); - let mut hasher = Sha256::new(); - let mut input_records = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS); - let mut input_message = Vec::with_capacity(num_blocks * SHA256_NUM_READ_ROWS); - let mut read_ptr = src; - for _ in 0..num_blocks { - let block_reads_records = array::from_fn(|i| { - memory.read( - e, - F::from_canonical_u32(read_ptr + (i * SHA256_READ_SIZE) as u32), - ) - }); - let block_reads_bytes = array::from_fn(|i| { - // we add to the hasher only the bytes that are part of the message - let num_reads = min( - SHA256_READ_SIZE, - (max(read_ptr, src + len) - read_ptr) as usize, - ); - let row_input = block_reads_records[i] - .1 - .map(|x| x.as_canonical_u32().try_into().unwrap()); - hasher.update(&row_input[..num_reads]); - read_ptr += SHA256_READ_SIZE as u32; - row_input - }); - input_records.push(block_reads_records.map(|x| x.0)); - input_message.push(block_reads_bytes); + let mut message = Vec::with_capacity(len as usize); + for offset in (0..len as usize).step_by(SHA256_READ_SIZE) { + let read = memory_read_from_state::<_, SHA256_READ_SIZE>(state, e, src + offset as u32); + let copy_len = std::cmp::min(SHA256_READ_SIZE, (len as usize) - offset); + message.extend_from_slice(&read[..copy_len]); } - let mut digest = [0u8; SHA256_WRITE_SIZE]; - digest.copy_from_slice(hasher.finalize().as_ref()); - let (digest_write, _) = memory.write( - e, - F::from_canonical_u32(dst), - digest.map(|b| F::from_canonical_u8(b)), - ); - - self.records.push(Sha256Record { - from_state: from_state.map(F::from_canonical_u32), - dst_read, - src_read, - len_read, - input_records, - input_message, - digest_write, - }); - - Ok(ExecutionState { - pc: from_state.pc + DEFAULT_PC_STEP, - timestamp: memory.timestamp(), - }) - } + let mut hasher = Sha256::new(); + hasher.update(&message); + + let output = hasher.finalize(); + memory_write_from_state(state, e, dst, output.as_ref()); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + state.ctx.trace_heights[chip_index] += (num_blocks * SHA256_ROWS_PER_BLOCK) as u32; - fn get_opcode_name(&self, _: usize) -> String { - "SHA256".to_string() + Ok(()) } } diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha256/circuit/src/sha256_chip/tests.rs index 55bc076e2c..f4d78fe308 100644 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ b/extensions/sha256/circuit/src/sha256_chip/tests.rs @@ -1,6 +1,7 @@ -use openvm_circuit::arch::{ - testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, - SystemPort, +use std::array; + +use openvm_circuit::arch::testing::{ + memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS, }; use openvm_circuit_primitives::bitwise_op_lookup::{ BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, @@ -12,11 +13,39 @@ use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; use rand::{rngs::StdRng, Rng}; -use super::Sha256VmChip; +use super::{Sha256VmAir, Sha256VmChip, Sha256VmStep}; use crate::{sha256_solve, Sha256VmDigestCols, Sha256VmRoundCols}; type F = BabyBear; -const BUS_IDX: BusIndex = 28; +const SELF_BUS_IDX: BusIndex = 28; +const MAX_INS_CAPACITY: usize = 4096; + +fn create_test_chips( + tester: &mut VmChipTestBuilder, +) -> ( + Sha256VmChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let chip = Sha256VmChip::new( + Sha256VmAir::new( + tester.system_port(), + bitwise_bus, + tester.address_bits(), + SELF_BUS_IDX, + ), + Sha256VmStep::new( + bitwise_chip.clone(), + Rv32Sha256Opcode::CLASS_OFFSET, + tester.address_bits(), + ), + MAX_INS_CAPACITY, + tester.memory_helper(), + ); + (chip, bitwise_chip) +} + fn set_and_execute( tester: &mut VmChipTestBuilder, chip: &mut Sha256VmChip, @@ -25,7 +54,7 @@ fn set_and_execute( message: Option<&[u8]>, len: Option, ) { - let len = len.unwrap_or(rng.gen_range(1..100000)); + let len = len.unwrap_or(rng.gen_range(1..3000)); let tmp = get_random_message(rng, len); let message: &[u8] = message.unwrap_or(&tmp); let len = message.len(); @@ -34,12 +63,7 @@ fn set_and_execute( let rs1 = gen_pointer(rng, 4); let rs2 = gen_pointer(rng, 4); - let max_mem_ptr: u32 = 1 - << tester - .memory_controller() - .borrow() - .mem_config() - .pointer_max_bits; + let max_mem_ptr: u32 = 1 << tester.address_bits(); let dst_ptr = rng.gen_range(0..max_mem_ptr); let dst_ptr = dst_ptr ^ (dst_ptr & 3); tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); @@ -48,9 +72,14 @@ fn set_and_execute( tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); - for (i, &byte) in message.iter().enumerate() { - tester.write(2, src_ptr as usize + i, [F::from_canonical_u8(byte)]); - } + message.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); + tester.write( + 2, + src_ptr as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), + ); + }); tester.execute( chip, @@ -75,23 +104,10 @@ fn rand_sha256_test() { setup_tracing(); let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChip::new( - SystemPort { - execution_bus: tester.execution_bus(), - program_bus: tester.program_bus(), - memory_bridge: tester.memory_bridge(), - }, - tester.address_bits(), - bitwise_chip.clone(), - BUS_IDX, - Rv32Sha256Opcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); + let (mut chip, bitwise_chip) = create_test_chips(&mut tester); - let num_tests: usize = 3; - for _ in 0..num_tests { + let num_ops: usize = 10; + for _ in 0..num_ops { set_and_execute(&mut tester, &mut chip, &mut rng, SHA256, None, None); } @@ -108,20 +124,7 @@ fn rand_sha256_test() { fn execute_roundtrip_sanity_test() { let mut rng = create_seeded_rng(); let mut tester = VmChipTestBuilder::default(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); - let mut chip = Sha256VmChip::new( - SystemPort { - execution_bus: tester.execution_bus(), - program_bus: tester.program_bus(), - memory_bridge: tester.memory_bridge(), - }, - tester.address_bits(), - bitwise_chip.clone(), - BUS_IDX, - Rv32Sha256Opcode::CLASS_OFFSET, - tester.offline_memory_mutex_arc(), - ); + let (mut chip, _) = create_test_chips(&mut tester); println!( "Sha256VmDigestCols::width(): {}", diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs index c02cd00dd8..e69e748073 100644 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ b/extensions/sha256/circuit/src/sha256_chip/trace.rs @@ -1,351 +1,418 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; +use std::{ + array, + borrow::{Borrow, BorrowMut}, +}; -use openvm_circuit_primitives::utils::next_power_of_two_or_zero; -use openvm_instructions::riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS}; -use openvm_rv32im_circuit::adapters::compose; -use openvm_sha256_air::{ - get_flag_pt_array, limbs_into_u32, Sha256Air, SHA256_BLOCK_WORDS, SHA256_BUFFER_SIZE, SHA256_H, - SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, +use openvm_circuit::{ + arch::{Result, TraceStep, VmStateMut}, + system::memory::{online::TracingMemory, MemoryAuxColsFactory}, }; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_air::BaseAir, - p3_field::{FieldAlgebra, PrimeField32}, - p3_matrix::dense::RowMajorMatrix, - p3_maybe_rayon::prelude::*, - prover::types::AirProofInput, - rap::get_air_name, - AirRef, Chip, ChipUsageGetter, +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, }; +use openvm_rv32im_circuit::adapters::{memory_read, tracing_read, tracing_write}; +use openvm_sha256_air::{ + get_flag_pt_array, u32_into_u16s, Sha256StepHelper, SHA256_BLOCK_BITS, SHA256_BLOCK_WORDS, + SHA256_H, SHA256_ROWS_PER_BLOCK, SHA256_WORD_U8S, +}; +use openvm_sha256_transpiler::Rv32Sha256Opcode; +use openvm_stark_backend::{p3_field::PrimeField32, p3_maybe_rayon::prelude::*}; use super::{ - Sha256VmChip, Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, - SHA256VM_DIGEST_WIDTH, SHA256VM_ROUND_WIDTH, + Sha256VmDigestCols, Sha256VmRoundCols, Sha256VmStep, SHA256VM_CONTROL_WIDTH, + SHA256VM_DIGEST_WIDTH, }; use crate::{ sha256_chip::{PaddingFlags, SHA256_READ_SIZE}, - SHA256_BLOCK_CELLS, + SHA256VM_ROUND_WIDTH, SHA256_BLOCK_CELLS, }; -impl Chip for Sha256VmChip> -where - Val: PrimeField32, -{ - fn air(&self) -> AirRef { - Arc::new(self.air.clone()) - } +impl TraceStep for Sha256VmStep { + fn execute( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + trace: &mut [F], + trace_offset: &mut usize, + width: usize, + ) -> Result<()> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + let d = d.as_canonical_u32(); + let e = e.as_canonical_u32(); + debug_assert_eq!(*opcode, Rv32Sha256Opcode::SHA256.global_opcode()); + debug_assert_eq!(d, RV32_REGISTER_AS); + debug_assert_eq!(e, RV32_MEMORY_AS); - fn generate_air_proof_input(self) -> AirProofInput { - let non_padded_height = self.current_trace_height(); - let height = next_power_of_two_or_zero(non_padded_height); - let width = self.trace_width(); - let mut values = Val::::zero_vec(height * width); - if height == 0 { - return AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)); - } - let records = self.records; - let offline_memory = self.offline_memory.lock().unwrap(); - let memory_aux_cols_factory = offline_memory.aux_cols_factory(); - - let mem_ptr_shift: u32 = - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.air.ptr_max_bits); + let trace = &mut trace[*trace_offset..]; + // Doing an untraced read to get the length to get the correct places to store the aux data + let len = u32::from_le_bytes(memory_read(state.memory.data(), d, c.as_canonical_u32())); + // need to pad with one 1 bit, 64 bits for the message length and then pad until the length + // is divisible by [SHA256_BLOCK_BITS] + let num_blocks = ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); - let mut states = Vec::with_capacity(height.div_ceil(SHA256_ROWS_PER_BLOCK)); - let mut global_block_idx = 0; - for (record_idx, record) in records.iter().enumerate() { - let dst_read = offline_memory.record_by_id(record.dst_read); - let src_read = offline_memory.record_by_id(record.src_read); - let len_read = offline_memory.record_by_id(record.len_read); + let last_row_offset = (num_blocks * SHA256_ROWS_PER_BLOCK - 1) * width; + let (dst, mut src) = { + let last_digest_row: &mut Sha256VmDigestCols = + trace[last_row_offset..last_row_offset + SHA256VM_DIGEST_WIDTH].borrow_mut(); - self.bitwise_lookup_chip.request_range( - dst_read - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - * mem_ptr_shift, - src_read - .data_at(RV32_REGISTER_NUM_LIMBS - 1) - .as_canonical_u32() - * mem_ptr_shift, + last_digest_row.from_state.timestamp = F::from_canonical_u32(state.memory.timestamp()); + last_digest_row.from_state.pc = F::from_canonical_u32(*state.pc); + let dst = tracing_read( + state.memory, + d, + a.as_canonical_u32(), + &mut last_digest_row.register_reads_aux[0], ); - let len = compose(len_read.data_slice().try_into().unwrap()); - let mut state = &None; - for (i, input_message) in record.input_message.iter().enumerate() { - let input_message = input_message - .iter() - .flatten() - .copied() - .collect::>() - .try_into() - .unwrap(); - states.push(Some(Self::generate_state( - state, - input_message, - record_idx, + let src = tracing_read( + state.memory, + d, + b.as_canonical_u32(), + &mut last_digest_row.register_reads_aux[1], + ); + let len = tracing_read::<_, RV32_REGISTER_NUM_LIMBS>( + state.memory, + d, + c.as_canonical_u32(), + &mut last_digest_row.register_reads_aux[2], + ); + + last_digest_row.rd_ptr = *a; + last_digest_row.rs1_ptr = *b; + last_digest_row.rs2_ptr = *c; + last_digest_row.dst_ptr = dst.map(F::from_canonical_u8); + last_digest_row.src_ptr = src.map(F::from_canonical_u8); + last_digest_row.len_data = len.map(F::from_canonical_u8); + (u32::from_le_bytes(dst), u32::from_le_bytes(src)) + }; + + // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used + debug_assert!( + src as usize + num_blocks * SHA256_BLOCK_CELLS <= (1 << self.pointer_max_bits) + ); + + // // We can deduce the global block index from the trace offset + // // Note: global block index is 1-based + // let global_idx = *trace_offset / (SHA256_ROWS_PER_BLOCK * width) + 1; + let mut prev_hash = SHA256_H; + trace + .chunks_mut(width * SHA256_ROWS_PER_BLOCK) + .enumerate() + .take(num_blocks) + .for_each(|(block_idx, block_slice)| { + let is_last_block = block_idx == num_blocks - 1; + let mut read_data = [[0u8; SHA256_READ_SIZE]; 4]; + block_slice + .chunks_mut(width) + .enumerate() + .take(4) + .for_each(|(row_idx, row)| { + let cols: &mut Sha256VmRoundCols = + row[..SHA256VM_ROUND_WIDTH].borrow_mut(); + read_data[row_idx] = tracing_read::<_, SHA256_READ_SIZE>( + state.memory, + e, + src, + &mut cols.read_aux, + ); + cols.inner + .message_schedule + .carry_or_buffer + .iter_mut() + .zip( + read_data[row_idx] + .map(F::from_canonical_u8) + .chunks_exact(SHA256_WORD_U8S), + ) + .for_each(|(buffer, data)| { + buffer.copy_from_slice(data); + }); + src += SHA256_READ_SIZE as u32; + }); + + let digest_row = &mut block_slice[(SHA256_ROWS_PER_BLOCK - 1) * width..]; + let digest_cols: &mut Sha256VmDigestCols = + digest_row[..SHA256VM_DIGEST_WIDTH].borrow_mut(); + digest_cols.inner.prev_hash = + prev_hash.map(|x| u32_into_u16s(x).map(F::from_canonical_u32)); + digest_cols.inner.flags.local_block_idx = F::from_canonical_usize(block_idx); + digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); + digest_cols.control.len = F::from_canonical_u32(len); + digest_cols.control.read_ptr = F::from_canonical_u32(src); + digest_cols.control.cur_timestamp = F::from_canonical_u32(state.memory.timestamp()); + let padded_input = get_padded_input( + read_data.concat().try_into().unwrap(), len, - i == record.input_records.len() - 1, - ))); - state = &states[global_block_idx]; - global_block_idx += 1; - } + block_idx, + is_last_block, + ); + Sha256StepHelper::get_block_hash(&mut prev_hash, padded_input); + }); + + let last_digest_row: &mut Sha256VmDigestCols = + trace[last_row_offset..last_row_offset + SHA256VM_DIGEST_WIDTH].borrow_mut(); + tracing_write( + state.memory, + e, + dst, + &prev_hash + .map(|x| x.to_be_bytes()) + .concat() + .try_into() + .unwrap(), + &mut last_digest_row.writes_aux, + ); + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + *trace_offset += num_blocks * SHA256_ROWS_PER_BLOCK * width; + Ok(()) + } + + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut [F], + width: usize, + rows_used: usize, + ) where + Self: Send + Sync, + F: Send + Sync, + { + if rows_used == 0 { + return; } - states.extend(std::iter::repeat_n( - None, - (height - non_padded_height).div_ceil(SHA256_ROWS_PER_BLOCK), - )); + + let mem_ptr_shift: u32 = + 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits); // During the first pass we will fill out most of the matrix // But there are some cells that can't be generated by the first pass so we will do a second // pass over the matrix - values + trace .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(states.into_par_iter().enumerate()) - .for_each(|(block, (global_block_idx, state))| { - // Fill in a valid block - if let Some(state) = state { - let mut has_padding_occurred = - state.local_block_idx * SHA256_BLOCK_CELLS > state.message_len as usize; - let message_left = if has_padding_occurred { - 0 - } else { - state.message_len as usize - state.local_block_idx * SHA256_BLOCK_CELLS - }; - let is_last_block = state.is_last_block; - let buffer: [[Val; SHA256_BUFFER_SIZE]; 4] = array::from_fn(|j| { - array::from_fn(|k| { - Val::::from_canonical_u8( - state.block_input_message[j * SHA256_BUFFER_SIZE + k], - ) - }) + .enumerate() + .for_each(|(block_idx, block_slice)| { + if block_idx * SHA256_ROWS_PER_BLOCK >= rows_used { + // Fill in the invalid rows + block_slice.par_chunks_mut(width).for_each(|row| { + let cols: &mut Sha256VmRoundCols = + row[..SHA256VM_ROUND_WIDTH].borrow_mut(); + self.inner.generate_default_row(&mut cols.inner); }); + return; + } - let padded_message: [u32; SHA256_BLOCK_WORDS] = array::from_fn(|j| { - limbs_into_u32::(array::from_fn(|k| { - state.block_padded_message[(j + 1) * SHA256_WORD_U8S - k - 1] as u32 - })) - }); + // The read data is kept in the buffer of the first 4 round cols + let read_data: [u8; SHA256_BLOCK_CELLS] = block_slice + .chunks_exact(width) + .take(4) + .map(|row| { + let cols: &Sha256VmRoundCols = row[..SHA256VM_ROUND_WIDTH].borrow(); + cols.inner.message_schedule.carry_or_buffer.as_flattened() + }) + .flatten() + .map(|x| x.as_canonical_u32() as u8) + .collect::>() + .try_into() + .unwrap(); - self.air.sha256_subair.generate_block_trace::>( - block, - width, - SHA256VM_CONTROL_WIDTH, - &padded_message, - self.bitwise_lookup_chip.clone(), - &state.hash, - is_last_block, - global_block_idx as u32 + 1, - state.local_block_idx as u32, - &buffer, - ); + let digest_offset = width * (SHA256_ROWS_PER_BLOCK - 1); + let (local_block_idx, len, is_last_block, prev_hash) = { + let digest_cols: &mut Sha256VmDigestCols = block_slice + [digest_offset..digest_offset + SHA256VM_DIGEST_WIDTH] + .borrow_mut(); + ( + digest_cols.inner.flags.local_block_idx.as_canonical_u32() as usize, + digest_cols.control.len.as_canonical_u32(), + digest_cols.inner.flags.is_last_block.is_one(), + digest_cols + .inner + .prev_hash + .map(|x| x[0].as_canonical_u32() + (x[1].as_canonical_u32() << 16)), + ) + }; + let mut has_padding_occurred = local_block_idx * SHA256_BLOCK_CELLS > len as usize; + let message_left = if has_padding_occurred { + 0 + } else { + len as usize - local_block_idx * SHA256_BLOCK_CELLS + }; - let block_reads = records[state.message_idx].input_records - [state.local_block_idx] - .map(|record_id| offline_memory.record_by_id(record_id)); + let padded_input = get_padded_input(read_data, len, local_block_idx, is_last_block); + let padded_input: [u32; SHA256_BLOCK_WORDS] = array::from_fn(|j| { + u32::from_be_bytes( + padded_input[j * SHA256_WORD_U8S..(j + 1) * SHA256_WORD_U8S] + .try_into() + .unwrap(), + ) + }); - let mut read_ptr = block_reads[0].pointer; - let mut cur_timestamp = Val::::from_canonical_u32(block_reads[0].timestamp); + self.inner.generate_block_trace::( + block_slice, + width, + SHA256VM_CONTROL_WIDTH, + &padded_input, + self.bitwise_lookup_chip.as_ref(), + &prev_hash, + is_last_block, + block_idx as u32 + 1, // global block index is 1-based + local_block_idx as u32, + ); - let read_size = Val::::from_canonical_usize(SHA256_READ_SIZE); - for row in 0..SHA256_ROWS_PER_BLOCK { - let row_slice = &mut block[row * width..(row + 1) * width]; - if row < 16 { - let cols: &mut Sha256VmRoundCols> = - row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); - cols.control.len = Val::::from_canonical_u32(state.message_len); - cols.control.read_ptr = read_ptr; - cols.control.cur_timestamp = cur_timestamp; - if row < 4 { - read_ptr += read_size; - cur_timestamp += Val::::ONE; - memory_aux_cols_factory - .generate_read_aux(block_reads[row], &mut cols.read_aux); + let (round_rows, digest_row) = block_slice.split_at_mut(digest_offset); + let digest_cols: &mut Sha256VmDigestCols = + digest_row[..SHA256VM_DIGEST_WIDTH].borrow_mut(); + let len = digest_cols.control.len; + let read_ptr = digest_cols.control.read_ptr; + let timestamp = digest_cols.control.cur_timestamp; - if (row + 1) * SHA256_READ_SIZE <= message_left { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotPadding as usize, - ) - .map(Val::::from_canonical_u32); - } else if !has_padding_occurred { - has_padding_occurred = true; - let len = message_left - row * SHA256_READ_SIZE; - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - if row == 3 && is_last_block { - PaddingFlags::FirstPadding0_LastRow - } else { - PaddingFlags::FirstPadding0 - } as usize - + len, - ) - .map(Val::::from_canonical_u32); + // Fill in the first 4 round rows + round_rows + .chunks_mut(width) + .take(4) + .enumerate() + .for_each(|(row, row_slice)| { + let cols: &mut Sha256VmRoundCols = + row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); + cols.control.len = len; + cols.control.read_ptr = + read_ptr - F::from_canonical_usize(SHA256_READ_SIZE * (4 - row)); + cols.control.cur_timestamp = timestamp - F::from_canonical_usize(4 - row); + mem_helper.fill_from_prev( + cols.control.cur_timestamp.as_canonical_u32(), + cols.read_aux.as_mut(), + ); + if (row + 1) * SHA256_READ_SIZE <= message_left { + cols.control.pad_flags = get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotPadding as usize, + ) + .map(F::from_canonical_u32); + } else if !has_padding_occurred { + has_padding_occurred = true; + let len = message_left - row * SHA256_READ_SIZE; + cols.control.pad_flags = get_flag_pt_array( + &self.padding_encoder, + if row == 3 && is_last_block { + PaddingFlags::FirstPadding0_LastRow } else { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - if row == 3 && is_last_block { - PaddingFlags::EntirePaddingLastRow - } else { - PaddingFlags::EntirePadding - } as usize, - ) - .map(Val::::from_canonical_u32); - } - } else { - cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(Val::::from_canonical_u32); - } - cols.control.padding_occurred = - Val::::from_bool(has_padding_occurred); + PaddingFlags::FirstPadding0 + } as usize + + len, + ) + .map(F::from_canonical_u32); } else { - if is_last_block { - has_padding_occurred = false; - } - let cols: &mut Sha256VmDigestCols> = - row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut(); - cols.control.len = Val::::from_canonical_u32(state.message_len); - cols.control.read_ptr = read_ptr; - cols.control.cur_timestamp = cur_timestamp; cols.control.pad_flags = get_flag_pt_array( - &self.air.padding_encoder, - PaddingFlags::NotConsidered as usize, + &self.padding_encoder, + if row == 3 && is_last_block { + PaddingFlags::EntirePaddingLastRow + } else { + PaddingFlags::EntirePadding + } as usize, ) - .map(Val::::from_canonical_u32); - if is_last_block { - let record = &records[state.message_idx]; - let dst_read = offline_memory.record_by_id(record.dst_read); - let src_read = offline_memory.record_by_id(record.src_read); - let len_read = offline_memory.record_by_id(record.len_read); - let digest_write = offline_memory.record_by_id(record.digest_write); - cols.from_state = record.from_state; - cols.rd_ptr = dst_read.pointer; - cols.rs1_ptr = src_read.pointer; - cols.rs2_ptr = len_read.pointer; - cols.dst_ptr.copy_from_slice(dst_read.data_slice()); - cols.src_ptr.copy_from_slice(src_read.data_slice()); - cols.len_data.copy_from_slice(len_read.data_slice()); - memory_aux_cols_factory - .generate_read_aux(dst_read, &mut cols.register_reads_aux[0]); - memory_aux_cols_factory - .generate_read_aux(src_read, &mut cols.register_reads_aux[1]); - memory_aux_cols_factory - .generate_read_aux(len_read, &mut cols.register_reads_aux[2]); - memory_aux_cols_factory - .generate_write_aux(digest_write, &mut cols.writes_aux); - } - cols.control.padding_occurred = - Val::::from_bool(has_padding_occurred); + .map(F::from_canonical_u32); } - } + cols.control.padding_occurred = F::from_bool(has_padding_occurred); + }); + + // Fill in the remaining round rows + + round_rows + .par_chunks_mut(width) + .skip(4) + .for_each(|row_slice| { + let cols: &mut Sha256VmRoundCols = + row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); + cols.control.len = len; + cols.control.read_ptr = read_ptr; + cols.control.cur_timestamp = timestamp; + cols.control.pad_flags = get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotConsidered as usize, + ) + .map(F::from_canonical_u32); + cols.control.padding_occurred = F::from_bool(has_padding_occurred); + }); + + // Fill in the digest row + if is_last_block { + has_padding_occurred = false; } - // Fill in the invalid rows - else { - block.par_chunks_mut(width).for_each(|row| { - let cols: &mut Sha256VmRoundCols> = row.borrow_mut(); - self.air.sha256_subair.generate_default_row(&mut cols.inner); - }) + digest_cols.control.pad_flags = + get_flag_pt_array(&self.padding_encoder, PaddingFlags::NotConsidered as usize) + .map(F::from_canonical_u32); + if is_last_block { + let mut timestamp = digest_cols.from_state.timestamp.as_canonical_u32(); + digest_cols.register_reads_aux.iter_mut().for_each(|aux| { + mem_helper.fill_from_prev(timestamp, aux.as_mut()); + timestamp += 1; + }); + mem_helper.fill_from_prev( + digest_cols.control.cur_timestamp.as_canonical_u32(), + digest_cols.writes_aux.as_mut(), + ); + self.bitwise_lookup_chip.request_range( + digest_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() + * mem_ptr_shift, + digest_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1].as_canonical_u32() + * mem_ptr_shift, + ); } + digest_cols.control.padding_occurred = F::from_bool(has_padding_occurred); }); // Do a second pass over the trace to fill in the missing values // Note, we need to skip the very first row - values[width..] + trace[width..] .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) + .take(rows_used / SHA256_ROWS_PER_BLOCK) .for_each(|chunk| { - self.air - .sha256_subair + self.inner .generate_missing_cells(chunk, width, SHA256VM_CONTROL_WIDTH); }); - - AirProofInput::simple_no_pis(RowMajorMatrix::new(values, width)) - } -} - -impl ChipUsageGetter for Sha256VmChip { - fn air_name(&self) -> String { - get_air_name(&self.air) - } - fn current_trace_height(&self) -> usize { - self.records.iter().fold(0, |acc, record| { - acc + record.input_records.len() * SHA256_ROWS_PER_BLOCK - }) } - fn trace_width(&self) -> usize { - BaseAir::::width(&self.air) + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", Rv32Sha256Opcode::SHA256) } } -/// This is the state information that a block will use to generate its trace -#[derive(Debug, Clone)] -struct Sha256State { - hash: [u32; SHA256_HASH_WORDS], - local_block_idx: usize, +fn get_padded_input( + block_input: [u8; SHA256_BLOCK_CELLS], message_len: u32, - block_input_message: [u8; SHA256_BLOCK_CELLS], - block_padded_message: [u8; SHA256_BLOCK_CELLS], - message_idx: usize, + local_block_idx: usize, is_last_block: bool, -} - -impl Sha256VmChip { - fn generate_state( - prev_state: &Option, - block_input_message: [u8; SHA256_BLOCK_CELLS], - message_idx: usize, - message_len: u32, - is_last_block: bool, - ) -> Sha256State { - let local_block_idx = if let Some(prev_state) = prev_state { - prev_state.local_block_idx + 1 - } else { - 0 - }; - let has_padding_occurred = local_block_idx * SHA256_BLOCK_CELLS > message_len as usize; - let message_left = if has_padding_occurred { - 0 - } else { - message_len as usize - local_block_idx * SHA256_BLOCK_CELLS - }; +) -> [u8; SHA256_BLOCK_CELLS] { + let has_padding_occurred = local_block_idx * SHA256_BLOCK_CELLS > message_len as usize; + let message_left = if has_padding_occurred { + 0 + } else { + message_len as usize - local_block_idx * SHA256_BLOCK_CELLS + }; - let padded_message_bytes: [u8; SHA256_BLOCK_CELLS] = array::from_fn(|j| { - if j < message_left { - block_input_message[j] - } else if j == message_left && !has_padding_occurred { - 1 << (RV32_CELL_BITS - 1) - } else if !is_last_block || j < SHA256_BLOCK_CELLS - 4 { - 0u8 - } else { - let shift_amount = (SHA256_BLOCK_CELLS - j - 1) * RV32_CELL_BITS; - ((message_len * RV32_CELL_BITS as u32) - .checked_shr(shift_amount as u32) - .unwrap_or(0) - & ((1 << RV32_CELL_BITS) - 1)) as u8 - } - }); - - if let Some(prev_state) = prev_state { - Sha256State { - hash: Sha256Air::get_block_hash(&prev_state.hash, prev_state.block_padded_message), - local_block_idx, - message_len, - block_input_message, - block_padded_message: padded_message_bytes, - message_idx, - is_last_block, - } + array::from_fn(|j| { + if j < message_left { + block_input[j] + } else if j == message_left && !has_padding_occurred { + 1 << (RV32_CELL_BITS - 1) + } else if !is_last_block || j < SHA256_BLOCK_CELLS - 4 { + 0u8 } else { - Sha256State { - hash: SHA256_H, - local_block_idx: 0, - message_len, - block_input_message, - block_padded_message: padded_message_bytes, - message_idx, - is_last_block, - } + let shift_amount = (SHA256_BLOCK_CELLS - j - 1) * RV32_CELL_BITS; + ((message_len * RV32_CELL_BITS as u32) + .checked_shr(shift_amount as u32) + .unwrap_or(0) + & ((1 << RV32_CELL_BITS) - 1)) as u8 } - } + }) } diff --git a/extensions/sha256/guest/src/lib.rs b/extensions/sha256/guest/src/lib.rs index cb34bcd5aa..5dea93b35a 100644 --- a/extensions/sha256/guest/src/lib.rs +++ b/extensions/sha256/guest/src/lib.rs @@ -1,5 +1,8 @@ #![no_std] +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + /// This is custom-0 defined in RISC-V spec document pub const OPCODE: u8 = 0x0b; pub const SHA256_FUNCT3: u8 = 0b100; @@ -13,7 +16,8 @@ pub fn sha256(input: &[u8]) -> [u8; 32] { output } -/// zkvm native implementation of sha256 +/// Native hook for sha256 +/// /// # Safety /// /// The VM accepts the preimage by pointer and length, and writes the @@ -21,11 +25,54 @@ pub fn sha256(input: &[u8]) -> [u8; 32] { /// - `bytes` must point to an input buffer at least `len` long. /// - `output` must point to a buffer that is at least 32-bytes long. /// -/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +/// [`sha2`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf #[cfg(target_os = "zkvm")] #[inline(always)] #[no_mangle] extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { + // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or + // `output` are not aligned to 4 bytes. + // The minimum alignment required for the input and output buffers + const MIN_ALIGN: usize = 4; + // The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes + const INPUT_ALIGN: usize = 16; + // The preferred alignment for the output buffer, since the output is written in chunks of 32 + // bytes + const OUTPUT_ALIGN: usize = 32; + unsafe { + if bytes as usize % MIN_ALIGN != 0 { + let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(aligned_buff.ptr, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(aligned_buff.ptr, len, output); + } + } else { + if output as usize % MIN_ALIGN != 0 { + let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); + __native_sha256(bytes, len, aligned_out.ptr); + core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); + } else { + __native_sha256(bytes, len, output); + } + }; + } +} + +/// sha256 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the preimage by pointer and length, and writes the +/// 32-byte hash. +/// - `bytes` must point to an input buffer at least `len` long. +/// - `output` must point to a buffer that is at least 32-bytes long. +/// - `bytes` and `output` must be 4-byte aligned. +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) { openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len); }