Skip to content

Telemetry server-side flag integration #646

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,6 @@ def read(self) -> Optional[OAuthToken]:
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
self._cursors = [] # type: List[Cursor]

self.server_telemetry_enabled = True
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
self.telemetry_enabled = (
self.client_telemetry_enabled and self.server_telemetry_enabled
)

try:
self.session = Session(
server_hostname,
Expand Down Expand Up @@ -285,6 +279,10 @@ def read(self) -> Optional[OAuthToken]:
)
self.staging_allowed_local_path = kwargs.get("staging_allowed_local_path", None)

self.force_enable_telemetry = kwargs.get("force_enable_telemetry", False)
self.enable_telemetry = kwargs.get("enable_telemetry", False)
self.telemetry_enabled = TelemetryHelper.is_telemetry_enabled(self)

TelemetryClientFactory.initialize_telemetry_client(
telemetry_enabled=self.telemetry_enabled,
session_id_hex=self.get_session_id_hex(),
Expand Down
181 changes: 181 additions & 0 deletions src/databricks/sql/common/feature_flag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# TODO: Test this when server-side feature flag is available

import threading
import time
import requests
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional, List, Any, TYPE_CHECKING

if TYPE_CHECKING:
from databricks.sql.client import Connection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is for static type checking and prevents circular imports



@dataclass
class FeatureFlagEntry:
"""Represents a single feature flag from the server response."""

name: str
value: str


@dataclass
class FeatureFlagsResponse:
"""Represents the full JSON response from the feature flag endpoint."""

flags: List[FeatureFlagEntry] = field(default_factory=list)
ttl_seconds: Optional[int] = None

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "FeatureFlagsResponse":
"""Factory method to create an instance from a dictionary (parsed JSON)."""
flags_data = data.get("flags", [])
flags_list = [FeatureFlagEntry(**flag) for flag in flags_data]
return cls(flags=flags_list, ttl_seconds=data.get("ttl_seconds"))


# --- Constants ---
FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT = (
"/api/2.0/connector-service/feature-flags/PYTHON/{}"
)
DEFAULT_TTL_SECONDS = 900 # 15 minutes
REFRESH_BEFORE_EXPIRY_SECONDS = 10 # Start proactive refresh 10s before expiry


class FeatureFlagsContext:
"""
Manages fetching and caching of server-side feature flags for a connection.

1. The very first check for any flag is a synchronous, BLOCKING operation.
2. Subsequent refreshes (triggered near TTL expiry) are done asynchronously
in the background, returning stale data until the refresh completes.
"""

def __init__(self, connection: "Connection", executor: ThreadPoolExecutor):
from databricks.sql import __version__

self._connection = connection
self._executor = executor # Used for ASYNCHRONOUS refreshes
self._lock = threading.RLock()

# Cache state: `None` indicates the cache has never been loaded.
self._flags: Optional[Dict[str, str]] = None
self._ttl_seconds: int = DEFAULT_TTL_SECONDS
self._last_refresh_time: float = 0

endpoint_suffix = FEATURE_FLAGS_ENDPOINT_SUFFIX_FORMAT.format(__version__)
self._feature_flag_endpoint = (
f"https://{self._connection.session.host}{endpoint_suffix}"
)

def _is_refresh_needed(self) -> bool:
"""Checks if the cache is due for a proactive background refresh."""
if self._flags is None:
return False # Not eligible for refresh until loaded once.

refresh_threshold = self._last_refresh_time + (
self._ttl_seconds - REFRESH_BEFORE_EXPIRY_SECONDS
)
return time.monotonic() > refresh_threshold

def is_feature_enabled(self, name: str, default_value: bool) -> bool:
"""
Checks if a feature is enabled.
- BLOCKS on the first call until flags are fetched.
- Returns cached values on subsequent calls, triggering non-blocking refreshes if needed.
"""
with self._lock:
# If cache has never been loaded, perform a synchronous, blocking fetch.
if self._flags is None:
self._refresh_flags()

# If a proactive background refresh is needed, start one. This is non-blocking.
elif self._is_refresh_needed():
# We don't check for an in-flight refresh; the executor queues the task, which is safe.
self._executor.submit(self._refresh_flags)

assert self._flags is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean we will just fail if we aren't able to fetch flags? shouldn't we do safe defaults?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is an exception in _refresh_flags(), _flags is set to an empty Dict. If the particular flag we are checking is set to None, we use the default value passed in the args.


# Now, return the value from the populated cache.
flag_value = self._flags.get(name)
if flag_value is None:
return default_value
return flag_value.lower() == "true"

def _refresh_flags(self):
"""Performs a synchronous network request to fetch and update flags."""
headers = {}
try:
# Authenticate the request
self._connection.session.auth_provider.add_headers(headers)
headers["User-Agent"] = self._connection.session.useragent_header

response = requests.get(
self._feature_flag_endpoint, headers=headers, timeout=30
)

if response.status_code == 200:
ff_response = FeatureFlagsResponse.from_dict(response.json())
self._update_cache_from_response(ff_response)
else:
# On failure, initialize with an empty dictionary to prevent re-blocking.
if self._flags is None:
self._flags = {}

except Exception as e:
# On exception, initialize with an empty dictionary to prevent re-blocking.
if self._flags is None:
self._flags = {}

def _update_cache_from_response(self, ff_response: FeatureFlagsResponse):
"""Atomically updates the internal cache state from a successful server response."""
with self._lock:
self._flags = {flag.name: flag.value for flag in ff_response.flags}
if ff_response.ttl_seconds is not None and ff_response.ttl_seconds > 0:
self._ttl_seconds = ff_response.ttl_seconds
self._last_refresh_time = time.monotonic()


class FeatureFlagsContextFactory:
"""
Manages a singleton instance of FeatureFlagsContext per connection session.
Also manages a shared ThreadPoolExecutor for all background refresh operations.
"""

_context_map: Dict[str, FeatureFlagsContext] = {}
_executor: Optional[ThreadPoolExecutor] = None
_lock = threading.Lock()

@classmethod
def _initialize(cls):
"""Initializes the shared executor for async refreshes if it doesn't exist."""
if cls._executor is None:
cls._executor = ThreadPoolExecutor(
max_workers=3, thread_name_prefix="feature-flag-refresher"
)

@classmethod
def get_instance(cls, connection: "Connection") -> FeatureFlagsContext:
"""Gets or creates a FeatureFlagsContext for the given connection."""
with cls._lock:
cls._initialize()
assert cls._executor is not None

# Use the unique session ID as the key
key = connection.get_session_id_hex()
if key not in cls._context_map:
cls._context_map[key] = FeatureFlagsContext(connection, cls._executor)
return cls._context_map[key]

@classmethod
def remove_instance(cls, connection: "Connection"):
"""Removes the context for a given connection and shuts down the executor if no clients remain."""
with cls._lock:
key = connection.get_session_id_hex()
if key in cls._context_map:
cls._context_map.pop(key, None)

# If this was the last active context, clean up the thread pool.
if not cls._context_map and cls._executor is not None:
cls._executor.shutdown(wait=False)
cls._executor = None
22 changes: 21 additions & 1 deletion src/databricks/sql/telemetry/telemetry_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Optional
from typing import Dict, Optional, TYPE_CHECKING
from databricks.sql.common.http import TelemetryHttpClient
from databricks.sql.telemetry.models.event import (
TelemetryEvent,
Expand Down Expand Up @@ -36,6 +36,10 @@
import uuid
import locale
from databricks.sql.telemetry.utils import BaseTelemetryClient
from databricks.sql.common.feature_flag import FeatureFlagsContextFactory

if TYPE_CHECKING:
from databricks.sql.client import Connection

logger = logging.getLogger(__name__)

Expand All @@ -44,6 +48,9 @@ class TelemetryHelper:
"""Helper class for getting telemetry related information."""

_DRIVER_SYSTEM_CONFIGURATION = None
TELEMETRY_FEATURE_FLAG_NAME = (
"databricks.partnerplatform.clientConfigsFeatureFlags.enableTelemetry"
)

@classmethod
def get_driver_system_configuration(cls) -> DriverSystemConfiguration:
Expand Down Expand Up @@ -98,6 +105,19 @@ def get_auth_flow(auth_provider):
else:
return None

@staticmethod
def is_telemetry_enabled(connection: "Connection") -> bool:
if connection.force_enable_telemetry:
return True

if connection.enable_telemetry:
context = FeatureFlagsContextFactory.get_instance(connection)
return context.is_feature_enabled(
TelemetryHelper.TELEMETRY_FEATURE_FLAG_NAME, default_value=False
)
else:
return False


class NoopTelemetryClient(BaseTelemetryClient):
"""
Expand Down
Loading