diff --git a/Cargo.lock b/Cargo.lock index de11a62..7626cbc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -577,7 +577,7 @@ checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "sqlite3-editor" -version = "1.0.202" +version = "1.0.203" dependencies = [ "base64", "clap", diff --git a/Cargo.toml b/Cargo.toml index 1faa399..c160a5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "sqlite3-editor" -version = "1.0.202" +version = "1.0.203" edition = "2021" [features] diff --git a/src/main.rs b/src/main.rs index cc2c30b..a185a83 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,6 @@ use std::{ fs::File, io::{BufRead, Read, Seek, SeekFrom, Write}, path::PathBuf, - str::FromStr, sync::{Arc, Mutex}, }; mod columnar_buffer; @@ -80,27 +79,36 @@ struct Args { command: Commands, } -#[derive(clap::ValueEnum, Clone, Debug, PartialEq, Eq)] +#[derive(ts_rs::TS, clap::ValueEnum, Clone, Debug, PartialEq, Eq)] +#[ts(export)] pub enum ImportFormat { #[clap(name = "csv")] + #[ts(rename = "csv")] CSV, #[clap(name = "tsv")] + #[ts(rename = "tsv")] TSV, #[clap(name = "json")] + #[ts(rename = "json")] JSON, } -#[derive(clap::ValueEnum, Clone, Debug, PartialEq, Eq)] +#[derive(ts_rs::TS, clap::ValueEnum, Clone, Debug, PartialEq, Eq)] +#[ts(export)] pub enum ExportFormat { #[clap(name = "csv")] + #[ts(rename = "csv")] CSV, #[clap(name = "json")] + #[ts(rename = "json")] JSON, #[clap(name = "xlsx")] + #[ts(rename = "xlsx")] XLSX, } -#[derive(Subcommand)] +#[derive(ts_rs::TS, Subcommand)] +#[ts(export, rename_all = "kebab-case")] enum Commands { Version {}, FunctionList {}, @@ -158,6 +166,13 @@ enum Commands { #[arg(long)] sql_cipher_key: Option, }, + CopyFile { + #[arg(long, required = true)] + src: PathBuf, + + #[arg(long, required = true)] + dst: PathBuf, + }, } /// Structure representing a database query @@ -191,11 +206,15 @@ impl From<(String, i64, i64)> for CompletionQuery { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(ts_rs::TS, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[ts(export, rename_all = "snake_case")] enum ServerCommand { Interrupt, Close, TryReconnect, + DisconnectTemporarily, + Resume, Handle, SemanticHighlight, CodeLens, @@ -203,23 +222,6 @@ enum ServerCommand { Completion, } -impl FromStr for ServerCommand { - type Err = (); - fn from_str(s: &str) -> Result { - Ok(match s.trim() { - "interrupt" => Self::Interrupt, - "close" => Self::Close, - "try_reconnect" => Self::TryReconnect, - "handle" => Self::Handle, - "semantic_highlight" => Self::SemanticHighlight, - "code_lens" => Self::CodeLens, - "check_syntax" => Self::CheckSyntax, - "completion" => Self::Completion, - _ => return Err(()), - }) - } -} - fn cli(args: Args, stdin: F, mut stdout: &mut O, mut stderr: &mut E) -> i32 where F: Fn() -> I + std::marker::Send + 'static, @@ -299,6 +301,7 @@ where }; let (command_sender, command_receiver) = std::sync::mpsc::channel::(); + let (resume_command_sender, resume_command_receiver) = std::sync::mpsc::channel::<()>(); let interrupt_handle = Arc::new(Mutex::new(db.get_interrupt_handle())); let _thread = { let interrupt_handle = Arc::clone(&interrupt_handle); @@ -311,10 +314,16 @@ where Ok(0) => return, _ => {} } - match command_str.parse::() { + + match ServerCommand::deserialize( + serde::de::value::StrDeserializer::::new(command_str.trim()), + ) { Ok(ServerCommand::Interrupt) => { interrupt_handle.lock().unwrap().interrupt(); } + Ok(ServerCommand::Resume) => { + resume_command_sender.send(()).unwrap(); + } Ok(command) => { if command_sender.send(command).is_err() { return; @@ -388,6 +397,30 @@ where } } + ServerCommand::DisconnectTemporarily => { + drop(db); + + if resume_command_receiver.recv().is_err() { + return 0; + } + + match sqlite3::SQLite3::connect(&database_filepath, READ_ONLY, &sql_cipher_key) { + Ok(new_db) => { + db = new_db; + + *interrupt_handle.lock().unwrap() = db.get_interrupt_handle(); + write_named(&mut w, &None::<&i64>).expect("Failed to write the result."); + finish(&mut stdout, &mut w, error::ErrorCode::Success); + } + Err(err) => { + w.truncate_all(); + write!(w, "{err}").expect("Failed to write an error message."); + finish(&mut stdout, &mut w, err.code()); + return 1; + } + } + } + // Handle the request ServerCommand::Handle => { // Deserialize the request @@ -569,6 +602,22 @@ where return 1; } } + Commands::CopyFile { src, dst } => { + // Read + let data = match std::fs::read(&src) { + Err(err) => { + writeln!(&mut stderr, "Failed to read {src:?}: {err}").unwrap(); + return 1; + } + Ok(data) => data, + }; + + // Write + if let Err(err) = std::fs::write(&dst, data) { + writeln!(&mut stderr, "Failed to write {src:?}: {err}").unwrap(); + return 1; + } + } } 0