diff --git a/.github/workflows/sqlx.yml b/.github/workflows/sqlx.yml index d583f33d64..229c248549 100644 --- a/.github/workflows/sqlx.yml +++ b/.github/workflows/sqlx.yml @@ -32,7 +32,8 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] steps: - uses: actions/checkout@v2 @@ -48,7 +49,7 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-check-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: @@ -56,21 +57,22 @@ jobs: args: > --manifest-path sqlx-core/Cargo.toml --no-default-features - --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }} + --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - uses: actions-rs/cargo@v1 with: command: check args: > --no-default-features - --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }},macros + --features offline,all-databases,all-types,migrate,runtime-${{ matrix.runtime }}-${{ matrix.tls }},macros test: name: Unit Test runs-on: ubuntu-20.04 strategy: matrix: - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] steps: - uses: actions/checkout@v2 @@ -93,7 +95,7 @@ jobs: command: test args: > --manifest-path sqlx-core/Cargo.toml - --features offline,all-databases,all-types,runtime-${{ matrix.runtime }} + --features offline,all-databases,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} cli: name: CLI Binaries @@ -148,7 +150,8 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -165,14 +168,14 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-sqlite-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-sqlite-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: test args: > --no-default-features - --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }} + --features any,macros,migrate,sqlite,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} -- --test-threads=1 env: @@ -183,8 +186,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - postgres: [12, 10, 9_6, 9_5] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + postgres: [13, 9_6] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -201,23 +205,24 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-postgres-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-postgres-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features postgres,all-types,runtime-${{ matrix.runtime }} + --features postgres,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - - run: docker-compose -f tests/docker-compose.yml run -d -p 5432:5432 postgres_${{ matrix.postgres }} - - run: sleep 10 + - run: | + docker-compose -f tests/docker-compose.yml run -d -p 5432:5432 --name postgres_${{ matrix.postgres }} postgres_${{ matrix.postgres }} + docker exec postgres_${{ matrix.postgres }} bash -c "until pg_isready; do sleep 1; done" - uses: actions-rs/cargo@v1 with: command: test args: > --no-default-features - --features any,postgres,macros,all-types,runtime-${{ matrix.runtime }} + --features any,postgres,macros,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx @@ -226,7 +231,7 @@ jobs: command: test args: > --no-default-features - --features any,postgres,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,postgres,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: postgres://postgres:password@localhost:5432/sqlx?sslmode=verify-ca&sslrootcert=.%2Ftests%2Fcerts%2Fca.crt @@ -235,8 +240,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mysql: [8, 5_7, 5_6] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + mysql: [8, 5_6] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -253,13 +259,13 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features mysql,all-types,runtime-${{ matrix.runtime }} + --features mysql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mysql_${{ matrix.mysql }} - run: sleep 60 @@ -269,7 +275,7 @@ jobs: command: test args: > --no-default-features - --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx @@ -278,8 +284,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mariadb: [10_5, 10_4, 10_3, 10_2, 10_1] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + mariadb: [10_6, 10_2] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -297,13 +304,13 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-mysql-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features mysql,all-types,runtime-${{ matrix.runtime }} + --features mysql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 3306:3306 mariadb_${{ matrix.mariadb }} - run: sleep 30 @@ -313,7 +320,7 @@ jobs: command: test args: > --no-default-features - --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,mysql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: mysql://root:password@localhost:3306/sqlx @@ -322,8 +329,9 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - mssql: [2019] - runtime: [async-std-native-tls, tokio-native-tls, actix-native-tls, async-std-rustls, tokio-rustls, actix-rustls] + mssql: [2019, 2017] + runtime: [async-std, tokio, actix] + tls: [native-tls, rustls] needs: check steps: - uses: actions/checkout@v2 @@ -340,13 +348,13 @@ jobs: ~/.cargo/registry ~/.cargo/git target - key: ${{ runner.os }}-mssql-${{ matrix.runtime }}-${{ hashFiles('**/Cargo.lock') }} + key: ${{ runner.os }}-mssql-${{ matrix.runtime }}-${{ matrix.tls }}-${{ hashFiles('**/Cargo.lock') }} - uses: actions-rs/cargo@v1 with: command: build args: > - --features mssql,all-types,runtime-${{ matrix.runtime }} + --features mssql,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} - run: docker-compose -f tests/docker-compose.yml run -d -p 1433:1433 mssql_${{ matrix.mssql }} - run: sleep 80 # MSSQL takes a "bit" to startup @@ -356,6 +364,6 @@ jobs: command: test args: > --no-default-features - --features any,mssql,macros,migrate,all-types,runtime-${{ matrix.runtime }} + --features any,mssql,macros,migrate,all-types,runtime-${{ matrix.runtime }}-${{ matrix.tls }} env: DATABASE_URL: mssql://sa:Password123!@localhost/sqlx diff --git a/CHANGELOG.md b/CHANGELOG.md index 1189e2d2bc..ca19f348ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,79 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## 0.5.9 - 2021-10-01 + +A hotfix release to address the issue of the `sqlx` crate itself still depending on older versions of `sqlx-core` and +`sqlx-macros`. + +No other changes from `0.5.8`. + +## 0.5.8 - 2021-10-01 (Yanked; use 0.5.9) + +[A total of 24 pull requests][0.5.8-prs] were merged this release cycle! Some highlights: + +* [[#1289]] Support the `immutable` option on SQLite connections [[@djmarcin]] +* [[#1295]] Support custom initial options for SQLite [[@ghassmo]] + * Allows specifying custom `PRAGMA`s and overriding those set by SQLx. +* [[#1345]] Initial support for Postgres `COPY FROM/TO`[[@montanalow], [@abonander]] +* [[#1439]] Handle multiple waiting results correctly in MySQL [[@eagletmt]] + +[#1289]: https://github.com/launchbadge/sqlx/pull/1289 +[#1295]: https://github.com/launchbadge/sqlx/pull/1295 +[#1345]: https://github.com/launchbadge/sqlx/pull/1345 +[#1439]: https://github.com/launchbadge/sqlx/pull/1439 +[0.5.8-prs]: https://github.com/launchbadge/sqlx/pulls?q=is%3Apr+is%3Amerged+merged%3A2021-08-21..2021-10-01 + +## 0.5.7 - 2021-08-20 + +* [[#1392]] use `resolve_path` when getting path for `include_str!()` [[@abonander]] + * Fixes a regression introduced by [[#1332]]. +* [[#1393]] avoid recursively spawning tasks in `PgListener::drop()` [[@abonander]] + * Fixes a panic that occurs when `PgListener` is dropped in `async fn main()`. + +[#1392]: https://github.com/launchbadge/sqlx/pull/1392 +[#1393]: https://github.com/launchbadge/sqlx/pull/1393 + +## 0.5.6 - 2021-08-16 + +A large bugfix release, including but not limited to: + +* [[#1329]] Implement `MACADDR` type for Postgres [[@nomick]] +* [[#1363]] Fix `PortalSuspended` for array of composite types in Postgres [[@AtkinsChang]] +* [[#1320]] Reimplement `sqlx::Pool` internals using `futures-intrusive` [[@abonander]] + * This addresses a number of deadlocks/stalls on acquiring connections from the pool. +* [[#1332]] Macros: tell the compiler about external files/env vars to watch [[@abonander]] + * Includes `sqlx build-script` to create a `build.rs` to watch `migrations/` for changes. + * Nightly users can try `RUSTFLAGS=--cfg sqlx_macros_unstable` to tell the compiler + to watch `migrations/` for changes instead of using a build script. + * See the new section in the docs for `sqlx::migrate!()` for details. +* [[#1351]] Fix a few sources of segfaults/errors in SQLite driver [[@abonander]] + * Includes contributions from [[@link2ext]] and [[@madadam]]. +* [[#1323]] Keep track of column typing in SQLite EXPLAIN parsing [[@marshoepial]] + * This fixes errors in the macros when using `INSERT/UPDATE/DELETE ... RETURNING ...` in SQLite. + +[A total of 25 pull requests][0.5.6-prs] were merged this release cycle! + +[#1329]: https://github.com/launchbadge/sqlx/pull/1329 +[#1363]: https://github.com/launchbadge/sqlx/pull/1363 +[#1320]: https://github.com/launchbadge/sqlx/pull/1320 +[#1332]: https://github.com/launchbadge/sqlx/pull/1332 +[#1351]: https://github.com/launchbadge/sqlx/pull/1351 +[#1323]: https://github.com/launchbadge/sqlx/pull/1323 +[0.5.6-prs]: https://github.com/launchbadge/sqlx/pulls?q=is%3Apr+is%3Amerged+merged%3A2021-05-24..2021-08-17 + +## 0.5.5 - 2021-05-24 + +- [[#1242]] Fix infinite loop at compile time when using query macros [[@toshokan]] + +[#1242]: https://github.com/launchbadge/sqlx/pull/1242 + +## 0.5.4 - 2021-05-22 + +- [[#1235]] Fix compilation with rustls from an eager update to webpki [[@ETCaton]] + +[#1235]: https://github.com/launchbadge/sqlx/pull/1235 + ## 0.5.3 - 2021-05-21 - [[#1211]] Even more tweaks and fixes to the Pool internals [[@abonander]] @@ -911,3 +984,14 @@ Fix docs.rs build by enabling a runtime feature in the docs.rs metadata in `Carg [@guylapid]: https://github.com/guylapid [@natproach]: https://github.com/NatPRoach [@feikesteenbergen]: https://github.com/feikesteenbergen +[@etcaton]: https://github.com/ETCaton +[@toshokan]: https://github.com/toshokan +[@nomick]: https://github.com/nomick +[@marshoepial]: https://github.com/marshoepial +[@link2ext]: https://github.com/link2ext +[@madadam]: https://github.com/madadam +[@AtkinsChang]: https://github.com/AtkinsChang +[@djmarcin]: https://github.com/djmarcin +[@ghassmo]: https://github.com/ghassmo +[@eagletmt]: https://github.com/eagletmt +[@montanalow]: https://github.com/montanalow \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 82aa0caeed..bf275b32f2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,5 +1,7 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. +version = 3 + [[package]] name = "actix-rt" version = "2.2.0" @@ -206,7 +208,7 @@ dependencies = [ "pin-project-lite", "pin-utils", "slab", - "wasm-bindgen-futures", + "wasm-bindgen-futures 0.4.24", ] [[package]] @@ -226,6 +228,17 @@ dependencies = [ "syn", ] +[[package]] +name = "async_io_stream" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d5ad740b7193a31e80950ab7fece57c38d426fcd23a729d9d7f4cf15bb63f94" +dependencies = [ + "futures 0.3.15", + "pharos", + "rustc_version 0.3.3", +] + [[package]] name = "atoi" version = "0.4.0" @@ -346,12 +359,6 @@ dependencies = [ "serde", ] -[[package]] -name = "build_const" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4ae4235e6dac0694637c763029ecea1a2ec9e4e06ec2729bd21ba4d9c863eb7" - [[package]] name = "bumpalo" version = "3.6.1" @@ -490,6 +497,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "console_error_panic_hook" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8d976903543e0c48546a91908f21588a680a8c8f984df9a5d69feccb2b2a211" +dependencies = [ + "cfg-if 0.1.10", + "wasm-bindgen", +] + [[package]] name = "const_fn" version = "0.4.8" @@ -523,13 +540,19 @@ dependencies = [ [[package]] name = "crc" -version = "1.8.1" +version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d663548de7f5cca343f1e0a48d14dcfb0e9eb4e079ec58883b7251539fa10aeb" +checksum = "10c2722795460108a7872e1cd933a85d6ec38abc4baecad51028f702da28889f" dependencies = [ - "build_const", + "crc-catalog", ] +[[package]] +name = "crc-catalog" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccaeedb56da03b09f598226e25e80088cb4cd25f316e6e4df7d695f0feeb1403" + [[package]] name = "criterion" version = "0.3.4" @@ -623,9 +646,9 @@ dependencies = [ [[package]] name = "crypto-mac" -version = "0.10.0" +version = "0.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4857fd85a0c34b3c3297875b747c1e02e06b6a0ea32dd892d8192b9ce0813ea6" +checksum = "b1d1a86f49236c215f271d40892d5fc950490551400b02ef360692c29815c714" dependencies = [ "generic-array", "subtle", @@ -663,18 +686,6 @@ dependencies = [ "syn", ] -[[package]] -name = "dialoguer" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c9dd058f8b65922819fabb4a41e7d1964e56344042c26efbccd465202c23fa0c" -dependencies = [ - "console", - "lazy_static", - "tempfile", - "zeroize", -] - [[package]] name = "difference" version = "2.0.0" @@ -699,6 +710,16 @@ dependencies = [ "dirs-sys", ] +[[package]] +name = "dirs-next" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf36e65a80337bea855cd4ef9b8401ffce06a7baedf2e85ec467b1ac3f6e82b6" +dependencies = [ + "cfg-if 1.0.0", + "dirs-sys-next", +] + [[package]] name = "dirs-sys" version = "0.3.6" @@ -710,6 +731,17 @@ dependencies = [ "winapi", ] +[[package]] +name = "dirs-sys-next" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ebda144c4fe02d1f7ea1a7d9641b6fc6b580adcfa024ae48797ecdeb6825b4d" +dependencies = [ + "libc", + "redox_users", + "winapi", +] + [[package]] name = "discard" version = "1.0.4" @@ -839,6 +871,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fed34cd105917e91daa4da6b3728c47b068749d6a62c59811f06ed2ac71d9da7" +[[package]] +name = "futures" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678" + [[package]] name = "futures" version = "0.3.15" @@ -864,12 +902,27 @@ dependencies = [ "futures-sink", ] +[[package]] +name = "futures-channel-preview" +version = "0.3.0-alpha.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5e5f4df964fa9c1c2f8bddeb5c3611631cacd93baf810fc8bb2fb4b495c263a" +dependencies = [ + "futures-core-preview", +] + [[package]] name = "futures-core" version = "0.3.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0402f765d8a89a26043b889b26ce3c4679d268fa6bb22cd7c6aad98340e179d1" +[[package]] +name = "futures-core-preview" +version = "0.3.0-alpha.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b35b6263fb1ef523c3056565fa67b1d16f0a8604ff12b11b08c25f28a734c60a" + [[package]] name = "futures-executor" version = "0.3.15" @@ -881,6 +934,17 @@ dependencies = [ "futures-util", ] +[[package]] +name = "futures-intrusive" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62007592ac46aa7c2b6416f7deb9a8a8f63a01e0f1d6e1787d5630170db2b63e" +dependencies = [ + "futures-core", + "lock_api", + "parking_lot", +] + [[package]] name = "futures-io" version = "0.3.15" @@ -948,6 +1012,17 @@ dependencies = [ "slab", ] +[[package]] +name = "futures-util-preview" +version = "0.3.0-alpha.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce968633c17e5f97936bd2797b6e38fb56cf16a7422319f7ec2e30d3c470e8d" +dependencies = [ + "futures-core-preview", + "pin-utils", + "slab", +] + [[package]] name = "generic-array" version = "0.14.4" @@ -965,15 +1040,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fcd999463524c52659517fe2cea98493cfe485d10565e7b0fb07dbba7ad2753" dependencies = [ "cfg-if 1.0.0", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] name = "git2" -version = "0.13.19" +version = "0.13.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17929de7239dea9f68aa14f94b2ab4974e7b24c1314275ffcc12a7758172fa18" +checksum = "d9831e983241f8c5591ed53f17d874833e2fa82cac2625f3888c50cbfe136cba" dependencies = [ "bitflags", "libc", @@ -1007,12 +1084,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62aca2aba2d62b4a7f5b33f3712cb1b0692779a56fb510499d5c0aa594daeaf3" -[[package]] -name = "hashbrown" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7afe4a420e3fe79967a00898cc1f4db7c8a49a9333a29f8a4bd76a253d5cd04" - [[package]] name = "hashbrown" version = "0.11.2" @@ -1028,7 +1099,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7249a3129cbc1ffccd74857f81464a323a152173cdb134e0fd81bc803b29facf" dependencies = [ - "hashbrown 0.11.2", + "hashbrown", ] [[package]] @@ -1057,9 +1128,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hmac" -version = "0.10.1" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c1441c6b1e930e2817404b5046f1f989899143a12bf92de603b69f4e0aee1e15" +checksum = "2a2a2320eb7ec0ebe8da8f744d7812d9fc4cb4d09344ac01898dbcb6a20ae69b" dependencies = [ "crypto-mac", "digest", @@ -1093,12 +1164,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.6.2" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "824845a0bf897a9042383849b02c1bc219c2383772efcd5c6f9766fa4b81aef3" +checksum = "bc633605454125dec4b66843673f01c7df2b89479b32e0ed634e43a91cff62a5" dependencies = [ "autocfg 1.0.1", - "hashbrown 0.9.1", + "hashbrown", ] [[package]] @@ -1165,7 +1236,7 @@ dependencies = [ "anyhow", "async-std", "dotenv", - "futures", + "futures 0.3.15", "paw", "serde", "serde_json", @@ -1212,9 +1283,9 @@ checksum = "18794a8ad5b29321f790b55d93dfba91e125cb1a9edbd4f8e3150acc771c1a5e" [[package]] name = "libgit2-sys" -version = "0.12.20+1.1.0" +version = "0.12.21+1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e2f09917e00b9ad194ae72072bb5ada2cca16d8171a43e91ddba2afbb02664b" +checksum = "86271bacd72b2b9e854c3dcfb82efd538f15f870e4c11af66900effb462f6825" dependencies = [ "cc", "libc", @@ -1230,9 +1301,9 @@ checksum = "c7d73b3f436185384286bd8098d17ec07c9a7d2388a6599f824d8502b529702a" [[package]] name = "libsqlite3-sys" -version = "0.22.2" +version = "0.23.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290b64917f8b0cb885d9de0f9959fe1f775d7fa12f1da2db9001c1c8ab60f89d" +checksum = "abd5850c449b40bacb498b2bbdfaff648b1b055630073ba8db499caf2d0ea9f2" dependencies = [ "cc", "pkg-config", @@ -1270,6 +1341,16 @@ dependencies = [ "value-bag", ] +[[package]] +name = "mac_address" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9bb26482176bddeea173ceaa2acec85146d20cdcc631eafaf9d605d3d4fc23" +dependencies = [ + "nix 0.19.1", + "winapi", +] + [[package]] name = "maplit" version = "1.0.2" @@ -1375,6 +1456,30 @@ dependencies = [ "tempfile", ] +[[package]] +name = "nix" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83450fe6a6142ddd95fb064b746083fc4ef1705fe81f64a64e1d4b39f54a1055" +dependencies = [ + "bitflags", + "cc", + "cfg-if 0.1.10", + "libc", +] + +[[package]] +name = "nix" +version = "0.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2ccba0cfe4fdf15982d1674c69b1fd80bad427d293849982668dfe454bd61f2" +dependencies = [ + "bitflags", + "cc", + "cfg-if 1.0.0", + "libc", +] + [[package]] name = "nom" version = "6.1.2" @@ -1641,6 +1746,16 @@ dependencies = [ "ucd-trie", ] +[[package]] +name = "pharos" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e5f09cc3d3cb227536487dfe7fb8c246eccc6c8a32d03daa30ba4cf3212917" +dependencies = [ + "futures 0.3.15", + "rustc_version 0.2.3", +] + [[package]] name = "pin-project-lite" version = "0.2.6" @@ -1773,13 +1888,22 @@ checksum = "bc881b2c22681370c6a780e47af9840ef841837bc98118431d4e1868bd0c1086" [[package]] name = "proc-macro2" -version = "1.0.27" +version = "1.0.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0d8caf72986c1a598726adc988bb5984792ef84f5ee5aa50209145ee8077038" +checksum = "5c7ed8b8c7b886ea3ed7dde405212185f423ab44682667c8c6dd14aa1d9f6612" dependencies = [ "unicode-xid", ] +[[package]] +name = "promptly" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b99cfb0289110d969dd21637cfbc922584329bc9e5037c5e576325c615658509" +dependencies = [ + "rustyline", +] + [[package]] name = "quick-error" version = "1.2.3" @@ -2010,6 +2134,25 @@ dependencies = [ "webpki", ] +[[package]] +name = "rustyline" +version = "6.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0d5e7b0219a3eadd5439498525d4765c59b7c993ef0c12244865cd2d988413" +dependencies = [ + "cfg-if 0.1.10", + "dirs-next", + "libc", + "log", + "memchr", + "nix 0.18.0", + "scopeguard", + "unicode-segmentation", + "unicode-width", + "utf8parse", + "winapi", +] + [[package]] name = "ryu" version = "1.0.5" @@ -2035,6 +2178,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "scoped-tls" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea6a9290e3c9cf0f18145ef7ffa62d68ee0bf5fcd651017e586dc7fd5da448c2" + [[package]] name = "scopeguard" version = "1.1.0" @@ -2107,6 +2256,12 @@ dependencies = [ "pest", ] +[[package]] +name = "send_wrapper" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "930c0acf610d3fdb5e2ab6213019aaa04e227ebe9547b0649ba599b16d788bd7" + [[package]] name = "serde" version = "1.0.126" @@ -2255,13 +2410,13 @@ dependencies = [ [[package]] name = "sqlx" -version = "0.5.3" +version = "0.5.9" dependencies = [ "anyhow", "async-std", "dotenv", "env_logger 0.8.3", - "futures", + "futures 0.3.15", "paste", "serde", "serde_json", @@ -2288,18 +2443,19 @@ dependencies = [ [[package]] name = "sqlx-cli" -version = "0.5.3" +version = "0.5.9" dependencies = [ "anyhow", "async-trait", "chrono", "clap 3.0.0-beta.2", + "clap_derive", "console", - "dialoguer", "dotenv", - "futures", + "futures 0.3.15", "glob", "openssl", + "promptly", "remove_dir_all 0.7.0", "serde", "serde_json", @@ -2310,7 +2466,7 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.5.3" +version = "0.5.9" dependencies = [ "ahash", "atoi", @@ -2332,17 +2488,21 @@ dependencies = [ "encoding_rs", "futures-channel", "futures-core", + "futures-intrusive", "futures-util", "generic-array", + "getrandom", "git2", "hashlink", "hex", "hmac", + "indexmap", "ipnetwork", "itoa", "libc", "libsqlite3-sys", "log", + "mac_address", "md-5", "memchr", "num-bigint 0.3.2", @@ -2378,7 +2538,7 @@ version = "0.1.0" dependencies = [ "anyhow", "async-std", - "futures", + "futures 0.3.15", "paw", "sqlx", "structopt", @@ -2389,7 +2549,7 @@ name = "sqlx-example-postgres-listen" version = "0.1.0" dependencies = [ "async-std", - "futures", + "futures 0.3.15", "sqlx", ] @@ -2401,7 +2561,7 @@ dependencies = [ "async-std", "async-trait", "dotenv", - "futures", + "futures 0.3.15", "mockall", "paw", "sqlx", @@ -2415,19 +2575,28 @@ dependencies = [ "anyhow", "async-std", "dotenv", - "futures", + "futures 0.3.15", "paw", "sqlx", "structopt", ] +[[package]] +name = "sqlx-example-postgres-transaction" +version = "0.1.0" +dependencies = [ + "async-std", + "futures", + "sqlx", +] + [[package]] name = "sqlx-example-sqlite-todos" version = "0.1.0" dependencies = [ "anyhow", "async-std", - "futures", + "futures 0.3.15", "paw", "sqlx", "structopt", @@ -2435,11 +2604,10 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.5.3" +version = "0.5.9" dependencies = [ "dotenv", "either", - "futures", "heck", "hex", "once_cell", @@ -2456,17 +2624,23 @@ dependencies = [ [[package]] name = "sqlx-rt" -version = "0.5.3" +version = "0.5.9" dependencies = [ "actix-rt", "async-native-tls", "async-rustls", "async-std", + "async_io_stream", + "futures-util", "native-tls", "once_cell", "tokio", "tokio-native-tls", "tokio-rustls", + "wasm-bindgen", + "wasm-bindgen-futures 0.3.27", + "web-sys", + "ws_stream_wasm", ] [[package]] @@ -2481,6 +2655,24 @@ dependencies = [ "tokio", ] +[[package]] +name = "sqlx-wasm-test" +version = "0.1.0" +dependencies = [ + "futures 0.3.15", + "instant", + "paste", + "serde", + "serde_json", + "sqlx", + "time 0.2.26", + "wasm-bindgen", + "wasm-bindgen-futures 0.3.27", + "wasm-bindgen-test", + "web-sys", + "ws_stream_wasm", +] + [[package]] name = "standback" version = "0.2.17" @@ -2600,9 +2792,9 @@ checksum = "1e81da0851ada1f3e9d4312c704aa4f8806f0f9d69faaf8df2f3464b4a9437c2" [[package]] name = "syn" -version = "1.0.72" +version = "1.0.74" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1e8cdbefb79a9a5a65e0db8b47b723ee907b7c7f8496c76a1770b5c310bab82" +checksum = "1873d832550d4588c3dbc20f01361ab00bfe741048f71e3fecf145a7cc18b29c" dependencies = [ "proc-macro2", "quote", @@ -2935,6 +3127,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8parse" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "936e4b492acfd135421d8dca4b1aa80a7bfc26e702ef3af710e0752684df5372" + [[package]] name = "uuid" version = "0.8.2" @@ -3017,6 +3215,22 @@ dependencies = [ "wasm-bindgen-shared", ] +[[package]] +name = "wasm-bindgen-futures" +version = "0.3.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83420b37346c311b9ed822af41ec2e82839bfe99867ec6c54e2da43b7538771c" +dependencies = [ + "cfg-if 0.1.10", + "futures 0.1.31", + "futures-channel-preview", + "futures-util-preview", + "js-sys", + "lazy_static", + "wasm-bindgen", + "web-sys", +] + [[package]] name = "wasm-bindgen-futures" version = "0.4.24" @@ -3058,6 +3272,30 @@ version = "0.2.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7cff876b8f18eed75a66cf49b65e7f967cb354a7aa16003fb55dbfd25b44b4f" +[[package]] +name = "wasm-bindgen-test" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cab416a9b970464c2882ed92d55b0c33046b08e0bdc9d59b3b718acd4e1bae8" +dependencies = [ + "console_error_panic_hook", + "js-sys", + "scoped-tls", + "wasm-bindgen", + "wasm-bindgen-futures 0.4.24", + "wasm-bindgen-test-macro", +] + +[[package]] +name = "wasm-bindgen-test-macro" +version = "0.3.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd4543fc6cf3541ef0d98bf720104cc6bd856d7eba449fd2aa365ef4fed0e782" +dependencies = [ + "proc-macro2", + "quote", +] + [[package]] name = "web-sys" version = "0.3.51" @@ -3137,6 +3375,24 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" +[[package]] +name = "ws_stream_wasm" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9528a8ec27f348ae922b166e86670ef8be9f0427642d1f8725289f07e5207cd7" +dependencies = [ + "async_io_stream", + "futures 0.3.15", + "js-sys", + "pharos", + "rustc_version 0.3.3", + "send_wrapper", + "thiserror", + "wasm-bindgen", + "wasm-bindgen-futures 0.4.24", + "web-sys", +] + [[package]] name = "wyz" version = "0.2.0" diff --git a/Cargo.toml b/Cargo.toml index 6234ea8d49..a27f8848e4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "sqlx-rt", "sqlx-macros", "sqlx-test", + "sqlx-wasm-test", "sqlx-cli", "sqlx-bench", "examples/mysql/todos", @@ -12,12 +13,13 @@ members = [ "examples/postgres/listen", "examples/postgres/todos", "examples/postgres/mockable-todos", + "examples/postgres/transaction", "examples/sqlite/todos", ] [package] name = "sqlx" -version = "0.5.3" +version = "0.5.9" license = "MIT OR Apache-2.0" readme = "README.md" repository = "https://github.com/launchbadge/sqlx" @@ -59,6 +61,7 @@ all-types = [ "time", "chrono", "ipnetwork", + "mac_address", "uuid", "bit-vec", "bstr", @@ -121,6 +124,7 @@ bigdecimal = ["sqlx-core/bigdecimal", "sqlx-macros/bigdecimal"] decimal = ["sqlx-core/decimal", "sqlx-macros/decimal"] chrono = ["sqlx-core/chrono", "sqlx-macros/chrono"] ipnetwork = ["sqlx-core/ipnetwork", "sqlx-macros/ipnetwork"] +mac_address = ["sqlx-core/mac_address", "sqlx-macros/mac_address"] uuid = ["sqlx-core/uuid", "sqlx-macros/uuid"] json = ["sqlx-core/json", "sqlx-macros/json"] time = ["sqlx-core/time", "sqlx-macros/time"] @@ -129,8 +133,8 @@ bstr = ["sqlx-core/bstr"] git2 = ["sqlx-core/git2"] [dependencies] -sqlx-core = { version = "0.5.3", path = "sqlx-core", default-features = false } -sqlx-macros = { version = "0.5.3", path = "sqlx-macros", default-features = false, optional = true } +sqlx-core = { version = "0.5.9", path = "sqlx-core", default-features = false } +sqlx-macros = { version = "0.5.9", path = "sqlx-macros", default-features = false, optional = true } [dev-dependencies] anyhow = "1.0.31" @@ -253,6 +257,26 @@ name = "postgres-derives" path = "tests/postgres/derives.rs" required-features = ["postgres", "macros"] +[[test]] +name = "pg-deletes-bench" +path = "tests/postgres/deletes_custom_bench.rs" +required-features = ["postgres"] + +[[test]] +name = "pg-updates-bench" +path = "tests/postgres/updates_custom_bench.rs" +required-features = ["postgres"] + +[[test]] +name = "pg-inserts-bench" +path = "tests/postgres/inserts_custom_bench.rs" +required-features = ["postgres"] + +[[test]] +name = "pg-selects-bench" +path = "tests/postgres/selects_custom_bench.rs" +required-features = ["postgres"] + # # Microsoft SQL Server (MSSQL) # diff --git a/FAQ.md b/FAQ.md new file mode 100644 index 0000000000..82ec4f4dde --- /dev/null +++ b/FAQ.md @@ -0,0 +1,200 @@ +SQLx Frequently Asked Questions +=============================== + +---------------------------------------------------------------- +### How can I do a `SELECT ... WHERE foo IN (...)` query? + + +In 0.6 SQLx will support binding arrays as a comma-separated list for every database, +but unfortunately there's no general solution for that currently in SQLx itself. +You would need to manually generate the query, at which point it +cannot be used with the macros. + +However, **in Postgres** you can work around this limitation by binding the arrays directly and using `= ANY()`: + +```rust +let db: PgPool = /* ... */; +let foo_ids: Vec = vec![/* ... */]; + +let foos = sqlx::query!( + "SELECT * FROM foo WHERE id = ANY($1)", + // a bug of the parameter typechecking code requires all array parameters to be slices + &foo_ids[..] +) + .fetch_all(&db) + .await?; +``` + +Even when SQLx gains generic placeholder expansion for arrays, this will still be the optimal way to do it for Postgres, +as comma-expansion means each possible length of the array generates a different query +(and represents a combinatorial explosion if more than one array is used). + +Note that you can use any operator that returns a boolean, but beware that `!= ANY($1)` is **not equivalent** to `NOT IN (...)` as it effectively works like this: + +`lhs != ANY(rhs) -> false OR lhs != rhs[0] OR lhs != rhs[1] OR ... lhs != rhs[length(rhs) - 1]` + +The equivalent of `NOT IN (...)` would be `!= ALL($1)`: + +`lhs != ALL(rhs) -> true AND lhs != rhs[0] AND lhs != rhs[1] AND ... lhs != rhs[length(rhs) - 1]` + +Note that `ANY` using any operator and passed an empty array will return `false`, thus the leading `false OR ...`. +Meanwhile, `ALL` with any operator and passed an empty array will return `true`, thus the leading `true AND ...`. + +See also: [Postgres Manual, Section 9.24: Row and Array Comparisons](https://www.postgresql.org/docs/current/functions-comparisons.html) + +----- +### How can I bind an array to a `VALUES()` clause? How can I do bulk inserts? + +Like the above, SQLx currently does not support this in the general case right now but will in 0.6. + +However, **Postgres** also has a feature to save the day here! You can pass an array to `UNNEST()` and +it will treat it as a temporary table: + +```rust +let foo_texts: Vec = vec![/* ... */]; + +sqlx::query!( + // because `UNNEST()` is a generic function, Postgres needs the cast on the parameter here + // in order to know what type to expect there when preparing the query + "INSERT INTO foo(text_column) SELECT * FROM UNNEST($1::text[])", + &foo_texts[..] +) + .execute(&db) + .await?; +``` + +`UNNEST()` can also take more than one array, in which case it'll treat each array as a column in the temporary table: + +```rust +// this solution currently requires each column to be its own vector +// in 0.6 we're aiming to allow binding iterators directly as arrays +// so you can take a vector of structs and bind iterators mapping to each field +let foo_texts: Vec = vec![/* ... */]; +let foo_bools: Vec = vec![/* ... */]; +let foo_ints: Vec = vec![/* ... */]; + +sqlx::query!( + " + INSERT INTO foo(text_column, bool_column, int_column) + SELECT * FROM UNNEST($1::text[], $2::bool[], $3::int8[]]) + ", + &foo_texts[..], + &foo_bools[..], + &foo_ints[..] +) + .execute(&db) + .await?; +``` + +Again, even with comma-expanded lists in 0.6 this will likely still be the most performant way to run bulk inserts +with Postgres--at least until we get around to implementing an interface for `COPY FROM STDIN`, though +this solution with `UNNEST()` will still be more flexible as you can use it in queries that are more complex +than just inserting into a table. + +Note that if some vectors are shorter than others, `UNNEST` will fill the corresponding columns with `NULL`s +to match the longest vector. + +For example, if `foo_texts` is length 5, `foo_bools` is length 4, `foo_ints` is length 3, the resulting table will +look like this: + +| Row # | `text_column` | `bool_column` | `int_column` | +| ----- | -------------- | -------------- | ------------- | +| 1 | `foo_texts[0]` | `foo_bools[0]` | `foo_ints[0]` | +| 2 | `foo_texts[1]` | `foo_bools[1]` | `foo_ints[1]` | +| 3 | `foo_texts[2]` | `foo_bools[2]` | `foo_ints[2]` | +| 4 | `foo_texts[3]` | `foo_bools[3]` | `NULL` | +| 5 | `foo_texts[4]` | `NULL` | `NULL` | + +See Also: +* [Postgres Manual, Section 7.2.1.4: Table Functions](https://www.postgresql.org/docs/current/queries-table-expressions.html#QUERIES-TABLEFUNCTIONS) +* [Postgres Manual, Section 9.19: Array Functions and Operators](https://www.postgresql.org/docs/current/functions-array.html) + +---- +### How do I compile with the macros without needing a database, e.g. in CI? + +The macros support an offline mode which saves data for existing queries to a JSON file, +so the macros can just read that file instead of talking to a database. + +See the following: + +* [the docs for `query!()`](https://docs.rs/sqlx/0.5.5/sqlx/macro.query.html#offline-mode-requires-the-offline-feature) +* [the README for `sqlx-cli`](sqlx-cli/README.md#enable-building-in-offline-mode-with-query) + +To keep `sqlx-data.json` up-to-date you need to run `cargo sqlx prepare` before every commit that +adds or changes a query; you can do this with a Git pre-commit hook: + +```shell +$ echo "cargo sqlx prepare > /dev/null 2>&1; git add sqlx-data.json > /dev/null" > .git/hooks/pre-commit +``` + +Note that this may make committing take some time as it'll cause your project to be recompiled, and +as an ergonomic choice it does _not_ block committing if `cargo sqlx prepare` fails. + +We're working on a way for the macros to save their data to the filesystem automatically which should be part of SQLx 0.6, +so your pre-commit hook would then just need to stage the changed files. + +---- + +### How do the query macros work under the hood? + +The macros work by talking to the database at compile time. When a(n) SQL client asks to create a prepared statement +from a query string, the response from the server typically includes information about the following: + +* the number of bind parameters, and their expected types if the database is capable of inferring that +* the number, names and types of result columns, as well as the original table and columns names before aliasing + +In MySQL/MariaDB, we also get boolean flag signaling if a column is `NOT NULL`, however +in Postgres and SQLite, we need to do a bit more work to determine whether a column can be `NULL` or not. + +After preparing, the Postgres driver will first look up the result columns in their source table and check if they have +a `NOT NULL` constraint. Then, it will execute `EXPLAIN (VERBOSE, FORMAT JSON) ` to determine which columns +come from half-open joins (LEFT and RIGHT joins), which makes a normally `NOT NULL` column nullable. Since the +`EXPLAIN VERBOSE` format is not stable or completely documented, this inference isn't perfect. However, it does err on +the side of producing false-positives (marking a column nullable when it's `NOT NULL`) to avoid errors at runtime. + +If you do encounter false-positives please feel free to open an issue; make sure to include your query, any relevant +schema as well as the output of `EXPLAIN (VERBOSE, FORMAT JSON) ` which will make this easier to debug. + +The SQLite driver will pull the bytecode of the prepared statement and step through it to find any instructions +that produce a null value for any column in the output. + +--- +### Why can't SQLx just look at my database schema/migrations and parse the SQL itself? + +Take a moment and think of the effort that would be required to do that. + +To implement this for a single database driver, SQLx would need to: + +* know how to parse SQL, and not just standard SQL but the specific dialect of that particular database +* know how to analyze and typecheck SQL queries in the context of the original schema +* if inferring schema from migrations it would need to simulate all the schema-changing effects of those migrations + +This is effectively reimplementing a good chunk of the database server's frontend, + +_and_ maintaining and ensuring correctness of that reimplementation, + +including bugs and idiosyncrasies, + +for the foreseeable future, + +for _every_ database we intend to support. + +Even Sisyphus would pity us. + +---- + +### Why does my project using sqlx query macros not build on docs.rs? + +Docs.rs doesn't have access to your database, so it needs to be provided a `sqlx-data.json` file and be instructed to set the `SQLX_OFFLINE` environment variable to true while compiling your project. Luckily for us, docs.rs creates a `DOCS_RS` environment variable that we can access in a custom build script to achieve this functionality. + +To do so, first, make sure that you have run `cargo sqlx prepare` to generate a `sqlx-data.json` file in your project. + +Next, create a file called `build.rs` in the root of your project directory (at the same level as `Cargo.toml`). Add the following code to it: +```rs +fn main() { + // When building in docs.rs, we want to set SQLX_OFFLINE mode to true + if std::env::var_os("DOCS_RS").is_some() { + println!("cargo:rustc-env=SQLX_OFFLINE=true"); + } +} +``` diff --git a/README.md b/README.md index df8a0e056d..fe89e37a1d 100644 --- a/README.md +++ b/README.md @@ -50,7 +50,13 @@
- Built with ❤️ by The LaunchBadge team + Built with ❤️ by The LaunchBadge team +
+ +
+ +
+
Have a question? Be sure to check the FAQ first!

@@ -59,20 +65,23 @@ SQLx is an async, pure Rust SQL crate featuring compile-time check - **Truly Asynchronous**. Built from the ground-up using async/await for maximum concurrency. -- **Type-safe SQL** (if you want it) without DSLs. Use the `query!()` macro to check your SQL and bind parameters at - compile time. (You can still use dynamic SQL queries if you like.) +- **Compile-time checked queries** (if you want). See [SQLx is not an ORM](#sqlx-is-not-an-orm). - **Database Agnostic**. Support for [PostgreSQL], [MySQL], [SQLite], and [MSSQL]. - **Pure Rust**. The Postgres and MySQL/MariaDB drivers are written in pure Rust using **zero** unsafe†† code. -* **Runtime Agnostic**. Works on different runtimes ([async-std](https://crates.io/crates/async-std) / [tokio](https://crates.io/crates/tokio) / [actix](https://crates.io/crates/actix-rt)) and TLS backends ([native-tls](https://crates.io/crates/native-tls), [rustls](https://crates.io/crates/rustls)). +- **Runtime Agnostic**. Works on different runtimes ([`async-std`] / [`tokio`] / [`actix`]) and TLS backends ([`native-tls`], [`rustls`]). -† The SQLite driver uses the libsqlite3 C library as SQLite is an embedded database (the only way -we could be pure Rust for SQLite is by porting _all_ of SQLite to Rust). + -†† SQLx uses `#![forbid(unsafe_code)]` unless the `sqlite` feature is enabled. As the SQLite driver interacts -with C, those interactions are `unsafe`. +† The SQLite driver uses the libsqlite3 C library as SQLite is an embedded database (the only way +we could be pure Rust for SQLite is by porting _all_ of SQLite to Rust). + +†† SQLx uses `#![forbid(unsafe_code)]` unless the `sqlite` feature is enabled. As the SQLite driver interacts +with C, those interactions are `unsafe`. + + [postgresql]: http://postgresql.org/ [sqlite]: https://sqlite.org/ @@ -108,6 +117,8 @@ SQLx is compatible with the [`async-std`], [`tokio`] and [`actix`] runtimes; and [`async-std`]: https://github.com/async-rs/async-std [`tokio`]: https://github.com/tokio-rs/tokio [`actix`]: https://github.com/actix/actix-net +[`native-tls`]: https://crates.io/crates/native-tls +[`rustls`]: https://crates.io/crates/rustls ```toml # Cargo.toml @@ -118,11 +129,11 @@ sqlx = { version = "0.5", features = [ "runtime-tokio-rustls" ] } sqlx = { version = "0.5", features = [ "runtime-async-std-native-tls" ] } ``` -The runtime and TLS backend not being separate feature sets to select is a workaround for a [Cargo issue](https://github.com/rust-lang/cargo/issues/3494). +The runtime and TLS backend not being separate feature sets to select is a workaround for a [Cargo issue](https://github.com/rust-lang/cargo/issues/3494). #### Cargo Feature Flags -- `runtime-async-std-native-tls` (on by default): Use the `async-std` runtime and `native-tls` TLS backend. +- `runtime-async-std-native-tls`: Use the `async-std` runtime and `native-tls` TLS backend. - `runtime-async-std-rustls`: Use the `async-std` runtime and `rustls` TLS backend. @@ -168,14 +179,49 @@ sqlx = { version = "0.5", features = [ "runtime-async-std-native-tls" ] } - `tls`: Add support for TLS connections. +- `offline`: Enables building the macros in offline mode when a live database is not available (such as CI). + - Requires `sqlx-cli` installed to use. See [sqlx-cli/README.md][readme-offline]. + +[readme-offline]: sqlx-cli/README.md#enable-building-in-offline-mode-with-query + +## SQLx is not an ORM! + +SQLx supports **compile-time checked queries**. It does not, however, do this by providing a Rust +API or DSL (domain-specific language) for building queries. Instead, it provides macros that take +regular SQL as an input and ensure that it is valid for your database. The way this works is that +SQLx connects to your development DB at compile time to have the database itself verify (and return +some info on) your SQL queries. This has some potentially surprising implications: + +- Since SQLx never has to parse the SQL string itself, any syntax that the development DB accepts + can be used (including things added by database extensions) +- Due to the different amount of information databases let you retrieve about queries, the extent of + SQL verification you get from the query macros depends on the database + +**If you are looking for an (asynchronous) ORM,** you can check out [`ormx`], which is built on top +of SQLx. + +[`ormx`]: https://crates.io/crates/ormx + ## Usage +See the `examples/` folder for more in-depth usage. + ### Quickstart ```toml [dependencies] -sqlx = { version = "0.4.1", features = [ "postgres" ] } -async-std = { version = "1.6", features = [ "attributes" ] } +# PICK ONE: +# Async-std: +sqlx = { version = "0.5", features = [ "runtime-async-std-native-tls", "postgres" ] } +async-std = { version = "1", features = [ "attributes" ] } + +# Tokio: +sqlx = { version = "0.5", features = [ "runtime-tokio-native-tls" , "postgres" ] } +tokio = { version = "1", features = ["full"] } + +# Actix-web: +sqlx = { version = "0.5", features = [ "runtime-actix-native-tls" , "postgres" ] } +actix-web = "3" ``` ```rust @@ -185,6 +231,7 @@ use sqlx::postgres::PgPoolOptions; #[async_std::main] // or #[tokio::main] +// or #[actix_web::main] async fn main() -> Result<(), sqlx::Error> { // Create a connection pool // for MySQL, use MySqlPoolOptions::new() @@ -194,7 +241,7 @@ async fn main() -> Result<(), sqlx::Error> { .max_connections(5) .connect("postgres://postgres:password@localhost/test").await?; - // Make a simple query to return the given parameter + // Make a simple query to return the given parameter (use a question mark `?` instead of `$1` for MySQL) let row: (i64,) = sqlx::query_as("SELECT $1") .bind(150_i64) .fetch_one(&pool).await?; @@ -329,15 +376,17 @@ Differences from `query()`: queries against; the database does not have to contain any data but must be the same kind (MySQL, Postgres, etc.) and have the same schema as the database you will be connecting to at runtime. - For convenience, you can use a .env file to set DATABASE_URL so that you don't have to pass it every time: + For convenience, you can use [a `.env` file][dotenv] to set DATABASE_URL so that you don't have to pass it every time: ``` DATABASE_URL=mysql://localhost/my_database ``` +[dotenv]: https://github.com/dotenv-rs/dotenv#examples + The biggest downside to `query!()` is that the output type cannot be named (due to Rust not -officially supporting anonymous records). To address that, there is a `query_as!()` macro that is identical -except that you can name the output type. +officially supporting anonymous records). To address that, there is a `query_as!()` macro that is +mostly identical except that you can name the output type. ```rust // no traits are needed @@ -359,6 +408,11 @@ WHERE organization = ? // countries[0].count ``` +To avoid the need of having a development database around to compile the project even when no +modifications (to the database-accessing parts of the code) are done, you can enable "offline mode" +to cache the results of the SQL query analysis using the `sqlx` command-line tool. See +[sqlx-cli/README.md](./sqlx-cli/README.md#enable-building-in-offline-mode-with-query). + ## Safety This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in 100% Safe Rust. diff --git a/examples/postgres/transaction/Cargo.toml b/examples/postgres/transaction/Cargo.toml new file mode 100644 index 0000000000..a51d4d37f8 --- /dev/null +++ b/examples/postgres/transaction/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "sqlx-example-postgres-transaction" +version = "0.1.0" +edition = "2018" +workspace = "../../../" + +[dependencies] +async-std = { version = "1.8.0", features = [ "attributes", "unstable" ] } +sqlx = { path = "../../../", features = [ "postgres", "tls", "runtime-async-std-native-tls" ] } +futures = "0.3.1" diff --git a/examples/postgres/transaction/README.md b/examples/postgres/transaction/README.md new file mode 100644 index 0000000000..2cfc1907c3 --- /dev/null +++ b/examples/postgres/transaction/README.md @@ -0,0 +1,18 @@ +# Postgres Transaction Example + +A simple example demonstrating how to obtain and roll back a transaction with postgres. + +## Usage + +Declare the database URL. This example does not include any reading or writing of data. + +``` +export DATABASE_URL="postgres://postgres@localhost/postgres" +``` + +Run. + +``` +cargo run +``` + diff --git a/examples/postgres/transaction/migrations/20200718111257_todos.sql b/examples/postgres/transaction/migrations/20200718111257_todos.sql new file mode 100644 index 0000000000..6599f8c10a --- /dev/null +++ b/examples/postgres/transaction/migrations/20200718111257_todos.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS todos +( + id BIGSERIAL PRIMARY KEY, + description TEXT NOT NULL, + done BOOLEAN NOT NULL DEFAULT FALSE +); diff --git a/examples/postgres/transaction/src/main.rs b/examples/postgres/transaction/src/main.rs new file mode 100644 index 0000000000..50539fe2ac --- /dev/null +++ b/examples/postgres/transaction/src/main.rs @@ -0,0 +1,37 @@ +use sqlx::query; + +#[async_std::main] +async fn main() -> Result<(), Box> { + let conn_str = + std::env::var("DATABASE_URL").expect("Env var DATABASE_URL is required for this example."); + let pool = sqlx::PgPool::connect(&conn_str).await?; + + let mut transaction = pool.begin().await?; + + let test_id = 1; + query!( + r#"INSERT INTO todos (id, description) + VALUES ( $1, $2 ) + "#, + test_id, + "test todo" + ) + .execute(&mut transaction) + .await?; + + // check that inserted todo can be fetched + let _ = query!(r#"SELECT FROM todos WHERE id = $1"#, test_id) + .fetch_one(&mut transaction) + .await?; + + transaction.rollback(); + + // check that inserted todo is now gone + let inserted_todo = query!(r#"SELECT FROM todos WHERE id = $1"#, test_id) + .fetch_one(&pool) + .await; + + assert!(inserted_todo.is_err()); + + Ok(()) +} diff --git a/prep-release.sh b/prep-release.sh new file mode 100755 index 0000000000..79c38330f1 --- /dev/null +++ b/prep-release.sh @@ -0,0 +1,16 @@ +#!/usr/bin/env sh +set -ex + +VERSION=$1 + +if [ -z "$VERSION" ] +then + echo "USAGE: ./prep-release.sh " + exit 1 +fi + +cargo set-version -p sqlx-rt "$VERSION" +cargo set-version -p sqlx-core "$VERSION" +cargo set-version -p sqlx-macros "$VERSION" +cargo set-version -p sqlx "$VERSION" +cargo set-version -p sqlx-cli "$VERSION" \ No newline at end of file diff --git a/sqlx-bench/Cargo.toml b/sqlx-bench/Cargo.toml index 0b288ca892..167311ac53 100644 --- a/sqlx-bench/Cargo.toml +++ b/sqlx-bench/Cargo.toml @@ -45,3 +45,8 @@ sqlx-rt = { version = "0.5", path = "../sqlx-rt", default-features = false } name = "pg_pool" harness = false required-features = ["postgres"] + +[[bench]] +name = "wasm_querying" +harness = false +required-features = ["postgres"] diff --git a/sqlx-bench/benches/wasm_querying.rs b/sqlx-bench/benches/wasm_querying.rs new file mode 100644 index 0000000000..0c6e89facd --- /dev/null +++ b/sqlx-bench/benches/wasm_querying.rs @@ -0,0 +1,26 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use sqlx::Row; +use sqlx::{postgres::PgRow, Connection}; +use sqlx::{Database, PgConnection, Postgres}; +use sqlx_rt::spawn; + +const URL: &str = "postgresql://paul:pass123@127.0.0.1:8080/jetasap_dev"; + +fn select() { + spawn(async { + let mut conn = ::Connection::connect(URL) + .await + .unwrap(); + + let airports = sqlx::query("select * from airports") + .fetch_all(&mut conn) + .await; + }); +} + +fn criterion_benchmark(c: &mut Criterion) { + c.bench_function("fib 20", |b| b.iter(|| select())); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/sqlx-cli/Cargo.toml b/sqlx-cli/Cargo.toml index 6b9f1549fd..a31483a872 100644 --- a/sqlx-cli/Cargo.toml +++ b/sqlx-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-cli" -version = "0.5.3" +version = "0.5.9" description = "Command-line utility for SQLx, the Rust SQL toolkit." edition = "2018" readme = "README.md" @@ -27,20 +27,24 @@ path = "src/bin/cargo-sqlx.rs" [dependencies] dotenv = "0.15" tokio = { version = "1.0.1", features = ["macros", "rt", "rt-multi-thread"] } -sqlx = { version = "0.5.3", path = "..", default-features = false, features = [ +sqlx = { version = "0.5.9", path = "..", default-features = false, features = [ "runtime-async-std-native-tls", "migrate", "any", "offline", ] } futures = "0.3" +# FIXME: we need to fix both of these versions until Clap 3.0 proper is released, then we can drop `clap_derive` +# https://github.com/launchbadge/sqlx/issues/1378 +# https://github.com/clap-rs/clap/issues/2705 clap = "=3.0.0-beta.2" +clap_derive = "=3.0.0-beta.2" chrono = "0.4" anyhow = "1.0" url = { version = "2.1.1", default-features = false } async-trait = "0.1.30" console = "0.14.1" -dialoguer = "0.8.0" +promptly = "0.3.0" serde_json = "1.0.53" serde = { version = "1.0.110", features = ["derive"] } glob = "0.3.0" diff --git a/sqlx-cli/README.md b/sqlx-cli/README.md index 0c47941c4e..2d64cf8d97 100644 --- a/sqlx-cli/README.md +++ b/sqlx-cli/README.md @@ -13,6 +13,9 @@ $ cargo install sqlx-cli # only for postgres $ cargo install sqlx-cli --no-default-features --features postgres + +# use vendored OpenSSL (build from source) +$ cargo install sqlx-cli --features openssl-vendored ``` ### Usage @@ -49,17 +52,53 @@ $ sqlx migrate run Compares the migration history of the running database against the `migrations/` folder and runs any scripts that are still pending. -#### Enable building in "offline" mode with `query!()` +#### Reverting Migrations + +If you would like to create _reversible_ migrations with corresponding "up" and "down" scripts, you use the `-r` flag when creating new migrations: + +```bash +$ sqlx migrate add -r +Creating migrations/20211001154420_.up.sql +Creating migrations/20211001154420_.down.sql +``` + +After that, you can run these as above: + +```bash +$ sqlx migrate run +Applied migrations/20211001154420 (32.517835ms) +``` + +And reverts work as well: + +```bash +$ sqlx migrate revert +Applied 20211001154420/revert +``` + +**Note**: attempting to mix "simple" migrations with reversible migrations with result in an error. + +```bash +$ sqlx migrate add +Creating migrations/20211001154420_.sql + +$ sqlx migrate add -r +error: cannot mix reversible migrations with simple migrations. All migrations should be reversible or simple migrations +``` + +#### Enable building in "offline mode" with `query!()` Note: must be run as `cargo sqlx`. ```bash cargo sqlx prepare ``` -Saves query data to `sqlx-data.json` in the current directory; check this file into version control -and an active database connection will no longer be needed to build your project. -Has no effect unless the `offline` feature of `sqlx` is enabled in your project. Omitting that feature is the most likely cause if you get a `sqlx-data.json` file that looks like this: +Saves query metadata to `sqlx-data.json` in the current directory; check this file into version +control and an active database connection will no longer be needed to build your project. + +Has no effect unless the `offline` feature of `sqlx` is enabled in your project. Omitting that +feature is the most likely cause if you get a `sqlx-data.json` file that looks like this: ```json { @@ -67,10 +106,12 @@ Has no effect unless the `offline` feature of `sqlx` is enabled in your project. } ``` ----- +--- + ```bash cargo sqlx prepare --check ``` + Exits with a nonzero exit status if the data in `sqlx-data.json` is out of date with the current database schema and queries in the project. Intended for use in Continuous Integration. @@ -79,3 +120,6 @@ database schema and queries in the project. Intended for use in Continuous Integ To make sure an accidentally-present `DATABASE_URL` environment variable or `.env` file does not result in `cargo build` (trying to) access the database, you can set the `SQLX_OFFLINE` environment variable to `true`. + +If you want to make this the default, just add it to your `.env` file. `cargo sqlx prepare` will +still do the right thing and connect to the database. diff --git a/sqlx-cli/src/bin/cargo-sqlx.rs b/sqlx-cli/src/bin/cargo-sqlx.rs index bfae83e995..a924af4244 100644 --- a/sqlx-cli/src/bin/cargo-sqlx.rs +++ b/sqlx-cli/src/bin/cargo-sqlx.rs @@ -1,5 +1,6 @@ use clap::{crate_version, AppSettings, FromArgMatches, IntoApp}; use console::style; +use dotenv::dotenv; use sqlx_cli::Opt; use std::{env, process}; @@ -9,6 +10,7 @@ async fn main() { // so we want to notch out that superfluous "sqlx" let args = env::args_os().skip(2); + dotenv().ok(); let matches = Opt::into_app() .version(crate_version!()) .bin_name("cargo sqlx") diff --git a/sqlx-cli/src/bin/sqlx.rs b/sqlx-cli/src/bin/sqlx.rs index 0d18278577..e413581bb9 100644 --- a/sqlx-cli/src/bin/sqlx.rs +++ b/sqlx-cli/src/bin/sqlx.rs @@ -1,9 +1,11 @@ use clap::{crate_version, FromArgMatches, IntoApp}; use console::style; +use dotenv::dotenv; use sqlx_cli::Opt; #[tokio::main] async fn main() { + dotenv().ok(); let matches = Opt::into_app().version(crate_version!()).get_matches(); // no special handling here diff --git a/sqlx-cli/src/database.rs b/sqlx-cli/src/database.rs index 6babb21a36..7521b1fb68 100644 --- a/sqlx-cli/src/database.rs +++ b/sqlx-cli/src/database.rs @@ -1,6 +1,6 @@ use crate::migrate; use console::style; -use dialoguer::Confirm; +use promptly::{prompt, ReadlineError}; use sqlx::any::Any; use sqlx::migrate::MigrateDatabase; @@ -13,16 +13,7 @@ pub async fn create(uri: &str) -> anyhow::Result<()> { } pub async fn drop(uri: &str, confirm: bool) -> anyhow::Result<()> { - if confirm - && !Confirm::new() - .with_prompt(format!( - "\nAre you sure you want to drop the database at {}?", - style(uri).cyan() - )) - .wait_for_newline(true) - .default(false) - .interact()? - { + if confirm && !ask_to_continue(uri) { return Ok(()); } @@ -42,3 +33,28 @@ pub async fn setup(migration_source: &str, uri: &str) -> anyhow::Result<()> { create(uri).await?; migrate::run(migration_source, uri, false, false).await } + +fn ask_to_continue(uri: &str) -> bool { + loop { + let r: Result = + prompt(format!("Drop database at {}? (y/n)", style(uri).cyan())); + match r { + Ok(response) => { + if response == "n" || response == "N" { + return false; + } else if response == "y" || response == "Y" { + return true; + } else { + println!( + "Response not recognized: {}\nPlease type 'y' or 'n' and press enter.", + response + ); + } + } + Err(e) => { + println!("{}", e); + return false; + } + } + } +} diff --git a/sqlx-cli/src/lib.rs b/sqlx-cli/src/lib.rs index 5dd4aeefc6..d02f4fa3b4 100644 --- a/sqlx-cli/src/lib.rs +++ b/sqlx-cli/src/lib.rs @@ -1,7 +1,6 @@ +use anyhow::Result; + use crate::opt::{Command, DatabaseCommand, MigrateCommand}; -use anyhow::anyhow; -use dotenv::dotenv; -use std::env; mod database; // mod migration; @@ -12,15 +11,7 @@ mod prepare; pub use crate::opt::Opt; -pub async fn run(opt: Opt) -> anyhow::Result<()> { - dotenv().ok(); - - let database_url = match opt.database_url { - Some(db_url) => db_url, - None => env::var("DATABASE_URL") - .map_err(|_| anyhow!("The DATABASE_URL environment variable must be set"))?, - }; - +pub async fn run(opt: Opt) -> Result<()> { match opt.command { Command::Migrate(migrate) => match migrate.command { MigrateCommand::Add { @@ -30,33 +21,47 @@ pub async fn run(opt: Opt) -> anyhow::Result<()> { MigrateCommand::Run { dry_run, ignore_missing, + database_url, } => migrate::run(&migrate.source, &database_url, dry_run, ignore_missing).await?, MigrateCommand::Revert { dry_run, ignore_missing, + database_url, } => migrate::revert(&migrate.source, &database_url, dry_run, ignore_missing).await?, - MigrateCommand::Info => migrate::info(&migrate.source, &database_url).await?, + MigrateCommand::Info { database_url } => { + migrate::info(&migrate.source, &database_url).await? + } + MigrateCommand::BuildScript { force } => migrate::build_script(&migrate.source, force)?, }, Command::Database(database) => match database.command { - DatabaseCommand::Create => database::create(&database_url).await?, - DatabaseCommand::Drop { yes } => database::drop(&database_url, !yes).await?, - DatabaseCommand::Reset { yes, source } => { - database::reset(&source, &database_url, !yes).await? + DatabaseCommand::Create { database_url } => database::create(&database_url).await?, + DatabaseCommand::Drop { yes, database_url } => { + database::drop(&database_url, !yes).await? } - DatabaseCommand::Setup { source } => database::setup(&source, &database_url).await?, + DatabaseCommand::Reset { + yes, + source, + database_url, + } => database::reset(&source, &database_url, !yes).await?, + DatabaseCommand::Setup { + source, + database_url, + } => database::setup(&source, &database_url).await?, }, Command::Prepare { check: false, merged, args, + database_url, } => prepare::run(&database_url, merged, args)?, Command::Prepare { check: true, merged, args, + database_url, } => prepare::check(&database_url, merged, args)?, }; diff --git a/sqlx-cli/src/migrate.rs b/sqlx-cli/src/migrate.rs index 20d61f1985..523cf83fa4 100644 --- a/sqlx-cli/src/migrate.rs +++ b/sqlx-cli/src/migrate.rs @@ -42,6 +42,11 @@ pub async fn add( ) -> anyhow::Result<()> { fs::create_dir_all(migration_source).context("Unable to create migrations directory")?; + // if the migrations directory is empty + let has_existing_migrations = fs::read_dir(migration_source) + .map(|mut dir| dir.next().is_some()) + .unwrap_or(false); + let migrator = Migrator::new(Path::new(migration_source)).await?; // This checks if all existing migrations are of the same type as the reverisble flag passed for migration in migrator.iter() { @@ -74,6 +79,31 @@ pub async fn add( )?; } + if !has_existing_migrations { + let quoted_source = if migration_source != "migrations" { + format!("{:?}", migration_source) + } else { + "".to_string() + }; + + print!( + r#" +Congratulations on creating your first migration! + +Did you know you can embed your migrations in your application binary? +On startup, after creating your database connection or pool, add: + +sqlx::migrate!({}).run(<&your_pool OR &mut your_connection>).await?; + +Note that the compiler won't pick up new migrations if no Rust source files have changed. +You can create a Cargo build script to work around this with `sqlx migrate build-script`. + +See: https://docs.rs/sqlx/0.5/sqlx/macro.migrate.html +"#, + quoted_source + ); + } + Ok(()) } @@ -245,3 +275,30 @@ pub async fn revert( Ok(()) } + +pub fn build_script(migration_source: &str, force: bool) -> anyhow::Result<()> { + anyhow::ensure!( + Path::new("Cargo.toml").exists(), + "must be run in a Cargo project root" + ); + + anyhow::ensure!( + (force || !Path::new("build.rs").exists()), + "build.rs already exists; use --force to overwrite" + ); + + let contents = format!( + r#"// generated by `sqlx migrate build-script` +fn main() {{ + // trigger recompilation when a new migration is added + println!("cargo:rerun-if-changed={}"); +}}"#, + migration_source + ); + + fs::write("build.rs", contents)?; + + println!("Created `build.rs`; be sure to check it into version control!"); + + Ok(()) +} diff --git a/sqlx-cli/src/opt.rs b/sqlx-cli/src/opt.rs index 8d912668bf..20243a5e91 100644 --- a/sqlx-cli/src/opt.rs +++ b/sqlx-cli/src/opt.rs @@ -4,9 +4,6 @@ use clap::Clap; pub struct Opt { #[clap(subcommand)] pub command: Command, - - #[clap(short = 'D', long)] - pub database_url: Option, } #[derive(Clap, Debug)] @@ -36,6 +33,10 @@ pub enum Command { /// Arguments to be passed to `cargo rustc ...`. #[clap(last = true)] args: Vec, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, #[clap(alias = "mig")] @@ -52,7 +53,11 @@ pub struct DatabaseOpt { #[derive(Clap, Debug)] pub enum DatabaseCommand { /// Creates the database specified in your DATABASE_URL. - Create, + Create { + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, + }, /// Drops the database specified in your DATABASE_URL. Drop { @@ -60,6 +65,10 @@ pub enum DatabaseCommand { /// your database. #[clap(short)] yes: bool, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// Drops the database specified in your DATABASE_URL, re-creates it, and runs any pending migrations. @@ -72,6 +81,10 @@ pub enum DatabaseCommand { /// Path to folder containing migrations. #[clap(long, default_value = "migrations")] source: String, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// Creates the database specified in your DATABASE_URL and runs any pending migrations. @@ -79,6 +92,10 @@ pub enum DatabaseCommand { /// Path to folder containing migrations. #[clap(long, default_value = "migrations")] source: String, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, } @@ -115,6 +132,10 @@ pub enum MigrateCommand { /// Ignore applied migrations that missing in the resolved migrations #[clap(long)] ignore_missing: bool, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// Revert the latest migration with a down file. @@ -126,8 +147,25 @@ pub enum MigrateCommand { /// Ignore applied migrations that missing in the resolved migrations #[clap(long)] ignore_missing: bool, + + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, short = 'D', env)] + database_url: String, }, /// List all available migrations. - Info, + Info { + /// Location of the DB, by default will be read from the DATABASE_URL env var + #[clap(long, env)] + database_url: String, + }, + + /// Generate a `build.rs` to trigger recompilation when a new migration is added. + /// + /// Must be run in a Cargo project root. + BuildScript { + /// Overwrite the build script if it already exists. + #[clap(long)] + force: bool, + }, } diff --git a/sqlx-core/Cargo.toml b/sqlx-core/Cargo.toml index 7d4ade7798..a6bd78f37d 100644 --- a/sqlx-core/Cargo.toml +++ b/sqlx-core/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-core" -version = "0.5.3" +version = "0.5.9" repository = "https://github.com/launchbadge/sqlx" description = "Core of SQLx, the rust SQL toolkit. Not intended to be used directly." license = "MIT OR Apache-2.0" @@ -54,6 +54,7 @@ all-types = [ "bigdecimal", "decimal", "ipnetwork", + "mac_address", "json", "uuid", "bit-vec", @@ -100,7 +101,7 @@ offline = ["serde", "either/serde"] [dependencies] ahash = "0.7.2" atoi = "0.4.0" -sqlx-rt = { path = "../sqlx-rt", version = "0.5.3" } +sqlx-rt = { path = "../sqlx-rt", version = "0.5.9"} base64 = { version = "0.13.0", default-features = false, optional = true, features = ["std"] } bigdecimal_ = { version = "0.2.0", optional = true, package = "bigdecimal" } rust_decimal = { version = "1.8.1", optional = true } @@ -109,7 +110,7 @@ bitflags = { version = "1.2.1", default-features = false } bytes = "1.0.0" byteorder = { version = "1.3.4", default-features = false, features = ["std"] } chrono = { version = "0.4.11", default-features = false, features = ["clock"], optional = true } -crc = { version = "1.8.1", optional = true } +crc = { version = "2.0.0", optional = true } crossbeam-queue = "0.3.1" crossbeam-channel = "0.5.0" crossbeam-utils = { version = "0.8.1", default-features = false } @@ -119,14 +120,16 @@ encoding_rs = { version = "0.8.23", optional = true } either = "1.5.3" futures-channel = { version = "0.3.5", default-features = false, features = ["sink", "alloc", "std"] } futures-core = { version = "0.3.5", default-features = false } -futures-util = { version = "0.3.5", features = ["sink"] } +futures-intrusive = "0.4.0" +futures-util = { version = "0.3.5", default-features = false, features = ["alloc", "sink"] } generic-array = { version = "0.14.4", default-features = false, optional = true } hex = "0.4.2" -hmac = { version = "0.10.1", default-features = false, optional = true } +hmac = { version = "0.11.0", default-features = false, optional = true } itoa = "0.4.5" ipnetwork = { version = "0.17.0", default-features = false, optional = true } +mac_address = { version = "1.1", default-features = false, optional = true } libc = "0.2.71" -libsqlite3-sys = { version = "0.22.0", optional = true, default-features = false, features = [ +libsqlite3-sys = { version = "0.23.1", optional = true, default-features = false, features = [ "pkg-config", "vcpkg", "bundled", @@ -158,5 +161,9 @@ webpki-roots = { version = "0.21.0", optional = true } whoami = "1.0.1" stringprep = "0.1.2" bstr = { version = "0.2.14", default-features = false, features = ["std"], optional = true } -git2 = { version = "0.13.12", default-features = false, optional = true } +git2 = { version = "0.13.20", default-features = false, optional = true } hashlink = "0.7.0" +indexmap = "1.6.2" + +[target.'cfg(target_arch = "wasm32")'.dependencies] +getrandom = { version = "0.2.2", features = ["js"] } diff --git a/sqlx-core/src/any/kind.rs b/sqlx-core/src/any/kind.rs index 8d5454ed45..b3278a9650 100644 --- a/sqlx-core/src/any/kind.rs +++ b/sqlx-core/src/any/kind.rs @@ -1,7 +1,7 @@ use crate::error::Error; use std::str::FromStr; -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] pub enum AnyKind { #[cfg(feature = "postgres")] Postgres, diff --git a/sqlx-core/src/any/migrate.rs b/sqlx-core/src/any/migrate.rs index 1825ff939b..04fa659b74 100644 --- a/sqlx-core/src/any/migrate.rs +++ b/sqlx-core/src/any/migrate.rs @@ -223,7 +223,10 @@ impl Migrate for AnyConnection { AnyConnectionKind::MySql(conn) => conn.revert(migration), #[cfg(feature = "mssql")] - AnyConnectionKind::Mssql(_conn) => unimplemented!(), + AnyConnectionKind::Mssql(_conn) => { + let _ = migration; + unimplemented!() + } } } } diff --git a/sqlx-core/src/any/mod.rs b/sqlx-core/src/any/mod.rs index a5f794820d..9dd7fea9f8 100644 --- a/sqlx-core/src/any/mod.rs +++ b/sqlx-core/src/any/mod.rs @@ -1,5 +1,7 @@ //! Generic database driver with the specific driver selected at runtime. +use crate::executor::Executor; + #[macro_use] mod decode; @@ -45,6 +47,10 @@ pub type AnyPool = crate::pool::Pool; pub type AnyPoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = Any>`][Executor]. +pub trait AnyExecutor<'c>: Executor<'c, Database = Any> {} +impl<'c, T: Executor<'c, Database = Any>> AnyExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(AnyArguments<'q>); impl_executor_for_pool_connection!(Any, AnyConnection, AnyRow); diff --git a/sqlx-core/src/arguments.rs b/sqlx-core/src/arguments.rs index 8867176264..2e76433a33 100644 --- a/sqlx-core/src/arguments.rs +++ b/sqlx-core/src/arguments.rs @@ -5,6 +5,7 @@ use crate::encode::Encode; use crate::types::Type; /// A tuple of arguments to be sent to the database. +#[cfg(not(target_arch = "wasm32"))] pub trait Arguments<'q>: Send + Sized + Default { type Database: Database; @@ -18,10 +19,30 @@ pub trait Arguments<'q>: Send + Sized + Default { T: 'q + Send + Encode<'q, Self::Database> + Type; } +#[cfg(target_arch = "wasm32")] +pub trait Arguments<'q>: Sized + Default { + type Database: Database; + + /// Reserves the capacity for at least `additional` more values (of `size` total bytes) to + /// be added to the arguments without a reallocation. + fn reserve(&mut self, additional: usize, size: usize); + + /// Add the value to the end of the arguments. + fn add(&mut self, value: T) + where + T: 'q + Encode<'q, Self::Database> + Type; +} + +#[cfg(not(target_arch = "wasm32"))] pub trait IntoArguments<'q, DB: HasArguments<'q>>: Sized + Send { fn into_arguments(self) -> >::Arguments; } +#[cfg(target_arch = "wasm32")] +pub trait IntoArguments<'q, DB: HasArguments<'q>>: Sized { + fn into_arguments(self) -> >::Arguments; +} + // NOTE: required due to lack of lazy normalization #[allow(unused_macros)] macro_rules! impl_into_arguments_for_arguments { diff --git a/sqlx-core/src/connection.rs b/sqlx-core/src/connection.rs index 402a198c9e..07dc191e1b 100644 --- a/sqlx-core/src/connection.rs +++ b/sqlx-core/src/connection.rs @@ -1,13 +1,20 @@ use crate::database::{Database, HasStatementCache}; use crate::error::Error; use crate::transaction::Transaction; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; + +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; + use log::LevelFilter; use std::fmt::Debug; use std::str::FromStr; use std::time::Duration; /// Represents a single database connection. +#[cfg(not(target_arch = "wasm32"))] pub trait Connection: Send { type Database: Database; @@ -125,6 +132,123 @@ pub trait Connection: Send { } } +#[cfg(target_arch = "wasm32")] +pub trait Connection { + type Database: Database; + + type Options: ConnectOptions; + + /// Explicitly close this database connection. + /// + /// This method is **not required** for safe and consistent operation. However, it is + /// recommended to call it instead of letting a connection `drop` as the database backend + /// will be faster at cleaning up resources. + fn close(self) -> BoxFuture<'static, Result<(), Error>>; + + /// Checks if a connection to the database is still valid. + fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>>; + + /// Begin a new transaction or establish a savepoint within the active transaction. + /// + /// Returns a [`Transaction`] for controlling and tracking the new transaction. + fn begin(&mut self) -> BoxFuture<'_, Result, Error>> + where + Self: Sized; + + /// Execute the function inside a transaction. + /// + /// If the function returns an error, the transaction will be rolled back. If it does not + /// return an error, the transaction will be committed. + /// + /// # Example + /// + /// ```rust + /// use sqlx_core::connection::Connection; + /// use sqlx_core::error::Error; + /// use sqlx_core::executor::Executor; + /// use sqlx_core::postgres::{PgConnection, PgRow}; + /// use sqlx_core::query::query; + /// + /// # pub async fn _f(conn: &mut PgConnection) -> Result, Error> { + /// conn.transaction(|conn|Box::pin(async move { + /// query("select * from ..").fetch_all(conn).await + /// })).await + /// # } + /// ``` + fn transaction(&mut self, callback: F) -> BoxFuture<'_, Result> + where + for<'c> F: FnOnce(&'c mut Transaction<'_, Self::Database>) -> BoxFuture<'c, Result> + + 'static + + Send + + Sync, + Self: Sized, + R: Send, + E: From + Send, + { + Box::pin(async move { + let mut transaction = self.begin().await?; + let ret = callback(&mut transaction).await; + + match ret { + Ok(ret) => { + transaction.commit().await?; + + Ok(ret) + } + Err(err) => { + transaction.rollback().await?; + + Err(err) + } + } + }) + } + + /// The number of statements currently cached in the connection. + fn cached_statements_size(&self) -> usize + where + Self::Database: HasStatementCache, + { + 0 + } + + /// Removes all statements from the cache, closing them on the server if + /// needed. + fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> + where + Self::Database: HasStatementCache, + { + Box::pin(async move { Ok(()) }) + } + + #[doc(hidden)] + fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>>; + + #[doc(hidden)] + fn should_flush(&self) -> bool; + + /// Establish a new database connection. + /// + /// A value of [`Options`][Self::Options] is parsed from the provided connection string. This parsing + /// is database-specific. + #[inline] + fn connect(url: &str) -> BoxFuture<'static, Result> + where + Self: Sized, + { + let options = url.parse(); + Box::pin(async move { Ok(Self::connect_with(&options?).await?) }) + } + + /// Establish a new database connection with the provided options. + fn connect_with(options: &Self::Options) -> BoxFuture<'_, Result> + where + Self: Sized, + { + options.connect() + } +} + #[derive(Clone, Debug)] pub(crate) struct LogSettings { pub(crate) statements_level: LevelFilter, @@ -152,6 +276,7 @@ impl LogSettings { } } +#[cfg(not(target_arch = "wasm32"))] pub trait ConnectOptions: 'static + Send + Sync + FromStr + Debug { type Connection: Connection + ?Sized; @@ -173,3 +298,26 @@ pub trait ConnectOptions: 'static + Send + Sync + FromStr + Debug { .log_slow_statements(LevelFilter::Off, Duration::default()) } } + +#[cfg(target_arch = "wasm32")] +pub trait ConnectOptions: 'static + FromStr + Debug { + type Connection: Connection + ?Sized; + + /// Establish a new database connection with the options specified by `self`. + fn connect(&self) -> BoxFuture<'_, Result> + where + Self::Connection: Sized; + + /// Log executed statements with the specified `level` + fn log_statements(&mut self, level: LevelFilter) -> &mut Self; + + /// Log executed statements with a duration above the specified `duration` + /// at the specified `level`. + fn log_slow_statements(&mut self, level: LevelFilter, duration: Duration) -> &mut Self; + + /// Entirely disables statement logging (both slow and regular). + fn disable_statement_logging(&mut self) -> &mut Self { + self.log_statements(LevelFilter::Off) + .log_slow_statements(LevelFilter::Off, Duration::default()) + } +} diff --git a/sqlx-core/src/database.rs b/sqlx-core/src/database.rs index e1788597fb..ea5fadf23d 100644 --- a/sqlx-core/src/database.rs +++ b/sqlx-core/src/database.rs @@ -68,6 +68,7 @@ use crate::value::{Value, ValueRef}; /// /// This trait encapsulates a complete set of traits that implement a driver for a /// specific database (e.g., MySQL, PostgreSQL). +#[cfg(not(target_arch = "wasm32"))] pub trait Database: 'static + Sized @@ -100,6 +101,38 @@ pub trait Database: type Value: Value + 'static; } +#[cfg(target_arch = "wasm32")] +pub trait Database: + 'static + + Sized + + Debug + + for<'r> HasValueRef<'r, Database = Self> + + for<'q> HasArguments<'q, Database = Self> + + for<'q> HasStatement<'q, Database = Self> +{ + /// The concrete `Connection` implementation for this database. + type Connection: Connection; + + /// The concrete `TransactionManager` implementation for this database. + type TransactionManager: TransactionManager; + + /// The concrete `Row` implementation for this database. + type Row: Row; + + /// The concrete `QueryResult` implementation for this database. + type QueryResult: 'static + Sized + Sync + Default + Extend; + + /// The concrete `Column` implementation for this database. + type Column: Column; + + /// The concrete `TypeInfo` implementation for this database. + type TypeInfo: TypeInfo; + + /// The concrete type used to hold an owned copy of the not-yet-decoded value that was + /// received from the database. + type Value: Value + 'static; +} + /// Associate [`Database`] with a [`ValueRef`](crate::value::ValueRef) of a generic lifetime. /// /// --- diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 5b2f4780a0..245ffff178 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -36,7 +36,7 @@ pub enum Error { /// Error returned from the database. #[error("error returned from database: {0}")] - Database(Box), + Database(#[source] Box), /// Error communicating with the database backend. #[error("error communicating with the server: {0}")] @@ -100,11 +100,13 @@ pub enum Error { #[error("attempted to communicate with a crashed background worker")] WorkerCrashed, - #[cfg(feature = "migrate")] + #[cfg(all(feature = "migrate", not(target_arch = "wasm32")))] #[error("{0}")] Migrate(#[source] Box), } +impl StdError for Box {} + impl Error { pub fn into_database_error(self) -> Option> { match self { @@ -235,7 +237,7 @@ where } } -#[cfg(feature = "migrate")] +#[cfg(all(feature = "migrate", not(target_arch = "wasm32")))] impl From for Error { #[inline] fn from(error: crate::migrate::MigrateError) -> Self { diff --git a/sqlx-core/src/executor.rs b/sqlx-core/src/executor.rs index 2b0e27c219..c73892fcad 100644 --- a/sqlx-core/src/executor.rs +++ b/sqlx-core/src/executor.rs @@ -2,8 +2,17 @@ use crate::database::{Database, HasArguments, HasStatement}; use crate::describe::Describe; use crate::error::Error; use either::Either; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::stream::BoxStream; +#[cfg(target_arch = "wasm32")] +use futures_core::stream::LocalBoxStream as BoxStream; + use futures_util::{future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use std::fmt::Debug; @@ -22,6 +31,7 @@ use std::fmt::Debug; /// * [`&mut PoolConnection`](super::pool::PoolConnection) /// * [`&mut Connection`](super::connection::Connection) /// +#[cfg(not(target_arch = "wasm32"))] pub trait Executor<'c>: Send + Debug + Sized { type Database: Database; @@ -175,6 +185,160 @@ pub trait Executor<'c>: Send + Debug + Sized { 'c: 'e; } +#[cfg(target_arch = "wasm32")] +pub trait Executor<'c>: Debug + Sized { + type Database: Database; + + /// Execute the query and return the total number of rows affected. + fn execute<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result<::QueryResult, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + self.execute_many(query).try_collect().boxed_local() + } + + /// Execute multiple queries and return the rows affected from each query, in a stream. + fn execute_many<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxStream<'e, Result<::QueryResult, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + self.fetch_many(query) + .try_filter_map(|step| async move { + Ok(match step { + Either::Left(rows) => Some(rows), + Either::Right(_) => None, + }) + }) + .boxed_local() + } + + /// Execute the query and return the generated results as a stream. + fn fetch<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxStream<'e, Result<::Row, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + self.fetch_many(query) + .try_filter_map(|step| async move { + Ok(match step { + Either::Left(_) => None, + Either::Right(row) => Some(row), + }) + }) + .boxed_local() + } + + /// Execute multiple queries and return the generated results as a stream + /// from each query, in a stream. + fn fetch_many<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxStream< + 'e, + Result< + Either<::QueryResult, ::Row>, + Error, + >, + > + where + 'c: 'e, + E: Execute<'q, Self::Database>; + + /// Execute the query and return all the generated results, collected into a [`Vec`]. + fn fetch_all<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result::Row>, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + self.fetch(query).try_collect().boxed_local() + } + + /// Execute the query and returns exactly one row. + fn fetch_one<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result<::Row, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + self.fetch_optional(query) + .and_then(|row| match row { + Some(row) => future::ok(row), + None => future::err(Error::RowNotFound), + }) + .boxed_local() + } + + /// Execute the query and returns at most one row. + fn fetch_optional<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result::Row>, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>; + + /// Prepare the SQL query to inspect the type information of its parameters + /// and results. + /// + /// Be advised that when using the `query`, `query_as`, or `query_scalar` functions, the query + /// is transparently prepared and executed. + /// + /// This explicit API is provided to allow access to the statement metadata available after + /// it prepared but before the first row is returned. + #[inline] + fn prepare<'e, 'q: 'e>( + self, + query: &'q str, + ) -> BoxFuture<'e, Result<>::Statement, Error>> + where + 'c: 'e, + { + self.prepare_with(query, &[]) + } + + /// Prepare the SQL query, with parameter type information, to inspect the + /// type information about its parameters and results. + /// + /// Only some database drivers (PostgreSQL, MSSQL) can take advantage of + /// this extra information to influence parameter type inference. + fn prepare_with<'e, 'q: 'e>( + self, + sql: &'q str, + parameters: &'e [::TypeInfo], + ) -> BoxFuture<'e, Result<>::Statement, Error>> + where + 'c: 'e; + + /// Describe the SQL query and return type information about its parameters + /// and results. + /// + /// This is used by compile-time verification in the query macros to + /// power their type inference. + #[doc(hidden)] + fn describe<'e, 'q: 'e>( + self, + sql: &'q str, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e; +} + /// A type that may be executed against a database connection. /// /// Implemented for the following: @@ -182,6 +346,7 @@ pub trait Executor<'c>: Send + Debug + Sized { /// * [`&str`](std::str) /// * [`Query`](super::query::Query) /// +#[cfg(not(target_arch = "wasm32"))] pub trait Execute<'q, DB: Database>: Send + Sized { /// Gets the SQL that will be executed. fn sql(&self) -> &'q str; @@ -200,6 +365,25 @@ pub trait Execute<'q, DB: Database>: Send + Sized { fn persistent(&self) -> bool; } +#[cfg(target_arch = "wasm32")] +pub trait Execute<'q, DB: Database>: Sized { + /// Gets the SQL that will be executed. + fn sql(&self) -> &'q str; + + /// Gets the previously cached statement, if available. + fn statement(&self) -> Option<&>::Statement>; + + /// Returns the arguments to be bound against the query string. + /// + /// Returning `None` for `Arguments` indicates to use a "simple" query protocol and to not + /// prepare the query. Returning `Some(Default::default())` is an empty arguments object that + /// will be prepared (and cached) before execution. + fn take_arguments(&mut self) -> Option<>::Arguments>; + + /// Returns `true` if the statement should be cached. + fn persistent(&self) -> bool; +} + // NOTE: `Execute` is explicitly not implemented for String and &String to make it slightly more // involved to write `conn.execute(format!("SELECT {}", val))` impl<'q, DB: Database> Execute<'q, DB> for &'q str { diff --git a/sqlx-core/src/ext/async_stream.rs b/sqlx-core/src/ext/async_stream.rs index a392609602..1f24732da2 100644 --- a/sqlx-core/src/ext/async_stream.rs +++ b/sqlx-core/src/ext/async_stream.rs @@ -3,7 +3,12 @@ use std::pin::Pin; use std::task::{Context, Poll}; use futures_channel::mpsc; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; + use futures_core::stream::Stream; use futures_util::{pin_mut, FutureExt, SinkExt}; @@ -14,6 +19,7 @@ pub struct TryAsyncStream<'a, T> { future: BoxFuture<'a, Result<(), Error>>, } +#[cfg(not(target_arch = "wasm32"))] impl<'a, T> TryAsyncStream<'a, T> { pub fn new(f: F) -> Self where @@ -38,6 +44,31 @@ impl<'a, T> TryAsyncStream<'a, T> { } } +#[cfg(target_arch = "wasm32")] +impl<'a, T> TryAsyncStream<'a, T> { + pub fn new(f: F) -> Self + where + F: FnOnce(mpsc::Sender>) -> Fut, + Fut: 'a + Future>, + T: 'a, + { + let (mut sender, receiver) = mpsc::channel(0); + + let future = f(sender.clone()); + let future = async move { + if let Err(error) = future.await { + let _ = sender.send(Err(error)).await; + } + + Ok(()) + } + .fuse() + .boxed_local(); + + Self { future, receiver } + } +} + impl<'a, T> Stream for TryAsyncStream<'a, T> { type Item = Result; @@ -62,9 +93,9 @@ macro_rules! try_stream { ($($block:tt)*) => { crate::ext::async_stream::TryAsyncStream::new(move |mut sender| async move { macro_rules! r#yield { - ($v:expr) => { + ($v:expr) => {{ let _ = futures_util::sink::SinkExt::send(&mut sender, Ok($v)).await; - } + }} } $($block)* diff --git a/sqlx-core/src/io/buf_stream.rs b/sqlx-core/src/io/buf_stream.rs index 6b5b55a4ae..8f376cbfb0 100644 --- a/sqlx-core/src/io/buf_stream.rs +++ b/sqlx-core/src/io/buf_stream.rs @@ -15,7 +15,7 @@ pub struct BufStream where S: AsyncRead + AsyncWrite + Unpin, { - stream: S, + pub(crate) stream: S, // writes with `write` to the underlying stream are buffered // this can be flushed with `flush` diff --git a/sqlx-core/src/lib.rs b/sqlx-core/src/lib.rs index 9e50d0d842..963dc46904 100644 --- a/sqlx-core/src/lib.rs +++ b/sqlx-core/src/lib.rs @@ -48,6 +48,7 @@ pub mod types; #[macro_use] pub mod query; +#[cfg(not(target_arch = "wasm32"))] #[macro_use] pub mod acquire; @@ -63,6 +64,7 @@ pub mod describe; pub mod executor; pub mod from_row; mod io; +#[cfg(not(target_arch = "wasm32"))] mod logger; mod net; pub mod query_as; @@ -71,7 +73,7 @@ pub mod row; pub mod type_info; pub mod value; -#[cfg(feature = "migrate")] +#[cfg(all(feature = "migrate", not(target_arch = "wasm32")))] pub mod migrate; #[cfg(all( diff --git a/sqlx-core/src/mssql/mod.rs b/sqlx-core/src/mssql/mod.rs index 068e77a750..ed3b325871 100644 --- a/sqlx-core/src/mssql/mod.rs +++ b/sqlx-core/src/mssql/mod.rs @@ -1,5 +1,7 @@ //! Microsoft SQL (MSSQL) database driver. +use crate::executor::Executor; + mod arguments; mod column; mod connection; @@ -32,6 +34,10 @@ pub use value::{MssqlValue, MssqlValueRef}; /// An alias for [`Pool`][crate::pool::Pool], specialized for MSSQL. pub type MssqlPool = crate::pool::Pool; +/// An alias for [`Executor<'_, Database = Mssql>`][Executor]. +pub trait MssqlExecutor<'c>: Executor<'c, Database = Mssql> {} +impl<'c, T: Executor<'c, Database = Mssql>> MssqlExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(MssqlArguments); impl_executor_for_pool_connection!(Mssql, MssqlConnection, MssqlRow); diff --git a/sqlx-core/src/mssql/types/str.rs b/sqlx-core/src/mssql/types/str.rs index 4902d783be..048dd84cd3 100644 --- a/sqlx-core/src/mssql/types/str.rs +++ b/sqlx-core/src/mssql/types/str.rs @@ -5,6 +5,7 @@ use crate::mssql::io::MssqlBufMutExt; use crate::mssql::protocol::type_info::{Collation, CollationFlags, DataType, TypeInfo}; use crate::mssql::{Mssql, MssqlTypeInfo, MssqlValueRef}; use crate::types::Type; +use std::borrow::Cow; impl Type for str { fn type_info() -> MssqlTypeInfo { @@ -81,3 +82,33 @@ impl Decode<'_, Mssql> for String { .into_owned()) } } + +impl Encode<'_, Mssql> for Cow<'_, str> { + fn produces(&self) -> Option { + match self { + Cow::Borrowed(str) => <&str as Encode>::produces(str), + Cow::Owned(str) => <&str as Encode>::produces(&(str.as_ref())), + } + } + + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode_by_ref(str, buf), + Cow::Owned(str) => <&str as Encode>::encode_by_ref(&(str.as_ref()), buf), + } + } +} + +impl<'r> Decode<'r, Mssql> for Cow<'r, str> { + fn decode(value: MssqlValueRef<'r>) -> Result { + Ok(Cow::Owned( + value + .type_info + .0 + .encoding()? + .decode_without_bom_handling(value.as_bytes()?) + .0 + .into_owned(), + )) + } +} diff --git a/sqlx-core/src/mysql/connection/executor.rs b/sqlx-core/src/mysql/connection/executor.rs index 012714d710..9cb0690bda 100644 --- a/sqlx-core/src/mysql/connection/executor.rs +++ b/sqlx-core/src/mysql/connection/executor.rs @@ -4,7 +4,7 @@ use crate::error::Error; use crate::executor::{Execute, Executor}; use crate::ext::ustr::UStr; use crate::logger::QueryLogger; -use crate::mysql::connection::stream::Busy; +use crate::mysql::connection::stream::Waiting; use crate::mysql::io::MySqlBufExt; use crate::mysql::protocol::response::Status; use crate::mysql::protocol::statement::{ @@ -93,7 +93,7 @@ impl MySqlConnection { let mut logger = QueryLogger::new(sql, self.log_settings.clone()); self.stream.wait_until_ready().await?; - self.stream.busy = Busy::Result; + self.stream.waiting.push_back(Waiting::Result); Ok(Box::pin(try_stream! { // make a slot for the shared column data @@ -146,12 +146,12 @@ impl MySqlConnection { continue; } - self.stream.busy = Busy::NotBusy; + self.stream.waiting.pop_front(); return Ok(()); } // otherwise, this first packet is the start of the result-set metadata, - self.stream.busy = Busy::Row; + *self.stream.waiting.front_mut().unwrap() = Waiting::Row; let num_columns = packet.get_uint_lenenc() as usize; // column count @@ -179,11 +179,11 @@ impl MySqlConnection { if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { // more result sets exist, continue to the next one - self.stream.busy = Busy::Result; + *self.stream.waiting.front_mut().unwrap() = Waiting::Result; break; } - self.stream.busy = Busy::NotBusy; + self.stream.waiting.pop_front(); return Ok(()); } diff --git a/sqlx-core/src/mysql/connection/mod.rs b/sqlx-core/src/mysql/connection/mod.rs index 509426a63d..4ade06beeb 100644 --- a/sqlx-core/src/mysql/connection/mod.rs +++ b/sqlx-core/src/mysql/connection/mod.rs @@ -16,7 +16,7 @@ mod executor; mod stream; mod tls; -pub(crate) use stream::{Busy, MySqlStream}; +pub(crate) use stream::{MySqlStream, Waiting}; const MAX_PACKET_SIZE: u32 = 1024; diff --git a/sqlx-core/src/mysql/connection/stream.rs b/sqlx-core/src/mysql/connection/stream.rs index 8b2f453608..e43cf253c6 100644 --- a/sqlx-core/src/mysql/connection/stream.rs +++ b/sqlx-core/src/mysql/connection/stream.rs @@ -1,3 +1,4 @@ +use std::collections::VecDeque; use std::ops::{Deref, DerefMut}; use bytes::{Buf, Bytes}; @@ -16,15 +17,13 @@ pub struct MySqlStream { pub(crate) server_version: (u16, u16, u16), pub(super) capabilities: Capabilities, pub(crate) sequence_id: u8, - pub(crate) busy: Busy, + pub(crate) waiting: VecDeque, pub(crate) charset: CharSet, pub(crate) collation: Collation, } #[derive(Debug, PartialEq, Eq)] -pub(crate) enum Busy { - NotBusy, - +pub(crate) enum Waiting { // waiting for a result set Result, @@ -65,7 +64,7 @@ impl MySqlStream { } Ok(Self { - busy: Busy::NotBusy, + waiting: VecDeque::new(), capabilities, server_version: (0, 0, 0), sequence_id: 0, @@ -80,32 +79,32 @@ impl MySqlStream { self.stream.flush().await?; } - while self.busy != Busy::NotBusy { - while self.busy == Busy::Row { + while !self.waiting.is_empty() { + while self.waiting.front() == Some(&Waiting::Row) { let packet = self.recv_packet().await?; if packet[0] == 0xfe && packet.len() < 9 { let eof = packet.eof(self.capabilities)?; - self.busy = if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - Busy::Result + if eof.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { + *self.waiting.front_mut().unwrap() = Waiting::Result; } else { - Busy::NotBusy + self.waiting.pop_front(); }; } } - while self.busy == Busy::Result { + while self.waiting.front() == Some(&Waiting::Result) { let packet = self.recv_packet().await?; if packet[0] == 0x00 || packet[0] == 0xff { let ok = packet.ok()?; if !ok.status.contains(Status::SERVER_MORE_RESULTS_EXISTS) { - self.busy = Busy::NotBusy; + self.waiting.pop_front(); } } else { - self.busy = Busy::Row; + *self.waiting.front_mut().unwrap() = Waiting::Row; self.skip_result_metadata(packet).await?; } } @@ -150,7 +149,7 @@ impl MySqlStream { // TODO: packet joining if payload[0] == 0xff { - self.busy = Busy::NotBusy; + self.waiting.pop_front(); // instead of letting this packet be looked at everywhere, we check here // and emit a proper Error diff --git a/sqlx-core/src/mysql/migrate.rs b/sqlx-core/src/mysql/migrate.rs index 248fd6298e..c3898e9fc2 100644 --- a/sqlx-core/src/mysql/migrate.rs +++ b/sqlx-core/src/mysql/migrate.rs @@ -8,7 +8,6 @@ use crate::mysql::{MySql, MySqlConnectOptions, MySqlConnection}; use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; -use crc::crc32; use futures_core::future::BoxFuture; use std::str::FromStr; use std::time::Duration; @@ -266,9 +265,10 @@ async fn current_database(conn: &mut MySqlConnection) -> Result String { + const CRC_IEEE: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); // 0x3d32ad9e chosen by fair dice roll format!( "{:x}", - 0x3d32ad9e * (crc32::checksum_ieee(database_name.as_bytes()) as i64) + 0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64) ) } diff --git a/sqlx-core/src/mysql/mod.rs b/sqlx-core/src/mysql/mod.rs index dc7f969936..e108e8591f 100644 --- a/sqlx-core/src/mysql/mod.rs +++ b/sqlx-core/src/mysql/mod.rs @@ -1,5 +1,7 @@ //! **MySQL** database driver. +use crate::executor::Executor; + mod arguments; mod collation; mod column; @@ -39,6 +41,10 @@ pub type MySqlPool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for MySQL. pub type MySqlPoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = MySql>`][Executor]. +pub trait MySqlExecutor<'c>: Executor<'c, Database = MySql> {} +impl<'c, T: Executor<'c, Database = MySql>> MySqlExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(MySqlArguments); impl_executor_for_pool_connection!(MySql, MySqlConnection, MySqlRow); diff --git a/sqlx-core/src/mysql/transaction.rs b/sqlx-core/src/mysql/transaction.rs index b62fc143b5..97cb121d0e 100644 --- a/sqlx-core/src/mysql/transaction.rs +++ b/sqlx-core/src/mysql/transaction.rs @@ -2,7 +2,7 @@ use futures_core::future::BoxFuture; use crate::error::Error; use crate::executor::Executor; -use crate::mysql::connection::Busy; +use crate::mysql::connection::Waiting; use crate::mysql::protocol::text::Query; use crate::mysql::{MySql, MySqlConnection}; use crate::transaction::{ @@ -57,7 +57,7 @@ impl TransactionManager for MySqlTransactionManager { let depth = conn.transaction_depth; if depth > 0 { - conn.stream.busy = Busy::Result; + conn.stream.waiting.push_back(Waiting::Result); conn.stream.sequence_id = 0; conn.stream .write_packet(Query(&*rollback_ansi_transaction_sql(depth))); diff --git a/sqlx-core/src/mysql/types/chrono.rs b/sqlx-core/src/mysql/types/chrono.rs index 5a261804bf..76e8b2985d 100644 --- a/sqlx-core/src/mysql/types/chrono.rs +++ b/sqlx-core/src/mysql/types/chrono.rs @@ -1,7 +1,7 @@ use std::convert::TryFrom; use bytes::Buf; -use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; +use chrono::{DateTime, Datelike, Local, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; use crate::decode::Decode; use crate::encode::{Encode, IsNull}; @@ -21,12 +21,14 @@ impl Type for DateTime { } } +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl Encode<'_, MySql> for DateTime { fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { Encode::::encode(&self.naive_utc(), buf) } } +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). impl<'r> Decode<'r, MySql> for DateTime { fn decode(value: MySqlValueRef<'r>) -> Result { let naive: NaiveDateTime = Decode::::decode(value)?; @@ -35,6 +37,30 @@ impl<'r> Decode<'r, MySql> for DateTime { } } +impl Type for DateTime { + fn type_info() -> MySqlTypeInfo { + MySqlTypeInfo::binary(ColumnType::Timestamp) + } + + fn compatible(ty: &MySqlTypeInfo) -> bool { + matches!(ty.r#type, ColumnType::Datetime | ColumnType::Timestamp) + } +} + +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). +impl Encode<'_, MySql> for DateTime { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + Encode::::encode(&self.naive_utc(), buf) + } +} + +/// Note: assumes the connection's `time_zone` is set to `+00:00` (UTC). +impl<'r> Decode<'r, MySql> for DateTime { + fn decode(value: MySqlValueRef<'r>) -> Result { + Ok( as Decode<'r, MySql>>::decode(value)?.with_timezone(&Local)) + } +} + impl Type for NaiveTime { fn type_info() -> MySqlTypeInfo { MySqlTypeInfo::binary(ColumnType::Time) diff --git a/sqlx-core/src/mysql/types/str.rs b/sqlx-core/src/mysql/types/str.rs index 19e3de62c9..076858901b 100644 --- a/sqlx-core/src/mysql/types/str.rs +++ b/sqlx-core/src/mysql/types/str.rs @@ -5,6 +5,7 @@ use crate::mysql::io::MySqlBufMutExt; use crate::mysql::protocol::text::{ColumnFlags, ColumnType}; use crate::mysql::{MySql, MySqlTypeInfo, MySqlValueRef}; use crate::types::Type; +use std::borrow::Cow; const COLLATE_UTF8_GENERAL_CI: u16 = 33; const COLLATE_UTF8_UNICODE_CI: u16 = 192; @@ -80,3 +81,18 @@ impl Decode<'_, MySql> for String { <&str as Decode>::decode(value).map(ToOwned::to_owned) } } + +impl Encode<'_, MySql> for Cow<'_, str> { + fn encode_by_ref(&self, buf: &mut Vec) -> IsNull { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), + Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), + } + } +} + +impl<'r> Decode<'r, MySql> for Cow<'r, str> { + fn decode(value: MySqlValueRef<'r>) -> Result { + value.as_str().map(Cow::Borrowed) + } +} diff --git a/sqlx-core/src/net/mod.rs b/sqlx-core/src/net/mod.rs index 6b8371ef50..7f712a6869 100644 --- a/sqlx-core/src/net/mod.rs +++ b/sqlx-core/src/net/mod.rs @@ -1,16 +1,20 @@ mod socket; + +#[cfg(not(target_arch = "wasm32"))] mod tls; pub use socket::Socket; + +#[cfg(not(target_arch = "wasm32"))] pub use tls::{CertificateInput, MaybeTlsStream}; -#[cfg(feature = "_rt-async-std")] +#[cfg(any(feature = "_rt-async-std", target_arch = "wasm32"))] type PollReadBuf<'a> = [u8]; #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] type PollReadBuf<'a> = sqlx_rt::ReadBuf<'a>; -#[cfg(feature = "_rt-async-std")] +#[cfg(any(feature = "_rt-async-std", target_arch = "wasm32"))] type PollReadOut = usize; #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] diff --git a/sqlx-core/src/net/socket.rs b/sqlx-core/src/net/socket.rs index 06d5575c01..8453ac4e98 100644 --- a/sqlx-core/src/net/socket.rs +++ b/sqlx-core/src/net/socket.rs @@ -5,22 +5,33 @@ use std::path::Path; use std::pin::Pin; use std::task::{Context, Poll}; +#[cfg(not(target_arch = "wasm32"))] use sqlx_rt::{AsyncRead, AsyncWrite, TcpStream}; +#[cfg(target_arch = "wasm32")] +use sqlx_rt::{AsyncRead, AsyncWrite, IoStream, WsMeta, WsStreamIo}; +#[cfg(target_arch = "wasm32")] +type WSIoStream = IoStream>; + #[derive(Debug)] pub enum Socket { + #[cfg(not(target_arch = "wasm32"))] Tcp(TcpStream), - #[cfg(unix)] + #[cfg(all(unix, not(target_arch = "wasm32")))] Unix(sqlx_rt::UnixStream), + + #[cfg(target_arch = "wasm32")] + WS((WsMeta, WSIoStream)), } impl Socket { + #[cfg(not(target_arch = "wasm32"))] pub async fn connect_tcp(host: &str, port: u16) -> io::Result { TcpStream::connect((host, port)).await.map(Socket::Tcp) } - #[cfg(unix)] + #[cfg(all(unix, not(target_arch = "wasm32")))] pub async fn connect_uds(path: impl AsRef) -> io::Result { sqlx_rt::UnixStream::connect(path.as_ref()) .await @@ -35,15 +46,23 @@ impl Socket { )) } + #[cfg(target_arch = "wasm32")] + pub async fn connect_ws(url: impl AsRef) -> io::Result { + WsMeta::connect(url, None) + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "can't connect to ws stream")) + .map(|(m, s)| Socket::WS((m, s.into_io()))) + } + pub async fn shutdown(&mut self) -> io::Result<()> { - #[cfg(feature = "_rt-async-std")] + #[cfg(all(feature = "_rt-async-std", not(target_arch = "wasm32")))] { use std::net::Shutdown; match self { Socket::Tcp(s) => s.shutdown(Shutdown::Both), - #[cfg(unix)] + #[cfg(all(unix, not(target_arch = "wasm32")))] Socket::Unix(s) => s.shutdown(Shutdown::Both), } } @@ -59,6 +78,15 @@ impl Socket { Socket::Unix(s) => s.shutdown().await, } } + + #[cfg(target_arch = "wasm32")] + { + let Socket::WS((m, _)) = self; + m.close() + .await + .map_err(|_| io::Error::new(io::ErrorKind::Other, "error closing ws stream")) + .map(|_| ()) + } } } @@ -69,9 +97,13 @@ impl AsyncRead for Socket { buf: &mut super::PollReadBuf<'_>, ) -> Poll> { match &mut *self { + #[cfg(not(target_arch = "wasm32"))] Socket::Tcp(s) => Pin::new(s).poll_read(cx, buf), - #[cfg(unix)] + #[cfg(target_arch = "wasm32")] + Socket::WS((_, s)) => Pin::new(s).poll_read(cx, buf), + + #[cfg(all(unix, not(target_arch = "wasm32")))] Socket::Unix(s) => Pin::new(s).poll_read(cx, buf), } } @@ -84,18 +116,28 @@ impl AsyncWrite for Socket { buf: &[u8], ) -> Poll> { match &mut *self { + #[cfg(not(target_arch = "wasm32"))] Socket::Tcp(s) => Pin::new(s).poll_write(cx, buf), - #[cfg(unix)] + #[cfg(target_arch = "wasm32")] + Socket::WS((_, s)) => Pin::new(s).poll_write(cx, buf), + + #[cfg(all(unix, not(target_arch = "wasm32")))] Socket::Unix(s) => Pin::new(s).poll_write(cx, buf), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { + #[cfg(not(target_arch = "wasm32"))] Socket::Tcp(s) => Pin::new(s).poll_flush(cx), - #[cfg(unix)] + #[cfg(target_arch = "wasm32")] + Socket::WS((_, s)) => Pin::new(s) + .poll_flush(cx) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "error flushing ws stream")), + + #[cfg(all(unix, not(target_arch = "wasm32")))] Socket::Unix(s) => Pin::new(s).poll_flush(cx), } } @@ -103,19 +145,32 @@ impl AsyncWrite for Socket { #[cfg(any(feature = "_rt-actix", feature = "_rt-tokio"))] fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { + #[cfg(not(target_arch = "wasm32"))] Socket::Tcp(s) => Pin::new(s).poll_shutdown(cx), - #[cfg(unix)] + #[cfg(all(unix, not(target_arch = "wasm32")))] Socket::Unix(s) => Pin::new(s).poll_shutdown(cx), } } - #[cfg(feature = "_rt-async-std")] + #[cfg(all(feature = "_rt-async-std", not(target_arch = "wasm32")))] fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { Socket::Tcp(s) => Pin::new(s).poll_close(cx), - #[cfg(unix)] + #[cfg(all(unix, not(target_arch = "wasm32")))] + Socket::Unix(s) => Pin::new(s).poll_close(cx), + } + } + + #[cfg(target_arch = "wasm32")] + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + Socket::WS((_, s)) => Pin::new(s) + .poll_close(cx) + .map_err(|_| io::Error::new(io::ErrorKind::Other, "error closing ws stream")), + + #[cfg(all(unix, not(target_arch = "wasm32")))] Socket::Unix(s) => Pin::new(s).poll_close(cx), } } diff --git a/sqlx-core/src/pool/connection.rs b/sqlx-core/src/pool/connection.rs index 732c1a8c92..88864566c1 100644 --- a/sqlx-core/src/pool/connection.rs +++ b/sqlx-core/src/pool/connection.rs @@ -1,13 +1,17 @@ -use super::inner::{DecrementSizeGuard, SharedPool}; -use crate::connection::Connection; -use crate::database::Database; -use crate::error::Error; -use sqlx_rt::spawn; use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; use std::sync::Arc; use std::time::Instant; +use futures_intrusive::sync::SemaphoreReleaser; + +use crate::connection::Connection; +use crate::database::Database; +use crate::error::Error; + +use super::inner::{DecrementSizeGuard, SharedPool}; +use std::future::Future; + /// A connection managed by a [`Pool`][crate::pool::Pool]. /// /// Will be returned to the pool on-drop. @@ -28,8 +32,8 @@ pub(super) struct Idle { /// RAII wrapper for connections being handled by functions that may drop them pub(super) struct Floating<'p, C> { - inner: C, - guard: DecrementSizeGuard<'p>, + pub(super) inner: C, + pub(super) guard: DecrementSizeGuard<'p>, } const DEREF_ERR: &str = "(bug) connection already released to pool"; @@ -57,43 +61,85 @@ impl DerefMut for PoolConnection { impl PoolConnection { /// Explicitly release a connection from the pool - pub fn release(mut self) -> DB::Connection { + #[deprecated = "renamed to `.detach()` for clarity"] + pub fn release(self) -> DB::Connection { + self.detach() + } + + /// Detach this connection from the pool, allowing it to open a replacement. + /// + /// Note that if your application uses a single shared pool, this + /// effectively lets the application exceed the `max_connections` setting. + /// + /// If you want the pool to treat this connection as permanently checked-out, + /// use [`.leak()`][Self::leak] instead. + pub fn detach(mut self) -> DB::Connection { self.live .take() .expect("PoolConnection double-dropped") .float(&self.pool) .detach() } + + /// Detach this connection from the pool, treating it as permanently checked-out. + /// + /// This effectively will reduce the maximum capacity of the pool by 1 every time it is used. + /// + /// If you don't want to impact the pool's capacity, use [`.detach()`][Self::detach] instead. + pub fn leak(mut self) -> DB::Connection { + self.live.take().expect("PoolConnection double-dropped").raw + } + + /// Test the connection to make sure it is still live before returning it to the pool. + /// + /// This effectively runs the drop handler eagerly instead of spawning a task to do it. + pub(crate) fn return_to_pool(&mut self) -> impl Future + Send + 'static { + // we want these to happen synchronously so the drop handler doesn't try to spawn a task anyway + // this also makes the returned future `'static` + let live = self.live.take(); + let pool = self.pool.clone(); + + async move { + let mut floating = if let Some(live) = live { + live.float(&pool) + } else { + return; + }; + + // test the connection on-release to ensure it is still viable + // if an Executor future/stream is dropped during an `.await` call, the connection + // is likely to be left in an inconsistent state, in which case it should not be + // returned to the pool; also of course, if it was dropped due to an error + // this is simply a band-aid as SQLx-next (0.6) connections should be able + // to recover from cancellations + if let Err(e) = floating.raw.ping().await { + log::warn!( + "error occurred while testing the connection on-release: {}", + e + ); + + // we now consider the connection to be broken; just drop it to close + // trying to close gracefully might cause something weird to happen + drop(floating); + } else { + // if the connection is still viable, release it to the pool + pool.release(floating); + } + } + } } /// Returns the connection to the [`Pool`][crate::pool::Pool] it was checked-out from. impl Drop for PoolConnection { fn drop(&mut self) { - if let Some(live) = self.live.take() { - let pool = self.pool.clone(); - spawn(async move { - let mut floating = live.float(&pool); - - // test the connection on-release to ensure it is still viable - // if an Executor future/stream is dropped during an `.await` call, the connection - // is likely to be left in an inconsistent state, in which case it should not be - // returned to the pool; also of course, if it was dropped due to an error - // this is simply a band-aid as SQLx-next (0.6) connections should be able - // to recover from cancellations - if let Err(e) = floating.raw.ping().await { - log::warn!( - "error occurred while testing the connection on-release: {}", - e - ); - - // we now consider the connection to be broken; just drop it to close - // trying to close gracefully might cause something weird to happen - drop(floating); - } else { - // if the connection is still viable, release it to th epool - pool.release(floating); - } - }); + if self.live.is_some() { + #[cfg(not(feature = "_rt-async-std"))] + if let Ok(handle) = sqlx_rt::Handle::try_current() { + handle.spawn(self.return_to_pool()); + } + + #[cfg(feature = "_rt-async-std")] + sqlx_rt::spawn(self.return_to_pool()); } } } @@ -102,7 +148,8 @@ impl Live { pub fn float(self, pool: &SharedPool) -> Floating<'_, Self> { Floating { inner: self, - guard: DecrementSizeGuard::new(pool), + // create a new guard from a previously leaked permit + guard: DecrementSizeGuard::new_permit(pool), } } @@ -128,13 +175,6 @@ impl DerefMut for Idle { } } -impl<'s, C> Floating<'s, C> { - pub fn into_leakable(self) -> C { - self.guard.cancel(); - self.inner - } -} - impl<'s, DB: Database> Floating<'s, Live> { pub fn new_live(conn: DB::Connection, guard: DecrementSizeGuard<'s>) -> Self { Self { @@ -161,6 +201,11 @@ impl<'s, DB: Database> Floating<'s, Live> { } } + pub async fn close(self) -> Result<(), Error> { + // `guard` is dropped as intended + self.inner.raw.close().await + } + pub fn detach(self) -> DB::Connection { self.inner.raw } @@ -174,10 +219,14 @@ impl<'s, DB: Database> Floating<'s, Live> { } impl<'s, DB: Database> Floating<'s, Idle> { - pub fn from_idle(idle: Idle, pool: &'s SharedPool) -> Self { + pub fn from_idle( + idle: Idle, + pool: &'s SharedPool, + permit: SemaphoreReleaser<'s>, + ) -> Self { Self { inner: idle, - guard: DecrementSizeGuard::new(pool), + guard: DecrementSizeGuard::from_permit(pool, permit), } } @@ -192,9 +241,12 @@ impl<'s, DB: Database> Floating<'s, Idle> { } } - pub async fn close(self) -> Result<(), Error> { + pub async fn close(self) -> DecrementSizeGuard<'s> { // `guard` is dropped as intended - self.inner.live.raw.close().await + if let Err(e) = self.inner.live.raw.close().await { + log::debug!("error occurred while closing the pool connection: {}", e); + } + self.guard } } diff --git a/sqlx-core/src/pool/inner.rs b/sqlx-core/src/pool/inner.rs index f9e5df43b3..d67cdfc0f1 100644 --- a/sqlx-core/src/pool/inner.rs +++ b/sqlx-core/src/pool/inner.rs @@ -4,23 +4,28 @@ use crate::connection::Connection; use crate::database::Database; use crate::error::Error; use crate::pool::{deadline_as_timeout, PoolOptions}; -use crossbeam_queue::{ArrayQueue, SegQueue}; -use futures_core::task::{Poll, Waker}; -use futures_util::future; +use crossbeam_queue::ArrayQueue; + +use futures_intrusive::sync::{Semaphore, SemaphoreReleaser}; + use std::cmp; use std::mem; use std::ptr; use std::sync::atomic::{AtomicBool, AtomicU32, Ordering}; -use std::sync::{Arc, Weak}; -use std::task::Context; +use std::sync::Arc; + use std::time::{Duration, Instant}; -type Waiters = SegQueue>; +/// Ihe number of permits to release to wake all waiters, such as on `SharedPool::close()`. +/// +/// This should be large enough to realistically wake all tasks waiting on the pool without +/// potentially overflowing the permits count in the semaphore itself. +const WAKE_ALL_PERMITS: usize = usize::MAX / 2; pub(crate) struct SharedPool { pub(super) connect_options: ::Options, pub(super) idle_conns: ArrayQueue>, - waiters: Waiters, + pub(super) semaphore: Semaphore, pub(super) size: AtomicU32, is_closed: AtomicBool, pub(super) options: PoolOptions, @@ -31,10 +36,18 @@ impl SharedPool { options: PoolOptions, connect_options: ::Options, ) -> Arc { + let capacity = options.max_connections as usize; + + // ensure the permit count won't overflow if we release `WAKE_ALL_PERMITS` + // this assert should never fire on 64-bit targets as `max_connections` is a u32 + let _ = capacity + .checked_add(WAKE_ALL_PERMITS) + .expect("max_connections exceeds max capacity of the pool"); + let pool = Self { connect_options, - idle_conns: ArrayQueue::new(options.max_connections as usize), - waiters: SegQueue::new(), + idle_conns: ArrayQueue::new(capacity), + semaphore: Semaphore::new(options.fair, capacity), size: AtomicU32::new(0), is_closed: AtomicBool::new(false), options, @@ -61,148 +74,133 @@ impl SharedPool { } pub(super) async fn close(&self) { - self.is_closed.store(true, Ordering::Release); - while let Some(waker) = self.waiters.pop() { - if let Some(waker) = waker.upgrade() { - waker.wake(); - } + let already_closed = self.is_closed.swap(true, Ordering::AcqRel); + + if !already_closed { + // if we were the one to mark this closed, release enough permits to wake all waiters + // we can't just do `usize::MAX` because that would overflow + // and we can't do this more than once cause that would _also_ overflow + self.semaphore.release(WAKE_ALL_PERMITS); } - // ensure we wait until the pool is actually closed - while self.size() > 0 { - if let Some(idle) = self.idle_conns.pop() { - if let Err(e) = Floating::from_idle(idle, self).close().await { - log::warn!("error occurred while closing the pool connection: {}", e); - } - } + // wait for all permits to be released + let _permits = self + .semaphore + .acquire(WAKE_ALL_PERMITS + (self.options.max_connections as usize)) + .await; - // yield to avoid starving the executor - sqlx_rt::yield_now().await; + while let Some(idle) = self.idle_conns.pop() { + let _ = idle.live.float(self).close().await; } } #[inline] - pub(super) fn try_acquire(&self) -> Option>> { - // don't cut in line - if self.options.fair && !self.waiters.is_empty() { + pub(super) fn try_acquire(&self) -> Option>> { + if self.is_closed() { return None; } - Some(self.pop_idle()?.into_live()) + + let permit = self.semaphore.try_acquire(1)?; + self.pop_idle(permit).ok() } - fn pop_idle(&self) -> Option>> { - if self.is_closed.load(Ordering::Acquire) { - return None; + fn pop_idle<'a>( + &'a self, + permit: SemaphoreReleaser<'a>, + ) -> Result>, SemaphoreReleaser<'a>> { + if let Some(idle) = self.idle_conns.pop() { + Ok(Floating::from_idle(idle, self, permit)) + } else { + Err(permit) } - - Some(Floating::from_idle(self.idle_conns.pop()?, self)) } pub(super) fn release(&self, mut floating: Floating<'_, Live>) { if let Some(test) = &self.options.after_release { if !test(&mut floating.raw) { - // drop the connection and do not return to the pool + // drop the connection and do not return it to the pool return; } } - let is_ok = self - .idle_conns - .push(floating.into_idle().into_leakable()) - .is_ok(); + let Floating { inner: idle, guard } = floating.into_idle(); - if !is_ok { + if !self.idle_conns.push(idle).is_ok() { panic!("BUG: connection queue overflow in release()"); } - wake_one(&self.waiters); + // NOTE: we need to make sure we drop the permit *after* we push to the idle queue + // don't decrease the size + guard.release_permit(); } /// Try to atomically increment the pool size for a new connection. /// /// Returns `None` if we are at max_connections or if the pool is closed. - pub(super) fn try_increment_size(&self) -> Option> { - if self.is_closed() { - return None; - } - - let mut size = self.size(); - - while size < self.options.max_connections { - match self - .size - .compare_exchange(size, size + 1, Ordering::AcqRel, Ordering::Acquire) - { - Ok(_) => return Some(DecrementSizeGuard::new(self)), - Err(new_size) => size = new_size, - } + pub(super) fn try_increment_size<'a>( + &'a self, + permit: SemaphoreReleaser<'a>, + ) -> Result, SemaphoreReleaser<'a>> { + match self + .size + .fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| { + size.checked_add(1) + .filter(|size| size <= &self.options.max_connections) + }) { + // we successfully incremented the size + Ok(_) => Ok(DecrementSizeGuard::from_permit(self, permit)), + // the pool is at max capacity + Err(_) => Err(permit), } - - None } #[allow(clippy::needless_lifetimes)] pub(super) async fn acquire<'s>(&'s self) -> Result>, Error> { - let start = Instant::now(); - let deadline = start + self.options.connect_timeout; - let mut waited = !self.options.fair; - - // the strong ref of the `Weak` that we push to the queue - // initialized during the `timeout()` call below - // as long as we own this, we keep our place in line - let mut waiter: Option> = None; - - // Unless the pool has been closed ... - while !self.is_closed() { - // Don't cut in line unless no one is waiting - if waited || self.waiters.is_empty() { - // Attempt to immediately acquire a connection. This will return Some - // if there is an idle connection in our channel. - if let Some(conn) = self.pop_idle() { - if let Some(live) = check_conn(conn, &self.options).await { - return Ok(live); - } - } + if self.is_closed() { + return Err(Error::PoolClosed); + } - // check if we can open a new connection - if let Some(guard) = self.try_increment_size() { - // pool has slots available; open a new connection - return self.connection(deadline, guard).await; - } - } + let deadline = Instant::now() + self.options.connect_timeout; - if let Some(ref waiter) = waiter { - // return the waiter to the queue, note that this does put it to the back - // of the queue when it should ideally stay at the front - self.waiters.push(Arc::downgrade(&waiter.inner)); - } + sqlx_rt::timeout( + self.options.connect_timeout, + async { + loop { + let permit = self.semaphore.acquire(1).await; - sqlx_rt::timeout( - // Returns an error if `deadline` passes - deadline_as_timeout::(deadline)?, - // `poll_fn` gets us easy access to a `Waker` that we can push to our queue - future::poll_fn(|cx| -> Poll<()> { - let waiter = waiter.get_or_insert_with(|| Waiter::push_new(cx, &self.waiters)); - - if waiter.is_woken() { - waiter.actually_woke = true; - Poll::Ready(()) - } else { - Poll::Pending + if self.is_closed() { + return Err(Error::PoolClosed); } - }), - ) - .await - .map_err(|_| Error::PoolTimedOut)?; - if let Some(ref mut waiter) = waiter { - waiter.reset(); + // First attempt to pop a connection from the idle queue. + let guard = match self.pop_idle(permit) { + + // Then, check that we can use it... + Ok(conn) => match check_conn(conn, &self.options).await { + + // All good! + Ok(live) => return Ok(live), + + // if the connection isn't usable for one reason or another, + // we get the `DecrementSizeGuard` back to open a new one + Err(guard) => guard, + }, + Err(permit) => if let Ok(guard) = self.try_increment_size(permit) { + // we can open a new connection + guard + } else { + log::debug!("woke but was unable to acquire idle connection or open new one; retrying"); + continue; + } + }; + + // Attempt to connect... + return self.connection(deadline, guard).await; + } } - - waited = true; - } - - Err(Error::PoolClosed) + ) + .await + .map_err(|_| Error::PoolTimedOut)? } pub(super) async fn connection<'s>( @@ -277,14 +275,13 @@ fn is_beyond_idle(idle: &Idle, options: &PoolOptions) -> b async fn check_conn<'s: 'p, 'p, DB: Database>( mut conn: Floating<'s, Idle>, options: &'p PoolOptions, -) -> Option>> { +) -> Result>, DecrementSizeGuard<'s>> { // If the connection we pulled has expired, close the connection and // immediately create a new connection if is_beyond_lifetime(&conn, options) { // we're closing the connection either way // close the connection but don't really care about the result - let _ = conn.close().await; - return None; + return Err(conn.close().await); } else if options.test_before_acquire { // Check that the connection is still live if let Err(e) = conn.ping().await { @@ -293,18 +290,18 @@ async fn check_conn<'s: 'p, 'p, DB: Database>( // the error itself here isn't necessarily unexpected so WARN is too strong log::info!("ping on idle connection returned error: {}", e); // connection is broken so don't try to close nicely - return None; + return Err(conn.close().await); } } else if let Some(test) = &options.before_acquire { match test(&mut conn.live.raw).await { Ok(false) => { // connection was rejected by user-defined hook - return None; + return Err(conn.close().await); } Err(error) => { log::info!("in `before_acquire`: {}", error); - return None; + return Err(conn.close().await); } Ok(true) => {} @@ -312,7 +309,7 @@ async fn check_conn<'s: 'p, 'p, DB: Database>( } // No need to re-connect; connection is alive or we don't care - Some(conn.into_live()) + Ok(conn.into_live()) } /// if `max_lifetime` or `idle_timeout` is set, spawn a task that reaps senescent connections @@ -329,11 +326,9 @@ fn spawn_reaper(pool: &Arc>) { sqlx_rt::spawn(async move { while !pool.is_closed() { - // only reap idle connections when no tasks are waiting - if pool.waiters.is_empty() { + if !pool.idle_conns.is_empty() { do_reap(&pool).await; } - sqlx_rt::sleep(period).await; } }); @@ -346,7 +341,7 @@ async fn do_reap(pool: &SharedPool) { // collect connections to reap let (reap, keep) = (0..max_reaped) // only connections waiting in the queue - .filter_map(|_| pool.pop_idle()) + .filter_map(|_| pool.try_acquire()) .partition::, _>(|conn| { is_beyond_idle(conn, &pool.options) || is_beyond_lifetime(conn, &pool.options) }); @@ -361,38 +356,44 @@ async fn do_reap(pool: &SharedPool) { } } -fn wake_one(waiters: &Waiters) { - while let Some(weak) = waiters.pop() { - if let Some(waiter) = weak.upgrade() { - if waiter.wake() { - return; - } - } - } -} - /// RAII guard returned by `Pool::try_increment_size()` and others. /// /// Will decrement the pool size if dropped, to avoid semantically "leaking" connections /// (where the pool thinks it has more connections than it does). pub(in crate::pool) struct DecrementSizeGuard<'a> { size: &'a AtomicU32, - waiters: &'a Waiters, + semaphore: &'a Semaphore, dropped: bool, } impl<'a> DecrementSizeGuard<'a> { - pub fn new(pool: &'a SharedPool) -> Self { + /// Create a new guard that will release a semaphore permit on-drop. + pub fn new_permit(pool: &'a SharedPool) -> Self { Self { size: &pool.size, - waiters: &pool.waiters, + semaphore: &pool.semaphore, dropped: false, } } + pub fn from_permit( + pool: &'a SharedPool, + mut permit: SemaphoreReleaser<'a>, + ) -> Self { + // here we effectively take ownership of the permit + permit.disarm(); + Self::new_permit(pool) + } + /// Return `true` if the internal references point to the same fields in `SharedPool`. pub fn same_pool(&self, pool: &'a SharedPool) -> bool { - ptr::eq(self.size, &pool.size) && ptr::eq(self.waiters, &pool.waiters) + ptr::eq(self.size, &pool.size) + } + + /// Release the semaphore permit without decreasing the pool size. + fn release_permit(self) { + self.semaphore.release(1); + self.cancel(); } pub fn cancel(self) { @@ -405,73 +406,8 @@ impl Drop for DecrementSizeGuard<'_> { assert!(!self.dropped, "double-dropped!"); self.dropped = true; self.size.fetch_sub(1, Ordering::SeqCst); - wake_one(&self.waiters); - } -} - -struct WaiterInner { - woken: AtomicBool, - waker: Waker, -} - -impl WaiterInner { - /// Wake this waiter if it has not previously been woken. - /// - /// Return `true` if this waiter was newly woken, or `false` if it was already woken. - fn wake(&self) -> bool { - // if we were the thread to flip this boolean from false to true - if let Ok(_) = self - .woken - .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire) - { - self.waker.wake_by_ref(); - return true; - } - false - } -} - -struct Waiter<'a> { - inner: Arc, - queue: &'a Waiters, - actually_woke: bool, -} - -impl<'a> Waiter<'a> { - fn push_new(cx: &mut Context<'_>, queue: &'a Waiters) -> Self { - let inner = Arc::new(WaiterInner { - woken: AtomicBool::new(false), - waker: cx.waker().clone(), - }); - - queue.push(Arc::downgrade(&inner)); - - Self { - inner, - queue, - actually_woke: false, - } - } - - fn is_woken(&self) -> bool { - self.inner.woken.load(Ordering::Acquire) - } - - fn reset(&mut self) { - self.inner - .woken - .compare_exchange(true, false, Ordering::AcqRel, Ordering::Acquire) - .ok(); - self.actually_woke = false; - } -} - -impl Drop for Waiter<'_> { - fn drop(&mut self) { - // if we didn't actually wake to get a connection, wake the next task instead - if self.is_woken() && !self.actually_woke { - wake_one(self.queue); - } + // and here we release the permit we got on construction + self.semaphore.release(1); } } diff --git a/sqlx-core/src/pool/maybe.rs b/sqlx-core/src/pool/maybe.rs index 43c7f3457d..3064ba77f1 100644 --- a/sqlx-core/src/pool/maybe.rs +++ b/sqlx-core/src/pool/maybe.rs @@ -1,10 +1,13 @@ use crate::database::Database; + +#[cfg(not(target_arch = "wasm32"))] use crate::pool::PoolConnection; use std::ops::{Deref, DerefMut}; pub(crate) enum MaybePoolConnection<'c, DB: Database> { #[allow(dead_code)] Connection(&'c mut DB::Connection), + #[cfg(not(target_arch = "wasm32"))] PoolConnection(PoolConnection), } @@ -15,6 +18,7 @@ impl<'c, DB: Database> Deref for MaybePoolConnection<'c, DB> { fn deref(&self) -> &Self::Target { match self { MaybePoolConnection::Connection(v) => v, + #[cfg(not(target_arch = "wasm32"))] MaybePoolConnection::PoolConnection(v) => v, } } @@ -25,6 +29,7 @@ impl<'c, DB: Database> DerefMut for MaybePoolConnection<'c, DB> { fn deref_mut(&mut self) -> &mut Self::Target { match self { MaybePoolConnection::Connection(v) => v, + #[cfg(not(target_arch = "wasm32"))] MaybePoolConnection::PoolConnection(v) => v, } } @@ -33,6 +38,7 @@ impl<'c, DB: Database> DerefMut for MaybePoolConnection<'c, DB> { #[allow(unused_macros)] macro_rules! impl_into_maybe_pool { ($DB:ident, $C:ident) => { + #[cfg(not(target_arch = "wasm32"))] impl<'c> From> for crate::pool::MaybePoolConnection<'c, $DB> { diff --git a/sqlx-core/src/pool/mod.rs b/sqlx-core/src/pool/mod.rs index 10b9c17335..826e6534c6 100644 --- a/sqlx-core/src/pool/mod.rs +++ b/sqlx-core/src/pool/mod.rs @@ -54,28 +54,43 @@ //! [`Pool::acquire`] or //! [`Pool::begin`]. +#[cfg(not(target_arch = "wasm32"))] use self::inner::SharedPool; +#[cfg(not(target_arch = "wasm32"))] use crate::connection::Connection; +#[cfg(not(target_arch = "wasm32"))] use crate::database::Database; +#[cfg(not(target_arch = "wasm32"))] use crate::error::Error; +#[cfg(not(target_arch = "wasm32"))] use crate::transaction::Transaction; +#[cfg(not(target_arch = "wasm32"))] use std::fmt; +#[cfg(not(target_arch = "wasm32"))] use std::future::Future; +#[cfg(not(target_arch = "wasm32"))] use std::sync::Arc; +#[cfg(not(target_arch = "wasm32"))] use std::time::{Duration, Instant}; +#[cfg(not(target_arch = "wasm32"))] #[macro_use] mod executor; #[macro_use] mod maybe; +#[cfg(not(target_arch = "wasm32"))] mod connection; +#[cfg(not(target_arch = "wasm32"))] mod inner; +#[cfg(not(target_arch = "wasm32"))] mod options; +#[cfg(not(target_arch = "wasm32"))] pub use self::connection::PoolConnection; pub(crate) use self::maybe::MaybePoolConnection; +#[cfg(not(target_arch = "wasm32"))] pub use self::options::PoolOptions; /// An asynchronous pool of SQLx database connections. @@ -97,11 +112,23 @@ pub use self::options::PoolOptions; /// /// Calls to `acquire()` are fair, i.e. fulfilled on a first-come, first-serve basis. /// -/// `Pool` is `Send`, `Sync` and `Clone`, so it should be created once at the start of your -/// application/daemon/web server/etc. and then shared with all tasks throughout its lifetime. How -/// best to accomplish this depends on your program architecture. +/// `Pool` is `Send`, `Sync` and `Clone`. It is intended to be created once at the start of your +/// application/daemon/web server/etc. and then shared with all tasks throughout the process' +/// lifetime. How best to accomplish this depends on your program architecture. /// -/// In Actix-Web, you can share a single pool with all request handlers using [web::Data]. +/// In Actix-Web, for example, you can share a single pool with all request handlers using [web::Data]. +/// +/// Cloning `Pool` is cheap as it is simply a reference-counted handle to the inner pool state. +/// When the last remaining handle to the pool is dropped, the connections owned by the pool are +/// immediately closed (also by dropping). `PoolConnection` returned by [Pool::acquire] and +/// `Transaction` returned by [Pool::begin] both implicitly hold a reference to the pool for +/// their lifetimes. +/// +/// If you prefer to explicitly shutdown the pool and gracefully close its connections (which +/// depending on the database type, may include sending a message to the database server that the +/// connection is being closed), you can call [Pool::close] which causes all waiting and subsequent +/// calls to [Pool::acquire] to return [Error::PoolClosed], and waits until all connections have +/// been returned to the pool and gracefully closed. /// /// Type aliases are provided for each database to make it easier to sprinkle `Pool` through /// your codebase: @@ -111,7 +138,7 @@ pub use self::options::PoolOptions; /// * [PgPool][crate::postgres::PgPool] (PostgreSQL) /// * [SqlitePool][crate::sqlite::SqlitePool] (SQLite) /// -/// [web::Data]: https://docs.rs/actix-web/2.0.0/actix_web/web/struct.Data.html +/// [web::Data]: https://docs.rs/actix-web/3/actix_web/web/struct.Data.html /// /// ### Why Use a Pool? /// @@ -213,8 +240,11 @@ pub use self::options::PoolOptions; /// /// Depending on the database server, a connection will have caches for all kinds of other data as /// well and queries will generally benefit from these caches being "warm" (populated with data). + +#[cfg(not(target_arch = "wasm32"))] pub struct Pool(pub(crate) Arc>); +#[cfg(not(target_arch = "wasm32"))] impl Pool { /// Creates a new connection pool with a default pool configuration and /// the given connection URI; and, immediately establishes one connection. @@ -256,7 +286,9 @@ impl Pool { /// /// Returns `None` immediately if there are no idle connections available in the pool. pub fn try_acquire(&self) -> Option> { - self.0.try_acquire().map(|conn| conn.attach(&self.0)) + self.0 + .try_acquire() + .map(|conn| conn.into_live().attach(&self.0)) } /// Retrieves a new connection and immediately begins a new transaction. @@ -276,10 +308,29 @@ impl Pool { } } - /// Ends the use of a connection pool. Prevents any new connections - /// and will close all active connections when they are returned to the pool. + /// Shut down the connection pool, waiting for all connections to be gracefully closed. + /// + /// Upon `.await`ing this call, any currently waiting or subsequent calls to [Pool::acquire] and + /// the like will immediately return [Error::PoolClosed] and no new connections will be opened. + /// + /// Any connections currently idle in the pool will be immediately closed, including sending + /// a graceful shutdown message to the database server, if applicable. + /// + /// Checked-out connections are unaffected, but will be closed in the same manner when they are + /// returned to the pool. + /// + /// Does not resolve until all connections are returned to the pool and gracefully closed. + /// + /// ### Note: `async fn` + /// Because this is an `async fn`, the pool will *not* be marked as closed unless the + /// returned future is polled at least once. /// - /// Does not resolve until all connections are closed. + /// If you want to close the pool but don't want to wait for all connections to be gracefully + /// closed, you can do `pool.close().now_or_never()`, which polls the future exactly once + /// with a no-op waker. + // TODO: I don't want to change the signature right now in case it turns out to be a + // breaking change, but this probably should eagerly mark the pool as closed and then the + // returned future only needs to be awaited to gracefully close the connections. pub async fn close(&self) { self.0.close().await; } @@ -305,12 +356,14 @@ impl Pool { } /// Returns a new [Pool] tied to the same shared connection pool. +#[cfg(not(target_arch = "wasm32"))] impl Clone for Pool { fn clone(&self) -> Self { Self(Arc::clone(&self.0)) } } +#[cfg(not(target_arch = "wasm32"))] impl fmt::Debug for Pool { fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { fmt.debug_struct("Pool") @@ -325,6 +378,7 @@ impl fmt::Debug for Pool { /// get the time between the deadline and now and use that as our timeout /// /// returns `Error::PoolTimedOut` if the deadline is in the past +#[cfg(not(target_arch = "wasm32"))] fn deadline_as_timeout(deadline: Instant) -> Result { deadline .checked_duration_since(Instant::now()) diff --git a/sqlx-core/src/pool/options.rs b/sqlx-core/src/pool/options.rs index a1b07f3721..32313808ff 100644 --- a/sqlx-core/src/pool/options.rs +++ b/sqlx-core/src/pool/options.rs @@ -231,19 +231,13 @@ impl PoolOptions { async fn init_min_connections(pool: &SharedPool) -> Result<(), Error> { for _ in 0..cmp::max(pool.options.min_connections, 1) { let deadline = Instant::now() + pool.options.connect_timeout; + let permit = pool.semaphore.acquire(1).await; // this guard will prevent us from exceeding `max_size` - if let Some(guard) = pool.try_increment_size() { + if let Ok(guard) = pool.try_increment_size(permit) { // [connect] will raise an error when past deadline let conn = pool.connection(deadline, guard).await?; - let is_ok = pool - .idle_conns - .push(conn.into_idle().into_leakable()) - .is_ok(); - - if !is_ok { - panic!("BUG: connection queue overflow in init_min_connections"); - } + pool.release(conn); } } diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index 097058cef2..a14e2a1a69 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -8,7 +8,10 @@ use crate::query_as::query_as; use crate::query_scalar::{query_scalar, query_scalar_with}; use crate::types::Json; use crate::HashMap; +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; use std::convert::TryFrom; use std::fmt::Write; use std::sync::Arc; @@ -399,13 +402,16 @@ SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 .fetch_all(&mut *self) .await?; - // patch up our null inference with data from EXPLAIN - let nullable_patch = self - .nullables_from_explain(stmt_id, meta.parameters.len()) - .await?; + // if it's cockroachdb skip this step #1248 + if !self.stream.parameter_statuses.contains_key("crdb_version") { + // patch up our null inference with data from EXPLAIN + let nullable_patch = self + .nullables_from_explain(stmt_id, meta.parameters.len()) + .await?; - for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { - *nullable = patch.or(*nullable); + for (nullable, patch) in nullables.iter_mut().zip(nullable_patch) { + *nullable = patch.or(*nullable); + } } Ok(nullables) diff --git a/sqlx-core/src/postgres/connection/establish.rs b/sqlx-core/src/postgres/connection/establish.rs index 59e3727c24..7bf760a2d5 100644 --- a/sqlx-core/src/postgres/connection/establish.rs +++ b/sqlx-core/src/postgres/connection/establish.rs @@ -3,7 +3,10 @@ use crate::HashMap; use crate::common::StatementCache; use crate::error::Error; use crate::io::Decode; -use crate::postgres::connection::{sasl, stream::PgStream, tls}; +#[cfg(not(target_arch = "wasm32"))] +use crate::postgres::connection::tls; +use crate::postgres::connection::{sasl, stream::PgStream}; + use crate::postgres::message::{ Authentication, BackendKeyData, MessageFormat, Password, ReadyForQuery, Startup, }; @@ -17,6 +20,7 @@ impl PgConnection { let mut stream = PgStream::connect(options).await?; // Upgrade to TLS if we were asked to and the server supports it + #[cfg(not(target_arch = "wasm32"))] tls::maybe_upgrade(&mut stream, options).await?; // To begin a session, a frontend opens a connection to the server diff --git a/sqlx-core/src/postgres/connection/executor.rs b/sqlx-core/src/postgres/connection/executor.rs index ee2ab76802..33d3948989 100644 --- a/sqlx-core/src/postgres/connection/executor.rs +++ b/sqlx-core/src/postgres/connection/executor.rs @@ -1,6 +1,7 @@ use crate::describe::Describe; use crate::error::Error; use crate::executor::{Execute, Executor}; +#[cfg(not(target_arch = "wasm32"))] use crate::logger::QueryLogger; use crate::postgres::message::{ self, Bind, Close, CommandComplete, DataRow, MessageFormat, ParameterDescription, Parse, Query, @@ -13,8 +14,17 @@ use crate::postgres::{ PgValueFormat, Postgres, }; use either::Either; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::stream::BoxStream; +#[cfg(target_arch = "wasm32")] +use futures_core::stream::LocalBoxStream as BoxStream; + use futures_core::Stream; use futures_util::{pin_mut, TryStreamExt}; use std::{borrow::Cow, sync::Arc}; @@ -199,6 +209,7 @@ impl PgConnection { persistent: bool, metadata_opt: Option>, ) -> Result, Error>> + 'e, Error> { + #[cfg(not(target_arch = "wasm32"))] let mut logger = QueryLogger::new(query, self.log_settings.clone()); // before we continue, wait until we are "ready" to accept more queries @@ -218,6 +229,10 @@ impl PgConnection { // patch holes created during encoding arguments.apply_patches(self, &metadata.parameters).await?; + // apply patches use fetch_optional thaht may produce `PortalSuspended` message, + // consume messages til `ReadyForQuery` before bind and execute + self.wait_until_ready().await?; + // bind to attach the arguments to the statement and create a portal self.stream.write(Bind { portal: None, @@ -297,6 +312,7 @@ impl PgConnection { } MessageFormat::DataRow => { + #[cfg(not(target_arch = "wasm32"))] logger.increment_rows(); // one of the set of rows returned by a SELECT, FETCH, etc query diff --git a/sqlx-core/src/postgres/connection/mod.rs b/sqlx-core/src/postgres/connection/mod.rs index e0238f5911..09b2c02bd2 100644 --- a/sqlx-core/src/postgres/connection/mod.rs +++ b/sqlx-core/src/postgres/connection/mod.rs @@ -2,7 +2,12 @@ use std::fmt::{self, Debug, Formatter}; use std::sync::Arc; use crate::HashMap; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; + use futures_util::{FutureExt, TryFutureExt}; use crate::common::StatementCache; @@ -11,7 +16,6 @@ use crate::error::Error; use crate::executor::Executor; use crate::ext::ustr::UStr; use crate::io::Decode; -use crate::postgres::connection::stream::PgStream; use crate::postgres::message::{ Close, Message, MessageFormat, ReadyForQuery, Terminate, TransactionStatus, }; @@ -19,11 +23,15 @@ use crate::postgres::statement::PgStatementMetadata; use crate::postgres::{PgConnectOptions, PgTypeInfo, Postgres}; use crate::transaction::Transaction; +pub use self::stream::PgStream; + pub(crate) mod describe; mod establish; mod executor; mod sasl; mod stream; + +#[cfg(not(target_arch = "wasm32"))] mod tls; /// A connection to a PostgreSQL database. @@ -66,7 +74,7 @@ pub struct PgConnection { impl PgConnection { // will return when the connection is ready for another query - async fn wait_until_ready(&mut self) -> Result<(), Error> { + pub(in crate::postgres) async fn wait_until_ready(&mut self) -> Result<(), Error> { if !self.stream.wbuf.is_empty() { self.stream.flush().await?; } @@ -128,10 +136,16 @@ impl Connection for PgConnection { }) } + #[cfg(not(target_arch = "wasm32"))] fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { // By sending a comment we avoid an error if the connection was in the middle of a rowset self.execute("/* SQLx ping */").map_ok(|_| ()).boxed() } + #[cfg(target_arch = "wasm32")] + fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { + // By sending a comment we avoid an error if the connection was in the middle of a rowset + self.execute("/* SQLx ping */").map_ok(|_| ()).boxed_local() + } fn begin(&mut self) -> BoxFuture<'_, Result, Error>> where @@ -168,12 +182,35 @@ impl Connection for PgConnection { } #[doc(hidden)] + #[cfg(not(target_arch = "wasm32"))] fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { self.wait_until_ready().boxed() } + #[cfg(target_arch = "wasm32")] + fn flush(&mut self) -> BoxFuture<'_, Result<(), Error>> { + self.wait_until_ready().boxed_local() + } #[doc(hidden)] fn should_flush(&self) -> bool { !self.stream.wbuf.is_empty() } } + +pub trait PgConnectionInfo { + /// the version number of the server in `libpq` format + fn server_version_num(&self) -> Option; +} + +impl PgConnectionInfo for PgConnection { + fn server_version_num(&self) -> Option { + self.stream.server_version_num + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl PgConnectionInfo for crate::pool::PoolConnection { + fn server_version_num(&self) -> Option { + self.stream.server_version_num + } +} diff --git a/sqlx-core/src/postgres/connection/sasl.rs b/sqlx-core/src/postgres/connection/sasl.rs index 905afe974c..809c8ea170 100644 --- a/sqlx-core/src/postgres/connection/sasl.rs +++ b/sqlx-core/src/postgres/connection/sasl.rs @@ -98,7 +98,7 @@ pub(crate) async fn authenticate( )?; // ClientKey := HMAC(SaltedPassword, "Client Key") - let mut mac = Hmac::::new_varkey(&salted_password).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Client Key"); let client_key = mac.finalize().into_bytes(); @@ -122,7 +122,7 @@ pub(crate) async fn authenticate( ); // ClientSignature := HMAC(StoredKey, AuthMessage) - let mut mac = Hmac::::new_varkey(&stored_key).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&stored_key).map_err(Error::protocol)?; mac.update(&auth_message.as_bytes()); let client_signature = mac.finalize().into_bytes(); @@ -135,13 +135,13 @@ pub(crate) async fn authenticate( .collect(); // ServerKey := HMAC(SaltedPassword, "Server Key") - let mut mac = Hmac::::new_varkey(&salted_password).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&salted_password).map_err(Error::protocol)?; mac.update(b"Server Key"); let server_key = mac.finalize().into_bytes(); // ServerSignature := HMAC(ServerKey, AuthMessage) - let mut mac = Hmac::::new_varkey(&server_key).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(&server_key).map_err(Error::protocol)?; mac.update(&auth_message.as_bytes()); // client-final-message = client-final-message-without-proof "," proof @@ -197,7 +197,7 @@ fn gen_nonce() -> String { // Hi(str, salt, i): fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error> { - let mut mac = Hmac::::new_varkey(s.as_bytes()).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; mac.update(&salt); mac.update(&1u32.to_be_bytes()); @@ -206,7 +206,7 @@ fn hi<'a>(s: &'a str, salt: &'a [u8], iter_count: u32) -> Result<[u8; 32], Error let mut hi = u; for _ in 1..iter_count { - let mut mac = Hmac::::new_varkey(s.as_bytes()).map_err(Error::protocol)?; + let mut mac = Hmac::::new_from_slice(s.as_bytes()).map_err(Error::protocol)?; mac.update(u.as_slice()); u = mac.finalize().into_bytes(); hi = hi.iter().zip(u.iter()).map(|(&a, &b)| a ^ b).collect(); diff --git a/sqlx-core/src/postgres/connection/stream.rs b/sqlx-core/src/postgres/connection/stream.rs index f61bda62b8..2b3d8c3b70 100644 --- a/sqlx-core/src/postgres/connection/stream.rs +++ b/sqlx-core/src/postgres/connection/stream.rs @@ -1,4 +1,6 @@ +use std::collections::BTreeMap; use std::ops::{Deref, DerefMut}; +use std::str::FromStr; use bytes::{Buf, Bytes}; use futures_channel::mpsc::UnboundedSender; @@ -7,8 +9,13 @@ use log::Level; use crate::error::Error; use crate::io::{BufStream, Decode, Encode}; -use crate::net::{MaybeTlsStream, Socket}; -use crate::postgres::message::{Message, MessageFormat, Notice, Notification}; + +#[cfg(not(target_arch = "wasm32"))] +use crate::net::MaybeTlsStream; + +use crate::net::Socket; + +use crate::postgres::message::{Message, MessageFormat, Notice, Notification, ParameterStatus}; use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity}; // the stream is a separate type from the connection to uphold the invariant where an instantiated @@ -21,27 +28,53 @@ use crate::postgres::{PgConnectOptions, PgDatabaseError, PgSeverity}; // is fully prepared to receive queries pub struct PgStream { + #[cfg(not(target_arch = "wasm32"))] inner: BufStream>, - + #[cfg(target_arch = "wasm32")] + inner: BufStream, // buffer of unreceived notification messages from `PUBLISH` // this is set when creating a PgListener and only written to if that listener is // re-used for query execution in-between receiving messages pub(crate) notifications: Option>, + + pub(crate) parameter_statuses: BTreeMap, + + pub(crate) server_version_num: Option, } impl PgStream { pub(super) async fn connect(options: &PgConnectOptions) -> Result { - let socket = match options.fetch_socket() { - Some(ref path) => Socket::connect_uds(path).await?, - None => Socket::connect_tcp(&options.host, options.port).await?, - }; - - let inner = BufStream::new(MaybeTlsStream::Raw(socket)); + #[cfg(target_arch = "wasm32")] + { + let socket = match options.fetch_socket() { + Some(ref path) => Socket::connect_ws(path).await?, + None => return Err(Error::Configuration("no ws url set".into())), + }; + let inner = BufStream::new(socket); + + Ok(Self { + inner, + notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, + }) + } - Ok(Self { - inner, - notifications: None, - }) + #[cfg(not(target_arch = "wasm32"))] + { + let socket = match options.fetch_socket() { + Some(ref path) => Socket::connect_uds(path).await?, + None => Socket::connect_tcp(&options.host, options.port).await?, + }; + let inner = BufStream::new(MaybeTlsStream::Raw(socket)); + + Ok(Self { + inner, + notifications: None, + parameter_statuses: BTreeMap::default(), + server_version_num: None, + }) + } } pub(crate) async fn send<'en, T>(&mut self, message: T) -> Result<(), Error> @@ -108,7 +141,18 @@ impl PgStream { // informs the frontend about the current (initial) // setting of backend parameters - // we currently have no use for that data so we promptly ignore this message + let ParameterStatus { name, value } = message.decode()?; + // TODO: handle `client_encoding`, `DateStyle` change + + match name.as_str() { + "server_version" => { + self.server_version_num = parse_server_version(&value); + } + _ => { + self.parameter_statuses.insert(name, value); + } + } + continue; } @@ -151,7 +195,10 @@ impl PgStream { } impl Deref for PgStream { + #[cfg(not(target_arch = "wasm32"))] type Target = BufStream>; + #[cfg(target_arch = "wasm32")] + type Target = BufStream; #[inline] fn deref(&self) -> &Self::Target { @@ -165,3 +212,68 @@ impl DerefMut for PgStream { &mut self.inner } } + +// reference: +// https://github.com/postgres/postgres/blob/6feebcb6b44631c3dc435e971bd80c2dd218a5ab/src/interfaces/libpq/fe-exec.c#L1030-L1065 +fn parse_server_version(s: &str) -> Option { + let mut parts = Vec::::with_capacity(3); + + let mut from = 0; + let mut chs = s.char_indices().peekable(); + while let Some((i, ch)) = chs.next() { + match ch { + '.' => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + from = i + 1; + } else { + break; + } + } + _ if ch.is_digit(10) => { + if chs.peek().is_none() { + if let Ok(num) = u32::from_str(&s[from..]) { + parts.push(num); + } + break; + } + } + _ => { + if let Ok(num) = u32::from_str(&s[from..i]) { + parts.push(num); + } + break; + } + }; + } + + let version_num = match parts.as_slice() { + [major, minor, rev] => (100 * major + minor) * 100 + rev, + [major, minor] if *major >= 10 => 100 * 100 * major + minor, + [major, minor] => (100 * major + minor) * 100, + [major] => 100 * 100 * major, + _ => return None, + }; + + Some(version_num) +} + +#[cfg(test)] +mod tests { + use super::parse_server_version; + + #[test] + fn test_parse_server_version_num() { + // old style + assert_eq!(parse_server_version("9.6.1"), Some(90601)); + // new style + assert_eq!(parse_server_version("10.1"), Some(100001)); + // old style without minor version + assert_eq!(parse_server_version("9.6devel"), Some(90600)); + // new style without minor version, e.g. */ + assert_eq!(parse_server_version("10devel"), Some(100000)); + assert_eq!(parse_server_version("13devel87"), Some(130000)); + // unknown + assert_eq!(parse_server_version("unknown"), None); + } +} diff --git a/sqlx-core/src/postgres/copy.rs b/sqlx-core/src/postgres/copy.rs new file mode 100644 index 0000000000..ddff8ce5d0 --- /dev/null +++ b/sqlx-core/src/postgres/copy.rs @@ -0,0 +1,342 @@ +use crate::error::{Error, Result}; +use crate::ext::async_stream::TryAsyncStream; +#[cfg(not(target_arch = "wasm32"))] +use crate::pool::{Pool, PoolConnection}; +use crate::postgres::connection::PgConnection; +use crate::postgres::message::{ + CommandComplete, CopyData, CopyDone, CopyFail, CopyResponse, MessageFormat, Query, +}; +#[cfg(not(target_arch = "wasm32"))] +use crate::postgres::Postgres; +use bytes::{BufMut, Bytes}; + +#[cfg(not(target_arch = "wasm32"))] +use futures_core::stream::BoxStream; +#[cfg(target_arch = "wasm32")] +use futures_core::stream::LocalBoxStream as BoxStream; + +use smallvec::alloc::borrow::Cow; +use sqlx_rt::{AsyncRead, AsyncReadExt, AsyncWriteExt}; +use std::convert::TryFrom; +use std::ops::{Deref, DerefMut}; + +impl PgConnection { + /// Issue a `COPY FROM STDIN` statement and transition the connection to streaming data + /// to Postgres. This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + pub async fn copy_in_raw(&mut self, statement: &str) -> Result> { + PgCopyIn::begin(self, statement).await + } + + /// Issue a `COPY TO STDOUT` statement and transition the connection to streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + #[allow(clippy::needless_lifetimes)] + pub async fn copy_out_raw<'c>( + &'c mut self, + statement: &str, + ) -> Result>> { + pg_begin_copy_out(self, statement).await + } +} + +#[cfg(not(target_arch = "wasm32"))] +impl Pool { + /// Issue a `COPY FROM STDIN` statement and begin streaming data to Postgres. + /// This is a more efficient way to import data into Postgres as compared to + /// `INSERT` but requires one of a few specific data formats (text/CSV/binary). + /// + /// A single connection will be checked out for the duration. + /// + /// If `statement` is anything other than a `COPY ... FROM STDIN ...` command, an error is + /// returned. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + /// + /// ### Note + /// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection + /// will return an error the next time it is used. + pub async fn copy_in_raw(&self, statement: &str) -> Result>> { + PgCopyIn::begin(self.acquire().await?, statement).await + } + + /// Issue a `COPY TO STDOUT` statement and begin streaming data + /// from Postgres. This is a more efficient way to export data from Postgres but + /// arrives in chunks of one of a few data formats (text/CSV/binary). + /// + /// If `statement` is anything other than a `COPY ... TO STDOUT ...` command, + /// an error is returned. + /// + /// Note that once this process has begun, unless you read the stream to completion, + /// it can only be canceled in two ways: + /// + /// 1. by closing the connection, or: + /// 2. by using another connection to kill the server process that is sending the data as shown + /// [in this StackOverflow answer](https://stackoverflow.com/a/35319598). + /// + /// If you don't read the stream to completion, the next time the connection is used it will + /// need to read and discard all the remaining queued data, which could take some time. + /// + /// Command examples and accepted formats for `COPY` data are shown here: + /// https://www.postgresql.org/docs/current/sql-copy.html + pub async fn copy_out_raw(&self, statement: &str) -> Result>> { + pg_begin_copy_out(self.acquire().await?, statement).await + } +} + +/// A connection in streaming `COPY FROM STDIN` mode. +/// +/// Created by [PgConnection::copy_in_raw] or [Pool::copy_out_raw]. +/// +/// ### Note +/// [PgCopyIn::finish] or [PgCopyIn::abort] *must* be called when finished or the connection +/// will return an error the next time it is used. +#[must_use = "connection will error on next use if `.finish()` or `.abort()` is not called"] +pub struct PgCopyIn> { + conn: Option, + response: CopyResponse, +} + +impl> PgCopyIn { + async fn begin(mut conn: C, statement: &str) -> Result { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let response: CopyResponse = conn + .stream + .recv_expect(MessageFormat::CopyInResponse) + .await?; + + Ok(PgCopyIn { + conn: Some(conn), + response, + }) + } + + /// Returns `true` if Postgres is expecting data in text or CSV format. + pub fn is_textual(&self) -> bool { + self.response.format == 0 + } + + /// Returns the number of columns expected in the input. + pub fn num_columns(&self) -> usize { + assert_eq!( + self.response.num_columns as usize, + self.response.format_codes.len(), + "num_columns does not match format_codes.len()" + ); + self.response.format_codes.len() + } + + /// Check if a column is expecting data in text format (`true`) or binary format (`false`). + /// + /// ### Panics + /// If `column` is out of range according to [`.num_columns()`][Self::num_columns]. + pub fn column_is_textual(&self, column: usize) -> bool { + self.response.format_codes[column] == 0 + } + + /// Send a chunk of `COPY` data. + /// + /// If you're copying data from an `AsyncRead`, maybe consider [Self::read_from] instead. + pub async fn send(&mut self, data: impl Deref) -> Result<&mut Self> { + self.conn + .as_deref_mut() + .expect("send_data: conn taken") + .stream + .send(CopyData(data)) + .await?; + + Ok(self) + } + + /// Copy data directly from `source` to the database without requiring an intermediate buffer. + /// + /// `source` will be read to the end. + /// + /// ### Note + /// You must still call either [Self::finish] or [Self::abort] to complete the process. + pub async fn read_from(&mut self, mut source: impl AsyncRead + Unpin) -> Result<&mut Self> { + // this is a separate guard from WriteAndFlush so we can reuse the buffer without zeroing + struct BufGuard<'s>(&'s mut Vec); + + impl Drop for BufGuard<'_> { + fn drop(&mut self) { + self.0.clear() + } + } + + let conn: &mut PgConnection = self.conn.as_deref_mut().expect("copy_from: conn taken"); + + // flush any existing messages in the buffer and clear it + conn.stream.flush().await?; + + { + let buf_stream = &mut *conn.stream; + let stream = &mut buf_stream.stream; + + // ensures the buffer isn't left in an inconsistent state + let mut guard = BufGuard(&mut buf_stream.wbuf); + + let buf: &mut Vec = &mut guard.0; + buf.push(b'd'); // CopyData format code + buf.resize(5, 0); // reserve space for the length + + loop { + let read = match () { + // Tokio lets us read into the buffer without zeroing first + #[cfg(any(feature = "runtime-tokio", feature = "runtime-actix"))] + _ if buf.len() != buf.capacity() => { + // in case we have some data in the buffer, which can occur + // if the previous write did not fill the buffer + buf.truncate(5); + source.read_buf(buf).await? + } + _ => { + // should be a no-op unless len != capacity + buf.resize(buf.capacity(), 0); + source.read(&mut buf[5..]).await? + } + }; + + if read == 0 { + break; + } + + let read32 = u32::try_from(read) + .map_err(|_| err_protocol!("number of bytes read exceeds 2^32: {}", read))?; + + (&mut buf[1..]).put_u32(read32 + 4); + + stream.write_all(&buf[..read + 5]).await?; + stream.flush().await?; + } + } + + Ok(self) + } + + /// Signal that the `COPY` process should be aborted and any data received should be discarded. + /// + /// The given message can be used for indicating the reason for the abort in the database logs. + /// + /// The server is expected to respond with an error, so only _unexpected_ errors are returned. + pub async fn abort(mut self, msg: impl Into) -> Result<()> { + let mut conn = self + .conn + .take() + .expect("PgCopyIn::fail_with: conn taken illegally"); + + conn.stream.send(CopyFail::new(msg)).await?; + + match conn.stream.recv().await { + Ok(msg) => Err(err_protocol!( + "fail_with: expected ErrorResponse, got: {:?}", + msg.format + )), + Err(Error::Database(e)) => { + match e.code() { + Some(Cow::Borrowed("57014")) => { + // postgres abort received error code + conn.stream + .recv_expect(MessageFormat::ReadyForQuery) + .await?; + Ok(()) + } + _ => Err(Error::Database(e)), + } + } + Err(e) => Err(e), + } + } + + /// Signal that the `COPY` process is complete. + /// + /// The number of rows affected is returned. + pub async fn finish(mut self) -> Result { + let mut conn = self + .conn + .take() + .expect("CopyWriter::finish: conn taken illegally"); + + conn.stream.send(CopyDone).await?; + let cc: CommandComplete = conn + .stream + .recv_expect(MessageFormat::CommandComplete) + .await?; + + conn.stream + .recv_expect(MessageFormat::ReadyForQuery) + .await?; + + Ok(cc.rows_affected()) + } +} + +impl> Drop for PgCopyIn { + fn drop(&mut self) { + if let Some(mut conn) = self.conn.take() { + conn.stream.write(CopyFail::new( + "PgCopyIn dropped without calling finish() or fail()", + )); + } + } +} + +async fn pg_begin_copy_out<'c, C: DerefMut + Send + 'c>( + mut conn: C, + statement: &str, +) -> Result>> { + conn.wait_until_ready().await?; + conn.stream.send(Query(statement)).await?; + + let _: CopyResponse = conn + .stream + .recv_expect(MessageFormat::CopyOutResponse) + .await?; + + let stream: TryAsyncStream<'c, Bytes> = try_stream! { + loop { + let msg = conn.stream.recv().await?; + match msg.format { + MessageFormat::CopyData => r#yield!(msg.decode::>()?.0), + MessageFormat::CopyDone => { + let _ = msg.decode::()?; + conn.stream.recv_expect(MessageFormat::CommandComplete).await?; + conn.stream.recv_expect(MessageFormat::ReadyForQuery).await?; + return Ok(()) + }, + _ => return Err(err_protocol!("unexpected message format during copy out: {:?}", msg.format)) + } + } + }; + + Ok(Box::pin(stream)) +} diff --git a/sqlx-core/src/postgres/listener.rs b/sqlx-core/src/postgres/listener.rs index 82c0460f4b..36a1cb5f5c 100644 --- a/sqlx-core/src/postgres/listener.rs +++ b/sqlx-core/src/postgres/listener.rs @@ -260,10 +260,23 @@ impl PgListener { impl Drop for PgListener { fn drop(&mut self) { if let Some(mut conn) = self.connection.take() { - // Unregister any listeners before returning the connection to the pool. - sqlx_rt::spawn(async move { + let fut = async move { let _ = conn.execute("UNLISTEN *").await; - }); + + // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task + // otherwise, it may trigger a panic if this task is dropped because the runtime is going away: + // https://github.com/launchbadge/sqlx/issues/1389 + conn.return_to_pool().await; + }; + + // Unregister any listeners before returning the connection to the pool. + #[cfg(not(feature = "_rt-async-std"))] + if let Ok(handle) = sqlx_rt::Handle::try_current() { + handle.spawn(fut); + } + + #[cfg(feature = "_rt-async-std")] + sqlx_rt::spawn(fut); } } } diff --git a/sqlx-core/src/postgres/message/copy.rs b/sqlx-core/src/postgres/message/copy.rs new file mode 100644 index 0000000000..58553d431b --- /dev/null +++ b/sqlx-core/src/postgres/message/copy.rs @@ -0,0 +1,96 @@ +use crate::error::Result; +use crate::io::{BufExt, BufMutExt, Decode, Encode}; +use bytes::{Buf, BufMut, Bytes}; +use std::ops::Deref; + +/// The same structure is sent for both `CopyInResponse` and `CopyOutResponse` +pub struct CopyResponse { + pub format: i8, + pub num_columns: i16, + pub format_codes: Vec, +} + +pub struct CopyData(pub B); + +pub struct CopyFail { + pub message: String, +} + +pub struct CopyDone; + +impl Decode<'_> for CopyResponse { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let format = buf.get_i8(); + let num_columns = buf.get_i16(); + + let format_codes = (0..num_columns).map(|_| buf.get_i16()).collect(); + + Ok(CopyResponse { + format, + num_columns, + format_codes, + }) + } +} + +impl Decode<'_> for CopyData { + fn decode_with(buf: Bytes, _: ()) -> Result { + // well.. that was easy + Ok(CopyData(buf)) + } +} + +impl> Encode<'_> for CopyData { + fn encode_with(&self, buf: &mut Vec, _context: ()) { + buf.push(b'd'); + buf.put_u32(self.0.len() as u32 + 4); + buf.extend_from_slice(&self.0); + } +} + +impl Decode<'_> for CopyFail { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + Ok(CopyFail { + message: buf.get_str_nul()?, + }) + } +} + +impl Encode<'_> for CopyFail { + fn encode_with(&self, buf: &mut Vec, _: ()) { + let len = 4 + self.message.len() + 1; + + buf.push(b'f'); // to pay respects + buf.put_u32(len as u32); + buf.put_str_nul(&self.message); + } +} + +impl CopyFail { + pub fn new(msg: impl Into) -> CopyFail { + CopyFail { + message: msg.into(), + } + } +} + +impl Decode<'_> for CopyDone { + fn decode_with(buf: Bytes, _: ()) -> Result { + if buf.is_empty() { + Ok(CopyDone) + } else { + Err(err_protocol!( + "expected no data for CopyDone, got: {:?}", + buf + )) + } + } +} + +impl Encode<'_> for CopyDone { + fn encode_with(&self, buf: &mut Vec, _: ()) { + buf.reserve(4); + buf.push(b'c'); + buf.put_u32(4); + } +} diff --git a/sqlx-core/src/postgres/message/mod.rs b/sqlx-core/src/postgres/message/mod.rs index 6c8d1f3023..1261bff339 100644 --- a/sqlx-core/src/postgres/message/mod.rs +++ b/sqlx-core/src/postgres/message/mod.rs @@ -8,12 +8,14 @@ mod backend_key_data; mod bind; mod close; mod command_complete; +mod copy; mod data_row; mod describe; mod execute; mod flush; mod notification; mod parameter_description; +mod parameter_status; mod parse; mod password; mod query; @@ -31,12 +33,14 @@ pub use backend_key_data::BackendKeyData; pub use bind::Bind; pub use close::Close; pub use command_complete::CommandComplete; +pub use copy::{CopyData, CopyDone, CopyFail, CopyResponse}; pub use data_row::DataRow; pub use describe::Describe; pub use execute::Execute; pub use flush::Flush; pub use notification::Notification; pub use parameter_description::ParameterDescription; +pub use parameter_status::ParameterStatus; pub use parse::Parse; pub use password::Password; pub use query::Query; @@ -57,6 +61,10 @@ pub enum MessageFormat { BindComplete, CloseComplete, CommandComplete, + CopyData, + CopyDone, + CopyInResponse, + CopyOutResponse, DataRow, EmptyQueryResponse, ErrorResponse, @@ -96,6 +104,10 @@ impl MessageFormat { b'2' => MessageFormat::BindComplete, b'3' => MessageFormat::CloseComplete, b'C' => MessageFormat::CommandComplete, + b'd' => MessageFormat::CopyData, + b'c' => MessageFormat::CopyDone, + b'G' => MessageFormat::CopyInResponse, + b'H' => MessageFormat::CopyOutResponse, b'D' => MessageFormat::DataRow, b'E' => MessageFormat::ErrorResponse, b'I' => MessageFormat::EmptyQueryResponse, diff --git a/sqlx-core/src/postgres/message/parameter_status.rs b/sqlx-core/src/postgres/message/parameter_status.rs new file mode 100644 index 0000000000..ffd0ef1b60 --- /dev/null +++ b/sqlx-core/src/postgres/message/parameter_status.rs @@ -0,0 +1,62 @@ +use bytes::Bytes; + +use crate::error::Error; +use crate::io::{BufExt, Decode}; + +#[derive(Debug)] +pub struct ParameterStatus { + pub name: String, + pub value: String, +} + +impl Decode<'_> for ParameterStatus { + fn decode_with(mut buf: Bytes, _: ()) -> Result { + let name = buf.get_str_nul()?; + let value = buf.get_str_nul()?; + + Ok(Self { name, value }) + } +} + +#[test] +fn test_decode_parameter_status() { + const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; + + let m = ParameterStatus::decode(DATA.into()).unwrap(); + + assert_eq!(&m.name, "client_encoding"); + assert_eq!(&m.value, "UTF8") +} + +#[test] +fn test_decode_empty_parameter_status() { + const DATA: &[u8] = b"\x00\x00"; + + let m = ParameterStatus::decode(DATA.into()).unwrap(); + + assert!(m.name.is_empty()); + assert!(m.value.is_empty()); +} + +#[cfg(all(test, not(debug_assertions)))] +#[bench] +fn bench_decode_parameter_status(b: &mut test::Bencher) { + const DATA: &[u8] = b"client_encoding\x00UTF8\x00"; + + b.iter(|| { + ParameterStatus::decode(test::black_box(Bytes::from_static(DATA))).unwrap(); + }); +} + +#[test] +fn test_decode_parameter_status_response() { + const PARAMETER_STATUS_RESPONSE: &[u8] = b"crdb_version\0CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)\0"; + + let message = ParameterStatus::decode(Bytes::from(PARAMETER_STATUS_RESPONSE)).unwrap(); + + assert_eq!(message.name, "crdb_version"); + assert_eq!( + message.value, + "CockroachDB CCL v21.1.0 (x86_64-unknown-linux-gnu, built 2021/05/17 13:49:40, go1.15.11)" + ); +} diff --git a/sqlx-core/src/postgres/migrate.rs b/sqlx-core/src/postgres/migrate.rs index 142918a69e..13bd2e3694 100644 --- a/sqlx-core/src/postgres/migrate.rs +++ b/sqlx-core/src/postgres/migrate.rs @@ -8,7 +8,6 @@ use crate::postgres::{PgConnectOptions, PgConnection, Postgres}; use crate::query::query; use crate::query_as::query_as; use crate::query_scalar::query_scalar; -use crc::crc32; use futures_core::future::BoxFuture; use std::str::FromStr; use std::time::Duration; @@ -25,9 +24,9 @@ fn parse_for_maintenance(uri: &str) -> Result<(PgConnectOptions, String), Error> .to_owned(); // switch us to the maintenance database - // use `postgres` _unless_ the current user is postgres, in which case, use `template1` + // use `postgres` _unless_ the database is postgres, in which case, use `template1` // this matches the behavior of the `createdb` util - options.database = if options.username == "postgres" { + options.database = if database == "postgres" { Some("template1".into()) } else { Some("postgres".into()) @@ -281,6 +280,7 @@ async fn current_database(conn: &mut PgConnection) -> Result i64 { + const CRC_IEEE: crc::Crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC); // 0x3d32ad9e chosen by fair dice roll - 0x3d32ad9e * (crc32::checksum_ieee(database_name.as_bytes()) as i64) + 0x3d32ad9e * (CRC_IEEE.checksum(database_name.as_bytes()) as i64) } diff --git a/sqlx-core/src/postgres/mod.rs b/sqlx-core/src/postgres/mod.rs index bd72bc6e86..8482315d23 100644 --- a/sqlx-core/src/postgres/mod.rs +++ b/sqlx-core/src/postgres/mod.rs @@ -1,12 +1,20 @@ //! **PostgreSQL** database driver. +use crate::executor::Executor; + mod arguments; mod column; mod connection; +mod copy; mod database; mod error; mod io; + +#[cfg(not(target_arch = "wasm32"))] mod listener; +#[cfg(target_arch = "wasm32")] +mod ws_listener; + mod message; mod options; mod query_result; @@ -17,15 +25,21 @@ mod type_info; pub mod types; mod value; -#[cfg(feature = "migrate")] +#[cfg(all(feature = "migrate", not(target_arch = "wasm32")))] mod migrate; pub use arguments::{PgArgumentBuffer, PgArguments}; pub use column::PgColumn; -pub use connection::PgConnection; +pub use connection::{PgConnection, PgConnectionInfo}; +pub use copy::PgCopyIn; pub use database::Postgres; pub use error::{PgDatabaseError, PgErrorPosition}; + +#[cfg(not(target_arch = "wasm32"))] pub use listener::{PgListener, PgNotification}; +#[cfg(target_arch = "wasm32")] +pub use ws_listener::PgListener; + pub use message::PgSeverity; pub use options::{PgConnectOptions, PgSslMode}; pub use query_result::PgQueryResult; @@ -36,14 +50,24 @@ pub use type_info::{PgTypeInfo, PgTypeKind}; pub use value::{PgValue, PgValueFormat, PgValueRef}; /// An alias for [`Pool`][crate::pool::Pool], specialized for Postgres. +#[cfg(not(target_arch = "wasm32"))] pub type PgPool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for Postgres. +#[cfg(not(target_arch = "wasm32"))] pub type PgPoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = Postgres>`][Executor]. +pub trait PgExecutor<'c>: Executor<'c, Database = Postgres> {} +impl<'c, T: Executor<'c, Database = Postgres>> PgExecutor<'c> for T {} + impl_into_arguments_for_arguments!(PgArguments); + +#[cfg(not(target_arch = "wasm32"))] impl_executor_for_pool_connection!(Postgres, PgConnection, PgRow); impl_executor_for_transaction!(Postgres, PgRow); + +#[cfg(not(target_arch = "wasm32"))] impl_acquire!(Postgres, PgConnection); impl_column_index_for_row!(PgRow); impl_column_index_for_statement!(PgStatement); diff --git a/sqlx-core/src/postgres/options/connect.rs b/sqlx-core/src/postgres/options/connect.rs index 5c98598dd6..c9f4b02344 100644 --- a/sqlx-core/src/postgres/options/connect.rs +++ b/sqlx-core/src/postgres/options/connect.rs @@ -1,7 +1,12 @@ use crate::connection::ConnectOptions; use crate::error::Error; use crate::postgres::{PgConnectOptions, PgConnection}; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; + use log::LevelFilter; use std::time::Duration; diff --git a/sqlx-core/src/postgres/options/mod.rs b/sqlx-core/src/postgres/options/mod.rs index 1770959fdf..37c85a768e 100644 --- a/sqlx-core/src/postgres/options/mod.rs +++ b/sqlx-core/src/postgres/options/mod.rs @@ -5,7 +5,12 @@ mod connect; mod parse; mod pgpass; mod ssl_mode; -use crate::{connection::LogSettings, net::CertificateInput}; + +#[cfg(not(target_arch = "wasm32"))] +use crate::net::CertificateInput; + +use crate::connection::LogSettings; + pub use ssl_mode::PgSslMode; /// Options and flags which can be used to configure a PostgreSQL connection. @@ -81,10 +86,14 @@ pub struct PgConnectOptions { pub(crate) password: Option, pub(crate) database: Option, pub(crate) ssl_mode: PgSslMode, + #[cfg(not(target_arch = "wasm32"))] pub(crate) ssl_root_cert: Option, pub(crate) statement_cache_capacity: usize, pub(crate) application_name: Option, pub(crate) log_settings: LogSettings, + + #[cfg(target_arch = "wasm32")] + pub(crate) ws_url: Option, } impl Default for PgConnectOptions { @@ -137,6 +146,8 @@ impl PgConnectOptions { username, password, database, + + #[cfg(not(target_arch = "wasm32"))] ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from), ssl_mode: var("PGSSLMODE") .ok() @@ -145,6 +156,8 @@ impl PgConnectOptions { statement_cache_capacity: 100, application_name: var("PGAPPNAME").ok(), log_settings: Default::default(), + #[cfg(target_arch = "wasm32")] + ws_url: None, } } @@ -240,6 +253,14 @@ impl PgConnectOptions { self } + /// Sets the websocket url. + /// + #[cfg(target_arch = "wasm32")] + pub fn ws_url(mut self) -> Self { + self.ws_url = Some(format!("ws://{}:{}", self.host, self.port)); + self + } + /// Sets whether or with what priority a secure SSL TCP/IP connection will be negotiated /// with the server. /// @@ -273,6 +294,7 @@ impl PgConnectOptions { /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_root_cert("./ca-certificate.crt"); /// ``` + #[cfg(not(target_arch = "wasm32"))] pub fn ssl_root_cert(mut self, cert: impl AsRef) -> Self { self.ssl_root_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf())); self @@ -289,6 +311,7 @@ impl PgConnectOptions { /// .ssl_mode(PgSslMode::VerifyCa) /// .ssl_root_cert_from_pem(vec![]); /// ``` + #[cfg(not(target_arch = "wasm32"))] pub fn ssl_root_cert_from_pem(mut self, pem_certificate: Vec) -> Self { self.ssl_root_cert = Some(CertificateInput::Inline(pem_certificate)); self @@ -322,16 +345,24 @@ impl PgConnectOptions { /// We try using a socket if hostname starts with `/` or if socket parameter /// is specified. pub(crate) fn fetch_socket(&self) -> Option { - match self.socket { - Some(ref socket) => { - let full_path = format!("{}/.s.PGSQL.{}", socket.display(), self.port); - Some(full_path) - } - None if self.host.starts_with('/') => { - let full_path = format!("{}/.s.PGSQL.{}", self.host, self.port); - Some(full_path) + #[cfg(target_arch = "wasm32")] + { + self.ws_url.as_ref().cloned() + } + + #[cfg(not(target_arch = "wasm32"))] + { + match self.socket { + Some(ref socket) => { + let full_path = format!("{}/.s.PGSQL.{}", socket.display(), self.port); + Some(full_path) + } + None if self.host.starts_with('/') => { + let full_path = format!("{}/.s.PGSQL.{}", self.host, self.port); + Some(full_path) + } + _ => None, } - _ => None, } } } diff --git a/sqlx-core/src/postgres/options/parse.rs b/sqlx-core/src/postgres/options/parse.rs index 5c5cd71ee8..8971283a0d 100644 --- a/sqlx-core/src/postgres/options/parse.rs +++ b/sqlx-core/src/postgres/options/parse.rs @@ -25,6 +25,11 @@ impl FromStr for PgConnectOptions { options = options.port(port); } + #[cfg(target_arch = "wasm32")] + { + options = options.ws_url(); + } + let username = url.username(); if !username.is_empty() { options = options.username( @@ -53,6 +58,7 @@ impl FromStr for PgConnectOptions { options = options.ssl_mode(value.parse().map_err(Error::config)?); } + #[cfg(not(target_arch = "wasm32"))] "sslrootcert" | "ssl-root-cert" | "ssl-ca" => { options = options.ssl_root_cert(&*value); } diff --git a/sqlx-core/src/postgres/transaction.rs b/sqlx-core/src/postgres/transaction.rs index efb11b8223..e88bec3386 100644 --- a/sqlx-core/src/postgres/transaction.rs +++ b/sqlx-core/src/postgres/transaction.rs @@ -1,4 +1,7 @@ +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; use crate::error::Error; use crate::executor::Executor; diff --git a/sqlx-core/src/postgres/type_info.rs b/sqlx-core/src/postgres/type_info.rs index 6f85364a85..37c018f798 100644 --- a/sqlx-core/src/postgres/type_info.rs +++ b/sqlx-core/src/postgres/type_info.rs @@ -198,6 +198,8 @@ impl PgTypeInfo { .contains(self) { Some("ipnetwork") + } else if [PgTypeInfo::MACADDR].contains(self) { + Some("mac_address") } else if [PgTypeInfo::NUMERIC, PgTypeInfo::NUMERIC_ARRAY].contains(self) { Some("bigdecimal") } else { @@ -740,8 +742,11 @@ impl PgType { PgType::Custom(ty) => &ty.kind, - PgType::DeclareWithOid(_) | PgType::DeclareWithName(_) => { - unreachable!("(bug) use of unresolved type declaration [kind]") + PgType::DeclareWithOid(oid) => { + unreachable!("(bug) use of unresolved type declaration [oid={}]", oid); + } + PgType::DeclareWithName(name) => { + unreachable!("(bug) use of unresolved type declaration [name={}]", name); } } } diff --git a/sqlx-core/src/postgres/types/decimal.rs b/sqlx-core/src/postgres/types/decimal.rs index e206b86b04..61ca06fcb7 100644 --- a/sqlx-core/src/postgres/types/decimal.rs +++ b/sqlx-core/src/postgres/types/decimal.rs @@ -88,7 +88,8 @@ impl TryFrom<&'_ Decimal> for PgNumeric { type Error = BoxDynError; fn try_from(decimal: &Decimal) -> Result { - if decimal.is_zero() { + // `Decimal` added `is_zero()` as an inherent method in a more recent version + if Zero::is_zero(decimal) { return Ok(PgNumeric::Number { sign: PgNumericSign::Positive, scale: 0, diff --git a/sqlx-core/src/postgres/types/interval.rs b/sqlx-core/src/postgres/types/interval.rs index 8a5307adac..42805a56b5 100644 --- a/sqlx-core/src/postgres/types/interval.rs +++ b/sqlx-core/src/postgres/types/interval.rs @@ -148,9 +148,32 @@ impl TryFrom for PgInterval { /// Convert a `chrono::Duration` to a `PgInterval`. /// /// This returns an error if there is a loss of precision using nanoseconds or if there is a - /// microsecond or nanosecond overflow. + /// nanosecond overflow. fn try_from(value: chrono::Duration) -> Result { - value.to_std()?.try_into() + value + .num_nanoseconds() + .map_or::, _>( + Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), + |nanoseconds| { + if nanoseconds % 1000 != 0 { + return Err( + "PostgreSQL `INTERVAL` does not support nanoseconds precision".into(), + ); + } + Ok(()) + }, + )?; + + value.num_microseconds().map_or( + Err("Overflow has occurred for PostgreSQL `INTERVAL`".into()), + |microseconds| { + Ok(Self { + months: 0, + days: 0, + microseconds: microseconds, + }) + }, + ) } } @@ -283,6 +306,7 @@ fn test_encode_interval() { #[test] fn test_pginterval_std() { + // Case for positive duration let interval = PgInterval { days: 0, months: 0, @@ -292,11 +316,18 @@ fn test_pginterval_std() { &PgInterval::try_from(std::time::Duration::from_micros(27_000)).unwrap(), &interval ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(std::time::Duration::from_nanos(27_000_001)).is_err()); + + // Case when microsecond overflow occurs + assert!(PgInterval::try_from(std::time::Duration::from_secs(20_000_000_000_000)).is_err()); } #[test] #[cfg(feature = "chrono")] fn test_pginterval_chrono() { + // Case for positive duration let interval = PgInterval { days: 0, months: 0, @@ -306,11 +337,31 @@ fn test_pginterval_chrono() { &PgInterval::try_from(chrono::Duration::microseconds(27_000)).unwrap(), &interval ); + + // Case for negative duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: -27_000, + }; + assert_eq!( + &PgInterval::try_from(chrono::Duration::microseconds(-27_000)).unwrap(), + &interval + ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(chrono::Duration::nanoseconds(27_000_001)).is_err()); + assert!(PgInterval::try_from(chrono::Duration::nanoseconds(-27_000_001)).is_err()); + + // Case when nanosecond overflow occurs + assert!(PgInterval::try_from(chrono::Duration::seconds(10_000_000_000)).is_err()); + assert!(PgInterval::try_from(chrono::Duration::seconds(-10_000_000_000)).is_err()); } #[test] #[cfg(feature = "time")] fn test_pginterval_time() { + // Case for positive duration let interval = PgInterval { days: 0, months: 0, @@ -320,4 +371,23 @@ fn test_pginterval_time() { &PgInterval::try_from(time::Duration::microseconds(27_000)).unwrap(), &interval ); + + // Case for negative duration + let interval = PgInterval { + days: 0, + months: 0, + microseconds: -27_000, + }; + assert_eq!( + &PgInterval::try_from(time::Duration::microseconds(-27_000)).unwrap(), + &interval + ); + + // Case when precision loss occurs + assert!(PgInterval::try_from(time::Duration::nanoseconds(27_000_001)).is_err()); + assert!(PgInterval::try_from(time::Duration::nanoseconds(-27_000_001)).is_err()); + + // Case when microsecond overflow occurs + assert!(PgInterval::try_from(time::Duration::seconds(10_000_000_000_000)).is_err()); + assert!(PgInterval::try_from(time::Duration::seconds(-10_000_000_000_000)).is_err()); } diff --git a/sqlx-core/src/postgres/types/ipnetwork.rs b/sqlx-core/src/postgres/types/ipnetwork.rs index 5d579e8648..84611814b2 100644 --- a/sqlx-core/src/postgres/types/ipnetwork.rs +++ b/sqlx-core/src/postgres/types/ipnetwork.rs @@ -38,6 +38,10 @@ impl Type for [IpNetwork] { fn type_info() -> PgTypeInfo { PgTypeInfo::INET_ARRAY } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::CIDR_ARRAY || *ty == PgTypeInfo::INET_ARRAY + } } impl Type for Vec { diff --git a/sqlx-core/src/postgres/types/mac_address.rs b/sqlx-core/src/postgres/types/mac_address.rs new file mode 100644 index 0000000000..37bd543217 --- /dev/null +++ b/sqlx-core/src/postgres/types/mac_address.rs @@ -0,0 +1,63 @@ +use mac_address::MacAddress; + +use std::convert::TryInto; + +use crate::decode::Decode; +use crate::encode::{Encode, IsNull}; +use crate::error::BoxDynError; +use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueFormat, PgValueRef, Postgres}; +use crate::types::Type; + +impl Type for MacAddress { + fn type_info() -> PgTypeInfo { + PgTypeInfo::MACADDR + } + + fn compatible(ty: &PgTypeInfo) -> bool { + *ty == PgTypeInfo::MACADDR + } +} + +impl Type for [MacAddress] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::MACADDR_ARRAY + } +} + +impl Type for Vec { + fn type_info() -> PgTypeInfo { + <[MacAddress] as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <[MacAddress] as Type>::compatible(ty) + } +} + +impl Encode<'_, Postgres> for MacAddress { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + buf.extend_from_slice(&self.bytes()); // write just the address + IsNull::No + } + + fn size_hint(&self) -> usize { + 6 + } +} + +impl Decode<'_, Postgres> for MacAddress { + fn decode(value: PgValueRef<'_>) -> Result { + let bytes = match value.format() { + PgValueFormat::Binary => value.as_bytes()?, + PgValueFormat::Text => { + return Ok(value.as_str()?.parse()?); + } + }; + + if bytes.len() == 6 { + return Ok(MacAddress::new(bytes.try_into().unwrap())); + } + + Err("invalid data received when expecting an MACADDR".into()) + } +} diff --git a/sqlx-core/src/postgres/types/mod.rs b/sqlx-core/src/postgres/types/mod.rs index 3827d9dc6e..066a8c2309 100644 --- a/sqlx-core/src/postgres/types/mod.rs +++ b/sqlx-core/src/postgres/types/mod.rs @@ -73,6 +73,14 @@ //! |---------------------------------------|------------------------------------------------------| //! | `ipnetwork::IpNetwork` | INET, CIDR | //! +//! ### [`mac_address`](https://crates.io/crates/mac_address) +//! +//! Requires the `mac_address` Cargo feature flag. +//! +//! | Rust type | Postgres type(s) | +//! |---------------------------------------|------------------------------------------------------| +//! | `mac_address::MacAddress` | MACADDR | +//! //! ### [`bit-vec`](https://crates.io/crates/bit-vec) //! //! Requires the `bit-vec` Cargo feature flag. @@ -194,6 +202,9 @@ mod json; #[cfg(feature = "ipnetwork")] mod ipnetwork; +#[cfg(feature = "mac_address")] +mod mac_address; + #[cfg(feature = "bit-vec")] mod bit_vec; diff --git a/sqlx-core/src/postgres/types/money.rs b/sqlx-core/src/postgres/types/money.rs index 2ae47dcd63..f327726710 100644 --- a/sqlx-core/src/postgres/types/money.rs +++ b/sqlx-core/src/postgres/types/money.rs @@ -20,46 +20,102 @@ use std::{ /// /// Reading `MONEY` value in text format is not supported and will cause an error. /// +/// ### `locale_frac_digits` +/// This parameter corresponds to the number of digits after the decimal separator. +/// +/// This value must match what Postgres is expecting for the locale set in the database +/// or else the decimal value you see on the client side will not match the `money` value +/// on the server side. +/// +/// **For _most_ locales, this value is `2`.** +/// +/// If you're not sure what locale your database is set to or how many decimal digits it specifies, +/// you can execute `SHOW lc_monetary;` to get the locale name, and then look it up in this list +/// (you can ignore the `.utf8` prefix): +/// https://lh.2xlibre.net/values/frac_digits/ +/// +/// If that link is dead and you're on a POSIX-compliant system (Unix, FreeBSD) you can also execute: +/// +/// ```sh +/// $ LC_MONETARY= locale -k frac_digits +/// ``` +/// +/// And the value you want is `N` in `frac_digits=N`. If you have shell access to the database +/// server you should execute it there as available locales may differ between machines. +/// +/// Note that if `frac_digits` for the locale is outside the range `[0, 10]`, Postgres assumes +/// it's a sentinel value and defaults to 2: +/// https://github.com/postgres/postgres/blob/master/src/backend/utils/adt/cash.c#L114-L123 +/// /// [`MONEY`]: https://www.postgresql.org/docs/current/datatype-money.html #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct PgMoney(pub i64); +pub struct PgMoney( + /// The raw integer value sent over the wire; for locales with `frac_digits=2` (i.e. most + /// of them), this will be the value in whole cents. + /// + /// E.g. for `select '$123.45'::money` with a locale of `en_US` (`frac_digits=2`), + /// this will be `12345`. + /// + /// If the currency of your locale does not have fractional units, e.g. Yen, then this will + /// just be the units of the currency. + /// + /// See the type-level docs for an explanation of `locale_frac_units`. + pub i64, +); impl PgMoney { - /// Convert the money value into a [`BigDecimal`] using the correct - /// precision defined in the PostgreSQL settings. The default precision is - /// two. + /// Convert the money value into a [`BigDecimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`BigDecimal`]: crate::types::BigDecimal #[cfg(feature = "bigdecimal")] - pub fn to_bigdecimal(self, scale: i64) -> bigdecimal::BigDecimal { + pub fn to_bigdecimal(self, locale_frac_digits: i64) -> bigdecimal::BigDecimal { let digits = num_bigint::BigInt::from(self.0); - bigdecimal::BigDecimal::new(digits, scale) + bigdecimal::BigDecimal::new(digits, locale_frac_digits) } - /// Convert the money value into a [`Decimal`] using the correct precision - /// defined in the PostgreSQL settings. The default precision is two. + /// Convert the money value into a [`Decimal`] using `locale_frac_digits`. + /// + /// See the type-level docs for an explanation of `locale_frac_digits`. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] - pub fn to_decimal(self, scale: u32) -> rust_decimal::Decimal { - rust_decimal::Decimal::new(self.0, scale) + pub fn to_decimal(self, locale_frac_digits: u32) -> rust_decimal::Decimal { + rust_decimal::Decimal::new(self.0, locale_frac_digits) } - /// Convert a [`Decimal`] value into money using the correct precision - /// defined in the PostgreSQL settings. The default precision is two. + /// Convert a [`Decimal`] value into money using `locale_frac_digits`. /// - /// Conversion may involve a loss of precision. + /// See the type-level docs for an explanation of `locale_frac_digits`. + /// + /// Note that `Decimal` has 96 bits of precision, but `PgMoney` only has 63 plus the sign bit. + /// If the value is larger than 63 bits it will be truncated. /// /// [`Decimal`]: crate::types::Decimal #[cfg(feature = "decimal")] - pub fn from_decimal(decimal: rust_decimal::Decimal, scale: u32) -> Self { - let cents = (decimal * rust_decimal::Decimal::new(10i64.pow(scale), 0)).round(); + pub fn from_decimal(mut decimal: rust_decimal::Decimal, locale_frac_digits: u32) -> Self { + use std::convert::TryFrom; + + // this is all we need to convert to our expected locale's `frac_digits` + decimal.rescale(locale_frac_digits); + + /// a mask to bitwise-AND with an `i64` to zero the sign bit + const SIGN_MASK: i64 = i64::MAX; + + let is_negative = decimal.is_sign_negative(); + let serialized = decimal.serialize(); - let mut buf: [u8; 8] = [0; 8]; - buf.copy_from_slice(¢s.serialize()[4..12]); + // interpret bytes `4..12` as an i64, ignoring the sign bit + // this is where truncation occurs + let value = i64::from_le_bytes( + *<&[u8; 8]>::try_from(&serialized[4..12]) + .expect("BUG: slice of serialized should be 8 bytes"), + ) & SIGN_MASK; // zero out the sign bit - Self(i64::from_le_bytes(buf)) + // negate if necessary + Self(if is_negative { -value } else { value }) } /// Convert a [`BigDecimal`](crate::types::BigDecimal) value into money using the correct precision @@ -67,12 +123,14 @@ impl PgMoney { #[cfg(feature = "bigdecimal")] pub fn from_bigdecimal( decimal: bigdecimal::BigDecimal, - scale: u32, + locale_frac_digits: u32, ) -> Result { use bigdecimal::ToPrimitive; - let multiplier = - bigdecimal::BigDecimal::new(num_bigint::BigInt::from(10i128.pow(scale)), 0); + let multiplier = bigdecimal::BigDecimal::new( + num_bigint::BigInt::from(10i128.pow(locale_frac_digits)), + 0, + ); let cents = decimal * multiplier; @@ -277,9 +335,25 @@ mod tests { #[test] #[cfg(feature = "decimal")] fn conversion_from_decimal_works() { - let dec = rust_decimal::Decimal::new(12345, 2); + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(12345, 2), 2) + ); - assert_eq!(PgMoney(12345), PgMoney::from_decimal(dec, 2)); + assert_eq!( + PgMoney(12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12345), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123450, 3), 2) + ); + + assert_eq!( + PgMoney(-12300), + PgMoney::from_decimal(rust_decimal::Decimal::new(-123, 0), 2) + ); } #[test] diff --git a/sqlx-core/src/postgres/types/range.rs b/sqlx-core/src/postgres/types/range.rs index 760249f79c..59f689d9c0 100644 --- a/sqlx-core/src/postgres/types/range.rs +++ b/sqlx-core/src/postgres/types/range.rs @@ -142,6 +142,17 @@ impl Type for PgRange { } } +#[cfg(feature = "decimal")] +impl Type for PgRange { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE + } + + fn compatible(ty: &PgTypeInfo) -> bool { + range_compatible::(ty) + } +} + #[cfg(feature = "chrono")] impl Type for PgRange { fn type_info() -> PgTypeInfo { @@ -227,6 +238,13 @@ impl Type for [PgRange] { } } +#[cfg(feature = "decimal")] +impl Type for [PgRange] { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + #[cfg(feature = "chrono")] impl Type for [PgRange] { fn type_info() -> PgTypeInfo { @@ -288,6 +306,13 @@ impl Type for Vec> { } } +#[cfg(feature = "decimal")] +impl Type for Vec> { + fn type_info() -> PgTypeInfo { + PgTypeInfo::NUM_RANGE_ARRAY + } +} + #[cfg(feature = "chrono")] impl Type for Vec> { fn type_info() -> PgTypeInfo { diff --git a/sqlx-core/src/postgres/types/str.rs b/sqlx-core/src/postgres/types/str.rs index 3607a4b898..7a721569d6 100644 --- a/sqlx-core/src/postgres/types/str.rs +++ b/sqlx-core/src/postgres/types/str.rs @@ -4,6 +4,7 @@ use crate::error::BoxDynError; use crate::postgres::types::array_compatible; use crate::postgres::{PgArgumentBuffer, PgTypeInfo, PgValueRef, Postgres}; use crate::types::Type; +use std::borrow::Cow; impl Type for str { fn type_info() -> PgTypeInfo { @@ -22,6 +23,16 @@ impl Type for str { } } +impl Type for Cow<'_, str> { + fn type_info() -> PgTypeInfo { + <&str as Type>::type_info() + } + + fn compatible(ty: &PgTypeInfo) -> bool { + <&str as Type>::compatible(ty) + } +} + impl Type for [&'_ str] { fn type_info() -> PgTypeInfo { PgTypeInfo::TEXT_ARRAY @@ -50,6 +61,15 @@ impl Encode<'_, Postgres> for &'_ str { } } +impl Encode<'_, Postgres> for Cow<'_, str> { + fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { + match self { + Cow::Borrowed(str) => <&str as Encode>::encode(*str, buf), + Cow::Owned(str) => <&str as Encode>::encode(&**str, buf), + } + } +} + impl Encode<'_, Postgres> for String { fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> IsNull { <&str as Encode>::encode(&**self, buf) @@ -62,6 +82,12 @@ impl<'r> Decode<'r, Postgres> for &'r str { } } +impl<'r> Decode<'r, Postgres> for Cow<'r, str> { + fn decode(value: PgValueRef<'r>) -> Result { + Ok(Cow::Borrowed(value.as_str()?)) + } +} + impl Type for String { fn type_info() -> PgTypeInfo { <&str as Type>::type_info() diff --git a/sqlx-core/src/postgres/ws_listener.rs b/sqlx-core/src/postgres/ws_listener.rs new file mode 100644 index 0000000000..4fe9c99b18 --- /dev/null +++ b/sqlx-core/src/postgres/ws_listener.rs @@ -0,0 +1,355 @@ +use crate::describe::Describe; +use crate::executor::{Execute, Executor}; +use crate::postgres::message::{MessageFormat, Notification}; +use crate::postgres::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres}; +use crate::{connection::Connection, error::Error}; +use either::Either; +use futures_channel::mpsc; +use futures_core::future::LocalBoxFuture as BoxFuture; +use futures_core::stream::LocalBoxStream as BoxStream; +use futures_core::stream::Stream; +use std::fmt::{self, Debug}; +use std::io; +use std::str::from_utf8; + +/// Represents a connection to a Postgres db over a websocket connection +pub struct PgListener { + connection: Option, + buffer_rx: mpsc::UnboundedReceiver, + buffer_tx: Option>, + channels: Vec, + url: String, +} + +/// An asynchronous notification from Postgres. +pub struct PgNotification(Notification); + +impl PgListener { + /// Connects to a PG instance over a websocket connection + pub async fn connect(url: &str) -> Result { + let mut connection = PgConnection::connect(url).await?; + let (sender, receiver) = mpsc::unbounded(); + connection.stream.notifications = Some(sender); + + Ok(Self { + connection: Some(connection), + buffer_rx: receiver, + buffer_tx: None, + channels: Vec::new(), + url: url.into(), + }) + } + + /// Starts listening for notifications on a channel. + /// The channel name is quoted here to ensure case sensitivity. + pub async fn listen(&mut self, channel: &str) -> Result<(), Error> { + self.connection() + .execute(&*format!(r#"LISTEN "{}""#, ident(channel))) + .await?; + + self.channels.push(channel.to_owned()); + + Ok(()) + } + + /// Starts listening for notifications on all channels. + pub async fn listen_all( + &mut self, + channels: impl IntoIterator, + ) -> Result<(), Error> { + let beg = self.channels.len(); + self.channels.extend(channels.into_iter().map(|s| s.into())); + + self.connection + .as_mut() + .unwrap() + .execute(&*build_listen_all_query(&self.channels[beg..])) + .await?; + + Ok(()) + } + + /// Stops listening for notifications on a channel. + /// The channel name is quoted here to ensure case sensitivity. + pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> { + self.connection() + .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel))) + .await?; + + if let Some(pos) = self.channels.iter().position(|s| s == channel) { + self.channels.remove(pos); + } + + Ok(()) + } + + /// Stops listening for notifications on all channels. + pub async fn unlisten_all(&mut self) -> Result<(), Error> { + self.connection().execute("UNLISTEN *").await?; + + self.channels.clear(); + + Ok(()) + } + + #[inline] + async fn connect_if_needed(&mut self) -> Result<(), Error> { + if self.connection.is_none() { + let mut connection = PgConnection::connect(&self.url).await?; + connection.stream.notifications = self.buffer_tx.take(); + + connection + .execute(&*build_listen_all_query(&self.channels)) + .await?; + + self.connection = Some(connection); + } + + Ok(()) + } + + #[inline] + fn connection(&mut self) -> &mut PgConnection { + self.connection.as_mut().unwrap() + } + + /// Receives the next notification available from any of the subscribed channels. + /// + /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next + /// call to `recv()`, and should be entirely transparent (as long as it was just an + /// intermittent network failure or long-lived connection reaper). + /// + /// As notifications are transient, any received while the connection was lost, will not + /// be returned. If you'd prefer the reconnection to be explicit and have a chance to + /// do something before, please see [`try_recv`](Self::try_recv). + /// + /// # Example + /// + /// ```rust,no_run + /// # use sqlx_core::postgres::PgListener; + /// # use sqlx_core::error::Error; + /// # + /// # #[cfg(feature = "_rt-async-std")] + /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move { + /// # let mut listener = PgListener::connect("postgres:// ...").await?; + /// loop { + /// // ask for next notification, re-connecting (transparently) if needed + /// let notification = listener.recv().await?; + /// + /// // handle notification, do something interesting + /// } + /// # Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn recv(&mut self) -> Result { + loop { + if let Some(notification) = self.try_recv().await? { + return Ok(notification); + } + } + } + + /// Receives the next notification available from any of the subscribed channels. + /// + /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is + /// reconnected on the next call to `try_recv()`. + /// + /// # Example + /// + /// ```rust,no_run + /// # use sqlx_core::postgres::PgListener; + /// # use sqlx_core::error::Error; + /// # + /// # #[cfg(feature = "_rt-async-std")] + /// # sqlx_rt::block_on::<_, Result<(), Error>>(async move { + /// # let mut listener = PgListener::connect("postgres:// ...").await?; + /// loop { + /// // start handling notifications, connecting if needed + /// while let Some(notification) = listener.try_recv().await? { + /// // handle notification + /// } + /// + /// // connection lost, do something interesting + /// } + /// # Ok(()) + /// # }).unwrap(); + /// ``` + pub async fn try_recv(&mut self) -> Result, Error> { + // Flush the buffer first, if anything + // This would only fill up if this listener is used as a connection + if let Ok(Some(notification)) = self.buffer_rx.try_next() { + return Ok(Some(PgNotification(notification))); + } + + loop { + // Ensure we have an active connection to work with. + self.connect_if_needed().await?; + + let message = match self.connection().stream.recv_unchecked().await { + Ok(message) => message, + + // The connection is dead, ensure that it is dropped, + // update self state, and loop to try again. + Err(Error::Io(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + self.buffer_tx = self.connection().stream.notifications.take(); + self.connection = None; + + // lost connection + return Ok(None); + } + + // Forward other errors + Err(error) => { + return Err(error); + } + }; + + match message.format { + // We've received an async notification, return it. + MessageFormat::NotificationResponse => { + return Ok(Some(PgNotification(message.decode()?))); + } + + // Mark the connection as ready for another query + MessageFormat::ReadyForQuery => { + self.connection().pending_ready_for_query_count -= 1; + } + + // Ignore unexpected messages + _ => {} + } + } + } + + /// Consume this listener, returning a `Stream` of notifications. + /// + /// The backing connection will be automatically reconnected should it be lost. + /// + /// This has the same potential drawbacks as [`recv`](PgListener::recv). + /// + pub fn into_stream(mut self) -> impl Stream> + Unpin { + Box::pin(try_stream! { + loop { + r#yield!(self.recv().await?); + } + }) + } +} + +impl<'c> Executor<'c> for &'c mut PgListener { + type Database = Postgres; + + fn fetch_many<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxStream<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + self.connection().fetch_many(query) + } + + fn fetch_optional<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + E: Execute<'q, Self::Database>, + { + self.connection().fetch_optional(query) + } + + fn prepare_with<'e, 'q: 'e>( + self, + query: &'q str, + parameters: &'e [PgTypeInfo], + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + self.connection().prepare_with(query, parameters) + } + + #[doc(hidden)] + fn describe<'e, 'q: 'e>( + self, + query: &'q str, + ) -> BoxFuture<'e, Result, Error>> + where + 'c: 'e, + { + self.connection().describe(query) + } +} + +impl PgNotification { + /// The process ID of the notifying backend process. + #[inline] + pub fn process_id(&self) -> u32 { + self.0.process_id + } + + /// The channel that the notify has been raised on. This can be thought + /// of as the message topic. + #[inline] + pub fn channel(&self) -> &str { + from_utf8(&self.0.channel).unwrap() + } + + /// The payload of the notification. An empty payload is received as an + /// empty string. + #[inline] + pub fn payload(&self) -> &str { + from_utf8(&self.0.payload).unwrap() + } +} + +impl Debug for PgListener { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgListener").finish() + } +} + +impl Debug for PgNotification { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("PgNotification") + .field("process_id", &self.process_id()) + .field("channel", &self.channel()) + .field("payload", &self.payload()) + .finish() + } +} + +fn ident(mut name: &str) -> String { + // If the input string contains a NUL byte, we should truncate the + // identifier. + if let Some(index) = name.find('\0') { + name = &name[..index]; + } + + // Any double quotes must be escaped + name.replace('"', "\"\"") +} + +fn build_listen_all_query(channels: impl IntoIterator>) -> String { + channels.into_iter().fold(String::new(), |mut acc, chan| { + acc.push_str(r#"LISTEN ""#); + acc.push_str(&ident(chan.as_ref())); + acc.push_str(r#"";"#); + acc + }) +} + +#[test] +fn test_build_listen_all_query_with_single_channel() { + let output = build_listen_all_query(&["test"]); + assert_eq!(output.as_str(), r#"LISTEN "test";"#); +} + +#[test] +fn test_build_listen_all_query_with_multiple_channels() { + let output = build_listen_all_query(&["channel.0", "channel.1"]); + assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#); +} diff --git a/sqlx-core/src/query.rs b/sqlx-core/src/query.rs index b3e30dc52c..ad01f24e4c 100644 --- a/sqlx-core/src/query.rs +++ b/sqlx-core/src/query.rs @@ -1,7 +1,12 @@ use std::marker::PhantomData; use either::Either; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::stream::BoxStream; +#[cfg(target_arch = "wasm32")] +use futures_core::stream::LocalBoxStream as BoxStream; + use futures_util::{future, StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::{Arguments, IntoArguments}; @@ -36,6 +41,7 @@ pub struct Map<'q, DB: Database, F, A> { mapper: F, } +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB, A> Execute<'q, DB> for Query<'q, DB, A> where DB: Database, @@ -67,6 +73,38 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB, A> Execute<'q, DB> for Query<'q, DB, A> +where + DB: Database, + A: IntoArguments<'q, DB>, +{ + #[inline] + fn sql(&self) -> &'q str { + match self.statement { + Either::Right(ref statement) => statement.sql(), + Either::Left(sql) => sql, + } + } + + fn statement(&self) -> Option<&>::Statement> { + match self.statement { + Either::Right(ref statement) => Some(&statement), + Either::Left(_) => None, + } + } + + #[inline] + fn take_arguments(&mut self) -> Option<>::Arguments> { + self.arguments.take().map(IntoArguments::into_arguments) + } + + #[inline] + fn persistent(&self) -> bool { + self.persistent + } +} + impl<'q, DB: Database> Query<'q, DB, >::Arguments> { /// Bind a value for use with this SQL query. /// @@ -76,6 +114,7 @@ impl<'q, DB: Database> Query<'q, DB, >::Arguments> { /// /// There is no validation that the value is of the type expected by the query. Most SQL /// flavors will perform type coercion (Postgres will return a database error). + #[cfg(not(target_arch = "wasm32"))] pub fn bind + Type>(mut self, value: T) -> Self { if let Some(arguments) = &mut self.arguments { arguments.add(value); @@ -83,6 +122,15 @@ impl<'q, DB: Database> Query<'q, DB, >::Arguments> { self } + + #[cfg(target_arch = "wasm32")] + pub fn bind + Type>(mut self, value: T) -> Self { + if let Some(arguments) = &mut self.arguments { + arguments.add(value); + } + + self + } } impl<'q, DB, A> Query<'q, DB, A> @@ -103,6 +151,7 @@ where } } +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB, A: Send> Query<'q, DB, A> where DB: Database, @@ -227,6 +276,129 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB, A> Query<'q, DB, A> +where + DB: Database, + A: 'q + IntoArguments<'q, DB>, +{ + /// Map each row in the result to another type. + /// + /// See [`try_map`](Query::try_map) for a fallible version of this method. + /// + /// The [`query_as`](super::query_as::query_as) method will construct a mapped query using + /// a [`FromRow`](super::from_row::FromRow) implementation. + #[inline] + pub fn map(self, mut f: F) -> Map<'q, DB, impl FnMut(DB::Row) -> Result, A> + where + F: FnMut(DB::Row) -> O, + O: Unpin, + { + self.try_map(move |row| Ok(f(row))) + } + + /// Map each row in the result to another type. + /// + /// The [`query_as`](super::query_as::query_as) method will construct a mapped query using + /// a [`FromRow`](super::from_row::FromRow) implementation. + #[inline] + pub fn try_map(self, f: F) -> Map<'q, DB, F, A> + where + F: FnMut(DB::Row) -> Result, + O: Unpin, + { + Map { + inner: self, + mapper: f, + } + } + + /// Execute the query and return the total number of rows affected. + #[inline] + pub async fn execute<'e, 'c: 'e, E>(self, executor: E) -> Result + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.execute(self).await + } + + /// Execute multiple queries and return the rows affected from each query, in a stream. + #[inline] + pub async fn execute_many<'e, 'c: 'e, E>( + self, + executor: E, + ) -> BoxStream<'e, Result> + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.execute_many(self) + } + + /// Execute the query and return the generated results as a stream. + #[inline] + pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.fetch(self) + } + + /// Execute multiple queries and return the generated results as a stream + /// from each query, in a stream. + #[inline] + pub fn fetch_many<'e, 'c: 'e, E>( + self, + executor: E, + ) -> BoxStream<'e, Result, Error>> + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.fetch_many(self) + } + + /// Execute the query and return all the generated results, collected into a [`Vec`]. + #[inline] + pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.fetch_all(self).await + } + + /// Execute the query and returns exactly one row. + #[inline] + pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> Result + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.fetch_one(self).await + } + + /// Execute the query and returns at most one row. + #[inline] + pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + where + 'q: 'e, + A: 'e, + E: Executor<'c, Database = DB>, + { + executor.fetch_optional(self).await + } +} + +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB, F: Send, A: Send> Execute<'q, DB> for Map<'q, DB, F, A> where DB: Database, @@ -253,6 +425,34 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB, F, A> Execute<'q, DB> for Map<'q, DB, F, A> +where + DB: Database, + A: IntoArguments<'q, DB>, +{ + #[inline] + fn sql(&self) -> &'q str { + self.inner.sql() + } + + #[inline] + fn statement(&self) -> Option<&>::Statement> { + self.inner.statement() + } + + #[inline] + fn take_arguments(&mut self) -> Option<>::Arguments> { + self.inner.take_arguments() + } + + #[inline] + fn persistent(&self) -> bool { + self.inner.arguments.is_some() + } +} + +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB, F, O, A> Map<'q, DB, F, A> where DB: Database, @@ -394,6 +594,142 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB, F, O, A> Map<'q, DB, F, A> +where + DB: Database, + F: FnMut(DB::Row) -> Result, + O: Unpin, + A: 'q + IntoArguments<'q, DB>, +{ + /// Map each row in the result to another type. + /// + /// See [`try_map`](Map::try_map) for a fallible version of this method. + /// + /// The [`query_as`](super::query_as::query_as) method will construct a mapped query using + /// a [`FromRow`](super::from_row::FromRow) implementation. + #[inline] + pub fn map(self, mut g: G) -> Map<'q, DB, impl FnMut(DB::Row) -> Result, A> + where + G: FnMut(O) -> P, + P: Unpin, + { + self.try_map(move |data| Ok(g(data))) + } + + /// Map each row in the result to another type. + /// + /// The [`query_as`](super::query_as::query_as) method will construct a mapped query using + /// a [`FromRow`](super::from_row::FromRow) implementation. + #[inline] + pub fn try_map(self, mut g: G) -> Map<'q, DB, impl FnMut(DB::Row) -> Result, A> + where + G: FnMut(O) -> Result, + P: Unpin, + { + let mut f = self.mapper; + Map { + inner: self.inner, + mapper: move |row| f(row).and_then(|o| g(o)), + } + } + + /// Execute the query and return the generated results as a stream. + pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + F: 'e, + O: 'e, + { + self.fetch_many(executor) + .try_filter_map(|step| async move { + Ok(match step { + Either::Left(_) => None, + Either::Right(o) => Some(o), + }) + }) + .boxed_local() + } + + /// Execute multiple queries and return the generated results as a stream + /// from each query, in a stream. + pub fn fetch_many<'e, 'c: 'e, E>( + mut self, + executor: E, + ) -> BoxStream<'e, Result, Error>> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + F: 'e, + O: 'e, + { + Box::pin(try_stream! { + let mut s = executor.fetch_many(self.inner); + + while let Some(v) = s.try_next().await? { + r#yield!(match v { + Either::Left(v) => Either::Left(v), + Either::Right(row) => { + Either::Right((self.mapper)(row)?) + } + }); + } + + Ok(()) + }) + } + + /// Execute the query and return all the generated results, collected into a [`Vec`]. + pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + F: 'e, + O: 'e, + { + self.fetch(executor).try_collect().await + } + + /// Execute the query and returns exactly one row. + pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> Result + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + F: 'e, + O: 'e, + { + self.fetch_optional(executor) + .and_then(|row| match row { + Some(row) => future::ok(row), + None => future::err(Error::RowNotFound), + }) + .await + } + + /// Execute the query and returns at most one row. + pub async fn fetch_optional<'e, 'c: 'e, E>(mut self, executor: E) -> Result, Error> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + F: 'e, + O: 'e, + { + let row = executor.fetch_optional(self.inner).await?; + + if let Some(row) = row { + (self.mapper)(row).map(Some) + } else { + Ok(None) + } + } +} + // Make a SQL query from a statement. pub(crate) fn query_statement<'q, DB>( statement: &'q >::Statement, diff --git a/sqlx-core/src/query_as.rs b/sqlx-core/src/query_as.rs index 62406c21e0..e23d3f2581 100644 --- a/sqlx-core/src/query_as.rs +++ b/sqlx-core/src/query_as.rs @@ -1,11 +1,16 @@ use std::marker::PhantomData; use either::Either; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::stream::BoxStream; +#[cfg(target_arch = "wasm32")] +use futures_core::stream::LocalBoxStream as BoxStream; + use futures_util::{StreamExt, TryStreamExt}; use crate::arguments::IntoArguments; -use crate::database::{Database, HasArguments, HasStatement}; +use crate::database::{Database, HasArguments, HasStatement, HasStatementCache}; use crate::encode::Encode; use crate::error::Error; use crate::executor::{Execute, Executor}; @@ -21,6 +26,7 @@ pub struct QueryAs<'q, DB: Database, O, A> { pub(crate) output: PhantomData, } +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB, O: Send, A: Send> Execute<'q, DB> for QueryAs<'q, DB, O, A> where DB: Database, @@ -47,18 +53,71 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB, O: Send, A> Execute<'q, DB> for QueryAs<'q, DB, O, A> +where + DB: Database, + A: 'q + IntoArguments<'q, DB>, +{ + #[inline] + fn sql(&self) -> &'q str { + self.inner.sql() + } + + #[inline] + fn statement(&self) -> Option<&>::Statement> { + self.inner.statement() + } + + #[inline] + fn take_arguments(&mut self) -> Option<>::Arguments> { + self.inner.take_arguments() + } + + #[inline] + fn persistent(&self) -> bool { + self.inner.persistent() + } +} + impl<'q, DB: Database, O> QueryAs<'q, DB, O, >::Arguments> { /// Bind a value for use with this SQL query. /// /// See [`Query::bind`](Query::bind). + #[cfg(not(target_arch = "wasm32"))] pub fn bind + Type>(mut self, value: T) -> Self { self.inner = self.inner.bind(value); self } + + #[cfg(target_arch = "wasm32")] + pub fn bind + Type>(mut self, value: T) -> Self { + self.inner = self.inner.bind(value); + self + } +} + +impl<'q, DB, O, A> QueryAs<'q, DB, O, A> +where + DB: Database + HasStatementCache, +{ + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.inner = self.inner.persistent(value); + self + } } // FIXME: This is very close, nearly 1:1 with `Map` // noinspection DuplicatedCode +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB, O, A> QueryAs<'q, DB, O, A> where DB: Database, @@ -151,6 +210,99 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB, O, A> QueryAs<'q, DB, O, A> +where + DB: Database, + A: 'q + IntoArguments<'q, DB>, + O: Unpin + for<'r> FromRow<'r, DB::Row>, +{ + /// Execute the query and return the generated results as a stream. + pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + A: 'e, + { + self.fetch_many(executor) + .try_filter_map(|step| async move { Ok(step.right()) }) + .boxed_local() + } + + /// Execute multiple queries and return the generated results as a stream + /// from each query, in a stream. + pub fn fetch_many<'e, 'c: 'e, E>( + self, + executor: E, + ) -> BoxStream<'e, Result, Error>> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + A: 'e, + { + Box::pin(try_stream! { + let mut s = executor.fetch_many(self.inner); + + while let Some(v) = s.try_next().await? { + r#yield!(match v { + Either::Left(v) => Either::Left(v), + Either::Right(row) => Either::Right(O::from_row(&row)?), + }); + } + + Ok(()) + }) + } + + /// Execute the query and return all the generated results, collected into a [`Vec`]. + #[inline] + pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + A: 'e, + { + self.fetch(executor).try_collect().await + } + + /// Execute the query and returns exactly one row. + pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> Result + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + A: 'e, + { + self.fetch_optional(executor) + .await + .and_then(|row| row.ok_or(Error::RowNotFound)) + } + + /// Execute the query and returns at most one row. + pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + A: 'e, + { + let row = executor.fetch_optional(self.inner).await?; + if let Some(row) = row { + O::from_row(&row).map(Some) + } else { + Ok(None) + } + } +} + /// Make a SQL query that is mapped to a concrete type /// using [`FromRow`]. #[inline] diff --git a/sqlx-core/src/query_scalar.rs b/sqlx-core/src/query_scalar.rs index 7e958a7b89..8898ccf654 100644 --- a/sqlx-core/src/query_scalar.rs +++ b/sqlx-core/src/query_scalar.rs @@ -1,9 +1,14 @@ use either::Either; + +#[cfg(not(target_arch = "wasm32"))] use futures_core::stream::BoxStream; +#[cfg(target_arch = "wasm32")] +use futures_core::stream::LocalBoxStream as BoxStream; + use futures_util::{StreamExt, TryFutureExt, TryStreamExt}; use crate::arguments::IntoArguments; -use crate::database::{Database, HasArguments, HasStatement}; +use crate::database::{Database, HasArguments, HasStatement, HasStatementCache}; use crate::encode::Encode; use crate::error::Error; use crate::executor::{Execute, Executor}; @@ -20,6 +25,7 @@ pub struct QueryScalar<'q, DB: Database, O, A> { inner: QueryAs<'q, DB, (O,), A>, } +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB: Database, O: Send, A: Send> Execute<'q, DB> for QueryScalar<'q, DB, O, A> where A: 'q + IntoArguments<'q, DB>, @@ -44,18 +50,69 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB: Database, O: Send, A> Execute<'q, DB> for QueryScalar<'q, DB, O, A> +where + A: 'q + IntoArguments<'q, DB>, +{ + #[inline] + fn sql(&self) -> &'q str { + self.inner.sql() + } + + fn statement(&self) -> Option<&>::Statement> { + self.inner.statement() + } + + #[inline] + fn take_arguments(&mut self) -> Option<>::Arguments> { + self.inner.take_arguments() + } + + #[inline] + fn persistent(&self) -> bool { + self.inner.persistent() + } +} + impl<'q, DB: Database, O> QueryScalar<'q, DB, O, >::Arguments> { /// Bind a value for use with this SQL query. /// /// See [`Query::bind`](crate::query::Query::bind). + #[cfg(not(target_arch = "wasm32"))] pub fn bind + Type>(mut self, value: T) -> Self { self.inner = self.inner.bind(value); self } + + #[cfg(target_arch = "wasm32")] + pub fn bind + Type>(mut self, value: T) -> Self { + self.inner = self.inner.bind(value); + self + } +} + +impl<'q, DB, O, A> QueryScalar<'q, DB, O, A> +where + DB: Database + HasStatementCache, +{ + /// If `true`, the statement will get prepared once and cached to the + /// connection's statement cache. + /// + /// If queried once with the flag set to `true`, all subsequent queries + /// matching the one with the flag will use the cached statement until the + /// cache is cleared. + /// + /// Default: `true`. + pub fn persistent(mut self, value: bool) -> Self { + self.inner = self.inner.persistent(value); + self + } } // FIXME: This is very close, nearly 1:1 with `Map` // noinspection DuplicatedCode +#[cfg(not(target_arch = "wasm32"))] impl<'q, DB, O, A> QueryScalar<'q, DB, O, A> where DB: Database, @@ -140,6 +197,91 @@ where } } +#[cfg(target_arch = "wasm32")] +impl<'q, DB, O, A> QueryScalar<'q, DB, O, A> +where + DB: Database, + O: Unpin, + A: 'q + IntoArguments<'q, DB>, + (O,): Unpin + for<'r> FromRow<'r, DB::Row>, +{ + /// Execute the query and return the generated results as a stream. + #[inline] + pub fn fetch<'e, 'c: 'e, E>(self, executor: E) -> BoxStream<'e, Result> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + A: 'e, + O: 'e, + { + self.inner.fetch(executor).map_ok(|it| it.0).boxed_local() + } + + /// Execute multiple queries and return the generated results as a stream + /// from each query, in a stream. + #[inline] + pub fn fetch_many<'e, 'c: 'e, E>( + self, + executor: E, + ) -> BoxStream<'e, Result, Error>> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + A: 'e, + O: 'e, + { + self.inner + .fetch_many(executor) + .map_ok(|v| v.map_right(|it| it.0)) + .boxed_local() + } + + /// Execute the query and return all the generated results, collected into a [`Vec`]. + #[inline] + pub async fn fetch_all<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + (O,): 'e, + A: 'e, + { + self.inner + .fetch(executor) + .map_ok(|it| it.0) + .try_collect() + .await + } + + /// Execute the query and returns exactly one row. + #[inline] + pub async fn fetch_one<'e, 'c: 'e, E>(self, executor: E) -> Result + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + A: 'e, + { + self.inner.fetch_one(executor).map_ok(|it| it.0).await + } + + /// Execute the query and returns at most one row. + #[inline] + pub async fn fetch_optional<'e, 'c: 'e, E>(self, executor: E) -> Result, Error> + where + 'q: 'e, + E: 'e + Executor<'c, Database = DB>, + DB: 'e, + O: 'e, + A: 'e, + { + Ok(self.inner.fetch_optional(executor).await?.map(|it| it.0)) + } +} + /// Make a SQL query that is mapped to a single concrete type /// using [`FromRow`]. #[inline] diff --git a/sqlx-core/src/sqlite/connection/describe.rs b/sqlx-core/src/sqlite/connection/describe.rs index 8bc9f9ceed..cb86e7e024 100644 --- a/sqlx-core/src/sqlite/connection/describe.rs +++ b/sqlx-core/src/sqlite/connection/describe.rs @@ -64,7 +64,7 @@ pub(super) fn describe<'c: 'e, 'q: 'e, 'e>( // fallback to [column_decltype] if !stepped && stmt.read_only() { stepped = true; - let _ = conn.worker.step(*stmt).await; + let _ = conn.worker.step(stmt).await; } let mut ty = stmt.column_type_info(col); diff --git a/sqlx-core/src/sqlite/connection/establish.rs b/sqlx-core/src/sqlite/connection/establish.rs index 20206a4388..ce8105a652 100644 --- a/sqlx-core/src/sqlite/connection/establish.rs +++ b/sqlx-core/src/sqlite/connection/establish.rs @@ -7,8 +7,8 @@ use crate::{ }; use libsqlite3_sys::{ sqlite3_busy_timeout, sqlite3_extended_result_codes, sqlite3_open_v2, SQLITE_OK, - SQLITE_OPEN_CREATE, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, SQLITE_OPEN_PRIVATECACHE, - SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, + SQLITE_OPEN_CREATE, SQLITE_OPEN_FULLMUTEX, SQLITE_OPEN_MEMORY, SQLITE_OPEN_NOMUTEX, + SQLITE_OPEN_PRIVATECACHE, SQLITE_OPEN_READONLY, SQLITE_OPEN_READWRITE, SQLITE_OPEN_SHAREDCACHE, }; use sqlx_rt::blocking; use std::io; @@ -29,13 +29,15 @@ pub(crate) async fn establish(options: &SqliteConnectOptions) -> Result Result Result Result Result( +async fn prepare<'a>( + worker: &mut StatementWorker, statements: &'a mut StatementCache, statement: &'a mut Option, query: &str, @@ -39,7 +40,7 @@ fn prepare<'a>( if exists { // as this statement has been executed before, we reset before continuing // this also causes any rows that are from the statement to be inflated - statement.reset(); + statement.reset(worker).await?; } Ok(statement) @@ -61,19 +62,25 @@ fn bind( /// A structure holding sqlite statement handle and resetting the /// statement when it is dropped. -struct StatementResetter { - handle: StatementHandle, +struct StatementResetter<'a> { + handle: Arc, + worker: &'a mut StatementWorker, } -impl StatementResetter { - fn new(handle: StatementHandle) -> Self { - Self { handle } +impl<'a> StatementResetter<'a> { + fn new(worker: &'a mut StatementWorker, handle: &Arc) -> Self { + Self { + worker, + handle: Arc::clone(handle), + } } } -impl Drop for StatementResetter { +impl Drop for StatementResetter<'_> { fn drop(&mut self) { - self.handle.reset(); + // this method is designed to eagerly send the reset command + // so we don't need to await or spawn it + let _ = self.worker.reset(&self.handle); } } @@ -103,7 +110,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let stmt = prepare(statements, statement, sql, persistent)?; + let stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -113,7 +120,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // is dropped. `StatementResetter` will reliably reset the // statement even if the stream returned from `fetch_many` // is dropped early. - let _resetter = StatementResetter::new(*stmt); + let resetter = StatementResetter::new(worker, stmt); // bind values to the statement num_arguments += bind(stmt, &arguments, num_arguments)?; @@ -125,7 +132,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - let s = worker.step(*stmt).await?; + let s = resetter.worker.step(stmt).await?; match s { Either::Left(changes) => { @@ -145,7 +152,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { Either::Right(()) => { let (row, weak_values_ref) = SqliteRow::current( - *stmt, + stmt.to_ref(conn.to_ref()), columns, column_names ); @@ -188,7 +195,7 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { } = self; // prepare statement object (or checkout from cache) - let virtual_stmt = prepare(statements, statement, sql, persistent)?; + let virtual_stmt = prepare(worker, statements, statement, sql, persistent).await?; // keep track of how many arguments we have bound let mut num_arguments = 0; @@ -205,18 +212,21 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { // invoke [sqlite3_step] on the dedicated worker thread // this will move us forward one row or finish the statement - match worker.step(*stmt).await? { + match worker.step(stmt).await? { Either::Left(_) => (), Either::Right(()) => { - let (row, weak_values_ref) = - SqliteRow::current(*stmt, columns, column_names); + let (row, weak_values_ref) = SqliteRow::current( + stmt.to_ref(self.handle.to_ref()), + columns, + column_names, + ); *last_row_values = Some(weak_values_ref); logger.increment_rows(); - virtual_stmt.reset(); + virtual_stmt.reset(worker).await?; return Ok(Some(row)); } } @@ -238,11 +248,12 @@ impl<'c> Executor<'c> for &'c mut SqliteConnection { handle: ref mut conn, ref mut statements, ref mut statement, + ref mut worker, .. } = self; // prepare statement object (or checkout from cache) - let statement = prepare(statements, statement, sql, true)?; + let statement = prepare(worker, statements, statement, sql, true).await?; let mut parameters = 0; let mut columns = None; diff --git a/sqlx-core/src/sqlite/connection/explain.rs b/sqlx-core/src/sqlite/connection/explain.rs index 3797a3d0bb..14df95e6ac 100644 --- a/sqlx-core/src/sqlite/connection/explain.rs +++ b/sqlx-core/src/sqlite/connection/explain.rs @@ -17,6 +17,13 @@ const SQLITE_AFF_REAL: u8 = 0x45; /* 'E' */ const OP_INIT: &str = "Init"; const OP_GOTO: &str = "Goto"; const OP_COLUMN: &str = "Column"; +const OP_MAKE_RECORD: &str = "MakeRecord"; +const OP_INSERT: &str = "Insert"; +const OP_IDX_INSERT: &str = "IdxInsert"; +const OP_OPEN_READ: &str = "OpenRead"; +const OP_OPEN_WRITE: &str = "OpenWrite"; +const OP_OPEN_EPHEMERAL: &str = "OpenEphemeral"; +const OP_OPEN_AUTOINDEX: &str = "OpenAutoindex"; const OP_AGG_STEP: &str = "AggStep"; const OP_FUNCTION: &str = "Function"; const OP_MOVE: &str = "Move"; @@ -34,6 +41,7 @@ const OP_BLOB: &str = "Blob"; const OP_VARIABLE: &str = "Variable"; const OP_COUNT: &str = "Count"; const OP_ROWID: &str = "Rowid"; +const OP_NEWROWID: &str = "NewRowid"; const OP_OR: &str = "Or"; const OP_AND: &str = "And"; const OP_BIT_AND: &str = "BitAnd"; @@ -48,6 +56,21 @@ const OP_REMAINDER: &str = "Remainder"; const OP_CONCAT: &str = "Concat"; const OP_RESULT_ROW: &str = "ResultRow"; +#[derive(Debug, Clone, Eq, PartialEq)] +enum RegDataType { + Single(DataType), + Record(Vec), +} + +impl RegDataType { + fn map_to_datatype(self) -> DataType { + match self { + RegDataType::Single(d) => d, + RegDataType::Record(_) => DataType::Null, //If we're trying to coerce to a regular Datatype, we can assume a Record is invalid for the context + } + } +} + #[allow(clippy::wildcard_in_or_patterns)] fn affinity_to_type(affinity: u8) -> DataType { match affinity { @@ -73,13 +96,19 @@ fn opcode_to_type(op: &str) -> DataType { } } +// Opcode Reference: https://sqlite.org/opcode.html pub(super) async fn explain( conn: &mut SqliteConnection, query: &str, ) -> Result<(Vec, Vec>), Error> { - let mut r = HashMap::::with_capacity(6); + // Registers + let mut r = HashMap::::with_capacity(6); + // Map between pointer and register let mut r_cursor = HashMap::>::with_capacity(6); + // Rows that pointers point to + let mut p = HashMap::>::with_capacity(6); + // Nullable columns let mut n = HashMap::::with_capacity(6); let program = @@ -119,15 +148,52 @@ pub(super) async fn explain( } OP_COLUMN => { - r_cursor.entry(p1).or_default().push(p3); + //Get the row stored at p1, or NULL; get the column stored at p2, or NULL + if let Some(record) = p.get(&p1) { + if let Some(col) = record.get(&p2) { + // insert into p3 the datatype of the col + r.insert(p3, RegDataType::Single(*col)); + // map between pointer p1 and register p3 + r_cursor.entry(p1).or_default().push(p3); + } else { + r.insert(p3, RegDataType::Single(DataType::Null)); + } + } else { + r.insert(p3, RegDataType::Single(DataType::Null)); + } + } + + OP_MAKE_RECORD => { + // p3 = Record([p1 .. p1 + p2]) + let mut record = Vec::with_capacity(p2 as usize); + for reg in p1..p1 + p2 { + record.push( + r.get(®) + .map(|d| d.clone().map_to_datatype()) + .unwrap_or(DataType::Null), + ); + } + r.insert(p3, RegDataType::Record(record)); + } + + OP_INSERT | OP_IDX_INSERT => { + if let Some(RegDataType::Record(record)) = r.get(&p2) { + if let Some(row) = p.get_mut(&p1) { + // Insert the record into wherever pointer p1 is + *row = (0..).zip(record.iter().copied()).collect(); + } + } + //Noop if the register p2 isn't a record, or if pointer p1 does not exist + } - // r[p3] = - r.insert(p3, DataType::Null); + OP_OPEN_READ | OP_OPEN_WRITE | OP_OPEN_EPHEMERAL | OP_OPEN_AUTOINDEX => { + //Create a new pointer which is referenced by p1 + p.insert(p1, HashMap::with_capacity(6)); } OP_VARIABLE => { // r[p2] = - r.insert(p2, DataType::Null); + r.insert(p2, RegDataType::Single(DataType::Null)); n.insert(p3, true); } @@ -136,7 +202,7 @@ pub(super) async fn explain( match from_utf8(p4).map_err(Error::protocol)? { "last_insert_rowid(0)" => { // last_insert_rowid() -> INTEGER - r.insert(p3, DataType::Int64); + r.insert(p3, RegDataType::Single(DataType::Int64)); n.insert(p3, n.get(&p3).copied().unwrap_or(false)); } @@ -145,9 +211,9 @@ pub(super) async fn explain( } OP_NULL_ROW => { - // all values of cursor X are potentially nullable - for column in &r_cursor[&p1] { - n.insert(*column, true); + // all registers that map to cursor X are potentially nullable + for register in &r_cursor[&p1] { + n.insert(*register, true); } } @@ -156,9 +222,9 @@ pub(super) async fn explain( if p4.starts_with("count(") { // count(_) -> INTEGER - r.insert(p3, DataType::Int64); + r.insert(p3, RegDataType::Single(DataType::Int64)); n.insert(p3, n.get(&p3).copied().unwrap_or(false)); - } else if let Some(v) = r.get(&p2).copied() { + } else if let Some(v) = r.get(&p2).cloned() { // r[p3] = AGG ( r[p2] ) r.insert(p3, v); let val = n.get(&p2).copied().unwrap_or(true); @@ -169,13 +235,13 @@ pub(super) async fn explain( OP_CAST => { // affinity(r[p1]) if let Some(v) = r.get_mut(&p1) { - *v = affinity_to_type(p2 as u8); + *v = RegDataType::Single(affinity_to_type(p2 as u8)); } } OP_COPY | OP_MOVE | OP_SCOPY | OP_INT_COPY => { // r[p2] = r[p1] - if let Some(v) = r.get(&p1).copied() { + if let Some(v) = r.get(&p1).cloned() { r.insert(p2, v); if let Some(null) = n.get(&p1).copied() { @@ -184,15 +250,16 @@ pub(super) async fn explain( } } - OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID => { + OP_OR | OP_AND | OP_BLOB | OP_COUNT | OP_REAL | OP_STRING8 | OP_INTEGER | OP_ROWID + | OP_NEWROWID => { // r[p2] = - r.insert(p2, opcode_to_type(&opcode)); + r.insert(p2, RegDataType::Single(opcode_to_type(&opcode))); n.insert(p2, n.get(&p2).copied().unwrap_or(false)); } OP_NOT => { // r[p2] = NOT r[p1] - if let Some(a) = r.get(&p1).copied() { + if let Some(a) = r.get(&p1).cloned() { r.insert(p2, a); let val = n.get(&p1).copied().unwrap_or(true); n.insert(p2, val); @@ -202,9 +269,16 @@ pub(super) async fn explain( OP_BIT_AND | OP_BIT_OR | OP_SHIFT_LEFT | OP_SHIFT_RIGHT | OP_ADD | OP_SUBTRACT | OP_MULTIPLY | OP_DIVIDE | OP_REMAINDER | OP_CONCAT => { // r[p3] = r[p1] + r[p2] - match (r.get(&p1).copied(), r.get(&p2).copied()) { + match (r.get(&p1).cloned(), r.get(&p2).cloned()) { (Some(a), Some(b)) => { - r.insert(p3, if matches!(a, DataType::Null) { b } else { a }); + r.insert( + p3, + if matches!(a, RegDataType::Single(DataType::Null)) { + b + } else { + a + }, + ); } (Some(v), None) => { @@ -252,7 +326,11 @@ pub(super) async fn explain( if let Some(result) = result { for i in result { - output.push(SqliteTypeInfo(r.remove(&i).unwrap_or(DataType::Null))); + output.push(SqliteTypeInfo( + r.remove(&i) + .map(|d| d.map_to_datatype()) + .unwrap_or(DataType::Null), + )); nullable.push(n.remove(&i)); } } diff --git a/sqlx-core/src/sqlite/connection/handle.rs b/sqlx-core/src/sqlite/connection/handle.rs index 6aa8f37667..c714fcc5f4 100644 --- a/sqlx-core/src/sqlite/connection/handle.rs +++ b/sqlx-core/src/sqlite/connection/handle.rs @@ -3,11 +3,24 @@ use std::ptr::NonNull; use libsqlite3_sys::{sqlite3, sqlite3_close, SQLITE_OK}; use crate::sqlite::SqliteError; +use std::sync::Arc; /// Managed handle to the raw SQLite3 database handle. -/// The database handle will be closed when this is dropped. +/// The database handle will be closed when this is dropped and no `ConnectionHandleRef`s exist. #[derive(Debug)] -pub(crate) struct ConnectionHandle(pub(super) NonNull); +pub(crate) struct ConnectionHandle(Arc); + +/// A wrapper around `ConnectionHandle` which only exists for a `StatementWorker` to own +/// which prevents the `sqlite3` handle from being finalized while it is running `sqlite3_step()` +/// or `sqlite3_reset()`. +/// +/// Note that this does *not* actually give access to the database handle! +#[derive(Clone, Debug)] +pub(crate) struct ConnectionHandleRef(Arc); + +// Wrapper for `*mut sqlite3` which finalizes the handle on-drop. +#[derive(Debug)] +struct HandleInner(NonNull); // A SQLite3 handle is safe to send between threads, provided not more than // one is accessing it at the same time. This is upheld as long as [SQLITE_CONFIG_MULTITHREAD] is @@ -20,19 +33,32 @@ pub(crate) struct ConnectionHandle(pub(super) NonNull); unsafe impl Send for ConnectionHandle {} +// SAFETY: `Arc` normally only implements `Send` where `T: Sync` because it allows +// concurrent access. +// +// However, in this case we're only using `Arc` to prevent the database handle from being +// finalized while the worker still holds a statement handle; `ConnectionHandleRef` thus +// should *not* actually provide access to the database handle. +unsafe impl Send for ConnectionHandleRef {} + impl ConnectionHandle { #[inline] pub(super) unsafe fn new(ptr: *mut sqlite3) -> Self { - Self(NonNull::new_unchecked(ptr)) + Self(Arc::new(HandleInner(NonNull::new_unchecked(ptr)))) } #[inline] pub(crate) fn as_ptr(&self) -> *mut sqlite3 { - self.0.as_ptr() + self.0 .0.as_ptr() + } + + #[inline] + pub(crate) fn to_ref(&self) -> ConnectionHandleRef { + ConnectionHandleRef(Arc::clone(&self.0)) } } -impl Drop for ConnectionHandle { +impl Drop for HandleInner { fn drop(&mut self) { unsafe { // https://sqlite.org/c3ref/close.html diff --git a/sqlx-core/src/sqlite/connection/mod.rs b/sqlx-core/src/sqlite/connection/mod.rs index 92926beef4..e001f08fa3 100644 --- a/sqlx-core/src/sqlite/connection/mod.rs +++ b/sqlx-core/src/sqlite/connection/mod.rs @@ -17,7 +17,7 @@ mod executor; mod explain; mod handle; -pub(crate) use handle::ConnectionHandle; +pub(crate) use handle::{ConnectionHandle, ConnectionHandleRef}; /// A connection to a [Sqlite] database. pub struct SqliteConnection { @@ -62,9 +62,15 @@ impl Connection for SqliteConnection { type Options = SqliteConnectOptions; - fn close(self) -> BoxFuture<'static, Result<(), Error>> { - // nothing explicit to do; connection will close in drop - Box::pin(future::ok(())) + fn close(mut self) -> BoxFuture<'static, Result<(), Error>> { + Box::pin(async move { + let shutdown = self.worker.shutdown(); + // Drop the statement worker and any outstanding statements, which should + // cover all references to the connection handle outside of the worker thread + drop(self); + // Ensure the worker thread has terminated + shutdown.await + }) } fn ping(&mut self) -> BoxFuture<'_, Result<(), Error>> { @@ -104,8 +110,7 @@ impl Connection for SqliteConnection { impl Drop for SqliteConnection { fn drop(&mut self) { - // before the connection handle is dropped, - // we must explicitly drop the statements as the drop-order in a struct is undefined + // explicitly drop statements before the connection handle is dropped self.statements.clear(); self.statement.take(); } diff --git a/sqlx-core/src/sqlite/mod.rs b/sqlx-core/src/sqlite/mod.rs index 6b31ff02b5..5be8cbfd92 100644 --- a/sqlx-core/src/sqlite/mod.rs +++ b/sqlx-core/src/sqlite/mod.rs @@ -5,6 +5,8 @@ // invariants. #![allow(unsafe_code)] +use crate::executor::Executor; + mod arguments; mod column; mod connection; @@ -43,6 +45,10 @@ pub type SqlitePool = crate::pool::Pool; /// An alias for [`PoolOptions`][crate::pool::PoolOptions], specialized for SQLite. pub type SqlitePoolOptions = crate::pool::PoolOptions; +/// An alias for [`Executor<'_, Database = Sqlite>`][Executor]. +pub trait SqliteExecutor<'c>: Executor<'c, Database = Sqlite> {} +impl<'c, T: Executor<'c, Database = Sqlite>> SqliteExecutor<'c> for T {} + // NOTE: required due to the lack of lazy normalization impl_into_arguments_for_arguments!(SqliteArguments<'q>); impl_executor_for_pool_connection!(Sqlite, SqliteConnection, SqliteRow); diff --git a/sqlx-core/src/sqlite/options/connect.rs b/sqlx-core/src/sqlite/options/connect.rs index 6c29120a2f..cbd465ec31 100644 --- a/sqlx-core/src/sqlite/options/connect.rs +++ b/sqlx-core/src/sqlite/options/connect.rs @@ -18,20 +18,12 @@ impl ConnectOptions for SqliteConnectOptions { let mut conn = establish(self).await?; // send an initial sql statement comprised of options - // - // page_size must be set before any other action on the database. - // - // Note that locking_mode should be set before journal_mode; see - // https://www.sqlite.org/wal.html#use_of_wal_without_shared_memory . - let init = format!( - "PRAGMA page_size = {}; PRAGMA locking_mode = {}; PRAGMA journal_mode = {}; PRAGMA foreign_keys = {}; PRAGMA synchronous = {}; PRAGMA auto_vacuum = {}", - self.page_size, - self.locking_mode.as_str(), - self.journal_mode.as_str(), - if self.foreign_keys { "ON" } else { "OFF" }, - self.synchronous.as_str(), - self.auto_vacuum.as_str(), - ); + let mut init = String::new(); + + for (key, value) in self.pragmas.iter() { + use std::fmt::Write; + write!(init, "PRAGMA {} = {}; ", key, value).ok(); + } conn.execute(&*init).await?; diff --git a/sqlx-core/src/sqlite/options/mod.rs b/sqlx-core/src/sqlite/options/mod.rs index ba50bc05d6..9db122f355 100644 --- a/sqlx-core/src/sqlite/options/mod.rs +++ b/sqlx-core/src/sqlite/options/mod.rs @@ -14,6 +14,8 @@ pub use locking_mode::SqliteLockingMode; use std::{borrow::Cow, time::Duration}; pub use synchronous::SqliteSynchronous; +use indexmap::IndexMap; + /// Options and flags which can be used to configure a SQLite connection. /// /// A value of `SqliteConnectOptions` can be parsed from a connection URI, @@ -53,16 +55,13 @@ pub struct SqliteConnectOptions { pub(crate) in_memory: bool, pub(crate) read_only: bool, pub(crate) create_if_missing: bool, - pub(crate) journal_mode: SqliteJournalMode, - pub(crate) locking_mode: SqliteLockingMode, - pub(crate) foreign_keys: bool, pub(crate) shared_cache: bool, pub(crate) statement_cache_capacity: usize, pub(crate) busy_timeout: Duration, pub(crate) log_settings: LogSettings, - pub(crate) synchronous: SqliteSynchronous, - pub(crate) auto_vacuum: SqliteAutoVacuum, - pub(crate) page_size: u32, + pub(crate) immutable: bool, + pub(crate) pragmas: IndexMap, Cow<'static, str>>, + pub(crate) serialized: bool, } impl Default for SqliteConnectOptions { @@ -73,21 +72,45 @@ impl Default for SqliteConnectOptions { impl SqliteConnectOptions { pub fn new() -> Self { + // set default pragmas + let mut pragmas: IndexMap, Cow<'static, str>> = IndexMap::new(); + + let locking_mode: SqliteLockingMode = Default::default(); + let auto_vacuum: SqliteAutoVacuum = Default::default(); + + // page_size must be set before any other action on the database. + pragmas.insert("page_size".into(), "4096".into()); + + // Note that locking_mode should be set before journal_mode; see + // https://www.sqlite.org/wal.html#use_of_wal_without_shared_memory . + pragmas.insert("locking_mode".into(), locking_mode.as_str().into()); + + pragmas.insert( + "journal_mode".into(), + SqliteJournalMode::Wal.as_str().into(), + ); + + pragmas.insert("foreign_keys".into(), "ON".into()); + + pragmas.insert( + "synchronous".into(), + SqliteSynchronous::Full.as_str().into(), + ); + + pragmas.insert("auto_vacuum".into(), auto_vacuum.as_str().into()); + Self { filename: Cow::Borrowed(Path::new(":memory:")), in_memory: false, read_only: false, create_if_missing: false, - foreign_keys: true, shared_cache: false, statement_cache_capacity: 100, - journal_mode: SqliteJournalMode::Wal, - locking_mode: Default::default(), busy_timeout: Duration::from_secs(5), log_settings: Default::default(), - synchronous: SqliteSynchronous::Full, - auto_vacuum: Default::default(), - page_size: 4096, + immutable: false, + pragmas, + serialized: false, } } @@ -101,7 +124,10 @@ impl SqliteConnectOptions { /// /// By default, this is enabled. pub fn foreign_keys(mut self, on: bool) -> Self { - self.foreign_keys = on; + self.pragmas.insert( + "foreign_keys".into(), + (if on { "ON" } else { "OFF" }).into(), + ); self } @@ -118,7 +144,8 @@ impl SqliteConnectOptions { /// The default journal mode is WAL. For most use cases this can be significantly faster but /// there are [disadvantages](https://www.sqlite.org/wal.html). pub fn journal_mode(mut self, mode: SqliteJournalMode) -> Self { - self.journal_mode = mode; + self.pragmas + .insert("journal_mode".into(), mode.as_str().into()); self } @@ -126,7 +153,8 @@ impl SqliteConnectOptions { /// /// The default locking mode is NORMAL. pub fn locking_mode(mut self, mode: SqliteLockingMode) -> Self { - self.locking_mode = mode; + self.pragmas + .insert("locking_mode".into(), mode.as_str().into()); self } @@ -171,7 +199,8 @@ impl SqliteConnectOptions { /// The default synchronous settings is FULL. However, if durability is not a concern, /// then NORMAL is normally all one needs in WAL mode. pub fn synchronous(mut self, synchronous: SqliteSynchronous) -> Self { - self.synchronous = synchronous; + self.pragmas + .insert("synchronous".into(), synchronous.as_str().into()); self } @@ -179,7 +208,8 @@ impl SqliteConnectOptions { /// /// The default auto_vacuum setting is NONE. pub fn auto_vacuum(mut self, auto_vacuum: SqliteAutoVacuum) -> Self { - self.auto_vacuum = auto_vacuum; + self.pragmas + .insert("auto_vacuum".into(), auto_vacuum.as_str().into()); self } @@ -187,7 +217,33 @@ impl SqliteConnectOptions { /// /// The default page_size setting is 4096. pub fn page_size(mut self, page_size: u32) -> Self { - self.page_size = page_size; + self.pragmas + .insert("page_size".into(), page_size.to_string().into()); + self + } + + /// Sets custom initial pragma for the database connection. + pub fn pragma(mut self, key: K, value: V) -> Self + where + K: Into>, + V: Into>, + { + self.pragmas.insert(key.into(), value.into()); + self + } + + pub fn immutable(mut self, immutable: bool) -> Self { + self.immutable = immutable; + self + } + + /// Sets the [threading mode](https://www.sqlite.org/threadsafe.html) for the database connection. + /// + /// The default setting is `false` corersponding to using `OPEN_NOMUTEX`, if `true` then `OPEN_FULLMUTEX`. + /// + /// See [open](https://www.sqlite.org/c3ref/open.html) for more details. + pub fn serialized(mut self, serialized: bool) -> Self { + self.serialized = serialized; self } } diff --git a/sqlx-core/src/sqlite/options/parse.rs b/sqlx-core/src/sqlite/options/parse.rs index 7c21adf469..f677df62b6 100644 --- a/sqlx-core/src/sqlite/options/parse.rs +++ b/sqlx-core/src/sqlite/options/parse.rs @@ -94,6 +94,20 @@ impl FromStr for SqliteConnectOptions { } }, + "immutable" => match &*value { + "true" | "1" => { + options.immutable = true; + } + "false" | "0" => { + options.immutable = false; + } + _ => { + return Err(Error::Configuration( + format!("unknown value {:?} for `immutable`", value).into(), + )); + } + }, + _ => { return Err(Error::Configuration( format!( diff --git a/sqlx-core/src/sqlite/row.rs b/sqlx-core/src/sqlite/row.rs index 9f14ca58f0..4199915fe1 100644 --- a/sqlx-core/src/sqlite/row.rs +++ b/sqlx-core/src/sqlite/row.rs @@ -11,7 +11,7 @@ use crate::column::ColumnIndex; use crate::error::Error; use crate::ext::ustr::UStr; use crate::row::Row; -use crate::sqlite::statement::StatementHandle; +use crate::sqlite::statement::{StatementHandle, StatementHandleRef}; use crate::sqlite::{Sqlite, SqliteColumn, SqliteValue, SqliteValueRef}; /// Implementation of [`Row`] for SQLite. @@ -23,7 +23,7 @@ pub struct SqliteRow { // IF the user drops the Row before iterating the stream (so // nearly all of our internal stream iterators), the executor moves on; otherwise, // it actually inflates this row with a list of owned sqlite3 values. - pub(crate) statement: StatementHandle, + pub(crate) statement: StatementHandleRef, pub(crate) values: Arc>, pub(crate) num_values: usize, @@ -48,7 +48,7 @@ impl SqliteRow { // returns a weak reference to an atomic list where the executor should inflate if its going // to increment the statement with [step] pub(crate) fn current( - statement: StatementHandle, + statement: StatementHandleRef, columns: &Arc>, column_names: &Arc>, ) -> (Self, Weak>) { diff --git a/sqlx-core/src/sqlite/statement/handle.rs b/sqlx-core/src/sqlite/statement/handle.rs index d1af117a7d..27e7b59020 100644 --- a/sqlx-core/src/sqlite/statement/handle.rs +++ b/sqlx-core/src/sqlite/statement/handle.rs @@ -1,5 +1,6 @@ use std::ffi::c_void; use std::ffi::CStr; + use std::os::raw::{c_char, c_int}; use std::ptr; use std::ptr::NonNull; @@ -9,21 +10,34 @@ use std::str::{from_utf8, from_utf8_unchecked}; use libsqlite3_sys::{ sqlite3, sqlite3_bind_blob64, sqlite3_bind_double, sqlite3_bind_int, sqlite3_bind_int64, sqlite3_bind_null, sqlite3_bind_parameter_count, sqlite3_bind_parameter_name, - sqlite3_bind_text64, sqlite3_changes, sqlite3_column_blob, sqlite3_column_bytes, - sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_decltype, - sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, sqlite3_column_name, - sqlite3_column_origin_name, sqlite3_column_table_name, sqlite3_column_type, - sqlite3_column_value, sqlite3_db_handle, sqlite3_reset, sqlite3_sql, sqlite3_stmt, - sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, SQLITE_OK, - SQLITE_TRANSIENT, SQLITE_UTF8, + sqlite3_bind_text64, sqlite3_changes, sqlite3_clear_bindings, sqlite3_column_blob, + sqlite3_column_bytes, sqlite3_column_count, sqlite3_column_database_name, + sqlite3_column_decltype, sqlite3_column_double, sqlite3_column_int, sqlite3_column_int64, + sqlite3_column_name, sqlite3_column_origin_name, sqlite3_column_table_name, + sqlite3_column_type, sqlite3_column_value, sqlite3_db_handle, sqlite3_finalize, sqlite3_sql, + sqlite3_stmt, sqlite3_stmt_readonly, sqlite3_table_column_metadata, sqlite3_value, + SQLITE_MISUSE, SQLITE_OK, SQLITE_TRANSIENT, SQLITE_UTF8, }; use crate::error::{BoxDynError, Error}; +use crate::sqlite::connection::ConnectionHandleRef; use crate::sqlite::type_info::DataType; use crate::sqlite::{SqliteError, SqliteTypeInfo}; - -#[derive(Debug, Copy, Clone)] -pub(crate) struct StatementHandle(pub(super) NonNull); +use std::ops::Deref; +use std::sync::Arc; + +#[derive(Debug)] +pub(crate) struct StatementHandle(NonNull); + +// wrapper for `Arc` which also holds a reference to the `ConnectionHandle` +#[derive(Clone, Debug)] +pub(crate) struct StatementHandleRef { + // NOTE: the ordering of fields here determines the drop order: + // https://doc.rust-lang.org/reference/destructors.html#destructors + // the statement *must* be dropped before the connection + statement: Arc, + connection: ConnectionHandleRef, +} // access to SQLite3 statement handles are safe to send and share between threads // as long as the `sqlite3_step` call is serialized. @@ -32,6 +46,14 @@ unsafe impl Send for StatementHandle {} unsafe impl Sync for StatementHandle {} impl StatementHandle { + pub(super) fn new(ptr: NonNull) -> Self { + Self(ptr) + } + + pub(crate) fn as_ptr(&self) -> *mut sqlite3_stmt { + self.0.as_ptr() + } + #[inline] pub(super) unsafe fn db_handle(&self) -> *mut sqlite3 { // O(c) access to the connection handle for this statement handle @@ -280,7 +302,44 @@ impl StatementHandle { Ok(from_utf8(self.column_blob(index))?) } - pub(crate) fn reset(&self) { - unsafe { sqlite3_reset(self.0.as_ptr()) }; + pub(crate) fn clear_bindings(&self) { + unsafe { sqlite3_clear_bindings(self.0.as_ptr()) }; + } + + pub(crate) fn to_ref( + self: &Arc, + conn: ConnectionHandleRef, + ) -> StatementHandleRef { + StatementHandleRef { + statement: Arc::clone(self), + connection: conn, + } + } +} + +impl Drop for StatementHandle { + fn drop(&mut self) { + // SAFETY: we have exclusive access to the `StatementHandle` here + unsafe { + // https://sqlite.org/c3ref/finalize.html + let status = sqlite3_finalize(self.0.as_ptr()); + if status == SQLITE_MISUSE { + // Panic in case of detected misuse of SQLite API. + // + // sqlite3_finalize returns it at least in the + // case of detected double free, i.e. calling + // sqlite3_finalize on already finalized + // statement. + panic!("Detected sqlite3_finalize misuse."); + } + } + } +} + +impl Deref for StatementHandleRef { + type Target = StatementHandle; + + fn deref(&self) -> &Self::Target { + &self.statement } } diff --git a/sqlx-core/src/sqlite/statement/mod.rs b/sqlx-core/src/sqlite/statement/mod.rs index dec11dcc17..97ca9f8685 100644 --- a/sqlx-core/src/sqlite/statement/mod.rs +++ b/sqlx-core/src/sqlite/statement/mod.rs @@ -12,7 +12,7 @@ mod handle; mod r#virtual; mod worker; -pub(crate) use handle::StatementHandle; +pub(crate) use handle::{StatementHandle, StatementHandleRef}; pub(crate) use r#virtual::VirtualStatement; pub(crate) use worker::StatementWorker; diff --git a/sqlx-core/src/sqlite/statement/virtual.rs b/sqlx-core/src/sqlite/statement/virtual.rs index 0063e06508..3da6d33d64 100644 --- a/sqlx-core/src/sqlite/statement/virtual.rs +++ b/sqlx-core/src/sqlite/statement/virtual.rs @@ -3,13 +3,12 @@ use crate::error::Error; use crate::ext::ustr::UStr; use crate::sqlite::connection::ConnectionHandle; -use crate::sqlite::statement::StatementHandle; +use crate::sqlite::statement::{StatementHandle, StatementWorker}; use crate::sqlite::{SqliteColumn, SqliteError, SqliteRow, SqliteValue}; use crate::HashMap; use bytes::{Buf, Bytes}; use libsqlite3_sys::{ - sqlite3, sqlite3_clear_bindings, sqlite3_finalize, sqlite3_prepare_v3, sqlite3_reset, - sqlite3_stmt, SQLITE_MISUSE, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, + sqlite3, sqlite3_prepare_v3, sqlite3_stmt, SQLITE_OK, SQLITE_PREPARE_PERSISTENT, }; use smallvec::SmallVec; use std::i32; @@ -31,7 +30,7 @@ pub(crate) struct VirtualStatement { // underlying sqlite handles for each inner statement // a SQL query string in SQLite is broken up into N statements // we use a [`SmallVec`] to optimize for the most likely case of a single statement - pub(crate) handles: SmallVec<[StatementHandle; 1]>, + pub(crate) handles: SmallVec<[Arc; 1]>, // each set of columns pub(crate) columns: SmallVec<[Arc>; 1]>, @@ -92,7 +91,7 @@ fn prepare( query.advance(n); if let Some(handle) = NonNull::new(statement_handle) { - return Ok(Some(StatementHandle(handle))); + return Ok(Some(StatementHandle::new(handle))); } } @@ -126,7 +125,7 @@ impl VirtualStatement { conn: &mut ConnectionHandle, ) -> Result< Option<( - &StatementHandle, + &Arc, &mut Arc>, &Arc>, &mut Option>>, @@ -159,7 +158,7 @@ impl VirtualStatement { column_names.insert(name, i); } - self.handles.push(statement); + self.handles.push(Arc::new(statement)); self.columns.push(Arc::new(columns)); self.column_names.push(Arc::new(column_names)); self.last_row_values.push(None); @@ -177,20 +176,20 @@ impl VirtualStatement { ))) } - pub(crate) fn reset(&mut self) { + pub(crate) async fn reset(&mut self, worker: &mut StatementWorker) -> Result<(), Error> { self.index = 0; for (i, handle) in self.handles.iter().enumerate() { SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); - unsafe { - // Reset A Prepared Statement Object - // https://www.sqlite.org/c3ref/reset.html - // https://www.sqlite.org/c3ref/clear_bindings.html - sqlite3_reset(handle.0.as_ptr()); - sqlite3_clear_bindings(handle.0.as_ptr()); - } + // Reset A Prepared Statement Object + // https://www.sqlite.org/c3ref/reset.html + // https://www.sqlite.org/c3ref/clear_bindings.html + worker.reset(handle).await?; + handle.clear_bindings(); } + + Ok(()) } } @@ -198,20 +197,6 @@ impl Drop for VirtualStatement { fn drop(&mut self) { for (i, handle) in self.handles.drain(..).enumerate() { SqliteRow::inflate_if_needed(&handle, &self.columns[i], self.last_row_values[i].take()); - - unsafe { - // https://sqlite.org/c3ref/finalize.html - let status = sqlite3_finalize(handle.0.as_ptr()); - if status == SQLITE_MISUSE { - // Panic in case of detected misuse of SQLite API. - // - // sqlite3_finalize returns it at least in the - // case of detected double free, i.e. calling - // sqlite3_finalize on already finalized - // statement. - panic!("Detected sqlite3_finalize misuse."); - } - } } } } diff --git a/sqlx-core/src/sqlite/statement/worker.rs b/sqlx-core/src/sqlite/statement/worker.rs index 8b1d229978..5a06f637b0 100644 --- a/sqlx-core/src/sqlite/statement/worker.rs +++ b/sqlx-core/src/sqlite/statement/worker.rs @@ -3,9 +3,14 @@ use crate::sqlite::statement::StatementHandle; use crossbeam_channel::{unbounded, Sender}; use either::Either; use futures_channel::oneshot; -use libsqlite3_sys::{sqlite3_step, SQLITE_DONE, SQLITE_ROW}; +use std::sync::{Arc, Weak}; use std::thread; +use crate::sqlite::connection::ConnectionHandleRef; + +use libsqlite3_sys::{sqlite3_reset, sqlite3_step, SQLITE_DONE, SQLITE_ROW}; +use std::future::Future; + // Each SQLite connection has a dedicated thread. // TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce @@ -18,31 +23,70 @@ pub(crate) struct StatementWorker { enum StatementWorkerCommand { Step { - statement: StatementHandle, + statement: Weak, tx: oneshot::Sender, Error>>, }, + Reset { + statement: Weak, + tx: oneshot::Sender<()>, + }, + Shutdown { + tx: oneshot::Sender<()>, + }, } impl StatementWorker { - pub(crate) fn new() -> Self { + pub(crate) fn new(conn: ConnectionHandleRef) -> Self { let (tx, rx) = unbounded(); thread::spawn(move || { for cmd in rx { match cmd { StatementWorkerCommand::Step { statement, tx } => { - let status = unsafe { sqlite3_step(statement.0.as_ptr()) }; + let statement = if let Some(statement) = statement.upgrade() { + statement + } else { + // statement is already finalized, the sender shouldn't be expecting a response + continue; + }; - let resp = match status { + // SAFETY: only the `StatementWorker` calls this function + let status = unsafe { sqlite3_step(statement.as_ptr()) }; + let result = match status { SQLITE_ROW => Ok(Either::Right(())), SQLITE_DONE => Ok(Either::Left(statement.changes())), _ => Err(statement.last_error().into()), }; - let _ = tx.send(resp); + let _ = tx.send(result); + } + StatementWorkerCommand::Reset { statement, tx } => { + if let Some(statement) = statement.upgrade() { + // SAFETY: this must be the only place we call `sqlite3_reset` + unsafe { sqlite3_reset(statement.as_ptr()) }; + + // `sqlite3_reset()` always returns either `SQLITE_OK` + // or the last error code for the statement, + // which should have already been handled; + // so it's assumed the return value is safe to ignore. + // + // https://www.sqlite.org/c3ref/reset.html + + let _ = tx.send(()); + } + } + StatementWorkerCommand::Shutdown { tx } => { + // drop the connection reference before sending confirmation + // and ending the command loop + drop(conn); + let _ = tx.send(()); + return; } } } + + // SAFETY: we need to make sure a strong ref to `conn` always outlives anything in `rx` + drop(conn); }); Self { tx } @@ -50,14 +94,68 @@ impl StatementWorker { pub(crate) async fn step( &mut self, - statement: StatementHandle, + statement: &Arc, ) -> Result, Error> { let (tx, rx) = oneshot::channel(); self.tx - .send(StatementWorkerCommand::Step { statement, tx }) + .send(StatementWorkerCommand::Step { + statement: Arc::downgrade(statement), + tx, + }) .map_err(|_| Error::WorkerCrashed)?; rx.await.map_err(|_| Error::WorkerCrashed)? } + + /// Send a command to the worker to execute `sqlite3_reset()` next. + /// + /// This method is written to execute the sending of the command eagerly so + /// you do not need to await the returned future unless you want to. + /// + /// The only error is `WorkerCrashed` as `sqlite3_reset()` returns the last error + /// in the statement execution which should have already been handled from `step()`. + pub(crate) fn reset( + &mut self, + statement: &Arc, + ) -> impl Future> { + // execute the sending eagerly so we don't need to spawn the future + let (tx, rx) = oneshot::channel(); + + let send_res = self + .tx + .send(StatementWorkerCommand::Reset { + statement: Arc::downgrade(statement), + tx, + }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } + + /// Send a command to the worker to shut down the processing thread. + /// + /// A `WorkerCrashed` error may be returned if the thread has already stopped. + /// Subsequent calls to `step()`, `reset()`, or this method will fail with + /// `WorkerCrashed`. Ensure that any associated statements are dropped first. + pub(crate) fn shutdown(&mut self) -> impl Future> { + let (tx, rx) = oneshot::channel(); + + let send_res = self + .tx + .send(StatementWorkerCommand::Shutdown { tx }) + .map_err(|_| Error::WorkerCrashed); + + async move { + send_res?; + + // wait for the response + rx.await.map_err(|_| Error::WorkerCrashed) + } + } } diff --git a/sqlx-core/src/sqlite/types/chrono.rs b/sqlx-core/src/sqlite/types/chrono.rs index cd01c3bde2..1ebb2c4f45 100644 --- a/sqlx-core/src/sqlite/types/chrono.rs +++ b/sqlx-core/src/sqlite/types/chrono.rs @@ -76,7 +76,7 @@ impl Encode<'_, Sqlite> for NaiveDate { impl Encode<'_, Sqlite> for NaiveTime { fn encode_by_ref(&self, buf: &mut Vec>) -> IsNull { - Encode::::encode(self.format("%T%.f%").to_string(), buf) + Encode::::encode(self.format("%T%.f").to_string(), buf) } } @@ -179,9 +179,11 @@ impl<'r> Decode<'r, Sqlite> for NaiveTime { // Loop over common time patterns, inspired by Diesel // https://github.com/diesel-rs/diesel/blob/93ab183bcb06c69c0aee4a7557b6798fd52dd0d8/diesel/src/sqlite/types/date_and_time/chrono.rs#L29-L47 + #[rustfmt::skip] // don't like how rustfmt mangles the comments let sqlite_time_formats = &[ // Most likely format - "%T.f", // Other formats in order of appearance in docs + "%T.f", "%T%.f", + // Other formats in order of appearance in docs "%R", "%RZ", "%T%.fZ", "%R%:z", "%T%.f%:z", ]; diff --git a/sqlx-core/src/sqlite/types/str.rs b/sqlx-core/src/sqlite/types/str.rs index 6a3ed533b1..086597ef10 100644 --- a/sqlx-core/src/sqlite/types/str.rs +++ b/sqlx-core/src/sqlite/types/str.rs @@ -52,3 +52,23 @@ impl<'r> Decode<'r, Sqlite> for String { value.text().map(ToOwned::to_owned) } } + +impl<'q> Encode<'q, Sqlite> for Cow<'q, str> { + fn encode(self, args: &mut Vec>) -> IsNull { + args.push(SqliteArgumentValue::Text(self)); + + IsNull::No + } + + fn encode_by_ref(&self, args: &mut Vec>) -> IsNull { + args.push(SqliteArgumentValue::Text(self.clone())); + + IsNull::No + } +} + +impl<'r> Decode<'r, Sqlite> for Cow<'r, str> { + fn decode(value: SqliteValueRef<'r>) -> Result { + value.text().map(Cow::Borrowed) + } +} diff --git a/sqlx-core/src/transaction.rs b/sqlx-core/src/transaction.rs index 76d5ac85ba..77b2a3570f 100644 --- a/sqlx-core/src/transaction.rs +++ b/sqlx-core/src/transaction.rs @@ -2,7 +2,10 @@ use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::ops::{Deref, DerefMut}; +#[cfg(not(target_arch = "wasm32"))] use futures_core::future::BoxFuture; +#[cfg(target_arch = "wasm32")] +use futures_core::future::LocalBoxFuture as BoxFuture; use crate::database::Database; use crate::error::Error; @@ -95,6 +98,7 @@ where } // NOTE: required due to lack of lazy normalization +#[cfg(not(target_arch = "wasm32"))] #[allow(unused_macros)] macro_rules! impl_executor_for_transaction { ($DB:ident, $Row:ident) => { @@ -165,6 +169,77 @@ macro_rules! impl_executor_for_transaction { }; } +#[cfg(target_arch = "wasm32")] +#[allow(unused_macros)] +macro_rules! impl_executor_for_transaction { + ($DB:ident, $Row:ident) => { + impl<'c, 't> crate::executor::Executor<'t> + for &'t mut crate::transaction::Transaction<'c, $DB> + { + type Database = $DB; + + fn fetch_many<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> futures_core::stream::LocalBoxStream< + 'e, + Result< + either::Either<<$DB as crate::database::Database>::QueryResult, $Row>, + crate::error::Error, + >, + > + where + 't: 'e, + E: crate::executor::Execute<'q, Self::Database>, + { + (&mut **self).fetch_many(query) + } + + fn fetch_optional<'e, 'q: 'e, E: 'q>( + self, + query: E, + ) -> futures_core::future::LocalBoxFuture<'e, Result, crate::error::Error>> + where + 't: 'e, + E: crate::executor::Execute<'q, Self::Database>, + { + (&mut **self).fetch_optional(query) + } + + fn prepare_with<'e, 'q: 'e>( + self, + sql: &'q str, + parameters: &'e [::TypeInfo], + ) -> futures_core::future::LocalBoxFuture< + 'e, + Result< + >::Statement, + crate::error::Error, + >, + > + where + 't: 'e, + { + (&mut **self).prepare_with(sql, parameters) + } + + #[doc(hidden)] + fn describe<'e, 'q: 'e>( + self, + query: &'q str, + ) -> futures_core::future::LocalBoxFuture< + 'e, + Result, crate::error::Error>, + > + where + 't: 'e, + { + (&mut **self).describe(query) + } + } + }; +} + impl<'c, DB> Debug for Transaction<'c, DB> where DB: Database, diff --git a/sqlx-core/src/types/mod.rs b/sqlx-core/src/types/mod.rs index 600daf0fdd..2bf4e3b5d2 100644 --- a/sqlx-core/src/types/mod.rs +++ b/sqlx-core/src/types/mod.rs @@ -75,6 +75,13 @@ pub mod ipnetwork { pub use ipnetwork::{IpNetwork, Ipv4Network, Ipv6Network}; } +#[cfg(feature = "mac_address")] +#[cfg_attr(docsrs, doc(cfg(feature = "mac_address")))] +pub mod mac_address { + #[doc(no_inline)] + pub use mac_address::MacAddress; +} + #[cfg(feature = "json")] pub use json::Json; diff --git a/sqlx-macros/Cargo.toml b/sqlx-macros/Cargo.toml index 91ae1f6071..099049ea90 100644 --- a/sqlx-macros/Cargo.toml +++ b/sqlx-macros/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-macros" -version = "0.5.3" +version = "0.5.9" repository = "https://github.com/launchbadge/sqlx" description = "Macros for SQLx, the rust SQL toolkit. Not intended to be used directly." license = "MIT OR Apache-2.0" @@ -72,20 +72,20 @@ decimal = ["sqlx-core/decimal"] chrono = ["sqlx-core/chrono"] time = ["sqlx-core/time"] ipnetwork = ["sqlx-core/ipnetwork"] +mac_address = ["sqlx-core/mac_address"] uuid = ["sqlx-core/uuid"] bit-vec = ["sqlx-core/bit-vec"] json = ["sqlx-core/json", "serde_json"] [dependencies] dotenv = { version = "0.15.0", default-features = false } -futures = { version = "0.3.4", default-features = false, features = ["executor"] } hex = { version = "0.4.2", optional = true } heck = "0.3.1" either = "1.5.3" once_cell = "1.5.2" proc-macro2 = { version = "1.0.9", default-features = false } -sqlx-core = { version = "0.5.3", default-features = false, path = "../sqlx-core" } -sqlx-rt = { version = "0.5.3", default-features = false, path = "../sqlx-rt" } +sqlx-core = { version = "0.5.9", default-features = false, path = "../sqlx-core" } +sqlx-rt = { version = "0.5.9", default-features = false, path = "../sqlx-rt" } serde = { version = "1.0.111", features = ["derive"], optional = true } serde_json = { version = "1.0.30", features = ["preserve_order"], optional = true } sha2 = { version = "0.9.1", optional = true } diff --git a/sqlx-macros/src/database/postgres.rs b/sqlx-macros/src/database/postgres.rs index 05f0a88bd6..5330bb3cd9 100644 --- a/sqlx-macros/src/database/postgres.rs +++ b/sqlx-macros/src/database/postgres.rs @@ -60,6 +60,9 @@ impl_database_ext! { #[cfg(feature = "ipnetwork")] sqlx::types::ipnetwork::IpNetwork, + #[cfg(feature = "mac_address")] + sqlx::types::mac_address::MacAddress, + #[cfg(feature = "json")] serde_json::Value, @@ -113,6 +116,9 @@ impl_database_ext! { #[cfg(feature = "ipnetwork")] Vec | &[sqlx::types::ipnetwork::IpNetwork], + #[cfg(feature = "mac_address")] + Vec | &[sqlx::types::mac_address::MacAddress], + #[cfg(feature = "json")] Vec | &[serde_json::Value], diff --git a/sqlx-macros/src/derives/attributes.rs b/sqlx-macros/src/derives/attributes.rs index bdf2812999..202b6b5a17 100644 --- a/sqlx-macros/src/derives/attributes.rs +++ b/sqlx-macros/src/derives/attributes.rs @@ -15,7 +15,7 @@ macro_rules! assert_attribute { macro_rules! fail { ($t:expr, $m:expr) => { - return Err(syn::Error::new_spanned($t, $m)); + return Err(syn::Error::new_spanned($t, $m)) }; } @@ -216,8 +216,6 @@ pub fn check_transparent_attributes( field ); - assert_attribute!(attributes.repr.is_none(), "unexpected #[repr(..)]", input); - let ch_attributes = parse_child_attributes(&field.attrs)?; assert_attribute!( diff --git a/sqlx-macros/src/lib.rs b/sqlx-macros/src/lib.rs index 8956aa745b..8a4ea4a248 100644 --- a/sqlx-macros/src/lib.rs +++ b/sqlx-macros/src/lib.rs @@ -2,6 +2,10 @@ not(any(feature = "postgres", feature = "mysql", feature = "offline")), allow(dead_code, unused_macros, unused_imports) )] +#![cfg_attr( + any(sqlx_macros_unstable, procmacro2_semver_exempt), + feature(track_path, proc_macro_tracked_env) +)] extern crate proc_macro; use proc_macro::TokenStream; @@ -17,7 +21,7 @@ mod database; mod derives; mod query; -#[cfg(feature = "migrate")] +#[cfg(all(feature = "migrate", not(target_arch = "wasm32")))] mod migrate; #[proc_macro] @@ -74,7 +78,7 @@ pub fn derive_from_row(input: TokenStream) -> TokenStream { } } -#[cfg(feature = "migrate")] +#[cfg(all(feature = "migrate", not(target_arch = "wasm32")))] #[proc_macro] pub fn migrate(input: TokenStream) -> TokenStream { use syn::LitStr; diff --git a/sqlx-macros/src/migrate.rs b/sqlx-macros/src/migrate.rs index f10bc22318..018ba1b41e 100644 --- a/sqlx-macros/src/migrate.rs +++ b/sqlx-macros/src/migrate.rs @@ -24,7 +24,7 @@ struct QuotedMigration { version: i64, description: String, migration_type: QuotedMigrationType, - sql: String, + path: String, checksum: Vec, } @@ -34,7 +34,7 @@ impl ToTokens for QuotedMigration { version, description, migration_type, - sql, + path, checksum, } = &self; @@ -43,7 +43,8 @@ impl ToTokens for QuotedMigration { version: #version, description: ::std::borrow::Cow::Borrowed(#description), migration_type: #migration_type, - sql: ::std::borrow::Cow::Borrowed(#sql), + // this tells the compiler to watch this path for changes + sql: ::std::borrow::Cow::Borrowed(include_str!(#path)), checksum: ::std::borrow::Cow::Borrowed(&[ #(#checksum),* ]), @@ -59,7 +60,7 @@ pub(crate) fn expand_migrator_from_dir(dir: LitStr) -> crate::Result crate::Result crate::Result( use ::sqlx::ty_match::{WrapSameExt as _, MatchBorrowExt as _}; // evaluate the expression only once in case it contains moves - let _expr = ::sqlx::ty_match::dupe_value(#name); + let expr = ::sqlx::ty_match::dupe_value(#name); - // if `_expr` is `Option`, get `Option<$ty>`, otherwise `$ty` - let ty_check = ::sqlx::ty_match::WrapSame::<#param_ty, _>::new(&_expr).wrap_same(); + // if `expr` is `Option`, get `Option<$ty>`, otherwise `$ty` + let ty_check = ::sqlx::ty_match::WrapSame::<#param_ty, _>::new(&expr).wrap_same(); - // if `_expr` is `&str`, convert `String` to `&str` - let (mut _ty_check, match_borrow) = ::sqlx::ty_match::MatchBorrow::new(ty_check, &_expr); + // if `expr` is `&str`, convert `String` to `&str` + let (mut _ty_check, match_borrow) = ::sqlx::ty_match::MatchBorrow::new(ty_check, &expr); _ty_check = match_borrow.match_borrow(); diff --git a/sqlx-macros/src/query/data.rs b/sqlx-macros/src/query/data.rs index ee123b503f..9e00e1e81a 100644 --- a/sqlx-macros/src/query/data.rs +++ b/sqlx-macros/src/query/data.rs @@ -61,7 +61,7 @@ pub mod offline { /// Find and deserialize the data table for this query from a shared `sqlx-data.json` /// file. The expected structure is a JSON map keyed by the SHA-256 hash of queries in hex. pub fn from_data_file(path: impl AsRef, query: &str) -> crate::Result { - serde_json::Deserializer::from_reader(BufReader::new( + let this = serde_json::Deserializer::from_reader(BufReader::new( File::open(path.as_ref()).map_err(|e| { format!("failed to open path {}: {}", path.as_ref().display(), e) })?, @@ -69,8 +69,22 @@ pub mod offline { .deserialize_map(DataFileVisitor { query, hash: hash_string(query), - }) - .map_err(Into::into) + })?; + + #[cfg(procmacr2_semver_exempt)] + { + let path = path.as_ref().canonicalize()?; + let path = path.to_str().ok_or_else(|| { + format!( + "sqlx-data.json path cannot be represented as a string: {:?}", + path + ) + })?; + + proc_macro::tracked_path::path(path); + } + + Ok(this) } } diff --git a/sqlx-macros/src/query/input.rs b/sqlx-macros/src/query/input.rs index 86627d60b1..f3bce4a333 100644 --- a/sqlx-macros/src/query/input.rs +++ b/sqlx-macros/src/query/input.rs @@ -8,7 +8,7 @@ use syn::{ExprArray, Type}; /// Macro input shared by `query!()` and `query_file!()` pub struct QueryMacroInput { - pub(super) src: String, + pub(super) sql: String, #[cfg_attr(not(feature = "offline"), allow(dead_code))] pub(super) src_span: Span, @@ -18,6 +18,8 @@ pub struct QueryMacroInput { pub(super) arg_exprs: Vec, pub(super) checked: bool, + + pub(super) file_path: Option, } enum QuerySrc { @@ -94,12 +96,15 @@ impl Parse for QueryMacroInput { let arg_exprs = args.unwrap_or_default(); + let file_path = src.file_path(src_span)?; + Ok(QueryMacroInput { - src: src.resolve(src_span)?, + sql: src.resolve(src_span)?, src_span, record_type, arg_exprs, checked, + file_path, }) } } @@ -112,6 +117,27 @@ impl QuerySrc { QuerySrc::File(file) => read_file_src(&file, source_span), } } + + fn file_path(&self, source_span: Span) -> syn::Result> { + if let QuerySrc::File(ref file) = *self { + let path = crate::common::resolve_path(file, source_span)? + .canonicalize() + .map_err(|e| syn::Error::new(source_span, e))?; + + Ok(Some( + path.to_str() + .ok_or_else(|| { + syn::Error::new( + source_span, + "query file path cannot be represented as a string", + ) + })? + .to_string(), + )) + } else { + Ok(None) + } + } } fn read_file_src(source: &str, source_span: Span) -> syn::Result { diff --git a/sqlx-macros/src/query/mod.rs b/sqlx-macros/src/query/mod.rs index d114f54528..58c5dc5f34 100644 --- a/sqlx-macros/src/query/mod.rs +++ b/sqlx-macros/src/query/mod.rs @@ -1,4 +1,6 @@ use std::path::PathBuf; +#[cfg(feature = "offline")] +use std::sync::{Arc, Mutex}; use once_cell::sync::Lazy; use proc_macro2::TokenStream; @@ -10,6 +12,8 @@ use quote::{format_ident, quote}; use sqlx_core::connection::Connection; use sqlx_core::database::Database; use sqlx_core::{column::Column, describe::Describe, type_info::TypeInfo}; + +#[cfg(not(target_arch = "wasm32"))] use sqlx_rt::block_on; use crate::database::DatabaseExt; @@ -17,77 +21,93 @@ use crate::query::data::QueryData; use crate::query::input::RecordType; use either::Either; +#[cfg(target_arch = "wasm32")] +use {futures::channel::oneshot, sqlx_rt::spawn}; + mod args; mod data; mod input; mod output; struct Metadata { + #[allow(unused)] manifest_dir: PathBuf, offline: bool, database_url: Option, #[cfg(feature = "offline")] target_dir: PathBuf, #[cfg(feature = "offline")] - workspace_root: PathBuf, + workspace_root: Arc>>, +} + +#[cfg(feature = "offline")] +impl Metadata { + pub fn workspace_root(&self) -> PathBuf { + let mut root = self.workspace_root.lock().unwrap(); + if root.is_none() { + use serde::Deserialize; + use std::process::Command; + + let cargo = env("CARGO").expect("`CARGO` must be set"); + + let output = Command::new(&cargo) + .args(&["metadata", "--format-version=1"]) + .current_dir(&self.manifest_dir) + .env_remove("__CARGO_FIX_PLZ") + .output() + .expect("Could not fetch metadata"); + + #[derive(Deserialize)] + struct CargoMetadata { + workspace_root: PathBuf, + } + + let metadata: CargoMetadata = + serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); + + *root = Some(metadata.workspace_root); + } + root.clone().unwrap() + } } // If we are in a workspace, lookup `workspace_root` since `CARGO_MANIFEST_DIR` won't // reflect the workspace dir: https://github.com/rust-lang/cargo/issues/3946 static METADATA: Lazy = Lazy::new(|| { - use std::env; - - let manifest_dir = env::var("CARGO_MANIFEST_DIR") + let manifest_dir: PathBuf = env("CARGO_MANIFEST_DIR") .expect("`CARGO_MANIFEST_DIR` must be set") .into(); #[cfg(feature = "offline")] - let target_dir = - env::var_os("CARGO_TARGET_DIR").map_or_else(|| "target".into(), |dir| dir.into()); + let target_dir = env("CARGO_TARGET_DIR").map_or_else(|_| "target".into(), |dir| dir.into()); // If a .env file exists at CARGO_MANIFEST_DIR, load environment variables from this, // otherwise fallback to default dotenv behaviour. - let env_path = METADATA.manifest_dir.join(".env"); - if env_path.exists() { + let env_path = manifest_dir.join(".env"); + + #[cfg_attr(not(procmacro2_semver_exempt), allow(unused_variables))] + let env_path = if env_path.exists() { let res = dotenv::from_path(&env_path); if let Err(e) = res { panic!("failed to load environment from {:?}, {}", env_path, e); } + + Some(env_path) } else { - let _ = dotenv::dotenv(); + dotenv::dotenv().ok() + }; + + // tell the compiler to watch the `.env` for changes, if applicable + #[cfg(procmacro2_semver_exempt)] + if let Some(env_path) = env_path.as_ref().and_then(|path| path.to_str()) { + proc_macro::tracked_path::path(env_path); } - // TODO: Switch to `var_os` after feature(osstring_ascii) is stable. - // Stabilization PR: https://github.com/rust-lang/rust/pull/80193 - let offline = env::var("SQLX_OFFLINE") + let offline = env("SQLX_OFFLINE") .map(|s| s.eq_ignore_ascii_case("true") || s == "1") .unwrap_or(false); - let database_url = env::var("DATABASE_URL").ok(); - - #[cfg(feature = "offline")] - let workspace_root = { - use serde::Deserialize; - use std::process::Command; - - let cargo = env::var_os("CARGO").expect("`CARGO` must be set"); - - let output = Command::new(&cargo) - .args(&["metadata", "--format-version=1"]) - .current_dir(&manifest_dir) - .output() - .expect("Could not fetch metadata"); - - #[derive(Deserialize)] - struct CargoMetadata { - workspace_root: PathBuf, - } - - let metadata: CargoMetadata = - serde_json::from_slice(&output.stdout).expect("Invalid `cargo metadata` output"); - - metadata.workspace_root - }; + let database_url = env("DATABASE_URL").ok(); Metadata { manifest_dir, @@ -96,7 +116,7 @@ static METADATA: Lazy = Lazy::new(|| { #[cfg(feature = "offline")] target_dir, #[cfg(feature = "offline")] - workspace_root, + workspace_root: Arc::new(Mutex::new(None)), } }); @@ -111,18 +131,20 @@ pub fn expand_input(input: QueryMacroInput) -> crate::Result { #[cfg(feature = "offline")] _ => { let data_file_path = METADATA.manifest_dir.join("sqlx-data.json"); - let workspace_data_file_path = METADATA.workspace_root.join("sqlx-data.json"); if data_file_path.exists() { expand_from_file(input, data_file_path) - } else if workspace_data_file_path.exists() { - expand_from_file(input, workspace_data_file_path) } else { - Err( - "`DATABASE_URL` must be set, or `cargo sqlx prepare` must have been run \ + let workspace_data_file_path = METADATA.workspace_root().join("sqlx-data.json"); + if workspace_data_file_path.exists() { + expand_from_file(input, workspace_data_file_path) + } else { + Err( + "`DATABASE_URL` must be set, or `cargo sqlx prepare` must have been run \ and sqlx-data.json must exist, to use query macros" - .into(), - ) + .into(), + ) + } } } @@ -147,16 +169,35 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::postgres::PgConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) }, + #[cfg(all(feature = "postgres", target_arch = "wasm32"))] + "postgres" | "postgresql" => { + let (tx, mut rx) = oneshot::channel(); + let src = input.src.clone(); + spawn(async move { + let mut conn = match sqlx_core::postgres::PgConnection::connect(db_url.as_str()).await { + Ok(conn) => conn, + _ => return + }; + let _ = tx.send(QueryData::from_db(&mut conn, &src).await); + }); + + if let Some(Ok(data)) = rx.try_recv()? { + expand_with_data(input, data, false) + } else { + Err("unable to connect to database".into()) + } + }, + #[cfg(not(feature = "postgres"))] "postgres" | "postgresql" => Err("database URL has the scheme of a PostgreSQL database but the `postgres` feature is not enabled".into()), @@ -164,7 +205,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::mssql::MssqlConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) @@ -177,7 +218,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::mysql::MySqlConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) @@ -190,7 +231,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result { let data = block_on(async { let mut conn = sqlx_core::sqlite::SqliteConnection::connect(db_url.as_str()).await?; - QueryData::from_db(&mut conn, &input.src).await + QueryData::from_db(&mut conn, &input.sql).await })?; expand_with_data(input, data, false) @@ -207,7 +248,7 @@ fn expand_from_db(input: QueryMacroInput, db_url: &str) -> crate::Result crate::Result { use data::offline::DynQueryData; - let query_data = DynQueryData::from_data_file(file, &input.src)?; + let query_data = DynQueryData::from_data_file(file, &input.sql)?; assert!(!query_data.db_name.is_empty()); match &*query_data.db_name { @@ -288,7 +329,7 @@ where .all(|it| it.type_info().is_void()) { let db_path = DB::db_path(); - let sql = &input.src; + let sql = &input.sql; quote! { ::sqlx::query_with::<#db_path, _>(#sql, #query_args) @@ -303,7 +344,8 @@ where for rust_col in &columns { if rust_col.type_.is_wildcard() { return Err( - "columns may not have wildcard overrides in `query!()` or `query_as!()" + "wildcard overrides are only allowed with an explicit record type, \ + e.g. `query_as!()` and its variants" .into(), ); } @@ -313,6 +355,7 @@ where |&output::RustColumn { ref ident, ref type_, + .. }| quote!(#ident: #type_,), ); @@ -367,3 +410,16 @@ where Ok(ret_tokens) } + +/// Get the value of an environment variable, telling the compiler about it if applicable. +fn env(name: &str) -> Result { + #[cfg(procmacro2_semver_exempt)] + { + proc_macro::tracked_env::var(name) + } + + #[cfg(not(procmacro2_semver_exempt))] + { + std::env::var(name) + } +} diff --git a/sqlx-macros/src/query/output.rs b/sqlx-macros/src/query/output.rs index e7b482e044..f7d56646dd 100644 --- a/sqlx-macros/src/query/output.rs +++ b/sqlx-macros/src/query/output.rs @@ -14,6 +14,7 @@ use syn::Token; pub struct RustColumn { pub(super) ident: Ident, + pub(super) var_name: Ident, pub(super) type_: ColumnType, } @@ -114,6 +115,9 @@ fn column_to_rust(describe: &Describe, i: usize) -> crate:: }; Ok(RustColumn { + // prefix the variable name we use in `quote_query_as!()` so it doesn't conflict + // https://github.com/launchbadge/sqlx/issues/1322 + var_name: quote::format_ident!("sqlx_query_as_{}", decl.ident), ident: decl.ident, type_, }) @@ -129,7 +133,7 @@ pub fn quote_query_as( |( i, &RustColumn { - ref ident, + ref var_name, ref type_, .. }, @@ -140,24 +144,32 @@ pub fn quote_query_as( // binding to a `let` avoids confusing errors about // "try expression alternatives have incompatible types" // it doesn't seem to hurt inference in the other branches - let #ident = row.try_get_unchecked::<#type_, _>(#i)?; + let #var_name = row.try_get_unchecked::<#type_, _>(#i)?; }, // type was overridden to be a wildcard so we fallback to the runtime check - (true, ColumnType::Wildcard) => quote! ( let #ident = row.try_get(#i)?; ), + (true, ColumnType::Wildcard) => quote! ( let #var_name = row.try_get(#i)?; ), (true, ColumnType::OptWildcard) => { - quote! ( let #ident = row.try_get::<::std::option::Option<_>, _>(#i)?; ) + quote! ( let #var_name = row.try_get::<::std::option::Option<_>, _>(#i)?; ) } // macro is the `_unchecked!()` variant so this will die in decoding if it's wrong - (false, _) => quote!( let #ident = row.try_get_unchecked(#i)?; ), + (false, _) => quote!( let #var_name = row.try_get_unchecked(#i)?; ), } }, ); let ident = columns.iter().map(|col| &col.ident); + let var_name = columns.iter().map(|col| &col.var_name); let db_path = DB::db_path(); let row_path = DB::row_path(); - let sql = &input.src; + + // if this query came from a file, use `include_str!()` to tell the compiler where it came from + let sql = if let Some(ref path) = &input.file_path { + quote::quote_spanned! { input.src_span => include_str!(#path) } + } else { + let sql = &input.sql; + quote! { #sql } + }; quote! { ::sqlx::query_with::<#db_path, _>(#sql, #bind_args).try_map(|row: #row_path| { @@ -165,7 +177,7 @@ pub fn quote_query_as( #(#instantiations)* - Ok(#out_ty { #(#ident: #ident),* }) + Ok(#out_ty { #(#ident: #var_name),* }) }) } } @@ -200,7 +212,7 @@ pub fn quote_query_scalar( }; let db = DB::db_path(); - let query = &input.src; + let query = &input.sql; Ok(quote! { ::sqlx::query_scalar_with::<#db, #ty, _>(#query, #bind_args) diff --git a/sqlx-rt/Cargo.toml b/sqlx-rt/Cargo.toml index 75fa4bc9c4..5d9d8d523a 100644 --- a/sqlx-rt/Cargo.toml +++ b/sqlx-rt/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlx-rt" -version = "0.5.3" +version = "0.5.9" repository = "https://github.com/launchbadge/sqlx" license = "MIT OR Apache-2.0" description = "Runtime abstraction used by SQLx, the Rust SQL toolkit. Not intended to be used directly." @@ -40,6 +40,15 @@ tokio-rustls = { version = "0.22.0", optional = true } native-tls = { version = "0.2.4", optional = true } once_cell = { version = "1.4", features = ["std"], optional = true } + +[target.'cfg(target_arch = "wasm32")'.dependencies] +wasm-bindgen = { version = "0.2.71" } +wasm-bindgen-futures = { version = "^0.3", features = ["futures_0_3"] } +futures-util = { version = "0.3.5", features = ["sink", "io"] } +ws_stream_wasm = { version = "0.7" } +async_io_stream = { version = "0.3.1" } +web-sys = { version = "*" } + [dependencies.tokio] version = "1.0.1" features = ["fs", "net", "rt", "rt-multi-thread", "time", "io-util"] diff --git a/sqlx-rt/src/lib.rs b/sqlx-rt/src/lib.rs index 9c9a68a534..ed28d79c73 100644 --- a/sqlx-rt/src/lib.rs +++ b/sqlx-rt/src/lib.rs @@ -37,7 +37,7 @@ pub use native_tls; ))] pub use tokio::{ self, fs, io::AsyncRead, io::AsyncReadExt, io::AsyncWrite, io::AsyncWriteExt, io::ReadBuf, - net::TcpStream, task::spawn, task::yield_now, time::sleep, time::timeout, + net::TcpStream, runtime::Handle, task::spawn, task::yield_now, time::sleep, time::timeout, }; #[cfg(all( @@ -105,7 +105,8 @@ pub use tokio_rustls::{client::TlsStream, TlsConnector}; #[macro_export] macro_rules! blocking { ($($expr:tt)*) => { - $crate::tokio::task::block_in_place(move || { $($expr)* }) + $crate::tokio::task::spawn_blocking(move || { $($expr)* }) + .await.expect("Blocking task failed to complete.") }; } @@ -137,6 +138,7 @@ macro_rules! blocking { #[cfg(all( feature = "_rt-async-std", not(any(feature = "_rt-actix", feature = "_rt-tokio")), + not(target_arch = "wasm32") ))] pub use async_std::{ self, fs, future::timeout, io::prelude::ReadExt as AsyncReadExt, @@ -193,3 +195,19 @@ pub use async_native_tls::{TlsConnector, TlsStream}; )), ))] pub use async_rustls::{client::TlsStream, TlsConnector}; + +// +// wasm-bindgen +// +#[cfg(target_arch = "wasm32")] +pub use { + async_io_stream::IoStream, + futures_util::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + pin_mut, + sink::Sink, + }, + wasm_bindgen_futures::futures_0_3::spawn_local as spawn, + web_sys::console, + ws_stream_wasm::{WsMeta, WsStream, WsStreamIo}, +}; diff --git a/sqlx-test/src/lib.rs b/sqlx-test/src/lib.rs index bd4d090cb7..da70155e5f 100644 --- a/sqlx-test/src/lib.rs +++ b/sqlx-test/src/lib.rs @@ -197,7 +197,8 @@ macro_rules! __test_prepared_type { #[macro_export] macro_rules! MySql_query_for_test_prepared_type { () => { - "SELECT {0} <=> ?, {0}, ?" + // MySQL 8.0.27 changed `<=>` to return an unsigned integer + "SELECT CAST({0} <=> ? AS SIGNED INTEGER), {0}, ?" }; } @@ -221,3 +222,125 @@ macro_rules! Postgres_query_for_test_prepared_type { "SELECT ({0} is not distinct from $1)::int4, {0}, $2" }; } + +#[macro_export] +macro_rules! time_delete_query { + ($n:expr, $count:literal) => { + let mut conn = new::().await.unwrap(); + + conn.execute("create temp table bench_deletes (id integer, descr text, primary key(id))") + .await; + + conn.execute("create bitmap index id_idx on bench_deletes (id)") + .await; + + let _ = sqlx::query(&format!( + "insert into bench_deletes (id, descr) select generate_series(1,{}) AS id, md5(random()::text) AS descr", + $count + )) + + .execute(&mut conn) + .await; + + let start = Instant::now(); + for _ in 0..3u8 { + for i in 1..$count { + let _ = sqlx::query(&format!( + "delete from bench_deletes where id = {}", + i + )) + .execute(&mut conn) + .await; + } + } + + let end = Instant::now(); + + println!("{}: Avg time is {}", $n, end.duration_since(start).as_millis() / 3u128); + }; +} + +#[macro_export] +macro_rules! time_update_query { + ($n:expr, $count:literal) => { + let mut conn = new::().await.unwrap(); + + conn.execute("create temp table bench_updates (id integer, descr text, primary key(id))") + .await; + + conn.execute("create bitmap index id_idx on bench_updates (id)") + .await; + + let _ = sqlx::query(&format!( + "insert into bench_updates (id, descr) select generate_series(1,{}) AS id, md5(random()::text) AS descr", + $count + )) + .execute(&mut conn) + .await; + + let start = Instant::now(); + for _ in 0..3u8 { + for i in 1..$count { + let _ = sqlx::query(&format!( + "update bench_updates set descr = md5(random()::text) where id = {}", + i + )) + .execute(&mut conn) + .await; + } + } + + let end = Instant::now(); + println!("{}: Avg time is {}", $n, end.duration_since(start).as_millis() / 3u128); + }; +} + +#[macro_export] +macro_rules! time_insert_query { + ($n:expr, $count:literal) => { + let mut conn = new::().await.unwrap(); + conn.execute("create temp table bench_inserts (id integer, descr text)") + .await; + + let start = Instant::now(); + + for _ in 0..3u8 { + for i in 0..$count { + let _ = sqlx::query(&format!( + "insert into bench_inserts (id, desc) values ({}, md5(random()::text))", + i + )) + .execute(&mut conn) + .await; + } + } + + let end = Instant::now(); + println!( + "{}: Avg time is {}", + $n, + end.duration_since(start).as_millis() / 3u128 + ); + }; +} + +#[macro_export] +macro_rules! time_query { + ($n:expr, $q:expr) => { + let mut conn = new::().await.unwrap(); + + let start = Instant::now(); + + for _ in 0..3u8 { + let _ = sqlx::query($q).fetch_all(&mut conn).await; + } + + let end = Instant::now(); + + println!( + "{}: Avg time is {}", + $n, + end.duration_since(start).as_millis() / 3u128 + ); + }; +} diff --git a/sqlx-wasm-test/.cargo/config b/sqlx-wasm-test/.cargo/config new file mode 100644 index 0000000000..4ec2f3b862 --- /dev/null +++ b/sqlx-wasm-test/.cargo/config @@ -0,0 +1,2 @@ +[target.wasm32-unknown-unknown] +runner = 'wasm-bindgen-test-runner' diff --git a/sqlx-wasm-test/.gitignore b/sqlx-wasm-test/.gitignore new file mode 100644 index 0000000000..96ef6c0b94 --- /dev/null +++ b/sqlx-wasm-test/.gitignore @@ -0,0 +1,2 @@ +/target +Cargo.lock diff --git a/sqlx-wasm-test/Cargo.toml b/sqlx-wasm-test/Cargo.toml new file mode 100644 index 0000000000..8697ae7d69 --- /dev/null +++ b/sqlx-wasm-test/Cargo.toml @@ -0,0 +1,103 @@ +[package] +name = "sqlx-wasm-test" +version = "0.1.0" +authors = ["abhi"] +edition = "2018" + +[lib] +crate-type = ["cdylib", "rlib"] +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +sqlx = { path = "..", features = ["postgres", "decimal", "bigdecimal", "time", "chrono", "bit-vec", "ipnetwork", "uuid", "json"] } +wasm-bindgen-futures = { version = "^0.3", features = ["futures_0_3"] } +wasm-bindgen = { version = "0.2.73" } +ws_stream_wasm = "0.7" +wasm-bindgen-test = "0.3.0" +instant = "0.1.9" +web-sys = { version = "0.3.50", features = ["console", "Performance"] } +futures = "0.3.14" +serde = "1.0.117" +serde_json = { version = "1.0.51", features = ["raw_value"] } +paste = "1.0.1" +time = { version = "0.2.26" } + +[[test]] +name = "selects" +path = "src/selects_bench.rs" + +[[test]] +name = "inserts" +path = "src/inserts_bench.rs" + +[[test]] +name = "updates" +path = "src/updates_bench.rs" + +[[test]] +name = "deletes" +path = "src/deletes_bench.rs" + +[[test]] +name = "pg_wasm_tests" +path = "src/pg_wasm_tests.rs" + +[[test]] +name = "pg_types_tests" +path = "src/pg_types_tests.rs" + +[[test]] +name = "pg_types_tests_2" +path = "src/pg_types_tests_2.rs" + +[[test]] +name = "pg_types_tests_3" +path = "src/pg_types_tests_3.rs" + +[[test]] +name = "pg_types_tests_4" +path = "src/pg_types_tests_4.rs" + +[[test]] +name = "pg_types_tests_uuid" +path = "src/pg_types_tests_uuid.rs" + +[[test]] +name = "pg_types_tests_ipnetwork" +path = "src/pg_types_tests_ipnetwork.rs" + +[[test]] +name = "pg_types_tests_bitvec" +path = "src/pg_types_tests_bitvec.rs" + +[[test]] +name = "pg_types_tests_chrono" +path = "src/pg_types_tests_chrono.rs" + +[[test]] +name = "pg_types_tests_time" +path = "src/pg_types_tests_time.rs" + +[[test]] +name = "pg_types_tests_json" +path = "src/pg_types_tests_json.rs" + +[[test]] +name = "pg_types_tests_bigdecimal" +path = "src/pg_types_tests_bigdecimal.rs" + +[[test]] +name = "pg_types_tests_decimal" +path = "src/pg_types_tests_decimal.rs" + +[[test]] +name = "pg_types_tests_money" +path = "src/pg_types_tests_money.rs" + +[[test]] +name = "pg_types_tests_range" +path = "src/pg_types_tests_range.rs" + +[[test]] +name = "pg_types_tests_interval" +path = "src/pg_types_tests_interval.rs" diff --git a/sqlx-wasm-test/README.md b/sqlx-wasm-test/README.md new file mode 100644 index 0000000000..cbee8d1229 --- /dev/null +++ b/sqlx-wasm-test/README.md @@ -0,0 +1,13 @@ +# Setup +1. Make sure postgres is installed and listening at port 5432. +2. Start a websocket-tcp proxy using [websocat](https://github.com/vi/websocat) + `$ websocat --binary ws-l:127.0.0.1:8080 tcp:127.0.0.1:5432` + +# Running +From the root folder of this crate: +1. `wasm-pack test --firefox -- --test` +2. Launch Firefox and navigate to [http://127.0.0.1:8000](http://127.0.0.1:8000) + +Corresponding native queries' benchmarking is done in `../tests/postgres/*_custom_bench.rs` files and they can be run by executing (insert benchmarking for example) - +`$ cargo test --no-default-features --features postgres,runtime-async-std-rustls --test pg-inserts-bench -- --test-threads=1 --nocapture` +from the root of this repo. diff --git a/sqlx-wasm-test/setup.sql b/sqlx-wasm-test/setup.sql new file mode 100644 index 0000000000..3d738e449b --- /dev/null +++ b/sqlx-wasm-test/setup.sql @@ -0,0 +1,3 @@ +CREATE DOMAIN month_id AS INT2 CHECK (1 <= value AND value <= 12); +CREATE TYPE year_month AS (year INT4, month month_id); +CREATE DOMAIN winter_year_month AS year_month CHECK ((value).month <= 3); diff --git a/sqlx-wasm-test/src/deletes_bench.rs b/sqlx-wasm-test/src/deletes_bench.rs new file mode 100644 index 0000000000..90aec3b9db --- /dev/null +++ b/sqlx-wasm-test/src/deletes_bench.rs @@ -0,0 +1,18 @@ +use sqlx::Executor; +use sqlx_wasm_test::time_delete_query; +use wasm_bindgen_test::*; + +#[wasm_bindgen_test] +async fn deletes_query_small() { + time_delete_query!("small", 100u32); +} + +#[wasm_bindgen_test] +async fn deletes_query_medium() { + time_delete_query!("medium", 1000u32); +} + +#[wasm_bindgen_test] +async fn deletes_query_large() { + time_delete_query!("large", 10000u32); +} diff --git a/sqlx-wasm-test/src/inserts_bench.rs b/sqlx-wasm-test/src/inserts_bench.rs new file mode 100644 index 0000000000..173880206d --- /dev/null +++ b/sqlx-wasm-test/src/inserts_bench.rs @@ -0,0 +1,18 @@ +use sqlx::Executor; +use sqlx_wasm_test::time_insert_query; +use wasm_bindgen_test::*; + +#[wasm_bindgen_test] +async fn insert_query_small() { + time_insert_query!("small", 100u32); +} + +#[wasm_bindgen_test] +async fn insert_query_medium() { + time_insert_query!("medium", 1000u32); +} + +#[wasm_bindgen_test] +async fn insert_query_large() { + time_insert_query!("large", 10000u32); +} diff --git a/sqlx-wasm-test/src/lib.rs b/sqlx-wasm-test/src/lib.rs new file mode 100644 index 0000000000..c06c0ad15c --- /dev/null +++ b/sqlx-wasm-test/src/lib.rs @@ -0,0 +1,286 @@ +#![feature(test)] + +extern crate test; + +wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + +use sqlx::Connection; +use sqlx::{Database, PgConnection, Postgres}; + +pub const URL: &str = "postgresql://paul:pass123@127.0.0.1:8080/jetasap_dev"; + +pub async fn new() -> PgConnection { + ::Connection::connect(URL) + .await + .unwrap() +} + +#[macro_export] +macro_rules! time_query { + ($n:expr, $q:expr) => { + let mut conn = sqlx_wasm_test::new().await; + + let performance = web_sys::window().unwrap().performance().unwrap(); + let start = performance.now(); + + for _ in 0..3u8 { + let _ = sqlx::query($q).fetch_all(&mut conn).await; + } + + let end = performance.now(); + web_sys::console::log_1(&format!("{}: Avg time is {}", $n, (end - start) / 3f64).into()); + }; +} + +#[macro_export] +macro_rules! time_insert_query { + ($n:expr, $count:literal) => { + let mut conn = sqlx_wasm_test::new().await; + let _ = conn + .execute("create temp table bench_inserts (id integer, descr text)") + .await; + + let performance = web_sys::window().unwrap().performance().unwrap(); + let start = performance.now(); + + for _ in 0..3u8 { + for i in 0..$count { + let _ = sqlx::query(&format!( + "insert into bench_inserts (id, desc) values ({}, md5(random()::text))", + i + )) + .execute(&mut conn) + .await; + } + } + + let end = performance.now(); + web_sys::console::log_1(&format!("{}: Avg time is {}", $n, (end - start) / 3f64).into()); + }; +} + +#[macro_export] +macro_rules! time_update_query { + ($n:expr, $count:literal) => { + let mut conn = sqlx_wasm_test::new().await; + let _ = conn.execute("create temp table bench_updates (id integer, descr text, primary key(id))") + .await; + let _ = conn.execute("create bitmap index id_idx on bench_updates (id)") + .await; + + let _ = sqlx::query(&format!( + "insert into bench_updates (id, descr) select generate_series(1,{}) AS id, md5(random()::text) AS descr", + $count + )) + .execute(&mut conn) + .await; + + let performance = web_sys::window().unwrap().performance().unwrap(); + let start = performance.now(); + + for _ in 0..3u8 { + for i in 1..$count { + let _ = sqlx::query(&format!( + "update bench_updates set descr = md5(random()::text) where id = {}", + i + )) + .execute(&mut conn) + .await; + } + } + + let end = performance.now(); + web_sys::console::log_1(&format!("{}: Avg time is {}", $n, (end - start) / 3f64).into()); + }; +} + +#[macro_export] +macro_rules! time_delete_query { + ($n:expr, $count:literal) => { + let mut conn = sqlx_wasm_test::new().await; + let _ = conn.execute("create temp table bench_deletes (id integer, descr text, primary key(id))") + .await; + + let _ = conn.execute("create bitmap index id_idx on bench_deletes (id)") + .await; + + let _ = sqlx::query(&format!( + "insert into bench_deletes (id, descr) select generate_series(1,{}) AS id, md5(random()::text) AS descr", + $count + )) + + .execute(&mut conn) + .await; + let performance = web_sys::window().unwrap().performance().unwrap(); + let start = performance.now(); + + for _ in 0..3u8 { + for i in 1..$count { + let _ = sqlx::query(&format!( + "delete from bench_deletes where id = {}", + i + )) + .execute(&mut conn) + .await; + } + } + + let end = performance.now(); + web_sys::console::log_1(&format!("{}: Avg time is {}", $n, (end - start) / 3f64).into()); + }; +} + +#[macro_export] +macro_rules! test_type { + ($name:ident<$ty:ty>($db:ident, $sql:literal, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::__test_prepared_type!($name<$ty>($db, $sql, $($text == $value),+)); + $crate::test_unprepared_type!($name<$ty>($db, $($text == $value),+)); + }; + + ($name:ident<$ty:ty>($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + paste::item! { + $crate::__test_prepared_type!($name<$ty>($db, $crate::[< $db _query_for_test_prepared_type >]!(), $($text == $value),+)); + $crate::test_unprepared_type!($name<$ty>($db, $($text == $value),+)); + } + }; + + ($name:ident($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::test_type!($name<$name>($db, $($text == $value),+)); + }; +} + +#[macro_export] +macro_rules! test_decode_type { + ($name:ident<$ty:ty>($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::__test_prepared_decode_type!($name<$ty>($db, $($text == $value),+)); + $crate::test_unprepared_type!($name<$ty>($db, $($text == $value),+)); + }; + + ($name:ident($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::test_decode_type!($name<$name>($db, $($text == $value),+)); + }; +} + +#[macro_export] +macro_rules! test_unprepared_type { + ($name:ident<$ty:ty>($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + paste::item! { + #[wasm_bindgen_test::wasm_bindgen_test] + async fn [< test_unprepared_type_ $name >] () { + use sqlx::prelude::*; + use futures::TryStreamExt; + + let mut conn = sqlx_wasm_test::new().await; + + $( + let query = format!("SELECT {}", $text); + let mut s = conn.fetch(&*query); + let row = s.try_next().await.unwrap().unwrap(); + let rec = row.try_get::<$ty, _>(0).unwrap(); + + assert!($value == rec); + + drop(s); + )+ + } + } + } +} + +#[macro_export] +macro_rules! __test_prepared_type { + ($name:ident<$ty:ty>($db:ident, $sql:expr, $($text:literal == $value:expr),+ $(,)?)) => { + paste::item! { + #[wasm_bindgen_test::wasm_bindgen_test] + async fn [< test_prepared_type_ $name >] () { + use sqlx::Row; + + let mut conn = sqlx_wasm_test::new().await; + + $( + let query = format!($sql, $text); + + let row = sqlx::query(&query) + .bind($value) + .bind($value) + .fetch_one(&mut conn) + .await.unwrap(); + + let matches: i32 = row.try_get(0).unwrap(); + let returned: $ty = row.try_get(1).unwrap(); + let round_trip: $ty = row.try_get(2).unwrap(); + + assert!(matches != 0, + "[1] DB value mismatch; given value: {:?}\n\ + as returned: {:?}\n\ + round-trip: {:?}", + $value, returned, round_trip); + + assert_eq!($value, returned, + "[2] DB value mismatch; given value: {:?}\n\ + as returned: {:?}\n\ + round-trip: {:?}", + $value, returned, round_trip); + + assert_eq!($value, round_trip, + "[3] DB value mismatch; given value: {:?}\n\ + as returned: {:?}\n\ + round-trip: {:?}", + $value, returned, round_trip); + )+ + } + } + }; +} + +// Test type encoding and decoding +#[macro_export] +macro_rules! test_prepared_type { + ($name:ident<$ty:ty>($db:ident, $sql:literal, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::__test_prepared_type!($name<$ty>($db, $sql, $($text == $value),+)); + }; + + ($name:ident<$ty:ty>($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + paste::item! { + $crate::__test_prepared_type!($name<$ty>($db, $crate::[< $db _query_for_test_prepared_type >]!(), $($text == $value),+)); + } + }; + + ($name:ident($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::__test_prepared_type!($name<$name>($db, $($text == $value),+)); + }; +} + +// Test type decoding only for the prepared query API +#[macro_export] +macro_rules! __test_prepared_decode_type { + ($name:ident<$ty:ty>($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + paste::item! { + #[wasm_bindgen_test::wasm_bindgen_test] + async fn [< test_prepared_decode_type_ $name >] () { + use sqlx::Row; + + let mut conn = sqlx_wasm_test::new().await; + + $( + let query = format!("SELECT {}", $text); + + let row = sqlx::query(&query) + .fetch_one(&mut conn) + .await.unwrap(); + + let rec: $ty = row.try_get(0).unwrap(); + + assert!($value == rec); + )+ + } + } + }; +} + +#[macro_export] +macro_rules! Postgres_query_for_test_prepared_type { + () => { + "SELECT ({0} is not distinct from $1)::int4, {0}, $2" + }; +} diff --git a/sqlx-wasm-test/src/pg_types_tests.rs b/sqlx-wasm-test/src/pg_types_tests.rs new file mode 100644 index 0000000000..c7cf06014f --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests.rs @@ -0,0 +1,41 @@ +use sqlx_wasm_test::test_type; + +test_type!(i8( + Postgres, + "0::\"char\"" == 0_i8, + "120::\"char\"" == 120_i8, +)); + +test_type!(u32(Postgres, "325235::oid" == 325235_u32,)); + +test_type!(i16( + Postgres, + "-2144::smallint" == -2144_i16, + "821::smallint" == 821_i16, +)); + +test_type!(i32( + Postgres, + "94101::int" == 94101_i32, + "-5101::int" == -5101_i32 +)); + +test_type!(i32_vec>(Postgres, + "'{5,10,50,100}'::int[]" == vec![5_i32, 10, 50, 100], + "'{1050}'::int[]" == vec![1050_i32], + "'{}'::int[]" == Vec::::new(), + "'{1,3,-5}'::int[]" == vec![1_i32, 3, -5] +)); + +test_type!(i64(Postgres, "9358295312::bigint" == 9358295312_i64)); + +test_type!(f32(Postgres, "9419.122::real" == 9419.122_f32)); + +test_type!(f64( + Postgres, + "939399419.1225182::double precision" == 939399419.1225182_f64 +)); + +test_type!(f64_vec>(Postgres, + "'{939399419.1225182,-12.0}'::float8[]" == vec![939399419.1225182_f64, -12.0] +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_2.rs b/sqlx-wasm-test/src/pg_types_tests_2.rs new file mode 100644 index 0000000000..04899da96d --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_2.rs @@ -0,0 +1,32 @@ +use sqlx_wasm_test::{test_prepared_type, test_type}; + +// BYTEA cannot be decoded by-reference from a simple query as postgres sends it as hex +test_prepared_type!(byte_slice<&[u8]>(Postgres, + "E'\\\\xDEADBEEF'::bytea" + == &[0xDE_u8, 0xAD, 0xBE, 0xEF][..], + "E'\\\\x0000000052'::bytea" + == &[0_u8, 0, 0, 0, 0x52][..] +)); + +test_type!(str<&str>(Postgres, + "'this is foo'" == "this is foo", + "''" == "", + "'identifier'::name" == "identifier", + "'five'::char(4)" == "five", + "'more text'::varchar" == "more text", +)); + +test_type!(string(Postgres, + "'this is foo'" == format!("this is foo"), +)); + +test_type!(string_vec>(Postgres, + "array['one','two','three']::text[]" + == vec!["one","two","three"], + + "array['', '\"']::text[]" + == vec!["", "\""], + + "array['Hello, World', '', 'Goodbye']::text[]" + == vec!["Hello, World", "", "Goodbye"] +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_3.rs b/sqlx-wasm-test/src/pg_types_tests_3.rs new file mode 100644 index 0000000000..b1f87a8c6c --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_3.rs @@ -0,0 +1,27 @@ +use sqlx_wasm_test::test_type; + +test_type!(null>(Postgres, + "NULL::int2" == None:: +)); + +test_type!(null_vec>>(Postgres, + "array[10,NULL,50]::int2[]" == vec![Some(10_i16), None, Some(50)], +)); + +test_type!(bool(Postgres, + "false::boolean" == false, + "true::boolean" == true +)); + +test_type!(bool_vec>(Postgres, + "array[true,false,true]::bool[]" == vec![true, false, true], +)); + +test_type!(byte_vec>(Postgres, + "E'\\\\xDEADBEEF'::bytea" + == vec![0xDE_u8, 0xAD, 0xBE, 0xEF], + "E'\\\\x'::bytea" + == Vec::::new(), + "E'\\\\x0000000052'::bytea" + == vec![0_u8, 0, 0, 0, 0x52] +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_4.rs b/sqlx-wasm-test/src/pg_types_tests_4.rs new file mode 100644 index 0000000000..5525ab8ed7 --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_4.rs @@ -0,0 +1,18 @@ +use sqlx_wasm_test::test_decode_type; + +test_decode_type!(bool_tuple<(bool,)>(Postgres, "row(true)" == (true,))); + +test_decode_type!(num_tuple<(i32, i64, f64,)>(Postgres, "row(10,515::int8,3.124::float8)" == (10,515,3.124))); + +test_decode_type!(empty_tuple<()>(Postgres, "row()" == ())); + +test_decode_type!(string_tuple<(String, String, String)>(Postgres, + "row('one','two','three')" + == ("one".to_string(), "two".to_string(), "three".to_string()), + + "row('', '\"', '\"\"\"\"\"\"')" + == ("".to_string(), "\"".to_string(), "\"\"\"\"\"\"".to_string()), + + "row('Hello, World', '', 'Goodbye')" + == ("Hello, World".to_string(), "".to_string(), "Goodbye".to_string()) +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_bigdecimal.rs b/sqlx-wasm-test/src/pg_types_tests_bigdecimal.rs new file mode 100644 index 0000000000..e562733b9c --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_bigdecimal.rs @@ -0,0 +1,22 @@ +use sqlx_wasm_test::test_type; + +test_type!(bigdecimal(Postgres, + + // https://github.com/launchbadge/sqlx/issues/283 + "0::numeric" == "0".parse::().unwrap(), + + "1::numeric" == "1".parse::().unwrap(), + "10000::numeric" == "10000".parse::().unwrap(), + "0.1::numeric" == "0.1".parse::().unwrap(), + "0.01::numeric" == "0.01".parse::().unwrap(), + "0.012::numeric" == "0.012".parse::().unwrap(), + "0.0123::numeric" == "0.0123".parse::().unwrap(), + "0.01234::numeric" == "0.01234".parse::().unwrap(), + "0.012345::numeric" == "0.012345".parse::().unwrap(), + "0.0123456::numeric" == "0.0123456".parse::().unwrap(), + "0.01234567::numeric" == "0.01234567".parse::().unwrap(), + "0.012345678::numeric" == "0.012345678".parse::().unwrap(), + "0.0123456789::numeric" == "0.0123456789".parse::().unwrap(), + "12.34::numeric" == "12.34".parse::().unwrap(), + "12345.6789::numeric" == "12345.6789".parse::().unwrap(), +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_bitvec.rs b/sqlx-wasm-test/src/pg_types_tests_bitvec.rs new file mode 100644 index 0000000000..8431b0ce9c --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_bitvec.rs @@ -0,0 +1,25 @@ +use sqlx_wasm_test::test_type; + +test_type!(bitvec( + Postgres, + // A full byte VARBIT + "B'01101001'" == sqlx::types::BitVec::from_bytes(&[0b0110_1001]), + // A VARBIT value missing five bits from a byte + "B'110'" == { + let mut bit_vec = sqlx::types::BitVec::with_capacity(4); + bit_vec.push(true); + bit_vec.push(true); + bit_vec.push(false); + bit_vec + }, + // A BIT value + "B'01101'::bit(5)" == { + let mut bit_vec = sqlx::types::BitVec::with_capacity(5); + bit_vec.push(false); + bit_vec.push(true); + bit_vec.push(true); + bit_vec.push(false); + bit_vec.push(true); + bit_vec + }, +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_chrono.rs b/sqlx-wasm-test/src/pg_types_tests_chrono.rs new file mode 100644 index 0000000000..81f18b6cb8 --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_chrono.rs @@ -0,0 +1,54 @@ +use sqlx::types::chrono::{ + DateTime, FixedOffset, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc, +}; +use sqlx_wasm_test::test_type; + +type PgTimeTz = sqlx::postgres::types::PgTimeTz; + +test_type!(chrono_date(Postgres, + "DATE '2001-01-05'" == NaiveDate::from_ymd(2001, 1, 5), + "DATE '2050-11-23'" == NaiveDate::from_ymd(2050, 11, 23) +)); + +test_type!(chrono_time(Postgres, + "TIME '05:10:20.115100'" == NaiveTime::from_hms_micro(5, 10, 20, 115100) +)); + +test_type!(chrono_date_time(Postgres, + "'2019-01-02 05:10:20'::timestamp" == NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20) +)); + +test_type!(chrono_date_time_vec>(Postgres, + "array['2019-01-02 05:10:20']::timestamp[]" + == vec![NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20)] +)); + +test_type!(chrono_date_time_tz_utc>(Postgres, + "TIMESTAMPTZ '2019-01-02 05:10:20.115100'" + == DateTime::::from_utc( + NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100), + Utc, + ) +)); + +test_type!(chrono_date_time_tz>(Postgres, + "TIMESTAMPTZ '2019-01-02 05:10:20.115100+06:30'" + == FixedOffset::east(60 * 60 * 6 + 1800).ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100) +)); + +test_type!(chrono_date_time_tz_vec>>(Postgres, + "array['2019-01-02 05:10:20.115100']::timestamptz[]" + == vec![ + DateTime::::from_utc( + NaiveDate::from_ymd(2019, 1, 2).and_hms_micro(5, 10, 20, 115100), + Utc, + ) + ] +)); + +test_type!(chrono_time_tz(Postgres, + "TIMETZ '05:10:20.115100+00'" == PgTimeTz { time: NaiveTime::from_hms_micro(5, 10, 20, 115100), offset: FixedOffset::east(0) }, + "TIMETZ '05:10:20.115100+06:30'" == PgTimeTz { time: NaiveTime::from_hms_micro(5, 10, 20, 115100), offset: FixedOffset::east(60 * 60 * 6 + 1800) }, + "TIMETZ '05:10:20.115100-05'" == PgTimeTz { time: NaiveTime::from_hms_micro(5, 10, 20, 115100), offset: FixedOffset::west(60 * 60 * 5) }, + "TIMETZ '05:10:20+02'" == PgTimeTz { time: NaiveTime::from_hms(5, 10, 20), offset: FixedOffset::east(60 * 60 * 2 )} +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_decimal.rs b/sqlx-wasm-test/src/pg_types_tests_decimal.rs new file mode 100644 index 0000000000..6534b94753 --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_decimal.rs @@ -0,0 +1,13 @@ +use sqlx_wasm_test::test_type; + +use std::str::FromStr; + +test_type!(decimal(Postgres, + "0::numeric" == sqlx::types::Decimal::from_str("0").unwrap(), + "1::numeric" == sqlx::types::Decimal::from_str("1").unwrap(), + "10000::numeric" == sqlx::types::Decimal::from_str("10000").unwrap(), + "0.1::numeric" == sqlx::types::Decimal::from_str("0.1").unwrap(), + "0.01234::numeric" == sqlx::types::Decimal::from_str("0.01234").unwrap(), + "12.34::numeric" == sqlx::types::Decimal::from_str("12.34").unwrap(), + "12345.6789::numeric" == sqlx::types::Decimal::from_str("12345.6789").unwrap(), +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_interval.rs b/sqlx-wasm-test/src/pg_types_tests_interval.rs new file mode 100644 index 0000000000..fa2c4b11f2 --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_interval.rs @@ -0,0 +1,30 @@ +use sqlx::postgres::types::PgInterval; +use sqlx_wasm_test::test_prepared_type; + +test_prepared_type!(interval( + Postgres, + "INTERVAL '1h'" + == PgInterval { + months: 0, + days: 0, + microseconds: 3_600_000_000 + }, + "INTERVAL '-1 hours'" + == PgInterval { + months: 0, + days: 0, + microseconds: -3_600_000_000 + }, + "INTERVAL '3 months 12 days 1h 15 minutes 10 second '" + == PgInterval { + months: 3, + days: 12, + microseconds: (3_600 + 15 * 60 + 10) * 1_000_000 + }, + "INTERVAL '03:10:20.116100'" + == PgInterval { + months: 0, + days: 0, + microseconds: (3 * 3_600 + 10 * 60 + 20) * 1_000_000 + 116100 + }, +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_ipnetwork.rs b/sqlx-wasm-test/src/pg_types_tests_ipnetwork.rs new file mode 100644 index 0000000000..71e18bff8b --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_ipnetwork.rs @@ -0,0 +1,52 @@ +use sqlx_wasm_test::test_unprepared_type; + +macro_rules! test_type { + ($name:ident<$ty:ty>($db:ident, $sql:literal, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::test_unprepared_type!($name<$ty>($db, $($text == $value),+)); + }; + + ($name:ident<$ty:ty>($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + paste::item! { + $crate::test_unprepared_type!($name<$ty>($db, $($text == $value),+)); + } + }; + + ($name:ident($db:ident, $($text:literal == $value:expr),+ $(,)?)) => { + $crate::test_type!($name<$name>($db, $($text == $value),+)); + }; +} + +test_type!(ipnetwork(Postgres, + "'127.0.0.1'::inet" + == "127.0.0.1" + .parse::() + .unwrap(), + "'8.8.8.8/24'::inet" + == "8.8.8.8/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0'::inet" + == "::ffff:1.2.3.0" + .parse::() + .unwrap(), + "'2001:4f8:3:ba::/64'::inet" + == "2001:4f8:3:ba::/64" + .parse::() + .unwrap(), + "'192.168'::cidr" + == "192.168.0.0/24" + .parse::() + .unwrap(), + "'::ffff:1.2.3.0/120'::cidr" + == "::ffff:1.2.3.0/120" + .parse::() + .unwrap(), +)); + +test_type!(ipnetwork_vec>(Postgres, + "'{127.0.0.1,8.8.8.8/24}'::inet[]" + == vec![ + "127.0.0.1".parse::().unwrap(), + "8.8.8.8/24".parse::().unwrap() + ] +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_json.rs b/sqlx-wasm-test/src/pg_types_tests_json.rs new file mode 100644 index 0000000000..8ff6c26231 --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_json.rs @@ -0,0 +1,80 @@ +use serde_json::value::RawValue as JsonRawValue; +use serde_json::{json, Value as JsonValue}; +use sqlx::postgres::PgRow; +use sqlx::types::Json; +use sqlx::{Executor, Row}; +use sqlx_wasm_test::test_type; + +// When testing JSON, coerce to JSONB for `=` comparison as `JSON = JSON` is not +// supported in PostgreSQL + +test_type!(json( + Postgres, + "SELECT ({0}::jsonb is not distinct from $1::jsonb)::int4, {0} as _2, $2 as _3", + "'\"Hello, World\"'::json" == json!("Hello, World"), + "'\"😎\"'::json" == json!("😎"), + "'\"🙋‍♀️\"'::json" == json!("🙋‍♀️"), + "'[\"Hello\", \"World!\"]'::json" == json!(["Hello", "World!"]) +)); + +test_type!(json_array>( + Postgres, + "SELECT ({0}::jsonb[] is not distinct from $1::jsonb[])::int4, {0} as _2, $2 as _3", + "array['\"😎\"'::json, '\"🙋‍♀️\"'::json]::json[]" == vec![json!("😎"), json!("🙋‍♀️")], +)); + +test_type!(jsonb( + Postgres, + "'\"Hello, World\"'::jsonb" == json!("Hello, World"), + "'\"😎\"'::jsonb" == json!("😎"), + "'\"🙋‍♀️\"'::jsonb" == json!("🙋‍♀️"), + "'[\"Hello\", \"World!\"]'::jsonb" == json!(["Hello", "World!"]) +)); + +test_type!(jsonb_array>( + Postgres, + "array['\"😎\"'::jsonb, '\"🙋‍♀️\"'::jsonb]::jsonb[]" == vec![json!("😎"), json!("🙋‍♀️")], +)); + +#[derive(serde::Deserialize, serde::Serialize, Debug, PartialEq)] +struct Friend { + name: String, + age: u32, +} + +test_type!(json_struct>(Postgres, + "'{\"name\":\"Joe\",\"age\":33}'::jsonb" == Json(Friend { name: "Joe".to_string(), age: 33 }) +)); + +test_type!(json_struct_vec>>(Postgres, + "array['{\"name\":\"Joe\",\"age\":33}','{\"name\":\"Bob\",\"age\":22}']::jsonb[]" + == vec![ + Json(Friend { name: "Joe".to_string(), age: 33 }), + Json(Friend { name: "Bob".to_string(), age: 22 }), + ] +)); + +#[wasm_bindgen_test::wasm_bindgen_test] +async fn test_json_raw_value() { + let mut conn = sqlx_wasm_test::new().await; + + // unprepared, text API + let row: PgRow = conn + .fetch_one("SELECT '{\"hello\": \"world\"}'::jsonb") + .await + .unwrap(); + + let value: &JsonRawValue = row.try_get(0).unwrap(); + + assert_eq!(value.get(), "{\"hello\": \"world\"}"); + + // prepared, binary API + let row: PgRow = conn + .fetch_one(sqlx::query("SELECT '{\"hello\": \"world\"}'::jsonb")) + .await + .unwrap(); + + let value: &JsonRawValue = row.try_get(0).unwrap(); + + assert_eq!(value.get(), "{\"hello\": \"world\"}"); +} diff --git a/sqlx-wasm-test/src/pg_types_tests_money.rs b/sqlx-wasm-test/src/pg_types_tests_money.rs new file mode 100644 index 0000000000..a5dc1f7144 --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_money.rs @@ -0,0 +1,9 @@ +use sqlx_wasm_test::test_prepared_type; + +use sqlx::postgres::types::PgMoney; + +test_prepared_type!(money(Postgres, "123.45::money" == PgMoney(12345))); + +test_prepared_type!(money_vec>(Postgres, + "array[123.45,420.00,666.66]::money[]" == vec![PgMoney(12345), PgMoney(42000), PgMoney(66666)], +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_range.rs b/sqlx-wasm-test/src/pg_types_tests_range.rs new file mode 100644 index 0000000000..120c76529a --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_range.rs @@ -0,0 +1,27 @@ +use sqlx::postgres::types::PgRange; +use sqlx_wasm_test::test_type; +use std::ops::Bound; + +const EXC2: Bound = Bound::Excluded(2); +const EXC3: Bound = Bound::Excluded(3); +const INC1: Bound = Bound::Included(1); +const INC2: Bound = Bound::Included(2); +const UNB: Bound = Bound::Unbounded; + +test_type!(int4range>(Postgres, + "'(,)'::int4range" == PgRange::from((UNB, UNB)), + "'(,]'::int4range" == PgRange::from((UNB, UNB)), + "'(,2)'::int4range" == PgRange::from((UNB, EXC2)), + "'(,2]'::int4range" == PgRange::from((UNB, EXC3)), + "'(1,)'::int4range" == PgRange::from((INC2, UNB)), + "'(1,]'::int4range" == PgRange::from((INC2, UNB)), + "'(1,2]'::int4range" == PgRange::from((INC2, EXC3)), + "'[,)'::int4range" == PgRange::from((UNB, UNB)), + "'[,]'::int4range" == PgRange::from((UNB, UNB)), + "'[,2)'::int4range" == PgRange::from((UNB, EXC2)), + "'[,2]'::int4range" == PgRange::from((UNB, EXC3)), + "'[1,)'::int4range" == PgRange::from((INC1, UNB)), + "'[1,]'::int4range" == PgRange::from((INC1, UNB)), + "'[1,2)'::int4range" == PgRange::from((INC1, EXC2)), + "'[1,2]'::int4range" == PgRange::from((INC1, EXC3)), +)); diff --git a/sqlx-wasm-test/src/pg_types_tests_time.rs b/sqlx-wasm-test/src/pg_types_tests_time.rs new file mode 100644 index 0000000000..7f1f94f34b --- /dev/null +++ b/sqlx-wasm-test/src/pg_types_tests_time.rs @@ -0,0 +1,39 @@ +use sqlx_wasm_test::{test_prepared_type, test_type}; + +use sqlx::types::time::{Date, OffsetDateTime, PrimitiveDateTime, Time, UtcOffset}; +use time::{date, time}; + +type PgTimeTz = sqlx::postgres::types::PgTimeTz; + +test_type!(time_date( + Postgres, + "DATE '2001-01-05'" == date!(2001 - 1 - 5), + "DATE '2050-11-23'" == date!(2050 - 11 - 23) +)); + +test_type!(time_time