Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a new embedding by PUT #8

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 127 additions & 4 deletions src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,15 @@ pub enum Error {

#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Db {
/// Collections in the database
pub collections: HashMap<String, Collection>,
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)]
pub struct SimilarityResult {
/// Similarity score
score: f32,
/// Matching embedding
embedding: Embedding,
}

Expand All @@ -55,17 +58,51 @@ pub struct Collection {
}

impl Collection {
pub fn get_similarity(&self, query: &[f32], k: usize) -> Vec<SimilarityResult> {
pub fn list(&self) -> Vec<String> {
self
.embeddings
.iter()
.map(|e| e.id.to_owned())
.collect()
}

pub fn get(&self, id: &str) -> Option<&Embedding> {
self
.embeddings
.iter()
.find(|e| e.id == id)
}

pub fn get_by_metadata(&self, filter: &[HashMap<String, String>], k: usize) -> Vec<Embedding> {
self
.embeddings
.iter()
.filter_map(|embedding| {
if match_embedding(embedding, filter) {
Some(embedding.clone())
} else {
None
}
})
.take(k)
.collect()
}

pub fn get_by_metadata_and_similarity(&self, filter: &[HashMap<String, String>], query: &[f32], k: usize) -> Vec<SimilarityResult> {
let memo_attr = get_cache_attr(self.distance, query);
let distance_fn = get_distance_fn(self.distance);

let scores = self
.embeddings
.par_iter()
.enumerate()
.map(|(index, embedding)| {
let score = distance_fn(&embedding.vector, query, memo_attr);
ScoreIndex { score, index }
.filter_map(|(index, embedding)| {
if match_embedding(embedding, filter) {
let score = distance_fn(&embedding.vector, query, memo_attr);
Some(ScoreIndex { score, index })
} else {
None
}
})
.collect::<Vec<_>>();

Expand All @@ -88,12 +125,86 @@ impl Collection {
})
.collect()
}

pub fn delete(&mut self, id: &str) -> bool {
let index_opt = self.embeddings
.iter()
.position(|e| e.id == id);

match index_opt {
None => false,
Some(index) => { self.embeddings.remove(index); true }
}
}

pub fn delete_by_metadata(&mut self, filter: &[HashMap<String, String>]) {
if filter.len() == 0 {
self.embeddings.clear();
return
}

let indexes = self
.embeddings
.par_iter()
.enumerate()
.filter_map(|(index, embedding)| {
if match_embedding(embedding, filter) {
Some(index)
} else {
None
}
})
.collect::<Vec<_>>();

for index in indexes {
self.embeddings.remove(index);
}
}
}

fn match_embedding(embedding: &Embedding, filter: &[HashMap<String, String>]) -> bool {
// an empty filter matches any embedding
if filter.len() == 0 {
return true
}

match &embedding.metadata {
// no metadata in an embedding cannot be matched by a not empty filter
None => false,
Some(metadata) => {
// enumerate criteria with OR semantics; look for the first one matching
for criteria in filter {
let mut matches = true;
// enumerate entries with AND semantics; look for the first one failing
for (key, expected) in criteria {
let found = match metadata.get(key) {
None => false,
Some(actual) => actual == expected
};
// a not matching entry means the whole embedding not matching
if !found {
matches = false;
break
}
}
// all entries matching mean the whole embedding matching
if matches {
return true
}
}
// no match found
false
}
}
}

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, JsonSchema)]
pub struct Embedding {
/// Unique identifier
pub id: String,
/// Vector computed from a text chunk
pub vector: Vec<f32>,
/// Metadata about the source text
pub metadata: Option<HashMap<String, String>>,
}

Expand Down Expand Up @@ -171,6 +282,18 @@ impl Db {
self.collections.get(name)
}

pub fn get_collection_mut(&mut self, name: &str) -> Option<&mut Collection> {
self.collections.get_mut(name)
}

pub fn list(&self) -> Vec<String> {
self
.collections
.keys()
.map(|name| name.to_owned())
.collect()
}

fn load_from_store() -> anyhow::Result<Self> {
if !STORE_PATH.exists() {
tracing::debug!("Creating database store");
Expand Down
153 changes: 148 additions & 5 deletions src/routes/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use aide::axum::{
use axum::{extract::Path, http::StatusCode, Extension};
use axum_jsonschema::Json;
use schemars::JsonSchema;
use std::time::Instant;
use std::{
collections::HashMap,
time::Instant,
};

use crate::{
db::{self, Collection, DbExtension, Embedding, Error as DbError, SimilarityResult},
Expand All @@ -17,14 +20,33 @@ pub fn handler() -> ApiRouter {
ApiRouter::new().nest(
"/collections",
ApiRouter::new()
.api_route("/", get(get_collections))
.api_route("/:collection_name", put(create_collection))
.api_route("/:collection_name", post(query_collection))
.api_route("/:collection_name", get(get_collection_info))
.api_route("/:collection_name", delete(delete_collection))
.api_route("/:collection_name/insert", post(insert_into_collection)),
.api_route("/:collection_name/embeddings", get(get_embeddings))
.api_route("/:collection_name/embeddings", post(query_embeddings))
.api_route("/:collection_name/embeddings", delete(delete_embeddings))
.api_route("/:collection_name/embeddings/:embedding_id", put(insert_into_collection))
.api_route("/:collection_name/embeddings/:embedding_id", get(get_embedding))
.api_route("/:collection_name/embeddings/:embedding_id", delete(delete_embedding)),
)
}

/// Get collection names
async fn get_collections(
Extension(db): DbExtension,
) -> Result<Json<Vec<String>>, HTTPError> {
tracing::trace!("Getting collection names");

let db = db.read().await;

let results = db.list();

Ok(Json(results))
}

/// Create a new collection
async fn create_collection(
Path(collection_name): Path<String>,
Expand Down Expand Up @@ -54,6 +76,8 @@ async fn create_collection(
struct QueryCollectionQuery {
/// Vector to query with
query: Vec<f32>,
/// Metadata to filter with
filter: Option<Vec<HashMap<String, String>>>,
/// Number of results to return
k: Option<usize>,
}
Expand All @@ -77,7 +101,7 @@ async fn query_collection(
}

let instant = Instant::now();
let results = collection.get_similarity(&req.query, req.k.unwrap_or(1));
let results = collection.get_by_metadata_and_similarity(&req.filter.unwrap_or_default(), &req.query, req.k.unwrap_or(1));
drop(db);

tracing::trace!("Query to {collection_name} took {:?}", instant.elapsed());
Expand Down Expand Up @@ -138,16 +162,29 @@ async fn delete_collection(
}
}

#[derive(Debug, serde::Deserialize, JsonSchema)]
struct EmbeddingData {
/// Vector computed from a text chunk
vector: Vec<f32>,
/// Metadata about the source text
metadata: Option<HashMap<String, String>>,
}

/// Insert a vector into a collection
async fn insert_into_collection(
Path(collection_name): Path<String>,
Path((collection_name, embedding_id)): Path<(String, String)>,
Extension(db): DbExtension,
Json(embedding): Json<Embedding>,
Json(embedding_data): Json<EmbeddingData>,
) -> Result<StatusCode, HTTPError> {
tracing::trace!("Inserting into collection {collection_name}");

let mut db = db.write().await;

let embedding = Embedding {
id: embedding_id,
vector: embedding_data.vector,
metadata: embedding_data.metadata,
};
let insert_result = db.insert_into_collection(&collection_name, embedding);
drop(db);

Expand All @@ -165,3 +202,109 @@ async fn insert_into_collection(
.with_status(StatusCode::BAD_REQUEST)),
}
}

/// Query embeddings in a collection
async fn get_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
) -> Result<Json<Vec<String>>, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let results = collection.list();
drop(db);

Ok(Json(results))
}

#[derive(Debug, serde::Deserialize, JsonSchema)]
struct EmbeddingsQuery {
/// Metadata to filter with
filter: Vec<HashMap<String, String>>,
/// Number of results to return
k: Option<usize>,
}

/// Query embeddings in a collection
async fn query_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
Json(req): Json<EmbeddingsQuery>,
) -> Result<Json<Vec<Embedding>>, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let instant = Instant::now();
let results = collection.get_by_metadata(&req.filter, req.k.unwrap_or(1));
drop(db);

tracing::trace!("Query embeddings from {collection_name} took {:?}", instant.elapsed());
Ok(Json(results))
}

/// Delete embeddings in a collection
async fn delete_embeddings(
Path(collection_name): Path<String>,
Extension(db): DbExtension,
Json(req): Json<EmbeddingsQuery>,
) -> Result<StatusCode, HTTPError> {
tracing::trace!("Querying embeddings from collection {collection_name}");

let mut db = db.write().await;
let collection = db
.get_collection_mut(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

collection.delete_by_metadata(&req.filter);
drop(db);

Ok(StatusCode::NO_CONTENT)
}

/// Get an embedding from a collection
async fn get_embedding(
Path((collection_name, embedding_id)): Path<(String, String)>,
Extension(db): DbExtension,
) -> Result<Json<Embedding>, HTTPError> {
tracing::trace!("Getting {embedding_id} from collection {collection_name}");

let db = db.read().await;
let collection = db
.get_collection(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let embedding = collection
.get(&embedding_id)
.ok_or_else(|| HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND))?;

Ok(Json(embedding.to_owned()))
}

/// Delete an embedding from a collection
async fn delete_embedding(
Path((collection_name, embedding_id)): Path<(String, String)>,
Extension(db): DbExtension,
) -> Result<StatusCode, HTTPError> {
tracing::trace!("Removing embedding {embedding_id} from collection {collection_name}");

let mut db = db.write().await;
let collection = db
.get_collection_mut(&collection_name)
.ok_or_else(|| HTTPError::new("Collection not found").with_status(StatusCode::NOT_FOUND))?;

let delete_result = collection.delete(&embedding_id);
drop(db);

match delete_result {
true => Ok(StatusCode::NO_CONTENT),
false => Err(HTTPError::new("Embedding not found").with_status(StatusCode::NOT_FOUND)),
}
}