diff --git a/src/sentry/db/deletion.py b/src/sentry/db/deletion.py index 4cfe17ecb14ed6..79a73908d77e61 100644 --- a/src/sentry/db/deletion.py +++ b/src/sentry/db/deletion.py @@ -122,6 +122,7 @@ def iterator( order_by=order_field, override_unique_safety_check=True, result_value_getter=lambda item: item[1], + query_timeout_retries=10, ) for batch in itertools.batched(wrapper, chunk_size): diff --git a/src/sentry/utils/query.py b/src/sentry/utils/query.py index bbab201dbffe98..05884729ce1745 100644 --- a/src/sentry/utils/query.py +++ b/src/sentry/utils/query.py @@ -12,9 +12,11 @@ from django.db.models.query_utils import Q from django.db.models.sql.constants import ROW_COUNT from django.db.models.sql.subqueries import DeleteQuery +from django.db.utils import OperationalError from sentry.db.models.base import Model from sentry.services import eventstore +from sentry.utils.retries import ConditionalRetryPolicy if TYPE_CHECKING: from sentry.services.eventstore.models import Event @@ -109,6 +111,8 @@ def __init__[M: Model]( callbacks: Sequence[Callable[[list[V]], None]] = (), result_value_getter: Callable[[V], int] | None = None, override_unique_safety_check: bool = False, + query_timeout_retries: int | None = None, + retry_delay_seconds: float = 0.5, ): # Support for slicing if queryset.query.low_mark == 0 and not ( @@ -132,6 +136,8 @@ def __init__[M: Model]( self.order_by = order_by self.callbacks = callbacks self.result_value_getter = result_value_getter + self.query_timeout_retries = query_timeout_retries + self.retry_delay_seconds = retry_delay_seconds order_by_col = queryset.model._meta.get_field(order_by if order_by != "pk" else "id") if not override_unique_safety_check and ( @@ -176,7 +182,16 @@ def __iter__(self) -> Iterator[V]: else: results_qs = queryset.filter(**{"%s__gte" % self.order_by: cur_value}) - results = list(results_qs[0 : self.step]) + if self.query_timeout_retries is not None: + retries = self.query_timeout_retries + retry_policy = ConditionalRetryPolicy( + test_function=lambda attempt, exc: attempt <= retries + and isinstance(exc, OperationalError), + delay_function=lambda i: self.retry_delay_seconds, + ) + results = retry_policy(lambda: list(results_qs[0 : self.step])) + else: + results = list(results_qs[0 : self.step]) for cb in self.callbacks: cb(results) diff --git a/tests/sentry/utils/test_query.py b/tests/sentry/utils/test_query.py index a9ffbffe5489a6..70f251171f6688 100644 --- a/tests/sentry/utils/test_query.py +++ b/tests/sentry/utils/test_query.py @@ -1,5 +1,8 @@ +from unittest.mock import patch + import pytest from django.db import connections +from django.db.utils import OperationalError from sentry.db.models.query import in_iexact from sentry.models.commit import Commit @@ -78,6 +81,96 @@ def test_wrapper_over_values_list(self) -> None: qs = User.objects.all().values_list("id") assert list(qs) == list(self.range_wrapper(qs, result_value_getter=lambda r: r[0])) + def test_retry_on_operational_error_success_after_failures(self) -> None: + """Test that with query_timeout_retries=3, after 2 errors and 1 success it works.""" + total = 5 + for _ in range(total): + self.create_user() + + qs = User.objects.all() + batch_attempts: list[int] = [] + current_batch_count = 0 + original_getitem = type(qs).__getitem__ + + def mock_getitem(self, slice_obj): + nonlocal current_batch_count + current_batch_count += 1 + if len(batch_attempts) == 0 and current_batch_count <= 2: + raise OperationalError("canceling statement due to user request") + if len(batch_attempts) == 0 and current_batch_count == 3: + batch_attempts.append(current_batch_count) + return original_getitem(self, slice_obj) + + with patch.object(type(qs), "__getitem__", mock_getitem): + results = list( + self.range_wrapper(qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01) + ) + + assert len(results) == total + assert batch_attempts[0] == 3 + + def test_retry_exhausted_raises_exception(self) -> None: + """Test that after exhausting retries, the OperationalError is raised.""" + total = 5 + for _ in range(total): + self.create_user() + + qs = User.objects.all() + + def always_fail(self, slice_obj): + raise OperationalError("canceling statement due to user request") + + with patch.object(type(qs), "__getitem__", always_fail): + with pytest.raises(OperationalError, match="canceling statement due to user request"): + list( + self.range_wrapper( + qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01 + ) + ) + + def test_retry_does_not_catch_other_exceptions(self) -> None: + """Test that non-OperationalError exceptions are not retried.""" + total = 5 + for _ in range(total): + self.create_user() + + qs = User.objects.all() + + attempt_count = {"count": 0} + + def raise_value_error(self, slice_obj): + attempt_count["count"] += 1 + raise ValueError("Some other error") + + with patch.object(type(qs), "__getitem__", raise_value_error): + with pytest.raises(ValueError, match="Some other error"): + list( + self.range_wrapper( + qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01 + ) + ) + assert attempt_count["count"] == 1 + + def test_no_retry_when_query_timeout_retries_is_none(self) -> None: + """Test that when query_timeout_retries is None, no retry logic is applied.""" + total = 5 + for _ in range(total): + self.create_user() + + qs = User.objects.all() + + attempt_count = {"count": 0} + + def fail_once(self, slice_obj): + attempt_count["count"] += 1 + raise OperationalError("canceling statement due to user request") + + with patch.object(type(qs), "__getitem__", fail_once): + with pytest.raises(OperationalError, match="canceling statement due to user request"): + list(self.range_wrapper(qs, step=10, query_timeout_retries=None)) + + assert attempt_count["count"] == 1 + @no_silo_test class RangeQuerySetWrapperWithProgressBarTest(RangeQuerySetWrapperTest):