|
6 | 6 | import dotenv |
7 | 7 | from grobid_quantities.quantities import QuantitiesAPI |
8 | 8 | from langchain.memory import ConversationBufferMemory |
9 | | -from langchain_huggingface import HuggingFaceEmbeddings |
| 9 | +from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings |
10 | 10 | from langchain_openai import ChatOpenAI |
11 | 11 | from streamlit_pdf_viewer import pdf_viewer |
12 | 12 |
|
|
23 | 23 | } |
24 | 24 |
|
25 | 25 | API_EMBEDDINGS = { |
26 | | - 'intfloat/e5-large-v2': 'intfloat/e5-large-v2', |
27 | | - 'intfloat/multilingual-e5-large-instruct': 'intfloat/multilingual-e5-large-instruct:', |
28 | | - 'Salesforce/SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R' |
| 26 | + 'intfloat/multilingual-e5-large-instruct': 'intfloat/multilingual-e5-large-instruct' |
29 | 27 | } |
30 | 28 |
|
31 | 29 | if 'rqa' not in st.session_state: |
@@ -135,8 +133,9 @@ def init_qa(model_name, embeddings_name): |
135 | 133 | api_key=os.environ.get('API_KEY') |
136 | 134 | ) |
137 | 135 |
|
138 | | - embeddings = HuggingFaceEmbeddings( |
139 | | - model_name=API_EMBEDDINGS[embeddings_name]) |
| 136 | + embeddings = HuggingFaceEndpointEmbeddings( |
| 137 | + repo_id=API_EMBEDDINGS[embeddings_name] |
| 138 | + ) |
140 | 139 |
|
141 | 140 | storage = DataStorage(embeddings) |
142 | 141 | return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory']) |
@@ -320,7 +319,8 @@ def play_old_messages(container): |
320 | 319 | st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings( |
321 | 320 | tmp_file.name, |
322 | 321 | chunk_size=chunk_size, |
323 | | - perc_overlap=0.1) |
| 322 | + perc_overlap=0.1 |
| 323 | + ) |
324 | 324 | st.session_state['loaded_embeddings'] = True |
325 | 325 | st.session_state.messages = [] |
326 | 326 |
|
|
0 commit comments