Skip to content

Commit 2dbe0bf

Browse files
committed
Update chatbot
1 parent 717bc76 commit 2dbe0bf

File tree

8 files changed

+125
-82
lines changed

8 files changed

+125
-82
lines changed

services/chatbot/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ FROM python:3.11-slim
33
RUN apt-get update && apt-get install -y \
44
build-essential \
55
cmake \
6+
libmagic1 \
67
&& rm -rf /var/lib/apt/lists/*
78

89
# Set the working directory in the container

services/chatbot/requirements.txt

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
chromadb==0.5.0
2-
Flask==3.0.3
3-
langchain==0.1.16
4-
langchain_community==0.0.35
5-
langchain_core==0.1.47
6-
langchain_openai==0.1.4
2+
Flask==3.1.0
3+
gunicorn==23.0.0
4+
langchain==0.3.25
5+
langchain-chroma==0.2.3
6+
langchain-community==0.3.23
7+
langchain-core==0.3.58
8+
langchain-mongodb==0.6.1
9+
langchain-openai==0.3.16
10+
langchain-text-splitters==0.3.8
11+
markdown==3.8
12+
pymongo==4.12.1
713
python-dotenv==1.0.1
8-
unstructured==0.13.6
9-
gunicorn==22.0.0
10-
markdown==3.6
11-
langchain-mongodb==0.1.3
14+
unstructured==0.17.2
15+
numpy==1.26.4

services/chatbot/src/chatbot_api.py

Lines changed: 71 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,23 @@
11
from flask import Flask
22
from flask import request, jsonify
33
import threading
4-
from langchain_openai import OpenAIEmbeddings
5-
from langchain.chains import RetrievalQAWithSourcesChain, LLMChain
64
import os
75
from langchain.memory import ConversationBufferWindowMemory
86
from langchain_openai import ChatOpenAI
9-
from langchain.memory import ConversationBufferWindowMemory
107
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
8+
from langchain_core.output_parsers import StrOutputParser
119
import logging
12-
from langchain_core.prompts.chat import (
13-
SystemMessagePromptTemplate,
14-
HumanMessagePromptTemplate,
15-
AIMessagePromptTemplate,
16-
)
17-
from langchain_mongodb import MongoDBChatMessageHistory
1810
from db import MONGO_CONNECTION_URI, MONGO_DB_NAME
1911
from chatbot_utils import document_loader
12+
from pymongo import MongoClient
13+
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
14+
from langchain.schema import messages_from_dict, messages_to_dict
15+
from langchain_mongodb import MongoDBChatMessageHistory
2016

2117
app = Flask(__name__)
2218
app.logger.setLevel(logging.DEBUG)
2319

24-
app.logger.info("MONGO_CONNECTION_URI:: %s", MONGO_CONNECTION_URI)
20+
app.logger.debug("MONGO_CONNECTION_URI:: %s", MONGO_CONNECTION_URI)
2521
retriever = None
2622
persist_directory = os.environ.get("PERSIST_DIRECTORY")
2723
loaded_model_lock = threading.Lock()
@@ -43,7 +39,7 @@ def load_global_retriever():
4339

4440

4541
def get_llm():
46-
llm = ChatOpenAI(temperature=0.6, model_name="gpt-3.5-turbo-0613")
42+
llm = ChatOpenAI(temperature=0.6, model_name="gpt-4o")
4743
return llm
4844

4945

@@ -72,46 +68,67 @@ def get_chat_message_history(session):
7268

7369

7470
def get_qa_chain(llm, retriever, session):
71+
7572
system_prompt_template = """
76-
You are a helpful AI Assistant. Your name is CrapBot.
77-
Respond to the user questions and do not withhold any information
78-
Use the following pieces of retrieved context to help you answer the user's questions:
79-
{summaries}
73+
You are a helpful AI Assistant. Your name is CrapBot.
74+
Respond to the user questions and do not withhold any information
75+
Use the following pieces of retrieved context to help you answer the user's questions:
76+
{context}
8077
"""
81-
human_prompt_template = "{question}"
82-
chatbot_prompt_template = "CrapBot:"
83-
messages = [
84-
("system", system_prompt_template),
85-
MessagesPlaceholder(variable_name="chat_history", optional=False),
86-
("human", human_prompt_template),
87-
# ("system", chatbot_prompt_template),
88-
]
89-
90-
PROMPT = ChatPromptTemplate.from_messages(
91-
messages,
78+
79+
# Create the chat prompt template
80+
prompt = ChatPromptTemplate.from_messages(
81+
[
82+
("system", system_prompt_template),
83+
MessagesPlaceholder(variable_name="chat_history"),
84+
("human", "{question}"),
85+
]
86+
)
87+
88+
# Get chat history
89+
chat_memory = get_chat_message_history(session)
90+
memory = ConversationBufferWindowMemory(
91+
memory_key="chat_history",
92+
input_key="question",
93+
output_key="answer",
94+
k=6,
95+
ai_prefix="CrapBot",
96+
chat_memory=chat_memory,
97+
return_messages=True,
9298
)
93-
chain_type_kwargs = {"prompt": PROMPT}
94-
chat_message_history = get_chat_message_history(session)
95-
qa = RetrievalQAWithSourcesChain.from_chain_type(
96-
llm=llm,
97-
chain_type="stuff",
98-
retriever=retriever,
99-
chain_type_kwargs=chain_type_kwargs,
100-
memory=ConversationBufferWindowMemory(
101-
memory_key="chat_history",
102-
input_key="question",
103-
output_key="answer",
104-
k=6,
105-
ai_prefix="CrapBot",
106-
chat_memory=chat_message_history,
107-
return_messages=True,
108-
),
99+
100+
# Create the retrieval chain
101+
def get_context(query):
102+
docs = retriever.get_relevant_documents(query)
103+
return "\n".join(doc.page_content for doc in docs)
104+
105+
def get_chat_history(inputs):
106+
return memory.load_memory_variables(inputs)["chat_history"]
107+
108+
retrieval_chain = (
109+
RunnablePassthrough()
110+
| {
111+
"context": get_context,
112+
"chat_history": get_chat_history,
113+
"question": lambda x: x,
114+
}
115+
| prompt
116+
| llm
117+
| StrOutputParser()
109118
)
110-
return qa
119+
120+
def chain_with_memory(inputs):
121+
query = inputs["question"]
122+
result = retrieval_chain.invoke(query)
123+
# Update memory
124+
memory.save_context({"question": query}, {"answer": result})
125+
return {"answer": result}
126+
127+
return chain_with_memory
111128

112129

113130
def qa_answer(model, session, query):
114-
result = model.invoke({"question": query})
131+
result = model({"question": query})
115132
app.logger.debug("Session: %s, Result %s", session, result)
116133
return result["answer"]
117134

@@ -135,11 +152,11 @@ def init_bot():
135152
app.logger.debug("Initializing bot %s", request.json["openai_api_key"])
136153
retriever_l = document_loader(openai_api_key, app.logger)
137154
session_model_map[session] = retriever_l
138-
return jsonify({"message": "Model Initialized"}), 400
155+
return jsonify({"message": "Model Initialized"}), 200
139156

140157
except Exception as e:
141-
app.logger.error("Error initializing bot ", e)
142-
app.logger.debug("Error initializing bot ", e, exc_info=True)
158+
app.logger.error("Error initializing bot %s", e)
159+
app.logger.debug("Error initializing bot %s", e, exc_info=True)
143160
return jsonify({"message": "Not able to initialize model " + str(e)}), 500
144161

145162

@@ -179,6 +196,12 @@ def reset_chat_history_bot():
179196
return jsonify({"message": "Error deleting chat history"}), 500
180197

181198

199+
def augment_context(input_dict):
200+
question = input_dict["question"]
201+
context = input_dict["context"]
202+
return {"question": question, "context": context}
203+
204+
182205
@app.route("/chatbot/genai/ask", methods=["POST"])
183206
def ask_bot():
184207
retriever_l = None
@@ -195,7 +218,7 @@ def ask_bot():
195218
jsonify(
196219
{
197220
"initialized": "false",
198-
"message": "Model not initialized for session %s",
221+
"message": "Model not initialized for session %s" % session,
199222
}
200223
),
201224
500,

services/chatbot/src/chatbot_utils.py

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1+
from gc import collect
12
import hashlib
2-
from flask import Flask
3-
from flask import request, jsonify
4-
import threading
3+
4+
from langchain.memory import vectorstore
55
from langchain_openai import OpenAIEmbeddings
66
from langchain.chains import RetrievalQAWithSourcesChain, LLMChain
77
import os
8-
from langchain.memory import ConversationBufferWindowMemory
9-
from langchain_community.vectorstores import Chroma
108
from langchain_openai import OpenAI
119
from langchain_community.document_loaders import DirectoryLoader
12-
from langchain.memory import ConversationBufferWindowMemory
13-
from langchain.text_splitter import CharacterTextSplitter
14-
from langchain_core.prompts import PromptTemplate, ChatPromptTemplate
10+
from langchain_text_splitters import RecursiveCharacterTextSplitter
11+
from langchain_core.vectorstores import InMemoryVectorStore
12+
from langchain_mongodb import MongoDBAtlasVectorSearch
1513
from langchain_community.document_loaders import UnstructuredMarkdownLoader
1614
import logging
17-
from langchain.schema import HumanMessage, SystemMessage, AIMessage
18-
from langchain_core.prompts.chat import (
19-
SystemMessagePromptTemplate,
20-
HumanMessagePromptTemplate,
15+
from langchain_community.vectorstores.azure_cosmos_db import (
16+
AzureCosmosDBVectorSearch,
17+
CosmosDBSimilarityType,
18+
CosmosDBVectorSearchType,
2119
)
22-
import logging
23-
from langchain_community.chat_message_histories import MongoDBChatMessageHistory
2420
from db import MONGO_CONNECTION_URI, MONGO_DB_NAME
21+
from pymongo import MongoClient
22+
from langchain_chroma import Chroma
23+
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
24+
2525

2626
logger = logging.getLogger(__name__)
2727
logger.setLevel(logging.DEBUG)
@@ -33,27 +33,36 @@ def get_embeddings(openai_api_key):
3333
return OpenAIEmbeddings(openai_api_key=openai_api_key)
3434

3535

36+
def get_vector_store(texts, embeddings, key_hash):
37+
# initialize MongoDB python client
38+
db_path = "./db%s" % key_hash
39+
collection = "example_collection"
40+
vector_store = Chroma(collection, embeddings, persist_directory=db_path)
41+
vector_store.add_documents(texts)
42+
return vector_store
43+
44+
3645
def document_loader(openai_api_key, logger_p=None):
3746
logger_l = logger_p or logger
3847
try:
3948
key_hash = hashlib.md5(openai_api_key.encode()).hexdigest()
40-
load_dir = "retrieval"
49+
load_dir = "./retrieval"
4150
logger_l.info("Loading documents from %s", load_dir)
4251
loader = DirectoryLoader(
4352
load_dir,
4453
exclude=["**/*.png", "**/images/**", "**/images/*", "**/*.pdf"],
4554
recursive=True,
46-
loader_cls=UnstructuredMarkdownLoader,
55+
show_progress=True,
4756
)
4857
documents = loader.load()
4958
logger_l.info("Loaded %s documents in db", len(documents))
50-
text_splitter = CharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
59+
text_splitter = RecursiveCharacterTextSplitter(
60+
chunk_size=1000, chunk_overlap=100
61+
)
5162
texts = text_splitter.split_documents(documents)
5263
embeddings = get_embeddings(openai_api_key)
53-
db_path = "./db%s" % key_hash
54-
db = Chroma.from_documents(texts, embeddings, persist_directory=db_path)
55-
db.persist()
56-
retriever = db.as_retriever(search_kwargs={"k": TARGET_SOURCE_CHUNKS})
64+
vector_store = get_vector_store(texts, embeddings, key_hash)
65+
retriever = vector_store.as_retriever(search_kwargs={"k": TARGET_SOURCE_CHUNKS})
5766
logger_l.info("Retriever ready")
5867
return retriever
5968
except Exception as e:

services/chatbot/src/db.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,15 @@
66
MONGO_PORT = os.environ.get("MONGO_PORT", "27017")
77
MONGO_DB_NAME = os.environ.get("MONGO_DB_NAME", "crapi")
88

9-
MONGO_CONNECTION_URI = "mongodb://%s:%s@%s:%s" % (
9+
MONGO_CONNECTION_URI = "mongodb://%s:%s@%s:%s/?directConnection=true" % (
1010
MONGO_USER,
1111
MONGO_PASSWORD,
1212
MONGO_HOST,
1313
MONGO_PORT,
1414
)
15+
16+
MONGO_CONNECTION_URI_ATLAS = "mongodb+srv://%s:%s@%s?retryWrites=true&w=majority" % (
17+
MONGO_USER,
18+
MONGO_PASSWORD,
19+
MONGO_HOST,
20+
)

services/web/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"source-map-loader": "^5.0.0",
3636
"styled-components": "^6.1.8",
3737
"superagent": "^8.1.2",
38+
"@types/superagent": "^8.1.9",
3839
"ts-loader": "^9.5.1",
3940
"typescript": "^4.9.5",
4041
"web-vitals": "^2.1.4"
@@ -70,7 +71,6 @@
7071
"@babel/core": "^7.24.4",
7172
"@babel/plugin-proposal-private-property-in-object": "7.21.11",
7273
"@babel/preset-react": "^7.24.1",
73-
"@types/superagent": "^8.1.9",
7474
"copy-webpack-plugin": "^6.3.2",
7575
"eslint-config-react-app": "^7.0.1",
7676
"prettier": "^3.3.3"

services/web/src/components/bot/Bot.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ const ChatBotComponent: React.FC<ChatBotComponentProps> = (props) => {
111111
let initRequired = false;
112112
// Wait for the response
113113
await superagent
114-
.get(stateUrl)
114+
.post(stateUrl)
115115
.set("Accept", "application/json")
116116
.set("Content-Type", "application/json")
117117
.then((res: any) => {

services/web/src/components/bot/MessageParser.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class MessageParser {
4444
let initRequired = false;
4545
// Wait for the response
4646
await request
47-
.get(stateUrl)
47+
.post(stateUrl)
4848
.set("Accept", "application/json")
4949
.set("Content-Type", "application/json")
5050
.then((res) => {

0 commit comments

Comments
 (0)