5656 GEN_AI_TOKEN_TYPE ,
5757 GEN_AI_USAGE_INPUT_TOKENS ,
5858 GEN_AI_USAGE_OUTPUT_TOKENS ,
59+ GEN_AI_AGENT_ID ,
60+ GEN_AI_AGENT_NAME ,
61+ GEN_AI_TOOL_CALL_ID ,
62+ GEN_AI_TOOL_NAME ,
63+ GEN_AI_TOOL_TYPE ,
5964 GenAiOperationNameValues ,
6065 GenAiSystemValues ,
6166 GenAiTokenTypeValues ,
6267)
68+
6369from opentelemetry .semconv ._incubating .metrics .gen_ai_metrics import (
6470 GEN_AI_CLIENT_OPERATION_DURATION ,
6571 GEN_AI_CLIENT_TOKEN_USAGE ,
@@ -118,6 +124,7 @@ class _BedrockRuntimeExtension(_AwsSdkExtension):
118124 "ConverseStream" ,
119125 "InvokeModel" ,
120126 "InvokeModelWithResponseStream" ,
127+ "InvokeAgent" ,
121128 }
122129 _DONT_CLOSE_SPAN_ON_END_OPERATIONS = {
123130 "ConverseStream" ,
@@ -147,6 +154,9 @@ def setup_metrics(self, meter: Meter, metrics: dict[str, Instrument]):
147154 def _extract_metrics_attributes (self ) -> _AttributeMapT :
148155 attributes = {GEN_AI_SYSTEM : GenAiSystemValues .AWS_BEDROCK .value }
149156
157+ if self ._call_context .operation == "InvokeAgent" :
158+ attributes [GEN_AI_OPERATION_NAME ] = "invoke_agent"
159+
150160 model_id = self ._call_context .params .get (_MODEL_ID_KEY )
151161 if not model_id :
152162 return attributes
@@ -170,6 +180,19 @@ def extract_attributes(self, attributes: _AttributeMapT):
170180
171181 attributes [GEN_AI_SYSTEM ] = GenAiSystemValues .AWS_BEDROCK .value
172182
183+ # Handle InvokeAgent
184+ if self ._call_context .operation == "InvokeAgent" :
185+ attributes [GEN_AI_OPERATION_NAME ] = "invoke_agent"
186+
187+ # Set agent attributes
188+ agent_id = self ._call_context .params .get ("agentId" )
189+ agent_alias_id = self ._call_context .params .get ("agentAliasId" )
190+
191+ self ._set_if_not_none (attributes , GEN_AI_AGENT_ID , agent_id )
192+ self ._set_if_not_none (attributes , GEN_AI_AGENT_NAME , agent_alias_id )
193+ return
194+
195+ # Handle non-agent chat completions
173196 model_id = self ._call_context .params .get (_MODEL_ID_KEY )
174197 if model_id :
175198 attributes [GEN_AI_REQUEST_MODEL ] = model_id
@@ -329,10 +352,14 @@ def before_service_call(
329352
330353 if span .is_recording ():
331354 operation_name = span .attributes .get (GEN_AI_OPERATION_NAME , "" )
332- request_model = span .attributes .get (GEN_AI_REQUEST_MODEL , "" )
333- # avoid setting to an empty string if are not available
334- if operation_name and request_model :
335- span .update_name (f"{ operation_name } { request_model } " )
355+ if self ._call_context .operation == "InvokeAgent" :
356+ if operation_name :
357+ span .update_name (f"{ operation_name } " )
358+ else :
359+ request_model = span .attributes .get (GEN_AI_REQUEST_MODEL , "" )
360+ # avoid setting to an empty string if are not available
361+ if operation_name and request_model :
362+ span .update_name (f"{ operation_name } { request_model } " )
336363
337364 # this is used to calculate the operation duration metric, duration may be skewed by request_hook
338365 # pylint: disable=attribute-defined-outside-init
@@ -472,6 +499,65 @@ def _on_stream_error_callback(
472499 attributes = metrics_attributes ,
473500 )
474501
502+ def _invoke_agent_on_success (
503+ self ,
504+ span : Span ,
505+ result : dict ,
506+ instrumentor_context : _BotocoreInstrumentorContext ,
507+ ):
508+ try :
509+ if "completion" in result and isinstance (result ["completion" ], EventStream ):
510+ event_stream = result ["completion" ]
511+
512+ # Drain the stream so we can instrument AND keep events
513+ all_events = list (event_stream )
514+
515+ # A replay generator so user code can still iterate
516+ result ["completion" ] = _replay_events (all_events )
517+
518+ for event in all_events :
519+ if "returnControl" in event :
520+ self ._handle_return_control (span , event )
521+
522+ # Record metrics
523+ metrics = instrumentor_context .metrics
524+ metrics_attributes = self ._extract_metrics_attributes ()
525+ if operation_duration_histogram := metrics .get (GEN_AI_CLIENT_OPERATION_DURATION ):
526+ duration = max ((default_timer () - self ._operation_start ), 0 )
527+ operation_duration_histogram .record (
528+ duration ,
529+ attributes = metrics_attributes ,
530+ )
531+
532+ except json .JSONDecodeError :
533+ _logger .debug ("Error: Unable to parse the response body as JSON" )
534+ except Exception as exc : # pylint: disable=broad-exception-caught
535+ _logger .debug ("Error processing response: %s" , exc )
536+
537+ def _handle_return_control (self , span : Span , event : dict ):
538+ return_control = event ["returnControl" ]
539+ invocation_id = return_control .get ("invocationId" )
540+ invocation_inputs = return_control .get ("invocationInputs" , [])
541+
542+ if span .is_recording () and invocation_id :
543+ span .set_attribute (GEN_AI_TOOL_CALL_ID , invocation_id )
544+
545+ for input_item in invocation_inputs :
546+ # Handle function invocation
547+ if "functionInvocationInput" in input_item :
548+ func_input = input_item ["functionInvocationInput" ]
549+ action_group = func_input .get ("actionGroup" )
550+ function = func_input .get ("function" )
551+ span .set_attribute (GEN_AI_TOOL_NAME , action_group )
552+ span .set_attribute (GEN_AI_TOOL_TYPE , "function" )
553+
554+ # Handle API invocation
555+ elif "apiInvocationInput" in input_item :
556+ api_input = input_item ["apiInvocationInput" ]
557+ action_group = api_input .get ("actionGroup" )
558+ span .set_attribute (GEN_AI_TOOL_NAME , action_group )
559+ span .set_attribute (GEN_AI_TOOL_TYPE , "extension" )
560+
475561 def on_success (
476562 self ,
477563 span : Span ,
@@ -481,6 +567,12 @@ def on_success(
481567 if self ._call_context .operation not in self ._HANDLED_OPERATIONS :
482568 return
483569
570+ # Handle InvokeAgent
571+ if self ._call_context .operation == "InvokeAgent" :
572+ self ._invoke_agent_on_success (span , result , instrumentor_context )
573+ return
574+
575+ # Handle non-agent chat completions
484576 capture_content = genai_capture_message_content ()
485577
486578 if self ._call_context .operation == "ConverseStream" :
@@ -754,3 +846,10 @@ def on_error(
754846 duration ,
755847 attributes = metrics_attributes ,
756848 )
849+
850+ def _replay_events (events ):
851+ """
852+ Helper so that user can still iterate EventStream
853+ """
854+ for e in events :
855+ yield e
0 commit comments