Skip to content

feat(node): add round-robin host shuffling capability #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/typesense/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class ConfigDict(typing.TypedDict):
dictionaries or URLs that represent the read replica nodes.

connection_timeout_seconds (float): The connection timeout in seconds.

round_robin_hosts (bool): Whether or not to shuffle hosts between requests
"""

nodes: typing.List[typing.Union[str, NodeConfigDict]]
Expand All @@ -96,6 +98,7 @@ class ConfigDict(typing.TypedDict):
typing.List[typing.Union[str, NodeConfigDict]]
] # deprecated
connection_timeout_seconds: typing.NotRequired[float]
round_robin_hosts: typing.NotRequired[bool]


class Node:
Expand Down Expand Up @@ -184,6 +187,7 @@ class Configuration:
retry_interval_seconds (float): The interval in seconds between retries.
healthcheck_interval_seconds (int): The interval in seconds between health checks.
verify (bool): Whether to verify the SSL certificate.
round_robin_hosts (bool): Whether or not to shuffle hosts between requests
"""

def __init__(
Expand Down Expand Up @@ -219,6 +223,7 @@ def __init__(
60,
)
self.verify = config_dict.get("verify", True)
self.round_robin_hosts = config_dict.get("round_robin_hosts", False)
self.additional_headers = config_dict.get("additional_headers", {})

def _handle_nearest_node(
Expand Down
91 changes: 77 additions & 14 deletions src/typesense/node_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,15 @@
"""

import copy
import random
import sys
import time

if sys.version_info >= (3, 11):
import typing
else:
import typing_extensions as typing

from typesense.configuration import Configuration, Node
from typesense.logger import logger

Expand Down Expand Up @@ -71,22 +78,19 @@ def get_node(self) -> Node:
Returns:
Node: The selected node for the next operation.
"""
if self.config.nearest_node:
if self.config.nearest_node.healthy or self._is_due_for_health_check(
self.config.nearest_node,
):
return self.config.nearest_node
if self._should_use_nearest_node():
return self.config.nearest_node

node_index = 0
while node_index < len(self.nodes):
node_index += 1
node = self.nodes[self.node_index]
self.node_index = (self.node_index + 1) % len(self.nodes)
if node.healthy or self._is_due_for_health_check(node):
return node
healthy_nodes = self._get_healthy_nodes()

logger.debug("No healthy nodes were found. Returning the next node.")
return self.nodes[self.node_index]
if not healthy_nodes:
logger.debug("No healthy nodes were found. Returning the next node.")
return self.nodes[self.node_index]

if self.config.round_robin_hosts:
return self._get_shuffled_node(healthy_nodes)

return self._get_next_round_robin_node()

def set_node_health(self, node: Node, is_healthy: bool) -> None:
"""
Expand Down Expand Up @@ -126,3 +130,62 @@ def _initialize_nodes(self) -> None:
self.set_node_health(self.config.nearest_node, is_healthy=True)
for node in self.nodes:
self.set_node_health(node, is_healthy=True)

def _should_use_nearest_node(self) -> bool:
"""
Check if we should use the nearest node.

Returns:
bool: True if nearest node should be used, False otherwise.
"""
return bool(
self.config.nearest_node
and (
self.config.nearest_node.healthy
or self._is_due_for_health_check(self.config.nearest_node)
),
)

def _get_healthy_nodes(self) -> typing.List[Node]:
"""
Get a list of all healthy nodes.

Returns:
List[Node]: List of healthy nodes.
"""
return [
node
for node in self.nodes
if node.healthy or self._is_due_for_health_check(node)
]

def _get_shuffled_node(self, healthy_nodes: typing.List[Node]) -> Node:
"""
Get a randomly shuffled node from the list of healthy nodes.

Args:
healthy_nodes (List[Node]): List of healthy nodes to choose from.

Returns:
Node: A randomly selected healthy node.
"""
random.shuffle(healthy_nodes)
self.node_index = (self.node_index + 1) % len(self.nodes)
return healthy_nodes[0]

def _get_next_round_robin_node(self) -> Node:
"""
Get the next node using standard round-robin selection.

Returns:
Node: The next node in the round-robin sequence.
"""
node_index = 0
while node_index < len(self.nodes):
node_index += 1
node = self.nodes[self.node_index]
self.node_index = (self.node_index + 1) % len(self.nodes)
if node.healthy or self._is_due_for_health_check(node):
return node

return self.nodes[self.node_index]
22 changes: 22 additions & 0 deletions tests/api_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,28 @@ def test_get_node_round_robin_selection(
assert_match_object(node3, fake_api_call.config.nodes[2])


def test_get_node_round_robin_shuffle(
fake_api_call: ApiCall,
mocker: MockerFixture,
) -> None:
"""Test that it shuffles healthy nodes when round_robin_hosts is true."""
fake_api_call.config.nearest_node = None
fake_api_call.config.round_robin_hosts = True
mocker.patch("time.time", return_value=100)

shuffle_mock = mocker.patch("random.shuffle")

for _ in range(3):
fake_api_call.node_manager.get_node()

assert shuffle_mock.call_count == 3

for call in shuffle_mock.call_args_list:
args = call[0][0]
assert isinstance(args, list)
assert all(node.healthy for node in args)


def test_get_exception() -> None:
"""Test that it correctly returns the exception class for a given status code."""
assert RequestHandler._get_exception(0) == exceptions.HTTPStatus0Error
Expand Down