Skip to content

Commit ae04b9d

Browse files
Luca FoppianoLuca Foppiano
authored andcommitted
fix import, and reformat
1 parent 0b28b48 commit ae04b9d

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

document_qa/document_qa_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from langchain.vectorstores import Chroma
1313
from tqdm import tqdm
1414

15-
from grobid_processors import GrobidProcessor
15+
from document_qa.grobid_processors import GrobidProcessor
1616

1717

1818
class DocumentQAEngine:

document_qa/grobid_processors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,8 @@ def __init__(self, grobid_superconductors_client):
413413

414414
def extract_materials(self, text):
415415
preprocessed_text = text.strip()
416-
status, result = self.grobid_superconductors_client.process_text(preprocessed_text, "processText_disable_linking")
416+
status, result = self.grobid_superconductors_client.process_text(preprocessed_text,
417+
"processText_disable_linking")
417418

418419
if status != 200:
419420
result = {}
@@ -679,6 +680,7 @@ def parse_xml(self, text):
679680

680681
return output_data
681682

683+
682684
def get_children_list_supermat(soup, use_paragraphs=False, verbose=False):
683685
children = []
684686

@@ -697,6 +699,7 @@ def get_children_list_supermat(soup, use_paragraphs=False, verbose=False):
697699

698700
return children
699701

702+
700703
def get_children_list_grobid(soup: object, use_paragraphs: object = True, verbose: object = False) -> object:
701704
children = []
702705

@@ -739,4 +742,4 @@ def get_children_figures(soup: object, use_paragraphs: object = True, verbose: o
739742
if verbose:
740743
print(str(children))
741744

742-
return children
745+
return children

streamlit_app.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
if "messages" not in st.session_state:
4343
st.session_state.messages = []
4444

45+
4546
def new_file():
4647
st.session_state['loaded_embeddings'] = None
4748
st.session_state['doc_id'] = None
@@ -69,6 +70,7 @@ def init_qa(model):
6970

7071
return DocumentQAEngine(chat, embeddings, grobid_url=os.environ['GROBID_URL'])
7172

73+
7274
@st.cache_resource
7375
def init_ner():
7476
quantities_client = QuantitiesAPI(os.environ['GROBID_QUANTITIES_URL'], check_server=True)
@@ -89,14 +91,16 @@ def init_ner():
8991
materials_client.set_config(config_materials)
9092

9193
gqa = GrobidAggregationProcessor(None,
92-
grobid_quantities_client=quantities_client,
93-
grobid_superconductors_client=materials_client
94-
)
94+
grobid_quantities_client=quantities_client,
95+
grobid_superconductors_client=materials_client
96+
)
9597

9698
return gqa
9799

100+
98101
gqa = init_ner()
99102

103+
100104
def get_file_hash(fname):
101105
hash_md5 = blake2b()
102106
with open(fname, "rb") as f:
@@ -122,7 +126,7 @@ def play_old_messages():
122126
is_api_key_provided = st.session_state['api_key']
123127

124128
model = st.sidebar.radio("Model (cannot be changed after selection or upload)",
125-
("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"),#, "llama-2-70b-chat"),
129+
("chatgpt-3.5-turbo", "mistral-7b-instruct-v0.1"), # , "llama-2-70b-chat"),
126130
index=1,
127131
captions=[
128132
"ChatGPT 3.5 Turbo + Ada-002-text (embeddings)",
@@ -134,13 +138,15 @@ def play_old_messages():
134138

135139
if not st.session_state['api_key']:
136140
if model == 'mistral-7b-instruct-v0.1' or model == 'llama-2-70b-chat':
137-
api_key = st.sidebar.text_input('Huggingface API Key', type="password")# if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
141+
api_key = st.sidebar.text_input('Huggingface API Key',
142+
type="password") # if 'HUGGINGFACEHUB_API_TOKEN' not in os.environ else os.environ['HUGGINGFACEHUB_API_TOKEN']
138143
if api_key:
139144
st.session_state['api_key'] = is_api_key_provided = True
140145
os.environ["HUGGINGFACEHUB_API_TOKEN"] = api_key
141146
st.session_state['rqa'] = init_qa(model)
142147
elif model == 'chatgpt-3.5-turbo':
143-
api_key = st.sidebar.text_input('OpenAI API Key', type="password") #if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
148+
api_key = st.sidebar.text_input('OpenAI API Key',
149+
type="password") # if 'OPENAI_API_KEY' not in os.environ else os.environ['OPENAI_API_KEY']
144150
if api_key:
145151
st.session_state['api_key'] = is_api_key_provided = True
146152
os.environ['OPENAI_API_KEY'] = api_key
@@ -177,10 +183,12 @@ def play_old_messages():
177183
st.markdown(
178184
"""After entering your API Key (Open AI or Huggingface). Upload a scientific article as PDF document. You will see a spinner or loading indicator while the processing is in progress. Once the spinner stops, you can proceed to ask your questions.""")
179185

180-
st.markdown('**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.', unsafe_allow_html=True)
186+
st.markdown(
187+
'**NER on LLM responses**: The responses from the LLMs are post-processed to extract <span style="color:orange">physical quantities, measurements</span> and <span style="color:green">materials</span> mentions.',
188+
unsafe_allow_html=True)
181189
if st.session_state['git_rev'] != "unknown":
182190
st.markdown("**Revision number**: [" + st.session_state[
183-
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
191+
'git_rev'] + "](https://github.com/lfoppiano/document-qa/commit/" + st.session_state['git_rev'] + ")")
184192

185193
st.header("Query mode (Advanced use)")
186194
st.markdown(
@@ -219,11 +227,11 @@ def play_old_messages():
219227
if mode == "Embeddings":
220228
with st.spinner("Generating LLM response..."):
221229
text_response = st.session_state['rqa'].query_storage(question, st.session_state.doc_id,
222-
context_size=context_size)
230+
context_size=context_size)
223231
elif mode == "LLM":
224232
with st.spinner("Generating response..."):
225233
_, text_response = st.session_state['rqa'].query_document(question, st.session_state.doc_id,
226-
context_size=context_size)
234+
context_size=context_size)
227235

228236
if not text_response:
229237
st.error("Something went wrong. Contact Luca Foppiano ([email protected]) to report the issue.")

0 commit comments

Comments
 (0)