66from openai_messages_token_helper import build_messages , get_token_limit
77
88from fastapi_app .api_models import Message , RAGContext , RetrievalResponse , ThoughtStep
9+ from fastapi_app .postgres_models import Item
910from fastapi_app .postgres_searcher import PostgresSearcher
1011from fastapi_app .query_rewriter import build_search_function , extract_search_arguments
11- from fastapi_app .rag_simple import RAGChatBase
12+ from fastapi_app .rag_simple import ChatParams , RAGChatBase
1213
1314
1415class AdvancedRAGChat (RAGChatBase ):
@@ -26,15 +27,10 @@ def __init__(
2627 self .chat_deployment = chat_deployment
2728 self .chat_token_limit = get_token_limit (chat_model , default_to_minimum = True )
2829
29- async def run (
30- self ,
31- messages : list [ChatCompletionMessageParam ],
32- overrides : dict [str , Any ] = {},
33- ) -> RetrievalResponse :
34- chat_params = self .get_params (messages , overrides )
35-
36- # Generate an optimized keyword search query based on the chat history and the last question
37- query_response_token_limit = 500
30+ async def generate_search_query (
31+ self , chat_params : ChatParams , query_response_token_limit : int
32+ ) -> tuple [list [ChatCompletionMessageParam ], Any | str | None , list ]:
33+ """Generate an optimized keyword search query based on the chat history and the last question"""
3834 query_messages : list [ChatCompletionMessageParam ] = build_messages (
3935 model = self .chat_model ,
4036 system_prompt = self .query_prompt_template ,
@@ -57,6 +53,12 @@ async def run(
5753
5854 query_text , filters = extract_search_arguments (chat_params .original_user_query , chat_completion )
5955
56+ return query_messages , query_text , filters
57+
58+ async def retreive_and_build_context (
59+ self , chat_params : ChatParams , query_text : str | Any | None , filters : list
60+ ) -> tuple [list [ChatCompletionMessageParam ], list [Item ]]:
61+ """Retrieve relevant items from the database and build a context for the chat model."""
6062 # Retrieve relevant items from the database with the GPT optimized query
6163 results = await self .searcher .search_and_embed (
6264 query_text ,
@@ -70,22 +72,40 @@ async def run(
7072 content = "\n " .join (sources_content )
7173
7274 # Generate a contextual and content specific answer using the search results and chat history
73- response_token_limit = 1024
7475 contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
7576 model = self .chat_model ,
76- system_prompt = overrides . get ( " prompt_template" ) or self . answer_prompt_template ,
77+ system_prompt = chat_params . prompt_template ,
7778 new_user_content = chat_params .original_user_query + "\n \n Sources:\n " + content ,
7879 past_messages = chat_params .past_messages ,
79- max_tokens = self .chat_token_limit - response_token_limit ,
80+ max_tokens = self .chat_token_limit - chat_params . response_token_limit ,
8081 fallback_to_default = True ,
8182 )
83+ return contextual_messages , results
84+
85+ async def run (
86+ self ,
87+ messages : list [ChatCompletionMessageParam ],
88+ overrides : dict [str , Any ] = {},
89+ ) -> RetrievalResponse :
90+ chat_params = self .get_params (messages , overrides )
91+
92+ # Generate an optimized keyword search query based on the chat history and the last question
93+ query_messages , query_text , filters = await self .generate_search_query (
94+ chat_params = chat_params , query_response_token_limit = 500
95+ )
96+
97+ # Retrieve relevant items from the database with the GPT optimized query
98+ # Generate a contextual and content specific answer using the search results and chat history
99+ contextual_messages , results = await self .retreive_and_build_context (
100+ chat_params = chat_params , query_text = query_text , filters = filters
101+ )
82102
83103 chat_completion_response : ChatCompletion = await self .openai_chat_client .chat .completions .create (
84104 # Azure OpenAI takes the deployment name as the model name
85105 model = self .chat_deployment if self .chat_deployment else self .chat_model ,
86106 messages = contextual_messages ,
87- temperature = overrides . get ( " temperature" , 0.3 ) ,
88- max_tokens = response_token_limit ,
107+ temperature = chat_params . temperature ,
108+ max_tokens = chat_params . response_token_limit ,
89109 n = 1 ,
90110 stream = False ,
91111 )
@@ -141,50 +161,14 @@ async def run_stream(
141161 chat_params = self .get_params (messages , overrides )
142162
143163 # Generate an optimized keyword search query based on the chat history and the last question
144- query_response_token_limit = 500
145- query_messages : list [ChatCompletionMessageParam ] = build_messages (
146- model = self .chat_model ,
147- system_prompt = self .query_prompt_template ,
148- new_user_content = chat_params .original_user_query ,
149- past_messages = chat_params .past_messages ,
150- max_tokens = self .chat_token_limit - query_response_token_limit , # TODO: count functions
151- fallback_to_default = True ,
164+ query_messages , query_text , filters = await self .generate_search_query (
165+ chat_params = chat_params , query_response_token_limit = 500
152166 )
153167
154- chat_completion : ChatCompletion = await self .openai_chat_client .chat .completions .create (
155- messages = query_messages ,
156- # Azure OpenAI takes the deployment name as the model name
157- model = self .chat_deployment if self .chat_deployment else self .chat_model ,
158- temperature = 0.0 , # Minimize creativity for search query generation
159- max_tokens = query_response_token_limit , # Setting too low risks malformed JSON, too high risks performance
160- n = 1 ,
161- tools = build_search_function (),
162- tool_choice = "auto" ,
163- )
164-
165- query_text , filters = extract_search_arguments (chat_params .original_user_query , chat_completion )
166-
167168 # Retrieve relevant items from the database with the GPT optimized query
168- results = await self .searcher .search_and_embed (
169- query_text ,
170- top = chat_params .top ,
171- enable_vector_search = chat_params .enable_vector_search ,
172- enable_text_search = chat_params .enable_text_search ,
173- filters = filters ,
174- )
175-
176- sources_content = [f"[{ (item .id )} ]:{ item .to_str_for_rag ()} \n \n " for item in results ]
177- content = "\n " .join (sources_content )
178-
179169 # Generate a contextual and content specific answer using the search results and chat history
180- response_token_limit = 1024
181- contextual_messages : list [ChatCompletionMessageParam ] = build_messages (
182- model = self .chat_model ,
183- system_prompt = overrides .get ("prompt_template" ) or self .answer_prompt_template ,
184- new_user_content = chat_params .original_user_query + "\n \n Sources:\n " + content ,
185- past_messages = chat_params .past_messages ,
186- max_tokens = self .chat_token_limit - response_token_limit ,
187- fallback_to_default = True ,
170+ contextual_messages , results = await self .retreive_and_build_context (
171+ chat_params = chat_params , query_text = query_text , filters = filters
188172 )
189173
190174 chat_completion_async_stream : AsyncStream [
@@ -193,8 +177,8 @@ async def run_stream(
193177 # Azure OpenAI takes the deployment name as the model name
194178 model = self .chat_deployment if self .chat_deployment else self .chat_model ,
195179 messages = contextual_messages ,
196- temperature = overrides . get ( " temperature" , 0.3 ) ,
197- max_tokens = response_token_limit ,
180+ temperature = chat_params . temperature ,
181+ max_tokens = chat_params . response_token_limit ,
198182 n = 1 ,
199183 stream = True ,
200184 )
0 commit comments