11from flask import Flask
22from flask import request , jsonify
33import threading
4- from langchain_openai import OpenAIEmbeddings
5- from langchain .chains import RetrievalQAWithSourcesChain , LLMChain
64import os
75from langchain .memory import ConversationBufferWindowMemory
86from langchain_openai import ChatOpenAI
9- from langchain .memory import ConversationBufferWindowMemory
107from langchain_core .prompts import ChatPromptTemplate , MessagesPlaceholder
8+ from langchain_core .output_parsers import StrOutputParser
119import logging
12- from langchain_core .prompts .chat import (
13- SystemMessagePromptTemplate ,
14- HumanMessagePromptTemplate ,
15- AIMessagePromptTemplate ,
16- )
17- from langchain_mongodb import MongoDBChatMessageHistory
1810from db import MONGO_CONNECTION_URI , MONGO_DB_NAME
1911from chatbot_utils import document_loader
12+ from pymongo import MongoClient
13+ from langchain_core .runnables import RunnableLambda , RunnablePassthrough
14+ from langchain .schema import messages_from_dict , messages_to_dict
15+ from langchain_mongodb import MongoDBChatMessageHistory
2016
2117app = Flask (__name__ )
2218app .logger .setLevel (logging .DEBUG )
2319
24- app .logger .info ("MONGO_CONNECTION_URI:: %s" , MONGO_CONNECTION_URI )
20+ app .logger .debug ("MONGO_CONNECTION_URI:: %s" , MONGO_CONNECTION_URI )
2521retriever = None
2622persist_directory = os .environ .get ("PERSIST_DIRECTORY" )
2723loaded_model_lock = threading .Lock ()
@@ -43,7 +39,7 @@ def load_global_retriever():
4339
4440
4541def get_llm ():
46- llm = ChatOpenAI (temperature = 0.6 , model_name = "gpt-3.5-turbo-0613 " )
42+ llm = ChatOpenAI (temperature = 0.6 , model_name = "gpt-4o " )
4743 return llm
4844
4945
@@ -72,46 +68,67 @@ def get_chat_message_history(session):
7268
7369
7470def get_qa_chain (llm , retriever , session ):
71+
7572 system_prompt_template = """
76- You are a helpful AI Assistant. Your name is CrapBot.
77- Respond to the user questions and do not withhold any information
78- Use the following pieces of retrieved context to help you answer the user's questions:
79- {summaries }
73+ You are a helpful AI Assistant. Your name is CrapBot.
74+ Respond to the user questions and do not withhold any information
75+ Use the following pieces of retrieved context to help you answer the user's questions:
76+ {context }
8077 """
81- human_prompt_template = "{question}"
82- chatbot_prompt_template = "CrapBot:"
83- messages = [
84- ("system" , system_prompt_template ),
85- MessagesPlaceholder (variable_name = "chat_history" , optional = False ),
86- ("human" , human_prompt_template ),
87- # ("system", chatbot_prompt_template),
88- ]
89-
90- PROMPT = ChatPromptTemplate .from_messages (
91- messages ,
78+
79+ # Create the chat prompt template
80+ prompt = ChatPromptTemplate .from_messages (
81+ [
82+ ("system" , system_prompt_template ),
83+ MessagesPlaceholder (variable_name = "chat_history" ),
84+ ("human" , "{question}" ),
85+ ]
86+ )
87+
88+ # Get chat history
89+ chat_memory = get_chat_message_history (session )
90+ memory = ConversationBufferWindowMemory (
91+ memory_key = "chat_history" ,
92+ input_key = "question" ,
93+ output_key = "answer" ,
94+ k = 6 ,
95+ ai_prefix = "CrapBot" ,
96+ chat_memory = chat_memory ,
97+ return_messages = True ,
9298 )
93- chain_type_kwargs = {"prompt" : PROMPT }
94- chat_message_history = get_chat_message_history (session )
95- qa = RetrievalQAWithSourcesChain .from_chain_type (
96- llm = llm ,
97- chain_type = "stuff" ,
98- retriever = retriever ,
99- chain_type_kwargs = chain_type_kwargs ,
100- memory = ConversationBufferWindowMemory (
101- memory_key = "chat_history" ,
102- input_key = "question" ,
103- output_key = "answer" ,
104- k = 6 ,
105- ai_prefix = "CrapBot" ,
106- chat_memory = chat_message_history ,
107- return_messages = True ,
108- ),
99+
100+ # Create the retrieval chain
101+ def get_context (query ):
102+ docs = retriever .get_relevant_documents (query )
103+ return "\n " .join (doc .page_content for doc in docs )
104+
105+ def get_chat_history (inputs ):
106+ return memory .load_memory_variables (inputs )["chat_history" ]
107+
108+ retrieval_chain = (
109+ RunnablePassthrough ()
110+ | {
111+ "context" : get_context ,
112+ "chat_history" : get_chat_history ,
113+ "question" : lambda x : x ,
114+ }
115+ | prompt
116+ | llm
117+ | StrOutputParser ()
109118 )
110- return qa
119+
120+ def chain_with_memory (inputs ):
121+ query = inputs ["question" ]
122+ result = retrieval_chain .invoke (query )
123+ # Update memory
124+ memory .save_context ({"question" : query }, {"answer" : result })
125+ return {"answer" : result }
126+
127+ return chain_with_memory
111128
112129
113130def qa_answer (model , session , query ):
114- result = model . invoke ({"question" : query })
131+ result = model ({"question" : query })
115132 app .logger .debug ("Session: %s, Result %s" , session , result )
116133 return result ["answer" ]
117134
@@ -135,11 +152,11 @@ def init_bot():
135152 app .logger .debug ("Initializing bot %s" , request .json ["openai_api_key" ])
136153 retriever_l = document_loader (openai_api_key , app .logger )
137154 session_model_map [session ] = retriever_l
138- return jsonify ({"message" : "Model Initialized" }), 400
155+ return jsonify ({"message" : "Model Initialized" }), 200
139156
140157 except Exception as e :
141- app .logger .error ("Error initializing bot " , e )
142- app .logger .debug ("Error initializing bot " , e , exc_info = True )
158+ app .logger .error ("Error initializing bot %s " , e )
159+ app .logger .debug ("Error initializing bot %s " , e , exc_info = True )
143160 return jsonify ({"message" : "Not able to initialize model " + str (e )}), 500
144161
145162
@@ -179,6 +196,12 @@ def reset_chat_history_bot():
179196 return jsonify ({"message" : "Error deleting chat history" }), 500
180197
181198
199+ def augment_context (input_dict ):
200+ question = input_dict ["question" ]
201+ context = input_dict ["context" ]
202+ return {"question" : question , "context" : context }
203+
204+
182205@app .route ("/chatbot/genai/ask" , methods = ["POST" ])
183206def ask_bot ():
184207 retriever_l = None
@@ -195,7 +218,7 @@ def ask_bot():
195218 jsonify (
196219 {
197220 "initialized" : "false" ,
198- "message" : "Model not initialized for session %s" ,
221+ "message" : "Model not initialized for session %s" % session ,
199222 }
200223 ),
201224 500 ,
0 commit comments