diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index de53a86e9..f47688fab 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -254,6 +254,9 @@ def read(self) -> Optional[OAuthToken]: self.telemetry_enabled = ( self.client_telemetry_enabled and self.server_telemetry_enabled ) + self.telemetry_batch_size = kwargs.get( + "telemetry_batch_size", TelemetryClientFactory.DEFAULT_BATCH_SIZE + ) try: self.session = Session( @@ -290,6 +293,7 @@ def read(self) -> Optional[OAuthToken]: session_id_hex=self.get_session_id_hex(), auth_provider=self.session.auth_provider, host_url=self.session.host, + batch_size=self.telemetry_batch_size, ) self._telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/src/databricks/sql/telemetry/telemetry_client.py b/src/databricks/sql/telemetry/telemetry_client.py index 8462e7ffe..9960490c5 100644 --- a/src/databricks/sql/telemetry/telemetry_client.py +++ b/src/databricks/sql/telemetry/telemetry_client.py @@ -138,8 +138,6 @@ class TelemetryClient(BaseTelemetryClient): TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext" TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth" - DEFAULT_BATCH_SIZE = 100 - def __init__( self, telemetry_enabled, @@ -147,10 +145,11 @@ def __init__( auth_provider, host_url, executor, + batch_size, ): logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex) self._telemetry_enabled = telemetry_enabled - self._batch_size = self.DEFAULT_BATCH_SIZE + self._batch_size = batch_size self._session_id_hex = session_id_hex self._auth_provider = auth_provider self._user_agent = None @@ -318,7 +317,7 @@ def close(self): class TelemetryClientFactory: """ Static factory class for creating and managing telemetry clients. - It uses a thread pool to handle asynchronous operations. + It uses a thread pool to handle asynchronous operations and a single flush thread for all clients. """ _clients: Dict[ @@ -331,6 +330,13 @@ class TelemetryClientFactory: _original_excepthook = None _excepthook_installed = False + # Shared flush thread for all clients + _flush_thread = None + _flush_event = threading.Event() + _flush_interval_seconds = 90 + + DEFAULT_BATCH_SIZE = 100 + @classmethod def _initialize(cls): """Initialize the factory if not already initialized""" @@ -341,11 +347,39 @@ def _initialize(cls): max_workers=10 ) # Thread pool for async operations cls._install_exception_hook() + cls._start_flush_thread() cls._initialized = True logger.debug( "TelemetryClientFactory initialized with thread pool (max_workers=10)" ) + @classmethod + def _start_flush_thread(cls): + """Start the shared background thread for periodic flushing of all clients""" + cls._flush_event.clear() + cls._flush_thread = threading.Thread(target=cls._flush_worker, daemon=True) + cls._flush_thread.start() + + @classmethod + def _flush_worker(cls): + """Background worker thread for periodic flushing of all clients""" + while not cls._flush_event.wait(cls._flush_interval_seconds): + logger.debug("Performing periodic flush for all telemetry clients") + + with cls._lock: + clients_to_flush = list(cls._clients.values()) + + for client in clients_to_flush: + client._flush() + + @classmethod + def _stop_flush_thread(cls): + """Stop the shared background flush thread""" + if cls._flush_thread is not None: + cls._flush_event.set() + cls._flush_thread.join(timeout=1.0) + cls._flush_thread = None + @classmethod def _install_exception_hook(cls): """Install global exception handler for unhandled exceptions""" @@ -374,6 +408,7 @@ def initialize_telemetry_client( session_id_hex, auth_provider, host_url, + batch_size, ): """Initialize a telemetry client for a specific connection if telemetry is enabled""" try: @@ -395,6 +430,7 @@ def initialize_telemetry_client( auth_provider=auth_provider, host_url=host_url, executor=TelemetryClientFactory._executor, + batch_size=batch_size, ) else: TelemetryClientFactory._clients[ @@ -433,6 +469,7 @@ def close(session_id_hex): "No more telemetry clients, shutting down thread pool executor" ) try: + TelemetryClientFactory._stop_flush_thread() TelemetryClientFactory._executor.shutdown(wait=True) TelemetryHttpClient.close() except Exception as e: @@ -458,6 +495,7 @@ def connection_failure_log( session_id_hex=UNAUTH_DUMMY_SESSION_ID, auth_provider=None, host_url=host_url, + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE, ) telemetry_client = TelemetryClientFactory.get_telemetry_client( diff --git a/tests/unit/test_telemetry.py b/tests/unit/test_telemetry.py index 398387540..d0e28c18d 100644 --- a/tests/unit/test_telemetry.py +++ b/tests/unit/test_telemetry.py @@ -30,6 +30,7 @@ def mock_telemetry_client(): auth_provider=auth_provider, host_url="test-host.com", executor=executor, + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE ) @@ -214,6 +215,7 @@ def test_client_lifecycle_flow(self): session_id_hex=session_id_hex, auth_provider=auth_provider, host_url="test-host.com", + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -238,6 +240,7 @@ def test_disabled_telemetry_flow(self): session_id_hex=session_id_hex, auth_provider=None, host_url="test-host.com", + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE ) client = TelemetryClientFactory.get_telemetry_client(session_id_hex) @@ -257,6 +260,7 @@ def test_factory_error_handling(self): session_id_hex=session_id, auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE ) # Should fall back to NoopTelemetryClient @@ -275,6 +279,7 @@ def test_factory_shutdown_flow(self): session_id_hex=session, auth_provider=AccessTokenAuthProvider("token"), host_url="test-host.com", + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE ) # Factory should be initialized diff --git a/tests/unit/test_telemetry_retry.py b/tests/unit/test_telemetry_retry.py index 11055b558..b8e216ff4 100644 --- a/tests/unit/test_telemetry_retry.py +++ b/tests/unit/test_telemetry_retry.py @@ -47,6 +47,7 @@ def get_client(self, session_id, num_retries=3): session_id_hex=session_id, auth_provider=None, host_url="test.databricks.com", + batch_size=TelemetryClientFactory.DEFAULT_BATCH_SIZE ) client = TelemetryClientFactory.get_telemetry_client(session_id)