From c413523737c4421e9680dc3d28fdf4407656a9ca Mon Sep 17 00:00:00 2001 From: yy0931 <54441600+yy0931@users.noreply.github.com> Date: Sun, 23 Feb 2025 21:52:14 +0900 Subject: [PATCH] v1.0.204 --- Cargo.lock | 66 +++++++------- Cargo.toml | 2 +- src/main.rs | 55 +++++++---- src/main_test.rs | 231 ++++++++++++++++++++++++++++++++++++++++++++++- src/sqlite3.rs | 33 ++++++- 5 files changed, 331 insertions(+), 56 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 7626cbc..58172f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -117,9 +117,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cc" -version = "1.2.13" +version = "1.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7777341816418c02e033934a09f20dc0ccaf65a5201ef8a450ae0105a573fda" +checksum = "c736e259eea577f443d5c86c304f9f4ae0295c43f3ba05c21f1d66b5f06001af" dependencies = [ "shlex", ] @@ -132,9 +132,9 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.5.28" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e77c3243bd94243c03672cb5154667347c457ca271254724f9f393aee1c05ff" +checksum = "92b7b18d71fad5313a1e320fa9897994228ce274b60faa4d694fe0ea89cd9e6d" dependencies = [ "clap_builder", "clap_derive", @@ -142,9 +142,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.27" +version = "4.5.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b26884eb4b57140e4d2d93652abfa49498b938b3c9179f9fc487b0acc3edad7" +checksum = "a35db2071778a7344791a4fb4f95308b5673d219dee3ae348b86642574ecc90c" dependencies = [ "anstream", "anstyle", @@ -205,9 +205,9 @@ dependencies = [ [[package]] name = "csv-core" -version = "0.1.11" +version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +checksum = "7d02f3b0da4c6504f86e9cd789d8dbafab48c2321be74e9987593de5a894d93d" dependencies = [ "memchr", ] @@ -313,14 +313,14 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.169" +version = "0.2.170" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" +checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" [[package]] name = "libsqlite3-sys" version = "0.26.0" -source = "git+https://github.com/yy0931/rusqlite.git?branch=master#d7da2970237a5c9eb8861f632b03a66f1be875dc" +source = "git+https://github.com/yy0931/rusqlite.git?branch=master#492a665e5ef22474fd949ed0d55d40aec7621f3d" dependencies = [ "cc", "openssl-sys", @@ -336,9 +336,9 @@ checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "log" -version = "0.4.25" +version = "0.4.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +checksum = "30bde2b3dc3671ae49d8e2e9f044c7c005836e7a023ee57cffa25ab82764bb9e" [[package]] name = "memchr" @@ -348,9 +348,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "miniz_oxide" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8402cab7aefae129c6977bb0ff1b8fd9a04eb5b51efc50a70bea51cda0c7924" +checksum = "8e3e04debbb59698c15bacbb6d93584a8c0ca9cc3213cb423d31f760d8843ce5" dependencies = [ "adler2", ] @@ -372,18 +372,18 @@ checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e" [[package]] name = "openssl-src" -version = "300.4.1+3.4.0" +version = "300.4.2+3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faa4eac4138c62414b5622d1b31c5c304f34b406b013c079c2bbc652fdd6678c" +checksum = "168ce4e058f975fe43e89d9ccf78ca668601887ae736090aacc23ae353c298e2" dependencies = [ "cc", ] [[package]] name = "openssl-sys" -version = "0.9.105" +version = "0.9.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b22d5b84be05a8d6947c7cb71f7c849aa0f112acd4bf51c2a7c1c988ac0a9dc" +checksum = "8bb61ea9811cc39e3c2069f40b8b8e2e70d8569b361f879786cc7ed48b777cdd" dependencies = [ "cc", "libc", @@ -482,7 +482,7 @@ dependencies = [ [[package]] name = "rusqlite" version = "0.29.0" -source = "git+https://github.com/yy0931/rusqlite.git?branch=master#d7da2970237a5c9eb8861f632b03a66f1be875dc" +source = "git+https://github.com/yy0931/rusqlite.git?branch=master#492a665e5ef22474fd949ed0d55d40aec7621f3d" dependencies = [ "bitflags", "fallible-iterator", @@ -524,9 +524,9 @@ checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd" [[package]] name = "serde" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02fc4265df13d6fa1d00ecff087228cc0a2b5f3c0e87e258d8b94a156e984c70" +checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60" dependencies = [ "serde_derive", ] @@ -542,9 +542,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.217" +version = "1.0.218" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5a9bf7cf98d04a2b28aead066b7496853d4779c9cc183c440dbac457641e19a0" +checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b" dependencies = [ "proc-macro2", "quote", @@ -553,9 +553,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.138" +version = "1.0.139" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949" +checksum = "44f86c3acccc9c65b153fe1b85a3be07fe5515274ec9f0653b4a0875731c72a6" dependencies = [ "itoa", "memchr", @@ -571,13 +571,13 @@ checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "smallvec" -version = "1.13.2" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +checksum = "7fcf8323ef1faaee30a44a340193b1ac6814fd9b7b4e88e9d4519a3e4abe1cfd" [[package]] name = "sqlite3-editor" -version = "1.0.203" +version = "1.0.204" dependencies = [ "base64", "clap", @@ -626,9 +626,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.16.0" +version = "3.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91" +checksum = "22e5a0acb1f3f55f65cc4a866c361b2fb2a0ff6366785ae6fbb5f85df07ba230" dependencies = [ "cfg-if", "fastrand", @@ -692,9 +692,9 @@ dependencies = [ [[package]] name = "unicode-ident" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034" +checksum = "00e2473a93778eb0bad35909dff6a10d28e63f792f16ed15e404fca9d5eeedbe" [[package]] name = "utf8parse" diff --git a/Cargo.toml b/Cargo.toml index c160a5b..e84b514 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlite3-editor" -version = "1.0.203" +version = "1.0.204" edition = "2021" [features] diff --git a/src/main.rs b/src/main.rs index a185a83..d09021d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -209,7 +209,7 @@ impl From<(String, i64, i64)> for CompletionQuery { #[derive(ts_rs::TS, Debug, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] #[ts(export, rename_all = "snake_case")] -enum ServerCommand { +pub enum ServerCommand { Interrupt, Close, TryReconnect, @@ -222,10 +222,33 @@ enum ServerCommand { Completion, } +pub trait ReadCommand { + fn read_command(&mut self) -> Option; +} + +impl ReadCommand for T { + fn read_command(&mut self) -> Option { + loop { + let mut command_str = String::new(); + match self.read_line(&mut command_str) { + Err(_) => return None, + Ok(0) => return None, + _ => {} + } + + if let Ok(command) = ServerCommand::deserialize( + serde::de::value::StrDeserializer::::new(command_str.trim()), + ) { + return Some(command); + } + } + } +} + fn cli(args: Args, stdin: F, mut stdout: &mut O, mut stderr: &mut E) -> i32 where - F: Fn() -> I + std::marker::Send + 'static, - I: Read + BufRead, + F: FnOnce() -> I + std::marker::Send + 'static, + I: ReadCommand, O: Write, E: Write, { @@ -308,28 +331,19 @@ where std::thread::spawn(move || { let mut stdin = stdin(); loop { - let mut command_str = String::new(); - match stdin.read_line(&mut command_str) { - Err(_) => return, - Ok(0) => return, - _ => {} - } - - match ServerCommand::deserialize( - serde::de::value::StrDeserializer::::new(command_str.trim()), - ) { - Ok(ServerCommand::Interrupt) => { + match stdin.read_command() { + Some(ServerCommand::Interrupt) => { interrupt_handle.lock().unwrap().interrupt(); } - Ok(ServerCommand::Resume) => { + Some(ServerCommand::Resume) => { resume_command_sender.send(()).unwrap(); } - Ok(command) => { + Some(command) => { if command_sender.send(command).is_err() { return; } } - _ => {} + None => return, } } }) @@ -400,6 +414,10 @@ where ServerCommand::DisconnectTemporarily => { drop(db); + // Send the response to DisconnectTemporarily + write_named(&mut w, &None::<&i64>).expect("Failed to write the result."); + finish(&mut stdout, &mut w, error::ErrorCode::Success); + if resume_command_receiver.recv().is_err() { return 0; } @@ -407,8 +425,9 @@ where match sqlite3::SQLite3::connect(&database_filepath, READ_ONLY, &sql_cipher_key) { Ok(new_db) => { db = new_db; - *interrupt_handle.lock().unwrap() = db.get_interrupt_handle(); + + // Send the response to Resume write_named(&mut w, &None::<&i64>).expect("Failed to write the result."); finish(&mut stdout, &mut w, error::ErrorCode::Success); } diff --git a/src/main_test.rs b/src/main_test.rs index 3a886ad..7b2bc0a 100644 --- a/src/main_test.rs +++ b/src/main_test.rs @@ -1,8 +1,17 @@ -use std::io::Cursor; +use std::{ + io::{Cursor, Write}, + sync::mpsc::{Receiver, Sender}, + thread::JoinHandle, +}; use tempfile::NamedTempFile; -use crate::{cli, Args, ExportFormat, Query}; +use crate::{ + cli, + request_type::{QueryMode, Request}, + sqlite3::QueryOptions, + Args, ExportFormat, Query, ReadCommand, ServerCommand, +}; #[test] fn test_parse_query() { @@ -94,3 +103,221 @@ fn test_export_json() { "[{\"x\":1,\"y\":2},{\"x\":3,\"y\":4}]" ); } + +fn wait_ms(ms: u64) { + std::thread::sleep(std::time::Duration::from_millis(ms)); +} +struct SenderWriter { + sender: std::sync::mpsc::Sender>, + buf: Vec, +} + +impl SenderWriter { + fn new(sender: Sender>) -> Self { + Self { + sender, + buf: Vec::new(), + } + } +} + +impl Write for SenderWriter { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + self.buf.extend(buf); + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.sender.send(std::mem::take(&mut self.buf)).unwrap(); + Ok(()) + } +} + +struct ReceiverReader { + receiver: std::sync::mpsc::Receiver, +} + +impl ReceiverReader { + fn new(receiver: std::sync::mpsc::Receiver) -> Self { + ReceiverReader { receiver } + } +} + +impl ReadCommand for ReceiverReader { + fn read_command(&mut self) -> Option { + self.receiver.recv().ok() + } +} + +struct ServerTestBench { + server_thread: JoinHandle<()>, + #[allow(dead_code)] + database_filepath: NamedTempFile, + request_body_filepath: NamedTempFile, + response_body_filepath: NamedTempFile, + stdin_sender: Sender, + stdout_receiver: Receiver>, + #[allow(dead_code)] + stderr_receiver: Receiver>, +} + +impl ServerTestBench { + fn new() -> Self { + let database_filepath = NamedTempFile::new().unwrap(); + let request_body_filepath = NamedTempFile::new().unwrap(); + let response_body_filepath = NamedTempFile::new().unwrap(); + let (stdin_sender, stdin_receiver) = std::sync::mpsc::channel::(); + let (stdout_sender, stdout_receiver) = std::sync::mpsc::channel::>(); + let mut stdout_writer = SenderWriter::new(stdout_sender); + let (stderr_sender, stderr_receiver) = std::sync::mpsc::channel::>(); + let mut stderr_writer = SenderWriter::new(stderr_sender); + + Self { + server_thread: { + let database_filepath = database_filepath.path().to_str().unwrap().to_owned(); + let request_body_filepath = request_body_filepath.path().to_owned(); + let response_body_filepath = response_body_filepath.path().to_owned(); + std::thread::spawn(move || { + assert_eq!( + cli( + Args { + command: crate::Commands::Server { + database_filepath, + request_body_filepath, + response_body_filepath, + sql_cipher_key: None, + }, + }, + move || { ReceiverReader::new(stdin_receiver) }, + &mut stdout_writer, + &mut stderr_writer, + ), + 0 + ); + }) + }, + database_filepath, + request_body_filepath, + response_body_filepath, + stdin_sender, + stdout_receiver, + stderr_receiver, + } + } + + fn send_stdin(&self, command: ServerCommand) { + self.stdin_sender.send(command).unwrap(); + } + + fn recv_stdout(&self) -> String { + String::from_utf8(self.stdout_receiver.recv().unwrap()).unwrap() + } + + #[allow(dead_code)] + fn recv_stderr(&self) -> String { + String::from_utf8(self.stderr_receiver.recv().unwrap()).unwrap() + } + + fn write_request_body(&self, data: &T) + where + T: serde::Serialize + ?Sized, + { + std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(&self.request_body_filepath) + .unwrap() + .write_all(&rmp_serde::to_vec(&data).unwrap()) + .unwrap(); + } + + fn read_response_body(&self) -> String { + std::fs::read_to_string(&self.response_body_filepath).unwrap() + } +} + +impl Drop for ServerTestBench { + fn drop(&mut self) { + // .send() returns an Error if the server is already shut down. + let _ = self.stdin_sender.send(ServerCommand::Close); + + // sever_thread.join() without taking its ownership + while !self.server_thread.is_finished() { + wait_ms(1); + } + } +} + +#[test] +fn test_server_close() { + let test_bench = ServerTestBench::new(); + wait_ms(100); + assert!(!test_bench.server_thread.is_finished()); + test_bench.send_stdin(ServerCommand::Close); + wait_ms(100); + assert!(test_bench.server_thread.is_finished()); +} + +#[test] +fn test_server_handle() { + let test_bench = ServerTestBench::new(); + test_bench.write_request_body(&Request { + query: "SELECT 1".to_owned(), + params: vec![], + mode: QueryMode::ReadOnly, + options: QueryOptions::default(), + }); + test_bench.send_stdin(ServerCommand::Handle); + assert_eq!(test_bench.recv_stdout(), "Success\n"); +} + +#[test] +#[cfg(not(feature = "sqlcipher"))] // sleep() does not support interruption under feature="sqlcipher" +fn test_server_interrupt() { + let test_bench = ServerTestBench::new(); + + test_bench.write_request_body(&Request { + query: "EDITOR_PRAGMA add_sleep_fn".to_owned(), + params: vec![], + mode: QueryMode::ReadOnly, + options: QueryOptions::default(), + }); + test_bench.send_stdin(ServerCommand::Handle); + assert_eq!(test_bench.recv_stdout(), "Success\n"); + + test_bench.write_request_body(&Request { + query: "SELECT sleep(5000)".to_owned(), + params: vec![], + mode: QueryMode::ReadOnly, + options: QueryOptions::default(), + }); + test_bench.send_stdin(ServerCommand::Handle); + wait_ms(500); + test_bench.send_stdin(ServerCommand::Interrupt); + assert_eq!(test_bench.recv_stdout(), "OtherError\n"); + assert_eq!( + test_bench.read_response_body(), + "interrupted\nQuery: SELECT sleep(5000)\nParameters: []END" + ); +} + +#[test] +fn test_server_disconnect_temporarily() { + let test_bench = ServerTestBench::new(); + + test_bench.send_stdin(ServerCommand::DisconnectTemporarily); + assert_eq!(test_bench.recv_stdout(), "Success\n"); + + test_bench.send_stdin(ServerCommand::Resume); + assert_eq!(test_bench.recv_stdout(), "Success\n"); + + test_bench.write_request_body(&Request { + query: "SELECT 1".to_owned(), + params: vec![], + mode: QueryMode::ReadOnly, + options: QueryOptions::default(), + }); + test_bench.send_stdin(ServerCommand::Handle); + assert_eq!(test_bench.recv_stdout(), "Success\n"); +} diff --git a/src/sqlite3.rs b/src/sqlite3.rs index 8f5bacd..69dbd00 100644 --- a/src/sqlite3.rs +++ b/src/sqlite3.rs @@ -22,6 +22,22 @@ use std::{ time::Duration, }; +struct SendSqliteHandle(*mut rusqlite::ffi::sqlite3); + +impl SendSqliteHandle { + #[cfg(feature = "sqlcipher")] + fn is_interrupted(&self) -> bool { + false // sqlite3_is_interrupted() is not defined under feature="sqlcipher". + } + + #[cfg(not(feature = "sqlcipher"))] + fn is_interrupted(&self) -> bool { + unsafe { rusqlite::ffi::sqlite3_is_interrupted(self.0) != 0 } + } +} + +unsafe impl Send for SendSqliteHandle {} + #[derive(Debug, Clone)] struct StringError(String); @@ -1373,9 +1389,22 @@ JOIN main.pragma_table_info("table_name") p"#, write_editor_pragma(w, (self.load_extensions(&extensions)?, vec![]), start_time) } "EDITOR_PRAGMA add_sleep_fn" => { + let handle = SendSqliteHandle(unsafe { self.con.handle() }); self.con - .create_scalar_function("sleep", 1, FunctionFlags::SQLITE_UTF8, |ms| { - std::thread::sleep(Duration::from_millis(ms.get(0)?)); + .create_scalar_function("sleep", 1, FunctionFlags::SQLITE_UTF8, move |ms| { + let total_duration = Duration::from_millis(ms.get::(0)? as u64); + let start = std::time::Instant::now(); + while start.elapsed() < total_duration { + std::thread::sleep(Duration::from_millis(1)); + + // Support sqlite3_interrupt() + if handle.is_interrupted() { + return Err(rusqlite::Error::SqliteFailure( + rusqlite::ffi::Error::new(rusqlite::ffi::SQLITE_INTERRUPT), + Some("interrupted".to_owned()), + )); + } + } Ok(0) }) .unwrap();