diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py index b85c2336b0..be739702bb 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/__init__.py @@ -91,7 +91,6 @@ def response_hook(span, instance, response): """ import typing from typing import Any, Collection - import redis from wrapt import wrap_function_wrapper @@ -106,6 +105,7 @@ def response_hook(span, instance, response): from opentelemetry.instrumentation.utils import unwrap from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Span +from opentelemetry.metrics import UpDownCounter, get_meter _DEFAULT_SERVICE = "redis" @@ -119,6 +119,7 @@ def response_hook(span, instance, response): ] _REDIS_ASYNCIO_VERSION = (4, 2, 0) + if redis.VERSION >= _REDIS_ASYNCIO_VERSION: import redis.asyncio @@ -137,6 +138,7 @@ def _set_connection_attributes(span, conn): def _instrument( tracer, + connections_usage: UpDownCounter, request_hook: _RequestHookT = None, response_hook: _ResponseHookT = None, ): @@ -146,18 +148,33 @@ def _traced_execute_command(func, instance, args, kwargs): name = args[0] else: name = instance.connection_pool.connection_kwargs.get("db", 0) - with tracer.start_as_current_span( - name, kind=trace.SpanKind.CLIENT - ) as span: - if span.is_recording(): - span.set_attribute(SpanAttributes.DB_STATEMENT, query) - _set_connection_attributes(span, instance) - span.set_attribute("db.redis.args_length", len(args)) - if callable(request_hook): - request_hook(span, instance, args, kwargs) - response = func(*args, **kwargs) - if callable(response_hook): - response_hook(span, instance, response) + + try: + with tracer.start_as_current_span( + name, kind=trace.SpanKind.CLIENT + ) as span: + if span.is_recording(): + span.set_attribute(SpanAttributes.DB_STATEMENT, query) + _set_connection_attributes(span, instance) + span.set_attribute("db.redis.args_length", len(args)) + if callable(request_hook): + request_hook(span, instance, args, kwargs) + response = func(*args, **kwargs) + connections_usage.add( + 1, + { + "db.client.connection.usage.state": "used", + "db.client.connection.usage.name": instance.connection_pool.pid, + }) + if callable(response_hook): + response_hook(span, instance, response) + finally: + connections_usage.add( + -1, + { + "db.client.connection.usage.state": "idle", + "db.client.connection.usage.name": instance.connection_pool.pid, + }) return response def _traced_execute_pipeline(func, instance, args, kwargs): @@ -199,13 +216,26 @@ def _traced_execute_pipeline(func, instance, args, kwargs): response_hook(span, instance, response) return response + def _traced_get_connection(func, connection_pool, command_name, *keys, **options): + response = func(command_name, *keys, **options) + connections_usage.add( + 1, + { + "db.client.connection.usage.state": "used", + "db.client.connection.usage.name": connection_pool.pid, + }) + return response + + pipeline_class = ( "BasePipeline" if redis.VERSION < (3, 0, 0) else "Pipeline" ) redis_class = "StrictRedis" if redis.VERSION < (3, 0, 0) else "Redis" wrap_function_wrapper( - "redis", f"{redis_class}.execute_command", _traced_execute_command + "redis", + f"{redis_class}.execute_command", + _traced_execute_command ) wrap_function_wrapper( "redis.client", @@ -228,6 +258,11 @@ def _traced_execute_pipeline(func, instance, args, kwargs): "ClusterPipeline.execute", _traced_execute_pipeline, ) + wrap_function_wrapper( + "redis", + "ConnectionPool.get_connection", + _traced_get_connection + ) if redis.VERSION >= _REDIS_ASYNCIO_VERSION: wrap_function_wrapper( "redis.asyncio", @@ -277,8 +312,20 @@ def _instrument(self, **kwargs): tracer = trace.get_tracer( __name__, __version__, tracer_provider=tracer_provider ) + meter_provider = kwargs.get("meter_provider") + meter = get_meter( + __name__, + __version__, + meter_provider + ) + connections_usage = meter.create_up_down_counter( + name="db.client.connection.usage", + description="The number of connections that are currently in state described", + unit="1", + ) _instrument( tracer, + connections_usage, request_hook=kwargs.get("request_hook"), response_hook=kwargs.get("response_hook"), ) diff --git a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/package.py b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/package.py index dd2efb37b0..39005f7ddf 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/package.py +++ b/instrumentation/opentelemetry-instrumentation-redis/src/opentelemetry/instrumentation/redis/package.py @@ -14,3 +14,4 @@ _instruments = ("redis >= 2.6",) +_supports_metrics = True diff --git a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py index 3d5479e731..42842d7ecd 100644 --- a/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py +++ b/instrumentation/opentelemetry-instrumentation-redis/tests/test_redis.py @@ -146,3 +146,43 @@ def request_hook(span, conn, args, kwargs): span = spans[0] self.assertEqual(span.attributes.get(custom_attribute_name), "GET") + + +class TestRedisIntegrationMetric(TestBase): + + def setUp(self): + super().setUp() + RedisInstrumentor().instrument(meter_provider=self.meter_provider) + + def tearDown(self): + super().tearDown() + RedisInstrumentor().uninstrument() + + @staticmethod + def redis_get(): + pool = redis.ConnectionPool(host='localhost', port=6379, db=0) + redis_client = redis.Redis(connection_pool=pool) + redis_client.get('foo') + return pool.pid + + def test_multiple_connections_metric_success_redis(self): + pid = self.redis_get() + expected_metric_names = { + "db.client.connection.usage", + } + expected_metric_attributes = { + "db.client.connection.usage.state": "used", + "db.client.connection.usage.name": pid, + } + for ( + resource_metrics + ) in self.memory_metrics_reader.get_metrics_data().resource_metrics: + for scope_metrics in resource_metrics.scope_metrics: + for metric in scope_metrics.metrics: + self.assertIn(metric.name, expected_metric_names) + for data_point in metric.data.data_points: + for attr in data_point.attributes: + self.assertIn( + attr, expected_metric_attributes + ) +