diff --git a/Cargo.lock b/Cargo.lock index 6e60dc34..9ebb7b8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -209,7 +209,7 @@ checksum = "3109e49b1e4909e9db6515a30c633684d68cdeaa252f215214cb4fa1a5bfee2c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "synstructure", ] @@ -221,7 +221,7 @@ checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -241,7 +241,7 @@ checksum = "e539d3fca749fcee5236ab05e93a52867dd549cc157c8cb7f99595f3cedffdb5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -434,7 +434,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.100", + "syn 2.0.101", "which", ] @@ -527,7 +527,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -650,9 +650,9 @@ dependencies = [ [[package]] name = "cipherstash-client" -version = "0.22.2" +version = "0.23.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0168885dfe967dc3884a51e415ce15c82343b7fd8be36b08d39f25f4fc727064" +checksum = "3165e25e8e2cd848f80ed24f127e0e4c869860c9840dc1a723951b0bcd24b655" dependencies = [ "aes-gcm-siv", "anyhow", @@ -786,6 +786,7 @@ dependencies = [ name = "cipherstash-proxy-integration" version = "0.1.0" dependencies = [ + "bytes", "chrono", "cipherstash-client", "cipherstash-config", @@ -853,7 +854,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -931,6 +932,19 @@ dependencies = [ "winnow 0.7.4", ] +[[package]] +name = "const-hex" +version = "1.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83e22e0ed40b96a48d3db274f72fd365bd78f67af39b6bbd47e8a15e1c6207ff" +dependencies = [ + "cfg-if", + "cpufeatures", + "hex", + "proptest", + "serde", +] + [[package]] name = "const-oid" version = "0.9.6" @@ -1065,7 +1079,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1076,7 +1090,7 @@ checksum = "d336a2a514f6ccccaa3e09b02d41d35330c07ddf03a62165fcec10bb561c7806" dependencies = [ "darling_core", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1128,7 +1142,7 @@ checksum = "8034092389675178f570469e6c3b0465d3d30b4505c294a6550db47f3c17ad18" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1157,7 +1171,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "unicode-xid", ] @@ -1188,7 +1202,7 @@ dependencies = [ "dsl_auto_type", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1197,7 +1211,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" dependencies = [ - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1245,7 +1259,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1259,7 +1273,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1271,7 +1285,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1283,7 +1297,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1295,7 +1309,7 @@ dependencies = [ "darling", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1321,6 +1335,7 @@ dependencies = [ "pretty_assertions", "sqltk", "thiserror 2.0.12", + "topological-sort", "tracing", "tracing-subscriber", "vec1", @@ -1330,9 +1345,10 @@ dependencies = [ name = "eql-mapper-macros" version = "2.0.0" dependencies = [ + "pretty_assertions", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1520,7 +1536,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1947,7 +1963,7 @@ checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -1985,7 +2001,7 @@ checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2295,7 +2311,7 @@ checksum = "bf45bf44ab49be92fd1227a3be6fc6f617f1a337c06af54981048574d8783147" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2518,7 +2534,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2686,7 +2702,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2743,7 +2759,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2812,7 +2828,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" dependencies = [ "proc-macro2", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -2826,13 +2842,29 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.94" +version = "1.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31971752e70b8b2686d7e46ec17fb38dad4051d94024c88df49b667caea9c84" +checksum = "02b3e5e68a3a1a02aad3ec490a98007cbc13c37cbe84a3cd7b8e406d76e7f778" dependencies = [ "unicode-ident", ] +[[package]] +name = "proptest" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fcdab19deb5195a31cf7726a210015ff1496ba1464fd42cb4f537b8b01b471f" +dependencies = [ + "bitflags 2.9.0", + "lazy_static", + "num-traits", + "rand 0.9.0", + "rand_chacha 0.9.0", + "rand_xorshift", + "regex-syntax 0.8.5", + "unarray", +] + [[package]] name = "psm" version = "0.1.26" @@ -2958,6 +2990,15 @@ dependencies = [ "getrandom 0.3.2", ] +[[package]] +name = "rand_xorshift" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "513962919efc330f829edb2535844d1b912b0fbe2ca165d613e4e8788bb05a5a" +dependencies = [ + "rand_core 0.9.3", +] + [[package]] name = "rand_xoshiro" version = "0.6.0" @@ -3014,7 +3055,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76009fbe0614077fc1a2ce255e3a1881a2e3a3527097d5dc6d8212c585e7e38b" dependencies = [ "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -3521,7 +3562,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -3781,9 +3822,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.100" +version = "2.0.101" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "8ce2b7fc941b3a24138a0a7cf8e858bfc6a992e7978a068a5c760deb0ed43caf" dependencies = [ "proc-macro2", "quote", @@ -3807,7 +3848,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -3884,7 +3925,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -3895,7 +3936,7 @@ checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -3982,7 +4023,7 @@ checksum = "2d2e76690929402faae40aebdda620a2c0e25dd6d3b9afe48867dfd95991f4bd" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4011,7 +4052,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4124,6 +4165,12 @@ dependencies = [ "winnow 0.7.4", ] +[[package]] +name = "topological-sort" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d" + [[package]] name = "tower" version = "0.5.2" @@ -4172,7 +4219,7 @@ checksum = "395ae124c09f9e6918a2310af6038fba074bcf474ac352496d5910dd59a2226d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -4239,6 +4286,12 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +[[package]] +name = "unarray" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eaea85b334db583fe3274d12b4cd1880032beab409c0d774be044d4480ab9a94" + [[package]] name = "unicode-bidi" version = "0.3.18" @@ -4479,7 +4532,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "wasm-bindgen-shared", ] @@ -4514,7 +4567,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -4999,7 +5052,7 @@ checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "synstructure", ] @@ -5029,7 +5082,7 @@ checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5040,7 +5093,7 @@ checksum = "a996a8f63c5c4448cd959ac1bab0aaa3306ccfd060472f85943ee0750f0169be" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] @@ -5060,7 +5113,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", "synstructure", ] @@ -5081,18 +5134,19 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] [[package]] name = "zerokms-protocol" -version = "0.4.0" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01a9d0d8103cfa862b451f2c35144301df25a233f7fae041666b890a1578c3b1" +checksum = "af31358bcf35336b9990ce1d6f671f66cca00385fb21f55db155e38b7ec666cd" dependencies = [ "async-trait", "base64", "cipherstash-config", + "const-hex", "fake 2.10.0", "opaque-debug", "rand 0.8.5", @@ -5122,5 +5176,5 @@ checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.100", + "syn 2.0.101", ] diff --git a/Cargo.toml b/Cargo.toml index 693dd175..41389427 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,13 @@ edition = "2021" [profile.dev] incremental = true debug = true +opt-level = 0 +split-debuginfo = "unpacked" # or "unpacked" on macOS + +[profile.dev.package.sqltk] +opt-level = 0 +debug = true +split-debuginfo = "unpacked" # or "unpacked" on macOS # [profile.dev.package]# aws-lc-sys.opt-level = 3 # proc-macro2.opt-level = 3 @@ -17,8 +24,8 @@ debug = true # sqlparser.opt-level = 3 # syn.opt-level = 3 -[profile.dev.build-override] -opt-level = 3 +# [profile.dev.build-override] +# opt-level = 3 [profile.test] incremental = true @@ -36,6 +43,7 @@ debug = true [workspace.dependencies] sqltk = { version = "0.10.0" } +cipherstash-client = "0.23.0" thiserror = "2.0.9" tokio = { version = "1.44.2", features = ["full"] } tracing = "0.1" diff --git a/README.md b/README.md index 2b32a336..d0759b94 100644 --- a/README.md +++ b/README.md @@ -29,17 +29,19 @@ [Read the announcement](https://cipherstash.com/blog/introducing-proxy) +CipherStash Proxy provides transparent, *searchable* encryption for your existing Postgres database. -CipherStash Proxy provides a transparent proxy to your existing Postgres database. +CipherStash Proxy: +* Automatically encrypts and decrypts data with zero changes to SQL +* Supports queries over *encrypted* values: + - equality + - comparison + - ordering + - grouping +* Is written in Rust for high performance and strongly-typed mapping of SQL statements. +* Manages keys using CipherStash ZeroKMS, offering up to 14x the performance of AWS KMS -Proxy: -* Automatically encrypts and decrypts the columns you specify -* Supports most query types over encrypted values -* Runs in a Docker container -* Is written in Rust and uses a formal type system for SQL mapping -* Works with CipherStash ZeroKMS and offers up to 14x the performance of AWS KMS - -Behind the scenes, it uses the [Encrypt Query Language](https://github.com/cipherstash/encrypt-query-language/) to index and search encrypted data. +Behind the scenes, CipherStash Proxy uses the [Encrypt Query Language](https://github.com/cipherstash/encrypt-query-language/) to index and search encrypted data. ## Table of contents @@ -54,7 +56,7 @@ Behind the scenes, it uses the [Encrypt Query Language](https://github.com/ciphe > [!IMPORTANT] > **Prerequisites:** Before you start you need to have this software installed: > - [Docker](https://www.docker.com/) — see Docker's [documentation for installing](https://docs.docker.com/get-started/get-docker/) - + Get up and running in local dev in < 5 minutes: diff --git a/mise.toml b/mise.toml index 191cb746..809b6207 100644 --- a/mise.toml +++ b/mise.toml @@ -31,7 +31,7 @@ CS_PROXY__HOST = "proxy" # Misc DOCKER_CLI_HINTS = "false" # Please don't show us What's Next. -CS_EQL_VERSION = "eql-2.0.4" +CS_EQL_VERSION = "eql-2.0.6" [tools] "cargo:cargo-binstall" = "latest" diff --git a/packages/cipherstash-proxy-integration/Cargo.toml b/packages/cipherstash-proxy-integration/Cargo.toml index ea0a4533..7cde1e62 100644 --- a/packages/cipherstash-proxy-integration/Cargo.toml +++ b/packages/cipherstash-proxy-integration/Cargo.toml @@ -4,13 +4,22 @@ version = "0.1.0" edition = "2021" [dependencies] +bytes = "1.10.1" +cipherstash-client = { workspace = true, features = ["tokio"] } +cipherstash-config = "0.2.3" cipherstash-proxy = { path = "../cipherstash-proxy/" } chrono = { version = "0.4.39", features = ["clock"] } +clap = "4.5.32" +fake = { version = "4", features = ["chrono", "derive"] } + +hex = "0.4.3" +postgres-types = { version = "0.2.9", features = ["derive"] } rand = "0.9" recipher = "0.1.3" rustls = { version = "0.23.20", default-features = false, features = ["std"] } serde = "1.0" serde_json = "1.0" +tap = "1.0.1" temp-env = "0.3.6" tokio = { workspace = true } tokio-postgres = { version = "0.7", features = [ @@ -21,14 +30,5 @@ tokio-postgres-rustls = "0.13.0" tokio-rustls = "0.26.0" tracing = { workspace = true } tracing-subscriber = { workspace = true } -webpki-roots = "1.0" - -[dev-dependencies] -cipherstash-client = { version = "0.22.0", features = ["tokio"] } -cipherstash-config = "0.2.3" -clap = "4.5.32" -fake = { version = "4", features = ["chrono", "derive"] } -hex = "0.4.3" -postgres-types = { version = "0.2.9", features = ["derive"] } -tap = "1.0.1" uuid = { version = "1.11.0", features = ["serde", "v4"] } +webpki-roots = "1.0" diff --git a/packages/cipherstash-proxy-integration/src/common.rs b/packages/cipherstash-proxy-integration/src/common.rs index 61402594..3eab32aa 100644 --- a/packages/cipherstash-proxy-integration/src/common.rs +++ b/packages/cipherstash-proxy-integration/src/common.rs @@ -5,6 +5,7 @@ use rustls::{ client::danger::ServerCertVerifier, crypto::aws_lc_rs::default_provider, pki_types::CertificateDer, ClientConfig, }; +use serde_json::Value; use std::sync::{Arc, Once}; use tokio_postgres::{types::ToSql, Client, NoTls}; use tracing_subscriber::{filter::Directive, EnvFilter, FmtSubscriber}; @@ -105,7 +106,7 @@ pub async fn connect_with_tls(port: u16) -> Client { tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); client @@ -117,7 +118,7 @@ pub async fn connect(port: u16) -> Client { tokio::spawn(async move { if let Err(e) = connection.await { - eprintln!("connection error: {}", e); + eprintln!("connection error: {e}"); } }); @@ -175,7 +176,18 @@ where rows.iter() .filter_map(|row| { if let tokio_postgres::SimpleQueryMessage::Row(r) = row { - r.get(0).and_then(|val| val.parse::().ok()) + r.get(0).and_then(|val| { + // Convert string value to FromSql compatible type + // Try different type conversions based on the value format + // PostgreSQL returns booleans as "t" or "f" in simple queries + + // Convert PostgreSQL boolean format to binary representation + match val { + "t" => "true".parse::().ok(), + "f" => "false".parse::().ok(), + _ => val.parse::().ok(), + } + }) } else { None } @@ -199,6 +211,33 @@ pub async fn simple_query_with_null(sql: &str) -> Vec> { .collect() } +pub async fn insert(sql: &str, params: &[&(dyn ToSql + Sync)]) { + let client = connect_with_tls(PROXY).await; + client.query(sql, params).await.unwrap(); +} + +pub async fn insert_jsonb() -> Value { + let id = random_id(); + + let encrypted_jsonb = serde_json::json!({ + "id": id, + "string": "hello", + "number": 42, + "nested": { + "number": 1815, + "string": "world", + }, + "array_string": ["hello", "world"], + "array_number": [42, 84], + }); + + let sql = "INSERT INTO encrypted (id, encrypted_jsonb) VALUES ($1, $2)".to_string(); + + insert(&sql, &[&id, &encrypted_jsonb]).await; + + encrypted_jsonb +} + /// /// Configure the client TLS settings. /// These are the settings for connecting to the database with TLS. diff --git a/packages/cipherstash-proxy-integration/src/insert/insert_with_literal.rs b/packages/cipherstash-proxy-integration/src/insert/insert_with_literal.rs index a78b3f51..e681f714 100644 --- a/packages/cipherstash-proxy-integration/src/insert/insert_with_literal.rs +++ b/packages/cipherstash-proxy-integration/src/insert/insert_with_literal.rs @@ -9,6 +9,10 @@ mod tests { macro_rules! test_insert_with_literal { ($name: ident, $type: ident, $pg_type: ident) => { + test_insert_with_literal!($name, $type, $pg_type, false); + }; + + ($name: ident, $type: ident, $pg_type: ident, $cast: expr) => { #[tokio::test] pub async fn $name() { trace(); @@ -22,8 +26,14 @@ mod tests { let expected = vec![encrypted_val.clone()]; - let insert_sql = format!("INSERT INTO encrypted (id, {encrypted_col}) VALUES ($1, '{encrypted_val}')"); - let select_sql = format!("SELECT {encrypted_col} FROM encrypted WHERE id = $1"); + let cast_to_type: &str = if $cast { + &format!("::{}", stringify!($pg_type)) + } else { + "" + }; + + let insert_sql = format!("INSERT INTO encrypted (id, {encrypted_col}) VALUES ($1, '{encrypted_val}'{cast_to_type})"); + let select_sql = format!("SELECT {encrypted_col}{cast_to_type} FROM encrypted WHERE id = $1"); execute_query(&insert_sql, &[&id]).await; let actual = query_by::<$type>(&select_sql, &id).await; @@ -36,6 +46,10 @@ mod tests { macro_rules! test_insert_simple_query_with_literal { ($name: ident, $type: ident, $pg_type: ident) => { + test_insert_simple_query_with_literal!($name, $type, $pg_type, false); + }; + + ($name: ident, $type: ident, $pg_type: ident, $cast: expr) => { #[tokio::test] pub async fn $name() { trace(); @@ -48,8 +62,14 @@ mod tests { let encrypted_col = format!("encrypted_{}", stringify!($pg_type)); let encrypted_val = crate::value_for_type!($type, random_limited()); - let insert_sql = format!("INSERT INTO encrypted (id, {encrypted_col}) VALUES ({id}, '{encrypted_val}')"); - let select_sql = format!("SELECT {encrypted_col} FROM encrypted WHERE id = {id}"); + let cast_to_type: &str = if $cast { + &format!("::{}", stringify!($pg_type)) + } else { + "" + }; + + let insert_sql = format!("INSERT INTO encrypted (id, {encrypted_col}) VALUES ({id}, '{encrypted_val}'{cast_to_type})"); + let select_sql = format!("SELECT {encrypted_col}{cast_to_type} FROM encrypted WHERE id = {id}"); let expected = vec![encrypted_val]; @@ -69,7 +89,7 @@ mod tests { test_insert_with_literal!(insert_with_literal_bool, bool, bool); test_insert_with_literal!(insert_with_literal_text, String, text); test_insert_with_literal!(insert_with_literal_date, NaiveDate, date); - test_insert_with_literal!(insert_with_literal_jsonb, Value, jsonb); + test_insert_with_literal!(insert_with_literal_jsonb, Value, jsonb, true); test_insert_simple_query_with_literal!(insert_simple_query_with_literal_int2, i16, int2); test_insert_simple_query_with_literal!(insert_simple_query_with_literal_int4, i32, int4); @@ -78,7 +98,12 @@ mod tests { test_insert_simple_query_with_literal!(insert_simple_query_with_literal_bool, bool, bool); test_insert_simple_query_with_literal!(insert_simple_query_with_literal_text, String, text); test_insert_simple_query_with_literal!(insert_simple_query_with_literal_date, NaiveDate, date); - test_insert_simple_query_with_literal!(insert_simple_query_with_literal_jsonb, Value, jsonb); + test_insert_simple_query_with_literal!( + insert_simple_query_with_literal_jsonb, + Value, + jsonb, + true + ); // ----------------------------------------------------------------- diff --git a/packages/cipherstash-proxy-integration/src/lib.rs b/packages/cipherstash-proxy-integration/src/lib.rs index f0f1c612..83521e5a 100644 --- a/packages/cipherstash-proxy-integration/src/lib.rs +++ b/packages/cipherstash-proxy-integration/src/lib.rs @@ -18,6 +18,7 @@ mod pipeline; mod schema_change; mod select; mod simple_protocol; +mod support; #[macro_export] macro_rules! value_for_type { diff --git a/packages/cipherstash-proxy-integration/src/map_literals.rs b/packages/cipherstash-proxy-integration/src/map_literals.rs index 53610a75..24e49526 100644 --- a/packages/cipherstash-proxy-integration/src/map_literals.rs +++ b/packages/cipherstash-proxy-integration/src/map_literals.rs @@ -55,12 +55,12 @@ mod tests { let encrypted_jsonb = serde_json::json!({"key": "value"}); let sql = format!( - "INSERT INTO encrypted (id, encrypted_jsonb) VALUES ($1, '{encrypted_jsonb}')", + "INSERT INTO encrypted (id, encrypted_jsonb) VALUES ($1, '{encrypted_jsonb}'::jsonb)", ); client.query(&sql, &[&id]).await.unwrap(); - let sql = "SELECT id, encrypted_jsonb FROM encrypted WHERE id = $1"; + let sql = "SELECT id, encrypted_jsonb::jsonb FROM encrypted WHERE id = $1"; let rows = client.query(sql, &[&id]).await.unwrap(); assert_eq!(rows.len(), 1); diff --git a/packages/cipherstash-proxy-integration/src/migrate/mod.rs b/packages/cipherstash-proxy-integration/src/migrate/mod.rs index 41fdcdc9..d120300f 100644 --- a/packages/cipherstash-proxy-integration/src/migrate/mod.rs +++ b/packages/cipherstash-proxy-integration/src/migrate/mod.rs @@ -47,7 +47,7 @@ mod tests { let config = match TandemConfig::load(&args) { Ok(config) => config, Err(err) => { - eprintln!("Configuration Error: {}", err); + eprintln!("Configuration Error: {err}"); panic!(); } }; diff --git a/packages/cipherstash-proxy-integration/src/select/indexing.rs b/packages/cipherstash-proxy-integration/src/select/indexing.rs new file mode 100644 index 00000000..2b1c23ed --- /dev/null +++ b/packages/cipherstash-proxy-integration/src/select/indexing.rs @@ -0,0 +1,54 @@ +#[cfg(test)] +mod tests { + use crate::common::{ + connect_with_tls, insert, query_by, random_id, simple_query, trace, PROXY, + }; + use tokio_postgres::types::{FromSql, ToSql}; + use tracing::info; + + #[derive(Debug, ToSql, FromSql, PartialEq)] + #[postgres(name = "domain_type_with_check")] + pub struct Domain(String); + + /// + /// Tests insertion of custom domain type + /// + #[tokio::test] + async fn select_with_index() { + trace(); + + // let id = random_id(); + // let encrypted_val = Domain("ZZ".to_string()); + + // CREATE INDEX ON encrypted (e eql_v2.encrypted_operator_class); + // SELECT ore.e FROM ore WHERE id = 42 INTO ore_term; + + for n in 1..=10 { + let id = random_id(); + + let encrypted_text = format!("hello_{}", n); + + let sql = format!("INSERT INTO encrypted (id, encrypted_text) VALUES ($1, $2)"); + insert(&sql, &[&id, &encrypted_text]).await; + } + + let client = connect_with_tls(PROXY).await; + + let sql = "CREATE INDEX ON encrypted (encrypted_text eql_v2.encrypted_operator_class)"; + let _ = client.simple_query(sql).await; + + // let sql = + // "EXPLAIN ANALYZE SELECT encrypted_text FROM encrypted WHERE encrypted_text <= '{\"hm\": \"abc\"}'::jsonb::eql_v2_encrypted"; + // let result = simple_query::(sql).await; + + let sql = "EXPLAIN ANALYZE SELECT encrypted_text FROM encrypted WHERE encrypted_text <= $1"; + + let encrypted_text = "hello_10".to_string(); + let result = query_by::(sql, &encrypted_text).await; + + info!("Result: {:?}", result); + + // let expected = vec![encrypted_val]; + // assert_eq!(expected, result); + } +} diff --git a/packages/cipherstash-proxy-integration/src/select/jsonb_path_exists.rs b/packages/cipherstash-proxy-integration/src/select/jsonb_path_exists.rs new file mode 100644 index 00000000..089360cb --- /dev/null +++ b/packages/cipherstash-proxy-integration/src/select/jsonb_path_exists.rs @@ -0,0 +1,89 @@ +#[cfg(test)] +mod tests { + use crate::common::{clear, insert_jsonb, query_by, simple_query, trace}; + use crate::support::assert::assert_expected; + use crate::support::json_path::JsonPath; + + async fn select_jsonb(selector: &str, value: bool) { + let selector = JsonPath::new(selector); + let expected = vec![value]; + + let sql = "SELECT jsonb_path_exists(encrypted_jsonb, $1) FROM encrypted"; + let actual = query_by::(sql, &selector).await; + + assert_expected(&expected, &actual); + + let sql = format!("SELECT jsonb_path_exists(encrypted_jsonb, '{selector}') FROM encrypted"); + let actual = simple_query::(&sql).await; + + assert_expected(&expected, &actual); + } + + #[tokio::test] + async fn select_jsonb_path_exists_number() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.number", true).await; + } + + #[tokio::test] + async fn select_jsonb_path_exists_string() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.nested.string", true).await; + } + + #[tokio::test] + async fn select_jsonb_path_exists_value() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.nested", true).await; + } + + #[tokio::test] + async fn select_jsonb_path_exists_with_unknown_selector() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.vtha", false).await; + } + + #[tokio::test] + async fn select_jsonb_path_exists_with_alias() { + trace(); + + clear().await; + + insert_jsonb().await; + + let selector = JsonPath::new("$.nested"); + let expected = vec![true]; + + let sql = "SELECT jsonb_path_exists(encrypted_jsonb, $1) as selected FROM encrypted"; + let actual = query_by::(sql, &selector).await; + + assert_expected(&expected, &actual); + + let sql = format!( + "SELECT jsonb_path_exists(encrypted_jsonb, '{selector}') as selected FROM encrypted" + ); + let actual = simple_query::(&sql).await; + + assert_expected(&expected, &actual); + } +} diff --git a/packages/cipherstash-proxy-integration/src/select/jsonb_path_query.rs b/packages/cipherstash-proxy-integration/src/select/jsonb_path_query.rs new file mode 100644 index 00000000..40020d5e --- /dev/null +++ b/packages/cipherstash-proxy-integration/src/select/jsonb_path_query.rs @@ -0,0 +1,122 @@ +#[cfg(test)] +mod tests { + use crate::common::{clear, insert_jsonb, query_by, simple_query, trace}; + use crate::support::assert::assert_expected; + use crate::support::json_path::JsonPath; + use serde::de::DeserializeOwned; + use serde_json::Value; + + async fn select_jsonb(selector: &str, value: T) + where + T: DeserializeOwned, + serde_json::Value: From, + { + let selector = JsonPath::new(selector); + let value = Value::from(value); + + let expected = vec![value]; + + let sql = "SELECT jsonb_path_query(encrypted_jsonb, $1) FROM encrypted"; + let actual = query_by::(sql, &selector).await; + + assert_expected(&expected, &actual); + + let sql = format!("SELECT jsonb_path_query(encrypted_jsonb, '{selector}') FROM encrypted"); + let actual = simple_query::(&sql).await; + + assert_expected(&expected, &actual); + } + + #[tokio::test] + async fn select_jsonb_path_query_number() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.number", 42).await; + } + + #[tokio::test] + async fn select_jsonb_path_query_string() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.nested.string", "world".to_string()).await; + } + + #[tokio::test] + async fn select_jsonb_path_query_value() { + trace(); + + clear().await; + + insert_jsonb().await; + + let v = serde_json::json!({ + "number": 1815, + "string": "world", + }); + + select_jsonb("$.nested", v).await; + } + + #[tokio::test] + async fn select_jsonb_path_query_with_unknown() { + trace(); + + clear().await; + + insert_jsonb().await; + + let selector = JsonPath::new("$.vtha"); + + let expected = vec![]; + + let sql = "SELECT jsonb_path_query(encrypted_jsonb, $1) as selected FROM encrypted"; + let actual = query_by::(sql, &selector).await; + + assert_expected(&expected, &actual); + + let sql = format!( + "SELECT jsonb_path_query(encrypted_jsonb, '{selector}') as selected FROM encrypted" + ); + let actual = simple_query::(&sql).await; + + assert_expected(&expected, &actual); + } + + #[tokio::test] + async fn select_jsonb_path_query_with_alias() { + trace(); + + clear().await; + + insert_jsonb().await; + + let value = serde_json::json!({ + "number": 1815, + "string": "world", + }); + + let selector = JsonPath::new("$.nested"); + + let expected = vec![value]; + + let sql = "SELECT jsonb_path_query(encrypted_jsonb, $1) as selected FROM encrypted"; + let actual = query_by::(sql, &selector).await; + + assert_expected(&expected, &actual); + + let sql = format!( + "SELECT jsonb_path_query(encrypted_jsonb, '{selector}') as selected FROM encrypted" + ); + let actual = simple_query::(&sql).await; + + assert_expected(&expected, &actual); + } +} diff --git a/packages/cipherstash-proxy-integration/src/select/jsonb_path_query_first.rs b/packages/cipherstash-proxy-integration/src/select/jsonb_path_query_first.rs new file mode 100644 index 00000000..9c2a3039 --- /dev/null +++ b/packages/cipherstash-proxy-integration/src/select/jsonb_path_query_first.rs @@ -0,0 +1,141 @@ +#[cfg(test)] +mod tests { + use crate::common::{ + clear, insert_jsonb, query_by, simple_query, simple_query_with_null, trace, + }; + use crate::support::assert::assert_expected; + use crate::support::json_path::JsonPath; + use serde::de::DeserializeOwned; + use serde_json::Value; + + async fn select_jsonb(selector: &str, value: T) + where + T: DeserializeOwned, + serde_json::Value: From, + { + let selector = JsonPath::new(selector); + let value = Value::from(value); + + let expected = vec![value]; + + let sql = "SELECT jsonb_path_query_first(encrypted_jsonb, $1) FROM encrypted"; + let actual = query_by::(sql, &selector).await; + + assert_expected(&expected, &actual); + + let sql = + format!("SELECT jsonb_path_query_first(encrypted_jsonb, '{selector}') FROM encrypted"); + let actual = simple_query::(&sql).await; + + assert_expected(&expected, &actual); + } + + #[tokio::test] + async fn select_jsonb_path_query_first_string() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.array_string[*]", "hello".to_string()).await; + } + + #[tokio::test] + async fn select_jsonb_path_query_first_number() { + trace(); + + clear().await; + + insert_jsonb().await; + + select_jsonb("$.array_number[*]", 42).await; + } + + #[tokio::test] + async fn select_jsonb_path_query_first_with_unknown() { + trace(); + + clear().await; + + insert_jsonb().await; + + let selector = JsonPath::new("$.vtha"); + + let sql = "SELECT jsonb_path_query_first(encrypted_jsonb, $1) as selected FROM encrypted"; + let actual = query_by::>(sql, &selector).await; + + let expected = vec![None]; + assert_expected(&expected, &actual); + + let sql = format!( + "SELECT jsonb_path_query_first(encrypted_jsonb, '{selector}') as selected FROM encrypted" + ); + + // Expect again for Vec> type + let expected = vec![None]; + let actual = simple_query_with_null(&sql).await; + + assert_expected(&expected, &actual); + } + + // #[tokio::test] + // async fn select_jsonb_path_query_first_string() { + // trace(); + + // clear().await; + + // insert_jsonb().await; + // insert_jsonb().await; + + // select_jsonb("$.nested.string", "world".to_string()).await; + // } + + // #[tokio::test] + // async fn select_jsonb_path_query_first_value() { + // trace(); + + // clear().await; + + // insert_jsonb().await; + // insert_jsonb().await; + + // let v = serde_json::json!({ + // "number": 1815, + // "string": "world", + // }); + + // select_jsonb("$.nested", v).await; + // } + + // #[tokio::test] + // async fn select_jsonb_path_query_first_with_alias() { + // trace(); + + // clear().await; + + // insert_jsonb().await; + // insert_jsonb().await; + + // let value = serde_json::json!({ + // "number": 1815, + // "string": "world", + // }); + + // let selector = JsonPath::new("$.nested"); + + // let expected = vec![value]; + + // let sql = "SELECT jsonb_path_query_first(encrypted_jsonb, $1) as selected FROM encrypted"; + // let actual = query_by::(sql, &selector).await; + + // assert_expected(&expected, &actual); + + // let sql = format!( + // "SELECT jsonb_path_query_first(encrypted_jsonb, '{selector}') as selected FROM encrypted" + // ); + // let actual = simple_query::(&sql).await; + + // assert_expected(&expected, &actual); + // } +} diff --git a/packages/cipherstash-proxy-integration/src/select/mod.rs b/packages/cipherstash-proxy-integration/src/select/mod.rs index d99c45c6..698277d8 100644 --- a/packages/cipherstash-proxy-integration/src/select/mod.rs +++ b/packages/cipherstash-proxy-integration/src/select/mod.rs @@ -1,4 +1,7 @@ mod group_by; +mod jsonb_path_exists; +mod jsonb_path_query; +mod jsonb_path_query_first; mod order_by; mod order_by_with_null; mod pg_catalog; diff --git a/packages/cipherstash-proxy-integration/src/select/order_by_with_null.rs b/packages/cipherstash-proxy-integration/src/select/order_by_with_null.rs index faec29cf..5d4209f4 100644 --- a/packages/cipherstash-proxy-integration/src/select/order_by_with_null.rs +++ b/packages/cipherstash-proxy-integration/src/select/order_by_with_null.rs @@ -10,7 +10,7 @@ mod tests { T: ToSql + Sync + Send + 'static, { let id = random_id(); - let sql = format!("INSERT INTO encrypted (id, {}) VALUES ($1, $2)", col); + let sql = format!("INSERT INTO encrypted (id, {col}) VALUES ($1, $2)"); execute_query(&sql, &[&id, &val]).await; } diff --git a/packages/cipherstash-proxy-integration/src/simple_protocol/map_literals.rs b/packages/cipherstash-proxy-integration/src/simple_protocol/map_literals.rs index aee5d0ec..f4d4fd57 100644 --- a/packages/cipherstash-proxy-integration/src/simple_protocol/map_literals.rs +++ b/packages/cipherstash-proxy-integration/src/simple_protocol/map_literals.rs @@ -22,7 +22,7 @@ mod tests { if let Row(r) = &rows[1] { assert_eq!(Some("plain"), r.get(1)); } else { - panic!("Unexpected query results: {:?}", rows); + panic!("Unexpected query results: {rows:?}"); } } @@ -43,7 +43,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } let sql = format!("SELECT id, encrypted_text FROM encrypted WHERE id = {id}"); @@ -83,7 +83,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } let sql = format!("SELECT id, encrypted_int2 FROM encrypted WHERE id = {id}"); @@ -124,7 +124,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } let sql = format!("SELECT id, encrypted_date FROM encrypted WHERE id = {id}"); @@ -166,7 +166,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } let sql = format!("SELECT id, encrypted_date FROM encrypted WHERE id = {id}"); @@ -217,7 +217,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } let sql = format!("SELECT id, encrypted_int4 FROM encrypted WHERE id = {id}"); diff --git a/packages/cipherstash-proxy-integration/src/simple_protocol/map_nulls.rs b/packages/cipherstash-proxy-integration/src/simple_protocol/map_nulls.rs index effd7bcf..78d93a8e 100644 --- a/packages/cipherstash-proxy-integration/src/simple_protocol/map_nulls.rs +++ b/packages/cipherstash-proxy-integration/src/simple_protocol/map_nulls.rs @@ -25,7 +25,7 @@ mod tests { if let Row(r) = &rows[1] { assert_eq!(encrypted_text, r.get(1)); } else { - panic!("Unexpected query results: {:?}", rows); + panic!("Unexpected query results: {rows:?}"); } let encrypted_int4: Option<&str> = None; @@ -41,7 +41,7 @@ mod tests { if let Row(r) = &rows[1] { assert_eq!(encrypted_int4, r.get(1)); } else { - panic!("Unexpected query results: {:?}", rows); + panic!("Unexpected query results: {rows:?}"); } } @@ -69,7 +69,7 @@ mod tests { assert!(r.get(3).is_none()); assert!(r.get(4).is_none()); } else { - panic!("Unexpected query results: {:?}", rows); + panic!("Unexpected query results: {rows:?}"); } let sql = format!("UPDATE encrypted SET encrypted_float8 = NULL WHERE id = {id}"); @@ -87,7 +87,7 @@ mod tests { assert!(r.get(4).is_none()); assert!(r.get(5).is_none()); } else { - panic!("Unexpected query results: {:?}", rows); + panic!("Unexpected query results: {rows:?}"); } } } diff --git a/packages/cipherstash-proxy-integration/src/simple_protocol/multiple_statements.rs b/packages/cipherstash-proxy-integration/src/simple_protocol/multiple_statements.rs index 3b559bdf..d1a20fa9 100644 --- a/packages/cipherstash-proxy-integration/src/simple_protocol/multiple_statements.rs +++ b/packages/cipherstash-proxy-integration/src/simple_protocol/multiple_statements.rs @@ -31,7 +31,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } // Check each Row by ID @@ -75,7 +75,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } // Build SQL string containing multiple statements; @@ -92,7 +92,7 @@ mod tests { // CmmandComplete does not implement PartialEq, so no equality check with == match &insert_result[0] { CommandComplete(n) => assert_eq!(1, *n), - _unexpected => panic!("unexpected insert result: {:?}", insert_result), + _unexpected => panic!("unexpected insert result: {insert_result:?}"), } // Check each Row by ID diff --git a/packages/cipherstash-proxy-integration/src/support/assert.rs b/packages/cipherstash-proxy-integration/src/support/assert.rs index 50f03107..0edba0cf 100644 --- a/packages/cipherstash-proxy-integration/src/support/assert.rs +++ b/packages/cipherstash-proxy-integration/src/support/assert.rs @@ -1,19 +1,9 @@ pub fn assert_expected(expected: &[T], actual: &[T]) where - T: std::fmt::Display + PartialEq + std::fmt::Debug, + T: PartialEq + std::fmt::Debug, { assert_eq!(expected.len(), actual.len()); for (e, a) in expected.iter().zip(actual) { assert_eq!(e, a); } } - -pub fn assert_expected_as_string(expected: &[T], actual: &[String]) -where - T: std::fmt::Display + PartialEq + std::fmt::Debug, -{ - assert_eq!(expected.len(), actual.len()); - for (e, a) in expected.iter().zip(actual) { - assert_eq!(e.to_string(), *a); - } -} diff --git a/packages/cipherstash-proxy-integration/src/support/json_path.rs b/packages/cipherstash-proxy-integration/src/support/json_path.rs index f6576f33..a1556274 100644 --- a/packages/cipherstash-proxy-integration/src/support/json_path.rs +++ b/packages/cipherstash-proxy-integration/src/support/json_path.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use bytes::BytesMut; use postgres_types::{Format, ToSql, Type}; use std::{ diff --git a/packages/cipherstash-proxy-integration/src/support/mod.rs b/packages/cipherstash-proxy-integration/src/support/mod.rs index d5bcd89d..ba363d24 100644 --- a/packages/cipherstash-proxy-integration/src/support/mod.rs +++ b/packages/cipherstash-proxy-integration/src/support/mod.rs @@ -1,2 +1,4 @@ +#[cfg(test)] pub mod assert; + pub mod json_path; diff --git a/packages/cipherstash-proxy/Cargo.toml b/packages/cipherstash-proxy/Cargo.toml index 18edd4a2..1771317c 100644 --- a/packages/cipherstash-proxy/Cargo.toml +++ b/packages/cipherstash-proxy/Cargo.toml @@ -8,7 +8,7 @@ bigdecimal = { version = "0.4.6", features = ["serde-json"] } arc-swap = "1.7.1" bytes = { version = "1.9", default-features = false } chrono = { version = "0.4.39", features = ["clock"] } -cipherstash-client = { version = "0.22.1", features = ["tokio"] } +cipherstash-client = { workspace = true, features = ["tokio"] } clap = { version = "4.5.31", features = ["derive", "env"] } config = { version = "0.15", features = [ "async", diff --git a/packages/cipherstash-proxy/src/config/log.rs b/packages/cipherstash-proxy/src/config/log.rs index 7108474a..efbff5b2 100644 --- a/packages/cipherstash-proxy/src/config/log.rs +++ b/packages/cipherstash-proxy/src/config/log.rs @@ -104,7 +104,7 @@ impl Display for LogLevel { LogLevel::Debug => "debug", LogLevel::Trace => "trace", }; - write!(f, "{}", s) + write!(f, "{s}") } } diff --git a/packages/cipherstash-proxy/src/config/tandem.rs b/packages/cipherstash-proxy/src/config/tandem.rs index 08c8fe64..70d24ff8 100644 --- a/packages/cipherstash-proxy/src/config/tandem.rs +++ b/packages/cipherstash-proxy/src/config/tandem.rs @@ -314,7 +314,7 @@ impl TandemConfig { stack_size .parse() .inspect_err(|err| { - println!("Could not parse env var RUST_MIN_STACK: {}", err); + println!("Could not parse env var RUST_MIN_STACK: {err}"); println!("Using the default thread stack size"); }) .unwrap_or(DEFAULT_THREAD_STACK_SIZE); diff --git a/packages/cipherstash-proxy/src/encrypt/mod.rs b/packages/cipherstash-proxy/src/encrypt/mod.rs index 03b36f37..3f8c74d9 100644 --- a/packages/cipherstash-proxy/src/encrypt/mod.rs +++ b/packages/cipherstash-proxy/src/encrypt/mod.rs @@ -10,21 +10,24 @@ use crate::{ postgresql::Column, Identifier, EQL_SCHEMA_VERSION, }; -use cipherstash_client::config::{ConfigError, ZeroKMSConfigWithClientKey}; use cipherstash_client::{ config::EnvSource, credentials::{auto_refresh::AutoRefresh, ServiceCredentials}, encryption::{ self, Encrypted, EncryptedEntry, EncryptedSteVecTerm, IndexTerm, Plaintext, - PlaintextTarget, ReferencedPendingPipeline, + PlaintextTarget, Queryable, ReferencedPendingPipeline, }, schema::ColumnConfig, ConsoleConfig, CtsConfig, ZeroKMSConfig, }; +use cipherstash_client::{ + config::{ConfigError, ZeroKMSConfigWithClientKey}, + encryption::QueryOp, +}; use config::EncryptConfigManager; use schema::SchemaManager; use std::{sync::Arc, vec}; -use tracing::{debug, warn}; +use tracing::{debug, info, warn}; /// SQL Statement for loading encrypt configuration from database const ENCRYPT_CONFIG_QUERY: &str = include_str!("./sql/select_config.sql"); @@ -53,8 +56,9 @@ pub struct Encrypt { impl Encrypt { pub async fn init(config: TandemConfig) -> Result { let cipher = Arc::new(init_cipher(&config).await?); - let schema = SchemaManager::init(&config.database).await?; let encrypt_config = EncryptConfigManager::init(&config.database).await?; + // TODO: populate EqlTraitImpls based on config + let schema = SchemaManager::init(&config.database).await?; let eql_version = { let client = connect::database(&config.database).await?; @@ -93,12 +97,20 @@ impl Encrypt { columns: &[Option], ) -> Result>, Error> { let mut pipeline = ReferencedPendingPipeline::new(self.cipher.clone()); + let mut index_term_plaintexts = vec![None; columns.len()]; for (idx, item) in plaintexts.into_iter().zip(columns.iter()).enumerate() { match item { (Some(plaintext), Some(column)) => { - let encryptable = PlaintextTarget::new(plaintext, column.config.clone()); - pipeline.add_with_ref::(encryptable, idx)?; + info!(target: ENCRYPT, msg = "ENCRYPT", idx, ?column, ?plaintext); + + if column.is_encryptable() { + let encryptable = PlaintextTarget::new(plaintext, column.config.clone()); + pipeline.add_with_ref::(encryptable, idx)?; + } else { + info!(target: ENCRYPT, msg = "Add to index_term_plaintexts", idx, ?column); + index_term_plaintexts[idx] = Some(plaintext); + } } (None, Some(column)) => { // Parameter is NULL @@ -119,18 +131,32 @@ impl Encrypt { } let mut encrypted_eql = vec![]; - if !pipeline.is_empty() { - let mut result = pipeline.encrypt(None).await?; - - for (idx, opt) in columns.iter().enumerate() { - let mut encrypted = None; - if let Some(col) = opt { - if let Some(e) = result.remove(idx) { - encrypted = Some(to_eql_encrypted(e, &col.identifier)?); - } + // if !pipeline.is_empty() { } + + let mut result = pipeline.encrypt(None).await?; + + for (idx, opt) in columns.iter().enumerate() { + let mut encrypted = None; + + if let Some(column) = opt { + if let Some(e) = result.remove(idx) { + encrypted = Some(to_eql_encrypted(e, &column.identifier)?); + } else if let Some(plaintext) = index_term_plaintexts[idx].clone() { + let index = column.config.clone().into_ste_vec_index().unwrap(); + let op = QueryOp::SteVecSelector; + + let index_term = (index, plaintext).build_queryable(self.cipher.clone(), op)?; + + encrypted = Some(to_eql_encrypted_from_index_term( + index_term, + &column.identifier, + )?); } - encrypted_eql.push(encrypted); } + + info!(target: ENCRYPT, msg = "encrypted_eql", idx, ?opt, ?encrypted); + + encrypted_eql.push(encrypted); } Ok(encrypted_eql) @@ -153,7 +179,7 @@ impl Encrypt { let (indices, encrypted): (Vec<_>, Vec<_>) = ciphertexts .into_iter() .enumerate() - .filter_map(|(idx, eql)| Some((idx, eql?.body.ciphertext))) + .filter_map(|(idx, eql)| Some((idx, eql?.body.ciphertext.unwrap()))) .collect::<_>(); // Decrypt the ciphertexts @@ -235,6 +261,37 @@ async fn init_cipher(config: &TandemConfig) -> Result { } } +fn to_eql_encrypted_from_index_term( + index_term: IndexTerm, + identifier: &Identifier, +) -> Result { + debug!(target: ENCRYPT, msg = "Encrypted to EQL", ?identifier); + + let selector = match index_term { + IndexTerm::SteVecSelector(s) => Some(hex::encode(s.as_bytes())), + _ => return Err(EncryptError::InvalidIndexTerm.into()), + }; + + Ok(eql::EqlEncrypted { + identifier: identifier.to_owned(), + version: EQL_SCHEMA_VERSION, + body: EqlEncryptedBody { + ciphertext: None, + indexes: EqlEncryptedIndexes { + bloom_filter: None, + ore_block_u64_8_256: None, + hmac_256: None, + blake3: None, + ore_cllw_u64_8: None, + ore_cllw_var_8: None, + selector, + ste_vec_index: None, + }, + is_array_item: None, + }, + }) +} + fn to_eql_encrypted( encrypted: Encrypted, identifier: &Identifier, @@ -287,7 +344,7 @@ fn to_eql_encrypted( identifier: identifier.to_owned(), version: EQL_SCHEMA_VERSION, body: EqlEncryptedBody { - ciphertext, + ciphertext: Some(ciphertext), indexes: EqlEncryptedIndexes { bloom_filter: match_index, ore_block_u64_8_256: ore_index, @@ -333,7 +390,7 @@ fn to_eql_encrypted( }; eql::EqlEncryptedBody { - ciphertext: record, + ciphertext: Some(record), indexes, is_array_item: Some(parent_is_array), } @@ -347,7 +404,7 @@ fn to_eql_encrypted( identifier: identifier.to_owned(), version: EQL_SCHEMA_VERSION, body: EqlEncryptedBody { - ciphertext: ciphertext.clone(), + ciphertext: Some(ciphertext.clone()), indexes: EqlEncryptedIndexes { bloom_filter: None, ore_block_u64_8_256: None, diff --git a/packages/cipherstash-proxy/src/encrypt/schema/manager.rs b/packages/cipherstash-proxy/src/encrypt/schema/manager.rs index d080bb96..a638b911 100644 --- a/packages/cipherstash-proxy/src/encrypt/schema/manager.rs +++ b/packages/cipherstash-proxy/src/encrypt/schema/manager.rs @@ -3,6 +3,7 @@ use crate::encrypt::{AGGREGATE_QUERY, SCHEMA_QUERY}; use crate::error::Error; use crate::{connect, log::SCHEMA}; use arc_swap::ArcSwap; +use eql_mapper::{self, EqlTraits}; use eql_mapper::{Column, Schema, Table}; use sqltk::parser::ast::Ident; use std::sync::Arc; @@ -141,7 +142,10 @@ pub async fn load_schema(config: &DatabaseConfig) -> Result { let column = match column_type_name.as_deref() { Some("eql_v2_encrypted") => { debug!(target: SCHEMA, msg = "eql_v2_encrypted column", table = table_name, column = col); - Column::eql(ident) + + // TODO - map config to the set of implemented traits + let eql_traits = EqlTraits::all(); + Column::eql(ident, eql_traits) } _ => Column::native(ident), }; diff --git a/packages/cipherstash-proxy/src/eql/mod.rs b/packages/cipherstash-proxy/src/eql/mod.rs index 222847ef..88cf9e56 100644 --- a/packages/cipherstash-proxy/src/eql/mod.rs +++ b/packages/cipherstash-proxy/src/eql/mod.rs @@ -1,4 +1,4 @@ -use cipherstash_client::zerokms::{encrypted_record, EncryptedRecord}; +use cipherstash_client::zerokms::EncryptedRecord; use serde::{Deserialize, Serialize}; use sqltk::parser::ast::Ident; @@ -75,8 +75,15 @@ pub struct EqlEncrypted { #[derive(Debug, Deserialize, Serialize)] pub struct EqlEncryptedBody { - #[serde(rename = "c", with = "encrypted_record::formats::mp_base85")] - pub(crate) ciphertext: EncryptedRecord, + #[serde( + rename = "c", + // serialize_with = "serialize_option_encrypted_record", + default, + with = "formats::mp_base85", + // with = "encrypted_record::formats::mp_base85", + skip_serializing_if = "Option::is_none" + )] + pub(crate) ciphertext: Option, #[serde(flatten)] pub(crate) indexes: EqlEncryptedIndexes, @@ -85,6 +92,63 @@ pub struct EqlEncryptedBody { pub(crate) is_array_item: Option, } +// /// Serializes an Option using the mp_base85 format. +// pub fn serialize_option_encrypted_record( +// value: &Option, +// serializer: S, +// ) -> Result +// where +// S: Serializer, +// { +// match value { +// Some(record) => { +// encrypted_record::formats::mp_base85::serialize(record, serializer) +// // serialize(record, serializer) +// // let encoded = record.to_mp_base85().map_err(serde::ser::Error::custom)?; +// // serializer.serialize_some(&encoded) +// } +// None => serializer.serialize_none(), +// } +// } +pub mod formats { + pub mod mp_base85 { + use super::super::*; + use serde::Deserialize; + + pub fn serialize( + ciphertext: &Option, + serializer: S, + ) -> Result + where + S: serde::Serializer, + { + // encrypted_record::formats::mp_base85 + match ciphertext { + Some(record) => { + let s = record.to_mp_base85().map_err(serde::ser::Error::custom)?; + serializer.serialize_some(&s) + } + + None => serializer.serialize_none(), + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + let s: Option = Option::deserialize(deserializer)?; + if let Some(s) = s { + Ok(EncryptedRecord::from_mp_base85(&s) + // .map_err(serde::de::Error::custom) + .ok()) + } else { + Ok(None) + } + } + } +} + /// /// EqlEncryptedIndexes /// - null values should not be serialized @@ -152,17 +216,18 @@ mod tests { pub fn ciphertext_json() { let expected = Identifier::new("table", "column"); + let ciphertext = Some(EncryptedRecord { + iv: Iv::default(), + ciphertext: vec![1; 32], + tag: vec![1; 16], + descriptor: "ciphertext".to_string(), + dataset_id: Some(Uuid::new_v4()), + }); let ct = EqlEncrypted { identifier: expected.clone(), version: 1, body: EqlEncryptedBody { - ciphertext: EncryptedRecord { - iv: Iv::default(), - ciphertext: vec![1; 32], - tag: vec![1; 16], - descriptor: "ciphertext".to_string(), - dataset_id: Some(Uuid::new_v4()), - }, + ciphertext, indexes: EqlEncryptedIndexes { ore_block_u64_8_256: None, bloom_filter: None, diff --git a/packages/cipherstash-proxy/src/error.rs b/packages/cipherstash-proxy/src/error.rs index 5535eab0..a63fbb0e 100644 --- a/packages/cipherstash-proxy/src/error.rs +++ b/packages/cipherstash-proxy/src/error.rs @@ -218,6 +218,9 @@ pub enum EncryptError { #[error("Column configuration for column '{column}' in table '{table}' does not match the encrypted column. For help visit {}#encrypt-column-config-mismatch", ERROR_DOC_BASE_URL)] ColumnConfigurationMismatch { table: String, column: String }, + #[error("InvalidIndexTerm")] + InvalidIndexTerm, + /// This should in practice be unreachable #[error("Missing encrypt configuration for column type `{plaintext_type}`. For help visit {}#encrypt-missing-encrypt-configuration", ERROR_DOC_BASE_URL)] MissingEncryptConfiguration { plaintext_type: String }, @@ -382,6 +385,6 @@ mod tests { let error = MappingError::Internal("unexpected bug encounterd".to_string()); let message = error.to_string(); - assert_eq!(format!("Statement encountered an internal error. This may be a bug in the statement mapping module of CipherStash Proxy. Please visit {}#mapping-internal-error for more information.", ERROR_DOC_BASE_URL), message); + assert_eq!(format!("Statement encountered an internal error. This may be a bug in the statement mapping module of CipherStash Proxy. Please visit {ERROR_DOC_BASE_URL}#mapping-internal-error for more information."), message); } } diff --git a/packages/cipherstash-proxy/src/main.rs b/packages/cipherstash-proxy/src/main.rs index 849c29dd..7d214d70 100644 --- a/packages/cipherstash-proxy/src/main.rs +++ b/packages/cipherstash-proxy/src/main.rs @@ -19,7 +19,7 @@ fn main() -> Result<(), Box> { let config = match TandemConfig::load(&args) { Ok(config) => config, Err(err) => { - eprintln!("{}", err); + eprintln!("{err}"); std::process::exit(exitcode::CONFIG); } }; diff --git a/packages/cipherstash-proxy/src/postgresql/backend.rs b/packages/cipherstash-proxy/src/postgresql/backend.rs index 41af7514..37267ac1 100644 --- a/packages/cipherstash-proxy/src/postgresql/backend.rs +++ b/packages/cipherstash-proxy/src/postgresql/backend.rs @@ -361,7 +361,12 @@ where let param_types = statement .param_columns .iter() - .map(|col| col.as_ref().map(|col| col.postgres_type.clone())) + .map(|col| { + col.as_ref().map(|col| { + debug!(target: MAPPER, client_id = self.context.client_id, ColumnConfig = ?col); + col.postgres_type.clone() + }) + }) .collect::>(); debug!(target: MAPPER, client_id = self.context.client_id, param_types = ?param_types); diff --git a/packages/cipherstash-proxy/src/postgresql/context/column.rs b/packages/cipherstash-proxy/src/postgresql/context/column.rs index 20155ca8..e1cdc3b1 100644 --- a/packages/cipherstash-proxy/src/postgresql/context/column.rs +++ b/packages/cipherstash-proxy/src/postgresql/context/column.rs @@ -11,8 +11,14 @@ pub struct Column { } impl Column { - pub fn new(identifier: Identifier, config: ColumnConfig) -> Column { - let postgres_type = column_type_to_postgres_type(&config.cast_type); + pub fn new( + identifier: Identifier, + config: ColumnConfig, + postgres_type: Option, + ) -> Column { + let postgres_type = + postgres_type.unwrap_or(column_type_to_postgres_type(&config.cast_type)); + Column { identifier, config, @@ -43,6 +49,10 @@ impl Column { pub fn is_param_type(&self, param_type: &Type) -> bool { param_type == &self.postgres_type } + + pub fn is_encryptable(&self) -> bool { + self.postgres_type != postgres_types::Type::JSONPATH + } } fn column_type_to_postgres_type(col_type: &ColumnType) -> postgres_types::Type { diff --git a/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs b/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs index 7850741f..5bfdf36d 100644 --- a/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs +++ b/packages/cipherstash-proxy/src/postgresql/data/from_sql.rs @@ -156,11 +156,11 @@ fn text_from_sql( (&Type::TIMESTAMPTZ, _) => { unimplemented!("TIMESTAMPTZ") } - (&Type::TEXT | &Type::JSONB, ColumnType::JsonB) => { - serde_json::from_str::(val) - .map_err(|_| MappingError::CouldNotParseParameter) - .map(Plaintext::new) - } + // If JSONB, JSONPATH values are treated as strings + (&Type::TEXT | &Type::JSONPATH, ColumnType::JsonB) => Ok(Plaintext::new(val)), + (&Type::JSONB, ColumnType::JsonB) => serde_json::from_str::(val) + .map_err(|_| MappingError::CouldNotParseParameter) + .map(Plaintext::new), (ty, _) => Err(MappingError::UnsupportedParameterType { name: ty.name().to_owned(), oid: ty.oid(), @@ -177,6 +177,9 @@ fn binary_from_sql( col_type: ColumnType, ) -> Result { match (pg_type, col_type) { + (&Type::TEXT, ColumnType::Utf8Str) => { + parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) + } (&Type::BOOL, ColumnType::Boolean) => { parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) } @@ -189,10 +192,6 @@ fn binary_from_sql( (&Type::INT2, ColumnType::SmallInt) => { parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) } - (&Type::TEXT, ColumnType::Utf8Str) => { - parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) - } - // INT4 and INT2 can be converted to Int plaintext (&Type::INT4, ColumnType::Int) => { parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) @@ -229,6 +228,10 @@ fn binary_from_sql( parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) } + // If JSONB, JSONPATH values are treated as strings + (&Type::JSONPATH, ColumnType::JsonB) => { + parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) + } (&Type::JSON | &Type::JSONB | &Type::BYTEA, ColumnType::JsonB) => { parse_bytes_from_sql::(bytes, pg_type).map(Plaintext::new) } diff --git a/packages/cipherstash-proxy/src/postgresql/frontend.rs b/packages/cipherstash-proxy/src/postgresql/frontend.rs index 51a134aa..e6c10a63 100644 --- a/packages/cipherstash-proxy/src/postgresql/frontend.rs +++ b/packages/cipherstash-proxy/src/postgresql/frontend.rs @@ -26,9 +26,10 @@ use crate::prometheus::{ use crate::EqlEncrypted; use bytes::BytesMut; use cipherstash_client::encryption::Plaintext; -use eql_mapper::{self, EqlMapperError, EqlValue, TableColumn, TypeCheckedStatement}; +use eql_mapper::{self, EqlMapperError, EqlTerm, TableColumn, TypeCheckedStatement}; use metrics::{counter, histogram}; use pg_escape::quote_literal; +use postgres_types::Type; use serde::Serialize; use sqltk::parser::ast::{self, Value}; use sqltk::parser::dialect::PostgreSqlDialect; @@ -376,7 +377,7 @@ where return Ok(vec![]); } - let plaintexts = literals_to_plaintext(&literal_values, literal_columns)?; + let plaintexts = literals_to_plaintext(literal_values, literal_columns)?; let start = Instant::now(); @@ -814,25 +815,28 @@ where typed_statement: &eql_mapper::TypeCheckedStatement<'_>, ) -> Result>, Error> { let mut projection_columns = vec![]; - if let eql_mapper::Projection::WithColumns(columns) = &typed_statement.projection { - for col in columns { - let eql_mapper::ProjectionColumn { ty, .. } = col; - let configured_column = match ty { - eql_mapper::Value::Eql(EqlValue(TableColumn { table, column })) => { - let identifier: Identifier = Identifier::from((table, column)); - debug!( - target: MAPPER, - client_id = self.context.client_id, - msg = "Configured column", - column = ?identifier - ); - self.get_column(identifier)? - } - _ => None, - }; - projection_columns.push(configured_column) - } + + for col in typed_statement.projection.columns() { + let eql_mapper::ProjectionColumn { ty, .. } = col; + let configured_column = match ty { + eql_mapper::Value::Eql(eql_term) => { + let TableColumn { table, column } = eql_term.table_column(); + let identifier: Identifier = Identifier::from((table, column)); + + debug!( + target: MAPPER, + client_id = self.context.client_id, + msg = "Configured column", + column = ?identifier, + ?eql_term, + ); + self.get_column(identifier, eql_term)? + } + _ => None, + }; + projection_columns.push(configured_column) } + Ok(projection_columns) } @@ -844,6 +848,7 @@ where /// /// Preserves the ordering and semantics of the projection to reduce the complexity of positional encryption. /// + /// fn get_param_columns( &self, typed_statement: &eql_mapper::TypeCheckedStatement<'_>, @@ -852,17 +857,19 @@ where for param in typed_statement.params.iter() { let configured_column = match param { - (_, eql_mapper::Value::Eql(EqlValue(TableColumn { table, column }))) => { + (_, eql_mapper::Value::Eql(eql_term)) => { + let TableColumn { table, column } = eql_term.table_column(); let identifier = Identifier::from((table, column)); debug!( target: MAPPER, client_id = self.context.client_id, msg = "Encrypted parameter", - column = ?identifier + column = ?identifier, + ?eql_term, ); - self.get_column(identifier)? + self.get_column(identifier, eql_term)? } _ => None, }; @@ -878,21 +885,20 @@ where ) -> Result>, Error> { let mut literal_columns = vec![]; - for (eql_value, _) in typed_statement.literals.iter() { - match eql_value { - EqlValue(TableColumn { table, column }) => { - let identifier = Identifier::from((table, column)); - debug!( - target: MAPPER, - client_id = self.context.client_id, - msg = "Encrypted literal", - identifier = ?identifier - ); - let col = self.get_column(identifier)?; - if col.is_some() { - literal_columns.push(col); - } - } + for (eql_term, _) in typed_statement.literals.iter() { + let TableColumn { table, column } = eql_term.table_column(); + let identifier = Identifier::from((table, column)); + + debug!( + target: MAPPER, + client_id = self.context.client_id, + msg = "Encrypted literal", + column = ?identifier, + ?eql_term, + ); + let col = self.get_column(identifier, eql_term)?; + if col.is_some() { + literal_columns.push(col); } } @@ -903,7 +909,11 @@ where /// Get the column configuration for the Identifier /// Returns `EncryptError::UnknownColumn` if configuration cannot be found for the Identified column /// if mapping enabled, and None if mapping is disabled. It'll log a warning either way. - fn get_column(&self, identifier: Identifier) -> Result, Error> { + fn get_column( + &self, + identifier: Identifier, + eql_term: &EqlTerm, + ) -> Result, Error> { match self.encrypt.get_column_config(&identifier) { Some(config) => { debug!( @@ -912,7 +922,16 @@ where msg = "Configured column", column = ?identifier ); - Ok(Some(Column::new(identifier, config))) + + // IndexTerm::SteVecSelector + + let postgres_type = if matches!(eql_term, EqlTerm::JsonPath(_)) { + Some(Type::JSONPATH) + } else { + None + }; + + Ok(Some(Column::new(identifier, config, postgres_type))) } None => { warn!( @@ -955,7 +974,7 @@ where // This *should* be sufficient for escaping error messages as we're only // using the string literal, and not identifiers - let quoted_error = quote_literal(format!("{}", err).as_str()); + let quoted_error = quote_literal(format!("{err}").as_str()); let content = format!("DO $$ BEGIN RAISE EXCEPTION {quoted_error}; END; $$;"); debug!( @@ -973,13 +992,13 @@ where } fn literals_to_plaintext( - literals: &Vec<&ast::Value>, + literals: &Vec<(EqlTerm, &ast::Value)>, literal_columns: &Vec>, ) -> Result>, Error> { let plaintexts = literals .iter() .zip(literal_columns) - .map(|(val, col)| match col { + .map(|((_, val), col)| match col { Some(col) => literal_from_sql(val, col.cast_type()).map_err(|err| { debug!( target: MAPPER, diff --git a/packages/cipherstash-proxy/src/postgresql/handler.rs b/packages/cipherstash-proxy/src/postgresql/handler.rs index 693c48bf..62bfcbe5 100644 --- a/packages/cipherstash-proxy/src/postgresql/handler.rs +++ b/packages/cipherstash-proxy/src/postgresql/handler.rs @@ -305,7 +305,7 @@ pub fn md5_hash(username: &[u8], password: &[u8], salt: &[u8; 4]) -> String { md5.update(password); md5.update(username); let output = md5.finalize_reset(); - md5.update(format!("{:x}", output)); + md5.update(format!("{output:x}")); md5.update(salt); format!("md5{:x}", md5.finalize()) } diff --git a/packages/cipherstash-proxy/src/postgresql/messages/authentication/auth.rs b/packages/cipherstash-proxy/src/postgresql/messages/authentication/auth.rs index e87a4849..c96f9e11 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/authentication/auth.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/authentication/auth.rs @@ -298,7 +298,7 @@ impl Display for SaslMechanism { SaslMechanism::ScramSha256 => SCRAM_SHA_256.to_owned(), SaslMechanism::ScramSha256Plus => SCRAM_SHA_256_PLUS.to_owned(), }; - write!(f, "{}", s) + write!(f, "{s}") } } diff --git a/packages/cipherstash-proxy/src/postgresql/messages/bind.rs b/packages/cipherstash-proxy/src/postgresql/messages/bind.rs index 1471e97d..1798b0c2 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/bind.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/bind.rs @@ -186,7 +186,7 @@ impl BindParam { impl Display for BindParam { fn fmt(&self, f: &mut Formatter) -> fmt::Result { let s = String::from_utf8_lossy(&self.bytes).to_string(); - write!(f, "{}", s) + write!(f, "{s}") } } diff --git a/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs b/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs index 5190d911..ac084058 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/data_row.rs @@ -7,7 +7,7 @@ use crate::{ }; use bytes::{Buf, BufMut, BytesMut}; use std::io::Cursor; -use tracing::{debug, error}; +use tracing::{debug, error, info}; #[derive(Debug, Clone)] pub struct DataRow { @@ -26,6 +26,7 @@ impl DataRow { ) -> Vec> { let mut result = vec![]; for (data_column, column_config) in self.columns.iter_mut().zip(column_configuration) { + info!(target: DECRYPT, ?column_config, ?data_column); let encrypted = column_config .as_ref() .filter(|_| data_column.is_not_null()) @@ -34,6 +35,7 @@ impl DataRow { .try_into() .inspect_err(|err| match err { Error::Encrypt(EncryptError::ColumnIsNull) => { + debug!(target: DECRYPT, msg ="ColumnIsNull", ?config); // Not an error, as you were data_column.set_null(); } @@ -179,6 +181,8 @@ impl TryFrom<&mut DataColumn> for eql::EqlEncrypted { fn try_from(col: &mut DataColumn) -> Result { if let Some(bytes) = &col.bytes { + info!(target: DECRYPT, ?bytes); + if &bytes[0..=1] == b"(\"" { // Text encoding // Encrypted record is in the form ("{}") @@ -198,6 +202,7 @@ impl TryFrom<&mut DataColumn> for eql::EqlEncrypted { } } } else { + // BINARY ENCODING // 12 bytes for the binary rowtype header // plus 1 byte for the jsonb header (value of 1) // [Int32] Number of fields (N) @@ -220,7 +225,10 @@ impl TryFrom<&mut DataColumn> for eql::EqlEncrypted { let sliced = &bytes[start..]; match serde_json::from_slice(sliced) { - Ok(e) => return Ok(e), + Ok(e) => { + info!(target: DECRYPT, ?e); + return Ok(e); + } Err(err) => { debug!(target: DECRYPT, error = err.to_string()); return Err(err.into()); @@ -252,7 +260,7 @@ mod tests { fn column_config(column: &str) -> Option { let identifier = Identifier::new("encrypted", column); let config = ColumnConfig::build("column".to_string()).casts_as(ColumnType::SmallInt); - let column = Column::new(identifier, config); + let column = Column::new(identifier, config, None); Some(column) } diff --git a/packages/cipherstash-proxy/src/postgresql/messages/param_description.rs b/packages/cipherstash-proxy/src/postgresql/messages/param_description.rs index 75bc1194..56ba751e 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/param_description.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/param_description.rs @@ -1,11 +1,13 @@ use super::BackendCode; use crate::{ error::{Error, ProtocolError}, + log::MAPPER, SIZE_I16, SIZE_I32, }; use bytes::{Buf, BufMut, BytesMut}; use postgres_types::Type; use std::io::Cursor; +use tracing::debug; /// /// Describe b't' (Backend) message. @@ -34,6 +36,8 @@ pub struct ParamDescription { impl ParamDescription { pub fn map_types(&mut self, mapped_types: &[Option]) { + debug!(target: MAPPER, ?mapped_types); + for (idx, t) in mapped_types.iter().enumerate() { if let Some(t) = t { self.types[idx] = t.oid() as i32; diff --git a/packages/cipherstash-proxy/src/postgresql/messages/parse.rs b/packages/cipherstash-proxy/src/postgresql/messages/parse.rs index c54d0cfa..68fa394a 100644 --- a/packages/cipherstash-proxy/src/postgresql/messages/parse.rs +++ b/packages/cipherstash-proxy/src/postgresql/messages/parse.rs @@ -157,7 +157,7 @@ mod tests { let config = ColumnConfig::build("column".to_string()).casts_as(ColumnType::SmallInt); - let column = Column::new(identifier, config); + let column = Column::new(identifier, config, None); let columns = vec![None, Some(column)]; parse.rewrite_param_types(&columns); diff --git a/packages/cipherstash-proxy/src/prometheus.rs b/packages/cipherstash-proxy/src/prometheus.rs index b85cfacb..cde700d4 100644 --- a/packages/cipherstash-proxy/src/prometheus.rs +++ b/packages/cipherstash-proxy/src/prometheus.rs @@ -35,7 +35,7 @@ pub const SERVER_BYTES_SENT_TOTAL: &str = "cipherstash_proxy_server_bytes_sent_t pub const SERVER_BYTES_RECEIVED_TOTAL: &str = "cipherstash_proxy_server_bytes_received_total"; pub fn start(host: String, port: u16) -> Result<(), Error> { - let address = format!("{}:{}", host, port); + let address = format!("{host}:{port}"); let socket_address: SocketAddr = address.parse().unwrap(); debug!(target: DEVELOPMENT, msg = "Starting Prometheus exporter", port); diff --git a/packages/eql-mapper-macros/Cargo.toml b/packages/eql-mapper-macros/Cargo.toml index 6ceb98c2..1cb0dc12 100644 --- a/packages/eql-mapper-macros/Cargo.toml +++ b/packages/eql-mapper-macros/Cargo.toml @@ -1,7 +1,9 @@ [package] name = "eql-mapper-macros" +description = "Macros to reduce boilerplate in the implementation of eql-mapper" version.workspace = true edition.workspace = true +publish = false [lib] proc-macro = true @@ -9,4 +11,7 @@ proc-macro = true [dependencies] syn = { version = "2.0", features = ["full"] } quote = "1.0" -proc-macro2 = "1.0" \ No newline at end of file +proc-macro2 = "1.0" + +[dev-dependencies] +pretty_assertions = "1.4.1" diff --git a/packages/eql-mapper-macros/README.md b/packages/eql-mapper-macros/README.md new file mode 100644 index 00000000..e9af3678 --- /dev/null +++ b/packages/eql-mapper-macros/README.md @@ -0,0 +1,3 @@ +# eql-mapper-macros + +This crate is a private implementation detail of eql-mapper. \ No newline at end of file diff --git a/packages/eql-mapper-macros/src/lib.rs b/packages/eql-mapper-macros/src/lib.rs index 79bb5872..aa5f15eb 100644 --- a/packages/eql-mapper-macros/src/lib.rs +++ b/packages/eql-mapper-macros/src/lib.rs @@ -1,115 +1,166 @@ -use proc_macro::TokenStream; +//! Defines macros specifically for reducing the amount of boilerplate in `eql-mapper`. + +mod trace_infer; use quote::{quote, ToTokens}; -use syn::{ - parse::Parse, parse_macro_input, parse_quote, Attribute, FnArg, Ident, ImplItem, ImplItemFn, - ItemImpl, Pat, PatType, Signature, Type, TypePath, TypeReference, +use trace_infer::*; +mod parse_type_decl; + +use proc_macro::TokenStream; + +use crate::parse_type_decl::{ + BinaryOpDecls, ConcreteTyArgs, FunctionDecls, ShallowInitTypes, TVar, TypeDecl, TypeEnvDecl, }; -/// This macro generates consistently defined `#[tracing::instrument]` attributes for `InferType::infer_enter` & -/// `InferType::infer_enter` implementations on `TypeInferencer`. +/// Generates `#[tracing::instrument]` attributes for `InferType::infer_enter` & `InferType::infer_enter` +/// implementations on `TypeInferencer`. /// /// This attribute MUST be defined on the trait `impl` itself (not the trait method impls). #[proc_macro_attribute] pub fn trace_infer(_attr: TokenStream, item: TokenStream) -> TokenStream { - let mut input = parse_macro_input!(item as ItemImpl); - - for item in &mut input.items { - if let ImplItem::Fn(ImplItemFn { - attrs, - sig: - Signature { - ident: method, - inputs, - .. - }, - .. - }) = item - { - let node_ident_and_type: Option<(&Ident, &Type)> = - if let Some(FnArg::Typed(PatType { - ty: node_ty, pat, .. - })) = inputs.get(1) - { - if let Pat::Ident(pat_ident) = &**pat { - Some((&pat_ident.ident, node_ty)) - } else { - None - } - } else { - None - }; - - let vec_ident: Ident = parse_quote!(Vec); + trace_infer_(_attr, item) +} - match node_ident_and_type { - Some((node_ident, node_ty)) => { - let (formatter, node_ty_abbrev) = match node_ty { - Type::Reference(TypeReference { elem, .. }) => match &**elem { - Type::Path(TypePath { path, .. }) => { - let last_segment = path.segments.last().unwrap(); - let last_segment_ident = &last_segment.ident; - let last_segment_arguments = if last_segment.arguments.is_empty() { - None - } else { - let args = &last_segment.arguments; - Some(quote!(<#args>)) - }; - match last_segment_ident { - ident if vec_ident == *ident => { - (quote!(crate::FmtAstVec), quote!(#last_segment_ident #last_segment_arguments)) - } - _ => (quote!(crate::FmtAst), quote!(#last_segment_ident #last_segment_arguments)) - } - }, - _ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>") - }, - _ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>") - }; +/// Parses a `;`-separated block of binary operator declarations, like this: +/// +/// ```ignore +/// let ops: Vec = binary_operators! { +/// (T = T) -> Native where T: Eq; +/// (T -> ::Accessor) -> T where T: JsonLike; +/// (T <@ T) -> Native where T: Contain; +/// (T ~~ ::Tokenized) -> Native where T: TokenMatch; +/// // ... +/// }; +/// +#[proc_macro] +pub fn binary_operators(tokens: TokenStream) -> TokenStream { + let binops = syn::parse_macro_input!(tokens as BinaryOpDecls); + binops.to_token_stream().into() +} - let node_ty_abbrev = node_ty_abbrev - .to_token_stream() - .to_string() - .replace(" ", ""); +/// Parses a `;`-separated block of function declarations, like this: +/// +/// ```ignore +/// let items: Vec = functions! { +/// pg_catalog.count(T) -> Native; +/// pg_catalog.min(T) -> T where T: Ord; +/// pg_catalog.max(T) -> T where T: Ord; +/// pg_catalog.jsonb_path_query(J, ::Path) -> J where J: JsonLike; +/// }; +/// ``` +#[proc_macro] +pub fn functions(tokens: TokenStream) -> TokenStream { + let functions = syn::parse_macro_input!(tokens as FunctionDecls); + functions.to_token_stream().into() +} - let target = format!("eql-mapper::{}", method.to_string().to_uppercase()); +/// Builds a [`TypeDecl`] from type declaration syntax. Useful for avoiding boilerplate, especially in tests. +/// +/// The generated code is guaranteed not to panic. +/// +/// ```ignore +/// let eql_ty: TypeDecl = ty!(EQL(customer.email)); +/// let native: TypeDecl = ty!(Native); +/// let projection: TypeDecl = ty!({Native(customer.id) as id, EQL(customer.email: Eq) as email}); +/// let array: TypeDecl = ty!([EQL(customer.email: Eq)]); +/// ``` +#[proc_macro] +pub fn ty(tokens: TokenStream) -> TokenStream { + let type_decl = syn::parse_macro_input!(tokens as TypeDecl); + type_decl.to_token_stream().into() +} - let attr: TracingInstrumentAttr = syn::parse2(quote! { - #[tracing::instrument( - target = #target, - level = "trace", - skip(self, #node_ident), - fields( - ast_ty = #node_ty_abbrev, - ast = %#formatter(#node_ident), - ), - ret(Debug) - )] - }) - .unwrap(); - attrs.push(attr.attr); - } - None => { - return quote!(compile_error!( - "could not determine name of node argumemt in Infer impl" - )) - .to_token_stream() - .into(); - } - } - } +/// Builds a concrete type from type declaration syntax. Useful for avoiding boilerplate, especially in tests. +/// +/// WARNING: this macro generates code that will panic if type instantiation fails so limit its usage to setting up +/// tests. +/// +/// ```ignore +/// let eql_ty: crate::Type = concrete_ty!(EQL(customer.email)); +/// let native: crate::Type = concrete_ty!(Native); +/// let projection: crate::Type = concrete_ty!({Native(customer.id) as id, EQL(customer.email: Eq) as email}); +/// let projection: crate::Projection = concrete_ty!({Native(customer.id) as id, EQL(customer.email: Eq) as email} as crate::Projection); +/// let array: crate::Type = concrete_ty!([EQL(customer.email: Eq)]); +/// ``` +#[proc_macro] +pub fn concrete_ty(tokens: TokenStream) -> TokenStream { + let args = syn::parse_macro_input!(tokens as ConcreteTyArgs); + let type_decl = &args.ty_decl; + if let Some(ty_as) = &args.ty_as { + quote! {{ + let mut unifier = crate::inference::unifier::Unifier::new( + std::rc::Rc::new(std::cell::RefCell::new(crate::inference::TypeRegistry::new())) + ); + let ty_as: #ty_as = #type_decl.instantiate_concrete().unwrap().resolved_as(&mut unifier).unwrap(); + ty_as + }}.into() + } else { + quote! {{ + let mut unifier = crate::inference::unifier::Unifier::new( + std::rc::Rc::new(std::cell::RefCell::new(crate::inference::TypeRegistry::new())) + ); + use crate::inference::unifier::ResolveType; + #type_decl.instantiate_concrete().unwrap().resolve_type(&mut unifier).unwrap() + }} + .into() } +} - input.to_token_stream().into() +/// Parses a list of pseudo-Rust let bindings where the right hand of the `=` is type declaration syntax (i.e. can be +/// parsed with [`macro@ty`]) and assigns an initialised `Arc` to each binding. +/// +/// WARNING: this macro generates code that will panic if type instantiation fails so it is recommended to limit its +/// usage to setting up tests. +/// +/// The type declarations are immediatly converted to `Arc` values using `InstantiateType::instantiate_shallow` +/// and assigned to a local variable binding in the current scope. +/// +/// ```ignore +/// let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); +/// +/// shallow_init_types! {&mut unifier, { +/// let lhs = T; +/// let rhs = Native; +/// let expected = Native; +/// }}; +/// +/// let actual = unifier.unify(lhs, rhs).unwrap(); +/// assert_eq!(actual, expected); +/// ``` +#[proc_macro] +pub fn shallow_init_types(tokens: TokenStream) -> TokenStream { + let shallow_init_types = syn::parse_macro_input!(tokens as ShallowInitTypes); + shallow_init_types.to_token_stream().into() } -struct TracingInstrumentAttr { - attr: Attribute, +/// Shortcut for creating a named type variable. Does not save much boilerplate but is easier on the eye. +/// +/// ```ignore +/// // this: +/// let var: TVar = tvar!(A); +/// +/// // is sugar for this: +/// let var: TVar = TVar("A".into()); +/// ``` +#[proc_macro] +pub fn tvar(tokens: TokenStream) -> TokenStream { + let tvar = syn::parse_macro_input!(tokens as TVar); + tvar.to_token_stream().into() } -impl Parse for TracingInstrumentAttr { - fn parse(input: syn::parse::ParseStream) -> syn::Result { - Ok(Self { - attr: Attribute::parse_outer(input)?.first().unwrap().clone(), - }) - } +/// Builds a type environment from a set of `;`-separated type equations. This helps to reduce boilerplate in tests. +/// +/// The left hand side of the equation is always a type variable, the right hand side is any type declaration. +/// +/// ```ignore +/// let env = type_env! { +/// P = {A as id, B as name, C as email}; +/// A = Native(customer.id); +/// B = EQL(customer.name: Eq); +/// C = EQL(customer.email: Eq); +/// }; +/// ``` +#[proc_macro] +pub fn type_env(tokens: TokenStream) -> TokenStream { + let env = syn::parse_macro_input!(tokens as TypeEnvDecl); + env.to_token_stream().into() } diff --git a/packages/eql-mapper-macros/src/parse_type_decl.rs b/packages/eql-mapper-macros/src/parse_type_decl.rs new file mode 100644 index 00000000..ab3e12e6 --- /dev/null +++ b/packages/eql-mapper-macros/src/parse_type_decl.rs @@ -0,0 +1,825 @@ +use proc_macro2::token_stream::TokenStream; +use quote::{quote, ToTokens, TokenStreamExt}; +use syn::{ + braced, bracketed, parenthesized, + parse::{Parse, ParseStream}, + punctuated::Punctuated, + token::{self}, + Ident, Token, TypePath, +}; + +mod kw { + syn::custom_keyword!(Accessor); + syn::custom_keyword!(Contain); + syn::custom_keyword!(EQL); + syn::custom_keyword!(Eq); + syn::custom_keyword!(Full); + syn::custom_keyword!(JsonLike); + syn::custom_keyword!(Native); + syn::custom_keyword!(Only); + syn::custom_keyword!(Ord); + syn::custom_keyword!(Partial); + syn::custom_keyword!(Path); + syn::custom_keyword!(SetOf); + syn::custom_keyword!(TokenMatch); +} + +/// Generates a newtype wrapper struct around a `TokenStream` and a implements `ToTokens` for it. +/// The newtype wrapper allows a `syn::parse::Parse` implementation to be attached to it. +macro_rules! tokens_of { + ($ident:ident) => { + pub(super) struct $ident(TokenStream); + + impl ToTokens for $ident { + fn to_tokens(&self, tokens: &mut TokenStream) { + self.0.to_tokens(tokens); + } + } + }; +} + +tokens_of!(ArrayDecl); +tokens_of!(AssociatedTypeDecl); +tokens_of!(BinaryOpDecl); +tokens_of!(BoundsDecl); +tokens_of!(EqlTerm); +tokens_of!(EqlTrait); +tokens_of!(EqlTraits); +tokens_of!(FunctionDecl); +tokens_of!(NativeDecl); +tokens_of!(ProjectionColumnDecl); +tokens_of!(ProjectionDecl); +tokens_of!(SetOfDecl); +tokens_of!(SqltkBinOp); +tokens_of!(TVar); +tokens_of!(TableColumn); +tokens_of!(TypeEquation); +tokens_of!(TypeEnvDecl); +tokens_of!(TypeDecl); +tokens_of!(VarDecl); + +impl Parse for TVar { + fn parse(input: ParseStream) -> syn::Result { + let ident = Ident::parse(input)?.to_string(); + Ok(Self(quote! { + crate::inference::unifier::TVar(#ident.to_string()) + })) + } +} + +impl Parse for VarDecl { + fn parse(input: ParseStream) -> syn::Result { + let ident = Ident::parse(input)?.to_string(); + if input.peek(Token![:]) { + let _: Token![:] = input.parse()?; + let bounds = EqlTraits::parse(input)?; + Ok(Self(quote! { + crate::inference::unifier::VarDecl { + tvar: crate::inference::unifier::TVar(#ident.to_string()), + bounds: #bounds, + } + })) + } else { + Ok(Self(quote! { + crate::inference::unifier::VarDecl { + tvar: crate::inference::unifier::TVar(#ident.to_string()), + bounds: crate::inference::unifier::EqlTraits::default(), + } + })) + } + } +} + +impl Parse for EqlTraits { + fn parse(input: ParseStream) -> syn::Result { + let mut traits: Vec = Vec::new(); + + loop { + traits.push(EqlTrait::parse(input)?); + + if !input.peek(token::Plus) { + break; + } + + token::Plus::parse(input)?; + } + + Ok(Self(quote!( + crate::inference::unifier::EqlTraits::from_iter(vec![#(#traits),*]) + ))) + } +} + +impl Parse for BoundsDecl { + fn parse(input: ParseStream) -> syn::Result { + let mut traits: Vec = Vec::new(); + + let tvar = TVar::parse(input)?; + + let _: token::Colon = input.parse()?; + + loop { + traits.push(EqlTrait::parse(input)?); + + if !input.peek(token::Plus) { + break; + } + + let _: token::Plus = input.parse()?; + } + + Ok(Self(quote! { + crate::inference::unifier::BoundsDecl( + #tvar, + crate::inference::unifier::EqlTraits::from_iter(vec![#(#traits),*]) + ) + })) + } +} + +impl Parse for EqlTrait { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(kw::Eq) { + kw::Eq::parse(input)?; + return Ok(Self(quote!(crate::inference::unifier::EqlTrait::Eq))); + } + + if input.peek(kw::Ord) { + kw::Ord::parse(input)?; + return Ok(Self(quote!(crate::inference::unifier::EqlTrait::Ord))); + } + + if input.peek(kw::TokenMatch) { + kw::TokenMatch::parse(input)?; + return Ok(Self(quote!( + crate::inference::unifier::EqlTrait::TokenMatch + ))); + } + + if input.peek(kw::JsonLike) { + kw::JsonLike::parse(input)?; + return Ok(Self(quote!(crate::inference::unifier::EqlTrait::JsonLike))); + } + + if input.peek(kw::Contain) { + kw::Contain::parse(input)?; + return Ok(Self(quote!(crate::inference::unifier::EqlTrait::Contain))); + } + + Err(syn::Error::new( + input.span(), + format!( + "Expected Eq, Ord, TokenMatch or JsonLike while parsing EqlTrait; got: {}", + input.cursor().token_stream() + ), + )) + } +} + +impl Parse for AssociatedTypeDecl { + fn parse(input: ParseStream) -> syn::Result { + let _: token::Lt = input.parse()?; + let impl_tvar = TVar::parse(input)?; + let _: token::As = input.parse()?; + let as_eql_trait = EqlTrait::parse(input)?; + let _: token::Gt = input.parse()?; + let _: token::PathSep = input.parse()?; + let type_name_ident = input.parse::()?; + let type_name = type_name_ident.to_string(); + + Ok(Self(quote! { + crate::inference::unifier::AssociatedTypeDecl { + impl_decl: Box::new(crate::inference::unifier::TypeDecl::Var( + crate::inference::unifier::VarDecl{ + tvar: #impl_tvar, + bounds: crate::inference::unifier::EqlTraits::none() + } + )), + as_eql_trait: #as_eql_trait, + type_name: #type_name, + } + })) + } +} + +impl Parse for NativeDecl { + fn parse(input: ParseStream) -> syn::Result { + let _: kw::Native = input.parse()?; + if input.peek(token::Paren) { + let content; + parenthesized!(content in input); + let table_column = TableColumn::parse(&content)?; + + Ok(Self( + quote!(crate::inference::unifier::NativeDecl(Some(#table_column))), + )) + } else { + Ok(Self(quote!(crate::inference::unifier::NativeDecl(None)))) + } + } +} + +impl Parse for ProjectionColumnDecl { + fn parse(input: ParseStream) -> syn::Result { + let spec = TypeDecl::parse(input)?; + if input.peek(token::As) { + let _: token::As = input.parse()?; + let alias = Ident::parse(input)?; + let alias = alias.to_string(); + Ok(Self( + quote!(crate::inference::unifier::ProjectionColumnDecl(Box::new(#spec), Some(#alias.into()))), + )) + } else { + Ok(Self( + quote!(crate::inference::unifier::ProjectionColumnDecl(Box::new(#spec), None)), + )) + } + } +} + +impl Parse for ProjectionDecl { + fn parse(input: ParseStream) -> syn::Result { + let content; + braced!(content in input); + + let mut specs: Vec = Vec::new(); + + loop { + specs.push(ProjectionColumnDecl::parse(&content)?); + + if !content.peek(token::Comma) { + break; + } + + token::Comma::parse(&content)?; + } + + Ok(Self(quote!(crate::inference::unifier::ProjectionDecl( + Vec::from_iter(vec![#(#specs,)*]) + )))) + } +} + +impl Parse for SetOfDecl { + fn parse(input: ParseStream) -> syn::Result { + let _: kw::SetOf = input.parse()?; + let _: Token![<] = input.parse()?; + let type_decl = TypeDecl::parse(input)?; + let _: Token![>] = input.parse()?; + + Ok(Self( + quote!(crate::inference::unifier::SetOfDecl(Box::new(#type_decl))), + )) + } +} + +impl Parse for ArrayDecl { + fn parse(input: ParseStream) -> syn::Result { + let content; + bracketed!(content in input); + + let type_spec = TypeDecl::parse(&content)?; + + Ok(Self( + quote!(crate::inference::unifier::ArrayDecl(Box::new(#type_spec))), + )) + } +} + +impl Parse for EqlTerm { + fn parse(input: ParseStream) -> syn::Result { + let _: kw::EQL = input.parse()?; + + let content; + parenthesized!(content in input); + + let table = Ident::parse(&content)?; + let table = table.to_string(); + let _: token::Dot = content.parse()?; + let column = Ident::parse(&content)?; + let column = column.to_string(); + + if content.peek(token::Colon) { + let _: token::Colon = content.parse()?; + let bounds = EqlTraits::parse(&content)?; + + Ok(Self(quote! { + crate::inference::unifier::EqlTerm::Full( + crate::inference::unifier::EqlValue( + crate::inference::unifier::TableColumn { + table: #table.into(), + column: #column.into(), + }, + #bounds, + ), + ) + })) + } else { + Ok(Self(quote! { + crate::inference::unifier::EqlTerm::Full( + crate::inference::unifier::EqlValue( + crate::inference::unifier::TableColumn { + table: #table, + column: #column + }, + ), + crate::inference::unifier::EqlTraits::none(), + ) + })) + } + } +} + +impl Parse for TypeDecl { + fn parse(input: ParseStream) -> syn::Result { + if AssociatedTypeDecl::parse(&input.fork()).is_ok() { + let inner = AssociatedTypeDecl::parse(input)?; + return Ok(Self(quote! { + crate::inference::unifier::TypeDecl::AssociatedType(#inner) + })); + } + + if SetOfDecl::parse(&input.fork()).is_ok() { + let inner = SetOfDecl::parse(input)?; + return Ok(Self(quote! { + crate::inference::unifier::TypeDecl::SetOf(#inner) + })); + } + + if NativeDecl::parse(&input.fork()).is_ok() { + let inner = NativeDecl::parse(input)?; + return Ok(Self(quote! { + crate::inference::unifier::TypeDecl::Native(#inner) + })); + } + + if EqlTerm::parse(&input.fork()).is_ok() { + let inner = EqlTerm::parse(input)?; + return Ok(Self(quote! { + crate::inference::unifier::TypeDecl::Eql(#inner) + })); + } + + if VarDecl::parse(&input.fork()).is_ok() { + let inner = VarDecl::parse(input)?; + return Ok(Self(quote! { + crate::inference::unifier::TypeDecl::Var(#inner) + })); + } + + if ArrayDecl::parse(&input.fork()).is_ok() { + let inner = ArrayDecl::parse(input)?; + return Ok(Self(quote! { + crate::inference::unifier::TypeDecl::Array(#inner) + })); + } + + if ProjectionDecl::parse(&input.fork()).is_ok() { + let inner = ProjectionDecl::parse(input)?; + return Ok(Self(quote! { + crate::inference::unifier::TypeDecl::Projection(#inner) + })); + } + + Err(syn::Error::new( + input.span(), + "could not parse as TypeDecl".to_string(), + )) + } +} + +impl Parse for FunctionDecl { + fn parse(input: ParseStream) -> syn::Result { + let schema = Ident::parse(input)?; + let schema = schema.to_string(); + let _: token::Dot = input.parse()?; + let function_name = Ident::parse(input)?; + let function_name = function_name.to_string(); + + let generic_args = if input.peek(token::Lt) { + token::Lt::parse(input)?; + let args = Punctuated::::parse_separated_nonempty(input)? + .into_iter() + .collect(); + token::Gt::parse(input)?; + args + } else { + Vec::new() + }; + + let content; + parenthesized!(content in input); + + let args: Vec = + Punctuated::::parse_separated_nonempty(&content)? + .into_iter() + .collect(); + + let _: token::RArrow = input.parse()?; + + let ret = TypeDecl::parse(input)?; + + let bounds: Vec<_> = if input.peek(token::Where) { + let _: token::Where = input.parse()?; + let boundeds = Punctuated::::parse_separated_nonempty(input)?; + boundeds.into_iter().collect() + } else { + vec![] + }; + + Ok(Self(quote! { + crate::inference::unifier::FunctionDecl { + name: sqltk::parser::ast::ObjectName(vec![ + sqltk::parser::ast::ObjectNamePart::Identifier(sqltk::parser::ast::Ident::new(#schema)), + sqltk::parser::ast::ObjectNamePart::Identifier(sqltk::parser::ast::Ident::new(#function_name)), + ]), + inner: crate::inference::unifier::FunctionSignatureDecl::new( + vec![#(#generic_args),*], + vec![#(#bounds),*], + vec![#(#args),*], + #ret, + ).expect("FunctionSignatureDecl creation failed due to a type error"), + } + })) + } +} + +impl Parse for TableColumn { + fn parse(input: ParseStream) -> syn::Result { + let table = Ident::parse(input)?; + let table = table.to_string(); + let _: token::Dot = input.parse()?; + let column = Ident::parse(input)?; + let column = column.to_string(); + + Ok(Self(quote! { + crate::TableColumn { + table: sqltk::parser::ast::Ident::new(#table), + column: sqltk::parser::ast::Ident::new(#column), + } + })) + } +} + +impl Parse for SqltkBinOp { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(token::RArrow) { + let _: token::RArrow = input.parse()?; + if input.peek(token::Gt) { + let _: token::Gt = input.parse()?; + return Ok(Self(quote!( + ::sqltk::parser::ast::BinaryOperator::LongArrow + ))); + } else { + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::Arrow))); + } + } + + if input.peek(token::At) { + let _: token::At = input.parse()?; + let _: token::Gt = input.parse()?; + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::AtArrow))); + } + + if input.peek(token::Le) { + let _: token::Le = input.parse()?; + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::LtEq))); + } + + if input.peek(token::Lt) { + let _: token::Lt = input.parse()?; + if input.peek(token::At) { + let _: token::At = input.parse()?; + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::ArrowAt))); + } else if input.peek(token::Gt) { + let _: token::Gt = input.parse()?; + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::NotEq))); + } + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::Lt))); + } + + if input.peek(token::Ge) { + let _: token::Ge = input.parse()?; + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::GtEq))); + } + + if input.peek(token::Eq) { + let _: token::Eq = input.parse()?; + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::Eq))); + } + + if input.peek(token::Gt) { + let _: token::Gt = input.parse()?; + return Ok(Self(quote!(::sqltk::parser::ast::BinaryOperator::Gt))); + } + + if input.peek(token::Tilde) { + let _: token::Tilde = input.parse()?; + if input.peek(token::Tilde) { + let _: token::Tilde = input.parse()?; + if input.peek(token::Star) { + let _: token::Star = input.parse()?; + return Ok(Self(quote!( + ::sqltk::parser::ast::BinaryOperator::PGILikeMatch + ))); + } else { + return Ok(Self(quote!( + ::sqltk::parser::ast::BinaryOperator::PGLikeMatch + ))); + } + } + } + + if input.peek(token::Not) { + let _: token::Not = input.parse()?; + if input.peek(token::Tilde) { + let _: token::Tilde = input.parse()?; + if input.peek(token::Tilde) { + let _: token::Tilde = input.parse()?; + if input.peek(token::Star) { + let _: token::Star = input.parse()?; + return Ok(Self(quote!( + ::sqltk::parser::ast::BinaryOperator::PGNotILikeMatch + ))); + } else { + return Ok(Self(quote!( + ::sqltk::parser::ast::BinaryOperator::PGNotLikeMatch + ))); + } + } + } + } + + Err(syn::Error::new( + input.span(), + "Expected an operator corresponding to one of the EQL traits Eq, Ord, TokenMatch or JsonLike".to_string(), + )) + } +} + +impl Parse for BinaryOpDecl { + fn parse(input: ParseStream) -> syn::Result { + let generic_args = if input.peek(token::Lt) { + token::Lt::parse(input)?; + let args = Punctuated::::parse_separated_nonempty(input)? + .into_iter() + .collect(); + token::Gt::parse(input)?; + args + } else { + Vec::new() + }; + + let content; + parenthesized!(content in input); + let lhs = TypeDecl::parse(&content)?; + let op = SqltkBinOp::parse(&content)?; + let rhs = TypeDecl::parse(&content)?; + + let _: token::RArrow = input.parse()?; + let ret = TypeDecl::parse(input)?; + + let bounds: Vec<_> = if input.peek(token::Where) { + let _: token::Where = input.parse()?; + let boundeds = Punctuated::::parse_separated_nonempty(input)?; + boundeds.into_iter().collect() + } else { + vec![] + }; + + Ok(Self(quote! { + crate::inference::unifier::BinaryOpDecl { + op: #op, + inner: crate::inference::unifier::FunctionSignatureDecl::new( + vec![#(#generic_args),*], + vec![#(#bounds),*], + vec![#lhs, #rhs], + #ret, + ).expect("FunctionSignatureDecl creation failed due to a type error"), + } + })) + } +} + +impl Parse for TypeEquation { + fn parse(input: ParseStream) -> syn::Result { + let tvar = TVar::parse(input)?; + let _ = token::Eq::parse(input)?; + let type_decl = TypeDecl::parse(input)?; + + Ok(Self(quote! { + env.add_decl(#tvar, #type_decl); + })) + } +} + +impl Parse for TypeEnvDecl { + fn parse(input: ParseStream) -> syn::Result { + let decls: Vec<_> = Punctuated::::parse_terminated(input)? + .into_iter() + .collect(); + + Ok(Self(quote! { + { + let mut env = crate::inference::unifier::TypeEnv::new(); + #( #decls )* + env + } + })) + } +} + +pub(crate) struct BinaryOpDecls { + ops: Vec, +} + +impl ToTokens for BinaryOpDecls { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let ops = &self.ops; + tokens.append_all(quote!(vec![#(#ops),*])); + } +} + +impl Parse for BinaryOpDecls { + fn parse(input: ParseStream) -> syn::Result { + let ops = Punctuated::::parse_terminated(input)? + .into_iter() + .collect(); + Ok(Self { ops }) + } +} + +pub(crate) struct FunctionDecls { + ops: Vec, +} + +impl ToTokens for FunctionDecls { + fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) { + let ops = &self.ops; + tokens.append_all(quote!(vec![#(#ops),*])); + } +} + +impl Parse for FunctionDecls { + fn parse(input: ParseStream) -> syn::Result { + let ops = Punctuated::::parse_terminated(input)? + .into_iter() + .collect(); + Ok(Self { ops }) + } +} + +pub(crate) struct ShallowInitTypes { + pub(crate) unifier: syn::Expr, + pub(crate) bindings: Vec, +} + +impl Parse for ShallowInitTypes { + fn parse(input: ParseStream) -> syn::Result { + let unifier = syn::Expr::parse(input)?; + let _: Token![,] = input.parse()?; + let content; + braced!(content in input); + let bindings: Vec = Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(); + Ok(Self { unifier, bindings }) + } +} + +impl ToTokens for ShallowInitTypes { + fn to_tokens(&self, tokens: &mut TokenStream) { + let unifier = &self.unifier; + for binding in self.bindings.iter() { + tokens.append_all(quote! { + #binding.instantiate_shallow(#unifier).unwrap(); + }); + } + } +} + +pub(crate) struct ConcreteTyArgs { + pub(crate) ty_decl: TypeDecl, + pub(crate) ty_as: Option, +} + +impl Parse for ConcreteTyArgs { + fn parse(input: ParseStream) -> syn::Result { + let ty_decl = input.parse()?; + let ty_as = if input.peek(Token![as]) { + let _: Token![as] = input.parse()?; + let ty: TypePath = input.parse()?; + Some(ty) + } else { + None + }; + + Ok(Self { ty_decl, ty_as }) + } +} + +pub(crate) struct Binding { + pub(crate) var: Ident, + pub(crate) type_decl: TypeDecl, +} + +impl Parse for Binding { + fn parse(input: ParseStream) -> syn::Result { + let _ = syn::token::Let::parse(input)?; + let var = Ident::parse(input)?; + let _: Token![=] = input.parse()?; + let type_decl = TypeDecl::parse(input)?; + + Ok(Self { var, type_decl }) + } +} + +impl ToTokens for Binding { + fn to_tokens(&self, tokens: &mut TokenStream) { + let var = &self.var; + let type_decl = &self.type_decl; + + tokens.append_all(quote! { + let #var = #type_decl + }); + } +} + +#[cfg(test)] +mod test { + use pretty_assertions::assert_eq; + use quote::quote; + use syn::parse2; + + use crate::parse_type_decl::{AssociatedTypeDecl, BinaryOpDecl, TVar}; + + #[test] + fn parse_tvar() { + let parsed: TVar = parse2(quote!(T)).unwrap(); + + assert_eq!( + parsed.0.to_string(), + quote!(crate::inference::unifier::TVar("T".to_string())).to_string() + ); + } + + #[test] + fn parse_associated_type() { + let parsed: AssociatedTypeDecl = parse2(quote!(::Accessor)).unwrap(); + + assert_eq!( + parsed.0.to_string(), + quote!(crate::inference::unifier::AssociatedTypeDecl { + impl_decl: Box::new(crate::inference::unifier::TypeDecl::Var( + crate::inference::unifier::VarDecl { + tvar: crate::inference::unifier::TVar("T".to_string()), + bounds: crate::inference::unifier::EqlTraits::none() + } + )), + as_eql_trait: crate::inference::unifier::EqlTrait::JsonLike, + type_name: "Accessor", + }) + .to_string() + ); + } + + #[test] + fn parse_binary_operators() { + let parsed: BinaryOpDecl = parse2(quote!((T = T) -> Native where T: Eq)).unwrap(); + + assert_eq!( + parsed.0.to_string(), + quote!(crate::inference::unifier::BinaryOpDecl { + op: ::sqltk::parser::ast::BinaryOperator::Eq, + inner: crate::inference::unifier::FunctionSignatureDecl::new( + vec![crate::inference::unifier::TVar("T".to_string())], + vec![crate::inference::unifier::BoundsDecl( + crate::inference::unifier::TVar("T".to_string()), + crate::inference::unifier::EqlTraits::from_iter(vec![ + crate::inference::unifier::EqlTrait::Eq + ]) + )], + vec![ + crate::inference::unifier::TypeDecl::Var( + crate::inference::unifier::VarDecl { + tvar: crate::inference::unifier::TVar("T".to_string()), + bounds: crate::inference::unifier::EqlTraits::default(), + } + ), + crate::inference::unifier::TypeDecl::Var( + crate::inference::unifier::VarDecl { + tvar: crate::inference::unifier::TVar("T".to_string()), + bounds: crate::inference::unifier::EqlTraits::default(), + } + ) + ], + crate::inference::unifier::TypeDecl::Native( + crate::inference::unifier::NativeDecl(None) + ), + ) + .expect("FunctionSignatureDecl creation failed due to a type error"), + }) + .to_string() + ); + } +} diff --git a/packages/eql-mapper-macros/src/trace_infer.rs b/packages/eql-mapper-macros/src/trace_infer.rs new file mode 100644 index 00000000..21592b7c --- /dev/null +++ b/packages/eql-mapper-macros/src/trace_infer.rs @@ -0,0 +1,110 @@ +use proc_macro::TokenStream; +use quote::{quote, ToTokens}; +use syn::{ + parse::Parse, parse_macro_input, parse_quote, Attribute, FnArg, Ident, ImplItem, ImplItemFn, + ItemImpl, Pat, PatType, Signature, Type, TypePath, TypeReference, +}; + +pub(super) fn trace_infer_(_attr: TokenStream, item: TokenStream) -> TokenStream { + let mut input = parse_macro_input!(item as ItemImpl); + + for item in &mut input.items { + if let ImplItem::Fn(ImplItemFn { + attrs, + sig: + Signature { + ident: method, + inputs, + .. + }, + .. + }) = item + { + let node_ident_and_type: Option<(&Ident, &Type)> = + if let Some(FnArg::Typed(PatType { + ty: node_ty, pat, .. + })) = inputs.get(1) + { + if let Pat::Ident(pat_ident) = &**pat { + Some((&pat_ident.ident, node_ty)) + } else { + None + } + } else { + None + }; + + let vec_ident: Ident = parse_quote!(Vec); + + match node_ident_and_type { + Some((node_ident, node_ty)) => { + let (formatter, node_ty_abbrev) = match node_ty { + Type::Reference(TypeReference { elem, .. }) => match &**elem { + Type::Path(TypePath { path, .. }) => { + let last_segment = path.segments.last().unwrap(); + let last_segment_ident = &last_segment.ident; + let last_segment_arguments = if last_segment.arguments.is_empty() { + None + } else { + let args = &last_segment.arguments; + Some(quote!(<#args>)) + }; + match last_segment_ident { + ident if vec_ident == *ident => { + (quote!(crate::FmtAstVec), quote!(#last_segment_ident #last_segment_arguments)) + } + _ => (quote!(crate::FmtAst), quote!(#last_segment_ident #last_segment_arguments)) + } + }, + _ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>") + }, + _ => unreachable!("Infer::infer_enter/infer_exit has sig: infer_..(&mut self, delete: &'ast N) -> Result<(), TypeError>") + }; + + let node_ty_abbrev = node_ty_abbrev + .to_token_stream() + .to_string() + .replace(" ", ""); + + let target = format!("eql-mapper::{}", method.to_string().to_uppercase()); + + let attr: TracingInstrumentAttr = syn::parse2(quote! { + #[tracing::instrument( + target = #target, + level = "trace", + skip(self, #node_ident), + fields( + ast_ty = #node_ty_abbrev, + ast = %#formatter(#node_ident), + ), + err(Debug) + )] + }) + .unwrap(); + attrs.push(attr.attr); + } + None => { + return quote!(compile_error!( + "could not determine name of node argumemt in Infer impl" + )) + .to_token_stream() + .into(); + } + } + } + } + + input.to_token_stream().into() +} + +struct TracingInstrumentAttr { + attr: Attribute, +} + +impl Parse for TracingInstrumentAttr { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self { + attr: Attribute::parse_outer(input)?.first().unwrap().clone(), + }) + } +} diff --git a/packages/eql-mapper/Cargo.toml b/packages/eql-mapper/Cargo.toml index 3c63989e..339bc7ec 100644 --- a/packages/eql-mapper/Cargo.toml +++ b/packages/eql-mapper/Cargo.toml @@ -12,7 +12,7 @@ authors = [ [dependencies] eql-mapper-macros = { path = "../eql-mapper-macros" } -derive_more = { version = "^1.0", features = ["display", "constructor"] } +derive_more = { version = "^1.0", features = ["display", "constructor", "deref", "deref_mut"] } impl-trait-for-tuples = "0.2.3" itertools = "^0.13" sqltk = { workspace = true } @@ -20,6 +20,7 @@ thiserror = { workspace = true } tracing = { workspace = true } tracing-subscriber = { workspace = true } vec1 = "1.12.1" +topological-sort = "0.2.2" [dev-dependencies] pretty_assertions = "^1.0" diff --git a/packages/eql-mapper/src/display_helpers.rs b/packages/eql-mapper/src/display_helpers.rs index ca45080f..a5375cf7 100644 --- a/packages/eql-mapper/src/display_helpers.rs +++ b/packages/eql-mapper/src/display_helpers.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, fmt::{Debug, Display}, + sync::Arc, }; use sqltk::parser::ast::{ @@ -8,7 +9,7 @@ use sqltk::parser::ast::{ }; use sqltk::NodeKey; -use crate::{EqlValue, Param, Type}; +use crate::{unifier::EqlTerm, Param, Type}; #[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)] pub struct Fmt(pub(crate) T); @@ -76,6 +77,19 @@ impl Display for Fmt<&HashMap, Type>> { } } +impl Display for Fmt<&[Arc]> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("[")?; + for (idx, ty) in self.0.iter().enumerate() { + f.write_fmt(format_args!("{ty}"))?; + if idx < self.0.len() - 1 { + f.write_str(", ")?; + } + } + f.write_str("]") + } +} + impl Display for FmtAstVec<&Vec> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("![")?; @@ -101,19 +115,19 @@ impl Display for Fmt<&Vec<(Param, crate::Value)>> { let formatted = self .0 .iter() - .map(|(p, v)| format!("{}: {}", p, v)) + .map(|(p, v)| format!("{p}: {v}")) .collect::>() .join(", "); f.write_str(&formatted) } } -impl Display for Fmt<&Vec<(EqlValue, &sqltk::parser::ast::Value)>> { +impl Display for Fmt<&Vec<(EqlTerm, &sqltk::parser::ast::Value)>> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { let formatted = self .0 .iter() - .map(|(e, n)| format!("{}: {}", n, e)) + .map(|(e, n)| format!("{e}: {n}")) .collect::>() .join(", "); f.write_str(&formatted) diff --git a/packages/eql-mapper/src/eql_mapper.rs b/packages/eql-mapper/src/eql_mapper.rs index 60a19262..495bcd5c 100644 --- a/packages/eql-mapper/src/eql_mapper.rs +++ b/packages/eql-mapper/src/eql_mapper.rs @@ -1,9 +1,9 @@ use super::importer::{ImportError, Importer}; use crate::{ inference::{TypeError, TypeInferencer}, - unifier::{EqlValue, Unifier}, - DepMut, Fmt, Param, ParamError, ScopeError, ScopeTracker, TableResolver, Type, - TypeCheckedStatement, TypeRegistry, Value, + unifier::{EqlTerm, Unifier}, + DepMut, Param, ParamError, ScopeError, ScopeTracker, TableResolver, Type, TypeCheckedStatement, + TypeRegistry, Value, }; use sqltk::parser::ast::{self as ast, Statement}; use sqltk::{Break, NodeKey, Visitable, Visitor}; @@ -147,6 +147,12 @@ impl<'ast> EqlMapper<'ast> { let _guard = span_begin.enter(); + // let _ = self.unifier.borrow_mut().resolve_unresolved_type_vars(); + let _ = self + .unifier + .borrow_mut() + .resolve_unresolved_associated_types(); + let _ = self.unifier.borrow_mut().resolve_unresolved_value_nodes(); let projection = self.projection_type(statement); @@ -159,15 +165,15 @@ impl<'ast> EqlMapper<'ast> { match combine_results() { Ok((projection, params, literals, node_types)) => { - event!( - target: "eql-mapper::EVENT_RESOLVE_OK", - parent: &span_begin, - Level::TRACE, - projection = %&projection, - params = %Fmt(¶ms), - literals = %Fmt(&literals), - node_types = %Fmt(&node_types) - ); + // event!( + // target: "eql-mapper::EVENT_RESOLVE_OK", + // parent: &span_begin, + // Level::TRACE, + // projection = %&projection, + // params = %Fmt(¶ms), + // literals = %Fmt(&literals), + // node_types = %Fmt(&node_types) + // ); Ok(TypeCheckedStatement::new( statement, @@ -222,10 +228,10 @@ impl<'ast> EqlMapper<'ast> { .into_iter() .map(|(p, ty)| -> Result<(Param, Value), EqlMapperError> { match ty.resolved(&mut self.unifier.borrow_mut())? { - Type::Value(value) => Ok((p, value)), + Type::Value(value) if value.contains_eql() => Ok((p, value)), + Type::Value(value) if !value.contains_eql() => Ok((p, value)), other => Err(TypeError::Expected(format!( - "expected param '{}' to resolve to a scalar type but got '{}'", - p, other + "expected param '{p}' to resolve to a scalar type but got '{other}'" )))?, } }) @@ -235,21 +241,21 @@ impl<'ast> EqlMapper<'ast> { } /// Asks the [`TypeInferencer`] for a hashmap of literal types, validating that they are all `Value` types. - fn literal_types(&self) -> Result, EqlMapperError> { - let iter = { + fn literal_types(&self) -> Result, EqlMapperError> { + let literals = { let registry = self.registry.borrow(); registry .get_nodes_and_types::() .into_iter() .filter(|(node, _)| !matches!(node, ast::Value::Placeholder(_))) }; - let literal_nodes: Vec<(EqlValue, &'ast ast::Value)> = iter + + let literal_nodes: Vec<(EqlTerm, &'ast ast::Value)> = literals .map( - |(node, ty)| -> Result, TypeError> { - if let crate::Type::Value(crate::Value::Eql(eql_value)) = - &ty.resolved(&mut self.unifier.borrow_mut())? - { - return Ok(Some((eql_value.clone(), node))); + |(node, ty)| -> Result, TypeError> { + let resolved_ty = ty.resolved(&mut self.unifier.borrow_mut())?; + if let crate::Type::Value(crate::Value::Eql(eql_term)) = &resolved_ty { + return Ok(Some((eql_term.clone(), node))); } Ok(None) }, diff --git a/packages/eql-mapper/src/importer.rs b/packages/eql-mapper/src/importer.rs index bdbe084a..9d744ee6 100644 --- a/packages/eql-mapper/src/importer.rs +++ b/packages/eql-mapper/src/importer.rs @@ -1,10 +1,7 @@ use crate::{ - inference::{ - unifier::{Constructor, Type}, - TypeError, TypeRegistry, - }, + inference::{unifier::Type, TypeError, TypeRegistry}, model::{SchemaError, TableResolver}, - unifier::{Projection, ProjectionColumns}, + unifier::{Projection, Value}, Relation, ScopeError, ScopeTracker, }; use sqltk::parser::ast::{ @@ -46,14 +43,11 @@ impl<'ast> Importer<'ast> { { let table = self.table_resolver.resolve_table(table_name)?; - let cols = ProjectionColumns::new_from_schema_table(table.clone()); + let projection = Projection::new_from_schema_table(table.clone()); self.scope_tracker.borrow_mut().add_relation(Relation { name: table_alias.clone(), - projection_type: Type::Constructor(Constructor::Projection( - Projection::WithColumns(cols), - )) - .into(), + projection_type: Type::Value(Value::Projection(projection)).into(), })?; Ok(()) @@ -114,14 +108,11 @@ impl<'ast> Importer<'ast> { if scope_tracker.resolve_relation(name).is_err() { let table = self.table_resolver.resolve_table(name)?; - let cols = ProjectionColumns::new_from_schema_table(table.clone()); + let projection = Projection::new_from_schema_table(table.clone()); scope_tracker.add_relation(Relation { name: record_as.cloned().ok(), - projection_type: Type::Constructor(Constructor::Projection( - Projection::WithColumns(cols), - )) - .into(), + projection_type: Type::Value(Value::Projection(projection)).into(), })?; } } diff --git a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs index 063710af..fc18a555 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/expr.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/expr.rs @@ -1,36 +1,37 @@ use crate::{ + get_sql_binop_rule, inference::{unifier::Type, InferType, TypeError}, SqlIdent, TypeInferencer, }; use eql_mapper_macros::trace_infer; -use sqltk::parser::ast::{AccessExpr, Array, BinaryOperator, Expr, Ident, Subscript}; +use sqltk::parser::ast::{AccessExpr, Array, Expr, Ident, Subscript}; #[trace_infer] impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { - fn infer_exit(&mut self, this_expr: &'ast Expr) -> Result<(), TypeError> { - match this_expr { + fn infer_exit(&mut self, return_val: &'ast Expr) -> Result<(), TypeError> { + match return_val { // Resolve an identifier using the scope, except if it happens to to be the DEFAULT keyword // in which case we resolve it to a fresh type variable. Expr::Identifier(ident) => { // sqltk_parser treats the `DEFAULT` keyword in expression position as an identifier. if SqlIdent(ident) == SqlIdent(&Ident::new("default")) { - self.unify_node_with_type(this_expr, self.fresh_tvar())?; + self.unify_node_with_type(return_val, self.fresh_tvar())?; } else { - self.unify_node_with_type(this_expr, self.resolve_ident(ident)?)?; + self.unify_node_with_type(return_val, self.resolve_ident(ident)?)?; }; } Expr::CompoundIdentifier(idents) => { - self.unify_node_with_type(this_expr, self.resolve_compound_ident(idents)?)?; + self.unify_node_with_type(return_val, self.resolve_compound_ident(idents)?)?; } Expr::Wildcard(_) => { - self.unify_node_with_type(this_expr, self.resolve_wildcard()?)?; + self.unify_node_with_type(return_val, self.resolve_wildcard()?)?; } Expr::QualifiedWildcard(object_name, _) => { self.unify_node_with_type( - this_expr, + return_val, self.resolve_qualified_wildcard(object_name)?, )?; } @@ -50,13 +51,13 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { | Expr::IsUnknown(expr) | Expr::IsNotUnknown(expr) => { self.unify_node_with_type( - this_expr, - self.unify(self.get_node_type(&**expr), Type::any_native())?, + return_val, + self.unify(self.get_node_type(&**expr), Type::native())?, )?; } Expr::IsDistinctFrom(a, b) | Expr::IsNotDistinctFrom(a, b) => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; self.unify_nodes(&**a, &**b)?; } @@ -65,7 +66,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { list, negated: _, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; self.unify_node_with_type( &**expr, list.iter().try_fold(self.get_node_type(&**expr), |a, b| { @@ -79,7 +80,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { subquery, negated: _, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; let ty = Type::projection(&[(self.get_node_type(&**expr), None)]); self.unify_node_with_type(&**subquery, ty)?; } @@ -94,91 +95,12 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { low, high, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; self.unify_node_with_type(&**high, self.unify_nodes(&**expr, &**low)?)?; } Expr::BinaryOp { left, op, right } => { - match op { - // Operators resolve to boolean (native) - // The left and right need to resolve to the same type - BinaryOperator::And - | BinaryOperator::Eq - | BinaryOperator::Gt - | BinaryOperator::GtEq - | BinaryOperator::Lt - | BinaryOperator::LtEq - | BinaryOperator::NotEq - | BinaryOperator::Or => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_nodes(&**left, &**right)?; - } - BinaryOperator::Plus - | BinaryOperator::Minus - | BinaryOperator::Multiply - | BinaryOperator::Divide - | BinaryOperator::Modulo - | BinaryOperator::StringConcat - | BinaryOperator::Spaceship - | BinaryOperator::Xor - | BinaryOperator::BitwiseOr - | BinaryOperator::BitwiseAnd - | BinaryOperator::BitwiseXor - | BinaryOperator::DuckIntegerDivide - | BinaryOperator::MyIntegerDivide - | BinaryOperator::Custom(_) - | BinaryOperator::PGBitwiseXor - | BinaryOperator::PGBitwiseShiftLeft - | BinaryOperator::PGBitwiseShiftRight - | BinaryOperator::PGExp - | BinaryOperator::PGOverlap - | BinaryOperator::PGRegexMatch - | BinaryOperator::PGRegexIMatch - | BinaryOperator::PGRegexNotMatch - | BinaryOperator::PGRegexNotIMatch - | BinaryOperator::PGLikeMatch - | BinaryOperator::PGILikeMatch - | BinaryOperator::PGNotLikeMatch - | BinaryOperator::PGNotILikeMatch - | BinaryOperator::PGStartsWith - | BinaryOperator::PGCustomBinaryOperator(_) => { - // EQL columns don't support these operators, so we only care that the output and inputs unify to a native type. - self.unify_node_with_type(&**left, Type::any_native())?; - self.unify_node_with_type(&**right, Type::any_native())?; - self.unify_node_with_type(this_expr, Type::any_native())?; - } - - // JSON(B) operators. - // Left side is JSON(B) and must unify to Scalar::Native, or Scalar::Encrypted(_). - BinaryOperator::Arrow - | BinaryOperator::LongArrow - | BinaryOperator::HashArrow - | BinaryOperator::HashLongArrow - | BinaryOperator::AtAt - | BinaryOperator::HashMinus // TODO do not support for EQL - | BinaryOperator::AtQuestion - | BinaryOperator::Question - | BinaryOperator::QuestionAnd - | BinaryOperator::QuestionPipe => { - self.unify_node_with_type(this_expr, self.unify_nodes(&**left, &**right)?)?; - } - - // JSON(B)/Array containment operators (@> and <@) - // Both sides must unify to the same type. - BinaryOperator::AtArrow | BinaryOperator::ArrowAt => { - self.unify_node_with_type(this_expr, self.unify_nodes(&**left, &**right)?)?; - } - - BinaryOperator::Overlaps| - BinaryOperator::DoubleHash| - BinaryOperator::LtDashGt - | BinaryOperator::AndLt | BinaryOperator::AndGt | BinaryOperator::LtLtPipe | - BinaryOperator::PipeGtGt | BinaryOperator::AndLtPipe| BinaryOperator::PipeAndGt | - - BinaryOperator::LtCaret | BinaryOperator::GtCaret | BinaryOperator::QuestionHash | - BinaryOperator::QuestionDash | BinaryOperator::QuestionDashPipe | BinaryOperator::QuestionDoublePipe | - BinaryOperator::At | BinaryOperator::TildeEq |BinaryOperator::Assignment=> { self.unify_node_with_type(this_expr, Type::any_native())?; } - } + get_sql_binop_rule(op).apply_constraints(self, left, right, return_val)?; } //customer_name LIKE 'A%'; @@ -196,7 +118,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { escape_char: _, any: false, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; self.unify_nodes(&**expr, &**pattern)?; } @@ -212,8 +134,8 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { pattern, escape_char: _, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_nodes_with_type(&**expr, &**pattern, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_nodes_with_type(&**expr, &**pattern, Type::native())?; } Expr::RLike { .. } => Err(TypeError::UnsupportedSqlFeature( @@ -231,7 +153,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { compare_op: _, right, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; self.unify_nodes(&**left, &**right)?; } @@ -240,16 +162,16 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { | Expr::UnaryOp { expr, .. } | Expr::Convert { expr, .. } | Expr::Cast { expr, .. } => { - self.unify_nodes_with_type(this_expr, &**expr, Type::any_native())?; + self.unify_nodes_with_type(return_val, &**expr, Type::native())?; } Expr::AtTimeZone { timestamp, time_zone, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_node_with_type(&**timestamp, Type::any_native())?; - self.unify_node_with_type(&**time_zone, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(&**timestamp, Type::native())?; + self.unify_node_with_type(&**time_zone, Type::native())?; } Expr::Extract { @@ -257,13 +179,13 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { syntax: _, expr, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_node_with_type(&**expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(&**expr, Type::native())?; } Expr::Position { expr, r#in } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_nodes_with_type(&**expr, &**r#in, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_nodes_with_type(&**expr, &**r#in, Type::native())?; } Expr::Substring { @@ -273,13 +195,13 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { special: _, shorthand: _, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_node_with_type(&**expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(&**expr, Type::native())?; if let Some(expr) = substring_from { - self.unify_node_with_type(&**expr, Type::any_native())?; + self.unify_node_with_type(&**expr, Type::native())?; } if let Some(expr) = substring_for { - self.unify_node_with_type(&**expr, Type::any_native())?; + self.unify_node_with_type(&**expr, Type::native())?; } } @@ -291,16 +213,16 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { trim_what, trim_characters, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_node_with_type(&**expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(&**expr, Type::native())?; if let Some(trim_where) = trim_where { - self.unify_node_with_type(trim_where, Type::any_native())?; + self.unify_node_with_type(trim_where, Type::native())?; } if let Some(trim_what) = trim_what { - self.unify_node_with_type(&**trim_what, Type::any_native())?; + self.unify_node_with_type(&**trim_what, Type::native())?; } if let Some(trim_characters) = trim_characters { - self.unify_all_with_type(trim_characters, Type::any_native())?; + self.unify_all_with_type(trim_characters, Type::native())?; } } @@ -310,39 +232,39 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { overlay_from, overlay_for, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_node_with_type(&**expr, Type::any_native())?; - self.unify_node_with_type(&**overlay_what, Type::any_native())?; - self.unify_node_with_type(&**overlay_from, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(&**expr, Type::native())?; + self.unify_node_with_type(&**overlay_what, Type::native())?; + self.unify_node_with_type(&**overlay_from, Type::native())?; if let Some(overlay_for) = overlay_for { - self.unify_node_with_type(&**overlay_for, Type::any_native())?; + self.unify_node_with_type(&**overlay_for, Type::native())?; } } Expr::Collate { expr, collation: _ } => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_node_with_type(&**expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(&**expr, Type::native())?; } // The current `Expr` shares the same type hole as the sub-expression Expr::Nested(expr) => { - self.unify_nodes(this_expr, &**expr)?; + self.unify_nodes(return_val, &**expr)?; } Expr::Value(value) => { - self.unify_nodes(this_expr, value)?; + self.unify_nodes(return_val, value)?; } Expr::TypedString { data_type: _, value: _, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; } // The return type of this function and the return type of this expression must be the same type. Expr::Function(function) => { - self.unify_node_with_type(this_expr, self.get_node_type(function))?; + self.unify_node_with_type(return_val, self.get_node_type(function))?; } // When operand is Some(operand), all conditions must be of the same type as the operand and much support equality @@ -360,7 +282,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { Some(operand) => { for cond_when in conditions { self.unify_nodes_with_type( - this_expr, + return_val, &**operand, self.unify_node_with_type(&cond_when.condition, self.fresh_tvar())?, )?; @@ -368,7 +290,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { } None => { for cond_when in conditions { - self.unify_node_with_type(&cond_when.condition, Type::any_native())?; + self.unify_node_with_type(&cond_when.condition, Type::native())?; } } } @@ -381,18 +303,18 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { self.unify_node_with_type(else_result, result_ty.clone())?; }; - self.unify_node_with_type(this_expr, result_ty)?; + self.unify_node_with_type(return_val, result_ty)?; } Expr::Exists { subquery: _, negated: _, } => { - self.unify_node_with_type(this_expr, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; } Expr::Subquery(subquery) => { - self.unify_nodes(this_expr, &**subquery)?; + self.unify_nodes(return_val, &**subquery)?; } // unsupported SQL features @@ -435,16 +357,16 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { AccessExpr::Subscript(Subscript::Index { index }) => { access_ty = self.fresh_tvar(); root_ty = Type::array(access_ty.clone()); - self.unify_node_with_type(index, Type::any_native())?; + self.unify_node_with_type(index, Type::native())?; } AccessExpr::Subscript(Subscript::Slice { lower_bound, upper_bound, stride, }) => { - self.unify_node_with_type(lower_bound, Type::any_native())?; - self.unify_node_with_type(upper_bound, Type::any_native())?; - self.unify_node_with_type(stride, Type::any_native())?; + self.unify_node_with_type(lower_bound, Type::native())?; + self.unify_node_with_type(upper_bound, Type::native())?; + self.unify_node_with_type(stride, Type::native())?; access_ty = self.fresh_tvar(); root_ty = Type::array(access_ty.clone()); } @@ -456,7 +378,7 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { } } - self.unify_node_with_type(this_expr, access_ty)?; + self.unify_node_with_type(return_val, access_ty)?; self.unify_node_with_type(&**root, root_ty)?; } @@ -464,13 +386,13 @@ impl<'ast> InferType<'ast, Expr> for TypeInferencer<'ast> { // Constrain all elements of the array to be the same type. let elem_ty = self.unify_all_with_type(elem, self.fresh_tvar())?; let array_ty = Type::array(elem_ty); - self.unify_node_with_type(this_expr, array_ty)?; + self.unify_node_with_type(return_val, array_ty)?; } // interval is unmapped, value is unmapped Expr::Interval(interval) => { - self.unify_node_with_type(this_expr, Type::any_native())?; - self.unify_node_with_type(&*interval.value, Type::any_native())?; + self.unify_node_with_type(return_val, Type::native())?; + self.unify_node_with_type(&*interval.value, Type::native())?; } // mysql specific diff --git a/packages/eql-mapper/src/inference/infer_type_impls/function.rs b/packages/eql-mapper/src/inference/infer_type_impls/function.rs index f3b25969..509c3dcd 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/function.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/function.rs @@ -1,9 +1,7 @@ use eql_mapper_macros::trace_infer; use sqltk::parser::ast::{Function, FunctionArguments}; -use crate::{ - get_sql_function_def, inference::infer_type::InferType, FunctionSig, TypeError, TypeInferencer, -}; +use crate::{get_sql_function, inference::infer_type::InferType, TypeError, TypeInferencer}; /// Looks up the function signature. /// @@ -19,20 +17,6 @@ impl<'ast> InferType<'ast, Function> for TypeInferencer<'ast> { )); } - let Function { name, args, .. } = function; - - match get_sql_function_def(name, args) { - Some(sql_fn) => { - sql_fn - .sig - .instantiate(&*self) - .apply_constraints(self, function)?; - } - None => { - FunctionSig::instantiate_native(function).apply_constraints(self, function)?; - } - } - - Ok(()) + get_sql_function(&function.name).apply_constraints(self, function) } } diff --git a/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs b/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs index dfde4b63..49e858e4 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/function_arg_expr.rs @@ -7,15 +7,16 @@ use crate::{inference::infer_type::InferType, unifier::Type, TypeError, TypeInfe impl<'ast> InferType<'ast, FunctionArgExpr> for TypeInferencer<'ast> { fn infer_exit(&mut self, farg_expr: &'ast FunctionArgExpr) -> Result<(), TypeError> { let farg_expr_ty = self.get_node_type(farg_expr); + match farg_expr { FunctionArgExpr::Expr(expr) => { self.unify(farg_expr_ty, self.get_node_type(expr))?; } - // COUNT(*) is the only function in SQL (that I can find) that accepts a wildcard as an argument. And it is - // *not* an expression - it is special case syntax that means "count all rows". If we see this syntax, we - // resolve the FunctionArgExpr type as Native. + // `COUNT(*)` is a special case in SQL. The `*` is NOT an expression - which would normally expand into a + // projection. `COUNT(*)` merely means "count all rows". As such, we should not attempt to resolve it as + // anything other than Native. FunctionArgExpr::QualifiedWildcard(_) | FunctionArgExpr::Wildcard => { - self.unify(farg_expr_ty, Type::any_native())?; + self.unify(farg_expr_ty, Type::native())?; } }; diff --git a/packages/eql-mapper/src/inference/infer_type_impls/insert_statement.rs b/packages/eql-mapper/src/inference/infer_type_impls/insert_statement.rs index c7b7159e..81562849 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/insert_statement.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/insert_statement.rs @@ -1,12 +1,8 @@ use std::sync::Arc; use crate::{ - inference::{ - type_error::TypeError, - unifier::{Constructor, Type}, - InferType, - }, - unifier::{EqlValue, NativeValue, Value}, + inference::{type_error::TypeError, unifier::Type, InferType}, + unifier::{EqlTerm, EqlValue, NativeValue, Value}, ColumnKind, TableColumn, TypeInferencer, }; use eql_mapper_macros::trace_infer; @@ -46,16 +42,14 @@ impl<'ast> InferType<'ast, Insert> for TypeInferencer<'ast> { column: stc.column.clone(), }; - let value_ty = if stc.kind == ColumnKind::Native { - Value::Native(NativeValue(Some(tc.clone()))) - } else { - Value::Eql(EqlValue(tc.clone())) + let value_ty = match &stc.kind { + ColumnKind::Native => Value::Native(NativeValue(Some(tc.clone()))), + ColumnKind::Eql(features) => { + Value::Eql(EqlTerm::Full(EqlValue(tc.clone(), *features))) + } }; - ( - Arc::new(Type::Constructor(Constructor::Value(value_ty))), - Some(tc.column.clone()), - ) + (Arc::new(Type::Value(value_ty)), Some(tc.column.clone())) }) .collect::>(), ); diff --git a/packages/eql-mapper/src/inference/infer_type_impls/select_items.rs b/packages/eql-mapper/src/inference/infer_type_impls/select_items.rs index d6bc22ae..7d3108de 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/select_items.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/select_items.rs @@ -5,7 +5,7 @@ use sqltk::parser::ast::{ use crate::{ inference::{type_error::TypeError, unifier::Type, InferType}, - unifier::{Constructor, Projection, ProjectionColumn}, + unifier::{Projection, ProjectionColumn, Value}, TypeInferencer, }; @@ -83,7 +83,7 @@ impl<'ast> InferType<'ast, Vec> for TypeInferencer<'ast> { self.unify_node_with_type( select_items, - Type::Constructor(Constructor::Projection(Projection::new(projection_columns))), + Type::Value(Value::Projection(Projection::new(projection_columns))), )?; Ok(()) diff --git a/packages/eql-mapper/src/inference/infer_type_impls/values.rs b/packages/eql-mapper/src/inference/infer_type_impls/values.rs index 7dcafea1..a90f19ef 100644 --- a/packages/eql-mapper/src/inference/infer_type_impls/values.rs +++ b/packages/eql-mapper/src/inference/infer_type_impls/values.rs @@ -1,5 +1,5 @@ use crate::{ - inference::type_error::TypeError, inference::unifier::Type, inference::InferType, + inference::{type_error::TypeError, unifier::Type, InferType}, TypeInferencer, }; use eql_mapper_macros::trace_infer; @@ -28,8 +28,8 @@ impl<'ast> InferType<'ast, Values> for TypeInferencer<'ast> { .collect::>(); for row in values.rows.iter() { - for (idx, val) in row.iter().enumerate() { - self.unify(self.get_node_type(val), column_types[idx].clone())?; + for (idx, expr) in row.iter().enumerate() { + self.unify(self.get_node_type(expr), column_types[idx].clone())?; } } diff --git a/packages/eql-mapper/src/inference/mod.rs b/packages/eql-mapper/src/inference/mod.rs index f84f5b8d..05fbb1eb 100644 --- a/packages/eql-mapper/src/inference/mod.rs +++ b/packages/eql-mapper/src/inference/mod.rs @@ -2,8 +2,7 @@ mod infer_type; mod infer_type_impls; mod registry; mod sequence; -mod sql_fn_macros; -mod sql_functions; +mod sql_types; mod type_error; pub mod unifier; @@ -23,7 +22,7 @@ use crate::{ScopeError, ScopeTracker, TableResolver}; pub(crate) use registry::*; pub(crate) use sequence::*; -pub(crate) use sql_functions::*; +pub(crate) use sql_types::*; pub(crate) use type_error::*; /// [`Visitor`] implementation that performs type inference on AST nodes. @@ -127,9 +126,9 @@ impl<'ast> TypeInferencer<'ast> { match self.unify(self.get_node_type(lhs), self.get_node_type(rhs)) { Ok(unified) => Ok(unified), Err(err) => Err(TypeError::OnNodes( - format!("{:?}", lhs), + format!("{lhs:?}"), self.get_node_type(lhs), - format!("{:?}", rhs), + format!("{rhs:?}"), self.get_node_type(rhs), err.to_string(), )), diff --git a/packages/eql-mapper/src/inference/registry.rs b/packages/eql-mapper/src/inference/registry.rs index 45b39742..7f48459e 100644 --- a/packages/eql-mapper/src/inference/registry.rs +++ b/packages/eql-mapper/src/inference/registry.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, marker::PhantomData, sync::Arc}; +use std::{collections::HashMap, fmt::Debug, marker::PhantomData, sync::Arc}; use sqltk::{AsNodeKey, NodeKey}; use tracing::{span, Level}; @@ -8,7 +8,10 @@ use crate::{ Param, ParamError, }; -use super::Sequence; +use super::{ + unifier::{EqlTraits, Var}, + Sequence, +}; /// `TypeRegistry` maintains an association between `sqltk_parser` AST nodes and the node's inferred [`Type`]. #[derive(Debug)] @@ -38,8 +41,9 @@ impl<'ast> TypeRegistry<'ast> { } } - pub(crate) fn get_nodes_and_types(&self) -> Vec<(&'ast N, Arc)> { - self.node_types + pub(crate) fn get_nodes_and_types(&self) -> Vec<(&'ast N, Arc)> { + let result = self + .node_types .iter() .filter_map(|(key, tvar)| { key.get_as::().map(|n| { @@ -48,11 +52,13 @@ impl<'ast> TypeRegistry<'ast> { self.substitutions .get(tvar) .cloned() - .unwrap_or(Arc::new(Type::Var(*tvar))), + .unwrap_or(Arc::new(Type::Var(Var(*tvar, EqlTraits::default())))), ) }) }) - .collect() + .collect(); + + result } pub(crate) fn get_type(&self, tvar: TypeVar) -> Option> { @@ -77,7 +83,7 @@ impl<'ast> TypeRegistry<'ast> { self.substitutions .get(tvar) .cloned() - .unwrap_or(Arc::new(Type::Var(*tvar))), + .unwrap_or(Arc::new(Type::Var(Var(*tvar, EqlTraits::none())))), ) }) .collect() @@ -105,7 +111,7 @@ impl<'ast> TypeRegistry<'ast> { self.substitutions .get(tvar) .cloned() - .unwrap_or(Arc::new(Type::Var(*tvar))), + .unwrap_or(Arc::new(Type::Var(Var(*tvar, EqlTraits::none())))), ) }) .collect() @@ -123,7 +129,7 @@ impl<'ast> TypeRegistry<'ast> { self.substitutions .get(&tvar) .cloned() - .unwrap_or(Arc::new(Type::Var(tvar))) + .unwrap_or(Arc::new(Type::Var(Var(tvar, EqlTraits::none())))) }) } @@ -141,7 +147,7 @@ impl<'ast> TypeRegistry<'ast> { None => { let tvar = self.fresh_tvar(); self.node_types.insert(node.as_node_key(), tvar); - Type::Var(tvar).into() + Type::Var(Var(tvar, EqlTraits::none())).into() } } } @@ -150,11 +156,11 @@ impl<'ast> TypeRegistry<'ast> { /// associated `Type` then a fresh [`Type::Var`] will be assigned. fn get_or_init_param_type(&mut self, param: &'ast String) -> Arc { match self.param_types.get(¶m).cloned() { - Some(tvar) => Type::Var(tvar).into(), + Some(tvar) => Type::Var(Var(tvar, EqlTraits::none())).into(), None => { let tvar = self.fresh_tvar(); self.param_types.insert(param, tvar); - Type::Var(tvar).into() + Type::Var(Var(tvar, EqlTraits::none())).into() } } } diff --git a/packages/eql-mapper/src/inference/sql_fn_macros.rs b/packages/eql-mapper/src/inference/sql_fn_macros.rs deleted file mode 100644 index d0acec86..00000000 --- a/packages/eql-mapper/src/inference/sql_fn_macros.rs +++ /dev/null @@ -1,55 +0,0 @@ -#[macro_export] -macro_rules! to_kind { - (NATIVE) => { - $crate::Kind::Native - }; - ($generic:ident) => { - $crate::Kind::Generic(stringify!($generic)) - }; -} - -#[macro_export] -macro_rules! sql_fn_args { - (()) => { vec![] }; - - (($arg:ident)) => { vec![$crate::to_kind!($arg)] }; - - (($arg:ident $(,$rest:ident)*)) => { - vec![$crate::to_kind!($arg) $(, $crate::to_kind!($rest))*] - }; -} - -#[macro_export] -macro_rules! sql_fn { - ($name:ident $args:tt -> $return_kind:ident, rewrite) => { - $crate::SqlFunction::new( - stringify!($name), - FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)), - $crate::RewriteRule::AsEqlFunction, - ) - }; - - ($name:ident $args:tt -> $return_kind:ident) => { - $crate::SqlFunction::new( - stringify!($name), - FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)), - $crate::RewriteRule::Ignore, - ) - }; - - ($schema:ident . $name:ident $args:tt -> $return_kind:ident, rewrite) => { - $crate::SqlFunction::new( - stringify!($schema.$name), - FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)), - $crate::RewriteRule::AsEqlFunction, - ) - }; - - ($schema:ident . $name:ident $args:tt -> $return_kind:ident) => { - $crate::SqlFunction::new( - stringify!($schema.$name), - FunctionSig::new($crate::sql_fn_args!($args), $crate::to_kind!($return_kind)), - $crate::RewriteRule::Ignore, - ) - }; -} diff --git a/packages/eql-mapper/src/inference/sql_functions.rs b/packages/eql-mapper/src/inference/sql_functions.rs deleted file mode 100644 index ba8ad411..00000000 --- a/packages/eql-mapper/src/inference/sql_functions.rs +++ /dev/null @@ -1,246 +0,0 @@ -use std::{ - collections::{HashMap, HashSet}, - sync::{Arc, LazyLock}, -}; - -use sqltk::parser::ast::{ - Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, ObjectName, ObjectNamePart, -}; - -use itertools::Itertools; - -use crate::{sql_fn, unifier::Type, TypeInferencer}; - -use super::TypeError; - -/// The identifier and type signature of a SQL function. -/// -/// See [`SQL_FUNCTION_SIGNATURES`]. -#[derive(Debug)] -pub(crate) struct SqlFunction { - pub(crate) name: ObjectName, - pub(crate) sig: FunctionSig, - pub(crate) rewrite_rule: RewriteRule, -} - -#[derive(Debug)] -pub(crate) enum RewriteRule { - Ignore, - AsEqlFunction, -} - -/// A representation of the type of an argument or return type in a SQL function. -#[derive(Debug, Clone, Copy, Eq, PartialEq)] -pub(crate) enum Kind { - /// A type that must be a native type - Native, - - /// A type that can be a native or EQL type. The `str` is the generic variable name. - Generic(&'static str), -} - -/// The type signature of a SQL functon (excluding its name). -#[derive(Debug, Clone)] -pub(crate) struct FunctionSig { - args: Vec, - return_type: Kind, - generics: HashSet<&'static str>, -} - -/// A function signature but filled in with fresh type variables that correspond with the [`Kind`] or each argument and -/// return type. -#[derive(Debug, Clone)] -pub(crate) struct InstantiatedSig { - args: Vec>, - return_type: Arc, -} - -impl FunctionSig { - fn new(args: Vec, return_type: Kind) -> Self { - let mut generics: HashSet<&'static str> = HashSet::new(); - - for arg in &args { - if let Kind::Generic(generic) = arg { - generics.insert(*generic); - } - } - - if let Kind::Generic(generic) = return_type { - generics.insert(generic); - } - - Self { - args, - return_type, - generics, - } - } - - /// Checks if `self` is applicable to a particular piece of SQL function invocation syntax. - pub(crate) fn is_applicable_to_args(&self, fn_args_syntax: &FunctionArguments) -> bool { - match fn_args_syntax { - FunctionArguments::None => self.args.is_empty(), - FunctionArguments::Subquery(_) => self.args.len() == 1, - FunctionArguments::List(fn_args) => self.args.len() == fn_args.args.len(), - } - } - - /// Creates an [`InstantiatedSig`] from `self`, filling in the [`Kind`]s with fresh type variables. - pub(crate) fn instantiate(&self, inferencer: &TypeInferencer<'_>) -> InstantiatedSig { - let mut generics: HashMap<&'static str, Arc> = HashMap::new(); - - for generic in self.generics.iter() { - generics.insert(generic, inferencer.fresh_tvar()); - } - - InstantiatedSig { - args: self - .args - .iter() - .map(|kind| match kind { - Kind::Native => Arc::new(Type::any_native()), - Kind::Generic(generic) => generics[generic].clone(), - }) - .collect(), - - return_type: match self.return_type { - Kind::Native => Arc::new(Type::any_native()), - Kind::Generic(generic) => generics[generic].clone(), - }, - } - } - - /// For functions that do not have special case handling we synthesise an [`InstatiatedSig`] from the SQL function - /// invocation synta where all arguments and the return types are native. - pub(crate) fn instantiate_native(function: &Function) -> InstantiatedSig { - let arg_count = match &function.args { - FunctionArguments::None => 0, - FunctionArguments::Subquery(_) => 1, - FunctionArguments::List(args) => args.args.len(), - }; - - let args: Vec> = (0..arg_count) - .map(|_| Arc::new(Type::any_native())) - .collect(); - - InstantiatedSig { - args, - return_type: Arc::new(Type::any_native()), - } - } -} - -impl InstantiatedSig { - /// Applies the type constraints of the function to to the AST. - pub(crate) fn apply_constraints<'ast>( - &self, - inferencer: &mut TypeInferencer<'ast>, - function: &'ast Function, - ) -> Result<(), TypeError> { - inferencer.unify_node_with_type(function, self.return_type.clone())?; - - match &function.args { - FunctionArguments::None => { - if self.args.is_empty() { - Ok(()) - } else { - Err(TypeError::Conflict(format!( - "expected {} args to function {}; got 0", - self.args.len(), - &function.name, - ))) - } - } - - FunctionArguments::Subquery(query) => { - if self.args.len() == 1 { - inferencer.unify_node_with_type(&**query, self.args[0].clone())?; - Ok(()) - } else { - Err(TypeError::Conflict(format!( - "expected {} args to function {}; got 0", - self.args.len(), - &function.name, - ))) - } - } - - FunctionArguments::List(args) => { - for (sig_arg, fn_arg) in self.args.iter().zip(args.args.iter()) { - let farg_expr = get_function_arg_expr(fn_arg); - inferencer.unify_node_with_type(farg_expr, sig_arg.clone())?; - } - - Ok(()) - } - } - } -} - -fn get_function_arg_expr(fn_arg: &FunctionArg) -> &FunctionArgExpr { - match fn_arg { - FunctionArg::Named { arg, .. } => arg, - FunctionArg::ExprNamed { arg, .. } => arg, - FunctionArg::Unnamed(arg) => arg, - } -} - -impl SqlFunction { - fn new(ident: &str, sig: FunctionSig, rewrite_rule: RewriteRule) -> Self { - Self { - name: ObjectName(vec![ObjectNamePart::Identifier(Ident::new(ident))]), - sig, - rewrite_rule, - } - } -} - -/// SQL functions that are handled with special case type checking rules. -static SQL_FUNCTIONS: LazyLock>> = LazyLock::new(|| { - // Notation: a single uppercase letter denotes an unknown type. Matching letters in a signature will be assigned - // *the same type variable* and thus must resolve to the same type. (🙏 Haskell) - // - // Eventually we should type check EQL types against their configured indexes instead of leaving that to the EQL - // extension in the database. I can imagine supporting type bounds in signatures here, such as: `T: Eq` - let sql_fns = vec![ - // TODO: when search_path support is added to the resolver we should change these - // to their fully-qualified names. - sql_fn!(count(T) -> NATIVE), - sql_fn!(min(T) -> T, rewrite), - sql_fn!(max(T) -> T, rewrite), - sql_fn!(jsonb_path_query(T, T) -> T, rewrite), - sql_fn!(jsonb_path_query_first(T, T) -> T, rewrite), - sql_fn!(jsonb_path_exists(T, T) -> T, rewrite), - sql_fn!(jsonb_array_length(T) -> T, rewrite), - sql_fn!(jsonb_array_elements(T) -> T, rewrite), - sql_fn!(jsonb_array_elements_text(T) -> T, rewrite), - // These are typings for when customer SQL already contains references to EQL functions. - // They must be type checked but not rewritten. - sql_fn!(eql_v2.min(T) -> T), - sql_fn!(eql_v2.max(T) -> T), - sql_fn!(eql_v2.jsonb_path_query(T, T) -> T), - sql_fn!(eql_v2.jsonb_path_query_first(T, T) -> T), - sql_fn!(eql_v2.jsonb_path_exists(T, T) -> T), - sql_fn!(eql_v2.jsonb_array_length(T) -> T), - sql_fn!(eql_v2.jsonb_array_elements(T) -> T), - sql_fn!(eql_v2.jsonb_array_elements_text(T) -> T), - ]; - - let mut sql_fns_by_name: HashMap> = HashMap::new(); - - for (key, chunk) in &sql_fns.into_iter().chunk_by(|sql_fn| sql_fn.name.clone()) { - sql_fns_by_name.insert(key.clone(), chunk.into_iter().collect()); - } - - sql_fns_by_name -}); - -pub(crate) fn get_sql_function_def( - fn_name: &ObjectName, - args: &FunctionArguments, -) -> Option<&'static SqlFunction> { - let sql_fns = SQL_FUNCTIONS.get(fn_name)?; - sql_fns - .iter() - .find(|sql_fn| sql_fn.sig.is_applicable_to_args(args)) -} diff --git a/packages/eql-mapper/src/inference/sql_types/mod.rs b/packages/eql-mapper/src/inference/sql_types/mod.rs new file mode 100644 index 00000000..665c8679 --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_types/mod.rs @@ -0,0 +1,7 @@ +mod sql_binary_operator_types; +mod sql_decls; +mod sql_function_types; + +pub(crate) use sql_binary_operator_types::*; +pub(crate) use sql_decls::*; +pub(crate) use sql_function_types::*; diff --git a/packages/eql-mapper/src/inference/sql_types/sql_binary_operator_types.rs b/packages/eql-mapper/src/inference/sql_types/sql_binary_operator_types.rs new file mode 100644 index 00000000..e5a6f94b --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_types/sql_binary_operator_types.rs @@ -0,0 +1,49 @@ +use sqltk::parser::ast::Expr; + +use crate::{ + unifier::{BinaryOpDecl, Type}, + TypeError, TypeInferencer, +}; + +/// A rule for determining how to apply typing rules to a SQL binary operator expression. +#[derive(Debug)] +pub(crate) enum SqlBinaryOp { + /// An explicit predefined rule for handling EQL types in the expression. + Explicit(&'static BinaryOpDecl), + + /// The fallback rule for when there is no explicit rule for a given operator. This rule will force the left and + /// right expressions of the operator and its return value to resolve to [`Type::native()`]. + Fallback, +} + +impl SqlBinaryOp { + pub(crate) fn apply_constraints<'ast>( + &self, + inferencer: &mut TypeInferencer<'ast>, + lhs: &'ast Expr, + rhs: &'ast Expr, + return_val: &'ast Expr, + ) -> Result<(), TypeError> { + match self { + SqlBinaryOp::Explicit(rule) => { + let lhs_ty = inferencer.get_node_type(lhs); + let rhs_ty = inferencer.get_node_type(rhs); + let ret_ty = inferencer.get_node_type(return_val); + + rule.inner.apply( + &mut inferencer.unifier.borrow_mut(), + &[lhs_ty, rhs_ty], + ret_ty, + )?; + } + + SqlBinaryOp::Fallback => { + inferencer.unify_node_with_type(lhs, Type::native())?; + inferencer.unify_node_with_type(rhs, Type::native())?; + inferencer.unify_node_with_type(return_val, Type::native())?; + } + } + + Ok(()) + } +} diff --git a/packages/eql-mapper/src/inference/sql_types/sql_decls.rs b/packages/eql-mapper/src/inference/sql_types/sql_decls.rs new file mode 100644 index 00000000..ed50d257 --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_types/sql_decls.rs @@ -0,0 +1,111 @@ +use std::{collections::HashMap, sync::LazyLock}; + +use eql_mapper_macros::{binary_operators, functions}; +use sqltk::parser::ast::{BinaryOperator, Ident, ObjectName, ObjectNamePart}; + +use crate::unifier::{BinaryOpDecl, FunctionDecl}; + +use super::{SqlBinaryOp, SqlFunction}; + +/// SQL operators that can accept EQL types. +static SQL_BINARY_OPERATORS: LazyLock> = + LazyLock::new(|| { + let ops = binary_operators! { + (T = T) -> Native where T: Eq; + (T <> T) -> Native where T: Eq; + (T <= T) -> Native where T: Ord; + (T >= T) -> Native where T: Ord; + (T < T) -> Native where T: Ord; + (T > T) -> Native where T: Ord; + (T -> ::Accessor) -> T where T: JsonLike; + (T ->> ::Accessor) -> T where T: JsonLike; + (T @> T) -> Native where T: Contain; + (T <@ T) -> Native where T: Contain; + (T ~~ ::Tokenized) -> Native where T: TokenMatch; // LIKE + (T !~~ ::Tokenized) -> Native where T: TokenMatch; // NOT LIKE + (T ~~* ::Tokenized) -> Native where T: TokenMatch; // ILIKE + (T !~~* ::Tokenized) -> Native where T: TokenMatch; // NOT ILIKE + }; + ops.into_iter() + .map(|binary_op_spec| (binary_op_spec.op.clone(), binary_op_spec)) + .collect::>() + }); + +pub(crate) fn get_sql_binop_rule(op: &BinaryOperator) -> SqlBinaryOp { + SQL_BINARY_OPERATORS + .get(op) + .map(SqlBinaryOp::Explicit) + .unwrap_or(SqlBinaryOp::Fallback) +} + +/// SQL functions that are handled with special case type checking rules for EQL. +static SQL_FUNCTION_TYPES: LazyLock> = LazyLock::new(|| { + // # SQL function declations. + // + // `Native` automatically satisfies *all* trait bounds. This is the trick that keeps the complexity of EQL Mapper's + // type system simple enough to be tractable for a small team of engineers. It is a *safe* strategy because even + // though EQL Mapper will not catch a type error with incorrect use of native database types, Postgres will. + // + // The Postgres versions of `count`, `min`, `max` etc are defined in the `pg_catalog` namespace. `pg_catalog` is + // prepended to the `search_path` by Postgres. When resolving the names of registered unqualified functions in + // this list, `pg_catalog` is assumed to be the schema. Additionally, functions in `pg_catalog` will be + // rewritten to their EQL counterpart by the EQL Mapper. + + let items = functions! { + pg_catalog.count(T) -> Native; + pg_catalog.min(T) -> T where T: Ord; + pg_catalog.max(T) -> T where T: Ord; + pg_catalog.jsonb_path_query(T, ::Path) -> T where T: JsonLike; + pg_catalog.jsonb_path_query_first(T, ::Path) -> T where T: JsonLike; + pg_catalog.jsonb_path_exists(T, ::Path) -> Native where T: JsonLike; + pg_catalog.jsonb_array_length(T) -> Native where T: JsonLike; + pg_catalog.jsonb_array_elements(T) -> SetOf where T: JsonLike; + pg_catalog.jsonb_array_elements_text(T) -> SetOf where T: JsonLike; + eql_v2.min(T) -> T where T: Ord; + eql_v2.max(T) -> T where T: Ord; + eql_v2.jsonb_path_query(T, ::Path) -> T where T: JsonLike; + eql_v2.jsonb_path_query_first(T, ::Path) -> T where T: JsonLike; + eql_v2.jsonb_path_exists(T, ::Path) -> Native where T: JsonLike; + eql_v2.jsonb_array_length(T) -> Native where T: JsonLike; + eql_v2.jsonb_array_elements(T) -> SetOf where T: JsonLike; + eql_v2.jsonb_array_elements_text(T) -> SetOf where T: JsonLike; + }; + + HashMap::from_iter( + items + .into_iter() + .map(|rule: FunctionDecl| (rule.name.clone(), rule)), + ) +}); + +pub(crate) fn get_sql_function(fn_name: &ObjectName) -> SqlFunction { + // FIXME: this is a hack and we need proper schema resolution logic + let fully_qualified_fn_name = if fn_name.0.len() == 1 { + &ObjectName(vec![ + ObjectNamePart::Identifier(Ident::new("pg_catalog")), + fn_name.0[0].clone(), + ]) + } else { + fn_name + }; + + SQL_FUNCTION_TYPES + .get(fully_qualified_fn_name) + .map(SqlFunction::Explicit) + .unwrap_or(SqlFunction::Fallback) +} + +#[cfg(test)] +mod tests { + use crate::inference::sql_types::sql_decls::{SQL_BINARY_OPERATORS, SQL_FUNCTION_TYPES}; + + #[test] + fn binops_load_properly() { + let _ = &*SQL_BINARY_OPERATORS; + } + + #[test] + fn sqlfns_load_properly() { + let _ = &*SQL_FUNCTION_TYPES; + } +} diff --git a/packages/eql-mapper/src/inference/sql_types/sql_function_types.rs b/packages/eql-mapper/src/inference/sql_types/sql_function_types.rs new file mode 100644 index 00000000..bec4f39e --- /dev/null +++ b/packages/eql-mapper/src/inference/sql_types/sql_function_types.rs @@ -0,0 +1,141 @@ +use std::sync::{Arc, LazyLock}; + +use sqltk::parser::ast::{ + Function, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, ObjectNamePart, +}; + +use crate::{ + unifier::{FunctionDecl, Type, Unifier}, + TypeError, TypeInferencer, +}; + +/// Either explicit typing rules for a function that supports EQL, or a fallback where the typing rules force all +/// function argument types and the return type to be native. +#[derive(Debug)] +pub(crate) enum SqlFunction { + Explicit(&'static FunctionDecl), + Fallback, +} + +static PG_CATALOG: LazyLock = + LazyLock::new(|| ObjectNamePart::Identifier(Ident::new("pg_catalog"))); + +impl SqlFunction { + pub(crate) fn should_rewrite(&self) -> bool { + match self { + SqlFunction::Explicit(function_spec) => function_spec.name.0[0] == *PG_CATALOG, + SqlFunction::Fallback => false, + } + } +} + +fn get_function_arg_expr(fn_arg: &FunctionArg) -> &FunctionArgExpr { + match fn_arg { + FunctionArg::Named { arg, .. } => arg, + FunctionArg::ExprNamed { arg, .. } => arg, + FunctionArg::Unnamed(arg) => arg, + } +} + +impl SqlFunction { + pub(crate) fn apply_constraints<'ast>( + &self, + inferencer: &mut TypeInferencer<'ast>, + function: &'ast Function, + ) -> Result<(), TypeError> { + let ret_type = inferencer.get_node_type(function); + match self { + SqlFunction::Explicit(rule) => { + match &function.args { + FunctionArguments::None => { + rule.inner + .apply(&mut inferencer.unifier.borrow_mut(), &[], ret_type)? + } + FunctionArguments::Subquery(query) => { + let node_type = inferencer.get_node_type(&**query); + rule.inner.apply( + &mut inferencer.unifier.borrow_mut(), + &[node_type], + ret_type, + )? + } + FunctionArguments::List(list) => { + let args: Vec> = list + .args + .iter() + .map(|arg| inferencer.get_node_type(get_function_arg_expr(arg))) + .collect(); + rule.inner + .apply(&mut inferencer.unifier.borrow_mut(), &args, ret_type)? + } + }; + + Ok(()) + } + SqlFunction::Fallback => { + match &function.args { + FunctionArguments::None => NativeFunction::new(0).apply_constraints( + &mut inferencer.unifier.borrow_mut(), + &[], + ret_type, + )?, + FunctionArguments::Subquery(query) => { + let query_type = &[inferencer.get_node_type(&**query)]; + NativeFunction::new(1).apply_constraints( + &mut inferencer.unifier.borrow_mut(), + query_type, + ret_type, + )? + } + FunctionArguments::List(list) => { + let args: Vec> = list + .args + .iter() + .map(|arg| inferencer.get_node_type(get_function_arg_expr(arg))) + .collect(); + NativeFunction::new(args.len() as u8).apply_constraints( + &mut inferencer.unifier.borrow_mut(), + &args, + ret_type, + )? + } + }; + + Ok(()) + } + } + } +} + +pub(crate) struct NativeFunction { + arg_count: u8, +} + +impl NativeFunction { + pub fn new(arg_count: u8) -> Self { + Self { arg_count } + } + + pub(crate) fn apply_constraints( + &self, + unifier: &mut Unifier<'_>, + args: &[Arc], + ret: Arc, + ) -> Result<(), TypeError> { + if args.len() != self.arg_count as usize { + return Err(TypeError::Expected(format!( + "expected {} function arguments but for {}", + self.arg_count, + args.len() + ))); + } + + for arg in args.iter() { + unifier.unify(arg.clone(), Type::native().into())?; + } + + unifier.unify(ret.clone(), Type::native().into())?; + + Ok(()) + } +} diff --git a/packages/eql-mapper/src/inference/type_error.rs b/packages/eql-mapper/src/inference/type_error.rs index 6b4a0c4d..26cf258b 100644 --- a/packages/eql-mapper/src/inference/type_error.rs +++ b/packages/eql-mapper/src/inference/type_error.rs @@ -2,6 +2,8 @@ use std::sync::Arc; use crate::{unifier::Type, SchemaError, ScopeError}; +use super::unifier::EqlTraits; + #[derive(Debug, PartialEq, Eq, thiserror::Error)] pub enum TypeError { #[error("SQL feature {} is not supported", _0)] @@ -13,18 +15,15 @@ pub enum TypeError { #[error("{}", _0)] Conflict(String), + #[error("Type `{}` does not satisfy bounds `{}`", _0, _1)] + UnsatisfiedBounds(Arc, EqlTraits), + #[error("unified type contains unresolved type variable: {}", _0)] Incomplete(String), #[error("{}", _0)] Expected(String), - #[error("Expected param count to be {}, but got {}", _0, _1)] - ParamCount(usize, usize), - - #[error("{}", _0)] - FunctionCall(String), - #[error("{}", _0)] ScopeError(#[from] ScopeError), @@ -40,4 +39,10 @@ pub enum TypeError { _4 )] OnNodes(String, Arc, String, Arc, String), + + #[error("Cannot parse placeholder syntax '{}'", _0)] + ParamSyntax(String), + + #[error("{}", _0)] + TypeSignature(String), } diff --git a/packages/eql-mapper/src/inference/unifier/eql_traits.rs b/packages/eql-mapper/src/inference/unifier/eql_traits.rs new file mode 100644 index 00000000..f9a2e824 --- /dev/null +++ b/packages/eql-mapper/src/inference/unifier/eql_traits.rs @@ -0,0 +1,407 @@ +use std::sync::Arc; + +use derive_more::derive::{Deref, Display}; + +use crate::{ + unifier::{AssociatedTypeSelector, SetOf}, + TypeError, +}; + +use super::{Array, EqlTerm, EqlValue, Projection, Type, Value, Var}; + +/// Represents the supported operations on an EQL type +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Display, Hash)] +pub enum EqlTrait { + #[display("Eq")] + Eq, + #[display("Ord")] + Ord, + #[display("TokenMatch")] + TokenMatch, + #[display("JsonLike")] + JsonLike, + #[display("Contain")] + Contain, +} + +#[derive(Debug, Deref)] +pub(crate) struct EqlTraitAssociatedTypes(#[deref] pub(crate) &'static [&'static str]); + +const ASSOC_TYPES_EQ: &EqlTraitAssociatedTypes = &EqlTraitAssociatedTypes(&["Only"]); + +const ASSOC_TYPES_ORD: &EqlTraitAssociatedTypes = &EqlTraitAssociatedTypes(&["Only"]); + +const ASSOC_TYPES_TOKEN_MATCH: &EqlTraitAssociatedTypes = &EqlTraitAssociatedTypes(&["Tokenized"]); + +const ASSOC_TYPES_JSON_LIKE: &EqlTraitAssociatedTypes = + &EqlTraitAssociatedTypes(&["Path", "Accessor"]); + +const ASSOC_TYPES_CONTAIN: &EqlTraitAssociatedTypes = &EqlTraitAssociatedTypes(&["Only"]); + +impl EqlTrait { + pub(crate) const fn associated_type_names(&self) -> &'static EqlTraitAssociatedTypes { + match self { + EqlTrait::Eq => ASSOC_TYPES_EQ, + EqlTrait::Ord => ASSOC_TYPES_ORD, + EqlTrait::TokenMatch => ASSOC_TYPES_TOKEN_MATCH, + EqlTrait::JsonLike => ASSOC_TYPES_JSON_LIKE, + EqlTrait::Contain => ASSOC_TYPES_CONTAIN, + } + } + + pub(crate) fn has_associated_type(&self, assoc_type_name: &str) -> bool { + self.associated_type_names().contains(&assoc_type_name) + } + + pub(crate) fn resolve_associated_type( + &self, + ty: Arc, + selector: &AssociatedTypeSelector, + ) -> Result, TypeError> { + ty.clone() + .must_implement(&EqlTraits::from(selector.eql_trait))?; + + match &*ty { + // Native satisfies all associated type bounds + Type::Value(Value::Native(_)) => { + match (self, selector.type_name) { + (EqlTrait::Eq, "Only") + | (EqlTrait::Ord, "Only") + | (EqlTrait::TokenMatch, "Tokenized") + | (EqlTrait::JsonLike, "Accessor") + | (EqlTrait::JsonLike, "Path") + | (EqlTrait::Contain, "Only") => Ok(ty.clone()), + (_, unknown_associated_type) => Err(TypeError::InternalError(format!( + "Unknown associated type {self}::{unknown_associated_type}" + ))), + } + } + Type::Value(Value::Eql(EqlTerm::Full(eql_col))) + | Type::Value(Value::Eql(EqlTerm::Partial(eql_col, _))) => { + match (self, selector.type_name) { + (EqlTrait::Eq, "Only") => { + Ok(Arc::new(Type::Value(Value::Eql( + EqlTerm::Partial(eql_col.clone(), EqlTraits::from(EqlTrait::Eq)), + )))) + } + (EqlTrait::Ord, "Only") => { + Ok(Arc::new(Type::Value(Value::Eql( + EqlTerm::Partial(eql_col.clone(), EqlTraits::from(EqlTrait::Ord)), + )))) + } + (EqlTrait::TokenMatch, "Tokenized") => Ok(Arc::new(Type::Value(Value::Eql(EqlTerm::Tokenized(eql_col.clone()))), + )), + (EqlTrait::JsonLike, "Accessor") => { + Ok(Arc::new(Type::Value( + Value::Eql(EqlTerm::JsonAccessor(eql_col.clone())), + ))) + } + (EqlTrait::JsonLike, "Path") => Ok(Arc::new(Type::Value(Value::Eql(EqlTerm::JsonPath(eql_col.clone()))), + )), + (EqlTrait::Contain, "Only") => { + Ok(Arc::new(Type::Value(Value::Eql( + EqlTerm::Partial(eql_col.clone(), EqlTraits::from(EqlTrait::Contain)), + )))) + } + (_, unknown_associated_type) => Err(TypeError::InternalError(format!( + "Unknown associated type {self}::{unknown_associated_type}" + ))), + } + } + _ => Err(TypeError::InternalError(format!( + "associated type can only be resolved on Value::Native or Value::Eql types; got {ty}", + ))), + } + } +} + +/// Represents the set of "traits" implemented by a [`crate::Type`]. +/// +/// EQL types _and_ native types are tested against the bounds, but the trick is that native types *always* satisfy all +/// of the bounds (we let the database do its job - it will shout loudly when an expression has been used incorrectly). +/// +/// EQL types _must_ implement every individually required bound. This information will eventually let us produce good +/// error messages, but implemented bounds are exposed to consumers [`crate::TypeCheckedStatement`] in order to inform +/// how to encrypt literals and params whether for storage or querying. +/// +/// Two [`EqlTraits`] values always successfully unify by merging their flags. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default, Hash)] +pub struct EqlTraits { + /// The type implements equality between its values using the `=` operator. + pub eq: bool, + + /// The type implements comparison of its values using `>`, `>=`, `=`, `<=`, `<`. + /// `ord` implies `eq`. + pub ord: bool, + + /// The type implements textual substring search using `LIKE`. + pub token_match: bool, + + /// The type implements field selection (e.g. `->` & `->>`) + pub json_like: bool, + + /// The type implements containment checking (e.g. `@>` and `<@`) + pub contain: bool, +} + +/// An [`EqlTraits`] with all trait flags set to `true`. +pub const ALL_TRAITS: EqlTraits = EqlTraits { + eq: true, + ord: true, + token_match: true, + json_like: true, + contain: true, +}; + +impl From for EqlTraits { + fn from(eql_trait: EqlTrait) -> Self { + let mut traits = EqlTraits::default(); + traits.add_mut(eql_trait); + traits + } +} + +impl FromIterator for EqlTraits { + fn from_iter>(iter: T) -> Self { + let mut traits = EqlTraits::default(); + for t in iter { + traits.add_mut(t) + } + traits + } +} + +impl EqlTraits { + /// An `EqlTraits` with all traits flags set to `false`. + pub fn none() -> Self { + Self::default() + } + + /// An `EqlTraits` with all traits flags set to `true`. + pub fn all() -> Self { + ALL_TRAITS + } + + pub(crate) fn add_mut(&mut self, eql_trait: EqlTrait) { + match eql_trait { + EqlTrait::Eq => self.eq = true, + EqlTrait::Ord => { + self.ord = true; + self.eq = true; // implied by Ord + } + EqlTrait::TokenMatch => self.token_match = true, + EqlTrait::JsonLike => { + self.json_like = true; + } + EqlTrait::Contain => self.contain = true, + } + } + + pub(crate) fn union(&self, other: &Self) -> Self { + EqlTraits { + eq: self.eq || other.eq, + ord: self.ord || other.ord, + token_match: self.token_match || other.token_match, + json_like: self.json_like || other.json_like, + contain: self.contain || other.contain, + } + } + + pub(crate) fn intersection(&self, other: &Self) -> Self { + EqlTraits { + eq: self.eq && other.eq, + ord: self.ord && other.ord, + token_match: self.token_match && other.token_match, + json_like: self.json_like && other.json_like, + contain: self.contain && other.contain, + } + } + + pub(crate) fn difference(&self, other: &Self) -> Self { + EqlTraits { + eq: self.eq ^ other.eq, + ord: self.ord ^ other.ord, + token_match: self.token_match ^ other.token_match, + json_like: self.json_like ^ other.json_like, + contain: self.contain ^ other.contain, + } + } +} + +impl std::fmt::Display for EqlTraits { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + const EQ: &str = "Eq"; + const ORD: &str = "Ord"; + const TOKEN_MATCH: &str = "TokenMatch"; + const CONTAIN: &str = "Contain"; + const JSON_LIKE: &str = "JsonLike"; + + let mut traits: Vec<&'static str> = Vec::new(); + if self.eq { + traits.push(EQ) + } + if self.ord { + traits.push(ORD) + } + if self.token_match { + traits.push(TOKEN_MATCH) + } + if self.contain { + traits.push(CONTAIN) + } + if self.json_like { + traits.push(JSON_LIKE) + } + + f.write_str(&traits.join("+"))?; + + Ok(()) + } +} + +impl Type { + pub(crate) fn effective_bounds(&self) -> EqlTraits { + match self { + Type::Value(value) => value.effective_bounds(), + Type::Var(Var(_, bounds)) => *bounds, + Type::Associated(associated_type) => associated_type.resolved_ty.effective_bounds(), + } + } +} + +impl Value { + pub(crate) fn effective_bounds(&self) -> EqlTraits { + match self { + Value::Eql(eql_term) => eql_term.effective_bounds(), + Value::Native(_) => ALL_TRAITS, + Value::Array(ty) => ty.effective_bounds(), + Value::Projection(projection) => projection.effective_bounds(), + Value::SetOf(set_of) => set_of.effective_bounds(), + } + } +} + +impl Array { + pub(crate) fn effective_bounds(&self) -> EqlTraits { + let Array(element_ty) = self; + element_ty.effective_bounds() + } +} + +impl SetOf { + pub(crate) fn effective_bounds(&self) -> EqlTraits { + let SetOf(some_ty) = self; + some_ty.effective_bounds() + } +} + +impl Projection { + pub(crate) fn effective_bounds(&self) -> EqlTraits { + if let Some((first, rest)) = self.columns().split_first() { + let mut acc = first.ty.effective_bounds(); + for col in rest { + acc = acc.intersection(&col.ty.effective_bounds()) + } + acc + } else { + EqlTraits::none() + } + } +} + +impl EqlTerm { + pub(crate) fn effective_bounds(&self) -> EqlTraits { + match self { + EqlTerm::Full(eql_value) => eql_value.effective_bounds(), + EqlTerm::Partial(_, bounds) => *bounds, + EqlTerm::JsonAccessor(_) => EqlTraits::none(), + EqlTerm::JsonPath(_) => EqlTraits::none(), + EqlTerm::Tokenized(_) => EqlTraits::none(), + } + } +} + +impl EqlValue { + pub(crate) fn effective_bounds(&self) -> EqlTraits { + self.trait_impls() + } +} + +/* + TODO: the following represents how I would eventually like to define the traits. + + /// `COUNT` has to be declared in order to work with EQL types. + function pg_catalog.count(T) -> Native; + + /// Trait that corresponds to equality tests in SQL. + eqltrait Eq { + /// The most minimal encoding of `Self` that can still be used by EQL (the database extension) to perform + /// equality tests. The purpose of `Partial` is to avoid generating all of the non-`Eq` search terms of + /// `Self` if they are not going to be used. + type Only; + + binop (Self = Self) -> Native; + binop (Self <> Self) -> Native; + } + + /// Trait that corresponds to comparison tests in SQL. + eqltrait Ord: Eq { + /// The most minimal encoding of `Self` that can still be used by EQL (the database extension) to perform + /// comparison tests. The purpose of `Only` is to avoid generating all of the non-`Ord` search terms of + /// `Self` if they are not going to be used. + type Only; + + binop (Self <= Self) -> Native; + binop (Self >= Self) -> Native; + binop (Self < Self) -> Native; + binop (Self > Self) -> Native; + + fn pg_catalog.min(Self) -> Self; + fn pg_catalog.max(Self) -> Self; + } + + /// Trait that corresponds to containment testing operations in SQL. + eqltrait Contain { + type Only; + + binop (Self @> Self) -> Native; + binop (Self <@ Self) -> Native; + } + + /// Trait that corresponds to JSON/B operations in SQL. + eqltrait JsonLike { + /// A term that can select a field by name or an array element by index on `Self`. + type Accessor; + + /// A term that can be used to match an entire JSON path on `Self`. + type Path; + + binop (Self -> Self::Accessor) -> Self; + binop (Self ->> Self::Accessor) -> Self; + + fn pg_catalog.jsonb_path_query(Self, Self::Path) -> Self; + fn pg_catalog.jsonb_path_query_first(Self, Self::Path) -> Self; + fn pg_catalog.jsonb_path_exists(Self, Self::Path) -> Native; + fn pg_catalog.jsonb_array_length(Self) -> Native; + fn pg_catalog.jsonb_array_elements(Self) -> SetOf; + fn pg_catalog.jsonb_array_elements_text(Self) -> SetOf; + } + + /// Trait that corresponds to LIKE operations in SQL. + eqltrait TokenMatch { + type Tokenized; + + binop (Self ~~ Self::Tokenized) -> Native; + binop (Self !~~ Self::Tokenized) -> Native; + + LIKE { expr: Self, pattern: Self::Tokenized, .. } -> Native; + } + + /// Trait that corresponds to LIKE & ILIKE operations in SQL. + eqltrait TokenMatchCaseInsensitive: TokenMatch { + binop (Self ~~* Self::Tokenized) -> Native; + binop (Self !~~* Self::Tokenized) -> Native; + + ILIKE { expr: Self, pattern: Self::Tokenized, .. } -> Native; + } +*/ diff --git a/packages/eql-mapper/src/inference/unifier/instantiated_type_env.rs b/packages/eql-mapper/src/inference/unifier/instantiated_type_env.rs new file mode 100644 index 00000000..99836e7c --- /dev/null +++ b/packages/eql-mapper/src/inference/unifier/instantiated_type_env.rs @@ -0,0 +1,49 @@ +use std::{collections::HashMap, fmt::Display, sync::Arc}; + +use crate::{ + unifier::{TVar, Type}, + TypeError, +}; + +#[derive(Debug, Clone, Default)] +pub(crate) struct InstantiatedTypeEnv { + types: HashMap>, +} + +impl InstantiatedTypeEnv { + pub(crate) fn new() -> Self { + Self::default() + } + + pub(crate) fn add_type(&mut self, tvar: TVar, ty: Arc) -> Result<(), TypeError> { + if self.types.insert(tvar.clone(), ty).is_none() { + Ok(()) + } else { + Err(TypeError::InternalError(format!( + "named type {tvar} already initialised in {self}" + ))) + } + } + + pub(crate) fn get_type(&self, tvar: &TVar) -> Result, TypeError> { + match self.types.get(tvar).cloned() { + Some(ty) => Ok(ty), + None => Err(TypeError::InternalError(format!( + "type for tvar {tvar} not found" + ))), + } + } +} + +impl Display for InstantiatedTypeEnv { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("InstantiatedTypeEnv{ ")?; + for (idx, (tvar, spec)) in self.types.iter().enumerate() { + f.write_fmt(format_args!("{tvar} => {spec}"))?; + if idx < self.types.len() - 1 { + f.write_str(", ")?; + } + } + f.write_str("}") + } +} diff --git a/packages/eql-mapper/src/inference/unifier/mod.rs b/packages/eql-mapper/src/inference/unifier/mod.rs index fb523356..864fdb9b 100644 --- a/packages/eql-mapper/src/inference/unifier/mod.rs +++ b/packages/eql-mapper/src/inference/unifier/mod.rs @@ -1,19 +1,34 @@ use std::{cell::RefCell, collections::HashMap, rc::Rc, sync::Arc}; +mod eql_traits; +mod instantiated_type_env; +mod resolve_type; +mod type_decl; +mod type_env; mod types; +mod unify_types; use crate::inference::TypeError; +pub use eql_traits::*; +pub(crate) use type_decl::*; + +use unify_types::UnifyTypes; + use sqltk::AsNodeKey; pub(crate) use types::*; -pub use types::{EqlValue, NativeValue, TableColumn}; +pub(crate) use type_env::*; +pub use types::{EqlTerm, EqlValue, NativeValue, TableColumn}; use super::TypeRegistry; -use tracing::{event, instrument, Level}; +use tracing::{event, instrument, Level, Span}; -/// Implements the type unification algorithm and maintains an association of type variables with the type that they -/// point to. +/// Implements the type unification algorithm. +/// +/// Type unification is the process of determining a type variable substitution that makes two type expressions +/// identical. It involves solving equations between types, by recursively comparing their structure and binding type +/// variables to concrete types or other variables. #[derive(Debug)] pub struct Unifier<'ast> { registry: Rc>>, @@ -28,7 +43,11 @@ impl<'ast> Unifier<'ast> { } pub(crate) fn fresh_tvar(&self) -> Arc { - Type::Var(self.registry.borrow_mut().fresh_tvar()).into() + self.fresh_bounded_tvar(EqlTraits::none()) + } + + pub(crate) fn fresh_bounded_tvar(&self, bounds: EqlTraits) -> Arc { + Type::Var(Var(self.registry.borrow_mut().fresh_tvar(), bounds)).into() } pub(crate) fn get_substitutions(&self) -> HashMap> { @@ -71,7 +90,30 @@ impl<'ast> Unifier<'ast> { .collect(); for (_, ty) in unresolved_value_nodes { - self.unify(ty, Type::any_native().into())?; + self.unify(ty, Type::native().into())?; + } + + Ok(()) + } + + pub(crate) fn resolve_unresolved_associated_types(&mut self) -> Result<(), TypeError> { + let unresolved_associated_types: Vec<_> = self + .registry + .borrow() + .get_nodes_and_types::() + .into_iter() + .map(|(node, ty)| (node, ty.follow_tvars(&*self))) + .filter_map(|(node, ty)| { + if let Type::Associated(associated) = &*ty { + Some((node, associated.clone())) + } else { + None + } + }) + .collect(); + + for (_node, associated_ty) in unresolved_associated_types { + associated_ty.resolve_selector_target(self)?; } Ok(()) @@ -101,248 +143,115 @@ impl<'ast> Unifier<'ast> { target = "eql-mapper::UNIFY", skip(self), level = "trace", - ret(Display), err(Debug), fields( lhs = %lhs, rhs = %rhs, + return = tracing::field::Empty, ) )] pub(crate) fn unify(&mut self, lhs: Arc, rhs: Arc) -> Result, TypeError> { - use types::Constructor::*; - use types::Value::*; - - let lhs: Arc = lhs; - let rhs: Arc = rhs; - - // Short-circuit the unification when lhs & rhs are equal. - if lhs == rhs { - return Ok(lhs.clone()); - } - - let unification = match (&*lhs, &*rhs) { - // Two projections unify if they have the same number of columns and all of the paired column types also - // unify. - (Type::Constructor(Projection(_)), Type::Constructor(Projection(_))) => { - self.unify_projections(lhs, rhs) - } + let span = Span::current(); + + let result = (|| { + // Short-circuit the unification when lhs & rhs are equal. + if lhs == rhs { + Ok(lhs.clone()) + } else { + match (&*lhs, &*rhs) { + (Type::Value(lhs_c), Type::Value(rhs_c)) => self.unify_types(lhs_c, rhs_c), + + (Type::Var(var), Type::Value(value)) | (Type::Value(value), Type::Var(var)) => { + self.unify_types(value, var) + } - // Two arrays unify if the types of their element types unify. - ( - Type::Constructor(Value(Array(lhs_element_ty))), - Type::Constructor(Value(Array(rhs_element_ty))), - ) => { - let unified_element_ty = - self.unify(lhs_element_ty.clone(), rhs_element_ty.clone())?; - let unified_array_ty = Type::Constructor(Value(Array(unified_element_ty))); - Ok(unified_array_ty.into()) - } + (Type::Var(lhs_v), Type::Var(rhs_v)) => self.unify_types(lhs_v, rhs_v), - // A Value can unify with a single column projection - (Type::Constructor(Value(_)), Type::Constructor(Projection(projection))) => { - let projection = projection.flatten(); - let len = projection.len(); - if len == 1 { - self.unify_value_type_with_one_col_projection(lhs, projection[0].ty.clone()) - } else { - Err(TypeError::Conflict( - "cannot unify value type with projection of more than one column" - .to_string(), - )) - } - } + (Type::Value(value), Type::Associated(associated_type)) + | (Type::Associated(associated_type), Type::Value(value)) => { + self.unify_types(associated_type, value) + } - (Type::Constructor(Projection(projection)), Type::Constructor(Value(_))) => { - let projection = projection.flatten(); - let len = projection.len(); - if len == 1 { - self.unify_value_type_with_one_col_projection(rhs, projection[0].ty.clone()) - } else { - Err(TypeError::Conflict( - "cannot unify value type with projection of more than one column" - .to_string(), - )) - } - } + (Type::Var(var), Type::Associated(associated_type)) + | (Type::Associated(associated_type), Type::Var(var)) => { + self.unify_types(associated_type, var) + } - // All native types are considered equal in the type system. However, for improved test readability the - // unifier favours a `NativeValue(Some(_))` over a `NativeValue(None)` because `NativeValue(Some(_))` - // carries more information. In a tie, the left hand side wins. - ( - Type::Constructor(Value(Native(native_lhs))), - Type::Constructor(Value(Native(native_rhs))), - ) => match (native_lhs, native_rhs) { - (NativeValue(Some(_)), NativeValue(Some(_))) => Ok(lhs), - (NativeValue(Some(_)), NativeValue(None)) => Ok(lhs), - (NativeValue(None), NativeValue(Some(_))) => Ok(rhs), - _ => Ok(lhs), - }, - - (Type::Constructor(Value(Eql(_))), Type::Constructor(Value(Eql(_)))) => { - if lhs == rhs { - Ok(lhs) - } else { - Err(TypeError::Conflict(format!( - "cannot unify different EQL types: {} and {}", - lhs, rhs - ))) + (Type::Associated(lhs_assoc), Type::Associated(rhs_assoc)) => { + self.unify_types(lhs_assoc, rhs_assoc) + } } } + })(); - // A constructor resolves with a type variable if either: - // 1. the type variable does not already refer to a constructor (transitively), or - // 2. it does refer to a constructor and the two constructors unify - (_, Type::Var(tvar)) => self.unify_with_type_var(lhs, *tvar), - - // A constructor resolves with a type variable if either: - // 1. the type variable does not already refer to a constructor (transitively), or - // 2. it does refer to a constructor and the two constructors unify - (Type::Var(tvar), _) => self.unify_with_type_var(rhs, *tvar), - - // Any other combination of types is a type error. - (lhs, rhs) => Err(TypeError::Conflict(format!( - "type {} cannot be unified with {}", - lhs, rhs - ))), - }; - - match unification { - Ok(ty) => { - event!( - name: "UNIFY::OK", - target: "eql-mapper::EVENT_UNIFY_OK", - Level::TRACE, - ty = %ty, - ); - - Ok(ty) - } - Err(err) => { - event!( - name: "UNIFY::ERR", - target: "eql-mapper::EVENT_UNIFY_ERR", - Level::TRACE, - err = ?&err - ); - - Err(err) - } + if let Ok(ref val) = result { + span.record("return", tracing::field::display(val)); } + + result } /// Unifies a type with a type variable. /// /// Attempts to unify the type with whatever the type variable is pointing to. - /// - /// After successful unification `ty_rc` and `tvar_rc` will refer to the same allocation. fn unify_with_type_var( &mut self, ty: Arc, tvar: TypeVar, + tvar_bounds: &EqlTraits, ) -> Result, TypeError> { - let sub_ty = { - let registry = &*self.registry.borrow(); - registry.get_type(tvar) - }; - - let unified_ty: Arc = match sub_ty { - Some(sub_ty) => self.unify(ty, sub_ty)?, - None => ty, - }; - - self.substitute(tvar, unified_ty.clone()); - - Ok(unified_ty) - } - - /// Unifies two projection types. - fn unify_projections( - &mut self, - lhs: Arc, - rhs: Arc, - ) -> Result, TypeError> { - match (&*lhs, &*rhs) { - ( - Type::Constructor(Constructor::Projection(lhs_projection)), - Type::Constructor(Constructor::Projection(rhs_projection)), - ) => { - let lhs_projection = lhs_projection.flatten(); - let rhs_projection = rhs_projection.flatten(); - - if lhs_projection.len() == rhs_projection.len() { - let mut cols: Vec = Vec::with_capacity(lhs_projection.len()); - - for (lhs_col, rhs_col) in lhs_projection - .columns() - .iter() - .zip(rhs_projection.columns()) - { - let unified_ty = self.unify(lhs_col.ty.clone(), rhs_col.ty.clone())?; - cols.push(ProjectionColumn::new(unified_ty, lhs_col.alias.clone())); + let unified = match self.get_type(tvar) { + Some(sub_ty) => { + self.satisfy_bounds(&sub_ty, tvar_bounds)?; + self.unify(ty, sub_ty)? + } + None => { + if let Type::Var(Var(_, ty_bounds)) = &*ty { + if ty_bounds != tvar_bounds { + self.fresh_bounded_tvar(tvar_bounds.union(ty_bounds)) + } else { + ty.clone() } - - let unified_ty = - Type::Constructor(Constructor::Projection(Projection::new(cols))); - - Ok(unified_ty.into()) } else { - Err(TypeError::Conflict(format!( - "cannot unify projections {} and {} because they have different numbers of columns", - lhs, rhs - ))) + ty.clone() } } - (_, _) => Err(TypeError::InternalError( - "unify_projections expected projection types".to_string(), - )), - } + }; + + Ok(self.substitute(tvar, unified)) } - fn unify_value_type_with_one_col_projection( - &mut self, - value_ty: Arc, - projection_ty: Arc, - ) -> Result, TypeError> { - match (&*value_ty, &*projection_ty) { - ( - Type::Constructor(Constructor::Value(Value::Eql(lhs))), - Type::Constructor(Constructor::Value(Value::Eql(rhs))), - ) if lhs == rhs => Ok(value_ty.clone()), - ( - Type::Constructor(Constructor::Value(Value::Native(lhs))), - Type::Constructor(Constructor::Value(Value::Native(rhs))), - ) => match (lhs, rhs) { - (NativeValue(Some(_)), NativeValue(Some(_))) => Ok(value_ty.clone()), - (NativeValue(Some(_)), NativeValue(None)) => Ok(value_ty.clone()), - (NativeValue(None), NativeValue(Some(_))) => Ok(projection_ty.clone()), - _ => Ok(value_ty.clone()), - }, - ( - Type::Constructor(Constructor::Value(Value::Array(lhs))), - Type::Constructor(Constructor::Value(Value::Array(rhs))), - ) => { - let unified_element_ty = self.unify(lhs.clone(), rhs.clone())?; - let unified_array_ty = - Type::Constructor(Constructor::Value(Value::Array(unified_element_ty))); - Ok(unified_array_ty.into()) - } - (Type::Constructor(Constructor::Value(Value::Eql(_))), Type::Var(tvar)) => { - self.unify_with_type_var(value_ty.clone(), *tvar) - } - (Type::Var(tvar), Type::Constructor(Constructor::Value(Value::Eql(_)))) => { - self.unify_with_type_var(projection_ty.clone(), *tvar) - } - _ => Err(TypeError::Conflict(format!( - "value type {} cannot be unified with single column projection of {}", - value_ty, projection_ty - ))), + /// Prove that `ty` satisfies `bounds`. + /// + /// If `ty` is a [`Type::Var`] this test always passes. + /// + /// # Rules + /// + /// 1. Native types satisfy all possible bounds. + /// 2. EQL types satisfy bounds that they implement. + /// 3. Arrays satisfy all bounds of their element type. + /// 4. Projections satisfy the intersection of the bounds of their columns. + /// a. However, empty projections satisfy all possible bounds. + /// 5. Type variables satisfy all bounds that they carry. + fn satisfy_bounds(&mut self, ty: &Type, bounds: &EqlTraits) -> Result<(), TypeError> { + if let Type::Var(_) = ty { + return Ok(()); + } + + if &bounds.intersection(&ty.effective_bounds()) == bounds { + Ok(()) + } else { + Err(TypeError::UnsatisfiedBounds( + Arc::new(ty.clone()), + bounds.difference(&ty.effective_bounds()), + )) } } } pub(crate) mod test_util { use sqltk::parser::ast::{ - Delete, Expr, Function, FunctionArguments, Insert, Query, Select, SelectItem, SetExpr, + Delete, Expr, Function, FunctionArgExpr, Insert, Query, Select, SelectItem, SetExpr, Statement, Value, Values, }; use sqltk::{AsNodeKey, Break, Visitable, Visitor}; @@ -367,7 +276,7 @@ pub(crate) mod test_util { /// Dumps the type information for a specific AST node to STDERR. /// /// Useful when debugging tests. - pub(crate) fn dump_node(&self, node: &'ast N) { + pub(crate) fn dump_node(&self, node: &'ast N) { let root_ty = self.get_node_type(node).clone(); let found_ty = root_ty.clone().follow_tvars(self); let ast_ty = type_name::(); @@ -431,7 +340,7 @@ pub(crate) mod test_util { self.0.dump_node(node); } - if let Some(node) = node.downcast_ref::() { + if let Some(node) = node.downcast_ref::() { self.0.dump_node(node); } @@ -454,135 +363,118 @@ pub(crate) mod test_util { #[cfg(test)] mod test { - use std::sync::Arc; + use eql_mapper_macros::shallow_init_types; - use crate::unifier::{Constructor::*, NativeValue, ProjectionColumn, Type, TypeVar, Value::*}; - use crate::unifier::{ProjectionColumns, Unifier}; + use crate::unifier::Unifier; + use crate::unifier::{EqlTraits, InstantiateType}; use crate::{DepMut, TypeRegistry}; #[test] fn eq_native() { let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); - let lhs: Arc<_> = Type::Constructor(Value(Native(NativeValue(None)))).into(); - let rhs: Arc<_> = Type::Constructor(Value(Native(NativeValue(None)))).into(); + shallow_init_types! {&mut unifier, { + let lhs = Native; + let rhs = Native; + }}; assert_eq!(unifier.unify(lhs.clone(), rhs), Ok(lhs)); } - #[ignore = "this is addressed in unmerged PR"] #[test] - fn eq_never() { + fn constructor_with_var() { let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); - let lhs: Arc<_> = Type::Constructor(Projection(crate::unifier::Projection::Empty)).into(); - let rhs: Arc<_> = Type::Constructor(Projection(crate::unifier::Projection::Empty)).into(); + shallow_init_types! { &mut unifier, { + let lhs = Native; + let rhs = T; + }}; assert_eq!(unifier.unify(lhs.clone(), rhs), Ok(lhs)); } #[test] - fn constructor_with_var() { + fn var_with_constructor() { let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); - let lhs: Arc<_> = Type::Constructor(Value(Native(NativeValue(None)))).into(); - let rhs: Arc<_> = Type::Var(TypeVar(0)).into(); + shallow_init_types! {&mut unifier, { + let lhs = T; + let rhs = Native; + let expected = Native; + }}; - assert_eq!(unifier.unify(lhs.clone(), rhs), Ok(lhs)); + let actual = unifier.unify(lhs, rhs).unwrap(); + assert_eq!(actual, expected); } #[test] - fn var_with_constructor() { + fn projections_without_wildcards() { let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); - let lhs: Arc<_> = Type::Var(TypeVar(0)).into(); - let rhs: Arc<_> = Type::Constructor(Value(Native(NativeValue(None)))).into(); + shallow_init_types! {&mut unifier, { + let lhs = {Native, T}; + let rhs = {U, Native}; + let expected = {Native, Native}; + }}; + + let actual = unifier.unify(lhs, rhs).unwrap(); - assert_eq!(unifier.unify(lhs, rhs.clone()), Ok(rhs)); + assert_eq!(actual, expected); } #[test] - fn projections_without_wildcards() { + #[ignore = "this scenario cannot happen during unification because wildcards will have been expanded before the projections are unified"] + // Leaving this test here as a reminder in case the above assertion proves to be false. + fn projections_with_wildcards() { let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); - let lhs: Arc<_> = Type::Constructor(Projection(crate::unifier::Projection::WithColumns( - ProjectionColumns(vec![ - ProjectionColumn::new(Type::Constructor(Value(Native(NativeValue(None)))), None), - ProjectionColumn::new(Type::Var(TypeVar(0)), None), - ]), - ))) - .into(); - - let rhs: Arc<_> = Type::Constructor(Projection(crate::unifier::Projection::WithColumns( - ProjectionColumns(vec![ - ProjectionColumn::new(Type::Var(TypeVar(1)), None), - ProjectionColumn::new(Type::Constructor(Value(Native(NativeValue(None)))), None), - ]), - ))) - .into(); + shallow_init_types! {&mut unifier, { + let lhs = {Native, Native}; + // rhs is a single projection that contains a projection column that contains a projection with two + // projection columns. This is how wildcard expansions is represented at the type level. + let rhs = {{Native, Native}}; + let expected = {Native, Native}; + }}; - let unified = unifier.unify(lhs, rhs).unwrap(); + let actual = unifier.unify(lhs, rhs).unwrap(); - assert_eq!( - *unified, - Type::Constructor(Projection(crate::unifier::Projection::WithColumns( - ProjectionColumns(vec![ - ProjectionColumn::new( - Type::Constructor(Value(Native(NativeValue(None)))), - None - ), - ProjectionColumn::new( - Type::Constructor(Value(Native(NativeValue(None)))), - None - ), - ]) - ))) - ); + assert_eq!(actual, expected); } #[test] - fn projections_with_wildcards() { + fn type_var_bounds_are_unified() { let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); - let lhs: Arc<_> = Type::Constructor(Projection(crate::unifier::Projection::WithColumns( - ProjectionColumns(vec![ - ProjectionColumn::new(Type::Constructor(Value(Native(NativeValue(None)))), None), - ProjectionColumn::new(Type::Constructor(Value(Native(NativeValue(None)))), None), - ]), - ))) - .into(); - - let cols: Arc<_> = Type::Constructor(Projection(crate::unifier::Projection::WithColumns( - ProjectionColumns(vec![ - ProjectionColumn::new(Type::Constructor(Value(Native(NativeValue(None)))), None), - ProjectionColumn::new(Type::Constructor(Value(Native(NativeValue(None)))), None), - ]), - ))) - .into(); - - // The RHS is a single projection that contains a projection column that contains a projection with two - // projection columns. This is how wildcard expansions is represented at the type level. - let rhs: Arc<_> = Type::Constructor(Projection(crate::unifier::Projection::WithColumns( - ProjectionColumns(vec![ProjectionColumn::new(cols, None)]), - ))) - .into(); + shallow_init_types! {&mut unifier, { + let lhs = T; + let rhs = U; + }}; let unified = unifier.unify(lhs, rhs).unwrap(); + assert_eq!(unified.effective_bounds(), EqlTraits::default()); - assert_eq!( - *unified, - Type::Constructor(Projection(crate::unifier::Projection::WithColumns( - ProjectionColumns(vec![ - ProjectionColumn::new( - Type::Constructor(Value(Native(NativeValue(None)))), - None - ), - ProjectionColumn::new( - Type::Constructor(Value(Native(NativeValue(None)))), - None - ), - ]) - ))) - ); + let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); + + shallow_init_types! {&mut unifier, { + let lhs = T; + let rhs = U: Eq; + let expected = V: Eq; + }}; + + let actual = unifier.unify(lhs, rhs).unwrap(); + + assert_eq!(actual.effective_bounds(), expected.effective_bounds()); + + let mut unifier = Unifier::new(DepMut::new(TypeRegistry::new())); + + shallow_init_types! {&mut unifier, { + let lhs = T: Eq; + let rhs = U; + let expected = V: Eq; + }}; + + let actual = unifier.unify(lhs, rhs).unwrap(); + + assert_eq!(actual.effective_bounds(), expected.effective_bounds()); } } diff --git a/packages/eql-mapper/src/inference/unifier/resolve_type.rs b/packages/eql-mapper/src/inference/unifier/resolve_type.rs new file mode 100644 index 00000000..f23709b6 --- /dev/null +++ b/packages/eql-mapper/src/inference/unifier/resolve_type.rs @@ -0,0 +1,104 @@ +use crate::TypeError; + +use super::{Array, NativeValue, Projection, SetOf, Type, Unifier, Value, Var}; + +/// A trait for resolving all type variables contained in a [`crate::unifier::Type`] and converting the successfully +/// resolved type into the publicly exported [`crate::Type`] type representation which is identical except for the +/// absence of type variables. +pub(crate) trait ResolveType { + /// The corresponding type for `Self` in `crate::Type::..`, e.g. when `Self` is `crate::unifier::Type` then + /// `Self::Output` is `crate::Type`. + type Output; + + /// Recursively resolves all type variables found in `self` and if successful it builds and returns `Ok(Self::Output)`. + /// + /// Returns a [`TypeError`] if there are any unresolved type variables. + fn resolve_type(&self, unifier: &mut Unifier<'_>) -> Result; +} + +impl ResolveType for Type { + type Output = crate::Type; + + fn resolve_type(&self, unifier: &mut Unifier<'_>) -> Result { + match self { + Type::Value(constructor) => Ok(constructor.resolve_type(unifier)?.into()), + + Type::Var(Var(type_var, _)) => { + if let Some(sub_ty) = unifier.get_type(*type_var) { + return sub_ty.resolved(unifier); + } + + Err(TypeError::Incomplete(format!( + "there are no substitutions for type var {type_var}" + ))) + } + + Type::Associated(associated) => { + if let Some(constructor) = associated.resolve_selector_target(unifier)? { + Ok(constructor.resolve_type(unifier)?) + } else { + Err(TypeError::InternalError(format!( + "could not resolve associated type {associated}" + ))) + } + } + } + } +} + +impl ResolveType for SetOf { + type Output = crate::SetOf; + + fn resolve_type(&self, unifier: &mut Unifier<'_>) -> Result { + Ok(crate::SetOf(Box::new(self.0.resolve_type(unifier)?))) + } +} + +impl ResolveType for Value { + type Output = crate::Value; + + fn resolve_type(&self, unifier: &mut Unifier<'_>) -> Result { + match self { + Value::Eql(eql_term) => Ok(crate::Value::Eql(eql_term.clone())), + Value::Native(NativeValue(Some(native_col))) => { + Ok(crate::Value::Native(NativeValue(Some(native_col.clone())))) + } + Value::Native(NativeValue(None)) => Ok(crate::Value::Native(NativeValue(None))), + Value::Array(Array(element_ty)) => { + let resolved = element_ty.resolve_type(unifier)?; + Ok(crate::Value::Array(crate::Array(resolved.into()))) + } + Value::Projection(projection) => { + Ok(crate::Value::Projection(projection.resolve_type(unifier)?)) + } + + Value::SetOf(set_of) => { + let resolved = set_of.resolve_type(unifier)?; + Ok(crate::Value::SetOf(resolved)) + } + } + } +} + +impl ResolveType for Projection { + type Output = crate::Projection; + + fn resolve_type(&self, unifier: &mut Unifier<'_>) -> Result { + let resolved_cols = self.columns().iter().try_fold( + vec![], + |mut acc, col| -> Result, TypeError> { + let alias = col.alias.clone(); + if let Type::Value(Value::Projection(projection)) = &*col.ty { + let resolved = projection.resolve_type(unifier)?; + acc.extend(resolved.0.into_iter()); + } else { + let crate::Type::Value(value) = col.ty.resolve_type(unifier)?; + acc.push(crate::ProjectionColumn { ty: value, alias }); + } + Ok(acc) + }, + )?; + + Ok(crate::Projection(resolved_cols)) + } +} diff --git a/packages/eql-mapper/src/inference/unifier/type_decl.rs b/packages/eql-mapper/src/inference/unifier/type_decl.rs new file mode 100644 index 00000000..040838a5 --- /dev/null +++ b/packages/eql-mapper/src/inference/unifier/type_decl.rs @@ -0,0 +1,649 @@ +//! # Symbolic type declarations +//! +//! [`TypeDecl`] and its variants provide the means to write type declarations symbolically which is much more pleasant +//! than constructing types via manually constructing [`Type`] variants. +//! +//! This makes it much simpler (and therefore less of a chore) to write tests for the [`Unifier`] and also for declaring +//! EQL-compatible functions and operators. +//! +//! The [`eql_mapper_macros`] crate provides macros that implement a mini-DSL for declaring types and macros for +//! instantiating them (converting to [`Type`]s for unification). +//! +//! Here is an example of the `type_env` macro in use: +//! +//! ```ignore +//! use eql_mapper_macros::type_env; +//! +//! let env = type_env! { +//! P = {A as id, B as name, C as email}; +//! A = Native(customer.id); +//! B = EQL(customer.name: Eq); +//! C = EQL(customer.email: Eq); +//! }; +//! ``` + +use std::{fmt::Display, sync::Arc}; + +use derive_more::derive::Display; +use sqltk::parser::ast::{BinaryOperator, ObjectName}; +use tracing::{event, instrument, Level}; + +use crate::{ + unifier::{instantiated_type_env::InstantiatedTypeEnv, AssociatedTypeSelector}, + EqlTrait, Fmt, TypeError, +}; + +use super::{ + AssociatedType, EqlTerm, EqlTraits, NativeValue, Projection, ProjectionColumn, TableColumn, + Type, TypeEnv, Unifier, Value, +}; + +/// A `TypeDecl` is a symbolic representation of a [`Type`]. Multiple type declarations can be added to a [`TypeEnv`] +/// and when the type environment is "instantiated" with [`TypeEnv::instantiate`], they become `Arc` values for +/// use in the [`Unifier`]. +/// +/// SQL functions & operators also have type decalaration syntax but they do not have corresponding [`TypeDecl`] +/// variants because functions and operators are not first class values in SQL. See [`FunctionDecl`] & [`BinaryOpDecl`]. +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Display)] +pub(crate) enum TypeDecl { + /// A type variable. See [`crate::unifier::Var`]. + #[display("{}", _0)] + Var(VarDecl), + + /// A native type with an optional table-column. See [`crate::unifier::NativeValue`]. + #[display("{}", _0)] + Native(NativeDecl), + + /// An EQL column with zero or more [`EqlTrait`] implementations. See [`crate::unifier::EqlTerm`]. + #[display("{}", _0)] + #[allow(unused)] + Eql(EqlTerm), + + /// An array. See [`crate::unifier::Array`]. + #[display("{}", _0)] + #[allow(unused)] + Array(ArrayDecl), + + /// A projection. See [`crate::unifier::Projection`] + #[display("{}", _0)] + #[allow(unused)] + Projection(ProjectionDecl), + + /// An associated type. See [`crate::unifier::AssociatedType`]. + #[display("{}", _0)] + AssociatedType(AssociatedTypeDecl), + + /// A `setof` type. See [`crate::unifier::SetOf`]. + #[display("{}", _0)] + SetOf(SetOfDecl), +} + +impl TypeDecl { + /// Recursively finds all of the [`TVar`]s that `self` uses within its definition. This information is used to + /// control the order of type instantiation in a [`TypeEnv`]. + pub(crate) fn depends_on(&self) -> Vec<&TVar> { + match self { + TypeDecl::Var(VarDecl { tvar, .. }) => { + if !tvar.0.starts_with("$") { + vec![tvar] + } else { + vec![] + } + } + TypeDecl::Native(_) => vec![], + TypeDecl::Eql(_) => vec![], + TypeDecl::Array(ArrayDecl(decl)) => decl.depends_on(), + TypeDecl::Projection(ProjectionDecl(cols)) => { + cols.iter().flat_map(|col| col.0.depends_on()).collect() + } + TypeDecl::AssociatedType(AssociatedTypeDecl { impl_decl, .. }) => { + impl_decl.depends_on() + } + TypeDecl::SetOf(SetOfDecl(decl)) => decl.depends_on(), + } + } +} + +/// Trait for instantiating a [`Type`] from a [`TypeDecl`]. +/// +/// # Instantiation modes +/// +/// - "in-env" substitutes type variables with concrete types by looking up already initialised types via type variables +/// in the environment. +/// +/// - "shallow" does not resolve type variables to concrete types, instead instantiates those as fresh, unique +/// [`Type::Var`] values generated by the [`Unifier`]. +/// +/// - "concrete" initialises concrete types only and will fail with an error if a [`TypeDecl`] contains a type variable. +pub(crate) trait InstantiateType { + /// Instantiates a [`Type`] to be used in an [`InstantiatedTypeEnv`]. + /// + /// This method is called by [`TypeEnv::instantiate`] which controls the initialisation order thus guaranteeing that + /// dependencies will already be initialised and available from `env`. + /// + /// The [`TVar`]s depended upon by `self` will have already been initialised and can be accessed via `env`. + fn instantiate_in_env( + &self, + unifier: &mut Unifier<'_>, + env: &InstantiatedTypeEnv, + ) -> Result, TypeError>; + + /// Instantiates a [`Type`] without looking up dependencies in an [`InstantiatedTypeEnv`]. + /// + /// Every dependency is initialised to a fresh [`Type::Var`]. + #[allow(unused)] + fn instantiate_shallow(&self, unifier: &mut Unifier<'_>) -> Result, TypeError>; + + /// Initialises a concrete [`Type`] from a [`TypeDecl`] variant that contains no [`TypeDecl::Var`]s. If a + /// `TypeDecl::Var` is encountered then a [`TypeError`] will be returned. + fn instantiate_concrete(&self) -> Result, TypeError>; +} + +impl InstantiateType for TypeDecl { + fn instantiate_in_env( + &self, + unifier: &mut Unifier<'_>, + env: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + match self { + TypeDecl::Var(var_decl) => var_decl.instantiate_in_env(unifier, env), + TypeDecl::Native(native_decl) => native_decl.instantiate_in_env(unifier, env), + TypeDecl::Eql(eql_term) => eql_term.instantiate_in_env(unifier, env), + TypeDecl::Array(array_decl) => array_decl.instantiate_in_env(unifier, env), + TypeDecl::Projection(projection_decl) => { + projection_decl.instantiate_in_env(unifier, env) + } + TypeDecl::AssociatedType(associated_type_decl) => { + associated_type_decl.instantiate_in_env(unifier, env) + } + TypeDecl::SetOf(setof_decl) => setof_decl.instantiate_in_env(unifier, env), + } + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + match self { + TypeDecl::Var(decl) => decl.instantiate_concrete(), + TypeDecl::Native(decl) => decl.instantiate_concrete(), + TypeDecl::Eql(decl) => decl.instantiate_concrete(), + TypeDecl::Array(decl) => decl.instantiate_concrete(), + TypeDecl::Projection(decl) => decl.instantiate_concrete(), + TypeDecl::AssociatedType(decl) => decl.instantiate_concrete(), + TypeDecl::SetOf(decl) => decl.instantiate_concrete(), + } + } + + fn instantiate_shallow(&self, unifier: &mut Unifier<'_>) -> Result, TypeError> { + match self { + TypeDecl::Var(decl) => decl.instantiate_shallow(unifier), + TypeDecl::Native(decl) => decl.instantiate_shallow(unifier), + TypeDecl::Eql(decl) => decl.instantiate_shallow(unifier), + TypeDecl::Array(decl) => decl.instantiate_shallow(unifier), + TypeDecl::Projection(decl) => decl.instantiate_shallow(unifier), + TypeDecl::AssociatedType(decl) => decl.instantiate_shallow(unifier), + TypeDecl::SetOf(decl) => decl.instantiate_shallow(unifier), + } + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Display)] +#[display("[{}]", _0)] +pub(crate) struct ArrayDecl(pub(crate) Box); + +impl InstantiateType for ArrayDecl { + fn instantiate_in_env( + &self, + unifier: &mut Unifier<'_>, + env: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + Ok(Type::array(self.0.instantiate_in_env(unifier, env)?)) + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + Ok(Type::array(self.0.instantiate_concrete()?)) + } + + fn instantiate_shallow(&self, unifier: &mut Unifier<'_>) -> Result, TypeError> { + Ok(Type::array(self.0.instantiate_shallow(unifier)?)) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct VarDecl { + pub(crate) tvar: TVar, + pub(crate) bounds: EqlTraits, +} + +impl Display for VarDecl { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.tvar))?; + if self.bounds != EqlTraits::none() { + f.write_fmt(format_args!(": {}", self.bounds))?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Display)] +pub(crate) struct TVar(#[display("${}", _0)] pub(crate) String); + +impl InstantiateType for VarDecl { + fn instantiate_in_env( + &self, + unifier: &mut Unifier<'_>, + env: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + if !self.tvar.0.starts_with("$") { + env.get_type(&self.tvar) + } else { + Ok(unifier.fresh_bounded_tvar(self.bounds)) + } + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + Err(TypeError::InternalError( + "tried to build a concrete type from a typedecl containing a type variable".into(), + )) + } + + fn instantiate_shallow(&self, unifier: &mut Unifier<'_>) -> Result, TypeError> { + Ok(unifier.fresh_bounded_tvar(self.bounds)) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct ProjectionDecl(pub(crate) Vec); + +impl Display for ProjectionDecl { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("{")?; + for (idx, col) in self.0.iter().enumerate() { + f.write_fmt(format_args!("{col}"))?; + if idx < self.0.len() - 1 { + f.write_str(",")?; + } + } + f.write_str("}") + } +} + +impl InstantiateType for ProjectionDecl { + fn instantiate_in_env( + &self, + unifier: &mut Unifier<'_>, + env: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + Ok(Arc::new(Type::Value(Value::Projection(Projection( + self.0 + .iter() + .map(|col_decl| -> Result<_, TypeError> { + Ok(ProjectionColumn::new( + col_decl.0.instantiate_in_env(unifier, env)?, + col_decl.1.clone(), + )) + }) + .collect::, _>>()?, + ))))) + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + Ok(Arc::new(Type::Value(Value::Projection(Projection( + self.0 + .iter() + .map(|col_decl| -> Result<_, TypeError> { + Ok(ProjectionColumn::new( + col_decl.0.instantiate_concrete()?, + col_decl.1.clone(), + )) + }) + .collect::, _>>()?, + ))))) + } + + fn instantiate_shallow(&self, unifier: &mut Unifier<'_>) -> Result, TypeError> { + Ok(Arc::new(Type::Value(Value::Projection(Projection( + self.0 + .iter() + .map(|col_decl| -> Result<_, TypeError> { + Ok(ProjectionColumn::new( + col_decl.0.instantiate_shallow(unifier)?, + col_decl.1.clone(), + )) + }) + .collect::, _>>()?, + ))))) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct ProjectionColumnDecl( + pub(crate) Box, + pub(crate) Option, +); + +impl Display for ProjectionColumnDecl { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.0))?; + if let Some(alias) = &self.1 { + f.write_fmt(format_args!(" as {alias}"))?; + } + Ok(()) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, Display)] +#[display("<{} as {}>::{}", impl_decl, as_eql_trait, type_name)] +pub(crate) struct AssociatedTypeDecl { + pub(crate) impl_decl: Box, + pub(crate) as_eql_trait: EqlTrait, + pub(crate) type_name: &'static str, +} + +impl InstantiateType for AssociatedTypeDecl { + fn instantiate_in_env( + &self, + unifier: &mut Unifier<'_>, + env: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + let impl_ty = self.impl_decl.instantiate_in_env(unifier, env)?; + let resolved_ty = unifier.fresh_tvar(); + + Ok(Arc::new(Type::Associated(AssociatedType { + impl_ty, + resolved_ty, + selector: AssociatedTypeSelector::new(self.as_eql_trait, self.type_name)?, + }))) + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + let impl_ty = self.impl_decl.instantiate_concrete()?; + let selector = AssociatedTypeSelector::new(self.as_eql_trait, self.type_name)?; + let resolved_ty = selector.resolve(impl_ty.clone())?; + + Ok(Arc::new(Type::Associated(AssociatedType { + impl_ty, + resolved_ty, + selector, + }))) + } + + fn instantiate_shallow(&self, unifier: &mut Unifier<'_>) -> Result, TypeError> { + let impl_ty = self.impl_decl.instantiate_shallow(unifier)?; + let selector = AssociatedTypeSelector::new(self.as_eql_trait, self.type_name)?; + let resolved_ty = selector.resolve(impl_ty.clone())?; + + Ok(Arc::new(Type::Associated(AssociatedType { + impl_ty, + resolved_ty, + selector, + }))) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct SetOfDecl(pub Box); + +impl Display for SetOfDecl { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("SetOf<{}>", self.0)) + } +} + +impl InstantiateType for SetOfDecl { + fn instantiate_in_env( + &self, + unifier: &mut Unifier<'_>, + env: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + Ok(Type::set_of(self.0.instantiate_in_env(unifier, env)?).into()) + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + Ok(Type::set_of(self.0.instantiate_concrete()?).into()) + } + + fn instantiate_shallow(&self, unifier: &mut Unifier<'_>) -> Result, TypeError> { + Ok(Type::set_of(self.0.instantiate_shallow(unifier)?).into()) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct BoundsDecl(pub(crate) TVar, pub(crate) EqlTraits); + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct FunctionDecl { + /// The function name. + pub(crate) name: ObjectName, + /// The declaration of this function. + pub(crate) inner: FunctionSignatureDecl, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct BinaryOpDecl { + /// The binary operator. + pub(crate) op: BinaryOperator, + /// The declaration of this binary operator as a 2-argument function. + pub(crate) inner: FunctionSignatureDecl, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct FunctionSignatureDecl { + /// The type type variables for args. + arg_tvars: Vec, + + /// The type variable for the return type. + ret_tvar: TVar, + + /// The type environment. + type_env: TypeEnv, +} + +impl FunctionSignatureDecl { + /// The generic args of this function - the generic args are local to this function definition. The ONLY type + /// variables allowed to be referenced in `args`, `ret` and `bounds`. + pub(crate) fn new( + generic_args: Vec, + generic_bounds: Vec, + arg_decls: Vec, + ret_decl: TypeDecl, + ) -> Result { + Self::check_no_undeclared_generic_args( + &generic_args, + &generic_bounds, + &arg_decls, + &ret_decl, + )?; + + let (type_env, (arg_tvars, ret_tvar)) = TypeEnv::build(|type_env| { + for tvar in generic_args { + let decl = match generic_bounds.iter().find(|bound| bound.0 == tvar) { + Some(bounds) => TypeDecl::Var(VarDecl { + tvar: type_env.fresh_tvar(), + bounds: bounds.1, + }), + None => TypeDecl::Var(VarDecl { + tvar: type_env.fresh_tvar(), + bounds: EqlTraits::none(), + }), + }; + + type_env.add_decl(tvar, decl); + } + + let arg_tvars = arg_decls + .into_iter() + .map(|decl| -> Result { type_env.add_decl_with_indirection(decl) }) + .collect::, _>>()?; + + let ret_tvar = type_env.add_decl_with_indirection(ret_decl)?; + + Ok((arg_tvars, ret_tvar)) + })?; + + Ok(Self { + arg_tvars, + ret_tvar, + type_env, + }) + } + + #[instrument( + target = "eql-mapper::UNIFY", + skip(self, unifier), + level = "trace", + err(Debug), + fields( + args = %Fmt(args), + ret = %ret, + return = tracing::field::Empty, + ) + )] + pub(crate) fn apply( + &self, + unifier: &mut Unifier<'_>, + args: &[Arc], + ret: Arc, + ) -> Result { + let span = tracing::Span::current(); + + let result = (|| { + if args.len() != self.arg_tvars.len() { + return Err(TypeError::Expected(format!( + "incorrect number of arguments; got {}, expected {}", + args.len(), + self.arg_tvars.len() + ))); + } + + event!( + target: "eql-mapper::TYPE_ENV", + Level::TRACE, + type_env = %self.type_env + ); + + let instantiated_env = self.type_env.instantiate(unifier)?; + + event!( + target: "eql-mapper::TYPE_ENV", + Level::TRACE, + instantiated = %instantiated_env + ); + + for (arg, arg_tvar) in args.iter().zip(self.arg_tvars.iter()) { + unifier.unify(arg.clone(), instantiated_env.get_type(arg_tvar)?)?; + } + + unifier.unify(ret.clone(), instantiated_env.get_type(&self.ret_tvar)?)?; + + event!( + target: "eql-mapper::TYPE_ENV", + Level::TRACE, + env_after_args_and_ret_unified = %instantiated_env + ); + + Ok(instantiated_env) + })(); + + if let Ok(ref instantiated_env) = result { + span.record("return", tracing::field::display(instantiated_env)); + } + + result + } + + fn check_no_undeclared_generic_args<'a>( + generic_args: &'a [TVar], + generic_bounds: &'a Vec, + args: &'a [TypeDecl], + ret: &'a TypeDecl, + ) -> Result<(), TypeError> { + let is_vardecl = |arg: &'a TypeDecl| { + if let TypeDecl::Var(VarDecl { tvar, .. }) = arg { + Some(tvar) + } else { + None + } + }; + + let check_known = |tvar: &TVar| -> Result<(), TypeError> { + if generic_args.contains(tvar) { + Ok(()) + } else { + Err(TypeError::InternalError(format!( + "use of undeclared type var '{tvar}'" + ))) + } + }; + + args.iter() + .filter_map(is_vardecl) + .try_fold((), |_, tvar| check_known(tvar))?; + + if let Some(tvar) = is_vardecl(ret) { + check_known(tvar)?; + } + + for bound in generic_bounds { + if !generic_args.contains(&bound.0) { + return Err(TypeError::InternalError(format!( + "generic bounds references undefined type variable '{}'", + bound.0 + ))); + } + } + + Ok(()) + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct NativeDecl(pub(crate) Option); + +impl Display for NativeDecl { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.0 { + Some(tc) => f.write_fmt(format_args!("Native({tc})")), + None => f.write_fmt(format_args!("Native")), + } + } +} + +impl InstantiateType for NativeDecl { + fn instantiate_in_env( + &self, + _: &mut Unifier<'_>, + _: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + self.instantiate_concrete() + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + match &self.0 { + Some(tc) => Ok(Arc::new(Type::Value(Value::Native(NativeValue(Some( + tc.clone(), + )))))), + None => Ok(Arc::new(Type::Value(Value::Native(NativeValue(None))))), + } + } + + fn instantiate_shallow(&self, _: &mut Unifier<'_>) -> Result, TypeError> { + self.instantiate_concrete() + } +} + +impl InstantiateType for EqlTerm { + fn instantiate_in_env( + &self, + _: &mut Unifier<'_>, + _: &InstantiatedTypeEnv, + ) -> Result, TypeError> { + self.instantiate_concrete() + } + + fn instantiate_concrete(&self) -> Result, TypeError> { + Ok(Arc::new(Type::Value(Value::Eql(self.clone())))) + } + + fn instantiate_shallow(&self, _: &mut Unifier<'_>) -> Result, TypeError> { + self.instantiate_concrete() + } +} diff --git a/packages/eql-mapper/src/inference/unifier/type_env.rs b/packages/eql-mapper/src/inference/unifier/type_env.rs new file mode 100644 index 00000000..e5a41532 --- /dev/null +++ b/packages/eql-mapper/src/inference/unifier/type_env.rs @@ -0,0 +1,283 @@ +//! Type definitions for constructing a [`TypeEnv`] and subsequently an [`InstantiatedTypeEnv`] from [`TypeDecl`]s. +//! +//! A `TypeEnv` is an environment containing `TypeDecl`s. A `TypeDecl` is a mirror of [`Type`] but works symbollically +//! and supports being able to define types with dedicated syntax so that constraints can be built declaratively rather +//! than programatically. +#![allow(unused)] + +use std::cell::RefCell; +use std::collections::HashSet; +use std::fmt::Display; +use std::hash::Hash; +use std::rc::Rc; +use std::{collections::HashMap, sync::Arc}; + +use derive_more::derive::Deref; +use sqltk::parser::ast::{Top, WindowFrameBound}; +use topological_sort::TopologicalSort; +use tracing::{event, instrument, Level}; + +use crate::unifier::instantiated_type_env::InstantiatedTypeEnv; +use crate::{TypeError, TypeRegistry}; + +use super::{ + ArrayDecl, EqlTraits, ProjectionColumnDecl, ProjectionDecl, Type, TypeDecl, TypeVar, Unifier, + VarDecl, +}; +use super::{InstantiateType, TVar}; + +/// A collection of [`TypeDecl`]s. +#[derive(Debug, Clone, Eq, PartialEq)] +pub(crate) struct TypeEnv { + decls: HashMap, + tvar_counter: usize, +} + +impl TypeEnv { + pub(crate) fn new() -> Self { + Self { + /// The [`TypeDecl`]s in the environment. + decls: HashMap::new(), + tvar_counter: 0, + } + } + + /// Builds a [`TypeEnv`] and returns it. + /// + /// After the supplied closure returns this method clones the resulting `TypeEnv` and attempts to instantiate it in + /// order to verify that it is well-formed. If instantiaton is successful then the *uninstantiated* `TypeEnv` is + /// returned. + /// + /// This can be used as a template for initialising [`crate::inference::SqlBinaryOp`] and + /// [`crate::inference::SqlFunction`] environments during unification. + #[instrument( + target = "eql-mapper::TYPE_ENV", + skip(f), + level = "trace", + err(Debug), + fields( + return = tracing::field::Empty, + ) + )] + pub(crate) fn build(mut f: F) -> Result<(Self, Out), TypeError> + where + F: FnOnce(&mut TypeEnv) -> Result, + { + let span = tracing::Span::current(); + + let result = (|| { + let mut type_env = TypeEnv::new(); + let out = f(&mut type_env)?; + let cloned = type_env.clone(); + let mut unifier = Unifier::new(Rc::new(RefCell::new(TypeRegistry::new()))); + match cloned.instantiate(&mut unifier) { + Ok(_) => Ok((type_env, out)), + Err(err) => Err(err), + } + })(); + + if let Ok((ref env, _)) = result { + span.record("return", tracing::field::display(env)); + } + + result + } + + pub(crate) fn fresh_tvar(&mut self) -> TVar { + let tvar = TVar(format!("${}", self.tvar_counter)); + self.tvar_counter += 1; + tvar + } + + pub(crate) fn add_decl(&mut self, tvar: TVar, spec: TypeDecl) -> TVar { + self.decls.insert(tvar.clone(), spec); + tvar + } + + pub(crate) fn add_decl_with_indirection(&mut self, spec: TypeDecl) -> Result { + match spec { + TypeDecl::Var(VarDecl { tvar, .. }) => { + self.get_decl(&tvar)?; + Ok(tvar.clone()) + } + _ => { + let tvar = self.fresh_tvar(); + Ok(self.add_decl(tvar, spec)) + } + } + } + + pub(crate) fn get_decl(&self, tvar: &TVar) -> Result<&TypeDecl, TypeError> { + match self.decls.get(tvar) { + Some(spec) => Ok(spec), + None => Err(TypeError::InternalError(format!( + "unknown typespec {tvar} in type env" + ))), + } + } + + pub(crate) fn get_bounds(&self, tvar: &TVar) -> Result { + match self.decls.get(tvar) { + Some(TypeDecl::Var(VarDecl { bounds, .. })) => Ok(*bounds), + Some(_) => Ok(EqlTraits::none()), + None => Err(TypeError::InternalError(format!( + "tvar {tvar} not found in type env" + ))), + } + } + + /// Builds an [`InstantiatedTypeEnv`] or fails with a [`TypeError`]. + /// + /// 1. All referenced type arguments be be defined in the env. + /// 2. All trait bounds must unify (e.g. the type argument to [`EqlTrait::JsonAccessor`]) must have a + /// `EqlTrait::Json` bound. + pub(crate) fn instantiate( + &self, + unifier: &mut Unifier<'_>, + ) -> Result { + event!( + target: "eql-mapper::TYPE_ENV", + Level::TRACE, + type_env = %self, + ); + + let mut tvars = self.tvars_in_order_of_initialisation(); + + let mut new_env = InstantiatedTypeEnv::new(); + + while let Some(tvar) = tvars.pop() { + let spec = self + .decls + .get(tvar) + .ok_or(TypeError::InternalError(format!( + "expected typespec for tvar {tvar} to be in the typeenv" + )))?; + + let ty = spec.instantiate_in_env(unifier, &new_env)?; + new_env.add_type(tvar.clone(), ty); + } + + Ok(new_env) + } + + fn tvars_in_order_of_initialisation(&self) -> TopologicalSort<&TVar> { + let mut topo = TopologicalSort::<&TVar>::new(); + + for (tvar, spec) in self.decls.iter() { + topo.insert(tvar); + + let dependencies = spec.depends_on(); + + for dep in dependencies { + topo.add_dependency(dep, tvar); + } + } + topo + } +} + +impl Display for TypeEnv { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("TypeEnv{ ")?; + for (idx, (tvar, spec)) in self.decls.iter().enumerate() { + f.write_fmt(format_args!("{tvar} => {spec}"))?; + if idx < self.decls.len() - 1 { + f.write_str(", ")?; + } + } + f.write_str(" }")?; + Ok(()) + } +} + +#[cfg(test)] +mod test { + use std::{cell::RefCell, rc::Rc, sync::Arc}; + + use crate::{ + test_helpers, + unifier::{ + Array, AssociatedType, EqlTerm, EqlTrait, EqlTraits, EqlValue, InstantiateType, Type, + Unifier, Value, + }, + NativeValue, TableColumn, TypeError, TypeRegistry, + }; + + use super::TypeEnv; + use eql_mapper_macros::{tvar, ty, type_env}; + use pretty_assertions::assert_eq; + + fn make_unifier<'a>() -> Unifier<'a> { + Unifier::new(Rc::new(RefCell::new(TypeRegistry::new()))) + } + + #[test] + fn build_env_with_array() -> Result<(), TypeError> { + let env = type_env! { + A = [E]; + E = T; + T = Native; + }; + + let mut unifier = make_unifier(); + let instance = env.instantiate(&mut unifier).unwrap(); + + let array_ty = instance.get_type(&tvar!(A))?; + + assert_eq!(&*array_ty, &*ty!([Native]).instantiate_concrete()?); + + Ok(()) + } + + #[test] + fn build_env_with_projection() -> Result<(), TypeError> { + let env = type_env! { + P = {A as id, B as name, C as email}; + A = Native(customer.id); + B = EQL(customer.name: Eq); + C = EQL(customer.email: Eq); + }; + + let mut unifier = make_unifier(); + let instance = env.instantiate(&mut unifier).unwrap(); + + assert_eq!( + &*instance.get_type(&tvar!(P))?, + &*ty!({ + Native(customer.id) as id, + EQL(customer.name: Eq) as name, + EQL(customer.email: Eq) as email} + ) + .instantiate_concrete()? + ); + Ok(()) + } + + #[test] + fn build_env_with_associated_type() -> Result<(), TypeError> { + let env = type_env! { + E = EQL(customer.name: JsonLike); + A = ::Accessor; + }; + + let mut unifier = make_unifier(); + let instance = env.instantiate(&mut unifier).unwrap(); + + if let Type::Associated(associated) = &*instance.get_type(&tvar!(A))? { + assert_eq!( + associated.resolve_selector_target(&mut unifier)?.as_deref(), + Some(&Type::Value(Value::Eql(EqlTerm::JsonAccessor(EqlValue( + TableColumn { + table: "customer".into(), + column: "name".into() + }, + EqlTraits::from(EqlTrait::JsonLike) + ),)))) + ); + } else { + panic!("expected associated type"); + } + + Ok(()) + } +} diff --git a/packages/eql-mapper/src/inference/unifier/types.rs b/packages/eql-mapper/src/inference/unifier/types.rs index 4c528eba..e27863f1 100644 --- a/packages/eql-mapper/src/inference/unifier/types.rs +++ b/packages/eql-mapper/src/inference/unifier/types.rs @@ -5,28 +5,35 @@ use sqltk::parser::ast::Ident; use crate::{ColumnKind, Table, TypeError}; -use super::Unifier; +use super::{resolve_type::ResolveType, EqlTrait, EqlTraits, Unifier}; -/// The type of an expression in a SQL statement or the type of a table column from the database schema. +/// The [`Type`] enum represents the types used by the [`Unifier`] to represent the SQL & EQL types returned by +/// expressions, projection-producing statements, built-in database functions & operators, EQL function & operators and +/// table columns. /// -/// An expression can be: +/// A value of [`Type`] is either a [`Constructor`] (a fully or partially resolved type) or a [`TypeVar`] (a placeholder +/// for an unresolved type) or [`Associated`] (an associated type). /// -/// - a [`sqltk::parser::ast::Expr`] node -/// - a [`sqltk::parser::ast::Statement`] or any other SQL AST node that produces a projection. -/// -/// A `Type` is either a [`Constructor`] (fully or partially known type) or a [`TypeVar`] (a placeholder for an unknown type). -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] -#[display("{self}")] +/// After successful unification of all of the types in a SQL statement, the types are converted into the publicly +/// exported [`crate::Type`] type, which is a mirror of this enum but without type variables which makes it more +/// ergonomic to consume. +#[derive(Debug, PartialEq, PartialOrd, Ord, Eq, Clone, Display, Hash)] pub enum Type { - /// A specific type constructor with zero or more generic parameters. + /// A value type. #[display("{}", _0)] - Constructor(Constructor), + Value(Value), - /// A type variable representing a placeholder for an unknown type. + /// A type representing a placeholder for an unresolved type. #[display("{}", _0)] - Var(TypeVar), + Var(Var), + + /// An associated type declared in an [`EqlTrait`] and implemented by a type that implements the `EqlTrait`. + #[display("{}", _0)] + Associated(AssociatedType), } +// Statically assert that `Type` is `Send + Sync`. If `Type` did not implement `Send` and/or `Sync` this crate would +// fail to compile anyway but the error message is very obtuse. A failure here makes it obvious. const _: () = { fn assert_send() {} fn assert_sync() {} @@ -37,111 +44,72 @@ const _: () = { } }; -/// A `Constructor` is what is known about a [`Type`]. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] -pub enum Constructor { - /// An EQL type, an opaque "database native" type or an array type. - #[display("{}", _0)] - Value(Value), - - /// A projection is a type with a fixed number of columns each of which has a type and optional alias. - #[display("{}", _0)] - Projection(Projection), +/// An associated type. +/// +/// This is a type of the form `T::A`. `T` is the type that implements a trait that defines the associated type. `A` is +/// the associated type. +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, derive_more::Display)] +#[display("<{} as {}>::{}", impl_ty, selector.eql_trait, selector.type_name)] +pub struct AssociatedType { + /// A value that can resolve the concrete `A` when given a concrete `T`. + pub selector: AssociatedTypeSelector, + + /// The type that implements the trait and will have defined an associated type. In `T::A` `impl_ty` is the `T`. + pub impl_ty: Arc, + + /// The associated type itself. In `T::A` `resolved_ty` is the `A`. + pub resolved_ty: Arc, } -impl Constructor { - fn resolve(&self, unifier: &mut Unifier<'_>) -> Result { - match self { - Constructor::Value(value) => match value { - Value::Eql(eql_col) => Ok(crate::Type::Value(crate::Value::Eql(eql_col.clone()))), - Value::Native(NativeValue(Some(native_col))) => Ok(crate::Type::Value( - crate::Value::Native(NativeValue(Some(native_col.clone()))), - )), - Value::Native(NativeValue(None)) => { - Ok(crate::Type::Value(crate::Value::Native(NativeValue(None)))) - } - Value::Array(element_ty) => { - let resolved = element_ty.resolved(unifier)?; - Ok(crate::Type::Value(crate::Value::Array(resolved.into()))) - } - }, - Constructor::Projection(projection) => { - Ok(crate::Type::Projection(projection.resolve(unifier)?)) - } +impl AssociatedType { + /// Tries to resolve the concrete associated type. + /// + /// If the parent type that the associated type is attached to is not yet resolved then this method will return + /// `Ok(None)`. + pub(crate) fn resolve_selector_target( + &self, + unifier: &mut Unifier<'_>, + ) -> Result>, TypeError> { + let impl_ty = self.impl_ty.clone().follow_tvars(unifier); + if let Type::Value(_) = &*impl_ty { + // The type that implements the EqlTrait is now known, so resolve the selector. + let ty: Arc = self.selector.resolve(impl_ty.clone())?; + Ok(Some(unifier.unify(self.resolved_ty.clone(), ty.clone())?)) + } else { + Ok(None) } } } -impl Projection { - fn resolve(&self, unifier: &mut Unifier<'_>) -> Result { - use itertools::Itertools; - - let resolved_cols = self - .flatten() - .columns() - .iter() - .map(|col| -> Result, TypeError> { - let alias = col.alias.clone(); - match &*col.ty { - Type::Constructor(constructor) => match constructor { - Constructor::Value(Value::Eql(eql_col)) => { - Ok(vec![crate::ProjectionColumn { - ty: crate::Value::Eql(eql_col.clone()), - alias, - }]) - } - Constructor::Value(Value::Native(native_col)) => { - Ok(vec![crate::ProjectionColumn { - ty: crate::Value::Native(native_col.clone()), - alias, - }]) - } - Constructor::Value(Value::Array(array_ty)) => { - match array_ty.resolved(unifier)? { - elem_ty @ crate::Type::Value(_) => { - Ok(vec![crate::ProjectionColumn { - ty: crate::Value::Array(elem_ty.into()), - alias, - }]) - } - crate::Type::Projection(_) => { - Err(TypeError::InternalError("projection type as array element".to_string())) - } - } - } - Constructor::Projection(_) => { - Err(TypeError::InternalError("projection type as projection column; projections should be flattened during final resolution".to_string())) - } - }, - Type::Var(tvar) => { - let ty = unifier.get_type(*tvar).ok_or( - TypeError::InternalError(format!("could not resolve type variable '{}'", tvar)))?; - if let Type::Constructor(Constructor::Projection(projection)) = &*ty { - match projection.resolve(unifier)? { - crate::Projection::WithColumns(projection_columns) => Ok(projection_columns), - crate::Projection::Empty => Ok(vec![]), - } - } else { - match ty.resolved(unifier)? { - crate::Type::Value(value) => Ok(vec![crate::ProjectionColumn { ty: value, alias }]), - crate::Type::Projection(_) => Err(TypeError::InternalError("unexpected projection".to_string())), - } - } - - }, - } - }) - .flatten_ok() - .collect::, _>>()?; - - if resolved_cols.is_empty() { - Ok(crate::Projection::Empty) +/// A type variable with trait bounds. +/// +/// Type variables represent an unresolved type. Unification of a concrete type with a type variable will succeed if the +/// concrete type implements all of the bounds on the type variable. The concrete type is allowed to implement a set of +/// traits that exceed the requirements of the bounds on the type variable. +#[derive(Debug, PartialEq, PartialOrd, Ord, Eq, Clone, Hash)] +pub struct Var(pub TypeVar, pub EqlTraits); + +impl Display for Var { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if self.1 != EqlTraits::none() { + f.write_fmt(format_args!("{}: {}", self.0, self.1)) } else { - Ok(crate::Projection::WithColumns(resolved_cols)) + f.write_fmt(format_args!("{}", self.0)) } } } +/// Represents a SQL `setof` type. Functions such as `jsonb_array_elements` return a `seto`. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] +pub struct SetOf(pub Arc); + +impl SetOf { + pub(crate) fn inner_ty(&self) -> Arc { + self.0.clone() + } +} + +/// The type of SQL expression. #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] pub enum Value { /// An encrypted type from a particular table-column in the schema. @@ -149,7 +117,7 @@ pub enum Value { /// An encrypted column never shares a type with another encrypted column - which is why it is sufficient to /// identify the type by its table & column names. #[display("{}", _0)] - Eql(EqlValue), + Eql(EqlTerm), /// A native database type that carries its table & column name. `NativeValue(None)` & `NativeValue(Some(_))` are /// will successfully unify with each other - they are the same type as far as the type system is concerned. @@ -159,7 +127,88 @@ pub enum Value { /// An array type that is parameterized by an element type. #[display("Array[{}]", _0)] - Array(Arc), + Array(Array), + + /// A projection is a type with a fixed number of columns each of which has a type and optional alias. + #[display("{}", _0)] + Projection(Projection), + + /// In PostgreSQL, SETOF is a special return type used in functions to indicate that the function returns a set of + /// rows rather than a single value. It allows a function to behave like a table or subquery in SQL, producing + /// multiple rows as output. + #[display("{}", _0)] + SetOf(SetOf), +} + +/// An array of some type. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] +pub struct Array(pub Arc); + +/// An `EqlTerm` is a type associated with a particular EQL type, i.e. an [`EqlValue`]. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] +pub enum EqlTerm { + /// This type represents the entire EQL payload (ciphertext + all encrypted search terms). It is suitable both for + /// `INSERT`ing new records and for querying against. + #[display("EQL:Full({})", _0)] + Full(EqlValue), + + /// This type represents a an EQL payload with exactly the encrypted search terms required in order to satisy its + /// [`Bounds`]. + #[display("EQL:Partial({}: {})", _0, _1)] + Partial(EqlValue, EqlTraits), + + /// A JSON field or array index. The inferred type of the right hand side of the `->` operator when the + /// left hand side is an [`EqlValue`] that implements the EQL trait `JsonLike`. + JsonAccessor(EqlValue), + + /// A JSON path. The inferred type of the second argument to functions such `jsonb_path_query` when the first + /// argument is an [`EqlValue`] that implements the EQL trait `JsonLike`. + JsonPath(EqlValue), + + /// A text value that can be used as the right hand side of `LIKE` or `ILIKE` when the left hand side is an + /// [`EqlValue`] that implements the EQL trait `TokenMatch`. + Tokenized(EqlValue), +} + +impl EqlTerm { + pub fn table_column(&self) -> &TableColumn { + match self { + EqlTerm::Full(eql_value) + | EqlTerm::Partial(eql_value, _) + | EqlTerm::JsonAccessor(eql_value) + | EqlTerm::JsonPath(eql_value) + | EqlTerm::Tokenized(eql_value) => eql_value.table_column(), + } + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, PartialOrd, Ord, derive_more::Display)] +#[display("{eql_trait}::{type_name}")] +pub struct AssociatedTypeSelector { + pub eql_trait: EqlTrait, + pub type_name: &'static str, +} + +impl AssociatedTypeSelector { + pub(crate) fn new( + eql_trait: EqlTrait, + associated_type_name: &'static str, + ) -> Result { + if eql_trait.has_associated_type(associated_type_name) { + Ok(Self { + eql_trait, + type_name: associated_type_name, + }) + } else { + Err(TypeError::InternalError(format!( + "Trait {eql_trait} does not define associated type {associated_type_name}" + ))) + } + } + + pub(crate) fn resolve(&self, ty: Arc) -> Result, TypeError> { + Ok(self.eql_trait.resolve_associated_type(ty, self)?.clone()) + } } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] @@ -171,10 +220,10 @@ pub struct TableColumn { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] #[display("EQL({})", _0)] -pub struct EqlValue(pub TableColumn); +pub struct EqlValue(pub TableColumn, pub EqlTraits); #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] -#[display("NATIVE{}", _0.as_ref().map(|tc| format!("({})", tc)).unwrap_or(String::from("")))] +#[display("{}", _0.as_ref().map(|tc| format!("({tc})")).unwrap_or(String::from("")))] pub struct NativeValue(pub Option); /// A column from a projection. @@ -195,78 +244,85 @@ pub struct TypeVar(pub usize); impl From for Type { fn from(tvar: TypeVar) -> Self { - Type::Var(tvar) + Type::Var(Var(tvar, EqlTraits::none())) } } impl Type { - /// Creates a `Type` containing an empty projection - pub(crate) fn empty_projection() -> Type { - Type::Constructor(Constructor::Projection(Projection::Empty)) + /// Creates a `Type::Value(Projection::Empty)`. + pub(crate) const fn empty_projection() -> Type { + Type::Value(Value::Projection(Projection(vec![]))) + } + + /// Creates a `Type::Value(Value::Native(NativeValue(None)))`. + pub(crate) const fn native() -> Type { + Type::Value(Value::Native(NativeValue(None))) } - /// Creates a `Type` containing a `Constructor::Scalar(Scalar::Native(NativeValue(None)))`. - pub(crate) fn any_native() -> Type { - Type::Constructor(Constructor::Value(Value::Native(NativeValue(None)))) + /// Creates a `Type::Value(Value::SetOf(ty))`. + pub(crate) const fn set_of(ty: Arc) -> Type { + Type::Value(Value::SetOf(SetOf(ty))) } - /// Creates a `Type` containing a `Constructor::Projection`. + /// Creates a `Type::Value(Value::Projection(Projection::WithColumns(columns)))`. pub(crate) fn projection(columns: &[(Arc, Option)]) -> Type { - if columns.is_empty() { - Type::Constructor(Constructor::Projection(Projection::Empty)) - } else { - Type::Constructor(Constructor::Projection(Projection::WithColumns( - ProjectionColumns( - columns - .iter() - .map(|(c, n)| ProjectionColumn::new(c.clone(), n.clone())) - .collect(), - ), - ))) - } + Type::Value(Value::Projection(Projection( + columns + .iter() + .map(|(c, n)| ProjectionColumn::new(c.clone(), n.clone())) + .collect(), + ))) } - /// Creates a `Type` containing a `Constructor::Array`. + /// Creates a `Type::Value(Value::Array(element_ty))`. pub(crate) fn array(element_ty: impl Into>) -> Arc { - Type::Constructor(Constructor::Value(Value::Array(element_ty.into()))).into() + Type::Value(Value::Array(Array(element_ty.into()))).into() } - /// Follows [`Type::Var`] types until a [`Type::Constructor`] is reached. Aborts and returns the last resolved type - /// when either a type variable has no substitution or it resolves to a constructor is found. pub(crate) fn follow_tvars(self: Arc, unifier: &Unifier<'_>) -> Arc { - let mut current_ty = self; - - loop { - match &*current_ty { - Type::Constructor(Constructor::Projection(Projection::WithColumns( - ProjectionColumns(cols), - ))) => { - let cols = cols - .iter() - .map(|col| ProjectionColumn { - ty: col.ty.clone().follow_tvars(unifier), - alias: col.alias.clone(), - }) - .collect(); - return Arc::new(Type::Constructor(Constructor::Projection( - Projection::WithColumns(ProjectionColumns(cols)), - ))); - } - Type::Constructor(Constructor::Projection(Projection::Empty)) => return current_ty, - Type::Constructor(Constructor::Value(Value::Array(ty))) => { - return Arc::new(Type::Constructor(Constructor::Value(Value::Array( - ty.clone().follow_tvars(unifier), - )))) - } - Type::Constructor(Constructor::Value(_)) => return current_ty, - Type::Var(tvar) => { - if let Some(ty) = unifier.get_type(*tvar) { - current_ty = ty.follow_tvars(unifier); - } else { - return current_ty; - } + match &*self.clone() { + Type::Value(Value::Projection(Projection(cols))) => { + let cols = cols + .iter() + .map(|col| ProjectionColumn { + ty: col.ty.clone().follow_tvars(unifier), + alias: col.alias.clone(), + }) + .collect(); + Projection(cols).into() + } + + Type::Value(Value::Array(Array(ty))) => Arc::new(Type::Value(Value::Array(Array( + ty.clone().follow_tvars(unifier), + )))), + + Type::Value(Value::SetOf(SetOf(ty))) => ty.clone().follow_tvars(unifier), + + Type::Value(_) => self, + + Type::Var(Var(tvar, _)) => { + if let Some(ty) = unifier.get_type(*tvar) { + ty.follow_tvars(unifier) + } else { + self } } + + Type::Associated(AssociatedType { + impl_ty, + resolved_ty, + selector, + }) => { + let impl_ty = impl_ty.clone().follow_tvars(unifier); + let resolved_ty = resolved_ty.clone().follow_tvars(unifier); + + Type::Associated(AssociatedType { + impl_ty, + resolved_ty, + selector: selector.clone(), + }) + .into() + } } } @@ -276,40 +332,59 @@ impl Type { /// /// Fails with a [`TypeError`] if the stored `Type` cannot be fully resolved. pub fn resolved(&self, unifier: &mut Unifier<'_>) -> Result { - match self { - Type::Constructor(constructor) => constructor.resolve(unifier), - Type::Var(type_var) => { - if let Some(sub_ty) = unifier.get_type(*type_var) { - return sub_ty.resolved(unifier); - } - - Err(TypeError::Incomplete(format!( - "there are no substitutions for type var {}", - type_var - ))) - } - } + self.resolve_type(unifier) } pub(crate) fn resolved_as( &self, unifier: &mut Unifier<'_>, ) -> Result { - let resolved_ty: crate::Type = self.resolved(unifier)?; + let resolved_ty: crate::Type = self.resolve_type(unifier)?; let result = match &resolved_ty { - crate::Type::Projection(projection) => { + crate::Type::Value(crate::Value::Projection(projection)) => { if let Some(t) = (projection as &dyn std::any::Any).downcast_ref::() { return Ok(t.clone()); } Err(()) } - crate::Type::Value(value) => { - if let Some(t) = (value as &dyn std::any::Any).downcast_ref::() { + crate::Type::Value(crate::Value::SetOf(ty)) => { + if let Some(t) = (ty as &dyn std::any::Any).downcast_ref::() { return Ok(t.clone()); } + Err(()) + } + crate::Type::Value(value) => { + match value { + crate::Value::Eql(maybe_t) => { + if let Some(t) = (maybe_t as &dyn std::any::Any).downcast_ref::() { + return Ok(t.clone()); + } + } + crate::Value::Native(maybe_t) => { + if let Some(t) = (maybe_t as &dyn std::any::Any).downcast_ref::() { + return Ok(t.clone()); + } + } + crate::Value::Array(maybe_t) => { + if let Some(t) = (maybe_t as &dyn std::any::Any).downcast_ref::() { + return Ok(t.clone()); + } + } + crate::Value::Projection(maybe_t) => { + if let Some(t) = (maybe_t as &dyn std::any::Any).downcast_ref::() { + return Ok(t.clone()); + } + } + crate::Value::SetOf(maybe_t) => { + if let Some(t) = (maybe_t as &dyn std::any::Any).downcast_ref::() { + return Ok(t.clone()); + } + } + } + Err(()) } }; @@ -322,61 +397,88 @@ impl Type { )) }) } + + pub(crate) fn must_implement(&self, bounds: &EqlTraits) -> Result<(), TypeError> { + if self.effective_bounds().intersection(bounds) == *bounds { + Ok(()) + } else { + Err(TypeError::UnsatisfiedBounds( + Arc::new(self.clone()), + self.effective_bounds().difference(bounds), + )) + } + } } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] -#[display("PROJ[{}]", _0.iter().map(|pc| pc.to_string()).collect::>().join(", "))] -pub struct ProjectionColumns(pub(crate) Vec); +impl EqlValue { + pub fn table_column(&self) -> &TableColumn { + &self.0 + } + + pub fn trait_impls(&self) -> EqlTraits { + self.1 + } +} /// The type of an [`sqltk::parser::ast::Expr`] or [`sqltk::parser::ast::Statement`] that returns a projection. /// /// It represents an ordered list of zero or more optionally aliased columns types. -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Display, Hash)] -pub enum Projection { - /// A projection with columns - #[display("{}", _0)] - WithColumns(ProjectionColumns), - - /// A projection without columns. - /// - /// An `INSERT`, `UPDATE` or `DELETE` statement without a `RETURNING` clause will have an empty projection. - /// - /// Also statements such as `SELECT FROM users` where there are no selected columns or wildcards will have an empty - /// projection. - #[display("PROJ[]")] - Empty, +/// +/// An `INSERT`, `UPDATE` or `DELETE` statement without a `RETURNING` clause will have an empty projection. +/// +/// Also statements such as `SELECT FROM users` where there are no selected columns or wildcards will have an empty +/// projection. +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Hash)] +pub struct Projection(pub Vec); + +impl Display for Projection { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("{")?; + for (idx, col) in self.0.iter().enumerate() { + col.fmt(f)?; + if idx < self.0.len() - 1 { + f.write_str(", ")?; + } + } + f.write_str("}") + } } impl Projection { pub fn new(columns: Vec) -> Self { - if columns.is_empty() { - Projection::Empty - } else { - Projection::WithColumns(ProjectionColumns(Vec::from_iter(columns.iter().cloned()))) - } + Self(columns) } - pub(crate) fn flatten(&self) -> Self { - match self { - Projection::WithColumns(projection_columns) => { - Projection::WithColumns(projection_columns.flatten()) - } - Projection::Empty => Projection::Empty, - } + pub(crate) fn new_from_schema_table(table: Arc) -> Self { + Self( + table + .columns + .iter() + .map(|col| { + let tc = TableColumn { + table: table.name.clone(), + column: col.name.clone(), + }; + + let value_ty = match &col.kind { + ColumnKind::Native => Type::Value(Value::Native(NativeValue(Some(tc)))), + ColumnKind::Eql(features) => { + Type::Value(Value::Eql(EqlTerm::Full(EqlValue(tc, *features)))) + } + }; + + ProjectionColumn::new(value_ty, Some(col.name.clone())) + }) + .collect(), + ) } pub(crate) fn len(&self) -> usize { - match self { - Projection::WithColumns(projection_columns) => projection_columns.len(), - Projection::Empty => 0, - } + self.0.len() } pub(crate) fn columns(&self) -> &[ProjectionColumn] { - match self { - Projection::WithColumns(projection_columns) => projection_columns.0.as_slice(), - Projection::Empty => &[], - } + &self.0 } } @@ -384,40 +486,16 @@ impl Index for Projection { type Output = ProjectionColumn; fn index(&self, index: usize) -> &Self::Output { - match self { - Projection::WithColumns(projection_columns) => &projection_columns.0[index], - Projection::Empty => panic!("cannot index into an empty projection"), - } - } -} - -impl ProjectionColumns { - pub(crate) fn len(&self) -> usize { - self.0.len() - } - - pub(crate) fn flatten(&self) -> Self { - ProjectionColumns(self.flatten_impl(Vec::with_capacity(self.len()))) - } - - fn flatten_impl(&self, mut output: Vec) -> Vec { - for ProjectionColumn { ty, alias } in &self.0 { - match &**ty { - Type::Constructor(Constructor::Projection(Projection::WithColumns(nested))) => { - output = nested.flatten_impl(output); - } - _ => output.push(ProjectionColumn::new(ty.clone(), alias.clone())), - } - } - output + &self.0[index] } } impl ProjectionColumn { /// Returns a new `ProjectionColumn` with type `ty` and optional `alias`. pub(crate) fn new(ty: impl Into>, alias: Option) -> Self { + let ty: Arc = ty.into(); Self { - ty: ty.into(), + ty: ty.clone(), alias, } } @@ -430,27 +508,62 @@ impl ProjectionColumn { } } -impl ProjectionColumns { - pub(crate) fn new_from_schema_table(table: Arc
) -> Self { - ProjectionColumns( - table - .columns - .iter() - .map(|col| { - let tc = TableColumn { - table: table.name.clone(), - column: col.name.clone(), - }; +macro_rules! impl_from_for_arc_type { + ($ty:ty) => { + impl From<$ty> for Arc { + fn from(value: $ty) -> Self { + Arc::new(Type::from(value)) + } + } + }; +} - let value_ty = if col.kind == ColumnKind::Native { - Type::Constructor(Constructor::Value(Value::Native(NativeValue(Some(tc))))) - } else { - Type::Constructor(Constructor::Value(Value::Eql(EqlValue(tc)))) - }; +impl_from_for_arc_type!(NativeValue); +impl_from_for_arc_type!(Projection); +impl_from_for_arc_type!(Var); +impl_from_for_arc_type!(EqlTerm); +impl_from_for_arc_type!(Value); +impl_from_for_arc_type!(Array); +impl_from_for_arc_type!(AssociatedType); + +impl From for Type { + fn from(associated: AssociatedType) -> Self { + Type::Associated(associated) + } +} - ProjectionColumn::new(value_ty, Some(col.name.clone())) - }) - .collect(), - ) +impl From for Type { + fn from(value: Value) -> Self { + Type::Value(value) + } +} + +impl From for Type { + fn from(eql_term: EqlTerm) -> Self { + Type::Value(Value::Eql(eql_term)) + } +} + +impl From for Type { + fn from(var: Var) -> Self { + Type::Var(var) + } +} + +impl From for Type { + fn from(projection: Projection) -> Self { + Type::Value(Value::Projection(projection)) + } +} + +impl From for Type { + fn from(native: NativeValue) -> Self { + Type::Value(Value::Native(native)) + } +} + +impl From for Type { + fn from(array: Array) -> Self { + Type::Value(Value::Array(array)) } } diff --git a/packages/eql-mapper/src/inference/unifier/unify_types.rs b/packages/eql-mapper/src/inference/unifier/unify_types.rs new file mode 100644 index 00000000..44af5732 --- /dev/null +++ b/packages/eql-mapper/src/inference/unifier/unify_types.rs @@ -0,0 +1,260 @@ +//! The [`UnifyTypes`] trait definition and all of the implementations. +//! +//! The entry point for [`Type`] unification is [`Unifier::unify`] which is an inherent method on the [`Unifier`] itself +//! and not part of the `UnifyTypes` trait. + +use std::sync::Arc; + +use crate::{unifier::SetOf, TypeError}; + +use super::{ + Array, AssociatedType, EqlTerm, NativeValue, Projection, ProjectionColumn, Type, Unifier, + Value, Var, +}; + +/// Trait for unifying two types. +/// +/// The `Lhs` and `Rhs` type arguments are independenty specifiable because some different base types (such as `Var` + +/// `Constructor` and `Value` + `Projection`) can be unified. +pub(super) trait UnifyTypes { + /// Try to unify types `lhs` & `rhs` to produce a new [`Type`]. + fn unify_types(&mut self, lhs: &Lhs, rhs: &Rhs) -> Result, TypeError>; +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types(&mut self, lhs: &SetOf, rhs: &SetOf) -> Result, TypeError> { + Ok(Type::set_of(self.unify(lhs.inner_ty(), rhs.inner_ty())?).into()) + } +} + +// A Value can be unified with a single-column Projection. +impl UnifyTypes for Unifier<'_> { + fn unify_types(&mut self, lhs: &Value, rhs: &Projection) -> Result, TypeError> { + let len = rhs.len(); + if len == 1 { + self.unify_types(lhs, &rhs[0].ty) + } else { + Err(TypeError::Conflict(format!( + "cannot unify value type {lhs} with projection with > 1 column (it has {len} columns) {rhs}" + ))) + } + } +} + +impl UnifyTypes> for Unifier<'_> { + fn unify_types(&mut self, lhs: &Value, rhs: &Arc) -> Result, TypeError> { + self.unify(lhs.clone().into(), rhs.clone()) + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types(&mut self, lhs: &Value, rhs: &Value) -> Result, TypeError> { + match (lhs, rhs) { + (Value::Eql(lhs), Value::Eql(rhs)) => self.unify_types(lhs, rhs), + + (Value::Native(lhs), Value::Native(rhs)) => self.unify_types(lhs, rhs), + + (Value::Array(lhs), Value::Array(rhs)) => self.unify_types(lhs, rhs), + + (Value::Projection(lhs), Value::Projection(rhs)) => self.unify_types(lhs, rhs), + + (Value::SetOf(lhs), Value::SetOf(rhs)) => self.unify_types(lhs, rhs), + + // Special case: a value can be unified with a single-column projection (producing a value). + (value, Value::Projection(projection)) | (Value::Projection(projection), value) => { + self.unify_types(value, projection) + } + + (lhs, rhs) => Err(TypeError::Conflict(format!( + "cannot unify values {lhs} and {rhs}" + ))), + } + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types(&mut self, lhs: &Array, rhs: &Array) -> Result, TypeError> { + let Array(lhs_element_ty) = lhs; + let Array(rhs_element_ty) = rhs; + + self.unify(lhs_element_ty.clone(), rhs_element_ty.clone()) + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types(&mut self, lhs: &EqlTerm, rhs: &EqlTerm) -> Result, TypeError> { + match (lhs, rhs) { + (EqlTerm::Full(lhs), EqlTerm::Full(rhs)) if lhs == rhs => { + Ok(EqlTerm::Full(lhs.clone()).into()) + } + + (EqlTerm::Partial(lhs_eql, lhs_bounds), EqlTerm::Partial(rhs_eql, rhs_bounds)) + if lhs_eql == rhs_eql => + { + Ok(EqlTerm::Partial(lhs_eql.clone(), lhs_bounds.union(rhs_bounds)).into()) + } + + (EqlTerm::Full(full), EqlTerm::Partial(partial, _)) + | (EqlTerm::Partial(partial, _), EqlTerm::Full(full)) + if full == partial => + { + Ok(EqlTerm::Full(full.clone()).into()) + } + + (EqlTerm::JsonAccessor(lhs), EqlTerm::JsonAccessor(rhs)) if lhs == rhs => { + Ok(EqlTerm::JsonAccessor(lhs.clone()).into()) + } + + (EqlTerm::JsonPath(lhs), EqlTerm::JsonPath(rhs)) if lhs == rhs => { + Ok(EqlTerm::JsonPath(lhs.clone()).into()) + } + + (EqlTerm::Tokenized(lhs), EqlTerm::Tokenized(rhs)) if lhs == rhs => { + Ok(EqlTerm::Tokenized(lhs.clone()).into()) + } + + (_, _) => Err(TypeError::Conflict(format!( + "cannot unify EQL terms {lhs} and {rhs}" + ))), + } + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types( + &mut self, + lhs: &Value, + Var(tvar, bounds): &Var, + ) -> Result, TypeError> { + self.unify_with_type_var(Type::Value(lhs.clone()).into(), *tvar, bounds) + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types( + &mut self, + associated: &AssociatedType, + var: &Var, + ) -> Result, TypeError> { + if let Some(resolved_ty) = associated.resolve_selector_target(self)? { + self.unify(resolved_ty, var.clone().into()) + } else { + Ok(AssociatedType { + impl_ty: associated.impl_ty.clone(), + selector: associated.selector.clone(), + resolved_ty: self.unify(associated.resolved_ty.clone(), var.clone().into())?, + } + .into()) + } + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types( + &mut self, + lhs: &AssociatedType, + rhs: &AssociatedType, + ) -> Result, TypeError> { + Ok(AssociatedType { + impl_ty: self.unify(lhs.impl_ty.clone(), rhs.impl_ty.clone())?, + selector: if lhs.selector == rhs.selector { + lhs.selector.clone() + } else { + Err(TypeError::Conflict(format!( + "Cannot unify associated types {lhs} and {rhs}" + )))? + }, + resolved_ty: self.unify(lhs.resolved_ty.clone(), rhs.resolved_ty.clone())?, + } + .into()) + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types( + &mut self, + assoc: &AssociatedType, + value: &Value, + ) -> Result, TypeError> { + // If the associated type is resolved then unify the resolved value with the value arg, else unify to a + // new associated type where the unresolved type is unified with the value. + + if let Some(resolved_value) = assoc.resolve_selector_target(self)? { + self.unify(value.clone().into(), resolved_value) + } else { + Ok(AssociatedType { + impl_ty: assoc.impl_ty.clone(), + selector: assoc.selector.clone(), + resolved_ty: self.unify(assoc.resolved_ty.clone(), value.clone().into())?, + } + .into()) + } + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types(&mut self, lhs: &Var, rhs: &Var) -> Result, TypeError> { + let Var(lhs_tvar, lhs_bounds) = lhs; + let Var(rhs_tvar, rhs_bounds) = rhs; + + match (self.get_type(*lhs_tvar), self.get_type(*rhs_tvar)) { + (None, None) => { + let merged_bounds = lhs_bounds.union(rhs_bounds); + let unified = self.fresh_bounded_tvar(merged_bounds); + self.substitute(*lhs_tvar, unified.clone()); + self.substitute(*rhs_tvar, unified.clone()); + Ok(unified) + } + + (None, Some(rhs)) => { + self.satisfy_bounds(&rhs, lhs_bounds)?; + self.substitute(*lhs_tvar, rhs.clone()); + Ok(rhs) + } + + (Some(lhs), None) => { + self.satisfy_bounds(&lhs, rhs_bounds)?; + self.substitute(*rhs_tvar, lhs.clone()); + Ok(lhs) + } + + (Some(lhs), Some(rhs)) => self.unify(lhs, rhs), + } + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types(&mut self, lhs: &Projection, rhs: &Projection) -> Result, TypeError> { + if lhs.len() == rhs.len() { + let mut cols: Vec = Vec::with_capacity(lhs.len()); + + for (lhs_col, rhs_col) in lhs.columns().iter().zip(rhs.columns()) { + let unified_ty = self.unify(lhs_col.ty.clone(), rhs_col.ty.clone())?; + cols.push(ProjectionColumn::new(unified_ty, lhs_col.alias.clone())); + } + + Ok(Projection::new(cols).into()) + } else { + Err(TypeError::Conflict(format!( + "cannot unify projections {lhs} and {rhs} because they have different numbers of columns" + ))) + } + } +} + +impl UnifyTypes for Unifier<'_> { + fn unify_types( + &mut self, + lhs: &NativeValue, + rhs: &NativeValue, + ) -> Result, TypeError> { + match (lhs, rhs) { + (NativeValue(Some(_)), NativeValue(Some(_))) + | (NativeValue(Some(_)), NativeValue(None)) => Ok(Type::from(lhs.clone()).into()), + + (NativeValue(None), NativeValue(Some(_))) => Ok(Type::from(rhs.clone()).into()), + + _ => Ok(Type::from(lhs.clone()).into()), + } + } +} diff --git a/packages/eql-mapper/src/lib.rs b/packages/eql-mapper/src/lib.rs index d8d832e9..38a7500c 100644 --- a/packages/eql-mapper/src/lib.rs +++ b/packages/eql-mapper/src/lib.rs @@ -20,7 +20,7 @@ pub use eql_mapper::*; pub use model::*; pub use param::*; pub use type_checked_statement::*; -pub use unifier::{EqlValue, NativeValue, TableColumn}; +pub use unifier::{EqlTerm, EqlTrait, EqlTraits, EqlValue, NativeValue, TableColumn}; pub(crate) use dep::*; pub(crate) use inference::*; @@ -29,26 +29,19 @@ pub(crate) use transformation_rules::*; #[cfg(test)] mod test { - use super::test_helpers::*; - use super::type_check; - use crate::col; - use crate::projection; - use crate::test_helpers; - use crate::Param; - use crate::Schema; - use crate::TableResolver; + use super::{test_helpers::*, type_check}; use crate::{ - schema, unifier::EqlValue, unifier::NativeValue, Projection, ProjectionColumn, TableColumn, - Value, + projection, schema, test_helpers, + unifier::{EqlTerm, EqlTrait, EqlTraits, EqlValue, InstantiateType, NativeValue}, + Param, Projection, ProjectionColumn, Schema, TableColumn, TableResolver, Value, }; + use eql_mapper_macros::concrete_ty; use pretty_assertions::assert_eq; - use sqltk::parser::ast::Ident; - use sqltk::parser::ast::Statement; - use sqltk::parser::ast::{self as ast}; - use sqltk::AsNodeKey; - use sqltk::NodeKey; - use std::collections::HashMap; - use std::sync::Arc; + use sqltk::{ + parser::ast::{self as ast, Ident, Statement}, + AsNodeKey, NodeKey, + }; + use std::{collections::HashMap, sync::Arc}; use tracing::error; fn resolver(schema: Schema) -> Arc { @@ -74,7 +67,7 @@ mod test { Ok(typed) => { assert_eq!( typed.projection, - projection![(NATIVE(users.email) as email)] + concrete_ty!({ Native(users.email) as email } as crate::Projection) ) } Err(err) => panic!("type check failed: {err}"), @@ -88,7 +81,7 @@ mod test { tables: { users: { id, - email (EQL), + email (EQL: Eq), first_name, } } @@ -98,17 +91,24 @@ mod test { match type_check(schema, &statement) { Ok(typed) => { - assert_eq!(typed.projection, projection![(EQL(users.email) as email)]); - - eprintln!("TYPED LITS: {:#?}", typed.literals); + assert_eq!( + typed.projection, + concrete_ty! {{EQL(users.email: Eq) as email} as crate::Projection} + ); - assert!(typed.literals.contains(&( - EqlValue(TableColumn { - table: id("users"), - column: id("email") - }), - &ast::Value::SingleQuotedString("hello@cipherstash.com".into()) - ))); + assert_eq!( + typed.literals, + vec![( + EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("email"), + }, + EqlTraits::from(EqlTrait::Eq) + ),), + &ast::Value::SingleQuotedString("hello@cipherstash.com".into()), + )] + ); } Err(err) => panic!("type check failed: {err}"), } @@ -132,11 +132,14 @@ mod test { match type_check(schema, &statement) { Ok(typed) => { assert!(typed.literals.contains(&( - EqlValue(TableColumn { - table: id("users"), - column: id("email") - }), - &ast::Value::SingleQuotedString("hello@cipherstash.com".into()) + EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("email") + }, + EqlTraits::default() + )), + &ast::Value::SingleQuotedString("hello@cipherstash.com".into()), ))); } Err(err) => panic!("type check failed: {err}"), @@ -161,11 +164,14 @@ mod test { match type_check(schema, &statement) { Ok(typed) => { assert!(typed.literals.contains(&( - EqlValue(TableColumn { - table: id("users"), - column: id("email") - }), - &ast::Value::SingleQuotedString("hello@cipherstash.com".into()) + EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("email") + }, + EqlTraits::default() + )), + &ast::Value::SingleQuotedString("hello@cipherstash.com".into()), ))); } Err(err) => panic!("type check failed: {err}"), @@ -191,11 +197,14 @@ mod test { match type_check(schema, &statement) { Ok(typed) => { assert!(typed.literals.contains(&( - EqlValue(TableColumn { - table: id("users"), - column: id("email") - }), - &ast::Value::SingleQuotedString("hello@cipherstash.com".into()) + EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("email") + }, + EqlTraits::default() + )), + &ast::Value::SingleQuotedString("hello@cipherstash.com".into()), ))); } Err(err) => panic!("type check failed: {err}"), @@ -219,14 +228,14 @@ mod test { match type_check(schema, &statement) { Ok(typed) => { - let v = Value::Native(NativeValue(Some(TableColumn { + let v: Value = Value::Native(NativeValue(Some(TableColumn { table: id("users"), column: id("id"), }))); - let (_, param_value) = typed.params.first().unwrap(); + let (_, value) = typed.params.first().unwrap(); - assert_eq!(param_value, &v); + assert_eq!(value, &v); assert_eq!( typed.projection, @@ -450,7 +459,54 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), + }; + + assert_eq!( + typed.projection, + projection![ + (NATIVE(users.id) as id), + (EQL(users.email) as email), + (NATIVE(todo_lists.id) as id), + (NATIVE(todo_lists.owner_id) as owner_id), + (EQL(todo_lists.secret) as secret) + ] + ); + } + + #[test] + fn wildcard_expansion_2() { + // init_tracing(); + let schema = resolver(schema! { + tables: { + users: { + id, + email (EQL), + } + todo_lists: { + id, + owner_id, + secret (EQL), + } + } + }); + + let statement = parse( + r#" + select * from ( + select + u.*, + tl.* + from + users as u + inner join todo_lists as tl on tl.owner_id = u.id + ) + "#, + ); + + let typed = match type_check(schema, &statement) { + Ok(typed) => typed, + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( @@ -482,17 +538,23 @@ mod test { match type_check(schema, &statement) { Ok(typed) => { - let a = Value::Eql(EqlValue(TableColumn { - table: id("users"), - column: id("email"), - })); + let a = Value::Eql(EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("email"), + }, + EqlTraits::default(), + ))); - let b = Value::Eql(EqlValue(TableColumn { - table: id("users"), - column: id("first_name"), - })); + let b = Value::Eql(EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("first_name"), + }, + EqlTraits::default(), + ))); - assert_eq!(typed.params, vec![(Param(1), a), (Param(2), b)]); + assert_eq!(typed.params, vec![(Param(1), a,), (Param(2), b,)]); assert_eq!( typed.projection, @@ -514,8 +576,8 @@ mod test { tables: { users: { id, - salary (EQL), - age (EQL), + salary (EQL: Ord), + age (EQL: Ord), } } }); @@ -524,24 +586,30 @@ mod test { match type_check(schema, &statement) { Ok(typed) => { - let a = Value::Eql(EqlValue(TableColumn { - table: id("users"), - column: id("salary"), - })); + let a = Value::Eql(EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("salary"), + }, + EqlTraits::from(EqlTrait::Ord), + ))); - let b = Value::Eql(EqlValue(TableColumn { - table: id("users"), - column: id("age"), - })); + let b = Value::Eql(EqlTerm::Full(EqlValue( + TableColumn { + table: id("users"), + column: id("age"), + }, + EqlTraits::from(EqlTrait::Ord), + ))); - assert_eq!(typed.params, vec![(Param(1), a), (Param(2), b)]); + assert_eq!(typed.params, vec![(Param(1), a,), (Param(2), b,)]); assert_eq!( typed.projection, projection![ (NATIVE(users.id) as id), - (EQL(users.salary) as salary), - (EQL(users.age) as age) + (EQL(users.salary: Ord) as salary), + (EQL(users.age: Ord) as age) ] ); } @@ -557,7 +625,7 @@ mod test { id, first_name, last_name, - salary (EQL), + salary (EQL: Ord), } } }); @@ -577,7 +645,7 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( @@ -585,7 +653,7 @@ mod test { projection![ (NATIVE(employees.first_name) as first_name), (NATIVE(employees.last_name) as last_name), - (EQL(employees.salary) as salary) + (EQL(employees.salary: Ord) as salary) ] ); } @@ -601,7 +669,7 @@ mod test { first_name, last_name, department_name, - salary (EQL), + salary (EQL: Ord), } } }); @@ -621,7 +689,7 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( @@ -630,7 +698,7 @@ mod test { (NATIVE(employees.first_name) as first_name), (NATIVE(employees.last_name) as last_name), (NATIVE(employees.department_name) as department_name), - (EQL(employees.salary) as salary), + (EQL(employees.salary: Ord) as salary), (NATIVE as rank) ] ); @@ -668,7 +736,7 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( @@ -719,7 +787,7 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, Err(err) => { - panic!("type check failed: {:#?}", err) + panic!("type check failed: {err:#?}") } }; @@ -745,7 +813,7 @@ mod test { id, department, age, - salary (EQL), + salary (EQL: Ord), } } }); @@ -762,14 +830,14 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( typed.projection, projection![ (NATIVE(employees.age) as max), - (EQL(employees.salary) as min) + (EQL(employees.salary: Ord) as min) ] ); } @@ -799,10 +867,10 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; - assert_eq!(typed.projection, Projection::Empty); + assert_eq!(typed.projection, Projection(vec![])); } #[test] @@ -831,7 +899,7 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( @@ -870,10 +938,10 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; - assert_eq!(typed.projection, Projection::Empty); + assert_eq!(typed.projection, Projection(vec![])); } #[test] @@ -900,7 +968,7 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( @@ -926,7 +994,7 @@ mod test { name, department, age, - salary (EQL), + salary (EQL: Ord), } } }); @@ -939,10 +1007,10 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; - assert_eq!(typed.projection, Projection::Empty); + assert_eq!(typed.projection, Projection(vec![])); } #[test] @@ -956,7 +1024,7 @@ mod test { name, department, age, - salary (EQL), + salary (EQL: Ord), } } }); @@ -969,7 +1037,7 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( @@ -979,7 +1047,7 @@ mod test { (NATIVE(employees.name) as name), (NATIVE(employees.department) as department), (NATIVE(employees.age) as age), - (EQL(employees.salary) as salary) + (EQL(employees.salary: Ord) as salary) ] ); } @@ -995,7 +1063,7 @@ mod test { name, department, age, - salary (EQL), + salary (EQL: Ord), } } }); @@ -1008,17 +1076,20 @@ mod test { let typed = match type_check(schema.clone(), &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( typed.literals, vec![( - EqlValue(TableColumn { - table: id("employees"), - column: id("salary") - }), - &ast::Value::Number(200000.into(), false) + EqlTerm::Full(EqlValue( + TableColumn { + table: id("employees"), + column: id("salary") + }, + EqlTraits::from(EqlTrait::Ord) + ),), + &ast::Value::Number(200000.into(), false), )] ); @@ -1030,7 +1101,7 @@ mod test { transformed_statement.to_string(), "SELECT * FROM employees WHERE salary > 'ENCRYPTED'::JSONB::eql_v2_encrypted" ), - Err(err) => panic!("statement transformation failed: {}", err), + Err(err) => panic!("statement transformation failed: {err}"), }; } @@ -1055,16 +1126,19 @@ mod test { let typed = match type_check(schema.clone(), &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( typed.literals, vec![( - EqlValue(TableColumn { - table: id("employees"), - column: id("salary") - }), + EqlTerm::Full(EqlValue( + TableColumn { + table: id("employees"), + column: id("salary") + }, + EqlTraits::default() + )), &ast::Value::Number(20000.into(), false) )] ); @@ -1077,7 +1151,7 @@ mod test { transformed_statement.to_string(), "INSERT INTO employees (salary) VALUES ('ENCRYPTED'::JSONB::eql_v2_encrypted)" ), - Err(err) => panic!("statement transformation failed: {}", err), + Err(err) => panic!("statement transformation failed: {err}"), }; } @@ -1091,7 +1165,7 @@ mod test { id, department_id, name, - salary (EQL), + salary (EQL: Ord), } } }); @@ -1123,18 +1197,18 @@ mod test { let typed = match type_check(schema, &statement) { Ok(typed) => typed, - Err(err) => panic!("type check failed: {:#?}", err), + Err(err) => panic!("type check failed: {err:#?}"), }; assert_eq!( typed.projection, projection![ - (EQL(employees.salary) as min_salary), - (EQL(employees.salary) as y), + (EQL(employees.salary: Ord) as min_salary), + (EQL(employees.salary: Ord) as y), (NATIVE(employees.id) as id), (NATIVE(employees.department_id) as department_id), (NATIVE(employees.name) as name), - (EQL(employees.salary) as salary) + (EQL(employees.salary: Ord) as salary) ] ); } @@ -1169,7 +1243,7 @@ mod test { projection_type(&parse("select $1")), projection_type(&parse("select t from (select $1 as t)")), projection_type(&parse("select * from (select $1)")), - Projection::WithColumns(vec![ProjectionColumn { + Projection(vec![ProjectionColumn { alias: None, ty: Value::Native(NativeValue(None)), }]), @@ -1404,26 +1478,39 @@ mod test { tables: { employees: { id, - eql_col (EQL), + eql_col (EQL: JsonLike), native_col, } } }); - let statement = parse(" - SELECT jsonb_path_query(eql_col, '$.secret'), jsonb_path_query(native_col, '$.not-secret') FROM employees - "); + let statement = parse( + " + SELECT + jsonb_path_exists(eql_col, '$.another-secret'), + jsonb_path_query(eql_col, '$.secret'), + jsonb_path_query(native_col, '$.not-secret') + FROM employees + ", + ); match type_check(schema, &statement) { Ok(typed) => { match typed.transform(test_helpers::dummy_encrypted_json_selector( &statement, - ast::Value::SingleQuotedString("$.secret".into()), + vec![ + ast::Value::SingleQuotedString("$.secret".into()), + ast::Value::SingleQuotedString("$.another-secret".into()), + ], )) { Ok(statement) => { assert_eq!( statement.to_string(), - "SELECT eql_v2.jsonb_path_query(eql_col, ''::JSONB::eql_v2_encrypted), jsonb_path_query(native_col, '$.not-secret') FROM employees" + "SELECT \ + eql_v2.jsonb_path_exists(eql_col, ''::JSONB::eql_v2_encrypted), \ + eql_v2.jsonb_path_query(eql_col, ''::JSONB::eql_v2_encrypted), \ + jsonb_path_query(native_col, '$.not-secret') \ + FROM employees" ); } Err(err) => panic!("transformation failed: {err}"), @@ -1447,6 +1534,7 @@ mod test { #[test] fn jsonb_operator_arrow() { + // init_tracing(); test_jsonb_operator("->"); } @@ -1456,36 +1544,31 @@ mod test { } #[test] - fn jsonb_operator_hash_arrow() { - test_jsonb_operator("#>"); - } - - #[test] - fn jsonb_operator_hash_long_arrow() { - test_jsonb_operator("#>>"); - } - - #[test] + #[ignore = "? is unimplemented"] fn jsonb_operator_hash_at_at() { test_jsonb_operator("@@"); } #[test] + #[ignore = "@? is unimplemented"] fn jsonb_operator_at_question() { test_jsonb_operator("@?"); } #[test] + #[ignore = "? is unimplemented"] fn jsonb_operator_question() { test_jsonb_operator("?"); } #[test] + #[ignore = "?& is unimplemented"] fn jsonb_operator_question_and() { test_jsonb_operator("?&"); } #[test] + #[ignore = "?| is unimplemented"] fn jsonb_operator_question_pipe() { test_jsonb_operator("?|"); } @@ -1514,15 +1597,12 @@ mod test { ); } - // TODO: do we need to check that the RHS of JSON operators MUST be a Value node - // and not an arbitrary expression? - fn test_jsonb_function(fn_name: &str, args: Vec) { let schema = resolver(schema! { tables: { patients: { id, - notes (EQL), + notes (EQL: JsonLike), } } }); @@ -1534,8 +1614,7 @@ mod test { .join(", "); let statement = parse(&format!( - "SELECT id, {}({}) AS meds FROM patients", - fn_name, args_in + "SELECT id, {fn_name}({args_in}) AS meds FROM patients" )); let args_encrypted = args @@ -1546,7 +1625,7 @@ mod test { value: ast::Value::SingleQuotedString(s), span: _, }) => { - format!("''::JSONB::eql_v2_encrypted", s) + format!("''::JSONB::eql_v2_encrypted") } _ => panic!("unsupported expr type in test util"), }) @@ -1559,7 +1638,7 @@ mod test { if let ast::Expr::Value(ast::ValueWithSpan { value, .. }) = arg { encrypted_literals.extend(test_helpers::dummy_encrypted_json_selector( &statement, - value.clone(), + vec![value.clone()], )); } } @@ -1587,22 +1666,21 @@ mod test { tables: { patients: { id, - notes (EQL), + notes (EQL: JsonLike + Contain), } } }); let statement = parse(&format!( - "SELECT id, notes {} 'medications' AS meds FROM patients", - op + "SELECT id, notes {op} 'medications' AS meds FROM patients", )); match type_check(schema, &statement) { Ok(typed) => { - match typed.transform(test_helpers::dummy_encrypted_json_selector(&statement, ast::Value::SingleQuotedString("medications".to_owned()))) { + match typed.transform(test_helpers::dummy_encrypted_json_selector(&statement, vec![ast::Value::SingleQuotedString("medications".to_owned())])) { Ok(statement) => assert_eq!( statement.to_string(), - format!("SELECT id, notes {} ''::JSONB::eql_v2_encrypted AS meds FROM patients", op) + format!("SELECT id, notes {op} ''::JSONB::eql_v2_encrypted AS meds FROM patients") ), Err(err) => panic!("transformation failed: {err}"), } @@ -1611,6 +1689,42 @@ mod test { } } + #[test] + fn eql_term_partial_is_unified_with_eql_term_whole() { + // init_tracing(); + let schema = resolver(schema! { + tables: { + patients: { + id, + email (EQL: Eq), + } + } + }); + + // let statement = parse( + // "SELECT id, email FROM patients WHERE email = 'alice@example.com'" + // ); + + let statement = parse( + " + SELECT id, email FROM patients AS p + INNER JOIN ( + SELECT 'alice@example.com' AS selector + ) AS selectors + WHERE p.email = selectors.selector + ", + ); + + let typed = type_check(schema, &statement) + .map_err(|err| err.to_string()) + .unwrap(); + + assert_eq!( + typed.projection, + projection![(NATIVE(patients.id) as id), (EQL(patients.email: Eq) as email)] + ); + } + #[test] fn select_with_multiple_joins() { // init_tracing(); @@ -1719,4 +1833,97 @@ mod test { Err(err) => panic!("type check failed: {err}"), } } + + #[test] + fn jsonb_path_query_param_to_eql() { + // init_tracing(); + let schema = resolver(schema! { + tables: { + patients: { + id, + notes (EQL: JsonLike), + } + } + }); + + let statement = parse("SELECT eql_v2.jsonb_path_query(notes, $1) as notes FROM patients"); + + let typed = type_check(schema, &statement) + .map_err(|err| err.to_string()) + .unwrap(); + + assert_eq!( + typed.projection, + projection![(EQL(patients.notes: JsonLike) as notes)] + ); + } + + #[test] + fn ensure_eql_mapper_does_not_choke_on_elixir_ecto_schema_metadata_query() { + // init_tracing(); + let schema = resolver(schema! { + tables: { + pg_attribute: { + attrelid, + attnum, + atttypid, + attisdropped, + } + pg_type: { + oid, + typname, + typsend, + typreceive, + typoutput, + typinput, + typbasetype, + typrelid, + typelem, + } + pg_range: { + rngtypid, + rngmultitypid, + rngsubtype, + } + } + }); + + let statement = parse( + "SELECT + t.oid, + t.typname, + t.typsend, + t.typreceive, + t.typoutput, + t.typinput, + coalesce(d.typelem, t.typelem), + coalesce(r.rngsubtype, 0), + ARRAY( + SELECT + a.atttypid + FROM + pg_attribute AS a + WHERE + a.attrelid = t.typrelid + AND a.attnum > 0 + AND NOT a.attisdropped + ORDER BY a.attnum + ) FROM pg_type AS t + LEFT JOIN pg_type AS d ON t.typbasetype = d.oid + LEFT JOIN pg_range AS r ON r.rngtypid = t.oid OR r.rngmultitypid = t.oid OR ( + t.typbasetype <> 0 + AND r.rngtypid = t.typbasetype + ) + WHERE + (t.typrelid = 0) + AND (t.typelem = 0 OR NOT EXISTS ( + SELECT 1 FROM pg_type AS s + WHERE s.typrelid <> 0 AND s.oid = t.typelem + ))", + ); + + type_check(schema, &statement) + .map_err(|err| err.to_string()) + .unwrap(); + } } diff --git a/packages/eql-mapper/src/model/mod.rs b/packages/eql-mapper/src/model/mod.rs index 4a2082c0..f2f1be35 100644 --- a/packages/eql-mapper/src/model/mod.rs +++ b/packages/eql-mapper/src/model/mod.rs @@ -1,4 +1,3 @@ -mod provenance; mod relation; mod schema; mod schema_delta; @@ -6,7 +5,6 @@ mod sql_ident; mod table_resolver; mod type_system; -pub use provenance::*; pub use schema::*; pub use schema_delta::*; pub use sql_ident::*; diff --git a/packages/eql-mapper/src/model/provenance.rs b/packages/eql-mapper/src/model/provenance.rs deleted file mode 100644 index e6af2ec9..00000000 --- a/packages/eql-mapper/src/model/provenance.rs +++ /dev/null @@ -1,35 +0,0 @@ -use crate::{model::schema::Table, Projection, TableColumn}; - -#[derive(Debug, Clone, Eq, PartialEq)] -pub enum Provenance { - Select(SelectProvenance), - Insert(InsertProvenance), - Update(UpdateProvenance), - Delete(DeleteProvenance), -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct SelectProvenance { - pub projection: Projection, -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct InsertProvenance { - pub into_table: Table, - pub returning: Option, - pub columns_written: Vec, - pub source_projection: Option, -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct UpdateProvenance { - pub update_table: Table, - pub returning: Option, - pub columns_written: Vec, -} - -#[derive(Debug, Clone, Eq, PartialEq)] -pub struct DeleteProvenance { - pub from_table: Table, - pub returning: Option, -} diff --git a/packages/eql-mapper/src/model/schema.rs b/packages/eql-mapper/src/model/schema.rs index e9047ddd..ce9c6e73 100644 --- a/packages/eql-mapper/src/model/schema.rs +++ b/packages/eql-mapper/src/model/schema.rs @@ -3,7 +3,7 @@ //! Column type information is unused currently. use super::sql_ident::*; -use crate::iterator_ext::IteratorExt; +use crate::{iterator_ext::IteratorExt, unifier::EqlTraits}; use core::fmt::Debug; use derive_more::Display; use sqltk::parser::ast::{Ident, ObjectName, ObjectNamePart}; @@ -41,14 +41,14 @@ pub struct Column { #[derive(Debug, Clone, Copy, PartialEq, Eq, Display, Hash)] pub enum ColumnKind { Native, - Eql, + Eql(EqlTraits), } impl Column { - pub fn eql(name: Ident) -> Self { + pub fn eql(name: Ident, features: EqlTraits) -> Self { Self { name, - kind: ColumnKind::Eql, + kind: ColumnKind::Eql(features), } } @@ -156,7 +156,7 @@ impl Schema { } else { Err(SchemaError::ColumnNotFound( format!("{table_name}"), - format!("{}", column_name), + format!("{column_name}"), )) } } @@ -192,6 +192,45 @@ impl Table { } } +#[macro_export] +macro_rules! to_eql_trait_impls { + ($($indexes:ident)*) => { + { + #[allow(unused_mut)] + let mut impls = $crate::unifier::EqlTraits::default(); + $crate::to_eql_trait_impls!(@flags impls $($indexes)*); + impls + } + }; + + (@flags $impls:ident Eq $($indexes:ident)*) => { + $impls.add_mut(EqlTrait::Eq); + $crate::to_eql_trait_impls!(@flags $impls $($indexes)*); + }; + + (@flags $impls:ident Ord $($indexes:ident)*) => { + $impls.add_mut(EqlTrait::Ord); + $crate::to_eql_trait_impls!(@flags $impls $($indexes)*); + }; + + (@flags $impls:ident TokenMatch $($indexes:ident)*) => { + $impls.add_mut(EqlTrait::TokenMatch); + $crate::to_eql_trait_impls!(@flags $impls $($indexes)*); + }; + + (@flags $impls:ident JsonLike $($indexes:ident)*) => { + $impls.add_mut(EqlTrait::JsonLike); + $crate::to_eql_trait_impls!(@flags $impls $($indexes)*); + }; + + (@flags $impls:ident Contain $($indexes:ident)*) => { + $impls.add_mut(EqlTrait::Contain); + $crate::to_eql_trait_impls!(@flags $impls $($indexes)*); + }; + + (@flags $impls:ident) => {} +} + /// A DSL to create a [`Schema`] for testing purposes. // #[cfg(test)] #[macro_export] @@ -235,11 +274,21 @@ macro_rules! schema { (@add_columns $table:ident $( $column_name:ident $(($($options:tt)+))? , )* ) => { $( schema!(@add_column $table $column_name $(($($options)*))? ); )* }; - (@add_column $table:ident $column_name:ident (EQL) ) => { + (@add_column $table:ident $column_name:ident (EQL $(: $trait_:ident $(+ $trait_rest:ident)*)?) ) => { $table.add_column(std::sync::Arc::new($crate::model::Column::eql( - ::sqltk::parser::ast::Ident::new(stringify!($column_name)) + ::sqltk::parser::ast::Ident::new(stringify!($column_name)), + $crate::to_eql_trait_impls!($($trait_ $($trait_rest)*)?) ))); }; + (@add_column $table:ident $column_name:ident (PK) ) => { + $table.add_column( + std::sync::Arc::new( + $crate::model::Column::native( + ::sqltk::parser::ast::Ident::new(stringify!($column_name)) + ) + ), + ); + }; (@add_column $table:ident $column_name:ident () ) => { $table.add_column( std::sync::Arc::new( diff --git a/packages/eql-mapper/src/model/schema_delta.rs b/packages/eql-mapper/src/model/schema_delta.rs index 8f30a5a9..7c17fdbf 100644 --- a/packages/eql-mapper/src/model/schema_delta.rs +++ b/packages/eql-mapper/src/model/schema_delta.rs @@ -363,6 +363,7 @@ mod test { use crate::{ schema, test_helpers::{id, object_name, parse}, + unifier::EqlTraits, ColumnKind, SchemaError, SchemaTableColumn, TableResolver, }; @@ -443,7 +444,7 @@ mod test { Ok(SchemaTableColumn { table: id("users"), column: id("primary_email"), - kind: ColumnKind::Eql + kind: ColumnKind::Eql(EqlTraits::default()) }) ) } @@ -475,7 +476,7 @@ mod test { Ok(SchemaTableColumn { table: id("app_users"), column: id("email"), - kind: ColumnKind::Eql + kind: ColumnKind::Eql(EqlTraits::default()) }) ) } @@ -525,7 +526,7 @@ mod test { Ok(SchemaTableColumn { table: id("users"), column: id("email"), - kind: ColumnKind::Eql + kind: ColumnKind::Eql(EqlTraits::default()) }) ); diff --git a/packages/eql-mapper/src/model/sql_ident.rs b/packages/eql-mapper/src/model/sql_ident.rs index aa5a0a3b..8f57beff 100644 --- a/packages/eql-mapper/src/model/sql_ident.rs +++ b/packages/eql-mapper/src/model/sql_ident.rs @@ -10,7 +10,7 @@ use sqltk::parser::ast::Ident; /// quoted or not. /// /// For an "official" explanation of how SQL identifiers work (at least with respect to Postgres), see -/// [https://www.postgresql.org/docs/14/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS]. +/// []. /// /// SQL is wild, hey! #[derive(Debug, Clone, Display)] diff --git a/packages/eql-mapper/src/model/type_system.rs b/packages/eql-mapper/src/model/type_system.rs index a4ef7fe4..21c6865c 100644 --- a/packages/eql-mapper/src/model/type_system.rs +++ b/packages/eql-mapper/src/model/type_system.rs @@ -5,33 +5,36 @@ //! `eql_mapper`'s internal representation of the type system contains additional implementation details which would not //! be pleasant for public consumption. -use crate::unifier::{EqlValue, NativeValue}; +use crate::unifier::{EqlTerm, NativeValue}; use derive_more::Display; use sqltk::parser::ast::Ident; /// The resolved type of a [`sqltk::parser::ast::Expr`] node. #[derive(Debug, Clone, PartialEq, Eq, Display)] -#[display("{self}")] pub enum Type { /// A value type (an EQL type, native database type or an array type) #[display("{}", _0)] Value(Value), +} - /// A projection type that is parameterized by a list of projection column types. - #[display("{}", _0)] - Projection(Projection), +#[derive(Debug, Clone, PartialEq, Eq, Display)] +#[display("SetOf<{}>", _0)] +pub struct SetOf(pub Box); +impl SetOf { + fn contains_eql(&self) -> bool { + self.0.contains_eql() + } } /// A value type (an EQL type, native database type or an array type) #[derive(Debug, Clone, PartialEq, Eq, Display)] -#[display("{self}")] pub enum Value { /// An encrypted type from a particular table-column in the schema. /// /// An encrypted column never shares a type with another encrypted column - which is why it is sufficient to /// identify the type by its table & column names. #[display("{}", _0)] - Eql(EqlValue), + Eql(EqlTerm), /// A native database type. #[display("{}", _0)] @@ -39,34 +42,69 @@ pub enum Value { /// An array type that is parameterized by an element type. #[display("Array[{}]", _0)] - Array(Box), + Array(Array), + + /// A projection type that is parameterized by a list of projection column types. + #[display("{}", _0)] + Projection(Projection), + + /// In PostgreSQL, SETOF is a special return type used in functions to indicate that the function returns a set of + /// rows rather than a single value. It allows a function to behave like a table or subquery in SQL, producing + /// multiple rows as output. + #[display("{}", _0)] + SetOf(SetOf), } +#[derive(Debug, Clone, PartialEq, Eq, Display)] +pub struct Array(pub Box); + /// A projection type that is parameterized by a list of projection column types. #[derive(Debug, Clone, PartialEq, Eq, Display)] -#[display("{self}")] -pub enum Projection { - #[display("PROJ[{}]", _0.iter().map(|pc| pc.to_string()).collect::>().join(", "))] - WithColumns(Vec), +#[display("{{{}}}", _0.iter().map(|pc| pc.to_string()).collect::>().join(", "))] +pub struct Projection(pub Vec); + +impl Type { + pub fn contains_eql(&self) -> bool { + match self { + Type::Value(value) => value.contains_eql(), + } + } +} - #[display("PROJ[]")] - Empty, +impl Value { + pub fn contains_eql(&self) -> bool { + match self { + Value::Eql(_) => true, + Value::Native(_) => false, + Value::Array(inner) => inner.contains_eql(), + Value::Projection(projection) => projection.contains_eql(), + Value::SetOf(set_of) => set_of.contains_eql(), + } + } +} + +impl Array { + pub fn contains_eql(&self) -> bool { + let Array(element_ty) = self; + element_ty.contains_eql() + } } impl Projection { pub fn new(columns: Vec) -> Self { - if columns.is_empty() { - Projection::Empty - } else { - Projection::WithColumns(columns) - } + Self(columns) } pub fn type_at_col_index(&self, index: usize) -> Option<&Value> { - match self { - Projection::WithColumns(cols) => cols.get(index).map(|col| &col.ty), - Projection::Empty => None, - } + self.0.get(index).map(|col| &col.ty) + } + + pub fn contains_eql(&self) -> bool { + self.0.iter().any(|col| col.ty.contains_eql()) + } + + pub fn columns(&self) -> &[ProjectionColumn] { + &self.0 } } @@ -84,8 +122,38 @@ pub struct ProjectionColumn { impl ProjectionColumn { fn render_alias(&self) -> String { match &self.alias { - Some(name) => format!(": {}", name), + Some(name) => format!(": {name}"), None => String::from(""), } } } + +impl From for Type { + fn from(value: Value) -> Self { + Type::Value(value) + } +} + +impl From for Type { + fn from(array: Array) -> Self { + Type::Value(Value::Array(array)) + } +} + +impl From for Type { + fn from(eql_term: EqlTerm) -> Self { + Type::Value(Value::Eql(eql_term)) + } +} + +impl From for Type { + fn from(projection: Projection) -> Self { + Type::Value(Value::Projection(projection)) + } +} + +impl From for Type { + fn from(native: NativeValue) -> Self { + Type::Value(Value::Native(native)) + } +} diff --git a/packages/eql-mapper/src/scope_tracker.rs b/packages/eql-mapper/src/scope_tracker.rs index 5c6c4830..a0b6bd13 100644 --- a/packages/eql-mapper/src/scope_tracker.rs +++ b/packages/eql-mapper/src/scope_tracker.rs @@ -1,9 +1,9 @@ //! Types for representing and maintaining a lexical scope during AST traversal. -use crate::inference::unifier::{Constructor, ProjectionColumn, Type}; +use crate::inference::unifier::{ProjectionColumn, Type}; use crate::inference::TypeError; use crate::iterator_ext::IteratorExt; use crate::model::SqlIdent; -use crate::unifier::{Projection, ProjectionColumns}; +use crate::unifier::{Projection, Value}; use crate::Relation; use sqltk::parser::ast::{Ident, ObjectName, ObjectNamePart, Query, Statement}; use sqltk::{into_control_flow, Break, Visitable, Visitor}; @@ -115,7 +115,7 @@ impl<'ast> Scope<'ast> { .map(|r| ProjectionColumn::new(r.projection_type.clone(), None)) .collect(); - Ok(Type::Constructor(Constructor::Projection(Projection::new(columns))).into()) + Ok(Type::Value(Value::Projection(Projection::new(columns))).into()) } } @@ -169,7 +169,7 @@ impl<'ast> Scope<'ast> { |mut acc, columns| { columns .map(|columns| { - acc.extend(columns.0.iter().cloned()); + acc.extend(columns.iter().cloned()); acc }) .map_err(|err| ScopeError::TypeError(Box::new(err))) @@ -185,8 +185,7 @@ impl<'ast> Scope<'ast> { Ok(None) => match &self.parent { Some(parent) => parent.borrow().resolve_ident(ident), None => Err(ScopeError::NoMatch(format!( - "identifier {} not found in scope", - ident + "identifier {ident} not found in scope" ))), }, } @@ -212,24 +211,20 @@ impl<'ast> Scope<'ast> { let columns = self .try_match_projection(named_relation.projection_type.clone()) .map_err(|err| ScopeError::TypeError(Box::new(err)))?; - let mut columns = columns.0.iter(); + let mut columns = columns.iter(); match columns.try_find_unique(&|column| { column.alias.as_ref().map(SqlIdent::from).as_ref() == Some(&second_ident) }) { Ok(Some(projection_column)) => Ok(projection_column.ty.clone()), - Ok(None) | Err(_) => Err(ScopeError::NoMatch(format!( - "{}.{}", - first_ident, second_ident - ))), + Ok(None) | Err(_) => { + Err(ScopeError::NoMatch(format!("{first_ident}.{second_ident}"))) + } } } Ok(None) | Err(_) => match &self.parent { Some(parent) => parent.borrow().resolve_compound_ident(idents), - None => Err(ScopeError::NoMatch(format!( - "{}.{}", - first_ident, second_ident - ))), + None => Err(ScopeError::NoMatch(format!("{first_ident}.{second_ident}"))), }, } } @@ -261,11 +256,11 @@ impl<'ast> Scope<'ast> { } } - fn try_match_projection(&self, ty: Arc) -> Result { + fn try_match_projection(&self, ty: Arc) -> Result, TypeError> { match &*ty { - Type::Constructor(Constructor::Projection(projection)) => Ok(ProjectionColumns( - Vec::from_iter(projection.columns().iter().cloned()), - )), + Type::Value(Value::Projection(projection)) => { + Ok(Vec::from_iter(projection.columns().iter().cloned())) + } other => Err(TypeError::Expected(format!( "expected projection but got: {other}" ))), diff --git a/packages/eql-mapper/src/test_helpers.rs b/packages/eql-mapper/src/test_helpers.rs index 29cdac56..839b49b5 100644 --- a/packages/eql-mapper/src/test_helpers.rs +++ b/packages/eql-mapper/src/test_helpers.rs @@ -22,6 +22,7 @@ pub(crate) fn init_tracing() { tracing_subscriber::fmt() .with_max_level(tracing::Level::TRACE) .with_span_events(FmtSpan::ACTIVE) + // .with_env_filter(EnvFilter::new("eql-mapper::UNIFY=trace,eql-mapper::EVENT_SUBSTITUTE=trace,eql-mapper::INFER_EXIT=trace,eql-mapper::TYPE_ENV=trace")) .with_file(true) .event_format(format().pretty()) .pretty() @@ -52,16 +53,21 @@ pub(crate) fn get_node_key_of_json_selector<'ast>( pub(crate) fn dummy_encrypted_json_selector( statement: &Statement, - selector: Value, + selectors: Vec, ) -> HashMap, ast::Value> { - if let Value::SingleQuotedString(s) = &selector { - HashMap::from_iter(vec![( - get_node_key_of_json_selector(statement, &selector), - ast::Value::SingleQuotedString(format!("", s)), - )]) - } else { - panic!("dummy_encrypted_json_selector only works on Value::SingleQuotedString") + let mut dummy_encrypted_values = HashMap::new(); + for selector in selectors.into_iter() { + if let Value::SingleQuotedString(s) = &selector { + dummy_encrypted_values.insert( + get_node_key_of_json_selector(statement, &selector), + ast::Value::SingleQuotedString(format!("")), + ); + } else { + panic!("dummy_encrypted_json_selector only works on Value::SingleQuotedString") + } } + + dummy_encrypted_values } /// Utility for finding the [`NodeKey`] of a [`Value`] node in `statement` by providing a `matching` equal node to search for. @@ -134,42 +140,50 @@ macro_rules! col { } }; - ((EQL($table:ident . $column:ident))) => { + ((EQL($table:ident . $column:ident $(: $($eql_traits:ident)*)?))) => { ProjectionColumn { - ty: Value::Eql(EqlValue::from((stringify!($table), stringify!($column)))), + ty: Value::Eql(EqlTerm::Full(EqlValue(TableColumn { + table: id(stringify!($table)), + column: id(stringify!($column)), + }, $crate::to_eql_traits!($($($eql_traits)*)?)))), alias: None, } }; - ((EQL($table:ident . $column:ident) as $alias:ident)) => { + ((EQL($table:ident . $column:ident $(: $($eql_traits:ident)*)?) as $alias:ident)) => { ProjectionColumn { - ty: Value::Eql(EqlValue(TableColumn { + ty: Value::Eql(EqlTerm::Full(EqlValue(TableColumn { table: id(stringify!($table)), column: id(stringify!($column)), - })), + }, $crate::to_eql_traits!($($($eql_traits)*)?)))), alias: Some(id(stringify!($alias))), } }; } +#[macro_export] +macro_rules! to_eql_traits { + () => { $crate::unifier::EqlTraits::default() }; + + ($($traits:ident)*) => { + EqlTraits::from_iter(vec![$($crate::unifier::EqlTrait::$traits,)*]) + }; +} + #[macro_export] macro_rules! projection { - [$($column:tt),*] => { Projection::new(vec![$(col!($column)),*]) }; + [$($column:tt),*] => { Projection::new(vec![$($crate::col!($column)),*]) }; } pub fn ignore_aliases(t: &Projection) -> Projection { - match t { - Projection::WithColumns(columns) => Projection::WithColumns( - columns - .iter() - .map(|pc| ProjectionColumn { - ty: pc.ty.clone(), - alias: None, - }) - .collect(), - ), - Projection::Empty => Projection::Empty, - } + Projection( + t.0.iter() + .map(|pc| ProjectionColumn { + ty: pc.ty.clone(), + alias: None, + }) + .collect(), + ) } pub fn assert_transitive_eq(items: &[T]) { diff --git a/packages/eql-mapper/src/transformation_rules/rewrite_standard_sql_fns_on_eql_types.rs b/packages/eql-mapper/src/transformation_rules/rewrite_standard_sql_fns_on_eql_types.rs index f82e28bc..eb31d278 100644 --- a/packages/eql-mapper/src/transformation_rules/rewrite_standard_sql_fns_on_eql_types.rs +++ b/packages/eql-mapper/src/transformation_rules/rewrite_standard_sql_fns_on_eql_types.rs @@ -1,10 +1,12 @@ use std::mem; use std::{collections::HashMap, sync::Arc}; -use sqltk::parser::ast::{Expr, Function, Ident, ObjectName, ObjectNamePart}; +use sqltk::parser::ast::{ + Expr, Function, FunctionArg, FunctionArguments, Ident, ObjectName, ObjectNamePart, +}; use sqltk::{AsNodeKey, NodeKey, NodePath, Visitable}; -use crate::{get_sql_function_def, EqlMapperError, RewriteRule, SqlFunction, Type, Value}; +use crate::{get_sql_function, EqlMapperError, Type, Value}; use super::TransformationRule; @@ -17,6 +19,32 @@ impl<'ast> RewriteStandardSqlFnsOnEqlTypes<'ast> { pub fn new(node_types: Arc, Type>>) -> Self { Self { node_types } } + + /// Returns `true` if at least one argument and/or return type is an EQL type. + fn uses_eql_type(&self, function: &Function) -> bool { + if matches!( + self.node_types.get(&function.as_node_key()), + Some(Type::Value(Value::Eql(_))) + ) { + return true; + } + + match &function.args { + FunctionArguments::None => false, + FunctionArguments::Subquery(query) => matches!( + self.node_types.get(&query.as_node_key()), + Some(Type::Value(Value::Eql(_))) + ), + FunctionArguments::List(list) => list.args.iter().any(|arg| match arg { + FunctionArg::Named { arg, .. } + | FunctionArg::ExprNamed { arg, .. } + | FunctionArg::Unnamed(arg) => matches!( + self.node_types.get(&arg.as_node_key()), + Some(Type::Value(Value::Eql(_))) + ), + }), + } + } } impl<'ast> TransformationRule<'ast> for RewriteStandardSqlFnsOnEqlTypes<'ast> { @@ -26,23 +54,11 @@ impl<'ast> TransformationRule<'ast> for RewriteStandardSqlFnsOnEqlTypes<'ast> { target_node: &mut N, ) -> Result { if self.would_edit(node_path, target_node) { - if let Some((_expr, function)) = node_path.last_2_as::() { - if matches!( - self.node_types.get(&function.as_node_key()), - Some(Type::Value(Value::Eql(_))) - ) { - if let Some(SqlFunction { - rewrite_rule: RewriteRule::AsEqlFunction, - .. - }) = get_sql_function_def(&function.name, &function.args) - { - let function = target_node.downcast_mut::().unwrap(); - let mut existing_name = mem::take(&mut function.name.0); - existing_name.insert(0, ObjectNamePart::Identifier(Ident::new("eql_v2"))); - function.name = ObjectName(existing_name); - } - } - } + let function = target_node.downcast_mut::().unwrap(); + let mut existing_name = mem::take(&mut function.name.0); + existing_name.insert(0, ObjectNamePart::Identifier(Ident::new("eql_v2"))); + function.name = ObjectName(existing_name); + return Ok(true); } Ok(false) @@ -50,18 +66,8 @@ impl<'ast> TransformationRule<'ast> for RewriteStandardSqlFnsOnEqlTypes<'ast> { fn would_edit(&mut self, node_path: &NodePath<'ast>, _target_node: &N) -> bool { if let Some((_expr, function)) = node_path.last_2_as::() { - if matches!( - self.node_types.get(&function.as_node_key()), - Some(Type::Value(Value::Eql(_))) - ) { - if let Some(SqlFunction { - rewrite_rule: RewriteRule::AsEqlFunction, - .. - }) = get_sql_function_def(&function.name, &function.args) - { - return true; - } - } + return get_sql_function(&function.name).should_rewrite() + && self.uses_eql_type(function); } false diff --git a/packages/eql-mapper/src/type_checked_statement.rs b/packages/eql-mapper/src/type_checked_statement.rs index 418b09d5..8fd56088 100644 --- a/packages/eql-mapper/src/type_checked_statement.rs +++ b/packages/eql-mapper/src/type_checked_statement.rs @@ -3,16 +3,14 @@ use std::{collections::HashMap, sync::Arc}; use sqltk::parser::ast::{self, Statement}; use sqltk::{AsNodeKey, NodeKey, Transformable}; +use crate::unifier::EqlTerm; use crate::{ - CastLiteralsAsEncrypted, CastParamsAsEncrypted, DryRunnable, EqlMapperError, EqlValue, + CastLiteralsAsEncrypted, CastParamsAsEncrypted, DryRunnable, EqlMapperError, FailOnPlaceholderChange, Param, PreserveEffectiveAliases, Projection, RewriteStandardSqlFnsOnEqlTypes, TransformationRule, Type, Value, }; /// A `TypeCheckedStatement` is returned from a successful call to [`crate::type_check`]. -/// -/// It stores a reference to the type-checked [`Statement`], the type of the -/// #[derive(Debug)] pub struct TypeCheckedStatement<'ast> { /// A reference to the original unmodified [`Statement`]. @@ -24,24 +22,24 @@ pub struct TypeCheckedStatement<'ast> { /// The types of all params discovered from [`Value::Placeholder`] nodes in the SQL statement. pub params: Vec<(Param, Value)>, - /// The type ([`EqlValue`]) and reference to an [`ast::Value`] nodes of all EQL literals from the SQL statement. - pub literals: Vec<(EqlValue, &'ast ast::Value)>, + /// The type ([`EqlTerm`]) and reference to an [`ast::Value`] nodes of all EQL literals from the SQL statement. + pub literals: Vec<(EqlTerm, &'ast ast::Value)>, /// A [`HashMap`] of AST node (using [`NodeKey`] as the key) to [`Type`]. The map contains a `Type` for every node /// in the AST with the node type is one of: [`Statement`], [`Query`], [`Insert`], [`Delete`], [`Expr`], /// [`SetExpr`], [`Select`], [`SelectItem`], [`Vec`], [`Function`], [`Values`], [`Value`]. /// - /// [`Query`]: sqlparser::ast::Query - /// [`Insert`]: sqlparser::ast::Insert - /// [`Delete`]: sqlparser::ast::Delete - /// [`Expr`]: sqlparser::ast::Expr - /// [`SetExpr`]: sqlparser::ast::SetExpr - /// [`Select`]: sqlparser::ast::Select - /// [`SelectItem`]: sqlparser::ast::SelectItem - /// [`Function`]: sqlparser::ast::Function - /// [`FunctionArgExpr`]: sqlparser::ast::FunctionArgExpr - /// [`Values`]: sqlparser::ast::Values - /// [`Value`]: sqlparser::ast::Value + /// [`Query`]: sqltk::parser::ast::Query + /// [`Insert`]: sqltk::parser::ast::Insert + /// [`Delete`]: sqltk::parser::ast::Delete + /// [`Expr`]: sqltk::parser::ast::Expr + /// [`SetExpr`]: sqltk::parser::ast::SetExpr + /// [`Select`]: sqltk::parser::ast::Select + /// [`SelectItem`]: sqltk::parser::ast::SelectItem + /// [`Function`]: sqltk::parser::ast::Function + /// [`FunctionArgExpr`]: sqltk::parser::ast::FunctionArgExpr + /// [`Values`]: sqltk::parser::ast::Values + /// [`Value`]: sqltk::parser::ast::Value pub node_types: Arc, Type>>, } @@ -50,7 +48,7 @@ impl<'ast> TypeCheckedStatement<'ast> { statement: &'ast Statement, projection: Projection, params: Vec<(Param, Value)>, - literals: Vec<(EqlValue, &'ast ast::Value)>, + literals: Vec<(EqlTerm, &'ast ast::Value)>, node_types: Arc, Type>>, ) -> Self { Self { @@ -64,9 +62,7 @@ impl<'ast> TypeCheckedStatement<'ast> { /// Returns `true` if one or more SQL param placeholders in the body has an EQL type, otherwise returns `false`. pub fn params_contain_eql(&self) -> bool { - self.params - .iter() - .any(|p| matches!(p.1, Value::Eql(EqlValue(_)))) + self.params.iter().any(|p| matches!(p.1, Value::Eql(_))) } /// Tests if a statement transformation is required. This works by executing all of the transformation rules but @@ -97,11 +93,8 @@ impl<'ast> TypeCheckedStatement<'ast> { self.statement.apply_transform(&mut transformer) } - pub fn literal_values(&self) -> Vec<&sqltk::parser::ast::Value> { - self.literals - .iter() - .map(|(_, value)| *value) - .collect::>() + pub fn literal_values(&self) -> &Vec<(EqlTerm, &'ast sqltk::parser::ast::Value)> { + &self.literals } fn dummy_encrypted_literals(&self) -> HashMap, ast::Value> { diff --git a/tests/benchmark/sql/benchmark-schema.sql b/tests/benchmark/sql/benchmark-schema.sql index 7a5be642..191da2db 100644 --- a/tests/benchmark/sql/benchmark-schema.sql +++ b/tests/benchmark/sql/benchmark-schema.sql @@ -19,6 +19,8 @@ SELECT eql_v2.add_column( 'email' ); -SELECT eql_v2.encrypt(); -SELECT eql_v2.activate(); +-- SELECT eql_v2.encrypt(); +-- SELECT eql_v2.activate(); +SELECT eql_v2.migrate_config(); +SELECT eql_v2.activate_config(); diff --git a/tests/integration/elixir_test/lib/elixir_test/encrypted.ex b/tests/integration/elixir_test/lib/elixir_test/encrypted.ex index 714d24f3..6e35ca86 100644 --- a/tests/integration/elixir_test/lib/elixir_test/encrypted.ex +++ b/tests/integration/elixir_test/lib/elixir_test/encrypted.ex @@ -2,8 +2,8 @@ defmodule ElixirTest.Encrypted do use Ecto.Schema import Ecto.Changeset - @primary_key {:id, :id, autogenerate: false} - schema "encrypted" do + @primary_key {:id, :id, autogenerate: true} + schema "encrypted_elixir" do field(:plaintext, :string) field(:plaintext_date, :date) field(:encrypted_text, :string) @@ -19,7 +19,6 @@ defmodule ElixirTest.Encrypted do def changeset(encrypted, attrs) do encrypted |> cast(attrs, [ - :id, :plaintext, :plaintext_date, :encrypted_text, diff --git a/tests/integration/elixir_test/test/elixir_test_test.exs b/tests/integration/elixir_test/test/elixir_test_test.exs index eb2bf8b3..7df41639 100644 --- a/tests/integration/elixir_test/test/elixir_test_test.exs +++ b/tests/integration/elixir_test/test/elixir_test_test.exs @@ -1,5 +1,5 @@ defmodule ElixirTestTest do - use ExUnit.Case + use ExUnit.Case, async: false doctest ElixirTest alias ElixirTest.Encrypted alias ElixirTest.Repo @@ -7,8 +7,6 @@ defmodule ElixirTestTest do setup do :ok = Ecto.Adapters.SQL.Sandbox.checkout(Repo) - max_id = Repo.aggregate(Encrypted, :max, :id) || 1 - %{next_id: max_id + 1} end test "db connection test" do @@ -17,9 +15,9 @@ defmodule ElixirTestTest do assert result.rows == [[1]] end - test "plaintext save and load", %{next_id: next_id} do + test "plaintext save and load" do {:ok, result} = - %Encrypted{id: next_id, plaintext: "plaintext content", plaintext_date: ~D[2025-06-02]} + %Encrypted{plaintext: "plaintext content", plaintext_date: ~D[2025-06-02]} |> Repo.insert() fetched = Encrypted |> Repo.get(result.id) @@ -28,9 +26,9 @@ defmodule ElixirTestTest do assert fetched.plaintext_date == ~D[2025-06-02] end - test "encrypted text save and load", %{next_id: next_id} do + test "encrypted text save and load" do {:ok, result} = - %Encrypted{id: next_id, encrypted_text: "encrypted text content"} + %Encrypted{encrypted_text: "encrypted text content"} |> Repo.insert() fetched = Encrypted |> Repo.get(result.id) @@ -38,10 +36,9 @@ defmodule ElixirTestTest do assert fetched.encrypted_text == "encrypted text content" end - test "encrypted fields save and load", %{next_id: next_id} do + test "encrypted fields save and load" do {:ok, result} = %Encrypted{ - id: next_id, encrypted_bool: false, encrypted_int2: 2, encrypted_int4: 4, @@ -52,6 +49,7 @@ defmodule ElixirTestTest do } |> Repo.insert() + fetched = Encrypted |> Repo.get(result.id) assert !fetched.encrypted_bool @@ -63,136 +61,129 @@ defmodule ElixirTestTest do assert fetched.encrypted_jsonb == %{"top" => %{"array" => [1, 2, 3]}} end - test "find by exact text", %{next_id: next_id} do + test "find by exact text" do {2, _} = Encrypted |> Repo.insert_all([ - %{id: next_id, encrypted_text: "encrypted text content"}, - %{id: next_id + 1, encrypted_text: "some other encrypted text"} + %{encrypted_text: "encrypted text content"}, + %{encrypted_text: "some other encrypted text"} ]) q = - from(e in "encrypted", + from e in Encrypted, where: e.encrypted_text == "encrypted text content", select: [e.encrypted_text] - ) fetched = Repo.all(q) assert Enum.at(fetched, 0) == ["encrypted text content"] end - test "find by text match", %{next_id: next_id} do + test "find by text match" do {2, _} = Encrypted |> Repo.insert_all([ - %{id: next_id, encrypted_text: "encrypted text content"}, - %{id: next_id + 1, encrypted_text: "some other encrypted text"} + %{encrypted_text: "encrypted text content"}, + %{encrypted_text: "some other encrypted text"} ]) q = - from(e in "encrypted", + from e in Encrypted, where: like(e.encrypted_text, "text cont"), select: [e.encrypted_text] - ) fetched = Repo.all(q) assert Enum.at(fetched, 0) == ["encrypted text content"] end - test "find by float value - currently not supported", %{next_id: next_id} do + test "find by float value - currently not supported" do {2, _} = Encrypted |> Repo.insert_all([ - %{id: next_id, encrypted_float8: 0.0}, - %{id: next_id + 1, encrypted_float8: 7.5} + %{encrypted_float8: 0.0}, + %{encrypted_float8: 7.5} ]) # Ecto appends explicit cast to `7.5`, making it `7.5::float` and causes # the "operator does not exist" error q = - from(e in "encrypted", + from e in Encrypted, where: e.encrypted_float8 == 7.5, select: [e.id, e.encrypted_float8] - ) assert_raise(Postgrex.Error, fn -> Repo.all(q) end) end - test "find by float value", %{next_id: next_id} do + test "find by float value" do {2, _} = Encrypted |> Repo.insert_all([ - %{id: next_id, encrypted_float8: 0.0}, - %{id: next_id + 1, encrypted_float8: 7.5} + %{encrypted_float8: 0.0}, + %{encrypted_float8: 7.5} ]) q = - from(e in "encrypted", + from e in Encrypted, where: fragment("? = 7.5", e.encrypted_float8), - select: [e.id, e.encrypted_float8] - ) + select: [e.encrypted_float8] fetched = Repo.all(q) - assert Enum.at(fetched, 0) == [next_id + 1, 7.5] + assert Enum.at(fetched, 0) == [7.5] end - test "find by float value gt", %{next_id: next_id} do + test "find by float value gt" do {2, _} = Encrypted |> Repo.insert_all([ - %{id: next_id, encrypted_float8: 0.0}, - %{id: next_id + 1, encrypted_float8: 7.5} + %{encrypted_float8: 0.0}, + %{encrypted_float8: 7.5} ]) q = - from(e in "encrypted", + from e in Encrypted, where: fragment("? > 3.0", e.encrypted_float8), - select: [e.id, e.encrypted_float8] - ) + select: [e.encrypted_float8] fetched = Repo.all(q) - assert Enum.at(fetched, 0) == [next_id + 1, 7.5] + assert Enum.at(fetched, 0) == [7.5] end - test "order by integer", %{next_id: next_id} do + test "order by integer" do {3, _} = Encrypted |> Repo.insert_all([ - %{id: next_id, encrypted_int2: 7}, - %{id: next_id + 1, encrypted_int2: 9}, - %{id: next_id + 2, encrypted_int2: 0} + %{encrypted_int2: 7}, + %{encrypted_int2: 9}, + %{encrypted_int2: 0} ]) q = - from(e in "encrypted", + from e in Encrypted, order_by: e.encrypted_int2, select: [e.encrypted_int2] - ) fetched = Repo.all(q) |> List.flatten() assert fetched == [0, 7, 9] end - test "find by text and float", %{next_id: next_id} do + test "find by text and float" do {3, _} = Encrypted |> Repo.insert_all([ - %{id: next_id, encrypted_text: "encrypted text content", encrypted_float8: 1.0}, - %{id: next_id + 1, encrypted_text: "encrypted text content", encrypted_float8: 3.0}, - %{id: next_id + 2, encrypted_text: "some other encrypted text", encrypted_float8: 5.0} + %{encrypted_text: "encrypted text content", encrypted_float8: 1.0}, + %{encrypted_text: "encrypted text content", encrypted_float8: 3.0}, + %{encrypted_text: "some other encrypted text", encrypted_float8: 5.0} ]) q = - from(e in "encrypted", + from e in Encrypted, where: like(e.encrypted_text, "text cont"), where: fragment("? > 2.0", e.encrypted_float8), select: [e.encrypted_text, e.encrypted_float8] - ) fetched = Repo.all(q) diff --git a/tests/sql/schema.sql b/tests/sql/schema.sql index 8f98b933..ef2ca0f3 100644 --- a/tests/sql/schema.sql +++ b/tests/sql/schema.sql @@ -160,5 +160,151 @@ SELECT eql_v2.add_search_config( SELECT eql_v2.add_encrypted_constraint('encrypted', 'encrypted_text'); -SELECT eql_v2.encrypt(); -SELECT eql_v2.activate(); +SELECT eql_v2.migrate_config(); +SELECT eql_v2.activate_config(); + +-- This is the exact same schema as above but using a database-generated primary key. +-- It is required to remove flake form the Elixir integration test suite. +-- TODO: port all the rest of our integration tests to this schema. +DROP TABLE IF EXISTS encrypted_elixir; +CREATE TABLE encrypted_elixir ( + id serial, + plaintext text, + plaintext_date date, + plaintext_domain domain_type_with_check, + encrypted_text eql_v2_encrypted, + encrypted_bool eql_v2_encrypted, + encrypted_int2 eql_v2_encrypted, + encrypted_int4 eql_v2_encrypted, + encrypted_int8 eql_v2_encrypted, + encrypted_float8 eql_v2_encrypted, + encrypted_date eql_v2_encrypted, + encrypted_jsonb eql_v2_encrypted, + PRIMARY KEY(id) +); + +DROP TABLE IF EXISTS unconfigured_elixir; +CREATE TABLE unconfigured_elixir ( + id serial, + encrypted_unconfigured eql_v2_encrypted, + PRIMARY KEY(id) +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_text', + 'unique', + 'text' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_text', + 'match', + 'text' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_text', + 'ore', + 'text' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_bool', + 'unique', + 'boolean' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_bool', + 'ore', + 'boolean' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_int2', + 'unique', + 'small_int' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_int2', + 'ore', + 'small_int' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_int4', + 'unique', + 'int' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_int4', + 'ore', + 'int' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_int8', + 'unique', + 'big_int' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_int8', + 'ore', + 'big_int' +); + + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_float8', + 'unique', + 'double' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_float8', + 'ore', + 'double' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_date', + 'unique', + 'date' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_date', + 'ore', + 'date' +); + +SELECT eql_v2.add_search_config( + 'encrypted_elixir', + 'encrypted_jsonb', + 'ste_vec', + 'jsonb', + '{"prefix": "encrypted/encrypted_jsonb"}' +); + +SELECT eql_v2.add_encrypted_constraint('encrypted_elixir', 'encrypted_text'); + +SELECT eql_v2.migrate_config(); +SELECT eql_v2.activate_config(); \ No newline at end of file