Skip to content

Commit

Permalink
assistant improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisclark committed Aug 19, 2024
1 parent 25543c3 commit bc2c1c7
Show file tree
Hide file tree
Showing 18 changed files with 297 additions and 101 deletions.
15 changes: 15 additions & 0 deletions explorer/assistant/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.db import models
from django.conf import settings
from explorer.ee.db_connections.models import DatabaseConnection


class PromptLog(models.Model):
Expand All @@ -19,3 +20,17 @@ class Meta:
duration = models.FloatField(blank=True, null=True) # seconds
model = models.CharField(blank=True, max_length=128, default="")
error = models.TextField(blank=True, null=True)


class TableDescription(models.Model):

class Meta:
app_label = "explorer"
unique_together = ("connection", "table_name")

connection = models.ForeignKey(to=DatabaseConnection, on_delete=models.CASCADE)
table_name = models.CharField(max_length=512)
description = models.TextField()

def __str__(self):
return f"{self.connection.alias} - {self.table_name}"
14 changes: 14 additions & 0 deletions explorer/assistant/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from django.urls import path
from explorer.assistant.views import (TableDescriptionListView,
TableDescriptionCreateView,
TableDescriptionUpdateView,
TableDescriptionDeleteView,
AssistantHelpView)

assistant_urls = [
path("assistant/", AssistantHelpView.as_view(), name="assistant"),
path("table-descriptions/", TableDescriptionListView.as_view(), name="table_description_list"),
path("table-descriptions/new/", TableDescriptionCreateView.as_view(), name="table_description_create"),
path("table-descriptions/<int:pk>/update/", TableDescriptionUpdateView.as_view(), name="table_description_update"),
path("table-descriptions/<int:pk>/delete/", TableDescriptionDeleteView.as_view(), name="table_description_delete"),
]
38 changes: 17 additions & 21 deletions explorer/assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

OPENAI_MODEL = app_settings.EXPLORER_ASSISTANT_MODEL["name"]
ROW_SAMPLE_SIZE = 2
MAX_FIELD_SAMPLE_SIZE = 500 # characters
MAX_FIELD_SAMPLE_SIZE = 200 # characters


def openai_client():
Expand Down Expand Up @@ -42,7 +42,7 @@ def tables_from_schema_info(db_connection, table_names):
def sample_rows_from_tables(connection, table_names):
ret = ""
for table_name in table_names:
ret += f"SAMPLE FROM TABLE {table_name}:\n"
ret += f"SAMPLE FROM TABLE '{table_name}':\n"
ret += format_rows_from_table(
sample_rows_from_table(connection, table_name)
) + "\n\n"
Expand Down Expand Up @@ -95,17 +95,6 @@ def format_rows_from_table(rows):
return ret


def get_table_names_from_query(sql):
from sql_metadata import Parser
if sql:
try:
parsed = Parser(sql)
return parsed.tables
except ValueError:
return []
return []


def num_tokens_from_string(string: str) -> int:
"""Returns the number of tokens in a text string."""
import tiktoken
Expand All @@ -123,30 +112,37 @@ def fits_in_window(string: str) -> bool:
return num_tokens_from_string(string) < (app_settings.EXPLORER_ASSISTANT_MODEL["max_tokens"] * 0.95)


def build_system_prompt(flavor):
bsp = ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value
bsp += f"""\n\nYou are an expert at writing SQL, specifically for {flavor}, and account for the nuances
of this dialect of SQL."""
return bsp


def build_prompt(db_connection, assistant_request, included_tables, query_error=None, sql=None):
user_prompt = ""
djc = db_connection.as_django_connection()
user_prompt += f"## Database Vendor / SQL Flavor is {djc.vendor}\n\n"
sp = build_system_prompt(djc.vendor)

user_prompt = f"## Database Type is {djc.vendor}\n\n"

if query_error:
user_prompt += f"## Query Error ##\n\n{query_error}\n\n"

if sql:
user_prompt += f"## Existing SQL ##\n\n{sql}\n\n"

results_sample = sample_rows_from_tables(djc,
included_tables)
results_sample = sample_rows_from_tables(djc, included_tables)
# If it's too large with sampling, then provide *just* the structure
if fits_in_window(user_prompt + results_sample):
user_prompt += f"## Table Structure with Sampled Data ##\n\n{results_sample}\n\n"
else: # If it's too large with sampling, then provide *just* the structure
table_struct = tables_from_schema_info(db_connection,
included_tables)
else:
table_struct = tables_from_schema_info(db_connection, included_tables)
user_prompt += f"## Table Structure ##\n\n{table_struct}\n\n"

user_prompt += f"## User's Request to Assistant ##\n\n{assistant_request}\n\n"

prompt = {
"system": ExplorerValue.objects.get_item(ExplorerValue.ASSISTANT_SYSTEM_PROMPT).value,
"system": sp,
"user": user_prompt
}
return prompt
34 changes: 31 additions & 3 deletions explorer/assistant/views.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
from django.http import JsonResponse
from django.views import View
from django.utils import timezone
from django.views.generic import ListView, CreateView, UpdateView, DeleteView
from django.urls import reverse_lazy
from .models import TableDescription

import json

from explorer.telemetry import Stat, StatNames
from explorer.ee.db_connections.models import DatabaseConnection
from explorer.assistant.models import PromptLog
from explorer.assistant.utils import (
do_req, extract_response,
get_table_names_from_query,
build_prompt
)


def run_assistant(request_data, user):

sql = request_data.get("sql")
extra_tables = request_data.get("selected_tables", [])
included_tables = get_table_names_from_query(sql) + extra_tables
included_tables = request_data.get("selected_tables", [])

connection_id = request_data.get("connection_id")
try:
Expand Down Expand Up @@ -67,3 +69,29 @@ def post(self, request, *args, **kwargs):
return JsonResponse(response_data)
except json.JSONDecodeError:
return JsonResponse({"status": "error", "message": "Invalid JSON"}, status=400)


class TableDescriptionListView(ListView):
model = TableDescription
template_name = "assistant/table_description_list.html"
context_object_name = "table_descriptions"


class TableDescriptionCreateView(CreateView):
model = TableDescription
template_name = "assistant/table_description_form.html"
fields = ["connection", "table_name", "description"]
success_url = reverse_lazy("table_description_list")


class TableDescriptionUpdateView(UpdateView):
model = TableDescription
template_name = "assistant/table_description_form.html"
fields = ["connection", "table_name", "description"]
success_url = reverse_lazy("table_description_list")


class TableDescriptionDeleteView(DeleteView):
model = TableDescription
template_name = "assistant/table_description_confirm_delete.html"
success_url = reverse_lazy("table_description_list")
26 changes: 26 additions & 0 deletions explorer/migrations/0026_tabledescription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.0.4 on 2024-08-19 14:48

import django.db.models.deletion
from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
('explorer', '0025_remove_query_connection_remove_querylog_connection_and_more'),
]

operations = [
migrations.CreateModel(
name='TableDescription',
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('table_name', models.CharField(max_length=512)),
('description', models.TextField()),
('connection', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='explorer.databaseconnection')),
],
options={
'unique_together': {('connection', 'table_name')},
},
),
]
85 changes: 43 additions & 42 deletions explorer/src/js/assistant.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import {getCsrfToken} from "./csrf";
import { marked } from "marked";
import DOMPurify from "dompurify";
import * as bootstrap from 'bootstrap';
import List from "list.js";
import * as bootstrap from "bootstrap";
import { SchemaSvc, getConnElement } from "./schemaService"
import Choices from "choices.js"

function getErrorMessage() {
const errorElement = document.querySelector('.alert-danger.db-error');
Expand All @@ -13,79 +13,80 @@ function getErrorMessage() {
function setupTableList() {
SchemaSvc.get().then(schema => {
const keys = Object.keys(schema);
const selectElement = document.createElement('select');
selectElement.className = 'js-choice';
selectElement.toggleAttribute('multiple');
selectElement.toggleAttribute('data-trigger');

keys.forEach((key) => {
const option = document.createElement('option');
option.value = key;
option.textContent = key;
selectElement.appendChild(option);
});

const tableList = document.getElementById('table-list');
tableList.innerHTML = '';

keys.forEach((key, index) => {
const div = document.createElement('div');
div.className = 'form-check';

const input = document.createElement('input');
input.className = 'form-check-input table-checkbox';
input.type = 'checkbox';
input.value = key;
input.id = 'flexCheckDefault' + index;

const label = document.createElement('label');
label.className = 'form-check-label';
label.setAttribute('for', input.id);
label.textContent = key;

div.appendChild(input);
div.appendChild(label);
tableList.appendChild(div);
tableList.appendChild(selectElement);

const choices = new Choices('.js-choice', {
removeItemButton: true,
searchEnabled: true,
shouldSort: false,
placeholder: true,
placeholderValue: 'Relevant tables',
position: 'bottom'
});

let options = {
valueNames: ['form-check-label'],
};

new List('additional_table_container', options);

const selectAllButton = document.getElementById('select_all_button');
const checkboxes = document.querySelectorAll('.table-checkbox');

let selectState = 'all';

selectAllButton.innerHTML = 'Select All';

selectAllButton.addEventListener('click', (e) => {
e.preventDefault();
const isSelectingAll = selectState === 'all';
checkboxes.forEach((checkbox) => {
checkbox.checked = isSelectingAll;
choices.setChoiceByValue(keys);
});

const deselectAllButton = document.getElementById('deselect_all_button');
deselectAllButton.addEventListener('click', (e) => {
e.preventDefault();
keys.forEach(k => {
choices.removeActiveItemsByValue(k);
});
selectState = isSelectingAll ? 'none' : 'all';
selectAllButton.innerHTML = isSelectingAll ? 'Deselect All' : 'Select All';
});

document.addEventListener('docChanged', (e) => {
const textContent = window.editor.state.doc.toString();
const textWords = new Set(textContent.split(/\s+/));
const hasKeys = keys.filter(key => textWords.has(key));
choices.setChoiceByValue(hasKeys);
});
})
.catch(error => {
console.error('Error retrieving JSON schema:', error);
});
}


export function setUpAssistant(expand = false) {

getConnElement().addEventListener('change', setupTableList);
setupTableList();

const error = getErrorMessage();

if(expand || error) {
if (expand || error) {
const myCollapseElement = document.getElementById('assistant_collapse');
const bsCollapse = new bootstrap.Collapse(myCollapseElement, {
toggle: false
toggle: false
});
bsCollapse.show();
if(error) {
if (error) {
document.getElementById('id_error_help_message').classList.remove('d-none');
}
}

const tooltipTriggerList = document.querySelectorAll('[data-bs-toggle="tooltip"]');
[...tooltipTriggerList].map(tooltipTriggerEl => new bootstrap.Tooltip(tooltipTriggerEl));

document.getElementById('id_assistant_input').addEventListener('keydown', function(event) {
document.getElementById('id_assistant_input').addEventListener('keydown', function (event) {
if ((event.ctrlKey || event.metaKey) && (event.key === 'Enter')) {
event.preventDefault();
submitAssistantAsk();
Expand Down
7 changes: 6 additions & 1 deletion explorer/src/js/codemirror-config.js
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,14 @@ import { Prec } from "@codemirror/state";
import {sql} from "@codemirror/lang-sql";
import { SchemaSvc } from "./schemaService"

let debounceTimeout;

let updateListenerExtension = EditorView.updateListener.of((update) => {
if (update.docChanged) {
document.dispatchEvent(new CustomEvent('docChanged', {}));
clearTimeout(debounceTimeout);
debounceTimeout = setTimeout(() => {
document.dispatchEvent(new CustomEvent('docChanged', {}));
}, 500);
}
});

Expand Down
5 changes: 0 additions & 5 deletions explorer/src/scss/assistant.scss
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@
cursor: pointer;
}

#additional_table_container {
overflow-y: auto;
max-height: 10rem;
}

#assistant_input_parent {
max-height: 120px;
overflow: hidden;
Expand Down
1 change: 1 addition & 0 deletions explorer/src/scss/styles.scss
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ $bootstrap-icons-font-dir: "../../../node_modules/bootstrap-icons/font/fonts";
@import "assistant";

@import "pivot.css";
@import "choices.js/public/assets/styles/choices.css";
14 changes: 14 additions & 0 deletions explorer/templates/assistant/table_description_confirm_delete.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{% extends "explorer/base.html" %}

{% block sql_explorer_content %}
<div class="container mt-5">
<h1>Confirm Delete</h1>
<p>Are you sure you want to delete the table description for "{{ object.table_name }}" in {{ object.connection.alias }}?</p>
<form method="post">
{% csrf_token %}
<button type="submit" class="btn btn-danger">Confirm Delete</button>
<a href="{% url 'table_description_list' %}" class="btn btn-secondary">Cancel</a>
</form>
</div>
{% endblock %}

Loading

0 comments on commit bc2c1c7

Please sign in to comment.