diff --git a/Cargo.lock b/Cargo.lock index 083d0c0..fc38bd6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -564,7 +564,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "sqlite3-editor" -version = "1.0.194" +version = "1.0.196" dependencies = [ "base64", "clap", diff --git a/Cargo.toml b/Cargo.toml index 32f7036..83e2d52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlite3-editor" -version = "1.0.194" +version = "1.0.196" edition = "2021" [features] diff --git a/src/cache/pager.rs b/src/cache/pager.rs index cab63b7..44273e5 100644 --- a/src/cache/pager.rs +++ b/src/cache/pager.rs @@ -1,6 +1,7 @@ use std::{rc::Rc, time::Duration}; use crate::{ + columnar_buffer::ColumnarBuffer, error::Error, literal::Literal, sqlite3::{write_value_ref_into_msgpack, InvalidUTF8}, @@ -141,9 +142,9 @@ impl Pager { params[len - 1] = Literal::I64(offset_with_margin.try_into().unwrap()); // Forward run: Fetch the queried area and cache records after that - let mut col_buf: Vec>; - let mut n_rows: u32 = 0; + let mut col_buf = ColumnarBuffer::default(); let columns: Vec; + let mut n_rows: u32 = 0; let mut end_margin_size = 0; { // Prepare @@ -157,16 +158,7 @@ impl Pager { .or_else(|err| Error::new_query_error(err, query, ¶ms))?; } - // List columns - columns = stmt - .column_names() - .into_iter() - .map(|v| v.to_owned()) - .collect::>(); - col_buf = vec![vec![]; columns.len()]; - let cache_size_prev = cache_entry.total_size_bytes(); - cache_entry.set_columns_if_not_set_yet(columns.clone()); // Fetch records let mut current_offset = offset_with_margin; @@ -201,15 +193,23 @@ impl Pager { } let mut cache_record = vec![]; - for (i, col_buf_i) in col_buf.iter_mut().enumerate() { - let mut w = vec![]; - write_value_ref_into_msgpack(&mut w, row.get_ref_unwrap(i), &mut on_invalid_utf8) - .expect("Failed to write msgpack"); - if !is_margin { - col_buf_i.extend(&w); + // NOTE: We need to call `stmt.column_count()` after `rows.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53), + // but since the borrow checker prevents us from calling `stmt.column_count()` while `row` is alive, + // we rely on `rusqlite::Error::InvalidColumnIndex` returned from `row.get_ref(i)` to check the number of columns. + for i in 0usize..=usize::MAX { + match row.get_ref(i) { + Ok(value) => { + let mut w = vec![]; + write_value_ref_into_msgpack(&mut w, value, &mut on_invalid_utf8) + .expect("Failed to write msgpack"); + if !is_margin { + col_buf.get_column(i).extend(&w); + } + cache_record.push(w); + } + Err(rusqlite::Error::InvalidColumnIndex(_)) => break, + Err(err) => return Error::new_query_error(err, query, ¶ms), } - - cache_record.push(w); } cache_entry.insert(current_offset, &cache_record); @@ -227,6 +227,16 @@ impl Pager { Err(err) => Error::new_query_error(err, query, ¶ms)?, } } + + drop(rows); + + // NOTE: We need to call `stmt.column_names()` after `rows.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53) + columns = stmt + .column_names() + .into_iter() + .map(|v| v.to_owned()) + .collect::>(); + cache_entry.set_columns_if_not_set_yet(columns.clone()); } // Backward run: cache `end_margin_size` records before the queried area @@ -243,7 +253,7 @@ impl Pager { .prepare(query) .or_else(|err| Error::new_query_error(err, query, ¶ms))?; - // Bind parameters + // Bind parametersnew_other_error for (i, param) in params.iter().enumerate() { stmt.raw_bind_parameter(i + 1, param) .or_else(|err| Error::new_query_error(err, query, ¶ms))?; @@ -258,8 +268,18 @@ impl Pager { let mut cache_record = vec![]; for i in 0..columns.len() { let mut w = vec![]; - write_value_ref_into_msgpack(&mut w, row.get_ref_unwrap(i), &mut on_invalid_utf8) - .expect("Failed to write msgpack"); + write_value_ref_into_msgpack( + &mut w, + row.get_ref(i).or_else(|err| { + Error::new_other_error( + format!("Error while caching backwards, possibly due to the database schema being updated during the process: {err:?}"), + Some(query.to_string()), + Some(¶ms), + ) + })?, + &mut on_invalid_utf8, + ) + .expect("Failed to write msgpack"); cache_record.push(w); } cache_entry.insert(current_offset, &cache_record); @@ -275,7 +295,11 @@ impl Pager { } } - Ok(Some(Records::new(col_buf, n_rows, Rc::new(columns)))) + Ok(Some(Records::new( + col_buf.finish(columns.len()), + n_rows, + Rc::new(columns), + ))) } } diff --git a/src/cache/pager_cache.rs b/src/cache/pager_cache.rs index 16a4478..5d61398 100644 --- a/src/cache/pager_cache.rs +++ b/src/cache/pager_cache.rs @@ -13,7 +13,7 @@ impl PagerCache { Self { cache: vec![] } } - /// Returns the cache entry that is associated to (query, params). + /// Returns the cache entry that is associated to (query, params[:-2]). /// Inserts an entry if it does not exist. pub(super) fn entry(&mut self, query: &str, params: &[Literal]) -> Rc> { let params = ¶ms[0..(params.len() - 2)]; diff --git a/src/column_origin.rs b/src/column_origin.rs index fc30faa..56bd17a 100644 --- a/src/column_origin.rs +++ b/src/column_origin.rs @@ -1,6 +1,6 @@ use rusqlite::ffi::{ sqlite3, sqlite3_column_count, sqlite3_column_database_name, sqlite3_column_name, sqlite3_column_origin_name, - sqlite3_column_table_name, sqlite3_errmsg, sqlite3_finalize, sqlite3_prepare_v2, sqlite3_stmt, + sqlite3_column_table_name, sqlite3_errmsg, sqlite3_finalize, sqlite3_prepare_v2, sqlite3_step, sqlite3_stmt, }; use serde::{Deserialize, Serialize}; use std::collections::HashMap; @@ -49,6 +49,9 @@ pub fn column_origin(db: *mut sqlite3, query: &str) -> Result::new(); + // NOTE: We need to call `sqlite3_column_count()` and `sqlite3_column_name()` after `sqlite3_step()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53) + unsafe { sqlite3_step(stmt) }; + let column_count: usize = unsafe { sqlite3_column_count(stmt).try_into().unwrap() }; for i in 0..column_count { let Some(column_name) = ptr_to_string(unsafe { sqlite3_column_name(stmt, i.try_into().unwrap()) }) else { diff --git a/src/columnar_buffer.rs b/src/columnar_buffer.rs new file mode 100644 index 0000000..7566cc2 --- /dev/null +++ b/src/columnar_buffer.rs @@ -0,0 +1,20 @@ +#[derive(Default)] +pub struct ColumnarBuffer { + // column index -> a msgpack containing all the values in the column + columns: Vec>, +} + +impl ColumnarBuffer { + pub fn get_column(&mut self, col_index: usize) -> &mut Vec { + while self.columns.len() <= col_index { + self.columns.resize(col_index + 1, vec![]); + } + &mut self.columns[col_index] + } + + pub fn finish(mut self, len: usize) -> Vec> { + // resize for when there are no records + self.columns.resize(len, vec![]); + self.columns + } +} diff --git a/src/export.rs b/src/export.rs index e173808..3b44fad 100644 --- a/src/export.rs +++ b/src/export.rs @@ -65,6 +65,7 @@ pub fn export_csv( Error::new_other_error("The delimiter needs to be a single character.", None, None)?; } + // TODO: `stmt.column_count()` and `stmt.column_names()` should be called after `rows.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53). let column_count = stmt.column_count(); let column_names = stmt .column_names() @@ -112,6 +113,7 @@ pub fn export_json( .prepare(query) .or_else(|err| Error::new_query_error(err, query, &[]))?; + // TODO: `stmt.column_names()` should be called after `rows.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53). let column_names = stmt .column_names() .into_iter() @@ -273,6 +275,7 @@ fn write_table_data( .prepare(query) .or_else(|err| Error::new_query_error(err, query, &[]))?; + // TODO: `stmt.column_count()` and `stmt.column_names()` should be called after `rows.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53). let column_count = stmt.column_count(); let column_names = stmt .column_names() diff --git a/src/literal.rs b/src/literal.rs index afdf1f2..c10257e 100644 --- a/src/literal.rs +++ b/src/literal.rs @@ -22,7 +22,7 @@ impl<'de> Deserialize<'de> for Blob { { struct MyBlobVisitor; - impl<'de> serde::de::Visitor<'de> for MyBlobVisitor { + impl serde::de::Visitor<'_> for MyBlobVisitor { type Value = Blob; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { diff --git a/src/main.rs b/src/main.rs index ee9832f..cc2c30b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ use std::{ str::FromStr, sync::{Arc, Mutex}, }; +mod columnar_buffer; mod completion; #[cfg(test)] mod completion_test; diff --git a/src/sqlite3.rs b/src/sqlite3.rs index 326fbac..400679c 100644 --- a/src/sqlite3.rs +++ b/src/sqlite3.rs @@ -1,6 +1,7 @@ use crate::{ cache::{Pager, Records}, column_origin::{column_origin, ColumnOrigin}, + columnar_buffer::ColumnarBuffer, find::{ find_widget_compare, find_widget_compare_c, find_widget_compare_r, find_widget_compare_r_c, find_widget_compare_r_w, find_widget_compare_r_w_c, find_widget_compare_w, find_widget_compare_w_c, @@ -577,26 +578,28 @@ impl SQLite3 { .or_else(|err| Error::new_query_error(err, query, params))?; } - // List columns - let columns = stmt - .column_names() - .into_iter() - .map(|v| v.to_owned()) - .collect::>(); - // Fetch records - let mut col_buf: Vec> = vec![vec![]; columns.len()]; + let mut col_buf = ColumnarBuffer::default(); let mut n_rows: u32 = 0; let mut rows = stmt.raw_query(); loop { match rows.next() { Ok(Some(row)) => { - for (i, col_buf_i) in col_buf.iter_mut().enumerate() { - write_value_ref_into_msgpack(col_buf_i, row.get_ref_unwrap(i), |err| { - warnings.push(err.with(query)) - }) - .expect("Failed to write msgpack"); + // NOTE: We need to call `stmt.column_count()` after `rows.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53), + // but since the borrow checker prevents us from calling `stmt.column_count()` while `row` is alive, + // we rely on `rusqlite::Error::InvalidColumnIndex` returned from `row.get_ref(i)` to check the number of columns. + for i in 0usize..=usize::MAX { + match row.get_ref(i) { + Ok(value) => { + write_value_ref_into_msgpack(&mut col_buf.get_column(i), value, |err| { + warnings.push(err.with(query)) + }) + .expect("Failed to write msgpack"); + } + Err(rusqlite::Error::InvalidColumnIndex(_)) => break, + Err(err) => return Error::new_query_error(err, query, params), + } } n_rows += 1; } @@ -606,6 +609,14 @@ impl SQLite3 { } drop(rows); + + // NOTE: We need to call `stmt.column_names()` after `rows.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53) + let columns = stmt + .column_names() + .into_iter() + .map(|v| v.to_owned()) + .collect::>(); + drop(stmt); if let Some(changes) = options.changes { @@ -627,7 +638,7 @@ impl SQLite3 { } tx.commit().or_else(|err| Error::new_query_error(err, query, params))?; - Records::new(col_buf, n_rows, Rc::new(columns)) + Records::new(col_buf.finish(columns.len()), n_rows, Rc::new(columns)) }; // Pack the result into a msgpack @@ -1199,16 +1210,20 @@ JOIN main.pragma_table_info("table_name") p"#, let column_origins = column_origin( unsafe { self.con.handle() }, - // \n is to handle comments, e.g. customQuery = "SELECT ... FROM ... -- comments" + // \n is to handle line comments, e.g. query = "SELECT a FROM b -- comments" &format!("SELECT * FROM ({query}\n) LIMIT 0"), ) .unwrap_or_default(); - let stmt = format!("SELECT * FROM ({query}\n) LIMIT 0"); - let column_names = self + let stmt_str = format!("SELECT * FROM ({query}\n) LIMIT 0"); + let mut stmt = self .con - .prepare(&stmt) - .or_else(|err| Error::new_query_error(err, &stmt, &[]))? + .prepare(&stmt_str) + .or_else(|err| Error::new_query_error(err, &stmt_str, &[]))?; + + // NOTE: We need to call `stmt.column_names()` after `.next()` (see https://github.com/rusqlite/rusqlite/blob/b7309f2dca70716fee44c85082c585b330edb073/src/column.rs#L51-L53) + let _ = stmt.raw_query().next(); + let column_names = stmt .column_names() .into_iter() .map(|v| v.to_owned())