-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #54 from philippgille/use-max-heap
Use max heap for building query result
- Loading branch information
Showing
4 changed files
with
93 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,12 @@ | ||
package chromem | ||
|
||
import ( | ||
"cmp" | ||
"container/heap" | ||
"context" | ||
"fmt" | ||
"runtime" | ||
"slices" | ||
"strings" | ||
"sync" | ||
) | ||
|
@@ -15,6 +18,70 @@ type docSim struct { | |
similarity float32 | ||
} | ||
|
||
// docMaxHeap is a max-heap of docSims, based on similarity. | ||
// See https://pkg.go.dev/container/[email protected]#example-package-IntHeap | ||
type docMaxHeap []docSim | ||
|
||
func (h docMaxHeap) Len() int { return len(h) } | ||
func (h docMaxHeap) Less(i, j int) bool { return h[i].similarity < h[j].similarity } | ||
func (h docMaxHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } | ||
|
||
func (h *docMaxHeap) Push(x any) { | ||
// Push and Pop use pointer receivers because they modify the slice's length, | ||
// not just its contents. | ||
*h = append(*h, x.(docSim)) | ||
} | ||
|
||
func (h *docMaxHeap) Pop() any { | ||
old := *h | ||
n := len(old) | ||
x := old[n-1] | ||
*h = old[0 : n-1] | ||
return x | ||
} | ||
|
||
// maxDocSims manages a max-heap of docSims with a fixed size, keeping the n highest | ||
// similarities. It's safe for concurrent use, but not the result of values(). | ||
// In our benchmarks this was faster than sorting a slice of docSims at the end. | ||
type maxDocSims struct { | ||
h docMaxHeap | ||
lock sync.RWMutex | ||
size int | ||
} | ||
|
||
// newMaxDocSims creates a new nMaxDocs with a fixed size. | ||
func newMaxDocSims(size int) *maxDocSims { | ||
return &maxDocSims{ | ||
h: make(docMaxHeap, 0, size), | ||
size: size, | ||
} | ||
} | ||
|
||
// add inserts a new docSim into the heap, keeping only the top n similarities. | ||
func (mds *maxDocSims) add(doc docSim) { | ||
mds.lock.Lock() | ||
defer mds.lock.Unlock() | ||
if mds.h.Len() < mds.size { | ||
heap.Push(&mds.h, doc) | ||
} else if mds.h.Len() > 0 && mds.h[0].similarity < doc.similarity { | ||
// Replace the smallest similarity if the new doc's similarity is higher | ||
heap.Pop(&mds.h) | ||
heap.Push(&mds.h, doc) | ||
} | ||
} | ||
|
||
// values returns the docSims in the heap, sorted by similarity (descending). | ||
// The call itself is safe for concurrent use with add(), but the result isn't. | ||
// Only work with the result after all calls to add() have finished. | ||
func (d *maxDocSims) values() []docSim { | ||
d.lock.RLock() | ||
defer d.lock.RUnlock() | ||
slices.SortFunc(d.h, func(i, j docSim) int { | ||
return cmp.Compare(j.similarity, i.similarity) | ||
}) | ||
return d.h | ||
} | ||
|
||
// filterDocs filters a map of documents by metadata and content. | ||
// It does this concurrently. | ||
func filterDocs(docs map[string]*Document, where, whereDocument map[string]string) []*Document { | ||
|
@@ -95,9 +162,8 @@ func documentMatchesFilters(document *Document, where, whereDocument map[string] | |
return true | ||
} | ||
|
||
func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Document) ([]docSim, error) { | ||
similarities := make([]docSim, 0, len(docs)) | ||
similaritiesLock := sync.Mutex{} | ||
func getMostSimilarDocs(ctx context.Context, queryVectors []float32, docs []*Document, n int) ([]docSim, error) { | ||
nMaxDocs := newMaxDocSims(n) | ||
|
||
// Determine concurrency. Use number of docs or CPUs, whichever is smaller. | ||
numCPUs := runtime.NumCPU() | ||
|
@@ -152,10 +218,7 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu | |
return | ||
} | ||
|
||
similaritiesLock.Lock() | ||
// We don't defer the unlock because we want to unlock much earlier. | ||
similarities = append(similarities, docSim{docID: doc.ID, similarity: sim}) | ||
similaritiesLock.Unlock() | ||
nMaxDocs.add(docSim{docID: doc.ID, similarity: sim}) | ||
} | ||
}(docs[start:end]) | ||
} | ||
|
@@ -166,5 +229,5 @@ func calcDocSimilarity(ctx context.Context, queryVectors []float32, docs []*Docu | |
return nil, sharedErr | ||
} | ||
|
||
return similarities, nil | ||
return nMaxDocs.values(), nil | ||
} |