Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.13"]
python-version: ["3.9", "3.12", "3.13"]

steps:
- uses: actions/checkout@v4
Expand Down
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ authors = [
{name = "Marek Szymutko", email = "mszymutk@redhat.com"},
]
readme = "README.md"
requires-python = ">=3.13"
requires-python = ">=3.9"
dependencies = [
"confluent-kafka>=2.12.2",
"pylint>=4.0.4",
"confluent-kafka==2.13.0",
]

[build-system]
Expand All @@ -26,12 +25,13 @@ pythonpath = ["src"]

[dependency-groups]
dev = [
"coverage>=7.12.0",
"coverage>=7.10.7",
"mypy>=1.18.2",
"pytest>=9.0.1",
"pylint>=3.3.9",
"pytest>=8.4.2",
"pytest-asyncio>=0.24.0",
"pytest-cov>=7.0.0",
"ruff>=0.14.5",
"tox>=4.32.0",
"tox>=4.30.3",
"types-confluent-kafka>=1.4.0",
]
37 changes: 28 additions & 9 deletions src/retriable_kafka_client/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,15 @@ def __perform_commits(self) -> None:
"""
committable = self.__tracking_manager.pop_committable()
if committable:
self._consumer.commit(offsets=committable, asynchronous=False)
try:
self._consumer.commit(offsets=committable, asynchronous=False)
except KafkaException:
LOGGER.exception(
"Temporarily failed to commit messages to partitions %s. "
"This action will be retried.",
committable,
)
self.__tracking_manager.reschedule_uncommittable(committable)

def __on_revoke(self, _: Consumer, partitions: list[TopicPartition]) -> None:
"""
Expand Down Expand Up @@ -153,6 +161,23 @@ def __ack_message(self, message: Message, finished_future: Future) -> None:
finally:
self.__tracking_manager.schedule_commit(message)

def __graceful_shutdown(self) -> None:
"""
Finish future execution, perform commits and
stop the Kafka consumer. This must be called from
the same thread that polls messages.
"""
self.__tracking_manager.register_revoke()
try:
self.__perform_commits()
finally:
LOGGER.debug("Shutting down Kafka consumer...")
try:
self._consumer.close()
except (RuntimeError, KafkaException):
LOGGER.debug("Consumer already closed.")
LOGGER.info("Kafka consumer has been shut down gracefully.")

### Retry methods ###

def __process_retried_messages_from_schedule(self) -> None:
Expand Down Expand Up @@ -245,8 +270,9 @@ def run(self) -> None:
self.__perform_commits()
except BrokenProcessPool:
LOGGER.exception("Process pool got broken, stopping consumer.")
self.stop()
self.__graceful_shutdown()
sys.exit(1)
self.__graceful_shutdown()

def connection_healthcheck(self) -> bool:
"""Programmatically check if we are able to read from Kafka."""
Expand All @@ -263,10 +289,3 @@ def stop(self) -> None:
"""
LOGGER.debug("Stopping retriable consumer...")
self.__stop_flag = True
self.__tracking_manager.register_revoke()
self.__perform_commits()
try:
LOGGER.debug("Shutting down Kafka consumer...")
self._consumer.close()
except (RuntimeError, KafkaException): # pragma: no cover
LOGGER.debug("Consumer already closed.")
32 changes: 24 additions & 8 deletions src/retriable_kafka_client/consumer_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ class TrackingManager:
of this cache is to hold information about offsets that cannot be committed
yet and also about offsets that can be committed already.

It is enough to commit only the offset of the last committable messages,
It is enough to commit only the offset of the last committable messages
+1 (Kafka tracks the next offset, not the current one),
the cluster cannot hold more information than the latest offset for each
partition and consumer group.

Expand Down Expand Up @@ -117,7 +118,7 @@ def pop_committable(self) -> list[TopicPartition]:
TopicPartition(
topic=partition_info.topic,
partition=partition_info.partition,
offset=max_to_commit + 1,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changes look weird, let me explain:

  • Kafka needs message commits to be +1 than the latest confirmed messages
  • In the past, the tracker class had the exact values of offsets from messages
  • Now that I added the commit retry, there is a chance we may need to return commits back to the tracker class. I think it's cleaner to add the +1 when the offset is first stored in the tracker object, because otherwise the retry (re-register) of commits needs to again subtract the -1.

This changes the behavior of the tracker class slightly, but it was never meant to be used directly, it's just called from the BaseConsumer where it's a private attribute.

offset=max_to_commit,
)
)
self.__to_commit[partition_info] = set()
Expand All @@ -137,7 +138,7 @@ def pop_committable(self) -> list[TopicPartition]:
TopicPartition(
topic=partition_info.topic,
partition=partition_info.partition,
offset=max_to_commit + 1,
offset=max_to_commit,
)
)
# Clean up committed
Expand All @@ -146,6 +147,20 @@ def pop_committable(self) -> list[TopicPartition]:
self._cleanup()
return to_commit

def reschedule_uncommittable(
self, failed_committable: list[TopicPartition]
) -> None:
"""
Add back data that could not be committed at the moment.
The committing of this data will be retried later.
Args:
failed_committable: list of data that failed to be committed
"""
for failed in failed_committable:
self.__to_commit.setdefault(
_PartitionInfo(topic=failed.topic, partition=failed.partition), set()
).add(failed.offset)

def process_message(self, message: Message, future: Future[Any]) -> None:
"""
Mark message as pending for processing.
Expand All @@ -159,9 +174,9 @@ def process_message(self, message: Message, future: Future[Any]) -> None:
message_offset: int = message.offset() # type: ignore[assignment]
with self.__access_lock:
# Mark the message as being processed
self.__to_process[_PartitionInfo.from_message(message)][message_offset] = (
future
)
self.__to_process[_PartitionInfo.from_message(message)][
message_offset + 1
] = future

def schedule_commit(self, message: Message) -> bool:
"""
Expand All @@ -175,9 +190,10 @@ def schedule_commit(self, message: Message) -> bool:
self.__semaphore.release()
partition_info = _PartitionInfo.from_message(message)
message_offset: int = message.offset() # type: ignore[assignment]
stored_offset = message_offset + 1
with self.__access_lock:
self.__to_process[partition_info].pop(message_offset, None)
self.__to_commit.setdefault(partition_info, set()).add(message_offset)
self.__to_process[partition_info].pop(stored_offset, None)
self.__to_commit.setdefault(partition_info, set()).add(stored_offset)
self._cleanup()
return True

Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_rebalance.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Integration tests for Kafka consumer rebalancing during message processing"""

import asyncio
import logging
from typing import Any

import pytest
Expand All @@ -16,13 +15,11 @@
async def test_rebalance_mid_processing_exactly_once(
kafka_config: dict[str, Any],
admin_client: AdminClient,
caplog: pytest.LogCaptureFixture,
) -> None:
"""
Test that when a rebalance occurs mid-processing (by adding a new consumer),
all messages are processed at least once with no message loss.
"""
caplog.set_level(logging.WARNING)
total_messages = 50

config = ScaffoldConfig(
Expand Down
60 changes: 40 additions & 20 deletions tests/unit/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,28 +140,25 @@ def test_consumer__consumer_property_reuses_instance(
assert mock_consumer_class.call_count == 1


def test_consumer_stop(
def test_consumer__graceful_shutdown(
base_consumer: BaseConsumer,
) -> None:
"""Test that stop method sets flag, drains cache, and attempts to close consumer."""
mock_consumer = base_consumer._consumer
mock_consumer.close = MagicMock()
mock_consumer.commit = MagicMock()
# Mock the executor.map to avoid actually calling it
base_consumer._executor.map = MagicMock()
base_consumer._executor = MagicMock()

# Pre-fill the tracking manager
partition_info = PartitionInfo("test-topic", 0)
tracking_manager = base_consumer._BaseConsumer__tracking_manager
tracking_manager._TrackingManager__to_commit[partition_info].update({100, 101, 102})

assert base_consumer._BaseConsumer__stop_flag is False
# Verify cache has data
assert len(tracking_manager._TrackingManager__to_commit[partition_info]) == 3

base_consumer.stop()
base_consumer._BaseConsumer__graceful_shutdown()

assert base_consumer._BaseConsumer__stop_flag is True
# Verify cache was drained - to_commit should be empty after commits
assert (
len(tracking_manager._TrackingManager__to_commit.get(partition_info, set()))
Expand Down Expand Up @@ -264,38 +261,44 @@ def poll_side_effect(*_, **__):
"Consumer error: I stubbed my toe when fetching messages" in caplog.messages
)

@patch("retriable_kafka_client.consumer.TrackingManager", MagicMock())
def test_consumer__graceful_shutdown_closed(
base_consumer: BaseConsumer, caplog: pytest.LogCaptureFixture
) -> None:
caplog.set_level("DEBUG")
mock_consumer = base_consumer._consumer
mock_consumer.close = MagicMock(side_effect=KafkaException())
with patch.object(BaseConsumer, "_BaseConsumer__perform_commits"):
base_consumer._BaseConsumer__graceful_shutdown()
assert "Consumer already closed." in caplog.messages

def test_consumer_run_handles_broken_process_pool(
base_consumer: BaseConsumer,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test that BrokenProcessPool exception is handled and consumer stops."""
caplog.set_level(logging.ERROR)

mock_consumer = base_consumer._consumer
mock_consumer.poll.side_effect = BrokenProcessPool("Pool broken")
mock_consumer.subscribe = MagicMock()

with patch("sys.exit") as mock_exit:
base_consumer.run()
mock_exit.assert_called_once_with(1)
assert base_consumer._BaseConsumer__stop_flag is True
assert any("Process pool got broken" in msg for msg in caplog.messages)
with patch.object(
BaseConsumer, "_BaseConsumer__process_retried_messages_from_schedule"
) as mock_reprocess:
mock_reprocess.side_effect = BrokenProcessPool()
with pytest.raises(SystemExit):
base_consumer.run()


@pytest.mark.parametrize(
"to_process_offsets,to_commit_offsets,should_commit,expected_offset",
[
pytest.param(set(), {100, 101}, True, 102, id="no_processing_can_commit"),
pytest.param(set(), {100, 101}, True, 101, id="no_processing_can_commit"),
pytest.param(
{50}, {100, 101}, False, None, id="older_processing_blocks_newer_commits"
),
pytest.param(
{100}, {100, 101}, False, None, id="processing_same_offset_blocks_commit"
),
pytest.param({70}, {50, 60}, True, 61, id="partial_commit_below_processing"),
pytest.param({70}, {50, 60}, True, 60, id="partial_commit_below_processing"),
pytest.param(
{100}, {50, 60}, True, 61, id="commit_older_while_processing_newer"
{100}, {50, 60}, True, 60, id="commit_older_while_processing_newer"
),
pytest.param(set(), set(), False, None, id="empty_queues_nothing_to_commit"),
],
Expand Down Expand Up @@ -334,10 +337,21 @@ def test_perform_commits_logic(
mock_consumer.commit.assert_not_called()


def test_perform_commits_failed(base_consumer: BaseConsumer) -> None:
mock_consumer = base_consumer._consumer
mock_consumer.commit = MagicMock(side_effect=KafkaException())

partition_info = PartitionInfo("test-topic", 0)
tracking_manager = base_consumer._BaseConsumer__tracking_manager
tracking_manager._TrackingManager__to_commit[partition_info] = {1}
base_consumer._BaseConsumer__perform_commits()
assert tracking_manager._TrackingManager__to_commit[partition_info] == {1}


@pytest.mark.parametrize(
"to_commit_offsets,expected_offset",
[
pytest.param({100, 101, 102}, 103, id="cache_filled_commits_offsets"),
pytest.param({100, 101, 102}, 102, id="cache_filled_commits_offsets"),
pytest.param(set(), None, id="cache_empty_no_commit"),
],
)
Expand Down Expand Up @@ -551,3 +565,9 @@ def test_consumer_validation_invalid_configs(
AssertionError, match="Cannot consume twice from the same topic"
):
BaseConsumer(config=config, executor=executor, max_concurrency=2)


def test_consumer_stop(base_consumer: BaseConsumer) -> None:
assert base_consumer._BaseConsumer__stop_flag is False
base_consumer.stop()
assert base_consumer._BaseConsumer__stop_flag is True
20 changes: 10 additions & 10 deletions tests/unit/test_consumer_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_partition_info_round_trip() -> None:
),
pytest.param(
[{"partition": 0, "to_commit": {100, 101, 102}, "to_process": set()}],
[(0, 103)],
[(0, 102)],
{0: set()},
id="no_pending_to_process_commit_all_and_cleanup",
),
Expand All @@ -62,7 +62,7 @@ def test_partition_info_round_trip() -> None:
),
pytest.param(
[{"partition": 0, "to_commit": {50, 60, 100, 101}, "to_process": {70}}],
[(0, 61)],
[(0, 60)],
{0: {100, 101}},
id="some_offsets_lt_min_pending_partial_commit",
),
Expand All @@ -73,7 +73,7 @@ def test_partition_info_round_trip() -> None:
{"partition": 2, "to_commit": {30, 31, 40, 41}, "to_process": {35}},
{"partition": 3, "to_commit": {50, 51}, "to_process": {45}},
],
[(0, 13), (2, 32)],
[(0, 12), (2, 31)],
{0: set(), 1: set(), 2: {40, 41}, 3: {50, 51}},
id="multiple_partitions_different_states",
),
Expand All @@ -85,7 +85,7 @@ def test_partition_info_round_trip() -> None:
"to_process": {25, 35, 45},
}
],
[(0, 21)],
[(0, 20)],
{0: {30, 40}},
id="multiple_pending_to_process_use_minimum",
),
Expand Down Expand Up @@ -158,8 +158,8 @@ def test_offset_cache_schedule_commit_success(

# Verify it's in to_process
partition_info = PartitionInfo("test-topic", 0)
assert 42 in cache._TrackingManager__to_process[partition_info]
assert 42 not in cache._TrackingManager__to_commit.get(partition_info, set())
assert 43 in cache._TrackingManager__to_process[partition_info]
assert 43 not in cache._TrackingManager__to_commit.get(partition_info, set())

# Now schedule it for commit
result = cache.schedule_commit(mock_message)
Expand All @@ -168,8 +168,8 @@ def test_offset_cache_schedule_commit_success(
assert result is True

# Verify it moved from to_process to to_commit
assert 42 not in cache._TrackingManager__to_process[partition_info]
assert 42 in cache._TrackingManager__to_commit[partition_info]
assert 43 not in cache._TrackingManager__to_process[partition_info]
assert 43 in cache._TrackingManager__to_commit[partition_info]

# No warning should be logged
assert len(caplog.records) == 0
Expand Down Expand Up @@ -200,7 +200,7 @@ def test_offset_cache_schedule_commit_without_prior_processing(

# Verify offset was added to to_commit (new behavior)
partition_info = PartitionInfo("test-topic", 0)
assert 42 in cache._TrackingManager__to_commit.get(partition_info, set())
assert 43 in cache._TrackingManager__to_commit.get(partition_info, set())

# No warning is logged in the new implementation
assert len(caplog.records) == 0
Expand Down Expand Up @@ -420,7 +420,7 @@ def test_offset_cache_schedule_commit_offset_not_in_partition(
assert len(caplog.records) == 0

# The offset IS added to to_commit (new behavior)
assert 42 in cache._TrackingManager__to_commit.get(partition_info, set())
assert 43 in cache._TrackingManager__to_commit.get(partition_info, set())

# Verify original offsets remain in to_process (42 wasn't there to remove)
assert set(cache._TrackingManager__to_process[partition_info].keys()) == {
Expand Down
Loading