Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 43 additions & 5 deletions langextract/providers/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

import concurrent.futures
import dataclasses
import random
import time
from typing import Any, Final, Iterator, Sequence

from absl import logging
Expand Down Expand Up @@ -67,6 +69,7 @@ class GeminiLanguageModel(base_model.BaseLanguageModel): # pylint: disable=too-
format_type: data.FormatType = data.FormatType.JSON
temperature: float = 0.0
max_workers: int = 10
max_retries: int = 5
fence_output: bool = False
_extra_kwargs: dict[str, Any] = dataclasses.field(
default_factory=dict, repr=False, compare=False
Expand Down Expand Up @@ -104,6 +107,7 @@ def __init__(
format_type: data.FormatType = data.FormatType.JSON,
temperature: float = 0.0,
max_workers: int = 10,
max_retries: int = 5,
fence_output: bool = False,
**kwargs,
) -> None:
Expand All @@ -121,6 +125,7 @@ def __init__(
format_type: Output format (JSON or YAML).
temperature: Sampling temperature.
max_workers: Maximum number of parallel API calls.
max_retries: Maximum number of retries for rate limit (429) errors.
fence_output: Whether to wrap output in markdown fences (ignored,
Gemini handles this based on schema).
**kwargs: Additional Gemini API parameters. Only allowlisted keys are
Expand Down Expand Up @@ -148,6 +153,7 @@ def __init__(
self.format_type = format_type
self.temperature = temperature
self.max_workers = max_workers
self.max_retries = max_retries
self.fence_output = fence_output

# Extract batch config before we filter kwargs into _extra_kwargs
Expand Down Expand Up @@ -214,15 +220,47 @@ def _process_single_prompt(
config.setdefault('response_mime_type', 'application/json')
config.setdefault('response_schema', self.gemini_schema.schema_dict)

response = self._client.models.generate_content(
model=self.model_id, contents=prompt, config=config
)
base_delay = 1.0 # seconds
max_delay = 120.0 # seconds

return core_types.ScoredOutput(score=1.0, output=response.text)
for attempt in range(self.max_retries + 1):
try:
response = self._client.models.generate_content(
model=self.model_id, contents=prompt, config=config
)
return core_types.ScoredOutput(score=1.0, output=response.text)

except Exception as e:
# Check for 429 RESOURCE_EXHAUSTED
is_rate_limit = False
error_message = str(e)
if "429" in error_message or "RESOURCE_EXHAUSTED" in error_message:
is_rate_limit = True

if is_rate_limit and attempt < self.max_retries:
delay = min(max_delay, base_delay * (2**attempt))
jitter = random.uniform(0, 0.1 * delay)
sleep_time = delay + jitter
logging.warning(
"Gemini API rate limit hit (429). Retrying in %.2fs (attempt %d/%d)",
sleep_time,
attempt + 1,
self.max_retries,
)
time.sleep(sleep_time)
continue

raise exceptions.InferenceRuntimeError(
f"Gemini API error: {error_message}", original=e
) from e

# This should technically be unreachable due to the raise in the loop
raise exceptions.InferenceRuntimeError("Gemini API error: Maximum retries exceeded")
except Exception as e:
if isinstance(e, exceptions.InferenceRuntimeError):
raise
raise exceptions.InferenceRuntimeError(
f'Gemini API error: {str(e)}', original=e
f"Gemini API error: {str(e)}", original=e
) from e

def infer(
Expand Down
99 changes: 99 additions & 0 deletions tests/test_gemini_backoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright 2025 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for Gemini provider exponential backoff."""

from unittest import mock
from absl.testing import absltest
from langextract.core import exceptions
from langextract.providers import gemini

class TestGeminiBackoff(absltest.TestCase):

@mock.patch("google.genai.Client")
@mock.patch("time.sleep") # Mock sleep to speed up tests
def test_gemini_retry_on_429(self, mock_sleep, mock_client_class):
"""Test that Gemini retries on 429 errors and eventually succeeds."""
mock_client = mock.Mock()
mock_client_class.return_value = mock_client

# Simulate one 429 error followed by a success
mock_response = mock.Mock()
mock_response.text = '{"result": "success"}'

mock_client.models.generate_content.side_effect = [
Exception("429 RESOURCE_EXHAUSTED"),
mock_response
]

model = gemini.GeminiLanguageModel(
api_key="test-key",
max_retries=3
)

results = list(model.infer(["Test prompt"]))

self.assertEqual(len(results), 1)
self.assertEqual(results[0][0].output, '{"result": "success"}')
self.assertEqual(mock_client.models.generate_content.call_count, 2)
mock_sleep.assert_called_once()

@mock.patch("google.genai.Client")
@mock.patch("time.sleep")
def test_gemini_max_retries_exceeded(self, mock_sleep, mock_client_class):
"""Test that Gemini fails after exceeding max retries."""
mock_client = mock.Mock()
mock_client_class.return_value = mock_client

# Simulate continuous 429 errors
mock_client.models.generate_content.side_effect = Exception("429 RESOURCE_EXHAUSTED")

model = gemini.GeminiLanguageModel(
api_key="test-key",
max_retries=2
)

with self.assertRaises(exceptions.InferenceRuntimeError) as cm:
list(model.infer(["Test prompt"]))

self.assertIn("Gemini API error", str(cm.exception))
self.assertIn("429", str(cm.exception))
# 1 initial call + 2 retries = 3 calls
self.assertEqual(mock_client.models.generate_content.call_count, 3)
self.assertEqual(mock_sleep.call_count, 2)

@mock.patch("google.genai.Client")
@mock.patch("time.sleep")
def test_gemini_no_retry_on_other_errors(self, mock_sleep, mock_client_class):
"""Test that Gemini does not retry on non-429 errors."""
mock_client = mock.Mock()
mock_client_class.return_value = mock_client

# Simulate a non-429 error
mock_client.models.generate_content.side_effect = Exception("500 Internal Server Error")

model = gemini.GeminiLanguageModel(
api_key="test-key",
max_retries=3
)

with self.assertRaises(exceptions.InferenceRuntimeError) as cm:
list(model.infer(["Test prompt"]))

self.assertIn("500", str(cm.exception))
self.assertEqual(mock_client.models.generate_content.call_count, 1)
mock_sleep.assert_not_called()

if __name__ == "__main__":
absltest.main()
Loading