Skip to content

Commit

Permalink
feat: simplified the batching, added faster set many implementation a…
Browse files Browse the repository at this point in the history
…nd tests
  • Loading branch information
maaasyn committed Oct 30, 2024
1 parent 5c0fe03 commit 44012a2
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 30 deletions.
55 changes: 25 additions & 30 deletions src/store/stores/sqlite.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use async_trait::async_trait;
use sqlx::Error;
use sqlx::{sqlite::SqliteConnectOptions, Pool, Row, Sqlite, SqlitePool};
use std::cmp::min;
use std::collections::HashMap;
use tokio::sync::Mutex;

Expand Down Expand Up @@ -71,7 +70,6 @@ impl Store for SQLiteStore {
.fetch_optional(&*pool)
.await?;

// Extract the value from the row, if it exists
if let Some(row) = row {
let value: String = row.try_get("value")?;
Ok(Some(value))
Expand All @@ -84,22 +82,16 @@ impl Store for SQLiteStore {
let pool = self.db.lock().await;
let mut map = HashMap::new();

let total_keys = keys.len();
let mut offset = 0;

while offset < total_keys {
let end = min(offset + MAX_VARIABLE_NUMBER, total_keys);
let key_slice = &keys[offset..end];

let placeholders = key_slice.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
for key_chunk in keys.chunks(MAX_VARIABLE_NUMBER) {
let placeholders = key_chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
let query_statement = format!(
"SELECT key, value FROM store WHERE key IN ({})",
placeholders
);

let mut query = sqlx::query(&query_statement);

for key in key_slice {
for key in key_chunk {
query = query.bind(*key);
}

Expand All @@ -109,8 +101,6 @@ impl Store for SQLiteStore {
let value: String = row.get("value");
map.insert(key, value);
}

offset = end;
}

Ok(map)
Expand All @@ -131,12 +121,25 @@ impl Store for SQLiteStore {
let pool = self.db.lock().await;
let mut transaction = pool.begin().await?;

for (key, value) in entries.iter() {
sqlx::query("INSERT OR REPLACE INTO store (key, value) VALUES (?, ?)")
.bind(key)
.bind(value)
.execute(&mut *transaction)
.await?;
for entry_chunk in entries
.iter()
.collect::<Vec<_>>()
.chunks(MAX_VARIABLE_NUMBER)
{
let mut query = String::from("INSERT OR REPLACE INTO store (key, value) VALUES ");
let placeholders = entry_chunk
.iter()
.map(|_| "(?, ?)")
.collect::<Vec<_>>()
.join(", ");
query.push_str(&placeholders);

let mut sqlx_query = sqlx::query(&query);
for (key, value) in entry_chunk {
sqlx_query = sqlx_query.bind(key).bind(value);
}

sqlx_query.execute(&mut *transaction).await?;
}

transaction.commit().await?;
Expand All @@ -156,25 +159,17 @@ impl Store for SQLiteStore {
async fn delete_many(&self, keys: Vec<&str>) -> Result<(), StoreError> {
let pool = self.db.lock().await;

let total_keys = keys.len();
let mut offset = 0;

while offset < total_keys {
let end = min(offset + MAX_VARIABLE_NUMBER, total_keys);
let key_slice = &keys[offset..end];

let placeholders = key_slice.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
for key_chunk in keys.chunks(MAX_VARIABLE_NUMBER) {
let placeholders = key_chunk.iter().map(|_| "?").collect::<Vec<_>>().join(", ");
let query_statement = format!("DELETE FROM store WHERE key IN ({})", placeholders);

let mut query = sqlx::query(&query_statement);

for key in key_slice {
for key in key_chunk {
query = query.bind(*key);
}

query.execute(&*pool).await?;

offset = end;
}

Ok(())
Expand Down
22 changes: 22 additions & 0 deletions tests/store/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,25 @@ async fn test_get_some_in_store_table() {
.unwrap();
assert_eq!(value.unwrap(), "value1".to_string());
}

#[tokio::test]
async fn test_batch_insertion_and_retrieval() {
let store = InMemoryStore::default();
let store = Arc::new(store);

let mut entries = HashMap::new();
for i in 0..10_000 {
entries.insert(format!("key{}", i), format!("value{}", i));
}

store.set_many(entries.clone()).await.unwrap();

let keys: Vec<_> = entries.keys().map(|k| k.as_str()).collect();
let values = store.get_many(keys).await.unwrap();

for i in 0..10_000 {
let key = format!("key{}", i);
let value = format!("value{}", i);
assert_eq!(values.get(&key), Some(&value));
}
}
25 changes: 25 additions & 0 deletions tests/store/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,28 @@ async fn test_get_some_in_store_table() {
.unwrap();
assert_eq!(value.unwrap(), "value1".to_string());
}

#[tokio::test]
async fn test_batch_insertion_and_retrieval() {
let store = SQLiteStore::new(":memory:", None, Some("test"))
.await
.unwrap();

let store = Arc::new(store);

let mut entries = HashMap::new();
for i in 0..10_000 {
entries.insert(format!("key{}", i), format!("value{}", i));
}

store.set_many(entries.clone()).await.unwrap();

let keys: Vec<_> = entries.keys().map(|k| k.as_str()).collect();
let values = store.get_many(keys).await.unwrap();

for i in 0..10_000 {
let key = format!("key{}", i);
let value = format!("value{}", i);
assert_eq!(values.get(&key), Some(&value));
}
}

0 comments on commit 44012a2

Please sign in to comment.