-
Notifications
You must be signed in to change notification settings - Fork 34
/
Copy pathprompts.py
144 lines (117 loc) · 4.7 KB
/
prompts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import duckdb
import pandas as pd
import openai
import ollama
class SqlPrompt:
"""
Setting Up the Prompt Features
"""
def __init__(self, table):
self.table = table
self.question = None
self.message = None
def get_table_schema(self):
"""
The function uses DuckDB and SQL to extract table schema from Pandas DataFrame object
"""
tbl_describe = duckdb.sql("DESCRIBE SELECT * FROM " + self.table + ";")
col_attr = tbl_describe.df()[["column_name", "column_type"]]
col_attr["column_joint"] = col_attr["column_name"] + " " + col_attr["column_type"]
self.schema = str(list(col_attr["column_joint"].values)).replace('[', '').replace(']', '').replace('\'', '')
self.column_names = col_attr["column_name"]
self.column_type = col_attr["column_type"]
def set_prompt(self, question):
if "schema" not in self.__dict__.keys():
self.get_table_schema()
system_template = """
Given the following SQL table, your job is to write queries given a user’s request. \n
CREATE TABLE {} ({}) \n
"""
user_template = "Write a SQL query that returns - {}"
self.system = system_template.format(self.table, self.schema)
self.user = user_template.format(question)
self.message = [
{
"role": "system",
"content": self.system
},
{
"role": "user",
"content": self.user
}
]
def openai_request(self,
openai_api_key,
model = "gpt-3.5-turbo",
temperature = 0,
max_tokens = 256,
frequency_penalty = 0,
presence_penalty= 0):
openai.api_key = openai_api_key
self.openai_response = openai.ChatCompletion.create(
model = model,
messages = self.message,
temperature = temperature,
max_tokens = max_tokens,
frequency_penalty = frequency_penalty,
presence_penalty = presence_penalty)
self.query = add_quotes(query = self.openai_response["choices"][0]["message"]["content"],
col_names = self.column_names)
self.query = remove_code_chunk(self.query)
def get_data(self):
if self.message is None:
print("The prompt is not defined")
return
else:
print(duckdb.sql(self.query))
self.data = duckdb.sql(self.query)
def ask_question(self,
question,
openai_api_key,
model = "gpt-3.5-turbo",
temperature = 0,
max_tokens = 256,
frequency_penalty = 0,
presence_penalty= 0):
self.set_prompt(question)
self.openai_request(openai_api_key = openai_api_key,
model = "gpt-3.5-turbo",
temperature = 0,
max_tokens = 256,
frequency_penalty = 0,
presence_penalty= 0)
# self.get_data()
def ask_ollama(self,
question,
model):
self.set_prompt(question)
response = ollama.chat(model=model, messages = self.message)
self.ollama_response = ollama.chat(model=model, messages = self.message)
self.markdown = add_quotes(query = self.ollama_response['message']['content'],
col_names = self.column_names)
self.markdown = remove_text(query = self.markdown)
self.query = remove_code_chunk(query=self.markdown)
def remove_text(query):
if "```sql\n" in query and "```\n" in query:
s = query.find("```sql\n")
e = query.find( "```\n") + 4
query = query[s:e]
return query
else:
return query
def remove_code_chunk(query):
if "```sql\n" in query:
query = str(query).replace( "```sql\n", "")
if "```" in query:
query = str(query).replace( "```", "")
return query
def add_quotes(query, col_names):
"""
Helper function to parse the quotes from a returned query
"""
for i in col_names:
s = " " + i + " "
if s in query:
r = ' "' + i + '" '
query = str(query).replace(s, r)
return(query)