diff --git a/src/reader.rs b/src/reader.rs index 9bd4fe1..525baa7 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -39,6 +39,7 @@ const LINEAR_SEARCH_THRESHOLD: u64 = 0; pub struct QueryBuilder<'a, D: Distance> { reader: &'a Reader<'a, D>, candidates: Option<&'a RoaringBitmap>, + filter: Option bool + 'a>>, count: usize, ef: usize, } @@ -101,6 +102,22 @@ impl<'a, D: Distance> QueryBuilder<'a, D> { self } + /// Specify a function to be used to filter items. + /// The function should accept (ItemId, Distance) and should return a boolean. + /// A return value of `false` indicates the item should be filtered. + /// + /// # Examples + /// + /// ```no_run + /// # use hannoy::{Reader, distances::Euclidean}; + /// # let (reader, rtxn): (Reader, heed::RoTxn) = todo!(); + /// reader.nns(20).filter(|id, distance| id % 2 == 0).by_item(&rtxn, 6); + /// ``` + pub fn filter bool + 'a>(&mut self, filter: F) -> &mut Self { + self.filter = Some(Box::new(filter)); + self + } + /// Specify a search buffer size from which the closest elements are returned. Increasing this /// value improves the search relevancy but increases latency as more neighbours need to be /// searched. @@ -334,7 +351,7 @@ impl<'t, D: Distance> Reader<'t, D> { /// /// You must provide the number of items you want to receive. pub fn nns(&self, count: usize) -> QueryBuilder { - QueryBuilder { reader: self, candidates: None, count, ef: DEFAULT_EF_SEARCH } + QueryBuilder { reader: self, candidates: None, filter: None, count, ef: DEFAULT_EF_SEARCH } } /// Get a generic read node from the database using the version of the database found while creating the reader. @@ -437,6 +454,12 @@ impl<'t, D: Distance> Reader<'t, D> { let mut nns = Vec::with_capacity(opt.count); while let Some((OrderedFloat(f), id)) = neighbours.pop_min() { + if let Some(filter) = &opt.filter { + if !filter(id, f) { + continue; + } + } + if opt.candidates.is_none_or(|candidates| candidates.contains(id)) { nns.push((id, f)); }