1-
21from langchain_openai import OpenAIEmbeddings
32from langchain_ollama import OllamaEmbeddings
43from langchain_aws import BedrockEmbeddings
109
1110from langchain_neo4j import Neo4jVector
1211
13- from langchain . chains import RetrievalQAWithSourcesChain
14- from langchain . chains . qa_with_sources import load_qa_with_sources_chain
12+ from langchain_core . runnables import RunnableParallel , RunnablePassthrough
13+ from langchain_core . output_parsers import StrOutputParser
1514
1615from langchain .prompts import (
1716 ChatPromptTemplate ,
1817 HumanMessagePromptTemplate ,
19- SystemMessagePromptTemplate
18+ SystemMessagePromptTemplate ,
2019)
2120
2221from typing import List , Any
23- from utils import BaseLogger , extract_title_and_question
22+ from utils import BaseLogger , extract_title_and_question , format_docs
2423from langchain_google_genai import GoogleGenerativeAIEmbeddings
2524
2625AWS_MODELS = (
3231 "mistral.mi" ,
3332)
3433
34+
3535def load_embedding_model (embedding_model_name : str , logger = BaseLogger (), config = {}):
3636 if embedding_model_name == "ollama" :
3737 embeddings = OllamaEmbeddings (
@@ -47,10 +47,8 @@ def load_embedding_model(embedding_model_name: str, logger=BaseLogger(), config=
4747 embeddings = BedrockEmbeddings ()
4848 dimension = 1536
4949 logger .info ("Embedding: Using AWS" )
50- elif embedding_model_name == "google-genai-embedding-001" :
51- embeddings = GoogleGenerativeAIEmbeddings (
52- model = "models/embedding-001"
53- )
50+ elif embedding_model_name == "google-genai-embedding-001" :
51+ embeddings = GoogleGenerativeAIEmbeddings (model = "models/embedding-001" )
5452 dimension = 768
5553 logger .info ("Embedding: Using Google Generative AI Embeddings" )
5654 else :
@@ -112,17 +110,8 @@ def configure_llm_only_chain(llm):
112110 chat_prompt = ChatPromptTemplate .from_messages (
113111 [system_message_prompt , human_message_prompt ]
114112 )
115-
116- def generate_llm_output (
117- user_input : str , callbacks : List [Any ], prompt = chat_prompt
118- ) -> str :
119- chain = prompt | llm
120- answer = chain .invoke (
121- {"question" : user_input }, config = {"callbacks" : callbacks }
122- ).content
123- return {"answer" : answer }
124-
125- return generate_llm_output
113+ chain = chat_prompt | llm | StrOutputParser ()
114+ return chain
126115
127116
128117def configure_qa_rag_chain (llm , embeddings , embeddings_store_url , username , password ):
@@ -152,12 +141,6 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
152141 ]
153142 qa_prompt = ChatPromptTemplate .from_messages (messages )
154143
155- qa_chain = load_qa_with_sources_chain (
156- llm ,
157- chain_type = "stuff" ,
158- prompt = qa_prompt ,
159- )
160-
161144 # Vector + Knowledge Graph response
162145 kg = Neo4jVector .from_existing_index (
163146 embedding = embeddings ,
@@ -183,12 +166,16 @@ def configure_qa_rag_chain(llm, embeddings, embeddings_store_url, username, pass
183166 ORDER BY similarity ASC // so that best answers are the last
184167 """ ,
185168 )
186-
187- kg_qa = RetrievalQAWithSourcesChain (
188- combine_documents_chain = qa_chain ,
189- retriever = kg .as_retriever (search_kwargs = {"k" : 2 }),
190- reduce_k_below_max_tokens = False ,
191- max_tokens_limit = 3375 ,
169+ kg_qa = (
170+ RunnableParallel (
171+ {
172+ "summaries" : kg .as_retriever (search_kwargs = {"k" : 2 }) | format_docs ,
173+ "question" : RunnablePassthrough (),
174+ }
175+ )
176+ | qa_prompt
177+ | llm
178+ | StrOutputParser ()
192179 )
193180 return kg_qa
194181
0 commit comments