Skip to content

Commit

Permalink
Sql assistant (#569)
Browse files Browse the repository at this point in the history
* LLM-powered SQL Assistant
  • Loading branch information
chrisclark authored Apr 9, 2024
1 parent 82fc21c commit 13fd67d
Show file tree
Hide file tree
Showing 30 changed files with 687 additions and 21 deletions.
2 changes: 2 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ This project adheres to `Semantic Versioning <https://semver.org/>`_.

`4.1.0`_ (TBD)
===========================
* SQL Assistant: Built in query help via OpenAI (or LLM of choice), with relevant schema
automatically injected into the prompt. Enable via setting EXPLORER_AI_API_KEY.
* Anonymous usage telemetry. Can be disabled by setting EXPLORER_ENABLE_ANONYMOUS_STATS to False
* `#594`_: Eliminate <script> tags to prevent potential Content Security Policy issues

Expand Down
11 changes: 8 additions & 3 deletions docs/features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ Security

Easy to get started
-------------------
- Built on Django's ORM, so works with Postgresql, Mysql, and
Sqlite. And, between you and me, it works fine on RedShift as
well.
- Built on Django's ORM, so works with MySQL, Postgres, Oracle,
SQLite, Snowflake, MS SQL Server, RedShift, and MariaDB.
- Small number of dependencies.
- Just want to get in and write some ad-hoc queries? Go nuts with
the Playground area.

SQL Assistant
-------------
- Built in integration with OpenAI (or the LLM of your choosing)
to quickly get help with your query, with relevant schema
automatically injected into the prompt.

Snapshots
---------
- Tick the 'snapshot' box on a query, and Explorer will upload a
Expand Down
32 changes: 32 additions & 0 deletions docs/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,38 @@ And run the server:

You can now browse to http://127.0.0.1:8000/explorer/ and get exploring!

Note that Explorer expects STATIC_URL to be set appropriately. This isn't a problem
with vanilla Django setups, but if you are using e.g. Django Storages with S3, you
must set your STATIC_URL to point to your S3 bucket (e.g. s3_bucket_url + '/static/')

AI SQL Assistant
----------------
To enable AI features, you must install the OpenAI SDK and Tiktoken library from
requirements/optional.txt. By default the Assistant is configured to use OpenAI and
the `gpt-4-0125-preview` model. To use those settings, set an OpenAI API token in
your project's settings.py file:

``EXPLORER_AI_API_KEY = 'your_openai_api_key'``

Or, more likely:

``EXPLORER_AI_API_KEY = os.environ.get("OPENAI_API_KEY")``

If you would prefer to use a different provider and/or different model, you can
also override the AI API URL root and default model. For example, this would configure
the Assistant to use OpenRouter and Mixtral 8x7B Instruct:

.. code-block:: python
:emphasize-lines: 5
EXPLORER_ASSISTANT_MODEL = {"name": "mistralai/mixtral-8x7b-instruct:nitro",
"max_tokens": 32768})
EXPLORER_ASSISTANT_BASE_URL = "https://openrouter.ai/api/v1"
EXPLORER_AI_API_KEY = os.environ.get("OPENROUTER_API_KEY")
Other Parameters
----------------

The default behavior when viewing a parameterized query is to autorun the associated
SQL with the default parameter values. This may perform poorly and you may want
a chance for your users to review the parameters before running. If so you may add
Expand Down
6 changes: 3 additions & 3 deletions explorer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
__version_info__ = {
"major": 4,
"minor": 0,
"patch": 2,
"releaselevel": "final",
"minor": 1,
"patch": 0,
"releaselevel": "beta",
"serial": 0
}

Expand Down
9 changes: 9 additions & 0 deletions explorer/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,12 @@
# - user will need to run by clicking the Save & Run Button to execute
EXPLORER_AUTORUN_QUERY_WITH_PARAMS = getattr(settings, "EXPLORER_AUTORUN_QUERY_WITH_PARAMS", True)
VITE_DEV_MODE = getattr(settings, "VITE_DEV_MODE", False)


# AI Assistant settings. Setting the first to an OpenAI key is the simplest way to enable the assistant
EXPLORER_AI_API_KEY = getattr(settings, "EXPLORER_AI_API_KEY", None)
EXPLORER_ASSISTANT_BASE_URL = getattr(settings, "EXPLORER_ASSISTANT_BASE_URL", "https://api.openai.com/v1")
EXPLORER_ASSISTANT_MODEL = getattr(settings, "EXPLORER_ASSISTANT_MODEL",
# Return the model name and max_tokens it supports
{"name": "gpt-4-0125-preview",
"max_tokens": 128000})
Empty file added explorer/assistant/__init__.py
Empty file.
21 changes: 21 additions & 0 deletions explorer/assistant/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from django.db import models
from django.conf import settings


class PromptLog(models.Model):

class Meta:
app_label = "explorer"

prompt = models.TextField(blank=True)
response = models.TextField(blank=True)
run_by_user = models.ForeignKey(
settings.AUTH_USER_MODEL,
null=True,
blank=True,
on_delete=models.CASCADE
)
run_at = models.DateTimeField(auto_now_add=True)
duration = models.FloatField(blank=True, null=True) # seconds
model = models.CharField(blank=True, max_length=128, default="")
error = models.TextField(blank=True, null=True)
11 changes: 11 additions & 0 deletions explorer/assistant/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
primary_prompt = {
"system": """
You are a data analyst's assistant and will be asked write or modify a SQL query to assist a business
user with their analysis. The user will provide a prompt of what they are looking for help with, and may also
provide SQL they have written so far, relevant table schema, and sample rows from the tables they are querying.
For complex requests, you may use Common Table Expressions (CTEs) to break down the problem into smaller parts.
CTEs are not needed for simpler requests.
""",
"user": ""
}
90 changes: 90 additions & 0 deletions explorer/assistant/tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from explorer.tests.factories import SimpleQueryFactory
from unittest.mock import patch, Mock
import unittest

import json
from django.test import TestCase
from django.urls import reverse
from django.contrib.auth.models import User
from explorer.app_settings import EXPLORER_DEFAULT_CONNECTION as CONN
from explorer import app_settings


class TestAssistantViews(TestCase):

def setUp(self):
self.user = User.objects.create_superuser(
"admin", "[email protected]", "pwd"
)
self.client.login(username="admin", password="pwd")
self.request_data = {
"sql": "SELECT * FROM explorer_query",
"connection": CONN,
"assistant_request": "Test Request"
}

@unittest.skipIf(not app_settings.EXPLORER_AI_API_KEY, "assistant not enabled")
@patch("explorer.assistant.utils.openai_client")
def test_do_modify_query(self, mocked_openai_client):
from explorer.assistant.views import run_assistant

# create.return_value should match: resp.choices[0].message
mocked_openai_client.return_value.chat.completions.create.return_value = Mock(
choices=[Mock(message=Mock(content="smart computer"))])
resp = run_assistant(self.request_data, None)
self.assertEqual(resp, "smart computer")

@unittest.skipIf(not app_settings.EXPLORER_AI_API_KEY, "assistant not enabled")
def test_assistant_help(self):
resp = self.client.post(reverse("assistant"),
data=json.dumps(self.request_data),
content_type="application/json")
self.assertIsNone(json.loads(resp.content)["message"])


class TestPromptContext(TestCase):

def test_retrieves_sample_rows(self):
from explorer.assistant.utils import sample_rows_from_table, ROW_SAMPLE_SIZE
SimpleQueryFactory(title="First Query")
SimpleQueryFactory(title="Second Query")
SimpleQueryFactory(title="Third Query")
SimpleQueryFactory(title="Fourth Query")
ret = sample_rows_from_table(CONN, "explorer_query")
self.assertEqual(len(ret), ROW_SAMPLE_SIZE+1) # includes header row

def test_format_rows_from_table(self):
from explorer.assistant.utils import format_rows_from_table
d = [
["col1", "col2"],
["val1", "val2"],
]
ret = format_rows_from_table(d)
self.assertEqual(ret, "col1 | col2\n" + "-" * 50 + "\nval1 | val2\n")

def test_parsing_tables_from_query(self):
from explorer.assistant.utils import get_table_names_from_query
sql = "SELECT * FROM explorer_query"
ret = get_table_names_from_query(sql)
self.assertEqual(ret, ["explorer_query"])

def test_parsing_tables_from_query_bad_sql(self):
from explorer.assistant.utils import get_table_names_from_query
sql = "foo"
ret = get_table_names_from_query(sql)
self.assertEqual(ret, [])

def test_schema_info_from_table_names(self):
from explorer.assistant.utils import tables_from_schema_info
ret = tables_from_schema_info(CONN, ["explorer_query"])
expected = [("explorer_query", [
("id", "AutoField"),
("title", "CharField"),
("sql", "TextField"),
("description", "TextField"),
("created_at", "DateTimeField"),
("last_run_date", "DateTimeField"),
("created_by_user_id", "IntegerField"),
("snapshot", "BooleanField"),
("connection", "CharField")])]
self.assertEqual(ret, expected)
97 changes: 97 additions & 0 deletions explorer/assistant/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import logging
from explorer import app_settings
from explorer.schema import schema_info
from explorer.utils import get_valid_connection
from sql_metadata import Parser
from django.db.utils import OperationalError

if app_settings.EXPLORER_AI_API_KEY:
import tiktoken
from openai import OpenAI

OPENAI_MODEL = app_settings.EXPLORER_ASSISTANT_MODEL["name"]
ROW_SAMPLE_SIZE = 2


def openai_client():
return OpenAI(
api_key=app_settings.EXPLORER_AI_API_KEY,
base_url=app_settings.EXPLORER_ASSISTANT_BASE_URL
)


def do_req(prompt):
messages = [
{"role": "system", "content": prompt["system"]},
{"role": "user", "content": prompt["user"]},
]
resp = openai_client().chat.completions.create(
model=OPENAI_MODEL,
messages=messages
)
messages.append(resp.choices[0].message)
logging.info(f"Response: {messages}")
return messages


def extract_response(r):
return r[-1].content


def tables_from_schema_info(connection, table_names):
schema = schema_info(connection)
return [table for table in schema if table[0] in set(table_names)]


def sample_rows_from_tables(connection, table_names):
ret = None
for table_name in table_names:
ret = f"SAMPLE FROM TABLE {table_name}:\n"
ret = ret + format_rows_from_table(
sample_rows_from_table(connection, table_name)
) + "\n\n"
return ret


def sample_rows_from_table(connection, table_name):
conn = get_valid_connection(connection)
cursor = conn.cursor()
try:
cursor.execute(f"SELECT * FROM {table_name} LIMIT {ROW_SAMPLE_SIZE}")
ret = [[header[0] for header in cursor.description]]
ret = ret + cursor.fetchall()
return ret
except OperationalError as e:
return [[str(e)]]


def format_rows_from_table(rows):
column_headers = list(rows[0])
ret = " | ".join(column_headers) + "\n" + "-" * 50 + "\n"
for row in rows[1:]:
row_str = " | ".join(str(item) for item in row)
ret += row_str + "\n"
return ret


def get_table_names_from_query(sql):
if sql:
try:
parsed = Parser(sql)
return parsed.tables
except ValueError:
return []
return []


def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
encoding = tiktoken.encoding_for_model(OPENAI_MODEL)
num_tokens = len(encoding.encode(string))
return num_tokens


def fits_in_window(string: str) -> bool:
# Ratchet down by 5% to account for other boilerplate and system prompt
# TODO make this better by actually looking at the token count of the system prompt
return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95)
Loading

0 comments on commit 13fd67d

Please sign in to comment.