diff --git a/app/models.py b/app/models.py index bf27efc..b426cb2 100644 --- a/app/models.py +++ b/app/models.py @@ -1,7 +1,7 @@ from sentence_transformers import SentenceTransformer, CrossEncoder from transformers import pipeline import torch - +import nltk bi_encoder = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1') @@ -9,3 +9,5 @@ cross_encoder_large = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') qa_model = pipeline('question-answering', model='deepset/roberta-base-squad2') + +nltk.download('punkt') diff --git a/app/search_logic.py b/app/search_logic.py index 34673ea..b1595c3 100644 --- a/app/search_logic.py +++ b/app/search_logic.py @@ -25,7 +25,6 @@ BI_ENCODER_CANDIDATES = 60 if torch.cuda.is_available() else 20 SMALL_CROSS_ENCODER_CANDIDATES = 30 if torch.cuda.is_available() else 10 -nltk.download('punkt') logger = logging.getLogger(__name__)