forked from HappyGO2023/simple-chatpdf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding.py
77 lines (70 loc) · 2.56 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain_community.embeddings import OllamaEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
import PyPDF2
from langchain.docstore.document import Document
import re
# os.environ["OPENAI_API_KEY"] = "{your-api-key}"
# 加载PDF文件
def load_pdf(pdf_file):
pdf_file = open('KOS:2023中国市场招聘趋势.pdf', 'rb')
#pdf_file = open('test.pdf', 'rb')
pdf_reader = PyPDF2.PdfReader(pdf_file)
text = ''
for num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[num]
text += page.extract_text()
return text
# 自定义句子分段的方式,保证句子不被截断
def split_paragraph(text, pdf_name, max_length=300):
text = text.replace('\n', '')
text = text.replace('\n\n', '')
text = re.sub(r'\s+', ' ', text)
"""
将文章分段
"""
# 首先按照句子分割文章
sentences = re.split('(;|。|!|\!|\.|?|\?)',text)
new_sents = []
for i in range(int(len(sentences)/2)):
sent = sentences[2*i] + sentences[2*i+1]
new_sents.append(sent)
if len(sentences) % 2 == 1:
new_sents.append(sentences[len(sentences)-1])
# 按照要求分段
paragraphs = []
current_length = 0
current_paragraph = ""
for sentence in new_sents:
sentence_length = len(sentence)
if current_length + sentence_length <= max_length:
current_paragraph += sentence
current_length += sentence_length
else:
paragraphs.append(current_paragraph.strip())
current_paragraph = sentence
current_length = sentence_length
paragraphs.append(current_paragraph.strip())
documents = []
metadata = {"source": pdf_name}
for paragraph in paragraphs:
new_doc = Document(page_content=paragraph, metadata=metadata)
documents.append(new_doc)
return documents
# 持久化向量数据
def persist_embedding(documents):
# 将embedding数据持久化到本地磁盘
persist_directory = 'db'
# embedding = OpenAIEmbeddings()
embedding = OllamaEmbeddings()
vectordb = Chroma.from_documents(documents=documents, embedding=embedding, persist_directory=persist_directory)
vectordb.persist()
vectordb = None
if __name__ == "__main__":
# embdding并且持久化
pdf_name = "KOS:2023中国市场招聘趋势.pdf"
content = load_pdf(pdf_name)
documents = split_paragraph(content, pdf_name)
persist_embedding(documents)