diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index 824291a8b..c6814fc60 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '**' + - "**" merge_group: types: [checks_requested] workflow_dispatch: {} @@ -25,7 +25,6 @@ env: LLVM_VERSION: "14.0" LLVM_FEATURE_NAME: "14-0" - jobs: # Check if changes were made to the relevant files. # Always returns true if running on the default branch, to ensure all changes are thoroughly checked. @@ -43,25 +42,25 @@ jobs: model: ${{ steps.filter.outputs.model == 'true' || steps.override.outputs.out == 'true' }} llvm: ${{ steps.filter.outputs.llvm == 'true' || steps.override.outputs.out == 'true' }} steps: - - uses: actions/checkout@v4 - - name: Override label - id: override - run: | - echo "Label contains run-ci-checks: $OVERRIDE_LABEL" - if [ "$OVERRIDE_LABEL" == "true" ]; then - echo "Overriding due to label 'run-ci-checks'" - echo "out=true" >> $GITHUB_OUTPUT - elif [ "$DEFAULT_BRANCH" == "true" ]; then - echo "Overriding due to running on the default branch" - echo "out=true" >> $GITHUB_OUTPUT - fi - env: - OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} - DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} - - uses: dorny/paths-filter@v3 - id: filter - with: - filters: .github/change-filters.yml + - uses: actions/checkout@v4 + - name: Override label + id: override + run: | + echo "Label contains run-ci-checks: $OVERRIDE_LABEL" + if [ "$OVERRIDE_LABEL" == "true" ]; then + echo "Overriding due to label 'run-ci-checks'" + echo "out=true" >> $GITHUB_OUTPUT + elif [ "$DEFAULT_BRANCH" == "true" ]; then + echo "Overriding due to running on the default branch" + echo "out=true" >> $GITHUB_OUTPUT + fi + env: + OVERRIDE_LABEL: ${{ github.event_name == 'pull_request' && contains( github.event.pull_request.labels.*.name, 'run-ci-checks') }} + DEFAULT_BRANCH: ${{ github.ref_name == github.event.repository.default_branch }} + - uses: dorny/paths-filter@v3 + id: filter + with: + filters: .github/change-filters.yml check: needs: changes @@ -109,7 +108,7 @@ jobs: - name: Override criterion with the CodSpeed harness run: cargo add --dev codspeed-criterion-compat --rename criterion --package hugr - name: Build benchmarks - run: cargo codspeed build --profile bench --features extension_inference,declarative,model_unstable,llvm,llvm-test + run: cargo codspeed build --profile bench --features declarative,llvm,llvm-test - name: Run benchmarks uses: CodSpeedHQ/action@v3 with: @@ -234,7 +233,7 @@ jobs: id: toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: "1.75" + toolchain: "1.85" - name: Install nightly toolchain uses: dtolnay/rust-toolchain@master with: @@ -253,12 +252,10 @@ jobs: cargo binstall cargo-minimal-versions --force - name: Pin transitive dependencies not compatible with our MSRV # Add new dependencies as needed if the check fails due to - # "package `XXX` cannot be built because it requires rustc YYY or newer, while the currently active rustc version is 1.75.0" + # "package `XXX` cannot be built because it requires rustc YYY or newer, while the currently active rustc version is 1.85.0" run: | - rm Cargo.lock - cargo add -p hugr half@2.4.1 - cargo add -p hugr litemap@0.7.4 - cargo add -p hugr zerofrom@0.1.5 + # rm Cargo.lock + # cargo add -p hugr half@2.4.1 - name: Build with no features run: cargo minimal-versions --direct test --verbose --no-default-features --no-run - name: Tests with no features diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 569ccfec1..b6e481bd5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,7 +79,7 @@ repos: # built into a binary build (without using `maturin`) # # This feature list should be kept in sync with the `hugr-py/pyproject.toml` - entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/extension_inference hugr/declarative hugr/model_unstable hugr/llvm hugr/llvm-test hugr/zstd' + entry: cargo test --workspace --exclude 'hugr-py' --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' language: system files: \.rs$ pass_filenames: false @@ -100,10 +100,7 @@ repos: - id: py-test name: pytest description: Run python tests - # We need to rebuild `hugr-cli` without the `extension_inference` feature - # to avoid test errors. - # TODO: Remove this once the issue is fixed. - entry: sh -c "cargo build -p hugr-cli && uv run pytest" + entry: sh -c "uv run pytest" language: system files: \.py$ pass_filenames: false diff --git a/Cargo.lock b/Cargo.lock index 7198a4ea5..b05cba1e8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -24,7 +24,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "getrandom 0.2.15", + "getrandom 0.2.16", "once_cell", "serde", "version_check", @@ -301,9 +301,9 @@ checksum = "38c99613cb3cd7429889a08dfcf651721ca971c86afa30798461f8eee994de47" [[package]] name = "bstr" -version = "1.11.3" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "531a9155a481e2ee699d4f98f43c0ca4ff8ee1bfd55c31e9e98fb29d2b176fe0" +checksum = "234113d19d0d7d613b40e86fb654acf958910802bcceab913a4f9e7cda03b1a4" dependencies = [ "memchr", "regex-automata", @@ -351,9 +351,9 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.2.18" +version = "1.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525046617d8376e3db1deffb079e91cef90a89fc3ca5c185bbf8c9ecdd15cd5c" +checksum = "04da6a0d40b948dfc4fa8f5bbf402b0fc1a64a28dbf7d12ffd683550f2c1b63a" dependencies = [ "jobserver", "libc", @@ -936,9 +936,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" dependencies = [ "cfg-if", "libc", @@ -1458,9 +1458,9 @@ checksum = "f4c7245a08504955605670dbf141fceab975f15ca21570696aebe9d2e71576bd" [[package]] name = "inkwell" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40fb405537710d51f6bdbc8471365ddd4cd6d3a3c3ad6e0c8291691031ba94b2" +checksum = "e67349bd7578d4afebbe15eaa642a80b884e8623db74b1716611b131feb1deef" dependencies = [ "either", "inkwell_internals", @@ -1472,9 +1472,9 @@ dependencies = [ [[package]] name = "inkwell_internals" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd28cfd4cfba665d47d31c08a6ba637eed16770abca2eccbbc3ca831fef1e44" +checksum = "f365c8de536236cfdebd0ba2130de22acefed18b1fb99c32783b3840aec5fb46" dependencies = [ "proc-macro2", "quote", @@ -1483,14 +1483,12 @@ dependencies = [ [[package]] name = "insta" -version = "1.42.2" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50259abbaa67d11d2bcafc7ba1d094ed7a0c70e3ce893f0d0997f73558cb3084" +checksum = "ab2d11b2f17a45095b8c3603928ba29d7d918d7129d0d0641a36ba73cf07daa6" dependencies = [ "console", - "linked-hash-map", "once_cell", - "pin-project", "serde", "similar", ] @@ -1622,21 +1620,15 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.171" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" - -[[package]] -name = "linked-hash-map" -version = "0.5.6" +version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" +checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" [[package]] name = "linux-raw-sys" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe7db12097d22ec582439daf8618b8fdd1a7bef6270e9af3b1ebcd30893cf413" +checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" [[package]] name = "litemap" @@ -1696,9 +1688,9 @@ checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" [[package]] name = "miniz_oxide" -version = "0.8.7" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff70ce3e48ae43fa075863cef62e8b43b71a4f2382229920e0df362592919430" +checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a" dependencies = [ "adler2", ] @@ -1937,26 +1929,6 @@ dependencies = [ "serde", ] -[[package]] -name = "pin-project" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "677f1add503faace112b9f1373e43e9e054bfdd22ff1a63c1bc485eaec6a6a8a" -dependencies = [ - "pin-project-internal", -] - -[[package]] -name = "pin-project-internal" -version = "1.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e918e4ff8c4549eb882f14b3a4bc8c8bc93de829416eacf579f1207a8fbf861" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "pin-project-lite" version = "0.2.16" @@ -2011,9 +1983,9 @@ checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" [[package]] name = "portgraph" -version = "0.14.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a9ea69cfb011d5f17af28813ec37a0a9668a063090e14ad75dc5fc07ba01b47" +checksum = "5fdce52d51ec359351ff3c209fafb6f133562abf52d951ce5821c0184798d979" dependencies = [ "bitvec", "delegate", @@ -2029,7 +2001,7 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" dependencies = [ - "zerocopy 0.8.24", + "zerocopy 0.8.25", ] [[package]] @@ -2094,9 +2066,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] @@ -2259,7 +2231,7 @@ version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.2.16", ] [[package]] @@ -2694,9 +2666,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.100" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -2830,15 +2802,15 @@ dependencies = [ [[package]] name = "toml_datetime" -version = "0.6.8" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" +checksum = "3da5db5a963e24bc68be8b17b6fa82814bb22ee8660f192bb182771d498f09a3" [[package]] name = "toml_edit" -version = "0.22.24" +version = "0.22.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17b4795ff5edd201c7cd6dca065ae59972ce77d1b80fa0a84d94950ece7d1474" +checksum = "10558ed0bd2a1562e630926a2d1f0b98c827da99fabd3fe20920a59642504485" dependencies = [ "indexmap", "toml_datetime", @@ -3418,9 +3390,9 @@ checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" [[package]] name = "winnow" -version = "0.7.6" +version = "0.7.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63d3fcd9bba44b03821e7d699eeee959f3126dcc4aa8e4ae18ec617c2a5cea10" +checksum = "6cb8234a863ea0e8cd7284fcdd4f145233eb00fee02bbdd9861aec44e6477bc5" dependencies = [ "memchr", ] @@ -3496,11 +3468,11 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2586fea28e186957ef732a5f8b3be2da217d65c5969d4b1e17f973ebbe876879" +checksum = "a1702d9583232ddb9174e01bb7c15a2ab8fb1bc6f227aa1233858c351a3ba0cb" dependencies = [ - "zerocopy-derive 0.8.24", + "zerocopy-derive 0.8.25", ] [[package]] @@ -3516,9 +3488,9 @@ dependencies = [ [[package]] name = "zerocopy-derive" -version = "0.8.24" +version = "0.8.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" +checksum = "28a6e20d751156648aa063f3800b706ee209a32c0b4d9f24be3d980b01be55ef" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index 3031df1e7..97dad7dea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ members = [ default-members = ["hugr", "hugr-core", "hugr-passes", "hugr-cli", "hugr-model"] [workspace.package] -rust-version = "1.75" +rust-version = "1.85" edition = "2021" homepage = "https://github.com/CQCL/hugr" repository = "https://github.com/CQCL/hugr" @@ -58,7 +58,7 @@ regex = "1.10.6" regex-syntax = "0.8.3" rstest = "0.24.0" semver = "1.0.26" -serde = "1.0.195" +serde = "1.0.219" serde_json = "1.0.140" serde_yaml = "0.9.34" smol_str = "0.3.1" @@ -87,8 +87,8 @@ zstd = "0.13.2" # These public dependencies usually require breaking changes downstream, so we # try to be as permissive as possible. pyo3 = ">= 0.23.4, < 0.25" -portgraph = { version = ">= 0.13.3, < 0.15" } -petgraph = { version = ">= 0.7.1, < 0.9", default-features = false } +portgraph = { version = "0.14.1" } +petgraph = { version = ">= 0.8.1, < 0.9", default-features = false } [profile.dev.package] insta.opt-level = 3 diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 4659f96c7..d9f19ed64 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -28,10 +28,10 @@ shell by setting up [direnv](https://devenv.sh/automatic-shell-activation/). To setup the environment manually you will need: -- Just: https://just.systems/ -- Rust `>=1.75`: https://www.rust-lang.org/tools/install -- uv `>=0.3`: docs.astral.sh/uv/getting-started/installation -- Optional: capnproto `>=1.0`: https://capnproto.org/install.html +- Just: +- Rust `>=1.85`: +- uv `>=0.3`: +- Optional: capnproto `>=1.0`: Required when modifying the `hugr-model` serialization schema. - Optional: llvm `== 14.0`. The "llvm" feature (backed by the sub-crate `hugr-llvm`) requires LLVM installed. We use the rust bindings diff --git a/hugr-cli/README.md b/hugr-cli/README.md index 277628d2b..dba9900e2 100644 --- a/hugr-cli/README.md +++ b/hugr-cli/README.md @@ -64,7 +64,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-cli/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-cli [crates]: https://img.shields.io/crates/v/hugr-cli [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-core/Cargo.toml b/hugr-core/Cargo.toml index 22e5390fc..8da686ee8 100644 --- a/hugr-core/Cargo.toml +++ b/hugr-core/Cargo.toml @@ -17,9 +17,7 @@ categories = ["compilers"] workspace = true [features] -extension_inference = [] declarative = ["serde_yaml"] -model_unstable = ["hugr-model"] zstd = ["dep:zstd"] [lib] @@ -27,10 +25,9 @@ bench = false [[test]] name = "model" -required-features = ["model_unstable"] [dependencies] -hugr-model = { version = "0.19.0", path = "../hugr-model", optional = true } +hugr-model = { version = "0.19.0", path = "../hugr-model" } cgmath = { workspace = true, features = ["serde"] } delegate = { workspace = true } diff --git a/hugr-core/README.md b/hugr-core/README.md index 46cafe16f..0e15305f1 100644 --- a/hugr-core/README.md +++ b/hugr-core/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr-core -=============== +# hugr-core [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr-core) @@ -15,15 +14,8 @@ Please read the [API documentation here][]. ## Experimental Features -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. -- `model_unstable` - Import and export from the representation defined in the `hugr-model` crate. - Unstable and subject to change. Not enabled by default. ## Recent Changes @@ -38,10 +30,10 @@ See [DEVELOPMENT.md](https://github.com/CQCL/hugr/blob/main/DEVELOPMENT.md) for This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http://www.apache.org/licenses/LICENSE-2.0). - [API documentation here]: https://docs.rs/hugr-core/ - [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg - [crates]: https://img.shields.io/crates/v/hugr-core - [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov - [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md +[API documentation here]: https://docs.rs/hugr-core/ +[build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main +[msrv]: https://img.shields.io/crates/msrv/hugr-core +[crates]: https://img.shields.io/crates/v/hugr-core +[codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov +[LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE +[CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-core/CHANGELOG.md diff --git a/hugr-core/src/builder.rs b/hugr-core/src/builder.rs index 056690e0a..9f7a219a7 100644 --- a/hugr-core/src/builder.rs +++ b/hugr-core/src/builder.rs @@ -42,7 +42,7 @@ //! let _dfg_handle = { //! let mut dfg = module_builder.define_function( //! "main", -//! Signature::new_endo(bool_t()).with_extension_delta(logic::EXTENSION_ID), +//! Signature::new_endo(bool_t()), //! )?; //! //! // Get the wires from the function inputs. @@ -59,8 +59,7 @@ //! let _circuit_handle = { //! let mut dfg = module_builder.define_function( //! "circuit", -//! Signature::new_endo(vec![bool_t(), bool_t()]) -//! .with_extension_delta(logic::EXTENSION_ID), +//! Signature::new_endo(vec![bool_t(), bool_t()]), //! )?; //! let mut circuit = dfg.as_circuit(dfg.input_wires()); //! @@ -89,7 +88,7 @@ use thiserror::Error; use crate::extension::simple_op::OpLoadError; -use crate::extension::{SignatureError, TO_BE_INFERRED}; +use crate::extension::SignatureError; use crate::hugr::ValidationError; use crate::ops::handle::{BasicBlockID, CfgID, ConditionalID, DfgID, FuncID, TailLoopID}; use crate::ops::{NamedOp, OpType}; @@ -123,16 +122,14 @@ pub use conditional::{CaseBuilder, ConditionalBuilder}; mod circuit; pub use circuit::{CircuitBuildError, CircuitBuilder}; -/// Return a FunctionType with the same input and output types (specified) -/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +/// Return a FunctionType with the same input and output types (specified). pub fn endo_sig(types: impl Into) -> Signature { - Signature::new_endo(types).with_extension_delta(TO_BE_INFERRED) + Signature::new_endo(types) } -/// Return a FunctionType with the specified input and output types -/// whose extension delta, when used in a non-FuncDefn container, will be inferred. +/// Return a FunctionType with the specified input and output types. pub fn inout_sig(inputs: impl Into, outputs: impl Into) -> Signature { - Signature::new(inputs, outputs).with_extension_delta(TO_BE_INFERRED) + Signature::new(inputs, outputs) } #[derive(Debug, Clone, PartialEq, Error)] diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index f1613895d..ba366c117 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -20,7 +20,7 @@ use crate::{ types::EdgeKind, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; +use crate::extension::ExtensionRegistry; use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeRow}; use itertools::Itertools; @@ -119,7 +119,7 @@ pub trait Container { } /// Insert a copy of a HUGR as a child of the container. - fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult { + fn add_hugr_view(&mut self, child: &H) -> InsertionResult { let parent = self.container_node(); self.hugr_mut().insert_from_view(parent, child) } @@ -153,7 +153,7 @@ pub trait Container { where ExtensionRegistry: Extend, { - self.hugr_mut().extensions_mut().extend(registry); + self.hugr_mut().use_extensions(registry); } } @@ -319,10 +319,7 @@ pub trait Dataflow: Container { inputs: impl IntoIterator, ) -> Result, BuildError> { let (types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); - self.dfg_builder( - Signature::new_endo(types).with_extension_delta(TO_BE_INFERRED), - input_wires, - ) + self.dfg_builder(Signature::new_endo(types), input_wires) } /// Return a builder for a [`crate::ops::CFG`] node, @@ -330,7 +327,6 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// The Extension delta will be inferred. /// /// # Errors /// @@ -340,27 +336,6 @@ pub trait Dataflow: Container { &mut self, inputs: impl IntoIterator, output_types: TypeRow, - ) -> Result, BuildError> { - self.cfg_builder_exts(inputs, output_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::CFG`] node, - /// i.e. a nested controlflow subgraph. - /// The `inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. - /// The `output_types` are the types of the outputs. - /// `extension_delta` is explicitly specified. Alternatively - /// [cfg_builder](Self::cfg_builder) may be used to infer it. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the CFG node. - fn cfg_builder_exts( - &mut self, - inputs: impl IntoIterator, - output_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, input_wires): (Vec, Vec) = inputs.into_iter().unzip(); @@ -369,8 +344,7 @@ pub trait Dataflow: Container { let (cfg_node, _) = add_node_with_wires( self, ops::CFG { - signature: Signature::new(inputs.clone(), output_types.clone()) - .with_extension_delta(extension_delta), + signature: Signature::new(inputs.clone(), output_types.clone()), }, input_wires, )?; @@ -449,7 +423,6 @@ pub trait Dataflow: Container { /// The `inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// The extension delta will be inferred. /// /// # Errors /// @@ -461,27 +434,6 @@ pub trait Dataflow: Container { just_inputs: impl IntoIterator, inputs_outputs: impl IntoIterator, just_out_types: TypeRow, - ) -> Result, BuildError> { - self.tail_loop_builder_exts(just_inputs, inputs_outputs, just_out_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::TailLoop`] node. - /// The `inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. - /// The `output_types` are the types of the outputs. - /// `extension_delta` explicitly specified. Alternatively - /// [tail_loop_builder](Self::tail_loop_builder) may be used to infer it. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the [`ops::TailLoop`] node. - fn tail_loop_builder_exts( - &mut self, - just_inputs: impl IntoIterator, - inputs_outputs: impl IntoIterator, - just_out_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let (input_types, mut input_wires): (Vec, Vec) = just_inputs.into_iter().unzip(); @@ -493,7 +445,6 @@ pub trait Dataflow: Container { just_inputs: input_types.into(), just_outputs: just_out_types, rest: rest_types.into(), - extension_delta: extension_delta.into(), }; // TODO: Make input extensions a parameter let (loop_node, _) = add_node_with_wires(self, tail_loop.clone(), input_wires)?; @@ -507,41 +458,17 @@ pub trait Dataflow: Container { /// /// The `other_inputs` must be an iterable over pairs of the type of the input and /// the corresponding wire. - /// The `output_types` are the types of the outputs. Extension delta will be inferred. - /// - /// # Errors - /// - /// This function will return an error if there is an error when building - /// the Conditional node. - fn conditional_builder( - &mut self, - sum_input: (impl IntoIterator, Wire), - other_inputs: impl IntoIterator, - output_types: TypeRow, - ) -> Result, BuildError> { - self.conditional_builder_exts(sum_input, other_inputs, output_types, TO_BE_INFERRED) - } - - /// Return a builder for a [`crate::ops::Conditional`] node. - /// `sum_rows` and `sum_wire` define the type of the Sum - /// variants and the wire carrying the Sum respectively. - /// - /// The `other_inputs` must be an iterable over pairs of the type of the input and - /// the corresponding wire. /// The `output_types` are the types of the outputs. - /// `extension_delta` is explicitly specified. Alternatively - /// [conditional_builder](Self::conditional_builder) may be used to infer it. /// /// # Errors /// /// This function will return an error if there is an error when building /// the Conditional node. - fn conditional_builder_exts( + fn conditional_builder( &mut self, (sum_rows, sum_wire): (impl IntoIterator, Wire), other_inputs: impl IntoIterator, output_types: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let mut input_wires = vec![sum_wire]; let (input_types, rest_input_wires): (Vec, Vec) = @@ -558,7 +485,6 @@ pub trait Dataflow: Container { sum_rows, other_inputs: inputs, outputs: output_types, - extension_delta: extension_delta.into(), }, input_wires, )?; diff --git a/hugr-core/src/builder/cfg.rs b/hugr-core/src/builder/cfg.rs index 81c7d7269..0aadc047b 100644 --- a/hugr-core/src/builder/cfg.rs +++ b/hugr-core/src/builder/cfg.rs @@ -5,9 +5,8 @@ use super::{ BasicBlockID, BuildError, CfgID, Container, Dataflow, HugrBuilder, Wire, }; -use crate::extension::TO_BE_INFERRED; use crate::ops::{self, handle::NodeHandle, DataflowBlock, DataflowParent, ExitBlock, OpType}; -use crate::{extension::ExtensionSet, types::Signature}; +use crate::types::Signature; use crate::{hugr::views::HugrView, types::TypeRow}; use crate::Node; @@ -106,7 +105,6 @@ use crate::{hugr::HugrMut, type_row, Hugr}; /// let hugr = cfg_builder.finish_hugr()?; /// Ok(hugr) /// }; -/// #[cfg(not(feature = "extension_inference"))] /// assert!(make_cfg().is_ok()); /// ``` #[derive(Debug, PartialEq)] @@ -157,10 +155,7 @@ impl CFGBuilder { } impl HugrBuilder for CFGBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } @@ -192,7 +187,7 @@ impl + AsRef> CFGBuilder { /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` /// and `outputs` and the variants of the branching Sum value - /// specified by `sum_rows`. Extension delta will be inferred. + /// specified by `sum_rows`. /// /// # Errors /// @@ -203,36 +198,12 @@ impl + AsRef> CFGBuilder { sum_rows: impl IntoIterator, other_outputs: TypeRow, ) -> Result, BuildError> { - self.block_builder_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` - /// and `outputs` and the variants of the branching Sum value - /// specified by `sum_rows`. Extension delta will be inferred. - /// - /// # Errors - /// - /// This function will return an error if there is an error adding the node. - pub fn block_builder_exts( - &mut self, - inputs: TypeRow, - sum_rows: impl IntoIterator, - other_outputs: TypeRow, - extension_delta: impl Into, - ) -> Result, BuildError> { - self.any_block_builder( - inputs, - extension_delta.into(), - sum_rows, - other_outputs, - false, - ) + self.any_block_builder(inputs, sum_rows, other_outputs, false) } fn any_block_builder( &mut self, inputs: TypeRow, - extension_delta: ExtensionSet, sum_rows: impl IntoIterator, other_outputs: TypeRow, entry: bool, @@ -242,7 +213,6 @@ impl + AsRef> CFGBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), sum_rows, - extension_delta, }); let parent = self.container_node(); let block_n = if entry { @@ -257,9 +227,9 @@ impl + AsRef> CFGBuilder { BlockBuilder::create(self.hugr_mut(), block_n) } - /// Return a builder for a non-entry [`DataflowBlock`] child graph with `inputs` - /// and `outputs` and `extension_delta` explicitly specified, plus a UnitSum type - /// (a Sum of `n_cases` unit types) to select the successor. + /// Return a builder for a non-entry [`DataflowBlock`] child graph with + /// `inputs` and `outputs` , plus a UnitSum type (a Sum of `n_cases` unit + /// types) to select the successor. /// /// # Errors /// @@ -269,17 +239,15 @@ impl + AsRef> CFGBuilder { signature: Signature, n_cases: usize, ) -> Result, BuildError> { - self.block_builder_exts( + self.block_builder( signature.input, vec![type_row![]; n_cases], signature.output, - signature.runtime_reqs, ) } /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs` /// and the variants of the branching Sum value specified by `sum_rows`. - /// Extension delta will be inferred. /// /// # Errors /// @@ -288,35 +256,12 @@ impl + AsRef> CFGBuilder { &mut self, sum_rows: impl IntoIterator, other_outputs: TypeRow, - ) -> Result, BuildError> { - self.entry_builder_exts(sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Return a builder for the entry [`DataflowBlock`] child graph with `outputs`, - /// the variants of the branching Sum value specified by `sum_rows`, and - /// `extension_delta` explicitly specified. ([entry_builder](Self::entry_builder) - /// may be used to infer.) - /// - /// # Errors - /// - /// This function will return an error if an entry block has already been built. - pub fn entry_builder_exts( - &mut self, - sum_rows: impl IntoIterator, - other_outputs: TypeRow, - extension_delta: impl Into, ) -> Result, BuildError> { let inputs = self .inputs .take() .ok_or(BuildError::EntryBuiltError(self.cfg_node))?; - self.any_block_builder( - inputs, - extension_delta.into(), - sum_rows, - other_outputs, - true, - ) + self.any_block_builder(inputs, sum_rows, other_outputs, true) } /// Return a builder for the entry [`DataflowBlock`] child graph with @@ -333,22 +278,6 @@ impl + AsRef> CFGBuilder { self.entry_builder(vec![type_row![]; n_cases], outputs) } - /// Return a builder for the entry [`DataflowBlock`] child graph with - /// `outputs` and a Sum of `n_cases` unit types, and explicit `extension_delta`. - /// ([simple_entry_builder](Self::simple_entry_builder) may be used to infer.) - /// - /// # Errors - /// - /// This function will return an error if there is an error adding the node. - pub fn simple_entry_builder_exts( - &mut self, - outputs: TypeRow, - n_cases: usize, - extension_delta: impl Into, - ) -> Result, BuildError> { - self.entry_builder_exts(vec![type_row![]; n_cases], outputs, extension_delta) - } - /// Returns the exit block of this [`CFGBuilder`]. pub fn exit_block(&self) -> BasicBlockID { self.exit_node.into() @@ -412,23 +341,10 @@ impl + AsRef> BlockBuilder { impl BlockBuilder { /// Initialize a [`DataflowBlock`] rooted HUGR builder. - /// Extension delta will be inferred. pub fn new( inputs: impl Into, sum_rows: impl IntoIterator, other_outputs: impl Into, - ) -> Result { - Self::new_exts(inputs, sum_rows, other_outputs, TO_BE_INFERRED) - } - - /// Initialize a [`DataflowBlock`] rooted HUGR builder. - /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) - /// may be used to infer it. - pub fn new_exts( - inputs: impl Into, - sum_rows: impl IntoIterator, - other_outputs: impl Into, - extension_delta: impl Into, ) -> Result { let inputs = inputs.into(); let sum_rows: Vec<_> = sum_rows.into_iter().collect(); @@ -437,7 +353,6 @@ impl BlockBuilder { inputs: inputs.clone(), other_outputs: other_outputs.clone(), sum_rows, - extension_delta: extension_delta.into(), }; let base = Hugr::new(op); @@ -507,11 +422,7 @@ pub(crate) mod test { ) -> Result<(), BuildError> { let usize_row: TypeRow = vec![usize_t()].into(); let sum2_variants = vec![usize_row.clone(), usize_row]; - let mut entry_b = cfg_builder.entry_builder_exts( - sum2_variants.clone(), - type_row![], - ExtensionSet::new(), - )?; + let mut entry_b = cfg_builder.entry_builder(sum2_variants.clone(), type_row![])?; let entry = { let [inw] = entry_b.input_wires_arr(); @@ -537,11 +448,7 @@ pub(crate) mod test { let sum_tuple_const = cfg_builder.add_constant(ops::Value::unary_unit_sum()); let sum_variants = vec![type_row![]]; - let mut entry_b = cfg_builder.entry_builder_exts( - sum_variants.clone(), - type_row![], - ExtensionSet::new(), - )?; + let mut entry_b = cfg_builder.entry_builder(sum_variants.clone(), type_row![])?; let [inw] = entry_b.input_wires_arr(); let entry = { let sum = entry_b.load_const(&sum_tuple_const); diff --git a/hugr-core/src/builder/circuit.rs b/hugr-core/src/builder/circuit.rs index 01f5e3e45..eb48b4fbc 100644 --- a/hugr-core/src/builder/circuit.rs +++ b/hugr-core/src/builder/circuit.rs @@ -245,8 +245,8 @@ mod test { use crate::builder::{Container, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::extension::{ExtensionId, ExtensionSet}; - use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; + use crate::extension::ExtensionId; + use crate::std_extensions::arithmetic::float_types::ConstF64; use crate::utils::test_quantum_extension::{ self, cx_gate, h_gate, measure, q_alloc, q_discard, rz_f64, }; @@ -260,10 +260,7 @@ mod test { #[test] fn simple_linear() { let build_res = build_main( - Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .with_extension_delta(float_types::EXTENSION_ID) - .into(), + Signature::new(vec![qb_t(), qb_t()], vec![qb_t(), qb_t()]).into(), |mut f_build| { let wires = f_build.input_wires().map(Some).collect(); @@ -314,11 +311,7 @@ mod test { Signature::new( vec![qb_t(), qb_t(), usize_t()], vec![qb_t(), qb_t(), bool_t()], - ) - .with_extension_delta(ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - my_ext_name, - ])), + ), ) .unwrap(); @@ -351,38 +344,33 @@ mod test { #[test] fn ancillae() { - let build_res = build_main( - Signature::new_endo(qb_t()) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .into(), - |mut f_build| { - let mut circ = f_build.as_circuit(f_build.input_wires()); - assert_eq!(circ.n_wires(), 1); + let build_res = build_main(Signature::new_endo(qb_t()).into(), |mut f_build| { + let mut circ = f_build.as_circuit(f_build.input_wires()); + assert_eq!(circ.n_wires(), 1); - let [q0] = circ.tracked_units_arr(); - let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?; - let ancilla = circ.track_wire(ancilla); + let [q0] = circ.tracked_units_arr(); + let [ancilla] = circ.append_with_outputs_arr(q_alloc(), [] as [CircuitUnit; 0])?; + let ancilla = circ.track_wire(ancilla); - assert_ne!(ancilla, 0); - assert_eq!(circ.n_wires(), 2); - assert_eq!(circ.tracked_units_arr(), [q0, ancilla]); + assert_ne!(ancilla, 0); + assert_eq!(circ.n_wires(), 2); + assert_eq!(circ.tracked_units_arr(), [q0, ancilla]); - circ.append(cx_gate(), [q0, ancilla])?; - let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?; + circ.append(cx_gate(), [q0, ancilla])?; + let [_bit] = circ.append_with_outputs_arr(measure(), [q0])?; - let q0 = circ.untrack_wire(q0)?; + let q0 = circ.untrack_wire(q0)?; - assert_eq!(circ.tracked_units_arr(), [ancilla]); + assert_eq!(circ.tracked_units_arr(), [ancilla]); - circ.append_and_consume(q_discard(), [q0])?; + circ.append_and_consume(q_discard(), [q0])?; - let outs = circ.finish(); + let outs = circ.finish(); - assert_eq!(outs.len(), 1); + assert_eq!(outs.len(), 1); - f_build.finish_with_outputs(outs) - }, - ); + f_build.finish_with_outputs(outs) + }); assert_matches!(build_res, Ok(_)); } diff --git a/hugr-core/src/builder/conditional.rs b/hugr-core/src/builder/conditional.rs index 0404abaf3..73670526c 100644 --- a/hugr-core/src/builder/conditional.rs +++ b/hugr-core/src/builder/conditional.rs @@ -1,6 +1,4 @@ -use crate::extension::TO_BE_INFERRED; use crate::hugr::views::HugrView; -use crate::ops::dataflow::DataflowOpTrait; use crate::types::{Signature, TypeRow}; use crate::ops; @@ -16,7 +14,7 @@ use super::{ }; use crate::Node; -use crate::{extension::ExtensionSet, hugr::HugrMut, Hugr}; +use crate::{hugr::HugrMut, Hugr}; use std::collections::HashSet; @@ -107,7 +105,6 @@ impl + AsRef> ConditionalBuilder { .clone() .try_into() .expect("Parent node does not have Conditional optype."); - let extension_delta = cond.signature().runtime_reqs.clone(); let inputs = cond .case_input_row(case) .ok_or(ConditionalBuildError::NotCase { conditional, case })?; @@ -118,8 +115,7 @@ impl + AsRef> ConditionalBuilder { let outputs = cond.outputs; let case_op = ops::Case { - signature: Signature::new(inputs.clone(), outputs.clone()) - .with_extension_delta(extension_delta.clone()), + signature: Signature::new(inputs.clone(), outputs.clone()), }; let case_node = // add case before any existing subsequent cases @@ -134,7 +130,7 @@ impl + AsRef> ConditionalBuilder { let dfg_builder = DFGBuilder::create_with_io( self.hugr_mut(), case_node, - Signature::new(inputs, outputs).with_extension_delta(extension_delta), + Signature::new(inputs, outputs), )?; Ok(CaseBuilder::from_dfg_builder(dfg_builder)) @@ -142,33 +138,18 @@ impl + AsRef> ConditionalBuilder { } impl HugrBuilder for ConditionalBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } } impl ConditionalBuilder { - /// Initialize a Conditional rooted HUGR builder, extension delta will be inferred. + /// Initialize a Conditional rooted HUGR builder. pub fn new( sum_rows: impl IntoIterator, other_inputs: impl Into, outputs: impl Into, - ) -> Result { - Self::new_exts(sum_rows, other_inputs, outputs, TO_BE_INFERRED) - } - - /// Initialize a Conditional rooted HUGR builder, - /// `extension_delta` explicitly specified. Alternatively, - /// [new](Self::new) may be used to infer it. - pub fn new_exts( - sum_rows: impl IntoIterator, - other_inputs: impl Into, - outputs: impl Into, - extension_delta: impl Into, ) -> Result { let sum_rows: Vec<_> = sum_rows.into_iter().collect(); let other_inputs = other_inputs.into(); @@ -181,7 +162,6 @@ impl ConditionalBuilder { sum_rows, other_inputs, outputs, - extension_delta: extension_delta.into(), }; let base = Hugr::new(op); let conditional_node = base.root(); @@ -225,12 +205,8 @@ mod test { #[test] fn basic_conditional() -> Result<(), BuildError> { - let mut conditional_b = ConditionalBuilder::new_exts( - [type_row![], type_row![]], - vec![usize_t()], - vec![usize_t()], - ExtensionSet::new(), - )?; + let mut conditional_b = + ConditionalBuilder::new([type_row![], type_row![]], vec![usize_t()], vec![usize_t()])?; n_identity(conditional_b.case_builder(1)?)?; n_identity(conditional_b.case_builder(0)?)?; diff --git a/hugr-core/src/builder/dataflow.rs b/hugr-core/src/builder/dataflow.rs index ebad52085..4e66f857f 100644 --- a/hugr-core/src/builder/dataflow.rs +++ b/hugr-core/src/builder/dataflow.rs @@ -82,10 +82,7 @@ impl DFGBuilder { } impl HugrBuilder for DFGBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.base.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.base.validate()?; Ok(self.base) } @@ -174,9 +171,7 @@ impl FunctionBuilder { // Update the inner input node let types = new_optype.signature.body().input.clone(); - self.hugr_mut() - .replace_op(inp_node, Input { types }) - .unwrap(); + self.hugr_mut().replace_op(inp_node, Input { types }); let mut new_port = self.hugr_mut().add_ports(inp_node, Direction::Outgoing, 1); let new_port = new_port.next().unwrap(); @@ -211,9 +206,7 @@ impl FunctionBuilder { // Update the inner input node let types = new_optype.signature.body().output.clone(); - self.hugr_mut() - .replace_op(out_node, Output { types }) - .unwrap(); + self.hugr_mut().replace_op(out_node, Output { types }); let mut new_port = self.hugr_mut().add_ports(out_node, Direction::Incoming, 1); let new_port = new_port.next().unwrap(); @@ -250,15 +243,13 @@ impl FunctionBuilder { .expect("FunctionBuilder node must be a FuncDefn"); let signature = old_optype.inner_signature().into_owned(); let name = old_optype.name.clone(); - self.hugr_mut() - .replace_op( - parent, - ops::FuncDefn { - signature: f(signature).into(), - name, - }, - ) - .expect("Could not replace FunctionBuilder operation"); + self.hugr_mut().replace_op( + parent, + ops::FuncDefn { + signature: f(signature).into(), + name, + }, + ); self.hugr().get_optype(parent).as_func_defn().unwrap() } @@ -424,19 +415,15 @@ pub(crate) mod test { #[test] fn simple_inter_graph_edge() { let builder = || -> Result { - let mut f_build = FunctionBuilder::new( - "main", - Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), - )?; + let mut f_build = + FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?; let [i1] = f_build.input_wires_arr(); let noop = f_build.add_dataflow_op(Noop(bool_t()), [i1])?; let i1 = noop.out_wire(0); - let mut nested = f_build.dfg_builder( - Signature::new(type_row![], vec![bool_t()]).with_prelude(), - [], - )?; + let mut nested = + f_build.dfg_builder(Signature::new(type_row![], vec![bool_t()]), [])?; let id = nested.add_dataflow_op(Noop(bool_t()), [i1])?; @@ -451,10 +438,8 @@ pub(crate) mod test { #[test] fn add_inputs_outputs() { let builder = || -> Result<(Hugr, Node), BuildError> { - let mut f_build = FunctionBuilder::new( - "main", - Signature::new(vec![bool_t()], vec![bool_t()]).with_prelude(), - )?; + let mut f_build = + FunctionBuilder::new("main", Signature::new(vec![bool_t()], vec![bool_t()]))?; let f_node = f_build.container_node(); let [i0] = f_build.input_wires_arr(); @@ -512,8 +497,8 @@ pub(crate) mod test { #[rstest] fn dfg_hugr(simple_dfg_hugr: Hugr) { - assert_eq!(simple_dfg_hugr.node_count(), 3); - assert_matches!(simple_dfg_hugr.root_type().tag(), OpTag::Dfg); + assert_eq!(simple_dfg_hugr.num_nodes(), 3); + assert_matches!(simple_dfg_hugr.root_optype().tag(), OpTag::Dfg); } #[test] @@ -539,7 +524,7 @@ pub(crate) mod test { }; let hugr = module_builder.finish_hugr()?; - assert_eq!(hugr.node_count(), 7); + assert_eq!(hugr.num_nodes(), 7); assert_eq!(hugr.get_metadata(hugr.root(), "x"), None); assert_eq!(hugr.get_metadata(dfg_node, "x").cloned(), Some(json!(42))); diff --git a/hugr-core/src/builder/module.rs b/hugr-core/src/builder/module.rs index 18390926e..a77f01e5f 100644 --- a/hugr-core/src/builder/module.rs +++ b/hugr-core/src/builder/module.rs @@ -50,10 +50,7 @@ impl Default for ModuleBuilder { } impl HugrBuilder for ModuleBuilder { - fn finish_hugr(mut self) -> Result { - if cfg!(feature = "extension_inference") { - self.0.infer_extensions(false)?; - } + fn finish_hugr(self) -> Result { self.0.validate()?; Ok(self.0) } @@ -83,8 +80,7 @@ impl + AsRef> ModuleBuilder { .clone(); let body = signature.body().clone(); self.hugr_mut() - .replace_op(f_node, ops::FuncDefn { name, signature }) - .expect("Replacing a FuncDecl node with a FuncDefn should always be valid"); + .replace_op(f_node, ops::FuncDefn { name, signature }); let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?; Ok(FunctionBuilder::from_dfg_builder(db)) diff --git a/hugr-core/src/builder/tail_loop.rs b/hugr-core/src/builder/tail_loop.rs index fd6fb03b8..2baa0bcd5 100644 --- a/hugr-core/src/builder/tail_loop.rs +++ b/hugr-core/src/builder/tail_loop.rs @@ -1,4 +1,3 @@ -use crate::extension::{ExtensionSet, TO_BE_INFERRED}; use crate::ops::{self, DataflowOpTrait}; use crate::hugr::views::HugrView; @@ -72,29 +71,15 @@ impl + AsRef> TailLoopBuilder { impl TailLoopBuilder { /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. - /// Extension delta will be inferred. pub fn new( just_inputs: impl Into, inputs_outputs: impl Into, just_outputs: impl Into, - ) -> Result { - Self::new_exts(just_inputs, inputs_outputs, just_outputs, TO_BE_INFERRED) - } - - /// Initialize new builder for a [`ops::TailLoop`] rooted HUGR. - /// `extension_delta` is explicitly specified; alternatively, [new](Self::new) - /// may be used to infer it. - pub fn new_exts( - just_inputs: impl Into, - inputs_outputs: impl Into, - just_outputs: impl Into, - extension_delta: impl Into, ) -> Result { let tail_loop = ops::TailLoop { just_inputs: just_inputs.into(), just_outputs: just_outputs.into(), rest: inputs_outputs.into(), - extension_delta: extension_delta.into(), }; let base = Hugr::new(tail_loop.clone()); let root = base.root(); @@ -109,7 +94,7 @@ mod test { use crate::extension::prelude::bool_t; use crate::{ builder::{DataflowSubContainer, HugrBuilder, ModuleBuilder, SubContainer}, - extension::prelude::{usize_t, ConstUsize, PRELUDE_ID}, + extension::prelude::{usize_t, ConstUsize}, hugr::ValidationError, ops::Value, type_row, @@ -120,8 +105,7 @@ mod test { #[test] fn basic_loop() -> Result<(), BuildError> { let build_result: Result = { - let mut loop_b = - TailLoopBuilder::new_exts(vec![], vec![bool_t()], vec![usize_t()], PRELUDE_ID)?; + let mut loop_b = TailLoopBuilder::new(vec![], vec![bool_t()], vec![usize_t()])?; let [i1] = loop_b.input_wires_arr(); let const_wire = loop_b.add_load_value(ConstUsize::new(1)); @@ -138,10 +122,8 @@ mod test { fn loop_with_conditional() -> Result<(), BuildError> { let build_result = { let mut module_builder = ModuleBuilder::new(); - let mut fbuild = module_builder.define_function( - "main", - Signature::new(vec![bool_t()], vec![usize_t()]).with_prelude(), - )?; + let mut fbuild = module_builder + .define_function("main", Signature::new(vec![bool_t()], vec![usize_t()]))?; let _fdef = { let [b1] = fbuild.input_wires_arr(); let loop_id = { diff --git a/hugr-core/src/core.rs b/hugr-core/src/core.rs index 03e009bef..cc9da77ab 100644 --- a/hugr-core/src/core.rs +++ b/hugr-core/src/core.rs @@ -83,7 +83,7 @@ pub struct Wire(N, OutgoingPort); impl Node { /// Returns the node as a portgraph `NodeIndex`. #[inline] - pub(crate) fn pg_index(self) -> portgraph::NodeIndex { + pub(crate) fn into_portgraph(self) -> portgraph::NodeIndex { self.index } } diff --git a/hugr-core/src/envelope.rs b/hugr-core/src/envelope.rs index 35ea9c85f..24c348b78 100644 --- a/hugr-core/src/envelope.rs +++ b/hugr-core/src/envelope.rs @@ -55,7 +55,6 @@ use std::io::Write; #[allow(unused_imports)] use itertools::Itertools as _; -#[cfg(feature = "model_unstable")] use crate::import::ImportError; /// Read a HUGR envelope from a reader. @@ -197,19 +196,16 @@ pub enum EnvelopeError { source: PackageEncodingError, }, /// Error importing a HUGR from a hugr-model payload. - #[cfg(feature = "model_unstable")] ModelImport { /// The source error. source: ImportError, }, /// Error reading a HUGR model payload. - #[cfg(feature = "model_unstable")] ModelRead { /// The source error. source: hugr_model::v0::binary::ReadError, }, /// Error writing a HUGR model payload. - #[cfg(feature = "model_unstable")] ModelWrite { /// The source error. source: hugr_model::v0::binary::WriteError, @@ -225,17 +221,9 @@ fn read_impl( match header.format { #[allow(deprecated)] EnvelopeFormat::PackageJson => Ok(Package::from_json_reader(payload, registry)?), - #[cfg(feature = "model_unstable")] EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { decode_model(payload, registry, header.format) } - #[cfg(not(feature = "model_unstable"))] - EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { - Err(EnvelopeError::FormatUnsupported { - format: header.format, - feature: Some("model_unstable"), - }) - } } } @@ -246,7 +234,6 @@ fn read_impl( /// - `extension_registry`: An extension registry with additional extensions to use when /// decoding the HUGR, if they are not already included in the package. /// - `format`: The format of the payload. -#[cfg(feature = "model_unstable")] fn decode_model( mut stream: impl BufRead, extension_registry: &ExtensionRegistry, @@ -286,22 +273,13 @@ fn write_impl( match config.format { #[allow(deprecated)] EnvelopeFormat::PackageJson => package.to_json_writer(writer)?, - #[cfg(feature = "model_unstable")] EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { encode_model(writer, package, config.format)? } - #[cfg(not(feature = "model_unstable"))] - EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => { - return Err(EnvelopeError::FormatUnsupported { - format: config.format, - feature: Some("model_unstable"), - }) - } } Ok(()) } -#[cfg(feature = "model_unstable")] fn encode_model( mut writer: impl Write, package: &Package, @@ -391,7 +369,6 @@ mod tests { //#[case::empty(Package::default())] // Not currently supported #[case::simple(simple_package())] //#[case::multi(multi_module_package())] // Not currently supported - #[cfg(feature = "model_unstable")] fn module_exts_roundtrip(#[case] package: Package) { let mut buffer = Vec::new(); let config = EnvelopeConfig { @@ -417,15 +394,7 @@ mod tests { format: EnvelopeFormat::Model, zstd: None, }; - let res = package.store(&mut buffer, config); - - match cfg!(feature = "model_unstable") { - true => res.unwrap(), - false => { - assert_matches!(res, Err(EnvelopeError::FormatUnsupported { .. })); - return; - } - } + package.store(&mut buffer, config).unwrap(); let (decoded_config, new_package) = read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap(); diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index 09ccf944c..42e04629b 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -1,4 +1,5 @@ //! Exporting HUGR graphs to their `hugr-model` representation. +use crate::hugr::internal::HugrInternals; use crate::{ extension::{ExtensionId, OpDef, SignatureFunc}, hugr::IdentList, @@ -94,7 +95,7 @@ struct Context<'a> { impl<'a> Context<'a> { pub fn new(hugr: &'a Hugr, bump: &'a Bump) -> Self { let mut module = table::Module::default(); - module.nodes.reserve(hugr.node_count()); + module.nodes.reserve(hugr.num_nodes()); let links = Links::new(hugr); Self { @@ -831,7 +832,6 @@ impl<'a> Context<'a> { ); self.make_term(table::Term::List(parts)) } - TypeArg::Extensions { .. } => self.make_term_apply("compat.ext_set", &[]), TypeArg::Variable { v } => self.export_type_arg_var(v), } } @@ -938,7 +938,6 @@ impl<'a> Context<'a> { let types = self.make_term(table::Term::List(parts)); self.make_term_apply(model::CORE_TUPLE_TYPE, &[types]) } - TypeParam::Extensions => self.make_term_apply("compat.ext_set_type", &[]), } } @@ -999,7 +998,7 @@ impl<'a> Context<'a> { let outer_hugr = std::mem::replace(&mut self.hugr, hugr); let outer_node_to_id = std::mem::take(&mut self.node_to_id); - let region = match hugr.root_type() { + let region = match hugr.root_optype() { OpType::DFG(_) => self.export_dfg(hugr.root(), model::ScopeClosure::Closed), _ => panic!("Value::Function root must be a DFG"), }; @@ -1031,7 +1030,7 @@ impl<'a> Context<'a> { } pub fn export_node_metadata(&mut self, node: Node) -> &'a [table::TermId] { - let metadata_map = self.hugr.get_node_metadata(node); + let metadata_map = self.hugr.node_metadata_map(node); let has_order_edges = { fn is_relevant_node(hugr: &Hugr, node: Node) -> bool { @@ -1049,13 +1048,11 @@ impl<'a> Context<'a> { .any(|(other, _)| is_relevant_node(self.hugr, other)) }; - let meta_capacity = metadata_map.map_or(0, |map| map.len()) + has_order_edges as usize; + let meta_capacity = metadata_map.len() + has_order_edges as usize; let mut meta = BumpVec::with_capacity_in(meta_capacity, self.bump); - if let Some(metadata_map) = metadata_map { - for (name, value) in metadata_map { - meta.push(self.export_json_meta(name, value)); - } + for (name, value) in metadata_map { + meta.push(self.export_json_meta(name, value)); } if has_order_edges { @@ -1176,19 +1173,15 @@ mod test { use crate::{ builder::{Dataflow, DataflowSubContainer}, extension::prelude::qb_t, - std_extensions::arithmetic::float_types, types::Signature, - utils::test_quantum_extension::{self, cx_gate, h_gate}, + utils::test_quantum_extension::{cx_gate, h_gate}, Hugr, }; #[fixture] fn test_simple_circuit() -> Hugr { crate::builder::test::build_main( - Signature::new_endo(vec![qb_t(), qb_t()]) - .with_extension_delta(test_quantum_extension::EXTENSION_ID) - .with_extension_delta(float_types::EXTENSION_ID) - .into(), + Signature::new_endo(vec![qb_t(), qb_t()]).into(), |mut f_build| { let wires: Vec<_> = f_build.input_wires().collect(); let mut linear = f_build.as_circuit(wires); diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 408c88e15..4300c74ad 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -19,12 +19,11 @@ use derive_more::Display; use thiserror::Error; use crate::hugr::IdentList; -use crate::ops::constant::{ValueName, ValueNameRef}; use crate::ops::custom::{ExtensionOp, OpaqueOp}; -use crate::ops::{self, OpName, OpNameRef}; +use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::RowVariable; -use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; +use crate::types::{CustomType, TypeBound, TypeName}; use crate::types::{Signature, TypeNameRef}; mod const_fold; @@ -378,6 +377,7 @@ pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry { /// TODO: decide on failure modes #[derive(Debug, Clone, Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum SignatureError { /// Name mismatch #[error("Definition name ({0}) and instantiation name ({1}) do not match.")] @@ -496,37 +496,6 @@ impl CustomConcrete for CustomType { } } -/// A constant value provided by a extension. -/// Must be an instance of a type available to the extension. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct ExtensionValue { - extension: ExtensionId, - name: ValueName, - typed_value: ops::Value, -} - -impl ExtensionValue { - /// Returns a reference to the typed value of this [`ExtensionValue`]. - pub fn typed_value(&self) -> &ops::Value { - &self.typed_value - } - - /// Returns a mutable reference to the typed value of this [`ExtensionValue`]. - pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value { - &mut self.typed_value - } - - /// Returns a reference to the name of this [`ExtensionValue`]. - pub fn name(&self) -> &str { - self.name.as_str() - } - - /// Returns a reference to the extension this [`ExtensionValue`] belongs to. - pub fn extension(&self) -> &ExtensionId { - &self.extension - } -} - /// A unique identifier for a extension. /// /// The actual [`Extension`] is stored externally. @@ -578,12 +547,8 @@ pub struct Extension { pub version: Version, /// Unique identifier for the extension. pub name: ExtensionId, - /// Runtime dependencies this extension has on other extensions. - pub runtime_reqs: ExtensionSet, /// Types defined by this extension. types: BTreeMap, - /// Static values defined by this extension. - values: BTreeMap, /// Operation declarations with serializable definitions. // Note: serde will serialize this because we configure with `features=["rc"]`. // That will clone anything that has multiple references, but each @@ -605,9 +570,7 @@ impl Extension { Self { name, version, - runtime_reqs: Default::default(), types: Default::default(), - values: Default::default(), operations: Default::default(), } } @@ -663,12 +626,6 @@ impl Extension { } } - /// Extend the runtime requirements of this extension with another set of extensions. - pub fn add_requirements(&mut self, runtime_reqs: impl Into) { - let reqs = mem::take(&mut self.runtime_reqs); - self.runtime_reqs = reqs.union(runtime_reqs.into()); - } - /// Allows read-only access to the operations in this Extension pub fn get_op(&self, name: &OpNameRef) -> Option<&Arc> { self.operations.get(name) @@ -679,11 +636,6 @@ impl Extension { self.types.get(type_name) } - /// Allows read-only access to the values in this Extension - pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> { - self.values.get(value_name) - } - /// Returns the name of the extension. pub fn name(&self) -> &ExtensionId { &self.name @@ -704,25 +656,6 @@ impl Extension { self.types.iter() } - /// Add a named static value to the extension. - pub fn add_value( - &mut self, - name: impl Into, - typed_value: ops::Value, - ) -> Result<&mut ExtensionValue, ExtensionBuildError> { - let extension_value = ExtensionValue { - extension: self.name.clone(), - name: name.into(), - typed_value, - }; - match self.values.entry(extension_value.name.clone()) { - btree_map::Entry::Occupied(_) => { - Err(ExtensionBuildError::ValueExists(extension_value.name)) - } - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)), - } - } - /// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension. pub fn instantiate_extension_op( &self, @@ -783,9 +716,6 @@ pub enum ExtensionBuildError { /// Existing [`TypeDef`] #[error("Extension already has an type called {0}.")] TypeDefExists(TypeName), - /// Existing [`ExtensionValue`] - #[error("Extension already has an extension value called {0}.")] - ValueExists(ValueName), } /// A set of extensions identified by their unique [`ExtensionId`]. @@ -795,14 +725,6 @@ pub enum ExtensionBuildError { #[display("[{}]", _0.iter().join(", "))] pub struct ExtensionSet(BTreeSet); -/// A special ExtensionId which indicates that the delta of a non-Function -/// container node should be computed by extension inference. -/// -/// See [`infer_extensions`] which lists the container nodes to which this can be applied. -/// -/// [`infer_extensions`]: crate::hugr::Hugr::infer_extensions -pub const TO_BE_INFERRED: ExtensionId = ExtensionId::new_unchecked(".TO_BE_INFERRED"); - impl ExtensionSet { /// Creates a new empty extension set. pub const fn new() -> Self { @@ -814,14 +736,6 @@ impl ExtensionSet { self.0.insert(extension.clone()); } - /// Adds a type var (which must have been declared as a [TypeParam::Extensions]) to this set - pub fn insert_type_var(&mut self, idx: usize) { - // Represent type vars as string representation of variable index. - // This is not a legal IdentList or ExtensionId so should not conflict. - self.0 - .insert(ExtensionId::new_unchecked(idx.to_string().as_str())); - } - /// Returns `true` if the set contains the given extension. pub fn contains(&self, extension: &ExtensionId) -> bool { self.0.contains(extension) @@ -844,14 +758,6 @@ impl ExtensionSet { set } - /// An ExtensionSet containing a single type variable - /// (which must have been declared as a [TypeParam::Extensions]) - pub fn type_var(idx: usize) -> Self { - let mut set = Self::new(); - set.insert_type_var(idx); - set - } - /// Returns the union of two extension sets. pub fn union(mut self, other: Self) -> Self { self.0.extend(other.0); @@ -882,22 +788,6 @@ impl ExtensionSet { pub fn is_empty(&self) -> bool { self.0.is_empty() } - - pub(crate) fn validate(&self, params: &[TypeParam]) -> Result<(), SignatureError> { - self.iter() - .filter_map(as_typevar) - .try_for_each(|var_idx| check_typevar_decl(params, var_idx, &TypeParam::Extensions)) - } - - pub(crate) fn substitute(&self, t: &Substitution) -> Self { - Self::from_iter(self.0.iter().flat_map(|e| match as_typevar(e) { - None => vec![e.clone()], - Some(i) => match t.apply_var(i, &TypeParam::Extensions) { - TypeArg::Extensions{es} => es.iter().cloned().collect::>(), - _ => panic!("value for type var was not extension set - type scheme should be validated first"), - }, - })) - } } impl From for ExtensionSet { @@ -924,16 +814,6 @@ impl<'a> IntoIterator for &'a ExtensionSet { } } -fn as_typevar(e: &ExtensionId) -> Option { - // Type variables are represented as radix-10 numbers, which are illegal - // as standard ExtensionIds. Hence if an ExtensionId starts with a digit, - // we assume it must be a type variable, and fail fast if it isn't. - match e.chars().next() { - Some(c) if c.is_ascii_digit() => Some(str::parse(e).unwrap()), - _ => None, - } -} - impl FromIterator for ExtensionSet { fn from_iter>(iter: I) -> Self { Self(BTreeSet::from_iter(iter)) @@ -1028,16 +908,8 @@ pub mod test { type Strategy = BoxedStrategy; fn arbitrary_with(_: Self::Parameters) -> Self::Strategy { - ( - hash_set(0..10usize, 0..3), - hash_set(any::(), 0..3), - ) - .prop_map(|(vars, extensions)| { - ExtensionSet::union_over( - std::iter::once(extensions.into_iter().collect::()) - .chain(vars.into_iter().map(ExtensionSet::type_var)), - ) - }) + hash_set(any::(), 0..3) + .prop_map(|extensions| extensions.into_iter().collect::()) .boxed() } } diff --git a/hugr-core/src/extension/declarative.rs b/hugr-core/src/extension/declarative.rs index 64092981f..14995db27 100644 --- a/hugr-core/src/extension/declarative.rs +++ b/hugr-core/src/extension/declarative.rs @@ -149,9 +149,14 @@ impl ExtensionDeclaration { /// Create an [`Extension`] from this declaration. pub fn make_extension( &self, - imports: &ExtensionSet, + _imports: &ExtensionSet, ctx: DeclarationContext<'_>, ) -> Result, ExtensionDeclarationError> { + // TODO: The imports were previously used as runtime extension + // requirements for the constructed extension. Now that runtime + // extension requirements are removed, they are no longer recorded + // anywhere in the `Extension`. + Extension::try_new_arc( self.name.clone(), // TODO: Get the version as a parameter. @@ -164,7 +169,6 @@ impl ExtensionDeclaration { for o in &self.operations { o.register(ext, ctx, extension_ref)?; } - ext.add_requirements(imports.clone()); Ok(()) }, diff --git a/hugr-core/src/extension/declarative/signature.rs b/hugr-core/src/extension/declarative/signature.rs index b84d56853..e2300956b 100644 --- a/hugr-core/src/extension/declarative/signature.rs +++ b/hugr-core/src/extension/declarative/signature.rs @@ -12,7 +12,7 @@ use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use crate::extension::prelude::PRELUDE_ID; -use crate::extension::{ExtensionSet, SignatureFunc, TypeDef}; +use crate::extension::{SignatureFunc, TypeDef}; use crate::types::type_param::TypeParam; use crate::types::{CustomType, FuncValueType, PolyFuncTypeRV, Type, TypeRowRV}; use crate::Extension; @@ -26,10 +26,6 @@ pub(super) struct SignatureDeclaration { inputs: Vec, /// The outputs of the operation. outputs: Vec, - /// A set of extensions invoked while running this operation. - #[serde(default)] - #[serde(skip_serializing_if = "crate::utils::is_default")] - extensions: ExtensionSet, } impl SignatureDeclaration { @@ -53,7 +49,6 @@ impl SignatureDeclaration { let body = FuncValueType { input: make_type_row(&self.inputs)?, output: make_type_row(&self.outputs)?, - runtime_reqs: self.extensions.clone(), }; let poly_func = PolyFuncTypeRV::new(op_params, body); diff --git a/hugr-core/src/extension/op_def.rs b/hugr-core/src/extension/op_def.rs index d5c9a5b5d..48eef663f 100644 --- a/hugr-core/src/extension/op_def.rs +++ b/hugr-core/src/extension/op_def.rs @@ -244,11 +244,7 @@ impl SignatureFunc { // TODO raise warning: https://github.com/CQCL/hugr/issues/1432 SignatureFunc::MissingValidateFunc(ts) => (ts, args), }; - let mut res = pf.instantiate(args)?; - - // Automatically add the extensions where the operation is defined to - // the runtime requirements of the op. - res.runtime_reqs.insert(def.extension.clone()); + let res = pf.instantiate(args)?; // If there are any row variables left, this will fail with an error: res.try_into() @@ -722,8 +718,7 @@ pub(super) mod test { Ok(Signature::new( vec![usize_t(); 3], vec![Type::new_tuple(vec![usize_t(); 3])] - ) - .with_extension_delta(EXT_ID)) + )) ); assert_eq!(def.validate_args(&args, &[]), Ok(())); @@ -733,10 +728,10 @@ pub(super) mod test { let args = [TypeArg::BoundedNat { n: 3 }, tyvar.clone().into()]; assert_eq!( def.compute_signature(&args), - Ok( - Signature::new(tyvars.clone(), vec![Type::new_tuple(tyvars)]) - .with_extension_delta(EXT_ID) - ) + Ok(Signature::new( + tyvars.clone(), + vec![Type::new_tuple(tyvars)] + )) ); def.validate_args(&args, &[TypeBound::Copyable.into()]) .unwrap(); @@ -787,14 +782,11 @@ pub(super) mod test { ), extension_ref, )?; - let tv = Type::new_var_use(1, TypeBound::Copyable); + let tv = Type::new_var_use(0, TypeBound::Copyable); let args = [TypeArg::Type { ty: tv.clone() }]; - let decls = [TypeParam::Extensions, TypeBound::Copyable.into()]; + let decls = [TypeBound::Copyable.into()]; def.validate_args(&args, &decls).unwrap(); - assert_eq!( - def.compute_signature(&args), - Ok(Signature::new_endo(tv).with_extension_delta(EXT_ID)) - ); + assert_eq!(def.compute_signature(&args), Ok(Signature::new_endo(tv))); // But not with an external row variable let arg: TypeArg = TypeRV::new_row_var_use(0, TypeBound::Copyable).into(); assert_eq!( @@ -811,36 +803,6 @@ pub(super) mod test { Ok(()) } - #[test] - fn instantiate_extension_delta() -> Result<(), Box> { - use crate::extension::prelude::bool_t; - - let _ext = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![TypeParam::Extensions]; - let db_set = ExtensionSet::type_var(0); - let fun_ty = Signature::new_endo(bool_t()).with_extension_delta(db_set); - - let def = ext.add_op( - "SimpleOp".into(), - "".into(), - PolyFuncTypeRV::new(params.clone(), fun_ty), - extension_ref, - )?; - - // Concrete extension set - let es = ExtensionSet::singleton(EXT_ID); - let exp_fun_ty = Signature::new_endo(bool_t()).with_extension_delta(es.clone()); - let args = [TypeArg::Extensions { es }]; - - def.validate_args(&args, ¶ms).unwrap(); - assert_eq!(def.compute_signature(&args), Ok(exp_fun_ty)); - - Ok(()) - })?; - - Ok(()) - } - mod proptest { use std::sync::Weak; diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index f88a84a0d..b1e78baf8 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -11,7 +11,7 @@ use crate::extension::simple_op::{ try_from_name, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; use crate::extension::{ - ConstFold, ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDefBound, + ConstFold, ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDefBound, }; use crate::ops::constant::{CustomCheckFailure, CustomConst, ValueName}; use crate::ops::OpName; @@ -245,10 +245,6 @@ impl CustomConst for ConstString { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } - fn get_type(&self) -> Type { string_type() } @@ -438,10 +434,6 @@ impl CustomConst for ConstUsize { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } - fn get_type(&self) -> Type { usize_t() } @@ -495,9 +487,6 @@ impl CustomConst for ConstError { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } fn get_type(&self) -> Type { error_type() } @@ -555,9 +544,6 @@ impl CustomConst for ConstExternalSymbol { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(PRELUDE_ID) - } fn get_type(&self) -> Type { self.typ.clone() } @@ -1068,7 +1054,7 @@ mod test { let optype: OpType = op.clone().into(); assert_eq!( optype.dataflow_signature().unwrap().as_ref(), - &Signature::new_endo(type_row![Type::UNIT]).with_prelude() + &Signature::new_endo(type_row![Type::UNIT]) ); let new_op = Barrier::from_extension_op(optype.as_extension_op().unwrap()).unwrap(); @@ -1121,10 +1107,6 @@ mod test { assert!(error_val.validate().is_ok()); - assert_eq!( - error_val.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(error_val.equal_consts(&ConstError::new(2, "my message"))); assert!(!error_val.equal_consts(&ConstError::new(3, "my message"))); @@ -1181,10 +1163,6 @@ mod test { let string_const: ConstString = ConstString::new("Lorem ipsum".into()); assert_eq!(string_const.name(), "ConstString(\"Lorem ipsum\")"); assert!(string_const.validate().is_ok()); - assert_eq!( - string_const.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(string_const.equal_consts(&ConstString::new("Lorem ipsum".into()))); assert!(!string_const.equal_consts(&ConstString::new("Lorem ispum".into()))); } @@ -1206,10 +1184,6 @@ mod test { assert_eq!(subject.get_type(), Type::UNIT); assert_eq!(subject.name(), "@foo"); assert!(subject.validate().is_ok()); - assert_eq!( - subject.extension_reqs(), - ExtensionSet::singleton(PRELUDE_ID) - ); assert!(subject.equal_consts(&ConstExternalSymbol::new("foo", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("bar", Type::UNIT, false))); assert!(!subject.equal_consts(&ConstExternalSymbol::new("foo", string_type(), false))); diff --git a/hugr-core/src/extension/prelude/unwrap_builder.rs b/hugr-core/src/extension/prelude/unwrap_builder.rs index 06b4e3939..3817d65c8 100644 --- a/hugr-core/src/extension/prelude/unwrap_builder.rs +++ b/hugr-core/src/extension/prelude/unwrap_builder.rs @@ -111,10 +111,8 @@ mod tests { #[test] fn test_build_unwrap() { - let mut builder = DFGBuilder::new( - Signature::new(Type::from(option_type(bool_t())), bool_t()).with_prelude(), - ) - .unwrap(); + let mut builder = + DFGBuilder::new(Signature::new(Type::from(option_type(bool_t())), bool_t())).unwrap(); let [opt] = builder.input_wires_arr(); diff --git a/hugr-core/src/extension/resolution.rs b/hugr-core/src/extension/resolution.rs index 90eae9422..a08cbfb38 100644 --- a/hugr-core/src/extension/resolution.rs +++ b/hugr-core/src/extension/resolution.rs @@ -9,10 +9,6 @@ //! HUGR nodes and wire types. This is computed from the union of all extension //! required across the HUGR. //! -//! This is distinct from _runtime_ extension requirements, which are defined -//! more granularly in each function signature by the `runtime_reqs` -//! field. See the `extension_inference` feature and related modules for that. -//! //! Note: These procedures are only temporary until `hugr-model` is stabilized. //! Once that happens, hugrs will no longer be directly deserialized using serde //! but instead will be created by the methods in `crate::import`. As these diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 61adc1dea..05c0faf69 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef}; -use super::types_mut::{resolve_signature_exts, resolve_value_exts}; +use super::types_mut::resolve_signature_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -59,14 +59,7 @@ impl Extension { for type_def in self.types.values_mut() { resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?; } - for val in self.values.values_mut() { - resolve_value_exts( - None, - val.typed_value_mut(), - extensions, - &mut used_extensions, - )?; - } + let ops = mem::take(&mut self.operations); for (op_id, mut op_def) in ops { // TODO: We should be able to clone the definition if needed by using `make_mut`, diff --git a/hugr-core/src/extension/resolution/test.rs b/hugr-core/src/extension/resolution/test.rs index 19373b04c..f3ae229ec 100644 --- a/hugr-core/src/extension/resolution/test.rs +++ b/hugr-core/src/extension/resolution/test.rs @@ -11,7 +11,7 @@ use crate::builder::{ Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; -use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize, PRELUDE_ID}; +use crate::extension::prelude::{bool_t, usize_custom_t, usize_t, ConstUsize}; use crate::extension::resolution::WeakExtensionRegistry; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionCollectionError, @@ -28,7 +28,7 @@ use crate::std_extensions::arithmetic::int_types::{self, int_type}; use crate::std_extensions::collections::list::ListValue; use crate::types::type_param::TypeParam; use crate::types::{PolyFuncType, Signature, Type, TypeArg, TypeBound}; -use crate::{std_extensions, type_row, Extension, Hugr, HugrView}; +use crate::{type_row, Extension, Hugr, HugrView}; #[rstest] #[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())] @@ -158,17 +158,7 @@ fn check_extension_resolution(mut hugr: Hugr) { /// Build a small hugr using the float types extension and check that the extensions are resolved. #[rstest] fn resolve_hugr_extensions_simple() { - let mut build = DFGBuilder::new( - Signature::new(vec![], vec![float64_type()]).with_extension_delta( - [ - PRELUDE_ID.to_owned(), - std_extensions::arithmetic::float_types::EXTENSION_ID.to_owned(), - ] - .into_iter() - .collect::(), - ), - ) - .unwrap(); + let mut build = DFGBuilder::new(Signature::new(vec![], vec![float64_type()])).unwrap(); // A constant op using a non-prelude extension. let f_const = build.add_load_const(Value::extension(ConstF64::new(f64::consts::PI))); @@ -218,7 +208,7 @@ fn resolve_hugr_extensions() { let (ext_b, op_b) = make_extension("dummy.b", "op_b"); let (ext_c, op_c) = make_extension("dummy.c", "op_c"); let (ext_d, op_d) = make_extension("dummy.d", "op_d"); - let (ext_e, op_e) = make_extension("dummy.e", "op_e"); + let (_ext_e, op_e) = make_extension("dummy.e", "op_e"); let mut module = ModuleBuilder::new(); @@ -234,18 +224,7 @@ fn resolve_hugr_extensions() { let mut func = module .define_function( "dummy_fn", - Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( - [ - ext_a.name(), - ext_b.name(), - ext_c.name(), - ext_d.name(), - ext_e.name(), - ] - .into_iter() - .cloned() - .collect::(), - ), + Signature::new(vec![float64_type(), bool_t()], vec![]), ) .unwrap(); let [func_i0, func_i1] = func.input_wires_arr(); @@ -368,11 +347,7 @@ fn resolve_call() { let dummy_fn = module.declare("called_fn", dummy_fn_sig).unwrap(); let mut func = module - .define_function( - "caller_fn", - Signature::new(vec![], vec![bool_t()]) - .with_extension_delta(ExtensionSet::from_iter(expected_exts.clone())), - ) + .define_function("caller_fn", Signature::new(vec![], vec![bool_t()])) .unwrap(); let _load_func = func.load_func(&dummy_fn, &[generic_type_1]).unwrap(); let call = func.call(&dummy_fn, &[generic_type_2], vec![]).unwrap(); @@ -390,15 +365,10 @@ fn resolve_call() { /// Fail when collecting extensions but the weak pointers are not resolved. #[rstest] fn dropped_weak_extensions() { - let (ext_a, op_a) = make_extension("dummy.a", "op_a"); + let (_ext_a, op_a) = make_extension("dummy.a", "op_a"); let mut func = FunctionBuilder::new( "dummy_fn", - Signature::new(vec![float64_type(), bool_t()], vec![]).with_extension_delta( - [ext_a.name()] - .into_iter() - .cloned() - .collect::(), - ), + Signature::new(vec![float64_type(), bool_t()], vec![]), ) .unwrap(); let [_func_i0, func_i1] = func.input_wires_arr(); diff --git a/hugr-core/src/extension/resolution/types.rs b/hugr-core/src/extension/resolution/types.rs index 6094f0aee..28bd6a12b 100644 --- a/hugr-core/src/extension/resolution/types.rs +++ b/hugr-core/src/extension/resolution/types.rs @@ -131,8 +131,6 @@ pub(crate) fn collect_signature_exts( used_extensions: &mut WeakExtensionRegistry, missing_extensions: &mut ExtensionSet, ) { - // Note that we do not include the signature's `runtime_reqs` here, as those refer - // to _runtime_ requirements that we do not be require to be defined. collect_type_row_exts(&signature.input, used_extensions, missing_extensions); collect_type_row_exts(&signature.output, used_extensions, missing_extensions); } diff --git a/hugr-core/src/extension/resolution/types_mut.rs b/hugr-core/src/extension/resolution/types_mut.rs index d70d6b861..af5803eff 100644 --- a/hugr-core/src/extension/resolution/types_mut.rs +++ b/hugr-core/src/extension/resolution/types_mut.rs @@ -124,8 +124,6 @@ pub(super) fn resolve_signature_exts( extensions: &WeakExtensionRegistry, used_extensions: &mut WeakExtensionRegistry, ) -> Result<(), ExtensionResolutionError> { - // Note that we do not include the signature's `runtime_reqs` here, as those refer - // to _runtime_ requirements that may not be currently present. resolve_type_row_exts(node, &mut signature.input, extensions, used_extensions)?; resolve_type_row_exts(node, &mut signature.output, extensions, used_extensions)?; Ok(()) diff --git a/hugr-core/src/hugr.rs b/hugr-core/src/hugr.rs index 2789ae056..16152b298 100644 --- a/hugr-core/src/hugr.rs +++ b/hugr-core/src/hugr.rs @@ -4,7 +4,7 @@ pub mod hugrmut; pub(crate) mod ident; pub mod internal; -pub mod rewrite; +pub mod patch; pub mod serialize; pub mod validate; pub mod views; @@ -17,20 +17,20 @@ pub(crate) use self::hugrmut::HugrMut; pub use self::validate::ValidationError; pub use ident::{IdentList, InvalidIdentifier}; -pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError}; +pub use patch::{Patch, SimpleReplacement, SimpleReplacementError}; use portgraph::multiportgraph::MultiPortGraph; use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap}; use thiserror::Error; -pub use self::views::{HugrView, RootTagged}; +pub use self::views::HugrView; use crate::core::NodeIndex; use crate::extension::resolution::{ resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError, WeakExtensionRegistry, }; -use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED}; -use crate::ops::{OpTag, OpTrait}; +use crate::extension::{ExtensionRegistry, ExtensionSet}; +use crate::ops::OpTag; pub use crate::ops::{OpType, DEFAULT_OPTYPE}; use crate::{Direction, Node}; @@ -89,13 +89,29 @@ impl Hugr { Self::with_capacity(root_node.into(), 0, 0) } + /// Create a new Hugr, with a single root node and preallocated capacity. + pub fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self { + let mut graph = MultiPortGraph::with_capacity(nodes, ports); + let hierarchy = Hierarchy::new(); + let mut op_types = UnmanagedDenseMap::with_capacity(nodes); + let root = graph.add_node(root_node.input_count(), root_node.output_count()); + let extensions = root_node.used_extensions(); + op_types[root] = root_node; + + Self { + graph, + hierarchy, + root, + op_types, + metadata: UnmanagedDenseMap::with_capacity(nodes), + extensions: extensions.unwrap_or_default(), + } + } + /// Load a Hugr from a json reader. /// /// Validates the Hugr against the provided extension registry, ensuring all /// operations are resolved. - /// - /// If the feature `extension_inference` is enabled, we will ensure every function - /// correctly specifies the extensions required by its contained ops. pub fn load_json( reader: impl Read, extension_registry: &ExtensionRegistry, @@ -103,87 +119,11 @@ impl Hugr { let mut hugr: Hugr = serde_json::from_reader(reader)?; hugr.resolve_extension_defs(extension_registry)?; - hugr.validate_no_extensions()?; - - if cfg!(feature = "extension_inference") { - hugr.infer_extensions(false)?; - hugr.validate_extensions()?; - } + hugr.validate()?; Ok(hugr) } - /// Infers an extension-delta for any non-function container node - /// whose current [extension_delta] contains [TO_BE_INFERRED]. The inferred delta - /// will be the smallest delta compatible with its children and that includes any - /// other [ExtensionId]s in the current delta. - /// - /// If `remove` is true, for such container nodes *without* [TO_BE_INFERRED], - /// ExtensionIds are removed from the delta if they are *not* used by any child node. - /// - /// The non-function container nodes are: - /// [Case], [CFG], [Conditional], [DataflowBlock], [DFG], [TailLoop] - /// - /// [Case]: crate::ops::Case - /// [CFG]: crate::ops::CFG - /// [Conditional]: crate::ops::Conditional - /// [DataflowBlock]: crate::ops::DataflowBlock - /// [DFG]: crate::ops::DFG - /// [TailLoop]: crate::ops::TailLoop - /// [extension_delta]: crate::ops::OpType::extension_delta - /// [ExtensionId]: crate::extension::ExtensionId - pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> { - fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> { - match optype { - OpType::DFG(dfg) => Some(&mut dfg.signature.runtime_reqs), - OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta), - OpType::TailLoop(tl) => Some(&mut tl.extension_delta), - OpType::CFG(cfg) => Some(&mut cfg.signature.runtime_reqs), - OpType::Conditional(c) => Some(&mut c.extension_delta), - OpType::Case(c) => Some(&mut c.signature.runtime_reqs), - //OpType::Lift(_) // Not ATM: only a single element, and we expect Lift to be removed - //OpType::FuncDefn(_) // Not at present due to the possibility of recursion - _ => None, - } - } - fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result { - let mut child_sets = h - .children(node) - .collect::>() // Avoid borrowing h over recursive call - .into_iter() - .map(|ch| Ok((ch, infer(h, ch, remove)?))) - .collect::, _>>()?; - - let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else { - return Ok(h.get_optype(node).extension_delta()); - }; - if es.contains(&TO_BE_INFERRED) { - // Do not remove anything from current delta - any other elements are a lower bound - child_sets.push((node, es.clone())); // "child_sets" now misnamed but we discard fst - } else if remove { - child_sets.iter().try_for_each(|(ch, ch_exts)| { - if !es.is_superset(ch_exts) { - return Err(ExtensionError { - parent: node, - parent_extensions: es.clone(), - child: *ch, - child_extensions: ch_exts.clone(), - }); - } - Ok(()) - })?; - } else { - return Ok(es.clone()); // Can't neither add nor remove, so nothing to do - } - let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e)); - *es = ExtensionSet::singleton(TO_BE_INFERRED).missing_from(&merged); - - Ok(es.clone()) - } - infer(self, self.root(), remove)?; - Ok(()) - } - /// Given a Hugr that has been deserialized, collect all extensions used to /// define the HUGR while resolving all [`OpType::OpaqueOp`] operations into /// [`OpType::ExtensionOp`]s and updating the extension pointer in all @@ -195,11 +135,6 @@ impl Hugr { /// to define the HUGR nodes and wire types. This is computed from the union /// of all extension required across the HUGR. /// - /// This is distinct from _runtime_ extension requirements computed in - /// [`Hugr::infer_extensions`], which are computed more granularly in each - /// function signature by the `runtime_reqs` field and define the set - /// of capabilities required by the runtime to execute each function. - /// /// Updates the internal extension registry with the extensions used in the /// definition. /// @@ -260,31 +195,6 @@ impl Hugr { /// Internal API for HUGRs, not intended for use by users. impl Hugr { - /// Create a new Hugr, with a single root node and preallocated capacity. - pub(crate) fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self { - let mut graph = MultiPortGraph::with_capacity(nodes, ports); - let hierarchy = Hierarchy::new(); - let mut op_types = UnmanagedDenseMap::with_capacity(nodes); - let root = graph.add_node(root_node.input_count(), root_node.output_count()); - let extensions = root_node.used_extensions(); - op_types[root] = root_node; - - Self { - graph, - hierarchy, - root, - op_types, - metadata: UnmanagedDenseMap::with_capacity(nodes), - extensions: extensions.unwrap_or_default(), - } - } - - /// Set the root node of the hugr. - pub(crate) fn set_root(&mut self, root: Node) { - self.hierarchy.detach(self.root); - self.root = root.pg_index(); - } - /// Add a node to the graph. pub(crate) fn add_node(&mut self, nodetype: OpType) -> Node { let node = self @@ -322,7 +232,7 @@ impl Hugr { /// preserve the indices. pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) { // Generate the ordered list of nodes - let mut ordered = Vec::with_capacity(self.node_count()); + let mut ordered = Vec::with_capacity(self.num_nodes()); let root = self.root(); ordered.extend(self.as_mut().canonical_order(root)); @@ -339,8 +249,8 @@ impl Hugr { let target: Node = portgraph::NodeIndex::new(position).into(); if target != source { - let pg_target = target.pg_index(); - let pg_source = source.pg_index(); + let pg_target = target.into_portgraph(); + let pg_source = source.into_portgraph(); self.graph.swap_nodes(pg_target, pg_source); self.op_types.swap(pg_target, pg_source); self.hierarchy.swap_nodes(pg_target, pg_source); @@ -367,13 +277,10 @@ pub struct ExtensionError { } /// Errors that can occur while manipulating a Hugr. -/// -/// TODO: Better descriptions, not just re-exporting portgraph errors. #[derive(Debug, Clone, PartialEq, Eq, Error)] #[non_exhaustive] pub enum HugrError { /// The node was not of the required [OpTag] - /// (e.g. to conform to the [RootTagged::RootHandle] of a [HugrView]) #[error("Invalid tag: required a tag in {required} but found {actual}")] #[allow(missing_docs)] InvalidTag { required: OpTag, actual: OpTag }, @@ -402,73 +309,13 @@ pub enum LoadHugrError { #[cfg(test)] mod test { - use std::sync::Arc; use std::{fs::File, io::BufReader}; - use super::internal::HugrMutInternals; - #[cfg(feature = "extension_inference")] - use super::ValidationError; - use super::{ExtensionError, Hugr, HugrMut, HugrView, Node}; - use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY, TO_BE_INFERRED}; - use crate::ops::{ExtensionOp, OpName}; - use crate::types::type_param::TypeParam; - use crate::types::{ - FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV, TypeRow, - }; - - use crate::{const_extension_ids, ops, test_file, type_row, Extension}; - use cool_asserts::assert_matches; - use lazy_static::lazy_static; - use rstest::rstest; - - const_extension_ids! { - pub(crate) const LIFT_EXT_ID: ExtensionId = "LIFT_EXT_ID"; - } - lazy_static! { - /// Tests only extension holding an Op that can add arbitrary extensions to a row. - pub(crate) static ref LIFT_EXT: Arc = { - Extension::new_arc( - LIFT_EXT_ID, - hugr::extension::Version::new(0, 0, 0), - |ext, extension_ref| { - ext.add_op( - OpName::new_inline("Lift"), - "".into(), - PolyFuncTypeRV::new( - vec![TypeParam::Extensions, TypeParam::new_list(TypeBound::Any)], - FuncValueType::new_endo(TypeRV::new_row_var_use(1, TypeBound::Any)) - .with_extension_delta(ExtensionSet::type_var(0)), - ), - extension_ref, - ) - .unwrap(); - }, - ) - }; - } + use super::{Hugr, HugrView}; + use crate::extension::PRELUDE_REGISTRY; - pub(crate) fn lift_op( - type_row: impl Into, - extensions: impl Into, - ) -> ExtensionOp { - LIFT_EXT - .instantiate_extension_op( - "Lift", - [ - TypeArg::Extensions { - es: extensions.into(), - }, - TypeArg::Sequence { - elems: type_row - .into() - .iter() - .map(|t| TypeArg::Type { ty: t.clone() }) - .collect(), - }, - ], - ) - .unwrap() - } + use crate::test_file; + use cool_asserts::assert_matches; #[test] fn impls_send_and_sync() { @@ -531,164 +378,4 @@ mod test { ); assert_matches!(&hugr, Ok(_)); } - - const_extension_ids! { - const XA: ExtensionId = "EXT_A"; - const XB: ExtensionId = "EXT_B"; - } - - #[rstest] - #[case([], XA.into())] - #[case([XA], XA.into())] - #[case([XB], ExtensionSet::from_iter([XA, XB]))] - - fn infer_single_delta( - #[case] parent: impl IntoIterator, - #[values(true, false)] remove: bool, // makes no difference when inferring - #[case] result: ExtensionSet, - ) { - let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into()); - let (mut h, _) = build_ext_dfg(parent); - h.infer_extensions(remove).unwrap(); - assert_eq!(h, build_ext_dfg(result.union(LIFT_EXT_ID.into())).0); - } - - #[test] - fn infer_removes_from_delta() { - let parent = ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]); - let mut h = build_ext_dfg(parent.clone()).0; - let backup = h.clone(); - h.infer_extensions(false).unwrap(); - assert_eq!(h, backup); // did nothing - h.infer_extensions(true).unwrap(); - assert_eq!( - h, - build_ext_dfg(ExtensionSet::from_iter([XA, LIFT_EXT_ID])).0 - ); - } - - #[test] - fn infer_bad_remove() { - let (mut h, mid) = build_ext_dfg(XB.into()); - let backup = h.clone(); - h.infer_extensions(false).unwrap(); - assert_eq!(h, backup); // did nothing - let val_res = h.validate(); - let expected_err = ExtensionError { - parent: h.root(), - parent_extensions: XB.into(), - child: mid, - child_extensions: ExtensionSet::from_iter([XA, LIFT_EXT_ID]), - }; - #[cfg(feature = "extension_inference")] - assert_eq!( - val_res, - Err(ValidationError::ExtensionError(expected_err.clone())) - ); - #[cfg(not(feature = "extension_inference"))] - assert!(val_res.is_ok()); - - let inf_res = h.infer_extensions(true); - assert_eq!(inf_res, Err(expected_err)); - } - - fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) { - let ty = Type::new_function(Signature::new_endo(type_row![])); - let mut h = Hugr::new(ops::DFG { - signature: Signature::new_endo(ty.clone()).with_extension_delta(parent.clone()), - }); - let root = h.root(); - let mid = add_inliftout(&mut h, root, ty); - (h, mid) - } - - fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node { - let inp = h.add_node_with_parent( - p, - ops::Input { - types: ty.clone().into(), - }, - ); - let out = h.add_node_with_parent( - p, - ops::Output { - types: ty.clone().into(), - }, - ); - let mid = h.add_node_with_parent(p, lift_op(ty, XA)); - h.connect(inp, 0, mid, 0); - h.connect(mid, 0, out, 0); - mid - } - - #[rstest] - // Base case success: delta inferred for parent equals grandparent. - #[case([XA], [TO_BE_INFERRED], true, [XA])] - // Success: delta inferred for parent is subset of grandparent - #[case([XA, XB], [TO_BE_INFERRED], true, [XA])] - // Base case failure: infers [XA] for parent but grandparent has disjoint set - #[case([XB], [TO_BE_INFERRED], false, [XA])] - // Failure: as previous, but extra "lower bound" on parent that has no effect - #[case([XB], [XA, TO_BE_INFERRED], false, [XA])] - // Failure: grandparent ok wrt. child but parent specifies extra lower-bound XB - #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])] - // Success: grandparent includes extra XB required for parent's "lower bound" - #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])] - // Success: grandparent is also inferred so can include 'extra' XB from parent - #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])] - // No inference: extraneous XB in parent is removed so all become [XA]. - #[case([XA], [XA, XB], true, [XA])] - fn infer_three_generations( - #[case] grandparent: impl IntoIterator, - #[case] parent: impl IntoIterator, - #[case] success: bool, - #[case] result: impl IntoIterator, - ) { - let ty = Type::new_function(Signature::new_endo(type_row![])); - let grandparent = ExtensionSet::from_iter(grandparent).union(LIFT_EXT_ID.into()); - let parent = ExtensionSet::from_iter(parent).union(LIFT_EXT_ID.into()); - let result = ExtensionSet::from_iter(result).union(LIFT_EXT_ID.into()); - let root_ty = ops::Conditional { - sum_rows: vec![type_row![]], - other_inputs: ty.clone().into(), - outputs: ty.clone().into(), - extension_delta: grandparent.clone(), - }; - let mut h = Hugr::new(root_ty.clone()); - let p = h.add_node_with_parent( - h.root(), - ops::Case { - signature: Signature::new_endo(ty.clone()).with_extension_delta(parent), - }, - ); - add_inliftout(&mut h, p, ty.clone()); - assert!(h.validate_extensions().is_err()); - let backup = h.clone(); - let inf_res = h.infer_extensions(true); - if success { - assert!(inf_res.is_ok()); - let expected_p = ops::Case { - signature: Signature::new_endo(ty).with_extension_delta(result.clone()), - }; - let mut expected = backup; - expected.replace_op(p, expected_p).unwrap(); - let expected_gp = ops::Conditional { - extension_delta: result, - ..root_ty - }; - expected.replace_op(h.root(), expected_gp).unwrap(); - - assert_eq!(h, expected); - } else { - assert_eq!( - inf_res, - Err(ExtensionError { - parent: h.root(), - parent_extensions: grandparent, - child: p, - child_extensions: result - }) - ); - } - } } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index f3ef094be..7805d3c67 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,22 +1,24 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet, VecDeque}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, PortMut, PortView, SecondaryMap}; +use crate::core::HugrNode; use crate::extension::ExtensionRegistry; +use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrView, Node, OpType, RootTagged}; -use crate::hugr::{NodeMetadata, Rewrite}; +use crate::hugr::{HugrView, Node, OpType}; +use crate::hugr::{NodeMetadata, Patch}; use crate::ops::OpTrait; use crate::types::Substitution; use crate::{Extension, Hugr, IncomingPort, OutgoingPort, Port, PortIndex}; use super::internal::HugrMutInternals; -use super::NodeMetadataMap; +use super::views::{panic_invalid_node, panic_invalid_non_root, panic_invalid_port}; /// Functions for low-level building of a HUGR. pub trait HugrMut: HugrMutInternals { @@ -25,17 +27,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph. - fn get_metadata_mut(&mut self, node: Node, key: impl AsRef) -> &mut NodeMetadata { - panic_invalid_node(self, node); - let node_meta = self - .hugr_mut() - .metadata - .get_mut(node.pg_index()) - .get_or_insert_with(Default::default); - node_meta - .entry(key.as_ref()) - .or_insert(serde_json::Value::Null) - } + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata; /// Sets a metadata value associated with a node. /// @@ -44,44 +36,17 @@ pub trait HugrMut: HugrMutInternals { /// If the node is not in the graph. fn set_metadata( &mut self, - node: Node, + node: Self::Node, key: impl AsRef, metadata: impl Into, - ) { - let entry = self.get_metadata_mut(node, key); - *entry = metadata.into(); - } + ); /// Remove a metadata entry associated with a node. /// /// # Panics /// /// If the node is not in the graph. - fn remove_metadata(&mut self, node: Node, key: impl AsRef) { - panic_invalid_node(self, node); - let node_meta = self.hugr_mut().metadata.get_mut(node.pg_index()); - if let Some(node_meta) = node_meta { - node_meta.remove(key.as_ref()); - } - } - - /// Retrieve the complete metadata map for a node. - fn take_node_metadata(&mut self, node: Self::Node) -> Option { - if !self.valid_node(node) { - return None; - } - self.hugr_mut().metadata.take(node.pg_index()) - } - - /// Overwrite the complete metadata map for a node. - /// - /// # Panics - /// - /// If the node is not in the graph. - fn overwrite_node_metadata(&mut self, node: Node, metadata: Option) { - panic_invalid_node(self, node); - self.hugr_mut().metadata.set(node.pg_index(), metadata); - } + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef); /// Add a node to the graph with a parent in the hierarchy. /// @@ -90,11 +55,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the parent is not in the graph. - #[inline] - fn add_node_with_parent(&mut self, parent: Node, op: impl Into) -> Node { - panic_invalid_node(self, parent); - self.hugr_mut().add_node_with_parent(parent, op) - } + fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; /// Add a node to the graph as the previous sibling of another node. /// @@ -103,11 +64,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the sibling is not in the graph, or if the sibling is the root node. - #[inline] - fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { - panic_invalid_non_root(self, sibling); - self.hugr_mut().add_node_before(sibling, nodetype) - } + fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; /// Add a node to the graph as the next sibling of another node. /// @@ -116,11 +73,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the sibling is not in the graph, or if the sibling is the root node. - #[inline] - fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { - panic_invalid_non_root(self, sibling); - self.hugr_mut().add_node_after(sibling, op) - } + fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; /// Remove a node from the graph and return the node weight. /// Note that if the node has children, they are not removed; this leaves @@ -129,24 +82,14 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the node is the root node. - #[inline] - fn remove_node(&mut self, node: Node) -> OpType { - panic_invalid_non_root(self, node); - self.hugr_mut().remove_node(node) - } + fn remove_node(&mut self, node: Self::Node) -> OpType; /// Remove a node from the graph, along with all its descendants in the hierarchy. /// /// # Panics /// /// If the node is not in the graph, or is the root (this would leave an empty Hugr). - fn remove_subtree(&mut self, node: Node) { - panic_invalid_non_root(self, node); - while let Some(ch) = self.first_child(node) { - self.remove_subtree(ch) - } - self.hugr_mut().remove_node(node); - } + fn remove_subtree(&mut self, node: Self::Node); /// Copies the strict descendants of `root` to under the `new_parent`, optionally applying a /// [Substitution] to the [OpType]s of the copied nodes. @@ -162,32 +105,23 @@ pub trait HugrMut: HugrMutInternals { /// correspondingly for `Dom` edges) fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { - panic_invalid_node(self, root); - panic_invalid_node(self, new_parent); - self.hugr_mut().copy_descendants(root, new_parent, subst) - } + ) -> BTreeMap; /// Connect two nodes at the given ports. /// /// # Panics /// /// If either node is not in the graph or if the ports are invalid. - #[inline] fn connect( &mut self, - src: Node, + src: Self::Node, src_port: impl Into, - dst: Node, + dst: Self::Node, dst_port: impl Into, - ) { - panic_invalid_node(self, src); - panic_invalid_node(self, dst); - self.hugr_mut().connect(src, src_port, dst, dst_port); - } + ); /// Disconnects all edges from the given port. /// @@ -196,11 +130,7 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the port is invalid. - #[inline] - fn disconnect(&mut self, node: Node, port: impl Into) { - panic_invalid_node(self, node); - self.hugr_mut().disconnect(node, port); - } + fn disconnect(&mut self, node: Self::Node, port: impl Into); /// Adds a non-dataflow edge between two nodes. The kind is given by the /// operation's [`OpTrait::other_input`] or [`OpTrait::other_output`]. @@ -213,35 +143,27 @@ pub trait HugrMut: HugrMutInternals { /// # Panics /// /// If the node is not in the graph, or if the port is invalid. - fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) { - panic_invalid_node(self, src); - panic_invalid_node(self, dst); - self.hugr_mut().add_other_edge(src, dst) - } + fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (OutgoingPort, IncomingPort); - /// Insert another hugr into this one, under a given root node. + /// Insert another hugr into this one, under a given parent node. /// /// # Panics /// /// If the root node is not in the graph. - #[inline] - fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult { - panic_invalid_node(self, root); - self.hugr_mut().insert_hugr(root, other) - } + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult; - /// Copy another hugr into this one, under a given root node. + /// Copy another hugr into this one, under a given parent node. /// /// # Panics /// /// If the root node is not in the graph. - #[inline] - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { - panic_invalid_node(self, root); - self.hugr_mut().insert_from_view(root, other) - } + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult; - /// Copy a subgraph from another hugr into this one, under a given root node. + /// Copy a subgraph from another hugr into this one, under a given parent node. /// /// Sibling order is not preserved. /// @@ -255,18 +177,15 @@ pub trait HugrMut: HugrMutInternals { // TODO: Try to preserve the order when possible? We cannot always ensure // it, since the subgraph may have arbitrary nodes without including their // parent. - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { - panic_invalid_node(self, root); - self.hugr_mut().insert_subgraph(root, other, subgraph) - } + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap; - /// Applies a rewrite to the graph. - fn apply_rewrite(&mut self, rw: impl Rewrite) -> Result + /// Applies a patch to the graph. + fn apply_patch(&mut self, rw: impl Patch) -> Result where Self: Sized, { @@ -279,9 +198,7 @@ pub trait HugrMut: HugrMutInternals { /// These can be queried using [`HugrView::extensions`]. /// /// See [`ExtensionRegistry::register_updated`] for more information. - fn use_extension(&mut self, extension: impl Into>) { - self.hugr_mut().extensions.register_updated(extension); - } + fn use_extension(&mut self, extension: impl Into>); /// Extend the set of extensions used by the hugr with the extensions in the /// registry. @@ -294,69 +211,103 @@ pub trait HugrMut: HugrMutInternals { /// See [`ExtensionRegistry::register_updated`] for more information. fn use_extensions(&mut self, registry: impl IntoIterator) where - ExtensionRegistry: Extend, - { - self.hugr_mut().extensions.extend(registry); - } - - /// Returns a mutable reference to the extension registry for this hugr. - fn extensions_mut(&mut self) -> &mut ExtensionRegistry { - &mut self.hugr_mut().extensions - } + ExtensionRegistry: Extend; } /// Records the result of inserting a Hugr or view /// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]. -pub struct InsertionResult { +/// +/// Contains a map from the nodes in the source HUGR to the nodes in the +/// target HUGR, using their respective `Node` types. +pub struct InsertionResult { /// The node, after insertion, that was the root of the inserted Hugr. /// /// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root] - pub new_root: Node, + pub new_root: TargetN, /// Map from nodes in the Hugr/view that was inserted, to their new /// positions in the Hugr into which said was inserted. - pub node_map: HashMap, + pub node_map: HashMap, } -fn translate_indices( +/// Translate a portgraph node index map into a map from nodes in the source +/// HUGR to nodes in the target HUGR. +/// +/// This is as a helper in `insert_hugr` and `insert_subgraph`, where the source +/// HUGR may be an arbitrary `HugrView` with generic node types. +fn translate_indices( + mut source_node: impl FnMut(portgraph::NodeIndex) -> N, + mut target_node: impl FnMut(portgraph::NodeIndex) -> Node, node_map: HashMap, -) -> impl Iterator { - node_map.into_iter().map(|(k, v)| (k.into(), v.into())) +) -> impl Iterator { + node_map + .into_iter() + .map(move |(k, v)| (source_node(k), target_node(v))) } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. -impl + AsMut> HugrMut for T { +impl HugrMut for Hugr { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut NodeMetadata { + panic_invalid_node(self, node); + self.node_metadata_map_mut(node) + .entry(key.as_ref()) + .or_insert(serde_json::Value::Null) + } + + fn set_metadata( + &mut self, + node: Self::Node, + key: impl AsRef, + metadata: impl Into, + ) { + let entry = self.get_metadata_mut(node, key); + *entry = metadata.into(); + } + + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef) { + panic_invalid_node(self, node); + let node_meta = self.node_metadata_map_mut(node); + node_meta.remove(key.as_ref()); + } + fn add_node_with_parent(&mut self, parent: Node, node: impl Into) -> Node { let node = self.as_mut().add_node(node.into()); - self.as_mut() - .hierarchy - .push_child(node.pg_index(), parent.pg_index()) + self.hierarchy + .push_child(node.into_portgraph(), parent.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } fn add_node_before(&mut self, sibling: Node, nodetype: impl Into) -> Node { let node = self.as_mut().add_node(nodetype.into()); - self.as_mut() - .hierarchy - .insert_before(node.pg_index(), sibling.pg_index()) + self.hierarchy + .insert_before(node.into_portgraph(), sibling.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } fn add_node_after(&mut self, sibling: Node, op: impl Into) -> Node { let node = self.as_mut().add_node(op.into()); - self.as_mut() - .hierarchy - .insert_after(node.pg_index(), sibling.pg_index()) + self.hierarchy + .insert_after(node.into_portgraph(), sibling.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); node } fn remove_node(&mut self, node: Node) -> OpType { panic_invalid_non_root(self, node); - self.as_mut().hierarchy.remove(node.pg_index()); - self.as_mut().graph.remove_node(node.pg_index()); - self.as_mut().op_types.take(node.pg_index()) + self.hierarchy.remove(node.into_portgraph()); + self.graph.remove_node(node.into_portgraph()); + self.op_types.take(node.into_portgraph()) + } + + fn remove_subtree(&mut self, node: Node) { + panic_invalid_non_root(self, node); + let mut queue = VecDeque::new(); + queue.push_back(node); + while let Some(n) = queue.pop_front() { + queue.extend(self.children(n)); + self.remove_node(n); + } } fn connect( @@ -370,12 +321,11 @@ impl + AsMut> HugrMut for T let dst_port = dst_port.into(); panic_invalid_port(self, src, src_port); panic_invalid_port(self, dst, dst_port); - self.as_mut() - .graph + self.graph .link_nodes( - src.pg_index(), + src.into_portgraph(), src_port.index(), - dst.pg_index(), + dst.into_portgraph(), dst_port.index(), ) .expect("The ports should exist at this point."); @@ -386,11 +336,10 @@ impl + AsMut> HugrMut for T let offset = port.pg_offset(); panic_invalid_port(self, node, port); let port = self - .as_mut() .graph - .port_index(node.pg_index(), offset) + .port_index(node.into_portgraph(), offset) .expect("The port should exist at this point."); - self.as_mut().graph.unlink_port(port); + self.graph.unlink_port(port); } fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) { @@ -406,90 +355,127 @@ impl + AsMut> HugrMut for T (src_port, dst_port) } - fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult { - let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); + fn insert_hugr( + &mut self, + root: Self::Node, + mut other: Hugr, + ) -> InsertionResult { + let (new_root, node_map) = insert_hugr_internal(self, root, &other); // Update the optypes and metadata, taking them from the other graph. // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { let optype = other.op_types.take(node); - self.as_mut().op_types.set(new_node, optype); + self.op_types.set(new_node, optype); let meta = other.metadata.take(node); - self.as_mut().metadata.set(new_node, meta); + self.metadata.set(new_node, meta); } debug_assert_eq!( - Some(&new_root.pg_index()), - node_map.get(&other.root().pg_index()) + Some(&new_root.into_portgraph()), + node_map.get(&other.root().into_portgraph()) ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect(), } } - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { - let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { + let (new_root, node_map) = insert_hugr_internal(self, root, other); // Update the optypes and metadata, copying them from the other graph. // // No need to compute each node's extensions here, as we merge `other.extensions` directly. for (&node, &new_node) in node_map.iter() { - let nodetype = other.get_optype(other.get_node(node)); - self.as_mut().op_types.set(new_node, nodetype.clone()); - let meta = other.base_hugr().metadata.get(node); - self.as_mut().metadata.set(new_node, meta.clone()); + let node = other.from_portgraph_node(node); + let nodetype = other.get_optype(node); + self.op_types.set(new_node, nodetype.clone()); + let meta = other.node_metadata_map(node); + if !meta.is_empty() { + self.metadata.set(new_node, Some(meta.clone())); + } } debug_assert_eq!( - Some(&new_root.pg_index()), - node_map.get(&other.get_pg_index(other.root())) + Some(&new_root.into_portgraph()), + node_map.get(&other.to_portgraph_node(other.root())) ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect(), } } - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { // Create a portgraph view with the explicit list of nodes defined by the subgraph. - let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> = + let context: HashSet = subgraph + .nodes() + .iter() + .map(|&n| other.to_portgraph_node(n)) + .collect(); + let portgraph: NodeFiltered<_, NodeFilter>, _> = NodeFiltered::new_node_filtered( other.portgraph(), - |node, ctx| ctx.contains(&node.into()), - subgraph.nodes(), + |node, ctx| ctx.contains(&node), + context, ); - let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph); + let node_map = insert_subgraph_internal(self, root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. for (&node, &new_node) in node_map.iter() { - let nodetype = other.get_optype(other.get_node(node)); - self.as_mut().op_types.set(new_node, nodetype.clone()); - let meta = other.base_hugr().metadata.get(node); - self.as_mut().metadata.set(new_node, meta.clone()); + let node = other.from_portgraph_node(node); + let nodetype = other.get_optype(node); + self.op_types.set(new_node, nodetype.clone()); + let meta = other.node_metadata_map(node); + if !meta.is_empty() { + self.metadata.set(new_node, Some(meta.clone())); + } // Add the required extensions to the registry. if let Ok(exts) = nodetype.used_extensions() { self.use_extensions(exts); } } - translate_indices(node_map).collect() + translate_indices( + |n| other.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, + ) + .collect() } fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { - let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); + ) -> BTreeMap { + let mut descendants = self.hierarchy.descendants(root.into_portgraph()); let root2 = descendants.next(); - debug_assert_eq!(root2, Some(root.pg_index())); + debug_assert_eq!(root2, Some(root.into_portgraph())); let nodes = Vec::from_iter(descendants); + let node_map = portgraph::view::Subgraph::with_nodes(&mut self.graph, nodes) + .copy_in_parent() + .expect("Is a MultiPortGraph"); let node_map = translate_indices( - portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) - .copy_in_parent() - .expect("Is a MultiPortGraph"), + |n| self.from_portgraph_node(n), + |n| self.from_portgraph_node(n), + node_map, ) .collect::>(); @@ -506,12 +492,25 @@ impl + AsMut> HugrMut for T (None, op) => op.clone(), (Some(subst), op) => op.substitute(subst), }; - self.as_mut().op_types.set(new_node.pg_index(), new_optype); - let meta = self.base_hugr().metadata.get(node.pg_index()).clone(); - self.as_mut().metadata.set(new_node.pg_index(), meta); + self.op_types.set(new_node.into_portgraph(), new_optype); + let meta = self.metadata.get(node.into_portgraph()).clone(); + self.metadata.set(new_node.into_portgraph(), meta); } node_map } + + #[inline] + fn use_extension(&mut self, extension: impl Into>) { + self.extensions_mut().register_updated(extension); + } + + #[inline] + fn use_extensions(&mut self, registry: impl IntoIterator) + where + ExtensionRegistry: Extend, + { + self.extensions_mut().extend(registry); + } } /// Internal implementation of `insert_hugr` and `insert_view` methods for @@ -531,18 +530,20 @@ fn insert_hugr_internal( .graph .insert_graph(&other.portgraph()) .unwrap_or_else(|e| panic!("Internal error while inserting a hugr into another: {e}")); - let other_root = node_map[&other.get_pg_index(other.root())]; + let other_root = node_map[&other.to_portgraph_node(other.root())]; // Update hierarchy and optypes hugr.hierarchy - .push_child(other_root, root.pg_index()) + .push_child(other_root, root.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); for (&node, &new_node) in node_map.iter() { - other.children(other.get_node(node)).for_each(|child| { - hugr.hierarchy - .push_child(node_map[&other.get_pg_index(child)], new_node) - .expect("Inserting a newly-created node into the hierarchy should never fail."); - }); + other + .children(other.from_portgraph_node(node)) + .for_each(|child| { + hugr.hierarchy + .push_child(node_map[&other.to_portgraph_node(child)], new_node) + .expect("Inserting a newly-created node into the hierarchy should never fail."); + }); } // Merge the extension sets. @@ -563,10 +564,10 @@ fn insert_hugr_internal( /// sibling order in the hierarchy. This is due to the subgraph not necessarily /// having a single root, so the logic for reconstructing the hierarchy is not /// able to just do a BFS. -fn insert_subgraph_internal( +fn insert_subgraph_internal( hugr: &mut Hugr, root: Node, - other: &impl HugrView, + other: &impl HugrView, portgraph: &impl portgraph::LinkView, ) -> HashMap { let node_map = hugr @@ -578,9 +579,9 @@ fn insert_subgraph_internal( // update the hierarchy with their new id. for (&node, &new_node) in node_map.iter() { let new_parent = other - .get_parent(other.get_node(node)) - .and_then(|parent| node_map.get(&other.get_pg_index(parent)).copied()) - .unwrap_or(root.pg_index()); + .get_parent(other.from_portgraph_node(node)) + .and_then(|parent| node_map.get(&other.to_portgraph_node(parent)).copied()) + .unwrap_or(root.into_portgraph()); hugr.hierarchy .push_child(new_node, new_parent) .expect("Inserting a newly-created node into the hierarchy should never fail."); @@ -589,48 +590,6 @@ fn insert_subgraph_internal( node_map } -/// Panic if [`HugrView::valid_node`] fails. -#[track_caller] -pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { - if !hugr.valid_node(node) { - panic!( - "Received an invalid node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); - } -} - -/// Panic if [`HugrView::valid_non_root`] fails. -#[track_caller] -pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { - if !hugr.valid_non_root(node) { - panic!( - "Received an invalid non-root node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); - } -} - -/// Panic if [`HugrView::valid_node`] fails. -#[track_caller] -pub(super) fn panic_invalid_port( - hugr: &H, - node: Node, - port: impl Into, -) { - let port = port.into(); - if hugr - .portgraph() - .port_index(node.pg_index(), port.pg_offset()) - .is_none() - { - panic!( - "Received an invalid port {port} for node {node} while mutating a HUGR:\n\n {}", - hugr.mermaid_string() - ); - } -} - #[cfg(test)] mod test { use crate::extension::PRELUDE; @@ -655,9 +614,7 @@ mod test { module, ops::FuncDefn { name: "main".into(), - signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]) - .with_prelude() - .into(), + signature: Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]).into(), }, ); @@ -714,14 +671,14 @@ mod test { fd }); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 7); + assert_eq!(hugr.num_nodes(), 7); hugr.remove_subtree(foo); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 4); + assert_eq!(hugr.num_nodes(), 4); hugr.remove_subtree(bar); hugr.validate().unwrap(); - assert_eq!(hugr.node_count(), 1); + assert_eq!(hugr.num_nodes(), 1); } } diff --git a/hugr-core/src/hugr/internal.rs b/hugr-core/src/hugr/internal.rs index 6dab3adc0..09f234de0 100644 --- a/hugr-core/src/hugr/internal.rs +++ b/hugr-core/src/hugr/internal.rs @@ -1,20 +1,18 @@ //! Internal traits, not exposed in the public `hugr` API. -use std::borrow::Cow; use std::ops::Range; -use std::rc::Rc; -use std::sync::Arc; +use std::sync::OnceLock; -use delegate::delegate; use itertools::Itertools; use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView}; -use crate::ops::handle::NodeHandle; -use crate::ops::{OpTag, OpTrait}; +use crate::extension::ExtensionRegistry; use crate::{Direction, Hugr, Node}; -use super::hugrmut::{panic_invalid_node, panic_invalid_non_root}; -use super::{HugrError, OpType, RootTagged}; +use super::views::{panic_invalid_node, panic_invalid_non_root}; +use super::HugrView; +use super::{NodeMetadataMap, OpType}; +use crate::ops::handle::NodeHandle; /// Trait for accessing the internals of a Hugr(View). /// @@ -22,7 +20,7 @@ use super::{HugrError, OpType, RootTagged}; /// view. pub trait HugrInternals { /// The underlying portgraph view type. - type Portgraph<'p>: LinkView + Clone + 'p + type Portgraph<'p>: LinkView + Clone + 'p where Self: 'p; @@ -32,24 +30,39 @@ pub trait HugrInternals { /// Returns a reference to the underlying portgraph. fn portgraph(&self) -> Self::Portgraph<'_>; + /// Returns a flat portgraph view of a region in the HUGR. + /// + /// This is a subgraph of [`HugrInternals::portgraph`], with a flat hierarchy. + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_>; + /// Returns the portgraph [Hierarchy](portgraph::Hierarchy) of the graph /// returned by [`HugrInternals::portgraph`]. - #[inline] - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy> { - Cow::Borrowed(&self.base_hugr().hierarchy) - } - - /// Returns the Hugr at the base of a chain of views. - fn base_hugr(&self) -> &Hugr; - - /// Return the root node of this view. - fn root_node(&self) -> Self::Node; + fn hierarchy(&self) -> &portgraph::Hierarchy; /// Convert a node to a portgraph node index. - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex; /// Convert a portgraph node index to a node. - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; + #[allow(clippy::wrong_self_convention)] + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node; + + /// Returns a metadata entry associated with a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap; + + /// Returns the Hugr at the base of a chain of views. + // TODO: This will be removed in a future PR. + #[deprecated( + since = "0.16.0", + note = "This method will be removed in a future PR. Use the individual HugrInternals methods instead." + )] + fn base_hugr(&self) -> &Hugr; } impl HugrInternals for Hugr { @@ -66,159 +79,65 @@ impl HugrInternals for Hugr { } #[inline] - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy> { - Cow::Borrowed(&self.hierarchy) + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion<'_, impl LinkView + Clone + '_> { + let pg = self.portgraph(); + let root = self.to_portgraph_node(parent); + portgraph::view::FlatRegion::new_without_root(pg, &self.hierarchy, root) } #[inline] - fn base_hugr(&self) -> &Hugr { - self + fn hierarchy(&self) -> &portgraph::Hierarchy { + &self.hierarchy } #[inline] - fn root_node(&self) -> Self::Node { - self.root.into() + fn base_hugr(&self) -> &Hugr { + self } - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex { - node.pg_index() + #[inline] + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + node.node().into_portgraph() } - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node { + #[inline] + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node { index.into() } -} - -impl HugrInternals for &T { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for &mut T { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} -impl HugrInternals for Rc { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Arc { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} - -impl HugrInternals for Box { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to (**self) { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } + #[inline] + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + static EMPTY: OnceLock = OnceLock::new(); + panic_invalid_node(self, node); + let map = self.metadata.get(node.into_portgraph()).as_ref(); + map.unwrap_or(EMPTY.get_or_init(Default::default)) } } -impl HugrInternals for Cow<'_, T> { - type Portgraph<'p> - = T::Portgraph<'p> - where - Self: 'p; - type Node = T::Node; - - delegate! { - to self.as_ref() { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Self::Node; - fn get_pg_index(&self, node: Self::Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Self::Node; - } - } -} /// Trait for accessing the mutable internals of a Hugr(Mut). /// /// Specifically, this trait lets you apply arbitrary modifications that may /// invalidate the HUGR. -pub trait HugrMutInternals: RootTagged { - /// Returns the Hugr at the base of a chain of views. - fn hugr_mut(&mut self) -> &mut Hugr; +pub trait HugrMutInternals: HugrView { + /// Set the node at the root of the HUGR hierarchy. + /// + /// Any node not reachable from this root should be deleted from the HUGR + /// after this call. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn set_root(&mut self, root: Self::Node); /// Set the number of ports on a node. This may invalidate the node's `PortIndex`. /// /// # Panics /// /// If the node is not in the graph. - fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { - panic_invalid_node(self, node); - self.hugr_mut().set_num_ports(node, incoming, outgoing) - } + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); /// Alter the number of ports on a node and returns a range with the new /// port offsets, if any. This may invalidate the node's `PortIndex`. @@ -231,10 +150,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If the node is not in the graph. - fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { - panic_invalid_node(self, node); - self.hugr_mut().add_ports(node, direction, amount) - } + fn add_ports(&mut self, node: Self::Node, direction: Direction, amount: isize) -> Range; /// Insert `amount` new ports for a node, starting at `index`. The /// `direction` parameter specifies whether to add ports to the incoming or @@ -247,14 +163,11 @@ pub trait HugrMutInternals: RootTagged { /// If the node is not in the graph. fn insert_ports( &mut self, - node: Node, + node: Self::Node, direction: Direction, index: usize, amount: usize, - ) -> Range { - panic_invalid_node(self, node); - self.hugr_mut().insert_ports(node, direction, index, amount) - } + ) -> Range; /// Sets the parent of a node. /// @@ -263,11 +176,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either the node or the parent is not in the graph. - fn set_parent(&mut self, node: Node, parent: Node) { - panic_invalid_node(self, parent); - panic_invalid_non_root(self, node); - self.hugr_mut().set_parent(node, parent); - } + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); /// Move a node in the hierarchy to be the subsequent sibling of another /// node. @@ -279,11 +188,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either node is not in the graph, or if it is a root. - fn move_after_sibling(&mut self, node: Node, after: Node) { - panic_invalid_non_root(self, node); - panic_invalid_non_root(self, after); - self.hugr_mut().move_after_sibling(node, after); - } + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); /// Move a node in the hierarchy to be the prior sibling of another node. /// @@ -294,11 +199,7 @@ pub trait HugrMutInternals: RootTagged { /// # Panics /// /// If either node is not in the graph, or if it is a root. - fn move_before_sibling(&mut self, node: Node, before: Node) { - panic_invalid_non_root(self, node); - panic_invalid_non_root(self, before); - self.hugr_mut().move_before_sibling(node, before) - } + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); /// Replace the OpType at node and return the old OpType. /// In general this invalidates the ports, which may need to be resized to @@ -306,56 +207,70 @@ pub trait HugrMutInternals: RootTagged { /// /// Returns the old OpType. /// - /// TODO: Add a version which ignores input extensions - /// - /// # Errors - /// - /// Returns a [`HugrError::InvalidTag`] if this would break the bound - /// (`Self::RootHandle`) on the root node's OpTag. - /// /// # Panics /// /// If the node is not in the graph. - fn replace_op(&mut self, node: Node, op: impl Into) -> Result { - panic_invalid_node(self, node); - let op = op.into(); - if node == self.root() && !Self::RootHandle::TAG.is_superset(op.tag()) { - return Err(HugrError::InvalidTag { - required: Self::RootHandle::TAG, - actual: op.tag(), - }); - } - self.hugr_mut().replace_op(node, op) - } + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> OpType; /// Gets a mutable reference to the optype. /// /// Changing this may invalidate the ports, which may need to be resized to /// match the OpType signature. /// - /// Will panic for the root node unless [`Self::RootHandle`](RootTagged::RootHandle) - /// is [OpTag::Any], as mutation could invalidate the bound. - fn optype_mut(&mut self, node: Node) -> &mut OpType { - if Self::RootHandle::TAG.is_superset(OpTag::Any) { - panic_invalid_node(self, node); - } else { - panic_invalid_non_root(self, node); - } - self.hugr_mut().op_types.get_mut(node.pg_index()) - } + /// Mutating the root node operation may invalidate the root tag. + /// + /// Mutating the module root into a non-module operation will invalidate the hugr. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn optype_mut(&mut self, node: Self::Node) -> &mut OpType; + + /// Returns a metadata entry associated with a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap; + + /// Returns a mutable reference to the extension registry for this HUGR. + /// + /// This set contains all extensions required to define the operations and + /// types in the HUGR. + fn extensions_mut(&mut self) -> &mut ExtensionRegistry; } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. -impl + AsMut> HugrMutInternals for T { - fn hugr_mut(&mut self) -> &mut Hugr { - self.as_mut() +impl HugrMutInternals for Hugr { + fn set_root(&mut self, root: Node) { + self.hierarchy.detach(self.root); + self.root = root.into_portgraph(); } #[inline] fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) { - self.hugr_mut() - .graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}) + panic_invalid_node(self, node); + self.graph + .set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {}) + } + + fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { + panic_invalid_node(self, node); + let mut incoming = self.graph.num_inputs(node.into_portgraph()); + let mut outgoing = self.graph.num_outputs(node.into_portgraph()); + let increment = |num: &mut usize| { + let new = num.saturating_add_signed(amount); + let range = *num..new; + *num = new; + range + }; + let range = match direction { + Direction::Incoming => increment(&mut incoming), + Direction::Outgoing => increment(&mut outgoing), + }; + self.graph + .set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {}); + range } fn insert_ports( @@ -365,28 +280,26 @@ impl + AsMut> HugrMutInterna index: usize, amount: usize, ) -> Range { - let old_num_ports = self.base_hugr().graph.num_ports(node.pg_index(), direction); + panic_invalid_node(self, node); + let old_num_ports = self.graph.num_ports(node.into_portgraph(), direction); self.add_ports(node, direction, amount as isize); for swap_from_port in (index..old_num_ports).rev() { let swap_to_port = swap_from_port + amount; let [from_port_index, to_port_index] = [swap_from_port, swap_to_port].map(|p| { - self.base_hugr() - .graph - .port_index(node.pg_index(), PortOffset::new(direction, p)) + self.graph + .port_index(node.into_portgraph(), PortOffset::new(direction, p)) .unwrap() }); let linked_ports = self - .base_hugr() .graph .port_links(from_port_index) .map(|(_, to_subport)| to_subport.port()) .collect_vec(); - self.hugr_mut().graph.unlink_port(from_port_index); + self.graph.unlink_port(from_port_index); for linked_port_index in linked_ports { let _ = self - .hugr_mut() .graph .link_ports(to_port_index, linked_port_index) .expect("Ports exist"); @@ -395,52 +308,53 @@ impl + AsMut> HugrMutInterna index..index + amount } - fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range { - let mut incoming = self.hugr_mut().graph.num_inputs(node.pg_index()); - let mut outgoing = self.hugr_mut().graph.num_outputs(node.pg_index()); - let increment = |num: &mut usize| { - let new = num.saturating_add_signed(amount); - let range = *num..new; - *num = new; - range - }; - let range = match direction { - Direction::Incoming => increment(&mut incoming), - Direction::Outgoing => increment(&mut outgoing), - }; - self.hugr_mut() - .graph - .set_num_ports(node.pg_index(), incoming, outgoing, |_, _| {}); - range - } - fn set_parent(&mut self, node: Node, parent: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy - .push_child(node.pg_index(), parent.pg_index()) + panic_invalid_node(self, parent); + panic_invalid_node(self, node); + self.hierarchy.detach(node.into_portgraph()); + self.hierarchy + .push_child(node.into_portgraph(), parent.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_after_sibling(&mut self, node: Node, after: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy - .insert_after(node.pg_index(), after.pg_index()) + panic_invalid_non_root(self, node); + panic_invalid_non_root(self, after); + self.hierarchy.detach(node.into_portgraph()); + self.hierarchy + .insert_after(node.into_portgraph(), after.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } fn move_before_sibling(&mut self, node: Node, before: Node) { - self.hugr_mut().hierarchy.detach(node.pg_index()); - self.hugr_mut() - .hierarchy - .insert_before(node.pg_index(), before.pg_index()) + panic_invalid_non_root(self, node); + panic_invalid_non_root(self, before); + self.hierarchy.detach(node.into_portgraph()); + self.hierarchy + .insert_before(node.into_portgraph(), before.into_portgraph()) .expect("Inserting a newly-created node into the hierarchy should never fail."); } - fn replace_op(&mut self, node: Node, op: impl Into) -> Result { - // We know RootHandle=Node here so no need to check - Ok(std::mem::replace(self.optype_mut(node), op.into())) + fn replace_op(&mut self, node: Node, op: impl Into) -> OpType { + panic_invalid_node(self, node); + std::mem::replace(self.optype_mut(node), op.into()) + } + + fn optype_mut(&mut self, node: Self::Node) -> &mut OpType { + panic_invalid_node(self, node); + let node = self.to_portgraph_node(node); + self.op_types.get_mut(node) + } + + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap { + panic_invalid_node(self, node); + self.metadata + .get_mut(node.into_portgraph()) + .get_or_insert_with(Default::default) + } + + fn extensions_mut(&mut self) -> &mut ExtensionRegistry { + &mut self.extensions } } @@ -458,8 +372,7 @@ mod test { #[test] fn insert_ports() { let (nop, mut hugr) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(Type::UNIT).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(Type::UNIT)).unwrap(); let [nop_in] = builder.input_wires_arr(); let nop = builder .add_dataflow_op(Noop::new(Type::UNIT), [nop_in]) diff --git a/hugr-core/src/hugr/patch.rs b/hugr-core/src/hugr/patch.rs new file mode 100644 index 000000000..1744ce760 --- /dev/null +++ b/hugr-core/src/hugr/patch.rs @@ -0,0 +1,169 @@ +//! Rewrite operations on the HUGR - replacement, outlining, etc. + +pub mod consts; +pub mod inline_call; +pub mod inline_dfg; +pub mod insert_identity; +pub mod outline_cfg; +mod port_types; +pub mod replace; +pub mod simple_replace; + +use crate::{Hugr, HugrView}; +pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; +pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; + +use super::HugrMut; + +/// Verify that a patch application would succeed. +pub trait PatchVerification { + /// The type of Error with which this Rewrite may fail + type Error: std::error::Error; + + /// The node type of the HugrView that this patch applies to. + type Node; + + /// Checks whether the rewrite would succeed on the specified Hugr. + /// If this call succeeds, [Patch::apply] should also succeed on the same + /// `h` If this calls fails, [Patch::apply] would fail with the same + /// error. + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; + + /// Returns a set of nodes referenced by the rewrite. Modifying any of these + /// nodes will invalidate it. + /// + /// Two `impl Rewrite`s can be composed if their invalidation sets are + /// disjoint. + fn invalidation_set(&self) -> impl Iterator; +} + +/// A patch that can be applied to a mutable Hugr of type `H`. +/// +/// ### When to use +/// +/// Use this trait whenever possible in bounds for the most generality. Note +/// that this will require specifying which type `H` the patch applies to. +/// +/// ### When to implement +/// +/// For patches that work on any `H: HugrMut`, prefer implementing [`PatchHugrMut`] instead. This +/// will automatically implement this trait. +pub trait Patch: PatchVerification { + /// The type returned on successful application of the rewrite. + type Outcome; + + /// If `true`, [Patch::apply]'s of this rewrite guarantee that they do not + /// mutate the Hugr when they return an Err. If `false`, there is no + /// guarantee; the Hugr should be assumed invalid when Err is returned. + const UNCHANGED_ON_FAILURE: bool; + + /// Mutate the specified Hugr, or fail with an error. + /// + /// Returns [`Self::Outcome`] if successful. + /// If [Patch::UNCHANGED_ON_FAILURE] is true, then `h` must be unchanged if + /// Err is returned. See also [PatchVerification::verify] + /// + /// # Panics + /// + /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that + /// is, implementations may begin with `assert!(h.validate())`, with + /// `debug_assert!(h.validate())` being preferred. + fn apply(self, h: &mut H) -> Result; +} + +/// A patch that can be applied to any [`HugrMut`]. +/// +/// This trait is a generalisation of [`Patch`] in that it guarantees that +/// the patch can be applied to any type implementing [`HugrMut`]. +/// +/// ### When to use +/// +/// Prefer using the more general [`Patch`] trait in bounds where the +/// type `H` is known. Resort to this trait if patches must be applicable to +/// any [`HugrMut`] instance. +/// +/// ### When to implement +/// +/// Always implement this trait when possible, to define how a patch is applied +/// to any type implementing [`HugrMut`]. A blanket implementation ensures that +/// any type implementing this trait also implements [`Patch`]. +pub trait PatchHugrMut: PatchVerification { + /// The type returned on successful application of the rewrite. + type Outcome; + + /// If `true`, [self.apply]'s of this rewrite guarantee that they do not + /// mutate the Hugr when they return an Err. If `false`, there is no + /// guarantee; the Hugr should be assumed invalid when Err is returned. + const UNCHANGED_ON_FAILURE: bool; + + /// Mutate the specified Hugr, or fail with an error. + /// + /// Returns [`Self::Outcome`] if successful. + /// If [self.unchanged_on_failure] is true, then `h` must be unchanged if + /// Err is returned. See also [self.verify] + /// # Panics + /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that + /// is, implementations may begin with `assert!(h.validate())`, with + /// `debug_assert!(h.validate())` being preferred. + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result; +} + +impl> Patch for R { + type Outcome = R::Outcome; + const UNCHANGED_ON_FAILURE: bool = R::UNCHANGED_ON_FAILURE; + + fn apply(self, h: &mut H) -> Result { + self.apply_hugr_mut(h) + } +} + +/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) +pub struct Transactional { + underlying: R, +} + +impl PatchVerification for Transactional { + type Error = R::Error; + type Node = R::Node; + + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + self.underlying.verify(h) + } + + #[inline] + fn invalidation_set(&self) -> impl Iterator { + self.underlying.invalidation_set() + } +} + +// Note we might like to constrain R to Rewrite but +// this is not yet supported, https://github.com/rust-lang/rust/issues/92827 +impl PatchHugrMut for Transactional { + type Outcome = R::Outcome; + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result { + if R::UNCHANGED_ON_FAILURE { + return self.underlying.apply_hugr_mut(h); + } + // Try to backup just the contents of this HugrMut. + let mut backup = Hugr::new(h.root_optype().clone()); + backup.insert_from_view(backup.root(), h); + let r = self.underlying.apply_hugr_mut(h); + if r.is_err() { + // Try to restore backup. + h.replace_op(h.root(), backup.root_optype().clone()); + while let Some(child) = h.first_child(h.root()) { + h.remove_node(child); + } + h.insert_hugr(h.root(), backup); + } + r + } +} diff --git a/hugr-core/src/hugr/rewrite/consts.rs b/hugr-core/src/hugr/patch/consts.rs similarity index 71% rename from hugr-core/src/hugr/rewrite/consts.rs rename to hugr-core/src/hugr/patch/consts.rs index c112dfc57..4ddd0b476 100644 --- a/hugr-core/src/hugr/rewrite/consts.rs +++ b/hugr-core/src/hugr/patch/consts.rs @@ -2,11 +2,11 @@ use std::iter; -use crate::{hugr::HugrMut, HugrView, Node}; +use crate::{core::HugrNode, hugr::HugrMut, HugrView, Node}; use itertools::Itertools; use thiserror::Error; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; /// Remove a [`crate::ops::LoadConstant`] node with no consumers. #[derive(Debug, Clone)] @@ -15,24 +15,20 @@ pub struct RemoveLoadConstant(pub N); /// Error from an [`RemoveConst`] or [`RemoveLoadConstant`] operation. #[derive(Debug, Clone, Error, PartialEq, Eq)] #[non_exhaustive] -pub enum RemoveError { +pub enum RemoveError { /// Invalid node. #[error("Node is invalid (either not in HUGR or not correct operation).")] - InvalidNode(Node), + InvalidNode(N), /// Node in use. #[error("Node: {0} has non-zero outgoing connections.")] - ValueUsed(Node), + ValueUsed(N), } -impl Rewrite for RemoveLoadConstant { - type Error = RemoveError; +impl PatchVerification for RemoveLoadConstant { + type Error = RemoveError; + type Node = N; - // The Const node the LoadConstant was connected to. - type ApplyResult = Node; - - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_load_constant()) { @@ -50,7 +46,18 @@ impl Rewrite for RemoveLoadConstant { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.0) + } +} + +impl PatchHugrMut for RemoveLoadConstant { + /// The [`Const`](crate::ops::Const) node the [`LoadConstant`](crate::ops::LoadConstant) was + /// connected to. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let source = h @@ -62,25 +69,17 @@ impl Rewrite for RemoveLoadConstant { Ok(source) } - - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.0) - } } /// Remove a [`crate::ops::Const`] node with no outputs. #[derive(Debug, Clone)] -pub struct RemoveConst(pub Node); - -impl Rewrite for RemoveConst { - type Error = RemoveError; +pub struct RemoveConst(pub N); - // The parent of the Const node. - type ApplyResult = Node; +impl PatchVerification for RemoveConst { + type Node = N; + type Error = RemoveError; - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let node = self.0; if (!h.contains_node(node)) || (!h.get_optype(node).is_const()) { @@ -94,7 +93,18 @@ impl Rewrite for RemoveConst { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.0) + } +} + +impl PatchHugrMut for RemoveConst { + // The parent of the Const node. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.verify(h)?; let node = self.0; let parent = h @@ -104,17 +114,12 @@ impl Rewrite for RemoveConst { Ok(parent) } - - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.0) - } } #[cfg(test)] mod test { use super::*; - use crate::extension::prelude::PRELUDE_ID; use crate::{ builder::{Container, Dataflow, HugrBuilder, ModuleBuilder, SubContainer}, extension::prelude::ConstUsize, @@ -127,10 +132,7 @@ mod test { let mut build = ModuleBuilder::new(); let con_node = build.add_constant(Value::extension(ConstUsize::new(2))); - let mut dfg_build = build.define_function( - "main", - Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID.clone()), - )?; + let mut dfg_build = build.define_function("main", Signature::new_endo(type_row![]))?; let load_1 = dfg_build.load_const(&con_node); let load_2 = dfg_build.load_const(&con_node); let tup = dfg_build.make_tuple([load_1, load_2])?; @@ -138,16 +140,16 @@ mod test { let mut h = build.finish_hugr()?; // nodes are Module, Function, Input, Output, Const, LoadConstant*2, MakeTuple - assert_eq!(h.node_count(), 8); + assert_eq!(h.num_nodes(), 8); let tup_node = tup.node(); // can't remove invalid node assert_eq!( - h.apply_rewrite(RemoveConst(tup_node)), + h.apply_patch(RemoveConst(tup_node)), Err(RemoveError::InvalidNode(tup_node)) ); assert_eq!( - h.apply_rewrite(RemoveLoadConstant(tup_node)), + h.apply_patch(RemoveLoadConstant(tup_node)), Err(RemoveError::InvalidNode(tup_node)) ); let load_1_node = load_1.node(); @@ -170,7 +172,7 @@ mod test { // can't remove nodes in use assert_eq!( - h.apply_rewrite(remove_1.clone()), + h.apply_patch(remove_1.clone()), Err(RemoveError::ValueUsed(load_1_node)) ); @@ -178,22 +180,22 @@ mod test { h.remove_node(tup_node); // remove first load - let reported_con_node = h.apply_rewrite(remove_1)?; + let reported_con_node = h.apply_patch(remove_1)?; assert_eq!(reported_con_node, con_node); // still can't remove const, in use by second load assert_eq!( - h.apply_rewrite(remove_con.clone()), + h.apply_patch(remove_con.clone()), Err(RemoveError::ValueUsed(con_node)) ); // remove second use - let reported_con_node = h.apply_rewrite(remove_2)?; + let reported_con_node = h.apply_patch(remove_2)?; assert_eq!(reported_con_node, con_node); // remove const - assert_eq!(h.apply_rewrite(remove_con)?, h.root()); + assert_eq!(h.apply_patch(remove_con)?, h.root()); - assert_eq!(h.node_count(), 4); + assert_eq!(h.num_nodes(), 4); assert!(h.validate().is_ok()); Ok(()) } diff --git a/hugr-core/src/hugr/rewrite/inline_call.rs b/hugr-core/src/hugr/patch/inline_call.rs similarity index 85% rename from hugr-core/src/hugr/rewrite/inline_call.rs rename to hugr-core/src/hugr/patch/inline_call.rs index 9af9cd70a..5f31fbc79 100644 --- a/hugr-core/src/hugr/rewrite/inline_call.rs +++ b/hugr-core/src/hugr/patch/inline_call.rs @@ -2,40 +2,41 @@ //! into a DFG which replaces the Call node. use derive_more::{Display, Error}; +use crate::core::HugrNode; use crate::ops::{DataflowParent, OpType, DFG}; use crate::types::Substitution; use crate::{Direction, HugrView, Node}; -use super::{HugrMut, Rewrite}; +use super::{HugrMut, PatchHugrMut, PatchVerification}; /// Rewrite to inline a [Call](OpType::Call) to a known [FuncDefn](OpType::FuncDefn) -pub struct InlineCall(Node); +pub struct InlineCall(N); /// Error in performing [InlineCall] rewrite. #[derive(Clone, Debug, Display, Error, PartialEq)] #[non_exhaustive] -pub enum InlineCallError { +pub enum InlineCallError { /// The specified Node was not a [Call](OpType::Call) #[display("Node to inline {_0} expected to be a Call but actually {_1}")] - NotCallNode(Node, OpType), + NotCallNode(N, OpType), /// The node was a Call, but the target was not a [FuncDefn](OpType::FuncDefn) /// - presumably a [FuncDecl](OpType::FuncDecl), if the Hugr is valid. #[display("Call targetted node {_0} which must be a FuncDefn but was {_1}")] - CallTargetNotFuncDefn(Node, OpType), + CallTargetNotFuncDefn(N, OpType), } -impl InlineCall { +impl InlineCall { /// Create a new instance that will inline the specified node /// (i.e. that should be a [Call](OpType::Call)) - pub fn new(node: Node) -> Self { + pub fn new(node: N) -> Self { Self(node) } } -impl Rewrite for InlineCall { - type ApplyResult = (); - type Error = InlineCallError; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { +impl PatchVerification for InlineCall { + type Error = InlineCallError; + type Node = N; + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { let call_ty = h.get_optype(self.0); if !call_ty.is_call() { return Err(InlineCallError::NotCallNode(self.0, call_ty.clone())); @@ -51,7 +52,14 @@ impl Rewrite for InlineCall { Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + fn invalidation_set(&self) -> impl Iterator { + Some(self.0).into_iter() + } +} + +impl PatchHugrMut for InlineCall { + type Outcome = (); + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result<(), Self::Error> { self.verify(h)?; // Now we know we have a Call to a FuncDefn. let orig_func = h.static_source(self.0).unwrap(); @@ -75,7 +83,6 @@ impl Rewrite for InlineCall { let ty_args = h .replace_op(self.0, new_op) - .unwrap() .as_call() .unwrap() .type_args @@ -99,10 +106,6 @@ impl Rewrite for InlineCall { /// Failure only occurs if the node is not a Call, or the target not a FuncDefn. /// (Any later failure means an invalid Hugr and `panic`.) const UNCHANGED_ON_FAILURE: bool = true; - - fn invalidation_set(&self) -> impl Iterator { - Some(self.0).into_iter() - } } #[cfg(test)] @@ -116,13 +119,10 @@ mod test { ModuleBuilder, }; use crate::extension::prelude::usize_t; - use crate::hugr::views::RootChecked; - use crate::ops::handle::{FuncID, ModuleRootID, NodeHandle}; + use crate::ops::handle::{FuncID, NodeHandle}; use crate::ops::{Input, OpType, Value}; - use crate::std_extensions::arithmetic::{ - int_ops::{self, IntOpDef}, - int_types::{self, ConstInt, INT_TYPES}, - }; + use crate::std_extensions::arithmetic::int_types::INT_TYPES; + use crate::std_extensions::arithmetic::{int_ops::IntOpDef, int_types::ConstInt}; use crate::types::{PolyFuncType, Signature, Type, TypeBound}; use crate::{HugrView, Node}; @@ -143,9 +143,7 @@ mod test { fn test_inline() -> Result<(), Box> { let mut mb = ModuleBuilder::new(); let cst3 = mb.add_constant(Value::from(ConstInt::new_u(4, 3)?)); - let sig = Signature::new_endo(INT_TYPES[4].clone()) - .with_extension_delta(int_ops::EXTENSION_ID) - .with_extension_delta(int_types::EXTENSION_ID); + let sig = Signature::new_endo(INT_TYPES[4].clone()); let func = { let mut fb = mb.define_function("foo", sig.clone())?; let c1 = fb.load_const(&cst3); @@ -178,10 +176,7 @@ mod test { .count(), 1 ); - RootChecked::<_, ModuleRootID>::try_new(&mut hugr) - .unwrap() - .apply_rewrite(InlineCall(call1.node())) - .unwrap(); + hugr.apply_patch(InlineCall(call1.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).collect_vec(), [call2]); assert_eq!(calls(&hugr), [call2]); @@ -194,7 +189,7 @@ mod test { .count(), 1 ); - hugr.apply_rewrite(InlineCall(call2.node())).unwrap(); + hugr.apply_patch(InlineCall(call2.node())).unwrap(); hugr.validate().unwrap(); assert_eq!(hugr.output_neighbours(func.node()).next(), None); assert_eq!(calls(&hugr), []); @@ -206,9 +201,7 @@ mod test { #[test] fn test_recursion() -> Result<(), Box> { let mut mb = ModuleBuilder::new(); - let sig = Signature::new_endo(INT_TYPES[5].clone()) - .with_extension_delta(int_ops::EXTENSION_ID) - .with_extension_delta(int_types::EXTENSION_ID); + let sig = Signature::new_endo(INT_TYPES[5].clone()); let (func, rec_call) = { let mut fb = mb.define_function("foo", sig.clone())?; let cst1 = fb.add_load_value(ConstInt::new_u(5, 1)?); @@ -229,7 +222,7 @@ mod test { let func = func.node(); let mut call = call.node(); for i in 2..10 { - hugr.apply_rewrite(InlineCall(call))?; + hugr.apply_patch(InlineCall(call))?; hugr.validate().unwrap(); assert_eq!(extension_ops(&hugr).len(), i); let v = calls(&hugr); @@ -268,7 +261,7 @@ mod test { let h = modb.finish_hugr().unwrap(); let mut h2 = h.clone(); assert_eq!( - h2.apply_rewrite(InlineCall(call.node())), + h2.apply_patch(InlineCall(call.node())), Err(InlineCallError::CallTargetNotFuncDefn( decl.node(), h.get_optype(decl.node()).clone() @@ -281,7 +274,7 @@ mod test { .try_into() .unwrap(); assert_eq!( - h2.apply_rewrite(InlineCall(inp)), + h2.apply_patch(InlineCall(inp)), Err(InlineCallError::NotCallNode( inp, Input { @@ -295,10 +288,7 @@ mod test { #[test] fn test_polymorphic() -> Result<(), Box> { let tuple_ty = Type::new_tuple(vec![usize_t(); 2]); - let mut fb = FunctionBuilder::new( - "mkpair", - Signature::new(usize_t(), tuple_ty.clone()).with_prelude(), - )?; + let mut fb = FunctionBuilder::new("mkpair", Signature::new(usize_t(), tuple_ty.clone()))?; let inner = fb.define_function( "id", PolyFuncType::new( @@ -318,7 +308,7 @@ mod test { hugr.output_neighbours(inner.node()).collect::>(), [call1.node(), call2.node()] ); - hugr.apply_rewrite(InlineCall::new(call1.node()))?; + hugr.apply_patch(InlineCall::new(call1.node()))?; assert_eq!( hugr.output_neighbours(inner.node()).collect::>(), diff --git a/hugr-core/src/hugr/rewrite/inline_dfg.rs b/hugr-core/src/hugr/patch/inline_dfg.rs similarity index 94% rename from hugr-core/src/hugr/rewrite/inline_dfg.rs rename to hugr-core/src/hugr/patch/inline_dfg.rs index a8a09e0cc..c7356f8e0 100644 --- a/hugr-core/src/hugr/rewrite/inline_dfg.rs +++ b/hugr-core/src/hugr/patch/inline_dfg.rs @@ -2,7 +2,7 @@ //! of the DFG except Input+Output into the DFG's parent, //! and deleting the DFG along with its Input + Output -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; use crate::ops::handle::{DfgID, NodeHandle}; use crate::{IncomingPort, Node, OutgoingPort, PortIndex}; @@ -21,12 +21,10 @@ pub enum InlineDFGError { NoParent, } -impl Rewrite for InlineDFG { - /// Returns the removed nodes: the DFG, and its Input and Output children. - type ApplyResult = [Node; 3]; +impl PatchVerification for InlineDFG { type Error = InlineDFGError; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = Node; fn verify(&self, h: &impl crate::HugrView) -> Result<(), Self::Error> { let n = self.0.node(); @@ -39,7 +37,21 @@ impl Rewrite for InlineDFG { Ok(()) } - fn apply(self, h: &mut impl crate::hugr::HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + [self.0.node()].into_iter() + } +} + +impl PatchHugrMut for InlineDFG { + /// The removed nodes: the DFG, and its Input and Output children. + type Outcome = [Node; 3]; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( + self, + h: &mut impl crate::hugr::HugrMut, + ) -> Result { self.verify(h)?; let n = self.0.node(); let (oth_in, oth_out) = { @@ -120,10 +132,6 @@ impl Rewrite for InlineDFG { h.remove_node(n); Ok([n, input, output]) } - - fn invalidation_set(&self) -> impl Iterator { - [self.0.node()].into_iter() - } } #[cfg(test)] @@ -137,8 +145,6 @@ mod test { SubContainer, }; use crate::extension::prelude::qb_t; - use crate::extension::ExtensionSet; - use crate::hugr::rewrite::inline_dfg::InlineDFGError; use crate::hugr::HugrMut; use crate::ops::handle::{DfgID, NodeHandle}; use crate::ops::{OpType, Value}; @@ -167,6 +173,8 @@ mod test { #[case(true)] #[case(false)] fn inline_add_load_const(#[case] nonlocal: bool) -> Result<(), Box> { + use crate::hugr::patch::inline_dfg::InlineDFGError; + let int_ty = &int_types::INT_TYPES[6]; let mut outer = DFGBuilder::new(inout_sig(vec![int_ty.clone(); 2], vec![int_ty.clone()]))?; @@ -208,13 +216,13 @@ mod test { // Check we can't inline the outer DFG let mut h = outer.clone(); assert_eq!( - h.apply_rewrite(InlineDFG(DfgID::from(h.root()))), + h.apply_patch(InlineDFG(DfgID::from(h.root()))), Err(InlineDFGError::NoParent) ); assert_eq!(h, outer); // unchanged } - outer.apply_rewrite(InlineDFG(*inner.handle()))?; + outer.apply_patch(InlineDFG(*inner.handle()))?; outer.validate()?; assert_eq!(outer.nodes().count(), 7); assert_eq!(find_dfgs(&outer), vec![outer.root()]); @@ -270,7 +278,7 @@ mod test { ] ); - h.apply_rewrite(InlineDFG(*swap.handle()))?; + h.apply_patch(InlineDFG(*swap.handle()))?; assert_eq!(find_dfgs(&h), vec![h.root()]); assert_eq!(h.nodes().count(), 5); // Dfg+I+O let mut ops = extension_ops(&h); @@ -326,12 +334,8 @@ mod test { .add_dataflow_op(test_quantum_extension::measure(), r.outputs())? .outputs_arr(); // Node using the boolean. Here we just select between two empty computations. - let mut if_n = inner.conditional_builder_exts( - ([type_row![], type_row![]], b), - [], - type_row![], - ExtensionSet::new(), - )?; + let mut if_n = + inner.conditional_builder(([type_row![], type_row![]], b), [], type_row![])?; if_n.case_builder(0)?.finish_with_outputs([])?; if_n.case_builder(1)?.finish_with_outputs([])?; let if_n = if_n.finish_sub_container()?; @@ -346,7 +350,7 @@ mod test { )?; let mut outer = outer.finish_hugr_with_outputs(cx.outputs())?; - outer.apply_rewrite(InlineDFG(*inner.handle()))?; + outer.apply_patch(InlineDFG(*inner.handle()))?; outer.validate()?; let order_neighbours = |n, d| { let p = outer.get_optype(n).other_port(d).unwrap(); diff --git a/hugr-core/src/hugr/rewrite/insert_identity.rs b/hugr-core/src/hugr/patch/insert_identity.rs similarity index 82% rename from hugr-core/src/hugr/rewrite/insert_identity.rs rename to hugr-core/src/hugr/patch/insert_identity.rs index 2114be8fd..c1f959ccd 100644 --- a/hugr-core/src/hugr/rewrite/insert_identity.rs +++ b/hugr-core/src/hugr/patch/insert_identity.rs @@ -2,6 +2,7 @@ use std::iter; +use crate::core::HugrNode; use crate::extension::prelude::Noop; use crate::hugr::{HugrMut, Node}; use crate::ops::{OpTag, OpTrait}; @@ -9,22 +10,22 @@ use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; use crate::{HugrView, IncomingPort}; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; use thiserror::Error; /// Specification of a identity-insertion operation. #[derive(Debug, Clone)] -pub struct IdentityInsertion { +pub struct IdentityInsertion { /// The node following the identity to be inserted. - pub post_node: Node, + pub post_node: N, /// The port following the identity to be inserted. pub post_port: IncomingPort, } -impl IdentityInsertion { +impl IdentityInsertion { /// Create a new [`IdentityInsertion`] specification. - pub fn new(post_node: Node, post_port: IncomingPort) -> Self { + pub fn new(post_node: N, post_port: IncomingPort) -> Self { Self { post_node, post_port, @@ -47,11 +48,10 @@ pub enum IdentityInsertionError { InvalidPortKind(Option), } -impl Rewrite for IdentityInsertion { +impl PatchVerification for IdentityInsertion { type Error = IdentityInsertionError; - /// The inserted node. - type ApplyResult = Node; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = N; + fn verify(&self, _h: &impl HugrView) -> Result<(), IdentityInsertionError> { /* Assumptions: @@ -65,7 +65,23 @@ impl Rewrite for IdentityInsertion { unimplemented!() } - fn apply(self, h: &mut impl HugrMut) -> Result { + + #[inline] + fn invalidation_set(&self) -> impl Iterator { + iter::once(self.post_node) + } +} + +impl PatchHugrMut for IdentityInsertion { + /// The inserted node. + type Outcome = N; + + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result { let kind = h.get_optype(self.post_node).port_kind(self.post_port); let Some(EdgeKind::Value(ty)) = kind else { return Err(IdentityInsertionError::InvalidPortKind(kind)); @@ -88,11 +104,6 @@ impl Rewrite for IdentityInsertion { h.connect(new_node, 0, self.post_node, self.post_port); Ok(new_node) } - - #[inline] - fn invalidation_set(&self) -> impl Iterator { - iter::once(self.post_node) - } } #[cfg(test)] @@ -107,7 +118,7 @@ mod tests { fn correct_insertion(dfg_hugr: Hugr) { let mut h = dfg_hugr; - assert_eq!(h.node_count(), 6); + assert_eq!(h.num_nodes(), 6); let final_node = h .input_neighbours(h.get_io(h.root()).unwrap()[1]) @@ -118,9 +129,9 @@ mod tests { let rw = IdentityInsertion::new(final_node, final_node_port); - let noop_node = h.apply_rewrite(rw).unwrap(); + let noop_node = h.apply_patch(rw).unwrap(); - assert_eq!(h.node_count(), 7); + assert_eq!(h.num_nodes(), 7); let noop: Noop = h.get_optype(noop_node).cast().unwrap(); diff --git a/hugr-core/src/hugr/rewrite/outline_cfg.rs b/hugr-core/src/hugr/patch/outline_cfg.rs similarity index 86% rename from hugr-core/src/hugr/rewrite/outline_cfg.rs rename to hugr-core/src/hugr/patch/outline_cfg.rs index 7294bfcad..b9cafed9e 100644 --- a/hugr-core/src/hugr/rewrite/outline_cfg.rs +++ b/hugr-core/src/hugr/patch/outline_cfg.rs @@ -1,23 +1,21 @@ -//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection of an existing CFG +//! Rewrite for inserting a CFG-node into the hierarchy containing a subsection +//! of an existing CFG use std::collections::HashSet; use itertools::Itertools; use thiserror::Error; use crate::builder::{BlockBuilder, Container, Dataflow, SubContainer}; -use crate::extension::ExtensionSet; -use crate::hugr::internal::HugrMutInternals; -use crate::hugr::rewrite::Rewrite; -use crate::hugr::views::sibling::SiblingMut; use crate::hugr::{HugrMut, HugrView}; use crate::ops; use crate::ops::controlflow::BasicBlock; -use crate::ops::dataflow::DataflowOpTrait; -use crate::ops::handle::{BasicBlockID, CfgID, NodeHandle}; +use crate::ops::handle::NodeHandle; use crate::ops::{DataflowBlock, OpType}; use crate::PortIndex; use crate::{type_row, Node}; +use super::{PatchHugrMut, PatchVerification}; + /// Moves part of a Control-flow Sibling Graph into a new CFG-node /// that is the only child of a new Basic Block in the original CSG. pub struct OutlineCfg { @@ -33,12 +31,11 @@ impl OutlineCfg { } /// Compute the entry and exit nodes of the CFG which contains - /// [`self.blocks`], along with the output neighbour its parent graph and - /// the combined extension_deltas of all of the blocks. - fn compute_entry_exit_outside_extensions( + /// [`self.blocks`], along with the output neighbour its parent graph. + fn compute_entry_exit( &self, h: &impl HugrView, - ) -> Result<(Node, Node, Node, ExtensionSet), OutlineCfgError> { + ) -> Result<(Node, Node, Node), OutlineCfgError> { let cfg_n = match self .blocks .iter() @@ -50,13 +47,12 @@ impl OutlineCfg { _ => return Err(OutlineCfgError::NotSiblings), }; let o = h.get_optype(cfg_n); - let OpType::CFG(o) = o else { + let OpType::CFG(_) = o else { return Err(OutlineCfgError::ParentNotCfg(cfg_n, o.clone())); }; let cfg_entry = h.children(cfg_n).next().unwrap(); let mut entry = None; let mut exit_succ = None; - let mut extension_delta = ExtensionSet::new(); for &n in self.blocks.iter() { if n == cfg_entry || h.input_neighbours(n) @@ -71,7 +67,6 @@ impl OutlineCfg { } } } - extension_delta = extension_delta.union(o.signature().runtime_reqs.clone()); let external_succs = h.output_neighbours(n).filter(|s| !self.blocks.contains(s)); match external_succs.at_most_one() { Ok(None) => (), // No external successors @@ -87,28 +82,38 @@ impl OutlineCfg { }; } match (entry, exit_succ) { - (Some(e), Some((x, o))) => Ok((e, x, o, extension_delta)), + (Some(e), Some((x, o))) => Ok((e, x, o)), (None, _) => Err(OutlineCfgError::NoEntryNode), (_, None) => Err(OutlineCfgError::NoExitNode), } } } -impl Rewrite for OutlineCfg { +impl PatchVerification for OutlineCfg { type Error = OutlineCfgError; + type Node = Node; + fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { + self.compute_entry_exit(h)?; + Ok(()) + } + + fn invalidation_set(&self) -> impl Iterator { + self.blocks.iter().copied() + } +} + +impl PatchHugrMut for OutlineCfg { /// The newly-created basic block, and the [CFG] node inside it /// /// [CFG]: OpType::CFG - type ApplyResult = (Node, Node); + type Outcome = [Node; 2]; const UNCHANGED_ON_FAILURE: bool = true; - fn verify(&self, h: &impl HugrView) -> Result<(), OutlineCfgError> { - self.compute_entry_exit_outside_extensions(h)?; - Ok(()) - } - fn apply(self, h: &mut impl HugrMut) -> Result<(Node, Node), OutlineCfgError> { - let (entry, exit, outside, extension_delta) = - self.compute_entry_exit_outside_extensions(h)?; + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result<[Node; 2], OutlineCfgError> { + let (entry, exit, outside) = self.compute_entry_exit(h)?; // 1. Compute signature // These panic()s only happen if the Hugr would not have passed validate() let OpType::DataflowBlock(DataflowBlock { inputs, .. }) = h.get_optype(entry) else { @@ -125,17 +130,10 @@ impl Rewrite for OutlineCfg { // 2. new_block contains input node, sub-cfg, exit node all connected let (new_block, cfg_node) = { - let mut new_block_bldr = BlockBuilder::new_exts( - inputs.clone(), - vec![type_row![]], - outputs.clone(), - extension_delta.clone(), - ) - .unwrap(); + let mut new_block_bldr = + BlockBuilder::new(inputs.clone(), vec![type_row![]], outputs.clone()).unwrap(); let wires_in = inputs.iter().cloned().zip(new_block_bldr.input_wires()); - let cfg = new_block_bldr - .cfg_builder_exts(wires_in, outputs, extension_delta) - .unwrap(); + let cfg = new_block_bldr.cfg_builder(wires_in, outputs).unwrap(); let cfg = cfg.finish_sub_container().unwrap(); let unit_sum = new_block_bldr.add_constant(ops::Value::unary_unit_sum()); let pred_wire = new_block_bldr.load_const(&unit_sum); @@ -185,8 +183,19 @@ impl Rewrite for OutlineCfg { let inner_exit = { // These operations do not fit within any CSG/SiblingMut // so we need to access the Hugr directly. - let h = h.hugr_mut(); - let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); + // + // TODO: This is a temporary hack that won't be needed once Hugr Root Pointers get implemented. + // The commented line below are the correct ones, but they don't work yet. + // https://github.com/CQCL/hugr/issues/2029 + let hierarchy = h.hierarchy(); + let inner_exit = hierarchy + .children(h.to_portgraph_node(cfg_node)) + .exactly_one() + .ok() + .unwrap(); + let inner_exit = h.from_portgraph_node(inner_exit); + //let inner_exit = h.children(cfg_node).exactly_one().ok().unwrap(); + // Entry node must be first h.move_before_sibling(entry, inner_exit); // And remaining nodes @@ -200,18 +209,9 @@ impl Rewrite for OutlineCfg { }; // 4(b). Reconnect exit edge to the new exit node within the inner CFG - // Use nested SiblingMut's in case the outer `h` is only a SiblingMut itself. - let mut in_bb_view: SiblingMut<'_, BasicBlockID> = - SiblingMut::try_new(h, new_block).unwrap(); - let mut in_cfg_view: SiblingMut<'_, CfgID> = - SiblingMut::try_new(&mut in_bb_view, cfg_node).unwrap(); - in_cfg_view.connect(exit, exit_port, inner_exit, 0); - - Ok((new_block, cfg_node)) - } + h.connect(exit, exit_port, inner_exit, 0); - fn invalidation_set(&self) -> impl Iterator { - self.blocks.iter().copied() + Ok([new_block, cfg_node]) } } @@ -252,10 +252,9 @@ mod test { HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::usize_t; - use crate::hugr::views::sibling::SiblingMut; use crate::hugr::HugrMut; use crate::ops::constant::Value; - use crate::ops::handle::{BasicBlockID, CfgID, ConstID, NodeHandle}; + use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; use crate::types::Signature; use crate::{Hugr, HugrView, Node}; use cool_asserts::assert_matches; @@ -357,22 +356,22 @@ mod test { } = cond_then_loop_cfg; let backup = h.clone(); - let r = h.apply_rewrite(OutlineCfg::new([tail])); + let r = h.apply_patch(OutlineCfg::new([tail])); assert_matches!(r, Err(OutlineCfgError::MultipleExitEdges(_, _))); assert_eq!(h, backup); - let r = h.apply_rewrite(OutlineCfg::new([entry, left, right])); + let r = h.apply_patch(OutlineCfg::new([entry, left, right])); assert_matches!(r, Err(OutlineCfgError::MultipleExitNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from_iter([left, right]))); assert_eq!(h, backup); - let r = h.apply_rewrite(OutlineCfg::new([left, right, merge])); + let r = h.apply_patch(OutlineCfg::new([left, right, merge])); assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from([left, right]))); assert_eq!(h, backup); // The entry node implicitly has an extra incoming edge - let r = h.apply_rewrite(OutlineCfg::new([entry, left, right, merge, head])); + let r = h.apply_patch(OutlineCfg::new([entry, left, right, merge, head])); assert_matches!(r, Err(OutlineCfgError::MultipleEntryNodes(a,b)) => assert_eq!(HashSet::from([a,b]), HashSet::from([entry, head]))); assert_eq!(h, backup); @@ -457,11 +456,7 @@ mod test { h.output_neighbours(tail).collect::>(), HashSet::from([head, exit_node]) ); - outline_cfg_check_parents( - &mut SiblingMut::<'_, CfgID>::try_new(&mut h, cfg).unwrap(), - cfg, - vec![head, tail], - ); + outline_cfg_check_parents(&mut h, cfg, vec![head, tail]); h.validate().unwrap(); } @@ -491,19 +486,20 @@ mod test { } fn outline_cfg_check_parents( - h: &mut impl HugrMut, + h: &mut impl HugrMut, cfg: Node, blocks: Vec, ) -> (Node, Node, Node) { let mut other_blocks = h.children(cfg).collect::>(); assert!(blocks.iter().all(|b| other_blocks.remove(b))); - let (new_block, new_cfg) = h.apply_rewrite(OutlineCfg::new(blocks.clone())).unwrap(); + let [new_block, new_cfg] = h.apply_patch(OutlineCfg::new(blocks.clone())).unwrap(); for n in other_blocks { assert_eq!(h.get_parent(n), Some(cfg)) } assert_eq!(h.get_parent(new_block), Some(cfg)); assert!(h.get_optype(new_block).is_dataflow_block()); + #[allow(deprecated)] let b = h.base_hugr(); // To cope with `h` potentially being a SiblingMut assert_eq!(b.get_parent(new_cfg), Some(new_block)); for n in blocks { diff --git a/hugr-core/src/hugr/rewrite/port_types.rs b/hugr-core/src/hugr/patch/port_types.rs similarity index 100% rename from hugr-core/src/hugr/rewrite/port_types.rs rename to hugr-core/src/hugr/patch/port_types.rs diff --git a/hugr-core/src/hugr/rewrite/replace.rs b/hugr-core/src/hugr/patch/replace.rs similarity index 70% rename from hugr-core/src/hugr/rewrite/replace.rs rename to hugr-core/src/hugr/patch/replace.rs index 55c07d680..606733543 100644 --- a/hugr-core/src/hugr/rewrite/replace.rs +++ b/hugr-core/src/hugr/patch/replace.rs @@ -5,29 +5,33 @@ use std::collections::{HashMap, HashSet, VecDeque}; use itertools::Itertools; use thiserror::Error; +use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; +use crate::hugr::views::check_valid_non_root; use crate::hugr::HugrMut; use crate::ops::{OpTag, OpTrait}; use crate::types::EdgeKind; use crate::{Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort}; -use super::Rewrite; +use super::{PatchHugrMut, PatchVerification}; /// Specifies how to create a new edge. #[derive(Clone, Debug, PartialEq, Eq)] -pub struct NewEdgeSpec { - /// The source of the new edge. For [Replacement::mu_inp] and [Replacement::mu_new], this is in the - /// existing Hugr; for edges in [Replacement::mu_out] this is in the [Replacement::replacement] - pub src: Node, - /// The target of the new edge. For [Replacement::mu_inp], this is in the [Replacement::replacement]; - /// for edges in [Replacement::mu_out] and [Replacement::mu_new], this is in the existing Hugr. - pub tgt: Node, +pub struct NewEdgeSpec { + /// The source of the new edge. For [Replacement::mu_inp] and + /// [Replacement::mu_new], this is in the existing Hugr; for edges in + /// [Replacement::mu_out] this is in the [Replacement::replacement] + pub src: SrcNode, + /// The target of the new edge. For [Replacement::mu_inp], this is in the + /// [Replacement::replacement]; for edges in [Replacement::mu_out] and + /// [Replacement::mu_new], this is in the existing Hugr. + pub tgt: TgtNode, /// The kind of edge to create, and any port specifiers required pub kind: NewEdgeKind, } /// Describes an edge that should be created between two nodes already given -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum NewEdgeKind { /// An [EdgeKind::StateOrder] edge (between DFG nodes only) Order, @@ -54,40 +58,47 @@ pub enum NewEdgeKind { /// Specification of a `Replace` operation #[derive(Debug, Clone, PartialEq)] -pub struct Replacement { +pub struct Replacement { /// The nodes to remove from the existing Hugr (known as Gamma). - /// These must all have a common parent (i.e. be siblings). Called "S" in the spec. - /// Must be non-empty - otherwise there is no parent under which to place [Self::replacement], - /// and there would be no possible [Self::mu_inp], [Self::mu_out] or [Self::adoptions]. - pub removal: Vec, - /// A hugr (not necessarily valid, as it may be missing edges and/or nodes), whose root - /// is the same type as the root of [Self::replacement]. "G" in the spec. + /// These must all have a common parent (i.e. be siblings). Called "S" in + /// the spec. Must be non-empty - otherwise there is no parent under + /// which to place [Self::replacement], and there would be no possible + /// [Self::mu_inp], [Self::mu_out] or [Self::adoptions]. + pub removal: Vec, + /// A hugr (not necessarily valid, as it may be missing edges and/or nodes), + /// whose root is the same type as the root of [Self::replacement]. "G" + /// in the spec. pub replacement: Hugr, - /// Describes how parts of the Hugr that would otherwise be removed should instead be preserved but - /// with new parents amongst the newly-inserted nodes. This is a Map from container nodes in - /// [Self::replacement] that have no children, to container nodes that are descended from [Self::removal]. - /// The keys are the new parents for the children of the values. Note no value may be ancestor or - /// descendant of another. This is "B" in the spec; "R" is the set of descendants of [Self::removal] - /// that are not descendants of values here. - pub adoptions: HashMap, - /// Edges from nodes in the existing Hugr that are not removed ([NewEdgeSpec::src] in Gamma\R) - /// to inserted nodes ([NewEdgeSpec::tgt] in [Self::replacement]). - pub mu_inp: Vec, - /// Edges from inserted nodes ([NewEdgeSpec::src] in [Self::replacement]) to existing nodes not removed - /// ([NewEdgeSpec::tgt] in Gamma \ R). - pub mu_out: Vec, - /// Edges to add between existing nodes (both [NewEdgeSpec::src] and [NewEdgeSpec::tgt] in Gamma \ R). - /// For example, in cases where the source had an edge to a removed node, and the target had an - /// edge from a removed node, this would allow source to be directly connected to target. - pub mu_new: Vec, + /// Describes how parts of the Hugr that would otherwise be removed should + /// instead be preserved but with new parents amongst the newly-inserted + /// nodes. This is a Map from container nodes in [Self::replacement] + /// that have no children, to container nodes that are descended from + /// [Self::removal]. The keys are the new parents for the children of + /// the values. Note no value may be ancestor or descendant of another. + /// This is "B" in the spec; "R" is the set of descendants of + /// [Self::removal] that are not descendants of values here. + pub adoptions: HashMap, + /// Edges from nodes in the existing Hugr that are not removed + /// ([NewEdgeSpec::src] in Gamma\R) to inserted nodes + /// ([NewEdgeSpec::tgt] in [Self::replacement]). + pub mu_inp: Vec>, + /// Edges from inserted nodes ([NewEdgeSpec::src] in [Self::replacement]) to + /// existing nodes not removed ([NewEdgeSpec::tgt] in Gamma \ R). + pub mu_out: Vec>, + /// Edges to add between existing nodes (both [NewEdgeSpec::src] and + /// [NewEdgeSpec::tgt] in Gamma \ R). For example, in cases where the + /// source had an edge to a removed node, and the target had an + /// edge from a removed node, this would allow source to be directly + /// connected to target. + pub mu_new: Vec>, } -impl NewEdgeSpec { - fn check_src( +impl NewEdgeSpec { + fn check_src( &self, - h: &impl HugrView, - err_spec: &NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + err_spec: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { let optype = h.get_optype(self.src); let ok = match self.kind { NewEdgeKind::Order => optype.other_output() == Some(EdgeKind::StateOrder), @@ -103,13 +114,14 @@ impl NewEdgeSpec { } }; ok.then_some(()) - .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Outgoing, err_spec(self.clone()))) } - fn check_tgt( + + fn check_tgt( &self, - h: &impl HugrView, - err_spec: &NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + err_spec: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { let optype = h.get_optype(self.tgt); let ok = match self.kind { NewEdgeKind::Order => optype.other_input() == Some(EdgeKind::StateOrder), @@ -126,18 +138,20 @@ impl NewEdgeSpec { ), }; ok.then_some(()) - .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec.clone())) + .ok_or_else(|| ReplaceError::BadEdgeKind(Direction::Incoming, err_spec(self.clone()))) } +} +impl NewEdgeSpec { fn check_existing_edge( &self, - h: &impl HugrView, - legal_src_ancestors: &HashSet, - err_edge: impl Fn() -> NewEdgeSpec, - ) -> Result<(), ReplaceError> { + h: &impl HugrView, + legal_src_ancestors: &HashSet, + err_edge: impl Fn(Self) -> WhichEdgeSpec, + ) -> Result<(), ReplaceError> { if let NewEdgeKind::Static { tgt_pos, .. } | NewEdgeKind::Value { tgt_pos, .. } = self.kind { - let descends_from_legal = |mut descendant: Node| -> bool { + let descends_from_legal = |mut descendant: HostNode| -> bool { while !legal_src_ancestors.contains(&descendant) { let Some(p) = h.get_parent(descendant) else { return false; @@ -150,15 +164,18 @@ impl NewEdgeSpec { .single_linked_output(self.tgt, tgt_pos) .is_some_and(|(src_n, _)| descends_from_legal(src_n)); if !found_incoming { - return Err(ReplaceError::NoRemovedEdge(err_edge())); + return Err(ReplaceError::NoRemovedEdge(err_edge(self.clone()))); }; }; Ok(()) } } -impl Replacement { - fn check_parent(&self, h: &impl HugrView) -> Result { +impl Replacement { + fn check_parent( + &self, + h: &impl HugrView, + ) -> Result> { let parent = self .removal .iter() @@ -168,10 +185,11 @@ impl Replacement { .map_err(|ex_one| ReplaceError::MultipleParents(ex_one.flatten().collect()))? .ok_or(ReplaceError::CantReplaceRoot)?; // If no parent - // Check replacement parent is of same tag. Note we do not require exact equality - // of OpType/Signature, e.g. to ease changing of Input/Output node signatures too. + // Check replacement parent is of same tag. Note we do not require exact + // equality of OpType/Signature, e.g. to ease changing of Input/Output + // node signatures too. let removed = h.get_optype(parent).tag(); - let replacement = self.replacement.root_type().tag(); + let replacement = self.replacement.root_optype().tag(); if removed != replacement { return Err(ReplaceError::WrongRootNodeTag { removed, @@ -183,8 +201,8 @@ impl Replacement { fn get_removed_nodes( &self, - h: &impl HugrView, - ) -> Result, ReplaceError> { + h: &impl HugrView, + ) -> Result, ReplaceError> { // Check the keys of the transfer map too, the values we'll use imminently self.adoptions.keys().try_for_each(|&n| { (self.replacement.contains_node(n) @@ -193,7 +211,7 @@ impl Replacement { .then_some(()) .ok_or(ReplaceError::InvalidAdoptingParent(n)) })?; - let mut transferred: HashSet = self.adoptions.values().copied().collect(); + let mut transferred: HashSet = self.adoptions.values().copied().collect(); if transferred.len() != self.adoptions.values().len() { return Err(ReplaceError::AdopteesNotSeparateDescendants( self.adoptions @@ -221,97 +239,149 @@ impl Replacement { Ok(removed) } } -impl Rewrite for Replacement { - type Error = ReplaceError; - - /// Map from Node in replacement to corresponding Node in the result Hugr - type ApplyResult = HashMap; - const UNCHANGED_ON_FAILURE: bool = false; +impl PatchVerification for Replacement { + type Error = ReplaceError; + type Node = HostNode; - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { + fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { self.check_parent(h)?; let removed = self.get_removed_nodes(h)?; // Edge sources... - for e in self.mu_inp.iter().chain(self.mu_new.iter()) { + for e in self.mu_inp.iter() { if !h.contains_node(e.src) || removed.contains(&e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Retained, - e.clone(), + WhichEdgeSpec::HostToRepl(e.clone()), )); } - e.check_src(h, e)?; + e.check_src(h, WhichEdgeSpec::HostToRepl)?; } - self.mu_out - .iter() - .try_for_each(|e| match self.replacement.valid_non_root(e.src) { - true => e.check_src(&self.replacement, e), + for e in self.mu_new.iter() { + if !h.contains_node(e.src) || removed.contains(&e.src) { + return Err(ReplaceError::BadEdgeSpec( + Direction::Outgoing, + WhichEdgeSpec::HostToHost(e.clone()), + )); + } + e.check_src(h, WhichEdgeSpec::HostToHost)?; + } + self.mu_out.iter().try_for_each(|e| { + match check_valid_non_root(&self.replacement, e.src) { + true => e.check_src(&self.replacement, WhichEdgeSpec::ReplToHost), false => Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Replacement, - e.clone(), + WhichEdgeSpec::ReplToHost(e.clone()), )), - })?; + } + })?; // Edge targets... - self.mu_inp - .iter() - .try_for_each(|e| match self.replacement.valid_non_root(e.tgt) { - true => e.check_tgt(&self.replacement, e), + self.mu_inp.iter().try_for_each(|e| { + match check_valid_non_root(&self.replacement, e.tgt) { + true => e.check_tgt(&self.replacement, WhichEdgeSpec::HostToRepl), false => Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Replacement, - e.clone(), + WhichEdgeSpec::HostToRepl(e.clone()), )), - })?; - for e in self.mu_out.iter().chain(self.mu_new.iter()) { + } + })?; + for e in self.mu_out.iter() { + if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { + return Err(ReplaceError::BadEdgeSpec( + Direction::Incoming, + WhichEdgeSpec::ReplToHost(e.clone()), + )); + } + e.check_tgt(h, WhichEdgeSpec::ReplToHost)?; + // The descendant check is to allow the case where the old edge is nonlocal + // from a part of the Hugr being moved (which may require changing source, + // depending on where the transplanted portion ends up). While this subsumes + // the first "removed.contains" check, we'll keep that as a common-case + // fast-path. + e.check_existing_edge(h, &removed, WhichEdgeSpec::ReplToHost)?; + } + for e in self.mu_new.iter() { if !h.contains_node(e.tgt) || removed.contains(&e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Retained, - e.clone(), + WhichEdgeSpec::HostToHost(e.clone()), )); } - e.check_tgt(h, e)?; + e.check_tgt(h, WhichEdgeSpec::HostToHost)?; // The descendant check is to allow the case where the old edge is nonlocal // from a part of the Hugr being moved (which may require changing source, // depending on where the transplanted portion ends up). While this subsumes - // the first "removed.contains" check, we'll keep that as a common-case fast-path. - e.check_existing_edge(h, &removed, || e.clone())?; + // the first "removed.contains" check, we'll keep that as a common-case + // fast-path. + e.check_existing_edge(h, &removed, WhichEdgeSpec::HostToHost)?; } Ok(()) } - fn apply(self, h: &mut impl HugrMut) -> Result { + fn invalidation_set(&self) -> impl Iterator { + self.removal.iter().copied() + } +} + +impl PatchHugrMut for Replacement { + /// Map from Node in replacement to corresponding Node in the result Hugr + type Outcome = HashMap; + + const UNCHANGED_ON_FAILURE: bool = false; + + fn apply_hugr_mut( + self, + h: &mut impl HugrMut, + ) -> Result { let parent = self.check_parent(h)?; // Calculate removed nodes here. (Does not include transfers, so enumerates only - // nodes we are going to remove, individually, anyway; so no *asymptotic* speed penalty) + // nodes we are going to remove, individually, anyway; so no *asymptotic* speed + // penalty) let to_remove = self.get_removed_nodes(h)?; - // 1. Add all the new nodes. Note this includes replacement.root(), which we don't want. + // 1. Add all the new nodes. Note this includes replacement.root(), which we + // don't want. // TODO what would an error here mean? e.g. malformed self.replacement?? let InsertionResult { new_root, node_map } = h.insert_hugr(parent, self.replacement); // 2. Add new edges from existing to copied nodes according to mu_in - let translate_idx = |n| node_map.get(&n).copied().ok_or(WhichHugr::Replacement); - let kept = |n| { - let keep = !to_remove.contains(&n); - keep.then_some(n).ok_or(WhichHugr::Retained) - }; - transfer_edges(h, self.mu_inp.iter(), kept, translate_idx, None)?; + let translate_idx = |n| node_map.get(&n).copied(); + let kept = |n| (!to_remove.contains(&n)).then_some(n); + transfer_edges( + h, + self.mu_inp.iter(), + kept, + translate_idx, + WhichEdgeSpec::HostToRepl, + None, + )?; // 3. Add new edges from copied to existing nodes according to mu_out, // replacing existing value/static edges incoming to targets - transfer_edges(h, self.mu_out.iter(), translate_idx, kept, Some(&to_remove))?; + transfer_edges( + h, + self.mu_out.iter(), + translate_idx, + kept, + WhichEdgeSpec::ReplToHost, + Some(&to_remove), + )?; // 4. Add new edges between existing nodes according to mu_new, // replacing existing value/static edges incoming to targets. - transfer_edges(h, self.mu_new.iter(), kept, kept, Some(&to_remove))?; + transfer_edges( + h, + self.mu_new.iter(), + kept, + kept, + WhichEdgeSpec::HostToHost, + Some(&to_remove), + )?; // 5. Put newly-added copies into correct places in hierarchy // (these will be correct places after removing nodes) let mut remove_top_sibs = self.removal.iter(); - for new_node in h.children(new_root).collect::>().into_iter() { + for new_node in h.children(new_root).collect::>().into_iter() { if let Some(top_sib) = remove_top_sibs.next() { h.move_before_sibling(new_node, *top_sib); } else { @@ -336,51 +406,53 @@ impl Rewrite for Replacement { }); Ok(node_map) } - - fn invalidation_set(&self) -> impl Iterator { - self.removal.iter().copied() - } } -fn transfer_edges<'a>( - h: &mut impl HugrMut, - edges: impl Iterator, - trans_src: impl Fn(Node) -> Result, - trans_tgt: impl Fn(Node) -> Result, - legal_src_ancestors: Option<&HashSet>, -) -> Result<(), ReplaceError> { +fn transfer_edges<'a, SrcNode, TgtNode, HostNode>( + h: &mut impl HugrMut, + edges: impl Iterator>, + trans_src: impl Fn(SrcNode) -> Option, + trans_tgt: impl Fn(TgtNode) -> Option, + err_spec: impl Fn(NewEdgeSpec) -> WhichEdgeSpec, + legal_src_ancestors: Option<&HashSet>, +) -> Result<(), ReplaceError> +where + SrcNode: 'a + HugrNode, + TgtNode: 'a + HugrNode, + HostNode: 'a + HugrNode, +{ for oe in edges { + let err_spec = err_spec(oe.clone()); let e = NewEdgeSpec { // Translation can only fail for Nodes that are supposed to be in the replacement src: trans_src(oe.src) - .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Outgoing, h, oe.clone()))?, + .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Outgoing, err_spec.clone()))?, tgt: trans_tgt(oe.tgt) - .map_err(|h| ReplaceError::BadEdgeSpec(Direction::Incoming, h, oe.clone()))?, - ..oe.clone() + .ok_or_else(|| ReplaceError::BadEdgeSpec(Direction::Incoming, err_spec.clone()))?, + kind: oe.kind, }; - if !h.valid_node(e.src) { + if !h.contains_node(e.src) { return Err(ReplaceError::BadEdgeSpec( Direction::Outgoing, - WhichHugr::Retained, - oe.clone(), + err_spec.clone(), )); } - if !h.valid_node(e.tgt) { + if !h.contains_node(e.tgt) { return Err(ReplaceError::BadEdgeSpec( Direction::Incoming, - WhichHugr::Retained, - oe.clone(), + err_spec.clone(), )); }; - e.check_src(h, oe)?; - e.check_tgt(h, oe)?; + let err_spec = |_| err_spec.clone(); + e.check_src(h, err_spec)?; + e.check_tgt(h, err_spec)?; match e.kind { NewEdgeKind::Order => { h.add_other_edge(e.src, e.tgt); } NewEdgeKind::Value { src_pos, tgt_pos } | NewEdgeKind::Static { src_pos, tgt_pos } => { if let Some(legal_src_ancestors) = legal_src_ancestors { - e.check_existing_edge(h, legal_src_ancestors, || oe.clone())?; + e.check_existing_edge(h, legal_src_ancestors, err_spec)?; h.disconnect(e.tgt, tgt_pos); } h.connect(e.src, src_pos, e.tgt, tgt_pos); @@ -394,14 +466,14 @@ fn transfer_edges<'a>( /// Error in a [`Replacement`] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[non_exhaustive] -pub enum ReplaceError { +pub enum ReplaceError { /// The node(s) to replace had no parent i.e. were root(s). // (Perhaps if there is only one node to replace we should be able to?) #[error("Cannot replace the root node of the Hugr")] CantReplaceRoot, /// The nodes to replace did not have a unique common parent #[error("Removed nodes had different parents {0:?}")] - MultipleParents(Vec), + MultipleParents(Vec), /// Replacement root node had different tag from parent of removed nodes #[error("Expected replacement root with tag {removed} but found {replacement}")] WrongRootNodeTag { @@ -410,40 +482,47 @@ pub enum ReplaceError { /// The tag of the root in the replacement Hugr replacement: OpTag, }, - /// Keys in [Replacement::adoptions] were not valid container nodes in [Replacement::replacement] + /// Keys in [Replacement::adoptions] were not valid container nodes in + /// [Replacement::replacement] #[error("Node {0} was not an empty container node in the replacement")] InvalidAdoptingParent(Node), - /// Some values in [Replacement::adoptions] were either descendants of other values, or not - /// descendants of the [Replacement::removal]. The nodes are indicated on a best-effort basis. + /// Some values in [Replacement::adoptions] were either descendants of other + /// values, or not descendants of the [Replacement::removal]. The nodes + /// are indicated on a best-effort basis. #[error("Nodes not free to be moved into new locations: {0:?}")] - AdopteesNotSeparateDescendants(Vec), + AdopteesNotSeparateDescendants(Vec), /// A node at one end of a [NewEdgeSpec] was not found - #[error("{0:?} end of edge {2:?} not found in {1}")] - BadEdgeSpec(Direction, WhichHugr, NewEdgeSpec), - /// The target of the edge was found, but there was no existing edge to replace + #[error("{0:?} end of edge {1:?} not found in {which_hugr}", which_hugr = .1.which_hugr(*.0))] + BadEdgeSpec(Direction, WhichEdgeSpec), + /// The target of the edge was found, but there was no existing edge to + /// replace #[error("Target of edge {0:?} did not have a corresponding incoming edge being removed")] - NoRemovedEdge(NewEdgeSpec), + NoRemovedEdge(WhichEdgeSpec), /// The [NewEdgeKind] was not applicable for the source/target node(s) #[error("The edge kind was not applicable to the {0:?} node: {1:?}")] - BadEdgeKind(Direction, NewEdgeSpec), + BadEdgeKind(Direction, WhichEdgeSpec), } -/// A Hugr or portion thereof that is part of the [Replacement] +/// The three kinds of [NewEdgeSpec] that may appear in a [ReplaceError] #[derive(Clone, Debug, PartialEq, Eq)] -pub enum WhichHugr { - /// The newly-inserted nodes, i.e. the [Replacement::replacement] - Replacement, - /// Nodes in the existing Hugr that are not [Replacement::removal] - /// (or are on the RHS of an entry in [Replacement::adoptions]) - Retained, +pub enum WhichEdgeSpec { + /// An edge from the host Hugr into the replacement, i.e. + /// [Replacement::mu_inp] + HostToRepl(NewEdgeSpec), + /// An edge from the replacement to the host, i.e. [Replacement::mu_out] + ReplToHost(NewEdgeSpec), + /// An edge between two nodes in the host (bypassing the replacement), + /// i.e. [Replacement::mu_new] + HostToHost(NewEdgeSpec), } -impl std::fmt::Display for WhichHugr { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(match self { - Self::Replacement => "replacement Hugr", - Self::Retained => "retained portion of Hugr", - }) +impl WhichEdgeSpec { + fn which_hugr(&self, d: Direction) -> &str { + match (self, d) { + (Self::HostToRepl(_), Direction::Incoming) + | (Self::ReplToHost(_), Direction::Outgoing) => "replacement Hugr", + _ => "retained portion of Hugr", + } } } @@ -461,8 +540,8 @@ mod test { use crate::extension::prelude::{bool_t, usize_t}; use crate::extension::{ExtensionRegistry, PRELUDE}; use crate::hugr::internal::HugrMutInternals; - use crate::hugr::rewrite::replace::WhichHugr; - use crate::hugr::{HugrMut, Rewrite}; + use crate::hugr::patch::PatchVerification; + use crate::hugr::{HugrMut, Patch}; use crate::ops::custom::ExtensionOp; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{BasicBlockID, ConstID, NodeHandle}; @@ -472,7 +551,7 @@ mod test { use crate::utils::{depth, test_quantum_extension}; use crate::{type_row, Direction, Extension, Hugr, HugrView, OutgoingPort}; - use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement}; + use super::{NewEdgeKind, NewEdgeSpec, ReplaceError, Replacement, WhichEdgeSpec}; #[test] #[ignore] // FIXME: This needs a rewrite now that `pop` returns an optional value -.-' @@ -519,7 +598,8 @@ mod test { } // Replacement: one BB with two DFGs inside. - // Use Hugr rather than Builder because DFGs must be empty (not even Input/Output). + // Use Hugr rather than Builder because it must be empty (not even + // Input/Output). let mut replacement = Hugr::new(ops::CFG { signature: Signature::new_endo(just_list.clone()), }); @@ -529,21 +609,18 @@ mod test { inputs: vec![listy.clone()].into(), sum_rows: vec![type_row![]], other_outputs: vec![listy.clone()].into(), - extension_delta: list::EXTENSION_ID.into(), }, ); let r_df1 = replacement.add_node_with_parent( r_bb, DFG { - signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())) - .with_extension_delta(list::EXTENSION_ID), + signature: Signature::new(vec![listy.clone()], simple_unary_plus(intermed.clone())), }, ); let r_df2 = replacement.add_node_with_parent( r_bb, DFG { - signature: Signature::new(intermed, simple_unary_plus(just_list.clone())) - .with_extension_delta(list::EXTENSION_ID), + signature: Signature::new(intermed, simple_unary_plus(just_list.clone())), }, ); [0, 1] @@ -568,7 +645,7 @@ mod test { replacement.connect(r_df2, 1, out, 1); } - h.apply_rewrite(Replacement { + h.apply_patch(Replacement { removal: vec![entry.node(), bb2.node()], replacement, adoptions: HashMap::from([(r_df1.node(), entry.node()), (r_df2.node(), bb2.node())]), @@ -626,7 +703,7 @@ mod test { }, op_sig.input() ); - h.simple_entry_builder_exts(op_sig.output.clone(), 1, op_sig.runtime_reqs.clone())? + h.simple_entry_builder(op_sig.output.clone(), 1)? } else { h.simple_block_builder(op_sig.into_owned(), 1)? }; @@ -653,25 +730,20 @@ mod test { ext.add_op("baz".into(), "".to_string(), utou.clone(), extension_ref) .unwrap(); }); - let ext_name = ext.name().clone(); let foo = ext.instantiate_extension_op("foo", []).unwrap(); let bar = ext.instantiate_extension_op("bar", []).unwrap(); let baz = ext.instantiate_extension_op("baz", []).unwrap(); let mut registry = test_quantum_extension::REG.clone(); registry.register(ext).unwrap(); - let mut h = DFGBuilder::new( - Signature::new(vec![usize_t(), bool_t()], vec![usize_t()]) - .with_extension_delta(ext_name.clone()), - ) - .unwrap(); + let mut h = + DFGBuilder::new(Signature::new(vec![usize_t(), bool_t()], vec![usize_t()])).unwrap(); let [i, b] = h.input_wires_arr(); let mut cond = h - .conditional_builder_exts( + .conditional_builder( (vec![type_row![]; 2], b), [(usize_t(), i)], vec![usize_t()].into(), - ext_name.clone(), ) .unwrap(); let mut case1 = cond.case_builder(0).unwrap(); @@ -679,12 +751,7 @@ mod test { let case1 = case1.finish_with_outputs(foo.outputs()).unwrap().node(); let mut case2 = cond.case_builder(1).unwrap(); let bar = case2.add_dataflow_op(bar, case2.input_wires()).unwrap(); - let mut baz_dfg = case2 - .dfg_builder( - utou.clone().with_extension_delta(ext_name.clone()), - bar.outputs(), - ) - .unwrap(); + let mut baz_dfg = case2.dfg_builder(utou.clone(), bar.outputs()).unwrap(); let baz = baz_dfg.add_dataflow_op(baz, baz_dfg.input_wires()).unwrap(); let baz_dfg = baz_dfg.finish_with_outputs(baz.outputs()).unwrap(); let case2 = case2.finish_with_outputs(baz_dfg.outputs()).unwrap().node(); @@ -731,8 +798,7 @@ mod test { // Root node type needs to be that of common parent of the removed nodes: let mut rep2 = rep.clone(); rep2.replacement - .replace_op(rep2.replacement.root(), h.root_type().clone()) - .unwrap(); + .replace_op(rep2.replacement.root(), h.root_optype().clone()); assert_eq!( check_same_errors(rep2), ReplaceError::WrongRootNodeTag { @@ -788,7 +854,10 @@ mod test { mu_inp: vec![edge_from_removed.clone()], ..rep.clone() }), - ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Retained, edge_from_removed) + ReplaceError::BadEdgeSpec( + Direction::Outgoing, + WhichEdgeSpec::HostToRepl(edge_from_removed) + ) ); let bad_out_edge = NewEdgeSpec { src: h.nodes().max().unwrap(), // not valid in replacement @@ -800,7 +869,7 @@ mod test { mu_out: vec![bad_out_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichHugr::Replacement, bad_out_edge) + ReplaceError::BadEdgeSpec(Direction::Outgoing, WhichEdgeSpec::ReplToHost(bad_out_edge),) ); let bad_order_edge = NewEdgeSpec { src: cond.node(), @@ -812,7 +881,7 @@ mod test { mu_new: vec![bad_order_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, bad_order_edge) + ReplaceError::BadEdgeKind(_, e) => assert_eq!(e, WhichEdgeSpec::HostToHost(bad_order_edge)) ); let op = OutgoingPort::from(0); let (tgt, ip) = h.linked_inputs(cond.node(), op).next().unwrap(); @@ -829,7 +898,7 @@ mod test { mu_out: vec![new_out_edge.clone()], ..rep.clone() }), - ReplaceError::BadEdgeKind(Direction::Outgoing, new_out_edge) + ReplaceError::BadEdgeKind(Direction::Outgoing, WhichEdgeSpec::ReplToHost(new_out_edge)) ); } } diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/patch/simple_replace.rs similarity index 90% rename from hugr-core/src/hugr/rewrite/simple_replace.rs rename to hugr-core/src/hugr/patch/simple_replace.rs index cf7f2922a..245a3cdc0 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/patch/simple_replace.rs @@ -4,9 +4,8 @@ use std::collections::HashMap; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; -pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; -use crate::hugr::{HugrMut, HugrView, Rewrite}; +use crate::hugr::{HugrMut, HugrView}; use crate::ops::{OpTag, OpTrait, OpType}; use crate::{Hugr, IncomingPort, Node, OutgoingPort}; @@ -15,7 +14,7 @@ use itertools::Itertools; use thiserror::Error; use super::inline_dfg::InlineDFGError; -use super::{BoundaryPort, HostPort, ReplacementPort}; +use super::{BoundaryPort, HostPort, PatchHugrMut, PatchVerification, ReplacementPort}; /// Specification of a simple replacement operation. /// @@ -29,7 +28,8 @@ pub struct SimpleReplacement { /// A hugr with DFG root (consisting of replacement nodes). replacement: Hugr, /// A map from (target ports of edges from the Input node of `replacement`) - /// to (target ports of edges from nodes not in `subgraph` to nodes in `subgraph`). + /// to (target ports of edges from nodes not in `subgraph` to nodes in + /// `subgraph`). nu_inp: HashMap<(Node, IncomingPort), (HostNode, IncomingPort)>, /// A map from (target ports of edges from nodes in `subgraph` to nodes not /// in `subgraph`) to (input ports of the Output node of `replacement`). @@ -126,7 +126,8 @@ impl SimpleReplacement { }) .map( |(&(rep_inp_node, rep_inp_port), (rem_inp_node, rem_inp_port))| { - // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, n_inp_port) + // add edge from predecessor of (s_inp_node, s_inp_port) to (new_inp_node, + // n_inp_port) let (rem_inp_pred_node, rem_inp_pred_port) = host .single_linked_output(*rem_inp_node, *rem_inp_port) .unwrap(); @@ -159,8 +160,9 @@ impl SimpleReplacement { > + 'a { let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); - // For each q = self.nu_out[p] such that the predecessor of q is not an Input port, - // there will be an edge from (the new copy of) the predecessor of q to p. + // For each q = self.nu_out[p] such that the predecessor of q is not an Input + // port, there will be an edge from (the new copy of) the predecessor of + // q to p. self.nu_out .iter() .filter_map(move |(&(rem_out_node, rem_out_port), rep_out_port)| { @@ -197,8 +199,8 @@ impl SimpleReplacement { > + 'a { let [_, replacement_output_node] = self.get_replacement_io().expect("replacement is a DFG"); - // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the predecessor of p0 - // to p1. + // For each q = self.nu_out[p1], p0 = self.nu_inp[q], add an edge from the + // predecessor of p0 to p1. self.nu_out .iter() .filter_map(move |(&(rem_out_node, rem_out_port), &rep_out_port)| { @@ -246,8 +248,9 @@ impl SimpleReplacement { /// Get all edges that the replacement would add between `host` and /// `self.replacement`. /// - /// This is equivalent to chaining the results of [`Self::incoming_boundary`], - /// [`Self::outgoing_boundary`], and [`Self::host_to_host_boundary`]. + /// This is equivalent to chaining the results of + /// [`Self::incoming_boundary`], [`Self::outgoing_boundary`], and + /// [`Self::host_to_host_boundary`]. /// /// This panics if self.replacement is not a DFG. pub fn all_boundary_edges<'a>( @@ -275,16 +278,35 @@ impl SimpleReplacement { } } -impl Rewrite for SimpleReplacement { +impl PatchVerification for SimpleReplacement { type Error = SimpleReplacementError; - type ApplyResult = Vec<(Node, OpType)>; - const UNCHANGED_ON_FAILURE: bool = true; + type Node = HostNode; - fn verify(&self, h: &impl HugrView) -> Result<(), SimpleReplacementError> { + fn verify(&self, h: &impl HugrView) -> Result<(), SimpleReplacementError> { self.is_valid_rewrite(h) } - fn apply(self, h: &mut impl HugrMut) -> Result { + #[inline] + fn invalidation_set(&self) -> impl Iterator { + let subcirc = self.subgraph.nodes().iter().copied(); + let out_neighs = self.nu_out.keys().map(|key| key.0); + subcirc.chain(out_neighs) + } +} + +/// Result of applying a [`SimpleReplacement`]. +pub struct Outcome { + /// Map from Node in replacement to corresponding Node in the result Hugr + pub node_map: HashMap, + /// Nodes removed from the result Hugr and their weights + pub removed_nodes: HashMap, +} + +impl PatchHugrMut for SimpleReplacement { + type Outcome = Outcome; + const UNCHANGED_ON_FAILURE: bool = true; + + fn apply_hugr_mut(self, h: &mut impl HugrMut) -> Result { self.is_valid_rewrite(h)?; let parent = self.subgraph.get_parent(h); @@ -305,13 +327,10 @@ impl Rewrite for SimpleReplacement { } = self; // 2. Insert the replacement as a whole. - let InsertionResult { - new_root, - node_map: index_map, - } = h.insert_hugr(parent, replacement); + let InsertionResult { new_root, node_map } = h.insert_hugr(parent, replacement); // remove the Input and Output nodes from the replacement graph - let replace_children = h.children(new_root).collect::>(); + let replace_children = h.children(new_root).collect::>(); for &io in &replace_children[..2] { h.remove_node(io); } @@ -324,24 +343,22 @@ impl Rewrite for SimpleReplacement { // 3. Insert all boundary edges. for (src, tgt) in boundary_edges { - let (src_node, src_port) = src.map_replacement(&index_map); - let (tgt_node, tgt_port) = tgt.map_replacement(&index_map); + let (src_node, src_port) = src.map_replacement(&node_map); + let (tgt_node, tgt_port) = tgt.map_replacement(&node_map); h.connect(src_node, src_port, tgt_node, tgt_port); } // 4. Remove all nodes in subgraph and edges between them. - Ok(subgraph + let removed_nodes = subgraph .nodes() .iter() .map(|&node| (node, h.remove_node(node))) - .collect()) - } + .collect(); - #[inline] - fn invalidation_set(&self) -> impl Iterator { - let subcirc = self.subgraph.nodes().iter().copied(); - let out_neighs = self.nu_out.keys().map(|key| key.0); - subcirc.chain(out_neighs) + Ok(Outcome { + node_map, + removed_nodes, + }) } } @@ -364,9 +381,10 @@ pub enum SimpleReplacementError { } #[cfg(test)] -pub(in crate::hugr::rewrite) mod test { +pub(in crate::hugr::patch) mod test { use itertools::Itertools; use rstest::{fixture, rstest}; + use std::collections::{HashMap, HashSet}; use crate::builder::test::n_identity; @@ -375,9 +393,9 @@ pub(in crate::hugr::rewrite) mod test { DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::{bool_t, qb_t}; - use crate::extension::ExtensionSet; + use crate::hugr::patch::PatchVerification; use crate::hugr::views::{HugrView, SiblingSubgraph}; - use crate::hugr::{Hugr, HugrMut, Rewrite}; + use crate::hugr::{Hugr, HugrMut, Patch}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::NodeHandle; use crate::ops::OpTag; @@ -385,7 +403,7 @@ pub(in crate::hugr::rewrite) mod test { use crate::std_extensions::logic::test::and_op; use crate::std_extensions::logic::LogicOp; use crate::types::{Signature, Type}; - use crate::utils::test_quantum_extension::{cx_gate, h_gate, EXTENSION_ID}; + use crate::utils::test_quantum_extension::{cx_gate, h_gate}; use crate::{IncomingPort, Node}; use super::SimpleReplacement; @@ -402,12 +420,8 @@ pub(in crate::hugr::rewrite) mod test { fn make_hugr() -> Result { let mut module_builder = ModuleBuilder::new(); let _f_id = { - let just_q: ExtensionSet = EXTENSION_ID.into(); - let mut func_builder = module_builder.define_function( - "main", - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) - .with_extension_delta(just_q.clone()), - )?; + let mut func_builder = module_builder + .define_function("main", Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]))?; let [qb0, qb1, qb2] = func_builder.input_wires_arr(); @@ -433,7 +447,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn simple_hugr() -> Hugr { + pub(in crate::hugr::patch) fn simple_hugr() -> Hugr { make_hugr().unwrap() } /// Creates a hugr with a DFG root like the following: @@ -443,7 +457,7 @@ pub(in crate::hugr::rewrite) mod test { /// ┤ H ├┤ X ├ /// └───┘└───┘ fn make_dfg_hugr() -> Result { - let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]).with_prelude())?; + let mut dfg_builder = DFGBuilder::new(endo_sig(vec![qb_t(), qb_t()]))?; let [wire0, wire1] = dfg_builder.input_wires_arr(); let wire2 = dfg_builder.add_dataflow_op(h_gate(), vec![wire0])?; let wire3 = dfg_builder.add_dataflow_op(h_gate(), vec![wire1])?; @@ -453,7 +467,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr() -> Hugr { + pub(in crate::hugr::patch) fn dfg_hugr() -> Hugr { make_dfg_hugr().unwrap() } @@ -473,7 +487,7 @@ pub(in crate::hugr::rewrite) mod test { } #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr2() -> Hugr { + pub(in crate::hugr::patch) fn dfg_hugr2() -> Hugr { make_dfg_hugr2().unwrap() } @@ -485,11 +499,12 @@ pub(in crate::hugr::rewrite) mod test { /// └─────────┘ │ ┌─────────┐ /// └────┤ (2) NOT ├── /// └─────────┘ - /// This can be replaced with an empty hugr coping the input to both outputs. + /// This can be replaced with an empty hugr coping the input to both + /// outputs. /// /// Returns the hugr and the nodes of the NOT gates, in order. #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { + pub(in crate::hugr::patch) fn dfg_hugr_copy_bools() -> (Hugr, Vec) { let mut dfg_builder = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); @@ -516,11 +531,12 @@ pub(in crate::hugr::rewrite) mod test { /// └─────────┘ │ /// └───────────────── /// - /// This can be replaced with a single NOT op, coping the input to the first output. + /// This can be replaced with a single NOT op, coping the input to the first + /// output. /// /// Returns the hugr and the nodes of the NOT ops, in order. #[fixture] - pub(in crate::hugr::rewrite) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { + pub(in crate::hugr::patch) fn dfg_hugr_half_not_bools() -> (Hugr, Vec) { let mut dfg_builder = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t(), bool_t()])).unwrap(); let [b] = dfg_builder.input_wires_arr(); @@ -682,7 +698,7 @@ pub(in crate::hugr::rewrite) mod test { nu_inp, nu_out, }; - h.apply_rewrite(r).unwrap(); + h.apply_patch(r).unwrap(); // Expect [DFG] to be replaced with: // ┌───┐┌───┐ // ┤ H ├┤ H ├ @@ -736,7 +752,7 @@ pub(in crate::hugr::rewrite) mod test { }) .map(|p| ((output, p), p)) .collect(); - h.apply_rewrite(SimpleReplacement::new( + h.apply_patch(SimpleReplacement::new( SiblingSubgraph::try_from_nodes(removal, &h).unwrap(), replacement, inputs, @@ -745,7 +761,7 @@ pub(in crate::hugr::rewrite) mod test { .unwrap(); // They should be the same, up to node indices - assert_eq!(h.edge_count(), orig.edge_count()); + assert_eq!(h.num_edges(), orig.num_edges()); } #[test] @@ -788,7 +804,7 @@ pub(in crate::hugr::rewrite) mod test { .map(|p| ((repl_output, p), p)) .collect(); - h.apply_rewrite(SimpleReplacement::new( + h.apply_patch(SimpleReplacement::new( SiblingSubgraph::try_from_nodes(removal, &h).unwrap(), repl, inputs, @@ -797,11 +813,11 @@ pub(in crate::hugr::rewrite) mod test { .unwrap(); // Nothing changed - assert_eq!(h.node_count(), orig.node_count()); + assert_eq!(h.num_nodes(), orig.num_nodes()); } - /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the input - /// directly to the outputs. + /// Remove all the NOT gates in [`dfg_hugr_copy_bools`] by connecting the + /// input directly to the outputs. /// /// https://github.com/CQCL/hugr/issues/1190 #[rstest] @@ -822,8 +838,9 @@ pub(in crate::hugr::rewrite) mod test { let subgraph = SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0, output_not_1], &hugr) .unwrap(); - // A map from (target ports of edges from the Input node of `replacement`) to (target ports of - // edges from nodes not in `removal` to nodes in `removal`). + // A map from (target ports of edges from the Input node of `replacement`) to + // (target ports of edges from nodes not in `removal` to nodes in + // `removal`). let nu_inp = [ ( (repl_output, IncomingPort::from(0)), @@ -836,8 +853,8 @@ pub(in crate::hugr::rewrite) mod test { ] .into_iter() .collect(); - // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - // (input ports of the Output node of `replacement`). + // A map from (target ports of edges from nodes in `removal` to nodes not in + // `removal`) to (input ports of the Output node of `replacement`). let nu_out = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), @@ -854,11 +871,11 @@ pub(in crate::hugr::rewrite) mod test { rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); assert_eq!(hugr.validate(), Ok(())); - assert_eq!(hugr.node_count(), 3); + assert_eq!(hugr.num_nodes(), 3); } - /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting the input - /// directly to the output. + /// Remove one of the NOT ops in [`dfg_hugr_half_not_bools`] by connecting + /// the input directly to the output. /// /// https://github.com/CQCL/hugr/issues/1323 #[rstest] @@ -880,8 +897,9 @@ pub(in crate::hugr::rewrite) mod test { let subgraph = SiblingSubgraph::try_from_nodes(vec![input_not, output_not_0], &hugr).unwrap(); - // A map from (target ports of edges from the Input node of `replacement`) to (target ports of - // edges from nodes not in `removal` to nodes in `removal`). + // A map from (target ports of edges from the Input node of `replacement`) to + // (target ports of edges from nodes not in `removal` to nodes in + // `removal`). let nu_inp = [ ( (repl_output, IncomingPort::from(0)), @@ -894,8 +912,8 @@ pub(in crate::hugr::rewrite) mod test { ] .into_iter() .collect(); - // A map from (target ports of edges from nodes in `removal` to nodes not in `removal`) to - // (input ports of the Output node of `replacement`). + // A map from (target ports of edges from nodes in `removal` to nodes not in + // `removal`) to (input ports of the Output node of `replacement`). let nu_out = [ ((output, IncomingPort::from(0)), IncomingPort::from(0)), ((output, IncomingPort::from(1)), IncomingPort::from(1)), @@ -912,7 +930,7 @@ pub(in crate::hugr::rewrite) mod test { rewrite.apply(&mut hugr).unwrap_or_else(|e| panic!("{e}")); assert_eq!(hugr.validate(), Ok(())); - assert_eq!(hugr.node_count(), 4); + assert_eq!(hugr.num_nodes(), 4); } #[rstest] @@ -951,17 +969,17 @@ pub(in crate::hugr::rewrite) mod test { let rewrite = SimpleReplacement::new(subgraph, replacement, nu_inp, nu_out); - assert_eq!(h.node_count(), 4); + assert_eq!(h.num_nodes(), 4); rewrite.apply(&mut h).unwrap_or_else(|e| panic!("{e}")); h.validate().unwrap_or_else(|e| panic!("{e}")); - assert_eq!(h.node_count(), 6); + assert_eq!(h.num_nodes(), 6); } - use crate::hugr::rewrite::replace::Replacement; + use crate::hugr::patch::replace::Replacement; fn to_replace(h: &impl HugrView, s: SimpleReplacement) -> Replacement { - use crate::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec}; + use crate::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec}; let mut replacement = s.replacement; let (in_, out) = replacement @@ -1018,10 +1036,10 @@ pub(in crate::hugr::rewrite) mod test { } fn apply_simple(h: &mut Hugr, rw: SimpleReplacement) { - h.apply_rewrite(rw).unwrap(); + h.apply_patch(rw).unwrap(); } fn apply_replace(h: &mut Hugr, rw: SimpleReplacement) { - h.apply_rewrite(to_replace(h, rw)).unwrap(); + h.apply_patch(to_replace(h, rw)).unwrap(); } } diff --git a/hugr-core/src/hugr/rewrite.rs b/hugr-core/src/hugr/rewrite.rs deleted file mode 100644 index 7c4374b65..000000000 --- a/hugr-core/src/hugr/rewrite.rs +++ /dev/null @@ -1,92 +0,0 @@ -//! Rewrite operations on the HUGR - replacement, outlining, etc. - -pub mod consts; -pub mod inline_call; -pub mod inline_dfg; -pub mod insert_identity; -pub mod outline_cfg; -mod port_types; -pub mod replace; -pub mod simple_replace; - -use crate::{Hugr, HugrView, Node}; -pub use port_types::{BoundaryPort, HostPort, ReplacementPort}; -pub use simple_replace::{SimpleReplacement, SimpleReplacementError}; - -use super::HugrMut; - -/// An operation that can be applied to mutate a Hugr -pub trait Rewrite { - /// The type of Error with which this Rewrite may fail - type Error: std::error::Error; - /// The type returned on successful application of the rewrite. - type ApplyResult; - - /// If `true`, [self.apply]'s of this rewrite guarantee that they do not mutate the Hugr when they return an Err. - /// If `false`, there is no guarantee; the Hugr should be assumed invalid when Err is returned. - const UNCHANGED_ON_FAILURE: bool; - - /// Checks whether the rewrite would succeed on the specified Hugr. - /// If this call succeeds, [self.apply] should also succeed on the same `h` - /// If this calls fails, [self.apply] would fail with the same error. - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error>; - - /// Mutate the specified Hugr, or fail with an error. - /// Returns [`Self::ApplyResult`] if successful. - /// If [self.unchanged_on_failure] is true, then `h` must be unchanged if Err is returned. - /// See also [self.verify] - /// # Panics - /// May panic if-and-only-if `h` would have failed [Hugr::validate]; that is, - /// implementations may begin with `assert!(h.validate())`, with `debug_assert!(h.validate())` - /// being preferred. - fn apply(self, h: &mut impl HugrMut) -> Result; - - /// Returns a set of nodes referenced by the rewrite. Modifying any of these - /// nodes will invalidate it. - /// - /// Two `impl Rewrite`s can be composed if their invalidation sets are - /// disjoint. - fn invalidation_set(&self) -> impl Iterator; -} - -/// Wraps any rewrite into a transaction (i.e. that has no effect upon failure) -pub struct Transactional { - underlying: R, -} - -// Note we might like to constrain R to Rewrite but this -// is not yet supported, https://github.com/rust-lang/rust/issues/92827 -impl Rewrite for Transactional { - type Error = R::Error; - type ApplyResult = R::ApplyResult; - const UNCHANGED_ON_FAILURE: bool = true; - - fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> { - self.underlying.verify(h) - } - - fn apply(self, h: &mut impl HugrMut) -> Result { - if R::UNCHANGED_ON_FAILURE { - return self.underlying.apply(h); - } - // Try to backup just the contents of this HugrMut. - let mut backup = Hugr::new(h.root_type().clone()); - backup.insert_from_view(backup.root(), h); - let r = self.underlying.apply(h); - if r.is_err() { - // Try to restore backup. - h.replace_op(h.root(), backup.root_type().clone()) - .expect("The root replacement should always match the old root type"); - while let Some(child) = h.first_child(h.root()) { - h.remove_node(child); - } - h.insert_from_view(h.root(), &backup); - } - r - } - - #[inline] - fn invalidation_set(&self) -> impl Iterator { - self.underlying.invalidation_set() - } -} diff --git a/hugr-core/src/hugr/serialize.rs b/hugr-core/src/hugr/serialize.rs index 906084d55..5e4922157 100644 --- a/hugr-core/src/hugr/serialize.rs +++ b/hugr-core/src/hugr/serialize.rs @@ -157,13 +157,13 @@ impl TryFrom<&Hugr> for SerHugrLatest { fn try_from(hugr: &Hugr) -> Result { // We compact the operation nodes during the serialization process, // and ignore the copy nodes. - let mut node_rekey: HashMap = HashMap::with_capacity(hugr.node_count()); + let mut node_rekey: HashMap = HashMap::with_capacity(hugr.num_nodes()); for (order, node) in hugr.canonical_order(hugr.root()).enumerate() { node_rekey.insert(node, portgraph::NodeIndex::new(order).into()); } - let mut nodes = vec![None; hugr.node_count()]; - let mut metadata = vec![None; hugr.node_count()]; + let mut nodes = vec![None; hugr.num_nodes()]; + let mut metadata = vec![None; hugr.num_nodes()]; for n in hugr.nodes() { let parent = node_rekey[&hugr.get_parent(n).unwrap_or(n)]; let opt = hugr.get_optype(n); @@ -172,7 +172,7 @@ impl TryFrom<&Hugr> for SerHugrLatest { parent, op: opt.clone(), }); - metadata[new_node].clone_from(hugr.metadata.get(n.pg_index())); + metadata[new_node].clone_from(hugr.metadata.get(n.into_portgraph())); } let nodes = nodes .into_iter() @@ -251,7 +251,7 @@ impl TryFrom for Hugr { } let unwrap_offset = |node: Node, offset, dir, hugr: &Hugr| -> Result { - if !hugr.graph.contains_node(node.pg_index()) { + if !hugr.graph.contains_node(node.into_portgraph()) { return Err(HUGRSerializationError::UnknownEdgeNode { node }); } let offset = match offset { diff --git a/hugr-core/src/hugr/serialize/test.rs b/hugr-core/src/hugr/serialize/test.rs index 49a7b9321..6848062b7 100644 --- a/hugr-core/src/hugr/serialize/test.rs +++ b/hugr-core/src/hugr/serialize/test.rs @@ -6,10 +6,10 @@ use crate::builder::{ DataflowSubContainer, HugrBuilder, ModuleBuilder, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; +use crate::extension::prelude::{bool_t, qb_t, usize_t}; use crate::extension::simple_op::MakeRegisteredOp; +use crate::extension::test::SimpleOpDef; use crate::extension::ExtensionRegistry; -use crate::extension::{test::SimpleOpDef, ExtensionSet}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::validate::ValidationError; use crate::ops::custom::{ExtensionOp, OpaqueOp, OpaqueOpError}; @@ -300,7 +300,7 @@ fn weighted_hugr_ser() { let t_row = vec![Type::new_sum([vec![usize_t()], vec![qb_t()]])]; let mut f_build = module_builder - .define_function("main", Signature::new(t_row.clone(), t_row).with_prelude()) + .define_function("main", Signature::new(t_row.clone(), t_row)) .unwrap(); let outputs = f_build @@ -324,7 +324,7 @@ fn weighted_hugr_ser() { #[test] fn dfg_roundtrip() -> Result<(), Box> { let tp: Vec = vec![bool_t(); 2]; - let mut dfg = DFGBuilder::new(Signature::new(tp.clone(), tp).with_prelude())?; + let mut dfg = DFGBuilder::new(Signature::new(tp.clone(), tp))?; let mut params: [_; 2] = dfg.input_wires_arr(); for p in params.iter_mut() { *p = dfg @@ -390,8 +390,8 @@ fn opaque_ops() -> Result<(), Box> { #[test] fn function_type() -> Result<(), Box> { - let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()]).with_prelude()); - let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]).with_prelude())?; + let fn_ty = Type::new_function(Signature::new_endo(vec![bool_t()])); + let mut bldr = DFGBuilder::new(Signature::new_endo(vec![fn_ty.clone()]))?; let op = bldr.add_dataflow_op(Noop(fn_ty), bldr.input_wires())?; let h = bldr.finish_hugr_with_outputs(op.outputs())?; @@ -482,10 +482,8 @@ fn roundtrip_value(#[case] value: Value) { } fn polyfunctype1() -> PolyFuncType { - let mut extension_set = ExtensionSet::new(); - extension_set.insert_type_var(1); - let function_type = Signature::new_endo(type_row![]).with_extension_delta(extension_set); - PolyFuncType::new([TypeParam::max_nat(), TypeParam::Extensions], function_type) + let function_type = Signature::new_endo(type_row![]); + PolyFuncType::new([TypeParam::max_nat()], function_type) } fn polyfunctype2() -> PolyFuncTypeRV { @@ -541,7 +539,7 @@ fn roundtrip_polyfunctype_varlen(#[case] poly_func_type: PolyFuncTypeRV) { #[case(ops::Const::new(Value::function(crate::builder::test::simple_dfg_hugr()).unwrap()))] #[case(ops::Input::new(vec![Type::new_var_use(3,TypeBound::Copyable)]))] #[case(ops::Output::new(vec![Type::new_function(FuncValueType::new_endo(type_row![]))]))] -#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}, TypeArg::Extensions{ es: ExtensionSet::singleton(PRELUDE_ID)} ]).unwrap())] +#[case(ops::Call::try_new(polyfunctype1(), [TypeArg::BoundedNat{n: 1}]).unwrap())] #[case(ops::CallIndirect { signature : Signature::new_endo(vec![bool_t()]) })] fn roundtrip_optype(#[case] optype: impl Into + std::fmt::Debug) { check_testing_roundtrip(NodeSer { diff --git a/hugr-core/src/hugr/serialize/upgrade.rs b/hugr-core/src/hugr/serialize/upgrade.rs index 2741b6175..ac1ac1eea 100644 --- a/hugr-core/src/hugr/serialize/upgrade.rs +++ b/hugr-core/src/hugr/serialize/upgrade.rs @@ -1,6 +1,7 @@ use thiserror::Error; #[derive(Debug, Error)] +#[non_exhaustive] pub enum UpgradeError { #[error(transparent)] Deserialize(#[from] serde_json::Error), diff --git a/hugr-core/src/hugr/serialize/upgrade/test.rs b/hugr-core/src/hugr/serialize/upgrade/test.rs index 5e1d3ee51..e3aa4740b 100644 --- a/hugr-core/src/hugr/serialize/upgrade/test.rs +++ b/hugr-core/src/hugr/serialize/upgrade/test.rs @@ -55,7 +55,6 @@ pub fn hugr_with_named_op() -> Hugr { #[rstest] #[case("empty_hugr", empty_hugr())] #[case("hugr_with_named_op", hugr_with_named_op())] -#[cfg_attr(feature = "extension_inference", ignore = "Fails extension inference")] fn serial_upgrade(#[case] name: String, #[case] hugr: Hugr) { let path = TEST_CASE_DIR.join(format!("{}.json", name)); if !path.exists() { diff --git a/hugr-core/src/hugr/validate.rs b/hugr-core/src/hugr/validate.rs index 3b04ccd86..31b34b044 100644 --- a/hugr-core/src/hugr/validate.rs +++ b/hugr-core/src/hugr/validate.rs @@ -9,18 +9,18 @@ use petgraph::visit::{Topo, Walker}; use portgraph::{LinkView, PortView}; use thiserror::Error; -use crate::extension::{SignatureError, TO_BE_INFERRED}; +use crate::extension::SignatureError; use crate::ops::constant::ConstTypeError; use crate::ops::custom::{ExtensionOp, OpaqueOpError}; use crate::ops::validate::{ChildrenEdgeData, ChildrenValidationError, EdgeValidationError}; -use crate::ops::{FuncDefn, NamedOp, OpName, OpParent, OpTag, OpTrait, OpType, ValidateOp}; +use crate::ops::{FuncDefn, NamedOp, OpName, OpTag, OpTrait, OpType, ValidateOp}; use crate::types::type_param::TypeParam; use crate::types::EdgeKind; use crate::{Direction, Hugr, Node, Port}; use super::internal::HugrInternals; -use super::views::{HierarchyView, HugrView, SiblingGraph}; +use super::views::HugrView; use super::ExtensionError; /// Structure keeping track of pre-computed information used in the validation @@ -31,72 +31,19 @@ use super::ExtensionError; struct ValidationContext<'a> { hugr: &'a Hugr, /// Dominator tree for each CFG region, using the container node as index. - dominators: HashMap>, + dominators: HashMap>, } impl Hugr { - /// Check the validity of the HUGR, assuming that it has no open extension - /// variables. - /// TODO: Add a version of validation which allows for open extension - /// variables (see github issue #457) + /// Check the validity of the HUGR. pub fn validate(&self) -> Result<(), ValidationError> { - self.validate_no_extensions()?; - if cfg!(feature = "extension_inference") { - self.validate_extensions()?; - } - Ok(()) - } - - /// Check the validity of the HUGR, but don't check consistency of extension - /// requirements between connected nodes or between parents and children. - pub fn validate_no_extensions(&self) -> Result<(), ValidationError> { let mut validator = ValidationContext::new(self); validator.validate() } - - /// Validate extensions, i.e. that extension deltas from parent nodes are reflected in their children. - pub fn validate_extensions(&self) -> Result<(), ValidationError> { - for parent in self.nodes() { - let parent_op = self.get_optype(parent); - if parent_op.extension_delta().contains(&TO_BE_INFERRED) { - return Err(ValidationError::ExtensionsNotInferred { node: parent }); - } - let parent_extensions = match parent_op.inner_function_type() { - Some(s) => s.runtime_reqs.clone(), - None => match parent_op.tag() { - OpTag::Cfg | OpTag::Conditional => parent_op.extension_delta(), - // ModuleRoot holds but does not execute its children, so allow any extensions - OpTag::ModuleRoot => continue, - _ => { - assert!(self.children(parent).next().is_none(), - "Unknown parent node type {} - not a DataflowParent, Module, Cfg or Conditional", - parent_op); - continue; - } - }, - }; - for child in self.children(parent) { - let child_extensions = self.get_optype(child).extension_delta(); - if !parent_extensions.is_superset(&child_extensions) { - return Err(ExtensionError { - parent, - parent_extensions, - child, - child_extensions, - } - .into()); - } - } - } - Ok(()) - } } impl<'a> ValidationContext<'a> { /// Create a new validation context. - // Allow unused "extension_closure" variable for when - // the "extension_inference" feature is disabled. - #[allow(unused_variables)] pub fn new(hugr: &'a Hugr) -> Self { let dominators = HashMap::new(); Self { hugr, dominators } @@ -138,10 +85,10 @@ impl<'a> ValidationContext<'a> { /// /// The results of this computation should be cached in `self.dominators`. /// We don't do it here to avoid mutable borrows. - fn compute_dominator(&self, parent: Node) -> Dominators { - let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap(); + fn compute_dominator(&self, parent: Node) -> Dominators { + let region = self.hugr.region_portgraph(parent); let entry_node = self.hugr.children(parent).next().unwrap(); - dominators::simple_fast(®ion.as_petgraph(), entry_node) + dominators::simple_fast(®ion, entry_node.into_portgraph()) } /// Check the constraints on a single node. @@ -163,7 +110,7 @@ impl<'a> ValidationContext<'a> { for dir in Direction::BOTH { // Check that we have the correct amount of ports and edges. - let num_ports = self.hugr.graph.num_ports(node.pg_index(), dir); + let num_ports = self.hugr.graph.num_ports(node.into_portgraph(), dir); if num_ports != op_type.port_count(dir) { return Err(ValidationError::WrongNumberOfPorts { node, @@ -316,7 +263,7 @@ impl<'a> ValidationContext<'a> { fn validate_children(&self, node: Node, op_type: &OpType) -> Result<(), ValidationError> { let flags = op_type.validity_flags(); - if self.hugr.hierarchy().child_count(node.pg_index()) > 0 { + if self.hugr.hierarchy().child_count(node.into_portgraph()) > 0 { if flags.allowed_children.is_empty() { return Err(ValidationError::NonContainerWithChildren { node, @@ -352,7 +299,8 @@ impl<'a> ValidationContext<'a> { } } // Additional validations running over the full list of children optypes - let children_optypes = all_children.map(|c| (c.pg_index(), self.hugr.get_optype(c))); + let children_optypes = + all_children.map(|c| (c.into_portgraph(), self.hugr.get_optype(c))); if let Err(source) = op_type.validate_op_children(children_optypes) { return Err(ValidationError::InvalidChildren { parent: node, @@ -363,9 +311,9 @@ impl<'a> ValidationContext<'a> { // Additional validations running over the edges of the contained graph if let Some(edge_check) = flags.edge_check { - for source in self.hugr.hierarchy().children(node.pg_index()) { + for source in self.hugr.hierarchy().children(node.into_portgraph()) { for target in self.hugr.graph.output_neighbours(source) { - if self.hugr.hierarchy.parent(target) != Some(node.pg_index()) { + if self.hugr.hierarchy.parent(target) != Some(node.into_portgraph()) { continue; } let source_op = self.hugr.get_optype(source.into()); @@ -411,16 +359,16 @@ impl<'a> ValidationContext<'a> { /// Inter-graph edges are ignored. Only internal dataflow, constant, or /// state order edges are considered. fn validate_children_dag(&self, parent: Node, op_type: &OpType) -> Result<(), ValidationError> { - if !self.hugr.hierarchy.has_children(parent.pg_index()) { + if !self.hugr.hierarchy.has_children(parent.into_portgraph()) { // No children, nothing to do return Ok(()); }; - let region: SiblingGraph = SiblingGraph::try_new(self.hugr, parent).unwrap(); - let postorder = Topo::new(®ion.as_petgraph()); + let region = self.hugr.region_portgraph(parent); + let postorder = Topo::new(®ion); let nodes_visited = postorder - .iter(®ion.as_petgraph()) - .filter(|n| *n != parent) + .iter(®ion) + .filter(|n| *n != parent.into_portgraph()) .count(); let node_count = self.hugr.children(parent).count(); if nodes_visited != node_count { @@ -500,7 +448,7 @@ impl<'a> ValidationContext<'a> { // Must have an order edge. self.hugr .graph - .get_connections(from.pg_index(), ancestor.pg_index()) + .get_connections(from.into_portgraph(), ancestor.into_portgraph()) .find(|&(p, _)| { let offset = self.hugr.graph.port_offset(p).unwrap(); from_optype.port_kind(offset) == Some(EdgeKind::StateOrder) @@ -537,8 +485,8 @@ impl<'a> ValidationContext<'a> { } }; if !dominator_tree - .dominators(ancestor) - .is_some_and(|mut ds| ds.any(|n| n == from_parent)) + .dominators(ancestor.into_portgraph()) + .is_some_and(|mut ds| ds.any(|n| n == from_parent.into_portgraph())) { return Err(InterGraphEdgeError::NonDominatedAncestor { from, @@ -616,7 +564,12 @@ impl<'a> ValidationContext<'a> { // Root nodes are ignored, as they cannot have connected edges. if node != self.hugr.root() { for dir in Direction::BOTH { - for (i, port_index) in self.hugr.graph.ports(node.pg_index(), dir).enumerate() { + for (i, port_index) in self + .hugr + .graph + .ports(node.into_portgraph(), dir) + .enumerate() + { let port = Port::new(dir, i); self.validate_port(node, port, port_index, op_type, var_decls)?; } diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index ecb417ec5..7fec75bce 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -11,8 +11,8 @@ use crate::builder::{ FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; use crate::extension::prelude::Noop; -use crate::extension::prelude::{bool_t, qb_t, usize_t, PRELUDE_ID}; -use crate::extension::{Extension, ExtensionRegistry, ExtensionSet, TypeDefBound, PRELUDE}; +use crate::extension::prelude::{bool_t, qb_t, usize_t}; +use crate::extension::{Extension, ExtensionRegistry, TypeDefBound, PRELUDE}; use crate::hugr::internal::HugrMutInternals; use crate::hugr::HugrMut; use crate::ops::dataflow::IOTrait; @@ -20,7 +20,6 @@ use crate::ops::handle::NodeHandle; use crate::ops::{self, OpType, Value}; use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::LogicOp; -use crate::std_extensions::logic::{self}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, @@ -36,9 +35,7 @@ use crate::{ fn make_simple_hugr(copies: usize) -> (Hugr, Node) { let def_op: OpType = ops::FuncDefn { name: "main".into(), - signature: Signature::new(vec![bool_t()], vec![bool_t(); copies]) - .with_prelude() - .into(), + signature: Signature::new(vec![bool_t()], vec![bool_t(); copies]).into(), } .into(); @@ -104,7 +101,7 @@ fn invalid_root() { ); // Fix the root - b.root = module.pg_index(); + b.root = module.into_portgraph(); b.remove_node(root); assert_eq!(b.validate(), Ok(())); } @@ -120,7 +117,7 @@ fn leaf_root() { #[test] fn dfg_root() { let dfg_op: OpType = ops::DFG { - signature: Signature::new_endo(vec![bool_t()]).with_prelude(), + signature: Signature::new_endo(vec![bool_t()]), } .into(); @@ -143,7 +140,7 @@ fn children_restrictions() { let root = b.root(); let (_input, copy, _output) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -186,52 +183,46 @@ fn df_children_restrictions() { let (mut b, def) = make_simple_hugr(2); let (_input, output, copy) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); // Replace the output operation of the df subgraph with a copy - b.replace_op(output, Noop(usize_t())).unwrap(); + b.replace_op(output, Noop(usize_t())); assert_matches!( b.validate(), Err(ValidationError::InvalidInitialChild { parent, .. }) => assert_eq!(parent, def) ); // Revert it back to an output, but with the wrong number of ports - b.replace_op(output, ops::Output::new(vec![bool_t()])) - .unwrap(); + b.replace_op(output, ops::Output::new(vec![bool_t()])); assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::IOSignatureMismatch { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, output.pg_index())} + => {assert_eq!(parent, def); assert_eq!(child, output.into_portgraph())} ); - b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])) - .unwrap(); + b.replace_op(output, ops::Output::new(vec![bool_t(), bool_t()])); // After fixing the output back, replace the copy with an output op - b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])) - .unwrap(); + b.replace_op(copy, ops::Output::new(vec![bool_t(), bool_t()])); assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalIOChildren { child, .. }, .. }) - => {assert_eq!(parent, def); assert_eq!(child, copy.pg_index())} + => {assert_eq!(parent, def); assert_eq!(child, copy.into_portgraph())} ); } #[test] fn test_ext_edge() { - let mut h = closed_dfg_root_hugr( - Signature::new(vec![bool_t(), bool_t()], vec![bool_t()]) - .with_extension_delta(TO_BE_INFERRED), - ); + let mut h = closed_dfg_root_hugr(Signature::new(vec![bool_t(), bool_t()], vec![bool_t()])); let [input, output] = h.get_io(h.root()).unwrap(); // Nested DFG bool_t() -> bool_t() let sub_dfg = h.add_node_with_parent( h.root(), ops::DFG { - signature: Signature::new_endo(vec![bool_t()]).with_extension_delta(TO_BE_INFERRED), + signature: Signature::new_endo(vec![bool_t()]), }, ); // this Xor has its 2nd input unconnected @@ -258,7 +249,6 @@ fn test_ext_edge() { ); //Order edge. This will need metadata indicating its purpose. h.add_other_edge(input, sub_dfg); - h.infer_extensions(false).unwrap(); h.validate().unwrap(); } @@ -293,8 +283,7 @@ fn no_ext_edge_into_func() -> Result<(), Box> { #[test] fn test_local_const() { - let mut h = - closed_dfg_root_hugr(Signature::new_endo(bool_t()).with_extension_delta(TO_BE_INFERRED)); + let mut h = closed_dfg_root_hugr(Signature::new_endo(bool_t())); let [input, output] = h.get_io(h.root()).unwrap(); let and = h.add_node_with_parent(h.root(), and_op()); h.connect(input, 0, and, 0); @@ -307,12 +296,7 @@ fn test_local_const() { port_kind: EdgeKind::Value(bool_t()) }) ); - let const_op: ops::Const = logic::EXTENSION - .get_value(&logic::TRUE_NAME) - .unwrap() - .typed_value() - .clone() - .into(); + let const_op: ops::Const = ops::Value::from_bool(true).into(); // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op); let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: bool_t() }); @@ -321,7 +305,6 @@ fn test_local_const() { h.connect(lcst, 0, and, 1); assert_eq!(h.static_source(lcst), Some(cst)); // There is no edge from Input to LoadConstant, but that's OK: - h.infer_extensions(false).unwrap(); h.validate().unwrap(); } @@ -558,11 +541,7 @@ fn no_polymorphic_consts() -> Result<(), Box> { reg.validate()?; let mut def = FunctionBuilder::new( "myfunc", - PolyFuncType::new( - [BOUND], - Signature::new(vec![], vec![list_of_var.clone()]) - .with_extension_delta(list::EXTENSION_ID), - ), + PolyFuncType::new([BOUND], Signature::new(vec![], vec![list_of_var.clone()])), )?; let empty_list = Value::extension(list::ListValue::new_empty(Type::new_var_use( 0, @@ -655,7 +634,7 @@ fn row_variables() -> Result<(), Box> { "id", PolyFuncType::new( [TypeParam::new_list(TypeBound::Any)], - Signature::new(inner_ft.clone(), ft_usz).with_extension_delta(e.name.clone()), + Signature::new(inner_ft.clone(), ft_usz), ), )?; // All the wires here are carrying higher-order Function values @@ -677,19 +656,15 @@ fn row_variables() -> Result<(), Box> { #[test] fn test_polymorphic_call() -> Result<(), Box> { + // TODO: This tests a function call that is polymorphic in an extension set. + // Should this be rewritten to be polymorphic in something else or removed? + let e = Extension::try_new_test_arc(EXT_ID, |ext, extension_ref| { - let params: Vec = vec![ - TypeBound::Any.into(), - TypeParam::Extensions, - TypeBound::Any.into(), - ]; - let evaled_fn = Type::new_function( - Signature::new( - Type::new_var_use(0, TypeBound::Any), - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), - ); + let params: Vec = vec![TypeBound::Any.into(), TypeBound::Any.into()]; + let evaled_fn = Type::new_function(Signature::new( + Type::new_var_use(0, TypeBound::Any), + Type::new_var_use(1, TypeBound::Any), + )); // Single-input/output version of the higher-order "eval" operation, with extension param. // Note the extension-delta of the eval node includes that of the input function. ext.add_op( @@ -699,9 +674,8 @@ fn test_polymorphic_call() -> Result<(), Box> { params.clone(), Signature::new( vec![evaled_fn, Type::new_var_use(0, TypeBound::Any)], - Type::new_var_use(2, TypeBound::Any), - ) - .with_extension_delta(ExtensionSet::type_var(1)), + Type::new_var_use(1, TypeBound::Any), + ), ), extension_ref, )?; @@ -709,27 +683,23 @@ fn test_polymorphic_call() -> Result<(), Box> { Ok(()) })?; - fn utou(e: impl Into) -> Type { - Type::new_function(Signature::new_endo(usize_t()).with_extension_delta(e.into())) + fn utou() -> Type { + Type::new_function(Signature::new_endo(usize_t())) } let int_pair = Type::new_tuple(vec![usize_t(); 2]); - // Root DFG: applies a function int--PRELUDE-->int to each element of a pair of two ints + // Root DFG: applies a function int-->int to each element of a pair of two ints let mut d = DFGBuilder::new(inout_sig( - vec![utou(PRELUDE_ID), int_pair.clone()], + vec![utou(), int_pair.clone()], vec![int_pair.clone()], ))?; - // ....by calling a function parametrized (int--e-->int, int_pair) -> int_pair + // ....by calling a function (int-->int, int_pair) -> int_pair let f = { - let es = ExtensionSet::type_var(0); let mut f = d.define_function( "two_ints", PolyFuncType::new( - vec![TypeParam::Extensions], - Signature::new(vec![utou(es.clone()), int_pair.clone()], int_pair.clone()) - .with_extension_delta(EXT_ID) - .with_prelude() - .with_extension_delta(es.clone()), + vec![], + Signature::new(vec![utou(), int_pair.clone()], int_pair.clone()), ), )?; let [func, tup] = f.input_wires_arr(); @@ -740,14 +710,7 @@ fn test_polymorphic_call() -> Result<(), Box> { )?; let mut cc = c.case_builder(0)?; let [i1, i2] = cc.input_wires_arr(); - let op = e.instantiate_extension_op( - "eval", - vec![ - usize_t().into(), - TypeArg::Extensions { es }, - usize_t().into(), - ], - )?; + let op = e.instantiate_extension_op("eval", vec![usize_t().into(), usize_t().into()])?; let [f1] = cc.add_dataflow_op(op.clone(), [func, i1])?.outputs_arr(); let [f2] = cc.add_dataflow_op(op, [func, i2])?.outputs_arr(); cc.finish_with_outputs([f1, f2])?; @@ -757,18 +720,10 @@ fn test_polymorphic_call() -> Result<(), Box> { }; let [func, tup] = d.input_wires_arr(); - let call = d.call( - f.handle(), - &[TypeArg::Extensions { - es: ExtensionSet::singleton(PRELUDE_ID), - }], - [func, tup], - )?; + let call = d.call(f.handle(), &[], [func, tup])?; let h = d.finish_hugr_with_outputs(call.outputs())?; let call_ty = h.get_optype(call.node()).dataflow_signature().unwrap(); - let exp_fun_ty = Signature::new(vec![utou(PRELUDE_ID), int_pair.clone()], int_pair) - .with_extension_delta(EXT_ID) - .with_prelude(); + let exp_fun_ty = Signature::new(vec![utou(), int_pair.clone()], int_pair); assert_eq!(call_ty.as_ref(), &exp_fun_ty); Ok(()) } @@ -800,7 +755,7 @@ fn cfg_children_restrictions() { let (mut b, def) = make_simple_hugr(1); let (_input, _output, copy) = b .hierarchy - .children(def.pg_index()) + .children(def.into_portgraph()) .map_into() .collect_tuple() .unwrap(); @@ -812,8 +767,7 @@ fn cfg_children_restrictions() { ops::CFG { signature: Signature::new(vec![bool_t()], vec![bool_t()]), }, - ) - .unwrap(); + ); assert_matches!( b.validate(), Err(ValidationError::ContainerWithoutChildren { .. }) @@ -827,7 +781,6 @@ fn cfg_children_restrictions() { inputs: vec![bool_t()].into(), sum_rows: vec![type_row![]], other_outputs: vec![bool_t()].into(), - extension_delta: ExtensionSet::new(), }, ); let const_op: ops::Const = ops::Value::unit_sum(0, 1).unwrap().into(); @@ -865,7 +818,7 @@ fn cfg_children_restrictions() { assert_matches!( b.validate(), Err(ValidationError::InvalidChildren { parent, source: ChildrenValidationError::InternalExitChildren { child, .. }, .. }) - => {assert_eq!(parent, cfg); assert_eq!(child, exit2.pg_index())} + => {assert_eq!(parent, cfg); assert_eq!(child, exit2.into_portgraph())} ); b.remove_node(exit2); @@ -875,28 +828,23 @@ fn cfg_children_restrictions() { ops::CFG { signature: Signature::new(vec![qb_t()], vec![bool_t()]), }, - ) - .unwrap(); + ); b.replace_op( block, ops::DataflowBlock { inputs: vec![qb_t()].into(), sum_rows: vec![type_row![]], other_outputs: vec![qb_t()].into(), - extension_delta: ExtensionSet::new(), }, - ) - .unwrap(); - let mut block_children = b.hierarchy.children(block.pg_index()); + ); + let mut block_children = b.hierarchy.children(block.into_portgraph()); let block_input = block_children.next().unwrap().into(); let block_output = block_children.next_back().unwrap().into(); - b.replace_op(block_input, ops::Input::new(vec![qb_t()])) - .unwrap(); + b.replace_op(block_input, ops::Input::new(vec![qb_t()])); b.replace_op( block_output, ops::Output::new(vec![Type::new_unit_sum(1), qb_t()]), - ) - .unwrap(); + ); assert_matches!( b.validate(), Err(ValidationError::InvalidEdges { parent, source: EdgeValidationError::CFGEdgeSignatureMismatch { .. }, .. }) @@ -913,8 +861,7 @@ fn cfg_connections() -> Result<(), Box> { let mut hugr = CFGBuilder::new(Signature::new_endo(usize_t()))?; let unary_pred = hugr.add_constant(Value::unary_unit_sum()); - let mut entry = - hugr.simple_entry_builder_exts(vec![usize_t()].into(), 1, ExtensionSet::new())?; + let mut entry = hugr.simple_entry_builder(vec![usize_t()].into(), 1)?; let p = entry.load_const(&unary_pred); let ins = entry.input_wires(); let entry = entry.finish_with_outputs(p, ins)?; @@ -958,219 +905,3 @@ fn cfg_entry_io_bug() -> Result<(), Box> { Ok(()) } - -#[cfg(feature = "extension_inference")] -mod extension_tests { - use self::ops::handle::{BasicBlockID, TailLoopID}; - use rstest::rstest; - - use super::*; - use crate::builder::handle::Outputs; - use crate::builder::{BlockBuilder, BuildHandle, CFGBuilder, DFGWrapper, TailLoopBuilder}; - use crate::extension::prelude::PRELUDE_ID; - use crate::extension::ExtensionSet; - use crate::hugr::test::{lift_op, LIFT_EXT_ID}; - use crate::macros::const_extension_ids; - use crate::Wire; - const_extension_ids! { - const XA: ExtensionId = "A"; - const XB: ExtensionId = "BOOL_EXT"; - } - - #[rstest] - #[case::d1(|signature| ops::DFG {signature}.into())] - #[case::f1(|sig: Signature| ops::FuncDefn {name: "foo".to_string(), signature: sig.into()}.into())] - #[case::c1(|signature| ops::Case {signature}.into())] - fn parent_extension_mismatch( - #[case] parent_f: impl Fn(Signature) -> OpType, - #[values(ExtensionSet::new(), XA.into())] parent_extensions: ExtensionSet, - ) { - // Child graph adds extension "XB", but the parent (in all cases) - // declares a different delta, causing a mismatch. - - let parent = parent_f( - Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), - ); - let mut hugr = Hugr::new(parent); - - let input = hugr.add_node_with_parent( - hugr.root(), - ops::Input { - types: vec![usize_t()].into(), - }, - ); - let output = hugr.add_node_with_parent( - hugr.root(), - ops::Output { - types: vec![usize_t()].into(), - }, - ); - - let lift = hugr.add_node_with_parent(hugr.root(), lift_op(usize_t(), XB)); - - hugr.connect(input, 0, lift, 0); - hugr.connect(lift, 0, output, 0); - - let result = hugr.validate(); - assert_eq!( - result, - Err(ValidationError::ExtensionError(ExtensionError { - parent: hugr.root(), - parent_extensions, - child: lift, - child_extensions: ExtensionSet::from_iter([LIFT_EXT_ID, XB]), - })) - ); - } - - #[rstest] - #[case(XA.into(), false)] - #[case(ExtensionSet::new(), false)] - #[case(ExtensionSet::from_iter([XA, XB]), true)] - fn cfg_extension_mismatch( - #[case] parent_extensions: ExtensionSet, - #[case] success: bool, - ) -> Result<(), BuildError> { - let mut cfg = CFGBuilder::new( - Signature::new_endo(usize_t()).with_extension_delta(parent_extensions.clone()), - )?; - let mut bb = cfg.simple_entry_builder_exts(usize_t().into(), 1, XB)?; - let pred = bb.add_load_value(Value::unary_unit_sum()); - let inputs = bb.input_wires(); - let blk = bb.finish_with_outputs(pred, inputs)?; - let exit = cfg.exit_block(); - cfg.branch(&blk, 0, &exit)?; - let root = cfg.hugr().root(); - let res = cfg.finish_hugr(); - if success { - assert!(res.is_ok()) - } else { - assert_eq!( - res, - Err(ValidationError::ExtensionError(ExtensionError { - parent: root, - parent_extensions, - child: blk.node(), - child_extensions: XB.into() - })) - ); - } - Ok(()) - } - - #[rstest] - #[case(XA.into(), false)] - #[case(ExtensionSet::new(), false)] - #[case(ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]), true)] - fn conditional_extension_mismatch( - #[case] parent_extensions: ExtensionSet, - #[case] success: bool, - ) { - // Child graph adds extension "XB", but the parent - // declares a different delta, in same cases causing a mismatch. - let parent = ops::Conditional { - sum_rows: vec![type_row![], type_row![]], - other_inputs: vec![usize_t()].into(), - outputs: vec![usize_t()].into(), - extension_delta: parent_extensions.clone(), - }; - let mut hugr = Hugr::new(parent); - - // First case with no delta should be ok in all cases. Second one may not be. - let [_, child] = [None, Some(XB)].map(|case_ext| { - let case_exts = if let Some(ex) = &case_ext { - ExtensionSet::from_iter([ex.clone(), LIFT_EXT_ID]) - } else { - ExtensionSet::new() - }; - let case = hugr.add_node_with_parent( - hugr.root(), - ops::Case { - signature: Signature::new_endo(usize_t()).with_extension_delta(case_exts), - }, - ); - - let input = hugr.add_node_with_parent( - case, - ops::Input { - types: vec![usize_t()].into(), - }, - ); - let output = hugr.add_node_with_parent( - case, - ops::Output { - types: vec![usize_t()].into(), - }, - ); - let res = match case_ext { - None => input, - Some(new_ext) => { - let lift = hugr.add_node_with_parent(case, lift_op(usize_t(), new_ext)); - hugr.connect(input, 0, lift, 0); - lift - } - }; - hugr.connect(res, 0, output, 0); - case - }); - // case is the last-assigned child, i.e. the one that requires 'XB' - let result = hugr.validate(); - let expected = if success { - Ok(()) - } else { - Err(ValidationError::ExtensionError(ExtensionError { - parent: hugr.root(), - parent_extensions, - child, - child_extensions: ExtensionSet::from_iter([XB, LIFT_EXT_ID]), - })) - }; - assert_eq!(result, expected); - } - - #[rstest] - #[case(make_bb, |bb: &mut DFGWrapper<_,_>, outs| bb.make_tuple(outs))] - #[case(make_tailloop, |tl: &mut DFGWrapper<_,_>, outs| tl.make_break(tl.loop_signature().unwrap().clone(), outs))] - fn bb_extension_mismatch( - #[case] dfg_fn: impl Fn(Type, ExtensionSet) -> DFGWrapper, - #[case] make_pred: impl Fn(&mut DFGWrapper, Outputs) -> Result, - // last one includes prelude because `MakeTuple` is in prelude - #[values((ExtensionSet::from_iter([XA,LIFT_EXT_ID]), false), (LIFT_EXT_ID.into(), false), (ExtensionSet::from_iter([XA,XB,LIFT_EXT_ID,PRELUDE_ID]), true))] - parent_exts_success: (ExtensionSet, bool), - ) -> Result<(), BuildError> { - let (parent_extensions, success) = parent_exts_success; - let mut dfg = dfg_fn(usize_t(), parent_extensions.clone()); - let lift = dfg.add_dataflow_op(lift_op(usize_t(), XB), dfg.input_wires())?; - let pred = make_pred(&mut dfg, lift.outputs())?; - let root = dfg.hugr().root(); - let res = dfg.finish_hugr_with_outputs([pred]); - if success { - if res.is_err() { - dbg!(&res); - } - assert!(res.is_ok()) - } else { - assert_eq!( - res, - Err(BuildError::InvalidHUGR(ValidationError::ExtensionError( - ExtensionError { - parent: root, - parent_extensions, - child: lift.node(), - child_extensions: ExtensionSet::from_iter([XB, LIFT_EXT_ID]) - } - ))) - ); - } - Ok(()) - } - - fn make_bb(t: Type, es: ExtensionSet) -> DFGWrapper { - BlockBuilder::new_exts(t.clone(), vec![t.into()], type_row![], es).unwrap() - } - - fn make_tailloop(t: Type, es: ExtensionSet) -> DFGWrapper> { - let row = TypeRow::from(t); - TailLoopBuilder::new_exts(row.clone(), type_row![], row, es).unwrap() - } -} diff --git a/hugr-core/src/hugr/views.rs b/hugr-core/src/hugr/views.rs index 09805d1f8..ea414c376 100644 --- a/hugr-core/src/hugr/views.rs +++ b/hugr-core/src/hugr/views.rs @@ -16,7 +16,7 @@ use std::borrow::Cow; pub use self::petgraph::PetgraphWrapper; use self::render::RenderConfig; pub use descendants::DescendantsGraph; -pub use root_checked::RootChecked; +pub use root_checked::{check_tag, RootCheckable, RootChecked}; pub use sibling::SiblingGraph; pub use sibling_subgraph::SiblingSubgraph; @@ -24,12 +24,9 @@ use itertools::Itertools; use portgraph::render::{DotFormat, MermaidFormat}; use portgraph::{LinkView, PortView}; -use super::internal::HugrInternals; -use super::{ - Hugr, HugrError, HugrMut, Node, NodeMetadata, NodeMetadataMap, ValidationError, DEFAULT_OPTYPE, -}; +use super::internal::{HugrInternals, HugrMutInternals}; +use super::{Hugr, HugrError, HugrMut, Node, NodeMetadata, ValidationError}; use crate::extension::ExtensionRegistry; -use crate::ops::handle::NodeHandle; use crate::ops::{OpParent, OpTag, OpTrait, OpType}; use crate::types::{EdgeKind, PolyFuncType, Signature, Type}; @@ -41,85 +38,67 @@ use itertools::Either; /// For end users we intend this to be superseded by region-specific APIs. pub trait HugrView: HugrInternals { /// Return the root node of this view. - #[inline] - fn root(&self) -> Self::Node { - self.root_node() - } + fn root(&self) -> Self::Node; - /// Return the type of the HUGR root node. + /// Return the optype of the HUGR root node. #[inline] - fn root_type(&self) -> &OpType { + fn root_optype(&self) -> &OpType { let node_type = self.get_optype(self.root()); - // Sadly no way to do this at present - // debug_assert!(Self::RootHandle::can_hold(node_type.tag())); node_type } - /// Returns whether the node exists. + /// Returns `true` if the node exists in the HUGR. fn contains_node(&self, node: Self::Node) -> bool; - /// Validates that a node is valid in the graph. - #[inline] - fn valid_node(&self, node: Self::Node) -> bool { - self.contains_node(node) - } - - /// Validates that a node is a valid root descendant in the graph. - /// - /// To include the root node use [`HugrView::valid_node`] instead. - #[inline] - fn valid_non_root(&self, node: Self::Node) -> bool { - self.root() != node && self.valid_node(node) - } - /// Returns the parent of a node. - #[inline] - fn get_parent(&self, node: Self::Node) -> Option { - if !self.valid_non_root(node) { - return None; - }; - self.base_hugr() - .hierarchy - .parent(self.get_pg_index(node)) - .map(|index| self.get_node(index)) - } - - /// Returns the operation type of a node. - #[inline] - fn get_optype(&self, node: Self::Node) -> &OpType { - match self.contains_node(node) { - true => self.base_hugr().op_types.get(self.get_pg_index(node)), - false => &DEFAULT_OPTYPE, - } - } + fn get_parent(&self, node: Self::Node) -> Option; /// Returns the metadata associated with a node. #[inline] fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&NodeMetadata> { match self.contains_node(node) { - true => self.get_node_metadata(node)?.get(key.as_ref()), + true => self.node_metadata_map(node).get(key.as_ref()), false => None, } } - /// Retrieve the complete metadata map for a node. - fn get_node_metadata(&self, node: Self::Node) -> Option<&NodeMetadataMap> { - if !self.valid_node(node) { - return None; - } - self.base_hugr() - .metadata - .get(self.get_pg_index(node)) - .as_ref() - } + /// Returns the operation type of a node. + /// + /// # Panics + /// + /// If the node is not in the graph. + fn get_optype(&self, node: Self::Node) -> &OpType; + + /// Returns the number of nodes in the HUGR. + fn num_nodes(&self) -> usize; - /// Returns the number of nodes in the hugr. - fn node_count(&self) -> usize; + /// Returns the number of edges in the HUGR. + fn num_edges(&self) -> usize; + + /// Number of ports in node for a given direction. + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; + + /// Number of inputs to a node. + /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Incoming)`. + #[inline] + fn num_inputs(&self, node: Self::Node) -> usize { + self.num_ports(node, Direction::Incoming) + } - /// Returns the number of edges in the hugr. - fn edge_count(&self) -> usize; + /// Number of outputs from a node. + /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Outgoing)`. + #[inline] + fn num_outputs(&self, node: Self::Node) -> usize { + self.num_ports(node, Direction::Outgoing) + } - /// Iterates over the nodes in the port graph. + /// Iterates over the all the nodes in the HUGR. + /// + /// This iterator returns every node in the HUGR, including those that are + /// not descendants from the root node. + /// + /// See [`HugrView::descendants`] and [`HugrView::children`] for more specific + /// iterators. fn nodes(&self) -> impl Iterator + Clone; /// Iterator over ports of node in a given direction. @@ -261,26 +240,15 @@ pub trait HugrView: HugrInternals { self.linked_ports(node, port).next().is_some() } - /// Number of ports in node for a given direction. - fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; - - /// Number of inputs to a node. - /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Incoming)`. - #[inline] - fn num_inputs(&self, node: Self::Node) -> usize { - self.num_ports(node, Direction::Incoming) - } - - /// Number of outputs from a node. - /// Shorthand for [`num_ports`][HugrView::num_ports]`(node, Direction::Outgoing)`. - #[inline] - fn num_outputs(&self, node: Self::Node) -> usize { - self.num_ports(node, Direction::Outgoing) - } - - /// Return iterator over the direct children of node. + /// Returns an iterator over the direct children of node. fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; + /// Returns an iterator over all the descendants of a node, + /// including the node itself. + /// + /// Yields the node itself first, followed by its children in breath-first order. + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; + /// Returns the first child of the specified node (if it is a parent). /// Useful because `x.children().next()` leaves x borrowed. fn first_child(&self, node: Self::Node) -> Option { @@ -335,13 +303,13 @@ pub trait HugrView: HugrInternals { /// In contrast to [`poly_func_type`][HugrView::poly_func_type], this /// method always return a concrete [`Signature`]. fn inner_function_type(&self) -> Option> { - self.root_type().inner_function_type() + self.root_optype().inner_function_type() } /// Returns the function type defined by this HUGR, i.e. `Some` iff the root is /// a [`FuncDecl`][crate::ops::FuncDecl] or [`FuncDefn`][crate::ops::FuncDefn]. fn poly_func_type(&self) -> Option { - match self.root_type() { + match self.root_optype() { OpType::FuncDecl(decl) => Some(decl.signature.clone()), OpType::FuncDefn(defn) => Some(defn.signature.clone()), _ => None, @@ -364,13 +332,7 @@ pub trait HugrView: HugrInternals { /// /// For a more detailed representation, use the [`HugrView::dot_string`] /// format instead. - fn mermaid_string(&self) -> String { - self.mermaid_string_with_config(RenderConfig { - node_indices: true, - port_offsets_in_edges: true, - type_labels_in_edges: true, - }) - } + fn mermaid_string(&self) -> String; /// Return the mermaid representation of the underlying hierarchical graph. /// @@ -379,35 +341,14 @@ pub trait HugrView: HugrInternals { /// /// For a more detailed representation, use the [`HugrView::dot_string`] /// format instead. - fn mermaid_string_with_config(&self, config: RenderConfig) -> String { - let hugr = self.base_hugr(); - let graph = self.portgraph(); - graph - .mermaid_format() - .with_hierarchy(&hugr.hierarchy) - .with_node_style(render::node_style(self, config)) - .with_edge_style(render::edge_style(self, config)) - .finish() - } + fn mermaid_string_with_config(&self, config: RenderConfig) -> String; /// Return the graphviz representation of the underlying graph and hierarchy side by side. /// /// For a simpler representation, use the [`HugrView::mermaid_string`] format instead. fn dot_string(&self) -> String where - Self: Sized, - { - let hugr = self.base_hugr(); - let graph = self.portgraph(); - let config = RenderConfig::default(); - graph - .dot_format() - .with_hierarchy(&hugr.hierarchy) - .with_node_style(render::node_style(self, config)) - .with_port_style(render::port_style(self, config)) - .with_edge_style(render::edge_style(self, config)) - .finish() - } + Self: Sized; /// If a node has a static input, return the source node. fn static_source(&self, node: Self::Node) -> Option { @@ -454,42 +395,19 @@ pub trait HugrView: HugrInternals { /// Returns the set of extensions used by the HUGR. /// - /// This set may contain extensions that are no longer required by the HUGR. - fn extensions(&self) -> &ExtensionRegistry { - &self.base_hugr().extensions - } + /// This set contains all extensions required to define the operations and + /// types in the HUGR. + fn extensions(&self) -> &ExtensionRegistry; /// Check the validity of the underlying HUGR. - /// - /// This includes checking consistency of extension requirements between - /// connected nodes and between parents and children. - /// See [`HugrView::validate_no_extensions`] for a version that doesn't check - /// extension requirements. fn validate(&self) -> Result<(), ValidationError> { + #[allow(deprecated)] self.base_hugr().validate() } - - /// Check the validity of the underlying HUGR, but don't check consistency - /// of extension requirements between connected nodes or between parents and - /// children. - /// - /// For a more thorough check, use [`HugrView::validate`]. - fn validate_no_extensions(&self) -> Result<(), ValidationError> { - self.base_hugr().validate_no_extensions() - } -} - -/// Trait for views that provides a guaranteed bound on the type of the root node. -pub trait RootTagged: HugrView { - /// The kind of handle that can be used to refer to the root node. - /// - /// The handle is guaranteed to be able to contain the operation returned by - /// [`HugrView::root_type`]. - type RootHandle: NodeHandle; } /// A common trait for views of a HUGR hierarchical subgraph. -pub trait HierarchyView<'a>: RootTagged + Sized { +pub trait HierarchyView<'a>: HugrView + Sized { /// Create a hierarchical view of a HUGR given a root node. /// /// # Errors @@ -515,30 +433,6 @@ pub trait ExtractHugr: HugrView + Sized { } } -fn check_tag( - hugr: &impl HugrView, - node: N, -) -> Result<(), HugrError> { - let actual = hugr.get_optype(node).tag(); - let required = Required::TAG; - if !required.is_superset(actual) { - return Err(HugrError::InvalidTag { required, actual }); - } - Ok(()) -} - -impl RootTagged for Hugr { - type RootHandle = Node; -} - -impl RootTagged for &Hugr { - type RootHandle = Node; -} - -impl RootTagged for &mut Hugr { - type RootHandle = Node; -} - // Explicit implementation to avoid cloning the Hugr. impl ExtractHugr for Hugr { fn extract_hugr(self) -> Hugr { @@ -560,18 +454,48 @@ impl ExtractHugr for &mut Hugr { impl HugrView for Hugr { #[inline] - fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(node.pg_index()) + fn root(&self) -> Self::Node { + self.root.into() + } + + #[inline] + fn contains_node(&self, node: Self::Node) -> bool { + self.graph.contains_node(node.into_portgraph()) + } + + #[inline] + fn get_parent(&self, node: Self::Node) -> Option { + if !check_valid_non_root(self, node) { + return None; + }; + self.hierarchy + .parent(self.to_portgraph_node(node)) + .map(|index| self.from_portgraph_node(index)) + } + + #[inline] + fn get_optype(&self, node: Node) -> &OpType { + // TODO: This currently fails because some methods get the optype of + // e.g. a parent outside a region view. We should be able to re-enable + // this once we add hugr entrypoints. + //panic_invalid_node(self, node); + self.op_types.get(self.to_portgraph_node(node)) + } + + #[inline] + fn num_nodes(&self) -> usize { + self.portgraph().node_count() } #[inline] - fn node_count(&self) -> usize { - self.graph.node_count() + fn num_edges(&self) -> usize { + self.portgraph().link_count() } #[inline] - fn edge_count(&self) -> usize { - self.graph.link_count() + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize { + self.portgraph() + .num_ports(self.to_portgraph_node(node), dir) } #[inline] @@ -581,12 +505,16 @@ impl HugrView for Hugr { #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.graph.port_offsets(node.pg_index(), dir).map_into() + self.graph + .port_offsets(node.into_portgraph(), dir) + .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { - self.graph.all_port_offsets(node.pg_index()).map_into() + self.graph + .all_port_offsets(node.into_portgraph()) + .map_into() } #[inline] @@ -599,7 +527,7 @@ impl HugrView for Hugr { let port = self .graph - .port_index(node.pg_index(), port.pg_offset()) + .port_index(node.into_portgraph(), port.pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let port = link.port(); @@ -612,30 +540,72 @@ impl HugrView for Hugr { #[inline] fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(node.pg_index(), other.pg_index()) + .get_connections(node.into_portgraph(), other.into_portgraph()) .map(|(p1, p2)| { [p1, p2].map(|link| self.graph.port_offset(link.port()).unwrap().into()) }) } #[inline] - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(node.pg_index(), dir) + fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone { + self.hierarchy + .children(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } #[inline] - fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { - self.hierarchy.children(node.pg_index()).map_into() + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone { + self.hierarchy + .descendants(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.graph.neighbours(node.pg_index(), dir).map_into() + self.graph.neighbours(node.into_portgraph(), dir).map_into() } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { - self.graph.all_neighbours(node.pg_index()).map_into() + self.graph.all_neighbours(node.into_portgraph()).map_into() + } + + fn mermaid_string(&self) -> String { + self.mermaid_string_with_config(RenderConfig { + node_indices: true, + port_offsets_in_edges: true, + type_labels_in_edges: true, + }) + } + + fn mermaid_string_with_config(&self, config: RenderConfig) -> String { + let graph = self.portgraph(); + graph + .mermaid_format() + .with_hierarchy(&self.hierarchy) + .with_node_style(render::node_style(self, config)) + .with_edge_style(render::edge_style(self, config)) + .finish() + } + + fn dot_string(&self) -> String + where + Self: Sized, + { + let graph = self.portgraph(); + let config = RenderConfig::default(); + graph + .dot_format() + .with_hierarchy(&self.hierarchy) + .with_node_style(render::node_style(self, config)) + .with_port_style(render::port_style(self, config)) + .with_edge_style(render::edge_style(self, config)) + .finish() + } + + #[inline] + fn extensions(&self) -> &ExtensionRegistry { + &self.extensions } } @@ -664,7 +634,7 @@ where hugr: &impl HugrView, ) -> impl Iterator { self.filter(move |(n, p)| { - let kind = hugr.get_optype(*n).port_kind(*p); + let kind = HugrView::get_optype(hugr, *n).port_kind(*p); predicate(kind) }) } @@ -676,3 +646,47 @@ where P: Into + Copy, { } + +/// Returns `true` if the node exists in the graph and is not the module at the hierarchy root. +pub(super) fn check_valid_non_root(hugr: &H, node: H::Node) -> bool { + hugr.contains_node(node) && node != hugr.root() +} + +/// Panic if [`HugrView::contains_node`] fails. +#[track_caller] +pub(super) fn panic_invalid_node(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if !hugr.contains_node(node) { + panic!("Received an invalid node {node}.",); + } +} + +/// Panic if [`check_valid_non_root`] fails. +#[track_caller] +pub(super) fn panic_invalid_non_root(hugr: &H, node: H::Node) { + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if !check_valid_non_root(hugr, node) { + panic!("Received an invalid non-root node {node}.",); + } +} + +/// Panic if [`HugrView::valid_node`] fails. +#[track_caller] +pub(super) fn panic_invalid_port( + hugr: &H, + node: Node, + port: impl Into, +) { + let port = port.into(); + // TODO: When stacking hugr wrappers, this gets called for every layer. + // Should we `cfg!(debug_assertions)` this? Benchmark and see if it matters. + if hugr + .portgraph() + .port_index(node.into_portgraph(), port.pg_offset()) + .is_none() + { + panic!("Received an invalid port {port} for node {node} while mutating a HUGR"); + } +} diff --git a/hugr-core/src/hugr/views/descendants.rs b/hugr-core/src/hugr/views/descendants.rs index 6f87027ef..13dfde8f7 100644 --- a/hugr-core/src/hugr/views/descendants.rs +++ b/hugr-core/src/hugr/views/descendants.rs @@ -8,7 +8,7 @@ use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView}; type RegionGraph<'g> = portgraph::view::Region<'g, &'g MultiPortGraph>; @@ -41,37 +41,44 @@ pub struct DescendantsGraph<'g, Root = Node> { _phantom: std::marker::PhantomData, } impl HugrView for DescendantsGraph<'_, Root> { + #[inline] + fn root(&self) -> Self::Node { + self.root + } + #[inline] fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(self.get_pg_index(node)) + self.graph.contains_node(self.to_portgraph_node(node)) } #[inline] - fn node_count(&self) -> usize { + fn num_nodes(&self) -> usize { self.graph.node_count() } #[inline] - fn edge_count(&self) -> usize { + fn num_edges(&self) -> usize { self.graph.link_count() } #[inline] fn nodes(&self) -> impl Iterator + Clone { - self.graph.nodes_iter().map(|index| self.get_node(index)) + self.graph + .nodes_iter() + .map(|index| self.from_portgraph_node(index)) } #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .port_offsets(self.get_pg_index(node), dir) + .port_offsets(self.to_portgraph_node(node), dir) .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_port_offsets(self.get_pg_index(node)) + .all_port_offsets(self.to_portgraph_node(node)) .map_into() } @@ -82,19 +89,19 @@ impl HugrView for DescendantsGraph<'_, Root> { ) -> impl Iterator + Clone { let port = self .graph - .port_index(self.get_pg_index(node), port.into().pg_offset()) + .port_index(self.to_portgraph_node(node), port.into().pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let port: PortIndex = link.into(); let node = self.graph.port_node(port).unwrap(); let offset = self.graph.port_offset(port).unwrap(); - (self.get_node(node), offset.into()) + (self.from_portgraph_node(node), offset.into()) }) } fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(self.get_pg_index(node), self.get_pg_index(other)) + .get_connections(self.to_portgraph_node(node), self.to_portgraph_node(other)) .map(|(p1, p2)| { [p1, p2].map(|link| { let offset = self.graph.port_offset(link).unwrap(); @@ -105,34 +112,47 @@ impl HugrView for DescendantsGraph<'_, Root> { #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(self.get_pg_index(node), dir) + self.graph.num_ports(self.to_portgraph_node(node), dir) } #[inline] fn children(&self, node: Node) -> impl DoubleEndedIterator + Clone { - let children = match self.graph.contains_node(self.get_pg_index(node)) { - true => self.base_hugr().hierarchy.children(self.get_pg_index(node)), + let hierarchy = self.hierarchy(); + let children = match self.graph.contains_node(self.to_portgraph_node(node)) { + true => hierarchy.children(self.to_portgraph_node(node)), false => portgraph::hierarchy::Children::default(), }; - children.map(|index| self.get_node(index)) + children.map(move |index| { + let _ = hierarchy; + self.from_portgraph_node(index) + }) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .neighbours(self.get_pg_index(node), dir) - .map(|index| self.get_node(index)) + .neighbours(self.to_portgraph_node(node), dir) + .map(|index| self.from_portgraph_node(index)) } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_neighbours(self.get_pg_index(node)) - .map(|index| self.get_node(index)) + .all_neighbours(self.to_portgraph_node(node)) + .map(|index| self.from_portgraph_node(index)) + } + + delegate::delegate! { + to (&self.hugr) { + fn get_parent(&self, node: Self::Node) -> Option; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; + fn mermaid_string(&self) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; + fn dot_string(&self) -> String; + fn extensions(&self) -> &crate::extension::ExtensionRegistry; + } } -} -impl RootTagged for DescendantsGraph<'_, Root> { - type RootHandle = Root; } impl<'a, Root> HierarchyView<'a> for DescendantsGraph<'a, Root> @@ -140,11 +160,12 @@ where Root: NodeHandle, { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { - check_tag::(hugr, root)?; + check_tag::(hugr, root)?; + #[allow(deprecated)] let hugr = hugr.base_hugr(); Ok(Self { root, - graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)), + graph: RegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.to_portgraph_node(root)), hugr, _phantom: std::marker::PhantomData, }) @@ -169,23 +190,38 @@ where &self.graph } - fn base_hugr(&self) -> &Hugr { - self.hugr + #[inline] + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } #[inline] - fn root_node(&self) -> Node { - self.root + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { - self.hugr.get_node(index) + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Node { + self.hugr.from_portgraph_node(index) + } + + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap { + self.hugr.node_metadata_map(node) + } + + fn base_hugr(&self) -> &Hugr { + self.hugr } } @@ -200,7 +236,7 @@ pub(super) mod test { use crate::{ builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}, types::Signature, - utils::test_quantum_extension::{h_gate, EXTENSION_ID}, + utils::test_quantum_extension::h_gate, }; use super::*; @@ -213,10 +249,8 @@ pub(super) mod test { let mut module_builder = ModuleBuilder::new(); let (f_id, inner_id) = { - let mut func_builder = module_builder.define_function( - "main", - Signature::new_endo(vec![usize_t(), qb_t()]).with_extension_delta(EXTENSION_ID), - )?; + let mut func_builder = module_builder + .define_function("main", Signature::new_endo(vec![usize_t(), qb_t()]))?; let [int, qb] = func_builder.input_wires_arr(); @@ -244,7 +278,7 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; let def_io = region.get_io(def).unwrap(); - assert_eq!(region.node_count(), 7); + assert_eq!(region.num_nodes(), 7); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) || hugr.get_parent(n) == Some(inner))); @@ -252,11 +286,7 @@ pub(super) mod test { assert_eq!( region.poly_func_type(), - Some( - Signature::new_endo(vec![usize_t(), qb_t()]) - .with_extension_delta(EXTENSION_ID) - .into() - ) + Some(Signature::new_endo(vec![usize_t(), qb_t()]).into()) ); let inner_region: DescendantsGraph = DescendantsGraph::try_new(&hugr, inner)?; @@ -264,8 +294,8 @@ pub(super) mod test { inner_region.inner_function_type().map(Cow::into_owned), Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); - assert_eq!(inner_region.node_count(), 3); - assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.num_nodes(), 3); + assert_eq!(inner_region.num_edges(), 1); assert_eq!(inner_region.children(inner).count(), 2); assert_eq!(inner_region.children(hugr.root()).count(), 0); assert_eq!( @@ -314,8 +344,8 @@ pub(super) mod test { let region: DescendantsGraph = DescendantsGraph::try_new(&hugr, def)?; - assert_eq!(region.node_count(), extracted.node_count()); - assert_eq!(region.root_type(), extracted.root_type()); + assert_eq!(region.num_nodes(), extracted.num_nodes()); + assert_eq!(region.root_optype(), extracted.root_optype()); Ok(()) } diff --git a/hugr-core/src/hugr/views/impls.rs b/hugr-core/src/hugr/views/impls.rs index 2cfc70104..9be352b5e 100644 --- a/hugr-core/src/hugr/views/impls.rs +++ b/hugr-core/src/hugr/views/impls.rs @@ -1,119 +1,241 @@ +//! Implementation of the core hugr traits for different wrappers of a `Hugr`. + use std::{borrow::Cow, rc::Rc, sync::Arc}; -use delegate::delegate; -use itertools::Either; +use super::HugrView; +use crate::hugr::internal::{HugrInternals, HugrMutInternals}; +use crate::hugr::HugrMut; -use super::{render::RenderConfig, HugrView, RootChecked}; -use crate::{ - extension::ExtensionRegistry, - hugr::{NodeMetadata, NodeMetadataMap, ValidationError}, - ops::OpType, - types::{PolyFuncType, Signature, Type}, - Direction, Hugr, IncomingPort, OutgoingPort, Port, -}; +macro_rules! hugr_internal_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn portgraph(&self) -> Self::Portgraph<'_>; + fn region_portgraph(&self, parent: Self::Node) -> portgraph::view::FlatRegion<'_, impl portgraph::view::LinkView + Clone + '_>; + fn hierarchy(&self) -> &portgraph::Hierarchy; + fn to_portgraph_node(&self, node: impl crate::ops::handle::NodeHandle) -> portgraph::NodeIndex; + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node; + fn node_metadata_map(&self, node: Self::Node) -> &crate::hugr::NodeMetadataMap; + #[allow(deprecated)] + fn base_hugr(&self) -> &crate::Hugr; + } + } + }; +} +pub(crate) use hugr_internal_methods; macro_rules! hugr_view_methods { // The extra ident here is because invocations of the macro cannot pass `self` as argument ($arg:ident, $e:expr) => { - delegate! { + delegate::delegate! { to ({let $arg=self; $e}) { fn root(&self) -> Self::Node; - fn root_type(&self) -> &OpType; + fn root_optype(&self) -> &crate::ops::OpType; fn contains_node(&self, node: Self::Node) -> bool; - fn valid_node(&self, node: Self::Node) -> bool; - fn valid_non_root(&self, node: Self::Node) -> bool; fn get_parent(&self, node: Self::Node) -> Option; - fn get_optype(&self, node: Self::Node) -> &OpType; - fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&NodeMetadata>; - fn get_node_metadata(&self, node: Self::Node) -> Option<&NodeMetadataMap>; - fn node_count(&self) -> usize; - fn edge_count(&self) -> usize; - fn nodes(&self) -> impl Iterator + Clone; - fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone; - fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; - fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; - fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; - fn linked_ports( - &self, - node: Self::Node, - port: impl Into, - ) -> impl Iterator + Clone; - fn all_linked_ports( - &self, - node: Self::Node, - dir: Direction, - ) -> Either< - impl Iterator, - impl Iterator, - >; - fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; - fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; - fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, Port)>; - fn single_linked_output(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, OutgoingPort)>; - fn single_linked_input(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, IncomingPort)>; - fn linked_outputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; - fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; - fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; - fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; - fn num_ports(&self, node: Self::Node, dir: Direction) -> usize; + fn get_metadata(&self, node: Self::Node, key: impl AsRef) -> Option<&crate::hugr::NodeMetadata>; + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType; + fn num_nodes(&self) -> usize; + fn num_edges(&self) -> usize; + fn num_ports(&self, node: Self::Node, dir: crate::Direction) -> usize; fn num_inputs(&self, node: Self::Node) -> usize; fn num_outputs(&self, node: Self::Node) -> usize; + fn nodes(&self) -> impl Iterator + Clone; + fn node_ports(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; + fn node_outputs(&self, node: Self::Node) -> impl Iterator + Clone; + fn node_inputs(&self, node: Self::Node) -> impl Iterator + Clone; + fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone; + fn linked_ports(&self, node: Self::Node, port: impl Into) -> impl Iterator + Clone; + fn all_linked_ports(&self, node: Self::Node, dir: crate::Direction) -> itertools::Either, impl Iterator>; + fn all_linked_outputs(&self, node: Self::Node) -> impl Iterator; + fn all_linked_inputs(&self, node: Self::Node) -> impl Iterator; + fn single_linked_port(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::Port)>; + fn single_linked_output(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::OutgoingPort)>; + fn single_linked_input(&self, node: Self::Node, port: impl Into) -> Option<(Self::Node, crate::IncomingPort)>; + fn linked_outputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; + fn linked_inputs(&self, node: Self::Node, port: impl Into) -> impl Iterator; + fn node_connections(&self, node: Self::Node, other: Self::Node) -> impl Iterator + Clone; + fn is_linked(&self, node: Self::Node, port: impl Into) -> bool; fn children(&self, node: Self::Node) -> impl DoubleEndedIterator + Clone; + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone; fn first_child(&self, node: Self::Node) -> Option; - fn neighbours(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone; + fn neighbours(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator + Clone; fn input_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn output_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone; - fn get_io(&self, node: Self::Node) -> Option<[Self::Node; 2]>; - fn inner_function_type(&self) -> Option>; - fn poly_func_type(&self) -> Option; - // TODO: cannot use delegate here. `PetgraphWrapper` is a thin - // wrapper around `Self`, so falling back to the default impl - // should be harmless. - // fn as_petgraph(&self) -> PetgraphWrapper<'_, Self>; fn mermaid_string(&self) -> String; - fn mermaid_string_with_config(&self, config: RenderConfig) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; fn dot_string(&self) -> String; fn static_source(&self, node: Self::Node) -> Option; - fn static_targets(&self, node: Self::Node) -> Option>; - fn signature(&self, node: Self::Node) -> Option>; - fn value_types(&self, node: Self::Node, dir: Direction) -> impl Iterator; - fn in_value_types(&self, node: Self::Node) -> impl Iterator; - fn out_value_types(&self, node: Self::Node) -> impl Iterator; - fn extensions(&self) -> &ExtensionRegistry; - fn validate(&self) -> Result<(), ValidationError>; - fn validate_no_extensions(&self) -> Result<(), ValidationError>; + fn static_targets(&self, node: Self::Node) -> Option>; + fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator; + fn extensions(&self) -> &crate::extension::ExtensionRegistry; + fn validate(&self) -> Result<(), crate::hugr::ValidationError>; } } } } +pub(crate) use hugr_view_methods; + +macro_rules! hugr_mut_internal_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn replace_op(&mut self, node: Self::Node, op: impl Into) -> crate::ops::OpType; + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } + } + }; +} +pub(crate) use hugr_mut_internal_methods; + +macro_rules! hugr_mut_methods { + // The extra ident here is because invocations of the macro cannot pass `self` as argument + ($arg:ident, $e:expr) => { + delegate::delegate! { + to ({let $arg=self; $e}) { + fn get_metadata_mut(&mut self, node: Self::Node, key: impl AsRef) -> &mut crate::hugr::NodeMetadata; + fn set_metadata(&mut self, node: Self::Node, key: impl AsRef, metadata: impl Into); + fn remove_metadata(&mut self, node: Self::Node, key: impl AsRef); + fn add_node_with_parent(&mut self, parent: Self::Node, op: impl Into) -> Self::Node; + fn add_node_before(&mut self, sibling: Self::Node, nodetype: impl Into) -> Self::Node; + fn add_node_after(&mut self, sibling: Self::Node, op: impl Into) -> Self::Node; + fn remove_node(&mut self, node: Self::Node) -> crate::ops::OpType; + fn remove_subtree(&mut self, node: Self::Node); + fn copy_descendants(&mut self, root: Self::Node, new_parent: Self::Node, subst: Option) -> std::collections::BTreeMap; + fn connect(&mut self, src: Self::Node, src_port: impl Into, dst: Self::Node, dst_port: impl Into); + fn disconnect(&mut self, node: Self::Node, port: impl Into); + fn add_other_edge(&mut self, src: Self::Node, dst: Self::Node) -> (crate::OutgoingPort, crate::IncomingPort); + fn insert_hugr(&mut self, root: Self::Node, other: crate::Hugr) -> crate::hugr::hugrmut::InsertionResult; + fn insert_from_view(&mut self, root: Self::Node, other: &Other) -> crate::hugr::hugrmut::InsertionResult; + fn insert_subgraph(&mut self, root: Self::Node, other: &Other, subgraph: &crate::hugr::views::SiblingSubgraph) -> std::collections::HashMap; + fn use_extension(&mut self, extension: impl Into>); + fn use_extensions(&mut self, registry: impl IntoIterator) where crate::extension::ExtensionRegistry: Extend; + } + } + }; +} +pub(crate) use hugr_mut_methods; + +// -------- Immutable borrow +impl HugrInternals for &T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, *this} +} impl HugrView for &T { hugr_view_methods! {this, *this} } +// -------- Mutable borrow +impl HugrInternals for &mut T { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + + hugr_internal_methods! {this, &**this} +} impl HugrView for &mut T { hugr_view_methods! {this, &**this} } +impl HugrMutInternals for &mut T { + hugr_mut_internal_methods! {this, &mut **this} +} +impl HugrMut for &mut T { + hugr_mut_methods! {this, &mut **this} +} + +// -------- Rc +impl HugrInternals for Rc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Rc { hugr_view_methods! {this, this.as_ref()} } +// -------- Arc +impl HugrInternals for Arc { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Arc { hugr_view_methods! {this, this.as_ref()} } +// -------- Box +impl HugrInternals for Box { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Box { hugr_view_methods! {this, this.as_ref()} } +impl HugrMutInternals for Box { + hugr_mut_internal_methods! {this, this.as_mut()} +} +impl HugrMut for Box { + hugr_mut_methods! {this, this.as_mut()} +} +// -------- Cow +impl HugrInternals for Cow<'_, T> { + type Portgraph<'p> + = T::Portgraph<'p> + where + Self: 'p; + type Node = T::Node; + + hugr_internal_methods! {this, this.as_ref()} +} impl HugrView for Cow<'_, T> { hugr_view_methods! {this, this.as_ref()} } - -impl, Root> HugrView for RootChecked { - hugr_view_methods! {this, this.as_ref()} +impl HugrMutInternals for Cow<'_, T> +where + T: HugrMutInternals + ToOwned, + ::Owned: HugrMutInternals, +{ + hugr_mut_internal_methods! {this, this.to_mut()} +} +impl HugrMut for Cow<'_, T> +where + T: HugrMut + ToOwned, + ::Owned: HugrMut, +{ + hugr_mut_methods! {this, this.to_mut()} } #[cfg(test)] diff --git a/hugr-core/src/hugr/views/petgraph.rs b/hugr-core/src/hugr/views/petgraph.rs index 17c3e0062..22da47f0a 100644 --- a/hugr-core/src/hugr/views/petgraph.rs +++ b/hugr-core/src/hugr/views/petgraph.rs @@ -55,7 +55,7 @@ where T: HugrView, { fn node_count(&self) -> usize { - HugrView::node_count(self.hugr) + HugrView::num_nodes(self.hugr) } } @@ -64,15 +64,15 @@ where T: HugrView, { fn node_bound(&self) -> usize { - HugrView::node_count(self.hugr) + HugrView::num_nodes(self.hugr) } fn to_index(&self, ix: Self::NodeId) -> usize { - self.hugr.get_pg_index(ix).into() + self.hugr.to_portgraph_node(ix).into() } fn from_index(&self, ix: usize) -> Self::NodeId { - self.hugr.get_node(portgraph::NodeIndex::new(ix)) + self.hugr.from_portgraph_node(portgraph::NodeIndex::new(ix)) } } @@ -81,7 +81,7 @@ where T: HugrView, { fn edge_count(&self) -> usize { - HugrView::edge_count(self.hugr) + HugrView::num_edges(self.hugr) } } @@ -233,7 +233,7 @@ mod test { assert_eq!(wrapper.node_bound(), 5); assert_eq!(wrapper.edge_count(), 7); - let cx1_index = cx1.node().pg_index().index(); + let cx1_index = cx1.node().into_portgraph().index(); assert_eq!(wrapper.to_index(cx1.node()), cx1_index); assert_eq!(wrapper.from_index(cx1_index), cx1.node()); diff --git a/hugr-core/src/hugr/views/render.rs b/hugr-core/src/hugr/views/render.rs index ecb8549c0..43530e4c1 100644 --- a/hugr-core/src/hugr/views/render.rs +++ b/hugr-core/src/hugr/views/render.rs @@ -36,7 +36,7 @@ pub(super) fn node_style( config: RenderConfig, ) -> Box NodeStyle + '_> { fn node_name(h: &H, n: NodeIndex) -> String { - match h.get_optype(h.get_node(n)) { + match h.get_optype(h.from_portgraph_node(n)) { OpType::FuncDecl(f) => format!("FuncDecl: \"{}\"", f.name), OpType::FuncDefn(f) => format!("FuncDefn: \"{}\"", f.name), op => op.name().to_string(), @@ -45,14 +45,14 @@ pub(super) fn node_style( if config.node_indices { Box::new(move |n| { - NodeStyle::Box(format!( + NodeStyle::boxed(format!( "({ni}) {name}", ni = n.index(), name = node_name(h, n) )) }) } else { - Box::new(move |n| NodeStyle::Box(node_name(h, n))) + Box::new(move |n| NodeStyle::boxed(node_name(h, n))) } } @@ -64,7 +64,7 @@ pub(super) fn port_style( let graph = h.portgraph(); Box::new(move |port| { let node = graph.port_node(port).unwrap(); - let optype = h.get_optype(h.get_node(node)); + let optype = h.get_optype(h.from_portgraph_node(node)); let offset = graph.port_offset(port).unwrap(); match optype.port_kind(offset).unwrap() { EdgeKind::Function(pf) => PortStyle::new(html_escape::encode_text(&format!("{}", pf))), @@ -95,7 +95,7 @@ pub(super) fn edge_style( let graph = h.portgraph(); Box::new(move |src, tgt| { let src_node = graph.port_node(src).unwrap(); - let src_optype = h.get_optype(h.get_node(src_node)); + let src_optype = h.get_optype(h.from_portgraph_node(src_node)); let src_offset = graph.port_offset(src).unwrap(); let tgt_offset = graph.port_offset(tgt).unwrap(); diff --git a/hugr-core/src/hugr/views/root_checked.rs b/hugr-core/src/hugr/views/root_checked.rs index ba214241a..50c9bcf44 100644 --- a/hugr-core/src/hugr/views/root_checked.rs +++ b/hugr-core/src/hugr/views/root_checked.rs @@ -1,22 +1,28 @@ -use std::borrow::Cow; use std::marker::PhantomData; -use delegate::delegate; -use portgraph::MultiPortGraph; - -use crate::hugr::internal::{HugrInternals, HugrMutInternals}; -use crate::hugr::{HugrError, HugrMut}; +use crate::hugr::HugrError; use crate::ops::handle::NodeHandle; +use crate::ops::{OpTag, OpTrait}; use crate::{Hugr, Node}; -use super::{check_tag, RootTagged}; +use super::HugrView; -/// A view of the whole Hugr. -/// (Just provides static checking of the type of the root node) +/// A wrapper over a Hugr that ensures the root node optype is of the required +/// [`OpTag`]. #[derive(Clone)] -pub struct RootChecked(H, PhantomData); +pub struct RootChecked(H, PhantomData); + +impl> RootChecked { + /// A tag that can contain the operation of the hugr root node. + const TAG: OpTag = Handle::TAG; + + /// Returns the most specific tag that can be applied to the root node. + pub fn tag(&self) -> OpTag { + let tag = self.0.get_optype(self.0.root()).tag(); + debug_assert!(Self::TAG.is_superset(tag)); + tag + } -impl, Root: NodeHandle> RootChecked { /// Create a hierarchical view of a whole HUGR /// /// # Errors @@ -24,82 +30,80 @@ impl, Root: NodeHandle> RootChecked { /// /// [`OpTag`]: crate::ops::OpTag pub fn try_new(hugr: H) -> Result { - if !H::RootHandle::TAG.is_superset(Root::TAG) { - return Err(HugrError::InvalidTag { - required: H::RootHandle::TAG, - actual: Root::TAG, - }); - } - check_tag::(&hugr, hugr.root())?; + Self::check(&hugr)?; Ok(Self(hugr, PhantomData)) } -} -impl RootChecked { - /// Extracts the underlying (owned) Hugr - pub fn into_hugr(self) -> Hugr { - self.0 + /// Check if a Hugr is valid for the given [`OpTag`]. + /// + /// To check arbitrary nodes, use [`check_tag`]. + pub fn check(hugr: &H) -> Result<(), HugrError> { + check_tag::(hugr, hugr.root())?; + Ok(()) } -} -impl RootChecked<&mut Hugr, Root> { - /// Allows immutably borrowing the underlying mutable reference - pub fn borrow(&self) -> RootChecked<&Hugr, Root> { - RootChecked(&*self.0, PhantomData) + /// Returns a reference to the underlying Hugr. + pub fn hugr(&self) -> &H { + &self.0 } -} - -impl, Root> HugrInternals for RootChecked { - type Portgraph<'p> - = &'p MultiPortGraph - where - Self: 'p; - type Node = Node; - delegate! { - to self.as_ref() { - fn portgraph(&self) -> Self::Portgraph<'_>; - fn hierarchy(&self) -> Cow<'_, portgraph::Hierarchy>; - fn base_hugr(&self) -> &Hugr; - fn root_node(&self) -> Node; - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex; - fn get_node(&self, index: portgraph::NodeIndex) -> Node; - } + /// Extracts the underlying Hugr + pub fn into_hugr(self) -> H { + self.0 } -} -impl, Root: NodeHandle> RootTagged for RootChecked { - type RootHandle = Root; + /// Returns a wrapper over a reference to the underlying Hugr. + pub fn as_ref(&self) -> RootChecked<&H, Handle> { + RootChecked(&self.0, PhantomData) + } } -impl, Root> AsRef for RootChecked { +impl, Handle> AsRef for RootChecked { fn as_ref(&self) -> &Hugr { self.0.as_ref() } } -impl, Root> HugrMutInternals for RootChecked -where - Root: NodeHandle, -{ - #[inline(always)] - fn hugr_mut(&mut self) -> &mut Hugr { - self.0.hugr_mut() +/// A trait for types that can be checked for a specific [`OpTag`] at their root node. +/// +/// This is used mainly specifying function inputs that may either be a [`HugrView`] or an already checked [`RootChecked`]. +pub trait RootCheckable>: Sized { + /// Wrap the Hugr in a [`RootChecked`] if it is valid for the required [`OpTag`]. + /// + /// If `Self` is already a [`RootChecked`], it is a no-op. + fn try_into_checked(self) -> Result, HugrError>; +} +impl> RootCheckable for H { + fn try_into_checked(self) -> Result, HugrError> { + RootChecked::try_new(self) + } +} +impl> RootCheckable for RootChecked { + fn try_into_checked(self) -> Result, HugrError> { + Ok(self) } } -impl, Root: NodeHandle> HugrMut for RootChecked {} +/// Check that the node in a HUGR can be represented by the required tag. +pub fn check_tag, N>( + hugr: &impl HugrView, + node: N, +) -> Result<(), HugrError> { + let actual = hugr.get_optype(node).tag(); + let required = Required::TAG; + if !required.is_superset(actual) { + return Err(HugrError::InvalidTag { required, actual }); + } + Ok(()) +} #[cfg(test)] mod test { use super::RootChecked; - use crate::extension::prelude::MakeTuple; - use crate::extension::ExtensionSet; - use crate::hugr::internal::HugrMutInternals; - use crate::hugr::{HugrError, HugrMut}; - use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID}; - use crate::ops::{DataflowBlock, OpTag, OpType}; - use crate::{ops, type_row, types::Signature, Hugr, HugrView}; + use crate::hugr::HugrError; + use crate::ops::handle::{CfgID, DfgID}; + use crate::ops::{OpTag, OpType}; + use crate::{ops, types::Signature, Hugr}; #[test] fn root_checked() { @@ -108,7 +112,7 @@ mod test { } .into(); let mut h = Hugr::new(root_type.clone()); - let cfg_v = RootChecked::<&Hugr, CfgID>::try_new(&h); + let cfg_v = RootChecked::<_, CfgID>::check(&h); assert_eq!( cfg_v.err(), Some(HugrError::InvalidTag { @@ -116,46 +120,9 @@ mod test { actual: OpTag::Dfg }) ); - let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); - // That is a HugrMutInternal, so we can try: - let root = dfg_v.root(); - let bb: OpType = DataflowBlock { - inputs: type_row![], - other_outputs: type_row![], - sum_rows: vec![type_row![]], - extension_delta: ExtensionSet::new(), - } - .into(); - let r = dfg_v.replace_op(root, bb.clone()); - assert_eq!( - r, - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: ops::OpTag::DataflowBlock - }) - ); - // That didn't do anything: - assert_eq!(dfg_v.get_optype(root), &root_type); - - // Make a RootChecked that allows any DataflowParent - // We won't be able to do this by widening the bound: - assert_eq!( - RootChecked::<_, DataflowParentID>::try_new(dfg_v).err(), - Some(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::DataflowParent - }) - ); - - let mut dfp_v = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut h).unwrap(); - let r = dfp_v.replace_op(root, bb.clone()); - assert_eq!(r, Ok(root_type)); - assert_eq!(dfp_v.get_optype(root), &bb); - // Just check we can create a nested instance (narrowing the bound) - let mut bb_v = RootChecked::<_, BasicBlockID>::try_new(dfp_v).unwrap(); - - // And it's a HugrMut: - let nodetype = MakeTuple(type_row![]); - bb_v.add_node_with_parent(bb_v.root(), nodetype); + // This should succeed + let dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap(); + assert!(OpTag::Dfg.is_superset(dfg_v.tag())); + assert_eq!(dfg_v.as_ref().tag(), dfg_v.tag()); } } diff --git a/hugr-core/src/hugr/views/sibling.rs b/hugr-core/src/hugr/views/sibling.rs index f93b14cb4..fa8378c7a 100644 --- a/hugr-core/src/hugr/views/sibling.rs +++ b/hugr-core/src/hugr/views/sibling.rs @@ -6,11 +6,12 @@ use itertools::{Either, Itertools}; use portgraph::{LinkView, MultiPortGraph, PortView}; use crate::hugr::internal::HugrMutInternals; -use crate::hugr::{HugrError, HugrMut}; +use crate::hugr::{HugrError, HugrMut, NodeMetadataMap}; use crate::ops::handle::NodeHandle; +use crate::ops::OpTrait; use crate::{Direction, Hugr, Node, Port}; -use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged}; +use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView}; type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>; @@ -50,15 +51,19 @@ pub struct SiblingGraph<'g, Root = Node> { macro_rules! impl_base_members { () => { #[inline] - fn node_count(&self) -> usize { - self.base_hugr() - .hierarchy - .child_count(self.get_pg_index(self.root)) + fn root(&self) -> Self::Node { + self.root + } + + #[inline] + fn num_nodes(&self) -> usize { + self.hierarchy() + .child_count(self.to_portgraph_node(self.root)) + 1 } #[inline] - fn edge_count(&self) -> usize { + fn num_edges(&self) -> usize { // Faster implementation than filtering all the nodes in the internal graph. self.nodes() .map(|n| self.output_neighbours(n).count()) @@ -69,10 +74,9 @@ macro_rules! impl_base_members { fn nodes(&self) -> impl Iterator + Clone { // Faster implementation than filtering all the nodes in the internal graph. let children = self - .base_hugr() - .hierarchy - .children(self.get_pg_index(self.root)) - .map(|n| self.get_node(n)); + .hierarchy() + .children(self.to_portgraph_node(self.root)) + .map(|n| self.from_portgraph_node(n)); iter::once(self.root).chain(children) } @@ -82,10 +86,41 @@ macro_rules! impl_base_members { ) -> impl DoubleEndedIterator + Clone { // Same as SiblingGraph let children = match node == self.root { - true => self.base_hugr().hierarchy.children(self.get_pg_index(node)), + true => self.hierarchy().children(self.to_portgraph_node(node)), false => portgraph::hierarchy::Children::default(), }; - children.map(|n| self.get_node(n)) + children.map(|n| self.from_portgraph_node(n)) + } + + fn get_optype(&self, node: Self::Node) -> &crate::ops::OpType { + self.hugr.get_optype(node) + } + + fn extensions(&self) -> &crate::extension::ExtensionRegistry { + self.hugr.extensions() + } + + fn get_parent(&self, node: Self::Node) -> Option { + match self.hugr.get_parent(node) { + Some(parent) if parent == self.root => Some(self.root), + _ => None, + } + } + + fn descendants(&self, node: Self::Node) -> impl Iterator + Clone { + if node == self.root { + Either::Left(self.hugr.descendants(node)) + } else { + Either::Right(iter::empty()) + } + } + + delegate::delegate! { + to (&self.hugr) { + fn mermaid_string(&self) -> String; + fn mermaid_string_with_config(&self, config: crate::hugr::views::render::RenderConfig) -> String; + fn dot_string(&self) -> String; + } } }; } @@ -95,20 +130,20 @@ impl HugrView for SiblingGraph<'_, Root> { #[inline] fn contains_node(&self, node: Node) -> bool { - self.graph.contains_node(self.get_pg_index(node)) + self.graph.contains_node(self.to_portgraph_node(node)) } #[inline] fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .port_offsets(self.get_pg_index(node), dir) + .port_offsets(self.to_portgraph_node(node), dir) .map_into() } #[inline] fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_port_offsets(self.get_pg_index(node)) + .all_port_offsets(self.to_portgraph_node(node)) .map_into() } @@ -119,50 +154,52 @@ impl HugrView for SiblingGraph<'_, Root> { ) -> impl Iterator + Clone { let port = self .graph - .port_index(self.get_pg_index(node), port.into().pg_offset()) + .port_index(self.to_portgraph_node(node), port.into().pg_offset()) .unwrap(); self.graph.port_links(port).map(|(_, link)| { let node = self.graph.port_node(link).unwrap(); let offset = self.graph.port_offset(link).unwrap(); - (self.get_node(node), offset.into()) + (self.from_portgraph_node(node), offset.into()) }) } fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { self.graph - .get_connections(self.get_pg_index(node), self.get_pg_index(other)) + .get_connections(self.to_portgraph_node(node), self.to_portgraph_node(other)) .map(|(p1, p2)| [p1, p2].map(|link| self.graph.port_offset(link).unwrap().into())) } #[inline] fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.graph.num_ports(self.get_pg_index(node), dir) + self.graph.num_ports(self.to_portgraph_node(node), dir) } #[inline] fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { self.graph - .neighbours(self.get_pg_index(node), dir) - .map(|n| self.get_node(n)) + .neighbours(self.to_portgraph_node(node), dir) + .map(|n| self.from_portgraph_node(n)) } #[inline] fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { self.graph - .all_neighbours(self.get_pg_index(node)) - .map(|n| self.get_node(n)) + .all_neighbours(self.to_portgraph_node(node)) + .map(|n| self.from_portgraph_node(n)) } } -impl RootTagged for SiblingGraph<'_, Root> { - type RootHandle = Root; -} impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> { fn new_unchecked(hugr: &'a impl HugrView, root: Node) -> Self { + #[allow(deprecated)] let hugr = hugr.base_hugr(); Self { root, - graph: FlatRegionGraph::new(&hugr.graph, &hugr.hierarchy, hugr.get_pg_index(root)), + graph: FlatRegionGraph::new_with_root( + &hugr.graph, + &hugr.hierarchy, + hugr.to_portgraph_node(root), + ), hugr, _phantom: std::marker::PhantomData, } @@ -175,7 +212,7 @@ where { fn try_new(hugr: &'a impl HugrView, root: Node) -> Result { assert!( - hugr.valid_node(root), + hugr.contains_node(root), "Cannot create a sibling graph from an invalid node {}.", root ); @@ -201,24 +238,40 @@ where &self.graph } + #[inline] + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) + } + + #[inline] + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() + } + #[inline] fn base_hugr(&self) -> &Hugr { self.hugr } #[inline] - fn root_node(&self) -> Node { - self.root + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Node { + self.hugr.from_portgraph_node(index) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { - self.hugr.get_node(index) + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + self.hugr.node_metadata_map(node) } } @@ -233,101 +286,120 @@ where /// [HugrView] methods may be slower than for an immutable [SiblingGraph] /// as the latter may cache information about the graph connectivity, /// whereas (in order to ease mutation) this does not. -pub struct SiblingMut<'g, Root = Node> { +pub struct SiblingMut<'g, H: HugrView, Root = Node> { /// The chosen root node. - root: Node, + root: H::Node, /// The rest of the HUGR. - hugr: &'g mut Hugr, + hugr: &'g mut H, /// The operation type of the root node. _phantom: std::marker::PhantomData, } -impl<'g, Root: NodeHandle> SiblingMut<'g, Root> { +impl<'g, H: HugrMut, Root: NodeHandle> SiblingMut<'g, H, Root> { /// Create a new SiblingMut from a base. /// Equivalent to [HierarchyView::try_new] but takes a *mutable* reference. - pub fn try_new(hugr: &'g mut Base, root: Node) -> Result { - if root == hugr.root() && !Base::RootHandle::TAG.is_superset(Root::TAG) { - return Err(HugrError::InvalidTag { - required: Base::RootHandle::TAG, - actual: Root::TAG, - }); - } + pub fn try_new(hugr: &'g mut H, root: H::Node) -> Result { check_tag::(hugr, root)?; Ok(Self { - hugr: hugr.hugr_mut(), + hugr, root, _phantom: std::marker::PhantomData, }) } } -impl ExtractHugr for SiblingMut<'_, Root> {} +impl> ExtractHugr for SiblingMut<'_, H, Root> {} -impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> { +impl<'g, H: HugrMut, Root: NodeHandle> HugrInternals for SiblingMut<'g, H, Root> { type Portgraph<'p> = FlatRegionGraph<'p> where 'g: 'p, Root: 'p; - type Node = Node; + type Node = H::Node; + #[inline] fn portgraph(&self) -> Self::Portgraph<'_> { - FlatRegionGraph::new( + FlatRegionGraph::new_with_root( + #[allow(deprecated)] &self.base_hugr().graph, - &self.base_hugr().hierarchy, - self.root.pg_index(), + self.hierarchy(), + self.to_portgraph_node(self.root), ) } - fn base_hugr(&self) -> &Hugr { - self.hugr + #[inline] + fn region_portgraph( + &self, + parent: Self::Node, + ) -> portgraph::view::FlatRegion< + '_, + impl portgraph::view::LinkView + Clone + '_, + > { + self.hugr.region_portgraph(parent) } - fn root_node(&self) -> Node { - self.root + #[inline] + fn hierarchy(&self) -> &portgraph::Hierarchy { + self.hugr.hierarchy() + } + + #[inline] + fn to_portgraph_node(&self, node: impl NodeHandle) -> portgraph::NodeIndex { + self.hugr.to_portgraph_node(node) } #[inline] - fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex { - self.hugr.get_pg_index(node) + fn from_portgraph_node(&self, index: portgraph::NodeIndex) -> Self::Node { + self.hugr.from_portgraph_node(index) } #[inline] - fn get_node(&self, index: portgraph::NodeIndex) -> Node { - self.hugr.get_node(index) + fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap { + self.hugr.node_metadata_map(node) + } + + #[inline] + fn base_hugr(&self) -> &Hugr { + #[allow(deprecated)] + self.hugr.base_hugr() } } -impl HugrView for SiblingMut<'_, Root> { +impl> HugrView for SiblingMut<'_, H, Root> { impl_base_members! {} - fn contains_node(&self, node: Node) -> bool { + fn contains_node(&self, node: H::Node) -> bool { // Don't call self.get_parent(). That requires valid_node(node) // which infinitely-recurses back here. - node == self.root || self.base_hugr().get_parent(node) == Some(self.root) + node == self.root || self.hugr.get_parent(node) == Some(self.root) } - fn node_ports(&self, node: Node, dir: Direction) -> impl Iterator + Clone { - self.base_hugr().node_ports(node, dir) + fn node_ports(&self, node: Self::Node, dir: Direction) -> impl Iterator + Clone { + self.hugr.node_ports(node, dir) } - fn all_node_ports(&self, node: Node) -> impl Iterator + Clone { - self.base_hugr().all_node_ports(node) + fn all_node_ports(&self, node: Self::Node) -> impl Iterator + Clone { + self.hugr.all_node_ports(node) } fn linked_ports( &self, - node: Node, + node: Self::Node, port: impl Into, - ) -> impl Iterator + Clone { + ) -> impl Iterator + Clone { self.hugr .linked_ports(node, port) .filter(|(n, _)| self.contains_node(*n)) } - fn node_connections(&self, node: Node, other: Node) -> impl Iterator + Clone { + fn node_connections( + &self, + node: Self::Node, + other: Self::Node, + ) -> impl Iterator + Clone { match self.contains_node(node) && self.contains_node(other) { // The nodes are not in the sibling graph false => Either::Left(iter::empty()), @@ -336,34 +408,64 @@ impl HugrView for SiblingMut<'_, Root> { } } - fn num_ports(&self, node: Node, dir: Direction) -> usize { - self.base_hugr().num_ports(node, dir) + fn num_ports(&self, node: Self::Node, dir: Direction) -> usize { + self.hugr.num_ports(node, dir) } - fn neighbours(&self, node: Node, dir: Direction) -> impl Iterator + Clone { + fn neighbours( + &self, + node: Self::Node, + dir: Direction, + ) -> impl Iterator + Clone { self.hugr .neighbours(node, dir) .filter(|n| self.contains_node(*n)) } - fn all_neighbours(&self, node: Node) -> impl Iterator + Clone { + fn all_neighbours(&self, node: Self::Node) -> impl Iterator + Clone { self.hugr .all_neighbours(node) .filter(|n| self.contains_node(*n)) } } -impl RootTagged for SiblingMut<'_, Root> { - type RootHandle = Root; -} - -impl HugrMutInternals for SiblingMut<'_, Root> { - fn hugr_mut(&mut self) -> &mut Hugr { - self.hugr +impl> HugrMutInternals for SiblingMut<'_, H, Root> { + fn replace_op( + &mut self, + node: Self::Node, + op: impl Into, + ) -> crate::ops::OpType { + let op = op.into(); + // Note: `SiblingMut` will be removed in a subsequent PR, so we just panic here for now. + if node == self.root() && !Root::TAG.is_superset(op.tag()) { + let err = HugrError::InvalidTag { + required: Root::TAG, + actual: op.tag(), + }; + panic!("{err}"); + } + self.hugr.replace_op(node, op) + } + + delegate::delegate! { + to (&mut *self.hugr) { + fn set_root(&mut self, root: Self::Node); + fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize); + fn add_ports(&mut self, node: Self::Node, direction: crate::Direction, amount: isize) -> std::ops::Range; + fn insert_ports(&mut self, node: Self::Node, direction: crate::Direction, index: usize, amount: usize) -> std::ops::Range; + fn set_parent(&mut self, node: Self::Node, parent: Self::Node); + fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node); + fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node); + fn optype_mut(&mut self, node: Self::Node) -> &mut crate::ops::OpType; + fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut crate::hugr::NodeMetadataMap; + fn extensions_mut(&mut self) -> &mut crate::extension::ExtensionRegistry; + } } } -impl HugrMut for SiblingMut<'_, Root> {} +impl> HugrMut for SiblingMut<'_, H, Root> { + super::impls::hugr_mut_methods! {this, &mut *this.hugr} +} #[cfg(test)] mod test { @@ -374,11 +476,10 @@ mod test { use crate::builder::test::simple_dfg_hugr; use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder}; use crate::extension::prelude::{qb_t, usize_t}; - use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID}; + use crate::ops::handle::{CfgID, DfgID, FuncID}; + use crate::ops::OpType; use crate::ops::{dataflow::IOTrait, Input, OpTag, Output}; - use crate::ops::{OpTrait, OpType}; use crate::types::Signature; - use crate::utils::test_quantum_extension::EXTENSION_ID; use crate::IncomingPort; use super::super::descendants::test::make_module_hgr; @@ -396,7 +497,7 @@ mod test { { let def_io = region.get_io(def).unwrap(); - assert_eq!(region.node_count(), 5); + assert_eq!(region.num_nodes(), 5); assert_eq!(region.portgraph().node_count(), 5); assert!(region.nodes().all(|n| n == def || hugr.get_parent(n) == Some(def) @@ -405,19 +506,15 @@ mod test { assert_eq!( region.poly_func_type(), - Some( - Signature::new_endo(vec![usize_t(), qb_t()]) - .with_extension_delta(EXTENSION_ID) - .into() - ) + Some(Signature::new_endo(vec![usize_t(), qb_t()]).into()) ); assert_eq!( inner_region.inner_function_type().map(Cow::into_owned), Some(Signature::new(vec![usize_t()], vec![usize_t()])) ); - assert_eq!(inner_region.node_count(), 3); - assert_eq!(inner_region.edge_count(), 1); + assert_eq!(inner_region.num_nodes(), 3); + assert_eq!(inner_region.num_edges(), 1); assert_eq!(inner_region.children(inner).count(), 2); assert_eq!(inner_region.children(hugr.root()).count(), 0); assert_eq!( @@ -475,7 +572,7 @@ mod test { let mut def_region_hugr = hugr.clone(); let mut inner_region_hugr = hugr.clone(); - test_properties::( + test_properties::>( &hugr, def, inner, @@ -526,7 +623,7 @@ mod test { let root = simple_dfg_hugr.root(); let signature = simple_dfg_hugr.inner_function_type().unwrap().into_owned(); - let sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root); + let sib_mut = SiblingMut::<_, CfgID>::try_new(&mut simple_dfg_hugr, root); assert_eq!( sib_mut.err(), Some(HugrError::InvalidTag { @@ -535,45 +632,13 @@ mod test { }) ); - let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); let bad_nodetype: OpType = crate::ops::CFG { signature }.into(); - assert_eq!( - sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()), - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::Cfg - }) - ); - // In contrast, performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation - simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap(); + // Performing this on the Hugr (where the allowed root type is 'Any') is only detected by validation + simple_dfg_hugr.replace_op(root, bad_nodetype); assert!(simple_dfg_hugr.validate().is_err()); } - #[rstest] - fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) { - let root = simple_dfg_hugr.root(); - let case_nodetype = crate::ops::Case { - signature: simple_dfg_hugr - .root_type() - .dataflow_signature() - .unwrap() - .into_owned(), - }; - let mut sib_mut = SiblingMut::::try_new(&mut simple_dfg_hugr, root).unwrap(); - // As expected, we cannot replace the root with a Case - assert_eq!( - sib_mut.replace_op(root, case_nodetype), - Err(HugrError::InvalidTag { - required: OpTag::Dfg, - actual: OpTag::Case - }) - ); - - let nested_sib_mut = SiblingMut::::try_new(&mut sib_mut, root); - assert!(nested_sib_mut.is_err()); - } - #[rstest] fn extract_hugr() -> Result<(), Box> { let (hugr, _def, inner) = make_module_hgr()?; @@ -584,8 +649,8 @@ mod test { let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?; - assert_eq!(region.node_count(), extracted.node_count()); - assert_eq!(region.root_type(), extracted.root_type()); + assert_eq!(region.num_nodes(), extracted.num_nodes()); + assert_eq!(region.root_optype(), extracted.root_optype()); Ok(()) } diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index a0bf1a3da..b2eba044e 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -21,14 +21,15 @@ use thiserror::Error; use crate::builder::{Container, FunctionBuilder}; use crate::core::HugrNode; -use crate::extension::ExtensionSet; -use crate::hugr::{HugrMut, HugrView, RootTagged}; +use crate::hugr::{HugrMut, HugrView}; use crate::ops::dataflow::DataflowOpTrait; use crate::ops::handle::{ContainerHandle, DataflowOpID}; use crate::ops::{NamedOp, OpTag, OpTrait, OpType}; use crate::types::{Signature, Type}; use crate::{Hugr, IncomingPort, Node, OutgoingPort, Port, SimpleReplacement}; +use super::root_checked::RootCheckable; + /// A non-empty convex subgraph of a HUGR sibling graph. /// /// A HUGR region in which all nodes share the same parent. Unlike @@ -95,11 +96,18 @@ impl SiblingSubgraph { /// /// This will return an [`InvalidSubgraph::EmptySubgraph`] error if the /// subgraph is empty. - pub fn try_new_dataflow_subgraph(dfg_graph: &H) -> Result> + pub fn try_new_dataflow_subgraph<'h, H, Root>( + dfg_graph: impl RootCheckable<&'h H, Root>, + ) -> Result> where - H: Clone + RootTagged, - Root: ContainerHandle, + H: 'h + Clone + HugrView, + Root: ContainerHandle, { + let Ok(dfg_graph) = dfg_graph.try_into_checked() else { + return Err(InvalidSubgraph::NonDataflowRegion); + }; + let dfg_graph = dfg_graph.into_hugr(); + let parent = dfg_graph.root(); let nodes = dfg_graph.children(parent).skip(2).collect_vec(); let (inputs, outputs) = get_input_output_ports(dfg_graph); @@ -185,7 +193,7 @@ impl SiblingSubgraph { let subpg = Subgraph::new_subgraph(pg.clone(), make_boundary(hugr, &inputs, &outputs)); let nodes = subpg .nodes_iter() - .map(|index| hugr.get_node(index)) + .map(|index| hugr.from_portgraph_node(index)) .collect_vec(); validate_subgraph(hugr, &nodes, &inputs, &outputs)?; @@ -340,11 +348,7 @@ impl SiblingSubgraph { sig.port_type(p).cloned().expect("must be dataflow edge") }) .collect_vec(); - Signature::new(input, output).with_extension_delta(ExtensionSet::union_over( - self.nodes - .iter() - .map(|n| hugr.get_optype(*n).extension_delta()), - )) + Signature::new(input, output) } /// The parent of the sibling subgraph. @@ -446,16 +450,14 @@ impl SiblingSubgraph { nu_out, )) } -} -impl SiblingSubgraph { /// Create a new Hugr containing only the subgraph. /// /// The new Hugr will contain a [FuncDefn][crate::ops::FuncDefn] root /// with the same signature as the subgraph and the specified `name` pub fn extract_subgraph( &self, - hugr: &impl HugrView, + hugr: &impl HugrView, name: impl Into, ) -> Hugr { let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap(); @@ -518,7 +520,7 @@ fn make_boundary<'a, N: HugrNode>( ) -> Boundary { let to_pg_index = |n: N, p: Port| { hugr.portgraph() - .port_index(hugr.get_pg_index(n), p.pg_offset()) + .port_index(hugr.to_portgraph_node(n), p.pg_offset()) .unwrap() }; Boundary::new( @@ -800,6 +802,9 @@ pub enum InvalidSubgraph { /// An invalid boundary port was found. #[error("Invalid boundary port.")] InvalidBoundary(#[from] InvalidSubgraphBoundary), + /// The hugr region is not a dataflow graph. + #[error("SiblingSubgraphs may only be defined on dataflow regions.")] + NonDataflowRegion, } /// Errors that can occur while constructing a [`SiblingSubgraph`]. @@ -828,12 +833,12 @@ mod tests { use cool_asserts::assert_matches; use crate::builder::inout_sig; - use crate::hugr::Rewrite; + use crate::hugr::Patch; use crate::ops::Const; - use crate::std_extensions::arithmetic::float_types::{self, ConstF64}; - use crate::std_extensions::logic::{self, LogicOp}; + use crate::std_extensions::arithmetic::float_types::ConstF64; + use crate::std_extensions::logic::LogicOp; use crate::type_row; - use crate::utils::test_quantum_extension::{self, cx_gate, rz_f64}; + use crate::utils::test_quantum_extension::{cx_gate, rz_f64}; use crate::{ builder::{ BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, @@ -879,12 +884,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) - .with_extension_delta(ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - float_types::EXTENSION_ID, - ])) - .into(), + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -903,12 +903,7 @@ mod tests { /// A bool to bool hugr with three subsequent NOT gates. fn build_3not_hugr() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - Signature::new_endo(vec![bool_t()]) - .with_extension_delta(logic::EXTENSION_ID) - .into(), - )?; + let func = mod_builder.declare("test", Signature::new_endo(vec![bool_t()]).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let outs1 = dfg.add_dataflow_op(LogicOp::Not, dfg.input_wires())?; @@ -927,9 +922,7 @@ mod tests { let mut mod_builder = ModuleBuilder::new(); let func = mod_builder.declare( "test", - Signature::new(bool_t(), vec![bool_t(), bool_t()]) - .with_extension_delta(logic::EXTENSION_ID) - .into(), + Signature::new(bool_t(), vec![bool_t(), bool_t()]).into(), )?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; @@ -947,12 +940,7 @@ mod tests { /// A HUGR with a copy fn build_hugr_classical() -> Result<(Hugr, Node), BuildError> { let mut mod_builder = ModuleBuilder::new(); - let func = mod_builder.declare( - "test", - Signature::new_endo(bool_t()) - .with_extension_delta(logic::EXTENSION_ID) - .into(), - )?; + let func = mod_builder.declare("test", Signature::new_endo(bool_t()).into())?; let func_id = { let mut dfg = mod_builder.define_declaration(&func)?; let in_wire = dfg.input_wires().exactly_one().unwrap(); @@ -987,7 +975,7 @@ mod tests { fn construct_simple_replacement() -> Result<(), InvalidSubgraph> { let (mut hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; let empty_dfg = { let builder = @@ -1000,9 +988,9 @@ mod tests { assert_eq!(rep.subgraph().nodes().len(), 4); - assert_eq!(hugr.node_count(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out - hugr.apply_rewrite(rep).unwrap(); - assert_eq!(hugr.node_count(), 4); // Module + Def + In + Out + assert_eq!(hugr.num_nodes(), 8); // Module + Def + In + CX + Rz + Const + LoadConst + Out + hugr.apply_patch(rep).unwrap(); + assert_eq!(hugr.num_nodes(), 4); // Module + Def + In + Out Ok(()) } @@ -1011,15 +999,10 @@ mod tests { fn test_signature() -> Result<(), InvalidSubgraph> { let (hugr, dfg) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, dfg).unwrap(); - let sub = SiblingSubgraph::try_new_dataflow_subgraph(&func)?; + let sub = SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func)?; assert_eq!( sub.signature(&func), - Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]).with_extension_delta( - ExtensionSet::from_iter([ - test_quantum_extension::EXTENSION_ID, - float_types::EXTENSION_ID, - ]) - ) + Signature::new_endo(vec![qb_t(), qb_t(), qb_t()]) ); Ok(()) } @@ -1048,7 +1031,7 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); assert_eq!( - SiblingSubgraph::try_new_dataflow_subgraph(&func) + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func) .unwrap() .nodes() .len(), @@ -1164,7 +1147,8 @@ mod tests { let (hugr, func_root) = build_hugr_classical().unwrap(); let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let func = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); + let func = + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); let func_defn = hugr.get_optype(func_root).as_func_defn().unwrap(); assert_eq!(func_defn.signature, func.signature(&func_graph).into()); } @@ -1174,7 +1158,8 @@ mod tests { let (hugr, func_root) = build_hugr().unwrap(); let func_graph: SiblingGraph<'_, FuncID> = SiblingGraph::try_new(&hugr, func_root).unwrap(); - let subgraph = SiblingSubgraph::try_new_dataflow_subgraph(&func_graph).unwrap(); + let subgraph = + SiblingSubgraph::try_new_dataflow_subgraph::<_, FuncID>(&func_graph).unwrap(); let extracted = subgraph.extract_subgraph(&hugr, "region"); extracted.validate().unwrap(); @@ -1199,19 +1184,14 @@ mod tests { let outw = [outw1].into_iter().chain(outw2); let h = builder.finish_hugr_with_outputs(outw).unwrap(); let view = SiblingGraph::::try_new(&h, h.root()).unwrap(); - let subg = SiblingSubgraph::try_new_dataflow_subgraph(&view).unwrap(); + let subg = SiblingSubgraph::try_new_dataflow_subgraph::<_, DfgID>(&view).unwrap(); assert_eq!(subg.nodes().len(), 2); } #[test] fn test_unconnected() { // test a replacement on a subgraph with a discarded output - let mut b = DFGBuilder::new( - Signature::new(bool_t(), type_row![]) - // .with_prelude() - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap(); let inw = b.input_wires().exactly_one().unwrap(); let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); // Unconnected output, discarded @@ -1222,11 +1202,7 @@ mod tests { assert_eq!(subg.nodes().len(), 1); // TODO create a valid replacement let replacement = { - let mut rep_b = DFGBuilder::new( - Signature::new_endo(bool_t()) - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut rep_b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let inw = rep_b.input_wires().exactly_one().unwrap(); let not_n = rep_b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); @@ -1241,11 +1217,7 @@ mod tests { #[test] fn single_node_subgraph() { // A hugr with a single NOT operation, with disconnected output. - let mut b = DFGBuilder::new( - Signature::new(bool_t(), type_row![]) - .with_extension_delta(crate::std_extensions::logic::EXTENSION_ID), - ) - .unwrap(); + let mut b = DFGBuilder::new(Signature::new(bool_t(), type_row![])).unwrap(); let inw = b.input_wires().exactly_one().unwrap(); let not_n = b.add_dataflow_op(LogicOp::Not, [inw]).unwrap(); // Unconnected output, discarded diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 642c84c41..ce5971364 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use crate::{ - extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError}, + extension::{ExtensionId, ExtensionRegistry, SignatureError}, hugr::{HugrMut, NodeMetadata}, ops::{ constant::{CustomConst, CustomSerialized, OpaqueValue}, @@ -35,6 +35,7 @@ use thiserror::Error; /// Error during import. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ImportError { /// The model contains a feature that is not supported by the importer yet. /// Errors of this kind are expected to be removed as the model format and @@ -75,6 +76,7 @@ pub enum ImportError { /// Import error caused by incorrect order hints. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] @@ -789,7 +791,6 @@ impl<'a> Context<'a> { just_inputs, just_outputs, rest, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -817,7 +818,6 @@ impl<'a> Context<'a> { sum_rows, other_inputs, outputs, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -885,7 +885,6 @@ impl<'a> Context<'a> { inputs: types.clone(), other_outputs: TypeRow::default(), sum_rows: vec![types.clone()], - extension_delta: ExtensionSet::default(), }), ); @@ -986,7 +985,6 @@ impl<'a> Context<'a> { inputs, other_outputs, sum_rows, - extension_delta: ExtensionSet::new(), }); let node = self.make_node(node_id, optype, parent)?; @@ -1489,7 +1487,7 @@ impl<'a> Context<'a> { let runtime_type = self.import_type(runtime_type)?; let value: serde_json::Value = serde_json::from_str(json) .map_err(|_| table::ModelError::TypeError(term_id))?; - let custom_const = CustomSerialized::new(runtime_type, value, ExtensionSet::new()); + let custom_const = CustomSerialized::new(runtime_type, value); let opaque_value = OpaqueValue::new(custom_const); return Ok(Value::Extension { e: opaque_value }); } diff --git a/hugr-core/src/lib.rs b/hugr-core/src/lib.rs index e32b623f2..e5f57d2a8 100644 --- a/hugr-core/src/lib.rs +++ b/hugr-core/src/lib.rs @@ -12,11 +12,9 @@ pub mod builder; pub mod core; pub mod envelope; -#[cfg(feature = "model_unstable")] pub mod export; pub mod extension; pub mod hugr; -#[cfg(feature = "model_unstable")] pub mod import; pub mod macros; pub mod ops; diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 0c7d3bb3f..5b5dbc420 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -9,17 +9,19 @@ pub mod module; pub mod sum; pub mod tag; pub mod validate; +use crate::core::HugrNode; use crate::extension::resolution::{ collect_op_extension, collect_op_types_extensions, ExtensionCollectionError, }; use std::borrow::Cow; use crate::extension::simple_op::MakeExtensionOp; -use crate::extension::{ExtensionId, ExtensionRegistry, ExtensionSet}; +use crate::extension::{ExtensionId, ExtensionRegistry}; use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use derive_more::Display; +use handle::NodeHandle; use paste::paste; use portgraph::NodeIndex; @@ -41,7 +43,6 @@ pub use tag::OpTag; #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] /// The concrete operation types for a node in the HUGR. -// TODO: Link the NodeHandles to the OpType. #[non_exhaustive] #[allow(missing_docs)] #[serde(tag = "op")] @@ -377,6 +378,19 @@ pub trait OpTrait: Sized + Clone { /// Tag identifying the operation. fn tag(&self) -> OpTag; + /// Tries to create a specific [`NodeHandle`] for a node with this operation + /// type. + /// + /// Fails if the operation's [`OpTrait::tag`] does not match the + /// [`NodeHandle::TAG`] of the requested handle. + fn try_node_handle(&self, node: N) -> Option + where + N: HugrNode, + H: NodeHandle + From, + { + H::TAG.is_superset(self.tag()).then(|| node.into()) + } + /// The signature of the operation. /// /// Only dataflow operations have a signature, otherwise returns None. @@ -384,12 +398,6 @@ pub trait OpTrait: Sized + Clone { None } - /// The delta between the input extensions specified for a node, - /// and the output extensions calculated for that node - fn extension_delta(&self) -> ExtensionSet { - ExtensionSet::new() - } - /// The edge kind for the non-dataflow inputs of the operation, /// not described by the signature. /// diff --git a/hugr-core/src/ops/constant.rs b/hugr-core/src/ops/constant.rs index 794e6eaaa..18f3974d4 100644 --- a/hugr-core/src/ops/constant.rs +++ b/hugr-core/src/ops/constant.rs @@ -8,7 +8,6 @@ use std::hash::{Hash, Hasher}; use super::{NamedOp, OpName, OpTrait, StaticTag}; use super::{OpTag, OpType}; -use crate::extension::ExtensionSet; use crate::types::{CustomType, EdgeKind, Signature, SumType, SumTypeError, Type, TypeRow}; use crate::{Hugr, HugrView}; @@ -81,10 +80,6 @@ impl OpTrait for Const { "Constant value" } - fn extension_delta(&self) -> ExtensionSet { - self.value().extension_reqs() - } - fn tag(&self) -> OpTag { ::TAG } @@ -251,7 +246,6 @@ pub enum Value { /// use serde_json::json; /// /// let expected_json = json!({ -/// "extensions": ["prelude"], /// "typ": usize_t(), /// "value": {'c': "ConstUsize", 'v': 1} /// }); @@ -259,9 +253,8 @@ pub enum Value { /// assert_eq!(&serde_json::to_value(&ev).unwrap(), &expected_json); /// assert_eq!(ev, serde_json::from_value(expected_json).unwrap()); /// -/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null, ExtensionSet::default())); +/// let ev = OpaqueValue::new(CustomSerialized::new(usize_t().clone(), serde_json::Value::Null)); /// let expected_json = json!({ -/// "extensions": [], /// "typ": usize_t(), /// "value": null /// }); @@ -297,8 +290,6 @@ impl OpaqueValue { pub fn get_type(&self) -> Type; /// An identifier of the internal [`CustomConst`]. pub fn name(&self) -> ValueName; - /// The extension(s) defining the internal [`CustomConst`]. - pub fn extension_reqs(&self) -> ExtensionSet; } } } @@ -364,7 +355,7 @@ pub enum ConstTypeError { /// Hugrs (even functions) inside Consts must be monomorphic fn mono_fn_type(h: &Hugr) -> Result, ConstTypeError> { let err = || ConstTypeError::NotMonomorphicFunction { - hugr_root_type: h.root_type().clone(), + hugr_root_type: h.root_optype().clone(), }; if let Some(pf) = h.poly_func_type() { match pf.try_into() { @@ -523,17 +514,6 @@ impl Value { .into() } - /// The extensions required by a [`Value`] - pub fn extension_reqs(&self) -> ExtensionSet { - match self { - Self::Extension { e } => e.extension_reqs().clone(), - Self::Function { .. } => ExtensionSet::new(), // no extensions required to load Hugr (only to run) - Self::Sum(Sum { values, .. }) => { - ExtensionSet::union_over(values.iter().map(|x| x.extension_reqs())) - } - } - } - /// Check the value. pub fn validate(&self) -> Result<(), ConstTypeError> { match self { @@ -631,10 +611,6 @@ pub(crate) mod test { format!("CustomTestValue({:?})", self.0).into() } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(self.0.extension().clone()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -849,8 +825,7 @@ pub(crate) mod test { // Dummy extension reference. &Weak::default(), ); - let json_const: Value = - CustomSerialized::new(typ_int.clone(), 6.into(), ex_id.clone()).into(); + let json_const: Value = CustomSerialized::new(typ_int.clone(), 6.into()).into(); let classic_t = Type::new_extension(typ_int.clone()); assert_matches!(classic_t.least_upper_bound(), TypeBound::Copyable); assert_eq!(json_const.get_type(), classic_t); diff --git a/hugr-core/src/ops/constant/custom.rs b/hugr-core/src/ops/constant/custom.rs index 985e15594..6ff1b67aa 100644 --- a/hugr-core/src/ops/constant/custom.rs +++ b/hugr-core/src/ops/constant/custom.rs @@ -13,7 +13,6 @@ use thiserror::Error; use crate::extension::resolution::{ resolve_type_extensions, ExtensionResolutionError, WeakExtensionRegistry, }; -use crate::extension::ExtensionSet; use crate::macros::impl_box_clone; use crate::types::{CustomCheckFailure, Type}; use crate::IncomingPort; @@ -44,7 +43,6 @@ use super::{Value, ValueName}; /// #[typetag::serde] /// impl CustomConst for CC { /// fn name(&self) -> ValueName { "CC".into() } -/// fn extension_reqs(&self) -> ExtensionSet { ExtensionSet::singleton(int_types::EXTENSION_ID) } /// fn get_type(&self) -> Type { int_types::INT_TYPES[5].clone() } /// } /// @@ -61,13 +59,6 @@ pub trait CustomConst: /// An identifier for the constant. fn name(&self) -> ValueName; - /// The extension(s) defining the custom constant - /// (a set to allow, say, a [List] of [USize]) - /// - /// [List]: crate::std_extensions::collections::list::LIST_TYPENAME - /// [USize]: crate::extension::prelude::usize_t - fn extension_reqs(&self) -> ExtensionSet; - /// Check the value. fn validate(&self) -> Result<(), CustomCheckFailure> { Ok(()) @@ -185,7 +176,6 @@ impl_box_clone!(CustomConst, CustomConstBoxClone); pub struct CustomSerialized { typ: Type, value: serde_json::Value, - extensions: ExtensionSet, } #[derive(Debug, Error)] @@ -206,15 +196,10 @@ pub struct DeserializeError { impl CustomSerialized { /// Creates a new [`CustomSerialized`]. - pub fn new( - typ: impl Into, - value: serde_json::Value, - exts: impl Into, - ) -> Self { + pub fn new(typ: impl Into, value: serde_json::Value) -> Self { Self { typ: typ.into(), value, - extensions: exts.into(), } } @@ -240,7 +225,6 @@ impl CustomSerialized { err, payload: cc.clone_box(), })?, - cc.extension_reqs(), ), }) } @@ -259,10 +243,10 @@ impl CustomSerialized { match cc.downcast::() { Ok(x) => Ok(*x), Err(cc) => { - let (typ, extension_reqs) = (cc.get_type(), cc.extension_reqs()); + let typ = cc.get_type(); let value = serialize_custom_const(cc.as_ref()) .map_err(|err| SerializeError { err, payload: cc })?; - Ok(Self::new(typ, value, extension_reqs)) + Ok(Self::new(typ, value)) } } } @@ -313,9 +297,6 @@ impl CustomConst for CustomSerialized { Some(self) == other.downcast_ref() } - fn extension_reqs(&self) -> ExtensionSet { - self.extensions.clone() - } fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -437,11 +418,8 @@ mod test { // check serialize_custom_const assert_eq!(expected_json, serialize_custom_const(&example.cc).unwrap()); - let expected_custom_serialized = CustomSerialized::new( - example.cc.get_type(), - expected_json, - example.cc.extension_reqs(), - ); + let expected_custom_serialized = + CustomSerialized::new(example.cc.get_type(), expected_json); // check all the try_from/try_into/into variations assert_eq!( @@ -494,11 +472,7 @@ mod test { let inner = example_custom_serialized().1; ( inner.clone(), - CustomSerialized::new( - inner.get_type(), - serialize_custom_const(&inner).unwrap(), - inner.extension_reqs(), - ), + CustomSerialized::new(inner.get_type(), serialize_custom_const(&inner).unwrap()), ) } @@ -545,7 +519,6 @@ mod proptest { use ::proptest::prelude::*; use crate::{ - extension::ExtensionSet, ops::constant::CustomSerialized, proptest::{any_serde_json_value, any_string}, types::Type, @@ -556,7 +529,6 @@ mod proptest { type Strategy = BoxedStrategy; fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy { let typ = any::(); - let extensions = any::(); // here we manually construct a serialized `dyn CustomConst`. // The "c" and "v" come from the `typetag::serde` annotation on // `trait CustomConst`. @@ -570,12 +542,8 @@ mod proptest { .collect::>() .into() }); - (typ, value, extensions) - .prop_map(|(typ, value, extensions)| CustomSerialized { - typ, - value, - extensions, - }) + (typ, value) + .prop_map(|(typ, value)| CustomSerialized { typ, value }) .boxed() } } diff --git a/hugr-core/src/ops/controlflow.rs b/hugr-core/src/ops/controlflow.rs index 49728980f..07c04f5c4 100644 --- a/hugr-core/src/ops/controlflow.rs +++ b/hugr-core/src/ops/controlflow.rs @@ -2,7 +2,6 @@ use std::borrow::Cow; -use crate::extension::ExtensionSet; use crate::types::{EdgeKind, Signature, Type, TypeRow}; use crate::Direction; @@ -20,8 +19,6 @@ pub struct TailLoop { pub just_outputs: TypeRow, /// Types that are appended to both input and output pub rest: TypeRow, - /// Extension requirements to execute the body - pub extension_delta: ExtensionSet, } impl_op_name!(TailLoop); @@ -37,9 +34,7 @@ impl DataflowOpTrait for TailLoop { // TODO: Store a cached signature let [inputs, outputs] = [&self.just_inputs, &self.just_outputs].map(|row| row.extend(self.rest.iter())); - Cow::Owned( - Signature::new(inputs, outputs).with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new(inputs, outputs)) } fn substitute(&self, subst: &crate::types::Substitution) -> Self { @@ -47,7 +42,6 @@ impl DataflowOpTrait for TailLoop { just_inputs: self.just_inputs.substitute(subst), just_outputs: self.just_outputs.substitute(subst), rest: self.rest.substitute(subst), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -80,10 +74,10 @@ impl TailLoop { impl DataflowParent for TailLoop { fn inner_signature(&self) -> Cow<'_, Signature> { // TODO: Store a cached signature - Cow::Owned( - Signature::new(self.body_input_row(), self.body_output_row()) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new( + self.body_input_row(), + self.body_output_row(), + )) } } @@ -97,8 +91,6 @@ pub struct Conditional { pub other_inputs: TypeRow, /// Output types pub outputs: TypeRow, - /// Extensions used to produce the outputs - pub extension_delta: ExtensionSet, } impl_op_name!(Conditional); @@ -115,10 +107,7 @@ impl DataflowOpTrait for Conditional { inputs .to_mut() .insert(0, Type::new_sum(self.sum_rows.clone())); - Cow::Owned( - Signature::new(inputs, self.outputs.clone()) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new(inputs, self.outputs.clone())) } fn substitute(&self, subst: &crate::types::Substitution) -> Self { @@ -126,7 +115,6 @@ impl DataflowOpTrait for Conditional { sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), other_inputs: self.other_inputs.substitute(subst), outputs: self.outputs.substitute(subst), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -174,7 +162,6 @@ pub struct DataflowBlock { pub inputs: TypeRow, pub other_outputs: TypeRow, pub sum_rows: Vec, - pub extension_delta: ExtensionSet, } #[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)] @@ -213,10 +200,10 @@ impl DataflowParent for DataflowBlock { let sum_type = Type::new_sum(self.sum_rows.clone()); let mut node_outputs = vec![sum_type]; node_outputs.extend_from_slice(&self.other_outputs); - Cow::Owned( - Signature::new(self.inputs.clone(), TypeRow::from(node_outputs)) - .with_extension_delta(self.extension_delta.clone()), - ) + Cow::Owned(Signature::new( + self.inputs.clone(), + TypeRow::from(node_outputs), + )) } } @@ -237,10 +224,6 @@ impl OpTrait for DataflowBlock { Some(EdgeKind::ControlFlow) } - fn extension_delta(&self) -> ExtensionSet { - self.extension_delta.clone() - } - fn non_df_port_count(&self, dir: Direction) -> usize { match dir { Direction::Incoming => 1, @@ -253,7 +236,6 @@ impl OpTrait for DataflowBlock { inputs: self.inputs.substitute(subst), other_outputs: self.other_outputs.substitute(subst), sum_rows: self.sum_rows.iter().map(|r| r.substitute(subst)).collect(), - extension_delta: self.extension_delta.substitute(subst), } } } @@ -343,10 +325,6 @@ impl OpTrait for Case { "A case node inside a conditional" } - fn extension_delta(&self) -> ExtensionSet { - self.signature.runtime_reqs.clone() - } - fn tag(&self) -> OpTag { ::TAG } @@ -373,10 +351,7 @@ impl Case { #[cfg(test)] mod test { use crate::{ - extension::{ - prelude::{qb_t, usize_t, PRELUDE_ID}, - ExtensionSet, - }, + extension::prelude::{qb_t, usize_t}, ops::{Conditional, DataflowOpTrait, DataflowParent}, types::{Signature, Substitution, Type, TypeArg, TypeBound, TypeRV}, }; @@ -391,19 +366,12 @@ mod test { inputs: vec![usize_t(), tv0.clone()].into(), other_outputs: vec![tv0.clone()].into(), sum_rows: vec![usize_t().into(), vec![qb_t(), tv0.clone()].into()], - extension_delta: ExtensionSet::type_var(1), }; - let dfb2 = dfb.substitute(&Substitution::new(&[ - qb_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ])); + let dfb2 = dfb.substitute(&Substitution::new(&[qb_t().into()])); let st = Type::new_sum(vec![vec![usize_t()], vec![qb_t(); 2]]); assert_eq!( dfb2.inner_signature(), Signature::new(vec![usize_t(), qb_t()], vec![st, qb_t()]) - .with_extension_delta(PRELUDE_ID) ); } @@ -414,7 +382,6 @@ mod test { sum_rows: vec![usize_t().into(), tv1.clone().into()], other_inputs: vec![Type::new_tuple(TypeRV::new_row_var_use(0, TypeBound::Any))].into(), outputs: vec![usize_t(), tv1].into(), - extension_delta: ExtensionSet::new(), }; let cond2 = cond.substitute(&Substitution::new(&[ TypeArg::Sequence { @@ -439,21 +406,14 @@ mod test { just_inputs: vec![qb_t(), tv0.clone()].into(), just_outputs: vec![tv0.clone(), qb_t()].into(), rest: vec![tv0.clone()].into(), - extension_delta: ExtensionSet::type_var(1), }; - let tail2 = tail_loop.substitute(&Substitution::new(&[ - usize_t().into(), - TypeArg::Extensions { - es: PRELUDE_ID.into(), - }, - ])); + let tail2 = tail_loop.substitute(&Substitution::new(&[usize_t().into()])); assert_eq!( tail2.signature(), Signature::new( vec![qb_t(), usize_t(), usize_t()], vec![usize_t(), qb_t(), usize_t()] ) - .with_extension_delta(PRELUDE_ID) ); } } diff --git a/hugr-core/src/ops/custom.rs b/hugr-core/src/ops/custom.rs index 6b907c947..5f5a13427 100644 --- a/hugr-core/src/ops/custom.rs +++ b/hugr-core/src/ops/custom.rs @@ -233,7 +233,6 @@ impl OpaqueOp { args: impl Into>, signature: Signature, ) -> Self { - let signature = signature.with_extension_delta(extension.clone()); Self { extension, name: name.into(), @@ -382,10 +381,7 @@ mod test { assert_eq!(op.name(), "res.op"); assert_eq!(DataflowOpTrait::description(&op), "desc"); assert_eq!(op.args(), &[TypeArg::Type { ty: usize_t() }]); - assert_eq!( - op.signature().as_ref(), - &sig.with_extension_delta(op.extension().clone()) - ); + assert_eq!(op.signature().as_ref(), &sig); } #[test] diff --git a/hugr-core/src/ops/dataflow.rs b/hugr-core/src/ops/dataflow.rs index c63c44b87..ba8f81c0c 100644 --- a/hugr-core/src/ops/dataflow.rs +++ b/hugr-core/src/ops/dataflow.rs @@ -4,7 +4,7 @@ use std::borrow::Cow; use super::{impl_op_name, OpTag, OpTrait}; -use crate::extension::{ExtensionSet, SignatureError}; +use crate::extension::SignatureError; use crate::ops::StaticTag; use crate::types::{EdgeKind, PolyFuncType, Signature, Substitution, Type, TypeArg, TypeRow}; use crate::{type_row, IncomingPort}; @@ -151,15 +151,15 @@ impl OpTrait for T { fn description(&self) -> &str { DataflowOpTrait::description(self) } + fn tag(&self) -> OpTag { T::TAG } + fn dataflow_signature(&self) -> Option> { Some(DataflowOpTrait::signature(self)) } - fn extension_delta(&self) -> ExtensionSet { - DataflowOpTrait::signature(self).runtime_reqs.clone() - } + fn other_input(&self) -> Option { DataflowOpTrait::other_input(self) } diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index d7fe16419..a5a3c294a 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,4 +1,5 @@ //! Handles to nodes in HUGR. +use crate::core::HugrNode; use crate::types::{Type, TypeBound}; use crate::Node; @@ -9,12 +10,12 @@ use super::{AliasDecl, OpTag}; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. -pub trait NodeHandle: Clone { +pub trait NodeHandle: Clone { /// The most specific operation tag associated with the handle. const TAG: OpTag; /// Index of underlying node. - fn node(&self) -> Node; + fn node(&self) -> N; /// Operation tag for the handle. #[inline] @@ -23,7 +24,7 @@ pub trait NodeHandle: Clone { } /// Cast the handle to a different more general tag. - fn try_cast>(&self) -> Option { + fn try_cast + From>(&self) -> Option { T::TAG.is_superset(Self::TAG).then(|| self.node().into()) } @@ -36,30 +37,30 @@ pub trait NodeHandle: Clone { /// Trait for handles that contain children. /// /// The allowed children handles are defined by the associated type. -pub trait ContainerHandle: NodeHandle { +pub trait ContainerHandle: NodeHandle { /// Handle type for the children of this node. - type ChildrenHandle: NodeHandle; + type ChildrenHandle: NodeHandle; } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowOp](crate::ops::dataflow). -pub struct DataflowOpID(Node); +pub struct DataflowOpID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DFG](crate::ops::DFG) node. -pub struct DfgID(Node); +pub struct DfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [CFG](crate::ops::CFG) node. -pub struct CfgID(Node); +pub struct CfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a module [Module](crate::ops::Module) node. -pub struct ModuleRootID(Node); +pub struct ModuleRootID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [module op](crate::ops::module) node. -pub struct ModuleID(Node); +pub struct ModuleID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [def](crate::ops::OpType::FuncDefn) @@ -67,7 +68,7 @@ pub struct ModuleID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct FuncID(Node); +pub struct FuncID(N); #[derive(Debug, Clone, PartialEq, Eq)] /// Handle to an [AliasDefn](crate::ops::OpType::AliasDefn) @@ -75,15 +76,15 @@ pub struct FuncID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct AliasID { - node: Node, +pub struct AliasID { + node: N, name: SmolStr, bound: TypeBound, } -impl AliasID { +impl AliasID { /// Construct new AliasID - pub fn new(node: Node, name: SmolStr, bound: TypeBound) -> Self { + pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self { Self { node, name, bound } } @@ -99,27 +100,27 @@ impl AliasID { #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. -pub struct ConstID(Node); +pub struct ConstID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node. -pub struct BasicBlockID(Node); +pub struct BasicBlockID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Case](crate::ops::Case) node. -pub struct CaseID(Node); +pub struct CaseID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [TailLoop](crate::ops::TailLoop) node. -pub struct TailLoopID(Node); +pub struct TailLoopID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Conditional](crate::ops::Conditional) node. -pub struct ConditionalID(Node); +pub struct ConditionalID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a dataflow container node. -pub struct DataflowParentID(Node); +pub struct DataflowParentID(N); /// Implements the `NodeHandle` trait for a tuple struct that contains just a /// NodeIndex. Takes the name of the struct, and the corresponding OpTag. @@ -131,11 +132,11 @@ macro_rules! impl_nodehandle { impl_nodehandle!($name, $tag, 0); }; ($name:ident, $tag:expr, $node_attr:tt) => { - impl NodeHandle for $name { + impl NodeHandle for $name { const TAG: OpTag = $tag; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.$node_attr } } @@ -156,35 +157,35 @@ impl_nodehandle!(ConstID, OpTag::Const); impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock); -impl NodeHandle for FuncID { +impl NodeHandle for FuncID { const TAG: OpTag = OpTag::Function; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.0 } } -impl NodeHandle for AliasID { +impl NodeHandle for AliasID { const TAG: OpTag = OpTag::Alias; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.node } } -impl NodeHandle for Node { +impl NodeHandle for N { const TAG: OpTag = OpTag::Any; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { *self } } /// Implements the `ContainerHandle` trait, with the given child handle type. macro_rules! impl_containerHandle { - ($name:path, $children:ident) => { - impl ContainerHandle for $name { - type ChildrenHandle = $children; + ($name:ident, $children:ident) => { + impl ContainerHandle for $name { + type ChildrenHandle = $children; } }; } @@ -197,5 +198,9 @@ impl_containerHandle!(CaseID, DataflowOpID); impl_containerHandle!(ModuleRootID, ModuleID); impl_containerHandle!(CfgID, BasicBlockID); impl_containerHandle!(BasicBlockID, DataflowOpID); -impl_containerHandle!(FuncID, DataflowOpID); -impl_containerHandle!(AliasID, DataflowOpID); +impl ContainerHandle for FuncID { + type ChildrenHandle = DataflowOpID; +} +impl ContainerHandle for AliasID { + type ChildrenHandle = DataflowOpID; +} diff --git a/hugr-core/src/package.rs b/hugr-core/src/package.rs index 1b96f1ebd..a7c48b3a2 100644 --- a/hugr-core/src/package.rs +++ b/hugr-core/src/package.rs @@ -224,9 +224,6 @@ impl Package { // As a fallback, try to load a hugr json. if let Ok(mut hugr) = serde_json::from_value::(val) { hugr.resolve_extension_defs(extension_registry)?; - if cfg!(feature = "extension_inference") { - hugr.infer_extensions(false)?; - } return Ok(Package::from_hugr(hugr)?); } @@ -353,8 +350,7 @@ fn to_module_hugr(mut hugr: Hugr) -> Result { name: "main".to_string(), signature: signature.into_owned().into(), }, - ) - .expect("Hugr accepts any root node"); + ); // Wrap it in a module. let new_root = hugr.add_node(Module::new().into()); diff --git a/hugr-core/src/std_extensions/arithmetic/conversions.rs b/hugr-core/src/std_extensions/arithmetic/conversions.rs index abeb61ab0..ea1004d92 100644 --- a/hugr-core/src/std_extensions/arithmetic/conversions.rs +++ b/hugr-core/src/std_extensions/arithmetic/conversions.rs @@ -8,7 +8,7 @@ use crate::extension::prelude::sum_with_error; use crate::extension::prelude::{bool_t, string_type, usize_t}; use crate::extension::simple_op::{HasConcrete, HasDef}; use crate::extension::simple_op::{MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError}; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc}; use crate::ops::OpName; use crate::ops::{custom::ExtensionOp, NamedOp}; use crate::std_extensions::arithmetic::int_ops::int_polytype; @@ -167,12 +167,6 @@ lazy_static! { /// Extension for conversions between integers and floats. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements( - ExtensionSet::from_iter(vec![ - super::int_types::EXTENSION_ID, - super::float_types::EXTENSION_ID, - ])); - ConvertOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/std_extensions/arithmetic/float_ops.rs b/hugr-core/src/std_extensions/arithmetic/float_ops.rs index 08b478535..f61353528 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_ops.rs @@ -9,7 +9,7 @@ use crate::{ extension::{ prelude::{bool_t, string_type}, simple_op::{MakeOpDef, MakeRegisteredOp, OpLoadError}, - ExtensionId, ExtensionSet, OpDef, SignatureFunc, + ExtensionId, OpDef, SignatureFunc, }, types::Signature, Extension, @@ -111,7 +111,6 @@ lazy_static! { /// Extension for basic float operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); FloatOps::load_all_ops(extension, extension_ref).unwrap(); }) }; diff --git a/hugr-core/src/std_extensions/arithmetic/float_types.rs b/hugr-core/src/std_extensions/arithmetic/float_types.rs index 3122bf30f..b5a741953 100644 --- a/hugr-core/src/std_extensions/arithmetic/float_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/float_types.rs @@ -5,7 +5,7 @@ use std::sync::{Arc, Weak}; use crate::ops::constant::{TryHash, ValueName}; use crate::types::TypeName; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::ExtensionId, ops::constant::CustomConst, types::{CustomType, Type, TypeBound}, Extension, @@ -65,7 +65,6 @@ impl std::ops::Deref for ConstF64 { impl ConstF64 { /// Name of the constructor for creating constant 64bit floats. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.float.const_f64"; /// Create a new [`ConstF64`] @@ -98,10 +97,6 @@ impl CustomConst for ConstF64 { fn equal_consts(&self, _: &dyn CustomConst) -> bool { false } - - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(EXTENSION_ID) - } } lazy_static! { diff --git a/hugr-core/src/std_extensions/arithmetic/int_ops.rs b/hugr-core/src/std_extensions/arithmetic/int_ops.rs index d0ae7baa7..69939d4e1 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_ops.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_ops.rs @@ -14,7 +14,7 @@ use crate::types::{FuncValueType, PolyFuncTypeRV, TypeRowRV}; use crate::utils::collect_array; use crate::{ - extension::{ExtensionId, ExtensionSet, SignatureError}, + extension::{ExtensionId, SignatureError}, types::{type_param::TypeArg, Type}, Extension, }; @@ -252,7 +252,6 @@ lazy_static! { /// Extension for basic integer operations. pub static ref EXTENSION: Arc = { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { - extension.add_requirements(ExtensionSet::singleton(super::int_types::EXTENSION_ID)); IntOpDef::load_all_ops(extension, extension_ref).unwrap(); }) }; @@ -377,7 +376,7 @@ mod test { .unwrap() .signature() .as_ref(), - &Signature::new(int_type(3), int_type(4)).with_extension_delta(EXTENSION_ID) + &Signature::new(int_type(3), int_type(4)) ); assert_eq!( IntOpDef::iwiden_s @@ -386,7 +385,7 @@ mod test { .unwrap() .signature() .as_ref(), - &Signature::new_endo(int_type(3)).with_extension_delta(EXTENSION_ID) + &Signature::new_endo(int_type(3)) ); assert_eq!( IntOpDef::inarrow_s @@ -396,7 +395,6 @@ mod test { .signature() .as_ref(), &Signature::new(int_type(3), sum_ty_with_err(int_type(3))) - .with_extension_delta(EXTENSION_ID) ); assert!( IntOpDef::iwiden_u @@ -414,7 +412,6 @@ mod test { .signature() .as_ref(), &Signature::new(int_type(2), sum_ty_with_err(int_type(1))) - .with_extension_delta(EXTENSION_ID) ); assert!(IntOpDef::inarrow_u diff --git a/hugr-core/src/std_extensions/arithmetic/int_types.rs b/hugr-core/src/std_extensions/arithmetic/int_types.rs index e5d625695..022f4d61e 100644 --- a/hugr-core/src/std_extensions/arithmetic/int_types.rs +++ b/hugr-core/src/std_extensions/arithmetic/int_types.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Weak}; use crate::ops::constant::ValueName; use crate::types::TypeName; use crate::{ - extension::{ExtensionId, ExtensionSet}, + extension::ExtensionId, ops::constant::CustomConst, types::{ type_param::{TypeArg, TypeArgError, TypeParam}, @@ -105,7 +105,6 @@ pub struct ConstInt { impl ConstInt { /// Name of the constructor for creating constant integers. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "arithmetic.int.const"; /// Create a new [`ConstInt`] with a given width and unsigned value @@ -185,10 +184,6 @@ impl CustomConst for ConstInt { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::singleton(EXTENSION_ID) - } - fn get_type(&self) -> Type { int_type(type_arg(self.log_width)) } diff --git a/hugr-core/src/std_extensions/collections/array.rs b/hugr-core/src/std_extensions/collections/array.rs index 0332ff351..2e7ee5b75 100644 --- a/hugr-core/src/std_extensions/collections/array.rs +++ b/hugr-core/src/std_extensions/collections/array.rs @@ -17,7 +17,7 @@ use crate::extension::resolution::{ WeakExtensionRegistry, }; use crate::extension::simple_op::{MakeOpDef, MakeRegisteredOp}; -use crate::extension::{ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound}; +use crate::extension::{ExtensionId, SignatureError, TypeDef, TypeDefBound}; use crate::ops::constant::{maybe_hash_values, CustomConst, TryHash, ValueName}; use crate::ops::{ExtensionOp, OpName, Value}; use crate::types::type_param::{TypeArg, TypeParam}; @@ -45,7 +45,6 @@ pub struct ArrayValue { impl ArrayValue { /// Name of the constructor for creating constant arrays. - #[cfg_attr(not(feature = "model_unstable"), allow(dead_code))] pub(crate) const CTR_NAME: &'static str = "collections.array.const"; /// Create a new [CustomConst] for an array of values of type `typ`. @@ -144,11 +143,6 @@ impl CustomConst for ArrayValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.values.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, diff --git a/hugr-core/src/std_extensions/collections/array/array_repeat.rs b/hugr-core/src/std_extensions/collections/array/array_repeat.rs index 544866970..a31505cb2 100644 --- a/hugr-core/src/std_extensions/collections/array/array_repeat.rs +++ b/hugr-core/src/std_extensions/collections/array/array_repeat.rs @@ -6,7 +6,7 @@ use std::sync::{Arc, Weak}; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncValueType, PolyFuncTypeRV, Signature, Type, TypeBound}; @@ -42,16 +42,10 @@ impl FromStr for ArrayRepeatDef { impl ArrayRepeatDef { /// To avoid recursion when defining the extension, take the type definition as an argument. fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { - let params = vec![ - TypeParam::max_nat(), - TypeBound::Any.into(), - TypeParam::Extensions, - ]; + let params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t = Type::new_var_use(1, TypeBound::Any); - let es = ExtensionSet::type_var(2); - let func = - Type::new_function(Signature::new(vec![], vec![t.clone()]).with_extension_delta(es)); + let func = Type::new_function(Signature::new(vec![], vec![t.clone()])); let array_ty = instantiate_array(array_def, n, t).expect("Array type instantiation failed"); PolyFuncTypeRV::new(params, FuncValueType::new(vec![func], array_ty)).into() } @@ -109,18 +103,12 @@ pub struct ArrayRepeat { pub elem_ty: Type, /// Size of the array. pub size: u64, - /// The extensions required by the function that generates the array elements. - pub extension_reqs: ExtensionSet, } impl ArrayRepeat { /// Creates a new array repeat op. - pub fn new(elem_ty: Type, size: u64, extension_reqs: ExtensionSet) -> Self { - ArrayRepeat { - elem_ty, - size, - extension_reqs, - } + pub fn new(elem_ty: Type, size: u64) -> Self { + ArrayRepeat { elem_ty, size } } } @@ -143,9 +131,6 @@ impl MakeExtensionOp for ArrayRepeat { vec![ TypeArg::BoundedNat { n: self.size }, self.elem_ty.clone().into(), - TypeArg::Extensions { - es: self.extension_reqs.clone(), - }, ] } } @@ -169,8 +154,8 @@ impl HasConcrete for ArrayRepeatDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty }, TypeArg::Extensions { es }] => { - Ok(ArrayRepeat::new(ty.clone(), *n, es.clone())) + [TypeArg::BoundedNat { n }, TypeArg::Type { ty }] => { + Ok(ArrayRepeat::new(ty.clone(), *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -179,7 +164,7 @@ impl HasConcrete for ArrayRepeatDef { #[cfg(test)] mod tests { - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use crate::std_extensions::collections::array::array_type; use crate::{ extension::prelude::qb_t, ops::{OpTrait, OpType}, @@ -190,7 +175,7 @@ mod tests { #[test] fn test_repeat_def() { - let op = ArrayRepeat::new(qb_t(), 2, ExtensionSet::singleton(EXTENSION_ID)); + let op = ArrayRepeat::new(qb_t(), 2); let optype: OpType = op.clone().into(); let new_op: ArrayRepeat = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -200,8 +185,7 @@ mod tests { fn test_repeat() { let size = 2; let element_ty = qb_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayRepeat::new(element_ty.clone(), size, es.clone()); + let op = ArrayRepeat::new(element_ty.clone(), size); let optype: OpType = op.into(); @@ -210,10 +194,7 @@ mod tests { assert_eq!( sig.io(), ( - &vec![Type::new_function( - Signature::new(vec![], vec![qb_t()]).with_extension_delta(es) - )] - .into(), + &vec![Type::new_function(Signature::new(vec![], vec![qb_t()]))].into(), &vec![array_type(size, element_ty.clone())].into(), ) ); diff --git a/hugr-core/src/std_extensions/collections/array/array_scan.rs b/hugr-core/src/std_extensions/collections/array/array_scan.rs index 86a0fe94e..8064a73d0 100644 --- a/hugr-core/src/std_extensions/collections/array/array_scan.rs +++ b/hugr-core/src/std_extensions/collections/array/array_scan.rs @@ -8,7 +8,7 @@ use itertools::Itertools; use crate::extension::simple_op::{ HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }; -use crate::extension::{ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef}; +use crate::extension::{ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef}; use crate::ops::{ExtensionOp, NamedOp, OpName}; use crate::types::type_param::{TypeArg, TypeParam}; use crate::types::{FuncTypeBase, PolyFuncTypeRV, RowVariable, Type, TypeBound, TypeRV}; @@ -51,13 +51,11 @@ impl ArrayScanDef { TypeBound::Any.into(), TypeBound::Any.into(), TypeParam::new_list(TypeBound::Any), - TypeParam::Extensions, ]; let n = TypeArg::new_var_use(0, TypeParam::max_nat()); let t1 = Type::new_var_use(1, TypeBound::Any); let t2 = Type::new_var_use(2, TypeBound::Any); let s = TypeRV::new_row_var_use(3, TypeBound::Any); - let es = ExtensionSet::type_var(4); PolyFuncTypeRV::new( params, FuncTypeBase::::new( @@ -65,13 +63,10 @@ impl ArrayScanDef { instantiate_array(array_def, n.clone(), t1.clone()) .expect("Array type instantiation failed") .into(), - Type::new_function( - FuncTypeBase::::new( - vec![t1.into(), s.clone()], - vec![t2.clone().into(), s.clone()], - ) - .with_extension_delta(es), - ) + Type::new_function(FuncTypeBase::::new( + vec![t1.into(), s.clone()], + vec![t2.clone().into(), s.clone()], + )) .into(), s.clone(), ], @@ -145,25 +140,16 @@ pub struct ArrayScan { pub acc_tys: Vec, /// Size of the array. pub size: u64, - /// The extensions required by the scan function. - pub extension_reqs: ExtensionSet, } impl ArrayScan { /// Creates a new array scan op. - pub fn new( - src_ty: Type, - tgt_ty: Type, - acc_tys: Vec, - size: u64, - extension_reqs: ExtensionSet, - ) -> Self { + pub fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { ArrayScan { src_ty, tgt_ty, acc_tys, size, - extension_reqs, } } } @@ -191,9 +177,6 @@ impl MakeExtensionOp for ArrayScan { TypeArg::Sequence { elems: self.acc_tys.clone().into_iter().map_into().collect(), }, - TypeArg::Extensions { - es: self.extension_reqs.clone(), - }, ] } } @@ -217,7 +200,7 @@ impl HasConcrete for ArrayScanDef { fn instantiate(&self, type_args: &[TypeArg]) -> Result { match type_args { - [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }, TypeArg::Extensions { es }] => + [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] => { let acc_tys: Result<_, OpLoadError> = acc_tys .iter() @@ -226,13 +209,7 @@ impl HasConcrete for ArrayScanDef { _ => Err(SignatureError::InvalidTypeArgs.into()), }) .collect(); - Ok(ArrayScan::new( - src_ty.clone(), - tgt_ty.clone(), - acc_tys?, - *n, - es.clone(), - )) + Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) } _ => Err(SignatureError::InvalidTypeArgs.into()), } @@ -243,7 +220,7 @@ impl HasConcrete for ArrayScanDef { mod tests { use crate::extension::prelude::usize_t; - use crate::std_extensions::collections::array::{array_type, EXTENSION_ID}; + use crate::std_extensions::collections::array::array_type; use crate::{ extension::prelude::{bool_t, qb_t}, ops::{OpTrait, OpType}, @@ -254,13 +231,7 @@ mod tests { #[test] fn test_scan_def() { - let op = ArrayScan::new( - bool_t(), - qb_t(), - vec![usize_t()], - 2, - ExtensionSet::singleton(EXTENSION_ID), - ); + let op = ArrayScan::new(bool_t(), qb_t(), vec![usize_t()], 2); let optype: OpType = op.clone().into(); let new_op: ArrayScan = optype.cast().unwrap(); assert_eq!(new_op, op); @@ -271,9 +242,8 @@ mod tests { let size = 2; let src_ty = qb_t(); let tgt_ty = bool_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); - let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size, es.clone()); + let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -282,9 +252,7 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function( - Signature::new(vec![src_ty], vec![tgt_ty.clone()]).with_extension_delta(es) - ) + Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) ] .into(), &vec![array_type(size, tgt_ty)].into(), @@ -299,14 +267,12 @@ mod tests { let tgt_ty = bool_t(); let acc_ty1 = usize_t(); let acc_ty2 = qb_t(); - let es = ExtensionSet::singleton(EXTENSION_ID); let op = ArrayScan::new( src_ty.clone(), tgt_ty.clone(), vec![acc_ty1.clone(), acc_ty2.clone()], size, - es.clone(), ); let optype: OpType = op.into(); let sig = optype.dataflow_signature().unwrap(); @@ -316,13 +282,10 @@ mod tests { ( &vec![ array_type(size, src_ty.clone()), - Type::new_function( - Signature::new( - vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], - vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] - ) - .with_extension_delta(es) - ), + Type::new_function(Signature::new( + vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], + vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] + )), acc_ty1.clone(), acc_ty2.clone() ] diff --git a/hugr-core/src/std_extensions/collections/array/op_builder.rs b/hugr-core/src/std_extensions/collections/array/op_builder.rs index 46338dd43..623443347 100644 --- a/hugr-core/src/std_extensions/collections/array/op_builder.rs +++ b/hugr-core/src/std_extensions/collections/array/op_builder.rs @@ -213,9 +213,7 @@ impl ArrayOpBuilder for D {} #[cfg(test)] mod test { - use crate::extension::prelude::PRELUDE_ID; - use crate::extension::ExtensionSet; - use crate::std_extensions::collections::array::{self, array_type}; + use crate::std_extensions::collections::array::array_type; use crate::{ builder::{DFGBuilder, HugrBuilder}, extension::prelude::{either_type, option_type, usize_t, ConstUsize, UnwrapBuilder as _}, @@ -229,11 +227,7 @@ mod test { #[rstest::fixture] #[default(DFGBuilder)] fn all_array_ops( - #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW) - .with_extension_delta(ExtensionSet::from_iter([ - PRELUDE_ID, - array::EXTENSION_ID - ]))).unwrap())] + #[default(DFGBuilder::new(Signature::new_endo(Type::EMPTY_TYPEROW)).unwrap())] mut builder: B, ) -> B { let us0 = builder.add_load_value(ConstUsize::new(0)); diff --git a/hugr-core/src/std_extensions/collections/list.rs b/hugr-core/src/std_extensions/collections/list.rs index 98804bab0..3ffb4d9a0 100644 --- a/hugr-core/src/std_extensions/collections/list.rs +++ b/hugr-core/src/std_extensions/collections/list.rs @@ -25,7 +25,7 @@ use crate::types::{TypeName, TypeRowRV}; use crate::{ extension::{ simple_op::{MakeExtensionOp, OpLoadError}, - ExtensionId, ExtensionSet, SignatureError, TypeDef, TypeDefBound, + ExtensionId, SignatureError, TypeDef, TypeDefBound, }, ops::constant::CustomConst, ops::{custom::ExtensionOp, NamedOp}, @@ -126,11 +126,6 @@ impl CustomConst for ListValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.0.iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, diff --git a/hugr-core/src/std_extensions/collections/static_array.rs b/hugr-core/src/std_extensions/collections/static_array.rs index 9d2259e0b..05e5651a1 100644 --- a/hugr-core/src/std_extensions/collections/static_array.rs +++ b/hugr-core/src/std_extensions/collections/static_array.rs @@ -28,7 +28,7 @@ use crate::{ try_from_name, HasConcrete, HasDef, MakeExtensionOp, MakeOpDef, MakeRegisteredOp, OpLoadError, }, - ExtensionId, ExtensionSet, OpDef, SignatureError, SignatureFunc, TypeDef, + ExtensionId, OpDef, SignatureError, SignatureFunc, TypeDef, }, ops::{ constant::{maybe_hash_values, CustomConst, TryHash, ValueName}, @@ -128,11 +128,6 @@ impl CustomConst for StaticArrayValue { crate::ops::constant::downcast_equal_consts(self, other) } - fn extension_reqs(&self) -> ExtensionSet { - ExtensionSet::union_over(self.get_contents().iter().map(Value::extension_reqs)) - .union(EXTENSION_ID.into()) - } - fn update_extensions( &mut self, extensions: &WeakExtensionRegistry, @@ -404,7 +399,7 @@ impl StaticArrayOpBuilder for T {} mod test { use crate::{ builder::{DFGBuilder, DataflowHugr as _}, - extension::prelude::{qb_t, ConstUsize, PRELUDE_ID}, + extension::prelude::{qb_t, ConstUsize}, type_row, }; @@ -419,10 +414,10 @@ mod test { #[test] fn all_ops() { let _ = { - let mut builder = DFGBuilder::new( - Signature::new(type_row![], Type::from(option_type(usize_t()))) - .with_extension_delta(ExtensionSet::from_iter([PRELUDE_ID, EXTENSION_ID])), - ) + let mut builder = DFGBuilder::new(Signature::new( + type_row![], + Type::from(option_type(usize_t())), + )) .unwrap(); let array = builder.add_load_value( StaticArrayValue::try_new( diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index fcc8be9d3..20977cb51 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -124,13 +124,6 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); fn extension() -> Arc { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { LogicOp::load_all_ops(extension, extension_ref).unwrap(); - - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); }) } @@ -172,12 +165,9 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { pub(crate) mod test { use std::sync::Arc; - use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; + use super::{extension, LogicOp}; use crate::{ - extension::{ - prelude::bool_t, - simple_op::{MakeOpDef, MakeRegisteredOp}, - }, + extension::simple_op::{MakeOpDef, MakeRegisteredOp}, ops::{NamedOp, Value}, Extension, }; @@ -207,18 +197,6 @@ pub(crate) mod test { } } - #[test] - fn test_values() { - let r: Arc = extension(); - let false_val = r.get_value(&FALSE_NAME).unwrap(); - let true_val = r.get_value(&TRUE_NAME).unwrap(); - - for v in [false_val, true_val] { - let simpl = v.typed_value().get_type(); - assert_eq!(simpl, bool_t()); - } - } - /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`] pub(crate) fn and_op() -> LogicOp { LogicOp::And diff --git a/hugr-core/src/std_extensions/ptr.rs b/hugr-core/src/std_extensions/ptr.rs index fc0b1bbb4..6d77ae52d 100644 --- a/hugr-core/src/std_extensions/ptr.rs +++ b/hugr-core/src/std_extensions/ptr.rs @@ -268,10 +268,7 @@ pub(crate) mod test { let in_row = vec![bool_t(), float64_type()]; let hugr = { - let mut builder = DFGBuilder::new( - Signature::new(in_row.clone(), type_row![]).with_extension_delta(EXTENSION_ID), - ) - .unwrap(); + let mut builder = DFGBuilder::new(Signature::new(in_row.clone(), type_row![])).unwrap(); let in_wires: [Wire; 2] = builder.input_wires_arr(); for (ty, w) in in_row.into_iter().zip(in_wires.iter()) { diff --git a/hugr-core/src/types/poly_func.rs b/hugr-core/src/types/poly_func.rs index 67bc7fbf5..885b6bae8 100644 --- a/hugr-core/src/types/poly_func.rs +++ b/hugr-core/src/types/poly_func.rs @@ -277,7 +277,6 @@ pub(crate) mod test { let list_def = list::EXTENSION.get_type(&list::LIST_TYPENAME).unwrap(); let body_type = Signature::new_endo(Type::new_extension(list_def.instantiate([tv])?)); for decl in [ - TypeParam::Extensions, TypeParam::List { param: Box::new(TypeParam::max_nat()), }, diff --git a/hugr-core/src/types/signature.rs b/hugr-core/src/types/signature.rs index 28c39fa08..78965f1b6 100644 --- a/hugr-core/src/types/signature.rs +++ b/hugr-core/src/types/signature.rs @@ -37,8 +37,6 @@ pub struct FuncTypeBase { /// Value outputs of the function. #[cfg_attr(test, proptest(strategy = "any_with::>(params)"))] pub output: TypeRowBase, - /// The extensions the function specifies as required at runtime. - pub runtime_reqs: ExtensionSet, } /// The concept of "signature" in the spec - the edges required to/from a node @@ -55,22 +53,10 @@ pub type Signature = FuncTypeBase; pub type FuncValueType = FuncTypeBase; impl FuncTypeBase { - /// Builder method, add runtime_reqs to a FunctionType - pub fn with_extension_delta(mut self, rs: impl Into) -> Self { - self.runtime_reqs = self.runtime_reqs.union(rs.into()); - self - } - - /// Shorthand for adding the prelude extension to a FunctionType. - pub fn with_prelude(self) -> Self { - self.with_extension_delta(crate::extension::prelude::PRELUDE_ID) - } - pub(crate) fn substitute(&self, tr: &Substitution) -> Self { Self { input: self.input.substitute(tr), output: self.output.substitute(tr), - runtime_reqs: self.runtime_reqs.substitute(tr), } } @@ -79,7 +65,6 @@ impl FuncTypeBase { Self { input: input.into(), output: output.into(), - runtime_reqs: ExtensionSet::new(), } } @@ -117,19 +102,10 @@ impl FuncTypeBase { pub(super) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> { self.input.validate(var_decls)?; - self.output.validate(var_decls)?; - self.runtime_reqs.validate(var_decls) + self.output.validate(var_decls) } /// Returns a registry with the concrete extensions used by this signature. - /// - /// Note that extension type parameters are not included, as they have not - /// been instantiated yet. - /// - /// This method only returns extensions actually used by the types in the - /// signature. The extension deltas added via [`Self::with_extension_delta`] - /// refer to _runtime_ extensions, which may not be in all places that - /// manipulate a HUGR. pub fn used_extensions(&self) -> Result { let mut used = WeakExtensionRegistry::default(); let mut missing = ExtensionSet::new(); @@ -167,7 +143,6 @@ impl Default for FuncTypeBase { Self { input: Default::default(), output: Default::default(), - runtime_reqs: Default::default(), } } } @@ -290,9 +265,6 @@ impl Display for FuncTypeBase { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { self.input.fmt(f)?; f.write_str(" -> ")?; - if !self.runtime_reqs.is_empty() { - self.runtime_reqs.fmt(f)?; - } self.output.fmt(f) } } @@ -303,7 +275,7 @@ impl TryFrom for Signature { fn try_from(value: FuncValueType) -> Result { let input: TypeRow = value.input.try_into()?; let output: TypeRow = value.output.try_into()?; - Ok(Self::new(input, output).with_extension_delta(value.runtime_reqs)) + Ok(Self::new(input, output)) } } @@ -312,16 +284,13 @@ impl From for FuncValueType { Self { input: value.input.into(), output: value.output.into(), - runtime_reqs: value.runtime_reqs, } } } impl PartialEq> for FuncTypeBase { fn eq(&self, other: &FuncTypeBase) -> bool { - self.input == other.input - && self.output == other.output - && self.runtime_reqs == other.runtime_reqs + self.input == other.input && self.output == other.output } } diff --git a/hugr-core/src/types/type_param.rs b/hugr-core/src/types/type_param.rs index db2efecc6..e8fa28346 100644 --- a/hugr-core/src/types/type_param.rs +++ b/hugr-core/src/types/type_param.rs @@ -15,7 +15,6 @@ use super::{ check_typevar_decl, NoRV, RowVariable, Substitution, Transformable, Type, TypeBase, TypeBound, TypeTransformer, }; -use crate::extension::ExtensionSet; use crate::extension::SignatureError; /// The upper non-inclusive bound of a [`TypeParam::BoundedNat`] @@ -92,10 +91,6 @@ pub enum TypeParam { /// The [TypeParam]s contained in the tuple. params: Vec, }, - /// Argument is a [TypeArg::Extensions]. A set of [ExtensionId]s. - /// - /// [ExtensionId]: crate::extension::ExtensionId - Extensions, } impl TypeParam { @@ -131,7 +126,6 @@ impl TypeParam { (TypeParam::Tuple { params: es1 }, TypeParam::Tuple { params: es2 }) => { es1.len() == es2.len() && es1.iter().zip(es2).all(|(e1, e2)| e1.contains(e2)) } - (TypeParam::Extensions, TypeParam::Extensions) => true, _ => false, } } @@ -184,18 +178,9 @@ pub enum TypeArg { /// List of element types elems: Vec, }, - /// Instance of [TypeParam::Extensions], providing the extension ids. - #[display("Exts({})", { - use itertools::Itertools as _; - es.iter().map(|t|t.to_string()).join(",") - })] - Extensions { - #[allow(missing_docs)] - es: ExtensionSet, - }, /// Variable (used in type schemes or inside polymorphic functions), /// but not a [TypeArg::Type] (not even a row variable i.e. [TypeParam::List] of type) - /// nor [TypeArg::Extensions] - see [TypeArg::new_var_use] + /// - see [TypeArg::new_var_use] #[display("{v}")] Variable { #[allow(missing_docs)] @@ -239,14 +224,7 @@ impl From> for TypeArg { } } -impl From for TypeArg { - fn from(es: ExtensionSet) -> Self { - Self::Extensions { es } - } -} - -/// Variable in a TypeArg, that is neither a [TypeArg::Extensions] -/// nor a single [TypeArg::Type] (i.e. not a [Type::new_var_use] +/// Variable in a TypeArg, that is not a single [TypeArg::Type] (i.e. not a [Type::new_var_use] /// - it might be a [Type::new_row_var_use]). #[derive( Clone, Debug, PartialEq, Eq, Hash, serde::Deserialize, serde::Serialize, derive_more::Display, @@ -270,10 +248,6 @@ impl TypeArg { // as a TypeArg::Type because the latter stores a Type i.e. only a single type, // not a RowVariable. TypeParam::Type { b } => Type::new_var_use(idx, b).into(), - // Prevent TypeArg::Variable(idx, TypeParam::Extensions) - TypeParam::Extensions => TypeArg::Extensions { - es: ExtensionSet::type_var(idx), - }, _ => TypeArg::Variable { v: TypeArgVariable { idx, @@ -314,7 +288,6 @@ impl TypeArg { TypeArg::Type { ty } => ty.validate(var_decls), TypeArg::BoundedNat { .. } | TypeArg::String { .. } => Ok(()), TypeArg::Sequence { elems } => elems.iter().try_for_each(|a| a.validate(var_decls)), - TypeArg::Extensions { es: _ } => Ok(()), TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, } => { @@ -362,9 +335,6 @@ impl TypeArg { }; TypeArg::Sequence { elems } } - TypeArg::Extensions { es } => TypeArg::Extensions { - es: es.substitute(t), - }, TypeArg::Variable { v: TypeArgVariable { idx, cached_decl }, } => t.apply_var(*idx, cached_decl), @@ -377,10 +347,9 @@ impl Transformable for TypeArg { match self { TypeArg::Type { ty } => ty.transform(tr), TypeArg::Sequence { elems } => elems.transform(tr), - TypeArg::BoundedNat { .. } - | TypeArg::String { .. } - | TypeArg::Extensions { .. } - | TypeArg::Variable { .. } => Ok(false), + TypeArg::BoundedNat { .. } | TypeArg::String { .. } | TypeArg::Variable { .. } => { + Ok(false) + } } } } @@ -449,7 +418,6 @@ pub fn check_type_arg(arg: &TypeArg, param: &TypeParam) -> Result<(), TypeArgErr } (TypeArg::String { .. }, TypeParam::String) => Ok(()), - (TypeArg::Extensions { .. }, TypeParam::Extensions) => Ok(()), _ => Err(TypeArgError::TypeMismatch { arg: arg.clone(), param: param.clone(), @@ -659,7 +627,6 @@ mod test { use proptest::prelude::*; use super::super::{TypeArg, TypeArgVariable, TypeParam, UpperBound}; - use crate::extension::ExtensionSet; use crate::proptest::RecursionDepth; use crate::types::{Type, TypeBound}; @@ -680,7 +647,6 @@ mod test { use prop::collection::vec; use prop::strategy::Union; let mut strat = Union::new([ - Just(Self::Extensions).boxed(), Just(Self::String).boxed(), any::().prop_map(|b| Self::Type { b }).boxed(), any::() @@ -711,9 +677,6 @@ mod test { let mut strat = Union::new([ any::().prop_map(|n| Self::BoundedNat { n }).boxed(), any::().prop_map(|arg| Self::String { arg }).boxed(), - any::() - .prop_map(|es| Self::Extensions { es }) - .boxed(), any_with::(depth) .prop_map(|ty| Self::Type { ty }) .boxed(), diff --git a/hugr-llvm/Cargo.toml b/hugr-llvm/Cargo.toml index 677a82a31..bdfc63f5a 100644 --- a/hugr-llvm/Cargo.toml +++ b/hugr-llvm/Cargo.toml @@ -23,7 +23,7 @@ llvm14-0 = ["inkwell/llvm14-0"] [dependencies] -inkwell = { version = "0.5.0", default-features = false } +inkwell = { version = "0.6.0", default-features = false } hugr-core = { path = "../hugr-core", version = "0.15.3" } anyhow = "1.0.98" itertools.workspace = true diff --git a/hugr-llvm/README.md b/hugr-llvm/README.md index 5fd2d3239..988a650dd 100644 --- a/hugr-llvm/README.md +++ b/hugr-llvm/README.md @@ -25,14 +25,14 @@ version will only change on major releases. ## Developing hugr-llvm -See [DEVELOPMENT](DEVELOPMENT.md) for instructions on setting up the development environment. +See [DEVELOPMENT](../DEVELOPMENT.md) for instructions on setting up the development environment. ## License This project is licensed under Apache License, Version 2.0 ([LICENCE](LICENCE) or ). [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-llvm [hugr]: https://lib.rs/crates/hugr [inkwell]: https://thedan64.github.io/inkwell/inkwell/index.html [llvm-sys]: https://crates.io/crates/llvm-sys diff --git a/hugr-llvm/src/emit/ops.rs b/hugr-llvm/src/emit/ops.rs index 9cb6f9b10..76bb2bb09 100644 --- a/hugr-llvm/src/emit/ops.rs +++ b/hugr-llvm/src/emit/ops.rs @@ -5,7 +5,6 @@ use hugr_core::ops::{ }; use hugr_core::Node; use hugr_core::{ - hugr::views::SiblingGraph, types::{SumType, Type, TypeEnum}, HugrView, NodeIndex, }; @@ -71,34 +70,33 @@ where debug_assert!(i.out_value_types().count() == self.inputs.as_ref().unwrap().len()); debug_assert!(o.in_value_types().count() == self.outputs.as_ref().unwrap().len()); - let region: SiblingGraph = node.try_new_hierarchy_view().unwrap(); - Topo::new(®ion.as_petgraph()) - .iter(®ion.as_petgraph()) - .filter(|x| (*x != node.node())) - .map(|x| node.hugr().fat_optype(x)) - .try_for_each(|node| { - let inputs_rmb = context.node_ins_rmb(node)?; - let inputs = inputs_rmb.read(context.builder(), [])?; - let outputs = context.node_outs_rmb(node)?.promise(); - match node.as_ref() { - OpType::Input(_) => { - let i = self.take_input()?; - outputs.finish(context.builder(), i) - } - OpType::Output(_) => { - let o = self.take_output()?; - o.finish(context.builder(), inputs) - } - _ => emit_optype( - context, - EmitOpArgs { - node, - inputs, - outputs, - }, - ), + let region_graph = node.hugr().region_portgraph(node.node()); + let topo = Topo::new(®ion_graph); + for n in topo.iter(®ion_graph) { + let node = node.hugr().fat_optype(node.hugr().from_portgraph_node(n)); + let inputs_rmb = context.node_ins_rmb(node)?; + let inputs = inputs_rmb.read(context.builder(), [])?; + let outputs = context.node_outs_rmb(node)?.promise(); + match node.as_ref() { + OpType::Input(_) => { + let i = self.take_input()?; + outputs.finish(context.builder(), i)?; } - }) + OpType::Output(_) => { + let o = self.take_output()?; + o.finish(context.builder(), inputs)?; + } + _ => emit_optype( + context, + EmitOpArgs { + node, + inputs, + outputs, + }, + )?, + } + } + Ok(()) } } diff --git a/hugr-llvm/src/emit/ops/cfg.rs b/hugr-llvm/src/emit/ops/cfg.rs index 4d62350be..12f22d2f7 100644 --- a/hugr-llvm/src/emit/ops/cfg.rs +++ b/hugr-llvm/src/emit/ops/cfg.rs @@ -219,7 +219,7 @@ impl<'c, 'hugr, H: HugrView> CfgEmitter<'c, 'hugr, H> { mod test { use hugr_core::builder::{Dataflow, DataflowSubContainer, SubContainer}; use hugr_core::extension::prelude::{self, bool_t}; - use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; + use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::Value; use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; use hugr_core::type_row; @@ -239,7 +239,6 @@ mod test { llvm_ctx.add_extensions(CodegenExtsBuilder::add_default_int_extensions); let t1 = INT_TYPES[0].clone(); let t2 = INT_TYPES[1].clone(); - let es = ExtensionSet::from_iter([int_types::EXTENSION_ID, prelude::PRELUDE_ID]); let hugr = SimpleHugrConfig::new() .with_ins(vec![t1.clone(), t2.clone()]) .with_outs(t2.clone()) @@ -250,11 +249,7 @@ mod test { .finish(|mut builder| { let [in1, in2] = builder.input_wires_arr(); let mut cfg_builder = builder - .cfg_builder_exts( - [(t1.clone(), in1), (t2.clone(), in2)], - t2.clone().into(), - es.clone(), - ) + .cfg_builder([(t1.clone(), in1), (t2.clone(), in2)], t2.clone().into()) .unwrap(); // entry block takes (t1,t2) and unconditionally branches to b1 with no other outputs diff --git a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap index b3283ee1b..124f36b53 100644 --- a/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap +++ b/hugr-llvm/src/emit/snapshots/hugr_llvm__emit__test__test_fns__emit_hugr_call_indirect@pre-mem2reg@llvm14.snap @@ -20,14 +20,14 @@ entry_block: ; preds = %alloca_block define i1 @_hl.main_unary.6(i1 %0) { alloca_block: %"0" = alloca i1, align 1 - %"7_0" = alloca i1, align 1 %"9_0" = alloca i1 (i1)*, align 8 + %"7_0" = alloca i1, align 1 %"10_0" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block - store i1 %0, i1* %"7_0", align 1 store i1 (i1)* @_hl.main_unary.6, i1 (i1)** %"9_0", align 8 + store i1 %0, i1* %"7_0", align 1 %"9_01" = load i1 (i1)*, i1 (i1)** %"9_0", align 8 %"7_02" = load i1, i1* %"7_0", align 1 %1 = call i1 %"9_01"(i1 %"7_02") @@ -42,17 +42,17 @@ define { i1, i1 } @_hl.main_binary.11(i1 %0, i1 %1) { alloca_block: %"0" = alloca i1, align 1 %"1" = alloca i1, align 1 + %"14_0" = alloca { i1, i1 } (i1, i1)*, align 8 %"12_0" = alloca i1, align 1 %"12_1" = alloca i1, align 1 - %"14_0" = alloca { i1, i1 } (i1, i1)*, align 8 %"15_0" = alloca i1, align 1 %"15_1" = alloca i1, align 1 br label %entry_block entry_block: ; preds = %alloca_block + store { i1, i1 } (i1, i1)* @_hl.main_binary.11, { i1, i1 } (i1, i1)** %"14_0", align 8 store i1 %0, i1* %"12_0", align 1 store i1 %1, i1* %"12_1", align 1 - store { i1, i1 } (i1, i1)* @_hl.main_binary.11, { i1, i1 } (i1, i1)** %"14_0", align 8 %"14_01" = load { i1, i1 } (i1, i1)*, { i1, i1 } (i1, i1)** %"14_0", align 8 %"12_02" = load i1, i1* %"12_0", align 1 %"12_13" = load i1, i1* %"12_1", align 1 diff --git a/hugr-llvm/src/emit/test.rs b/hugr-llvm/src/emit/test.rs index 3f6977a8c..d53d2ef0c 100644 --- a/hugr-llvm/src/emit/test.rs +++ b/hugr-llvm/src/emit/test.rs @@ -4,13 +4,8 @@ use anyhow::{anyhow, Result}; use hugr_core::builder::{ BuildHandle, Container, DFGWrapper, HugrBuilder, ModuleBuilder, SubContainer, }; -use hugr_core::extension::prelude::PRELUDE_ID; -use hugr_core::extension::{ExtensionRegistry, ExtensionSet}; +use hugr_core::extension::ExtensionRegistry; use hugr_core::ops::handle::FuncID; -use hugr_core::std_extensions::arithmetic::{ - conversions, float_ops, float_types, int_ops, int_types, -}; -use hugr_core::std_extensions::{collections, logic}; use hugr_core::types::TypeRow; use hugr_core::{Hugr, HugrView, Node}; use inkwell::module::Module; @@ -150,23 +145,7 @@ impl SimpleHugrConfig { ) -> Hugr { let mut mod_b = ModuleBuilder::new(); let func_b = mod_b - .define_function( - "main", - HugrFuncType::new(self.ins, self.outs).with_extension_delta( - ExtensionSet::from_iter([ - PRELUDE_ID, - int_types::EXTENSION_ID, - int_ops::EXTENSION_ID, - float_types::EXTENSION_ID, - float_ops::EXTENSION_ID, - conversions::EXTENSION_ID, - logic::EXTENSION_ID, - collections::array::EXTENSION_ID, - collections::list::EXTENSION_ID, - collections::static_array::EXTENSION_ID, - ]), - ), - ) + .define_function("main", HugrFuncType::new(self.ins, self.outs)) .unwrap(); make(func_b, &self.extensions); @@ -265,7 +244,7 @@ mod test_fns { use hugr_core::ops::{CallIndirect, Tag, Value}; use hugr_core::std_extensions::arithmetic::int_ops::{self}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; + use hugr_core::std_extensions::arithmetic::int_types::{self, ConstInt}; use hugr_core::std_extensions::STD_REG; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{type_row, Hugr}; diff --git a/hugr-llvm/src/extension/collections/array.rs b/hugr-llvm/src/extension/collections/array.rs index 55dcecefc..0216e9014 100644 --- a/hugr-llvm/src/extension/collections/array.rs +++ b/hugr-llvm/src/extension/collections/array.rs @@ -708,7 +708,6 @@ pub fn emit_scan_op<'c, H: HugrView>( mod test { use hugr_core::builder::Container as _; use hugr_core::extension::prelude::either_type; - use hugr_core::extension::ExtensionSet; use hugr_core::ops::Tag; use hugr_core::std_extensions::collections::array::{self, array_type, ArrayRepeat, ArrayScan}; use hugr_core::std_extensions::STD_REG; @@ -854,16 +853,6 @@ mod test { ]) } - fn exec_extension_set() -> ExtensionSet { - ExtensionSet::from_iter([ - int_types::EXTENSION_ID, - int_ops::EXTENSION_ID, - logic::EXTENSION_ID, - prelude::PRELUDE_ID, - array::EXTENSION_ID, - ]) - } - #[rstest] #[case(0, 1)] #[case(1, 2)] @@ -1223,16 +1212,12 @@ mod test { .with_extensions(exec_registry()) .finish(|mut builder| { let mut func = builder - .define_function( - "foo", - Signature::new(vec![], vec![int_ty.clone()]) - .with_extension_delta(exec_extension_set()), - ) + .define_function("foo", Signature::new(vec![], vec![int_ty.clone()])) .unwrap(); let v = func.add_load_value(ConstInt::new_u(6, value).unwrap()); let func_id = func.finish_with_outputs(vec![v]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let repeat = ArrayRepeat::new(int_ty.clone(), size, exec_extension_set()); + let repeat = ArrayRepeat::new(int_ty.clone(), size); let arr = builder .add_dataflow_op(repeat, vec![func_v]) .unwrap() @@ -1280,8 +1265,7 @@ mod test { let mut func = builder .define_function( "foo", - Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]) - .with_extension_delta(exec_extension_set()), + Signature::new(vec![int_ty.clone()], vec![int_ty.clone()]), ) .unwrap(); let [elem] = func.input_wires_arr(); @@ -1289,13 +1273,7 @@ mod test { let out = func.add_iadd(6, elem, delta).unwrap(); let func_id = func.finish_with_outputs(vec![out]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( - int_ty.clone(), - int_ty.clone(), - vec![], - size, - exec_extension_set(), - ); + let scan = ArrayScan::new(int_ty.clone(), int_ty.clone(), vec![], size); let mut arr = builder .add_dataflow_op(scan, [arr, func_v]) .unwrap() @@ -1357,8 +1335,7 @@ mod test { Signature::new( vec![int_ty.clone(), int_ty.clone()], vec![Type::UNIT, int_ty.clone()], - ) - .with_extension_delta(exec_extension_set()), + ), ) .unwrap(); let [elem, acc] = func.input_wires_arr(); @@ -1369,13 +1346,7 @@ mod test { .out_wire(0); let func_id = func.finish_with_outputs(vec![unit, acc]).unwrap(); let func_v = builder.load_func(func_id.handle(), &[]).unwrap(); - let scan = ArrayScan::new( - int_ty.clone(), - Type::UNIT, - vec![int_ty.clone()], - size, - exec_extension_set(), - ); + let scan = ArrayScan::new(int_ty.clone(), Type::UNIT, vec![int_ty.clone()], size); let zero = builder.add_load_value(ConstInt::new_u(6, 0).unwrap()); let sum = builder .add_dataflow_op(scan, [arr, func_v, zero]) diff --git a/hugr-llvm/src/extension/collections/static_array.rs b/hugr-llvm/src/extension/collections/static_array.rs index 7d3ac5f5c..7f59bff82 100644 --- a/hugr-llvm/src/extension/collections/static_array.rs +++ b/hugr-llvm/src/extension/collections/static_array.rs @@ -58,6 +58,7 @@ fn value_is_const<'c>(value: impl BasicValue<'c>) -> bool { BasicValueEnum::PointerValue(v) => v.is_const(), BasicValueEnum::StructValue(v) => v.is_const(), BasicValueEnum::VectorValue(v) => v.is_const(), + BasicValueEnum::ScalableVectorValue(v) => v.is_const(), } } @@ -109,6 +110,13 @@ fn const_array<'c>( .collect_vec() .as_slice(), ), + BasicTypeEnum::ScalableVectorType(t) => t.const_array( + values + .into_iter() + .map(|x| x.as_basic_value_enum().into_scalable_vector_value()) + .collect_vec() + .as_slice(), + ), } } diff --git a/hugr-llvm/src/sum.rs b/hugr-llvm/src/sum.rs index c2b9a0475..381e09469 100644 --- a/hugr-llvm/src/sum.rs +++ b/hugr-llvm/src/sum.rs @@ -47,6 +47,7 @@ fn basic_type_undef<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> { BasicTypeEnum::PointerType(t) => t.get_undef().as_basic_value_enum(), BasicTypeEnum::StructType(t) => t.get_undef().as_basic_value_enum(), BasicTypeEnum::VectorType(t) => t.get_undef().as_basic_value_enum(), + BasicTypeEnum::ScalableVectorType(t) => t.get_undef().as_basic_value_enum(), } } @@ -60,6 +61,7 @@ fn basic_type_poison<'c>(t: impl BasicType<'c>) -> BasicValueEnum<'c> { BasicTypeEnum::PointerType(t) => t.get_poison().as_basic_value_enum(), BasicTypeEnum::StructType(t) => t.get_poison().as_basic_value_enum(), BasicTypeEnum::VectorType(t) => t.get_poison().as_basic_value_enum(), + BasicTypeEnum::ScalableVectorType(t) => t.get_poison().as_basic_value_enum(), } } diff --git a/hugr-llvm/src/sum/layout.rs b/hugr-llvm/src/sum/layout.rs index fd67a3240..d016de851 100644 --- a/hugr-llvm/src/sum/layout.rs +++ b/hugr-llvm/src/sum/layout.rs @@ -45,6 +45,9 @@ fn size_of_type<'c>(t: impl BasicType<'c>) -> Option { BasicTypeEnum::PointerType(t) => t.size_of().get_zero_extended_constant(), BasicTypeEnum::StructType(t) => t.size_of().and_then(|x| x.get_zero_extended_constant()), BasicTypeEnum::VectorType(t) => t.size_of().and_then(|x| x.get_zero_extended_constant()), + BasicTypeEnum::ScalableVectorType(t) => { + t.size_of().and_then(|x| x.get_zero_extended_constant()) + } } } diff --git a/hugr-llvm/src/utils/fat.rs b/hugr-llvm/src/utils/fat.rs index dec866b4e..5deeb4bf0 100644 --- a/hugr-llvm/src/utils/fat.rs +++ b/hugr-llvm/src/utils/fat.rs @@ -47,7 +47,7 @@ where /// Note that while we do check the type of the node's `get_optype`, we /// do not verify that it is actually equal to `ot`. pub fn new(hugr: &'hugr H, node: H::Node, #[allow(unused)] ot: &OT) -> Self { - assert!(hugr.valid_node(node)); + assert!(hugr.contains_node(node)); assert!(TryInto::<&OT>::try_into(hugr.get_optype(node)).is_ok()); // We don't actually check `ot == hugr.get_optype(node)` so as to not require OT: PartialEq` Self { @@ -63,7 +63,7 @@ where /// If the node is invalid, or if its `get_optype` is not `OT`, returns /// `None`. pub fn try_new(hugr: &'hugr H, node: H::Node) -> Option { - (hugr.valid_node(node)).then_some(())?; + (hugr.contains_node(node)).then_some(())?; Some(Self::new( hugr, node, @@ -99,7 +99,7 @@ impl<'hugr, H: HugrView + ?Sized> FatNode<'hugr, OpType, H, H::Node> { /// /// Panics if the node is not valid in the [Hugr]. pub fn new_optype(hugr: &'hugr H, node: H::Node) -> Self { - assert!(hugr.valid_node(node)); + assert!(hugr.contains_node(node)); FatNode::new(hugr, node, hugr.get_optype(node)) } diff --git a/hugr-llvm/src/utils/inline_constant_functions.rs b/hugr-llvm/src/utils/inline_constant_functions.rs index 28e664b97..1b0931bd2 100644 --- a/hugr-llvm/src/utils/inline_constant_functions.rs +++ b/hugr-llvm/src/utils/inline_constant_functions.rs @@ -11,12 +11,12 @@ fn const_fn_name(konst_n: Node) -> String { format!("const_fun_{}", konst_n.index()) } -pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> { +pub fn inline_constant_functions(hugr: &mut impl HugrMut) -> Result<()> { while inline_constant_functions_impl(hugr)? {} Ok(()) } -fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { +fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { let mut const_funs = vec![]; for n in hugr.nodes() { @@ -69,7 +69,7 @@ fn inline_constant_functions_impl(hugr: &mut impl HugrMut) -> Result { hugr.insert_hugr(func_node, func_hugr); for lcn in load_constant_ns { - hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?)?; + hugr.replace_op(lcn, LoadFunction::try_new(polysignature.clone(), [])?); } any_changes = true; } diff --git a/hugr-model/README.md b/hugr-model/README.md index 0ea6fdf8f..be93253eb 100644 --- a/hugr-model/README.md +++ b/hugr-model/README.md @@ -30,7 +30,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-model/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-model [crates]: https://img.shields.io/crates/v/hugr-core [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index 2f8a5ba6e..c9be8896b 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -362,6 +362,7 @@ impl<'a> Context<'a> { /// Error that may occur in [`Module::resolve`]. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ResolveError { /// Unknown variable. #[error("unknown var: {0}")] diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 756a52c1e..55a4b9889 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -456,6 +456,7 @@ pub struct VarId(pub NodeId, pub VarIndex); /// Errors that can occur when traversing and interpreting the model. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ModelError { /// There is a reference to a node that does not exist. #[error("node not found: {0}")] diff --git a/hugr-passes/Cargo.toml b/hugr-passes/Cargo.toml index 7a2c50367..e4e4f9087 100644 --- a/hugr-passes/Cargo.toml +++ b/hugr-passes/Cargo.toml @@ -26,9 +26,6 @@ paste = { workspace = true } thiserror = { workspace = true } petgraph = { workspace = true } -[features] -extension_inference = ["hugr-core/extension_inference"] - [dev-dependencies] rstest = { workspace = true } proptest = { workspace = true } diff --git a/hugr-passes/README.md b/hugr-passes/README.md index b9552fe75..c2bca2124 100644 --- a/hugr-passes/README.md +++ b/hugr-passes/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr-passes -=============== +# hugr-passes [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr-passes) @@ -29,13 +28,6 @@ cargo add hugr-passes Please read the [API documentation here][]. -## Experimental Features - -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - ## Recent Changes See [CHANGELOG][] for a list of changes. The minimum supported rust @@ -51,8 +43,8 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr-passes/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr-passes [crates]: https://img.shields.io/crates/v/hugr-passes [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE - [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md + [CHANGELOG]: https://github.com/CQCL/hugr/blob/main/hugr-passes/CHANGELOG.md \ No newline at end of file diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs new file mode 100644 index 000000000..faf92b8a7 --- /dev/null +++ b/hugr-passes/src/composable.rs @@ -0,0 +1,359 @@ +//! Compiler passes and utilities for composing them + +use std::{error::Error, marker::PhantomData}; + +use hugr_core::core::HugrNode; +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; +use itertools::Either; + +/// An optimization pass that can be sequenced with another and/or wrapped +/// e.g. by [ValidatingPass] +pub trait ComposablePass: Sized { + type Node: HugrNode; + type Error: Error; + type Result; // Would like to default to () but currently unstable + + fn run(&self, hugr: &mut impl HugrMut) -> Result; + + fn map_err( + self, + f: impl Fn(Self::Error) -> E2, + ) -> impl ComposablePass { + ErrMapper::new(self, f) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", so long as + /// `other::Err` can be combined with ours. + fn then, E: ErrorCombiner>( + self, + other: P, + ) -> impl ComposablePass { + struct Sequence(P1, P2, PhantomData); + impl ComposablePass for Sequence + where + P1: ComposablePass, + P2: ComposablePass, + E: ErrorCombiner, + { + type Node = P1::Node; + type Error = E; + type Result = (P1::Result, P2::Result); + + fn run( + &self, + hugr: &mut impl HugrMut, + ) -> Result { + let res1 = self.0.run(hugr).map_err(E::from_first)?; + let res2 = self.1.run(hugr).map_err(E::from_second)?; + Ok((res1, res2)) + } + } + + Sequence(self, other, PhantomData) + } +} + +/// Trait for combining the error types from two different passes +/// into a single error. +pub trait ErrorCombiner: Error { + fn from_first(a: A) -> Self; + fn from_second(b: B) -> Self; +} + +impl> ErrorCombiner for A { + fn from_first(a: A) -> Self { + a + } + + fn from_second(b: B) -> Self { + b.into() + } +} + +impl ErrorCombiner for Either { + fn from_first(a: A) -> Self { + Either::Left(a) + } + + fn from_second(b: B) -> Self { + Either::Right(b) + } +} + +// Note: in the short term we could wish for two more impls: +// impl ErrorCombiner for E +// impl ErrorCombiner for E +// however, these aren't possible as they conflict with +// impl> ErrorCombiner for A +// when A=E=Infallible, boo :-(. +// However this will become possible, indeed automatic, when Infallible is replaced +// by ! (never_type) as (unlike Infallible) ! converts Into anything + +// ErrMapper ------------------------------ +struct ErrMapper(P, F, PhantomData); + +impl E> ErrMapper { + fn new(pass: P, err_fn: F) -> Self { + Self(pass, err_fn, PhantomData) + } +} + +impl E> ComposablePass for ErrMapper { + type Node = P::Node; + type Error = E; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr).map_err(&self.1) + } +} + +// ValidatingPass ------------------------------ + +/// Error from a [ValidatingPass] +#[derive(thiserror::Error, Debug)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + Input { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + Output { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error(transparent)] + Underlying(#[from] E), +} + +/// Runs an underlying pass, but with validation of the Hugr +/// both before and afterwards. +pub struct ValidatingPass

(P); + +impl ValidatingPass

{ + pub fn new(underlying: P) -> Self { + Self(underlying) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), ValidatePassError> { + hugr.validate() + .map_err(|err| mk_err(err, hugr.mermaid_string())) + } +} + +impl ComposablePass for ValidatingPass

{ + type Node = P::Node; + type Error = ValidatePassError; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { + err, + pretty_hugr, + })?; + let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { + err, + pretty_hugr, + })?; + Ok(res) + } +} + +// IfThen ------------------------------ +/// [ComposablePass] that executes a first pass that returns a `bool` +/// result; and then, if-and-only-if that first result was true, +/// executes a second pass +pub struct IfThen(A, B, PhantomData); + +impl< + A: ComposablePass, + B: ComposablePass, + E: ErrorCombiner, + > IfThen +{ + /// Make a new instance given the [ComposablePass] to run first + /// and (maybe) second + pub fn new(fst: A, opt_snd: B) -> Self { + Self(fst, opt_snd, PhantomData) + } +} + +impl< + A: ComposablePass, + B: ComposablePass, + E: ErrorCombiner, + > ComposablePass for IfThen +{ + type Node = A::Node; + type Error = E; + type Result = Option; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; + res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) + .transpose() + } +} + +pub(crate) fn validate_if_test( + pass: P, + hugr: &mut impl HugrMut, +) -> Result> { + if cfg!(test) { + ValidatingPass::new(pass).run(hugr) + } else { + pass.run(hugr).map_err(ValidatePassError::Underlying) + } +} + +#[cfg(test)] +mod test { + use itertools::{Either, Itertools}; + use std::convert::Infallible; + + use hugr_core::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use hugr_core::extension::prelude::{bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple}; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::{Signature, TypeRow}; + use hugr_core::{Hugr, HugrView, IncomingPort}; + + use crate::const_fold::{ConstFoldError, ConstantFoldPass}; + use crate::untuple::{UntupleRecursive, UntupleResult}; + use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass}; + + use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass}; + + #[test] + fn test_then() { + let mut mb = ModuleBuilder::new(); + let id1 = mb + .define_function("id1", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id1.input_wires(); + let id1 = id1.finish_with_outputs(inps).unwrap(); + let id2 = mb + .define_function("id2", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id2.input_wires(); + let id2 = id2.finish_with_outputs(inps).unwrap(); + let hugr = mb.finish_hugr().unwrap(); + + let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]); + let cfold = + ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]); + + cfold.run(&mut hugr.clone()).unwrap(); + + let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE); + let r: Result<_, Either> = + dce.clone().then(cfold.clone()).run(&mut hugr.clone()); + assert_eq!(r, Err(Either::Right(exp_err.clone()))); + + let r = dce + .clone() + .map_err(|inf| match inf {}) + .then(cfold.clone()) + .run(&mut hugr.clone()); + assert_eq!(r, Err(exp_err)); + + let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone()); + r2.unwrap(); + } + + #[test] + fn test_validation() { + let mut h = Hugr::new(DFG { + signature: Signature::new(usize_t(), bool_t()), + }); + let inp = h.add_node_with_parent( + h.root(), + Input { + types: usize_t().into(), + }, + ); + let outp = h.add_node_with_parent( + h.root(), + Output { + types: bool_t().into(), + }, + ); + h.connect(inp, 0, outp, 0); + let backup = h.clone(); + let err = backup.validate().unwrap_err(); + + let no_inputs: [(IncomingPort, _); 0] = []; + let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs); + cfold.run(&mut h).unwrap(); + assert_eq!(h, backup); // Did nothing + + let r = ValidatingPass(cfold).run(&mut h); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + } + + #[test] + fn test_if_then() { + let tr = TypeRow::from(vec![usize_t(); 2]); + + let h = { + let sig = Signature::new_endo(tr.clone()); + let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); + let [a, b] = fb.input_wires_arr(); + let tup = fb + .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b]) + .unwrap(); + let untup = fb + .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs()) + .unwrap(); + fb.finish_hugr_with_outputs(untup.outputs()).unwrap() + }; + + let untup = UntuplePass::new(UntupleRecursive::Recursive); + { + // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple + let mut repl = ReplaceTypes::default(); + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + repl.replace_type(usize_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup.clone()); + + let mut h = h.clone(); + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!( + r, + Some(UntupleResult { + rewrites_applied: 1 + }) + ); + let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap(); + assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]); + } + + // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple + let mut repl = ReplaceTypes::default(); + let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone(); + repl.replace_type(i32_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup); + let mut h = h; + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!(r, None); + assert_eq!(h.children(h.root()).count(), 4); + let mktup = h + .output_neighbours(h.first_child(h.root()).unwrap()) + .next() + .unwrap(); + assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr))); + } +} diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..b406ae894 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,15 +7,11 @@ use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, + hugr::hugrmut::HugrMut, ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, }, - types::{EdgeKind, TypeArg}, + types::EdgeKind, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; @@ -25,12 +21,11 @@ use crate::dataflow::{ TailLoopTermination, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{composable::validate_if_test, ComposablePass}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { - validation: ValidationLevel, allow_increase_termination: bool, /// Each outer key Node must be either: /// - a FuncDefn child of the root, if the root is a module; or @@ -38,13 +33,10 @@ pub struct ConstantFoldPass { inputs: HashMap>, } -#[derive(Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), /// Error raised when a Node is specified as an entry-point but /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor /// a [Conditional](OpType::Conditional). @@ -53,12 +45,6 @@ pub enum ConstFoldError { } impl ConstantFoldPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their /// result (if/when they do terminate) is either known or not needed. /// @@ -90,9 +76,20 @@ impl ConstantFoldPass { .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); self } +} + +impl ComposablePass for ConstantFoldPass { + type Node = Node; + type Error = ConstFoldError; + type Result = (); /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + /// + /// # Errors + /// + /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] + /// was of an invalid [OpType] + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -102,7 +99,7 @@ impl ConstantFoldPass { n, in_vals.iter().map(|(p, v)| { let const_with_dummy_loc = partial_from_const( - &ConstFoldContext(hugr), + &ConstFoldContext, ConstLocation::Field(p.index(), &fresh_node.into()), v, ); @@ -112,7 +109,7 @@ impl ConstantFoldPass { .map_err(|opty| ConstFoldError::InvalidEntryPoint(n, opty))?; } - let results = m.run(ConstFoldContext(hugr), []); + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); let wires_to_break = hugr @@ -131,7 +128,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) @@ -168,23 +165,10 @@ impl ConstantFoldPass { } }) }) - .run(hugr)?; + .run(hugr) + .map_err(|inf| match inf {})?; // TODO use into_ok when available Ok(()) } - - /// Run the pass using this configuration. - /// - /// # Errors - /// - /// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards - /// (if [Self::validation_level] is set, or in tests) - /// - /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] - /// was of an invalid OpType - pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } /// Exhaustively apply constant folding to a HUGR. @@ -192,7 +176,7 @@ impl ConstantFoldPass { /// /// [FuncDefn]: hugr_core::ops::OpType::FuncDefn /// [Module]: hugr_core::ops::OpType::Module -pub fn constant_fold_pass(h: &mut H) { +pub fn constant_fold_pass>(h: &mut H) { let c = ConstantFoldPass::default(); let c = if h.get_optype(h.root()).is_module() { let no_inputs: [(IncomingPort, _); 0] = []; @@ -202,63 +186,38 @@ pub fn constant_fold_pass(h: &mut H) { } else { c }; - c.run(h).unwrap() + validate_if_test(c, h).unwrap() } -struct ConstFoldContext<'a, H>(&'a H); - -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} +struct ConstFoldContext; -impl> ConstLoader> for ConstFoldContext<'_, H> { - type Node = H::Node; +impl ConstLoader> for ConstFoldContext { + type Node = Node; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } - - fn value_from_function( - &self, - node: H::Node, - type_args: &[TypeArg], - ) -> Option> { - if !type_args.is_empty() { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; - Some(ValueHandle::new_const_hugr( - ConstLocation::Node(node), - Box::new(func.extract_hugr()), - )) - } } -impl> DFContext> for ConstFoldContext<'_, H> { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: H::Node, + node: Node, op: &ExtensionOp, - ins: &[PartialValue>], - outs: &mut [PartialValue>], + ins: &[PartialValue>], + outs: &mut [PartialValue>], ) { let sig = op.signature(); let known_ins = sig diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..dcdc4df0a 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -3,7 +3,6 @@ use std::collections::HashSet; use hugr_core::ops::handle::NodeHandle; use hugr_core::ops::Const; -use hugr_core::std_extensions::arithmetic::{int_ops, int_types}; use itertools::Itertools; use lazy_static::lazy_static; use rstest::rstest; @@ -32,6 +31,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::ComposablePass as _; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; @@ -42,8 +42,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&temp); + let ctx = ConstFoldContext; let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -114,8 +113,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let temp = Hugr::default(); - let mut ctx = ConstFoldContext(&temp); + let mut ctx = ConstFoldContext; let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); @@ -161,7 +159,7 @@ fn test_big() { .unwrap(); let mut h = build.finish_hugr_with_outputs(to_int.outputs()).unwrap(); - assert_eq!(h.node_count(), 8); + assert_eq!(h.num_nodes(), 8); constant_fold_pass(&mut h); @@ -334,7 +332,7 @@ fn test_const_fold_to_nonfinite() { assert_fully_folded_with(&h0, |v| { v.get_custom_value::().unwrap().value() == 1.0 }); - assert_eq!(h0.node_count(), 5); + assert_eq!(h0.num_nodes(), 5); // HUGR computing 1.0 / 0.0 let mut build = DFGBuilder::new(noargfn(vec![float64_type()])).unwrap(); @@ -343,7 +341,7 @@ fn test_const_fold_to_nonfinite() { let x2 = build.add_dataflow_op(FloatOps::fdiv, [x0, x1]).unwrap(); let mut h1 = build.finish_hugr_with_outputs(x2.outputs()).unwrap(); constant_fold_pass(&mut h1); - assert_eq!(h1.node_count(), 8); + assert_eq!(h1.num_nodes(), 8); } #[test] @@ -1363,7 +1361,7 @@ fn test_tail_loop_unknown() { constant_fold_pass(&mut h); // Must keep the loop, even though we know the output, in case the output doesn't happen - assert_eq!(h.node_count(), 12); + assert_eq!(h.num_nodes(), 12); let tl = h .nodes() .filter(|n| h.get_optype(*n).is_tail_loop()) @@ -1596,9 +1594,7 @@ fn test_module() -> Result<(), Box> { let ad2 = mb.add_alias_def("unused2", INT_TYPES[3].clone())?; let mut main = mb.define_function( "main", - Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]) - .with_extension_delta(int_types::EXTENSION_ID) - .with_extension_delta(int_ops::EXTENSION_ID), + Signature::new(type_row![], vec![INT_TYPES[5].clone(); 2]), )?; let lc7 = main.load_const(&c7); let lc17 = main.load_const(&c17); diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..e5c99a8e7 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -1,16 +1,18 @@ //! Total equality (and hence [AbstractValue] support for [Value]s //! (by adding a source-Node and part unhashable constants) use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::convert::Infallible; use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; +use hugr_core::types::ConstTypeError; use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, ConstLocation}; +use crate::dataflow::{AbstractValue, AsConcrete, ConstLocation, LoadedFunction, Sum}; /// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] @@ -153,9 +155,12 @@ impl Hash for ValueHandle { // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl From> for Value { - fn from(value: ValueHandle) -> Self { - match value { +impl AsConcrete, N> for Value { + type ValErr = Infallible; + type SumErr = ConstTypeError; + + fn from_value(value: ValueHandle) -> Result { + Ok(match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { leaf: Either::Left(val), @@ -169,7 +174,15 @@ impl From> for Value { } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap(), - } + }) + } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 43caa9c94..1f7c1ae5a 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -9,7 +9,7 @@ mod results; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, AsConcrete, LoadedFunction, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; @@ -31,8 +31,8 @@ pub trait DFContext: ConstLoader { &mut self, _node: Self::Node, _e: &ExtensionOp, - _ins: &[PartialValue], - _outs: &mut [PartialValue], + _ins: &[PartialValue], + _outs: &mut [PartialValue], ) { } } @@ -55,8 +55,8 @@ impl From for ConstLocation<'_, N> { } /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. -/// Implementors will likely want to override some/all of [Self::value_from_opaque], -/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// Implementors will likely want to override either/both of [Self::value_from_opaque] +/// and [Self::value_from_const_hugr]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { /// The type of nodes in the Hugr. @@ -81,6 +81,7 @@ pub trait ConstLoader { /// [FuncDefn]: hugr_core::ops::FuncDefn /// [FuncDecl]: hugr_core::ops::FuncDecl /// [LoadFunction]: hugr_core::ops::LoadFunction + #[deprecated(note = "Automatically handled by Datalog, implementation will be ignored")] fn value_from_function(&self, _node: Self::Node, _type_args: &[TypeArg]) -> Option { None } @@ -94,7 +95,7 @@ pub fn partial_from_const<'a, V, CL: ConstLoader>( cl: &CL, loc: impl Into>, cst: &Value, -) -> PartialValue +) -> PartialValue where CL::Node: 'a, { @@ -120,8 +121,8 @@ where /// A row of inputs to a node contains bottom (can't happen, the node /// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). -pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( - elements: impl IntoIterator>, +pub fn row_contains_bottom<'a, V: 'a, N: 'a>( + elements: impl IntoIterator>, ) -> bool { elements.into_iter().any(PartialValue::contains_bottom) } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 13e510daf..ad1a99345 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -3,19 +3,22 @@ use std::collections::HashMap; use ascent::lattice::BoundedLattice; +use ascent::Lattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; use super::{ partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, - PartialValue, + LoadedFunction, PartialValue, }; -type PV = PartialValue; +type PV = PartialValue; + +type NodeInputs = Vec<(IncomingPort, PV)>; /// Basic structure for performing an analysis. Usage: /// 1. Make a new instance via [Self::new()] @@ -25,10 +28,7 @@ type PV = PartialValue; /// [Self::prepopulate_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine( - H, - HashMap)>>, -); +pub struct Machine(H, HashMap>); impl Machine { /// Create a new Machine to analyse the given Hugr(View) @@ -40,7 +40,7 @@ impl Machine { impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed /// or any value previously prepopulated for the same Wire. - pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { for (n, inp) in self.0.linked_inputs(w.node(), w.source()) { self.1.entry(n).or_default().push((inp, v.clone())); } @@ -54,7 +54,7 @@ impl Machine { pub fn prepopulate_inputs( &mut self, parent: H::Node, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> Result<(), OpType> { match self.0.get_optype(parent) { OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => { @@ -102,7 +102,7 @@ impl Machine { pub fn run( mut self, context: impl DFContext, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = self.0.root(); if self.0.get_optype(root).is_module() { @@ -135,10 +135,12 @@ impl Machine { } } +pub(super) type InWire = (N, IncomingPort, PartialValue); + pub(super) fn run_datalog( mut ctx: impl DFContext, hugr: H, - in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, + in_wire_value_proto: Vec>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. @@ -155,9 +157,9 @@ pub(super) fn run_datalog( relation parent_of_node(H::Node, H::Node); // is parent of relation input_child(H::Node, H::Node); // has 1st child that is its `Input` relation output_child(H::Node, H::Node); // has 2nd child that is its `Output` - lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value - lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value - lattice node_in_value_row(H::Node, ValueRow); // 's inputs are + lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(H::Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -322,6 +324,37 @@ pub(super) fn run_datalog( func_call(call, func), output_child(func, outp), in_wire_value(outp, p, v); + + // CallIndirect -------------------- + lattice indirect_call(H::Node, LatticeWrapper); // is an `IndirectCall` to `FuncDefn` + indirect_call(call, tgt) <-- + node(call), + if let OpType::CallIndirect(_) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + let tgt = load_func(v); + + out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + input_child(func, inp), + in_wire_value(call, p, v) + if p.index() > 0; + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + output_child(func, outp), + in_wire_value(outp, p, v); + + // Default out-value is Bottom, but if we can't determine the called function, + // assign everything to Top + out_wire_value(call, p, PV::Top) <-- + node(call), + if let OpType::CallIndirect(ci) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + // Second alternative below addresses function::Value's: + if matches!(v, PartialValue::Top | PartialValue::Value(_)), + for p in ci.signature().output_ports(); }; let out_wire_values = all_results .out_wire_value @@ -337,13 +370,58 @@ pub(super) fn run_datalog( } } +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd)] +enum LatticeWrapper { + Bottom, + Value(T), + Top, +} + +impl Lattice for LatticeWrapper { + fn meet_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + return false; + }; + if *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + *self = other; + return true; + }; + // Both are `Value`s and not equal + *self = LatticeWrapper::Bottom; + true + } + + fn join_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + return false; + }; + if *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + *self = other; + return true; + }; + // Both are `Value`s and are not equal + *self = LatticeWrapper::Top; + true + } +} + +fn load_func(v: &PV) -> LatticeWrapper { + match v { + PartialValue::Bottom | PartialValue::PartialSum(_) => LatticeWrapper::Bottom, + PartialValue::LoadedFunction(LoadedFunction { func_node, .. }) => { + LatticeWrapper::Value(*func_node) + } + PartialValue::Value(_) | PartialValue::Top => LatticeWrapper::Top, + } +} + fn propagate_leaf_op( ctx: &mut impl DFContext, hugr: &H, n: H::Node, - ins: &[PV], + ins: &[PV], num_outs: usize, -) -> Option> { +) -> Option> { match hugr.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. @@ -362,8 +440,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent - OpType::Call(_) => None, // handled via Input/Output of FuncDefn - OpType::Const(_) => None, // handled by LoadConstant: + OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = hugr @@ -380,10 +457,10 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::singleton( - ctx.value_from_function(func_node, &load_op.type_args) - .map_or(PV::Top, PV::Value), - )) + Some(ValueRow::singleton(PartialValue::new_load( + func_node, + load_op.type_args.clone(), + ))) } OpType::ExtensionOp(e) => { Some(ValueRow::from_iter(if row_contains_bottom(ins) { @@ -401,6 +478,54 @@ fn propagate_leaf_op( outs })) } - o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + // We only call propagate_leaf_op for dataflow op non-containers, + o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive + } +} + +#[cfg(test)] +mod test { + use ascent::Lattice; + + use super::LatticeWrapper; + + #[test] + fn latwrap_join() { + for lv in [ + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + LatticeWrapper::Top, + ] { + let mut subject = LatticeWrapper::Bottom; + assert!(subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.join_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Top + ); + assert_eq!(subject, LatticeWrapper::Top); + } + } + + #[test] + fn latwrap_meet() { + for lv in [ + LatticeWrapper::Bottom, + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + ] { + let mut subject = LatticeWrapper::Top; + assert!(subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.meet_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Bottom + ); + assert_eq!(subject, LatticeWrapper::Bottom); + } } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f2a497806..f7b8a171c 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,7 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::Node; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -51,15 +51,25 @@ pub struct Sum { pub st: SumType, } +/// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" +/// to a function at a specific node, instantiated with the provided type-args. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct LoadedFunction { + /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded + pub func_node: N, + /// The type arguments provided when loading + pub args: Vec, +} + /// A representation of a value of [SumType], that may have one or more possible tags, /// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] -pub struct PartialSum(pub HashMap>>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { /// New instance for a single known tag. /// (Multi-tag instances can be created via [Self::try_join_mut].) - pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -75,9 +85,21 @@ impl PartialSum { pv.assert_invariants(); } } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } -impl PartialSum { +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -141,12 +163,33 @@ impl PartialSum { } Ok(changed) } +} - /// Whether this sum might have the specified tag - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.contains_key(&tag) - } +/// Trait implemented by value types into which [PartialValue]s can be converted, +/// so long as the PV has no [Top](PartialValue::Top), [Bottom](PartialValue::Bottom) +/// or [PartialSum]s with more than one possible tag. See [PartialSum::try_into_sum] +/// and [PartialValue::try_into_concrete]. +/// +/// `V` is the type of [AbstractValue] from which `Self` can (fallibly) be constructed, +/// `N` is the type of [HugrNode](hugr_core::core::HugrNode) for function pointers +pub trait AsConcrete: Sized { + /// Kind of error raised when creating `Self` from a value `V`, see [Self::from_value] + type ValErr: std::error::Error; + /// Kind of error that may be raised when creating `Self` from a [Sum] of `Self`s, + /// see [Self::from_sum] + type SumErr: std::error::Error; + + /// Convert an abstract value into concrete + fn from_value(val: V) -> Result; + + /// Convert a sum (of concrete values, already recursively converted) into concrete + fn from_sum(sum: Sum) -> Result; + + /// Convert a function pointer into a concrete value + fn from_func(func: LoadedFunction) -> Result>; +} +impl PartialSum { /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. /// @@ -155,11 +198,11 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> - where - V: TryInto, - Sum: TryInto, - { + #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases + pub fn try_into_sum>( + self, + typ: &Type, + ) -> Result, ExtractValueError> { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); } @@ -185,22 +228,15 @@ impl PartialSum { num_elements: v.len(), }) } - - /// Can this ever occur at runtime? See [PartialValue::contains_bottom] - pub fn contains_bottom(&self) -> bool { - self.0 - .iter() - .all(|(_tag, elements)| row_contains_bottom(elements)) - } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type /// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] - MultipleVariants(PartialSum), + MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] ValueIsBottom, #[error("Value contained `Top`")] @@ -209,6 +245,8 @@ pub enum ExtractValueError { CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] CouldNotBuildSum(#[source] SE), + #[error("Could not convert into concrete function pointer {0}")] + CouldNotLoadFunction(LoadedFunction), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -217,14 +255,14 @@ pub enum ExtractValueError { }, } -impl PartialSum { +impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. - pub fn variant_values(&self, variant: usize) -> Option>> { + pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } } -impl PartialOrd for PartialSum { +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -254,13 +292,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -273,30 +311,32 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { /// No possibilities known (so far) Bottom, + /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) + LoadedFunction(LoadedFunction), /// A single value (of the underlying representation) Value(V), /// Sum (with at least one, perhaps several, possible tags) of underlying values - PartialSum(PartialSum), + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } } -impl From> for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -312,33 +352,59 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } + + /// New instance of self for a [LoadFunction](hugr_core::ops::LoadFunction) + pub fn new_load(func_node: N, args: impl Into>) -> Self { + Self::LoadedFunction(LoadedFunction { + func_node, + args: args.into(), + }) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + false + } + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } -impl PartialValue { +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + return None + } PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) } +} - /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } - +impl PartialValue { /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by /// [PartialSum::try_into_sum]. @@ -348,47 +414,27 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete(self, typ: &Type) -> Result> - where - V: TryInto, - Sum: TryInto, - { + pub fn try_into_concrete>( + self, + typ: &Type, + ) -> Result> { match self { - Self::Value(v) => v - .clone() - .try_into() - .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), - Self::PartialSum(ps) => ps - .try_into_sum(typ)? - .try_into() - .map_err(ExtractValueError::CouldNotBuildSum), + Self::Value(v) => { + C::from_value(v.clone()).map_err(|e| ExtractValueError::CouldNotConvert(v, e)) + } + Self::LoadedFunction(lf) => { + C::from_func(lf).map_err(ExtractValueError::CouldNotLoadFunction) + } + Self::PartialSum(ps) => { + C::from_sum(ps.try_into_sum(typ)?).map_err(ExtractValueError::CouldNotBuildSum) + } Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } - - /// A value contains bottom means that it cannot occur during execution: - /// it may be an artefact during bootstrapping of the analysis, or else - /// the value depends upon a `panic` or a loop that - /// [never terminates](super::TailLoopTermination::NeverBreaks). - pub fn contains_bottom(&self) -> bool { - match self { - PartialValue::Bottom => true, - PartialValue::Top | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.contains_bottom(), - } - } } -impl TryFrom> for Value { - type Error = ConstTypeError; - - fn try_from(value: Sum) -> Result { - Self::sum(value.tag, value.values, value.st) - } -} - -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); let mut old_self = Self::Top; @@ -400,13 +446,17 @@ impl Lattice for PartialValue { Some((h3, b)) => (Self::Value(h3), b), None => (Self::Top, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also join the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Top, true) - } + _ => (Self::Top, true), }; *self = res; ch @@ -423,20 +473,24 @@ impl Lattice for PartialValue { Some((h3, ch)) => (Self::Value(h3), ch), None => (Self::Bottom, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also meet the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Bottom, true) - } + _ => (Self::Bottom, true), }; *self = res; ch } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self::Top } @@ -446,7 +500,7 @@ impl BoundedLattice for PartialValue { } } -impl PartialOrd for PartialValue { +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { @@ -457,6 +511,9 @@ impl PartialOrd for PartialValue { (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) => { + (lf1 == lf2).then_some(Ordering::Equal) + } (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } @@ -468,19 +525,20 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::NodeIndex; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PartialSum, PartialValue}; + use super::{AbstractValue, LoadedFunction, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { Branch(Vec>>), - /// None => unit, Some => TestValue <= this *usize* - Leaf(Option), + LeafVal(usize), // contains a TestValue <= this usize + LeafPtr(usize), // contains a LoadedFunction with node <= this *usize* } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -509,8 +567,11 @@ mod test { fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::LeafVal(max), PartialValue::Value(TestValue(val))) => val <= max, + ( + Self::LeafPtr(max), + PartialValue::LoadedFunction(LoadedFunction { func_node, args }), + ) => args.is_empty() && func_node.index() <= *max, (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { @@ -537,8 +598,11 @@ mod test { fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; - let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + let leaf_strat = prop_oneof![ + (0..usize::MAX).prop_map(TestSumType::LeafVal), + // This is the maximum value accepted by portgraph::NodeIndex::new + (0..((2usize ^ 31) - 2)).prop_map(TestSumType::LeafPtr) + ]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, @@ -605,11 +669,18 @@ mod test { ust: &TestSumType, ) -> impl Strategy> { match ust { - TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), - TestSumType::Leaf(Some(i)) => (0..*i) + TestSumType::LeafVal(i) => (0..=*i) .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), + TestSumType::LeafPtr(i) => (0..=*i) + .prop_map(|i| { + PartialValue::LoadedFunction(LoadedFunction { + func_node: portgraph::NodeIndex::new(i).into(), + args: vec![], + }) + }) + .boxed(), TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c40f1d87f..c4a94a9e7 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,17 +1,19 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, PortIndex, Wire}; +use hugr_core::{HugrView, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; +use super::{ + datalog::InWire, partial_value::ExtractValueError, AbstractValue, AsConcrete, PartialValue, +}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, - pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, + pub(super) in_wire_value: Vec>, pub(super) case_reachable: Vec<(H::Node, H::Node)>, pub(super) bb_reachable: Vec<(H::Node, H::Node)>, - pub(super) out_wire_values: HashMap, PartialValue>, + pub(super) out_wire_values: HashMap, PartialValue>, } impl AnalysisResults { @@ -21,7 +23,7 @@ impl AnalysisResults { } /// Gets the lattice value computed for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() } @@ -84,13 +86,11 @@ impl AnalysisResults { /// `None` if the analysis did not produce a result for that wire, or if /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` - pub fn try_read_wire_concrete( + #[allow(clippy::type_complexity)] + pub fn try_read_wire_concrete>( &self, w: Wire, - ) -> Result>> - where - V2: TryFrom + TryFrom, Error = SE>, - { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr @@ -116,7 +116,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - fn from_control_value(v: &PartialValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3af0097f7..a67556ce1 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,16 +1,15 @@ +use std::convert::Infallible; + use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::builder::{inout_sig, CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; -use hugr_core::ops::TailLoop; -use hugr_core::types::TypeRow; +use hugr_core::ops::{CallIndirect, TailLoop}; +use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, - extension::{ - prelude::{bool_t, UnpackTuple}, - ExtensionSet, - }, + extension::prelude::{bool_t, UnpackTuple}, ops::{handle::NodeHandle, DataflowOpTrait, Tag, Value}, type_row, types::{Signature, SumType, Type}, @@ -19,7 +18,10 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{ + AbstractValue, AsConcrete, ConstLoader, DFContext, LoadedFunction, Machine, PartialValue, Sum, + TailLoopTermination, +}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,10 +37,22 @@ impl ConstLoader for TestContext { impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) -impl From for Value { - fn from(v: Void) -> Self { +impl AsConcrete for Value { + type ValErr = Infallible; + + type SumErr = ConstTypeError; + + fn from_value(v: Void) -> Result { match v {} } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) + } } fn pv_false() -> PartialValue { @@ -159,12 +173,7 @@ fn test_tail_loop_two_iters() { let false_w = builder.add_load_value(Value::false_val()); let tlb = builder - .tail_loop_builder_exts( - [], - [(bool_t(), false_w), (bool_t(), true_w)], - type_row![], - ExtensionSet::new(), - ) + .tail_loop_builder([], [(bool_t(), false_w), (bool_t(), true_w)], type_row![]) .unwrap(); assert_eq!( tlb.loop_signature().unwrap().signature().as_ref(), @@ -295,9 +304,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .try_read_wire_concrete::(cond_o2) - .is_err()); + assert!(results.try_read_wire_concrete::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); @@ -547,3 +554,78 @@ fn test_module() { ); } } + +#[rstest] +#[case(pv_false(), pv_false())] +#[case(pv_false(), pv_true())] +#[case(pv_true(), pv_false())] +#[case(pv_true(), pv_true())] +fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue) { + let b2b = || Signature::new_endo(bool_t()); + let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t(); 3], vec![bool_t(); 2])).unwrap(); + + let [id1, id2] = ["id1", "[id2]"].map(|name| { + let fb = dfb.define_function(name, b2b()).unwrap(); + let [inp] = fb.input_wires_arr(); + fb.finish_with_outputs([inp]).unwrap() + }); + + let [inp_direct, which, inp_indirect] = dfb.input_wires_arr(); + let [res1] = dfb + .call(id1.handle(), &[], [inp_direct]) + .unwrap() + .outputs_arr(); + + // We'll unconditionally load both functions, to demonstrate that it's + // the CallIndirect that matters, not just which functions are loaded. + let lf1 = dfb.load_func(id1.handle(), &[]).unwrap(); + let lf2 = dfb.load_func(id2.handle(), &[]).unwrap(); + let bool_func = || Type::new_function(b2b()); + let mut cond = dfb + .conditional_builder( + (vec![type_row![]; 2], which), + [(bool_func(), lf1), (bool_func(), lf2)], + bool_func().into(), + ) + .unwrap(); + let case_false = cond.case_builder(0).unwrap(); + let [f0, _f1] = case_false.input_wires_arr(); + case_false.finish_with_outputs([f0]).unwrap(); + let case_true = cond.case_builder(1).unwrap(); + let [_f0, f1] = case_true.input_wires_arr(); + case_true.finish_with_outputs([f1]).unwrap(); + let [tgt] = cond.finish_sub_container().unwrap().outputs_arr(); + let [res2] = dfb + .add_dataflow_op(CallIndirect { signature: b2b() }, [tgt, inp_indirect]) + .unwrap() + .outputs_arr(); + let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); + + let run = |which| { + Machine::new(&h).run( + TestContext, + [ + (0.into(), inp1.clone()), + (1.into(), which), + (2.into(), inp2.clone()), + ], + ) + }; + let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); + + // 1. Test with `which` unknown -> second output unknown + let results = run(PartialValue::Top); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); + + // 2. Test with `which` selecting second function -> both passthrough + let results = run(pv_true()); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); + + //3. Test with `which` selecting first function -> alias + let results = run(pv_false()); + let out = Some(inp1.join(inp2)); + assert_eq!(results.read_out_wire(w1), out); + assert_eq!(results.read_out_wire(w2), out); +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 50cf10318..43c842d91 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,25 +5,25 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::{lattice::BoundedLattice, Lattice}; +use ascent::Lattice; use itertools::zip_eq; use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] -pub(super) struct ValueRow(Vec>); +pub(super) struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PartialValue::Bottom; len]) } - pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { *self.0.get_mut(idx).unwrap() = v; self } - pub fn singleton(v: PartialValue) -> Self { + pub fn singleton(v: PartialValue) -> Self { Self(vec![v]) } @@ -34,25 +34,25 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option>> { + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } } -impl Lattice for ValueRow { +impl Lattice for ValueRow { fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; @@ -72,30 +72,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index b714dd6fd..69bcfabf6 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,13 +1,14 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -18,7 +19,6 @@ pub struct DeadCodeElimPass { /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [PreserveNode::default_for]. preserve_callback: Arc, - validation: ValidationLevel, } impl Default for DeadCodeElimPass { @@ -26,7 +26,6 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), - validation: ValidationLevel::default(), } } } @@ -39,13 +38,11 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a> { entry_points: &'a Vec, - validation: ValidationLevel, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, - validation: self.validation, }, f, ) @@ -86,13 +83,6 @@ impl PreserveNode { } impl DeadCodeElimPass { - /// Sets the validation level used before and after the pass is run - #[allow(unused)] - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows setting a callback that determines whether a node must be preserved /// (even when its result is not used) pub fn set_preserve_callback(mut self, cb: Arc) -> Self { @@ -146,24 +136,6 @@ impl DeadCodeElimPass { needed } - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - self.validation.run_validated_pass(hugr, |h, _| { - self.run_no_validate(h); - Ok(()) - }) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) { - let needed = self.find_needed_nodes(&*hugr); - let remove = hugr - .nodes() - .filter(|n| !needed.contains(n)) - .collect::>(); - for n in remove { - hugr.remove_node(n); - } - } - fn must_preserve( &self, h: &impl HugrView, @@ -173,6 +145,7 @@ impl DeadCodeElimPass { if let Some(res) = cache.get(&n) { return *res; } + #[allow(deprecated)] let res = match self.preserve_callback.as_ref()(h.base_hugr(), n) { PreserveNode::MustKeep => true, PreserveNode::CanRemoveIgnoringChildren => false, @@ -185,24 +158,41 @@ impl DeadCodeElimPass { } } +impl ComposablePass for DeadCodeElimPass { + type Node = Node; + type Error = Infallible; + type Result = (); + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } +} #[cfg(test)] mod test { use std::sync::Arc; use hugr_core::builder::{CFGBuilder, Container, Dataflow, DataflowSubContainer, HugrBuilder}; - use hugr_core::extension::prelude::{usize_t, ConstUsize, PRELUDE_ID}; + use hugr_core::extension::prelude::{usize_t, ConstUsize}; use hugr_core::ops::{handle::NodeHandle, OpTag, OpTrait}; use hugr_core::types::Signature; use hugr_core::{ops::Value, type_row, HugrView}; use itertools::Itertools; + use crate::ComposablePass; + use super::{DeadCodeElimPass, PreserveNode}; #[test] fn test_cfg_callback() { - let mut cb = - CFGBuilder::new(Signature::new_endo(type_row![]).with_extension_delta(PRELUDE_ID)) - .unwrap(); + let mut cb = CFGBuilder::new(Signature::new_endo(type_row![])).unwrap(); let cst_unused = cb.add_constant(Value::from(ConstUsize::new(3))); let cst_used_in_dfg = cb.add_constant(Value::from(ConstUsize::new(5))); let cst_used = cb.add_constant(Value::unary_unit_sum()); diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index b114a9e42..d1714eac9 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -10,7 +10,10 @@ use hugr_core::{ }; use petgraph::visit::{Dfs, Walker}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{ + composable::{validate_if_test, ValidatePassError}, + ComposablePass, +}; use super::call_graph::{CallGraph, CallGraphNode}; @@ -26,9 +29,6 @@ pub enum RemoveDeadFuncsError { /// The invalid node. node: N, }, - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), } fn reachable_funcs<'a, H: HugrView>( @@ -64,17 +64,10 @@ fn reachable_funcs<'a, H: HugrView>( #[derive(Debug, Clone, Default)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - validation: ValidationLevel, entry_points: Vec, } impl RemoveDeadFuncsPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Adds new entry points - these must be [FuncDefn] nodes /// that are children of the [Module] at the root of the Hugr. /// @@ -87,16 +80,33 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } +} - /// Runs the pass (see [remove_dead_funcs]) with this configuration - pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - remove_dead_funcs(hugr, self.entry_points.iter().cloned()) - }) +impl ComposablePass for RemoveDeadFuncsPass { + type Node = Node; + type Error = RemoveDeadFuncsError; + type Result = (); + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs( + &CallGraph::new(hugr), + hugr, + self.entry_points.iter().cloned(), + )? + .collect::>(); + let unreachable = hugr + .nodes() + .filter(|n| { + OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n) + }) + .collect::>(); + for n in unreachable { + hugr.remove_subtree(n); + } + Ok(()) } } -/// Delete from the Hugr any functions that are not used by either [Call] or +/// Deletes from the Hugr any functions that are not used by either [Call] or /// [LoadFunction] nodes in reachable parts. /// /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, @@ -116,18 +126,13 @@ impl RemoveDeadFuncsPass { /// [LoadFunction]: hugr_core::ops::OpType::LoadFunction /// [Module]: hugr_core::ops::OpType::Module pub fn remove_dead_funcs( - h: &mut impl HugrMut, + h: &mut impl HugrMut, entry_points: impl IntoIterator, -) -> Result<(), RemoveDeadFuncsError> { - let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); - let unreachable = h - .nodes() - .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) - .collect::>(); - for n in unreachable { - h.remove_subtree(n); - } - Ok(()) +) -> Result<(), ValidatePassError> { + validate_if_test( + RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + h, + ) } #[cfg(test)] @@ -142,7 +147,7 @@ mod test { }; use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; - use super::RemoveDeadFuncsPass; + use super::remove_dead_funcs; #[rstest] #[case([], vec![])] // No entry_points removes everything! @@ -182,15 +187,14 @@ mod test { }) .collect::>(); - RemoveDeadFuncsPass::default() - .with_module_entry_points( - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .run(&mut hugr) - .unwrap(); + remove_dead_funcs( + &mut hugr, + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/force_order.rs b/hugr-passes/src/force_order.rs index 689479b95..cbb637b2a 100644 --- a/hugr-passes/src/force_order.rs +++ b/hugr-passes/src/force_order.rs @@ -2,11 +2,7 @@ use std::{cmp::Reverse, collections::BinaryHeap, iter}; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, HierarchyView, SiblingGraph}, - HugrError, - }, + hugr::{hugrmut::HugrMut, HugrError}, ops::{NamedOp, OpTag, OpTrait}, types::EdgeKind, HugrView as _, Node, @@ -36,7 +32,7 @@ use petgraph::{ /// there is no path from `n2` to `n1` (otherwise this would invalidate `hugr`). /// Nodes of equal rank will be ordered arbitrarily, although that arbitrary /// order is deterministic. -pub fn force_order( +pub fn force_order>( hugr: &mut H, root: Node, rank: impl Fn(&H, Node) -> i64, @@ -46,39 +42,47 @@ pub fn force_order( /// As [force_order], but allows a generic [Ord] choice for the result of the /// `rank` function. -pub fn force_order_by_key( +pub fn force_order_by_key, K: Ord>( hugr: &mut H, root: Node, rank: impl Fn(&H, Node) -> K, ) -> Result<(), HugrError> { - let dataflow_parents = DescendantsGraph::::try_new(hugr, root)? - .nodes() + let dataflow_parents = hugr + .descendants(root) .filter(|n| hugr.get_optype(*n).tag() <= OpTag::DataflowParent) .collect_vec(); for dp in dataflow_parents { // we filter out the input and output nodes from the topological sort let [i, o] = hugr.get_io(dp).unwrap(); - let rank = |n| rank(hugr, n); - let sg = SiblingGraph::::try_new(hugr, dp)?; - let petgraph = NodeFiltered::from_fn(sg.as_petgraph(), |x| x != dp && x != i && x != o); - let ordered_nodes = ForceOrder::new(&petgraph, &rank) - .iter(&petgraph) - .filter(|&x| { - let expected_edge = Some(EdgeKind::StateOrder); - let optype = hugr.get_optype(x); - if optype.other_input() == expected_edge || optype.other_output() == expected_edge { - assert_eq!( - optype.other_input(), - optype.other_output(), - "Optype does not have both input and output order edge: {}", - optype.name() - ); - true - } else { - false - } - }) - .collect_vec(); + let ordered_nodes = { + let rank = |n| rank(hugr, hugr.from_portgraph_node(n)); + let sg = hugr.region_portgraph(dp); + let petgraph = NodeFiltered::from_fn(&sg, |x| { + let x = hugr.from_portgraph_node(x); + x != dp && x != i && x != o + }); + ForceOrder::new(&petgraph, &rank) + .iter(&petgraph) + .map(|x| hugr.from_portgraph_node(x)) + .filter(|&x| { + let expected_edge = Some(EdgeKind::StateOrder); + let optype = hugr.get_optype(x); + if optype.other_input() == expected_edge + || optype.other_output() == expected_edge + { + assert_eq!( + optype.other_input(), + optype.other_output(), + "Optype does not have both input and output order edge: {}", + optype.name() + ); + true + } else { + false + } + }) + .collect_vec() + }; // we iterate over the topologically sorted nodes, prepending the input // node and suffixing the output node. @@ -275,7 +279,7 @@ mod test { .iter(&hugr.as_petgraph()) .filter(|n| rank_map.contains_key(n)) .collect_vec(); - hugr.validate_no_extensions().unwrap(); + hugr.validate().unwrap(); topo_sorted } diff --git a/hugr-passes/src/half_node.rs b/hugr-passes/src/half_node.rs index ca0d9880e..7f332209f 100644 --- a/hugr-passes/src/half_node.rs +++ b/hugr-passes/src/half_node.rs @@ -3,12 +3,10 @@ use std::hash::Hash; use super::nest_cfgs::CfgNodeMap; use hugr_core::hugr::internal::HugrInternals; -use hugr_core::hugr::RootTagged; - +use hugr_core::hugr::views::RootCheckable; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{OpTag, OpTrait}; - -use hugr_core::{Direction, Node}; +use hugr_core::{Direction, HugrView, Node}; /// We provide a view of a cfg where every node has at most one of /// (multiple predecessors, multiple successors). @@ -32,9 +30,12 @@ struct HalfNodeView { exit: H::Node, } -impl> HalfNodeView { +impl HalfNodeView { #[allow(unused)] - pub(crate) fn new(h: H) -> Self { + pub(crate) fn new(h: impl RootCheckable>) -> Self { + let checked = h.try_into_checked().expect("Hugr must be a CFG region"); + let h = checked.into_hugr(); + let (entry, exit) = { let mut children = h.children(h.root()); (children.next().unwrap(), children.next().unwrap()) @@ -64,7 +65,7 @@ impl> HalfNodeView { } } -impl> CfgNodeMap> for HalfNodeView { +impl CfgNodeMap> for HalfNodeView { fn entry_node(&self) -> HalfNode { HalfNode::N(self.entry) } @@ -98,7 +99,6 @@ mod test { use super::super::nest_cfgs::{test::*, EdgeClassifier}; use super::{HalfNode, HalfNodeView}; use hugr_core::builder::BuildError; - use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::NodeHandle; use itertools::Itertools; @@ -118,7 +118,7 @@ mod test { // \---<---<---<---<---<---<---<---<---<---/ // Allowing to identify two nested regions (and fixing the problem with an IdentityCfgMap on the same example) - let v = HalfNodeView::new(RootChecked::try_new(&h).unwrap()); + let v = HalfNodeView::new(&h); let edge_classes = EdgeClassifier::get_edge_classes(&v); let HalfNodeView { h: _, entry, exit } = v; diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 961c4da47..83ff71b67 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod call_graph; +pub mod composable; +pub use composable::ComposablePass; pub mod const_fold; pub mod dataflow; pub mod dead_code; @@ -21,19 +23,11 @@ pub mod untuple; )] #[allow(deprecated)] pub use monomorphize::remove_polyfuncs; -// TODO: Deprecated re-export. Remove on a breaking release. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -#[allow(deprecated)] -pub use monomorphize::monomorphize; -pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{monomorphize, MonomorphizePass}; pub mod replace_types; pub use replace_types::ReplaceTypes; pub mod nest_cfgs; pub mod non_local; -pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 09e02c41d..334127bab 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -1,23 +1,20 @@ use hugr_core::{ - hugr::{hugrmut::HugrMut, views::SiblingSubgraph, HugrError}, + hugr::{hugrmut::HugrMut, views::SiblingSubgraph}, ops::OpType, Hugr, Node, }; +use itertools::Itertools; use thiserror::Error; /// Replace all operations in a HUGR according to a mapping. /// New operations must match the signature of the old operations. /// /// Returns a list of the replaced nodes and their old operations. -/// -/// # Errors -/// -/// Returns a [`HugrError`] if any replacement fails. pub fn replace_many_ops>( - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, mapping: impl Fn(&OpType) -> Option, -) -> Result, HugrError> { +) -> Vec<(Node, OpType)> { let replacements = hugr .nodes() .filter_map(|node| { @@ -28,13 +25,17 @@ pub fn replace_many_ops>( replacements .into_iter() - .map(|(node, new_op)| hugr.replace_op(node, new_op).map(|old_op| (node, old_op))) + .map(|(node, new_op)| { + let old_op = hugr.replace_op(node, new_op); + (node, old_op) + }) .collect() } /// Errors produced by the [`lower_ops`] function. #[derive(Debug, Error)] #[error(transparent)] +#[non_exhaustive] pub enum LowerError { /// Invalid subgraph. #[error("Subgraph formed by node is invalid: {0}")] @@ -53,7 +54,7 @@ pub enum LowerError { /// /// Returns a [`LowerError`] if the lowered HUGR is invalid or if any rewrite fails. pub fn lower_ops( - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, lowering: impl Fn(&OpType) -> Option, ) -> Result, LowerError> { let replacements = hugr @@ -69,9 +70,11 @@ pub fn lower_ops( .map(|(node, replacement)| { let subcirc = SiblingSubgraph::from_node(node, hugr); let rw = subcirc.create_simple_replacement(hugr, replacement)?; - let mut repls = hugr.apply_rewrite(rw)?; - debug_assert_eq!(repls.len(), 1); - Ok(repls.remove(0)) + let removed_nodes = hugr.apply_patch(rw)?.removed_nodes; + Ok(removed_nodes + .into_iter() + .exactly_one() + .expect("removed exactly one node")) }) .collect() } @@ -91,7 +94,7 @@ mod test { #[fixture] fn noop_hugr() -> Hugr { - let mut b = DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut b = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let out = b .add_dataflow_op(Noop::new(bool_t()), [b.input_wires().next().unwrap()]) .unwrap() @@ -116,8 +119,7 @@ mod test { } else { None } - }) - .unwrap(); + }); assert_eq!(replaced.len(), 1); let (n, op) = replaced.remove(0); @@ -139,6 +141,6 @@ mod test { }); assert_eq!(lowered.unwrap().len(), 1); - assert_eq!(h.node_count(), 3); // DFG, input, output + assert_eq!(h.num_nodes(), 3); // DFG, input, output } } diff --git a/hugr-passes/src/merge_bbs.rs b/hugr-passes/src/merge_bbs.rs index aeabc26ce..2c739c3d7 100644 --- a/hugr-passes/src/merge_bbs.rs +++ b/hugr-passes/src/merge_bbs.rs @@ -4,11 +4,11 @@ use std::collections::HashMap; use hugr_core::extension::prelude::UnpackTuple; use hugr_core::hugr::hugrmut::HugrMut; +use hugr_core::hugr::views::RootCheckable; use itertools::Itertools; -use hugr_core::hugr::rewrite::inline_dfg::InlineDFG; -use hugr_core::hugr::rewrite::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; -use hugr_core::hugr::RootTagged; +use hugr_core::hugr::patch::inline_dfg::InlineDFG; +use hugr_core::hugr::patch::replace::{NewEdgeKind, NewEdgeSpec, Replacement}; use hugr_core::ops::handle::CfgID; use hugr_core::ops::{DataflowBlock, DataflowParent, Input, Output, DFG}; use hugr_core::{Hugr, HugrView, Node}; @@ -16,8 +16,14 @@ use hugr_core::{Hugr, HugrView, Node}; /// Merge any basic blocks that are direct children of the specified CFG /// i.e. where a basic block B has a single successor B' whose only predecessor /// is B, B and B' can be combined. -pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { - let mut worklist = cfg.nodes().collect::>(); +pub fn merge_basic_blocks<'h, H>(cfg: impl RootCheckable<&'h mut H, CfgID>) +where + H: 'h + HugrMut, +{ + let checked = cfg.try_into_checked().expect("Hugr must be a CFG region"); + let cfg = checked.into_hugr(); + + let mut worklist = cfg.children(cfg.root()).collect::>(); while let Some(n) = worklist.pop() { // Consider merging n with its successor let Ok(succ) = cfg.output_neighbours(n).exactly_one() else { @@ -33,20 +39,18 @@ pub fn merge_basic_blocks(cfg: &mut impl HugrMut) { continue; }; let (rep, merge_bb, dfgs) = mk_rep(cfg, n, succ); - let node_map = cfg.hugr_mut().apply_rewrite(rep).unwrap(); + let node_map = cfg.apply_patch(rep).unwrap(); let merged_bb = *node_map.get(&merge_bb).unwrap(); for dfg_id in dfgs { let n_id = *node_map.get(&dfg_id).unwrap(); - cfg.hugr_mut() - .apply_rewrite(InlineDFG(n_id.into())) - .unwrap(); + cfg.apply_patch(InlineDFG(n_id.into())).unwrap(); } worklist.push(merged_bb); } } fn mk_rep( - cfg: &impl RootTagged, + cfg: &impl HugrView, pred: Node, succ: Node, ) -> (Replacement, Node, [Node; 2]) { @@ -55,17 +59,13 @@ fn mk_rep( let succ_sig = succ_ty.inner_signature(); // Make a Hugr with just a single CFG root node having the same signature. - let mut replacement: Hugr = Hugr::new(cfg.root_type().clone()); + let mut replacement: Hugr = Hugr::new(cfg.root_optype().clone()); let merged = replacement.add_node_with_parent(replacement.root(), { - let mut merged_block = DataflowBlock { + DataflowBlock { inputs: pred_ty.inputs.clone(), ..succ_ty.clone() - }; - merged_block.extension_delta = merged_block - .extension_delta - .union(pred_ty.extension_delta.clone()); - merged_block + } }); let input = replacement.add_node_with_parent( merged, @@ -165,9 +165,7 @@ mod test { use hugr_core::builder::{endo_sig, inout_sig, CFGBuilder, DFGWrapper, Dataflow, HugrBuilder}; use hugr_core::extension::prelude::{qb_t, usize_t, ConstUsize}; - use hugr_core::hugr::views::sibling::SiblingMut; use hugr_core::ops::constant::Value; - use hugr_core::ops::handle::CfgID; use hugr_core::ops::{LoadConstant, OpTrait, OpType}; use hugr_core::types::{Signature, Type, TypeRow}; use hugr_core::{const_extension_ids, type_row, Extension, Hugr, HugrView, Wire}; @@ -223,7 +221,7 @@ mod test { let e = extension(); let tst_op = e.instantiate_extension_op("Test", [])?; let mut h = CFGBuilder::new(inout_sig(loop_variants.clone(), exit_types.clone()))?; - let mut no_b1 = h.simple_entry_builder_exts(loop_variants.clone(), 1, PRELUDE_ID)?; + let mut no_b1 = h.simple_entry_builder(loop_variants.clone(), 1)?; let n = no_b1.add_dataflow_op(Noop::new(qb_t()), no_b1.input_wires())?; let br = unary_unit_sum(&mut no_b1); let no_b1 = no_b1.finish_with_outputs(br, n.outputs())?; @@ -254,7 +252,7 @@ mod test { let mut h = h.finish_hugr()?; let r = h.root(); - merge_basic_blocks(&mut SiblingMut::::try_new(&mut h, r)?); + merge_basic_blocks(&mut h); h.validate().unwrap(); assert_eq!(r, h.root()); assert!(matches!(h.get_optype(r), OpType::CFG(_))); @@ -348,8 +346,7 @@ mod test { h.branch(&bb3, 0, &h.exit_block())?; let mut h = h.finish_hugr()?; - let root = h.root(); - merge_basic_blocks(&mut SiblingMut::try_new(&mut h, root)?); + merge_basic_blocks(&mut h); h.validate()?; // Should only be one BB left diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 4f4e9bda2..cfe2c9514 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,7 +1,7 @@ use std::{ collections::{hash_map::Entry, HashMap}, + convert::Infallible, fmt::Write, - ops::Deref, }; use hugr_core::{ @@ -12,7 +12,9 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; -use thiserror::Error; + +use crate::composable::{validate_if_test, ValidatePassError}; +use crate::ComposablePass; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -30,26 +32,10 @@ use thiserror::Error; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. -pub fn monomorphize(mut h: Hugr) -> Hugr { - monomorphize_ref(&mut h); - h -} - -fn monomorphize_ref(h: &mut impl HugrMut) { - let root = h.root(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - if !h.get_optype(root).is_module() { - #[allow(deprecated)] // TODO remove in next breaking release and update docs - remove_polyfuncs_ref(h); - } - } +pub fn monomorphize( + hugr: &mut impl HugrMut, +) -> Result<(), ValidatePassError> { + validate_if_test(MonomorphizePass, hugr) } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -71,7 +57,7 @@ pub fn remove_polyfuncs(mut h: Hugr) -> Hugr { since = "0.14.1", note = "Use hugr_passes::RemoveDeadFuncsPass instead" )] -fn remove_polyfuncs_ref(h: &mut impl HugrMut) { +fn remove_polyfuncs_ref(h: &mut impl HugrMut) { let mut pfs_to_delete = Vec::new(); let mut to_scan = Vec::from_iter(h.children(h.root())); while let Some(n) = to_scan.pop() { @@ -107,7 +93,7 @@ type Instantiations = HashMap, Node>>; /// Optionally copies the subtree into a new location whilst applying a substitution. /// The subtree should be monomorphic after the substitution (if provided) has been applied. fn mono_scan( - h: &mut impl HugrMut, + h: &mut impl HugrMut, parent: Node, mut subst_into: Option<&mut Instantiating>, cache: &mut Instantiations, @@ -170,12 +156,12 @@ fn mono_scan( h.disconnect(ch, fn_inp); // No-op if copying+substituting h.connect(new_tgt, fn_out, ch, fn_inp); - h.replace_op(ch, new_op).unwrap(); + h.replace_op(ch, new_op); } } fn instantiate( - h: &mut impl HugrMut, + h: &mut impl HugrMut, poly_func: Node, type_args: Vec, mono_sig: Signature, @@ -191,7 +177,7 @@ fn instantiate( name: mangle_inner_func(&outer_name, &fd.name), signature: fd.signature.clone(), }; - h.replace_op(n, fd).unwrap(); + h.replace_op(n, fd); h.move_after_sibling(n, poly_func); } else { to_scan.extend(h.children(n)) @@ -254,8 +240,6 @@ fn instantiate( mono_tgt } -use crate::validation::{ValidatePassError, ValidationLevel}; - /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. /// @@ -271,38 +255,26 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone, Default)] -pub struct MonomorphizePass { - validation: ValidationLevel, -} - -#[derive(Debug, Error)] -#[non_exhaustive] -/// Errors produced by [MonomorphizePass]. -pub enum MonomorphizeError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), -} - -impl MonomorphizePass { - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Monomorphization pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { - monomorphize_ref(hugr); +#[derive(Debug, Clone)] +pub struct MonomorphizePass; + +impl ComposablePass for MonomorphizePass { + type Node = Node; + type Error = Infallible; + type Result = (); + + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(h, root, None, &mut HashMap::new()); + if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs + remove_polyfuncs_ref(h); + } + } Ok(()) } - - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } struct TypeArgsList<'a>(&'a [TypeArg]); @@ -327,10 +299,6 @@ fn write_type_arg_str(arg: &TypeArg, f: &mut std::fmt::Formatter<'_>) -> std::fm TypeArg::BoundedNat { n } => f.write_fmt(format_args!("n({n})")), TypeArg::String { arg } => f.write_fmt(format_args!("s({})", escape_dollar(arg))), TypeArg::Sequence { elems } => f.write_fmt(format_args!("seq({})", TypeArgsList(elems))), - TypeArg::Extensions { es } => f.write_fmt(format_args!( - "es({})", - es.iter().map(|x| x.deref()).join(",") - )), // We are monomorphizing. We will never monomorphize to a signature // containing a variable. TypeArg::Variable { .. } => panic!("type_arg_str variable: {arg}"), @@ -365,6 +333,7 @@ mod test { use std::iter; use hugr_core::extension::simple_op::MakeRegisteredOp as _; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; use hugr_core::std_extensions::collections; use hugr_core::std_extensions::collections::array::{array_type_parametric, ArrayOpDef}; use hugr_core::types::type_param::TypeParam; @@ -374,22 +343,16 @@ mod test { Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, }; - use hugr_core::extension::prelude::{ - usize_t, ConstUsize, UnpackTuple, UnwrapBuilder, PRELUDE_ID, - }; - use hugr_core::extension::ExtensionSet; + use hugr_core::extension::prelude::{usize_t, ConstUsize, UnpackTuple, UnwrapBuilder}; use hugr_core::ops::handle::{FuncID, NodeHandle}; use hugr_core::ops::{CallIndirect, DataflowOpTrait as _, FuncDefn, Tag}; - use hugr_core::std_extensions::arithmetic::int_types::{self, INT_TYPES}; - use hugr_core::types::{ - PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum, TypeRow, - }; + use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeEnum}; use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::remove_dead_funcs; + use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -399,10 +362,6 @@ mod test { Type::new_tuple(vec![ty.clone(), ty.clone(), ty]) } - fn prelusig(ins: impl Into, outs: impl Into) -> Signature { - Signature::new(ins, outs).with_extension_delta(PRELUDE_ID) - } - #[test] fn test_null() { let dfg_builder = @@ -410,7 +369,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass::default().run(&mut hugr2).unwrap(); + monomorphize(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -438,7 +397,7 @@ mod test { }; let tr = { - let sig = prelusig(tv0(), Type::new_tuple(vec![tv0(); 3])); + let sig = Signature::new(tv0(), Type::new_tuple(vec![tv0(); 3])); let mut fb = mb.define_function( "triple", PolyFuncType::new([TypeBound::Copyable.into()], sig), @@ -455,7 +414,7 @@ mod test { }; let mn = { let outs = vec![triple_type(usize_t()), triple_type(pair_type(usize_t()))]; - let mut fb = mb.define_function("main", prelusig(usize_t(), outs))?; + let mut fb = mb.define_function("main", Signature::new(usize_t(), outs))?; let [elem] = fb.input_wires_arr(); let [res1] = fb .call(tr.handle(), &[usize_t().into()], [elem])? @@ -472,7 +431,7 @@ mod test { .count(), 3 ); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -493,7 +452,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass::default().run(&mut mono2)?; + monomorphize(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -520,37 +479,30 @@ mod test { let n: u64 = 5; let mut outer = FunctionBuilder::new( "mainish", - prelusig( + Signature::new( array_type_parametric(sa(n), array_type_parametric(sa(2), usize_t()).unwrap()) .unwrap(), vec![usize_t(); 2], - ) - .with_extension_delta(collections::array::EXTENSION_ID), + ), ) .unwrap(); let arr2u = || array_type_parametric(sa(2), usize_t()).unwrap(); let pf1t = PolyFuncType::new( [TypeParam::max_nat()], - prelusig(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), + Signature::new(array_type_parametric(sv(0), arr2u()).unwrap(), usize_t()), ); let mut pf1 = outer.define_function("pf1", pf1t).unwrap(); let pf2t = PolyFuncType::new( [TypeParam::max_nat(), TypeBound::Copyable.into()], - prelusig(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)) - .with_extension_delta(collections::array::EXTENSION_ID), + Signature::new(vec![array_type_parametric(sv(0), tv(1)).unwrap()], tv(1)), ); let mut pf2 = pf1.define_function("pf2", pf2t).unwrap(); let mono_func = { let mut fb = pf2 - .define_function( - "get_usz", - prelusig(vec![], usize_t()) - .with_extension_delta(collections::array::EXTENSION_ID), - ) + .define_function("get_usz", Signature::new(vec![], usize_t())) .unwrap(); let cst0 = fb.add_load_value(ConstUsize::new(1)); fb.finish_with_outputs([cst0]).unwrap() @@ -601,7 +553,7 @@ mod test { .outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -662,7 +614,7 @@ mod test { let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); @@ -719,7 +671,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); @@ -733,8 +685,6 @@ mod test { #[case::string(vec!["arg".into()], "$foo$$s(arg)")] #[case::dollar_string(vec!["$arg".into()], "$foo$$s(\\$arg)")] #[case::sequence(vec![vec![0.into(), Type::UNIT.into()].into()], "$foo$$seq($n(0)$t(Unit))")] - #[case::extensionset(vec![ExtensionSet::from_iter([PRELUDE_ID,int_types::EXTENSION_ID]).into()], - "$foo$$es(arithmetic.int.types,prelude)")] // alphabetic ordering of extension names #[should_panic] #[case::typeargvariable(vec![TypeArg::new_var_use(1, TypeParam::String)], "$foo$$v(1)")] diff --git a/hugr-passes/src/nest_cfgs.rs b/hugr-passes/src/nest_cfgs.rs index 9baf250f9..3c15ca6f2 100644 --- a/hugr-passes/src/nest_cfgs.rs +++ b/hugr-passes/src/nest_cfgs.rs @@ -44,14 +44,14 @@ use std::hash::Hash; use itertools::Itertools; use thiserror::Error; -use hugr_core::hugr::rewrite::outline_cfg::OutlineCfg; +use hugr_core::hugr::patch::outline_cfg::OutlineCfg; use hugr_core::hugr::views::sibling::SiblingMut; -use hugr_core::hugr::views::{HierarchyView, HugrView, SiblingGraph}; -use hugr_core::hugr::{hugrmut::HugrMut, Rewrite, RootTagged}; +use hugr_core::hugr::views::{HierarchyView, HugrView, RootCheckable, SiblingGraph}; +use hugr_core::hugr::{hugrmut::HugrMut, Patch}; use hugr_core::ops::handle::{BasicBlockID, CfgID}; use hugr_core::ops::OpTag; use hugr_core::ops::OpTrait; -use hugr_core::{Direction, Hugr}; +use hugr_core::{Direction, Hugr, Node}; /// A "view" of a CFG in a Hugr which allows basic blocks in the underlying CFG to be split into /// multiple blocks in the view (or merged together). @@ -155,7 +155,7 @@ pub fn transform_cfg_to_nested( pub fn transform_all_cfgs(h: &mut Hugr) { let mut node_stack = Vec::from([h.root()]); while let Some(n) = node_stack.pop() { - if let Ok(s) = SiblingMut::::try_new(h, n) { + if let Ok(s) = SiblingMut::<_, CfgID>::try_new(h, n) { transform_cfg_to_nested(&mut IdentityCfgMap::new(s)); } node_stack.extend(h.children(n)) @@ -219,9 +219,12 @@ pub struct IdentityCfgMap { entry: H::Node, exit: H::Node, } -impl> IdentityCfgMap { +impl IdentityCfgMap { /// Creates an [IdentityCfgMap] for the specified CFG - pub fn new(h: H) -> Self { + pub fn new(h: impl RootCheckable>) -> Self { + let h = h.try_into_checked().expect("Hugr must be a CFG region"); + let h = h.into_hugr(); + // Panic if malformed enough not to have two children let (entry, exit) = h.children(h.root()).take(2).collect_tuple().unwrap(); debug_assert_eq!(h.get_optype(exit).tag(), OpTag::BasicBlockExit); @@ -246,7 +249,7 @@ impl CfgNodeMap for IdentityCfgMap { } } -impl CfgNester for IdentityCfgMap { +impl> CfgNester for IdentityCfgMap { fn nest_sese_region( &mut self, entry_edge: (H::Node, H::Node), @@ -257,7 +260,7 @@ impl CfgNester for IdentityCfgMap { assert!([entry_edge.0, entry_edge.1, exit_edge.0, exit_edge.1] .iter() .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); - let (new_block, new_cfg) = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); + let [new_block, new_cfg] = OutlineCfg::new(blocks).apply(&mut self.h).unwrap(); debug_assert!([entry_edge.0, exit_edge.1] .iter() .all(|n| self.h.get_parent(*n) == Some(self.h.root()))); @@ -574,9 +577,9 @@ pub(crate) mod test { use hugr_core::builder::{ endo_sig, BuildError, CFGBuilder, Container, DataflowSubContainer, HugrBuilder, }; - use hugr_core::extension::{prelude::usize_t, ExtensionSet}; + use hugr_core::extension::prelude::usize_t; - use hugr_core::hugr::rewrite::insert_identity::{IdentityInsertion, IdentityInsertionError}; + use hugr_core::hugr::patch::insert_identity::{IdentityInsertion, IdentityInsertionError}; use hugr_core::hugr::views::RootChecked; use hugr_core::ops::handle::{ConstID, NodeHandle}; use hugr_core::ops::Value; @@ -609,11 +612,7 @@ pub(crate) mod test { let const_unit = cfg_builder.add_constant(Value::unary_unit_sum()); let entry = n_identity( - cfg_builder.simple_entry_builder_exts( - vec![usize_t()].into(), - 1, - ExtensionSet::new(), - )?, + cfg_builder.simple_entry_builder(vec![usize_t()].into(), 1)?, &const_unit, )?; let (split, merge) = build_if_then_else_merge(&mut cfg_builder, &pred_const, &const_unit)?; @@ -636,7 +635,7 @@ pub(crate) mod test { let rc = RootChecked::<_, CfgID>::try_new(&mut h).unwrap(); let (entry, exit) = (entry.node(), exit.node()); let (split, merge, head, tail) = (split.node(), merge.node(), head.node(), tail.node()); - let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.borrow())); + let edge_classes = EdgeClassifier::get_edge_classes(&IdentityCfgMap::new(rc.as_ref())); let [&left, &right] = edge_classes .keys() .filter(|(s, _)| *s == split) @@ -734,7 +733,7 @@ pub(crate) mod test { // There's no need to use a view of a region here but we do so just to check // that we *can* (as we'll need to for "real" module Hugr's) - let v = IdentityCfgMap::new(SiblingGraph::try_new(&h, h.root()).unwrap()); + let v = IdentityCfgMap::new(SiblingGraph::::try_new(&h, h.root()).unwrap()); let edge_classes = EdgeClassifier::get_edge_classes(&v); let IdentityCfgMap { h: _, entry, exit } = v; let [&left, &right] = edge_classes @@ -760,7 +759,7 @@ pub(crate) mod test { // Again, there's no need for a view of a region here, but check that the // transformation still works when we can only directly mutate the top level let root = h.root(); - let m = SiblingMut::::try_new(&mut h, root).unwrap(); + let m = SiblingMut::<_, CfgID>::try_new(&mut h, root).unwrap(); transform_cfg_to_nested(&mut IdentityCfgMap::new(m)); h.validate().unwrap(); assert_eq!(1, depth(&h, entry)); @@ -827,7 +826,7 @@ pub(crate) mod test { let rw = IdentityInsertion::new(final_node, final_node_input); - let apply_result = h.apply_rewrite(rw); + let apply_result = h.apply_patch(rw); assert_eq!( apply_result, Err(IdentityInsertionError::InvalidPortKind(Some( diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index fca74657b..a2219d14f 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -23,6 +23,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator { #[error("Found {} nonlocal edges", .0.len())] Edges(Vec<(N, IncomingPort)>), @@ -53,8 +54,7 @@ mod test { #[test] fn ensures_no_nonlocal_edges() { let hugr = { - let mut builder = - DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [in_w] = builder.input_wires_arr(); let [out_w] = builder .add_dataflow_op(Noop::new(bool_t()), [in_w]) @@ -68,12 +68,11 @@ mod test { #[test] fn find_nonlocal_edges() { let (hugr, edge) = { - let mut builder = - DFGBuilder::new(Signature::new_endo(bool_t()).with_prelude()).unwrap(); + let mut builder = DFGBuilder::new(Signature::new_endo(bool_t())).unwrap(); let [in_w] = builder.input_wires_arr(); let ([out_w], edge) = { let mut dfg_builder = builder - .dfg_builder(Signature::new(type_row![], bool_t()).with_prelude(), []) + .dfg_builder(Signature::new(type_row![], bool_t()), []) .unwrap(); let noop = dfg_builder .add_dataflow_op(Noop::new(bool_t()), [in_w]) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a9..45bc25bcf 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -15,18 +15,19 @@ use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; -use hugr_core::ops::handle::DataflowOpID; +use hugr_core::ops::handle::{DataflowOpID, FuncID}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, + TypeTransformer, }; -use hugr_core::{Hugr, HugrView, Node, Wire}; +use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; mod linearize; pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; @@ -45,21 +46,41 @@ pub enum NodeTemplate { /// Note this will be of limited use before [monomorphization](super::monomorphize()) /// because the new subtree will not be able to use type variables present in the /// parent Hugr or previous op. - // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s - // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - // TODO allow also Call to a Node in the existing Hugr - // (can't see any other way to achieve multiple calls to the same decl. - // So client should add the functions before replacement, then remove unused ones afterwards.) + /// A Call to an existing function. + Call(Node, Vec), } impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + /// + /// # Panics + /// + /// * If `parent` is not in the `hugr` + /// + /// # Errors + /// + /// * If `self` is a [Self::Call] and the target Node either + /// * is neither a [FuncDefn] nor a [FuncDecl] + /// * has a [`signature`] which the type-args of the [Self::Call] do not match + /// + /// [`signature`]: hugr_core::types::PolyFuncType + pub fn add_hugr( + self, + hugr: &mut impl HugrMut, + parent: Node, + ) -> Result { match self { - NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), + NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), + NodeTemplate::Call(target, type_args) => { + let c = call(hugr, target, type_args)?; + let tgt_port = c.called_function_port(); + let n = hugr.add_node_with_parent(parent, c); + hugr.connect(target, 0, n, tgt_port); + Ok(n) + } } } @@ -72,10 +93,15 @@ impl NodeTemplate { match self { NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + // Really we should check whether func points at a FuncDecl or FuncDefn and create + // the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call. + NodeTemplate::Call(func, type_args) => { + dfb.call(&FuncID::::from(func), &type_args, inputs) + } } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -88,19 +114,57 @@ impl NodeTemplate { } root_opty } + NodeTemplate::Call(func, type_args) => { + let c = call(hugr, func, type_args)?; + let static_inport = c.called_function_port(); + // insert an input for the Call static input + hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1); + // connect the function to (what will be) the call + hugr.connect(func, 0, n, static_inport); + c.into() + } }; *hugr.optype_mut(n) = new_optype; + Ok(()) } - fn signature(&self) -> Option> { - match self { + fn check_signature( + &self, + inputs: &TypeRow, + outputs: &TypeRow, + ) -> Result<(), Option> { + let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, - NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::CompoundOp(hugr) => hugr.root_optype(), + NodeTemplate::Call(_, _) => return Ok(()), // no way to tell + } + .dataflow_signature(); + if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) { + Ok(()) + } else { + Err(sig.map(Cow::into_owned)) } - .dataflow_signature() } } +fn call>( + h: &H, + func: Node, + type_args: Vec, +) -> Result { + let func_sig = match h.get_optype(func) { + OpType::FuncDecl(fd) => fd.signature.clone(), + OpType::FuncDefn(fd) => fd.signature.clone(), + _ => { + return Err(BuildError::UnexpectedType { + node: func, + op_desc: "func defn/decl", + }) + } + }; + Ok(Call::try_new(func_sig, type_args)?) +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. /// @@ -143,7 +207,6 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - validation: ValidationLevel, } impl Default for ReplaceTypes { @@ -184,11 +247,11 @@ pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), #[error(transparent)] - ValidationError(#[from] ValidatePassError), - #[error(transparent)] ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), + #[error("Replacement op for {0} could not be added because {1}")] + AddTemplateError(Node, BuildError), } impl ReplaceTypes { @@ -203,16 +266,9 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - validation: Default::default(), } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Configures this instance to replace occurrences of type `src` with `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this @@ -323,37 +379,11 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { - let mut changed = false; - for n in hugr.nodes().collect::>() { - changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.root()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } - } - } - Ok(changed) - } - - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { + fn change_node( + &self, + hugr: &mut impl HugrMut, + n: Node, + ) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) | OpType::FuncDecl(FuncDecl { signature, .. }) => signature.body_mut().transform(self), @@ -410,8 +440,11 @@ impl ReplaceTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => Ok( + // Copy/discard insertion done by caller if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - replacement.replace(hugr, n); // Copy/discard insertion done by caller + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { let def = ext_op.def_arc(); @@ -422,7 +455,9 @@ impl ReplaceTypes { .get(&def.as_ref().into()) .and_then(|rep_fn| rep_fn(&args)) { - replacement.replace(hugr, n); + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { if ch { @@ -472,11 +507,41 @@ impl ReplaceTypes { false } }), - Value::Function { hugr } => self.run_no_validate(&mut **hugr), + Value::Function { hugr } => self.run(&mut **hugr), } } } +impl ComposablePass for ReplaceTypes { + type Node = Node; + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let mut changed = false; + for n in hugr.nodes().collect::>() { + changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + } + Ok(changed) + } +} + pub mod handlers; #[derive(Clone, Hash, PartialEq, Eq)] @@ -526,35 +591,30 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID, }; - use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; - + use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version}; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::hugr::{IdentList, ValidationError}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; - use hugr_core::std_extensions::collections::array::{ - array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, + use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef}; + use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; + use hugr_core::std_extensions::collections::{ + array::{self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, + list::{list_type, list_type_def, ListOp, ListValue}, }; - use hugr_core::std_extensions::collections::list::{ - list_type, list_type_def, ListOp, ListValue, - }; - - use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; - use crate::validation::ValidatePassError; + use crate::ComposablePass; - use super::ReplaceTypesError; use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; @@ -615,30 +675,37 @@ mod test { ) } - fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { - let ty = just_elem_type(args); - let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], - ty.clone(), - )) + fn lowered_read( + elem_ty: Type, + new: impl Fn(Signature) -> Result, + ) -> T { + let mut dfb = new(Signature::new( + vec![array_type(64, elem_ty.clone()), i64_t()], + elem_ty.clone(), + ) + .with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + array::EXTENSION_ID, + conversions::EXTENSION_ID, + ]))) + .unwrap(); + let [val, idx] = dfb.input_wires_arr(); + let [idx] = dfb + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let [opt] = dfb + .add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx]) + .unwrap() + .outputs_arr(); + let [res] = dfb + .build_unwrap_sum(1, option_type(Type::from(elem_ty)), opt) .unwrap(); - let [val, idx] = dfb.input_wires_arr(); - let [idx] = dfb - .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) - .unwrap() - .outputs_arr(); - let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) - .unwrap() - .outputs_arr(); - let [res] = dfb - .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) - .unwrap(); - Some(NodeTemplate::CompoundOp(Box::new( - dfb.finish_hugr_with_outputs([res]).unwrap(), - ))) - } + dfb.set_outputs([res]).unwrap(); + dfb + } + + fn lowerer(ext: &Arc) -> ReplaceTypes { let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); @@ -654,7 +721,13 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + Some(NodeTemplate::CompoundOp(Box::new( + lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) + .finish_hugr() + .unwrap(), + ))) + }); lw } @@ -672,8 +745,7 @@ mod test { let inps = fb.input_wires(); let id = fb.finish_with_outputs(inps).unwrap(); - let sig = Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()) - .with_extension_delta(ext.name.clone()); + let sig = Signature::new(vec![i64_t(), c_int.clone(), c_bool.clone()], bool_t()); let mut fb = mb.define_function("main", sig).unwrap(); let [idx, indices, bools] = fb.input_wires_arr(); let [indices] = fb @@ -939,7 +1011,7 @@ mod test { // list -> read -> usz just becomes list -> read -> qb // list> -> read> -> opt becomes list -> get -> opt assert_eq!( - h.root_type().dataflow_signature().unwrap().io(), + h.root_optype().dataflow_signature().unwrap().io(), ( &vec![list_type(qb_t()); 2].into(), &vec![qb_t(), option_type(qb_t()).into()].into() @@ -979,13 +1051,64 @@ mod test { let cu = cst.value().downcast_ref::().unwrap(); Ok(ConstInt::new_u(6, cu.value())?.into()) }); + + let mut h = backup.clone(); + repl.run(&mut h).unwrap(); // No validation here assert!( - matches!(repl.run(&mut backup.clone()), Err(ReplaceTypesError::ValidationError(ValidatePassError::OutputError { - err: ValidationError::IncompatiblePorts {from, to, ..}, .. - })) if backup.get_optype(from).is_const() && to == c.node()) + matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..}) + if backup.get_optype(from).is_const() && to == c.node()) ); repl.replace_consts_parametrized(array_type_def(), array_const); let mut h = backup; - repl.run(&mut h).unwrap(); // Includes validation + repl.run(&mut h).unwrap(); + h.validate_no_extensions().unwrap(); + } + + #[test] + fn op_to_call() { + let e = ext(); + let pv = e.get_type(PACKED_VEC).unwrap(); + let inner = pv.instantiate([usize_t().into()]).unwrap(); + let outer = pv + .instantiate([Type::new_extension(inner.clone()).into()]) + .unwrap(); + let mut dfb = DFGBuilder::new(inout_sig(vec![outer.into(), i64_t()], usize_t())).unwrap(); + let [outer, idx] = dfb.input_wires_arr(); + let [inner] = dfb + .add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx]) + .unwrap() + .outputs_arr(); + let res = dfb + .add_dataflow_op(read_op(&e, usize_t()), [inner, idx]) + .unwrap(); + let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap(); + let read_func = h + .insert_hugr( + h.root(), + lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { + FunctionBuilder::new( + "lowered_read", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + ) + }) + .finish_hugr() + .unwrap(), + ) + .new_root; + + let mut lw = lowerer(&e); + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + lw.run(&mut h).unwrap(); + + assert_eq!(h.output_neighbours(read_func).count(), 2); + let ext_op_names = h + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .map(|e| e.def().name()) + .sorted() + .collect_vec(); + assert_eq!(ext_op_names, ["get", "itousize", "panic",]); } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b..573188340 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -3,7 +3,6 @@ use hugr_core::builder::{endo_sig, inout_sig, DFGBuilder, Dataflow, DataflowHugr}; use hugr_core::extension::prelude::{option_type, UnwrapBuilder}; -use hugr_core::extension::ExtensionSet; use hugr_core::ops::{constant::OpaqueValue, Value}; use hugr_core::ops::{OpTrait, OpType, Tag}; use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; @@ -13,8 +12,8 @@ use hugr_core::std_extensions::collections::array::{ array_type, ArrayOpDef, ArrayRepeat, ArrayScan, ArrayValue, }; use hugr_core::std_extensions::collections::list::ListValue; +use hugr_core::type_row; use hugr_core::types::{SumType, Transformable, Type, TypeArg}; -use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; use super::{ @@ -67,10 +66,6 @@ pub fn array_const( Ok(Some(ArrayValue::new(elem_t, vals).into())) } -fn runtime_reqs(h: &Hugr) -> ExtensionSet { - h.signature(h.root()).unwrap().runtime_reqs.clone() -} - /// Handler for copying/discarding arrays if their elements have become linear. /// Included in [ReplaceTypes::default] and [DelegatingLinearizer::default]. /// @@ -92,12 +87,12 @@ pub fn linearize_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .unwrap(); + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; // Now array.scan that over the input array to get an array of unit (which can be discarded) - let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n, runtime_reqs(&map_fn)); + let array_scan = ArrayScan::new(ty.clone(), Type::UNIT, vec![], *n); let in_type = array_type(*n, ty.clone()); return Ok(NodeTemplate::CompoundOp(Box::new({ let mut dfb = DFGBuilder::new(inout_sig(in_type, type_row![])).unwrap(); @@ -131,8 +126,7 @@ pub fn linearize_array( .unwrap(); dfb.finish_hugr_with_outputs(none.outputs()).unwrap() }; - let repeats = - vec![ArrayRepeat::new(option_ty.clone(), *n, runtime_reqs(&fn_none)); num_new]; + let repeats = vec![ArrayRepeat::new(option_ty.clone(), *n); num_new]; let fn_none = dfb.add_load_value(Value::function(fn_none).unwrap()); repeats .into_iter() @@ -162,7 +156,7 @@ pub fn linearize_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .unwrap() + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly @@ -212,7 +206,6 @@ pub fn linearize_array( .chain(vec![option_array; num_new]) .collect(), *n, - runtime_reqs(©_elem), ); let copy_elem = dfb.add_load_value(Value::function(copy_elem).unwrap()); @@ -240,13 +233,7 @@ pub fn linearize_array( dfb.finish_hugr_with_outputs([val]).unwrap() }; - let unwrap_scan = ArrayScan::new( - option_ty.clone(), - ty.clone(), - vec![], - *n, - runtime_reqs(&unwrap_elem), - ); + let unwrap_scan = ArrayScan::new(option_ty.clone(), ty.clone(), vec![], *n); let unwrap_elem = dfb.add_load_value(Value::function(unwrap_elem).unwrap()); let out_arrays = std::iter::once(out_array1) diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 7b83717d0..81324dbee 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,10 +1,8 @@ -use std::borrow::Cow; -use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ - inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, + inout_sig, BuildError, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; @@ -50,7 +48,7 @@ pub trait Linearizer { /// if `src` is not a valid Wire (does not identify a dataflow out-port) fn insert_copy_discard( &self, - hugr: &mut impl HugrMut, + hugr: &mut impl HugrMut, src: Wire, targets: &[(Node, IncomingPort)], ) -> Result<(), LinearizeError> { @@ -76,9 +74,11 @@ pub trait Linearizer { tgt_parent, }); } + let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); + .copy_discard_op(&typ, targets.len())? + .add_hugr(hugr, src_parent) + .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -133,8 +133,9 @@ impl Default for DelegatingLinearizer { // rather than passing a &DelegatingLinearizer directly) pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); -#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[derive(Clone, Debug, thiserror::Error, PartialEq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] NeedCopyDiscard(Type), @@ -162,6 +163,10 @@ pub enum LinearizeError { /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] CopyableType(Type), + /// Error may be returned by a callback for e.g. a container because it could + /// not generate a [NodeTemplate] because of a problem with an element + #[error("Could not generate NodeTemplate for contained type {0} because {1}")] + NestedTemplateError(Type, BuildError), } impl DelegatingLinearizer { @@ -184,8 +189,10 @@ impl DelegatingLinearizer { /// /// * [LinearizeError::CopyableType] If `typ` is /// [Copyable](hugr_core::types::TypeBound::Copyable) - /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the - /// expected inputs or outputs + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the expected + /// inputs or outputs (for [NodeTemplate::SingleOp] and [NodeTemplate::CompoundOp] + /// only: the signature for a [NodeTemplate::Call] cannot be checked until it is used + /// in a Hugr). pub fn register_simple( &mut self, cty: CustomType, @@ -229,18 +236,12 @@ impl DelegatingLinearizer { } fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { - let sig = tmpl.signature(); - if sig.as_ref().is_some_and(|sig| { - sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) - }) { - Ok(()) - } else { - Err(LinearizeError::WrongSignature { + tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + .map_err(|sig| LinearizeError::WrongSignature { typ: typ.clone(), num_outports, - sig: sig.map(Cow::into_owned), + sig, }) - } } impl Linearizer for DelegatingLinearizer { @@ -271,7 +272,7 @@ impl Linearizer for DelegatingLinearizer { let mut elems_for_copy = vec![vec![]; num_outports]; for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { let inp_copies = if ty.copyable() { - repeat(inp).take(num_outports).collect::>() + std::iter::repeat_n(inp, num_outports).collect::>() } else { self.copy_discard_op(ty, num_outports)? .add(&mut case_b, [inp]) @@ -352,7 +353,10 @@ mod test { use std::iter::successors; use std::sync::Arc; - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; + use hugr_core::builder::{ + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, + }; use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; @@ -376,7 +380,7 @@ mod test { use crate::replace_types::handlers::linearize_array; use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; - use crate::ReplaceTypes; + use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; @@ -618,10 +622,7 @@ mod test { NodeTemplate::SingleOp(copy3.clone()), NodeTemplate::SingleOp(discard.clone().into()), ); - let sig3 = Some( - Signature::new(lin_t.clone(), vec![lin_t.clone(); 3]) - .with_extension_delta(ext.name().clone()), - ); + let sig3 = Some(Signature::new(lin_t.clone(), vec![lin_t.clone(); 3])); assert_eq!( bad_copy, Err(LinearizeError::WrongSignature { @@ -767,4 +768,64 @@ mod test { )); assert_eq!(copy_sig.input[2..], copy_sig.output[1..]); } + + #[test] + fn call_ok_except_in_array() { + let (e, _) = ext_lowerer(); + let lin_ct = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t: Type = lin_ct.clone().into(); + + // A simple Hugr that discards a usize_t, with a "drop" function + let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); + let discard_fn = { + let mut fb = dfb + .define_function("drop", Signature::new(lin_t.clone(), type_row![])) + .unwrap(); + let ins = fb.input_wires(); + fb.add_dataflow_op( + ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(), + ins, + ) + .unwrap(); + fb.finish_with_outputs([]).unwrap() + } + .node(); + let backup = dfb.finish_hugr().unwrap(); + + let mut lower_discard_to_call = ReplaceTypes::default(); + // The `copy_fn` here will break completely, but we don't use it + lower_discard_to_call + .linearizer() + .register_simple( + lin_ct.clone(), + NodeTemplate::Call(backup.root(), vec![]), + NodeTemplate::Call(discard_fn, vec![]), + ) + .unwrap(); + + // Ok to lower usize_t to lin_t and call that function + { + let mut lowerer = lower_discard_to_call.clone(); + lowerer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + let mut h = backup.clone(); + lowerer.run(&mut h).unwrap(); + assert_eq!(h.output_neighbours(discard_fn).count(), 1); + } + + // But if we lower usize_t to array, the call will fail + lower_discard_to_call.replace_type( + usize_t().as_extension().unwrap().clone(), + array_type(4, lin_ct.into()), + ); + let r = lower_discard_to_call.run(&mut backup.clone()); + assert!(matches!( + r, + Err(ReplaceTypesError::LinearizeError( + LinearizeError::NestedTemplateError( + nested_t, + BuildError::UnexpectedType { node, .. } + ) + )) if nested_t == lin_t && node == discard_fn + )); + } } diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index dbe04edd1..1c9be1c75 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -10,19 +10,19 @@ use hugr_core::hugr::views::SiblingSubgraph; use hugr_core::hugr::SimpleReplacementError; use hugr_core::ops::{NamedOp, OpTrait, OpType}; use hugr_core::types::Type; -use hugr_core::{HugrView, SimpleReplacement}; +use hugr_core::{HugrView, Node, SimpleReplacement}; use itertools::Itertools; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration enum for the untuple rewrite pass. /// /// Indicates whether the pattern match should traverse the HUGR recursively. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum UntupleRecursive { - /// Traverse the HUGR recursively. + /// Traverse the HUGR recursively, i.e. consider the entire subtree Recursive, - /// Do not traverse the HUGR recursively. + /// Do not traverse the HUGR recursively, i.e. consider only the sibling subgraph #[default] NonRecursive, } @@ -48,22 +48,20 @@ pub enum UntupleRecursive { pub struct UntuplePass { /// Whether to traverse the HUGR recursively. recursive: UntupleRecursive, - /// The level of validation to perform on the rewrite. - validation: ValidationLevel, + /// Parent node under which to operate; None indicates the Hugr root + parent: Option, } #[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)] #[non_exhaustive] /// Errors produced by [UntuplePass]. pub enum UntupleError { - /// An error occurred while validating the rewrite. - ValidationError(ValidatePassError), /// Rewriting the circuit failed. RewriteError(SimpleReplacementError), } /// Result type for the untuple pass. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy, Default, PartialEq)] pub struct UntupleResult { /// Number of `MakeTuple` rewrites applied. pub rewrites_applied: usize, @@ -71,16 +69,16 @@ pub struct UntupleResult { impl UntuplePass { /// Create a new untuple pass with the given configuration. - pub fn new(recursive: UntupleRecursive, validation: ValidationLevel) -> Self { + pub fn new(recursive: UntupleRecursive) -> Self { Self { recursive, - validation, + parent: None, } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; + /// Sets the parent node to optimize (overwrites any previous setting) + pub fn set_parent(mut self, parent: impl Into>) -> Self { + self.parent = parent.into(); self } @@ -90,31 +88,6 @@ impl UntuplePass { self } - /// Run the pass using specified configuration. - pub fn run( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr, parent)) - } - - /// Run the Monomorphization pass. - fn run_no_validate( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - let rewrites = self.find_rewrites(hugr, parent); - let rewrites_applied = rewrites.len(); - // The rewrites are independent, so we can always apply them all. - for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; - } - Ok(UntupleResult { rewrites_applied }) - } - /// Find tuple pack operations followed by tuple unpack operations /// and generate rewrites to remove them. /// @@ -148,6 +121,22 @@ impl UntuplePass { } } +impl ComposablePass for UntuplePass { + type Node = Node; + type Error = UntupleError; + type Result = UntupleResult; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); + let rewrites_applied = rewrites.len(); + // The rewrites are independent, so we can always apply them all. + for rewrite in rewrites { + hugr.apply_patch(rewrite)?; + } + Ok(UntupleResult { rewrites_applied }) + } +} + /// Returns true if the given optype is a MakeTuple operation. /// /// Boilerplate required due to https://github.com/CQCL/hugr/issues/1496 @@ -258,7 +247,7 @@ fn remove_pack_unpack<'h, T: HugrView>( .add_dataflow_op(op, replacement.input_wires()) .unwrap() .outputs_arr(); - outputs.extend(std::iter::repeat(tuple).take(num_other_outputs)) + outputs.extend(std::iter::repeat_n(tuple, num_other_outputs)) } // These should never fail, as we are defining the replacement ourselves. @@ -289,9 +278,7 @@ mod test { /// These can be removed entirely. #[fixture] fn unused_pack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![]).with_prelude()) - .unwrap(); + let mut h = DFGBuilder::new(Signature::new(vec![bool_t(), bool_t()], vec![])).unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -306,8 +293,7 @@ mod test { /// These can be removed entirely. #[fixture] fn simple_pack_unpack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()]).with_prelude()).unwrap(); + let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap(); let mut inps = h.input_wires(); let qb1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -326,8 +312,7 @@ mod test { /// we just remove everything. #[fixture] fn ordered_pack_unpack() -> Hugr { - let mut h = - DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()]).with_prelude()).unwrap(); + let mut h = DFGBuilder::new(Signature::new_endo(vec![qb_t(), bool_t()])).unwrap(); let mut inps = h.input_wires(); let qb1 = inps.next().unwrap(); let b2 = inps.next().unwrap(); @@ -349,13 +334,10 @@ mod test { /// These can be removed entirely. #[fixture] fn multi_unpack() -> Hugr { - let mut h = DFGBuilder::new( - Signature::new( - vec![bool_t(), bool_t()], - vec![bool_t(), bool_t(), bool_t(), bool_t()], - ) - .with_prelude(), - ) + let mut h = DFGBuilder::new(Signature::new( + vec![bool_t(), bool_t()], + vec![bool_t(), bool_t(), bool_t(), bool_t()], + )) .unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); @@ -380,17 +362,14 @@ mod test { /// The unpack operation can be removed, but the pack operation cannot. #[fixture] fn partial_unpack() -> Hugr { - let mut h = DFGBuilder::new( - Signature::new( - vec![bool_t(), bool_t()], - vec![ - bool_t(), - bool_t(), - Type::new_tuple(vec![bool_t(), bool_t()]), - ], - ) - .with_prelude(), - ) + let mut h = DFGBuilder::new(Signature::new( + vec![bool_t(), bool_t()], + vec![ + bool_t(), + bool_t(), + Type::new_tuple(vec![bool_t(), bool_t()]), + ], + )) .unwrap(); let mut inps = h.input_wires(); let b1 = inps.next().unwrap(); @@ -421,7 +400,8 @@ mod test { let parent = hugr.root(); let res = pass - .run(&mut hugr, parent) + .set_parent(parent) + .run(&mut hugr) .unwrap_or_else(|e| panic!("{e}")); assert_eq!(res.rewrites_applied, expected_rewrites); assert_eq!(hugr.children(parent).count(), remaining_nodes); diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 5f53f403c..90d338faf 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -16,15 +16,13 @@ use hugr_core::HugrView; pub enum ValidationLevel { /// Do no verification. None, - /// Validate using [HugrView::validate_no_extensions]. This is useful when you - /// do not expect valid Extension annotations on Nodes. - WithoutExtensions, /// Validate using [HugrView::validate]. - WithExtensions, + Validate, } #[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] InputError { @@ -43,8 +41,7 @@ pub enum ValidatePassError { impl Default for ValidationLevel { fn default() -> Self { if cfg!(test) { - // Many tests fail when run with Self::WithExtensions - Self::WithoutExtensions + Self::Validate } else { Self::None } @@ -85,8 +82,7 @@ impl ValidationLevel { { match self { ValidationLevel::None => Ok(()), - ValidationLevel::WithoutExtensions => hugr.validate_no_extensions(), - ValidationLevel::WithExtensions => hugr.validate(), + ValidationLevel::Validate => hugr.validate(), } .map_err(|err| mk_err(err, hugr.mermaid_string()).into()) } diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 429bdd785..3bb377ed5 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -8,7 +8,6 @@ from hugr.hugr.base import Hugr from hugr.utils import deser_it -from .ops import Value from .serial_hugr import SerialHugr, serialization_version from .tys import ( ConfiguredBaseModel, @@ -20,7 +19,6 @@ ) if TYPE_CHECKING: - from .ops import Value from .serial_hugr import SerialHugr @@ -62,20 +60,6 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef: ) -class ExtensionValue(ConfiguredBaseModel): - extension: ExtensionId - name: str - typed_value: Value - - def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: - return extension.add_extension_value( - ext.ExtensionValue( - name=self.name, - val=self.typed_value.deserialize(), - ) - ) - - # -------------------------------------- # --------------- OpDef ---------------- # -------------------------------------- @@ -102,9 +86,7 @@ class OpDef(ConfiguredBaseModel, populate_by_name=True): def deserialize(self, extension: ext.Extension) -> ext.OpDef: signature = ext.OpDefSig( - self.signature.deserialize().with_runtime_reqs([extension.name]) - if self.signature - else None, + self.signature.deserialize() if self.signature else None, self.binary, ) @@ -122,9 +104,7 @@ def deserialize(self, extension: ext.Extension) -> ext.OpDef: class Extension(ConfiguredBaseModel): version: SemanticVersion name: ExtensionId - runtime_reqs: set[ExtensionId] types: dict[str, TypeDef] - values: dict[str, ExtensionValue] operations: dict[str, OpDef] @classmethod @@ -135,7 +115,6 @@ def deserialize(self) -> ext.Extension: e = ext.Extension( version=self.version, # type: ignore[arg-type] name=self.name, - runtime_reqs=self.runtime_reqs, ) for k, t in self.types.items(): @@ -146,10 +125,6 @@ def deserialize(self) -> ext.Extension: assert k == o.name, "Operation name must match key" e.add_op_def(o.deserialize(e)) - for k, v in self.values.items(): - assert k == v.name, "Value name must match key" - e.add_extension_value(v.deserialize(e)) - return e diff --git a/hugr-py/src/hugr/_serialization/ops.py b/hugr-py/src/hugr/_serialization/ops.py index 48b4e6b87..28a1daf5e 100644 --- a/hugr-py/src/hugr/_serialization/ops.py +++ b/hugr-py/src/hugr/_serialization/ops.py @@ -206,7 +206,6 @@ class DataflowBlock(BaseOp): inputs: TypeRow = Field(default_factory=list) other_outputs: TypeRow = Field(default_factory=list) sum_rows: list[TypeRow] - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: num_cases = len(out_types) @@ -384,13 +383,11 @@ class DFG(DataflowOp): signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: - self.signature = FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) - ) + self.signature = FunctionType(input=list(inputs), output=list(outputs)) def deserialize(self) -> ops.DFG: sig = self.signature.deserialize() - return ops.DFG(sig.input, sig.output, sig.runtime_reqs) + return ops.DFG(sig.input, sig.output) # ------------------------------------------------ @@ -407,8 +404,6 @@ class Conditional(DataflowOp): sum_rows: list[TypeRow] = Field( description="The possible rows of the Sum input", default_factory=list ) - # Extensions used to produce the outputs - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: # First port is a predicate, i.e. a sum of tuple types. We need to unpack @@ -442,9 +437,7 @@ class Case(BaseOp): signature: FunctionType = Field(default_factory=FunctionType.empty) def insert_child_dfg_signature(self, inputs: TypeRow, outputs: TypeRow) -> None: - self.signature = stys.FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) - ) + self.signature = stys.FunctionType(input=list(inputs), output=list(outputs)) def deserialize(self) -> ops.Case: sig = self.signature.deserialize() @@ -455,11 +448,12 @@ class TailLoop(DataflowOp): """Tail-controlled loop.""" op: Literal["TailLoop"] = "TailLoop" - just_inputs: TypeRow = Field(default_factory=list) # Types that are only input - just_outputs: TypeRow = Field(default_factory=list) # Types that are only output + # Types that are only input + just_inputs: TypeRow = Field(default_factory=list) + # Types that are only output + just_outputs: TypeRow = Field(default_factory=list) # Types that are appended to both input and output: rest: TypeRow = Field(default_factory=list) - extension_delta: ExtensionSet = Field(default_factory=list) def insert_port_types(self, in_types: TypeRow, out_types: TypeRow) -> None: assert in_types == out_types @@ -472,7 +466,6 @@ def deserialize(self) -> ops.TailLoop: just_inputs=deser_it(self.just_inputs), _just_outputs=deser_it(self.just_outputs), rest=deser_it(self.rest), - extension_delta=self.extension_delta, ) @@ -484,7 +477,8 @@ class CFG(DataflowOp): def insert_port_types(self, inputs: TypeRow, outputs: TypeRow) -> None: self.signature = FunctionType( - input=list(inputs), output=list(outputs), runtime_reqs=ExtensionSet([]) + input=list(inputs), + output=list(outputs), ) def deserialize(self) -> ops.CFG: diff --git a/hugr-py/src/hugr/_serialization/tys.py b/hugr-py/src/hugr/_serialization/tys.py index 4a0a0e75b..c00a73375 100644 --- a/hugr-py/src/hugr/_serialization/tys.py +++ b/hugr-py/src/hugr/_serialization/tys.py @@ -110,23 +110,11 @@ def deserialize(self) -> tys.TupleParam: return tys.TupleParam(params=deser_it(self.params)) -class ExtensionsParam(BaseTypeParam): - tp: Literal["Extensions"] = "Extensions" - - def deserialize(self) -> tys.ExtensionsParam: - return tys.ExtensionsParam() - - class TypeParam(RootModel): """A type parameter.""" root: Annotated[ - TypeTypeParam - | BoundedNatParam - | StringParam - | ListParam - | TupleParam - | ExtensionsParam, + TypeTypeParam | BoundedNatParam | StringParam | ListParam | TupleParam, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tp") @@ -178,14 +166,6 @@ def deserialize(self) -> tys.SequenceArg: return tys.SequenceArg(elems=deser_it(self.elems)) -class ExtensionsArg(BaseTypeArg): - tya: Literal["Extensions"] = "Extensions" - es: ExtensionSet - - def deserialize(self) -> tys.ExtensionsArg: - return tys.ExtensionsArg(extensions=self.es) - - class VariableArg(BaseTypeArg): tya: Literal["Variable"] = "Variable" idx: int @@ -199,12 +179,7 @@ class TypeArg(RootModel): """A type argument.""" root: Annotated[ - TypeTypeArg - | BoundedNatArg - | StringArg - | SequenceArg - | ExtensionsArg - | VariableArg, + TypeTypeArg | BoundedNatArg | StringArg | SequenceArg | VariableArg, WrapValidator(_json_custom_error_validator), ] = Field(discriminator="tya") @@ -307,18 +282,15 @@ class FunctionType(BaseType): input: TypeRow # Value inputs of the function. output: TypeRow # Value outputs of the function. - # The extension requirements which are added by the operation - runtime_reqs: ExtensionSet = Field(default_factory=ExtensionSet) @classmethod def empty(cls) -> FunctionType: - return FunctionType(input=[], output=[], runtime_reqs=[]) + return FunctionType(input=[], output=[]) def deserialize(self) -> tys.FunctionType: return tys.FunctionType( input=deser_it(self.input), output=deser_it(self.output), - runtime_reqs=self.runtime_reqs, ) model_config = ConfigDict( diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 494ea3c69..fd59da0fc 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -8,7 +8,7 @@ from semver import Version import hugr._serialization.extension as ext_s -from hugr import ops, tys, val +from hugr import ops, tys from hugr.utils import ser_it __all__ = [ @@ -18,7 +18,6 @@ "FixedHugr", "OpDefSig", "OpDef", - "ExtensionValue", "Extension", "Version", ] @@ -236,33 +235,9 @@ def instantiate( concrete_signature: Concrete function type of the operation, only required if the operation is polymorphic. """ - # Add the extension where the operation is defined as a runtime requirement. - # We don't store this in the json definition as it is redundant information. - if concrete_signature is not None: - concrete_signature = concrete_signature.with_runtime_reqs( - [self.get_extension().name] - ) - return ops.ExtOp(self, concrete_signature, list(args or [])) -@dataclass -class ExtensionValue(ExtensionObject): - """A value defined in an :class:`Extension`.""" - - #: The name of the value. - name: str - #: Value payload. - val: val.Value - - def _to_serial(self) -> ext_s.ExtensionValue: - return ext_s.ExtensionValue( - extension=self.get_extension().name, - name=self.name, - typed_value=self.val._to_serial_root(), - ) - - T = TypeVar("T", bound=ops.RegisteredOp) @@ -274,12 +249,8 @@ class Extension: name: ExtensionId #: The version of the extension. version: Version - #: Extensions required by this extension at runtime, identified by name. - runtime_reqs: set[ExtensionId] = field(default_factory=set) #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) - #: Values defined in the extension. - values: dict[str, ExtensionValue] = field(default_factory=dict) #: Operation definitions in the extension. operations: dict[str, OpDef] = field(default_factory=dict) @@ -293,9 +264,7 @@ def _to_serial(self) -> ext_s.Extension: return ext_s.Extension( name=self.name, version=self.version, # type: ignore[arg-type] - runtime_reqs=self.runtime_reqs, types={k: v._to_serial() for k, v in self.types.items()}, - values={k: v._to_serial() for k, v in self.values.items()}, operations={k: v._to_serial() for k, v in self.operations.items()}, ) @@ -324,12 +293,6 @@ def add_op_def(self, op_def: OpDef) -> OpDef: Returns: The added operation definition, now associated with the extension. """ - if op_def.signature.poly_func is not None: - # Ensure the op def signature has the extension as a requirement - op_def.signature.poly_func = op_def.signature.poly_func.with_runtime_reqs( - [self.name] - ) - op_def._extension = self self.operations[op_def.name] = op_def return self.operations[op_def.name] @@ -347,19 +310,6 @@ def add_type_def(self, type_def: TypeDef) -> TypeDef: self.types[type_def.name] = type_def return self.types[type_def.name] - def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: - """Add a value to the extension. - - Args: - extension_value: The value to add. - - Returns: - The added value, now associated with the extension. - """ - extension_value._extension = self - self.values[extension_value.name] = extension_value - return self.values[extension_value.name] - @dataclass class OperationNotFound(NotFound): """Operation not found in extension.""" @@ -406,12 +356,6 @@ def get_type(self, name: str) -> TypeDef: class ValueNotFound(NotFound): """Value not found in extension.""" - def get_value(self, name: str) -> ExtensionValue: - try: - return self.values[name] - except KeyError as e: - raise self.ValueNotFound(name) from e - T = TypeVar("T", bound=ops.RegisteredOp) def register_op( diff --git a/hugr-py/src/hugr/ops.py b/hugr-py/src/hugr/ops.py index 1555dab4d..b6030b6a0 100644 --- a/hugr-py/src/hugr/ops.py +++ b/hugr-py/src/hugr/ops.py @@ -456,7 +456,6 @@ def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType( input=self.types, output=[tys.Tuple(*self.types)], - runtime_reqs=["prelude"], ) def type_args(self) -> list[tys.TypeArg]: @@ -499,7 +498,6 @@ def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType( input=[tys.Tuple(*self.types)], output=self.types, - runtime_reqs=["prelude"], ) def type_args(self) -> list[tys.TypeArg]: @@ -632,7 +630,6 @@ class DFG(DfParentOp, DataflowOp): #: Inputs types of the operation. inputs: tys.TypeRow _outputs: tys.TypeRow | None = field(default=None, repr=False) - _extension_delta: tys.ExtensionSet = field(default_factory=list, repr=False) @property def outputs(self) -> tys.TypeRow: @@ -650,7 +647,7 @@ def signature(self) -> tys.FunctionType: Raises: IncompleteOp: If the outputs have not been set. """ - return tys.FunctionType(self.inputs, self.outputs, self._extension_delta) + return tys.FunctionType(self.inputs, self.outputs) @property def num_out(self) -> int: @@ -729,7 +726,6 @@ class DataflowBlock(DfParentOp): inputs: tys.TypeRow _sum: tys.Sum | None = None _other_outputs: tys.TypeRow | None = field(default=None, repr=False) - extension_delta: tys.ExtensionSet = field(default_factory=list) @property def sum_ty(self) -> tys.Sum: @@ -762,7 +758,6 @@ def _to_serial(self, parent: Node) -> sops.DataflowBlock: inputs=ser_it(self.inputs), sum_rows=list(map(ser_it, self.sum_ty.variant_rows)), other_outputs=ser_it(self.other_outputs), - extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: @@ -993,7 +988,6 @@ class TailLoop(DfParentOp, DataflowOp): #: Types that are appended to both inputs and outputs of the graph. rest: tys.TypeRow _just_outputs: tys.TypeRow | None = field(default=None, repr=False) - extension_delta: tys.ExtensionSet = field(default_factory=list, repr=False) @property def just_outputs(self) -> tys.TypeRow: @@ -1014,7 +1008,6 @@ def _to_serial(self, parent: Node) -> sops.TailLoop: just_inputs=ser_it(self.just_inputs), just_outputs=ser_it(self.just_outputs), rest=ser_it(self.rest), - extension_delta=self.extension_delta, ) def inner_signature(self) -> tys.FunctionType: @@ -1334,13 +1327,11 @@ def type_args(self) -> list[tys.TypeArg]: def cached_signature(self) -> tys.FunctionType | None: return tys.FunctionType.endo( [self.type_], - runtime_reqs=["prelude"], ) def outer_signature(self) -> tys.FunctionType: return tys.FunctionType.endo( [self.type_], - runtime_reqs=["prelude"], ) def _set_in_types(self, types: tys.TypeRow) -> None: diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json index 9c0054354..3c2fb983c 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json @@ -1,12 +1,7 @@ { "version": "0.1.0", "name": "arithmetic.conversions", - "runtime_reqs": [ - "arithmetic.float.types", - "arithmetic.int.types" - ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", @@ -37,8 +32,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -72,8 +66,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -116,8 +109,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -160,8 +152,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -193,8 +184,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -224,8 +214,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -257,8 +246,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -301,8 +289,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -345,8 +332,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -376,8 +362,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -437,8 +422,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -498,8 +482,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 31ccaaa59..60180ec84 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -1,11 +1,7 @@ { "version": "0.1.0", "name": "arithmetic.float", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", @@ -31,8 +27,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -68,8 +63,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -98,8 +92,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -135,8 +128,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -170,8 +162,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -200,8 +191,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -235,8 +225,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -270,8 +259,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -305,8 +293,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -340,8 +327,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -377,8 +363,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -414,8 +399,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -451,8 +435,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -486,8 +469,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -516,8 +498,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -553,8 +534,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -583,8 +563,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -620,8 +599,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -650,8 +628,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json index 56e35c50b..33db43f5b 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float.types", - "runtime_reqs": [], "types": { "float64": { "extension": "arithmetic.float.types", @@ -14,6 +13,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 62d0a6663..e8e6fdca8 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -1,11 +1,7 @@ { "version": "0.1.0", "name": "arithmetic.int", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", @@ -54,8 +50,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -123,8 +118,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -192,8 +186,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -278,8 +271,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -364,8 +356,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -433,8 +424,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -502,8 +492,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -612,8 +601,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -722,8 +710,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -807,8 +794,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -892,8 +878,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -950,8 +935,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1008,8 +992,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1066,8 +1049,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1124,8 +1106,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1182,8 +1163,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1240,8 +1220,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1298,8 +1277,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1356,8 +1334,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1414,8 +1391,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1483,8 +1459,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1552,8 +1527,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1621,8 +1595,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1690,8 +1663,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1776,8 +1748,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1862,8 +1833,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1931,8 +1901,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2000,8 +1969,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2069,8 +2037,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2143,8 +2110,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2217,8 +2183,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2275,8 +2240,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2328,8 +2292,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2381,8 +2344,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2450,8 +2412,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2519,8 +2480,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2588,8 +2548,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2657,8 +2616,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2710,8 +2668,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2779,8 +2736,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2848,8 +2804,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2917,8 +2872,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2970,8 +2924,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -3027,8 +2980,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3084,8 +3036,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3153,8 +3104,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json index 60cf69f63..0b77d2e55 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int.types", - "runtime_reqs": [], "types": { "int": { "extension": "arithmetic.int.types", @@ -19,6 +18,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/collections/array.json b/hugr-py/src/hugr/std/_json_defs/collections/array.json index 21e405151..fba222793 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.array", - "runtime_reqs": [], "types": { "array": { "extension": "collections.array", @@ -25,7 +24,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", @@ -61,8 +59,7 @@ "bound": "A" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false @@ -127,8 +124,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -167,9 +163,6 @@ { "tp": "Type", "b": "A" - }, - { - "tp": "Extensions" } ], "body": { @@ -183,9 +176,6 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [ - "2" ] } ], @@ -214,8 +204,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -244,9 +233,6 @@ "tp": "Type", "b": "A" } - }, - { - "tp": "Extensions" } ], "body": { @@ -300,9 +286,6 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [ - "4" ] }, { @@ -341,8 +324,7 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -466,8 +448,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -579,8 +560,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/collections/list.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json index 0fbafc638..de9736e4e 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/list.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.list", - "runtime_reqs": [], "types": { "List": { "extension": "collections.list", @@ -21,7 +20,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", @@ -71,8 +69,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -152,8 +149,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -208,8 +204,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -275,8 +270,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -333,8 +327,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -414,8 +407,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json index e4669f671..cde35e063 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.static_array", - "runtime_reqs": [], "types": { "static_array": { "extension": "collections.static_array", @@ -19,7 +18,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", @@ -69,8 +67,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -109,8 +106,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index ad9f02019..45cd7f606 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -1,36 +1,7 @@ { "version": "0.1.0", "name": "logic", - "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", @@ -57,8 +28,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -88,8 +58,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -114,8 +83,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -145,8 +113,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -176,8 +143,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index e11ba2388..7cf1d02c7 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -1,7 +1,6 @@ { "version": "0.2.0", "name": "prelude", - "runtime_reqs": [], "types": { "error": { "extension": "prelude", @@ -44,7 +43,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", @@ -74,8 +72,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -116,8 +113,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -147,8 +143,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -189,8 +184,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -237,8 +231,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -260,8 +253,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -308,8 +300,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -330,8 +321,7 @@ "bound": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/hugr-py/src/hugr/std/_json_defs/ptr.json b/hugr-py/src/hugr/std/_json_defs/ptr.json index 18b1f26b6..d701fff53 100644 --- a/hugr-py/src/hugr/std/_json_defs/ptr.json +++ b/hugr-py/src/hugr/std/_json_defs/ptr.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "ptr", - "runtime_reqs": [], "types": { "ptr": { "extension": "ptr", @@ -19,7 +18,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", @@ -57,8 +55,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -99,8 +96,7 @@ "i": 0, "b": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -140,8 +136,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/hugr-py/src/hugr/std/int.py b/hugr-py/src/hugr/std/int.py index 4c1d0cdeb..27432a3d5 100644 --- a/hugr-py/src/hugr/std/int.py +++ b/hugr-py/src/hugr/std/int.py @@ -93,7 +93,7 @@ def type_args(self) -> list[tys.TypeArg]: def cached_signature(self) -> tys.FunctionType | None: row: list[tys.Type] = [int_t(self.width)] * 2 - return tys.FunctionType.endo(row, runtime_reqs=[INT_OPS_EXTENSION.name]) + return tys.FunctionType.endo(row) @classmethod def from_ext(cls, custom: ExtOp) -> Self | None: diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index fbaadf7d3..8411f19bf 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -188,21 +188,6 @@ def to_model(self) -> model.Term: return model.Apply("core.tuple", [item_types]) -@dataclass(frozen=True) -class ExtensionsParam(TypeParam): - """An extension set parameter.""" - - def _to_serial(self) -> stys.ExtensionsParam: - return stys.ExtensionsParam() - - def __str__(self) -> str: - return "Extensions" - - def to_model(self) -> model.Term: - # Since extension sets will be deprecated, this is just a placeholder. - return model.Apply("compat.ext_set_type") - - # ------------------------------------------ # --------------- TypeArg ------------------ # ------------------------------------------ @@ -280,23 +265,6 @@ def to_model(self) -> model.Term: return model.List([elem.to_model() for elem in self.elems]) -@dataclass(frozen=True) -class ExtensionsArg(TypeArg): - """Type argument for an :class:`ExtensionsParam`.""" - - extensions: ExtensionSet - - def _to_serial(self) -> stys.ExtensionsArg: - return stys.ExtensionsArg(es=self.extensions) - - def __str__(self) -> str: - return f"Extensions({comma_sep_str(self.extensions)})" - - def to_model(self) -> model.Term: - # Since extension sets will be deprecated, this is just a placeholder. - return model.Apply("compat.ext_set") - - @dataclass(frozen=True) class VariableArg(TypeArg): """A type argument variable.""" @@ -518,7 +486,6 @@ class FunctionType(Type): input: TypeRow output: TypeRow - runtime_reqs: ExtensionSet = field(default_factory=ExtensionSet) def type_bound(self) -> TypeBound: return TypeBound.Copyable @@ -527,7 +494,6 @@ def _to_serial(self) -> stys.FunctionType: return stys.FunctionType( input=ser_it(self.input), output=ser_it(self.output), - runtime_reqs=self.runtime_reqs, ) @classmethod @@ -541,16 +507,14 @@ def empty(cls) -> FunctionType: return cls(input=[], output=[]) @classmethod - def endo( - cls, tys: TypeRow, runtime_reqs: ExtensionSet | None = None - ) -> FunctionType: + def endo(cls, tys: TypeRow) -> FunctionType: """Function type with the same input and output types. Example: >>> FunctionType.endo([Qubit]) FunctionType([Qubit], [Qubit]) """ - return cls(input=tys, output=tys, runtime_reqs=runtime_reqs or ExtensionSet()) + return cls(input=tys, output=tys) def flip(self) -> FunctionType: """Return a new function type with input and output types swapped. @@ -569,17 +533,8 @@ def resolve(self, registry: ext.ExtensionRegistry) -> FunctionType: return FunctionType( input=[ty.resolve(registry) for ty in self.input], output=[ty.resolve(registry) for ty in self.output], - runtime_reqs=self.runtime_reqs, ) - def with_runtime_reqs(self, runtime_reqs: ExtensionSet) -> FunctionType: - """Adds a list of extension requirements to the function type, and - returns the new signature. - """ - exts = set(self.runtime_reqs) - exts = exts.union(runtime_reqs) - return FunctionType(self.input, self.output, [*exts]) - def __str__(self) -> str: return f"{comma_sep_str(self.input)} -> {comma_sep_str(self.output)}" @@ -614,15 +569,6 @@ def resolve(self, registry: ext.ExtensionRegistry) -> PolyFuncType: body=self.body.resolve(registry), ) - def with_runtime_reqs(self, runtime_reqs: ExtensionSet) -> PolyFuncType: - """Adds a list of extension requirements to the function type, and - returns the new signature. - """ - return PolyFuncType( - params=self.params, - body=self.body.with_runtime_reqs(runtime_reqs), - ) - def __str__(self) -> str: return f"∀ {comma_sep_str(self.params)}. {self.body!s}" diff --git a/hugr-py/tests/serialization/test_extension.py b/hugr-py/tests/serialization/test_extension.py index cf595319a..7f1ea28bf 100644 --- a/hugr-py/tests/serialization/test_extension.py +++ b/hugr-py/tests/serialization/test_extension.py @@ -25,7 +25,6 @@ { "version": "0.1.0", "name": "ext", - "runtime_reqs": [], "types": { "foo": { "extension": "ext", @@ -64,8 +63,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "lower_funcs": [] @@ -99,7 +97,6 @@ def test_extension(): ext = Extension( version=SemanticVersion(0, 1, 0), name="ext", - runtime_reqs=set(), types={"foo": type_def}, values={}, operations={"New": op_def}, @@ -121,7 +118,6 @@ def test_package(): ext = Extension( version=SemanticVersion(0, 1, 0), name="ext", - runtime_reqs=set(), types={}, values={}, operations={}, diff --git a/hugr-py/tests/test_custom.py b/hugr-py/tests/test_custom.py index 3018bf863..48f57de7a 100644 --- a/hugr-py/tests/test_custom.py +++ b/hugr-py/tests/test_custom.py @@ -37,7 +37,7 @@ def type_args(self) -> list[tys.TypeArg]: return [tys.StringArg(self.tag)] def cached_signature(self) -> tys.FunctionType | None: - return tys.FunctionType.endo([], runtime_reqs=[STRINGLY_EXT.name]) + return tys.FunctionType.endo([]) @classmethod def from_ext(cls, custom: ops.ExtOp) -> "StringlyOp": diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py index e2c6d7d51..33bf55561 100644 --- a/hugr-py/tests/test_tys.py +++ b/hugr-py/tests/test_tys.py @@ -14,8 +14,6 @@ BoundedNatArg, BoundedNatParam, Either, - ExtensionsArg, - ExtensionsParam, ExtType, FunctionType, ListParam, @@ -95,7 +93,6 @@ def test_tys_sum_str(ty: Type, string: str, repr_str: str): "(Any, Nat(3))", ), (ListParam(StringParam()), "[String]"), - (ExtensionsParam(), "Extensions"), ], ) def test_params_str(param: TypeParam, string: str): @@ -113,7 +110,6 @@ def test_params_str(param: TypeParam, string: str): "(Type(Qubit), 3)", ), (VariableArg(2, StringParam()), "$2"), - (ExtensionsArg(["A", "B"]), "Extensions(A, B)"), ], ) def test_args_str(arg: TypeArg, string: str): diff --git a/hugr/Cargo.toml b/hugr/Cargo.toml index c0439a960..3385df9f1 100644 --- a/hugr/Cargo.toml +++ b/hugr/Cargo.toml @@ -24,15 +24,13 @@ path = "src/lib.rs" [features] default = ["zstd"] -extension_inference = ["hugr-core/extension_inference"] declarative = ["hugr-core/declarative"] -model_unstable = ["hugr-core/model_unstable", "hugr-model"] llvm = ["hugr-llvm/llvm14-0"] llvm-test = ["hugr-llvm/llvm14-0", "hugr-llvm/test-utils"] zstd = ["hugr-core/zstd"] [dependencies] -hugr-model = { path = "../hugr-model", optional = true, version = "0.19.0" } +hugr-model = { path = "../hugr-model", version = "0.19.0" } hugr-core = { path = "../hugr-core", version = "0.15.3" } hugr-passes = { path = "../hugr-passes", version = "0.15.3" } hugr-llvm = { path = "../hugr-llvm", version = "0.15.3", optional = true } diff --git a/hugr/README.md b/hugr/README.md index 6ecfc405b..83a2cc501 100644 --- a/hugr/README.md +++ b/hugr/README.md @@ -1,7 +1,6 @@ ![](/hugr/assets/hugr_logo.svg) -hugr -=============== +# hugr [![build_status][]](https://github.com/CQCL/hugr/actions) [![crates][]](https://crates.io/crates/hugr) @@ -29,10 +28,6 @@ Please read the [API documentation here][]. ## Experimental Features -- `extension_inference`: - Experimental feature which allows automatic inference of which extra extensions - are required at runtime by a HUGR when validating it. - Not enabled by default. - `declarative`: Experimental support for declaring extensions in YAML files, support is limited. @@ -51,7 +46,7 @@ This project is licensed under Apache License, Version 2.0 ([LICENSE][] or http: [API documentation here]: https://docs.rs/hugr/ [build_status]: https://github.com/CQCL/hugr/actions/workflows/ci-rs.yml/badge.svg?branch=main - [msrv]: https://img.shields.io/badge/rust-1.75.0%2B-blue.svg + [msrv]: https://img.shields.io/crates/msrv/hugr [crates]: https://img.shields.io/crates/v/hugr [codecov]: https://img.shields.io/codecov/c/gh/CQCL/hugr?logo=codecov [LICENSE]: https://github.com/CQCL/hugr/blob/main/LICENCE diff --git a/hugr/benches/benchmarks/hugr.rs b/hugr/benches/benchmarks/hugr.rs index 49d73d58e..3635c8d09 100644 --- a/hugr/benches/benchmarks/hugr.rs +++ b/hugr/benches/benchmarks/hugr.rs @@ -24,10 +24,8 @@ impl Serializer for JsonSer { } } -#[cfg(feature = "model_unstable")] struct CapnpSer; -#[cfg(feature = "model_unstable")] impl Serializer for CapnpSer { fn serialize(&self, hugr: &Hugr) -> Vec { let bump = bumpalo::Bump::new(); @@ -90,20 +88,17 @@ fn bench_serialization(c: &mut Criterion) { } group.finish(); - #[cfg(feature = "model_unstable")] - { - let mut group = c.benchmark_group("circuit_roundtrip/capnp"); - group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); - for size in [0, 1, 10, 100, 1000].iter() { - group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { - let h = circuit(size).0; - b.iter(|| { - black_box(roundtrip(&h, CapnpSer)); - }); + let mut group = c.benchmark_group("circuit_roundtrip/capnp"); + group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); + for size in [0, 1, 10, 100, 1000].iter() { + group.bench_with_input(BenchmarkId::from_parameter(size), size, |b, &size| { + let h = circuit(size).0; + b.iter(|| { + black_box(roundtrip(&h, CapnpSer)); }); - } - group.finish(); + }); } + group.finish(); } criterion_group! { diff --git a/hugr/benches/benchmarks/hugr/examples.rs b/hugr/benches/benchmarks/hugr/examples.rs index 0ece1eefb..2b7676439 100644 --- a/hugr/benches/benchmarks/hugr/examples.rs +++ b/hugr/benches/benchmarks/hugr/examples.rs @@ -7,9 +7,8 @@ use hugr::builder::{ HugrBuilder, ModuleBuilder, }; use hugr::extension::prelude::{bool_t, qb_t, usize_t}; -use hugr::extension::ExtensionSet; use hugr::ops::OpName; -use hugr::std_extensions::arithmetic::float_types::{self, float64_type, ConstF64}; +use hugr::std_extensions::arithmetic::float_types::{float64_type, ConstF64}; use hugr::types::Signature; use hugr::{type_row, CircuitUnit, Extension, Hugr, Node}; use lazy_static::lazy_static; @@ -97,11 +96,7 @@ pub fn circuit(layers: usize) -> (Hugr, Vec) { let h_gate = QUANTUM_EXT.instantiate_extension_op("H", []).unwrap(); let cx_gate = QUANTUM_EXT.instantiate_extension_op("CX", []).unwrap(); let rz = QUANTUM_EXT.instantiate_extension_op("Rz", []).unwrap(); - let signature = - Signature::new_endo(vec![qb_t(), qb_t()]).with_extension_delta(ExtensionSet::from_iter([ - QUANTUM_EXT.name().clone(), - float_types::EXTENSION_ID, - ])); + let signature = Signature::new_endo(vec![qb_t(), qb_t()]); let mut module_builder = ModuleBuilder::new(); let mut f_build = module_builder.define_function("main", signature).unwrap(); diff --git a/hugr/src/hugr.rs b/hugr/src/hugr.rs index 88c8c8df0..0bd8f64ff 100644 --- a/hugr/src/hugr.rs +++ b/hugr/src/hugr.rs @@ -2,7 +2,7 @@ // Exports everything except the `internal` module. pub use hugr_core::hugr::{ - hugrmut, rewrite, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, - InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Rewrite, RootTagged, + hugrmut, patch, serialize, validate, views, Hugr, HugrError, HugrView, IdentList, + InvalidIdentifier, LoadHugrError, NodeMetadata, NodeMetadataMap, OpType, Patch, SimpleReplacement, SimpleReplacementError, ValidationError, DEFAULT_OPTYPE, }; diff --git a/justfile b/justfile index 61173375b..d7e3f81f2 100644 --- a/justfile +++ b/justfile @@ -23,7 +23,7 @@ test-rust: HUGR_TEST_SCHEMA=1 cargo test \ --workspace \ --exclude 'hugr-py' \ - --features 'hugr/extension_inference hugr/declarative hugr/model_unstable hugr/llvm hugr/llvm-test hugr/zstd' + --features 'hugr/declarative hugr/llvm hugr/llvm-test hugr/zstd' # Run all python tests. test-python: uv run maturin develop --uv diff --git a/release-plz.toml b/release-plz.toml index 091ca3795..4bc9f7104 100644 --- a/release-plz.toml +++ b/release-plz.toml @@ -63,9 +63,7 @@ version_group = "hugr" [[package]] name = "hugr-model" release = true -# Use a separate version group while the dependency is `-unstable`, -# to avoid breaking releases of the main package. -version_group = "hugr-model" +version_group = "hugr" [[package]] name = "hugr-llvm" diff --git a/specification/hugr.md b/specification/hugr.md index 6204e0e4f..3bd22b8ef 100644 --- a/specification/hugr.md +++ b/specification/hugr.md @@ -891,71 +891,6 @@ See [Declarative Format](#declarative-format) for more examples. Note that since a row variable does not have kind Type, it cannot be used as the type of an edge. -### Extension Tracking - -The type of `Function` includes a set of [extensions](#extension-system) which are required to execute the graph. -Similarly, every dataflow node in the HUGR has a set of extensions required to execute the node (computed from its operation), -also known as the "delta". The delta of any node must be a subset of its parent's delta, -except for FuncDefn's: -* the delta of any child of a FuncDefn must be a subset of the extensions in the FuncDefn's *type* -* the FuncDefn itself has no delta (trivially a subset of any parent): this reflects that the extensions -are not needed to *know* the FuncDefn, only to *execute* it -(by a Call node, whose delta is taken from the called FuncDefn's *type*). - -Keeping track of the extension requirements like this allows extension designers -and third-party tooling to control how/where a module is run. - -Concretely, if a plugin writer adds an extension -*X*, then some function from -a plugin needs to provide a mechanism to convert the -*X* to some other extension -requirement before it can interface with other plugins which don't know -about *X*. - -A runtime could have access to means of -running different extensions. By the same mechanism, the runtime can reason -about where to run different parts of the graph by inspecting their -extension requirements. - -Special operations **lift** and **liftGraph** can add extension requirements: -* `lift>` is a node with input and output rows `R` and extension-delta `{E}` -* `liftGraph, E: ExtensionSet, O: List>` has one input -$ \vec{I}^{\underrightarrow{\;E\;}}\vec{O} $ and one output $ \vec{I}^{\underrightarrow{\;E \cup N\;}}\vec{O}$. -That is, given a graph, it adds extensions $N$ to the requirements of the graph. - -The latter is useful for higher-order operations such as conditionally selecting -one function or another, where the output must have a consistent type (including -the extension-requirements of the function). - -### Rewriting Extension Requirements - -Extension requirements help denote different runtime capabilities. -For example, a quantum computer may not be able to handle arithmetic -while running a circuit, so its use is tracked in the function type so that -rewrites can be performed which remove the arithmetic. - -Simple circuits may look something like: - -```haskell -Function[Quantum](Array(5, Q), (ms: Array(5, Qubit), results: Array(5, Bit))) -``` - -A circuit built using a higher-order extension to manage control flow -could then look like: - -```haskell -Function[Quantum, HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))) -``` - -So the compiler would need to perform some graph transformation pass to turn the -graph-based control flow into a CFG node that a quantum computer could -run, which removes the `HigherOrder` extension requirement. - -```haskell -precompute :: Function[](Function[Quantum,HigherOrder](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit))), - Function[Quantum](Array(5, Qubit), (ms: Array(5, Qubit), results: Array(5, Bit)))) -``` - ## Extension System ### Goals and constraints diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 9e7d8c40c..02889a3f4 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -517,13 +495,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -535,9 +506,7 @@ "required": [ "version", "name", - "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,64 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, - "ExtensionsArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -772,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1571,13 +1475,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1684,7 +1581,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1705,9 +1601,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1783,7 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1806,9 +1698,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 6f436f969..558f64c57 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -517,13 +495,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -535,9 +506,7 @@ "required": [ "version", "name", - "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,64 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, - "ExtensionsArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -772,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1571,13 +1475,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1684,7 +1581,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1705,9 +1601,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1783,7 +1676,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1806,9 +1698,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index bc067d40e..f534a3cbd 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -517,13 +495,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -535,9 +506,7 @@ "required": [ "version", "name", - "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,64 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, - "ExtensionsArg": { - "additionalProperties": true, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": true, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -772,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1570,13 +1474,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1762,7 +1659,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1783,9 +1679,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1861,7 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1884,9 +1776,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 47c9778d3..eb3fcff0f 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -277,13 +277,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -428,13 +421,6 @@ }, "title": "Sum Rows", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -502,14 +488,6 @@ "title": "Name", "type": "string" }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array", - "uniqueItems": true - }, "types": { "additionalProperties": { "$ref": "#/$defs/TypeDef" @@ -517,13 +495,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -535,9 +506,7 @@ "required": [ "version", "name", - "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,64 +558,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, - "ExtensionsArg": { - "additionalProperties": false, - "properties": { - "tya": { - "const": "Extensions", - "default": "Extensions", - "title": "Tya", - "type": "string" - }, - "es": { - "items": { - "type": "string" - }, - "title": "Es", - "type": "array" - } - }, - "required": [ - "es" - ], - "title": "ExtensionsArg", - "type": "object" - }, - "ExtensionsParam": { - "additionalProperties": false, - "properties": { - "tp": { - "const": "Extensions", - "default": "Extensions", - "title": "Tp", - "type": "string" - } - }, - "title": "ExtensionsParam", - "type": "object" - }, "FixedHugr": { "properties": { "extensions": { @@ -772,13 +683,6 @@ }, "title": "Output", "type": "array" - }, - "runtime_reqs": { - "items": { - "type": "string" - }, - "title": "Runtime Reqs", - "type": "array" } }, "required": [ @@ -1570,13 +1474,6 @@ }, "title": "Rest", "type": "array" - }, - "extension_delta": { - "items": { - "type": "string" - }, - "title": "Extension Delta", - "type": "array" } }, "required": [ @@ -1762,7 +1659,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatArg", - "Extensions": "#/$defs/ExtensionsArg", "Sequence": "#/$defs/SequenceArg", "String": "#/$defs/StringArg", "Type": "#/$defs/TypeTypeArg", @@ -1783,9 +1679,6 @@ { "$ref": "#/$defs/SequenceArg" }, - { - "$ref": "#/$defs/ExtensionsArg" - }, { "$ref": "#/$defs/VariableArg" } @@ -1861,7 +1754,6 @@ "discriminator": { "mapping": { "BoundedNat": "#/$defs/BoundedNatParam", - "Extensions": "#/$defs/ExtensionsParam", "List": "#/$defs/ListParam", "String": "#/$defs/StringParam", "Tuple": "#/$defs/TupleParam", @@ -1884,9 +1776,6 @@ }, { "$ref": "#/$defs/TupleParam" - }, - { - "$ref": "#/$defs/ExtensionsParam" } ], "required": [ diff --git a/specification/std_extensions/arithmetic/conversions.json b/specification/std_extensions/arithmetic/conversions.json index 9c0054354..3c2fb983c 100644 --- a/specification/std_extensions/arithmetic/conversions.json +++ b/specification/std_extensions/arithmetic/conversions.json @@ -1,12 +1,7 @@ { "version": "0.1.0", "name": "arithmetic.conversions", - "runtime_reqs": [ - "arithmetic.float.types", - "arithmetic.int.types" - ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", @@ -37,8 +32,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -72,8 +66,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -116,8 +109,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -160,8 +152,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -193,8 +184,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -224,8 +214,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -257,8 +246,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -301,8 +289,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -345,8 +332,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -376,8 +362,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -437,8 +422,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -498,8 +482,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 31ccaaa59..60180ec84 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -1,11 +1,7 @@ { "version": "0.1.0", "name": "arithmetic.float", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", @@ -31,8 +27,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -68,8 +63,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -98,8 +92,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -135,8 +128,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -170,8 +162,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -200,8 +191,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -235,8 +225,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -270,8 +259,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -305,8 +293,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -340,8 +327,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -377,8 +363,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -414,8 +399,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -451,8 +435,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -486,8 +469,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -516,8 +498,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -553,8 +534,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -583,8 +563,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -620,8 +599,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -650,8 +628,7 @@ "args": [], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/float/types.json b/specification/std_extensions/arithmetic/float/types.json index 56e35c50b..33db43f5b 100644 --- a/specification/std_extensions/arithmetic/float/types.json +++ b/specification/std_extensions/arithmetic/float/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.float.types", - "runtime_reqs": [], "types": { "float64": { "extension": "arithmetic.float.types", @@ -14,6 +13,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 62d0a6663..e8e6fdca8 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -1,11 +1,7 @@ { "version": "0.1.0", "name": "arithmetic.int", - "runtime_reqs": [ - "arithmetic.int.types" - ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", @@ -54,8 +50,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -123,8 +118,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -192,8 +186,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -278,8 +271,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -364,8 +356,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -433,8 +424,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -502,8 +492,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -612,8 +601,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -722,8 +710,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -807,8 +794,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -892,8 +878,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -950,8 +935,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1008,8 +992,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1066,8 +1049,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1124,8 +1106,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1182,8 +1163,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1240,8 +1220,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1298,8 +1277,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1356,8 +1334,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1414,8 +1391,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1483,8 +1459,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1552,8 +1527,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1621,8 +1595,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1690,8 +1663,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1776,8 +1748,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1862,8 +1833,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -1931,8 +1901,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2000,8 +1969,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2069,8 +2037,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2143,8 +2110,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2217,8 +2183,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -2275,8 +2240,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2328,8 +2292,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2381,8 +2344,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2450,8 +2412,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2519,8 +2480,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2588,8 +2548,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2657,8 +2616,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2710,8 +2668,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2779,8 +2736,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2848,8 +2804,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2917,8 +2872,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -2970,8 +2924,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -3027,8 +2980,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3084,8 +3036,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": true @@ -3153,8 +3104,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/arithmetic/int/types.json b/specification/std_extensions/arithmetic/int/types.json index 60cf69f63..0b77d2e55 100644 --- a/specification/std_extensions/arithmetic/int/types.json +++ b/specification/std_extensions/arithmetic/int/types.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "arithmetic.int.types", - "runtime_reqs": [], "types": { "int": { "extension": "arithmetic.int.types", @@ -19,6 +18,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/collections/array.json b/specification/std_extensions/collections/array.json index 21e405151..fba222793 100644 --- a/specification/std_extensions/collections/array.json +++ b/specification/std_extensions/collections/array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.array", - "runtime_reqs": [], "types": { "array": { "extension": "collections.array", @@ -25,7 +24,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", @@ -61,8 +59,7 @@ "bound": "A" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false @@ -127,8 +124,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -167,9 +163,6 @@ { "tp": "Type", "b": "A" - }, - { - "tp": "Extensions" } ], "body": { @@ -183,9 +176,6 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [ - "2" ] } ], @@ -214,8 +204,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -244,9 +233,6 @@ "tp": "Type", "b": "A" } - }, - { - "tp": "Extensions" } ], "body": { @@ -300,9 +286,6 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [ - "4" ] }, { @@ -341,8 +324,7 @@ "i": 3, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -466,8 +448,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -579,8 +560,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/collections/list.json b/specification/std_extensions/collections/list.json index 0fbafc638..de9736e4e 100644 --- a/specification/std_extensions/collections/list.json +++ b/specification/std_extensions/collections/list.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.list", - "runtime_reqs": [], "types": { "List": { "extension": "collections.list", @@ -21,7 +20,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", @@ -71,8 +69,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -152,8 +149,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -208,8 +204,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -275,8 +270,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -333,8 +327,7 @@ ], "bound": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -414,8 +407,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/collections/static_array.json b/specification/std_extensions/collections/static_array.json index e4669f671..cde35e063 100644 --- a/specification/std_extensions/collections/static_array.json +++ b/specification/std_extensions/collections/static_array.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "collections.static_array", - "runtime_reqs": [], "types": { "static_array": { "extension": "collections.static_array", @@ -19,7 +18,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", @@ -69,8 +67,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -109,8 +106,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index ad9f02019..45cd7f606 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -1,36 +1,7 @@ { "version": "0.1.0", "name": "logic", - "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", @@ -57,8 +28,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -88,8 +58,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -114,8 +83,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -145,8 +113,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -176,8 +143,7 @@ "s": "Unit", "size": 2 } - ], - "runtime_reqs": [] + ] } }, "binary": false diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index e11ba2388..7cf1d02c7 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -1,7 +1,6 @@ { "version": "0.2.0", "name": "prelude", - "runtime_reqs": [], "types": { "error": { "extension": "prelude", @@ -44,7 +43,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", @@ -74,8 +72,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -116,8 +113,7 @@ ] ] } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -147,8 +143,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -189,8 +184,7 @@ "i": 0, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -237,8 +231,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -260,8 +253,7 @@ { "t": "I" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -308,8 +300,7 @@ "i": 1, "b": "A" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -330,8 +321,7 @@ "bound": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/specification/std_extensions/ptr.json b/specification/std_extensions/ptr.json index 18b1f26b6..d701fff53 100644 --- a/specification/std_extensions/ptr.json +++ b/specification/std_extensions/ptr.json @@ -1,7 +1,6 @@ { "version": "0.1.0", "name": "ptr", - "runtime_reqs": [], "types": { "ptr": { "extension": "ptr", @@ -19,7 +18,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", @@ -57,8 +55,7 @@ ], "bound": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -99,8 +96,7 @@ "i": 0, "b": "C" } - ], - "runtime_reqs": [] + ] } }, "binary": false @@ -140,8 +136,7 @@ "b": "C" } ], - "output": [], - "runtime_reqs": [] + "output": [] } }, "binary": false diff --git a/uv.lock b/uv.lock index 130657231..4f7d6012a 100644 --- a/uv.lock +++ b/uv.lock @@ -277,7 +277,7 @@ wheels = [ [[package]] name = "hugr" -version = "0.11.4" +version = "0.11.5" source = { editable = "hugr-py" } dependencies = [ { name = "graphviz" },