diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/cohere.go b/plugins/wasm-go/extensions/ai-cache/embedding/cohere.go new file mode 100644 index 0000000000..d952d2ad2c --- /dev/null +++ b/plugins/wasm-go/extensions/ai-cache/embedding/cohere.go @@ -0,0 +1,158 @@ +package embedding + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" +) + +const ( + COHERE_DOMAIN = "api.cohere.com" + COHERE_PORT = 443 + COHERE_DEFAULT_MODEL_NAME = "embed-english-v2.0" + COHERE_ENDPOINT = "/v2/embed" +) + +type cohereProviderInitializer struct { +} + +var cohereConfig cohereProviderConfig + +type cohereProviderConfig struct { + // @Title zh-CN 文本特征提取服务 API Key + // @Description zh-CN 文本特征提取服务 API Key + apiKey string +} + +func (c *cohereProviderInitializer) InitConfig(json gjson.Result) { + cohereConfig.apiKey = json.Get("apiKey").String() +} +func (c *cohereProviderInitializer) ValidateConfig() error { + if cohereConfig.apiKey == "" { + return errors.New("[Cohere] apiKey is required") + } + return nil +} + +func (t *cohereProviderInitializer) CreateProvider(c ProviderConfig) (Provider, error) { + if c.servicePort == 0 { + c.servicePort = COHERE_PORT + } + if c.serviceHost == "" { + c.serviceHost = COHERE_DOMAIN + } + return &CohereProvider{ + config: c, + client: wrapper.NewClusterClient(wrapper.FQDNCluster{ + FQDN: c.serviceName, + Host: c.serviceHost, + Port: int64(c.servicePort), + }), + }, nil +} + +type cohereResponse struct { + Embeddings cohereEmbeddings `json:"embeddings"` +} + +type cohereEmbeddings struct { + FloatTypeEebedding [][]float64 `json:"float"` +} + +type cohereEmbeddingRequest struct { + Texts []string `json:"texts"` + Model string `json:"model"` + InputType string `json:"input_type"` + EmbeddingTypes []string `json:"embedding_types"` +} + +type CohereProvider struct { + config ProviderConfig + client wrapper.HttpClient +} + +func (t *CohereProvider) GetProviderType() string { + return PROVIDER_TYPE_COHERE +} +func (t *CohereProvider) constructParameters(texts []string, log wrapper.Log) (string, [][2]string, []byte, error) { + model := t.config.model + + if model == "" { + model = COHERE_DEFAULT_MODEL_NAME + } + data := cohereEmbeddingRequest{ + Texts: texts, + Model: model, + InputType: "search_document", + EmbeddingTypes: []string{"float"}, + } + + requestBody, err := json.Marshal(data) + if err != nil { + log.Errorf("failed to marshal request data: %v", err) + return "", nil, nil, err + } + + headers := [][2]string{ + {"Authorization", fmt.Sprintf("BEARER %s", cohereConfig.apiKey)}, + {"Content-Type", "application/json"}, + } + + return COHERE_ENDPOINT, headers, requestBody, nil +} + +func (t *CohereProvider) parseTextEmbedding(responseBody []byte) (*cohereResponse, error) { + var resp cohereResponse + err := json.Unmarshal(responseBody, &resp) + if err != nil { + return nil, err + } + return &resp, nil +} + +func (t *CohereProvider) GetEmbedding( + queryString string, + ctx wrapper.HttpContext, + log wrapper.Log, + callback func(emb []float64, err error)) error { + embUrl, embHeaders, embRequestBody, err := t.constructParameters([]string{queryString}, log) + if err != nil { + log.Errorf("failed to construct parameters: %v", err) + return err + } + + var resp *cohereResponse + err = t.client.Post(embUrl, embHeaders, embRequestBody, + func(statusCode int, responseHeaders http.Header, responseBody []byte) { + + if statusCode != http.StatusOK { + err = errors.New("failed to get embedding due to status code: " + strconv.Itoa(statusCode)) + callback(nil, err) + return + } + + log.Debugf("get embedding response: %d, %s", statusCode, responseBody) + + resp, err = t.parseTextEmbedding(responseBody) + if err != nil { + err = fmt.Errorf("failed to parse response: %v", err) + callback(nil, err) + return + } + + if len(resp.Embeddings.FloatTypeEebedding) == 0 { + err = errors.New("no embedding found in response") + callback(nil, err) + return + } + + callback(resp.Embeddings.FloatTypeEebedding[0], nil) + + }, t.config.timeout) + return err +} diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go index 35c897cce5..f31a8d17b8 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/dashscope.go @@ -8,6 +8,7 @@ import ( "strconv" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" ) const ( @@ -17,11 +18,22 @@ const ( DASHSCOPE_ENDPOINT = "/api/v1/services/embeddings/text-embedding/text-embedding" ) +var dashScopeConfig dashScopeProviderConfig + type dashScopeProviderInitializer struct { } +type dashScopeProviderConfig struct { + // @Title zh-CN 文本特征提取服务 API Key + // @Description zh-CN 文本特征提取服务 API Key + apiKey string +} + +func (c *dashScopeProviderInitializer) InitConfig(json gjson.Result) { + dashScopeConfig.apiKey = json.Get("apiKey").String() +} -func (d *dashScopeProviderInitializer) ValidateConfig(config ProviderConfig) error { - if config.apiKey == "" { +func (c *dashScopeProviderInitializer) ValidateConfig() error { + if dashScopeConfig.apiKey == "" { return errors.New("[DashScope] apiKey is required") } return nil @@ -114,14 +126,14 @@ func (d *DSProvider) constructParameters(texts []string, log wrapper.Log) (strin return "", nil, nil, err } - if d.config.apiKey == "" { + if dashScopeConfig.apiKey == "" { err := errors.New("dashScopeKey is empty") log.Errorf("failed to construct headers: %v", err) return "", nil, nil, err } headers := [][2]string{ - {"Authorization", "Bearer " + d.config.apiKey}, + {"Authorization", "Bearer " + dashScopeConfig.apiKey}, {"Content-Type", "application/json"}, } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/openai.go b/plugins/wasm-go/extensions/ai-cache/embedding/openai.go index 6b251ab341..04c1d8cdd1 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/openai.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/openai.go @@ -4,8 +4,10 @@ import ( "encoding/json" "errors" "fmt" - "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" "net/http" + + "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" ) const ( @@ -18,9 +20,21 @@ const ( type openAIProviderInitializer struct { } -func (t *openAIProviderInitializer) ValidateConfig(config ProviderConfig) error { - if config.apiKey == "" { - return errors.New("[OpenAI] embedding service ApiKey is required") +var openAIConfig openAIProviderConfig + +type openAIProviderConfig struct { + // @Title zh-CN 文本特征提取服务 API Key + // @Description zh-CN 文本特征提取服务 API Key + apiKey string +} + +func (c *openAIProviderInitializer) InitConfig(json gjson.Result) { + openAIConfig.apiKey = json.Get("apiKey").String() +} + +func (c *openAIProviderInitializer) ValidateConfig() error { + if openAIConfig.apiKey == "" { + return errors.New("[openAI] apiKey is required") } return nil } @@ -97,7 +111,7 @@ func (t *OpenAIProvider) constructParameters(text string, log wrapper.Log) (stri } headers := [][2]string{ - {"Authorization", fmt.Sprintf("Bearer %s", t.config.apiKey)}, + {"Authorization", fmt.Sprintf("Bearer %s", openAIConfig.apiKey)}, {"Content-Type", "application/json"}, } diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go index 18c9860968..608f50ad54 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/provider.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/provider.go @@ -10,11 +10,13 @@ import ( const ( PROVIDER_TYPE_DASHSCOPE = "dashscope" PROVIDER_TYPE_TEXTIN = "textin" + PROVIDER_TYPE_COHERE = "cohere" PROVIDER_TYPE_OPENAI = "openai" ) type providerInitializer interface { - ValidateConfig(ProviderConfig) error + InitConfig(json gjson.Result) + ValidateConfig() error CreateProvider(ProviderConfig) (Provider, error) } @@ -22,6 +24,7 @@ var ( providerInitializers = map[string]providerInitializer{ PROVIDER_TYPE_DASHSCOPE: &dashScopeProviderInitializer{}, PROVIDER_TYPE_TEXTIN: &textInProviderInitializer{}, + PROVIDER_TYPE_COHERE: &cohereProviderInitializer{}, PROVIDER_TYPE_OPENAI: &openAIProviderInitializer{}, } ) @@ -39,35 +42,26 @@ type ProviderConfig struct { // @Title zh-CN 文本特征提取服务端口 // @Description zh-CN 文本特征提取服务端口 servicePort int64 - // @Title zh-CN 文本特征提取服务 API Key - // @Description zh-CN 文本特征提取服务 API Key - apiKey string - //@Title zh-CN TextIn x-ti-app-id - // @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding - textinAppId string - //@Title zh-CN TextIn x-ti-secret-code - // @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding - textinSecretCode string - //@Title zh-CN TextIn request matryoshka_dim - // @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding - textinMatryoshkaDim int // @Title zh-CN 文本特征提取服务超时时间 // @Description zh-CN 文本特征提取服务超时时间 timeout uint32 // @Title zh-CN 文本特征提取服务使用的模型 // @Description zh-CN 用于文本特征提取的模型名称, 在 DashScope 中默认为 "text-embedding-v1" model string + + initializer providerInitializer } func (c *ProviderConfig) FromJson(json gjson.Result) { c.typ = json.Get("type").String() + i, has := providerInitializers[c.typ] + if has { + i.InitConfig(json) + c.initializer = i + } c.serviceName = json.Get("serviceName").String() c.serviceHost = json.Get("serviceHost").String() c.servicePort = json.Get("servicePort").Int() - c.apiKey = json.Get("apiKey").String() - c.textinAppId = json.Get("textinAppId").String() - c.textinSecretCode = json.Get("textinSecretCode").String() - c.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int()) c.timeout = uint32(json.Get("timeout").Int()) c.model = json.Get("model").String() if c.timeout == 0 { @@ -82,11 +76,10 @@ func (c *ProviderConfig) Validate() error { if c.typ == "" { return errors.New("embedding service type is required") } - initializer, has := providerInitializers[c.typ] - if !has { + if c.initializer == nil { return errors.New("unknown embedding service provider type: " + c.typ) } - if err := initializer.ValidateConfig(*c); err != nil { + if err := c.initializer.ValidateConfig(); err != nil { return err } return nil diff --git a/plugins/wasm-go/extensions/ai-cache/embedding/textin.go b/plugins/wasm-go/extensions/ai-cache/embedding/textin.go index 9bc474041c..5ff29f1af2 100644 --- a/plugins/wasm-go/extensions/ai-cache/embedding/textin.go +++ b/plugins/wasm-go/extensions/ai-cache/embedding/textin.go @@ -8,6 +8,7 @@ import ( "strconv" "github.com/alibaba/higress/plugins/wasm-go/pkg/wrapper" + "github.com/tidwall/gjson" ) const ( @@ -20,14 +21,34 @@ const ( type textInProviderInitializer struct { } -func (t *textInProviderInitializer) ValidateConfig(config ProviderConfig) error { - if config.textinAppId == "" { - return errors.New("embedding service TextIn App ID is required") +var textInConfig textInProviderConfig + +type textInProviderConfig struct { + //@Title zh-CN TextIn x-ti-app-id + // @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding + textinAppId string + //@Title zh-CN TextIn x-ti-secret-code + // @Description zh-CN 仅适用于 TextIn 服务。参考 https://www.textin.com/document/acge_text_embedding + textinSecretCode string + //@Title zh-CN TextIn request matryoshka_dim + // @Description zh-CN 仅适用于 TextIn 服务, 指定返回的向量维度。参考 https://www.textin.com/document/acge_text_embedding + textinMatryoshkaDim int +} + +func (c *textInProviderInitializer) InitConfig(json gjson.Result) { + textInConfig.textinAppId = json.Get("textinAppId").String() + textInConfig.textinSecretCode = json.Get("textinSecretCode").String() + textInConfig.textinMatryoshkaDim = int(json.Get("textinMatryoshkaDim").Int()) +} + +func (c *textInProviderInitializer) ValidateConfig() error { + if textInConfig.textinAppId == "" { + return errors.New("textinAppId is required") } - if config.textinSecretCode == "" { - return errors.New("embedding service TextIn Secret Code is required") + if textInConfig.textinSecretCode == "" { + return errors.New("textinSecretCode is required") } - if config.textinMatryoshkaDim == 0 { + if textInConfig.textinMatryoshkaDim == 0 { return errors.New("embedding service TextIn Matryoshka Dim is required") } return nil @@ -62,7 +83,7 @@ type TextInResponse struct { } type TextInResult struct { - Embeddings [][]float64 `json:"embedding"` + Embeddings [][]float64 `json:"embedding"` MatryoshkaDim int `json:"matryoshka_dim"` } @@ -80,7 +101,7 @@ func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (strin data := TextInEmbeddingRequest{ Input: texts, - MatryoshkaDim: t.config.textinMatryoshkaDim, + MatryoshkaDim: textInConfig.textinMatryoshkaDim, } requestBody, err := json.Marshal(data) @@ -89,20 +110,20 @@ func (t *TIProvider) constructParameters(texts []string, log wrapper.Log) (strin return "", nil, nil, err } - if t.config.textinAppId == "" { + if textInConfig.textinAppId == "" { err := errors.New("textinAppId is empty") log.Errorf("failed to construct headers: %v", err) return "", nil, nil, err } - if t.config.textinSecretCode == "" { + if textInConfig.textinSecretCode == "" { err := errors.New("textinSecretCode is empty") log.Errorf("failed to construct headers: %v", err) return "", nil, nil, err } headers := [][2]string{ - {"x-ti-app-id", t.config.textinAppId}, - {"x-ti-secret-code", t.config.textinSecretCode}, + {"x-ti-app-id", textInConfig.textinAppId}, + {"x-ti-secret-code", textInConfig.textinSecretCode}, {"Content-Type", "application/json"}, }