Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/retriable_kafka_client/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
"""Retriable Kafka client module"""

from .consumer import BaseConsumer
from .config import ConsumerConfig, ProducerConfig, ConsumeTopicConfig
from .config import ConsumerConfig, ProducerConfig, ConsumeTopicConfig, CommonConfig
from .orchestrate import consume_topics, ConsumerThread
from .producer import BaseProducer
from .health import HealthCheckClient

__all__ = (
"BaseConsumer",
"BaseProducer",
"CommonConfig",
"consume_topics",
"ConsumerConfig",
"ConsumerThread",
"ProducerConfig",
"ConsumeTopicConfig",
"HealthCheckClient",
)
6 changes: 3 additions & 3 deletions src/retriable_kafka_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass(kw_only=True)
class _CommonConfig:
class CommonConfig:
"""
Topic configuration common for consumers and producers.
Attributes:
Expand All @@ -22,7 +22,7 @@ class _CommonConfig:


@dataclass
class ProducerConfig(_CommonConfig):
class ProducerConfig(CommonConfig):
"""
Topic configuration common each producer, including backoff settings.
Attributes:
Expand Down Expand Up @@ -63,7 +63,7 @@ class ConsumeTopicConfig:


@dataclass
class ConsumerConfig(_CommonConfig):
class ConsumerConfig(CommonConfig):
"""
Topic configuration for each consumer.
Attributes:
Expand Down
8 changes: 2 additions & 6 deletions src/retriable_kafka_client/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from confluent_kafka import Consumer, Message, KafkaException, TopicPartition

from .health import perform_healthcheck_using_client
from .kafka_utils import message_to_partition
from .kafka_settings import KafkaOptions, DEFAULT_CONSUMER_SETTINGS
from .consumer_tracking import TrackingManager
Expand Down Expand Up @@ -276,12 +277,7 @@ def run(self) -> None:

def connection_healthcheck(self) -> bool:
"""Programmatically check if we are able to read from Kafka."""
try:
self._consumer.list_topics(timeout=5)
return True
except KafkaException as e:
LOGGER.debug("Error while connecting to Kafka %s", e)
return False
return perform_healthcheck_using_client(self._consumer)

def stop(self) -> None:
"""
Expand Down
55 changes: 55 additions & 0 deletions src/retriable_kafka_client/health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""Module for detached Kafka cluster healthcheck"""

import logging
from typing import TypeVar

from confluent_kafka import KafkaException, Consumer, Producer

from retriable_kafka_client.config import CommonConfig
from retriable_kafka_client.kafka_settings import (
KafkaOptions,
DEFAULT_PRODUCER_SETTINGS,
)

LOGGER = logging.getLogger(__name__)

Client = TypeVar("Client", Producer, Consumer)


def perform_healthcheck_using_client(client: Client) -> bool:
"""
Programmatically check if we are able to read from Kafka
using the provided client.
"""
try:
client.list_topics(timeout=5)
return True
except KafkaException as e:
LOGGER.warning("Error while connecting to Kafka %s", e)
return False


class HealthCheckClient:
"""
Class for only performing health checks on Kafka cluster.
"""

# pylint: disable=too-few-public-methods

def __init__(self, config: CommonConfig):
self.config = config
config_dict = {
KafkaOptions.KAFKA_NODES: ",".join(config.kafka_hosts),
KafkaOptions.USERNAME: config.username,
KafkaOptions.PASSWORD: config.password,
**DEFAULT_PRODUCER_SETTINGS,
}
config_dict.update(config.additional_settings)
# We use producer so the group ID can stay unfilled
self._client = Producer(
config_dict,
)

def connection_healthcheck(self) -> bool:
"""Programmatically check if we are able to read from Kafka."""
return perform_healthcheck_using_client(self._client)
5 changes: 5 additions & 0 deletions src/retriable_kafka_client/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from confluent_kafka import Producer, KafkaException

from .health import perform_healthcheck_using_client
from .kafka_settings import KafkaOptions, DEFAULT_PRODUCER_SETTINGS
from .config import ProducerConfig

Expand Down Expand Up @@ -159,6 +160,10 @@ async def send(
self._producer.flush()
self.__handle_problems(problems)

def connection_healthcheck(self) -> bool:
"""Programmatically check if we are able to read from Kafka."""
return perform_healthcheck_using_client(self._producer)

def close(self) -> None:
"""
Finish sending all messages, block until complete.
Expand Down
32 changes: 32 additions & 0 deletions tests/integration/test_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from copy import copy
from typing import Any

import pytest

from retriable_kafka_client.kafka_settings import KafkaOptions
from retriable_kafka_client.config import CommonConfig
from retriable_kafka_client.health import HealthCheckClient


@pytest.fixture(scope="session")
def config_object(kafka_config: dict[str, Any]) -> CommonConfig:
return CommonConfig(
kafka_hosts=[kafka_config[KafkaOptions.KAFKA_NODES]],
username=kafka_config[KafkaOptions.USERNAME],
password=kafka_config[KafkaOptions.PASSWORD],
additional_settings={
KafkaOptions.AUTH_MECHANISM: "SCRAM-SHA-512",
KafkaOptions.SECURITY_PROTO: "SASL_PLAINTEXT",
},
)


def test_health_success(config_object: CommonConfig) -> None:
healthcheck_client = HealthCheckClient(config_object)
assert healthcheck_client.connection_healthcheck() is True


def test_health_failure(config_object: CommonConfig) -> None:
copied_config = copy(config_object)
copied_config.password = "invalid"
assert HealthCheckClient(copied_config).connection_healthcheck() is False
26 changes: 0 additions & 26 deletions tests/unit/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,32 +95,6 @@ def test_consumer__process_message_valid_json(
assert mock_future.add_done_callback.called


def test_consumer_connection_healthcheck_success(
base_consumer: BaseConsumer,
) -> None:
"""Test successful connection healthcheck."""
mock_consumer = base_consumer._consumer
mock_consumer.list_topics.return_value = None

result = base_consumer.connection_healthcheck()
assert result is True
mock_consumer.list_topics.assert_called_once_with(timeout=5)


def test_consumer_connection_healthcheck_failure(
base_consumer: BaseConsumer,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test connection healthcheck failure handling."""
caplog.set_level(logging.DEBUG)
mock_consumer = base_consumer._consumer
mock_consumer.list_topics.side_effect = KafkaException("Connection failed")

result = base_consumer.connection_healthcheck()
assert result is False
assert any("Error while connecting to Kafka" in msg for msg in caplog.messages)


def test_consumer__consumer_property_reuses_instance(
sample_config: ConsumerConfig,
executor: Executor,
Expand Down
39 changes: 39 additions & 0 deletions tests/unit/test_health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest.mock import MagicMock, patch

import pytest
from confluent_kafka import KafkaException

from retriable_kafka_client.config import CommonConfig
from retriable_kafka_client.health import (
perform_healthcheck_using_client,
HealthCheckClient,
)


def test_perform_healthcheck_using_client() -> None:
mock_client = MagicMock()

result = perform_healthcheck_using_client(mock_client)
assert result is True
mock_client.list_topics.assert_called_once_with(timeout=5)


def test_perform_healthcheck_using_client_error(
caplog: pytest.LogCaptureFixture,
) -> None:
mock_client = MagicMock()
mock_client.list_topics.side_effect = KafkaException

result = perform_healthcheck_using_client(mock_client)
assert result is False
mock_client.list_topics.assert_called_once_with(timeout=5)
assert any("Error while connecting to Kafka" in msg for msg in caplog.messages)


@patch("retriable_kafka_client.health.Producer")
def test_healthcheck_client(mock_consumer_class: MagicMock) -> None:
client = HealthCheckClient(
CommonConfig(kafka_hosts=["foo"], username="bar", password="spam")
)
client.connection_healthcheck()
mock_consumer_class.return_value.list_topics.assert_called_once_with(timeout=5)