diff --git a/src/typesense/configuration.py b/src/typesense/configuration.py index d59ac5e..7f5d044 100644 --- a/src/typesense/configuration.py +++ b/src/typesense/configuration.py @@ -80,6 +80,8 @@ class ConfigDict(typing.TypedDict): dictionaries or URLs that represent the read replica nodes. connection_timeout_seconds (float): The connection timeout in seconds. + + round_robin_hosts (bool): Whether or not to shuffle hosts between requests """ nodes: typing.List[typing.Union[str, NodeConfigDict]] @@ -96,6 +98,7 @@ class ConfigDict(typing.TypedDict): typing.List[typing.Union[str, NodeConfigDict]] ] # deprecated connection_timeout_seconds: typing.NotRequired[float] + round_robin_hosts: typing.NotRequired[bool] class Node: @@ -184,6 +187,7 @@ class Configuration: retry_interval_seconds (float): The interval in seconds between retries. healthcheck_interval_seconds (int): The interval in seconds between health checks. verify (bool): Whether to verify the SSL certificate. + round_robin_hosts (bool): Whether or not to shuffle hosts between requests """ def __init__( @@ -219,6 +223,7 @@ def __init__( 60, ) self.verify = config_dict.get("verify", True) + self.round_robin_hosts = config_dict.get("round_robin_hosts", False) self.additional_headers = config_dict.get("additional_headers", {}) def _handle_nearest_node( diff --git a/src/typesense/node_manager.py b/src/typesense/node_manager.py index e671c8d..f660aa4 100644 --- a/src/typesense/node_manager.py +++ b/src/typesense/node_manager.py @@ -30,8 +30,15 @@ """ import copy +import random +import sys import time +if sys.version_info >= (3, 11): + import typing +else: + import typing_extensions as typing + from typesense.configuration import Configuration, Node from typesense.logger import logger @@ -71,22 +78,19 @@ def get_node(self) -> Node: Returns: Node: The selected node for the next operation. """ - if self.config.nearest_node: - if self.config.nearest_node.healthy or self._is_due_for_health_check( - self.config.nearest_node, - ): - return self.config.nearest_node + if self._should_use_nearest_node(): + return self.config.nearest_node - node_index = 0 - while node_index < len(self.nodes): - node_index += 1 - node = self.nodes[self.node_index] - self.node_index = (self.node_index + 1) % len(self.nodes) - if node.healthy or self._is_due_for_health_check(node): - return node + healthy_nodes = self._get_healthy_nodes() - logger.debug("No healthy nodes were found. Returning the next node.") - return self.nodes[self.node_index] + if not healthy_nodes: + logger.debug("No healthy nodes were found. Returning the next node.") + return self.nodes[self.node_index] + + if self.config.round_robin_hosts: + return self._get_shuffled_node(healthy_nodes) + + return self._get_next_round_robin_node() def set_node_health(self, node: Node, is_healthy: bool) -> None: """ @@ -126,3 +130,62 @@ def _initialize_nodes(self) -> None: self.set_node_health(self.config.nearest_node, is_healthy=True) for node in self.nodes: self.set_node_health(node, is_healthy=True) + + def _should_use_nearest_node(self) -> bool: + """ + Check if we should use the nearest node. + + Returns: + bool: True if nearest node should be used, False otherwise. + """ + return bool( + self.config.nearest_node + and ( + self.config.nearest_node.healthy + or self._is_due_for_health_check(self.config.nearest_node) + ), + ) + + def _get_healthy_nodes(self) -> typing.List[Node]: + """ + Get a list of all healthy nodes. + + Returns: + List[Node]: List of healthy nodes. + """ + return [ + node + for node in self.nodes + if node.healthy or self._is_due_for_health_check(node) + ] + + def _get_shuffled_node(self, healthy_nodes: typing.List[Node]) -> Node: + """ + Get a randomly shuffled node from the list of healthy nodes. + + Args: + healthy_nodes (List[Node]): List of healthy nodes to choose from. + + Returns: + Node: A randomly selected healthy node. + """ + random.shuffle(healthy_nodes) + self.node_index = (self.node_index + 1) % len(self.nodes) + return healthy_nodes[0] + + def _get_next_round_robin_node(self) -> Node: + """ + Get the next node using standard round-robin selection. + + Returns: + Node: The next node in the round-robin sequence. + """ + node_index = 0 + while node_index < len(self.nodes): + node_index += 1 + node = self.nodes[self.node_index] + self.node_index = (self.node_index + 1) % len(self.nodes) + if node.healthy or self._is_due_for_health_check(node): + return node + + return self.nodes[self.node_index] diff --git a/tests/api_call_test.py b/tests/api_call_test.py index 1d5fa11..15c8f3b 100644 --- a/tests/api_call_test.py +++ b/tests/api_call_test.py @@ -80,6 +80,28 @@ def test_get_node_round_robin_selection( assert_match_object(node3, fake_api_call.config.nodes[2]) +def test_get_node_round_robin_shuffle( + fake_api_call: ApiCall, + mocker: MockerFixture, +) -> None: + """Test that it shuffles healthy nodes when round_robin_hosts is true.""" + fake_api_call.config.nearest_node = None + fake_api_call.config.round_robin_hosts = True + mocker.patch("time.time", return_value=100) + + shuffle_mock = mocker.patch("random.shuffle") + + for _ in range(3): + fake_api_call.node_manager.get_node() + + assert shuffle_mock.call_count == 3 + + for call in shuffle_mock.call_args_list: + args = call[0][0] + assert isinstance(args, list) + assert all(node.healthy for node in args) + + def test_get_exception() -> None: """Test that it correctly returns the exception class for a given status code.""" assert RequestHandler._get_exception(0) == exceptions.HTTPStatus0Error