-
Notifications
You must be signed in to change notification settings - Fork 117
/
main.py
70 lines (57 loc) · 1.83 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# import libraries
import instructor
from pydantic import BaseModel
from openai import OpenAI
from lancedb.pydantic import Vector, LanceModel
import lancedb
from langchain_openai import OpenAIEmbeddings
from datasets import load_dataset
# load the dataset and convert to pandas dataframe
df = load_dataset(
"fabiochiu/medium-articles", data_files="medium_articles.csv", split="train"
).to_pandas()
df = df.dropna().sample(20000, random_state=32)
# select first 1000 characters from each article
df["text"] = df["text"].str[:1000]
# join article title and the text
df["title_text"] = df["title"] + ". " + df["text"]
# schema for table
class UserData(LanceModel):
vector: Vector(1536)
headline: str
content: str
entity: str
sentiment: str
news_article: bool
# schema for instructor output
class structureData(BaseModel):
headline: str
entity: str
sentiment: str
news: bool
# Patch the OpenAI client
client = instructor.from_openai(OpenAI())
openai_embedding = OpenAIEmbeddings(model="text-embedding-3-small")
# connect lancedb
db = lancedb.connect("~/.lancedb")
table_name = "instructor_lancedb"
table = db.create_table(table_name, schema=UserData, mode="overwrite")
for index, row in df[:10].iterrows():
# generate response
structured_info = client.chat.completions.create(
model="gpt-3.5-turbo",
response_model=structureData,
messages=[{"role": "user", "content": row["title_text"]}],
)
embedding = openai_embedding.embed_query(row["title_text"])
userdata = UserData(
vector=embedding,
headline=structured_info.headline,
content=row["title_text"],
entity=structured_info.entity,
sentiment=structured_info.sentiment,
news_article=structured_info.news,
)
table.add([userdata])
# show table content
print(table.to_pandas())