diff --git a/docs/user-guides/configuration-guide.md b/docs/user-guides/configuration-guide.md index 803c4a519..a70c160b8 100644 --- a/docs/user-guides/configuration-guide.md +++ b/docs/user-guides/configuration-guide.md @@ -538,6 +538,7 @@ The following tables lists the supported embedding providers: | OpenAI | `openai` | `text-embedding-ada-002`, etc. | | SentenceTransformers | `SentenceTransformers` | `all-MiniLM-L6-v2`, etc. | | NVIDIA AI Endpoints | `nvidia_ai_endpoints` | `nv-embed-v1`, etc. | +| Cohere | `cohere` | `embed-multilingual-v3.0`, etc. | ```{note} You can use any of the supported models for any of the supported embedding providers. diff --git a/nemoguardrails/embeddings/providers/__init__.py b/nemoguardrails/embeddings/providers/__init__.py index c9a8f2896..d8f3d6f16 100644 --- a/nemoguardrails/embeddings/providers/__init__.py +++ b/nemoguardrails/embeddings/providers/__init__.py @@ -18,7 +18,7 @@ from typing import Optional, Type -from . import fastembed, nim, openai, sentence_transformers +from . import cohere, fastembed, nim, openai, sentence_transformers from .base import EmbeddingModel from .registry import EmbeddingProviderRegistry @@ -68,6 +68,7 @@ def register_embedding_provider( register_embedding_provider(sentence_transformers.SentenceTransformerEmbeddingModel) register_embedding_provider(nim.NIMEmbeddingModel) register_embedding_provider(nim.NVIDIAAIEndpointsEmbeddingModel) +register_embedding_provider(cohere.CohereEmbeddingModel) def init_embedding_model( diff --git a/nemoguardrails/embeddings/providers/cohere.py b/nemoguardrails/embeddings/providers/cohere.py new file mode 100644 index 000000000..c6171daac --- /dev/null +++ b/nemoguardrails/embeddings/providers/cohere.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +from contextvars import ContextVar +from typing import List + +from .base import EmbeddingModel + +# We set the Cohere async client in an asyncio context variable because we need it +# to be scoped at the asyncio loop level. The client caches it somewhere, and if the loop +# is changed, it will fail. +async_client_var: ContextVar = ContextVar("async_client", default=None) + + +class CohereEmbeddingModel(EmbeddingModel): + """ + Embedding model using Cohere API. + + To use, you must have either: + 1. The ``COHERE_API_KEY`` environment variable set with your API key, or + 2. Pass your API key using the api_key kwarg to the Cohere constructor. + + Args: + embedding_model (str): The name of the embedding model. + input_type (str): The type of input for the embedding model, default is "search_document". + "search_document", "search_query", "classification", "clustering", "image" + + Attributes: + model (str): The name of the embedding model. + embedding_size (int): The size of the embeddings. + + Methods: + encode: Encode a list of documents into embeddings. + """ + + engine_name = "cohere" + + def __init__( + self, + embedding_model: str, + input_type: str = "search_document", + **kwargs, + ): + try: + import cohere + from cohere import AsyncClient, Client + except ImportError: + raise ImportError( + "Could not import cohere, please install it with " + "`pip install cohere`." + ) + + self.model = embedding_model + self.input_type = input_type + self.client = cohere.Client(**kwargs) + + self.embedding_size_dict = { + "embed-v4.0": 1536, + "embed-english-v3.0": 1024, + "embed-english-light-v3.0": 384, + "embed-multilingual-v3.0": 1024, + "embed-multilingual-light-v3.0": 384, + } + + if self.model in self.embedding_size_dict: + self.embedding_size = self.embedding_size_dict[self.model] + else: + # Perform a first encoding to get the embedding size + self.embedding_size = len(self.encode(["test"])[0]) + + async def encode_async(self, documents: List[str]) -> List[List[float]]: + """Encode a list of documents into embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The encoded embeddings. + + """ + loop = asyncio.get_running_loop() + embeddings = await loop.run_in_executor(None, self.encode, documents) + + # NOTE: The async implementation below has some edge cases because of + # httpx and async and returns "Event loop is closed." errors. Falling back to + # a thread-based implementation for now. + + # # We do lazy initialization of the async client to make sure it's on the correct loop + # async_client = async_client_var.get() + # if async_client is None: + # async_client = AsyncClient() + # async_client_var.set(async_client) + # + # # Make embedding request to Cohere API + # embeddings = await async_client.embed(texts=documents, model=self.model, input_type=self.input_type).embeddings + + return embeddings + + def encode(self, documents: List[str]) -> List[List[float]]: + """Encode a list of documents into embeddings. + + Args: + documents (List[str]): The list of documents to be encoded. + + Returns: + List[List[float]]: The encoded embeddings. + + """ + + # Make embedding request to Cohere API + return self.client.embed( + texts=documents, model=self.model, input_type=self.input_type + ).embeddings diff --git a/tests/test_configs/with_cohere_embeddings/config.co b/tests/test_configs/with_cohere_embeddings/config.co new file mode 100644 index 000000000..56035e40c --- /dev/null +++ b/tests/test_configs/with_cohere_embeddings/config.co @@ -0,0 +1,12 @@ +define user ask capabilities + "What can you do?" + "What can you help me with?" + "tell me what you can do" + "tell me about you" + +define bot inform capabilities + "I am an AI assistant that helps answer questions." + +define flow + user ask capabilities + bot inform capabilities diff --git a/tests/test_configs/with_cohere_embeddings/config.yml b/tests/test_configs/with_cohere_embeddings/config.yml new file mode 100644 index 000000000..71627761c --- /dev/null +++ b/tests/test_configs/with_cohere_embeddings/config.yml @@ -0,0 +1,8 @@ +models: + - type: main + engine: openai + model: gpt-3.5-turbo-instruct + + - type: embeddings + engine: cohere + model: embed-multilingual-v3.0 diff --git a/tests/test_embeddings_cohere.py b/tests/test_embeddings_cohere.py new file mode 100644 index 000000000..04b351d6b --- /dev/null +++ b/tests/test_embeddings_cohere.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from nemoguardrails import LLMRails, RailsConfig + +try: + from nemoguardrails.embeddings.providers.cohere import CohereEmbeddingModel +except ImportError: + # Ignore this if running in test environment when cohere not installed. + CohereEmbeddingModel = None + +CONFIGS_FOLDER = os.path.join(os.path.dirname(__file__), ".", "test_configs") + +LIVE_TEST_MODE = os.environ.get("LIVE_TEST") + + +@pytest.fixture +def app(): + """Load the configuration where we replace FastEmbed with Cohere.""" + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_cohere_embeddings") + ) + + return LLMRails(config) + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +def test_custom_llm_registration(app): + assert isinstance( + app.llm_generation_actions.flows_index._model, CohereEmbeddingModel + ) + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +async def test_live_query(): + config = RailsConfig.from_path( + os.path.join(CONFIGS_FOLDER, "with_cohere_embeddings") + ) + app = LLMRails(config) + + result = await app.generate_async( + messages=[{"role": "user", "content": "tell me what you can do"}] + ) + + assert result == { + "role": "assistant", + "content": "I am an AI assistant that helps answer questions.", + } + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +def test_live_query(app): + result = app.generate( + messages=[{"role": "user", "content": "tell me what you can do"}] + ) + + assert result == { + "role": "assistant", + "content": "I am an AI assistant that helps answer questions.", + } + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +def test_sync_embeddings(): + model = CohereEmbeddingModel("embed-multilingual-v3.0") + + result = model.encode(["test"]) + + assert len(result[0]) == 1024 + + +@pytest.mark.skipif(not LIVE_TEST_MODE, reason="Not in live mode.") +@pytest.mark.asyncio +async def test_async_embeddings(): + model = CohereEmbeddingModel("embed-multilingual-v3.0") + + result = await model.encode_async(["test"]) + + assert len(result[0]) == 1024