-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathembedding.go
More file actions
104 lines (88 loc) · 2.5 KB
/
embedding.go
File metadata and controls
104 lines (88 loc) · 2.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package syzgydb
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"net/http"
"strings"
"sync"
)
const maxCacheSize = 100
var (
embeddingCache = newLRUCache(maxCacheSize)
cacheMutex sync.RWMutex
)
type EmbedTextFunc func(text []string, useCache bool) ([][]float64, error)
// Default implementation of the embedding function
var embedText EmbedTextFunc = EmbedText
// EmbedText connects to the configured Ollama server and runs the configured text model
// to generate an embedding for the given text.
func EmbedText(texts []string, useCache bool) ([][]float64, error) {
// Check the cache first if useCache is true
if useCache {
cachedEmbeddings := make([][]float64, len(texts))
allCached := true
cacheMutex.RLock()
for i, text := range texts {
if embedding, found := embeddingCache.get(text); found {
cachedEmbeddings[i] = embedding
} else {
allCached = false
break
}
}
cacheMutex.RUnlock()
if allCached {
return cachedEmbeddings, nil
}
}
// Prepare the request payload
payload := map[string]interface{}{
"model": globalConfig.TextModel,
"input": texts,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, fmt.Errorf("failed to marshal request payload: %v", err)
}
// Construct the request URL
url := globalConfig.OllamaServer
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
url = "http://" + url
}
url = fmt.Sprintf("%s/api/embed", url)
log.Printf("Sending to %v %v", url, payload)
// Make the HTTP request
resp, err := http.Post(url, "application/json", bytes.NewBuffer(payloadBytes))
if err != nil {
return nil, fmt.Errorf("failed to connect to Ollama server: %v", err)
}
defer resp.Body.Close()
// Check for a successful response
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := ioutil.ReadAll(resp.Body)
return nil, fmt.Errorf("failed to get embedding: %s", string(bodyBytes))
}
// Parse the response
var response struct {
Embeddings [][]float64 `json:"embeddings"`
}
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
return nil, fmt.Errorf("failed to decode response: %v", err)
}
// Check if embeddings are present
if len(response.Embeddings) == 0 {
return nil, fmt.Errorf("no embeddings found in response")
}
// Store the new embeddings in the cache if useCache is true
if useCache {
cacheMutex.Lock()
for i, text := range texts {
embeddingCache.put(text, response.Embeddings[i])
}
cacheMutex.Unlock()
}
return response.Embeddings, nil
}