diff --git a/Cargo.lock b/Cargo.lock index 37a2d04..11ea648 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -564,7 +564,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "sqlite3-editor" -version = "1.0.190" +version = "1.0.191" dependencies = [ "base64", "clap", diff --git a/Cargo.toml b/Cargo.toml index d06881d..a9c0def 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlite3-editor" -version = "1.0.190" +version = "1.0.191" edition = "2021" [features] diff --git a/src/code_lens.rs b/src/code_lens.rs index 596cf0c..9246a6f 100644 --- a/src/code_lens.rs +++ b/src/code_lens.rs @@ -3,6 +3,7 @@ use sqlparser::keywords::Keyword; use sqlparser::tokenizer::{Token, Word}; use crate::keywords::START_OF_STATEMENT_KEYWORDS_UNSUPPORTED_BY_SQLPARSER; +use crate::list_placeholders::{list_placeholders, Placeholder}; use crate::parse_cte::parse_cte; use crate::split_statements::{get_text_range, split_sqlite_statements}; use crate::sqlite3::escape_sql_identifier; @@ -24,6 +25,8 @@ pub struct CodeLens { pub start: ZeroIndexedLocation, pub end: ZeroIndexedLocation, pub stmt_executed: String, + pub cte_identifier: Option, + pub placeholders: Vec, } /// Returns a list of code lenses for the given SQL input. @@ -43,22 +46,33 @@ pub fn code_lens(sql: &str) -> Vec { if let Some(cte) = cte { for entry in cte.entries { let with_clause = get_text_range(&lines, &stmt.real_start, &cte.body_start); + let cte_ident = escape_sql_identifier(&get_text_range(&lines, &entry.ident_start, &entry.ident_end)); let select_stmt = format!( "{}SELECT * FROM {}", if with_clause.ends_with(' ') { "" } else { " " }, - escape_sql_identifier(&get_text_range(&lines, &entry.ident_start, &entry.ident_end)) + cte_ident.clone() ); + let stmt_executed = with_clause + &select_stmt; + let Ok((cte_stmt, _)) = split_sqlite_statements(&stmt_executed) else { + continue; + }; + let Some(cte_stmt) = cte_stmt.first() else { + continue; + }; code_lens.push(CodeLens { kind: CodeLensKind::Select, - stmt_executed: with_clause + &select_stmt, + placeholders: list_placeholders(cte_stmt), + stmt_executed, start: entry.ident_start, end: entry.ident_end, + cte_identifier: Some(cte_ident), }) } cte_end = cte.body_start; } let mut kind: Option = None; + let placeholders = list_placeholders(&stmt); for token in stmt.real_tokens { if token.start < cte_end { continue; @@ -120,6 +134,8 @@ pub fn code_lens(sql: &str) -> Vec { start: stmt.real_start, end: stmt.real_end, stmt_executed: stmt.real_text, + cte_identifier: None, + placeholders, }) } } diff --git a/src/code_lens_test.rs b/src/code_lens_test.rs index 81db619..1688700 100644 --- a/src/code_lens_test.rs +++ b/src/code_lens_test.rs @@ -1,5 +1,6 @@ use crate::{ code_lens::{code_lens, CodeLens, CodeLensKind}, + list_placeholders::{Placeholder, PlaceholderRange}, tokenize::ZeroIndexedLocation, }; @@ -13,18 +14,24 @@ fn test_select() { start: ZeroIndexedLocation::new(0, 0), end: ZeroIndexedLocation::new(0, 9), stmt_executed: "SELECT 1;".to_owned(), + cte_identifier: None, + placeholders: vec![], }, CodeLens { kind: CodeLensKind::Select, start: ZeroIndexedLocation::new(0, 10), end: ZeroIndexedLocation::new(0, 19), stmt_executed: "SELECT 2;".to_owned(), + cte_identifier: None, + placeholders: vec![], }, CodeLens { kind: CodeLensKind::Select, start: ZeroIndexedLocation::new(0, 20), end: ZeroIndexedLocation::new(0, 30), stmt_executed: "VALUES(3);".to_owned(), + cte_identifier: None, + placeholders: vec![], }, ] ); @@ -40,12 +47,16 @@ fn test_with_clause() { start: ZeroIndexedLocation::new(0, 5), end: ZeroIndexedLocation::new(0, 6), stmt_executed: "WITH a AS (SELECT 1) SELECT * FROM `a`".to_owned(), + cte_identifier: Some("`a`".to_owned()), + placeholders: vec![], }, CodeLens { kind: CodeLensKind::Select, start: ZeroIndexedLocation::new(0, 0), end: ZeroIndexedLocation::new(0, 30), stmt_executed: "WITH a AS (SELECT 1) SELECT 2;".to_owned(), + cte_identifier: None, + placeholders: vec![], }, ] ); @@ -66,12 +77,16 @@ fn test_other() { start: ZeroIndexedLocation::new(0, 0), end: ZeroIndexedLocation::new(0, 13), stmt_executed: "DROP TABLE t;".to_owned(), + cte_identifier: None, + placeholders: vec![], }, CodeLens { kind: CodeLensKind::Other, start: ZeroIndexedLocation::new(0, 14), end: ZeroIndexedLocation::new(0, 31), stmt_executed: "ATTACH 'db' as db".to_owned(), + cte_identifier: None, + placeholders: vec![], } ] ); @@ -86,6 +101,8 @@ fn test_begin_end() { start: ZeroIndexedLocation::new(0, 0), end: ZeroIndexedLocation::new(0, 31), stmt_executed: "BEGIN; SELECT 1; SELECT 2; END;".to_owned(), + cte_identifier: None, + placeholders: vec![], }] ); } @@ -99,6 +116,8 @@ fn test_pragma() { start: ZeroIndexedLocation::new(0, 0), end: ZeroIndexedLocation::new(0, 22), stmt_executed: "PRAGMA analysis_limit;".to_owned(), + cte_identifier: None, + placeholders: vec![], }] ); } @@ -112,6 +131,8 @@ fn test_vacuum() { start: ZeroIndexedLocation::new(0, 0), end: ZeroIndexedLocation::new(0, 7), stmt_executed: "VACUUM;".to_owned(), + cte_identifier: None, + placeholders: vec![], }] ); } @@ -126,13 +147,77 @@ fn test_with_update() { start: ZeroIndexedLocation::new(0, 5), end: ZeroIndexedLocation::new(0, 6), stmt_executed: "WITH x AS (SELECT 1) SELECT * FROM `x`".to_owned(), + cte_identifier: Some("`x`".to_owned()), + placeholders: vec![], }, CodeLens { kind: CodeLensKind::Other, start: ZeroIndexedLocation::new(0, 0), end: ZeroIndexedLocation::new(0, 40), stmt_executed: "WITH x AS (SELECT 1) UPDATE t SET a = 1;".to_owned(), + cte_identifier: None, + placeholders: vec![], } ] ); } + +#[test] +fn test_placeholders() { + assert_eq!( + code_lens("SELECT ?, :a;"), + [CodeLens { + kind: CodeLensKind::Select, + start: ZeroIndexedLocation::new(0, 0), + end: ZeroIndexedLocation::new(0, 13), + stmt_executed: "SELECT ?, :a;".to_owned(), + cte_identifier: None, + placeholders: vec![ + Placeholder { + name: None, + ranges_relative_to_stmt: vec![PlaceholderRange { + start: ZeroIndexedLocation::new(0, 7), + end: ZeroIndexedLocation::new(0, 8), + }], + }, + Placeholder { + name: Some(":a".to_owned()), + ranges_relative_to_stmt: vec![PlaceholderRange { + start: ZeroIndexedLocation::new(0, 10), + end: ZeroIndexedLocation::new(0, 12), + }] + } + ], + },] + ); +} + +#[test] +fn test_placeholders_with_prefix() { + assert_eq!( + code_lens("-- comment\n SELECT ?, :a;"), + [CodeLens { + kind: CodeLensKind::Select, + start: ZeroIndexedLocation::new(1, 1), + end: ZeroIndexedLocation::new(1, 14), + stmt_executed: "SELECT ?, :a;".to_owned(), + cte_identifier: None, + placeholders: vec![ + Placeholder { + name: None, + ranges_relative_to_stmt: vec![PlaceholderRange { + start: ZeroIndexedLocation::new(0, 7), + end: ZeroIndexedLocation::new(0, 8), + }], + }, + Placeholder { + name: Some(":a".to_owned()), + ranges_relative_to_stmt: vec![PlaceholderRange { + start: ZeroIndexedLocation::new(0, 10), + end: ZeroIndexedLocation::new(0, 12), + }] + } + ], + },] + ); +} diff --git a/src/error.rs b/src/error.rs index 014456b..5fa3ce5 100644 --- a/src/error.rs +++ b/src/error.rs @@ -37,6 +37,11 @@ pub enum Error { query: String, params: Vec, }, + InvalidNumberOfParameters { + query: String, + placeholders: Vec>, + params: Vec, + }, Other { message: String, query: Option, @@ -166,6 +171,23 @@ impl std::fmt::Display for Error { ) } } + Self::InvalidNumberOfParameters { + query, + placeholders, + params, + } => { + write!( + f, + "Invalid number of parameters.\n{}\nParameters: {}\nPlaceholders: [{}]", + Self::format_query(query), + Self::format_params(params), + placeholders + .iter() + .map(|v| v.clone().unwrap_or("?".to_owned())) + .collect::>() + .join(", ") + ) + } Self::Other { message, query, params } => { write!( f, diff --git a/src/list_placeholders.rs b/src/list_placeholders.rs new file mode 100644 index 0000000..076c987 --- /dev/null +++ b/src/list_placeholders.rs @@ -0,0 +1,145 @@ +use lazy_static::lazy_static; +use serde::{Deserialize, Serialize}; +use sqlparser::{ + keywords::Keyword, + tokenizer::{Token, Word}, +}; + +use crate::{ + split_statements::SplittedStatement, + tokenize::{TokenWithRangeLocation, ZeroIndexedLocation}, +}; + +lazy_static! { + static ref QUESTION_NUMBER: regex::Regex = regex::Regex::new(r"^\?\d+$").unwrap(); + static ref NUMBER: regex::Regex = regex::Regex::new(r"^.+$").unwrap(); +} + +#[derive(ts_rs::TS, Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[ts(export)] +pub struct PlaceholderRange { + pub start: ZeroIndexedLocation, + pub end: ZeroIndexedLocation, +} + +impl PlaceholderRange { + fn new(token: &TokenWithRangeLocation) -> Self { + Self { + start: token.start.clone(), + end: token.end.clone(), + } + } +} + +#[derive(ts_rs::TS, Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[ts(export)] +pub struct Placeholder { + /// None for "?", Some for "?NNN", ":VVV", etc. + pub name: Option, + /// Empty for placeholders implicitly inserted by "?NNN". + pub ranges_relative_to_stmt: Vec, +} + +pub fn list_placeholders(stmt: &SplittedStatement) -> Vec { + let mut previous_colon_or_at_sign: Option = None; + let mut result: Vec = vec![]; + let mut is_previous_placeholder_unfinished = false; + + for token in &stmt.real_tokens { + if is_previous_placeholder_unfinished { + match &token.token { + Token::Word(Word { + value: s, + quote_style: None, + keyword: Keyword::NoKeyword, + }) + | Token::Number(s, /* "L" suffix */ false) => { + // Extend the last placeholder + result.last_mut().unwrap().name = + Some(result.last().unwrap().name.to_owned().unwrap() + s.as_str()); + result + .last_mut() + .unwrap() + .ranges_relative_to_stmt + .last_mut() + .unwrap() + .end = token.end.clone(); + } + _ => { + is_previous_placeholder_unfinished = false; + } + } + } + match &token.token { + // https://www.sqlite.org/c3ref/bind_blob.html + // ? + Token::Placeholder(p) if p == "?" => { + result.push(Placeholder { + name: None, + ranges_relative_to_stmt: vec![PlaceholderRange::new(token)], + }); + } + // ?NNN + Token::Placeholder(s) if QUESTION_NUMBER.is_match(s) => { + if let Ok(n) = s[1..].parse::().map(|v| v - 1) { + while result.len() < n + 1 { + result.push(Placeholder { + name: None, + ranges_relative_to_stmt: vec![], + }); + } + if result[n].name.is_none() { + result[n].name = Some(s.to_owned()); + result[n].ranges_relative_to_stmt.push(PlaceholderRange::new(token)); + } + } + } + // :VVV, @VVV, $VVV + Token::Word(Word { + value: s, + quote_style: None, + keyword: Keyword::NoKeyword, + }) => { + if s.starts_with(":") || s.starts_with("@") || s.starts_with("$") { + result.push(Placeholder { + name: Some(s.to_owned()), + ranges_relative_to_stmt: vec![PlaceholderRange::new(token)], + }); + } else if let Some(sign) = previous_colon_or_at_sign { + let mut range = PlaceholderRange::new(token); + range.start.column = range.start.column.saturating_sub(1); + result.push(Placeholder { + name: Some(sign + s.as_str()), + ranges_relative_to_stmt: vec![range], + }); + } + is_previous_placeholder_unfinished = true; + } + Token::Number(s, /* "L" suffix */ false) => { + if let Some(sign) = previous_colon_or_at_sign { + let mut range = PlaceholderRange::new(token); + range.start.column = range.start.column.saturating_sub(1); + result.push(Placeholder { + name: Some(sign + s.as_str()), + ranges_relative_to_stmt: vec![range], + }); + } + is_previous_placeholder_unfinished = true; + } + _ => {} + } + previous_colon_or_at_sign = match token.token { + Token::Colon => Some(":".to_owned()), + Token::AtSign => Some("@".to_owned()), + _ => None, + }; + } + + for placeholder in result.iter_mut() { + for range in placeholder.ranges_relative_to_stmt.iter_mut() { + range.start -= &stmt.real_start; + range.end -= &stmt.real_start; + } + } + result +} diff --git a/src/list_placeholders_test.rs b/src/list_placeholders_test.rs new file mode 100644 index 0000000..fa91987 --- /dev/null +++ b/src/list_placeholders_test.rs @@ -0,0 +1,60 @@ +use crate::{ + list_placeholders::{list_placeholders, Placeholder, PlaceholderRange}, + split_statements::split_sqlite_statements, + tokenize::ZeroIndexedLocation, +}; + +fn list_placeholders_with_sqlite_api(sql: &str) -> rusqlite::Result>> { + let conn = rusqlite::Connection::open_in_memory()?; + let stmt = conn.prepare(sql)?; + Ok((1..=stmt.parameter_count()) + .map(|i| stmt.parameter_name(i).map(|v| v.to_owned())) + .collect()) +} + +fn compare(sql: &str) { + let expected = list_placeholders_with_sqlite_api(sql).unwrap(); + let actual = list_placeholders(&split_sqlite_statements(sql).unwrap().0[0]); + assert_eq!(expected.len(), actual.len()); + for i in 0..expected.len() { + assert_eq!(actual[i].name, expected[i]); + } +} + +#[test] +pub fn test_list_placeholders_without_comparing_to_sqlite_api_output() { + assert_eq!( + list_placeholders(&split_sqlite_statements("SELECT ?, $a").unwrap().0[0]), + vec![ + Placeholder { + name: None, + ranges_relative_to_stmt: vec![PlaceholderRange { + start: ZeroIndexedLocation::new(0, 7), + end: ZeroIndexedLocation::new(0, 8), + }], + }, + Placeholder { + name: Some("$a".to_owned()), + ranges_relative_to_stmt: vec![PlaceholderRange { + start: ZeroIndexedLocation::new(0, 10), + end: ZeroIndexedLocation::new(0, 12), + }], + }, + ] + ); +} + +#[test] +pub fn test_no_placeholder() { + compare("SELECT 1"); +} + +#[test] +pub fn test_placeholder_reuse() { + compare("SELECT ?, ?, @a, ?2, ?3"); +} + +#[test] +pub fn test_everything() { + compare("WITH x AS (SELECT @a) SELECT ?, ?, ?10, :10, @10, $10, :aa, @aa, $aa, ?12, ?, :1a1"); +} diff --git a/src/main.rs b/src/main.rs index 1d61b12..f59bd03 100644 --- a/src/main.rs +++ b/src/main.rs @@ -35,6 +35,9 @@ mod export_test; mod find; #[cfg(test)] mod import_test; +mod list_placeholders; +#[cfg(test)] +mod list_placeholders_test; mod literal; #[cfg(test)] mod literal_test; diff --git a/src/sqlite3.rs b/src/sqlite3.rs index bb505b4..326fbac 100644 --- a/src/sqlite3.rs +++ b/src/sqlite3.rs @@ -358,6 +358,9 @@ pub struct QueryOptions { /// A statement that is executed before the main statement. It shares the transaction and the parameters with the main statement, and requires ExecMode::ReadWrite. pub pre_stmt: Option, + + /// Checks the list of placeholders returned by list_placeholders.rs. + pub check_placeholders: Option>>, } #[derive(ts_rs::TS, Serialize, Debug, Clone)] @@ -544,6 +547,31 @@ impl SQLite3 { .or_else(|err| Error::new_query_error(err, query, params))?; // Bind parameters + if params.len() != stmt.parameter_count() { + return Err(Error::InvalidNumberOfParameters { + query: query.to_string(), + placeholders: (1..=stmt.parameter_count()) + .map(|i| stmt.parameter_name(i).map(|v| v.to_owned())) + .collect(), + params: params.to_vec(), + }); + } + if let Some(placeholders) = options.check_placeholders { + if placeholders.len() != stmt.parameter_count() + || (0..placeholders.len()).any(|i| placeholders[i].as_deref() != stmt.parameter_name(i + 1)) + { + return Error::new_other_error( + format!( + "Failed to parse the SQL statement: {placeholders:?} != {:?}", + (0..placeholders.len()) + .map(|i| stmt.parameter_name(i + 1)) + .collect::>() + ), + Some(query.to_owned()), + Some(params), + ); + } + } for (i, param) in params.iter().enumerate() { stmt.raw_bind_parameter(i + 1, param) .or_else(|err| Error::new_query_error(err, query, params))?; diff --git a/src/sqlite3_test.rs b/src/sqlite3_test.rs index 76f1d62..5b62cb9 100644 --- a/src/sqlite3_test.rs +++ b/src/sqlite3_test.rs @@ -862,6 +862,29 @@ fn test_query_error() { ); } +#[test] +fn test_invalid_number_of_parameters() { + let mut db = SQLite3::connect(":memory:", false, &None::<&str>).unwrap(); + let mut w = Cursor::new(Vec::::new()); + assert_eq!( + format!( + "{}", + db.handle( + &mut w, + "SELECT ?, ?, @a", + &[Literal::I64(1)], + crate::request_type::QueryMode::ReadOnly, + QueryOptions::default(), + ) + .unwrap_err() + ), + "Invalid number of parameters. +Query: SELECT ?, ?, @a +Parameters: [1] +Placeholders: [?, ?, @a]", + ); +} + #[test] fn test_transaction_success() { let mut db = SQLite3::connect(":memory:", false, &None::<&str>).unwrap(); diff --git a/src/tokenize.rs b/src/tokenize.rs index a78c1e8..d84e889 100644 --- a/src/tokenize.rs +++ b/src/tokenize.rs @@ -12,6 +12,30 @@ pub struct ZeroIndexedLocation { pub column: usize, } +impl std::ops::Sub<&Self> for ZeroIndexedLocation { + type Output = Self; + fn sub(self, rhs: &Self) -> Self::Output { + Self { + line: self.line.saturating_sub(rhs.line), + column: if self.line == rhs.line { + self.column.saturating_sub(rhs.column) + } else { + self.column + }, + } + } +} + +impl std::ops::SubAssign<&Self> for ZeroIndexedLocation { + fn sub_assign(&mut self, rhs: &Self) { + if self.line == rhs.line { + self.column = self.column.saturating_sub(rhs.column); + } + // Update line after comparing it to rhs.line + self.line = self.line.saturating_sub(rhs.line); + } +} + impl std::cmp::Ord for ZeroIndexedLocation { fn cmp(&self, other: &Self) -> std::cmp::Ordering { self.line.cmp(&other.line).then_with(|| self.column.cmp(&other.column))