Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Connection API to install SQLite update hooks. #1839

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions libsql/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ use crate::statement::Statement;
use crate::transaction::Transaction;
use crate::{Result, TransactionBehavior};

#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Op {
Insert = 0,
Delete = 1,
Update = 2,
}

#[async_trait::async_trait]
pub(crate) trait Conn {
async fn execute(&self, sql: &str, params: Params) -> Result<u64>;
Expand Down Expand Up @@ -38,6 +45,10 @@ pub(crate) trait Conn {
fn load_extension(&self, _dylib_path: &Path, _entry_point: Option<&str>) -> Result<()> {
Err(crate::Error::LoadExtensionNotSupported)
}

fn add_update_hook(&self, _cb: Box<dyn Fn(Op, &str, &str, i64) + Send + Sync>) -> Result<()> {
Err(crate::Error::UpdateHookNotSupported)
}
}

/// A set of rows returned from `execute_batch`/`execute_transactional_batch`. It is essentially
Expand Down Expand Up @@ -244,6 +255,13 @@ impl Connection {
) -> Result<()> {
self.conn.load_extension(dylib_path.as_ref(), entry_point)
}

pub fn add_update_hook(
&self,
cb: Box<dyn Fn(Op, &str, &str, i64) + Send + Sync>,
) -> Result<()> {
self.conn.add_update_hook(cb)
}
}

impl fmt::Debug for Connection {
Expand Down
2 changes: 2 additions & 0 deletions libsql/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ pub enum Error {
SyncNotSupported(String), // Not in rusqlite
#[error("Loading extension is only supported in local databases.")]
LoadExtensionNotSupported, // Not in rusqlite
#[error("Update hooks are only supported in local databases.")]
UpdateHookNotSupported, // Not in rusqlite
#[error("Column not found: {0}")]
ColumnNotFound(i32), // Not in rusqlite
#[error("Hrana: `{0}`")]
Expand Down
2 changes: 1 addition & 1 deletion libsql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ cfg_hrana! {
}

pub use self::{
connection::{BatchRows, Connection},
connection::{BatchRows, Connection, Op},
database::{Builder, Database},
load_extension_guard::LoadExtensionGuard,
rows::{Column, Row, Rows},
Expand Down
49 changes: 48 additions & 1 deletion libsql/src/local/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@

use crate::local::rows::BatchedRows;
use crate::params::Params;
use crate::{connection::BatchRows, errors};
use crate::{
connection::{BatchRows, Op},
errors,
};

use super::{Database, Error, Result, Rows, RowsFuture, Statement, Transaction};

Expand All @@ -11,6 +14,10 @@ use crate::TransactionBehavior;
use libsql_sys::ffi;
use std::{ffi::c_int, fmt, path::Path, sync::Arc};

struct Container {
cb: Box<dyn Fn(Op, &str, &str, i64) + Send + Sync>,
}

/// A connection to a libSQL database.
#[derive(Clone)]
pub struct Connection {
Expand Down Expand Up @@ -384,6 +391,24 @@ impl Connection {
})
}

/// Installs update hook
pub fn add_update_hook(&self, cb: Box<dyn Fn(Op, &str, &str, i64) + Send + Sync>) {
let c = Box::new(Container { cb });
let ptr: *mut Container = std::ptr::from_mut(Box::leak(c));

let old_data = unsafe {
ffi::sqlite3_update_hook(
self.raw,
Some(update_hook_cb),
ptr as *mut ::std::os::raw::c_void,
)
};

if !old_data.is_null() {
let _ = unsafe { Box::from_raw(old_data as *mut Container) };
}
}

pub fn enable_load_extension(&self, onoff: bool) -> Result<()> {
// SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION configration verb accepts 2 additional parameters: an on/off flag and a pointer to an c_int where new state of the parameter will be written (or NULL if reporting back the setting is not needed)
// See: https://sqlite.org/c3ref/c_dbconfig_defensive.html#sqlitedbconfigenableloadextension
Expand Down Expand Up @@ -489,3 +514,25 @@ impl fmt::Debug for Connection {
f.debug_struct("Connection").finish()
}
}

#[no_mangle]
extern "C" fn update_hook_cb(
data: *mut ::std::os::raw::c_void,
op: ::std::os::raw::c_int,
db_name: *const ::std::os::raw::c_char,
table_name: *const ::std::os::raw::c_char,
row_id: i64,
) {
let db = unsafe { std::ffi::CStr::from_ptr(db_name).to_string_lossy() };
let table = unsafe { std::ffi::CStr::from_ptr(table_name).to_string_lossy() };

let c = unsafe { &mut *(data as *mut Container) };
let o = match op {
9 => Op::Delete,
18 => Op::Insert,
23 => Op::Update,
_ => unreachable!("Unknown operation {op}"),
};

(*c.cb)(o, &db, &table, row_id);
}
7 changes: 5 additions & 2 deletions libsql/src/local/impls.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::sync::Arc;
use std::{fmt, path::Path};

use crate::connection::BatchRows;
use crate::connection::{Conn, BatchRows, Op};
use crate::{
connection::Conn,
params::Params,
rows::{ColumnsInner, RowInner, RowsInner},
statement::Stmt,
Expand Down Expand Up @@ -79,6 +78,10 @@ impl Conn for LibsqlConnection {
fn load_extension(&self, dylib_path: &Path, entry_point: Option<&str>) -> Result<()> {
self.conn.load_extension(dylib_path, entry_point)
}

fn add_update_hook(&self, cb: Box<dyn Fn(Op, &str, &str, i64) + Send + Sync>) -> Result<()> {
Ok(self.conn.add_update_hook(cb))
}
}

impl Drop for LibsqlConnection {
Expand Down
74 changes: 73 additions & 1 deletion libsql/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ use futures::{StreamExt, TryStreamExt};
use libsql::{
named_params, params,
params::{IntoParams, IntoValue},
Connection, Database, Value,
Connection, Database, Op, Value,
};
use rand::distributions::Uniform;
use rand::prelude::*;
use std::collections::HashSet;
use std::sync::{Arc, Mutex};

async fn setup() -> Connection {
let db = Database::open(":memory:").unwrap();
Expand All @@ -27,6 +28,77 @@ async fn enable_disable_extension() {
conn.load_extension_disable().unwrap();
}

#[tokio::test]
async fn add_update_hook() {
let conn = setup().await;

#[derive(PartialEq, Debug)]
struct Data {
op: Op,
db: String,
table: String,
row_id: i64,
}

let d = Arc::new(Mutex::new(None::<Data>));

let d_clone = d.clone();
conn.add_update_hook(Box::new(move |op, db, table, row_id| {
*d_clone.lock().unwrap() = Some(Data {
op,
db: db.to_string(),
table: table.to_string(),
row_id,
});
}))
.unwrap();

let _ = conn
.execute("INSERT INTO users (id, name) VALUES (2, 'Alice')", ())
.await
.unwrap();

assert_eq!(
*d.lock().unwrap().as_ref().unwrap(),
Data {
op: Op::Insert,
db: "main".to_string(),
table: "users".to_string(),
row_id: 1,
}
);

let _ = conn
.execute("UPDATE users SET name = 'Bob' WHERE id = 2", ())
.await
.unwrap();

assert_eq!(
*d.lock().unwrap().as_ref().unwrap(),
Data {
op: Op::Update,
db: "main".to_string(),
table: "users".to_string(),
row_id: 1,
}
);

let _ = conn
.execute("DELETE FROM users WHERE id = 2", ())
.await
.unwrap();

assert_eq!(
*d.lock().unwrap().as_ref().unwrap(),
Data {
op: Op::Delete,
db: "main".to_string(),
table: "users".to_string(),
row_id: 1,
}
);
}

#[tokio::test]
async fn connection_drops_before_statements() {
let db = Database::open(":memory:").unwrap();
Expand Down