@@ -224,15 +224,20 @@ def search_tool(self, query: str, top_k: int = 5) -> str:
224224 return formatted_result
225225
226226 def ask (self , query : str , max_tokens : int = 2048 , top_k : int = 10 , top_p : float = 0.8 , temperature : float = 0.3 ,
227- model : str = os .getenv ("DEFAULT_MODEL" )) -> str :
227+ model : str = os .getenv ("DEFAULT_MODEL" )) -> Dict [ str , Any ] :
228228 """
229229 Ask a question using the advanced Memory Alpha RAG system with tool use.
230+ Returns a dictionary with answer and token usage information.
230231 """
231232
232233 if not model :
233234 raise ValueError ("model must be provided or set in DEFAULT_MODEL environment variable." )
234235
235236 logger .info (f"Starting tool-enabled RAG for query: { query } " )
237+
238+ # Initialize token tracking
239+ total_input_tokens = 0
240+ total_output_tokens = 0
236241
237242 # Define the search tool
238243 search_tool_definition = {
@@ -317,6 +322,20 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float
317322 logger .info (f"LLM response type: { type (response_message )} " )
318323 logger .debug (f"Response message content: { response_message .get ('content' , 'No content' )[:200 ]} ..." )
319324
325+ # Estimate tokens based on content length
326+ # Rough estimation: ~4 characters per token for English text
327+ content = response_message .get ('content' , '' )
328+ estimated_output_tokens = len (content ) // 4
329+ total_output_tokens += estimated_output_tokens
330+
331+ # Estimate input tokens from current message content
332+ input_text = ' ' .join ([msg .get ('content' , '' ) for msg in messages ])
333+ estimated_input_tokens = len (input_text ) // 4
334+ # Only add the increment from this iteration to avoid double counting
335+ total_input_tokens = estimated_input_tokens
336+
337+ logger .info (f"Estimated tokens - Input: { estimated_input_tokens } , Output: { estimated_output_tokens } " )
338+
320339 # Check if the model wants to use a tool
321340 tool_calls = getattr (response_message , 'tool_calls' , None ) or response_message .get ('tool_calls' )
322341 if tool_calls :
@@ -377,15 +396,37 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float
377396
378397 self ._update_history (query , final_response )
379398 logger .info ("Returning final answer" )
380- return final_response
399+
400+ return {
401+ "answer" : final_response ,
402+ "token_usage" : {
403+ "input_tokens" : total_input_tokens ,
404+ "output_tokens" : total_output_tokens ,
405+ "total_tokens" : total_input_tokens + total_output_tokens
406+ }
407+ }
381408
382409 except Exception as e :
383410 logger .error (f"Chat failed: { e } " )
384- return f"Error processing query: { str (e )} "
411+ return {
412+ "answer" : f"Error processing query: { str (e )} " ,
413+ "token_usage" : {
414+ "input_tokens" : total_input_tokens ,
415+ "output_tokens" : total_output_tokens ,
416+ "total_tokens" : total_input_tokens + total_output_tokens
417+ }
418+ }
385419
386420 # Fallback if max iterations reached
387421 logger .warning (f"Max iterations reached for query: { query } " )
388- return "Query processing exceeded maximum iterations. Please try a simpler question."
422+ return {
423+ "answer" : "Query processing exceeded maximum iterations. Please try a simpler question." ,
424+ "token_usage" : {
425+ "input_tokens" : total_input_tokens ,
426+ "output_tokens" : total_output_tokens ,
427+ "total_tokens" : total_input_tokens + total_output_tokens
428+ }
429+ }
389430
390431 def _update_history (self , question : str , answer : str ):
391432 """Update conversation history."""
0 commit comments