Skip to content

Commit aa48b38

Browse files
committed
feat(mistralai): enable tenacity retries opt-out
1 parent 1aeff43 commit aa48b38

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

libs/partners/mistralai/langchain_mistralai/embeddings.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
import warnings
4-
from collections.abc import Iterable
4+
from collections.abc import Callable, Iterable
55

66
import httpx
77
from httpx import Response
@@ -16,6 +16,7 @@
1616
SecretStr,
1717
model_validator,
1818
)
19+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
1920
from tokenizers import Tokenizer # type: ignore[import]
2021
from typing_extensions import Self
2122

@@ -133,7 +134,9 @@ class MistralAIEmbeddings(BaseModel, Embeddings):
133134
default_factory=secret_from_env("MISTRAL_API_KEY", default=""),
134135
)
135136
endpoint: str = "https://api.mistral.ai/v1/"
137+
max_retries: int | None = 5
136138
timeout: int = 120
139+
wait_time: int | None = 30
137140
max_concurrent_requests: int = 64
138141
tokenizer: Tokenizer = Field(default=None)
139142

@@ -210,6 +213,18 @@ def _get_batches(self, texts: list[str]) -> Iterable[list[str]]:
210213
if batch:
211214
yield batch
212215

216+
def _retry(self, func: Callable) -> Callable:
217+
if self.max_retries is None or self.wait_time is None:
218+
return func
219+
220+
return retry(
221+
retry=retry_if_exception_type(
222+
(httpx.TimeoutException, httpx.HTTPStatusError)
223+
),
224+
wait=wait_fixed(self.wait_time),
225+
stop=stop_after_attempt(self.max_retries),
226+
)(func)
227+
213228
def embed_documents(self, texts: list[str]) -> list[list[float]]:
214229
"""Embed a list of document texts.
215230
@@ -223,6 +238,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
223238
try:
224239
batch_responses = []
225240

241+
@self._retry
226242
def _embed_batch(batch: list[str]) -> Response:
227243
response = self.client.post(
228244
url="/embeddings",
@@ -257,6 +273,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
257273
"""
258274
try:
259275

276+
@self._retry
260277
async def _aembed_batch(batch: list[str]) -> Response:
261278
response = await self.async_client.post(
262279
url="/embeddings",

libs/partners/mistralai/tests/integration_tests/test_embeddings.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import httpx
66
import pytest
7+
import tenacity
78

89
from langchain_mistralai import MistralAIEmbeddings
910

@@ -34,10 +35,25 @@ async def test_mistralai_embedding_documents_async() -> None:
3435
assert len(output[0]) == 1024
3536

3637

38+
async def test_mistralai_embedding_documents_tenacity_error_async() -> None:
39+
"""Test MistralAI embeddings for documents."""
40+
documents = ["foo bar", "test document"]
41+
embedding = MistralAIEmbeddings(max_retries=0)
42+
mock_response = httpx.Response(
43+
status_code=400,
44+
request=httpx.Request("POST", url=embedding.async_client.base_url),
45+
)
46+
with (
47+
patch.object(embedding.async_client, "post", return_value=mock_response),
48+
pytest.raises(tenacity.RetryError),
49+
):
50+
await embedding.aembed_documents(documents)
51+
52+
3753
async def test_mistralai_embedding_documents_http_error_async() -> None:
3854
"""Test MistralAI embeddings for documents."""
3955
documents = ["foo bar", "test document"]
40-
embedding = MistralAIEmbeddings()
56+
embedding = MistralAIEmbeddings(max_retries=None)
4157
mock_response = httpx.Response(
4258
status_code=400,
4359
request=httpx.Request("POST", url=embedding.async_client.base_url),

0 commit comments

Comments
 (0)