Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sentry/db/deletion.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def iterator(
order_by=order_field,
override_unique_safety_check=True,
result_value_getter=lambda item: item[1],
query_timeout_retries=10,
)

for batch in itertools.batched(wrapper, chunk_size):
Expand Down
17 changes: 16 additions & 1 deletion src/sentry/utils/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
from django.db.models.query_utils import Q
from django.db.models.sql.constants import ROW_COUNT
from django.db.models.sql.subqueries import DeleteQuery
from django.db.utils import OperationalError

from sentry.db.models.base import Model
from sentry.services import eventstore
from sentry.utils.retries import ConditionalRetryPolicy

if TYPE_CHECKING:
from sentry.services.eventstore.models import Event
Expand Down Expand Up @@ -109,6 +111,8 @@ def __init__[M: Model](
callbacks: Sequence[Callable[[list[V]], None]] = (),
result_value_getter: Callable[[V], int] | None = None,
override_unique_safety_check: bool = False,
query_timeout_retries: int | None = None,
retry_delay_seconds: float = 0.5,
):
# Support for slicing
if queryset.query.low_mark == 0 and not (
Expand All @@ -132,6 +136,8 @@ def __init__[M: Model](
self.order_by = order_by
self.callbacks = callbacks
self.result_value_getter = result_value_getter
self.query_timeout_retries = query_timeout_retries
self.retry_delay_seconds = retry_delay_seconds

order_by_col = queryset.model._meta.get_field(order_by if order_by != "pk" else "id")
if not override_unique_safety_check and (
Expand Down Expand Up @@ -176,7 +182,16 @@ def __iter__(self) -> Iterator[V]:
else:
results_qs = queryset.filter(**{"%s__gte" % self.order_by: cur_value})

results = list(results_qs[0 : self.step])
if self.query_timeout_retries is not None:
retries = self.query_timeout_retries
retry_policy = ConditionalRetryPolicy(
test_function=lambda attempt, exc: attempt <= retries
and isinstance(exc, OperationalError),
delay_function=lambda i: self.retry_delay_seconds,
)
results = retry_policy(lambda: list(results_qs[0 : self.step]))
else:
results = list(results_qs[0 : self.step])

for cb in self.callbacks:
cb(results)
Expand Down
93 changes: 93 additions & 0 deletions tests/sentry/utils/test_query.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from unittest.mock import patch

import pytest
from django.db import connections
from django.db.utils import OperationalError

from sentry.db.models.query import in_iexact
from sentry.models.commit import Commit
Expand Down Expand Up @@ -78,6 +81,96 @@ def test_wrapper_over_values_list(self) -> None:
qs = User.objects.all().values_list("id")
assert list(qs) == list(self.range_wrapper(qs, result_value_getter=lambda r: r[0]))

def test_retry_on_operational_error_success_after_failures(self) -> None:
"""Test that with query_timeout_retries=3, after 2 errors and 1 success it works."""
total = 5
for _ in range(total):
self.create_user()

qs = User.objects.all()
batch_attempts: list[int] = []
current_batch_count = 0
original_getitem = type(qs).__getitem__

def mock_getitem(self, slice_obj):
nonlocal current_batch_count
current_batch_count += 1
if len(batch_attempts) == 0 and current_batch_count <= 2:
raise OperationalError("canceling statement due to user request")
if len(batch_attempts) == 0 and current_batch_count == 3:
batch_attempts.append(current_batch_count)
return original_getitem(self, slice_obj)

with patch.object(type(qs), "__getitem__", mock_getitem):
results = list(
self.range_wrapper(qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01)
)

assert len(results) == total
assert batch_attempts[0] == 3

def test_retry_exhausted_raises_exception(self) -> None:
"""Test that after exhausting retries, the OperationalError is raised."""
total = 5
for _ in range(total):
self.create_user()

qs = User.objects.all()

def always_fail(self, slice_obj):
raise OperationalError("canceling statement due to user request")

with patch.object(type(qs), "__getitem__", always_fail):
with pytest.raises(OperationalError, match="canceling statement due to user request"):
list(
self.range_wrapper(
qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01
)
)

def test_retry_does_not_catch_other_exceptions(self) -> None:
"""Test that non-OperationalError exceptions are not retried."""
total = 5
for _ in range(total):
self.create_user()

qs = User.objects.all()

attempt_count = {"count": 0}

def raise_value_error(self, slice_obj):
attempt_count["count"] += 1
raise ValueError("Some other error")

with patch.object(type(qs), "__getitem__", raise_value_error):
with pytest.raises(ValueError, match="Some other error"):
list(
self.range_wrapper(
qs, step=10, query_timeout_retries=3, retry_delay_seconds=0.01
)
)
assert attempt_count["count"] == 1

def test_no_retry_when_query_timeout_retries_is_none(self) -> None:
"""Test that when query_timeout_retries is None, no retry logic is applied."""
total = 5
for _ in range(total):
self.create_user()

qs = User.objects.all()

attempt_count = {"count": 0}

def fail_once(self, slice_obj):
attempt_count["count"] += 1
raise OperationalError("canceling statement due to user request")

with patch.object(type(qs), "__getitem__", fail_once):
with pytest.raises(OperationalError, match="canceling statement due to user request"):
list(self.range_wrapper(qs, step=10, query_timeout_retries=None))

assert attempt_count["count"] == 1


@no_silo_test
class RangeQuerySetWrapperWithProgressBarTest(RangeQuerySetWrapperTest):
Expand Down
Loading