Skip to content

Commit 4108477

Browse files
committed
back to the huggingface API embeddings, remove non-usable embeddings
1 parent 135de1f commit 4108477

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

document_qa/document_qa_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def create_memory_embeddings(
423423
if doc_id:
424424
hash = doc_id
425425
else:
426-
hash = metadata[0]['hash']
426+
hash = metadata[0]['hash'] if len(metadata) > 0 and 'hash' in metadata[0] else ""
427427

428428
self.data_storage.embed_document(hash, texts, metadata)
429429

streamlit_app.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import dotenv
77
from grobid_quantities.quantities import QuantitiesAPI
88
from langchain.memory import ConversationBufferMemory
9-
from langchain_huggingface import HuggingFaceEmbeddings
9+
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpointEmbeddings
1010
from langchain_openai import ChatOpenAI
1111
from streamlit_pdf_viewer import pdf_viewer
1212

@@ -23,9 +23,7 @@
2323
}
2424

2525
API_EMBEDDINGS = {
26-
'intfloat/e5-large-v2': 'intfloat/e5-large-v2',
27-
'intfloat/multilingual-e5-large-instruct': 'intfloat/multilingual-e5-large-instruct:',
28-
'Salesforce/SFR-Embedding-2_R': 'Salesforce/SFR-Embedding-2_R'
26+
'intfloat/multilingual-e5-large-instruct': 'intfloat/multilingual-e5-large-instruct'
2927
}
3028

3129
if 'rqa' not in st.session_state:
@@ -135,8 +133,9 @@ def init_qa(model_name, embeddings_name):
135133
api_key=os.environ.get('API_KEY')
136134
)
137135

138-
embeddings = HuggingFaceEmbeddings(
139-
model_name=API_EMBEDDINGS[embeddings_name])
136+
embeddings = HuggingFaceEndpointEmbeddings(
137+
repo_id=API_EMBEDDINGS[embeddings_name]
138+
)
140139

141140
storage = DataStorage(embeddings)
142141
return DocumentQAEngine(chat, storage, grobid_url=os.environ['GROBID_URL'], memory=st.session_state['memory'])
@@ -320,7 +319,8 @@ def play_old_messages(container):
320319
st.session_state['doc_id'] = hash = st.session_state['rqa'][model].create_memory_embeddings(
321320
tmp_file.name,
322321
chunk_size=chunk_size,
323-
perc_overlap=0.1)
322+
perc_overlap=0.1
323+
)
324324
st.session_state['loaded_embeddings'] = True
325325
st.session_state.messages = []
326326

0 commit comments

Comments
 (0)