-
Notifications
You must be signed in to change notification settings - Fork 371
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
30 changed files
with
687 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from django.db import models | ||
from django.conf import settings | ||
|
||
|
||
class PromptLog(models.Model): | ||
|
||
class Meta: | ||
app_label = "explorer" | ||
|
||
prompt = models.TextField(blank=True) | ||
response = models.TextField(blank=True) | ||
run_by_user = models.ForeignKey( | ||
settings.AUTH_USER_MODEL, | ||
null=True, | ||
blank=True, | ||
on_delete=models.CASCADE | ||
) | ||
run_at = models.DateTimeField(auto_now_add=True) | ||
duration = models.FloatField(blank=True, null=True) # seconds | ||
model = models.CharField(blank=True, max_length=128, default="") | ||
error = models.TextField(blank=True, null=True) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
primary_prompt = { | ||
"system": """ | ||
You are a data analyst's assistant and will be asked write or modify a SQL query to assist a business | ||
user with their analysis. The user will provide a prompt of what they are looking for help with, and may also | ||
provide SQL they have written so far, relevant table schema, and sample rows from the tables they are querying. | ||
For complex requests, you may use Common Table Expressions (CTEs) to break down the problem into smaller parts. | ||
CTEs are not needed for simpler requests. | ||
""", | ||
"user": "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from explorer.tests.factories import SimpleQueryFactory | ||
from unittest.mock import patch, Mock | ||
import unittest | ||
|
||
import json | ||
from django.test import TestCase | ||
from django.urls import reverse | ||
from django.contrib.auth.models import User | ||
from explorer.app_settings import EXPLORER_DEFAULT_CONNECTION as CONN | ||
from explorer import app_settings | ||
|
||
|
||
class TestAssistantViews(TestCase): | ||
|
||
def setUp(self): | ||
self.user = User.objects.create_superuser( | ||
"admin", "[email protected]", "pwd" | ||
) | ||
self.client.login(username="admin", password="pwd") | ||
self.request_data = { | ||
"sql": "SELECT * FROM explorer_query", | ||
"connection": CONN, | ||
"assistant_request": "Test Request" | ||
} | ||
|
||
@unittest.skipIf(not app_settings.EXPLORER_AI_API_KEY, "assistant not enabled") | ||
@patch("explorer.assistant.utils.openai_client") | ||
def test_do_modify_query(self, mocked_openai_client): | ||
from explorer.assistant.views import run_assistant | ||
|
||
# create.return_value should match: resp.choices[0].message | ||
mocked_openai_client.return_value.chat.completions.create.return_value = Mock( | ||
choices=[Mock(message=Mock(content="smart computer"))]) | ||
resp = run_assistant(self.request_data, None) | ||
self.assertEqual(resp, "smart computer") | ||
|
||
@unittest.skipIf(not app_settings.EXPLORER_AI_API_KEY, "assistant not enabled") | ||
def test_assistant_help(self): | ||
resp = self.client.post(reverse("assistant"), | ||
data=json.dumps(self.request_data), | ||
content_type="application/json") | ||
self.assertIsNone(json.loads(resp.content)["message"]) | ||
|
||
|
||
class TestPromptContext(TestCase): | ||
|
||
def test_retrieves_sample_rows(self): | ||
from explorer.assistant.utils import sample_rows_from_table, ROW_SAMPLE_SIZE | ||
SimpleQueryFactory(title="First Query") | ||
SimpleQueryFactory(title="Second Query") | ||
SimpleQueryFactory(title="Third Query") | ||
SimpleQueryFactory(title="Fourth Query") | ||
ret = sample_rows_from_table(CONN, "explorer_query") | ||
self.assertEqual(len(ret), ROW_SAMPLE_SIZE+1) # includes header row | ||
|
||
def test_format_rows_from_table(self): | ||
from explorer.assistant.utils import format_rows_from_table | ||
d = [ | ||
["col1", "col2"], | ||
["val1", "val2"], | ||
] | ||
ret = format_rows_from_table(d) | ||
self.assertEqual(ret, "col1 | col2\n" + "-" * 50 + "\nval1 | val2\n") | ||
|
||
def test_parsing_tables_from_query(self): | ||
from explorer.assistant.utils import get_table_names_from_query | ||
sql = "SELECT * FROM explorer_query" | ||
ret = get_table_names_from_query(sql) | ||
self.assertEqual(ret, ["explorer_query"]) | ||
|
||
def test_parsing_tables_from_query_bad_sql(self): | ||
from explorer.assistant.utils import get_table_names_from_query | ||
sql = "foo" | ||
ret = get_table_names_from_query(sql) | ||
self.assertEqual(ret, []) | ||
|
||
def test_schema_info_from_table_names(self): | ||
from explorer.assistant.utils import tables_from_schema_info | ||
ret = tables_from_schema_info(CONN, ["explorer_query"]) | ||
expected = [("explorer_query", [ | ||
("id", "AutoField"), | ||
("title", "CharField"), | ||
("sql", "TextField"), | ||
("description", "TextField"), | ||
("created_at", "DateTimeField"), | ||
("last_run_date", "DateTimeField"), | ||
("created_by_user_id", "IntegerField"), | ||
("snapshot", "BooleanField"), | ||
("connection", "CharField")])] | ||
self.assertEqual(ret, expected) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import logging | ||
from explorer import app_settings | ||
from explorer.schema import schema_info | ||
from explorer.utils import get_valid_connection | ||
from sql_metadata import Parser | ||
from django.db.utils import OperationalError | ||
|
||
if app_settings.EXPLORER_AI_API_KEY: | ||
import tiktoken | ||
from openai import OpenAI | ||
|
||
OPENAI_MODEL = app_settings.EXPLORER_ASSISTANT_MODEL["name"] | ||
ROW_SAMPLE_SIZE = 2 | ||
|
||
|
||
def openai_client(): | ||
return OpenAI( | ||
api_key=app_settings.EXPLORER_AI_API_KEY, | ||
base_url=app_settings.EXPLORER_ASSISTANT_BASE_URL | ||
) | ||
|
||
|
||
def do_req(prompt): | ||
messages = [ | ||
{"role": "system", "content": prompt["system"]}, | ||
{"role": "user", "content": prompt["user"]}, | ||
] | ||
resp = openai_client().chat.completions.create( | ||
model=OPENAI_MODEL, | ||
messages=messages | ||
) | ||
messages.append(resp.choices[0].message) | ||
logging.info(f"Response: {messages}") | ||
return messages | ||
|
||
|
||
def extract_response(r): | ||
return r[-1].content | ||
|
||
|
||
def tables_from_schema_info(connection, table_names): | ||
schema = schema_info(connection) | ||
return [table for table in schema if table[0] in set(table_names)] | ||
|
||
|
||
def sample_rows_from_tables(connection, table_names): | ||
ret = None | ||
for table_name in table_names: | ||
ret = f"SAMPLE FROM TABLE {table_name}:\n" | ||
ret = ret + format_rows_from_table( | ||
sample_rows_from_table(connection, table_name) | ||
) + "\n\n" | ||
return ret | ||
|
||
|
||
def sample_rows_from_table(connection, table_name): | ||
conn = get_valid_connection(connection) | ||
cursor = conn.cursor() | ||
try: | ||
cursor.execute(f"SELECT * FROM {table_name} LIMIT {ROW_SAMPLE_SIZE}") | ||
ret = [[header[0] for header in cursor.description]] | ||
ret = ret + cursor.fetchall() | ||
return ret | ||
except OperationalError as e: | ||
return [[str(e)]] | ||
|
||
|
||
def format_rows_from_table(rows): | ||
column_headers = list(rows[0]) | ||
ret = " | ".join(column_headers) + "\n" + "-" * 50 + "\n" | ||
for row in rows[1:]: | ||
row_str = " | ".join(str(item) for item in row) | ||
ret += row_str + "\n" | ||
return ret | ||
|
||
|
||
def get_table_names_from_query(sql): | ||
if sql: | ||
try: | ||
parsed = Parser(sql) | ||
return parsed.tables | ||
except ValueError: | ||
return [] | ||
return [] | ||
|
||
|
||
def num_tokens_from_string(string: str) -> int: | ||
"""Returns the number of tokens in a text string.""" | ||
encoding = tiktoken.encoding_for_model(OPENAI_MODEL) | ||
num_tokens = len(encoding.encode(string)) | ||
return num_tokens | ||
|
||
|
||
def fits_in_window(string: str) -> bool: | ||
# Ratchet down by 5% to account for other boilerplate and system prompt | ||
# TODO make this better by actually looking at the token count of the system prompt | ||
return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95) |
Oops, something went wrong.