Skip to content

added healthcheck to the trino python client making it able to procce… #575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,3 +1447,149 @@ def delete_password(self, servicename, username):
return None

os.remove(file_path)


@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_success(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat is sent periodically and does not stop on success."""
head_call_count = 0
def fake_head(url, timeout=10):
nonlocal head_call_count
head_call_count += 1
class Resp:
status_code = 200
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.1)
def finish_query(*args, **kwargs):
query._finished = True
return []
query.fetch = finish_query
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.3)
query._stop_heartbeat()
assert head_call_count >= 2

@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_failure_stops(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops after 3 consecutive failures."""
def fake_head(url, timeout=10):
class Resp:
status_code = 500
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.3)
assert not query._heartbeat_enabled
query._stop_heartbeat()

@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_404_405_stops(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops if server returns 404 or 405."""
for code in (404, 405):
def fake_head(url, timeout=10, code=code):
class Resp:
status_code = code
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.2)
assert not query._heartbeat_enabled
query._stop_heartbeat()

@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_stops_on_finish(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops when the query is finished."""
head_call_count = 0
def fake_head(url, timeout=10):
nonlocal head_call_count
head_call_count += 1
class Resp:
status_code = 200
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.1)
query._finished = True
time.sleep(0.1)
query._stop_heartbeat()
# Heartbeat should have stopped after query finished
assert head_call_count >= 1

@mock.patch("trino.client.TrinoRequest.http")
def test_trinoquery_heartbeat_stops_on_cancel(mock_requests, sample_post_response_data, sample_get_response_data):
"""Test that heartbeat stops when the query is cancelled."""
head_call_count = 0
def fake_head(url, timeout=10):
nonlocal head_call_count
head_call_count += 1
class Resp:
status_code = 200
return Resp()
mock_requests.head.side_effect = fake_head
mock_requests.Response.return_value.json.return_value = sample_post_response_data
mock_requests.get.return_value.json.return_value = sample_get_response_data
mock_requests.post.return_value.json.return_value = sample_post_response_data
req = TrinoRequest(
host="coordinator",
port=8080,
client_session=ClientSession(user="test"),
http_scheme="http",
)
query = TrinoQuery(request=req, query="SELECT 1", heartbeat_interval=0.05)
query._next_uri = "http://coordinator/v1/statement/next"
query._row_mapper = mock.Mock(map=lambda x: [])
query._start_heartbeat()
time.sleep(0.1)
query._cancelled = True
time.sleep(0.1)
query._stop_heartbeat()
# Heartbeat should have stopped after query cancelled
assert head_call_count >= 1
51 changes: 50 additions & 1 deletion trino/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,8 @@ def __init__(
request: TrinoRequest,
query: str,
legacy_primitive_types: bool = False,
fetch_mode: Literal["mapped", "segments"] = "mapped"
fetch_mode: Literal["mapped", "segments"] = "mapped",
heartbeat_interval: float = 60.0, # seconds
) -> None:
self._query_id: Optional[str] = None
self._stats: Dict[Any, Any] = {}
Expand All @@ -826,6 +827,11 @@ def __init__(
self._legacy_primitive_types = legacy_primitive_types
self._row_mapper: Optional[RowMapper] = None
self._fetch_mode = fetch_mode
self._heartbeat_interval = heartbeat_interval
self._heartbeat_thread = None
self._heartbeat_stop_event = threading.Event()
self._heartbeat_failures = 0
self._heartbeat_enabled = True

@property
def query_id(self) -> Optional[str]:
Expand Down Expand Up @@ -868,6 +874,39 @@ def result(self):
def info_uri(self):
return self._info_uri

def _start_heartbeat(self):
if self._heartbeat_thread is not None:
return
self._heartbeat_stop_event.clear()
self._heartbeat_thread = threading.Thread(target=self._heartbeat_loop, daemon=True)
self._heartbeat_thread.start()

def _stop_heartbeat(self):
self._heartbeat_stop_event.set()
if self._heartbeat_thread is not None:
self._heartbeat_thread.join(timeout=2)
self._heartbeat_thread = None

def _heartbeat_loop(self):
while not self._heartbeat_stop_event.is_set() and not self.finished and not self.cancelled and self._heartbeat_enabled:
if self._next_uri is None:
break
try:
response = self._request.http.head(self._next_uri, timeout=10)
if response.status_code == 404 or response.status_code == 405:
self._heartbeat_enabled = False
break
if response.status_code == 200:
self._heartbeat_failures = 0
else:
self._heartbeat_failures += 1
except Exception:
self._heartbeat_failures += 1
if self._heartbeat_failures >= 3:
self._heartbeat_enabled = False
break
self._heartbeat_stop_event.wait(self._heartbeat_interval)

def execute(self, additional_http_headers=None) -> TrinoResult:
"""Initiate a Trino query by sending the SQL statement

Expand Down Expand Up @@ -895,6 +934,9 @@ def execute(self, additional_http_headers=None) -> TrinoResult:
rows = self._row_mapper.map(status.rows) if self._row_mapper else status.rows
self._result = TrinoResult(self, rows)

# Start heartbeat thread
self._start_heartbeat()

# Execute should block until at least one row is received or query is finished or cancelled
while not self.finished and not self.cancelled and len(self._result.rows) == 0:
self._result.rows += self.fetch()
Expand All @@ -921,6 +963,7 @@ def fetch(self) -> List[Union[List[Any]], Any]:
self._update_state(status)
if status.next_uri is None:
self._finished = True
self._stop_heartbeat()

if not self._row_mapper:
return []
Expand Down Expand Up @@ -968,6 +1011,7 @@ def cancel(self) -> None:
if response.status_code == requests.codes.no_content:
self._cancelled = True
logger.debug("query cancelled: %s", self.query_id)
self._stop_heartbeat()
return

self._request.raise_response_error(response)
Expand All @@ -985,6 +1029,11 @@ def finished(self) -> bool:
def cancelled(self) -> bool:
return self._cancelled

@property
def is_running(self) -> bool:
"""Return True if the query is still running (not finished or cancelled)."""
return not self.finished and not self.cancelled


def _retry_with(handle_retry, handled_exceptions, conditions, max_attempts):
def wrapper(func):
Expand Down
Loading