From 7aab46783c49a195f6c59f958f2fe202f7b3f15d Mon Sep 17 00:00:00 2001 From: Ferdinand Prantl Date: Wed, 27 Dec 2023 19:43:36 +0100 Subject: [PATCH 1/3] Support filtering with metadata (#1) Add a simple metadata filter based on equality of key-value pairs. * An undefined filter or an empty array matches all embeddings. * An empty object in the array matches any embedding. * If an object in the array is not empty, it will match an embedding, if all its key-value pairs are found in the embedding's metadata. * If there are more than one objects in the array, the whole filter will match an ambedding, if at least one object matches the embedding's metadata. --- src/db.rs | 48 ++++++++++++++++++++++++++++++++++++---- src/routes/collection.rs | 9 ++++++-- 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/src/db.rs b/src/db.rs index 423ece3..2aa42bd 100644 --- a/src/db.rs +++ b/src/db.rs @@ -55,7 +55,7 @@ pub struct Collection { } impl Collection { - pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec { + pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap], query: &[f32], k: usize) -> Vec { let memo_attr = get_cache_attr(self.distance, query); let distance_fn = get_distance_fn(self.distance); @@ -63,9 +63,13 @@ impl Collection { .embeddings .par_iter() .enumerate() - .map(|(index, embedding)| { - let score = distance_fn(&embedding.vector, query, memo_attr); - ScoreIndex { score, index } + .filter_map(|(index, embedding)| { + if match_embedding(embedding, filter) { + let score = distance_fn(&embedding.vector, query, memo_attr); + Some(ScoreIndex { score, index }) + } else { + None + } }) .collect::>(); @@ -90,6 +94,42 @@ impl Collection { } } +fn match_embedding(embedding: &Embedding, filter: &[HashMap]) -> bool { + // an empty filter matches any embedding + if filter.len() == 0 { + return true + } + + match &embedding.metadata { + // no metadata in an embedding cannot be matched by a not empty filter + None => false, + Some(metadata) => { + // enumerate criteria with OR semantics; look for the first one matching + for criteria in filter { + let mut matches = true; + // enumerate entries with AND semantics; look for the first one failing + for (key, expected) in criteria { + let found = match metadata.get(key) { + None => false, + Some(actual) => actual == expected + }; + // a not matching entry means the whole embedding not matching + if !found { + matches = false; + break + } + } + // all entries matching mean the whole embedding matching + if matches { + return true + } + } + // no match found + false + } + } +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] pub struct Embedding { pub id: String, diff --git a/src/routes/collection.rs b/src/routes/collection.rs index c0f8f47..8aa7633 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -5,7 +5,10 @@ use aide::axum::{ use axum::{extract::Path, http::StatusCode, Extension}; use axum_jsonschema::Json; use schemars::JsonSchema; -use std::time::Instant; +use std::{ + collections::HashMap, + time::Instant, +}; use crate::{ db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult}, @@ -54,6 +57,8 @@ async fn create_collection( struct QueryCollectionQuery { /// Vector to query with query: Vec, + /// Metadata to filter with + filter: Option>>, /// Number of results to return k: Option, } @@ -77,7 +82,7 @@ async fn query_collection( } let instant = Instant::now(); - let results = collection.get_similarity(&req.query, req.k.unwrap_or(1)); + let results = collection.get_by_metadata_and_similarity(&req.filter.unwrap_or_default(), &req.query, req.k.unwrap_or(1)); drop(db); tracing::trace!("Query to {collection_name} took {:?}", instant.elapsed()); From bf03142b6d339e63b03e4a84e896cc2fcd27029f Mon Sep 17 00:00:00 2001 From: Ferdinand Prantl Date: Wed, 27 Dec 2023 19:44:58 +0100 Subject: [PATCH 2/3] Add endpoints for complete management of embeddings (#2) * GET /collections - list collection names * GET `/collections/:collection_name/embeddings` - get embedding identifiers * POST /collections/:collection_name/embeddings - filter embeddings with metadata * DELETE /collections/:collection_name/embeddings - delete embeddings by metadata * GET /collections/:collection_name/embeddings/:embedding_id - get an embedding * DELETE /collections/:collection_name/embeddings/:embedding_id - delete an embedding --- src/db.rs | 79 ++++++++++++++++++++++- src/routes/collection.rs | 131 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 208 insertions(+), 2 deletions(-) diff --git a/src/db.rs b/src/db.rs index 2aa42bd..fc22b83 100644 --- a/src/db.rs +++ b/src/db.rs @@ -55,6 +55,36 @@ pub struct Collection { } impl Collection { + pub fn list(&self) -> Vec { + self + .embeddings + .iter() + .map(|e| e.id.to_owned()) + .collect() + } + + pub fn get(&self, id: &str) -> Option<&Embedding> { + self + .embeddings + .iter() + .find(|e| e.id == id) + } + + pub fn get_by_metadata(&self, filter: &[HashMap], k: usize) -> Vec { + self + .embeddings + .iter() + .filter_map(|embedding| { + if match_embedding(embedding, filter) { + Some(embedding.clone()) + } else { + None + } + }) + .take(k) + .collect() + } + pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap], query: &[f32], k: usize) -> Vec { let memo_attr = get_cache_attr(self.distance, query); let distance_fn = get_distance_fn(self.distance); @@ -92,6 +122,41 @@ impl Collection { }) .collect() } + + pub fn delete(&mut self, id: &str) -> bool { + let index_opt = self.embeddings + .iter() + .position(|e| e.id == id); + + match index_opt { + None => false, + Some(index) => { self.embeddings.remove(index); true } + } + } + + pub fn delete_by_metadata(&mut self, filter: &[HashMap]) { + if filter.len() == 0 { + self.embeddings.clear(); + return + } + + let indexes = self + .embeddings + .par_iter() + .enumerate() + .filter_map(|(index, embedding)| { + if match_embedding(embedding, filter) { + Some(index) + } else { + None + } + }) + .collect::>(); + + for index in indexes { + self.embeddings.remove(index); + } + } } fn match_embedding(embedding: &Embedding, filter: &[HashMap]) -> bool { @@ -104,7 +169,7 @@ fn match_embedding(embedding: &Embedding, filter: &[HashMap]) -> // no metadata in an embedding cannot be matched by a not empty filter None => false, Some(metadata) => { - // enumerate criteria with OR semantics; look for the first one matching + // enumerate criteria with OR semantics; look for the first one matching for criteria in filter { let mut matches = true; // enumerate entries with AND semantics; look for the first one failing @@ -211,6 +276,18 @@ impl Db { self.collections.get(name) } + pub fn get_collection_mut(&mut self, name: &str) -> Option<&mut Collection> { + self.collections.get_mut(name) + } + + pub fn list(&self) -> Vec { + self + .collections + .keys() + .map(|name| name.to_owned()) + .collect() + } + fn load_from_store() -> anyhow::Result { if !STORE_PATH.exists() { tracing::debug!("Creating database store"); diff --git a/src/routes/collection.rs b/src/routes/collection.rs index 8aa7633..e4c1f15 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -20,14 +20,33 @@ pub fn handler() -> ApiRouter { ApiRouter::new().nest( "/collections", ApiRouter::new() + .api_route("/", get(get_collections)) .api_route("/:collection_name", put(create_collection)) .api_route("/:collection_name", post(query_collection)) .api_route("/:collection_name", get(get_collection_info)) .api_route("/:collection_name", delete(delete_collection)) - .api_route("/:collection_name/insert", post(insert_into_collection)), + .api_route("/:collection_name/insert", post(insert_into_collection)) + .api_route("/:collection_name/embeddings", get(get_embeddings)) + .api_route("/:collection_name/embeddings", post(query_embeddings)) + .api_route("/:collection_name/embeddings", delete(delete_embeddings)) + .api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding)) + .api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)), ) } +/// Get collection names +async fn get_collections( + Extension(db): DbExtension, +) -> Result>, HTTPError> { + tracing::trace!("Getting collection names"); + + let db = db.read().await; + + let results = db.list(); + + Ok(Json(results)) +} + /// Create a new collection async fn create_collection( Path(collection_name): Path, @@ -170,3 +189,113 @@ async fn insert_into_collection( .with_status(StatusCode::BAD_REQUEST)), } } + +/// Query embeddings in a collection +async fn get_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, +) -> Result>, HTTPError> { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let results = collection.list(); + drop(db); + + Ok(Json(results)) +} + +#[derive(Debug, serde::Deserialize, JsonSchema)] +struct EmbeddingsQuery { + /// Metadata to filter with + filter: Vec>, + /// Number of results to return + k: Option, +} + +/// Query embeddings in a collection +async fn query_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, + Json(req): Json, +) -> Result>, HTTPError> { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let instant = Instant::now(); + let results = collection.get_by_metadata(&req.filter, req.k.unwrap_or(1)); + drop(db); + + tracing::trace!("Query embeddings from {collection_name} took {:?}", instant.elapsed()); + Ok(Json(results)) +} + +/// Delete embeddings in a collection +async fn delete_embeddings( + Path(collection_name): Path, + Extension(db): DbExtension, + Json(req): Json, +) -> Result { + tracing::trace!("Querying embeddings from collection {collection_name}"); + + let mut db = db.write().await; + let collection = db + .get_collection_mut(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + collection.delete_by_metadata(&req.filter); + drop(db); + + Ok(StatusCode::NO_CONTENT) +} + +/// Get an embedding from a collection +async fn get_embedding( + Path((collection_name, embedding_id)): Path<(String, String)>, + Extension(db): DbExtension, +) -> Result, HTTPError> { + tracing::trace!("Getting {embedding_id} from collection {collection_name}"); + + if embedding_id.len() == 0 { + return Err(HTTPError::new("Embedding identifier empty").with_status(StatusCode::BAD_REQUEST)); + } + + let db = db.read().await; + let collection = db + .get_collection(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let embedding = collection + .get(&embedding_id) + .ok_or_else(|| HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND))?; + + Ok(Json(embedding.to_owned())) +} + +/// Delete an embedding from a collection +async fn delete_embedding( + Path((collection_name, embedding_id)): Path<(String, String)>, + Extension(db): DbExtension, +) -> Result { + tracing::trace!("Removing embedding {embedding_id} from collection {collection_name}"); + + let mut db = db.write().await; + let collection = db + .get_collection_mut(&collection_name) + .ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?; + + let delete_result = collection.delete(&embedding_id); + drop(db); + + match delete_result { + true => Ok(StatusCode::NO_CONTENT), + false => Err(HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND)), + } +} From 141d33ea75429b74b185310ae33ac633eacf8788 Mon Sep 17 00:00:00 2001 From: Ferdinand Prantl Date: Wed, 27 Dec 2023 19:38:58 +0100 Subject: [PATCH 3/3] Create a new embedding by PUT Follow the unusual patter from PUT /collections/:collection_name: objects are created by PUT /place/ Instead of POST /collections/:collection_name/insert, use PUT /collections/:collection_name/embeddings/:embedding_id. --- src/db.rs | 6 ++++++ src/routes/collection.rs | 23 ++++++++++++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/src/db.rs b/src/db.rs index fc22b83..a819f4d 100644 --- a/src/db.rs +++ b/src/db.rs @@ -34,12 +34,15 @@ pub enum Error { #[derive(Debug, serde::Serialize, serde::Deserialize)] pub struct Db { + /// Collections in the database pub collections: HashMap, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] pub struct SimilarityResult { + /// Similarity score score: f32, + /// Matching embedding embedding: Embedding, } @@ -197,8 +200,11 @@ fn match_embedding(embedding: &Embedding, filter: &[HashMap]) -> #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)] pub struct Embedding { + /// Unique identifier pub id: String, + /// Vector computed from a text chunk pub vector: Vec, + /// Metadata about the source text pub metadata: Option>, } diff --git a/src/routes/collection.rs b/src/routes/collection.rs index e4c1f15..888b604 100644 --- a/src/routes/collection.rs +++ b/src/routes/collection.rs @@ -25,10 +25,10 @@ pub fn handler() -> ApiRouter { .api_route("/:collection_name", post(query_collection)) .api_route("/:collection_name", get(get_collection_info)) .api_route("/:collection_name", delete(delete_collection)) - .api_route("/:collection_name/insert", post(insert_into_collection)) .api_route("/:collection_name/embeddings", get(get_embeddings)) .api_route("/:collection_name/embeddings", post(query_embeddings)) .api_route("/:collection_name/embeddings", delete(delete_embeddings)) + .api_route("/:collection_name/embeddings/:embedding_id", put(insert_into_collection)) .api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding)) .api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)), ) @@ -162,16 +162,29 @@ async fn delete_collection( } } +#[derive(Debug, serde::Deserialize, JsonSchema)] +struct EmbeddingData { + /// Vector computed from a text chunk + vector: Vec, + /// Metadata about the source text + metadata: Option>, +} + /// Insert a vector into a collection async fn insert_into_collection( - Path(collection_name): Path, + Path((collection_name, embedding_id)): Path<(String, String)>, Extension(db): DbExtension, - Json(embedding): Json, + Json(embedding_data): Json, ) -> Result { tracing::trace!("Inserting into collection {collection_name}"); let mut db = db.write().await; + let embedding = Embedding { + id: embedding_id, + vector: embedding_data.vector, + metadata: embedding_data.metadata, + }; let insert_result = db.insert_into_collection(&collection_name, embedding); drop(db); @@ -263,10 +276,6 @@ async fn get_embedding( ) -> Result, HTTPError> { tracing::trace!("Getting {embedding_id} from collection {collection_name}"); - if embedding_id.len() == 0 { - return Err(HTTPError::new("Embedding identifier empty").with_status(StatusCode::BAD_REQUEST)); - } - let db = db.read().await; let collection = db .get_collection(&collection_name)