Skip to content

Commit

Permalink
basic sql assistant
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisclark committed Jan 11, 2024
1 parent d5274ad commit 945626a
Show file tree
Hide file tree
Showing 15 changed files with 554 additions and 3,982 deletions.
1 change: 1 addition & 0 deletions explorer/app_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@
# If set to False will not autorun queries containing parameters when viewed
# - user will need to run by clicking the Save & Run Button to execute
EXPLORER_AUTORUN_QUERY_WITH_PARAMS = getattr(settings, "EXPLORER_AUTORUN_QUERY_WITH_PARAMS", True)
EXPLORER_OPENAI_API_KEY = getattr(settings, "OPENAI_API_KEY", None)
Empty file added explorer/assistant/__init__.py
Empty file.
24 changes: 24 additions & 0 deletions explorer/assistant/prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
modify_prompt = {
'system': '''
You are a data analyst and have been asked to modify an existing SQL query to assist a business user with their analysis.
The user has provided a description of the analysis they are trying to perform, and the SQL they have written so far.
You will be asked to modify the SQL to complete the analysis.
You will also receive a description of the table structure.
''',
'user': '''
## Existing SQL ##
{sql}
## Modification Request ##
{description}
## Table Structure ##
{table_structure}'''
}

PROMPT_MAP = {
"modify_query": modify_prompt,
"new_query": None,
"debug_query": None,
"schema_help": None
}
60 changes: 60 additions & 0 deletions explorer/assistant/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import logging
from openai import OpenAI
import pandas as pd
from io import StringIO
import sqlparse
from explorer.app_settings import EXPLORER_OPENAI_API_KEY


OPENAI_MODEL = 'gpt-4'
openai_client = OpenAI(
api_key=EXPLORER_OPENAI_API_KEY
)


TABLE_NAME = 'tbl'
TABLE_STRUCT_QUERY = f'PRAGMA table_info({TABLE_NAME})'
TABLE_SAMPLE_QUERY = f'SELECT * FROM {TABLE_NAME} ORDER BY RANDOM() LIMIT 20'


def do_req(prompt, history=None):
if not history:
history = []
messages = history + [
{"role": "system", "content": prompt['system']},
{"role": "user", "content": prompt['user']},
]
print(messages)
resp = openai_client.chat.completions.create(
model=OPENAI_MODEL,
messages=messages
)
messages.append(resp.choices[0].message)
logging.info(f'Response: {messages}')
return messages


def df_to_string(df):
output = StringIO()
df.to_csv(output, index=False)
return output.getvalue()


def q(query, conn):
df = pd.read_sql_query(query, conn)
return df


def format_prompt(p, format_data):
ret = {}
for k, v in p.items():
ret[k] = v.format(**format_data)
return ret


def format_sql(q):
return sqlparse.format(q, reindent=True, keyword_case='upper')


def extract_response(r):
return r[-1].content
54 changes: 54 additions & 0 deletions explorer/assistant/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from django.http import JsonResponse
import json

from sql_metadata import Parser
from explorer.schema import schema_info
from explorer.assistant.prompts import PROMPT_MAP
from explorer.assistant.utils import format_prompt, do_req, extract_response


def tables_from_schema_info(connection, table_names):
schema = schema_info(connection)
return [table for table in schema if table[0] in table_names]


def do_modify_query(request_data):
sql = request_data['sql']
connection = request_data['connection']
assistant_request = request_data['assistant_request']
parsed = Parser(sql)
table_struct = tables_from_schema_info(connection, parsed.tables)
prompt = format_prompt(PROMPT_MAP['modify_query'],
{
'sql': sql,
'description': assistant_request,
'table_structure': table_struct
})
resp = do_req(prompt)
return extract_response(resp)


def assistant_help(request):
if request.method == 'POST':
try:
data = json.loads(request.body)
assistant_function = data['assistant_function']

if assistant_function == 'modify_query':
resp = do_modify_query(data)
else:
resp = None

response_data = {
'status': 'success',
'message': resp
}

return JsonResponse(response_data)

except json.JSONDecodeError:
return JsonResponse({'status': 'error', 'message': 'Invalid JSON'}, status=400)

else:
return JsonResponse({'status': 'error', 'message': 'Invalid request method'}, status=405)

65 changes: 65 additions & 0 deletions explorer/src/js/assistant.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import axios from "axios";
import {getCsrfToken} from "./csrf";
import { marked } from "marked";
import DOMPurify from "dompurify";

const assistantOptions = {
"modify_query": "Modify or enhance the current query to change its behavior.",
"new_query": "Construct a new SQL query from scratch, given a description of what the query should do.",
"debug_query": "Determine (and try to automatically fix) why the current query is not running.",
"schema_help": "Inquire about the current database schema."
}

const spinner = "<div class=\"spinner-border text-primary\" role=\"status\"><span class=\"visually-hidden\">Loading...</span></div>";

function getCheckedAssistantValue() {
const checkedRadio = document.querySelector('input[name="assistant-option"]:checked');
if (checkedRadio) {
return checkedRadio.id;
}
return null;
}

function updateDivText(value) {
document.getElementById('id_assistant_input_label').textContent = assistantOptions[value];
}

export function setUpAssistant() {

document.querySelectorAll('input[name="assistant-option"]').forEach(function(radio) {
radio.addEventListener('change', function() {
if (this.checked) {
updateDivText(this.id);
}
});
});

updateDivText(getCheckedAssistantValue());

document.getElementById('ask_assistant_btn').addEventListener('click', function() {
const data = {
assistant_function: getCheckedAssistantValue(),
sql: window.editor.state.doc.toString(),
connection: document.getElementById("id_connection").value,
assistant_request: document.getElementById("id_assistant_input").value
};

document.getElementById("response_block").style.display = "block";
document.getElementById("assistant_response").innerHTML = spinner;

axios.post('/assistant/', data, {
headers: {
'X-CSRFToken': getCsrfToken()
}
})
.then(function (response) {
const output = DOMPurify.sanitize(marked.parse(response.data.message));
document.getElementById("response_block").style.display = "block";
document.getElementById("assistant_response").innerHTML = output;
console.log(response.data.message)
})
.catch(function (error) {
console.error('Error:', error);
});
});
}
2 changes: 2 additions & 0 deletions explorer/src/js/explorer.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ export class ExplorerEditor {

this.editor = editorFromTextArea(document.getElementById("id_sql"));

window.editor = this.editor;

document.addEventListener('submitEventFromCM', (e) => {
this.$submit.click();
});
Expand Down
6 changes: 5 additions & 1 deletion explorer/src/js/main.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ import * as bootstrap from 'bootstrap'; // eslint-disable-line no-unused-vars
import {ExplorerEditor} from "./explorer"
import {setupQueryList} from "./query-list"
import {setupSchema} from "./schema";
import {setUpAssistant} from "./assistant";


const route_initializers = {
explorer_index: setupQueryList,
query_detail: () => new ExplorerEditor(queryId),
query_detail: () => {
new ExplorerEditor(queryId);
setUpAssistant();
},
query_create: () => new ExplorerEditor('new'),
explorer_playground: () => new ExplorerEditor('new'),
explorer_schema: setupSchema
Expand Down
5 changes: 5 additions & 0 deletions explorer/src/scss/explorer.css
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,8 @@ div.sort {
.query_favourite_detail {
float: right;
}

#assistant_description_label {
white-space: pre-wrap;

}
42 changes: 42 additions & 0 deletions explorer/templates/explorer/assistant.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
<div class ="accordion accordion-flush my-3" id="assistant_accordion">
<div class="accordion-item">
<div class="accordion-header" id="assistant_accordion_header">
<button class="accordion-button bg-light collapsed" type="button" data-bs-toggle="collapse" data-bs-target="#assistant_collapse" aria-expanded="false" aria-controls="assistant_collapse">
<label for="id_assistant_input">SQL Assistant</label>
</button>
</div>
</div>
<div id="assistant_collapse" class="accordion-collapse collapse" aria-labelledby="assistant_accordion_header" data-bs-parent="#assistant_accordion">
<div class="accordion-body">

<input type="radio" class="btn-check" name="assistant-option" id="modify_query" autocomplete="off" checked>
<label class="btn" for="modify_query">Modify Query</label>

<input type="radio" class="btn-check" name="assistant-option" id="new_query" autocomplete="off">
<label class="btn" for="new_query">New Query</label>

<input type="radio" class="btn-check" name="assistant-option" id="debug_query" autocomplete="off">
<label class="btn" for="debug_query">Debug Query</label>

<input type="radio" class="btn-check" name="assistant-option" id="schema_help" autocomplete="off">
<label class="btn" for="schema_help">Schema Help</label>

<div class="form-floating mt-3">
<textarea
class="form-control" id="id_assistant_input"
name="sql_assistant"></textarea>
<label for="id_assistant_input" class="form-label" id="id_assistant_input_label"></label>
</div>
<div class="mt-3 text-end">
<div class="btn-group" role="group">
<button type="button" class="btn btn-outline-secondary" id="ask_assistant_btn">Ask Assistant</button>
</div>
</div>
<div id="response_block" style="display: none" class="position-relative">
<div class="position-absolute start-50 translate-middle top-0 px-2 bg-white rounded-2">Response</div>
<div class="mt-3 p-2 rounded-2 border bg-light" id="assistant_response"></div>
</div>
</div>
</div>
</div>

1 change: 1 addition & 0 deletions explorer/templates/explorer/query.html
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ <h2>
{% endif %}
</div>
</div>
{% include 'explorer/assistant.html' %}
</form>
</div>
</div>
Expand Down
2 changes: 2 additions & 0 deletions explorer/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
StreamQueryView, format_sql,
)

from assistant.views import assistant_help

urlpatterns = [
path(
Expand Down Expand Up @@ -42,4 +43,5 @@
path('favorites/', QueryFavoritesView.as_view(), name='query_favorites'),
path('favorite/<int:query_id>', QueryFavoriteView.as_view(), name='query_favorite'),
path('', ListQueryView.as_view(), name='explorer_index'),
path('assistant/', assistant_help, name='assistant'),
]
Loading

0 comments on commit 945626a

Please sign in to comment.