-
Notifications
You must be signed in to change notification settings - Fork 291
/
Copy pathchat_with_documents.py
101 lines (83 loc) · 3.74 KB
/
chat_with_documents.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Chat with retrieval and embeddings."""
import os
import tempfile
from config import set_environment
from langchain.chains.base import Chain
from langchain.chains.conversational_retrieval.base import ConversationalRetrievalChain
from langchain.chains.flare.base import FlareChain
from langchain.chains.moderation import OpenAIModerationChain
from langchain.chains.sequential import SequentialChain
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import EmbeddingsFilter
from langchain_community.vectorstores.docarray import DocArrayInMemorySearch
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from chapter5.chat_with_retrieval.utils import LOGGER, MEMORY, load_document
set_environment()
LOGGER.info("setup LLM")
# Setup LLM and QA chain; set temperature low to keep hallucinations in check
LLM = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0, streaming=True)
LOGGER.info("configure_retriever")
def configure_retriever(docs: list[Document], use_compression: bool = False) -> BaseRetriever:
"""Retriever to use."""
# Split each document documents:
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=200)
splits = text_splitter.split_documents(docs)
# Create embeddings and store in vectordb:
embeddings = OpenAIEmbeddings()
# alternatively: HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
# Create vectordb with single call to embedding model for texts:
vectordb = DocArrayInMemorySearch.from_documents(splits, embeddings)
retriever = vectordb.as_retriever(
search_type="mmr",
search_kwargs={"k": 5, "fetch_k": 7, "include_metadata": True},
)
if not use_compression:
return retriever
embeddings_filter = EmbeddingsFilter(embeddings=embeddings, similarity_threshold=0.2)
return ContextualCompressionRetriever(
base_compressor=embeddings_filter,
base_retriever=retriever,
)
def configure_chain(retriever: BaseRetriever, use_flare: bool = True) -> Chain:
"""Configure chain with a retriever.
Passing in a max_tokens_limit amount automatically
truncates the tokens when prompting your llm!
"""
output_key = "response" if use_flare else "answer"
MEMORY.output_key = output_key
params = dict(
llm=LLM,
retriever=retriever,
memory=MEMORY,
verbose=True,
max_tokens_limit=4000,
)
if use_flare:
# different set of parameters and init
return FlareChain.from_llm(**params)
return ConversationalRetrievalChain.from_llm(**params)
def configure_retrieval_chain(
uploaded_files,
use_compression: bool = False,
use_flare: bool = False,
use_moderation: bool = False,
) -> Chain:
"""Read documents, configure retriever, and the chain."""
docs = []
temp_dir = tempfile.TemporaryDirectory()
for file in uploaded_files:
temp_filepath = os.path.join(temp_dir.name, file.name)
with open(temp_filepath, "wb") as f:
f.write(file.getvalue())
docs.extend(load_document(temp_filepath))
retriever = configure_retriever(docs=docs, use_compression=use_compression)
chain = configure_chain(retriever=retriever, use_flare=use_flare)
if not use_moderation:
return chain
input_variables = ["user_input"] if use_flare else ["chat_history", "question"]
moderation_input = "response" if use_flare else "answer"
moderation_chain = OpenAIModerationChain(input_key=moderation_input)
return SequentialChain(chains=[chain, moderation_chain], input_variables=input_variables)