@@ -396,11 +396,21 @@ def _format_passthrough_output(self, result: Any, context: Dict[str, Any]) -> An
396396 return passthrough_output
397397
398398 def _format_chat_prompt_output (
399- self , result : Any , tool_calls : Optional [list ] = None
399+ self ,
400+ result : Any ,
401+ tool_calls : Optional [list ] = None ,
402+ metadata : Optional [dict ] = None ,
400403 ) -> AIMessage :
401404 """Format output for ChatPromptValue input."""
402405 content = self ._extract_content_from_result (result )
403- if tool_calls :
406+
407+ if metadata and isinstance (metadata , dict ):
408+ metadata_copy = metadata .copy ()
409+ metadata_copy .pop ("content" , None )
410+ if tool_calls :
411+ metadata_copy ["tool_calls" ] = tool_calls
412+ return AIMessage (content = content , ** metadata_copy )
413+ elif tool_calls :
404414 return AIMessage (content = content , tool_calls = tool_calls )
405415 return AIMessage (content = content )
406416
@@ -409,11 +419,21 @@ def _format_string_prompt_output(self, result: Any) -> str:
409419 return self ._extract_content_from_result (result )
410420
411421 def _format_message_output (
412- self , result : Any , tool_calls : Optional [list ] = None
422+ self ,
423+ result : Any ,
424+ tool_calls : Optional [list ] = None ,
425+ metadata : Optional [dict ] = None ,
413426 ) -> AIMessage :
414427 """Format output for BaseMessage input types."""
415428 content = self ._extract_content_from_result (result )
416- if tool_calls :
429+
430+ if metadata and isinstance (metadata , dict ):
431+ metadata_copy = metadata .copy ()
432+ metadata_copy .pop ("content" , None )
433+ if tool_calls :
434+ metadata_copy ["tool_calls" ] = tool_calls
435+ return AIMessage (content = content , ** metadata_copy )
436+ elif tool_calls :
417437 return AIMessage (content = content , tool_calls = tool_calls )
418438 return AIMessage (content = content )
419439
@@ -437,25 +457,50 @@ def _format_dict_output_for_dict_message_list(
437457 }
438458
439459 def _format_dict_output_for_base_message_list (
440- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
460+ self ,
461+ result : Any ,
462+ output_key : str ,
463+ tool_calls : Optional [list ] = None ,
464+ metadata : Optional [dict ] = None ,
441465 ) -> Dict [str , Any ]:
442466 """Format dict output when user input was a list of BaseMessage objects."""
443467 content = self ._extract_content_from_result (result )
444- if tool_calls :
468+
469+ if metadata and isinstance (metadata , dict ):
470+ metadata_copy = metadata .copy ()
471+ metadata_copy .pop ("content" , None )
472+ if tool_calls :
473+ metadata_copy ["tool_calls" ] = tool_calls
474+ return {output_key : AIMessage (content = content , ** metadata_copy )}
475+ elif tool_calls :
445476 return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
446477 return {output_key : AIMessage (content = content )}
447478
448479 def _format_dict_output_for_base_message (
449- self , result : Any , output_key : str , tool_calls : Optional [list ] = None
480+ self ,
481+ result : Any ,
482+ output_key : str ,
483+ tool_calls : Optional [list ] = None ,
484+ metadata : Optional [dict ] = None ,
450485 ) -> Dict [str , Any ]:
451486 """Format dict output when user input was a BaseMessage."""
452487 content = self ._extract_content_from_result (result )
453- if tool_calls :
488+
489+ if metadata :
490+ metadata_copy = metadata .copy ()
491+ if tool_calls :
492+ metadata_copy ["tool_calls" ] = tool_calls
493+ return {output_key : AIMessage (content = content , ** metadata_copy )}
494+ elif tool_calls :
454495 return {output_key : AIMessage (content = content , tool_calls = tool_calls )}
455496 return {output_key : AIMessage (content = content )}
456497
457498 def _format_dict_output (
458- self , input_dict : dict , result : Any , tool_calls : Optional [list ] = None
499+ self ,
500+ input_dict : dict ,
501+ result : Any ,
502+ tool_calls : Optional [list ] = None ,
503+ metadata : Optional [dict ] = None ,
459504 ) -> Dict [str , Any ]:
460505 """Format output for dictionary input."""
461506 output_key = self .passthrough_bot_output_key
@@ -474,13 +519,13 @@ def _format_dict_output(
474519 )
475520 elif all (isinstance (msg , BaseMessage ) for msg in user_input ):
476521 return self ._format_dict_output_for_base_message_list (
477- result , output_key , tool_calls
522+ result , output_key , tool_calls , metadata
478523 )
479524 else :
480525 return {output_key : result }
481526 elif isinstance (user_input , BaseMessage ):
482527 return self ._format_dict_output_for_base_message (
483- result , output_key , tool_calls
528+ result , output_key , tool_calls , metadata
484529 )
485530
486531 # Generic fallback for dictionaries
@@ -493,6 +538,7 @@ def _format_output(
493538 result : Any ,
494539 context : Dict [str , Any ],
495540 tool_calls : Optional [list ] = None ,
541+ metadata : Optional [dict ] = None ,
496542 ) -> Any :
497543 """Format the output based on the input type and rails result.
498544
@@ -515,17 +561,17 @@ def _format_output(
515561 return self ._format_passthrough_output (result , context )
516562
517563 if isinstance (input , ChatPromptValue ):
518- return self ._format_chat_prompt_output (result , tool_calls )
564+ return self ._format_chat_prompt_output (result , tool_calls , metadata )
519565 elif isinstance (input , StringPromptValue ):
520566 return self ._format_string_prompt_output (result )
521567 elif isinstance (input , (HumanMessage , AIMessage , BaseMessage )):
522- return self ._format_message_output (result , tool_calls )
568+ return self ._format_message_output (result , tool_calls , metadata )
523569 elif isinstance (input , list ) and all (
524570 isinstance (msg , BaseMessage ) for msg in input
525571 ):
526- return self ._format_message_output (result , tool_calls )
572+ return self ._format_message_output (result , tool_calls , metadata )
527573 elif isinstance (input , dict ):
528- return self ._format_dict_output (input , result , tool_calls )
574+ return self ._format_dict_output (input , result , tool_calls , metadata )
529575 elif isinstance (input , str ):
530576 return self ._format_string_prompt_output (result )
531577 else :
@@ -672,7 +718,9 @@ def _full_rails_invoke(
672718 result = result [0 ]
673719
674720 # Format and return the output based in input type
675- return self ._format_output (input , result , context , res .tool_calls )
721+ return self ._format_output (
722+ input , result , context , res .tool_calls , res .llm_metadata
723+ )
676724
677725 async def ainvoke (
678726 self ,
@@ -734,7 +782,9 @@ async def _full_rails_ainvoke(
734782 result = res .response
735783
736784 # Format and return the output based on input type
737- return self ._format_output (input , result , context , res .tool_calls )
785+ return self ._format_output (
786+ input , result , context , res .tool_calls , res .llm_metadata
787+ )
738788
739789 def stream (
740790 self ,
0 commit comments