56
56
GEN_AI_TOKEN_TYPE ,
57
57
GEN_AI_USAGE_INPUT_TOKENS ,
58
58
GEN_AI_USAGE_OUTPUT_TOKENS ,
59
+ GEN_AI_AGENT_ID ,
60
+ GEN_AI_AGENT_NAME ,
61
+ GEN_AI_TOOL_CALL_ID ,
62
+ GEN_AI_TOOL_NAME ,
63
+ GEN_AI_TOOL_TYPE ,
59
64
GenAiOperationNameValues ,
60
65
GenAiSystemValues ,
61
66
GenAiTokenTypeValues ,
62
67
)
68
+
63
69
from opentelemetry .semconv ._incubating .metrics .gen_ai_metrics import (
64
70
GEN_AI_CLIENT_OPERATION_DURATION ,
65
71
GEN_AI_CLIENT_TOKEN_USAGE ,
@@ -118,6 +124,7 @@ class _BedrockRuntimeExtension(_AwsSdkExtension):
118
124
"ConverseStream" ,
119
125
"InvokeModel" ,
120
126
"InvokeModelWithResponseStream" ,
127
+ "InvokeAgent" ,
121
128
}
122
129
_DONT_CLOSE_SPAN_ON_END_OPERATIONS = {
123
130
"ConverseStream" ,
@@ -147,6 +154,9 @@ def setup_metrics(self, meter: Meter, metrics: dict[str, Instrument]):
147
154
def _extract_metrics_attributes (self ) -> _AttributeMapT :
148
155
attributes = {GEN_AI_SYSTEM : GenAiSystemValues .AWS_BEDROCK .value }
149
156
157
+ if self ._call_context .operation == "InvokeAgent" :
158
+ attributes [GEN_AI_OPERATION_NAME ] = "invoke_agent"
159
+
150
160
model_id = self ._call_context .params .get (_MODEL_ID_KEY )
151
161
if not model_id :
152
162
return attributes
@@ -170,6 +180,19 @@ def extract_attributes(self, attributes: _AttributeMapT):
170
180
171
181
attributes [GEN_AI_SYSTEM ] = GenAiSystemValues .AWS_BEDROCK .value
172
182
183
+ # Handle InvokeAgent
184
+ if self ._call_context .operation == "InvokeAgent" :
185
+ attributes [GEN_AI_OPERATION_NAME ] = "invoke_agent"
186
+
187
+ # Set agent attributes
188
+ agent_id = self ._call_context .params .get ("agentId" )
189
+ agent_alias_id = self ._call_context .params .get ("agentAliasId" )
190
+
191
+ self ._set_if_not_none (attributes , GEN_AI_AGENT_ID , agent_id )
192
+ self ._set_if_not_none (attributes , GEN_AI_AGENT_NAME , agent_alias_id )
193
+ return
194
+
195
+ # Handle non-agent chat completions
173
196
model_id = self ._call_context .params .get (_MODEL_ID_KEY )
174
197
if model_id :
175
198
attributes [GEN_AI_REQUEST_MODEL ] = model_id
@@ -329,10 +352,14 @@ def before_service_call(
329
352
330
353
if span .is_recording ():
331
354
operation_name = span .attributes .get (GEN_AI_OPERATION_NAME , "" )
332
- request_model = span .attributes .get (GEN_AI_REQUEST_MODEL , "" )
333
- # avoid setting to an empty string if are not available
334
- if operation_name and request_model :
335
- span .update_name (f"{ operation_name } { request_model } " )
355
+ if self ._call_context .operation == "InvokeAgent" :
356
+ if operation_name :
357
+ span .update_name (f"{ operation_name } " )
358
+ else :
359
+ request_model = span .attributes .get (GEN_AI_REQUEST_MODEL , "" )
360
+ # avoid setting to an empty string if are not available
361
+ if operation_name and request_model :
362
+ span .update_name (f"{ operation_name } { request_model } " )
336
363
337
364
# this is used to calculate the operation duration metric, duration may be skewed by request_hook
338
365
# pylint: disable=attribute-defined-outside-init
@@ -472,6 +499,65 @@ def _on_stream_error_callback(
472
499
attributes = metrics_attributes ,
473
500
)
474
501
502
+ def _invoke_agent_on_success (
503
+ self ,
504
+ span : Span ,
505
+ result : dict ,
506
+ instrumentor_context : _BotocoreInstrumentorContext ,
507
+ ):
508
+ try :
509
+ if "completion" in result and isinstance (result ["completion" ], EventStream ):
510
+ event_stream = result ["completion" ]
511
+
512
+ # Drain the stream so we can instrument AND keep events
513
+ all_events = list (event_stream )
514
+
515
+ # A replay generator so user code can still iterate
516
+ result ["completion" ] = _replay_events (all_events )
517
+
518
+ for event in all_events :
519
+ if "returnControl" in event :
520
+ self ._handle_return_control (span , event )
521
+
522
+ # Record metrics
523
+ metrics = instrumentor_context .metrics
524
+ metrics_attributes = self ._extract_metrics_attributes ()
525
+ if operation_duration_histogram := metrics .get (GEN_AI_CLIENT_OPERATION_DURATION ):
526
+ duration = max ((default_timer () - self ._operation_start ), 0 )
527
+ operation_duration_histogram .record (
528
+ duration ,
529
+ attributes = metrics_attributes ,
530
+ )
531
+
532
+ except json .JSONDecodeError :
533
+ _logger .debug ("Error: Unable to parse the response body as JSON" )
534
+ except Exception as exc : # pylint: disable=broad-exception-caught
535
+ _logger .debug ("Error processing response: %s" , exc )
536
+
537
+ def _handle_return_control (self , span : Span , event : dict ):
538
+ return_control = event ["returnControl" ]
539
+ invocation_id = return_control .get ("invocationId" )
540
+ invocation_inputs = return_control .get ("invocationInputs" , [])
541
+
542
+ if span .is_recording () and invocation_id :
543
+ span .set_attribute (GEN_AI_TOOL_CALL_ID , invocation_id )
544
+
545
+ for input_item in invocation_inputs :
546
+ # Handle function invocation
547
+ if "functionInvocationInput" in input_item :
548
+ func_input = input_item ["functionInvocationInput" ]
549
+ action_group = func_input .get ("actionGroup" )
550
+ function = func_input .get ("function" )
551
+ span .set_attribute (GEN_AI_TOOL_NAME , action_group )
552
+ span .set_attribute (GEN_AI_TOOL_TYPE , "function" )
553
+
554
+ # Handle API invocation
555
+ elif "apiInvocationInput" in input_item :
556
+ api_input = input_item ["apiInvocationInput" ]
557
+ action_group = api_input .get ("actionGroup" )
558
+ span .set_attribute (GEN_AI_TOOL_NAME , action_group )
559
+ span .set_attribute (GEN_AI_TOOL_TYPE , "extension" )
560
+
475
561
def on_success (
476
562
self ,
477
563
span : Span ,
@@ -481,6 +567,12 @@ def on_success(
481
567
if self ._call_context .operation not in self ._HANDLED_OPERATIONS :
482
568
return
483
569
570
+ # Handle InvokeAgent
571
+ if self ._call_context .operation == "InvokeAgent" :
572
+ self ._invoke_agent_on_success (span , result , instrumentor_context )
573
+ return
574
+
575
+ # Handle non-agent chat completions
484
576
capture_content = genai_capture_message_content ()
485
577
486
578
if self ._call_context .operation == "ConverseStream" :
@@ -754,3 +846,10 @@ def on_error(
754
846
duration ,
755
847
attributes = metrics_attributes ,
756
848
)
849
+
850
+ def _replay_events (events ):
851
+ """
852
+ Helper so that user can still iterate EventStream
853
+ """
854
+ for e in events :
855
+ yield e
0 commit comments